http_server_async: add transport class, sql2: new sub-module

This commit is contained in:
Izalia Mae 2021-12-11 18:46:12 -05:00
parent 3e56885a49
commit 1e62836754
19 changed files with 1639 additions and 54 deletions

View file

@ -1,4 +1,4 @@
import json, mimetypes
import json, mimetypes, traceback
from datetime import datetime, timezone
from functools import partial
@ -30,26 +30,45 @@ url_keys = [
]
def parse_privacy_level(to: list=[], cc: list=[]):
if to == [pubstr] and len(cc) == 1:
def parse_privacy_level(to: list=[], cc: list=[], followers=None):
if pubstr in to and followers in cc:
return 'public'
elif to and self.actor in to[0] and not cc:
elif followers in to and pubstr in cc:
return 'unlisted'
elif pubstr not in to:
elif pubstr not in to and pubstr not in cc and followers in cc:
return 'private'
elif not tuple(item for item in [*to, *cc] if item not in [pubstr, followers]):
return 'direct'
else:
logging.warning('Not sure what this privacy level is')
logging.debug(f'to: {json.dumps(to)}')
logging.debug(f'cc: {json.dumps(cc)}')
logging.debug(f'followers: {followers}')
def generate_privacy_fields(privacy='public'):
def generate_privacy_fields(privacy='public', followers=None, to=[], cc=[]):
if privacy == 'public':
return ([pubstr])
to = [pubstr, *to]
cc = [followers, *to]
elif privacy == 'unlisted':
to = [followers, *to]
cc = [pubstr, *to]
elif privacy == 'private':
cc = [followers, *cc]
elif privacy == 'direct':
pass
else:
raise ValueError(f'Unknown privacy level: {privacy}')
return to, cc
class Object(DotDict):
def __setitem__(self, key, value):
@ -384,7 +403,11 @@ class Object(DotDict):
@property
def privacy_level(self):
return parse_privacy_level(self.get('to', []), self.get('cc', []))
return parse_privacy_level(
self.get('to', []),
self.get('cc', []),
self.get('attributedTo', '') + '/followers'
)
@property
@ -416,7 +439,7 @@ class Object(DotDict):
@property
def info_table(self):
return DotDict({p.name: p.value for p in self.get('attachment', {})})
return DotDict({p['name']: p['value'] for p in self.get('attachment', {})})
@property
@ -429,6 +452,16 @@ class Object(DotDict):
return self.get('summary')
@property
def avatar(self):
return self.icon.url
@property
def header(self):
return self.image.url
class Collection(Object):
@classmethod
def new_replies(cls, statusid):

View file

@ -24,3 +24,23 @@ class MethodNotHandledException(Exception):
class NoBlueprintForPath(Exception):
'raise when no blueprint is found for a specific path'
class NoConnectionError(Exception):
'Raise when a function requiring a connection gets called when there is no connection'
class MaxConnectionsError(Exception):
'Raise when the max amount of connections has been reached'
class NoTransactionError(Exception):
'Raise when trying to execute an SQL write statement outside a transaction'
class NoTableLayoutError(Exception):
'Raise when a table layout is necessary, but not loaded'
class UpdateAllRowsError(Exception):
'Raise when an UPDATE tries to modify all rows in a table'

View file

@ -17,7 +17,6 @@ def create_app(appname, **kwargs):
from .application import Application, Blueprint
from .middleware import MediaCacheControl
from .misc import Cookies, Headers
from .request import Request
from .response import Response
from .view import View, Static

View file

@ -8,6 +8,7 @@ from .config import Config
from .response import Response
#from .router import Router
from .view import Static, Manifest, Robots, Style
from .transport import Transport
from .. import logging
from ..dotdict import DotDict
@ -16,7 +17,7 @@ from ..misc import signal_handler
from ..path import Path
try:
from ..sql import Database
from ..sql2 import Database
except ImportError:
Database = NotImplementedError('Failed to import SQL database class')
@ -52,7 +53,7 @@ class ApplicationBase:
if isinstance(Database, Exception):
raise Database from None
self.db = dbclass(dbtype, **dbargs)
self.db = dbclass(dbtype, **dbargs, app=self)
def __getitem__(self, key):
@ -320,13 +321,20 @@ class Application(ApplicationBase):
async def handle_client(self, reader, writer):
transport = Transport(self, reader, writer)
request = None
response = None
try:
request = self.cfg.request_class(self, reader, writer.get_extra_info('peername')[0])
request = self.cfg.request_class(self, transport)
response = self.cfg.response_class(request=request)
await request.parse_headers()
try:
await request.parse_headers()
except asyncio.exceptions.IncompleteReadError as e:
request = None
raise e from None
try:
# this doesn't work all the time for some reason
@ -336,8 +344,8 @@ class Application(ApplicationBase):
except NoBlueprintForPath:
response = await self.handle_request(request, response)
except Exception as e:
traceback.print_exc()
#except Exception as e:
#traceback.print_exc()
except NotFound:
response = self.cfg.response_class(request=request).set_error('Not Found', 404)
@ -357,22 +365,22 @@ class Application(ApplicationBase):
except:
traceback.print_exc()
## Don't use a custom response class here just in case it caused the error
response = Response(request=request).set_error('Server Error', 500)
if not response.streaming:
## Don't use a custom response class here just in case it caused the error
response = Response(request=request).set_error('Server Error', 500)
try:
response.headers.update(self.cfg.default_headers)
writer.write(response.compile())
await writer.drain()
if not response.streaming:
try:
response.headers.update(self.cfg.default_headers)
await transport.write(response.compile())
if request and not request.path.startswith('/framework'):
logging.info(f'{request.remote} {request.method} {request.path} {response.status} {len(response.body)} {request.agent}')
if request and request.log and not request.path.startswith('/framework'):
logging.info(f'{request.remote} {request.method} {request.path} {response.status} {len(response.body)} {request.agent}')
except:
traceback.print_exc()
except:
traceback.print_exc()
writer.close()
await writer.wait_closed()
await transport.close()
class Blueprint(ApplicationBase):

