rework #3
28 changed files with 892 additions and 1682 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -121,7 +121,7 @@ reload.cfg
|
|||
/izzylib
|
||||
/base/izzylib/dbus
|
||||
/base/izzylib/hasher
|
||||
/base/izzylib/http_requests_client
|
||||
/base/izzylib/http_urllib_client
|
||||
/base/izzylib/http_server
|
||||
/base/izzylib/mbus
|
||||
/base/izzylib/sql
|
||||
|
|
|
@ -16,7 +16,7 @@ You only need to install the base and whatever sub-modules you want to use
|
|||
|
||||
$(venv)/bin/python -m pip install -e "git+https://git.barkshark.xyz/izaliamae/izzylib.git@rework#egg=izzylib-http-server&subdirectory=http_server"
|
||||
|
||||
$(venv)/bin/python -m pip install -e "git+https://git.barkshark.xyz/izaliamae/izzylib.git@rework#egg=izzylib-http-requests-client&subdirectory=requests_client"
|
||||
$(venv)/bin/python -m pip install -e "git+https://git.barkshark.xyz/izaliamae/izzylib.git@rework#egg=izzylib-http-urllib-client&subdirectory=http_urllib_client"
|
||||
|
||||
$(venv)/bin/python -m pip install -e "git+https://git.barkshark.xyz/izaliamae/izzylib.git@rework#egg=izzylib-sql&subdirectory=sql"
|
||||
|
||||
|
@ -26,7 +26,7 @@ You only need to install the base and whatever sub-modules you want to use
|
|||
|
||||
### From Source
|
||||
|
||||
$(venv)/bin/python setup.py install ['all' or a combination of these: dbus hasher http_server requests_client sql template tinydb]
|
||||
$(venv)/bin/python setup.py install ['all' or a combination of these: dbus hasher http_server http_urllib_client sql template tinydb]
|
||||
|
||||
## Documentation
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from .misc import *
|
|||
from .cache import CacheDecorator, LruCache, TtlCache
|
||||
from .connection import Connection
|
||||
|
||||
from .http_urllib_client import HttpUrllibClient, HttpUrllibResponse
|
||||
from .http_client import HttpClient, HttpResponse
|
||||
|
||||
|
||||
def log_import_error(package, *message):
|
||||
|
@ -48,10 +48,10 @@ except ImportError:
|
|||
log_import_error('template', 'Failed to import http template classes. Jinja and HAML templates disabled')
|
||||
|
||||
try:
|
||||
from izzylib.http_requests_client import *
|
||||
from izzylib.http_urllib_client import *
|
||||
|
||||
except ImportError:
|
||||
log_import_error('http_requests_client', 'Failed to import Requests http client classes. Requests http client is disabled')
|
||||
log_import_error('http_urllib_client', 'Failed to import Requests http client classes. Requests http client is disabled')
|
||||
|
||||
try:
|
||||
from izzylib.http_server import PasswordHasher, HttpServer, HttpServerRequest, HttpServerResponse
|
||||
|
|
|
@ -110,18 +110,12 @@ class DefaultDotDict(DotDict):
|
|||
|
||||
|
||||
class LowerDotDict(DotDict):
|
||||
def __getattr__(self, key):
|
||||
return super().__getattr__(self, key.lower())
|
||||
def __getitem__(self, key):
|
||||
return super().__getitem__(key.lower())
|
||||
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
return super().__setattr__(key.lower(), value)
|
||||
|
||||
|
||||
def update(self, data):
|
||||
data = {k.lower(): v for k,v in self.items()}
|
||||
|
||||
return super().update(data)
|
||||
def __setitem__(self, key, value):
|
||||
return super().__setitem__(key.lower(), value)
|
||||
|
||||
|
||||
class MultiDotDict(DotDict):
|
||||
|
|
|
@ -22,7 +22,7 @@ except ImportError:
|
|||
methods = ['connect', 'delete', 'get', 'head', 'options', 'patch', 'post', 'put', 'trace']
|
||||
|
||||
|
||||
class HttpUrllibClient:
|
||||
class HttpClient:
|
||||
def __init__(self, headers={}, useragent=f'IzzyLib/{__version__}', appagent=None, proxy_type='https', proxy_host=None, proxy_port=None):
|
||||
proxy_ports = {
|
||||
'http': 80,
|
||||
|
@ -74,7 +74,7 @@ class HttpUrllibClient:
|
|||
except HTTPError as e:
|
||||
response = e.fp
|
||||
|
||||
return HttpUrllibResponse(response)
|
||||
return HttpResponse(response)
|
||||
|
||||
|
||||
def file(self, url, filepath, *args, filename=None, size=2048, create_dirs=True, **kwargs):
|
||||
|
@ -141,7 +141,7 @@ class HttpUrllibClient:
|
|||
return self.request(*args, headers=headers, **kwargs)
|
||||
|
||||
|
||||
class HttpUrllibResponse(object):
|
||||
class HttpResponse(object):
|
||||
def __init__(self, response):
|
||||
self.body = response.read()
|
||||
self.headers = DefaultDotDict({k.lower(): v.lower() for k,v in response.headers.items()})
|
|
@ -4,6 +4,7 @@ from datetime import datetime
|
|||
from getpass import getpass, getuser
|
||||
from importlib import util
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from . import izzylog
|
||||
from .dotdict import DotDict
|
||||
|
@ -27,7 +28,8 @@ __all__ = [
|
|||
'time_function',
|
||||
'time_function_pprint',
|
||||
'timestamp',
|
||||
'var_name'
|
||||
'var_name',
|
||||
'Url'
|
||||
]
|
||||
|
||||
|
||||
|
@ -460,3 +462,26 @@ def var_name(single=True, **kwargs):
|
|||
|
||||
keys = list(kwargs.keys())
|
||||
return key[0] if single else keys
|
||||
|
||||
|
||||
class Url(str):
|
||||
protocols = {
|
||||
'http': 80,
|
||||
'https': 443,
|
||||
'ftp': 21,
|
||||
'ftps': 990
|
||||
}
|
||||
|
||||
def __init__(self, url):
|
||||
str.__new__(Url, url)
|
||||
|
||||
parsed = urlparse(url)
|
||||
|
||||
self.__parsed = parsed
|
||||
self.proto = parsed.scheme
|
||||
self.host = parsed.netloc
|
||||
self.path = parsed.path
|
||||
self.query = parsed.query
|
||||
self.username = parsed.username
|
||||
self.password = parsed.password
|
||||
self.port = self.protocols.get(self.proto) if not parsed.port else None
|
||||
|
|
|
@ -78,6 +78,5 @@ class AccessLog(MiddlewareBase):
|
|||
|
||||
async def handler(self, request, response):
|
||||
uagent = request.headers.get('user-agent', 'None')
|
||||
address = request.headers.get('x-real-ip', request.forwarded.get('for', request.remote_addr))
|
||||
|
||||
applog.info(f'({multiprocessing.current_process().name}) {address} {request.method} {request.path} {response.status} "{uagent}"')
|
||||
applog.info(f'({multiprocessing.current_process().name}) {request.address} {request.method} {request.path} {response.status} "{uagent}"')
|
||||
|
|
|
@ -11,6 +11,7 @@ class Request(sanic.request.Request):
|
|||
super().__init__(url_bytes, headers, version, method, transport, app)
|
||||
|
||||
self.Headers = Headers(headers)
|
||||
self.address = self.headers.get('x-real-ip', self.forwarded.get('for', self.remote_addr))
|
||||
self.data = Data(self)
|
||||
self.template = self.app.template
|
||||
self.user_level = 0
|
||||
|
|
30
http_urllib_client/izzylib/http_urllib_client/__init__.py
Normal file
30
http_urllib_client/izzylib/http_urllib_client/__init__.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
from .signatures import (
|
||||
verify_request,
|
||||
verify_headers,
|
||||
parse_signature,
|
||||
fetch_actor,
|
||||
fetch_instance,
|
||||
fetch_nodeinfo,
|
||||
fetch_webfinger_account,
|
||||
generate_rsa_key
|
||||
)
|
||||
|
||||
|
||||
from .client import HttpUrllibClient, set_default_client
|
||||
from .request import HttpUrllibRequest
|
||||
from .response import HttpUrllibResponse
|
||||
|
||||
#__all__ = [
|
||||
#'HttpRequestsClient',
|
||||
#'HttpRequestsRequest',
|
||||
#'HttpRequestsResponse',
|
||||
#'fetch_actor',
|
||||
#'fetch_instance',
|
||||
#'fetch_nodeinfo',
|
||||
#'fetch_webfinger_account',
|
||||
#'generate_rsa_key',
|
||||
#'parse_signature',
|
||||
#'set_requests_client',
|
||||
#'verify_headers',
|
||||
#'verify_request',
|
||||
#]
|
128
http_urllib_client/izzylib/http_urllib_client/client.py
Normal file
128
http_urllib_client/izzylib/http_urllib_client/client.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
import json, sys, urllib3
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from base64 import b64encode
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
from io import BytesIO
|
||||
from izzylib import DefaultDotDict, DotDict, LowerDotDict, Path, izzylog as logging, __version__
|
||||
from izzylib.exceptions import HttpFileDownloadedError
|
||||
from ssl import SSLCertVerificationError
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .request import HttpUrllibRequest
|
||||
from .response import HttpUrllibResponse
|
||||
from .signatures import set_client
|
||||
|
||||
|
||||
Client = None
|
||||
|
||||
proxy_ports = {
|
||||
'http': 80,
|
||||
'https': 443
|
||||
}
|
||||
|
||||
|
||||
class HttpUrllibClient:
|
||||
def __init__(self, headers={}, useragent=None, appagent=None, proxy_type='https', proxy_host=None, proxy_port=None, num_pools=20):
|
||||
if not useragent:
|
||||
useragent = f'IzzyLib/{__version__}'
|
||||
|
||||
self.headers = {k:v.lower() for k,v in headers.items()}
|
||||
self.agent = f'{useragent} ({appagent})' if appagent else useragent
|
||||
|
||||
if proxy_type not in ['http', 'https']:
|
||||
raise ValueError(f'Not a valid proxy type: {proxy_type}')
|
||||
|
||||
if proxy_host:
|
||||
proxy = f'{proxy_type}://{proxy_host}:{proxy_ports[proxy_type] if not proxy_port else proxy_port}'
|
||||
self.pool = urllib3.ProxyManager(proxy, num_pools=num_pools)
|
||||
|
||||
else:
|
||||
self.pool = urllib3.PoolManager(num_pools=num_pools)
|
||||
|
||||
|
||||
@property
|
||||
def agent(self):
|
||||
return self.headers['user-agent']
|
||||
|
||||
|
||||
@agent.setter
|
||||
def agent(self, value):
|
||||
self.headers['user-agent'] = value
|
||||
|
||||
|
||||
def set_global(self):
|
||||
set_default_client(self)
|
||||
|
||||
|
||||
def build_request(self, *args, **kwargs):
|
||||
return HttpUrllibRequest(*args, **kwargs)
|
||||
|
||||
|
||||
def handle_request(self, request):
|
||||
request.headers.update(self.headers)
|
||||
response = self.pool.urlopen(*request._args, **request._kwargs)
|
||||
return HttpUrllibResponse(response)
|
||||
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
return self.handle_request(self.build_request(*args, **kwargs))
|
||||
|
||||
|
||||
def signed_request(self, privkey, keyid, *args, **kwargs):
|
||||
return self.request(*args, privkey=privkey, keyid=keyid, **kwargs)
|
||||
|
||||
|
||||
def download(self, url, filepath, *args, filename=None, **kwargs):
|
||||
resp = self.request(url, *args, **kwargs)
|
||||
|
||||
if resp.status != 200:
|
||||
raise HttpFileDownloadedError(f'Failed to download {url}: Status: {resp.status}, Body: {resp.body}')
|
||||
|
||||
return resp.save(filepath)
|
||||
|
||||
|
||||
def image(self, url, filepath, *args, filename=None, ext='png', dimensions=(50, 50), **kwargs):
|
||||
if not Image:
|
||||
izzylog.error('Pillow module is not installed')
|
||||
return
|
||||
|
||||
resp = self.request(url, *args, **kwargs)
|
||||
|
||||
if resp.status != 200:
|
||||
izzylog.error(f'Failed to download {url}:', resp.status, resp.body)
|
||||
return False
|
||||
|
||||
if not filename:
|
||||
filename = Path(url).stem()
|
||||
|
||||
path = Path(filepath)
|
||||
|
||||
if not path.exists:
|
||||
izzylog.error('Path does not exist:', path)
|
||||
return False
|
||||
|
||||
byte = BytesIO()
|
||||
image = Image.open(BytesIO(resp.body))
|
||||
image.thumbnail(dimensions)
|
||||
image.save(byte, format=ext.upper())
|
||||
|
||||
with path.join(filename).open('wb') as fd:
|
||||
fd.write(byte.getvalue())
|
||||
|
||||
|
||||
def json(self, *args, headers={}, activity=True, **kwargs):
|
||||
json_type = 'activity+json' if activity else 'json'
|
||||
headers.update({
|
||||
'accept': f'application/{json_type}'
|
||||
})
|
||||
|
||||
return self.request(*args, headers=headers, **kwargs)
|
||||
|
||||
|
||||
def set_default_client(client=None):
|
||||
global Client
|
||||
Client = client or HttpClient()
|
||||
set_client(Client)
|
111
http_urllib_client/izzylib/http_urllib_client/request.py
Normal file
111
http_urllib_client/izzylib/http_urllib_client/request.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
import json
|
||||
|
||||
from Crypto.Hash import SHA256
|
||||
from izzylib import DotDict, LowerDotDict, Url, boolean
|
||||
from base64 import b64decode, b64encode
|
||||
from datetime import datetime
|
||||
from izzylib import izzylog as logging
|
||||
|
||||
from .signatures import sign_pkcs_headers
|
||||
|
||||
|
||||
methods = ['delete', 'get', 'head', 'options', 'patch', 'post', 'put']
|
||||
|
||||
|
||||
class HttpUrllibRequest:
|
||||
def __init__(self, url, **kwargs):
|
||||
self._body = b''
|
||||
|
||||
method = kwargs.get('method', 'get').lower()
|
||||
|
||||
if method not in methods:
|
||||
raise ValueError(f'Invalid method: {method}')
|
||||
|
||||
self.url = Url(url)
|
||||
self.body = kwargs.get('body')
|
||||
self.method = method
|
||||
self.headers = LowerDotDict(kwargs.get('headers', {}))
|
||||
self.redirect = boolean(kwargs.get('redirect', True))
|
||||
self.retries = int(kwargs.get('retries', 10))
|
||||
self.timeout = int(kwargs.get('timeout', 5))
|
||||
|
||||
privkey = kwargs.get('privkey')
|
||||
keyid = kwargs.get('keyid')
|
||||
|
||||
if privkey and keyid:
|
||||
self.sign(privkey, keyid)
|
||||
|
||||
|
||||
@property
|
||||
def _args(self):
|
||||
return [self.method.upper(), self.url]
|
||||
|
||||
|
||||
@property
|
||||
def _kwargs(self):
|
||||
return {
|
||||
'body': self.body,
|
||||
'headers': self.headers,
|
||||
'redirect': self.redirect,
|
||||
'retries': self.retries,
|
||||
'timeout': self.timeout
|
||||
}
|
||||
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self._body
|
||||
|
||||
|
||||
@body.setter
|
||||
def body(self, data):
|
||||
if isinstance(data, dict):
|
||||
data = DotDict(data).to_json()
|
||||
|
||||
elif any(map(isinstance, [data], [list, tuple])):
|
||||
data = json.dumps(data)
|
||||
|
||||
if data == None:
|
||||
data = b''
|
||||
|
||||
elif not isinstance(data, bytes):
|
||||
data = bytes(data, 'utf-8')
|
||||
|
||||
self._body = data
|
||||
|
||||
|
||||
def set_header(self, key, value):
|
||||
self.headers[key] = value
|
||||
|
||||
|
||||
def unset_header(self, key):
|
||||
self.headers.pop(key, None)
|
||||
|
||||
|
||||
def sign(self, privkey, keyid):
|
||||
self.unset_header('signature')
|
||||
|
||||
self.set_header('(request-target)', f'{self.method.lower()} {self.url.path}')
|
||||
self.set_header('host', self.url.host)
|
||||
self.set_header('date', datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'))
|
||||
|
||||
if self.body:
|
||||
body_hash = b64encode(SHA256.new(self.body).digest()).decode("UTF-8")
|
||||
|
||||
self.set_header('digest', f'SHA-256={body_hash}')
|
||||
self.set_header('content-length', str(len(self.body)))
|
||||
|
||||
sig = {
|
||||
'keyId': keyid,
|
||||
'algorithm': 'rsa-sha256',
|
||||
'headers': ' '.join([k.lower() for k in self.headers.keys()]),
|
||||
'signature': b64encode(sign_pkcs_headers(privkey, self.headers)).decode('UTF-8')
|
||||
}
|
||||
|
||||
sig_items = [f'{k}="{v}"' for k,v in sig.items()]
|
||||
sig_string = ','.join(sig_items)
|
||||
|
||||
self.set_header('signature', sig_string)
|
||||
|
||||
self.unset_header('(request-target)')
|
||||
self.unset_header('host')
|
97
http_urllib_client/izzylib/http_urllib_client/response.py
Normal file
97
http_urllib_client/izzylib/http_urllib_client/response.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
import json
|
||||
|
||||
from io import BytesIO
|
||||
from izzylib import DefaultDotDict, DotDict, Path, Url
|
||||
|
||||
|
||||
class HttpUrllibResponse:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
|
||||
self._dict = None
|
||||
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.dict[key]
|
||||
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.dict[key] = value
|
||||
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
for line in self.headers.get('content-type', '').split(';'):
|
||||
try:
|
||||
k,v = line.split('=')
|
||||
|
||||
if k.lower == 'charset':
|
||||
return v.lower()
|
||||
|
||||
except:
|
||||
pass
|
||||
|
||||
return 'utf-8'
|
||||
|
||||
|
||||
@property
|
||||
def headers(self):
|
||||
return self.response.headers
|
||||
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
return self.response.status
|
||||
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
return Url(self.response.geturl())
|
||||
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
data = self.response.read(cache_content=True)
|
||||
|
||||
if not data:
|
||||
data = self.response.data
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return self.body.decode(self.encoding)
|
||||
|
||||
|
||||
@property
|
||||
def dict(self):
|
||||
if not self._dict:
|
||||
self._dict = DotDict(self.text)
|
||||
|
||||
return self._dict
|
||||
|
||||
|
||||
def json_pretty(self, indent=4):
|
||||
return self.dict.to_json(indent)
|
||||
|
||||
|
||||
def chunks(self, size=1024):
|
||||
return self.response.stream(amt=size)
|
||||
|
||||
|
||||
def save(self, path, overwrite=True, create_parents=True):
|
||||
path = Path(path)
|
||||
|
||||
if not path.parent.exists:
|
||||
if not create_parents:
|
||||
raise ValueError(f'Path does not exist: {path.parent}')
|
||||
|
||||
path.parent.mkdir()
|
||||
|
||||
if overwrite and path.exists:
|
||||
path.delete()
|
||||
|
||||
with path.open('wb') as fd:
|
||||
for chunk in self.chunks():
|
||||
fd.write(chunk)
|
|
@ -5,21 +5,21 @@ from setuptools import setup, find_namespace_packages
|
|||
requires = [
|
||||
'pillow==8.2.0',
|
||||
'pycryptodome==3.10.1',
|
||||
'requests==2.25.1',
|
||||
'urllib==1.26.5',
|
||||
'tldextract==3.1.0'
|
||||
]
|
||||
|
||||
|
||||
setup(
|
||||
name="IzzyLib Requests Client",
|
||||
name="IzzyLib Urllib3 Client",
|
||||
version='0.6.0',
|
||||
packages=find_namespace_packages(include=['izzylib.http_requests_client']),
|
||||
packages=find_namespace_packages(include=['izzylib.http_urllib_client']),
|
||||
python_requires='>=3.7.0',
|
||||
install_requires=requires,
|
||||
include_package_data=False,
|
||||
author='Zoey Mae',
|
||||
author_email='admin@barkshark.xyz',
|
||||
description='A Requests client with support for http header signing and verifying',
|
||||
description='A Urllib3 client with support for http header signing and verifying',
|
||||
keywords='web http client',
|
||||
url='https://git.barkshark.xyz/izaliamae/izzylib',
|
||||
project_urls={
|
|
@ -1,33 +0,0 @@
|
|||
from .signature import (
|
||||
verify_request,
|
||||
verify_headers,
|
||||
parse_signature,
|
||||
fetch_actor,
|
||||
fetch_instance,
|
||||
fetch_nodeinfo,
|
||||
fetch_webfinger_account,
|
||||
generate_rsa_key
|
||||
)
|
||||
|
||||
|
||||
from .client import (
|
||||
HttpRequestsClient,
|
||||
HttpRequestsRequest,
|
||||
HttpRequestsResponse,
|
||||
set_requests_client
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'HttpRequestsClient',
|
||||
'HttpRequestsRequest',
|
||||
'HttpRequestsResponse',
|
||||
'fetch_actor',
|
||||
'fetch_instance',
|
||||
'fetch_nodeinfo',
|
||||
'fetch_webfinger_account',
|
||||
'generate_rsa_key',
|
||||
'parse_signature',
|
||||
'set_requests_client',
|
||||
'verify_headers',
|
||||
'verify_request',
|
||||
]
|
|
@ -1,227 +0,0 @@
|
|||
import json, requests, sys
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from base64 import b64encode
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
from io import BytesIO
|
||||
from izzylib import DefaultDotDict, DotDict, Path, izzylog as logging, __version__
|
||||
from izzylib.exceptions import HttpFileDownloadedError
|
||||
from ssl import SSLCertVerificationError
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .signature import sign_request, set_client
|
||||
|
||||
|
||||
Client = None
|
||||
methods = ['connect', 'delete', 'get', 'head', 'options', 'patch', 'post', 'put', 'trace']
|
||||
|
||||
|
||||
class HttpRequestsClient(object):
|
||||
def __init__(self, headers={}, useragent=f'IzzyLib/{__version__}', appagent=None, proxy_type='https', proxy_host=None, proxy_port=None):
|
||||
proxy_ports = {
|
||||
'http': 80,
|
||||
'https': 443
|
||||
}
|
||||
|
||||
if proxy_type not in ['http', 'https']:
|
||||
raise ValueError(f'Not a valid proxy type: {proxy_type}')
|
||||
|
||||
self.headers=headers
|
||||
self.agent = f'{useragent} ({appagent})' if appagent else useragent
|
||||
self.proxy = DotDict({
|
||||
'enabled': True if proxy_host else False,
|
||||
'ptype': proxy_type,
|
||||
'host': proxy_host,
|
||||
'port': proxy_ports[proxy_type] if not proxy_port else proxy_port
|
||||
})
|
||||
|
||||
|
||||
def set_global(self):
|
||||
set_requests_client(self)
|
||||
|
||||
|
||||
def build_request(self, *args, method='get', privkey=None, keyid=None, **kwargs):
|
||||
if method.lower() not in methods:
|
||||
raise ValueError(f'Invalid method: {method}')
|
||||
|
||||
request = HttpRequestsRequest(self, *args, method=method.lower(), **kwargs)
|
||||
|
||||
if privkey and keyid:
|
||||
request.sign(privkey, keyid)
|
||||
|
||||
return request
|
||||
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
request = self.build_request(*args, **kwargs)
|
||||
return HttpRequestsResponse(request.send())
|
||||
|
||||
|
||||
def signed_request(self, privkey, keyid, *args, **kwargs):
|
||||
return self.request(*args, privkey=privkey, keyid=keyid, **kwargs)
|
||||
|
||||
|
||||
def download(self, url, filepath, *args, filename=None, **kwargs):
|
||||
resp = self.request(url, *args, **kwargs)
|
||||
|
||||
if resp.status != 200:
|
||||
raise HttpFileDownloadedError(f'Failed to download {url}: Status: {resp.status}, Body: {resp.body}')
|
||||
|
||||
return resp.save(filepath)
|
||||
|
||||
|
||||
def image(self, url, filepath, *args, filename=None, ext='png', dimensions=(50, 50), **kwargs):
|
||||
if not Image:
|
||||
izzylog.error('Pillow module is not installed')
|
||||
return
|
||||
|
||||
resp = self.request(url, *args, **kwargs)
|
||||
|
||||
if resp.status != 200:
|
||||
izzylog.error(f'Failed to download {url}:', resp.status, resp.body)
|
||||
return False
|
||||
|
||||
if not filename:
|
||||
filename = Path(url).stem()
|
||||
|
||||
path = Path(filepath)
|
||||
|
||||
if not path.exists:
|
||||
izzylog.error('Path does not exist:', path)
|
||||
return False
|
||||
|
||||
byte = BytesIO()
|
||||
image = Image.open(BytesIO(resp.body))
|
||||
image.thumbnail(dimensions)
|
||||
image.save(byte, format=ext.upper())
|
||||
|
||||
with path.join(filename).open('wb') as fd:
|
||||
fd.write(byte.getvalue())
|
||||
|
||||
|
||||
def json(self, *args, headers={}, activity=True, **kwargs):
|
||||
json_type = 'activity+json' if activity else 'json'
|
||||
headers.update({
|
||||
'accept': f'application/{json_type}'
|
||||
})
|
||||
|
||||
return self.request(*args, headers=headers, **kwargs)
|
||||
|
||||
|
||||
class HttpRequestsRequest(object):
|
||||
def __init__(self, client, url, data=b'', headers={}, query={}, method='get'):
|
||||
parsed = urlparse(url)
|
||||
self.args = [url]
|
||||
self.kwargs = DotDict({'params': query})
|
||||
self.method = method.lower()
|
||||
self.client = client
|
||||
self.path = parsed.path
|
||||
self.host = parsed.netloc
|
||||
self.body = data
|
||||
|
||||
new_headers = client.headers.copy()
|
||||
new_headers.update(headers)
|
||||
|
||||
parsed_headers = {k.lower(): v for k,v in new_headers.items()}
|
||||
|
||||
if not parsed_headers.get('user-agent'):
|
||||
parsed_headers['user-agent'] = client.agent
|
||||
|
||||
self.kwargs['headers'] = DotDict(new_headers)
|
||||
|
||||
if client.proxy.enabled:
|
||||
self.kwargs['proxies'] = DotDict({self.proxy.ptype: f'{self.proxy.ptype}://{self.proxy.host}:{self.proxy.port}'})
|
||||
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.kwargs.data
|
||||
|
||||
|
||||
@body.setter
|
||||
def body(self, data):
|
||||
self.kwargs.data = data.encode('utf-8') if isinstance(data, str) else data
|
||||
|
||||
|
||||
@property
|
||||
def headers(self):
|
||||
return self.kwargs.headers
|
||||
|
||||
|
||||
def add_header(self, key, value):
|
||||
self.kwargs.headers[key] = value
|
||||
|
||||
|
||||
def remove_header(self, key):
|
||||
self.kwargs.headers.pop(key, None)
|
||||
|
||||
|
||||
def send(self):
|
||||
func = getattr(requests, self.method)
|
||||
return func(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
def sign(self, privkey, keyid):
|
||||
sign_request(self, privkey, keyid)
|
||||
|
||||
|
||||
class HttpRequestsResponse(object):
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
self.data = b''
|
||||
self.headers = DefaultDotDict({k.lower(): v.lower() for k,v in response.headers.items()})
|
||||
self.status = response.status_code
|
||||
self.url = response.url
|
||||
|
||||
|
||||
def chunks(self, size=256):
|
||||
return self.response.iter_content(chunk_size=256)
|
||||
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
for chunk in self.chunks():
|
||||
self.data += chunk
|
||||
|
||||
return self.data
|
||||
|
||||
|
||||
@cached_property
|
||||
def text(self):
|
||||
return self.data.decode(self.response.encoding)
|
||||
|
||||
|
||||
@cached_property
|
||||
def json(self):
|
||||
try:
|
||||
return DotDict(self.text)
|
||||
|
||||
except:
|
||||
return DotDict(self.body)
|
||||
|
||||
|
||||
@cached_property
|
||||
def json_pretty(self, indent=4):
|
||||
return json.dumps(self.json, indent=indent)
|
||||
|
||||
|
||||
def save(self, path, overwrite=True):
|
||||
path = Path(path)
|
||||
|
||||
if not path.parent.exists:
|
||||
raise ValueError(f'Path does not exist: {path.parent}')
|
||||
|
||||
if overwrite and path.exists:
|
||||
path.delete()
|
||||
|
||||
with path.open('wb') as fd:
|
||||
for chunk in self.chunks():
|
||||
fd.write(chunk)
|
||||
|
||||
|
||||
def set_requests_client(client=None):
|
||||
global Client
|
||||
Client = client or RequestsClient()
|
||||
set_client(Client)
|
|
@ -1,6 +1,13 @@
|
|||
# old sql classes
|
||||
from .generic import SqlColumn, CustomRows, SqlSession, SqlDatabase, Tables, OperationalError, ProgrammingError
|
||||
from .sqlite_server import SqliteClient, SqliteColumn, SqliteServer, SqliteSession
|
||||
## Normal SQL client
|
||||
from .database import Database, OperationalError, ProgrammingError
|
||||
from .session import Session
|
||||
from .column import Column
|
||||
|
||||
#from .database import Database, Session
|
||||
#from .queries import Column, Insert, Select, Table, Tables, Update
|
||||
## Sqlite server
|
||||
#from .sqlite_server import SqliteClient, SqliteColumn, SqliteServer, SqliteSession
|
||||
|
||||
|
||||
## Compat
|
||||
SqlDatabase = Database
|
||||
SqlSession = Session
|
||||
SqlColumn = Column
|
||||
|
|
54
sql/izzylib/sql/column.py
Normal file
54
sql/izzylib/sql/column.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import (
|
||||
Column as sqlalchemy_column,
|
||||
types as Types
|
||||
)
|
||||
|
||||
|
||||
SqlTypes = {t.lower(): getattr(Types, t) for t in dir(Types) if not t.startswith('_')}
|
||||
|
||||
|
||||
class Column(sqlalchemy_column):
|
||||
def __init__(self, name, stype=None, fkey=None, **kwargs):
|
||||
if not stype and not kwargs:
|
||||
if name == 'id':
|
||||
stype = 'integer'
|
||||
kwargs['primary_key'] = True
|
||||
kwargs['autoincrement'] = True
|
||||
|
||||
elif name == 'timestamp':
|
||||
stype = 'datetime'
|
||||
|
||||
else:
|
||||
raise ValueError('Missing column type and options')
|
||||
|
||||
stype = (stype.lower() if type(stype) == str else stype) or 'string'
|
||||
|
||||
if type(stype) == str:
|
||||
try:
|
||||
stype = SqlTypes[stype.lower()]
|
||||
|
||||
except KeyError:
|
||||
raise KeyError(f'Invalid SQL data type: {stype}')
|
||||
|
||||
options = [name, stype]
|
||||
|
||||
if fkey:
|
||||
options.append(ForeignKey(fkey))
|
||||
|
||||
super().__init__(*options, **kwargs)
|
||||
|
||||
|
||||
def compile(self):
|
||||
sql = f'{self.name} {self.type}'
|
||||
|
||||
if not self.nullable:
|
||||
sql += ' NOT NULL'
|
||||
|
||||
if self.primary_key:
|
||||
sql += ' PRIMARY KEY'
|
||||
|
||||
if self.unique:
|
||||
sql += ' UNIQUE'
|
||||
|
||||
return sql
|
|
@ -1,100 +0,0 @@
|
|||
import importlib, sqlite3, ssl
|
||||
|
||||
from getpass import getuser
|
||||
from izzylib import DotDict, Path, izzylog
|
||||
|
||||
|
||||
defaults = {
|
||||
'name': (None, str),
|
||||
'host': (None, str),
|
||||
'port': (None, int),
|
||||
'username': (getuser(), str),
|
||||
'password': (None, str),
|
||||
'ssl': ('allow', str),
|
||||
'ssl_context': (ssl.create_default_context(), ssl.SSLContext),
|
||||
'ssl_key': (None, Path),
|
||||
'ssl_cert': (None, Path),
|
||||
'max_connections': (25, int),
|
||||
'type': ('sqlite', str),
|
||||
'module': (sqlite3, None),
|
||||
'mod_name': ('sqlite3', str),
|
||||
'timeout': (5, int),
|
||||
'args': ([], list),
|
||||
'kwargs': ({}, dict)
|
||||
}
|
||||
|
||||
modtypes = {
|
||||
'sqlite': ['sqlite3'],
|
||||
'postgresql': ['pg8000', 'psycopg2', 'psycopg3', 'pgdb'],
|
||||
'mysql': ['mysqldb', 'trio_mysql'],
|
||||
'mssql': ['pymssql', 'adodbapi']
|
||||
}
|
||||
|
||||
sslmodes = ['disable', 'allow', 'prefer', 'require', 'verify-ca', 'verify-full']
|
||||
|
||||
|
||||
class Config(DotDict):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__({k: v[0] for k,v in defaults.items()})
|
||||
|
||||
module = kwargs.pop('module', None)
|
||||
|
||||
if module:
|
||||
self.parse_module(module)
|
||||
|
||||
self.update(kwargs)
|
||||
|
||||
if self.ssl != 'disable' and (self.ssl_key or self.ssl_cert):
|
||||
self.ssl_context.load_cert_chain(self.ssl_cert, self.ssl_key)
|
||||
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key not in defaults:
|
||||
raise KeyError(f'Invalid config option: {key}')
|
||||
|
||||
valtype = defaults[key][1]
|
||||
|
||||
if valtype and value and not isinstance(value, valtype):
|
||||
raise TypeError(f'{key} should be a {valtype}, not a {value.__class__.__name__}')
|
||||
|
||||
if key == 'ssl' and value == True:
|
||||
value = ssl.create_default_context()
|
||||
|
||||
super().__setitem__(key, value)
|
||||
|
||||
|
||||
def parse_module(self, name):
|
||||
module = None
|
||||
module_type = None
|
||||
module_name = None
|
||||
|
||||
if name == 'sqlite3':
|
||||
name = 'sqlite'
|
||||
|
||||
for mtype, modules in modtypes.items():
|
||||
if name == mtype:
|
||||
module_type = name
|
||||
|
||||
for mod in modules:
|
||||
try:
|
||||
module = importlib.import_module(mod)
|
||||
module_name = mod
|
||||
break
|
||||
except ImportError:
|
||||
izzylog.verbose(f'Database module not installed:', mod)
|
||||
|
||||
elif name in modules:
|
||||
try:
|
||||
module = importlib.import_module(name)
|
||||
module_type = mtype
|
||||
module_name = name
|
||||
break
|
||||
except ImportError:
|
||||
izzylog.error(f'Database module not installed:', name)
|
||||
|
||||
if None in (module, module_name, module_type):
|
||||
raise ValueError(f'Failed to find module for {name}')
|
||||
|
||||
self.module = module
|
||||
self.mod_name = module_name
|
||||
self.type = module_type
|
|
@ -1,360 +1,186 @@
|
|||
import sqlite3, traceback
|
||||
import json, pkgutil, sys, threading, time
|
||||
|
||||
from functools import partial
|
||||
from getpass import getuser
|
||||
from izzylib import DotDict, izzylog, boolean, random_gen
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from izzylib import LruCache, DotDict, Path, nfs_check, izzylog
|
||||
from sqlalchemy import Table, create_engine
|
||||
from sqlalchemy.exc import OperationalError, ProgrammingError
|
||||
from sqlalchemy.engine import URL
|
||||
from sqlalchemy.schema import MetaData
|
||||
|
||||
from . import error
|
||||
from .config import Config
|
||||
from .queries import Column, Delete, Insert, Select, Table, Tables, Update
|
||||
from .rows import RowClasses
|
||||
from .session import Session
|
||||
|
||||
|
||||
modules = dict(
|
||||
postgresql = ['pygresql', 'pg8000', 'psycopg2', 'psycopg3']
|
||||
)
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, tables=None, **kwargs):
|
||||
self.tables = tables
|
||||
self.cfg = Config(**kwargs)
|
||||
self.sessions = DotDict()
|
||||
def __init__(self, dbtype='sqlite', **kwargs):
|
||||
self._connect_args = [dbtype, kwargs]
|
||||
self.db = None
|
||||
self.cache = None
|
||||
self.config = DotDict()
|
||||
self.meta = MetaData()
|
||||
self.classes = RowClasses(*kwargs.get('row_classes', []))
|
||||
self.cache = None
|
||||
|
||||
self.session_class = kwargs.get('session_class', Session)
|
||||
self.sessions = {}
|
||||
|
||||
self.open()
|
||||
|
||||
|
||||
def _setup_cache(self):
|
||||
self.cache = DotDict({table: LruCache() for table in self.get_tables()})
|
||||
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return self.get_session(False)
|
||||
return self.session_class(self)
|
||||
|
||||
|
||||
@property
|
||||
def session_trans(self):
|
||||
return self.get_session(True)
|
||||
def dbtype(self):
|
||||
return self.db.url.get_backend_name()
|
||||
|
||||
|
||||
def connect(self, sid, session):
|
||||
if len(self.sessions) >= self.cfg.max_connections:
|
||||
raise error.MaxConnectionsError(f'Cannot start a new session with id {sid}. Reach max connection count of {self.cfg.max_connections}.')
|
||||
|
||||
self.sessions[sid] = session
|
||||
|
||||
|
||||
def disconnect(self, sid):
|
||||
self.sessions[sid].disconnect()
|
||||
del self.sessions[sid]
|
||||
|
||||
|
||||
def disconnect_all(self):
|
||||
sids = []
|
||||
|
||||
for sid in self.sessions.keys():
|
||||
sids.append(sid)
|
||||
|
||||
for sid in sids:
|
||||
self.disconnect(sid)
|
||||
|
||||
|
||||
def get_session(self, trans=True):
|
||||
session = Session(self, trans)
|
||||
self.sessions[session.id] = session
|
||||
return session
|
||||
|
||||
|
||||
def execute(self, *args):
|
||||
with self.session as s:
|
||||
s.execute(*args)
|
||||
|
||||
|
||||
def load_tables(self, path):
|
||||
self.tables = Tables.new_from_json_file(path)
|
||||
|
||||
|
||||
def pre_setup(self):
|
||||
if self.cfg.type != 'postgresql':
|
||||
izzylog.verbose(f'Database not supported for pre_setup: {self.cfg.type}')
|
||||
return
|
||||
|
||||
original_database = self.cfg.name
|
||||
self.cfg.name = 'postgres'
|
||||
|
||||
with self.session as s:
|
||||
s.conn.autocommit = True
|
||||
s.rollback()
|
||||
|
||||
if original_database not in s.get_databases():
|
||||
#s.execute('SET AUTOCOMMIT = OFF')
|
||||
s.cursor.execute(f'CREATE DATABASE {original_database}')
|
||||
|
||||
s.conn.autocommit = False
|
||||
|
||||
self.cfg.name = original_database
|
||||
|
||||
|
||||
def set_row_class(self, table, row_class):
|
||||
pass
|
||||
|
||||
|
||||
class Session:
|
||||
def __init__(self, db, trans):
|
||||
self.id = random_gen()
|
||||
self.db = db
|
||||
self.cfg = db.cfg
|
||||
self.trans = trans
|
||||
self.trans_state = False
|
||||
self.conn = None
|
||||
self.cursor = None
|
||||
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
izzylog.verbose('Deleting session:', self.id)
|
||||
except ModuleNotFoundError:
|
||||
if izzylog.get_config('level') >= 20:
|
||||
print('[izzylib] VERBOSE: Deleting session:', self.id)
|
||||
|
||||
self.db.sessions.pop(self.id, None)
|
||||
|
||||
if self.conn:
|
||||
self.disconnect()
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
self.connect()
|
||||
|
||||
if self.trans:
|
||||
self.begin()
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
if exc_traceback:
|
||||
self.rollback()
|
||||
|
||||
else:
|
||||
self.commit()
|
||||
|
||||
self.disconnect()
|
||||
self.db.disconnect(self.id)
|
||||
|
||||
|
||||
def connect(self):
|
||||
if self.conn:
|
||||
return
|
||||
|
||||
self.db.connect(self.id, self)
|
||||
|
||||
if self.cfg.type == 'sqlite':
|
||||
self.conn = self.cfg.module.connect(self.cfg.name, self.cfg.timeout, check_same_thread=True)
|
||||
|
||||
elif self.cfg.type == 'postgresql':
|
||||
options = dict(
|
||||
host = self.cfg.host or '/var/run/postgresql',
|
||||
port = self.cfg.port or 5432,
|
||||
database = self.cfg.name or 'postgresql',
|
||||
user = self.cfg.username or getuser(),
|
||||
password = self.cfg.password,
|
||||
)
|
||||
|
||||
if self.cfg.mod_name == 'pg8000':
|
||||
if options['host'] in [None, '/var/run/postgresql']:
|
||||
port = options.pop('port')
|
||||
options['unix_sock'] = options.pop('host') + f'/.s.PGSQL.{port}'
|
||||
|
||||
## SSL is a pain in the ass tbh. Gonna deal with this later
|
||||
#if self.cfg.mod_name == 'pg8000':
|
||||
#options['sslmode'] = self.cfg.ssl
|
||||
#options['ssl_context'] = self.cfg.ssl_context
|
||||
|
||||
#elif self.cfg.mod_name == 'psycopg2':
|
||||
#options['sslcert'] = self.cfg.ssl_cert
|
||||
#options['sslkey'] = self.cfg.ssl_key
|
||||
|
||||
self.conn = self.cfg.module.connect(**options)
|
||||
|
||||
else:
|
||||
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
|
||||
|
||||
try:
|
||||
self.conn.autocommit = False
|
||||
except:
|
||||
izzylog.verbose('Failed to turn off autocommit')
|
||||
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
|
||||
def disconnect(self):
|
||||
if not self.conn:
|
||||
return
|
||||
|
||||
self.cursor.close()
|
||||
self.conn.close()
|
||||
|
||||
self.cursor = None
|
||||
self.conn = None
|
||||
|
||||
|
||||
def begin(self):
|
||||
if self.trans_state:
|
||||
return
|
||||
|
||||
#self.conn.begin()
|
||||
self.execute('BEGIN TRANSACTION')
|
||||
self.trans_state = True
|
||||
|
||||
|
||||
def rollback(self):
|
||||
if not self.trans_state:
|
||||
return
|
||||
|
||||
self.conn.rollback()
|
||||
#self.execute('ROLLBACK TRANSACTION')
|
||||
self.trans_state = False
|
||||
|
||||
|
||||
def commit(self):
|
||||
if not self.trans_state:
|
||||
return
|
||||
|
||||
self.conn.commit()
|
||||
#self.execute('COMMIT TRANSACTION')
|
||||
self.trans_state = False
|
||||
|
||||
|
||||
## data management functions
|
||||
def execute(self, string, values=[]):
|
||||
if any(map(string.lower().startswith, ['insert', 'update', 'remove', 'create', 'drop'])) and not self.trans_state:
|
||||
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
|
||||
|
||||
self.cursor.execute(string, values)
|
||||
return self.cursor
|
||||
|
||||
|
||||
def fetch(self, table, single=True, **kwargs):
|
||||
rows = []
|
||||
data = Select(table, type=self.cfg.type, **kwargs).exec(self)
|
||||
|
||||
for line in data:
|
||||
row = Row(table, self.cursor.description, line)
|
||||
|
||||
if single:
|
||||
return row
|
||||
|
||||
rows.append(row)
|
||||
|
||||
return rows if not single else None
|
||||
|
||||
|
||||
def search(self, table, **kwargs):
|
||||
return self.fetch(table, single=False, **kwargs)
|
||||
|
||||
|
||||
def insert(self, table, **kwargs):
|
||||
if not self.trans_state:
|
||||
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
|
||||
|
||||
Insert(table, type=self.cfg.type, **kwargs).exec(self)
|
||||
return self.fetch(table, **kwargs)
|
||||
|
||||
|
||||
def update(self, table, rowid, **kwargs):
|
||||
if not self.trans_state:
|
||||
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
|
||||
|
||||
Update(table, rowid, type=self.cfg.type, **kwargs).exec(self)
|
||||
return self.fetch(table, id=rowid)
|
||||
|
||||
|
||||
def delete(self, table, **kwargs):
|
||||
if not self.trans_state:
|
||||
raise error.NoTransactionError('Please start a transaction with "session.begin()" before using a write command.')
|
||||
|
||||
Delete(table, type=self.cfg.type, **kwargs).exec(self)
|
||||
|
||||
|
||||
## helper functions
|
||||
def get_columns(self, table):
|
||||
if table not in self.get_tables():
|
||||
raise KeyError(f'Not an existing table: {table}')
|
||||
|
||||
if self.cfg.type == 'sqlite':
|
||||
rows = self.execute(f'PRAGMA table_info({table})')
|
||||
return [row[1] for row in rows]
|
||||
|
||||
elif self.cfg.type == 'postgresql':
|
||||
rows = self.execute(f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table}'")
|
||||
return [row[0] for row in rows]
|
||||
|
||||
else:
|
||||
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
|
||||
@property
|
||||
def table(self):
|
||||
return DotDict(self.meta.tables)
|
||||
|
||||
|
||||
def get_tables(self):
|
||||
if self.cfg.type == 'sqlite':
|
||||
rows = self.execute("SELECT name FROM sqlite_master WHERE type IN ('table','view') and name NOT LIKE 'sqlite_%'")
|
||||
return list(self.table.keys())
|
||||
|
||||
elif self.cfg.type == 'postgresql':
|
||||
rows = self.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' ORDER BY table_name")
|
||||
|
||||
def get_columns(self, table):
|
||||
return list(col.name for col in self.table[table].columns)
|
||||
|
||||
|
||||
def new_session(self, trans=True):
|
||||
return self.session_class(self, trans=trans)
|
||||
|
||||
|
||||
## Leaving link to example code for read-only sqlite for later use
|
||||
## https://github.com/pudo/dataset/issues/136#issuecomment-128693122
|
||||
def open(self):
|
||||
dbtype, kwargs = self._connect_args
|
||||
engine_kwargs = {
|
||||
'future': True,
|
||||
#'maxconnections': 25
|
||||
}
|
||||
|
||||
if not kwargs.get('name'):
|
||||
raise KeyError('Database "name" is not set')
|
||||
|
||||
if dbtype == 'sqlite':
|
||||
database = kwargs['name']
|
||||
|
||||
if nfs_check(database):
|
||||
izzylog.warning('Database file is on an NFS share which does not support locking. Any writes to the database will fail')
|
||||
|
||||
engine_kwargs['connect_args'] = {'check_same_thread': False}
|
||||
|
||||
elif dbtype == 'postgresql':
|
||||
ssl_context = kwargs.get('ssl')
|
||||
|
||||
if ssl_context:
|
||||
engine_kwargs['ssl_context'] = ssl_context
|
||||
|
||||
if not kwargs.get('host'):
|
||||
kwargs['unix_socket'] = '/var/run/postgresql'
|
||||
|
||||
if kwargs.get('host') and Path(kwargs['host']).exists():
|
||||
kwargs['unix_socket'] = kwargs.pop('host')
|
||||
|
||||
else:
|
||||
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
|
||||
raise TypeError(f'Unsupported database type: {dbtype}')
|
||||
|
||||
return [row[0] for row in rows]
|
||||
self.config.update(kwargs)
|
||||
|
||||
|
||||
def get_databases(self):
|
||||
if self.cfg.type == 'sqlite':
|
||||
izzylog.verbose('This function is useless with sqlite')
|
||||
return
|
||||
|
||||
elif self.cfg.type == 'postgresql':
|
||||
databases = [row[0] for row in self.execute('SELECT datname FROM pg_database')]
|
||||
if dbtype == 'sqlite':
|
||||
url = URL.create(
|
||||
drivername='sqlite',
|
||||
database=kwargs.pop('name')
|
||||
)
|
||||
|
||||
else:
|
||||
raise error.DatabaseNotSupportedError(f'Database not supported yet: {self.cfg.type}')
|
||||
try:
|
||||
for module in modules[dbtype]:
|
||||
if pkgutil.get_loader(module):
|
||||
dbtype = f'{dbtype}+{module}'
|
||||
|
||||
return databases
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
url = URL.create(
|
||||
drivername = dbtype,
|
||||
username = kwargs.pop('user', None),
|
||||
password = kwargs.pop('password', None),
|
||||
host = kwargs.pop('host', None),
|
||||
port = kwargs.pop('port', None),
|
||||
database = kwargs.pop('name'),
|
||||
)
|
||||
|
||||
self.db = create_engine(url, **engine_kwargs)
|
||||
self.meta = MetaData()
|
||||
self.meta.reflect(bind=self.db, resolve_fks=True, views=True)
|
||||
self._setup_cache()
|
||||
|
||||
|
||||
def cursor_description(self):
|
||||
return [row[0] for row in self.cursor.description]
|
||||
def close(self):
|
||||
for sid in list(self.sessions):
|
||||
self.sessions[sid].commit()
|
||||
self.sessions[sid].close()
|
||||
|
||||
self.config = DotDict()
|
||||
self.cache = DotDict()
|
||||
self.meta = None
|
||||
self.db = None
|
||||
|
||||
|
||||
def setup_database(self):
|
||||
if not self.db.tables:
|
||||
raise ValueError('Tables have not been specified.')
|
||||
def load_tables(self, **tables):
|
||||
self.meta = MetaData()
|
||||
|
||||
current_tables = self.get_tables()
|
||||
for name, columns in tables.items():
|
||||
Table(name, self.meta, *columns)
|
||||
|
||||
for name, table in self.db.tables.items():
|
||||
if name in current_tables:
|
||||
izzylog.verbose(f'Skipping table creation since it already exists: {name}')
|
||||
continue
|
||||
|
||||
izzylog.verbose(f'Creating table: {name}')
|
||||
self.execute(table.build(self.cfg.type))
|
||||
self._setup_cache()
|
||||
|
||||
|
||||
class Row(DotDict):
|
||||
def __init__(self, table, keys, values):
|
||||
self._db = None
|
||||
self._table = table
|
||||
def create_database(self, tables={}):
|
||||
if tables:
|
||||
self.load_tables(**tables)
|
||||
|
||||
super().__init__()
|
||||
if self.db.url.get_backend_name() == 'postgresql':
|
||||
predb = create_engine(self.db.engine_string.replace(self.config.name, 'postgres', -1), future=True)
|
||||
conn = predb.connect()
|
||||
|
||||
for idx, key in enumerate([key[0] for key in keys]):
|
||||
self[key] = values[idx]
|
||||
try:
|
||||
conn.execute(text(f'CREATE DATABASE {database}'))
|
||||
|
||||
except ProgrammingError:
|
||||
'The database already exists, so just move along'
|
||||
|
||||
except Exception as e:
|
||||
conn.close()
|
||||
raise e from None
|
||||
|
||||
conn.close()
|
||||
|
||||
self.meta.create_all(bind=self.db)
|
||||
|
||||
|
||||
def update(self, data):
|
||||
for k, v in data.items():
|
||||
if k not in self:
|
||||
raise KeyError(f'Not a column for {self._table}')
|
||||
def drop_tables(self, *tables):
|
||||
if not tables:
|
||||
raise ValueError('No tables specified')
|
||||
|
||||
self[k] = v
|
||||
self.meta.drop_all(bind=self.db, tables=tables)
|
||||
|
||||
|
||||
def delete(self):
|
||||
with self._db.session as s:
|
||||
s.delete(self._table, id=self.id)
|
||||
|
||||
|
||||
def update(self, **kwargs):
|
||||
self.update(kwargs)
|
||||
|
||||
with self._db.session as s:
|
||||
s.update(self._table, id=self.id, **kwargs)
|
||||
def execute(self, string, **kwargs):
|
||||
with self.session as s:
|
||||
s.execute(string, **kwargs)
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
class MaxConnectionsError(Exception):
|
||||
'raise when the max amount of connections has been reached'
|
||||
|
||||
|
||||
class NoTransactionError(Exception):
|
||||
'raise when a write command is executed outside a transaction'
|
||||
|
||||
|
||||
class DatabaseNotSupportedError(Exception):
|
||||
'raise when the action being performed is not supported by the database in use'
|
|
@ -1,508 +0,0 @@
|
|||
import json, sys, threading, time
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from sqlalchemy import create_engine, ForeignKey, MetaData, Table
|
||||
from sqlalchemy import Column as sqlalchemy_column, types as Types
|
||||
from sqlalchemy.exc import OperationalError, ProgrammingError
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
|
||||
from izzylib import (
|
||||
LruCache,
|
||||
DotDict,
|
||||
Path,
|
||||
random_gen,
|
||||
nfs_check,
|
||||
izzylog
|
||||
)
|
||||
|
||||
SqlTypes = DotDict({t.lower(): getattr(Types, t) for t in dir(Types) if not t.startswith('_')})
|
||||
|
||||
|
||||
class SqlDatabase:
|
||||
def __init__(self, dbtype='sqlite', tables={}, **kwargs):
|
||||
self.db = self.__create_engine(dbtype, kwargs)
|
||||
self.table = None
|
||||
self.table_names = None
|
||||
self.classes = kwargs.get('row_classes', CustomRows())
|
||||
self.cache = None
|
||||
|
||||
self.session_class = kwargs.get('session_class', SqlSession)
|
||||
self.sessions = {}
|
||||
|
||||
self.setup_tables(tables)
|
||||
self.setup_cache()
|
||||
|
||||
|
||||
## Leaving link to example code for read-only sqlite for later use
|
||||
## https://github.com/pudo/dataset/issues/136#issuecomment-128693122
|
||||
def __create_engine(self, dbtype, kwargs):
|
||||
engine_args = []
|
||||
engine_kwargs = {}
|
||||
|
||||
if not kwargs.get('name'):
|
||||
raise KeyError('Database "name" is not set')
|
||||
|
||||
engine_string = dbtype + '://'
|
||||
|
||||
if dbtype == 'sqlite':
|
||||
database = kwargs['name']
|
||||
|
||||
if nfs_check(database):
|
||||
izzylog.error('Database file is on an NFS share which does not support locking. Any writes to the database will fail')
|
||||
|
||||
engine_string += f'/{database}'
|
||||
engine_kwargs['connect_args'] = {'check_same_thread': False}
|
||||
|
||||
elif dbtype == 'postgresql':
|
||||
ssl_context = kwargs.get('ssl')
|
||||
|
||||
if ssl_context:
|
||||
engine_kwargs['ssl_context'] = ssl_context
|
||||
|
||||
else:
|
||||
user = kwargs.get('user')
|
||||
password = kwargs.get('pass')
|
||||
host = kwargs.get('host', '/var/run/postgresql')
|
||||
port = kwargs.get('port', 5432)
|
||||
name = kwargs.get('name', 'postgres')
|
||||
maxconn = kwargs.get('maxconnections', 25)
|
||||
|
||||
if user:
|
||||
if password:
|
||||
engine_string += f'{user}:{password}@'
|
||||
else:
|
||||
engine_string += user + '@'
|
||||
|
||||
if host == '/var/run/postgresql':
|
||||
engine_string += f'/{name}:{port}/{name}'
|
||||
|
||||
else:
|
||||
engine_string += f'{host}:{port}/{name}'
|
||||
|
||||
return create_engine(engine_string, *engine_args, **engine_kwargs)
|
||||
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return self.session_class(self)
|
||||
|
||||
|
||||
def close(self):
|
||||
self.setup_cache()
|
||||
|
||||
|
||||
def setup_cache(self):
|
||||
self.cache = DotDict({table: LruCache() for table in self.table_names})
|
||||
|
||||
|
||||
def create_tables(self, *tables):
|
||||
if not tables:
|
||||
raise ValueError('No tables specified')
|
||||
|
||||
new_tables = [self.table[table] for table in tables]
|
||||
self.table.meta.create_all(bind=self.db, tables=new_tables)
|
||||
|
||||
|
||||
def create_database(self):
|
||||
if self.db.url.get_backend_name() == 'postgresql':
|
||||
predb = create_engine(db.engine_string.replace(config.db.name, 'postgres', -1))
|
||||
conn = predb.connect()
|
||||
conn.execute('commit')
|
||||
|
||||
try:
|
||||
conn.execute(f'CREATE DATABASE {config.db.name}')
|
||||
|
||||
except ProgrammingError:
|
||||
'The database already exists, so just move along'
|
||||
|
||||
except Exception as e:
|
||||
conn.close()
|
||||
raise e from None
|
||||
|
||||
conn.close()
|
||||
|
||||
self.table.meta.create_all(self.db)
|
||||
|
||||
|
||||
def setup_tables(self, tables):
|
||||
self.table = Tables(self, tables)
|
||||
self.table_names = tables.keys()
|
||||
|
||||
|
||||
def execute(self, string, values=[]):
|
||||
with self.session as s:
|
||||
s.execute(string, values)
|
||||
|
||||
|
||||
class SqlSession(object):
|
||||
def __init__(self, db):
|
||||
self.closed = False
|
||||
|
||||
self.database = db
|
||||
self.classes = db.classes
|
||||
self.session = sessionmaker(bind=db.db)()
|
||||
self.table = db.table
|
||||
self.cache = db.cache
|
||||
|
||||
# session aliases
|
||||
self.s = self.session
|
||||
self.begin = self.s.begin
|
||||
self.commit = self.s.commit
|
||||
self.rollback = self.s.rollback
|
||||
self.query = self.s.query
|
||||
self.execute = self.s.execute
|
||||
|
||||
# remove in the future
|
||||
self.db = db
|
||||
|
||||
self._setup()
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
self.open()
|
||||
return self
|
||||
|
||||
|
||||
def __exit__(self, exctype, value, tb):
|
||||
if tb:
|
||||
self.rollback()
|
||||
|
||||
self.close()
|
||||
|
||||
|
||||
def open(self):
|
||||
self.sessionid = random_gen(10)
|
||||
self.db.sessions[self.sessionid] = self
|
||||
|
||||
|
||||
def close(self):
|
||||
self.commit()
|
||||
self.s.close()
|
||||
self.closed = True
|
||||
|
||||
del self.db.sessions[self.sessionid]
|
||||
|
||||
self.sessionid = None
|
||||
|
||||
|
||||
def _setup(self):
|
||||
pass
|
||||
|
||||
|
||||
@property
|
||||
def dirty(self):
|
||||
return any([self.s.new, self.s.dirty, self.s.deleted])
|
||||
|
||||
|
||||
def count(self, table_name, **kwargs):
|
||||
return self.query(self.table[table_name]).filter_by(**kwargs).count()
|
||||
|
||||
|
||||
def fetch(self, table_name, single=True, orderby=None, orderdir='asc', **kwargs):
|
||||
table = self.table[table_name]
|
||||
RowClass = self.classes.get(table_name.capitalize())
|
||||
|
||||
query = self.query(table).filter_by(**kwargs)
|
||||
|
||||
if not orderby:
|
||||
rows = query.all()
|
||||
|
||||
else:
|
||||
if orderdir == 'asc':
|
||||
rows = query.order_by(getattr(table.c, orderby).asc()).all()
|
||||
|
||||
elif orderdir == 'desc':
|
||||
rows = query.order_by(getattr(table.c, orderby).desc()).all()
|
||||
|
||||
else:
|
||||
raise ValueError(f'Unsupported order direction: {orderdir}')
|
||||
|
||||
if single:
|
||||
return RowClass(table_name, rows[0], self) if len(rows) > 0 else None
|
||||
|
||||
return [RowClass(table_name, row, self) for row in rows]
|
||||
|
||||
|
||||
def search(self, *args, **kwargs):
|
||||
kwargs.pop('single', None)
|
||||
return self.fetch(*args, single=False, **kwargs)
|
||||
|
||||
|
||||
def insert(self, table_name, return_row=False, **kwargs):
|
||||
row = self.fetch(table_name, **kwargs)
|
||||
|
||||
if row:
|
||||
row.update_session(self, **kwargs)
|
||||
return
|
||||
|
||||
table = self.table[table_name]
|
||||
|
||||
if getattr(table, 'timestamp', None) and not kwargs.get('timestamp'):
|
||||
kwargs['timestamp'] = datetime.now()
|
||||
|
||||
self.execute(table.insert().values(**kwargs))
|
||||
|
||||
if return_row:
|
||||
return self.fetch(table_name, **kwargs)
|
||||
|
||||
|
||||
def update(self, table=None, rowid=None, row=None, return_row=False, **data):
|
||||
if row:
|
||||
if not getattr(row, '_table_name', None):
|
||||
print(row)
|
||||
print(dir(row))
|
||||
rowid = row.id
|
||||
table = row._table_name
|
||||
|
||||
if not rowid or not table:
|
||||
raise ValueError('Missing row ID or table')
|
||||
|
||||
tclass = self.table[table]
|
||||
self.execute(tclass.update().where(tclass.c.id == rowid).values(**data))
|
||||
|
||||
if return_row:
|
||||
return self.fetch(table, id=rowid)
|
||||
|
||||
|
||||
def remove(self, table=None, rowid=None, row=None):
|
||||
if row:
|
||||
rowid = row.id
|
||||
table = row._table_name
|
||||
|
||||
if not rowid or not table:
|
||||
raise ValueError('Missing row ID or table')
|
||||
|
||||
self.execute(f'DELETE FROM {table} WHERE id={rowid}')
|
||||
|
||||
|
||||
def drop_table(self, name):
|
||||
if name not in self.get_tables():
|
||||
raise KeyError(f'Table does not exist: {name}')
|
||||
|
||||
self.execute(f'DROP TABLE {name}')
|
||||
|
||||
|
||||
def drop_tables(self):
|
||||
tables = self.get_tables()
|
||||
|
||||
for table in tables:
|
||||
self.drop_table(table)
|
||||
|
||||
|
||||
def get_columns(self, table):
|
||||
if table not in self.get_tables():
|
||||
raise KeyError(f'Not an existing table: {table}')
|
||||
|
||||
rows = self.execute('PRAGMA table_info(user)')
|
||||
return [row[1] for row in rows]
|
||||
|
||||
|
||||
def get_tables(self):
|
||||
rows = self.execute("SELECT name FROM sqlite_master WHERE type IN ('table','view') and name NOT LIKE 'sqlite_%'")
|
||||
return [row[0] for row in rows]
|
||||
|
||||
|
||||
def append_column(self, table, column):
|
||||
if column.name in self.get_columns(table):
|
||||
logging.warning(f'Table "{table}" already has column "{column.name}"')
|
||||
return
|
||||
|
||||
self.execute(f'ALTER TABLE {table} ADD COLUMN {column.compile()}')
|
||||
|
||||
|
||||
def append_column2(self, tbl, col):
|
||||
table = self.table[tbl]
|
||||
|
||||
try:
|
||||
column = getattr(table.c, col)
|
||||
|
||||
except AttributeError:
|
||||
izzylog.error(f'Table "{tbl}" does not have column "{col}"')
|
||||
return
|
||||
|
||||
columns = [row[1] for row in self.execute(f'PRAGMA table_info({tbl})')]
|
||||
|
||||
if col in self.get_columns(tbl):
|
||||
izzylog.info(f'Column "{col}" already exists')
|
||||
return
|
||||
|
||||
sql = f'ALTER TABLE {tbl} ADD COLUMN {col} {column.type}'
|
||||
|
||||
if not column.nullable:
|
||||
sql += ' NOT NULL'
|
||||
|
||||
if column.primary_key:
|
||||
sql += ' PRIMARY KEY'
|
||||
|
||||
if column.unique:
|
||||
sql += ' UNIQUE'
|
||||
|
||||
self.execute(sql)
|
||||
|
||||
|
||||
def remove_column(self, tbl, col):
|
||||
table = self.table[tbl]
|
||||
column = getattr(table, col, None)
|
||||
columns = [row[1] for row in self.execute(f'PRAGMA table_info({tbl})')]
|
||||
|
||||
if col not in columns:
|
||||
izzylog.info(f'Column "{col}" already exists')
|
||||
return
|
||||
|
||||
columns.remove(col)
|
||||
coltext = ', '.join(columns)
|
||||
|
||||
self.execute(f'CREATE TABLE {tbl}_temp AS SELECT {coltext} FROM {tbl}')
|
||||
self.execute(f'DROP TABLE {tbl}')
|
||||
self.execute(f'ALTER TABLE {tbl}_temp RENAME TO {tbl}')
|
||||
|
||||
|
||||
def clear_table(self, table):
|
||||
self.execute(f'DELETE FROM {table}')
|
||||
|
||||
|
||||
class CustomRows(object):
|
||||
def get(self, name):
|
||||
return getattr(self, name, self.Row)
|
||||
|
||||
|
||||
class Row(DotDict):
|
||||
#_filter_columns = lambda self, row: [attr for attr in dir(row) if not attr.startswith('_') and attr != 'metadata']
|
||||
|
||||
|
||||
def __init__(self, table, row, session):
|
||||
super().__init__()
|
||||
|
||||
if row:
|
||||
try:
|
||||
self._update(row._asdict())
|
||||
except:
|
||||
self._update(row)
|
||||
|
||||
self._db = session.db
|
||||
self._table_name = table
|
||||
self._columns = self.keys()
|
||||
|
||||
self.__run__(session)
|
||||
|
||||
|
||||
## Subclass Row and redefine this function
|
||||
def __run__(self, s):
|
||||
pass
|
||||
|
||||
|
||||
def _filter_data(self):
|
||||
data = {k: v for k,v in self.items() if k in self._columns}
|
||||
|
||||
for k,v in self.items():
|
||||
if v.__class__ == DotDict:
|
||||
data[k] = v.asDict()
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def asDict(self):
|
||||
return self._filter_data()
|
||||
|
||||
|
||||
def _update(self, new_data={}, **kwargs):
|
||||
kwargs.update(new_data)
|
||||
|
||||
for k,v in kwargs.items():
|
||||
if type(v) == dict:
|
||||
self[k] = DotDict(v)
|
||||
|
||||
self[k] = v
|
||||
|
||||
|
||||
def delete(self, s=None):
|
||||
if s:
|
||||
return self.delete_session(s)
|
||||
|
||||
with self._db.session as s:
|
||||
return self.delete_session(s)
|
||||
|
||||
|
||||
def delete_session(self, s):
|
||||
return s.remove(table=self._table_name, row=self)
|
||||
|
||||
|
||||
def update(self, dict_data={}, s=None, **data):
|
||||
dict_data.update(data)
|
||||
self._update(dict_data)
|
||||
|
||||
if s:
|
||||
return self.update_session(s, **self._filter_data())
|
||||
|
||||
with self._db.session as s:
|
||||
s.update(row=self, **self._filter_data())
|
||||
|
||||
|
||||
def update_session(self, s, dict_data={}, **data):
|
||||
dict_data.update(data)
|
||||
self._update(dict_data)
|
||||
return s.update(table=self._table_name, row=self, **dict_data)
|
||||
|
||||
|
||||
class Tables(DotDict):
|
||||
def __init__(self, db, tables={}):
|
||||
'"tables" should be a dict with the table names for keys and a list of Columns for values'
|
||||
super().__init__()
|
||||
|
||||
self.db = db
|
||||
self.meta = MetaData()
|
||||
|
||||
for name, table in tables.items():
|
||||
self.__setup_table(name, table)
|
||||
|
||||
|
||||
def __setup_table(self, name, table):
|
||||
columns = [col if type(col) == SqlColumn else SqlColumn(*col.get('args'), **col.get('kwargs')) for col in table]
|
||||
self[name] = Table(name, self.meta, *columns)
|
||||
|
||||
|
||||
class SqlColumn(sqlalchemy_column):
|
||||
def __init__(self, name, stype=None, fkey=None, **kwargs):
|
||||
if not stype and not kwargs:
|
||||
if name == 'id':
|
||||
stype = 'integer'
|
||||
kwargs['primary_key'] = True
|
||||
kwargs['autoincrement'] = True
|
||||
|
||||
elif name == 'timestamp':
|
||||
stype = 'datetime'
|
||||
|
||||
else:
|
||||
raise ValueError('Missing column type and options')
|
||||
|
||||
stype = (stype.lower() if type(stype) == str else stype) or 'string'
|
||||
|
||||
if type(stype) == str:
|
||||
try:
|
||||
stype = SqlTypes[stype.lower()]
|
||||
|
||||
except KeyError:
|
||||
raise KeyError(f'Invalid SQL data type: {stype}')
|
||||
|
||||
options = [name, stype]
|
||||
|
||||
if fkey:
|
||||
options.append(ForeignKey(fkey))
|
||||
|
||||
super().__init__(*options, **kwargs)
|
||||
|
||||
|
||||
def compile(self):
|
||||
sql = f'{self.name} {self.type}'
|
||||
|
||||
if not self.nullable:
|
||||
sql += ' NOT NULL'
|
||||
|
||||
if self.primary_key:
|
||||
sql += ' PRIMARY KEY'
|
||||
|
||||
if self.unique:
|
||||
sql += ' UNIQUE'
|
||||
|
||||
return sql
|
|
@ -1,415 +0,0 @@
|
|||
from datetime import datetime
|
||||
from functools import partial
|
||||
from izzylib import DotDict, Path
|
||||
|
||||
from .types import BaseType, Type
|
||||
|
||||
|
||||
placeholders = dict(
|
||||
sqlite = '?',
|
||||
postgresql = '%s'
|
||||
)
|
||||
|
||||
|
||||
## Data queries
|
||||
class Delete:
|
||||
def __init__(self, table, type='sqlite', **kwargs):
|
||||
self.table = table
|
||||
self.placeholder = placeholders[type]
|
||||
self.keys = []
|
||||
self.values = []
|
||||
|
||||
for k,v in kwargs.items():
|
||||
self.keys.append(k)
|
||||
self.values.append(v)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
self.build(embed_values=True)
|
||||
|
||||
|
||||
def build(self, comp_type='AND', embed_values=False):
|
||||
sql = 'DELETE FROM {table} WHERE {kstring}'
|
||||
|
||||
if not embed_values:
|
||||
kstring = f' {comp_type.upper()} '.join([f'{k} = {self.placeholder}' for k in self.keys])
|
||||
return sql.format(table=self.table, kstring=kstring), self.values
|
||||
|
||||
values = []
|
||||
|
||||
for idx, value in enumerate(self.values):
|
||||
if type(value) == str:
|
||||
values.append(f"{self.keys[idx]} = '{value}'")
|
||||
|
||||
else:
|
||||
values.append(f"{self.keys[idx]} = {value}")
|
||||
|
||||
kstring = ','.join(values)
|
||||
return sql.format(table=self.table, kstring=kstring, rowid=self.rowid)
|
||||
|
||||
|
||||
def exec(self, session, comp_type='AND'):
|
||||
return session.execute(*self.build(comp_type))
|
||||
|
||||
|
||||
class Insert:
|
||||
def __init__(self, table, type='sqlite', **kwargs):
|
||||
self.table = table
|
||||
self.placeholder = placeholders[type]
|
||||
self.keys = []
|
||||
self.values = []
|
||||
|
||||
for k, v in kwargs.items():
|
||||
self.keys.append(k)
|
||||
self.values.append(v)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.build(embed_values=True)
|
||||
|
||||
|
||||
def build(self, embed_values=False):
|
||||
kstring = ','.join(self.keys)
|
||||
|
||||
if not embed_values:
|
||||
vstring = ','.join([self.placeholder for k in self.keys])
|
||||
return f'INSERT INTO {self.table} ({kstring}) VALUES({vstring})', self.values
|
||||
|
||||
else:
|
||||
vstring = ','.join(self.values)
|
||||
return f'INSERT INTO {self.table} ({kstring}) VALUES({vstring})'
|
||||
|
||||
|
||||
def exec(self, session):
|
||||
return session.execute(*self.build())
|
||||
|
||||
|
||||
class Select:
|
||||
def __init__(self, table, columns=[], type='sqlite', **kwargs):
|
||||
self.placeholder = placeholders[type]
|
||||
self.columns = columns
|
||||
self.table = table
|
||||
self.where = []
|
||||
self.where_build = []
|
||||
self._order = []
|
||||
self.keys = []
|
||||
self.values = []
|
||||
|
||||
self.equals = partial(self.__comparison, '=')
|
||||
self.less = partial(self.__comparison, '<')
|
||||
self.greater = partial(self.__comparison, '>')
|
||||
self.like = partial(self.__comparison, 'LIKE')
|
||||
|
||||
for k,v in kwargs.items():
|
||||
self.equals(k, v)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.build(embed_values=True)
|
||||
|
||||
|
||||
def __comparison(self, comp, key, value):
|
||||
self.values.append(value)
|
||||
self.keys.append(key)
|
||||
self.where.append(f'{key} {comp.upper()} {self.placeholder}')
|
||||
self.where_build.append(f"{key} {comp.upper()} '{value}'" if type(key) == str else f"{key} {comp.upper()} {value}")
|
||||
return self
|
||||
|
||||
|
||||
def order(self, column, asc=True):
|
||||
self._order = [column, 'ASC' if asc else 'DESC']
|
||||
return self
|
||||
|
||||
|
||||
def build(self, comp_type='AND', embed_values=False):
|
||||
if not self.columns:
|
||||
cols = '*'
|
||||
|
||||
else:
|
||||
cols = ','.join('columns')
|
||||
|
||||
sql_query = f'SELECT {cols} FROM {self.table}'
|
||||
|
||||
if self.where:
|
||||
where = f' {comp_type.upper()} '.join(self.where if not embed_values else self.where_build)
|
||||
sql_query += f' WHERE {where}'
|
||||
|
||||
if self._order:
|
||||
col, order = self._order
|
||||
sql_query += f' ORDER BY {col} {order}'
|
||||
|
||||
if embed_values:
|
||||
return sql_query
|
||||
|
||||
return sql_query, self.values
|
||||
|
||||
|
||||
def exec(self, session, comp_type='AND'):
|
||||
return session.execute(*self.build(comp_type))
|
||||
|
||||
|
||||
class Update:
|
||||
def __init__(self, table, rowid, type='sqlite', **kwargs):
|
||||
self.placeholder = placeholders[type]
|
||||
self.table = table
|
||||
self.rowid = rowid
|
||||
self.keys = []
|
||||
self.values = []
|
||||
|
||||
for k,v in kwargs.items():
|
||||
self.keys.append(k)
|
||||
self.values.append(v)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.build(embed_values=True)
|
||||
|
||||
|
||||
def build(self, embed_values=False):
|
||||
sql = 'UPDATE {table} SET {kstring} WHERE id={rowid}'
|
||||
|
||||
if not embed_values:
|
||||
kstring = ','.join([f'{k} = {self.placeholder}' for k in self.keys])
|
||||
return sql.format(table=self.table, kstring=kstring, rowid=self.rowid), self.values
|
||||
|
||||
values = []
|
||||
|
||||
for idx, value in enumerate(self.values):
|
||||
if type(value) == str:
|
||||
values.append(f"{self.keys[idx]} = '{value}'")
|
||||
|
||||
else:
|
||||
values.append(f"{self.keys[idx]} = {value}")
|
||||
|
||||
kstring = ','.join(values)
|
||||
return sql.format(table=self.table, kstring=kstring, rowid=self.rowid)
|
||||
|
||||
|
||||
def exec(self, session):
|
||||
return session.execute(*self.build())
|
||||
|
||||
|
||||
## Database objects
|
||||
class Column:
|
||||
def __init__(self, name, type='STRING', unique=False, nullable=True, default=None, primary_key=False, autoincrement=False, foreign_key=None):
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.nullable = nullable
|
||||
self.default = default
|
||||
self.primary_key = primary_key
|
||||
self.autoincrement = autoincrement
|
||||
self.unique = unique
|
||||
|
||||
if any(map(isinstance, [foreign_key], [list, tuple, set])):
|
||||
self.foreign_key = foreign_key
|
||||
|
||||
else:
|
||||
self.foreign_key = foreign_key.split('.') if foreign_key else None
|
||||
|
||||
if autoincrement:
|
||||
self.primary_key = True
|
||||
self.type = Type['INTEGER']
|
||||
|
||||
if isinstance(self.type, BaseType):
|
||||
self.type = self.type.name
|
||||
|
||||
else:
|
||||
if self.type.upper() in Type.keys():
|
||||
self.type = self.type.upper()
|
||||
|
||||
else:
|
||||
raise TypeError(f'Invalid SQL type: {self.type}')
|
||||
|
||||
if foreign_key and len(self.foreign_key) != 2:
|
||||
raise ValueError('Invalid foreign key. Must be in the format "table.column".')
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.build()
|
||||
|
||||
|
||||
def build(self, dbtype='sqlite'):
|
||||
if dbtype == 'postgresql':
|
||||
if self.type.lower() == 'string':
|
||||
self.type = 'TEXT'
|
||||
|
||||
elif self.type.lower() == 'datetime':
|
||||
self.type = 'TIMESTAMPTZ'
|
||||
|
||||
if self.autoincrement:
|
||||
self.type = 'SERIAL'
|
||||
self.autoincrement = False
|
||||
|
||||
sql = f'{self.name} {self.type}'
|
||||
|
||||
if self.primary_key:
|
||||
sql += ' PRIMARY KEY'
|
||||
|
||||
if self.autoincrement:
|
||||
sql += ' AUTOINCREMENT'
|
||||
|
||||
if self.unique:
|
||||
sql += ' UNIQUE'
|
||||
|
||||
if not self.nullable:
|
||||
sql += ' NOT NULL'
|
||||
|
||||
if self.default:
|
||||
def_type = type(self.default)
|
||||
|
||||
if self.default == 'CURRENT_TIMESTAMP':
|
||||
if dbtype == 'sqlite':
|
||||
sql += " DEFAULT (datetime('now', 'localtime'))"
|
||||
|
||||
elif dbtype == 'postgresql':
|
||||
sql += ' DEFAULT now()'
|
||||
|
||||
else:
|
||||
sql += f' DEFAULT {datetime.now().timestamp()}'
|
||||
|
||||
elif def_type == str:
|
||||
sql += f" DEFAULT '{self.default}'"
|
||||
|
||||
elif def_type in [int, float]:
|
||||
sql += f' DEFAULT {self.default}'
|
||||
|
||||
elif def_type == bool and dbtype == 'sqlite':
|
||||
sql += f' DEFAULT {int(self.default)}'
|
||||
|
||||
else:
|
||||
sql += f' DEFAULT {self.default}'
|
||||
|
||||
print(sql)
|
||||
return sql
|
||||
|
||||
|
||||
def json(self):
|
||||
return DotDict({
|
||||
'type': self.type,
|
||||
'nullable': self.nullable,
|
||||
'default': self.default,
|
||||
'primary_key': self.primary_key,
|
||||
'autoincrement': self.autoincrement,
|
||||
'unique': self.unique,
|
||||
'foreign_key': self.foreign_key
|
||||
})
|
||||
|
||||
|
||||
class Table(DotDict):
|
||||
def __init__(self, name, *columns):
|
||||
super().__init__()
|
||||
self._name = name
|
||||
self._foreign_keys = {}
|
||||
|
||||
self.add_column(Column('id', autoincrement=True))
|
||||
|
||||
for column in columns:
|
||||
self.add_column(column)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.build()
|
||||
|
||||
|
||||
# this'll be useful later
|
||||
def __call__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
|
||||
def add_column(self, column):
|
||||
self[column.name] = column
|
||||
|
||||
if column.foreign_key:
|
||||
self._foreign_keys[column.name] = column.foreign_key
|
||||
|
||||
|
||||
def build(self, dbtype='sqlite'):
|
||||
column_string = ',\n'.join([f'\t{col.build(dbtype)}' for col in self.values()])
|
||||
|
||||
if self._foreign_keys:
|
||||
column_string += ',\n'
|
||||
column_string += ',\n'.join([f'\tFOREIGN KEY ({column}) REFERENCES {key[0]} ({key[1]})' for column, key in self._foreign_keys.items()])
|
||||
|
||||
return f'''CREATE TABLE {self.name} (
|
||||
{column_string}
|
||||
);'''
|
||||
|
||||
|
||||
def json(self):
|
||||
data = {}
|
||||
|
||||
for name, column in self.items():
|
||||
data[name] = column.json()
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class Tables(DotDict):
|
||||
def __init__(self, *tables, data={}):
|
||||
super().__init__()
|
||||
|
||||
for table in tables:
|
||||
self.add_table(table)
|
||||
|
||||
if data:
|
||||
self.from_dict(data)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.build()
|
||||
|
||||
|
||||
@classmethod
|
||||
def new_from_json_file(cls, path):
|
||||
return cls(data=DotDict.new_from_json_file(path))
|
||||
|
||||
|
||||
def add_table(self, table):
|
||||
self[table.name] = table
|
||||
|
||||
|
||||
def build(self):
|
||||
return '\n\n'.join([str(table) for table in self.values()])
|
||||
|
||||
|
||||
def load_json(self, path):
|
||||
data = DotDict()
|
||||
data.load_json(path)
|
||||
|
||||
self.from_dict(data)
|
||||
|
||||
|
||||
def save_json(self, path, indent='\t'):
|
||||
self.to_dict().save_json(path, indent=indent)
|
||||
|
||||
|
||||
def from_dict(self, data):
|
||||
for name, columns in data.items():
|
||||
table = Table(name)
|
||||
|
||||
for col, kwargs in columns.items():
|
||||
table.add_column(Column(col,
|
||||
type = kwargs.get('type', 'STRING'),
|
||||
nullable = kwargs.get('nullable', True),
|
||||
default = kwargs.get('default'),
|
||||
primary_key = kwargs.get('primary_key', False),
|
||||
autoincrement = kwargs.get('autoincrement', False),
|
||||
unique = kwargs.get('unique', False),
|
||||
foreign_key = kwargs.get('foreign_key')
|
||||
))
|
||||
|
||||
self.add_table(table)
|
||||
|
||||
|
||||
def to_dict(self):
|
||||
data = DotDict()
|
||||
|
||||
for name, table in self.items():
|
||||
data[name] = table.json()
|
||||
|
||||
return data
|
|
@ -1,19 +0,0 @@
|
|||
from izzylib import DotDict
|
||||
|
||||
|
||||
class DbRow(DotDict):
|
||||
def __init__(self, table, keys, values):
|
||||
self.table = table
|
||||
|
||||
super().__init__()
|
||||
|
||||
for idx, key in enumerate(keys):
|
||||
self[key] = values[idx]
|
||||
|
||||
|
||||
def delete(self):
|
||||
pass
|
||||
|
||||
|
||||
def update(self, **kwargs):
|
||||
pass
|
90
sql/izzylib/sql/rows.py
Normal file
90
sql/izzylib/sql/rows.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
from izzylib import DotDict
|
||||
|
||||
|
||||
class RowClasses(DotDict):
|
||||
def __init__(self, *classes):
|
||||
super().__init__()
|
||||
|
||||
for rowclass in classes:
|
||||
self.update({rowclass.__name__.lower(): rowclass})
|
||||
|
||||
|
||||
def get_class(self, name):
|
||||
return self.get(name, Row)
|
||||
|
||||
|
||||
class Row(DotDict):
|
||||
def __init__(self, table, row, session):
|
||||
super().__init__()
|
||||
|
||||
if row:
|
||||
try:
|
||||
self._update(row._asdict())
|
||||
except:
|
||||
self._update(row)
|
||||
|
||||
self.__db = session.db
|
||||
self.__table_name = table
|
||||
|
||||
self.__run__(session)
|
||||
|
||||
|
||||
@property
|
||||
def db(self):
|
||||
return self.__db
|
||||
|
||||
|
||||
@property
|
||||
def table(self):
|
||||
return self.__table_name
|
||||
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
return self.keys()
|
||||
|
||||
|
||||
## Subclass Row and redefine this function
|
||||
def __run__(self, s):
|
||||
pass
|
||||
|
||||
|
||||
def _update(self, *args, **kwargs):
|
||||
super().update(*args, **kwargs)
|
||||
|
||||
|
||||
def delete(self, s=None):
|
||||
izzylog.warning('deprecated function: Row.delete')
|
||||
|
||||
if s:
|
||||
return self.delete_session(s)
|
||||
|
||||
with self.db.session as s:
|
||||
return self.delete_session(s)
|
||||
|
||||
|
||||
def delete_session(self, s):
|
||||
izzylog.warning('deprecated function: Row.delete_session')
|
||||
|
||||
return s.remove(table=self.table, row=self)
|
||||
|
||||
|
||||
def update(self, dict_data={}, s=None, **data):
|
||||
izzylog.warning('deprecated function: Row.update')
|
||||
|
||||
dict_data.update(data)
|
||||
self._update(dict_data)
|
||||
|
||||
if s:
|
||||
return self.update_session(s, **self)
|
||||
|
||||
with self.db.session as s:
|
||||
s.update(row=self, **self)
|
||||
|
||||
|
||||
def update_session(self, s, dict_data={}, **data):
|
||||
izzylog.warning('deprecated function: Row.update_session')
|
||||
|
||||
dict_data.update(data)
|
||||
self._update(dict_data)
|
||||
return s.update(table=self.table, row=self, **dict_data)
|
179
sql/izzylib/sql/session.py
Normal file
179
sql/izzylib/sql/session.py
Normal file
|
@ -0,0 +1,179 @@
|
|||
from izzylib import DotDict, random_gen, izzylog
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm.session import Session as sqlalchemy_session
|
||||
|
||||
|
||||
class Session(sqlalchemy_session):
|
||||
def __init__(self, db, trans=False):
|
||||
super().__init__(bind=db.db, future=True)
|
||||
|
||||
self.closed = False
|
||||
self.trans = trans
|
||||
|
||||
self.database = db
|
||||
self.classes = db.classes
|
||||
self.cache = db.cache
|
||||
|
||||
self.sessionid = random_gen(10)
|
||||
self.database.sessions[self.sessionid] = self
|
||||
|
||||
# remove in the future
|
||||
self.db = db
|
||||
|
||||
self._setup()
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
if self.trans:
|
||||
self.begin()
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def __exit__(self, exctype, value, tb):
|
||||
if self.in_transaction():
|
||||
if tb:
|
||||
self.rollback()
|
||||
|
||||
self.commit()
|
||||
|
||||
self.close()
|
||||
|
||||
|
||||
def _setup(self):
|
||||
pass
|
||||
|
||||
|
||||
@property
|
||||
def table(self):
|
||||
return self.db.table
|
||||
|
||||
|
||||
def commit(self):
|
||||
if not self.in_transaction():
|
||||
return
|
||||
|
||||
super().commit()
|
||||
|
||||
|
||||
def close(self):
|
||||
super().close()
|
||||
self.closed = True
|
||||
|
||||
del self.db.sessions[self.sessionid]
|
||||
|
||||
self.sessionid = None
|
||||
|
||||
|
||||
def run(self, expression, **kwargs):
|
||||
result = self.execute(text(expression), params=kwargs)
|
||||
|
||||
try:
|
||||
return result.mappings().all()
|
||||
except Exception as e:
|
||||
izzylog.verbose(f'Session.run: {e.__class__.__name__}: {e}')
|
||||
return result
|
||||
|
||||
|
||||
def count(self, table_name, **kwargs):
|
||||
return self.query(self.table[table_name]).filter_by(**kwargs).count()
|
||||
|
||||
|
||||
def fetch(self, table, single=True, orderby=None, orderdir='asc', **kwargs):
|
||||
RowClass = self.classes.get_class(table.lower())
|
||||
|
||||
query = self.query(self.table[table]).filter_by(**kwargs)
|
||||
|
||||
if not orderby:
|
||||
rows = query.all()
|
||||
|
||||
else:
|
||||
if orderdir == 'asc':
|
||||
rows = query.order_by(getattr(self.table[table].c, orderby).asc()).all()
|
||||
|
||||
elif orderdir == 'desc':
|
||||
rows = query.order_by(getattr(self.table[table].c, orderby).desc()).all()
|
||||
|
||||
else:
|
||||
raise ValueError(f'Unsupported order direction: {orderdir}')
|
||||
|
||||
if single:
|
||||
return RowClass(table, rows[0], self) if len(rows) > 0 else None
|
||||
|
||||
return [RowClass(table, row, self) for row in rows]
|
||||
|
||||
|
||||
def search(self, *args, **kwargs):
|
||||
kwargs.pop('single', None)
|
||||
return self.fetch(*args, single=False, **kwargs)
|
||||
|
||||
|
||||
def insert(self, table, return_row=False, **kwargs):
|
||||
row = self.fetch(table, **kwargs)
|
||||
|
||||
if row:
|
||||
row.update_session(self, **kwargs)
|
||||
return
|
||||
|
||||
if getattr(self.table[table], 'timestamp', None) and not kwargs.get('timestamp'):
|
||||
kwargs['timestamp'] = datetime.now()
|
||||
|
||||
return self.execute(self.table[table].insert().values(**kwargs))
|
||||
|
||||
if return_row:
|
||||
return self.fetch(table, **kwargs)
|
||||
|
||||
|
||||
def update(self, table=None, rowid=None, row=None, return_row=False, **kwargs):
|
||||
if row:
|
||||
rowid = row.id
|
||||
table = row.table
|
||||
|
||||
if not rowid or not table:
|
||||
raise ValueError('Missing row ID or table')
|
||||
|
||||
self.execute(self.table[table].update().where(self.table[table].c.id == rowid).values(**kwargs))
|
||||
|
||||
if return_row:
|
||||
return self.fetch(table, id=rowid)
|
||||
|
||||
|
||||
def remove(self, table=None, rowid=None, row=None):
|
||||
if row:
|
||||
rowid = row.id
|
||||
table = row.table
|
||||
|
||||
if not rowid or not table:
|
||||
raise ValueError('Missing row ID or table')
|
||||
|
||||
self.run(f'DELETE FROM {table} WHERE id=:id', id=rowid)
|
||||
|
||||
|
||||
def append_column(self, table, column):
|
||||
if column.name in self.db.get_columns(table):
|
||||
logging.warning(f'Table "{table}" already has column "{column.name}"')
|
||||
return
|
||||
|
||||
self.run(f'ALTER TABLE {table} ADD COLUMN {column.compile()}')
|
||||
|
||||
|
||||
def remove_column(self, tbl, col):
|
||||
table = self.table[tbl]
|
||||
column = getattr(table, col, None)
|
||||
columns = self.db.get_columns(tbl)
|
||||
|
||||
if col not in columns:
|
||||
izzylog.info(f'Column "{col}" already exists')
|
||||
return
|
||||
|
||||
columns.remove(col)
|
||||
coltext = ','.join(columns)
|
||||
|
||||
self.run(f'CREATE TABLE {tbl}_temp AS SELECT {coltext} FROM {tbl}')
|
||||
self.run(f'DROP TABLE {tbl}')
|
||||
self.run(f'ALTER TABLE {tbl}_temp RENAME TO {tbl}')
|
||||
|
||||
|
||||
def clear_table(self, table):
|
||||
self.run(f'DELETE FROM {table}')
|
|
@ -1,19 +0,0 @@
|
|||
from enum import Enum
|
||||
from izzylib import DotDict
|
||||
|
||||
|
||||
class BaseType(Enum):
|
||||
INTEGER = int
|
||||
TEXT = str
|
||||
BLOB = bytes
|
||||
REAL = float
|
||||
NUMERIC = float
|
||||
|
||||
|
||||
Type = DotDict(
|
||||
**{v: BaseType.INTEGER for v in ['INT', 'INTEGER', 'TINYINT', 'SMALLINT', 'MEDIUMINT', 'BIGINT', 'UNSIGNED BIG INT', 'INT2', 'INT8']},
|
||||
**{v: BaseType.TEXT for v in ['CHARACTER', 'VARCHAR', 'VARYING CHARACTER', 'NCHAR', 'NATIVE CHARACTER', 'NVARCHAR', 'TEXT', 'CLOB', 'STRING', 'JSON']},
|
||||
**{v: BaseType.BLOB for v in ['BYTES', 'BLOB']},
|
||||
**{v: BaseType.REAL for v in ['REAL', 'DOUBLE', 'DOUBLE PRECISION', 'FLOAT']},
|
||||
**{v: BaseType.NUMERIC for v in ['NUMERIC', 'DECIMAL', 'BOOLEAN', 'DATE', 'DATETIME']}
|
||||
)
|
Loading…
Reference in a new issue