Skip to content

Commit

Permalink
Backport PRs 344, 407, 415 (#417)
Browse files Browse the repository at this point in the history
* Fixed conflict.

* Updates for more stable generate feature

(cherry picked from commit d9b26f5)

* Refactored generate for better stability with all providers/models.

(cherry picked from commit 608ed25)

* Upgraded LangChain to 0.0.318

(cherry picked from commit 0d247fd)

* Updated to use memory instead of chat history, fix for Bedrock Anthropic

(cherry picked from commit 7c09863)

* Allow to define block and allow lists for providers (#415)

* Allow to block or allow-list providers by id

* Add tests for block/allow-lists

* Fix "No language model is associated with" issue

This was appearing because the models which are blocked were not
returned (correctly!) but the previous validation logic did not
know that sometimes models may be missing for a valid reason even
if there are existing settings for these.

* Add docs for allow listing and block listing providers

* Updated docs

* Added an intro block to docs

* Updated the docs

---------

Co-authored-by: Piyush Jain <[email protected]>
(cherry picked from commit 92dab10)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Michał Krassowski <[email protected]>
  • Loading branch information
3 people authored Oct 23, 2023
1 parent 54c357b commit 52c9339
Show file tree
Hide file tree
Showing 13 changed files with 283 additions and 72 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https:/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: end-of-file-fixer
- id: check-case-conflict
Expand All @@ -18,7 +18,7 @@ repos:
- id: trailing-whitespace

- repo: https:/psf/black
rev: 23.7.0
rev: 23.9.1
hooks:
- id: black

Expand All @@ -30,7 +30,7 @@ repos:
files: \.py$

- repo: https:/asottile/pyupgrade
rev: v3.9.0
rev: v3.15.0
hooks:
- id: pyupgrade
args: [--py37-plus]
Expand All @@ -48,7 +48,7 @@ repos:
stages: [manual]

- repo: https:/sirosen/check-jsonschema
rev: 0.23.3
rev: 0.27.0
hooks:
- id: check-jsonschema
name: "Check GitHub Workflows"
Expand Down
31 changes: 31 additions & 0 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -705,3 +705,34 @@ The `--region-name` parameter is set to the [AWS region code](https://docs.aws.a
The `--request-schema` parameter is the JSON object the endpoint expects as input, with the prompt being substituted into any value that matches the string literal `"<prompt>"`. For example, the request schema `{"text_inputs":"<prompt>"}` will submit a JSON object with the prompt stored under the `text_inputs` key.

The `--response-path` option is a [JSONPath](https://goessner.net/articles/JsonPath/index.html) string that retrieves the language model's output from the endpoint's JSON response. For example, if your endpoint returns an object with the schema `{"generated_texts":["<output>"]}`, its response path is `generated_texts.[0]`.


## Configuration

You can specify an allowlist, to only allow only a certain list of providers, or a blocklist, to block some providers.

### Blocklisting providers
This configuration allows for blocking specific providers in the settings panel. This list takes precedence over the allowlist in the next section.

```
jupyter lab --AiExtension.blocked_providers=openai
```

To block more than one provider in the block-list, repeat the runtime configuration.

```
jupyter lab --AiExtension.blocked_providers=openai --AiExtension.blocked_providers=ai21
```

### Allowlisting providers
This configuration allows for filtering the list of providers in the settings panel to only an allowlisted set of providers.

```
jupyter lab --AiExtension.allowed_providers=openai
```

To allow more than one provider in the allowlist, repeat the runtime configuration.

```
jupyter lab --AiExtension.allowed_providers=openai --AiExtension.allowed_providers=ai21
```
16 changes: 16 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ def get_prompt_template(self, format) -> PromptTemplate:
def is_chat_provider(self):
return isinstance(self, BaseChatModel)

@property
def allows_concurrency(self):
return True


class AI21Provider(BaseProvider, AI21):
id = "ai21"
Expand Down Expand Up @@ -267,6 +271,10 @@ class AnthropicProvider(BaseProvider, Anthropic):
pypi_package_deps = ["anthropic"]
auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY")

@property
def allows_concurrency(self):
return False


class ChatAnthropicProvider(BaseProvider, ChatAnthropic):
id = "anthropic-chat"
Expand All @@ -285,6 +293,10 @@ class ChatAnthropicProvider(BaseProvider, ChatAnthropic):
pypi_package_deps = ["anthropic"]
auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY")

@property
def allows_concurrency(self):
return False


class CohereProvider(BaseProvider, Cohere):
id = "cohere"
Expand Down Expand Up @@ -665,3 +677,7 @@ async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:

async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]:
return await self._generate_in_executor(*args, **kwargs)

@property
def allows_concurrency(self):
return not "anthropic" in self.model_id
34 changes: 34 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.

import pytest
from jupyter_ai_magics.utils import get_lm_providers

KNOWN_LM_A = "openai"
KNOWN_LM_B = "huggingface_hub"


@pytest.mark.parametrize(
"restrictions",
[
{"allowed_providers": None, "blocked_providers": None},
{"allowed_providers": [], "blocked_providers": []},
{"allowed_providers": [], "blocked_providers": [KNOWN_LM_B]},
{"allowed_providers": [KNOWN_LM_A], "blocked_providers": []},
],
)
def test_get_lm_providers_not_restricted(restrictions):
a_not_restricted = get_lm_providers(None, restrictions)
assert KNOWN_LM_A in a_not_restricted


@pytest.mark.parametrize(
"restrictions",
[
{"allowed_providers": [], "blocked_providers": [KNOWN_LM_A]},
{"allowed_providers": [KNOWN_LM_B], "blocked_providers": []},
],
)
def test_get_lm_providers_restricted(restrictions):
a_not_restricted = get_lm_providers(None, restrictions)
assert KNOWN_LM_A not in a_not_restricted
32 changes: 28 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, Optional, Tuple, Type, Union
from typing import Dict, List, Literal, Optional, Tuple, Type, Union

from importlib_metadata import entry_points
from jupyter_ai_magics.aliases import MODEL_ID_ALIASES
Expand All @@ -11,13 +11,19 @@
EmProvidersDict = Dict[str, BaseEmbeddingsProvider]
AnyProvider = Union[BaseProvider, BaseEmbeddingsProvider]
ProviderDict = Dict[str, AnyProvider]
ProviderRestrictions = Dict[
Literal["allowed_providers", "blocked_providers"], Optional[List[str]]
]


def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict:
def get_lm_providers(
log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None
) -> LmProvidersDict:
if not log:
log = logging.getLogger()
log.addHandler(logging.NullHandler())

if not restrictions:
restrictions = {"allowed_providers": None, "blocked_providers": None}
providers = {}
eps = entry_points()
model_provider_eps = eps.select(group="jupyter_ai.model_providers")
Expand All @@ -29,18 +35,23 @@ def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict:
f"Unable to load model provider class from entry point `{model_provider_ep.name}`."
)
continue
if not is_provider_allowed(provider.id, restrictions):
log.info(f"Skipping blocked provider `{provider.id}`.")
continue
providers[provider.id] = provider
log.info(f"Registered model provider `{provider.id}`.")

return providers


def get_em_providers(
log: Optional[Logger] = None,
log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None
) -> EmProvidersDict:
if not log:
log = logging.getLogger()
log.addHandler(logging.NullHandler())
if not restrictions:
restrictions = {"allowed_providers": None, "blocked_providers": None}
providers = {}
eps = entry_points()
model_provider_eps = eps.select(group="jupyter_ai.embeddings_model_providers")
Expand All @@ -52,6 +63,9 @@ def get_em_providers(
f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`."
)
continue
if not is_provider_allowed(provider.id, restrictions):
log.info(f"Skipping blocked provider `{provider.id}`.")
continue
providers[provider.id] = provider
log.info(f"Registered embeddings model provider `{provider.id}`.")

Expand Down Expand Up @@ -97,6 +111,16 @@ def get_em_provider(
return _get_provider(model_id, em_providers)


def is_provider_allowed(provider_id: str, restrictions: ProviderRestrictions) -> bool:
allowed = restrictions["allowed_providers"]
blocked = restrictions["blocked_providers"]
if blocked and provider_id in blocked:
return False
if allowed and provider_id not in allowed:
return False
return True


def _get_provider(model_id: str, providers: ProviderDict):
provider_id, local_model_id = decompose_model_id(model_id, providers)
provider = providers.get(provider_id, None)
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"ipython",
"pydantic~=1.0",
"importlib_metadata>=5.2.0",
"langchain==0.0.308",
"langchain==0.0.318",
"typing_extensions>=4.5.0",
"click~=8.0",
"jsonpath-ng>=1.5.3,<2",
Expand Down
28 changes: 19 additions & 9 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,19 @@
from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import PromptTemplate

from .base import BaseChatHandler

PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_PROMPT = PromptTemplate.from_template(PROMPT_TEMPLATE)


class AskChatHandler(BaseChatHandler):
"""Processes messages prefixed with /ask. This actor will
Expand All @@ -27,9 +37,15 @@ def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
self.llm = provider(**provider_params)
self.chat_history = []
memory = ConversationBufferWindowMemory(
memory_key="chat_history", return_messages=True, k=2
)
self.llm_chain = ConversationalRetrievalChain.from_llm(
self.llm, self._retriever, verbose=True
self.llm,
self._retriever,
memory=memory,
condense_question_prompt=CONDENSE_PROMPT,
verbose=False,
)

async def _process_message(self, message: HumanChatMessage):
Expand All @@ -44,14 +60,8 @@ async def _process_message(self, message: HumanChatMessage):
self.get_llm_chain()

try:
# limit chat history to last 2 exchanges
self.chat_history = self.chat_history[-2:]

result = await self.llm_chain.acall(
{"question": query, "chat_history": self.chat_history}
)
result = await self.llm_chain.acall({"question": query})
response = result["answer"]
self.chat_history.append((query, response))
self.reply(response, message)
except AssertionError as e:
self.log.error(e)
Expand Down
Loading

0 comments on commit 52c9339

Please sign in to comment.