Skip to content

Commit

Permalink
feat: system prompt for playground page
Browse files Browse the repository at this point in the history
Signed-off-by: Philippe Martin <phmartin@redhat.com>
  • Loading branch information
feloy committed Mar 20, 2024
1 parent 0bab1bf commit ecd1e9d
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 32 deletions.
6 changes: 3 additions & 3 deletions packages/backend/src/managers/playgroundV2Manager.spec.ts
Expand Up @@ -80,7 +80,7 @@ test('submit should throw an error if the server is stopped', async () => {
} as unknown as InferenceServer,
]);

await expect(manager.submit('0', 'dummyUserInput')).rejects.toThrowError('Inference server is not running.');
await expect(manager.submit('0', 'dummyUserInput', '')).rejects.toThrowError('Inference server is not running.');
});

test('submit should throw an error if the server is unhealthy', async () => {
Expand All @@ -100,7 +100,7 @@ test('submit should throw an error if the server is unhealthy', async () => {
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock);
await manager.createPlayground('p1', { id: 'model1' } as ModelInfo);
const playgroundId = manager.getPlaygrounds()[0].id;
await expect(manager.submit(playgroundId, 'dummyUserInput')).rejects.toThrowError(
await expect(manager.submit(playgroundId, 'dummyUserInput', '')).rejects.toThrowError(
'Inference server is not healthy, currently status: unhealthy.',
);
});
Expand Down Expand Up @@ -166,7 +166,7 @@ test('valid submit should create IPlaygroundMessage and notify the webview', asy
vi.setSystemTime(date);

const playgrounds = manager.getPlaygrounds();
await manager.submit(playgrounds[0].id, 'dummyUserInput');
await manager.submit(playgrounds[0].id, 'dummyUserInput', '');

// Wait for assistant message to be completed
await vi.waitFor(() => {
Expand Down
8 changes: 6 additions & 2 deletions packages/backend/src/managers/playgroundV2Manager.ts
Expand Up @@ -87,7 +87,7 @@ export class PlaygroundV2Manager extends Publisher<PlaygroundV2[]> implements Di
* @param userInput the user input
* @param options the model configuration
*/
async submit(playgroundId: string, userInput: string, options?: ModelOptions): Promise<void> {
async submit(playgroundId: string, userInput: string, systemPrompt: string, options?: ModelOptions): Promise<void> {
const playground = this.#playgrounds.get(playgroundId);
if (playground === undefined) throw new Error('Playground not found.');

Expand Down Expand Up @@ -122,8 +122,12 @@ export class PlaygroundV2Manager extends Publisher<PlaygroundV2[]> implements Di
apiKey: 'dummy',
});

const messages = this.getFormattedMessages(playground.id);
if (systemPrompt) {
messages.push({ role: 'system', content: systemPrompt });
}
const response = await client.chat.completions.create({
messages: this.getFormattedMessages(playground.id),
messages,
stream: true,
model: modelInfo.file.file,
...options,
Expand Down
9 changes: 7 additions & 2 deletions packages/backend/src/studio-api-impl.ts
Expand Up @@ -72,8 +72,13 @@ export class StudioApiImpl implements StudioAPI {
return this.playgroundV2.getPlaygrounds();
}

submitPlaygroundMessage(containerId: string, userInput: string, options?: ModelOptions): Promise<void> {
return this.playgroundV2.submit(containerId, userInput, options);
submitPlaygroundMessage(
containerId: string,
userInput: string,
systemPrompt: string,
options?: ModelOptions,
): Promise<void> {
return this.playgroundV2.submit(containerId, userInput, systemPrompt, options);
}

async getPlaygroundConversations(): Promise<Conversation[]> {
Expand Down
78 changes: 54 additions & 24 deletions packages/frontend/src/pages/Playground.svelte
Expand Up @@ -14,6 +14,7 @@ import { playgrounds } from '../stores/playgrounds-v2';
import { catalog } from '../stores/catalog';
import Button from '../lib/button/Button.svelte';
import { afterUpdate } from 'svelte';
import ContentDetailsLayout from '../lib/ContentDetailsLayout.svelte';
export let playgroundId: string;
let prompt: string;
Expand All @@ -22,6 +23,9 @@ let scrollable: Element;
let lastIsUserMessage = false;
let errorMsg = '';
// settings
let systemPrompt: string = '';
$: conversation = $conversations.find(conversation => conversation.id === playgroundId);
$: playground = $playgrounds.find(playground => playground.id === playgroundId);
$: model = $catalog.models.find(model => model.id === playground?.modelId);
Expand Down Expand Up @@ -63,7 +67,7 @@ function getMessageParagraphs(message: ChatMessage): string[] {
function askPlayground() {
errorMsg = '';
sendEnabled = false;
studioClient.submitPlaygroundMessage(playgroundId, prompt).catch((err: unknown) => {
studioClient.submitPlaygroundMessage(playgroundId, prompt, systemPrompt, {}).catch((err: unknown) => {
errorMsg = String(err);
sendEnabled = true;
});
Expand Down Expand Up @@ -101,30 +105,56 @@ function elapsedTime(msg: AssistantChat): string {
<svelte:fragment slot="subtitle">{model?.name}</svelte:fragment>
<svelte:fragment slot="content">
<div class="flex flex-col w-full h-full">
<div bind:this="{scrollable}" aria-label="conversation" class="w-full h-full overflow-auto">
{#if conversation?.messages}
<ul class="p-4">
{#each conversation?.messages as message}
<li class="m-4">
<div class="text-lg" class:text-right="{isAssistantChat(message)}">{roleNames[message.role]}</div>
<div
class="p-4 rounded-md"
class:bg-charcoal-400="{isUserChat(message)}"
class:bg-charcoal-900="{isAssistantChat(message)}"
class:ml-8="{isAssistantChat(message)}"
class:mr-8="{isUserChat(message)}">
{#each getMessageParagraphs(message) as paragraph}
<p>{paragraph}</p>
{/each}
</div>
{#if isAssistantChat(message)}
<div class="text-sm text-gray-400 text-right" aria-label="elapsed">{elapsedTime(message)} s</div>
<div class="h-full overflow-auto" bind:this="{scrollable}">
<ContentDetailsLayout detailsTitle="Settings" detailsLabel="settings">
<svelte:fragment slot="content">
<div class="flex flex-col w-full h-full">
<div aria-label="conversation" class="w-full h-full">
{#if conversation?.messages}
<ul class="p-4">
{#each conversation?.messages as message}
<li class="m-4">
<div class="text-lg" class:text-right="{isAssistantChat(message)}">
{roleNames[message.role]}
</div>
<div
class="p-4 rounded-md"
class:bg-charcoal-400="{isUserChat(message)}"
class:bg-charcoal-900="{isAssistantChat(message)}"
class:ml-8="{isAssistantChat(message)}"
class:mr-8="{isUserChat(message)}">
{#each getMessageParagraphs(message) as paragraph}
<p>{paragraph}</p>
{/each}
</div>
{#if isAssistantChat(message)}
<div class="text-sm text-gray-400 text-right" aria-label="elapsed">
{elapsedTime(message)} s
</div>
{/if}
<div></div>
</li>
{/each}
</ul>
{/if}
<div></div>
</li>
{/each}
</ul>
{/if}
</div>
</div>
</svelte:fragment>
<svelte:fragment slot="details">
<div class="text-gray-800 text-xs">Next prompt will use these settings</div>
<div class="bg-charcoal-600 w-full rounded-md text-xs p-4">
<div class="mb-4">System Prompt</div>
<div class="w-full">
<textarea
bind:value="{systemPrompt}"
class="w-full outline-none bg-charcoal-500 rounded-sm text-gray-700 placeholder-gray-700"
rows="4"
placeholder="Provide system prompt to define general context, instructions or guidelines to be used with each query"
></textarea>
</div>
</div>
</svelte:fragment>
</ContentDetailsLayout>
</div>
{#if errorMsg}
<div class="text-red-500 text-sm p-2">{errorMsg}</div>
Expand Down
7 changes: 6 additions & 1 deletion packages/shared/src/StudioAPI.ts
Expand Up @@ -137,7 +137,12 @@ export abstract class StudioAPI {
* @param userInput the user input, e.g. 'What is the capital of France ?'
* @param options the options for the model, e.g. temperature
*/
abstract submitPlaygroundMessage(containerId: string, userInput: string, options?: ModelOptions): Promise<void>;
abstract submitPlaygroundMessage(
containerId: string,
userInput: string,
systemPrompt: string,
options?: ModelOptions,
): Promise<void>;

/**
* Return the conversations
Expand Down

0 comments on commit ecd1e9d

Please sign in to comment.