This commit is contained in:
Izalia Mae 2021-08-29 22:14:55 -04:00
parent 11aaaf3499
commit cd20eaea97
25 changed files with 1155 additions and 64 deletions

3
.gitignore vendored
View file

@ -117,10 +117,13 @@ dmypy.json
test*.py test*.py
reload.cfg reload.cfg
# symlinks
/izzylib /izzylib
/base/izzylib/dbus /base/izzylib/dbus
/base/izzylib/hasher /base/izzylib/hasher
/base/izzylib/http_requests_client /base/izzylib/http_requests_client
/base/izzylib/http_server
/base/izzylib/mbus
/base/izzylib/sql /base/izzylib/sql
/base/izzylib/template /base/izzylib/template
/base/izzylib/tinydb /base/izzylib/tinydb

View file

@ -4,7 +4,7 @@ Licensed under the CNPL: https://git.pixie.town/thufie/CNPL
https://git.barkshark.xyz/izaliamae/izzylib https://git.barkshark.xyz/izaliamae/izzylib
''' '''
import sys, traceback import os, sys, traceback
assert sys.version_info >= (3, 7) assert sys.version_info >= (3, 7)
__version_tpl__ = (0, 6, 0) __version_tpl__ = (0, 6, 0)
@ -13,6 +13,7 @@ __version__ = '.'.join([str(v) for v in __version_tpl__])
from . import logging from . import logging
izzylog = logging.logger['IzzyLib'] izzylog = logging.logger['IzzyLib']
izzylog.set_config('level', os.environ.get('IZZYLOG_LEVEL', 'INFO'))
from .path import Path from .path import Path
from .dotdict import DotDict, LowerDotDict, DefaultDotDict, MultiDotDict, JsonEncoder from .dotdict import DotDict, LowerDotDict, DefaultDotDict, MultiDotDict, JsonEncoder

View file

@ -42,6 +42,10 @@ def parse_ttl(ttl):
return multiplier * int(amount) return multiplier * int(amount)
class DefaultValue(object):
pass
class BaseCache(OrderedDict): class BaseCache(OrderedDict):
_get = OrderedDict.get _get = OrderedDict.get
_items = OrderedDict.items _items = OrderedDict.items
@ -120,6 +124,20 @@ class BaseCache(OrderedDict):
return item return item
def pop(self, key, default=DefaultValue):
try:
item = self.get(key)
del self[key]
return item
except Exception as e:
if default == DefaultValue:
raise e from None
return default
## This doesn't work for some reason ## This doesn't work for some reason
def CacheDecorator(cache): def CacheDecorator(cache):
def decorator(func): def decorator(func):

View file

@ -59,6 +59,13 @@ class DotDict(dict):
raise AttributeError(f'{self.__class__.__name__} object has no attribute {k}') from None raise AttributeError(f'{self.__class__.__name__} object has no attribute {k}') from None
@classmethod
def new_from_json_file(cls, path):
data = cls()
data.load_json(path)
return data
def copy(self): def copy(self):
return DotDict(self) return DotDict(self)

View file

