New serialization protocol to replace json

This commit is contained in:
neumond 2020-07-20 03:53:56 +03:00
parent db07997335
commit cb66ca9741
8 changed files with 458 additions and 78 deletions

View file

@ -11,47 +11,115 @@ ws = http.websocket(url..'ws/')
if ws == false then
error('unable to connect to server '..url..'ws/')
end
ws.send(textutils.serializeJSON{
local serialize
do
local function s_rec(v, tracking)
local t = type(v)
if v == nil then
return 'N'
elseif v == false then
return 'F'
elseif v == true then
return 'T'
elseif t == 'number' then
return '\[' .. tostring(v) .. '\]'
elseif t == 'string' then
return string.format('<%u>', #v) .. v
elseif t == 'table' then
if tracking[v] ~= nil then
error('Cannot serialize table with recursive entries', 0)
end
tracking[v] = true
local r = '{'
for k, x in pairs(v) do
r = r .. ':' .. s_rec(k, tracking) .. s_rec(x, tracking)
end
return r .. '}'
else
error('Cannot serialize type ' .. t, 0)
end
local tp = type(t)
end
serialize = function(v) return s_rec(v, {}) end
end
local deserialize
do
local function d_rec(s, idx)
local tok = s:sub(idx, idx)
idx = idx + 1
if tok == 'N' then
return nil, idx
elseif tok == 'F' then
return false, idx
elseif tok == 'T' then
return true, idx
elseif tok == '\[' then
local newidx = s:find('\]', idx, true)
return tonumber(s:sub(idx, newidx - 1)), newidx + 1
elseif tok == '<' then
local newidx = s:find('>', idx, true)
local slen = tonumber(s:sub(idx, newidx - 1))
if slen == 0 then
return '', newidx + 1
end
return s:sub(newidx + 1, newidx + slen), newidx + slen + 1
elseif tok == '{' then
local r = {}
while true do
tok = s:sub(idx, idx)
idx = idx + 1
if tok == '}' then break end
local key, value
key, idx = d_rec(s, idx)
value, idx = d_rec(s, idx)
r[key] = value
end
return r, idx
else
error('Unknown token ' .. tok, 0)
end
end
deserialize = function(s)
local r = d_rec(s, 1)
return r
end
end
function ws_send(data)
ws.send(serialize(data), true)
end
ws_send{
action='run',
computer=os.getComputerID(),
args={...},
})
function nullify_array(a, size)
local r = {}
for k=1,size do
if a[k] == nil then
r[k] = textutils.json_null
else
r[k] = a[k]
end
end
return r
end
}
while true do
local event, p1, p2, p3, p4, p5 = os.pullEvent()
if event == 'websocket_message' then
msg = textutils.unserializeJSON(p2)
msg = deserialize(p2)
if msg.action == 'task' then
local fn, err = loadstring(msg.code)
if fn == nil then
ws.send(textutils.serializeJSON{
ws_send{
action='task_result',
task_id=msg.task_id,
result={false, err},
yields=0,
})
}
else
setfenv(fn, genv)
if msg.immediate then
ws.send(textutils.serializeJSON{
ws_send{
action='task_result',
task_id=msg.task_id,
result={fn()},
yields=0,
})
}
else
tasks[msg.task_id] = coroutine.create(fn)
ycounts[msg.task_id] = 0
@ -74,11 +142,11 @@ while true do
break
end
elseif event_sub[event] == true then
ws.send(textutils.serializeJSON{
ws_send{
action='event',
event=event,
params=nullify_array({p1, p2, p3, p4, p5}, 5),
})
params={p1, p2, p3, p4, p5},
}
end
local del_tasks = {}
@ -86,12 +154,12 @@ while true do
if filters[task_id] == nil or filters[task_id] == event then
local r = {coroutine.resume(tasks[task_id], event, p1, p2, p3, p4, p5)}
if coroutine.status(tasks[task_id]) == 'dead' then
ws.send(textutils.serializeJSON{
ws_send{
action='task_result',
task_id=task_id,
result=r,
yields=ycounts[task_id],
})
}
del_tasks[task_id] = true
else
if r[1] == true then

View file

@ -18,21 +18,25 @@ class ArbLuaExpr(LuaExpr):
return self._code
_tmap = {
'\\': '\\\\',
'\a': '\\a',
'\b': '\\b',
'\f': '\\f',
'\n': '\\n',
'\r': '\\r',
'\t': '\\t',
'\v': '\\v',
'"': '\\"',
"'": "\\'",
'[': '\\[',
']': '\\]',
}
_tmap = {ord(c): r for c, r in _tmap.items()}
def lua_string(v):
return '"{}"'.format(
v.replace('\\', '\\\\')
.replace('\a', '\\a')
.replace('\b', '\\b')
.replace('\f', '\\f')
.replace('\n', '\\n')
.replace('\r', '\\r')
.replace('\t', '\\t')
.replace('\v', '\\v')
.replace('"', '\\"')
.replace("'", "\\'")
.replace('[', '\\[')
.replace(']', '\\]')
)
return '"{}"'.format(v.translate(_tmap))
def lua_list(v):

82
computercraft/ser.py Normal file
View file

@ -0,0 +1,82 @@
from typing import Any, Tuple
__all__ = (
'serialize',
'deserialize',
)
_ENC = 'latin1'
# encoding fast check
assert [bytes([i]) for i in range(256)] == [chr(i).encode(_ENC) for i in range(256)]
def serialize(v: Any) -> bytes:
if v is None:
return b'N'
elif v is False:
return b'F'
elif v is True:
return b'T'
elif isinstance(v, (int, float)):
return '[{}]'.format(v).encode(_ENC)
elif isinstance(v, str):
v = v.encode(_ENC)
return '<{}>'.format(len(v)).encode(_ENC) + v
elif isinstance(v, (list, tuple)):
items = []
for k, x in enumerate(v, start=1):
items.append(b':' + serialize(k) + serialize(x))
return b'{' + b''.join(items) + b'}'
elif isinstance(v, dict):
items = []
for k, x in v.items():
items.append(b':' + serialize(k) + serialize(x))
return b'{' + b''.join(items) + b'}'
else:
raise ValueError
def _deserialize(b: bytes, _idx: int) -> Tuple[Any, int]:
tok = b[_idx]
_idx += 1
if tok == 78: # N
return None, _idx
elif tok == 70: # F
return False, _idx
elif tok == 84: # T
return True, _idx
elif tok == 91: # [
newidx = b.index(b']', _idx)
f = float(b[_idx:newidx].decode(_ENC))
if f.is_integer():
f = int(f)
return f, newidx + 1
elif tok == 60: # <
newidx = b.index(b'>', _idx)
ln = int(b[_idx:newidx].decode(_ENC))
return b[newidx + 1:newidx + 1 + ln].decode(_ENC), newidx + 1 + ln
elif tok == 123: # {
r = {}
while True:
tok = b[_idx]
_idx += 1
if tok == 125: # }
break
key, _idx = _deserialize(b, _idx)
value, _idx = _deserialize(b, _idx)
r[key] = value
if r:
for i in range(1, len(r) + 1):
if i not in r:
break
else:
r = [r[i + 1] for i in range(len(r))]
return r, _idx
else:
raise ValueError
def deserialize(b: bytes) -> Any:
return _deserialize(b, 0)[0]

View file

@ -1,11 +1,12 @@
import argparse
import asyncio
import json
import sys
from os.path import join, dirname, abspath
from aiohttp import web, WSMsgType
from .sess import CCSession
from . import ser
THIS_DIR = dirname(abspath(__file__))
@ -14,25 +15,26 @@ LUA_FILE = join(THIS_DIR, 'back.lua')
class CCApplication(web.Application):
@staticmethod
async def _json_messages(ws):
async def _bin_messages(ws):
async for msg in ws:
# print('ws received', msg)
if msg.type != WSMsgType.TEXT:
if msg.type != WSMsgType.BINARY:
continue
# print('ws received', msg.data)
yield json.loads(msg.data.replace('\\\n', '\\n'))
sys.__stdout__.write('ws received ' + repr(msg.data) + '\n')
yield msg.data
async def _launch_program(self, ws):
async for msg in self._json_messages(ws):
async for msg in self._bin_messages(ws):
msg = ser.deserialize(msg)
if msg['action'] != 'run':
await ws.send_json({
await ws.send_bytes(ser.serialize({
'action': 'close',
'error': 'protocol error',
})
}))
return None
def sender(data):
asyncio.create_task(ws.send_json(data))
sys.__stdout__.write('ws send ' + repr(data) + '\n')
asyncio.create_task(ws.send_bytes(data))
sess = CCSession(msg['computer'], sender)
if msg['args']:
@ -47,16 +49,17 @@ class CCApplication(web.Application):
sess = await self._launch_program(ws)
if sess is not None:
async for msg in self._json_messages(ws):
async for msg in self._bin_messages(ws):
msg = ser.deserialize(msg)
if msg['action'] == 'event':
sess.on_event(msg['event'], msg['params'])
elif msg['action'] == 'task_result':
sess.on_task_result(msg['task_id'], msg['result'])
else:
await ws.send_json({
await ws.send_bytes(ser.serialize({
'action': 'close',
'error': 'protocol error',
})
}))
break
return ws

View file

@ -15,7 +15,7 @@ from types import ModuleType
from greenlet import greenlet, getcurrent as get_current_greenlet
from .lua import lua_string, lua_call, return_lua_call
from . import rproc
from . import rproc, ser
__all__ = (
@ -141,10 +141,11 @@ sys.stderr = StdFileProxy(sys.__stderr__, True)
def eval_lua(lua_code, immediate=False):
result = get_current_session()._server_greenlet.switch({
request = ser.serialize({
'code': lua_code,
'immediate': immediate,
})
result = get_current_session()._server_greenlet.switch(request)
# debug('{} → {}'.format(lua_code, repr(result)))
if not immediate:
result = rproc.coro(result)
@ -204,7 +205,7 @@ class CCGreenlet:
if error is not None:
if error is True:
error = {}
self._sess._sender({'action': 'close', **error})
self._sess._sender(ser.serialize({'action': 'close', **error}))
if self._parent is not None:
self._parent._children.discard(self._task_id)
@ -225,16 +226,14 @@ class CCGreenlet:
return
# lua_eval call or simply idle
if isinstance(task, dict):
if isinstance(task, bytes):
x = self
while x._g.dead:
x = x._parent
tid = x._task_id
self._sess._sender({
'action': 'task',
'task_id': tid,
**task,
})
assert task[-1] == 125 # }
tid = ser.serialize(x._task_id)
task = task[:-1] + b':<6>action<4>task:<7>task_id' + tid + b'}'
self._sess._sender(task)
if self._g.dead:
if self._parent is None:
@ -303,8 +302,8 @@ class CCSession:
self._server_greenlet = get_current_greenlet()
self._program_greenlet = None
self._evr = CCEventRouter(
lambda event: self._sender({'action': 'sub', 'event': event}),
lambda event: self._sender({'action': 'unsub', 'event': event}),
lambda event: self._sender(ser.serialize({'action': 'sub', 'event': event})),
lambda event: self._sender(ser.serialize({'action': 'unsub', 'event': event})),
lambda task_id: self._greenlets[task_id].defer_switch('event'),
)
@ -332,10 +331,10 @@ class CCSession:
for task_id in task_ids:
all_tids.extend(collect(task_id))
self._sender({
self._sender(ser.serialize({
'action': 'drop',
'task_ids': all_tids,
})
}))
def _run_sandboxed_greenlet(self, fn):
self._program_greenlet = CCGreenlet(fn, sess=self)

View file

@ -1,12 +1,12 @@
import builtins
from contextlib import contextmanager
from typing import Optional, List, Union
from typing import Optional, List
from .base import BaseSubAPI
from ..errors import LuaException
from ..lua import lua_call
from ..rproc import boolean, string, integer, nil, array_string, option_string, option_integer, fact_scheme_dict
from ..sess import eval_lua_method_factory, lua_context_object
from ..lua import lua_call, lua_args, lua_string
from ..rproc import boolean, string, integer, nil, array_string, option_string, fact_scheme_dict
from ..sess import eval_lua, eval_lua_method_factory, lua_context_object
attribute = fact_scheme_dict({
@ -28,17 +28,28 @@ class SeekMixin:
class ReadHandle(BaseSubAPI):
# TODO: binary handle must return bytes instead string
def _decode(self, b):
return b.decode('utf-8')
def read(self, count: int = None) -> Optional[Union[str, int]]:
r = self._method('read', count)
return option_integer(r) if count is None else option_string(r)
def _read(self, name, params, val):
code = '''
local s = {}.{}({})
if s == nil then return nil end
s = s:gsub('.', function(c) return string.format('%02X', string.byte(c)) end)
return s
'''.lstrip().format(
self.get_expr_code(), name, lua_args(*params),
)
return self._decode(bytes.fromhex(val(eval_lua(code))))
def read(self, count: int = 1) -> Optional[str]:
return self._read('read', (count, ), option_string)
def readLine(self) -> Optional[str]:
return option_string(self._method('readLine'))
return self._read('readLine', (), option_string)
def readAll(self) -> str:
return string(self._method('readAll'))
return self._read('readAll', (), string)
def __iter__(self):
return self
@ -51,10 +62,25 @@ class ReadHandle(BaseSubAPI):
class BinaryReadHandle(ReadHandle, SeekMixin):
pass
def _decode(self, b):
return b
class WriteHandle(BaseSubAPI):
def _encode(self, s):
return s.encode('utf-8')
def _write(self, name, text, val):
code = '''
local s = {}
s = s:gsub('..', function(cc) return string.char(tonumber(cc, 16)) end)
return {}.{}(s)
'''.lstrip().format(
lua_string(self._encode(text).hex()),
self.get_expr_code(), name,
)
return val(eval_lua(code))
def write(self, text: str):
return nil(self._method('write', text))
@ -66,7 +92,8 @@ class WriteHandle(BaseSubAPI):
class BinaryWriteHandle(WriteHandle, SeekMixin):
pass
def _encode(self, s):
return s
method = eval_lua_method_factory('fs.')
@ -161,7 +188,7 @@ def open(path: str, mode: str):
...
'''
with lua_context_object(
lua_call('fs.open', path, mode),
lua_call('fs.open', path, mode.replace('b', '') + 'b'),
'{e}.close()',
) as var:
if 'b' in mode:

148
tests/serialization.lua Normal file
View file

@ -0,0 +1,148 @@
local serialize
do
local function s_rec(v, tracking)
local t = type(v)
if v == nil then
return 'N'
elseif v == false then
return 'F'
elseif v == true then
return 'T'
elseif t == 'number' then
return '\[' .. tostring(v) .. '\]'
elseif t == 'string' then
return string.format('<%u>', #v) .. v
elseif t == 'table' then
if tracking[v] ~= nil then
error('Cannot serialize table with recursive entries', 0)
end
tracking[v] = true
local r = '{'
for k, x in pairs(v) do
r = r .. ':' .. s_rec(k, tracking) .. s_rec(x, tracking)
end
return r .. '}'
else
error('Cannot serialize type ' .. t, 0)
end
local tp = type(t)
end
serialize = function(v) return s_rec(v, {}) end
end
local deserialize
do
local function d_rec(s, idx)
local tok = s:sub(idx, idx)
idx = idx + 1
if tok == 'N' then
return nil, idx
elseif tok == 'F' then
return false, idx
elseif tok == 'T' then
return true, idx
elseif tok == '\[' then
local newidx = s:find('\]', idx, true)
return tonumber(s:sub(idx, newidx - 1)), newidx + 1
elseif tok == '<' then
local newidx = s:find('>', idx, true)
local slen = tonumber(s:sub(idx, newidx - 1))
if slen == 0 then
return '', newidx + 1
end
return s:sub(newidx + 1, newidx + slen), newidx + slen + 1
elseif tok == '{' then
local r = {}
while true do
tok = s:sub(idx, idx)
idx = idx + 1
if tok == '}' then break end
local key, value
key, idx = d_rec(s, idx)
value, idx = d_rec(s, idx)
r[key] = value
end
return r, idx
else
error('Unknown token ' .. tok, 0)
end
end
deserialize = function(s)
local r = d_rec(s, 1)
return r
end
end
print(serialize(nil))
assert(deserialize(serialize(nil)) == nil)
local roundtrip_vals = {
true,
false,
0,
-1,
1,
1e6,
1.5,
2.4e-9,
tonumber('inf'),
tonumber('-inf'),
'',
'string',
'\n\r\0',
'\0',
'2',
}
for _, v in ipairs(roundtrip_vals) do
print(serialize(v))
assert(v == deserialize(serialize(v)))
end
print(serialize(tonumber('nan')))
assert(tostring(deserialize(serialize(tonumber('nan')))) == 'nan')
function areTablesEqual(a, b)
assert(type(a) == 'table')
assert(type(b) == 'table')
for k, v in pairs(a) do
if type(v) == 'table' then
if not areTablesEqual(v, b[k]) then return false end
else
if b[k] ~= v then return false end
end
end
for k, v in pairs(b) do
if type(v) == 'table' then
if not areTablesEqual(v, a[k]) then return false end
else
if a[k] ~= v then return false end
end
end
return true
end
local roundtrip_tables = {
{},
{[2]=4},
{a=1, b=true, c={}, d={x=8}},
{1, 2, 3},
{1},
{'abc'},
{[1]='a', [2]='b', [3]='c'},
{'a', 'b', 'c'},
}
for _, v in ipairs(roundtrip_tables) do
print(serialize(v))
assert(areTablesEqual(v, deserialize(serialize(v))))
end
print('ALL OK')

49
tests/serialization.py Normal file
View file

@ -0,0 +1,49 @@
from math import inf, nan, isnan
from computercraft.ser import serialize, deserialize
roundtrip_vals = [
None,
True,
False,
0,
-1,
1,
1e6,
1.5,
2.4e-9,
# nan,
inf,
-inf,
'',
'string',
'\n\r\0',
'\0',
'2',
{},
{2: 4},
{'a': 1, 'b': None, 'c': {}, 'd': {'x': 8}},
[1, 2, 3],
[1],
['abc'],
]
for v in roundtrip_vals:
print(serialize(v))
assert v == deserialize(serialize(v))
print(serialize(nan))
assert isnan(deserialize(serialize(nan)))
oneway_vals = [
({1: 'a', 2: 'b', 3: 'c'}, ['a', 'b', 'c']),
]
for a, b in oneway_vals:
print(serialize(a))
assert b == deserialize(serialize(a))