From f75a11e4bd7b0c4d0bc441cfb9bba3d1d6bba2c0 Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Wed, 20 Sep 2023 14:04:34 +0200 Subject: [PATCH 01/18] refactor: creating an LLM module in libmemex and moving embedding module inside --- Cargo.lock | 3 +++ lib/api/src/endpoints/collections/handlers.rs | 2 +- lib/libmemex/Cargo.toml | 3 +++ lib/libmemex/src/lib.rs | 1 - lib/libmemex/src/{ => llm}/embedding.rs | 0 lib/libmemex/src/llm/mod.rs | 1 + lib/worker/src/lib.rs | 2 +- 7 files changed, 9 insertions(+), 3 deletions(-) rename lib/libmemex/src/{ => llm}/embedding.rs (100%) diff --git a/Cargo.lock b/Cargo.lock index 00f762f..609cca3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2151,6 +2151,7 @@ dependencies = [ "dotenv", "handlebars", "hnsw_rs", + "llm", "log", "migration", "opensearch", @@ -2162,9 +2163,11 @@ dependencies = [ "serde_json", "strum", "strum_macros", + "tera", "thiserror", "tokenizers 0.14.0", "tokio", + "toml", "url", "uuid", ] diff --git a/lib/api/src/endpoints/collections/handlers.rs b/lib/api/src/endpoints/collections/handlers.rs index b973ee6..4690b40 100644 --- a/lib/api/src/endpoints/collections/handlers.rs +++ b/lib/api/src/endpoints/collections/handlers.rs @@ -4,7 +4,7 @@ use crate::{ }; use libmemex::{ db::{embedding, queue}, - embedding::{ModelConfig, SentenceEmbedder}, + llm::embedding::{ModelConfig, SentenceEmbedder}, storage::get_vector_storage, }; use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter}; diff --git a/lib/libmemex/Cargo.toml b/lib/libmemex/Cargo.toml index f08d7cb..70849c7 100644 --- a/lib/libmemex/Cargo.toml +++ b/lib/libmemex/Cargo.toml @@ -12,6 +12,7 @@ chrono = { workspace = true } dotenv = { workspace = true } handlebars = "4.4.0" hnsw_rs = { git = "https://github.com/jean-pierreBoth/hnswlib-rs", rev = "52a7f9174e002820d168fa65ca7303364ee3ac33" } +llm = { git = "https://github.com/rustformers/llm.git", rev = "84800b02a7a96f62c0c9c03a38c36cb23bf4b2ec" } log = { workspace = true } migration ={ path = "../../migration" } opensearch = "2.1.0" @@ -23,8 +24,10 @@ serde = { workspace = true } serde_json = { workspace = true } strum = "0.25" strum_macros = "0.25" +tera = "1.19.0" thiserror = "1.0" tokenizers = { version = "0.14", features = ["http"] } tokio = { workspace = true } +toml = "0.7.4" url = "2.4.0" uuid = { workspace = true } diff --git a/lib/libmemex/src/lib.rs b/lib/libmemex/src/lib.rs index 301a345..7cf7c03 100644 --- a/lib/libmemex/src/lib.rs +++ b/lib/libmemex/src/lib.rs @@ -1,5 +1,4 @@ pub mod db; -pub mod embedding; pub mod llm; pub mod storage; diff --git a/lib/libmemex/src/embedding.rs b/lib/libmemex/src/llm/embedding.rs similarity index 100% rename from lib/libmemex/src/embedding.rs rename to lib/libmemex/src/llm/embedding.rs diff --git a/lib/libmemex/src/llm/mod.rs b/lib/libmemex/src/llm/mod.rs index eb9526a..eabd3f1 100644 --- a/lib/libmemex/src/llm/mod.rs +++ b/lib/libmemex/src/llm/mod.rs @@ -1,2 +1,3 @@ +pub mod embedding; pub mod openai; pub mod prompter; diff --git a/lib/worker/src/lib.rs b/lib/worker/src/lib.rs index 6c52fd2..fcb700c 100644 --- a/lib/worker/src/lib.rs +++ b/lib/worker/src/lib.rs @@ -1,7 +1,7 @@ use libmemex::db::create_connection_by_uri; use libmemex::db::queue::{self, check_for_jobs, Job, TaskType}; use libmemex::db::{document, embedding}; -use libmemex::embedding::{ModelConfig, SentenceEmbedder}; +use libmemex::llm::embedding::{ModelConfig, SentenceEmbedder}; use libmemex::storage::{get_vector_storage, VectorData, VectorStorage}; use libmemex::NAMESPACE; use sea_orm::prelude::*; From 86b9a7c88e8ecc18fc676a88cf440c31cbf3ed97 Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Thu, 21 Sep 2023 16:40:53 +0200 Subject: [PATCH 02/18] adding a basic run prompt function for llm testing inside memex --- Cargo.lock | 1 + lib/libmemex/Cargo.toml | 1 + lib/libmemex/src/llm/mod.rs | 70 +++++++++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 609cca3..ca807c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2156,6 +2156,7 @@ dependencies = [ "migration", "opensearch", "qdrant-client", + "rand 0.8.5", "reqwest", "rust-bert", "sea-orm", diff --git a/lib/libmemex/Cargo.toml b/lib/libmemex/Cargo.toml index 70849c7..3465387 100644 --- a/lib/libmemex/Cargo.toml +++ b/lib/libmemex/Cargo.toml @@ -18,6 +18,7 @@ migration ={ path = "../../migration" } opensearch = "2.1.0" qdrant-client = "1.2.0" reqwest = { version = "0.11.16", features = ["stream" ] } +rand = "0.8.5" rust-bert = { version = "0.21.0", features= ["download-libtorch"] } sea-orm = { workspace = true } serde = { workspace = true } diff --git a/lib/libmemex/src/llm/mod.rs b/lib/libmemex/src/llm/mod.rs index eabd3f1..f502f20 100644 --- a/lib/libmemex/src/llm/mod.rs +++ b/lib/libmemex/src/llm/mod.rs @@ -1,3 +1,73 @@ +use std::path::PathBuf; + +use llm::{self, InferenceSessionConfig, KnownModel, LoadProgress}; +use tokio::sync::mpsc; + pub mod embedding; pub mod openai; pub mod prompter; + +#[derive(Debug)] +pub enum LlmEvent { + ModelLoadProgress(LoadProgress), + TokenReceived(String), + InferenceDone, +} + +pub fn run_prompt(prompt: &str) -> anyhow::Result<()> { + println!("Prompting..."); + let model_params = llm::ModelParameters::default(); + let infer_params = llm::InferenceParameters::default(); + + let prompt_request = llm::InferenceRequest { + prompt: llm::Prompt::Text(prompt), + maximum_token_count: None, + parameters: &infer_params, + play_back_previous_tokens: false, + }; + + let model_path: PathBuf = + "../../resources/models/LLaMa2/llama-2-7b-chat.ggmlv3.q4_1.bin".into(); + let model = match llm::load::( + &model_path, + llm::TokenizerSource::Embedded, + model_params, + move |_| {}, + ) { + Ok(model) => model, + Err(err) => return Err(anyhow::anyhow!("Unable to load model: {err}")), + }; + + let config = InferenceSessionConfig::default(); + let mut session = model.start_session(config); + + let (sender, _receiver) = mpsc::unbounded_channel::(); + + print!("Prompt: {}", prompt); + let _res = session.infer::( + &model, + &mut rand::thread_rng(), + &prompt_request, + &mut Default::default(), + move |t| { + match t { + llm::InferenceResponse::InferredToken(token) => { + print!("{}", token); + if sender.send(LlmEvent::TokenReceived(token)).is_err() { + return Ok(llm::InferenceFeedback::Halt); + } + } + llm::InferenceResponse::EotToken => { + if sender.send(LlmEvent::InferenceDone).is_err() { + return Ok(llm::InferenceFeedback::Halt); + } + } + _ => {} + } + + Ok(llm::InferenceFeedback::Continue) + }, + )?; + + Ok(()) +} From 14baa525dfb3f9c84f56e6db7a4b9675866f695a Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Thu, 21 Sep 2023 16:41:15 +0200 Subject: [PATCH 03/18] wip: sherpa - guiding llms using logit biasing, templates, etc. --- Cargo.lock | 52 ++++++++++++++++++++++++++++++++++++++++--- Cargo.toml | 1 + lib/sherpa/Cargo.toml | 23 +++++++++++++++++++ lib/sherpa/src/lib.rs | 11 +++++++++ 4 files changed, 84 insertions(+), 3 deletions(-) create mode 100644 lib/sherpa/Cargo.toml create mode 100644 lib/sherpa/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index ca807c1..0f26c14 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -694,7 +694,7 @@ dependencies = [ "tera", "tokenizers 0.14.0", "tokio", - "toml", + "toml 0.7.6", ] [[package]] @@ -2168,7 +2168,7 @@ dependencies = [ "thiserror", "tokenizers 0.14.0", "tokio", - "toml", + "toml 0.7.6", "url", "uuid", ] @@ -3966,6 +3966,27 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "sherpa" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "llm", + "log", + "rand 0.8.5", + "rust-bert", + "serde", + "serde_json", + "strum", + "strum_macros", + "tera", + "thiserror", + "tokenizers 0.14.0", + "tokio", + "toml 0.8.0", +] + [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -4714,7 +4735,19 @@ dependencies = [ "serde", "serde_spanned", "toml_datetime", - "toml_edit", + "toml_edit 0.19.14", +] + +[[package]] +name = "toml" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c226a7bba6d859b63c92c4b4fe69c5b6b72d0cb897dbc8e6012298e6154cb56e" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit 0.20.0", ] [[package]] @@ -4739,6 +4772,19 @@ dependencies = [ "winnow", ] +[[package]] +name = "toml_edit" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ff63e60a958cefbb518ae1fd6566af80d9d4be430a33f3723dfc47d1d411d95" +dependencies = [ + "indexmap 2.0.0", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + [[package]] name = "tonic" version = "0.9.2" diff --git a/Cargo.toml b/Cargo.toml index af394d6..79cb56f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "lib/api", "lib/libmemex", + "lib/sherpa", "lib/worker", "examples/clippy" diff --git a/lib/sherpa/Cargo.toml b/lib/sherpa/Cargo.toml new file mode 100644 index 0000000..a296f7f --- /dev/null +++ b/lib/sherpa/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "sherpa" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +anyhow = "1.0" +async-trait = { workspace = true } +llm = { git = "https://github.com/rustformers/llm.git", rev = "84800b02a7a96f62c0c9c03a38c36cb23bf4b2ec" } +log = { workspace = true } +rand = "0.8.5" +rust-bert = { version = "0.21.0", features= ["download-libtorch"] } +serde = { workspace = true } +serde_json = { workspace = true } +strum = "0.25" +strum_macros = "0.25" +tera = "1.19.0" +thiserror = "1.0" +tokenizers = { version = "0.14", features = ["http"] } +tokio = { workspace = true } +toml = "0.8.0" diff --git a/lib/sherpa/src/lib.rs b/lib/sherpa/src/lib.rs new file mode 100644 index 0000000..7b5d6fc --- /dev/null +++ b/lib/sherpa/src/lib.rs @@ -0,0 +1,11 @@ +pub async fn create_mask() -> anyhow::Result<()> { + Ok(()) +} + +#[cfg(test)] +mod test { + #[tokio::test] + async fn test_logit_biasing() { + super::create_mask().await.expect("Unable to create mask"); + } +} From ce15cd48e45e0ac5a563dda28e88ba2c9026737a Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Fri, 22 Sep 2023 14:37:46 +0200 Subject: [PATCH 04/18] wip: tweaking sampler to output only what we want --- Cargo.lock | 2 + lib/libmemex/Cargo.toml | 3 +- lib/libmemex/src/llm/mod.rs | 186 ++++++++++++++++++++++++++++-------- lib/sherpa/Cargo.toml | 1 + lib/sherpa/src/lib.rs | 51 +++++++++- 5 files changed, 198 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0f26c14..85f7994 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2162,6 +2162,7 @@ dependencies = [ "sea-orm", "serde", "serde_json", + "sherpa", "strum", "strum_macros", "tera", @@ -3973,6 +3974,7 @@ dependencies = [ "anyhow", "async-trait", "llm", + "llm-base", "log", "rand 0.8.5", "rust-bert", diff --git a/lib/libmemex/Cargo.toml b/lib/libmemex/Cargo.toml index 3465387..5f8e9fb 100644 --- a/lib/libmemex/Cargo.toml +++ b/lib/libmemex/Cargo.toml @@ -23,6 +23,7 @@ rust-bert = { version = "0.21.0", features= ["download-libtorch"] } sea-orm = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } +sherpa = { path = "../sherpa" } strum = "0.25" strum_macros = "0.25" tera = "1.19.0" @@ -31,4 +32,4 @@ tokenizers = { version = "0.14", features = ["http"] } tokio = { workspace = true } toml = "0.7.4" url = "2.4.0" -uuid = { workspace = true } +uuid = { workspace = true } \ No newline at end of file diff --git a/lib/libmemex/src/llm/mod.rs b/lib/libmemex/src/llm/mod.rs index f502f20..887a77c 100644 --- a/lib/libmemex/src/llm/mod.rs +++ b/lib/libmemex/src/llm/mod.rs @@ -1,6 +1,9 @@ +use llm::samplers::llm_samplers::samplers::SampleFlatBias; +use llm::samplers::llm_samplers::types::SamplerChain; +use llm::{self, samplers::ConfiguredSamplers, InferenceSessionConfig, LoadProgress}; +use std::io::Write; use std::path::PathBuf; - -use llm::{self, InferenceSessionConfig, KnownModel, LoadProgress}; +use std::sync::{Arc, Mutex}; use tokio::sync::mpsc; pub mod embedding; @@ -14,20 +17,135 @@ pub enum LlmEvent { InferenceDone, } -pub fn run_prompt(prompt: &str) -> anyhow::Result<()> { - println!("Prompting..."); - let model_params = llm::ModelParameters::default(); - let infer_params = llm::InferenceParameters::default(); +// Run model and apply rules to generation. +async fn run_guidance( + model: &T, + default_samplers: SamplerChain, + prompt: &str, +) -> anyhow::Result<()> +where + T: llm::KnownModel, +{ + // Create a mask based on the current tokens generated and + let token_bias = sherpa::create_mask(model.tokenizer(), false).expect("Unable to create mask"); + let mut bias_sampler = SampleFlatBias::new(token_bias); + + // Create sampler chain + let mut samplers = SamplerChain::new(); + samplers += bias_sampler.clone(); + samplers += default_samplers; + + let infer_params = llm::InferenceParameters { + sampler: Arc::new(Mutex::new(samplers)), // sampler: llm::samplers::default_samplers(), + }; + + // Add our biasing mask to the sampler chain. + let mut num_tokens = 0; + let buffer = Arc::new(Mutex::new(prompt.to_string())); + let tokens = Arc::new(Mutex::new(Vec::new())); + + let (sender, mut receiver) = mpsc::unbounded_channel::(); - let prompt_request = llm::InferenceRequest { - prompt: llm::Prompt::Text(prompt), - maximum_token_count: None, - parameters: &infer_params, - play_back_previous_tokens: false, + let writer_handle = { + let buffer = buffer.clone(); + let tokens = tokens.clone(); + tokio::spawn(async move { + loop { + if let Some(event) = receiver.recv().await { + match &event { + LlmEvent::TokenReceived(token) => { + if let Ok(mut buff) = buffer.lock() { + *buff += token; + } + + if let Ok(mut tokens) = tokens.lock() { + tokens.push(token.to_string()); + } + } + LlmEvent::InferenceDone => { + std::io::stdout().flush().unwrap(); + return; + } + _ => {} + } + } + } + }) }; + { + let config = InferenceSessionConfig::default(); + let buffer = buffer.clone(); + loop { + let prompt = if let Ok(buff) = buffer.lock() { + buff.to_string() + } else { + continue; + }; + println!("processing prompt: {}", prompt); + let infer_params = infer_params.clone(); + let sender = sender.clone(); + + let prompt_request = llm::InferenceRequest { + prompt: llm::Prompt::Text(&prompt), + maximum_token_count: Some(1), + parameters: &infer_params, + play_back_previous_tokens: false, + }; + + let mut session = model.start_session(config); + let channel = sender.clone(); + let _res = session + .infer::( + model, + &mut rand::thread_rng(), + &prompt_request, + &mut Default::default(), + move |t| { + match t { + llm::InferenceResponse::InferredToken(token) => { + if channel.send(LlmEvent::TokenReceived(token)).is_err() { + return Ok(llm::InferenceFeedback::Halt); + } + } + llm::InferenceResponse::EotToken => { + if channel.send(LlmEvent::InferenceDone).is_err() { + return Ok(llm::InferenceFeedback::Halt); + } + } + _ => {} + } + + Ok(llm::InferenceFeedback::Continue) + }, + ) + .expect("Unable"); + + let _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + num_tokens += 1; + *bias_sampler = + sherpa::create_mask(model.tokenizer(), false).expect("Unable to create mask"); + if num_tokens >= 3 { + let _ = sender.send(LlmEvent::InferenceDone); + break; + } + } + } + + let _ = writer_handle.await; + println!("result: {}", buffer.lock().unwrap()); + println!("tokens generated: {:?}", tokens.lock().unwrap()); + println!("🛑"); + Ok(()) +} + +pub async fn run_prompt(prompt: &str) -> anyhow::Result<()> { + println!("Prompting..."); + + // Load model let model_path: PathBuf = "../../resources/models/LLaMa2/llama-2-7b-chat.ggmlv3.q4_1.bin".into(); + let model_params = llm::ModelParameters::default(); let model = match llm::load::( &model_path, llm::TokenizerSource::Embedded, @@ -38,36 +156,22 @@ pub fn run_prompt(prompt: &str) -> anyhow::Result<()> { Err(err) => return Err(anyhow::anyhow!("Unable to load model: {err}")), }; - let config = InferenceSessionConfig::default(); - let mut session = model.start_session(config); - - let (sender, _receiver) = mpsc::unbounded_channel::(); - - print!("Prompt: {}", prompt); - let _res = session.infer::( - &model, - &mut rand::thread_rng(), - &prompt_request, - &mut Default::default(), - move |t| { - match t { - llm::InferenceResponse::InferredToken(token) => { - print!("{}", token); - if sender.send(LlmEvent::TokenReceived(token)).is_err() { - return Ok(llm::InferenceFeedback::Halt); - } - } - llm::InferenceResponse::EotToken => { - if sender.send(LlmEvent::InferenceDone).is_err() { - return Ok(llm::InferenceFeedback::Halt); - } - } - _ => {} - } + // Configure samplers + let mut samplers = ConfiguredSamplers::default(); + samplers.ensure_default_slots(); + + run_guidance(&model, samplers.builder.into_chain(), prompt).await +} - Ok(llm::InferenceFeedback::Continue) - }, - )?; +#[cfg(test)] +mod test { + use super::run_prompt; - Ok(()) + #[tokio::test] + async fn test_structured_prompt() { + run_prompt("i see london, i see france, i see ") + .await + .expect("Unable to prompt"); + println!(""); + } } diff --git a/lib/sherpa/Cargo.toml b/lib/sherpa/Cargo.toml index a296f7f..3a8e19a 100644 --- a/lib/sherpa/Cargo.toml +++ b/lib/sherpa/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" anyhow = "1.0" async-trait = { workspace = true } llm = { git = "https://github.com/rustformers/llm.git", rev = "84800b02a7a96f62c0c9c03a38c36cb23bf4b2ec" } +llm-base = { git = "https://github.com/rustformers/llm.git", rev = "84800b02a7a96f62c0c9c03a38c36cb23bf4b2ec" } log = { workspace = true } rand = "0.8.5" rust-bert = { version = "0.21.0", features= ["download-libtorch"] } diff --git a/lib/sherpa/src/lib.rs b/lib/sherpa/src/lib.rs index 7b5d6fc..efb56e9 100644 --- a/lib/sherpa/src/lib.rs +++ b/lib/sherpa/src/lib.rs @@ -1,11 +1,56 @@ -pub async fn create_mask() -> anyhow::Result<()> { - Ok(()) +use llm_base::{TokenId, Tokenizer}; + +pub type TokenMask = Vec<(TokenId, f32)>; + +pub fn create_mask(tokens: &Tokenizer, only_numbers: bool) -> anyhow::Result { + let num_tokens = tokens.len(); + + let mut mask: Vec<(TokenId, f32)> = Vec::new(); + + for idx in 0..num_tokens { + let token = tokens.token(idx); + let token_str = String::from_utf8_lossy(&token); + + let bias = if only_numbers { + if token_str.is_ascii() && token_str.parse::().is_ok() { + 1.0 + } else { + f32::NEG_INFINITY + } + } else if token_str.is_ascii() + && !token_str.starts_with(' ') + && token_str.parse::().is_err() + { + 1.0 + } else { + f32::NEG_INFINITY + }; + + mask.push((idx as u32, bias)); + } + + Ok(mask) } #[cfg(test)] mod test { + use std::path::PathBuf; + + use llm_base::KnownModel; + #[tokio::test] async fn test_logit_biasing() { - super::create_mask().await.expect("Unable to create mask"); + let model_path: PathBuf = + "../../resources/models/LLaMa2/llama-2-7b-chat.ggmlv3.q4_1.bin".into(); + + let model = llm::load::( + &model_path, + llm::TokenizerSource::Embedded, + llm::ModelParameters::default(), + move |_| {}, + ) + .expect("Unable to load model"); + + super::create_mask(&model.tokenizer(), false).expect("Unable to create mask"); } } From 224538a07df31c09f0829e1ebe2cef1fc7405533 Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Wed, 4 Oct 2023 16:03:10 -0400 Subject: [PATCH 05/18] updating .env.template file --- .env.template | 10 +++++++++- resources/config.llama2.toml | 3 +-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/.env.template b/.env.template index 5fd8be2..d9ecada 100644 --- a/.env.template +++ b/.env.template @@ -4,4 +4,12 @@ PORT=8181 # Use postgres for "production" DATABASE_CONNECTION=sqlite://data/sqlite.db # Use qdrant/etc. for "production" -VECTOR_CONNECTION=hnsw://data/vdb \ No newline at end of file +VECTOR_CONNECTION=hnsw://data/vdb +# When using OpenSearch as the vector backend +# VECTOR_CONNECTION=opensearch+https://admin:admin@localhost:9200 + +# If using OpenAPI, setup your API key here +OPENAI_API_KEY= +# Or point to local LLM configuration file. By default, memex wil use +# llama2 +LOCAL_LLM_CONFIG=resources/config.llama2.toml \ No newline at end of file diff --git a/resources/config.llama2.toml b/resources/config.llama2.toml index f638779..5570975 100644 --- a/resources/config.llama2.toml +++ b/resources/config.llama2.toml @@ -1,8 +1,7 @@ prompt_template = "resources/templates/clippy_prompt.txt" [model] -# Download from https://huggingface.co/TheBloke/Wizard-Vicuna-7B-Uncensored-GGML/tree/main -# And place in the same directory as this config file. +# Download from https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML path = "resources/models/LLaMa2/llama-2-7b-chat.ggmlv3.q4_1.bin" model_type = "Llama" prefer_mmap = false From d6f6d6510a92fca191c4ea55e48d530bbe2a50ff Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Thu, 5 Oct 2023 14:57:55 -0400 Subject: [PATCH 06/18] removing sherpa stuff for now --- lib/sherpa/Cargo.toml | 24 ------------------- lib/sherpa/src/lib.rs | 56 ------------------------------------------- 2 files changed, 80 deletions(-) delete mode 100644 lib/sherpa/Cargo.toml delete mode 100644 lib/sherpa/src/lib.rs diff --git a/lib/sherpa/Cargo.toml b/lib/sherpa/Cargo.toml deleted file mode 100644 index 3a8e19a..0000000 --- a/lib/sherpa/Cargo.toml +++ /dev/null @@ -1,24 +0,0 @@ -[package] -name = "sherpa" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -anyhow = "1.0" -async-trait = { workspace = true } -llm = { git = "https://github.com/rustformers/llm.git", rev = "84800b02a7a96f62c0c9c03a38c36cb23bf4b2ec" } -llm-base = { git = "https://github.com/rustformers/llm.git", rev = "84800b02a7a96f62c0c9c03a38c36cb23bf4b2ec" } -log = { workspace = true } -rand = "0.8.5" -rust-bert = { version = "0.21.0", features= ["download-libtorch"] } -serde = { workspace = true } -serde_json = { workspace = true } -strum = "0.25" -strum_macros = "0.25" -tera = "1.19.0" -thiserror = "1.0" -tokenizers = { version = "0.14", features = ["http"] } -tokio = { workspace = true } -toml = "0.8.0" diff --git a/lib/sherpa/src/lib.rs b/lib/sherpa/src/lib.rs deleted file mode 100644 index efb56e9..0000000 --- a/lib/sherpa/src/lib.rs +++ /dev/null @@ -1,56 +0,0 @@ -use llm_base::{TokenId, Tokenizer}; - -pub type TokenMask = Vec<(TokenId, f32)>; - -pub fn create_mask(tokens: &Tokenizer, only_numbers: bool) -> anyhow::Result { - let num_tokens = tokens.len(); - - let mut mask: Vec<(TokenId, f32)> = Vec::new(); - - for idx in 0..num_tokens { - let token = tokens.token(idx); - let token_str = String::from_utf8_lossy(&token); - - let bias = if only_numbers { - if token_str.is_ascii() && token_str.parse::().is_ok() { - 1.0 - } else { - f32::NEG_INFINITY - } - } else if token_str.is_ascii() - && !token_str.starts_with(' ') - && token_str.parse::().is_err() - { - 1.0 - } else { - f32::NEG_INFINITY - }; - - mask.push((idx as u32, bias)); - } - - Ok(mask) -} - -#[cfg(test)] -mod test { - use std::path::PathBuf; - - use llm_base::KnownModel; - - #[tokio::test] - async fn test_logit_biasing() { - let model_path: PathBuf = - "../../resources/models/LLaMa2/llama-2-7b-chat.ggmlv3.q4_1.bin".into(); - - let model = llm::load::( - &model_path, - llm::TokenizerSource::Embedded, - llm::ModelParameters::default(), - move |_| {}, - ) - .expect("Unable to load model"); - - super::create_mask(&model.tokenizer(), false).expect("Unable to create mask"); - } -} From 066baf3855d9ad40f45ef3c6c85ab42ecfc0a56d Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Thu, 5 Oct 2023 17:09:22 -0400 Subject: [PATCH 07/18] Creating LLM trait and sharing structs between local LLM & OpenAI impls --- lib/api/src/endpoints/actions/handlers.rs | 4 +- lib/libmemex/src/llm/local/mod.rs | 207 +++++++++++++++++++++ lib/libmemex/src/llm/local/schema.rs | 57 ++++++ lib/libmemex/src/llm/mod.rs | 211 ++++++---------------- lib/libmemex/src/llm/openai/mod.rs | 113 +++++------- lib/libmemex/src/llm/prompter.rs | 22 +-- lib/worker/src/tasks.rs | 4 +- 7 files changed, 372 insertions(+), 246 deletions(-) create mode 100644 lib/libmemex/src/llm/local/mod.rs create mode 100644 lib/libmemex/src/llm/local/schema.rs diff --git a/lib/api/src/endpoints/actions/handlers.rs b/lib/api/src/endpoints/actions/handlers.rs index c5b617a..a6a2665 100644 --- a/lib/api/src/endpoints/actions/handlers.rs +++ b/lib/api/src/endpoints/actions/handlers.rs @@ -11,7 +11,7 @@ use libmemex::{ db::queue, llm::{ openai::{truncate_text, OpenAIClient}, - prompter, + prompter, LLM, }, }; @@ -34,7 +34,7 @@ pub async fn handle_extract( }; let response = llm - .chat_completion(&model, &prompt) + .chat_completion(model, &prompt) .await .map_err(|err| ServerError::Other(err.to_string()))?; diff --git a/lib/libmemex/src/llm/local/mod.rs b/lib/libmemex/src/llm/local/mod.rs new file mode 100644 index 0000000..65b9dd1 --- /dev/null +++ b/lib/libmemex/src/llm/local/mod.rs @@ -0,0 +1,207 @@ +use llm::samplers::llm_samplers::samplers::SampleFlatBias; +use llm::samplers::llm_samplers::types::SamplerChain; +use llm::InferenceParameters; +use llm::{self, samplers::ConfiguredSamplers, InferenceSessionConfig}; +use std::io::Write; +use std::sync::{Arc, Mutex}; +use tokio::sync::mpsc; + +use crate::llm::ChatRole; + +use super::{ChatMessage, LLMError, LLM}; +mod schema; +use schema::LlmEvent; + +pub struct LocalLLM +where + T: llm::KnownModel, +{ + model: T, + infer_params: InferenceParameters, + /// At the moment does nothing but will eventually be used by our internal + /// sampler to only output JSON/etc. + bias_sampler: Arc>, +} + +impl LocalLLM +where + T: llm::KnownModel, +{ + fn new(model: T) -> Self { + let bias_sampler = SampleFlatBias::default(); + // Create sampler chain + let mut samplers = SamplerChain::new(); + samplers += bias_sampler.clone(); + + let mut default_samplers = ConfiguredSamplers::default(); + default_samplers.ensure_default_slots(); + samplers += default_samplers.builder.into_chain(); + + let infer_params = llm::InferenceParameters { + sampler: Arc::new(Mutex::new(samplers)), // sampler: llm::samplers::default_samplers(), + }; + + Self { + model, + infer_params, + bias_sampler: Arc::new(Mutex::new(bias_sampler)), + } + } + + async fn run_model(&self, prompt: &str) -> anyhow::Result { + log::info!("running model w/ prompt: {prompt}"); + + let buffer = Arc::new(Mutex::new(String::new())); + let tokens = Arc::new(Mutex::new(Vec::new())); + + let (sender, mut receiver) = mpsc::unbounded_channel::(); + + let writer_handle = { + let buffer = buffer.clone(); + let tokens = tokens.clone(); + tokio::spawn(async move { + loop { + if let Some(event) = receiver.recv().await { + match &event { + LlmEvent::TokenReceived(token) => { + if let Ok(mut buff) = buffer.lock() { + *buff += token; + } + + if let Ok(mut tokens) = tokens.lock() { + tokens.push(token.to_string()); + } + } + LlmEvent::InferenceDone => { + std::io::stdout().flush().unwrap(); + return; + } + _ => {} + } + } + } + }) + }; + + let config = InferenceSessionConfig::default(); + let infer_params = self.infer_params.clone(); + let sender = sender.clone(); + + let prompt_request = llm::InferenceRequest { + prompt: llm::Prompt::Text(prompt), + maximum_token_count: None, + parameters: &infer_params, + play_back_previous_tokens: false, + }; + + let channel = sender.clone(); + let mut session = self.model.start_session(config); + let _stats = session + .infer::( + &self.model, + &mut rand::thread_rng(), + &prompt_request, + &mut Default::default(), + move |t| { + match t { + llm::InferenceResponse::InferredToken(token) => { + if channel.send(LlmEvent::TokenReceived(token)).is_err() { + return Ok(llm::InferenceFeedback::Halt); + } + } + llm::InferenceResponse::EotToken => { + if channel.send(LlmEvent::InferenceDone).is_err() { + return Ok(llm::InferenceFeedback::Halt); + } + } + _ => {} + } + + Ok(llm::InferenceFeedback::Continue) + }, + ) + .map_err(|err| LLMError::InferenceError(err.to_string()))?; + let _ = sender.send(LlmEvent::InferenceDone); + // Wait for buffer to finish writing + let _ = writer_handle.await; + // Retrieve buffer and clean up any trailing/leading spaces. + let buffer = buffer + .lock() + .expect("Unable to grab buffer") + .trim() + .to_string(); + Ok(buffer) + } +} + +#[async_trait::async_trait] +impl LLM for LocalLLM +where + T: llm::KnownModel, +{ + async fn chat_completion( + &self, + _: schema::LocalLLMSize, + msgs: &[ChatMessage], + ) -> anyhow::Result { + log::info!("LocalLLM running chat_completion"); + + let system_msg = msgs + .iter() + .find(|x| x.role == ChatRole::System) + .map(|x| x.content.clone()) + .unwrap_or(String::from("You're a helpful assistant")); + + // Currently the prompt assumes a llama based model, pull this out into the + // config file. + let mut prompt = format!("[INST] <>\n{system_msg}\n<>\n\n"); + for msg in msgs { + if msg.role == ChatRole::System { + continue; + } + + prompt.push_str(&format!("{}\n", msg.content)); + } + prompt.push_str("[/INST]"); + self.run_model(&prompt).await + } +} + +#[cfg(test)] +mod test { + use crate::llm::{ChatMessage, LLM}; + + use super::schema::LocalLLMConfig; + use super::LocalLLM; + use std::path::PathBuf; + + #[ignore] + #[tokio::test] + async fn test_prompting() { + let base_dir: PathBuf = "../..".into(); + let model_config: PathBuf = base_dir.join("resources/config.llama2.toml"); + + let config = std::fs::read_to_string(model_config).expect("Unable to read cfg"); + let config: LocalLLMConfig = toml::from_str(&config).expect("Unable to parse cfg"); + let model_path: PathBuf = base_dir.join(config.model.path); + + let model_params = llm::ModelParameters::default(); + let model = llm::load::( + &model_path, + llm::TokenizerSource::Embedded, + model_params, + move |_| {}, + ) + .expect("Unable to load model"); + + let llm = LocalLLM::new(model); + let msgs = vec![ + ChatMessage::system("You're a helpful assistant. Answer questions as accurately and concisely as possible."), + ChatMessage::user("Who won the world series in 2020?"), + ]; + + let result = llm.chat_completion(Default::default(), &msgs).await; + assert!(result.is_ok()); + dbg!(result.unwrap()); + } +} diff --git a/lib/libmemex/src/llm/local/schema.rs b/lib/libmemex/src/llm/local/schema.rs new file mode 100644 index 0000000..c440593 --- /dev/null +++ b/lib/libmemex/src/llm/local/schema.rs @@ -0,0 +1,57 @@ +use llm::{LoadProgress, ModelArchitecture}; +use serde::Deserialize; +use std::path::PathBuf; + +#[derive(Debug)] +pub enum LlmEvent { + ModelLoadProgress(LoadProgress), + TokenReceived(String), + InferenceDone, +} + +#[derive(Deserialize)] +pub struct LocalLLMConfig { + pub prompt_template: PathBuf, + pub model: ModelConfig, +} + +#[derive(Clone, Deserialize)] +pub struct ModelConfig { + pub path: PathBuf, + pub model_type: ModelArch, + pub prefer_mmap: bool, + pub top_k: usize, + pub top_p: f32, + pub repeat_penalty: f32, + pub temperature: f32, + pub repetition_penalty_last_n: usize, +} + +#[derive(Clone, Deserialize)] +pub enum ModelArch { + Bloom, + Gpt2, + GptJ, + GptNeoX, + Llama, + Mpt, +} + +impl From for ModelArch { + fn from(value: ModelArchitecture) -> Self { + match value { + ModelArchitecture::Bloom => Self::Bloom, + ModelArchitecture::Gpt2 => Self::Gpt2, + ModelArchitecture::GptJ => Self::GptJ, + ModelArchitecture::GptNeoX => Self::GptNeoX, + ModelArchitecture::Llama => Self::Llama, + ModelArchitecture::Mpt => Self::Mpt, + } + } +} + +#[derive(Default, Debug, Clone)] +pub enum LocalLLMSize { + #[default] + Base, +} diff --git a/lib/libmemex/src/llm/mod.rs b/lib/libmemex/src/llm/mod.rs index 887a77c..4d8aa65 100644 --- a/lib/libmemex/src/llm/mod.rs +++ b/lib/libmemex/src/llm/mod.rs @@ -1,177 +1,70 @@ -use llm::samplers::llm_samplers::samplers::SampleFlatBias; -use llm::samplers::llm_samplers::types::SamplerChain; -use llm::{self, samplers::ConfiguredSamplers, InferenceSessionConfig, LoadProgress}; -use std::io::Write; -use std::path::PathBuf; -use std::sync::{Arc, Mutex}; -use tokio::sync::mpsc; +use serde::Serialize; +use strum_macros::Display; +use thiserror::Error; pub mod embedding; +pub mod local; pub mod openai; pub mod prompter; -#[derive(Debug)] -pub enum LlmEvent { - ModelLoadProgress(LoadProgress), - TokenReceived(String), - InferenceDone, +#[derive(Clone, Debug, Serialize, Display, Eq, PartialEq)] +pub enum ChatRole { + #[strum(serialize = "system")] + System, + #[strum(serialize = "user")] + User, + #[strum(serialize = "assistant")] + Assistant, } -// Run model and apply rules to generation. -async fn run_guidance( - model: &T, - default_samplers: SamplerChain, - prompt: &str, -) -> anyhow::Result<()> -where - T: llm::KnownModel, -{ - // Create a mask based on the current tokens generated and - let token_bias = sherpa::create_mask(model.tokenizer(), false).expect("Unable to create mask"); - let mut bias_sampler = SampleFlatBias::new(token_bias); - - // Create sampler chain - let mut samplers = SamplerChain::new(); - samplers += bias_sampler.clone(); - samplers += default_samplers; - - let infer_params = llm::InferenceParameters { - sampler: Arc::new(Mutex::new(samplers)), // sampler: llm::samplers::default_samplers(), - }; - - // Add our biasing mask to the sampler chain. - let mut num_tokens = 0; - let buffer = Arc::new(Mutex::new(prompt.to_string())); - let tokens = Arc::new(Mutex::new(Vec::new())); - - let (sender, mut receiver) = mpsc::unbounded_channel::(); - - let writer_handle = { - let buffer = buffer.clone(); - let tokens = tokens.clone(); - tokio::spawn(async move { - loop { - if let Some(event) = receiver.recv().await { - match &event { - LlmEvent::TokenReceived(token) => { - if let Ok(mut buff) = buffer.lock() { - *buff += token; - } - - if let Ok(mut tokens) = tokens.lock() { - tokens.push(token.to_string()); - } - } - LlmEvent::InferenceDone => { - std::io::stdout().flush().unwrap(); - return; - } - _ => {} - } - } - } - }) - }; - - { - let config = InferenceSessionConfig::default(); - let buffer = buffer.clone(); - loop { - let prompt = if let Ok(buff) = buffer.lock() { - buff.to_string() - } else { - continue; - }; - println!("processing prompt: {}", prompt); - let infer_params = infer_params.clone(); - let sender = sender.clone(); +#[derive(Serialize, Debug, Clone)] +pub struct ChatMessage { + role: ChatRole, + content: String, +} - let prompt_request = llm::InferenceRequest { - prompt: llm::Prompt::Text(&prompt), - maximum_token_count: Some(1), - parameters: &infer_params, - play_back_previous_tokens: false, - }; +impl ChatMessage { + pub fn assistant(content: &str) -> Self { + Self::new(ChatRole::Assistant, content) + } - let mut session = model.start_session(config); - let channel = sender.clone(); - let _res = session - .infer::( - model, - &mut rand::thread_rng(), - &prompt_request, - &mut Default::default(), - move |t| { - match t { - llm::InferenceResponse::InferredToken(token) => { - if channel.send(LlmEvent::TokenReceived(token)).is_err() { - return Ok(llm::InferenceFeedback::Halt); - } - } - llm::InferenceResponse::EotToken => { - if channel.send(LlmEvent::InferenceDone).is_err() { - return Ok(llm::InferenceFeedback::Halt); - } - } - _ => {} - } + pub fn user(content: &str) -> Self { + Self::new(ChatRole::User, content) + } - Ok(llm::InferenceFeedback::Continue) - }, - ) - .expect("Unable"); + pub fn system(content: &str) -> Self { + Self::new(ChatRole::System, content) + } - let _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - num_tokens += 1; - *bias_sampler = - sherpa::create_mask(model.tokenizer(), false).expect("Unable to create mask"); - if num_tokens >= 3 { - let _ = sender.send(LlmEvent::InferenceDone); - break; - } + pub fn new(role: ChatRole, content: &str) -> Self { + Self { + role, + content: content.to_string(), } } - - let _ = writer_handle.await; - println!("result: {}", buffer.lock().unwrap()); - println!("tokens generated: {:?}", tokens.lock().unwrap()); - println!("🛑"); - Ok(()) } -pub async fn run_prompt(prompt: &str) -> anyhow::Result<()> { - println!("Prompting..."); - - // Load model - let model_path: PathBuf = - "../../resources/models/LLaMa2/llama-2-7b-chat.ggmlv3.q4_1.bin".into(); - let model_params = llm::ModelParameters::default(); - let model = match llm::load::( - &model_path, - llm::TokenizerSource::Embedded, - model_params, - move |_| {}, - ) { - Ok(model) => model, - Err(err) => return Err(anyhow::anyhow!("Unable to load model: {err}")), - }; - - // Configure samplers - let mut samplers = ConfiguredSamplers::default(); - samplers.ensure_default_slots(); - - run_guidance(&model, samplers.builder.into_chain(), prompt).await +#[derive(Debug, Error)] +pub enum LLMError { + #[error("Context length exceeded: {0}")] + ContextLengthExceeded(String), + #[error("No response received")] + NoResponse, + #[error("Inference Error: {0}")] + InferenceError(String), + #[error("Request Error: {0}")] + RequestError(#[from] reqwest::Error), + #[error("Unable to deserialize: {0}")] + SerdeError(#[from] serde_json::Error), + #[error("Invalid Request: {0}")] + Other(String), } -#[cfg(test)] -mod test { - use super::run_prompt; - - #[tokio::test] - async fn test_structured_prompt() { - run_prompt("i see london, i see france, i see ") - .await - .expect("Unable to prompt"); - println!(""); - } +#[async_trait::async_trait] +pub trait LLM { + async fn chat_completion( + &self, + model: T, + msgs: &[ChatMessage], + ) -> anyhow::Result; } diff --git a/lib/libmemex/src/llm/openai/mod.rs b/lib/libmemex/src/llm/openai/mod.rs index ae25b6a..be91135 100644 --- a/lib/libmemex/src/llm/openai/mod.rs +++ b/lib/libmemex/src/llm/openai/mod.rs @@ -1,10 +1,10 @@ use reqwest::{header, Response, StatusCode}; use serde::Serialize; use strum_macros::{AsRefStr, Display}; -use thiserror::Error; use tiktoken_rs::cl100k_base; use self::schema::ErrorResponse; +use super::{ChatMessage, LLMError, LLM}; mod schema; @@ -32,41 +32,12 @@ pub enum OpenAIModel { GPT4_8K, } -#[derive(Debug, Error)] -pub enum OpenAIError { - #[error("Context length exceeded: {0}")] - ContextLengthExceeded(String), - #[error("No response received")] - NoResponse, - #[error("Request Error: {0}")] - RequestError(#[from] reqwest::Error), - #[error("Unable to deserialize: {0}")] - SerdeError(#[from] serde_json::Error), - #[error("Invalid Request: {0}")] - Other(String), -} - -impl From for OpenAIError { +impl From for LLMError { fn from(value: ErrorResponse) -> Self { if value.error.code == CONTEXT_LENGTH_ERROR { - OpenAIError::ContextLengthExceeded(value.error.message) + LLMError::ContextLengthExceeded(value.error.message) } else { - OpenAIError::Other(value.error.message) - } - } -} - -#[derive(Serialize, Debug, Clone)] -pub struct ChatMessage { - role: String, - content: String, -} - -impl ChatMessage { - pub fn new(role: &str, content: &str) -> Self { - Self { - role: role.to_string(), - content: content.to_string(), + LLMError::Other(value.error.message) } } } @@ -103,16 +74,16 @@ impl CompletionRequest { } /// Helper function to parse error messages from the OpenAI API response. -async fn check_api_error(response: Response) -> OpenAIError { +async fn check_api_error(response: Response) -> LLMError { // Grab the raw response body let raw_body = match response.text().await { Ok(raw) => raw, - Err(err) => return OpenAIError::Other(format!("Invalid response: {err}")), + Err(err) => return LLMError::Other(format!("Invalid response: {err}")), }; // Attempt to parse into an error object, otherwise return the raw message. match serde_json::from_str::(&raw_body) { Ok(error) => error.into(), - Err(err) => OpenAIError::Other(format!("Error: {err}, raw response: {raw_body}")), + Err(err) => LLMError::Other(format!("Error: {err}, raw response: {raw_body}")), } } @@ -121,38 +92,20 @@ pub struct OpenAIClient { client: reqwest::Client, } -impl OpenAIClient { - pub fn new(api_key: &str) -> Self { - let mut headers = header::HeaderMap::new(); - headers.insert( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ); - headers.insert( - header::AUTHORIZATION, - header::HeaderValue::from_str(&format!("Bearer {api_key}")).expect("Invalid api_key"), - ); - - let client = reqwest::Client::builder() - .default_headers(headers) - .build() - .expect("Unable to build HTTP client"); - - Self { client } - } - - pub async fn chat_completion( +#[async_trait::async_trait] +impl LLM for OpenAIClient { + async fn chat_completion( &self, - model: &OpenAIModel, + model: OpenAIModel, msgs: &[ChatMessage], - ) -> anyhow::Result { + ) -> anyhow::Result { log::debug!( "[OpenAI] chat completion w/ {} | {} messages", model, msgs.len() ); - let request_body = CompletionRequest::new(model, msgs); + let request_body = CompletionRequest::new(&model, msgs); let response = self .client .post(&"https://api.openai.com/v1/chat/completions".to_string()) @@ -165,22 +118,43 @@ impl OpenAIClient { let completion = response .json::() .await - .map_err(OpenAIError::RequestError)?; + .map_err(LLMError::RequestError)?; match completion.response() { Some(msg) => Ok(msg), - None => Err(OpenAIError::NoResponse), + None => Err(LLMError::NoResponse), } } else if StatusCode::is_client_error(status) || StatusCode::is_server_error(status) { Err(check_api_error(response).await) } else { let warning = format!("OpenAI response not currently supported {:?}", response); log::warn!("{}", &warning); - Err(OpenAIError::Other(warning)) + Err(LLMError::Other(warning)) } } } +impl OpenAIClient { + pub fn new(api_key: &str) -> Self { + let mut headers = header::HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ); + headers.insert( + header::AUTHORIZATION, + header::HeaderValue::from_str(&format!("Bearer {api_key}")).expect("Invalid api_key"), + ); + + let client = reqwest::Client::builder() + .default_headers(headers) + .build() + .expect("Unable to build HTTP client"); + + Self { client } + } +} + pub fn segment(content: &str) -> (Vec, OpenAIModel) { let cl = cl100k_base().unwrap(); let size = cl.encode_with_special_tokens(content).len(); @@ -269,21 +243,20 @@ pub fn split_text(text: &str, max_tokens: usize) -> Vec { #[cfg(test)] mod test { + use super::{ChatMessage, OpenAIClient, OpenAIModel, LLM}; use crate::llm::prompter::{json_schema_extraction, summarize}; - use super::{ChatMessage, OpenAIClient, OpenAIModel}; - #[ignore] #[tokio::test] pub async fn test_completion_api() { dotenv::dotenv().ok(); let client = OpenAIClient::new(&std::env::var("OPENAI_API_KEY").unwrap()); let msgs = vec![ - ChatMessage::new("system", "You are a helpful assistant"), - ChatMessage::new("user", "Who won the world series in 2020?"), + ChatMessage::system("You are a helpful assistant"), + ChatMessage::user("Who won the world series in 2020?"), ]; - let resp = client.chat_completion(&OpenAIModel::GPT35, &msgs).await; + let resp = client.chat_completion(OpenAIModel::GPT35, &msgs).await; // dbg!(&resp); assert!(resp.is_ok()); } @@ -301,7 +274,7 @@ mod test { ); let resp = client - .chat_completion(&OpenAIModel::GPT35, &msgs) + .chat_completion(OpenAIModel::GPT35, &msgs) .await .unwrap(); dbg!(&resp); @@ -318,7 +291,7 @@ mod test { "../../../../../fixtures/sample_yelp_review.txt" )); let resp = client - .chat_completion(&OpenAIModel::GPT35, &msgs) + .chat_completion(OpenAIModel::GPT35, &msgs) .await .unwrap(); diff --git a/lib/libmemex/src/llm/prompter.rs b/lib/libmemex/src/llm/prompter.rs index dce000c..9201788 100644 --- a/lib/libmemex/src/llm/prompter.rs +++ b/lib/libmemex/src/llm/prompter.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use handlebars::RenderError; use serde::Serialize; -use super::openai::ChatMessage; +use super::ChatMessage; pub fn build_prompt(template: &str, data: &T) -> Result where @@ -16,16 +16,16 @@ where pub fn quick_question(user_request: &str) -> Vec { vec![ - ChatMessage::new("system", "You are a helpful assistant"), - ChatMessage::new("user", user_request), + ChatMessage::system("You are a helpful assistant"), + ChatMessage::user(user_request), ] } pub fn summarize(input_data: &str) -> Vec { vec![ - ChatMessage::new("system", include_str!("../../prompts/summarize/system.txt")), - ChatMessage::new("user", input_data), - ChatMessage::new("user", include_str!("../../prompts/summarize/prompt.txt")), + ChatMessage::system(include_str!("../../prompts/summarize/system.txt")), + ChatMessage::user(input_data), + ChatMessage::user(include_str!("../../prompts/summarize/prompt.txt")), ] } @@ -39,13 +39,9 @@ pub fn json_schema_extraction( data.insert("json_schema".to_string(), output_schema.to_string()); vec![ - ChatMessage::new( - "system", - include_str!("../../prompts/json_schema/system.txt"), - ), - ChatMessage::new("user", input_data), - ChatMessage::new( - "user", + ChatMessage::system(include_str!("../../prompts/json_schema/system.txt")), + ChatMessage::user(input_data), + ChatMessage::user( &build_prompt(include_str!("../../prompts/json_schema/prompt.txt"), &data).unwrap(), ), ] diff --git a/lib/worker/src/tasks.rs b/lib/worker/src/tasks.rs index 04d0a78..a1f3300 100644 --- a/lib/worker/src/tasks.rs +++ b/lib/worker/src/tasks.rs @@ -1,7 +1,7 @@ use libmemex::db::{document, embedding, queue}; use libmemex::llm::embedding::{ModelConfig, SentenceEmbedder}; use libmemex::llm::openai::{segment, OpenAIClient}; -use libmemex::llm::prompter; +use libmemex::llm::{prompter, LLM}; use libmemex::storage::{VectorData, VectorStorage}; use libmemex::NAMESPACE; use sea_orm::{prelude::*, Set, TransactionTrait}; @@ -73,7 +73,7 @@ pub async fn generate_summary(client: &OpenAIClient, payload: &str) -> anyhow::R let time = std::time::Instant::now(); let request = prompter::summarize(segment); - if let Ok(content) = client.chat_completion(&model, &request).await { + if let Ok(content) = client.chat_completion(model.clone(), &request).await { buffer.push_str(&content); } From 33fe800e2c2e45de63ae3f43bbb2d02c6b477297 Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Thu, 5 Oct 2023 17:52:59 -0400 Subject: [PATCH 08/18] Using LLM trait in API to switch between local/OpenAI when configured --- lib/api/src/endpoints/actions/filters.rs | 8 +- lib/api/src/endpoints/actions/handlers.rs | 13 +- lib/api/src/endpoints/mod.rs | 6 +- lib/api/src/lib.rs | 13 +- lib/libmemex/src/llm/local/mod.rs | 49 ++++++- lib/libmemex/src/llm/local/schema.rs | 6 - lib/libmemex/src/llm/mod.rs | 51 ++++++- lib/libmemex/src/llm/openai/mod.rs | 154 +++++++++------------- lib/worker/src/tasks.rs | 6 +- 9 files changed, 181 insertions(+), 125 deletions(-) diff --git a/lib/api/src/endpoints/actions/filters.rs b/lib/api/src/endpoints/actions/filters.rs index ad8a8b1..9ce6bb4 100644 --- a/lib/api/src/endpoints/actions/filters.rs +++ b/lib/api/src/endpoints/actions/filters.rs @@ -1,5 +1,7 @@ +use std::sync::Arc; + use crate::{endpoints::json_body, with_db, with_llm}; -use libmemex::llm::openai::OpenAIClient; +use libmemex::llm::LLM; use sea_orm::DatabaseConnection; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -24,7 +26,7 @@ pub struct SummarizeRequest { } fn extract( - llm: &OpenAIClient, + llm: &Arc>, ) -> impl Filter + Clone { warp::path!("action" / "ask") .and(warp::post()) @@ -44,7 +46,7 @@ fn summarize( } pub fn build( - llm: &OpenAIClient, + llm: &Arc>, db: &DatabaseConnection, ) -> impl Filter + Clone { extract(llm).or(summarize(db)) diff --git a/lib/api/src/endpoints/actions/handlers.rs b/lib/api/src/endpoints/actions/handlers.rs index a6a2665..21d11aa 100644 --- a/lib/api/src/endpoints/actions/handlers.rs +++ b/lib/api/src/endpoints/actions/handlers.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::{ schema::{ApiResponse, TaskResult}, ServerError, @@ -9,19 +11,16 @@ use warp::reject::Rejection; use super::filters; use libmemex::{ db::queue, - llm::{ - openai::{truncate_text, OpenAIClient}, - prompter, LLM, - }, + llm::{prompter, LLM}, }; pub async fn handle_extract( - llm: OpenAIClient, + llm: Arc>, request: filters::AskRequest, ) -> Result { let time = std::time::Instant::now(); - let (content, model) = truncate_text(&request.text); + let (content, model) = llm.truncate_text(&request.text); // Build prompt let prompt = if let Some(schema) = &request.json_schema { @@ -34,7 +33,7 @@ pub async fn handle_extract( }; let response = llm - .chat_completion(model, &prompt) + .chat_completion(model.as_ref(), &prompt) .await .map_err(|err| ServerError::Other(err.to_string()))?; diff --git a/lib/api/src/endpoints/mod.rs b/lib/api/src/endpoints/mod.rs index d9f150c..9e893f8 100644 --- a/lib/api/src/endpoints/mod.rs +++ b/lib/api/src/endpoints/mod.rs @@ -1,4 +1,6 @@ -use libmemex::llm::openai::OpenAIClient; +use std::sync::Arc; + +use libmemex::llm::LLM; use sea_orm::DatabaseConnection; use serde::de::DeserializeOwned; use warp::Filter; @@ -24,7 +26,7 @@ pub fn json_body( pub fn build( db: &DatabaseConnection, - llm: &OpenAIClient, + llm: &Arc>, ) -> impl Filter + Clone { actions::filters::build(llm, db) .or(collections::filters::build(db)) diff --git a/lib/api/src/lib.rs b/lib/api/src/lib.rs index f8d75ff..0ce3ac4 100644 --- a/lib/api/src/lib.rs +++ b/lib/api/src/lib.rs @@ -1,8 +1,11 @@ use dotenv_codegen::dotenv; -use libmemex::{db::create_connection_by_uri, llm::openai::OpenAIClient}; +use libmemex::{ + db::create_connection_by_uri, + llm::{openai::OpenAIClient, LLM}, +}; use sea_orm::DatabaseConnection; use serde_json::json; -use std::{convert::Infallible, net::Ipv4Addr, path::PathBuf}; +use std::{convert::Infallible, net::Ipv4Addr, path::PathBuf, sync::Arc}; use thiserror::Error; use warp::{hyper::StatusCode, reject::Reject, Filter, Rejection, Reply}; @@ -76,11 +79,13 @@ pub async fn start(host: Ipv4Addr, port: u16, db_uri: String) { let llm_client = OpenAIClient::new(&std::env::var("OPENAI_API_KEY").expect("OpenAI API key not set")); + let cors = warp::cors() .allow_any_origin() .allow_methods(vec!["GET", "POST", "PUT", "PATCH", "DELETE"]) .allow_headers(["Authorization", "Content-Type"]); + let llm_client: Arc> = Arc::new(Box::new(llm_client)); let api = warp::path("api") .and(endpoints::build(&db_connection, &llm_client)) .with(warp::trace::request()); @@ -105,7 +110,7 @@ pub fn with_db( } pub fn with_llm( - llm: OpenAIClient, -) -> impl Filter + Clone { + llm: Arc>, +) -> impl Filter>,), Error = std::convert::Infallible> + Clone { warp::any().map(move || llm.clone()) } diff --git a/lib/libmemex/src/llm/local/mod.rs b/lib/libmemex/src/llm/local/mod.rs index 65b9dd1..1b93a69 100644 --- a/lib/libmemex/src/llm/local/mod.rs +++ b/lib/libmemex/src/llm/local/mod.rs @@ -4,14 +4,18 @@ use llm::InferenceParameters; use llm::{self, samplers::ConfiguredSamplers, InferenceSessionConfig}; use std::io::Write; use std::sync::{Arc, Mutex}; +use tiktoken_rs::cl100k_base; use tokio::sync::mpsc; -use crate::llm::ChatRole; +use crate::llm::{split_text, ChatRole}; use super::{ChatMessage, LLMError, LLM}; mod schema; use schema::LlmEvent; +pub const MAX_TOKENS: usize = 2_048 - 512 - 100; + +#[derive(Clone)] pub struct LocalLLM where T: llm::KnownModel, @@ -20,7 +24,7 @@ where infer_params: InferenceParameters, /// At the moment does nothing but will eventually be used by our internal /// sampler to only output JSON/etc. - bias_sampler: Arc>, + _bias_sampler: Arc>, } impl LocalLLM @@ -44,7 +48,7 @@ where Self { model, infer_params, - bias_sampler: Arc::new(Mutex::new(bias_sampler)), + _bias_sampler: Arc::new(Mutex::new(bias_sampler)), } } @@ -135,13 +139,13 @@ where } #[async_trait::async_trait] -impl LLM for LocalLLM +impl LLM for LocalLLM where T: llm::KnownModel, { async fn chat_completion( &self, - _: schema::LocalLLMSize, + _: &str, msgs: &[ChatMessage], ) -> anyhow::Result { log::info!("LocalLLM running chat_completion"); @@ -165,6 +169,41 @@ where prompt.push_str("[/INST]"); self.run_model(&prompt).await } + + fn segment_text(&self, text: &str) -> (Vec, String) { + let cl = cl100k_base().unwrap(); + let size = cl.encode_with_special_tokens(text).len(); + log::debug!("context size: {size}"); + + if size <= MAX_TOKENS { + (vec![text.to_string()], Default::default()) + } else { + let splits = split_text(text, MAX_TOKENS); + (splits, Default::default()) + } + } + + fn truncate_text(&self, text: &str) -> (String, String) { + let cl = cl100k_base().unwrap(); + let total_tokens: usize = cl.encode_with_special_tokens(text).len(); + + if total_tokens <= MAX_TOKENS { + (text.to_string(), Default::default()) + } else { + let mut buffer = String::new(); + for txt in text.split(' ') { + let with_txt = buffer.clone() + txt; + let current_size = cl.encode_with_special_tokens(&with_txt).len(); + if current_size > MAX_TOKENS { + break; + } else { + buffer.push_str(txt); + } + } + + (buffer, Default::default()) + } + } } #[cfg(test)] diff --git a/lib/libmemex/src/llm/local/schema.rs b/lib/libmemex/src/llm/local/schema.rs index c440593..cf1381c 100644 --- a/lib/libmemex/src/llm/local/schema.rs +++ b/lib/libmemex/src/llm/local/schema.rs @@ -49,9 +49,3 @@ impl From for ModelArch { } } } - -#[derive(Default, Debug, Clone)] -pub enum LocalLLMSize { - #[default] - Base, -} diff --git a/lib/libmemex/src/llm/mod.rs b/lib/libmemex/src/llm/mod.rs index 4d8aa65..191bc14 100644 --- a/lib/libmemex/src/llm/mod.rs +++ b/lib/libmemex/src/llm/mod.rs @@ -1,6 +1,7 @@ use serde::Serialize; use strum_macros::Display; use thiserror::Error; +use tiktoken_rs::cl100k_base; pub mod embedding; pub mod local; @@ -61,10 +62,56 @@ pub enum LLMError { } #[async_trait::async_trait] -pub trait LLM { +pub trait LLM: Send + Sync { async fn chat_completion( &self, - model: T, + model: &str, msgs: &[ChatMessage], ) -> anyhow::Result; + + fn segment_text(&self, text: &str) -> (Vec, String); + fn truncate_text(&self, text: &str) -> (String, String); +} + +pub fn split_text(text: &str, max_tokens: usize) -> Vec { + let cl = cl100k_base().unwrap(); + + let total_tokens: usize = cl.encode_with_special_tokens(text).len(); + let mut doc_parts = Vec::new(); + if total_tokens <= max_tokens { + doc_parts.push(text.into()); + } else { + let split_count = total_tokens + .checked_div(max_tokens) + .map(|val| val + 2) + .unwrap_or(1); + let split_size = text.len().checked_div(split_count).unwrap_or(text.len()); + if split_size == text.len() { + doc_parts.push(text.into()); + } else { + let mut part = Vec::new(); + let mut size = 0; + for txt in text.split(' ') { + if (size + txt.len()) > split_size { + doc_parts.push(part.join(" ")); + let mut end = part.len(); + if part.len() > 10 { + end = part.len() - 10; + } + part.drain(0..end); + size = part.join(" ").len(); + } + size += txt.len() + 1; + part.push(txt); + } + if !part.is_empty() { + doc_parts.push(part.join(" ")); + } + } + } + + doc_parts + .iter() + .map(|pt| pt.to_string()) + .collect::>() } diff --git a/lib/libmemex/src/llm/openai/mod.rs b/lib/libmemex/src/llm/openai/mod.rs index be91135..f1988e4 100644 --- a/lib/libmemex/src/llm/openai/mod.rs +++ b/lib/libmemex/src/llm/openai/mod.rs @@ -1,8 +1,12 @@ +use std::str::FromStr; + use reqwest::{header, Response, StatusCode}; use serde::Serialize; -use strum_macros::{AsRefStr, Display}; +use strum_macros::{AsRefStr, Display, EnumString}; use tiktoken_rs::cl100k_base; +use crate::llm::split_text; + use self::schema::ErrorResponse; use super::{ChatMessage, LLMError, LLM}; @@ -13,7 +17,7 @@ const CONTEXT_LENGTH_ERROR: &str = "context_length_exceeded"; pub const MAX_TOKENS: usize = 4_097 - 1_024 - 100; pub const MAX_16K_TOKENS: usize = 16_384 - 2_048 - 100; -#[derive(AsRefStr, Display, Clone)] +#[derive(AsRefStr, Display, Clone, EnumString)] pub enum OpenAIModel { // Most capable GPT-3.5 model and optimized for chat at 1/10th the cost of text-davinci-003. // Will be updated with our latest model iteration 2 weeks after it is released. @@ -93,10 +97,10 @@ pub struct OpenAIClient { } #[async_trait::async_trait] -impl LLM for OpenAIClient { +impl LLM for OpenAIClient { async fn chat_completion( &self, - model: OpenAIModel, + model: &str, msgs: &[ChatMessage], ) -> anyhow::Result { log::debug!( @@ -105,6 +109,9 @@ impl LLM for OpenAIClient { msgs.len() ); + let model: OpenAIModel = OpenAIModel::from_str(model) + .map_err(|err| LLMError::Other(format!("Invalid model: {err}")))?; + let request_body = CompletionRequest::new(&model, msgs); let response = self .client @@ -132,6 +139,51 @@ impl LLM for OpenAIClient { Err(LLMError::Other(warning)) } } + + fn segment_text(&self, content: &str) -> (Vec, String) { + let cl = cl100k_base().unwrap(); + let size = cl.encode_with_special_tokens(content).len(); + + log::debug!("Context Size {:?}", size); + if size <= MAX_TOKENS { + log::debug!("Using standard model"); + (vec![content.to_string()], OpenAIModel::GPT35.to_string()) + } else if size <= MAX_16K_TOKENS { + log::debug!("Using 16k model"); + ( + vec![content.to_string()], + OpenAIModel::GPT35_16K.to_string(), + ) + } else { + let splits = split_text(content, MAX_16K_TOKENS); + log::debug!("Spliting with 16K model splits {:?}", splits.len()); + (splits, OpenAIModel::GPT35_16K.to_string()) + } + } + + fn truncate_text(&self, text: &str) -> (String, String) { + let cl = cl100k_base().unwrap(); + let total_tokens: usize = cl.encode_with_special_tokens(text).len(); + + if total_tokens <= MAX_TOKENS { + (text.to_string(), OpenAIModel::GPT35.to_string()) + } else if total_tokens <= MAX_16K_TOKENS { + (text.to_string(), OpenAIModel::GPT35_16K.to_string()) + } else { + let mut buffer = String::new(); + for txt in text.split(' ') { + let with_txt = buffer.clone() + txt; + let current_size = cl.encode_with_special_tokens(&with_txt).len(); + if current_size > MAX_16K_TOKENS { + break; + } else { + buffer.push_str(txt); + } + } + + (buffer, OpenAIModel::GPT35_16K.to_string()) + } + } } impl OpenAIClient { @@ -155,92 +207,6 @@ impl OpenAIClient { } } -pub fn segment(content: &str) -> (Vec, OpenAIModel) { - let cl = cl100k_base().unwrap(); - let size = cl.encode_with_special_tokens(content).len(); - - log::debug!("Context Size {:?}", size); - if size <= MAX_TOKENS { - log::debug!("Using standard model"); - (vec![content.to_string()], OpenAIModel::GPT35) - } else if size <= MAX_16K_TOKENS { - log::debug!("Using 16k model"); - (vec![content.to_string()], OpenAIModel::GPT35_16K) - } else { - let splits = split_text(content, MAX_16K_TOKENS); - log::debug!("Spliting with 16K model splits {:?}", splits.len()); - (splits, OpenAIModel::GPT35_16K) - } -} - -/// Truncates a blob of text to the max token size -pub fn truncate_text(text: &str) -> (String, OpenAIModel) { - let cl = cl100k_base().unwrap(); - let total_tokens: usize = cl.encode_with_special_tokens(text).len(); - - if total_tokens <= MAX_TOKENS { - (text.to_string(), OpenAIModel::GPT35) - } else if total_tokens <= MAX_16K_TOKENS { - (text.to_string(), OpenAIModel::GPT35_16K) - } else { - let mut buffer = String::new(); - for txt in text.split(' ') { - let with_txt = buffer.clone() + txt; - let current_size = cl.encode_with_special_tokens(&with_txt).len(); - if current_size > MAX_16K_TOKENS { - break; - } else { - buffer.push_str(txt); - } - } - - (buffer, OpenAIModel::GPT35_16K) - } -} - -pub fn split_text(text: &str, max_tokens: usize) -> Vec { - let cl = cl100k_base().unwrap(); - - let total_tokens: usize = cl.encode_with_special_tokens(text).len(); - let mut doc_parts = Vec::new(); - if total_tokens <= max_tokens { - doc_parts.push(text.into()); - } else { - let split_count = total_tokens - .checked_div(max_tokens) - .map(|val| val + 2) - .unwrap_or(1); - let split_size = text.len().checked_div(split_count).unwrap_or(text.len()); - if split_size == text.len() { - doc_parts.push(text.into()); - } else { - let mut part = Vec::new(); - let mut size = 0; - for txt in text.split(' ') { - if (size + txt.len()) > split_size { - doc_parts.push(part.join(" ")); - let mut end = part.len(); - if part.len() > 10 { - end = part.len() - 10; - } - part.drain(0..end); - size = part.join(" ").len(); - } - size += txt.len() + 1; - part.push(txt); - } - if !part.is_empty() { - doc_parts.push(part.join(" ")); - } - } - } - - doc_parts - .iter() - .map(|pt| pt.to_string()) - .collect::>() -} - #[cfg(test)] mod test { use super::{ChatMessage, OpenAIClient, OpenAIModel, LLM}; @@ -256,7 +222,9 @@ mod test { ChatMessage::user("Who won the world series in 2020?"), ]; - let resp = client.chat_completion(OpenAIModel::GPT35, &msgs).await; + let resp = client + .chat_completion(OpenAIModel::GPT35.as_ref(), &msgs) + .await; // dbg!(&resp); assert!(resp.is_ok()); } @@ -274,7 +242,7 @@ mod test { ); let resp = client - .chat_completion(OpenAIModel::GPT35, &msgs) + .chat_completion(OpenAIModel::GPT35.as_ref(), &msgs) .await .unwrap(); dbg!(&resp); @@ -291,7 +259,7 @@ mod test { "../../../../../fixtures/sample_yelp_review.txt" )); let resp = client - .chat_completion(OpenAIModel::GPT35, &msgs) + .chat_completion(OpenAIModel::GPT35.as_ref(), &msgs) .await .unwrap(); diff --git a/lib/worker/src/tasks.rs b/lib/worker/src/tasks.rs index a1f3300..342be51 100644 --- a/lib/worker/src/tasks.rs +++ b/lib/worker/src/tasks.rs @@ -1,6 +1,6 @@ use libmemex::db::{document, embedding, queue}; use libmemex::llm::embedding::{ModelConfig, SentenceEmbedder}; -use libmemex::llm::openai::{segment, OpenAIClient}; +use libmemex::llm::openai::OpenAIClient; use libmemex::llm::{prompter, LLM}; use libmemex::storage::{VectorData, VectorStorage}; use libmemex::NAMESPACE; @@ -67,13 +67,13 @@ pub async fn process_embeddings( pub async fn generate_summary(client: &OpenAIClient, payload: &str) -> anyhow::Result { // Break task content into segments - let (splits, model) = segment(payload); + let (splits, model) = client.segment_text(payload); let mut buffer = String::new(); for (idx, segment) in splits.iter().enumerate() { let time = std::time::Instant::now(); let request = prompter::summarize(segment); - if let Ok(content) = client.chat_completion(model.clone(), &request).await { + if let Ok(content) = client.chat_completion(model.as_ref(), &request).await { buffer.push_str(&content); } From ec244d91c27e4d7dbeb579261abd6775f458e47b Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Fri, 6 Oct 2023 16:21:42 -0400 Subject: [PATCH 09/18] load llm client based on whether `OPENAI_API_KEY` or `LOCAL_LLM_CONFIG` is set --- bin/memex/src/main.rs | 19 ++++++- lib/api/src/lib.rs | 35 +++++++++---- lib/libmemex/src/llm/local/mod.rs | 64 ++++++++++++++++-------- lib/libmemex/src/llm/local/schema.rs | 75 ++++++++++++++++++++++++++-- resources/config.gpt4all.toml | 4 +- resources/config.llama2.toml | 5 +- resources/config.vicuna.toml | 4 +- 7 files changed, 166 insertions(+), 40 deletions(-) diff --git a/bin/memex/src/main.rs b/bin/memex/src/main.rs index ead9709..56a8e5c 100644 --- a/bin/memex/src/main.rs +++ b/bin/memex/src/main.rs @@ -1,3 +1,4 @@ +use api::ApiConfig; use clap::{Parser, Subcommand}; use futures::future::join_all; use std::{net::Ipv4Addr, process::ExitCode}; @@ -25,6 +26,10 @@ pub struct Args { database_connection: Option, #[clap(long, value_parser, value_name = "VECTOR_CONNECTION", env)] vector_connection: Option, + #[clap(long, value_parser, value_name = "OPENAI_API_KEY", env)] + openai_api_key: Option, + #[clap(long, value_parser, value_name = "LOCAL_LLM_CONFIG", env)] + local_llm_config: Option, } #[derive(Debug, Display, Clone, PartialEq, EnumString)] @@ -100,9 +105,21 @@ async fn main() -> ExitCode { let _vector_store_uri = args.vector_connection.expect("VECTOR_CONNECTION not set"); + if args.openai_api_key.is_none() && args.local_llm_config.is_none() { + log::error!("Must set either OPENAI_API_KEY or LOCAL_LLM_CONFIG"); + return ExitCode::FAILURE; + } + if roles.contains(&Roles::Api) { let db_uri = db_uri.clone(); - handles.push(tokio::spawn(api::start(host, port, db_uri))); + let cfg = ApiConfig { + host, + port, + db_uri, + open_ai_key: args.openai_api_key, + local_llm_config: args.local_llm_config, + }; + handles.push(tokio::spawn(api::start(cfg))); } if roles.contains(&Roles::Worker) { diff --git a/lib/api/src/lib.rs b/lib/api/src/lib.rs index 0ce3ac4..37b2a78 100644 --- a/lib/api/src/lib.rs +++ b/lib/api/src/lib.rs @@ -1,7 +1,7 @@ use dotenv_codegen::dotenv; use libmemex::{ db::create_connection_by_uri, - llm::{openai::OpenAIClient, LLM}, + llm::{local::load_from_cfg, openai::OpenAIClient, LLM}, }; use sea_orm::DatabaseConnection; use serde_json::json; @@ -25,6 +25,14 @@ pub enum ServerError { impl Reject for ServerError {} +pub struct ApiConfig { + pub host: Ipv4Addr, + pub port: u16, + pub db_uri: String, + pub open_ai_key: Option, + pub local_llm_config: Option, +} + // Handle custom errors/rejections async fn handle_rejection(err: Rejection) -> Result { let code; @@ -62,8 +70,8 @@ pub fn health_check() -> impl Filter> = if let Some(openai_key) = config.open_ai_key { + Arc::new(Box::new(OpenAIClient::new(&openai_key))) + } else if let Some(llm_config_path) = config.local_llm_config { + let llm = load_from_cfg(llm_config_path.into(), true) + .await + .expect("Unable to load local LLM"); + Arc::new(llm) + } else { + panic!("Please setup OPENAI_API_KEY or LOCAL_LLM_CONFIG"); + }; let cors = warp::cors() .allow_any_origin() .allow_methods(vec!["GET", "POST", "PUT", "PATCH", "DELETE"]) .allow_headers(["Authorization", "Content-Type"]); - let llm_client: Arc> = Arc::new(Box::new(llm_client)); let api = warp::path("api") .and(endpoints::build(&db_connection, &llm_client)) .with(warp::trace::request()); @@ -93,7 +108,7 @@ pub async fn start(host: Ipv4Addr, port: u16, db_uri: String) { let filters = health_check().or(api).with(cors).recover(handle_rejection); let (_addr, handle) = - warp::serve(filters).bind_with_graceful_shutdown((host, port), async move { + warp::serve(filters).bind_with_graceful_shutdown((config.host, config.port), async move { tokio::signal::ctrl_c() .await .expect("failed to listen to shutdown signal"); diff --git a/lib/libmemex/src/llm/local/mod.rs b/lib/libmemex/src/llm/local/mod.rs index 1b93a69..8436397 100644 --- a/lib/libmemex/src/llm/local/mod.rs +++ b/lib/libmemex/src/llm/local/mod.rs @@ -1,14 +1,17 @@ use llm::samplers::llm_samplers::samplers::SampleFlatBias; use llm::samplers::llm_samplers::types::SamplerChain; -use llm::InferenceParameters; use llm::{self, samplers::ConfiguredSamplers, InferenceSessionConfig}; +use llm::{InferenceParameters, LoadProgress}; use std::io::Write; +use std::path::PathBuf; use std::sync::{Arc, Mutex}; use tiktoken_rs::cl100k_base; use tokio::sync::mpsc; use crate::llm::{split_text, ChatRole}; +use self::schema::LocalLLMConfig; + use super::{ChatMessage, LLMError, LLM}; mod schema; use schema::LlmEvent; @@ -80,7 +83,6 @@ where std::io::stdout().flush().unwrap(); return; } - _ => {} } } } @@ -206,34 +208,56 @@ where } } +pub async fn load_from_cfg( + llm_config: PathBuf, + report_progress: bool, +) -> anyhow::Result> { + let config = std::fs::read_to_string(llm_config.clone())?; + let config: LocalLLMConfig = toml::from_str(&config)?; + + let parent_dir = llm_config.parent().unwrap(); + let model_path: PathBuf = parent_dir.join(config.model.path.clone()); + + let model_params = config.to_model_params(); + let model = llm::load::( + &model_path, + llm::TokenizerSource::Embedded, + model_params, + move |event| { + if report_progress { + match &event { + LoadProgress::TensorLoaded { + current_tensor, + tensor_count, + } => { + log::info!("Loaded {}/{} tensors", current_tensor, tensor_count); + } + LoadProgress::Loaded { .. } => { + log::info!("Model finished loading"); + } + _ => {} + } + } + }, + )?; + + let llm = LocalLLM::new(model); + Ok(Box::new(llm)) +} + #[cfg(test)] mod test { - use crate::llm::{ChatMessage, LLM}; - - use super::schema::LocalLLMConfig; - use super::LocalLLM; + use crate::llm::ChatMessage; use std::path::PathBuf; - #[ignore] #[tokio::test] async fn test_prompting() { let base_dir: PathBuf = "../..".into(); let model_config: PathBuf = base_dir.join("resources/config.llama2.toml"); - let config = std::fs::read_to_string(model_config).expect("Unable to read cfg"); - let config: LocalLLMConfig = toml::from_str(&config).expect("Unable to parse cfg"); - let model_path: PathBuf = base_dir.join(config.model.path); - - let model_params = llm::ModelParameters::default(); - let model = llm::load::( - &model_path, - llm::TokenizerSource::Embedded, - model_params, - move |_| {}, - ) - .expect("Unable to load model"); + let llm = super::load_from_cfg(model_config, true).await + .expect("Unable to load model"); - let llm = LocalLLM::new(model); let msgs = vec![ ChatMessage::system("You're a helpful assistant. Answer questions as accurately and concisely as possible."), ChatMessage::user("Who won the world series in 2020?"), diff --git a/lib/libmemex/src/llm/local/schema.rs b/lib/libmemex/src/llm/local/schema.rs index cf1381c..b4eb1ba 100644 --- a/lib/libmemex/src/llm/local/schema.rs +++ b/lib/libmemex/src/llm/local/schema.rs @@ -1,10 +1,18 @@ -use llm::{LoadProgress, ModelArchitecture}; +use llm::{ + samplers::{ + llm_samplers::{ + configure::{SamplerChainBuilder, SamplerSlot}, + samplers::{SampleRepetition, SampleTemperature, SampleTopK, SampleTopP}, + }, + ConfiguredSamplers, + }, + ModelArchitecture, +}; use serde::Deserialize; -use std::path::PathBuf; +use std::{path::PathBuf, sync::Arc, sync::Mutex}; #[derive(Debug)] pub enum LlmEvent { - ModelLoadProgress(LoadProgress), TokenReceived(String), InferenceDone, } @@ -15,6 +23,67 @@ pub struct LocalLLMConfig { pub model: ModelConfig, } +impl LocalLLMConfig { + pub fn to_model_params(&self) -> llm::ModelParameters { + llm::ModelParameters { + prefer_mmap: false, + context_size: 2048, + lora_adapters: None, + ..Default::default() + } + } + + pub fn to_inference_params(&self) -> llm::InferenceParameters { + let model = self.model.clone(); + let sampler_builder: SamplerChainBuilder = SamplerChainBuilder::from([ + ( + "repetition", + SamplerSlot::new_chain( + move || { + Box::new( + SampleRepetition::default() + .penalty(model.repeat_penalty) + .last_n(model.repetition_penalty_last_n), + ) + }, + [], + ), + ), + ( + "topk", + SamplerSlot::new_single( + move || Box::new(SampleTopK::default().k(model.top_k)), + Option::::None, + ), + ), + ( + "topp", + SamplerSlot::new_single( + move || Box::new(SampleTopP::default().p(model.top_p)), + Option::::None, + ), + ), + ( + "temperature", + SamplerSlot::new_single( + move || Box::new(SampleTemperature::default().temperature(model.temperature)), + Option::::None, + ), + ), + ]); + + let mut sampler = ConfiguredSamplers { + builder: sampler_builder, + ..Default::default() + }; + sampler.ensure_default_slots(); + + llm::InferenceParameters { + sampler: Arc::new(Mutex::new(sampler.builder.into_chain())), + } + } +} + #[derive(Clone, Deserialize)] pub struct ModelConfig { pub path: PathBuf, diff --git a/resources/config.gpt4all.toml b/resources/config.gpt4all.toml index fe905c2..0a20d6c 100644 --- a/resources/config.gpt4all.toml +++ b/resources/config.gpt4all.toml @@ -1,7 +1,7 @@ -prompt_template = "resources/templates/clippy_prompt.txt" +prompt_template = "templates/clippy_prompt.txt" [model] -path = "resources/models/gpt4all/gpt4all-j-q4_0-ggjt.bin" +path = "models/gpt4all/gpt4all-j-q4_0-ggjt.bin" model_type = "Gptj" prefer_mmap = false # The top K words by score are kept during sampling. diff --git a/resources/config.llama2.toml b/resources/config.llama2.toml index 5570975..7b00ecd 100644 --- a/resources/config.llama2.toml +++ b/resources/config.llama2.toml @@ -1,8 +1,9 @@ -prompt_template = "resources/templates/clippy_prompt.txt" +prompt_template = "templates/clippy_prompt.txt" [model] # Download from https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML -path = "resources/models/LLaMa2/llama-2-7b-chat.ggmlv3.q4_1.bin" +# NOTE: Path is relative from this file. +path = "models/LLaMa2/llama-2-7b-chat.ggmlv3.q4_1.bin" model_type = "Llama" prefer_mmap = false # The top K words by score are kept during sampling. diff --git a/resources/config.vicuna.toml b/resources/config.vicuna.toml index 37ab6b7..be7b0fa 100644 --- a/resources/config.vicuna.toml +++ b/resources/config.vicuna.toml @@ -1,9 +1,9 @@ -prompt_template = "resources/templates/clippy_prompt.txt" +prompt_template = "templates/clippy_prompt.txt" [model] # Download from https://huggingface.co/TheBloke/Wizard-Vicuna-7B-Uncensored-GGML/tree/main # And place in the same directory as this config file. -path = "resources/models/Wziard-Vicuna/Wizard-Vicuna-7B-Uncensored.ggmlv3.q4_0.bin" +path = "models/Wziard-Vicuna/Wizard-Vicuna-7B-Uncensored.ggmlv3.q4_0.bin" model_type = "Llama" prefer_mmap = false # The top K words by score are kept during sampling. From 3910cf1804debf46522e84d6436a32b0c90ca76e Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Fri, 6 Oct 2023 16:21:52 -0400 Subject: [PATCH 10/18] update README to point that out --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index caf319b..ee106a1 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,15 @@ since Linux ARM builds are very finicky. 2023-06-13T05:04:21.518732Z INFO memex: starting server with roles: [Api, Worker] ``` +## Using a LLM +You can use either OpenAI or a local LLM for LLM based functionality (such as the +summarization or extraction APIs). + +Set `OPENAI_API_KEY` to your API key in the `.env` file or set `LOCAL_LLM_CONFIG` to +a LLM configuration file. See `resources/config.llama2.toml` for an example. By +default, a base memex will use the llama-2 configuration file. + + ## Add a document NOTE: If the `test` collection does not initially exist, it'll be created. From 6819059cee0bf4fcad37e28e57ec14d1c1727b09 Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Fri, 6 Oct 2023 16:27:08 -0400 Subject: [PATCH 11/18] Create samplers from config and pass into `LocalLLM` struct --- lib/libmemex/src/llm/local/mod.rs | 13 +++++-------- lib/libmemex/src/llm/local/schema.rs | 9 +++------ 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/lib/libmemex/src/llm/local/mod.rs b/lib/libmemex/src/llm/local/mod.rs index 8436397..aaf5768 100644 --- a/lib/libmemex/src/llm/local/mod.rs +++ b/lib/libmemex/src/llm/local/mod.rs @@ -34,15 +34,12 @@ impl LocalLLM where T: llm::KnownModel, { - fn new(model: T) -> Self { + fn new(model: T, base_samplers: ConfiguredSamplers) -> Self { let bias_sampler = SampleFlatBias::default(); // Create sampler chain let mut samplers = SamplerChain::new(); samplers += bias_sampler.clone(); - - let mut default_samplers = ConfiguredSamplers::default(); - default_samplers.ensure_default_slots(); - samplers += default_samplers.builder.into_chain(); + samplers += base_samplers.builder.into_chain(); let infer_params = llm::InferenceParameters { sampler: Arc::new(Mutex::new(samplers)), // sampler: llm::samplers::default_samplers(), @@ -241,8 +238,7 @@ pub async fn load_from_cfg( }, )?; - let llm = LocalLLM::new(model); - Ok(Box::new(llm)) + Ok(Box::new(LocalLLM::new(model, config.base_samplers()))) } #[cfg(test)] @@ -255,7 +251,8 @@ mod test { let base_dir: PathBuf = "../..".into(); let model_config: PathBuf = base_dir.join("resources/config.llama2.toml"); - let llm = super::load_from_cfg(model_config, true).await + let llm = super::load_from_cfg(model_config, true) + .await .expect("Unable to load model"); let msgs = vec![ diff --git a/lib/libmemex/src/llm/local/schema.rs b/lib/libmemex/src/llm/local/schema.rs index b4eb1ba..fc5907e 100644 --- a/lib/libmemex/src/llm/local/schema.rs +++ b/lib/libmemex/src/llm/local/schema.rs @@ -9,7 +9,7 @@ use llm::{ ModelArchitecture, }; use serde::Deserialize; -use std::{path::PathBuf, sync::Arc, sync::Mutex}; +use std::path::PathBuf; #[derive(Debug)] pub enum LlmEvent { @@ -33,7 +33,7 @@ impl LocalLLMConfig { } } - pub fn to_inference_params(&self) -> llm::InferenceParameters { + pub fn base_samplers(&self) -> ConfiguredSamplers { let model = self.model.clone(); let sampler_builder: SamplerChainBuilder = SamplerChainBuilder::from([ ( @@ -77,10 +77,7 @@ impl LocalLLMConfig { ..Default::default() }; sampler.ensure_default_slots(); - - llm::InferenceParameters { - sampler: Arc::new(Mutex::new(sampler.builder.into_chain())), - } + sampler } } From b1d4848f59e0b5bb9004f0cdb9684fe58aa2a78a Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Fri, 6 Oct 2023 16:36:25 -0400 Subject: [PATCH 12/18] add basic support for other architectures, but focus on llama for now --- lib/libmemex/src/llm/local/mod.rs | 59 +++++++++++++++++----------- lib/libmemex/src/llm/local/schema.rs | 20 ++++++---- 2 files changed, 49 insertions(+), 30 deletions(-) diff --git a/lib/libmemex/src/llm/local/mod.rs b/lib/libmemex/src/llm/local/mod.rs index aaf5768..af7582d 100644 --- a/lib/libmemex/src/llm/local/mod.rs +++ b/lib/libmemex/src/llm/local/mod.rs @@ -10,7 +10,7 @@ use tokio::sync::mpsc; use crate::llm::{split_text, ChatRole}; -use self::schema::LocalLLMConfig; +use self::schema::{LocalLLMConfig, ModelArch}; use super::{ChatMessage, LLMError, LLM}; mod schema; @@ -216,29 +216,44 @@ pub async fn load_from_cfg( let model_path: PathBuf = parent_dir.join(config.model.path.clone()); let model_params = config.to_model_params(); - let model = llm::load::( - &model_path, - llm::TokenizerSource::Embedded, - model_params, - move |event| { - if report_progress { - match &event { - LoadProgress::TensorLoaded { - current_tensor, - tensor_count, - } => { - log::info!("Loaded {}/{} tensors", current_tensor, tensor_count); - } - LoadProgress::Loaded { .. } => { - log::info!("Model finished loading"); - } - _ => {} + + let progress_cb = move |event| { + if report_progress { + match &event { + LoadProgress::TensorLoaded { + current_tensor, + tensor_count, + } => { + log::info!("Loaded {}/{} tensors", current_tensor, tensor_count); + } + LoadProgress::Loaded { .. } => { + log::info!("Model finished loading"); } + _ => {} } - }, - )?; - - Ok(Box::new(LocalLLM::new(model, config.base_samplers()))) + } + }; + + match config.model.model_type { + ModelArch::GptJ => Ok(Box::new(LocalLLM::new( + llm::load::( + &model_path, + llm::TokenizerSource::Embedded, + model_params, + progress_cb, + )?, + config.base_samplers(), + ))), + ModelArch::Llama => Ok(Box::new(LocalLLM::new( + llm::load::( + &model_path, + llm::TokenizerSource::Embedded, + model_params, + progress_cb, + )?, + config.base_samplers(), + ))), + } } #[cfg(test)] diff --git a/lib/libmemex/src/llm/local/schema.rs b/lib/libmemex/src/llm/local/schema.rs index fc5907e..3cfbaca 100644 --- a/lib/libmemex/src/llm/local/schema.rs +++ b/lib/libmemex/src/llm/local/schema.rs @@ -93,25 +93,29 @@ pub struct ModelConfig { pub repetition_penalty_last_n: usize, } +/// TODO: Test on more architectures #[derive(Clone, Deserialize)] pub enum ModelArch { - Bloom, - Gpt2, + // Bloom, + // Gpt2, GptJ, - GptNeoX, + // GptNeoX, Llama, - Mpt, + // Mpt, } impl From for ModelArch { fn from(value: ModelArchitecture) -> Self { match value { - ModelArchitecture::Bloom => Self::Bloom, - ModelArchitecture::Gpt2 => Self::Gpt2, + // ModelArchitecture::Bloom => Self::Bloom, + // ModelArchitecture::Gpt2 => Self::Gpt2, ModelArchitecture::GptJ => Self::GptJ, - ModelArchitecture::GptNeoX => Self::GptNeoX, + // ModelArchitecture::GptNeoX => Self::GptNeoX, ModelArchitecture::Llama => Self::Llama, - ModelArchitecture::Mpt => Self::Mpt, + // ModelArchitecture::Mpt => Self::Mpt, + _ => { + panic!("Model architecture not yet supported"); + } } } } From 96130d123a4c20b0e742a7804a707c76be75ea13 Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Fri, 6 Oct 2023 16:44:18 -0400 Subject: [PATCH 13/18] add LLM ask example to README --- README.md | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ee106a1..899d5fa 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,14 @@ Set `OPENAI_API_KEY` to your API key in the `.env` file or set `LOCAL_LLM_CONFIG a LLM configuration file. See `resources/config.llama2.toml` for an example. By default, a base memex will use the llama-2 configuration file. +### Supported local models -## Add a document +Currently we have supported (and have tested) the following models: +- Llama based models (llama 1 & 2, Mistral, etc.) - *recommended* +- Gptj (e.g. GPT4All) + + +## Adding a document NOTE: If the `test` collection does not initially exist, it'll be created. @@ -89,7 +95,7 @@ Or if it's finished, something like so: One the task is shown as "Completed", you can now run a query against the doc(s) you've just added. -## Run a query +## Run a search query ``` bash > curl http://localhost:8181/api/collections/test/search \ @@ -109,6 +115,22 @@ you've just added. } ``` +## Ask a question +```bash +> curl http://localhost:8181/api/action/ask \ + -H "Content-Type: application/json" \ + -X POST \ + -d "{\"text\": \"\", \"query\": \"What is the airspeed velocity of an unladen swallow?\", "json_schema": { .. }}" +{ + "time": 1.234, + "status": "ok", + "result": { + "answer": "The airspeed velocity of an unladen swallow is..." + } +} + +``` + ## Env variables - `HOST`: Defaults to `127.0.0.1` From a1e7c6c554d18f6d68b70a8876b064dfb6576eac Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Mon, 9 Oct 2023 11:56:07 -0700 Subject: [PATCH 14/18] output debug messages for local llm responses --- lib/api/src/endpoints/actions/handlers.rs | 1 + lib/libmemex/src/llm/local/mod.rs | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/api/src/endpoints/actions/handlers.rs b/lib/api/src/endpoints/actions/handlers.rs index 21d11aa..726654b 100644 --- a/lib/api/src/endpoints/actions/handlers.rs +++ b/lib/api/src/endpoints/actions/handlers.rs @@ -37,6 +37,7 @@ pub async fn handle_extract( .await .map_err(|err| ServerError::Other(err.to_string()))?; + log::debug!("llm response: {response}"); let val = serde_json::from_str::(&response) .map_err(|err| ServerError::Other(err.to_string()))?; diff --git a/lib/libmemex/src/llm/local/mod.rs b/lib/libmemex/src/llm/local/mod.rs index af7582d..abb3a5d 100644 --- a/lib/libmemex/src/llm/local/mod.rs +++ b/lib/libmemex/src/llm/local/mod.rs @@ -224,7 +224,8 @@ pub async fn load_from_cfg( current_tensor, tensor_count, } => { - log::info!("Loaded {}/{} tensors", current_tensor, tensor_count); + let percent = current_tensor * 100 / tensor_count; + log::info!("Loaded tensors: {percent} ({current_tensor}/{tensor_count})"); } LoadProgress::Loaded { .. } => { log::info!("Model finished loading"); From e368e87037efe8276c116c103f584441436862b9 Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Mon, 9 Oct 2023 12:07:05 -0700 Subject: [PATCH 15/18] correctly capture server errors & send back the error messages --- lib/api/src/lib.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/lib/api/src/lib.rs b/lib/api/src/lib.rs index 37b2a78..d42cb9d 100644 --- a/lib/api/src/lib.rs +++ b/lib/api/src/lib.rs @@ -46,6 +46,18 @@ async fn handle_rejection(err: Rejection) -> Result { // and render it however we want code = StatusCode::METHOD_NOT_ALLOWED; message = "METHOD_NOT_ALLOWED".into(); + } else if let Some(err) = err.find::() { + (code, message) = match err { + ServerError::ClientRequestError(err) => { + (StatusCode::BAD_REQUEST, err.to_string()) + }, + ServerError::DatabaseError(err) => { + (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) + }, + ServerError::Other(err) => { + (StatusCode::BAD_REQUEST, err.to_string()) + } + }; } else { // We should have expected this... Just log and say its a 500 eprintln!("unhandled rejection: {:?}", err); From 3af2396202ec29e517c566a02cb71c136cd0c865 Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Mon, 9 Oct 2023 12:38:47 -0700 Subject: [PATCH 16/18] removing old dep --- Cargo.lock | 54 +++-------------------------------------- Cargo.toml | 1 - lib/libmemex/Cargo.toml | 1 - 3 files changed, 3 insertions(+), 53 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 66c9471..03cf5a7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -697,7 +697,7 @@ dependencies = [ "tera", "tokenizers 0.14.0", "tokio", - "toml 0.7.6", + "toml", ] [[package]] @@ -2165,7 +2165,6 @@ dependencies = [ "sea-orm", "serde", "serde_json", - "sherpa", "strum", "strum_macros", "tera", @@ -2173,7 +2172,7 @@ dependencies = [ "tiktoken-rs", "tokenizers 0.14.0", "tokio", - "toml 0.7.6", + "toml", "url", "uuid", ] @@ -3971,28 +3970,6 @@ dependencies = [ "lazy_static", ] -[[package]] -name = "sherpa" -version = "0.1.0" -dependencies = [ - "anyhow", - "async-trait", - "llm", - "llm-base", - "log", - "rand 0.8.5", - "rust-bert", - "serde", - "serde_json", - "strum", - "strum_macros", - "tera", - "thiserror", - "tokenizers 0.14.0", - "tokio", - "toml 0.8.0", -] - [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -4741,19 +4718,7 @@ dependencies = [ "serde", "serde_spanned", "toml_datetime", - "toml_edit 0.19.14", -] - -[[package]] -name = "toml" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c226a7bba6d859b63c92c4b4fe69c5b6b72d0cb897dbc8e6012298e6154cb56e" -dependencies = [ - "serde", - "serde_spanned", - "toml_datetime", - "toml_edit 0.20.0", + "toml_edit", ] [[package]] @@ -4778,19 +4743,6 @@ dependencies = [ "winnow", ] -[[package]] -name = "toml_edit" -version = "0.20.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ff63e60a958cefbb518ae1fd6566af80d9d4be430a33f3723dfc47d1d411d95" -dependencies = [ - "indexmap 2.0.0", - "serde", - "serde_spanned", - "toml_datetime", - "winnow", -] - [[package]] name = "tonic" version = "0.9.2" diff --git a/Cargo.toml b/Cargo.toml index 79cb56f..af394d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,6 @@ members = [ "lib/api", "lib/libmemex", - "lib/sherpa", "lib/worker", "examples/clippy" diff --git a/lib/libmemex/Cargo.toml b/lib/libmemex/Cargo.toml index ceb1812..f272bd8 100644 --- a/lib/libmemex/Cargo.toml +++ b/lib/libmemex/Cargo.toml @@ -23,7 +23,6 @@ rust-bert = { version = "0.21.0", features= ["download-libtorch"] } sea-orm = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } -sherpa = { path = "../sherpa" } strum = "0.25" strum_macros = "0.25" tera = "1.19.0" From d07b219d590220aeb6cc0ab34cb0281ff3c63f92 Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Mon, 9 Oct 2023 12:50:24 -0700 Subject: [PATCH 17/18] ignore model test --- lib/libmemex/src/llm/local/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/libmemex/src/llm/local/mod.rs b/lib/libmemex/src/llm/local/mod.rs index abb3a5d..8575917 100644 --- a/lib/libmemex/src/llm/local/mod.rs +++ b/lib/libmemex/src/llm/local/mod.rs @@ -262,6 +262,8 @@ mod test { use crate::llm::ChatMessage; use std::path::PathBuf; + // ignoring this for now since we don't want to continually download models + #[ignore] #[tokio::test] async fn test_prompting() { let base_dir: PathBuf = "../..".into(); From 13e2bb0fec0f423a129577108322903cf00b4e92 Mon Sep 17 00:00:00 2001 From: Andrew Huynh Date: Mon, 9 Oct 2023 13:27:45 -0700 Subject: [PATCH 18/18] cargo fmt --- lib/api/src/lib.rs | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/lib/api/src/lib.rs b/lib/api/src/lib.rs index d42cb9d..9c6736e 100644 --- a/lib/api/src/lib.rs +++ b/lib/api/src/lib.rs @@ -48,15 +48,9 @@ async fn handle_rejection(err: Rejection) -> Result { message = "METHOD_NOT_ALLOWED".into(); } else if let Some(err) = err.find::() { (code, message) = match err { - ServerError::ClientRequestError(err) => { - (StatusCode::BAD_REQUEST, err.to_string()) - }, - ServerError::DatabaseError(err) => { - (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) - }, - ServerError::Other(err) => { - (StatusCode::BAD_REQUEST, err.to_string()) - } + ServerError::ClientRequestError(err) => (StatusCode::BAD_REQUEST, err.to_string()), + ServerError::DatabaseError(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()), + ServerError::Other(err) => (StatusCode::BAD_REQUEST, err.to_string()), }; } else { // We should have expected this... Just log and say its a 500