387 lines
8.8 KiB
Python
387 lines
8.8 KiB
Python
import sys
|
|
|
|
from contextlib import contextmanager
|
|
from datetime import datetime
|
|
from sqlalchemy import create_engine, ForeignKey, MetaData, Table
|
|
from sqlalchemy import Column as SqlColumn, types as Types
|
|
#from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.exc import OperationalError, ProgrammingError
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
from . import logging
|
|
from .cache import LRUCache
|
|
from .misc import DotDict, RandomGen, NfsCheck, PrintMethods
|
|
|
|
|
|
SqlTypes = DotDict({t.lower(): getattr(Types, t) for t in dir(Types) if not t.startswith('_')})
|
|
|
|
|
|
class DataBase():
|
|
def __init__(self, dbtype='postgresql+psycopg2', tables={}, **kwargs):
|
|
self.engine_string = self.__engine_string(dbtype, kwargs)
|
|
self.db = create_engine(self.engine_string)
|
|
self.table = Tables(self, tables)
|
|
self.table_names = tables.keys()
|
|
self.classes = kwargs.get('row_classes', CustomRows())
|
|
self.cache = None
|
|
|
|
session_class = kwargs.get('session_class', Session)
|
|
self.session = lambda trans=True: session_class(self, trans)
|
|
|
|
self.SetupCache()
|
|
|
|
|
|
def __engine_string(self, dbtype, kwargs):
|
|
if not kwargs.get('database'):
|
|
raise MissingDatabaseError('Database not set')
|
|
|
|
engine_string = dbtype + '://'
|
|
|
|
if dbtype == 'sqlite':
|
|
if NfsCheck(kwargs.get('database')):
|
|
logging.error('Database file is on an NFS share which does not support locking. Any writes to the database will fail')
|
|
|
|
engine_string += '/' + kwargs.get('database')
|
|
|
|
else:
|
|
user = kwargs.get('user')
|
|
password = kwargs.get('pass')
|
|
host = kwargs.get('host', '/var/run/postgresql')
|
|
port = kwargs.get('port', 5432)
|
|
name = kwargs.get('name', 'postgres')
|
|
maxconn = kwargs.get('maxconnections', 25)
|
|
|
|
if user:
|
|
if password:
|
|
engine_string += f'{user}:{password}@'
|
|
else:
|
|
engine_string += user + '@'
|
|
|
|
if host == '/var/run/postgresql':
|
|
engine_string += '/' + name
|
|
|
|
else:
|
|
engine_string += f'{host}:{port}/{name}'
|
|
|
|
return engine_string
|
|
|
|
|
|
def close(self):
|
|
self.SetupCache()
|
|
|
|
|
|
def SetupCache(self):
|
|
self.cache = DotDict({table: LRUCache() for table in self.table_names})
|
|
|
|
|
|
def CreateTables(self, *tables):
|
|
new_tables = [self.table[table] for table in tables]
|
|
self.table.meta.create_all(bind=self.db, tables=new_tables)
|
|
|
|
|
|
def CreateDatabase(self):
|
|
if self.engine_string.startswith('postgresql'):
|
|
predb = create_engine(db.engine_string.replace(config.db.name, 'postgres', -1))
|
|
conn = predb.connect()
|
|
conn.execute('commit')
|
|
|
|
try:
|
|
conn.execute(f'CREATE DATABASE {config.db.name}')
|
|
|
|
except ProgrammingError:
|
|
'The database already exists, so just move along'
|
|
|
|
except Exception as e:
|
|
conn.close()
|
|
raise e from None
|
|
|
|
conn.close()
|
|
|
|
self.table.meta.create_all(self.db)
|
|
|
|
|
|
def execute(self, *args, **kwargs):
|
|
with self.session() as s:
|
|
return s.execute(*args, **kwargs)
|
|
|
|
|
|
class Session(object):
|
|
def __init__(self, db, trans=True):
|
|
self.db = db
|
|
self.classes = self.db.classes
|
|
self.session = sessionmaker(bind=db.db)()
|
|
self.table = self.db.table
|
|
self.cache = self.db.cache
|
|
self.trans = trans
|
|
|
|
# session aliases
|
|
self.s = self.session
|
|
self.begin = self.s.begin
|
|
self.commit = self.s.commit
|
|
self.rollback = self.s.rollback
|
|
self.query = self.s.query
|
|
self.execute = self.s.execute
|
|
|
|
self._setup()
|
|
|
|
if not self.trans:
|
|
self.commit()
|
|
|
|
|
|
def __enter__(self):
|
|
self.sessionid = RandomGen(10)
|
|
return self
|
|
|
|
|
|
def __exit__(self, exctype, value, tb):
|
|
if tb:
|
|
self.rollback()
|
|
|
|
else:
|
|
self.commit()
|
|
|
|
|
|
def _setup(self):
|
|
pass
|
|
|
|
|
|
def count(self, table_name, **kwargs):
|
|
return self.query(self.table[table_name]).filter_by(**kwargs).count()
|
|
|
|
|
|
def fetch(self, table_name, single=True, orderby=None, orderdir='asc', **kwargs):
|
|
table = self.table[table_name]
|
|
RowClass = self.classes.get(table_name.capitalize())
|
|
|
|
query = self.query(table).filter_by(**kwargs)
|
|
|
|
if not orderby:
|
|
rows = query.all()
|
|
|
|
else:
|
|
if orderdir == 'asc':
|
|
rows = query.order_by(getattr(table.c, orderby).asc()).all()
|
|
|
|
elif orderdir == 'desc':
|
|
rows = query.order_by(getattr(table.c, orderby).asc()).all()
|
|
|
|
else:
|
|
raise ValueError(f'Unsupported order direction: {orderdir}')
|
|
|
|
if single:
|
|
return RowClass(table_name, rows[0], self) if len(rows) > 0 else None
|
|
|
|
return [RowClass(table_name, row, self) for row in rows]
|
|
|
|
|
|
def insert(self, table_name, **kwargs):
|
|
row = self.fetch(table_name, **kwargs)
|
|
|
|
if row:
|
|
row.update_session(self, **kwargs)
|
|
return
|
|
|
|
table = self.table[table_name]
|
|
|
|
if getattr(table, 'timestamp', None) and not kwargs.get('timestamp'):
|
|
kwargs['timestamp'] = datetime.now()
|
|
|
|
res = self.execute(table.insert().values(**kwargs))
|
|
#return self.fetch(table_name, **kwargs)
|
|
|
|
|
|
def update(self, table=None, rowid=None, row=None, **data):
|
|
if row:
|
|
rowid = row.id
|
|
table = row._table_name
|
|
|
|
if not rowid or not table:
|
|
raise ValueError('Missing row ID or table')
|
|
|
|
tclass = self.table[table]
|
|
|
|
self.execute(tclass.update().where(tclass.c.id == rowid).values(**data))
|
|
|
|
|
|
def remove(self, table=None, rowid=None, row=None):
|
|
if row:
|
|
rowid = row.id
|
|
table = row._table_name
|
|
|
|
if not rowid or not table:
|
|
raise ValueError('Missing row ID or table')
|
|
|
|
row = self.execute(f'DELETE FROM {table} WHERE id={rowid}')
|
|
|
|
|
|
def DropTables(self):
|
|
tables = self.GetTables()
|
|
|
|
for table in tables:
|
|
self.execute(f'DROP TABLE {table}')
|
|
|
|
|
|
def GetTables(self):
|
|
rows = self.execute("SELECT name FROM sqlite_master WHERE type IN ('table','view') and name NOT LIKE 'sqlite_%'")
|
|
return [row[0] for row in rows]
|
|
|
|
|
|
def AppendColumn(self, tbl, col):
|
|
table = self.table[tbl]
|
|
|
|
try:
|
|
column = getattr(table.c, col)
|
|
|
|
except AttributeError:
|
|
logging.error(f'Table "{tbl}" does not have column "{col}"')
|
|
return
|
|
|
|
columns = [row[1] for row in self.execute(f'PRAGMA table_info({tbl})')]
|
|
|
|
if col in columns:
|
|
logging.info(f'Column "{col}" already exists')
|
|
return
|
|
|
|
sql = f'ALTER TABLE {tbl} ADD COLUMN {col} {column.type}'
|
|
|
|
if not column.nullable:
|
|
sql += ' NOT NULL'
|
|
|
|
if column.primary_key:
|
|
sql += ' PRIMARY KEY'
|
|
|
|
if column.unique:
|
|
sql += ' UNIQUE'
|
|
|
|
self.execute(sql)
|
|
|
|
|
|
def RemoveColumn(self, tbl, col):
|
|
table = self.table[tbl]
|
|
column = getattr(table, col, None)
|
|
columns = [row[1] for row in self.execute(f'PRAGMA table_info({tbl})')]
|
|
|
|
if col not in columns:
|
|
logging.info(f'Column "{col}" already exists')
|
|
return
|
|
|
|
columns.remove(col)
|
|
coltext = ', '.join(columns)
|
|
|
|
self.execute(f'CREATE TABLE {tbl}_temp AS SELECT {coltext} FROM {tbl}')
|
|
self.execute(f'DROP TABLE {tbl}')
|
|
self.execute(f'ALTER TABLE {tbl}_temp RENAME TO {tbl}')
|
|
|
|
|
|
class CustomRows(object):
|
|
def get(self, name):
|
|
return getattr(self, name, self.Row)
|
|
|
|
|
|
class Row(DotDict):
|
|
#_filter_columns = lambda self, row: [attr for attr in dir(row) if not attr.startswith('_') and attr != 'metadata']
|
|
|
|
|
|
def __init__(self, table, row, session):
|
|
if not row:
|
|
return
|
|
|
|
super().__init__()
|
|
self._update(row._asdict())
|
|
|
|
self._db = session.db
|
|
self._table_name = table
|
|
self._columns = self.keys()
|
|
#self._columns = self._filter_columns(row)
|
|
|
|
self.__run__(session)
|
|
|
|
|
|
## Subclass Row and redefine this function
|
|
def __run__(self, s):
|
|
pass
|
|
|
|
|
|
def _filter_data(self):
|
|
data = {k: v for k,v in self.items() if k in self._columns}
|
|
|
|
for k,v in self.items():
|
|
if v.__class__ == DotDict:
|
|
data[k] = v.asDict()
|
|
|
|
return data
|
|
|
|
|
|
def asDict(self):
|
|
return self._filter_data()
|
|
|
|
|
|
def _update(self, new_data={}, **kwargs):
|
|
kwargs.update(new_data)
|
|
|
|
for k,v in kwargs.items():
|
|
if type(v) == dict:
|
|
self[k] = DotDict(v)
|
|
|
|
self[k] = v
|
|
|
|
|
|
def delete(self):
|
|
with self._db.session() as s:
|
|
return self.delete_session(s)
|
|
|
|
|
|
def delete_session(self, s):
|
|
return s.remove(row=self)
|
|
|
|
|
|
def update(self, dict_data={}, **data):
|
|
dict_data.update(data)
|
|
self._update(dict_data)
|
|
|
|
with self._db.session() as s:
|
|
s.update(row=self, **self._filter_data())
|
|
|
|
|
|
def update_session(self, s, dict_data={}, **data):
|
|
return s.update(row=self, **dict_data, **data)
|
|
|
|
|
|
class Tables(DotDict):
|
|
def __init__(self, db, tables={}):
|
|
'"tables" should be a dict with the table names for keys and a list of Columns for values'
|
|
super().__init__()
|
|
|
|
self.db = db
|
|
self.meta = MetaData()
|
|
|
|
for name, table in tables.items():
|
|
self.__setup_table(name, table)
|
|
|
|
|
|
def __setup_table(self, name, table):
|
|
self[name] = Table(name, self.meta, *table)
|
|
|
|
|
|
def Column(name, stype=None, fkey=None, **kwargs):
|
|
if not stype and not kwargs:
|
|
if name == 'id':
|
|
return Column('id', 'integer', primary_key=True, autoincrement=True)
|
|
|
|
elif name == 'timestamp':
|
|
return Column('timestamp', 'datetime')
|
|
|
|
raise ValueError('Missing column type and options')
|
|
|
|
else:
|
|
options = [name, SqlTypes.get(stype.lower(), SqlTypes['string'])]
|
|
|
|
if fkey:
|
|
options.append(ForeignKey(fkey))
|
|
|
|
return SqlColumn(*options, **kwargs)
|
|
|
|
|
|
class MissingDatabaseError(Exception):
|
|
'''raise when the "database" kwargs is not set'''
|