From 05dbb39b090f3e959134263b58ccc93343041f65 Mon Sep 17 00:00:00 2001 From: Philippe Martin Date: Wed, 20 Mar 2024 15:07:57 +0100 Subject: [PATCH] feat: system prompt for playground page (#591) * feat: system prompt for playground page Signed-off-by: Philippe Martin * tests: add backend unit tests Signed-off-by: Philippe Martin * run completions.create async Signed-off-by: Philippe Martin --------- Signed-off-by: Philippe Martin --- .../src/managers/playgroundV2Manager.spec.ts | 69 ++++++++++++++-- .../src/managers/playgroundV2Manager.ts | 32 +++++--- packages/backend/src/studio-api-impl.ts | 9 ++- packages/frontend/src/pages/Playground.svelte | 78 +++++++++++++------ packages/shared/src/StudioAPI.ts | 7 +- 5 files changed, 152 insertions(+), 43 deletions(-) diff --git a/packages/backend/src/managers/playgroundV2Manager.spec.ts b/packages/backend/src/managers/playgroundV2Manager.spec.ts index 70773928e..a823c70c8 100644 --- a/packages/backend/src/managers/playgroundV2Manager.spec.ts +++ b/packages/backend/src/managers/playgroundV2Manager.spec.ts @@ -81,7 +81,7 @@ test('submit should throw an error if the server is stopped', async () => { } as unknown as InferenceServer, ]); - await expect(manager.submit('0', 'dummyUserInput')).rejects.toThrowError('Inference server is not running.'); + await expect(manager.submit('0', 'dummyUserInput', '')).rejects.toThrowError('Inference server is not running.'); }); test('submit should throw an error if the server is unhealthy', async () => { @@ -101,7 +101,7 @@ test('submit should throw an error if the server is unhealthy', async () => { const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock); await manager.createPlayground('p1', { id: 'model1' } as ModelInfo); const playgroundId = manager.getPlaygrounds()[0].id; - await expect(manager.submit(playgroundId, 'dummyUserInput')).rejects.toThrowError( + await expect(manager.submit(playgroundId, 'dummyUserInput', '')).rejects.toThrowError( 'Inference server is not healthy, currently status: unhealthy.', ); }); @@ -167,7 +167,7 @@ test('valid submit should create IPlaygroundMessage and notify the webview', asy vi.setSystemTime(date); const playgrounds = manager.getPlaygrounds(); - await manager.submit(playgrounds[0].id, 'dummyUserInput'); + await manager.submit(playgrounds[0].id, 'dummyUserInput', ''); // Wait for assistant message to be completed await vi.waitFor(() => { @@ -183,7 +183,7 @@ test('valid submit should create IPlaygroundMessage and notify the webview', asy id: expect.anything(), options: undefined, role: 'user', - timestamp: date.getTime(), + timestamp: expect.any(Number), }); expect(conversations[0].messages[1]).toStrictEqual({ choices: undefined, @@ -191,7 +191,7 @@ test('valid submit should create IPlaygroundMessage and notify the webview', asy content: '', id: expect.anything(), role: 'assistant', - timestamp: date.getTime(), + timestamp: expect.any(Number), }); expect(webviewMock.postMessage).toHaveBeenLastCalledWith({ @@ -200,6 +200,65 @@ test('valid submit should create IPlaygroundMessage and notify the webview', asy }); }); +test.each(['', 'my system prompt'])( + 'valid submit should send a message with system prompt if non empty, system prompt is "%s"}', + async (systemPrompt: string) => { + vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ + { + status: 'running', + health: { + Status: 'healthy', + }, + models: [ + { + id: 'dummyModelId', + file: { + file: 'dummyModelFile', + }, + }, + ], + connection: { + port: 8888, + }, + } as unknown as InferenceServer, + ]); + const createMock = vi.fn().mockResolvedValue([]); + vi.mocked(OpenAI).mockReturnValue({ + chat: { + completions: { + create: createMock, + }, + }, + } as unknown as OpenAI); + + const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock); + await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo); + + const playgrounds = manager.getPlaygrounds(); + await manager.submit(playgrounds[0].id, 'dummyUserInput', systemPrompt); + + const messages: unknown[] = [ + { + content: 'dummyUserInput', + id: expect.any(String), + role: 'user', + timestamp: expect.any(Number), + }, + ]; + if (systemPrompt) { + messages.push({ + content: 'my system prompt', + role: 'system', + }); + } + expect(createMock).toHaveBeenCalledWith({ + messages, + model: 'dummyModelFile', + stream: true, + }); + }, +); + test('creating a new playground should send new playground to frontend', async () => { vi.mocked(inferenceManagerMock.getServers).mockReturnValue([]); const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock); diff --git a/packages/backend/src/managers/playgroundV2Manager.ts b/packages/backend/src/managers/playgroundV2Manager.ts index fb31b2e7b..967212a82 100644 --- a/packages/backend/src/managers/playgroundV2Manager.ts +++ b/packages/backend/src/managers/playgroundV2Manager.ts @@ -88,7 +88,7 @@ export class PlaygroundV2Manager extends Publisher implements Di * @param userInput the user input * @param options the model configuration */ - async submit(playgroundId: string, userInput: string, options?: ModelOptions): Promise { + async submit(playgroundId: string, userInput: string, systemPrompt: string, options?: ModelOptions): Promise { const playground = this.#playgrounds.get(playgroundId); if (playground === undefined) throw new Error('Playground not found.'); @@ -123,16 +123,26 @@ export class PlaygroundV2Manager extends Publisher implements Di apiKey: 'dummy', }); - const response = await client.chat.completions.create({ - messages: this.getFormattedMessages(playground.id), - stream: true, - model: modelInfo.file.file, - ...options, - }); - // process stream async - this.processStream(playground.id, response).catch((err: unknown) => { - console.error('Something went wrong while processing stream', err); - }); + const messages = this.getFormattedMessages(playground.id); + if (systemPrompt) { + messages.push({ role: 'system', content: systemPrompt }); + } + client.chat.completions + .create({ + messages, + stream: true, + model: modelInfo.file.file, + ...options, + }) + .then(response => { + // process stream async + this.processStream(playground.id, response).catch((err: unknown) => { + console.error('Something went wrong while processing stream', err); + }); + }) + .catch((err: unknown) => { + console.error('Something went wrong while creating model reponse', err); + }); } /** diff --git a/packages/backend/src/studio-api-impl.ts b/packages/backend/src/studio-api-impl.ts index 2ff2d82ac..23b9665db 100644 --- a/packages/backend/src/studio-api-impl.ts +++ b/packages/backend/src/studio-api-impl.ts @@ -68,8 +68,13 @@ export class StudioApiImpl implements StudioAPI { return this.playgroundV2.getPlaygrounds(); } - submitPlaygroundMessage(containerId: string, userInput: string, options?: ModelOptions): Promise { - return this.playgroundV2.submit(containerId, userInput, options); + submitPlaygroundMessage( + containerId: string, + userInput: string, + systemPrompt: string, + options?: ModelOptions, + ): Promise { + return this.playgroundV2.submit(containerId, userInput, systemPrompt, options); } async getPlaygroundConversations(): Promise { diff --git a/packages/frontend/src/pages/Playground.svelte b/packages/frontend/src/pages/Playground.svelte index 3575bb787..ba597137a 100644 --- a/packages/frontend/src/pages/Playground.svelte +++ b/packages/frontend/src/pages/Playground.svelte @@ -14,6 +14,7 @@ import { playgrounds } from '../stores/playgrounds-v2'; import { catalog } from '../stores/catalog'; import Button from '../lib/button/Button.svelte'; import { afterUpdate } from 'svelte'; +import ContentDetailsLayout from '../lib/ContentDetailsLayout.svelte'; export let playgroundId: string; let prompt: string; @@ -22,6 +23,9 @@ let scrollable: Element; let lastIsUserMessage = false; let errorMsg = ''; +// settings +let systemPrompt: string = ''; + $: conversation = $conversations.find(conversation => conversation.id === playgroundId); $: playground = $playgrounds.find(playground => playground.id === playgroundId); $: model = $catalog.models.find(model => model.id === playground?.modelId); @@ -63,7 +67,7 @@ function getMessageParagraphs(message: ChatMessage): string[] { function askPlayground() { errorMsg = ''; sendEnabled = false; - studioClient.submitPlaygroundMessage(playgroundId, prompt).catch((err: unknown) => { + studioClient.submitPlaygroundMessage(playgroundId, prompt, systemPrompt, {}).catch((err: unknown) => { errorMsg = String(err); sendEnabled = true; }); @@ -101,30 +105,56 @@ function elapsedTime(msg: AssistantChat): string { {model?.name}
-
- {#if conversation?.messages} -
    - {#each conversation?.messages as message} -
  • -
    {roleNames[message.role]}
    -
    - {#each getMessageParagraphs(message) as paragraph} -

    {paragraph}

    - {/each} -
    - {#if isAssistantChat(message)} -
    {elapsedTime(message)} s
    +
    + + +
    +
    + {#if conversation?.messages} +
      + {#each conversation?.messages as message} +
    • +
      + {roleNames[message.role]} +
      +
      + {#each getMessageParagraphs(message) as paragraph} +

      {paragraph}

      + {/each} +
      + {#if isAssistantChat(message)} +
      + {elapsedTime(message)} s +
      + {/if} +
      +
    • + {/each} +
    {/if} -
    -
  • - {/each} -
- {/if} +
+
+
+ +
Next prompt will use these settings
+
+
System Prompt
+
+ +
+
+
+ {#if errorMsg}
{errorMsg}
diff --git a/packages/shared/src/StudioAPI.ts b/packages/shared/src/StudioAPI.ts index 3088936d0..f90dea2a9 100644 --- a/packages/shared/src/StudioAPI.ts +++ b/packages/shared/src/StudioAPI.ts @@ -118,7 +118,12 @@ export abstract class StudioAPI { * @param userInput the user input, e.g. 'What is the capital of France ?' * @param options the options for the model, e.g. temperature */ - abstract submitPlaygroundMessage(containerId: string, userInput: string, options?: ModelOptions): Promise; + abstract submitPlaygroundMessage( + containerId: string, + userInput: string, + systemPrompt: string, + options?: ModelOptions, + ): Promise; /** * Return the conversations