simplify private key check and enable validation exception test
This commit is contained in:
parent
e1bf6d7256
commit
0dbc97a98c
|
@ -23,21 +23,21 @@ except ImportError:
|
|||
BaseRequest = None
|
||||
|
||||
|
||||
def _check_private(private: bool) -> Callable:
|
||||
name = "private" if private else "public"
|
||||
def _check_private(func: Callable) -> Callable:
|
||||
"Checks if the key is a private key before running a method"
|
||||
|
||||
def outer(func: Callable):
|
||||
@wraps(func) # noqa: ANN201
|
||||
def wrapper(key: str, *args: Any, **kwargs: Any) -> Callable:
|
||||
if key.is_private != private:
|
||||
raise TypeError(f"Cannot use method '{func.__name__}' on Signer with {name} key")
|
||||
@wraps(func) # noqa: ANN201
|
||||
def wrapper(key: Signer, *args: Any, **kwargs: Any) -> Any:
|
||||
if not key.is_private:
|
||||
raise TypeError(f"Cannot use method '{func.__name__}' on Signer with public key")
|
||||
|
||||
return func(key, *args, **kwargs)
|
||||
return wrapper
|
||||
return outer
|
||||
return func(key, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def _check_key_type(*types: KeyType) -> Callable:
|
||||
"Checks if the key is the correct type before running the method"
|
||||
|
||||
types = [KeyType.parse(type) for type in types]
|
||||
|
||||
def outer(func: Callable):
|
||||
|
@ -215,7 +215,7 @@ class Signer:
|
|||
|
||||
|
||||
@property
|
||||
@_check_private(True)
|
||||
@_check_private
|
||||
def pubkey(self) -> str:
|
||||
"Export the public key to a str"
|
||||
key = self.key.public_key().export_key(format="PEM")
|
||||
|
@ -244,7 +244,7 @@ class Signer:
|
|||
return key
|
||||
|
||||
|
||||
@_check_private(True)
|
||||
@_check_private
|
||||
def sign_headers(self, method: str, url: str, data: dict[str, Any] = None,
|
||||
headers: dict[str, str] = None, sign_all: bool = False,
|
||||
algorithm: AlgorithmType = AlgorithmType.HS2019) -> dict[str, str]:
|
||||
|
@ -326,7 +326,7 @@ class Signer:
|
|||
return headers
|
||||
|
||||
|
||||
@_check_private(True)
|
||||
@_check_private
|
||||
def sign_request(self, request: Request, algorithm: AlgorithmType = AlgorithmType.HS2019) -> Any:
|
||||
"""
|
||||
Convenience function to sign a request. Support for more Request classes planned.
|
||||
|
@ -356,7 +356,6 @@ class Signer:
|
|||
return request
|
||||
|
||||
|
||||
@_check_private(False)
|
||||
def validate_signature(self, method: str, path: str, headers: dict[str, Any],
|
||||
body: Optional[bytes | str] = None) -> bool:
|
||||
"""
|
||||
|
@ -417,7 +416,6 @@ class Signer:
|
|||
|
||||
|
||||
if BaseRequest is not None:
|
||||
@_check_private(False)
|
||||
async def validate_aiohttp_request(self, request: BaseRequest) -> bool:
|
||||
"""
|
||||
Validate the signature header of an AIOHTTP server request object
|
||||
|
|
|
@ -36,12 +36,12 @@ class SignerTest(unittest.TestCase):
|
|||
|
||||
|
||||
# fails because of _check_private
|
||||
# def test_sign_exception(self):
|
||||
# output = signer.sign_headers("GET", url, headers={'date': date})
|
||||
# data = {
|
||||
# 'date': 'Fri, 25 Nov 2023 06:09:42 GMT',
|
||||
# 'signature': "keyId=\"https://social.example.com/users/merpinator#main-key\",algorithm=\"hs2019\",headers=\"(request-target) host date (created) (expires)\",created=\"1669374582\",expires=\"1669396182\",signature=\"EpEvJ3N1mVvQYAfM05UxLuAF5TGCv59dYxcI34TFTEHqOr/2wDnQLQTgE1+SpFi8k8zUCPChvT/M4OneSZTQjV8JAa0YGA1g70kq6miQj0vXTMHZOGnBSgUko/B0Io72lLP+Kj+LMLUnQ9WpYvbmlwoilptAD52hLitYqQoqyeEdn8Zm7IVBGc46VVBO0iZgLAofp0WFvpyzLqCxByvGBv2IJ083QBlgN88k5GFkQYyRK9iOKT2+00rKlKCzGjrvjxrFWQ8ZLdZJJxA9BmbRYq4avwyd6kOTk/hWzRSm60doCsR+fXTNiRNteNhhOUJMpugzak/7HFus1vVzvNO91w==\""
|
||||
# }
|
||||
def test_verify_exception(self):
|
||||
output = signer.sign_headers("GET", url, headers={'date': date})
|
||||
data = {
|
||||
'date': 'Fri, 25 Nov 2023 06:09:42 GMT',
|
||||
'signature': "keyId=\"https://social.example.com/users/merpinator#main-key\",algorithm=\"hs2019\",headers=\"(request-target) host date (created) (expires)\",created=\"1669374582\",expires=\"1669396182\",signature=\"EpEvJ3N1mVvQYAfM05UxLuAF5TGCv59dYxcI34TFTEHqOr/2wDnQLQTgE1+SpFi8k8zUCPChvT/M4OneSZTQjV8JAa0YGA1g70kq6miQj0vXTMHZOGnBSgUko/B0Io72lLP+Kj+LMLUnQ9WpYvbmlwoilptAD52hLitYqQoqyeEdn8Zm7IVBGc46VVBO0iZgLAofp0WFvpyzLqCxByvGBv2IJ083QBlgN88k5GFkQYyRK9iOKT2+00rKlKCzGjrvjxrFWQ8ZLdZJJxA9BmbRYq4avwyd6kOTk/hWzRSm60doCsR+fXTNiRNteNhhOUJMpugzak/7HFus1vVzvNO91w==\""
|
||||
}
|
||||
|
||||
# with self.assertRaises(aputils.SignatureFailureError):
|
||||
# signer.validate_signature("GET", "/actor", {"date": date})
|
||||
with self.assertRaises(aputils.SignatureFailureError):
|
||||
signer.validate_signature("GET", "/actor", {"date": date})
|
||||
|
|
Loading…
Reference in a new issue