View file

@ -17,28 +17,28 @@ LocalTime = datetime.now(UtcTime).astimezone().tzinfo
class Request:
__slots__ = [
'_body', '_form', '_reader', '_method', '_app', '_params',
'_body', '_form', '_method', '_app', '_params',
'address', 'path', 'version', 'headers', 'cookies',
'query', 'raw_query'
'query', 'raw_query', 'transport', 'log'
]
ctx = DotDict()
def __init__(self, app, reader, address):
def __init__(self, app, transport):
super().__init__()
self._app = app
self._reader = reader
self._body = b''
self._form = DotDict()
self._method = None
self._params = None
self.transport = transport
self.headers = Headers()
self.cookies = Cookies()
self.query = DotDict()
self.address = address
self.address = transport.client_address
self.path = None
self.version = None
self.raw_query = None
@ -141,14 +141,9 @@ class Request:
return self._params
async def read(self, length=2048, timeout=None):
try: return await asyncio.wait_for(self._reader.read(length), timeout or self.app.cfg.timeout)
except: return
async def body(self):
if not self._body and self.length:
self._body = await self.read(self.length)
self._body = await self.transport.read(self.length)
return self._body
@ -178,7 +173,7 @@ class Request:
async def parse_headers(self):
data = (await self._reader.readuntil(b'\r\n\r\n')).decode('utf-8')
data = (await self.transport.readuntil(b'\r\n\r\n')).decode('utf-8')
for idx, line in enumerate(data.splitlines()):
if idx == 0:

View file

@ -64,6 +64,11 @@ class Response:
return len(self.body)
@property
def streaming(self):
return self.headers.getone('Transfer-Encoding') == 'chunked'
def append(self, data):
self._body += self._parse_body_data(data)
@ -137,7 +142,7 @@ class Response:
return self
def set_json(self, body={}, status=None, activity=False,):
def set_json(self, body={}, status=None, activity=False):
self.content_type = 'application/activity+json' if activity else 'application/json'
self.body = body
@ -177,11 +182,23 @@ class Response:
return self
async def set_streaming(self, transport, headers={}):
self.headers.update(headers)
self.headers.update(transport.app.cfg.default_headers)
self.headers.setall('Transfer-encoding', 'chunked')
transport.write(self._compile_headers())
def set_cookie(self, key, value, **kwargs):
self.cookies[key] = CookieItem(key, value, **kwargs)
def compile(self):
return self._compile_headers() + self.body
def _compile_headers(self):
data = bytes(f'HTTP/1.1 {self.status}', 'utf-8')
for k,v in self.headers.items():
@ -198,9 +215,7 @@ class Response:
data += bytes(f'\r\nDate: {datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S GMT")}', 'utf-8')
data += bytes(f'\r\nContent-Length: {len(self.body)}', 'utf-8')
data += b'\r\n\r\n'
data += self.body
return data

View file

@ -0,0 +1,65 @@
import asyncio
from ..dotdict import DotDict
class Transport:
def __init__(self, app, reader, writer):
self.app = app
self.reader = reader
self.writer = writer
@property
def client_address(self):
return self.writer.get_extra_info('peername')[0]
@property
def client_port(self):
return self.writer.get_extra_info('peername')[1]
@property
def closed(self):
return self.writer.is_closing()
async def read(self, length=2048, timeout=None):
return await asyncio.wait_for(self.reader.read(length), timeout or self.app.cfg.timeout)
async def readuntil(self, bytes, timeout=None):
return await asyncio.wait_for(self.reader.readuntil(bytes), timeout or self.app.cfg.timeout)
async def write(self, data):
if isinstance(data, DotDict):
data = data.to_json()
elif any(map(isinstance, [data], [dict, list, tuple])):
data = json.dumps(data)
# not sure if there's a better type to use, but this should be fine for now
elif any(map(isinstance, [data], [float, int])):
data = str(data)
elif isinstance(data, bytearray):
data = str(data)
elif not any(map(isinstance, [data], [bytes, str])):
raise TypeError('Data must be or a str, bytes, bytearray, float, it, dict, list, or tuple')
if isinstance(data, str):
data = data.encode('utf-8')
self.writer.write(data)
await self.writer.drain()
async def close(self):
if self.closed:
return
self.writer.close()
await self.writer.wait_closed()

View file

@ -24,6 +24,7 @@ class Row(DotDict):
except:
self._update(row)
self.__session = session
self.__db = session.db
self.__table_name = table
@ -40,6 +41,11 @@ class Row(DotDict):
return self.__table_name
@property
def session(self):
return self.__session
@property
def columns(self):
return self.keys()

