Skip to content

Commit

Permalink
feat(providers): add support for custom vendors (#74)
Browse files Browse the repository at this point in the history
* feat(providers): add support for custom vendors

Signed-off-by: Aaron Pham <[email protected]>

* fix: override configuration not setup

Signed-off-by: Aaron Pham <[email protected]>

---------

Signed-off-by: Aaron Pham <[email protected]>
  • Loading branch information
aarnphm authored Aug 19, 2024
1 parent 5fa4f70 commit 2700cad
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 76 deletions.
93 changes: 93 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,99 @@ lua_ls = {

Then you can set `dev = true` in your `lazy` config for development.

## Custom Providers

To add support for custom providers, one add `AvanteProvider` spec into `opts.vendors`:

```lua
{
provider = "my-custom-provider", -- You can then change this provider here
vendors = {
["my-custom-provider"] = {...}
},
windows = {
wrap_line = true,
width = 30, -- default % based on available width
},
--- @class AvanteConflictUserConfig
diff = {
debug = false,
autojump = true,
---@type string | fun(): any
list_opener = "copen",
},
}

```

A custom provider should following the following spec:

```lua
---@type AvanteProvider
{
endpoint = "https://api.openai.com/v1/chat/completions", -- The full endpoint of the provider
model = "gpt-4o", -- The model name to use with this provider
api_key_name = "OPENAI_API_KEY", -- The name of the environment variable that contains the API key
--- This function below will be used to parse in cURL arguments.
--- It takes in the provider options as the first argument, followed by code_opts retrieved from given buffer.
--- This code_opts include:
--- - question: Input from the users
--- - code_lang: the language of given code buffer
--- - code_content: content of code buffer
--- - selected_code_content: (optional) If given code content is selected in visual mode as context.
---@type fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
parse_curl_args = function(opts, code_opts) end
--- This function will be used to parse incoming SSE stream
--- It takes in the data stream as the first argument, followed by opts retrieved from given buffer.
--- This opts include:
--- - on_chunk: (fun(chunk: string): any) this is invoked on parsing correct delta chunk
--- - on_complete: (fun(err: string|nil): any) this is invoked on either complete call or error chunk
--- - event_state: SSE event state.
---@type fun(data_stream: string, opts: ResponseParser): nil
parse_response_data = function(data_stream, opts) end
}
```

<details>
<summary>Full working example of perplexity</summary>

```lua
vendors = {
---@type AvanteProvider
perplexity = {
endpoint = "https://api.perplexity.ai/chat/completions",
model = "llama-3.1-sonar-large-128k-online",
api_key_name = "PPLX_API_KEY",
--- this function below will be used to parse in cURL arguments.
parse_curl_args = function(opts, code_opts)
local Llm = require "avante.llm"
return {
url = opts.endpoint,
headers = {
["Accept"] = "application/json",
["Content-Type"] = "application/json",
["Authorization"] = "Bearer " .. os.getenv(opts.api_key_name),
},
body = {
model = opts.model,
messages = Llm.make_openai_message(code_opts), -- you can make your own message, but this is very advanced
temperature = 0,
max_tokens = 8192,
stream = true, -- this will be set by default.
},
}
end,
-- The below function is used if the vendors has specific SSE spec that is not claude or openai.
parse_response_data = function(data_stream, opts)
local Llm = require "avante.llm"
Llm.parse_openai_response(data_stream, opts)
end,
},
},
```

</details>

## License

avante.nvim is licensed under the Apache License. For more details, please refer to the [LICENSE](./LICENSE) file.
11 changes: 10 additions & 1 deletion lua/avante/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ local M = {}

---@class avante.Config
M.defaults = {
---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq"
---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq" | [string]
provider = "claude", -- "claude" or "openai" or "azure" or "deepseek" or "groq"
openai = {
endpoint = "https://api.openai.com",
Expand Down Expand Up @@ -39,6 +39,10 @@ M.defaults = {
temperature = 0,
max_tokens = 4096,
},
--- To add support for custom provider, follow the format below
--- See https:/yetone/avante.nvim/README.md#custom-providers for more details
---@type table<string, AvanteProvider>
vendors = {},
behaviour = {
auto_apply_diff_after_generation = false, -- Whether to automatically apply diff after LLM response.
},
Expand Down Expand Up @@ -100,6 +104,11 @@ function M.setup(opts)
)
end

---@param opts? avante.Config
function M.override(opts)
M.options = vim.tbl_deep_extend("force", M.options, opts or {})
end

M = setmetatable(M, {
__index = function(_, k)
if M.options[k] then
Expand Down
2 changes: 1 addition & 1 deletion lua/avante/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ function M.setup(opts)
end

require("avante.diff").setup()
require("avante.ai_bot").setup()
require("avante.llm").setup()

-- setup helpers
H.autocmds()
Expand Down
Loading

0 comments on commit 2700cad

Please sign in to comment.