http_server_async: add transport class, sql2: new sub-module
This commit is contained in:
parent
3e56885a49
commit
1e62836754
|
@ -1,4 +1,4 @@
|
|||
import json, mimetypes
|
||||
import json, mimetypes, traceback
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
|
@ -30,26 +30,45 @@ url_keys = [
|
|||
]
|
||||
|
||||
|
||||
def parse_privacy_level(to: list=[], cc: list=[]):
|
||||
if to == [pubstr] and len(cc) == 1:
|
||||
def parse_privacy_level(to: list=[], cc: list=[], followers=None):
|
||||
if pubstr in to and followers in cc:
|
||||
return 'public'
|
||||
|
||||
elif to and self.actor in to[0] and not cc:
|
||||
elif followers in to and pubstr in cc:
|
||||
return 'unlisted'
|
||||
|
||||
elif pubstr not in to:
|
||||
elif pubstr not in to and pubstr not in cc and followers in cc:
|
||||
return 'private'
|
||||
|
||||
elif not tuple(item for item in [*to, *cc] if item not in [pubstr, followers]):
|
||||
return 'direct'
|
||||
|
||||
else:
|
||||
logging.warning('Not sure what this privacy level is')
|
||||
logging.debug(f'to: {json.dumps(to)}')
|
||||
logging.debug(f'cc: {json.dumps(cc)}')
|
||||
logging.debug(f'followers: {followers}')
|
||||
|
||||
|
||||
def generate_privacy_fields(privacy='public'):
|
||||
def generate_privacy_fields(privacy='public', followers=None, to=[], cc=[]):
|
||||
if privacy == 'public':
|
||||
return ([pubstr])
|
||||
to = [pubstr, *to]
|
||||
cc = [followers, *to]
|
||||
|
||||
elif privacy == 'unlisted':
|
||||
to = [followers, *to]
|
||||
cc = [pubstr, *to]
|
||||
|
||||
elif privacy == 'private':
|
||||
cc = [followers, *cc]
|
||||
|
||||
elif privacy == 'direct':
|
||||
pass
|
||||
|
||||
else:
|
||||
raise ValueError(f'Unknown privacy level: {privacy}')
|
||||
|
||||
return to, cc
|
||||
|
||||
class Object(DotDict):
|
||||
def __setitem__(self, key, value):
|
||||
|
@ -384,7 +403,11 @@ class Object(DotDict):
|
|||
|
||||
@property
|
||||
def privacy_level(self):
|
||||
return parse_privacy_level(self.get('to', []), self.get('cc', []))
|
||||
return parse_privacy_level(
|
||||
self.get('to', []),
|
||||
self.get('cc', []),
|
||||
self.get('attributedTo', '') + '/followers'
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
|
@ -416,7 +439,7 @@ class Object(DotDict):
|
|||
|
||||
@property
|
||||
def info_table(self):
|
||||
return DotDict({p.name: p.value for p in self.get('attachment', {})})
|
||||
return DotDict({p['name']: p['value'] for p in self.get('attachment', {})})
|
||||
|
||||
|
||||
@property
|
||||
|
@ -429,6 +452,16 @@ class Object(DotDict):
|
|||
return self.get('summary')
|
||||
|
||||
|
||||
@property
|
||||
def avatar(self):
|
||||
return self.icon.url
|
||||
|
||||
|
||||
@property
|
||||
def header(self):
|
||||
return self.image.url
|
||||
|
||||
|
||||
class Collection(Object):
|
||||
@classmethod
|
||||
def new_replies(cls, statusid):
|
||||
|
|
|
@ -24,3 +24,23 @@ class MethodNotHandledException(Exception):
|
|||
|
||||
class NoBlueprintForPath(Exception):
|
||||
'raise when no blueprint is found for a specific path'
|
||||
|
||||
|
||||
class NoConnectionError(Exception):
|
||||
'Raise when a function requiring a connection gets called when there is no connection'
|
||||
|
||||
|
||||
class MaxConnectionsError(Exception):
|
||||
'Raise when the max amount of connections has been reached'
|
||||
|
||||
|
||||
class NoTransactionError(Exception):
|
||||
'Raise when trying to execute an SQL write statement outside a transaction'
|
||||
|
||||
|
||||
class NoTableLayoutError(Exception):
|
||||
'Raise when a table layout is necessary, but not loaded'
|
||||
|
||||
|
||||
class UpdateAllRowsError(Exception):
|
||||
'Raise when an UPDATE tries to modify all rows in a table'
|
||||
|
|
|
@ -17,7 +17,6 @@ def create_app(appname, **kwargs):
|
|||
|
||||
from .application import Application, Blueprint
|
||||
from .middleware import MediaCacheControl
|
||||
from .misc import Cookies, Headers
|
||||
from .request import Request
|
||||
from .response import Response
|
||||
from .view import View, Static
|
||||
|
|
|
@ -8,6 +8,7 @@ from .config import Config
|
|||
from .response import Response
|
||||
#from .router import Router
|
||||
from .view import Static, Manifest, Robots, Style
|
||||
from .transport import Transport
|
||||
|
||||
from .. import logging
|
||||
from ..dotdict import DotDict
|
||||
|
@ -16,7 +17,7 @@ from ..misc import signal_handler
|
|||
from ..path import Path
|
||||
|
||||
try:
|
||||
from ..sql import Database
|
||||
from ..sql2 import Database
|
||||
except ImportError:
|
||||
Database = NotImplementedError('Failed to import SQL database class')
|
||||
|
||||
|
@ -52,7 +53,7 @@ class ApplicationBase:
|
|||
if isinstance(Database, Exception):
|
||||
raise Database from None
|
||||
|
||||
self.db = dbclass(dbtype, **dbargs)
|
||||
self.db = dbclass(dbtype, **dbargs, app=self)
|
||||
|
||||
|
||||
def __getitem__(self, key):
|
||||
|
@ -320,13 +321,20 @@ class Application(ApplicationBase):
|
|||
|
||||
|
||||
async def handle_client(self, reader, writer):
|
||||
transport = Transport(self, reader, writer)
|
||||
request = None
|
||||
response = None
|
||||
|
||||
try:
|
||||
request = self.cfg.request_class(self, reader, writer.get_extra_info('peername')[0])
|
||||
request = self.cfg.request_class(self, transport)
|
||||
response = self.cfg.response_class(request=request)
|
||||
await request.parse_headers()
|
||||
|
||||
try:
|
||||
await request.parse_headers()
|
||||
|
||||
except asyncio.exceptions.IncompleteReadError as e:
|
||||
request = None
|
||||
raise e from None
|
||||
|
||||
try:
|
||||
# this doesn't work all the time for some reason
|
||||
|
@ -336,8 +344,8 @@ class Application(ApplicationBase):
|
|||
except NoBlueprintForPath:
|
||||
response = await self.handle_request(request, response)
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
#except Exception as e:
|
||||
#traceback.print_exc()
|
||||
|
||||
except NotFound:
|
||||
response = self.cfg.response_class(request=request).set_error('Not Found', 404)
|
||||
|
@ -357,22 +365,22 @@ class Application(ApplicationBase):
|
|||
except:
|
||||
traceback.print_exc()
|
||||
|
||||
## Don't use a custom response class here just in case it caused the error
|
||||
response = Response(request=request).set_error('Server Error', 500)
|
||||
if not response.streaming:
|
||||
## Don't use a custom response class here just in case it caused the error
|
||||
response = Response(request=request).set_error('Server Error', 500)
|
||||
|
||||
try:
|
||||
response.headers.update(self.cfg.default_headers)
|
||||
writer.write(response.compile())
|
||||
await writer.drain()
|
||||
if not response.streaming:
|
||||
try:
|
||||
response.headers.update(self.cfg.default_headers)
|
||||
await transport.write(response.compile())
|
||||
|
||||
if request and not request.path.startswith('/framework'):
|
||||
logging.info(f'{request.remote} {request.method} {request.path} {response.status} {len(response.body)} {request.agent}')
|
||||
if request and request.log and not request.path.startswith('/framework'):
|
||||
logging.info(f'{request.remote} {request.method} {request.path} {response.status} {len(response.body)} {request.agent}')
|
||||
|
||||
except:
|
||||
traceback.print_exc()
|
||||
except:
|
||||
traceback.print_exc()
|
||||
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
await transport.close()
|
||||
|
||||
|
||||
class Blueprint(ApplicationBase):
|
||||
|
|
|
@ -17,28 +17,28 @@ LocalTime = datetime.now(UtcTime).astimezone().tzinfo
|
|||
|
||||
class Request:
|
||||
__slots__ = [
|
||||
'_body', '_form', '_reader', '_method', '_app', '_params',
|
||||
'_body', '_form', '_method', '_app', '_params',
|
||||
'address', 'path', 'version', 'headers', 'cookies',
|
||||
'query', 'raw_query'
|
||||
'query', 'raw_query', 'transport', 'log'
|
||||
]
|
||||
|
||||
ctx = DotDict()
|
||||
|
||||
def __init__(self, app, reader, address):
|
||||
def __init__(self, app, transport):
|
||||
super().__init__()
|
||||
|
||||
self._app = app
|
||||
self._reader = reader
|
||||
self._body = b''
|
||||
self._form = DotDict()
|
||||
self._method = None
|
||||
self._params = None
|
||||
|
||||
self.transport = transport
|
||||
self.headers = Headers()
|
||||
self.cookies = Cookies()
|
||||
self.query = DotDict()
|
||||
|
||||
self.address = address
|
||||
self.address = transport.client_address
|
||||
self.path = None
|
||||
self.version = None
|
||||
self.raw_query = None
|
||||
|
@ -141,14 +141,9 @@ class Request:
|
|||
return self._params
|
||||
|
||||
|
||||
async def read(self, length=2048, timeout=None):
|
||||
try: return await asyncio.wait_for(self._reader.read(length), timeout or self.app.cfg.timeout)
|
||||
except: return
|
||||
|
||||
|
||||
async def body(self):
|
||||
if not self._body and self.length:
|
||||
self._body = await self.read(self.length)
|
||||
self._body = await self.transport.read(self.length)
|
||||
|
||||
return self._body
|
||||
|
||||
|
@ -178,7 +173,7 @@ class Request:
|
|||
|
||||
|
||||
async def parse_headers(self):
|
||||
data = (await self._reader.readuntil(b'\r\n\r\n')).decode('utf-8')
|
||||
data = (await self.transport.readuntil(b'\r\n\r\n')).decode('utf-8')
|
||||
|
||||
for idx, line in enumerate(data.splitlines()):
|
||||
if idx == 0:
|
||||
|
|
|
@ -64,6 +64,11 @@ class Response:
|
|||
return len(self.body)
|
||||
|
||||
|
||||
@property
|
||||
def streaming(self):
|
||||
return self.headers.getone('Transfer-Encoding') == 'chunked'
|
||||
|
||||
|
||||
def append(self, data):
|
||||
self._body += self._parse_body_data(data)
|
||||
|
||||
|
@ -137,7 +142,7 @@ class Response:
|
|||
return self
|
||||
|
||||
|
||||
def set_json(self, body={}, status=None, activity=False,):
|
||||
def set_json(self, body={}, status=None, activity=False):
|
||||
self.content_type = 'application/activity+json' if activity else 'application/json'
|
||||
self.body = body
|
||||
|
||||
|
@ -177,11 +182,23 @@ class Response:
|
|||
return self
|
||||
|
||||
|
||||
async def set_streaming(self, transport, headers={}):
|
||||
self.headers.update(headers)
|
||||
self.headers.update(transport.app.cfg.default_headers)
|
||||
self.headers.setall('Transfer-encoding', 'chunked')
|
||||
|
||||
transport.write(self._compile_headers())
|
||||
|
||||
|
||||
def set_cookie(self, key, value, **kwargs):
|
||||
self.cookies[key] = CookieItem(key, value, **kwargs)
|
||||
|
||||
|
||||
def compile(self):
|
||||
return self._compile_headers() + self.body
|
||||
|
||||
|
||||
def _compile_headers(self):
|
||||
data = bytes(f'HTTP/1.1 {self.status}', 'utf-8')
|
||||
|
||||
for k,v in self.headers.items():
|
||||
|
@ -198,9 +215,7 @@ class Response:
|
|||
data += bytes(f'\r\nDate: {datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S GMT")}', 'utf-8')
|
||||
|
||||
data += bytes(f'\r\nContent-Length: {len(self.body)}', 'utf-8')
|
||||
|
||||
data += b'\r\n\r\n'
|
||||
data += self.body
|
||||
|
||||
return data
|
||||
|
||||
|
|
65
izzylib/http_server_async/transport.py
Normal file
65
izzylib/http_server_async/transport.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
import asyncio
|
||||
|
||||
from ..dotdict import DotDict
|
||||
|
||||
|
||||
class Transport:
|
||||
def __init__(self, app, reader, writer):
|
||||
self.app = app
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
|
||||
|
||||
@property
|
||||
def client_address(self):
|
||||
return self.writer.get_extra_info('peername')[0]
|
||||
|
||||
|
||||
@property
|
||||
def client_port(self):
|
||||
return self.writer.get_extra_info('peername')[1]
|
||||
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
return self.writer.is_closing()
|
||||
|
||||
|
||||
async def read(self, length=2048, timeout=None):
|
||||
return await asyncio.wait_for(self.reader.read(length), timeout or self.app.cfg.timeout)
|
||||
|
||||
|
||||
async def readuntil(self, bytes, timeout=None):
|
||||
return await asyncio.wait_for(self.reader.readuntil(bytes), timeout or self.app.cfg.timeout)
|
||||
|
||||
|
||||
async def write(self, data):
|
||||
if isinstance(data, DotDict):
|
||||
data = data.to_json()
|
||||
|
||||
elif any(map(isinstance, [data], [dict, list, tuple])):
|
||||
data = json.dumps(data)
|
||||
|
||||
# not sure if there's a better type to use, but this should be fine for now
|
||||
elif any(map(isinstance, [data], [float, int])):
|
||||
data = str(data)
|
||||
|
||||
elif isinstance(data, bytearray):
|
||||
data = str(data)
|
||||
|
||||
elif not any(map(isinstance, [data], [bytes, str])):
|
||||
raise TypeError('Data must be or a str, bytes, bytearray, float, it, dict, list, or tuple')
|
||||
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
|
||||
self.writer.write(data)
|
||||
await self.writer.drain()
|
||||
|
||||
|
||||
async def close(self):
|
||||
if self.closed:
|
||||
return
|
||||
|
||||
self.writer.close()
|
||||
await self.writer.wait_closed()
|
|
@ -24,6 +24,7 @@ class Row(DotDict):
|
|||
except:
|
||||
self._update(row)
|
||||
|
||||
self.__session = session
|
||||
self.__db = session.db
|
||||
self.__table_name = table
|
||||
|
||||
|
@ -40,6 +41,11 @@ class Row(DotDict):
|
|||
return self.__table_name
|
||||
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return self.__session
|
||||
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
return self.keys()
|
||||
|
|
|
@ -91,23 +91,21 @@ class Session(sqlalchemy_session):
|
|||
|
||||
query = self.query(self.table[table]).filter_by(**kwargs)
|
||||
|
||||
if not orderby:
|
||||
rows = query.all()
|
||||
|
||||
else:
|
||||
if orderby:
|
||||
if orderdir == 'asc':
|
||||
rows = query.order_by(getattr(self.table[table].c, orderby).asc()).all()
|
||||
query = query.order_by(getattr(self.table[table].c, orderby).asc())
|
||||
|
||||
elif orderdir == 'desc':
|
||||
rows = query.order_by(getattr(self.table[table].c, orderby).desc()).all()
|
||||
query = query.order_by(getattr(self.table[table].c, orderby).desc())
|
||||
|
||||
else:
|
||||
raise ValueError(f'Unsupported order direction: {orderdir}')
|
||||
|
||||
if single:
|
||||
return RowClass(table, rows[0], self) if len(rows) > 0 else None
|
||||
row = query.first()
|
||||
return RowClass(table, row, self) if row else None
|
||||
|
||||
return [RowClass(table, row, self) for row in rows]
|
||||
return [RowClass(table, row, self) for row in query.all()]
|
||||
|
||||
|
||||
def search(self, *args, **kwargs):
|
||||
|
@ -131,10 +129,10 @@ class Session(sqlalchemy_session):
|
|||
if getattr(self.table[table], 'timestamp', None) and not kwargs.get('timestamp'):
|
||||
kwargs['timestamp'] = datetime.now()
|
||||
|
||||
self.execute(self.table[table].insert().values(**kwargs))
|
||||
cursor = self.execute(self.table[table].insert().values(**kwargs))
|
||||
|
||||
if return_row:
|
||||
return self.fetch(table, **kwargs)
|
||||
return self.fetch(table, id=cursor.inserted_primary_key[0])
|
||||
|
||||
|
||||
def update(self, table=None, rowid=None, row=None, return_row=False, **kwargs):
|
||||
|
@ -145,7 +143,7 @@ class Session(sqlalchemy_session):
|
|||
if not rowid or not table:
|
||||
raise ValueError('Missing row ID or table')
|
||||
|
||||
self.execute(self.table[table].update().where(self.table[table].c.id == rowid).values(**kwargs))
|
||||
row = self.execute(self.table[table].update().where(self.table[table].c.id == rowid).values(**kwargs))
|
||||
|
||||
if return_row:
|
||||
return self.fetch(table, id=rowid)
|
||||
|
|
6
izzylib/sql2/__init__.py
Normal file
6
izzylib/sql2/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from .database import Database, Connection
|
||||
from .result import Result
|
||||
from .row import Row
|
||||
from .session import Session
|
||||
from .statements import Comparison, Statement, Select, Insert, Update, Delete, Count
|
||||
from .table import Column
|
102
izzylib/sql2/config.py
Normal file
102
izzylib/sql2/config.py
Normal file
|
@ -0,0 +1,102 @@
|
|||
import sqlite3
|
||||
|
||||
from getpass import getuser
|
||||
from importlib import import_module
|
||||
|
||||
from .result import Result
|
||||
from .row import Row
|
||||
from .session import Session
|
||||
|
||||
from ..config import BaseConfig
|
||||
from ..path import Path
|
||||
|
||||
|
||||
class Config(BaseConfig):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
appname = 'IzzyLib SQL Client',
|
||||
type = 'sqlite',
|
||||
module = None,
|
||||
module_name = None,
|
||||
tables = {},
|
||||
row_classes = {},
|
||||
session_class = Session,
|
||||
result_class = Result,
|
||||
host = 'localhost',
|
||||
port = 0,
|
||||
database = None,
|
||||
username = getuser(),
|
||||
password = None,
|
||||
minconnections = 4,
|
||||
maxconnections = 25,
|
||||
engine_args = {},
|
||||
auto_trans = True,
|
||||
connect_function = None,
|
||||
autocommit = False
|
||||
)
|
||||
|
||||
for k, v in kwargs.items():
|
||||
self[k] = v
|
||||
|
||||
if not self.database:
|
||||
if self.type == 'sqlite':
|
||||
self.database = ':memory:'
|
||||
|
||||
else:
|
||||
raise ValueError('Missing database name')
|
||||
|
||||
if not self.port:
|
||||
if self.type == 'postgresql':
|
||||
self.port = 5432
|
||||
|
||||
elif self.type == 'mysql':
|
||||
self.port = 3306
|
||||
|
||||
if not self.module and not self.connect_function:
|
||||
if self.type == 'sqlite':
|
||||
self.module = sqlite3
|
||||
self.module_name = 'sqlite3'
|
||||
|
||||
elif self.type == 'postgresql':
|
||||
for mod in ['pg8000.dbapi', 'pgdb', 'psycopg2']:
|
||||
try:
|
||||
self.module = import_module(mod)
|
||||
self.module_name = mod
|
||||
break
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
elif self.type == 'mysql':
|
||||
try:
|
||||
self.module = import_module('mysql.connector')
|
||||
self.module_name = 'mysql.connector'
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if not self.module:
|
||||
raise ImportError(f'Cannot find module for "{self.type}"')
|
||||
|
||||
self.module.paramstyle = 'qmark'
|
||||
|
||||
|
||||
@property
|
||||
def dbargs(self):
|
||||
return {key: self[key] for key in ['host', 'port', 'database', 'username', 'password']}
|
||||
|
||||
|
||||
def parse_value(self, key, value):
|
||||
if key == 'type':
|
||||
if value not in ['sqlite', 'postgresql', 'mysql', 'mssql']:
|
||||
raise ValueError(f'Invalid database type: {value}')
|
||||
|
||||
if key == 'port':
|
||||
if not isinstance(value, int):
|
||||
raise TypeError('Port is not an integer')
|
||||
|
||||
if key == 'row_classes':
|
||||
for row_class in value.values():
|
||||
if not issubclass(row_class, Row):
|
||||
raise TypeError(f'Row classes must be izzylib.sql2.row.Row, not {row_class.__name__}')
|
||||
|
||||
return value
|
244
izzylib/sql2/database.py
Normal file
244
izzylib/sql2/database.py
Normal file
|
@ -0,0 +1,244 @@
|
|||
import itertools
|
||||
|
||||
from .config import Config
|
||||
from .row import Row
|
||||
from .table import DbTables
|
||||
from .types import Types
|
||||
|
||||
from .. import izzylog
|
||||
from ..dotdict import DotDict
|
||||
from ..exceptions import MaxConnectionsError, NoTableLayoutError, NoConnectionError
|
||||
from ..path import Path
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, autoconnect=True, app=None, **kwargs):
|
||||
tables = kwargs.pop('tables', None)
|
||||
|
||||
self.cfg = Config(**kwargs)
|
||||
self.tables = DbTables(self)
|
||||
self.types = Types(self)
|
||||
self.connections = []
|
||||
self.app = app
|
||||
|
||||
if tables:
|
||||
self.load_tables(tables)
|
||||
|
||||
if autoconnect:
|
||||
self.connect()
|
||||
|
||||
|
||||
def connect(self):
|
||||
for _ in itertools.repeat(None, self.cfg.minconnections):
|
||||
self.get_connection()
|
||||
|
||||
|
||||
def disconnect(self):
|
||||
for conn in self.connections:
|
||||
conn.disconnect()
|
||||
|
||||
self.connections = []
|
||||
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return self.get_connection().session
|
||||
|
||||
|
||||
def new_connection(self):
|
||||
if len(self.connections) >= self.cfg.maxconnections:
|
||||
raise MaxConnectionsError('Too many connections')
|
||||
|
||||
conn = Connection(self)
|
||||
conn.connect()
|
||||
self.connections.append(conn)
|
||||
|
||||
return conn
|
||||
|
||||
|
||||
def close_connection(self, conn):
|
||||
print('close connection')
|
||||
conn.close_sessions()
|
||||
conn.disconnect()
|
||||
|
||||
if not conn.conn:
|
||||
try: self.connections.remove(conn)
|
||||
except: pass
|
||||
|
||||
|
||||
def get_connection(self):
|
||||
if not len(self.connections):
|
||||
return self.new_connection()
|
||||
|
||||
if len(self.connections) < self.cfg.minconnections:
|
||||
return self.new_connection()
|
||||
|
||||
for conn in self.connections:
|
||||
if not len(conn.sessions):
|
||||
return conn
|
||||
|
||||
if len(self.connections) < self.cfg.maxconnections:
|
||||
return self.new_connection()
|
||||
|
||||
conns = {(conn, len(conn.sessions)) for conn in self.connections}
|
||||
return min(conns, key=lambda x: x[1])[0]
|
||||
|
||||
|
||||
def new_predb(self, database='postgres'):
|
||||
dbconfig = Config(**self.cfg)
|
||||
dbconfig['database'] = database
|
||||
dbconfig['autocommit'] = True
|
||||
|
||||
return Database(**dbconfig)
|
||||
|
||||
|
||||
def set_row_class(self, name, row_class):
|
||||
if not issubclass(row_class, Row):
|
||||
raise TypeError(f'Row classes must be izzylib.sql2.row.Row, not {row_class.__name__}')
|
||||
|
||||
self.cfg.row_classes[name] = row_class
|
||||
|
||||
|
||||
def get_row_class(self, name):
|
||||
return self.cfg.row_classes.get(name, Row)
|
||||
|
||||
|
||||
def load_tables(self, tables=None):
|
||||
if tables:
|
||||
self.tables.load_tables(tables)
|
||||
|
||||
else:
|
||||
with self.session as s:
|
||||
self.tables.load_tables(s.table_layout())
|
||||
|
||||
def create_tables(self):
|
||||
if self.tables.empty:
|
||||
raise NoTableLayoutError('Table layout not loaded yet')
|
||||
|
||||
with self.session as s:
|
||||
for table in self.tables.names:
|
||||
s.execute(self.tables.compile_table(table))
|
||||
|
||||
|
||||
def create_database(self):
|
||||
if self.cfg.type == 'postgresql':
|
||||
with self.new_predb().session as s:
|
||||
if not s.raw_execute('SELECT datname FROM pg_database WHERE datname = ?', [self.cfg.database]).fetchone():
|
||||
s.raw_execute(f'CREATE DATABASE {self.cfg.database}')
|
||||
|
||||
elif self.cfg.type != 'sqlite':
|
||||
raise NotImplementedError(f'Database type not supported yet: {self.cfg.type}')
|
||||
|
||||
self.create_tables(tables)
|
||||
|
||||
|
||||
def drop_database(self, database):
|
||||
if self.cfg.type == 'sqlite':
|
||||
izzylog.verbose('drop_database not needed for SQLite')
|
||||
return
|
||||
|
||||
with self.session as s:
|
||||
if self.cfg.type == 'postgresql':
|
||||
s.raw_execute(f'DROP DATABASE {database}')
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Database type not supported yet: {self.cfg.type}')
|
||||
|
||||
|
||||
class Connection:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
self.cfg = db.cfg
|
||||
self.sessions = []
|
||||
self.conn = None
|
||||
|
||||
self.connect()
|
||||
|
||||
if db.tables.empty:
|
||||
with self.session as s:
|
||||
db.load_tables(s.table_layout())
|
||||
|
||||
|
||||
@property
|
||||
def autocommit(self):
|
||||
return self.conn.autocommit
|
||||
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return self.cfg.session_class(self)
|
||||
|
||||
|
||||
def connect(self):
|
||||
if self.conn:
|
||||
return
|
||||
|
||||
dbconfig = self.cfg.dbargs
|
||||
|
||||
if self.cfg.type == 'sqlite':
|
||||
if self.cfg.autocommit:
|
||||
self.conn = self.cfg.module.connect(dbconfig['database'], isolation_level=None)
|
||||
|
||||
else:
|
||||
self.conn = self.cfg.module.connect(dbconfig['database'])
|
||||
|
||||
elif self.cfg.type == 'postgresql':
|
||||
if Path(self.cfg.host).exists():
|
||||
dbconfig['unix_sock'] = dbconfig.pop('host')
|
||||
|
||||
dbconfig['user'] = dbconfig.pop('username')
|
||||
dbconfig['application_name'] = self.cfg.appname
|
||||
self.conn = self.cfg.module.connect(**dbconfig)
|
||||
|
||||
else:
|
||||
self.conn = self.cfg.module.connect(**self.cfg.dbargs)
|
||||
|
||||
try:
|
||||
self.conn.autocommit = self.cfg.autocommit
|
||||
|
||||
except AttributeError:
|
||||
if self.cfg.module_name not in ['sqlite']:
|
||||
izzylog.verbose('Module does not support autocommit:', self.cfg.module_name)
|
||||
|
||||
return self.conn
|
||||
|
||||
|
||||
def disconnect(self):
|
||||
if not self.conn:
|
||||
return
|
||||
|
||||
self.close_sessions()
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
|
||||
|
||||
def close_sessions(self):
|
||||
for session in self.sessions:
|
||||
self.close_session(session)
|
||||
|
||||
|
||||
def close_session(session):
|
||||
try: self.sessions.remove(session)
|
||||
except: pass
|
||||
|
||||
session.close()
|
||||
|
||||
if not len(self.sessions) and len(self.db.connections) > self.cfg.minconnections:
|
||||
self.disconnect()
|
||||
|
||||
|
||||
def cursor(self):
|
||||
if not self.conn:
|
||||
raise
|
||||
return self.conn.cursor()
|
||||
|
||||
|
||||
def dump_database(self, path='database.sql'):
|
||||
if self.cfg.type == 'sqlite':
|
||||
path = Path(path)
|
||||
|
||||
with path.open('w') as fd:
|
||||
fd.write('\n\n'.join(list(self.conn.iterdump())[1:-1]))
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Only SQLite supported atm :/')
|
68
izzylib/sql2/result.py
Normal file
68
izzylib/sql2/result.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
from .row import Row
|
||||
|
||||
|
||||
class Result:
|
||||
def __init__(self, session):
|
||||
self.table = None
|
||||
self.session = session
|
||||
self.cursor = session.cursor
|
||||
|
||||
try:
|
||||
self.keys = [desc[0] for desc in session.cursor.description]
|
||||
|
||||
except TypeError:
|
||||
self.keys = []
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.all_iter()
|
||||
|
||||
|
||||
@property
|
||||
def row_class(self):
|
||||
return self.session.db.get_row_class(self.table)
|
||||
|
||||
|
||||
@property
|
||||
def last_row_id(self):
|
||||
if self.session.cfg.type == 'postgresql':
|
||||
try:
|
||||
return self.one().id
|
||||
|
||||
except:
|
||||
return None
|
||||
|
||||
return self.cursor.lastrowid
|
||||
|
||||
|
||||
@property
|
||||
def row_count(self):
|
||||
return self.cursor.rowcount
|
||||
|
||||
|
||||
def set_table(self, table):
|
||||
self.table = table
|
||||
|
||||
|
||||
def one(self):
|
||||
data = self.cursor.fetchone()
|
||||
|
||||
if not data:
|
||||
return
|
||||
|
||||
return self.row_class(
|
||||
self.session,
|
||||
self.table,
|
||||
{self.keys[idx]: value for idx, value in enumerate(data)},
|
||||
)
|
||||
|
||||
|
||||
def all(self):
|
||||
return [row for row in self.all_iter()]
|
||||
|
||||
|
||||
def all_iter(self):
|
||||
for row in self.cursor:
|
||||
yield self.row_class(self.session, self.table,
|
||||
{self.keys[idx]: value for idx, value in enumerate(row)}
|
||||
)
|
29
izzylib/sql2/row.py
Normal file
29
izzylib/sql2/row.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
from ..dotdict import DotDict
|
||||
|
||||
|
||||
class Row(DotDict):
|
||||
def __init__(self, session, table, data):
|
||||
super().__init__(session._parse_data('serialize', table, data))
|
||||
|
||||
self._table = table
|
||||
self._session = session
|
||||
self.__run__(session)
|
||||
|
||||
|
||||
def __run__(self, session):
|
||||
pass
|
||||
|
||||
|
||||
@property
|
||||
def table(self):
|
||||
return self._table
|
||||
|
||||
|
||||
@property
|
||||
def rowid(self):
|
||||
return self.id
|
||||
|
||||
|
||||
@property
|
||||
def rowid2(self):
|
||||
return self.get('rowid', self.id)
|
346
izzylib/sql2/session.py
Normal file
346
izzylib/sql2/session.py
Normal file
|
@ -0,0 +1,346 @@
|
|||
import json
|
||||
|
||||
from pathlib import Path as PyPath
|
||||
|
||||
from .result import Result
|
||||
from .row import Row
|
||||
from .statements import Select, Insert, Delete, Count, Update, Statement
|
||||
from .table import SessionTables
|
||||
|
||||
from .. import izzylog
|
||||
from ..dotdict import DotDict
|
||||
from ..exceptions import NoTransactionError, UpdateAllRowsError
|
||||
from ..misc import boolean, random_gen
|
||||
from ..path import Path
|
||||
|
||||
|
||||
class Session:
|
||||
def __init__(self, conn):
|
||||
self.db = conn.db
|
||||
self.cfg = conn.db.cfg
|
||||
self.conn = conn
|
||||
self.sid = random_gen()
|
||||
self.tables = SessionTables(self)
|
||||
|
||||
self.cursor = conn.cursor()
|
||||
self.trans = False
|
||||
|
||||
self.__setup__()
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
||||
def __exit__(self, exctype, excvalue, traceback):
|
||||
if traceback:
|
||||
self.rollback()
|
||||
|
||||
else:
|
||||
self.commit()
|
||||
|
||||
|
||||
def __setup__(self):
|
||||
pass
|
||||
|
||||
|
||||
def close():
|
||||
if not self.cursor:
|
||||
return
|
||||
|
||||
self.conn.close_session(self)
|
||||
|
||||
|
||||
def _parse_data(self, action, table, kwargs):
|
||||
data = {}
|
||||
|
||||
if self.db.tables:
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
coltype = self.db.tables[table][key].type
|
||||
|
||||
except KeyError:
|
||||
data[key] = value
|
||||
continue
|
||||
|
||||
parser = self.db.types.get_type(coltype)
|
||||
|
||||
try:
|
||||
data[key] = parser(action, self.cfg.type, value)
|
||||
except Exception as e:
|
||||
izzylog.error(f'Failed to parse data from the table "{table}": {key} = {value}')
|
||||
izzylog.debug(f'Parser: {parser}, Type: {coltype}')
|
||||
raise e from None
|
||||
|
||||
else:
|
||||
data = kwargs
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def dump_database(self, path):
|
||||
import sqlparse
|
||||
path = Path(path)
|
||||
|
||||
with path.open('w') as fd:
|
||||
line = '\n\n'.join(list(self.conn.iterdump())[1:-1])
|
||||
fd.write(sqlparse.format(line,
|
||||
reindent = False,
|
||||
keyword_case = 'upper',
|
||||
))
|
||||
|
||||
def dump_database2(self, path):
|
||||
path = Path(path)
|
||||
|
||||
with path.open('w') as fd:
|
||||
fd.write('\n\n'.join(list(self.conn.iterdump())[1:-1]))
|
||||
|
||||
|
||||
def begin(self):
|
||||
if self.trans or self.cfg.autocommit:
|
||||
return
|
||||
|
||||
self.execute('BEGIN')
|
||||
self.trans = True
|
||||
|
||||
|
||||
def commit(self):
|
||||
if not self.trans:
|
||||
return
|
||||
|
||||
self.execute('COMMIT')
|
||||
self.trans = False
|
||||
|
||||
|
||||
def rollback(self):
|
||||
if not self.trans:
|
||||
return
|
||||
|
||||
self.execute('ROLLBACK')
|
||||
self.trans = False
|
||||
|
||||
|
||||
def raw_execute(self, string, values=None):
|
||||
if type(string) == Path:
|
||||
string = string.read()
|
||||
|
||||
elif type(string) == PyPath:
|
||||
with string.open() as fd:
|
||||
string = fd.read()
|
||||
|
||||
if values:
|
||||
self.cursor.execute(string, values)
|
||||
|
||||
else:
|
||||
self.cursor.execute(string)
|
||||
|
||||
return self.cursor
|
||||
|
||||
|
||||
def execute(self, string, *values):
|
||||
if isinstance(string, Statement):
|
||||
raise TypeError('String must be a str not a Statement')
|
||||
|
||||
action = string.split()[0].upper()
|
||||
|
||||
if not self.trans and action in ['CREATE', 'INSERT', 'UPDATE', 'UPSERT', 'DROP', 'DELETE', 'ALTER']:
|
||||
if self.cfg.auto_trans:
|
||||
self.begin()
|
||||
|
||||
else:
|
||||
raise NoTransactionError(f'Command not supported outside a transaction: {action}')
|
||||
|
||||
try:
|
||||
self.raw_execute(string, values)
|
||||
|
||||
except Exception as e:
|
||||
if type(e).__name__ in ['DatabaseError', 'OperationalError']:
|
||||
print(string, values)
|
||||
|
||||
raise e from None
|
||||
|
||||
return Result(self)
|
||||
|
||||
|
||||
def run(self, query):
|
||||
result = self.execute(query.compile(self.cfg.type), *query.values)
|
||||
|
||||
if type(query) == Count:
|
||||
return list(result.one().values())[0]
|
||||
|
||||
result.set_table(query.table)
|
||||
return result
|
||||
|
||||
|
||||
def run_count(self, query):
|
||||
return list(self.run(query).one().values())[0]
|
||||
|
||||
|
||||
def count(self, table, **kwargs):
|
||||
if self.db.tables and table not in self.db.tables:
|
||||
raise KeyError(f'Table does not exist: {table}')
|
||||
|
||||
query = Count(table, **kwargs)
|
||||
|
||||
return self.run_count(query)
|
||||
|
||||
|
||||
def fetch(self, table, orderby=None, orderdir='ASC', limit=None, offset=None, **kwargs):
|
||||
if self.db.tables and table not in self.db.tables:
|
||||
raise KeyError(f'Table does not exist: {table}')
|
||||
|
||||
query = Select(table, **kwargs)
|
||||
|
||||
if orderby:
|
||||
query.order(orderby, orderdir)
|
||||
|
||||
if limit:
|
||||
query.limit(limist)
|
||||
|
||||
if offset:
|
||||
query.offset(offset)
|
||||
|
||||
return self.run(query)
|
||||
|
||||
|
||||
def insert(self, table, return_row=False, **kwargs):
|
||||
if self.db.tables and table not in self.db.tables:
|
||||
raise KeyError(f'Table does not exist: {table}')
|
||||
|
||||
result = self.run(Insert(table, **self._parse_data('deserialize', table, kwargs)))
|
||||
|
||||
if return_row:
|
||||
return self.fetch(table, id=result.last_row_id).one()
|
||||
|
||||
return result.last_row_id
|
||||
|
||||
|
||||
def update(self, table, data, return_row=False, **kwargs):
|
||||
query = Update(table, **data)
|
||||
|
||||
for pair in kwargs.items():
|
||||
query.where(*pair)
|
||||
|
||||
if not query._where:
|
||||
raise UpdateAllRowsError(f'Refusing to update all rows in table: {table}')
|
||||
|
||||
result = self.run(query)
|
||||
|
||||
if return_row:
|
||||
return self.fetch(table, id=result.last_row_id).one()
|
||||
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def update_row(self, row, return_row=False, **kwargs):
|
||||
return self.update(row.table, kwargs, id=row.id, return_row=return_row)
|
||||
|
||||
|
||||
def remove(self, table, **kwargs):
|
||||
if self.db.tables and table not in self.db.tables:
|
||||
raise KeyError(f'Table does not exist: {table}')
|
||||
|
||||
self.run(Delete(table, self._parse_data('deserialize', table, kwargs)))
|
||||
|
||||
|
||||
def remove_row(self, row):
|
||||
if not row.table:
|
||||
raise ValueError('Row not associated with a table')
|
||||
|
||||
self.remove(row.table, id=row.id)
|
||||
|
||||
|
||||
def create_tables(self, tables=None):
|
||||
if tables:
|
||||
self.load_tables(tables)
|
||||
|
||||
if not self.tables:
|
||||
raise NoTableLayoutError('No table layout available')
|
||||
|
||||
for table in self.tables.values():
|
||||
self.execute(table.compile(self.cfg.type))
|
||||
|
||||
|
||||
def table_layout(self):
|
||||
tables = {}
|
||||
|
||||
if self.cfg.type == 'sqlite':
|
||||
rows = self.execute("SELECT name, sql FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%'")
|
||||
|
||||
for row in rows:
|
||||
name = row.name
|
||||
tables[name] = {}
|
||||
fkeys = {fkey['from']: f'{fkey.table}.{fkey["to"]}' for fkey in self.execute(f'PRAGMA foreign_key_list({name})')}
|
||||
columns = [col for col in self.execute(f'PRAGMA table_info({name})')]
|
||||
|
||||
unique_list = parse_unique(row.sql)
|
||||
|
||||
for column in columns:
|
||||
tables[name][column.name] = dict(
|
||||
type = column.type.upper(),
|
||||
nullable = not column.notnull,
|
||||
default = parse_default(column.dflt_value),
|
||||
primary_key = bool(column.pk),
|
||||
foreign_key = fkeys.get(column.name),
|
||||
unique = column.name in unique_list
|
||||
)
|
||||
|
||||
elif self.cfg.type == 'postgresql':
|
||||
for row in self.execute("SELECT * FROM information_schema.columns WHERE table_schema not in ('information_schema', 'pg_catalog') ORDER BY table_schema, table_name, ordinal_position"):
|
||||
table = row.table_name
|
||||
column = row.column_name
|
||||
|
||||
if not tables.get(table):
|
||||
tables[table] = {}
|
||||
|
||||
if not tables[table].get(column):
|
||||
tables[table][column] = {}
|
||||
|
||||
tables[table][column] = dict(
|
||||
type = row.data_type.upper(),
|
||||
nullable = boolean(row.is_nullable),
|
||||
default = row.column_default if row.column_default and not row.column_default.startswith('nextval') else None,
|
||||
primary_key = None,
|
||||
foreign_key = None,
|
||||
unique = None
|
||||
)
|
||||
|
||||
return tables
|
||||
|
||||
|
||||
def parse_unique(sql):
|
||||
unique_list = []
|
||||
|
||||
try:
|
||||
for raw_line in sql.splitlines():
|
||||
if 'UNIQUE' not in raw_line:
|
||||
continue
|
||||
|
||||
for line in raw_line.replace('UNIQUE', '').replace('(', '').replace(')', '').split(','):
|
||||
line = line.strip()
|
||||
|
||||
if line:
|
||||
unique_list.append(line)
|
||||
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
return unique_list
|
||||
|
||||
|
||||
def parse_default(value):
|
||||
if value == None:
|
||||
return
|
||||
|
||||
if value.startswith("'") and value.endswith("'"):
|
||||
value = value[1:-1]
|
||||
|
||||
else:
|
||||
try:
|
||||
value = int(value)
|
||||
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return value
|
185
izzylib/sql2/statements.py
Normal file
185
izzylib/sql2/statements.py
Normal file
|
@ -0,0 +1,185 @@
|
|||
from ..dotdict import DotDict
|
||||
|
||||
|
||||
Comparison = DotDict(
|
||||
LESS = lambda key: f'{key} < ?',
|
||||
GREATER = lambda key: f'{key} > ?',
|
||||
LESS_EQUAL = lambda key: f'{key} <= ?',
|
||||
GREATER_EQUAL = lambda key: f'{key} >= ?',
|
||||
EQUAL = lambda key: f'{key} = ?',
|
||||
NOT_EQUAL = lambda key: f'{key} != ?',
|
||||
IN = lambda key: f'{key} IN (?)',
|
||||
NOT_IN = lambda key: f'{key} NOT IN (?)',
|
||||
LIKE = lambda key: f'{key} LIKE ?',
|
||||
NOT_LIKE = lambda key: f'{key} NOT LIKE ?'
|
||||
)
|
||||
|
||||
|
||||
class Statement:
|
||||
def __init__(self, table):
|
||||
self.table = table
|
||||
self.values = []
|
||||
|
||||
self._where = ''
|
||||
self._order = None
|
||||
self._limit = None
|
||||
self._offset = None
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.compile('sqlite')
|
||||
|
||||
|
||||
def where(self, key, value, comparison='equal', operator='and'):
|
||||
try:
|
||||
comp = Comparison[comparison.upper().replace('-', '_')]
|
||||
|
||||
except KeyError:
|
||||
raise KeyError(f'Invalid comparison: {comparison}')
|
||||
|
||||
prefix = f' {operator} ' if self._where else ' '
|
||||
|
||||
self._where += f'{prefix}{comp(key)}'
|
||||
self.values.append(value)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def order(self, column, direction='ASC'):
|
||||
direction = direction.upper()
|
||||
assert direction in ['ASC', 'DESC']
|
||||
self._order = (column, direction)
|
||||
return self
|
||||
|
||||
|
||||
def limit(self, limit_num):
|
||||
self._limit = int(limit_num)
|
||||
return self
|
||||
|
||||
|
||||
def offset(self, offset_num):
|
||||
self._offset = int(offset_num)
|
||||
return self
|
||||
|
||||
|
||||
def compile(self, dbtype):
|
||||
raise NotImplementedError('Do not use the Statement class directly.')
|
||||
|
||||
|
||||
class Select(Statement):
|
||||
def __init__(self, table, *columns, **kwargs):
|
||||
super().__init__(table)
|
||||
|
||||
self.columns = columns
|
||||
|
||||
for key, value in kwargs.items():
|
||||
self.where(key, value)
|
||||
|
||||
|
||||
def compile(self, dbtype):
|
||||
data = f'SELECT'
|
||||
|
||||
if self.columns:
|
||||
columns = ','.join(self.columns)
|
||||
|
||||
else:
|
||||
columns = '*'
|
||||
|
||||
data += f' {columns} FROM {self.table}'
|
||||
|
||||
if self._where:
|
||||
data += f' WHERE {self._where}'
|
||||
|
||||
if self._order:
|
||||
col, direc = self._order
|
||||
data += f' ORDER BY {col} {direc}'
|
||||
|
||||
if self._limit:
|
||||
data += f' LIMIT {self._limit}'
|
||||
|
||||
if self._offset:
|
||||
data += f' OFFSET {self._offset}'
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class Insert(Statement):
|
||||
def __init__(self, table, **kwargs):
|
||||
super().__init__(table)
|
||||
|
||||
self.keys = []
|
||||
|
||||
for pair in kwargs.items():
|
||||
self.add_data(*pair)
|
||||
|
||||
|
||||
def add_data(self, key, value):
|
||||
self.keys.append(key)
|
||||
self.values.append(value)
|
||||
|
||||
|
||||
def remove_data(self, key):
|
||||
index = self.keys.index(key)
|
||||
|
||||
del self.keys[index]
|
||||
del self.values[index]
|
||||
|
||||
|
||||
def compile(self, dbtype):
|
||||
keys = ','.join(self.keys)
|
||||
values = ','.join('?' for value in self.values)
|
||||
data = f'INSERT INTO {self.table} ({keys}) VALUES ({values})'
|
||||
|
||||
if dbtype == 'postgresql':
|
||||
data += f' RETURNING id'
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class Update(Statement):
|
||||
def __init__(self, table, **kwargs):
|
||||
super().__init__(table)
|
||||
self.keys = []
|
||||
|
||||
for key, value in kwargs.items():
|
||||
self.keys.append(key)
|
||||
self.values.append(value)
|
||||
|
||||
|
||||
def compile(self, dbtype):
|
||||
pairs = ','.join(f'{key} = ?' for key in self.keys)
|
||||
data = f'UPDATE {self.table} SET {pairs} WHERE {self._where}'
|
||||
|
||||
if dbtype == 'postgresql':
|
||||
data += f' RETURNING id'
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class Delete(Statement):
|
||||
def __init__(self, table, **kwargs):
|
||||
super().__init__(table)
|
||||
|
||||
for key, value in kwargs.items():
|
||||
self.where(key, value)
|
||||
|
||||
|
||||
def compile(self, dbtype):
|
||||
return f'DELETE FROM {self.table} WHERE {self._where}'
|
||||
|
||||
|
||||
class Count(Statement):
|
||||
def __init__(self, table, **kwargs):
|
||||
super().__init__(table)
|
||||
|
||||
for key, value in kwargs.items():
|
||||
self.where(key, value)
|
||||
|
||||
|
||||
def compile(self, dbtype):
|
||||
data = f'SELECT COUNT(*) FROM {self.table}'
|
||||
|
||||
if self._where:
|
||||
data += f' WHERE {self._where}'
|
||||
|
||||
return data
|
244
izzylib/sql2/table.py
Normal file
244
izzylib/sql2/table.py
Normal file
|
@ -0,0 +1,244 @@
|
|||
from ..dotdict import DotDict
|
||||
|
||||
|
||||
class SessionTables:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
self._db = session.db
|
||||
self._tables = session.db.tables
|
||||
|
||||
|
||||
def __getattr__(self, key):
|
||||
return SessionTable(session, key, self._tables[key])
|
||||
|
||||
|
||||
def names(self):
|
||||
return tuple(self._tables.keys())
|
||||
|
||||
|
||||
class SessionTable(DotDict):
|
||||
def __init__(self, session, name, columns):
|
||||
super().__init__(columns)
|
||||
|
||||
self._name = name
|
||||
self._session = session
|
||||
self._db = session.db
|
||||
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
return tuple(self.keys())
|
||||
|
||||
|
||||
def fetch(self, **kwargs):
|
||||
self._check_columns(**kwargs)
|
||||
return self.session.fetch(self.name, **kwargs)
|
||||
|
||||
|
||||
def insert(self, **kwargs):
|
||||
self._check_columns(**kwargs)
|
||||
return self.session.insert(self.name, **kwargs)
|
||||
|
||||
|
||||
def remove(self, **kwargs):
|
||||
self._check_columns(**kwargs)
|
||||
return self.session.remove(self.name, **kwargs)
|
||||
|
||||
|
||||
def _check_columns(self, **kwargs):
|
||||
for key in kwargs.keys():
|
||||
if key not in self.columns:
|
||||
raise KeyError(f'Not a column for table "{self.name}": {key}')
|
||||
|
||||
|
||||
class DbTables(DotDict):
|
||||
def __init__(self, db):
|
||||
super().__init__()
|
||||
|
||||
self._db = db
|
||||
self._cfg = db.cfg
|
||||
|
||||
|
||||
@property
|
||||
def empty(self):
|
||||
return not len(self.keys())
|
||||
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
return tuple(self.keys())
|
||||
|
||||
|
||||
def load_tables(self, tables):
|
||||
for name, columns in tables.items():
|
||||
self.add_table(name, columns)
|
||||
|
||||
|
||||
def unload_tables(self):
|
||||
for key in self.names:
|
||||
del self[key]
|
||||
|
||||
|
||||
def add_table(self, name, columns):
|
||||
self[name] = {}
|
||||
|
||||
if type(columns) == list:
|
||||
self[name] = {col.name: col for col in columns}
|
||||
|
||||
elif isinstance(columns, dict):
|
||||
for column, data in columns.items():
|
||||
self[name][column] = DbColumn(self._cfg.type, column, **data)
|
||||
|
||||
else:
|
||||
raise TypeError('Columns must be a list of Column objects or a dict')
|
||||
|
||||
|
||||
def remove_table(self, name):
|
||||
return self.pop(name)
|
||||
|
||||
|
||||
def get_columns(self, name):
|
||||
return tuple(self[name].values())
|
||||
|
||||
|
||||
def compile_table(self, table_name, dbtype):
|
||||
table = self[table_name]
|
||||
columns = []
|
||||
foreign_keys = []
|
||||
|
||||
for column in self.get_columns(table_name):
|
||||
columns.append(column.compile(dbtype))
|
||||
|
||||
if column.foreign_key:
|
||||
fkey_table, fkey_col = column.foreign_key
|
||||
foreign_keys.append(f'FOREIGN KEY ({column.name}) REFERENCES {fkey_table} ({fkey_col})')
|
||||
|
||||
return f'CREATE TABLE IF NOT EXISTS {self.name} ({",".join(columns)}{",".join(foreign_keys)})'
|
||||
|
||||
|
||||
def compile_all(self, dbtype):
|
||||
return [self.compile_table(name, dbtype) for name in self.keys()]
|
||||
|
||||
|
||||
class DbColumn(DotDict):
|
||||
def __init__(self, dbtype, name, type=None, default=None, primary_key=False, unique=False, nullable=True, autoincrement=False, foreign_key=None):
|
||||
super().__init__(
|
||||
name = name,
|
||||
type = type,
|
||||
default = default,
|
||||
primary_key = primary_key,
|
||||
unique = unique,
|
||||
nullable = nullable,
|
||||
autoincrement = autoincrement,
|
||||
foreign_key = foreign_key
|
||||
)
|
||||
|
||||
if self.name == 'id':
|
||||
if dbtype == 'sqlite':
|
||||
self.type = 'INTEGER'
|
||||
self.autoincrement = True
|
||||
|
||||
elif dbtype == 'postgresql':
|
||||
self.type = 'SERIAL'
|
||||
self.autoincrement = False
|
||||
|
||||
self.primary_key = True
|
||||
self.unique = False
|
||||
self.nullable = False
|
||||
self.default = None
|
||||
self.foreign_key = None
|
||||
|
||||
elif self.name in ['created', 'modified', 'accessed'] and not self.type:
|
||||
self.type = 'DATETIME'
|
||||
|
||||
if not self.type:
|
||||
raise ValueError(f'Must provide a column type for column: {name}')
|
||||
|
||||
try:
|
||||
self.fkey
|
||||
|
||||
except ValueError:
|
||||
raise ValueError(f'Invalid foreign_key format. Must be "table.column"')
|
||||
|
||||
|
||||
@property
|
||||
def fkey(self):
|
||||
try:
|
||||
return self.foreign_key.split('.')
|
||||
|
||||
except AttributeError:
|
||||
return
|
||||
|
||||
|
||||
def compile(self, dbtype):
|
||||
line = f'{self.name} {self.type}'
|
||||
|
||||
if self.primary_key:
|
||||
line += ' PRIMARY KEY'
|
||||
|
||||
if not self.nullable:
|
||||
line += ' NOT NULL'
|
||||
|
||||
if self.unique:
|
||||
line += ' UNIQUE'
|
||||
|
||||
if self.autoincrement and dbtype != 'postgresql':
|
||||
line += ' AUTOINCREMENT'
|
||||
|
||||
if self.default:
|
||||
line += f" DEFAULT {parse_default(self.default)}"
|
||||
|
||||
return line
|
||||
|
||||
|
||||
class Column(DotDict):
|
||||
def __init__(self, name, type=None, default=None, primary_key=False, unique=False, nullable=True, autoincrement=False, foreign_key=None):
|
||||
super().__init__(
|
||||
name = name,
|
||||
type = type.upper() if type else None,
|
||||
default = default,
|
||||
primary_key = primary_key,
|
||||
unique = unique,
|
||||
nullable = nullable,
|
||||
autoincrement = autoincrement,
|
||||
foreign_key = foreign_key
|
||||
)
|
||||
|
||||
if self.name == 'id':
|
||||
self.type = 'SERIAL'
|
||||
|
||||
elif self.name in ['created', 'modified', 'accessed'] and not self.type:
|
||||
self.type = 'DATETIME'
|
||||
|
||||
if not self.type:
|
||||
raise ValueError(f'Must provide a column type for column: {name}')
|
||||
|
||||
try:
|
||||
self.fkey
|
||||
|
||||
except ValueError:
|
||||
raise ValueError(f'Invalid foreign_key format. Must be "table.column"')
|
||||
|
||||
|
||||
@property
|
||||
def fkey(self):
|
||||
try:
|
||||
return self.foreign_key.split('.')
|
||||
|
||||
except AttributeError:
|
||||
return
|
||||
|
||||
|
||||
def parse_default(default):
|
||||
if isinstance(default, dict) or isinstance(default, list):
|
||||
default = json.dumps(default)
|
||||
|
||||
if type(default) == str:
|
||||
default = f"'{default}'"
|
||||
|
||||
return default
|
219
izzylib/sql2/types.py
Normal file
219
izzylib/sql2/types.py
Normal file
|
@ -0,0 +1,219 @@
|
|||
from datetime import date, time, datetime
|
||||
|
||||
from .. import izzylog
|
||||
from ..dotdict import DotDict, LowerDotDict
|
||||
|
||||
|
||||
Standard = {
|
||||
'INTEGER',
|
||||
'INT',
|
||||
'TINYINT',
|
||||
'SMALLINT',
|
||||
'MEDIUMINT',
|
||||
'BIGINT',
|
||||
'UNSIGNED BIG INT',
|
||||
'INT2',
|
||||
'INT8',
|
||||
'TEXT',
|
||||
'CHARACTER',
|
||||
'CHAR',
|
||||
'VARCHAR',
|
||||
'BLOB',
|
||||
'CLOB',
|
||||
'REAL',
|
||||
'DOUBLE',
|
||||
'DOUBLE PRECISION',
|
||||
'FLOAT',
|
||||
'NUMERIC',
|
||||
'DEC',
|
||||
'DECIMAL',
|
||||
'BOOLEAN',
|
||||
'DATE',
|
||||
'TIME',
|
||||
'JSON'
|
||||
}
|
||||
|
||||
|
||||
Sqlite = {
|
||||
*Standard,
|
||||
'DATETIME'
|
||||
}
|
||||
|
||||
|
||||
Postgresql = {
|
||||
*Standard,
|
||||
'SMALLSERIAL',
|
||||
'SERIAL',
|
||||
'BIGSERIAL',
|
||||
'VARYING',
|
||||
'BYTEA',
|
||||
'TIMESTAMP',
|
||||
'INTERVAL',
|
||||
'POINT',
|
||||
'LINE',
|
||||
'LSEG',
|
||||
'BOX',
|
||||
'PATH',
|
||||
'POLYGON',
|
||||
'CIRCLE',
|
||||
}
|
||||
|
||||
|
||||
Mysql = {
|
||||
*Standard,
|
||||
'FIXED',
|
||||
'BIT',
|
||||
'YEAR',
|
||||
'VARBINARY',
|
||||
'ENUM',
|
||||
'SET'
|
||||
}
|
||||
|
||||
|
||||
class Type:
|
||||
sqlite = None
|
||||
postgresql = None
|
||||
mysql = None
|
||||
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in ['sqlite', 'postgresql', 'mysql']:
|
||||
return getattr(self, key)
|
||||
|
||||
raise KeyError(f'Invalid database type: {key}')
|
||||
|
||||
|
||||
def __call__(self, action, dbtype, value):
|
||||
return getattr(self, action)(dbtype, value)
|
||||
|
||||
|
||||
def name(self, dbtype='sqlite'):
|
||||
return self[dbtype]
|
||||
|
||||
|
||||
def serialize(self, dbtype, value):
|
||||
return value
|
||||
|
||||
|
||||
def deserialize(self, dbtype, value):
|
||||
return value
|
||||
|
||||
|
||||
class Json(Type):
|
||||
sqlite = 'JSON'
|
||||
postgresql = 'JSON'
|
||||
mysql = 'JSON'
|
||||
|
||||
|
||||
def serialize(self, dbtype, value):
|
||||
izzylog.debug(f'serialize {type(self).__name__}: {type(value).__name__}', value)
|
||||
|
||||
if type(value) == str:
|
||||
return DotDict(value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def deserialize(self, dbtype, value):
|
||||
izzylog.debug(f'deserialize {type(self).__name__}: {type(value).__name__}', value)
|
||||
|
||||
return DotDict(value).to_json()
|
||||
|
||||
|
||||
class Datetime(Type):
|
||||
sqlite = 'DATETIME'
|
||||
postgresql = 'TIMESTAMP'
|
||||
mysql = 'DATETIME'
|
||||
|
||||
|
||||
def serialize(self, dbtype, value):
|
||||
izzylog.debug(f'serialize {type(self).__name__}: {type(value).__name__}', value)
|
||||
|
||||
if type(value) == str:
|
||||
return datetime.fromisoformat(value)
|
||||
|
||||
elif type(value) == int:
|
||||
return datetime.fromtimestamp(value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def deserialize(self, dbtype, value):
|
||||
izzylog.debug(f'deserialize {type(self).__name__}: {type(value).__name__}', value)
|
||||
|
||||
if dbtype == 'sqlite':
|
||||
return value.isoformat()
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class Date(Type):
|
||||
sqlite = 'DATE'
|
||||
postgresql = 'DATE'
|
||||
mysql = 'DATE'
|
||||
|
||||
|
||||
def serialize(self, dbtype, value):
|
||||
izzylog.debug(f'serialize {type(self).__name__}: {type(value).__name__}', value)
|
||||
|
||||
if type(value) == str:
|
||||
return date.fromisoformat(value)
|
||||
|
||||
elif type(value) == int:
|
||||
return date.fromtimestamp(value)
|
||||
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def deserialize(self, dbtype, value):
|
||||
izzylog.debug(f'deserialize {type(self).__name__}: {type(value).__name__}', value)
|
||||
|
||||
if dbtype == 'sqlite':
|
||||
return value.isoformat()
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class Time(Type):
|
||||
sqlite = 'TIME'
|
||||
postgresql = 'TIME'
|
||||
mysql = 'TIME'
|
||||
|
||||
|
||||
def serialize(self, dbtype, value):
|
||||
izzylog.debug(f'serialize {type(self).__name__}: {type(value).__name__}', value)
|
||||
|
||||
if type(value) == str:
|
||||
return time.fromisoformat(value)
|
||||
|
||||
elif type(value) == int:
|
||||
return time.fromtimestamp(value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def deserialize(self, dbtype, value):
|
||||
izzylog.debug(f'deserialize {type(self).__name__}: {type(value).__name__}', value)
|
||||
|
||||
if dbtype == 'sqlite':
|
||||
return value.isoformat()
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class Types(DotDict):
|
||||
def __init__(self, db):
|
||||
self._db = db
|
||||
|
||||
self.set_type(Json, Date, Time, Datetime)
|
||||
|
||||
|
||||
def get_type(self, name):
|
||||
return self.get(name.upper(), Type())
|
||||
|
||||
|
||||
def set_type(self, *types):
|
||||
for type_object in types:
|
||||
typeclass = type_object()
|
||||
self[typeclass.name(self._db.cfg.type)] = typeclass
|
|
@ -33,6 +33,7 @@ packages =
|
|||
izzylib.http_server_async
|
||||
izzylib.http_urllib_client
|
||||
izzylib.sql
|
||||
izzylib.sql2
|
||||
setup_requires =
|
||||
setuptools >= 38.3.0
|
||||
|
||||
|
@ -59,6 +60,8 @@ http_urllib_client =
|
|||
sql =
|
||||
SQLAlchemy == 1.4.23
|
||||
SQLAlchemy-Paginator == 0.2
|
||||
sql2 =
|
||||
sql-metadata == 2.3.0
|
||||
template =
|
||||
colour == 0.1.5
|
||||
Hamlish-Jinja == 0.3.3
|
||||
|
|
Loading…
Reference in a new issue