Skip to content

Commit

Permalink
added code review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ZohebShaikh committed Oct 16, 2024
1 parent 8bac945 commit 7ba72d7
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 43 deletions.
3 changes: 3 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ boltons==24.0.0
cachetools==5.5.0
caproto==1.1.1
certifi==2024.8.30
cffi==1.17.1
cfgv==3.4.0
charset-normalizer==3.4.0
click==8.1.7
Expand All @@ -34,6 +35,7 @@ confluent-kafka==2.6.0
contourpy==1.3.0
copier==9.4.0
coverage==7.6.3
cryptography==43.0.1
cycler==0.12.1
dask==2024.9.1
databroker==1.2.5
Expand Down Expand Up @@ -149,6 +151,7 @@ pure_eval==0.2.3
pvxslibs==1.3.2
py==1.11.0
pyasn1==0.6.1
pycparser==2.22
pycryptodome==3.21.0
pydantic==2.9.2
pydantic-settings==2.5.2
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"bluesky-stomp>=0.1.2",
"pyjwt",
"python-multipart",
"cryptography"
]
dynamic = ["version"]
license.file = "LICENSE"
Expand Down
64 changes: 28 additions & 36 deletions src/blueapi/service/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,11 @@ def __init__(
self.oauth: OauthConfig = oauth
self.baseAuthConfig: BaseAuthConfig = baseAuthConfig

def verify_token(
self, token: str, verify_expiration: bool = True
) -> tuple[bool, Exception | None]:
try:
decode = self.decode_jwt(token, verify_expiration)
if decode:
return (True, None)
except jwt.PyJWTError as e:
print(e)
return (False, e)

return (False, Exception("Invalid token"))
def verify_token(self, token: str, verify_expiration: bool = True) -> bool:
decode = self.decode_jwt(token, verify_expiration)
if decode:
return True
return False

Check warning on line 37 in src/blueapi/service/authentication.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/service/authentication.py#L37

Added line #L37 was not covered by tests

def decode_jwt(self, token: str, verify_expiration: bool = True):
signing_key = jwt.PyJWKClient(self.oauth.jwks_uri).get_signing_key_from_jwt(
Expand All @@ -59,15 +52,9 @@ def decode_jwt(self, token: str, verify_expiration: bool = True):
)
return decode

def userInfo(self, token: str) -> tuple[str | None, str | None]:
try:
decode = self.decode_jwt(token)
if decode:
return (decode["name"], decode["fedid"])
else:
return (None, None)
except jwt.PyJWTError as _:
return (None, None)
def userInfo(self, token: str, verify_expiration=True) -> tuple[str, str]:
decode = self.decode_jwt(token, verify_expiration)
return (decode["name"], decode["fedid"])


class TokenManager:
Expand Down Expand Up @@ -158,13 +145,14 @@ def poll_for_token(

def start_device_flow(self) -> None:
if self.token:
valid_token, exception = self.authenticator.verify_token(
self.token["access_token"]
)
if valid_token:
print("Token verified")
return
elif isinstance(exception, jwt.ExpiredSignatureError):
try:
valid_token = self.authenticator.verify_token(
self.token["access_token"]
)
if valid_token:
print("Logged in successfully!")
return
except jwt.ExpiredSignatureError:
if self.refresh_auth_token():
return

Check warning on line 157 in src/blueapi/service/authentication.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/service/authentication.py#L155-L157

Added lines #L155 - L157 were not covered by tests

Expand All @@ -183,14 +171,18 @@ def start_device_flow(self) -> None:
)
auth_token_json = self.poll_for_token(device_code)
if auth_token_json:
print(auth_token_json)
verify, exception = self.authenticator.verify_token(
auth_token_json["access_token"]
)
if verify:
print("Token verified")
self.save_token(auth_token_json)
else:
try:
valid_token = self.authenticator.verify_token(

Check warning on line 175 in src/blueapi/service/authentication.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/service/authentication.py#L172-L175

Added lines #L172 - L175 were not covered by tests
auth_token_json["access_token"]
)
if valid_token:
print("Logged in successfully!")
self.save_token(auth_token_json)
return
except jwt.ExpiredSignatureError:
if self.refresh_auth_token():
return
except Exception:
print("Unauthorized access")
return

Check warning on line 187 in src/blueapi/service/authentication.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/service/authentication.py#L178-L187

Added lines #L178 - L187 were not covered by tests
else:
Expand Down
24 changes: 17 additions & 7 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
RUNNER: WorkerDispatcher | None = None
AUTHENTICATOR: Authenticator | None = None
SWAGGER_CONFIG: dict[str, Any] | None = None
AUTH_URL: str = os.getenv("PKCE_AUTHENTICATION_URL") or ""
TOKEN_URL: str = os.getenv("TOKEN_URL") or ""


def _runner() -> WorkerDispatcher:
Expand Down Expand Up @@ -77,21 +79,29 @@ async def lifespan(app: FastAPI):


oauth_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl=os.getenv("PKCE_AUTHENTICATION_URL") or "",
tokenUrl=os.getenv("TOKEN_URL") or "",
refreshUrl=os.getenv("TOKEN_URL") or "",
authorizationUrl=AUTH_URL,
tokenUrl=TOKEN_URL,
refreshUrl=TOKEN_URL,
)


