Skip to content

Commit

Permalink
[Epic] AI Insights + Assistant - Add "Other" option to the existing O…
Browse files Browse the repository at this point in the history
…penAI Connector dropdown list (elastic#8936)
  • Loading branch information
e40pud committed Oct 3, 2024
1 parent 6827ba4 commit 0dcf7e5
Show file tree
Hide file tree
Showing 26 changed files with 826 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export const Reader = z.object({}).catchall(z.unknown());
* Provider
*/
export type Provider = z.infer<typeof Provider>;
export const Provider = z.enum(['OpenAI', 'Azure OpenAI']);
export const Provider = z.enum(['OpenAI', 'Azure OpenAI', 'Other']);
export type ProviderEnum = typeof Provider.enum;
export const ProviderEnum = Provider.enum;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ components:
enum:
- OpenAI
- Azure OpenAI
- Other

MessageRole:
type: string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import { PRECONFIGURED_CONNECTOR } from './translations';
enum OpenAiProviderType {
OpenAi = 'OpenAI',
AzureAi = 'Azure OpenAI',
Other = 'Other',
}

interface GenAiConfig {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1025,15 +1025,17 @@ describe('actions telemetry', () => {
'.d3security': 2,
'.gen-ai__Azure OpenAI': 3,
'.gen-ai__OpenAI': 1,
'.gen-ai__Other': 1,
};
const { countByType, countGenAiProviderTypes } = getCounts(aggs);
expect(countByType).toEqual({
__d3security: 2,
'__gen-ai': 4,
'__gen-ai': 5,
});
expect(countGenAiProviderTypes).toEqual({
'Azure OpenAI': 3,
OpenAI: 1,
Other: 1,
});
});
});
1 change: 1 addition & 0 deletions x-pack/plugins/actions/server/usage/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export const byGenAiProviderTypeSchema: MakeSchemaFrom<ActionsUsage>['count_by_t
// Known providers:
['Azure OpenAI']: { type: 'long' },
['OpenAI']: { type: 'long' },
['Other']: { type: 'long' },
};

export const byServiceProviderTypeSchema: MakeSchemaFrom<ActionsUsage>['count_active_email_connectors_by_service_type'] =
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugins/search_playground/common/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ export enum APIRoutes {
export enum LLMs {
openai = 'openai',
openai_azure = 'openai_azure',
openai_other = 'openai_other',
bedrock = 'bedrock',
gemini = 'gemini',
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ const mockConnectors = [
{ id: 'connectorId1', title: 'OpenAI Connector', type: LLMs.openai },
{ id: 'connectorId2', title: 'OpenAI Azure Connector', type: LLMs.openai_azure },
{ id: 'connectorId2', title: 'Bedrock Connector', type: LLMs.bedrock },
{ id: 'connectorId3', title: 'OpenAI OSS Model Connector', type: LLMs.openai_other },
];
const mockUseLoadConnectors = (data: any) => {
(useLoadConnectors as jest.Mock).mockReturnValue({ data });
Expand Down Expand Up @@ -106,6 +107,18 @@ describe('useLLMsModels Hook', () => {
value: 'anthropic.claude-3-5-sonnet-20240620-v1:0',
promptTokenLimit: 200000,
},
{
connectorId: 'connectorId3',
connectorName: undefined,
connectorType: LLMs.openai_other,
disabled: false,
icon: expect.any(Function),
id: 'connectorId3Other (OpenAI Compatible Service) ',
name: 'Other (OpenAI Compatible Service) ',
showConnectorName: false,
value: undefined,
promptTokenLimit: undefined,
},
]);
});

