From 1e62836754b3ee3dacb5718928a3e4157e3edd7a Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Sat, 11 Dec 2021 18:46:12 -0500 Subject: [PATCH] http_server_async: add transport class, sql2: new sub-module --- izzylib/activitypub.py | 51 +++- izzylib/exceptions.py | 20 ++ izzylib/http_server_async/__init__.py | 1 - izzylib/http_server_async/application.py | 44 +-- izzylib/http_server_async/request.py | 19 +- izzylib/http_server_async/response.py | 21 +- izzylib/http_server_async/transport.py | 65 +++++ izzylib/sql/rows.py | 6 + izzylib/sql/session.py | 20 +- izzylib/sql2/__init__.py | 6 + izzylib/sql2/config.py | 102 +++++++ izzylib/sql2/database.py | 244 ++++++++++++++++ izzylib/sql2/result.py | 68 +++++ izzylib/sql2/row.py | 29 ++ izzylib/sql2/session.py | 346 +++++++++++++++++++++++ izzylib/sql2/statements.py | 185 ++++++++++++ izzylib/sql2/table.py | 244 ++++++++++++++++ izzylib/sql2/types.py | 219 ++++++++++++++ setup.cfg | 3 + 19 files changed, 1639 insertions(+), 54 deletions(-) create mode 100644 izzylib/http_server_async/transport.py create mode 100644 izzylib/sql2/__init__.py create mode 100644 izzylib/sql2/config.py create mode 100644 izzylib/sql2/database.py create mode 100644 izzylib/sql2/result.py create mode 100644 izzylib/sql2/row.py create mode 100644 izzylib/sql2/session.py create mode 100644 izzylib/sql2/statements.py create mode 100644 izzylib/sql2/table.py create mode 100644 izzylib/sql2/types.py diff --git a/izzylib/activitypub.py b/izzylib/activitypub.py index a7d7527..346ceb6 100644 --- a/izzylib/activitypub.py +++ b/izzylib/activitypub.py @@ -1,4 +1,4 @@ -import json, mimetypes +import json, mimetypes, traceback from datetime import datetime, timezone from functools import partial @@ -30,26 +30,45 @@ url_keys = [ ] -def parse_privacy_level(to: list=[], cc: list=[]): - if to == [pubstr] and len(cc) == 1: +def parse_privacy_level(to: list=[], cc: list=[], followers=None): + if pubstr in to and followers in cc: return 'public' - elif to and self.actor in to[0] and not cc: + elif followers in to and pubstr in cc: return 'unlisted' - elif pubstr not in to: + elif pubstr not in to and pubstr not in cc and followers in cc: return 'private' + elif not tuple(item for item in [*to, *cc] if item not in [pubstr, followers]): + return 'direct' + else: logging.warning('Not sure what this privacy level is') logging.debug(f'to: {json.dumps(to)}') logging.debug(f'cc: {json.dumps(cc)}') + logging.debug(f'followers: {followers}') -def generate_privacy_fields(privacy='public'): +def generate_privacy_fields(privacy='public', followers=None, to=[], cc=[]): if privacy == 'public': - return ([pubstr]) + to = [pubstr, *to] + cc = [followers, *to] + elif privacy == 'unlisted': + to = [followers, *to] + cc = [pubstr, *to] + + elif privacy == 'private': + cc = [followers, *cc] + + elif privacy == 'direct': + pass + + else: + raise ValueError(f'Unknown privacy level: {privacy}') + + return to, cc class Object(DotDict): def __setitem__(self, key, value): @@ -384,7 +403,11 @@ class Object(DotDict): @property def privacy_level(self): - return parse_privacy_level(self.get('to', []), self.get('cc', [])) + return parse_privacy_level( + self.get('to', []), + self.get('cc', []), + self.get('attributedTo', '') + '/followers' + ) @property @@ -416,7 +439,7 @@ class Object(DotDict): @property def info_table(self): - return DotDict({p.name: p.value for p in self.get('attachment', {})}) + return DotDict({p['name']: p['value'] for p in self.get('attachment', {})}) @property @@ -429,6 +452,16 @@ class Object(DotDict): return self.get('summary') + @property + def avatar(self): + return self.icon.url + + + @property + def header(self): + return self.image.url + + class Collection(Object): @classmethod def new_replies(cls, statusid): diff --git a/izzylib/exceptions.py b/izzylib/exceptions.py index 5a151f9..e5ad8af 100644 --- a/izzylib/exceptions.py +++ b/izzylib/exceptions.py @@ -24,3 +24,23 @@ class MethodNotHandledException(Exception): class NoBlueprintForPath(Exception): 'raise when no blueprint is found for a specific path' + + +class NoConnectionError(Exception): + 'Raise when a function requiring a connection gets called when there is no connection' + + +class MaxConnectionsError(Exception): + 'Raise when the max amount of connections has been reached' + + +class NoTransactionError(Exception): + 'Raise when trying to execute an SQL write statement outside a transaction' + + +class NoTableLayoutError(Exception): + 'Raise when a table layout is necessary, but not loaded' + + +class UpdateAllRowsError(Exception): + 'Raise when an UPDATE tries to modify all rows in a table' diff --git a/izzylib/http_server_async/__init__.py b/izzylib/http_server_async/__init__.py index 72d8453..634eace 100644 --- a/izzylib/http_server_async/__init__.py +++ b/izzylib/http_server_async/__init__.py @@ -17,7 +17,6 @@ def create_app(appname, **kwargs): from .application import Application, Blueprint from .middleware import MediaCacheControl -from .misc import Cookies, Headers from .request import Request from .response import Response from .view import View, Static diff --git a/izzylib/http_server_async/application.py b/izzylib/http_server_async/application.py index d33e7dc..6206e30 100644 --- a/izzylib/http_server_async/application.py +++ b/izzylib/http_server_async/application.py @@ -8,6 +8,7 @@ from .config import Config from .response import Response #from .router import Router from .view import Static, Manifest, Robots, Style +from .transport import Transport from .. import logging from ..dotdict import DotDict @@ -16,7 +17,7 @@ from ..misc import signal_handler from ..path import Path try: - from ..sql import Database + from ..sql2 import Database except ImportError: Database = NotImplementedError('Failed to import SQL database class') @@ -52,7 +53,7 @@ class ApplicationBase: if isinstance(Database, Exception): raise Database from None - self.db = dbclass(dbtype, **dbargs) + self.db = dbclass(dbtype, **dbargs, app=self) def __getitem__(self, key): @@ -320,13 +321,20 @@ class Application(ApplicationBase): async def handle_client(self, reader, writer): + transport = Transport(self, reader, writer) request = None response = None try: - request = self.cfg.request_class(self, reader, writer.get_extra_info('peername')[0]) + request = self.cfg.request_class(self, transport) response = self.cfg.response_class(request=request) - await request.parse_headers() + + try: + await request.parse_headers() + + except asyncio.exceptions.IncompleteReadError as e: + request = None + raise e from None try: # this doesn't work all the time for some reason @@ -336,8 +344,8 @@ class Application(ApplicationBase): except NoBlueprintForPath: response = await self.handle_request(request, response) - except Exception as e: - traceback.print_exc() + #except Exception as e: + #traceback.print_exc() except NotFound: response = self.cfg.response_class(request=request).set_error('Not Found', 404) @@ -357,22 +365,22 @@ class Application(ApplicationBase): except: traceback.print_exc() - ## Don't use a custom response class here just in case it caused the error - response = Response(request=request).set_error('Server Error', 500) + if not response.streaming: + ## Don't use a custom response class here just in case it caused the error + response = Response(request=request).set_error('Server Error', 500) - try: - response.headers.update(self.cfg.default_headers) - writer.write(response.compile()) - await writer.drain() + if not response.streaming: + try: + response.headers.update(self.cfg.default_headers) + await transport.write(response.compile()) - if request and not request.path.startswith('/framework'): - logging.info(f'{request.remote} {request.method} {request.path} {response.status} {len(response.body)} {request.agent}') + if request and request.log and not request.path.startswith('/framework'): + logging.info(f'{request.remote} {request.method} {request.path} {response.status} {len(response.body)} {request.agent}') - except: - traceback.print_exc() + except: + traceback.print_exc() - writer.close() - await writer.wait_closed() + await transport.close() class Blueprint(ApplicationBase): diff --git a/izzylib/http_server_async/request.py b/izzylib/http_server_async/request.py index 39705dd..4a73b94 100644 --- a/izzylib/http_server_async/request.py +++ b/izzylib/http_server_async/request.py @@ -17,28 +17,28 @@ LocalTime = datetime.now(UtcTime).astimezone().tzinfo class Request: __slots__ = [ - '_body', '_form', '_reader', '_method', '_app', '_params', + '_body', '_form', '_method', '_app', '_params', 'address', 'path', 'version', 'headers', 'cookies', - 'query', 'raw_query' + 'query', 'raw_query', 'transport', 'log' ] ctx = DotDict() - def __init__(self, app, reader, address): + def __init__(self, app, transport): super().__init__() self._app = app - self._reader = reader self._body = b'' self._form = DotDict() self._method = None self._params = None + self.transport = transport self.headers = Headers() self.cookies = Cookies() self.query = DotDict() - self.address = address + self.address = transport.client_address self.path = None self.version = None self.raw_query = None @@ -141,14 +141,9 @@ class Request: return self._params - async def read(self, length=2048, timeout=None): - try: return await asyncio.wait_for(self._reader.read(length), timeout or self.app.cfg.timeout) - except: return - - async def body(self): if not self._body and self.length: - self._body = await self.read(self.length) + self._body = await self.transport.read(self.length) return self._body @@ -178,7 +173,7 @@ class Request: async def parse_headers(self): - data = (await self._reader.readuntil(b'\r\n\r\n')).decode('utf-8') + data = (await self.transport.readuntil(b'\r\n\r\n')).decode('utf-8') for idx, line in enumerate(data.splitlines()): if idx == 0: diff --git a/izzylib/http_server_async/response.py b/izzylib/http_server_async/response.py index 86031bb..d116802 100644 --- a/izzylib/http_server_async/response.py +++ b/izzylib/http_server_async/response.py @@ -64,6 +64,11 @@ class Response: return len(self.body) + @property + def streaming(self): + return self.headers.getone('Transfer-Encoding') == 'chunked' + + def append(self, data): self._body += self._parse_body_data(data) @@ -137,7 +142,7 @@ class Response: return self - def set_json(self, body={}, status=None, activity=False,): + def set_json(self, body={}, status=None, activity=False): self.content_type = 'application/activity+json' if activity else 'application/json' self.body = body @@ -177,11 +182,23 @@ class Response: return self + async def set_streaming(self, transport, headers={}): + self.headers.update(headers) + self.headers.update(transport.app.cfg.default_headers) + self.headers.setall('Transfer-encoding', 'chunked') + + transport.write(self._compile_headers()) + + def set_cookie(self, key, value, **kwargs): self.cookies[key] = CookieItem(key, value, **kwargs) def compile(self): + return self._compile_headers() + self.body + + + def _compile_headers(self): data = bytes(f'HTTP/1.1 {self.status}', 'utf-8') for k,v in self.headers.items(): @@ -198,9 +215,7 @@ class Response: data += bytes(f'\r\nDate: {datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S GMT")}', 'utf-8') data += bytes(f'\r\nContent-Length: {len(self.body)}', 'utf-8') - data += b'\r\n\r\n' - data += self.body return data diff --git a/izzylib/http_server_async/transport.py b/izzylib/http_server_async/transport.py new file mode 100644 index 0000000..2c75d62 --- /dev/null +++ b/izzylib/http_server_async/transport.py @@ -0,0 +1,65 @@ +import asyncio + +from ..dotdict import DotDict + + +class Transport: + def __init__(self, app, reader, writer): + self.app = app + self.reader = reader + self.writer = writer + + + @property + def client_address(self): + return self.writer.get_extra_info('peername')[0] + + + @property + def client_port(self): + return self.writer.get_extra_info('peername')[1] + + + @property + def closed(self): + return self.writer.is_closing() + + + async def read(self, length=2048, timeout=None): + return await asyncio.wait_for(self.reader.read(length), timeout or self.app.cfg.timeout) + + + async def readuntil(self, bytes, timeout=None): + return await asyncio.wait_for(self.reader.readuntil(bytes), timeout or self.app.cfg.timeout) + + + async def write(self, data): + if isinstance(data, DotDict): + data = data.to_json() + + elif any(map(isinstance, [data], [dict, list, tuple])): + data = json.dumps(data) + + # not sure if there's a better type to use, but this should be fine for now + elif any(map(isinstance, [data], [float, int])): + data = str(data) + + elif isinstance(data, bytearray): + data = str(data) + + elif not any(map(isinstance, [data], [bytes, str])): + raise TypeError('Data must be or a str, bytes, bytearray, float, it, dict, list, or tuple') + + if isinstance(data, str): + data = data.encode('utf-8') + + self.writer.write(data) + await self.writer.drain() + + + async def close(self): + if self.closed: + return + + self.writer.close() + await self.writer.wait_closed() diff --git a/izzylib/sql/rows.py b/izzylib/sql/rows.py index 5280aed..48aac68 100644 --- a/izzylib/sql/rows.py +++ b/izzylib/sql/rows.py @@ -24,6 +24,7 @@ class Row(DotDict): except: self._update(row) + self.__session = session self.__db = session.db self.__table_name = table @@ -40,6 +41,11 @@ class Row(DotDict): return self.__table_name + @property + def session(self): + return self.__session + + @property def columns(self): return self.keys() diff --git a/izzylib/sql/session.py b/izzylib/sql/session.py index d5bf2c9..b0a8db4 100644 --- a/izzylib/sql/session.py +++ b/izzylib/sql/session.py @@ -91,23 +91,21 @@ class Session(sqlalchemy_session): query = self.query(self.table[table]).filter_by(**kwargs) - if not orderby: - rows = query.all() - - else: + if orderby: if orderdir == 'asc': - rows = query.order_by(getattr(self.table[table].c, orderby).asc()).all() + query = query.order_by(getattr(self.table[table].c, orderby).asc()) elif orderdir == 'desc': - rows = query.order_by(getattr(self.table[table].c, orderby).desc()).all() + query = query.order_by(getattr(self.table[table].c, orderby).desc()) else: raise ValueError(f'Unsupported order direction: {orderdir}') if single: - return RowClass(table, rows[0], self) if len(rows) > 0 else None + row = query.first() + return RowClass(table, row, self) if row else None - return [RowClass(table, row, self) for row in rows] + return [RowClass(table, row, self) for row in query.all()] def search(self, *args, **kwargs): @@ -131,10 +129,10 @@ class Session(sqlalchemy_session): if getattr(self.table[table], 'timestamp', None) and not kwargs.get('timestamp'): kwargs['timestamp'] = datetime.now() - self.execute(self.table[table].insert().values(**kwargs)) + cursor = self.execute(self.table[table].insert().values(**kwargs)) if return_row: - return self.fetch(table, **kwargs) + return self.fetch(table, id=cursor.inserted_primary_key[0]) def update(self, table=None, rowid=None, row=None, return_row=False, **kwargs): @@ -145,7 +143,7 @@ class Session(sqlalchemy_session): if not rowid or not table: raise ValueError('Missing row ID or table') - self.execute(self.table[table].update().where(self.table[table].c.id == rowid).values(**kwargs)) + row = self.execute(self.table[table].update().where(self.table[table].c.id == rowid).values(**kwargs)) if return_row: return self.fetch(table, id=rowid) diff --git a/izzylib/sql2/__init__.py b/izzylib/sql2/__init__.py new file mode 100644 index 0000000..d8dce96 --- /dev/null +++ b/izzylib/sql2/__init__.py @@ -0,0 +1,6 @@ +from .database import Database, Connection +from .result import Result +from .row import Row +from .session import Session +from .statements import Comparison, Statement, Select, Insert, Update, Delete, Count +from .table import Column diff --git a/izzylib/sql2/config.py b/izzylib/sql2/config.py new file mode 100644 index 0000000..7d131f6 --- /dev/null +++ b/izzylib/sql2/config.py @@ -0,0 +1,102 @@ +import sqlite3 + +from getpass import getuser +from importlib import import_module + +from .result import Result +from .row import Row +from .session import Session + +from ..config import BaseConfig +from ..path import Path + + +class Config(BaseConfig): + def __init__(self, **kwargs): + super().__init__( + appname = 'IzzyLib SQL Client', + type = 'sqlite', + module = None, + module_name = None, + tables = {}, + row_classes = {}, + session_class = Session, + result_class = Result, + host = 'localhost', + port = 0, + database = None, + username = getuser(), + password = None, + minconnections = 4, + maxconnections = 25, + engine_args = {}, + auto_trans = True, + connect_function = None, + autocommit = False + ) + + for k, v in kwargs.items(): + self[k] = v + + if not self.database: + if self.type == 'sqlite': + self.database = ':memory:' + + else: + raise ValueError('Missing database name') + + if not self.port: + if self.type == 'postgresql': + self.port = 5432 + + elif self.type == 'mysql': + self.port = 3306 + + if not self.module and not self.connect_function: + if self.type == 'sqlite': + self.module = sqlite3 + self.module_name = 'sqlite3' + + elif self.type == 'postgresql': + for mod in ['pg8000.dbapi', 'pgdb', 'psycopg2']: + try: + self.module = import_module(mod) + self.module_name = mod + break + + except ImportError: + pass + + elif self.type == 'mysql': + try: + self.module = import_module('mysql.connector') + self.module_name = 'mysql.connector' + except ImportError: + pass + + if not self.module: + raise ImportError(f'Cannot find module for "{self.type}"') + + self.module.paramstyle = 'qmark' + + + @property + def dbargs(self): + return {key: self[key] for key in ['host', 'port', 'database', 'username', 'password']} + + + def parse_value(self, key, value): + if key == 'type': + if value not in ['sqlite', 'postgresql', 'mysql', 'mssql']: + raise ValueError(f'Invalid database type: {value}') + + if key == 'port': + if not isinstance(value, int): + raise TypeError('Port is not an integer') + + if key == 'row_classes': + for row_class in value.values(): + if not issubclass(row_class, Row): + raise TypeError(f'Row classes must be izzylib.sql2.row.Row, not {row_class.__name__}') + + return value diff --git a/izzylib/sql2/database.py b/izzylib/sql2/database.py new file mode 100644 index 0000000..4b249a2 --- /dev/null +++ b/izzylib/sql2/database.py @@ -0,0 +1,244 @@ +import itertools + +from .config import Config +from .row import Row +from .table import DbTables +from .types import Types + +from .. import izzylog +from ..dotdict import DotDict +from ..exceptions import MaxConnectionsError, NoTableLayoutError, NoConnectionError +from ..path import Path + + +class Database: + def __init__(self, autoconnect=True, app=None, **kwargs): + tables = kwargs.pop('tables', None) + + self.cfg = Config(**kwargs) + self.tables = DbTables(self) + self.types = Types(self) + self.connections = [] + self.app = app + + if tables: + self.load_tables(tables) + + if autoconnect: + self.connect() + + + def connect(self): + for _ in itertools.repeat(None, self.cfg.minconnections): + self.get_connection() + + + def disconnect(self): + for conn in self.connections: + conn.disconnect() + + self.connections = [] + + + @property + def session(self): + return self.get_connection().session + + + def new_connection(self): + if len(self.connections) >= self.cfg.maxconnections: + raise MaxConnectionsError('Too many connections') + + conn = Connection(self) + conn.connect() + self.connections.append(conn) + + return conn + + + def close_connection(self, conn): + print('close connection') + conn.close_sessions() + conn.disconnect() + + if not conn.conn: + try: self.connections.remove(conn) + except: pass + + + def get_connection(self): + if not len(self.connections): + return self.new_connection() + + if len(self.connections) < self.cfg.minconnections: + return self.new_connection() + + for conn in self.connections: + if not len(conn.sessions): + return conn + + if len(self.connections) < self.cfg.maxconnections: + return self.new_connection() + + conns = {(conn, len(conn.sessions)) for conn in self.connections} + return min(conns, key=lambda x: x[1])[0] + + + def new_predb(self, database='postgres'): + dbconfig = Config(**self.cfg) + dbconfig['database'] = database + dbconfig['autocommit'] = True + + return Database(**dbconfig) + + + def set_row_class(self, name, row_class): + if not issubclass(row_class, Row): + raise TypeError(f'Row classes must be izzylib.sql2.row.Row, not {row_class.__name__}') + + self.cfg.row_classes[name] = row_class + + + def get_row_class(self, name): + return self.cfg.row_classes.get(name, Row) + + + def load_tables(self, tables=None): + if tables: + self.tables.load_tables(tables) + + else: + with self.session as s: + self.tables.load_tables(s.table_layout()) + + def create_tables(self): + if self.tables.empty: + raise NoTableLayoutError('Table layout not loaded yet') + + with self.session as s: + for table in self.tables.names: + s.execute(self.tables.compile_table(table)) + + + def create_database(self): + if self.cfg.type == 'postgresql': + with self.new_predb().session as s: + if not s.raw_execute('SELECT datname FROM pg_database WHERE datname = ?', [self.cfg.database]).fetchone(): + s.raw_execute(f'CREATE DATABASE {self.cfg.database}') + + elif self.cfg.type != 'sqlite': + raise NotImplementedError(f'Database type not supported yet: {self.cfg.type}') + + self.create_tables(tables) + + + def drop_database(self, database): + if self.cfg.type == 'sqlite': + izzylog.verbose('drop_database not needed for SQLite') + return + + with self.session as s: + if self.cfg.type == 'postgresql': + s.raw_execute(f'DROP DATABASE {database}') + + else: + raise NotImplementedError(f'Database type not supported yet: {self.cfg.type}') + + +class Connection: + def __init__(self, db): + self.db = db + self.cfg = db.cfg + self.sessions = [] + self.conn = None + + self.connect() + + if db.tables.empty: + with self.session as s: + db.load_tables(s.table_layout()) + + + @property + def autocommit(self): + return self.conn.autocommit + + + @property + def session(self): + return self.cfg.session_class(self) + + + def connect(self): + if self.conn: + return + + dbconfig = self.cfg.dbargs + + if self.cfg.type == 'sqlite': + if self.cfg.autocommit: + self.conn = self.cfg.module.connect(dbconfig['database'], isolation_level=None) + + else: + self.conn = self.cfg.module.connect(dbconfig['database']) + + elif self.cfg.type == 'postgresql': + if Path(self.cfg.host).exists(): + dbconfig['unix_sock'] = dbconfig.pop('host') + + dbconfig['user'] = dbconfig.pop('username') + dbconfig['application_name'] = self.cfg.appname + self.conn = self.cfg.module.connect(**dbconfig) + + else: + self.conn = self.cfg.module.connect(**self.cfg.dbargs) + + try: + self.conn.autocommit = self.cfg.autocommit + + except AttributeError: + if self.cfg.module_name not in ['sqlite']: + izzylog.verbose('Module does not support autocommit:', self.cfg.module_name) + + return self.conn + + + def disconnect(self): + if not self.conn: + return + + self.close_sessions() + self.conn.close() + self.conn = None + + + def close_sessions(self): + for session in self.sessions: + self.close_session(session) + + + def close_session(session): + try: self.sessions.remove(session) + except: pass + + session.close() + + if not len(self.sessions) and len(self.db.connections) > self.cfg.minconnections: + self.disconnect() + + + def cursor(self): + if not self.conn: + raise + return self.conn.cursor() + + + def dump_database(self, path='database.sql'): + if self.cfg.type == 'sqlite': + path = Path(path) + + with path.open('w') as fd: + fd.write('\n\n'.join(list(self.conn.iterdump())[1:-1])) + + else: + raise NotImplementedError('Only SQLite supported atm :/') diff --git a/izzylib/sql2/result.py b/izzylib/sql2/result.py new file mode 100644 index 0000000..2ae2c1e --- /dev/null +++ b/izzylib/sql2/result.py @@ -0,0 +1,68 @@ +from .row import Row + + +class Result: + def __init__(self, session): + self.table = None + self.session = session + self.cursor = session.cursor + + try: + self.keys = [desc[0] for desc in session.cursor.description] + + except TypeError: + self.keys = [] + + + def __iter__(self): + yield from self.all_iter() + + + @property + def row_class(self): + return self.session.db.get_row_class(self.table) + + + @property + def last_row_id(self): + if self.session.cfg.type == 'postgresql': + try: + return self.one().id + + except: + return None + + return self.cursor.lastrowid + + + @property + def row_count(self): + return self.cursor.rowcount + + + def set_table(self, table): + self.table = table + + + def one(self): + data = self.cursor.fetchone() + + if not data: + return + + return self.row_class( + self.session, + self.table, + {self.keys[idx]: value for idx, value in enumerate(data)}, + ) + + + def all(self): + return [row for row in self.all_iter()] + + + def all_iter(self): + for row in self.cursor: + yield self.row_class(self.session, self.table, + {self.keys[idx]: value for idx, value in enumerate(row)} + ) diff --git a/izzylib/sql2/row.py b/izzylib/sql2/row.py new file mode 100644 index 0000000..b9e42d4 --- /dev/null +++ b/izzylib/sql2/row.py @@ -0,0 +1,29 @@ +from ..dotdict import DotDict + + +class Row(DotDict): + def __init__(self, session, table, data): + super().__init__(session._parse_data('serialize', table, data)) + + self._table = table + self._session = session + self.__run__(session) + + + def __run__(self, session): + pass + + + @property + def table(self): + return self._table + + + @property + def rowid(self): + return self.id + + + @property + def rowid2(self): + return self.get('rowid', self.id) diff --git a/izzylib/sql2/session.py b/izzylib/sql2/session.py new file mode 100644 index 0000000..0367fe5 --- /dev/null +++ b/izzylib/sql2/session.py @@ -0,0 +1,346 @@ +import json + +from pathlib import Path as PyPath + +from .result import Result +from .row import Row +from .statements import Select, Insert, Delete, Count, Update, Statement +from .table import SessionTables + +from .. import izzylog +from ..dotdict import DotDict +from ..exceptions import NoTransactionError, UpdateAllRowsError +from ..misc import boolean, random_gen +from ..path import Path + + +class Session: + def __init__(self, conn): + self.db = conn.db + self.cfg = conn.db.cfg + self.conn = conn + self.sid = random_gen() + self.tables = SessionTables(self) + + self.cursor = conn.cursor() + self.trans = False + + self.__setup__() + + + def __enter__(self): + return self + + + def __exit__(self, exctype, excvalue, traceback): + if traceback: + self.rollback() + + else: + self.commit() + + + def __setup__(self): + pass + + + def close(): + if not self.cursor: + return + + self.conn.close_session(self) + + + def _parse_data(self, action, table, kwargs): + data = {} + + if self.db.tables: + for key, value in kwargs.items(): + try: + coltype = self.db.tables[table][key].type + + except KeyError: + data[key] = value + continue + + parser = self.db.types.get_type(coltype) + + try: + data[key] = parser(action, self.cfg.type, value) + except Exception as e: + izzylog.error(f'Failed to parse data from the table "{table}": {key} = {value}') + izzylog.debug(f'Parser: {parser}, Type: {coltype}') + raise e from None + + else: + data = kwargs + + return data + + + def dump_database(self, path): + import sqlparse + path = Path(path) + + with path.open('w') as fd: + line = '\n\n'.join(list(self.conn.iterdump())[1:-1]) + fd.write(sqlparse.format(line, + reindent = False, + keyword_case = 'upper', + )) + + def dump_database2(self, path): + path = Path(path) + + with path.open('w') as fd: + fd.write('\n\n'.join(list(self.conn.iterdump())[1:-1])) + + + def begin(self): + if self.trans or self.cfg.autocommit: + return + + self.execute('BEGIN') + self.trans = True + + + def commit(self): + if not self.trans: + return + + self.execute('COMMIT') + self.trans = False + + + def rollback(self): + if not self.trans: + return + + self.execute('ROLLBACK') + self.trans = False + + + def raw_execute(self, string, values=None): + if type(string) == Path: + string = string.read() + + elif type(string) == PyPath: + with string.open() as fd: + string = fd.read() + + if values: + self.cursor.execute(string, values) + + else: + self.cursor.execute(string) + + return self.cursor + + + def execute(self, string, *values): + if isinstance(string, Statement): + raise TypeError('String must be a str not a Statement') + + action = string.split()[0].upper() + + if not self.trans and action in ['CREATE', 'INSERT', 'UPDATE', 'UPSERT', 'DROP', 'DELETE', 'ALTER']: + if self.cfg.auto_trans: + self.begin() + + else: + raise NoTransactionError(f'Command not supported outside a transaction: {action}') + + try: + self.raw_execute(string, values) + + except Exception as e: + if type(e).__name__ in ['DatabaseError', 'OperationalError']: + print(string, values) + + raise e from None + + return Result(self) + + + def run(self, query): + result = self.execute(query.compile(self.cfg.type), *query.values) + + if type(query) == Count: + return list(result.one().values())[0] + + result.set_table(query.table) + return result + + + def run_count(self, query): + return list(self.run(query).one().values())[0] + + + def count(self, table, **kwargs): + if self.db.tables and table not in self.db.tables: + raise KeyError(f'Table does not exist: {table}') + + query = Count(table, **kwargs) + + return self.run_count(query) + + + def fetch(self, table, orderby=None, orderdir='ASC', limit=None, offset=None, **kwargs): + if self.db.tables and table not in self.db.tables: + raise KeyError(f'Table does not exist: {table}') + + query = Select(table, **kwargs) + + if orderby: + query.order(orderby, orderdir) + + if limit: + query.limit(limist) + + if offset: + query.offset(offset) + + return self.run(query) + + + def insert(self, table, return_row=False, **kwargs): + if self.db.tables and table not in self.db.tables: + raise KeyError(f'Table does not exist: {table}') + + result = self.run(Insert(table, **self._parse_data('deserialize', table, kwargs))) + + if return_row: + return self.fetch(table, id=result.last_row_id).one() + + return result.last_row_id + + + def update(self, table, data, return_row=False, **kwargs): + query = Update(table, **data) + + for pair in kwargs.items(): + query.where(*pair) + + if not query._where: + raise UpdateAllRowsError(f'Refusing to update all rows in table: {table}') + + result = self.run(query) + + if return_row: + return self.fetch(table, id=result.last_row_id).one() + + else: + return result + + + def update_row(self, row, return_row=False, **kwargs): + return self.update(row.table, kwargs, id=row.id, return_row=return_row) + + + def remove(self, table, **kwargs): + if self.db.tables and table not in self.db.tables: + raise KeyError(f'Table does not exist: {table}') + + self.run(Delete(table, self._parse_data('deserialize', table, kwargs))) + + + def remove_row(self, row): + if not row.table: + raise ValueError('Row not associated with a table') + + self.remove(row.table, id=row.id) + + + def create_tables(self, tables=None): + if tables: + self.load_tables(tables) + + if not self.tables: + raise NoTableLayoutError('No table layout available') + + for table in self.tables.values(): + self.execute(table.compile(self.cfg.type)) + + + def table_layout(self): + tables = {} + + if self.cfg.type == 'sqlite': + rows = self.execute("SELECT name, sql FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%'") + + for row in rows: + name = row.name + tables[name] = {} + fkeys = {fkey['from']: f'{fkey.table}.{fkey["to"]}' for fkey in self.execute(f'PRAGMA foreign_key_list({name})')} + columns = [col for col in self.execute(f'PRAGMA table_info({name})')] + + unique_list = parse_unique(row.sql) + + for column in columns: + tables[name][column.name] = dict( + type = column.type.upper(), + nullable = not column.notnull, + default = parse_default(column.dflt_value), + primary_key = bool(column.pk), + foreign_key = fkeys.get(column.name), + unique = column.name in unique_list + ) + + elif self.cfg.type == 'postgresql': + for row in self.execute("SELECT * FROM information_schema.columns WHERE table_schema not in ('information_schema', 'pg_catalog') ORDER BY table_schema, table_name, ordinal_position"): + table = row.table_name + column = row.column_name + + if not tables.get(table): + tables[table] = {} + + if not tables[table].get(column): + tables[table][column] = {} + + tables[table][column] = dict( + type = row.data_type.upper(), + nullable = boolean(row.is_nullable), + default = row.column_default if row.column_default and not row.column_default.startswith('nextval') else None, + primary_key = None, + foreign_key = None, + unique = None + ) + + return tables + + +def parse_unique(sql): + unique_list = [] + + try: + for raw_line in sql.splitlines(): + if 'UNIQUE' not in raw_line: + continue + + for line in raw_line.replace('UNIQUE', '').replace('(', '').replace(')', '').split(','): + line = line.strip() + + if line: + unique_list.append(line) + + except IndexError: + pass + + return unique_list + + +def parse_default(value): + if value == None: + return + + if value.startswith("'") and value.endswith("'"): + value = value[1:-1] + + else: + try: + value = int(value) + + except ValueError: + pass + + return value diff --git a/izzylib/sql2/statements.py b/izzylib/sql2/statements.py new file mode 100644 index 0000000..353bc54 --- /dev/null +++ b/izzylib/sql2/statements.py @@ -0,0 +1,185 @@ +from ..dotdict import DotDict + + +Comparison = DotDict( + LESS = lambda key: f'{key} < ?', + GREATER = lambda key: f'{key} > ?', + LESS_EQUAL = lambda key: f'{key} <= ?', + GREATER_EQUAL = lambda key: f'{key} >= ?', + EQUAL = lambda key: f'{key} = ?', + NOT_EQUAL = lambda key: f'{key} != ?', + IN = lambda key: f'{key} IN (?)', + NOT_IN = lambda key: f'{key} NOT IN (?)', + LIKE = lambda key: f'{key} LIKE ?', + NOT_LIKE = lambda key: f'{key} NOT LIKE ?' +) + + +class Statement: + def __init__(self, table): + self.table = table + self.values = [] + + self._where = '' + self._order = None + self._limit = None + self._offset = None + + + def __str__(self): + return self.compile('sqlite') + + + def where(self, key, value, comparison='equal', operator='and'): + try: + comp = Comparison[comparison.upper().replace('-', '_')] + + except KeyError: + raise KeyError(f'Invalid comparison: {comparison}') + + prefix = f' {operator} ' if self._where else ' ' + + self._where += f'{prefix}{comp(key)}' + self.values.append(value) + + return self + + + def order(self, column, direction='ASC'): + direction = direction.upper() + assert direction in ['ASC', 'DESC'] + self._order = (column, direction) + return self + + + def limit(self, limit_num): + self._limit = int(limit_num) + return self + + + def offset(self, offset_num): + self._offset = int(offset_num) + return self + + + def compile(self, dbtype): + raise NotImplementedError('Do not use the Statement class directly.') + + +class Select(Statement): + def __init__(self, table, *columns, **kwargs): + super().__init__(table) + + self.columns = columns + + for key, value in kwargs.items(): + self.where(key, value) + + + def compile(self, dbtype): + data = f'SELECT' + + if self.columns: + columns = ','.join(self.columns) + + else: + columns = '*' + + data += f' {columns} FROM {self.table}' + + if self._where: + data += f' WHERE {self._where}' + + if self._order: + col, direc = self._order + data += f' ORDER BY {col} {direc}' + + if self._limit: + data += f' LIMIT {self._limit}' + + if self._offset: + data += f' OFFSET {self._offset}' + + return data + + +class Insert(Statement): + def __init__(self, table, **kwargs): + super().__init__(table) + + self.keys = [] + + for pair in kwargs.items(): + self.add_data(*pair) + + + def add_data(self, key, value): + self.keys.append(key) + self.values.append(value) + + + def remove_data(self, key): + index = self.keys.index(key) + + del self.keys[index] + del self.values[index] + + + def compile(self, dbtype): + keys = ','.join(self.keys) + values = ','.join('?' for value in self.values) + data = f'INSERT INTO {self.table} ({keys}) VALUES ({values})' + + if dbtype == 'postgresql': + data += f' RETURNING id' + + return data + + +class Update(Statement): + def __init__(self, table, **kwargs): + super().__init__(table) + self.keys = [] + + for key, value in kwargs.items(): + self.keys.append(key) + self.values.append(value) + + + def compile(self, dbtype): + pairs = ','.join(f'{key} = ?' for key in self.keys) + data = f'UPDATE {self.table} SET {pairs} WHERE {self._where}' + + if dbtype == 'postgresql': + data += f' RETURNING id' + + return data + + +class Delete(Statement): + def __init__(self, table, **kwargs): + super().__init__(table) + + for key, value in kwargs.items(): + self.where(key, value) + + + def compile(self, dbtype): + return f'DELETE FROM {self.table} WHERE {self._where}' + + +class Count(Statement): + def __init__(self, table, **kwargs): + super().__init__(table) + + for key, value in kwargs.items(): + self.where(key, value) + + + def compile(self, dbtype): + data = f'SELECT COUNT(*) FROM {self.table}' + + if self._where: + data += f' WHERE {self._where}' + + return data diff --git a/izzylib/sql2/table.py b/izzylib/sql2/table.py new file mode 100644 index 0000000..f9b577e --- /dev/null +++ b/izzylib/sql2/table.py @@ -0,0 +1,244 @@ +from ..dotdict import DotDict + + +class SessionTables: + def __init__(self, session): + self._session = session + self._db = session.db + self._tables = session.db.tables + + + def __getattr__(self, key): + return SessionTable(session, key, self._tables[key]) + + + def names(self): + return tuple(self._tables.keys()) + + +class SessionTable(DotDict): + def __init__(self, session, name, columns): + super().__init__(columns) + + self._name = name + self._session = session + self._db = session.db + + + @property + def name(self): + return self._name + + + @property + def columns(self): + return tuple(self.keys()) + + + def fetch(self, **kwargs): + self._check_columns(**kwargs) + return self.session.fetch(self.name, **kwargs) + + + def insert(self, **kwargs): + self._check_columns(**kwargs) + return self.session.insert(self.name, **kwargs) + + + def remove(self, **kwargs): + self._check_columns(**kwargs) + return self.session.remove(self.name, **kwargs) + + + def _check_columns(self, **kwargs): + for key in kwargs.keys(): + if key not in self.columns: + raise KeyError(f'Not a column for table "{self.name}": {key}') + + +class DbTables(DotDict): + def __init__(self, db): + super().__init__() + + self._db = db + self._cfg = db.cfg + + + @property + def empty(self): + return not len(self.keys()) + + + @property + def names(self): + return tuple(self.keys()) + + + def load_tables(self, tables): + for name, columns in tables.items(): + self.add_table(name, columns) + + + def unload_tables(self): + for key in self.names: + del self[key] + + + def add_table(self, name, columns): + self[name] = {} + + if type(columns) == list: + self[name] = {col.name: col for col in columns} + + elif isinstance(columns, dict): + for column, data in columns.items(): + self[name][column] = DbColumn(self._cfg.type, column, **data) + + else: + raise TypeError('Columns must be a list of Column objects or a dict') + + + def remove_table(self, name): + return self.pop(name) + + + def get_columns(self, name): + return tuple(self[name].values()) + + + def compile_table(self, table_name, dbtype): + table = self[table_name] + columns = [] + foreign_keys = [] + + for column in self.get_columns(table_name): + columns.append(column.compile(dbtype)) + + if column.foreign_key: + fkey_table, fkey_col = column.foreign_key + foreign_keys.append(f'FOREIGN KEY ({column.name}) REFERENCES {fkey_table} ({fkey_col})') + + return f'CREATE TABLE IF NOT EXISTS {self.name} ({",".join(columns)}{",".join(foreign_keys)})' + + + def compile_all(self, dbtype): + return [self.compile_table(name, dbtype) for name in self.keys()] + + +class DbColumn(DotDict): + def __init__(self, dbtype, name, type=None, default=None, primary_key=False, unique=False, nullable=True, autoincrement=False, foreign_key=None): + super().__init__( + name = name, + type = type, + default = default, + primary_key = primary_key, + unique = unique, + nullable = nullable, + autoincrement = autoincrement, + foreign_key = foreign_key + ) + + if self.name == 'id': + if dbtype == 'sqlite': + self.type = 'INTEGER' + self.autoincrement = True + + elif dbtype == 'postgresql': + self.type = 'SERIAL' + self.autoincrement = False + + self.primary_key = True + self.unique = False + self.nullable = False + self.default = None + self.foreign_key = None + + elif self.name in ['created', 'modified', 'accessed'] and not self.type: + self.type = 'DATETIME' + + if not self.type: + raise ValueError(f'Must provide a column type for column: {name}') + + try: + self.fkey + + except ValueError: + raise ValueError(f'Invalid foreign_key format. Must be "table.column"') + + + @property + def fkey(self): + try: + return self.foreign_key.split('.') + + except AttributeError: + return + + + def compile(self, dbtype): + line = f'{self.name} {self.type}' + + if self.primary_key: + line += ' PRIMARY KEY' + + if not self.nullable: + line += ' NOT NULL' + + if self.unique: + line += ' UNIQUE' + + if self.autoincrement and dbtype != 'postgresql': + line += ' AUTOINCREMENT' + + if self.default: + line += f" DEFAULT {parse_default(self.default)}" + + return line + + +class Column(DotDict): + def __init__(self, name, type=None, default=None, primary_key=False, unique=False, nullable=True, autoincrement=False, foreign_key=None): + super().__init__( + name = name, + type = type.upper() if type else None, + default = default, + primary_key = primary_key, + unique = unique, + nullable = nullable, + autoincrement = autoincrement, + foreign_key = foreign_key + ) + + if self.name == 'id': + self.type = 'SERIAL' + + elif self.name in ['created', 'modified', 'accessed'] and not self.type: + self.type = 'DATETIME' + + if not self.type: + raise ValueError(f'Must provide a column type for column: {name}') + + try: + self.fkey + + except ValueError: + raise ValueError(f'Invalid foreign_key format. Must be "table.column"') + + + @property + def fkey(self): + try: + return self.foreign_key.split('.') + + except AttributeError: + return + + +def parse_default(default): + if isinstance(default, dict) or isinstance(default, list): + default = json.dumps(default) + + if type(default) == str: + default = f"'{default}'" + + return default diff --git a/izzylib/sql2/types.py b/izzylib/sql2/types.py new file mode 100644 index 0000000..2d4ed11 --- /dev/null +++ b/izzylib/sql2/types.py @@ -0,0 +1,219 @@ +from datetime import date, time, datetime + +from .. import izzylog +from ..dotdict import DotDict, LowerDotDict + + +Standard = { + 'INTEGER', + 'INT', + 'TINYINT', + 'SMALLINT', + 'MEDIUMINT', + 'BIGINT', + 'UNSIGNED BIG INT', + 'INT2', + 'INT8', + 'TEXT', + 'CHARACTER', + 'CHAR', + 'VARCHAR', + 'BLOB', + 'CLOB', + 'REAL', + 'DOUBLE', + 'DOUBLE PRECISION', + 'FLOAT', + 'NUMERIC', + 'DEC', + 'DECIMAL', + 'BOOLEAN', + 'DATE', + 'TIME', + 'JSON' +} + + +Sqlite = { + *Standard, + 'DATETIME' +} + + +Postgresql = { + *Standard, + 'SMALLSERIAL', + 'SERIAL', + 'BIGSERIAL', + 'VARYING', + 'BYTEA', + 'TIMESTAMP', + 'INTERVAL', + 'POINT', + 'LINE', + 'LSEG', + 'BOX', + 'PATH', + 'POLYGON', + 'CIRCLE', +} + + +Mysql = { + *Standard, + 'FIXED', + 'BIT', + 'YEAR', + 'VARBINARY', + 'ENUM', + 'SET' +} + + +class Type: + sqlite = None + postgresql = None + mysql = None + + + def __getitem__(self, key): + if key in ['sqlite', 'postgresql', 'mysql']: + return getattr(self, key) + + raise KeyError(f'Invalid database type: {key}') + + + def __call__(self, action, dbtype, value): + return getattr(self, action)(dbtype, value) + + + def name(self, dbtype='sqlite'): + return self[dbtype] + + + def serialize(self, dbtype, value): + return value + + + def deserialize(self, dbtype, value): + return value + + +class Json(Type): + sqlite = 'JSON' + postgresql = 'JSON' + mysql = 'JSON' + + + def serialize(self, dbtype, value): + izzylog.debug(f'serialize {type(self).__name__}: {type(value).__name__}', value) + + if type(value) == str: + return DotDict(value) + + return value + + + def deserialize(self, dbtype, value): + izzylog.debug(f'deserialize {type(self).__name__}: {type(value).__name__}', value) + + return DotDict(value).to_json() + + +class Datetime(Type): + sqlite = 'DATETIME' + postgresql = 'TIMESTAMP' + mysql = 'DATETIME' + + + def serialize(self, dbtype, value): + izzylog.debug(f'serialize {type(self).__name__}: {type(value).__name__}', value) + + if type(value) == str: + return datetime.fromisoformat(value) + + elif type(value) == int: + return datetime.fromtimestamp(value) + + return value + + + def deserialize(self, dbtype, value): + izzylog.debug(f'deserialize {type(self).__name__}: {type(value).__name__}', value) + + if dbtype == 'sqlite': + return value.isoformat() + + return value + + +class Date(Type): + sqlite = 'DATE' + postgresql = 'DATE' + mysql = 'DATE' + + + def serialize(self, dbtype, value): + izzylog.debug(f'serialize {type(self).__name__}: {type(value).__name__}', value) + + if type(value) == str: + return date.fromisoformat(value) + + elif type(value) == int: + return date.fromtimestamp(value) + + + return value + + + def deserialize(self, dbtype, value): + izzylog.debug(f'deserialize {type(self).__name__}: {type(value).__name__}', value) + + if dbtype == 'sqlite': + return value.isoformat() + + return value + + +class Time(Type): + sqlite = 'TIME' + postgresql = 'TIME' + mysql = 'TIME' + + + def serialize(self, dbtype, value): + izzylog.debug(f'serialize {type(self).__name__}: {type(value).__name__}', value) + + if type(value) == str: + return time.fromisoformat(value) + + elif type(value) == int: + return time.fromtimestamp(value) + + return value + + + def deserialize(self, dbtype, value): + izzylog.debug(f'deserialize {type(self).__name__}: {type(value).__name__}', value) + + if dbtype == 'sqlite': + return value.isoformat() + + return value + + +class Types(DotDict): + def __init__(self, db): + self._db = db + + self.set_type(Json, Date, Time, Datetime) + + + def get_type(self, name): + return self.get(name.upper(), Type()) + + + def set_type(self, *types): + for type_object in types: + typeclass = type_object() + self[typeclass.name(self._db.cfg.type)] = typeclass diff --git a/setup.cfg b/setup.cfg index 864d7a1..007d66e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,6 +33,7 @@ packages = izzylib.http_server_async izzylib.http_urllib_client izzylib.sql + izzylib.sql2 setup_requires = setuptools >= 38.3.0 @@ -59,6 +60,8 @@ http_urllib_client = sql = SQLAlchemy == 1.4.23 SQLAlchemy-Paginator == 0.2 +sql2 = + sql-metadata == 2.3.0 template = colour == 0.1.5 Hamlish-Jinja == 0.3.3