Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Uplift 1.62.x] #21398 AI Chat: introduce freemium model concept #21594

Merged
merged 3 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,12 @@ void AIChatUIPageHandler::GetModels(GetModelsCallback callback) {
return;
}

std::vector<mojom::ModelPtr> models(kAllModelKeysDisplayOrder.size());
auto all_models = GetAllModels();
std::vector<mojom::ModelPtr> models(all_models.size());
// Ensure we return only in intended display order
std::transform(kAllModelKeysDisplayOrder.cbegin(),
kAllModelKeysDisplayOrder.cend(), models.begin(),
[](auto& model_key) {
auto model_match = kAllModels.find(model_key);
DCHECK(model_match != kAllModels.end());
return model_match->second.Clone();
});
std::transform(all_models.cbegin(), all_models.cend(), models.begin(),
[](auto& model) { return model.Clone(); });

std::move(callback).Run(std::move(models),
active_chat_tab_helper_->GetCurrentModel().key);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,26 +147,18 @@ void BraveLeoAssistantHandler::HandleResetLeoData(
}

void BraveLeoAssistantHandler::HandleGetModels(const base::Value::List& args) {
std::vector<ai_chat::mojom::ModelPtr> models(
ai_chat::kAllModelKeysDisplayOrder.size());
// Ensure we return only in intended display order
std::transform(ai_chat::kAllModelKeysDisplayOrder.cbegin(),
ai_chat::kAllModelKeysDisplayOrder.cend(), models.begin(),
[](auto& model_key) {
auto model_match = ai_chat::kAllModels.find(model_key);
DCHECK(model_match != ai_chat::kAllModels.end());
return model_match->second.Clone();
});
auto& models = ai_chat::GetAllModels();
base::Value::List models_list;
for (auto& model : models) {
base::Value::Dict dict;
dict.Set("key", model->key);
dict.Set("name", model->name);
dict.Set("display_name", model->display_name);
dict.Set("display_maker", model->display_maker);
dict.Set("engine_type", static_cast<int>(model->engine_type));
dict.Set("category", static_cast<int>(model->category));
dict.Set("is_premium", model->is_premium);
dict.Set("key", model.key);
dict.Set("name", model.name);
dict.Set("display_name", model.display_name);
dict.Set("display_maker", model.display_maker);
dict.Set("engine_type", static_cast<int>(model.engine_type));
dict.Set("category", static_cast<int>(model.category));
dict.Set("is_premium",
model.access == ai_chat::mojom::ModelAccess::PREMIUM);
models_list.Append(std::move(dict));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ void BraveAddCommonStrings(content::WebUIDataSource* html_source,
{"braveLeoAssistantModelSelectionLabel",
IDS_SETTINGS_LEO_ASSISTANT_MODEL_SELECTION_LABEL},
{"braveLeoModelCategory-chat", IDS_CHAT_UI_MODEL_CATEGORY_CHAT},
{"braveLeoModelSubtitle-chat-default", IDS_CHAT_UI_CHAT_DEFAULT_SUBTITLE},
{"braveLeoModelSubtitle-chat-basic", IDS_CHAT_UI_CHAT_BASIC_SUBTITLE},
{"braveLeoModelSubtitle-chat-leo-expanded",
IDS_CHAT_UI_CHAT_LEO_EXPANDED_SUBTITLE},
{"braveLeoModelSubtitle-chat-claude-instant",
Expand Down
11 changes: 7 additions & 4 deletions components/ai_chat/core/browser/constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@ base::span<const webui::LocalizedString> GetLocalizedStrings() {
{"errorNetworkLabel", IDS_CHAT_UI_ERROR_NETWORK},
{"errorRateLimit", IDS_CHAT_UI_ERROR_RATE_LIMIT},
{"retryButtonLabel", IDS_CHAT_UI_RETRY_BUTTON_LABEL},
{"introMessage-chat-default", IDS_CHAT_UI_INTRO_MESSAGE_CHAT_DEFAULT},
{"introMessage-chat-basic", IDS_CHAT_UI_INTRO_MESSAGE_CHAT_BASIC},
{"introMessage-chat-leo-expanded",
IDS_CHAT_UI_INTRO_MESSAGE_CHAT_LEO_EXPANDED},
{"introMessage-chat-claude-instant",
IDS_CHAT_UI_INTRO_MESSAGE_CHAT_LEO_CLAUDE_INSTANT},
{"modelNameSyntax", IDS_CHAT_UI_MODEL_NAME_SYNTAX},
{"modelFreemiumLabelNonPremium",
IDS_CHAT_UI_MODEL_FREEMIUM_LABEL_NON_PREMIUM},
{"modelFreemiumLabelPremium", IDS_CHAT_UI_MODEL_FREEMIUM_LABEL_PREMIUM},
{"modelCategory-chat", IDS_CHAT_UI_MODEL_CATEGORY_CHAT},
{"menuNewChat", IDS_CHAT_UI_MENU_NEW_CHAT},
{"menuGoPremium", IDS_CHAT_UI_MENU_GO_PREMIUM},
Expand All @@ -44,8 +47,8 @@ base::span<const webui::LocalizedString> GetLocalizedStrings() {
{"premiumFeature_2", IDS_CHAT_UI_PREMIUM_FEATURE_2},
{"premiumLabel", IDS_CHAT_UI_PREMIUM_LABEL},
{"premiumPricing", IDS_CHAT_UI_PREMIUM_PRICING},
{"switchToDefaultModelButtonLabel",
IDS_CHAT_UI_SWITCH_TO_DEFAULT_MODEL_BUTTON_LABEL},
{"switchToBasicModelButtonLabel",
IDS_CHAT_UI_SWITCH_TO_BASIC_MODEL_BUTTON_LABEL},
{"dismissButtonLabel", IDS_CHAT_UI_DISMISS_BUTTON_LABEL},
{"unlockPremiumTitle", IDS_CHAT_UI_UNLOCK_PREMIUM_TITLE},
{"premiumFeature_1_desc", IDS_CHAT_UI_PREMIUM_FEATURE_1_DESC},
Expand Down Expand Up @@ -74,7 +77,7 @@ base::span<const webui::LocalizedString> GetLocalizedStrings() {
{"optionOther", IDS_CHAT_UI_OPTION_OTHER},
{"feedbackError", IDS_CHAT_UI_FEEDBACK_SUBMIT_ERROR},
{"ratingError", IDS_CHAT_UI_RATING_ERROR},
{"braveLeoModelSubtitle-chat-default", IDS_CHAT_UI_CHAT_DEFAULT_SUBTITLE},
{"braveLeoModelSubtitle-chat-basic", IDS_CHAT_UI_CHAT_BASIC_SUBTITLE},
{"braveLeoModelSubtitle-chat-leo-expanded",
IDS_CHAT_UI_CHAT_LEO_EXPANDED_SUBTITLE},
{"braveLeoModelSubtitle-chat-claude-instant",
Expand Down
53 changes: 29 additions & 24 deletions components/ai_chat/core/browser/conversation_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,18 @@ ConversationDriver::ConversationDriver(raw_ptr<PrefService> pref_service,
base::BindRepeating(&ConversationDriver::OnUserOptedIn,
weak_ptr_factory_.GetWeakPtr()));

// Engines and model names are selectable per conversation, not static.
// Start with default from pref value but only if user set. We can't rely on
// actual default pref value since we should vary if user is premium or not.
// Model choice names is selectable per conversation, not global.
// Start with default from pref value if. If user is premium and premium model
// is different to non-premium default, and user hasn't customized the model
// pref, then switch the user to the premium default.
// TODO(petemill): When we have an event for premium status changed, and a
// profile service for AIChat, then we can call
// |pref_service_->SetDefaultPrefValue| when the user becomes premium. With
// that, we'll be able to simply call GetString(prefs::kDefaultModelKey) and
// not vary on premium status.
if (!pref_service_->GetUserPrefValue(prefs::kDefaultModelKey)) {
// not have to fetch premium status.
if (!pref_service_->GetUserPrefValue(prefs::kDefaultModelKey) &&
features::kAIModelsPremiumDefaultKey.Get() !=
features::kAIModelsDefaultKey.Get()) {
credential_manager_->GetPremiumStatus(base::BindOnce(
[](ConversationDriver* instance, mojom::PremiumStatus status) {
instance->last_premium_status_ = status;
Expand All @@ -76,18 +79,18 @@ ConversationDriver::ConversationDriver(raw_ptr<PrefService> pref_service,
return;
}
// Use default premium model for this instance
instance->ChangeModel(kModelsPremiumDefaultKey);
instance->ChangeModel(features::kAIModelsPremiumDefaultKey.Get());
// Make sure default model reflects premium status
const auto* current_default =
instance->pref_service_
->GetDefaultPrefValue(prefs::kDefaultModelKey)
->GetIfString();

if (current_default &&
*current_default != kModelsPremiumDefaultKey) {
*current_default != features::kAIModelsPremiumDefaultKey.Get()) {
instance->pref_service_->SetDefaultPrefValue(
prefs::kDefaultModelKey,
base::Value(kModelsPremiumDefaultKey));
base::Value(features::kAIModelsPremiumDefaultKey.Get()));
}
},
// Unretained is ok as credential manager is owned by this class,
Expand Down Expand Up @@ -118,16 +121,19 @@ ConversationDriver::~ConversationDriver() = default;
void ConversationDriver::ChangeModel(const std::string& model_key) {
DCHECK(!model_key.empty());
// Check that the key exists
if (kAllModels.find(model_key) == kAllModels.end()) {
auto* new_model = GetModel(model_key);
if (!new_model) {
NOTREACHED() << "No matching model found for key: " << model_key;
return;
}
model_key_ = model_key;
model_key_ = new_model->key;
InitEngine();
}

const mojom::Model& ConversationDriver::GetCurrentModel() {
return kAllModels.find(model_key_)->second;
auto* model = GetModel(model_key_);
DCHECK(model);
return *model;
}

const std::vector<ConversationTurn>& ConversationDriver::GetConversationHistory() {
Expand All @@ -145,33 +151,32 @@ void ConversationDriver::OnConversationActiveChanged(bool is_conversation_active

void ConversationDriver::InitEngine() {
DCHECK(!model_key_.empty());
auto model_match = kAllModels.find(model_key_);
auto* model = GetModel(model_key_);
// Make sure we get a valid model, defaulting to static default or first.
if (model_match == kAllModels.end()) {
if (!model) {
NOTREACHED() << "Model was not part of static model list";
// Use default
model_match = kAllModels.find(kModelsDefaultKey);
const auto is_found = model_match != kAllModels.end();
DCHECK(is_found);
if (!is_found) {
model_match = kAllModels.begin();
model = GetModel(features::kAIModelsDefaultKey.Get());
DCHECK(model);
if (!model) {
// Use first if given bad default value
model = &GetAllModels().at(0);
}
}

auto model = model_match->second;
// Model's key might not be the same as what we asked for (e.g. if the model
// no longer exists).
model_key_ = model.key;
model_key_ = model->key;

// Engine enum on model to decide which one
if (model.engine_type == mojom::ModelEngineType::LLAMA_REMOTE) {
if (model->engine_type == mojom::ModelEngineType::LLAMA_REMOTE) {
VLOG(1) << "Started AI engine: llama";
engine_ = std::make_unique<EngineConsumerLlamaRemote>(
model, url_loader_factory_, credential_manager_.get());
*model, url_loader_factory_, credential_manager_.get());
} else {
VLOG(1) << "Started AI engine: claude";
engine_ = std::make_unique<EngineConsumerClaudeRemote>(
model, url_loader_factory_, credential_manager_.get());
*model, url_loader_factory_, credential_manager_.get());
}

// Pending requests have been deleted along with the model engine
Expand Down Expand Up @@ -719,7 +724,7 @@ void ConversationDriver::OnPremiumStatusReceived(
if (last_premium_status_ != premium_status &&
premium_status == mojom::PremiumStatus::Active) {
// Change model if we haven't already
ChangeModel(kModelsPremiumDefaultKey);
ChangeModel(features::kAIModelsPremiumDefaultKey.Get());
}
last_premium_status_ = premium_status;
std::move(parent_callback).Run(premium_status);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,11 @@ EngineConsumerClaudeRemote::EngineConsumerClaudeRemote(
const mojom::Model& model,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
AIChatCredentialManager* credential_manager) {
// Allow specific model name to be overriden by feature flag
// TODO(petemill): verify premium status, or ensure server will verify even
// when given a model name override via cli flag param.
std::string model_name = ai_chat::features::kAIModelName.Get();
if (model_name.empty()) {
model_name = model.name;
}
DCHECK(!model_name.empty());
DCHECK(!model.name.empty());
base::flat_set<std::string_view> stop_sequences(kStopSequences.begin(),
kStopSequences.end());
api_ = std::make_unique<RemoteCompletionClient>(
model_name, stop_sequences, url_loader_factory, credential_manager);
model.name, stop_sequences, url_loader_factory, credential_manager);

max_page_content_length_ = model.max_page_content_length;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,18 +243,11 @@ EngineConsumerLlamaRemote::EngineConsumerLlamaRemote(
const mojom::Model& model,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
AIChatCredentialManager* credential_manager) {
// Allow specific model name to be overriden by feature flag
// TODO(petemill): verify premium status, or ensure server will verify even
// when given a model name override via cli flag param.
std::string model_name = ai_chat::features::kAIModelName.Get();
if (model_name.empty()) {
model_name = model.name;
}
DCHECK(!model_name.empty());
DCHECK(!model.name.empty());
base::flat_set<std::string_view> stop_sequences(kStopSequences.begin(),
kStopSequences.end());
api_ = std::make_unique<RemoteCompletionClient>(
model_name, stop_sequences, url_loader_factory, credential_manager);
model.name, stop_sequences, url_loader_factory, credential_manager);

max_page_content_length_ = model.max_page_content_length;
}
Expand Down
53 changes: 32 additions & 21 deletions components/ai_chat/core/browser/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@

#include "brave/components/ai_chat/core/browser/models.h"

#include "base/no_destructor.h"
#include "brave/components/ai_chat/core/common/features.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"

namespace ai_chat {

// When adding new models, especially for display, make sure to add the UI
// strings to ai_chat_ui_strings.grdp and ai_chat/core/constants.cc.
// This also applies for modifying keys, since some of the strings are based
// on the model key.
// on the model key. Also be sure to migrate prefs if changing or removing
// keys.

// Llama2 Token Allocation:
// - Llama2 has a context limit: tokens + max_new_tokens <= 4096
Expand All @@ -32,25 +35,33 @@ namespace ai_chat {
// - Reserverd for page content: 100k / 2 = 50k tokens
// - Long conversation warning threshold: 100k * 0.80 = 80k tokens

const base::flat_map<std::string_view, mojom::Model> kAllModels = {
{"chat-default",
{"chat-default", "llama-2-13b-chat", "llama2 13b", "Meta",
mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, false,
9000, 9700}},
{"chat-leo-expanded",
{"chat-leo-expanded", "llama-2-70b-chat", "llama2 70b", "Meta",
mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, true,
9000, 9700}},
{"chat-claude-instant",
{"chat-claude-instant", "claude-instant-v1", "Claude Instant", "Anthropic",
mojom::ModelEngineType::CLAUDE_REMOTE, mojom::ModelCategory::CHAT, true,
200000, 320000}},
};

const std::vector<std::string_view> kAllModelKeysDisplayOrder = {
"chat-default",
"chat-leo-expanded",
"chat-claude-instant",
};
const std::vector<ai_chat::mojom::Model>& GetAllModels() {
static const auto kFreemiumAccess =
features::kFreemiumAvailable.Get() ? mojom::ModelAccess::BASIC_AND_PREMIUM
: mojom::ModelAccess::PREMIUM;
static const base::NoDestructor<std::vector<mojom::Model>> kModels({
{"chat-leo-expanded", "mixtral-8x7b-instruct", "Mixtral", "Mistral AI",
mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT,
kFreemiumAccess, 9000, 9700},
{"chat-claude-instant", "claude-instant-v1", "Claude Instant",
"Anthropic", mojom::ModelEngineType::CLAUDE_REMOTE,
mojom::ModelCategory::CHAT, kFreemiumAccess, 180000, 320000},
{"chat-basic", "llama-2-13b-chat", "llama2 13b", "Meta",
mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT,
mojom::ModelAccess::BASIC, 9000, 9700},
});
return *kModels;
}

const ai_chat::mojom::Model* GetModel(std::string_view key) {
auto& models = GetAllModels();
auto match = std::find_if(
models.cbegin(), models.cend(),
[&key](const mojom::Model& item) { return item.key == key; });
if (match != models.cend()) {
return &*match;
}
return nullptr;
}

} // namespace ai_chat
14 changes: 8 additions & 6 deletions components/ai_chat/core/browser/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,22 @@
#ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_MODELS_H_
#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_MODELS_H_

#include <string_view>
#include <vector>

#include "base/containers/flat_map.h"
#include "brave/components/ai_chat/core/common/features.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h"

namespace ai_chat {

inline constexpr char kModelsDefaultKey[] = "chat-default";
inline constexpr char kModelsPremiumDefaultKey[] = "chat-claude-instant";
// All models that the user can choose for chat conversations, in UI display
// order.
extern const std::vector<ai_chat::mojom::Model>& GetAllModels();

// All models that the user can choose for chat conversations.
extern const base::flat_map<std::string_view, mojom::Model> kAllModels;
// UI display order for models
extern const std::vector<std::string_view> kAllModelKeysDisplayOrder;
// Get model by key. If there is no matching model for the key, NULL is
// returned.
extern const ai_chat::mojom::Model* GetModel(std::string_view key);

} // namespace ai_chat

Expand Down
15 changes: 15 additions & 0 deletions components/ai_chat/core/common/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,18 @@ static_library("common") {
"//components/prefs",
]
}

if (!is_ios) {
source_set("unit_tests") {
testonly = true
sources = [ "pref_names_unittest.cc" ]

deps = [
"//base/test:test_support",
"//brave/components/ai_chat/core/common",
"//components/prefs:test_support",
"//content/test:test_support",
"//testing/gtest:gtest",
]
}
}
Loading