rework #3
28 changed files with 892 additions and 1682 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -121,7 +121,7 @@ reload.cfg
|
||||||
/izzylib
|
/izzylib
|
||||||
/base/izzylib/dbus
|
/base/izzylib/dbus
|
||||||
/base/izzylib/hasher
|
/base/izzylib/hasher
|
||||||
/base/izzylib/http_requests_client
|
/base/izzylib/http_urllib_client
|
||||||
/base/izzylib/http_server
|
/base/izzylib/http_server
|
||||||
/base/izzylib/mbus
|
/base/izzylib/mbus
|
||||||
/base/izzylib/sql
|
/base/izzylib/sql
|
||||||
|
|
|
@ -16,7 +16,7 @@ You only need to install the base and whatever sub-modules you want to use
|
||||||
|
|
||||||
$(venv)/bin/python -m pip install -e "git+https://git.barkshark.xyz/izaliamae/izzylib.git@rework#egg=izzylib-http-server&subdirectory=http_server"
|
$(venv)/bin/python -m pip install -e "git+https://git.barkshark.xyz/izaliamae/izzylib.git@rework#egg=izzylib-http-server&subdirectory=http_server"
|
||||||
|
|
||||||
$(venv)/bin/python -m pip install -e "git+https://git.barkshark.xyz/izaliamae/izzylib.git@rework#egg=izzylib-http-requests-client&subdirectory=requests_client"
|
$(venv)/bin/python -m pip install -e "git+https://git.barkshark.xyz/izaliamae/izzylib.git@rework#egg=izzylib-http-urllib-client&subdirectory=http_urllib_client"
|
||||||
|
|
||||||
$(venv)/bin/python -m pip install -e "git+https://git.barkshark.xyz/izaliamae/izzylib.git@rework#egg=izzylib-sql&subdirectory=sql"
|
$(venv)/bin/python -m pip install -e "git+https://git.barkshark.xyz/izaliamae/izzylib.git@rework#egg=izzylib-sql&subdirectory=sql"
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ You only need to install the base and whatever sub-modules you want to use
|
||||||
|
|
||||||
### From Source
|
### From Source
|
||||||
|
|
||||||
$(venv)/bin/python setup.py install ['all' or a combination of these: dbus hasher http_server requests_client sql template tinydb]
|
$(venv)/bin/python setup.py install ['all' or a combination of these: dbus hasher http_server http_urllib_client sql template tinydb]
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ from .misc import *
|
||||||
from .cache import CacheDecorator, LruCache, TtlCache
|
from .cache import CacheDecorator, LruCache, TtlCache
|
||||||
from .connection import Connection
|
from .connection import Connection
|
||||||
|
|
||||||
from .http_urllib_client import HttpUrllibClient, HttpUrllibResponse
|
from .http_client import HttpClient, HttpResponse
|
||||||
|
|
||||||
|
|
||||||
def log_import_error(package, *message):
|
def log_import_error(package, *message):
|
||||||
|
@ -48,10 +48,10 @@ except ImportError:
|
||||||
log_import_error('template', 'Failed to import http template classes. Jinja and HAML templates disabled')
|
log_import_error('template', 'Failed to import http template classes. Jinja and HAML templates disabled')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from izzylib.http_requests_client import *
|
from izzylib.http_urllib_client import *
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
log_import_error('http_requests_client', 'Failed to import Requests http client classes. Requests http client is disabled')
|
log_import_error('http_urllib_client', 'Failed to import Requests http client classes. Requests http client is disabled')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from izzylib.http_server import PasswordHasher, HttpServer, HttpServerRequest, HttpServerResponse
|
from izzylib.http_server import PasswordHasher, HttpServer, HttpServerRequest, HttpServerResponse
|
||||||
|
|
|
@ -110,18 +110,12 @@ class DefaultDotDict(DotDict):
|
||||||
|
|
||||||
|
|
||||||
class LowerDotDict(DotDict):
|
class LowerDotDict(DotDict):
|
||||||
def __getattr__(self, key):
|
def __getitem__(self, key):
|
||||||
return super().__getattr__(self, key.lower())
|
return super().__getitem__(key.lower())
|
||||||
|
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
return super().__setattr__(key.lower(), value)
|
return super().__setitem__(key.lower(), value)
|
||||||
|
|
||||||
|
|
||||||
def update(self, data):
|
|
||||||
data = {k.lower(): v for k,v in self.items()}
|
|
||||||
|
|
||||||
return super().update(data)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiDotDict(DotDict):
|
class MultiDotDict(DotDict):
|
||||||
|
|
|
@ -22,7 +22,7 @@ except ImportError:
|
||||||
methods = ['connect', 'delete', 'get', 'head', 'options', 'patch', 'post', 'put', 'trace']
|
methods = ['connect', 'delete', 'get', 'head', 'options', 'patch', 'post', 'put', 'trace']
|
||||||
|
|
||||||
|
|
||||||
class HttpUrllibClient:
|
class HttpClient:
|
||||||
def __init__(self, headers={}, useragent=f'IzzyLib/{__version__}', appagent=None, proxy_type='https', proxy_host=None, proxy_port=None):
|
def __init__(self, headers={}, useragent=f'IzzyLib/{__version__}', appagent=None, proxy_type='https', proxy_host=None, proxy_port=None):
|
||||||
proxy_ports = {
|
proxy_ports = {
|
||||||
'http': 80,
|
'http': 80,
|
||||||
|
@ -74,7 +74,7 @@ class HttpUrllibClient:
|
||||||
except HTTPError as e:
|
except HTTPError as e:
|
||||||
response = e.fp
|
response = e.fp
|
||||||
|
|
||||||
return HttpUrllibResponse(response)
|
return HttpResponse(response)
|
||||||
|
|
||||||
|
|
||||||
def file(self, url, filepath, *args, filename=None, size=2048, create_dirs=True, **kwargs):
|
def file(self, url, filepath, *args, filename=None, size=2048, create_dirs=True, **kwargs):
|
||||||
|
@ -141,7 +141,7 @@ class HttpUrllibClient:
|
||||||
return self.request(*args, headers=headers, **kwargs)
|
return self.request(*args, headers=headers, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class HttpUrllibResponse(object):
|
class HttpResponse(object):
|
||||||
def __init__(self, response):
|
def __init__(self, response):
|
||||||
self.body = response.read()
|
self.body = response.read()
|
||||||
self.headers = DefaultDotDict({k.lower(): v.lower() for k,v in response.headers.items()})
|
self.headers = DefaultDotDict({k.lower(): v.lower() for k,v in response.headers.items()})
|
|
@ -4,6 +4,7 @@ from datetime import datetime
|
||||||
from getpass import getpass, getuser
|
from getpass import getpass, getuser
|
||||||
from importlib import util
|
from importlib import util
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from . import izzylog
|
from . import izzylog
|
||||||
from .dotdict import DotDict
|
from .dotdict import DotDict
|
||||||
|
@ -27,7 +28,8 @@ __all__ = [
|
||||||
'time_function',
|
'time_function',
|
||||||
'time_function_pprint',
|
'time_function_pprint',
|
||||||
'timestamp',
|
'timestamp',
|
||||||
'var_name'
|
'var_name',
|
||||||
|
'Url'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -460,3 +462,26 @@ def var_name(single=True, **kwargs):
|
||||||
|
|
||||||
keys = list(kwargs.keys())
|
keys = list(kwargs.keys())
|
||||||
return key[0] if single else keys
|
return key[0] if single else keys
|
||||||
|
|
||||||
|
|
||||||
|
class Url(str):
|
||||||
|
protocols = {
|
||||||
|
'http': 80,
|
||||||
|
'https': 443,
|
||||||
|
'ftp': 21,
|
||||||
|
'ftps': 990
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, url):
|
||||||
|
str.__new__(Url, url)
|
||||||
|
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
|
self.__parsed = parsed
|
||||||
|
self.proto = parsed.scheme
|
||||||
|
self.host = parsed.netloc
|
||||||
|
self.path = parsed.path
|
||||||
|
self.query = parsed.query
|
||||||
|
self.username = parsed.username
|
||||||
|
self.password = parsed.password
|
||||||
|
self.port = self.protocols.get(self.proto) if not parsed.port else None
|
||||||
|
|
|
@ -78,6 +78,5 @@ class AccessLog(MiddlewareBase):
|
||||||
|
|
||||||
async def handler(self, request, response):
|
async def handler(self, request, response):
|
||||||
uagent = request.headers.get('user-agent', 'None')
|
uagent = request.headers.get('user-agent', 'None')
|
||||||
address = request.headers.get('x-real-ip', request.forwarded.get('for', request.remote_addr))
|
|
||||||
|
|
||||||
applog.info(f'({multiprocessing.current_process().name}) {address} {request.method} {request.path} {response.status} "{uagent}"')
|
applog.info(f'({multiprocessing.current_process().name}) {request.address} {request.method} {request.path} {response.status} "{uagent}"')
|
||||||
|
|
|
@ -11,6 +11,7 @@ class Request(sanic.request.Request):
|
||||||
super().__init__(url_bytes, headers, version, method, transport, app)
|
super().__init__(url_bytes, headers, version, method, transport, app)
|
||||||
|
|
||||||
self.Headers = Headers(headers)
|
self.Headers = Headers(headers)
|
||||||
|
self.address = self.headers.get('x-real-ip', self.forwarded.get('for', self.remote_addr))
|
||||||
self.data = Data(self)
|
self.data = Data(self)
|
||||||
self.template = self.app.template
|
self.template = self.app.template
|
||||||
self.user_level = 0
|
self.user_level = 0
|
||||||
|
|
30
http_urllib_client/izzylib/http_urllib_client/__init__.py
Normal file
30
http_urllib_client/izzylib/http_urllib_client/__init__.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
from .signatures import (
|
||||||
|
verify_request,
|
||||||
|
verify_headers,
|
||||||
|
parse_signature,
|
||||||
|
fetch_actor,
|
||||||
|
fetch_instance,
|
||||||
|
fetch_nodeinfo,
|
||||||
|
fetch_webfinger_account,
|
||||||
|
generate_rsa_key
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from .client import HttpUrllibClient, set_default_client
|
||||||
|
from .request import HttpUrllibRequest
|
||||||
|
from .response import HttpUrllibResponse
|
||||||
|
|
||||||
|
#__all__ = [
|
||||||
|
#'HttpRequestsClient',
|
||||||
|
#'HttpRequestsRequest',
|
||||||
|
#'HttpRequestsResponse',
|
||||||
|
#'fetch_actor',
|
||||||
|
#'fetch_instance',
|
||||||
|
#'fetch_nodeinfo',
|
||||||
|
#'fetch_webfinger_account',
|
||||||
|
#'generate_rsa_key',
|
||||||
|
#'parse_signature',
|
||||||
|
#'set_requests_client',
|
||||||
|
#'verify_headers',
|
||||||
|
#'verify_request',
|
||||||
|
#]
|
128
http_urllib_client/izzylib/http_urllib_client/client.py
Normal file
128
http_urllib_client/izzylib/http_urllib_client/client.py
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
import json, sys, urllib3
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from base64 import b64encode
|
||||||
|
from datetime import datetime
|
||||||
|
from functools import cached_property
|
||||||
|
from io import BytesIO
|
||||||
|
from izzylib import DefaultDotDict, DotDict, LowerDotDict, Path, izzylog as logging, __version__
|
||||||
|
from izzylib.exceptions import HttpFileDownloadedError
|
||||||
|
from ssl import SSLCertVerificationError
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from .request import HttpUrllibRequest
|
||||||
|
from .response import HttpUrllibResponse
|
||||||
|
from .signatures import set_client
|
||||||
|
|
||||||
|
|
||||||
|
Client = None
|
||||||
|
|
||||||
|
proxy_ports = {
|
||||||
|
'http': 80,
|
||||||
|
'https': 443
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class HttpUrllibClient:
|
||||||
|
def __init__(self, headers={}, useragent=None, appagent=None, proxy_type='https', proxy_host=None, proxy_port=None, num_pools=20):
|
||||||
|
if not useragent:
|
||||||
|
useragent = f'IzzyLib/{__version__}'
|
||||||
|
|
||||||
|
self.headers = {k:v.lower() for k,v in headers.items()}
|
||||||
|
self.agent = f'{useragent} ({appagent})' if appagent else useragent
|
||||||
|
|
||||||
|
if proxy_type not in ['http', 'https']:
|
||||||
|
raise ValueError(f'Not a valid proxy type: {proxy_type}')
|
||||||
|
|
||||||
|
if proxy_host:
|
||||||
|
proxy = f'{proxy_type}://{proxy_host}:{proxy_ports[proxy_type] if not proxy_port else proxy_port}'
|
||||||
|
self.pool = urllib3.ProxyManager(proxy, num_pools=num_pools)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.pool = urllib3.PoolManager(num_pools=num_pools)
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def agent(self):
|
||||||
|
return self.headers['user-agent']
|
||||||
|
|
||||||
|
|
||||||
|
@agent.setter
|
||||||
|
def agent(self, value):
|
||||||
|
self.headers['user-agent'] = value
|
||||||
|
|
||||||
|
|
||||||
|
def set_global(self):
|
||||||
|
set_default_client(self)
|
||||||
|
|
||||||
|
|
||||||
|
def build_request(self, *args, **kwargs):
|
||||||
|
return HttpUrllibRequest(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_request(self, request):
|
||||||
|
request.headers.update(self.headers)
|
||||||
|
response = self.pool.urlopen(*request._args, **request._kwargs)
|
||||||
|
return HttpUrllibResponse(response)
|
||||||
|
|
||||||
|
|
||||||
|
def request(self, *args, **kwargs):
|
||||||
|
return self.handle_request(self.build_request(*args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
def signed_request(self, privkey, keyid, *args, **kwargs):
|
||||||
|
return self.request(*args, privkey=privkey, keyid=keyid, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def download(self, url, filepath, *args, filename=None, **kwargs):
|
||||||
|
resp = self.request(url, *args, **kwargs)
|
||||||
|
|
||||||
|
if resp.status != 200:
|
||||||
|
raise HttpFileDownloadedError(f'Failed to download {url}: Status: {resp.status}, Body: {resp.body}')
|
||||||
|
|
||||||
|
return resp.save(filepath)
|
||||||
|
|
||||||
|
|
||||||
|
def image(self, url, filepath, *args, filename=None, ext='png', dimensions=(50, 50), **kwargs):
|
||||||
|
if not Image:
|
||||||
|
izzylog.error('Pillow module is not installed')
|
||||||
|
return
|
||||||
|
|
||||||
|
resp = self.request(url, *args, **kwargs)
|
||||||
|
|
||||||
|
if resp.status != 200:
|
||||||
|
izzylog.error(f'Failed to download {url}:', resp.status, resp.body)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not filename:
|
||||||
|
filename = Path(url).stem()
|
||||||
|
|
||||||
|
path = Path(filepath)
|
||||||
|
|
||||||
|
if not path.exists:
|
||||||
|
izzylog.error('Path does not exist:', path)
|
||||||
|
return False
|
||||||
|
|
||||||
|
byte = BytesIO()
|
||||||
|
image = Image.open(BytesIO(resp.body))
|
||||||
|
image.thumbnail(dimensions)
|
||||||
|
image.save(byte, format=ext.upper())
|
||||||
|
|
||||||
|
with path.join(filename).open('wb') as fd:
|
||||||
|
fd.write(byte.getvalue())
|
||||||
|
|
||||||
|
|
||||||
|
def json(self, *args, headers={}, activity=True, **kwargs):
|
||||||
|
json_type = 'activity+json' if activity else 'json'
|
||||||
|
headers.update({
|
||||||
|
'accept': f'application/{json_type}'
|
||||||
|
})
|
||||||
|
|
||||||
|
return self.request(*args, headers=headers, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def set_default_client(client=None):
|
||||||
|
global Client
|
||||||
|
Client = client or HttpClient()
|
||||||
|
set_client(Client)
|
111
http_urllib_client/izzylib/http_urllib_client/request.py
Normal file
111
http_urllib_client/izzylib/http_urllib_client/request.py
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
from Crypto.Hash import SHA256
|
||||||
|
from izzylib import DotDict, LowerDotDict, Url, boolean
|
||||||
|
from base64 import b64decode, b64encode
|
||||||
|
from datetime import datetime
|
||||||
|
from izzylib import izzylog as logging
|
||||||
|
|
||||||
|
from .signatures import sign_pkcs_headers
|
||||||
|
|
||||||
|
|
||||||
|
methods = ['delete', 'get', 'head', 'options', 'patch', 'post', 'put']
|
||||||
|
|
||||||
|
|
||||||
|
class HttpUrllibRequest:
|
||||||
|
def __init__(self, url, **kwargs):
|
||||||
|
self._body = b''
|
||||||
|
|
||||||
|
method = kwargs.get('method', 'get').lower()
|
||||||
|
|
||||||
|
if method not in methods:
|
||||||
|
raise ValueError(f'Invalid method: {method}')
|
||||||
|
|
||||||
|
self.url = Url(url)
|
||||||
|
self.body = kwargs.get('body')
|
||||||
|
self.method = method
|
||||||
|
self.headers = LowerDotDict(kwargs.get('headers', {}))
|
||||||
|
self.redirect = boolean(kwargs.get('redirect', True))
|
||||||
|
self.retries = int(kwargs.get('retries', 10))
|
||||||
|
self.timeout = int(kwargs.get('timeout', 5))
|
||||||
|
|
||||||
|
privkey = kwargs.get('privkey')
|
||||||
|
keyid = kwargs.get('keyid')
|
||||||
|
|
||||||
|
if privkey and keyid:
|
||||||
|
self.sign(privkey, keyid)
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _args(self):
|
||||||
|
return [self.method.upper(), self.url]
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _kwargs(self):
|
||||||
|
return {
|
||||||
|
'body': self.body,
|
||||||
|
'headers': self.headers,
|
||||||
|
'redirect': self.redirect,
|
||||||
|
'retries': self.retries,
|
||||||
|
'timeout': self.timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def body(self):
|
||||||
|
return self._body
|
||||||
|
|
||||||
|
|
||||||
|
@body.setter
|
||||||
|
def body(self, data):
|
||||||
|
if isinstance(data, dict):
|
||||||
|
data = DotDict(data).to_json()
|
||||||
|
|
||||||
|
elif any(map(isinstance, [data], [list, tuple])):
|
||||||
|
data = json.dumps(data)
|
||||||
|
|
||||||
|
if data == None:
|
||||||
|
data = b''
|
||||||
|
|
||||||
|
elif not isinstance(data, bytes):
|
||||||
|
data = bytes(data, 'utf-8')
|
||||||
|
|
||||||
|
self._body = data
|
||||||
|
|
||||||
|
|
||||||
|
def set_header(self, key, value):
|
||||||
|
self.headers[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
def unset_header(self, key):
|
||||||
|
self.headers.pop(key, None)
|
||||||
|
|
||||||
|
|
||||||
|
def sign(self, privkey, keyid):
|
||||||
|
self.unset_header('signature')
|
||||||
|
|
||||||
|
self.set_header('(request-target)', f'{self.method.lower()} {self.url.path}')
|
||||||
|
self.set_header('host', self.url.host)
|
||||||
|
self.set_header('date', datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'))
|
||||||
|
|
||||||
|
if self.body:
|
||||||
|
body_hash = b64encode(SHA256.new(self.body).digest()).decode("UTF-8")
|
||||||
|
|
||||||
|
self.set_header('digest', f'SHA-256={body_hash}')
|
||||||
|
self.set_header('content-length', str(len(self.body)))
|
||||||
|
|
||||||
|
sig = {
|
||||||
|
'keyId': keyid,
|
||||||
|
'algorithm': 'rsa-sha256',
|
||||||
|
'headers': ' '.join([k.lower() for k in self.headers.keys()]),
|
||||||
|
'signature': b64encode(sign_pkcs_headers(privkey, self.headers)).decode('UTF-8')
|
||||||
|
}
|
||||||
|
|
||||||
|
sig_items = [f'{k}="{v}"' for k,v in sig.items()]
|
||||||
|
sig_string = ','.join(sig_items)
|
||||||
|
|
||||||
|
self.set_header('signature', sig_string)
|
||||||
|
|
||||||
|
self.unset_header('(request-target)')
|
||||||
|
self.unset_header('host')
|
97
http_urllib_client/izzylib/http_urllib_client/response.py
Normal file
97
http_urllib_client/izzylib/http_urllib_client/response.py
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
from io import BytesIO
|
||||||
|
from izzylib import DefaultDotDict, DotDict, Path, Url
|
||||||
|
|
||||||
|
|
||||||
|
class HttpUrllibResponse:
|
||||||
|
def __init__(self, response):
|
||||||
|
self.response = response
|
||||||
|
|
||||||
|
self._dict = None
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.dict[key]
|
||||||
|
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self.dict[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
for line in self.headers.get('content-type', '').split(';'):
|
||||||
|
try:
|
||||||
|
k,v = line.split('=')
|
||||||
|
|
||||||
|
if k.lower == 'charset':
|
||||||
|
return v.lower()
|
||||||
|
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return 'utf-8'
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def headers(self):
|
||||||
|
return self.response.headers
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def status(self):
|
||||||
|
return self.response.status
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url(self):
|
||||||
|
return Url(self.response.geturl())
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def body(self):
|
||||||
|
data = self.response.read(cache_content=True)
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
data = self.response.data
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self):
|
||||||
|
return self.body.decode(self.encoding)
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dict(self):
|
||||||
|
if not self._dict:
|
||||||
|
self._dict = DotDict(self.text)
|
||||||
|
|
||||||
|
return self._dict
|
||||||
|
|
||||||
|
|
||||||
|
def json_pretty(self, indent=4):
|
||||||
|
return self.dict.to_json(indent)
|
||||||
|
|
||||||
|
|
||||||
|
def chunks(self, size=1024):
|
||||||
|
return self.response.stream(amt=size)
|
||||||
|
|
||||||
|
|
||||||
|
def save(self, path, overwrite=True, create_parents=True):
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
if not path.parent.exists:
|
||||||
|
if not create_parents:
|
||||||
|
raise ValueError(f'Path does not exist: {path.parent}')
|
||||||
|
|
||||||
|
path.parent.mkdir()
|
||||||
|
|
||||||
|
if overwrite and path.exists:
|
||||||
|
path.delete()
|
||||||
|
|
||||||
|
with path.open('wb') as fd:
|
||||||
|
for chunk in self.chunks():
|
||||||
|
fd.write(chunk)
|
|
@ -5,21 +5,21 @@ from setuptools import setup, find_namespace_packages
|
||||||
requires = [
|
requires = [
|
||||||
'pillow==8.2.0',
|
'pillow==8.2.0',
|
||||||
'pycryptodome==3.10.1',
|
'pycryptodome==3.10.1',
|
||||||
'requests==2.25.1',
|
'urllib==1.26.5',
|
||||||
'tldextract==3.1.0'
|
'tldextract==3.1.0'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="IzzyLib Requests Client",
|
name="IzzyLib Urllib3 Client",
|
||||||
version='0.6.0',
|
version='0.6.0',
|
||||||
packages=find_namespace_packages(include=['izzylib.http_requests_client']),
|
packages=find_namespace_packages(include=['izzylib.http_urllib_client']),
|
||||||
python_requires='>=3.7.0',
|
python_requires='>=3.7.0',
|
||||||
install_requires=requires,
|
install_requires=requires,
|
||||||
include_package_data=False,
|
include_package_data=False,
|
||||||
author='Zoey Mae',
|
author='Zoey Mae',
|
||||||
author_email='admin@barkshark.xyz',
|
author_email='admin@barkshark.xyz',
|
||||||
description='A Requests client with support for http header signing and verifying',
|
description='A Urllib3 client with support for http header signing and verifying',
|
||||||
keywords='web http client',
|
keywords='web http client',
|
||||||
url='https://git.barkshark.xyz/izaliamae/izzylib',
|
url='https://git.barkshark.xyz/izaliamae/izzylib',
|
||||||
project_urls={
|
project_urls={
|
|
@ -1,33 +0,0 @@
|
||||||
from .signature import (
|
|
||||||
verify_request,
|
|
||||||
verify_headers,
|
|
||||||
parse_signature,
|
|
||||||
fetch_actor,
|
|
||||||
fetch_instance,
|
|
||||||
fetch_nodeinfo,
|
|
||||||
fetch_webfinger_account,
|
|
||||||
generate_rsa_key
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
from .client import (
|
|
||||||
HttpRequestsClient,
|
|
||||||
HttpRequestsRequest,
|
|
||||||
HttpRequestsResponse,
|
|
||||||
set_requests_client
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'HttpRequestsClient',
|
|
||||||
'HttpRequestsRequest',
|
|
||||||
'HttpRequestsResponse',
|
|
||||||
'fetch_actor',
|
|
||||||
'fetch_instance',
|
|
||||||
'fetch_nodeinfo',
|
|
||||||
'fetch_webfinger_account',
|
|
||||||
'generate_rsa_key',
|
|
||||||
'parse_signature',
|
|
||||||
'set_requests_client',
|
|
||||||
'verify_headers',
|
|
||||||
'verify_request',
|
|
||||||
]
|
|
|
@ -1,227 +0,0 @@
|
||||||
import json, requests, sys
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from base64 import b64encode
|
|
||||||
from datetime import datetime
|
|
||||||
from functools import cached_property
|
|
||||||
from io import BytesIO
|
|
||||||
from izzylib import DefaultDotDict, DotDict, Path, izzylog as logging, __version__
|
|
||||||
from izzylib.exceptions import HttpFileDownloadedError
|
|
||||||
from ssl import SSLCertVerificationError
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from .signature import sign_request, set_client
|
|
||||||
|
|
||||||
|
|
||||||
Client = None
|
|
||||||
methods = ['connect', 'delete', 'get', 'head', 'options', 'patch', 'post', 'put', 'trace']
|
|
||||||
|
|
||||||
|
|
||||||
class HttpRequestsClient(object):
|
|
||||||
def __init__(self, headers={}, useragent=f'IzzyLib/{__version__}', appagent=None, proxy_type='https', proxy_host=None, proxy_port=None):
|
|
||||||
proxy_ports = {
|
|
||||||
'http': 80,
|
|
||||||
'https': 443
|
|
||||||
}
|
|
||||||
|
|
||||||
if proxy_type not in ['http', 'https']:
|
|
||||||
raise ValueError(f'Not a valid proxy type: {proxy_type}')
|
|
||||||
|
|
||||||
self.headers=headers
|
|
||||||
self.agent = f'{useragent} ({appagent})' if appagent else useragent
|
|
||||||
self.proxy = DotDict({
|
|
||||||
'enabled': True if proxy_host else False,
|
|
||||||
'ptype': proxy_type,
|
|
||||||
'host': proxy_host,
|
|
||||||
'port': proxy_ports[proxy_type] if not proxy_port else proxy_port
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def set_global(self):
|
|
||||||
set_requests_client(self)
|
|
||||||
|
|
||||||
|
|
||||||
def build_request(self, *args, method='get', privkey=None, keyid=None, **kwargs):
|
|
||||||
if method.lower() not in methods:
|
|
||||||
raise ValueError(f'Invalid method: {method}')
|
|
||||||
|
|
||||||
request = HttpRequestsRequest(self, *args, method=method.lower(), **kwargs)
|
|
||||||
|
|
||||||
if privkey and keyid:
|
|
||||||
request.sign(privkey, keyid)
|
|
||||||
|
|
||||||
return request
|
|
||||||
|
|
||||||
|
|
||||||
def request(self, *args, **kwargs):
|
|
||||||
request = self.build_request(*args, **kwargs)
|
|
||||||
return HttpRequestsResponse(request.send())
|
|
||||||
|
|
||||||
|
|
||||||
def signed_request(self, privkey, keyid, *args, **kwargs):
|
|
||||||
return self.request(*args, privkey=privkey, keyid=keyid, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def download(self, url, filepath, *args, filename=None, **kwargs):
|
|
||||||
resp = self.request(url, *args, **kwargs)
|
|
||||||
|
|
||||||
if resp.status != 200:
|
|
||||||
raise HttpFileDownloadedError(f'Failed to download {url}: Status: {resp.status}, Body: {resp.body}')
|
|
||||||
|
|
||||||
return resp.save(filepath)
|
|
||||||
|
|
||||||
|
|
||||||
def image(self, url, filepath, *args, filename=None, ext='png', dimensions=(50, 50), **kwargs):
|
|
||||||
if not Image:
|
|
||||||
izzylog.error('Pillow module is not installed')
|
|
||||||
return
|
|
||||||
|
|
||||||
resp = self.request(url, *args, **kwargs)
|
|
||||||
|
|
||||||
if resp.status != 200:
|
|
||||||
izzylog.error(f'Failed to download {url}:', resp.status, resp.body)
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not filename:
|
|
||||||
filename = Path(url).stem()
|
|
||||||
|
|
||||||
path = Path(filepath)
|
|
||||||
|
|
||||||
if not path.exists:
|
|
||||||
izzylog.error('Path does not exist:', path)
|
|
||||||
return False
|
|
||||||
|
|
||||||
byte = BytesIO()
|
|
||||||
image = Image.open(BytesIO(resp.body))
|
|
||||||
image.thumbnail(dimensions)
|
|
||||||
image.save(byte, format=ext.upper())
|
|
||||||
|
|
||||||
with path.join(filename).open('wb') as fd:
|
|
||||||
fd.write(byte.getvalue())
|
|
||||||
|
|
||||||
|
|
||||||
def json(self, *args, headers={}, activity=True, **kwargs):
|
|
||||||
json_type = 'activity+json' if activity else 'json'
|
|
||||||
headers.update({
|
|
||||||
'accept': f'application/{json_type}'
|
|
||||||
})
|
|
||||||
|
|
||||||
return self.request(*args, headers=headers, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class HttpRequestsRequest(object):
|
|
||||||
def __init__(self, client, url, data=b'', headers={}, query={}, method='get'):
|
|
||||||
parsed = urlparse(url)
|
|
||||||
self.args = [url]
|
|
||||||
self.kwargs = DotDict({'params': query})
|
|
||||||
self.method = method.lower()
|
|
||||||
self.client = client
|
|
||||||
self.path = parsed.path
|
|
||||||
self.host = parsed.netloc
|
|
||||||
self.body = data
|
|
||||||
|
|
||||||
new_headers = client.headers.copy()
|
|
||||||
new_headers.update(headers)
|
|
||||||
|
|
||||||
parsed_headers = {k.lower(): v for k,v in new_headers.items()}
|
|
||||||
|
|
||||||
if not parsed_headers.get('user-agent'):
|
|
||||||
parsed_headers['user-agent'] = client.agent
|
|
||||||
|
|
||||||
self.kwargs['headers'] = DotDict(new_headers)
|
|
||||||
|
|
||||||
if client.proxy.enabled:
|
|
||||||
self.kwargs['proxies'] = DotDict({self.proxy.ptype: f'{self.proxy.ptype}://{self.proxy.host}:{self.proxy.port}'})
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def body(self):
|
|
||||||
return self.kwargs.data
|
|
||||||
|
|
||||||
|
|
||||||
@body.setter
|
|
||||||
def body(self, data):
|
|
||||||
self.kwargs.data = data.encode('utf-8') if isinstance(data, str) else data
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def headers(self):
|
|
||||||
return self.kwargs.headers
|
|
||||||
|
|
||||||
|
|
||||||
def add_header(self, key, value):
|
|
||||||
self.kwargs.headers[key] = value
|
|
||||||
|
|
||||||
|
|
||||||
def remove_header(self, key):
|
|
||||||
self.kwargs.headers.pop(key, None)
|
|
||||||
|
|
||||||
|
|
||||||
def send(self):
|
|
||||||
func = getattr(requests, self.method)
|
|
||||||
return func(*self.args, **self.kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def sign(self, privkey, keyid):
|
|
||||||
sign_request(self, privkey, keyid)
|
|
||||||
|
|
||||||
|
|
||||||
class HttpRequestsResponse(object):
|
|
||||||
def __init__(self, response):
|
|
||||||
self.response = response
|
|
||||||
self.data = b''
|
|
||||||
self.headers = DefaultDotDict({k.lower(): v.lower() for k,v in response.headers.items()})
|
|
||||||
self.status = response.status_code
|
|
||||||
self.url = response.url
|
|
||||||
|
|
||||||
|
|
||||||
def chunks(self, size=256):
|
|
||||||
return self.response.iter_content(chunk_size=256)
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def body(self):
|
|
||||||
for chunk in self.chunks():
|
|
||||||
self.data += chunk
|
|
||||||
|
|
||||||
return self.data
|
|
||||||
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def text(self):
|
|
||||||
return self.data.decode(self.response.encoding)
|
|
||||||
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def json(self):
|
|
||||||
try:
|
|
||||||
return DotDict(self.text)
|
|
||||||
|
|
||||||
except:
|
|
||||||
return DotDict(self.body)
|
|
||||||
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def json_pretty(self, indent=4):
|
|
||||||
return json.dumps(self.json, indent=indent)
|
|
||||||
|
|
||||||
|
|
||||||
def save(self, path, overwrite=True):
|
|
||||||
path = Path(path)
|
|
||||||
|
|
||||||
if not path.parent.exists:
|
|
||||||
raise ValueError(f'Path does not exist: {path.parent}')
|
|
||||||
|
|
||||||
if overwrite and path.exists:
|
|
||||||
path.delete()
|
|
||||||
|
|
||||||
with path.open('wb') as fd:
|
|
||||||
for chunk in self.chunks():
|
|
||||||
fd.write(chunk)
|
|
||||||
|
|
||||||
|
|
||||||
def set_requests_client(client=None):
|
|
||||||
global Client
|
|
||||||
Client = client or RequestsClient()
|
|
||||||
set_client(Client)
|
|
|
@ -1,6 +1,13 @@
|
||||||
# old sql classes
|
## Normal SQL client
|
||||||
from .generic import SqlColumn, CustomRows, SqlSession, SqlDatabase, Tables, OperationalError, ProgrammingError
|
from .database import Database, OperationalError, ProgrammingError
|
||||||
from .sqlite_server import SqliteClient, SqliteColumn, SqliteServer, SqliteSession
|
from .session import Session
|
||||||
|
from .column import Column
|
||||||
|
|
||||||
#from .database import Database, Session
|
## Sqlite server
|
||||||
#from .queries import Column, Insert, Select, Table, Tables, Update
|
#from .sqlite_server import SqliteClient, SqliteColumn, SqliteServer, SqliteSession
|
||||||
|
|
||||||
|
|
||||||
|
## Compat
|
||||||
|
SqlDatabase = Database
|
||||||
|
SqlSession = Session
|
||||||
|
SqlColumn = Column
|
||||||
|
|
54
sql/izzylib/sql/column.py
Normal file
54
sql/izzylib/sql/column.py
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
from sqlalchemy import ForeignKey
|
||||||
|
from sqlalchemy import (
|
||||||
|
Column as sqlalchemy_column,
|
||||||
|
types as Types
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
SqlTypes = {t.lower(): getattr(Types, t) for t in dir(Types) if not t.startswith('_')}
|
||||||
|
|
||||||
|
|
||||||
|
class Column(sqlalchemy_column):
|
||||||
|
def __init__(self, name, stype=None, fkey=None, **kwargs):
|
||||||
|
if not stype and not kwargs:
|
||||||
|
if name == 'id':
|
||||||
|
stype = 'integer'
|
||||||
|
kwargs['primary_key'] = True
|
||||||
|
kwargs['autoincrement'] = True
|
||||||
|
|
||||||
|
elif name == 'timestamp':
|
||||||
|
stype = 'datetime'
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError('Missing column type and options')
|
||||||
|
|
||||||
|
stype = (stype.lower() if type(stype) == str else stype) or 'string'
|
||||||
|
|
||||||
|
if type(stype) == str:
|
||||||
|
try:
|
||||||
|
stype = SqlTypes[stype.lower()]
|
||||||
|
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError(f'Invalid SQL data type: {stype}')
|
||||||
|
|
||||||
|
options = [name, stype]
|
||||||
|
|
||||||
|
if fkey:
|
||||||
|
options.append(ForeignKey(fkey))
|
||||||
|
|
||||||
|
super().__init__(*options, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def compile(self):
|
||||||
|
sql = f'{self.name} {self.type}'
|
||||||
|
|
||||||
|
if not self.nullable:
|
||||||
|
sql += ' NOT NULL'
|
||||||
|
|
||||||
|
if self.primary_key:
|
||||||
|
sql += ' PRIMARY KEY'
|
||||||
|
|
||||||
|
if self.unique:
|
||||||
|
sql += ' UNIQUE'
|
||||||
|
|
||||||
|
return sql
|
|
@ -1,100 +0,0 @@
|
||||||
import importlib, sqlite3, ssl
|
|
||||||
|
|
||||||
from getpass import getuser
|
|
||||||
from izzylib import DotDict, Path, izzylog
|
|
||||||
|
|
||||||
|
|
||||||
defaults = {
|
|
||||||
'name': (None, str),
|
|
||||||
'host': (None, str),
|
|
||||||
'port': (None, int),
|
|
||||||
'username': (getuser(), str),
|
|
||||||
'password': (None, str),
|
|
||||||
'ssl': ('allow', str),
|
|
||||||
'ssl_context': (ssl.create_default_context(), ssl.SSLContext),
|
|
||||||
'ssl_key': (None, Path),
|
|
||||||
'ssl_cert': (None, Path),
|
|
||||||
'max_connections': (25, int),
|
|
||||||
'type': ('sqlite', str),
|
|
||||||
'module': (sqlite3, None),
|
|
||||||
'mod_name': ('sqlite3', str),
|
|
||||||
'timeout': (5, int),
|
|
||||||
'args': ([], list),
|
|
||||||
'kwargs': ({}, dict)
|
|
||||||
}
|
|
||||||
|
|
||||||
modtypes = {
|
|
||||||
'sqlite': ['sqlite3'],
|
|
||||||
'postgresql': ['pg8000', 'psycopg2', 'psycopg3', 'pgdb'],
|
|
||||||
'mysql': ['mysqldb', 'trio_mysql'],
|
|
||||||
'mssql': ['pymssql', 'adodbapi']
|
|
||||||
}
|
|
||||||
|
|
||||||
sslmodes = ['disable', 'allow', 'prefer', 'require', 'verify-ca', 'verify-full']
|
|
||||||
|
|
||||||
|
|
||||||
class Config(DotDict):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__({k: v[0] for k,v in defaults.items()})
|
|
||||||
|
|
||||||
module = kwargs.pop('module', None)
|
|
||||||
|
|
||||||
if module:
|
|
||||||
self.parse_module(module)
|
|
||||||
|
|
||||||
self.update(kwargs)
|
|
||||||
|
|
||||||
if self.ssl != 'disable' and (self.ssl_key or self.ssl_cert):
|
|
||||||
self.ssl_context.load_cert_chain(self.ssl_cert, self.ssl_key)
|
|
||||||
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
if key not in defaults:
|
|
||||||
raise KeyError(f'Invalid config option: {key}')
|
|
||||||
|
|
||||||
valtype = defaults[key][1]
|
|
||||||
|
|
||||||
if valtype and value and not isinstance(value, valtype):
|
|
||||||
raise TypeError(f'{key} should be a {valtype}, not a {value.__class__.__name__}')
|
|
||||||
|
|
||||||
if key == 'ssl' and value == True:
|
|
||||||
value = ssl.create_default_context()
|
|
||||||
|
|
||||||
super().__setitem__(key, value)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_module(self, name):
|
|
||||||
module = None
|
|
||||||
module_type = None
|
|
||||||
module_name = None
|
|
||||||
|
|
||||||
if name == 'sqlite3':
|
|
||||||
name = 'sqlite'
|
|
||||||
|
|
||||||
for mtype, modules in modtypes.items():
|
|
||||||
if name == mtype:
|
|
||||||
module_type = name
|
|
||||||
|
|
||||||
for mod in modules:
|
|
||||||
try:
|
|
||||||
module = importlib.import_module(mod)
|
|
||||||
module_name = mod
|
|
||||||
break
|
|
||||||
except ImportError:
|
|
||||||
izzylog.verbose(f'Database module not installed:', mod)
|
|
||||||
|
|
||||||
elif name in modules:
|
|
||||||
try:
|
|
||||||
module = importlib.import_module(name)
|
|
||||||
module_type = mtype
|
|
||||||
module_name = name
|
|
||||||
break
|
|
||||||
except ImportError:
|
|
||||||
izzylog.error(f'Database module not installed:', name)
|
|
||||||
|
|
||||||
if None in (module, module_name, module_type):
|
|
||||||
raise ValueError(f'Failed to find module for {name}')
|
|
||||||
|
|
||||||
self.module = module
|
|
||||||
self.mod_name = module_name
|
|
||||||
self.type = module_type
|
|
|
@ -1,360 +1,186 @@
|
||||||
import sqlite3, traceback
|
import json, pkgutil, sys, threading, time
|
||||||
|
|
||||||
from functools import partial
|
from contextlib import contextmanager
|
||||||
from getpass import getuser
|
from datetime import datetime
|
||||||
from izzylib import DotDict, izzylog, boolean, random_gen
|
from izzylib import LruCache, DotDict, Path, nfs_check, izzylog
|
||||||
|
from sqlalchemy import Table, create_engine
|
||||||
|
from sqlalchemy.exc import OperationalError, ProgrammingError
|
||||||
|
from sqlalchemy.engine import URL
|
||||||
|
from sqlalchemy.schema import MetaData
|
||||||
|
|
||||||
from . import error
|
from .rows import RowClasses
|
||||||
from .config import Config
|
from .session import Session
|
||||||
from .queries import Column, Delete, Insert, Select, Table, Tables, Update
|
|
||||||
|
|
||||||
|
modules = dict(
|
||||||
|
postgresql = ['pygresql', 'pg8000', 'psycopg2', 'psycopg3']
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
def __init__(self, tables=None, **kwargs):
|
def __init__(self, dbtype='sqlite', **kwargs):
|
||||||
self.tables = tables
|
self._connect_args = [dbtype, kwargs]
|
||||||
self.cfg = Config(**kwargs)
|
self.db = None
|
||||||
self.sessions = DotDict()
|
self.cache = None
|
||||||
|
self.config = DotDict()
|
||||||
|
self.meta = MetaData()
|
||||||
|
self.classes = RowClasses(*kwargs.get('row_classes', []))
|
||||||
|
self.cache = None
|
||||||
|
|
||||||
|
self.session_class = kwargs.get('session_class', Session)
|
||||||
|
self.sessions = {}
|
||||||
|
|
||||||
|
self.open()
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_cache(self):
|
||||||
|
self.cache = DotDict({table: LruCache() for table in self.get_tables()})
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def session(self):
|
def session(self):
|
||||||
return self.get_session(False)
|
return self.session_class(self)
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def session_trans(self):
|
def dbtype(self):
|
||||||
return self.get_session(True)
|
return self.db.url.get_backend_name()
|
||||||
|
|
||||||
|
|
||||||
def connect(self, sid, session):
|
@property
|
||||||
if len(self.sessions) >= self.cfg.max_connections:
|
def table(self):
|
||||||
raise error.MaxConnectionsError(f'Cannot start a new session with id {sid}. Reach max connection count of {self.cfg.max_connections}.')
|
return DotDict(self.meta.tables)
|
||||||
|
|
||||||
self.sessions[sid] = session
|
|
||||||
|
|
||||||
|
|
||||||
def disconnect(self, sid):
|
|
||||||
self.sessions[sid].disconnect()
|
|
||||||
del self.sessions[sid]
|
|
||||||
|
|
||||||
|
|
||||||
def disconnect_all(self):
|
|
||||||
sids = []
|
|
||||||
|
|
||||||
for sid in self.sessions.keys():
|
|
||||||
sids.append(sid)
|
|
||||||
|
|
||||||
for sid in sids:
|
|
||||||
self.disconnect(sid)
|
|
||||||
|
|
||||||
|
|
||||||
def get_session(self, trans=True):
|
|
||||||
session = Session(self, trans)
|
|
||||||
self.sessions[session.id] = session
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
def execute(self, *args):
|
|
||||||
with self.session as s:
|
|
||||||
s.execute(*args)
|
|
||||||
|
|
||||||
|
|
||||||
def load_tables(self, path):
|
|
||||||
self.tables = Tables.new_from_json_file(path)
|
|
||||||
|
|
||||||
|
|
||||||
def pre_setup(self):
|
|
||||||
if self.cfg.type != 'postgresql':
|
|
||||||
izzylog.verbose(f'Database not supported for pre_setup: {self.cfg.type}')
|
|
||||||
return
|
|
||||||
|
|
||||||
original_database = self.cfg.name
|
|
||||||
self.cfg.name = 'postgres'
|
|
||||||
|
|
||||||
with self.session as s:
|
|
||||||
s.conn.autocommit = True
|
|
||||||
s.rollback()
|
|
||||||
|
|
||||||
if original_database not in s.get_databases():
|
|
||||||
#s.execute('SET AUTOCOMMIT = OFF')
|
|
||||||
s.cursor.execute(f'CREATE DATABASE {original_database}')
|
|
||||||
|
|
||||||
s.conn.autocommit = False
|
|
||||||
|
|
||||||
self.cfg.name = original_database
|
|
||||||
|
|
||||||
|
|
||||||
def set_row_class(self, table, row_class):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Session:
|
|
||||||
def __init__(self, db, trans):
|
|
||||||
self.id = random_gen()
|
|
||||||
self.db = db
|
|
||||||
self.cfg = db.cfg
|
|
||||||
self.trans = trans
|
|
||||||
self.trans_state = False
|
|
||||||
self.conn = None
|
|
||||||
self.cursor = None
|
|
||||||
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
try:
|
|
||||||
izzylog.verbose('Deleting session:', self.id)
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
if izzylog.get_config('level') >= 20:
|
|
||||||
print('[izzylib] VERBOSE: Deleting session:', self.id)
|
|
||||||
|
|
||||||
self.db.sessions.pop(self.id, None)
|
|
||||||
|
|
||||||
if self.conn:
|
|
||||||
self.disconnect()
|
|
||||||
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.connect()
|
|
||||||
|
|
||||||
if self.trans:
|
|
||||||
self.begin()
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
|
||||||
if exc_traceback:
|
|
||||||
self.rollback()
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.commit()
|
|
||||||
|
|
||||||
self.disconnect()
|
|
||||||
self.db.disconnect(self.id)
|
|
||||||
|
|
||||||
|
|
||||||
def connect(self):
|
|
||||||
if self.conn:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.db.connect(self.id, self)
|
|
||||||
|
|
||||||
if self.cfg.type == 'sqlite':
|
|
||||||
self.conn = self.cfg.module.connect(self.cfg.name, self.cfg.timeout, check_same_thread=True)
|
|
||||||
|
|
||||||
elif self.cfg.type == 'postgresql':
|
|
||||||
options = dict(
|
|
||||||
host = self.cfg.host or '/var/run/postgresql',
|
|
||||||
port = self.cfg.port or 5432,
|
|
||||||
database = self.cfg.name or 'postgresql',
|
|
||||||
user = self.cfg.username or getuser(),
|
|
||||||
password = self.cfg.password,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.mod_name == 'pg8000':
|
|
||||||
if options['host'] in [None, '/var/run/postgresql']:
|
|
||||||
port = options.pop('port')
|
|
||||||
options['unix_sock'] = options.pop('host') + f'/.s.PGSQL.{port}'
|
|
||||||
|
|
||||||
## SSL is a pain in the ass tbh. Gonna deal with this later
|
|
||||||
#if self.cfg.mod_name == 'pg8000':
|
|
||||||
#options['sslmode'] = self.cfg.ssl
|
|
||||||
#options['ssl_context'] = self.cfg.ssl_context
|
|
||||||
|
|
||||||
#elif self.cfg.mod_name == 'psycopg2':
|
|
||||||
#options['sslcert'] = self.cfg.ssl_cert
|
|
||||||
#options['sslkey'] = self.cfg.ssl_key
|
|
||||||
|
|
||||||
self.conn = self.cfg.module.connect(**options)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.conn.autocommit = False
|
|
||||||
except:
|
|
||||||
izzylog.verbose('Failed to turn off autocommit')
|
|
||||||
|
|
||||||
self.cursor = self.conn.cursor()
|
|
||||||
|
|
||||||
|
|
||||||
def disconnect(self):
|
|
||||||
if not self.conn:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.cursor.close()
|
|
||||||
self.conn.close()
|
|
||||||
|
|
||||||
self.cursor = None
|
|
||||||
self.conn = None
|
|
||||||
|
|
||||||
|
|
||||||
def begin(self):
|
|
||||||
if self.trans_state:
|
|
||||||
return
|
|
||||||
|
|
||||||
#self.conn.begin()
|
|
||||||
self.execute('BEGIN TRANSACTION')
|
|
||||||
self.trans_state = True
|
|
||||||
|
|
||||||
|
|
||||||
def rollback(self):
|
|
||||||
if not self.trans_state:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.conn.rollback()
|
|
||||||
#self.execute('ROLLBACK TRANSACTION')
|
|
||||||
self.trans_state = False
|
|
||||||
|
|
||||||
|
|
||||||
def commit(self):
|
|
||||||
if not self.trans_state:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.conn.commit()
|
|
||||||
#self.execute('COMMIT TRANSACTION')
|
|
||||||
self.trans_state = False
|
|
||||||
|
|
||||||
|
|
||||||
## data management functions
|
|
||||||
def execute(self, string, values=[]):
|
|
||||||
if any(map(string.lower().startswith, ['insert', 'update', 'remove', 'create', 'drop'])) and not self.trans_state:
|
|
||||||
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
|
|
||||||
|
|
||||||
self.cursor.execute(string, values)
|
|
||||||
return self.cursor
|
|
||||||
|
|
||||||
|
|
||||||
def fetch(self, table, single=True, **kwargs):
|
|
||||||
rows = []
|
|
||||||
data = Select(table, type=self.cfg.type, **kwargs).exec(self)
|
|
||||||
|
|
||||||
for line in data:
|
|
||||||
row = Row(table, self.cursor.description, line)
|
|
||||||
|
|
||||||
if single:
|
|
||||||
return row
|
|
||||||
|
|
||||||
rows.append(row)
|
|
||||||
|
|
||||||
return rows if not single else None
|
|
||||||
|
|
||||||
|
|
||||||
def search(self, table, **kwargs):
|
|
||||||
return self.fetch(table, single=False, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def insert(self, table, **kwargs):
|
|
||||||
if not self.trans_state:
|
|
||||||
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
|
|
||||||
|
|
||||||
Insert(table, type=self.cfg.type, **kwargs).exec(self)
|
|
||||||
return self.fetch(table, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def update(self, table, rowid, **kwargs):
|
|
||||||
if not self.trans_state:
|
|
||||||
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
|
|
||||||
|
|
||||||
Update(table, rowid, type=self.cfg.type, **kwargs).exec(self)
|
|
||||||
return self.fetch(table, id=rowid)
|
|
||||||
|
|
||||||
|
|
||||||
def delete(self, table, **kwargs):
|
|
||||||
if not self.trans_state:
|
|
||||||
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
|
|
||||||
|
|
||||||
Delete(table, type=self.cfg.type, **kwargs).exec(self)
|
|
||||||
|
|
||||||
|
|
||||||
## helper functions
|
|
||||||
def get_columns(self, table):
|
|
||||||
if table not in self.get_tables():
|
|
||||||
raise KeyError(f'Not an existing table: {table}')
|
|
||||||
|
|
||||||
if self.cfg.type == 'sqlite':
|
|
||||||
rows = self.execute(f'PRAGMA table_info({table})')
|
|
||||||
return [row[1] for row in rows]
|
|
||||||
|
|
||||||
elif self.cfg.type == 'postgresql':
|
|
||||||
rows = self.execute(f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table}'")
|
|
||||||
return [row[0] for row in rows]
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
|
|
||||||
|
|
||||||
|
|
||||||
def get_tables(self):
|
def get_tables(self):
|
||||||
if self.cfg.type == 'sqlite':
|
return list(self.table.keys())
|
||||||
rows = self.execute("SELECT name FROM sqlite_master WHERE type IN ('table','view') and name NOT LIKE 'sqlite_%'")
|
|
||||||
|
|
||||||
elif self.cfg.type == 'postgresql':
|
|
||||||
rows = self.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' ORDER BY table_name")
|
def get_columns(self, table):
|
||||||
|
return list(col.name for col in self.table[table].columns)
|
||||||
|
|
||||||
|
|
||||||
|
def new_session(self, trans=True):
|
||||||
|
return self.session_class(self, trans=trans)
|
||||||
|
|
||||||
|
|
||||||
|
## Leaving link to example code for read-only sqlite for later use
|
||||||
|
## https://github.com/pudo/dataset/issues/136#issuecomment-128693122
|
||||||
|
def open(self):
|
||||||
|
dbtype, kwargs = self._connect_args
|
||||||
|
engine_kwargs = {
|
||||||
|
'future': True,
|
||||||
|
#'maxconnections': 25
|
||||||
|
}
|
||||||
|
|
||||||
|
if not kwargs.get('name'):
|
||||||
|
raise KeyError('Database "name" is not set')
|
||||||
|
|
||||||
|
if dbtype == 'sqlite':
|
||||||
|
database = kwargs['name']
|
||||||
|
|
||||||
|
if nfs_check(database):
|
||||||
|
izzylog.warning('Database file is on an NFS share which does not support locking. Any writes to the database will fail')
|
||||||
|
|
||||||
|
engine_kwargs['connect_args'] = {'check_same_thread': False}
|
||||||
|
|
||||||
|
elif dbtype == 'postgresql':
|
||||||
|
ssl_context = kwargs.get('ssl')
|
||||||
|
|
||||||
|
if ssl_context:
|
||||||
|
engine_kwargs['ssl_context'] = ssl_context
|
||||||
|
|
||||||
|
if not kwargs.get('host'):
|
||||||
|
kwargs['unix_socket'] = '/var/run/postgresql'
|
||||||
|
|
||||||
|
if kwargs.get('host') and Path(kwargs['host']).exists():
|
||||||
|
kwargs['unix_socket'] = kwargs.pop('host')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
|
raise TypeError(f'Unsupported database type: {dbtype}')
|
||||||
|
|
||||||
return [row[0] for row in rows]
|
self.config.update(kwargs)
|
||||||
|
|
||||||
|
if dbtype == 'sqlite':
|
||||||
def get_databases(self):
|
url = URL.create(
|
||||||
if self.cfg.type == 'sqlite':
|
drivername='sqlite',
|
||||||
izzylog.verbose('This function is useless with sqlite')
|
database=kwargs.pop('name')
|
||||||
return
|
)
|
||||||
|
|
||||||
elif self.cfg.type == 'postgresql':
|
|
||||||
databases = [row[0] for row in self.execute('SELECT datname FROM pg_database')]
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
|
try:
|
||||||
|
for module in modules[dbtype]:
|
||||||
|
if pkgutil.get_loader(module):
|
||||||
|
dbtype = f'{dbtype}+{module}'
|
||||||
|
|
||||||
return databases
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
url = URL.create(
|
||||||
|
drivername = dbtype,
|
||||||
|
username = kwargs.pop('user', None),
|
||||||
|
password = kwargs.pop('password', None),
|
||||||
|
host = kwargs.pop('host', None),
|
||||||
|
port = kwargs.pop('port', None),
|
||||||
|
database = kwargs.pop('name'),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db = create_engine(url, **engine_kwargs)
|
||||||
|
self.meta = MetaData()
|
||||||
|
self.meta.reflect(bind=self.db, resolve_fks=True, views=True)
|
||||||
|
self._setup_cache()
|
||||||
|
|
||||||
|
|
||||||
def cursor_description(self):
|
def close(self):
|
||||||
return [row[0] for row in self.cursor.description]
|
for sid in list(self.sessions):
|
||||||
|
self.sessions[sid].commit()
|
||||||
|
self.sessions[sid].close()
|
||||||
|
|
||||||
|
self.config = DotDict()
|
||||||
|
self.cache = DotDict()
|
||||||
|
self.meta = None
|
||||||
|
self.db = None
|
||||||
|
|
||||||
|
|
||||||
def setup_database(self):
|
def load_tables(self, **tables):
|
||||||
if not self.db.tables:
|
self.meta = MetaData()
|
||||||
raise ValueError('Tables have not been specified.')
|
|
||||||
|
|
||||||
current_tables = self.get_tables()
|
for name, columns in tables.items():
|
||||||
|
Table(name, self.meta, *columns)
|
||||||
|
|
||||||
for name, table in self.db.tables.items():
|
self._setup_cache()
|
||||||
if name in current_tables:
|
|
||||||
izzylog.verbose(f'Skipping table creation since it already exists: {name}')
|
|
||||||
continue
|
|
||||||
|
|
||||||
izzylog.verbose(f'Creating table: {name}')
|
|
||||||
self.execute(table.build(self.cfg.type))
|
|
||||||
|
|
||||||
|
|
||||||
class Row(DotDict):
|
def create_database(self, tables={}):
|
||||||
def __init__(self, table, keys, values):
|
if tables:
|
||||||
self._db = None
|
self.load_tables(**tables)
|
||||||
self._table = table
|
|
||||||
|
|
||||||
super().__init__()
|
if self.db.url.get_backend_name() == 'postgresql':
|
||||||
|
predb = create_engine(self.db.engine_string.replace(self.config.name, 'postgres', -1), future=True)
|
||||||
|
conn = predb.connect()
|
||||||
|
|
||||||
for idx, key in enumerate([key[0] for key in keys]):
|
try:
|
||||||
self[key] = values[idx]
|
conn.execute(text(f'CREATE DATABASE {database}'))
|
||||||
|
|
||||||
|
except ProgrammingError:
|
||||||
|
'The database already exists, so just move along'
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
conn.close()
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
self.meta.create_all(bind=self.db)
|
||||||
|
|
||||||
|
|
||||||
def update(self, data):
|
def drop_tables(self, *tables):
|
||||||
for k, v in data.items():
|
if not tables:
|
||||||
if k not in self:
|
raise ValueError('No tables specified')
|
||||||
raise KeyError(f'Not a column for {self._table}')
|
|
||||||
|
|
||||||
self[k] = v
|
self.meta.drop_all(bind=self.db, tables=tables)
|
||||||
|
|
||||||
|
|
||||||
def delete(self):
|
def execute(self, string, **kwargs):
|
||||||
with self._db.session as s:
|
with self.session as s:
|
||||||
s.delete(self._table, id=self.id)
|
s.execute(string, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def update(self, **kwargs):
|
|
||||||
self.update(kwargs)
|
|
||||||
|
|
||||||
with self._db.session as s:
|
|
||||||
s.update(self._table, id=self.id, **kwargs)
|
|
||||||
|
|
|
@ -1,10 +0,0 @@
|
||||||
class MaxConnectionsError(Exception):
|
|
||||||
'raise when the max amount of connections has been reached'
|
|
||||||
|
|
||||||
|
|
||||||
class NoTransactionError(Exception):
|
|
||||||
'raise when a write command is executed outside a transaction'
|
|
||||||
|
|
||||||
|
|
||||||
class DatabaseNotSupportedError(Exception):
|
|
||||||
'raise when the action being performed is not supported by the database in use'
|
|
|
@ -1,508 +0,0 @@
|
||||||
import json, sys, threading, time
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from datetime import datetime
|
|
||||||
from sqlalchemy import create_engine, ForeignKey, MetaData, Table
|
|
||||||
from sqlalchemy import Column as sqlalchemy_column, types as Types
|
|
||||||
from sqlalchemy.exc import OperationalError, ProgrammingError
|
|
||||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
|
||||||
|
|
||||||
from izzylib import (
|
|
||||||
LruCache,
|
|
||||||
DotDict,
|
|
||||||
Path,
|
|
||||||
random_gen,
|
|
||||||
nfs_check,
|
|
||||||
izzylog
|
|
||||||
)
|
|
||||||
|
|
||||||
SqlTypes = DotDict({t.lower(): getattr(Types, t) for t in dir(Types) if not t.startswith('_')})
|
|
||||||
|
|
||||||
|
|
||||||
class SqlDatabase:
|
|
||||||
def __init__(self, dbtype='sqlite', tables={}, **kwargs):
|
|
||||||
self.db = self.__create_engine(dbtype, kwargs)
|
|
||||||
self.table = None
|
|
||||||
self.table_names = None
|
|
||||||
self.classes = kwargs.get('row_classes', CustomRows())
|
|
||||||
self.cache = None
|
|
||||||
|
|
||||||
self.session_class = kwargs.get('session_class', SqlSession)
|
|
||||||
self.sessions = {}
|
|
||||||
|
|
||||||
self.setup_tables(tables)
|
|
||||||
self.setup_cache()
|
|
||||||
|
|
||||||
|
|
||||||
## Leaving link to example code for read-only sqlite for later use
|
|
||||||
## https://github.com/pudo/dataset/issues/136#issuecomment-128693122
|
|
||||||
def __create_engine(self, dbtype, kwargs):
|
|
||||||
engine_args = []
|
|
||||||
engine_kwargs = {}
|
|
||||||
|
|
||||||
if not kwargs.get('name'):
|
|
||||||
raise KeyError('Database "name" is not set')
|
|
||||||
|
|
||||||
engine_string = dbtype + '://'
|
|
||||||
|
|
||||||
if dbtype == 'sqlite':
|
|
||||||
database = kwargs['name']
|
|
||||||
|
|
||||||
if nfs_check(database):
|
|
||||||
izzylog.error('Database file is on an NFS share which does not support locking. Any writes to the database will fail')
|
|
||||||
|
|
||||||
engine_string += f'/{database}'
|
|
||||||
engine_kwargs['connect_args'] = {'check_same_thread': False}
|
|
||||||
|
|
||||||
elif dbtype == 'postgresql':
|
|
||||||
ssl_context = kwargs.get('ssl')
|
|
||||||
|
|
||||||
if ssl_context:
|
|
||||||
engine_kwargs['ssl_context'] = ssl_context
|
|
||||||
|
|
||||||
else:
|
|
||||||
user = kwargs.get('user')
|
|
||||||
password = kwargs.get('pass')
|
|
||||||
host = kwargs.get('host', '/var/run/postgresql')
|
|
||||||
port = kwargs.get('port', 5432)
|
|
||||||
name = kwargs.get('name', 'postgres')
|
|
||||||
maxconn = kwargs.get('maxconnections', 25)
|
|
||||||
|
|
||||||
if user:
|
|
||||||
if password:
|
|
||||||
engine_string += f'{user}:{password}@'
|
|
||||||
else:
|
|
||||||
engine_string += user + '@'
|
|
||||||
|
|
||||||
if host == '/var/run/postgresql':
|
|
||||||
engine_string += f'/{name}:{port}/{name}'
|
|
||||||
|
|
||||||
else:
|
|
||||||
engine_string += f'{host}:{port}/{name}'
|
|
||||||
|
|
||||||
return create_engine(engine_string, *engine_args, **engine_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def session(self):
|
|
||||||
return self.session_class(self)
|
|
||||||
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.setup_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def setup_cache(self):
|
|
||||||
self.cache = DotDict({table: LruCache() for table in self.table_names})
|
|
||||||
|
|
||||||
|
|
||||||
def create_tables(self, *tables):
|
|
||||||
if not tables:
|
|
||||||
raise ValueError('No tables specified')
|
|
||||||
|
|
||||||
new_tables = [self.table[table] for table in tables]
|
|
||||||
self.table.meta.create_all(bind=self.db, tables=new_tables)
|
|
||||||
|
|
||||||
|
|
||||||
def create_database(self):
|
|
||||||
if self.db.url.get_backend_name() == 'postgresql':
|
|
||||||
predb = create_engine(db.engine_string.replace(config.db.name, 'postgres', -1))
|
|
||||||
conn = predb.connect()
|
|
||||||
conn.execute('commit')
|
|
||||||
|
|
||||||
try:
|
|
||||||
conn.execute(f'CREATE DATABASE {config.db.name}')
|
|
||||||
|
|
||||||
except ProgrammingError:
|
|
||||||
'The database already exists, so just move along'
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
conn.close()
|
|
||||||
raise e from None
|
|
||||||
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
self.table.meta.create_all(self.db)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_tables(self, tables):
|
|
||||||
self.table = Tables(self, tables)
|
|
||||||
self.table_names = tables.keys()
|
|
||||||
|
|
||||||
|
|
||||||
def execute(self, string, values=[]):
|
|
||||||
with self.session as s:
|
|
||||||
s.execute(string, values)
|
|
||||||
|
|
||||||
|
|
||||||
class SqlSession(object):
|
|
||||||
def __init__(self, db):
|
|
||||||
self.closed = False
|
|
||||||
|
|
||||||
self.database = db
|
|
||||||
self.classes = db.classes
|
|
||||||
self.session = sessionmaker(bind=db.db)()
|
|
||||||
self.table = db.table
|
|
||||||
self.cache = db.cache
|
|
||||||
|
|
||||||
# session aliases
|
|
||||||
self.s = self.session
|
|
||||||
self.begin = self.s.begin
|
|
||||||
self.commit = self.s.commit
|
|
||||||
self.rollback = self.s.rollback
|
|
||||||
self.query = self.s.query
|
|
||||||
self.execute = self.s.execute
|
|
||||||
|
|
||||||
# remove in the future
|
|
||||||
self.db = db
|
|
||||||
|
|
||||||
self._setup()
|
|
||||||
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.open()
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
def __exit__(self, exctype, value, tb):
|
|
||||||
if tb:
|
|
||||||
self.rollback()
|
|
||||||
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
|
|
||||||
def open(self):
|
|
||||||
self.sessionid = random_gen(10)
|
|
||||||
self.db.sessions[self.sessionid] = self
|
|
||||||
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.commit()
|
|
||||||
self.s.close()
|
|
||||||
self.closed = True
|
|
||||||
|
|
||||||
del self.db.sessions[self.sessionid]
|
|
||||||
|
|
||||||
self.sessionid = None
|
|
||||||
|
|
||||||
|
|
||||||
def _setup(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dirty(self):
|
|
||||||
return any([self.s.new, self.s.dirty, self.s.deleted])
|
|
||||||
|
|
||||||
|
|
||||||
def count(self, table_name, **kwargs):
|
|
||||||
return self.query(self.table[table_name]).filter_by(**kwargs).count()
|
|
||||||
|
|
||||||
|
|
||||||
def fetch(self, table_name, single=True, orderby=None, orderdir='asc', **kwargs):
|
|
||||||
table = self.table[table_name]
|
|
||||||
RowClass = self.classes.get(table_name.capitalize())
|
|
||||||
|
|
||||||
query = self.query(table).filter_by(**kwargs)
|
|
||||||
|
|
||||||
if not orderby:
|
|
||||||
rows = query.all()
|
|
||||||
|
|
||||||
else:
|
|
||||||
if orderdir == 'asc':
|
|
||||||
rows = query.order_by(getattr(table.c, orderby).asc()).all()
|
|
||||||
|
|
||||||
elif orderdir == 'desc':
|
|
||||||
rows = query.order_by(getattr(table.c, orderby).desc()).all()
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported order direction: {orderdir}')
|
|
||||||
|
|
||||||
if single:
|
|
||||||
return RowClass(table_name, rows[0], self) if len(rows) > 0 else None
|
|
||||||
|
|
||||||
return [RowClass(table_name, row, self) for row in rows]
|
|
||||||
|
|
||||||
|
|
||||||
def search(self, *args, **kwargs):
|
|
||||||
kwargs.pop('single', None)
|
|
||||||
return self.fetch(*args, single=False, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def insert(self, table_name, return_row=False, **kwargs):
|
|
||||||
row = self.fetch(table_name, **kwargs)
|
|
||||||
|
|
||||||
if row:
|
|
||||||
row.update_session(self, **kwargs)
|
|
||||||
return
|
|
||||||
|
|
||||||
table = self.table[table_name]
|
|
||||||
|
|
||||||
if getattr(table, 'timestamp', None) and not kwargs.get('timestamp'):
|
|
||||||
kwargs['timestamp'] = datetime.now()
|
|
||||||
|
|
||||||
self.execute(table.insert().values(**kwargs))
|
|
||||||
|
|
||||||
if return_row:
|
|
||||||
return self.fetch(table_name, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def update(self, table=None, rowid=None, row=None, return_row=False, **data):
|
|
||||||
if row:
|
|
||||||
if not getattr(row, '_table_name', None):
|
|
||||||
print(row)
|
|
||||||
print(dir(row))
|
|
||||||
rowid = row.id
|
|
||||||
table = row._table_name
|
|
||||||
|
|
||||||
if not rowid or not table:
|
|
||||||
raise ValueError('Missing row ID or table')
|
|
||||||
|
|
||||||
tclass = self.table[table]
|
|
||||||
self.execute(tclass.update().where(tclass.c.id == rowid).values(**data))
|
|
||||||
|
|
||||||
if return_row:
|
|
||||||
return self.fetch(table, id=rowid)
|
|
||||||
|
|
||||||
|
|
||||||
def remove(self, table=None, rowid=None, row=None):
|
|
||||||
if row:
|
|
||||||
rowid = row.id
|
|
||||||
table = row._table_name
|
|
||||||
|
|
||||||
if not rowid or not table:
|
|
||||||
raise ValueError('Missing row ID or table')
|
|
||||||
|
|
||||||
self.execute(f'DELETE FROM {table} WHERE id={rowid}')
|
|
||||||
|
|
||||||
|
|
||||||
def drop_table(self, name):
|
|
||||||
if name not in self.get_tables():
|
|
||||||
raise KeyError(f'Table does not exist: {name}')
|
|
||||||
|
|
||||||
self.execute(f'DROP TABLE {name}')
|
|
||||||
|
|
||||||
|
|
||||||
def drop_tables(self):
|
|
||||||
tables = self.get_tables()
|
|
||||||
|
|
||||||
for table in tables:
|
|
||||||
self.drop_table(table)
|
|
||||||
|
|
||||||
|
|
||||||
def get_columns(self, table):
|
|
||||||
if table not in self.get_tables():
|
|
||||||
raise KeyError(f'Not an existing table: {table}')
|
|
||||||
|
|
||||||
rows = self.execute('PRAGMA table_info(user)')
|
|
||||||
return [row[1] for row in rows]
|
|
||||||
|
|
||||||
|
|
||||||
def get_tables(self):
|
|
||||||
rows = self.execute("SELECT name FROM sqlite_master WHERE type IN ('table','view') and name NOT LIKE 'sqlite_%'")
|
|
||||||
return [row[0] for row in rows]
|
|
||||||
|
|
||||||
|
|
||||||
def append_column(self, table, column):
|
|
||||||
if column.name in self.get_columns(table):
|
|
||||||
logging.warning(f'Table "{table}" already has column "{column.name}"')
|
|
||||||
return
|
|
||||||
|
|
||||||
self.execute(f'ALTER TABLE {table} ADD COLUMN {column.compile()}')
|
|
||||||
|
|
||||||
|
|
||||||
def append_column2(self, tbl, col):
|
|
||||||
table = self.table[tbl]
|
|
||||||
|
|
||||||
try:
|
|
||||||
column = getattr(table.c, col)
|
|
||||||
|
|
||||||
except AttributeError:
|
|
||||||
izzylog.error(f'Table "{tbl}" does not have column "{col}"')
|
|
||||||
return
|
|
||||||
|
|
||||||
columns = [row[1] for row in self.execute(f'PRAGMA table_info({tbl})')]
|
|
||||||
|
|
||||||
if col in self.get_columns(tbl):
|
|
||||||
izzylog.info(f'Column "{col}" already exists')
|
|
||||||
return
|
|
||||||
|
|
||||||
sql = f'ALTER TABLE {tbl} ADD COLUMN {col} {column.type}'
|
|
||||||
|
|
||||||
if not column.nullable:
|
|
||||||
sql += ' NOT NULL'
|
|
||||||
|
|
||||||
if column.primary_key:
|
|
||||||
sql += ' PRIMARY KEY'
|
|
||||||
|
|
||||||
if column.unique:
|
|
||||||
sql += ' UNIQUE'
|
|
||||||
|
|
||||||
self.execute(sql)
|
|
||||||
|
|
||||||
|
|
||||||
def remove_column(self, tbl, col):
|
|
||||||
table = self.table[tbl]
|
|
||||||
column = getattr(table, col, None)
|
|
||||||
columns = [row[1] for row in self.execute(f'PRAGMA table_info({tbl})')]
|
|
||||||
|
|
||||||
if col not in columns:
|
|
||||||
izzylog.info(f'Column "{col}" already exists')
|
|
||||||
return
|
|
||||||
|
|
||||||
columns.remove(col)
|
|
||||||
coltext = ', '.join(columns)
|
|
||||||
|
|
||||||
self.execute(f'CREATE TABLE {tbl}_temp AS SELECT {coltext} FROM {tbl}')
|
|
||||||
self.execute(f'DROP TABLE {tbl}')
|
|
||||||
self.execute(f'ALTER TABLE {tbl}_temp RENAME TO {tbl}')
|
|
||||||
|
|
||||||
|
|
||||||
def clear_table(self, table):
|
|
||||||
self.execute(f'DELETE FROM {table}')
|
|
||||||
|
|
||||||
|
|
||||||
class CustomRows(object):
|
|
||||||
def get(self, name):
|
|
||||||
return getattr(self, name, self.Row)
|
|
||||||
|
|
||||||
|
|
||||||
class Row(DotDict):
|
|
||||||
#_filter_columns = lambda self, row: [attr for attr in dir(row) if not attr.startswith('_') and attr != 'metadata']
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, table, row, session):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if row:
|
|
||||||
try:
|
|
||||||
self._update(row._asdict())
|
|
||||||
except:
|
|
||||||
self._update(row)
|
|
||||||
|
|
||||||
self._db = session.db
|
|
||||||
self._table_name = table
|
|
||||||
self._columns = self.keys()
|
|
||||||
|
|
||||||
self.__run__(session)
|
|
||||||
|
|
||||||
|
|
||||||
## Subclass Row and redefine this function
|
|
||||||
def __run__(self, s):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _filter_data(self):
|
|
||||||
data = {k: v for k,v in self.items() if k in self._columns}
|
|
||||||
|
|
||||||
for k,v in self.items():
|
|
||||||
if v.__class__ == DotDict:
|
|
||||||
data[k] = v.asDict()
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def asDict(self):
|
|
||||||
return self._filter_data()
|
|
||||||
|
|
||||||
|
|
||||||
def _update(self, new_data={}, **kwargs):
|
|
||||||
kwargs.update(new_data)
|
|
||||||
|
|
||||||
for k,v in kwargs.items():
|
|
||||||
if type(v) == dict:
|
|
||||||
self[k] = DotDict(v)
|
|
||||||
|
|
||||||
self[k] = v
|
|
||||||
|
|
||||||
|
|
||||||
def delete(self, s=None):
|
|
||||||
if s:
|
|
||||||
return self.delete_session(s)
|
|
||||||
|
|
||||||
with self._db.session as s:
|
|
||||||
return self.delete_session(s)
|
|
||||||
|
|
||||||
|
|
||||||
def delete_session(self, s):
|
|
||||||
return s.remove(table=self._table_name, row=self)
|
|
||||||
|
|
||||||
|
|
||||||
def update(self, dict_data={}, s=None, **data):
|
|
||||||
dict_data.update(data)
|
|
||||||
self._update(dict_data)
|
|
||||||
|
|
||||||
if s:
|
|
||||||
return self.update_session(s, **self._filter_data())
|
|
||||||
|
|
||||||
with self._db.session as s:
|
|
||||||
s.update(row=self, **self._filter_data())
|
|
||||||
|
|
||||||
|
|
||||||
def update_session(self, s, dict_data={}, **data):
|
|
||||||
dict_data.update(data)
|
|
||||||
self._update(dict_data)
|
|
||||||
return s.update(table=self._table_name, row=self, **dict_data)
|
|
||||||
|
|
||||||
|
|
||||||
class Tables(DotDict):
|
|
||||||
def __init__(self, db, tables={}):
|
|
||||||
'"tables" should be a dict with the table names for keys and a list of Columns for values'
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.db = db
|
|
||||||
self.meta = MetaData()
|
|
||||||
|
|
||||||
for name, table in tables.items():
|
|
||||||
self.__setup_table(name, table)
|
|
||||||
|
|
||||||
|
|
||||||
def __setup_table(self, name, table):
|
|
||||||
columns = [col if type(col) == SqlColumn else SqlColumn(*col.get('args'), **col.get('kwargs')) for col in table]
|
|
||||||
self[name] = Table(name, self.meta, *columns)
|
|
||||||
|
|
||||||
|
|
||||||
class SqlColumn(sqlalchemy_column):
|
|
||||||
def __init__(self, name, stype=None, fkey=None, **kwargs):
|
|
||||||
if not stype and not kwargs:
|
|
||||||
if name == 'id':
|
|
||||||
stype = 'integer'
|
|
||||||
kwargs['primary_key'] = True
|
|
||||||
kwargs['autoincrement'] = True
|
|
||||||
|
|
||||||
elif name == 'timestamp':
|
|
||||||
stype = 'datetime'
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError('Missing column type and options')
|
|
||||||
|
|
||||||
stype = (stype.lower() if type(stype) == str else stype) or 'string'
|
|
||||||
|
|
||||||
if type(stype) == str:
|
|
||||||
try:
|
|
||||||
stype = SqlTypes[stype.lower()]
|
|
||||||
|
|
||||||
except KeyError:
|
|
||||||
raise KeyError(f'Invalid SQL data type: {stype}')
|
|
||||||
|
|
||||||
options = [name, stype]
|
|
||||||
|
|
||||||
if fkey:
|
|
||||||
options.append(ForeignKey(fkey))
|
|
||||||
|
|
||||||
super().__init__(*options, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def compile(self):
|
|
||||||
sql = f'{self.name} {self.type}'
|
|
||||||
|
|
||||||
if not self.nullable:
|
|
||||||
sql += ' NOT NULL'
|
|
||||||
|
|
||||||
if self.primary_key:
|
|
||||||
sql += ' PRIMARY KEY'
|
|
||||||
|
|
||||||
if self.unique:
|
|
||||||
sql += ' UNIQUE'
|
|
||||||
|
|
||||||
return sql
|
|
|
@ -1,415 +0,0 @@
|
||||||
from datetime import datetime
|
|
||||||
from functools import partial
|
|
||||||
from izzylib import DotDict, Path
|
|
||||||
|
|
||||||
from .types import BaseType, Type
|
|
||||||
|
|
||||||
|
|
||||||
placeholders = dict(
|
|
||||||
sqlite = '?',
|
|
||||||
postgresql = '%s'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
## Data queries
|
|
||||||
class Delete:
|
|
||||||
def __init__(self, table, type='sqlite', **kwargs):
|
|
||||||
self.table = table
|
|
||||||
self.placeholder = placeholders[type]
|
|
||||||
self.keys = []
|
|
||||||
self.values = []
|
|
||||||
|
|
||||||
for k,v in kwargs.items():
|
|
||||||
self.keys.append(k)
|
|
||||||
self.values.append(v)
|
|
||||||
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
self.build(embed_values=True)
|
|
||||||
|
|
||||||
|
|
||||||
def build(self, comp_type='AND', embed_values=False):
|
|
||||||
sql = 'DELETE FROM {table} WHERE {kstring}'
|
|
||||||
|
|
||||||
if not embed_values:
|
|
||||||
kstring = f' {comp_type.upper()} '.join([f'{k} = {self.placeholder}' for k in self.keys])
|
|
||||||
return sql.format(table=self.table, kstring=kstring), self.values
|
|
||||||
|
|
||||||
values = []
|
|
||||||
|
|
||||||
for idx, value in enumerate(self.values):
|
|
||||||
if type(value) == str:
|
|
||||||
values.append(f"{self.keys[idx]} = '{value}'")
|
|
||||||
|
|
||||||
else:
|
|
||||||
values.append(f"{self.keys[idx]} = {value}")
|
|
||||||
|
|
||||||
kstring = ','.join(values)
|
|
||||||
return sql.format(table=self.table, kstring=kstring, rowid=self.rowid)
|
|
||||||
|
|
||||||
|
|
||||||
def exec(self, session, comp_type='AND'):
|
|
||||||
return session.execute(*self.build(comp_type))
|
|
||||||
|
|
||||||
|
|
||||||
class Insert:
|
|
||||||
def __init__(self, table, type='sqlite', **kwargs):
|
|
||||||
self.table = table
|
|
||||||
self.placeholder = placeholders[type]
|
|
||||||
self.keys = []
|
|
||||||
self.values = []
|
|
||||||
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
self.keys.append(k)
|
|
||||||
self.values.append(v)
|
|
||||||
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.build(embed_values=True)
|
|
||||||
|
|
||||||
|
|
||||||
def build(self, embed_values=False):
|
|
||||||
kstring = ','.join(self.keys)
|
|
||||||
|
|
||||||
if not embed_values:
|
|
||||||
vstring = ','.join([self.placeholder for k in self.keys])
|
|
||||||
return f'INSERT INTO {self.table} ({kstring}) VALUES({vstring})', self.values
|
|
||||||
|
|
||||||
else:
|
|
||||||
vstring = ','.join(self.values)
|
|
||||||
return f'INSERT INTO {self.table} ({kstring}) VALUES({vstring})'
|
|
||||||
|
|
||||||
|
|
||||||
def exec(self, session):
|
|
||||||
return session.execute(*self.build())
|
|
||||||
|
|
||||||
|
|
||||||
class Select:
|
|
||||||
def __init__(self, table, columns=[], type='sqlite', **kwargs):
|
|
||||||
self.placeholder = placeholders[type]
|
|
||||||
self.columns = columns
|
|
||||||
self.table = table
|
|
||||||
self.where = []
|
|
||||||
self.where_build = []
|
|
||||||
self._order = []
|
|
||||||
self.keys = []
|
|
||||||
self.values = []
|
|
||||||
|
|
||||||
self.equals = partial(self.__comparison, '=')
|
|
||||||
self.less = partial(self.__comparison, '<')
|
|
||||||
self.greater = partial(self.__comparison, '>')
|
|
||||||
self.like = partial(self.__comparison, 'LIKE')
|
|
||||||
|
|
||||||
for k,v in kwargs.items():
|
|
||||||
self.equals(k, v)
|
|
||||||
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.build(embed_values=True)
|
|
||||||
|
|
||||||
|
|
||||||
def __comparison(self, comp, key, value):
|
|
||||||
self.values.append(value)
|
|
||||||
self.keys.append(key)
|
|
||||||
self.where.append(f'{key} {comp.upper()} {self.placeholder}')
|
|
||||||
self.where_build.append(f"{key} {comp.upper()} '{value}'" if type(key) == str else f"{key} {comp.upper()} {value}")
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
def order(self, column, asc=True):
|
|
||||||
self._order = [column, 'ASC' if asc else 'DESC']
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
def build(self, comp_type='AND', embed_values=False):
|
|
||||||
if not self.columns:
|
|
||||||
cols = '*'
|
|
||||||
|
|
||||||
else:
|
|
||||||
cols = ','.join('columns')
|
|
||||||
|
|
||||||
sql_query = f'SELECT {cols} FROM {self.table}'
|
|
||||||
|
|
||||||
if self.where:
|
|
||||||
where = f' {comp_type.upper()} '.join(self.where if not embed_values else self.where_build)
|
|
||||||
sql_query += f' WHERE {where}'
|
|
||||||
|
|
||||||
if self._order:
|
|
||||||
col, order = self._order
|
|
||||||
sql_query += f' ORDER BY {col} {order}'
|
|
||||||
|
|
||||||
if embed_values:
|
|
||||||
return sql_query
|
|
||||||
|
|
||||||
return sql_query, self.values
|
|
||||||
|
|
||||||
|
|
||||||
def exec(self, session, comp_type='AND'):
|
|
||||||
return session.execute(*self.build(comp_type))
|
|
||||||
|
|
||||||
|
|
||||||
class Update:
|
|
||||||
def __init__(self, table, rowid, type='sqlite', **kwargs):
|
|
||||||
self.placeholder = placeholders[type]
|
|
||||||
self.table = table
|
|
||||||
self.rowid = rowid
|
|
||||||
self.keys = []
|
|
||||||
self.values = []
|
|
||||||
|
|
||||||
for k,v in kwargs.items():
|
|
||||||
self.keys.append(k)
|
|
||||||
self.values.append(v)
|
|
||||||
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.build(embed_values=True)
|
|
||||||
|
|
||||||
|
|
||||||
def build(self, embed_values=False):
|
|
||||||
sql = 'UPDATE {table} SET {kstring} WHERE id={rowid}'
|
|
||||||
|
|
||||||
if not embed_values:
|
|
||||||
kstring = ','.join([f'{k} = {self.placeholder}' for k in self.keys])
|
|
||||||
return sql.format(table=self.table, kstring=kstring, rowid=self.rowid), self.values
|
|
||||||
|
|
||||||
values = []
|
|
||||||
|
|
||||||
for idx, value in enumerate(self.values):
|
|
||||||
if type(value) == str:
|
|
||||||
values.append(f"{self.keys[idx]} = '{value}'")
|
|
||||||
|
|
||||||
else:
|
|
||||||
values.append(f"{self.keys[idx]} = {value}")
|
|
||||||
|
|
||||||
kstring = ','.join(values)
|
|
||||||
return sql.format(table=self.table, kstring=kstring, rowid=self.rowid)
|
|
||||||
|
|
||||||
|
|
||||||
def exec(self, session):
|
|
||||||
return session.execute(*self.build())
|
|
||||||
|
|
||||||
|
|
||||||
## Database objects
|
|
||||||
class Column:
|
|
||||||
def __init__(self, name, type='STRING', unique=False, nullable=True, default=None, primary_key=False, autoincrement=False, foreign_key=None):
|
|
||||||
self.name = name
|
|
||||||
self.type = type
|
|
||||||
self.nullable = nullable
|
|
||||||
self.default = default
|
|
||||||
self.primary_key = primary_key
|
|
||||||
self.autoincrement = autoincrement
|
|
||||||
self.unique = unique
|
|
||||||
|
|
||||||
if any(map(isinstance, [foreign_key], [list, tuple, set])):
|
|
||||||
self.foreign_key = foreign_key
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.foreign_key = foreign_key.split('.') if foreign_key else None
|
|
||||||
|
|
||||||
if autoincrement:
|
|
||||||
self.primary_key = True
|
|
||||||
self.type = Type['INTEGER']
|
|
||||||
|
|
||||||
if isinstance(self.type, BaseType):
|
|
||||||
self.type = self.type.name
|
|
||||||
|
|
||||||
else:
|
|
||||||
if self.type.upper() in Type.keys():
|
|
||||||
self.type = self.type.upper()
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise TypeError(f'Invalid SQL type: {self.type}')
|
|
||||||
|
|
||||||
if foreign_key and len(self.foreign_key) != 2:
|
|
||||||
raise ValueError('Invalid foreign key. Must be in the format "table.column".')
|
|
||||||
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.build()
|
|
||||||
|
|
||||||
|
|
||||||
def build(self, dbtype='sqlite'):
|
|
||||||
if dbtype == 'postgresql':
|
|
||||||
if self.type.lower() == 'string':
|
|
||||||
self.type = 'TEXT'
|
|
||||||
|
|
||||||
elif self.type.lower() == 'datetime':
|
|
||||||
self.type = 'TIMESTAMPTZ'
|
|
||||||
|
|
||||||
if self.autoincrement:
|
|
||||||
self.type = 'SERIAL'
|
|
||||||
self.autoincrement = False
|
|
||||||
|
|
||||||
sql = f'{self.name} {self.type}'
|
|
||||||
|
|
||||||
if self.primary_key:
|
|
||||||
sql += ' PRIMARY KEY'
|
|
||||||
|
|
||||||
if self.autoincrement:
|
|
||||||
sql += ' AUTOINCREMENT'
|
|
||||||
|
|
||||||
if self.unique:
|
|
||||||
sql += ' UNIQUE'
|
|
||||||
|
|
||||||
if not self.nullable:
|
|
||||||
sql += ' NOT NULL'
|
|
||||||
|
|
||||||
if self.default:
|
|
||||||
def_type = type(self.default)
|
|
||||||
|
|
||||||
if self.default == 'CURRENT_TIMESTAMP':
|
|
||||||
if dbtype == 'sqlite':
|
|
||||||
sql += " DEFAULT (datetime('now', 'localtime'))"
|
|
||||||
|
|
||||||
elif dbtype == 'postgresql':
|
|
||||||
sql += ' DEFAULT now()'
|
|
||||||
|
|
||||||
else:
|
|
||||||
sql += f' DEFAULT {datetime.now().timestamp()}'
|
|
||||||
|
|
||||||
elif def_type == str:
|
|
||||||
sql += f" DEFAULT '{self.default}'"
|
|
||||||
|
|
||||||
elif def_type in [int, float]:
|
|
||||||
sql += f' DEFAULT {self.default}'
|
|
||||||
|
|
||||||
elif def_type == bool and dbtype == 'sqlite':
|
|
||||||
sql += f' DEFAULT {int(self.default)}'
|
|
||||||
|
|
||||||
else:
|
|
||||||
sql += f' DEFAULT {self.default}'
|
|
||||||
|
|
||||||
print(sql)
|
|
||||||
return sql
|
|
||||||
|
|
||||||
|
|
||||||
def json(self):
|
|
||||||
return DotDict({
|
|
||||||
'type': self.type,
|
|
||||||
'nullable': self.nullable,
|
|
||||||
'default': self.default,
|
|
||||||
'primary_key': self.primary_key,
|
|
||||||
'autoincrement': self.autoincrement,
|
|
||||||
'unique': self.unique,
|
|
||||||
'foreign_key': self.foreign_key
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
class Table(DotDict):
|
|
||||||
def __init__(self, name, *columns):
|
|
||||||
super().__init__()
|
|
||||||
self._name = name
|
|
||||||
self._foreign_keys = {}
|
|
||||||
|
|
||||||
self.add_column(Column('id', autoincrement=True))
|
|
||||||
|
|
||||||
for column in columns:
|
|
||||||
self.add_column(column)
|
|
||||||
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.build()
|
|
||||||
|
|
||||||
|
|
||||||
# this'll be useful later
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
|
|
||||||
def add_column(self, column):
|
|
||||||
self[column.name] = column
|
|
||||||
|
|
||||||
if column.foreign_key:
|
|
||||||
self._foreign_keys[column.name] = column.foreign_key
|
|
||||||
|
|
||||||
|
|
||||||
def build(self, dbtype='sqlite'):
|
|
||||||
column_string = ',\n'.join([f'\t{col.build(dbtype)}' for col in self.values()])
|
|
||||||
|
|
||||||
if self._foreign_keys:
|
|
||||||
column_string += ',\n'
|
|
||||||
column_string += ',\n'.join([f'\tFOREIGN KEY ({column}) REFERENCES {key[0]} ({key[1]})' for column, key in self._foreign_keys.items()])
|
|
||||||
|
|
||||||
return f'''CREATE TABLE {self.name} (
|
|
||||||
{column_string}
|
|
||||||
);'''
|
|
||||||
|
|
||||||
|
|
||||||
def json(self):
|
|
||||||
data = {}
|
|
||||||
|
|
||||||
for name, column in self.items():
|
|
||||||
data[name] = column.json()
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class Tables(DotDict):
|
|
||||||
def __init__(self, *tables, data={}):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
for table in tables:
|
|
||||||
self.add_table(table)
|
|
||||||
|
|
||||||
if data:
|
|
||||||
self.from_dict(data)
|
|
||||||
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.build()
|
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def new_from_json_file(cls, path):
|
|
||||||
return cls(data=DotDict.new_from_json_file(path))
|
|
||||||
|
|
||||||
|
|
||||||
def add_table(self, table):
|
|
||||||
self[table.name] = table
|
|
||||||
|
|
||||||
|
|
||||||
def build(self):
|
|
||||||
return '\n\n'.join([str(table) for table in self.values()])
|
|
||||||
|
|
||||||
|
|
||||||
def load_json(self, path):
|
|
||||||
data = DotDict()
|
|
||||||
data.load_json(path)
|
|
||||||
|
|
||||||
self.from_dict(data)
|
|
||||||
|
|
||||||
|
|
||||||
def save_json(self, path, indent='\t'):
|
|
||||||
self.to_dict().save_json(path, indent=indent)
|
|
||||||
|
|
||||||
|
|
||||||
def from_dict(self, data):
|
|
||||||
for name, columns in data.items():
|
|
||||||
table = Table(name)
|
|
||||||
|
|
||||||
for col, kwargs in columns.items():
|
|
||||||
table.add_column(Column(col,
|
|
||||||
type = kwargs.get('type', 'STRING'),
|
|
||||||
nullable = kwargs.get('nullable', True),
|
|
||||||
default = kwargs.get('default'),
|
|
||||||
primary_key = kwargs.get('primary_key', False),
|
|
||||||
autoincrement = kwargs.get('autoincrement', False),
|
|
||||||
unique = kwargs.get('unique', False),
|
|
||||||
foreign_key = kwargs.get('foreign_key')
|
|
||||||
))
|
|
||||||
|
|
||||||
self.add_table(table)
|
|
||||||
|
|
||||||
|
|
||||||
def to_dict(self):
|
|
||||||
data = DotDict()
|
|
||||||
|
|
||||||
for name, table in self.items():
|
|
||||||
data[name] = table.json()
|
|
||||||
|
|
||||||
return data
|
|
|
@ -1,19 +0,0 @@
|
||||||
from izzylib import DotDict
|
|
||||||
|
|
||||||
|
|
||||||
class DbRow(DotDict):
|
|
||||||
def __init__(self, table, keys, values):
|
|
||||||
self.table = table
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
for idx, key in enumerate(keys):
|
|
||||||
self[key] = values[idx]
|
|
||||||
|
|
||||||
|
|
||||||
def delete(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def update(self, **kwargs):
|
|
||||||
pass
|
|
90
sql/izzylib/sql/rows.py
Normal file
90
sql/izzylib/sql/rows.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
from izzylib import DotDict
|
||||||
|
|
||||||
|
|
||||||
|
class RowClasses(DotDict):
|
||||||
|
def __init__(self, *classes):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
for rowclass in classes:
|
||||||
|
self.update({rowclass.__name__.lower(): rowclass})
|
||||||
|
|
||||||
|
|
||||||
|
def get_class(self, name):
|
||||||
|
return self.get(name, Row)
|
||||||
|
|
||||||
|
|
||||||
|
class Row(DotDict):
|
||||||
|
def __init__(self, table, row, session):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if row:
|
||||||
|
try:
|
||||||
|
self._update(row._asdict())
|
||||||
|
except:
|
||||||
|
self._update(row)
|
||||||
|
|
||||||
|
self.__db = session.db
|
||||||
|
self.__table_name = table
|
||||||
|
|
||||||
|
self.__run__(session)
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def db(self):
|
||||||
|
return self.__db
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def table(self):
|
||||||
|
return self.__table_name
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def columns(self):
|
||||||
|
return self.keys()
|
||||||
|
|
||||||
|
|
||||||
|
## Subclass Row and redefine this function
|
||||||
|
def __run__(self, s):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _update(self, *args, **kwargs):
|
||||||
|
super().update(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def delete(self, s=None):
|
||||||
|
izzylog.warning('deprecated function: Row.delete')
|
||||||
|
|
||||||
|
if s:
|
||||||
|
return self.delete_session(s)
|
||||||
|
|
||||||
|
with self.db.session as s:
|
||||||
|
return self.delete_session(s)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_session(self, s):
|
||||||
|
izzylog.warning('deprecated function: Row.delete_session')
|
||||||
|
|
||||||
|
return s.remove(table=self.table, row=self)
|
||||||
|
|
||||||
|
|
||||||
|
def update(self, dict_data={}, s=None, **data):
|
||||||
|
izzylog.warning('deprecated function: Row.update')
|
||||||
|
|
||||||
|
dict_data.update(data)
|
||||||
|
self._update(dict_data)
|
||||||
|
|
||||||
|
if s:
|
||||||
|
return self.update_session(s, **self)
|
||||||
|
|
||||||
|
with self.db.session as s:
|
||||||
|
s.update(row=self, **self)
|
||||||
|
|
||||||
|
|
||||||
|
def update_session(self, s, dict_data={}, **data):
|
||||||
|
izzylog.warning('deprecated function: Row.update_session')
|
||||||
|
|
||||||
|
dict_data.update(data)
|
||||||
|
self._update(dict_data)
|
||||||
|
return s.update(table=self.table, row=self, **dict_data)
|
179
sql/izzylib/sql/session.py
Normal file
179
sql/izzylib/sql/session.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
from izzylib import DotDict, random_gen, izzylog
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.orm.session import Session as sqlalchemy_session
|
||||||
|
|
||||||
|
|
||||||
|
class Session(sqlalchemy_session):
|
||||||
|
def __init__(self, db, trans=False):
|
||||||
|
super().__init__(bind=db.db, future=True)
|
||||||
|
|
||||||
|
self.closed = False
|
||||||
|
self.trans = trans
|
||||||
|
|
||||||
|
self.database = db
|
||||||
|
self.classes = db.classes
|
||||||
|
self.cache = db.cache
|
||||||
|
|
||||||
|
self.sessionid = random_gen(10)
|
||||||
|
self.database.sessions[self.sessionid] = self
|
||||||
|
|
||||||
|
# remove in the future
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
self._setup()
|
||||||
|
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if self.trans:
|
||||||
|
self.begin()
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def __exit__(self, exctype, value, tb):
|
||||||
|
if self.in_transaction():
|
||||||
|
if tb:
|
||||||
|
self.rollback()
|
||||||
|
|
||||||
|
self.commit()
|
||||||
|
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _setup(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def table(self):
|
||||||
|
return self.db.table
|
||||||
|
|
||||||
|
|
||||||
|
def commit(self):
|
||||||
|
if not self.in_transaction():
|
||||||
|
return
|
||||||
|
|
||||||
|
super().commit()
|
||||||
|
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
super().close()
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
del self.db.sessions[self.sessionid]
|
||||||
|
|
||||||
|
self.sessionid = None
|
||||||
|
|
||||||
|
|
||||||
|
def run(self, expression, **kwargs):
|
||||||
|
result = self.execute(text(expression), params=kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return result.mappings().all()
|
||||||
|
except Exception as e:
|
||||||
|
izzylog.verbose(f'Session.run: {e.__class__.__name__}: {e}')
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def count(self, table_name, **kwargs):
|
||||||
|
return self.query(self.table[table_name]).filter_by(**kwargs).count()
|
||||||
|
|
||||||
|
|
||||||
|
def fetch(self, table, single=True, orderby=None, orderdir='asc', **kwargs):
|
||||||
|
RowClass = self.classes.get_class(table.lower())
|
||||||
|
|
||||||
|
query = self.query(self.table[table]).filter_by(**kwargs)
|
||||||
|
|
||||||
|
if not orderby:
|
||||||
|
rows = query.all()
|
||||||
|
|
||||||
|
else:
|
||||||
|
if orderdir == 'asc':
|
||||||
|
rows = query.order_by(getattr(self.table[table].c, orderby).asc()).all()
|
||||||
|
|
||||||
|
elif orderdir == 'desc':
|
||||||
|
rows = query.order_by(getattr(self.table[table].c, orderby).desc()).all()
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported order direction: {orderdir}')
|
||||||
|
|
||||||
|
if single:
|
||||||
|
return RowClass(table, rows[0], self) if len(rows) > 0 else None
|
||||||
|
|
||||||
|
return [RowClass(table, row, self) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def search(self, *args, **kwargs):
|
||||||
|
kwargs.pop('single', None)
|
||||||
|
return self.fetch(*args, single=False, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def insert(self, table, return_row=False, **kwargs):
|
||||||
|
row = self.fetch(table, **kwargs)
|
||||||
|
|
||||||
|
if row:
|
||||||
|
row.update_session(self, **kwargs)
|
||||||
|
return
|
||||||
|
|
||||||
|
if getattr(self.table[table], 'timestamp', None) and not kwargs.get('timestamp'):
|
||||||
|
kwargs['timestamp'] = datetime.now()
|
||||||
|
|
||||||
|
return self.execute(self.table[table].insert().values(**kwargs))
|
||||||
|
|
||||||
|
if return_row:
|
||||||
|
return self.fetch(table, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def update(self, table=None, rowid=None, row=None, return_row=False, **kwargs):
|
||||||
|
if row:
|
||||||
|
rowid = row.id
|
||||||
|
table = row.table
|
||||||
|
|
||||||
|
if not rowid or not table:
|
||||||
|
raise ValueError('Missing row ID or table')
|
||||||
|
|
||||||
|
self.execute(self.table[table].update().where(self.table[table].c.id == rowid).values(**kwargs))
|
||||||
|
|
||||||
|
if return_row:
|
||||||
|
return self.fetch(table, id=rowid)
|
||||||
|
|
||||||
|
|
||||||
|
def remove(self, table=None, rowid=None, row=None):
|
||||||
|
if row:
|
||||||
|
rowid = row.id
|
||||||
|
table = row.table
|
||||||
|
|
||||||
|
if not rowid or not table:
|
||||||
|
raise ValueError('Missing row ID or table')
|
||||||
|
|
||||||
|
self.run(f'DELETE FROM {table} WHERE id=:id', id=rowid)
|
||||||
|
|
||||||
|
|
||||||
|
def append_column(self, table, column):
|
||||||
|
if column.name in self.db.get_columns(table):
|
||||||
|
logging.warning(f'Table "{table}" already has column "{column.name}"')
|
||||||
|
return
|
||||||
|
|
||||||
|
self.run(f'ALTER TABLE {table} ADD COLUMN {column.compile()}')
|
||||||
|
|
||||||
|
|
||||||
|
def remove_column(self, tbl, col):
|
||||||
|
table = self.table[tbl]
|
||||||
|
column = getattr(table, col, None)
|
||||||
|
columns = self.db.get_columns(tbl)
|
||||||
|
|
||||||
|
if col not in columns:
|
||||||
|
izzylog.info(f'Column "{col}" already exists')
|
||||||
|
return
|
||||||
|
|
||||||
|
columns.remove(col)
|
||||||
|
coltext = ','.join(columns)
|
||||||
|
|
||||||
|
self.run(f'CREATE TABLE {tbl}_temp AS SELECT {coltext} FROM {tbl}')
|
||||||
|
self.run(f'DROP TABLE {tbl}')
|
||||||
|
self.run(f'ALTER TABLE {tbl}_temp RENAME TO {tbl}')
|
||||||
|
|
||||||
|
|
||||||
|
def clear_table(self, table):
|
||||||
|
self.run(f'DELETE FROM {table}')
|
|
@ -1,19 +0,0 @@
|
||||||
from enum import Enum
|
|
||||||
from izzylib import DotDict
|
|
||||||
|
|
||||||
|
|
||||||
class BaseType(Enum):
|
|
||||||
INTEGER = int
|
|
||||||
TEXT = str
|
|
||||||
BLOB = bytes
|
|
||||||
REAL = float
|
|
||||||
NUMERIC = float
|
|
||||||
|
|
||||||
|
|
||||||
Type = DotDict(
|
|
||||||
**{v: BaseType.INTEGER for v in ['INT', 'INTEGER', 'TINYINT', 'SMALLINT', 'MEDIUMINT', 'BIGINT', 'UNSIGNED BIG INT', 'INT2', 'INT8']},
|
|
||||||
**{v: BaseType.TEXT for v in ['CHARACTER', 'VARCHAR', 'VARYING CHARACTER', 'NCHAR', 'NATIVE CHARACTER', 'NVARCHAR', 'TEXT', 'CLOB', 'STRING', 'JSON']},
|
|
||||||
**{v: BaseType.BLOB for v in ['BYTES', 'BLOB']},
|
|
||||||
**{v: BaseType.REAL for v in ['REAL', 'DOUBLE', 'DOUBLE PRECISION', 'FLOAT']},
|
|
||||||
**{v: BaseType.NUMERIC for v in ['NUMERIC', 'DECIMAL', 'BOOLEAN', 'DATE', 'DATETIME']}
|
|
||||||
)
|
|
Loading…
Reference in a new issue