diff --git a/.gitignore b/.gitignore index 8c8a459..73dcf22 100644 --- a/.gitignore +++ b/.gitignore @@ -117,10 +117,13 @@ dmypy.json test*.py reload.cfg +# symlinks /izzylib /base/izzylib/dbus /base/izzylib/hasher /base/izzylib/http_requests_client +/base/izzylib/http_server +/base/izzylib/mbus /base/izzylib/sql /base/izzylib/template /base/izzylib/tinydb diff --git a/base/izzylib/__init__.py b/base/izzylib/__init__.py index 0d710f8..c8f8bfe 100644 --- a/base/izzylib/__init__.py +++ b/base/izzylib/__init__.py @@ -4,7 +4,7 @@ Licensed under the CNPL: https://git.pixie.town/thufie/CNPL https://git.barkshark.xyz/izaliamae/izzylib ''' -import sys, traceback +import os, sys, traceback assert sys.version_info >= (3, 7) __version_tpl__ = (0, 6, 0) @@ -13,6 +13,7 @@ __version__ = '.'.join([str(v) for v in __version_tpl__]) from . import logging izzylog = logging.logger['IzzyLib'] +izzylog.set_config('level', os.environ.get('IZZYLOG_LEVEL', 'INFO')) from .path import Path from .dotdict import DotDict, LowerDotDict, DefaultDotDict, MultiDotDict, JsonEncoder diff --git a/base/izzylib/cache.py b/base/izzylib/cache.py index ed9e1e2..a8d746f 100644 --- a/base/izzylib/cache.py +++ b/base/izzylib/cache.py @@ -42,6 +42,10 @@ def parse_ttl(ttl): return multiplier * int(amount) +class DefaultValue(object): + pass + + class BaseCache(OrderedDict): _get = OrderedDict.get _items = OrderedDict.items @@ -120,6 +124,20 @@ class BaseCache(OrderedDict): return item + def pop(self, key, default=DefaultValue): + try: + item = self.get(key) + del self[key] + + return item + + except Exception as e: + if default == DefaultValue: + raise e from None + + return default + + ## This doesn't work for some reason def CacheDecorator(cache): def decorator(func): diff --git a/base/izzylib/dotdict.py b/base/izzylib/dotdict.py index d44222e..34003a4 100644 --- a/base/izzylib/dotdict.py +++ b/base/izzylib/dotdict.py @@ -59,6 +59,13 @@ class DotDict(dict): raise AttributeError(f'{self.__class__.__name__} object has no attribute {k}') from None + @classmethod + def new_from_json_file(cls, path): + data = cls() + data.load_json(path) + return data + + def copy(self): return DotDict(self) diff --git a/base/izzylib/misc.py b/base/izzylib/misc.py index 94a428b..e3cedc0 100644 --- a/base/izzylib/misc.py +++ b/base/izzylib/misc.py @@ -22,6 +22,7 @@ __all__ = [ 'print_methods', 'prompt', 'random_gen', + 'remove', 'signal_handler', 'time_function', 'time_function_pprint', @@ -348,6 +349,13 @@ def random_gen(length=20, letters=True, numbers=True, extra=None): return ''.join(random.choices(characters, k=length)) +def remove(string: str, junk: list): + for line in junk: + string = string.replace(line, '') + + return string + + def signal_handler(func, *args, original_args=True, **kwargs): if original_args: handler = lambda signum, frame: func(signum, frame, *args, **kwargs) diff --git a/base/izzylib/path.py b/base/izzylib/path.py index 0b6ada7..db11acb 100644 --- a/base/izzylib/path.py +++ b/base/izzylib/path.py @@ -1,4 +1,4 @@ -import os, shutil +import json, os, shutil from datetime import datetime from functools import cached_property @@ -105,9 +105,9 @@ class Path(str): return json.load(s) - def json_dump(self, data): + def json_dump(self, data, indent=None): with self.open('w') as s: - s.write(json.dumps(data)) + s.write(json.dumps(data, indent=indent)) def link(self, path): diff --git a/http_server/izzylib/http_server/__init__.py b/http_server/izzylib/http_server/__init__.py index aabd388..e9b50b5 100644 --- a/http_server/izzylib/http_server/__init__.py +++ b/http_server/izzylib/http_server/__init__.py @@ -4,6 +4,7 @@ start_time = datetime.now() from .application import Application from .config import Config, UserLevel +from .middleware import MiddlewareBase, Headers, AccessLog from .request import Request from .response import Response from .view import View diff --git a/http_server/izzylib/http_server/application.py b/http_server/izzylib/http_server/application.py index 40179f7..9502747 100644 --- a/http_server/izzylib/http_server/application.py +++ b/http_server/izzylib/http_server/application.py @@ -12,7 +12,7 @@ from izzylib.template import Template from .config import Config, UserLevel from .error_handlers import GenericError, MissingTemplateError from .middleware import AccessLog, Headers -from .view import Manifest, Style +from .view import Manifest, Robots, Style log_path_ignore = [ @@ -46,10 +46,12 @@ class Application(sanic.Sanic): ) self.template.add_env('cfg', self.cfg) + self.template.add_env('len', len) if self.cfg.tpl_default: self.template.add_search_path(frontend) self.add_class_route(Manifest) + self.add_class_route(Robots) self.add_class_route(Style) self.static('/favicon.ico', frontend.join('static/icon64.png')) self.static('/framework/static', frontend.join('static')) @@ -93,6 +95,11 @@ class Application(sanic.Sanic): del self.cfg.menu[name] + def get_route_by_path(self, path, method='get', host=None,): + route, handler, _ = self.router.get(path, method, host) + return handler + + def start(self): # register built-in middleware now so they're last in the chain self.add_middleware(Headers) diff --git a/http_server/izzylib/http_server/config.py b/http_server/izzylib/http_server/config.py index ebcf6ec..e6e508a 100644 --- a/http_server/izzylib/http_server/config.py +++ b/http_server/izzylib/http_server/config.py @@ -11,6 +11,7 @@ class UserLevel(IntEnum): USER = 10 MODERATOR = 20 ADMIN = 30 + AUTH = 1000 class Config(DotDict): @@ -25,6 +26,7 @@ class Config(DotDict): port = 8080, proto = 'http', workers = cpu_count(), + access_log = True, request_class = Request, response_class = Response, sig_handler = None, diff --git a/http_server/izzylib/http_server/frontend/base.css b/http_server/izzylib/http_server/frontend/base.css index 5f35466..e61122f 100644 --- a/http_server/izzylib/http_server/frontend/base.css +++ b/http_server/izzylib/http_server/frontend/base.css @@ -39,7 +39,6 @@ a:hover { } input:not([type='checkbox']), select, textarea { - margin: 1px 0; color: var(--text); background-color: var(--background); border: 1px solid var(--background); @@ -263,7 +262,7 @@ details:focus, summary:focus { transition: background-color var(--trans-speed); width: 55px; margin-left: var(--gap); - background-image: url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/menu.svg'); + background-image: url('/framework/static/menu.svg'); background-size: 50px; background-position: center center; background-repeat: no-repeat; @@ -412,8 +411,8 @@ body {scrollbar-width: 15px; scrollbar-color: var(--primary) {{background.darken @font-face { font-family: 'sans undertale'; src: local('Nunito Sans Bold'), - url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-SemiBold.woff2') format('woff2'), - url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-SemiBold.ttf') format('ttf'); + url('/framework/static/nunito/NunitoSans-SemiBold.woff2') format('woff2'), + url('/framework/static/nunito/NunitoSans-SemiBold.ttf') format('ttf'); font-weight: bold; font-style: normal; } @@ -421,8 +420,8 @@ body {scrollbar-width: 15px; scrollbar-color: var(--primary) {{background.darken @font-face { font-family: 'sans undertale'; src: local('Nunito Sans Light Italic'), - url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-ExtraLightItalic.woff2') format('woff2'), - url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-ExtraLightItalic.ttf') format('ttf'); + url('/framework/static/nunito/NunitoSans-ExtraLightItalic.woff2') format('woff2'), + url('/framework/static/nunito/NunitoSans-ExtraLightItalic.ttf') format('ttf'); font-weight: normal; font-style: italic; } @@ -430,8 +429,8 @@ body {scrollbar-width: 15px; scrollbar-color: var(--primary) {{background.darken @font-face { font-family: 'sans undertale'; src: local('Nunito Sans Bold Italic'), - url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-Italic.woff2') format('woff2'), - url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-Italic.ttf') format('ttf'); + url('/framework/static/nunito/NunitoSans-Italic.woff2') format('woff2'), + url('/framework/static/nunito/NunitoSans-Italic.ttf') format('ttf'); font-weight: bold; font-style: italic; } @@ -439,8 +438,8 @@ body {scrollbar-width: 15px; scrollbar-color: var(--primary) {{background.darken @font-face { font-family: 'sans undertale'; src: local('Nunito Sans Light'), - url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-Light.woff2') format('woff2'), - url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-Light.ttf') format('ttf'); + url('/framework/static/nunito/NunitoSans-Light.woff2') format('woff2'), + url('/framework/static/nunito/NunitoSans-Light.ttf') format('ttf'); font-weight: normal; font-style: normal; } diff --git a/http_server/izzylib/http_server/frontend/base.haml b/http_server/izzylib/http_server/frontend/base.haml index 87a3754..d2da0e6 100644 --- a/http_server/izzylib/http_server/frontend/base.haml +++ b/http_server/izzylib/http_server/frontend/base.haml @@ -2,20 +2,21 @@ %html %head %title << {{cfg.name}}: {{page}} - %link rel='stylesheet' type='text/css' href='{{cfg.proto}}://{{cfg.web_host}}/framework/style.css' - %link rel='manifest' href='{{cfg.proto}}://{{cfg.web_host}}/framework/manifest.json' + %link rel='stylesheet' type='text/css' href='/framework/style.css' + %link rel='manifest' href='/framework/manifest.json' %meta charset='UTF-8' %meta name='viewport' content='width=device-width, initial-scale=1' + -block head %body #body #header.flex-container -if menu_left #btn.section - .page-title.section -> %a.title href='{{cfg.proto}}://{{cfg.web_host}}/' << {{cfg.name}} + .page-title.section -> %a.title href='/' << {{cfg.name}} -else - .page-title.section -> %a.title href='{{cfg.proto}}://{{cfg.web_host}}/' << {{cfg.name}} + .page-title.section -> %a.title href='/' << {{cfg.name}} #btn.section -if message @@ -27,9 +28,15 @@ #menu.section .title-item.item << Menu #items - -for label, data in cfg.menu.items() - -if request.user_level >= data[1] - .item -> %a href='{{cfg.proto}}://{{cfg.web_host}}{{data[0]}}' << {{label}} + -if not len(cfg.menu): + -include 'menu.haml' + + -else: + -for label, path_data in cfg.menu.items() + -if path_data[1] == 1000 and request.user_level == 0: + .item -> %a href='{{path_data[0]}}' << {{label}} + -elif request.user_level >= path_data[1] + .item -> %a href='{{path_data[0]}}' << {{label}} #content-body.section -block content @@ -40,4 +47,4 @@ .source %a href='{{cfg.git_repo}}' target='_new' << {{cfg.name}}/{{cfg.version}} - %script type='application/javascript' src='{{cfg.proto}}://{{cfg.web_host}}/framework/static/menu.js' + %script type='application/javascript' src='/framework/static/menu.js' diff --git a/http_server/izzylib/http_server/frontend/menu.haml b/http_server/izzylib/http_server/frontend/menu.haml new file mode 100644 index 0000000..c883b18 --- /dev/null +++ b/http_server/izzylib/http_server/frontend/menu.haml @@ -0,0 +1 @@ +.item => %a(href='/') << Home diff --git a/http_server/izzylib/http_server/middleware.py b/http_server/izzylib/http_server/middleware.py index 6c7c86f..8555d87 100644 --- a/http_server/izzylib/http_server/middleware.py +++ b/http_server/izzylib/http_server/middleware.py @@ -6,6 +6,34 @@ from izzylib import izzylog as logging, logging as applog from . import start_time +cache_types = [ + 'text/css', + 'application/javascript', +] + +cache_base_types = [ + 'image', + 'audio', + 'video' +] + + +def cache_check(request, response): + content_type = response.headers.get('content-type') + + if request.path.startswith('/framework'): + return True + + if not content_type: + return False + + if content_type in cache_types: + return True + + if any(map(content_type.startswith, cache_base_types)): + return True + + class MiddlewareBase: attach = 'request' @@ -22,7 +50,14 @@ class Headers(MiddlewareBase): attach = 'response' async def handler(self, request, response): - if request.path.startswith('/framework') or request.path == '/favicon.ico': + if not response.headers.get('content-type'): + if request.path.endswith('.css'): + response.headers['content-type'] = 'text/css' + + elif request.path.endswith('.js'): + response.headers['content-type'] = 'application/javascript' + + if cache_check(request, response): max_age = int(timedelta(weeks=2).total_seconds()) response.headers['Cache-Control'] = f'immutable,private,max-age={max_age}' diff --git a/http_server/izzylib/http_server/request.py b/http_server/izzylib/http_server/request.py index 4590388..7d288f8 100644 --- a/http_server/izzylib/http_server/request.py +++ b/http_server/izzylib/http_server/request.py @@ -1,14 +1,16 @@ import sanic +from izzylib import DotDict + from .misc import Headers class Request(sanic.request.Request): - def __init__(self, url_bytes, headers, version, method, transport, app): + def __init__(self, url_bytes, headers, version, method, transport, app, **kwargs): super().__init__(url_bytes, headers, version, method, transport, app) self.Headers = Headers(headers) - self.Data = Data(self) + self.data = Data(self) self.template = self.app.template self.user_level = 0 self.setup() diff --git a/http_server/izzylib/http_server/response.py b/http_server/izzylib/http_server/response.py index 5b9e69e..5e7f5b0 100644 --- a/http_server/izzylib/http_server/response.py +++ b/http_server/izzylib/http_server/response.py @@ -1,6 +1,7 @@ import json, sanic -from izzylib import DotDict +from datetime import datetime +from izzylib import DotDict, izzylog from izzylib.template import Color from sanic.compat import Header from sanic.cookies import CookieJar @@ -27,6 +28,18 @@ class Response: 'speed': 250 }) + cookie_keys = { + 'expires': 'Expires', + 'path': 'Path', + 'comment': 'Comment', + 'domain': 'Domain', + 'max_age': 'Max-Age', + 'secure': 'Secure', + 'httponly': 'HttpOnly', + 'version': 'Version', + 'samesite': 'SameSite' + } + def __init__(self, app, request, body=None, headers={}, cookies={}, status=200, content_type='text/html'): # server objects @@ -110,6 +123,8 @@ class Response: data.update(kwargs) for k,v in data.items(): + k = self.cookie_keys.get(k, k) + if k.lower() == 'max-age': if isinstance(v, timedelta): v = int(v.total_seconds()) @@ -118,13 +133,16 @@ class Response: raise TypeError('Max-Age must be an integer or timedelta') elif k.lower() == 'expires': - if isinstance(v, datetime): - v = v.strftime('%a, %d-%b-%Y %T GMT') + if isinstance(v, str): + v = datetime.strptime(v, '%a, %d-%b-%Y %T GMT') - elif not isinstance(v, str): + elif not isinstance(v, datetime): raise TypeError('Expires must be a string or datetime') - self.cookies[key][k] = v + try: + self.cookies[key][k] = v + except KeyError as e: + izzylog.error('Invalid cookie key:', k) def get_cookie(self, key): @@ -140,7 +158,7 @@ class Response: del self.cookies[key] - def template(self, tplfile, context={}, headers={}, status=200, content_type='text/html', cookies={}, pprint=False): + def template(self, tplfile, context={}, headers={}, status=200, content_type='text/html', pprint=False): self.status = status context.update({ 'response': self, @@ -148,7 +166,7 @@ class Response: }) html = self.app.template.render(tplfile, context, request=self.request, pprint=pprint) - return self.html(html, headers=headers, status=status, content_type=content_type, cookies=cookies) + return self.html(html, headers=headers, status=status, content_type=content_type) def error(self, message, status=500, **kwargs): @@ -158,13 +176,13 @@ class Response: return self.template('error.haml', {'error_message': message}, status=status, **kwargs) - def json(self, body={}, headers={}, status=200, content_type='application/json', cookies={}): + def json(self, body={}, headers={}, status=200, content_type='application/json'): body = json.dumps(body) - return self.get_response(body, headers, status, content_type, cookies) + return self.get_response(body, headers, status, content_type) - def text(self, body, headers={}, status=200, content_type='text/plain', cookies={}): - return self.get_response(body, headers, status, content_type, cookies) + def text(self, body, headers={}, status=200, content_type='text/plain'): + return self.get_response(body, headers, status, content_type) def html(self, *args, **kwargs): @@ -174,24 +192,26 @@ class Response: def css(self, *args, **kwargs): self.content_type = 'text/css' - return self.text.text(*args, **kwargs) + return self.text(*args, **kwargs) def javascript(self, *args, **kwargs): self.content_type = 'application/javascript' - return self.text.text(*args, **kwargs) + return self.text(*args, **kwargs) def activitypub(self, *args, **kwargs): self.content_type = 'application/activity+json' - return self.text.text(*args, **kwargs) + return self.text(*args, **kwargs) def redir(self, path, status=302, headers={}): - return sanic.response.redirect(path, status=status, headers={}) + headers.update(dict(location=path)) + return self.text(body=None, status=status, headers=headers) + #return sanic.response.redirect(path, status=status, headers={}) - def set_data(self, body=None, headers={}, status=200, content_type='text/html', cookies={}): + def set_data(self, body=None, headers={}, status=200, content_type='text/html'): ctype = self.content_types.get(content_type, content_type) self.body = body diff --git a/http_server/izzylib/http_server/view.py b/http_server/izzylib/http_server/view.py index 0eaa691..5db3de3 100644 --- a/http_server/izzylib/http_server/view.py +++ b/http_server/izzylib/http_server/view.py @@ -48,6 +48,14 @@ class Manifest(View): return response.json(data) +class Robots(View): + paths = ['/robots.txt'] + + async def get(self, request, response): + data = '# Disallow all\nUser-agent: *\nDisallow: /' + return response.text(data) + + class Style(View): paths = ['/framework/style.css'] diff --git a/http_server/setup.py b/http_server/setup.py index 46a3668..d3dd724 100644 --- a/http_server/setup.py +++ b/http_server/setup.py @@ -5,6 +5,7 @@ from setuptools import setup, find_namespace_packages requires = [ 'sanic>=20.12.3', 'sanic-cors>=1.0.0', + 'envbash>=1.0.0' ] diff --git a/sql/izzylib/sql/__init__.py b/sql/izzylib/sql/__init__.py index 5228cfa..3aa54d0 100644 --- a/sql/izzylib/sql/__init__.py +++ b/sql/izzylib/sql/__init__.py @@ -1,2 +1,6 @@ -from .generic import SqlColumn, CustomRows, SqlSession, SqlDatabase, Tables +# old sql classes +from .generic import SqlColumn, CustomRows, SqlSession, SqlDatabase, Tables, OperationalError, ProgrammingError from .sqlite_server import SqliteClient, SqliteColumn, SqliteServer, SqliteSession + +from .database import Database, Session +from .queries import Column, Insert, Select, Table, Tables, Update diff --git a/sql/izzylib/sql/config.py b/sql/izzylib/sql/config.py new file mode 100644 index 0000000..8202a1b --- /dev/null +++ b/sql/izzylib/sql/config.py @@ -0,0 +1,100 @@ +import importlib, sqlite3, ssl + +from getpass import getuser +from izzylib import DotDict, Path, izzylog + + +defaults = { + 'name': (None, str), + 'host': (None, str), + 'port': (None, int), + 'username': (getuser(), str), + 'password': (None, str), + 'ssl': ('allow', str), + 'ssl_context': (ssl.create_default_context(), ssl.SSLContext), + 'ssl_key': (None, Path), + 'ssl_cert': (None, Path), + 'max_connections': (25, int), + 'type': ('sqlite', str), + 'module': (sqlite3, None), + 'mod_name': ('sqlite3', str), + 'timeout': (5, int), + 'args': ([], list), + 'kwargs': ({}, dict) +} + +modtypes = { + 'sqlite': ['sqlite3'], + 'postgresql': ['pg8000', 'psycopg2', 'psycopg3', 'pgdb'], + 'mysql': ['mysqldb', 'trio_mysql'], + 'mssql': ['pymssql', 'adodbapi'] +} + +sslmodes = ['disable', 'allow', 'prefer', 'require', 'verify-ca', 'verify-full'] + + +class Config(DotDict): + def __init__(self, **kwargs): + super().__init__({k: v[0] for k,v in defaults.items()}) + + module = kwargs.pop('module', None) + + if module: + self.parse_module(module) + + self.update(kwargs) + + if self.ssl != 'disable' and (self.ssl_key or self.ssl_cert): + self.ssl_context.load_cert_chain(self.ssl_cert, self.ssl_key) + + + def __setitem__(self, key, value): + if key not in defaults: + raise KeyError(f'Invalid config option: {key}') + + valtype = defaults[key][1] + + if valtype and value and not isinstance(value, valtype): + raise TypeError(f'{key} should be a {valtype}, not a {value.__class__.__name__}') + + if key == 'ssl' and value == True: + value = ssl.create_default_context() + + super().__setitem__(key, value) + + + def parse_module(self, name): + module = None + module_type = None + module_name = None + + if name == 'sqlite3': + name = 'sqlite' + + for mtype, modules in modtypes.items(): + if name == mtype: + module_type = name + + for mod in modules: + try: + module = importlib.import_module(mod) + module_name = mod + break + except ImportError: + izzylog.verbose(f'Database module not installed:', mod) + + elif name in modules: + try: + module = importlib.import_module(name) + module_type = mtype + module_name = name + break + except ImportError: + izzylog.error(f'Database module not installed:', name) + + if None in (module, module_name, module_type): + raise ValueError(f'Failed to find module for {name}') + + self.module = module + self.mod_name = module_name + self.type = module_type diff --git a/sql/izzylib/sql/database.py b/sql/izzylib/sql/database.py new file mode 100644 index 0000000..2f2ce9e --- /dev/null +++ b/sql/izzylib/sql/database.py @@ -0,0 +1,360 @@ +import sqlite3, traceback + +from functools import partial +from getpass import getuser +from izzylib import DotDict, izzylog, boolean, random_gen + +from . import error +from .config import Config +from .queries import Column, Delete, Insert, Select, Table, Tables, Update + + +class Database: + def __init__(self, tables=None, **kwargs): + self.tables = tables + self.cfg = Config(**kwargs) + self.sessions = DotDict() + + + @property + def session(self): + return self.get_session(False) + + + @property + def session_trans(self): + return self.get_session(True) + + + def connect(self, sid, session): + if len(self.sessions) >= self.cfg.max_connections: + raise error.MaxConnectionsError(f'Cannot start a new session with id {sid}. Reach max connection count of {self.cfg.max_connections}.') + + self.sessions[sid] = session + + + def disconnect(self, sid): + self.sessions[sid].disconnect() + del self.sessions[sid] + + + def disconnect_all(self): + sids = [] + + for sid in self.sessions.keys(): + sids.append(sid) + + for sid in sids: + self.disconnect(sid) + + + def get_session(self, trans=True): + session = Session(self, trans) + self.sessions[session.id] = session + return session + + + def execute(self, *args): + with self.session as s: + s.execute(*args) + + + def load_tables(self, path): + self.tables = Tables.new_from_json_file(path) + + + def pre_setup(self): + if self.cfg.type != 'postgresql': + izzylog.verbose(f'Database not supported for pre_setup: {self.cfg.type}') + return + + original_database = self.cfg.name + self.cfg.name = 'postgres' + + with self.session as s: + s.conn.autocommit = True + s.rollback() + + if original_database not in s.get_databases(): + #s.execute('SET AUTOCOMMIT = OFF') + s.cursor.execute(f'CREATE DATABASE {original_database}') + + s.conn.autocommit = False + + self.cfg.name = original_database + + + def set_row_class(self, table, row_class): + pass + + +class Session: + def __init__(self, db, trans): + self.id = random_gen() + self.db = db + self.cfg = db.cfg + self.trans = trans + self.trans_state = False + self.conn = None + self.cursor = None + + + def __del__(self): + try: + izzylog.verbose('Deleting session:', self.id) + except ModuleNotFoundError: + if izzylog.get_config('level') >= 20: + print('[izzylib] VERBOSE: Deleting session:', self.id) + + self.db.sessions.pop(self.id, None) + + if self.conn: + self.disconnect() + + + def __enter__(self): + self.connect() + + if self.trans: + self.begin() + + return self + + + def __exit__(self, exc_type, exc_value, exc_traceback): + if exc_traceback: + self.rollback() + + else: + self.commit() + + self.disconnect() + self.db.disconnect(self.id) + + + def connect(self): + if self.conn: + return + + self.db.connect(self.id, self) + + if self.cfg.type == 'sqlite': + self.conn = self.cfg.module.connect(self.cfg.name, self.cfg.timeout, check_same_thread=True) + + elif self.cfg.type == 'postgresql': + options = dict( + host = self.cfg.host or '/var/run/postgresql', + port = self.cfg.port or 5432, + database = self.cfg.name or 'postgresql', + user = self.cfg.username or getuser(), + password = self.cfg.password, + ) + + if self.cfg.mod_name == 'pg8000': + if options['host'] in [None, '/var/run/postgresql']: + port = options.pop('port') + options['unix_sock'] = options.pop('host') + f'/.s.PGSQL.{port}' + + ## SSL is a pain in the ass tbh. Gonna deal with this later + #if self.cfg.mod_name == 'pg8000': + #options['sslmode'] = self.cfg.ssl + #options['ssl_context'] = self.cfg.ssl_context + + #elif self.cfg.mod_name == 'psycopg2': + #options['sslcert'] = self.cfg.ssl_cert + #options['sslkey'] = self.cfg.ssl_key + + self.conn = self.cfg.module.connect(**options) + + else: + raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}') + + try: + self.conn.autocommit = False + except: + izzylog.verbose('Failed to turn off autocommit') + + self.cursor = self.conn.cursor() + + + def disconnect(self): + if not self.conn: + return + + self.cursor.close() + self.conn.close() + + self.cursor = None + self.conn = None + + + def begin(self): + if self.trans_state: + return + + #self.conn.begin() + self.execute('BEGIN TRANSACTION') + self.trans_state = True + + + def rollback(self): + if not self.trans_state: + return + + self.conn.rollback() + #self.execute('ROLLBACK TRANSACTION') + self.trans_state = False + + + def commit(self): + if not self.trans_state: + return + + self.conn.commit() + #self.execute('COMMIT TRANSACTION') + self.trans_state = False + + + ## data management functions + def execute(self, string, values=[]): + if any(map(string.lower().startswith, ['insert', 'update', 'remove', 'create', 'drop'])) and not self.trans_state: + raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.') + + self.cursor.execute(string, values) + return self.cursor + + + def fetch(self, table, single=True, **kwargs): + rows = [] + data = Select(table, type=self.cfg.type, **kwargs).exec(self) + + for line in data: + row = Row(table, self.cursor.description, line) + + if single: + return row + + rows.append(row) + + return rows if not single else None + + + def search(self, table, **kwargs): + return self.fetch(table, single=False, **kwargs) + + + def insert(self, table, **kwargs): + if not self.trans_state: + raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.') + + Insert(table, type=self.cfg.type, **kwargs).exec(self) + return self.fetch(table, **kwargs) + + + def update(self, table, rowid, **kwargs): + if not self.trans_state: + raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.') + + Update(table, rowid, type=self.cfg.type, **kwargs).exec(self) + return self.fetch(table, id=rowid) + + + def delete(self, table, **kwargs): + if not self.trans_state: + raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.') + + Delete(table, type=self.cfg.type, **kwargs).exec(self) + + + ## helper functions + def get_columns(self, table): + if table not in self.get_tables(): + raise KeyError(f'Not an existing table: {table}') + + if self.cfg.type == 'sqlite': + rows = self.execute(f'PRAGMA table_info({table})') + return [row[1] for row in rows] + + elif self.cfg.type == 'postgresql': + rows = self.execute(f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table}'") + return [row[0] for row in rows] + + else: + raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}') + + + def get_tables(self): + if self.cfg.type == 'sqlite': + rows = self.execute("SELECT name FROM sqlite_master WHERE type IN ('table','view') and name NOT LIKE 'sqlite_%'") + + elif self.cfg.type == 'postgresql': + rows = self.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' ORDER BY table_name") + + else: + raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}') + + return [row[0] for row in rows] + + + def get_databases(self): + if self.cfg.type == 'sqlite': + izzylog.verbose('This function is useless with sqlite') + return + + elif self.cfg.type == 'postgresql': + databases = [row[0] for row in self.execute('SELECT datname FROM pg_database')] + + else: + raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}') + + return databases + + + def cursor_description(self): + return [row[0] for row in self.cursor.description] + + + def setup_database(self): + if not self.db.tables: + raise ValueError('Tables have not been specified.') + + current_tables = self.get_tables() + + for name, table in self.db.tables.items(): + if name in current_tables: + izzylog.verbose(f'Skipping table creation since it already exists: {name}') + continue + + izzylog.verbose(f'Creating table: {name}') + self.execute(table.build(self.cfg.type)) + + +class Row(DotDict): + def __init__(self, table, keys, values): + self._db = None + self._table = table + + super().__init__() + + for idx, key in enumerate([key[0] for key in keys]): + self[key] = values[idx] + + + def update(self, data): + for k, v in data.items(): + if k not in self: + raise KeyError(f'Not a column for {self._table}') + + self[k] = v + + + def delete(self): + with self._db.session as s: + s.delete(self._table, id=self.id) + + + def update(self, **kwargs): + self.update(kwargs) + + with self._db.session as s: + s.update(self._table, id=self.id, **kwargs) diff --git a/sql/izzylib/sql/error.py b/sql/izzylib/sql/error.py new file mode 100644 index 0000000..4c60314 --- /dev/null +++ b/sql/izzylib/sql/error.py @@ -0,0 +1,10 @@ +class MaxConnectionsError(Exception): + 'raise when the max amount of connections has been reached' + + +class NoTransactionError(Exception): + 'raise when a write command is executed outside a transaction' + + +class DatabaseNotSupportedError(Exception): + 'raise when the action being performed is not supported by the database in use' diff --git a/sql/izzylib/sql/generic.py b/sql/izzylib/sql/generic.py index 48d01fe..aa313ed 100644 --- a/sql/izzylib/sql/generic.py +++ b/sql/izzylib/sql/generic.py @@ -3,7 +3,7 @@ import json, sys, threading, time from contextlib import contextmanager from datetime import datetime from sqlalchemy import create_engine, ForeignKey, MetaData, Table -from sqlalchemy import Column, types as Types +from sqlalchemy import Column as sqlalchemy_column, types as Types from sqlalchemy.exc import OperationalError, ProgrammingError from sqlalchemy.orm import scoped_session, sessionmaker @@ -57,6 +57,12 @@ class SqlDatabase: engine_string += f'/{database}' engine_kwargs['connect_args'] = {'check_same_thread': False} + elif dbtype == 'postgresql': + ssl_context = kwargs.get('ssl') + + if ssl_context: + engine_kwargs['ssl_context'] = ssl_context + else: user = kwargs.get('user') password = kwargs.get('pass') @@ -124,9 +130,9 @@ class SqlDatabase: self.table_names = tables.keys() - def execute(self, *args, **kwargs): + def execute(self, string, values=[]): with self.session as s: - return s.execute(*args, **kwargs) + s.execute(string, values) class SqlSession(object): @@ -267,7 +273,7 @@ class SqlSession(object): if not rowid or not table: raise ValueError('Missing row ID or table') - row = self.execute(f'DELETE FROM {table} WHERE id={rowid}') + self.execute(f'DELETE FROM {table} WHERE id={rowid}') def drop_table(self, name): @@ -284,12 +290,28 @@ class SqlSession(object): self.drop_table(table) + def get_columns(self, table): + if table not in self.get_tables(): + raise KeyError(f'Not an existing table: {table}') + + rows = self.execute('PRAGMA table_info(user)') + return [row[1] for row in rows] + + def get_tables(self): rows = self.execute("SELECT name FROM sqlite_master WHERE type IN ('table','view') and name NOT LIKE 'sqlite_%'") return [row[0] for row in rows] - def append_column(self, tbl, col): + def append_column(self, table, column): + if column.name in self.get_columns(table): + logging.warning(f'Table "{table}" already has column "{column.name}"') + return + + self.execute(f'ALTER TABLE {table} ADD COLUMN {column.compile()}') + + + def append_column2(self, tbl, col): table = self.table[tbl] try: @@ -301,7 +323,7 @@ class SqlSession(object): columns = [row[1] for row in self.execute(f'PRAGMA table_info({tbl})')] - if col in columns: + if col in self.get_columns(tbl): izzylog.info(f'Column "{col}" already exists') return @@ -436,29 +458,51 @@ class Tables(DotDict): def __setup_table(self, name, table): - columns = [col if type(col) == Column else Column(*col.get('args'), **col.get('kwargs')) for col in table] + columns = [col if type(col) == SqlColumn else SqlColumn(*col.get('args'), **col.get('kwargs')) for col in table] self[name] = Table(name, self.meta, *columns) -def SqlColumn(name, stype=None, fkey=None, **kwargs): - if not stype and not kwargs: - if name == 'id': - return Column('id', SqlTypes['integer'], primary_key=True, autoincrement=True) +class SqlColumn(sqlalchemy_column): + def __init__(self, name, stype=None, fkey=None, **kwargs): + if not stype and not kwargs: + if name == 'id': + stype = 'integer' + kwargs['primary_key'] = True + kwargs['autoincrement'] = True - elif name == 'timestamp': - return Column('timestamp', SqlTypes['datetime']) + elif name == 'timestamp': + stype = 'datetime' - raise ValueError('Missing column type and options') + else: + raise ValueError('Missing column type and options') - else: - try: - stype = stype or 'string' - options = [name, SqlTypes[stype.lower()]] + stype = (stype.lower() if type(stype) == str else stype) or 'string' - except KeyError: - raise KeyError(f'Invalid SQL data type: {stype}') + if type(stype) == str: + try: + stype = SqlTypes[stype.lower()] + + except KeyError: + raise KeyError(f'Invalid SQL data type: {stype}') + + options = [name, stype] if fkey: options.append(ForeignKey(fkey)) - return Column(*options, **kwargs) + super().__init__(*options, **kwargs) + + + def compile(self): + sql = f'{self.name} {self.type}' + + if not self.nullable: + sql += ' NOT NULL' + + if self.primary_key: + sql += ' PRIMARY KEY' + + if self.unique: + sql += ' UNIQUE' + + return sql diff --git a/sql/izzylib/sql/queries.py b/sql/izzylib/sql/queries.py new file mode 100644 index 0000000..b780e9b --- /dev/null +++ b/sql/izzylib/sql/queries.py @@ -0,0 +1,415 @@ +from datetime import datetime +from functools import partial +from izzylib import DotDict, Path + +from .types import BaseType, Type + + +placeholders = dict( + sqlite = '?', + postgresql = '%s' +) + + +## Data queries +class Delete: + def __init__(self, table, type='sqlite', **kwargs): + self.table = table + self.placeholder = placeholders[type] + self.keys = [] + self.values = [] + + for k,v in kwargs.items(): + self.keys.append(k) + self.values.append(v) + + + def __str__(self): + self.build(embed_values=True) + + + def build(self, comp_type='AND', embed_values=False): + sql = 'DELETE FROM {table} WHERE {kstring}' + + if not embed_values: + kstring = f' {comp_type.upper()} '.join([f'{k} = {self.placeholder}' for k in self.keys]) + return sql.format(table=self.table, kstring=kstring), self.values + + values = [] + + for idx, value in enumerate(self.values): + if type(value) == str: + values.append(f"{self.keys[idx]} = '{value}'") + + else: + values.append(f"{self.keys[idx]} = {value}") + + kstring = ','.join(values) + return sql.format(table=self.table, kstring=kstring, rowid=self.rowid) + + + def exec(self, session, comp_type='AND'): + return session.execute(*self.build(comp_type)) + + +class Insert: + def __init__(self, table, type='sqlite', **kwargs): + self.table = table + self.placeholder = placeholders[type] + self.keys = [] + self.values = [] + + for k, v in kwargs.items(): + self.keys.append(k) + self.values.append(v) + + + def __str__(self): + return self.build(embed_values=True) + + + def build(self, embed_values=False): + kstring = ','.join(self.keys) + + if not embed_values: + vstring = ','.join([self.placeholder for k in self.keys]) + return f'INSERT INTO {self.table} ({kstring}) VALUES({vstring})', self.values + + else: + vstring = ','.join(self.values) + return f'INSERT INTO {self.table} ({kstring}) VALUES({vstring})' + + + def exec(self, session): + return session.execute(*self.build()) + + +class Select: + def __init__(self, table, columns=[], type='sqlite', **kwargs): + self.placeholder = placeholders[type] + self.columns = columns + self.table = table + self.where = [] + self.where_build = [] + self._order = [] + self.keys = [] + self.values = [] + + self.equals = partial(self.__comparison, '=') + self.less = partial(self.__comparison, '<') + self.greater = partial(self.__comparison, '>') + self.like = partial(self.__comparison, 'LIKE') + + for k,v in kwargs.items(): + self.equals(k, v) + + + def __str__(self): + return self.build(embed_values=True) + + + def __comparison(self, comp, key, value): + self.values.append(value) + self.keys.append(key) + self.where.append(f'{key} {comp.upper()} {self.placeholder}') + self.where_build.append(f"{key} {comp.upper()} '{value}'" if type(key) == str else f"{key} {comp.upper()} {value}") + return self + + + def order(self, column, asc=True): + self._order = [column, 'ASC' if asc else 'DESC'] + return self + + + def build(self, comp_type='AND', embed_values=False): + if not self.columns: + cols = '*' + + else: + cols = ','.join('columns') + + sql_query = f'SELECT {cols} FROM {self.table}' + + if self.where: + where = f' {comp_type.upper()} '.join(self.where if not embed_values else self.where_build) + sql_query += f' WHERE {where}' + + if self._order: + col, order = self._order + sql_query += f' ORDER BY {col} {order}' + + if embed_values: + return sql_query + + return sql_query, self.values + + + def exec(self, session, comp_type='AND'): + return session.execute(*self.build(comp_type)) + + +class Update: + def __init__(self, table, rowid, type='sqlite', **kwargs): + self.placeholder = placeholders[type] + self.table = table + self.rowid = rowid + self.keys = [] + self.values = [] + + for k,v in kwargs.items(): + self.keys.append(k) + self.values.append(v) + + + def __str__(self): + return self.build(embed_values=True) + + + def build(self, embed_values=False): + sql = 'UPDATE {table} SET {kstring} WHERE id={rowid}' + + if not embed_values: + kstring = ','.join([f'{k} = {self.placeholder}' for k in self.keys]) + return sql.format(table=self.table, kstring=kstring, rowid=self.rowid), self.values + + values = [] + + for idx, value in enumerate(self.values): + if type(value) == str: + values.append(f"{self.keys[idx]} = '{value}'") + + else: + values.append(f"{self.keys[idx]} = {value}") + + kstring = ','.join(values) + return sql.format(table=self.table, kstring=kstring, rowid=self.rowid) + + + def exec(self, session): + return session.execute(*self.build()) + + +## Database objects +class Column: + def __init__(self, name, type='STRING', unique=False, nullable=True, default=None, primary_key=False, autoincrement=False, foreign_key=None): + self.name = name + self.type = type + self.nullable = nullable + self.default = default + self.primary_key = primary_key + self.autoincrement = autoincrement + self.unique = unique + + if any(map(isinstance, [foreign_key], [list, tuple, set])): + self.foreign_key = foreign_key + + else: + self.foreign_key = foreign_key.split('.') if foreign_key else None + + if autoincrement: + self.primary_key = True + self.type = Type['INTEGER'] + + if isinstance(self.type, BaseType): + self.type = self.type.name + + else: + if self.type.upper() in Type.keys(): + self.type = self.type.upper() + + else: + raise TypeError(f'Invalid SQL type: {self.type}') + + if foreign_key and len(self.foreign_key) != 2: + raise ValueError('Invalid foreign key. Must be in the format "table.column".') + + + def __str__(self): + return self.build() + + + def build(self, dbtype='sqlite'): + if dbtype == 'postgresql': + if self.type.lower() == 'string': + self.type = 'TEXT' + + elif self.type.lower() == 'datetime': + self.type = 'TIMESTAMPTZ' + + if self.autoincrement: + self.type = 'SERIAL' + self.autoincrement = False + + sql = f'{self.name} {self.type}' + + if self.primary_key: + sql += ' PRIMARY KEY' + + if self.autoincrement: + sql += ' AUTOINCREMENT' + + if self.unique: + sql += ' UNIQUE' + + if not self.nullable: + sql += ' NOT NULL' + + if self.default: + def_type = type(self.default) + + if self.default == 'CURRENT_TIMESTAMP': + if dbtype == 'sqlite': + sql += " DEFAULT (datetime('now', 'localtime'))" + + elif dbtype == 'postgresql': + sql += ' DEFAULT now()' + + else: + sql += f' DEFAULT {datetime.now().timestamp()}' + + elif def_type == str: + sql += f" DEFAULT '{self.default}'" + + elif def_type in [int, float]: + sql += f' DEFAULT {self.default}' + + elif def_type == bool and dbtype == 'sqlite': + sql += f' DEFAULT {int(self.default)}' + + else: + sql += f' DEFAULT {self.default}' + + print(sql) + return sql + + + def json(self): + return DotDict({ + 'type': self.type, + 'nullable': self.nullable, + 'default': self.default, + 'primary_key': self.primary_key, + 'autoincrement': self.autoincrement, + 'unique': self.unique, + 'foreign_key': self.foreign_key + }) + + +class Table(DotDict): + def __init__(self, name, *columns): + super().__init__() + self._name = name + self._foreign_keys = {} + + self.add_column(Column('id', autoincrement=True)) + + for column in columns: + self.add_column(column) + + + def __str__(self): + return self.build() + + + # this'll be useful later + def __call__(self, *args, **kwargs): + pass + + + @property + def name(self): + return self._name + + + def add_column(self, column): + self[column.name] = column + + if column.foreign_key: + self._foreign_keys[column.name] = column.foreign_key + + + def build(self, dbtype='sqlite'): + column_string = ',\n'.join([f'\t{col.build(dbtype)}' for col in self.values()]) + + if self._foreign_keys: + column_string += ',\n' + column_string += ',\n'.join([f'\tFOREIGN KEY ({column}) REFERENCES {key[0]} ({key[1]})' for column, key in self._foreign_keys.items()]) + + return f'''CREATE TABLE {self.name} ( +{column_string} +);''' + + + def json(self): + data = {} + + for name, column in self.items(): + data[name] = column.json() + + return data + + +class Tables(DotDict): + def __init__(self, *tables, data={}): + super().__init__() + + for table in tables: + self.add_table(table) + + if data: + self.from_dict(data) + + + def __str__(self): + return self.build() + + + @classmethod + def new_from_json_file(cls, path): + return cls(data=DotDict.new_from_json_file(path)) + + + def add_table(self, table): + self[table.name] = table + + + def build(self): + return '\n\n'.join([str(table) for table in self.values()]) + + + def load_json(self, path): + data = DotDict() + data.load_json(path) + + self.from_dict(data) + + + def save_json(self, path, indent='\t'): + self.to_dict().save_json(path, indent=indent) + + + def from_dict(self, data): + for name, columns in data.items(): + table = Table(name) + + for col, kwargs in columns.items(): + table.add_column(Column(col, + type = kwargs.get('type', 'STRING'), + nullable = kwargs.get('nullable', True), + default = kwargs.get('default'), + primary_key = kwargs.get('primary_key', False), + autoincrement = kwargs.get('autoincrement', False), + unique = kwargs.get('unique', False), + foreign_key = kwargs.get('foreign_key') + )) + + self.add_table(table) + + + def to_dict(self): + data = DotDict() + + for name, table in self.items(): + data[name] = table.json() + + return data diff --git a/sql/izzylib/sql/row.py b/sql/izzylib/sql/row.py new file mode 100644 index 0000000..2a5416d --- /dev/null +++ b/sql/izzylib/sql/row.py @@ -0,0 +1,19 @@ +from izzylib import DotDict + + +class DbRow(DotDict): + def __init__(self, table, keys, values): + self.table = table + + super().__init__() + + for idx, key in enumerate(keys): + self[key] = values[idx] + + + def delete(self): + pass + + + def update(self, **kwargs): + pass diff --git a/sql/izzylib/sql/types.py b/sql/izzylib/sql/types.py new file mode 100644 index 0000000..e53b400 --- /dev/null +++ b/sql/izzylib/sql/types.py @@ -0,0 +1,19 @@ +from enum import Enum +from izzylib import DotDict + + +class BaseType(Enum): + INTEGER = int + TEXT = str + BLOB = bytes + REAL = float + NUMERIC = float + + +Type = DotDict( + **{v: BaseType.INTEGER for v in ['INT', 'INTEGER', 'TINYINT', 'SMALLINT', 'MEDIUMINT', 'BIGINT', 'UNSIGNED BIG INT', 'INT2', 'INT8']}, + **{v: BaseType.TEXT for v in ['CHARACTER', 'VARCHAR', 'VARYING CHARACTER', 'NCHAR', 'NATIVE CHARACTER', 'NVARCHAR', 'TEXT', 'CLOB', 'STRING', 'JSON']}, + **{v: BaseType.BLOB for v in ['BYTES', 'BLOB']}, + **{v: BaseType.REAL for v in ['REAL', 'DOUBLE', 'DOUBLE PRECISION', 'FLOAT']}, + **{v: BaseType.NUMERIC for v in ['NUMERIC', 'DECIMAL', 'BOOLEAN', 'DATE', 'DATETIME']} +)