Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/add ai.chat.allowed models feature limit #552

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
- Webhooks always act like static insts.
- This means that any changes made to bots in the webhook are erased after the webhook finishes.
- Added the ability to request consent again so that a parent can adjust the privacy features for their child.
- Added `ai.chat.allowedModels` feature to enforce model usage limits, restricting access to specific models based on configuration.

### :bug: Bug Fixes

Expand Down
210 changes: 210 additions & 0 deletions src/aux-records/AIController.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,111 @@ describe('AIController', () => {
expect(chatInterface.chat).not.toBeCalled();
});

it('should return success when allowedModels includes the model', async () => {
chatInterface.chat.mockReturnValueOnce(
Promise.resolve({
choices: [
{
role: 'user',
content: 'test',
finishReason: 'stop',
},
],
totalTokens: 1,
})
);

const result = await controller.chat({
model: 'test-model1',
messages: [
{
role: 'user',
content: 'test',
},
],
temperature: 0.5,
userId,
userSubscriptionTier,
});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TroyceGowdy This test should be inside the "subscriptions" group and it should setup a subscription configuration that specifies the allowed models for the subscription tier that the user has. See lines 626-649 for an example of setting up the subscription configuration.

The way this test is currently written, it doesn't actually test the new subscription feature limit and only tests the server config.


expect(result).toEqual({
success: true,
choices: [
{
role: 'user',
content: 'test',
finishReason: 'stop',
},
],
});
expect(chatInterface.chat).toBeCalledWith({
model: 'test-model1',
messages: [
{
role: 'user',
content: 'test',
},
],
temperature: 0.5,
userId: 'test-user',
});
});

it('should return not_authorized when allowedModels does not include the model', async () => {
controller = new AIController({
chat: {
interfaces: {
provider1: chatInterface,
},
options: {
defaultModel: 'default-model',
defaultModelProvider: 'provider1',
allowedChatModels: [
{
provider: 'provider1',
model: 'modelA',
},
{
provider: 'provider1',
model: 'modelB',
},
],
allowedChatSubscriptionTiers: ['test-tier'],
},
},
generateSkybox: null,
images: null,
metrics: store,
config: store,
hume: null,
sloyd: null,
policies: null,
policyController: policies,
records: store,
});
Comment on lines +372 to +403
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TroyceGowdy Same with this test. You shouldn't have to create a new AIController, instead you need to setup a subscription config.


const result = await controller.chat({
model: 'modelC',
messages: [
{
role: 'user',
content: 'test',
},
],
temperature: 0.5,
userId,
userSubscriptionTier,
});

expect(result).toEqual({
success: false,
errorCode: 'not_authorized',
errorMessage:
'The subscription does not permit the given model for AI Chat features.',
});
expect(chatInterface.chat).not.toBeCalled();
});

it('should return an not_logged_in result if the given a null userId', async () => {
const result = await controller.chat({
model: 'test-model1',
Expand Down Expand Up @@ -1502,6 +1607,111 @@ describe('AIController', () => {
});
});

it('should return success when allowedModels includes the model', async () => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TroyceGowdy Same here.

chatInterface.chat.mockReturnValueOnce(
Promise.resolve({
choices: [
{
role: 'user',
content: 'test',
finishReason: 'stop',
},
],
totalTokens: 1,
})
);

const result = await controller.chat({
model: 'test-model1',
messages: [
{
role: 'user',
content: 'test',
},
],
temperature: 0.5,
userId,
userSubscriptionTier,
});

expect(result).toEqual({
success: true,
choices: [
{
role: 'user',
content: 'test',
finishReason: 'stop',
},
],
});
expect(chatInterface.chat).toBeCalledWith({
model: 'test-model1',
messages: [
{
role: 'user',
content: 'test',
},
],
temperature: 0.5,
userId: 'test-user',
});
});

it('should return not_authorized error when allowedModels does not include the model', async () => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TroyceGowdy And here.

controller = new AIController({
chat: {
interfaces: {
provider1: chatInterface,
},
options: {
defaultModel: 'default-model',
defaultModelProvider: 'provider1',
allowedChatModels: [
{
provider: 'provider1',
model: 'modelA',
},
{
provider: 'provider1',
model: 'modelB',
},
],
allowedChatSubscriptionTiers: ['test-tier'],
},
},
generateSkybox: null,
images: null,
metrics: store,
config: store,
hume: null,
sloyd: null,
policies: null,
policyController: policies,
records: store,
});

const result = await controller.chat({
model: 'modelC',
messages: [
{
role: 'user',
content: 'test',
},
],
temperature: 0.5,
userId,
userSubscriptionTier,
});

expect(result).toEqual({
success: false,
errorCode: 'not_authorized',
errorMessage:
'The subscription does not permit the given model for AI Chat features.',
});
expect(chatInterface.chat).not.toBeCalled();
});

describe('subscriptions', () => {
beforeEach(async () => {
store.subscriptionConfiguration = buildSubscriptionConfig(
Expand Down
39 changes: 39 additions & 0 deletions src/aux-records/AIController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,26 @@ export class AIController {
maxTokens,
});

if (allowedFeatures.ai.chat.allowedModels) {
const allowedModels = allowedFeatures.ai.chat.allowedModels;
if (
!allowedModels ||
allowedModels.length === 0 ||
allowedModels.includes(model)
) {
return {
success: true,
choices: result.choices,
};
}
return {
success: false,
errorCode: 'not_authorized',
errorMessage:
'The subscription does not permit the given model for AI Chat features.',
};
}

if (result.totalTokens > 0) {
await this._metrics.recordChatMetrics({
userId: request.userId,
Expand Down Expand Up @@ -598,6 +618,25 @@ export class AIController {
};
}

if (allowedFeatures.ai.chat.allowedModels) {
const allowedModels = allowedFeatures.ai.chat.allowedModels;
if (
!allowedModels ||
allowedModels.length === 0 ||
allowedModels.includes(model)
) {
return {
success: true,
};
}
return {
success: false,
errorCode: 'not_authorized',
errorMessage:
'The subscription does not permit the given model for AI Chat features.',
};
}

let maxTokens: number = undefined;
if (allowedFeatures.ai.chat.maxTokensPerPeriod) {
maxTokens =
Expand Down
12 changes: 12 additions & 0 deletions src/aux-records/SubscriptionConfiguration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ export const subscriptionFeaturesSchema = z.object({
.int()
.positive()
.optional(),
allowedModels: z
.array(z.string())
.describe(
'The list of model IDs that are allowed for the subscription. If omitted, then all models are allowed.'
)
.optional(),
}),
images: z.object({
allowed: z
Expand Down Expand Up @@ -938,6 +944,12 @@ export interface AIChatFeaturesConfiguration {
* If not specified, then there is no limit.
*/
maxTokensPerPeriod?: number;

/**
* The list of model IDs that are allowed for the subscription.
* If omitted, then all models are allowed.
*/
allowedModels?: string[];
}

export interface AIImageFeaturesConfiguration {
Expand Down
Loading