Skip to content

Commit

Permalink
added minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ZohebShaikh committed Oct 17, 2024
1 parent fba96e6 commit d03b1ec
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 28 deletions.
3 changes: 1 addition & 2 deletions src/blueapi/client/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,7 @@ def _request_and_deserialize(
headers["Authorization"] = f"Bearer {token['access_token']}"
except jwt.ExpiredSignatureError:
if token := self._session_manager.refresh_auth_token():
if token := self._session_manager.get_token():
headers["Authorization"] = f"Bearer {token['access_token']}"
headers["Authorization"] = f"Bearer {token['access_token']}"
except Exception:
pass
if data:
Expand Down
15 changes: 10 additions & 5 deletions src/blueapi/service/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,19 @@ class CliTokenManager(TokenManager):
def __init__(self, token_file_path: Path) -> None:
self._token_file_path: Path = token_file_path

def _file_path(self) -> str:
return os.path.expanduser(self._token_file_path)

def save_token(self, token: dict[str, Any]) -> None:
token_json: str = json.dumps(token)
token_bytes: bytes = token_json.encode("utf-8")
token_base64: bytes = base64.b64encode(token_bytes)
with open(os.path.expanduser(self._token_file_path), "wb") as token_file:
with open(self._file_path(), "wb") as token_file:
token_file.write(token_base64)

def load_token(self) -> dict[str, Any] | None:
file_path = os.path.expanduser(self._token_file_path)
if not os.path.exists(file_path):
file_path = self._file_path()
if not os.path.exists(self._file_path()):
return None
with open(file_path, "rb") as token_file:
token_base64: bytes = token_file.read()
Expand All @@ -90,8 +93,8 @@ def load_token(self) -> dict[str, Any] | None:
return json.loads(token_json)

def delete_token(self) -> None:
if os.path.exists(os.path.expanduser(self._token_file_path)):
os.remove(os.path.expanduser(self._token_file_path))
if os.path.exists(self._file_path()):
os.remove(self._file_path())


class SessionManager:
Expand Down Expand Up @@ -215,3 +218,5 @@ def start_device_flow(self) -> None:
if valid_token:
self._token_manager.save_token(auth_token_json)
self.authenticator.print_user_info(auth_token_json["access_token"])
else:
print("Failed to login")
19 changes: 11 additions & 8 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,17 @@
WorkerTask,
)

load_dotenv()
REST_API_VERSION = "0.0.5"

RUNNER: WorkerDispatcher | None = None
AUTHENTICATOR: Authenticator | None = None
SWAGGER_CONFIG: dict[str, Any] | None = None
AUTH_URL: str = os.getenv("PKCE_AUTHENTICATION_URL") or ""
TOKEN_URL: str = os.getenv("TOKEN_URL") or ""
_PKCE_AUTHENTICATION_URL: str = "PKCE_AUTHENTICATION_URL"
_TOKEN_URL: str = "TOKEN_URL"
_PKCE_CLIENT_ID: str = "PKCE_CLIENT_ID"
_PKCE_CLIENT_SECRET: str = "PKCE_CLIENT_SECRET"
AUTH_URL: str = os.getenv(_PKCE_AUTHENTICATION_URL, "")
TOKEN_URL: str = os.getenv(_TOKEN_URL, "")


def _runner() -> WorkerDispatcher:
Expand Down Expand Up @@ -97,19 +100,19 @@ def verify_access_token(access_token: str = Depends(oauth_scheme)):
) from exception


if TOKEN_URL == "" or AUTH_URL == "":
dependencies = []
else:
if TOKEN_URL and AUTH_URL:
dependencies = [Depends(verify_access_token)]
else:
dependencies = []

app = FastAPI(
docs_url="/docs",
title="BlueAPI Control",
lifespan=lifespan,
version=REST_API_VERSION,
swagger_ui_init_oauth={
"clientId": os.getenv("PKCE_CLIENT_ID"),
"clientSecret": os.getenv("PKCE_CLIENT_SECRET"),
"clientId": os.getenv(_PKCE_CLIENT_ID),
"clientSecret": os.getenv(_PKCE_CLIENT_SECRET),
"usePkceWithAuthorizationCodeGrant": True,
"scopeSeparator": " ",
"scopes": "openid profile offline_access",
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/client/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_refresh_if_signature_expired(rest: BlueapiRestClient):
) as mock_refresh_token,
):
mock_verify_token.side_effect = jwt.ExpiredSignatureError
mock_refresh_token.return_value = True
mock_refresh_token.return_value = {"access_token": "new_token"}
result = rest.get_plans()
assert result == PlanResponse(plans=[PlanModel.from_plan(plan)])

Expand Down
1 change: 0 additions & 1 deletion tests/unit_tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
def client() -> Iterator[TestClient]:
with patch("blueapi.service.interface.worker"):
main.setup_runner(use_subprocess=False)
main.app.dependency_overrides[main.verify_access_token] = lambda: None
yield TestClient(main.app)
main.teardown_runner()

Expand Down
14 changes: 3 additions & 11 deletions tests/unit_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
from collections.abc import Mapping
from dataclasses import dataclass
from enum import Enum
from io import StringIO
from pathlib import Path
from textwrap import dedent
Expand Down Expand Up @@ -73,11 +72,6 @@ def test_connection_error_caught_by_wrapper_func(
assert result.stdout == "Failed to establish connection to FastAPI server.\n"


class OidcResponse(Enum):
VERIFICATION_URL = ""
DEVICE_AUTHORIZATION = ""


class MyModel(BaseModel):
id: str

Expand All @@ -88,9 +82,7 @@ class MyDevice:


@responses.activate
def test_get_plans(
runner: CliRunner,
):
def test_get_plans(runner: CliRunner):
plan = Plan(name="my-plan", model=MyModel)

response = responses.add(
Expand All @@ -101,8 +93,8 @@ def test_get_plans(
)

plans = runner.invoke(main, ["controller", "plans"])
assert plans.output == "my-plan\n Args\n id=string (Required)\n"
assert response.call_count == 1
assert plans.output == "my-plan\n Args\n id=string (Required)\n"


@responses.activate
Expand Down Expand Up @@ -825,7 +817,7 @@ def test_login_edge_cases(runner: CliRunner, valid_auth_config: str, tmp_path: P
):
mock_decode.side_effect = jwt.ExpiredSignatureError
result = runner.invoke(main, ["-c", valid_auth_config, "login"])
assert "Logging in\n" == result.output
assert "Logging in\nFailed to login\n" == result.output
assert result.exit_code == 0


Expand Down

0 comments on commit d03b1ec

Please sign in to comment.