Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test framework for server #200

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import gc
import logging
import os
from typing import List, Optional, Tuple

Expand All @@ -9,6 +10,7 @@
from transformers import (AutoModelForCausalLM, AutoProcessor,
LlavaForConditionalGeneration)

from tests.utils.logging import make_logger
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel
Expand Down Expand Up @@ -555,3 +557,8 @@ def get_tokenizer_pool_config(tokenizer_group_type):
pool_type="ray",
extra_config={})
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")


@pytest.fixture(scope="session")
def logger() -> logging.Logger:
return make_logger("vllm_test")
Empty file added tests/utils/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions tests/utils/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging


def make_logger(name: str) -> logging.Logger:
"""Create a base logger"""

logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
return logger


def log_banner(logger: logging.Logger,
label: str,
body: str,
level: int = logging.INFO):
"""
Log a message in the "banner"-style format.

:param logger: Instance of "logging.Logger" to use
:param label: Label for the top of the banner
:param body: Body content inside the banner
:param level: Logging level to use (default: INFO)
"""

banner = f"==== {label} ====\n{body}\n===="
logger.log(level, "\n%s", banner)
133 changes: 133 additions & 0 deletions tests/utils/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import logging
import os
import shlex
import subprocess
import sys
import time
from typing import Any, Dict, List, Optional

import ray
import requests
import torch

from tests.utils.logging import log_banner

MAX_SERVER_START_WAIT = 600 # time (seconds) to wait for server to start


@ray.remote(num_gpus=torch.cuda.device_count())
class ServerRunner:

def __init__(self,
args: List[str],
*,
logger: Optional[logging.Logger] = None):
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.startup_command = [
sys.executable,
"-m",
"vllm.entrypoints.openai.api_server",
*args,
]

if logger:
log_banner(
logger,
"server startup command",
shlex.join(self.startup_command),
logging.DEBUG,
)

self.proc = subprocess.Popen(
[
sys.executable, "-m", "vllm.entrypoints.openai.api_server",
*args
],
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
self._wait_for_server()

def ready(self):
return True

def _wait_for_server(self):
# run health check
start = time.time()
while True:
try:
if requests.get(
"http://localhost:8000/health").status_code == 200:
break
except Exception as err:
if self.proc.poll() is not None:
raise RuntimeError("Server exited unexpectedly.") from err

time.sleep(0.5)
if time.time() - start > MAX_SERVER_START_WAIT:
raise RuntimeError(
"Server failed to start in time.") from err

def __del__(self):
if hasattr(self, "proc"):
self.proc.terminate()


class ServerContext:
"""
Context manager for the lifecycle of a vLLM server, wrapping `ServerRunner`.
"""

def __init__(self, args: Dict[str, str], *,
logger: logging.Logger) -> None:
"""Initialize a vLLM server

:param args: dictionary of flags/values to pass to the server command
:param logger: logging.Logger instance to use for logging
:param port: port the server is running on
"""
self._args = self._args_to_list(args)
self._logger = logger
self.server_runner = None

def __enter__(self):
"""Executes the server process and waits for it to become ready."""
ray.init(ignore_reinit_error=True)
log_banner(self._logger, "server startup command args",
shlex.join(self._args))
self.server_runner = ServerRunner.remote(self._args,
logger=self._logger)
ray.get(self.server_runner.ready.remote())
return self.server_runner

def __exit__(self, exc_type, exc_value, exc_traceback):
"""
Stops the server if it's still running.
"""
if self.server_runner is not None:
del self.server_runner
ray.shutdown()

def _args_to_list(self, args: Dict[str, Any]) -> List[str]:
"""
Convert a dict mapping of CLI args to a list. All values must be
string-able.

:param args: `dict` containing CLI flags and their values
:return: flattened list to pass to a CLI
"""

arg_list: List[str] = []
for flag, value in args.items():
# minimal error-checking: flag names must be strings
if not isinstance(flag, str):
error = f"all flags must be strings, got {type(flag)} ({flag})"
raise ValueError(error)

arg_list.append(flag)
if value is not None:
arg_list.append(str(value))

return arg_list