database: track sessions and allow threads with sql, misc: new functions

This commit is contained in:
Izalia Mae 2021-04-14 08:12:34 -04:00
parent 3887a68db9
commit 0d18205d17
2 changed files with 80 additions and 28 deletions

View file

@ -1,25 +1,22 @@
import sys import json, sys, threading, time
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from sqlalchemy import create_engine, ForeignKey, MetaData, Table from sqlalchemy import create_engine, ForeignKey, MetaData, Table
from sqlalchemy import Column as SqlColumn, types as Types from sqlalchemy import Column as SqlColumn, types as Types
#from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.exc import OperationalError, ProgrammingError from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
from . import logging from . import logging
from .cache import LRUCache from .cache import LRUCache
from .misc import DotDict, RandomGen, NfsCheck, PrintMethods from .misc import DotDict, RandomGen, NfsCheck, PrintMethods, Path
SqlTypes = DotDict({t.lower(): getattr(Types, t) for t in dir(Types) if not t.startswith('_')}) SqlTypes = DotDict({t.lower(): getattr(Types, t) for t in dir(Types) if not t.startswith('_')})
class DataBase(): class DataBase():
def __init__(self, dbtype='postgresql+psycopg2', tables={}, **kwargs): def __init__(self, dbtype='postgresql+pg8000', tables={}, **kwargs):
self.engine_string = self.__engine_string(dbtype, kwargs) self.db = self.__create_engine(dbtype, kwargs)
self.db = create_engine(self.engine_string)
self.table = Tables(self, tables) self.table = Tables(self, tables)
self.table_names = tables.keys() self.table_names = tables.keys()
self.classes = kwargs.get('row_classes', CustomRows()) self.classes = kwargs.get('row_classes', CustomRows())
@ -32,7 +29,10 @@ class DataBase():
self.SetupCache() self.SetupCache()
def __engine_string(self, dbtype, kwargs): def __create_engine(self, dbtype, kwargs):
engine_args = []
engine_kwargs = {}
if not kwargs.get('database'): if not kwargs.get('database'):
raise MissingDatabaseError('Database not set') raise MissingDatabaseError('Database not set')
@ -43,6 +43,7 @@ class DataBase():
logging.error('Database file is on an NFS share which does not support locking. Any writes to the database will fail') logging.error('Database file is on an NFS share which does not support locking. Any writes to the database will fail')
engine_string += '/' + kwargs.get('database') engine_string += '/' + kwargs.get('database')
engine_kwargs['connect_args'] = {'check_same_thread': False}
else: else:
user = kwargs.get('user') user = kwargs.get('user')
@ -64,7 +65,7 @@ class DataBase():
else: else:
engine_string += f'{host}:{port}/{name}' engine_string += f'{host}:{port}/{name}'
return engine_string return create_engine(engine_string, *engine_args, **engine_kwargs)
def close(self): def close(self):
@ -110,7 +111,7 @@ class Session(object):
def __init__(self, db, trans=True): def __init__(self, db, trans=True):
self.db = db self.db = db
self.classes = self.db.classes self.classes = self.db.classes
self.session = scoped_session(sessionmaker(bind=db.db))() self.session = sessionmaker(bind=db.db)()
self.table = self.db.table self.table = self.db.table
self.cache = self.db.cache self.cache = self.db.cache
self.trans = trans self.trans = trans
@ -126,9 +127,6 @@ class Session(object):
self._setup() self._setup()
if not self.trans:
self.commit()
def __enter__(self): def __enter__(self):
self.sessionid = RandomGen(10) self.sessionid = RandomGen(10)
@ -149,6 +147,11 @@ class Session(object):
pass pass
@property
def dirty(self):
return any([self.s.new, self.s.dirty, self.s.deleted])
def count(self, table_name, **kwargs): def count(self, table_name, **kwargs):
return self.query(self.table[table_name]).filter_by(**kwargs).count() return self.query(self.table[table_name]).filter_by(**kwargs).count()
@ -178,6 +181,11 @@ class Session(object):
return [RowClass(table_name, row, self) for row in rows] return [RowClass(table_name, row, self) for row in rows]
def search(self, *args, **kwargs):
kwargs.pop('single', None)
return self.fetch(*args, single=False, **kwargs)
def insert(self, table_name, **kwargs): def insert(self, table_name, **kwargs):
row = self.fetch(table_name, **kwargs) row = self.fetch(table_name, **kwargs)
@ -203,10 +211,8 @@ class Session(object):
raise ValueError('Missing row ID or table') raise ValueError('Missing row ID or table')
tclass = self.table[table] tclass = self.table[table]
self.execute(tclass.update().where(tclass.c.id == rowid).values(**data)) self.execute(tclass.update().where(tclass.c.id == rowid).values(**data))
def remove(self, table=None, rowid=None, row=None): def remove(self, table=None, rowid=None, row=None):
if row: if row:
rowid = row.id rowid = row.id
@ -348,7 +354,10 @@ class CustomRows(object):
def update_session(self, s, dict_data={}, **data): def update_session(self, s, dict_data={}, **data):
return s.update(row=self, **dict_data, **data) dict_data.update(data)
self._update(dict_data)
print(dict_data)
return s.update(row=self, **dict_data)
class Tables(DotDict): class Tables(DotDict):
@ -387,4 +396,4 @@ def Column(name, stype=None, fkey=None, **kwargs):
class MissingDatabaseError(Exception): class MissingDatabaseError(Exception):
'''raise when the "database" kwargs is not set''' '''raise when the "database" kwarg is not set'''

