Skip to content

Commit

Permalink
ask the system prompt from the Playground creation form (#643)
Browse files Browse the repository at this point in the history
* ask the system prompt from the Playground creation form

Signed-off-by: Philippe Martin <phmartin@redhat.com>

* pass empty string systemPrompt as undefined

Signed-off-by: Philippe Martin <phmartin@redhat.com>

* Update packages/backend/src/managers/playgroundV2Manager.ts

Co-authored-by: Jeff MAURY <jmaury@redhat.com>
Signed-off-by: Philippe Martin <feloy1@gmail.com>

---------

Signed-off-by: Philippe Martin <phmartin@redhat.com>
Signed-off-by: Philippe Martin <feloy1@gmail.com>
Co-authored-by: Jeff MAURY <jmaury@redhat.com>
  • Loading branch information
feloy and jeffmaury committed Mar 25, 2024
1 parent f300c31 commit 1613e13
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 91 deletions.
113 changes: 45 additions & 68 deletions packages/backend/src/managers/playgroundV2Manager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ test('submit should throw an error if the server is stopped', async () => {
} as unknown as InferenceServer,
]);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('playground 1', { id: 'model1' } as ModelInfo, 'tracking-1');
await manager.createPlayground('playground 1', { id: 'model1' } as ModelInfo, '', 'tracking-1');

vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
Expand Down Expand Up @@ -107,7 +107,7 @@ test('submit should throw an error if the server is unhealthy', async () => {
} as unknown as InferenceServer,
]);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('p1', { id: 'model1' } as ModelInfo, 'tracking-1');
await manager.createPlayground('p1', { id: 'model1' } as ModelInfo, '', 'tracking-1');
const playgroundId = manager.getPlaygrounds()[0].id;
await expect(manager.submit(playgroundId, 'dummyUserInput', '')).rejects.toThrowError(
'Inference server is not healthy, currently status: unhealthy.',
Expand All @@ -133,12 +133,42 @@ test('create playground should create conversation.', async () => {
]);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
expect(manager.getConversations().length).toBe(0);
await manager.createPlayground('playground 1', { id: 'model-1' } as ModelInfo, 'tracking-1');
await manager.createPlayground('playground 1', { id: 'model-1' } as ModelInfo, '', 'tracking-1');

const conversations = manager.getConversations();
expect(conversations.length).toBe(1);
});

test('create playground called with a system prompt should create conversation with a system message.', async () => {
vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
status: 'running',
health: {
Status: 'healthy',
},
models: [
{
id: 'dummyModelId',
file: {
file: 'dummyModelFile',
},
},
],
} as unknown as InferenceServer,
]);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
expect(manager.getConversations().length).toBe(0);
await manager.createPlayground('playground 1', { id: 'model-1' } as ModelInfo, 'a system prompt', 'tracking-1');

const conversations = manager.getConversations();
expect(conversations.length).toBe(1);
const conversation = conversations[0];
expect(conversation.messages).toHaveLength(1);
const systemMessage = conversation.messages[0];
expect(systemMessage.role).toEqual('system');
expect(systemMessage.content).toEqual('a system prompt');
});

test('valid submit should create IPlaygroundMessage and notify the webview', async () => {
vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
Expand Down Expand Up @@ -169,7 +199,7 @@ test('valid submit should create IPlaygroundMessage and notify the webview', asy
} as unknown as OpenAI);

const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1');
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, undefined, 'tracking-1');

const date = new Date(2000, 1, 1, 13);
vi.setSystemTime(date);
Expand Down Expand Up @@ -208,65 +238,6 @@ test('valid submit should create IPlaygroundMessage and notify the webview', asy
});
});

