Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: local llm support #16

Merged
merged 20 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
f75a11e
refactor: creating an LLM module in libmemex and moving embedding mod…
a5huynh Sep 20, 2023
86b9a7c
adding a basic run prompt function for llm testing inside memex
a5huynh Sep 21, 2023
14baa52
wip: sherpa - guiding llms using logit biasing, templates, etc.
a5huynh Sep 21, 2023
ce15cd4
wip: tweaking sampler to output only what we want
a5huynh Sep 22, 2023
dc05694
Merge branch 'main' into feature/embedded-llm-support
a5huynh Oct 3, 2023
0bfc014
Merge branch 'main' into feature/embedded-llm-support
a5huynh Oct 4, 2023
224538a
updating .env.template file
a5huynh Oct 4, 2023
d6f6d65
removing sherpa stuff for now
a5huynh Oct 5, 2023
066baf3
Creating LLM trait and sharing structs between local LLM & OpenAI impls
a5huynh Oct 5, 2023
33fe800
Using LLM trait in API to switch between local/OpenAI when configured
a5huynh Oct 5, 2023
ec244d9
load llm client based on whether `OPENAI_API_KEY` or `LOCAL_LLM_CONFIG`
a5huynh Oct 6, 2023
3910cf1
update README to point that out
a5huynh Oct 6, 2023
6819059
Create samplers from config and pass into `LocalLLM` struct
a5huynh Oct 6, 2023
b1d4848
add basic support for other architectures, but focus on llama for now
a5huynh Oct 6, 2023
96130d1
add LLM ask example to README
a5huynh Oct 6, 2023
a1e7c6c
output debug messages for local llm responses
a5huynh Oct 9, 2023
e368e87
correctly capture server errors & send back the error messages
a5huynh Oct 9, 2023
3af2396
removing old dep
a5huynh Oct 9, 2023
d07b219
ignore model test
a5huynh Oct 9, 2023
13e2bb0
cargo fmt
a5huynh Oct 9, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,12 @@ PORT=8181
# Use postgres for "production"
DATABASE_CONNECTION=sqlite://data/sqlite.db
# Use qdrant/etc. for "production"
VECTOR_CONNECTION=hnsw://data/vdb
VECTOR_CONNECTION=hnsw://data/vdb
# When using OpenSearch as the vector backend
# VECTOR_CONNECTION=opensearch+https://admin:admin@localhost:9200

# If using OpenAPI, setup your API key here
OPENAI_API_KEY=
# Or point to local LLM configuration file. By default, memex wil use
# llama2
LOCAL_LLM_CONFIG=resources/config.llama2.toml
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 33 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,22 @@ since Linux ARM builds are very finicky.
2023-06-13T05:04:21.518732Z INFO memex: starting server with roles: [Api, Worker]
```

## Add a document
## Using a LLM
You can use either OpenAI or a local LLM for LLM based functionality (such as the
summarization or extraction APIs).

Set `OPENAI_API_KEY` to your API key in the `.env` file or set `LOCAL_LLM_CONFIG` to
a LLM configuration file. See `resources/config.llama2.toml` for an example. By
default, a base memex will use the llama-2 configuration file.

### Supported local models

Currently we have supported (and have tested) the following models:
- Llama based models (llama 1 & 2, Mistral, etc.) - *recommended*
- Gptj (e.g. GPT4All)


## Adding a document

NOTE: If the `test` collection does not initially exist, it'll be created.

Expand Down Expand Up @@ -80,7 +95,7 @@ Or if it's finished, something like so:
One the task is shown as "Completed", you can now run a query against the doc(s)
you've just added.

## Run a query
## Run a search query

``` bash
> curl http://localhost:8181/api/collections/test/search \
Expand All @@ -100,6 +115,22 @@ you've just added.
}
```

## Ask a question
```bash
> curl http://localhost:8181/api/action/ask \
-H "Content-Type: application/json" \
-X POST \
-d "{\"text\": \"<context if any>\", \"query\": \"What is the airspeed velocity of an unladen swallow?\", "json_schema": { .. }}"
{
"time": 1.234,
"status": "ok",
"result": {
"answer": "The airspeed velocity of an unladen swallow is..."
}
}

