#!/usr/bin/python3
# Copyright 2023 Bas Wijnen <wijnen@debian.org> {{{
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# }}}

import lua

def print_dict(d):
	keys = list(d.keys())
	keys.sort(key = lambda x: str(x))
	return '{' + ', '.join("%s: %s" % (repr(k), d[k]) for k in keys) + '}'

code = lua.Lua()

code.run(var = 'a', value = 'x')

code2 = lua.Lua()
print('global "x"', code2.run(b'return a'))

# features: debug, loadlib, searchers, dofile, loadfile, os, io

# run_file
print('file {a: 3, b: "bar"}', print_dict(code.run_file('foo.lua').dict()))

# run: script (str)
print('1 + 2 =', code.run('return 1 + 2'))

# run: set variable
code.run(var = 'x', value = 15)
print('x = 15:', code.run(b'return x'))

# run: both
print('new x = "foo":', code.run(b'return x', var = 'x', value = 'foo'))

# run: use str
print('use str = "48":', code.run(b'return x:byte(1)', var = 'x', value = '012'))

# run: use bytes
print('use bytes = "49":', code.run(b'return x[1]', var = 'x', value = b'012'))

# module
module = {
	'foo': 12,
	'bar': 42
}
code.module('baz', module)
print('module foo = 12:', code.run(b'baz = require "baz" return baz.foo'))

# metatable:
class Obj:
	def __init__(self, v):
		self.value = v
	def __add__(self, other):
		return self.value + other
	def __sub__(self, other):
		return self.value - other
	def __mul__(self, other):
		return self.value * other
	def __truediv__(self, other):
		return self.value / other
	def __mod__(self, other):
		return self.value % other
	def __pow__(self, other):
		return self.value ** other
	def __neg__(self):
		return -self.value
	def __floordiv__(self, other):
		return self.value // other
	def __and__(self, other):
		return self.value & int(other)
	def __or__(self, other):
		return self.value | int(other)
	def __xor__(self, other):
		return self.value ^ int(other)
	def __invert__(self):
		return ~self.value
	def __lshift__(self, other):
		return self.value << int(other)
	def __rshift__(self, other):
		return self.value >> int(other)
	def __matmul__(self, other):	# abused as concat operator.
		return str(self.value) + ' concat ' + str(other)
	def __len__(self):
		return len(self.value)
	def __eq__(self, other):
		return self.value == other
	def __ne__(self, other):
		return self.value != other
	def __lt__(self, other):
		return self.value < other
	def __gt__(self, other):
		return self.value > other
	def __le__(self, other):
		return self.value <= other
	def __ge__(self, other):
		return self.value >= other
	def __getitem__(self, other):
		if isinstance(other, (str, bytes)):
			return self.value[other]
		return self.value[int(other)]
	def __setitem__(self, other, newval):
		self.value[int(other)] = newval
	def __call__(self, *args):
		return self.value(*args)
	#def __del__(self):
	#	print('killing', self.value)
	def __repr__(self):
		return repr(self.value)
	def __close__(self, arg):
		print('closing', self.value, arg)

