Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 19, 2023
1 parent dae87d3 commit 100e7fb
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 69 deletions.
1 change: 1 addition & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

pytest_plugins = ("jupyter_server.pytest_plugin",)


@pytest.fixture
def jp_server_config(jp_server_config):
return {"ServerApp": {"jpserver_extensions": {"jupyter_ai": True}}}
109 changes: 70 additions & 39 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import json
import logging
import os
from typing import Union, Optional
import shutil
import time
from typing import Optional, Union

from jsonschema import Draft202012Validator as Validator
from deepmerge import always_merger as Merger
from traitlets import Unicode, Integer
from traitlets.config import Configurable

from jupyter_ai.models import GlobalConfig, DescribeConfigResponse, UpdateConfigRequest
from jsonschema import Draft202012Validator as Validator
from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest
from jupyter_ai_magics.utils import (
AnyProvider,
EmProvidersDict,
Expand All @@ -19,41 +16,52 @@
get_lm_provider,
)
from jupyter_core.paths import jupyter_data_dir
from traitlets import Integer, Unicode
from traitlets.config import Configurable

Logger = Union[logging.Logger, logging.LoggerAdapter]

# default path to config
DEFAULT_CONFIG_PATH = os.path.join(jupyter_data_dir(), "jupyter_ai", "config.json")
DEFAULT_CONFIG_PATH = os.path.join(jupyter_data_dir(), "jupyter_ai", "config.json")

# default path to config JSON Schema
DEFAULT_SCHEMA_PATH = os.path.join(jupyter_data_dir(), "jupyter_ai", "config_schema.json")
DEFAULT_SCHEMA_PATH = os.path.join(
jupyter_data_dir(), "jupyter_ai", "config_schema.json"
)

# default no. of spaces to use when formatting config
DEFAULT_INDENTATION_DEPTH = 4

# path to the default schema defined in this project
# if a file does not exist at SCHEMA_PATH, this file is used as a default.
OUR_SCHEMA_PATH = os.path.join(os.path.dirname(__file__), "config", "config_schema.json")
OUR_SCHEMA_PATH = os.path.join(
os.path.dirname(__file__), "config", "config_schema.json"
)


class AuthError(Exception):
pass


class WriteConflictError(Exception):
pass


class KeyInUseError(Exception):
pass


def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider):
# TODO: handle non-env auth strategies
if not provider.auth_strategy or provider.auth_strategy.type != "env":
return

if provider.auth_strategy.name not in config.api_keys:
raise AuthError(
f"Missing API key for '{provider.auth_strategy.name}' in the config."
)



class ConfigManager(Configurable):
"""Provides model and embedding provider id along
with the credentials to authenticate providers.
Expand All @@ -63,25 +71,30 @@ class ConfigManager(Configurable):
default_value=DEFAULT_CONFIG_PATH,
help="Path to the configuration file.",
allow_none=False,
config=True
config=True,
)

schema_path = Unicode(
default_value=DEFAULT_SCHEMA_PATH,
help="Path to the configuration's corresponding JSON Schema file.",
allow_none=False,
config=True
config=True,
)

indentation_depth = Integer(
default_value=DEFAULT_INDENTATION_DEPTH,
help="Indentation depth, in number of spaces per level.",
allow_none=False,
config=True
config=True,
)

def __init__(
self, log: Logger, lm_providers: LmProvidersDict, em_providers: EmProvidersDict, *args, **kwargs
self,
log: Logger,
lm_providers: LmProvidersDict,
em_providers: EmProvidersDict,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.log = log
Expand Down Expand Up @@ -117,18 +130,18 @@ def _init_validator(self) -> Validator:
def _init_config(self):
if os.path.exists(self.config_path):
with open(self.config_path, encoding="utf-8") as f:
config = (GlobalConfig(**json.loads(f.read())))
config = GlobalConfig(**json.loads(f.read()))
# re-write to the file to validate the config and apply any
# updates to the config file immediately
self._write_config(config)
return
properties = self.validator.schema.get('properties', {})

properties = self.validator.schema.get("properties", {})
field_list = GlobalConfig.__fields__.keys()
field_dict = { field: properties.get(field).get('default') for field in field_list }
default_config = GlobalConfig(
**field_dict
)
field_dict = {
field: properties.get(field).get("default") for field in field_list
}
default_config = GlobalConfig(**field_dict)
self._write_config(default_config)

def _read_config(self) -> GlobalConfig:
Expand All @@ -146,7 +159,7 @@ def _read_config(self) -> GlobalConfig:
config = GlobalConfig(**raw_config)
self._validate_config(config)
return config

def _validate_config(self, config: GlobalConfig):
"""Method used to validate the configuration. This is called after every
read and before every write to the config file. Guarantees that the
Expand All @@ -156,18 +169,26 @@ def _validate_config(self, config: GlobalConfig):