View file

@ -91,23 +91,21 @@ class Session(sqlalchemy_session):
query = self.query(self.table[table]).filter_by(**kwargs)
if not orderby:
rows = query.all()
else:
if orderby:
if orderdir == 'asc':
rows = query.order_by(getattr(self.table[table].c, orderby).asc()).all()
query = query.order_by(getattr(self.table[table].c, orderby).asc())
elif orderdir == 'desc':
rows = query.order_by(getattr(self.table[table].c, orderby).desc()).all()
query = query.order_by(getattr(self.table[table].c, orderby).desc())
else:
raise ValueError(f'Unsupported order direction: {orderdir}')
if single:
return RowClass(table, rows[0], self) if len(rows) > 0 else None
row = query.first()
return RowClass(table, row, self) if row else None
return [RowClass(table, row, self) for row in rows]
return [RowClass(table, row, self) for row in query.all()]
def search(self, *args, **kwargs):
@ -131,10 +129,10 @@ class Session(sqlalchemy_session):
if getattr(self.table[table], 'timestamp', None) and not kwargs.get('timestamp'):
kwargs['timestamp'] = datetime.now()
self.execute(self.table[table].insert().values(**kwargs))
cursor = self.execute(self.table[table].insert().values(**kwargs))
if return_row:
return self.fetch(table, **kwargs)
return self.fetch(table, id=cursor.inserted_primary_key[0])
def update(self, table=None, rowid=None, row=None, return_row=False, **kwargs):
@ -145,7 +143,7 @@ class Session(sqlalchemy_session):
if not rowid or not table:
raise ValueError('Missing row ID or table')
self.execute(self.table[table].update().where(self.table[table].c.id == rowid).values(**kwargs))
row = self.execute(self.table[table].update().where(self.table[table].c.id == rowid).values(**kwargs))
if return_row:
return self.fetch(table, id=rowid)

6
izzylib/sql2/__init__.py Normal file
View file

@ -0,0 +1,6 @@
from .database import Database, Connection
from .result import Result
from .row import Row
from .session import Session
from .statements import Comparison, Statement, Select, Insert, Update, Delete, Count
from .table import Column

102
izzylib/sql2/config.py Normal file
View file

@ -0,0 +1,102 @@
import sqlite3
from getpass import getuser
from importlib import import_module
from .result import Result
from .row import Row
from .session import Session
from ..config import BaseConfig
from ..path import Path
class Config(BaseConfig):
def __init__(self, **kwargs):
super().__init__(
appname = 'IzzyLib SQL Client',
type = 'sqlite',
module = None,
module_name = None,
tables = {},
row_classes = {},
session_class = Session,
result_class = Result,
host = 'localhost',
port = 0,
database = None,
username = getuser(),
password = None,
minconnections = 4,
maxconnections = 25,
engine_args = {},
auto_trans = True,
connect_function = None,
autocommit = False
)
for k, v in kwargs.items():
self[k] = v
if not self.database:
if self.type == 'sqlite':
self.database = ':memory:'
else:
raise ValueError('Missing database name')
if not self.port:
if self.type == 'postgresql':
self.port = 5432
elif self.type == 'mysql':
self.port = 3306
if not self.module and not self.connect_function:
if self.type == 'sqlite':
self.module = sqlite3
self.module_name = 'sqlite3'
elif self.type == 'postgresql':
for mod in ['pg8000.dbapi', 'pgdb', 'psycopg2']:
try:
self.module = import_module(mod)
self.module_name = mod
break
except ImportError:
pass
elif self.type == 'mysql':
try:
self.module = import_module('mysql.connector')
self.module_name = 'mysql.connector'
except ImportError:
pass
if not self.module:
raise ImportError(f'Cannot find module for "{self.type}"')
self.module.paramstyle = 'qmark'
@property
def dbargs(self):
return {key: self[key] for key in ['host', 'port', 'database', 'username', 'password']}
def parse_value(self, key, value):
if key == 'type':
if value not in ['sqlite', 'postgresql', 'mysql', 'mssql']:
raise ValueError(f'Invalid database type: {value}')
if key == 'port':
if not isinstance(value, int):
raise TypeError('Port is not an integer')
if key == 'row_classes':
for row_class in value.values():
if not issubclass(row_class, Row):
raise TypeError(f'Row classes must be izzylib.sql2.row.Row, not {row_class.__name__}')
return value

244
izzylib/sql2/database.py Normal file
View file