def verify_access_token(access_token: str = Depends(oauth_scheme)):
if AUTHENTICATOR:
_, exception = AUTHENTICATOR.verify_token(access_token)
if exception:
try:
valid_token = AUTHENTICATOR.verify_token(access_token)
if not valid_token:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
except Exception as exception:
raise HTTPException(

Check warning on line 95 in src/blueapi/service/main.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/service/main.py#L89-L95

Added lines #L89 - L95 were not covered by tests
status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exception)
status_code=status.HTTP_401_UNAUTHORIZED,
) from exception


if TOKEN_URL == "" or AUTH_URL == "":
dependencies = []
else:
dependencies = [Depends(verify_access_token)]

Check warning on line 103 in src/blueapi/service/main.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/service/main.py#L103

Added line #L103 was not covered by tests

app = FastAPI(
docs_url="/docs",
title="BlueAPI Control",
Expand All @@ -104,7 +114,7 @@ def verify_access_token(access_token: str = Depends(oauth_scheme)):
"scopeSeparator": " ",
"scopes": "openid profile offline_access",
},
dependencies=[Depends(verify_access_token)],
dependencies=dependencies,
)


Expand Down
135 changes: 135 additions & 0 deletions tests/unit_tests/service/test_authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import os
from http import HTTPStatus
from unittest import TestCase, mock

import jwt
import pytest
from jwt import PyJWTError

from blueapi.config import BaseAuthConfig, CLIAuthConfig, OauthConfig
from blueapi.service.authentication import Authenticator, TokenManager


class TestAuthenticator(TestCase):
@mock.patch("requests.get")
def setUp(self, mock_requests_get):
mock_requests_get.return_value.status_code = 200
mock_requests_get.return_value.json.return_value = {
"device_authorization_endpoint": "https://example.com/device_authorization",
"authorization_endpoint": "https://example.com/authorization",
"token_endpoint": "https://example.com/token",
"issuer": "https://example.com",
"jwks_uri": "https://example.com/.well-known/jwks.json",
"end_session_endpoint": "https://example.com/logout",
}
self.oauth_config = OauthConfig(
oidc_config_url="https://auth.example.com/realms/sample/.well-known/openid-configuration",
)
self.base_auth_config = BaseAuthConfig(
client_id="example_client_id", client_audience="example_audience"
)
self.authenticator = Authenticator(self.oauth_config, self.base_auth_config)

@mock.patch("jwt.decode")
@mock.patch("jwt.PyJWKClient.get_signing_key_from_jwt")
def test_verify_token_valid(self, mock_get_signing_key, mock_decode):
decode_retun_value = {"token": "valid_token", "name": "John Doe"}
mock_decode.return_value = decode_retun_value
valid_token = self.authenticator.verify_token(decode_retun_value["token"])
self.assertTrue(valid_token)

@mock.patch("jwt.decode")
@mock.patch("jwt.PyJWKClient.get_signing_key_from_jwt")
def test_verify_token_invalid(self, mock_get_signing_key, mock_decode):
mock_decode.side_effect = jwt.ExpiredSignatureError
token = "invalid_token"
with pytest.raises(PyJWTError):
self.authenticator.verify_token(token)

