database: track sessions and allow threads with sql, misc: new functions

This commit is contained in:
Izalia Mae 2021-04-14 08:12:34 -04:00
parent 3887a68db9
commit 0d18205d17
2 changed files with 80 additions and 28 deletions

View file

@ -1,25 +1,22 @@
import sys
import json, sys, threading, time
from contextlib import contextmanager
from datetime import datetime
from sqlalchemy import create_engine, ForeignKey, MetaData, Table
from sqlalchemy import Column as SqlColumn, types as Types
#from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.orm import scoped_session, sessionmaker
from . import logging
from .cache import LRUCache
from .misc import DotDict, RandomGen, NfsCheck, PrintMethods
from .misc import DotDict, RandomGen, NfsCheck, PrintMethods, Path
SqlTypes = DotDict({t.lower(): getattr(Types, t) for t in dir(Types) if not t.startswith('_')})
class DataBase():
def __init__(self, dbtype='postgresql+psycopg2', tables={}, **kwargs):
self.engine_string = self.__engine_string(dbtype, kwargs)
self.db = create_engine(self.engine_string)
def __init__(self, dbtype='postgresql+pg8000', tables={}, **kwargs):
self.db = self.__create_engine(dbtype, kwargs)
self.table = Tables(self, tables)
self.table_names = tables.keys()
self.classes = kwargs.get('row_classes', CustomRows())
@ -32,7 +29,10 @@ class DataBase():
self.SetupCache()
def __engine_string(self, dbtype, kwargs):
def __create_engine(self, dbtype, kwargs):
engine_args = []
engine_kwargs = {}
if not kwargs.get('database'):
raise MissingDatabaseError('Database not set')
@ -43,6 +43,7 @@ class DataBase():
logging.error('Database file is on an NFS share which does not support locking. Any writes to the database will fail')
engine_string += '/' + kwargs.get('database')
engine_kwargs['connect_args'] = {'check_same_thread': False}
else:
user = kwargs.get('user')
@ -64,7 +65,7 @@ class DataBase():
else:
engine_string += f'{host}:{port}/{name}'
return engine_string
return create_engine(engine_string, *engine_args, **engine_kwargs)
def close(self):
@ -110,7 +111,7 @@ class Session(object):
def __init__(self, db, trans=True):
self.db = db
self.classes = self.db.classes
self.session = scoped_session(sessionmaker(bind=db.db))()
self.session = sessionmaker(bind=db.db)()
self.table = self.db.table
self.cache = self.db.cache
self.trans = trans
@ -126,9 +127,6 @@ class Session(object):
self._setup()
if not self.trans:
self.commit()
def __enter__(self):
self.sessionid = RandomGen(10)
@ -149,6 +147,11 @@ class Session(object):
pass
@property
def dirty(self):
return any([self.s.new, self.s.dirty, self.s.deleted])
def count(self, table_name, **kwargs):
return self.query(self.table[table_name]).filter_by(**kwargs).count()
@ -178,6 +181,11 @@ class Session(object):
return [RowClass(table_name, row, self) for row in rows]
def search(self, *args, **kwargs):
kwargs.pop('single', None)
return self.fetch(*args, single=False, **kwargs)
def insert(self, table_name, **kwargs):
row = self.fetch(table_name, **kwargs)
@ -203,10 +211,8 @@ class Session(object):
raise ValueError('Missing row ID or table')
tclass = self.table[table]
self.execute(tclass.update().where(tclass.c.id == rowid).values(**data))
def remove(self, table=None, rowid=None, row=None):
if row:
rowid = row.id
@ -348,7 +354,10 @@ class CustomRows(object):
def update_session(self, s, dict_data={}, **data):
return s.update(row=self, **dict_data, **data)
dict_data.update(data)
self._update(dict_data)
print(dict_data)
return s.update(row=self, **dict_data)
class Tables(DotDict):
@ -387,4 +396,4 @@ def Column(name, stype=None, fkey=None, **kwargs):
class MissingDatabaseError(Exception):
'''raise when the "database" kwargs is not set'''
'''raise when the "database" kwarg is not set'''

View file

@ -153,6 +153,16 @@ def NfsCheck(path):
return False
def PortCheck(port, address='127.0.0.1', tcp=True):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM) as s:
try:
return not s.connect_ex((address, port)) == 0
except socket.error as e:
print(e)
return False
def PrintMethods(object, include_underscore=False):
for line in dir(object):
if line.startswith('_'):
@ -162,6 +172,31 @@ def PrintMethods(object, include_underscore=False):
else:
print(line)
class Connection(socket.socket):
def __init__(self, address='127.0.0.1', port=8080, tcp=True):
super().__init__(socket.AF_INET, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM)
self.address = address
self.port = port
def __enter__(self):
self.connect((self.address, self.port))
return self
def __exit__(self, exctype, value, tb):
self.close()
def send(self, msg):
self.sendall(msg)
def recieve(self, size=8192):
return self.recv(size)
class DotDict(dict):
def __init__(self, value=None, **kwargs):
'''Python dictionary, but variables can be set/get via attributes
@ -171,7 +206,6 @@ class DotDict(dict):
kwargs: key/value pairs to set on init. Overrides identical keys set by 'value'
'''
super().__init__()
data = {}
if isinstance(value, (str, bytes)):
self.fromJson(value)
@ -265,7 +299,7 @@ class DotDict(dict):
def toJson(self, indent=None, **kwargs):
if 'cls' not in kwargs:
kwargs['cls'] = DotDictEncoder
kwargs['cls'] = JsonEncoder
return json.dumps(dict(self), indent=indent, **kwargs)
@ -275,6 +309,15 @@ class DotDict(dict):
self.update(data)
def load_json(self, path: str=None):
self.update(Path(path).load_json())
def save_json(self, path: str, **kwargs):
with Path(path).open(w) as fd:
write(self.toJson(*kwargs))
class DefaultDict(DotDict):
def __getattr__(self, key):
try:
@ -325,7 +368,7 @@ class Path(object):
if str(path).startswith('~'):
self.__path = self.__path.expanduser()
self.json = DotDict({})
self.json = DotDict()
self.exist = exist
self.missing = missing
self.parents = parents
@ -475,24 +518,24 @@ class Path(object):
return True if self.__path.exists() else False
def loadJson(self):
def load_json(self):
self.json = DotDict(self.read())
return self.json
def updateJson(self, data={}):
def save_json(self, indent=None):
with self.__path.open('w') as fp:
fp.write(json.dumps(self.json.asDict(), indent=indent, cls=JsonEncoder))
def update_json(self, data={}):
if type(data) == str:
data = json.loads(data)
self.json.update(data)
def storeJson(self, indent=None):
with self.__path.open('w') as fp:
fp.write(json.dumps(self.json.asDict(), indent=indent))
# This needs to be extended to handle dirs with files/sub-dirs
def delete(self):
if self.isdir():
@ -516,7 +559,7 @@ class Path(object):
return self.open().readlines()
class DotDictEncoder(json.JSONEncoder):
class JsonEncoder(json.JSONEncoder):
def default(self, obj):
if type(obj) not in [str, int, float, dict]:
return str(obj)