major update

This commit is contained in:
Izalia Mae 2021-01-21 19:31:38 -05:00
parent 221beb7670
commit 0e59542626
9 changed files with 1375 additions and 456 deletions

View file

@ -8,4 +8,4 @@ import sys
assert sys.version_info >= (3, 6)
__version__ = (0, 3, 1)
__version__ = (0, 4, 0)

View file

@ -3,19 +3,14 @@ import re
from colour import Color
from . import logging
check = lambda color: Color(f'#{str(color)}' if re.search(r'^(?:[0-9a-fA-F]{3}){1,2}$', color) else color)
def _multi(multiplier):
if multiplier > 100:
if multiplier >= 1:
return 1
elif multiplier > 1:
return multiplier/100
elif multiplier < 0:
elif multiplier <= 0:
return 0
return multiplier

536
IzzyLib/database.py Normal file
View file

@ -0,0 +1,536 @@
## Probably gonna replace all of this with a custom sqlalchemy setup tbh
## It'll look like the db classes and functions in https://git.barkshark.xyz/izaliamae/social
import shutil, traceback, importlib, sqlite3, sys, json
from contextlib import contextmanager
from datetime import datetime
from .cache import LRUCache
from .misc import Boolean, DotDict, Path
from . import logging, sql
try:
from dbutils.pooled_db import PooledDB
except ImportError:
from DBUtils.PooledDB import PooledDB
## Only sqlite3 has been tested
## Use other modules at your own risk.
class DB():
def __init__(self, tables, dbmodule='sqlite', cursor=None, **kwargs):
cursor = Cursor if not cursor else cursor
if dbmodule in ['sqlite', 'sqlite3']:
self.dbmodule = sqlite3
self.dbtype = 'sqlite'
else:
self.dbmodule = None
self.__setup_module(dbmodule)
self.db = None
self.cursor = lambda : cursor(self).begin()
self.kwargs = kwargs
self.tables = tables
self.cache = DotDict()
self.__setup_database()
for table in tables.keys():
self.__setup_cache(table)
def query(self, *args, **kwargs):
return self.__cursor_cmd('query', *args, **kwargs)
def fetch(self, *args, **kwargs):
return self.__cursor_cmd('fetch', *args, **kwargs)
def insert(self, *args, **kwargs):
return self.__cursor_cmd('insert', *args, **kwargs)
def update(self, *args, **kwargs):
return self.__cursor_cmd('update', *args, **kwargs)
def remove(self, *args, **kwargs):
return self.__cursor_cmd('remove', *args, **kwargs)
def __cursor_cmd(self, name, *args, **kwargs):
with self.cursor() as cur:
method = getattr(cur, name)
if not method:
raise KeyError('Not a valid cursor method:', name)
return method(*args, **kwargs)
def __setup_module(self, dbtype):
modules = []
modtypes = {
'sqlite': ['sqlite3'],
'postgresql': ['psycopg3', 'pgdb', 'psycopg2', 'pg8000'],
'mysql': ['mysqldb', 'trio_mysql'],
'mssql': ['pymssql', 'adodbapi']
}
for dbmod, mods in modtypes.items():
if dbtype == dbmod:
self.dbtype = dbmod
modules = mods
break
elif dbtype in mods:
self.dbtype = dbmod
modules = [dbtype]
break
if not modules:
logging.verbose('Not a database type. Checking if it is a module...')
for mod in modules:
try:
self.dbmodule = importlib.import_module(mod)
except ImportError as e:
logging.verbose('Module not installed:', mod)
if self.dbmodule:
break
if not self.dbmodule:
if modtypes.get(dbtype):
logging.error('Failed to find module for database type:', dbtype)
logging.error(f'Please install one of these modules to use a {dbtype} database:', ', '.join(modules))
else:
logging.error('Failed to import module:', dbtype)
logging.error('Install one of the following modules:')
for key, modules in modtypes.items():
logging.error(f'{key}:')
for mod in modules:
logging.error(f'\t{mod}')
sys.exit()
def __setup_database(self):
if self.dbtype == 'sqlite':
if not self.kwargs.get('database'):
dbfile = ':memory:'
dbfile = Path(self.kwargs['database'])
dbfile.parent().mkdir()
if not dbfile.parent().isdir():
logging.error('Invalid path to database file:', dbfile.parent().str())
sys.exit()
self.kwargs['database'] = dbfile.str()
self.kwargs['check_same_thread'] = False
else:
if not self.kwargs.get('password'):
self.kwargs.pop('password', None)
self.db = PooledDB(self.dbmodule, **self.kwargs)
def connection(self):
if self.dbtype == 'sqlite':
return self.db.connection()
def __setup_cache(self, table):
self.cache[table] = LRUCache(128)
def close(self):
self.db.close()
def count(self, table):
tabledict = tables.get(table)
if not tabledict:
logging.debug('Table does not exist:', table)
return 0
data = self.query(f'SELECT COUNT(*) FROM {table}')
return data[0][0]
#def query(self, string, values=[], cursor=None):
#if not string.endswith(';'):
#string += ';'
#if not cursor:
#with self.Cursor() as cursor:
#cursor.execute(string, values)
#return cursor.fetchall()
#else:
#cursor.execute(string,value)
#return cursor.fetchall()
#return False
@contextmanager
def Cursor(self):
conn = self.db.connection()
conn.begin()
cursor = conn.cursor()
try:
yield cursor
except self.dbmodule.OperationalError:
traceback.print_exc()
conn.rollback()
finally:
cursor.close()
conn.commit()
conn.close()
def CreateTable(self, table):
layout = DotDict(self.tables.get(table))
if not layout:
logging.error('Table config doesn\'t exist:', table)
return
cmd = f'CREATE TABLE IF NOT EXISTS {table}('
items = []
for k, v in layout.items():
options = ' '.join(v.get('options', []))
default = v.get('default')
item = f'{k} {v["type"].upper()} {options}'
if default:
item += f'DEFAULT {default}'
items.append(item)
cmd += ', '.join(items) + ')'
return True if self.query(cmd) != False else False
def RenameTable(self, table, newname):
self.query(f'ALTER TABLE {table} RENAME TO {newname}')
def DropTable(self, table):
self.query(f'DROP TABLE {table}')
def AddColumn(self, table, name, datatype, default=None, options=None):
query = f'ALTER TABLE {table} ADD COLUMN {name} {datatype.upper()}'
if default:
query += f' DEFAULT {default}'
if options:
query += f' {options}'
self.query(query)
def CheckDatabase(self, database):
if self.dbtype == 'postgresql':
tables = self.query('SELECT datname FROM pg_database')
else:
tables = []
tables = [table[0] for table in tables]
print(database in tables, database, tables)
return database in tables
def CreateTables(self):
dbname = self.kwargs['database']
for name, table in self.tables.items():
if len(self.query(sql.CheckTable(self.dbtype, name))) < 1:
logging.info('Creating table:', name)
self.query(table.sql())
class Table(dict):
def __init__(self, name, dbtype='sqlite'):
super().__init__({})
self.sqlstr = 'CREATE TABLE IF NOT EXISTS {} ({});'
self.dbtype = dbtype
self.name = name
self.columns = {}
self.fkeys = {}
def addColumn(self, name, datatype=None, null=True, unique=None, primary=None, fkey=None):
if name == 'id':
if self.dbtype == 'sqlite':
datatype = 'integer'
primary = True if primary == None else primary
unique = True
null = False
else:
datatype = 'serial'
primary = 'true'
if name == 'timestamp':
datatype = 'float'
elif not datatype:
raise MissingTypeError(f'Missing a data type for column: {name}')
colsql = f'{name} {datatype.upper()}'
if unique:
colsql += ' UNIQUE'
if not null:
colsql += ' NOT NULL'
if primary:
self.primary = name
colsql += ' PRIMARY KEY'
if name == 'id' and self.dbtype == 'sqlite':
colsql += ' AUTOINCREMENT'
if fkey:
if self.dbtype == 'postgresql':
colsql += f' REFERENCES {fkey[0]}({fkey[1]})'
elif self.dbtype == 'sqlite':
self.fkeys[name] += f'FOREIGN KEY ({name}) REFERENCES {fkey[0]} ({fkey[1]})'
self.columns.update({name: colsql})
def sql(self):
if not self.primary:
logging.error('Please specify a primary column')
return
data = ', '.join(list(self.columns.values()))
if self.fkeys:
data += ', '
data += ', '.join(list(self.fkeys.values()))
sqldata = self.sqlstr.format(self.name, data)
print(sqldata)
return sqldata
class Cursor(object):
def __init__(self, db):
self.main = db
self.db = db.db
@contextmanager
def begin(self):
self.conn = self.db.connection()
self.conn.begin()
self.cursor = self.conn.cursor()
try:
yield self
except self.main.dbmodule.OperationalError:
self.conn.rollback()
raise
finally:
self.cursor.close()
self.conn.commit()
self.conn.close()
def query(self, string, values=[], cursor=None):
#if not string.endswith(';'):
#string += ';'
self.cursor.execute(string, values)
data = self.cursor.fetchall()
def fetch(self, table, single=True, sort=None, **kwargs):
rowid = kwargs.get('id')
querysort = f'ORDER BY {sort}'
resultOpts = [self, table, self.cursor]
if rowid:
cursor.execute(f"SELECT * FROM {table} WHERE id = ?", [rowid])
elif kwargs:
placeholders = [f'{k} = ?' for k in kwargs.keys()]
values = kwargs.values()
where = ' and '.join(placeholders)
query = f"SELECT * FROM {table} WHERE {where} {querysort if sort else ''}"
self.cursor.execute(query, list(values))
else:
self.cursor.execute(f'SELECT * FROM {table} {querysort if sort else ""}')
rows = self.cursor.fetchall() if not single else self.cursor.fetchone()
if rows:
if single:
return DBResult(rows, *resultOpts)
return [DBResult(row, *resultOpts) for row in rows]
return None if single else []
def insert(self, table, data={}, **kwargs):
data.update(kwargs)
placeholders = ",".join(['?' for _ in data.keys()])
values = tuple(data.values())
keys = ','.join(data.keys())
if 'timestamp' in self.main.tables[table].keys() and 'timestamp' not in keys:
data['timestamp'] = datetime.now()
query = f'INSERT INTO {table} ({keys}) VALUES ({placeholders})'
self.query(query, values)
return True
def remove(self, table, **kwargs):
keys = []
values = []
for k,v in kwargs.items():
keys.append(k)
values.append(v)
keydata = ','.join([f'{k} = ?' for k in keys])
query = f'DELETE FROM {table} WHERE {keydata}'
self.query(query, values)
def update(self, table, rowid, data={}, **kwargs):
data.update(kwargs)
newdata = {k: v for k, v in data.items() if k in self.main.tables[table].keys()}
keys = list(newdata.keys())
values = list(newdata.values())
if len(newdata) < 1:
logging.debug('No data provided to update row')
return False
query_data = ', '.join(f'{k} = ?' for k in keys)
query = f'UPDATE {table} SET {query_data} WHERE id = {rowid}'
self.query(query, values)
class DBResult(DotDict):
def __init__(self, row, db, table, cursor):
super().__init__()
self.db = db
self.table = table
for idx, col in enumerate(cursor.description):
self[col[0]] = row[idx]
def __setattr__(self, name, value):
if name not in ['db', 'table']:
return self.__setitem__(name, value)
else:
return super().__setattr__(name, value)
def __delattr__(self, name):
if name not in ['db', 'table']:
return self.__delitem__(name)
else:
return super().__delattr__(name)
def __getattr__(self, value, default=None):
options = [value]
if default:
options.append(default)
if value in self.keys():
val = super().__getitem__(*options)
return DotDict(val) if isinstance(val, dict) else val
else:
return dict.__getattr__(*options)
# Kept for backwards compatibility. Delete later.
def asdict(self):
return self
def Update(self, data={}):
with self.db.Cursor().begin() as cursor:
self.update(data)
cursor.update(self.table, self.id, self.AsDict())
def Remove(self):
with self.db.Cursor().begin() as cursor:
cursor.remove(self.table, id=self.id)
def ParseData(table, row):
tbdata = tables.get(table)
types = []
result = []
if not tbdata:
logging.error('Invalid table:', table)
return
for v in tbdata.values():
dtype = v.split()[0].upper()
if dtype == 'BOOLEAN':
types.append(boolean)
elif dtype in ['INTEGER', 'INT']:
types.append(int)
else:
types.append(str)
for idx, v in enumerate(row):
result.append(types[idx](v) if v else None)
return row
db.insert('config', {'key': 'version', 'value': dbversion})
class MissingTypeError(Exception):
pass
class TooManyConnectionsError(Exception):
pass

