diff --git a/IzzyLib/database/__init__.py b/IzzyLib/database/__init__.py new file mode 100644 index 0000000..9765c09 --- /dev/null +++ b/IzzyLib/database/__init__.py @@ -0,0 +1,17 @@ +from .. import logging + +try: + from .sql import SqlDatabase + from .sqlite_server import SqliteClient, SqliteServer +except ImportError as e: + logging.verbose('Failed to load SqlDatabase, SqliteClient, and SqliteServer. Is sqlalchemy installed?') + +try: + from .tiny import TinyDatabase +except ImportError as e: + logging.verbose('Failed to import TinyDatabase. Is tinydb and tinydb-serialization installed?') + +try: + from .pysondb import PysonDatabase +except ImportError as e: + logging.verbose('Failed to import PysonDatabase. Is pysondb installed?') diff --git a/IzzyLib/database_pysondb.py b/IzzyLib/database/pysondb.py similarity index 99% rename from IzzyLib/database_pysondb.py rename to IzzyLib/database/pysondb.py index c601edb..4887895 100644 --- a/IzzyLib/database_pysondb.py +++ b/IzzyLib/database/pysondb.py @@ -11,10 +11,10 @@ import random from pysondb.db import JsonDatabase, IdNotFoundError -from . import misc +from .. import misc -class Database(multiprocessing.Process): +class PysonDatabase(multiprocessing.Process): def __init__(self, dbpath: misc.Path, tables: dict=None): multiprocessing.Process.__init__(self, daemon=True) diff --git a/IzzyLib/database.py b/IzzyLib/database/sql.py similarity index 89% rename from IzzyLib/database.py rename to IzzyLib/database/sql.py index f2c7039..614fe0c 100644 --- a/IzzyLib/database.py +++ b/IzzyLib/database/sql.py @@ -7,25 +7,25 @@ from sqlalchemy import Column as SqlColumn, types as Types from sqlalchemy.exc import OperationalError, ProgrammingError from sqlalchemy.orm import scoped_session, sessionmaker -from . import logging -from .cache import LRUCache -from .misc import DotDict, RandomGen, NfsCheck, PrintMethods, Path +from .. import logging +from ..cache import LRUCache +from ..misc import DotDict, RandomGen, NfsCheck, PrintMethods, Path SqlTypes = DotDict({t.lower(): getattr(Types, t) for t in dir(Types) if not t.startswith('_')}) -class DataBase(): - def __init__(self, dbtype='postgresql+pg8000', tables={}, **kwargs): +class SqlDatabase: + def __init__(self, dbtype='sqlite', tables={}, **kwargs): self.db = self.__create_engine(dbtype, kwargs) - self.table = Tables(self, tables) - self.table_names = tables.keys() + self.table = None + self.table_names = None self.classes = kwargs.get('row_classes', CustomRows()) self.cache = None - session_class = kwargs.get('session_class', Session) - self.session = lambda trans=True: session_class(self, trans) + self.session_class = kwargs.get('session_class', Session) self.sessions = {} + self.SetupTables(tables) self.SetupCache() @@ -42,7 +42,7 @@ class DataBase(): if NfsCheck(kwargs.get('database')): logging.error('Database file is on an NFS share which does not support locking. Any writes to the database will fail') - engine_string += '/' + kwargs.get('database') + engine_string += '/' + str(kwargs.get('database')) engine_kwargs['connect_args'] = {'check_same_thread': False} else: @@ -68,6 +68,11 @@ class DataBase(): return create_engine(engine_string, *engine_args, **engine_kwargs) + @property + def session(self): + return self.session_class(self) + + def close(self): self.SetupCache() @@ -102,19 +107,25 @@ class DataBase(): self.table.meta.create_all(self.db) + def SetupTables(self, tables): + self.table = Tables(self, tables) + self.table_names = tables.keys() + + def execute(self, *args, **kwargs): with self.session() as s: return s.execute(*args, **kwargs) class Session(object): - def __init__(self, db, trans=True): + def __init__(self, db): + self.closed = False + self.db = db self.classes = self.db.classes self.session = sessionmaker(bind=db.db)() self.table = self.db.table self.cache = self.db.cache - self.trans = trans # session aliases self.s = self.session @@ -123,14 +134,12 @@ class Session(object): self.rollback = self.s.rollback self.query = self.s.query self.execute = self.s.execute - self.close = self.s.close self._setup() def __enter__(self): - self.sessionid = RandomGen(10) - self.db.sessions[self.sessionid] = self + self.open() return self @@ -138,10 +147,23 @@ class Session(object): if tb: self.rollback() - self.commit() self.close() + + + def open(self): + self.sessionid = RandomGen(10) + self.db.sessions[self.sessionid] = self + + + def close(self): + self.commit() + self.s.close() + self.closed = True + del self.db.sessions[self.sessionid] + self.sessionid = None + def _setup(self): pass @@ -216,7 +238,6 @@ class Session(object): def remove(self, table=None, rowid=None, row=None): if row: rowid = row.id - table = row._table_name if not rowid or not table: raise ValueError('Missing row ID or table') @@ -297,12 +318,15 @@ class CustomRows(object): return super().__init__() - self._update(row._asdict()) + + try: + self._update(row._asdict()) + except: + self._update(row) self._db = session.db self._table_name = table self._columns = self.keys() - #self._columns = self._filter_columns(row) self.__run__(session) @@ -345,7 +369,7 @@ class CustomRows(object): def delete_session(self, s): - return s.remove(row=self) + return s.remove(table=self._table_name, row=self) def update(self, dict_data={}, s=None, **data): @@ -362,7 +386,7 @@ class CustomRows(object): def update_session(self, s, dict_data={}, **data): dict_data.update(data) self._update(dict_data) - return s.update(row=self, **dict_data) + return s.update(table=self._table_name, row=self, **dict_data) class Tables(DotDict): @@ -378,7 +402,8 @@ class Tables(DotDict): def __setup_table(self, name, table): - self[name] = Table(name, self.meta, *table) + columns = [col if type(col) == SqlColumn else Column(*col.get('args'), **col.get('kwargs')) for col in table] + self[name] = Table(name, self.meta, *columns) def Column(name, stype=None, fkey=None, **kwargs): diff --git a/IzzyLib/database/sqlite_server.py b/IzzyLib/database/sqlite_server.py new file mode 100644 index 0000000..8dddb0b --- /dev/null +++ b/IzzyLib/database/sqlite_server.py @@ -0,0 +1,374 @@ +import asyncio, json, socket, sqlite3, ssl, time, traceback + +from . import SqlDatabase +from .sql import CustomRows +from .. import logging, misc + + +commands = [ + 'insert', 'update', 'remove', 'query', 'execute', 'dirty', 'count', + 'DropTables', 'GetTables', 'AppendColumn', 'RemoveColumn' +] + + +class SqliteClient(object): + def __init__(self, database: str='metadata', host: str='localhost', port: int=3926, password: str=None, session_class=None): + self.ssl = None + self.data = misc.DotDict({ + 'host': host, + 'port': int(port), + 'password': password, + 'database': database + }) + + self.session_class = session_class or SqliteSession + self.classes = CustomRows() + + self._setup() + + + @property + def session(self): + return self.session_class(self) + + + def setup_ssl(self, certfile, keyfile, password=None): + self.ssl = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + self.ssl.load_cert_chain(certfile, keyfile, password) + + + def switch_database(self, database): + self.data.database = database + + + def _setup(self): + pass + + +class SqliteSession(socket.socket): + def __init__(self, client): + super().__init__(socket.AF_INET, socket.SOCK_STREAM) + self.connected = False + + self.client = client + self.classes = client.classes + self.data = client.data + self.begin = lambda: self.send('begin') + self.commit = lambda: self.send('commit') + self.rollback = lambda: self.send('rollback') + + for cmd in commands: + self.setup_command(cmd) + + + def __enter__(self): + self.open() + return self + + + def __exit__(self, exctype, value, tb): + if tb: + self.rollback() + + self.commit() + self.close() + + + def fetch(self, table, *args, **kwargs): + RowClass = self.classes.get(table.capitalize()) + data = self.send('fetch', table, *args, **kwargs) + + if isinstance(data, dict): + return RowClass(table, data, self) + + elif isinstance(data, list): + return [RowClass(table, row, self) for row in data] + + + def search(self, *args, **kwargs): + return self.fetch(*args, **kwargs, single=False) + + + def setup_command(self, name): + setattr(self, name, lambda *args, **kwargs: self.send(name, *args, **kwargs)) + + + def send(self, command, *args, **kwargs): + self.sendall(json.dumps({'database': self.data.database, 'command': command, 'args': list(args), 'kwargs': dict(kwargs)}).encode('utf8')) + data = self.recv(8*1024*1024).decode() + + try: + data = misc.DotDict(data) + except ValueError: + data = json.loads(data) + + if isinstance(data, dict) and data.get('error'): + raise ServerError(data.get('error')) + + return data + + + def open(self): + try: + self.connect((self.data.host, self.data.port)) + except ConnectionRefusedError: + time.sleep(2) + self.connect((self.data.host, self.data.port)) + + if self.data.password: + login = self.send('login', self.data.password) + + if not login.get('message') == 'OK': + logging.error('Server error:', login.error) + return + + self.connected = True + + + def close(self): + self.send('close') + super().close() + self.connected = False + + + def is_transaction(self): + self.send('trans_state') + + + def is_connected(self): + return self.connected + + + def _setup(self): + pass + + +def Column(*args, **kwargs): + return {'args': list(args), 'kwargs': dict(kwargs)} + + +class SqliteServer(misc.DotDict): + def __init__(self, path, host='localhost', port=3926, password=None): + self.server = None + self.database = misc.DotDict() + + self.path = misc.Path(path).resolve() + self.ssl = None + self.password = password + self.host = host + self.port = int(port) + + self.metadata_layout = { + 'databases': [ + Column('id'), + Column('name', 'text', nullable=False), + Column('layout', 'text', nullable=False) + ] + } + + if not self.path.exists(): + raise FileNotFoundError('Database directory not found') + + if not self.path.isdir(): + raise NotADirectoryError('Database directory is a file') + + try: + self.open('metadata') + except: + self.setup_metadata() + + for path in self.path.listdir(False): + if path.str().endswith('.sqlite3') and path.stem != 'metadata': + self.open(path.stem) + + + def open(self, database, new=False): + db = SqlDatabase(dbtype='sqlite', database=self.path.join(database + '.sqlite3')) + + if database != 'metadata' and not new: + with self.get_database('metadata').session() as s: + row = s.fetch('databases', name=database) + + if not row: + logging.error('Database not found:', database) + return + + db.SetupTables(row.layout) + + else: + db.SetupTables(self.metadata_layout) + + setattr(db, 'name', database) + self[database] = db + return db + + + def close(self, database): + del self[database] + + + def delete(self, database): + self.close(database) + path.join(database + '.sqlite3').unlink() + + + def get_database(self, database): + return self[database] + + + def asyncio_run(self): + self.server = asyncio.start_server(self.handle_connection, self.host, self.port, ssl=self.ssl) + return self.server + + + def run(self): + loop = asyncio.get_event_loop() + loop.run_until_complete(self.asyncio_run()) + + try: + logging.info('Starting Sqlite Server') + loop.run_forever() + except KeyboardInterrupt: + print() + logging.info('Closing...') + return + + + def setup_metadata(self): + meta = self.open('metadata') + tables = { + 'databases': [ + Column('id'), + Column('name', 'text', nullable=False), + Column('layout', 'text', nullable=False) + ] + } + + db = self.open('metadata') + db.SetupTables(tables) + db.CreateDatabase() + + + def setup_ssl(self, certfile, keyfile, password=None): + self.ssl = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + self.ssl.load_cert_chain(certfile, keyfile, password) + + + async def handle_connection(self, reader, writer): + session = None + database = None + valid = None + close = False + + try: + while not close: + raw_data = await asyncio.wait_for(reader.read(8*1024*1024), timeout=60) + + if not raw_data: + break + + try: + data = misc.DotDict(raw_data) + + if self.password: + if valid == None and data.command == 'login': + valid = self.login(*data.get('args')) + + if not valid: + response = {'error': 'Missing or invalid password'} + + elif data.command in ['session']: + response = {'error': 'Invalid command'} + + else: + if not database: + database = data.database + + if data.command == 'close' and session: + session.commit() + + else: + if not session: + session = self[database].session() + session.open() + + response = self.run_command(session, database, data.command, *data.get('args'), **data.get('kwargs')) + + except Exception as e: + traceback.print_exc() + response = {'error': f'{e.__class__.__name__}: {str(e)}'} + + writer.write(json.dumps(response or {'message': 'OK'}, cls=misc.JsonEncoder).encode('utf8')) + await writer.drain() + logging.info(f'{writer.get_extra_info("peername")[0]}: [{database}] {data.command} {data.args} {data.kwargs}') + + if data.command == 'delete': + writer.close() + break + + except ConnectionResetError: + pass + + if session: + session.close() + + writer.close() + + + def login(self, password): + return self.password == password + + + def run_command(self, session, database, command, *args, **kwargs): + if command == 'update': + return self.cmd_update(*args, **kwargs) + + if command == 'dropdb': + return self.cmd_delete(session, database) + + elif command == 'createdb': + return self.cmd_createdb(session, database, *args) + + elif command == 'test': + return + + elif command == 'trans_state': + return {'trans_state': session.dirty} + + cmd = getattr(session, command, None) + + if not cmd: + return {'error': f'Command not found: {command}'} + + return cmd(*args, **kwargs) + + + def cmd_delete(self, session, database): + session.rollback() + session.close() + + self.delete(database) + + + def cmd_createdb(self, session, database, name, tables): + if session.fetch('databases', name=name): + raise ValueError('Database already exists:', database) + + session.insert('databases', name=name, layout=json.dumps(tables)) + + db = self.open(name, new=True) + db.SetupTables(tables) + db.CreateDatabase() + + self[name] = db + + + def cmd_update(self, table=None, rowid=None, row=None, **data): + if row: + row = misc.DotDict(row) + + return self.update(table, rowid, row, **data) + + +class ServerError(Exception): + pass diff --git a/IzzyLib/database_tiny.py b/IzzyLib/database/tiny.py similarity index 99% rename from IzzyLib/database_tiny.py rename to IzzyLib/database/tiny.py index 258b465..a9cafe4 100644 --- a/IzzyLib/database_tiny.py +++ b/IzzyLib/database/tiny.py @@ -8,14 +8,14 @@ import time import tinydb import tinydb_serialization -from . import misc +from .. import misc class AwaitingResult(object): pass -class DataBase(tinydb.TinyDB): +class TinyDatabase(tinydb.TinyDB): def __init__(self, dbfile: misc.Path, queue_limit: int=64, serializers: list=[]): options = { 'indent': 2, diff --git a/IzzyLib/misc.py b/IzzyLib/misc.py index 657e558..268c474 100644 --- a/IzzyLib/misc.py +++ b/IzzyLib/misc.py @@ -210,11 +210,11 @@ class DotDict(dict): if isinstance(value, (str, bytes)): self.fromJson(value) - elif isinstance(value, dict): + elif isinstance(value, dict) or isinstance(value, list): self.update(value) elif value: - raise TypeError('The value must be a JSON string, dict, or another DotDict object, not', value.__class__) + raise TypeError('The value must be a JSON string, list, dict, or another DotDict object, not', value.__class__) if kwargs: self.update(kwargs) @@ -479,8 +479,9 @@ class Path(object): return self.__path.is_symlink() - def listdir(self): - return [Path(path) for path in self.__path.iterdir()] + def listdir(self, recursive=True): + paths = self.__path.iterdir() if recursive else os.listdir(self.__path) + return [Path(path) for path in paths] def exists(self):