@ -22,6 +22,7 @@ __all__ = [
'print_methods', 'print_methods',
'prompt', 'prompt',
'random_gen', 'random_gen',
'remove',
'signal_handler', 'signal_handler',
'time_function', 'time_function',
'time_function_pprint', 'time_function_pprint',
@ -348,6 +349,13 @@ def random_gen(length=20, letters=True, numbers=True, extra=None):
return ''.join(random.choices(characters, k=length)) return ''.join(random.choices(characters, k=length))
def remove(string: str, junk: list):
for line in junk:
string = string.replace(line, '')
return string
def signal_handler(func, *args, original_args=True, **kwargs): def signal_handler(func, *args, original_args=True, **kwargs):
if original_args: if original_args:
handler = lambda signum, frame: func(signum, frame, *args, **kwargs) handler = lambda signum, frame: func(signum, frame, *args, **kwargs)

View file

@ -1,4 +1,4 @@
import os, shutil import json, os, shutil
from datetime import datetime from datetime import datetime
from functools import cached_property from functools import cached_property
@ -105,9 +105,9 @@ class Path(str):
return json.load(s) return json.load(s)
def json_dump(self, data): def json_dump(self, data, indent=None):
with self.open('w') as s: with self.open('w') as s:
s.write(json.dumps(data)) s.write(json.dumps(data, indent=indent))
def link(self, path): def link(self, path):

View file

@ -4,6 +4,7 @@ start_time = datetime.now()
from .application import Application from .application import Application
from .config import Config, UserLevel from .config import Config, UserLevel
from .middleware import MiddlewareBase, Headers, AccessLog
from .request import Request from .request import Request
from .response import Response from .response import Response
from .view import View from .view import View

View file

@ -12,7 +12,7 @@ from izzylib.template import Template
from .config import Config, UserLevel from .config import Config, UserLevel
from .error_handlers import GenericError, MissingTemplateError from .error_handlers import GenericError, MissingTemplateError
from .middleware import AccessLog, Headers from .middleware import AccessLog, Headers
from .view import Manifest, Style from .view import Manifest, Robots, Style
log_path_ignore = [ log_path_ignore = [
@ -46,10 +46,12 @@ class Application(sanic.Sanic):
) )
self.template.add_env('cfg', self.cfg) self.template.add_env('cfg', self.cfg)
self.template.add_env('len', len)
if self.cfg.tpl_default: if self.cfg.tpl_default:
self.template.add_search_path(frontend) self.template.add_search_path(frontend)
self.add_class_route(Manifest) self.add_class_route(Manifest)
self.add_class_route(Robots)
self.add_class_route(Style) self.add_class_route(Style)
self.static('/favicon.ico', frontend.join('static/icon64.png')) self.static('/favicon.ico', frontend.join('static/icon64.png'))
self.static('/framework/static', frontend.join('static')) self.static('/framework/static', frontend.join('static'))
@ -93,6 +95,11 @@ class Application(sanic.Sanic):
del self.cfg.menu[name] del self.cfg.menu[name]
def get_route_by_path(self, path, method='get', host=None,):
route, handler, _ = self.router.get(path, method, host)
return handler
def start(self): def start(self):
# register built-in middleware now so they're last in the chain # register built-in middleware now so they're last in the chain
self.add_middleware(Headers) self.add_middleware(Headers)

View file

@ -11,6 +11,7 @@ class UserLevel(IntEnum):
USER = 10 USER = 10
MODERATOR = 20 MODERATOR = 20
ADMIN = 30 ADMIN = 30
AUTH = 1000
class Config(DotDict): class Config(DotDict):
@ -25,6 +26,7 @@ class Config(DotDict):
port = 8080, port = 8080,
proto = 'http', proto = 'http',
workers = cpu_count(), workers = cpu_count(),
access_log = True,
request_class = Request, request_class = Request,
response_class = Response, response_class = Response,
sig_handler = None, sig_handler = None,

View file

@ -39,7 +39,6 @@ a:hover {
} }
input:not([type='checkbox']), select, textarea { input:not([type='checkbox']), select, textarea {
margin: 1px 0;
color: var(--text); color: var(--text);
background-color: var(--background); background-color: var(--background);
border: 1px solid var(--background); border: 1px solid var(--background);
@ -263,7 +262,7 @@ details:focus, summary:focus {
transition: background-color var(--trans-speed); transition: background-color var(--trans-speed);
width: 55px; width: 55px;
margin-left: var(--gap); margin-left: var(--gap);
background-image: url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/menu.svg'); background-image: url('/framework/static/menu.svg');
background-size: 50px; background-size: 50px;
background-position: center center; background-position: center center;
background-repeat: no-repeat; background-repeat: no-repeat;
@ -412,8 +411,8 @@ body {scrollbar-width: 15px; scrollbar-color: var(--primary) {{background.darken
@font-face { @font-face {
font-family: 'sans undertale'; font-family: 'sans undertale';
src: local('Nunito Sans Bold'), src: local('Nunito Sans Bold'),
url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-SemiBold.woff2') format('woff2'), url('/framework/static/nunito/NunitoSans-SemiBold.woff2') format('woff2'),
url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-SemiBold.ttf') format('ttf'); url('/framework/static/nunito/NunitoSans-SemiBold.ttf') format('ttf');
font-weight: bold; font-weight: bold;
font-style: normal; font-style: normal;
} }
@ -421,8 +420,8 @@ body {scrollbar-width: 15px; scrollbar-color: var(--primary) {{background.darken
@font-face { @font-face {
font-family: 'sans undertale'; font-family: 'sans undertale';
src: local('Nunito Sans Light Italic'), src: local('Nunito Sans Light Italic'),
url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-ExtraLightItalic.woff2') format('woff2'), url('/framework/static/nunito/NunitoSans-ExtraLightItalic.woff2') format('woff2'),
url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-ExtraLightItalic.ttf') format('ttf'); url('/framework/static/nunito/NunitoSans-ExtraLightItalic.ttf') format('ttf');
font-weight: normal; font-weight: normal;
font-style: italic; font-style: italic;
} }
@ -430,8 +429,8 @@ body {scrollbar-width: 15px; scrollbar-color: var(--primary) {{background.darken
@font-face { @font-face {
font-family: 'sans undertale'; font-family: 'sans undertale';
src: local('Nunito Sans Bold Italic'), src: local('Nunito Sans Bold Italic'),
url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-Italic.woff2') format('woff2'), url('/framework/static/nunito/NunitoSans-Italic.woff2') format('woff2'),
url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-Italic.ttf') format('ttf'); url('/framework/static/nunito/NunitoSans-Italic.ttf') format('ttf');
font-weight: bold; font-weight: bold;
font-style: italic; font-style: italic;
} }
@ -439,8 +438,8 @@ body {scrollbar-width: 15px; scrollbar-color: var(--primary) {{background.darken
@font-face { @font-face {
font-family: 'sans undertale'; font-family: 'sans undertale';
src: local('Nunito Sans Light'), src: local('Nunito Sans Light'),
url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-Light.woff2') format('woff2'), url('/framework/static/nunito/NunitoSans-Light.woff2') format('woff2'),
url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-Light.ttf') format('ttf'); url('/framework/static/nunito/NunitoSans-Light.ttf') format('ttf');
font-weight: normal; font-weight: normal;
font-style: normal; font-style: normal;
} }

View file

@ -2,20 +2,21 @@
%html %html
%head %head
%title << {{cfg.name}}: {{page}} %title << {{cfg.name}}: {{page}}
%link rel='stylesheet' type='text/css' href='{{cfg.proto}}://{{cfg.web_host}}/framework/style.css' %link rel='stylesheet' type='text/css' href='/framework/style.css'
%link rel='manifest' href='{{cfg.proto}}://{{cfg.web_host}}/framework/manifest.json' %link rel='manifest' href='/framework/manifest.json'
%meta charset='UTF-8' %meta charset='UTF-8'
%meta name='viewport' content='width=device-width, initial-scale=1' %meta name='viewport' content='width=device-width, initial-scale=1'
-block head
%body %body
#body #body
#header.flex-container #header.flex-container
-if menu_left -if menu_left
#btn.section #btn.section
.page-title.section -> %a.title href='{{cfg.proto}}://{{cfg.web_host}}/' << {{cfg.name}} .page-title.section -> %a.title href='/' << {{cfg.name}}
-else -else
.page-title.section -> %a.title href='{{cfg.proto}}://{{cfg.web_host}}/' << {{cfg.name}} .page-title.section -> %a.title href='/' << {{cfg.name}}
#btn.section #btn.section
-if message -if message
@ -27,9 +28,15 @@
#menu.section #menu.section
.title-item.item << Menu .title-item.item << Menu
#items #items
-for label, data in cfg.menu.items() -if not len(cfg.menu):
-if request.user_level >= data[1] -include 'menu.haml'
.item -> %a href='{{cfg.proto}}://{{cfg.web_host}}{{data[0]}}' << {{label}}
-else:
-for label, path_data in cfg.menu.items()
-if path_data[1] == 1000 and request.user_level == 0:
.item -> %a href='{{path_data[0]}}' << {{label}}
-elif request.user_level >= path_data[1]
.item -> %a href='{{path_data[0]}}' << {{label}}
#content-body.section #content-body.section
-block content -block content
@ -40,4 +47,4 @@
.source .source
%a href='{{cfg.git_repo}}' target='_new' << {{cfg.name}}/{{cfg.version}} %a href='{{cfg.git_repo}}' target='_new' << {{cfg.name}}/{{cfg.version}}
%script type='application/javascript' src='{{cfg.proto}}://{{cfg.web_host}}/framework/static/menu.js' %script type='application/javascript' src='/framework/static/menu.js'

View file

@ -0,0 +1 @@
.item => %a(href='/') << Home

View file

@ -6,6 +6,34 @@ from izzylib import izzylog as logging, logging as applog
from . import start_time from . import start_time
cache_types = [
'text/css',
'application/javascript',
]
cache_base_types = [
'image',
'audio',
'video'
]
def cache_check(request, response):
content_type = response.headers.get('content-type')
if request.path.startswith('/framework'):
return True
if not content_type:
return False
if content_type in cache_types:
return True
if any(map(content_type.startswith, cache_base_types)):
return True
class MiddlewareBase: class MiddlewareBase:
attach = 'request' attach = 'request'
@ -22,7 +50,14 @@ class Headers(MiddlewareBase):
attach = 'response' attach = 'response'
async def handler(self, request, response): async def handler(self, request, response):
if request.path.startswith('/framework') or request.path == '/favicon.ico': if not response.headers.get('content-type'):
if request.path.endswith('.css'):
response.headers['content-type'] = 'text/css'
elif request.path.endswith('.js'):
response.headers['content-type'] = 'application/javascript'
if cache_check(request, response):
max_age = int(timedelta(weeks=2).total_seconds()) max_age = int(timedelta(weeks=2).total_seconds())
response.headers['Cache-Control'] = f'immutable,private,max-age={max_age}' response.headers['Cache-Control'] = f'immutable,private,max-age={max_age}'

View file

@ -1,14 +1,16 @@
import sanic import sanic
from izzylib import DotDict
from .misc import Headers from .misc import Headers
class Request(sanic.request.Request): class Request(sanic.request.Request):
def __init__(self, url_bytes, headers, version, method, transport, app): def __init__(self, url_bytes, headers, version, method, transport, app, **kwargs):
super().__init__(url_bytes, headers, version, method, transport, app) super().__init__(url_bytes, headers, version, method, transport, app)
self.Headers = Headers(headers) self.Headers = Headers(headers)
self.Data = Data(self) self.data = Data(self)
self.template = self.app.template self.template = self.app.template
self.user_level = 0 self.user_level = 0
self.setup() self.setup()

View file

@ -1,6 +1,7 @@
import json, sanic import json, sanic
from izzylib import DotDict from datetime import datetime
from izzylib import DotDict, izzylog
from izzylib.template import Color from izzylib.template import Color
from sanic.compat import Header from sanic.compat import Header
from sanic.cookies import CookieJar from sanic.cookies import CookieJar
@ -27,6 +28,18 @@ class Response:
'speed': 250 'speed': 250
}) })
cookie_keys = {
'expires': 'Expires',
'path': 'Path',
'comment': 'Comment',
'domain': 'Domain',
'max_age': 'Max-Age',
'secure': 'Secure',
'httponly': 'HttpOnly',
'version': 'Version',
'samesite': 'SameSite'
}
def __init__(self, app, request, body=None, headers={}, cookies={}, status=200, content_type='text/html'): def __init__(self, app, request, body=None, headers={}, cookies={}, status=200, content_type='text/html'):
# server objects # server objects
@ -110,6 +123,8 @@ class Response:
data.update(kwargs) data.update(kwargs)
for k,v in data.items(): for k,v in data.items():
k = self.cookie_keys.get(k, k)
if k.lower() == 'max-age': if k.lower() == 'max-age':
if isinstance(v, timedelta): if isinstance(v, timedelta):
v = int(v.total_seconds()) v = int(v.total_seconds())
@ -118,13 +133,16 @@ class Response:
raise TypeError('Max-Age must be an integer or timedelta') raise TypeError('Max-Age must be an integer or timedelta')
elif k.lower() == 'expires': elif k.lower() == 'expires':
if isinstance(v, datetime): if isinstance(v, str):
v = v.strftime('%a, %d-%b-%Y %T GMT') v = datetime.strptime(v, '%a, %d-%b-%Y %T GMT')
elif not isinstance(v, str): elif not isinstance(v, datetime):
raise TypeError('Expires must be a string or datetime') raise TypeError('Expires must be a string or datetime')
try:
self.cookies[key][k] = v self.cookies[key][k] = v
except KeyError as e:
izzylog.error('Invalid cookie key:', k)
def get_cookie(self, key): def get_cookie(self, key):
@ -140,7 +158,7 @@ class Response:
del self.cookies[key] del self.cookies[key]
def template(self, tplfile, context={}, headers={}, status=200, content_type='text/html', cookies={}, pprint=False): def template(self, tplfile, context={}, headers={}, status=200, content_type='text/html', pprint=False):
self.status = status self.status = status
context.update({ context.update({
'response': self, 'response': self,
@ -148,7 +166,7 @@ class Response:
}) })
html = self.app.template.render(tplfile, context, request=self.request, pprint=pprint) html = self.app.template.render(tplfile, context, request=self.request, pprint=pprint)
return self.html(html, headers=headers, status=status, content_type=content_type, cookies=cookies) return self.html(html, headers=headers, status=status, content_type=content_type)
def error(self, message, status=500, **kwargs): def error(self, message, status=500, **kwargs):
@ -158,13 +176,13 @@ class Response:
return self.template('error.haml', {'error_message': message}, status=status, **kwargs) return self.template('error.haml', {'error_message': message}, status=status, **kwargs)
def json(self, body={}, headers={}, status=200, content_type='application/json', cookies={}): def json(self, body={}, headers={}, status=200, content_type='application/json'):
body = json.dumps(body) body = json.dumps(body)
return self.get_response(body, headers, status, content_type, cookies) return self.get_response(body, headers, status, content_type)
def text(self, body, headers={}, status=200, content_type='text/plain', cookies={}): def text(self, body, headers={}, status=200, content_type='text/plain'):
return self.get_response(body, headers, status, content_type, cookies) return self.get_response(body, headers, status, content_type)
def html(self, *args, **kwargs): def html(self, *args, **kwargs):
@ -174,24 +192,26 @@ class Response:
def css(self, *args, **kwargs): def css(self, *args, **kwargs):
self.content_type = 'text/css' self.content_type = 'text/css'
return self.text.text(*args, **kwargs) return self.text(*args, **kwargs)
def javascript(self, *args, **kwargs): def javascript(self, *args, **kwargs):
self.content_type = 'application/javascript' self.content_type = 'application/javascript'
return self.text.text(*args, **kwargs) return self.text(*args, **kwargs)
def activitypub(self, *args, **kwargs): def activitypub(self, *args, **kwargs):
self.content_type = 'application/activity+json' self.content_type = 'application/activity+json'
return self.text.text(*args, **kwargs) return self.text(*args, **kwargs)
def redir(self, path, status=302, headers={}): def redir(self, path, status=302, headers={}):
return sanic.response.redirect(path, status=status, headers={}) headers.update(dict(location=path))
return self.text(body=None, status=status, headers=headers)
#return sanic.response.redirect(path, status=status, headers={})
def set_data(self, body=None, headers={}, status=200, content_type='text/html', cookies={}): def set_data(self, body=None, headers={}, status=200, content_type='text/html'):
ctype = self.content_types.get(content_type, content_type) ctype = self.content_types.get(content_type, content_type)
self.body = body self.body = body

View file

@ -48,6 +48,14 @@ class Manifest(View):
return response.json(data) return response.json(data)
class Robots(View):
paths = ['/robots.txt']
async def get(self, request, response):
data = '# Disallow all\nUser-agent: *\nDisallow: /'
return response.text(data)
class Style(View): class Style(View):
paths = ['/framework/style.css'] paths = ['/framework/style.css']

View file

@ -5,6 +5,7 @@ from setuptools import setup, find_namespace_packages
requires = [ requires = [
'sanic>=20.12.3', 'sanic>=20.12.3',
'sanic-cors>=1.0.0', 'sanic-cors>=1.0.0',
'envbash>=1.0.0'
] ]

View file

@ -1,2 +1,6 @@
from .generic import SqlColumn, CustomRows, SqlSession, SqlDatabase, Tables # old sql classes
from .generic import SqlColumn, CustomRows, SqlSession, SqlDatabase, Tables, OperationalError, ProgrammingError
from .sqlite_server import SqliteClient, SqliteColumn, SqliteServer, SqliteSession from .sqlite_server import SqliteClient, SqliteColumn, SqliteServer, SqliteSession
from .database import Database, Session
from .queries import Column, Insert, Select, Table, Tables, Update

100
sql/izzylib/sql/config.py Normal file
View file

@ -0,0 +1,100 @@
import importlib, sqlite3, ssl
from getpass import getuser
from izzylib import DotDict, Path, izzylog
defaults = {
'name': (None, str),
'host': (None, str),
'port': (None, int),
'username': (getuser(), str),
'password': (None, str),
'ssl': ('allow', str),
'ssl_context': (ssl.create_default_context(), ssl.SSLContext),
'ssl_key': (None, Path),
'ssl_cert': (None, Path),
'max_connections': (25, int),
'type': ('sqlite', str),
'module': (sqlite3, None),
'mod_name': ('sqlite3', str),
'timeout': (5, int),
'args': ([], list),
'kwargs': ({}, dict)
}
modtypes = {
'sqlite': ['sqlite3'],
'postgresql': ['pg8000', 'psycopg2', 'psycopg3', 'pgdb'],
'mysql': ['mysqldb', 'trio_mysql'],
'mssql': ['pymssql', 'adodbapi']
}
sslmodes = ['disable', 'allow', 'prefer', 'require', 'verify-ca', 'verify-full']
class Config(DotDict):
def __init__(self, **kwargs):
super().__init__({k: v[0] for k,v in defaults.items()})
module = kwargs.pop('module', None)
if module:
self.parse_module(module)
self.update(kwargs)
if self.ssl != 'disable' and (self.ssl_key or self.ssl_cert):
self.ssl_context.load_cert_chain(self.ssl_cert, self.ssl_key)
def __setitem__(self, key, value):
if key not in defaults:
raise KeyError(f'Invalid config option: {key}')
valtype = defaults[key][1]
if valtype and value and not isinstance(value, valtype):
raise TypeError(f'{key} should be a {valtype}, not a {value.__class__.__name__}')
if key == 'ssl' and value == True:
value = ssl.create_default_context()
super().__setitem__(key, value)
def parse_module(self, name):
module = None
module_type = None
module_name = None
if name == 'sqlite3':
name = 'sqlite'
for mtype, modules in modtypes.items():
if name == mtype:
module_type = name
for mod in modules:
try:
module = importlib.import_module(mod)
module_name = mod
break
except ImportError:
izzylog.verbose(f'Database module not installed:', mod)
elif name in modules:
try:
module = importlib.import_module(name)
module_type = mtype
module_name = name
break
except ImportError:
izzylog.error(f'Database module not installed:', name)
if None in (module, module_name, module_type):
raise ValueError(f'Failed to find module for {name}')
self.module = module
self.mod_name = module_name
self.type = module_type

360
sql/izzylib/sql/database.py Normal file
View file

@ -0,0 +1,360 @@
import sqlite3, traceback
from functools import partial
from getpass import getuser
from izzylib import DotDict, izzylog, boolean, random_gen
from . import error
from .config import Config
from .queries import Column, Delete, Insert, Select, Table, Tables, Update
class Database:
def __init__(self, tables=None, **kwargs):
self.tables = tables
self.cfg = Config(**kwargs)
self.sessions = DotDict()
@property
def session(self):
return self.get_session(False)
@property
def session_trans(self):
return self.get_session(True)
def connect(self, sid, session):
if len(self.sessions) >= self.cfg.max_connections:
raise error.MaxConnectionsError(f'Cannot start a new session with id {sid}. Reach max connection count of {self.cfg.max_connections}.')
self.sessions[sid] = session
def disconnect(self, sid):
self.sessions[sid].disconnect()
del self.sessions[sid]
def disconnect_all(self):
sids = []
for sid in self.sessions.keys():
sids.append(sid)
for sid in sids:
self.disconnect(sid)
def get_session(self, trans=True):
session = Session(self, trans)
self.sessions[session.id] = session
return session
def execute(self, *args):
with self.session as s:
s.execute(*args)
def load_tables(self, path):
self.tables = Tables.new_from_json_file(path)
def pre_setup(self):
if self.cfg.type != 'postgresql':
izzylog.verbose(f'Database not supported for pre_setup: {self.cfg.type}')
return
original_database = self.cfg.name
self.cfg.name = 'postgres'
with self.session as s:
s.conn.autocommit = True
s.rollback()
if original_database not in s.get_databases():
#s.execute('SET AUTOCOMMIT = OFF')
s.cursor.execute(f'CREATE DATABASE {original_database}')
s.conn.autocommit = False
self.cfg.name = original_database
def set_row_class(self, table, row_class):
pass
class Session:
def __init__(self, db, trans):
self.id = random_gen()
self.db = db
self.cfg = db.cfg
self.trans = trans
self.trans_state = False
self.conn = None
self.cursor = None
def __del__(self):
try:
izzylog.verbose('Deleting session:', self.id)
except ModuleNotFoundError:
if izzylog.get_config('level') >= 20:
print('[izzylib] VERBOSE: Deleting session:', self.id)
self.db.sessions.pop(self.id, None)
if self.conn:
self.disconnect()
def __enter__(self):
self.connect()
if self.trans:
self.begin()
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
if exc_traceback:
self.rollback()
else:
self.commit()
self.disconnect()
self.db.disconnect(self.id)
def connect(self):
if self.conn:
return
self.db.connect(self.id, self)
if self.cfg.type == 'sqlite':
self.conn = self.cfg.module.connect(self.cfg.name, self.cfg.timeout, check_same_thread=True)
elif self.cfg.type == 'postgresql':
options = dict(
host = self.cfg.host or '/var/run/postgresql',
port = self.cfg.port or 5432,
database = self.cfg.name or 'postgresql',
user = self.cfg.username or getuser(),
password = self.cfg.password,
)
if self.cfg.mod_name == 'pg8000':
if options['host'] in [None, '/var/run/postgresql']:
port = options.pop('port')
options['unix_sock'] = options.pop('host') + f'/.s.PGSQL.{port}'
## SSL is a pain in the ass tbh. Gonna deal with this later
#if self.cfg.mod_name == 'pg8000':
#options['sslmode'] = self.cfg.ssl
#options['ssl_context'] = self.cfg.ssl_context
#elif self.cfg.mod_name == 'psycopg2':
#options['sslcert'] = self.cfg.ssl_cert
#options['sslkey'] = self.cfg.ssl_key
self.conn = self.cfg.module.connect(**options)
else:
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
try:
self.conn.autocommit = False
except:
izzylog.verbose('Failed to turn off autocommit')
self.cursor = self.conn.cursor()
def disconnect(self):
if not self.conn:
return
self.cursor.close()
self.conn.close()
self.cursor = None
self.conn = None
def begin(self):
if self.trans_state:
return
#self.conn.begin()
self.execute('BEGIN TRANSACTION')
self.trans_state = True
def rollback(self):
if not self.trans_state:
return
self.conn.rollback()
#self.execute('ROLLBACK TRANSACTION')
self.trans_state = False
def commit(self):
if not self.trans_state:
return
self.conn.commit()
#self.execute('COMMIT TRANSACTION')
self.trans_state = False
## data management functions
def execute(self, string, values=[]):
if any(map(string.lower().startswith, ['insert', 'update', 'remove', 'create', 'drop'])) and not self.trans_state:
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
self.cursor.execute(string, values)
return self.cursor
def fetch(self, table, single=True, **kwargs):
rows = []
data = Select(table, type=self.cfg.type, **kwargs).exec(self)
for line in data:
row = Row(table, self.cursor.description, line)
if single:
return row
rows.append(row)
return rows if not single else None
def search(self, table, **kwargs):
return self.fetch(table, single=False, **kwargs)
def insert(self, table, **kwargs):
if not self.trans_state:
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
Insert(table, type=self.cfg.type, **kwargs).exec(self)
return self.fetch(table, **kwargs)
def update(self, table, rowid, **kwargs):
if not self.trans_state:
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
Update(table, rowid, type=self.cfg.type, **kwargs).exec(self)
return self.fetch(table, id=rowid)
def delete(self, table, **kwargs):
if not self.trans_state:
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
Delete(table, type=self.cfg.type, **kwargs).exec(self)
## helper functions
def get_columns(self, table):
if table not in self.get_tables():
raise KeyError(f'Not an existing table: {table}')
if self.cfg.type == 'sqlite':
rows = self.execute(f'PRAGMA table_info({table})')
return [row[1] for row in rows]
elif self.cfg.type == 'postgresql':
rows = self.execute(f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table}'")
return [row[0] for row in rows]
else:
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
def get_tables(self):
if self.cfg.type == 'sqlite':
rows = self.execute("SELECT name FROM sqlite_master WHERE type IN ('table','view') and name NOT LIKE 'sqlite_%'")
elif self.cfg.type == 'postgresql':
rows = self.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' ORDER BY table_name")
else:
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
return [row[0] for row in rows]
def get_databases(self):
if self.cfg.type == 'sqlite':
izzylog.verbose('This function is useless with sqlite')
return
elif self.cfg.type == 'postgresql':
databases = [row[0] for row in self.execute('SELECT datname FROM pg_database')]
else:
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
return databases
def cursor_description(self):
return [row[0] for row in self.cursor.description]
def setup_database(self):
if not self.db.tables:
raise ValueError('Tables have not been specified.')
current_tables = self.get_tables()
for name, table in self.db.tables.items():
if name in current_tables:
izzylog.verbose(f'Skipping table creation since it already exists: {name}')
continue
izzylog.verbose(f'Creating table: {name}')
self.execute(table.build(self.cfg.type))
class Row(DotDict):
def __init__(self, table, keys, values):
self._db = None
self._table = table
super().__init__()
for idx, key in enumerate([key[0] for key in keys]):
self[key] = values[idx]
def update(self, data):
for k, v in data.items():
if k not in self:
raise KeyError(f'Not a column for {self._table}')
self[k] = v
def delete(self):
with self._db.session as s:
s.delete(self._table, id=self.id)
def update(self, **kwargs):
self.update(kwargs)
with self._db.session as s:
s.update(self._table, id=self.id, **kwargs)

10
sql/izzylib/sql/error.py Normal file
View file

@ -0,0 +1,10 @@
class MaxConnectionsError(Exception):
'raise when the max amount of connections has been reached'
class NoTransactionError(Exception):
'raise when a write command is executed outside a transaction'
class DatabaseNotSupportedError(Exception):
'raise when the action being performed is not supported by the database in use'

View file

@ -3,7 +3,7 @@ import json, sys, threading, time
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from sqlalchemy import create_engine, ForeignKey, MetaData, Table from sqlalchemy import create_engine, ForeignKey, MetaData, Table
from sqlalchemy import Column, types as Types from sqlalchemy import Column as sqlalchemy_column, types as Types
from sqlalchemy.exc import OperationalError, ProgrammingError from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
@ -57,6 +57,12 @@ class SqlDatabase:
engine_string += f'/{database}' engine_string += f'/{database}'
engine_kwargs['connect_args'] = {'check_same_thread': False} engine_kwargs['connect_args'] = {'check_same_thread': False}
elif dbtype == 'postgresql':
ssl_context = kwargs.get('ssl')
if ssl_context:
engine_kwargs['ssl_context'] = ssl_context
else: else:
user = kwargs.get('user') user = kwargs.get('user')
password = kwargs.get('pass') password = kwargs.get('pass')
@ -124,9 +130,9 @@ class SqlDatabase:
self.table_names = tables.keys() self.table_names = tables.keys()
def execute(self, *args, **kwargs): def execute(self, string, values=[]):
with self.session as s: with self.session as s:
return s.execute(*args, **kwargs) s.execute(string, values)
class SqlSession(object): class SqlSession(object):
@ -267,7 +273,7 @@ class SqlSession(object):
if not rowid or not table: if not rowid or not table:
raise ValueError('Missing row ID or table') raise ValueError('Missing row ID or table')
row = self.execute(f'DELETE FROM {table} WHERE id={rowid}') self.execute(f'DELETE FROM {table} WHERE id={rowid}')
def drop_table(self, name): def drop_table(self, name):
@ -284,12 +290,28 @@ class SqlSession(object):
self.drop_table(table) self.drop_table(table)
def get_columns(self, table):
if table not in self.get_tables():
raise KeyError(f'Not an existing table: {table}')
rows = self.execute('PRAGMA table_info(user)')
return [row[1] for row in rows]
def get_tables(self): def get_tables(self):
rows = self.execute("SELECT name FROM sqlite_master WHERE type IN ('table','view') and name NOT LIKE 'sqlite_%'") rows = self.execute("SELECT name FROM sqlite_master WHERE type IN ('table','view') and name NOT LIKE 'sqlite_%'")
return [row[0] for row in rows] return [row[0] for row in rows]
def append_column(self, tbl, col): def append_column(self, table, column):
if column.name in self.get_columns(table):
logging.warning(f'Table "{table}" already has column "{column.name}"')
return
self.execute(f'ALTER TABLE {table} ADD COLUMN {column.compile()}')
def append_column2(self, tbl, col):
table = self.table[tbl] table = self.table[tbl]
try: try:
@ -301,7 +323,7 @@ class SqlSession(object):
columns = [row[1] for row in self.execute(f'PRAGMA table_info({tbl})')] columns = [row[1] for row in self.execute(f'PRAGMA table_info({tbl})')]
if col in columns: if col in self.get_columns(tbl):
izzylog.info(f'Column "{col}" already exists') izzylog.info(f'Column "{col}" already exists')
return return
@ -436,29 +458,51 @@ class Tables(DotDict):
def __setup_table(self, name, table): def __setup_table(self, name, table):
columns = [col if type(col) == Column else Column(*col.get('args'), **col.get('kwargs')) for col in table] columns = [col if type(col) == SqlColumn else SqlColumn(*col.get('args'), **col.get('kwargs')) for col in table]
self[name] = Table(name, self.meta, *columns) self[name] = Table(name, self.meta, *columns)
def SqlColumn(name, stype=None, fkey=None, **kwargs): class SqlColumn(sqlalchemy_column):
def __init__(self, name, stype=None, fkey=None, **kwargs):
if not stype and not kwargs: if not stype and not kwargs:
if name == 'id': if name == 'id':
return Column('id', SqlTypes['integer'], primary_key=True, autoincrement=True) stype = 'integer'
kwargs['primary_key'] = True
kwargs['autoincrement'] = True
elif name == 'timestamp': elif name == 'timestamp':
return Column('timestamp', SqlTypes['datetime']) stype = 'datetime'
raise ValueError('Missing column type and options')
else: else:
raise ValueError('Missing column type and options')
stype = (stype.lower() if type(stype) == str else stype) or 'string'
if type(stype) == str:
try: try:
stype = stype or 'string' stype = SqlTypes[stype.lower()]
options = [name, SqlTypes[stype.lower()]]
except KeyError: except KeyError:
raise KeyError(f'Invalid SQL data type: {stype}') raise KeyError(f'Invalid SQL data type: {stype}')
options = [name, stype]
if fkey: if fkey:
options.append(ForeignKey(fkey)) options.append(ForeignKey(fkey))
return Column(*options, **kwargs) super().__init__(*options, **kwargs)
def compile(self):
sql = f'{self.name} {self.type}'
if not self.nullable:
sql += ' NOT NULL'
if self.primary_key:
sql += ' PRIMARY KEY'
if self.unique:
sql += ' UNIQUE'
return sql

415
sql/izzylib/sql/queries.py Normal file
View file

@ -0,0 +1,415 @@
from datetime import datetime
from functools import partial
from izzylib import DotDict, Path
from .types import BaseType, Type
placeholders = dict(
sqlite = '?',
postgresql = '%s'
)
## Data queries
class Delete:
def __init__(self, table, type='sqlite', **kwargs):
self.table = table
self.placeholder = placeholders[type]
self.keys = []
self.values = []
for k,v in kwargs.items():
self.keys.append(k)
self.values.append(v)
def __str__(self):
self.build(embed_values=True)
def build(self, comp_type='AND', embed_values=False):
sql = 'DELETE FROM {table} WHERE {kstring}'
if not embed_values:
kstring = f' {comp_type.upper()} '.join([f'{k} = {self.placeholder}' for k in self.keys])
return sql.format(table=self.table, kstring=kstring), self.values
values = []
for idx, value in enumerate(self.values):
if type(value) == str:
values.append(f"{self.keys[idx]} = '{value}'")
else:
values.append(f"{self.keys[idx]} = {value}")
kstring = ','.join(values)
return sql.format(table=self.table, kstring=kstring, rowid=self.rowid)
def exec(self, session, comp_type='AND'):
return session.execute(*self.build(comp_type))
class Insert:
def __init__(self, table, type='sqlite', **kwargs):
self.table = table
self.placeholder = placeholders[type]
self.keys = []
self.values = []
for k, v in kwargs.items():
self.keys.append(k)
self.values.append(v)
def __str__(self):
return self.build(embed_values=True)
def build(self, embed_values=False):
kstring = ','.join(self.keys)
if not embed_values:
vstring = ','.join([self.placeholder for k in self.keys])
return f'INSERT INTO {self.table} ({kstring}) VALUES({vstring})', self.values
else:
vstring = ','.join(self.values)
return f'INSERT INTO {self.table} ({kstring}) VALUES({vstring})'
def exec(self, session):
return session.execute(*self.build())
class Select:
def __init__(self, table, columns=[], type='sqlite', **kwargs):
self.placeholder = placeholders[type]
self.columns = columns
self.table = table
self.where = []
self.where_build = []
self._order = []
self.keys = []
self.values = []
self.equals = partial(self.__comparison, '=')
self.less = partial(self.__comparison, '<')
self.greater = partial(self.__comparison, '>')
self.like = partial(self.__comparison, 'LIKE')
for k,v in kwargs.items():
self.equals(k, v)
def __str__(self):
return self.build(embed_values=True)
def __comparison(self, comp, key, value):
self.values.append(value)
self.keys.append(key)
self.where.append(f'{key} {comp.upper()} {self.placeholder}')
self.where_build.append(f"{key} {comp.upper()} '{value}'" if type(key) == str else f"{key} {comp.upper()} {value}")
return self
def order(self, column, asc=True):
self._order = [column, 'ASC' if asc else 'DESC']
return self
def build(self, comp_type='AND', embed_values=False):
if not self.columns:
cols = '*'
else:
cols = ','.join('columns')
sql_query = f'SELECT {cols} FROM {self.table}'
if self.where:
where = f' {comp_type.upper()} '.join(self.where if not embed_values else self.where_build)
sql_query += f' WHERE {where}'
if self._order:
col, order = self._order
sql_query += f' ORDER BY {col} {order}'
if embed_values:
return sql_query
return sql_query, self.values
def exec(self, session, comp_type='AND'):
return session.execute(*self.build(comp_type))
class Update:
def __init__(self, table, rowid, type='sqlite', **kwargs):
self.placeholder = placeholders[type]
self.table = table
self.rowid = rowid
self.keys = []
self.values = []
for k,v in kwargs.items():
self.keys.append(k)
self.values.append(v)
def __str__(self):
return self.build(embed_values=True)
def build(self, embed_values=False):
sql = 'UPDATE {table} SET {kstring} WHERE id={rowid}'
if not embed_values:
kstring = ','.join([f'{k} = {self.placeholder}' for k in self.keys])
return sql.format(table=self.table, kstring=kstring, rowid=self.rowid), self.values
values = []
for idx, value in enumerate(self.values):
if type(value) == str:
values.append(f"{self.keys[idx]} = '{value}'")
else:
values.append(f"{self.keys[idx]} = {value}")
kstring = ','.join(values)
return sql.format(table=self.table, kstring=kstring, rowid=self.rowid)
def exec(self, session):
return session.execute(*self.build())
## Database objects
class Column:
def __init__(self, name, type='STRING', unique=False, nullable=True, default=None, primary_key=False, autoincrement=False, foreign_key=None):
self.name = name
self.type = type
self.nullable = nullable
self.default = default
self.primary_key = primary_key
self.autoincrement = autoincrement
self.unique = unique
if any(map(isinstance, [foreign_key], [list, tuple, set])):
self.foreign_key = foreign_key
else:
self.foreign_key = foreign_key.split('.') if foreign_key else None
if autoincrement:
self.primary_key = True
self.type = Type['INTEGER']
if isinstance(self.type, BaseType):
self.type = self.type.name
else:
if self.type.upper() in Type.keys():
self.type = self.type.upper()
else:
raise TypeError(f'Invalid SQL type: {self.type}')
if foreign_key and len(self.foreign_key) != 2:
raise ValueError('Invalid foreign key. Must be in the format "table.column".')
def __str__(self):
return self.build()
def build(self, dbtype='sqlite'):
if dbtype == 'postgresql':
if self.type.lower() == 'string':
self.type = 'TEXT'
elif self.type.lower() == 'datetime':
self.type = 'TIMESTAMPTZ'
if self.autoincrement:
self.type = 'SERIAL'
self.autoincrement = False
sql = f'{self.name} {self.type}'
if self.primary_key:
sql += ' PRIMARY KEY'
if self.autoincrement:
sql += ' AUTOINCREMENT'
if self.unique:
sql += ' UNIQUE'
if not self.nullable:
sql += ' NOT NULL'
if self.default:
def_type = type(self.default)
if self.default == 'CURRENT_TIMESTAMP':
if dbtype == 'sqlite':
sql += " DEFAULT (datetime('now', 'localtime'))"
elif dbtype == 'postgresql':
sql += ' DEFAULT now()'
else:
sql += f' DEFAULT {datetime.now().timestamp()}'
elif def_type == str:
sql += f" DEFAULT '{self.default}'"
elif def_type in [int, float]:
sql += f' DEFAULT {self.default}'
elif def_type == bool and dbtype == 'sqlite':
sql += f' DEFAULT {int(self.default)}'
else:
sql += f' DEFAULT {self.default}'
print(sql)
return sql
def json(self):
return DotDict({
'type': self.type,
'nullable': self.nullable,
'default': self.default,
'primary_key': self.primary_key,
'autoincrement': self.autoincrement,
'unique': self.unique,
'foreign_key': self.foreign_key
})
class Table(DotDict):
def __init__(self, name, *columns):
super().__init__()
self._name = name
self._foreign_keys = {}
self.add_column(Column('id', autoincrement=True))
for column in columns:
self.add_column(column)
def __str__(self):
return self.build()
# this'll be useful later
def __call__(self, *args, **kwargs):
pass
@property
def name(self):
return self._name
def add_column(self, column):
self[column.name] = column
if column.foreign_key:
self._foreign_keys[column.name] = column.foreign_key
def build(self, dbtype='sqlite'):
column_string = ',\n'.join([f'\t{col.build(dbtype)}' for col in self.values()])
if self._foreign_keys:
column_string += ',\n'
column_string += ',\n'.join([f'\tFOREIGN KEY ({column}) REFERENCES {key[0]} ({key[1]})' for column, key in self._foreign_keys.items()])
return f'''CREATE TABLE {self.name} (
{column_string}
);'''
def json(self):
data = {}
for name, column in self.items():
data[name] = column.json()
return data
class Tables(DotDict):
def __init__(self, *tables, data={}):
super().__init__()
for table in tables:
self.add_table(table)
if data:
self.from_dict(data)
def __str__(self):
return self.build()
@classmethod
def new_from_json_file(cls, path):
return cls(data=DotDict.new_from_json_file(path))
def add_table(self, table):
self[table.name] = table
def build(self):
return '\n\n'.join([str(table) for table in self.values()])
def load_json(self, path):
data = DotDict()
data.load_json(path)
self.from_dict(data)
def save_json(self, path, indent='\t'):
self.to_dict().save_json(path, indent=indent)
def from_dict(self, data):
for name, columns in data.items():
table = Table(name)
for col, kwargs in columns.items():
table.add_column(Column(col,
type = kwargs.get('type', 'STRING'),
nullable = kwargs.get('nullable', True),
default = kwargs.get('default'),
primary_key = kwargs.get('primary_key', False),
autoincrement = kwargs.get('autoincrement', False),
unique = kwargs.get('unique', False),
foreign_key = kwargs.get('foreign_key')
))
self.add_table(table)
def to_dict(self):
data = DotDict()
for name, table in self.items():
data[name] = table.json()
return data

19
sql/izzylib/sql/row.py Normal file
View file

@ -0,0 +1,19 @@
from izzylib import DotDict
class DbRow(DotDict):
def __init__(self, table, keys, values):
self.table = table
super().__init__()
for idx, key in enumerate(keys):
self[key] = values[idx]
def delete(self):
pass
def update(self, **kwargs):
pass

19
sql/izzylib/sql/types.py Normal file
View file

@ -0,0 +1,19 @@
from enum import Enum
from izzylib import DotDict
class BaseType(Enum):
INTEGER = int
TEXT = str
BLOB = bytes
REAL = float
NUMERIC = float
Type = DotDict(
**{v: BaseType.INTEGER for v in ['INT', 'INTEGER', 'TINYINT', 'SMALLINT', 'MEDIUMINT', 'BIGINT', 'UNSIGNED BIG INT', 'INT2', 'INT8']},
**{v: BaseType.TEXT for v in ['CHARACTER', 'VARCHAR', 'VARYING CHARACTER', 'NCHAR', 'NATIVE CHARACTER', 'NVARCHAR', 'TEXT', 'CLOB', 'STRING', 'JSON']},
**{v: BaseType.BLOB for v in ['BYTES', 'BLOB']},
**{v: BaseType.REAL for v in ['REAL', 'DOUBLE', 'DOUBLE PRECISION', 'FLOAT']},
**{v: BaseType.NUMERIC for v in ['NUMERIC', 'DECIMAL', 'BOOLEAN', 'DATE', 'DATETIME']}
)