Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom/provisioned models in Bedrock #922

Merged
merged 6 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,23 +153,24 @@ Jupyter AI supports a wide range of model providers and models. To use Jupyter A

Jupyter AI supports the following model providers:

| Provider | Provider ID | Environment variable(s) | Python package(s) |
|---------------------|----------------------|----------------------------|---------------------------------|
| AI21 | `ai21` | `AI21_API_KEY` | `ai21` |
| Anthropic | `anthropic` | `ANTHROPIC_API_KEY` | `langchain-anthropic` |
| Anthropic (chat) | `anthropic-chat` | `ANTHROPIC_API_KEY` | `langchain-anthropic` |
| Bedrock | `bedrock` | N/A | `langchain-aws` |
| Bedrock (chat) | `bedrock-chat` | N/A | `langchain-aws` |
| Cohere | `cohere` | `COHERE_API_KEY` | `langchain_cohere` |
| ERNIE-Bot | `qianfan` | `QIANFAN_AK`, `QIANFAN_SK` | `qianfan` |
| Gemini | `gemini` | `GOOGLE_API_KEY` | `langchain-google-genai` |
| GPT4All | `gpt4all` | N/A | `gpt4all` |
| Hugging Face Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets`, `pillow` |
| MistralAI | `mistralai` | `MISTRAL_API_KEY` | `langchain-mistralai` |
| NVIDIA | `nvidia-chat` | `NVIDIA_API_KEY` | `langchain_nvidia_ai_endpoints` |
| OpenAI | `openai` | `OPENAI_API_KEY` | `langchain-openai` |
| OpenAI (chat) | `openai-chat` | `OPENAI_API_KEY` | `langchain-openai` |
| SageMaker | `sagemaker-endpoint` | N/A | `langchain-aws` |
| Provider | Provider ID | Environment variable(s) | Python package(s) |
|------------------------------|----------------------|----------------------------|-------------------------------------------|
| AI21 | `ai21` | `AI21_API_KEY` | `ai21` |
| Anthropic | `anthropic` | `ANTHROPIC_API_KEY` | `langchain-anthropic` |
| Anthropic (chat) | `anthropic-chat` | `ANTHROPIC_API_KEY` | `langchain-anthropic` |
| Bedrock | `bedrock` | N/A | `langchain-aws` |
| Bedrock (chat) | `bedrock-chat` | N/A | `langchain-aws` |
| Bedrock (custom/provisioned) | `bedrock-custom` | N/A | `langchain-aws` |
| Cohere | `cohere` | `COHERE_API_KEY` | `langchain-cohere` |
| ERNIE-Bot | `qianfan` | `QIANFAN_AK`, `QIANFAN_SK` | `qianfan` |
| Gemini | `gemini` | `GOOGLE_API_KEY` | `langchain-google-genai` |
| GPT4All | `gpt4all` | N/A | `gpt4all` |
| Hugging Face Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets`, `pillow` |
| MistralAI | `mistralai` | `MISTRAL_API_KEY` | `langchain-mistralai` |
| NVIDIA | `nvidia-chat` | `NVIDIA_API_KEY` | `langchain_nvidia_ai_endpoints` |
| OpenAI | `openai` | `OPENAI_API_KEY` | `langchain-openai` |
| OpenAI (chat) | `openai-chat` | `OPENAI_API_KEY` | `langchain-openai` |
| SageMaker endpoint | `sagemaker-endpoint` | N/A | `langchain-aws` |

The environment variable names shown above are also the names of the settings keys used when setting up the chat interface.
If multiple variables are listed for a provider, **all** must be specified.
Expand Down Expand Up @@ -615,6 +616,7 @@ We currently support the following language model providers:
- `anthropic-chat`
- `bedrock`
- `bedrock-chat`
- `bedrock-custom`
- `cohere`
- `huggingface_hub`
- `nvidia-chat`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,30 @@ def allows_concurrency(self):
return not "anthropic" in self.model_id


class BedrockCustomProvider(BaseProvider, ChatBedrock):
id = "bedrock-custom"
name = "Amazon Bedrock (custom/provisioned)"
models = ["*"]
model_id_key = "model_id"
model_id_label = "Model ID"
pypi_package_deps = ["langchain-aws"]
auth_strategy = AwsAuthStrategy()
fields = [
TextField(key="provider", label="Provider (required)", format="text"),
TextField(key="region_name", label="Region name (optional)", format="text"),
TextField(
key="credentials_profile_name",
label="AWS profile (optional)",
format="text",
),
]
help = (
"Specify the ARN (Amazon Resource Name) of the custom/provisioned model as the model ID. For more information, see the [Amazon Bedrock model IDs documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html).\n\n"
"The model provider must also be specified below. This is the provider of your foundation model *in lowercase*, e.g. `amazon`, `anthropic`, `meta`, or `mistral`."
)
registry = True


# See model ID list here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
class BedrockEmbeddingsProvider(BaseEmbeddingsProvider, BedrockEmbeddings):
id = "bedrock"
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ azure-chat-openai = "jupyter_ai_magics.partner_providers.openai:AzureChatOpenAIP
sagemaker-endpoint = "jupyter_ai_magics.partner_providers.aws:SmEndpointProvider"
amazon-bedrock = "jupyter_ai_magics.partner_providers.aws:BedrockProvider"
amazon-bedrock-chat = "jupyter_ai_magics.partner_providers.aws:BedrockChatProvider"
amazon-bedrock-custom = "jupyter_ai_magics.partner_providers.aws:BedrockCustomProvider"
qianfan = "jupyter_ai_magics:QianfanProvider"
nvidia-chat = "jupyter_ai_magics.partner_providers.nvidia:ChatNVIDIAProvider"
together-ai = "jupyter_ai_magics:TogetherAIProvider"
Expand Down
Loading