#	'add',
print('+ 8:', code.run(b'return obj + 5', var = 'obj', value = Obj(3)))
#	'sub',
print('-2 -2:', code.run(b'return obj - 5', var = 'obj', value = Obj(3)))
#	'mul',
print('* 15:', code.run(b'return obj * 5', var = 'obj', value = Obj(3)))
#	'div',
print('/ .6:', code.run(b'return obj / 5', var = 'obj', value = Obj(3)))
#	'mod',
print('% 3:', code.run(b'return obj % 5', var = 'obj', value = Obj(3)))
#	'pow',
print('** 243:', code.run(b'return obj ^ 5', var = 'obj', value = Obj(3)))
#	'unm',
print('-1 -3:', code.run(b'return -obj', var = 'obj', value = Obj(3)))
#	'idiv',
print('// 0:', code.run(b'return obj // 5', var = 'obj', value = Obj(3)))
#	'band',
print('& 1:', code.run(b'return obj & 5', var = 'obj', value = Obj(3)))
#	'bor',
print('| 7:', code.run(b'return obj | 5', var = 'obj', value = Obj(3)))
#	'bxor',
print('~ 6:', code.run(b'return obj ~ 5', var = 'obj', value = Obj(3)))
#	'bnot',
print('~ -4:', code.run(b'return ~obj', var = 'obj', value = Obj(3)))
#	'shl',
print('<< 96:', code.run(b'return obj << 5', var = 'obj', value = Obj(3)))
#	'shr',
print('>> 128:', code.run(b'return obj >> 1', var = 'obj', value = Obj(257)))
#	'concat',
print('.. "35"', code.run(b'return obj .. 5', var = 'obj', value = Obj(3)))
#	'len',
print('# 4:', code.run(b'return #obj', var = 'obj', value = Obj([4, 6, 1, 2])))
#	'eq',
print('== False:', code.run(b'return obj == 5', var = 'obj', value = Obj(3)))
print('~= True:', code.run(b'return obj ~= 5', var = 'obj', value = Obj(3)))
#	'lt',
print('< True:', code.run(b'return obj < 5', var = 'obj', value = Obj(3)))
print('> False:', code.run(b'return obj > 5', var = 'obj', value = Obj(3)))
#	'le',
print('<= True:', code.run(b'return obj <= 5', var = 'obj', value = Obj(3)))
print('>= False:', code.run(b'return obj >= 5', var = 'obj', value = Obj(3)))
#	'index',
print('[int] 3:', code.run(b'return obj[5]', var = 'obj', value = Obj([4, 5, 2, 7, 9, 3])))
print('[str] 25:', code.run(b'return obj["cheese"]', var = 'obj', value = Obj({'milk': 12, 'cheese': 25, 'eggs': 3})))
#	'newindex',
print('[]= [4,7,42,2,33,10]', code.run(b'obj[2] = 42 return obj', var = 'obj', value = Obj([4, 7, 3, 2, 33, 10])))
#	'call',
print('() None + 5,"foo":', code.run(b'return obj(5, "foo")', var = 'obj', value = Obj(print)))
#	'close',
print('close 8 and 12', code.run(b'if true then local x <close> = obj end return 8', var = 'obj', value = Obj(12)))
#	'tostring'
print('str "abc"', code.run(b'return tostring(obj)', var = 'obj', value = Obj(b'abc')))

# Function object:
#	call
f = code.run(b'return function(a, b) return a + b + 2 end')
print('code call 11:', f(4, 5))

# Table object:
t = code.run(b'return {4, 5, 6}')
#	len()
print('table len 3', len(t))
#	+=
#	[] = 
t += (3, 4, 5)
#	[]
print('table [] 3:', t[4])
t[2] = 'boo'
#	del
del t[4]
print('table {None,4,"boo",6,None,4,5}', print_dict(t.dict()))
t[4] = 0
#	in
print('contain FFTT', 'boo' in t, 'foo' in t, 3 in t, 3.0 in t)
#	for
print('for', ['%s: %s' % (x, t[x]) for x in t])
#	==
#	!=
u = code.run(b'return {}')
for i in t:
	u[i] = t[i]
print('eq TFFT:', t == t, t == u, t != t, t != u)
#print('lte', t < t, t <= u, t >= t, t > u)
#	dict
print('dict', print_dict(t.dict()))
#	list
print('list', t.list())
#	pop
t.pop()
print('pop', t.list())
t.pop(1)
print('pop', t.list())
#	del
del t
print("del'd")

# Python objects
print('python list [1, 2, 3]', code.run(b'python = require("python") return python.list{1, 2, 3}'))
print('python dict {1: "foo", "bar": 42}', print_dict(code.run(b'python = require("python") return python.dict{"foo", bar = 42}')))
print("python bytes from string = b'Hello!'", code.run(b'python = require("python") return python.bytes("Hello!")'))
print("python bytes from table = b'AB'", code.run(b'python = require("python") return python.bytes{65, 66}'))
print("python bytes from int = b'\\x00\\x00'", code.run(b'python = require("python") return python.bytes(2)'))

#print(code.run(b'return require "foo"')[0].dict())
