a bunch of changes
This commit is contained in:
parent
fda6e3160b
commit
62a2ab7115
|
@ -117,7 +117,7 @@ class BaseCache(OrderedDict):
|
|||
self[key]['timestamp'] = timestamp + self.ttl
|
||||
|
||||
self.move_to_end(key)
|
||||
return item.data
|
||||
return item
|
||||
|
||||
|
||||
## This doesn't work for some reason
|
||||
|
|
|
@ -4,7 +4,7 @@ from . import Path
|
|||
|
||||
|
||||
class DotDict(dict):
|
||||
dict_ignore_types = ['basecache', 'lrucache', 'ttlcache']
|
||||
non_dict_vars = []
|
||||
|
||||
|
||||
def __init__(self, value=None, **kwargs):
|
||||
|
@ -15,16 +15,11 @@ class DotDict(dict):
|
|||
'''
|
||||
|
||||
super().__init__()
|
||||
self.__setattr__ = self.__setitem__
|
||||
|
||||
## compatibility
|
||||
self.toJson = self.to_json
|
||||
self.fromJson = self.from_json
|
||||
|
||||
if isinstance(value, (str, bytes)):
|
||||
self.from_json(value)
|
||||
|
||||
elif value.__class__.__name__.lower() not in self.dict_ignore_types and isinstance(value, dict):
|
||||
elif isinstance(value, dict):
|
||||
self.update(value)
|
||||
|
||||
elif value:
|
||||
|
@ -42,8 +37,16 @@ class DotDict(dict):
|
|||
raise AttributeError(f'{self.__class__.__name__} object has no attribute {k}') from None
|
||||
|
||||
|
||||
def __setattr__(self, k, v):
|
||||
if k in self.non_dict_vars or k.startswith('_'):
|
||||
super().__setattr__(k, v)
|
||||
|
||||
else:
|
||||
self.__setitem__(k, v)
|
||||
|
||||
|
||||
def __setitem__(self, k, v):
|
||||
if v.__class__.__name__.lower() not in self.dict_ignore_types and isinstance(v, dict):
|
||||
if type(v) == dict:
|
||||
v = DotDict(v)
|
||||
|
||||
super().__setitem__(k, v)
|
||||
|
|
|
@ -348,8 +348,12 @@ def random_gen(length=20, letters=True, numbers=True, extra=None):
|
|||
return ''.join(random.choices(characters, k=length))
|
||||
|
||||
|
||||
def signal_handler(func, *args, **kwargs):
|
||||
handler = lambda signum, frame: func(signum, frame, *args, **kwargs)
|
||||
def signal_handler(func, *args, original_args=True, **kwargs):
|
||||
if original_args:
|
||||
handler = lambda signum, frame: func(signum, frame, *args, **kwargs)
|
||||
|
||||
else:
|
||||
handler = lambda *_: func(*args, **kwargs)
|
||||
|
||||
signal.signal(signal.SIGHUP, handler)
|
||||
signal.signal(signal.SIGINT, handler)
|
||||
|
|
|
@ -68,10 +68,14 @@ class Path(str):
|
|||
|
||||
self.__check_dir(path)
|
||||
|
||||
if target.exists and overwrite:
|
||||
target.delete()
|
||||
if overwrite:
|
||||
try:
|
||||
target.delete()
|
||||
|
||||
shutil.copyfile(self, target)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
shutil.copy2(self, target)
|
||||
|
||||
|
||||
def delete(self):
|
||||
|
|
|
@ -1,8 +1,4 @@
|
|||
from .client import (
|
||||
HttpRequestsClient,
|
||||
HttpRequestsRequest,
|
||||
HttpRequestsResponse,
|
||||
SigningError,
|
||||
from .signature import (
|
||||
verify_request,
|
||||
verify_headers,
|
||||
parse_signature,
|
||||
|
@ -10,22 +6,21 @@ from .client import (
|
|||
fetch_instance,
|
||||
fetch_nodeinfo,
|
||||
fetch_webfinger_account,
|
||||
set_requests_client,
|
||||
generate_rsa_key
|
||||
)
|
||||
|
||||
## These usually only get called by the above functions, but importing anyway
|
||||
|
||||
from .client import (
|
||||
parse_body_digest,
|
||||
verify_string,
|
||||
sign_pkcs_headers
|
||||
HttpRequestsClient,
|
||||
HttpRequestsRequest,
|
||||
HttpRequestsResponse,
|
||||
set_requests_client
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'HttpRequestsClient',
|
||||
'HttpRequestsRequest',
|
||||
'HttpRequestsResponse',
|
||||
'SigningError',
|
||||
'fetch_actor',
|
||||
'fetch_instance',
|
||||
'fetch_nodeinfo',
|
||||
|
|
|
@ -1,20 +1,18 @@
|
|||
import json, requests, sys
|
||||
|
||||
from Crypto.Hash import SHA256
|
||||
from Crypto.PublicKey import RSA
|
||||
from Crypto.Signature import PKCS1_v1_5
|
||||
from PIL import Image
|
||||
|
||||
from base64 import b64decode, b64encode
|
||||
from base64 import b64encode
|
||||
from datetime import datetime
|
||||
from functools import cached_property, lru_cache
|
||||
from functools import cached_property
|
||||
from io import BytesIO
|
||||
from izzylib import DefaultDotDict, DotDict, Path, izzylog as logging, __version__
|
||||
from izzylib.exceptions import HttpFileDownloadedError
|
||||
from ssl import SSLCertVerificationError
|
||||
from tldextract import extract
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .signature import sign_request
|
||||
|
||||
|
||||
Client = None
|
||||
methods = ['connect', 'delete', 'get', 'head', 'options', 'patch', 'post', 'put', 'trace']
|
||||
|
@ -41,48 +39,28 @@ class HttpRequestsClient(object):
|
|||
|
||||
|
||||
def set_global(self):
|
||||
global Client
|
||||
Client = self
|
||||
set_requests_client(self)
|
||||
|
||||
|
||||
def __sign_request(self, request, privkey, keyid):
|
||||
request.add_header('(request-target)', f'{request.method.lower()} {request.path}')
|
||||
request.add_header('host', request.host)
|
||||
request.add_header('date', datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'))
|
||||
|
||||
if request.body:
|
||||
body_hash = b64encode(SHA256.new(request.body).digest()).decode("UTF-8")
|
||||
request.add_header('digest', f'SHA-256={body_hash}')
|
||||
request.add_header('content-length', len(request.body))
|
||||
|
||||
sig = {
|
||||
'keyId': keyid,
|
||||
'algorithm': 'rsa-sha256',
|
||||
'headers': ' '.join([k.lower() for k in request.headers.keys()]),
|
||||
'signature': b64encode(sign_pkcs_headers(privkey, request.headers)).decode('UTF-8')
|
||||
}
|
||||
|
||||
sig_items = [f'{k}="{v}"' for k,v in sig.items()]
|
||||
sig_string = ','.join(sig_items)
|
||||
|
||||
request.add_header('signature', sig_string)
|
||||
|
||||
request.remove_header('(request-target)')
|
||||
request.remove_header('host')
|
||||
|
||||
|
||||
def request(self, *args, method='get', **kwargs):
|
||||
def build_request(self, *args, method='get', privkey=None, keyid=None, **kwargs):
|
||||
if method.lower() not in methods:
|
||||
raise ValueError(f'Invalid method: {method}')
|
||||
|
||||
request = HttpRequestsRequest(self, *args, method=method.lower(), **kwargs)
|
||||
|
||||
if privkey and keyid:
|
||||
request.sign(privkey, keyid)
|
||||
|
||||
return request
|
||||
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
request = self.build_request(*args, **kwargs)
|
||||
return HttpRequestsResponse(request.send())
|
||||
|
||||
|
||||
def signed_request(self, privkey, keyid, *args, **kwargs):
|
||||
request = HttpRequestsRequest(self, *args, **kwargs)
|
||||
self.__sign_request(request, privkey, keyid)
|
||||
return HttpRequestsResponse(request.send())
|
||||
return self.request(*args, privkey=privkey, keyid=keyid, **kwargs)
|
||||
|
||||
|
||||
def download(self, url, filepath, *args, filename=None, **kwargs):
|
||||
|
@ -185,6 +163,10 @@ class HttpRequestsRequest(object):
|
|||
return func(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
def sign(self, privkey, keyid):
|
||||
sign_request(self, privkey, keyid)
|
||||
|
||||
|
||||
class HttpRequestsResponse(object):
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
|
@ -239,233 +221,6 @@ class HttpRequestsResponse(object):
|
|||
fd.write(chunk)
|
||||
|
||||
|
||||
async def verify_request(request, actor: dict=None):
|
||||
'''Verify a header signature from a SimpleASGI request
|
||||
|
||||
request: The request with the headers to verify
|
||||
actor: A dictionary containing the activitypub actor and the link to the pubkey used for verification
|
||||
'''
|
||||
|
||||
body = (await request.body) if request.body else None
|
||||
headers = {k.lower(): v[0] for k,v in request.headers.items()}
|
||||
return verify_headers(headers, request.method, request.path, actor, body)
|
||||
|
||||
|
||||
def verify_headers(headers: dict, method: str, path: str, actor: dict=None, body=None):
|
||||
'''Verify a header signature
|
||||
|
||||
headers: A dictionary containing all the headers from a request
|
||||
method: The HTTP method of the request
|
||||
path: The path of the HTTP request
|
||||
actor (optional): A dictionary containing the activitypub actor and the link to the pubkey used for verification
|
||||
body (optional): The body of the request. Only needed if the signature includes the digest header
|
||||
fail (optional): If set to True, raise an error instead of returning False if any step of the process fails
|
||||
'''
|
||||
|
||||
headers = {k.lower(): v for k,v in headers.items()}
|
||||
headers['(request-target)'] = f'{method.lower()} {path}'
|
||||
signature = parse_signature(headers.get('signature'))
|
||||
digest = headers.get('digest')
|
||||
missing_headers = [k for k in headers if k in ['date', 'host'] if headers.get(k) == None]
|
||||
|
||||
if not signature:
|
||||
raise AssertionError('Missing signature')
|
||||
|
||||
if not actor:
|
||||
actor = fetch_actor(signature.keyid)
|
||||
print(actor)
|
||||
|
||||
## Add digest header to missing headers list if it doesn't exist
|
||||
if method.lower() == 'post' and not digest:
|
||||
missing_headers.append('digest')
|
||||
|
||||
## Fail if missing date, host or digest (if POST) headers
|
||||
if missing_headers:
|
||||
raise AssertionError(f'Missing headers: {missing_headers}')
|
||||
|
||||
## Fail if body verification fails
|
||||
if digest:
|
||||
digest_hash = parse_body_digest(headers.get('digest'))
|
||||
|
||||
if not verify_string(body, digest_hash.sig, digest_hash.alg):
|
||||
raise AssertionError('Failed body digest verification')
|
||||
|
||||
pubkey = actor.publicKey['publicKeyPem']
|
||||
|
||||
return sign_pkcs_headers(pubkey, {k:v for k,v in headers.items() if k in signature.headers}, sig=signature)
|
||||
|
||||
|
||||
def parse_body_digest(digest):
|
||||
if not digest:
|
||||
raise AssertionError('Empty digest')
|
||||
|
||||
parsed = DotDict()
|
||||
alg, sig = digest.split('=', 1)
|
||||
|
||||
parsed.sig = sig
|
||||
parsed.alg = alg.replace('-', '')
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def verify_string(string, enc_string, alg='SHA256', fail=False):
|
||||
if type(string) != bytes:
|
||||
string = string.encode('UTF-8')
|
||||
|
||||
body_hash = b64encode(SHA256.new(string).digest()).decode('UTF-8')
|
||||
|
||||
if body_hash == enc_string:
|
||||
return True
|
||||
|
||||
if fail:
|
||||
raise AssertionError('String failed validation')
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def sign_pkcs_headers(key: str, headers: dict, sig=None):
|
||||
if sig:
|
||||
head_items = [f'{item}: {headers[item]}' for item in sig.headers]
|
||||
|
||||
else:
|
||||
head_items = [f'{k.lower()}: {v}' for k,v in headers.items()]
|
||||
|
||||
head_string = '\n'.join(head_items)
|
||||
head_bytes = head_string.encode('UTF-8')
|
||||
|
||||
KEY = RSA.importKey(key)
|
||||
pkcs = PKCS1_v1_5.new(KEY)
|
||||
h = SHA256.new(head_bytes)
|
||||
|
||||
if sig:
|
||||
return pkcs.verify(h, b64decode(sig.signature))
|
||||
|
||||
else:
|
||||
return pkcs.sign(h)
|
||||
|
||||
|
||||
def parse_signature(signature: str):
|
||||
if not signature:
|
||||
return
|
||||
raise AssertionError('Missing signature header')
|
||||
|
||||
split_sig = signature.split(',')
|
||||
sig = DefaultDotDict()
|
||||
|
||||
for part in split_sig:
|
||||
key, value = part.split('=', 1)
|
||||
sig[key.lower()] = value.replace('"', '')
|
||||
|
||||
sig.headers = sig.headers.split()
|
||||
sig.domain = urlparse(sig.keyid).netloc
|
||||
sig.top_domain = '.'.join(extract(sig.domain)[1:])
|
||||
|
||||
return sig
|
||||
|
||||
|
||||
@lru_cache(maxsize=512)
|
||||
def fetch_actor(url):
|
||||
if not Client:
|
||||
raise ValueError('Please set global client with "SetRequestsClient(client)"')
|
||||
|
||||
url = url.split('#')[0]
|
||||
headers = {'Accept': 'application/activity+json'}
|
||||
resp = Client.request(url, headers=headers)
|
||||
|
||||
try:
|
||||
actor = resp.json
|
||||
|
||||
except json.decoder.JSONDecodeError:
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
izzylog.debug(f'HTTP {resp.status}: {resp.body}')
|
||||
raise e from None
|
||||
|
||||
actor.web_domain = urlparse(url).netloc
|
||||
actor.shared_inbox = actor.inbox
|
||||
actor.pubkey = None
|
||||
actor.handle = actor.preferredUsername
|
||||
|
||||
if actor.get('endpoints'):
|
||||
actor.shared_inbox = actor.endpoints.get('sharedInbox', actor.inbox)
|
||||
|
||||
if actor.get('publicKey'):
|
||||
actor.pubkey = actor.publicKey.get('publicKeyPem')
|
||||
|
||||
return actor
|
||||
|
||||
|
||||
@lru_cache(maxsize=512)
|
||||
def fetch_instance(domain):
|
||||
if not Client:
|
||||
raise ValueError('Please set global client with "SetRequestsClient(client)"')
|
||||
|
||||
headers = {'Accept': 'application/json'}
|
||||
resp = Client.request(f'https://{domain}/api/v1/instance', headers=headers)
|
||||
|
||||
try:
|
||||
return resp.json
|
||||
|
||||
except json.decoder.JSONDecodeError:
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
izzylog.debug(f'HTTP {resp.status}: {resp.body}')
|
||||
raise e from None
|
||||
|
||||
|
||||
@lru_cache(maxsize=512)
|
||||
def fetch_webfinger_account(handle, domain):
|
||||
if not Client:
|
||||
raise ValueError('Please set global client with HttpRequestsClient.set_global()')
|
||||
|
||||
data = DefaultDotDict()
|
||||
webfinger = Client.request(f'https://{domain}/.well-known/webfinger?resource=acct:{handle}@{domain}')
|
||||
|
||||
if not webfinger.body:
|
||||
raise ValueError('Webfinger body empty')
|
||||
|
||||
data.handle, data.domain = webfinger.json.subject.replace('acct:', '').split('@')
|
||||
|
||||
for link in webfinger.json.links:
|
||||
if link['rel'] == 'self' and link['type'] == 'application/activity+json':
|
||||
data.actor = link['href']
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@lru_cache(maxsize=512)
|
||||
def fetch_nodeinfo(domain):
|
||||
if not Client:
|
||||
raise ValueError('Please set global client with HttpRequestsClient.set_global()')
|
||||
|
||||
webfinger = Client.request(f'https://{domain}/.well-known/nodeinfo')
|
||||
webfinger_data = DotDict(webfinger.body)
|
||||
|
||||
for link in webfinger.json.links:
|
||||
if link['rel'] == 'http://nodeinfo.diaspora.software/ns/schema/2.0':
|
||||
nodeinfo_url = link['href']
|
||||
break
|
||||
|
||||
nodeinfo = Client.request(nodeinfo_url)
|
||||
return nodeinfo.json
|
||||
|
||||
|
||||
def set_requests_client(client=None):
|
||||
global Client
|
||||
Client = client or RequestsClient()
|
||||
|
||||
|
||||
def generate_rsa_key():
|
||||
privkey = RSA.generate(2048)
|
||||
|
||||
key = DotDict({'PRIVKEY': privkey, 'PUBKEY': privkey.publickey()})
|
||||
key.update({'privkey': key.PRIVKEY.export_key().decode(), 'pubkey': key.PUBKEY.export_key().decode()})
|
||||
|
||||
return key
|
||||
|
||||
|
||||
class SigningError(Exception):
|
||||
pass
|
||||
|
|
264
requests_client/izzylib/http_requests_client/signature.py
Normal file
264
requests_client/izzylib/http_requests_client/signature.py
Normal file
|
@ -0,0 +1,264 @@
|
|||
import json, requests, sys
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from Crypto.Hash import SHA256
|
||||
from Crypto.PublicKey import RSA
|
||||
from Crypto.Signature import PKCS1_v1_5
|
||||
from base64 import b64decode, b64encode
|
||||
from functools import lru_cache
|
||||
from izzylib import DefaultDotDict, DotDict
|
||||
from izzylib import izzylog as logging
|
||||
from tldextract import extract
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
@lru_cache(maxsize=512)
|
||||
def fetch_actor(url):
|
||||
if not Client:
|
||||
raise ValueError('Please set global client with "SetRequestsClient(client)"')
|
||||
|
||||
url = url.split('#')[0]
|
||||
headers = {'Accept': 'application/activity+json'}
|
||||
resp = Client.request(url, headers=headers)
|
||||
|
||||
try:
|
||||
actor = resp.json
|
||||
|
||||
except json.decoder.JSONDecodeError:
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
izzylog.debug(f'HTTP {resp.status}: {resp.body}')
|
||||
raise e from None
|
||||
|
||||
actor.web_domain = urlparse(url).netloc
|
||||
actor.shared_inbox = actor.inbox
|
||||
actor.pubkey = None
|
||||
actor.handle = actor.preferredUsername
|
||||
|
||||
if actor.get('endpoints'):
|
||||
actor.shared_inbox = actor.endpoints.get('sharedInbox', actor.inbox)
|
||||
|
||||
if actor.get('publicKey'):
|
||||
actor.pubkey = actor.publicKey.get('publicKeyPem')
|
||||
|
||||
return actor
|
||||
|
||||
|
||||
@lru_cache(maxsize=512)
|
||||
def fetch_instance(domain):
|
||||
if not Client:
|
||||
raise ValueError('Please set global client with "SetRequestsClient(client)"')
|
||||
|
||||
headers = {'Accept': 'application/json'}
|
||||
resp = Client.request(f'https://{domain}/api/v1/instance', headers=headers)
|
||||
|
||||
try:
|
||||
return resp.json
|
||||
|
||||
except json.decoder.JSONDecodeError:
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
izzylog.debug(f'HTTP {resp.status}: {resp.body}')
|
||||
raise e from None
|
||||
|
||||
|
||||
@lru_cache(maxsize=512)
|
||||
def fetch_nodeinfo(domain):
|
||||
if not Client:
|
||||
raise ValueError('Please set global client with HttpRequestsClient.set_global()')
|
||||
|
||||
webfinger = Client.request(f'https://{domain}/.well-known/nodeinfo')
|
||||
webfinger_data = DotDict(webfinger.body)
|
||||
|
||||
for link in webfinger.json.links:
|
||||
if link['rel'] == 'http://nodeinfo.diaspora.software/ns/schema/2.0':
|
||||
nodeinfo_url = link['href']
|
||||
break
|
||||
|
||||
nodeinfo = Client.request(nodeinfo_url)
|
||||
return nodeinfo.json
|
||||
|
||||
|
||||
@lru_cache(maxsize=512)
|
||||
def fetch_webfinger_account(handle, domain):
|
||||
if not Client:
|
||||
raise ValueError('Please set global client with HttpRequestsClient.set_global()')
|
||||
|
||||
data = DefaultDotDict()
|
||||
webfinger = Client.request(f'https://{domain}/.well-known/webfinger?resource=acct:{handle}@{domain}')
|
||||
|
||||
if not webfinger.body:
|
||||
raise ValueError('Webfinger body empty')
|
||||
|
||||
data.handle, data.domain = webfinger.json.subject.replace('acct:', '').split('@')
|
||||
|
||||
for link in webfinger.json.links:
|
||||
if link['rel'] == 'self' and link['type'] == 'application/activity+json':
|
||||
data.actor = link['href']
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def generate_rsa_key():
|
||||
privkey = RSA.generate(2048)
|
||||
|
||||
key = DotDict({'PRIVKEY': privkey, 'PUBKEY': privkey.publickey()})
|
||||
key.update({'privkey': key.PRIVKEY.export_key().decode(), 'pubkey': key.PUBKEY.export_key().decode()})
|
||||
|
||||
return key
|
||||
|
||||
|
||||
def parse_signature(signature: str):
|
||||
if not signature:
|
||||
return
|
||||
raise AssertionError('Missing signature header')
|
||||
|
||||
split_sig = signature.split(',')
|
||||
sig = DefaultDotDict()
|
||||
|
||||
for part in split_sig:
|
||||
key, value = part.split('=', 1)
|
||||
sig[key.lower()] = value.replace('"', '')
|
||||
|
||||
sig.headers = sig.headers.split()
|
||||
sig.domain = urlparse(sig.keyid).netloc
|
||||
sig.top_domain = '.'.join(extract(sig.domain)[1:])
|
||||
|
||||
return sig
|
||||
|
||||
|
||||
def verify_headers(headers: dict, method: str, path: str, actor: dict=None, body=None):
|
||||
'''Verify a header signature
|
||||
|
||||
headers: A dictionary containing all the headers from a request
|
||||
method: The HTTP method of the request
|
||||
path: The path of the HTTP request
|
||||
actor (optional): A dictionary containing the activitypub actor and the link to the pubkey used for verification
|
||||
body (optional): The body of the request. Only needed if the signature includes the digest header
|
||||
fail (optional): If set to True, raise an error instead of returning False if any step of the process fails
|
||||
'''
|
||||
|
||||
headers = {k.lower(): v for k,v in headers.items()}
|
||||
headers['(request-target)'] = f'{method.lower()} {path}'
|
||||
signature = parse_signature(headers.get('signature'))
|
||||
digest = headers.get('digest')
|
||||
missing_headers = [k for k in headers if k in ['date', 'host'] if headers.get(k) == None]
|
||||
|
||||
if not signature:
|
||||
raise AssertionError('Missing signature')
|
||||
|
||||
if not actor:
|
||||
actor = fetch_actor(signature.keyid)
|
||||
print(actor)
|
||||
|
||||
## Add digest header to missing headers list if it doesn't exist
|
||||
if method.lower() == 'post' and not digest:
|
||||
missing_headers.append('digest')
|
||||
|
||||
## Fail if missing date, host or digest (if POST) headers
|
||||
if missing_headers:
|
||||
raise AssertionError(f'Missing headers: {missing_headers}')
|
||||
|
||||
## Fail if body verification fails
|
||||
if digest:
|
||||
digest_hash = parse_body_digest(headers.get('digest'))
|
||||
|
||||
if not verify_string(body, digest_hash.sig, digest_hash.alg):
|
||||
raise AssertionError('Failed body digest verification')
|
||||
|
||||
pubkey = actor.publicKey['publicKeyPem']
|
||||
|
||||
return sign_pkcs_headers(pubkey, {k:v for k,v in headers.items() if k in signature.headers}, sig=signature)
|
||||
|
||||
|
||||
async def verify_request(request, actor: dict=None):
|
||||
'''Verify a header signature from a SimpleASGI request
|
||||
|
||||
request: The request with the headers to verify
|
||||
actor: A dictionary containing the activitypub actor and the link to the pubkey used for verification
|
||||
'''
|
||||
|
||||
body = (await request.body) if request.body else None
|
||||
headers = {k.lower(): v[0] for k,v in request.headers.items()}
|
||||
return verify_headers(headers, request.method, request.path, actor, body)
|
||||
|
||||
|
||||
|
||||
### Helper functions that shouldn't be used directly ###
|
||||
def parse_body_digest(digest):
|
||||
if not digest:
|
||||
raise AssertionError('Empty digest')
|
||||
|
||||
parsed = DotDict()
|
||||
alg, sig = digest.split('=', 1)
|
||||
|
||||
parsed.sig = sig
|
||||
parsed.alg = alg.replace('-', '')
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def sign_pkcs_headers(key: str, headers: dict, sig=None):
|
||||
if sig:
|
||||
head_items = [f'{item}: {headers[item]}' for item in sig.headers]
|
||||
|
||||
else:
|
||||
head_items = [f'{k.lower()}: {v}' for k,v in headers.items()]
|
||||
|
||||
head_string = '\n'.join(head_items)
|
||||
head_bytes = head_string.encode('UTF-8')
|
||||
|
||||
KEY = RSA.importKey(key)
|
||||
pkcs = PKCS1_v1_5.new(KEY)
|
||||
h = SHA256.new(head_bytes)
|
||||
|
||||
if sig:
|
||||
return pkcs.verify(h, b64decode(sig.signature))
|
||||
|
||||
else:
|
||||
return pkcs.sign(h)
|
||||
|
||||
|
||||
def sign_request(request, privkey, keyid):
|
||||
request.add_header('(request-target)', f'{request.method.lower()} {request.path}')
|
||||
request.add_header('host', request.host)
|
||||
request.add_header('date', datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'))
|
||||
|
||||
if request.body:
|
||||
body_hash = b64encode(SHA256.new(request.body).digest()).decode("UTF-8")
|
||||
request.add_header('digest', f'SHA-256={body_hash}')
|
||||
request.add_header('content-length', len(request.body))
|
||||
|
||||
sig = {
|
||||
'keyId': keyid,
|
||||
'algorithm': 'rsa-sha256',
|
||||
'headers': ' '.join([k.lower() for k in request.headers.keys()]),
|
||||
'signature': b64encode(sign_pkcs_headers(privkey, request.headers)).decode('UTF-8')
|
||||
}
|
||||
|
||||
sig_items = [f'{k}="{v}"' for k,v in sig.items()]
|
||||
sig_string = ','.join(sig_items)
|
||||
|
||||
request.add_header('signature', sig_string)
|
||||
|
||||
request.remove_header('(request-target)')
|
||||
request.remove_header('host')
|
||||
|
||||
|
||||
def verify_string(string, enc_string, alg='SHA256', fail=False):
|
||||
if type(string) != bytes:
|
||||
string = string.encode('UTF-8')
|
||||
|
||||
body_hash = b64encode(SHA256.new(string).digest()).decode('UTF-8')
|
||||
|
||||
if body_hash == enc_string:
|
||||
return True
|
||||
|
||||
if fail:
|
||||
raise AssertionError('String failed validation')
|
||||
|
||||
else:
|
||||
return False
|
|
@ -40,13 +40,16 @@ class SqlDatabase:
|
|||
engine_args = []
|
||||
engine_kwargs = {}
|
||||
|
||||
if not kwargs.get('database'):
|
||||
raise KeyError('Database argument is not set')
|
||||
if not kwargs.get('name'):
|
||||
raise KeyError('Database "name" is not set')
|
||||
|
||||
engine_string = dbtype + '://'
|
||||
|
||||
if dbtype == 'sqlite':
|
||||
database = kwargs.get('database')
|
||||
try:
|
||||
database = kwargs['name']
|
||||
except KeyError:
|
||||
database = kwargs['database']
|
||||
|
||||
if nfs_check(database):
|
||||
izzylog.error('Database file is on an NFS share which does not support locking. Any writes to the database will fail')
|
||||
|
@ -240,6 +243,9 @@ class SqlSession(object):
|
|||
|
||||
def update(self, table=None, rowid=None, row=None, return_row=False, **data):
|
||||
if row:
|
||||
if not getattr(row, '_table_name', None):
|
||||
print(row)
|
||||
print(dir(row))
|
||||
rowid = row.id
|
||||
table = row._table_name
|
||||
|
||||
|
@ -344,15 +350,13 @@ class CustomRows(object):
|
|||
|
||||
|
||||
def __init__(self, table, row, session):
|
||||
if not row:
|
||||
return
|
||||
|
||||
super().__init__()
|
||||
|
||||
try:
|
||||
self._update(row._asdict())
|
||||
except:
|
||||
self._update(row)
|
||||
if row:
|
||||
try:
|
||||
self._update(row._asdict())
|
||||
except:
|
||||
self._update(row)
|
||||
|
||||
self._db = session.db
|
||||
self._table_name = table
|
||||
|
|
Loading…
Reference in a new issue