@ -0,0 +1,244 @@
import itertools
from .config import Config
from .row import Row
from .table import DbTables
from .types import Types
from .. import izzylog
from ..dotdict import DotDict
from ..exceptions import MaxConnectionsError, NoTableLayoutError, NoConnectionError
from ..path import Path
class Database:
def __init__(self, autoconnect=True, app=None, **kwargs):
tables = kwargs.pop('tables', None)
self.cfg = Config(**kwargs)
self.tables = DbTables(self)
self.types = Types(self)
self.connections = []
self.app = app
if tables:
self.load_tables(tables)
if autoconnect:
self.connect()
def connect(self):
for _ in itertools.repeat(None, self.cfg.minconnections):
self.get_connection()
def disconnect(self):
for conn in self.connections:
conn.disconnect()
self.connections = []
@property
def session(self):
return self.get_connection().session
def new_connection(self):
if len(self.connections) >= self.cfg.maxconnections:
raise MaxConnectionsError('Too many connections')
conn = Connection(self)
conn.connect()
self.connections.append(conn)
return conn
def close_connection(self, conn):
print('close connection')
conn.close_sessions()
conn.disconnect()
if not conn.conn:
try: self.connections.remove(conn)
except: pass
def get_connection(self):
if not len(self.connections):
return self.new_connection()
if len(self.connections) < self.cfg.minconnections:
return self.new_connection()
for conn in self.connections:
if not len(conn.sessions):
return conn
if len(self.connections) < self.cfg.maxconnections:
return self.new_connection()
conns = {(conn, len(conn.sessions)) for conn in self.connections}
return min(conns, key=lambda x: x[1])[0]
def new_predb(self, database='postgres'):
dbconfig = Config(**self.cfg)
dbconfig['database'] = database
dbconfig['autocommit'] = True
return Database(**dbconfig)
def set_row_class(self, name, row_class):
if not issubclass(row_class, Row):
raise TypeError(f'Row classes must be izzylib.sql2.row.Row, not {row_class.__name__}')
self.cfg.row_classes[name] = row_class
def get_row_class(self, name):
return self.cfg.row_classes.get(name, Row)
def load_tables(self, tables=None):
if tables:
self.tables.load_tables(tables)
else:
with self.session as s:
self.tables.load_tables(s.table_layout())
def create_tables(self):
if self.tables.empty:
raise NoTableLayoutError('Table layout not loaded yet')
with self.session as s:
for table in self.tables.names:
s.execute(self.tables.compile_table(table))
def create_database(self):
if self.cfg.type == 'postgresql':
with self.new_predb().session as s:
if not s.raw_execute('SELECT datname FROM pg_database WHERE datname = ?', [self.cfg.database]).fetchone():
s.raw_execute(f'CREATE DATABASE {self.cfg.database}')
elif self.cfg.type != 'sqlite':
raise NotImplementedError(f'Database type not supported yet: {self.cfg.type}')
self.create_tables(tables)
def drop_database(self, database):
if self.cfg.type == 'sqlite':
izzylog.verbose('drop_database not needed for SQLite')
return
with self.session as s:
if self.cfg.type == 'postgresql':
s.raw_execute(f'DROP DATABASE {database}')
else:
raise NotImplementedError(f'Database type not supported yet: {self.cfg.type}')
class Connection:
def __init__(self, db):
self.db = db
self.cfg = db.cfg
self.sessions = []
self.conn = None
self.connect()
if db.tables.empty:
with self.session as s:
db.load_tables(s.table_layout())
@property
def autocommit(self):
return self.conn.autocommit
@property
def session(self):
return self.cfg.session_class(self)
def connect(self):
if self.conn:
return
dbconfig = self.cfg.dbargs
if self.cfg.type == 'sqlite':
if self.cfg.autocommit:
self.conn = self.cfg.module.connect(dbconfig['database'], isolation_level=None)
else:
self.conn = self.cfg.module.connect(dbconfig['database'])
elif self.cfg.type == 'postgresql':
if Path(self.cfg.host).exists():
dbconfig['unix_sock'] = dbconfig.pop('host')
dbconfig['user'] = dbconfig.pop('username')
dbconfig['application_name'] = self.cfg.appname
self.conn = self.cfg.module.connect(**dbconfig)
else:
self.conn = self.cfg.module.connect(**self.cfg.dbargs)
try:
self.conn.autocommit = self.cfg.autocommit
except AttributeError:
if self.cfg.module_name not in ['sqlite']:
izzylog.verbose('Module does not support autocommit:', self.cfg.module_name)
return self.conn
def disconnect(self):
if not self.conn:
return
self.close_sessions()
self.conn.close()
self.conn = None
def close_sessions(self):
for session in self.sessions:
self.close_session(session)
def close_session(session):
try: self.sessions.remove(session)
except: pass
session.close()
if not len(self.sessions) and len(self.db.connections) > self.cfg.minconnections:
self.disconnect()
def cursor(self):
if not self.conn:
raise
return self.conn.cursor()
def dump_database(self, path='database.sql'):
if self.cfg.type == 'sqlite':
path = Path(path)
with path.open('w') as fd:
fd.write('\n\n'.join(list(self.conn.iterdump())[1:-1]))
else:
raise NotImplementedError('Only SQLite supported atm :/')

68
izzylib/sql2/result.py Normal file
View file

@ -0,0 +1,68 @@
from .row import Row
class Result:
def __init__(self, session):
self.table = None
self.session = session
self.cursor = session.cursor
try:
self.keys = [desc[0] for desc in session.cursor.description]
except TypeError:
self.keys = []
def __iter__(self):
yield from self.all_iter()
@property
def row_class(self):
return self.session.db.get_row_class(self.table)
@property
def last_row_id(self):
if self.session.cfg.type == 'postgresql':
try:
return self.one().id
except:
return None
return self.cursor.lastrowid
@property
def row_count(self):
return self.cursor.rowcount
def set_table(self, table):
self.table = table
def one(self):
data = self.cursor.fetchone()
if not data:
return
return self.row_class(
self.session,
self.table,
{self.keys[idx]: value for idx, value in enumerate(data)},
)
def all(self):
return [row for row in self.all_iter()]
def all_iter(self):
for row in self.cursor:
yield self.row_class(self.session, self.table,
{self.keys[idx]: value for idx, value in enumerate(row)}
)

