major update
This commit is contained in:
parent
221beb7670
commit
0e59542626
|
@ -8,4 +8,4 @@ import sys
|
|||
assert sys.version_info >= (3, 6)
|
||||
|
||||
|
||||
__version__ = (0, 3, 1)
|
||||
__version__ = (0, 4, 0)
|
||||
|
|
|
@ -3,19 +3,14 @@ import re
|
|||
|
||||
from colour import Color
|
||||
|
||||
from . import logging
|
||||
|
||||
|
||||
check = lambda color: Color(f'#{str(color)}' if re.search(r'^(?:[0-9a-fA-F]{3}){1,2}$', color) else color)
|
||||
|
||||
def _multi(multiplier):
|
||||
if multiplier > 100:
|
||||
if multiplier >= 1:
|
||||
return 1
|
||||
|
||||
elif multiplier > 1:
|
||||
return multiplier/100
|
||||
|
||||
elif multiplier < 0:
|
||||
elif multiplier <= 0:
|
||||
return 0
|
||||
|
||||
return multiplier
|
||||
|
|
536
IzzyLib/database.py
Normal file
536
IzzyLib/database.py
Normal file
|
@ -0,0 +1,536 @@
|
|||
## Probably gonna replace all of this with a custom sqlalchemy setup tbh
|
||||
## It'll look like the db classes and functions in https://git.barkshark.xyz/izaliamae/social
|
||||
import shutil, traceback, importlib, sqlite3, sys, json
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
|
||||
from .cache import LRUCache
|
||||
from .misc import Boolean, DotDict, Path
|
||||
from . import logging, sql
|
||||
|
||||
try:
|
||||
from dbutils.pooled_db import PooledDB
|
||||
except ImportError:
|
||||
from DBUtils.PooledDB import PooledDB
|
||||
|
||||
|
||||
## Only sqlite3 has been tested
|
||||
## Use other modules at your own risk.
|
||||
class DB():
|
||||
def __init__(self, tables, dbmodule='sqlite', cursor=None, **kwargs):
|
||||
cursor = Cursor if not cursor else cursor
|
||||
|
||||
if dbmodule in ['sqlite', 'sqlite3']:
|
||||
self.dbmodule = sqlite3
|
||||
self.dbtype = 'sqlite'
|
||||
|
||||
else:
|
||||
self.dbmodule = None
|
||||
self.__setup_module(dbmodule)
|
||||
|
||||
self.db = None
|
||||
self.cursor = lambda : cursor(self).begin()
|
||||
self.kwargs = kwargs
|
||||
self.tables = tables
|
||||
self.cache = DotDict()
|
||||
self.__setup_database()
|
||||
|
||||
for table in tables.keys():
|
||||
self.__setup_cache(table)
|
||||
|
||||
|
||||
def query(self, *args, **kwargs):
|
||||
return self.__cursor_cmd('query', *args, **kwargs)
|
||||
|
||||
|
||||
def fetch(self, *args, **kwargs):
|
||||
return self.__cursor_cmd('fetch', *args, **kwargs)
|
||||
|
||||
|
||||
def insert(self, *args, **kwargs):
|
||||
return self.__cursor_cmd('insert', *args, **kwargs)
|
||||
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
return self.__cursor_cmd('update', *args, **kwargs)
|
||||
|
||||
|
||||
def remove(self, *args, **kwargs):
|
||||
return self.__cursor_cmd('remove', *args, **kwargs)
|
||||
|
||||
|
||||
def __cursor_cmd(self, name, *args, **kwargs):
|
||||
with self.cursor() as cur:
|
||||
method = getattr(cur, name)
|
||||
|
||||
if not method:
|
||||
raise KeyError('Not a valid cursor method:', name)
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
|
||||
def __setup_module(self, dbtype):
|
||||
modules = []
|
||||
modtypes = {
|
||||
'sqlite': ['sqlite3'],
|
||||
'postgresql': ['psycopg3', 'pgdb', 'psycopg2', 'pg8000'],
|
||||
'mysql': ['mysqldb', 'trio_mysql'],
|
||||
'mssql': ['pymssql', 'adodbapi']
|
||||
}
|
||||
|
||||
for dbmod, mods in modtypes.items():
|
||||
if dbtype == dbmod:
|
||||
self.dbtype = dbmod
|
||||
modules = mods
|
||||
break
|
||||
|
||||
elif dbtype in mods:
|
||||
self.dbtype = dbmod
|
||||
modules = [dbtype]
|
||||
break
|
||||
|
||||
if not modules:
|
||||
logging.verbose('Not a database type. Checking if it is a module...')
|
||||
|
||||
for mod in modules:
|
||||
try:
|
||||
self.dbmodule = importlib.import_module(mod)
|
||||
except ImportError as e:
|
||||
logging.verbose('Module not installed:', mod)
|
||||
|
||||
if self.dbmodule:
|
||||
break
|
||||
|
||||
if not self.dbmodule:
|
||||
if modtypes.get(dbtype):
|
||||
logging.error('Failed to find module for database type:', dbtype)
|
||||
logging.error(f'Please install one of these modules to use a {dbtype} database:', ', '.join(modules))
|
||||
|
||||
else:
|
||||
logging.error('Failed to import module:', dbtype)
|
||||
logging.error('Install one of the following modules:')
|
||||
|
||||
for key, modules in modtypes.items():
|
||||
logging.error(f'{key}:')
|
||||
for mod in modules:
|
||||
logging.error(f'\t{mod}')
|
||||
|
||||
sys.exit()
|
||||
|
||||
|
||||
def __setup_database(self):
|
||||
if self.dbtype == 'sqlite':
|
||||
if not self.kwargs.get('database'):
|
||||
dbfile = ':memory:'
|
||||
|
||||
dbfile = Path(self.kwargs['database'])
|
||||
dbfile.parent().mkdir()
|
||||
|
||||
if not dbfile.parent().isdir():
|
||||
logging.error('Invalid path to database file:', dbfile.parent().str())
|
||||
sys.exit()
|
||||
|
||||
self.kwargs['database'] = dbfile.str()
|
||||
self.kwargs['check_same_thread'] = False
|
||||
|
||||
else:
|
||||
if not self.kwargs.get('password'):
|
||||
self.kwargs.pop('password', None)
|
||||
|
||||
self.db = PooledDB(self.dbmodule, **self.kwargs)
|
||||
|
||||
|
||||
def connection(self):
|
||||
if self.dbtype == 'sqlite':
|
||||
return self.db.connection()
|
||||
|
||||
|
||||
def __setup_cache(self, table):
|
||||
self.cache[table] = LRUCache(128)
|
||||
|
||||
|
||||
def close(self):
|
||||
self.db.close()
|
||||
|
||||
|
||||
def count(self, table):
|
||||
tabledict = tables.get(table)
|
||||
|
||||
if not tabledict:
|
||||
logging.debug('Table does not exist:', table)
|
||||
return 0
|
||||
|
||||
data = self.query(f'SELECT COUNT(*) FROM {table}')
|
||||
return data[0][0]
|
||||
|
||||
|
||||
#def query(self, string, values=[], cursor=None):
|
||||
#if not string.endswith(';'):
|
||||
#string += ';'
|
||||
|
||||
#if not cursor:
|
||||
#with self.Cursor() as cursor:
|
||||
#cursor.execute(string, values)
|
||||
#return cursor.fetchall()
|
||||
|
||||
#else:
|
||||
#cursor.execute(string,value)
|
||||
#return cursor.fetchall()
|
||||
|
||||
#return False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def Cursor(self):
|
||||
conn = self.db.connection()
|
||||
conn.begin()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
yield cursor
|
||||
except self.dbmodule.OperationalError:
|
||||
traceback.print_exc()
|
||||
conn.rollback()
|
||||
finally:
|
||||
cursor.close()
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def CreateTable(self, table):
|
||||
layout = DotDict(self.tables.get(table))
|
||||
|
||||
if not layout:
|
||||
logging.error('Table config doesn\'t exist:', table)
|
||||
return
|
||||
|
||||
cmd = f'CREATE TABLE IF NOT EXISTS {table}('
|
||||
items = []
|
||||
|
||||
for k, v in layout.items():
|
||||
options = ' '.join(v.get('options', []))
|
||||
default = v.get('default')
|
||||
item = f'{k} {v["type"].upper()} {options}'
|
||||
|
||||
if default:
|
||||
item += f'DEFAULT {default}'
|
||||
|
||||
items.append(item)
|
||||
|
||||
cmd += ', '.join(items) + ')'
|
||||
|
||||
return True if self.query(cmd) != False else False
|
||||
|
||||
|
||||
def RenameTable(self, table, newname):
|
||||
self.query(f'ALTER TABLE {table} RENAME TO {newname}')
|
||||
|
||||
|
||||
def DropTable(self, table):
|
||||
self.query(f'DROP TABLE {table}')
|
||||
|
||||
|
||||
def AddColumn(self, table, name, datatype, default=None, options=None):
|
||||
query = f'ALTER TABLE {table} ADD COLUMN {name} {datatype.upper()}'
|
||||
|
||||
if default:
|
||||
query += f' DEFAULT {default}'
|
||||
|
||||
if options:
|
||||
query += f' {options}'
|
||||
|
||||
self.query(query)
|
||||
|
||||
|
||||
def CheckDatabase(self, database):
|
||||
if self.dbtype == 'postgresql':
|
||||
tables = self.query('SELECT datname FROM pg_database')
|
||||
|
||||
else:
|
||||
tables = []
|
||||
|
||||
tables = [table[0] for table in tables]
|
||||
print(database in tables, database, tables)
|
||||
|
||||
return database in tables
|
||||
|
||||
|
||||
def CreateTables(self):
|
||||
dbname = self.kwargs['database']
|
||||
|
||||
for name, table in self.tables.items():
|
||||
if len(self.query(sql.CheckTable(self.dbtype, name))) < 1:
|
||||
logging.info('Creating table:', name)
|
||||
self.query(table.sql())
|
||||
|
||||
|
||||
class Table(dict):
|
||||
def __init__(self, name, dbtype='sqlite'):
|
||||
super().__init__({})
|
||||
|
||||
self.sqlstr = 'CREATE TABLE IF NOT EXISTS {} ({});'
|
||||
self.dbtype = dbtype
|
||||
self.name = name
|
||||
self.columns = {}
|
||||
self.fkeys = {}
|
||||
|
||||
|
||||
def addColumn(self, name, datatype=None, null=True, unique=None, primary=None, fkey=None):
|
||||
if name == 'id':
|
||||
if self.dbtype == 'sqlite':
|
||||
datatype = 'integer'
|
||||
primary = True if primary == None else primary
|
||||
unique = True
|
||||
null = False
|
||||
|
||||
else:
|
||||
datatype = 'serial'
|
||||
primary = 'true'
|
||||
|
||||
if name == 'timestamp':
|
||||
datatype = 'float'
|
||||
|
||||
elif not datatype:
|
||||
raise MissingTypeError(f'Missing a data type for column: {name}')
|
||||
|
||||
colsql = f'{name} {datatype.upper()}'
|
||||
|
||||
if unique:
|
||||
colsql += ' UNIQUE'
|
||||
|
||||
if not null:
|
||||
colsql += ' NOT NULL'
|
||||
|
||||
if primary:
|
||||
self.primary = name
|
||||
colsql += ' PRIMARY KEY'
|
||||
|
||||
if name == 'id' and self.dbtype == 'sqlite':
|
||||
colsql += ' AUTOINCREMENT'
|
||||
|
||||
if fkey:
|
||||
if self.dbtype == 'postgresql':
|
||||
colsql += f' REFERENCES {fkey[0]}({fkey[1]})'
|
||||
|
||||
elif self.dbtype == 'sqlite':
|
||||
self.fkeys[name] += f'FOREIGN KEY ({name}) REFERENCES {fkey[0]} ({fkey[1]})'
|
||||
|
||||
self.columns.update({name: colsql})
|
||||
|
||||
|
||||
def sql(self):
|
||||
if not self.primary:
|
||||
logging.error('Please specify a primary column')
|
||||
return
|
||||
|
||||
data = ', '.join(list(self.columns.values()))
|
||||
|
||||
if self.fkeys:
|
||||
data += ', '
|
||||
data += ', '.join(list(self.fkeys.values()))
|
||||
|
||||
sqldata = self.sqlstr.format(self.name, data)
|
||||
print(sqldata)
|
||||
return sqldata
|
||||
|
||||
|
||||
class Cursor(object):
|
||||
def __init__(self, db):
|
||||
self.main = db
|
||||
self.db = db.db
|
||||
|
||||
@contextmanager
|
||||
def begin(self):
|
||||
self.conn = self.db.connection()
|
||||
self.conn.begin()
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
try:
|
||||
yield self
|
||||
|
||||
except self.main.dbmodule.OperationalError:
|
||||
self.conn.rollback()
|
||||
raise
|
||||
|
||||
finally:
|
||||
self.cursor.close()
|
||||
self.conn.commit()
|
||||
self.conn.close()
|
||||
|
||||
|
||||
def query(self, string, values=[], cursor=None):
|
||||
#if not string.endswith(';'):
|
||||
#string += ';'
|
||||
|
||||
self.cursor.execute(string, values)
|
||||
data = self.cursor.fetchall()
|
||||
|
||||
|
||||
def fetch(self, table, single=True, sort=None, **kwargs):
|
||||
rowid = kwargs.get('id')
|
||||
querysort = f'ORDER BY {sort}'
|
||||
|
||||
resultOpts = [self, table, self.cursor]
|
||||
|
||||
if rowid:
|
||||
cursor.execute(f"SELECT * FROM {table} WHERE id = ?", [rowid])
|
||||
|
||||
elif kwargs:
|
||||
placeholders = [f'{k} = ?' for k in kwargs.keys()]
|
||||
values = kwargs.values()
|
||||
|
||||
where = ' and '.join(placeholders)
|
||||
query = f"SELECT * FROM {table} WHERE {where} {querysort if sort else ''}"
|
||||
self.cursor.execute(query, list(values))
|
||||
|
||||
else:
|
||||
self.cursor.execute(f'SELECT * FROM {table} {querysort if sort else ""}')
|
||||
|
||||
rows = self.cursor.fetchall() if not single else self.cursor.fetchone()
|
||||
|
||||
if rows:
|
||||
if single:
|
||||
return DBResult(rows, *resultOpts)
|
||||
|
||||
return [DBResult(row, *resultOpts) for row in rows]
|
||||
|
||||
return None if single else []
|
||||
|
||||
|
||||
def insert(self, table, data={}, **kwargs):
|
||||
data.update(kwargs)
|
||||
placeholders = ",".join(['?' for _ in data.keys()])
|
||||
values = tuple(data.values())
|
||||
keys = ','.join(data.keys())
|
||||
|
||||
if 'timestamp' in self.main.tables[table].keys() and 'timestamp' not in keys:
|
||||
data['timestamp'] = datetime.now()
|
||||
|
||||
query = f'INSERT INTO {table} ({keys}) VALUES ({placeholders})'
|
||||
self.query(query, values)
|
||||
return True
|
||||
|
||||
|
||||
def remove(self, table, **kwargs):
|
||||
keys = []
|
||||
values = []
|
||||
|
||||
for k,v in kwargs.items():
|
||||
keys.append(k)
|
||||
values.append(v)
|
||||
|
||||
keydata = ','.join([f'{k} = ?' for k in keys])
|
||||
query = f'DELETE FROM {table} WHERE {keydata}'
|
||||
|
||||
self.query(query, values)
|
||||
|
||||
|
||||
def update(self, table, rowid, data={}, **kwargs):
|
||||
data.update(kwargs)
|
||||
newdata = {k: v for k, v in data.items() if k in self.main.tables[table].keys()}
|
||||
keys = list(newdata.keys())
|
||||
values = list(newdata.values())
|
||||
|
||||
if len(newdata) < 1:
|
||||
logging.debug('No data provided to update row')
|
||||
return False
|
||||
|
||||
query_data = ', '.join(f'{k} = ?' for k in keys)
|
||||
query = f'UPDATE {table} SET {query_data} WHERE id = {rowid}'
|
||||
|
||||
self.query(query, values)
|
||||
|
||||
|
||||
class DBResult(DotDict):
|
||||
def __init__(self, row, db, table, cursor):
|
||||
super().__init__()
|
||||
|
||||
self.db = db
|
||||
self.table = table
|
||||
|
||||
for idx, col in enumerate(cursor.description):
|
||||
self[col[0]] = row[idx]
|
||||
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name not in ['db', 'table']:
|
||||
return self.__setitem__(name, value)
|
||||
|
||||
else:
|
||||
return super().__setattr__(name, value)
|
||||
|
||||
|
||||
def __delattr__(self, name):
|
||||
if name not in ['db', 'table']:
|
||||
return self.__delitem__(name)
|
||||
|
||||
else:
|
||||
return super().__delattr__(name)
|
||||
|
||||
|
||||
def __getattr__(self, value, default=None):
|
||||
options = [value]
|
||||
|
||||
if default:
|
||||
options.append(default)
|
||||
|
||||
if value in self.keys():
|
||||
val = super().__getitem__(*options)
|
||||
return DotDict(val) if isinstance(val, dict) else val
|
||||
|
||||
else:
|
||||
return dict.__getattr__(*options)
|
||||
|
||||
|
||||
# Kept for backwards compatibility. Delete later.
|
||||
def asdict(self):
|
||||
return self
|
||||
|
||||
|
||||
def Update(self, data={}):
|
||||
with self.db.Cursor().begin() as cursor:
|
||||
self.update(data)
|
||||
cursor.update(self.table, self.id, self.AsDict())
|
||||
|
||||
|
||||
def Remove(self):
|
||||
with self.db.Cursor().begin() as cursor:
|
||||
cursor.remove(self.table, id=self.id)
|
||||
|
||||
|
||||
def ParseData(table, row):
|
||||
tbdata = tables.get(table)
|
||||
types = []
|
||||
result = []
|
||||
|
||||
if not tbdata:
|
||||
logging.error('Invalid table:', table)
|
||||
return
|
||||
|
||||
for v in tbdata.values():
|
||||
dtype = v.split()[0].upper()
|
||||
|
||||
if dtype == 'BOOLEAN':
|
||||
types.append(boolean)
|
||||
|
||||
elif dtype in ['INTEGER', 'INT']:
|
||||
types.append(int)
|
||||
|
||||
else:
|
||||
types.append(str)
|
||||
|
||||
for idx, v in enumerate(row):
|
||||
result.append(types[idx](v) if v else None)
|
||||
|
||||
return row
|
||||
|
||||
db.insert('config', {'key': 'version', 'value': dbversion})
|
||||
|
||||
|
||||
class MissingTypeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class TooManyConnectionsError(Exception):
|
||||
pass
|
23
IzzyLib/error.py
Normal file
23
IzzyLib/error.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
class MissingHeadersError(Exception):
|
||||
def __init__(self, headers: list):
|
||||
self.headers = ', '.join(headers)
|
||||
self.message = f'Missing required headers for verificaton: {self.headers}'
|
||||
super().init(self.message)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
class VerificationError(Exception):
|
||||
def __init__(self, string=None):
|
||||
self.message = f'Failed to verify hash'
|
||||
|
||||
if string:
|
||||
self.message += ' for ' + string
|
||||
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
435
IzzyLib/http.py
435
IzzyLib/http.py
|
@ -1,207 +1,342 @@
|
|||
import traceback, urllib3, json
|
||||
import functools, json, sys
|
||||
|
||||
from IzzyLib import logging
|
||||
from IzzyLib.misc import DefaultDict, DotDict
|
||||
from base64 import b64decode, b64encode
|
||||
from urllib.parse import urlparse
|
||||
from datetime import datetime
|
||||
from urllib.error import HTTPError
|
||||
from urllib.parse import urlparse
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import httpsig
|
||||
from . import error
|
||||
|
||||
from Crypto.PublicKey import RSA
|
||||
from Crypto.Hash import SHA, SHA256, SHA384, SHA512
|
||||
from Crypto.Signature import PKCS1_v1_5
|
||||
try:
|
||||
from Crypto.Hash import SHA256
|
||||
from Crypto.PublicKey import RSA
|
||||
from Crypto.Signature import PKCS1_v1_5
|
||||
crypto_enabled = True
|
||||
except ImportError:
|
||||
logging.verbose('Pycryptodome module not found. HTTP header signing and verifying is disabled')
|
||||
crypto_enabled = False
|
||||
|
||||
from . import logging, __version__
|
||||
from .cache import TTLCache, LRUCache
|
||||
from .misc import formatUTC
|
||||
try:
|
||||
from sanic.request import Request as SanicRequest
|
||||
sanic_enabled = True
|
||||
except ImportError:
|
||||
logging.verbose('Sanic module not found. Request verification is disabled')
|
||||
sanic_enabled = False
|
||||
|
||||
|
||||
version = '.'.join([str(num) for num in __version__])
|
||||
Client = None
|
||||
|
||||
|
||||
class httpClient:
|
||||
def __init__(self, pool=100, timeout=30, headers={}, agent=None):
|
||||
self.cache = LRUCache()
|
||||
self.pool = pool
|
||||
self.timeout = timeout
|
||||
self.agent = agent if agent else f'IzzyLib/{version}'
|
||||
self.headers = headers
|
||||
def VerifyRequest(request: SanicRequest, actor: dict=None):
|
||||
'''Verify a header signature from a sanic request
|
||||
|
||||
self.client = urllib3.PoolManager(num_pools=self.pool, timeout=self.timeout)
|
||||
self.headers['User-Agent'] = self.agent
|
||||
request: The request with the headers to verify
|
||||
actor (optional): A dictionary containing the activitypub actor and the link to the pubkey used for verification
|
||||
'''
|
||||
if not sanic_enabled:
|
||||
logging.error('Sanic request verification disabled')
|
||||
return
|
||||
|
||||
if not actor:
|
||||
actor = request.ctx.actor
|
||||
|
||||
body = request.body if request.body else None
|
||||
return VerifyHeaders(request.headers, request.method, request.path, body, actor, False)
|
||||
|
||||
|
||||
def _fetch(self, url, headers={}, method='GET', data=None, cached=True):
|
||||
cached_data = self.cache.fetch(url)
|
||||
#url = url.split('#')[0]
|
||||
def VerifyHeaders(headers: dict, method: str, path: str, actor: dict=None, body=None, fail: bool=False):
|
||||
'''Verify a header signature
|
||||
|
||||
if cached and cached_data:
|
||||
logging.debug(f'Returning cached data for {url}')
|
||||
return cached_data
|
||||
headers: A dictionary containing all the headers from a request
|
||||
method: The HTTP method of the request
|
||||
path: The path of the HTTP request
|
||||
actor (optional): A dictionary containing the activitypub actor and the link to the pubkey used for verification
|
||||
body (optional): The body of the request. Only needed if the signature includes the digest header
|
||||
fail (optional): If set to True, raise an error instead of returning False if any step of the process fails
|
||||
'''
|
||||
if not crypto_enabled:
|
||||
logging.error('Crypto functions disabled')
|
||||
return
|
||||
|
||||
if not headers.get('User-Agent'):
|
||||
headers.update({'User-Agent': self.agent})
|
||||
headers = {k.lower(): v for k,v in headers.items()}
|
||||
headers['(request-target)'] = f'{method.lower()} {path}'
|
||||
signature = ParseSig(headers.get('signature'))
|
||||
digest = headers.get('digest')
|
||||
missing_headers = [k for k in headers if k in ['date', 'host'] if headers.get(k) == None]
|
||||
|
||||
logging.debug(f'Fetching new data for {url}')
|
||||
if not signature:
|
||||
if fail:
|
||||
raise MissingSignatureError()
|
||||
|
||||
try:
|
||||
if data:
|
||||
if isinstance(data, dict):
|
||||
data = json.dumps(data)
|
||||
return False
|
||||
|
||||
resp = self.client.request(method, url, headers=headers, body=data)
|
||||
if not actor:
|
||||
actor = FetchActor(signature.keyid)
|
||||
|
||||
## Add digest header to missing headers list if it doesn't exist
|
||||
if method.lower() == 'post' and not headers.get('digest'):
|
||||
missing_headers.append('digest')
|
||||
|
||||
## Fail if missing date, host or digest (if POST) headers
|
||||
if missing_headers:
|
||||
if fail:
|
||||
raise error.MissingHeadersError(missing_headers)
|
||||
|
||||
return False
|
||||
|
||||
## Fail if body verification fails
|
||||
if digest and not VerifyString(body, digest):
|
||||
if fail:
|
||||
raise error.VerificationError('digest header')
|
||||
|
||||
return False
|
||||
|
||||
pubkey = actor.publicKey['publicKeyPem']
|
||||
|
||||
if PkcsHeaders(pubkey, {k:v for k,v in headers.items() if k in signature.headers}, sig=signature):
|
||||
return True
|
||||
|
||||
if fail:
|
||||
raise error.VerificationError('headers')
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=512)
|
||||
def VerifyString(string, enc_string, alg='SHA256', fail=False):
|
||||
if not crypto_enabled:
|
||||
logging.error('Crypto functions disabled')
|
||||
return
|
||||
|
||||
if type(string) != bytes:
|
||||
string = string.encode('UTF-8')
|
||||
|
||||
body_hash = b64encode(SHA256.new(string).digest()).decode('UTF-8')
|
||||
|
||||
if body_hash == enc_string:
|
||||
return True
|
||||
|
||||
if fail:
|
||||
raise error.VerificationError()
|
||||
|
||||
else:
|
||||
resp = self.client.request(method, url, headers=headers)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logging.debug(f'Failed to fetch url: {e}')
|
||||
|
||||
def PkcsHeaders(key: str, headers: dict, sig=None):
|
||||
if not crypto_enabled:
|
||||
logging.error('Crypto functions disabled')
|
||||
return
|
||||
|
||||
if cached:
|
||||
logging.debug(f'Caching {url}')
|
||||
self.cache.store(url, resp)
|
||||
if sig:
|
||||
head_items = [f'{item}: {headers[item]}' for item in sig.headers]
|
||||
|
||||
return resp
|
||||
else:
|
||||
head_items = [f'{k.lower()}: {v}' for k,v in headers.items()]
|
||||
|
||||
head_string = '\n'.join(head_items)
|
||||
head_bytes = head_string.encode('UTF-8')
|
||||
|
||||
KEY = RSA.importKey(key)
|
||||
pkcs = PKCS1_v1_5.new(KEY)
|
||||
h = SHA256.new(head_bytes)
|
||||
|
||||
if sig:
|
||||
return pkcs.verify(h, b64decode(sig.signature))
|
||||
|
||||
else:
|
||||
return pkcs.sign(h)
|
||||
|
||||
|
||||
def raw(self, *args, **kwargs):
|
||||
'''
|
||||
Return a response object
|
||||
'''
|
||||
return self._fetch(*args, **kwargs)
|
||||
|
||||
|
||||
def text(self, *args, **kwargs):
|
||||
'''
|
||||
Return the body as text
|
||||
'''
|
||||
resp = self._fetch(*args, **kwargs)
|
||||
|
||||
return resp.data.decode() if resp else None
|
||||
|
||||
|
||||
def json(self, *args, **kwargs):
|
||||
'''
|
||||
Return the body as a dict if it's json
|
||||
'''
|
||||
|
||||
headers = kwargs.get('headers')
|
||||
|
||||
if not headers:
|
||||
kwargs['headers'] = {}
|
||||
|
||||
kwargs['headers'].update({'Accept': 'application/json'})
|
||||
resp = self._fetch(*args, **kwargs)
|
||||
|
||||
try:
|
||||
data = json.loads(resp.data.decode())
|
||||
|
||||
except Exception as e:
|
||||
logging.debug(f'Failed to load json: {e}')
|
||||
return
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def ParseSig(headers):
|
||||
sig_header = headers.get('signature')
|
||||
|
||||
if not sig_header:
|
||||
def ParseSig(signature: str):
|
||||
if not signature:
|
||||
logging.verbose('Missing signature header')
|
||||
return
|
||||
|
||||
split_sig = sig_header.split(',')
|
||||
signature = {}
|
||||
split_sig = signature.split(',')
|
||||
sig = DefaultDict({})
|
||||
|
||||
for part in split_sig:
|
||||
key, value = part.split('=', 1)
|
||||
signature[key.lower()] = value.replace('"', '')
|
||||
sig[key.lower()] = value.replace('"', '')
|
||||
|
||||
if not signature.get('headers'):
|
||||
if not sig.headers:
|
||||
logging.verbose('Missing headers section in signature')
|
||||
return
|
||||
|
||||
signature['headers'] = signature['headers'].split()
|
||||
sig.headers = sig.headers.split()
|
||||
|
||||
return signature
|
||||
return sig
|
||||
|
||||
|
||||
def SignHeaders(headers, keyid, privkey, url, method='GET'):
|
||||
'''
|
||||
Signs headers and returns them with a signature header
|
||||
@functools.lru_cache(maxsize=512)
|
||||
def FetchActor(keyid, client=None):
|
||||
if not client:
|
||||
client = Client if Client else HttpClient()
|
||||
|
||||
headers (dict): Headers to be signed
|
||||
keyid (str): Url to the public key used to verify the signature
|
||||
privkey (str): Private key used to sign the headers
|
||||
url (str): Url of the request for the signed headers
|
||||
method (str): Http method of the request for the signed headers
|
||||
'''
|
||||
actor = Client.request(keyid).json()
|
||||
actor.domain = urlparse(actor.id).netloc
|
||||
actor.shared_inbox = actor.inbox
|
||||
actor.pubkey = None
|
||||
|
||||
RSAkey = RSA.import_key(privkey)
|
||||
key_size = int(RSAkey.size_in_bytes()/2)
|
||||
logging.debug('Signing key size:', key_size)
|
||||
if actor.get('endpoints'):
|
||||
actor.shared_inbox = actor.endpoints.get('sharedInbox', actor.inbox)
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
logging.debug(parsed_url)
|
||||
if actor.get('publicKey'):
|
||||
actor.pubkey = actor.publicKey.get('publicKeyPem')
|
||||
|
||||
raw_headers = {'date': formatUTC(), 'host': parsed_url.netloc, '(request-target)': ' '.join([method, parsed_url.path])}
|
||||
raw_headers.update(dict(headers))
|
||||
header_keys = raw_headers.keys()
|
||||
|
||||
signer = httpsig.HeaderSigner(keyid, privkey, f'rsa-sha{key_size}', headers=header_keys, sign_header='signature')
|
||||
new_headers = signer.sign(raw_headers, parsed_url.netloc, method, parsed_url.path)
|
||||
logging.debug('Signed headers:', new_headers)
|
||||
|
||||
del new_headers['(request-target)']
|
||||
|
||||
return new_headers
|
||||
return actor
|
||||
|
||||
|
||||
def ValidateSignature(headers, method, path, client=None, agent=None):
|
||||
'''
|
||||
Validates the signature header.
|
||||
class HttpClient(object):
|
||||
def __init__(self, headers={}, useragent='IzzyLib/0.3', proxy_type='https', proxy_host=None, proxy_port=None):
|
||||
proxy_ports = {
|
||||
'http': 80,
|
||||
'https': 443
|
||||
}
|
||||
|
||||
headers (dict): All of the headers to be used to check a signature. The signature header must be included too
|
||||
method (str): The http method used in relation to the headers
|
||||
path (str): The path of the request in relation to the headers
|
||||
client (pool object): Specify a httpClient to use for fetching the actor. optional
|
||||
agent (str): User agent used for fetching actor data. optional
|
||||
'''
|
||||
if proxy_type not in ['http', 'https']:
|
||||
raise ValueError(f'Not a valid proxy type: {proxy_type}')
|
||||
|
||||
client = httpClient(agent=agent) if not client else client
|
||||
headers = {k.lower(): v for k,v in headers.items()}
|
||||
self.headers=headers
|
||||
self.agent=useragent
|
||||
self.proxy = DotDict({
|
||||
'enabled': True if proxy_host else False,
|
||||
'ptype': proxy_type,
|
||||
'host': proxy_host,
|
||||
'port': proxy_ports[proxy_type] if not proxy_port else proxy_port
|
||||
})
|
||||
|
||||
signature = ParseSig(headers)
|
||||
|
||||
actor_data = client.json(signature['keyid'])
|
||||
logging.debug(actor_data)
|
||||
def __sign_request(self, request, privkey, keyid):
|
||||
if not crypto_enabled:
|
||||
logging.error('Crypto functions disabled')
|
||||
return
|
||||
|
||||
request.add_header('(request-target)', f'{request.method.lower()} {request.path}')
|
||||
request.add_header('host', request.host)
|
||||
request.add_header('date', datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'))
|
||||
|
||||
if request.body:
|
||||
body_hash = b64encode(SHA256.new(request.body).digest()).decode("UTF-8")
|
||||
request.add_header('digest', f'SHA-256={body_hash}')
|
||||
request.add_header('content-length', len(request.body))
|
||||
|
||||
sig = {
|
||||
'keyId': keyid,
|
||||
'algorithm': 'rsa-sha256',
|
||||
'headers': ' '.join([k.lower() for k in request.headers.keys()]),
|
||||
'signature': b64encode(PkcsHeaders(privkey, request.headers)).decode('UTF-8')
|
||||
}
|
||||
|
||||
sig_items = [f'{k}="{v}"' for k,v in sig.items()]
|
||||
sig_string = ','.join(sig_items)
|
||||
|
||||
request.add_header('signature', sig_string)
|
||||
|
||||
request.remove_header('(request-target)')
|
||||
request.remove_header('host')
|
||||
|
||||
|
||||
def __build_request(self, url, data=None, headers={}, method='GET'):
|
||||
new_headers = self.headers.copy()
|
||||
new_headers.update(headers)
|
||||
|
||||
parsed_headers = {k.lower(): v.lower() for k,v in new_headers.items()}
|
||||
|
||||
if not parsed_headers.get('user-agent'):
|
||||
parsed_headers['user-agent'] = self.agent
|
||||
|
||||
if isinstance(data, dict):
|
||||
data = json.dumps(data)
|
||||
|
||||
if isinstance(data, str):
|
||||
data = data.encode('UTF-8')
|
||||
|
||||
request = HttpRequest(url, data=data, headers=parsed_headers, method=method)
|
||||
|
||||
if self.proxy.enabled:
|
||||
request.set_proxy(f'{self.proxy.host}:{self.proxy.host}', self.proxy.ptype)
|
||||
|
||||
return request
|
||||
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
request = self.__build_request(*args, **kwargs)
|
||||
|
||||
try:
|
||||
pubkey = actor_data['publicKey']['publicKeyPem']
|
||||
response = urlopen(request)
|
||||
except HTTPError as e:
|
||||
response = e.fp
|
||||
|
||||
return HttpResponse(response)
|
||||
|
||||
|
||||
def json(self, *args, headers={}, activity=True, **kwargs):
|
||||
json_type = 'activity+json' if activity else 'json'
|
||||
headers.update({
|
||||
'accept': f'application/{json_type}'
|
||||
})
|
||||
return self.request(*args, headers=headers, **kwargs)
|
||||
|
||||
|
||||
def signed_request(self, privkey, keyid, *args, **kwargs):
|
||||
request = self.__build_request(*args, **kwargs)
|
||||
|
||||
self.__sign_request(request, privkey, keyid)
|
||||
|
||||
try:
|
||||
response = urlopen(request)
|
||||
except HTTPError as e:
|
||||
response = e
|
||||
|
||||
return HttpResponse(response)
|
||||
|
||||
|
||||
class HttpRequest(Request):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
parsed = urlparse(self.full_url)
|
||||
|
||||
self.scheme = parsed.scheme
|
||||
self.host = parsed.netloc
|
||||
self.domain = parsed.hostname
|
||||
self.port = parsed.port
|
||||
self.path = parsed.path
|
||||
self.query = parsed.query
|
||||
self.body = self.data if self.data else b''
|
||||
|
||||
|
||||
class HttpResponse(object):
|
||||
def __init__(self, response):
|
||||
self.body = response.read()
|
||||
self.headers = DefaultDict({k.lower(): v.lower() for k,v in response.headers.items()})
|
||||
self.status = response.status
|
||||
self.url = response.url
|
||||
|
||||
|
||||
def text(self):
|
||||
return self.body.decode('UTF-8')
|
||||
|
||||
|
||||
def json(self, fail=False):
|
||||
try:
|
||||
return DotDict(self.text())
|
||||
except Exception as e:
|
||||
logging.verbose(f'Failed to get public key for actor {signature["keyid"]}')
|
||||
return
|
||||
|
||||
valid = httpsig.HeaderVerifier(headers, pubkey, signature['headers'], method, path, sign_header='signature').verify()
|
||||
|
||||
if not valid:
|
||||
if not isinstance(valid, tuple):
|
||||
logging.verbose('Signature validation failed for unknown actor')
|
||||
logging.verbose(valid)
|
||||
if fail:
|
||||
raise e from None
|
||||
|
||||
else:
|
||||
logging.verbose(f'Signature validation failed for actor: {valid[1]}')
|
||||
|
||||
return
|
||||
|
||||
else:
|
||||
return True
|
||||
return DotDict()
|
||||
|
||||
|
||||
def ValidateRequest(request, client=None, agent=None):
|
||||
'''
|
||||
Validates the headers in a Sanic or Aiohttp request (other frameworks may be supported)
|
||||
See ValidateSignature for 'client' and 'agent' usage
|
||||
'''
|
||||
return ValidateSignature(request.headers, request.method, request.path, client, agent)
|
||||
def json_pretty(self, indent=4):
|
||||
return json.dumps(self.json().asDict(), indent=indent)
|
||||
|
||||
|
||||
def SetClient(client: HttpClient):
|
||||
global Client
|
||||
Client = client
|
||||
|
|
|
@ -25,7 +25,17 @@ class Log():
|
|||
'MERP': 0
|
||||
}
|
||||
|
||||
self.config = dict()
|
||||
self.long_levels = {
|
||||
'CRITICAL': 'CRIT',
|
||||
'ERROR': 'ERROR',
|
||||
'WARNING': 'WARN',
|
||||
'INFO': 'INFO',
|
||||
'VERBOSE': 'VERB',
|
||||
'DEBUG': 'DEBUG',
|
||||
'MERP': 'MERP'
|
||||
}
|
||||
|
||||
self.config = {'windows': sys.executable.endswith('pythonw.exe')}
|
||||
self.setConfig(self._parseConfig(config))
|
||||
|
||||
|
||||
|
@ -35,10 +45,11 @@ class Log():
|
|||
value = int(level)
|
||||
|
||||
except ValueError:
|
||||
value = self.levels.get(level.upper())
|
||||
level = self.long_levels.get(level.upper(), level)
|
||||
value = self.levels.get(level)
|
||||
|
||||
if value not in self.levels.values():
|
||||
raise error.InvalidLevel(f'Invalid logging level: {level}')
|
||||
raise InvalidLevel(f'Invalid logging level: {level}')
|
||||
|
||||
return value
|
||||
|
||||
|
@ -48,13 +59,14 @@ class Log():
|
|||
if level == num:
|
||||
return name
|
||||
|
||||
raise error.InvalidLevel(f'Invalid logging level: {level}')
|
||||
raise InvalidLevel(f'Invalid logging level: {level}')
|
||||
|
||||
|
||||
def _parseConfig(self, config):
|
||||
'''parse the new config and update the old values'''
|
||||
date = config.get('date', self.config.get('date',True))
|
||||
systemd = config.get('systemd', self.config.get('systemd,', True))
|
||||
windows = config.get('windows', self.config.get('windows', False))
|
||||
|
||||
if not isinstance(date, bool):
|
||||
raise TypeError(f'value for "date" is not a boolean: {date}')
|
||||
|
@ -64,14 +76,18 @@ class Log():
|
|||
|
||||
level_num = self._lvlCheck(config.get('level', self.config.get('level', 'INFO')))
|
||||
|
||||
return {
|
||||
newconfig = {
|
||||
'level': self._getLevelName(level_num),
|
||||
'levelnum': level_num,
|
||||
'datefmt': config.get('datefmt', self.config.get('datefmt', '%Y-%m-%d %H:%M:%S')),
|
||||
'date': date,
|
||||
'systemd': systemd
|
||||
'systemd': systemd,
|
||||
'windows': windows,
|
||||
'systemnotif': config.get('systemnotif', None)
|
||||
}
|
||||
|
||||
return newconfig
|
||||
|
||||
|
||||
def setConfig(self, config):
|
||||
'''set the config'''
|
||||
|
@ -81,8 +97,8 @@ class Log():
|
|||
def getConfig(self, key=None):
|
||||
'''return the current config'''
|
||||
if key:
|
||||
if self.get(key):
|
||||
return self.get(key)
|
||||
if self.config.get(key):
|
||||
return self.config.get(key)
|
||||
else:
|
||||
raise ValueError(f'Invalid config option: {key}')
|
||||
return self.config
|
||||
|
@ -95,7 +111,14 @@ class Log():
|
|||
stdout.flush()
|
||||
|
||||
|
||||
def setLevel(self, level):
|
||||
self.minimum = self._lvlCheck(level)
|
||||
|
||||
|
||||
def log(self, level, *msg):
|
||||
if self.config['windows']:
|
||||
return
|
||||
|
||||
'''log to the console'''
|
||||
levelNum = self._lvlCheck(level)
|
||||
|
||||
|
@ -108,6 +131,9 @@ class Log():
|
|||
message = ' '.join([str(message) for message in msg])
|
||||
output = f'{level}: {message}\n'
|
||||
|
||||
if self.config['systemnotif']:
|
||||
self.config['systemnotif'].New(level, message)
|
||||
|
||||
if self.config['date'] and (self.config['systemd'] and not env.get('INVOCATION_ID')):
|
||||
'''only show date when not running in systemd and date var is True'''
|
||||
date = datetime.now().strftime(self.config['datefmt'])
|
||||
|
@ -148,17 +174,14 @@ def getLogger(loginst, config=None):
|
|||
logger[loginst] = Log(config)
|
||||
|
||||
else:
|
||||
raise error.InvalidLogger(f'logger "{loginst}" doesn\'t exist')
|
||||
raise InvalidLogger(f'logger "{loginst}" doesn\'t exist')
|
||||
|
||||
return logger[loginst]
|
||||
|
||||
class error:
|
||||
'''base class for all errors'''
|
||||
|
||||
class InvalidLevel(Exception):
|
||||
class InvalidLevel(Exception):
|
||||
'''Raise when an invalid logging level was specified'''
|
||||
|
||||
class InvalidLogger(Exception):
|
||||
class InvalidLogger(Exception):
|
||||
'''Raise when the specified logger doesn't exist'''
|
||||
|
||||
|
||||
|
@ -182,4 +205,5 @@ merp = DefaultLog.merp
|
|||
'''aliases for the default logger's config functions'''
|
||||
setConfig = DefaultLog.setConfig
|
||||
getConfig = DefaultLog.getConfig
|
||||
setLevel = DefaultLog.setLevel
|
||||
printConfig = DefaultLog.printConfig
|
||||
|
|
480
IzzyLib/misc.py
480
IzzyLib/misc.py
|
@ -1,16 +1,16 @@
|
|||
'''Miscellaneous functions'''
|
||||
import random, string, sys, os, shlex, subprocess, socket, traceback
|
||||
import random, string, sys, os, json, socket
|
||||
|
||||
from os.path import abspath, dirname, basename, isdir, isfile
|
||||
from os import environ as env
|
||||
from datetime import datetime
|
||||
from collections import namedtuple
|
||||
from getpass import getpass
|
||||
from pathlib import Path as Pathlib
|
||||
|
||||
from . import logging
|
||||
|
||||
|
||||
def boolean(v, fail=True):
|
||||
if type(v) in [dict, list, tuple]:
|
||||
def Boolean(v, return_value=False):
|
||||
if type(v) not in [str, bool, int, type(None)]:
|
||||
raise ValueError(f'Value is not a string, boolean, int, or nonetype: {value}')
|
||||
|
||||
'''make the value lowercase if it's a string'''
|
||||
|
@ -20,101 +20,50 @@ def boolean(v, fail=True):
|
|||
'''convert string to True'''
|
||||
return True
|
||||
|
||||
elif value in [0, False, None, 'off', 'n', 'no', 'false', 'disable', 'none', '']:
|
||||
if value in [0, False, None, 'off', 'n', 'no', 'false', 'disable', '']:
|
||||
'''convert string to False'''
|
||||
return False
|
||||
|
||||
elif not fail:
|
||||
if return_value:
|
||||
'''just return the value'''
|
||||
return value
|
||||
return v
|
||||
|
||||
else:
|
||||
raise ValueError(f'Value cannot be converted to a boolean: {value}')
|
||||
return True
|
||||
|
||||
|
||||
def randomgen(chars=20):
|
||||
if not isinstance(chars, int):
|
||||
raise TypeError(f'Character length must be an integer, not a {type(char)}')
|
||||
def RandomGen(length=20, chars=None):
|
||||
if not isinstance(length, int):
|
||||
raise TypeError(f'Character length must be an integer, not {type(length)}')
|
||||
|
||||
return ''.join(random.choices(string.ascii_letters + string.digits, k=chars))
|
||||
characters = string.ascii_letters + string.digits
|
||||
|
||||
if chars:
|
||||
characters += chars
|
||||
|
||||
return ''.join(random.choices(characters, k=length))
|
||||
|
||||
|
||||
def formatUTC(timestamp=None, ap=False):
|
||||
date = datetime.fromtimestamp(timestamp) if timestamp else datetime.utcnow()
|
||||
def Timestamp(dtobj=None, utc=False):
|
||||
dtime = dtobj if dtobj else datetime
|
||||
date = dtime.utcnow() if utc else dtime.now()
|
||||
|
||||
if ap:
|
||||
return date.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
|
||||
return date.strftime('%a, %d %b %Y %H:%M:%S GMT')
|
||||
return date.timestamp()
|
||||
|
||||
|
||||
def config_dir(modpath=None):
|
||||
if env.get('CONFDIR'):
|
||||
'''set the storage path to the environment variable if it exists'''
|
||||
stor_path = abspath(env['CONFDIR'])
|
||||
def ApDate(date=None, alt=False):
|
||||
if not date:
|
||||
date = datetime.utcnow()
|
||||
|
||||
else:
|
||||
stor_path = f'{os.getcwd()}'
|
||||
elif type(date) == int:
|
||||
date = datetime.fromtimestamp(date)
|
||||
|
||||
if modpath and not env.get('CONFDIR'):
|
||||
modname = basename(dirname(modpath))
|
||||
elif type(date) != datetime:
|
||||
raise TypeError(f'Unsupported object type for ApDate: {type(date)}')
|
||||
|
||||
if isdir(f'{stor_path}/{modname}'):
|
||||
'''set the storage path to CWD/data if the module or script is in the working dir'''
|
||||
stor_path += '/data'
|
||||
|
||||
if not isdir (stor_path):
|
||||
os.makedirs(stor_path, exist_ok=True)
|
||||
|
||||
return stor_path
|
||||
return date.strftime('%a, %d %b %Y %H:%M:%S GMT' if alt else '%Y-%m-%dT%H:%M:%SZ')
|
||||
|
||||
|
||||
def getBin(filename):
|
||||
for pathdir in env['PATH'].split(':'):
|
||||
fullpath = os.path.join(pathdir, filename)
|
||||
|
||||
if os.path.isfile(fullpath):
|
||||
return fullpath
|
||||
|
||||
raise FileNotFoundError(f'Cannot find {filename} in path.')
|
||||
|
||||
|
||||
def Try(funct, *args, **kwargs):
|
||||
Result = namedtuple('Result', 'result exception')
|
||||
out = None
|
||||
exc = None
|
||||
|
||||
try:
|
||||
out = funct(*args, **kwargs)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
exc = e
|
||||
|
||||
return Result(out, exc)
|
||||
|
||||
|
||||
def sudo(cmd, password, user=None):
|
||||
### Please don't pur your password in plain text in a script
|
||||
### Use a module like 'getpass' to get the password instead
|
||||
|
||||
if isinstance(cmd, list):
|
||||
cmd = ' '.join(cmd)
|
||||
|
||||
elif not isinstance(cmd, str):
|
||||
raise ValueError('Command is not a list or string')
|
||||
|
||||
euser = os.environ.get('USER')
|
||||
|
||||
cmd = ' '.join(['sudo', '-u', user, cmd]) if user and euser != user else 'sudo ' + cmd
|
||||
sudocmd = ' '.join(['echo', f'{password}', '|', cmd])
|
||||
proc = subprocess.Popen(['/usr/bin/env', 'bash', '-c', sudocmd])
|
||||
|
||||
return proc
|
||||
|
||||
|
||||
def getip():
|
||||
# Get the main IP address of the machine
|
||||
def GetIp():
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
|
||||
try:
|
||||
|
@ -131,6 +80,367 @@ def getip():
|
|||
return ip
|
||||
|
||||
|
||||
def merp():
|
||||
log = logging.getLogger('merp-heck', {'level': 'merp', 'date': False})
|
||||
log.merp('heck')
|
||||
def Input(prompt, default=None, valtype=str, options=[], password=False):
|
||||
input_func = getpass if password else input
|
||||
|
||||
if default != None:
|
||||
prompt += ' [-redacted-]' if password else f' [{default}]'
|
||||
|
||||
prompt += '\n'
|
||||
|
||||
if options:
|
||||
opt = '/'.join(options)
|
||||
prompt += f'[{opt}]'
|
||||
|
||||
prompt += ': '
|
||||
|
||||
value = input_func(prompt)
|
||||
|
||||
while value and options and value not in options:
|
||||
input_func('Invalid value:', value)
|
||||
value = input(prompt)
|
||||
|
||||
if not value:
|
||||
return default
|
||||
|
||||
ret = valtype(value)
|
||||
|
||||
while valtype == Path and not ret.parent().exists():
|
||||
input_func('Parent directory doesn\'t exist')
|
||||
ret = Path(input(prompt))
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class DotDict(dict):
|
||||
def __init__(self, value=None, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
if type(value) in [str, bytes]:
|
||||
self.fromJson(value)
|
||||
|
||||
elif type(value) in [dict, DotDict, DefaultDict]:
|
||||
self.update(value)
|
||||
|
||||
elif value:
|
||||
raise TypeError('The value must be a JSON string, dict, or another DotDict object, not', value.__class__)
|
||||
|
||||
if kwargs:
|
||||
self.update(kwargs)
|
||||
|
||||
|
||||
def __getattr__(self, key):
|
||||
try:
|
||||
val = super().__getattribute__(key)
|
||||
|
||||
except AttributeError:
|
||||
val = self.get(key, InvalidKey())
|
||||
|
||||
if type(val) == InvalidKey:
|
||||
raise KeyError(f'Invalid key: {key}')
|
||||
|
||||
return DotDict(val) if type(val) == dict else val
|
||||
|
||||
|
||||
def __delattr__(self, key):
|
||||
if self.get(key):
|
||||
del self[key]
|
||||
|
||||
super().__delattr__(key)
|
||||
|
||||
|
||||
#def __delitem__(self, key):
|
||||
#print('delitem', key)
|
||||
#self.__delattr__(key)
|
||||
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key.startswith('_'):
|
||||
super().__setattr__(key, value)
|
||||
|
||||
else:
|
||||
super().__setitem__(key, value)
|
||||
|
||||
|
||||
def __parse_item__(self, k, v):
|
||||
if type(v) == dict:
|
||||
v = DotDict(v)
|
||||
|
||||
if not k.startswith('_'):
|
||||
return (k, v)
|
||||
|
||||
|
||||
def get(self, key, default=None):
|
||||
value = super().get(key, default)
|
||||
return DotDict(value) if type(value) == dict else value
|
||||
|
||||
|
||||
def items(self):
|
||||
data = []
|
||||
|
||||
for k, v in super().items():
|
||||
new = self.__parse_item__(k, v)
|
||||
|
||||
if new:
|
||||
data.append(new)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def values(self):
|
||||
return list(super().values())
|
||||
|
||||
|
||||
def keys(self):
|
||||
return list(super().keys())
|
||||
|
||||
|
||||
def asDict(self):
|
||||
return dict(self)
|
||||
|
||||
|
||||
def toJson(self, indent=None, **kwargs):
|
||||
data = {}
|
||||
|
||||
for k,v in self.items():
|
||||
if type(k) in [DotDict, Path, Pathlib]:
|
||||
k = str(k)
|
||||
|
||||
if type(v) in [DotDict, Path, Pathlib]:
|
||||
v = str(v)
|
||||
|
||||
data[k] = v
|
||||
|
||||
return json.dumps(data, indent=indent, **kwargs)
|
||||
|
||||
|
||||
def fromJson(self, string):
|
||||
data = json.loads(string)
|
||||
self.update(data)
|
||||
|
||||
|
||||
class DefaultDict(DotDict):
|
||||
def __getattr__(self, key):
|
||||
try:
|
||||
val = super().__getattribute__(key)
|
||||
|
||||
except AttributeError:
|
||||
val = self.get(key, DefaultDict())
|
||||
|
||||
return DotDict(val) if type(val) == dict else val
|
||||
|
||||
|
||||
class Path(object):
|
||||
def __init__(self, path, exist=True, missing=True, parents=True):
|
||||
self.__path = Pathlib(str(path)).resolve()
|
||||
self.json = DotDict({})
|
||||
self.exist = exist
|
||||
self.missing = missing
|
||||
self.parents = parents
|
||||
self.name = self.__path.name
|
||||
|
||||
|
||||
#def __getattr__(self, key):
|
||||
#try:
|
||||
#attr = getattr(self.__path, key)
|
||||
|
||||
#except AttributeError:
|
||||
#attr = getattr(self, key)
|
||||
|
||||
#return attr
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return str(self.__path)
|
||||
|
||||
|
||||
def str(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def __check_dir(self, path=None):
|
||||
target = self if not path else Path(path)
|
||||
|
||||
if not self.parents and not target.parent().exists():
|
||||
raise FileNotFoundError('Parent directories do not exist:', target.str())
|
||||
|
||||
if not self.exist and target.exists():
|
||||
raise FileExistsError('File or directory already exists:', target.str())
|
||||
|
||||
|
||||
def size(self):
|
||||
return self.__path.stat().st_size
|
||||
|
||||
|
||||
def mtime(self):
|
||||
return self.__path.stat().st_mtime
|
||||
|
||||
|
||||
def mkdir(self, mode=0o755):
|
||||
self.__path.mkdir(mode, self.parents, self.exist)
|
||||
|
||||
return True if self.__path.exists() else False
|
||||
|
||||
|
||||
def new(self):
|
||||
return Path(self.__path)
|
||||
|
||||
|
||||
def parent(self, new=True):
|
||||
path = Pathlib(self.__path).parent
|
||||
|
||||
if new:
|
||||
return Path(path)
|
||||
|
||||
self.__path = path
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def move(self, path):
|
||||
target = Path(path)
|
||||
|
||||
self.__check_dir(path)
|
||||
|
||||
if target.exists() and not target.isdir():
|
||||
target.delete()
|
||||
|
||||
|
||||
def join(self, path, new=True):
|
||||
new_path = self.__path.joinpath(path).resolve()
|
||||
|
||||
if new:
|
||||
return Path(new_path)
|
||||
|
||||
self.__path = new_path
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def home(self, path=None, new=True):
|
||||
new_path = Pathlib.home()
|
||||
|
||||
if path:
|
||||
new_path = new_path.joinpath(path)
|
||||
|
||||
if new:
|
||||
return Path(new_path)
|
||||
|
||||
self.__path = new_path
|
||||
return self
|
||||
|
||||
|
||||
def isdir(self):
|
||||
return self.__path.is_dir()
|
||||
|
||||
|
||||
def isfile(self):
|
||||
return self.__path.is_file()
|
||||
|
||||
|
||||
def islink(self):
|
||||
return self.__path.is_symlink()
|
||||
|
||||
|
||||
def listdir(self):
|
||||
return [Path(path) for path in self.__path.iterdir()]
|
||||
|
||||
|
||||
def exists(self):
|
||||
return self.__path.exists()
|
||||
|
||||
|
||||
def mtime(self):
|
||||
return os.path.getmtime(self.str())
|
||||
|
||||
|
||||
def link(self, path):
|
||||
target = Path(path)
|
||||
|
||||
self.__check_dir(path)
|
||||
|
||||
if target.exists():
|
||||
target.delete()
|
||||
|
||||
self.__path.symlink_to(path, target.isdir())
|
||||
|
||||
|
||||
def resolve(self, new=True):
|
||||
path = self.__path.resolve()
|
||||
|
||||
if new:
|
||||
return Path(path)
|
||||
|
||||
self.__path = path
|
||||
return self
|
||||
|
||||
|
||||
def touch(self, mode=0o666):
|
||||
self.__path.touch(mode, self.exist)
|
||||
|
||||
return True if self.__path.exists() else False
|
||||
|
||||
|
||||
def loadJson(self):
|
||||
self.json = DotDict(self.read())
|
||||
|
||||
return self.json
|
||||
|
||||
|
||||
def updateJson(self, data={}):
|
||||
if type(data) == str:
|
||||
data = json.loads(data)
|
||||
|
||||
self.json.update(data)
|
||||
|
||||
|
||||
def storeJson(self, indent=None):
|
||||
with self.__path.open('w') as fp:
|
||||
fp.write(json.dumps(self.json.asDict(), indent=indent))
|
||||
|
||||
|
||||
# This needs to be extended to handle dirs with files/sub-dirs
|
||||
def delete(self):
|
||||
if self.isdir():
|
||||
self.__path.rmdir(self.missing)
|
||||
|
||||
else:
|
||||
self.__path.unlink(self.missing)
|
||||
|
||||
return not self.exists()
|
||||
|
||||
|
||||
def open(self, *args):
|
||||
return self.__path.open(*args)
|
||||
|
||||
|
||||
def read(self, *args):
|
||||
return self.open().read(*args)
|
||||
|
||||
|
||||
def readlines(self):
|
||||
return self.open().readlines()
|
||||
|
||||
|
||||
## def rmdir():
|
||||
|
||||
|
||||
def NfsCheck(path):
|
||||
proc = Path('/proc/mounts')
|
||||
path = Path(path).resolve()
|
||||
|
||||
if not proc.exists():
|
||||
return True
|
||||
|
||||
with proc.open() as fd:
|
||||
for line in fd:
|
||||
line = line.split()
|
||||
|
||||
if line[2] == 'nfs' and line[1] in path.str():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class InvalidKey(object):
|
||||
pass
|
||||
|
|
|
@ -1,233 +1,131 @@
|
|||
'''functions for web template management and rendering'''
|
||||
import codecs, traceback, os, json
|
||||
import codecs, traceback, os, json, xml
|
||||
|
||||
from os import listdir, makedirs
|
||||
from os.path import isfile, isdir, getmtime, abspath
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, ChoiceLoader, select_autoescape
|
||||
from hamlpy.hamlpy import Compiler
|
||||
from jinja2 import Environment, FileSystemLoader, ChoiceLoader, select_autoescape, Markup
|
||||
from hamlish_jinja import HamlishExtension
|
||||
from markdown import markdown
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
from xml.dom import minidom
|
||||
|
||||
try:
|
||||
from sanic import response as Response
|
||||
|
||||
except ModuleNotFoundError:
|
||||
Response = None
|
||||
|
||||
from . import logging
|
||||
from .color import *
|
||||
|
||||
framework = 'sanic'
|
||||
|
||||
try:
|
||||
import sanic
|
||||
except:
|
||||
logging.debug('Cannot find Sanic')
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
except:
|
||||
logging.debug('Cannot find aioHTTP')
|
||||
from .misc import Path, DotDict
|
||||
|
||||
|
||||
env = None
|
||||
class Template(Environment):
|
||||
def __init__(self, search=[], global_vars={}, autoescape=True):
|
||||
self.autoescape = autoescape
|
||||
self.watcher = None
|
||||
self.search = []
|
||||
self.func_context = None
|
||||
|
||||
global_variables = {
|
||||
for path in search:
|
||||
self.__add_search_path(path)
|
||||
|
||||
super().__init__(
|
||||
loader=ChoiceLoader([FileSystemLoader(path) for path in self.search]),
|
||||
extensions=[HamlishExtension],
|
||||
autoescape=self.autoescape,
|
||||
lstrip_blocks=True,
|
||||
trim_blocks=True
|
||||
)
|
||||
|
||||
self.hamlish_file_extensions=('.haml',)
|
||||
self.hamlish_enable_div_shortcut=True
|
||||
self.hamlish_mode = 'indented'
|
||||
|
||||
self.globals.update({
|
||||
'markdown': markdown,
|
||||
'markup': Markup,
|
||||
'cleanhtml': lambda text: ''.join(xml.etree.ElementTree.fromstring(text).itertext()),
|
||||
'lighten': lighten,
|
||||
'darken': darken,
|
||||
'saturate': saturate,
|
||||
'desaturate': desaturate,
|
||||
'rgba': rgba
|
||||
}
|
||||
|
||||
search_path = list()
|
||||
build_path_pairs = dict()
|
||||
|
||||
|
||||
def addSearchPath(path):
|
||||
tplPath = abspath(path)
|
||||
|
||||
if not isdir(tplPath):
|
||||
raise FileNotFoundError(f'Cannot find template directory: {tplPath}')
|
||||
|
||||
if tplPath not in search_path:
|
||||
search_path.append(tplPath)
|
||||
|
||||
|
||||
def delSearchPath(path):
|
||||
tplPath = abspath(path)
|
||||
|
||||
if tplPath in search_path:
|
||||
search_path.remove(tplPath)
|
||||
|
||||
|
||||
def addBuildPath(name, source, destination):
|
||||
src = abspath(source)
|
||||
dest = abspath(destination)
|
||||
|
||||
if not isdir(src):
|
||||
raise FileNotFoundError(f'Source path doesn\'t exist: {src}')
|
||||
|
||||
build_path_pairs.update({
|
||||
name: {
|
||||
'source': src,
|
||||
'destination': dest
|
||||
}
|
||||
})
|
||||
|
||||
addSearchPath(dest)
|
||||
self.globals.update(global_vars)
|
||||
|
||||
|
||||
def delBuildPath(name):
|
||||
if not build_path_pairs.get(name):
|
||||
raise ValueError(f'"{name}" not in build paths')
|
||||
def __add_search_path(self, path):
|
||||
tpl_path = Path(path)
|
||||
|
||||
del build_path_pairs[src]
|
||||
if not tpl_path.exists():
|
||||
raise FileNotFoundError('Cannot find search path:', tpl_path.str())
|
||||
|
||||
if tpl_path.str() not in self.search:
|
||||
self.search.append(tpl_path.str())
|
||||
|
||||
|
||||
def getBuildPath(name=None):
|
||||
template = build_path_pairs.get(name)
|
||||
|
||||
if name:
|
||||
if template:
|
||||
return template
|
||||
|
||||
else:
|
||||
raise ValueError(f'"{name}" not in build paths')
|
||||
|
||||
return build_path_pairs
|
||||
def addEnv(self, k, v):
|
||||
self.globals[k] = v
|
||||
|
||||
|
||||
def addEnv(data):
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f'environment data is not a dict')
|
||||
|
||||
global_variables.update(data)
|
||||
|
||||
|
||||
def delEnv(var):
|
||||
if not global_variables.get(var):
|
||||
def delEnv(self, var):
|
||||
if not self.globals.get(var):
|
||||
raise ValueError(f'"{var}" not in global variables')
|
||||
|
||||
del global_variables[var]
|
||||
del self.var[var]
|
||||
|
||||
|
||||
def setup(fwork='sanic'):
|
||||
global env
|
||||
global framework
|
||||
def updateEnv(self, data):
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f'Environment data not a dict')
|
||||
|
||||
framework = fwork
|
||||
env = Environment(
|
||||
loader=ChoiceLoader([FileSystemLoader(path) for path in search_path]),
|
||||
autoescape=select_autoescape(['html', 'css']),
|
||||
lstrip_blocks=True,
|
||||
trim_blocks=True
|
||||
)
|
||||
self.globals.update(data)
|
||||
|
||||
|
||||
def renderTemplate(tplfile, context={}, request=None, headers=dict(), cookies=dict(), **kwargs):
|
||||
def addFilter(self, funct, name=None):
|
||||
name = funct.__name__ if not name else name
|
||||
self.filters[name] = funct
|
||||
|
||||
|
||||
def delFilter(self, name):
|
||||
if not self.filters.get(name):
|
||||
raise valueError(f'"{name}" not in global filters')
|
||||
|
||||
del self.filters[name]
|
||||
|
||||
|
||||
def updateFilter(self, data):
|
||||
if not isinstance(context, dict):
|
||||
raise TypeError(f'context for {tplfile} not a dict')
|
||||
raise ValueError(f'Filter data not a dict')
|
||||
|
||||
data = global_variables.copy()
|
||||
data['request'] = request if request else {'headers': headers, 'cookies': cookies}
|
||||
data.update(context)
|
||||
|
||||
return env.get_template(tplfile).render(data)
|
||||
self.filters.update(data)
|
||||
|
||||
|
||||
def sendResponse(template, request, context=dict(), status=200, ctype='text/html', headers=dict(), **kwargs):
|
||||
context['request'] = request
|
||||
html = renderTemplate(template, context, **kwargs)
|
||||
def render(self, tplfile, request=None, context={}, headers={}, cookies={}, pprint=False, **kwargs):
|
||||
if not isinstance(context, dict):
|
||||
raise TypeError(f'context for {tplfile} not a dict: {type(context)} {context}')
|
||||
|
||||
if framework == 'sanic':
|
||||
return sanic.response.text(html, status=status, headers=headers, content_type=ctype)
|
||||
context['request'] = request if request else {'headers': headers, 'cookies': cookies}
|
||||
|
||||
elif framework == 'aiohttp':
|
||||
return aiohttp.web.Response(body=html, status=status, headers=headers, content_type=ctype)
|
||||
if self.func_context:
|
||||
context.update(self.func_context(DotDict(context), DotDict(self.globals)))
|
||||
|
||||
result = self.get_template(tplfile).render(context)
|
||||
|
||||
if pprint and any(map(tplfile.endswith, ['haml', 'html', 'xml'])):
|
||||
return minidom.parseString(result).toprettyxml(indent=" ")
|
||||
|
||||
else:
|
||||
logging.error('Please install aiohttp or sanic. Response not sent.')
|
||||
return result
|
||||
|
||||
|
||||
# delete me later
|
||||
aiohttpTemplate = sendResponse
|
||||
def response(self, *args, ctype='text/html', status=200, **kwargs):
|
||||
if not Response:
|
||||
raise ModuleNotFoundError('Sanic is not installed')
|
||||
|
||||
|
||||
def buildTemplates(name=None):
|
||||
paths = getBuildPath(name)
|
||||
|
||||
for k, tplPaths in paths.items():
|
||||
src = tplPaths['source']
|
||||
dest = tplPaths['destination']
|
||||
|
||||
timefile = f'{dest}/times.json'
|
||||
updated = False
|
||||
|
||||
if not isdir(f'{dest}'):
|
||||
makedirs(f'{dest}')
|
||||
|
||||
if isfile(timefile):
|
||||
try:
|
||||
times = json.load(open(timefile))
|
||||
|
||||
except:
|
||||
times = {}
|
||||
|
||||
else:
|
||||
times = {}
|
||||
|
||||
for filename in listdir(src):
|
||||
fullPath = f'{src}/{filename}'
|
||||
modtime = getmtime(fullPath)
|
||||
base, ext = filename.split('.', 1)
|
||||
|
||||
if ext != 'haml':
|
||||
pass
|
||||
|
||||
elif base not in times or times.get(base) != modtime:
|
||||
updated = True
|
||||
logging.verbose(f"Template '{filename}' was changed. Building...")
|
||||
|
||||
try:
|
||||
destination = f'{dest}/{base}.html'
|
||||
haml_lines = codecs.open(fullPath, 'r', encoding='utf-8').read().splitlines()
|
||||
|
||||
compiler = Compiler()
|
||||
output = compiler.process_lines(haml_lines)
|
||||
outfile = codecs.open(destination, 'w', encoding='utf-8')
|
||||
outfile.write(output)
|
||||
|
||||
logging.info(f"Template '{filename}' has been built")
|
||||
|
||||
except Exception as e:
|
||||
'''I'm actually not sure what sort of errors can happen here, so generic catch-all for now'''
|
||||
traceback.print_exc()
|
||||
logging.error(f'Failed to build {filename}: {e}')
|
||||
|
||||
times[base] = modtime
|
||||
|
||||
if updated:
|
||||
with open(timefile, 'w') as filename:
|
||||
filename.write(json.dumps(times))
|
||||
|
||||
|
||||
def templateWatcher():
|
||||
watchPaths = [path['source'] for k, path in build_path_pairs.items()]
|
||||
logging.info('Starting template watcher')
|
||||
observer = Observer()
|
||||
|
||||
for tplpath in watchPaths:
|
||||
logging.debug(f'Watching template dir for changes: {tplpath}')
|
||||
observer.schedule(templateWatchHandler(), tplpath, recursive=True)
|
||||
|
||||
return observer
|
||||
|
||||
|
||||
class templateWatchHandler(FileSystemEventHandler):
|
||||
def on_any_event(self, event):
|
||||
filename, ext = os.path.splitext(os.path.relpath(event.src_path))
|
||||
|
||||
if event.event_type in ['modified', 'created'] and ext[1:] == 'haml':
|
||||
logging.info('Rebuilding templates')
|
||||
buildTemplates()
|
||||
|
||||
|
||||
__all__ = ['addSearchPath', 'delSearchPath', 'addBuildPath', 'delSearchPath', 'addEnv', 'delEnv', 'setup', 'renderTemplate', 'sendResponse', 'buildTemplates', 'templateWatcher']
|
||||
html = self.render(*args, **kwargs)
|
||||
return Response.HTTPResponse(body=html, status=status, content_type=ctype, headers=kwargs.get('headers', {}))
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
pycryptodome>=3.9.1
|
||||
colour>=0.1.5
|
||||
aiohttp>=3.6.2
|
||||
envbash>=1.2.0
|
||||
hamlpy3>=0.84.0
|
||||
Hamlish-Jinja==0.3.3
|
||||
Jinja2>=2.10.1
|
||||
jinja2-markdown>=0.0.3
|
||||
Mastodon.py>=1.5.0
|
||||
pycryptodome>=3.9.1
|
||||
python-magic>=0.4.18
|
||||
sanic>=19.12.2
|
||||
urllib3>=1.25.7
|
||||
watchdog>=0.8.3
|
||||
httpsig>=1.3.0
|
||||
|
|
Loading…
Reference in a new issue