```

## Env variables

- `HOST`: Defaults to `127.0.0.1`
Expand Down
19 changes: 18 additions & 1 deletion bin/memex/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use api::ApiConfig;
use clap::{Parser, Subcommand};
use futures::future::join_all;
use std::{net::Ipv4Addr, process::ExitCode};
Expand Down Expand Up @@ -25,6 +26,10 @@ pub struct Args {
database_connection: Option<String>,
#[clap(long, value_parser, value_name = "VECTOR_CONNECTION", env)]
vector_connection: Option<String>,
#[clap(long, value_parser, value_name = "OPENAI_API_KEY", env)]
openai_api_key: Option<String>,
#[clap(long, value_parser, value_name = "LOCAL_LLM_CONFIG", env)]
local_llm_config: Option<String>,
}

#[derive(Debug, Display, Clone, PartialEq, EnumString)]
Expand Down Expand Up @@ -100,9 +105,21 @@ async fn main() -> ExitCode {

let _vector_store_uri = args.vector_connection.expect("VECTOR_CONNECTION not set");

if args.openai_api_key.is_none() && args.local_llm_config.is_none() {
log::error!("Must set either OPENAI_API_KEY or LOCAL_LLM_CONFIG");
return ExitCode::FAILURE;
}

if roles.contains(&Roles::Api) {
let db_uri = db_uri.clone();
handles.push(tokio::spawn(api::start(host, port, db_uri)));
let cfg = ApiConfig {
host,
port,
db_uri,
open_ai_key: args.openai_api_key,
local_llm_config: args.local_llm_config,
};
handles.push(tokio::spawn(api::start(cfg)));
}

if roles.contains(&Roles::Worker) {
Expand Down
8 changes: 5 additions & 3 deletions lib/api/src/endpoints/actions/filters.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::sync::Arc;

use crate::{endpoints::json_body, with_db, with_llm};
use libmemex::llm::openai::OpenAIClient;
use libmemex::llm::LLM;
use sea_orm::DatabaseConnection;
use serde::{Deserialize, Serialize};
use serde_json::Value;
Expand All @@ -24,7 +26,7 @@ pub struct SummarizeRequest {
}

fn extract(
llm: &OpenAIClient,
llm: &Arc<Box<dyn LLM>>,
) -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone {
warp::path!("action" / "ask")
.and(warp::post())
Expand All @@ -44,7 +46,7 @@ fn summarize(
}

pub fn build(
llm: &OpenAIClient,
llm: &Arc<Box<dyn LLM>>,
db: &DatabaseConnection,
) -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone {
extract(llm).or(summarize(db))
Expand Down
14 changes: 7 additions & 7 deletions lib/api/src/endpoints/actions/handlers.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use crate::{
schema::{ApiResponse, TaskResult},
ServerError,
Expand All @@ -9,19 +11,16 @@ use warp::reject::Rejection;
use super::filters;
use libmemex::{
db::queue,
llm::{
openai::{truncate_text, OpenAIClient},
prompter,
},
llm::{prompter, LLM},
};

pub async fn handle_extract(
llm: OpenAIClient,
llm: Arc<Box<dyn LLM>>,
request: filters::AskRequest,
) -> Result<impl warp::Reply, Rejection> {
let time = std::time::Instant::now();

let (content, model) = truncate_text(&request.text);
let (content, model) = llm.truncate_text(&request.text);

// Build prompt
let prompt = if let Some(schema) = &request.json_schema {
Expand All @@ -34,10 +33,11 @@ pub async fn handle_extract(
};

let response = llm
.chat_completion(&model, &prompt)
.chat_completion(model.as_ref(), &prompt)
.await
.map_err(|err| ServerError::Other(err.to_string()))?;

log::debug!("llm response: {response}");
let val = serde_json::from_str::<serde_json::Value>(&response)
.map_err(|err| ServerError::Other(err.to_string()))?;

Expand Down
2 changes: 1 addition & 1 deletion lib/api/src/endpoints/collections/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
};
use libmemex::{
db::{embedding, queue},
embedding::{ModelConfig, SentenceEmbedder},
llm::embedding::{ModelConfig, SentenceEmbedder},
storage::get_vector_storage,
};
use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter};
Expand Down
6 changes: 4 additions & 2 deletions lib/api/src/endpoints/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use libmemex::llm::openai::OpenAIClient;
use std::sync::Arc;

use libmemex::llm::LLM;
use sea_orm::DatabaseConnection;
use serde::de::DeserializeOwned;
use warp::Filter;
Expand All @@ -24,7 +26,7 @@ pub fn json_body<T: std::marker::Send + DeserializeOwned>(

pub fn build(
db: &DatabaseConnection,
llm: &OpenAIClient,
llm: &Arc<Box<dyn LLM>>,
) -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone {
actions::filters::build(llm, db)
.or(collections::filters::build(db))
Expand Down
48 changes: 37 additions & 11 deletions lib/api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use dotenv_codegen::dotenv;
use libmemex::{db::create_connection_by_uri, llm::openai::OpenAIClient};
use libmemex::{
db::create_connection_by_uri,
llm::{local::load_from_cfg, openai::OpenAIClient, LLM},
};
use sea_orm::DatabaseConnection;
use serde_json::json;
use std::{convert::Infallible, net::Ipv4Addr, path::PathBuf};
use std::{convert::Infallible, net::Ipv4Addr, path::PathBuf, sync::Arc};
use thiserror::Error;
use warp::{hyper::StatusCode, reject::Reject, Filter, Rejection, Reply};

Expand All @@ -22,6 +25,14 @@ pub enum ServerError {

impl Reject for ServerError {}

pub struct ApiConfig {
pub host: Ipv4Addr,
pub port: u16,
pub db_uri: String,
pub open_ai_key: Option<String>,
pub local_llm_config: Option<String>,
}

// Handle custom errors/rejections
async fn handle_rejection(err: Rejection) -> Result<impl Reply, Infallible> {
let code;
Expand All @@ -35,6 +46,12 @@ async fn handle_rejection(err: Rejection) -> Result<impl Reply, Infallible> {
// and render it however we want
code = StatusCode::METHOD_NOT_ALLOWED;
message = "METHOD_NOT_ALLOWED".into();
} else if let Some(err) = err.find::<ServerError>() {
(code, message) = match err {
ServerError::ClientRequestError(err) => (StatusCode::BAD_REQUEST, err.to_string()),
ServerError::DatabaseError(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
ServerError::Other(err) => (StatusCode::BAD_REQUEST, err.to_string()),
};
} else {
// We should have expected this... Just log and say its a 500
eprintln!("unhandled rejection: {:?}", err);
Expand All @@ -59,8 +76,8 @@ pub fn health_check() -> impl Filter<Extract = (impl warp::Reply,), Error = warp
.map(move || warp::reply::json(&json!({ "version": version })))
}

pub async fn start(host: Ipv4Addr, port: u16, db_uri: String) {
log::info!("starting api server @ {}:{}", host, port);
pub async fn start(config: ApiConfig) {
log::info!("starting api server @ {}:{}", config.host, config.port);

log::info!("checking for upload directory...");
let data_dir_path: PathBuf = endpoints::UPLOAD_DATA_DIR.into();
Expand All @@ -70,12 +87,21 @@ pub async fn start(host: Ipv4Addr, port: u16, db_uri: String) {
}

// Attempt to connect to db
let db_connection = create_connection_by_uri(&db_uri, true)
let db_connection = create_connection_by_uri(&config.db_uri, true)
.await
.unwrap_or_else(|err| panic!("Unable to connect to database: {} - {err}", db_uri));
.unwrap_or_else(|err| panic!("Unable to connect to database: {} - {err}", config.db_uri));

let llm_client: Arc<Box<dyn LLM>> = if let Some(openai_key) = config.open_ai_key {
Arc::new(Box::new(OpenAIClient::new(&openai_key)))
} else if let Some(llm_config_path) = config.local_llm_config {
let llm = load_from_cfg(llm_config_path.into(), true)
.await
.expect("Unable to load local LLM");
Arc::new(llm)
} else {
panic!("Please setup OPENAI_API_KEY or LOCAL_LLM_CONFIG");
};

let llm_client =
OpenAIClient::new(&std::env::var("OPENAI_API_KEY").expect("OpenAI API key not set"));
let cors = warp::cors()
.allow_any_origin()
.allow_methods(vec!["GET", "POST", "PUT", "PATCH", "DELETE"])
Expand All @@ -88,7 +114,7 @@ pub async fn start(host: Ipv4Addr, port: u16, db_uri: String) {
let filters = health_check().or(api).with(cors).recover(handle_rejection);

let (_addr, handle) =
warp::serve(filters).bind_with_graceful_shutdown((host, port), async move {
warp::serve(filters).bind_with_graceful_shutdown((config.host, config.port), async move {
tokio::signal::ctrl_c()
.await
.expect("failed to listen to shutdown signal");
Expand All @@ -105,7 +131,7 @@ pub fn with_db(
}

pub fn with_llm(
llm: OpenAIClient,
) -> impl Filter<Extract = (OpenAIClient,), Error = std::convert::Infallible> + Clone {
llm: Arc<Box<dyn LLM>>,
) -> impl Filter<Extract = (Arc<Box<dyn LLM>>,), Error = std::convert::Infallible> + Clone {
warp::any().map(move || llm.clone())
}
6 changes: 5 additions & 1 deletion lib/libmemex/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,24 @@ chrono = { workspace = true }
dotenv = { workspace = true }
handlebars = "4.4.0"
hnsw_rs = { git = "https:/jean-pierreBoth/hnswlib-rs", rev = "52a7f9174e002820d168fa65ca7303364ee3ac33" }
llm = { git = "https:/rustformers/llm.git", rev = "84800b02a7a96f62c0c9c03a38c36cb23bf4b2ec" }
log = { workspace = true }
migration ={ path = "../../migration" }
opensearch = "2.1.0"
qdrant-client = "1.2.0"
reqwest = { version = "0.11.16", features = ["stream" ] }
rand = "0.8.5"
rust-bert = { version = "0.21.0", features= ["download-libtorch"] }
sea-orm = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
strum = "0.25"
strum_macros = "0.25"
tera = "1.19.0"
thiserror = "1.0"
tiktoken-rs = "0.5.4"
tokenizers = { version = "0.14", features = ["http"] }
tokio = { workspace = true }
toml = "0.7.4"
url = "2.4.0"
uuid = { workspace = true }
uuid = { workspace = true }
1 change: 0 additions & 1 deletion lib/libmemex/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
pub mod db;
pub mod embedding;
pub mod llm;
pub mod storage;

Expand Down
File renamed without changes.
Loading
Loading