Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ask the system prompt from the Playground creation form #643

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
56 changes: 46 additions & 10 deletions packages/backend/src/managers/playgroundV2Manager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ test('submit should throw an error if the server is stopped', async () => {
} as unknown as InferenceServer,
]);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('playground 1', { id: 'model1' } as ModelInfo, 'tracking-1');
await manager.createPlayground('playground 1', { id: 'model1' } as ModelInfo, '', 'tracking-1');

vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
Expand Down Expand Up @@ -107,7 +107,7 @@ test('submit should throw an error if the server is unhealthy', async () => {
} as unknown as InferenceServer,
]);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('p1', { id: 'model1' } as ModelInfo, 'tracking-1');
await manager.createPlayground('p1', { id: 'model1' } as ModelInfo, '', 'tracking-1');
const playgroundId = manager.getPlaygrounds()[0].id;
await expect(manager.submit(playgroundId, 'dummyUserInput', '')).rejects.toThrowError(
'Inference server is not healthy, currently status: unhealthy.',
Expand All @@ -133,12 +133,42 @@ test('create playground should create conversation.', async () => {
]);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
expect(manager.getConversations().length).toBe(0);
await manager.createPlayground('playground 1', { id: 'model-1' } as ModelInfo, 'tracking-1');
await manager.createPlayground('playground 1', { id: 'model-1' } as ModelInfo, '', 'tracking-1');

const conversations = manager.getConversations();
expect(conversations.length).toBe(1);
});

test('create playground called with a system prompt should create conversation with a system message.', async () => {
vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
status: 'running',
health: {
Status: 'healthy',
},
models: [
{
id: 'dummyModelId',
file: {
file: 'dummyModelFile',
},
},
],
} as unknown as InferenceServer,
]);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
expect(manager.getConversations().length).toBe(0);
await manager.createPlayground('playground 1', { id: 'model-1' } as ModelInfo, 'a system prompt', 'tracking-1');

const conversations = manager.getConversations();
expect(conversations.length).toBe(1);
const conversation = conversations[0];
expect(conversation.messages).toHaveLength(1);
const systemMessage = conversation.messages[0];
expect(systemMessage.role).toEqual('system');
expect(systemMessage.content).toEqual('a system prompt');
});

test('valid submit should create IPlaygroundMessage and notify the webview', async () => {
vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
Expand Down Expand Up @@ -169,7 +199,7 @@ test('valid submit should create IPlaygroundMessage and notify the webview', asy
} as unknown as OpenAI);

const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1');
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, '', 'tracking-1');

const date = new Date(2000, 1, 1, 13);
vi.setSystemTime(date);
Expand Down Expand Up @@ -240,7 +270,7 @@ test.each(['', 'my system prompt'])(
} as unknown as OpenAI);

const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1');
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, '', 'tracking-1');

const playgrounds = manager.getPlaygrounds();
await manager.submit(playgrounds[0].id, 'dummyUserInput', systemPrompt);
Expand Down Expand Up @@ -297,7 +327,7 @@ test('submit should send options', async () => {
} as unknown as OpenAI);

const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1');
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, '', 'tracking-1');

