Skip to content

Commit

Permalink
refactor: extract ApproachBase class
Browse files Browse the repository at this point in the history
  • Loading branch information
sinedied committed Sep 1, 2023
1 parent 4983a2c commit 910b314
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 130 deletions.
126 changes: 126 additions & 0 deletions packages/api/src/lib/approaches/approach-base.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import { SearchClient } from '@azure/search-documents';
import { OpenAiClients } from '../../plugins/openai.js';
import { removeNewlines } from '../util/index.js';

export interface SearchDocumentsResult {
query: string;
results: string[];
content: string;
}

export class ApproachBase {
constructor(
protected search: SearchClient<any>,
protected openai: OpenAiClients,
protected chatGptModel: string,
protected sourcePageField: string,
protected contentField: string,
) {}

protected async searchDocuments(query?: string, overrides: Record<string, any> = {}): Promise<SearchDocumentsResult> {
const hasText = ['text', 'hybrid', undefined].includes(overrides?.retrieval_mode);
const hasVectors = ['vectors', 'hybrid', undefined].includes(overrides?.retrieval_mode);
const useSemanticCaption = Boolean(overrides?.use_semantic_caption) && hasText;
const top = overrides?.top ? Number(overrides?.top) : 3;
const excludeCategory: string | undefined = overrides?.exclude_category;
const filter = excludeCategory ? `category ne '${excludeCategory.replace("'", "''")}'` : undefined;

// If retrieval mode includes vectors, compute an embedding for the query
let queryVector;
if (hasVectors) {
let openAiEmbeddings = await this.openai.getEmbeddings();
const result = await openAiEmbeddings.create({
model: 'text-embedding-ada-002',
input: query!,
});
queryVector = result.data[0].embedding;
}

// Only keep the text query if the retrieval mode uses text, otherwise drop it
const queryText = hasText ? query : '';

// Use semantic L2 reranker if requested and if retrieval mode is text or hybrid (vectors + text)
let searchResults;
if (overrides?.semantic_ranker && hasText) {
searchResults = await this.search.search(queryText, {
filter,
queryType: 'semantic',
queryLanguage: 'en-us',
speller: 'lexicon',
semanticConfiguration: 'default',
top,
captions: useSemanticCaption ? 'extractive|highlight-false' : undefined,
vectors: [
{
value: queryVector,
kNearestNeighborsCount: queryVector ? 50 : undefined,
fields: queryVector ? ['embedding'] : undefined,
},
],
});
} else {
searchResults = await this.search.search(queryText, {
filter,
top,
vectors: [
{
value: queryVector,
kNearestNeighborsCount: queryVector ? 50 : undefined,
fields: queryVector ? ['embedding'] : undefined,
},
],
});
}

let results: string[] = [];
if (useSemanticCaption) {
for await (const result of searchResults.results) {
// TODO: ensure typings
const doc = result as any;
const captions = doc['@search.captions'];
const captionsText = captions.map((c: any) => c.text).join(' . ');
results.push(`${doc[this.sourcePageField]}: ${removeNewlines(captionsText)}`);
}
} else {
for await (const result of searchResults.results) {
// TODO: ensure typings
const doc = result.document as any;
results.push(`${doc[this.sourcePageField]}: ${removeNewlines(doc[this.contentField])}`);
}
}
const content = results.join('\n');
return {
query: queryText ?? '',
results,
content,
};
}

protected async lookupDocument(query: string): Promise<any> {
const searchResults = await this.search.search(query, {
top: 1,
includeTotalCount: true,
queryType: 'semantic',
queryLanguage: 'en-us',
speller: 'lexicon',
semanticConfiguration: 'default',
answers: 'extractive|count-1',
captions: 'extractive|highlight-false',
});

const answers = await searchResults.answers;
if (answers && answers.length > 0) {
return answers[0].text;
}
if (searchResults.count ?? 0 > 0) {
const results = [];
for await (const result of searchResults.results) {
// TODO: ensure typings
const doc = result.document as any;
results.push(doc[this.contentField]);
}
return results.join('\n');
}
return undefined;
}
}
126 changes: 0 additions & 126 deletions packages/api/src/lib/approaches/approach.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import { SearchClient } from '@azure/search-documents';
import { OpenAiClients } from '../../plugins/openai.js';
import { removeNewlines } from '../util/index.js';
import { HistoryMessage } from '../message';

