354 lines
7.4 KiB
Python
354 lines
7.4 KiB
Python
import asyncio
|
|
import base64
|
|
import datetime
|
|
import filelock
|
|
import json
|
|
import multiprocessing
|
|
import operator
|
|
import pysondb
|
|
import queue
|
|
import random
|
|
|
|
from pysondb.db import JsonDatabase, IdNotFoundError
|
|
|
|
from . import misc
|
|
|
|
|
|
class Database(multiprocessing.Process):
|
|
def __init__(self, dbpath: misc.Path, tables: dict=None):
|
|
multiprocessing.Process.__init__(self, daemon=True)
|
|
|
|
self.dbpath = dbpath
|
|
self.tables = tables
|
|
self.shutdown = False
|
|
self.port = self._setup_port()
|
|
self.token = misc.RandomGen()
|
|
print(self.port)
|
|
|
|
self.fetch = lambda *args, **kwargs: self.send_message('fetch', *args, **kwargs)
|
|
self.search = lambda *args, **kwargs: self.send_message('search', *args, **kwargs)
|
|
self.insert = lambda *args, **kwargs: self.send_message('insert', *args, **kwargs)
|
|
self.remove = lambda *args, **kwargs: self.send_message('remove', *args, **kwargs)
|
|
|
|
self.start()
|
|
|
|
|
|
def run(self):
|
|
self.db = DatabaseProcess(self.dbpath, self.tables)
|
|
print(self.port)
|
|
server = asyncio.create_server(process_queue, '127.0.0.1', self.port)
|
|
|
|
loop = asyncio.new_event_loop()
|
|
loop.run_until_complete(server)
|
|
loop.run_forever()
|
|
|
|
|
|
def close(self):
|
|
self.terminate()
|
|
self.join(timeout=5)
|
|
|
|
|
|
def _setup_port(self):
|
|
port = None
|
|
|
|
while True:
|
|
port = random.randint(8096, 16394)
|
|
|
|
if misc.PortCheck(port) == True:
|
|
return port
|
|
|
|
|
|
def get_action(self, action):
|
|
return getattr(self.db, action)
|
|
|
|
|
|
def send_message(self, action, table, *args, **kwargs):
|
|
data = {
|
|
'token': self.token,
|
|
'action': action,
|
|
'table': table,
|
|
'args': args,
|
|
'kwargs': kwargs
|
|
}
|
|
|
|
with self.socket as s:
|
|
s.send(json.dumps(data))
|
|
return s.recieve(16 * 1024 * 1024)
|
|
|
|
|
|
@property
|
|
def socket(self):
|
|
return misc.Connection(port=self.port)
|
|
|
|
|
|
async def process_queue(self, reader, writer):
|
|
data = misc.DotDict(await reader.read(16 * 1024 * 1024))
|
|
|
|
if data.token != self.token:
|
|
return
|
|
|
|
if data.action == 'close':
|
|
self._shutdown = True
|
|
|
|
new_data = self.get_action(data.action)(data.table, *args, **kwargs)
|
|
|
|
if isinstance(new_data, dict):
|
|
writer.write(json.dumps(new_data))
|
|
await writer.drain()
|
|
|
|
writer.close()
|
|
|
|
async def pipe_listener(self):
|
|
pass
|
|
|
|
|
|
class DatabaseProcess(misc.DotDict):
|
|
def __init__(self, dbpath: misc.Path, tables: dict=None):
|
|
dbpath = misc.Path(dbpath)
|
|
super().__init__()
|
|
|
|
self.path = misc.Path(dbpath).resolve()
|
|
self.metadata = misc.DotDict(**{
|
|
'path': self.path.join('metadata.json'),
|
|
'lock': self.path.join('metadata.json.lock'),
|
|
'version': 0
|
|
})
|
|
|
|
self._closed = False
|
|
self.__setup_database(tables)
|
|
|
|
|
|
def __setup_database(self, tables):
|
|
self.path.mkdir()
|
|
self.load_meta()
|
|
|
|
for name, columns in tables.items():
|
|
self[name] = columns if type(columns) == Table else Table(name, columns)
|
|
|
|
if not self[name].db:
|
|
self[name].setup(self)
|
|
|
|
|
|
def load_meta(self):
|
|
if self.metadata.path.exists():
|
|
with filelock.FileLock(self.metadata.lock.str()):
|
|
data = self.metadata.path.load_json()
|
|
self.metadata.update(data)
|
|
|
|
|
|
def save_meta(self):
|
|
with filelock.FileLock(self.metadata.lock.str()):
|
|
data = self.metadata.copy()
|
|
data.pop('path')
|
|
data.pop('lock')
|
|
|
|
self.metadata.path.update_json(data)
|
|
self.metadata.path.save_json()
|
|
|
|
|
|
def close(self):
|
|
self.save_meta()
|
|
self._closed = True
|
|
|
|
|
|
def fetch(self, table, *args, **kwargs):
|
|
return self[table].fetch(*args, **kwargs)
|
|
|
|
|
|
def search(self, table, *args, **kwargs):
|
|
return self[table].search(*args, **kwargs)
|
|
|
|
|
|
def insert(self, table, *args, **kwargs):
|
|
return self[table].insert(*args, **kwargs)
|
|
|
|
|
|
def remove(self, table, *args, **kwargs):
|
|
return self[table].remove(*args, **kwargs)
|
|
|
|
|
|
def migrate(self, table=None):
|
|
tables = [self[table]] if table else self.table
|
|
|
|
for name, table in tables:
|
|
for row in table.search():
|
|
table.update(row.id, )
|
|
|
|
|
|
class Table(JsonDatabase):
|
|
def __init__(self, name: str, columns: dict={}):
|
|
self.db = None
|
|
self.name = name
|
|
self.columns = {}
|
|
self.add_column('id')
|
|
|
|
for name, col in columns.items():
|
|
if name != 'id':
|
|
self.add_column(name, *col)
|
|
|
|
|
|
def setup(self, db):
|
|
self.db = db
|
|
tablefile = db.path.join(f'table_{self.name}.json')
|
|
|
|
if not tablefile.exists():
|
|
tablefile.touch(mode=0o644)
|
|
with tablefile.open('w') as fd:
|
|
fd.write('{"data": []}')
|
|
|
|
super().__init__(tablefile.str())
|
|
|
|
|
|
def add_column(self, name: str, type: str='str', default: bool=None, nullable: bool=True, primary_key: bool=False):
|
|
if name == 'id':
|
|
type = 'int'
|
|
nullable = False
|
|
primary_key = True
|
|
|
|
self.columns[name] = misc.DotDict({
|
|
'default': default,
|
|
'type': type,
|
|
'primary_key': primary_key,
|
|
'nullable': nullable
|
|
})
|
|
|
|
|
|
def fetch(self, single=True, orderby=None, reverse=False, **kwargs):
|
|
if self.db._closed:
|
|
return logging.error('Database closed')
|
|
|
|
if not kwargs:
|
|
rows = DBRows(self, self.getAll())
|
|
single = False
|
|
|
|
else:
|
|
rows = DBRows(self, self.getBy(kwargs))
|
|
|
|
if single:
|
|
return rows[0] if rows else None
|
|
|
|
return rows if not orderby else sorted(rows, key=operator.itemgetter(orderby), reverse=reverse)
|
|
|
|
|
|
def search(self, orderby=None, reverse=None, **kwargs):
|
|
self.fetch(single=False, orderby=orderby, reverse=reverse, **kwargs)
|
|
|
|
|
|
def insert(self, row=None, rowid=None, **kwargs):
|
|
if self.db._closed:
|
|
return logging.error('Database closed')
|
|
|
|
new_data = {}
|
|
|
|
for name, col in self.columns.items():
|
|
raw_value = kwargs.get(name, col.default)
|
|
value = serialize(raw_value, col.type)
|
|
|
|
if not value and not col.nullable:
|
|
raise ValueError(f'Column "{name}" cannot be empty')
|
|
|
|
new_data[name] = value
|
|
|
|
if row:
|
|
rowid = row.id
|
|
|
|
if rowid:
|
|
return self.update({'id': rowid}, new_data)
|
|
|
|
return self.add(new_data)
|
|
|
|
|
|
def delete(self, rowid):
|
|
with self.lock:
|
|
with open(self.filename, "r+") as db_file:
|
|
db_data = self._get_load_function()(db_file)
|
|
result = []
|
|
found = False
|
|
|
|
for d in db_data["data"]:
|
|
print(d)
|
|
if rowid in d:
|
|
found = True
|
|
|
|
else:
|
|
result.append(d)
|
|
|
|
if not found:
|
|
raise IdNotFoundError(kwargs)
|
|
|
|
db_data["data"] = result
|
|
db_file.seek(0)
|
|
db_file.truncate()
|
|
self._get_dump_function()(db_data, db_file)
|
|
|
|
return True
|
|
|
|
|
|
def remove(self, row=None, rowid=None, **kwargs):
|
|
if self.db._closed:
|
|
return logging.error('Database closed')
|
|
|
|
if row or rowid:
|
|
return self.remove({'id': rowid or row.id})
|
|
|
|
return self.delete(kwargs)
|
|
|
|
|
|
def _get_dump_function(self):
|
|
return lambda *args, **kwargs: json.dump(*args, indent=2, **kwargs)
|
|
|
|
|
|
def serialize(data, dtype):
|
|
types = {
|
|
'datetime': lambda arg: arg.timestamp(),
|
|
'dotdict': lambda arg: arg.toDict(),
|
|
'bytes': lambda arg: base64.b64encode(arg).decode('ascii'),
|
|
'bool': misc.Boolean,
|
|
'int': int,
|
|
'path': lambda arg: arg.str()
|
|
}
|
|
|
|
if data != None:
|
|
serial_type = types.get(dtype)
|
|
return serial_type(data) if serial_type else data
|
|
|
|
return data
|
|
|
|
|
|
def deserialize(data, dtype):
|
|
types = {
|
|
'datetime': datetime.datetime.fromtimestamp,
|
|
'dotdict': misc.DotDict,
|
|
'dict': misc.DotDict,
|
|
'bytes': lambda: base64.b64decode,
|
|
'path': misc.Path
|
|
}
|
|
|
|
return types.get(dtype)(data) if data else None
|
|
|
|
|
|
|
|
|
|
def DBRows(table, rows):
|
|
return [DBRow(table, row) for row in rows]
|
|
|
|
|
|
class DBRow(misc.DotDict):
|
|
def __init__(self, table, row):
|
|
super().__init(**{k: deserialize(row[v], v.type) for k,v in table.items()})
|
|
self.table = table
|
|
|
|
|
|
def __str__(self):
|
|
data = ', '.join(f'{k}={v}' for k,v in self.items())
|
|
return
|
|
|
|
|
|
def update(self, data={}):
|
|
super().update(data)
|
|
self.table.update(rowid=self.id, **self)
|
|
|
|
|
|
def remove(self):
|
|
self.table.remove(rowid=self.id)
|