test.each(['', 'my system prompt'])(
'valid submit should send a message with system prompt if non empty, system prompt is "%s"}',
async (systemPrompt: string) => {
vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
status: 'running',
health: {
Status: 'healthy',
},
models: [
{
id: 'dummyModelId',
file: {
file: 'dummyModelFile',
},
},
],
connection: {
port: 8888,
},
} as unknown as InferenceServer,
]);
const createMock = vi.fn().mockResolvedValue([]);
vi.mocked(OpenAI).mockReturnValue({
chat: {
completions: {
create: createMock,
},
},
} as unknown as OpenAI);

const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1');

const playgrounds = manager.getPlaygrounds();
await manager.submit(playgrounds[0].id, 'dummyUserInput', systemPrompt);

const messages: unknown[] = [
{
content: 'dummyUserInput',
id: expect.any(String),
role: 'user',
timestamp: expect.any(Number),
},
];
if (systemPrompt) {
messages.push({
content: 'my system prompt',
role: 'system',
});
}
expect(createMock).toHaveBeenCalledWith({
messages,
model: 'dummyModelFile',
stream: true,
});
},
);

test('submit should send options', async () => {
vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
Expand Down Expand Up @@ -297,7 +268,7 @@ test('submit should send options', async () => {
} as unknown as OpenAI);

const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1');
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, undefined, 'tracking-1');

const playgrounds = manager.getPlaygrounds();
await manager.submit(playgrounds[0].id, 'dummyUserInput', '', { temperature: 0.123, max_tokens: 45, top_p: 0.345 });
Expand Down Expand Up @@ -334,6 +305,7 @@ test('creating a new playground should send new playground to frontend', async (
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);
expect(webviewMock.postMessage).toHaveBeenCalledWith({
Expand All @@ -357,6 +329,7 @@ test('creating a new playground with no name should send new playground to front
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);
expect(webviewMock.postMessage).toHaveBeenCalledWith({
Expand All @@ -381,6 +354,7 @@ test('creating a new playground with no model served should start an inference s
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);
expect(createInferenceServerMock).toHaveBeenCalledWith(
Expand Down Expand Up @@ -417,6 +391,7 @@ test('creating a new playground with the model already served should not start a
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);
expect(createInferenceServerMock).not.toHaveBeenCalled();
Expand Down Expand Up @@ -445,6 +420,7 @@ test('creating a new playground with the model server stopped should start the i
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);
expect(createInferenceServerMock).not.toHaveBeenCalled();
Expand All @@ -462,6 +438,7 @@ test('delete conversation should delete the conversation', async () => {
id: 'model-1',
name: 'Model 1',
} as unknown as ModelInfo,
'',
'tracking-1',
);

Expand All @@ -484,9 +461,9 @@ test('requestCreatePlayground should call createPlayground and createTask, then
});
const createPlaygroundSpy = vi.spyOn(manager, 'createPlayground').mockResolvedValue('playground-1');

const id = await manager.requestCreatePlayground('a name', { id: 'model-1' } as ModelInfo);
const id = await manager.requestCreatePlayground('a name', { id: 'model-1' } as ModelInfo, '');

expect(createPlaygroundSpy).toHaveBeenCalledWith('a name', { id: 'model-1' } as ModelInfo, expect.any(String));
expect(createPlaygroundSpy).toHaveBeenCalledWith('a name', { id: 'model-1' } as ModelInfo, '', expect.any(String));
expect(createTaskMock).toHaveBeenCalledWith('Creating Playground environment', 'loading', {
trackingId: id,
});
Expand All @@ -513,9 +490,9 @@ test('requestCreatePlayground should call createPlayground and createTask, then
});
const createPlaygroundSpy = vi.spyOn(manager, 'createPlayground').mockRejectedValue(new Error('an error'));

const id = await manager.requestCreatePlayground('a name', { id: 'model-1' } as ModelInfo);
const id = await manager.requestCreatePlayground('a name', { id: 'model-1' } as ModelInfo, '');

expect(createPlaygroundSpy).toHaveBeenCalledWith('a name', { id: 'model-1' } as ModelInfo, expect.any(String));
expect(createPlaygroundSpy).toHaveBeenCalledWith('a name', { id: 'model-1' } as ModelInfo, '', expect.any(String));
expect(createTaskMock).toHaveBeenCalledWith('Creating Playground environment', 'loading', {
trackingId: id,
});
Expand Down
22 changes: 18 additions & 4 deletions packages/backend/src/managers/playgroundV2Manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import type { ChatCompletionChunk, ChatCompletionMessageParam } from 'openai/src
import type { ModelOptions } from '@shared/src/models/IModelOptions';
import type { Stream } from 'openai/streaming';
import { ConversationRegistry } from '../registries/conversationRegistry';
import type { Conversation, PendingChat, UserChat } from '@shared/src/models/IPlaygroundMessage';
import type { Conversation, PendingChat, SystemPrompt, UserChat } from '@shared/src/models/IPlaygroundMessage';
import type { PlaygroundV2 } from '@shared/src/models/IPlaygroundV2';
import { Publisher } from '../utils/Publisher';
import { Messages } from '@shared/Messages';
Expand Down Expand Up @@ -54,13 +54,13 @@ export class PlaygroundV2Manager extends Publisher<PlaygroundV2[]> implements Di
this.notify();
}

async requestCreatePlayground(name: string, model: ModelInfo): Promise<string> {
async requestCreatePlayground(name: string, model: ModelInfo, systemPrompt?: string): Promise<string> {
const trackingId: string = getRandomString();
const task = this.taskRegistry.createTask('Creating Playground environment', 'loading', {
trackingId: trackingId,
});

this.createPlayground(name, model, trackingId)
this.createPlayground(name, model, systemPrompt, trackingId)
.then((playgroundId: string) => {
this.taskRegistry.updateTask({
...task,
Expand Down Expand Up @@ -94,7 +94,12 @@ export class PlaygroundV2Manager extends Publisher<PlaygroundV2[]> implements Di
return trackingId;
}

async createPlayground(name: string, model: ModelInfo, trackingId: string): Promise<string> {
async createPlayground(
name: string,
model: ModelInfo,
systemPrompt: string | undefined,
trackingId: string,
): Promise<string> {
const id = `${this.#playgroundCounter++}`;

if (!name) {
Expand All @@ -103,6 +108,15 @@ export class PlaygroundV2Manager extends Publisher<PlaygroundV2[]> implements Di

this.#conversationRegistry.createConversation(id);

if (systemPrompt) {
this.#conversationRegistry.submit(id, {
content: systemPrompt,
role: 'system',
id: this.getUniqueId(),
timestamp: Date.now(),
} as SystemPrompt);
}

// create/start inference server if necessary
const servers = this.inferenceManager.getServers();
const server = servers.find(s => s.models.map(mi => mi.id).includes(model.id));
Expand Down
4 changes: 2 additions & 2 deletions packages/backend/src/studio-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ export class StudioApiImpl implements StudioAPI {
});
}

async requestCreatePlayground(name: string, model: ModelInfo): Promise<string> {
async requestCreatePlayground(name: string, model: ModelInfo, systemPrompt?: string): Promise<string> {
try {
return this.playgroundV2.requestCreatePlayground(name, model);
return this.playgroundV2.requestCreatePlayground(name, model, systemPrompt);
} catch (err: unknown) {
console.error('Something went wrong while trying to create playground environment', err);
throw err;
Expand Down
20 changes: 5 additions & 15 deletions packages/frontend/src/pages/Playground.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
isPendingChat,
isUserChat,
type AssistantChat,
isSystemPrompt,
} from '@shared/src/models/IPlaygroundMessage';
import NavPage from '../lib/NavPage.svelte';
import { playgrounds } from '../stores/playgrounds-v2';
Expand Down Expand Up @@ -46,7 +47,7 @@ $: {
}
const roleNames = {
system: 'System',
system: 'System prompt',
user: 'User',
assistant: 'Assistant',
};
Expand All @@ -61,9 +62,8 @@ function getMessageParagraphs(message: ChatMessage): string[] {
.join('')
.split('\n');
}
} else if (isUserChat(message)) {
const msg = message as UserChat;
return msg.content?.split('\n') ?? [];
} else if (isUserChat(message) || isSystemPrompt(message)) {
return message.content?.split('\n') ?? [];
}
return [];
}
Expand Down Expand Up @@ -130,6 +130,7 @@ function elapsedTime(msg: AssistantChat): string {
<div
class="p-4 rounded-md"
class:bg-charcoal-400="{isUserChat(message)}"
class:bg-charcoal-800="{isSystemPrompt(message)}"
class:bg-charcoal-900="{isAssistantChat(message)}"
class:ml-8="{isAssistantChat(message)}"
class:mr-8="{isUserChat(message)}">
Expand All @@ -152,17 +153,6 @@ function elapsedTime(msg: AssistantChat): string {
</svelte:fragment>
<svelte:fragment slot="details">
<div class="text-gray-800 text-xs">Next prompt will use these settings</div>
<div class="bg-charcoal-600 w-full rounded-md text-xs p-4">
<div class="mb-4">System Prompt</div>
<div class="w-full">
<textarea
bind:value="{systemPrompt}"
class="p-2 w-full outline-none bg-charcoal-500 rounded-sm text-gray-700 placeholder-gray-700"
rows="4"
placeholder="Provide system prompt to define general context, instructions or guidelines to be used with each query"
></textarea>
</div>
</div>
<div class="bg-charcoal-600 w-full rounded-md text-xs p-4">
<div class="mb-4 flex flex-col">Model Parameters</div>
<div class="flex flex-col space-y-4">
Expand Down
13 changes: 12 additions & 1 deletion packages/frontend/src/pages/PlaygroundCreate.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ let localModels: ModelInfo[];
$: localModels = $modelsInfo.filter(model => model.file);
$: availModels = $modelsInfo.filter(model => !model.file);
let modelId: string | undefined = undefined;
let systemPrompt: string | undefined = undefined;
let submitted: boolean = false;
let playgroundName: string;
Expand Down Expand Up @@ -55,7 +56,8 @@ async function submit() {
// disable submit button
submitted = true;
try {
trackingId = await studioClient.requestCreatePlayground(playgroundName, model);
// Using || and not && as we want to have the empty string systemPrompt passed as undefined
trackingId = await studioClient.requestCreatePlayground(playgroundName, model, systemPrompt || undefined);
} catch (err: unknown) {
trackingId = undefined;
console.error('Something wrong while trying to create the playground.', err);
Expand Down Expand Up @@ -161,6 +163,15 @@ onDestroy(() => {
</div>
</div>
{/if}

<label for="model" class="pt-4 block mb-2 text-sm font-bold text-gray-400">System prompt</label>
<textarea
aria-label="system-prompt-textarea"
bind:value="{systemPrompt}"
class="w-full p-2 outline-none text-sm bg-charcoal-600 rounded-sm text-gray-700 placeholder-gray-700"
rows="4"
placeholder="Optionally provide system prompt to define general context, instructions or guidelines to be used with each query"
></textarea>
</div>
<footer>
<div class="w-full flex flex-col">
Expand Down
2 changes: 1 addition & 1 deletion packages/shared/src/StudioAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ export abstract class StudioAPI {
*/
abstract createSnippet(options: RequestOptions, language: string, variant: string): Promise<string>;

abstract requestCreatePlayground(name: string, model: ModelInfo): Promise<string>;
abstract requestCreatePlayground(name: string, model: ModelInfo, systemPrompt?: string): Promise<string>;

abstract getPlaygroundsV2(): Promise<PlaygroundV2[]>;

Expand Down

0 comments on commit 1613e13

Please sign in to comment.