29
izzylib/sql2/row.py Normal file
View file

@ -0,0 +1,29 @@
from ..dotdict import DotDict
class Row(DotDict):
def __init__(self, session, table, data):
super().__init__(session._parse_data('serialize', table, data))
self._table = table
self._session = session
self.__run__(session)
def __run__(self, session):
pass
@property
def table(self):
return self._table
@property
def rowid(self):
return self.id
@property
def rowid2(self):
return self.get('rowid', self.id)

346
izzylib/sql2/session.py Normal file
View file

@ -0,0 +1,346 @@
import json
from pathlib import Path as PyPath
from .result import Result
from .row import Row
from .statements import Select, Insert, Delete, Count, Update, Statement
from .table import SessionTables
from .. import izzylog
from ..dotdict import DotDict
from ..exceptions import NoTransactionError, UpdateAllRowsError
from ..misc import boolean, random_gen
from ..path import Path
class Session:
def __init__(self, conn):
self.db = conn.db
self.cfg = conn.db.cfg
self.conn = conn
self.sid = random_gen()
self.tables = SessionTables(self)
self.cursor = conn.cursor()
self.trans = False
self.__setup__()
def __enter__(self):
return self
def __exit__(self, exctype, excvalue, traceback):
if traceback:
self.rollback()
else:
self.commit()
def __setup__(self):
pass
def close():
if not self.cursor:
return
self.conn.close_session(self)
def _parse_data(self, action, table, kwargs):
data = {}
if self.db.tables:
for key, value in kwargs.items():
try:
coltype = self.db.tables[table][key].type
except KeyError:
data[key] = value
continue
parser = self.db.types.get_type(coltype)
try:
data[key] = parser(action, self.cfg.type, value)
except Exception as e:
izzylog.error(f'Failed to parse data from the table "{table}": {key} = {value}')
izzylog.debug(f'Parser: {parser}, Type: {coltype}')
raise e from None
else:
data = kwargs
return data
def dump_database(self, path):
import sqlparse
path = Path(path)
with path.open('w') as fd:
line = '\n\n'.join(list(self.conn.iterdump())[1:-1])
fd.write(sqlparse.format(line,
reindent = False,
keyword_case = 'upper',
))
def dump_database2(self, path):
path = Path(path)
with path.open('w') as fd:
fd.write('\n\n'.join(list(self.conn.iterdump())[1:-1]))
def begin(self):
if self.trans or self.cfg.autocommit:
return
self.execute('BEGIN')
self.trans = True
def commit(self):
if not self.trans:
return
self.execute('COMMIT')
self.trans = False
def rollback(self):
if not self.trans:
return
self.execute('ROLLBACK')
self.trans = False
def raw_execute(self, string, values=None):
if type(string) == Path:
string = string.read()
elif type(string) == PyPath:
with string.open() as fd:
string = fd.read()
if values:
self.cursor.execute(string, values)
else:
self.cursor.execute(string)
return self.cursor
def execute(self, string, *values):
if isinstance(string, Statement):
raise TypeError('String must be a str not a Statement')
action = string.split()[0].upper()
if not self.trans and action in ['CREATE', 'INSERT', 'UPDATE', 'UPSERT', 'DROP', 'DELETE', 'ALTER']:
if self.cfg.auto_trans:
self.begin()
else:
raise NoTransactionError(f'Command not supported outside a transaction: {action}')
try:
self.raw_execute(string, values)
except Exception as e:
if type(e).__name__ in ['DatabaseError', 'OperationalError']:
print(string, values)
raise e from None
return Result(self)
def run(self, query):
result = self.execute(query.compile(self.cfg.type), *query.values)
if type(query) == Count:
return list(result.one().values())[0]
result.set_table(query.table)
return result
def run_count(self, query):
return list(self.run(query).one().values())[0]
def count(self, table, **kwargs):
if self.db.tables and table not in self.db.tables:
raise KeyError(f'Table does not exist: {table}')
query = Count(table, **kwargs)
return self.run_count(query)
def fetch(self, table, orderby=None, orderdir='ASC', limit=None, offset=None, **kwargs):
if self.db.tables and table not in self.db.tables:
raise KeyError(f'Table does not exist: {table}')
query = Select(table, **kwargs)
if orderby:
query.order(orderby, orderdir)
if limit:
query.limit(limist)
if offset:
query.offset(offset)
return self.run(query)
def insert(self, table, return_row=False, **kwargs):
if self.db.tables and table not in self.db.tables:
raise KeyError(f'Table does not exist: {table}')
result = self.run(Insert(table, **self._parse_data('deserialize', table, kwargs)))
if return_row:
return self.fetch(table, id=result.last_row_id).one()
return result.last_row_id
def update(self, table, data, return_row=False, **kwargs):
query = Update(table, **data)
for pair in kwargs.items():
query.where(*pair)
if not query._where:
raise UpdateAllRowsError(f'Refusing to update all rows in table: {table}')
result = self.run(query)
if return_row:
return self.fetch(table, id=result.last_row_id).one()
else:
return result
def update_row(self, row, return_row=False, **kwargs):
return self.update(row.table, kwargs, id=row.id, return_row=return_row)
def remove(self, table, **kwargs):
if self.db.tables and table not in self.db.tables:
raise KeyError(f'Table does not exist: {table}')
self.run(Delete(table, self._parse_data('deserialize', table, kwargs)))
def remove_row(self, row):
if not row.table:
raise ValueError('Row not associated with a table')
self.remove(row.table, id=row.id)
def create_tables(self, tables=None):
if tables:
self.load_tables(tables)
if not self.tables:
raise NoTableLayoutError('No table layout available')
for table in self.tables.values():
self.execute(table.compile(self.cfg.type))
def table_layout(self):
tables = {}
if self.cfg.type == 'sqlite':
rows = self.execute("SELECT name, sql FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%'")
for row in rows:
name = row.name
tables[name] = {}
fkeys = {fkey['from']: f'{fkey.table}.{fkey["to"]}' for fkey in self.execute(f'PRAGMA foreign_key_list({name})')}
columns = [col for col in self.execute(f'PRAGMA table_info({name})')]
unique_list = parse_unique(row.sql)
for column in columns:
tables[name][column.name] = dict(
type = column.type.upper(),
nullable = not column.notnull,
default = parse_default(column.dflt_value),
primary_key = bool(column.pk),
foreign_key = fkeys.get(column.name),
unique = column.name in unique_list
)
elif self.cfg.type == 'postgresql':
for row in self.execute("SELECT * FROM information_schema.columns WHERE table_schema not in ('information_schema', 'pg_catalog') ORDER BY table_schema, table_name, ordinal_position"):
table = row.table_name
column = row.column_name
if not tables.get(table):
tables[table] = {}
if not tables[table].get(column):
tables[table][column] = {}
tables[table][column] = dict(
type = row.data_type.upper(),
nullable = boolean(row.is_nullable),
default = row.column_default if row.column_default and not row.column_default.startswith('nextval') else None,
primary_key = None,
foreign_key = None,
unique = None
)
return tables
def parse_unique(sql):
unique_list = []
try:
for raw_line in sql.splitlines():
if 'UNIQUE' not in raw_line:
continue
for line in raw_line.replace('UNIQUE', '').replace('(', '').replace(')', '').split(','):
line = line.strip()
if line:
unique_list.append(line)
except IndexError:
pass
return unique_list
def parse_default(value):
if value == None:
return
if value.startswith("'") and value.endswith("'"):
value = value[1:-1]
else:
try:
value = int(value)
except ValueError:
pass
return value