@mock.patch("jwt.decode")
@mock.patch("jwt.PyJWKClient.get_signing_key_from_jwt")
def test_user_info(self, mock_get_signing_key, mock_decode):
mock_decode.return_value = {
"name": "John Doe",
"fedid": "12345",
}
name, fedid = self.authenticator.userInfo("valid_token")
self.assertEqual(name, "John Doe")
self.assertEqual(fedid, "12345")


class TestTokenManager(TestCase):
@mock.patch("requests.get")
def setUp(self, mock_requests_get):
mock_requests_get.return_value.status_code = 200
mock_requests_get.return_value.json.return_value = {
"device_authorization_endpoint": "https://example.com/device_authorization",
"authorization_endpoint": "https://example.com/authorization",
"token_endpoint": "https://example.com/token",
"issuer": "https://example.com",
"jwks_uri": "https://example.com/.well-known/jwks.json",
"end_session_endpoint": "https://example.com/logout",
}
self.oauth_config = OauthConfig(
oidc_config_url="https://auth.example.com/realms/sample/.well-known/openid-configuration",
)
self.cli_auth_config = CLIAuthConfig(
client_id="client_id",
client_audience="client_audience",
token_file_path="~/.token",
)
self.token_manager = TokenManager(self.oauth_config, self.cli_auth_config)

@mock.patch("os.path.exists")
@mock.patch("os.remove")
def test_logout(self, mock_remove, mock_exists):
mock_exists.return_value = True
self.token_manager.logout()
mock_remove.assert_called_once_with(
os.path.expanduser(self.cli_auth_config.token_file_path)
)

@mock.patch("requests.post")
def test_refresh_auth_token(self, mock_post):
self.token_manager.token = {"refresh_token": "refresh_token"}
mock_post.return_value.status_code = HTTPStatus.OK
mock_post.return_value.json.return_value = {"access_token": "new_access_token"}
result = self.token_manager.refresh_auth_token()
self.assertTrue(result)

@mock.patch("requests.post")
def test_get_device_code(self, mock_post):
mock_post.return_value.status_code = HTTPStatus.OK
mock_post.return_value.json.return_value = {"device_code": "device_code"}
device_code = self.token_manager.get_device_code()
self.assertEqual(device_code, "device_code")

@mock.patch("requests.post")
def test_poll_for_token(self, mock_post):
mock_post.return_value.status_code = HTTPStatus.OK
mock_post.return_value.json.return_value = {"access_token": "access_token"}
device_code = "device_code"
token = self.token_manager.poll_for_token(device_code)
self.assertEqual(token, {"access_token": "access_token"})

@mock.patch("requests.post")
@mock.patch("time.sleep")
def test_poll_for_token_timeout(self, mock_sleep, mock_post):
mock_post.return_value.status_code = HTTPStatus.BAD_REQUEST
device_code = "device_code"
with self.assertRaises(TimeoutError):
self.token_manager.poll_for_token(
device_code, timeout=1, polling_interval=0.1
)

@mock.patch("requests.post")
@mock.patch("blueapi.service.authentication.Authenticator.verify_token")
def test_start_device_flow(self, mock_verify_token, mock_post):
mock_post.return_value.status_code = HTTPStatus.OK
mock_post.return_value.json.return_value = {
"device_code": "device_code",
"verification_uri_complete": "https://example.com/verify",
}
mock_verify_token.return_value = (True, None)
self.token_manager.start_device_flow()
mock_verify_token.assert_called()
3 changes: 3 additions & 0 deletions tests/unit_tests/worker/test_task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@ def begin_task_and_wait_until_complete(
#


@pytest.mark.skip(
"This test is currently waiting for https:/DiamondLightSource/dls-bluesky-core/blob/main/src/dls_bluesky_core/plans/wrapped.py"
)
def test_worker_and_data_events_produce_in_order(worker: TaskWorker) -> None:
assert_running_count_plan_produces_ordered_worker_and_data_events(
[
Expand Down

0 comments on commit 7ba72d7

Please sign in to comment.