Skip to content

Commit

Permalink
Add test framework for server
Browse files Browse the repository at this point in the history
  • Loading branch information
dbarbuzzi committed Apr 22, 2024
1 parent e8e00d2 commit 952a0db
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import gc
import logging
import os
from typing import List, Optional, Tuple

Expand All @@ -9,6 +10,7 @@
from transformers import (AutoModelForCausalLM, AutoProcessor,
LlavaForConditionalGeneration)

from tests.utils.logging import make_logger
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel
Expand Down Expand Up @@ -547,3 +549,8 @@ def get_tokenizer_pool_config(tokenizer_group_type):
pool_type="ray",
extra_config={})
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")


@pytest.fixture(scope="session")
def logger() -> logging.Logger:
return make_logger("vllm_test")
33 changes: 33 additions & 0 deletions tests/utils/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging


def make_logger(name: str) -> logging.Logger:
"""Create a base logger"""

logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
return logger


def log_banner(logger: logging.Logger,
label: str,
body: str,
level: int = logging.INFO):
"""
Log a message in the "banner"-style format.
:param logger: Instance of "logging.Logger" to use
:param label: Label for the top of the banner
:param body: Body content inside the banner
:param level: Logging level to use (default: INFO)
"""

banner = f"==== {label} ====\n{body}\n===="
logger.log(level, "\n%s", banner)
70 changes: 70 additions & 0 deletions tests/utils/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import logging
import shlex
from typing import Any, Dict, List

import ray

from tests.entrypoints.test_openai_server import ServerRunner
from tests.utils.logging import log_banner


class ServerContext:
"""
Context manager for the lifecycle of a vLLM server, wrapping `ServerRunner`.
"""

def __init__(self, args: Dict[str, str], *,
logger: logging.Logger) -> None:
"""Initialize a vLLM server
:param args: dictionary of flags/values to pass to the server command
:param logger: logging.Logger instance to use for logging
:param port: port the server is running on
"""
self._args = self._args_to_list(args)
self._logger = logger
self.server_runner = None

def __enter__(self):
"""Executes the server process and waits for it to become ready."""
log_banner(
self._logger,
"server startup command args",
shlex.join(self._args),
logging.DEBUG,
)

ray.init(ignore_reinit_error=True)
self.server_runner = ServerRunner.remote(self._args)
ray.get(self.server_runner.ready.remote())
return self.server_runner

def __exit__(self, exc_type, exc_value, exc_traceback):
"""
Stops the server if it's still running and captures/logs its output.
"""
if self.server_runner is not None:
del self.server_runner
ray.shutdown()

def _args_to_list(self, args: Dict[str, Any]) -> List[str]:
"""
Convert a dict mapping of CLI args to a list. All values must be
string-able.
:param args: `dict` containing CLI flags and their values
:return: flattened list to pass to a CLI
"""

arg_list: List[str] = []
for flag, value in args.items():
# minimal error-checking: flag names must be strings
if not isinstance(flag, str):
error = f"all flags must be strings, got {type(flag)} ({flag})"
raise ValueError(error)

arg_list.append(flag)
if value is not None:
arg_list.append(str(value))

return arg_list

0 comments on commit 952a0db

Please sign in to comment.