185
izzylib/sql2/statements.py Normal file
View file

@ -0,0 +1,185 @@
from ..dotdict import DotDict
Comparison = DotDict(
LESS = lambda key: f'{key} < ?',
GREATER = lambda key: f'{key} > ?',
LESS_EQUAL = lambda key: f'{key} <= ?',
GREATER_EQUAL = lambda key: f'{key} >= ?',
EQUAL = lambda key: f'{key} = ?',
NOT_EQUAL = lambda key: f'{key} != ?',
IN = lambda key: f'{key} IN (?)',
NOT_IN = lambda key: f'{key} NOT IN (?)',
LIKE = lambda key: f'{key} LIKE ?',
NOT_LIKE = lambda key: f'{key} NOT LIKE ?'
)
class Statement:
def __init__(self, table):
self.table = table
self.values = []
self._where = ''
self._order = None
self._limit = None
self._offset = None
def __str__(self):
return self.compile('sqlite')
def where(self, key, value, comparison='equal', operator='and'):
try:
comp = Comparison[comparison.upper().replace('-', '_')]
except KeyError:
raise KeyError(f'Invalid comparison: {comparison}')
prefix = f' {operator} ' if self._where else ' '
self._where += f'{prefix}{comp(key)}'
self.values.append(value)
return self
def order(self, column, direction='ASC'):
direction = direction.upper()
assert direction in ['ASC', 'DESC']
self._order = (column, direction)
return self
def limit(self, limit_num):
self._limit = int(limit_num)
return self
def offset(self, offset_num):
self._offset = int(offset_num)
return self
def compile(self, dbtype):
raise NotImplementedError('Do not use the Statement class directly.')
class Select(Statement):
def __init__(self, table, *columns, **kwargs):
super().__init__(table)
self.columns = columns
for key, value in kwargs.items():
self.where(key, value)
def compile(self, dbtype):
data = f'SELECT'
if self.columns:
columns = ','.join(self.columns)
else:
columns = '*'
data += f' {columns} FROM {self.table}'
if self._where:
data += f' WHERE {self._where}'
if self._order:
col, direc = self._order
data += f' ORDER BY {col} {direc}'
if self._limit:
data += f' LIMIT {self._limit}'
if self._offset:
data += f' OFFSET {self._offset}'
return data
class Insert(Statement):
def __init__(self, table, **kwargs):
super().__init__(table)
self.keys = []
for pair in kwargs.items():
self.add_data(*pair)
def add_data(self, key, value):
self.keys.append(key)
self.values.append(value)
def remove_data(self, key):
index = self.keys.index(key)
del self.keys[index]
del self.values[index]
def compile(self, dbtype):
keys = ','.join(self.keys)
values = ','.join('?' for value in self.values)
data = f'INSERT INTO {self.table} ({keys}) VALUES ({values})'
if dbtype == 'postgresql':
data += f' RETURNING id'
return data
class Update(Statement):
def __init__(self, table, **kwargs):
super().__init__(table)
self.keys = []
for key, value in kwargs.items():
self.keys.append(key)
self.values.append(value)
def compile(self, dbtype):
pairs = ','.join(f'{key} = ?' for key in self.keys)
data = f'UPDATE {self.table} SET {pairs} WHERE {self._where}'
if dbtype == 'postgresql':
data += f' RETURNING id'
return data
class Delete(Statement):
def __init__(self, table, **kwargs):
super().__init__(table)
for key, value in kwargs.items():
self.where(key, value)
def compile(self, dbtype):
return f'DELETE FROM {self.table} WHERE {self._where}'
class Count(Statement):
def __init__(self, table, **kwargs):
super().__init__(table)
for key, value in kwargs.items():
self.where(key, value)
def compile(self, dbtype):
data = f'SELECT COUNT(*) FROM {self.table}'
if self._where:
data += f' WHERE {self._where}'
return data

