This commit is contained in:
Izalia Mae 2021-08-29 22:14:55 -04:00
parent 11aaaf3499
commit cd20eaea97
25 changed files with 1155 additions and 64 deletions

3
.gitignore vendored
View file

@ -117,10 +117,13 @@ dmypy.json
test*.py
reload.cfg
# symlinks
/izzylib
/base/izzylib/dbus
/base/izzylib/hasher
/base/izzylib/http_requests_client
/base/izzylib/http_server
/base/izzylib/mbus
/base/izzylib/sql
/base/izzylib/template
/base/izzylib/tinydb

View file

@ -4,7 +4,7 @@ Licensed under the CNPL: https://git.pixie.town/thufie/CNPL
https://git.barkshark.xyz/izaliamae/izzylib
'''
import sys, traceback
import os, sys, traceback
assert sys.version_info >= (3, 7)
__version_tpl__ = (0, 6, 0)
@ -13,6 +13,7 @@ __version__ = '.'.join([str(v) for v in __version_tpl__])
from . import logging
izzylog = logging.logger['IzzyLib']
izzylog.set_config('level', os.environ.get('IZZYLOG_LEVEL', 'INFO'))
from .path import Path
from .dotdict import DotDict, LowerDotDict, DefaultDotDict, MultiDotDict, JsonEncoder

View file

@ -42,6 +42,10 @@ def parse_ttl(ttl):
return multiplier * int(amount)
class DefaultValue(object):
pass
class BaseCache(OrderedDict):
_get = OrderedDict.get
_items = OrderedDict.items
@ -120,6 +124,20 @@ class BaseCache(OrderedDict):
return item
def pop(self, key, default=DefaultValue):
try:
item = self.get(key)
del self[key]
return item
except Exception as e:
if default == DefaultValue:
raise e from None
return default
## This doesn't work for some reason
def CacheDecorator(cache):
def decorator(func):

View file

@ -59,6 +59,13 @@ class DotDict(dict):
raise AttributeError(f'{self.__class__.__name__} object has no attribute {k}') from None
@classmethod
def new_from_json_file(cls, path):
data = cls()
data.load_json(path)
return data
def copy(self):
return DotDict(self)

View file

@ -22,6 +22,7 @@ __all__ = [
'print_methods',
'prompt',
'random_gen',
'remove',
'signal_handler',
'time_function',
'time_function_pprint',
@ -348,6 +349,13 @@ def random_gen(length=20, letters=True, numbers=True, extra=None):
return ''.join(random.choices(characters, k=length))
def remove(string: str, junk: list):
for line in junk:
string = string.replace(line, '')
return string
def signal_handler(func, *args, original_args=True, **kwargs):
if original_args:
handler = lambda signum, frame: func(signum, frame, *args, **kwargs)

View file

@ -1,4 +1,4 @@
import os, shutil
import json, os, shutil
from datetime import datetime
from functools import cached_property
@ -105,9 +105,9 @@ class Path(str):
return json.load(s)
def json_dump(self, data):
def json_dump(self, data, indent=None):
with self.open('w') as s:
s.write(json.dumps(data))
s.write(json.dumps(data, indent=indent))
def link(self, path):

View file

@ -4,6 +4,7 @@ start_time = datetime.now()
from .application import Application
from .config import Config, UserLevel
from .middleware import MiddlewareBase, Headers, AccessLog
from .request import Request
from .response import Response
from .view import View

View file

@ -12,7 +12,7 @@ from izzylib.template import Template
from .config import Config, UserLevel
from .error_handlers import GenericError, MissingTemplateError
from .middleware import AccessLog, Headers
from .view import Manifest, Style
from .view import Manifest, Robots, Style
log_path_ignore = [
@ -46,10 +46,12 @@ class Application(sanic.Sanic):
)
self.template.add_env('cfg', self.cfg)
self.template.add_env('len', len)
if self.cfg.tpl_default:
self.template.add_search_path(frontend)
self.add_class_route(Manifest)
self.add_class_route(Robots)
self.add_class_route(Style)
self.static('/favicon.ico', frontend.join('static/icon64.png'))
self.static('/framework/static', frontend.join('static'))
@ -93,6 +95,11 @@ class Application(sanic.Sanic):
del self.cfg.menu[name]
def get_route_by_path(self, path, method='get', host=None,):
route, handler, _ = self.router.get(path, method, host)
return handler
def start(self):
# register built-in middleware now so they're last in the chain
self.add_middleware(Headers)

View file

@ -11,6 +11,7 @@ class UserLevel(IntEnum):
USER = 10
MODERATOR = 20
ADMIN = 30
AUTH = 1000
class Config(DotDict):
@ -25,6 +26,7 @@ class Config(DotDict):
port = 8080,
proto = 'http',
workers = cpu_count(),
access_log = True,
request_class = Request,
response_class = Response,
sig_handler = None,

View file

@ -39,7 +39,6 @@ a:hover {
}
input:not([type='checkbox']), select, textarea {
margin: 1px 0;
color: var(--text);
background-color: var(--background);
border: 1px solid var(--background);
@ -263,7 +262,7 @@ details:focus, summary:focus {
transition: background-color var(--trans-speed);
width: 55px;
margin-left: var(--gap);
background-image: url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/menu.svg');
background-image: url('/framework/static/menu.svg');
background-size: 50px;
background-position: center center;
background-repeat: no-repeat;
@ -412,8 +411,8 @@ body {scrollbar-width: 15px; scrollbar-color: var(--primary) {{background.darken
@font-face {
font-family: 'sans undertale';
src: local('Nunito Sans Bold'),
url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-SemiBold.woff2') format('woff2'),
url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-SemiBold.ttf') format('ttf');
url('/framework/static/nunito/NunitoSans-SemiBold.woff2') format('woff2'),
url('/framework/static/nunito/NunitoSans-SemiBold.ttf') format('ttf');
font-weight: bold;
font-style: normal;
}
@ -421,8 +420,8 @@ body {scrollbar-width: 15px; scrollbar-color: var(--primary) {{background.darken
@font-face {
font-family: 'sans undertale';
src: local('Nunito Sans Light Italic'),
url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-ExtraLightItalic.woff2') format('woff2'),
url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-ExtraLightItalic.ttf') format('ttf');
url('/framework/static/nunito/NunitoSans-ExtraLightItalic.woff2') format('woff2'),
url('/framework/static/nunito/NunitoSans-ExtraLightItalic.ttf') format('ttf');
font-weight: normal;
font-style: italic;
}
@ -430,8 +429,8 @@ body {scrollbar-width: 15px; scrollbar-color: var(--primary) {{background.darken
@font-face {
font-family: 'sans undertale';
src: local('Nunito Sans Bold Italic'),
url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-Italic.woff2') format('woff2'),
url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-Italic.ttf') format('ttf');
url('/framework/static/nunito/NunitoSans-Italic.woff2') format('woff2'),
url('/framework/static/nunito/NunitoSans-Italic.ttf') format('ttf');
font-weight: bold;
font-style: italic;
}
@ -439,8 +438,8 @@ body {scrollbar-width: 15px; scrollbar-color: var(--primary) {{background.darken
@font-face {
font-family: 'sans undertale';
src: local('Nunito Sans Light'),
url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-Light.woff2') format('woff2'),
url('{{cfg.proto}}://{{cfg.web_host}}/framework/static/nunito/NunitoSans-Light.ttf') format('ttf');
url('/framework/static/nunito/NunitoSans-Light.woff2') format('woff2'),
url('/framework/static/nunito/NunitoSans-Light.ttf') format('ttf');
font-weight: normal;
font-style: normal;
}

View file

@ -2,20 +2,21 @@
%html
%head
%title << {{cfg.name}}: {{page}}
%link rel='stylesheet' type='text/css' href='{{cfg.proto}}://{{cfg.web_host}}/framework/style.css'
%link rel='manifest' href='{{cfg.proto}}://{{cfg.web_host}}/framework/manifest.json'
%link rel='stylesheet' type='text/css' href='/framework/style.css'
%link rel='manifest' href='/framework/manifest.json'
%meta charset='UTF-8'
%meta name='viewport' content='width=device-width, initial-scale=1'
-block head
%body
#body
#header.flex-container
-if menu_left
#btn.section
.page-title.section -> %a.title href='{{cfg.proto}}://{{cfg.web_host}}/' << {{cfg.name}}
.page-title.section -> %a.title href='/' << {{cfg.name}}
-else
.page-title.section -> %a.title href='{{cfg.proto}}://{{cfg.web_host}}/' << {{cfg.name}}
.page-title.section -> %a.title href='/' << {{cfg.name}}
#btn.section
-if message
@ -27,9 +28,15 @@
#menu.section
.title-item.item << Menu
#items
-for label, data in cfg.menu.items()
-if request.user_level >= data[1]
.item -> %a href='{{cfg.proto}}://{{cfg.web_host}}{{data[0]}}' << {{label}}
-if not len(cfg.menu):
-include 'menu.haml'
-else:
-for label, path_data in cfg.menu.items()
-if path_data[1] == 1000 and request.user_level == 0:
.item -> %a href='{{path_data[0]}}' << {{label}}
-elif request.user_level >= path_data[1]
.item -> %a href='{{path_data[0]}}' << {{label}}
#content-body.section
-block content
@ -40,4 +47,4 @@
.source
%a href='{{cfg.git_repo}}' target='_new' << {{cfg.name}}/{{cfg.version}}
%script type='application/javascript' src='{{cfg.proto}}://{{cfg.web_host}}/framework/static/menu.js'
%script type='application/javascript' src='/framework/static/menu.js'

View file

@ -0,0 +1 @@
.item => %a(href='/') << Home

View file

@ -6,6 +6,34 @@ from izzylib import izzylog as logging, logging as applog
from . import start_time
cache_types = [
'text/css',
'application/javascript',
]
cache_base_types = [
'image',
'audio',
'video'
]
def cache_check(request, response):
content_type = response.headers.get('content-type')
if request.path.startswith('/framework'):
return True
if not content_type:
return False
if content_type in cache_types:
return True
if any(map(content_type.startswith, cache_base_types)):
return True
class MiddlewareBase:
attach = 'request'
@ -22,7 +50,14 @@ class Headers(MiddlewareBase):
attach = 'response'
async def handler(self, request, response):
if request.path.startswith('/framework') or request.path == '/favicon.ico':
if not response.headers.get('content-type'):
if request.path.endswith('.css'):
response.headers['content-type'] = 'text/css'
elif request.path.endswith('.js'):
response.headers['content-type'] = 'application/javascript'
if cache_check(request, response):
max_age = int(timedelta(weeks=2).total_seconds())
response.headers['Cache-Control'] = f'immutable,private,max-age={max_age}'

View file

@ -1,14 +1,16 @@
import sanic
from izzylib import DotDict
from .misc import Headers
class Request(sanic.request.Request):
def __init__(self, url_bytes, headers, version, method, transport, app):
def __init__(self, url_bytes, headers, version, method, transport, app, **kwargs):
super().__init__(url_bytes, headers, version, method, transport, app)
self.Headers = Headers(headers)
self.Data = Data(self)
self.data = Data(self)
self.template = self.app.template
self.user_level = 0
self.setup()

View file

@ -1,6 +1,7 @@
import json, sanic
from izzylib import DotDict
from datetime import datetime
from izzylib import DotDict, izzylog
from izzylib.template import Color
from sanic.compat import Header
from sanic.cookies import CookieJar
@ -27,6 +28,18 @@ class Response:
'speed': 250
})
cookie_keys = {
'expires': 'Expires',
'path': 'Path',
'comment': 'Comment',
'domain': 'Domain',
'max_age': 'Max-Age',
'secure': 'Secure',
'httponly': 'HttpOnly',
'version': 'Version',
'samesite': 'SameSite'
}
def __init__(self, app, request, body=None, headers={}, cookies={}, status=200, content_type='text/html'):
# server objects
@ -110,6 +123,8 @@ class Response:
data.update(kwargs)
for k,v in data.items():
k = self.cookie_keys.get(k, k)
if k.lower() == 'max-age':
if isinstance(v, timedelta):
v = int(v.total_seconds())
@ -118,13 +133,16 @@ class Response:
raise TypeError('Max-Age must be an integer or timedelta')
elif k.lower() == 'expires':
if isinstance(v, datetime):
v = v.strftime('%a, %d-%b-%Y %T GMT')
if isinstance(v, str):
v = datetime.strptime(v, '%a, %d-%b-%Y %T GMT')
elif not isinstance(v, str):
elif not isinstance(v, datetime):
raise TypeError('Expires must be a string or datetime')
self.cookies[key][k] = v
try:
self.cookies[key][k] = v
except KeyError as e:
izzylog.error('Invalid cookie key:', k)
def get_cookie(self, key):
@ -140,7 +158,7 @@ class Response:
del self.cookies[key]
def template(self, tplfile, context={}, headers={}, status=200, content_type='text/html', cookies={}, pprint=False):
def template(self, tplfile, context={}, headers={}, status=200, content_type='text/html', pprint=False):
self.status = status
context.update({
'response': self,
@ -148,7 +166,7 @@ class Response:
})
html = self.app.template.render(tplfile, context, request=self.request, pprint=pprint)
return self.html(html, headers=headers, status=status, content_type=content_type, cookies=cookies)
return self.html(html, headers=headers, status=status, content_type=content_type)
def error(self, message, status=500, **kwargs):
@ -158,13 +176,13 @@ class Response:
return self.template('error.haml', {'error_message': message}, status=status, **kwargs)
def json(self, body={}, headers={}, status=200, content_type='application/json', cookies={}):
def json(self, body={}, headers={}, status=200, content_type='application/json'):
body = json.dumps(body)
return self.get_response(body, headers, status, content_type, cookies)
return self.get_response(body, headers, status, content_type)
def text(self, body, headers={}, status=200, content_type='text/plain', cookies={}):
return self.get_response(body, headers, status, content_type, cookies)
def text(self, body, headers={}, status=200, content_type='text/plain'):
return self.get_response(body, headers, status, content_type)
def html(self, *args, **kwargs):
@ -174,24 +192,26 @@ class Response:
def css(self, *args, **kwargs):
self.content_type = 'text/css'
return self.text.text(*args, **kwargs)
return self.text(*args, **kwargs)
def javascript(self, *args, **kwargs):
self.content_type = 'application/javascript'
return self.text.text(*args, **kwargs)
return self.text(*args, **kwargs)
def activitypub(self, *args, **kwargs):
self.content_type = 'application/activity+json'
return self.text.text(*args, **kwargs)
return self.text(*args, **kwargs)
def redir(self, path, status=302, headers={}):
return sanic.response.redirect(path, status=status, headers={})
headers.update(dict(location=path))
return self.text(body=None, status=status, headers=headers)
#return sanic.response.redirect(path, status=status, headers={})
def set_data(self, body=None, headers={}, status=200, content_type='text/html', cookies={}):
def set_data(self, body=None, headers={}, status=200, content_type='text/html'):
ctype = self.content_types.get(content_type, content_type)
self.body = body

View file

@ -48,6 +48,14 @@ class Manifest(View):
return response.json(data)
class Robots(View):
paths = ['/robots.txt']
async def get(self, request, response):
data = '# Disallow all\nUser-agent: *\nDisallow: /'
return response.text(data)
class Style(View):
paths = ['/framework/style.css']

View file

@ -5,6 +5,7 @@ from setuptools import setup, find_namespace_packages
requires = [
'sanic>=20.12.3',
'sanic-cors>=1.0.0',
'envbash>=1.0.0'
]

View file

@ -1,2 +1,6 @@
from .generic import SqlColumn, CustomRows, SqlSession, SqlDatabase, Tables
# old sql classes
from .generic import SqlColumn, CustomRows, SqlSession, SqlDatabase, Tables, OperationalError, ProgrammingError
from .sqlite_server import SqliteClient, SqliteColumn, SqliteServer, SqliteSession
from .database import Database, Session
from .queries import Column, Insert, Select, Table, Tables, Update

100
sql/izzylib/sql/config.py Normal file
View file

@ -0,0 +1,100 @@
import importlib, sqlite3, ssl
from getpass import getuser
from izzylib import DotDict, Path, izzylog
defaults = {
'name': (None, str),
'host': (None, str),
'port': (None, int),
'username': (getuser(), str),
'password': (None, str),
'ssl': ('allow', str),
'ssl_context': (ssl.create_default_context(), ssl.SSLContext),
'ssl_key': (None, Path),
'ssl_cert': (None, Path),
'max_connections': (25, int),
'type': ('sqlite', str),
'module': (sqlite3, None),
'mod_name': ('sqlite3', str),
'timeout': (5, int),
'args': ([], list),
'kwargs': ({}, dict)
}
modtypes = {
'sqlite': ['sqlite3'],
'postgresql': ['pg8000', 'psycopg2', 'psycopg3', 'pgdb'],
'mysql': ['mysqldb', 'trio_mysql'],
'mssql': ['pymssql', 'adodbapi']
}
sslmodes = ['disable', 'allow', 'prefer', 'require', 'verify-ca', 'verify-full']
class Config(DotDict):
def __init__(self, **kwargs):
super().__init__({k: v[0] for k,v in defaults.items()})
module = kwargs.pop('module', None)
if module:
self.parse_module(module)
self.update(kwargs)
if self.ssl != 'disable' and (self.ssl_key or self.ssl_cert):
self.ssl_context.load_cert_chain(self.ssl_cert, self.ssl_key)
def __setitem__(self, key, value):
if key not in defaults:
raise KeyError(f'Invalid config option: {key}')
valtype = defaults[key][1]
if valtype and value and not isinstance(value, valtype):
raise TypeError(f'{key} should be a {valtype}, not a {value.__class__.__name__}')
if key == 'ssl' and value == True:
value = ssl.create_default_context()
super().__setitem__(key, value)
def parse_module(self, name):
module = None
module_type = None
module_name = None
if name == 'sqlite3':
name = 'sqlite'
for mtype, modules in modtypes.items():
if name == mtype:
module_type = name
for mod in modules:
try:
module = importlib.import_module(mod)
module_name = mod
break
except ImportError:
izzylog.verbose(f'Database module not installed:', mod)
elif name in modules:
try:
module = importlib.import_module(name)
module_type = mtype
module_name = name
break
except ImportError:
izzylog.error(f'Database module not installed:', name)
if None in (module, module_name, module_type):
raise ValueError(f'Failed to find module for {name}')
self.module = module
self.mod_name = module_name
self.type = module_type

360
sql/izzylib/sql/database.py Normal file
View file

@ -0,0 +1,360 @@
import sqlite3, traceback
from functools import partial
from getpass import getuser
from izzylib import DotDict, izzylog, boolean, random_gen
from . import error
from .config import Config
from .queries import Column, Delete, Insert, Select, Table, Tables, Update
class Database:
def __init__(self, tables=None, **kwargs):
self.tables = tables
self.cfg = Config(**kwargs)
self.sessions = DotDict()
@property
def session(self):
return self.get_session(False)
@property
def session_trans(self):
return self.get_session(True)
def connect(self, sid, session):
if len(self.sessions) >= self.cfg.max_connections:
raise error.MaxConnectionsError(f'Cannot start a new session with id {sid}. Reach max connection count of {self.cfg.max_connections}.')
self.sessions[sid] = session
def disconnect(self, sid):
self.sessions[sid].disconnect()
del self.sessions[sid]
def disconnect_all(self):
sids = []
for sid in self.sessions.keys():
sids.append(sid)
for sid in sids:
self.disconnect(sid)
def get_session(self, trans=True):
session = Session(self, trans)
self.sessions[session.id] = session
return session
def execute(self, *args):
with self.session as s:
s.execute(*args)
def load_tables(self, path):
self.tables = Tables.new_from_json_file(path)
def pre_setup(self):
if self.cfg.type != 'postgresql':
izzylog.verbose(f'Database not supported for pre_setup: {self.cfg.type}')
return
original_database = self.cfg.name
self.cfg.name = 'postgres'
with self.session as s:
s.conn.autocommit = True
s.rollback()
if original_database not in s.get_databases():
#s.execute('SET AUTOCOMMIT = OFF')
s.cursor.execute(f'CREATE DATABASE {original_database}')
s.conn.autocommit = False
self.cfg.name = original_database
def set_row_class(self, table, row_class):
pass
class Session:
def __init__(self, db, trans):
self.id = random_gen()
self.db = db
self.cfg = db.cfg
self.trans = trans
self.trans_state = False
self.conn = None
self.cursor = None
def __del__(self):
try:
izzylog.verbose('Deleting session:', self.id)
except ModuleNotFoundError:
if izzylog.get_config('level') >= 20:
print('[izzylib] VERBOSE: Deleting session:', self.id)
self.db.sessions.pop(self.id, None)
if self.conn:
self.disconnect()
def __enter__(self):
self.connect()
if self.trans:
self.begin()
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
if exc_traceback:
self.rollback()
else:
self.commit()
self.disconnect()
self.db.disconnect(self.id)
def connect(self):
if self.conn:
return
self.db.connect(self.id, self)
if self.cfg.type == 'sqlite':
self.conn = self.cfg.module.connect(self.cfg.name, self.cfg.timeout, check_same_thread=True)
elif self.cfg.type == 'postgresql':
options = dict(
host = self.cfg.host or '/var/run/postgresql',
port = self.cfg.port or 5432,
database = self.cfg.name or 'postgresql',
user = self.cfg.username or getuser(),
password = self.cfg.password,
)
if self.cfg.mod_name == 'pg8000':
if options['host'] in [None, '/var/run/postgresql']:
port = options.pop('port')
options['unix_sock'] = options.pop('host') + f'/.s.PGSQL.{port}'
## SSL is a pain in the ass tbh. Gonna deal with this later
#if self.cfg.mod_name == 'pg8000':
#options['sslmode'] = self.cfg.ssl
#options['ssl_context'] = self.cfg.ssl_context
#elif self.cfg.mod_name == 'psycopg2':
#options['sslcert'] = self.cfg.ssl_cert
#options['sslkey'] = self.cfg.ssl_key
self.conn = self.cfg.module.connect(**options)
else:
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
try:
self.conn.autocommit = False
except:
izzylog.verbose('Failed to turn off autocommit')
self.cursor = self.conn.cursor()
def disconnect(self):
if not self.conn:
return
self.cursor.close()
self.conn.close()
self.cursor = None
self.conn = None
def begin(self):
if self.trans_state:
return
#self.conn.begin()
self.execute('BEGIN TRANSACTION')
self.trans_state = True
def rollback(self):
if not self.trans_state:
return
self.conn.rollback()
#self.execute('ROLLBACK TRANSACTION')
self.trans_state = False
def commit(self):
if not self.trans_state:
return
self.conn.commit()
#self.execute('COMMIT TRANSACTION')
self.trans_state = False
## data management functions
def execute(self, string, values=[]):
if any(map(string.lower().startswith, ['insert', 'update', 'remove', 'create', 'drop'])) and not self.trans_state:
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
self.cursor.execute(string, values)
return self.cursor
def fetch(self, table, single=True, **kwargs):
rows = []
data = Select(table, type=self.cfg.type, **kwargs).exec(self)
for line in data:
row = Row(table, self.cursor.description, line)
if single:
return row
rows.append(row)
return rows if not single else None
def search(self, table, **kwargs):
return self.fetch(table, single=False, **kwargs)
def insert(self, table, **kwargs):
if not self.trans_state:
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
Insert(table, type=self.cfg.type, **kwargs).exec(self)
return self.fetch(table, **kwargs)
def update(self, table, rowid, **kwargs):
if not self.trans_state:
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
Update(table, rowid, type=self.cfg.type, **kwargs).exec(self)
return self.fetch(table, id=rowid)
def delete(self, table, **kwargs):
if not self.trans_state:
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
Delete(table, type=self.cfg.type, **kwargs).exec(self)
## helper functions
def get_columns(self, table):
if table not in self.get_tables():
raise KeyError(f'Not an existing table: {table}')
if self.cfg.type == 'sqlite':
rows = self.execute(f'PRAGMA table_info({table})')
return [row[1] for row in rows]
elif self.cfg.type == 'postgresql':
rows = self.execute(f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table}'")
return [row[0] for row in rows]
else:
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
def get_tables(self):
if self.cfg.type == 'sqlite':
rows = self.execute("SELECT name FROM sqlite_master WHERE type IN ('table','view') and name NOT LIKE 'sqlite_%'")
elif self.cfg.type == 'postgresql':
rows = self.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' ORDER BY table_name")
else:
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
return [row[0] for row in rows]
def get_databases(self):
if self.cfg.type == 'sqlite':
izzylog.verbose('This function is useless with sqlite')
return
elif self.cfg.type == 'postgresql':
databases = [row[0] for row in self.execute('SELECT datname FROM pg_database')]
else:
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
return databases
def cursor_description(self):
return [row[0] for row in self.cursor.description]
def setup_database(self):
if not self.db.tables:
raise ValueError('Tables have not been specified.')
current_tables = self.get_tables()
for name, table in self.db.tables.items():
if name in current_tables:
izzylog.verbose(f'Skipping table creation since it already exists: {name}')
continue
izzylog.verbose(f'Creating table: {name}')
self.execute(table.build(self.cfg.type))
class Row(DotDict):
def __init__(self, table, keys, values):
self._db = None
self._table = table
super().__init__()
for idx, key in enumerate([key[0] for key in keys]):
self[key] = values[idx]
def update(self, data):
for k, v in data.items():
if k not in self:
raise KeyError(f'Not a column for {self._table}')
self[k] = v
def delete(self):
with self._db.session as s:
s.delete(self._table, id=self.id)
def update(self, **kwargs):
self.update(kwargs)
with self._db.session as s:
s.update(self._table, id=self.id, **kwargs)

10
sql/izzylib/sql/error.py Normal file
View file

@ -0,0 +1,10 @@
class MaxConnectionsError(Exception):
'raise when the max amount of connections has been reached'
class NoTransactionError(Exception):
'raise when a write command is executed outside a transaction'
class DatabaseNotSupportedError(Exception):
'raise when the action being performed is not supported by the database in use'

View file

@ -3,7 +3,7 @@ import json, sys, threading, time
from contextlib import contextmanager
from datetime import datetime
from sqlalchemy import create_engine, ForeignKey, MetaData, Table
from sqlalchemy import Column, types as Types
from sqlalchemy import Column as sqlalchemy_column, types as Types
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.orm import scoped_session, sessionmaker
@ -57,6 +57,12 @@ class SqlDatabase:
engine_string += f'/{database}'
engine_kwargs['connect_args'] = {'check_same_thread': False}
elif dbtype == 'postgresql':
ssl_context = kwargs.get('ssl')
if ssl_context:
engine_kwargs['ssl_context'] = ssl_context
else:
user = kwargs.get('user')
password = kwargs.get('pass')
@ -124,9 +130,9 @@ class SqlDatabase:
self.table_names = tables.keys()
def execute(self, *args, **kwargs):
def execute(self, string, values=[]):
with self.session as s:
return s.execute(*args, **kwargs)
s.execute(string, values)
class SqlSession(object):
@ -267,7 +273,7 @@ class SqlSession(object):
if not rowid or not table:
raise ValueError('Missing row ID or table')
row = self.execute(f'DELETE FROM {table} WHERE id={rowid}')
self.execute(f'DELETE FROM {table} WHERE id={rowid}')
def drop_table(self, name):
@ -284,12 +290,28 @@ class SqlSession(object):
self.drop_table(table)
def get_columns(self, table):
if table not in self.get_tables():
raise KeyError(f'Not an existing table: {table}')
rows = self.execute('PRAGMA table_info(user)')
return [row[1] for row in rows]
def get_tables(self):
rows = self.execute("SELECT name FROM sqlite_master WHERE type IN ('table','view') and name NOT LIKE 'sqlite_%'")
return [row[0] for row in rows]
def append_column(self, tbl, col):
def append_column(self, table, column):
if column.name in self.get_columns(table):
logging.warning(f'Table "{table}" already has column "{column.name}"')
return
self.execute(f'ALTER TABLE {table} ADD COLUMN {column.compile()}')
def append_column2(self, tbl, col):
table = self.table[tbl]
try:
@ -301,7 +323,7 @@ class SqlSession(object):
columns = [row[1] for row in self.execute(f'PRAGMA table_info({tbl})')]
if col in columns:
if col in self.get_columns(tbl):
izzylog.info(f'Column "{col}" already exists')
return
@ -436,29 +458,51 @@ class Tables(DotDict):
def __setup_table(self, name, table):
columns = [col if type(col) == Column else Column(*col.get('args'), **col.get('kwargs')) for col in table]
columns = [col if type(col) == SqlColumn else SqlColumn(*col.get('args'), **col.get('kwargs')) for col in table]
self[name] = Table(name, self.meta, *columns)
def SqlColumn(name, stype=None, fkey=None, **kwargs):
if not stype and not kwargs:
if name == 'id':
return Column('id', SqlTypes['integer'], primary_key=True, autoincrement=True)
class SqlColumn(sqlalchemy_column):
def __init__(self, name, stype=None, fkey=None, **kwargs):
if not stype and not kwargs:
if name == 'id':
stype = 'integer'
kwargs['primary_key'] = True
kwargs['autoincrement'] = True
elif name == 'timestamp':
return Column('timestamp', SqlTypes['datetime'])
elif name == 'timestamp':
stype = 'datetime'
raise ValueError('Missing column type and options')
else:
raise ValueError('Missing column type and options')
else:
try:
stype = stype or 'string'
options = [name, SqlTypes[stype.lower()]]
stype = (stype.lower() if type(stype) == str else stype) or 'string'
except KeyError:
raise KeyError(f'Invalid SQL data type: {stype}')
if type(stype) == str:
try:
stype = SqlTypes[stype.lower()]
except KeyError:
raise KeyError(f'Invalid SQL data type: {stype}')
options = [name, stype]
if fkey:
options.append(ForeignKey(fkey))
return Column(*options, **kwargs)
super().__init__(*options, **kwargs)
def compile(self):
sql = f'{self.name} {self.type}'
if not self.nullable:
sql += ' NOT NULL'
if self.primary_key:
sql += ' PRIMARY KEY'
if self.unique:
sql += ' UNIQUE'
return sql

415
sql/izzylib/sql/queries.py Normal file
View file

@ -0,0 +1,415 @@
from datetime import datetime
from functools import partial
from izzylib import DotDict, Path
from .types import BaseType, Type
placeholders = dict(
sqlite = '?',
postgresql = '%s'
)
## Data queries
class Delete:
def __init__(self, table, type='sqlite', **kwargs):
self.table = table
self.placeholder = placeholders[type]
self.keys = []
self.values = []
for k,v in kwargs.items():
self.keys.append(k)
self.values.append(v)
def __str__(self):
self.build(embed_values=True)
def build(self, comp_type='AND', embed_values=False):
sql = 'DELETE FROM {table} WHERE {kstring}'
if not embed_values:
kstring = f' {comp_type.upper()} '.join([f'{k} = {self.placeholder}' for k in self.keys])
return sql.format(table=self.table, kstring=kstring), self.values
values = []
for idx, value in enumerate(self.values):
if type(value) == str:
values.append(f"{self.keys[idx]} = '{value}'")
else:
values.append(f"{self.keys[idx]} = {value}")
kstring = ','.join(values)
return sql.format(table=self.table, kstring=kstring, rowid=self.rowid)
def exec(self, session, comp_type='AND'):
return session.execute(*self.build(comp_type))
class Insert:
def __init__(self, table, type='sqlite', **kwargs):
self.table = table
self.placeholder = placeholders[type]
self.keys = []
self.values = []
for k, v in kwargs.items():
self.keys.append(k)
self.values.append(v)
def __str__(self):
return self.build(embed_values=True)
def build(self, embed_values=False):
kstring = ','.join(self.keys)
if not embed_values:
vstring = ','.join([self.placeholder for k in self.keys])
return f'INSERT INTO {self.table} ({kstring}) VALUES({vstring})', self.values
else:
vstring = ','.join(self.values)
return f'INSERT INTO {self.table} ({kstring}) VALUES({vstring})'
def exec(self, session):
return session.execute(*self.build())
class Select:
def __init__(self, table, columns=[], type='sqlite', **kwargs):
self.placeholder = placeholders[type]
self.columns = columns
self.table = table
self.where = []
self.where_build = []
self._order = []
self.keys = []
self.values = []
self.equals = partial(self.__comparison, '=')
self.less = partial(self.__comparison, '<')
self.greater = partial(self.__comparison, '>')
self.like = partial(self.__comparison, 'LIKE')
for k,v in kwargs.items():
self.equals(k, v)
def __str__(self):
return self.build(embed_values=True)
def __comparison(self, comp, key, value):
self.values.append(value)
self.keys.append(key)
self.where.append(f'{key} {comp.upper()} {self.placeholder}')
self.where_build.append(f"{key} {comp.upper()} '{value}'" if type(key) == str else f"{key} {comp.upper()} {value}")
return self
def order(self, column, asc=True):
self._order = [column, 'ASC' if asc else 'DESC']
return self
def build(self, comp_type='AND', embed_values=False):
if not self.columns:
cols = '*'
else:
cols = ','.join('columns')
sql_query = f'SELECT {cols} FROM {self.table}'
if self.where:
where = f' {comp_type.upper()} '.join(self.where if not embed_values else self.where_build)
sql_query += f' WHERE {where}'
if self._order:
col, order = self._order
sql_query += f' ORDER BY {col} {order}'
if embed_values:
return sql_query
return sql_query, self.values
def exec(self, session, comp_type='AND'):
return session.execute(*self.build(comp_type))
class Update:
def __init__(self, table, rowid, type='sqlite', **kwargs):
self.placeholder = placeholders[type]
self.table = table
self.rowid = rowid
self.keys = []
self.values = []
for k,v in kwargs.items():
self.keys.append(k)
self.values.append(v)
def __str__(self):
return self.build(embed_values=True)
def build(self, embed_values=False):
sql = 'UPDATE {table} SET {kstring} WHERE id={rowid}'
if not embed_values:
kstring = ','.join([f'{k} = {self.placeholder}' for k in self.keys])
return sql.format(table=self.table, kstring=kstring, rowid=self.rowid), self.values
values = []
for idx, value in enumerate(self.values):
if type(value) == str:
values.append(f"{self.keys[idx]} = '{value}'")
else:
values.append(f"{self.keys[idx]} = {value}")
kstring = ','.join(values)
return sql.format(table=self.table, kstring=kstring, rowid=self.rowid)
def exec(self, session):
return session.execute(*self.build())
## Database objects
class Column:
def __init__(self, name, type='STRING', unique=False, nullable=True, default=None, primary_key=False, autoincrement=False, foreign_key=None):
self.name = name
self.type = type
self.nullable = nullable
self.default = default
self.primary_key = primary_key
self.autoincrement = autoincrement
self.unique = unique
if any(map(isinstance, [foreign_key], [list, tuple, set])):
self.foreign_key = foreign_key
else:
self.foreign_key = foreign_key.split('.') if foreign_key else None
if autoincrement:
self.primary_key = True
self.type = Type['INTEGER']
if isinstance(self.type, BaseType):
self.type = self.type.name
else:
if self.type.upper() in Type.keys():
self.type = self.type.upper()
else:
raise TypeError(f'Invalid SQL type: {self.type}')
if foreign_key and len(self.foreign_key) != 2:
raise ValueError('Invalid foreign key. Must be in the format "table.column".')
def __str__(self):
return self.build()
def build(self, dbtype='sqlite'):
if dbtype == 'postgresql':
if self.type.lower() == 'string':
self.type = 'TEXT'
elif self.type.lower() == 'datetime':
self.type = 'TIMESTAMPTZ'
if self.autoincrement:
self.type = 'SERIAL'
self.autoincrement = False
sql = f'{self.name} {self.type}'
if self.primary_key:
sql += ' PRIMARY KEY'
if self.autoincrement:
sql += ' AUTOINCREMENT'
if self.unique:
sql += ' UNIQUE'
if not self.nullable:
sql += ' NOT NULL'
if self.default:
def_type = type(self.default)
if self.default == 'CURRENT_TIMESTAMP':
if dbtype == 'sqlite':
sql += " DEFAULT (datetime('now', 'localtime'))"
elif dbtype == 'postgresql':
sql += ' DEFAULT now()'
else:
sql += f' DEFAULT {datetime.now().timestamp()}'
elif def_type == str:
sql += f" DEFAULT '{self.default}'"
elif def_type in [int, float]:
sql += f' DEFAULT {self.default}'
elif def_type == bool and dbtype == 'sqlite':
sql += f' DEFAULT {int(self.default)}'
else:
sql += f' DEFAULT {self.default}'
print(sql)
return sql
def json(self):
return DotDict({
'type': self.type,
'nullable': self.nullable,
'default': self.default,
'primary_key': self.primary_key,
'autoincrement': self.autoincrement,
'unique': self.unique,
'foreign_key': self.foreign_key
})
class Table(DotDict):
def __init__(self, name, *columns):
super().__init__()
self._name = name
self._foreign_keys = {}
self.add_column(Column('id', autoincrement=True))
for column in columns:
self.add_column(column)
def __str__(self):
return self.build()
# this'll be useful later
def __call__(self, *args, **kwargs):
pass
@property
def name(self):
return self._name
def add_column(self, column):
self[column.name] = column
if column.foreign_key:
self._foreign_keys[column.name] = column.foreign_key
def build(self, dbtype='sqlite'):
column_string = ',\n'.join([f'\t{col.build(dbtype)}' for col in self.values()])
if self._foreign_keys:
column_string += ',\n'
column_string += ',\n'.join([f'\tFOREIGN KEY ({column}) REFERENCES {key[0]} ({key[1]})' for column, key in self._foreign_keys.items()])
return f'''CREATE TABLE {self.name} (
{column_string}
);'''
def json(self):
data = {}
for name, column in self.items():
data[name] = column.json()
return data
class Tables(DotDict):
def __init__(self, *tables, data={}):
super().__init__()
for table in tables:
self.add_table(table)
if data:
self.from_dict(data)
def __str__(self):
return self.build()
@classmethod
def new_from_json_file(cls, path):
return cls(data=DotDict.new_from_json_file(path))
def add_table(self, table):
self[table.name] = table
def build(self):
return '\n\n'.join([str(table) for table in self.values()])
def load_json(self, path):
data = DotDict()
data.load_json(path)
self.from_dict(data)
def save_json(self, path, indent='\t'):
self.to_dict().save_json(path, indent=indent)
def from_dict(self, data):
for name, columns in data.items():
table = Table(name)
for col, kwargs in columns.items():
table.add_column(Column(col,
type = kwargs.get('type', 'STRING'),
nullable = kwargs.get('nullable', True),
default = kwargs.get('default'),
primary_key = kwargs.get('primary_key', False),
autoincrement = kwargs.get('autoincrement', False),
unique = kwargs.get('unique', False),
foreign_key = kwargs.get('foreign_key')
))
self.add_table(table)
def to_dict(self):
data = DotDict()
for name, table in self.items():
data[name] = table.json()
return data

19
sql/izzylib/sql/row.py Normal file
View file

@ -0,0 +1,19 @@
from izzylib import DotDict
class DbRow(DotDict):
def __init__(self, table, keys, values):
self.table = table
super().__init__()
for idx, key in enumerate(keys):
self[key] = values[idx]
def delete(self):
pass
def update(self, **kwargs):
pass

19
sql/izzylib/sql/types.py Normal file
View file

@ -0,0 +1,19 @@
from enum import Enum
from izzylib import DotDict
class BaseType(Enum):
INTEGER = int
TEXT = str
BLOB = bytes
REAL = float
NUMERIC = float
Type = DotDict(
**{v: BaseType.INTEGER for v in ['INT', 'INTEGER', 'TINYINT', 'SMALLINT', 'MEDIUMINT', 'BIGINT', 'UNSIGNED BIG INT', 'INT2', 'INT8']},
**{v: BaseType.TEXT for v in ['CHARACTER', 'VARCHAR', 'VARYING CHARACTER', 'NCHAR', 'NATIVE CHARACTER', 'NVARCHAR', 'TEXT', 'CLOB', 'STRING', 'JSON']},
**{v: BaseType.BLOB for v in ['BYTES', 'BLOB']},
**{v: BaseType.REAL for v in ['REAL', 'DOUBLE', 'DOUBLE PRECISION', 'FLOAT']},
**{v: BaseType.NUMERIC for v in ['NUMERIC', 'DECIMAL', 'BOOLEAN', 'DATE', 'DATETIME']}
)