multiple signer changes
* fix several issues with signing and verifying * rename `AlgorithmType.ORIGINAL` to `AlgorithmType.RSASHA256` * deprecate `Digest.new_from_digest` and add `Digest.parse` class method * swap position of `algorithm` and `sign_all` in `Signer.sign_headers`
This commit is contained in:
parent
98bcdb73eb
commit
36147cccb5
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import typing
|
||||
|
||||
from Crypto import Hash
|
||||
|
@ -11,10 +12,13 @@ from datetime import datetime, timedelta
|
|||
from urllib.parse import urlparse
|
||||
|
||||
from .enums import AlgorithmType
|
||||
from .errors import SignatureFailureError
|
||||
from .misc import Digest, HttpDate, Signature
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from .message import Message
|
||||
from .signer import Signer
|
||||
|
||||
class DataHash(typing.Protocol):
|
||||
|
@ -60,9 +64,10 @@ class Algorithm(ABC):
|
|||
@abstractmethod
|
||||
def process_headers(
|
||||
method: str,
|
||||
url: str,
|
||||
host: str | None,
|
||||
path: str,
|
||||
headers: dict[str, str],
|
||||
body: bytes | None = None) -> dict[str, str]:
|
||||
body: Message | dict[str, Any] | bytes | str | None = None) -> dict[str, str]:
|
||||
...
|
||||
|
||||
|
||||
|
@ -89,12 +94,10 @@ class Algorithm(ABC):
|
|||
url: str,
|
||||
headers: dict[str, str],
|
||||
used_headers: Sequence | None = None,
|
||||
body: bytes | str | None = None) -> dict[str, str]:
|
||||
body: Message | dict[str, Any] | bytes | str | None = None) -> dict[str, str]:
|
||||
|
||||
if body is not None and not isinstance(body, bytes):
|
||||
body = bytes(body, "utf-8")
|
||||
|
||||
headers = type(self).process_headers(method, url, headers, body)
|
||||
uri = urlparse(url)
|
||||
headers = type(self).process_headers(method, uri.hostname, uri.path, headers, body)
|
||||
used_headers = tuple([*headers.keys(), *(used_headers or [])])
|
||||
hash_bytes, used_headers = type(self).build_signing_string(headers, used_headers)
|
||||
data_hash = self.hash_data(hash_bytes)
|
||||
|
@ -126,10 +129,18 @@ class Algorithm(ABC):
|
|||
method: str,
|
||||
path: str,
|
||||
headers: dict[str, str],
|
||||
signature: Signature) -> bool:
|
||||
signature: Signature,
|
||||
body: Message | dict[str, Any] | bytes | str | None = None) -> bool:
|
||||
|
||||
headers = type(self).process_headers(method, None, path, headers, body)
|
||||
|
||||
if (digest := Digest.parse(headers["digest"])):
|
||||
if body is None:
|
||||
raise SignatureFailureError("A digest was added with an empty body")
|
||||
|
||||
if not digest.validate(body):
|
||||
raise SignatureFailureError("Body digest does not match")
|
||||
|
||||
url = f"https://{headers['host']}/{path}"
|
||||
headers = type(self).process_headers(method, url, headers)
|
||||
sig_hash, _ = type(self).build_signing_string(headers, signature.headers)
|
||||
|
||||
return self.verify_data(
|
||||
|
@ -168,14 +179,27 @@ class HS2019(Algorithm):
|
|||
@staticmethod
|
||||
def process_headers(
|
||||
method: str,
|
||||
url: str,
|
||||
host: str | None,
|
||||
path: str,
|
||||
headers: dict[str, str],
|
||||
body: bytes | None = None) -> dict[str, str]:
|
||||
body: Message | dict[str, Any] | bytes | str | None = None) -> dict[str, str]:
|
||||
|
||||
parsed = urlparse(url)
|
||||
from .message import Message
|
||||
|
||||
if body is not None:
|
||||
if isinstance(body, Message):
|
||||
body = body.to_json()
|
||||
|
||||
elif isinstance(body, dict):
|
||||
body = json.dumps(body)
|
||||
|
||||
if not isinstance(body, bytes):
|
||||
body = bytes(body, "utf-8")
|
||||
|
||||
headers = {key.lower(): value for key, value in headers.items()}
|
||||
headers["host"] = parsed.netloc
|
||||
|
||||
if host:
|
||||
headers["host"] = host
|
||||
|
||||
date: HttpDate | datetime | str = headers.get("date", HttpDate.new_utc())
|
||||
|
||||
|
@ -190,11 +214,13 @@ class HS2019(Algorithm):
|
|||
|
||||
headers.update({
|
||||
"date": date.to_string(),
|
||||
"(request-target)": f"{method.lower()} {parsed.path}",
|
||||
"(request-target)": f"{method.lower()} {path}",
|
||||
"(created)": str(date.timestamp()),
|
||||
"(expires)": str((date + timedelta(hours=6)).timestamp())
|
||||
})
|
||||
|
||||
print(method, path, repr(body))
|
||||
|
||||
if body is not None:
|
||||
headers.update({
|
||||
"digest": Digest.new(body).compile(),
|
||||
|
@ -206,17 +232,18 @@ class HS2019(Algorithm):
|
|||
|
||||
@register
|
||||
class RsaSha256(HS2019):
|
||||
algo_type: AlgorithmType = AlgorithmType.ORIGINAL
|
||||
algo_type: AlgorithmType = AlgorithmType.RSASHA256
|
||||
|
||||
|
||||
@staticmethod
|
||||
def process_headers(
|
||||
method: str,
|
||||
url: str,
|
||||
host: str | None,
|
||||
path: str,
|
||||
headers: dict[str, str],
|
||||
body: bytes | None = None) -> dict[str, str]:
|
||||
body: Message | dict[str, Any] | bytes | str | None = None) -> dict[str, str]:
|
||||
|
||||
headers = HS2019.process_headers(method, url, headers)
|
||||
headers = HS2019.process_headers(method, host, path, headers, body)
|
||||
|
||||
del headers["(created)"]
|
||||
del headers["(expires)"]
|
||||
|
@ -229,15 +256,16 @@ class RsaSha256(HS2019):
|
|||
url: str,
|
||||
headers: dict[str, str],
|
||||
used_headers: Sequence | None = None,
|
||||
body: bytes | str | None = None) -> dict[str, str]:
|
||||
body: Message | dict[str, Any] | bytes | str | None = None) -> dict[str, str]:
|
||||
|
||||
return HS2019.sign_headers(self, method, url, headers, used_headers, body)
|
||||
|
||||
|
||||
def verify_headers(self,
|
||||
method: str,
|
||||
url: str,
|
||||
path: str,
|
||||
headers: dict[str, str],
|
||||
signature: Signature) -> bool:
|
||||
signature: Signature,
|
||||
body: Message | dict[str, Any] | bytes | str | None = None) -> bool:
|
||||
|
||||
return HS2019.verify_headers(self, method, url, headers, signature)
|
||||
return HS2019.verify_headers(self, method, path, headers, signature, body)
|
||||
|
|
|
@ -64,12 +64,15 @@ class StrEnum(str, Enum):
|
|||
class AlgorithmType(StrEnum):
|
||||
"Algorithm type"
|
||||
|
||||
ORIGINAL = "rsa-sha256"
|
||||
RSASHA256 = "rsa-sha256"
|
||||
"Old deprecated signing standard (still in use in the fediverse)"
|
||||
|
||||
HS2019 = "hs2019"
|
||||
"Current signing standard"
|
||||
|
||||
ORIGINAL = RSASHA256
|
||||
"Alias for RSASHA256 (will be removed in 0.3.0)"
|
||||
|
||||
|
||||
class KeyType(StrEnum):
|
||||
"Type of private or public key"
|
||||
|
|
|
@ -115,17 +115,30 @@ class Digest:
|
|||
|
||||
|
||||
@classmethod
|
||||
def parse(cls: type[Digest], digest: str | None) -> Digest | None:
|
||||
"""
|
||||
Create a new digest from a digest header
|
||||
|
||||
:param digest: Digest header
|
||||
"""
|
||||
|
||||
if not digest:
|
||||
return None
|
||||
|
||||
alg, digest = digest.split("=", 1)
|
||||
return cls(digest, alg)
|
||||
|
||||
|
||||
@classmethod
|
||||
@deprecated("Digest.parse", "0.1.9", "0.3.0")
|
||||
def new_from_digest(cls: type[Digest], digest: str | None) -> Digest | None:
|
||||
"""
|
||||
Create a new digest from a digest header
|
||||
|
||||
:param digest: Digest header
|
||||
"""
|
||||
if not digest:
|
||||
return None
|
||||
|
||||
alg, digest = digest.split("=", 1)
|
||||
return cls(digest, alg)
|
||||
return cls.parse(digest)
|
||||
|
||||
|
||||
@property
|
||||
|
@ -139,13 +152,14 @@ class Digest:
|
|||
return "=".join([self.algorithm, self.digest])
|
||||
|
||||
|
||||
def validate(self, body: dict | str | bytes, hash_size: int = 256) -> bool:
|
||||
def validate(self, body: dict[str, Any] | str | bytes, hash_size: int = 256) -> bool:
|
||||
"""
|
||||
Check if the body digest matches the body
|
||||
|
||||
:param body: Message body to verify
|
||||
:param hash_size: Size of the hashing algorithm
|
||||
"""
|
||||
|
||||
body_digest = Digest.new(body, hash_size)
|
||||
return self.digest == body_digest.digest
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
|
||||
from functools import wraps
|
||||
|
@ -205,8 +204,8 @@ class Signer:
|
|||
url: str,
|
||||
data: dict[str, Any] | bytes | str | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
sign_all: bool = False,
|
||||
algorithm: AlgorithmType = AlgorithmType.HS2019) -> dict[str, Any]:
|
||||
algorithm: AlgorithmType = AlgorithmType.HS2019,
|
||||
sign_all: bool = False) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a signature and return the headers with a "Signature" key
|
||||
|
||||
|
@ -217,18 +216,14 @@ class Signer:
|
|||
:param url: URL of the request
|
||||
:param data: ActivityPub message for a POST request
|
||||
:param headers: Request headers
|
||||
:param sign_all: If ``True``, sign all headers instead of just the required ones
|
||||
:param algorithm: Type of algorithm to use for hashing the headers. HS2019 is the only
|
||||
non-deprecated algorithm.
|
||||
:param sign_all: If ``True``, sign all headers instead of just the required ones
|
||||
"""
|
||||
|
||||
algo = algorithms.get(algorithm)(self)
|
||||
headers = headers or {}
|
||||
used_headers = tuple([]) if not sign_all else tuple(headers)
|
||||
|
||||
if data is not None and isinstance(data, dict):
|
||||
data = json.dumps(data)
|
||||
|
||||
return algo.sign_headers(method, url, headers, used_headers, data)
|
||||
|
||||
|
||||
|
@ -289,7 +284,7 @@ class Signer:
|
|||
|
||||
algo = algorithms.get(signature.algorithm)(self)
|
||||
|
||||
if not algo.verify_headers(method, path, headers, signature):
|
||||
if not algo.verify_headers(method, path, headers, signature, body):
|
||||
raise SignatureFailureError("Failed to verify signature")
|
||||
|
||||
return True
|
||||
|
@ -302,6 +297,7 @@ class Signer:
|
|||
|
||||
:param request: AioHttp server request to validate
|
||||
"""
|
||||
|
||||
return self.validate_signature(
|
||||
request.method,
|
||||
request.path,
|
||||
|
|
Loading…
Reference in a new issue