Move database classes to subdir and create half-implemented sqlite server

This commit is contained in:
Izalia Mae 2021-05-17 17:38:18 -04:00
parent 820a5b4772
commit 649507fdaa
6 changed files with 447 additions and 30 deletions

View file

@ -0,0 +1,17 @@
from .. import logging
try:
from .sql import SqlDatabase
from .sqlite_server import SqliteClient, SqliteServer
except ImportError as e:
logging.verbose('Failed to load SqlDatabase, SqliteClient, and SqliteServer. Is sqlalchemy installed?')
try:
from .tiny import TinyDatabase
except ImportError as e:
logging.verbose('Failed to import TinyDatabase. Is tinydb and tinydb-serialization installed?')
try:
from .pysondb import PysonDatabase
except ImportError as e:
logging.verbose('Failed to import PysonDatabase. Is pysondb installed?')

View file

@ -11,10 +11,10 @@ import random
from pysondb.db import JsonDatabase, IdNotFoundError
from . import misc
from .. import misc
class Database(multiprocessing.Process):
class PysonDatabase(multiprocessing.Process):
def __init__(self, dbpath: misc.Path, tables: dict=None):
multiprocessing.Process.__init__(self, daemon=True)

View file

@ -7,25 +7,25 @@ from sqlalchemy import Column as SqlColumn, types as Types
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.orm import scoped_session, sessionmaker
from . import logging
from .cache import LRUCache
from .misc import DotDict, RandomGen, NfsCheck, PrintMethods, Path
from .. import logging
from ..cache import LRUCache
from ..misc import DotDict, RandomGen, NfsCheck, PrintMethods, Path
SqlTypes = DotDict({t.lower(): getattr(Types, t) for t in dir(Types) if not t.startswith('_')})
class DataBase():
def __init__(self, dbtype='postgresql+pg8000', tables={}, **kwargs):
class SqlDatabase:
def __init__(self, dbtype='sqlite', tables={}, **kwargs):
self.db = self.__create_engine(dbtype, kwargs)
self.table = Tables(self, tables)
self.table_names = tables.keys()
self.table = None
self.table_names = None
self.classes = kwargs.get('row_classes', CustomRows())
self.cache = None
session_class = kwargs.get('session_class', Session)
self.session = lambda trans=True: session_class(self, trans)
self.session_class = kwargs.get('session_class', Session)
self.sessions = {}
self.SetupTables(tables)
self.SetupCache()
@ -42,7 +42,7 @@ class DataBase():
if NfsCheck(kwargs.get('database')):
logging.error('Database file is on an NFS share which does not support locking. Any writes to the database will fail')
engine_string += '/' + kwargs.get('database')
engine_string += '/' + str(kwargs.get('database'))
engine_kwargs['connect_args'] = {'check_same_thread': False}
else:
@ -68,6 +68,11 @@ class DataBase():
return create_engine(engine_string, *engine_args, **engine_kwargs)
@property
def session(self):
return self.session_class(self)
def close(self):
self.SetupCache()
@ -102,19 +107,25 @@ class DataBase():
self.table.meta.create_all(self.db)
def SetupTables(self, tables):
self.table = Tables(self, tables)
self.table_names = tables.keys()
def execute(self, *args, **kwargs):
with self.session() as s:
return s.execute(*args, **kwargs)
class Session(object):
def __init__(self, db, trans=True):
def __init__(self, db):
self.closed = False
self.db = db
self.classes = self.db.classes
self.session = sessionmaker(bind=db.db)()
self.table = self.db.table
self.cache = self.db.cache
self.trans = trans
# session aliases
self.s = self.session
@ -123,14 +134,12 @@ class Session(object):
self.rollback = self.s.rollback
self.query = self.s.query
self.execute = self.s.execute
self.close = self.s.close
self._setup()
def __enter__(self):
self.sessionid = RandomGen(10)
self.db.sessions[self.sessionid] = self
self.open()
return self
@ -138,10 +147,23 @@ class Session(object):
if tb:
self.rollback()
self.commit()
self.close()
def open(self):
self.sessionid = RandomGen(10)
self.db.sessions[self.sessionid] = self
def close(self):
self.commit()
self.s.close()
self.closed = True
del self.db.sessions[self.sessionid]
self.sessionid = None
def _setup(self):
pass
@ -216,7 +238,6 @@ class Session(object):
def remove(self, table=None, rowid=None, row=None):
if row:
rowid = row.id
table = row._table_name
if not rowid or not table:
raise ValueError('Missing row ID or table')
@ -297,12 +318,15 @@ class CustomRows(object):
return
super().__init__()
self._update(row._asdict())
try:
self._update(row._asdict())
except:
self._update(row)
self._db = session.db
self._table_name = table
self._columns = self.keys()
#self._columns = self._filter_columns(row)
self.__run__(session)
@ -345,7 +369,7 @@ class CustomRows(object):
def delete_session(self, s):
return s.remove(row=self)
return s.remove(table=self._table_name, row=self)
def update(self, dict_data={}, s=None, **data):
@ -362,7 +386,7 @@ class CustomRows(object):
def update_session(self, s, dict_data={}, **data):
dict_data.update(data)
self._update(dict_data)
return s.update(row=self, **dict_data)
return s.update(table=self._table_name, row=self, **dict_data)
class Tables(DotDict):
@ -378,7 +402,8 @@ class Tables(DotDict):
def __setup_table(self, name, table):
self[name] = Table(name, self.meta, *table)
columns = [col if type(col) == SqlColumn else Column(*col.get('args'), **col.get('kwargs')) for col in table]
self[name] = Table(name, self.meta, *columns)
def Column(name, stype=None, fkey=None, **kwargs):