export interface ApproachResponse {
Expand All @@ -9,133 +6,10 @@ export interface ApproachResponse {
thoughts: string;
}

export interface SearchDocumentsResult {
query: string;
results: string[];
content: string;
}

export interface ChatApproach {
run(history: HistoryMessage[], overrides: Record<string, any>): Promise<ApproachResponse>;
}

export interface AskApproach {
run(query: string, overrides: Record<string, any>): Promise<ApproachResponse>;
}

export class ApproachBase {
constructor(
protected search: SearchClient<any>,
protected openai: OpenAiClients,
protected chatGptModel: string,
protected sourcePageField: string,
protected contentField: string,
) {}

protected async searchDocuments(query?: string, overrides: Record<string, any> = {}): Promise<SearchDocumentsResult> {
const hasText = ['text', 'hybrid', undefined].includes(overrides?.retrieval_mode);
const hasVectors = ['vectors', 'hybrid', undefined].includes(overrides?.retrieval_mode);
const useSemanticCaption = Boolean(overrides?.use_semantic_caption) && hasText;
const top = overrides?.top ? Number(overrides?.top) : 3;
const excludeCategory: string | undefined = overrides?.exclude_category;
const filter = excludeCategory ? `category ne '${excludeCategory.replace("'", "''")}'` : undefined;

// If retrieval mode includes vectors, compute an embedding for the query
let queryVector;
if (hasVectors) {
let openAiEmbeddings = await this.openai.getEmbeddings();
const result = await openAiEmbeddings.create({
model: 'text-embedding-ada-002',
input: query!,
});
queryVector = result.data[0].embedding;
}

// Only keep the text query if the retrieval mode uses text, otherwise drop it
const queryText = hasText ? query : '';

// Use semantic L2 reranker if requested and if retrieval mode is text or hybrid (vectors + text)
let searchResults;
if (overrides?.semantic_ranker && hasText) {
searchResults = await this.search.search(queryText, {
filter,
queryType: 'semantic',
queryLanguage: 'en-us',
speller: 'lexicon',
semanticConfiguration: 'default',
top,
captions: useSemanticCaption ? 'extractive|highlight-false' : undefined,
vectors: [
{
value: queryVector,
kNearestNeighborsCount: queryVector ? 50 : undefined,
fields: queryVector ? ['embedding'] : undefined,
},
],
});
} else {
searchResults = await this.search.search(queryText, {
filter,
top,
vectors: [
{
value: queryVector,
kNearestNeighborsCount: queryVector ? 50 : undefined,
fields: queryVector ? ['embedding'] : undefined,
},
],
});
}

let results: string[] = [];
if (useSemanticCaption) {
for await (const result of searchResults.results) {
// TODO: ensure typings
const doc = result as any;
const captions = doc['@search.captions'];
const captionsText = captions.map((c: any) => c.text).join(' . ');
results.push(`${doc[this.sourcePageField]}: ${removeNewlines(captionsText)}`);
}
} else {
for await (const result of searchResults.results) {
// TODO: ensure typings
const doc = result.document as any;
results.push(`${doc[this.sourcePageField]}: ${removeNewlines(doc[this.contentField])}`);
}
}
const content = results.join('\n');
return {
query: queryText ?? '',
results,
content,
};
}

protected async lookupDocument(query: string): Promise<any> {
const searchResults = await this.search.search(query, {
top: 1,
includeTotalCount: true,
queryType: 'semantic',
queryLanguage: 'en-us',
speller: 'lexicon',
semanticConfiguration: 'default',
answers: 'extractive|count-1',
captions: 'extractive|highlight-false',
});

const answers = await searchResults.answers;
if (answers && answers.length > 0) {
return answers[0].text;
}
if (searchResults.count ?? 0 > 0) {
const results = [];
for await (const result of searchResults.results) {
// TODO: ensure typings
const doc = result.document as any;
results.push(doc[this.contentField]);
}
return results.join('\n');
}
return undefined;
}
}
5 changes: 3 additions & 2 deletions packages/api/src/lib/approaches/ask-retrieve-then-read.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { SearchClient } from '@azure/search-documents';
import { OpenAiClients } from '../../plugins/openai.js';
import { MessageBuilder } from '../message-builder.js';
import { ApproachBase, AskApproach } from './approach.js';
import { messagesToString } from '../message.js';
import { MessageBuilder } from '../message-builder.js';
import { AskApproach } from './approach.js';
import { ApproachBase } from './approach-base.js';

const SYSTEM_CHAT_TEMPLATE = `You are an intelligent assistant helping Contoso Inc employees with their healthcare plan questions and employee handbook questions.
Use 'you' to refer to the individual asking the questions even if they ask with 'I'.
Expand Down
5 changes: 3 additions & 2 deletions packages/api/src/lib/approaches/chat-read-retrieve-read.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { SearchClient } from '@azure/search-documents';
import { ChatApproach, ApproachResponse, ApproachBase } from './approach.js';
import { OpenAiClients } from '../../plugins/openai.js';
import { ChatApproach, ApproachResponse } from './approach.js';
import { ApproachBase } from './approach-base.js';
import { HistoryMessage, Message, messagesToString } from '../message.js';
import { MessageBuilder } from '../message-builder.js';
import { getTokenLimit } from '../model-helpers.js';
import { HistoryMessage, Message, messagesToString } from '../message.js';

const SYSTEM_MESSAGE_CHAT_CONVERSATION = `Assistant helps the company employees with their healthcare plan questions, and questions about the employee handbook. Be brief in your answers.
Answer ONLY with the facts listed in the list of sources below. If there isn't enough information below, say you don't know. Do not generate answers that don't use the sources below. If asking a clarifying question to the user would help, ask the question.
Expand Down

0 comments on commit 910b314

Please sign in to comment.