244
izzylib/sql2/table.py Normal file
View file

@ -0,0 +1,244 @@
from ..dotdict import DotDict
class SessionTables:
def __init__(self, session):
self._session = session
self._db = session.db
self._tables = session.db.tables
def __getattr__(self, key):
return SessionTable(session, key, self._tables[key])
def names(self):
return tuple(self._tables.keys())
class SessionTable(DotDict):
def __init__(self, session, name, columns):
super().__init__(columns)
self._name = name
self._session = session
self._db = session.db
@property
def name(self):
return self._name
@property
def columns(self):
return tuple(self.keys())
def fetch(self, **kwargs):
self._check_columns(**kwargs)
return self.session.fetch(self.name, **kwargs)
def insert(self, **kwargs):
self._check_columns(**kwargs)
return self.session.insert(self.name, **kwargs)
def remove(self, **kwargs):
self._check_columns(**kwargs)
return self.session.remove(self.name, **kwargs)
def _check_columns(self, **kwargs):
for key in kwargs.keys():
if key not in self.columns:
raise KeyError(f'Not a column for table "{self.name}": {key}')
class DbTables(DotDict):
def __init__(self, db):
super().__init__()
self._db = db
self._cfg = db.cfg
@property
def empty(self):
return not len(self.keys())
@property
def names(self):
return tuple(self.keys())
def load_tables(self, tables):
for name, columns in tables.items():
self.add_table(name, columns)
def unload_tables(self):
for key in self.names:
del self[key]
def add_table(self, name, columns):
self[name] = {}
if type(columns) == list:
self[name] = {col.name: col for col in columns}
elif isinstance(columns, dict):
for column, data in columns.items():
self[name][column] = DbColumn(self._cfg.type, column, **data)
else:
raise TypeError('Columns must be a list of Column objects or a dict')
def remove_table(self, name):
return self.pop(name)
def get_columns(self, name):
return tuple(self[name].values())
def compile_table(self, table_name, dbtype):
table = self[table_name]
columns = []
foreign_keys = []
for column in self.get_columns(table_name):
columns.append(column.compile(dbtype))
if column.foreign_key:
fkey_table, fkey_col = column.foreign_key
foreign_keys.append(f'FOREIGN KEY ({column.name}) REFERENCES {fkey_table} ({fkey_col})')
return f'CREATE TABLE IF NOT EXISTS {self.name} ({",".join(columns)}{",".join(foreign_keys)})'
def compile_all(self, dbtype):
return [self.compile_table(name, dbtype) for name in self.keys()]
class DbColumn(DotDict):
def __init__(self, dbtype, name, type=None, default=None, primary_key=False, unique=False, nullable=True, autoincrement=False, foreign_key=None):
super().__init__(
name = name,
type = type,
default = default,
primary_key = primary_key,
unique = unique,
nullable = nullable,
autoincrement = autoincrement,
foreign_key = foreign_key
)
if self.name == 'id':
if dbtype == 'sqlite':
self.type = 'INTEGER'
self.autoincrement = True
elif dbtype == 'postgresql':
self.type = 'SERIAL'
self.autoincrement = False
self.primary_key = True
self.unique = False
self.nullable = False
self.default = None
self.foreign_key = None
elif self.name in ['created', 'modified', 'accessed'] and not self.type:
self.type = 'DATETIME'
if not self.type:
raise ValueError(f'Must provide a column type for column: {name}')
try:
self.fkey
except ValueError:
raise ValueError(f'Invalid foreign_key format. Must be "table.column"')
@property
def fkey(self):
try:
return self.foreign_key.split('.')
except AttributeError:
return
def compile(self, dbtype):
line = f'{self.name} {self.type}'
if self.primary_key:
line += ' PRIMARY KEY'
if not self.nullable:
line += ' NOT NULL'
if self.unique:
line += ' UNIQUE'
if self.autoincrement and dbtype != 'postgresql':
line += ' AUTOINCREMENT'
if self.default:
line += f" DEFAULT {parse_default(self.default)}"
return line
class Column(DotDict):
def __init__(self, name, type=None, default=None, primary_key=False, unique=False, nullable=True, autoincrement=False, foreign_key=None):
super().__init__(
name = name,
type = type.upper() if type else None,
default = default,
primary_key = primary_key,
unique = unique,
nullable = nullable,
autoincrement = autoincrement,
foreign_key = foreign_key
)
if self.name == 'id':
self.type = 'SERIAL'
elif self.name in ['created', 'modified', 'accessed'] and not self.type:
self.type = 'DATETIME'
if not self.type:
raise ValueError(f'Must provide a column type for column: {name}')
try:
self.fkey
except ValueError:
raise ValueError(f'Invalid foreign_key format. Must be "table.column"')
@property
def fkey(self):
try:
return self.foreign_key.split('.')
except AttributeError:
return
def parse_default(default):
if isinstance(default, dict) or isinstance(default, list):
default = json.dumps(default)
if type(default) == str:
default = f"'{default}'"
return default

