a lot of changes
This commit is contained in:
parent
0e59542626
commit
c4e5c9b6b4
|
@ -8,4 +8,5 @@ import sys
|
||||||
assert sys.version_info >= (3, 6)
|
assert sys.version_info >= (3, 6)
|
||||||
|
|
||||||
|
|
||||||
__version__ = (0, 4, 0)
|
__version_tpl__ = (0, 4, 0)
|
||||||
|
__version__ = '.'.join([str(v) for v in __version_tpl__])
|
||||||
|
|
|
@ -5,8 +5,16 @@ import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from .misc import DotDict
|
||||||
|
|
||||||
|
|
||||||
def parse_ttl(ttl):
|
def parse_ttl(ttl):
|
||||||
|
if not ttl:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if type(ttl) == int:
|
||||||
|
return ttl * 60
|
||||||
|
|
||||||
m = re.match(r'^(\d+)([smhdw]?)$', ttl)
|
m = re.match(r'^(\d+)([smhdw]?)$', ttl)
|
||||||
|
|
||||||
if not m:
|
if not m:
|
||||||
|
@ -34,25 +42,58 @@ def parse_ttl(ttl):
|
||||||
return multiplier * int(amount)
|
return multiplier * int(amount)
|
||||||
|
|
||||||
|
|
||||||
class TTLCache(OrderedDict):
|
class BaseCache(OrderedDict):
|
||||||
def __init__(self, maxsize=1024, ttl='1h'):
|
def __init__(self, maxsize=1024, ttl=None):
|
||||||
self.ttl = parse_ttl(ttl)
|
self.ttl = parse_ttl(ttl)
|
||||||
self.maxsize = maxsize
|
self.maxsize = maxsize
|
||||||
|
self.set = self.store
|
||||||
|
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
data = ', '.join([f'{k}="{v["data"]}"' for k,v in self.items()])
|
||||||
|
return f'BaseCache({data})'
|
||||||
|
|
||||||
|
|
||||||
|
def get(self, key):
|
||||||
|
while len(self) >= self.maxsize and self.maxsize != 0:
|
||||||
|
self.popitem(last=False)
|
||||||
|
|
||||||
|
item = DotDict.get(self, key)
|
||||||
|
|
||||||
|
if not item:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.ttl > 0:
|
||||||
|
timestamp = int(datetime.timestamp(datetime.now()))
|
||||||
|
|
||||||
|
if timestamp >= self[key].timestamp:
|
||||||
|
del self[key]
|
||||||
|
return
|
||||||
|
|
||||||
|
self[key].timestamp = timestamp + self.ttl
|
||||||
|
|
||||||
|
self.move_to_end(key)
|
||||||
|
return item['data']
|
||||||
|
|
||||||
|
|
||||||
def remove(self, key):
|
def remove(self, key):
|
||||||
if self.get(key):
|
if self.get(key):
|
||||||
del self[key]
|
del self[key]
|
||||||
|
|
||||||
|
|
||||||
def store(self, key, value):
|
def store(self, key, value):
|
||||||
|
if not self.get(key):
|
||||||
|
self[key] = DotDict()
|
||||||
|
|
||||||
|
self[key].data = value
|
||||||
|
|
||||||
|
if self.ttl:
|
||||||
timestamp = int(datetime.timestamp(datetime.now()))
|
timestamp = int(datetime.timestamp(datetime.now()))
|
||||||
item = self.get(key)
|
self[key].timestamp = timestamp + self.ttl
|
||||||
|
|
||||||
while len(self) >= self.maxsize and self.maxsize != 0:
|
|
||||||
self.popitem(last=False)
|
|
||||||
|
|
||||||
self[key] = {'data': value, 'timestamp': timestamp + self.ttl}
|
|
||||||
self.move_to_end(key)
|
self.move_to_end(key)
|
||||||
|
|
||||||
|
|
||||||
def fetch(self, key):
|
def fetch(self, key):
|
||||||
item = self.get(key)
|
item = self.get(key)
|
||||||
timestamp = int(datetime.timestamp(datetime.now()))
|
timestamp = int(datetime.timestamp(datetime.now()))
|
||||||
|
@ -60,29 +101,22 @@ class TTLCache(OrderedDict):
|
||||||
if not item:
|
if not item:
|
||||||
return
|
return
|
||||||
|
|
||||||
if timestamp >= self[key]['timestamp']:
|
if self.ttl:
|
||||||
|
if timestamp >= self[key].timestamp:
|
||||||
del self[key]
|
del self[key]
|
||||||
return
|
return
|
||||||
|
|
||||||
self[key]['timestamp'] = timestamp + self.ttl
|
self[key]['timestamp'] = timestamp + self.ttl
|
||||||
|
|
||||||
self.move_to_end(key)
|
self.move_to_end(key)
|
||||||
return self[key]['data']
|
return self[key].data
|
||||||
|
|
||||||
|
|
||||||
class LRUCache(OrderedDict):
|
class TTLCache(BaseCache):
|
||||||
|
def __init__(self, maxsize=1024, ttl='1h'):
|
||||||
|
super().__init__(maxsize, ttl)
|
||||||
|
|
||||||
|
|
||||||
|
class LRUCache(BaseCache):
|
||||||
def __init__(self, maxsize=1024):
|
def __init__(self, maxsize=1024):
|
||||||
self.maxsize = maxsize
|
super().__init__(maxsize)
|
||||||
|
|
||||||
def remove(self, key):
|
|
||||||
if key in self:
|
|
||||||
del self[key]
|
|
||||||
|
|
||||||
def store(self, key, value):
|
|
||||||
while len(self) >= self.maxsize and self.maxsize != 0:
|
|
||||||
self.popitem(last=False)
|
|
||||||
|
|
||||||
self[key] = value
|
|
||||||
self.move_to_end(key)
|
|
||||||
|
|
||||||
def fetch(self, key):
|
|
||||||
return self.get(key)
|
|
||||||
|
|
|
@ -1,536 +1,308 @@
|
||||||
## Probably gonna replace all of this with a custom sqlalchemy setup tbh
|
import sys
|
||||||
## It'll look like the db classes and functions in https://git.barkshark.xyz/izaliamae/social
|
|
||||||
import shutil, traceback, importlib, sqlite3, sys, json
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from sqlalchemy import create_engine, ForeignKey, MetaData, Table
|
||||||
|
from sqlalchemy import Column as SqlColumn, types as Types
|
||||||
|
#from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.exc import OperationalError, ProgrammingError
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from . import logging
|
||||||
from .cache import LRUCache
|
from .cache import LRUCache
|
||||||
from .misc import Boolean, DotDict, Path
|
from .misc import DotDict, RandomGen, NfsCheck
|
||||||
from . import logging, sql
|
|
||||||
|
|
||||||
try:
|
|
||||||
from dbutils.pooled_db import PooledDB
|
|
||||||
except ImportError:
|
|
||||||
from DBUtils.PooledDB import PooledDB
|
|
||||||
|
|
||||||
|
|
||||||
## Only sqlite3 has been tested
|
SqlTypes = DotDict({t.lower(): getattr(Types, t) for t in dir(Types) if not t.startswith('_')})
|
||||||
## Use other modules at your own risk.
|
|
||||||
class DB():
|
|
||||||
def __init__(self, tables, dbmodule='sqlite', cursor=None, **kwargs):
|
|
||||||
cursor = Cursor if not cursor else cursor
|
|
||||||
|
|
||||||
if dbmodule in ['sqlite', 'sqlite3']:
|
|
||||||
self.dbmodule = sqlite3
|
class DataBase():
|
||||||
self.dbtype = 'sqlite'
|
def __init__(self, dbtype='postgresql+psycopg2', tables={}, **kwargs):
|
||||||
|
self.engine_string = self.__engine_string(dbtype, kwargs)
|
||||||
|
self.db = create_engine(self.engine_string)
|
||||||
|
self.table = Tables(self, tables)
|
||||||
|
self.cache = DotDict({table: LRUCache() for table in tables.keys()})
|
||||||
|
self.classes = kwargs.get('row_classes', CustomRows())
|
||||||
|
|
||||||
|
session_class = kwargs.get('session_class', Session)
|
||||||
|
self.session = lambda trans=True: session_class(self, trans)
|
||||||
|
|
||||||
|
|
||||||
|
def __engine_string(self, dbtype, kwargs):
|
||||||
|
if not kwargs.get('database'):
|
||||||
|
raise MissingDatabaseError('Database not set')
|
||||||
|
|
||||||
|
engine_string = dbtype + '://'
|
||||||
|
|
||||||
|
if dbtype == 'sqlite':
|
||||||
|
if NfsCheck(kwargs.get('database')):
|
||||||
|
logging.error('Database file is on an NFS share which does not support locking. Any writes to the database will fail')
|
||||||
|
|
||||||
|
engine_string += '/' + kwargs.get('database')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.dbmodule = None
|
user = kwargs.get('user')
|
||||||
self.__setup_module(dbmodule)
|
password = kwargs.get('pass')
|
||||||
|
host = kwargs.get('host', '/var/run/postgresql')
|
||||||
|
port = kwargs.get('port', 5432)
|
||||||
|
name = kwargs.get('name', 'postgres')
|
||||||
|
maxconn = kwargs.get('maxconnections', 25)
|
||||||
|
|
||||||
self.db = None
|
if user:
|
||||||
self.cursor = lambda : cursor(self).begin()
|
if password:
|
||||||
self.kwargs = kwargs
|
engine_string += f'{user}:{password}@'
|
||||||
self.tables = tables
|
else:
|
||||||
self.cache = DotDict()
|
engine_string += user + '@'
|
||||||
self.__setup_database()
|
|
||||||
|
|
||||||
for table in tables.keys():
|
if host == '/var/run/postgresql':
|
||||||
self.__setup_cache(table)
|
engine_string += '/' + name
|
||||||
|
|
||||||
|
|
||||||
def query(self, *args, **kwargs):
|
|
||||||
return self.__cursor_cmd('query', *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def fetch(self, *args, **kwargs):
|
|
||||||
return self.__cursor_cmd('fetch', *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def insert(self, *args, **kwargs):
|
|
||||||
return self.__cursor_cmd('insert', *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def update(self, *args, **kwargs):
|
|
||||||
return self.__cursor_cmd('update', *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def remove(self, *args, **kwargs):
|
|
||||||
return self.__cursor_cmd('remove', *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def __cursor_cmd(self, name, *args, **kwargs):
|
|
||||||
with self.cursor() as cur:
|
|
||||||
method = getattr(cur, name)
|
|
||||||
|
|
||||||
if not method:
|
|
||||||
raise KeyError('Not a valid cursor method:', name)
|
|
||||||
|
|
||||||
return method(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def __setup_module(self, dbtype):
|
|
||||||
modules = []
|
|
||||||
modtypes = {
|
|
||||||
'sqlite': ['sqlite3'],
|
|
||||||
'postgresql': ['psycopg3', 'pgdb', 'psycopg2', 'pg8000'],
|
|
||||||
'mysql': ['mysqldb', 'trio_mysql'],
|
|
||||||
'mssql': ['pymssql', 'adodbapi']
|
|
||||||
}
|
|
||||||
|
|
||||||
for dbmod, mods in modtypes.items():
|
|
||||||
if dbtype == dbmod:
|
|
||||||
self.dbtype = dbmod
|
|
||||||
modules = mods
|
|
||||||
break
|
|
||||||
|
|
||||||
elif dbtype in mods:
|
|
||||||
self.dbtype = dbmod
|
|
||||||
modules = [dbtype]
|
|
||||||
break
|
|
||||||
|
|
||||||
if not modules:
|
|
||||||
logging.verbose('Not a database type. Checking if it is a module...')
|
|
||||||
|
|
||||||
for mod in modules:
|
|
||||||
try:
|
|
||||||
self.dbmodule = importlib.import_module(mod)
|
|
||||||
except ImportError as e:
|
|
||||||
logging.verbose('Module not installed:', mod)
|
|
||||||
|
|
||||||
if self.dbmodule:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not self.dbmodule:
|
|
||||||
if modtypes.get(dbtype):
|
|
||||||
logging.error('Failed to find module for database type:', dbtype)
|
|
||||||
logging.error(f'Please install one of these modules to use a {dbtype} database:', ', '.join(modules))
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logging.error('Failed to import module:', dbtype)
|
engine_string += f'{host}:{port}/{name}'
|
||||||
logging.error('Install one of the following modules:')
|
|
||||||
|
|
||||||
for key, modules in modtypes.items():
|
return engine_string
|
||||||
logging.error(f'{key}:')
|
|
||||||
for mod in modules:
|
|
||||||
logging.error(f'\t{mod}')
|
|
||||||
|
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
|
|
||||||
def __setup_database(self):
|
def CreateDatabase(self):
|
||||||
if self.dbtype == 'sqlite':
|
if self.engine_string.startswith('postgresql'):
|
||||||
if not self.kwargs.get('database'):
|
predb = create_engine(db.engine_string.replace(config.db.name, 'postgres', -1))
|
||||||
dbfile = ':memory:'
|
conn = predb.connect()
|
||||||
|
conn.execute('commit')
|
||||||
dbfile = Path(self.kwargs['database'])
|
|
||||||
dbfile.parent().mkdir()
|
|
||||||
|
|
||||||
if not dbfile.parent().isdir():
|
|
||||||
logging.error('Invalid path to database file:', dbfile.parent().str())
|
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
self.kwargs['database'] = dbfile.str()
|
|
||||||
self.kwargs['check_same_thread'] = False
|
|
||||||
|
|
||||||
else:
|
|
||||||
if not self.kwargs.get('password'):
|
|
||||||
self.kwargs.pop('password', None)
|
|
||||||
|
|
||||||
self.db = PooledDB(self.dbmodule, **self.kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def connection(self):
|
|
||||||
if self.dbtype == 'sqlite':
|
|
||||||
return self.db.connection()
|
|
||||||
|
|
||||||
|
|
||||||
def __setup_cache(self, table):
|
|
||||||
self.cache[table] = LRUCache(128)
|
|
||||||
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.db.close()
|
|
||||||
|
|
||||||
|
|
||||||
def count(self, table):
|
|
||||||
tabledict = tables.get(table)
|
|
||||||
|
|
||||||
if not tabledict:
|
|
||||||
logging.debug('Table does not exist:', table)
|
|
||||||
return 0
|
|
||||||
|
|
||||||
data = self.query(f'SELECT COUNT(*) FROM {table}')
|
|
||||||
return data[0][0]
|
|
||||||
|
|
||||||
|
|
||||||
#def query(self, string, values=[], cursor=None):
|
|
||||||
#if not string.endswith(';'):
|
|
||||||
#string += ';'
|
|
||||||
|
|
||||||
#if not cursor:
|
|
||||||
#with self.Cursor() as cursor:
|
|
||||||
#cursor.execute(string, values)
|
|
||||||
#return cursor.fetchall()
|
|
||||||
|
|
||||||
#else:
|
|
||||||
#cursor.execute(string,value)
|
|
||||||
#return cursor.fetchall()
|
|
||||||
|
|
||||||
#return False
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def Cursor(self):
|
|
||||||
conn = self.db.connection()
|
|
||||||
conn.begin()
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield cursor
|
conn.execute(f'CREATE DATABASE {config.db.name}')
|
||||||
except self.dbmodule.OperationalError:
|
|
||||||
traceback.print_exc()
|
except ProgrammingError:
|
||||||
conn.rollback()
|
'The database already exists, so just move along'
|
||||||
finally:
|
|
||||||
cursor.close()
|
except Exception as e:
|
||||||
conn.commit()
|
conn.close()
|
||||||
|
raise e from None
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
self.table.meta.create_all(self.db)
|
||||||
|
|
||||||
def CreateTable(self, table):
|
|
||||||
layout = DotDict(self.tables.get(table))
|
|
||||||
|
|
||||||
if not layout:
|
def execute(self, *args, **kwargs):
|
||||||
logging.error('Table config doesn\'t exist:', table)
|
with self.session() as s:
|
||||||
return
|
return s.execute(*args, **kwargs)
|
||||||
|
|
||||||
cmd = f'CREATE TABLE IF NOT EXISTS {table}('
|
|
||||||
items = []
|
|
||||||
|
|
||||||
for k, v in layout.items():
|
|
||||||
options = ' '.join(v.get('options', []))
|
|
||||||
default = v.get('default')
|
|
||||||
item = f'{k} {v["type"].upper()} {options}'
|
|
||||||
|
|
||||||
if default:
|
|
||||||
item += f'DEFAULT {default}'
|
|
||||||
|
|
||||||
items.append(item)
|
|
||||||
|
|
||||||
cmd += ', '.join(items) + ')'
|
|
||||||
|
|
||||||
return True if self.query(cmd) != False else False
|
|
||||||
|
|
||||||
|
|
||||||
def RenameTable(self, table, newname):
|
|
||||||
self.query(f'ALTER TABLE {table} RENAME TO {newname}')
|
|
||||||
|
|
||||||
|
|
||||||
def DropTable(self, table):
|
|
||||||
self.query(f'DROP TABLE {table}')
|
|
||||||
|
|
||||||
|
|
||||||
def AddColumn(self, table, name, datatype, default=None, options=None):
|
|
||||||
query = f'ALTER TABLE {table} ADD COLUMN {name} {datatype.upper()}'
|
|
||||||
|
|
||||||
if default:
|
|
||||||
query += f' DEFAULT {default}'
|
|
||||||
|
|
||||||
if options:
|
|
||||||
query += f' {options}'
|
|
||||||
|
|
||||||
self.query(query)
|
|
||||||
|
|
||||||
|
|
||||||
def CheckDatabase(self, database):
|
|
||||||
if self.dbtype == 'postgresql':
|
|
||||||
tables = self.query('SELECT datname FROM pg_database')
|
|
||||||
|
|
||||||
else:
|
|
||||||
tables = []
|
|
||||||
|
|
||||||
tables = [table[0] for table in tables]
|
|
||||||
print(database in tables, database, tables)
|
|
||||||
|
|
||||||
return database in tables
|
|
||||||
|
|
||||||
|
|
||||||
def CreateTables(self):
|
|
||||||
dbname = self.kwargs['database']
|
|
||||||
|
|
||||||
for name, table in self.tables.items():
|
|
||||||
if len(self.query(sql.CheckTable(self.dbtype, name))) < 1:
|
|
||||||
logging.info('Creating table:', name)
|
|
||||||
self.query(table.sql())
|
|
||||||
|
|
||||||
|
|
||||||
class Table(dict):
|
|
||||||
def __init__(self, name, dbtype='sqlite'):
|
|
||||||
super().__init__({})
|
|
||||||
|
|
||||||
self.sqlstr = 'CREATE TABLE IF NOT EXISTS {} ({});'
|
|
||||||
self.dbtype = dbtype
|
|
||||||
self.name = name
|
|
||||||
self.columns = {}
|
|
||||||
self.fkeys = {}
|
|
||||||
|
|
||||||
|
|
||||||
def addColumn(self, name, datatype=None, null=True, unique=None, primary=None, fkey=None):
|
|
||||||
if name == 'id':
|
|
||||||
if self.dbtype == 'sqlite':
|
|
||||||
datatype = 'integer'
|
|
||||||
primary = True if primary == None else primary
|
|
||||||
unique = True
|
|
||||||
null = False
|
|
||||||
|
|
||||||
else:
|
|
||||||
datatype = 'serial'
|
|
||||||
primary = 'true'
|
|
||||||
|
|
||||||
if name == 'timestamp':
|
|
||||||
datatype = 'float'
|
|
||||||
|
|
||||||
elif not datatype:
|
|
||||||
raise MissingTypeError(f'Missing a data type for column: {name}')
|
|
||||||
|
|
||||||
colsql = f'{name} {datatype.upper()}'
|
|
||||||
|
|
||||||
if unique:
|
|
||||||
colsql += ' UNIQUE'
|
|
||||||
|
|
||||||
if not null:
|
|
||||||
colsql += ' NOT NULL'
|
|
||||||
|
|
||||||
if primary:
|
|
||||||
self.primary = name
|
|
||||||
colsql += ' PRIMARY KEY'
|
|
||||||
|
|
||||||
if name == 'id' and self.dbtype == 'sqlite':
|
|
||||||
colsql += ' AUTOINCREMENT'
|
|
||||||
|
|
||||||
if fkey:
|
|
||||||
if self.dbtype == 'postgresql':
|
|
||||||
colsql += f' REFERENCES {fkey[0]}({fkey[1]})'
|
|
||||||
|
|
||||||
elif self.dbtype == 'sqlite':
|
|
||||||
self.fkeys[name] += f'FOREIGN KEY ({name}) REFERENCES {fkey[0]} ({fkey[1]})'
|
|
||||||
|
|
||||||
self.columns.update({name: colsql})
|
|
||||||
|
|
||||||
|
|
||||||
def sql(self):
|
|
||||||
if not self.primary:
|
|
||||||
logging.error('Please specify a primary column')
|
|
||||||
return
|
|
||||||
|
|
||||||
data = ', '.join(list(self.columns.values()))
|
|
||||||
|
|
||||||
if self.fkeys:
|
|
||||||
data += ', '
|
|
||||||
data += ', '.join(list(self.fkeys.values()))
|
|
||||||
|
|
||||||
sqldata = self.sqlstr.format(self.name, data)
|
|
||||||
print(sqldata)
|
|
||||||
return sqldata
|
|
||||||
|
|
||||||
|
|
||||||
class Cursor(object):
|
|
||||||
def __init__(self, db):
|
|
||||||
self.main = db
|
|
||||||
self.db = db.db
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def begin(self):
|
|
||||||
self.conn = self.db.connection()
|
|
||||||
self.conn.begin()
|
|
||||||
self.cursor = self.conn.cursor()
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield self
|
|
||||||
|
|
||||||
except self.main.dbmodule.OperationalError:
|
|
||||||
self.conn.rollback()
|
|
||||||
raise
|
|
||||||
|
|
||||||
finally:
|
|
||||||
self.cursor.close()
|
|
||||||
self.conn.commit()
|
|
||||||
self.conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
def query(self, string, values=[], cursor=None):
|
|
||||||
#if not string.endswith(';'):
|
|
||||||
#string += ';'
|
|
||||||
|
|
||||||
self.cursor.execute(string, values)
|
|
||||||
data = self.cursor.fetchall()
|
|
||||||
|
|
||||||
|
|
||||||
def fetch(self, table, single=True, sort=None, **kwargs):
|
|
||||||
rowid = kwargs.get('id')
|
|
||||||
querysort = f'ORDER BY {sort}'
|
|
||||||
|
|
||||||
resultOpts = [self, table, self.cursor]
|
|
||||||
|
|
||||||
if rowid:
|
|
||||||
cursor.execute(f"SELECT * FROM {table} WHERE id = ?", [rowid])
|
|
||||||
|
|
||||||
elif kwargs:
|
|
||||||
placeholders = [f'{k} = ?' for k in kwargs.keys()]
|
|
||||||
values = kwargs.values()
|
|
||||||
|
|
||||||
where = ' and '.join(placeholders)
|
|
||||||
query = f"SELECT * FROM {table} WHERE {where} {querysort if sort else ''}"
|
|
||||||
self.cursor.execute(query, list(values))
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.cursor.execute(f'SELECT * FROM {table} {querysort if sort else ""}')
|
|
||||||
|
|
||||||
rows = self.cursor.fetchall() if not single else self.cursor.fetchone()
|
|
||||||
|
|
||||||
if rows:
|
|
||||||
if single:
|
|
||||||
return DBResult(rows, *resultOpts)
|
|
||||||
|
|
||||||
return [DBResult(row, *resultOpts) for row in rows]
|
|
||||||
|
|
||||||
return None if single else []
|
|
||||||
|
|
||||||
|
|
||||||
def insert(self, table, data={}, **kwargs):
|
|
||||||
data.update(kwargs)
|
|
||||||
placeholders = ",".join(['?' for _ in data.keys()])
|
|
||||||
values = tuple(data.values())
|
|
||||||
keys = ','.join(data.keys())
|
|
||||||
|
|
||||||
if 'timestamp' in self.main.tables[table].keys() and 'timestamp' not in keys:
|
|
||||||
data['timestamp'] = datetime.now()
|
|
||||||
|
|
||||||
query = f'INSERT INTO {table} ({keys}) VALUES ({placeholders})'
|
|
||||||
self.query(query, values)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def remove(self, table, **kwargs):
|
|
||||||
keys = []
|
|
||||||
values = []
|
|
||||||
|
|
||||||
for k,v in kwargs.items():
|
|
||||||
keys.append(k)
|
|
||||||
values.append(v)
|
|
||||||
|
|
||||||
keydata = ','.join([f'{k} = ?' for k in keys])
|
|
||||||
query = f'DELETE FROM {table} WHERE {keydata}'
|
|
||||||
|
|
||||||
self.query(query, values)
|
|
||||||
|
|
||||||
|
|
||||||
def update(self, table, rowid, data={}, **kwargs):
|
|
||||||
data.update(kwargs)
|
|
||||||
newdata = {k: v for k, v in data.items() if k in self.main.tables[table].keys()}
|
|
||||||
keys = list(newdata.keys())
|
|
||||||
values = list(newdata.values())
|
|
||||||
|
|
||||||
if len(newdata) < 1:
|
|
||||||
logging.debug('No data provided to update row')
|
|
||||||
return False
|
|
||||||
|
|
||||||
query_data = ', '.join(f'{k} = ?' for k in keys)
|
|
||||||
query = f'UPDATE {table} SET {query_data} WHERE id = {rowid}'
|
|
||||||
|
|
||||||
self.query(query, values)
|
|
||||||
|
|
||||||
|
|
||||||
class DBResult(DotDict):
|
|
||||||
def __init__(self, row, db, table, cursor):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
|
class Session(object):
|
||||||
|
def __init__(self, db, trans=True):
|
||||||
self.db = db
|
self.db = db
|
||||||
self.table = table
|
self.classes = self.db.classes
|
||||||
|
self.session = sessionmaker(bind=db.db)()
|
||||||
|
self.table = self.db.table
|
||||||
|
self.cache = self.db.cache
|
||||||
|
self.trans = trans
|
||||||
|
|
||||||
for idx, col in enumerate(cursor.description):
|
# session aliases
|
||||||
self[col[0]] = row[idx]
|
self.s = self.session
|
||||||
|
self.commit = self.s.commit
|
||||||
|
self.rollback = self.s.rollback
|
||||||
|
self.query = self.s.query
|
||||||
|
self.execute = self.s.execute
|
||||||
|
|
||||||
|
self._setup()
|
||||||
|
|
||||||
|
if not self.trans:
|
||||||
|
self.commit()
|
||||||
|
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
def __enter__(self):
|
||||||
if name not in ['db', 'table']:
|
self.sessionid = RandomGen(10)
|
||||||
return self.__setitem__(name, value)
|
|
||||||
|
|
||||||
else:
|
|
||||||
return super().__setattr__(name, value)
|
|
||||||
|
|
||||||
|
|
||||||
def __delattr__(self, name):
|
|
||||||
if name not in ['db', 'table']:
|
|
||||||
return self.__delitem__(name)
|
|
||||||
|
|
||||||
else:
|
|
||||||
return super().__delattr__(name)
|
|
||||||
|
|
||||||
|
|
||||||
def __getattr__(self, value, default=None):
|
|
||||||
options = [value]
|
|
||||||
|
|
||||||
if default:
|
|
||||||
options.append(default)
|
|
||||||
|
|
||||||
if value in self.keys():
|
|
||||||
val = super().__getitem__(*options)
|
|
||||||
return DotDict(val) if isinstance(val, dict) else val
|
|
||||||
|
|
||||||
else:
|
|
||||||
return dict.__getattr__(*options)
|
|
||||||
|
|
||||||
|
|
||||||
# Kept for backwards compatibility. Delete later.
|
|
||||||
def asdict(self):
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def Update(self, data={}):
|
def __exit__(self, exctype, value, tb):
|
||||||
with self.db.Cursor().begin() as cursor:
|
if tb:
|
||||||
self.update(data)
|
self.rollback()
|
||||||
cursor.update(self.table, self.id, self.AsDict())
|
|
||||||
|
|
||||||
|
|
||||||
def Remove(self):
|
|
||||||
with self.db.Cursor().begin() as cursor:
|
|
||||||
cursor.remove(self.table, id=self.id)
|
|
||||||
|
|
||||||
|
|
||||||
def ParseData(table, row):
|
|
||||||
tbdata = tables.get(table)
|
|
||||||
types = []
|
|
||||||
result = []
|
|
||||||
|
|
||||||
if not tbdata:
|
|
||||||
logging.error('Invalid table:', table)
|
|
||||||
return
|
|
||||||
|
|
||||||
for v in tbdata.values():
|
|
||||||
dtype = v.split()[0].upper()
|
|
||||||
|
|
||||||
if dtype == 'BOOLEAN':
|
|
||||||
types.append(boolean)
|
|
||||||
|
|
||||||
elif dtype in ['INTEGER', 'INT']:
|
|
||||||
types.append(int)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
types.append(str)
|
self.commit()
|
||||||
|
|
||||||
for idx, v in enumerate(row):
|
|
||||||
result.append(types[idx](v) if v else None)
|
|
||||||
|
|
||||||
return row
|
|
||||||
|
|
||||||
db.insert('config', {'key': 'version', 'value': dbversion})
|
|
||||||
|
|
||||||
|
|
||||||
class MissingTypeError(Exception):
|
def _setup(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TooManyConnectionsError(Exception):
|
def count(self, table_name, **kwargs):
|
||||||
|
return self.query(self.table[table_name]).filter_by(**kwargs).count()
|
||||||
|
|
||||||
|
|
||||||
|
def fetch(self, table_name, single=True, **kwargs):
|
||||||
|
RowClass = self.classes.get(table_name.capitalize())
|
||||||
|
|
||||||
|
rows = self.query(self.table[table_name]).filter_by(**kwargs).all()
|
||||||
|
|
||||||
|
if single:
|
||||||
|
return RowClass(table_name, rows[0], self) if len(rows) > 0 else None
|
||||||
|
|
||||||
|
return [RowClass(table_name, row, self) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def insert(self, table_name, **kwargs):
|
||||||
|
row = self.fetch(table_name, **kwargs)
|
||||||
|
|
||||||
|
if row:
|
||||||
|
row.update_session(self, **kwargs)
|
||||||
|
return
|
||||||
|
|
||||||
|
table = self.table[table_name]
|
||||||
|
|
||||||
|
if getattr(table, 'timestamp', None) and not kwargs.get('timestamp'):
|
||||||
|
kwargs['timestamp'] = datetime.now()
|
||||||
|
|
||||||
|
res = self.execute(table.insert().values(**kwargs))
|
||||||
|
#return self.fetch(table_name, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def update(self, table=None, rowid=None, row=None, **data):
|
||||||
|
if row:
|
||||||
|
rowid = row.id
|
||||||
|
table = row._table_name
|
||||||
|
|
||||||
|
if not rowid or not table:
|
||||||
|
raise ValueError('Missing row ID or table')
|
||||||
|
|
||||||
|
tclass = self.table[table]
|
||||||
|
|
||||||
|
self.execute(tclass.update().where(tclass.c.id == rowid).values(**data))
|
||||||
|
|
||||||
|
|
||||||
|
def remove(self, table=None, rowid=None, row=None):
|
||||||
|
if row:
|
||||||
|
rowid = row.id
|
||||||
|
table = row._table_name
|
||||||
|
|
||||||
|
if not rowid or not table:
|
||||||
|
raise ValueError('Missing row ID or table')
|
||||||
|
|
||||||
|
row = self.execute(f'DELETE FROM {table} WHERE id={rowid}')
|
||||||
|
|
||||||
|
|
||||||
|
def DropTables(self):
|
||||||
|
tables = self.GetTables()
|
||||||
|
|
||||||
|
for table in tables:
|
||||||
|
self.execute(f'DROP TABLE {table}')
|
||||||
|
|
||||||
|
|
||||||
|
def GetTables(self):
|
||||||
|
rows = self.execute("SELECT name FROM sqlite_master WHERE type IN ('table','view') and name NOT LIKE 'sqlite_%'")
|
||||||
|
return [row[0] for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
class CustomRows(object):
|
||||||
|
def get(self, name):
|
||||||
|
return getattr(self, name, self.Row)
|
||||||
|
|
||||||
|
|
||||||
|
class Row(DotDict):
|
||||||
|
#_filter_columns = lambda self, row: [attr for attr in dir(row) if not attr.startswith('_') and attr != 'metadata']
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, table, row, session):
|
||||||
|
if not row:
|
||||||
|
return
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self._update(row._asdict())
|
||||||
|
|
||||||
|
self._db = session.db
|
||||||
|
self._table_name = table
|
||||||
|
self._columns = self.keys()
|
||||||
|
#self._columns = self._filter_columns(row)
|
||||||
|
|
||||||
|
self.__run__(session)
|
||||||
|
|
||||||
|
|
||||||
|
## Subclass Row and redefine this function
|
||||||
|
def __run__(self, s):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_data(self):
|
||||||
|
data = {k: v for k,v in self.items() if k in self._columns}
|
||||||
|
|
||||||
|
for k,v in self.items():
|
||||||
|
if v.__class__ == DotDict:
|
||||||
|
data[k] = v.asDict()
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def asDict(self):
|
||||||
|
return self._filter_data()
|
||||||
|
|
||||||
|
|
||||||
|
def _update(self, new_data={}, **kwargs):
|
||||||
|
kwargs.update(new_data)
|
||||||
|
|
||||||
|
for k,v in kwargs.items():
|
||||||
|
if type(v) == dict:
|
||||||
|
self[k] = DotDict(v)
|
||||||
|
|
||||||
|
self[k] = v
|
||||||
|
|
||||||
|
|
||||||
|
def delete(self):
|
||||||
|
with self._db.session() as s:
|
||||||
|
return self.delete_session(s)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_session(self, s):
|
||||||
|
return s.remove(row=self)
|
||||||
|
|
||||||
|
|
||||||
|
def update(self, dict_data={}, **data):
|
||||||
|
dict_data.update(data)
|
||||||
|
self._update(dict_data)
|
||||||
|
|
||||||
|
with self._db.session() as s:
|
||||||
|
s.update(row=self, **self._filter_data())
|
||||||
|
|
||||||
|
|
||||||
|
def update_session(self, s, dict_data={}, **data):
|
||||||
|
return s.update(row=self, **dict_data, **data)
|
||||||
|
|
||||||
|
|
||||||
|
class Tables(DotDict):
|
||||||
|
def __init__(self, db, tables={}):
|
||||||
|
'"tables" should be a dict with the table names for keys and a list of Columns for values'
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.db = db
|
||||||
|
self.meta = MetaData()
|
||||||
|
|
||||||
|
for name, table in tables.items():
|
||||||
|
self.__setup_table(name, table)
|
||||||
|
|
||||||
|
|
||||||
|
def __setup_table(self, name, table):
|
||||||
|
self[name] = Table(name, self.meta, *table)
|
||||||
|
|
||||||
|
|
||||||
|
def Column(name, stype=None, fkey=None, **kwargs):
|
||||||
|
if not stype and not kwargs:
|
||||||
|
if name == 'id':
|
||||||
|
return Column('id', 'integer', primary_key=True, autoincrement=True)
|
||||||
|
|
||||||
|
elif name == 'timestamp':
|
||||||
|
return Column('timestamp', 'datetime')
|
||||||
|
|
||||||
|
raise ValueError('Missing column type and options')
|
||||||
|
|
||||||
|
else:
|
||||||
|
options = [name, SqlTypes.get(stype.lower(), SqlTypes['string'])]
|
||||||
|
|
||||||
|
if fkey:
|
||||||
|
options.append(ForeignKey(fkey))
|
||||||
|
|
||||||
|
return SqlColumn(*options, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class MissingDatabaseError(Exception):
|
||||||
|
'''raise when the "database" kwargs is not set'''
|
||||||
|
|
385
IzzyLib/http.py
385
IzzyLib/http.py
|
@ -4,11 +4,12 @@ from IzzyLib import logging
|
||||||
from IzzyLib.misc import DefaultDict, DotDict
|
from IzzyLib.misc import DefaultDict, DotDict
|
||||||
from base64 import b64decode, b64encode
|
from base64 import b64decode, b64encode
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from ssl import SSLCertVerificationError
|
||||||
from urllib.error import HTTPError
|
from urllib.error import HTTPError
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from urllib.request import Request, urlopen
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
from . import error
|
from . import error, __version__
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from Crypto.Hash import SHA256
|
from Crypto.Hash import SHA256
|
||||||
|
@ -21,6 +22,7 @@ except ImportError:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sanic.request import Request as SanicRequest
|
from sanic.request import Request as SanicRequest
|
||||||
|
from sanic.exceptions import SanicException
|
||||||
sanic_enabled = True
|
sanic_enabled = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.verbose('Sanic module not found. Request verification is disabled')
|
logging.verbose('Sanic module not found. Request verification is disabled')
|
||||||
|
@ -30,169 +32,8 @@ except ImportError:
|
||||||
Client = None
|
Client = None
|
||||||
|
|
||||||
|
|
||||||
def VerifyRequest(request: SanicRequest, actor: dict=None):
|
|
||||||
'''Verify a header signature from a sanic request
|
|
||||||
|
|
||||||
request: The request with the headers to verify
|
|
||||||
actor (optional): A dictionary containing the activitypub actor and the link to the pubkey used for verification
|
|
||||||
'''
|
|
||||||
if not sanic_enabled:
|
|
||||||
logging.error('Sanic request verification disabled')
|
|
||||||
return
|
|
||||||
|
|
||||||
if not actor:
|
|
||||||
actor = request.ctx.actor
|
|
||||||
|
|
||||||
body = request.body if request.body else None
|
|
||||||
return VerifyHeaders(request.headers, request.method, request.path, body, actor, False)
|
|
||||||
|
|
||||||
|
|
||||||
def VerifyHeaders(headers: dict, method: str, path: str, actor: dict=None, body=None, fail: bool=False):
|
|
||||||
'''Verify a header signature
|
|
||||||
|
|
||||||
headers: A dictionary containing all the headers from a request
|
|
||||||
method: The HTTP method of the request
|
|
||||||
path: The path of the HTTP request
|
|
||||||
actor (optional): A dictionary containing the activitypub actor and the link to the pubkey used for verification
|
|
||||||
body (optional): The body of the request. Only needed if the signature includes the digest header
|
|
||||||
fail (optional): If set to True, raise an error instead of returning False if any step of the process fails
|
|
||||||
'''
|
|
||||||
if not crypto_enabled:
|
|
||||||
logging.error('Crypto functions disabled')
|
|
||||||
return
|
|
||||||
|
|
||||||
headers = {k.lower(): v for k,v in headers.items()}
|
|
||||||
headers['(request-target)'] = f'{method.lower()} {path}'
|
|
||||||
signature = ParseSig(headers.get('signature'))
|
|
||||||
digest = headers.get('digest')
|
|
||||||
missing_headers = [k for k in headers if k in ['date', 'host'] if headers.get(k) == None]
|
|
||||||
|
|
||||||
if not signature:
|
|
||||||
if fail:
|
|
||||||
raise MissingSignatureError()
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not actor:
|
|
||||||
actor = FetchActor(signature.keyid)
|
|
||||||
|
|
||||||
## Add digest header to missing headers list if it doesn't exist
|
|
||||||
if method.lower() == 'post' and not headers.get('digest'):
|
|
||||||
missing_headers.append('digest')
|
|
||||||
|
|
||||||
## Fail if missing date, host or digest (if POST) headers
|
|
||||||
if missing_headers:
|
|
||||||
if fail:
|
|
||||||
raise error.MissingHeadersError(missing_headers)
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
## Fail if body verification fails
|
|
||||||
if digest and not VerifyString(body, digest):
|
|
||||||
if fail:
|
|
||||||
raise error.VerificationError('digest header')
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
pubkey = actor.publicKey['publicKeyPem']
|
|
||||||
|
|
||||||
if PkcsHeaders(pubkey, {k:v for k,v in headers.items() if k in signature.headers}, sig=signature):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if fail:
|
|
||||||
raise error.VerificationError('headers')
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=512)
|
|
||||||
def VerifyString(string, enc_string, alg='SHA256', fail=False):
|
|
||||||
if not crypto_enabled:
|
|
||||||
logging.error('Crypto functions disabled')
|
|
||||||
return
|
|
||||||
|
|
||||||
if type(string) != bytes:
|
|
||||||
string = string.encode('UTF-8')
|
|
||||||
|
|
||||||
body_hash = b64encode(SHA256.new(string).digest()).decode('UTF-8')
|
|
||||||
|
|
||||||
if body_hash == enc_string:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if fail:
|
|
||||||
raise error.VerificationError()
|
|
||||||
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def PkcsHeaders(key: str, headers: dict, sig=None):
|
|
||||||
if not crypto_enabled:
|
|
||||||
logging.error('Crypto functions disabled')
|
|
||||||
return
|
|
||||||
|
|
||||||
if sig:
|
|
||||||
head_items = [f'{item}: {headers[item]}' for item in sig.headers]
|
|
||||||
|
|
||||||
else:
|
|
||||||
head_items = [f'{k.lower()}: {v}' for k,v in headers.items()]
|
|
||||||
|
|
||||||
head_string = '\n'.join(head_items)
|
|
||||||
head_bytes = head_string.encode('UTF-8')
|
|
||||||
|
|
||||||
KEY = RSA.importKey(key)
|
|
||||||
pkcs = PKCS1_v1_5.new(KEY)
|
|
||||||
h = SHA256.new(head_bytes)
|
|
||||||
|
|
||||||
if sig:
|
|
||||||
return pkcs.verify(h, b64decode(sig.signature))
|
|
||||||
|
|
||||||
else:
|
|
||||||
return pkcs.sign(h)
|
|
||||||
|
|
||||||
|
|
||||||
def ParseSig(signature: str):
|
|
||||||
if not signature:
|
|
||||||
logging.verbose('Missing signature header')
|
|
||||||
return
|
|
||||||
|
|
||||||
split_sig = signature.split(',')
|
|
||||||
sig = DefaultDict({})
|
|
||||||
|
|
||||||
for part in split_sig:
|
|
||||||
key, value = part.split('=', 1)
|
|
||||||
sig[key.lower()] = value.replace('"', '')
|
|
||||||
|
|
||||||
if not sig.headers:
|
|
||||||
logging.verbose('Missing headers section in signature')
|
|
||||||
return
|
|
||||||
|
|
||||||
sig.headers = sig.headers.split()
|
|
||||||
|
|
||||||
return sig
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=512)
|
|
||||||
def FetchActor(keyid, client=None):
|
|
||||||
if not client:
|
|
||||||
client = Client if Client else HttpClient()
|
|
||||||
|
|
||||||
actor = Client.request(keyid).json()
|
|
||||||
actor.domain = urlparse(actor.id).netloc
|
|
||||||
actor.shared_inbox = actor.inbox
|
|
||||||
actor.pubkey = None
|
|
||||||
|
|
||||||
if actor.get('endpoints'):
|
|
||||||
actor.shared_inbox = actor.endpoints.get('sharedInbox', actor.inbox)
|
|
||||||
|
|
||||||
if actor.get('publicKey'):
|
|
||||||
actor.pubkey = actor.publicKey.get('publicKeyPem')
|
|
||||||
|
|
||||||
return actor
|
|
||||||
|
|
||||||
|
|
||||||
class HttpClient(object):
|
class HttpClient(object):
|
||||||
def __init__(self, headers={}, useragent='IzzyLib/0.3', proxy_type='https', proxy_host=None, proxy_port=None):
|
def __init__(self, headers={}, useragent=f'IzzyLib/{__version__}', appagent=None, proxy_type='https', proxy_host=None, proxy_port=None):
|
||||||
proxy_ports = {
|
proxy_ports = {
|
||||||
'http': 80,
|
'http': 80,
|
||||||
'https': 443
|
'https': 443
|
||||||
|
@ -202,7 +43,7 @@ class HttpClient(object):
|
||||||
raise ValueError(f'Not a valid proxy type: {proxy_type}')
|
raise ValueError(f'Not a valid proxy type: {proxy_type}')
|
||||||
|
|
||||||
self.headers=headers
|
self.headers=headers
|
||||||
self.agent=useragent
|
self.agent = f'{useragent} ({appagent})' if appagent else useragent
|
||||||
self.proxy = DotDict({
|
self.proxy = DotDict({
|
||||||
'enabled': True if proxy_host else False,
|
'enabled': True if proxy_host else False,
|
||||||
'ptype': proxy_type,
|
'ptype': proxy_type,
|
||||||
|
@ -210,6 +51,8 @@ class HttpClient(object):
|
||||||
'port': proxy_ports[proxy_type] if not proxy_port else proxy_port
|
'port': proxy_ports[proxy_type] if not proxy_port else proxy_port
|
||||||
})
|
})
|
||||||
|
|
||||||
|
self.SetGlobal = SetClient
|
||||||
|
|
||||||
|
|
||||||
def __sign_request(self, request, privkey, keyid):
|
def __sign_request(self, request, privkey, keyid):
|
||||||
if not crypto_enabled:
|
if not crypto_enabled:
|
||||||
|
@ -269,9 +112,14 @@ class HttpClient(object):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = urlopen(request)
|
response = urlopen(request)
|
||||||
|
|
||||||
except HTTPError as e:
|
except HTTPError as e:
|
||||||
response = e.fp
|
response = e.fp
|
||||||
|
|
||||||
|
except SSLCertVerificationError as e:
|
||||||
|
logging.error('HttpClient.request: Certificate error:', e)
|
||||||
|
return
|
||||||
|
|
||||||
return HttpResponse(response)
|
return HttpResponse(response)
|
||||||
|
|
||||||
|
|
||||||
|
@ -337,6 +185,211 @@ class HttpResponse(object):
|
||||||
return json.dumps(self.json().asDict(), indent=indent)
|
return json.dumps(self.json().asDict(), indent=indent)
|
||||||
|
|
||||||
|
|
||||||
def SetClient(client: HttpClient):
|
def VerifyRequest(request: SanicRequest, actor: dict):
|
||||||
|
'''Verify a header signature from a sanic request
|
||||||
|
|
||||||
|
request: The request with the headers to verify
|
||||||
|
actor: A dictionary containing the activitypub actor and the link to the pubkey used for verification
|
||||||
|
'''
|
||||||
|
if not sanic_enabled:
|
||||||
|
logging.error('Sanic request verification disabled')
|
||||||
|
return
|
||||||
|
|
||||||
|
body = request.body if request.body else None
|
||||||
|
return VerifyHeaders(request.headers, request.method, request.path, body, actor)
|
||||||
|
|
||||||
|
|
||||||
|
def VerifyHeaders(headers: dict, method: str, path: str, actor: dict=None, body=None):
|
||||||
|
'''Verify a header signature
|
||||||
|
|
||||||
|
headers: A dictionary containing all the headers from a request
|
||||||
|
method: The HTTP method of the request
|
||||||
|
path: The path of the HTTP request
|
||||||
|
actor (optional): A dictionary containing the activitypub actor and the link to the pubkey used for verification
|
||||||
|
body (optional): The body of the request. Only needed if the signature includes the digest header
|
||||||
|
fail (optional): If set to True, raise an error instead of returning False if any step of the process fails
|
||||||
|
'''
|
||||||
|
if not crypto_enabled:
|
||||||
|
logging.error('Crypto functions disabled')
|
||||||
|
return
|
||||||
|
|
||||||
|
headers = {k.lower(): v for k,v in headers.items()}
|
||||||
|
headers['(request-target)'] = f'{method.lower()} {path}'
|
||||||
|
signature = ParseSig(headers.get('signature'))
|
||||||
|
digest = ParseBodyDigest(headers.get('digest'))
|
||||||
|
missing_headers = [k for k in headers if k in ['date', 'host'] if headers.get(k) == None]
|
||||||
|
|
||||||
|
if not signature:
|
||||||
|
logging.verbose('Missing signature')
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not actor:
|
||||||
|
actor = FetchActor(signature.keyid)
|
||||||
|
|
||||||
|
## Add digest header to missing headers list if it doesn't exist
|
||||||
|
if method.lower() == 'post' and not digest:
|
||||||
|
missing_headers.append('digest')
|
||||||
|
|
||||||
|
## Fail if missing date, host or digest (if POST) headers
|
||||||
|
if missing_headers:
|
||||||
|
logging.verbose('Missing headers:', missing_headers)
|
||||||
|
return False
|
||||||
|
|
||||||
|
## Fail if body verification fails
|
||||||
|
if digest and not VerifyString(body, digest.sig, digest.alg):
|
||||||
|
logging.verbose('Failed body digest verification')
|
||||||
|
return False
|
||||||
|
|
||||||
|
pubkey = actor.publicKey['publicKeyPem']
|
||||||
|
|
||||||
|
if PkcsHeaders(pubkey, {k:v for k,v in headers.items() if k in signature.headers}, sig=signature):
|
||||||
|
return True
|
||||||
|
|
||||||
|
logging.verbose('Failed header verification')
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def ParseBodyDigest(digest):
|
||||||
|
if not digest:
|
||||||
|
return
|
||||||
|
|
||||||
|
parsed = DotDict()
|
||||||
|
parts = digest.split('=', 1)
|
||||||
|
|
||||||
|
if len(parts) != 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
parsed.sig = parts[1]
|
||||||
|
parsed.alg = parts[0].replace('-', '')
|
||||||
|
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
def VerifyString(string, enc_string, alg='SHA256', fail=False):
|
||||||
|
if not crypto_enabled:
|
||||||
|
logging.error('Crypto functions disabled')
|
||||||
|
return
|
||||||
|
|
||||||
|
if type(string) != bytes:
|
||||||
|
string = string.encode('UTF-8')
|
||||||
|
|
||||||
|
body_hash = b64encode(SHA256.new(string).digest()).decode('UTF-8')
|
||||||
|
|
||||||
|
if body_hash == enc_string:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if fail:
|
||||||
|
raise error.VerificationError()
|
||||||
|
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def PkcsHeaders(key: str, headers: dict, sig=None):
|
||||||
|
if not crypto_enabled:
|
||||||
|
logging.error('Crypto functions disabled')
|
||||||
|
return
|
||||||
|
|
||||||
|
if sig:
|
||||||
|
head_items = [f'{item}: {headers[item]}' for item in sig.headers]
|
||||||
|
|
||||||
|
else:
|
||||||
|
head_items = [f'{k.lower()}: {v}' for k,v in headers.items()]
|
||||||
|
|
||||||
|
head_string = '\n'.join(head_items)
|
||||||
|
head_bytes = head_string.encode('UTF-8')
|
||||||
|
|
||||||
|
KEY = RSA.importKey(key)
|
||||||
|
pkcs = PKCS1_v1_5.new(KEY)
|
||||||
|
h = SHA256.new(head_bytes)
|
||||||
|
|
||||||
|
if sig:
|
||||||
|
return pkcs.verify(h, b64decode(sig.signature))
|
||||||
|
|
||||||
|
else:
|
||||||
|
return pkcs.sign(h)
|
||||||
|
|
||||||
|
|
||||||
|
def ParseSig(signature: str):
|
||||||
|
if not signature:
|
||||||
|
logging.verbose('Missing signature header')
|
||||||
|
return
|
||||||
|
|
||||||
|
split_sig = signature.split(',')
|
||||||
|
sig = DefaultDict({})
|
||||||
|
|
||||||
|
for part in split_sig:
|
||||||
|
key, value = part.split('=', 1)
|
||||||
|
sig[key.lower()] = value.replace('"', '')
|
||||||
|
|
||||||
|
if not sig.headers:
|
||||||
|
logging.verbose('Missing headers section in signature')
|
||||||
|
return
|
||||||
|
|
||||||
|
sig.headers = sig.headers.split()
|
||||||
|
|
||||||
|
return sig
|
||||||
|
|
||||||
|
|
||||||
|
def FetchActor(url):
|
||||||
|
if not Client:
|
||||||
|
logging.error('IzzyLib.http: Please set global client with "SetClient(client)"')
|
||||||
|
return {}
|
||||||
|
|
||||||
|
url = url.split('#')[0]
|
||||||
|
headers = {'Accept': 'application/activity+json'}
|
||||||
|
resp = Client.request(url, headers=headers)
|
||||||
|
|
||||||
|
if not resp.json():
|
||||||
|
logging.verbose('functions.FetchActor: Failed to fetch actor:', url)
|
||||||
|
logging.debug(f'Error {resp.status}: {resp.body}')
|
||||||
|
return {}
|
||||||
|
|
||||||
|
actor = resp.json()
|
||||||
|
actor.web_domain = urlparse(url).netloc
|
||||||
|
actor.shared_inbox = actor.inbox
|
||||||
|
actor.pubkey = None
|
||||||
|
actor.handle = actor.preferredUsername
|
||||||
|
|
||||||
|
if actor.get('endpoints'):
|
||||||
|
actor.shared_inbox = actor.endpoints.get('sharedInbox', actor.inbox)
|
||||||
|
|
||||||
|
if actor.get('publicKey'):
|
||||||
|
actor.pubkey = actor.publicKey.get('publicKeyPem')
|
||||||
|
|
||||||
|
return actor
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=512)
|
||||||
|
def FetchWebfingerAcct(handle, domain):
|
||||||
|
if not Client:
|
||||||
|
logging.error('IzzyLib.http: Please set global client with "SetClient(client)"')
|
||||||
|
return {}
|
||||||
|
|
||||||
|
data = DefaultDict()
|
||||||
|
webfinger = Client.request(f'https://{domain}/.well-known/webfinger?resource=acct:{handle}@{domain}')
|
||||||
|
|
||||||
|
if not webfinger.body:
|
||||||
|
return
|
||||||
|
|
||||||
|
data.handle, data.domain = webfinger.json().subject.replace('acct:', '').split('@')
|
||||||
|
|
||||||
|
for link in webfinger.json().links:
|
||||||
|
if link['rel'] == 'self' and link['type'] == 'application/activity+json':
|
||||||
|
data.actor = link['href']
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def SetClient(client=None):
|
||||||
global Client
|
global Client
|
||||||
Client = client
|
Client = client or HttpClient()
|
||||||
|
|
||||||
|
|
||||||
|
def GenRsaKey():
|
||||||
|
privkey = RSA.generate(2048)
|
||||||
|
|
||||||
|
key = DotDict({'PRIVKEY': privkey, 'PUBKEY': privkey.publickey()})
|
||||||
|
key.update({'privkey': key.PRIVKEY.export_key().decode(), 'pubkey': key.PUBKEY.export_key().decode()})
|
||||||
|
|
||||||
|
return key
|
||||||
|
|
326
IzzyLib/http_server.py
Normal file
326
IzzyLib/http_server.py
Normal file
|
@ -0,0 +1,326 @@
|
||||||
|
import multiprocessing, sanic, signal, traceback
|
||||||
|
import logging as pylog
|
||||||
|
|
||||||
|
from jinja2.exceptions import TemplateNotFound
|
||||||
|
from multidict import CIMultiDict
|
||||||
|
from multiprocessing import cpu_count, current_process
|
||||||
|
from urllib.parse import parse_qsl, urlparse
|
||||||
|
|
||||||
|
from . import http, logging
|
||||||
|
from .misc import DotDict, DefaultDict, LowerDotDict
|
||||||
|
from .template import Template
|
||||||
|
|
||||||
|
|
||||||
|
log_path_ignore = [
|
||||||
|
'/media',
|
||||||
|
'/static'
|
||||||
|
]
|
||||||
|
|
||||||
|
log_ext_ignore = [
|
||||||
|
'js', 'ttf', 'woff2',
|
||||||
|
'ac3', 'aiff', 'flac', 'm4a', 'mp3', 'ogg', 'wav', 'wma',
|
||||||
|
'apng', 'ico', 'jpeg', 'jpg', 'png', 'svg',
|
||||||
|
'divx', 'mov', 'mp4', 'webm', 'wmv'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class HttpServer(sanic.Sanic):
|
||||||
|
def __init__(self, name='sanic', host='0.0.0.0', port='4080', **kwargs):
|
||||||
|
self.host = host
|
||||||
|
self.port = int(port)
|
||||||
|
self.workers = int(kwargs.get('workers', cpu_count()))
|
||||||
|
self.sig_handler = kwargs.get('sig_handler')
|
||||||
|
self.ctx = DotDict()
|
||||||
|
|
||||||
|
super().__init__(name, request_class=kwargs.get('request_class', HttpRequest))
|
||||||
|
|
||||||
|
#for log in ['sanic.root', 'sanic.access']:
|
||||||
|
#pylog.getLogger(log).setLevel(pylog.CRITICAL)
|
||||||
|
|
||||||
|
self.template = Template(
|
||||||
|
kwargs.get('tpl_search', []),
|
||||||
|
kwargs.get('tpl_globals', {}),
|
||||||
|
kwargs.get('tpl_context'),
|
||||||
|
kwargs.get('tpl_autoescape', True)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.template.addEnv('app', self)
|
||||||
|
|
||||||
|
self.error_handler.add(TemplateNotFound, NoTemplateError)
|
||||||
|
self.error_handler.add(Exception, kwargs.get('error_handler', GenericError))
|
||||||
|
self.register_middleware(MiddlewareAccessLog, attach_to='response')
|
||||||
|
|
||||||
|
signal.signal(signal.SIGHUP, self.finish)
|
||||||
|
signal.signal(signal.SIGINT, self.finish)
|
||||||
|
signal.signal(signal.SIGQUIT, self.finish)
|
||||||
|
signal.signal(signal.SIGTERM, self.finish)
|
||||||
|
|
||||||
|
|
||||||
|
def add_method_route(self, method, *routes):
|
||||||
|
for route in routes:
|
||||||
|
self.add_route(method.as_view(), route)
|
||||||
|
|
||||||
|
|
||||||
|
def add_method_routes(self, routes: list):
|
||||||
|
for route in routes:
|
||||||
|
self.add_method_route(*route)
|
||||||
|
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
options = {
|
||||||
|
'host': self.host,
|
||||||
|
'port': self.port,
|
||||||
|
'workers': self.workers,
|
||||||
|
'access_log': False,
|
||||||
|
'debug': False
|
||||||
|
}
|
||||||
|
|
||||||
|
msg = f'Starting {self.name} at {self.host}:{self.port}'
|
||||||
|
|
||||||
|
if self.workers > 1:
|
||||||
|
msg += f' with {self.workers} workers'
|
||||||
|
|
||||||
|
logging.info(msg)
|
||||||
|
self.run(**options)
|
||||||
|
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
if self.sig_handler:
|
||||||
|
self.sig_handler()
|
||||||
|
|
||||||
|
self.stop()
|
||||||
|
logging.info('Bye! :3')
|
||||||
|
|
||||||
|
|
||||||
|
class HttpRequest(sanic.request.Request):
|
||||||
|
def __init__(self, url_bytes, headers, version, method, transport, app):
|
||||||
|
super().__init__(url_bytes, headers, version, method, transport, app)
|
||||||
|
|
||||||
|
self.Headers = Headers(headers)
|
||||||
|
self.Data = Data(self)
|
||||||
|
self.template = self.app.template
|
||||||
|
self.__setup_defaults()
|
||||||
|
self.__parse_path()
|
||||||
|
|
||||||
|
#if self.paths.media:
|
||||||
|
#return
|
||||||
|
|
||||||
|
self.__parse_signature()
|
||||||
|
self.Run()
|
||||||
|
|
||||||
|
|
||||||
|
def Run(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def response(self, tpl, *args, **kwargs):
|
||||||
|
return self.template.response(self, tpl, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def alldata(self):
|
||||||
|
return self.__combine_dicts(self.content.json, self.data.query, self.data.form)
|
||||||
|
|
||||||
|
|
||||||
|
def verify(self, actor=None):
|
||||||
|
self.ap.valid = http.VerifyHeaders(self.headers, self.method, self.path, actor, self.body)
|
||||||
|
return self.ap.valid
|
||||||
|
|
||||||
|
|
||||||
|
def __combine_dicts(self, *dicts):
|
||||||
|
data = DotDict()
|
||||||
|
|
||||||
|
for item in dicts:
|
||||||
|
data.update(item)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def __setup_defaults(self):
|
||||||
|
self.paths = DotDict({'media': False, 'json': False, 'ap': False, 'cookie': False})
|
||||||
|
self.ap = DotDict({'valid': False, 'signature': {}, 'actor': None, 'inbox': None, 'domain': None})
|
||||||
|
|
||||||
|
|
||||||
|
def __parse_path(self):
|
||||||
|
self.paths.media = any(map(self.path.startswith, log_path_ignore)) or any(map(self.path.startswith, log_ext_ignore))
|
||||||
|
self.paths.json = self.__json_check()
|
||||||
|
|
||||||
|
|
||||||
|
def __parse_signature(self):
|
||||||
|
sig = self.headers.getone('signature', None)
|
||||||
|
|
||||||
|
if sig:
|
||||||
|
self.ap.signature = http.ParseSig(sig)
|
||||||
|
|
||||||
|
if self.ap.signature:
|
||||||
|
self.ap.actor = self.ap.signature.get('keyid', '').split('#', 1)[0]
|
||||||
|
self.ap.domain = urlparse(self.ap.actor).netloc
|
||||||
|
|
||||||
|
|
||||||
|
def __json_check(self):
|
||||||
|
if self.path.endswith('.json'):
|
||||||
|
return True
|
||||||
|
|
||||||
|
accept = self.headers.getone('Accept', None)
|
||||||
|
|
||||||
|
if accept:
|
||||||
|
mimes = [v.strip() for v in accept.split(',')]
|
||||||
|
|
||||||
|
if any(mime in ['application/json', 'application/activity+json'] for mime in mimes):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class Headers(LowerDotDict):
|
||||||
|
def __init__(self, headers):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
for k,v in headers.items():
|
||||||
|
if not self.get(k):
|
||||||
|
self[k] = []
|
||||||
|
|
||||||
|
self[k].append(v)
|
||||||
|
|
||||||
|
|
||||||
|
def getone(self, key, default=None):
|
||||||
|
value = self.get(key)
|
||||||
|
|
||||||
|
if not value:
|
||||||
|
return default
|
||||||
|
|
||||||
|
return value[0]
|
||||||
|
|
||||||
|
|
||||||
|
def getall(self, key, default=[]):
|
||||||
|
return self.get(key.lower(), default)
|
||||||
|
|
||||||
|
|
||||||
|
class Data(object):
|
||||||
|
def __init__(self, request):
|
||||||
|
self.request = request
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def combined(self):
|
||||||
|
return DotDict(**self.form.asDict(), **self.query.asDict(), **self.json.asDict())
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def query(self):
|
||||||
|
data = {k: v for k,v in parse_qsl(self.request.query_string)}
|
||||||
|
return DotDict(data)
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def form(self):
|
||||||
|
data = {k: v[0] for k,v in self.request.form.items()}
|
||||||
|
return DotDict(data)
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def files(self):
|
||||||
|
return DotDict({k:v[0] for k,v in self.request.files.items()})
|
||||||
|
|
||||||
|
|
||||||
|
### body functions
|
||||||
|
@property
|
||||||
|
def raw(self):
|
||||||
|
try:
|
||||||
|
return self.request.body
|
||||||
|
except Exception as e:
|
||||||
|
logging.verbose('IzzyLib.http_server.Data.raw: failed to get body')
|
||||||
|
logging.debug(f'{e.__class__.__name__}: {e}')
|
||||||
|
return b''
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self):
|
||||||
|
try:
|
||||||
|
return self.raw.decode()
|
||||||
|
except Exception as e:
|
||||||
|
logging.verbose('IzzyLib.http_server.Data.text: failed to get body')
|
||||||
|
logging.debug(f'{e.__class__.__name__}: {e}')
|
||||||
|
return ''
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def json(self):
|
||||||
|
try:
|
||||||
|
return DotDict(self.text)
|
||||||
|
except Exception as e:
|
||||||
|
logging.verbose('IzzyLib.http_server.Data.json: failed to get body')
|
||||||
|
logging.debug(f'{e.__class__.__name__}: {e}')
|
||||||
|
data = '{}'
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
async def MiddlewareAccessLog(request, response):
|
||||||
|
if request.paths.media:
|
||||||
|
return
|
||||||
|
|
||||||
|
uagent = request.headers.get('user-agent')
|
||||||
|
address = request.headers.get('x-real-ip', request.forwarded.get('for', request.remote_addr))
|
||||||
|
|
||||||
|
logging.info(f'({multiprocessing.current_process().name}) {address} {request.method} {request.path} {response.status} "{uagent}"')
|
||||||
|
|
||||||
|
|
||||||
|
def GenericError(request, exception):
|
||||||
|
try:
|
||||||
|
status = exception.status_code
|
||||||
|
except:
|
||||||
|
status = 500
|
||||||
|
|
||||||
|
if status not in range(200, 499):
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
msg = f'{exception.__class__.__name__}: {str(exception)}'
|
||||||
|
|
||||||
|
if request.paths.json:
|
||||||
|
return sanic.response.json({'error': {'status': status, 'message': msg}})
|
||||||
|
|
||||||
|
try:
|
||||||
|
return request.response('server_error.haml', status=status, context={'status': str(status), 'error': msg})
|
||||||
|
|
||||||
|
except TemplateNotFound:
|
||||||
|
return sanic.response.text(f'Error {status}: {msg}')
|
||||||
|
|
||||||
|
|
||||||
|
def NoTemplateError(request, exception):
|
||||||
|
logging.error('TEMPLATE_ERROR:', f'{exception.__class__.__name__}: {str(exception)}')
|
||||||
|
return sanic.response.html('I\'m a dumbass and forgot to create a template for this page', 500)
|
||||||
|
|
||||||
|
|
||||||
|
def ReplaceHeader(headers, key, value):
|
||||||
|
for k,v in headers.items():
|
||||||
|
if k.lower() == header.lower():
|
||||||
|
del headers[k]
|
||||||
|
|
||||||
|
|
||||||
|
class Response:
|
||||||
|
Text = sanic.response.text
|
||||||
|
Html = sanic.response.html
|
||||||
|
Json = sanic.response.json
|
||||||
|
Redir = sanic.response.redirect
|
||||||
|
|
||||||
|
|
||||||
|
def Css(*args, headers={}, **kwargs):
|
||||||
|
ReplaceHeader(headers, 'content-type', 'text/css')
|
||||||
|
return sanic.response.text(*args, headers=headers, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def Js(*args, headers={}, **kwargs):
|
||||||
|
ReplaceHeader(headers, 'content-type', 'application/javascript')
|
||||||
|
return sanic.response.text(*args, headers=headers, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def Ap(*args, headers={}, **kwargs):
|
||||||
|
ReplaceHeader(headers, 'content-type', 'application/activity+json')
|
||||||
|
return sanic.response.json(*args, headers=headers, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def Jrd(*args, headers={}, **kwargs):
|
||||||
|
ReplaceHeader(headers, 'content-type', 'application/jrd+json')
|
||||||
|
return sanic.response.json(*args, headers=headers, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
Resp = Response
|
199
IzzyLib/misc.py
199
IzzyLib/misc.py
|
@ -1,5 +1,5 @@
|
||||||
'''Miscellaneous functions'''
|
'''Miscellaneous functions'''
|
||||||
import random, string, sys, os, json, socket
|
import hashlib, random, string, sys, os, json, socket, time
|
||||||
|
|
||||||
from os import environ as env
|
from os import environ as env
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -8,6 +8,11 @@ from pathlib import Path as Pathlib
|
||||||
|
|
||||||
from . import logging
|
from . import logging
|
||||||
|
|
||||||
|
try:
|
||||||
|
from passlib.hash import argon2
|
||||||
|
except ImportError:
|
||||||
|
argon2 = None
|
||||||
|
|
||||||
|
|
||||||
def Boolean(v, return_value=False):
|
def Boolean(v, return_value=False):
|
||||||
if type(v) not in [str, bool, int, type(None)]:
|
if type(v) not in [str, bool, int, type(None)]:
|
||||||
|
@ -43,6 +48,20 @@ def RandomGen(length=20, chars=None):
|
||||||
return ''.join(random.choices(characters, k=length))
|
return ''.join(random.choices(characters, k=length))
|
||||||
|
|
||||||
|
|
||||||
|
def HashString(string, alg='blake2s'):
|
||||||
|
if alg not in hashlib.__always_supported:
|
||||||
|
logging.error('Unsupported hash algorithm:', alg)
|
||||||
|
logging.error('Supported algs:', ', '.join(hashlib.__always_supported))
|
||||||
|
return
|
||||||
|
|
||||||
|
string = string.encode('UTF-8') if type(string) != bytes else string
|
||||||
|
salt = salt.encode('UTF-8') if type(salt) != bytes else salt
|
||||||
|
|
||||||
|
newhash = hashlib.new(alg)
|
||||||
|
newhash.update(string)
|
||||||
|
return newhash.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def Timestamp(dtobj=None, utc=False):
|
def Timestamp(dtobj=None, utc=False):
|
||||||
dtime = dtobj if dtobj else datetime
|
dtime = dtobj if dtobj else datetime
|
||||||
date = dtime.utcnow() if utc else dtime.now()
|
date = dtime.utcnow() if utc else dtime.now()
|
||||||
|
@ -50,6 +69,11 @@ def Timestamp(dtobj=None, utc=False):
|
||||||
return date.timestamp()
|
return date.timestamp()
|
||||||
|
|
||||||
|
|
||||||
|
def GetVarName(*kwargs, single=True):
|
||||||
|
keys = list(kwargs.keys())
|
||||||
|
return key[0] if single else keys
|
||||||
|
|
||||||
|
|
||||||
def ApDate(date=None, alt=False):
|
def ApDate(date=None, alt=False):
|
||||||
if not date:
|
if not date:
|
||||||
date = datetime.utcnow()
|
date = datetime.utcnow()
|
||||||
|
@ -93,14 +117,13 @@ def Input(prompt, default=None, valtype=str, options=[], password=False):
|
||||||
prompt += f'[{opt}]'
|
prompt += f'[{opt}]'
|
||||||
|
|
||||||
prompt += ': '
|
prompt += ': '
|
||||||
|
|
||||||
value = input_func(prompt)
|
value = input_func(prompt)
|
||||||
|
|
||||||
while value and options and value not in options:
|
while value and len(options) > 0 and value not in options:
|
||||||
input_func('Invalid value:', value)
|
input_func('Invalid value:', value)
|
||||||
value = input(prompt)
|
value = input(prompt)
|
||||||
|
|
||||||
if not value:
|
if not value or value == '':
|
||||||
return default
|
return default
|
||||||
|
|
||||||
ret = valtype(value)
|
ret = valtype(value)
|
||||||
|
@ -112,14 +135,38 @@ def Input(prompt, default=None, valtype=str, options=[], password=False):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def NfsCheck(path):
|
||||||
|
proc = Path('/proc/mounts')
|
||||||
|
path = Path(path).resolve()
|
||||||
|
|
||||||
|
if not proc.exists():
|
||||||
|
return True
|
||||||
|
|
||||||
|
with proc.open() as fd:
|
||||||
|
for line in fd:
|
||||||
|
line = line.split()
|
||||||
|
|
||||||
|
if line[2] == 'nfs' and line[1] in path.str():
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class DotDict(dict):
|
class DotDict(dict):
|
||||||
def __init__(self, value=None, **kwargs):
|
def __init__(self, value=None, **kwargs):
|
||||||
super().__init__()
|
'''Python dictionary, but variables can be set/get via attributes
|
||||||
|
|
||||||
if type(value) in [str, bytes]:
|
value [str, bytes, dict]: JSON or dict of values to init with
|
||||||
|
case_insensitive [bool]: Wether keys should be case sensitive or not
|
||||||
|
kwargs: key/value pairs to set on init. Overrides identical keys set by 'value'
|
||||||
|
'''
|
||||||
|
super().__init__()
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
if isinstance(value, (str, bytes)):
|
||||||
self.fromJson(value)
|
self.fromJson(value)
|
||||||
|
|
||||||
elif type(value) in [dict, DotDict, DefaultDict]:
|
elif isinstance(value, dict):
|
||||||
self.update(value)
|
self.update(value)
|
||||||
|
|
||||||
elif value:
|
elif value:
|
||||||
|
@ -134,11 +181,15 @@ class DotDict(dict):
|
||||||
val = super().__getattribute__(key)
|
val = super().__getattribute__(key)
|
||||||
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
val = self.get(key, InvalidKey())
|
val = self.get(key, KeyError)
|
||||||
|
|
||||||
if type(val) == InvalidKey:
|
try:
|
||||||
|
if val == KeyError:
|
||||||
raise KeyError(f'Invalid key: {key}')
|
raise KeyError(f'Invalid key: {key}')
|
||||||
|
|
||||||
|
except AttributeError:
|
||||||
|
'PyCryptodome.PublicKey.RSA.RsaKey.__eq__ does not seem to play nicely'
|
||||||
|
|
||||||
return DotDict(val) if type(val) == dict else val
|
return DotDict(val) if type(val) == dict else val
|
||||||
|
|
||||||
|
|
||||||
|
@ -149,11 +200,6 @@ class DotDict(dict):
|
||||||
super().__delattr__(key)
|
super().__delattr__(key)
|
||||||
|
|
||||||
|
|
||||||
#def __delitem__(self, key):
|
|
||||||
#print('delitem', key)
|
|
||||||
#self.__delattr__(key)
|
|
||||||
|
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
def __setattr__(self, key, value):
|
||||||
if key.startswith('_'):
|
if key.startswith('_'):
|
||||||
super().__setattr__(key, value)
|
super().__setattr__(key, value)
|
||||||
|
@ -162,6 +208,10 @@ class DotDict(dict):
|
||||||
super().__setitem__(key, value)
|
super().__setitem__(key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.toJson()
|
||||||
|
|
||||||
|
|
||||||
def __parse_item__(self, k, v):
|
def __parse_item__(self, k, v):
|
||||||
if type(v) == dict:
|
if type(v) == dict:
|
||||||
v = DotDict(v)
|
v = DotDict(v)
|
||||||
|
@ -170,8 +220,12 @@ class DotDict(dict):
|
||||||
return (k, v)
|
return (k, v)
|
||||||
|
|
||||||
|
|
||||||
|
def update(self, data):
|
||||||
|
super().update(data)
|
||||||
|
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
value = super().get(key, default)
|
value = dict.get(self, key, default)
|
||||||
return DotDict(value) if type(value) == dict else value
|
return DotDict(value) if type(value) == dict else value
|
||||||
|
|
||||||
|
|
||||||
|
@ -200,13 +254,18 @@ class DotDict(dict):
|
||||||
|
|
||||||
|
|
||||||
def toJson(self, indent=None, **kwargs):
|
def toJson(self, indent=None, **kwargs):
|
||||||
|
kwargs.pop('cls', None)
|
||||||
|
return json.dumps(dict(self), indent=indent, cls=DotDictEncoder, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def toJson2(self, indent=None, **kwargs):
|
||||||
data = {}
|
data = {}
|
||||||
|
|
||||||
for k,v in self.items():
|
for k,v in self.items():
|
||||||
if type(k) in [DotDict, Path, Pathlib]:
|
if k and not type(k) in [str, int, float, dict]:
|
||||||
k = str(k)
|
k = str(k)
|
||||||
|
|
||||||
if type(v) in [DotDict, Path, Pathlib]:
|
if v and not type(k) in [str, int, float, dict]:
|
||||||
v = str(v)
|
v = str(v)
|
||||||
|
|
||||||
data[k] = v
|
data[k] = v
|
||||||
|
@ -230,9 +289,41 @@ class DefaultDict(DotDict):
|
||||||
return DotDict(val) if type(val) == dict else val
|
return DotDict(val) if type(val) == dict else val
|
||||||
|
|
||||||
|
|
||||||
|
class LowerDotDict(DotDict):
|
||||||
|
def __getattr__(self, key):
|
||||||
|
key = key.lower()
|
||||||
|
|
||||||
|
try:
|
||||||
|
val = super().__getattribute__(key)
|
||||||
|
|
||||||
|
except AttributeError:
|
||||||
|
val = self.get(key, KeyError)
|
||||||
|
|
||||||
|
if val == KeyError:
|
||||||
|
raise KeyError(f'Invalid key: {key}')
|
||||||
|
|
||||||
|
return DotDict(val) if type(val) == dict else val
|
||||||
|
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
key = key.lower()
|
||||||
|
|
||||||
|
if key.startswith('_'):
|
||||||
|
super().__setattr__(key, value)
|
||||||
|
|
||||||
|
else:
|
||||||
|
super().__setitem__(key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def update(self, data):
|
||||||
|
data = {k.lower(): v for k,v in self.items()}
|
||||||
|
|
||||||
|
super().update(data)
|
||||||
|
|
||||||
|
|
||||||
class Path(object):
|
class Path(object):
|
||||||
def __init__(self, path, exist=True, missing=True, parents=True):
|
def __init__(self, path, exist=True, missing=True, parents=True):
|
||||||
self.__path = Pathlib(str(path)).resolve()
|
self.__path = Pathlib(str(path))
|
||||||
self.json = DotDict({})
|
self.json = DotDict({})
|
||||||
self.exist = exist
|
self.exist = exist
|
||||||
self.missing = missing
|
self.missing = missing
|
||||||
|
@ -240,20 +331,14 @@ class Path(object):
|
||||||
self.name = self.__path.name
|
self.name = self.__path.name
|
||||||
|
|
||||||
|
|
||||||
#def __getattr__(self, key):
|
|
||||||
#try:
|
|
||||||
#attr = getattr(self.__path, key)
|
|
||||||
|
|
||||||
#except AttributeError:
|
|
||||||
#attr = getattr(self, key)
|
|
||||||
|
|
||||||
#return attr
|
|
||||||
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return str(self.__path)
|
return str(self.__path)
|
||||||
|
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'Path({str(self.__path)})'
|
||||||
|
|
||||||
|
|
||||||
def str(self):
|
def str(self):
|
||||||
return self.__str__()
|
return self.__str__()
|
||||||
|
|
||||||
|
@ -307,7 +392,7 @@ class Path(object):
|
||||||
|
|
||||||
|
|
||||||
def join(self, path, new=True):
|
def join(self, path, new=True):
|
||||||
new_path = self.__path.joinpath(path).resolve()
|
new_path = self.__path.joinpath(path)
|
||||||
|
|
||||||
if new:
|
if new:
|
||||||
return Path(new_path)
|
return Path(new_path)
|
||||||
|
@ -422,25 +507,53 @@ class Path(object):
|
||||||
return self.open().readlines()
|
return self.open().readlines()
|
||||||
|
|
||||||
|
|
||||||
## def rmdir():
|
class DotDictEncoder(json.JSONEncoder):
|
||||||
|
def default(self, obj):
|
||||||
|
if type(obj) not in [str, int, float, dict]:
|
||||||
|
return str(obj)
|
||||||
|
|
||||||
|
return json.JSONEncoder.default(self, obj)
|
||||||
|
|
||||||
|
|
||||||
def NfsCheck(path):
|
class PasswordHash(object):
|
||||||
proc = Path('/proc/mounts')
|
def __init__(self, salt=None, rounds=8, bsize=50, threads=os.cpu_count(), length=64):
|
||||||
path = Path(path).resolve()
|
if type(salt) == Path:
|
||||||
|
if salt.exists():
|
||||||
|
with salt.open() as fd:
|
||||||
|
self.salt = fd.read()
|
||||||
|
|
||||||
if not proc.exists():
|
else:
|
||||||
return True
|
newsalt = RandomGen(40)
|
||||||
|
|
||||||
with proc.open() as fd:
|
with salt.open('w') as fd:
|
||||||
for line in fd:
|
fd.write(newsalt)
|
||||||
line = line.split()
|
|
||||||
|
|
||||||
if line[2] == 'nfs' and line[1] in path.str():
|
self.salt = newsalt
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
else:
|
||||||
|
self.salt = salt or RandomGen(40)
|
||||||
|
|
||||||
|
self.rounds = rounds
|
||||||
|
self.bsize = bsize * 1024
|
||||||
|
self.threads = threads
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
|
||||||
class InvalidKey(object):
|
def hash(self, password):
|
||||||
pass
|
return argon2.using(
|
||||||
|
salt = self.salt.encode('UTF-8'),
|
||||||
|
rounds = self.rounds,
|
||||||
|
memory_cost = self.bsize,
|
||||||
|
max_threads = self.threads,
|
||||||
|
digest_size = self.length
|
||||||
|
).hash(password)
|
||||||
|
|
||||||
|
|
||||||
|
def verify(self, password, passhash):
|
||||||
|
return argon2.using(
|
||||||
|
salt = self.salt.encode('UTF-8'),
|
||||||
|
rounds = self.rounds,
|
||||||
|
memory_cost = self.bsize,
|
||||||
|
max_threads = self.threads,
|
||||||
|
digest_size = self.length
|
||||||
|
).verify(password, passhash)
|
||||||
|
|
|
@ -6,9 +6,6 @@ from os.path import isfile, isdir, getmtime, abspath
|
||||||
|
|
||||||
from jinja2 import Environment, FileSystemLoader, ChoiceLoader, select_autoescape, Markup
|
from jinja2 import Environment, FileSystemLoader, ChoiceLoader, select_autoescape, Markup
|
||||||
from hamlish_jinja import HamlishExtension
|
from hamlish_jinja import HamlishExtension
|
||||||
from markdown import markdown
|
|
||||||
from watchdog.observers import Observer
|
|
||||||
from watchdog.events import FileSystemEventHandler
|
|
||||||
from xml.dom import minidom
|
from xml.dom import minidom
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -23,11 +20,10 @@ from .misc import Path, DotDict
|
||||||
|
|
||||||
|
|
||||||
class Template(Environment):
|
class Template(Environment):
|
||||||
def __init__(self, search=[], global_vars={}, autoescape=True):
|
def __init__(self, search=[], global_vars={}, context=None, autoescape=True):
|
||||||
self.autoescape = autoescape
|
self.autoescape = autoescape
|
||||||
self.watcher = None
|
|
||||||
self.search = []
|
self.search = []
|
||||||
self.func_context = None
|
self.func_context = context
|
||||||
|
|
||||||
for path in search:
|
for path in search:
|
||||||
self.__add_search_path(path)
|
self.__add_search_path(path)
|
||||||
|
@ -45,7 +41,6 @@ class Template(Environment):
|
||||||
self.hamlish_mode = 'indented'
|
self.hamlish_mode = 'indented'
|
||||||
|
|
||||||
self.globals.update({
|
self.globals.update({
|
||||||
'markdown': markdown,
|
|
||||||
'markup': Markup,
|
'markup': Markup,
|
||||||
'cleanhtml': lambda text: ''.join(xml.etree.ElementTree.fromstring(text).itertext()),
|
'cleanhtml': lambda text: ''.join(xml.etree.ElementTree.fromstring(text).itertext()),
|
||||||
'lighten': lighten,
|
'lighten': lighten,
|
||||||
|
@ -123,9 +118,9 @@ class Template(Environment):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def response(self, *args, ctype='text/html', status=200, **kwargs):
|
def response(self, request, tpl, ctype='text/html', status=200, **kwargs):
|
||||||
if not Response:
|
if not Response:
|
||||||
raise ModuleNotFoundError('Sanic is not installed')
|
raise ModuleNotFoundError('Sanic is not installed')
|
||||||
|
|
||||||
html = self.render(*args, **kwargs)
|
html = self.render(tpl, request=request, **kwargs)
|
||||||
return Response.HTTPResponse(body=html, status=status, content_type=ctype, headers=kwargs.get('headers', {}))
|
return Response.HTTPResponse(body=html, status=status, content_type=ctype, headers=kwargs.get('headers', {}))
|
||||||
|
|
|
@ -4,6 +4,7 @@ Hamlish-Jinja==0.3.3
|
||||||
Jinja2>=2.10.1
|
Jinja2>=2.10.1
|
||||||
jinja2-markdown>=0.0.3
|
jinja2-markdown>=0.0.3
|
||||||
Mastodon.py>=1.5.0
|
Mastodon.py>=1.5.0
|
||||||
|
multidict>=5.1.0
|
||||||
pycryptodome>=3.9.1
|
pycryptodome>=3.9.1
|
||||||
python-magic>=0.4.18
|
python-magic>=0.4.18
|
||||||
sanic>=19.12.2
|
sanic>=19.12.2
|
||||||
|
|
Loading…
Reference in a new issue