From c4e5c9b6b4cc8c5e0554024d06e9cbd1fb9a3958 Mon Sep 17 00:00:00 2001 From: Izalia Mae Date: Tue, 2 Mar 2021 12:13:18 -0500 Subject: [PATCH] a lot of changes --- IzzyLib/__init__.py | 3 +- IzzyLib/cache.py | 90 +++-- IzzyLib/database.py | 764 +++++++++++++++-------------------------- IzzyLib/http.py | 385 ++++++++++++--------- IzzyLib/http_server.py | 326 ++++++++++++++++++ IzzyLib/misc.py | 201 ++++++++--- IzzyLib/template.py | 13 +- requirements.txt | 1 + 8 files changed, 1039 insertions(+), 744 deletions(-) create mode 100644 IzzyLib/http_server.py diff --git a/IzzyLib/__init__.py b/IzzyLib/__init__.py index d4e26fb..2c72f04 100644 --- a/IzzyLib/__init__.py +++ b/IzzyLib/__init__.py @@ -8,4 +8,5 @@ import sys assert sys.version_info >= (3, 6) -__version__ = (0, 4, 0) +__version_tpl__ = (0, 4, 0) +__version__ = '.'.join([str(v) for v in __version_tpl__]) diff --git a/IzzyLib/cache.py b/IzzyLib/cache.py index 75f70d3..780e748 100644 --- a/IzzyLib/cache.py +++ b/IzzyLib/cache.py @@ -5,8 +5,16 @@ import re from datetime import datetime from collections import OrderedDict +from .misc import DotDict + def parse_ttl(ttl): + if not ttl: + return 0 + + if type(ttl) == int: + return ttl * 60 + m = re.match(r'^(\d+)([smhdw]?)$', ttl) if not m: @@ -34,25 +42,58 @@ def parse_ttl(ttl): return multiplier * int(amount) -class TTLCache(OrderedDict): - def __init__(self, maxsize=1024, ttl='1h'): +class BaseCache(OrderedDict): + def __init__(self, maxsize=1024, ttl=None): self.ttl = parse_ttl(ttl) self.maxsize = maxsize + self.set = self.store + + + def __str__(self): + data = ', '.join([f'{k}="{v["data"]}"' for k,v in self.items()]) + return f'BaseCache({data})' + + + def get(self, key): + while len(self) >= self.maxsize and self.maxsize != 0: + self.popitem(last=False) + + item = DotDict.get(self, key) + + if not item: + return + + if self.ttl > 0: + timestamp = int(datetime.timestamp(datetime.now())) + + if timestamp >= self[key].timestamp: + del self[key] + return + + self[key].timestamp = timestamp + self.ttl + + self.move_to_end(key) + return item['data'] + def remove(self, key): if self.get(key): del self[key] + def store(self, key, value): - timestamp = int(datetime.timestamp(datetime.now())) - item = self.get(key) + if not self.get(key): + self[key] = DotDict() - while len(self) >= self.maxsize and self.maxsize != 0: - self.popitem(last=False) + self[key].data = value + + if self.ttl: + timestamp = int(datetime.timestamp(datetime.now())) + self[key].timestamp = timestamp + self.ttl - self[key] = {'data': value, 'timestamp': timestamp + self.ttl} self.move_to_end(key) + def fetch(self, key): item = self.get(key) timestamp = int(datetime.timestamp(datetime.now())) @@ -60,29 +101,22 @@ class TTLCache(OrderedDict): if not item: return - if timestamp >= self[key]['timestamp']: - del self[key] - return + if self.ttl: + if timestamp >= self[key].timestamp: + del self[key] + return + + self[key]['timestamp'] = timestamp + self.ttl - self[key]['timestamp'] = timestamp + self.ttl self.move_to_end(key) - return self[key]['data'] + return self[key].data -class LRUCache(OrderedDict): +class TTLCache(BaseCache): + def __init__(self, maxsize=1024, ttl='1h'): + super().__init__(maxsize, ttl) + + +class LRUCache(BaseCache): def __init__(self, maxsize=1024): - self.maxsize = maxsize - - def remove(self, key): - if key in self: - del self[key] - - def store(self, key, value): - while len(self) >= self.maxsize and self.maxsize != 0: - self.popitem(last=False) - - self[key] = value - self.move_to_end(key) - - def fetch(self, key): - return self.get(key) + super().__init__(maxsize) diff --git a/IzzyLib/database.py b/IzzyLib/database.py index 571ba84..bfd9672 100644 --- a/IzzyLib/database.py +++ b/IzzyLib/database.py @@ -1,536 +1,308 @@ -## Probably gonna replace all of this with a custom sqlalchemy setup tbh -## It'll look like the db classes and functions in https://git.barkshark.xyz/izaliamae/social -import shutil, traceback, importlib, sqlite3, sys, json +import sys from contextlib import contextmanager from datetime import datetime +from sqlalchemy import create_engine, ForeignKey, MetaData, Table +from sqlalchemy import Column as SqlColumn, types as Types +#from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.exc import OperationalError, ProgrammingError +from sqlalchemy.orm import sessionmaker +from . import logging from .cache import LRUCache -from .misc import Boolean, DotDict, Path -from . import logging, sql - -try: - from dbutils.pooled_db import PooledDB -except ImportError: - from DBUtils.PooledDB import PooledDB +from .misc import DotDict, RandomGen, NfsCheck -## Only sqlite3 has been tested -## Use other modules at your own risk. -class DB(): - def __init__(self, tables, dbmodule='sqlite', cursor=None, **kwargs): - cursor = Cursor if not cursor else cursor +SqlTypes = DotDict({t.lower(): getattr(Types, t) for t in dir(Types) if not t.startswith('_')}) - if dbmodule in ['sqlite', 'sqlite3']: - self.dbmodule = sqlite3 - self.dbtype = 'sqlite' + +class DataBase(): + def __init__(self, dbtype='postgresql+psycopg2', tables={}, **kwargs): + self.engine_string = self.__engine_string(dbtype, kwargs) + self.db = create_engine(self.engine_string) + self.table = Tables(self, tables) + self.cache = DotDict({table: LRUCache() for table in tables.keys()}) + self.classes = kwargs.get('row_classes', CustomRows()) + + session_class = kwargs.get('session_class', Session) + self.session = lambda trans=True: session_class(self, trans) + + + def __engine_string(self, dbtype, kwargs): + if not kwargs.get('database'): + raise MissingDatabaseError('Database not set') + + engine_string = dbtype + '://' + + if dbtype == 'sqlite': + if NfsCheck(kwargs.get('database')): + logging.error('Database file is on an NFS share which does not support locking. Any writes to the database will fail') + + engine_string += '/' + kwargs.get('database') else: - self.dbmodule = None - self.__setup_module(dbmodule) + user = kwargs.get('user') + password = kwargs.get('pass') + host = kwargs.get('host', '/var/run/postgresql') + port = kwargs.get('port', 5432) + name = kwargs.get('name', 'postgres') + maxconn = kwargs.get('maxconnections', 25) - self.db = None - self.cursor = lambda : cursor(self).begin() - self.kwargs = kwargs - self.tables = tables - self.cache = DotDict() - self.__setup_database() + if user: + if password: + engine_string += f'{user}:{password}@' + else: + engine_string += user + '@' - for table in tables.keys(): - self.__setup_cache(table) - - - def query(self, *args, **kwargs): - return self.__cursor_cmd('query', *args, **kwargs) - - - def fetch(self, *args, **kwargs): - return self.__cursor_cmd('fetch', *args, **kwargs) - - - def insert(self, *args, **kwargs): - return self.__cursor_cmd('insert', *args, **kwargs) - - - def update(self, *args, **kwargs): - return self.__cursor_cmd('update', *args, **kwargs) - - - def remove(self, *args, **kwargs): - return self.__cursor_cmd('remove', *args, **kwargs) - - - def __cursor_cmd(self, name, *args, **kwargs): - with self.cursor() as cur: - method = getattr(cur, name) - - if not method: - raise KeyError('Not a valid cursor method:', name) - - return method(*args, **kwargs) - - - def __setup_module(self, dbtype): - modules = [] - modtypes = { - 'sqlite': ['sqlite3'], - 'postgresql': ['psycopg3', 'pgdb', 'psycopg2', 'pg8000'], - 'mysql': ['mysqldb', 'trio_mysql'], - 'mssql': ['pymssql', 'adodbapi'] - } - - for dbmod, mods in modtypes.items(): - if dbtype == dbmod: - self.dbtype = dbmod - modules = mods - break - - elif dbtype in mods: - self.dbtype = dbmod - modules = [dbtype] - break - - if not modules: - logging.verbose('Not a database type. Checking if it is a module...') - - for mod in modules: - try: - self.dbmodule = importlib.import_module(mod) - except ImportError as e: - logging.verbose('Module not installed:', mod) - - if self.dbmodule: - break - - if not self.dbmodule: - if modtypes.get(dbtype): - logging.error('Failed to find module for database type:', dbtype) - logging.error(f'Please install one of these modules to use a {dbtype} database:', ', '.join(modules)) + if host == '/var/run/postgresql': + engine_string += '/' + name else: - logging.error('Failed to import module:', dbtype) - logging.error('Install one of the following modules:') + engine_string += f'{host}:{port}/{name}' - for key, modules in modtypes.items(): - logging.error(f'{key}:') - for mod in modules: - logging.error(f'\t{mod}') - - sys.exit() + return engine_string - def __setup_database(self): - if self.dbtype == 'sqlite': - if not self.kwargs.get('database'): - dbfile = ':memory:' + def CreateDatabase(self): + if self.engine_string.startswith('postgresql'): + predb = create_engine(db.engine_string.replace(config.db.name, 'postgres', -1)) + conn = predb.connect() + conn.execute('commit') - dbfile = Path(self.kwargs['database']) - dbfile.parent().mkdir() + try: + conn.execute(f'CREATE DATABASE {config.db.name}') - if not dbfile.parent().isdir(): - logging.error('Invalid path to database file:', dbfile.parent().str()) - sys.exit() + except ProgrammingError: + 'The database already exists, so just move along' - self.kwargs['database'] = dbfile.str() - self.kwargs['check_same_thread'] = False + except Exception as e: + conn.close() + raise e from None - else: - if not self.kwargs.get('password'): - self.kwargs.pop('password', None) - - self.db = PooledDB(self.dbmodule, **self.kwargs) - - - def connection(self): - if self.dbtype == 'sqlite': - return self.db.connection() - - - def __setup_cache(self, table): - self.cache[table] = LRUCache(128) - - - def close(self): - self.db.close() - - - def count(self, table): - tabledict = tables.get(table) - - if not tabledict: - logging.debug('Table does not exist:', table) - return 0 - - data = self.query(f'SELECT COUNT(*) FROM {table}') - return data[0][0] - - - #def query(self, string, values=[], cursor=None): - #if not string.endswith(';'): - #string += ';' - - #if not cursor: - #with self.Cursor() as cursor: - #cursor.execute(string, values) - #return cursor.fetchall() - - #else: - #cursor.execute(string,value) - #return cursor.fetchall() - - #return False - - - @contextmanager - def Cursor(self): - conn = self.db.connection() - conn.begin() - cursor = conn.cursor() - - try: - yield cursor - except self.dbmodule.OperationalError: - traceback.print_exc() - conn.rollback() - finally: - cursor.close() - conn.commit() conn.close() + self.table.meta.create_all(self.db) - def CreateTable(self, table): - layout = DotDict(self.tables.get(table)) - if not layout: - logging.error('Table config doesn\'t exist:', table) - return + def execute(self, *args, **kwargs): + with self.session() as s: + return s.execute(*args, **kwargs) - cmd = f'CREATE TABLE IF NOT EXISTS {table}(' - items = [] - - for k, v in layout.items(): - options = ' '.join(v.get('options', [])) - default = v.get('default') - item = f'{k} {v["type"].upper()} {options}' - - if default: - item += f'DEFAULT {default}' - - items.append(item) - - cmd += ', '.join(items) + ')' - - return True if self.query(cmd) != False else False - - - def RenameTable(self, table, newname): - self.query(f'ALTER TABLE {table} RENAME TO {newname}') - - - def DropTable(self, table): - self.query(f'DROP TABLE {table}') - - - def AddColumn(self, table, name, datatype, default=None, options=None): - query = f'ALTER TABLE {table} ADD COLUMN {name} {datatype.upper()}' - - if default: - query += f' DEFAULT {default}' - - if options: - query += f' {options}' - - self.query(query) - - - def CheckDatabase(self, database): - if self.dbtype == 'postgresql': - tables = self.query('SELECT datname FROM pg_database') - - else: - tables = [] - - tables = [table[0] for table in tables] - print(database in tables, database, tables) - - return database in tables - - - def CreateTables(self): - dbname = self.kwargs['database'] - - for name, table in self.tables.items(): - if len(self.query(sql.CheckTable(self.dbtype, name))) < 1: - logging.info('Creating table:', name) - self.query(table.sql()) - - -class Table(dict): - def __init__(self, name, dbtype='sqlite'): - super().__init__({}) - - self.sqlstr = 'CREATE TABLE IF NOT EXISTS {} ({});' - self.dbtype = dbtype - self.name = name - self.columns = {} - self.fkeys = {} - - - def addColumn(self, name, datatype=None, null=True, unique=None, primary=None, fkey=None): - if name == 'id': - if self.dbtype == 'sqlite': - datatype = 'integer' - primary = True if primary == None else primary - unique = True - null = False - - else: - datatype = 'serial' - primary = 'true' - - if name == 'timestamp': - datatype = 'float' - - elif not datatype: - raise MissingTypeError(f'Missing a data type for column: {name}') - - colsql = f'{name} {datatype.upper()}' - - if unique: - colsql += ' UNIQUE' - - if not null: - colsql += ' NOT NULL' - - if primary: - self.primary = name - colsql += ' PRIMARY KEY' - - if name == 'id' and self.dbtype == 'sqlite': - colsql += ' AUTOINCREMENT' - - if fkey: - if self.dbtype == 'postgresql': - colsql += f' REFERENCES {fkey[0]}({fkey[1]})' - - elif self.dbtype == 'sqlite': - self.fkeys[name] += f'FOREIGN KEY ({name}) REFERENCES {fkey[0]} ({fkey[1]})' - - self.columns.update({name: colsql}) - - - def sql(self): - if not self.primary: - logging.error('Please specify a primary column') - return - - data = ', '.join(list(self.columns.values())) - - if self.fkeys: - data += ', ' - data += ', '.join(list(self.fkeys.values())) - - sqldata = self.sqlstr.format(self.name, data) - print(sqldata) - return sqldata - - -class Cursor(object): - def __init__(self, db): - self.main = db - self.db = db.db - - @contextmanager - def begin(self): - self.conn = self.db.connection() - self.conn.begin() - self.cursor = self.conn.cursor() - - try: - yield self - - except self.main.dbmodule.OperationalError: - self.conn.rollback() - raise - - finally: - self.cursor.close() - self.conn.commit() - self.conn.close() - - - def query(self, string, values=[], cursor=None): - #if not string.endswith(';'): - #string += ';' - - self.cursor.execute(string, values) - data = self.cursor.fetchall() - - - def fetch(self, table, single=True, sort=None, **kwargs): - rowid = kwargs.get('id') - querysort = f'ORDER BY {sort}' - - resultOpts = [self, table, self.cursor] - - if rowid: - cursor.execute(f"SELECT * FROM {table} WHERE id = ?", [rowid]) - - elif kwargs: - placeholders = [f'{k} = ?' for k in kwargs.keys()] - values = kwargs.values() - - where = ' and '.join(placeholders) - query = f"SELECT * FROM {table} WHERE {where} {querysort if sort else ''}" - self.cursor.execute(query, list(values)) - - else: - self.cursor.execute(f'SELECT * FROM {table} {querysort if sort else ""}') - - rows = self.cursor.fetchall() if not single else self.cursor.fetchone() - - if rows: - if single: - return DBResult(rows, *resultOpts) - - return [DBResult(row, *resultOpts) for row in rows] - - return None if single else [] - - - def insert(self, table, data={}, **kwargs): - data.update(kwargs) - placeholders = ",".join(['?' for _ in data.keys()]) - values = tuple(data.values()) - keys = ','.join(data.keys()) - - if 'timestamp' in self.main.tables[table].keys() and 'timestamp' not in keys: - data['timestamp'] = datetime.now() - - query = f'INSERT INTO {table} ({keys}) VALUES ({placeholders})' - self.query(query, values) - return True - - - def remove(self, table, **kwargs): - keys = [] - values = [] - - for k,v in kwargs.items(): - keys.append(k) - values.append(v) - - keydata = ','.join([f'{k} = ?' for k in keys]) - query = f'DELETE FROM {table} WHERE {keydata}' - - self.query(query, values) - - - def update(self, table, rowid, data={}, **kwargs): - data.update(kwargs) - newdata = {k: v for k, v in data.items() if k in self.main.tables[table].keys()} - keys = list(newdata.keys()) - values = list(newdata.values()) - - if len(newdata) < 1: - logging.debug('No data provided to update row') - return False - - query_data = ', '.join(f'{k} = ?' for k in keys) - query = f'UPDATE {table} SET {query_data} WHERE id = {rowid}' - - self.query(query, values) - - -class DBResult(DotDict): - def __init__(self, row, db, table, cursor): - super().__init__() +class Session(object): + def __init__(self, db, trans=True): self.db = db - self.table = table + self.classes = self.db.classes + self.session = sessionmaker(bind=db.db)() + self.table = self.db.table + self.cache = self.db.cache + self.trans = trans - for idx, col in enumerate(cursor.description): - self[col[0]] = row[idx] + # session aliases + self.s = self.session + self.commit = self.s.commit + self.rollback = self.s.rollback + self.query = self.s.query + self.execute = self.s.execute + + self._setup() + + if not self.trans: + self.commit() - def __setattr__(self, name, value): - if name not in ['db', 'table']: - return self.__setitem__(name, value) - - else: - return super().__setattr__(name, value) - - - def __delattr__(self, name): - if name not in ['db', 'table']: - return self.__delitem__(name) - - else: - return super().__delattr__(name) - - - def __getattr__(self, value, default=None): - options = [value] - - if default: - options.append(default) - - if value in self.keys(): - val = super().__getitem__(*options) - return DotDict(val) if isinstance(val, dict) else val - - else: - return dict.__getattr__(*options) - - - # Kept for backwards compatibility. Delete later. - def asdict(self): + def __enter__(self): + self.sessionid = RandomGen(10) return self - def Update(self, data={}): - with self.db.Cursor().begin() as cursor: - self.update(data) - cursor.update(self.table, self.id, self.AsDict()) - - - def Remove(self): - with self.db.Cursor().begin() as cursor: - cursor.remove(self.table, id=self.id) - - -def ParseData(table, row): - tbdata = tables.get(table) - types = [] - result = [] - - if not tbdata: - logging.error('Invalid table:', table) - return - - for v in tbdata.values(): - dtype = v.split()[0].upper() - - if dtype == 'BOOLEAN': - types.append(boolean) - - elif dtype in ['INTEGER', 'INT']: - types.append(int) + def __exit__(self, exctype, value, tb): + if tb: + self.rollback() else: - types.append(str) - - for idx, v in enumerate(row): - result.append(types[idx](v) if v else None) - - return row - - db.insert('config', {'key': 'version', 'value': dbversion}) + self.commit() -class MissingTypeError(Exception): - pass + def _setup(self): + pass -class TooManyConnectionsError(Exception): - pass + def count(self, table_name, **kwargs): + return self.query(self.table[table_name]).filter_by(**kwargs).count() + + + def fetch(self, table_name, single=True, **kwargs): + RowClass = self.classes.get(table_name.capitalize()) + + rows = self.query(self.table[table_name]).filter_by(**kwargs).all() + + if single: + return RowClass(table_name, rows[0], self) if len(rows) > 0 else None + + return [RowClass(table_name, row, self) for row in rows] + + + def insert(self, table_name, **kwargs): + row = self.fetch(table_name, **kwargs) + + if row: + row.update_session(self, **kwargs) + return + + table = self.table[table_name] + + if getattr(table, 'timestamp', None) and not kwargs.get('timestamp'): + kwargs['timestamp'] = datetime.now() + + res = self.execute(table.insert().values(**kwargs)) + #return self.fetch(table_name, **kwargs) + + + def update(self, table=None, rowid=None, row=None, **data): + if row: + rowid = row.id + table = row._table_name + + if not rowid or not table: + raise ValueError('Missing row ID or table') + + tclass = self.table[table] + + self.execute(tclass.update().where(tclass.c.id == rowid).values(**data)) + + + def remove(self, table=None, rowid=None, row=None): + if row: + rowid = row.id + table = row._table_name + + if not rowid or not table: + raise ValueError('Missing row ID or table') + + row = self.execute(f'DELETE FROM {table} WHERE id={rowid}') + + + def DropTables(self): + tables = self.GetTables() + + for table in tables: + self.execute(f'DROP TABLE {table}') + + + def GetTables(self): + rows = self.execute("SELECT name FROM sqlite_master WHERE type IN ('table','view') and name NOT LIKE 'sqlite_%'") + return [row[0] for row in rows] + + +class CustomRows(object): + def get(self, name): + return getattr(self, name, self.Row) + + + class Row(DotDict): + #_filter_columns = lambda self, row: [attr for attr in dir(row) if not attr.startswith('_') and attr != 'metadata'] + + + def __init__(self, table, row, session): + if not row: + return + + super().__init__() + self._update(row._asdict()) + + self._db = session.db + self._table_name = table + self._columns = self.keys() + #self._columns = self._filter_columns(row) + + self.__run__(session) + + + ## Subclass Row and redefine this function + def __run__(self, s): + pass + + + def _filter_data(self): + data = {k: v for k,v in self.items() if k in self._columns} + + for k,v in self.items(): + if v.__class__ == DotDict: + data[k] = v.asDict() + + return data + + + def asDict(self): + return self._filter_data() + + + def _update(self, new_data={}, **kwargs): + kwargs.update(new_data) + + for k,v in kwargs.items(): + if type(v) == dict: + self[k] = DotDict(v) + + self[k] = v + + + def delete(self): + with self._db.session() as s: + return self.delete_session(s) + + + def delete_session(self, s): + return s.remove(row=self) + + + def update(self, dict_data={}, **data): + dict_data.update(data) + self._update(dict_data) + + with self._db.session() as s: + s.update(row=self, **self._filter_data()) + + + def update_session(self, s, dict_data={}, **data): + return s.update(row=self, **dict_data, **data) + + +class Tables(DotDict): + def __init__(self, db, tables={}): + '"tables" should be a dict with the table names for keys and a list of Columns for values' + super().__init__() + + self.db = db + self.meta = MetaData() + + for name, table in tables.items(): + self.__setup_table(name, table) + + + def __setup_table(self, name, table): + self[name] = Table(name, self.meta, *table) + + +def Column(name, stype=None, fkey=None, **kwargs): + if not stype and not kwargs: + if name == 'id': + return Column('id', 'integer', primary_key=True, autoincrement=True) + + elif name == 'timestamp': + return Column('timestamp', 'datetime') + + raise ValueError('Missing column type and options') + + else: + options = [name, SqlTypes.get(stype.lower(), SqlTypes['string'])] + + if fkey: + options.append(ForeignKey(fkey)) + + return SqlColumn(*options, **kwargs) + + +class MissingDatabaseError(Exception): + '''raise when the "database" kwargs is not set''' diff --git a/IzzyLib/http.py b/IzzyLib/http.py index d55d456..3012e96 100644 --- a/IzzyLib/http.py +++ b/IzzyLib/http.py @@ -4,11 +4,12 @@ from IzzyLib import logging from IzzyLib.misc import DefaultDict, DotDict from base64 import b64decode, b64encode from datetime import datetime +from ssl import SSLCertVerificationError from urllib.error import HTTPError from urllib.parse import urlparse from urllib.request import Request, urlopen -from . import error +from . import error, __version__ try: from Crypto.Hash import SHA256 @@ -21,6 +22,7 @@ except ImportError: try: from sanic.request import Request as SanicRequest + from sanic.exceptions import SanicException sanic_enabled = True except ImportError: logging.verbose('Sanic module not found. Request verification is disabled') @@ -30,169 +32,8 @@ except ImportError: Client = None -def VerifyRequest(request: SanicRequest, actor: dict=None): - '''Verify a header signature from a sanic request - - request: The request with the headers to verify - actor (optional): A dictionary containing the activitypub actor and the link to the pubkey used for verification - ''' - if not sanic_enabled: - logging.error('Sanic request verification disabled') - return - - if not actor: - actor = request.ctx.actor - - body = request.body if request.body else None - return VerifyHeaders(request.headers, request.method, request.path, body, actor, False) - - -def VerifyHeaders(headers: dict, method: str, path: str, actor: dict=None, body=None, fail: bool=False): - '''Verify a header signature - - headers: A dictionary containing all the headers from a request - method: The HTTP method of the request - path: The path of the HTTP request - actor (optional): A dictionary containing the activitypub actor and the link to the pubkey used for verification - body (optional): The body of the request. Only needed if the signature includes the digest header - fail (optional): If set to True, raise an error instead of returning False if any step of the process fails - ''' - if not crypto_enabled: - logging.error('Crypto functions disabled') - return - - headers = {k.lower(): v for k,v in headers.items()} - headers['(request-target)'] = f'{method.lower()} {path}' - signature = ParseSig(headers.get('signature')) - digest = headers.get('digest') - missing_headers = [k for k in headers if k in ['date', 'host'] if headers.get(k) == None] - - if not signature: - if fail: - raise MissingSignatureError() - - return False - - if not actor: - actor = FetchActor(signature.keyid) - - ## Add digest header to missing headers list if it doesn't exist - if method.lower() == 'post' and not headers.get('digest'): - missing_headers.append('digest') - - ## Fail if missing date, host or digest (if POST) headers - if missing_headers: - if fail: - raise error.MissingHeadersError(missing_headers) - - return False - - ## Fail if body verification fails - if digest and not VerifyString(body, digest): - if fail: - raise error.VerificationError('digest header') - - return False - - pubkey = actor.publicKey['publicKeyPem'] - - if PkcsHeaders(pubkey, {k:v for k,v in headers.items() if k in signature.headers}, sig=signature): - return True - - if fail: - raise error.VerificationError('headers') - - return False - - -@functools.lru_cache(maxsize=512) -def VerifyString(string, enc_string, alg='SHA256', fail=False): - if not crypto_enabled: - logging.error('Crypto functions disabled') - return - - if type(string) != bytes: - string = string.encode('UTF-8') - - body_hash = b64encode(SHA256.new(string).digest()).decode('UTF-8') - - if body_hash == enc_string: - return True - - if fail: - raise error.VerificationError() - - else: - return False - - -def PkcsHeaders(key: str, headers: dict, sig=None): - if not crypto_enabled: - logging.error('Crypto functions disabled') - return - - if sig: - head_items = [f'{item}: {headers[item]}' for item in sig.headers] - - else: - head_items = [f'{k.lower()}: {v}' for k,v in headers.items()] - - head_string = '\n'.join(head_items) - head_bytes = head_string.encode('UTF-8') - - KEY = RSA.importKey(key) - pkcs = PKCS1_v1_5.new(KEY) - h = SHA256.new(head_bytes) - - if sig: - return pkcs.verify(h, b64decode(sig.signature)) - - else: - return pkcs.sign(h) - - -def ParseSig(signature: str): - if not signature: - logging.verbose('Missing signature header') - return - - split_sig = signature.split(',') - sig = DefaultDict({}) - - for part in split_sig: - key, value = part.split('=', 1) - sig[key.lower()] = value.replace('"', '') - - if not sig.headers: - logging.verbose('Missing headers section in signature') - return - - sig.headers = sig.headers.split() - - return sig - - -@functools.lru_cache(maxsize=512) -def FetchActor(keyid, client=None): - if not client: - client = Client if Client else HttpClient() - - actor = Client.request(keyid).json() - actor.domain = urlparse(actor.id).netloc - actor.shared_inbox = actor.inbox - actor.pubkey = None - - if actor.get('endpoints'): - actor.shared_inbox = actor.endpoints.get('sharedInbox', actor.inbox) - - if actor.get('publicKey'): - actor.pubkey = actor.publicKey.get('publicKeyPem') - - return actor - - class HttpClient(object): - def __init__(self, headers={}, useragent='IzzyLib/0.3', proxy_type='https', proxy_host=None, proxy_port=None): + def __init__(self, headers={}, useragent=f'IzzyLib/{__version__}', appagent=None, proxy_type='https', proxy_host=None, proxy_port=None): proxy_ports = { 'http': 80, 'https': 443 @@ -202,7 +43,7 @@ class HttpClient(object): raise ValueError(f'Not a valid proxy type: {proxy_type}') self.headers=headers - self.agent=useragent + self.agent = f'{useragent} ({appagent})' if appagent else useragent self.proxy = DotDict({ 'enabled': True if proxy_host else False, 'ptype': proxy_type, @@ -210,6 +51,8 @@ class HttpClient(object): 'port': proxy_ports[proxy_type] if not proxy_port else proxy_port }) + self.SetGlobal = SetClient + def __sign_request(self, request, privkey, keyid): if not crypto_enabled: @@ -269,9 +112,14 @@ class HttpClient(object): try: response = urlopen(request) + except HTTPError as e: response = e.fp + except SSLCertVerificationError as e: + logging.error('HttpClient.request: Certificate error:', e) + return + return HttpResponse(response) @@ -337,6 +185,211 @@ class HttpResponse(object): return json.dumps(self.json().asDict(), indent=indent) -def SetClient(client: HttpClient): +def VerifyRequest(request: SanicRequest, actor: dict): + '''Verify a header signature from a sanic request + + request: The request with the headers to verify + actor: A dictionary containing the activitypub actor and the link to the pubkey used for verification + ''' + if not sanic_enabled: + logging.error('Sanic request verification disabled') + return + + body = request.body if request.body else None + return VerifyHeaders(request.headers, request.method, request.path, body, actor) + + +def VerifyHeaders(headers: dict, method: str, path: str, actor: dict=None, body=None): + '''Verify a header signature + + headers: A dictionary containing all the headers from a request + method: The HTTP method of the request + path: The path of the HTTP request + actor (optional): A dictionary containing the activitypub actor and the link to the pubkey used for verification + body (optional): The body of the request. Only needed if the signature includes the digest header + fail (optional): If set to True, raise an error instead of returning False if any step of the process fails + ''' + if not crypto_enabled: + logging.error('Crypto functions disabled') + return + + headers = {k.lower(): v for k,v in headers.items()} + headers['(request-target)'] = f'{method.lower()} {path}' + signature = ParseSig(headers.get('signature')) + digest = ParseBodyDigest(headers.get('digest')) + missing_headers = [k for k in headers if k in ['date', 'host'] if headers.get(k) == None] + + if not signature: + logging.verbose('Missing signature') + return False + + if not actor: + actor = FetchActor(signature.keyid) + + ## Add digest header to missing headers list if it doesn't exist + if method.lower() == 'post' and not digest: + missing_headers.append('digest') + + ## Fail if missing date, host or digest (if POST) headers + if missing_headers: + logging.verbose('Missing headers:', missing_headers) + return False + + ## Fail if body verification fails + if digest and not VerifyString(body, digest.sig, digest.alg): + logging.verbose('Failed body digest verification') + return False + + pubkey = actor.publicKey['publicKeyPem'] + + if PkcsHeaders(pubkey, {k:v for k,v in headers.items() if k in signature.headers}, sig=signature): + return True + + logging.verbose('Failed header verification') + return False + + +def ParseBodyDigest(digest): + if not digest: + return + + parsed = DotDict() + parts = digest.split('=', 1) + + if len(parts) != 2: + return + + parsed.sig = parts[1] + parsed.alg = parts[0].replace('-', '') + + return parsed + + +def VerifyString(string, enc_string, alg='SHA256', fail=False): + if not crypto_enabled: + logging.error('Crypto functions disabled') + return + + if type(string) != bytes: + string = string.encode('UTF-8') + + body_hash = b64encode(SHA256.new(string).digest()).decode('UTF-8') + + if body_hash == enc_string: + return True + + if fail: + raise error.VerificationError() + + else: + return False + + +def PkcsHeaders(key: str, headers: dict, sig=None): + if not crypto_enabled: + logging.error('Crypto functions disabled') + return + + if sig: + head_items = [f'{item}: {headers[item]}' for item in sig.headers] + + else: + head_items = [f'{k.lower()}: {v}' for k,v in headers.items()] + + head_string = '\n'.join(head_items) + head_bytes = head_string.encode('UTF-8') + + KEY = RSA.importKey(key) + pkcs = PKCS1_v1_5.new(KEY) + h = SHA256.new(head_bytes) + + if sig: + return pkcs.verify(h, b64decode(sig.signature)) + + else: + return pkcs.sign(h) + + +def ParseSig(signature: str): + if not signature: + logging.verbose('Missing signature header') + return + + split_sig = signature.split(',') + sig = DefaultDict({}) + + for part in split_sig: + key, value = part.split('=', 1) + sig[key.lower()] = value.replace('"', '') + + if not sig.headers: + logging.verbose('Missing headers section in signature') + return + + sig.headers = sig.headers.split() + + return sig + + +def FetchActor(url): + if not Client: + logging.error('IzzyLib.http: Please set global client with "SetClient(client)"') + return {} + + url = url.split('#')[0] + headers = {'Accept': 'application/activity+json'} + resp = Client.request(url, headers=headers) + + if not resp.json(): + logging.verbose('functions.FetchActor: Failed to fetch actor:', url) + logging.debug(f'Error {resp.status}: {resp.body}') + return {} + + actor = resp.json() + actor.web_domain = urlparse(url).netloc + actor.shared_inbox = actor.inbox + actor.pubkey = None + actor.handle = actor.preferredUsername + + if actor.get('endpoints'): + actor.shared_inbox = actor.endpoints.get('sharedInbox', actor.inbox) + + if actor.get('publicKey'): + actor.pubkey = actor.publicKey.get('publicKeyPem') + + return actor + + +@functools.lru_cache(maxsize=512) +def FetchWebfingerAcct(handle, domain): + if not Client: + logging.error('IzzyLib.http: Please set global client with "SetClient(client)"') + return {} + + data = DefaultDict() + webfinger = Client.request(f'https://{domain}/.well-known/webfinger?resource=acct:{handle}@{domain}') + + if not webfinger.body: + return + + data.handle, data.domain = webfinger.json().subject.replace('acct:', '').split('@') + + for link in webfinger.json().links: + if link['rel'] == 'self' and link['type'] == 'application/activity+json': + data.actor = link['href'] + + return data + + +def SetClient(client=None): global Client - Client = client + Client = client or HttpClient() + + +def GenRsaKey(): + privkey = RSA.generate(2048) + + key = DotDict({'PRIVKEY': privkey, 'PUBKEY': privkey.publickey()}) + key.update({'privkey': key.PRIVKEY.export_key().decode(), 'pubkey': key.PUBKEY.export_key().decode()}) + + return key diff --git a/IzzyLib/http_server.py b/IzzyLib/http_server.py new file mode 100644 index 0000000..5a467c4 --- /dev/null +++ b/IzzyLib/http_server.py @@ -0,0 +1,326 @@ +import multiprocessing, sanic, signal, traceback +import logging as pylog + +from jinja2.exceptions import TemplateNotFound +from multidict import CIMultiDict +from multiprocessing import cpu_count, current_process +from urllib.parse import parse_qsl, urlparse + +from . import http, logging +from .misc import DotDict, DefaultDict, LowerDotDict +from .template import Template + + +log_path_ignore = [ + '/media', + '/static' +] + +log_ext_ignore = [ + 'js', 'ttf', 'woff2', + 'ac3', 'aiff', 'flac', 'm4a', 'mp3', 'ogg', 'wav', 'wma', + 'apng', 'ico', 'jpeg', 'jpg', 'png', 'svg', + 'divx', 'mov', 'mp4', 'webm', 'wmv' +] + + +class HttpServer(sanic.Sanic): + def __init__(self, name='sanic', host='0.0.0.0', port='4080', **kwargs): + self.host = host + self.port = int(port) + self.workers = int(kwargs.get('workers', cpu_count())) + self.sig_handler = kwargs.get('sig_handler') + self.ctx = DotDict() + + super().__init__(name, request_class=kwargs.get('request_class', HttpRequest)) + + #for log in ['sanic.root', 'sanic.access']: + #pylog.getLogger(log).setLevel(pylog.CRITICAL) + + self.template = Template( + kwargs.get('tpl_search', []), + kwargs.get('tpl_globals', {}), + kwargs.get('tpl_context'), + kwargs.get('tpl_autoescape', True) + ) + + self.template.addEnv('app', self) + + self.error_handler.add(TemplateNotFound, NoTemplateError) + self.error_handler.add(Exception, kwargs.get('error_handler', GenericError)) + self.register_middleware(MiddlewareAccessLog, attach_to='response') + + signal.signal(signal.SIGHUP, self.finish) + signal.signal(signal.SIGINT, self.finish) + signal.signal(signal.SIGQUIT, self.finish) + signal.signal(signal.SIGTERM, self.finish) + + + def add_method_route(self, method, *routes): + for route in routes: + self.add_route(method.as_view(), route) + + + def add_method_routes(self, routes: list): + for route in routes: + self.add_method_route(*route) + + + def start(self): + options = { + 'host': self.host, + 'port': self.port, + 'workers': self.workers, + 'access_log': False, + 'debug': False + } + + msg = f'Starting {self.name} at {self.host}:{self.port}' + + if self.workers > 1: + msg += f' with {self.workers} workers' + + logging.info(msg) + self.run(**options) + + + def finish(self): + if self.sig_handler: + self.sig_handler() + + self.stop() + logging.info('Bye! :3') + + +class HttpRequest(sanic.request.Request): + def __init__(self, url_bytes, headers, version, method, transport, app): + super().__init__(url_bytes, headers, version, method, transport, app) + + self.Headers = Headers(headers) + self.Data = Data(self) + self.template = self.app.template + self.__setup_defaults() + self.__parse_path() + + #if self.paths.media: + #return + + self.__parse_signature() + self.Run() + + + def Run(self): + pass + + + def response(self, tpl, *args, **kwargs): + return self.template.response(self, tpl, *args, **kwargs) + + + def alldata(self): + return self.__combine_dicts(self.content.json, self.data.query, self.data.form) + + + def verify(self, actor=None): + self.ap.valid = http.VerifyHeaders(self.headers, self.method, self.path, actor, self.body) + return self.ap.valid + + + def __combine_dicts(self, *dicts): + data = DotDict() + + for item in dicts: + data.update(item) + + return data + + + def __setup_defaults(self): + self.paths = DotDict({'media': False, 'json': False, 'ap': False, 'cookie': False}) + self.ap = DotDict({'valid': False, 'signature': {}, 'actor': None, 'inbox': None, 'domain': None}) + + + def __parse_path(self): + self.paths.media = any(map(self.path.startswith, log_path_ignore)) or any(map(self.path.startswith, log_ext_ignore)) + self.paths.json = self.__json_check() + + + def __parse_signature(self): + sig = self.headers.getone('signature', None) + + if sig: + self.ap.signature = http.ParseSig(sig) + + if self.ap.signature: + self.ap.actor = self.ap.signature.get('keyid', '').split('#', 1)[0] + self.ap.domain = urlparse(self.ap.actor).netloc + + + def __json_check(self): + if self.path.endswith('.json'): + return True + + accept = self.headers.getone('Accept', None) + + if accept: + mimes = [v.strip() for v in accept.split(',')] + + if any(mime in ['application/json', 'application/activity+json'] for mime in mimes): + return True + + return False + + +class Headers(LowerDotDict): + def __init__(self, headers): + super().__init__() + + for k,v in headers.items(): + if not self.get(k): + self[k] = [] + + self[k].append(v) + + + def getone(self, key, default=None): + value = self.get(key) + + if not value: + return default + + return value[0] + + + def getall(self, key, default=[]): + return self.get(key.lower(), default) + + +class Data(object): + def __init__(self, request): + self.request = request + + + @property + def combined(self): + return DotDict(**self.form.asDict(), **self.query.asDict(), **self.json.asDict()) + + + @property + def query(self): + data = {k: v for k,v in parse_qsl(self.request.query_string)} + return DotDict(data) + + + @property + def form(self): + data = {k: v[0] for k,v in self.request.form.items()} + return DotDict(data) + + + @property + def files(self): + return DotDict({k:v[0] for k,v in self.request.files.items()}) + + + ### body functions + @property + def raw(self): + try: + return self.request.body + except Exception as e: + logging.verbose('IzzyLib.http_server.Data.raw: failed to get body') + logging.debug(f'{e.__class__.__name__}: {e}') + return b'' + + + @property + def text(self): + try: + return self.raw.decode() + except Exception as e: + logging.verbose('IzzyLib.http_server.Data.text: failed to get body') + logging.debug(f'{e.__class__.__name__}: {e}') + return '' + + + @property + def json(self): + try: + return DotDict(self.text) + except Exception as e: + logging.verbose('IzzyLib.http_server.Data.json: failed to get body') + logging.debug(f'{e.__class__.__name__}: {e}') + data = '{}' + return {} + + +async def MiddlewareAccessLog(request, response): + if request.paths.media: + return + + uagent = request.headers.get('user-agent') + address = request.headers.get('x-real-ip', request.forwarded.get('for', request.remote_addr)) + + logging.info(f'({multiprocessing.current_process().name}) {address} {request.method} {request.path} {response.status} "{uagent}"') + + +def GenericError(request, exception): + try: + status = exception.status_code + except: + status = 500 + + if status not in range(200, 499): + traceback.print_exc() + + msg = f'{exception.__class__.__name__}: {str(exception)}' + + if request.paths.json: + return sanic.response.json({'error': {'status': status, 'message': msg}}) + + try: + return request.response('server_error.haml', status=status, context={'status': str(status), 'error': msg}) + + except TemplateNotFound: + return sanic.response.text(f'Error {status}: {msg}') + + +def NoTemplateError(request, exception): + logging.error('TEMPLATE_ERROR:', f'{exception.__class__.__name__}: {str(exception)}') + return sanic.response.html('I\'m a dumbass and forgot to create a template for this page', 500) + + +def ReplaceHeader(headers, key, value): + for k,v in headers.items(): + if k.lower() == header.lower(): + del headers[k] + + +class Response: + Text = sanic.response.text + Html = sanic.response.html + Json = sanic.response.json + Redir = sanic.response.redirect + + + def Css(*args, headers={}, **kwargs): + ReplaceHeader(headers, 'content-type', 'text/css') + return sanic.response.text(*args, headers=headers, **kwargs) + + + def Js(*args, headers={}, **kwargs): + ReplaceHeader(headers, 'content-type', 'application/javascript') + return sanic.response.text(*args, headers=headers, **kwargs) + + + def Ap(*args, headers={}, **kwargs): + ReplaceHeader(headers, 'content-type', 'application/activity+json') + return sanic.response.json(*args, headers=headers, **kwargs) + + + def Jrd(*args, headers={}, **kwargs): + ReplaceHeader(headers, 'content-type', 'application/jrd+json') + return sanic.response.json(*args, headers=headers, **kwargs) + + +Resp = Response diff --git a/IzzyLib/misc.py b/IzzyLib/misc.py index bb52a2f..2348ae0 100644 --- a/IzzyLib/misc.py +++ b/IzzyLib/misc.py @@ -1,5 +1,5 @@ '''Miscellaneous functions''' -import random, string, sys, os, json, socket +import hashlib, random, string, sys, os, json, socket, time from os import environ as env from datetime import datetime @@ -8,6 +8,11 @@ from pathlib import Path as Pathlib from . import logging +try: + from passlib.hash import argon2 +except ImportError: + argon2 = None + def Boolean(v, return_value=False): if type(v) not in [str, bool, int, type(None)]: @@ -43,6 +48,20 @@ def RandomGen(length=20, chars=None): return ''.join(random.choices(characters, k=length)) +def HashString(string, alg='blake2s'): + if alg not in hashlib.__always_supported: + logging.error('Unsupported hash algorithm:', alg) + logging.error('Supported algs:', ', '.join(hashlib.__always_supported)) + return + + string = string.encode('UTF-8') if type(string) != bytes else string + salt = salt.encode('UTF-8') if type(salt) != bytes else salt + + newhash = hashlib.new(alg) + newhash.update(string) + return newhash.hexdigest() + + def Timestamp(dtobj=None, utc=False): dtime = dtobj if dtobj else datetime date = dtime.utcnow() if utc else dtime.now() @@ -50,6 +69,11 @@ def Timestamp(dtobj=None, utc=False): return date.timestamp() +def GetVarName(*kwargs, single=True): + keys = list(kwargs.keys()) + return key[0] if single else keys + + def ApDate(date=None, alt=False): if not date: date = datetime.utcnow() @@ -93,14 +117,13 @@ def Input(prompt, default=None, valtype=str, options=[], password=False): prompt += f'[{opt}]' prompt += ': ' - value = input_func(prompt) - while value and options and value not in options: + while value and len(options) > 0 and value not in options: input_func('Invalid value:', value) value = input(prompt) - if not value: + if not value or value == '': return default ret = valtype(value) @@ -112,14 +135,38 @@ def Input(prompt, default=None, valtype=str, options=[], password=False): return ret +def NfsCheck(path): + proc = Path('/proc/mounts') + path = Path(path).resolve() + + if not proc.exists(): + return True + + with proc.open() as fd: + for line in fd: + line = line.split() + + if line[2] == 'nfs' and line[1] in path.str(): + return True + + return False + + class DotDict(dict): def __init__(self, value=None, **kwargs): - super().__init__() + '''Python dictionary, but variables can be set/get via attributes - if type(value) in [str, bytes]: + value [str, bytes, dict]: JSON or dict of values to init with + case_insensitive [bool]: Wether keys should be case sensitive or not + kwargs: key/value pairs to set on init. Overrides identical keys set by 'value' + ''' + super().__init__() + data = {} + + if isinstance(value, (str, bytes)): self.fromJson(value) - elif type(value) in [dict, DotDict, DefaultDict]: + elif isinstance(value, dict): self.update(value) elif value: @@ -134,10 +181,14 @@ class DotDict(dict): val = super().__getattribute__(key) except AttributeError: - val = self.get(key, InvalidKey()) + val = self.get(key, KeyError) - if type(val) == InvalidKey: - raise KeyError(f'Invalid key: {key}') + try: + if val == KeyError: + raise KeyError(f'Invalid key: {key}') + + except AttributeError: + 'PyCryptodome.PublicKey.RSA.RsaKey.__eq__ does not seem to play nicely' return DotDict(val) if type(val) == dict else val @@ -149,11 +200,6 @@ class DotDict(dict): super().__delattr__(key) - #def __delitem__(self, key): - #print('delitem', key) - #self.__delattr__(key) - - def __setattr__(self, key, value): if key.startswith('_'): super().__setattr__(key, value) @@ -162,6 +208,10 @@ class DotDict(dict): super().__setitem__(key, value) + def __str__(self): + return self.toJson() + + def __parse_item__(self, k, v): if type(v) == dict: v = DotDict(v) @@ -170,8 +220,12 @@ class DotDict(dict): return (k, v) + def update(self, data): + super().update(data) + + def get(self, key, default=None): - value = super().get(key, default) + value = dict.get(self, key, default) return DotDict(value) if type(value) == dict else value @@ -200,13 +254,18 @@ class DotDict(dict): def toJson(self, indent=None, **kwargs): + kwargs.pop('cls', None) + return json.dumps(dict(self), indent=indent, cls=DotDictEncoder, **kwargs) + + + def toJson2(self, indent=None, **kwargs): data = {} for k,v in self.items(): - if type(k) in [DotDict, Path, Pathlib]: + if k and not type(k) in [str, int, float, dict]: k = str(k) - if type(v) in [DotDict, Path, Pathlib]: + if v and not type(k) in [str, int, float, dict]: v = str(v) data[k] = v @@ -230,9 +289,41 @@ class DefaultDict(DotDict): return DotDict(val) if type(val) == dict else val +class LowerDotDict(DotDict): + def __getattr__(self, key): + key = key.lower() + + try: + val = super().__getattribute__(key) + + except AttributeError: + val = self.get(key, KeyError) + + if val == KeyError: + raise KeyError(f'Invalid key: {key}') + + return DotDict(val) if type(val) == dict else val + + + def __setattr__(self, key, value): + key = key.lower() + + if key.startswith('_'): + super().__setattr__(key, value) + + else: + super().__setitem__(key, value) + + + def update(self, data): + data = {k.lower(): v for k,v in self.items()} + + super().update(data) + + class Path(object): def __init__(self, path, exist=True, missing=True, parents=True): - self.__path = Pathlib(str(path)).resolve() + self.__path = Pathlib(str(path)) self.json = DotDict({}) self.exist = exist self.missing = missing @@ -240,20 +331,14 @@ class Path(object): self.name = self.__path.name - #def __getattr__(self, key): - #try: - #attr = getattr(self.__path, key) - - #except AttributeError: - #attr = getattr(self, key) - - #return attr - - def __str__(self): return str(self.__path) + def __repr__(self): + return f'Path({str(self.__path)})' + + def str(self): return self.__str__() @@ -307,7 +392,7 @@ class Path(object): def join(self, path, new=True): - new_path = self.__path.joinpath(path).resolve() + new_path = self.__path.joinpath(path) if new: return Path(new_path) @@ -422,25 +507,53 @@ class Path(object): return self.open().readlines() - ## def rmdir(): +class DotDictEncoder(json.JSONEncoder): + def default(self, obj): + if type(obj) not in [str, int, float, dict]: + return str(obj) + + return json.JSONEncoder.default(self, obj) -def NfsCheck(path): - proc = Path('/proc/mounts') - path = Path(path).resolve() +class PasswordHash(object): + def __init__(self, salt=None, rounds=8, bsize=50, threads=os.cpu_count(), length=64): + if type(salt) == Path: + if salt.exists(): + with salt.open() as fd: + self.salt = fd.read() - if not proc.exists(): - return True + else: + newsalt = RandomGen(40) - with proc.open() as fd: - for line in fd: - line = line.split() + with salt.open('w') as fd: + fd.write(newsalt) - if line[2] == 'nfs' and line[1] in path.str(): - return False + self.salt = newsalt - return True + else: + self.salt = salt or RandomGen(40) + + self.rounds = rounds + self.bsize = bsize * 1024 + self.threads = threads + self.length = length -class InvalidKey(object): - pass + def hash(self, password): + return argon2.using( + salt = self.salt.encode('UTF-8'), + rounds = self.rounds, + memory_cost = self.bsize, + max_threads = self.threads, + digest_size = self.length + ).hash(password) + + + def verify(self, password, passhash): + return argon2.using( + salt = self.salt.encode('UTF-8'), + rounds = self.rounds, + memory_cost = self.bsize, + max_threads = self.threads, + digest_size = self.length + ).verify(password, passhash) diff --git a/IzzyLib/template.py b/IzzyLib/template.py index 0d48a3e..63da6f1 100644 --- a/IzzyLib/template.py +++ b/IzzyLib/template.py @@ -6,9 +6,6 @@ from os.path import isfile, isdir, getmtime, abspath from jinja2 import Environment, FileSystemLoader, ChoiceLoader, select_autoescape, Markup from hamlish_jinja import HamlishExtension -from markdown import markdown -from watchdog.observers import Observer -from watchdog.events import FileSystemEventHandler from xml.dom import minidom try: @@ -23,11 +20,10 @@ from .misc import Path, DotDict class Template(Environment): - def __init__(self, search=[], global_vars={}, autoescape=True): + def __init__(self, search=[], global_vars={}, context=None, autoescape=True): self.autoescape = autoescape - self.watcher = None self.search = [] - self.func_context = None + self.func_context = context for path in search: self.__add_search_path(path) @@ -45,7 +41,6 @@ class Template(Environment): self.hamlish_mode = 'indented' self.globals.update({ - 'markdown': markdown, 'markup': Markup, 'cleanhtml': lambda text: ''.join(xml.etree.ElementTree.fromstring(text).itertext()), 'lighten': lighten, @@ -123,9 +118,9 @@ class Template(Environment): return result - def response(self, *args, ctype='text/html', status=200, **kwargs): + def response(self, request, tpl, ctype='text/html', status=200, **kwargs): if not Response: raise ModuleNotFoundError('Sanic is not installed') - html = self.render(*args, **kwargs) + html = self.render(tpl, request=request, **kwargs) return Response.HTTPResponse(body=html, status=status, content_type=ctype, headers=kwargs.get('headers', {})) diff --git a/requirements.txt b/requirements.txt index b486ed0..42a861c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ Hamlish-Jinja==0.3.3 Jinja2>=2.10.1 jinja2-markdown>=0.0.3 Mastodon.py>=1.5.0 +multidict>=5.1.0 pycryptodome>=3.9.1 python-magic>=0.4.18 sanic>=19.12.2