simplify private key check and enable validation exception test

This commit is contained in:
Izalia Mae 2023-12-24 05:01:45 -05:00
parent e1bf6d7256
commit 0dbc97a98c
2 changed files with 21 additions and 23 deletions

View file

@ -23,21 +23,21 @@ except ImportError:
BaseRequest = None
def _check_private(private: bool) -> Callable:
name = "private" if private else "public"
def _check_private(func: Callable) -> Callable:
"Checks if the key is a private key before running a method"
def outer(func: Callable):
@wraps(func) # noqa: ANN201
def wrapper(key: str, *args: Any, **kwargs: Any) -> Callable:
if key.is_private != private:
raise TypeError(f"Cannot use method '{func.__name__}' on Signer with {name} key")
@wraps(func) # noqa: ANN201
def wrapper(key: Signer, *args: Any, **kwargs: Any) -> Any:
if not key.is_private:
raise TypeError(f"Cannot use method '{func.__name__}' on Signer with public key")
return func(key, *args, **kwargs)
return wrapper
return outer
return func(key, *args, **kwargs)
return wrapper
def _check_key_type(*types: KeyType) -> Callable:
"Checks if the key is the correct type before running the method"
types = [KeyType.parse(type) for type in types]
def outer(func: Callable):
@ -215,7 +215,7 @@ class Signer:
@property
@_check_private(True)
@_check_private
def pubkey(self) -> str:
"Export the public key to a str"
key = self.key.public_key().export_key(format="PEM")
@ -244,7 +244,7 @@ class Signer:
return key
@_check_private(True)
@_check_private
def sign_headers(self, method: str, url: str, data: dict[str, Any] = None,
headers: dict[str, str] = None, sign_all: bool = False,
algorithm: AlgorithmType = AlgorithmType.HS2019) -> dict[str, str]:
@ -326,7 +326,7 @@ class Signer:
return headers
@_check_private(True)
@_check_private
def sign_request(self, request: Request, algorithm: AlgorithmType = AlgorithmType.HS2019) -> Any:
"""
Convenience function to sign a request. Support for more Request classes planned.
@ -356,7 +356,6 @@ class Signer:
return request
@_check_private(False)
def validate_signature(self, method: str, path: str, headers: dict[str, Any],
body: Optional[bytes | str] = None) -> bool:
"""
@ -417,7 +416,6 @@ class Signer:
if BaseRequest is not None:
@_check_private(False)
async def validate_aiohttp_request(self, request: BaseRequest) -> bool:
"""
Validate the signature header of an AIOHTTP server request object

View file

@ -36,12 +36,12 @@ class SignerTest(unittest.TestCase):
# fails because of _check_private
# def test_sign_exception(self):
# output = signer.sign_headers("GET", url, headers={'date': date})
# data = {
# 'date': 'Fri, 25 Nov 2023 06:09:42 GMT',
# 'signature': "keyId=\"https://social.example.com/users/merpinator#main-key\",algorithm=\"hs2019\",headers=\"(request-target) host date (created) (expires)\",created=\"1669374582\",expires=\"1669396182\",signature=\"EpEvJ3N1mVvQYAfM05UxLuAF5TGCv59dYxcI34TFTEHqOr/2wDnQLQTgE1+SpFi8k8zUCPChvT/M4OneSZTQjV8JAa0YGA1g70kq6miQj0vXTMHZOGnBSgUko/B0Io72lLP+Kj+LMLUnQ9WpYvbmlwoilptAD52hLitYqQoqyeEdn8Zm7IVBGc46VVBO0iZgLAofp0WFvpyzLqCxByvGBv2IJ083QBlgN88k5GFkQYyRK9iOKT2+00rKlKCzGjrvjxrFWQ8ZLdZJJxA9BmbRYq4avwyd6kOTk/hWzRSm60doCsR+fXTNiRNteNhhOUJMpugzak/7HFus1vVzvNO91w==\""
# }
def test_verify_exception(self):
output = signer.sign_headers("GET", url, headers={'date': date})
data = {
'date': 'Fri, 25 Nov 2023 06:09:42 GMT',
'signature': "keyId=\"https://social.example.com/users/merpinator#main-key\",algorithm=\"hs2019\",headers=\"(request-target) host date (created) (expires)\",created=\"1669374582\",expires=\"1669396182\",signature=\"EpEvJ3N1mVvQYAfM05UxLuAF5TGCv59dYxcI34TFTEHqOr/2wDnQLQTgE1+SpFi8k8zUCPChvT/M4OneSZTQjV8JAa0YGA1g70kq6miQj0vXTMHZOGnBSgUko/B0Io72lLP+Kj+LMLUnQ9WpYvbmlwoilptAD52hLitYqQoqyeEdn8Zm7IVBGc46VVBO0iZgLAofp0WFvpyzLqCxByvGBv2IJ083QBlgN88k5GFkQYyRK9iOKT2+00rKlKCzGjrvjxrFWQ8ZLdZJJxA9BmbRYq4avwyd6kOTk/hWzRSm60doCsR+fXTNiRNteNhhOUJMpugzak/7HFus1vVzvNO91w==\""
}
# with self.assertRaises(aputils.SignatureFailureError):
# signer.validate_signature("GET", "/actor", {"date": date})
with self.assertRaises(aputils.SignatureFailureError):
signer.validate_signature("GET", "/actor", {"date": date})