a bunch of changes

This commit is contained in:
Izalia Mae 2021-08-16 10:55:59 -04:00
parent fda6e3160b
commit 62a2ab7115
8 changed files with 329 additions and 300 deletions

View file

@ -117,7 +117,7 @@ class BaseCache(OrderedDict):
self[key]['timestamp'] = timestamp + self.ttl self[key]['timestamp'] = timestamp + self.ttl
self.move_to_end(key) self.move_to_end(key)
return item.data return item
## This doesn't work for some reason ## This doesn't work for some reason

View file

@ -4,7 +4,7 @@ from . import Path
class DotDict(dict): class DotDict(dict):
dict_ignore_types = ['basecache', 'lrucache', 'ttlcache'] non_dict_vars = []
def __init__(self, value=None, **kwargs): def __init__(self, value=None, **kwargs):
@ -15,16 +15,11 @@ class DotDict(dict):
''' '''
super().__init__() super().__init__()
self.__setattr__ = self.__setitem__
## compatibility
self.toJson = self.to_json
self.fromJson = self.from_json
if isinstance(value, (str, bytes)): if isinstance(value, (str, bytes)):
self.from_json(value) self.from_json(value)
elif value.__class__.__name__.lower() not in self.dict_ignore_types and isinstance(value, dict): elif isinstance(value, dict):
self.update(value) self.update(value)
elif value: elif value:
@ -42,8 +37,16 @@ class DotDict(dict):
raise AttributeError(f'{self.__class__.__name__} object has no attribute {k}') from None raise AttributeError(f'{self.__class__.__name__} object has no attribute {k}') from None
def __setattr__(self, k, v):
if k in self.non_dict_vars or k.startswith('_'):
super().__setattr__(k, v)
else:
self.__setitem__(k, v)
def __setitem__(self, k, v): def __setitem__(self, k, v):
if v.__class__.__name__.lower() not in self.dict_ignore_types and isinstance(v, dict): if type(v) == dict:
v = DotDict(v) v = DotDict(v)
super().__setitem__(k, v) super().__setitem__(k, v)

View file

@ -348,8 +348,12 @@ def random_gen(length=20, letters=True, numbers=True, extra=None):
return ''.join(random.choices(characters, k=length)) return ''.join(random.choices(characters, k=length))
def signal_handler(func, *args, **kwargs): def signal_handler(func, *args, original_args=True, **kwargs):
handler = lambda signum, frame: func(signum, frame, *args, **kwargs) if original_args:
handler = lambda signum, frame: func(signum, frame, *args, **kwargs)
else:
handler = lambda *_: func(*args, **kwargs)
signal.signal(signal.SIGHUP, handler) signal.signal(signal.SIGHUP, handler)
signal.signal(signal.SIGINT, handler) signal.signal(signal.SIGINT, handler)

View file

@ -68,10 +68,14 @@ class Path(str):
self.__check_dir(path) self.__check_dir(path)
if target.exists and overwrite: if overwrite:
target.delete() try:
target.delete()
shutil.copyfile(self, target) except FileNotFoundError:
pass
shutil.copy2(self, target)
def delete(self): def delete(self):

View file