const playgrounds = manager.getPlaygrounds();
await manager.submit(playgrounds[0].id, 'dummyUserInput', '', { temperature: 0.123, max_tokens: 45, top_p: 0.345 });
Expand Down Expand Up @@ -334,6 +364,7 @@ test('creating a new playground should send new playground to frontend', async (
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);
expect(webviewMock.postMessage).toHaveBeenCalledWith({
Expand All @@ -357,6 +388,7 @@ test('creating a new playground with no name should send new playground to front
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);
expect(webviewMock.postMessage).toHaveBeenCalledWith({
Expand All @@ -381,6 +413,7 @@ test('creating a new playground with no model served should start an inference s
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);
expect(createInferenceServerMock).toHaveBeenCalledWith(
Expand Down Expand Up @@ -417,6 +450,7 @@ test('creating a new playground with the model already served should not start a
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);
expect(createInferenceServerMock).not.toHaveBeenCalled();
Expand Down Expand Up @@ -445,6 +479,7 @@ test('creating a new playground with the model server stopped should start the i
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);
expect(createInferenceServerMock).not.toHaveBeenCalled();
Expand All @@ -462,6 +497,7 @@ test('delete conversation should delete the conversation', async () => {
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);

Expand All @@ -484,9 +520,9 @@ test('requestCreatePlayground should call createPlayground and createTask, then
});
const createPlaygroundSpy = vi.spyOn(manager, 'createPlayground').mockResolvedValue('playground-1');

const id = await manager.requestCreatePlayground('a name', { id: 'model-1' } as ModelInfo);
const id = await manager.requestCreatePlayground('a name', { id: 'model-1' } as ModelInfo, '');

expect(createPlaygroundSpy).toHaveBeenCalledWith('a name', { id: 'model-1' } as ModelInfo, expect.any(String));
expect(createPlaygroundSpy).toHaveBeenCalledWith('a name', { id: 'model-1' } as ModelInfo, '', expect.any(String));
expect(createTaskMock).toHaveBeenCalledWith('Creating Playground environment', 'loading', {
trackingId: id,
});
Expand All @@ -513,9 +549,9 @@ test('requestCreatePlayground should call createPlayground and createTask, then
});
const createPlaygroundSpy = vi.spyOn(manager, 'createPlayground').mockRejectedValue(new Error('an error'));

const id = await manager.requestCreatePlayground('a name', { id: 'model-1' } as ModelInfo);
const id = await manager.requestCreatePlayground('a name', { id: 'model-1' } as ModelInfo, '');

expect(createPlaygroundSpy).toHaveBeenCalledWith('a name', { id: 'model-1' } as ModelInfo, expect.any(String));
expect(createPlaygroundSpy).toHaveBeenCalledWith('a name', { id: 'model-1' } as ModelInfo, '', expect.any(String));
expect(createTaskMock).toHaveBeenCalledWith('Creating Playground environment', 'loading', {
trackingId: id,
});
Expand Down
17 changes: 13 additions & 4 deletions packages/backend/src/managers/playgroundV2Manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import type { ChatCompletionChunk, ChatCompletionMessageParam } from 'openai/src
import type { ModelOptions } from '@shared/src/models/IModelOptions';
import type { Stream } from 'openai/streaming';
import { ConversationRegistry } from '../registries/conversationRegistry';
import type { Conversation, PendingChat, UserChat } from '@shared/src/models/IPlaygroundMessage';
import type { Conversation, PendingChat, SystemPrompt, UserChat } from '@shared/src/models/IPlaygroundMessage';
import type { PlaygroundV2 } from '@shared/src/models/IPlaygroundV2';
import { Publisher } from '../utils/Publisher';
import { Messages } from '@shared/Messages';
Expand Down Expand Up @@ -54,13 +54,13 @@ export class PlaygroundV2Manager extends Publisher<PlaygroundV2[]> implements Di
this.notify();
}

async requestCreatePlayground(name: string, model: ModelInfo): Promise<string> {
async requestCreatePlayground(name: string, model: ModelInfo, systemPrompt: string): Promise<string> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
async requestCreatePlayground(name: string, model: ModelInfo, systemPrompt: string): Promise<string> {
async requestCreatePlayground(name: string, model: ModelInfo, systemPrompt?: string): Promise<string> {

const trackingId: string = getRandomString();
const task = this.taskRegistry.createTask('Creating Playground environment', 'loading', {
trackingId: trackingId,
});

this.createPlayground(name, model, trackingId)
this.createPlayground(name, model, systemPrompt, trackingId)
.then((playgroundId: string) => {
this.taskRegistry.updateTask({
...task,
Expand Down Expand Up @@ -94,7 +94,7 @@ export class PlaygroundV2Manager extends Publisher<PlaygroundV2[]> implements Di
return trackingId;
}

async createPlayground(name: string, model: ModelInfo, trackingId: string): Promise<string> {
async createPlayground(name: string, model: ModelInfo, systemPrompt: string, trackingId: string): Promise<string> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
async createPlayground(name: string, model: ModelInfo, systemPrompt: string, trackingId: string): Promise<string> {
async createPlayground(name: string, model: ModelInfo, systemPrompt: string | undefined, trackingId: string): Promise<string> {

const id = `${this.#playgroundCounter++}`;

if (!name) {
Expand All @@ -103,6 +103,15 @@ export class PlaygroundV2Manager extends Publisher<PlaygroundV2[]> implements Di

this.#conversationRegistry.createConversation(id);

if (systemPrompt) {
this.#conversationRegistry.submit(id, {
content: systemPrompt,
role: 'system',
id: this.getUniqueId(),
timestamp: Date.now(),
} as SystemPrompt);
}

// create/start inference server if necessary
const servers = this.inferenceManager.getServers();
const server = servers.find(s => s.models.map(mi => mi.id).includes(model.id));
Expand Down
4 changes: 2 additions & 2 deletions packages/backend/src/studio-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ export class StudioApiImpl implements StudioAPI {
});
}

async requestCreatePlayground(name: string, model: ModelInfo): Promise<string> {
async requestCreatePlayground(name: string, model: ModelInfo, systemPrompt: string): Promise<string> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
async requestCreatePlayground(name: string, model: ModelInfo, systemPrompt: string): Promise<string> {
async requestCreatePlayground(name: string, model: ModelInfo, systemPrompt?: string): Promise<string> {

try {
return this.playgroundV2.requestCreatePlayground(name, model);
return this.playgroundV2.requestCreatePlayground(name, model, systemPrompt);
} catch (err: unknown) {
console.error('Something went wrong while trying to create playground environment', err);
throw err;
Expand Down
20 changes: 5 additions & 15 deletions packages/frontend/src/pages/Playground.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
isPendingChat,
isUserChat,
type AssistantChat,
isSystemPrompt,
} from '@shared/src/models/IPlaygroundMessage';
import NavPage from '../lib/NavPage.svelte';
import { playgrounds } from '../stores/playgrounds-v2';
Expand Down Expand Up @@ -46,7 +47,7 @@ $: {
}

const roleNames = {
system: 'System',
system: 'System prompt',
user: 'User',
assistant: 'Assistant',
};
Expand All @@ -61,9 +62,8 @@ function getMessageParagraphs(message: ChatMessage): string[] {
.join('')
.split('\n');
}
} else if (isUserChat(message)) {
const msg = message as UserChat;
return msg.content?.split('\n') ?? [];
} else if (isUserChat(message) || isSystemPrompt(message)) {
return message.content?.split('\n') ?? [];
}
return [];
}
Expand Down Expand Up @@ -130,6 +130,7 @@ function elapsedTime(msg: AssistantChat): string {
<div
class="p-4 rounded-md"
class:bg-charcoal-400="{isUserChat(message)}"
class:bg-charcoal-800="{isSystemPrompt(message)}"
class:bg-charcoal-900="{isAssistantChat(message)}"
class:ml-8="{isAssistantChat(message)}"
class:mr-8="{isUserChat(message)}">
Expand All @@ -152,17 +153,6 @@ function elapsedTime(msg: AssistantChat): string {
</svelte:fragment>
<svelte:fragment slot="details">
<div class="text-gray-800 text-xs">Next prompt will use these settings</div>
<div class="bg-charcoal-600 w-full rounded-md text-xs p-4">
<div class="mb-4">System Prompt</div>
<div class="w-full">
<textarea
bind:value="{systemPrompt}"
class="p-2 w-full outline-none bg-charcoal-500 rounded-sm text-gray-700 placeholder-gray-700"
rows="4"
placeholder="Provide system prompt to define general context, instructions or guidelines to be used with each query"
></textarea>
</div>
</div>
<div class="bg-charcoal-600 w-full rounded-md text-xs p-4">
<div class="mb-4 flex flex-col">Model Parameters</div>
<div class="flex flex-col space-y-4">
Expand Down
12 changes: 11 additions & 1 deletion packages/frontend/src/pages/PlaygroundCreate.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ let localModels: ModelInfo[];
$: localModels = $modelsInfo.filter(model => model.file);
$: availModels = $modelsInfo.filter(model => !model.file);
let modelId: string | undefined = undefined;
let systemPrompt: string | undefined = undefined;
let submitted: boolean = false;
let playgroundName: string;

Expand Down Expand Up @@ -55,7 +56,7 @@ async function submit() {
// disable submit button
submitted = true;
try {
trackingId = await studioClient.requestCreatePlayground(playgroundName, model);
trackingId = await studioClient.requestCreatePlayground(playgroundName, model, systemPrompt ?? '');
Copy link
Contributor

@axel7083 axel7083 Mar 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
trackingId = await studioClient.requestCreatePlayground(playgroundName, model, systemPrompt ?? '');
trackingId = await studioClient.requestCreatePlayground(playgroundName, model, systemPrompt);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my first attempt, but the problem I can see is that systemPrompt can be either undefined or an empty string, depending on if the user lets the textarea untouched, or edits its content and removes it. I wanted to have only one value (the empty string) to indicate the systemPrompt is empty. But as the two values are handled correctly on the other side, we can do like this if you prefer

Copy link
Contributor

@axel7083 axel7083 Mar 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the empty value should not be allowed, and be replaced with undefined, as it should be fully optional from the frontend POV

} catch (err: unknown) {
trackingId = undefined;
console.error('Something wrong while trying to create the playground.', err);
Expand Down Expand Up @@ -161,6 +162,15 @@ onDestroy(() => {
</div>
</div>
{/if}

<label for="model" class="pt-4 block mb-2 text-sm font-bold text-gray-400">System prompt</label>
<textarea
aria-label="system-prompt-textarea"
bind:value="{systemPrompt}"
class="w-full p-2 outline-none text-sm bg-charcoal-600 rounded-sm text-gray-700 placeholder-gray-700"
rows="4"
placeholder="Optionally provide system prompt to define general context, instructions or guidelines to be used with each query"
></textarea>
</div>
<footer>
<div class="w-full flex flex-col">
Expand Down
2 changes: 1 addition & 1 deletion packages/shared/src/StudioAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ export abstract class StudioAPI {
*/
abstract createSnippet(options: RequestOptions, language: string, variant: string): Promise<string>;

abstract requestCreatePlayground(name: string, model: ModelInfo): Promise<string>;
abstract requestCreatePlayground(name: string, model: ModelInfo, systemPrompt: string): Promise<string>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
abstract requestCreatePlayground(name: string, model: ModelInfo, systemPrompt: string): Promise<string>;
abstract requestCreatePlayground(name: string, model: ModelInfo, systemPrompt?: string): Promise<string>;


abstract getPlaygroundsV2(): Promise<PlaygroundV2[]>;

Expand Down
9 changes: 9 additions & 0 deletions packages/shared/src/models/IPlaygroundMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ export interface AssistantChat extends ChatMessage {
completed?: number;
}

export interface SystemPrompt extends ChatMessage {
role: 'system';
content: string;
}

export interface PendingChat extends AssistantChat {
completed: undefined;
choices: Choice[];
Expand Down Expand Up @@ -60,3 +65,7 @@ export function isUserChat(msg: ChatMessage): msg is UserChat {
export function isPendingChat(msg: ChatMessage): msg is PendingChat {
return isAssistantChat(msg) && !msg.completed;
}

export function isSystemPrompt(msg: ChatMessage): msg is SystemPrompt {
return msg.role === 'system';
}