Expand Down
11 changes: 11 additions & 0 deletions x-pack/plugins/search_playground/public/hooks/use_llms_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ const mapLlmToModels: Record<
},
],
},
[LLMs.openai_other]: {
icon: OpenAILogo,
getModels: (connectorName, includeName) => [
{
label: i18n.translate('xpack.searchPlayground.otherOpenAIModel', {
defaultMessage: 'Other (OpenAI Compatible Service) {name}',
values: { name: includeName ? `(${connectorName})` : '' },
}),
},
],
},
[LLMs.bedrock]: {
icon: BedrockLogo,
getModels: () =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ describe('useLoadConnectors', () => {
actionTypeId: '.bedrock',
isMissingSecrets: false,
},
{
id: '5',
actionTypeId: '.gen-ai',
isMissingSecrets: false,
config: { apiProvider: OpenAiProviderType.Other },
},
];
mockedLoadConnectors.mockResolvedValue(connectors);

Expand Down Expand Up @@ -106,6 +112,16 @@ describe('useLoadConnectors', () => {
title: 'Bedrock',
type: 'bedrock',
},
{
actionTypeId: '.gen-ai',
config: {
apiProvider: 'Other',
},
id: '5',
isMissingSecrets: false,
title: 'OpenAI Other',
type: 'openai_other',
},
]);
});
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,20 @@ const connectorTypeToLLM: Array<{
type: LLMs.openai,
}),
},
{
actionId: OPENAI_CONNECTOR_ID,
actionProvider: OpenAiProviderType.Other,
match: (connector) =>
connector.actionTypeId === OPENAI_CONNECTOR_ID &&
(connector as OpenAIConnector)?.config?.apiProvider === OpenAiProviderType.Other,
transform: (connector) => ({
...connector,
title: i18n.translate('xpack.searchPlayground.openAIOtherConnectorTitle', {
defaultMessage: 'OpenAI Other',
}),
type: LLMs.openai_other,
}),
},
{
actionId: BEDROCK_CONNECTOR_ID,
match: (connector) => connector.actionTypeId === BEDROCK_CONNECTOR_ID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import { isEmpty } from 'lodash/fp';
enum OpenAiProviderType {
OpenAi = 'OpenAI',
AzureAi = 'Azure OpenAI',
Other = 'Other',
}

interface GenAiConfig {
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugins/stack_connectors/common/openai/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export enum SUB_ACTION {
export enum OpenAiProviderType {
OpenAi = 'OpenAI',
AzureAi = 'Azure OpenAI',
Other = 'Other',
}

export const DEFAULT_TIMEOUT_MS = 120000;
Expand Down
6 changes: 6 additions & 0 deletions x-pack/plugins/stack_connectors/common/openai/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ export const ConfigSchema = schema.oneOf([
defaultModel: schema.string({ defaultValue: DEFAULT_OPENAI_MODEL }),
headers: schema.maybe(schema.recordOf(schema.string(), schema.string())),
}),
schema.object({
apiProvider: schema.oneOf([schema.literal(OpenAiProviderType.Other)]),
apiUrl: schema.string(),
defaultModel: schema.string(),
headers: schema.maybe(schema.recordOf(schema.string(), schema.string())),
}),
]);

export const SecretsSchema = schema.object({ apiKey: schema.string() });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ describe('useGetDashboard', () => {
it.each([
['Azure OpenAI', 'openai'],
['OpenAI', 'openai'],
['Other', 'openai'],
['Bedrock', 'bedrock'],
])(
'fetches the %p dashboard and sets the dashboard URL with %p',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ const azureConnector = {
apiKey: 'thats-a-nice-looking-key',
},
};
const otherOpenAiConnector = {
...openAiConnector,
config: {
apiUrl: 'https://localhost/oss-llm',
apiProvider: OpenAiProviderType.Other,
defaultModel: 'local-model',
},
secrets: {
apiKey: 'thats-a-nice-looking-key',
},
};

const navigateToUrl = jest.fn();

Expand Down Expand Up @@ -93,6 +104,24 @@ describe('ConnectorFields renders', () => {
expect(getAllByTestId('azure-ai-api-keys-doc')[0]).toBeInTheDocument();
});

test('other open ai connector fields are rendered', async () => {
const { getAllByTestId } = render(
<ConnectorFormTestProvider connector={otherOpenAiConnector}>
<ConnectorFields readOnly={false} isEdit={false} registerPreSubmitValidator={() => {}} />
</ConnectorFormTestProvider>
);
expect(getAllByTestId('config.apiUrl-input')[0]).toBeInTheDocument();
expect(getAllByTestId('config.apiUrl-input')[0]).toHaveValue(
otherOpenAiConnector.config.apiUrl
);
expect(getAllByTestId('config.apiProvider-select')[0]).toBeInTheDocument();
expect(getAllByTestId('config.apiProvider-select')[0]).toHaveValue(
otherOpenAiConnector.config.apiProvider
);
expect(getAllByTestId('other-ai-api-doc')[0]).toBeInTheDocument();
expect(getAllByTestId('other-ai-api-keys-doc')[0]).toBeInTheDocument();
});

describe('Dashboard link', () => {
it('Does not render if isEdit is false and dashboardUrl is defined', async () => {
const { queryByTestId } = render(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import * as i18n from './translations';
import {
azureAiConfig,
azureAiSecrets,
otherAiConfig,
otherAiSecrets,
openAiConfig,
openAiSecrets,
providerOptions,
Expand Down Expand Up @@ -85,6 +87,14 @@ const ConnectorFields: React.FC<ActionConnectorFieldsProps> = ({ readOnly, isEdi
secretsFormSchema={azureAiSecrets}
/>
)}
{config != null && config.apiProvider === OpenAiProviderType.Other && (
<SimpleConnectorForm
isEdit={isEdit}
readOnly={readOnly}
configFormSchema={otherAiConfig}
secretsFormSchema={otherAiSecrets}
/>
)}
{isEdit && (
<DashboardLink
connectorId={id}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,41 @@ export const azureAiConfig: ConfigFieldSchema[] = [
},
];

export const otherAiConfig: ConfigFieldSchema[] = [
{
id: 'apiUrl',
label: i18n.API_URL_LABEL,
isUrlField: true,
helpText: (
<FormattedMessage
defaultMessage="The Other (OpenAI Compatible Service) endpoint URL. For more information on the URL, refer to the {genAiAPIUrlDocs}."
id="xpack.stackConnectors.components.genAi.otherAiDocumentation"
values={{
genAiAPIUrlDocs: (
<EuiLink
data-test-subj="other-ai-api-doc"
href="https://www.elastic.co/guide/en/security/current/connect-to-byo-llm.html"
target="_blank"
>
{`${i18n.OTHER_AI} ${i18n.DOCUMENTATION}`}
</EuiLink>
),
}}
/>
),
},
{
id: 'defaultModel',
label: i18n.DEFAULT_MODEL_LABEL,
helpText: (
<FormattedMessage
defaultMessage="If a request does not include a model, it uses the default."
id="xpack.stackConnectors.components.genAi.otherAiDocumentationModel"
/>
),
},
];

export const openAiSecrets: SecretsFieldSchema[] = [
{
id: 'apiKey',
Expand Down Expand Up @@ -142,6 +177,31 @@ export const azureAiSecrets: SecretsFieldSchema[] = [
},
];

export const otherAiSecrets: SecretsFieldSchema[] = [
{
id: 'apiKey',
label: i18n.API_KEY_LABEL,
isPasswordField: true,
helpText: (
<FormattedMessage
defaultMessage="The Other (OpenAI Compatible Service) API key for HTTP Basic authentication. For more details about generating Other model API keys, refer to the {genAiAPIKeyDocs}."
id="xpack.stackConnectors.components.genAi.otherAiApiKeyDocumentation"
values={{
genAiAPIKeyDocs: (
<EuiLink
data-test-subj="other-ai-api-keys-doc"
href="https://www.elastic.co/guide/en/security/current/connect-to-byo-llm.html"
target="_blank"
>
{`${i18n.OTHER_AI} ${i18n.DOCUMENTATION}`}
</EuiLink>
),
}}
/>
),
},
];

export const providerOptions = [
{
value: OpenAiProviderType.OpenAi,
Expand All @@ -153,4 +213,9 @@ export const providerOptions = [
text: i18n.AZURE_AI,
label: i18n.AZURE_AI,
},
{
value: OpenAiProviderType.Other,
text: i18n.OTHER_AI,
label: i18n.OTHER_AI,
},
];
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ describe('Gen AI Params Fields renders', () => {
expect(getByTestId('bodyJsonEditor')).toHaveProperty('value', '{"message": "test"}');
expect(getByTestId('bodyAddVariableButton')).toBeInTheDocument();
});
test.each([OpenAiProviderType.OpenAi, OpenAiProviderType.AzureAi])(
test.each([OpenAiProviderType.OpenAi, OpenAiProviderType.AzureAi, OpenAiProviderType.Other])(
'useEffect handles the case when subAction and subActionParams are undefined and apiProvider is %p',
(apiProvider) => {
const actionParams = {
Expand Down Expand Up @@ -79,6 +79,9 @@ describe('Gen AI Params Fields renders', () => {
if (apiProvider === OpenAiProviderType.AzureAi) {
expect(editAction).toHaveBeenCalledWith('subActionParams', { body: DEFAULT_BODY_AZURE }, 0);
}
if (apiProvider === OpenAiProviderType.Other) {
expect(editAction).toHaveBeenCalledWith('subActionParams', { body: DEFAULT_BODY }, 0);
}
}
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ export const AZURE_AI = i18n.translate('xpack.stackConnectors.components.genAi.a
defaultMessage: 'Azure OpenAI',
});

export const OTHER_AI = i18n.translate('xpack.stackConnectors.components.genAi.otherAi', {
defaultMessage: 'Other (OpenAI Compatible Service)',
});

export const DOCUMENTATION = i18n.translate(
'xpack.stackConnectors.components.genAi.documentation',
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ export const configValidator = (configObject: Config, validatorServices: Validat

const { apiProvider } = configObject;

if (apiProvider !== OpenAiProviderType.OpenAi && apiProvider !== OpenAiProviderType.AzureAi) {
if (
apiProvider !== OpenAiProviderType.OpenAi &&
apiProvider !== OpenAiProviderType.AzureAi &&
apiProvider !== OpenAiProviderType.Other
) {
throw new Error(
`API Provider is not supported${
apiProvider && (apiProvider as OpenAiProviderType).length ? `: ${apiProvider}` : ``
Expand Down
Loading

0 comments on commit 0dcf7e5

Please sign in to comment.