izzylib/IzzyLib/database.py

334 lines
7.5 KiB
Python

import sys
from contextlib import contextmanager
from datetime import datetime
from sqlalchemy import create_engine, ForeignKey, MetaData, Table
from sqlalchemy import Column as SqlColumn, types as Types
#from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.orm import sessionmaker
from . import logging
from .cache import LRUCache
from .misc import DotDict, RandomGen, NfsCheck
SqlTypes = DotDict({t.lower(): getattr(Types, t) for t in dir(Types) if not t.startswith('_')})
class DataBase():
def __init__(self, dbtype='postgresql+psycopg2', tables={}, **kwargs):
self.engine_string = self.__engine_string(dbtype, kwargs)
self.db = create_engine(self.engine_string)
self.table = Tables(self, tables)
self.table_names = tables.keys()
self.classes = kwargs.get('row_classes', CustomRows())
self.cache = None
session_class = kwargs.get('session_class', Session)
self.session = lambda trans=True: session_class(self, trans)
self.SetupCache()
def __engine_string(self, dbtype, kwargs):
if not kwargs.get('database'):
raise MissingDatabaseError('Database not set')
engine_string = dbtype + '://'
if dbtype == 'sqlite':
if NfsCheck(kwargs.get('database')):
logging.error('Database file is on an NFS share which does not support locking. Any writes to the database will fail')
engine_string += '/' + kwargs.get('database')
else:
user = kwargs.get('user')
password = kwargs.get('pass')
host = kwargs.get('host', '/var/run/postgresql')
port = kwargs.get('port', 5432)
name = kwargs.get('name', 'postgres')
maxconn = kwargs.get('maxconnections', 25)
if user:
if password:
engine_string += f'{user}:{password}@'
else:
engine_string += user + '@'
if host == '/var/run/postgresql':
engine_string += '/' + name
else:
engine_string += f'{host}:{port}/{name}'
return engine_string
def close(self):
self.SetupCache()
def SetupCache(self):
self.cache = DotDict({table: LRUCache() for table in self.table_names})
def CreateDatabase(self):
if self.engine_string.startswith('postgresql'):
predb = create_engine(db.engine_string.replace(config.db.name, 'postgres', -1))
conn = predb.connect()
conn.execute('commit')
try:
conn.execute(f'CREATE DATABASE {config.db.name}')
except ProgrammingError:
'The database already exists, so just move along'
except Exception as e:
conn.close()
raise e from None
conn.close()
self.table.meta.create_all(self.db)
def execute(self, *args, **kwargs):
with self.session() as s:
return s.execute(*args, **kwargs)
class Session(object):
def __init__(self, db, trans=True):
self.db = db
self.classes = self.db.classes
self.session = sessionmaker(bind=db.db)()
self.table = self.db.table
self.cache = self.db.cache
self.trans = trans
# session aliases
self.s = self.session
self.commit = self.s.commit
self.rollback = self.s.rollback
self.query = self.s.query
self.execute = self.s.execute
self._setup()
if not self.trans:
self.commit()
def __enter__(self):
self.sessionid = RandomGen(10)
return self
def __exit__(self, exctype, value, tb):
if tb:
self.rollback()
else:
self.commit()
def _setup(self):
pass
def count(self, table_name, **kwargs):
return self.query(self.table[table_name]).filter_by(**kwargs).count()
def fetch(self, table_name, single=True, orderby=None, orderdir='asc', **kwargs):
table = self.table[table_name]
RowClass = self.classes.get(table_name.capitalize())
query = self.query(table).filter_by(**kwargs)
if not orderby:
rows = query.all()
else:
if orderdir == 'asc':
rows = query.order_by(getattr(table.c, orderby).asc()).all()
elif orderdir == 'desc':
rows = query.order_by(getattr(table.c, orderby).asc()).all()
else:
raise ValueError(f'Unsupported order direction: {orderdir}')
if single:
return RowClass(table_name, rows[0], self) if len(rows) > 0 else None
return [RowClass(table_name, row, self) for row in rows]
def insert(self, table_name, **kwargs):
row = self.fetch(table_name, **kwargs)
if row:
row.update_session(self, **kwargs)
return
table = self.table[table_name]
if getattr(table, 'timestamp', None) and not kwargs.get('timestamp'):
kwargs['timestamp'] = datetime.now()
res = self.execute(table.insert().values(**kwargs))
#return self.fetch(table_name, **kwargs)
def update(self, table=None, rowid=None, row=None, **data):
if row:
rowid = row.id
table = row._table_name
if not rowid or not table:
raise ValueError('Missing row ID or table')
tclass = self.table[table]
self.execute(tclass.update().where(tclass.c.id == rowid).values(**data))
def remove(self, table=None, rowid=None, row=None):
if row:
rowid = row.id
table = row._table_name
if not rowid or not table:
raise ValueError('Missing row ID or table')
row = self.execute(f'DELETE FROM {table} WHERE id={rowid}')
def DropTables(self):
tables = self.GetTables()
for table in tables:
self.execute(f'DROP TABLE {table}')
def GetTables(self):
rows = self.execute("SELECT name FROM sqlite_master WHERE type IN ('table','view') and name NOT LIKE 'sqlite_%'")
return [row[0] for row in rows]
class CustomRows(object):
def get(self, name):
return getattr(self, name, self.Row)
class Row(DotDict):
#_filter_columns = lambda self, row: [attr for attr in dir(row) if not attr.startswith('_') and attr != 'metadata']
def __init__(self, table, row, session):
if not row:
return
super().__init__()
self._update(row._asdict())
self._db = session.db
self._table_name = table
self._columns = self.keys()
#self._columns = self._filter_columns(row)
self.__run__(session)
## Subclass Row and redefine this function
def __run__(self, s):
pass
def _filter_data(self):
data = {k: v for k,v in self.items() if k in self._columns}
for k,v in self.items():
if v.__class__ == DotDict:
data[k] = v.asDict()
return data
def asDict(self):
return self._filter_data()
def _update(self, new_data={}, **kwargs):
kwargs.update(new_data)
for k,v in kwargs.items():
if type(v) == dict:
self[k] = DotDict(v)
self[k] = v
def delete(self):
with self._db.session() as s:
return self.delete_session(s)
def delete_session(self, s):
return s.remove(row=self)
def update(self, dict_data={}, **data):
dict_data.update(data)
self._update(dict_data)
with self._db.session() as s:
s.update(row=self, **self._filter_data())
def update_session(self, s, dict_data={}, **data):
return s.update(row=self, **dict_data, **data)
class Tables(DotDict):
def __init__(self, db, tables={}):
'"tables" should be a dict with the table names for keys and a list of Columns for values'
super().__init__()
self.db = db
self.meta = MetaData()
for name, table in tables.items():
self.__setup_table(name, table)
def __setup_table(self, name, table):
self[name] = Table(name, self.meta, *table)
def Column(name, stype=None, fkey=None, **kwargs):
if not stype and not kwargs:
if name == 'id':
return Column('id', 'integer', primary_key=True, autoincrement=True)
elif name == 'timestamp':
return Column('timestamp', 'datetime')
raise ValueError('Missing column type and options')
else:
options = [name, SqlTypes.get(stype.lower(), SqlTypes['string'])]
if fkey:
options.append(ForeignKey(fkey))
return SqlColumn(*options, **kwargs)
class MissingDatabaseError(Exception):
'''raise when the "database" kwargs is not set'''