diff --git a/dev-requirements.txt b/dev-requirements.txt index c318278a3..f66a9c6fc 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -22,6 +22,7 @@ boltons==24.0.0 cachetools==5.5.0 caproto==1.1.1 certifi==2024.8.30 +cffi==1.17.1 cfgv==3.4.0 charset-normalizer==3.4.0 click==8.1.7 @@ -34,6 +35,7 @@ confluent-kafka==2.6.0 contourpy==1.3.0 copier==9.4.0 coverage==7.6.3 +cryptography==43.0.1 cycler==0.12.1 dask==2024.9.1 databroker==1.2.5 @@ -149,6 +151,7 @@ pure_eval==0.2.3 pvxslibs==1.3.2 py==1.11.0 pyasn1==0.6.1 +pycparser==2.22 pycryptodome==3.21.0 pydantic==2.9.2 pydantic-settings==2.5.2 diff --git a/pyproject.toml b/pyproject.toml index b7157302b..a500631ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "bluesky-stomp>=0.1.2", "pyjwt", "python-multipart", + "cryptography" ] dynamic = ["version"] license.file = "LICENSE" diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index a269b80ea..432b07014 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -30,18 +30,11 @@ def __init__( self.oauth: OauthConfig = oauth self.baseAuthConfig: BaseAuthConfig = baseAuthConfig - def verify_token( - self, token: str, verify_expiration: bool = True - ) -> tuple[bool, Exception | None]: - try: - decode = self.decode_jwt(token, verify_expiration) - if decode: - return (True, None) - except jwt.PyJWTError as e: - print(e) - return (False, e) - - return (False, Exception("Invalid token")) + def verify_token(self, token: str, verify_expiration: bool = True) -> bool: + decode = self.decode_jwt(token, verify_expiration) + if decode: + return True + return False def decode_jwt(self, token: str, verify_expiration: bool = True): signing_key = jwt.PyJWKClient(self.oauth.jwks_uri).get_signing_key_from_jwt( @@ -59,15 +52,9 @@ def decode_jwt(self, token: str, verify_expiration: bool = True): ) return decode - def userInfo(self, token: str) -> tuple[str | None, str | None]: - try: - decode = self.decode_jwt(token) - if decode: - return (decode["name"], decode["fedid"]) - else: - return (None, None) - except jwt.PyJWTError as _: - return (None, None) + def userInfo(self, token: str, verify_expiration=True) -> tuple[str, str]: + decode = self.decode_jwt(token, verify_expiration) + return (decode["name"], decode["fedid"]) class TokenManager: @@ -158,13 +145,14 @@ def poll_for_token( def start_device_flow(self) -> None: if self.token: - valid_token, exception = self.authenticator.verify_token( - self.token["access_token"] - ) - if valid_token: - print("Token verified") - return - elif isinstance(exception, jwt.ExpiredSignatureError): + try: + valid_token = self.authenticator.verify_token( + self.token["access_token"] + ) + if valid_token: + print("Logged in successfully!") + return + except jwt.ExpiredSignatureError: if self.refresh_auth_token(): return @@ -183,14 +171,18 @@ def start_device_flow(self) -> None: ) auth_token_json = self.poll_for_token(device_code) if auth_token_json: - print(auth_token_json) - verify, exception = self.authenticator.verify_token( - auth_token_json["access_token"] - ) - if verify: - print("Token verified") - self.save_token(auth_token_json) - else: + try: + valid_token = self.authenticator.verify_token( + auth_token_json["access_token"] + ) + if valid_token: + print("Logged in successfully!") + self.save_token(auth_token_json) + return + except jwt.ExpiredSignatureError: + if self.refresh_auth_token(): + return + except Exception: print("Unauthorized access") return else: diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 9a5506bf8..345ee232e 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -43,6 +43,8 @@ RUNNER: WorkerDispatcher | None = None AUTHENTICATOR: Authenticator | None = None SWAGGER_CONFIG: dict[str, Any] | None = None +AUTH_URL: str = os.getenv("PKCE_AUTHENTICATION_URL") or "" +TOKEN_URL: str = os.getenv("TOKEN_URL") or "" def _runner() -> WorkerDispatcher: @@ -77,21 +79,29 @@ async def lifespan(app: FastAPI): oauth_scheme = OAuth2AuthorizationCodeBearer( - authorizationUrl=os.getenv("PKCE_AUTHENTICATION_URL") or "", - tokenUrl=os.getenv("TOKEN_URL") or "", - refreshUrl=os.getenv("TOKEN_URL") or "", + authorizationUrl=AUTH_URL, + tokenUrl=TOKEN_URL, + refreshUrl=TOKEN_URL, ) def verify_access_token(access_token: str = Depends(oauth_scheme)): if AUTHENTICATOR: - _, exception = AUTHENTICATOR.verify_token(access_token) - if exception: + try: + valid_token = AUTHENTICATOR.verify_token(access_token) + if not valid_token: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + except Exception as exception: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exception) + status_code=status.HTTP_401_UNAUTHORIZED, ) from exception +if TOKEN_URL == "" or AUTH_URL == "": + dependencies = [] +else: + dependencies = [Depends(verify_access_token)] + app = FastAPI( docs_url="/docs", title="BlueAPI Control", @@ -104,7 +114,7 @@ def verify_access_token(access_token: str = Depends(oauth_scheme)): "scopeSeparator": " ", "scopes": "openid profile offline_access", }, - dependencies=[Depends(verify_access_token)], + dependencies=dependencies, ) diff --git a/tests/unit_tests/service/test_authentication.py b/tests/unit_tests/service/test_authentication.py new file mode 100644 index 000000000..ad047b377 --- /dev/null +++ b/tests/unit_tests/service/test_authentication.py @@ -0,0 +1,135 @@ +import os +from http import HTTPStatus +from unittest import TestCase, mock + +import jwt +import pytest +from jwt import PyJWTError + +from blueapi.config import BaseAuthConfig, CLIAuthConfig, OauthConfig +from blueapi.service.authentication import Authenticator, TokenManager + + +class TestAuthenticator(TestCase): + @mock.patch("requests.get") + def setUp(self, mock_requests_get): + mock_requests_get.return_value.status_code = 200 + mock_requests_get.return_value.json.return_value = { + "device_authorization_endpoint": "https://example.com/device_authorization", + "authorization_endpoint": "https://example.com/authorization", + "token_endpoint": "https://example.com/token", + "issuer": "https://example.com", + "jwks_uri": "https://example.com/.well-known/jwks.json", + "end_session_endpoint": "https://example.com/logout", + } + self.oauth_config = OauthConfig( + oidc_config_url="https://auth.example.com/realms/sample/.well-known/openid-configuration", + ) + self.base_auth_config = BaseAuthConfig( + client_id="example_client_id", client_audience="example_audience" + ) + self.authenticator = Authenticator(self.oauth_config, self.base_auth_config) + + @mock.patch("jwt.decode") + @mock.patch("jwt.PyJWKClient.get_signing_key_from_jwt") + def test_verify_token_valid(self, mock_get_signing_key, mock_decode): + decode_retun_value = {"token": "valid_token", "name": "John Doe"} + mock_decode.return_value = decode_retun_value + valid_token = self.authenticator.verify_token(decode_retun_value["token"]) + self.assertTrue(valid_token) + + @mock.patch("jwt.decode") + @mock.patch("jwt.PyJWKClient.get_signing_key_from_jwt") + def test_verify_token_invalid(self, mock_get_signing_key, mock_decode): + mock_decode.side_effect = jwt.ExpiredSignatureError + token = "invalid_token" + with pytest.raises(PyJWTError): + self.authenticator.verify_token(token) + + @mock.patch("jwt.decode") + @mock.patch("jwt.PyJWKClient.get_signing_key_from_jwt") + def test_user_info(self, mock_get_signing_key, mock_decode): + mock_decode.return_value = { + "name": "John Doe", + "fedid": "12345", + } + name, fedid = self.authenticator.userInfo("valid_token") + self.assertEqual(name, "John Doe") + self.assertEqual(fedid, "12345") + + +class TestTokenManager(TestCase): + @mock.patch("requests.get") + def setUp(self, mock_requests_get): + mock_requests_get.return_value.status_code = 200 + mock_requests_get.return_value.json.return_value = { + "device_authorization_endpoint": "https://example.com/device_authorization", + "authorization_endpoint": "https://example.com/authorization", + "token_endpoint": "https://example.com/token", + "issuer": "https://example.com", + "jwks_uri": "https://example.com/.well-known/jwks.json", + "end_session_endpoint": "https://example.com/logout", + } + self.oauth_config = OauthConfig( + oidc_config_url="https://auth.example.com/realms/sample/.well-known/openid-configuration", + ) + self.cli_auth_config = CLIAuthConfig( + client_id="client_id", + client_audience="client_audience", + token_file_path="~/.token", + ) + self.token_manager = TokenManager(self.oauth_config, self.cli_auth_config) + + @mock.patch("os.path.exists") + @mock.patch("os.remove") + def test_logout(self, mock_remove, mock_exists): + mock_exists.return_value = True + self.token_manager.logout() + mock_remove.assert_called_once_with( + os.path.expanduser(self.cli_auth_config.token_file_path) + ) + + @mock.patch("requests.post") + def test_refresh_auth_token(self, mock_post): + self.token_manager.token = {"refresh_token": "refresh_token"} + mock_post.return_value.status_code = HTTPStatus.OK + mock_post.return_value.json.return_value = {"access_token": "new_access_token"} + result = self.token_manager.refresh_auth_token() + self.assertTrue(result) + + @mock.patch("requests.post") + def test_get_device_code(self, mock_post): + mock_post.return_value.status_code = HTTPStatus.OK + mock_post.return_value.json.return_value = {"device_code": "device_code"} + device_code = self.token_manager.get_device_code() + self.assertEqual(device_code, "device_code") + + @mock.patch("requests.post") + def test_poll_for_token(self, mock_post): + mock_post.return_value.status_code = HTTPStatus.OK + mock_post.return_value.json.return_value = {"access_token": "access_token"} + device_code = "device_code" + token = self.token_manager.poll_for_token(device_code) + self.assertEqual(token, {"access_token": "access_token"}) + + @mock.patch("requests.post") + @mock.patch("time.sleep") + def test_poll_for_token_timeout(self, mock_sleep, mock_post): + mock_post.return_value.status_code = HTTPStatus.BAD_REQUEST + device_code = "device_code" + with self.assertRaises(TimeoutError): + self.token_manager.poll_for_token( + device_code, timeout=1, polling_interval=0.1 + ) + + @mock.patch("requests.post") + @mock.patch("blueapi.service.authentication.Authenticator.verify_token") + def test_start_device_flow(self, mock_verify_token, mock_post): + mock_post.return_value.status_code = HTTPStatus.OK + mock_post.return_value.json.return_value = { + "device_code": "device_code", + "verification_uri_complete": "https://example.com/verify", + } + mock_verify_token.return_value = (True, None) + self.token_manager.start_device_flow() + mock_verify_token.assert_called() diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index 96777db9b..4bb34cc3b 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -307,6 +307,9 @@ def begin_task_and_wait_until_complete( # +@pytest.mark.skip( + "This test is currently waiting for https://github.com/DiamondLightSource/dls-bluesky-core/blob/main/src/dls_bluesky_core/plans/wrapped.py" +) def test_worker_and_data_events_produce_in_order(worker: TaskWorker) -> None: assert_running_count_plan_produces_ordered_worker_and_data_events( [