View file

@ -153,6 +153,16 @@ def NfsCheck(path):
return False return False
def PortCheck(port, address='127.0.0.1', tcp=True):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM) as s:
try:
return not s.connect_ex((address, port)) == 0
except socket.error as e:
print(e)
return False
def PrintMethods(object, include_underscore=False): def PrintMethods(object, include_underscore=False):
for line in dir(object): for line in dir(object):
if line.startswith('_'): if line.startswith('_'):
@ -162,6 +172,31 @@ def PrintMethods(object, include_underscore=False):
else: else:
print(line) print(line)
class Connection(socket.socket):
def __init__(self, address='127.0.0.1', port=8080, tcp=True):
super().__init__(socket.AF_INET, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM)
self.address = address
self.port = port
def __enter__(self):
self.connect((self.address, self.port))
return self
def __exit__(self, exctype, value, tb):
self.close()
def send(self, msg):
self.sendall(msg)
def recieve(self, size=8192):
return self.recv(size)
class DotDict(dict): class DotDict(dict):
def __init__(self, value=None, **kwargs): def __init__(self, value=None, **kwargs):
'''Python dictionary, but variables can be set/get via attributes '''Python dictionary, but variables can be set/get via attributes
@ -171,7 +206,6 @@ class DotDict(dict):
kwargs: key/value pairs to set on init. Overrides identical keys set by 'value' kwargs: key/value pairs to set on init. Overrides identical keys set by 'value'
''' '''
super().__init__() super().__init__()
data = {}
if isinstance(value, (str, bytes)): if isinstance(value, (str, bytes)):
self.fromJson(value) self.fromJson(value)
@ -265,7 +299,7 @@ class DotDict(dict):
def toJson(self, indent=None, **kwargs): def toJson(self, indent=None, **kwargs):
if 'cls' not in kwargs: if 'cls' not in kwargs:
kwargs['cls'] = DotDictEncoder kwargs['cls'] = JsonEncoder
return json.dumps(dict(self), indent=indent, **kwargs) return json.dumps(dict(self), indent=indent, **kwargs)
@ -275,6 +309,15 @@ class DotDict(dict):
self.update(data) self.update(data)
def load_json(self, path: str=None):
self.update(Path(path).load_json())
def save_json(self, path: str, **kwargs):
with Path(path).open(w) as fd:
write(self.toJson(*kwargs))
class DefaultDict(DotDict): class DefaultDict(DotDict):
def __getattr__(self, key): def __getattr__(self, key):
try: try:
@ -325,7 +368,7 @@ class Path(object):
if str(path).startswith('~'): if str(path).startswith('~'):
self.__path = self.__path.expanduser() self.__path = self.__path.expanduser()
self.json = DotDict({}) self.json = DotDict()
self.exist = exist self.exist = exist
self.missing = missing self.missing = missing
self.parents = parents self.parents = parents
@ -475,24 +518,24 @@ class Path(object):
return True if self.__path.exists() else False return True if self.__path.exists() else False
def loadJson(self): def load_json(self):
self.json = DotDict(self.read()) self.json = DotDict(self.read())
return self.json return self.json
def updateJson(self, data={}): def save_json(self, indent=None):
with self.__path.open('w') as fp:
fp.write(json.dumps(self.json.asDict(), indent=indent, cls=JsonEncoder))
def update_json(self, data={}):
if type(data) == str: if type(data) == str:
data = json.loads(data) data = json.loads(data)
self.json.update(data) self.json.update(data)
def storeJson(self, indent=None):
with self.__path.open('w') as fp:
fp.write(json.dumps(self.json.asDict(), indent=indent))
# This needs to be extended to handle dirs with files/sub-dirs # This needs to be extended to handle dirs with files/sub-dirs
def delete(self): def delete(self):
if self.isdir(): if self.isdir():
@ -516,7 +559,7 @@ class Path(object):
return self.open().readlines() return self.open().readlines()
class DotDictEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
if type(obj) not in [str, int, float, dict]: if type(obj) not in [str, int, float, dict]:
return str(obj) return str(obj)