diff --git a/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc b/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc index 27d63d4d7012..aedd32563d60 100644 --- a/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc +++ b/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc @@ -108,15 +108,12 @@ void AIChatUIPageHandler::GetModels(GetModelsCallback callback) { return; } - std::vector models(kAllModelKeysDisplayOrder.size()); + auto all_models = GetAllModels(); + std::vector models(all_models.size()); // Ensure we return only in intended display order - std::transform(kAllModelKeysDisplayOrder.cbegin(), - kAllModelKeysDisplayOrder.cend(), models.begin(), - [](auto& model_key) { - auto model_match = kAllModels.find(model_key); - DCHECK(model_match != kAllModels.end()); - return model_match->second.Clone(); - }); + std::transform(all_models.cbegin(), all_models.cend(), models.begin(), + [](auto& model) { return model.Clone(); }); + std::move(callback).Run(std::move(models), active_chat_tab_helper_->GetCurrentModel().key); } diff --git a/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc b/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc index fa1725c5f9e7..e3f95a4f771d 100644 --- a/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc +++ b/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc @@ -147,26 +147,18 @@ void BraveLeoAssistantHandler::HandleResetLeoData( } void BraveLeoAssistantHandler::HandleGetModels(const base::Value::List& args) { - std::vector models( - ai_chat::kAllModelKeysDisplayOrder.size()); - // Ensure we return only in intended display order - std::transform(ai_chat::kAllModelKeysDisplayOrder.cbegin(), - ai_chat::kAllModelKeysDisplayOrder.cend(), models.begin(), - [](auto& model_key) { - auto model_match = ai_chat::kAllModels.find(model_key); - DCHECK(model_match != ai_chat::kAllModels.end()); - return model_match->second.Clone(); - }); + auto& models = ai_chat::GetAllModels(); base::Value::List models_list; for (auto& model : models) { base::Value::Dict dict; - dict.Set("key", model->key); - dict.Set("name", model->name); - dict.Set("display_name", model->display_name); - dict.Set("display_maker", model->display_maker); - dict.Set("engine_type", static_cast(model->engine_type)); - dict.Set("category", static_cast(model->category)); - dict.Set("is_premium", model->is_premium); + dict.Set("key", model.key); + dict.Set("name", model.name); + dict.Set("display_name", model.display_name); + dict.Set("display_maker", model.display_maker); + dict.Set("engine_type", static_cast(model.engine_type)); + dict.Set("category", static_cast(model.category)); + dict.Set("is_premium", + model.access == ai_chat::mojom::ModelAccess::PREMIUM); models_list.Append(std::move(dict)); } diff --git a/browser/ui/webui/settings/brave_settings_localized_strings_provider.cc b/browser/ui/webui/settings/brave_settings_localized_strings_provider.cc index e6c7d47b1ac5..48e7714be22e 100644 --- a/browser/ui/webui/settings/brave_settings_localized_strings_provider.cc +++ b/browser/ui/webui/settings/brave_settings_localized_strings_provider.cc @@ -394,7 +394,7 @@ void BraveAddCommonStrings(content::WebUIDataSource* html_source, {"braveLeoAssistantModelSelectionLabel", IDS_SETTINGS_LEO_ASSISTANT_MODEL_SELECTION_LABEL}, {"braveLeoModelCategory-chat", IDS_CHAT_UI_MODEL_CATEGORY_CHAT}, - {"braveLeoModelSubtitle-chat-default", IDS_CHAT_UI_CHAT_DEFAULT_SUBTITLE}, + {"braveLeoModelSubtitle-chat-basic", IDS_CHAT_UI_CHAT_BASIC_SUBTITLE}, {"braveLeoModelSubtitle-chat-leo-expanded", IDS_CHAT_UI_CHAT_LEO_EXPANDED_SUBTITLE}, {"braveLeoModelSubtitle-chat-claude-instant", diff --git a/components/ai_chat/core/browser/constants.cc b/components/ai_chat/core/browser/constants.cc index 6530173542d2..d0f897fae03e 100644 --- a/components/ai_chat/core/browser/constants.cc +++ b/components/ai_chat/core/browser/constants.cc @@ -23,12 +23,15 @@ base::span GetLocalizedStrings() { {"errorNetworkLabel", IDS_CHAT_UI_ERROR_NETWORK}, {"errorRateLimit", IDS_CHAT_UI_ERROR_RATE_LIMIT}, {"retryButtonLabel", IDS_CHAT_UI_RETRY_BUTTON_LABEL}, - {"introMessage-chat-default", IDS_CHAT_UI_INTRO_MESSAGE_CHAT_DEFAULT}, + {"introMessage-chat-basic", IDS_CHAT_UI_INTRO_MESSAGE_CHAT_BASIC}, {"introMessage-chat-leo-expanded", IDS_CHAT_UI_INTRO_MESSAGE_CHAT_LEO_EXPANDED}, {"introMessage-chat-claude-instant", IDS_CHAT_UI_INTRO_MESSAGE_CHAT_LEO_CLAUDE_INSTANT}, {"modelNameSyntax", IDS_CHAT_UI_MODEL_NAME_SYNTAX}, + {"modelFreemiumLabelNonPremium", + IDS_CHAT_UI_MODEL_FREEMIUM_LABEL_NON_PREMIUM}, + {"modelFreemiumLabelPremium", IDS_CHAT_UI_MODEL_FREEMIUM_LABEL_PREMIUM}, {"modelCategory-chat", IDS_CHAT_UI_MODEL_CATEGORY_CHAT}, {"menuNewChat", IDS_CHAT_UI_MENU_NEW_CHAT}, {"menuGoPremium", IDS_CHAT_UI_MENU_GO_PREMIUM}, @@ -44,8 +47,8 @@ base::span GetLocalizedStrings() { {"premiumFeature_2", IDS_CHAT_UI_PREMIUM_FEATURE_2}, {"premiumLabel", IDS_CHAT_UI_PREMIUM_LABEL}, {"premiumPricing", IDS_CHAT_UI_PREMIUM_PRICING}, - {"switchToDefaultModelButtonLabel", - IDS_CHAT_UI_SWITCH_TO_DEFAULT_MODEL_BUTTON_LABEL}, + {"switchToBasicModelButtonLabel", + IDS_CHAT_UI_SWITCH_TO_BASIC_MODEL_BUTTON_LABEL}, {"dismissButtonLabel", IDS_CHAT_UI_DISMISS_BUTTON_LABEL}, {"unlockPremiumTitle", IDS_CHAT_UI_UNLOCK_PREMIUM_TITLE}, {"premiumFeature_1_desc", IDS_CHAT_UI_PREMIUM_FEATURE_1_DESC}, @@ -74,7 +77,7 @@ base::span GetLocalizedStrings() { {"optionOther", IDS_CHAT_UI_OPTION_OTHER}, {"feedbackError", IDS_CHAT_UI_FEEDBACK_SUBMIT_ERROR}, {"ratingError", IDS_CHAT_UI_RATING_ERROR}, - {"braveLeoModelSubtitle-chat-default", IDS_CHAT_UI_CHAT_DEFAULT_SUBTITLE}, + {"braveLeoModelSubtitle-chat-basic", IDS_CHAT_UI_CHAT_BASIC_SUBTITLE}, {"braveLeoModelSubtitle-chat-leo-expanded", IDS_CHAT_UI_CHAT_LEO_EXPANDED_SUBTITLE}, {"braveLeoModelSubtitle-chat-claude-instant", diff --git a/components/ai_chat/core/browser/conversation_driver.cc b/components/ai_chat/core/browser/conversation_driver.cc index bdb3fbe01d4b..e51f6793f9f0 100644 --- a/components/ai_chat/core/browser/conversation_driver.cc +++ b/components/ai_chat/core/browser/conversation_driver.cc @@ -59,15 +59,18 @@ ConversationDriver::ConversationDriver(raw_ptr pref_service, base::BindRepeating(&ConversationDriver::OnUserOptedIn, weak_ptr_factory_.GetWeakPtr())); - // Engines and model names are selectable per conversation, not static. - // Start with default from pref value but only if user set. We can't rely on - // actual default pref value since we should vary if user is premium or not. + // Model choice names is selectable per conversation, not global. + // Start with default from pref value if. If user is premium and premium model + // is different to non-premium default, and user hasn't customized the model + // pref, then switch the user to the premium default. // TODO(petemill): When we have an event for premium status changed, and a // profile service for AIChat, then we can call // |pref_service_->SetDefaultPrefValue| when the user becomes premium. With // that, we'll be able to simply call GetString(prefs::kDefaultModelKey) and - // not vary on premium status. - if (!pref_service_->GetUserPrefValue(prefs::kDefaultModelKey)) { + // not have to fetch premium status. + if (!pref_service_->GetUserPrefValue(prefs::kDefaultModelKey) && + features::kAIModelsPremiumDefaultKey.Get() != + features::kAIModelsDefaultKey.Get()) { credential_manager_->GetPremiumStatus(base::BindOnce( [](ConversationDriver* instance, mojom::PremiumStatus status) { instance->last_premium_status_ = status; @@ -76,7 +79,7 @@ ConversationDriver::ConversationDriver(raw_ptr pref_service, return; } // Use default premium model for this instance - instance->ChangeModel(kModelsPremiumDefaultKey); + instance->ChangeModel(features::kAIModelsPremiumDefaultKey.Get()); // Make sure default model reflects premium status const auto* current_default = instance->pref_service_ @@ -84,10 +87,10 @@ ConversationDriver::ConversationDriver(raw_ptr pref_service, ->GetIfString(); if (current_default && - *current_default != kModelsPremiumDefaultKey) { + *current_default != features::kAIModelsPremiumDefaultKey.Get()) { instance->pref_service_->SetDefaultPrefValue( prefs::kDefaultModelKey, - base::Value(kModelsPremiumDefaultKey)); + base::Value(features::kAIModelsPremiumDefaultKey.Get())); } }, // Unretained is ok as credential manager is owned by this class, @@ -118,16 +121,19 @@ ConversationDriver::~ConversationDriver() = default; void ConversationDriver::ChangeModel(const std::string& model_key) { DCHECK(!model_key.empty()); // Check that the key exists - if (kAllModels.find(model_key) == kAllModels.end()) { + auto* new_model = GetModel(model_key); + if (!new_model) { NOTREACHED() << "No matching model found for key: " << model_key; return; } - model_key_ = model_key; + model_key_ = new_model->key; InitEngine(); } const mojom::Model& ConversationDriver::GetCurrentModel() { - return kAllModels.find(model_key_)->second; + auto* model = GetModel(model_key_); + DCHECK(model); + return *model; } const std::vector& ConversationDriver::GetConversationHistory() { @@ -145,33 +151,32 @@ void ConversationDriver::OnConversationActiveChanged(bool is_conversation_active void ConversationDriver::InitEngine() { DCHECK(!model_key_.empty()); - auto model_match = kAllModels.find(model_key_); + auto* model = GetModel(model_key_); // Make sure we get a valid model, defaulting to static default or first. - if (model_match == kAllModels.end()) { + if (!model) { NOTREACHED() << "Model was not part of static model list"; // Use default - model_match = kAllModels.find(kModelsDefaultKey); - const auto is_found = model_match != kAllModels.end(); - DCHECK(is_found); - if (!is_found) { - model_match = kAllModels.begin(); + model = GetModel(features::kAIModelsDefaultKey.Get()); + DCHECK(model); + if (!model) { + // Use first if given bad default value + model = &GetAllModels().at(0); } } - auto model = model_match->second; // Model's key might not be the same as what we asked for (e.g. if the model // no longer exists). - model_key_ = model.key; + model_key_ = model->key; // Engine enum on model to decide which one - if (model.engine_type == mojom::ModelEngineType::LLAMA_REMOTE) { + if (model->engine_type == mojom::ModelEngineType::LLAMA_REMOTE) { VLOG(1) << "Started AI engine: llama"; engine_ = std::make_unique( - model, url_loader_factory_, credential_manager_.get()); + *model, url_loader_factory_, credential_manager_.get()); } else { VLOG(1) << "Started AI engine: claude"; engine_ = std::make_unique( - model, url_loader_factory_, credential_manager_.get()); + *model, url_loader_factory_, credential_manager_.get()); } // Pending requests have been deleted along with the model engine @@ -719,7 +724,7 @@ void ConversationDriver::OnPremiumStatusReceived( if (last_premium_status_ != premium_status && premium_status == mojom::PremiumStatus::Active) { // Change model if we haven't already - ChangeModel(kModelsPremiumDefaultKey); + ChangeModel(features::kAIModelsPremiumDefaultKey.Get()); } last_premium_status_ = premium_status; std::move(parent_callback).Run(premium_status); diff --git a/components/ai_chat/core/browser/engine/engine_consumer_claude.cc b/components/ai_chat/core/browser/engine/engine_consumer_claude.cc index 6b92ed5f8a7c..99faf2fe5fdd 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_claude.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_claude.cc @@ -125,18 +125,11 @@ EngineConsumerClaudeRemote::EngineConsumerClaudeRemote( const mojom::Model& model, scoped_refptr url_loader_factory, AIChatCredentialManager* credential_manager) { - // Allow specific model name to be overriden by feature flag - // TODO(petemill): verify premium status, or ensure server will verify even - // when given a model name override via cli flag param. - std::string model_name = ai_chat::features::kAIModelName.Get(); - if (model_name.empty()) { - model_name = model.name; - } - DCHECK(!model_name.empty()); + DCHECK(!model.name.empty()); base::flat_set stop_sequences(kStopSequences.begin(), kStopSequences.end()); api_ = std::make_unique( - model_name, stop_sequences, url_loader_factory, credential_manager); + model.name, stop_sequences, url_loader_factory, credential_manager); max_page_content_length_ = model.max_page_content_length; } diff --git a/components/ai_chat/core/browser/engine/engine_consumer_llama.cc b/components/ai_chat/core/browser/engine/engine_consumer_llama.cc index d489c801b8c4..d1805cf5ff9e 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_llama.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_llama.cc @@ -243,18 +243,11 @@ EngineConsumerLlamaRemote::EngineConsumerLlamaRemote( const mojom::Model& model, scoped_refptr url_loader_factory, AIChatCredentialManager* credential_manager) { - // Allow specific model name to be overriden by feature flag - // TODO(petemill): verify premium status, or ensure server will verify even - // when given a model name override via cli flag param. - std::string model_name = ai_chat::features::kAIModelName.Get(); - if (model_name.empty()) { - model_name = model.name; - } - DCHECK(!model_name.empty()); + DCHECK(!model.name.empty()); base::flat_set stop_sequences(kStopSequences.begin(), kStopSequences.end()); api_ = std::make_unique( - model_name, stop_sequences, url_loader_factory, credential_manager); + model.name, stop_sequences, url_loader_factory, credential_manager); max_page_content_length_ = model.max_page_content_length; } diff --git a/components/ai_chat/core/browser/models.cc b/components/ai_chat/core/browser/models.cc index 0de03d5460fd..fa84fb272c14 100644 --- a/components/ai_chat/core/browser/models.cc +++ b/components/ai_chat/core/browser/models.cc @@ -7,6 +7,8 @@ #include "brave/components/ai_chat/core/browser/models.h" +#include "base/no_destructor.h" +#include "brave/components/ai_chat/core/common/features.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" namespace ai_chat { @@ -14,7 +16,8 @@ namespace ai_chat { // When adding new models, especially for display, make sure to add the UI // strings to ai_chat_ui_strings.grdp and ai_chat/core/constants.cc. // This also applies for modifying keys, since some of the strings are based -// on the model key. +// on the model key. Also be sure to migrate prefs if changing or removing +// keys. // Llama2 Token Allocation: // - Llama2 has a context limit: tokens + max_new_tokens <= 4096 @@ -32,25 +35,33 @@ namespace ai_chat { // - Reserverd for page content: 100k / 2 = 50k tokens // - Long conversation warning threshold: 100k * 0.80 = 80k tokens -const base::flat_map kAllModels = { - {"chat-default", - {"chat-default", "llama-2-13b-chat", "llama2 13b", "Meta", - mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, false, - 9000, 9700}}, - {"chat-leo-expanded", - {"chat-leo-expanded", "llama-2-70b-chat", "llama2 70b", "Meta", - mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, true, - 9000, 9700}}, - {"chat-claude-instant", - {"chat-claude-instant", "claude-instant-v1", "Claude Instant", "Anthropic", - mojom::ModelEngineType::CLAUDE_REMOTE, mojom::ModelCategory::CHAT, true, - 200000, 320000}}, -}; - -const std::vector kAllModelKeysDisplayOrder = { - "chat-default", - "chat-leo-expanded", - "chat-claude-instant", -}; +const std::vector& GetAllModels() { + static const auto kFreemiumAccess = + features::kFreemiumAvailable.Get() ? mojom::ModelAccess::BASIC_AND_PREMIUM + : mojom::ModelAccess::PREMIUM; + static const base::NoDestructor> kModels({ + {"chat-leo-expanded", "mixtral-8x7b-instruct", "Mixtral", "Mistral AI", + mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, + kFreemiumAccess, 9000, 9700}, + {"chat-claude-instant", "claude-instant-v1", "Claude Instant", + "Anthropic", mojom::ModelEngineType::CLAUDE_REMOTE, + mojom::ModelCategory::CHAT, kFreemiumAccess, 180000, 320000}, + {"chat-basic", "llama-2-13b-chat", "llama2 13b", "Meta", + mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, + mojom::ModelAccess::BASIC, 9000, 9700}, + }); + return *kModels; +} + +const ai_chat::mojom::Model* GetModel(std::string_view key) { + auto& models = GetAllModels(); + auto match = std::find_if( + models.cbegin(), models.cend(), + [&key](const mojom::Model& item) { return item.key == key; }); + if (match != models.cend()) { + return &*match; + } + return nullptr; +} } // namespace ai_chat diff --git a/components/ai_chat/core/browser/models.h b/components/ai_chat/core/browser/models.h index 9492e6dfb6d7..5fd32e10704c 100644 --- a/components/ai_chat/core/browser/models.h +++ b/components/ai_chat/core/browser/models.h @@ -6,20 +6,22 @@ #ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_MODELS_H_ #define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_MODELS_H_ +#include #include #include "base/containers/flat_map.h" +#include "brave/components/ai_chat/core/common/features.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" namespace ai_chat { -inline constexpr char kModelsDefaultKey[] = "chat-default"; -inline constexpr char kModelsPremiumDefaultKey[] = "chat-claude-instant"; +// All models that the user can choose for chat conversations, in UI display +// order. +extern const std::vector& GetAllModels(); -// All models that the user can choose for chat conversations. -extern const base::flat_map kAllModels; -// UI display order for models -extern const std::vector kAllModelKeysDisplayOrder; +// Get model by key. If there is no matching model for the key, NULL is +// returned. +extern const ai_chat::mojom::Model* GetModel(std::string_view key); } // namespace ai_chat diff --git a/components/ai_chat/core/common/BUILD.gn b/components/ai_chat/core/common/BUILD.gn index e1080ea7705a..1a3d36a5cb29 100644 --- a/components/ai_chat/core/common/BUILD.gn +++ b/components/ai_chat/core/common/BUILD.gn @@ -21,3 +21,18 @@ static_library("common") { "//components/prefs", ] } + +if (!is_ios) { + source_set("unit_tests") { + testonly = true + sources = [ "pref_names_unittest.cc" ] + + deps = [ + "//base/test:test_support", + "//brave/components/ai_chat/core/common", + "//components/prefs:test_support", + "//content/test:test_support", + "//testing/gtest:gtest", + ] + } +} diff --git a/components/ai_chat/core/common/features.cc b/components/ai_chat/core/common/features.cc index 251762a41635..4ef9a4f8a275 100644 --- a/components/ai_chat/core/common/features.cc +++ b/components/ai_chat/core/common/features.cc @@ -21,8 +21,13 @@ BASE_FEATURE(kAIChat, base::FEATURE_DISABLED_BY_DEFAULT #endif ); -const base::FeatureParam kAIModelName{&kAIChat, "ai_model_name", - ""}; +const base::FeatureParam kAIModelsDefaultKey{ + &kAIChat, "default_model", "chat-leo-expanded"}; +const base::FeatureParam kAIModelsPremiumDefaultKey{ + &kAIChat, "default_premium_model", "chat-leo-expanded"}; +const base::FeatureParam kFreemiumAvailable(&kAIChat, + "is_freemium_available", + true); const base::FeatureParam kAIChatSSE{&kAIChat, "ai_chat_sse", true}; const base::FeatureParam kAITemperature{&kAIChat, "temperature", 0.2}; diff --git a/components/ai_chat/core/common/features.h b/components/ai_chat/core/common/features.h index 29bf90b9b8d3..3a03dc3e725c 100644 --- a/components/ai_chat/core/common/features.h +++ b/components/ai_chat/core/common/features.h @@ -14,7 +14,13 @@ namespace ai_chat::features { BASE_DECLARE_FEATURE(kAIChat); -extern const base::FeatureParam kAIModelName; +extern const base::FeatureParam kAIModelsDefaultKey; +extern const base::FeatureParam kAIModelsPremiumDefaultKey; + +// If true, certain freemium models are available to non-premium users. If +// false, those models are premium-only. +extern const base::FeatureParam kFreemiumAvailable; + extern const base::FeatureParam kAIChatSSE; extern const base::FeatureParam kAITemperature; diff --git a/components/ai_chat/core/common/mojom/ai_chat.mojom b/components/ai_chat/core/common/mojom/ai_chat.mojom index cbeb9a96e651..52eb5147e22e 100644 --- a/components/ai_chat/core/common/mojom/ai_chat.mojom +++ b/components/ai_chat/core/common/mojom/ai_chat.mojom @@ -35,6 +35,15 @@ enum ModelCategory { CHAT, }; +enum ModelAccess { + // The model only has a single basic tier, accessible by any level + BASIC, + // The model has a basic tier and a more capable premium tier (a.k.a freemium) + BASIC_AND_PREMIUM, + // The model only has a premium tier + PREMIUM, +}; + enum PremiumStatus { Inactive, Active, @@ -86,7 +95,8 @@ struct Model { ModelEngineType engine_type; // user-facing category ModelCategory category; - bool is_premium; + // Which access level grants permission to use the model + ModelAccess access; // max limit to truncate page contents (measured in chars, not tokens) uint32 max_page_content_length; // max limit for overall conversation (measured in chars, not tokens) diff --git a/components/ai_chat/core/common/pref_names.cc b/components/ai_chat/core/common/pref_names.cc index e67b052dc98c..847ed825881a 100644 --- a/components/ai_chat/core/common/pref_names.cc +++ b/components/ai_chat/core/common/pref_names.cc @@ -5,6 +5,10 @@ #include "brave/components/ai_chat/core/common/pref_names.h" +#include + +#include "base/strings/string_util.h" +#include "brave/components/ai_chat/core/common/features.h" #include "components/prefs/pref_registry_simple.h" #include "components/prefs/pref_service.h" @@ -21,7 +25,8 @@ void RegisterProfilePrefs(PrefRegistrySimple* registry) { registry->RegisterTimePref(kLastAcceptedDisclaimer, {}); registry->RegisterBooleanPref(kBraveChatAutocompleteProviderEnabled, true); registry->RegisterBooleanPref(kUserDismissedPremiumPrompt, false); - registry->RegisterStringPref(kDefaultModelKey, "chat-default"); + registry->RegisterStringPref(kDefaultModelKey, + features::kAIModelsDefaultKey.Get()); } void RegisterProfilePrefsForMigration(PrefRegistrySimple* registry) { @@ -30,6 +35,16 @@ void RegisterProfilePrefsForMigration(PrefRegistrySimple* registry) { void MigrateProfilePrefs(PrefService* profile_prefs) { profile_prefs->ClearPref(kObseleteBraveChatAutoGenerateQuestions); + // migrate model key from "chat-default" to "chat-basic" + static const std::string kDefaultModelBasicFrom = "chat-default"; + static const std::string kDefaultModelBasicTo = "chat-basic"; + if (auto* default_model_value = + profile_prefs->GetUserPrefValue(kDefaultModelKey)) { + if (base::EqualsCaseInsensitiveASCII(default_model_value->GetString(), + kDefaultModelBasicFrom)) { + profile_prefs->SetString(kDefaultModelKey, kDefaultModelBasicTo); + } + } } void RegisterLocalStatePrefs(PrefRegistrySimple* registry) { diff --git a/components/ai_chat/core/common/pref_names_unittest.cc b/components/ai_chat/core/common/pref_names_unittest.cc new file mode 100644 index 000000000000..bcdcc236210e --- /dev/null +++ b/components/ai_chat/core/common/pref_names_unittest.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "brave/components/ai_chat/core/common/pref_names.h" + +#include + +#include "components/prefs/testing_pref_service.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace ai_chat::prefs { + +class AIChatPrefMigrationTest : public ::testing::Test { + public: + void SetUp() override { + RegisterProfilePrefs(pref_service_.registry()); + RegisterProfilePrefsForMigration(pref_service_.registry()); + } + + TestingPrefServiceSimple pref_service_; +}; + +TEST_F(AIChatPrefMigrationTest, ChangeOldDefaultKey) { + pref_service_.SetString(kDefaultModelKey, "chat-default"); + MigrateProfilePrefs(&pref_service_); + + EXPECT_EQ(pref_service_.GetUserPrefValue(kDefaultModelKey)->GetString(), + "chat-basic"); +} + +} // namespace ai_chat::prefs diff --git a/components/ai_chat/resources/page/components/alerts/error_rate_limit.tsx b/components/ai_chat/resources/page/components/alerts/error_rate_limit.tsx index af2b83631d23..0071a25bd5c3 100644 --- a/components/ai_chat/resources/page/components/alerts/error_rate_limit.tsx +++ b/components/ai_chat/resources/page/components/alerts/error_rate_limit.tsx @@ -7,24 +7,37 @@ import * as React from 'react' import { getLocale } from '$web-common/locale' import Alert from '@brave/leo/react/alert' import Button from '@brave/leo/react/button' +import getPageHandlerInstance, * as mojom from '../../api/page_handler' import DataContext from '../../state/context' import PremiumSuggestion from '../premium_suggestion' import styles from './alerts.module.scss' -interface ErrorRateLimit { - onRetry?: () => void -} -function ErrorRateLimit(props: ErrorRateLimit) { - const { isPremiumUser } = React.useContext(DataContext) +function ErrorRateLimit() { + const context = React.useContext(DataContext) - if (!isPremiumUser) { + if (!context.isPremiumUser) { + // Freemium model with non-premium user has stricter rate limits. Secondary + // action is to switch to completely free model. + if (context.currentModel?.access === mojom.ModelAccess.BASIC_AND_PREMIUM) { + return ( + + {getLocale('switchToBasicModelButtonLabel')} + + } + /> + ) + } return ( + } @@ -42,7 +55,7 @@ function ErrorRateLimit(props: ErrorRateLimit) { diff --git a/components/ai_chat/resources/page/components/feature_button_menu/index.tsx b/components/ai_chat/resources/page/components/feature_button_menu/index.tsx index 71adf4c8c649..71ee22c688cc 100644 --- a/components/ai_chat/resources/page/components/feature_button_menu/index.tsx +++ b/components/ai_chat/resources/page/components/feature_button_menu/index.tsx @@ -7,8 +7,9 @@ import * as React from 'react' import ButtonMenu from '@brave/leo/react/buttonMenu' import Button from '@brave/leo/react/button' import Icon from '@brave/leo/react/icon' +import Label from '@brave/leo/react/label' import { getLocale } from '$web-common/locale' -import getPageHandlerInstance from '../../api/page_handler' +import getPageHandlerInstance, * as mojom from '../../api/page_handler' import DataContext from '../../state/context' import styles from './style.module.scss' import classnames from '$web-common/classnames' @@ -50,12 +51,12 @@ export default function FeatureMenu() { >
-
{model.name}
+
{model.displayName}

