rework external http client and database

This commit is contained in:
Izalia Mae 2021-09-17 08:59:29 -04:00
parent 56e3e30db1
commit f401cc7e1a
28 changed files with 892 additions and 1682 deletions

2
.gitignore vendored
View file

@ -121,7 +121,7 @@ reload.cfg
/izzylib
/base/izzylib/dbus
/base/izzylib/hasher
/base/izzylib/http_requests_client
/base/izzylib/http_urllib_client
/base/izzylib/http_server
/base/izzylib/mbus
/base/izzylib/sql

View file

@ -16,7 +16,7 @@ You only need to install the base and whatever sub-modules you want to use
$(venv)/bin/python -m pip install -e "git+https://git.barkshark.xyz/izaliamae/izzylib.git@rework#egg=izzylib-http-server&subdirectory=http_server"
$(venv)/bin/python -m pip install -e "git+https://git.barkshark.xyz/izaliamae/izzylib.git@rework#egg=izzylib-http-requests-client&subdirectory=requests_client"
$(venv)/bin/python -m pip install -e "git+https://git.barkshark.xyz/izaliamae/izzylib.git@rework#egg=izzylib-http-urllib-client&subdirectory=http_urllib_client"
$(venv)/bin/python -m pip install -e "git+https://git.barkshark.xyz/izaliamae/izzylib.git@rework#egg=izzylib-sql&subdirectory=sql"
@ -26,7 +26,7 @@ You only need to install the base and whatever sub-modules you want to use
### From Source
$(venv)/bin/python setup.py install ['all' or a combination of these: dbus hasher http_server requests_client sql template tinydb]
$(venv)/bin/python setup.py install ['all' or a combination of these: dbus hasher http_server http_urllib_client sql template tinydb]
## Documentation

View file

@ -21,7 +21,7 @@ from .misc import *
from .cache import CacheDecorator, LruCache, TtlCache
from .connection import Connection
from .http_urllib_client import HttpUrllibClient, HttpUrllibResponse
from .http_client import HttpClient, HttpResponse
def log_import_error(package, *message):
@ -48,10 +48,10 @@ except ImportError:
log_import_error('template', 'Failed to import http template classes. Jinja and HAML templates disabled')
try:
from izzylib.http_requests_client import *
from izzylib.http_urllib_client import *
except ImportError:
log_import_error('http_requests_client', 'Failed to import Requests http client classes. Requests http client is disabled')
log_import_error('http_urllib_client', 'Failed to import Requests http client classes. Requests http client is disabled')
try:
from izzylib.http_server import PasswordHasher, HttpServer, HttpServerRequest, HttpServerResponse

View file

@ -110,18 +110,12 @@ class DefaultDotDict(DotDict):
class LowerDotDict(DotDict):
def __getattr__(self, key):
return super().__getattr__(self, key.lower())
def __getitem__(self, key):
return super().__getitem__(key.lower())
def __setattr__(self, key, value):
return super().__setattr__(key.lower(), value)
def update(self, data):
data = {k.lower(): v for k,v in self.items()}
return super().update(data)
def __setitem__(self, key, value):
return super().__setitem__(key.lower(), value)
class MultiDotDict(DotDict):

View file

