Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: set system prompt at the beginning #625

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 16 additions & 16 deletions packages/backend/src/managers/playgroundV2Manager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ test('submit should throw an error if the server is stopped', async () => {
} as unknown as InferenceServer,
]);

await expect(manager.submit('0', 'dummyUserInput', '')).rejects.toThrowError('Inference server is not running.');
await expect(manager.submit('0', 'dummyUserInput', false)).rejects.toThrowError('Inference server is not running.');
});

test('submit should throw an error if the server is unhealthy', async () => {
Expand All @@ -109,7 +109,7 @@ test('submit should throw an error if the server is unhealthy', async () => {
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('p1', { id: 'model1' } as ModelInfo, 'tracking-1');
const playgroundId = manager.getPlaygrounds()[0].id;
await expect(manager.submit(playgroundId, 'dummyUserInput', '')).rejects.toThrowError(
await expect(manager.submit(playgroundId, 'dummyUserInput', false)).rejects.toThrowError(
'Inference server is not healthy, currently status: unhealthy.',
);
});
Expand Down Expand Up @@ -175,7 +175,7 @@ test('valid submit should create IPlaygroundMessage and notify the webview', asy
vi.setSystemTime(date);

const playgrounds = manager.getPlaygrounds();
await manager.submit(playgrounds[0].id, 'dummyUserInput', '');
await manager.submit(playgrounds[0].id, 'dummyUserInput', false);

// Wait for assistant message to be completed
await vi.waitFor(() => {
Expand Down Expand Up @@ -208,9 +208,9 @@ test('valid submit should create IPlaygroundMessage and notify the webview', asy
});
});

test.each(['', 'my system prompt'])(
test.each([true, false])(
'valid submit should send a message with system prompt if non empty, system prompt is "%s"}',
async (systemPrompt: string) => {
async (systemPrompt: boolean) => {
vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
status: 'running',
Expand Down Expand Up @@ -249,21 +249,17 @@ test.each(['', 'my system prompt'])(
{
content: 'dummyUserInput',
id: expect.any(String),
role: 'user',
role: systemPrompt ? 'system' : 'user',
timestamp: expect.any(Number),
},
];
if (systemPrompt) {
messages.push({
content: 'my system prompt',
role: 'system',
if (!systemPrompt) {
expect(createMock).toHaveBeenCalledWith({
messages,
model: 'dummyModelFile',
stream: true,
});
}
expect(createMock).toHaveBeenCalledWith({
messages,
model: 'dummyModelFile',
stream: true,
});
},
);

Expand Down Expand Up @@ -300,7 +296,11 @@ test('submit should send options', async () => {
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1');

const playgrounds = manager.getPlaygrounds();
await manager.submit(playgrounds[0].id, 'dummyUserInput', '', { temperature: 0.123, max_tokens: 45, top_p: 0.345 });
await manager.submit(playgrounds[0].id, 'dummyUserInput', false, {
temperature: 0.123,
max_tokens: 45,
top_p: 0.345,
});

const messages: unknown[] = [
{
Expand Down
11 changes: 6 additions & 5 deletions packages/backend/src/managers/playgroundV2Manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ export class PlaygroundV2Manager extends Publisher<PlaygroundV2[]> implements Di
* @param userInput the user input
* @param options the model configuration
*/
async submit(playgroundId: string, userInput: string, systemPrompt: string, options?: ModelOptions): Promise<void> {
async submit(playgroundId: string, userInput: string, systemPrompt: boolean, options?: ModelOptions): Promise<void> {
const playground = this.#playgrounds.get(playgroundId);
if (playground === undefined) throw new Error('Playground not found.');

Expand All @@ -165,20 +165,21 @@ export class PlaygroundV2Manager extends Publisher<PlaygroundV2[]> implements Di
this.#conversationRegistry.submit(conversation.id, {
content: userInput,
options: options,
role: 'user',
role: systemPrompt ? 'system' : 'user',
id: this.getUniqueId(),
timestamp: Date.now(),
} as UserChat);

if (systemPrompt) {
return;
}

const client = new OpenAI({
baseURL: `http://localhost:${server.connection.port}/v1`,
apiKey: 'dummy',
});

const messages = this.getFormattedMessages(playground.id);
if (systemPrompt) {
messages.push({ role: 'system', content: systemPrompt });
}
client.chat.completions
.create({
messages,
Expand Down
2 changes: 1 addition & 1 deletion packages/backend/src/studio-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ export class StudioApiImpl implements StudioAPI {
submitPlaygroundMessage(
containerId: string,
userInput: string,
systemPrompt: string,
systemPrompt: boolean,
options?: ModelOptions,
): Promise<void> {
return this.playgroundV2.submit(containerId, userInput, systemPrompt, options);
Expand Down
122 changes: 120 additions & 2 deletions packages/frontend/src/pages/Playground.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@

import '@testing-library/jest-dom/vitest';
import { render, screen, waitFor, within } from '@testing-library/svelte';
import { expect, test, vi } from 'vitest';
import { beforeEach, describe, expect, test, vi } from 'vitest';
import Playground from './Playground.svelte';
import { studioClient } from '../utils/client';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import { fireEvent } from '@testing-library/dom';
import type { AssistantChat, Conversation, PendingChat, UserChat } from '@shared/src/models/IPlaygroundMessage';
import type {
AssistantChat,
ChatMessage,
Conversation,
PendingChat,
UserChat,
} from '@shared/src/models/IPlaygroundMessage';
import * as conversationsStore from '/@/stores/conversations';
import { writable } from 'svelte/store';
import userEvent from '@testing-library/user-event';
Expand Down Expand Up @@ -51,6 +57,10 @@ vi.mock('/@/stores/conversations', async () => {
};
});

beforeEach(() => {
vi.resetAllMocks();
});

test('should display playground and model names in header', async () => {
vi.mocked(studioClient.getCatalog).mockResolvedValue({
models: [
Expand Down Expand Up @@ -172,6 +182,7 @@ test('receiving complete message should enable the send button', async () => {
},
]);
const customConversations = writable<Conversation[]>([]);
vi.mocked(studioClient.submitPlaygroundMessage).mockResolvedValue();
vi.mocked(conversationsStore).conversations = customConversations;
render(Playground, {
playgroundId: 'playground-1',
Expand Down Expand Up @@ -233,6 +244,7 @@ test('sending prompt should display the prompt and the response', async () => {
},
]);
const customConversations = writable<Conversation[]>([]);
vi.mocked(studioClient.submitPlaygroundMessage).mockResolvedValue();
vi.mocked(conversationsStore).conversations = customConversations;
render(Playground, {
playgroundId: 'playground-1',
Expand Down Expand Up @@ -299,3 +311,109 @@ test('sending prompt should display the prompt and the response', async () => {
within(conversation).getByText('a response from the assistant');
});
});

describe('system prompt', () => {
test('system prompt textarea should be displayed before sending the first prompt and hidden after', async () => {
vi.mocked(studioClient.getCatalog).mockResolvedValue({
models: [
{
id: 'model-1',
name: 'Model 1',
},
] as ModelInfo[],
recipes: [],
categories: [],
});
vi.mocked(studioClient.getPlaygroundsV2).mockResolvedValue([
{
id: 'playground-1',
name: 'Playground 1',
modelId: 'model-1',
},
]);
const customConversations = writable<Conversation[]>([]);
vi.mocked(studioClient.submitPlaygroundMessage).mockResolvedValue();
vi.mocked(conversationsStore).conversations = customConversations;

render(Playground, {
playgroundId: 'playground-1',
});

const systemPromptTextarea = screen.getByLabelText('system-prompt-textarea');
expect(systemPromptTextarea).toBeInTheDocument();
await userEvent.type(systemPromptTextarea, 'a system prompt');

const setSystemPrompt = screen.getByLabelText('set-system-prompt');
expect(setSystemPrompt).toBeInTheDocument();
fireEvent.click(setSystemPrompt);

customConversations.set([
{
id: 'playground-1',
messages: [
{
role: 'system',
content: 'a system prompt',
} as ChatMessage,
],
},
]);

await waitFor(() => {
const textarea = screen.queryByLabelText('system-prompt-textarea');
expect(textarea).not.toBeInTheDocument();
});

await waitFor(() => {
const conversation = screen.getByLabelText('conversation');
within(conversation).getByText('a system prompt');
});
});

test('system prompt textarea should not be displayed when coming to the page with a started conversation', async () => {
vi.mocked(studioClient.getCatalog).mockResolvedValue({
models: [
{
id: 'model-1',
name: 'Model 1',
},
] as ModelInfo[],
recipes: [],
categories: [],
});
vi.mocked(studioClient.getPlaygroundsV2).mockResolvedValue([
{
id: 'playground-1',
name: 'Playground 1',
modelId: 'model-1',
},
]);
const customConversations = writable<Conversation[]>([
{
id: 'playground-1',
messages: [
{
role: 'system',
content: 'a system prompt',
} as ChatMessage,
],
},
]);
vi.mocked(studioClient.submitPlaygroundMessage).mockResolvedValue();
vi.mocked(conversationsStore).conversations = customConversations;

render(Playground, {
playgroundId: 'playground-1',
});

await waitFor(() => {
const textarea = screen.queryByLabelText('system-prompt-textarea');
expect(textarea).not.toBeInTheDocument();
});

await waitFor(() => {
const conversation = screen.getByLabelText('conversation');
within(conversation).getByText('a system prompt');
});
});
});