23
IzzyLib/error.py Normal file
View file

@ -0,0 +1,23 @@
class MissingHeadersError(Exception):
def __init__(self, headers: list):
self.headers = ', '.join(headers)
self.message = f'Missing required headers for verificaton: {self.headers}'
super().init(self.message)
def __str__(self):
return self.message
class VerificationError(Exception):
def __init__(self, string=None):
self.message = f'Failed to verify hash'
if string:
self.message += ' for ' + string
super().__init__(self.message)
def __str__(self):
return self.message

View file

@ -1,207 +1,342 @@
import traceback, urllib3, json
import functools, json, sys
from IzzyLib import logging
from IzzyLib.misc import DefaultDict, DotDict
from base64 import b64decode, b64encode
from urllib.parse import urlparse
from datetime import datetime
from urllib.error import HTTPError
from urllib.parse import urlparse
from urllib.request import Request, urlopen
import httpsig
from . import error
from Crypto.PublicKey import RSA
from Crypto.Hash import SHA, SHA256, SHA384, SHA512
from Crypto.Signature import PKCS1_v1_5
try:
from Crypto.Hash import SHA256
from Crypto.PublicKey import RSA
from Crypto.Signature import PKCS1_v1_5
crypto_enabled = True
except ImportError:
logging.verbose('Pycryptodome module not found. HTTP header signing and verifying is disabled')
crypto_enabled = False
from . import logging, __version__
from .cache import TTLCache, LRUCache
from .misc import formatUTC
try:
from sanic.request import Request as SanicRequest
sanic_enabled = True
except ImportError:
logging.verbose('Sanic module not found. Request verification is disabled')
sanic_enabled = False
version = '.'.join([str(num) for num in __version__])
Client = None
class httpClient:
def __init__(self, pool=100, timeout=30, headers={}, agent=None):
self.cache = LRUCache()
self.pool = pool
self.timeout = timeout
self.agent = agent if agent else f'IzzyLib/{version}'
self.headers = headers
def VerifyRequest(request: SanicRequest, actor: dict=None):
'''Verify a header signature from a sanic request
self.client = urllib3.PoolManager(num_pools=self.pool, timeout=self.timeout)
self.headers['User-Agent'] = self.agent
request: The request with the headers to verify
actor (optional): A dictionary containing the activitypub actor and the link to the pubkey used for verification
'''
if not sanic_enabled:
logging.error('Sanic request verification disabled')
return
if not actor:
actor = request.ctx.actor
body = request.body if request.body else None
return VerifyHeaders(request.headers, request.method, request.path, body, actor, False)
def _fetch(self, url, headers={}, method='GET', data=None, cached=True):
cached_data = self.cache.fetch(url)
#url = url.split('#')[0]
def VerifyHeaders(headers: dict, method: str, path: str, actor: dict=None, body=None, fail: bool=False):
'''Verify a header signature
if cached and cached_data:
logging.debug(f'Returning cached data for {url}')
return cached_data
headers: A dictionary containing all the headers from a request
method: The HTTP method of the request
path: The path of the HTTP request
actor (optional): A dictionary containing the activitypub actor and the link to the pubkey used for verification
body (optional): The body of the request. Only needed if the signature includes the digest header
fail (optional): If set to True, raise an error instead of returning False if any step of the process fails
'''
if not crypto_enabled:
logging.error('Crypto functions disabled')
return
if not headers.get('User-Agent'):
headers.update({'User-Agent': self.agent})
headers = {k.lower(): v for k,v in headers.items()}
headers['(request-target)'] = f'{method.lower()} {path}'
signature = ParseSig(headers.get('signature'))
digest = headers.get('digest')
missing_headers = [k for k in headers if k in ['date', 'host'] if headers.get(k) == None]
logging.debug(f'Fetching new data for {url}')
if not signature:
if fail:
raise MissingSignatureError()
try:
if data:
if isinstance(data, dict):
data = json.dumps(data)
return False
resp = self.client.request(method, url, headers=headers, body=data)
if not actor:
actor = FetchActor(signature.keyid)
else:
resp = self.client.request(method, url, headers=headers)
## Add digest header to missing headers list if it doesn't exist
if method.lower() == 'post' and not headers.get('digest'):
missing_headers.append('digest')
except Exception as e:
logging.debug(f'Failed to fetch url: {e}')
return
## Fail if missing date, host or digest (if POST) headers
if missing_headers:
if fail:
raise error.MissingHeadersError(missing_headers)
if cached:
logging.debug(f'Caching {url}')
self.cache.store(url, resp)
return False
return resp
## Fail if body verification fails
if digest and not VerifyString(body, digest):
if fail:
raise error.VerificationError('digest header')
return False
pubkey = actor.publicKey['publicKeyPem']
if PkcsHeaders(pubkey, {k:v for k,v in headers.items() if k in signature.headers}, sig=signature):
return True
if fail:
raise error.VerificationError('headers')
return False
def raw(self, *args, **kwargs):
'''
Return a response object
'''
return self._fetch(*args, **kwargs)
@functools.lru_cache(maxsize=512)
def VerifyString(string, enc_string, alg='SHA256', fail=False):
if not crypto_enabled:
logging.error('Crypto functions disabled')
return
if type(string) != bytes:
string = string.encode('UTF-8')
body_hash = b64encode(SHA256.new(string).digest()).decode('UTF-8')
if body_hash == enc_string:
return True
if fail:
raise error.VerificationError()
else:
return False
def text(self, *args, **kwargs):
'''
Return the body as text
'''
resp = self._fetch(*args, **kwargs)
def PkcsHeaders(key: str, headers: dict, sig=None):
if not crypto_enabled:
logging.error('Crypto functions disabled')
return
return resp.data.decode() if resp else None
if sig:
head_items = [f'{item}: {headers[item]}' for item in sig.headers]
else:
head_items = [f'{k.lower()}: {v}' for k,v in headers.items()]
head_string = '\n'.join(head_items)
head_bytes = head_string.encode('UTF-8')
KEY = RSA.importKey(key)
pkcs = PKCS1_v1_5.new(KEY)
h = SHA256.new(head_bytes)
if sig:
return pkcs.verify(h, b64decode(sig.signature))
else:
return pkcs.sign(h)
def json(self, *args, **kwargs):
'''
Return the body as a dict if it's json
'''
headers = kwargs.get('headers')
if not headers:
kwargs['headers'] = {}
kwargs['headers'].update({'Accept': 'application/json'})
resp = self._fetch(*args, **kwargs)
try:
data = json.loads(resp.data.decode())
except Exception as e:
logging.debug(f'Failed to load json: {e}')
return
return data
def ParseSig(headers):
sig_header = headers.get('signature')
if not sig_header:
def ParseSig(signature: str):
if not signature:
logging.verbose('Missing signature header')
return
split_sig = sig_header.split(',')
signature = {}
split_sig = signature.split(',')
sig = DefaultDict({})
for part in split_sig:
key, value = part.split('=', 1)
signature[key.lower()] = value.replace('"', '')
sig[key.lower()] = value.replace('"', '')
if not signature.get('headers'):
if not sig.headers:
logging.verbose('Missing headers section in signature')
return
signature['headers'] = signature['headers'].split()
sig.headers = sig.headers.split()
return signature
return sig
def SignHeaders(headers, keyid, privkey, url, method='GET'):
'''
Signs headers and returns them with a signature header
@functools.lru_cache(maxsize=512)
def FetchActor(keyid, client=None):
if not client:
client = Client if Client else HttpClient()
headers (dict): Headers to be signed
keyid (str): Url to the public key used to verify the signature
privkey (str): Private key used to sign the headers
url (str): Url of the request for the signed headers
method (str): Http method of the request for the signed headers
'''
actor = Client.request(keyid).json()
actor.domain = urlparse(actor.id).netloc
actor.shared_inbox = actor.inbox
actor.pubkey = None
RSAkey = RSA.import_key(privkey)
key_size = int(RSAkey.size_in_bytes()/2)
logging.debug('Signing key size:', key_size)
if actor.get('endpoints'):
actor.shared_inbox = actor.endpoints.get('sharedInbox', actor.inbox)
parsed_url = urlparse(url)
logging.debug(parsed_url)
if actor.get('publicKey'):
actor.pubkey = actor.publicKey.get('publicKeyPem')
raw_headers = {'date': formatUTC(), 'host': parsed_url.netloc, '(request-target)': ' '.join([method, parsed_url.path])}
raw_headers.update(dict(headers))
header_keys = raw_headers.keys()
signer = httpsig.HeaderSigner(keyid, privkey, f'rsa-sha{key_size}', headers=header_keys, sign_header='signature')
new_headers = signer.sign(raw_headers, parsed_url.netloc, method, parsed_url.path)
logging.debug('Signed headers:', new_headers)
del new_headers['(request-target)']
return new_headers
return actor
def ValidateSignature(headers, method, path, client=None, agent=None):
'''
Validates the signature header.
class HttpClient(object):
def __init__(self, headers={}, useragent='IzzyLib/0.3', proxy_type='https', proxy_host=None, proxy_port=None):
proxy_ports = {
'http': 80,
'https': 443
}
headers (dict): All of the headers to be used to check a signature. The signature header must be included too
method (str): The http method used in relation to the headers
path (str): The path of the request in relation to the headers
client (pool object): Specify a httpClient to use for fetching the actor. optional
agent (str): User agent used for fetching actor data. optional
'''
if proxy_type not in ['http', 'https']:
raise ValueError(f'Not a valid proxy type: {proxy_type}')
client = httpClient(agent=agent) if not client else client
headers = {k.lower(): v for k,v in headers.items()}
signature = ParseSig(headers)
actor_data = client.json(signature['keyid'])
logging.debug(actor_data)
try:
pubkey = actor_data['publicKey']['publicKeyPem']
except Exception as e:
logging.verbose(f'Failed to get public key for actor {signature["keyid"]}')
return
valid = httpsig.HeaderVerifier(headers, pubkey, signature['headers'], method, path, sign_header='signature').verify()
if not valid:
if not isinstance(valid, tuple):
logging.verbose('Signature validation failed for unknown actor')
logging.verbose(valid)
else:
logging.verbose(f'Signature validation failed for actor: {valid[1]}')
return
else:
return True
self.headers=headers
self.agent=useragent
self.proxy = DotDict({
'enabled': True if proxy_host else False,
'ptype': proxy_type,
'host': proxy_host,
'port': proxy_ports[proxy_type] if not proxy_port else proxy_port
})
def ValidateRequest(request, client=None, agent=None):
'''
Validates the headers in a Sanic or Aiohttp request (other frameworks may be supported)
See ValidateSignature for 'client' and 'agent' usage
'''
return ValidateSignature(request.headers, request.method, request.path, client, agent)
def __sign_request(self, request, privkey, keyid):
if not crypto_enabled:
logging.error('Crypto functions disabled')
return
request.add_header('(request-target)', f'{request.method.lower()} {request.path}')
request.add_header('host', request.host)
request.add_header('date', datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'))
if request.body:
body_hash = b64encode(SHA256.new(request.body).digest()).decode("UTF-8")
request.add_header('digest', f'SHA-256={body_hash}')
request.add_header('content-length', len(request.body))
sig = {
'keyId': keyid,
'algorithm': 'rsa-sha256',
'headers': ' '.join([k.lower() for k in request.headers.keys()]),
'signature': b64encode(PkcsHeaders(privkey, request.headers)).decode('UTF-8')
}
sig_items = [f'{k}="{v}"' for k,v in sig.items()]
sig_string = ','.join(sig_items)
request.add_header('signature', sig_string)
request.remove_header('(request-target)')
request.remove_header('host')
def __build_request(self, url, data=None, headers={}, method='GET'):
new_headers = self.headers.copy()
new_headers.update(headers)
parsed_headers = {k.lower(): v.lower() for k,v in new_headers.items()}
if not parsed_headers.get('user-agent'):
parsed_headers['user-agent'] = self.agent
if isinstance(data, dict):
data = json.dumps(data)
if isinstance(data, str):
data = data.encode('UTF-8')
request = HttpRequest(url, data=data, headers=parsed_headers, method=method)
if self.proxy.enabled:
request.set_proxy(f'{self.proxy.host}:{self.proxy.host}', self.proxy.ptype)
return request
def request(self, *args, **kwargs):
request = self.__build_request(*args, **kwargs)
try:
response = urlopen(request)
except HTTPError as e:
response = e.fp
return HttpResponse(response)
def json(self, *args, headers={}, activity=True, **kwargs):
json_type = 'activity+json' if activity else 'json'
headers.update({
'accept': f'application/{json_type}'
})
return self.request(*args, headers=headers, **kwargs)
def signed_request(self, privkey, keyid, *args, **kwargs):
request = self.__build_request(*args, **kwargs)
self.__sign_request(request, privkey, keyid)
try:
response = urlopen(request)
except HTTPError as e:
response = e
return HttpResponse(response)
class HttpRequest(Request):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
parsed = urlparse(self.full_url)
self.scheme = parsed.scheme
self.host = parsed.netloc
self.domain = parsed.hostname
self.port = parsed.port
self.path = parsed.path
self.query = parsed.query
self.body = self.data if self.data else b''
class HttpResponse(object):
def __init__(self, response):
self.body = response.read()
self.headers = DefaultDict({k.lower(): v.lower() for k,v in response.headers.items()})
self.status = response.status
self.url = response.url
def text(self):
return self.body.decode('UTF-8')
def json(self, fail=False):
try:
return DotDict(self.text())
except Exception as e:
if fail:
raise e from None
else:
return DotDict()
def json_pretty(self, indent=4):
return json.dumps(self.json().asDict(), indent=indent)
def SetClient(client: HttpClient):
global Client
Client = client

View file

@ -25,7 +25,17 @@ class Log():
'MERP': 0
}
self.config = dict()
self.long_levels = {
'CRITICAL': 'CRIT',
'ERROR': 'ERROR',
'WARNING': 'WARN',
'INFO': 'INFO',
'VERBOSE': 'VERB',
'DEBUG': 'DEBUG',
'MERP': 'MERP'
}
self.config = {'windows': sys.executable.endswith('pythonw.exe')}
self.setConfig(self._parseConfig(config))
@ -35,10 +45,11 @@ class Log():
value = int(level)
except ValueError:
value = self.levels.get(level.upper())
level = self.long_levels.get(level.upper(), level)
value = self.levels.get(level)
if value not in self.levels.values():
raise error.InvalidLevel(f'Invalid logging level: {level}')
raise InvalidLevel(f'Invalid logging level: {level}')
return value
@ -48,13 +59,14 @@ class Log():
if level == num:
return name
raise error.InvalidLevel(f'Invalid logging level: {level}')
raise InvalidLevel(f'Invalid logging level: {level}')
def _parseConfig(self, config):
'''parse the new config and update the old values'''
date = config.get('date', self.config.get('date',True))
systemd = config.get('systemd', self.config.get('systemd,', True))
windows = config.get('windows', self.config.get('windows', False))
if not isinstance(date, bool):
raise TypeError(f'value for "date" is not a boolean: {date}')
@ -64,14 +76,18 @@ class Log():
level_num = self._lvlCheck(config.get('level', self.config.get('level', 'INFO')))
return {
newconfig = {
'level': self._getLevelName(level_num),
'levelnum': level_num,
'datefmt': config.get('datefmt', self.config.get('datefmt', '%Y-%m-%d %H:%M:%S')),
'date': date,
'systemd': systemd
'systemd': systemd,
'windows': windows,
'systemnotif': config.get('systemnotif', None)
}
return newconfig
def setConfig(self, config):
'''set the config'''
@ -81,8 +97,8 @@ class Log():
def getConfig(self, key=None):
'''return the current config'''
if key:
if self.get(key):
return self.get(key)
if self.config.get(key):
return self.config.get(key)
else:
raise ValueError(f'Invalid config option: {key}')
return self.config
@ -95,7 +111,14 @@ class Log():
stdout.flush()
def setLevel(self, level):
self.minimum = self._lvlCheck(level)
def log(self, level, *msg):
if self.config['windows']:
return
'''log to the console'''
levelNum = self._lvlCheck(level)
@ -108,6 +131,9 @@ class Log():
message = ' '.join([str(message) for message in msg])
output = f'{level}: {message}\n'
if self.config['systemnotif']:
self.config['systemnotif'].New(level, message)
if self.config['date'] and (self.config['systemd'] and not env.get('INVOCATION_ID')):
'''only show date when not running in systemd and date var is True'''
date = datetime.now().strftime(self.config['datefmt'])
@ -148,18 +174,15 @@ def getLogger(loginst, config=None):
logger[loginst] = Log(config)
else:
raise error.InvalidLogger(f'logger "{loginst}" doesn\'t exist')
raise InvalidLogger(f'logger "{loginst}" doesn\'t exist')
return logger[loginst]
class error:
'''base class for all errors'''
class InvalidLevel(Exception):
'''Raise when an invalid logging level was specified'''
class InvalidLevel(Exception):
'''Raise when an invalid logging level was specified'''
class InvalidLogger(Exception):
'''Raise when the specified logger doesn't exist'''
class InvalidLogger(Exception):
'''Raise when the specified logger doesn't exist'''
'''create a default logger'''
@ -182,4 +205,5 @@ merp = DefaultLog.merp
'''aliases for the default logger's config functions'''
setConfig = DefaultLog.setConfig
getConfig = DefaultLog.getConfig
setLevel = DefaultLog.setLevel
printConfig = DefaultLog.printConfig

View file

@ -1,16 +1,16 @@
'''Miscellaneous functions'''
import random, string, sys, os, shlex, subprocess, socket, traceback
import random, string, sys, os, json, socket
from os.path import abspath, dirname, basename, isdir, isfile
from os import environ as env
from datetime import datetime
from collections import namedtuple
from getpass import getpass
from pathlib import Path as Pathlib
from . import logging
def boolean(v, fail=True):
if type(v) in [dict, list, tuple]:
def Boolean(v, return_value=False):
if type(v) not in [str, bool, int, type(None)]:
raise ValueError(f'Value is not a string, boolean, int, or nonetype: {value}')
'''make the value lowercase if it's a string'''
@ -20,101 +20,50 @@ def boolean(v, fail=True):
'''convert string to True'''
return True
elif value in [0, False, None, 'off', 'n', 'no', 'false', 'disable', 'none', '']:
if value in [0, False, None, 'off', 'n', 'no', 'false', 'disable', '']:
'''convert string to False'''
return False
elif not fail:
if return_value:
'''just return the value'''
return value
return v
else:
raise ValueError(f'Value cannot be converted to a boolean: {value}')
return True
def randomgen(chars=20):
if not isinstance(chars, int):
raise TypeError(f'Character length must be an integer, not a {type(char)}')
def RandomGen(length=20, chars=None):
if not isinstance(length, int):
raise TypeError(f'Character length must be an integer, not {type(length)}')
return ''.join(random.choices(string.ascii_letters + string.digits, k=chars))
characters = string.ascii_letters + string.digits
if chars:
characters += chars
return ''.join(random.choices(characters, k=length))
def formatUTC(timestamp=None, ap=False):
date = datetime.fromtimestamp(timestamp) if timestamp else datetime.utcnow()
def Timestamp(dtobj=None, utc=False):
dtime = dtobj if dtobj else datetime
date = dtime.utcnow() if utc else dtime.now()
if ap:
return date.strftime('%Y-%m-%dT%H:%M:%SZ')
return date.strftime('%a, %d %b %Y %H:%M:%S GMT')
return date.timestamp()
def config_dir(modpath=None):
if env.get('CONFDIR'):
'''set the storage path to the environment variable if it exists'''
stor_path = abspath(env['CONFDIR'])
def ApDate(date=None, alt=False):
if not date:
date = datetime.utcnow()
else:
stor_path = f'{os.getcwd()}'
elif type(date) == int:
date = datetime.fromtimestamp(date)
if modpath and not env.get('CONFDIR'):
modname = basename(dirname(modpath))
elif type(date) != datetime:
raise TypeError(f'Unsupported object type for ApDate: {type(date)}')
if isdir(f'{stor_path}/{modname}'):
'''set the storage path to CWD/data if the module or script is in the working dir'''
stor_path += '/data'
if not isdir (stor_path):
os.makedirs(stor_path, exist_ok=True)
return stor_path
return date.strftime('%a, %d %b %Y %H:%M:%S GMT' if alt else '%Y-%m-%dT%H:%M:%SZ')
def getBin(filename):
for pathdir in env['PATH'].split(':'):
fullpath = os.path.join(pathdir, filename)
if os.path.isfile(fullpath):
return fullpath
raise FileNotFoundError(f'Cannot find {filename} in path.')
def Try(funct, *args, **kwargs):
Result = namedtuple('Result', 'result exception')
out = None
exc = None
try:
out = funct(*args, **kwargs)
except Exception as e:
traceback.print_exc()
exc = e
return Result(out, exc)
def sudo(cmd, password, user=None):
### Please don't pur your password in plain text in a script
### Use a module like 'getpass' to get the password instead
if isinstance(cmd, list):
cmd = ' '.join(cmd)
elif not isinstance(cmd, str):
raise ValueError('Command is not a list or string')
euser = os.environ.get('USER')
cmd = ' '.join(['sudo', '-u', user, cmd]) if user and euser != user else 'sudo ' + cmd
sudocmd = ' '.join(['echo', f'{password}', '|', cmd])
proc = subprocess.Popen(['/usr/bin/env', 'bash', '-c', sudocmd])
return proc
def getip():
# Get the main IP address of the machine
def GetIp():
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
@ -131,6 +80,367 @@ def getip():
return ip
def merp():
log = logging.getLogger('merp-heck', {'level': 'merp', 'date': False})
log.merp('heck')
def Input(prompt, default=None, valtype=str, options=[], password=False):
input_func = getpass if password else input
if default != None:
prompt += ' [-redacted-]' if password else f' [{default}]'
prompt += '\n'
if options:
opt = '/'.join(options)
prompt += f'[{opt}]'
prompt += ': '
value = input_func(prompt)
while value and options and value not in options:
input_func('Invalid value:', value)
value = input(prompt)
if not value:
return default
ret = valtype(value)
while valtype == Path and not ret.parent().exists():
input_func('Parent directory doesn\'t exist')
ret = Path(input(prompt))
return ret
class DotDict(dict):
def __init__(self, value=None, **kwargs):
super().__init__()
if type(value) in [str, bytes]:
self.fromJson(value)
elif type(value) in [dict, DotDict, DefaultDict]:
self.update(value)
elif value:
raise TypeError('The value must be a JSON string, dict, or another DotDict object, not', value.__class__)
if kwargs:
self.update(kwargs)
def __getattr__(self, key):
try:
val = super().__getattribute__(key)
except AttributeError:
val = self.get(key, InvalidKey())
if type(val) == InvalidKey:
raise KeyError(f'Invalid key: {key}')
return DotDict(val) if type(val) == dict else val
def __delattr__(self, key):
if self.get(key):
del self[key]
super().__delattr__(key)
#def __delitem__(self, key):
#print('delitem', key)
#self.__delattr__(key)
def __setattr__(self, key, value):
if key.startswith('_'):
super().__setattr__(key, value)
else:
super().__setitem__(key, value)
def __parse_item__(self, k, v):
if type(v) == dict:
v = DotDict(v)
if not k.startswith('_'):
return (k, v)
def get(self, key, default=None):
value = super().get(key, default)
return DotDict(value) if type(value) == dict else value
def items(self):
data = []
for k, v in super().items():
new = self.__parse_item__(k, v)
if new:
data.append(new)
return data
def values(self):
return list(super().values())
def keys(self):
return list(super().keys())
def asDict(self):
return dict(self)
def toJson(self, indent=None, **kwargs):
data = {}
for k,v in self.items():
if type(k) in [DotDict, Path, Pathlib]:
k = str(k)
if type(v) in [DotDict, Path, Pathlib]:
v = str(v)
data[k] = v
return json.dumps(data, indent=indent, **kwargs)
def fromJson(self, string):
data = json.loads(string)
self.update(data)
class DefaultDict(DotDict):
def __getattr__(self, key):
try:
val = super().__getattribute__(key)
except AttributeError:
val = self.get(key, DefaultDict())
return DotDict(val) if type(val) == dict else val
class Path(object):
def __init__(self, path, exist=True, missing=True, parents=True):
self.__path = Pathlib(str(path)).resolve()
self.json = DotDict({})
self.exist = exist
self.missing = missing
self.parents = parents
self.name = self.__path.name
#def __getattr__(self, key):
#try:
#attr = getattr(self.__path, key)
#except AttributeError:
#attr = getattr(self, key)
#return attr
def __str__(self):
return str(self.__path)
def str(self):
return self.__str__()
def __check_dir(self, path=None):
target = self if not path else Path(path)
if not self.parents and not target.parent().exists():
raise FileNotFoundError('Parent directories do not exist:', target.str())
if not self.exist and target.exists():
raise FileExistsError('File or directory already exists:', target.str())
def size(self):
return self.__path.stat().st_size
def mtime(self):
return self.__path.stat().st_mtime
def mkdir(self, mode=0o755):
self.__path.mkdir(mode, self.parents, self.exist)
return True if self.__path.exists() else False
def new(self):
return Path(self.__path)
def parent(self, new=True):
path = Pathlib(self.__path).parent
if new:
return Path(path)
self.__path = path
return self
def move(self, path):
target = Path(path)
self.__check_dir(path)
if target.exists() and not target.isdir():
target.delete()
def join(self, path, new=True):
new_path = self.__path.joinpath(path).resolve()
if new:
return Path(new_path)
self.__path = new_path
return self
def home(self, path=None, new=True):
new_path = Pathlib.home()
if path:
new_path = new_path.joinpath(path)
if new:
return Path(new_path)
self.__path = new_path
return self
def isdir(self):
return self.__path.is_dir()
def isfile(self):
return self.__path.is_file()
def islink(self):
return self.__path.is_symlink()
def listdir(self):
return [Path(path) for path in self.__path.iterdir()]
def exists(self):
return self.__path.exists()
def mtime(self):
return os.path.getmtime(self.str())
def link(self, path):
target = Path(path)
self.__check_dir(path)
if target.exists():
target.delete()
self.__path.symlink_to(path, target.isdir())
def resolve(self, new=True):
path = self.__path.resolve()
if new:
return Path(path)
self.__path = path
return self
def touch(self, mode=0o666):
self.__path.touch(mode, self.exist)
return True if self.__path.exists() else False
def loadJson(self):
self.json = DotDict(self.read())
return self.json
def updateJson(self, data={}):
if type(data) == str:
data = json.loads(data)
self.json.update(data)
def storeJson(self, indent=None):
with self.__path.open('w') as fp:
fp.write(json.dumps(self.json.asDict(), indent=indent))
# This needs to be extended to handle dirs with files/sub-dirs
def delete(self):
if self.isdir():
self.__path.rmdir(self.missing)
else:
self.__path.unlink(self.missing)
return not self.exists()
def open(self, *args):
return self.__path.open(*args)
def read(self, *args):
return self.open().read(*args)
def readlines(self):
return self.open().readlines()
## def rmdir():
def NfsCheck(path):
proc = Path('/proc/mounts')
path = Path(path).resolve()
if not proc.exists():
return True
with proc.open() as fd:
for line in fd:
line = line.split()
if line[2] == 'nfs' and line[1] in path.str():
return False
return True
class InvalidKey(object):
pass

View file

@ -1,233 +1,131 @@
'''functions for web template management and rendering'''
import codecs, traceback, os, json
import codecs, traceback, os, json, xml
from os import listdir, makedirs
from os.path import isfile, isdir, getmtime, abspath
from jinja2 import Environment, FileSystemLoader, ChoiceLoader, select_autoescape
from hamlpy.hamlpy import Compiler
from jinja2 import Environment, FileSystemLoader, ChoiceLoader, select_autoescape, Markup
from hamlish_jinja import HamlishExtension
from markdown import markdown
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from xml.dom import minidom
try:
from sanic import response as Response
except ModuleNotFoundError:
Response = None
from . import logging
from .color import *
framework = 'sanic'
try:
import sanic
except:
logging.debug('Cannot find Sanic')
try:
import aiohttp
except:
logging.debug('Cannot find aioHTTP')
from .misc import Path, DotDict
env = None
class Template(Environment):
def __init__(self, search=[], global_vars={}, autoescape=True):
self.autoescape = autoescape
self.watcher = None
self.search = []
self.func_context = None
global_variables = {
'markdown': markdown,
'lighten': lighten,
'darken': darken,
'saturate': saturate,
'desaturate': desaturate,
'rgba': rgba
}
for path in search:
self.__add_search_path(path)
search_path = list()
build_path_pairs = dict()
super().__init__(
loader=ChoiceLoader([FileSystemLoader(path) for path in self.search]),
extensions=[HamlishExtension],
autoescape=self.autoescape,
lstrip_blocks=True,
trim_blocks=True
)
self.hamlish_file_extensions=('.haml',)
self.hamlish_enable_div_shortcut=True
self.hamlish_mode = 'indented'
self.globals.update({
'markdown': markdown,
'markup': Markup,
'cleanhtml': lambda text: ''.join(xml.etree.ElementTree.fromstring(text).itertext()),
'lighten': lighten,
'darken': darken,
'saturate': saturate,
'desaturate': desaturate,
'rgba': rgba
})
self.globals.update(global_vars)
def addSearchPath(path):
tplPath = abspath(path)
def __add_search_path(self, path):
tpl_path = Path(path)
if not isdir(tplPath):
raise FileNotFoundError(f'Cannot find template directory: {tplPath}')
if not tpl_path.exists():
raise FileNotFoundError('Cannot find search path:', tpl_path.str())
if tplPath not in search_path:
search_path.append(tplPath)
if tpl_path.str() not in self.search:
self.search.append(tpl_path.str())
def delSearchPath(path):
tplPath = abspath(path)
if tplPath in search_path:
search_path.remove(tplPath)
def addEnv(self, k, v):
self.globals[k] = v
def addBuildPath(name, source, destination):
src = abspath(source)
dest = abspath(destination)
def delEnv(self, var):
if not self.globals.get(var):
raise ValueError(f'"{var}" not in global variables')
if not isdir(src):
raise FileNotFoundError(f'Source path doesn\'t exist: {src}')
build_path_pairs.update({
name: {
'source': src,
'destination': dest
}
})
addSearchPath(dest)
del self.var[var]
def delBuildPath(name):
if not build_path_pairs.get(name):
raise ValueError(f'"{name}" not in build paths')
def updateEnv(self, data):
if not isinstance(data, dict):
raise ValueError(f'Environment data not a dict')
del build_path_pairs[src]
self.globals.update(data)
def getBuildPath(name=None):
template = build_path_pairs.get(name)
if name:
if template:
return template
else:
raise ValueError(f'"{name}" not in build paths')
return build_path_pairs
def addFilter(self, funct, name=None):
name = funct.__name__ if not name else name
self.filters[name] = funct
def addEnv(data):
if not isinstance(data, dict):
raise TypeError(f'environment data is not a dict')
def delFilter(self, name):
if not self.filters.get(name):
raise valueError(f'"{name}" not in global filters')
global_variables.update(data)
del self.filters[name]
def delEnv(var):
if not global_variables.get(var):
raise ValueError(f'"{var}" not in global variables')
def updateFilter(self, data):
if not isinstance(context, dict):
raise ValueError(f'Filter data not a dict')
del global_variables[var]
self.filters.update(data)
def setup(fwork='sanic'):
global env
global framework
def render(self, tplfile, request=None, context={}, headers={}, cookies={}, pprint=False, **kwargs):
if not isinstance(context, dict):
raise TypeError(f'context for {tplfile} not a dict: {type(context)} {context}')
framework = fwork
env = Environment(
loader=ChoiceLoader([FileSystemLoader(path) for path in search_path]),
autoescape=select_autoescape(['html', 'css']),
lstrip_blocks=True,
trim_blocks=True
)
context['request'] = request if request else {'headers': headers, 'cookies': cookies}
if self.func_context:
context.update(self.func_context(DotDict(context), DotDict(self.globals)))
def renderTemplate(tplfile, context={}, request=None, headers=dict(), cookies=dict(), **kwargs):
if not isinstance(context, dict):
raise TypeError(f'context for {tplfile} not a dict')
result = self.get_template(tplfile).render(context)
data = global_variables.copy()
data['request'] = request if request else {'headers': headers, 'cookies': cookies}
data.update(context)
return env.get_template(tplfile).render(data)
def sendResponse(template, request, context=dict(), status=200, ctype='text/html', headers=dict(), **kwargs):
context['request'] = request
html = renderTemplate(template, context, **kwargs)
if framework == 'sanic':
return sanic.response.text(html, status=status, headers=headers, content_type=ctype)
elif framework == 'aiohttp':
return aiohttp.web.Response(body=html, status=status, headers=headers, content_type=ctype)
else:
logging.error('Please install aiohttp or sanic. Response not sent.')
# delete me later
aiohttpTemplate = sendResponse
def buildTemplates(name=None):
paths = getBuildPath(name)
for k, tplPaths in paths.items():
src = tplPaths['source']
dest = tplPaths['destination']
timefile = f'{dest}/times.json'
updated = False
if not isdir(f'{dest}'):
makedirs(f'{dest}')
if isfile(timefile):
try:
times = json.load(open(timefile))
except:
times = {}
if pprint and any(map(tplfile.endswith, ['haml', 'html', 'xml'])):
return minidom.parseString(result).toprettyxml(indent=" ")
else:
times = {}
for filename in listdir(src):
fullPath = f'{src}/{filename}'
modtime = getmtime(fullPath)
base, ext = filename.split('.', 1)
if ext != 'haml':
pass
elif base not in times or times.get(base) != modtime:
updated = True
logging.verbose(f"Template '{filename}' was changed. Building...")
try:
destination = f'{dest}/{base}.html'
haml_lines = codecs.open(fullPath, 'r', encoding='utf-8').read().splitlines()
compiler = Compiler()
output = compiler.process_lines(haml_lines)
outfile = codecs.open(destination, 'w', encoding='utf-8')
outfile.write(output)
logging.info(f"Template '{filename}' has been built")
except Exception as e:
'''I'm actually not sure what sort of errors can happen here, so generic catch-all for now'''
traceback.print_exc()
logging.error(f'Failed to build {filename}: {e}')
times[base] = modtime
if updated:
with open(timefile, 'w') as filename:
filename.write(json.dumps(times))
return result
def templateWatcher():
watchPaths = [path['source'] for k, path in build_path_pairs.items()]
logging.info('Starting template watcher')
observer = Observer()
def response(self, *args, ctype='text/html', status=200, **kwargs):
if not Response:
raise ModuleNotFoundError('Sanic is not installed')
for tplpath in watchPaths:
logging.debug(f'Watching template dir for changes: {tplpath}')
observer.schedule(templateWatchHandler(), tplpath, recursive=True)
return observer
class templateWatchHandler(FileSystemEventHandler):
def on_any_event(self, event):
filename, ext = os.path.splitext(os.path.relpath(event.src_path))
if event.event_type in ['modified', 'created'] and ext[1:] == 'haml':
logging.info('Rebuilding templates')
buildTemplates()
__all__ = ['addSearchPath', 'delSearchPath', 'addBuildPath', 'delSearchPath', 'addEnv', 'delEnv', 'setup', 'renderTemplate', 'sendResponse', 'buildTemplates', 'templateWatcher']
html = self.render(*args, **kwargs)
return Response.HTTPResponse(body=html, status=status, content_type=ctype, headers=kwargs.get('headers', {}))

View file

@ -1,12 +1,10 @@
pycryptodome>=3.9.1
colour>=0.1.5
aiohttp>=3.6.2
envbash>=1.2.0
hamlpy3>=0.84.0
Hamlish-Jinja==0.3.3
Jinja2>=2.10.1
jinja2-markdown>=0.0.3
Mastodon.py>=1.5.0
pycryptodome>=3.9.1
python-magic>=0.4.18
sanic>=19.12.2
urllib3>=1.25.7
watchdog>=0.8.3
httpsig>=1.3.0