View file

@ -0,0 +1,374 @@
import asyncio, json, socket, sqlite3, ssl, time, traceback
from . import SqlDatabase
from .sql import CustomRows
from .. import logging, misc
commands = [
'insert', 'update', 'remove', 'query', 'execute', 'dirty', 'count',
'DropTables', 'GetTables', 'AppendColumn', 'RemoveColumn'
]
class SqliteClient(object):
def __init__(self, database: str='metadata', host: str='localhost', port: int=3926, password: str=None, session_class=None):
self.ssl = None
self.data = misc.DotDict({
'host': host,
'port': int(port),
'password': password,
'database': database
})
self.session_class = session_class or SqliteSession
self.classes = CustomRows()
self._setup()
@property
def session(self):
return self.session_class(self)
def setup_ssl(self, certfile, keyfile, password=None):
self.ssl = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
self.ssl.load_cert_chain(certfile, keyfile, password)
def switch_database(self, database):
self.data.database = database
def _setup(self):
pass
class SqliteSession(socket.socket):
def __init__(self, client):
super().__init__(socket.AF_INET, socket.SOCK_STREAM)
self.connected = False
self.client = client
self.classes = client.classes
self.data = client.data
self.begin = lambda: self.send('begin')
self.commit = lambda: self.send('commit')
self.rollback = lambda: self.send('rollback')
for cmd in commands:
self.setup_command(cmd)
def __enter__(self):
self.open()
return self
def __exit__(self, exctype, value, tb):
if tb:
self.rollback()
self.commit()
self.close()
def fetch(self, table, *args, **kwargs):
RowClass = self.classes.get(table.capitalize())
data = self.send('fetch', table, *args, **kwargs)
if isinstance(data, dict):
return RowClass(table, data, self)
elif isinstance(data, list):
return [RowClass(table, row, self) for row in data]
def search(self, *args, **kwargs):
return self.fetch(*args, **kwargs, single=False)
def setup_command(self, name):
setattr(self, name, lambda *args, **kwargs: self.send(name, *args, **kwargs))
def send(self, command, *args, **kwargs):
self.sendall(json.dumps({'database': self.data.database, 'command': command, 'args': list(args), 'kwargs': dict(kwargs)}).encode('utf8'))
data = self.recv(8*1024*1024).decode()
try:
data = misc.DotDict(data)
except ValueError:
data = json.loads(data)
if isinstance(data, dict) and data.get('error'):
raise ServerError(data.get('error'))
return data
def open(self):
try:
self.connect((self.data.host, self.data.port))
except ConnectionRefusedError:
time.sleep(2)
self.connect((self.data.host, self.data.port))
if self.data.password:
login = self.send('login', self.data.password)
if not login.get('message') == 'OK':
logging.error('Server error:', login.error)
return
self.connected = True
def close(self):
self.send('close')
super().close()
self.connected = False
def is_transaction(self):
self.send('trans_state')
def is_connected(self):
return self.connected
def _setup(self):
pass
def Column(*args, **kwargs):
return {'args': list(args), 'kwargs': dict(kwargs)}
class SqliteServer(misc.DotDict):
def __init__(self, path, host='localhost', port=3926, password=None):
self.server = None
self.database = misc.DotDict()
self.path = misc.Path(path).resolve()
self.ssl = None
self.password = password
self.host = host
self.port = int(port)
self.metadata_layout = {
'databases': [
Column('id'),
Column('name', 'text', nullable=False),
Column('layout', 'text', nullable=False)
]
}
if not self.path.exists():
raise FileNotFoundError('Database directory not found')
if not self.path.isdir():
raise NotADirectoryError('Database directory is a file')
try:
self.open('metadata')
except:
self.setup_metadata()
for path in self.path.listdir(False):
if path.str().endswith('.sqlite3') and path.stem != 'metadata':
self.open(path.stem)
def open(self, database, new=False):
db = SqlDatabase(dbtype='sqlite', database=self.path.join(database + '.sqlite3'))
if database != 'metadata' and not new:
with self.get_database('metadata').session() as s:
row = s.fetch('databases', name=database)
if not row:
logging.error('Database not found:', database)
return
db.SetupTables(row.layout)
else:
db.SetupTables(self.metadata_layout)
setattr(db, 'name', database)
self[database] = db
return db
def close(self, database):
del self[database]
def delete(self, database):
self.close(database)
path.join(database + '.sqlite3').unlink()
def get_database(self, database):
return self[database]
def asyncio_run(self):
self.server = asyncio.start_server(self.handle_connection, self.host, self.port, ssl=self.ssl)
return self.server
def run(self):
loop = asyncio.get_event_loop()
loop.run_until_complete(self.asyncio_run())
try:
logging.info('Starting Sqlite Server')
loop.run_forever()
except KeyboardInterrupt:
print()
logging.info('Closing...')
return
def setup_metadata(self):
meta = self.open('metadata')
tables = {
'databases': [
Column('id'),
Column('name', 'text', nullable=False),
Column('layout', 'text', nullable=False)
]
}
db = self.open('metadata')
db.SetupTables(tables)
db.CreateDatabase()
def setup_ssl(self, certfile, keyfile, password=None):
self.ssl = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
self.ssl.load_cert_chain(certfile, keyfile, password)
async def handle_connection(self, reader, writer):
session = None
database = None
valid = None
close = False
try:
while not close:
raw_data = await asyncio.wait_for(reader.read(8*1024*1024), timeout=60)
if not raw_data:
break
try:
data = misc.DotDict(raw_data)
if self.password:
if valid == None and data.command == 'login':
valid = self.login(*data.get('args'))
if not valid:
response = {'error': 'Missing or invalid password'}
elif data.command in ['session']:
response = {'error': 'Invalid command'}
else:
if not database:
database = data.database
if data.command == 'close' and session:
session.commit()
else:
if not session:
session = self[database].session()
session.open()
response = self.run_command(session, database, data.command, *data.get('args'), **data.get('kwargs'))
except Exception as e:
traceback.print_exc()
response = {'error': f'{e.__class__.__name__}: {str(e)}'}
writer.write(json.dumps(response or {'message': 'OK'}, cls=misc.JsonEncoder).encode('utf8'))
await writer.drain()
logging.info(f'{writer.get_extra_info("peername")[0]}: [{database}] {data.command} {data.args} {data.kwargs}')
if data.command == 'delete':
writer.close()
break
except ConnectionResetError:
pass
if session:
session.close()
writer.close()
def login(self, password):
return self.password == password
def run_command(self, session, database, command, *args, **kwargs):
if command == 'update':
return self.cmd_update(*args, **kwargs)
if command == 'dropdb':
return self.cmd_delete(session, database)
elif command == 'createdb':
return self.cmd_createdb(session, database, *args)
elif command == 'test':
return
elif command == 'trans_state':
return {'trans_state': session.dirty}
cmd = getattr(session, command, None)
if not cmd:
return {'error': f'Command not found: {command}'}
return cmd(*args, **kwargs)
def cmd_delete(self, session, database):
session.rollback()
session.close()
self.delete(database)
def cmd_createdb(self, session, database, name, tables):
if session.fetch('databases', name=name):
raise ValueError('Database already exists:', database)
session.insert('databases', name=name, layout=json.dumps(tables))
db = self.open(name, new=True)
db.SetupTables(tables)
db.CreateDatabase()
self[name] = db
def cmd_update(self, table=None, rowid=None, row=None, **data):
if row:
row = misc.DotDict(row)
return self.update(table, rowid, row, **data)
class ServerError(Exception):
pass

View file

@ -8,14 +8,14 @@ import time
import tinydb
import tinydb_serialization
from . import misc
from .. import misc
class AwaitingResult(object):
pass
class DataBase(tinydb.TinyDB):
class TinyDatabase(tinydb.TinyDB):
def __init__(self, dbfile: misc.Path, queue_limit: int=64, serializers: list=[]):
options = {
'indent': 2,

View file

@ -210,11 +210,11 @@ class DotDict(dict):
if isinstance(value, (str, bytes)):
self.fromJson(value)
elif isinstance(value, dict):
elif isinstance(value, dict) or isinstance(value, list):
self.update(value)
elif value:
raise TypeError('The value must be a JSON string, dict, or another DotDict object, not', value.__class__)
raise TypeError('The value must be a JSON string, list, dict, or another DotDict object, not', value.__class__)
if kwargs:
self.update(kwargs)
@ -479,8 +479,9 @@ class Path(object):
return self.__path.is_symlink()
def listdir(self):
return [Path(path) for path in self.__path.iterdir()]
def listdir(self, recursive=True):
paths = self.__path.iterdir() if recursive else os.listdir(self.__path)
return [Path(path) for path in paths]
def exists(self):