@ -1,8 +1,4 @@
from .client import ( from .signature import (
HttpRequestsClient,
HttpRequestsRequest,
HttpRequestsResponse,
SigningError,
verify_request, verify_request,
verify_headers, verify_headers,
parse_signature, parse_signature,
@ -10,22 +6,21 @@ from .client import (
fetch_instance, fetch_instance,
fetch_nodeinfo, fetch_nodeinfo,
fetch_webfinger_account, fetch_webfinger_account,
set_requests_client,
generate_rsa_key generate_rsa_key
) )
## These usually only get called by the above functions, but importing anyway
from .client import ( from .client import (
parse_body_digest, HttpRequestsClient,
verify_string, HttpRequestsRequest,
sign_pkcs_headers HttpRequestsResponse,
set_requests_client
) )
__all__ = [ __all__ = [
'HttpRequestsClient', 'HttpRequestsClient',
'HttpRequestsRequest', 'HttpRequestsRequest',
'HttpRequestsResponse', 'HttpRequestsResponse',
'SigningError',
'fetch_actor', 'fetch_actor',
'fetch_instance', 'fetch_instance',
'fetch_nodeinfo', 'fetch_nodeinfo',

View file

@ -1,20 +1,18 @@
import json, requests, sys import json, requests, sys
from Crypto.Hash import SHA256
from Crypto.PublicKey import RSA
from Crypto.Signature import PKCS1_v1_5
from PIL import Image from PIL import Image
from base64 import b64decode, b64encode from base64 import b64encode
from datetime import datetime from datetime import datetime
from functools import cached_property, lru_cache from functools import cached_property
from io import BytesIO from io import BytesIO
from izzylib import DefaultDotDict, DotDict, Path, izzylog as logging, __version__ from izzylib import DefaultDotDict, DotDict, Path, izzylog as logging, __version__
from izzylib.exceptions import HttpFileDownloadedError from izzylib.exceptions import HttpFileDownloadedError
from ssl import SSLCertVerificationError from ssl import SSLCertVerificationError
from tldextract import extract
from urllib.parse import urlparse from urllib.parse import urlparse
from .signature import sign_request
Client = None Client = None
methods = ['connect', 'delete', 'get', 'head', 'options', 'patch', 'post', 'put', 'trace'] methods = ['connect', 'delete', 'get', 'head', 'options', 'patch', 'post', 'put', 'trace']
@ -41,48 +39,28 @@ class HttpRequestsClient(object):
def set_global(self): def set_global(self):
global Client set_requests_client(self)
Client = self
def __sign_request(self, request, privkey, keyid): def build_request(self, *args, method='get', privkey=None, keyid=None, **kwargs):
request.add_header('(request-target)', f'{request.method.lower()} {request.path}')
request.add_header('host', request.host)
request.add_header('date', datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'))
if request.body:
body_hash = b64encode(SHA256.new(request.body).digest()).decode("UTF-8")
request.add_header('digest', f'SHA-256={body_hash}')
request.add_header('content-length', len(request.body))
sig = {
'keyId': keyid,
'algorithm': 'rsa-sha256',
'headers': ' '.join([k.lower() for k in request.headers.keys()]),
'signature': b64encode(sign_pkcs_headers(privkey, request.headers)).decode('UTF-8')
}
sig_items = [f'{k}="{v}"' for k,v in sig.items()]
sig_string = ','.join(sig_items)
request.add_header('signature', sig_string)
request.remove_header('(request-target)')
request.remove_header('host')
def request(self, *args, method='get', **kwargs):
if method.lower() not in methods: if method.lower() not in methods:
raise ValueError(f'Invalid method: {method}') raise ValueError(f'Invalid method: {method}')
request = HttpRequestsRequest(self, *args, method=method.lower(), **kwargs) request = HttpRequestsRequest(self, *args, method=method.lower(), **kwargs)
if privkey and keyid:
request.sign(privkey, keyid)
return request
def request(self, *args, **kwargs):
request = self.build_request(*args, **kwargs)
return HttpRequestsResponse(request.send()) return HttpRequestsResponse(request.send())
def signed_request(self, privkey, keyid, *args, **kwargs): def signed_request(self, privkey, keyid, *args, **kwargs):
request = HttpRequestsRequest(self, *args, **kwargs) return self.request(*args, privkey=privkey, keyid=keyid, **kwargs)
self.__sign_request(request, privkey, keyid)
return HttpRequestsResponse(request.send())
def download(self, url, filepath, *args, filename=None, **kwargs): def download(self, url, filepath, *args, filename=None, **kwargs):
@ -185,6 +163,10 @@ class HttpRequestsRequest(object):
return func(*self.args, **self.kwargs) return func(*self.args, **self.kwargs)
def sign(self, privkey, keyid):
sign_request(self, privkey, keyid)
class HttpRequestsResponse(object): class HttpRequestsResponse(object):
def __init__(self, response): def __init__(self, response):
self.response = response self.response = response
@ -239,233 +221,6 @@ class HttpRequestsResponse(object):
fd.write(chunk) fd.write(chunk)
async def verify_request(request, actor: dict=None):
'''Verify a header signature from a SimpleASGI request
request: The request with the headers to verify
actor: A dictionary containing the activitypub actor and the link to the pubkey used for verification
'''
body = (await request.body) if request.body else None
headers = {k.lower(): v[0] for k,v in request.headers.items()}
return verify_headers(headers, request.method, request.path, actor, body)
def verify_headers(headers: dict, method: str, path: str, actor: dict=None, body=None):
'''Verify a header signature
headers: A dictionary containing all the headers from a request
method: The HTTP method of the request
path: The path of the HTTP request
actor (optional): A dictionary containing the activitypub actor and the link to the pubkey used for verification
body (optional): The body of the request. Only needed if the signature includes the digest header
fail (optional): If set to True, raise an error instead of returning False if any step of the process fails
'''
headers = {k.lower(): v for k,v in headers.items()}
headers['(request-target)'] = f'{method.lower()} {path}'
signature = parse_signature(headers.get('signature'))
digest = headers.get('digest')
missing_headers = [k for k in headers if k in ['date', 'host'] if headers.get(k) == None]
if not signature:
raise AssertionError('Missing signature')
if not actor:
actor = fetch_actor(signature.keyid)
print(actor)
## Add digest header to missing headers list if it doesn't exist
if method.lower() == 'post' and not digest:
missing_headers.append('digest')
## Fail if missing date, host or digest (if POST) headers
if missing_headers:
raise AssertionError(f'Missing headers: {missing_headers}')
## Fail if body verification fails
if digest:
digest_hash = parse_body_digest(headers.get('digest'))
if not verify_string(body, digest_hash.sig, digest_hash.alg):
raise AssertionError('Failed body digest verification')
pubkey = actor.publicKey['publicKeyPem']
return sign_pkcs_headers(pubkey, {k:v for k,v in headers.items() if k in signature.headers}, sig=signature)
def parse_body_digest(digest):
if not digest:
raise AssertionError('Empty digest')
parsed = DotDict()
alg, sig = digest.split('=', 1)
parsed.sig = sig
parsed.alg = alg.replace('-', '')
return parsed
def verify_string(string, enc_string, alg='SHA256', fail=False):
if type(string) != bytes:
string = string.encode('UTF-8')
body_hash = b64encode(SHA256.new(string).digest()).decode('UTF-8')
if body_hash == enc_string:
return True
if fail:
raise AssertionError('String failed validation')
else:
return False
def sign_pkcs_headers(key: str, headers: dict, sig=None):
if sig:
head_items = [f'{item}: {headers[item]}' for item in sig.headers]
else:
head_items = [f'{k.lower()}: {v}' for k,v in headers.items()]
head_string = '\n'.join(head_items)
head_bytes = head_string.encode('UTF-8')
KEY = RSA.importKey(key)
pkcs = PKCS1_v1_5.new(KEY)
h = SHA256.new(head_bytes)
if sig:
return pkcs.verify(h, b64decode(sig.signature))
else:
return pkcs.sign(h)
def parse_signature(signature: str):
if not signature:
return
raise AssertionError('Missing signature header')
split_sig = signature.split(',')
sig = DefaultDotDict()
for part in split_sig:
key, value = part.split('=', 1)
sig[key.lower()] = value.replace('"', '')
sig.headers = sig.headers.split()
sig.domain = urlparse(sig.keyid).netloc
sig.top_domain = '.'.join(extract(sig.domain)[1:])
return sig
@lru_cache(maxsize=512)
def fetch_actor(url):
if not Client:
raise ValueError('Please set global client with "SetRequestsClient(client)"')
url = url.split('#')[0]
headers = {'Accept': 'application/activity+json'}
resp = Client.request(url, headers=headers)
try:
actor = resp.json
except json.decoder.JSONDecodeError:
return
except Exception as e:
izzylog.debug(f'HTTP {resp.status}: {resp.body}')
raise e from None
actor.web_domain = urlparse(url).netloc
actor.shared_inbox = actor.inbox
actor.pubkey = None
actor.handle = actor.preferredUsername
if actor.get('endpoints'):
actor.shared_inbox = actor.endpoints.get('sharedInbox', actor.inbox)
if actor.get('publicKey'):
actor.pubkey = actor.publicKey.get('publicKeyPem')
return actor
@lru_cache(maxsize=512)
def fetch_instance(domain):
if not Client:
raise ValueError('Please set global client with "SetRequestsClient(client)"')
headers = {'Accept': 'application/json'}
resp = Client.request(f'https://{domain}/api/v1/instance', headers=headers)
try:
return resp.json
except json.decoder.JSONDecodeError:
return
except Exception as e:
izzylog.debug(f'HTTP {resp.status}: {resp.body}')
raise e from None
@lru_cache(maxsize=512)
def fetch_webfinger_account(handle, domain):
if not Client:
raise ValueError('Please set global client with HttpRequestsClient.set_global()')
data = DefaultDotDict()
webfinger = Client.request(f'https://{domain}/.well-known/webfinger?resource=acct:{handle}@{domain}')
if not webfinger.body:
raise ValueError('Webfinger body empty')
data.handle, data.domain = webfinger.json.subject.replace('acct:', '').split('@')
for link in webfinger.json.links:
if link['rel'] == 'self' and link['type'] == 'application/activity+json':
data.actor = link['href']
return data
@lru_cache(maxsize=512)
def fetch_nodeinfo(domain):
if not Client:
raise ValueError('Please set global client with HttpRequestsClient.set_global()')
webfinger = Client.request(f'https://{domain}/.well-known/nodeinfo')
webfinger_data = DotDict(webfinger.body)
for link in webfinger.json.links:
if link['rel'] == 'http://nodeinfo.diaspora.software/ns/schema/2.0':
nodeinfo_url = link['href']
break
nodeinfo = Client.request(nodeinfo_url)
return nodeinfo.json
def set_requests_client(client=None): def set_requests_client(client=None):
global Client global Client
Client = client or RequestsClient() Client = client or RequestsClient()
def generate_rsa_key():
privkey = RSA.generate(2048)
key = DotDict({'PRIVKEY': privkey, 'PUBKEY': privkey.publickey()})
key.update({'privkey': key.PRIVKEY.export_key().decode(), 'pubkey': key.PUBKEY.export_key().decode()})
return key
class SigningError(Exception):
pass

View file

@ -0,0 +1,264 @@
import json, requests, sys
from PIL import Image
from Crypto.Hash import SHA256
from Crypto.PublicKey import RSA
from Crypto.Signature import PKCS1_v1_5
from base64 import b64decode, b64encode
from functools import lru_cache
from izzylib import DefaultDotDict, DotDict
from izzylib import izzylog as logging
from tldextract import extract
from urllib.parse import urlparse
@lru_cache(maxsize=512)
def fetch_actor(url):
if not Client:
raise ValueError('Please set global client with "SetRequestsClient(client)"')
url = url.split('#')[0]
headers = {'Accept': 'application/activity+json'}
resp = Client.request(url, headers=headers)
try:
actor = resp.json
except json.decoder.JSONDecodeError:
return
except Exception as e:
izzylog.debug(f'HTTP {resp.status}: {resp.body}')
raise e from None
actor.web_domain = urlparse(url).netloc
actor.shared_inbox = actor.inbox
actor.pubkey = None
actor.handle = actor.preferredUsername
if actor.get('endpoints'):
actor.shared_inbox = actor.endpoints.get('sharedInbox', actor.inbox)
if actor.get('publicKey'):
actor.pubkey = actor.publicKey.get('publicKeyPem')
return actor
@lru_cache(maxsize=512)
def fetch_instance(domain):
if not Client:
raise ValueError('Please set global client with "SetRequestsClient(client)"')
headers = {'Accept': 'application/json'}
resp = Client.request(f'https://{domain}/api/v1/instance', headers=headers)
try:
return resp.json
except json.decoder.JSONDecodeError:
return
except Exception as e:
izzylog.debug(f'HTTP {resp.status}: {resp.body}')
raise e from None
@lru_cache(maxsize=512)
def fetch_nodeinfo(domain):
if not Client:
raise ValueError('Please set global client with HttpRequestsClient.set_global()')
webfinger = Client.request(f'https://{domain}/.well-known/nodeinfo')
webfinger_data = DotDict(webfinger.body)
for link in webfinger.json.links:
if link['rel'] == 'http://nodeinfo.diaspora.software/ns/schema/2.0':
nodeinfo_url = link['href']
break
nodeinfo = Client.request(nodeinfo_url)
return nodeinfo.json
@lru_cache(maxsize=512)
def fetch_webfinger_account(handle, domain):
if not Client:
raise ValueError('Please set global client with HttpRequestsClient.set_global()')
data = DefaultDotDict()
webfinger = Client.request(f'https://{domain}/.well-known/webfinger?resource=acct:{handle}@{domain}')
if not webfinger.body:
raise ValueError('Webfinger body empty')
data.handle, data.domain = webfinger.json.subject.replace('acct:', '').split('@')
for link in webfinger.json.links:
if link['rel'] == 'self' and link['type'] == 'application/activity+json':
data.actor = link['href']
return data
def generate_rsa_key():
privkey = RSA.generate(2048)
key = DotDict({'PRIVKEY': privkey, 'PUBKEY': privkey.publickey()})
key.update({'privkey': key.PRIVKEY.export_key().decode(), 'pubkey': key.PUBKEY.export_key().decode()})
return key
def parse_signature(signature: str):
if not signature:
return
raise AssertionError('Missing signature header')
split_sig = signature.split(',')
sig = DefaultDotDict()
for part in split_sig:
key, value = part.split('=', 1)
sig[key.lower()] = value.replace('"', '')
sig.headers = sig.headers.split()
sig.domain = urlparse(sig.keyid).netloc
sig.top_domain = '.'.join(extract(sig.domain)[1:])
return sig
def verify_headers(headers: dict, method: str, path: str, actor: dict=None, body=None):
'''Verify a header signature
headers: A dictionary containing all the headers from a request
method: The HTTP method of the request
path: The path of the HTTP request
actor (optional): A dictionary containing the activitypub actor and the link to the pubkey used for verification
body (optional): The body of the request. Only needed if the signature includes the digest header
fail (optional): If set to True, raise an error instead of returning False if any step of the process fails
'''
headers = {k.lower(): v for k,v in headers.items()}
headers['(request-target)'] = f'{method.lower()} {path}'
signature = parse_signature(headers.get('signature'))
digest = headers.get('digest')
missing_headers = [k for k in headers if k in ['date', 'host'] if headers.get(k) == None]
if not signature:
raise AssertionError('Missing signature')
if not actor:
actor = fetch_actor(signature.keyid)
print(actor)
## Add digest header to missing headers list if it doesn't exist
if method.lower() == 'post' and not digest:
missing_headers.append('digest')
## Fail if missing date, host or digest (if POST) headers
if missing_headers:
raise AssertionError(f'Missing headers: {missing_headers}')
## Fail if body verification fails
if digest:
digest_hash = parse_body_digest(headers.get('digest'))
if not verify_string(body, digest_hash.sig, digest_hash.alg):
raise AssertionError('Failed body digest verification')
pubkey = actor.publicKey['publicKeyPem']
return sign_pkcs_headers(pubkey, {k:v for k,v in headers.items() if k in signature.headers}, sig=signature)
async def verify_request(request, actor: dict=None):
'''Verify a header signature from a SimpleASGI request
request: The request with the headers to verify
actor: A dictionary containing the activitypub actor and the link to the pubkey used for verification
'''
body = (await request.body) if request.body else None
headers = {k.lower(): v[0] for k,v in request.headers.items()}
return verify_headers(headers, request.method, request.path, actor, body)
### Helper functions that shouldn't be used directly ###
def parse_body_digest(digest):
if not digest:
raise AssertionError('Empty digest')
parsed = DotDict()
alg, sig = digest.split('=', 1)
parsed.sig = sig
parsed.alg = alg.replace('-', '')
return parsed
def sign_pkcs_headers(key: str, headers: dict, sig=None):
if sig:
head_items = [f'{item}: {headers[item]}' for item in sig.headers]
else:
head_items = [f'{k.lower()}: {v}' for k,v in headers.items()]
head_string = '\n'.join(head_items)
head_bytes = head_string.encode('UTF-8')
KEY = RSA.importKey(key)
pkcs = PKCS1_v1_5.new(KEY)
h = SHA256.new(head_bytes)
if sig:
return pkcs.verify(h, b64decode(sig.signature))
else:
return pkcs.sign(h)
def sign_request(request, privkey, keyid):
request.add_header('(request-target)', f'{request.method.lower()} {request.path}')
request.add_header('host', request.host)
request.add_header('date', datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'))
if request.body:
body_hash = b64encode(SHA256.new(request.body).digest()).decode("UTF-8")
request.add_header('digest', f'SHA-256={body_hash}')
request.add_header('content-length', len(request.body))
sig = {
'keyId': keyid,
'algorithm': 'rsa-sha256',
'headers': ' '.join([k.lower() for k in request.headers.keys()]),
'signature': b64encode(sign_pkcs_headers(privkey, request.headers)).decode('UTF-8')
}
sig_items = [f'{k}="{v}"' for k,v in sig.items()]
sig_string = ','.join(sig_items)
request.add_header('signature', sig_string)
request.remove_header('(request-target)')
request.remove_header('host')
def verify_string(string, enc_string, alg='SHA256', fail=False):
if type(string) != bytes:
string = string.encode('UTF-8')
body_hash = b64encode(SHA256.new(string).digest()).decode('UTF-8')
if body_hash == enc_string:
return True
if fail:
raise AssertionError('String failed validation')
else:
return False

View file

@ -40,13 +40,16 @@ class SqlDatabase:
engine_args = [] engine_args = []
engine_kwargs = {} engine_kwargs = {}
if not kwargs.get('database'): if not kwargs.get('name'):
raise KeyError('Database argument is not set') raise KeyError('Database "name" is not set')
engine_string = dbtype + '://' engine_string = dbtype + '://'
if dbtype == 'sqlite': if dbtype == 'sqlite':
database = kwargs.get('database') try:
database = kwargs['name']
except KeyError:
database = kwargs['database']
if nfs_check(database): if nfs_check(database):
izzylog.error('Database file is on an NFS share which does not support locking. Any writes to the database will fail') izzylog.error('Database file is on an NFS share which does not support locking. Any writes to the database will fail')
@ -240,6 +243,9 @@ class SqlSession(object):
def update(self, table=None, rowid=None, row=None, return_row=False, **data): def update(self, table=None, rowid=None, row=None, return_row=False, **data):
if row: if row:
if not getattr(row, '_table_name', None):
print(row)
print(dir(row))
rowid = row.id rowid = row.id
table = row._table_name table = row._table_name
@ -344,15 +350,13 @@ class CustomRows(object):
def __init__(self, table, row, session): def __init__(self, table, row, session):
if not row:
return
super().__init__() super().__init__()
try: if row:
self._update(row._asdict()) try:
except: self._update(row._asdict())
self._update(row) except:
self._update(row)
self._db = session.db self._db = session.db
self._table_name = table self._table_name = table