major update
This commit is contained in:
parent
221beb7670
commit
0e59542626
|
@ -8,4 +8,4 @@ import sys
|
||||||
assert sys.version_info >= (3, 6)
|
assert sys.version_info >= (3, 6)
|
||||||
|
|
||||||
|
|
||||||
__version__ = (0, 3, 1)
|
__version__ = (0, 4, 0)
|
||||||
|
|
|
@ -3,19 +3,14 @@ import re
|
||||||
|
|
||||||
from colour import Color
|
from colour import Color
|
||||||
|
|
||||||
from . import logging
|
|
||||||
|
|
||||||
|
|
||||||
check = lambda color: Color(f'#{str(color)}' if re.search(r'^(?:[0-9a-fA-F]{3}){1,2}$', color) else color)
|
check = lambda color: Color(f'#{str(color)}' if re.search(r'^(?:[0-9a-fA-F]{3}){1,2}$', color) else color)
|
||||||
|
|
||||||
def _multi(multiplier):
|
def _multi(multiplier):
|
||||||
if multiplier > 100:
|
if multiplier >= 1:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
elif multiplier > 1:
|
elif multiplier <= 0:
|
||||||
return multiplier/100
|
|
||||||
|
|
||||||
elif multiplier < 0:
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
return multiplier
|
return multiplier
|
||||||
|
|
536
IzzyLib/database.py
Normal file
536
IzzyLib/database.py
Normal file
|
@ -0,0 +1,536 @@
|
||||||
|
## Probably gonna replace all of this with a custom sqlalchemy setup tbh
|
||||||
|
## It'll look like the db classes and functions in https://git.barkshark.xyz/izaliamae/social
|
||||||
|
import shutil, traceback, importlib, sqlite3, sys, json
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from .cache import LRUCache
|
||||||
|
from .misc import Boolean, DotDict, Path
|
||||||
|
from . import logging, sql
|
||||||
|
|
||||||
|
try:
|
||||||
|
from dbutils.pooled_db import PooledDB
|
||||||
|
except ImportError:
|
||||||
|
from DBUtils.PooledDB import PooledDB
|
||||||
|
|
||||||
|
|
||||||
|
## Only sqlite3 has been tested
|
||||||
|
## Use other modules at your own risk.
|
||||||
|
class DB():
|
||||||
|
def __init__(self, tables, dbmodule='sqlite', cursor=None, **kwargs):
|
||||||
|
cursor = Cursor if not cursor else cursor
|
||||||
|
|
||||||
|
if dbmodule in ['sqlite', 'sqlite3']:
|
||||||
|
self.dbmodule = sqlite3
|
||||||
|
self.dbtype = 'sqlite'
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.dbmodule = None
|
||||||
|
self.__setup_module(dbmodule)
|
||||||
|
|
||||||
|
self.db = None
|
||||||
|
self.cursor = lambda : cursor(self).begin()
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.tables = tables
|
||||||
|
self.cache = DotDict()
|
||||||
|
self.__setup_database()
|
||||||
|
|
||||||
|
for table in tables.keys():
|
||||||
|
self.__setup_cache(table)
|
||||||
|
|
||||||
|
|
||||||
|
def query(self, *args, **kwargs):
|
||||||
|
return self.__cursor_cmd('query', *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch(self, *args, **kwargs):
|
||||||
|
return self.__cursor_cmd('fetch', *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def insert(self, *args, **kwargs):
|
||||||
|
return self.__cursor_cmd('insert', *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def update(self, *args, **kwargs):
|
||||||
|
return self.__cursor_cmd('update', *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def remove(self, *args, **kwargs):
|
||||||
|
return self.__cursor_cmd('remove', *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def __cursor_cmd(self, name, *args, **kwargs):
|
||||||
|
with self.cursor() as cur:
|
||||||
|
method = getattr(cur, name)
|
||||||
|
|
||||||
|
if not method:
|
||||||
|
raise KeyError('Not a valid cursor method:', name)
|
||||||
|
|
||||||
|
return method(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def __setup_module(self, dbtype):
|
||||||
|
modules = []
|
||||||
|
modtypes = {
|
||||||
|
'sqlite': ['sqlite3'],
|
||||||
|
'postgresql': ['psycopg3', 'pgdb', 'psycopg2', 'pg8000'],
|
||||||
|
'mysql': ['mysqldb', 'trio_mysql'],
|
||||||
|
'mssql': ['pymssql', 'adodbapi']
|
||||||
|
}
|
||||||
|
|
||||||
|
for dbmod, mods in modtypes.items():
|
||||||
|
if dbtype == dbmod:
|
||||||
|
self.dbtype = dbmod
|
||||||
|
modules = mods
|
||||||
|
break
|
||||||
|
|
||||||
|
elif dbtype in mods:
|
||||||
|
self.dbtype = dbmod
|
||||||
|
modules = [dbtype]
|
||||||
|
break
|
||||||
|
|
||||||
|
if not modules:
|
||||||
|
logging.verbose('Not a database type. Checking if it is a module...')
|
||||||
|
|
||||||
|
for mod in modules:
|
||||||
|
try:
|
||||||
|
self.dbmodule = importlib.import_module(mod)
|
||||||
|
except ImportError as e:
|
||||||
|
logging.verbose('Module not installed:', mod)
|
||||||
|
|
||||||
|
if self.dbmodule:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not self.dbmodule:
|
||||||
|
if modtypes.get(dbtype):
|
||||||
|
logging.error('Failed to find module for database type:', dbtype)
|
||||||
|
logging.error(f'Please install one of these modules to use a {dbtype} database:', ', '.join(modules))
|
||||||
|
|
||||||
|
else:
|
||||||
|
logging.error('Failed to import module:', dbtype)
|
||||||
|
logging.error('Install one of the following modules:')
|
||||||
|
|
||||||
|
for key, modules in modtypes.items():
|
||||||
|
logging.error(f'{key}:')
|
||||||
|
for mod in modules:
|
||||||
|
logging.error(f'\t{mod}')
|
||||||
|
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
|
||||||
|
def __setup_database(self):
|
||||||
|
if self.dbtype == 'sqlite':
|
||||||
|
if not self.kwargs.get('database'):
|
||||||
|
dbfile = ':memory:'
|
||||||
|
|
||||||
|
dbfile = Path(self.kwargs['database'])
|
||||||
|
dbfile.parent().mkdir()
|
||||||
|
|
||||||
|
if not dbfile.parent().isdir():
|
||||||
|
logging.error('Invalid path to database file:', dbfile.parent().str())
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
self.kwargs['database'] = dbfile.str()
|
||||||
|
self.kwargs['check_same_thread'] = False
|
||||||
|
|
||||||
|
else:
|
||||||
|
if not self.kwargs.get('password'):
|
||||||
|
self.kwargs.pop('password', None)
|
||||||
|
|
||||||
|
self.db = PooledDB(self.dbmodule, **self.kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def connection(self):
|
||||||
|
if self.dbtype == 'sqlite':
|
||||||
|
return self.db.connection()
|
||||||
|
|
||||||
|
|
||||||
|
def __setup_cache(self, table):
|
||||||
|
self.cache[table] = LRUCache(128)
|
||||||
|
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def count(self, table):
|
||||||
|
tabledict = tables.get(table)
|
||||||
|
|
||||||
|
if not tabledict:
|
||||||
|
logging.debug('Table does not exist:', table)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
data = self.query(f'SELECT COUNT(*) FROM {table}')
|
||||||
|
return data[0][0]
|
||||||
|
|
||||||
|
|
||||||
|
#def query(self, string, values=[], cursor=None):
|
||||||
|
#if not string.endswith(';'):
|
||||||
|
#string += ';'
|
||||||
|
|
||||||
|
#if not cursor:
|
||||||
|
#with self.Cursor() as cursor:
|
||||||
|
#cursor.execute(string, values)
|
||||||
|
#return cursor.fetchall()
|
||||||
|
|
||||||
|
#else:
|
||||||
|
#cursor.execute(string,value)
|
||||||
|
#return cursor.fetchall()
|
||||||
|
|
||||||
|
#return False
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def Cursor(self):
|
||||||
|
conn = self.db.connection()
|
||||||
|
conn.begin()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield cursor
|
||||||
|
except self.dbmodule.OperationalError:
|
||||||
|
traceback.print_exc()
|
||||||
|
conn.rollback()
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def CreateTable(self, table):
|
||||||
|
layout = DotDict(self.tables.get(table))
|
||||||
|
|
||||||
|
if not layout:
|
||||||
|
logging.error('Table config doesn\'t exist:', table)
|
||||||
|
return
|
||||||
|
|
||||||
|
cmd = f'CREATE TABLE IF NOT EXISTS {table}('
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for k, v in layout.items():
|
||||||
|
options = ' '.join(v.get('options', []))
|
||||||
|
default = v.get('default')
|
||||||
|
item = f'{k} {v["type"].upper()} {options}'
|
||||||
|
|
||||||
|
if default:
|
||||||
|
item += f'DEFAULT {default}'
|
||||||
|
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
cmd += ', '.join(items) + ')'
|
||||||
|
|
||||||
|
return True if self.query(cmd) != False else False
|
||||||
|
|
||||||
|
|
||||||
|
def RenameTable(self, table, newname):
|
||||||
|
self.query(f'ALTER TABLE {table} RENAME TO {newname}')
|
||||||
|
|
||||||
|
|
||||||
|
def DropTable(self, table):
|
||||||
|
self.query(f'DROP TABLE {table}')
|
||||||
|
|
||||||
|
|
||||||
|
def AddColumn(self, table, name, datatype, default=None, options=None):
|
||||||
|
query = f'ALTER TABLE {table} ADD COLUMN {name} {datatype.upper()}'
|
||||||
|
|
||||||
|
if default:
|
||||||
|
query += f' DEFAULT {default}'
|
||||||
|
|
||||||
|
if options:
|
||||||
|
query += f' {options}'
|
||||||
|
|
||||||
|
self.query(query)
|
||||||
|
|
||||||
|
|
||||||
|
def CheckDatabase(self, database):
|
||||||
|
if self.dbtype == 'postgresql':
|
||||||
|
tables = self.query('SELECT datname FROM pg_database')
|
||||||
|
|
||||||
|
else:
|
||||||
|
tables = []
|
||||||
|
|
||||||
|
tables = [table[0] for table in tables]
|
||||||
|
print(database in tables, database, tables)
|
||||||
|
|
||||||
|
return database in tables
|
||||||
|
|
||||||
|
|
||||||
|
def CreateTables(self):
|
||||||
|
dbname = self.kwargs['database']
|
||||||
|
|
||||||
|
for name, table in self.tables.items():
|
||||||
|
if len(self.query(sql.CheckTable(self.dbtype, name))) < 1:
|
||||||
|
logging.info('Creating table:', name)
|
||||||
|
self.query(table.sql())
|
||||||
|
|
||||||
|
|
||||||
|
class Table(dict):
|
||||||
|
def __init__(self, name, dbtype='sqlite'):
|
||||||
|
super().__init__({})
|
||||||
|
|
||||||
|
self.sqlstr = 'CREATE TABLE IF NOT EXISTS {} ({});'
|
||||||
|
self.dbtype = dbtype
|
||||||
|
self.name = name
|
||||||
|
self.columns = {}
|
||||||
|
self.fkeys = {}
|
||||||
|
|
||||||
|
|
||||||
|
def addColumn(self, name, datatype=None, null=True, unique=None, primary=None, fkey=None):
|
||||||
|
if name == 'id':
|
||||||
|
if self.dbtype == 'sqlite':
|
||||||
|
datatype = 'integer'
|
||||||
|
primary = True if primary == None else primary
|
||||||
|
unique = True
|
||||||
|
null = False
|
||||||
|
|
||||||
|
else:
|
||||||
|
datatype = 'serial'
|
||||||
|
primary = 'true'
|
||||||
|
|
||||||
|
if name == 'timestamp':
|
||||||
|
datatype = 'float'
|
||||||
|
|
||||||
|
elif not datatype:
|
||||||
|
raise MissingTypeError(f'Missing a data type for column: {name}')
|
||||||
|
|
||||||
|
colsql = f'{name} {datatype.upper()}'
|
||||||
|
|
||||||
|
if unique:
|
||||||
|
colsql += ' UNIQUE'
|
||||||
|
|
||||||
|
if not null:
|
||||||
|
colsql += ' NOT NULL'
|
||||||
|
|
||||||
|
if primary:
|
||||||
|
self.primary = name
|
||||||
|
colsql += ' PRIMARY KEY'
|
||||||
|
|
||||||
|
if name == 'id' and self.dbtype == 'sqlite':
|
||||||
|
colsql += ' AUTOINCREMENT'
|
||||||
|
|
||||||
|
if fkey:
|
||||||
|
if self.dbtype == 'postgresql':
|
||||||
|
colsql += f' REFERENCES {fkey[0]}({fkey[1]})'
|
||||||
|
|
||||||
|
elif self.dbtype == 'sqlite':
|
||||||
|
self.fkeys[name] += f'FOREIGN KEY ({name}) REFERENCES {fkey[0]} ({fkey[1]})'
|
||||||
|
|
||||||
|
self.columns.update({name: colsql})
|
||||||
|
|
||||||
|
|
||||||
|
def sql(self):
|
||||||
|
if not self.primary:
|
||||||
|
logging.error('Please specify a primary column')
|
||||||
|
return
|
||||||
|
|
||||||
|
data = ', '.join(list(self.columns.values()))
|
||||||
|
|
||||||
|
if self.fkeys:
|
||||||
|
data += ', '
|
||||||
|
data += ', '.join(list(self.fkeys.values()))
|
||||||
|
|
||||||
|
sqldata = self.sqlstr.format(self.name, data)
|
||||||
|
print(sqldata)
|
||||||
|
return sqldata
|
||||||
|
|
||||||
|
|
||||||
|
class Cursor(object):
|
||||||
|
def __init__(self, db):
|
||||||
|
self.main = db
|
||||||
|
self.db = db.db
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def begin(self):
|
||||||
|
self.conn = self.db.connection()
|
||||||
|
self.conn.begin()
|
||||||
|
self.cursor = self.conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self
|
||||||
|
|
||||||
|
except self.main.dbmodule.OperationalError:
|
||||||
|
self.conn.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self.cursor.close()
|
||||||
|
self.conn.commit()
|
||||||
|
self.conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def query(self, string, values=[], cursor=None):
|
||||||
|
#if not string.endswith(';'):
|
||||||
|
#string += ';'
|
||||||
|
|
||||||
|
self.cursor.execute(string, values)
|
||||||
|
data = self.cursor.fetchall()
|
||||||
|
|
||||||
|
|
||||||
|
def fetch(self, table, single=True, sort=None, **kwargs):
|
||||||
|
rowid = kwargs.get('id')
|
||||||
|
querysort = f'ORDER BY {sort}'
|
||||||
|
|
||||||
|
resultOpts = [self, table, self.cursor]
|
||||||
|
|
||||||
|
if rowid:
|
||||||
|
cursor.execute(f"SELECT * FROM {table} WHERE id = ?", [rowid])
|
||||||
|
|
||||||
|
elif kwargs:
|
||||||
|
placeholders = [f'{k} = ?' for k in kwargs.keys()]
|
||||||
|
values = kwargs.values()
|
||||||
|
|
||||||
|
where = ' and '.join(placeholders)
|
||||||
|
query = f"SELECT * FROM {table} WHERE {where} {querysort if sort else ''}"
|
||||||
|
self.cursor.execute(query, list(values))
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.cursor.execute(f'SELECT * FROM {table} {querysort if sort else ""}')
|
||||||
|
|
||||||
|
rows = self.cursor.fetchall() if not single else self.cursor.fetchone()
|
||||||
|
|
||||||
|
if rows:
|
||||||
|
if single:
|
||||||
|
return DBResult(rows, *resultOpts)
|
||||||
|
|
||||||
|
return [DBResult(row, *resultOpts) for row in rows]
|
||||||
|
|
||||||
|
return None if single else []
|
||||||
|
|
||||||
|
|
||||||
|
def insert(self, table, data={}, **kwargs):
|
||||||
|
data.update(kwargs)
|
||||||
|
placeholders = ",".join(['?' for _ in data.keys()])
|
||||||
|
values = tuple(data.values())
|
||||||
|
keys = ','.join(data.keys())
|
||||||
|
|
||||||
|
if 'timestamp' in self.main.tables[table].keys() and 'timestamp' not in keys:
|
||||||
|
data['timestamp'] = datetime.now()
|
||||||
|
|
||||||
|
query = f'INSERT INTO {table} ({keys}) VALUES ({placeholders})'
|
||||||
|
self.query(query, values)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def remove(self, table, **kwargs):
|
||||||
|
keys = []
|
||||||
|
values = []
|
||||||
|
|
||||||
|
for k,v in kwargs.items():
|
||||||
|
keys.append(k)
|
||||||
|
values.append(v)
|
||||||
|
|
||||||
|
keydata = ','.join([f'{k} = ?' for k in keys])
|
||||||
|
query = f'DELETE FROM {table} WHERE {keydata}'
|
||||||
|
|
||||||
|
self.query(query, values)
|
||||||
|
|
||||||
|
|
||||||
|
def update(self, table, rowid, data={}, **kwargs):
|
||||||
|
data.update(kwargs)
|
||||||
|
newdata = {k: v for k, v in data.items() if k in self.main.tables[table].keys()}
|
||||||
|
keys = list(newdata.keys())
|
||||||
|
values = list(newdata.values())
|
||||||
|
|
||||||
|
if len(newdata) < 1:
|
||||||
|
logging.debug('No data provided to update row')
|
||||||
|
return False
|
||||||
|
|
||||||
|
query_data = ', '.join(f'{k} = ?' for k in keys)
|
||||||
|
query = f'UPDATE {table} SET {query_data} WHERE id = {rowid}'
|
||||||
|
|
||||||
|
self.query(query, values)
|
||||||
|
|
||||||
|
|
||||||
|
class DBResult(DotDict):
|
||||||
|
def __init__(self, row, db, table, cursor):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.db = db
|
||||||
|
self.table = table
|
||||||
|
|
||||||
|
for idx, col in enumerate(cursor.description):
|
||||||
|
self[col[0]] = row[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
if name not in ['db', 'table']:
|
||||||
|
return self.__setitem__(name, value)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return super().__setattr__(name, value)
|
||||||
|
|
||||||
|
|
||||||
|
def __delattr__(self, name):
|
||||||
|
if name not in ['db', 'table']:
|
||||||
|
return self.__delitem__(name)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return super().__delattr__(name)
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(self, value, default=None):
|
||||||
|
options = [value]
|
||||||
|
|
||||||
|
if default:
|
||||||
|
options.append(default)
|
||||||
|
|
||||||
|
if value in self.keys():
|
||||||
|
val = super().__getitem__(*options)
|
||||||
|
return DotDict(val) if isinstance(val, dict) else val
|
||||||
|
|
||||||
|
else:
|
||||||
|
return dict.__getattr__(*options)
|
||||||
|
|
||||||
|
|
||||||
|
# Kept for backwards compatibility. Delete later.
|
||||||
|
def asdict(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def Update(self, data={}):
|
||||||
|
with self.db.Cursor().begin() as cursor:
|
||||||
|
self.update(data)
|
||||||
|
cursor.update(self.table, self.id, self.AsDict())
|
||||||
|
|
||||||
|
|
||||||
|
def Remove(self):
|
||||||
|
with self.db.Cursor().begin() as cursor:
|
||||||
|
cursor.remove(self.table, id=self.id)
|
||||||
|
|
||||||
|
|
||||||
|
def ParseData(table, row):
|
||||||
|
tbdata = tables.get(table)
|
||||||
|
types = []
|
||||||
|
result = []
|
||||||
|
|
||||||
|
if not tbdata:
|
||||||
|
logging.error('Invalid table:', table)
|
||||||
|
return
|
||||||
|
|
||||||
|
for v in tbdata.values():
|
||||||
|
dtype = v.split()[0].upper()
|
||||||
|
|
||||||
|
if dtype == 'BOOLEAN':
|
||||||
|
types.append(boolean)
|
||||||
|
|
||||||
|
elif dtype in ['INTEGER', 'INT']:
|
||||||
|
types.append(int)
|
||||||
|
|
||||||
|
else:
|
||||||
|
types.append(str)
|
||||||
|
|
||||||
|
for idx, v in enumerate(row):
|
||||||
|
result.append(types[idx](v) if v else None)
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
db.insert('config', {'key': 'version', 'value': dbversion})
|
||||||
|
|
||||||
|
|
||||||
|
class MissingTypeError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TooManyConnectionsError(Exception):
|
||||||
|
pass
|
23
IzzyLib/error.py
Normal file
23
IzzyLib/error.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
class MissingHeadersError(Exception):
|
||||||
|
def __init__(self, headers: list):
|
||||||
|
self.headers = ', '.join(headers)
|
||||||
|
self.message = f'Missing required headers for verificaton: {self.headers}'
|
||||||
|
super().init(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.message
|
||||||
|
|
||||||
|
|
||||||
|
class VerificationError(Exception):
|
||||||
|
def __init__(self, string=None):
|
||||||
|
self.message = f'Failed to verify hash'
|
||||||
|
|
||||||
|
if string:
|
||||||
|
self.message += ' for ' + string
|
||||||
|
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.message
|
445
IzzyLib/http.py
445
IzzyLib/http.py
|
@ -1,207 +1,342 @@
|
||||||
import traceback, urllib3, json
|
import functools, json, sys
|
||||||
|
|
||||||
|
from IzzyLib import logging
|
||||||
|
from IzzyLib.misc import DefaultDict, DotDict
|
||||||
from base64 import b64decode, b64encode
|
from base64 import b64decode, b64encode
|
||||||
from urllib.parse import urlparse
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from urllib.error import HTTPError
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
import httpsig
|
from . import error
|
||||||
|
|
||||||
from Crypto.PublicKey import RSA
|
try:
|
||||||
from Crypto.Hash import SHA, SHA256, SHA384, SHA512
|
from Crypto.Hash import SHA256
|
||||||
from Crypto.Signature import PKCS1_v1_5
|
from Crypto.PublicKey import RSA
|
||||||
|
from Crypto.Signature import PKCS1_v1_5
|
||||||
|
crypto_enabled = True
|
||||||
|
except ImportError:
|
||||||
|
logging.verbose('Pycryptodome module not found. HTTP header signing and verifying is disabled')
|
||||||
|
crypto_enabled = False
|
||||||
|
|
||||||
from . import logging, __version__
|
try:
|
||||||
from .cache import TTLCache, LRUCache
|
from sanic.request import Request as SanicRequest
|
||||||
from .misc import formatUTC
|
sanic_enabled = True
|
||||||
|
except ImportError:
|
||||||
|
logging.verbose('Sanic module not found. Request verification is disabled')
|
||||||
|
sanic_enabled = False
|
||||||
|
|
||||||
|
|
||||||
version = '.'.join([str(num) for num in __version__])
|
Client = None
|
||||||
|
|
||||||
|
|
||||||
class httpClient:
|
def VerifyRequest(request: SanicRequest, actor: dict=None):
|
||||||
def __init__(self, pool=100, timeout=30, headers={}, agent=None):
|
'''Verify a header signature from a sanic request
|
||||||
self.cache = LRUCache()
|
|
||||||
self.pool = pool
|
|
||||||
self.timeout = timeout
|
|
||||||
self.agent = agent if agent else f'IzzyLib/{version}'
|
|
||||||
self.headers = headers
|
|
||||||
|
|
||||||
self.client = urllib3.PoolManager(num_pools=self.pool, timeout=self.timeout)
|
request: The request with the headers to verify
|
||||||
self.headers['User-Agent'] = self.agent
|
actor (optional): A dictionary containing the activitypub actor and the link to the pubkey used for verification
|
||||||
|
'''
|
||||||
|
if not sanic_enabled:
|
||||||
|
logging.error('Sanic request verification disabled')
|
||||||
|
return
|
||||||
|
|
||||||
|
if not actor:
|
||||||
|
actor = request.ctx.actor
|
||||||
|
|
||||||
|
body = request.body if request.body else None
|
||||||
|
return VerifyHeaders(request.headers, request.method, request.path, body, actor, False)
|
||||||
|
|
||||||
|
|
||||||
def _fetch(self, url, headers={}, method='GET', data=None, cached=True):
|
def VerifyHeaders(headers: dict, method: str, path: str, actor: dict=None, body=None, fail: bool=False):
|
||||||
cached_data = self.cache.fetch(url)
|
'''Verify a header signature
|
||||||
#url = url.split('#')[0]
|
|
||||||
|
|
||||||
if cached and cached_data:
|
headers: A dictionary containing all the headers from a request
|
||||||
logging.debug(f'Returning cached data for {url}')
|
method: The HTTP method of the request
|
||||||
return cached_data
|
path: The path of the HTTP request
|
||||||
|
actor (optional): A dictionary containing the activitypub actor and the link to the pubkey used for verification
|
||||||
|
body (optional): The body of the request. Only needed if the signature includes the digest header
|
||||||
|
fail (optional): If set to True, raise an error instead of returning False if any step of the process fails
|
||||||
|
'''
|
||||||
|
if not crypto_enabled:
|
||||||
|
logging.error('Crypto functions disabled')
|
||||||
|
return
|
||||||
|
|
||||||
if not headers.get('User-Agent'):
|
headers = {k.lower(): v for k,v in headers.items()}
|
||||||
headers.update({'User-Agent': self.agent})
|
headers['(request-target)'] = f'{method.lower()} {path}'
|
||||||
|
signature = ParseSig(headers.get('signature'))
|
||||||
|
digest = headers.get('digest')
|
||||||
|
missing_headers = [k for k in headers if k in ['date', 'host'] if headers.get(k) == None]
|
||||||
|
|
||||||
logging.debug(f'Fetching new data for {url}')
|
if not signature:
|
||||||
|
if fail:
|
||||||
|
raise MissingSignatureError()
|
||||||
|
|
||||||
try:
|
return False
|
||||||
if data:
|
|
||||||
if isinstance(data, dict):
|
|
||||||
data = json.dumps(data)
|
|
||||||
|
|
||||||
resp = self.client.request(method, url, headers=headers, body=data)
|
if not actor:
|
||||||
|
actor = FetchActor(signature.keyid)
|
||||||
|
|
||||||
else:
|
## Add digest header to missing headers list if it doesn't exist
|
||||||
resp = self.client.request(method, url, headers=headers)
|
if method.lower() == 'post' and not headers.get('digest'):
|
||||||
|
missing_headers.append('digest')
|
||||||
|
|
||||||
except Exception as e:
|
## Fail if missing date, host or digest (if POST) headers
|
||||||
logging.debug(f'Failed to fetch url: {e}')
|
if missing_headers:
|
||||||
return
|
if fail:
|
||||||
|
raise error.MissingHeadersError(missing_headers)
|
||||||
|
|
||||||
if cached:
|
return False
|
||||||
logging.debug(f'Caching {url}')
|
|
||||||
self.cache.store(url, resp)
|
|
||||||
|
|
||||||
return resp
|
## Fail if body verification fails
|
||||||
|
if digest and not VerifyString(body, digest):
|
||||||
|
if fail:
|
||||||
|
raise error.VerificationError('digest header')
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
pubkey = actor.publicKey['publicKeyPem']
|
||||||
|
|
||||||
|
if PkcsHeaders(pubkey, {k:v for k,v in headers.items() if k in signature.headers}, sig=signature):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if fail:
|
||||||
|
raise error.VerificationError('headers')
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def raw(self, *args, **kwargs):
|
@functools.lru_cache(maxsize=512)
|
||||||
'''
|
def VerifyString(string, enc_string, alg='SHA256', fail=False):
|
||||||
Return a response object
|
if not crypto_enabled:
|
||||||
'''
|
logging.error('Crypto functions disabled')
|
||||||
return self._fetch(*args, **kwargs)
|
return
|
||||||
|
|
||||||
|
if type(string) != bytes:
|
||||||
|
string = string.encode('UTF-8')
|
||||||
|
|
||||||
|
body_hash = b64encode(SHA256.new(string).digest()).decode('UTF-8')
|
||||||
|
|
||||||
|
if body_hash == enc_string:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if fail:
|
||||||
|
raise error.VerificationError()
|
||||||
|
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def text(self, *args, **kwargs):
|
def PkcsHeaders(key: str, headers: dict, sig=None):
|
||||||
'''
|
if not crypto_enabled:
|
||||||
Return the body as text
|
logging.error('Crypto functions disabled')
|
||||||
'''
|
return
|
||||||
resp = self._fetch(*args, **kwargs)
|
|
||||||
|
|
||||||
return resp.data.decode() if resp else None
|
if sig:
|
||||||
|
head_items = [f'{item}: {headers[item]}' for item in sig.headers]
|
||||||
|
|
||||||
|
else:
|
||||||
|
head_items = [f'{k.lower()}: {v}' for k,v in headers.items()]
|
||||||
|
|
||||||
|
head_string = '\n'.join(head_items)
|
||||||
|
head_bytes = head_string.encode('UTF-8')
|
||||||
|
|
||||||
|
KEY = RSA.importKey(key)
|
||||||
|
pkcs = PKCS1_v1_5.new(KEY)
|
||||||
|
h = SHA256.new(head_bytes)
|
||||||
|
|
||||||
|
if sig:
|
||||||
|
return pkcs.verify(h, b64decode(sig.signature))
|
||||||
|
|
||||||
|
else:
|
||||||
|
return pkcs.sign(h)
|
||||||
|
|
||||||
|
|
||||||
def json(self, *args, **kwargs):
|
def ParseSig(signature: str):
|
||||||
'''
|
if not signature:
|
||||||
Return the body as a dict if it's json
|
|
||||||
'''
|
|
||||||
|
|
||||||
headers = kwargs.get('headers')
|
|
||||||
|
|
||||||
if not headers:
|
|
||||||
kwargs['headers'] = {}
|
|
||||||
|
|
||||||
kwargs['headers'].update({'Accept': 'application/json'})
|
|
||||||
resp = self._fetch(*args, **kwargs)
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = json.loads(resp.data.decode())
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.debug(f'Failed to load json: {e}')
|
|
||||||
return
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def ParseSig(headers):
|
|
||||||
sig_header = headers.get('signature')
|
|
||||||
|
|
||||||
if not sig_header:
|
|
||||||
logging.verbose('Missing signature header')
|
logging.verbose('Missing signature header')
|
||||||
return
|
return
|
||||||
|
|
||||||
split_sig = sig_header.split(',')
|
split_sig = signature.split(',')
|
||||||
signature = {}
|
sig = DefaultDict({})
|
||||||
|
|
||||||
for part in split_sig:
|
for part in split_sig:
|
||||||
key, value = part.split('=', 1)
|
key, value = part.split('=', 1)
|
||||||
signature[key.lower()] = value.replace('"', '')
|
sig[key.lower()] = value.replace('"', '')
|
||||||
|
|
||||||
if not signature.get('headers'):
|
if not sig.headers:
|
||||||
logging.verbose('Missing headers section in signature')
|
logging.verbose('Missing headers section in signature')
|
||||||
return
|
return
|
||||||
|
|
||||||
signature['headers'] = signature['headers'].split()
|
sig.headers = sig.headers.split()
|
||||||
|
|
||||||
return signature
|
return sig
|
||||||
|
|
||||||
|
|
||||||
def SignHeaders(headers, keyid, privkey, url, method='GET'):
|
@functools.lru_cache(maxsize=512)
|
||||||
'''
|
def FetchActor(keyid, client=None):
|
||||||
Signs headers and returns them with a signature header
|
if not client:
|
||||||
|
client = Client if Client else HttpClient()
|
||||||
|
|
||||||
headers (dict): Headers to be signed
|
actor = Client.request(keyid).json()
|
||||||
keyid (str): Url to the public key used to verify the signature
|
actor.domain = urlparse(actor.id).netloc
|
||||||
privkey (str): Private key used to sign the headers
|
actor.shared_inbox = actor.inbox
|
||||||
url (str): Url of the request for the signed headers
|
actor.pubkey = None
|
||||||
method (str): Http method of the request for the signed headers
|
|
||||||
'''
|
|
||||||
|
|
||||||
RSAkey = RSA.import_key(privkey)
|
if actor.get('endpoints'):
|
||||||
key_size = int(RSAkey.size_in_bytes()/2)
|
actor.shared_inbox = actor.endpoints.get('sharedInbox', actor.inbox)
|
||||||
logging.debug('Signing key size:', key_size)
|
|
||||||
|
|
||||||
parsed_url = urlparse(url)
|
if actor.get('publicKey'):
|
||||||
logging.debug(parsed_url)
|
actor.pubkey = actor.publicKey.get('publicKeyPem')
|
||||||
|
|
||||||
raw_headers = {'date': formatUTC(), 'host': parsed_url.netloc, '(request-target)': ' '.join([method, parsed_url.path])}
|
return actor
|
||||||
raw_headers.update(dict(headers))
|
|
||||||
header_keys = raw_headers.keys()
|
|
||||||
|
|
||||||
signer = httpsig.HeaderSigner(keyid, privkey, f'rsa-sha{key_size}', headers=header_keys, sign_header='signature')
|
|
||||||
new_headers = signer.sign(raw_headers, parsed_url.netloc, method, parsed_url.path)
|
|
||||||
logging.debug('Signed headers:', new_headers)
|
|
||||||
|
|
||||||
del new_headers['(request-target)']
|
|
||||||
|
|
||||||
return new_headers
|
|
||||||
|
|
||||||
|
|
||||||
def ValidateSignature(headers, method, path, client=None, agent=None):
|
class HttpClient(object):
|
||||||
'''
|
def __init__(self, headers={}, useragent='IzzyLib/0.3', proxy_type='https', proxy_host=None, proxy_port=None):
|
||||||
Validates the signature header.
|
proxy_ports = {
|
||||||
|
'http': 80,
|
||||||
|
'https': 443
|
||||||
|
}
|
||||||
|
|
||||||
headers (dict): All of the headers to be used to check a signature. The signature header must be included too
|
if proxy_type not in ['http', 'https']:
|
||||||
method (str): The http method used in relation to the headers
|
raise ValueError(f'Not a valid proxy type: {proxy_type}')
|
||||||
path (str): The path of the request in relation to the headers
|
|
||||||
client (pool object): Specify a httpClient to use for fetching the actor. optional
|
|
||||||
agent (str): User agent used for fetching actor data. optional
|
|
||||||
'''
|
|
||||||
|
|
||||||
client = httpClient(agent=agent) if not client else client
|
self.headers=headers
|
||||||
headers = {k.lower(): v for k,v in headers.items()}
|
self.agent=useragent
|
||||||
|
self.proxy = DotDict({
|
||||||
signature = ParseSig(headers)
|
'enabled': True if proxy_host else False,
|
||||||
|
'ptype': proxy_type,
|
||||||
actor_data = client.json(signature['keyid'])
|
'host': proxy_host,
|
||||||
logging.debug(actor_data)
|
'port': proxy_ports[proxy_type] if not proxy_port else proxy_port
|
||||||
|
})
|
||||||
try:
|
|
||||||
pubkey = actor_data['publicKey']['publicKeyPem']
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.verbose(f'Failed to get public key for actor {signature["keyid"]}')
|
|
||||||
return
|
|
||||||
|
|
||||||
valid = httpsig.HeaderVerifier(headers, pubkey, signature['headers'], method, path, sign_header='signature').verify()
|
|
||||||
|
|
||||||
if not valid:
|
|
||||||
if not isinstance(valid, tuple):
|
|
||||||
logging.verbose('Signature validation failed for unknown actor')
|
|
||||||
logging.verbose(valid)
|
|
||||||
|
|
||||||
else:
|
|
||||||
logging.verbose(f'Signature validation failed for actor: {valid[1]}')
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def ValidateRequest(request, client=None, agent=None):
|
def __sign_request(self, request, privkey, keyid):
|
||||||
'''
|
if not crypto_enabled:
|
||||||
Validates the headers in a Sanic or Aiohttp request (other frameworks may be supported)
|
logging.error('Crypto functions disabled')
|
||||||
See ValidateSignature for 'client' and 'agent' usage
|
return
|
||||||
'''
|
|
||||||
return ValidateSignature(request.headers, request.method, request.path, client, agent)
|
request.add_header('(request-target)', f'{request.method.lower()} {request.path}')
|
||||||
|
request.add_header('host', request.host)
|
||||||
|
request.add_header('date', datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'))
|
||||||
|
|
||||||
|
if request.body:
|
||||||
|
body_hash = b64encode(SHA256.new(request.body).digest()).decode("UTF-8")
|
||||||
|
request.add_header('digest', f'SHA-256={body_hash}')
|
||||||
|
request.add_header('content-length', len(request.body))
|
||||||
|
|
||||||
|
sig = {
|
||||||
|
'keyId': keyid,
|
||||||
|
'algorithm': 'rsa-sha256',
|
||||||
|
'headers': ' '.join([k.lower() for k in request.headers.keys()]),
|
||||||
|
'signature': b64encode(PkcsHeaders(privkey, request.headers)).decode('UTF-8')
|
||||||
|
}
|
||||||
|
|
||||||
|
sig_items = [f'{k}="{v}"' for k,v in sig.items()]
|
||||||
|
sig_string = ','.join(sig_items)
|
||||||
|
|
||||||
|
request.add_header('signature', sig_string)
|
||||||
|
|
||||||
|
request.remove_header('(request-target)')
|
||||||
|
request.remove_header('host')
|
||||||
|
|
||||||
|
|
||||||
|
def __build_request(self, url, data=None, headers={}, method='GET'):
|
||||||
|
new_headers = self.headers.copy()
|
||||||
|
new_headers.update(headers)
|
||||||
|
|
||||||
|
parsed_headers = {k.lower(): v.lower() for k,v in new_headers.items()}
|
||||||
|
|
||||||
|
if not parsed_headers.get('user-agent'):
|
||||||
|
parsed_headers['user-agent'] = self.agent
|
||||||
|
|
||||||
|
if isinstance(data, dict):
|
||||||
|
data = json.dumps(data)
|
||||||
|
|
||||||
|
if isinstance(data, str):
|
||||||
|
data = data.encode('UTF-8')
|
||||||
|
|
||||||
|
request = HttpRequest(url, data=data, headers=parsed_headers, method=method)
|
||||||
|
|
||||||
|
if self.proxy.enabled:
|
||||||
|
request.set_proxy(f'{self.proxy.host}:{self.proxy.host}', self.proxy.ptype)
|
||||||
|
|
||||||
|
return request
|
||||||
|
|
||||||
|
|
||||||
|
def request(self, *args, **kwargs):
|
||||||
|
request = self.__build_request(*args, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = urlopen(request)
|
||||||
|
except HTTPError as e:
|
||||||
|
response = e.fp
|
||||||
|
|
||||||
|
return HttpResponse(response)
|
||||||
|
|
||||||
|
|
||||||
|
def json(self, *args, headers={}, activity=True, **kwargs):
|
||||||
|
json_type = 'activity+json' if activity else 'json'
|
||||||
|
headers.update({
|
||||||
|
'accept': f'application/{json_type}'
|
||||||
|
})
|
||||||
|
return self.request(*args, headers=headers, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def signed_request(self, privkey, keyid, *args, **kwargs):
|
||||||
|
request = self.__build_request(*args, **kwargs)
|
||||||
|
|
||||||
|
self.__sign_request(request, privkey, keyid)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = urlopen(request)
|
||||||
|
except HTTPError as e:
|
||||||
|
response = e
|
||||||
|
|
||||||
|
return HttpResponse(response)
|
||||||
|
|
||||||
|
|
||||||
|
class HttpRequest(Request):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
parsed = urlparse(self.full_url)
|
||||||
|
|
||||||
|
self.scheme = parsed.scheme
|
||||||
|
self.host = parsed.netloc
|
||||||
|
self.domain = parsed.hostname
|
||||||
|
self.port = parsed.port
|
||||||
|
self.path = parsed.path
|
||||||
|
self.query = parsed.query
|
||||||
|
self.body = self.data if self.data else b''
|
||||||
|
|
||||||
|
|
||||||
|
class HttpResponse(object):
|
||||||
|
def __init__(self, response):
|
||||||
|
self.body = response.read()
|
||||||
|
self.headers = DefaultDict({k.lower(): v.lower() for k,v in response.headers.items()})
|
||||||
|
self.status = response.status
|
||||||
|
self.url = response.url
|
||||||
|
|
||||||
|
|
||||||
|
def text(self):
|
||||||
|
return self.body.decode('UTF-8')
|
||||||
|
|
||||||
|
|
||||||
|
def json(self, fail=False):
|
||||||
|
try:
|
||||||
|
return DotDict(self.text())
|
||||||
|
except Exception as e:
|
||||||
|
if fail:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
else:
|
||||||
|
return DotDict()
|
||||||
|
|
||||||
|
|
||||||
|
def json_pretty(self, indent=4):
|
||||||
|
return json.dumps(self.json().asDict(), indent=indent)
|
||||||
|
|
||||||
|
|
||||||
|
def SetClient(client: HttpClient):
|
||||||
|
global Client
|
||||||
|
Client = client
|
||||||
|
|
|
@ -25,7 +25,17 @@ class Log():
|
||||||
'MERP': 0
|
'MERP': 0
|
||||||
}
|
}
|
||||||
|
|
||||||
self.config = dict()
|
self.long_levels = {
|
||||||
|
'CRITICAL': 'CRIT',
|
||||||
|
'ERROR': 'ERROR',
|
||||||
|
'WARNING': 'WARN',
|
||||||
|
'INFO': 'INFO',
|
||||||
|
'VERBOSE': 'VERB',
|
||||||
|
'DEBUG': 'DEBUG',
|
||||||
|
'MERP': 'MERP'
|
||||||
|
}
|
||||||
|
|
||||||
|
self.config = {'windows': sys.executable.endswith('pythonw.exe')}
|
||||||
self.setConfig(self._parseConfig(config))
|
self.setConfig(self._parseConfig(config))
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,10 +45,11 @@ class Log():
|
||||||
value = int(level)
|
value = int(level)
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
value = self.levels.get(level.upper())
|
level = self.long_levels.get(level.upper(), level)
|
||||||
|
value = self.levels.get(level)
|
||||||
|
|
||||||
if value not in self.levels.values():
|
if value not in self.levels.values():
|
||||||
raise error.InvalidLevel(f'Invalid logging level: {level}')
|
raise InvalidLevel(f'Invalid logging level: {level}')
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@ -48,13 +59,14 @@ class Log():
|
||||||
if level == num:
|
if level == num:
|
||||||
return name
|
return name
|
||||||
|
|
||||||
raise error.InvalidLevel(f'Invalid logging level: {level}')
|
raise InvalidLevel(f'Invalid logging level: {level}')
|
||||||
|
|
||||||
|
|
||||||
def _parseConfig(self, config):
|
def _parseConfig(self, config):
|
||||||
'''parse the new config and update the old values'''
|
'''parse the new config and update the old values'''
|
||||||
date = config.get('date', self.config.get('date',True))
|
date = config.get('date', self.config.get('date',True))
|
||||||
systemd = config.get('systemd', self.config.get('systemd,', True))
|
systemd = config.get('systemd', self.config.get('systemd,', True))
|
||||||
|
windows = config.get('windows', self.config.get('windows', False))
|
||||||
|
|
||||||
if not isinstance(date, bool):
|
if not isinstance(date, bool):
|
||||||
raise TypeError(f'value for "date" is not a boolean: {date}')
|
raise TypeError(f'value for "date" is not a boolean: {date}')
|
||||||
|
@ -64,14 +76,18 @@ class Log():
|
||||||
|
|
||||||
level_num = self._lvlCheck(config.get('level', self.config.get('level', 'INFO')))
|
level_num = self._lvlCheck(config.get('level', self.config.get('level', 'INFO')))
|
||||||
|
|
||||||
return {
|
newconfig = {
|
||||||
'level': self._getLevelName(level_num),
|
'level': self._getLevelName(level_num),
|
||||||
'levelnum': level_num,
|
'levelnum': level_num,
|
||||||
'datefmt': config.get('datefmt', self.config.get('datefmt', '%Y-%m-%d %H:%M:%S')),
|
'datefmt': config.get('datefmt', self.config.get('datefmt', '%Y-%m-%d %H:%M:%S')),
|
||||||
'date': date,
|
'date': date,
|
||||||
'systemd': systemd
|
'systemd': systemd,
|
||||||
|
'windows': windows,
|
||||||
|
'systemnotif': config.get('systemnotif', None)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return newconfig
|
||||||
|
|
||||||
|
|
||||||
def setConfig(self, config):
|
def setConfig(self, config):
|
||||||
'''set the config'''
|
'''set the config'''
|
||||||
|
@ -81,8 +97,8 @@ class Log():
|
||||||
def getConfig(self, key=None):
|
def getConfig(self, key=None):
|
||||||
'''return the current config'''
|
'''return the current config'''
|
||||||
if key:
|
if key:
|
||||||
if self.get(key):
|
if self.config.get(key):
|
||||||
return self.get(key)
|
return self.config.get(key)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Invalid config option: {key}')
|
raise ValueError(f'Invalid config option: {key}')
|
||||||
return self.config
|
return self.config
|
||||||
|
@ -95,7 +111,14 @@ class Log():
|
||||||
stdout.flush()
|
stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def setLevel(self, level):
|
||||||
|
self.minimum = self._lvlCheck(level)
|
||||||
|
|
||||||
|
|
||||||
def log(self, level, *msg):
|
def log(self, level, *msg):
|
||||||
|
if self.config['windows']:
|
||||||
|
return
|
||||||
|
|
||||||
'''log to the console'''
|
'''log to the console'''
|
||||||
levelNum = self._lvlCheck(level)
|
levelNum = self._lvlCheck(level)
|
||||||
|
|
||||||
|
@ -108,6 +131,9 @@ class Log():
|
||||||
message = ' '.join([str(message) for message in msg])
|
message = ' '.join([str(message) for message in msg])
|
||||||
output = f'{level}: {message}\n'
|
output = f'{level}: {message}\n'
|
||||||
|
|
||||||
|
if self.config['systemnotif']:
|
||||||
|
self.config['systemnotif'].New(level, message)
|
||||||
|
|
||||||
if self.config['date'] and (self.config['systemd'] and not env.get('INVOCATION_ID')):
|
if self.config['date'] and (self.config['systemd'] and not env.get('INVOCATION_ID')):
|
||||||
'''only show date when not running in systemd and date var is True'''
|
'''only show date when not running in systemd and date var is True'''
|
||||||
date = datetime.now().strftime(self.config['datefmt'])
|
date = datetime.now().strftime(self.config['datefmt'])
|
||||||
|
@ -148,18 +174,15 @@ def getLogger(loginst, config=None):
|
||||||
logger[loginst] = Log(config)
|
logger[loginst] = Log(config)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise error.InvalidLogger(f'logger "{loginst}" doesn\'t exist')
|
raise InvalidLogger(f'logger "{loginst}" doesn\'t exist')
|
||||||
|
|
||||||
return logger[loginst]
|
return logger[loginst]
|
||||||
|
|
||||||
class error:
|
class InvalidLevel(Exception):
|
||||||
'''base class for all errors'''
|
'''Raise when an invalid logging level was specified'''
|
||||||
|
|
||||||
class InvalidLevel(Exception):
|
class InvalidLogger(Exception):
|
||||||
'''Raise when an invalid logging level was specified'''
|
'''Raise when the specified logger doesn't exist'''
|
||||||
|
|
||||||
class InvalidLogger(Exception):
|
|
||||||
'''Raise when the specified logger doesn't exist'''
|
|
||||||
|
|
||||||
|
|
||||||
'''create a default logger'''
|
'''create a default logger'''
|
||||||
|
@ -182,4 +205,5 @@ merp = DefaultLog.merp
|
||||||
'''aliases for the default logger's config functions'''
|
'''aliases for the default logger's config functions'''
|
||||||
setConfig = DefaultLog.setConfig
|
setConfig = DefaultLog.setConfig
|
||||||
getConfig = DefaultLog.getConfig
|
getConfig = DefaultLog.getConfig
|
||||||
|
setLevel = DefaultLog.setLevel
|
||||||
printConfig = DefaultLog.printConfig
|
printConfig = DefaultLog.printConfig
|
||||||
|
|
480
IzzyLib/misc.py
480
IzzyLib/misc.py
|
@ -1,16 +1,16 @@
|
||||||
'''Miscellaneous functions'''
|
'''Miscellaneous functions'''
|
||||||
import random, string, sys, os, shlex, subprocess, socket, traceback
|
import random, string, sys, os, json, socket
|
||||||
|
|
||||||
from os.path import abspath, dirname, basename, isdir, isfile
|
|
||||||
from os import environ as env
|
from os import environ as env
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from collections import namedtuple
|
from getpass import getpass
|
||||||
|
from pathlib import Path as Pathlib
|
||||||
|
|
||||||
from . import logging
|
from . import logging
|
||||||
|
|
||||||
|
|
||||||
def boolean(v, fail=True):
|
def Boolean(v, return_value=False):
|
||||||
if type(v) in [dict, list, tuple]:
|
if type(v) not in [str, bool, int, type(None)]:
|
||||||
raise ValueError(f'Value is not a string, boolean, int, or nonetype: {value}')
|
raise ValueError(f'Value is not a string, boolean, int, or nonetype: {value}')
|
||||||
|
|
||||||
'''make the value lowercase if it's a string'''
|
'''make the value lowercase if it's a string'''
|
||||||
|
@ -20,101 +20,50 @@ def boolean(v, fail=True):
|
||||||
'''convert string to True'''
|
'''convert string to True'''
|
||||||
return True
|
return True
|
||||||
|
|
||||||
elif value in [0, False, None, 'off', 'n', 'no', 'false', 'disable', 'none', '']:
|
if value in [0, False, None, 'off', 'n', 'no', 'false', 'disable', '']:
|
||||||
'''convert string to False'''
|
'''convert string to False'''
|
||||||
return False
|
return False
|
||||||
|
|
||||||
elif not fail:
|
if return_value:
|
||||||
'''just return the value'''
|
'''just return the value'''
|
||||||
return value
|
return v
|
||||||
|
|
||||||
else:
|
return True
|
||||||
raise ValueError(f'Value cannot be converted to a boolean: {value}')
|
|
||||||
|
|
||||||
|
|
||||||
def randomgen(chars=20):
|
def RandomGen(length=20, chars=None):
|
||||||
if not isinstance(chars, int):
|
if not isinstance(length, int):
|
||||||
raise TypeError(f'Character length must be an integer, not a {type(char)}')
|
raise TypeError(f'Character length must be an integer, not {type(length)}')
|
||||||
|
|
||||||
return ''.join(random.choices(string.ascii_letters + string.digits, k=chars))
|
characters = string.ascii_letters + string.digits
|
||||||
|
|
||||||
|
if chars:
|
||||||
|
characters += chars
|
||||||
|
|
||||||
|
return ''.join(random.choices(characters, k=length))
|
||||||
|
|
||||||
|
|
||||||
def formatUTC(timestamp=None, ap=False):
|
def Timestamp(dtobj=None, utc=False):
|
||||||
date = datetime.fromtimestamp(timestamp) if timestamp else datetime.utcnow()
|
dtime = dtobj if dtobj else datetime
|
||||||
|
date = dtime.utcnow() if utc else dtime.now()
|
||||||
|
|
||||||
if ap:
|
return date.timestamp()
|
||||||
return date.strftime('%Y-%m-%dT%H:%M:%SZ')
|
|
||||||
|
|
||||||
return date.strftime('%a, %d %b %Y %H:%M:%S GMT')
|
|
||||||
|
|
||||||
|
|
||||||
def config_dir(modpath=None):
|
def ApDate(date=None, alt=False):
|
||||||
if env.get('CONFDIR'):
|
if not date:
|
||||||
'''set the storage path to the environment variable if it exists'''
|
date = datetime.utcnow()
|
||||||
stor_path = abspath(env['CONFDIR'])
|
|
||||||
|
|
||||||
else:
|
elif type(date) == int:
|
||||||
stor_path = f'{os.getcwd()}'
|
date = datetime.fromtimestamp(date)
|
||||||
|
|
||||||
if modpath and not env.get('CONFDIR'):
|
elif type(date) != datetime:
|
||||||
modname = basename(dirname(modpath))
|
raise TypeError(f'Unsupported object type for ApDate: {type(date)}')
|
||||||
|
|
||||||
if isdir(f'{stor_path}/{modname}'):
|
return date.strftime('%a, %d %b %Y %H:%M:%S GMT' if alt else '%Y-%m-%dT%H:%M:%SZ')
|
||||||
'''set the storage path to CWD/data if the module or script is in the working dir'''
|
|
||||||
stor_path += '/data'
|
|
||||||
|
|
||||||
if not isdir (stor_path):
|
|
||||||
os.makedirs(stor_path, exist_ok=True)
|
|
||||||
|
|
||||||
return stor_path
|
|
||||||
|
|
||||||
|
|
||||||
def getBin(filename):
|
def GetIp():
|
||||||
for pathdir in env['PATH'].split(':'):
|
|
||||||
fullpath = os.path.join(pathdir, filename)
|
|
||||||
|
|
||||||
if os.path.isfile(fullpath):
|
|
||||||
return fullpath
|
|
||||||
|
|
||||||
raise FileNotFoundError(f'Cannot find {filename} in path.')
|
|
||||||
|
|
||||||
|
|
||||||
def Try(funct, *args, **kwargs):
|
|
||||||
Result = namedtuple('Result', 'result exception')
|
|
||||||
out = None
|
|
||||||
exc = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
out = funct(*args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
exc = e
|
|
||||||
|
|
||||||
return Result(out, exc)
|
|
||||||
|
|
||||||
|
|
||||||
def sudo(cmd, password, user=None):
|
|
||||||
### Please don't pur your password in plain text in a script
|
|
||||||
### Use a module like 'getpass' to get the password instead
|
|
||||||
|
|
||||||
if isinstance(cmd, list):
|
|
||||||
cmd = ' '.join(cmd)
|
|
||||||
|
|
||||||
elif not isinstance(cmd, str):
|
|
||||||
raise ValueError('Command is not a list or string')
|
|
||||||
|
|
||||||
euser = os.environ.get('USER')
|
|
||||||
|
|
||||||
cmd = ' '.join(['sudo', '-u', user, cmd]) if user and euser != user else 'sudo ' + cmd
|
|
||||||
sudocmd = ' '.join(['echo', f'{password}', '|', cmd])
|
|
||||||
proc = subprocess.Popen(['/usr/bin/env', 'bash', '-c', sudocmd])
|
|
||||||
|
|
||||||
return proc
|
|
||||||
|
|
||||||
|
|
||||||
def getip():
|
|
||||||
# Get the main IP address of the machine
|
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -131,6 +80,367 @@ def getip():
|
||||||
return ip
|
return ip
|
||||||
|
|
||||||
|
|
||||||
def merp():
|
def Input(prompt, default=None, valtype=str, options=[], password=False):
|
||||||
log = logging.getLogger('merp-heck', {'level': 'merp', 'date': False})
|
input_func = getpass if password else input
|
||||||
log.merp('heck')
|
|
||||||
|
if default != None:
|
||||||
|
prompt += ' [-redacted-]' if password else f' [{default}]'
|
||||||
|
|
||||||
|
prompt += '\n'
|
||||||
|
|
||||||
|
if options:
|
||||||
|
opt = '/'.join(options)
|
||||||
|
prompt += f'[{opt}]'
|
||||||
|
|
||||||
|
prompt += ': '
|
||||||
|
|
||||||
|
value = input_func(prompt)
|
||||||
|
|
||||||
|
while value and options and value not in options:
|
||||||
|
input_func('Invalid value:', value)
|
||||||
|
value = input(prompt)
|
||||||
|
|
||||||
|
if not value:
|
||||||
|
return default
|
||||||
|
|
||||||
|
ret = valtype(value)
|
||||||
|
|
||||||
|
while valtype == Path and not ret.parent().exists():
|
||||||
|
input_func('Parent directory doesn\'t exist')
|
||||||
|
ret = Path(input(prompt))
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class DotDict(dict):
|
||||||
|
def __init__(self, value=None, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if type(value) in [str, bytes]:
|
||||||
|
self.fromJson(value)
|
||||||
|
|
||||||
|
elif type(value) in [dict, DotDict, DefaultDict]:
|
||||||
|
self.update(value)
|
||||||
|
|
||||||
|
elif value:
|
||||||
|
raise TypeError('The value must be a JSON string, dict, or another DotDict object, not', value.__class__)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
self.update(kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
try:
|
||||||
|
val = super().__getattribute__(key)
|
||||||
|
|
||||||
|
except AttributeError:
|
||||||
|
val = self.get(key, InvalidKey())
|
||||||
|
|
||||||
|
if type(val) == InvalidKey:
|
||||||
|
raise KeyError(f'Invalid key: {key}')
|
||||||
|
|
||||||
|
return DotDict(val) if type(val) == dict else val
|
||||||
|
|
||||||
|
|
||||||
|
def __delattr__(self, key):
|
||||||
|
if self.get(key):
|
||||||
|
del self[key]
|
||||||
|
|
||||||
|
super().__delattr__(key)
|
||||||
|
|
||||||
|
|
||||||
|
#def __delitem__(self, key):
|
||||||
|
#print('delitem', key)
|
||||||
|
#self.__delattr__(key)
|
||||||
|
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
if key.startswith('_'):
|
||||||
|
super().__setattr__(key, value)
|
||||||
|
|
||||||
|
else:
|
||||||
|
super().__setitem__(key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def __parse_item__(self, k, v):
|
||||||
|
if type(v) == dict:
|
||||||
|
v = DotDict(v)
|
||||||
|
|
||||||
|
if not k.startswith('_'):
|
||||||
|
return (k, v)
|
||||||
|
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
value = super().get(key, default)
|
||||||
|
return DotDict(value) if type(value) == dict else value
|
||||||
|
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
data = []
|
||||||
|
|
||||||
|
for k, v in super().items():
|
||||||
|
new = self.__parse_item__(k, v)
|
||||||
|
|
||||||
|
if new:
|
||||||
|
data.append(new)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
return list(super().values())
|
||||||
|
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return list(super().keys())
|
||||||
|
|
||||||
|
|
||||||
|
def asDict(self):
|
||||||
|
return dict(self)
|
||||||
|
|
||||||
|
|
||||||
|
def toJson(self, indent=None, **kwargs):
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
for k,v in self.items():
|
||||||
|
if type(k) in [DotDict, Path, Pathlib]:
|
||||||
|
k = str(k)
|
||||||
|
|
||||||
|
if type(v) in [DotDict, Path, Pathlib]:
|
||||||
|
v = str(v)
|
||||||
|
|
||||||
|
data[k] = v
|
||||||
|
|
||||||
|
return json.dumps(data, indent=indent, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def fromJson(self, string):
|
||||||
|
data = json.loads(string)
|
||||||
|
self.update(data)
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultDict(DotDict):
|
||||||
|
def __getattr__(self, key):
|
||||||
|
try:
|
||||||
|
val = super().__getattribute__(key)
|
||||||
|
|
||||||
|
except AttributeError:
|
||||||
|
val = self.get(key, DefaultDict())
|
||||||
|
|
||||||
|
return DotDict(val) if type(val) == dict else val
|
||||||
|
|
||||||
|
|
||||||
|
class Path(object):
|
||||||
|
def __init__(self, path, exist=True, missing=True, parents=True):
|
||||||
|
self.__path = Pathlib(str(path)).resolve()
|
||||||
|
self.json = DotDict({})
|
||||||
|
self.exist = exist
|
||||||
|
self.missing = missing
|
||||||
|
self.parents = parents
|
||||||
|
self.name = self.__path.name
|
||||||
|
|
||||||
|
|
||||||
|
#def __getattr__(self, key):
|
||||||
|
#try:
|
||||||
|
#attr = getattr(self.__path, key)
|
||||||
|
|
||||||
|
#except AttributeError:
|
||||||
|
#attr = getattr(self, key)
|
||||||
|
|
||||||
|
#return attr
|
||||||
|
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.__path)
|
||||||
|
|
||||||
|
|
||||||
|
def str(self):
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
|
|
||||||
|
def __check_dir(self, path=None):
|
||||||
|
target = self if not path else Path(path)
|
||||||
|
|
||||||
|
if not self.parents and not target.parent().exists():
|
||||||
|
raise FileNotFoundError('Parent directories do not exist:', target.str())
|
||||||
|
|
||||||
|
if not self.exist and target.exists():
|
||||||
|
raise FileExistsError('File or directory already exists:', target.str())
|
||||||
|
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
return self.__path.stat().st_size
|
||||||
|
|
||||||
|
|
||||||
|
def mtime(self):
|
||||||
|
return self.__path.stat().st_mtime
|
||||||
|
|
||||||
|
|
||||||
|
def mkdir(self, mode=0o755):
|
||||||
|
self.__path.mkdir(mode, self.parents, self.exist)
|
||||||
|
|
||||||
|
return True if self.__path.exists() else False
|
||||||
|
|
||||||
|
|
||||||
|
def new(self):
|
||||||
|
return Path(self.__path)
|
||||||
|
|
||||||
|
|
||||||
|
def parent(self, new=True):
|
||||||
|
path = Pathlib(self.__path).parent
|
||||||
|
|
||||||
|
if new:
|
||||||
|
return Path(path)
|
||||||
|
|
||||||
|
self.__path = path
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def move(self, path):
|
||||||
|
target = Path(path)
|
||||||
|
|
||||||
|
self.__check_dir(path)
|
||||||
|
|
||||||
|
if target.exists() and not target.isdir():
|
||||||
|
target.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def join(self, path, new=True):
|
||||||
|
new_path = self.__path.joinpath(path).resolve()
|
||||||
|
|
||||||
|
if new:
|
||||||
|
return Path(new_path)
|
||||||
|
|
||||||
|
self.__path = new_path
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def home(self, path=None, new=True):
|
||||||
|
new_path = Pathlib.home()
|
||||||
|
|
||||||
|
if path:
|
||||||
|
new_path = new_path.joinpath(path)
|
||||||
|
|
||||||
|
if new:
|
||||||
|
return Path(new_path)
|
||||||
|
|
||||||
|
self.__path = new_path
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def isdir(self):
|
||||||
|
return self.__path.is_dir()
|
||||||
|
|
||||||
|
|
||||||
|
def isfile(self):
|
||||||
|
return self.__path.is_file()
|
||||||
|
|
||||||
|
|
||||||
|
def islink(self):
|
||||||
|
return self.__path.is_symlink()
|
||||||
|
|
||||||
|
|
||||||
|
def listdir(self):
|
||||||
|
return [Path(path) for path in self.__path.iterdir()]
|
||||||
|
|
||||||
|
|
||||||
|
def exists(self):
|
||||||
|
return self.__path.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def mtime(self):
|
||||||
|
return os.path.getmtime(self.str())
|
||||||
|
|
||||||
|
|
||||||
|
def link(self, path):
|
||||||
|
target = Path(path)
|
||||||
|
|
||||||
|
self.__check_dir(path)
|
||||||
|
|
||||||
|
if target.exists():
|
||||||
|
target.delete()
|
||||||
|
|
||||||
|
self.__path.symlink_to(path, target.isdir())
|
||||||
|
|
||||||
|
|
||||||
|
def resolve(self, new=True):
|
||||||
|
path = self.__path.resolve()
|
||||||
|
|
||||||
|
if new:
|
||||||
|
return Path(path)
|
||||||
|
|
||||||
|
self.__path = path
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def touch(self, mode=0o666):
|
||||||
|
self.__path.touch(mode, self.exist)
|
||||||
|
|
||||||
|
return True if self.__path.exists() else False
|
||||||
|
|
||||||
|
|
||||||
|
def loadJson(self):
|
||||||
|
self.json = DotDict(self.read())
|
||||||
|
|
||||||
|
return self.json
|
||||||
|
|
||||||
|
|
||||||
|
def updateJson(self, data={}):
|
||||||
|
if type(data) == str:
|
||||||
|
data = json.loads(data)
|
||||||
|
|
||||||
|
self.json.update(data)
|
||||||
|
|
||||||
|
|
||||||
|
def storeJson(self, indent=None):
|
||||||
|
with self.__path.open('w') as fp:
|
||||||
|
fp.write(json.dumps(self.json.asDict(), indent=indent))
|
||||||
|
|
||||||
|
|
||||||
|
# This needs to be extended to handle dirs with files/sub-dirs
|
||||||
|
def delete(self):
|
||||||
|
if self.isdir():
|
||||||
|
self.__path.rmdir(self.missing)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.__path.unlink(self.missing)
|
||||||
|
|
||||||
|
return not self.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def open(self, *args):
|
||||||
|
return self.__path.open(*args)
|
||||||
|
|
||||||
|
|
||||||
|
def read(self, *args):
|
||||||
|
return self.open().read(*args)
|
||||||
|
|
||||||
|
|
||||||
|
def readlines(self):
|
||||||
|
return self.open().readlines()
|
||||||
|
|
||||||
|
|
||||||
|
## def rmdir():
|
||||||
|
|
||||||
|
|
||||||
|
def NfsCheck(path):
|
||||||
|
proc = Path('/proc/mounts')
|
||||||
|
path = Path(path).resolve()
|
||||||
|
|
||||||
|
if not proc.exists():
|
||||||
|
return True
|
||||||
|
|
||||||
|
with proc.open() as fd:
|
||||||
|
for line in fd:
|
||||||
|
line = line.split()
|
||||||
|
|
||||||
|
if line[2] == 'nfs' and line[1] in path.str():
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidKey(object):
|
||||||
|
pass
|
||||||
|
|
|
@ -1,233 +1,131 @@
|
||||||
'''functions for web template management and rendering'''
|
'''functions for web template management and rendering'''
|
||||||
import codecs, traceback, os, json
|
import codecs, traceback, os, json, xml
|
||||||
|
|
||||||
from os import listdir, makedirs
|
from os import listdir, makedirs
|
||||||
from os.path import isfile, isdir, getmtime, abspath
|
from os.path import isfile, isdir, getmtime, abspath
|
||||||
|
|
||||||
from jinja2 import Environment, FileSystemLoader, ChoiceLoader, select_autoescape
|
from jinja2 import Environment, FileSystemLoader, ChoiceLoader, select_autoescape, Markup
|
||||||
from hamlpy.hamlpy import Compiler
|
from hamlish_jinja import HamlishExtension
|
||||||
from markdown import markdown
|
from markdown import markdown
|
||||||
from watchdog.observers import Observer
|
from watchdog.observers import Observer
|
||||||
from watchdog.events import FileSystemEventHandler
|
from watchdog.events import FileSystemEventHandler
|
||||||
|
from xml.dom import minidom
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sanic import response as Response
|
||||||
|
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
Response = None
|
||||||
|
|
||||||
from . import logging
|
from . import logging
|
||||||
from .color import *
|
from .color import *
|
||||||
|
from .misc import Path, DotDict
|
||||||
framework = 'sanic'
|
|
||||||
|
|
||||||
try:
|
|
||||||
import sanic
|
|
||||||
except:
|
|
||||||
logging.debug('Cannot find Sanic')
|
|
||||||
|
|
||||||
try:
|
|
||||||
import aiohttp
|
|
||||||
except:
|
|
||||||
logging.debug('Cannot find aioHTTP')
|
|
||||||
|
|
||||||
|
|
||||||
env = None
|
class Template(Environment):
|
||||||
|
def __init__(self, search=[], global_vars={}, autoescape=True):
|
||||||
|
self.autoescape = autoescape
|
||||||
|
self.watcher = None
|
||||||
|
self.search = []
|
||||||
|
self.func_context = None
|
||||||
|
|
||||||
global_variables = {
|
for path in search:
|
||||||
'markdown': markdown,
|
self.__add_search_path(path)
|
||||||
'lighten': lighten,
|
|
||||||
'darken': darken,
|
|
||||||
'saturate': saturate,
|
|
||||||
'desaturate': desaturate,
|
|
||||||
'rgba': rgba
|
|
||||||
}
|
|
||||||
|
|
||||||
search_path = list()
|
super().__init__(
|
||||||
build_path_pairs = dict()
|
loader=ChoiceLoader([FileSystemLoader(path) for path in self.search]),
|
||||||
|
extensions=[HamlishExtension],
|
||||||
|
autoescape=self.autoescape,
|
||||||
|
lstrip_blocks=True,
|
||||||
|
trim_blocks=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hamlish_file_extensions=('.haml',)
|
||||||
|
self.hamlish_enable_div_shortcut=True
|
||||||
|
self.hamlish_mode = 'indented'
|
||||||
|
|
||||||
|
self.globals.update({
|
||||||
|
'markdown': markdown,
|
||||||
|
'markup': Markup,
|
||||||
|
'cleanhtml': lambda text: ''.join(xml.etree.ElementTree.fromstring(text).itertext()),
|
||||||
|
'lighten': lighten,
|
||||||
|
'darken': darken,
|
||||||
|
'saturate': saturate,
|
||||||
|
'desaturate': desaturate,
|
||||||
|
'rgba': rgba
|
||||||
|
})
|
||||||
|
|
||||||
|
self.globals.update(global_vars)
|
||||||
|
|
||||||
|
|
||||||
def addSearchPath(path):
|
def __add_search_path(self, path):
|
||||||
tplPath = abspath(path)
|
tpl_path = Path(path)
|
||||||
|
|
||||||
if not isdir(tplPath):
|
if not tpl_path.exists():
|
||||||
raise FileNotFoundError(f'Cannot find template directory: {tplPath}')
|
raise FileNotFoundError('Cannot find search path:', tpl_path.str())
|
||||||
|
|
||||||
if tplPath not in search_path:
|
if tpl_path.str() not in self.search:
|
||||||
search_path.append(tplPath)
|
self.search.append(tpl_path.str())
|
||||||
|
|
||||||
|
|
||||||
def delSearchPath(path):
|
def addEnv(self, k, v):
|
||||||
tplPath = abspath(path)
|
self.globals[k] = v
|
||||||
|
|
||||||
if tplPath in search_path:
|
|
||||||
search_path.remove(tplPath)
|
|
||||||
|
|
||||||
|
|
||||||
def addBuildPath(name, source, destination):
|
def delEnv(self, var):
|
||||||
src = abspath(source)
|
if not self.globals.get(var):
|
||||||
dest = abspath(destination)
|
raise ValueError(f'"{var}" not in global variables')
|
||||||
|
|
||||||
if not isdir(src):
|
del self.var[var]
|
||||||
raise FileNotFoundError(f'Source path doesn\'t exist: {src}')
|
|
||||||
|
|
||||||
build_path_pairs.update({
|
|
||||||
name: {
|
|
||||||
'source': src,
|
|
||||||
'destination': dest
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
addSearchPath(dest)
|
|
||||||
|
|
||||||
|
|
||||||
def delBuildPath(name):
|
def updateEnv(self, data):
|
||||||
if not build_path_pairs.get(name):
|
if not isinstance(data, dict):
|
||||||
raise ValueError(f'"{name}" not in build paths')
|
raise ValueError(f'Environment data not a dict')
|
||||||
|
|
||||||
del build_path_pairs[src]
|
self.globals.update(data)
|
||||||
|
|
||||||
|
|
||||||
def getBuildPath(name=None):
|
def addFilter(self, funct, name=None):
|
||||||
template = build_path_pairs.get(name)
|
name = funct.__name__ if not name else name
|
||||||
|
self.filters[name] = funct
|
||||||
if name:
|
|
||||||
if template:
|
|
||||||
return template
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f'"{name}" not in build paths')
|
|
||||||
|
|
||||||
return build_path_pairs
|
|
||||||
|
|
||||||
|
|
||||||
def addEnv(data):
|
def delFilter(self, name):
|
||||||
if not isinstance(data, dict):
|
if not self.filters.get(name):
|
||||||
raise TypeError(f'environment data is not a dict')
|
raise valueError(f'"{name}" not in global filters')
|
||||||
|
|
||||||
global_variables.update(data)
|
del self.filters[name]
|
||||||
|
|
||||||
|
|
||||||
def delEnv(var):
|
def updateFilter(self, data):
|
||||||
if not global_variables.get(var):
|
if not isinstance(context, dict):
|
||||||
raise ValueError(f'"{var}" not in global variables')
|
raise ValueError(f'Filter data not a dict')
|
||||||
|
|
||||||
del global_variables[var]
|
self.filters.update(data)
|
||||||
|
|
||||||
|
|
||||||
def setup(fwork='sanic'):
|
def render(self, tplfile, request=None, context={}, headers={}, cookies={}, pprint=False, **kwargs):
|
||||||
global env
|
if not isinstance(context, dict):
|
||||||
global framework
|
raise TypeError(f'context for {tplfile} not a dict: {type(context)} {context}')
|
||||||
|
|
||||||
framework = fwork
|
context['request'] = request if request else {'headers': headers, 'cookies': cookies}
|
||||||
env = Environment(
|
|
||||||
loader=ChoiceLoader([FileSystemLoader(path) for path in search_path]),
|
|
||||||
autoescape=select_autoescape(['html', 'css']),
|
|
||||||
lstrip_blocks=True,
|
|
||||||
trim_blocks=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
if self.func_context:
|
||||||
|
context.update(self.func_context(DotDict(context), DotDict(self.globals)))
|
||||||
|
|
||||||
def renderTemplate(tplfile, context={}, request=None, headers=dict(), cookies=dict(), **kwargs):
|
result = self.get_template(tplfile).render(context)
|
||||||
if not isinstance(context, dict):
|
|
||||||
raise TypeError(f'context for {tplfile} not a dict')
|
|
||||||
|
|
||||||
data = global_variables.copy()
|
if pprint and any(map(tplfile.endswith, ['haml', 'html', 'xml'])):
|
||||||
data['request'] = request if request else {'headers': headers, 'cookies': cookies}
|
return minidom.parseString(result).toprettyxml(indent=" ")
|
||||||
data.update(context)
|
|
||||||
|
|
||||||
return env.get_template(tplfile).render(data)
|
|
||||||
|
|
||||||
|
|
||||||
def sendResponse(template, request, context=dict(), status=200, ctype='text/html', headers=dict(), **kwargs):
|
|
||||||
context['request'] = request
|
|
||||||
html = renderTemplate(template, context, **kwargs)
|
|
||||||
|
|
||||||
if framework == 'sanic':
|
|
||||||
return sanic.response.text(html, status=status, headers=headers, content_type=ctype)
|
|
||||||
|
|
||||||
elif framework == 'aiohttp':
|
|
||||||
return aiohttp.web.Response(body=html, status=status, headers=headers, content_type=ctype)
|
|
||||||
|
|
||||||
else:
|
|
||||||
logging.error('Please install aiohttp or sanic. Response not sent.')
|
|
||||||
|
|
||||||
|
|
||||||
# delete me later
|
|
||||||
aiohttpTemplate = sendResponse
|
|
||||||
|
|
||||||
|
|
||||||
def buildTemplates(name=None):
|
|
||||||
paths = getBuildPath(name)
|
|
||||||
|
|
||||||
for k, tplPaths in paths.items():
|
|
||||||
src = tplPaths['source']
|
|
||||||
dest = tplPaths['destination']
|
|
||||||
|
|
||||||
timefile = f'{dest}/times.json'
|
|
||||||
updated = False
|
|
||||||
|
|
||||||
if not isdir(f'{dest}'):
|
|
||||||
makedirs(f'{dest}')
|
|
||||||
|
|
||||||
if isfile(timefile):
|
|
||||||
try:
|
|
||||||
times = json.load(open(timefile))
|
|
||||||
|
|
||||||
except:
|
|
||||||
times = {}
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
times = {}
|
return result
|
||||||
|
|
||||||
for filename in listdir(src):
|
|
||||||
fullPath = f'{src}/{filename}'
|
|
||||||
modtime = getmtime(fullPath)
|
|
||||||
base, ext = filename.split('.', 1)
|
|
||||||
|
|
||||||
if ext != 'haml':
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif base not in times or times.get(base) != modtime:
|
|
||||||
updated = True
|
|
||||||
logging.verbose(f"Template '{filename}' was changed. Building...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
destination = f'{dest}/{base}.html'
|
|
||||||
haml_lines = codecs.open(fullPath, 'r', encoding='utf-8').read().splitlines()
|
|
||||||
|
|
||||||
compiler = Compiler()
|
|
||||||
output = compiler.process_lines(haml_lines)
|
|
||||||
outfile = codecs.open(destination, 'w', encoding='utf-8')
|
|
||||||
outfile.write(output)
|
|
||||||
|
|
||||||
logging.info(f"Template '{filename}' has been built")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
'''I'm actually not sure what sort of errors can happen here, so generic catch-all for now'''
|
|
||||||
traceback.print_exc()
|
|
||||||
logging.error(f'Failed to build {filename}: {e}')
|
|
||||||
|
|
||||||
times[base] = modtime
|
|
||||||
|
|
||||||
if updated:
|
|
||||||
with open(timefile, 'w') as filename:
|
|
||||||
filename.write(json.dumps(times))
|
|
||||||
|
|
||||||
|
|
||||||
def templateWatcher():
|
def response(self, *args, ctype='text/html', status=200, **kwargs):
|
||||||
watchPaths = [path['source'] for k, path in build_path_pairs.items()]
|
if not Response:
|
||||||
logging.info('Starting template watcher')
|
raise ModuleNotFoundError('Sanic is not installed')
|
||||||
observer = Observer()
|
|
||||||
|
|
||||||
for tplpath in watchPaths:
|
html = self.render(*args, **kwargs)
|
||||||
logging.debug(f'Watching template dir for changes: {tplpath}')
|
return Response.HTTPResponse(body=html, status=status, content_type=ctype, headers=kwargs.get('headers', {}))
|
||||||
observer.schedule(templateWatchHandler(), tplpath, recursive=True)
|
|
||||||
|
|
||||||
return observer
|
|
||||||
|
|
||||||
|
|
||||||
class templateWatchHandler(FileSystemEventHandler):
|
|
||||||
def on_any_event(self, event):
|
|
||||||
filename, ext = os.path.splitext(os.path.relpath(event.src_path))
|
|
||||||
|
|
||||||
if event.event_type in ['modified', 'created'] and ext[1:] == 'haml':
|
|
||||||
logging.info('Rebuilding templates')
|
|
||||||
buildTemplates()
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['addSearchPath', 'delSearchPath', 'addBuildPath', 'delSearchPath', 'addEnv', 'delEnv', 'setup', 'renderTemplate', 'sendResponse', 'buildTemplates', 'templateWatcher']
|
|
||||||
|
|
|
@ -1,12 +1,10 @@
|
||||||
pycryptodome>=3.9.1
|
|
||||||
colour>=0.1.5
|
colour>=0.1.5
|
||||||
aiohttp>=3.6.2
|
|
||||||
envbash>=1.2.0
|
envbash>=1.2.0
|
||||||
hamlpy3>=0.84.0
|
Hamlish-Jinja==0.3.3
|
||||||
Jinja2>=2.10.1
|
Jinja2>=2.10.1
|
||||||
jinja2-markdown>=0.0.3
|
jinja2-markdown>=0.0.3
|
||||||
Mastodon.py>=1.5.0
|
Mastodon.py>=1.5.0
|
||||||
|
pycryptodome>=3.9.1
|
||||||
|
python-magic>=0.4.18
|
||||||
sanic>=19.12.2
|
sanic>=19.12.2
|
||||||
urllib3>=1.25.7
|
|
||||||
watchdog>=0.8.3
|
watchdog>=0.8.3
|
||||||
httpsig>=1.3.0
|
|
||||||
|
|
Loading…
Reference in a new issue