219
izzylib/sql2/types.py Normal file
View file

@ -0,0 +1,219 @@
from datetime import date, time, datetime
from .. import izzylog
from ..dotdict import DotDict, LowerDotDict
Standard = {
'INTEGER',
'INT',
'TINYINT',
'SMALLINT',
'MEDIUMINT',
'BIGINT',
'UNSIGNED BIG INT',
'INT2',
'INT8',
'TEXT',
'CHARACTER',
'CHAR',
'VARCHAR',
'BLOB',
'CLOB',
'REAL',
'DOUBLE',
'DOUBLE PRECISION',
'FLOAT',
'NUMERIC',
'DEC',
'DECIMAL',
'BOOLEAN',
'DATE',
'TIME',
'JSON'
}
Sqlite = {
*Standard,
'DATETIME'
}
Postgresql = {
*Standard,
'SMALLSERIAL',
'SERIAL',
'BIGSERIAL',
'VARYING',
'BYTEA',
'TIMESTAMP',
'INTERVAL',
'POINT',
'LINE',
'LSEG',
'BOX',
'PATH',
'POLYGON',
'CIRCLE',
}
Mysql = {
*Standard,
'FIXED',
'BIT',
'YEAR',
'VARBINARY',
'ENUM',
'SET'
}
class Type:
sqlite = None
postgresql = None
mysql = None
def __getitem__(self, key):
if key in ['sqlite', 'postgresql', 'mysql']:
return getattr(self, key)
raise KeyError(f'Invalid database type: {key}')
def __call__(self, action, dbtype, value):
return getattr(self, action)(dbtype, value)
def name(self, dbtype='sqlite'):
return self[dbtype]
def serialize(self, dbtype, value):
return value
def deserialize(self, dbtype, value):
return value
class Json(Type):
sqlite = 'JSON'
postgresql = 'JSON'
mysql = 'JSON'
def serialize(self, dbtype, value):
izzylog.debug(f'serialize {type(self).__name__}: {type(value).__name__}', value)
if type(value) == str:
return DotDict(value)
return value
def deserialize(self, dbtype, value):
izzylog.debug(f'deserialize {type(self).__name__}: {type(value).__name__}', value)
return DotDict(value).to_json()
class Datetime(Type):
sqlite = 'DATETIME'
postgresql = 'TIMESTAMP'
mysql = 'DATETIME'
def serialize(self, dbtype, value):
izzylog.debug(f'serialize {type(self).__name__}: {type(value).__name__}', value)
if type(value) == str:
return datetime.fromisoformat(value)
elif type(value) == int:
return datetime.fromtimestamp(value)
return value
def deserialize(self, dbtype, value):
izzylog.debug(f'deserialize {type(self).__name__}: {type(value).__name__}', value)
if dbtype == 'sqlite':
return value.isoformat()
return value
class Date(Type):
sqlite = 'DATE'
postgresql = 'DATE'
mysql = 'DATE'
def serialize(self, dbtype, value):
izzylog.debug(f'serialize {type(self).__name__}: {type(value).__name__}', value)
if type(value) == str:
return date.fromisoformat(value)
elif type(value) == int:
return date.fromtimestamp(value)
return value
def deserialize(self, dbtype, value):
izzylog.debug(f'deserialize {type(self).__name__}: {type(value).__name__}', value)
if dbtype == 'sqlite':
return value.isoformat()
return value
class Time(Type):
sqlite = 'TIME'
postgresql = 'TIME'
mysql = 'TIME'
def serialize(self, dbtype, value):
izzylog.debug(f'serialize {type(self).__name__}: {type(value).__name__}', value)
if type(value) == str:
return time.fromisoformat(value)
elif type(value) == int:
return time.fromtimestamp(value)
return value
def deserialize(self, dbtype, value):
izzylog.debug(f'deserialize {type(self).__name__}: {type(value).__name__}', value)
if dbtype == 'sqlite':
return value.isoformat()
return value
class Types(DotDict):
def __init__(self, db):
self._db = db
self.set_type(Json, Date, Time, Datetime)
def get_type(self, name):
return self.get(name.upper(), Type())
def set_type(self, *types):
for type_object in types:
typeclass = type_object()
self[typeclass.name(self._db.cfg.type)] = typeclass

View file

@ -33,6 +33,7 @@ packages =
izzylib.http_server_async
izzylib.http_urllib_client
izzylib.sql
izzylib.sql2
setup_requires =
setuptools >= 38.3.0
@ -59,6 +60,8 @@ http_urllib_client =
sql =
SQLAlchemy == 1.4.23
SQLAlchemy-Paginator == 0.2
sql2 =
sql-metadata == 2.3.0
template =
colour == 0.1.5
Hamlish-Jinja == 0.3.3