Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Security solution] AI Assistant, replace LLM with SimpleChatModel + Bedrock streaming #182041

Merged
merged 44 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
4c82b22
wip
stephmilovic Apr 29, 2024
08dfd30
esql tool
stephmilovic Apr 29, 2024
e823543
alert tools
stephmilovic Apr 30, 2024
3908f40
fix merge
stephmilovic May 9, 2024
699abbb
rm structured
stephmilovic May 9, 2024
dc3d12c
stream wip
stephmilovic May 9, 2024
3e37927
Merge branch 'main' into simple_chat_model
stephmilovic May 14, 2024
385bf10
wip
stephmilovic May 14, 2024
4830879
zomg it streamed
stephmilovic May 15, 2024
ada553e
Merge branch 'main' into simple_chat_model
stephmilovic May 16, 2024
13cbfe8
this is awesome
stephmilovic May 16, 2024
b7c6af7
rm silly
stephmilovic May 16, 2024
af61cac
moar
stephmilovic May 16, 2024
cb83d04
merge in main
stephmilovic May 16, 2024
e0556f4
support non-streaming for ChatOpenAI
stephmilovic May 16, 2024
a3c5394
wip openai
stephmilovic May 17, 2024
109dfa0
openai really works
stephmilovic May 17, 2024
df3732e
fix executor tests
stephmilovic May 17, 2024
b123495
more tests
stephmilovic May 17, 2024
93fedb7
more
stephmilovic May 17, 2024
b53c2c5
more tests
stephmilovic May 20, 2024
14779d2
cleanup
stephmilovic May 20, 2024
082c5ba
Merge branch 'main' into simple_chat_model
stephmilovic May 20, 2024
4cdc699
[CI] Auto-commit changed files from 'node scripts/lint_ts_projects --…
kibanamachine May 20, 2024
88d759e
test fix
stephmilovic May 20, 2024
6d6a3c9
i18n fix
stephmilovic May 20, 2024
43cef41
Merge branch 'simple_chat_model' of github.com:stephmilovic/kibana in…
stephmilovic May 20, 2024
265afe2
rm disable
stephmilovic May 20, 2024
85f7d2d
server dir
stephmilovic May 21, 2024
8d7d540
move langchain code to langchain package
stephmilovic May 21, 2024
4025acd
fix types
stephmilovic May 21, 2024
47ff787
fix merge
stephmilovic May 21, 2024
5b0b77b
fix type
stephmilovic May 21, 2024
c463169
fix import
stephmilovic May 21, 2024
3a8c017
[CI] Auto-commit changed files from 'node scripts/lint_ts_projects --…
kibanamachine May 21, 2024
68b0b7b
[CI] Auto-commit changed files from 'node scripts/generate codeowners'
kibanamachine May 21, 2024
fd4e3c3
fix lint
stephmilovic May 21, 2024
6f11ca6
Merge branch 'simple_chat_model' of github.com:stephmilovic/kibana in…
stephmilovic May 21, 2024
726dcbd
[CI] Auto-commit changed files from 'node scripts/eslint --no-cache -…
kibanamachine May 21, 2024
afb7e49
more import fixing
stephmilovic May 21, 2024
b81bdf0
Merge branch 'simple_chat_model' of github.com:stephmilovic/kibana in…
stephmilovic May 21, 2024
2f66531
add readme comment, fix imports better
stephmilovic May 21, 2024
021de2b
fix!
stephmilovic May 21, 2024
f2d398d
Merge branch 'main' into simple_chat_model
stephmilovic May 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@

export { ActionsClientChatOpenAI } from './chat_openai';
export { ActionsClientLlm } from './llm';
export { ActionsClientSimpleChatModel } from './simple_chat_model';
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import {
SimpleChatModel,
type BaseChatModelParams,
} from '@langchain/core/language_models/chat_models';
import { type BaseMessage } from '@langchain/core/messages';
import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server';
import { Logger } from '@kbn/logging';
import { KibanaRequest } from '@kbn/core-http-server';
import { v4 as uuidv4 } from 'uuid';
import { get } from 'lodash/fp';
import { getDefaultArguments } from './constants';
import { ExecuteConnectorRequestBody } from '../..';

export const getMessageContentAndRole = (prompt: string, role = 'user') => ({
content: prompt,
role: role === 'human' ? 'user' : role,
});

export interface CustomChatModelInput extends BaseChatModelParams {
actions: ActionsPluginStart;
connectorId: string;
logger: Logger;
llmType?: string;
model?: string;
temperature?: number;
request: KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
}

