python-computer-craft/computercraft/server.py
2020-06-28 08:19:01 +03:00

250 lines
7.9 KiB
Python

import asyncio
import json
import string
from aiohttp import web, WSMsgType
from traceback import print_exc
from os.path import join, dirname, abspath
from importlib import import_module
import argparse
from .subapis.root import RootAPIMixin
from .errors import LuaException
from .subapis.colors import ColorsAPI
from .subapis.commands import CommandsAPI
from .subapis.disk import DiskAPI
from .subapis.fs import FSAPI
from .subapis.gps import GpsAPI
from .subapis.help import HelpAPI
from .subapis.keys import KeysAPI
from .subapis.multishell import MultishellAPI
from .subapis.os import OSAPI
from .subapis.peripheral import PeripheralAPI
from .subapis.rednet import RednetAPI
from .subapis.redstone import RedstoneAPI
from .subapis.settings import SettingsAPI
from .subapis.shell import ShellAPI
from .subapis.term import TermAPI
from .subapis.textutils import TextutilsAPI
from .subapis.turtle import TurtleAPI
from .subapis.window import WindowAPI
THIS_DIR = dirname(abspath(__file__))
LUA_FILE = join(THIS_DIR, 'back.lua')
DIGITS = string.digits + string.ascii_lowercase
def base36(n):
r = ''
while n:
r += DIGITS[n % 36]
n //= 36
return r[::-1]
class CCAPI(RootAPIMixin):
def __init__(self, nid, program, cleanup_callback):
self._id = nid
self._task_autoid = 1
self._cmd = asyncio.Queue(maxsize=1)
self._result_locks = {}
self._result_values = {}
self._result_queues = {}
self._event_to_tids = {}
self._tid_to_event = {}
self.colors = ColorsAPI(self)
self.commands = CommandsAPI(self)
self.disk = DiskAPI(self)
self.fs = FSAPI(self)
self.gps = GpsAPI(self)
self.help = HelpAPI(self)
self.keys = KeysAPI(self)
self.multishell = MultishellAPI(self)
self.os = OSAPI(self)
self.peripheral = PeripheralAPI(self)
self.rednet = RednetAPI(self)
self.redstone = RedstoneAPI(self)
self.settings = SettingsAPI(self)
self.shell = ShellAPI(self) # TODO: autocomplete functions
self.term = TermAPI(self) # TODO: window redirections
self.textutils = TextutilsAPI(self)
self.turtle = TurtleAPI(self)
self.window = WindowAPI(self) # TODO: unimplemented
async def prog_wrap():
err = None
try:
await program(self)
except asyncio.CancelledError:
print('program {} cancelled'.format(self._id))
print_exc()
err = 'program has been cancelled'
except Exception as e:
print('program {} crashed: {} {}'.format(self._id, type(e), e))
print_exc()
err = type(e).__name__ + ': ' + str(e)
else:
print('program {} finished'.format(self._id))
finally:
c = {'action': 'close'}
if err is not None:
c['error'] = err
await self._cmd.put(c)
cleanup_callback()
self._task = asyncio.create_task(prog_wrap())
def _new_task_id(self) -> str:
task_id = base36(self._task_autoid)
self._task_autoid += 1
return task_id
async def raw_eval(self, lua_code):
task_id = self._new_task_id()
self._result_locks[task_id] = asyncio.Event()
await self._cmd.put({
'action': 'task',
'task_id': task_id,
'code': lua_code,
})
await self._result_locks[task_id].wait()
del self._result_locks[task_id]
return self._result_values.pop(task_id)
async def _send_cmd(self, lua):
result = await self.raw_eval(lua)
print('{}{}'.format(lua, result))
if not result[0]:
raise LuaException(*result[1:])
return result[1:]
async def _start_queue(self, event):
task_id = self._new_task_id()
self._result_queues[task_id] = asyncio.Queue()
es = self._event_to_tids.setdefault(event, set())
if not es:
await self._cmd.put({
'action': 'sub',
'event': event,
})
es.add(task_id)
self._tid_to_event[task_id] = event
return self._result_queues[task_id], task_id
async def _stop_queue(self, task_id):
event = self._tid_to_event[task_id]
del self._result_queues[task_id]
del self._tid_to_event[task_id]
self._event_to_tids[event].discard(task_id)
if not self._event_to_tids[event]:
await self._cmd.put({
'action': 'unsub',
'event': event,
})
class CCApplication(web.Application):
@staticmethod
async def _sender(ws, api):
while not ws.closed:
cmd = await api._cmd.get()
# print(f'_sender: {cmd}')
if not ws.closed:
await ws.send_json(cmd)
if cmd['action'] == 'close':
break
@staticmethod
async def _json_messages(ws):
async for msg in ws:
# print('ws received', msg)
if msg.type != WSMsgType.TEXT:
continue
# print('ws received', msg.data)
yield json.loads(msg.data.replace('\\\n', '\\n'))
async def _launch_program(self, ws):
async for msg in self._json_messages(ws):
if msg['action'] != 'run':
await ws.send_json({
'action': 'close',
'error': 'protocol error',
})
return None
module = import_module(self['source_module'])
program = getattr(module, msg['args'][0], None)
if program is None:
await ws.send_json({
'action': 'close',
'error': "program doesn't exist",
})
return None
return CCAPI(msg['computer'], program, lambda: None)
async def ws(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
api = await self._launch_program(ws)
if api is not None:
asyncio.create_task(self._sender(ws, api))
async for msg in self._json_messages(ws):
if msg['action'] == 'event':
for task_id in api._event_to_tids.get(msg['event'], ()):
await api._result_queues[task_id].put(msg['params'])
elif msg['action'] == 'task_result':
api._result_values[msg['task_id']] = msg['result']
api._result_locks[msg['task_id']].set()
else:
await ws.send_json({
'action': 'close',
'error': 'protocol error',
})
break
return ws
@staticmethod
def backdoor(request):
with open(LUA_FILE, 'r') as f:
fcont = f.read()
h = request.host
if ':' not in h:
# fix for malformed Host header
h += ':{}'.format(request.app['port'])
fcont = fcont.replace(
"local url = 'http://127.0.0.1:4343/'",
"local url = '{}://{}/'".format('ws', h)
)
return web.Response(text=fcont)
def initialize(self, source_module):
self['source_module'] = source_module
self['exchange'] = {}
self.router.add_get('/', self.backdoor)
self.router.add_get('/ws/', self.ws)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('module', help='Module used as source for programs')
parser.add_argument('--host')
parser.add_argument('--port', type=int, default=8080)
args = parser.parse_args()
app_kw = {}
if args.host is not None:
app_kw['host'] = args.host
app_kw['port'] = args.port
app = CCApplication()
app['port'] = args.port
app.initialize(args.module)
web.run_app(app, **app_kw)
if __name__ == '__main__':
main()