From 1ecbe8aed7d100a47877b167e52a3fe5d164af6a Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Fri, 4 Oct 2024 19:48:50 +0200 Subject: [PATCH] mod: Move signature checking for all chain in a function --- .../vm/orchestrator/views/authentication.py | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/src/aleph/vm/orchestrator/views/authentication.py b/src/aleph/vm/orchestrator/views/authentication.py index 97c54655..41966207 100644 --- a/src/aleph/vm/orchestrator/views/authentication.py +++ b/src/aleph/vm/orchestrator/views/authentication.py @@ -11,7 +11,7 @@ import json import logging from collections.abc import Awaitable, Callable, Coroutine -from typing import Any, Literal, Union +from typing import Any, Literal import cryptography.exceptions import pydantic @@ -22,7 +22,7 @@ from jwcrypto import jwk from jwcrypto.jwa import JWA from nacl.exceptions import BadSignatureError -from pydantic import BaseModel, Field, ValidationError, root_validator, validator +from pydantic import BaseModel, ValidationError, root_validator, validator from solathon.utils import verify_signature from aleph.vm.conf import settings @@ -40,7 +40,7 @@ def is_token_still_valid(datestr: str): return expiry_datetime > current_datetime -def verify_wallet_signature(signature, message, address): +def verify_eth_wallet_signature(signature, message, address): """ Verifies a signature issued by a wallet """ @@ -49,6 +49,21 @@ def verify_wallet_signature(signature, message, address): return computed_address.lower() == address.lower() +def check_wallet_signature_or_raise(address, chain, payload, signature): + if chain == Chain.SOL: + try: + verify_signature(address, signature, payload.hex()) + except BadSignatureError: + msg = "Invalid signature" + raise ValueError(msg) + elif chain == "ETH": + if not verify_eth_wallet_signature(signature, payload.hex(), address): + msg = "Invalid signature" + raise ValueError(msg) + else: + raise ValueError("Unsupported chain") + + class SignedPubKeyPayload(BaseModel): """This payload is signed by the wallet of the user to authorize an ephemeral key to act on his behalf.""" @@ -101,20 +116,7 @@ def check_signature(cls, values) -> dict[str, bytes]: signature: list = values["signature"] payload: bytes = values["payload"] content = SignedPubKeyPayload.parse_raw(payload) - - if content.chain == Chain.SOL: - - try: - verify_signature(content.address, signature, payload.hex()) - except BadSignatureError: - msg = "Invalid signature" - raise ValueError(msg) - elif content.chain == "ETH": - if not verify_wallet_signature(signature, payload.hex(), content.address): - msg = "Invalid signature" - raise ValueError(msg) - else: - raise ValueError("Unsupported chain") + check_wallet_signature_or_raise(content.address, content.chain, payload, signature) return values @property