expand PasswordHasher and minor DotDict fixes

This commit is contained in:
Izalia Mae 2021-06-03 15:40:30 -04:00
parent 08895f778e
commit 63a6819e41
2 changed files with 133 additions and 45 deletions

View file

@ -4,6 +4,7 @@ import logging as pylog
from jinja2.exceptions import TemplateNotFound from jinja2.exceptions import TemplateNotFound
from multidict import CIMultiDict from multidict import CIMultiDict
from multiprocessing import cpu_count, current_process from multiprocessing import cpu_count, current_process
from sanic.views import HTTPMethodView
from urllib.parse import parse_qsl, urlparse from urllib.parse import parse_qsl, urlparse
from . import http, logging from . import http, logging
@ -30,12 +31,11 @@ class HttpServer(sanic.Sanic):
self.port = int(port) self.port = int(port)
self.workers = int(kwargs.get('workers', cpu_count())) self.workers = int(kwargs.get('workers', cpu_count()))
self.sig_handler = kwargs.get('sig_handler') self.sig_handler = kwargs.get('sig_handler')
self.ctx = DotDict()
super().__init__(name, request_class=kwargs.get('request_class', HttpRequest)) super().__init__(name, request_class=kwargs.get('request_class', HttpRequest))
#for log in ['sanic.root', 'sanic.access']: for log in ['sanic.root', 'sanic.access']:
#pylog.getLogger(log).setLevel(pylog.CRITICAL) pylog.getLogger(log).setLevel(pylog.ERROR)
self.template = Template( self.template = Template(
kwargs.get('tpl_search', []), kwargs.get('tpl_search', []),
@ -56,6 +56,11 @@ class HttpServer(sanic.Sanic):
signal.signal(signal.SIGTERM, self.finish) signal.signal(signal.SIGTERM, self.finish)
## Sanic spits out a warning, so this is the workaround to stop it
def __setattr__(self, key, value):
object.__setattr__(self, key, value)
def add_method_route(self, method, *routes): def add_method_route(self, method, *routes):
for route in routes: for route in routes:
self.add_route(method.as_view(), route) self.add_route(method.as_view(), route)
@ -88,8 +93,10 @@ class HttpServer(sanic.Sanic):
if self.sig_handler: if self.sig_handler:
self.sig_handler() self.sig_handler()
print('stopping.....')
self.stop() self.stop()
logging.info('Bye! :3') logging.info('Bye! :3')
sys.exit()
class HttpRequest(sanic.request.Request): class HttpRequest(sanic.request.Request):

View file

@ -1,5 +1,5 @@
'''Miscellaneous functions''' '''Miscellaneous functions'''
import hashlib, random, string, sys, os, json, socket, time import hashlib, random, string, sys, os, json, statistics, socket, time, timeit
from os import environ as env from os import environ as env
from datetime import datetime from datetime import datetime
@ -11,8 +11,9 @@ from shutil import copyfile, rmtree
from . import logging from . import logging
try: try:
from passlib.hash import argon2 import argon2
except ImportError: except ImportError:
logging.verbose('argon2-cffi not installed. PasswordHasher class disabled')
argon2 = None argon2 = None
@ -191,6 +192,47 @@ def PrintMethods(object, include_underscore=False):
print(line) print(line)
def TimeFunction(func, *args, passes=1, use_gc=True, **kwargs):
options = [
lambda: func(*args, **kwargs)
]
if use_gc:
options.append('gc.enable()')
timer = timeit.Timer(*options)
if passes > 1:
return timer.repeat(passes, 1)
return timer.timeit(1)
def TimeFunctionPPrint(func, *args, passes=5, use_gc=True, floatlen=3, **kwargs):
parse_time = lambda num: f'{round(num, floatlen)}s'
times = []
for idx in range(0, passes):
passtime = TimeFunction(func, *args, **kwargs, passes=1, use_gc=use_gc)
times.append(passtime)
print(f'Pass {idx+1}: {parse_time(passtime)}')
average = statistics.fmean(times)
print('-----------------')
print(f'Average: {parse_time(average)}')
print(f'Total: {parse_time(sum(times))}')
def TimePassHasher(string='hecking heck', passes=3, iterations=[2,4,8,16,32,64,96]):
for iteration in iterations:
print('\nTesting hash iterations:', iteration)
hasher = PasswordHasher(iterations=iteration)
strhash = hasher.hash(string)
TimeFunctionPPrint(hasher.verify, strhash, string, passes=passes)
class Connection(socket.socket): class Connection(socket.socket):
def __init__(self, address='127.0.0.1', port=8080, tcp=True): def __init__(self, address='127.0.0.1', port=8080, tcp=True):
super().__init__(socket.AF_INET, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM) super().__init__(socket.AF_INET, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM)
@ -216,6 +258,9 @@ class Connection(socket.socket):
class DotDict(dict): class DotDict(dict):
dict_ignore_types = ['basecache', 'lrucache', 'ttlcache']
def __init__(self, value=None, **kwargs): def __init__(self, value=None, **kwargs):
'''Python dictionary, but variables can be set/get via attributes '''Python dictionary, but variables can be set/get via attributes
@ -233,7 +278,7 @@ class DotDict(dict):
if isinstance(value, (str, bytes)): if isinstance(value, (str, bytes)):
self.from_json(value) self.from_json(value)
elif isinstance(value, dict) or isinstance(value, list): elif value.__class__.__name__.lower() not in self.dict_ignore_types and isinstance(value, dict):
self.update(value) self.update(value)
elif value: elif value:
@ -252,7 +297,7 @@ class DotDict(dict):
def __setitem__(self, k, v): def __setitem__(self, k, v):
if isinstance(v, dict): if v.__class__.__name__.lower() not in self.dict_ignore_types and isinstance(v, dict):
v = DotDict(v) v = DotDict(v)
super().__setitem__(k, v) super().__setitem__(k, v)
@ -309,7 +354,7 @@ class LowerDotDict(DotDict):
def __setattr__(self, key, value): def __setattr__(self, key, value):
return super().__setattr(self, key.lower(), value) return super().__setattr__(key.lower(), value)
def update(self, data): def update(self, data):
@ -359,6 +404,17 @@ class Path(object):
return mode if type(mode) == oct else eval(f'0o{mode}') return mode if type(mode) == oct else eval(f'0o{mode}')
def append(self, text, new=True):
path = str(self.__path) + text
if new:
return Path(path)
self.__path = Pathlib(path)
return self
def size(self): def size(self):
return self.__path.stat().st_size return self.__path.stat().st_size
@ -549,45 +605,70 @@ class JsonEncoder(json.JSONEncoder):
return json.JSONEncoder.default(self, obj) return json.JSONEncoder.default(self, obj)
class PasswordHash(object): class PasswordHasher(DotDict):
def __init__(self, salt=None, rounds=8, bsize=50, threads=os.cpu_count(), length=64): ## The defaults can usually be used, except for `iterations`. That should be tweaked on each machine
if type(salt) == Path: ## You can use the TimeFunctionPPrint command above to test this
if salt.exists():
with salt.open() as fd:
self.salt = fd.read()
else: aliases = {
newsalt = RandomGen(40) 'iterations': 'time_cost',
'memory': 'memory_cost',
with salt.open('w') as fd: 'threads': 'parallelism'
fd.write(newsalt) }
self.salt = newsalt
else:
self.salt = salt or RandomGen(40)
self.rounds = rounds
self.bsize = bsize * 1024
self.threads = threads
self.length = length
def hash(self, password): def __init__(self, **kwargs):
return argon2.using( if not argon2:
salt = self.salt.encode('UTF-8'), raise ValueError('password hashing disabled')
rounds = self.rounds,
memory_cost = self.bsize, super().__init__({
max_threads = self.threads, 'time_cost': 16,
digest_size = self.length 'memory_cost': 100 * 1024,
).hash(password) 'parallelism': os.cpu_count(),
'encoding': 'utf-8',
'type': argon2.Type.ID,
})
self.hasher = None
self.update(kwargs)
self.setup()
def verify(self, password, passhash): def get_config(self, key):
return argon2.using( key = self.aliases.get(key, key)
salt = self.salt.encode('UTF-8'),
rounds = self.rounds, self[key]
memory_cost = self.bsize, return self.get(key) / 1024 if key == 'memory_cost' else self.get(key)
max_threads = self.threads,
digest_size = self.length
).verify(password, passhash) def set_config(self, key, value):
key = self.aliases.get(key, key)
self[key] = value * 1024 if key == 'memory_cost' else value
self.setup()
def setup(self):
self.hasher = argon2.PasswordHasher(**self)
def hash(self, password: str):
return self.hasher.hash(password)
def verify(self, passhash: str, password: str):
try:
return self.hasher.verify(passhash, password)
except argon2.exceptions.VerifyMismatchError:
return False
def iteration_test(self, string='hecking heck', passes=3, iterations=[8,16,24,32,40,48,56,64]):
original_iter = self.get_config('iterations')
for iteration in iterations:
self.set_config('iterations', iteration)
print('\nTesting hash iterations:', iteration)
TimeFunctionPPrint(self.verify, self.hash(string), string, passes=passes)
self.set_config('iterations', original_iter)