{getLocale(`braveLeoModelSubtitle-${model.key}`)}

- {model.isPremium && ( + {model.access === mojom.ModelAccess.PREMIUM && ( )} + {model.access === mojom.ModelAccess.BASIC_AND_PREMIUM && ( + + )}
))} diff --git a/components/ai_chat/resources/page/components/feature_button_menu/style.module.scss b/components/ai_chat/resources/page/components/feature_button_menu/style.module.scss index a195e28ca62e..fce8ca0c73b0 100644 --- a/components/ai_chat/resources/page/components/feature_button_menu/style.module.scss +++ b/components/ai_chat/resources/page/components/feature_button_menu/style.module.scss @@ -58,6 +58,10 @@ margin: 0; } +.modelFreemiumLabel { + align-self: flex-start; +} + .lockOpen { --leo-icon-color: var(--leo-color-systemfeedback-success-icon); } diff --git a/components/ai_chat/resources/page/components/main/index.tsx b/components/ai_chat/resources/page/components/main/index.tsx index 1fb60230ebff..53bc1edc642e 100644 --- a/components/ai_chat/resources/page/components/main/index.tsx +++ b/components/ai_chat/resources/page/components/main/index.tsx @@ -45,7 +45,7 @@ function Main() { hasAcceptedAgreement && !context.isPremiumStatusFetching && // Avoid flash of content !context.isPremiumUser && - context.currentModel?.isPremium + context.currentModel?.access === mojom.ModelAccess.PREMIUM const shouldShowPremiumSuggestionStandalone = hasAcceptedAgreement && @@ -72,9 +72,7 @@ function Main() { if (apiHasError && currentError === mojom.APIError.RateLimitReached) { currentErrorElement = ( - getPageHandlerInstance().pageHandler.retryAPIRequest()} - /> + ) } @@ -133,9 +131,9 @@ function Main() { secondaryActionButton={ } /> diff --git a/components/ai_chat/resources/page/state/context.ts b/components/ai_chat/resources/page/state/context.ts index 44652824cdb5..bb2b95bc2987 100644 --- a/components/ai_chat/resources/page/state/context.ts +++ b/components/ai_chat/resources/page/state/context.ts @@ -28,7 +28,7 @@ export interface AIChatContext { showAgreementModal: boolean shouldSendPageContents: boolean setCurrentModel: (model: mojom.Model) => void, - switchToDefaultModel: () => void, + switchToBasicModel: () => void, generateSuggestedQuestions: () => void goPremium: () => void managePremium: () => void @@ -62,7 +62,7 @@ export const defaultContext: AIChatContext = { showAgreementModal: false, shouldSendPageContents: true, setCurrentModel: () => {}, - switchToDefaultModel: () => {}, + switchToBasicModel: () => {}, generateSuggestedQuestions: () => {}, goPremium: () => {}, managePremium: () => {}, diff --git a/components/ai_chat/resources/page/state/data-context-provider.tsx b/components/ai_chat/resources/page/state/data-context-provider.tsx index ab538f789371..a81d63e8baf3 100644 --- a/components/ai_chat/resources/page/state/data-context-provider.tsx +++ b/components/ai_chat/resources/page/state/data-context-provider.tsx @@ -70,7 +70,7 @@ function DataContextProvider (props: DataContextProviderProps) { const isPremiumUser = premiumStatus !== undefined && premiumStatus !== mojom.PremiumStatus.Inactive const apiHasError = (currentError !== mojom.APIError.None) - const shouldDisableUserInput = !!(apiHasError || isGenerating || (!isPremiumUser && currentModel?.isPremium)) + const shouldDisableUserInput = !!(apiHasError || isGenerating || (!isPremiumUser && currentModel?.access === mojom.ModelAccess.PREMIUM)) const getConversationHistory = () => { getPageHandlerInstance() @@ -128,9 +128,9 @@ function DataContextProvider (props: DataContextProviderProps) { setCanShowPremiumPrompt(false) } - const switchToDefaultModel = () => { + const switchToBasicModel = () => { // Select the first non-premium model - const nonPremium = allModels.find(m => !m.isPremium) + const nonPremium = allModels.find(m => [mojom.ModelAccess.BASIC, mojom.ModelAccess.BASIC_AND_PREMIUM].includes(m.access)) if (!nonPremium) { console.error('Could not find a non-premium model!') return @@ -297,7 +297,7 @@ function DataContextProvider (props: DataContextProviderProps) { showAgreementModal, shouldSendPageContents: shouldSendPageContents && siteInfo?.isContentAssociationPossible, setCurrentModel, - switchToDefaultModel, + switchToBasicModel, goPremium, managePremium, generateSuggestedQuestions, diff --git a/components/ai_chat/resources/page/stories/components_panel.tsx b/components/ai_chat/resources/page/stories/components_panel.tsx index ab86ec144452..b44695eed853 100644 --- a/components/ai_chat/resources/page/stories/components_panel.tsx +++ b/components/ai_chat/resources/page/stories/components_panel.tsx @@ -49,7 +49,7 @@ const MODELS: mojom.Model[] = [ displayMaker: 'Company', engineType: mojom.ModelEngineType.LLAMA_REMOTE, category: mojom.ModelCategory.CHAT, - isPremium: false, + access: mojom.ModelAccess.BASIC, maxPageContentLength: 10000, longConversationWarningCharacterLimit: 9700 }, @@ -60,7 +60,18 @@ const MODELS: mojom.Model[] = [ displayMaker: 'Company', engineType: mojom.ModelEngineType.LLAMA_REMOTE, category: mojom.ModelCategory.CHAT, - isPremium: true, + access: mojom.ModelAccess.PREMIUM, + maxPageContentLength: 10000, + longConversationWarningCharacterLimit: 9700 + }, + { + key: '3', + name: 'model-three-freemium', + displayName: 'Model Three', + displayMaker: 'Company', + engineType: mojom.ModelEngineType.LLAMA_REMOTE, + category: mojom.ModelCategory.CHAT, + access: mojom.ModelAccess.BASIC_AND_PREMIUM, maxPageContentLength: 10000, longConversationWarningCharacterLimit: 9700 } diff --git a/components/ai_chat/resources/page/stories/locale.ts b/components/ai_chat/resources/page/stories/locale.ts index 22ee3d0b0ea4..9f34969dfdf0 100644 --- a/components/ai_chat/resources/page/stories/locale.ts +++ b/components/ai_chat/resources/page/stories/locale.ts @@ -21,6 +21,8 @@ provideStrings({ 'introMessage-0': `I'm here to help. What can I assist you with today?`, 'introMessage-1': 'I have a vast base of knowledge and a large memory able to help with more complex challenges.', modelNameSyntax: '$1 by $2', + modelFreemiumLabelNonPremium: 'Limited', + modelFreemiumLabelPremium: 'Unlimited', 'modelCategory-chat': 'Chat', menuNewChat: 'New chat', menuSettings: 'Settings', @@ -59,6 +61,7 @@ provideStrings({ feedbackError: 'Your feedback could not be submitted, please check network connection and try again', premiumRefreshWarningDescription: 'Your Brave account session has expired. Please visit your account page to refresh, then come back to use premium features.', premiumRefreshWarningAction: 'Refresh', + switchToBasicModelButtonLabel: 'Switch to the free model', clearChatButtonLabel: 'Clear chat', errorContextLimitReaching: 'This conversation is long and Leo may start forgetting things soon. You can get longer conversations by switching to a premium model, or you can clear the chat to reset it', gotItButtonLabel: 'Got it', diff --git a/components/resources/ai_chat_ui_strings.grdp b/components/resources/ai_chat_ui_strings.grdp index 37f17ed391bd..d62ba7b7b904 100644 --- a/components/resources/ai_chat_ui_strings.grdp +++ b/components/resources/ai_chat_ui_strings.grdp @@ -51,11 +51,11 @@ Retry - + Hi, I'm Leo. I'm a fully hosted AI assistant by Brave. I'm powered by Llama 13B, a model created by Meta to be performant and applicable to many use cases. - Hi, I'm Leo. I'm a fully hosted AI assistant by Brave. I'm powered by Llama 70B, a model created by Meta to handle more advanced tasks than 13B. + Hi, I'm Leo. I'm a fully hosted AI assistant by Brave. I'm powered by Mixtral 7B, a model created by Mistral AI to handle advanced tasks. Hi, I'm Leo. I'm proxied by Brave and powered by Claude Instant, a model created by Anthropic to power conversational and text processing tasks. @@ -63,6 +63,12 @@ $1llama2-13b by $2Meta + + Limited + + + Unlimited + Chat @@ -106,8 +112,8 @@ Dismiss - - Switch to the default model + + Switch to the free model Unlock the full potential of Leo @@ -202,7 +208,7 @@ This conversation is too long and cannot continue. There may be other models available with which Leo is capable of maintaining accuracy for longer conversations. - + General purpose chat diff --git a/test/BUILD.gn b/test/BUILD.gn index 4f1df43ac7e2..9a3cc1a6b19b 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -360,6 +360,7 @@ test("brave_unit_tests") { deps += [ "//brave/browser/ai_chat:unit_tests", "//brave/components/ai_chat/core/browser:unit_tests", + "//brave/components/ai_chat/core/common:unit_tests", ] }