forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
110 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import logging | ||
|
||
|
||
def make_logger(name: str) -> logging.Logger: | ||
"""Create a base logger""" | ||
|
||
logger = logging.getLogger(name) | ||
logger.setLevel(logging.DEBUG) | ||
stream_handler = logging.StreamHandler() | ||
formatter = logging.Formatter( | ||
"%(asctime)s - %(levelname)s - %(name)s - %(message)s", | ||
datefmt="%Y-%m-%d %H:%M:%S", | ||
) | ||
stream_handler.setFormatter(formatter) | ||
logger.addHandler(stream_handler) | ||
return logger | ||
|
||
|
||
def log_banner(logger: logging.Logger, | ||
label: str, | ||
body: str, | ||
level: int = logging.INFO): | ||
""" | ||
Log a message in the "banner"-style format. | ||
:param logger: Instance of "logging.Logger" to use | ||
:param label: Label for the top of the banner | ||
:param body: Body content inside the banner | ||
:param level: Logging level to use (default: INFO) | ||
""" | ||
|
||
banner = f"==== {label} ====\n{body}\n====" | ||
logger.log(level, "\n%s", banner) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import logging | ||
import shlex | ||
from typing import Any, Dict, List | ||
|
||
import ray | ||
|
||
from tests.entrypoints.test_openai_server import ServerRunner | ||
from tests.utils.logging import log_banner | ||
|
||
|
||
class ServerContext: | ||
""" | ||
Context manager for the lifecycle of a vLLM server, wrapping `ServerRunner`. | ||
""" | ||
|
||
def __init__(self, args: Dict[str, str], *, | ||
logger: logging.Logger) -> None: | ||
"""Initialize a vLLM server | ||
:param args: dictionary of flags/values to pass to the server command | ||
:param logger: logging.Logger instance to use for logging | ||
:param port: port the server is running on | ||
""" | ||
self._args = self._args_to_list(args) | ||
self._logger = logger | ||
self.server_runner = None | ||
|
||
def __enter__(self): | ||
"""Executes the server process and waits for it to become ready.""" | ||
log_banner( | ||
self._logger, | ||
"server startup command args", | ||
shlex.join(self._args), | ||
logging.DEBUG, | ||
) | ||
|
||
ray.init(ignore_reinit_error=True) | ||
self.server_runner = ServerRunner.remote(self._args) | ||
ray.get(self.server_runner.ready.remote()) | ||
return self.server_runner | ||
|
||
def __exit__(self, exc_type, exc_value, exc_traceback): | ||
""" | ||
Stops the server if it's still running and captures/logs its output. | ||
""" | ||
if self.server_runner is not None: | ||
del self.server_runner | ||
ray.shutdown() | ||
|
||
def _args_to_list(self, args: Dict[str, Any]) -> List[str]: | ||
""" | ||
Convert a dict mapping of CLI args to a list. All values must be | ||
string-able. | ||
:param args: `dict` containing CLI flags and their values | ||
:return: flattened list to pass to a CLI | ||
""" | ||
|
||
arg_list: List[str] = [] | ||
for flag, value in args.items(): | ||
# minimal error-checking: flag names must be strings | ||
if not isinstance(flag, str): | ||
error = f"all flags must be strings, got {type(flag)} ({flag})" | ||
raise ValueError(error) | ||
|
||
arg_list.append(flag) | ||
if value is not None: | ||
arg_list.append(str(value)) | ||
|
||
return arg_list |