export class ActionsClientSimpleChatModel extends SimpleChatModel {
#actions: ActionsPluginStart;
#connectorId: string;
#logger: Logger;
#request: KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
#traceId: string;
llmType: string;
model?: string;
temperature?: number;

constructor({
actions,
connectorId,
llmType,
logger,
model,
request,
temperature,
}: CustomChatModelInput) {
super({});

this.#actions = actions;
this.#connectorId = connectorId;
this.#traceId = uuidv4();
this.#logger = logger;
this.#request = request;
this.llmType = llmType ?? 'openai';
this.model = model;
this.temperature = temperature;
}

_llmType() {
return this.llmType;
}

// Model type needs to be `base_chat_model` to work with LangChain OpenAI Tools
// We may want to make this configurable (ala _llmType) if different agents end up requiring different model types
// See: https://github.com/langchain-ai/langchainjs/blob/fb699647a310c620140842776f4a7432c53e02fa/langchain/src/agents/openai/index.ts#L185
_modelType() {
return 'base_chat_model';
}

async _call(messages: BaseMessage[], options: this['ParsedCallOptions']): Promise<string> {
if (!messages.length) {
throw new Error('No messages provided.');
}
const formattedMessages = [];
if (messages.length === 2) {
messages.forEach((message, i) => {
if (typeof message.content !== 'string') {
throw new Error('Multimodal messages are not supported.');
}
formattedMessages.push(getMessageContentAndRole(message.content, message._getType()));
});
} else {
if (typeof messages[0].content !== 'string') {
throw new Error('Multimodal messages are not supported.');
}
formattedMessages.push(getMessageContentAndRole(messages[0].content));
}
this.#logger.debug(
`ActionsClientSimpleChatModel#_call\ntraceId: ${
this.#traceId
}\nassistantMessage:\n${JSON.stringify(formattedMessages)} `
);
// create a new connector request body with the assistant message:
const requestBody = {
actionId: this.#connectorId,
params: {
// hard code to non-streaming subaction as this class only supports non-streaming
subAction: 'invokeAI',
subActionParams: {
model: this.#request.body.model,
messages: formattedMessages,
...getDefaultArguments(this.llmType, this.temperature, options.stop),
},
},
};

// create an actions client from the authenticated request context:
const actionsClient = await this.#actions.getActionsClientWithRequest(this.#request);

const actionResult = await actionsClient.execute(requestBody);

if (actionResult.status === 'error') {
throw new Error(
`ActionsClientSimpleChatModel: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}`
);
}

const content = get('data.message', actionResult);

if (typeof content !== 'string') {
throw new Error(
`ActionsClientSimpleChatModel: content should be a string, but it had an unexpected type: ${typeof content}`
);
}

return content; // per the contact of _call, return a string
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import { transformError } from '@kbn/securitysolution-es-utils';
import { RetrievalQAChain } from 'langchain/chains';
import {
ActionsClientChatOpenAI,
ActionsClientLlm,
ActionsClientSimpleChatModel,
} from '@kbn/elastic-assistant-common/impl/language_models';
import { getDefaultArguments } from '@kbn/elastic-assistant-common/impl/language_models/constants';
import { ElasticsearchStore } from '../elasticsearch_store/elasticsearch_store';
Expand All @@ -26,7 +26,7 @@ import { AssistantToolParams } from '../../../types';
export const DEFAULT_AGENT_EXECUTOR_ID = 'Elastic AI Assistant Agent Executor';

/**
* The default agent executor used by the Elastic AI Assistant. Main agent/chain that wraps the ActionsClientLlm,
* The default agent executor used by the Elastic AI Assistant. Main agent/chain that wraps the ActionsClientSimpleChatModel,
* sets up a conversation BufferMemory from chat history, and registers tools like the ESQLKnowledgeBaseTool.
*
*/
Expand Down Expand Up @@ -55,7 +55,7 @@ export const callAgentExecutor: AgentExecutor<true | false> = async ({
}) => {
// TODO implement llmClass for bedrock streaming
// tracked here: https://github.com/elastic/security-team/issues/7363
const llmClass = isStream ? ActionsClientChatOpenAI : ActionsClientLlm;
const llmClass = isStream ? ActionsClientChatOpenAI : ActionsClientSimpleChatModel;

const llm = new llmClass({
actions,
Expand Down Expand Up @@ -130,11 +130,16 @@ export const callAgentExecutor: AgentExecutor<true | false> = async ({
verbose: false,
})
: await initializeAgentExecutorWithOptions(tools, llm, {
agentType: 'chat-conversational-react-description',
agentType: 'structured-chat-zero-shot-react-description',
memory,
verbose: false,
returnIntermediateSteps: false,
handleParsingErrors: 'Try again, paying close attention to the allowed tool input',
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handleParsingErrors can also be a function, so I put logs in here while developing. It was not hit once I added the agentArgs. However when I was hitting it, it did help to get the agent back on track so I think I should leave it. Here is an example of a run where it was hit: https://smith.langchain.com/public/910b7739-cd9a-401b-9db0-fe8438ff53e5/r

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agentArgs: {
// this is important to help LangChain correctly format tool input
humanMessageTemplate: `Question: {input}\n\n{agent_scratchpad}`,
},
});

// Sets up tracer for tracing executions to APM. See x-pack/plugins/elastic_assistant/server/lib/langchain/tracers/README.mdx
// If LangSmith env vars are set, executions will be traced there as well. See https://docs.smith.langchain.com/tracing
const apmTracer = new APMTracer({ projectName: traceOptions?.projectName ?? 'default' }, logger);
Expand Down
6 changes: 4 additions & 2 deletions x-pack/plugins/elastic_assistant/server/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ import { LicensingApiRequestHandlerContext } from '@kbn/licensing-plugin/server'
import {
ActionsClientChatOpenAI,
ActionsClientLlm,
ActionsClientSimpleChatModel,
} from '@kbn/elastic-assistant-common/impl/language_models';

import { DynamicStructuredTool } from '@langchain/core/dist/tools';
import { AIAssistantConversationsDataClient } from './ai_assistant_data_clients/conversations';
import type { GetRegisteredFeatures, GetRegisteredTools } from './services/app_context';
import { AIAssistantDataClient } from './ai_assistant_data_clients';
Expand Down Expand Up @@ -202,7 +204,7 @@ export interface AssistantTool {
description: string;
sourceRegister: string;
isSupported: (params: AssistantToolParams) => boolean;
getTool: (params: AssistantToolParams) => Tool | null;
getTool: (params: AssistantToolParams) => Tool | DynamicStructuredTool | null;
}

export interface AssistantToolParams {
Expand All @@ -211,7 +213,7 @@ export interface AssistantToolParams {
isEnabledKnowledgeBase: boolean;
chain?: RetrievalQAChain;
esClient: ElasticsearchClient;
llm?: ActionsClientLlm | ActionsClientChatOpenAI;
llm?: ActionsClientLlm | ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
modelExists: boolean;
onNewReplacements?: (newReplacements: Replacements) => void;
replacements?: Replacements;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
* 2.0.
*/

import { DynamicTool } from '@langchain/core/tools';
import { DynamicStructuredTool } from '@langchain/core/tools';
import { z } from 'zod';
import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server';
import { APP_UI_ID } from '../../../../common';

export type EsqlKnowledgeBaseToolParams = AssistantToolParams;

const toolDetails = {
description:
'Call this for knowledge on how to build an ESQL query, or answer questions about the ES|QL query language. Input must always be the query on a single line, with no other text. Only output valid ES|QL queries as described above. Do not add any additional text to describe your output.',
'Call this for knowledge on how to build an ESQL query, or answer questions about the ES|QL query language. Input must always be the query on a single line, with no other text. Only output valid ES|QL queries as described above. Format ESQL correctly by surrounding it with back ticks. Do not add any additional text to describe your output.',
id: 'esql-knowledge-base-tool',
name: 'ESQLKnowledgeBaseTool',
};
Expand All @@ -30,13 +31,16 @@ export const ESQL_KNOWLEDGE_BASE_TOOL: AssistantTool = {
const { chain } = params as EsqlKnowledgeBaseToolParams;
if (chain == null) return null;

return new DynamicTool({
return new DynamicStructuredTool({
name: toolDetails.name,
description: toolDetails.description,
schema: z.object({
question: z.string().describe(`The user's exact question about ESQL`),
}),
func: async (input, _, cbManager) => {
const result = await chain.invoke(
{
query: input,
query: input.question,
},
cbManager
);
Expand Down
15 changes: 9 additions & 6 deletions x-pack/plugins/stack_connectors/common/openai/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,15 @@ export const InvokeAIActionParamsSchema = schema.object({
{
name: schema.string(),
description: schema.string(),
parameters: schema.object({
type: schema.string(),
properties: schema.object({}, { unknowns: 'allow' }),
additionalProperties: schema.boolean(),
$schema: schema.string(),
}),
parameters: schema.object(
{
type: schema.string(),
properties: schema.object({}, { unknowns: 'allow' }),
additionalProperties: schema.boolean(),
$schema: schema.string(),
},
{ unknowns: 'allow' }
),
},
// Not sure if this will include other properties, we should pass them if it does
{ unknowns: 'allow' }
Expand Down