Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ThrottlingMiddleware example to docs #1413

Open
makarworld opened this issue Feb 12, 2024 · 1 comment
Open

Add ThrottlingMiddleware example to docs #1413

makarworld opened this issue Feb 12, 2024 · 1 comment
Labels
enhancement Make it better!

Comments

@makarworld
Copy link

makarworld commented Feb 12, 2024

aiogram version

3.x

Problem

I wanted to use ThrottlingMiddleware like in aiogram 2.x but there is no example in docs

Possible solution

I writed ThrottlingMiddleware for aiogram 3.x. It was first posted at stackoverflow.

Alternatives

No response

Code example

from __future__ import annotations
from typing import *
from aiogram import BaseMiddleware
from aiogram.types import Message
import redis.asyncio.client 
import time

def rate_limit(limit: int, key = None):
    """
    Decorator for configuring rate limit and key in different functions.

    :param limit:
    :param key:
    :return:
    """

    def decorator(func):
        setattr(func, 'throttling_rate_limit', limit)
        if key:
            setattr(func, 'throttling_key', key)
        return func

    return decorator

class ThrottlingMiddleware(BaseMiddleware):
    def __init__(self, redis: redis.asyncio.client.Redis, limit = .5, key_prefix = 'antiflood_'):
        self.rate_limit = limit
        self.prefix = key_prefix
        self.throttle_manager = ThrottleManager(redis = redis)

        super(ThrottlingMiddleware, self).__init__()

    async def __call__(
        self,
        handler: Callable[[Message, Dict[str, Any]], Awaitable[Any]],
        event: Message,
        data: Dict[str, Any]
    ) -> Any:

        try:
            await self.on_process_event(event, data)
        except CancelHandler:
            # Cancel current handler
            return

        try:
            result = await handler(event, data)
        except Exception as e:
            logger.exception(e)

        return result

    async def on_process_event(
        self, 
        event: Message,
        data: Dict[str, Any],
    ) -> Any:

        limit = getattr(data["handler"].callback, "throttling_rate_limit", self.rate_limit)
        key = getattr(data["handler"].callback, "throttling_key", f"{self.prefix}_message")

        # Use ThrottleManager.throttle method.
        try:
            await self.throttle_manager.throttle(key, rate = limit, user_id = event.from_user.id, chat_id = event.chat.id)
        except Throttled as t:
            # Execute action
            await self.event_throttled(event, t)

            # Cancel current handler
            raise CancelHandler()

    async def event_throttled(self, event: Message, throttled: Throttled):
        # Calculate how many time is left till the block ends
        delta = throttled.rate - throttled.delta

        # Prevent flooding
        if throttled.exceeded_count <= 2:
            await event.answer(f'Too many events.\nTry again in {delta:.2f} seconds.')


class ThrottleManager:
    bucket_keys = [
        "RATE_LIMIT", "DELTA",
        "LAST_CALL", "EXCEEDED_COUNT"
    ]
    def __init__(self, redis: redis.asyncio.client.Redis):
        self.redis = redis

    async def throttle(self, key: str, rate: float, user_id: int, chat_id: int):
        now = time.time()
        bucket_name = f'throttle_{key}_{user_id}_{chat_id}'

        data = await self.redis.hmget(bucket_name, self.bucket_keys)
        data = {
            k: float(v.decode()) 
               if isinstance(v, bytes) 
               else v 
            for k, v in zip(self.bucket_keys, data) 
            if v is not None
        }

        # Calculate
        called = data.get("LAST_CALL", now)
        delta = now - called
        result = delta >= rate or delta <= 0

        # Save result
        data["RATE_LIMIT"] = rate
        data["LAST_CALL"] = now
        data["DELTA"] = delta
        if not result:
            data["EXCEEDED_COUNT"] += 1
        else:
            data["EXCEEDED_COUNT"] = 1

        await self.redis.hmset(bucket_name, data)

        if not result:
            raise Throttled(key=key, chat=chat_id, user=user_id, **data)
        
        return result

class Throttled(Exception):
    def __init__(self, **kwargs):
        self.key = kwargs.pop("key", '<None>')
        self.called_at = kwargs.pop("LAST_CALL", time.time())
        self.rate = kwargs.pop("RATE_LIMIT", None)
        self.exceeded_count = kwargs.pop("EXCEEDED_COUNT", 0)
        self.delta = kwargs.pop("DELTA", 0)
        self.user = kwargs.pop('user', None)
        self.chat = kwargs.pop('chat', None)

    def __str__(self):
        return f"Rate limit exceeded! (Limit: {self.rate} s, " \
               f"exceeded: {self.exceeded_count}, " \
               f"time delta: {round(self.delta, 3)} s)"

class CancelHandler(Exception):
    pass

Additional information

I hope you can enhance my code for add it to docs. It may help other people who want to use ThrottlingMiddleware

@makarworld makarworld added the enhancement Make it better! label Feb 12, 2024
@makarworld
Copy link
Author

Added @rate_limit decorator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Make it better!
Projects
None yet
Development

No branches or pull requests

1 participant