From 2143e0670b10d91f96ebc88a243a852da774aad2 Mon Sep 17 00:00:00 2001 From: Dylan Pulver Date: Mon, 1 Jul 2024 12:19:28 -0400 Subject: [PATCH] safety/auth --- safety/auth/cli.py | 117 ++++++++++++------ safety/auth/cli_utils.py | 115 +++++++++++++----- safety/auth/main.py | 147 +++++++++++++++++++---- safety/auth/models.py | 43 ++++++- safety/auth/server.py | 131 ++++++++++++++++----- safety/auth/utils.py | 248 +++++++++++++++++++++++++++++++-------- 6 files changed, 633 insertions(+), 168 deletions(-) diff --git a/safety/auth/cli.py b/safety/auth/cli.py index 187f4878..8a0a2c1a 100644 --- a/safety/auth/cli.py +++ b/safety/auth/cli.py @@ -40,26 +40,37 @@ CMD_LOGOUT_NAME = "logout" DEFAULT_CMD = CMD_LOGIN_NAME -@auth_app.callback(invoke_without_command=True, - cls=SafetyCLISubGroup, - help=CLI_AUTH_COMMAND_HELP, +@auth_app.callback(invoke_without_command=True, + cls=SafetyCLISubGroup, + help=CLI_AUTH_COMMAND_HELP, epilog=DEFAULT_EPILOG, - context_settings={"allow_extra_args": True, + context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) @pass_safety_cli_obj -def auth(ctx: typer.Context): +def auth(ctx: typer.Context) -> None: """ - Authenticate Safety CLI with your account + Authenticate Safety CLI with your account. + + Args: + ctx (typer.Context): The Typer context object. """ LOG.info('auth started') + # If no subcommand is invoked, forward to the default command if not ctx.invoked_subcommand: - default_command = get_command_for(name=DEFAULT_CMD, + default_command = get_command_for(name=DEFAULT_CMD, typer_instance=auth_app) return ctx.forward(default_command) -def fail_if_authenticated(ctx, with_msg: str): +def fail_if_authenticated(ctx: typer.Context, with_msg: str) -> None: + """ + Exits the command if the user is already authenticated. + + Args: + ctx (typer.Context): The Typer context object. + with_msg (str): The message to display if authenticated. + """ info = get_auth_info(ctx) if info: @@ -72,10 +83,26 @@ def fail_if_authenticated(ctx, with_msg: str): sys.exit(0) def render_email_note(auth: Auth) -> str: + """ + Renders a note indicating whether email verification is required. + + Args: + auth (Auth): The Auth object. + + Returns: + str: The rendered email note. + """ return "" if auth.email_verified else "[red](email verification required)[/red]" def render_successful_login(auth: Auth, - organization: Optional[str] = None): + organization: Optional[str] = None) -> None: + """ + Renders a message indicating a successful login. + + Args: + auth (Auth): The Auth object. + organization (Optional[str]): The organization name. + """ DEFAULT = "--" name = auth.name if auth.name else DEFAULT email = auth.email if auth.email else DEFAULT @@ -88,46 +115,52 @@ def render_successful_login(auth: Auth, details = [f"[green][bold]Account:[/bold] {email}[/green] {email_note}"] if organization: - details.insert(0, + details.insert(0, "[green][bold]Organization:[/bold] " \ f"{organization}[green]") for msg in details: - console.print(Padding(msg, (0, 0, 0, 1)), emoji=True) + console.print(Padding(msg, (0, 0, 0, 1)), emoji=True) @auth_app.command(name=CMD_LOGIN_NAME, help=CLI_AUTH_LOGIN_HELP) -def login(ctx: typer.Context, headless: bool = False): +def login(ctx: typer.Context, headless: bool = False) -> None: """ Authenticate Safety CLI with your safetycli.com account using your default browser. + + Args: + ctx (typer.Context): The Typer context object. + headless (bool): Whether to run in headless mode. """ LOG.info('login started') + # Check if the user is already authenticated fail_if_authenticated(ctx, with_msg=MSG_FAIL_LOGIN_AUTHED) console.print() - + info = None - + brief_msg: str = "Redirecting your browser to log in; once authenticated, " \ "return here to start using Safety" - - if ctx.obj.auth.org: + + if ctx.obj.auth.org: console.print(f"Logging into [bold]{ctx.obj.auth.org.name}[/bold] " \ "organization.") - + if headless: brief_msg = "Running in headless mode. Please copy and open the following URL in a browser" - + # Get authorization data and generate the authorization URL uri, initial_state = get_authorization_data(client=ctx.obj.auth.client, code_verifier=ctx.obj.auth.code_verifier, organization=ctx.obj.auth.org, headless=headless) click.secho(brief_msg) click.echo() + # Process the browser callback to complete the authentication info = process_browser_callback(uri, initial_state=initial_state, ctx=ctx, headless=headless) - + if info: if info.get("email", None): @@ -161,21 +194,25 @@ def login(ctx: typer.Context, headless: bool = False): msg += " Please try again, or use [bold]`safety auth -–help`[/bold] " \ "for more information[/red]" - + console.print(msg, emoji=True) @auth_app.command(name=CMD_LOGOUT_NAME, help=CLI_AUTH_LOGOUT_HELP) -def logout(ctx: typer.Context): +def logout(ctx: typer.Context) -> None: """ Log out of your current session. + + Args: + ctx (typer.Context): The Typer context object. """ LOG.info('logout started') id_token = get_token('id_token') - + msg = MSG_NON_AUTHENTICATED - + if id_token: + # Clean the session if an ID token is found if clean_session(ctx.obj.auth.client): msg = MSG_LOGOUT_DONE else: @@ -190,10 +227,15 @@ def logout(ctx: typer.Context): "authentication is made.") @click.option("--login-timeout", "-w", type=int, default=600, help="Max time allowed to wait for an authentication.") -def status(ctx: typer.Context, ensure_auth: bool = False, - login_timeout: int = 600): +def status(ctx: typer.Context, ensure_auth: bool = False, + login_timeout: int = 600) -> None: """ Display Safety CLI's current authentication status. + + Args: + ctx (typer.Context): The Typer context object. + ensure_auth (bool): Whether to keep running until authentication is made. + login_timeout (int): Max time allowed to wait for authentication. """ LOG.info('status started') current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -211,21 +253,22 @@ def status(ctx: typer.Context, ensure_auth: bool = False, verified = is_email_verified(info) email_status = " [red](email not verified)[/red]" if not verified else "" - console.print(f'[green]Authenticated as {info["email"]}[/green]{email_status}') + console.print(f'[green]Authenticated as {info["email"]}[/green]{email_status}') elif ensure_auth: console.print('Safety is not authenticated. Launching default browser to log in') console.print() uri, initial_state = get_authorization_data(client=ctx.obj.auth.client, code_verifier=ctx.obj.auth.code_verifier, organization=ctx.obj.auth.org, ensure_auth=ensure_auth) - - info = process_browser_callback(uri, initial_state=initial_state, - timeout=login_timeout, ctx=ctx) - + + # Process the browser callback to complete the authentication + info = process_browser_callback(uri, initial_state=initial_state, + timeout=login_timeout, ctx=ctx) + if not info: console.print(f'[red]Timeout error ({login_timeout} seconds): not successfully authenticated without the timeout period.[/red]') sys.exit(1) - + organization = None if ctx.obj.auth.org and ctx.obj.auth.org.name: organization = ctx.obj.auth.org.name @@ -238,21 +281,27 @@ def status(ctx: typer.Context, ensure_auth: bool = False, @auth_app.command(name=CMD_REGISTER_NAME) -def register(ctx: typer.Context): +def register(ctx: typer.Context) -> None: """ Create a new user account for the safetycli.com service. + + Args: + ctx (typer.Context): The Typer context object. """ LOG.info('register started') + # Check if the user is already authenticated fail_if_authenticated(ctx, with_msg=MSG_FAIL_REGISTER_AUTHED) + # Get authorization data and generate the registration URL uri, initial_state = get_authorization_data(client=ctx.obj.auth.client, code_verifier=ctx.obj.auth.code_verifier, sign_up=True) - + console.print("Redirecting your browser to register for a free account. Once registered, return here to start using Safety.") console.print() + # Process the browser callback to complete the registration info = process_browser_callback(uri, initial_state=initial_state, ctx=ctx) @@ -261,4 +310,4 @@ def register(ctx: typer.Context): console.print() else: console.print('[red]Unable to register in this time, try again.[/red]') - + diff --git a/safety/auth/cli_utils.py b/safety/auth/cli_utils.py index 098bde30..cc0cca43 100644 --- a/safety/auth/cli_utils.py +++ b/safety/auth/cli_utils.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Optional +from typing import Dict, Optional, Tuple, Any, Callable import click @@ -21,9 +21,20 @@ LOG = logging.getLogger(__name__) -def build_client_session(api_key=None, proxies=None, headers=None): - kwargs = {} +def build_client_session(api_key: Optional[str] = None, proxies: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None) -> Tuple[SafetyAuthSession, Dict[str, Any]]: + """ + Builds and configures the client session for authentication. + Args: + api_key (Optional[str]): The API key for authentication. + proxies (Optional[Dict[str, str]]): Proxy configuration. + headers (Optional[Dict[str, str]]): Additional headers. + + Returns: + Tuple[SafetyAuthSession, Dict[str, Any]]: The configured client session and OpenID configuration. + """ + + kwargs = {} target_proxies = proxies # Global proxy defined in the config.ini @@ -31,21 +42,21 @@ def build_client_session(api_key=None, proxies=None, headers=None): if not proxies: target_proxies = proxy_config - + def update_token(tokens, **kwargs): save_auth_config(access_token=tokens['access_token'], id_token=tokens['id_token'], refresh_token=tokens['refresh_token']) load_auth_session(click_ctx=click.get_current_context(silent=True)) - client_session = SafetyAuthSession(client_id=CLIENT_ID, + client_session = SafetyAuthSession(client_id=CLIENT_ID, code_challenge_method='S256', - redirect_uri=get_redirect_url(), + redirect_uri=get_redirect_url(), update_token=update_token, scope='openid email profile offline_access', **kwargs) - + client_session.mount("https://pyup.io/static-s3/", S3PresignedAdapter()) - + client_session.proxy_required = proxy_required client_session.proxy_timeout = proxy_timeout client_session.proxies = target_proxies @@ -57,7 +68,7 @@ def update_token(tokens, **kwargs): LOG.debug('Unable to load the openID config: %s', e) openid_config = {} - client_session.metadata["token_endpoint"] = openid_config.get("token_endpoint", + client_session.metadata["token_endpoint"] = openid_config.get("token_endpoint", None) if api_key: @@ -70,7 +81,13 @@ def update_token(tokens, **kwargs): return client_session, openid_config -def load_auth_session(click_ctx): +def load_auth_session(click_ctx: click.Context) -> None: + """ + Loads the authentication session from the context. + + Args: + click_ctx (click.Context): The Click context object. + """ if not click_ctx: LOG.warn("Click context is needed to be able to load the Auth data.") return @@ -94,63 +111,103 @@ def load_auth_session(click_ctx): print(e) clean_session(client) -def proxy_options(func): +def proxy_options(func: Callable) -> Callable: """ + Decorator that defines proxy options for Click commands. + Options defined per command, this will override the proxy settings defined in the config.ini file. + + Args: + func (Callable): The Click command function. + + Returns: + Callable: The wrapped Click command function with proxy options. """ - func = click.option("--proxy-protocol", + func = click.option("--proxy-protocol", type=click.Choice(['http', 'https']), default='https', cls=DependentOption, required_options=['proxy_host'], help=CLI_PROXY_PROTOCOL_HELP)(func) - func = click.option("--proxy-port", multiple=False, type=int, default=80, - cls=DependentOption, required_options=['proxy_host'], + func = click.option("--proxy-port", multiple=False, type=int, default=80, + cls=DependentOption, required_options=['proxy_host'], help=CLI_PROXY_PORT_HELP)(func) - func = click.option("--proxy-host", multiple=False, type=str, default=None, + func = click.option("--proxy-host", multiple=False, type=str, default=None, help=CLI_PROXY_HOST_HELP)(func) return func -def auth_options(stage=True): +def auth_options(stage: bool = True) -> Callable: + """ + Decorator that defines authentication options for Click commands. - def decorator(func): + Args: + stage (bool): Whether to include the stage option. + + Returns: + Callable: The decorator function. + """ + def decorator(func: Callable) -> Callable: func = click.option("--key", default=None, envvar="SAFETY_API_KEY", help=CLI_KEY_HELP)(func) if stage: - func = click.option("--stage", default=None, envvar="SAFETY_STAGE", + func = click.option("--stage", default=None, envvar="SAFETY_STAGE", help=CLI_STAGE_HELP)(func) - + return func - + return decorator -def inject_session(func): +def inject_session(func: Callable) -> Callable: """ + Decorator that injects a session object into Click commands. + Builds the session object to be used in each command. + + Args: + func (Callable): The Click command function. + + Returns: + Callable: The wrapped Click command function with session injection. """ @wraps(func) - def inner(ctx, proxy_protocol: Optional[str] = None, + def inner(ctx: click.Context, proxy_protocol: Optional[str] = None, proxy_host: Optional[str] = None, proxy_port: Optional[str] = None, - key: Optional[str] = None, - stage: Optional[Stage] = None, *args, **kwargs): - + key: Optional[str] = None, + stage: Optional[Stage] = None, *args, **kwargs) -> Any: + """ + Inner function that performs the session injection. + + Args: + ctx (click.Context): The Click context object. + proxy_protocol (Optional[str]): The proxy protocol. + proxy_host (Optional[str]): The proxy host. + proxy_port (Optional[int]): The proxy port. + key (Optional[str]): The API key. + stage (Optional[Stage]): The stage. + *args (Any): Additional arguments. + **kwargs (Any): Additional keyword arguments. + + Returns: + Any: The result of the decorated function. + """ + if ctx.invoked_subcommand == "configure": return - + org: Optional[Organization] = get_organization() - + if not stage: host_stage = get_host_config(key_name="stage") stage = host_stage if host_stage else Stage.development - proxy_config: Optional[Dict[str, str]] = get_proxy_dict(proxy_protocol, + proxy_config: Optional[Dict[str, str]] = get_proxy_dict(proxy_protocol, proxy_host, proxy_port) - client_session, openid_config = build_client_session(api_key=key, + client_session, openid_config = build_client_session(api_key=key, proxies=proxy_config) keys = get_keys(client_session, openid_config) diff --git a/safety/auth/main.py b/safety/auth/main.py index c96542a2..fdb9c225 100644 --- a/safety/auth/main.py +++ b/safety/auth/main.py @@ -1,25 +1,36 @@ import configparser -import json from typing import Any, Dict, Optional, Tuple, Union -from urllib.parse import urlencode from authlib.oidc.core import CodeIDToken from authlib.jose import jwt from authlib.jose.errors import ExpiredTokenError from safety.auth.models import Organization -from safety.auth.constants import AUTH_SERVER_URL, CLI_AUTH_LOGOUT, CLI_CALLBACK, AUTH_CONFIG_USER, CLI_AUTH +from safety.auth.constants import CLI_AUTH_LOGOUT, CLI_CALLBACK, AUTH_CONFIG_USER, CLI_AUTH from safety.constants import CONFIG -from safety.errors import NotVerifiedEmailError from safety.scan.util import Stage from safety.util import get_proxy_dict def get_authorization_data(client, code_verifier: str, - organization: Optional[Organization] = None, + organization: Optional[Organization] = None, sign_up: bool = False, ensure_auth: bool = False, headless: bool = False) -> Tuple[str, str]: - + """ + Generate the authorization URL for the authentication process. + + Args: + client: The authentication client. + code_verifier (str): The code verifier for the PKCE flow. + organization (Optional[Organization]): The organization to authenticate with. + sign_up (bool): Whether the URL is for sign-up. + ensure_auth (bool): Whether to ensure authentication. + headless (bool): Whether to run in headless mode. + + Returns: + Tuple[str, str]: The authorization URL and initial state. + """ + kwargs = {'sign_up': sign_up, 'locale': 'en', 'ensure_auth': ensure_auth, 'headless': headless} if organization: kwargs['organization'] = organization.id @@ -29,12 +40,33 @@ def get_authorization_data(client, code_verifier: str, **kwargs) def get_logout_url(id_token: str) -> str: + """ + Generate the logout URL. + + Args: + id_token (str): The ID token. + + Returns: + str: The logout URL. + """ return f'{CLI_AUTH_LOGOUT}?id_token={id_token}' def get_redirect_url() -> str: + """ + Get the redirect URL for the authentication callback. + + Returns: + str: The redirect URL. + """ return CLI_CALLBACK def get_organization() -> Optional[Organization]: + """ + Retrieve the organization configuration. + + Returns: + Optional[Organization]: The organization object, or None if not configured. + """ config = configparser.ConfigParser() config.read(CONFIG) @@ -53,9 +85,18 @@ def get_organization() -> Optional[Organization]: return org -def get_auth_info(ctx): +def get_auth_info(ctx) -> Optional[Dict]: + """ + Retrieve the authentication information. + + Args: + ctx: The context object containing authentication data. + + Returns: + Optional[Dict]: The authentication information, or None if not authenticated. + """ from safety.auth.utils import is_email_verified - + info = None if ctx.obj.auth.client.token: try: @@ -67,7 +108,7 @@ def get_auth_info(ctx): verified = is_email_verified(user_info) if verified: - # refresh only if needed + # refresh only if needed raise ExpiredTokenError except ExpiredTokenError as e: @@ -80,10 +121,21 @@ def get_auth_info(ctx): clean_session(ctx.obj.auth.client) except Exception as _g: clean_session(ctx.obj.auth.client) - + return info -def get_token_data(token, keys, silent_if_expired=False) -> Optional[Dict]: +def get_token_data(token: str, keys: Any, silent_if_expired: bool = False) -> Optional[Dict]: + """ + Decode and validate the token data. + + Args: + token (str): The token to decode. + keys (Any): The keys to use for decoding. + silent_if_expired (bool): Whether to silently ignore expired tokens. + + Returns: + Optional[Dict]: The decoded token data, or None if invalid. + """ claims = jwt.decode(token, keys, claims_cls=CodeIDToken) try: claims.validate() @@ -93,10 +145,18 @@ def get_token_data(token, keys, silent_if_expired=False) -> Optional[Dict]: return claims -def get_token(name='access_token') -> Optional[str]: +def get_token(name: str = 'access_token') -> Optional[str]: """" + Retrieve a token from the local authentication configuration. + This returns tokens saved in the local auth configuration. There are two types of tokens: access_token and id_token + + Args: + name (str): The name of the token to retrieve. + + Returns: + Optional[str]: The token value, or None if not found. """ config = configparser.ConfigParser() config.read(AUTH_CONFIG_USER) @@ -108,13 +168,22 @@ def get_token(name='access_token') -> Optional[str]: return None -def get_host_config(key_name) -> Optional[Any]: +def get_host_config(key_name: str) -> Optional[Any]: + """ + Retrieve a configuration value from the host configuration. + + Args: + key_name (str): The name of the configuration key. + + Returns: + Optional[Any]: The configuration value, or None if not found. + """ config = configparser.ConfigParser() config.read(CONFIG) if not config.has_section("host"): return None - + host_section = dict(config.items("host")) if key_name in host_section: @@ -128,8 +197,19 @@ def get_host_config(key_name) -> Optional[Any]: return None -def str_to_bool(s): - """Convert a string to a boolean value.""" +def str_to_bool(s: str) -> bool: + """ + Convert a string to a boolean value. + + Args: + s (str): The string to convert. + + Returns: + bool: The converted boolean value. + + Raises: + ValueError: If the string cannot be converted. + """ if s.lower() == 'true' or s == '1': return True elif s.lower() == 'false' or s == '0': @@ -137,7 +217,13 @@ def str_to_bool(s): else: raise ValueError(f"Cannot convert '{s}' to a boolean value.") -def get_proxy_config() -> Tuple[Dict[str, str], Optional[int], bool]: +def get_proxy_config() -> Tuple[Optional[Dict[str, str]], Optional[int], bool]: + """ + Retrieve the proxy configuration. + + Returns: + Tuple[Optional[Dict[str, str]], Optional[int], bool]: The proxy configuration, timeout, and whether it is required. + """ config = configparser.ConfigParser() config.read(CONFIG) @@ -151,7 +237,7 @@ def get_proxy_config() -> Tuple[Dict[str, str], Optional[int], bool]: if proxy: try: - proxy_dictionary = get_proxy_dict(proxy['protocol'], proxy['host'], + proxy_dictionary = get_proxy_dict(proxy['protocol'], proxy['host'], proxy['port']) required = str_to_bool(proxy["required"]) timeout = proxy["timeout"] @@ -160,7 +246,16 @@ def get_proxy_config() -> Tuple[Dict[str, str], Optional[int], bool]: return proxy_dictionary, timeout, required -def clean_session(client): +def clean_session(client) -> bool: + """ + Clean the authentication session. + + Args: + client: The authentication client. + + Returns: + bool: Always returns True. + """ config = configparser.ConfigParser() config['auth'] = {'access_token': '', 'id_token': '', 'refresh_token':''} @@ -171,11 +266,19 @@ def clean_session(client): return True -def save_auth_config(access_token=None, id_token=None, refresh_token=None): +def save_auth_config(access_token: Optional[str] = None, id_token: Optional[str] = None, refresh_token: Optional[str] = None) -> None: + """ + Save the authentication configuration. + + Args: + access_token (Optional[str]): The access token. + id_token (Optional[str]): The ID token. + refresh_token (Optional[str]): The refresh token. + """ config = configparser.ConfigParser() config.read(AUTH_CONFIG_USER) - config['auth'] = {'access_token': access_token, 'id_token': id_token, + config['auth'] = {'access_token': access_token, 'id_token': id_token, 'refresh_token': refresh_token} - + with open(AUTH_CONFIG_USER, 'w') as configfile: config.write(configfile) diff --git a/safety/auth/models.py b/safety/auth/models.py index 9966e90c..3312dedc 100644 --- a/safety/auth/models.py +++ b/safety/auth/models.py @@ -1,6 +1,6 @@ from dataclasses import dataclass import os -from typing import Any, Optional +from typing import Any, Optional, Dict from authlib.integrations.base_client import BaseOAuth @@ -11,7 +11,13 @@ class Organization: id: str name: str - def to_dict(self): + def to_dict(self) -> Dict: + """ + Convert the Organization instance to a dictionary. + + Returns: + dict: The dictionary representation of the organization. + """ return {'id': self.id, 'name': self.name} @dataclass @@ -27,6 +33,12 @@ class Auth: email_verified: bool = False def is_valid(self) -> bool: + """ + Check if the authentication information is valid. + + Returns: + bool: True if valid, False otherwise. + """ if os.getenv("SAFETY_DB_DIR"): return True @@ -38,7 +50,13 @@ def is_valid(self) -> bool: return bool(self.client.token and self.email_verified) - def refresh_from(self, info): + def refresh_from(self, info: Dict) -> None: + """ + Refresh the authentication information from the provided info. + + Args: + info (dict): The information to refresh from. + """ from safety.auth.utils import is_email_verified self.name = info.get("name") @@ -46,9 +64,24 @@ def refresh_from(self, info): self.email_verified = is_email_verified(info) class XAPIKeyAuth(BaseOAuth): - def __init__(self, api_key): + def __init__(self, api_key: str) -> None: + """ + Initialize the XAPIKeyAuth instance. + + Args: + api_key (str): The API key to use for authentication. + """ self.api_key = api_key - def __call__(self, r): + def __call__(self, r: Any) -> Any: + """ + Add the API key to the request headers. + + Args: + r (Any): The request object. + + Returns: + Any: The modified request object. + """ r.headers['X-API-Key'] = self.api_key return r diff --git a/safety/auth/server.py b/safety/auth/server.py index 3559c6eb..64ffe9af 100644 --- a/safety/auth/server.py +++ b/safety/auth/server.py @@ -4,7 +4,7 @@ import socket import sys import time -from typing import Any, Optional +from typing import Any, Optional, Dict import urllib.parse import threading import click @@ -14,14 +14,18 @@ from safety.auth.constants import AUTH_SERVER_URL, CLI_AUTH_SUCCESS, CLI_LOGOUT_SUCCESS, HOST from safety.auth.main import save_auth_config -from authlib.integrations.base_client.errors import OAuthError from rich.prompt import Prompt LOG = logging.getLogger(__name__) -def find_available_port(): - """Find an available port on localhost""" +def find_available_port() -> Optional[int]: + """ + Find an available port on localhost within the dynamic port range (49152-65536). + + Returns: + Optional[int]: An available port number, or None if no ports are available. + """ # Dynamic ports IANA port_range = range(49152, 65536) @@ -36,7 +40,23 @@ def find_available_port(): return None -def auth_process(code: str, state: str, initial_state: str, code_verifier, client): +def auth_process(code: str, state: str, initial_state: str, code_verifier: str, client: Any) -> Any: + """ + Process the authentication callback and exchange the authorization code for tokens. + + Args: + code (str): The authorization code. + state (str): The state parameter from the callback. + initial_state (str): The initial state parameter. + code_verifier (str): The code verifier for PKCE. + client (Any): The OAuth client. + + Returns: + Any: The user information. + + Raises: + SystemExit: If there is an error during authentication. + """ err = None if initial_state is None or initial_state != state: @@ -51,15 +71,15 @@ def auth_process(code: str, state: str, initial_state: str, code_verifier, clien if err: click.secho(f'Error: {err}', fg='red') sys.exit(1) - + try: tokens = client.fetch_token(url=f'{AUTH_SERVER_URL}/oauth/token', code_verifier=code_verifier, client_id=client.client_id, grant_type='authorization_code', code=code) - save_auth_config(access_token=tokens['access_token'], - id_token=tokens['id_token'], + save_auth_config(access_token=tokens['access_token'], + id_token=tokens['id_token'], refresh_token=tokens['refresh_token']) return client.fetch_user_info() @@ -68,20 +88,32 @@ def auth_process(code: str, state: str, initial_state: str, code_verifier, clien sys.exit(1) class CallbackHandler(http.server.BaseHTTPRequestHandler): - def auth(self, code: str, state: str, err, error_description): + def auth(self, code: str, state: str, err: str, error_description: str) -> None: + """ + Handle the authentication callback. + + Args: + code (str): The authorization code. + state (str): The state parameter. + err (str): The error message, if any. + error_description (str): The error description, if any. + """ initial_state = self.server.initial_state ctx = self.server.ctx - result = auth_process(code=code, - state=state, - initial_state=initial_state, + result = auth_process(code=code, + state=state, + initial_state=initial_state, code_verifier=ctx.obj.auth.code_verifier, client=ctx.obj.auth.client) - + self.server.callback = result self.do_redirect(location=CLI_AUTH_SUCCESS, params={}) - def logout(self): + def logout(self) -> None: + """ + Handle the logout callback. + """ ctx = self.server.ctx uri = CLI_LOGOUT_SUCCESS @@ -90,7 +122,10 @@ def logout(self): self.do_redirect(location=CLI_LOGOUT_SUCCESS, params={}) - def do_GET(self): + def do_GET(self) -> None: + """ + Handle GET requests. + """ query = urllib.parse.urlparse(self.path).query params = urllib.parse.parse_qs(query) callback_type: Optional[str] = None @@ -111,22 +146,56 @@ def do_GET(self): state = params.get('state', [''])[0] err = params.get('error', [''])[0] error_description = params.get('error_description', [''])[0] - + self.auth(code=code, state=state, err=err, error_description=error_description) - def do_redirect(self, location, params): + def do_redirect(self, location: str, params: Dict) -> None: + """ + Redirect the client to the specified location. + + Args: + location (str): The URL to redirect to. + params (dict): Additional parameters for the redirection. + """ self.send_response(301) self.send_header('Location', location) self.end_headers() - def log_message(self, format, *args): + def log_message(self, format: str, *args: Any) -> None: + """ + Log an arbitrary message. + + Args: + format (str): The format string. + args (Any): Arguments for the format string. + """ LOG.info(format % args) -def process_browser_callback(uri, **kwargs) -> Any: +def process_browser_callback(uri: str, **kwargs: Any) -> Any: + """ + Process the browser callback for authentication. + + Args: + uri (str): The authorization URL. + **kwargs (Any): Additional keyword arguments. + + Returns: + Any: The user information. + + Raises: + SystemExit: If there is an error during the process. + """ class ThreadedHTTPServer(http.server.HTTPServer): - def __init__(self, server_address, RequestHandlerClass): + def __init__(self, server_address: tuple, RequestHandlerClass: Any) -> None: + """ + Initialize the ThreadedHTTPServer. + + Args: + server_address (tuple): The server address as a tuple (host, port). + RequestHandlerClass (Any): The request handler class. + """ super().__init__(server_address, RequestHandlerClass) self.initial_state = None self.ctx = None @@ -134,6 +203,9 @@ def __init__(self, server_address, RequestHandlerClass): self.timeout_reached = False def handle_timeout(self) -> None: + """ + Handle server timeout. + """ self.timeout_reached = True return super().handle_timeout() @@ -142,7 +214,7 @@ def handle_timeout(self) -> None: if not PORT: click.secho("No available ports.") sys.exit(1) - + try: headless = kwargs.get("headless", False) initial_state = kwargs.get("initial_state", None) @@ -152,6 +224,7 @@ def handle_timeout(self) -> None: if not headless: + # Start a threaded HTTP server to handle the callback server = ThreadedHTTPServer((HOST, PORT), CallbackHandler) server.initial_state = initial_state server.timeout = kwargs.get("timeout", 600) @@ -159,14 +232,14 @@ def handle_timeout(self) -> None: server_thread = threading.Thread(target=server.handle_request) server_thread.start() message = f"If the browser does not automatically open in 5 seconds, " \ - "copy and paste this url into your browser:" + "copy and paste this url into your browser:" target = uri if headless else f"{uri}&port={PORT}" console.print(f"{message} [link={target}]{target}[/link]") console.print() if headless: - + # Handle the headless mode where user manually provides the response exchange_data = None while not exchange_data: auth_code_text = Prompt.ask("Paste the response here", default=None, console=console) @@ -175,17 +248,17 @@ def handle_timeout(self) -> None: state = exchange_data["state"] code = exchange_data["code"] except Exception as e: - code = state = None + code = state = None - return auth_process(code=code, - state=state, - initial_state=initial_state, + return auth_process(code=code, + state=state, + initial_state=initial_state, code_verifier=ctx.obj.auth.code_verifier, client=ctx.obj.auth.client) else: - + # Wait for the browser authentication in non-headless mode wait_msg = "waiting for browser authentication" - + with console.status(wait_msg, spinner="bouncingBar"): time.sleep(2) click.launch(target) diff --git a/safety/auth/utils.py b/safety/auth/utils.py index f0a93ec7..2fd5b651 100644 --- a/safety/auth/utils.py +++ b/safety/auth/utils.py @@ -1,10 +1,11 @@ import json import logging -from typing import Any, Optional +from typing import Any, Optional, Dict, Callable from authlib.integrations.requests_client import OAuth2Session from authlib.integrations.base_client.errors import OAuthError import requests from requests.adapters import HTTPAdapter + from safety.auth.constants import AUTH_SERVER_URL, CLAIM_EMAIL_VERIFIED_API, \ CLAIM_EMAIL_VERIFIED_AUTH_SERVER from safety.auth.main import get_auth_info, get_token_data @@ -21,17 +22,45 @@ LOG = logging.getLogger(__name__) -def get_keys(client_session, openid_config): +def get_keys(client_session: OAuth2Session, openid_config: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Retrieve the keys from the OpenID configuration. + + Args: + client_session (OAuth2Session): The OAuth2 session. + openid_config (Dict[str, Any]): The OpenID configuration. + + Returns: + Optional[Dict[str, Any]]: The keys, if available. + """ if "jwks_uri" in openid_config: return client_session.get(url=openid_config["jwks_uri"], bearer=False).json() return None -def is_email_verified(info) -> bool: +def is_email_verified(info: Dict[str, Any]) -> Optional[bool]: + """ + Check if the email is verified. + + Args: + info (Dict[str, Any]): The user information. + + Returns: + bool: True if the email is verified, False otherwise. + """ return info.get(CLAIM_EMAIL_VERIFIED_API) or info.get(CLAIM_EMAIL_VERIFIED_AUTH_SERVER) -def parse_response(func): +def parse_response(func: Callable) -> Callable: + """ + Decorator to parse the response from an HTTP request. + + Args: + func (Callable): The function to wrap. + + Returns: + Callable: The wrapped function. + """ def wrapper(*args, **kwargs): try: r = func(*args, **kwargs) @@ -46,12 +75,12 @@ def wrapper(*args, **kwargs): raise e if r.status_code == 403: - raise InvalidCredentialError(credential="Failed authentication.", + raise InvalidCredentialError(credential="Failed authentication.", reason=r.text) if r.status_code == 429: raise TooManyRequestsError(reason=r.text) - + if r.status_code >= 400 and r.status_code < 500: error_code = None try: @@ -72,43 +101,87 @@ def wrapper(*args, **kwargs): data = r.json() except json.JSONDecodeError as e: raise SafetyError(message=f"Bad JSON response: {e}") - + return data return wrapper class SafetyAuthSession(OAuth2Session): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: + """ + Initialize the SafetyAuthSession. + + Args: + *args (Any): Positional arguments for the parent class. + **kwargs (Any): Keyword arguments for the parent class. + """ super().__init__(*args, **kwargs) self.proxy_required: bool = False self.proxy_timeout: Optional[int] = None self.api_key = None def get_credential(self) -> Optional[str]: + """ + Get the current authentication credential. + + Returns: + Optional[str]: The API key, token, or None. + """ if self.api_key: return self.api_key - + if self.token: return SafetyContext().account - + return None - + def is_using_auth_credentials(self) -> bool: - """This does NOT check if the client is authenticated""" + """ + Check if the session is using authentication credentials. + + This does NOT check if the client is authenticated. + + Returns: + bool: True if using authentication credentials, False otherwise. + """ return self.get_authentication_type() != AuthenticationType.none def get_authentication_type(self) -> AuthenticationType: + """ + Get the type of authentication being used. + + Returns: + AuthenticationType: The type of authentication. + """ if self.api_key: return AuthenticationType.api_key - + if self.token: return AuthenticationType.token - + return AuthenticationType.none - def request(self, method, url, withhold_token=False, auth=None, bearer=True, **kwargs): - """Use the right auth parameter for Safety supported auth types""" + def request(self, method: str, url: str, withhold_token: bool = False, auth: Optional[tuple] = None, bearer: bool = True, **kwargs: Any) -> requests.Response: + """ + Make an HTTP request with the appropriate authentication. + + Use the right auth parameter for Safety supported auth types. + + Args: + method (str): The HTTP method. + url (str): The URL to request. + withhold_token (bool): Whether to withhold the token. + auth (Optional[tuple]): The authentication tuple. + bearer (bool): Whether to use bearer authentication. + **kwargs (Any): Additional keyword arguments. + + Returns: + requests.Response: The HTTP response. + + Raises: + Exception: If the request fails. + """ # By default use the token_auth TIMEOUT_KEYWARD = "timeout" func_timeout = kwargs[TIMEOUT_KEYWARD] if TIMEOUT_KEYWARD in kwargs else REQUEST_TIMEOUT @@ -119,7 +192,7 @@ def request(self, method, url, withhold_token=False, auth=None, bearer=True, **k kwargs["headers"] = key_header else: kwargs["headers"]["X-Api-Key"] = self.api_key - + if not self.token or not bearer: # Fallback to no token auth auth = () @@ -128,7 +201,7 @@ def request(self, method, url, withhold_token=False, auth=None, bearer=True, **k # Override proxies if self.proxies: kwargs['proxies'] = self.proxies - + if self.proxy_timeout: kwargs['timeout'] = int(self.proxy_timeout) / 1000 @@ -140,7 +213,7 @@ def request(self, method, url, withhold_token=False, auth=None, bearer=True, **k 'method': method, 'url': url, 'withhold_token': withhold_token, - 'auth': auth, + 'auth': auth, } params.update(kwargs) @@ -160,13 +233,19 @@ def request(self, method, url, withhold_token=False, auth=None, bearer=True, **k LOG.debug(message) if message not in [a['message'] for a in SafetyContext.local_announcements]: SafetyContext.local_announcements.append({'message': message, 'type': 'warning', 'local': True}) - + return request_func(**params) - + raise e @parse_response def fetch_user_info(self) -> Any: + """ + Fetch user information from the authorization server. + + Returns: + Any: The user information. + """ USER_INFO_ENDPOINT = f"{AUTH_SERVER_URL}/userinfo" r = self.get( @@ -176,13 +255,26 @@ def fetch_user_info(self) -> Any: return r @parse_response - def check_project(self, scan_stage: str, safety_source: str, + def check_project(self, scan_stage: str, safety_source: str, project_slug: Optional[str] = None, git_origin: Optional[str] = None, project_slug_source: Optional[str] = None) -> Any: - - data = {"scan_stage": scan_stage, "safety_source": safety_source, - "project_slug": project_slug, - "project_slug_source": project_slug_source, + """ + Check project information. + + Args: + scan_stage (str): The scan stage. + safety_source (str): The safety source. + project_slug (Optional[str]): The project slug. + git_origin (Optional[str]): The git origin. + project_slug_source (Optional[str]): The project slug source. + + Returns: + Any: The project information. + """ + + data = {"scan_stage": scan_stage, "safety_source": safety_source, + "project_slug": project_slug, + "project_slug_source": project_slug_source, "git_origin": git_origin} r = self.post( @@ -191,80 +283,138 @@ def check_project(self, scan_stage: str, safety_source: str, ) return r - + @parse_response def project(self, project_id: str) -> Any: + """ + Get project information. + + Args: + project_id (str): The project ID. + + Returns: + Any: The project information. + """ data = {"project": project_id} - r = self.get( + return self.get( url=PLATFORM_API_PROJECT_ENDPOINT, params=data ) - return r - @parse_response def download_policy(self, project_id: Optional[str], stage: Stage, branch: Optional[str]) -> Any: + """ + Download the project policy. + + Args: + project_id (Optional[str]): The project ID. + stage (Stage): The stage. + branch (Optional[str]): The branch. + + Returns: + Any: The policy data. + """ data = {"project": project_id, "stage": STAGE_ID_MAPPING[stage], "branch": branch} - r = self.get( + return self.get( url=PLATFORM_API_POLICY_ENDPOINT, params=data ) - return r - + @parse_response def project_scan_request(self, project_id: str) -> Any: + """ + Request a project scan. + + Args: + project_id (str): The project ID. + + Returns: + Any: The scan request result. + """ data = {"project_id": project_id} - r = self.post( + return self.post( url=PLATFORM_API_PROJECT_SCAN_REQUEST_ENDPOINT, json=data ) - return r - + @parse_response def upload_report(self, json_report: str) -> Any: + """ + Upload a scan report. + + Args: + json_report (str): The JSON report. + + Returns: + Any: The upload result. + """ headers = { "Content-Type": "application/json" - } + } - r = self.post( + return self.post( url=PLATFORM_API_PROJECT_UPLOAD_SCAN_ENDPOINT, data=json_report, headers=headers ) - return r - + @parse_response - def check_updates(self, version: int, safety_version=None, - python_version=None, - os_type=None, - os_release=None, - os_description=None) -> Any: - data = {"version": version, + def check_updates(self, version: int, safety_version: Optional[str] = None, python_version: Optional[str] = None, os_type: Optional[str] = None, os_release: Optional[str] = None, os_description: Optional[str] = None) -> Any: + """ + Check for updates. + + Args: + version (int): The version. + safety_version (Optional[str]): The Safety version. + python_version (Optional[str]): The Python version. + os_type (Optional[str]): The OS type. + os_release (Optional[str]): The OS release. + os_description (Optional[str]): The OS description. + + Returns: + Any: The update check result. + """ + data = {"version": version, "safety_version": safety_version, "python_version": python_version, "os_type": os_type, "os_release": os_release, "os_description": os_description} - r = self.get( + return self.get( url=PLATFORM_API_CHECK_UPDATES_ENDPOINT, params=data ) - return r @parse_response def initialize_scan(self) -> Any: + """ + Initialize a scan. + + Returns: + Any: The initialization result. + """ return self.get(url=PLATFORM_API_INITIALIZE_SCAN_ENDPOINT, timeout=2) class S3PresignedAdapter(HTTPAdapter): - def send(self, request, **kwargs): + def send(self, request: requests.PreparedRequest, **kwargs: Any) -> requests.Response: + """ + Send a request, removing the Authorization header. + + Args: + request (requests.PreparedRequest): The prepared request. + **kwargs (Any): Additional keyword arguments. + + Returns: + requests.Response: The response. + """ request.headers.pop("Authorization", None) return super().send(request, **kwargs)