Skip to content

Commit

Permalink
Add another AI model
Browse files Browse the repository at this point in the history
  • Loading branch information
RikudouSage committed Apr 11, 2024
1 parent 84a9499 commit 7caf441
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/Enum/AiModel.php
Expand Up @@ -4,5 +4,6 @@

enum AiModel: string
{
case Mistral7BOpenHermes = 'OpenHermes-2.5-Mistral-7B';
case OpenHermesMistral7B = 'OpenHermes-2.5-Mistral-7B';
case Fimbulvetr11Bv2 = 'Fimbulvetr-11B-v2';
}
Expand Up @@ -38,6 +38,6 @@ public function formatOutput(string $message): Message

public function supports(AiModel $model): bool
{
return in_array($model, [AiModel::Mistral7BOpenHermes], true);
return in_array($model, [AiModel::OpenHermesMistral7B], true);
}
}
35 changes: 35 additions & 0 deletions src/Service/AiHorde/MessageFormatter/VicunaMessageFormatter.php
@@ -0,0 +1,35 @@
<?php

namespace App\Service\AiHorde\MessageFormatter;

use App\Enum\AiActor;
use App\Enum\AiModel;
use App\Service\AiHorde\Message\Message;
use App\Service\AiHorde\Message\MessageHistory;

final readonly class VicunaMessageFormatter implements MessageFormatter
{
public function getPrompt(MessageHistory $messages): string
{
$result = '';
foreach ($messages as $message) {
$result .= "\n" . strtoupper($message->role->value) . ': ' . $message->content;
}
$result .= "\nASSISTANT:";

return $result;
}

public function formatOutput(string $message): Message
{
return new Message(
role: AiActor::Assistant,
content: trim($message),
);
}

public function supports(AiModel $model): bool
{
return in_array($model, [AiModel::Fimbulvetr11Bv2], true);
}
}
8 changes: 7 additions & 1 deletion src/Service/Expression/ExpressionLanguageAiFunctions.php
Expand Up @@ -7,6 +7,7 @@
use App\Service\AiHorde\AiHorde;
use App\Service\AiHorde\Message\Message;
use App\Service\AiHorde\Message\MessageHistory;
use RuntimeException;
use Symfony\Component\ExpressionLanguage\ExpressionFunction;

final readonly class ExpressionLanguageAiFunctions extends AbstractExpressionLanguageFunctionProvider
Expand All @@ -33,7 +34,12 @@ private function aiAnalyzeFunction(array $context, string $message, ?string $sys
if ($systemPrompt !== null) {
$history[] = new Message(role: AiActor::System, content: $systemPrompt);
}
$models = array_filter([AiModel::OpenHermesMistral7B, AiModel::Fimbulvetr11Bv2], fn (AiModel $model) => count($this->aiHorde->findModels($model)));
if (!count($models)) {
throw new RuntimeException('There are no models online available to service your request.');
}
$model = $models[array_rand($models)];

return $this->aiHorde->getResponse($message, AiModel::Mistral7BOpenHermes, $history);
return $this->aiHorde->getResponse($message, $model, $history);
}
}

0 comments on commit 7caf441

Please sign in to comment.