Skip to content

Commit

Permalink
Feat: Add AI analyzing of comments (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
RikudouSage committed Apr 11, 2024
1 parent d614dad commit c762c0f
Show file tree
Hide file tree
Showing 17 changed files with 477 additions and 17 deletions.
1 change: 1 addition & 0 deletions .env
Expand Up @@ -46,6 +46,7 @@ MATRIX_INSTANCE= # for example lemmings.world, can be left empty if it's same as
#### api keys
SLACK_BOT_TOKEN=
MATRIX_API_TOKEN=
AI_HORDE_API_KEY=

#### other settings
USE_LEMMYVERSE_LINK_SLACK=0
Expand Down
10 changes: 7 additions & 3 deletions config/services.yaml
Expand Up @@ -27,6 +27,7 @@ parameters:

app.fediseer.api: '%env(FEDISEER_API_URL)%'
app.fediseer.key: '%env(FEDISEER_API_KEY)%'
app.ai_horde.api_key: '%env(AI_HORDE_API_KEY)%'

app.image_check.regex: '%env(IMAGE_CHECK_REGEX)%'

Expand Down Expand Up @@ -88,7 +89,10 @@ services:
arguments:
$removalLogValidity: '@app.log_validity'

App\Service\ExpressionLanguage:
App\Service\Expression\ExpressionLanguage:
calls:
- registerProvider: ['@App\Service\ExpressionLanguageFunctions']
Symfony\Component\ExpressionLanguage\ExpressionLanguage: '@App\Service\ExpressionLanguage'
- registerProvider: ['@App\Service\Expression\ExpressionLanguageFunctions']
- registerProvider: ['@App\Service\Expression\ExpressionLanguageAiFunctions']
- registerProvider: ['@App\Service\Expression\ExpressionLanguageStringFunctions']
- registerProvider: ['@App\Service\Expression\ExpressionLanguageLemmyFunctions']
Symfony\Component\ExpressionLanguage\ExpressionLanguage: '@App\Service\Expression\ExpressionLanguage'
2 changes: 1 addition & 1 deletion src/Automod/ModAction/ComplexRuleAction.php
Expand Up @@ -9,7 +9,7 @@
use App\Enum\FurtherAction;
use App\Enum\RunConfiguration;
use App\Repository\ComplexRuleRepository;
use App\Service\ExpressionLanguage;
use App\Service\Expression\ExpressionLanguage;
use LogicException;
use Rikudou\LemmyApi\LemmyApi;
use Rikudou\LemmyApi\Response\Model\Person;
Expand Down
10 changes: 10 additions & 0 deletions src/Enum/AiActor.php
@@ -0,0 +1,10 @@
<?php

namespace App\Enum;

enum AiActor: string
{
case System = 'system';
case User = 'user';
case Assistant = 'assistant';
}
8 changes: 8 additions & 0 deletions src/Enum/AiModel.php
@@ -0,0 +1,8 @@
<?php

namespace App\Enum;

enum AiModel: string
{
case Mistral7BOpenHermes = 'OpenHermes-2.5-Mistral-7B';
}
21 changes: 21 additions & 0 deletions src/MessageHandler/RunExpressionAsyncHandler.php
@@ -0,0 +1,21 @@
<?php

namespace App\MessageHandler;

use App\Message\RunExpressionAsyncMessage;
use App\Service\Expression\ExpressionLanguage;
use Symfony\Component\Messenger\Attribute\AsMessageHandler;

#[AsMessageHandler]
final readonly class RunExpressionAsyncHandler
{
public function __construct(
private ExpressionLanguage $expressionLanguage,
) {
}

public function __invoke(RunExpressionAsyncMessage $message): void
{
$this->expressionLanguage->evaluate($message->expression, $message->context);
}
}
134 changes: 134 additions & 0 deletions src/Service/AiHorde/AiHorde.php
@@ -0,0 +1,134 @@
<?php

namespace App\Service\AiHorde;

use App\Enum\AiActor;
use App\Enum\AiModel;
use App\Service\AiHorde\Message\Message;
use App\Service\AiHorde\Message\MessageHistory;
use App\Service\AiHorde\MessageFormatter\MessageFormatter;
use LogicException;
use Symfony\Component\DependencyInjection\Attribute\Autowire;
use Symfony\Component\DependencyInjection\Attribute\TaggedIterator;
use Symfony\Component\HttpFoundation\Request;
use Symfony\Contracts\HttpClient\HttpClientInterface;

final readonly class AiHorde
{
public function __construct(
private HttpClientInterface $httpClient,
#[TaggedIterator('app.message_formatter')]
private iterable $formatters,
#[Autowire('%app.ai_horde.api_key%')]
private string $apiKey,
) {
}

public function getResponse(
string $message,
AiModel $model,
MessageHistory $history = new MessageHistory(),
): string {
if (!$this->apiKey) {
throw new LogicException('There is no api key set, cannot use AI actions');
}

$models = $this->findModels($model);
if (!count($models)) {
throw new LogicException('There was an error while looking for available models - no model able to handle your message seems to be online. Please try again later.');
}
$formatter = $this->findFormatter($model) ?? throw new LogicException("Could not find formatter for {$model->value}");
[$maxLength, $maxContextLength] = $this->getMaxLength($model);

$response = $this->httpClient->request(Request::METHOD_POST, 'https://aihorde.net/api/v2/generate/text/async', [
'json' => [
'prompt' => $formatter->getPrompt(new MessageHistory(
...[...$history, new Message(role: AiActor::User, content: $message)],
)),
'params' => [
'max_length' => $maxLength,
'max_context_length' => $maxContextLength,
],
'models' => $models,
],
'headers' => [
'apikey' => $this->apiKey,
],
]);
$json = json_decode($response->getContent(), true, flags: JSON_THROW_ON_ERROR);
$jobId = $json['id'];

do {
$response = $this->httpClient->request(Request::METHOD_GET, "https://aihorde.net/api/v2/generate/text/status/{$jobId}", [
'headers' => [
'apikey' => $this->apiKey,
],
]);
$json = json_decode($response->getContent(), true, flags: JSON_THROW_ON_ERROR);
if (!$json['done']) {
sleep(1);
}
} while (!$json['done']);

if (!isset($json['generations'][0])) {
throw new LogicException('Missing generations output');
}

$output = $formatter->formatOutput($json['generations'][0]['text']);

return $output->content;
}

/**
* @return array<string>
*/
public function findModels(AiModel $model): array
{
$response = $this->httpClient->request(Request::METHOD_GET, 'https://aihorde.net/api/v2/status/models?type=text');
$json = json_decode($response->getContent(), true, flags: JSON_THROW_ON_ERROR);

return array_values(array_map(
fn (array $modelData) => $modelData['name'],
array_filter($json, fn (array $modelData) => fnmatch("*/{$model->value}", $modelData['name'])),
));
}

private function findFormatter(AiModel $model): ?MessageFormatter
{
foreach ($this->formatters as $formatter) {
if ($formatter->supports($model)) {
return $formatter;
}
}

return null;
}

private function getMaxLength(AiModel $model): array
{
$response = $this->httpClient->request(Request::METHOD_GET, 'https://aihorde.net/api/v2/workers?type=text');
$json = json_decode($response->getContent(), true, flags: JSON_THROW_ON_ERROR);
$workers = array_filter(
$json,
fn (array $worker) => count(array_filter(
$worker['models'],
fn (string $modelName) => fnmatch("*/{$model->value}", $modelName),
)) > 0,
);
$targetLength = 1024;
$targetContext = 2048;

if (!count(array_filter($workers, fn(array $worker) => $worker['max_length'] >= $targetLength))) {
$targetLength = max(array_map(fn (array $worker) => $worker['max_length'], $workers));
}
if (!count(array_filter($workers, fn(array $worker) => $worker['max_context_length'] >= $targetContext))) {
$targetContext = max(array_map(fn (array $worker) => $worker['max_context_length'], $workers));
}

if ($targetLength > $targetContext / 2) {
$targetLength = $targetContext / 2;
}

return [$targetLength, $targetContext];
}
}
25 changes: 25 additions & 0 deletions src/Service/AiHorde/Message/Message.php
@@ -0,0 +1,25 @@
<?php

namespace App\Service\AiHorde\Message;

use App\Enum\AiActor;

final class Message implements \JsonSerializable
{
public function __construct(
public AiActor $role,
public string $content,
) {
}

/**
* @return array{role: string, content: string}
*/
public function jsonSerialize(): array
{
return [
'role' => $this->role->value,
'content' => $this->content,
];
}
}
73 changes: 73 additions & 0 deletions src/Service/AiHorde/Message/MessageHistory.php
@@ -0,0 +1,73 @@
<?php

namespace App\Service\AiHorde\Message;

use ArrayAccess;
use ArrayIterator;
use Countable;
use InvalidArgumentException;
use IteratorAggregate;
use JsonSerializable;
use Traversable;

/**
* @implements IteratorAggregate<int, Message>
* @implements ArrayAccess<int, Message>
*/
final class MessageHistory implements IteratorAggregate, ArrayAccess, Countable, JsonSerializable
{
/**
* @var array<Message>
*/
private array $messages;

public function __construct(Message ...$messages)
{
$this->messages = $messages;
}

public function getIterator(): Traversable
{
return new ArrayIterator($this->messages);
}

public function offsetExists(mixed $offset): bool
{
return isset($this->messages[$offset]);
}

public function offsetGet(mixed $offset): Message
{
return $this->messages[$offset];
}

public function offsetSet(mixed $offset, mixed $value): void
{
if (!$value instanceof Message) {
throw new InvalidArgumentException('Only instances of ' . Message::class . ' are supported');
}
if ($offset !== null) {
$this->messages[$offset] = $value;
} else {
$this->messages[] = $value;
}
}

public function offsetUnset(mixed $offset): void
{
unset($this->messages[$offset]);
}

public function count(): int
{
return count($this->messages);
}

/**
* @return array<array{role: string, content: string}>
*/
public function jsonSerialize(): array
{
return array_map(fn (Message $message) => $message->jsonSerialize(), $this->messages);
}
}
43 changes: 43 additions & 0 deletions src/Service/AiHorde/MessageFormatter/ChatMLPromptFormat.php
@@ -0,0 +1,43 @@
<?php

namespace App\Service\AiHorde\MessageFormatter;

use App\Enum\AiActor;
use App\Enum\AiModel;
use App\Service\AiHorde\Message\Message;
use App\Service\AiHorde\Message\MessageHistory;

final readonly class ChatMLPromptFormat implements MessageFormatter
{
public function getPrompt(MessageHistory $messages): string
{
return trim(implode("\n", array_map(function (Message $message) {
return "<|im_start|>{$message->role->value}\n{$message->content}<|im_end|>";
}, [...$messages])));
}

public function formatOutput(string $message): Message
{
$role = 'assistant';
$message = trim($message);

if (str_starts_with($message, '<|im_start|>')) {
$message = substr($message, strlen('<|im_start|>'));
$parts = explode("\n", $message, 2);
$message = $parts[1];
$role = $parts[0];
}
if (str_ends_with($message, '<|im_end|>')) {
$message = substr($message, 0, -strlen('<|im_end|>'));
}

$role = AiActor::tryFrom($role) ?? AiActor::Assistant;

return new Message(role: $role, content: $message);
}

public function supports(AiModel $model): bool
{
return in_array($model, [AiModel::Mistral7BOpenHermes], true);
}
}
17 changes: 17 additions & 0 deletions src/Service/AiHorde/MessageFormatter/MessageFormatter.php
@@ -0,0 +1,17 @@
<?php

namespace App\Service\AiHorde\MessageFormatter;

use App\Enum\AiModel;
use App\Service\AiHorde\Message\Message;
use App\Service\AiHorde\Message\MessageHistory;
use Symfony\Component\DependencyInjection\Attribute\AutoconfigureTag;

#[AutoconfigureTag('app.message_formatter')]
interface MessageFormatter
{
public function getPrompt(MessageHistory $messages): string;

public function formatOutput(string $message): Message;
public function supports(AiModel $model): bool;
}
@@ -0,0 +1,15 @@
<?php

namespace App\Service\Expression;

use Closure;
use LogicException;
use Symfony\Component\ExpressionLanguage\ExpressionFunctionProviderInterface;

abstract readonly class AbstractExpressionLanguageFunctionProvider implements ExpressionFunctionProviderInterface
{
protected function uncompilableFunction(): Closure
{
return fn () => throw new LogicException('This function cannot be compiled');
}
}

0 comments on commit c762c0f

Please sign in to comment.