@ -22,7 +22,7 @@ except ImportError:
methods = ['connect', 'delete', 'get', 'head', 'options', 'patch', 'post', 'put', 'trace']
class HttpUrllibClient:
class HttpClient:
def __init__(self, headers={}, useragent=f'IzzyLib/{__version__}', appagent=None, proxy_type='https', proxy_host=None, proxy_port=None):
proxy_ports = {
'http': 80,
@ -74,7 +74,7 @@ class HttpUrllibClient:
except HTTPError as e:
response = e.fp
return HttpUrllibResponse(response)
return HttpResponse(response)
def file(self, url, filepath, *args, filename=None, size=2048, create_dirs=True, **kwargs):
@ -141,7 +141,7 @@ class HttpUrllibClient:
return self.request(*args, headers=headers, **kwargs)
class HttpUrllibResponse(object):
class HttpResponse(object):
def __init__(self, response):
self.body = response.read()
self.headers = DefaultDotDict({k.lower(): v.lower() for k,v in response.headers.items()})

View file

@ -4,6 +4,7 @@ from datetime import datetime
from getpass import getpass, getuser
from importlib import util
from pathlib import Path
from urllib.parse import urlparse
from . import izzylog
from .dotdict import DotDict
@ -27,7 +28,8 @@ __all__ = [
'time_function',
'time_function_pprint',
'timestamp',
'var_name'
'var_name',
'Url'
]
@ -460,3 +462,26 @@ def var_name(single=True, **kwargs):
keys = list(kwargs.keys())
return key[0] if single else keys
class Url(str):
protocols = {
'http': 80,
'https': 443,
'ftp': 21,
'ftps': 990
}
def __init__(self, url):
str.__new__(Url, url)
parsed = urlparse(url)
self.__parsed = parsed
self.proto = parsed.scheme
self.host = parsed.netloc
self.path = parsed.path
self.query = parsed.query
self.username = parsed.username
self.password = parsed.password
self.port = self.protocols.get(self.proto) if not parsed.port else None

View file

@ -78,6 +78,5 @@ class AccessLog(MiddlewareBase):
async def handler(self, request, response):
uagent = request.headers.get('user-agent', 'None')
address = request.headers.get('x-real-ip', request.forwarded.get('for', request.remote_addr))
applog.info(f'({multiprocessing.current_process().name}) {address} {request.method} {request.path} {response.status} "{uagent}"')
applog.info(f'({multiprocessing.current_process().name}) {request.address} {request.method} {request.path} {response.status} "{uagent}"')

View file

@ -11,6 +11,7 @@ class Request(sanic.request.Request):
super().__init__(url_bytes, headers, version, method, transport, app)
self.Headers = Headers(headers)
self.address = self.headers.get('x-real-ip', self.forwarded.get('for', self.remote_addr))
self.data = Data(self)
self.template = self.app.template
self.user_level = 0

View file

@ -0,0 +1,30 @@
from .signatures import (
verify_request,
verify_headers,
parse_signature,
fetch_actor,
fetch_instance,
fetch_nodeinfo,
fetch_webfinger_account,
generate_rsa_key
)
from .client import HttpUrllibClient, set_default_client
from .request import HttpUrllibRequest
from .response import HttpUrllibResponse
#__all__ = [
#'HttpRequestsClient',
#'HttpRequestsRequest',
#'HttpRequestsResponse',
#'fetch_actor',
#'fetch_instance',
#'fetch_nodeinfo',
#'fetch_webfinger_account',
#'generate_rsa_key',
#'parse_signature',
#'set_requests_client',
#'verify_headers',
#'verify_request',
#]

View file

@ -0,0 +1,128 @@
import json, sys, urllib3
from PIL import Image
from base64 import b64encode
from datetime import datetime
from functools import cached_property
from io import BytesIO
from izzylib import DefaultDotDict, DotDict, LowerDotDict, Path, izzylog as logging, __version__
from izzylib.exceptions import HttpFileDownloadedError
from ssl import SSLCertVerificationError
from urllib.parse import urlparse
from .request import HttpUrllibRequest
from .response import HttpUrllibResponse
from .signatures import set_client
Client = None
proxy_ports = {
'http': 80,
'https': 443
}
class HttpUrllibClient:
def __init__(self, headers={}, useragent=None, appagent=None, proxy_type='https', proxy_host=None, proxy_port=None, num_pools=20):
if not useragent:
useragent = f'IzzyLib/{__version__}'
self.headers = {k:v.lower() for k,v in headers.items()}
self.agent = f'{useragent} ({appagent})' if appagent else useragent
if proxy_type not in ['http', 'https']:
raise ValueError(f'Not a valid proxy type: {proxy_type}')
if proxy_host:
proxy = f'{proxy_type}://{proxy_host}:{proxy_ports[proxy_type] if not proxy_port else proxy_port}'
self.pool = urllib3.ProxyManager(proxy, num_pools=num_pools)
else:
self.pool = urllib3.PoolManager(num_pools=num_pools)
@property
def agent(self):
return self.headers['user-agent']
@agent.setter
def agent(self, value):
self.headers['user-agent'] = value
def set_global(self):
set_default_client(self)
def build_request(self, *args, **kwargs):
return HttpUrllibRequest(*args, **kwargs)
def handle_request(self, request):
request.headers.update(self.headers)
response = self.pool.urlopen(*request._args, **request._kwargs)
return HttpUrllibResponse(response)
def request(self, *args, **kwargs):
return self.handle_request(self.build_request(*args, **kwargs))
def signed_request(self, privkey, keyid, *args, **kwargs):
return self.request(*args, privkey=privkey, keyid=keyid, **kwargs)
def download(self, url, filepath, *args, filename=None, **kwargs):
resp = self.request(url, *args, **kwargs)
if resp.status != 200:
raise HttpFileDownloadedError(f'Failed to download {url}: Status: {resp.status}, Body: {resp.body}')
return resp.save(filepath)
def image(self, url, filepath, *args, filename=None, ext='png', dimensions=(50, 50), **kwargs):
if not Image:
izzylog.error('Pillow module is not installed')
return
resp = self.request(url, *args, **kwargs)
if resp.status != 200:
izzylog.error(f'Failed to download {url}:', resp.status, resp.body)
return False
if not filename:
filename = Path(url).stem()
path = Path(filepath)
if not path.exists:
izzylog.error('Path does not exist:', path)
return False
byte = BytesIO()
image = Image.open(BytesIO(resp.body))
image.thumbnail(dimensions)
image.save(byte, format=ext.upper())
with path.join(filename).open('wb') as fd:
fd.write(byte.getvalue())
def json(self, *args, headers={}, activity=True, **kwargs):
json_type = 'activity+json' if activity else 'json'
headers.update({
'accept': f'application/{json_type}'
})
return self.request(*args, headers=headers, **kwargs)
def set_default_client(client=None):
global Client
Client = client or HttpClient()
set_client(Client)

View file

@ -0,0 +1,111 @@
import json
from Crypto.Hash import SHA256
from izzylib import DotDict, LowerDotDict, Url, boolean
from base64 import b64decode, b64encode
from datetime import datetime
from izzylib import izzylog as logging
from .signatures import sign_pkcs_headers
methods = ['delete', 'get', 'head', 'options', 'patch', 'post', 'put']
class HttpUrllibRequest:
def __init__(self, url, **kwargs):
self._body = b''
method = kwargs.get('method', 'get').lower()
if method not in methods:
raise ValueError(f'Invalid method: {method}')
self.url = Url(url)
self.body = kwargs.get('body')
self.method = method
self.headers = LowerDotDict(kwargs.get('headers', {}))
self.redirect = boolean(kwargs.get('redirect', True))
self.retries = int(kwargs.get('retries', 10))
self.timeout = int(kwargs.get('timeout', 5))
privkey = kwargs.get('privkey')
keyid = kwargs.get('keyid')
if privkey and keyid:
self.sign(privkey, keyid)
@property
def _args(self):
return [self.method.upper(), self.url]
@property
def _kwargs(self):
return {
'body': self.body,
'headers': self.headers,
'redirect': self.redirect,
'retries': self.retries,
'timeout': self.timeout
}
@property
def body(self):
return self._body
@body.setter
def body(self, data):
if isinstance(data, dict):
data = DotDict(data).to_json()
elif any(map(isinstance, [data], [list, tuple])):
data = json.dumps(data)
if data == None:
data = b''
elif not isinstance(data, bytes):
data = bytes(data, 'utf-8')
self._body = data
def set_header(self, key, value):
self.headers[key] = value
def unset_header(self, key):
self.headers.pop(key, None)
def sign(self, privkey, keyid):
self.unset_header('signature')
self.set_header('(request-target)', f'{self.method.lower()} {self.url.path}')
self.set_header('host', self.url.host)
self.set_header('date', datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'))
if self.body:
body_hash = b64encode(SHA256.new(self.body).digest()).decode("UTF-8")
self.set_header('digest', f'SHA-256={body_hash}')
self.set_header('content-length', str(len(self.body)))
sig = {
'keyId': keyid,
'algorithm': 'rsa-sha256',
'headers': ' '.join([k.lower() for k in self.headers.keys()]),
'signature': b64encode(sign_pkcs_headers(privkey, self.headers)).decode('UTF-8')
}
sig_items = [f'{k}="{v}"' for k,v in sig.items()]
sig_string = ','.join(sig_items)
self.set_header('signature', sig_string)
self.unset_header('(request-target)')
self.unset_header('host')

View file

@ -0,0 +1,97 @@
import json
from io import BytesIO
from izzylib import DefaultDotDict, DotDict, Path, Url
class HttpUrllibResponse:
def __init__(self, response):
self.response = response
self._dict = None
def __getitem__(self, key):
return self.dict[key]
def __setitem__(self, key, value):
self.dict[key] = value
@property
def encoding(self):
for line in self.headers.get('content-type', '').split(';'):
try:
k,v = line.split('=')
if k.lower == 'charset':
return v.lower()
except:
pass
return 'utf-8'
@property
def headers(self):
return self.response.headers
@property
def status(self):
return self.response.status
@property
def url(self):
return Url(self.response.geturl())
@property
def body(self):
data = self.response.read(cache_content=True)
if not data:
data = self.response.data
return data
@property
def text(self):
return self.body.decode(self.encoding)
@property
def dict(self):
if not self._dict:
self._dict = DotDict(self.text)
return self._dict
def json_pretty(self, indent=4):
return self.dict.to_json(indent)
def chunks(self, size=1024):
return self.response.stream(amt=size)
def save(self, path, overwrite=True, create_parents=True):
path = Path(path)
if not path.parent.exists:
if not create_parents:
raise ValueError(f'Path does not exist: {path.parent}')
path.parent.mkdir()
if overwrite and path.exists:
path.delete()
with path.open('wb') as fd:
for chunk in self.chunks():
fd.write(chunk)

View file

@ -5,21 +5,21 @@ from setuptools import setup, find_namespace_packages
requires = [
'pillow==8.2.0',
'pycryptodome==3.10.1',
'requests==2.25.1',
'urllib==1.26.5',
'tldextract==3.1.0'
]
setup(
name="IzzyLib Requests Client",
name="IzzyLib Urllib3 Client",
version='0.6.0',
packages=find_namespace_packages(include=['izzylib.http_requests_client']),
packages=find_namespace_packages(include=['izzylib.http_urllib_client']),
python_requires='>=3.7.0',
install_requires=requires,
include_package_data=False,
author='Zoey Mae',
author_email='admin@barkshark.xyz',
description='A Requests client with support for http header signing and verifying',
description='A Urllib3 client with support for http header signing and verifying',
keywords='web http client',
url='https://git.barkshark.xyz/izaliamae/izzylib',
project_urls={

View file

@ -1,33 +0,0 @@
from .signature import (
verify_request,
verify_headers,
parse_signature,
fetch_actor,
fetch_instance,
fetch_nodeinfo,
fetch_webfinger_account,
generate_rsa_key
)
from .client import (
HttpRequestsClient,
HttpRequestsRequest,
HttpRequestsResponse,
set_requests_client
)
__all__ = [
'HttpRequestsClient',
'HttpRequestsRequest',
'HttpRequestsResponse',
'fetch_actor',
'fetch_instance',
'fetch_nodeinfo',
'fetch_webfinger_account',
'generate_rsa_key',
'parse_signature',
'set_requests_client',
'verify_headers',
'verify_request',
]

View file

@ -1,227 +0,0 @@
import json, requests, sys
from PIL import Image
from base64 import b64encode
from datetime import datetime
from functools import cached_property
from io import BytesIO
from izzylib import DefaultDotDict, DotDict, Path, izzylog as logging, __version__
from izzylib.exceptions import HttpFileDownloadedError
from ssl import SSLCertVerificationError
from urllib.parse import urlparse
from .signature import sign_request, set_client
Client = None
methods = ['connect', 'delete', 'get', 'head', 'options', 'patch', 'post', 'put', 'trace']
class HttpRequestsClient(object):
def __init__(self, headers={}, useragent=f'IzzyLib/{__version__}', appagent=None, proxy_type='https', proxy_host=None, proxy_port=None):
proxy_ports = {
'http': 80,
'https': 443
}
if proxy_type not in ['http', 'https']:
raise ValueError(f'Not a valid proxy type: {proxy_type}')
self.headers=headers
self.agent = f'{useragent} ({appagent})' if appagent else useragent
self.proxy = DotDict({
'enabled': True if proxy_host else False,
'ptype': proxy_type,
'host': proxy_host,
'port': proxy_ports[proxy_type] if not proxy_port else proxy_port
})
def set_global(self):
set_requests_client(self)
def build_request(self, *args, method='get', privkey=None, keyid=None, **kwargs):
if method.lower() not in methods:
raise ValueError(f'Invalid method: {method}')
request = HttpRequestsRequest(self, *args, method=method.lower(), **kwargs)
if privkey and keyid:
request.sign(privkey, keyid)
return request
def request(self, *args, **kwargs):
request = self.build_request(*args, **kwargs)
return HttpRequestsResponse(request.send())
def signed_request(self, privkey, keyid, *args, **kwargs):
return self.request(*args, privkey=privkey, keyid=keyid, **kwargs)
def download(self, url, filepath, *args, filename=None, **kwargs):
resp = self.request(url, *args, **kwargs)
if resp.status != 200:
raise HttpFileDownloadedError(f'Failed to download {url}: Status: {resp.status}, Body: {resp.body}')
return resp.save(filepath)
def image(self, url, filepath, *args, filename=None, ext='png', dimensions=(50, 50), **kwargs):
if not Image:
izzylog.error('Pillow module is not installed')
return
resp = self.request(url, *args, **kwargs)
if resp.status != 200:
izzylog.error(f'Failed to download {url}:', resp.status, resp.body)
return False
if not filename:
filename = Path(url).stem()
path = Path(filepath)
if not path.exists:
izzylog.error('Path does not exist:', path)
return False
byte = BytesIO()
image = Image.open(BytesIO(resp.body))
image.thumbnail(dimensions)
image.save(byte, format=ext.upper())
with path.join(filename).open('wb') as fd:
fd.write(byte.getvalue())
def json(self, *args, headers={}, activity=True, **kwargs):
json_type = 'activity+json' if activity else 'json'
headers.update({
'accept': f'application/{json_type}'
})
return self.request(*args, headers=headers, **kwargs)
class HttpRequestsRequest(object):
def __init__(self, client, url, data=b'', headers={}, query={}, method='get'):
parsed = urlparse(url)
self.args = [url]
self.kwargs = DotDict({'params': query})
self.method = method.lower()
self.client = client
self.path = parsed.path
self.host = parsed.netloc
self.body = data
new_headers = client.headers.copy()
new_headers.update(headers)
parsed_headers = {k.lower(): v for k,v in new_headers.items()}
if not parsed_headers.get('user-agent'):
parsed_headers['user-agent'] = client.agent
self.kwargs['headers'] = DotDict(new_headers)
if client.proxy.enabled:
self.kwargs['proxies'] = DotDict({self.proxy.ptype: f'{self.proxy.ptype}://{self.proxy.host}:{self.proxy.port}'})
@property
def body(self):
return self.kwargs.data
@body.setter
def body(self, data):
self.kwargs.data = data.encode('utf-8') if isinstance(data, str) else data
@property
def headers(self):
return self.kwargs.headers
def add_header(self, key, value):
self.kwargs.headers[key] = value
def remove_header(self, key):
self.kwargs.headers.pop(key, None)
def send(self):
func = getattr(requests, self.method)
return func(*self.args, **self.kwargs)
def sign(self, privkey, keyid):
sign_request(self, privkey, keyid)
class HttpRequestsResponse(object):
def __init__(self, response):
self.response = response
self.data = b''
self.headers = DefaultDotDict({k.lower(): v.lower() for k,v in response.headers.items()})
self.status = response.status_code
self.url = response.url
def chunks(self, size=256):
return self.response.iter_content(chunk_size=256)
@property
def body(self):
for chunk in self.chunks():
self.data += chunk
return self.data
@cached_property
def text(self):
return self.data.decode(self.response.encoding)
@cached_property
def json(self):
try:
return DotDict(self.text)
except:
return DotDict(self.body)
@cached_property
def json_pretty(self, indent=4):
return json.dumps(self.json, indent=indent)
def save(self, path, overwrite=True):
path = Path(path)
if not path.parent.exists:
raise ValueError(f'Path does not exist: {path.parent}')
if overwrite and path.exists:
path.delete()
with path.open('wb') as fd:
for chunk in self.chunks():
fd.write(chunk)
def set_requests_client(client=None):
global Client
Client = client or RequestsClient()
set_client(Client)

View file

@ -1,6 +1,13 @@
# old sql classes
from .generic import SqlColumn, CustomRows, SqlSession, SqlDatabase, Tables, OperationalError, ProgrammingError
from .sqlite_server import SqliteClient, SqliteColumn, SqliteServer, SqliteSession
## Normal SQL client
from .database import Database, OperationalError, ProgrammingError
from .session import Session
from .column import Column
#from .database import Database, Session
#from .queries import Column, Insert, Select, Table, Tables, Update
## Sqlite server
#from .sqlite_server import SqliteClient, SqliteColumn, SqliteServer, SqliteSession
## Compat
SqlDatabase = Database
SqlSession = Session
SqlColumn = Column

54
sql/izzylib/sql/column.py Normal file
View file

@ -0,0 +1,54 @@
from sqlalchemy import ForeignKey
from sqlalchemy import (
Column as sqlalchemy_column,
types as Types
)
SqlTypes = {t.lower(): getattr(Types, t) for t in dir(Types) if not t.startswith('_')}
class Column(sqlalchemy_column):
def __init__(self, name, stype=None, fkey=None, **kwargs):
if not stype and not kwargs:
if name == 'id':
stype = 'integer'
kwargs['primary_key'] = True
kwargs['autoincrement'] = True
elif name == 'timestamp':
stype = 'datetime'
else:
raise ValueError('Missing column type and options')
stype = (stype.lower() if type(stype) == str else stype) or 'string'
if type(stype) == str:
try:
stype = SqlTypes[stype.lower()]
except KeyError:
raise KeyError(f'Invalid SQL data type: {stype}')
options = [name, stype]
if fkey:
options.append(ForeignKey(fkey))
super().__init__(*options, **kwargs)
def compile(self):
sql = f'{self.name} {self.type}'
if not self.nullable:
sql += ' NOT NULL'
if self.primary_key:
sql += ' PRIMARY KEY'
if self.unique:
sql += ' UNIQUE'
return sql

View file

@ -1,100 +0,0 @@
import importlib, sqlite3, ssl
from getpass import getuser
from izzylib import DotDict, Path, izzylog
defaults = {
'name': (None, str),
'host': (None, str),
'port': (None, int),
'username': (getuser(), str),
'password': (None, str),
'ssl': ('allow', str),
'ssl_context': (ssl.create_default_context(), ssl.SSLContext),
'ssl_key': (None, Path),
'ssl_cert': (None, Path),
'max_connections': (25, int),
'type': ('sqlite', str),
'module': (sqlite3, None),
'mod_name': ('sqlite3', str),
'timeout': (5, int),
'args': ([], list),
'kwargs': ({}, dict)
}
modtypes = {
'sqlite': ['sqlite3'],
'postgresql': ['pg8000', 'psycopg2', 'psycopg3', 'pgdb'],
'mysql': ['mysqldb', 'trio_mysql'],
'mssql': ['pymssql', 'adodbapi']
}
sslmodes = ['disable', 'allow', 'prefer', 'require', 'verify-ca', 'verify-full']
class Config(DotDict):
def __init__(self, **kwargs):
super().__init__({k: v[0] for k,v in defaults.items()})
module = kwargs.pop('module', None)
if module:
self.parse_module(module)
self.update(kwargs)
if self.ssl != 'disable' and (self.ssl_key or self.ssl_cert):
self.ssl_context.load_cert_chain(self.ssl_cert, self.ssl_key)
def __setitem__(self, key, value):
if key not in defaults:
raise KeyError(f'Invalid config option: {key}')
valtype = defaults[key][1]
if valtype and value and not isinstance(value, valtype):
raise TypeError(f'{key} should be a {valtype}, not a {value.__class__.__name__}')
if key == 'ssl' and value == True:
value = ssl.create_default_context()
super().__setitem__(key, value)
def parse_module(self, name):
module = None
module_type = None
module_name = None
if name == 'sqlite3':
name = 'sqlite'
for mtype, modules in modtypes.items():
if name == mtype:
module_type = name
for mod in modules:
try:
module = importlib.import_module(mod)
module_name = mod
break
except ImportError:
izzylog.verbose(f'Database module not installed:', mod)
elif name in modules:
try:
module = importlib.import_module(name)
module_type = mtype
module_name = name
break
except ImportError:
izzylog.error(f'Database module not installed:', name)
if None in (module, module_name, module_type):
raise ValueError(f'Failed to find module for {name}')
self.module = module
self.mod_name = module_name
self.type = module_type

View file

@ -1,360 +1,186 @@
import sqlite3, traceback
import json, pkgutil, sys, threading, time
from functools import partial
from getpass import getuser
from izzylib import DotDict, izzylog, boolean, random_gen
from contextlib import contextmanager
from datetime import datetime
from izzylib import LruCache, DotDict, Path, nfs_check, izzylog
from sqlalchemy import Table, create_engine
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.engine import URL
from sqlalchemy.schema import MetaData
from . import error
from .config import Config
from .queries import Column, Delete, Insert, Select, Table, Tables, Update
from .rows import RowClasses
from .session import Session
modules = dict(
postgresql = ['pygresql', 'pg8000', 'psycopg2', 'psycopg3']
)
class Database:
def __init__(self, tables=None, **kwargs):
self.tables = tables
self.cfg = Config(**kwargs)
self.sessions = DotDict()
def __init__(self, dbtype='sqlite', **kwargs):
self._connect_args = [dbtype, kwargs]
self.db = None
self.cache = None
self.config = DotDict()
self.meta = MetaData()
self.classes = RowClasses(*kwargs.get('row_classes', []))
self.cache = None
self.session_class = kwargs.get('session_class', Session)
self.sessions = {}
self.open()
def _setup_cache(self):
self.cache = DotDict({table: LruCache() for table in self.get_tables()})
@property
def session(self):
return self.get_session(False)
return self.session_class(self)
@property
def session_trans(self):
return self.get_session(True)
def dbtype(self):
return self.db.url.get_backend_name()
def connect(self, sid, session):
if len(self.sessions) >= self.cfg.max_connections:
raise error.MaxConnectionsError(f'Cannot start a new session with id {sid}. Reach max connection count of {self.cfg.max_connections}.')
self.sessions[sid] = session
def disconnect(self, sid):
self.sessions[sid].disconnect()
del self.sessions[sid]
def disconnect_all(self):
sids = []
for sid in self.sessions.keys():
sids.append(sid)
for sid in sids:
self.disconnect(sid)
def get_session(self, trans=True):
session = Session(self, trans)
self.sessions[session.id] = session
return session
def execute(self, *args):
with self.session as s:
s.execute(*args)
def load_tables(self, path):
self.tables = Tables.new_from_json_file(path)
def pre_setup(self):
if self.cfg.type != 'postgresql':
izzylog.verbose(f'Database not supported for pre_setup: {self.cfg.type}')
return
original_database = self.cfg.name
self.cfg.name = 'postgres'
with self.session as s:
s.conn.autocommit = True
s.rollback()
if original_database not in s.get_databases():
#s.execute('SET AUTOCOMMIT = OFF')
s.cursor.execute(f'CREATE DATABASE {original_database}')
s.conn.autocommit = False
self.cfg.name = original_database
def set_row_class(self, table, row_class):
pass
class Session:
def __init__(self, db, trans):
self.id = random_gen()
self.db = db
self.cfg = db.cfg
self.trans = trans
self.trans_state = False
self.conn = None
self.cursor = None
def __del__(self):
try:
izzylog.verbose('Deleting session:', self.id)
except ModuleNotFoundError:
if izzylog.get_config('level') >= 20:
print('[izzylib] VERBOSE: Deleting session:', self.id)
self.db.sessions.pop(self.id, None)
if self.conn:
self.disconnect()
def __enter__(self):
self.connect()
if self.trans:
self.begin()
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
if exc_traceback:
self.rollback()
else:
self.commit()
self.disconnect()
self.db.disconnect(self.id)
def connect(self):
if self.conn:
return
self.db.connect(self.id, self)
if self.cfg.type == 'sqlite':
self.conn = self.cfg.module.connect(self.cfg.name, self.cfg.timeout, check_same_thread=True)
elif self.cfg.type == 'postgresql':
options = dict(
host = self.cfg.host or '/var/run/postgresql',
port = self.cfg.port or 5432,
database = self.cfg.name or 'postgresql',
user = self.cfg.username or getuser(),
password = self.cfg.password,
)
if self.cfg.mod_name == 'pg8000':
if options['host'] in [None, '/var/run/postgresql']:
port = options.pop('port')
options['unix_sock'] = options.pop('host') + f'/.s.PGSQL.{port}'
## SSL is a pain in the ass tbh. Gonna deal with this later
#if self.cfg.mod_name == 'pg8000':
#options['sslmode'] = self.cfg.ssl
#options['ssl_context'] = self.cfg.ssl_context
#elif self.cfg.mod_name == 'psycopg2':
#options['sslcert'] = self.cfg.ssl_cert
#options['sslkey'] = self.cfg.ssl_key
self.conn = self.cfg.module.connect(**options)
else:
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
try:
self.conn.autocommit = False
except:
izzylog.verbose('Failed to turn off autocommit')
self.cursor = self.conn.cursor()
def disconnect(self):
if not self.conn:
return
self.cursor.close()
self.conn.close()
self.cursor = None
self.conn = None
def begin(self):
if self.trans_state:
return
#self.conn.begin()
self.execute('BEGIN TRANSACTION')
self.trans_state = True
def rollback(self):
if not self.trans_state:
return
self.conn.rollback()
#self.execute('ROLLBACK TRANSACTION')
self.trans_state = False
def commit(self):
if not self.trans_state:
return
self.conn.commit()
#self.execute('COMMIT TRANSACTION')
self.trans_state = False
## data management functions
def execute(self, string, values=[]):
if any(map(string.lower().startswith, ['insert', 'update', 'remove', 'create', 'drop'])) and not self.trans_state:
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
self.cursor.execute(string, values)
return self.cursor
def fetch(self, table, single=True, **kwargs):
rows = []
data = Select(table, type=self.cfg.type, **kwargs).exec(self)
for line in data:
row = Row(table, self.cursor.description, line)
if single:
return row
rows.append(row)
return rows if not single else None
def search(self, table, **kwargs):
return self.fetch(table, single=False, **kwargs)
def insert(self, table, **kwargs):
if not self.trans_state:
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
Insert(table, type=self.cfg.type, **kwargs).exec(self)
return self.fetch(table, **kwargs)
def update(self, table, rowid, **kwargs):
if not self.trans_state:
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
Update(table, rowid, type=self.cfg.type, **kwargs).exec(self)
return self.fetch(table, id=rowid)
def delete(self, table, **kwargs):
if not self.trans_state:
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
Delete(table, type=self.cfg.type, **kwargs).exec(self)
## helper functions
def get_columns(self, table):
if table not in self.get_tables():
raise KeyError(f'Not an existing table: {table}')
if self.cfg.type == 'sqlite':
rows = self.execute(f'PRAGMA table_info({table})')
return [row[1] for row in rows]
elif self.cfg.type == 'postgresql':
rows = self.execute(f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table}'")
return [row[0] for row in rows]
else:
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
@property
def table(self):
return DotDict(self.meta.tables)
def get_tables(self):
if self.cfg.type == 'sqlite':
rows = self.execute("SELECT name FROM sqlite_master WHERE type IN ('table','view') and name NOT LIKE 'sqlite_%'")
return list(self.table.keys())
elif self.cfg.type == 'postgresql':
rows = self.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' ORDER BY table_name")
def get_columns(self, table):
return list(col.name for col in self.table[table].columns)
def new_session(self, trans=True):
return self.session_class(self, trans=trans)
## Leaving link to example code for read-only sqlite for later use
## https://github.com/pudo/dataset/issues/136#issuecomment-128693122
def open(self):
dbtype, kwargs = self._connect_args
engine_kwargs = {
'future': True,
#'maxconnections': 25
}
if not kwargs.get('name'):
raise KeyError('Database "name" is not set')
if dbtype == 'sqlite':
database = kwargs['name']
if nfs_check(database):
izzylog.warning('Database file is on an NFS share which does not support locking. Any writes to the database will fail')
engine_kwargs['connect_args'] = {'check_same_thread': False}
elif dbtype == 'postgresql':
ssl_context = kwargs.get('ssl')
if ssl_context:
engine_kwargs['ssl_context'] = ssl_context
if not kwargs.get('host'):
kwargs['unix_socket'] = '/var/run/postgresql'
if kwargs.get('host') and Path(kwargs['host']).exists():
kwargs['unix_socket'] = kwargs.pop('host')
else:
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
raise TypeError(f'Unsupported database type: {dbtype}')
return [row[0] for row in rows]
self.config.update(kwargs)
def get_databases(self):
if self.cfg.type == 'sqlite':
izzylog.verbose('This function is useless with sqlite')
return
elif self.cfg.type == 'postgresql':
databases = [row[0] for row in self.execute('SELECT datname FROM pg_database')]
if dbtype == 'sqlite':
url = URL.create(
drivername='sqlite',
database=kwargs.pop('name')
)
else:
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
try:
for module in modules[dbtype]:
if pkgutil.get_loader(module):
dbtype = f'{dbtype}+{module}'
return databases
except KeyError:
pass
url = URL.create(
drivername = dbtype,
username = kwargs.pop('user', None),
password = kwargs.pop('password', None),
host = kwargs.pop('host', None),
port = kwargs.pop('port', None),
database = kwargs.pop('name'),
)
self.db = create_engine(url, **engine_kwargs)
self.meta = MetaData()
self.meta.reflect(bind=self.db, resolve_fks=True, views=True)
self._setup_cache()
def cursor_description(self):
return [row[0] for row in self.cursor.description]
def close(self):
for sid in list(self.sessions):
self.sessions[sid].commit()
self.sessions[sid].close()
self.config = DotDict()
self.cache = DotDict()
self.meta = None
self.db = None
def setup_database(self):
if not self.db.tables:
raise ValueError('Tables have not been specified.')
def load_tables(self, **tables):
self.meta = MetaData()
current_tables = self.get_tables()
for name, columns in tables.items():
Table(name, self.meta, *columns)
for name, table in self.db.tables.items():
if name in current_tables:
izzylog.verbose(f'Skipping table creation since it already exists: {name}')
continue
izzylog.verbose(f'Creating table: {name}')
self.execute(table.build(self.cfg.type))
self._setup_cache()
class Row(DotDict):
def __init__(self, table, keys, values):
self._db = None
self._table = table
def create_database(self, tables={}):
if tables:
self.load_tables(**tables)
super().__init__()
if self.db.url.get_backend_name() == 'postgresql':
predb = create_engine(self.db.engine_string.replace(self.config.name, 'postgres', -1), future=True)
conn = predb.connect()
for idx, key in enumerate([key[0] for key in keys]):
self[key] = values[idx]
try:
conn.execute(text(f'CREATE DATABASE {database}'))
except ProgrammingError:
'The database already exists, so just move along'
except Exception as e:
conn.close()
raise e from None
conn.close()
self.meta.create_all(bind=self.db)
def update(self, data):
for k, v in data.items():
if k not in self:
raise KeyError(f'Not a column for {self._table}')
def drop_tables(self, *tables):
if not tables:
raise ValueError('No tables specified')
self[k] = v
self.meta.drop_all(bind=self.db, tables=tables)
def delete(self):
with self._db.session as s:
s.delete(self._table, id=self.id)
def update(self, **kwargs):
self.update(kwargs)
with self._db.session as s:
s.update(self._table, id=self.id, **kwargs)
def execute(self, string, **kwargs):
with self.session as s:
s.execute(string, **kwargs)

View file

@ -1,10 +0,0 @@
class MaxConnectionsError(Exception):
'raise when the max amount of connections has been reached'
class NoTransactionError(Exception):
'raise when a write command is executed outside a transaction'
class DatabaseNotSupportedError(Exception):
'raise when the action being performed is not supported by the database in use'

View file

@ -1,508 +0,0 @@
import json, sys, threading, time
from contextlib import contextmanager
from datetime import datetime
from sqlalchemy import create_engine, ForeignKey, MetaData, Table
from sqlalchemy import Column as sqlalchemy_column, types as Types
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.orm import scoped_session, sessionmaker
from izzylib import (
LruCache,
DotDict,
Path,
random_gen,
nfs_check,
izzylog
)
SqlTypes = DotDict({t.lower(): getattr(Types, t) for t in dir(Types) if not t.startswith('_')})
class SqlDatabase:
def __init__(self, dbtype='sqlite', tables={}, **kwargs):
self.db = self.__create_engine(dbtype, kwargs)
self.table = None
self.table_names = None
self.classes = kwargs.get('row_classes', CustomRows())
self.cache = None
self.session_class = kwargs.get('session_class', SqlSession)
self.sessions = {}
self.setup_tables(tables)
self.setup_cache()
## Leaving link to example code for read-only sqlite for later use
## https://github.com/pudo/dataset/issues/136#issuecomment-128693122
def __create_engine(self, dbtype, kwargs):
engine_args = []
engine_kwargs = {}
if not kwargs.get('name'):
raise KeyError('Database "name" is not set')
engine_string = dbtype + '://'
if dbtype == 'sqlite':
database = kwargs['name']
if nfs_check(database):
izzylog.error('Database file is on an NFS share which does not support locking. Any writes to the database will fail')
engine_string += f'/{database}'
engine_kwargs['connect_args'] = {'check_same_thread': False}
elif dbtype == 'postgresql':
ssl_context = kwargs.get('ssl')
if ssl_context:
engine_kwargs['ssl_context'] = ssl_context
else:
user = kwargs.get('user')
password = kwargs.get('pass')
host = kwargs.get('host', '/var/run/postgresql')
port = kwargs.get('port', 5432)
name = kwargs.get('name', 'postgres')
maxconn = kwargs.get('maxconnections', 25)
if user:
if password:
engine_string += f'{user}:{password}@'
else:
engine_string += user + '@'
if host == '/var/run/postgresql':
engine_string += f'/{name}:{port}/{name}'
else:
engine_string += f'{host}:{port}/{name}'
return create_engine(engine_string, *engine_args, **engine_kwargs)
@property
def session(self):
return self.session_class(self)
def close(self):
self.setup_cache()
def setup_cache(self):
self.cache = DotDict({table: LruCache() for table in self.table_names})
def create_tables(self, *tables):
if not tables:
raise ValueError('No tables specified')
new_tables = [self.table[table] for table in tables]
self.table.meta.create_all(bind=self.db, tables=new_tables)
def create_database(self):
if self.db.url.get_backend_name() == 'postgresql':
predb = create_engine(db.engine_string.replace(config.db.name, 'postgres', -1))
conn = predb.connect()
conn.execute('commit')
try:
conn.execute(f'CREATE DATABASE {config.db.name}')
except ProgrammingError:
'The database already exists, so just move along'
except Exception as e:
conn.close()
raise e from None
conn.close()
self.table.meta.create_all(self.db)
def setup_tables(self, tables):
self.table = Tables(self, tables)
self.table_names = tables.keys()
def execute(self, string, values=[]):
with self.session as s:
s.execute(string, values)
class SqlSession(object):
def __init__(self, db):
self.closed = False
self.database = db
self.classes = db.classes
self.session = sessionmaker(bind=db.db)()
self.table = db.table
self.cache = db.cache
# session aliases
self.s = self.session
self.begin = self.s.begin
self.commit = self.s.commit
self.rollback = self.s.rollback
self.query = self.s.query
self.execute = self.s.execute
# remove in the future
self.db = db
self._setup()
def __enter__(self):
self.open()
return self
def __exit__(self, exctype, value, tb):
if tb:
self.rollback()
self.close()
def open(self):
self.sessionid = random_gen(10)
self.db.sessions[self.sessionid] = self
def close(self):
self.commit()
self.s.close()
self.closed = True
del self.db.sessions[self.sessionid]
self.sessionid = None
def _setup(self):
pass
@property
def dirty(self):
return any([self.s.new, self.s.dirty, self.s.deleted])
def count(self, table_name, **kwargs):
return self.query(self.table[table_name]).filter_by(**kwargs).count()
def fetch(self, table_name, single=True, orderby=None, orderdir='asc', **kwargs):
table = self.table[table_name]
RowClass = self.classes.get(table_name.capitalize())
query = self.query(table).filter_by(**kwargs)
if not orderby:
rows = query.all()
else:
if orderdir == 'asc':
rows = query.order_by(getattr(table.c, orderby).asc()).all()
elif orderdir == 'desc':
rows = query.order_by(getattr(table.c, orderby).desc()).all()
else:
raise ValueError(f'Unsupported order direction: {orderdir}')
if single:
return RowClass(table_name, rows[0], self) if len(rows) > 0 else None
return [RowClass(table_name, row, self) for row in rows]
def search(self, *args, **kwargs):
kwargs.pop('single', None)
return self.fetch(*args, single=False, **kwargs)
def insert(self, table_name, return_row=False, **kwargs):
row = self.fetch(table_name, **kwargs)
if row:
row.update_session(self, **kwargs)
return
table = self.table[table_name]
if getattr(table, 'timestamp', None) and not kwargs.get('timestamp'):
kwargs['timestamp'] = datetime.now()
self.execute(table.insert().values(**kwargs))
if return_row:
return self.fetch(table_name, **kwargs)
def update(self, table=None, rowid=None, row=None, return_row=False, **data):
if row:
if not getattr(row, '_table_name', None):
print(row)
print(dir(row))
rowid = row.id
table = row._table_name
if not rowid or not table:
raise ValueError('Missing row ID or table')
tclass = self.table[table]
self.execute(tclass.update().where(tclass.c.id == rowid).values(**data))
if return_row:
return self.fetch(table, id=rowid)
def remove(self, table=None, rowid=None, row=None):
if row:
rowid = row.id
table = row._table_name
if not rowid or not table:
raise ValueError('Missing row ID or table')
self.execute(f'DELETE FROM {table} WHERE id={rowid}')
def drop_table(self, name):
if name not in self.get_tables():
raise KeyError(f'Table does not exist: {name}')
self.execute(f'DROP TABLE {name}')
def drop_tables(self):
tables = self.get_tables()
for table in tables:
self.drop_table(table)
def get_columns(self, table):
if table not in self.get_tables():
raise KeyError(f'Not an existing table: {table}')
rows = self.execute('PRAGMA table_info(user)')
return [row[1] for row in rows]
def get_tables(self):
rows = self.execute("SELECT name FROM sqlite_master WHERE type IN ('table','view') and name NOT LIKE 'sqlite_%'")
return [row[0] for row in rows]
def append_column(self, table, column):
if column.name in self.get_columns(table):
logging.warning(f'Table "{table}" already has column "{column.name}"')
return
self.execute(f'ALTER TABLE {table} ADD COLUMN {column.compile()}')
def append_column2(self, tbl, col):
table = self.table[tbl]
try:
column = getattr(table.c, col)
except AttributeError:
izzylog.error(f'Table "{tbl}" does not have column "{col}"')
return
columns = [row[1] for row in self.execute(f'PRAGMA table_info({tbl})')]
if col in self.get_columns(tbl):
izzylog.info(f'Column "{col}" already exists')
return
sql = f'ALTER TABLE {tbl} ADD COLUMN {col} {column.type}'
if not column.nullable:
sql += ' NOT NULL'
if column.primary_key:
sql += ' PRIMARY KEY'
if column.unique:
sql += ' UNIQUE'
self.execute(sql)
def remove_column(self, tbl, col):
table = self.table[tbl]
column = getattr(table, col, None)
columns = [row[1] for row in self.execute(f'PRAGMA table_info({tbl})')]
if col not in columns:
izzylog.info(f'Column "{col}" already exists')
return
columns.remove(col)
coltext = ', '.join(columns)
self.execute(f'CREATE TABLE {tbl}_temp AS SELECT {coltext} FROM {tbl}')
self.execute(f'DROP TABLE {tbl}')
self.execute(f'ALTER TABLE {tbl}_temp RENAME TO {tbl}')
def clear_table(self, table):
self.execute(f'DELETE FROM {table}')
class CustomRows(object):
def get(self, name):
return getattr(self, name, self.Row)
class Row(DotDict):
#_filter_columns = lambda self, row: [attr for attr in dir(row) if not attr.startswith('_') and attr != 'metadata']
def __init__(self, table, row, session):
super().__init__()
if row:
try:
self._update(row._asdict())
except:
self._update(row)
self._db = session.db
self._table_name = table
self._columns = self.keys()
self.__run__(session)
## Subclass Row and redefine this function
def __run__(self, s):
pass
def _filter_data(self):
data = {k: v for k,v in self.items() if k in self._columns}
for k,v in self.items():
if v.__class__ == DotDict:
data[k] = v.asDict()
return data
def asDict(self):
return self._filter_data()
def _update(self, new_data={}, **kwargs):
kwargs.update(new_data)
for k,v in kwargs.items():
if type(v) == dict:
self[k] = DotDict(v)
self[k] = v
def delete(self, s=None):
if s:
return self.delete_session(s)
with self._db.session as s:
return self.delete_session(s)
def delete_session(self, s):
return s.remove(table=self._table_name, row=self)
def update(self, dict_data={}, s=None, **data):
dict_data.update(data)
self._update(dict_data)
if s:
return self.update_session(s, **self._filter_data())
with self._db.session as s:
s.update(row=self, **self._filter_data())
def update_session(self, s, dict_data={}, **data):
dict_data.update(data)
self._update(dict_data)
return s.update(table=self._table_name, row=self, **dict_data)
class Tables(DotDict):
def __init__(self, db, tables={}):
'"tables" should be a dict with the table names for keys and a list of Columns for values'
super().__init__()
self.db = db
self.meta = MetaData()
for name, table in tables.items():
self.__setup_table(name, table)
def __setup_table(self, name, table):
columns = [col if type(col) == SqlColumn else SqlColumn(*col.get('args'), **col.get('kwargs')) for col in table]
self[name] = Table(name, self.meta, *columns)
class SqlColumn(sqlalchemy_column):
def __init__(self, name, stype=None, fkey=None, **kwargs):
if not stype and not kwargs:
if name == 'id':
stype = 'integer'
kwargs['primary_key'] = True
kwargs['autoincrement'] = True
elif name == 'timestamp':
stype = 'datetime'
else:
raise ValueError('Missing column type and options')
stype = (stype.lower() if type(stype) == str else stype) or 'string'
if type(stype) == str:
try:
stype = SqlTypes[stype.lower()]
except KeyError:
raise KeyError(f'Invalid SQL data type: {stype}')
options = [name, stype]
if fkey:
options.append(ForeignKey(fkey))
super().__init__(*options, **kwargs)
def compile(self):
sql = f'{self.name} {self.type}'
if not self.nullable:
sql += ' NOT NULL'
if self.primary_key:
sql += ' PRIMARY KEY'
if self.unique:
sql += ' UNIQUE'
return sql

View file

@ -1,415 +0,0 @@
from datetime import datetime
from functools import partial
from izzylib import DotDict, Path
from .types import BaseType, Type
placeholders = dict(
sqlite = '?',
postgresql = '%s'
)
## Data queries
class Delete:
def __init__(self, table, type='sqlite', **kwargs):
self.table = table
self.placeholder = placeholders[type]
self.keys = []
self.values = []
for k,v in kwargs.items():
self.keys.append(k)
self.values.append(v)
def __str__(self):
self.build(embed_values=True)
def build(self, comp_type='AND', embed_values=False):
sql = 'DELETE FROM {table} WHERE {kstring}'
if not embed_values:
kstring = f' {comp_type.upper()} '.join([f'{k} = {self.placeholder}' for k in self.keys])
return sql.format(table=self.table, kstring=kstring), self.values
values = []
for idx, value in enumerate(self.values):
if type(value) == str:
values.append(f"{self.keys[idx]} = '{value}'")
else:
values.append(f"{self.keys[idx]} = {value}")
kstring = ','.join(values)
return sql.format(table=self.table, kstring=kstring, rowid=self.rowid)
def exec(self, session, comp_type='AND'):
return session.execute(*self.build(comp_type))
class Insert:
def __init__(self, table, type='sqlite', **kwargs):
self.table = table
self.placeholder = placeholders[type]
self.keys = []
self.values = []
for k, v in kwargs.items():
self.keys.append(k)
self.values.append(v)
def __str__(self):
return self.build(embed_values=True)
def build(self, embed_values=False):
kstring = ','.join(self.keys)
if not embed_values:
vstring = ','.join([self.placeholder for k in self.keys])
return f'INSERT INTO {self.table} ({kstring}) VALUES({vstring})', self.values
else:
vstring = ','.join(self.values)
return f'INSERT INTO {self.table} ({kstring}) VALUES({vstring})'
def exec(self, session):
return session.execute(*self.build())
class Select:
def __init__(self, table, columns=[], type='sqlite', **kwargs):
self.placeholder = placeholders[type]
self.columns = columns
self.table = table
self.where = []
self.where_build = []
self._order = []
self.keys = []
self.values = []
self.equals = partial(self.__comparison, '=')
self.less = partial(self.__comparison, '<')
self.greater = partial(self.__comparison, '>')
self.like = partial(self.__comparison, 'LIKE')
for k,v in kwargs.items():
self.equals(k, v)
def __str__(self):
return self.build(embed_values=True)
def __comparison(self, comp, key, value):
self.values.append(value)
self.keys.append(key)
self.where.append(f'{key} {comp.upper()} {self.placeholder}')
self.where_build.append(f"{key} {comp.upper()} '{value}'" if type(key) == str else f"{key} {comp.upper()} {value}")
return self
def order(self, column, asc=True):
self._order = [column, 'ASC' if asc else 'DESC']
return self
def build(self, comp_type='AND', embed_values=False):
if not self.columns:
cols = '*'
else:
cols = ','.join('columns')
sql_query = f'SELECT {cols} FROM {self.table}'
if self.where:
where = f' {comp_type.upper()} '.join(self.where if not embed_values else self.where_build)
sql_query += f' WHERE {where}'
if self._order:
col, order = self._order
sql_query += f' ORDER BY {col} {order}'
if embed_values:
return sql_query
return sql_query, self.values
def exec(self, session, comp_type='AND'):
return session.execute(*self.build(comp_type))
class Update:
def __init__(self, table, rowid, type='sqlite', **kwargs):
self.placeholder = placeholders[type]
self.table = table
self.rowid = rowid
self.keys = []
self.values = []
for k,v in kwargs.items():
self.keys.append(k)
self.values.append(v)
def __str__(self):
return self.build(embed_values=True)
def build(self, embed_values=False):
sql = 'UPDATE {table} SET {kstring} WHERE id={rowid}'
if not embed_values:
kstring = ','.join([f'{k} = {self.placeholder}' for k in self.keys])
return sql.format(table=self.table, kstring=kstring, rowid=self.rowid), self.values
values = []
for idx, value in enumerate(self.values):
if type(value) == str:
values.append(f"{self.keys[idx]} = '{value}'")
else:
values.append(f"{self.keys[idx]} = {value}")
kstring = ','.join(values)
return sql.format(table=self.table, kstring=kstring, rowid=self.rowid)
def exec(self, session):
return session.execute(*self.build())
## Database objects
class Column:
def __init__(self, name, type='STRING', unique=False, nullable=True, default=None, primary_key=False, autoincrement=False, foreign_key=None):
self.name = name
self.type = type
self.nullable = nullable
self.default = default
self.primary_key = primary_key
self.autoincrement = autoincrement
self.unique = unique
if any(map(isinstance, [foreign_key], [list, tuple, set])):
self.foreign_key = foreign_key
else:
self.foreign_key = foreign_key.split('.') if foreign_key else None
if autoincrement:
self.primary_key = True
self.type = Type['INTEGER']
if isinstance(self.type, BaseType):
self.type = self.type.name
else:
if self.type.upper() in Type.keys():
self.type = self.type.upper()
else:
raise TypeError(f'Invalid SQL type: {self.type}')
if foreign_key and len(self.foreign_key) != 2:
raise ValueError('Invalid foreign key. Must be in the format "table.column".')
def __str__(self):
return self.build()
def build(self, dbtype='sqlite'):
if dbtype == 'postgresql':
if self.type.lower() == 'string':
self.type = 'TEXT'
elif self.type.lower() == 'datetime':
self.type = 'TIMESTAMPTZ'
if self.autoincrement:
self.type = 'SERIAL'
self.autoincrement = False
sql = f'{self.name} {self.type}'
if self.primary_key:
sql += ' PRIMARY KEY'
if self.autoincrement:
sql += ' AUTOINCREMENT'
if self.unique:
sql += ' UNIQUE'
if not self.nullable:
sql += ' NOT NULL'
if self.default:
def_type = type(self.default)
if self.default == 'CURRENT_TIMESTAMP':
if dbtype == 'sqlite':
sql += " DEFAULT (datetime('now', 'localtime'))"
elif dbtype == 'postgresql':
sql += ' DEFAULT now()'
else:
sql += f' DEFAULT {datetime.now().timestamp()}'
elif def_type == str:
sql += f" DEFAULT '{self.default}'"
elif def_type in [int, float]:
sql += f' DEFAULT {self.default}'
elif def_type == bool and dbtype == 'sqlite':
sql += f' DEFAULT {int(self.default)}'
else:
sql += f' DEFAULT {self.default}'
print(sql)
return sql
def json(self):
return DotDict({
'type': self.type,
'nullable': self.nullable,
'default': self.default,
'primary_key': self.primary_key,
'autoincrement': self.autoincrement,
'unique': self.unique,
'foreign_key': self.foreign_key
})
class Table(DotDict):
def __init__(self, name, *columns):
super().__init__()
self._name = name
self._foreign_keys = {}
self.add_column(Column('id', autoincrement=True))
for column in columns:
self.add_column(column)
def __str__(self):
return self.build()
# this'll be useful later
def __call__(self, *args, **kwargs):
pass
@property
def name(self):
return self._name
def add_column(self, column):
self[column.name] = column
if column.foreign_key:
self._foreign_keys[column.name] = column.foreign_key
def build(self, dbtype='sqlite'):
column_string = ',\n'.join([f'\t{col.build(dbtype)}' for col in self.values()])
if self._foreign_keys:
column_string += ',\n'
column_string += ',\n'.join([f'\tFOREIGN KEY ({column}) REFERENCES {key[0]} ({key[1]})' for column, key in self._foreign_keys.items()])
return f'''CREATE TABLE {self.name} (
{column_string}
);'''
def json(self):
data = {}
for name, column in self.items():
data[name] = column.json()
return data
class Tables(DotDict):
def __init__(self, *tables, data={}):
super().__init__()
for table in tables:
self.add_table(table)
if data:
self.from_dict(data)
def __str__(self):
return self.build()
@classmethod
def new_from_json_file(cls, path):
return cls(data=DotDict.new_from_json_file(path))
def add_table(self, table):
self[table.name] = table
def build(self):
return '\n\n'.join([str(table) for table in self.values()])
def load_json(self, path):
data = DotDict()
data.load_json(path)
self.from_dict(data)
def save_json(self, path, indent='\t'):
self.to_dict().save_json(path, indent=indent)
def from_dict(self, data):
for name, columns in data.items():
table = Table(name)
for col, kwargs in columns.items():
table.add_column(Column(col,
type = kwargs.get('type', 'STRING'),
nullable = kwargs.get('nullable', True),
default = kwargs.get('default'),
primary_key = kwargs.get('primary_key', False),
autoincrement = kwargs.get('autoincrement', False),
unique = kwargs.get('unique', False),
foreign_key = kwargs.get('foreign_key')
))
self.add_table(table)
def to_dict(self):
data = DotDict()
for name, table in self.items():
data[name] = table.json()
return data

View file

@ -1,19 +0,0 @@
from izzylib import DotDict
class DbRow(DotDict):
def __init__(self, table, keys, values):
self.table = table
super().__init__()
for idx, key in enumerate(keys):
self[key] = values[idx]
def delete(self):
pass
def update(self, **kwargs):
pass

90
sql/izzylib/sql/rows.py Normal file
View file

@ -0,0 +1,90 @@
from izzylib import DotDict
class RowClasses(DotDict):
def __init__(self, *classes):
super().__init__()
for rowclass in classes:
self.update({rowclass.__name__.lower(): rowclass})
def get_class(self, name):
return self.get(name, Row)
class Row(DotDict):
def __init__(self, table, row, session):
super().__init__()
if row:
try:
self._update(row._asdict())
except:
self._update(row)
self.__db = session.db
self.__table_name = table
self.__run__(session)
@property
def db(self):
return self.__db
@property
def table(self):
return self.__table_name
@property
def columns(self):
return self.keys()
## Subclass Row and redefine this function
def __run__(self, s):
pass
def _update(self, *args, **kwargs):
super().update(*args, **kwargs)
def delete(self, s=None):
izzylog.warning('deprecated function: Row.delete')
if s:
return self.delete_session(s)
with self.db.session as s:
return self.delete_session(s)
def delete_session(self, s):
izzylog.warning('deprecated function: Row.delete_session')
return s.remove(table=self.table, row=self)
def update(self, dict_data={}, s=None, **data):
izzylog.warning('deprecated function: Row.update')
dict_data.update(data)
self._update(dict_data)
if s:
return self.update_session(s, **self)
with self.db.session as s:
s.update(row=self, **self)
def update_session(self, s, dict_data={}, **data):
izzylog.warning('deprecated function: Row.update_session')
dict_data.update(data)
self._update(dict_data)
return s.update(table=self.table, row=self, **dict_data)

179
sql/izzylib/sql/session.py Normal file
View file

@ -0,0 +1,179 @@
from izzylib import DotDict, random_gen, izzylog
from sqlalchemy import text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import Session as sqlalchemy_session
class Session(sqlalchemy_session):
def __init__(self, db, trans=False):
super().__init__(bind=db.db, future=True)
self.closed = False
self.trans = trans
self.database = db
self.classes = db.classes
self.cache = db.cache
self.sessionid = random_gen(10)
self.database.sessions[self.sessionid] = self
# remove in the future
self.db = db
self._setup()
def __enter__(self):
if self.trans:
self.begin()
return self
def __exit__(self, exctype, value, tb):
if self.in_transaction():
if tb:
self.rollback()
self.commit()
self.close()
def _setup(self):
pass
@property
def table(self):
return self.db.table
def commit(self):
if not self.in_transaction():
return
super().commit()
def close(self):
super().close()
self.closed = True
del self.db.sessions[self.sessionid]
self.sessionid = None
def run(self, expression, **kwargs):
result = self.execute(text(expression), params=kwargs)
try:
return result.mappings().all()
except Exception as e:
izzylog.verbose(f'Session.run: {e.__class__.__name__}: {e}')
return result
def count(self, table_name, **kwargs):
return self.query(self.table[table_name]).filter_by(**kwargs).count()
def fetch(self, table, single=True, orderby=None, orderdir='asc', **kwargs):
RowClass = self.classes.get_class(table.lower())
query = self.query(self.table[table]).filter_by(**kwargs)
if not orderby:
rows = query.all()
else:
if orderdir == 'asc':
rows = query.order_by(getattr(self.table[table].c, orderby).asc()).all()
elif orderdir == 'desc':
rows = query.order_by(getattr(self.table[table].c, orderby).desc()).all()
else:
raise ValueError(f'Unsupported order direction: {orderdir}')
if single:
return RowClass(table, rows[0], self) if len(rows) > 0 else None
return [RowClass(table, row, self) for row in rows]
def search(self, *args, **kwargs):
kwargs.pop('single', None)
return self.fetch(*args, single=False, **kwargs)
def insert(self, table, return_row=False, **kwargs):
row = self.fetch(table, **kwargs)
if row:
row.update_session(self, **kwargs)
return
if getattr(self.table[table], 'timestamp', None) and not kwargs.get('timestamp'):
kwargs['timestamp'] = datetime.now()
return self.execute(self.table[table].insert().values(**kwargs))
if return_row:
return self.fetch(table, **kwargs)
def update(self, table=None, rowid=None, row=None, return_row=False, **kwargs):
if row:
rowid = row.id
table = row.table
if not rowid or not table:
raise ValueError('Missing row ID or table')
self.execute(self.table[table].update().where(self.table[table].c.id == rowid).values(**kwargs))
if return_row:
return self.fetch(table, id=rowid)
def remove(self, table=None, rowid=None, row=None):
if row:
rowid = row.id
table = row.table
if not rowid or not table:
raise ValueError('Missing row ID or table')
self.run(f'DELETE FROM {table} WHERE id=:id', id=rowid)
def append_column(self, table, column):
if column.name in self.db.get_columns(table):
logging.warning(f'Table "{table}" already has column "{column.name}"')
return
self.run(f'ALTER TABLE {table} ADD COLUMN {column.compile()}')
def remove_column(self, tbl, col):
table = self.table[tbl]
column = getattr(table, col, None)
columns = self.db.get_columns(tbl)
if col not in columns:
izzylog.info(f'Column "{col}" already exists')
return
columns.remove(col)
coltext = ','.join(columns)
self.run(f'CREATE TABLE {tbl}_temp AS SELECT {coltext} FROM {tbl}')
self.run(f'DROP TABLE {tbl}')
self.run(f'ALTER TABLE {tbl}_temp RENAME TO {tbl}')
def clear_table(self, table):
self.run(f'DELETE FROM {table}')

View file

@ -1,19 +0,0 @@
from enum import Enum
from izzylib import DotDict
class BaseType(Enum):
INTEGER = int
TEXT = str
BLOB = bytes
REAL = float
NUMERIC = float
Type = DotDict(
**{v: BaseType.INTEGER for v in ['INT', 'INTEGER', 'TINYINT', 'SMALLINT', 'MEDIUMINT', 'BIGINT', 'UNSIGNED BIG INT', 'INT2', 'INT8']},
**{v: BaseType.TEXT for v in ['CHARACTER', 'VARCHAR', 'VARYING CHARACTER', 'NCHAR', 'NATIVE CHARACTER', 'NVARCHAR', 'TEXT', 'CLOB', 'STRING', 'JSON']},
**{v: BaseType.BLOB for v in ['BYTES', 'BLOB']},
**{v: BaseType.REAL for v in ['REAL', 'DOUBLE', 'DOUBLE PRECISION', 'FLOAT']},
**{v: BaseType.NUMERIC for v in ['NUMERIC', 'DECIMAL', 'BOOLEAN', 'DATE', 'DATETIME']}
)