# validate language model config
if config.model_provider_id:
_, lm_provider = get_lm_provider(config.model_provider_id, self._lm_providers)
_, lm_provider = get_lm_provider(
config.model_provider_id, self._lm_providers
)
if not lm_provider:
raise ValueError(f"No language model is associated with '{config.model_provider_id}'.")
raise ValueError(
f"No language model is associated with '{config.model_provider_id}'."
)
_validate_provider_authn(config, lm_provider)

# validate embedding model config
if config.embeddings_provider_id:
_, em_provider = get_em_provider(config.embeddings_provider_id, self._em_providers)
_, em_provider = get_em_provider(
config.embeddings_provider_id, self._em_providers
)
if not em_provider:
raise ValueError(f"No embedding model is associated with '{config.embeddings_provider_id}'.")
raise ValueError(
f"No embedding model is associated with '{config.embeddings_provider_id}'."
)
_validate_provider_authn(config, em_provider)

def _write_config(self, new_config: GlobalConfig):
"""Updates configuration and persists it to disk. This accepts a
complete `GlobalConfig` object, and should not be called publicly."""
Expand All @@ -180,21 +201,33 @@ def delete_api_key(self, key_name: str):
lm_provider = self.lm_provider
em_provider = self.em_provider
required_keys = []
if lm_provider and lm_provider.auth_strategy and lm_provider.auth_strategy.type == "env":
if (
lm_provider
and lm_provider.auth_strategy
and lm_provider.auth_strategy.type == "env"
):
required_keys.append(lm_provider.auth_strategy.name)
if em_provider and em_provider.auth_strategy and em_provider.auth_strategy.type == "env":
if (
em_provider
and em_provider.auth_strategy
and em_provider.auth_strategy.type == "env"
):
required_keys.append(self.em_provider.auth_strategy.name)

if (key_name in required_keys):
raise KeyInUseError("This API key is currently in use by the language or embedding model. Please change the model before deleting the corresponding API key.")

if key_name in required_keys:
raise KeyInUseError(
"This API key is currently in use by the language or embedding model. Please change the model before deleting the corresponding API key."
)

config_dict["api_keys"].pop(key_name, None)
self._write_config(GlobalConfig(**config_dict))

def update_config(self, config_update: UpdateConfigRequest):
last_write = os.stat(self.config_path).st_mtime_ns
if config_update.last_read and config_update.last_read < last_write:
raise WriteConflictError("Configuration was modified after it was read from disk.")
raise WriteConflictError(
"Configuration was modified after it was read from disk."
)

config_dict = self._read_config().dict()
Merger.merge(config_dict, config_update.dict(exclude_unset=True))
Expand All @@ -207,16 +240,14 @@ def get_config(self):
config_dict = config.dict(exclude_unset=True)
api_key_names = list(config_dict.pop("api_keys").keys())
return DescribeConfigResponse(
**config_dict,
api_keys=api_key_names,
last_read=self._last_read
**config_dict, api_keys=api_key_names, last_read=self._last_read
)

@property
def lm_gid(self):
config = self._read_config()
return config.model_provider_id

@property
def em_gid(self):
config = self._read_config()
Expand Down Expand Up @@ -265,7 +296,7 @@ def lm_provider_params(self):
**fields,
**authn_fields,
}

@property
def em_provider_params(self):
# get generic fields
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
ListProvidersEntry,
ListProvidersResponse,
Message,
UpdateConfigRequest
UpdateConfigRequest,
)


Expand Down
14 changes: 10 additions & 4 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class IndexedDir(BaseModel):
class IndexMetadata(BaseModel):
dirs: List[IndexedDir]


class DescribeConfigResponse(BaseModel):
model_provider_id: Optional[str]
embeddings_provider_id: Optional[str]
Expand All @@ -108,10 +109,12 @@ class DescribeConfigResponse(BaseModel):
# passed to the subsequent UpdateConfig request.
last_read: int


def forbid_none(cls, v):
assert v is not None, 'size may not be None'
assert v is not None, "size may not be None"
return v


class UpdateConfigRequest(BaseModel):
model_provider_id: Optional[str]
embeddings_provider_id: Optional[str]
Expand All @@ -122,14 +125,17 @@ class UpdateConfigRequest(BaseModel):
# time specified by `last_read` to prevent write-write conflicts.
last_read: Optional[int]

_validate_send_wse = validator('send_with_shift_enter', allow_reuse=True)(forbid_none)
_validate_api_keys = validator('api_keys', allow_reuse=True)(forbid_none)
_validate_fields = validator('fields', allow_reuse=True)(forbid_none)
_validate_send_wse = validator("send_with_shift_enter", allow_reuse=True)(
forbid_none
)
_validate_api_keys = validator("api_keys", allow_reuse=True)(forbid_none)
_validate_fields = validator("fields", allow_reuse=True)(forbid_none)


class GlobalConfig(BaseModel):
"""Model used to represent the config by ConfigManager. This is exclusive to
the backend and should never be sent to the client."""

model_provider_id: Optional[str]
embeddings_provider_id: Optional[str]
send_with_shift_enter: bool
Expand Down
Loading

0 comments on commit 100e7fb

Please sign in to comment.