diff --git a/CHANGELOG.md b/CHANGELOG.md index 696573c41..6e172b685 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ - Webhooks always act like static insts. - This means that any changes made to bots in the webhook are erased after the webhook finishes. - Added the ability to request consent again so that a parent can adjust the privacy features for their child. +- Added `ai.chat.allowedModels` feature to enforce model usage limits, restricting access to specific models based on configuration. ### :bug: Bug Fixes diff --git a/src/aux-records/AIController.spec.ts b/src/aux-records/AIController.spec.ts index 67aa8471e..bfdedf7b9 100644 --- a/src/aux-records/AIController.spec.ts +++ b/src/aux-records/AIController.spec.ts @@ -369,6 +369,61 @@ describe('AIController', () => { expect(chatInterface.chat).not.toBeCalled(); }); + it('should return not_authorized when allowedModels does not include the model', async () => { + controller = new AIController({ + chat: { + interfaces: { + provider1: chatInterface, + }, + options: { + defaultModel: 'default-model', + defaultModelProvider: 'provider1', + allowedChatModels: [ + { + provider: 'provider1', + model: 'modelA', + }, + { + provider: 'provider1', + model: 'modelB', + }, + ], + allowedChatSubscriptionTiers: ['test-tier'], + }, + }, + generateSkybox: null, + images: null, + metrics: store, + config: store, + hume: null, + sloyd: null, + policies: null, + policyController: policies, + records: store, + }); + + const result = await controller.chat({ + model: 'modelC', + messages: [ + { + role: 'user', + content: 'test', + }, + ], + temperature: 0.5, + userId, + userSubscriptionTier, + }); + + expect(result).toEqual({ + success: false, + errorCode: 'not_authorized', + errorMessage: + 'The subscription does not permit the given model for AI Chat features.', + }); + expect(chatInterface.chat).not.toBeCalled(); + }); + it('should return an not_logged_in result if the given a null userId', async () => { const result = await controller.chat({ model: 'test-model1', @@ -649,6 +704,56 @@ describe('AIController', () => { }); }); + it('should return success when allowedModels includes the model', async () => { + chatInterface.chat.mockReturnValueOnce( + Promise.resolve({ + choices: [ + { + role: 'user', + content: 'test', + finishReason: 'stop', + }, + ], + totalTokens: 1, + }) + ); + + const result = await controller.chat({ + model: 'test-model1', + messages: [ + { + role: 'user', + content: 'test', + }, + ], + temperature: 0.5, + userId, + userSubscriptionTier, + }); + + expect(result).toEqual({ + success: true, + choices: [ + { + role: 'user', + content: 'test', + finishReason: 'stop', + }, + ], + }); + expect(chatInterface.chat).toBeCalledWith({ + model: 'test-model1', + messages: [ + { + role: 'user', + content: 'test', + }, + ], + temperature: 0.5, + userId: 'test-user', + }); + }); + it('should specify the maximum number of tokens allowed based on how many tokens the subscription has left in the period', async () => { chatInterface.chat.mockReturnValueOnce( Promise.resolve({ @@ -1502,6 +1607,61 @@ describe('AIController', () => { }); }); + it('should return not_authorized error when allowedModels does not include the model', async () => { + controller = new AIController({ + chat: { + interfaces: { + provider1: chatInterface, + }, + options: { + defaultModel: 'default-model', + defaultModelProvider: 'provider1', + allowedChatModels: [ + { + provider: 'provider1', + model: 'modelA', + }, + { + provider: 'provider1', + model: 'modelB', + }, + ], + allowedChatSubscriptionTiers: ['test-tier'], + }, + }, + generateSkybox: null, + images: null, + metrics: store, + config: store, + hume: null, + sloyd: null, + policies: null, + policyController: policies, + records: store, + }); + + const result = await controller.chat({ + model: 'modelC', + messages: [ + { + role: 'user', + content: 'test', + }, + ], + temperature: 0.5, + userId, + userSubscriptionTier, + }); + + expect(result).toEqual({ + success: false, + errorCode: 'not_authorized', + errorMessage: + 'The subscription does not permit the given model for AI Chat features.', + }); + expect(chatInterface.chat).not.toBeCalled(); + }); + describe('subscriptions', () => { beforeEach(async () => { store.subscriptionConfiguration = buildSubscriptionConfig( @@ -1530,6 +1690,56 @@ describe('AIController', () => { }); }); + it('should return success when allowedModels includes the model', async () => { + chatInterface.chat.mockReturnValueOnce( + Promise.resolve({ + choices: [ + { + role: 'user', + content: 'test', + finishReason: 'stop', + }, + ], + totalTokens: 1, + }) + ); + + const result = await controller.chat({ + model: 'test-model1', + messages: [ + { + role: 'user', + content: 'test', + }, + ], + temperature: 0.5, + userId, + userSubscriptionTier, + }); + + expect(result).toEqual({ + success: true, + choices: [ + { + role: 'user', + content: 'test', + finishReason: 'stop', + }, + ], + }); + expect(chatInterface.chat).toBeCalledWith({ + model: 'test-model1', + messages: [ + { + role: 'user', + content: 'test', + }, + ], + temperature: 0.5, + userId: 'test-user', + }); + }); + it('should specify the maximum number of tokens allowed based on how many tokens the subscription has left in the period', async () => { chatInterface.chatStream.mockReturnValueOnce( asyncIterable([ diff --git a/src/aux-records/AIController.ts b/src/aux-records/AIController.ts index b64dd906e..36ca4329f 100644 --- a/src/aux-records/AIController.ts +++ b/src/aux-records/AIController.ts @@ -450,6 +450,26 @@ export class AIController { maxTokens, }); + if (allowedFeatures.ai.chat.allowedModels) { + const allowedModels = allowedFeatures.ai.chat.allowedModels; + if ( + !allowedModels || + allowedModels.length === 0 || + allowedModels.includes(model) + ) { + return { + success: true, + choices: result.choices, + }; + } + return { + success: false, + errorCode: 'not_authorized', + errorMessage: + 'The subscription does not permit the given model for AI Chat features.', + }; + } + if (result.totalTokens > 0) { await this._metrics.recordChatMetrics({ userId: request.userId, @@ -598,6 +618,25 @@ export class AIController { }; } + if (allowedFeatures.ai.chat.allowedModels) { + const allowedModels = allowedFeatures.ai.chat.allowedModels; + if ( + !allowedModels || + allowedModels.length === 0 || + allowedModels.includes(model) + ) { + return { + success: true, + }; + } + return { + success: false, + errorCode: 'not_authorized', + errorMessage: + 'The subscription does not permit the given model for AI Chat features.', + }; + } + let maxTokens: number = undefined; if (allowedFeatures.ai.chat.maxTokensPerPeriod) { maxTokens = diff --git a/src/aux-records/SubscriptionConfiguration.ts b/src/aux-records/SubscriptionConfiguration.ts index 36507956a..ec3c15982 100644 --- a/src/aux-records/SubscriptionConfiguration.ts +++ b/src/aux-records/SubscriptionConfiguration.ts @@ -141,6 +141,12 @@ export const subscriptionFeaturesSchema = z.object({ .int() .positive() .optional(), + allowedModels: z + .array(z.string()) + .describe( + 'The list of model IDs that are allowed for the subscription. If omitted, then all models are allowed.' + ) + .optional(), }), images: z.object({ allowed: z @@ -938,6 +944,12 @@ export interface AIChatFeaturesConfiguration { * If not specified, then there is no limit. */ maxTokensPerPeriod?: number; + + /** + * The list of model IDs that are allowed for the subscription. + * If omitted, then all models are allowed. + */ + allowedModels?: string[]; } export interface AIImageFeaturesConfiguration {