From b41e61907d801b3b367f0ae8a32bd21af1c24076 Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Fri, 23 Feb 2024 20:07:29 +0100 Subject: [PATCH 01/19] Require Mopidy 4.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 32d30a6..9e23219 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", "Topic :: Multimedia :: Sound/Audio :: Players", ] -dependencies = ["mopidy >= 3.3.0", "pykka >= 4.0", "setuptools >= 66"] +dependencies = ["mopidy >= 4.0.0a1", "pykka >= 4.0", "setuptools >= 66"] [project.optional-dependencies] lint = ["ruff"] From d69fbfcc9c76fa7d96622dc0d716c30b828b4517 Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Fri, 23 Feb 2024 14:45:55 +0100 Subject: [PATCH 02/19] Make ruff enforce the presence of type hints --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9e23219..24a7e02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,7 +80,9 @@ select = [ "W", # pycodestyle ] ignore = [ - "ANN", # flake8-annotations + "ANN101", # missing-type-self + "ANN102", # missing-type-cls + "ANN401", # any-type "D", # pydocstyle "ISC001", # single-line-implicit-string-concatenation "TRY003", # raise-vanilla-args From 0a57222deece4fb96ce81f5b52696497f5e9369c Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Fri, 23 Feb 2024 16:13:18 +0100 Subject: [PATCH 03/19] Add type hints to everything --- src/mopidy_mpd/__init__.py | 6 +- src/mopidy_mpd/actor.py | 32 ++- src/mopidy_mpd/dispatcher.py | 261 ++++++++++++++------ src/mopidy_mpd/exceptions.py | 39 ++- src/mopidy_mpd/formatting.py | 12 +- src/mopidy_mpd/network.py | 151 ++++++----- src/mopidy_mpd/protocol/__init__.py | 79 ++++-- src/mopidy_mpd/protocol/audio_output.py | 15 +- src/mopidy_mpd/protocol/channels.py | 17 +- src/mopidy_mpd/protocol/command_list.py | 13 +- src/mopidy_mpd/protocol/connection.py | 19 +- src/mopidy_mpd/protocol/current_playlist.py | 68 ++--- src/mopidy_mpd/protocol/mount.py | 17 +- src/mopidy_mpd/protocol/music_db.py | 110 +++++---- src/mopidy_mpd/protocol/playback.py | 126 +++++----- src/mopidy_mpd/protocol/reflection.py | 17 +- src/mopidy_mpd/protocol/status.py | 207 ++++++---------- src/mopidy_mpd/protocol/stickers.py | 18 +- src/mopidy_mpd/protocol/stored_playlists.py | 103 +++++--- src/mopidy_mpd/session.py | 35 ++- src/mopidy_mpd/tokenize.py | 6 +- src/mopidy_mpd/translator.py | 100 +++++--- src/mopidy_mpd/types.py | 17 ++ src/mopidy_mpd/uri_mapper.py | 38 +-- tests/network/test_lineprotocol.py | 9 - tests/network/test_server.py | 9 +- tests/protocol/test_music_db.py | 17 +- tests/test_commands.py | 90 +++---- tests/test_dispatcher.py | 41 +-- tests/test_translator.py | 19 +- 30 files changed, 1021 insertions(+), 670 deletions(-) create mode 100644 src/mopidy_mpd/types.py diff --git a/src/mopidy_mpd/__init__.py b/src/mopidy_mpd/__init__.py index 6647f15..3cb7ad4 100644 --- a/src/mopidy_mpd/__init__.py +++ b/src/mopidy_mpd/__init__.py @@ -11,10 +11,10 @@ class Extension(ext.Extension): ext_name = "mpd" version = __version__ - def get_default_config(self): + def get_default_config(self) -> str: return config.read(pathlib.Path(__file__).parent / "ext.conf") - def get_config_schema(self): + def get_config_schema(self) -> config.ConfigSchema: schema = super().get_config_schema() schema["hostname"] = config.Hostname() schema["port"] = config.Port(optional=True) @@ -26,7 +26,7 @@ def get_config_schema(self): schema["default_playlist_scheme"] = config.String() return schema - def setup(self, registry): + def setup(self, registry: ext.Registry) -> None: from .actor import MpdFrontend registry.add("frontend", MpdFrontend) diff --git a/src/mopidy_mpd/actor.py b/src/mopidy_mpd/actor.py index 25e6cd8..5795811 100644 --- a/src/mopidy_mpd/actor.py +++ b/src/mopidy_mpd/actor.py @@ -1,10 +1,12 @@ import logging +from typing import Any, cast import pykka from mopidy import exceptions, listener, zeroconf -from mopidy.core import CoreListener +from mopidy.config import Config +from mopidy.core import CoreListener, CoreProxy -from mopidy_mpd import network, session, uri_mapper +from mopidy_mpd import network, session, types, uri_mapper logger = logging.getLogger(__name__) @@ -27,19 +29,23 @@ class MpdFrontend(pykka.ThreadingActor, CoreListener): - def __init__(self, config, core): + def __init__(self, config: Config, core: CoreProxy) -> None: super().__init__() - self.hostname = network.format_hostname(config["mpd"]["hostname"]) - self.port = config["mpd"]["port"] + mpd_config = cast(types.MpdConfig, config.get("mpd", {})) + + self.hostname = network.format_hostname(mpd_config["hostname"]) + self.port = mpd_config["port"] self.uri_map = uri_mapper.MpdUriMapper(core) - self.zeroconf_name = config["mpd"]["zeroconf"] + self.zeroconf_name = mpd_config["zeroconf"] self.zeroconf_service = None self.server = self._setup_server(config, core) - def _setup_server(self, config, core): + def _setup_server(self, config: Config, core: CoreProxy) -> network.Server: + mpd_config = cast(types.MpdConfig, config.get("mpd", {})) + try: server = network.Server( self.hostname, @@ -50,8 +56,8 @@ def _setup_server(self, config, core): "core": core, "uri_map": self.uri_map, }, - max_connections=config["mpd"]["max_connections"], - timeout=config["mpd"]["connection_timeout"], + max_connections=mpd_config["max_connections"], + timeout=mpd_config["connection_timeout"], ) except OSError as exc: raise exceptions.FrontendError(f"MPD server startup failed: {exc}") from exc @@ -60,14 +66,14 @@ def _setup_server(self, config, core): return server - def on_start(self): + def on_start(self) -> None: if self.zeroconf_name and not network.is_unix_socket(self.server.server_socket): self.zeroconf_service = zeroconf.Zeroconf( name=self.zeroconf_name, stype="_mpd._tcp", port=self.port ) self.zeroconf_service.publish() - def on_stop(self): + def on_stop(self) -> None: if self.zeroconf_service: self.zeroconf_service.unpublish() @@ -77,12 +83,12 @@ def on_stop(self): self.server.stop() - def on_event(self, event, **kwargs): + def on_event(self, event: str, **kwargs: Any) -> None: if event not in _CORE_EVENTS_TO_IDLE_SUBSYSTEMS: logger.warning("Got unexpected event: %s(%s)", event, ", ".join(kwargs)) else: self.send_idle(_CORE_EVENTS_TO_IDLE_SUBSYSTEMS[event]) - def send_idle(self, subsystem): + def send_idle(self, subsystem: str | None) -> None: if subsystem: listener.send(session.MpdSession, subsystem) diff --git a/src/mopidy_mpd/dispatcher.py b/src/mopidy_mpd/dispatcher.py index 52bbea3..cf71a7f 100644 --- a/src/mopidy_mpd/dispatcher.py +++ b/src/mopidy_mpd/dispatcher.py @@ -1,41 +1,83 @@ +from __future__ import annotations + import logging import re +from collections.abc import Callable, Generator +from typing import ( + TYPE_CHECKING, + Any, + Literal, + NewType, + TypeAlias, + TypeVar, + cast, + overload, +) import pykka -from mopidy_mpd import exceptions, protocol, tokenize +from mopidy_mpd import exceptions, protocol, tokenize, types + +if TYPE_CHECKING: + from mopidy.core import CoreProxy + from mopidy.ext import Config + from mopidy.models import Ref, Track + from mopidy.types import Uri + + from mopidy_mpd.session import MpdSession + from mopidy_mpd.uri_mapper import MpdUriMapper + logger = logging.getLogger(__name__) protocol.load_protocol_modules() +T = TypeVar("T") +Request: TypeAlias = str +Response = NewType("Response", list[str]) +Filter: TypeAlias = Callable[[Request, Response, list["Filter"]], Response] -class MpdDispatcher: +class MpdDispatcher: """ The MPD session feeds the MPD dispatcher with requests. The dispatcher - finds the correct handler, processes the request and sends the response + finds the correct handler, processes the request, and sends the response back to the MPD session. """ _noidle = re.compile(r"^noidle$") - def __init__(self, session=None, config=None, core=None, uri_map=None): + def __init__( + self, + config: Config, + core: CoreProxy, + session: MpdSession, + uri_map: MpdUriMapper, + ) -> None: self.config = config + self.mpd_config = cast(types.MpdConfig, config.get("mpd", {}) if config else {}) self.authenticated = False self.command_list_receiving = False self.command_list_ok = False self.command_list = [] self.command_list_index = None self.context = MpdContext( - self, session=session, config=config, core=core, uri_map=uri_map + core=core, + dispatcher=self, + session=session, + config=config, + uri_map=uri_map, ) - def handle_request(self, request, current_command_list_index=None): + def handle_request( + self, + request: Request, + current_command_list_index: int | None = None, + ) -> Response: """Dispatch incoming requests to the correct handler.""" self.command_list_index = current_command_list_index - response = [] - filter_chain = [ + response: Response = [] + filter_chain: list[Filter] = [ self._catch_mpd_ack_errors_filter, self._authenticate_filter, self._command_list_filter, @@ -45,7 +87,7 @@ def handle_request(self, request, current_command_list_index=None): ] return self._call_next_filter(request, response, filter_chain) - def handle_idle(self, subsystem): + def handle_idle(self, subsystem: str) -> None: # TODO: validate against mopidy_mpd/protocol/status.SUBSYSTEMS self.context.events.add(subsystem) @@ -53,7 +95,7 @@ def handle_idle(self, subsystem): if not subsystems: return - response = [] + response: list[str] = [] for subsystem in subsystems: response.append(f"changed: {subsystem}") response.append("OK") @@ -61,7 +103,9 @@ def handle_idle(self, subsystem): self.context.events = set() self.context.session.send_lines(response) - def _call_next_filter(self, request, response, filter_chain): + def _call_next_filter( + self, request: str, response: Response, filter_chain: list[Filter] + ) -> Response: if filter_chain: next_filter = filter_chain.pop(0) return next_filter(request, response, filter_chain) @@ -69,17 +113,27 @@ def _call_next_filter(self, request, response, filter_chain): # --- Filter: catch MPD ACK errors - def _catch_mpd_ack_errors_filter(self, request, response, filter_chain): + def _catch_mpd_ack_errors_filter( + self, + request: Request, + response: Response, + filter_chain: list[Filter], + ) -> Response: try: return self._call_next_filter(request, response, filter_chain) except exceptions.MpdAckError as mpd_ack_error: if self.command_list_index is not None: mpd_ack_error.index = self.command_list_index - return [mpd_ack_error.get_mpd_ack()] + return Response([mpd_ack_error.get_mpd_ack()]) # --- Filter: authenticate - def _authenticate_filter(self, request, response, filter_chain): + def _authenticate_filter( + self, + request: Request, + response: Response, + filter_chain: list[Filter], + ) -> Response: if self.authenticated: return self._call_next_filter(request, response, filter_chain) @@ -97,7 +151,12 @@ def _authenticate_filter(self, request, response, filter_chain): # --- Filter: command list - def _command_list_filter(self, request, response, filter_chain): + def _command_list_filter( + self, + request: Request, + response: Response, + filter_chain: list[Filter], + ) -> Response: if self._is_receiving_command_list(request): self.command_list.append(request) return [] @@ -111,18 +170,23 @@ def _command_list_filter(self, request, response, filter_chain): and response and response[-1] == "OK" ): - response = response[:-1] + response = Response(response[:-1]) return response - def _is_receiving_command_list(self, request): + def _is_receiving_command_list(self, request: str) -> bool: return self.command_list_receiving and request != "command_list_end" - def _is_processing_command_list(self, request): + def _is_processing_command_list(self, request: str) -> bool: return self.command_list_index is not None and request != "command_list_end" # --- Filter: idle - def _idle_filter(self, request, response, filter_chain): + def _idle_filter( + self, + request: Request, + response: Response, + filter_chain: list[Filter], + ) -> Response: if self._is_currently_idle() and not self._noidle.match(request): logger.debug( "Client sent us %s, only %s is allowed while in " "the idle state", @@ -142,127 +206,168 @@ def _idle_filter(self, request, response, filter_chain): return response - def _is_currently_idle(self): + def _is_currently_idle(self) -> bool: return bool(self.context.subscriptions) # --- Filter: add OK - def _add_ok_filter(self, request, response, filter_chain): + def _add_ok_filter( + self, + request: Request, + response: Response, + filter_chain: list[Filter], + ) -> Response: response = self._call_next_filter(request, response, filter_chain) if not self._has_error(response): response.append("OK") return response - def _has_error(self, response): - return response and response[-1].startswith("ACK") + def _has_error(self, response: Response) -> bool: + return bool(response) and response[-1].startswith("ACK") # --- Filter: call handler - def _call_handler_filter(self, request, response, filter_chain): + def _call_handler_filter( + self, + request: Request, + response: Response, + filter_chain: list[Filter], + ) -> Response: try: - response = self._format_response(self._call_handler(request)) + result = self._call_handler(request) + response = self._format_response(result) return self._call_next_filter(request, response, filter_chain) except pykka.ActorDeadError as exc: logger.warning("Tried to communicate with dead actor.") - raise exceptions.MpdSystemError(exc) from exc + raise exceptions.MpdSystemError(str(exc)) from exc - def _call_handler(self, request): + def _call_handler(self, request: str) -> protocol.Result: tokens = tokenize.split(request) # TODO: check that blacklist items are valid commands? - blacklist = self.config["mpd"].get("command_blacklist", []) + blacklist = self.mpd_config.get("command_blacklist", []) if tokens and tokens[0] in blacklist: logger.warning("MPD client used blacklisted command: %s", tokens[0]) raise exceptions.MpdDisabledError(command=tokens[0]) try: - return protocol.commands.call(tokens, context=self.context) + return protocol.commands.call( + context=self.context, + tokens=tokens, + ) except exceptions.MpdAckError as exc: if exc.command is None: exc.command = tokens[0] raise - def _format_response(self, response): - formatted_response = [] - for element in self._listify_result(response): - formatted_response.extend(self._format_lines(element)) - return formatted_response + def _format_response(self, result: protocol.Result) -> Response: + response = Response([]) + for element in self._listify_result(result): + response.extend(self._format_lines(element)) + return response - def _listify_result(self, result): - if result is None: - return [] - if isinstance(result, set): - return self._flatten(list(result)) - if not isinstance(result, list): - return [result] - return self._flatten(result) - - def _flatten(self, the_list): - result = [] - for element in the_list: + def _listify_result(self, result: protocol.Result) -> protocol.ResultList: + match result: + case None: + return [] + case list(): + return self._flatten(result) + case _: + return [result] + + def _flatten(self, lst: protocol.ResultList) -> protocol.ResultList: + result: protocol.ResultList = [] + for element in lst: if isinstance(element, list): result.extend(self._flatten(element)) else: result.append(element) return result - def _format_lines(self, line): - if isinstance(line, dict): - return [f"{key}: {value}" for (key, value) in line.items()] - if isinstance(line, tuple): - (key, value) = line - return [f"{key}: {value}"] - return [line] + def _format_lines( + self, element: protocol.ResultDict | protocol.ResultTuple | str + ) -> Response: + if isinstance(element, dict): + return Response([f"{key}: {value}" for (key, value) in element.items()]) + if isinstance(element, tuple): + (key, value) = element + return Response([f"{key}: {value}"]) + return Response([element]) class MpdContext: - """ This object is passed as the first argument to all MPD command handlers to give the command handlers access to important parts of Mopidy. """ - #: The current :class:`MpdDispatcher`. - dispatcher = None - - #: The current :class:`mopidy_mpd.MpdSession`. - session = None + #: The Mopidy core API. + core: CoreProxy - #: The MPD password - password = None + _uri_map: MpdUriMapper - #: The Mopidy core API. An instance of :class:`mopidy.core.Core`. - core = None + #: The current dispatcher instance. + dispatcher: MpdDispatcher - #: The active subsystems that have pending events. - events = None + #: The current session instance. + session: MpdSession - #: The subsytems that we want to be notified about in idle mode. - subscriptions = None + #: The MPD password. + password: str | None = None - _uri_map = None - - def __init__(self, dispatcher, session=None, config=None, core=None, uri_map=None): # noqa: PLR0913 + #: The active subsystems that have pending events. + events: set[str] + + #: The subsystems that we want to be notified about in idle mode. + subscriptions: set[str] + + def __init__( # noqa: PLR0913 + self, + config: Config, + core: CoreProxy, + uri_map: MpdUriMapper, + dispatcher: MpdDispatcher, + session: MpdSession, + ) -> None: + self.core = core + self._uri_map = uri_map self.dispatcher = dispatcher self.session = session if config is not None: - self.password = config["mpd"]["password"] - self.core = core + mpd_config = cast(types.MpdConfig, config["mpd"]) + self.password = mpd_config["password"] self.events = set() self.subscriptions = set() - self._uri_map = uri_map - def lookup_playlist_uri_from_name(self, name): + def lookup_playlist_uri_from_name(self, name: str) -> Uri | None: """ Helper function to retrieve a playlist from its unique MPD name. """ return self._uri_map.playlist_uri_from_name(name) - def lookup_playlist_name_from_uri(self, uri): + def lookup_playlist_name_from_uri(self, uri: Uri) -> str | None: """ Helper function to retrieve the unique MPD playlist name from its uri. """ return self._uri_map.playlist_name_from_uri(uri) - def browse(self, path, *, recursive=True, lookup=True): # noqa: C901, PLR0912 + @overload + def browse( + self, path: str | None, *, recursive: bool, lookup: Literal[True] + ) -> Generator[tuple[str, pykka.Future[dict[Uri, list[Track]]] | None], Any, None]: + ... + + @overload + def browse( + self, path: str | None, *, recursive: bool, lookup: Literal[False] + ) -> Generator[tuple[str, Ref | None], Any, None]: + ... + + def browse( # noqa: C901, PLR0912 + self, + path: str | None, + *, + recursive: bool = True, + lookup: bool = True, + ) -> Generator[Any, Any, None]: """ Browse the contents of a given directory path. @@ -281,8 +386,8 @@ def browse(self, path, *, recursive=True, lookup=True): # noqa: C901, PLR0912 :class:`None`. """ - path_parts = re.findall(r"[^/]+", path or "") - root_path = "/".join(["", *path_parts]) + path_parts: list[str] = re.findall(r"[^/]+", path or "") + root_path: str = "/".join(["", *path_parts]) uri = self._uri_map.uri_from_name(root_path) if uri is None: diff --git a/src/mopidy_mpd/exceptions.py b/src/mopidy_mpd/exceptions.py index 10eac86..39675a5 100644 --- a/src/mopidy_mpd/exceptions.py +++ b/src/mopidy_mpd/exceptions.py @@ -1,4 +1,7 @@ +from typing import Any + from mopidy.exceptions import MopidyException +from mopidy.types import UriScheme class MpdAckError(MopidyException): @@ -20,13 +23,18 @@ class MpdAckError(MopidyException): error_code = 0 - def __init__(self, message="", index=0, command=None): + def __init__( + self, + message: str = "", + index: int = 0, + command: str | None = None, + ) -> None: super().__init__(message, index, command) self.message = message self.index = index self.command = command - def get_mpd_ack(self): + def get_mpd_ack(self) -> str: """ MPD error code format:: @@ -49,7 +57,7 @@ class MpdPasswordError(MpdAckError): class MpdPermissionError(MpdAckError): error_code = MpdAckError.ACK_ERROR_PERMISSION - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) assert self.command is not None, "command must be given explicitly" self.message = f'you don\'t have permission for "{self.command}"' @@ -60,7 +68,7 @@ class MpdUnknownError(MpdAckError): class MpdUnknownCommandError(MpdUnknownError): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) assert self.command is not None, "command must be given explicitly" self.message = f'unknown command "{self.command}"' @@ -68,7 +76,7 @@ def __init__(self, *args, **kwargs): class MpdNoCommandError(MpdUnknownCommandError): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs["command"] = "" super().__init__(*args, **kwargs) self.message = "No command given" @@ -89,7 +97,7 @@ class MpdSystemError(MpdAckError): class MpdInvalidPlaylistNameError(MpdAckError): error_code = MpdAckError.ACK_ERROR_ARG - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.message = ( "playlist name is invalid: playlist names may not " @@ -100,7 +108,7 @@ def __init__(self, *args, **kwargs): class MpdNotImplementedError(MpdAckError): error_code = 0 - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.message = "Not implemented" @@ -109,7 +117,13 @@ class MpdInvalidTrackForPlaylistError(MpdAckError): # NOTE: This is a custom error for Mopidy that does not exist in MPD. error_code = 0 - def __init__(self, playlist_scheme, track_scheme, *args, **kwargs): + def __init__( + self, + playlist_scheme: UriScheme, + track_scheme: UriScheme, + *args: Any, + **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) self.message = ( f'Playlist with scheme "{playlist_scheme}" ' @@ -121,7 +135,12 @@ class MpdFailedToSavePlaylistError(MpdAckError): # NOTE: This is a custom error for Mopidy that does not exist in MPD. error_code = 0 - def __init__(self, backend_scheme, *args, **kwargs): + def __init__( + self, + backend_scheme: UriScheme, + *args: Any, + **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) self.message = f'Backend with scheme "{backend_scheme}" failed to save playlist' @@ -130,7 +149,7 @@ class MpdDisabledError(MpdAckError): # NOTE: This is a custom error for Mopidy that does not exist in MPD. error_code = 0 - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) assert self.command is not None, "command must be given explicitly" self.message = f'"{self.command}" has been disabled in the server' diff --git a/src/mopidy_mpd/formatting.py b/src/mopidy_mpd/formatting.py index bd0cdd5..57d5696 100644 --- a/src/mopidy_mpd/formatting.py +++ b/src/mopidy_mpd/formatting.py @@ -1,7 +1,13 @@ -def indent(string, *, places=4, linebreak="\n", singles=False): - lines = string.split(linebreak) +def indent( + value: str, + *, + places: int = 4, + linebreak: str = "\n", + singles: bool = False, +) -> str: + lines = value.split(linebreak) if not singles and len(lines) == 1: - return string + return value for i, line in enumerate(lines): lines[i] = " " * places + line result = linebreak.join(lines) diff --git a/src/mopidy_mpd/network.py b/src/mopidy_mpd/network.py index 54afacb..fe4bf98 100644 --- a/src/mopidy_mpd/network.py +++ b/src/mopidy_mpd/network.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import errno import logging @@ -6,17 +8,24 @@ import socket import sys import threading +from typing import TYPE_CHECKING, Any, NoReturn import pykka from gi.repository import GLib logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from collections.abc import Generator + from types import TracebackType + + from mopidy_mpd.session import MpdSession + from mopidy_mpd.types import SocketAddress CONTROL_CHARS = dict.fromkeys(range(32)) -def get_systemd_socket(): +def get_systemd_socket() -> socket.socket | None: """Attempt to get a socket from systemd.""" fdnames = os.environ.get("LISTEN_FDNAMES", "").split(":") if "mpd" not in fdnames: @@ -25,21 +34,21 @@ def get_systemd_socket(): return socket.socket(fileno=fd) -def get_unix_socket_path(socket_path): +def get_unix_socket_path(socket_path: str) -> str | None: match = re.search("^unix:(.*)", socket_path) if not match: return None return match.group(1) -def is_unix_socket(sock): +def is_unix_socket(sock: socket.socket) -> bool: """Check if the provided socket is a Unix domain socket""" if hasattr(socket, "AF_UNIX"): return sock.family == socket.AF_UNIX return False -def get_socket_address(host, port): +def get_socket_address(host: str, port: int) -> tuple[str, int | None]: unix_socket_path = get_unix_socket_path(host) if unix_socket_path is not None: return (unix_socket_path, None) @@ -51,7 +60,7 @@ class ShouldRetrySocketCallError(Exception): """Indicate that attempted socket call should be retried""" -def try_ipv6_socket(): +def try_ipv6_socket() -> bool: """Determine if system really supports IPv6""" if not socket.has_ipv6: return False @@ -70,7 +79,7 @@ def try_ipv6_socket(): has_ipv6 = try_ipv6_socket() -def create_tcp_socket(): +def create_tcp_socket() -> socket.socket: """Create a TCP socket with or without IPv6 depending on system support""" if has_ipv6: sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) @@ -87,12 +96,12 @@ def create_tcp_socket(): return sock -def create_unix_socket(): +def create_unix_socket() -> socket.socket: """Create a Unix domain socket""" return socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) -def format_address(address): +def format_address(address: SocketAddress) -> str: """Format socket address for display.""" host, port = address[:2] if port is not None: @@ -100,7 +109,7 @@ def format_address(address): return f"[{host}]" -def format_hostname(hostname): +def format_hostname(hostname: str) -> str: """Format hostname for display.""" if has_ipv6 and re.match(r"\d+.\d+.\d+.\d+", hostname) is not None: hostname = f"::ffff:{hostname}" @@ -113,13 +122,13 @@ class Server: def __init__( # noqa: PLR0913 self, - host, - port, - protocol, - protocol_kwargs=None, - max_connections=5, - timeout=30, - ): + host: str, + port: int, + protocol: type[MpdSession], + protocol_kwargs: dict[str, Any] | None = None, + max_connections: int = 5, + timeout: int = 30, + ) -> None: self.protocol = protocol self.protocol_kwargs = protocol_kwargs or {} self.max_connections = max_connections @@ -129,7 +138,7 @@ def __init__( # noqa: PLR0913 self.watcher = self.register_server_socket(self.server_socket.fileno()) - def create_server_socket(self, host, port): + def create_server_socket(self, host: str, port: int) -> socket.socket: sock = get_systemd_socket() if sock is not None: return sock @@ -149,7 +158,7 @@ def create_server_socket(self, host, port): sock.listen(1) return sock - def stop(self): + def stop(self) -> None: GLib.source_remove(self.watcher) if is_unix_socket(self.server_socket): unix_socket_path = self.server_socket.getsockname() @@ -163,10 +172,10 @@ def stop(self): if unix_socket_path is not None: os.unlink(unix_socket_path) # noqa: PTH108 - def register_server_socket(self, fileno): + def register_server_socket(self, fileno: int) -> Any: return GLib.io_add_watch(fileno, GLib.IO_IN, self.handle_connection) - def handle_connection(self, _fd, _flags): + def handle_connection(self, _fd: int, _flags: int) -> bool: try: sock, addr = self.accept_connection() except ShouldRetrySocketCallError: @@ -178,34 +187,37 @@ def handle_connection(self, _fd, _flags): self.init_connection(sock, addr) return True - def accept_connection(self): + def accept_connection(self) -> tuple[socket.socket, SocketAddress]: try: sock, addr = self.server_socket.accept() if is_unix_socket(sock): addr = (sock.getsockname(), None) - except OSError as e: - if e.errno in (errno.EAGAIN, errno.EINTR): + except OSError as exc: + if exc.errno in (errno.EAGAIN, errno.EINTR): raise ShouldRetrySocketCallError from None raise else: - return sock, addr + return ( + sock, + addr[:2], # addr is a two-tuple for IPv4 and four-tuple for IPv6 + ) - def maximum_connections_exceeded(self): + def maximum_connections_exceeded(self) -> bool: return ( self.max_connections is not None and self.number_of_connections() >= self.max_connections ) - def number_of_connections(self): + def number_of_connections(self) -> int: return len(pykka.ActorRegistry.get_by_class(self.protocol)) - def reject_connection(self, sock, addr): + def reject_connection(self, sock: socket.socket, addr: SocketAddress) -> None: # FIXME provide more context in logging? logger.warning("Rejected connection from %s", format_address(addr)) with contextlib.suppress(OSError): sock.close() - def init_connection(self, sock, addr): + def init_connection(self, sock: socket.socket, addr: SocketAddress) -> None: Connection(self.protocol, self.protocol_kwargs, sock, addr, self.timeout) @@ -218,10 +230,20 @@ class Connection: # false return value would only tell us that what we thought was registered # is already gone, there is really nothing more we can do. - def __init__(self, protocol, protocol_kwargs, sock, addr, timeout): # noqa: PLR0913 + host: str + port: int | None + + def __init__( # noqa: PLR0913 + self, + protocol: type[MpdSession], + protocol_kwargs: dict[str, Any], + sock: socket.socket, + addr: SocketAddress, + timeout: int, + ) -> None: sock.setblocking(False) # noqa: FBT003 - self.host, self.port = addr[:2] # IPv6 has larger addr + self.host, self.port = addr[:2] self._sock = sock self.protocol = protocol @@ -242,7 +264,7 @@ def __init__(self, protocol, protocol_kwargs, sock, addr, timeout): # noqa: PLR self.enable_recv() self.enable_timeout() - def stop(self, reason, level=logging.DEBUG): + def stop(self, reason: str, level: int = logging.DEBUG) -> None: if self.stopping: logger.log(level, f"Already stopping: {reason}") return @@ -261,7 +283,7 @@ def stop(self, reason, level=logging.DEBUG): with contextlib.suppress(OSError): self._sock.close() - def queue_send(self, data): + def queue_send(self, data: bytes) -> None: """Try to send data to client exactly as is and queue rest.""" self.send_lock.acquire(blocking=True) self.send_buffer = self.send(self.send_buffer + data) @@ -269,7 +291,7 @@ def queue_send(self, data): if self.send_buffer: self.enable_send() - def send(self, data): + def send(self, data: bytes) -> bytes: """Send data to client, return any unsent data.""" try: sent = self._sock.send(data) @@ -280,7 +302,7 @@ def send(self, data): self.stop(f"Unexpected client error: {exc}") return b"" - def enable_timeout(self): + def enable_timeout(self) -> None: """Reactivate timeout mechanism.""" if self.timeout is None or self.timeout <= 0: return @@ -288,14 +310,14 @@ def enable_timeout(self): self.disable_timeout() self.timeout_id = GLib.timeout_add_seconds(self.timeout, self.timeout_callback) - def disable_timeout(self): + def disable_timeout(self) -> None: """Deactivate timeout mechanism.""" if self.timeout_id is None: return GLib.source_remove(self.timeout_id) self.timeout_id = None - def enable_recv(self): + def enable_recv(self) -> None: if self.recv_id is not None: return @@ -308,13 +330,13 @@ def enable_recv(self): except OSError as exc: self.stop(f"Problem with connection: {exc}") - def disable_recv(self): + def disable_recv(self) -> None: if self.recv_id is None: return GLib.source_remove(self.recv_id) self.recv_id = None - def enable_send(self): + def enable_send(self) -> None: if self.send_id is not None: return @@ -327,14 +349,14 @@ def enable_send(self): except OSError as exc: self.stop(f"Problem with connection: {exc}") - def disable_send(self): + def disable_send(self) -> None: if self.send_id is None: return GLib.source_remove(self.send_id) self.send_id = None - def recv_callback(self, fd, flags): # noqa: ARG002 + def recv_callback(self, fd: int, flags: int) -> bool: # noqa: ARG002 if flags & (GLib.IO_ERR | GLib.IO_HUP): self.stop(f"Bad client flags: {flags}") return True @@ -358,7 +380,7 @@ def recv_callback(self, fd, flags): # noqa: ARG002 return True - def send_callback(self, fd, flags): # noqa: ARG002 + def send_callback(self, fd: int, flags: int) -> bool: # noqa: ARG002 if flags & (GLib.IO_ERR | GLib.IO_HUP): self.stop(f"Bad client flags: {flags}") return True @@ -377,11 +399,11 @@ def send_callback(self, fd, flags): # noqa: ARG002 return True - def timeout_callback(self): + def timeout_callback(self) -> bool: self.stop(f"Client inactive for {self.timeout:d}s; closing connection") return False - def __str__(self): + def __str__(self) -> str: return format_address((self.host, self.port)) @@ -397,33 +419,27 @@ class LineProtocol(pykka.ThreadingActor): #: Line terminator to use for outputed lines. terminator = b"\n" - #: Regex to use for spliting lines, will be set compiled version of its - #: own value, or to ``terminator``s value if it is not set itself. - delimiter = None + #: Regex to use for spliting lines. + delimiter = re.compile(rb"\r?\n") - #: What encoding to expect incoming data to be in, can be :class:`None`. + #: What encoding to expect incoming data to be in. encoding = "utf-8" - def __init__(self, connection): + def __init__(self, connection: Connection) -> None: super().__init__() self.connection = connection self.prevent_timeout = False self.recv_buffer = b"" - if self.delimiter: - self.delimiter = re.compile(self.delimiter) - else: - self.delimiter = re.compile(self.terminator) - @property - def host(self): + def host(self) -> str: return self.connection.host @property - def port(self): + def port(self) -> int | None: return self.connection.port - def on_line_received(self, line): + def on_line_received(self, line: str) -> None: """ Called whenever a new line is found. @@ -431,7 +447,7 @@ def on_line_received(self, line): """ raise NotImplementedError - def on_receive(self, message): + def on_receive(self, message: dict[str, Any]) -> None: """Handle messages with new data from server.""" if "close" in message: self.connection.stop("Client most likely disconnected.") @@ -451,21 +467,26 @@ def on_receive(self, message): if not self.prevent_timeout: self.connection.enable_timeout() - def on_failure(self, exception_type, exception_value, traceback): # noqa: ARG002 + def on_failure( + self, + exception_type: type[BaseException] | None, # noqa: ARG002 + exception_value: BaseException | None, # noqa: ARG002 + traceback: TracebackType | None, # noqa: ARG002 + ) -> None: """Clean up connection resouces when actor fails.""" self.connection.stop("Actor failed.") - def on_stop(self): + def on_stop(self) -> None: """Clean up connection resouces when actor stops.""" self.connection.stop("Actor is shutting down.") - def parse_lines(self): + def parse_lines(self) -> Generator[bytes, Any, None]: """Consume new data and yield any lines found.""" while re.search(self.terminator, self.recv_buffer): line, self.recv_buffer = self.delimiter.split(self.recv_buffer, 1) yield line - def encode(self, line): + def encode(self, line: str) -> bytes: """ Handle encoding of line. @@ -480,8 +501,9 @@ def encode(self, line): self.encoding, ) self.stop() + return NoReturn - def decode(self, line): + def decode(self, line: bytes) -> str: """ Handle decoding of line. @@ -496,14 +518,15 @@ def decode(self, line): self.encoding, ) self.stop() + return NoReturn - def join_lines(self, lines): + def join_lines(self, lines: list[str]) -> str: if not lines: return "" line_terminator = self.decode(self.terminator) return line_terminator.join(lines) + line_terminator - def send_lines(self, lines): + def send_lines(self, lines: list[str]) -> None: """ Send array of lines to client via connection. diff --git a/src/mopidy_mpd/protocol/__init__.py b/src/mopidy_mpd/protocol/__init__.py index 683a854..8edf58c 100644 --- a/src/mopidy_mpd/protocol/__init__.py +++ b/src/mopidy_mpd/protocol/__init__.py @@ -10,12 +10,20 @@ `MPD clients `_. """ +from __future__ import annotations + +import functools import inspect +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, TypeAlias from mopidy_mpd import exceptions +if TYPE_CHECKING: + from mopidy_mpd.dispatcher import MpdContext + #: The MPD protocol uses UTF-8 for encoding all data. -ENCODING = "UTF-8" +ENCODING = "utf-8" #: The MPD protocol uses ``\n`` as line terminator. LINE_TERMINATOR = b"\n" @@ -24,7 +32,15 @@ VERSION = "0.19.0" -def load_protocol_modules(): +ResultValue: TypeAlias = str | int +ResultDict: TypeAlias = dict[str, ResultValue] +ResultTuple: TypeAlias = tuple[str, ResultValue] +ResultList: TypeAlias = list[ResultTuple | ResultDict] +Result: TypeAlias = None | ResultDict | ResultTuple | ResultList +Handler: TypeAlias = Callable[..., Result] + + +def load_protocol_modules() -> None: """ The protocol modules must be imported to get them registered in :attr:`commands`. @@ -45,7 +61,7 @@ def load_protocol_modules(): ) -def INT(value): # noqa: N802 +def INT(value: str) -> int: # noqa: N802 r"""Converts a value that matches [+-]?\d+ into an integer.""" if value is None: raise ValueError("None is not a valid integer") @@ -53,7 +69,7 @@ def INT(value): # noqa: N802 return int(value) -def UINT(value): # noqa: N802 +def UINT(value: str) -> int: # noqa: N802 r"""Converts a value that matches \d+ into an integer.""" if value is None: raise ValueError("None is not a valid integer") @@ -62,31 +78,31 @@ def UINT(value): # noqa: N802 return int(value) -def FLOAT(value): # noqa: N802 +def FLOAT(value: str) -> float: # noqa: N802 r"""Converts a value that matches [+-]\d+(.\d+)? into a float.""" if value is None: raise ValueError("None is not a valid float") return float(value) -def UFLOAT(value): # noqa: N802 +def UFLOAT(value: str) -> float: # noqa: N802 r"""Converts a value that matches \d+(.\d+)? into a float.""" if value is None: raise ValueError("None is not a valid float") - value = float(value) - if value < 0: + result = float(value) + if result < 0: raise ValueError("Only positive numbers are allowed") - return value + return result -def BOOL(value): # noqa: N802 +def BOOL(value: str) -> bool: # noqa: N802 """Convert the values 0 and 1 into booleans.""" if value in ("1", "0"): return bool(int(value)) raise ValueError(f"{value!r} is not 0 or 1") -def RANGE(value): # noqa: N802 +def RANGE(value: str) -> slice: # noqa: N802 """Convert a single integer or range spec into a slice ``n`` should become ``slice(n, n+1)`` @@ -116,12 +132,19 @@ class Commands: installed into. """ - def __init__(self): + def __init__(self) -> None: self.handlers = {} # TODO: consider removing auth_required and list_command in favour of # additional command instances to register in? - def add(self, name, *, auth_required=True, list_command=True, **validators): # noqa: C901 + def add( # noqa: C901 + self, + name: str, + *, + auth_required: bool = True, + list_command: bool = True, + **validators: Callable[[str], Any], + ) -> Callable[[Handler], Handler]: """Create a decorator that registers a handler and validation rules. Additional keyword arguments are treated as converters/validators to @@ -137,12 +160,12 @@ def add(self, name, *, auth_required=True, list_command=True, **validators): # Decorator returns the unwrapped function so that tests etc can use the functions with values with correct python types instead of strings. - :param string name: Name of the command being registered. - :param bool auth_required: If authorization is required. - :param bool list_command: If command should be listed in reflection. + :param name: Name of the command being registered. + :param auth_required: If authorization is required. + :param list_command: If command should be listed in reflection. """ - def wrapper(func): # noqa: C901 + def wrapper(func: Handler) -> Handler: # noqa: C901 if name in self.handlers: raise ValueError(f"{name} already registered") @@ -167,7 +190,8 @@ def wrapper(func): # noqa: C901 if spec.varkw or spec.kwonlyargs: raise TypeError("Keyword arguments are not permitted") - def validate(*args, **kwargs): + @functools.wraps(func) + def validate(*args: Any, **kwargs: Any) -> Result: if spec.varargs: return func(*args, **kwargs) @@ -197,21 +221,26 @@ def validate(*args, **kwargs): return wrapper - def call(self, tokens, context=None): + def call( + self, + *, + context: MpdContext, + tokens: list[str], + ) -> Result: """Find and run the handler registered for the given command. If the handler was registered with any converters/validators they will be run before calling the real handler. - :param list tokens: List of tokens to process - :param context: MPD context. - :type context: :class:`~mopidy_mpd.dispatcher.MpdContext` + :param context: MPD context + :param tokens: List of tokens to process """ if not tokens: raise exceptions.MpdNoCommandError - if tokens[0] not in self.handlers: - raise exceptions.MpdUnknownCommandError(command=tokens[0]) - return self.handlers[tokens[0]](context, *tokens[1:]) + command, tokens = tokens[0], tokens[1:] + if command not in self.handlers: + raise exceptions.MpdUnknownCommandError(command=command) + return self.handlers[command](context, *tokens) #: Global instance to install commands into diff --git a/src/mopidy_mpd/protocol/audio_output.py b/src/mopidy_mpd/protocol/audio_output.py index b07bef8..fe84909 100644 --- a/src/mopidy_mpd/protocol/audio_output.py +++ b/src/mopidy_mpd/protocol/audio_output.py @@ -1,8 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from mopidy_mpd import exceptions, protocol +if TYPE_CHECKING: + from mopidy_mpd.dispatcher import MpdContext + @protocol.commands.add("disableoutput", outputid=protocol.UINT) -def disableoutput(context, outputid): +def disableoutput(context: MpdContext, outputid: int) -> None: """ *musicpd.org, audio output section:* @@ -19,7 +26,7 @@ def disableoutput(context, outputid): @protocol.commands.add("enableoutput", outputid=protocol.UINT) -def enableoutput(context, outputid): +def enableoutput(context: MpdContext, outputid: int) -> None: """ *musicpd.org, audio output section:* @@ -36,7 +43,7 @@ def enableoutput(context, outputid): @protocol.commands.add("toggleoutput", outputid=protocol.UINT) -def toggleoutput(context, outputid): +def toggleoutput(context: MpdContext, outputid: int) -> None: """ *musicpd.org, audio output section:* @@ -54,7 +61,7 @@ def toggleoutput(context, outputid): @protocol.commands.add("outputs") -def outputs(context): +def outputs(context: MpdContext) -> protocol.Result: """ *musicpd.org, audio output section:* diff --git a/src/mopidy_mpd/protocol/channels.py b/src/mopidy_mpd/protocol/channels.py index fe65675..020011f 100644 --- a/src/mopidy_mpd/protocol/channels.py +++ b/src/mopidy_mpd/protocol/channels.py @@ -1,8 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Never + from mopidy_mpd import exceptions, protocol +if TYPE_CHECKING: + from mopidy_mpd.dispatcher import MpdContext + @protocol.commands.add("subscribe") -def subscribe(context, channel): +def subscribe(context: MpdContext, channel: str) -> Never: """ *musicpd.org, client to client section:* @@ -17,7 +24,7 @@ def subscribe(context, channel): @protocol.commands.add("unsubscribe") -def unsubscribe(context, channel): +def unsubscribe(context: MpdContext, channel: str) -> Never: """ *musicpd.org, client to client section:* @@ -30,7 +37,7 @@ def unsubscribe(context, channel): @protocol.commands.add("channels") -def channels(context): +def channels(context: MpdContext) -> Never: """ *musicpd.org, client to client section:* @@ -43,7 +50,7 @@ def channels(context): @protocol.commands.add("readmessages") -def readmessages(context): +def readmessages(context: MpdContext) -> Never: """ *musicpd.org, client to client section:* @@ -56,7 +63,7 @@ def readmessages(context): @protocol.commands.add("sendmessage") -def sendmessage(context, channel, text): +def sendmessage(context: MpdContext, channel: str, text: str) -> Never: """ *musicpd.org, client to client section:* diff --git a/src/mopidy_mpd/protocol/command_list.py b/src/mopidy_mpd/protocol/command_list.py index 854a5d9..bdff82a 100644 --- a/src/mopidy_mpd/protocol/command_list.py +++ b/src/mopidy_mpd/protocol/command_list.py @@ -1,8 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from mopidy_mpd import exceptions, protocol +if TYPE_CHECKING: + from mopidy_mpd.dispatcher import MpdContext + @protocol.commands.add("command_list_begin", list_command=False) -def command_list_begin(context): +def command_list_begin(context: MpdContext) -> None: """ *musicpd.org, command list section:* @@ -24,7 +31,7 @@ def command_list_begin(context): @protocol.commands.add("command_list_end", list_command=False) -def command_list_end(context): +def command_list_end(context: MpdContext) -> protocol.Result: """See :meth:`command_list_begin()`.""" # TODO: batch consecutive add commands if not context.dispatcher.command_list_receiving: @@ -52,7 +59,7 @@ def command_list_end(context): @protocol.commands.add("command_list_ok_begin", list_command=False) -def command_list_ok_begin(context): +def command_list_ok_begin(context: MpdContext) -> None: """See :meth:`command_list_begin()`.""" context.dispatcher.command_list_receiving = True context.dispatcher.command_list_ok = True diff --git a/src/mopidy_mpd/protocol/connection.py b/src/mopidy_mpd/protocol/connection.py index 73b3838..5def70e 100644 --- a/src/mopidy_mpd/protocol/connection.py +++ b/src/mopidy_mpd/protocol/connection.py @@ -1,9 +1,16 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Never + from mopidy_mpd import exceptions, protocol from mopidy_mpd.protocol import tagtype_list +if TYPE_CHECKING: + from mopidy_mpd.dispatcher import MpdContext + @protocol.commands.add("close", auth_required=False) -def close(context): +def close(context: MpdContext) -> None: """ *musicpd.org, connection section:* @@ -15,7 +22,7 @@ def close(context): @protocol.commands.add("kill", list_command=False) -def kill(context): +def kill(context: MpdContext) -> Never: """ *musicpd.org, connection section:* @@ -27,7 +34,7 @@ def kill(context): @protocol.commands.add("password", auth_required=False) -def password(context, password): +def password(context: MpdContext, password: str) -> None: """ *musicpd.org, connection section:* @@ -43,7 +50,7 @@ def password(context, password): @protocol.commands.add("ping", auth_required=False) -def ping(context): +def ping(context: MpdContext) -> None: """ *musicpd.org, connection section:* @@ -54,7 +61,7 @@ def ping(context): @protocol.commands.add("tagtypes") -def tagtypes(context, *parameters): +def tagtypes(context: MpdContext, *parameters: list[str]) -> protocol.Result: """ *mpd.readthedocs.io, connection settings section:* @@ -98,7 +105,7 @@ def tagtypes(context, *parameters): return [("tagtype", tagtype) for tagtype in context.session.tagtypes] -def _validate_tagtypes(parameters): +def _validate_tagtypes(parameters: list[str]) -> None: param_set = set(parameters) if not param_set: raise exceptions.MpdArgError("Not enough arguments") diff --git a/src/mopidy_mpd/protocol/current_playlist.py b/src/mopidy_mpd/protocol/current_playlist.py index a3ebad7..8ec43ce 100644 --- a/src/mopidy_mpd/protocol/current_playlist.py +++ b/src/mopidy_mpd/protocol/current_playlist.py @@ -1,10 +1,18 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING from urllib.parse import urlparse from mopidy_mpd import exceptions, protocol, translator +if TYPE_CHECKING: + from mopidy.types import Uri + + from mopidy_mpd.dispatcher import MpdContext + @protocol.commands.add("add") -def add(context, uri): +def add(context: MpdContext, uri: Uri) -> None: """ *musicpd.org, current playlist section:* @@ -40,7 +48,7 @@ def add(context, uri): @protocol.commands.add("addid", songpos=protocol.UINT) -def addid(context, uri, songpos=None): +def addid(context: MpdContext, uri: Uri, songpos: int | None = None) -> protocol.Result: """ *musicpd.org, current playlist section:* @@ -73,7 +81,7 @@ def addid(context, uri, songpos=None): @protocol.commands.add("delete", songrange=protocol.RANGE) -def delete(context, songrange): +def delete(context: MpdContext, songrange: slice) -> None: """ *musicpd.org, current playlist section:* @@ -88,12 +96,11 @@ def delete(context, songrange): tl_tracks = context.core.tracklist.slice(start, end).get() if not tl_tracks: raise exceptions.MpdArgError("Bad song index", command="delete") - for tlid, _ in tl_tracks: - context.core.tracklist.remove({"tlid": [tlid]}) + context.core.tracklist.remove({"tlid": [tl_track.tlid for tl_track in tl_tracks]}) @protocol.commands.add("deleteid", tlid=protocol.UINT) -def deleteid(context, tlid): +def deleteid(context: MpdContext, tlid: int) -> None: """ *musicpd.org, current playlist section:* @@ -107,7 +114,7 @@ def deleteid(context, tlid): @protocol.commands.add("clear") -def clear(context): +def clear(context: MpdContext) -> None: """ *musicpd.org, current playlist section:* @@ -119,7 +126,7 @@ def clear(context): @protocol.commands.add("move", songrange=protocol.RANGE, to=protocol.UINT) -def move_range(context, songrange, to): +def move_range(context: MpdContext, songrange: slice, to: int) -> None: """ *musicpd.org, current playlist section:* @@ -136,7 +143,7 @@ def move_range(context, songrange, to): @protocol.commands.add("moveid", tlid=protocol.UINT, to=protocol.UINT) -def moveid(context, tlid, to): +def moveid(context: MpdContext, tlid: int, to: int) -> None: """ *musicpd.org, current playlist section:* @@ -146,15 +153,14 @@ def moveid(context, tlid, to): the playlist. If ``TO`` is negative, it is relative to the current song in the playlist (if there is one). """ - tl_tracks = context.core.tracklist.filter({"tlid": [tlid]}).get() - if not tl_tracks: + position = context.core.tracklist.index(tlid=tlid).get() + if position is None: raise exceptions.MpdNoExistError("No such song") - position = context.core.tracklist.index(tl_tracks[0]).get() context.core.tracklist.move(position, position + 1, to) @protocol.commands.add("playlist") -def playlist(context): +def playlist(context: MpdContext) -> protocol.Result: """ *musicpd.org, current playlist section:* @@ -170,7 +176,7 @@ def playlist(context): @protocol.commands.add("playlistfind") -def playlistfind(context, tag, needle): +def playlistfind(context: MpdContext, tag: str, needle: str) -> protocol.Result: """ *musicpd.org, current playlist section:* @@ -190,7 +196,7 @@ def playlistfind(context, tag, needle): @protocol.commands.add("playlistid", tlid=protocol.UINT) -def playlistid(context, tlid=None): +def playlistid(context: MpdContext, tlid: int | None = None) -> protocol.Result: """ *musicpd.org, current playlist section:* @@ -215,7 +221,7 @@ def playlistid(context, tlid=None): @protocol.commands.add("playlistinfo") -def playlistinfo(context, parameter=None): +def playlistinfo(context: MpdContext, parameter: str | None = None) -> protocol.Result: """ *musicpd.org, current playlist section:* @@ -247,7 +253,7 @@ def playlistinfo(context, parameter=None): @protocol.commands.add("playlistsearch") -def playlistsearch(context, tag, needle): +def playlistsearch(context: MpdContext, tag: str, needle: str) -> protocol.Result: """ *musicpd.org, current playlist section:* @@ -264,7 +270,7 @@ def playlistsearch(context, tag, needle): @protocol.commands.add("plchanges", version=protocol.INT) -def plchanges(context, version): +def plchanges(context: MpdContext, version: int) -> protocol.Result: """ *musicpd.org, current playlist section:* @@ -308,7 +314,7 @@ def plchanges(context, version): @protocol.commands.add("plchangesposid", version=protocol.INT) -def plchangesposid(context, version): +def plchangesposid(context: MpdContext, version: int) -> protocol.Result: """ *musicpd.org, current playlist section:* @@ -334,7 +340,7 @@ def plchangesposid(context, version): @protocol.commands.add("prio", priority=protocol.UINT, position=protocol.RANGE) -def prio(context, priority, position): +def prio(context: MpdContext, priority: int, position: int) -> protocol.Result: """ *musicpd.org, current playlist section:* @@ -350,7 +356,7 @@ def prio(context, priority, position): @protocol.commands.add("prioid") -def prioid(context, *args): +def prioid(context: MpdContext, *args: str) -> protocol.Result: """ *musicpd.org, current playlist section:* @@ -362,7 +368,7 @@ def prioid(context, *args): @protocol.commands.add("rangeid", tlid=protocol.UINT, songrange=protocol.RANGE) -def rangeid(context, tlid, songrange): +def rangeid(context: MpdContext, tlid: int, songrange: slice) -> protocol.Result: """ *musicpd.org, current playlist section:* @@ -381,7 +387,7 @@ def rangeid(context, tlid, songrange): @protocol.commands.add("shuffle", songrange=protocol.RANGE) -def shuffle(context, songrange=None): +def shuffle(context: MpdContext, songrange: slice | None = None) -> None: """ *musicpd.org, current playlist section:* @@ -398,7 +404,7 @@ def shuffle(context, songrange=None): @protocol.commands.add("swap", songpos1=protocol.UINT, songpos2=protocol.UINT) -def swap(context, songpos1, songpos2): +def swap(context: MpdContext, songpos1: int, songpos2: int) -> None: """ *musicpd.org, current playlist section:* @@ -413,7 +419,7 @@ def swap(context, songpos1, songpos2): @protocol.commands.add("swapid", tlid1=protocol.UINT, tlid2=protocol.UINT) -def swapid(context, tlid1, tlid2): +def swapid(context: MpdContext, tlid1: int, tlid2: int) -> None: """ *musicpd.org, current playlist section:* @@ -421,17 +427,15 @@ def swapid(context, tlid1, tlid2): Swaps the positions of ``SONG1`` and ``SONG2`` (both song ids). """ - tl_tracks1 = context.core.tracklist.filter({"tlid": [tlid1]}).get() - tl_tracks2 = context.core.tracklist.filter({"tlid": [tlid2]}).get() - if not tl_tracks1 or not tl_tracks2: + position1 = context.core.tracklist.index(tlid=tlid1).get() + position2 = context.core.tracklist.index(tlid=tlid2).get() + if position1 is None or position2 is None: raise exceptions.MpdNoExistError("No such song") - position1 = context.core.tracklist.index(tl_tracks1[0]).get() - position2 = context.core.tracklist.index(tl_tracks2[0]).get() swap(context, position1, position2) @protocol.commands.add("addtagid", tlid=protocol.UINT) -def addtagid(context, tlid, tag, value): +def addtagid(context: MpdContext, tlid: int, tag: str, value: str) -> protocol.Result: """ *musicpd.org, current playlist section:* @@ -449,7 +453,7 @@ def addtagid(context, tlid, tag, value): @protocol.commands.add("cleartagid", tlid=protocol.UINT) -def cleartagid(context, tlid, tag): +def cleartagid(context: MpdContext, tlid: int, tag: str) -> protocol.Result: """ *musicpd.org, current playlist section:* diff --git a/src/mopidy_mpd/protocol/mount.py b/src/mopidy_mpd/protocol/mount.py index f75eb99..b8280bc 100644 --- a/src/mopidy_mpd/protocol/mount.py +++ b/src/mopidy_mpd/protocol/mount.py @@ -1,8 +1,17 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Never + from mopidy_mpd import exceptions, protocol +if TYPE_CHECKING: + from mopidy.types import Uri + + from mopidy_mpd.dispatcher import MpdContext + @protocol.commands.add("mount") -def mount(context, path, uri): +def mount(context: MpdContext, path: str, uri: Uri) -> Never: """ *musicpd.org, mounts and neighbors section:* @@ -19,7 +28,7 @@ def mount(context, path, uri): @protocol.commands.add("unmount") -def unmount(context, path): +def unmount(context: MpdContext, path: str) -> Never: """ *musicpd.org, mounts and neighbors section:* @@ -36,7 +45,7 @@ def unmount(context, path): @protocol.commands.add("listmounts") -def listmounts(context): +def listmounts(context: MpdContext) -> Never: """ *musicpd.org, mounts and neighbors section:* @@ -59,7 +68,7 @@ def listmounts(context): @protocol.commands.add("listneighbors") -def listneighbors(context): +def listneighbors(context: MpdContext) -> Never: """ *musicpd.org, mounts and neighbors section:* diff --git a/src/mopidy_mpd/protocol/music_db.py b/src/mopidy_mpd/protocol/music_db.py index 6abd539..0a15349 100644 --- a/src/mopidy_mpd/protocol/music_db.py +++ b/src/mopidy_mpd/protocol/music_db.py @@ -1,11 +1,20 @@ -import functools +from __future__ import annotations + import itertools +from typing import TYPE_CHECKING, Never, cast -from mopidy.models import Track +from mopidy.models import Album, Artist, SearchResult, Track +from mopidy.types import DistinctField, Query, SearchField, Uri from mopidy_mpd import exceptions, protocol, translator -_LIST_MAPPING = { +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from mopidy_mpd.dispatcher import MpdContext + + +_LIST_MAPPING: dict[str, DistinctField] = { "album": "album", "albumartist": "albumartist", "artist": "artist", @@ -45,12 +54,12 @@ _SEARCH_MAPPING = dict(_LIST_MAPPING, any="any") -def _query_from_mpd_search_parameters(parameters, mapping): - query = {} +def _query_for_search(parameters: Sequence[str]) -> Query[SearchField]: + query: dict[str, list[str]] = {} parameters = list(parameters) while parameters: # TODO: does it matter that this is now case insensitive - field = mapping.get(parameters.pop(0).lower()) + field = _SEARCH_MAPPING.get(parameters.pop(0).lower()) if not field: raise exceptions.MpdArgError("incorrect arguments") if not parameters: @@ -58,19 +67,22 @@ def _query_from_mpd_search_parameters(parameters, mapping): value = parameters.pop(0) if value.strip(): query.setdefault(field, []).append(value) - return query + return cast(Query[SearchField], query) -def _get_field(field, search_results): - return list(itertools.chain(*[getattr(r, field) for r in search_results])) +def _get_albums(search_results: Iterable[SearchResult]) -> list[Album]: + return list(itertools.chain(*[r.albums for r in search_results])) -_get_albums = functools.partial(_get_field, "albums") -_get_artists = functools.partial(_get_field, "artists") -_get_tracks = functools.partial(_get_field, "tracks") +def _get_artists(search_results: Iterable[SearchResult]) -> list[Artist]: + return list(itertools.chain(*[r.artists for r in search_results])) -def _album_as_track(album): +def _get_tracks(search_results: Iterable[SearchResult]) -> list[Track]: + return list(itertools.chain(*[r.tracks for r in search_results])) + + +def _album_as_track(album: Album) -> Track: return Track( uri=album.uri, name="Album: " + album.name, @@ -80,12 +92,16 @@ def _album_as_track(album): ) -def _artist_as_track(artist): - return Track(uri=artist.uri, name="Artist: " + artist.name, artists=[artist]) +def _artist_as_track(artist: Artist) -> Track: + return Track( + uri=artist.uri, + name="Artist: " + artist.name, + artists=[artist], + ) @protocol.commands.add("count") -def count(context, *args): +def count(context: MpdContext, *args: str) -> protocol.Result: """ *musicpd.org, music database section:* @@ -99,9 +115,10 @@ def count(context, *args): - use multiple tag-needle pairs to make more specific searches. """ try: - query = _query_from_mpd_search_parameters(args, _SEARCH_MAPPING) + query = _query_for_search(args) except ValueError as exc: raise exceptions.MpdArgError("incorrect arguments") from exc + results = context.core.library.search(query=query, exact=True).get() result_tracks = _get_tracks(results) total_length = sum(t.length for t in result_tracks if t.length) @@ -112,7 +129,7 @@ def count(context, *args): @protocol.commands.add("find") -def find(context, *args): +def find(context: MpdContext, *args: str) -> protocol.Result: """ *musicpd.org, music database section:* @@ -138,12 +155,12 @@ def find(context, *args): - uses "file" instead of "filename". """ try: - query = _query_from_mpd_search_parameters(args, _SEARCH_MAPPING) + query = _query_for_search(args) except ValueError: return None results = context.core.library.search(query=query, exact=True).get() - result_tracks = [] + result_tracks: list[Track] = [] if ( "artist" not in query and "albumartist" not in query @@ -158,7 +175,7 @@ def find(context, *args): @protocol.commands.add("findadd") -def findadd(context, *args): +def findadd(context: MpdContext, *args: str) -> None: """ *musicpd.org, music database section:* @@ -168,17 +185,17 @@ def findadd(context, *args): current playlist. Parameters have the same meaning as for ``find``. """ try: - query = _query_from_mpd_search_parameters(args, _SEARCH_MAPPING) + query = _query_for_search(args) except ValueError: return results = context.core.library.search(query=query, exact=True).get() - - context.core.tracklist.add(uris=[track.uri for track in _get_tracks(results)]).get() + uris = [track.uri for track in _get_tracks(results)] + context.core.tracklist.add(uris=uris).get() @protocol.commands.add("list") -def list_(context, *args): +def list_(context: MpdContext, *args: str) -> protocol.Result: """ *musicpd.org, music database section:* @@ -264,7 +281,7 @@ def list_(context, *args): if field is None: raise exceptions.MpdArgError(f"Unknown tag type: {field_arg}") - query = None + query: Query[SearchField] | None = None if len(params) == 1: if field != "album": raise exceptions.MpdArgError('should be "Album" for 3 arguments') @@ -272,7 +289,7 @@ def list_(context, *args): query = {"artist": params} else: try: - query = _query_from_mpd_search_parameters(params, _SEARCH_MAPPING) + query = _query_for_search(params) except exceptions.MpdArgError as exc: exc.message = "Unknown filter type" # B306: Our own exception raise @@ -285,7 +302,7 @@ def list_(context, *args): @protocol.commands.add("listall") -def listall(context, uri=None): +def listall(context: MpdContext, uri: str | None = None) -> protocol.Result: """ *musicpd.org, music database section:* @@ -297,7 +314,6 @@ def listall(context, uri=None): database. That is fragile and adds huge overhead. It will break with large databases. Instead, query MPD whenever you need something. - .. warning:: This command is disabled by default in Mopidy installs. """ result = [] @@ -313,7 +329,7 @@ def listall(context, uri=None): @protocol.commands.add("listallinfo") -def listallinfo(context, uri=None): +def listallinfo(context: MpdContext, uri: str | None = None) -> protocol.Result: """ *musicpd.org, music database section:* @@ -329,8 +345,8 @@ def listallinfo(context, uri=None): .. warning:: This command is disabled by default in Mopidy installs. """ - result = [] - for path, lookup_future in context.browse(uri): + result: protocol.ResultList = [] + for path, lookup_future in context.browse(uri, lookup=True): if not lookup_future: result.append(("directory", path.lstrip("/"))) else: @@ -343,7 +359,7 @@ def listallinfo(context, uri=None): @protocol.commands.add("listfiles") -def listfiles(context, uri=None): +def listfiles(context: MpdContext, uri: str | None = None) -> Never: """ *musicpd.org, music database section:* @@ -367,7 +383,7 @@ def listfiles(context, uri=None): @protocol.commands.add("lsinfo") -def lsinfo(context, uri=None): +def lsinfo(context: MpdContext, uri: str | None = None) -> protocol.Result: """ *musicpd.org, music database section:* @@ -384,7 +400,7 @@ def lsinfo(context, uri=None): ""``, and ``lsinfo "/"``. """ result = [] - for path, lookup_future in context.browse(uri, recursive=False): + for path, lookup_future in context.browse(uri, recursive=False, lookup=True): if not lookup_future: result.append(("directory", path.lstrip("/"))) else: @@ -403,7 +419,7 @@ def lsinfo(context, uri=None): @protocol.commands.add("rescan") -def rescan(context, uri=None): +def rescan(context: MpdContext, uri: str | None = None) -> protocol.Result: """ *musicpd.org, music database section:* @@ -415,7 +431,7 @@ def rescan(context, uri=None): @protocol.commands.add("search") -def search(context, *args): +def search(context: MpdContext, *args: str) -> protocol.Result: """ *musicpd.org, music database section:* @@ -441,7 +457,7 @@ def search(context, *args): - uses "file" instead of "filename". """ try: - query = _query_from_mpd_search_parameters(args, _SEARCH_MAPPING) + query = _query_for_search(args) except ValueError: return None results = context.core.library.search(query).get() @@ -454,7 +470,7 @@ def search(context, *args): @protocol.commands.add("searchadd") -def searchadd(context, *args): +def searchadd(context: MpdContext, *args: str) -> None: """ *musicpd.org, music database section:* @@ -467,7 +483,7 @@ def searchadd(context, *args): not case sensitive. """ try: - query = _query_from_mpd_search_parameters(args, _SEARCH_MAPPING) + query = _query_for_search(args) except ValueError: return @@ -477,7 +493,7 @@ def searchadd(context, *args): @protocol.commands.add("searchaddpl") -def searchaddpl(context, *args): +def searchaddpl(context: MpdContext, *args: str) -> None: """ *musicpd.org, music database section:* @@ -496,22 +512,26 @@ def searchaddpl(context, *args): raise exceptions.MpdArgError("incorrect arguments") playlist_name = parameters.pop(0) try: - query = _query_from_mpd_search_parameters(parameters, _SEARCH_MAPPING) + query = _query_for_search(parameters) except ValueError: return results = context.core.library.search(query).get() uri = context.lookup_playlist_uri_from_name(playlist_name) - playlist = uri is not None and context.core.playlists.lookup(uri).get() + if uri is None: + return # TODO: Raise error? + playlist = context.core.playlists.lookup(uri).get() if not playlist: playlist = context.core.playlists.create(playlist_name).get() + if not playlist: + return # TODO: Raise error about failed playlist creation? tracks = list(playlist.tracks) + _get_tracks(results) playlist = playlist.replace(tracks=tracks) context.core.playlists.save(playlist) @protocol.commands.add("update") -def update(context, uri=None): +def update(context: MpdContext, uri: Uri | None = None) -> protocol.Result: """ *musicpd.org, music database section:* @@ -532,7 +552,7 @@ def update(context, uri=None): # TODO: add at least reflection tests before adding NotImplemented version # @protocol.commands.add('readcomments') -def readcomments(context, uri): +def readcomments(context: MpdContext, uri: Uri | None = None) -> None: """ *musicpd.org, music database section:* diff --git a/src/mopidy_mpd/protocol/playback.py b/src/mopidy_mpd/protocol/playback.py index 98676c6..22ce9fd 100644 --- a/src/mopidy_mpd/protocol/playback.py +++ b/src/mopidy_mpd/protocol/playback.py @@ -1,10 +1,18 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Never + from mopidy.core import PlaybackState +from mopidy.types import DurationMs, Percentage from mopidy_mpd import exceptions, protocol +if TYPE_CHECKING: + from mopidy_mpd.dispatcher import MpdContext + @protocol.commands.add("consume", state=protocol.BOOL) -def consume(context, state): +def consume(context: MpdContext, state: bool) -> None: # noqa: FBT001 """ *musicpd.org, playback section:* @@ -18,7 +26,7 @@ def consume(context, state): @protocol.commands.add("crossfade", seconds=protocol.UINT) -def crossfade(context, seconds): +def crossfade(context: MpdContext, seconds: int) -> Never: """ *musicpd.org, playback section:* @@ -30,7 +38,7 @@ def crossfade(context, seconds): @protocol.commands.add("mixrampdb") -def mixrampdb(context, decibels): +def mixrampdb(context: MpdContext, decibels: str) -> Never: """ *musicpd.org, playback section:* @@ -47,7 +55,7 @@ def mixrampdb(context, decibels): @protocol.commands.add("mixrampdelay", seconds=protocol.UINT) -def mixrampdelay(context, seconds): +def mixrampdelay(context: MpdContext, seconds: int) -> Never: """ *musicpd.org, playback section:* @@ -61,7 +69,7 @@ def mixrampdelay(context, seconds): @protocol.commands.add("next") -def next_(context): +def next_(context: MpdContext) -> None: """ *musicpd.org, playback section:* @@ -115,11 +123,11 @@ def next_(context): order as the first time. """ - return context.core.playback.next().get() + context.core.playback.next().get() @protocol.commands.add("pause", state=protocol.BOOL) -def pause(context, state=None): +def pause(context: MpdContext, state: bool | None = None) -> None: """ *musicpd.org, playback section:* @@ -131,21 +139,22 @@ def pause(context, state=None): - Calls ``pause`` without any arguments to toogle pause. """ - if state is None: - # Deprecated: Calling `pause` without any arguments - playback_state = context.core.playback.get_state().get() - if playback_state == PlaybackState.PLAYING: + match state: + case None: + # Deprecated: Calling `pause` without any arguments + playback_state = context.core.playback.get_state().get() + if playback_state == PlaybackState.PLAYING: + context.core.playback.pause().get() + elif playback_state == PlaybackState.PAUSED: + context.core.playback.resume().get() + case True: context.core.playback.pause().get() - elif playback_state == PlaybackState.PAUSED: + case False: context.core.playback.resume().get() - elif state: - context.core.playback.pause().get() - else: - context.core.playback.resume().get() @protocol.commands.add("play", songpos=protocol.INT) -def play(context, songpos=None): +def play(context: MpdContext, songpos: int | None = None) -> None: """ *musicpd.org, playback section:* @@ -170,38 +179,42 @@ def play(context, songpos=None): - issues ``play 6`` without quotes around the argument. """ if songpos is None: - return context.core.playback.play().get() + context.core.playback.play().get() + return if songpos == -1: - return _play_minus_one(context) + _play_minus_one(context) + return try: tl_track = context.core.tracklist.slice(songpos, songpos + 1).get()[0] - return context.core.playback.play(tlid=tl_track.tlid).get() + context.core.playback.play(tlid=tl_track.tlid).get() except IndexError as exc: raise exceptions.MpdArgError("Bad song index") from exc -def _play_minus_one(context): - playback_state = context.core.playback.get_state().get() - if playback_state == PlaybackState.PLAYING: - return None # Nothing to do - if playback_state == PlaybackState.PAUSED: - return context.core.playback.resume().get() - - current_tl_track = context.core.playback.get_current_tl_track().get() - if current_tl_track is not None: - return context.core.playback.play(tlid=current_tl_track.tlid).get() +def _play_minus_one(context: MpdContext) -> None: + match context.core.playback.get_state().get(): + case PlaybackState.PLAYING: + pass # Nothing to do + case PlaybackState.PAUSED: + context.core.playback.resume().get() + case PlaybackState.STOPPED: + current_tlid = context.core.playback.get_current_tlid().get() + if current_tlid is not None: + context.core.playback.play(tlid=current_tlid).get() + return - tl_tracks = context.core.tracklist.slice(0, 1).get() - if tl_tracks: - return context.core.playback.play(tlid=tl_tracks[0].tlid).get() + tl_tracks = context.core.tracklist.slice(0, 1).get() + if tl_tracks: + context.core.playback.play(tlid=tl_tracks[0].tlid).get() + return - return None # Fail silently + # No current track, empty tracklist: nothing to do @protocol.commands.add("playid", tlid=protocol.INT) -def playid(context, tlid): +def playid(context: MpdContext, tlid: int) -> None: """ *musicpd.org, playback section:* @@ -220,6 +233,7 @@ def playid(context, tlid): """ if tlid == -1: return _play_minus_one(context) + tl_tracks = context.core.tracklist.filter({"tlid": [tlid]}).get() if not tl_tracks: raise exceptions.MpdNoExistError("No such song") @@ -227,7 +241,7 @@ def playid(context, tlid): @protocol.commands.add("previous") -def previous(context): +def previous(context: MpdContext) -> None: """ *musicpd.org, playback section:* @@ -270,11 +284,11 @@ def previous(context): ``previous`` should do a seek to time position 0. """ - return context.core.playback.previous().get() + context.core.playback.previous().get() @protocol.commands.add("random", state=protocol.BOOL) -def random(context, state): +def random(context: MpdContext, state: bool) -> None: # noqa: FBT001 """ *musicpd.org, playback section:* @@ -286,7 +300,7 @@ def random(context, state): @protocol.commands.add("repeat", state=protocol.BOOL) -def repeat(context, state): +def repeat(context: MpdContext, state: bool) -> None: # noqa: FBT001 """ *musicpd.org, playback section:* @@ -298,7 +312,7 @@ def repeat(context, state): @protocol.commands.add("replay_gain_mode") -def replay_gain_mode(context, mode): +def replay_gain_mode(context: MpdContext, mode: str) -> Never: """ *musicpd.org, playback section:* @@ -315,7 +329,7 @@ def replay_gain_mode(context, mode): @protocol.commands.add("replay_gain_status") -def replay_gain_status(context): +def replay_gain_status(context: MpdContext) -> str: """ *musicpd.org, playback section:* @@ -328,7 +342,7 @@ def replay_gain_status(context): @protocol.commands.add("seek", songpos=protocol.UINT, seconds=protocol.UFLOAT) -def seek(context, songpos, seconds): +def seek(context: MpdContext, songpos: int, seconds: float) -> None: """ *musicpd.org, playback section:* @@ -343,12 +357,12 @@ def seek(context, songpos, seconds): """ tl_track = context.core.playback.get_current_tl_track().get() if context.core.tracklist.index(tl_track).get() != songpos: - play(context, songpos) - context.core.playback.seek(int(seconds * 1000)).get() + play(context, songpos=songpos) + context.core.playback.seek(DurationMs(int(seconds * 1000))).get() @protocol.commands.add("seekid", tlid=protocol.UINT, seconds=protocol.UFLOAT) -def seekid(context, tlid, seconds): +def seekid(context: MpdContext, tlid: int, seconds: float) -> None: """ *musicpd.org, playback section:* @@ -358,12 +372,12 @@ def seekid(context, tlid, seconds): """ tl_track = context.core.playback.get_current_tl_track().get() if not tl_track or tl_track.tlid != tlid: - playid(context, tlid) - context.core.playback.seek(int(seconds * 1000)).get() + playid(context, tlid=tlid) + context.core.playback.seek(DurationMs(int(seconds * 1000))).get() @protocol.commands.add("seekcur") -def seekcur(context, time): +def seekcur(context: MpdContext, time: str) -> None: """ *musicpd.org, playback section:* @@ -374,15 +388,15 @@ def seekcur(context, time): """ if time.startswith(("+", "-")): position = context.core.playback.get_time_position().get() - position += int(protocol.FLOAT(time) * 1000) + position = DurationMs(position + int(protocol.FLOAT(time) * 1000)) context.core.playback.seek(position).get() else: - position = int(protocol.UFLOAT(time) * 1000) + position = DurationMs(int(protocol.UFLOAT(time) * 1000)) context.core.playback.seek(position).get() @protocol.commands.add("setvol", volume=protocol.INT) -def setvol(context, volume): +def setvol(context: MpdContext, volume: int) -> None: """ *musicpd.org, playback section:* @@ -395,14 +409,14 @@ def setvol(context, volume): - issues ``setvol 50`` without quotes around the argument. """ # NOTE: we use INT as clients can pass in +N etc. - value = min(max(0, volume), 100) + value = Percentage(min(max(0, volume), 100)) success = context.core.mixer.set_volume(value).get() if not success: raise exceptions.MpdSystemError("problems setting volume") @protocol.commands.add("single", state=protocol.BOOL) -def single(context, state): +def single(context: MpdContext, state: bool) -> None: # noqa: FBT001 """ *musicpd.org, playback section:* @@ -416,7 +430,7 @@ def single(context, state): @protocol.commands.add("stop") -def stop(context): +def stop(context: MpdContext) -> None: """ *musicpd.org, playback section:* @@ -428,7 +442,7 @@ def stop(context): @protocol.commands.add("volume", change=protocol.INT) -def volume(context, change): +def volume(context: MpdContext, change: int) -> None: """ *musicpd.org, playback section:* @@ -450,7 +464,7 @@ def volume(context, change): if old_volume is None: raise exceptions.MpdSystemError("problems setting volume") - new_volume = min(max(min_volume, old_volume + change), max_volume) + new_volume = Percentage(min(max(min_volume, old_volume + change), max_volume)) success = context.core.mixer.set_volume(new_volume).get() if not success: raise exceptions.MpdSystemError("problems setting volume") diff --git a/src/mopidy_mpd/protocol/reflection.py b/src/mopidy_mpd/protocol/reflection.py index 96a370c..d73e2cd 100644 --- a/src/mopidy_mpd/protocol/reflection.py +++ b/src/mopidy_mpd/protocol/reflection.py @@ -1,8 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Never + from mopidy_mpd import exceptions, protocol +if TYPE_CHECKING: + from mopidy_mpd.dispatcher import MpdContext + @protocol.commands.add("config", list_command=False) -def config(context): +def config(context: MpdContext) -> Never: """ *musicpd.org, reflection section:* @@ -16,7 +23,7 @@ def config(context): @protocol.commands.add("commands", auth_required=False) -def commands(context): +def commands(context: MpdContext) -> protocol.Result: """ *musicpd.org, reflection section:* @@ -35,7 +42,7 @@ def commands(context): @protocol.commands.add("decoders") -def decoders(context): +def decoders(context: MpdContext) -> None: """ *musicpd.org, reflection section:* @@ -62,7 +69,7 @@ def decoders(context): @protocol.commands.add("notcommands", auth_required=False) -def notcommands(context): +def notcommands(context: MpdContext) -> protocol.Result: """ *musicpd.org, reflection section:* @@ -81,7 +88,7 @@ def notcommands(context): @protocol.commands.add("urlhandlers") -def urlhandlers(context): +def urlhandlers(context: MpdContext) -> protocol.Result: """ *musicpd.org, reflection section:* diff --git a/src/mopidy_mpd/protocol/status.py b/src/mopidy_mpd/protocol/status.py index c88f949..df6f565 100644 --- a/src/mopidy_mpd/protocol/status.py +++ b/src/mopidy_mpd/protocol/status.py @@ -1,8 +1,18 @@ -import pykka +from __future__ import annotations + +from typing import TYPE_CHECKING, Never + from mopidy.core import PlaybackState from mopidy_mpd import exceptions, protocol, translator +if TYPE_CHECKING: + from mopidy.models import Track + from mopidy.types import DurationMs + + from mopidy_mpd.dispatcher import MpdContext + + #: Subsystems that can be registered with idle command. SUBSYSTEMS = [ "database", @@ -17,7 +27,7 @@ @protocol.commands.add("clearerror") -def clearerror(context): +def clearerror(context: MpdContext) -> Never: """ *musicpd.org, status section:* @@ -30,7 +40,7 @@ def clearerror(context): @protocol.commands.add("currentsong") -def currentsong(context): +def currentsong(context: MpdContext) -> protocol.Result: """ *musicpd.org, status section:* @@ -53,7 +63,7 @@ def currentsong(context): @protocol.commands.add("idle") -def idle(context, *subsystems): +def idle(context: MpdContext, *subsystems: list[str]) -> protocol.Result: """ *musicpd.org, status section:* @@ -109,7 +119,7 @@ def idle(context, *subsystems): @protocol.commands.add("noidle", list_command=False) -def noidle(context): +def noidle(context: MpdContext) -> None: """See :meth:`_status_idle`.""" if not context.subscriptions: return @@ -119,7 +129,7 @@ def noidle(context): @protocol.commands.add("stats") -def stats(context): +def stats(context: MpdContext) -> protocol.Result: """ *musicpd.org, status section:* @@ -146,7 +156,7 @@ def stats(context): @protocol.commands.add("status") -def status(context): +def status(context: MpdContext) -> protocol.Result: """ *musicpd.org, status section:* @@ -182,140 +192,85 @@ def status(context): - ``elapsed``: Higher resolution means time in seconds with three decimal places for millisecond precision. """ - tl_track = context.core.playback.get_current_tl_track() - next_tlid = context.core.tracklist.get_next_tlid() - - futures = { - "tracklist.length": context.core.tracklist.get_length(), - "tracklist.version": context.core.tracklist.get_version(), - "mixer.volume": context.core.mixer.get_volume(), - "tracklist.consume": context.core.tracklist.get_consume(), - "tracklist.random": context.core.tracklist.get_random(), - "tracklist.repeat": context.core.tracklist.get_repeat(), - "tracklist.single": context.core.tracklist.get_single(), - "playback.state": context.core.playback.get_state(), - "playback.current_tl_track": tl_track, - "tracklist.index": context.core.tracklist.index(tl_track.get()), - "tracklist.next_tlid": next_tlid, - "tracklist.next_index": context.core.tracklist.index(tlid=next_tlid.get()), - "playback.time_position": context.core.playback.get_time_position(), - } - pykka.get_all(futures.values()) + # Fire these off first, as other futures depends on them + f_current_tl_track = context.core.playback.get_current_tl_track() + f_next_tlid = context.core.tracklist.get_next_tlid() + + # ...and wait for them to complete + current_tl_track = f_current_tl_track.get() + current_tlid = current_tl_track.tlid if current_tl_track else None + current_track = current_tl_track.track if current_tl_track else None + next_tlid = f_next_tlid.get() + + # Then fire off the rest... + f_current_index = context.core.tracklist.index(tlid=current_tlid) + f_mixer_volume = context.core.mixer.get_volume() + f_next_index = context.core.tracklist.index(tlid=next_tlid) + f_playback_state = context.core.playback.get_state() + f_playback_time_position = context.core.playback.get_time_position() + f_tracklist_consume = context.core.tracklist.get_consume() + f_tracklist_length = context.core.tracklist.get_length() + f_tracklist_random = context.core.tracklist.get_random() + f_tracklist_repeat = context.core.tracklist.get_repeat() + f_tracklist_single = context.core.tracklist.get_single() + f_tracklist_version = context.core.tracklist.get_version() + + # ...and wait for them to complete + current_index = f_current_index.get() + mixer_volume = f_mixer_volume.get() + next_index = f_next_index.get() + playback_state = f_playback_state.get() + playback_time_position = f_playback_time_position.get() + tracklist_consume = f_tracklist_consume.get() + tracklist_length = f_tracklist_length.get() + tracklist_random = f_tracklist_random.get() + tracklist_repeat = f_tracklist_repeat.get() + tracklist_single = f_tracklist_single.get() + tracklist_version = f_tracklist_version.get() + result = [ - ("volume", _status_volume(futures)), - ("repeat", _status_repeat(futures)), - ("random", _status_random(futures)), - ("single", _status_single(futures)), - ("consume", _status_consume(futures)), - ("playlist", _status_playlist_version(futures)), - ("playlistlength", _status_playlist_length(futures)), - ("xfade", _status_xfade(futures)), - ("state", _status_state(futures)), + ("volume", mixer_volume if mixer_volume is not None else -1), + ("repeat", int(tracklist_repeat)), + ("random", int(tracklist_random)), + ("single", int(tracklist_single)), + ("consume", int(tracklist_consume)), + ("playlist", tracklist_version), + ("playlistlength", tracklist_length), + ("xfade", 0), # Not supported + ("state", _status_state(playback_state)), ] - if futures["playback.current_tl_track"].get() is not None: - result.append(("song", _status_songpos(futures))) - result.append(("songid", _status_songid(futures))) - if futures["tracklist.next_tlid"].get() is not None: - result.append(("nextsong", _status_nextsongpos(futures))) - result.append(("nextsongid", _status_nextsongid(futures))) - if futures["playback.state"].get() in ( - PlaybackState.PLAYING, - PlaybackState.PAUSED, + if current_tlid is not None and current_index is not None: + result.append(("song", current_index)) + result.append(("songid", current_tlid)) + if next_tlid is not None and next_index is not None: + result.append(("nextsong", next_index)) + result.append(("nextsongid", next_tlid)) + if ( + playback_state in (PlaybackState.PLAYING, PlaybackState.PAUSED) + and current_track is not None ): - result.append(("time", _status_time(futures))) - result.append(("elapsed", _status_time_elapsed(futures))) - result.append(("bitrate", _status_bitrate(futures))) + result.append(("time", _status_time(playback_time_position, current_track))) + result.append(("elapsed", _status_time_elapsed(playback_time_position))) + result.append(("bitrate", current_track.bitrate or 0)) return result -def _status_bitrate(futures): - current_tl_track = futures["playback.current_tl_track"].get() - if current_tl_track is None: - return 0 - if current_tl_track.track.bitrate is None: - return 0 - return current_tl_track.track.bitrate - - -def _status_consume(futures): - return int(futures["tracklist.consume"].get()) - - -def _status_playlist_length(futures): - return futures["tracklist.length"].get() - - -def _status_playlist_version(futures): - return futures["tracklist.version"].get() - - -def _status_random(futures): - return int(futures["tracklist.random"].get()) - - -def _status_repeat(futures): - return int(futures["tracklist.repeat"].get()) - - -def _status_single(futures): - return int(futures["tracklist.single"].get()) - - -def _status_songid(futures): - current_tl_track = futures["playback.current_tl_track"].get() - if current_tl_track is not None: - return current_tl_track.tlid - return _status_songpos(futures) - - -def _status_songpos(futures): - return futures["tracklist.index"].get() - - -def _status_nextsongid(futures): - return futures["tracklist.next_tlid"].get() - - -def _status_nextsongpos(futures): - return futures["tracklist.next_index"].get() - - -def _status_state(futures): - match futures["playback.state"].get(): +def _status_state(playback_state: PlaybackState) -> str: + match playback_state: case PlaybackState.PLAYING: return "play" case PlaybackState.STOPPED: return "stop" case PlaybackState.PAUSED: return "pause" - case _: - return None -def _status_time(futures): - position = futures["playback.time_position"].get() // 1000 - total = _status_time_total(futures) // 1000 +def _status_time(playback_time_position: DurationMs, current_track: Track) -> str: + position = playback_time_position // 1000 + total = (current_track.length or 0) // 1000 return f"{position:d}:{total:d}" -def _status_time_elapsed(futures): - elapsed = futures["playback.time_position"].get() / 1000.0 +def _status_time_elapsed(playback_time_position: DurationMs) -> str: + elapsed = playback_time_position / 1000.0 return f"{elapsed:.3f}" - - -def _status_time_total(futures): - current_tl_track = futures["playback.current_tl_track"].get() - if current_tl_track is None or current_tl_track.track.length is None: - return 0 - return current_tl_track.track.length - - -def _status_volume(futures): - volume = futures["mixer.volume"].get() - if volume is None: - return -1 - return volume - - -def _status_xfade(futures): - return 0 # Not supported diff --git a/src/mopidy_mpd/protocol/stickers.py b/src/mopidy_mpd/protocol/stickers.py index b67ecf6..6808d51 100644 --- a/src/mopidy_mpd/protocol/stickers.py +++ b/src/mopidy_mpd/protocol/stickers.py @@ -1,8 +1,24 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Never + from mopidy_mpd import exceptions, protocol +if TYPE_CHECKING: + from mopidy.types import Uri + + from mopidy_mpd.dispatcher import MpdContext + @protocol.commands.add("sticker", list_command=False) -def sticker(context, action, field, uri, name=None, value=None): # noqa: PLR0913 +def sticker( # noqa: PLR0913 + context: MpdContext, + action: str, + field: str, + uri: Uri, + name: str | None = None, + value: str | None = None, +) -> Never: """ *musicpd.org, sticker section:* diff --git a/src/mopidy_mpd/protocol/stored_playlists.py b/src/mopidy_mpd/protocol/stored_playlists.py index 009f78d..c89f4e2 100644 --- a/src/mopidy_mpd/protocol/stored_playlists.py +++ b/src/mopidy_mpd/protocol/stored_playlists.py @@ -1,19 +1,47 @@ +from __future__ import annotations + import datetime import logging import re -import urllib +from typing import TYPE_CHECKING, Literal, overload +from urllib.parse import urlparse + +from mopidy.types import Uri, UriScheme from mopidy_mpd import exceptions, protocol, translator +if TYPE_CHECKING: + from collections.abc import Iterable + + from mopidy.models import Playlist, Track + + from mopidy_mpd.dispatcher import MpdContext + logger = logging.getLogger(__name__) -def _check_playlist_name(name): +def _check_playlist_name(name: str) -> None: if re.search("[/\n\r]", name): raise exceptions.MpdInvalidPlaylistNameError -def _get_playlist(context, name, *, must_exist=True): +@overload +def _get_playlist( + context: MpdContext, name: str, *, must_exist: Literal[True] +) -> Playlist: + ... + + +@overload +def _get_playlist( + context: MpdContext, name: str, *, must_exist: Literal[False] +) -> Playlist | None: + ... + + +def _get_playlist( + context: MpdContext, name: str, *, must_exist: bool +) -> Playlist | None: playlist = None uri = context.lookup_playlist_uri_from_name(name) if uri: @@ -24,7 +52,7 @@ def _get_playlist(context, name, *, must_exist=True): @protocol.commands.add("listplaylist") -def listplaylist(context, name): +def listplaylist(context: MpdContext, name: str) -> protocol.Result: """ *musicpd.org, stored playlists section:* @@ -38,12 +66,12 @@ def listplaylist(context, name): file: relative/path/to/file2.ogg file: relative/path/to/file3.mp3 """ - playlist = _get_playlist(context, name) - return [f"file: {track.uri}" for track in playlist.tracks] + playlist = _get_playlist(context, name, must_exist=True) + return [("file", track.uri) for track in playlist.tracks] @protocol.commands.add("listplaylistinfo") -def listplaylistinfo(context, name): +def listplaylistinfo(context: MpdContext, name: str) -> protocol.Result: """ *musicpd.org, stored playlists section:* @@ -56,7 +84,7 @@ def listplaylistinfo(context, name): Standard track listing, with fields: file, Time, Title, Date, Album, Artist, Track """ - playlist = _get_playlist(context, name) + playlist = _get_playlist(context, name, must_exist=True) track_uris = [track.uri for track in playlist.tracks] tracks_map = context.core.library.lookup(uris=track_uris).get() tracks = [] @@ -67,7 +95,7 @@ def listplaylistinfo(context, name): @protocol.commands.add("listplaylists") -def listplaylists(context): +def listplaylists(context: MpdContext) -> protocol.Result: """ *musicpd.org, stored playlists section:* @@ -104,7 +132,7 @@ def listplaylists(context): # TODO: move to translators? -def _get_last_modified(last_modified=None): +def _get_last_modified(last_modified: int | None = None) -> str: """Formats last modified timestamp of a playlist for MPD. Time in UTC with second precision, formatted in the ISO 8601 format, with @@ -122,7 +150,11 @@ def _get_last_modified(last_modified=None): @protocol.commands.add("load", playlist_slice=protocol.RANGE) -def load(context, name, playlist_slice=DEFAULT_PLAYLIST_SLICE): +def load( + context: MpdContext, + name: str, + playlist_slice: slice = DEFAULT_PLAYLIST_SLICE, +) -> None: """ *musicpd.org, stored playlists section:* @@ -143,13 +175,13 @@ def load(context, name, playlist_slice=DEFAULT_PLAYLIST_SLICE): - MPD 0.17.1 does not fail if the specified range is outside the playlist, in either or both ends. """ - playlist = _get_playlist(context, name) + playlist = _get_playlist(context, name, must_exist=True) track_uris = [track.uri for track in playlist.tracks[playlist_slice]] context.core.tracklist.add(uris=track_uris).get() @protocol.commands.add("playlistadd") -def playlistadd(context, name, track_uri): +def playlistadd(context: MpdContext, name: str, track_uri: Uri) -> None: """ *musicpd.org, stored playlists section:* @@ -177,18 +209,18 @@ def playlistadd(context, name, track_uri): ) saved_playlist = context.core.playlists.save(new_playlist).get() if saved_playlist is None: - playlist_scheme = urllib.parse.urlparse(old_playlist.uri).scheme - uri_scheme = urllib.parse.urlparse(track_uri).scheme + playlist_scheme = UriScheme(urlparse(old_playlist.uri).scheme) + uri_scheme = UriScheme(urlparse(track_uri).scheme) raise exceptions.MpdInvalidTrackForPlaylistError( playlist_scheme, uri_scheme ) -def _create_playlist(context, name, tracks): +def _create_playlist(context: MpdContext, name: str, tracks: Iterable[Track]) -> None: """ Creates new playlist using backend appropriate for the given tracks """ - uri_schemes = {urllib.parse.urlparse(t.uri).scheme for t in tracks} + uri_schemes = {urlparse(t.uri).scheme for t in tracks} for scheme in uri_schemes: new_playlist = context.core.playlists.create(name, scheme).get() if new_playlist is None: @@ -209,12 +241,12 @@ def _create_playlist(context, name, tracks): new_playlist = new_playlist.replace(tracks=tracks) saved_playlist = context.core.playlists.save(new_playlist).get() if saved_playlist is None: - uri_scheme = urllib.parse.urlparse(new_playlist.uri).scheme + uri_scheme = UriScheme(urlparse(new_playlist.uri).scheme) raise exceptions.MpdFailedToSavePlaylistError(uri_scheme) @protocol.commands.add("playlistclear") -def playlistclear(context, name): +def playlistclear(context: MpdContext, name: str) -> None: """ *musicpd.org, stored playlists section:* @@ -229,16 +261,19 @@ def playlistclear(context, name): if not playlist: playlist = context.core.playlists.create(name).get() + # TODO(type): Handle the failure to create a playlist + assert playlist + # Just replace tracks with empty list and save playlist = playlist.replace(tracks=[]) if context.core.playlists.save(playlist).get() is None: raise exceptions.MpdFailedToSavePlaylistError( - urllib.parse.urlparse(playlist.uri).scheme + UriScheme(urlparse(playlist.uri).scheme) ) @protocol.commands.add("playlistdelete", songpos=protocol.UINT) -def playlistdelete(context, name, songpos): +def playlistdelete(context: MpdContext, name: str, songpos: int) -> None: """ *musicpd.org, stored playlists section:* @@ -247,7 +282,7 @@ def playlistdelete(context, name, songpos): Deletes ``SONGPOS`` from the playlist ``NAME.m3u``. """ _check_playlist_name(name) - playlist = _get_playlist(context, name) + playlist = _get_playlist(context, name, must_exist=True) try: # Convert tracks to list and remove requested @@ -260,13 +295,11 @@ def playlistdelete(context, name, songpos): playlist = playlist.replace(tracks=tracks) saved_playlist = context.core.playlists.save(playlist).get() if saved_playlist is None: - raise exceptions.MpdFailedToSavePlaylistError( - urllib.parse.urlparse(playlist.uri).scheme - ) + raise exceptions.MpdFailedToSavePlaylistError(urlparse(playlist.uri).scheme) @protocol.commands.add("playlistmove", from_pos=protocol.UINT, to_pos=protocol.UINT) -def playlistmove(context, name, from_pos, to_pos): +def playlistmove(context: MpdContext, name: str, from_pos: int, to_pos: int) -> None: """ *musicpd.org, stored playlists section:* @@ -285,7 +318,7 @@ def playlistmove(context, name, from_pos, to_pos): return _check_playlist_name(name) - playlist = _get_playlist(context, name) + playlist = _get_playlist(context, name, must_exist=True) if from_pos == to_pos: return # Nothing to do @@ -301,13 +334,11 @@ def playlistmove(context, name, from_pos, to_pos): playlist = playlist.replace(tracks=tracks) saved_playlist = context.core.playlists.save(playlist).get() if saved_playlist is None: - raise exceptions.MpdFailedToSavePlaylistError( - urllib.parse.urlparse(playlist.uri).scheme - ) + raise exceptions.MpdFailedToSavePlaylistError(urlparse(playlist.uri).scheme) @protocol.commands.add("rename") -def rename(context, old_name, new_name): +def rename(context: MpdContext, old_name: str, new_name: str) -> None: """ *musicpd.org, stored playlists section:* @@ -318,14 +349,14 @@ def rename(context, old_name, new_name): _check_playlist_name(old_name) _check_playlist_name(new_name) - old_playlist = _get_playlist(context, old_name) + old_playlist = _get_playlist(context, old_name, must_exist=True) if _get_playlist(context, new_name, must_exist=False): raise exceptions.MpdExistError("Playlist already exists") # TODO: should we purge the mapping in an else? # Create copy of the playlist and remove original - uri_scheme = urllib.parse.urlparse(old_playlist.uri).scheme + uri_scheme = UriScheme(urlparse(old_playlist.uri).scheme) new_playlist = context.core.playlists.create(new_name, uri_scheme).get() new_playlist = new_playlist.replace(tracks=old_playlist.tracks) saved_playlist = context.core.playlists.save(new_playlist).get() @@ -336,7 +367,7 @@ def rename(context, old_name, new_name): @protocol.commands.add("rm") -def rm(context, name): +def rm(context: MpdContext, name: str) -> None: """ *musicpd.org, stored playlists section:* @@ -352,7 +383,7 @@ def rm(context, name): @protocol.commands.add("save") -def save(context, name): +def save(context: MpdContext, name: str) -> None: """ *musicpd.org, stored playlists section:* @@ -373,5 +404,5 @@ def save(context, name): saved_playlist = context.core.playlists.save(new_playlist).get() if saved_playlist is None: raise exceptions.MpdFailedToSavePlaylistError( - urllib.parse.urlparse(playlist.uri).scheme + UriScheme(urlparse(playlist.uri).scheme) ) diff --git a/src/mopidy_mpd/session.py b/src/mopidy_mpd/session.py index c3f0082..9c70d5c 100644 --- a/src/mopidy_mpd/session.py +++ b/src/mopidy_mpd/session.py @@ -1,13 +1,21 @@ +from __future__ import annotations + import logging +from typing import TYPE_CHECKING, NoReturn from mopidy_mpd import dispatcher, formatting, network, protocol from mopidy_mpd.protocol import tagtype_list +if TYPE_CHECKING: + from mopidy.core import CoreProxy + from mopidy.ext import Config + + from mopidy_mpd.uri_mapper import MpdUriMapper + logger = logging.getLogger(__name__) class MpdSession(network.LineProtocol): - """ The MPD client session. Keeps track of a single client session. Any requests from the client is passed on to the MPD request dispatcher. @@ -15,20 +23,28 @@ class MpdSession(network.LineProtocol): terminator = protocol.LINE_TERMINATOR encoding = protocol.ENCODING - delimiter = rb"\r?\n" - def __init__(self, connection, config=None, core=None, uri_map=None): + def __init__( + self, + connection: network.Connection, + config: Config | None = None, + core: CoreProxy | None = None, + uri_map: MpdUriMapper | None = None, + ) -> None: super().__init__(connection) self.dispatcher = dispatcher.MpdDispatcher( - session=self, config=config, core=core, uri_map=uri_map + session=self, + config=config, + core=core, + uri_map=uri_map, ) self.tagtypes = tagtype_list.TAGTYPE_LIST.copy() - def on_start(self): + def on_start(self) -> None: logger.info("New MPD connection from %s", self.connection) self.send_lines([f"OK MPD {protocol.VERSION}"]) - def on_line_received(self, line): + def on_line_received(self, line: str) -> None: logger.debug("Request from %s: %s", self.connection, line) # All mpd commands start with a lowercase alphabetic character @@ -50,10 +66,10 @@ def on_line_received(self, line): self.send_lines(response) - def on_event(self, subsystem): + def on_event(self, subsystem: str) -> None: self.dispatcher.handle_idle(subsystem) - def decode(self, line): + def decode(self, line: bytes) -> str: try: return super().decode(line) except ValueError: @@ -62,6 +78,7 @@ def decode(self, line): "supplied by client was not valid." ) self.stop() + return NoReturn - def close(self): + def close(self) -> None: self.stop() diff --git a/src/mopidy_mpd/tokenize.py b/src/mopidy_mpd/tokenize.py index 84fe81d..9e000b4 100644 --- a/src/mopidy_mpd/tokenize.py +++ b/src/mopidy_mpd/tokenize.py @@ -44,7 +44,7 @@ UNESCAPE_RE = re.compile(r"\\(.)") # Backslash escapes any following char. -def split(line): +def split(line: str) -> list[str]: """Splits a line into tokens using same rules as MPD. - Lines may not start with whitespace @@ -71,7 +71,7 @@ def split(line): if whitespace: raise exceptions.MpdUnknownError("Letter expected") - result = [command] + result: list[str] = [command] while remainder: match = PARAM_RE.match(remainder) if not match: @@ -82,7 +82,7 @@ def split(line): return result -def _determine_error_message(remainder): +def _determine_error_message(remainder: str) -> str: """Helper to emulate MPD errors.""" # Following checks are simply to match MPD error messages: match = BAD_QUOTED_PARAM_RE.match(remainder) diff --git a/src/mopidy_mpd/translator.py b/src/mopidy_mpd/translator.py index 52aeece..46b4899 100644 --- a/src/mopidy_mpd/translator.py +++ b/src/mopidy_mpd/translator.py @@ -1,35 +1,47 @@ +from __future__ import annotations + import datetime import logging +from typing import TYPE_CHECKING -from mopidy.models import TlTrack +from mopidy.models import Album, Artist, Playlist, TlTrack, Track from mopidy_mpd.protocol import tagtype_list +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from mopidy_mpd import protocol + logger = logging.getLogger(__name__) -def track_to_mpd_format(track, tagtypes, *, position=None, stream_title=None): # noqa: C901, PLR0912 +def track_to_mpd_format( # noqa: C901, PLR0912 + obj: Track | TlTrack, + tagtypes: set[str], + *, + position: int | None = None, + stream_title: str | None = None, +) -> protocol.ResultList: """ Format track for output to MPD client. :param track: the track - :type track: :class:`mopidy.models.Track` or :class:`mopidy.models.TlTrack` :param position: track's position in playlist - :type position: integer :param stream_title: the current streams title - :type position: string - :rtype: list of two-tuples """ - if isinstance(track, TlTrack): - (tlid, track) = track - else: - (tlid, track) = (None, track) # noqa: PLW0127 + match obj: + case TlTrack() as tl_track: + tlid = tl_track.tlid + track = tl_track.track + case Track() as track: + tlid = None if not track.uri: logger.warning("Ignoring track without uri") return [] - result = [ + result: protocol.Result = [ ("file", track.uri), ("Time", track.length and (track.length // 1000) or 0), *multi_tag_list(track.artists, "name", "Artist"), @@ -95,17 +107,18 @@ def track_to_mpd_format(track, tagtypes, *, position=None, stream_title=None): return [element for element in result if _has_value(tagtypes, *element)] -def _has_value(tagtypes, tagtype, value): +def _has_value( + tagtypes: set[str], + tagtype: str, + value: str | int, +) -> bool: """ Determine whether to add the tagtype to the output or not. The tagtype must be in the list of tagtypes configured for the client. :param tagtypes: the MPD tagtypes configured for the client - :type tagtypes: set of strings :param tagtype: the MPD tagtype - :type tagtype: string :param value: the tag value - :rtype: bool """ if tagtype in tagtype_list.TAGTYPE_LIST: if tagtype not in tagtypes: @@ -114,16 +127,15 @@ def _has_value(tagtypes, tagtype, value): return True -def concat_multi_values(models, attribute): +def concat_multi_values( + models: Iterable[Artist | Album | Track], + attribute: str, +) -> str: """ Format Mopidy model values for output to MPD client. :param models: the models - :type models: array of :class:`mopidy.models.Artist`, - :class:`mopidy.models.Album` or :class:`mopidy.models.Track` :param attribute: the attribute to use - :type attribute: string - :rtype: string """ # Don't sort the values. MPD doesn't appear to (or if it does it's not # strict alphabetical). If we just use them in the order in which they come @@ -133,60 +145,68 @@ def concat_multi_values(models, attribute): ) -def multi_tag_list(objects, attribute, tag): +def multi_tag_list( + models: Iterable[Artist | Album | Track], + attribute: str, + tag: str, +) -> protocol.ResultList: """ Format multiple objects for output to MPD client in a list with one tag per value. - :param objects: the model objects - :type objects: array of :class:`mopidy.models.Artist`, - :class:`mopidy.models.Album`, or :class:`mopidy.models.Track` + :param models: the model objects :param attribute: the attribute to use - :type attribute: string :param tag: the name of the tag - :type tag: string - :rtype: list of tuples of string and attribute value """ return [ (tag, getattr(obj, attribute)) - for obj in objects + for obj in models if getattr(obj, attribute, None) is not None ] -def tracks_to_mpd_format(tracks, tagtypes, *, start=0, end=None): +def tracks_to_mpd_format( + tracks: Sequence[Track | TlTrack], + tagtypes: set[str], + *, + start: int = 0, + end: int | None = None, +) -> protocol.ResultList: """ Format list of tracks for output to MPD client. Optionally limit output to the slice ``[start:end]`` of the list. :param tracks: the tracks - :type tracks: list of :class:`mopidy.models.Track` or - :class:`mopidy.models.TlTrack` :param start: position of first track to include in output - :type start: int (positive or negative) :param end: position after last track to include in output - :type end: int (positive or negative) or :class:`None` for end of list - :rtype: list of lists of two-tuples """ if end is None: end = len(tracks) tracks = tracks[start:end] positions = range(start, end) assert len(tracks) == len(positions) - result = [] - for track, position in zip(tracks, positions, strict=False): + result: protocol.ResultList = [] + for track, position in zip(tracks, positions, strict=True): formatted_track = track_to_mpd_format(track, tagtypes, position=position) if formatted_track: - result.append(formatted_track) + result.extend(formatted_track) return result -def playlist_to_mpd_format(playlist, tagtypes, *, start=0, end=None): +def playlist_to_mpd_format( + playlist: Playlist, + tagtypes: set[str], + *, + start: int = 0, + end: int | None = None, +) -> protocol.ResultList: """ Format playlist for output to MPD client. - Arguments as for :func:`tracks_to_mpd_format`, except the first one. + :param playlist: the playlist + :param start: position of first track to include in output + :param end: position after last track to include in output """ - return tracks_to_mpd_format(playlist.tracks, tagtypes, start=start, end=end) + return tracks_to_mpd_format(list(playlist.tracks), tagtypes, start=start, end=end) diff --git a/src/mopidy_mpd/types.py b/src/mopidy_mpd/types.py new file mode 100644 index 0000000..6929b47 --- /dev/null +++ b/src/mopidy_mpd/types.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from typing import TypeAlias, TypedDict + + +class MpdConfig(TypedDict): + hostname: str + port: int + password: str | None + max_connections: int + connection_timeout: int + zeroconf: str + command_blacklist: list[str] + default_playlist_scheme: str + + +SocketAddress: TypeAlias = tuple[str, int | None] diff --git a/src/mopidy_mpd/uri_mapper.py b/src/mopidy_mpd/uri_mapper.py index 407cdc9..20e8f2f 100644 --- a/src/mopidy_mpd/uri_mapper.py +++ b/src/mopidy_mpd/uri_mapper.py @@ -1,28 +1,34 @@ +from __future__ import annotations + import re +from typing import TYPE_CHECKING -# TOOD: refactor this into a generic mapper that does not know about browse -# or playlists and then use one instance for each case? +if TYPE_CHECKING: + from mopidy.core import CoreProxy + from mopidy.types import Uri class MpdUriMapper: - """ Maintains the mappings between uniquified MPD names and URIs. """ - #: The Mopidy core API. An instance of :class:`mopidy.core.Core`. - core = None + # TODO: refactor this into a generic mapper that does not know about browse + # or playlists and then use one instance for each case? + + #: The Mopidy core API. + core: CoreProxy _invalid_browse_chars = re.compile(r"[\n\r]") _invalid_playlist_chars = re.compile(r"[/]") - def __init__(self, core=None): + def __init__(self, core: CoreProxy) -> None: self.core = core - self._uri_from_name = {} - self._browse_name_from_uri = {} - self._playlist_name_from_uri = {} + self._uri_from_name: dict[str, Uri | None] = {} + self._browse_name_from_uri: dict[Uri | None, str] = {} + self._playlist_name_from_uri: dict[Uri | None, str] = {} - def _create_unique_name(self, name, uri): + def _create_unique_name(self, name: str, uri: Uri | None) -> str: stripped_name = self._invalid_browse_chars.sub(" ", name) name = stripped_name i = 2 @@ -33,7 +39,7 @@ def _create_unique_name(self, name, uri): i += 1 return name - def insert(self, name, uri, *, playlist=False): + def insert(self, name: str, uri: Uri | None, *, playlist: bool = False) -> str: """ Create a unique and MPD compatible name that maps to the given URI. """ @@ -45,13 +51,13 @@ def insert(self, name, uri, *, playlist=False): self._browse_name_from_uri[uri] = name return name - def uri_from_name(self, name): + def uri_from_name(self, name: str) -> Uri | None: """ - Return the uri for the given MPD name. + Return the URI for the given MPD name. """ return self._uri_from_name.get(name) - def refresh_playlists_mapping(self): + def refresh_playlists_mapping(self) -> None: """ Maintain map between playlists and unique playlist names to be used by MPD. @@ -65,7 +71,7 @@ def refresh_playlists_mapping(self): name = self._invalid_playlist_chars.sub("|", playlist_ref.name) self.insert(name, playlist_ref.uri, playlist=True) - def playlist_uri_from_name(self, name): + def playlist_uri_from_name(self, name: str) -> Uri | None: """ Helper function to retrieve a playlist URI from its unique MPD name. """ @@ -73,7 +79,7 @@ def playlist_uri_from_name(self, name): self.refresh_playlists_mapping() return self._uri_from_name.get(name) - def playlist_name_from_uri(self, uri): + def playlist_name_from_uri(self, uri: Uri) -> str: """ Helper function to retrieve the unique MPD playlist name from its URI. """ diff --git a/tests/network/test_lineprotocol.py b/tests/network/test_lineprotocol.py index eb87d2c..ca41e5b 100644 --- a/tests/network/test_lineprotocol.py +++ b/tests/network/test_lineprotocol.py @@ -22,20 +22,11 @@ def prepare_on_receive_test(self, return_value=None): self.mock.parse_lines.return_value = return_value or [] def test_init_stores_values_in_attributes(self): - delimiter = re.compile(network.LineProtocol.terminator) network.LineProtocol.__init__(self.mock, sentinel.connection) assert sentinel.connection == self.mock.connection assert self.mock.recv_buffer == b"" - assert delimiter == self.mock.delimiter assert not self.mock.prevent_timeout - def test_init_compiles_delimiter(self): - self.mock.delimiter = "\r?\n" - delimiter = re.compile("\r?\n") - - network.LineProtocol.__init__(self.mock, sentinel.connection) - assert delimiter == self.mock.delimiter - def test_on_receive_close_calls_stop(self): self.prepare_on_receive_test() diff --git a/tests/network/test_server.py b/tests/network/test_server.py index 5f90365..68e2b7c 100644 --- a/tests/network/test_server.py +++ b/tests/network/test_server.py @@ -204,12 +204,15 @@ def test_handle_connection_exceeded_connections(self): def test_accept_connection(self): sock = Mock(spec=socket.socket) connected_sock = Mock(spec=socket.socket) - sock.accept.return_value = (connected_sock, sentinel.addr) + sock.accept.return_value = ( + connected_sock, + (sentinel.host, sentinel.port, sentinel.flow, sentinel.scope), + ) self.mock.server_socket = sock sock, addr = network.Server.accept_connection(self.mock) - assert connected_sock == sock - assert sentinel.addr == addr + assert sock == connected_sock + assert addr == (sentinel.host, sentinel.port) def test_accept_connection_unix(self): sock = Mock(spec=socket.socket) diff --git a/tests/protocol/test_music_db.py b/tests/protocol/test_music_db.py index b8ef56b..be45e0d 100644 --- a/tests/protocol/test_music_db.py +++ b/tests/protocol/test_music_db.py @@ -9,24 +9,17 @@ # TODO: split into more modules for faster parallel tests? -class QueryFromMpdSearchFormatTest(unittest.TestCase): +class QueryForSearchTest(unittest.TestCase): def test_dates_are_extracted(self): - result = music_db._query_from_mpd_search_parameters( - ["Date", "1974-01-02", "Date", "1975"], music_db._SEARCH_MAPPING - ) - assert result["date"][0] == "1974-01-02" - assert result["date"][1] == "1975" + result = music_db._query_for_search(["Date", "1974-01-02", "Date", "1975"]) + assert result["date"] == ["1974-01-02", "1975"] def test_empty_value_is_ignored(self): - result = music_db._query_from_mpd_search_parameters( - ["Date", ""], music_db._SEARCH_MAPPING - ) + result = music_db._query_for_search(["Date", ""]) assert result == {} def test_whitespace_value_is_ignored(self): - result = music_db._query_from_mpd_search_parameters( - ["Date", " "], music_db._SEARCH_MAPPING - ) + result = music_db._query_for_search(["Date", " "]) assert result == {} # TODO Test more mappings diff --git a/tests/test_commands.py b/tests/test_commands.py index e9a7cdb..80d42c1 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1,6 +1,7 @@ import unittest from mopidy_mpd import exceptions, protocol +from mopidy_mpd.dispatcher import MpdContext class TestConverts(unittest.TestCase): @@ -56,11 +57,11 @@ def setUp(self): def test_add_as_a_decorator(self): @self.commands.add("test") - def test(context): + def test(context: MpdContext): pass def test_register_second_command_to_same_name_fails(self): - def func(context): + def func(context: MpdContext): pass self.commands.add("foo")(func) @@ -70,40 +71,42 @@ def func(context): def test_function_only_takes_context_succeeds(self): sentinel = object() self.commands.add("bar")(lambda context: sentinel) - assert sentinel == self.commands.call(["bar"]) + assert sentinel == self.commands.call(context=None, tokens=["bar"]) def test_function_has_required_arg_succeeds(self): sentinel = object() self.commands.add("bar")(lambda context, required: sentinel) - assert sentinel == self.commands.call(["bar", "arg"]) + assert sentinel == self.commands.call(context=None, tokens=["bar", "arg"]) def test_function_has_optional_args_succeeds(self): sentinel = object() self.commands.add("bar")(lambda context, optional=None: sentinel) - assert sentinel == self.commands.call(["bar"]) - assert sentinel == self.commands.call(["bar", "arg"]) + assert sentinel == self.commands.call(context=None, tokens=["bar"]) + assert sentinel == self.commands.call(context=None, tokens=["bar", "arg"]) def test_function_has_required_and_optional_args_succeeds(self): sentinel = object() - def func(context, required, optional=None): + def func(context: MpdContext, required, optional=None): return sentinel self.commands.add("bar")(func) - assert sentinel == self.commands.call(["bar", "arg"]) - assert sentinel == self.commands.call(["bar", "arg", "arg"]) + assert sentinel == self.commands.call(context=None, tokens=["bar", "arg"]) + assert sentinel == self.commands.call( + context=None, tokens=["bar", "arg", "arg"] + ) def test_function_has_varargs_succeeds(self): sentinel, args = object(), [] self.commands.add("bar")(lambda context, *args: sentinel) for _ in range(10): - assert sentinel == self.commands.call(["bar", *args]) + assert sentinel == self.commands.call(context=None, tokens=["bar", *args]) args.append("test") def test_function_has_only_varags_succeeds(self): sentinel = object() self.commands.add("baz")(lambda *args: sentinel) - assert sentinel == self.commands.call(["baz"]) + assert sentinel == self.commands.call(context=None, tokens=["baz"]) def test_function_has_no_arguments_fails(self): with self.assertRaises(TypeError): @@ -112,7 +115,7 @@ def test_function_has_no_arguments_fails(self): def test_function_has_required_and_varargs_fails(self): with self.assertRaises(TypeError): - def func(context, required, *args): + def func(context: MpdContext, required, *args): pass self.commands.add("test")(func) @@ -120,7 +123,7 @@ def func(context, required, *args): def test_function_has_optional_and_varargs_fails(self): with self.assertRaises(TypeError): - def func(context, optional=None, *args): + def func(context: MpdContext, optional=None, *args): pass self.commands.add("test")(func) @@ -135,40 +138,43 @@ def test_call_chooses_correct_handler(self): self.commands.add("bar")(lambda context: sentinel2) self.commands.add("baz")(lambda context: sentinel3) - assert sentinel1 == self.commands.call(["foo"]) - assert sentinel2 == self.commands.call(["bar"]) - assert sentinel3 == self.commands.call(["baz"]) + assert sentinel1 == self.commands.call(context=None, tokens=["foo"]) + assert sentinel2 == self.commands.call(context=None, tokens=["bar"]) + assert sentinel3 == self.commands.call(context=None, tokens=["baz"]) def test_call_with_nonexistent_handler(self): with self.assertRaises(exceptions.MpdUnknownCommandError): - self.commands.call(["bar"]) + self.commands.call(context=None, tokens=["bar"]) def test_call_passes_context(self): sentinel = object() self.commands.add("foo")(lambda context: context) - assert sentinel == self.commands.call(["foo"], context=sentinel) + assert sentinel == self.commands.call(context=sentinel, tokens=["foo"]) def test_call_without_args_fails(self): with self.assertRaises(exceptions.MpdNoCommandError): - self.commands.call([]) + self.commands.call(context=None, tokens=[]) def test_call_passes_required_argument(self): self.commands.add("foo")(lambda context, required: required) - assert self.commands.call(["foo", "test123"]) == "test123" + assert self.commands.call(context=None, tokens=["foo", "test123"]) == "test123" def test_call_passes_optional_argument(self): sentinel = object() self.commands.add("foo")(lambda context, optional=sentinel: optional) - assert sentinel == self.commands.call(["foo"]) - assert self.commands.call(["foo", "test"]) == "test" + assert sentinel == self.commands.call(context=None, tokens=["foo"]) + assert self.commands.call(context=None, tokens=["foo", "test"]) == "test" def test_call_passes_required_and_optional_argument(self): - def func(context, required, optional=None): + def func(context: MpdContext, required, optional=None): return (required, optional) self.commands.add("foo")(func) - assert self.commands.call(["foo", "arg"]) == ("arg", None) - assert self.commands.call(["foo", "arg", "kwarg"]) == ("arg", "kwarg") + assert self.commands.call(context=None, tokens=["foo", "arg"]) == ("arg", None) + assert self.commands.call(context=None, tokens=["foo", "arg", "kwarg"]) == ( + "arg", + "kwarg", + ) def test_call_passes_varargs(self): self.commands.add("foo")(lambda context, *args: args) @@ -176,50 +182,50 @@ def test_call_passes_varargs(self): def test_call_incorrect_args(self): self.commands.add("foo")(lambda context: context) with self.assertRaises(exceptions.MpdArgError): - self.commands.call(["foo", "bar"]) + self.commands.call(context=None, tokens=["foo", "bar"]) self.commands.add("bar")(lambda context, required: context) with self.assertRaises(exceptions.MpdArgError): - self.commands.call(["bar", "bar", "baz"]) + self.commands.call(context=None, tokens=["bar", "bar", "baz"]) self.commands.add("baz")(lambda context, optional=None: context) with self.assertRaises(exceptions.MpdArgError): - self.commands.call(["baz", "bar", "baz"]) + self.commands.call(context=None, tokens=["baz", "bar", "baz"]) def test_validator_gets_applied_to_required_arg(self): sentinel = object() - def func(context, required): + def func(context: MpdContext, required): return required self.commands.add("test", required=lambda v: sentinel)(func) - assert sentinel == self.commands.call(["test", "foo"]) + assert sentinel == self.commands.call(context=None, tokens=["test", "foo"]) def test_validator_gets_applied_to_optional_arg(self): sentinel = object() - def func(context, optional=None): + def func(context: MpdContext, optional=None): return optional self.commands.add("foo", optional=lambda v: sentinel)(func) - assert sentinel == self.commands.call(["foo", "123"]) + assert sentinel == self.commands.call(context=None, tokens=["foo", "123"]) def test_validator_skips_optional_default(self): sentinel = object() - def func(context, optional=sentinel): + def func(context: MpdContext, optional=sentinel): return optional self.commands.add("foo", optional=lambda v: None)(func) - assert sentinel == self.commands.call(["foo"]) + assert sentinel == self.commands.call(context=None, tokens=["foo"]) def test_validator_applied_to_non_existent_arg_fails(self): self.commands.add("foo")(lambda context, arg: arg) with self.assertRaises(TypeError): - def func(context, wrong_arg): + def func(context: MpdContext, wrong_arg): return wrong_arg self.commands.add("bar", arg=lambda v: v)(func) @@ -228,7 +234,7 @@ def test_validator_called_context_fails(self): return # TODO: how to handle this with self.assertRaises(TypeError): - def func(context): + def func(context: MpdContext): pass self.commands.add("bar", context=lambda v: v)(func) @@ -237,19 +243,19 @@ def test_validator_value_error_is_converted(self): def validdate(value): raise ValueError - def func(context, arg): + def func(context: MpdContext, arg): pass self.commands.add("bar", arg=validdate)(func) with self.assertRaises(exceptions.MpdArgError): - self.commands.call(["bar", "test"]) + self.commands.call(context=None, tokens=["bar", "test"]) def test_auth_required_gets_stored(self): - def func1(context): + def func1(context: MpdContext): pass - def func2(context): + def func2(context: MpdContext): pass self.commands.add("foo")(func1) @@ -259,10 +265,10 @@ def func2(context): assert not self.commands.handlers["bar"].auth_required def test_list_command_gets_stored(self): - def func1(context): + def func1(context: MpdContext): pass - def func2(context): + def func2(context: MpdContext): pass self.commands.add("foo")(func1) diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index 1145601..c3b542c 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -1,8 +1,10 @@ import unittest +from typing import cast import pykka import pytest -from mopidy import core +from mopidy.backend import BackendProxy +from mopidy.core import Core, CoreProxy from mopidy.models import Ref from mopidy_mpd.dispatcher import MpdContext, MpdDispatcher from mopidy_mpd.exceptions import MpdAckError @@ -15,9 +17,10 @@ class MpdDispatcherTest(unittest.TestCase): def setUp(self): config = {"mpd": {"password": None, "command_blacklist": ["disabled"]}} self.backend = dummy_backend.create_proxy() - self.dispatcher = MpdDispatcher(config=config) - - self.core = core.Core.start(config=None, backends=[self.backend]).proxy() + self.core = cast( + CoreProxy, Core.start(config=None, backends=[self.backend]).proxy() + ) + self.dispatcher = MpdDispatcher(config=config, core=self.core) def tearDown(self): pykka.ActorRegistry.stop_all() @@ -44,30 +47,40 @@ def test_handling_blacklisted_command(self): @pytest.fixture() -def a_track(): +def a_track() -> Ref: return Ref.track(uri="dummy:/a", name="a") @pytest.fixture() -def b_track(): +def b_track() -> Ref: return Ref.track(uri="dummy:/foo/b", name="b") @pytest.fixture() -def backend_to_browse(a_track, b_track): - backend = dummy_backend.create_proxy() +def backend_to_browse(a_track: Ref, b_track: Ref) -> BackendProxy: + backend = cast(BackendProxy, dummy_backend.create_proxy()) backend.library.dummy_browse_result = { - "dummy:/": [a_track, Ref.directory(uri="dummy:/foo", name="foo")], - "dummy:/foo": [b_track], + "dummy:/": [ + a_track, + Ref.directory(uri="dummy:/foo", name="foo"), + ], + "dummy:/foo": [ + b_track, + ], } return backend @pytest.fixture() -def mpd_context(backend_to_browse): - mopidy_core = core.Core.start(config=None, backends=[backend_to_browse]).proxy() - uri_map = MpdUriMapper(mopidy_core) - return MpdContext(None, core=mopidy_core, uri_map=uri_map) +def mpd_context(backend_to_browse: BackendProxy) -> MpdContext: + core = cast( + CoreProxy, + Core.start(config=None, backends=[backend_to_browse]).proxy(), + ) + return MpdContext( + core=core, + uri_map=MpdUriMapper(core), + ) class TestMpdContext: diff --git a/tests/test_translator.py b/tests/test_translator.py index fd99a21..6828c38 100644 --- a/tests/test_translator.py +++ b/tests/test_translator.py @@ -207,8 +207,20 @@ def test_mpd_format(self): Track(uri="baz", track_no=3), ] ) + result = translator.playlist_to_mpd_format(playlist, tagtype_list.TAGTYPE_LIST) - assert len(result) == 3 + + assert result == [ + ("file", "foo"), + ("Time", 0), + ("Track", 1), + ("file", "bàr"), + ("Time", 0), + ("Track", 2), + ("file", "baz"), + ("Time", 0), + ("Track", 3), + ] def test_mpd_format_with_range(self): playlist = Playlist( @@ -218,8 +230,9 @@ def test_mpd_format_with_range(self): Track(uri="baz", track_no=3), ] ) + result = translator.playlist_to_mpd_format( playlist, tagtype_list.TAGTYPE_LIST, start=1, end=2 ) - assert len(result) == 1 - assert dict(result[0])["Track"] == 2 + + assert result == [("file", "bàr"), ("Time", 0), ("Track", 2)] From adee48b0afc3041e64f96801452ae8427d079474 Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Fri, 23 Feb 2024 14:43:32 +0100 Subject: [PATCH 04/19] Set up pyright in basic checking mode --- .github/workflows/ci.yml | 3 +++ pyproject.toml | 10 ++++++++++ tox.ini | 5 ++++- 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1a5e5e3..444f36e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,6 +19,9 @@ jobs: python: "3.12" tox: py312 coverage: true + - name: "Lint: pyright" + python: "3.12" + tox: pyright - name: "Lint: ruff lint" python: "3.12" tox: ruff-lint diff --git a/pyproject.toml b/pyproject.toml index 24a7e02..9315e79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,16 @@ Issues = "https://github.com/mopidy/mopidy-mpd/issues" mpd = "mopidy_mpd:Extension" +[tool.pyright] +pythonVersion = "3.11" +# Use venv from parent directory, to share it with any extensions: +venvPath = "../" +venv = ".venv" +typeCheckingMode = "basic" +# Already covered by flake8-self: +reportPrivateImportUsage = false + + [tool.ruff] target-version = "py311" diff --git a/tox.ini b/tox.ini index f3e9390..eaa8633 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py311, py312, check-manifest, flake8 +envlist = py311, py312, typing, ruff-lint, ruff-format [testenv] sitepackages = true @@ -10,6 +10,9 @@ commands = --cov=mopidy_mpd --cov-report=term-missing \ {posargs} +[testenv:pyright] +deps = .[typing] +commands = python -m pyright src [testenv:ruff-lint] deps = .[lint] From 40d02d4385d6c06ba452dc39eaa3534185078e1a Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Sat, 24 Feb 2024 02:20:33 +0100 Subject: [PATCH 05/19] Fix pyright's basic warnings --- pyproject.toml | 7 ++++- src/mopidy_mpd/dispatcher.py | 23 +++++++--------- src/mopidy_mpd/network.py | 2 +- src/mopidy_mpd/protocol/connection.py | 4 +-- src/mopidy_mpd/protocol/current_playlist.py | 15 ++++++----- src/mopidy_mpd/protocol/music_db.py | 9 ++++--- src/mopidy_mpd/protocol/playback.py | 4 +-- src/mopidy_mpd/protocol/status.py | 5 ++-- src/mopidy_mpd/protocol/stored_playlists.py | 25 +++++++++++------- src/mopidy_mpd/session.py | 7 ++--- src/mopidy_mpd/translator.py | 12 ++++++--- tests/protocol/__init__.py | 22 +++++++++------- tests/test_dispatcher.py | 11 +++++--- tests/test_session.py | 12 +++++++-- tests/test_status.py | 29 ++++++++++++++------- 15 files changed, 112 insertions(+), 75 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9315e79..2002f57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,12 @@ classifiers = [ "Programming Language :: Python :: 3.12", "Topic :: Multimedia :: Sound/Audio :: Players", ] -dependencies = ["mopidy >= 4.0.0a1", "pykka >= 4.0", "setuptools >= 66"] +dependencies = [ + "mopidy >= 4.0.0a1", + "pygobject >= 3.42", + "pykka >= 4.0", + "setuptools >= 66", +] [project.optional-dependencies] lint = ["ruff"] diff --git a/src/mopidy_mpd/dispatcher.py b/src/mopidy_mpd/dispatcher.py index cf71a7f..5c73590 100644 --- a/src/mopidy_mpd/dispatcher.py +++ b/src/mopidy_mpd/dispatcher.py @@ -17,6 +17,7 @@ import pykka from mopidy_mpd import exceptions, protocol, tokenize, types +from mopidy_mpd.uri_mapper import MpdUriMapper if TYPE_CHECKING: from mopidy.core import CoreProxy @@ -25,7 +26,6 @@ from mopidy.types import Uri from mopidy_mpd.session import MpdSession - from mopidy_mpd.uri_mapper import MpdUriMapper logger = logging.getLogger(__name__) @@ -52,7 +52,6 @@ def __init__( config: Config, core: CoreProxy, session: MpdSession, - uri_map: MpdUriMapper, ) -> None: self.config = config self.mpd_config = cast(types.MpdConfig, config.get("mpd", {}) if config else {}) @@ -66,7 +65,6 @@ def __init__( dispatcher=self, session=session, config=config, - uri_map=uri_map, ) def handle_request( @@ -76,7 +74,7 @@ def handle_request( ) -> Response: """Dispatch incoming requests to the correct handler.""" self.command_list_index = current_command_list_index - response: Response = [] + response: Response = Response([]) filter_chain: list[Filter] = [ self._catch_mpd_ack_errors_filter, self._authenticate_filter, @@ -159,7 +157,7 @@ def _command_list_filter( ) -> Response: if self._is_receiving_command_list(request): self.command_list.append(request) - return [] + return Response([]) response = self._call_next_filter(request, response, filter_chain) if ( @@ -194,15 +192,15 @@ def _idle_filter( repr("noidle"), ) self.context.session.close() - return [] + return Response([]) if not self._is_currently_idle() and self._noidle.match(request): - return [] # noidle was called before idle + return Response([]) # noidle was called before idle response = self._call_next_filter(request, response, filter_chain) if self._is_currently_idle(): - return [] + return Response([]) return response @@ -302,8 +300,6 @@ class MpdContext: #: The Mopidy core API. core: CoreProxy - _uri_map: MpdUriMapper - #: The current dispatcher instance. dispatcher: MpdDispatcher @@ -319,16 +315,16 @@ class MpdContext: #: The subsystems that we want to be notified about in idle mode. subscriptions: set[str] - def __init__( # noqa: PLR0913 + _uri_map: MpdUriMapper + + def __init__( self, config: Config, core: CoreProxy, - uri_map: MpdUriMapper, dispatcher: MpdDispatcher, session: MpdSession, ) -> None: self.core = core - self._uri_map = uri_map self.dispatcher = dispatcher self.session = session if config is not None: @@ -336,6 +332,7 @@ def __init__( # noqa: PLR0913 self.password = mpd_config["password"] self.events = set() self.subscriptions = set() + self._uri_map = MpdUriMapper(core) def lookup_playlist_uri_from_name(self, name: str) -> Uri | None: """ diff --git a/src/mopidy_mpd/network.py b/src/mopidy_mpd/network.py index fe4bf98..3fc2872 100644 --- a/src/mopidy_mpd/network.py +++ b/src/mopidy_mpd/network.py @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Any, NoReturn import pykka -from gi.repository import GLib +from gi.repository import GLib # pyright: ignore[reportMissingModuleSource] logger = logging.getLogger(__name__) diff --git a/src/mopidy_mpd/protocol/connection.py b/src/mopidy_mpd/protocol/connection.py index 5def70e..0ce1460 100644 --- a/src/mopidy_mpd/protocol/connection.py +++ b/src/mopidy_mpd/protocol/connection.py @@ -61,7 +61,7 @@ def ping(context: MpdContext) -> None: @protocol.commands.add("tagtypes") -def tagtypes(context: MpdContext, *parameters: list[str]) -> protocol.Result: +def tagtypes(context: MpdContext, *args: str) -> protocol.Result: """ *mpd.readthedocs.io, connection settings section:* @@ -85,7 +85,7 @@ def tagtypes(context: MpdContext, *parameters: list[str]) -> protocol.Result: Announce that this client is interested in all tag types. """ - parameters = list(parameters) + parameters = list(args) if parameters: subcommand = parameters.pop(0).lower() match subcommand: diff --git a/src/mopidy_mpd/protocol/current_playlist.py b/src/mopidy_mpd/protocol/current_playlist.py index 8ec43ce..b82a1df 100644 --- a/src/mopidy_mpd/protocol/current_playlist.py +++ b/src/mopidy_mpd/protocol/current_playlist.py @@ -35,7 +35,7 @@ def add(context: MpdContext, uri: Uri) -> None: try: uris = [] - for _path, ref in context.browse(uri, lookup=False): + for _path, ref in context.browse(uri, recursive=True, lookup=False): if ref: uris.append(ref.uri) except exceptions.MpdNoExistError as exc: @@ -303,12 +303,13 @@ def plchanges(context: MpdContext, version: int) -> protocol.Result: tl_track = context.core.playback.get_current_tl_track().get() position = context.core.tracklist.index(tl_track).get() - return translator.track_to_mpd_format( - tl_track, - context.session.tagtypes, - position=position, - stream_title=stream_title, - ) + if tl_track is not None and position is not None: + return translator.track_to_mpd_format( + tl_track, + context.session.tagtypes, + position=position, + stream_title=stream_title, + ) return None diff --git a/src/mopidy_mpd/protocol/music_db.py b/src/mopidy_mpd/protocol/music_db.py index 0a15349..b80e181 100644 --- a/src/mopidy_mpd/protocol/music_db.py +++ b/src/mopidy_mpd/protocol/music_db.py @@ -7,6 +7,7 @@ from mopidy.types import DistinctField, Query, SearchField, Uri from mopidy_mpd import exceptions, protocol, translator +from mopidy_mpd.protocol import stored_playlists if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -317,7 +318,7 @@ def listall(context: MpdContext, uri: str | None = None) -> protocol.Result: .. warning:: This command is disabled by default in Mopidy installs. """ result = [] - for path, track_ref in context.browse(uri, lookup=False): + for path, track_ref in context.browse(uri, recursive=True, lookup=False): if not track_ref: result.append(("directory", path.lstrip("/"))) else: @@ -346,7 +347,7 @@ def listallinfo(context: MpdContext, uri: str | None = None) -> protocol.Result: .. warning:: This command is disabled by default in Mopidy installs. """ result: protocol.ResultList = [] - for path, lookup_future in context.browse(uri, lookup=True): + for path, lookup_future in context.browse(uri, recursive=True, lookup=True): if not lookup_future: result.append(("directory", path.lstrip("/"))) else: @@ -399,7 +400,7 @@ def lsinfo(context: MpdContext, uri: str | None = None) -> protocol.Result: directories located at the root level, for both ``lsinfo``, ``lsinfo ""``, and ``lsinfo "/"``. """ - result = [] + result: protocol.ResultList = [] for path, lookup_future in context.browse(uri, recursive=False, lookup=True): if not lookup_future: result.append(("directory", path.lstrip("/"))) @@ -413,7 +414,7 @@ def lsinfo(context: MpdContext, uri: str | None = None) -> protocol.Result: ) if uri in (None, "", "/"): - result.extend(protocol.stored_playlists.listplaylists(context)) + result.extend(stored_playlists.listplaylists(context)) return result diff --git a/src/mopidy_mpd/protocol/playback.py b/src/mopidy_mpd/protocol/playback.py index 22ce9fd..4cc62c6 100644 --- a/src/mopidy_mpd/protocol/playback.py +++ b/src/mopidy_mpd/protocol/playback.py @@ -329,7 +329,7 @@ def replay_gain_mode(context: MpdContext, mode: str) -> Never: @protocol.commands.add("replay_gain_status") -def replay_gain_status(context: MpdContext) -> str: +def replay_gain_status(context: MpdContext) -> protocol.Result: """ *musicpd.org, playback section:* @@ -338,7 +338,7 @@ def replay_gain_status(context: MpdContext) -> str: Prints replay gain options. Currently, only the variable ``replay_gain_mode`` is returned. """ - return "replay_gain_mode: off" # TODO + return ("replay_gain_mode", "off") # TODO @protocol.commands.add("seek", songpos=protocol.UINT, seconds=protocol.UFLOAT) diff --git a/src/mopidy_mpd/protocol/status.py b/src/mopidy_mpd/protocol/status.py index df6f565..bb29d36 100644 --- a/src/mopidy_mpd/protocol/status.py +++ b/src/mopidy_mpd/protocol/status.py @@ -63,7 +63,7 @@ def currentsong(context: MpdContext) -> protocol.Result: @protocol.commands.add("idle") -def idle(context: MpdContext, *subsystems: list[str]) -> protocol.Result: +def idle(context: MpdContext, *args: str) -> protocol.Result: """ *musicpd.org, status section:* @@ -98,8 +98,7 @@ def idle(context: MpdContext, *subsystems: list[str]) -> protocol.Result: """ # TODO: test against valid subsystems - if not subsystems: - subsystems = SUBSYSTEMS + subsystems = list(args) if args else SUBSYSTEMS for subsystem in subsystems: context.subscriptions.add(subsystem) diff --git a/src/mopidy_mpd/protocol/stored_playlists.py b/src/mopidy_mpd/protocol/stored_playlists.py index c89f4e2..d1dc2e3 100644 --- a/src/mopidy_mpd/protocol/stored_playlists.py +++ b/src/mopidy_mpd/protocol/stored_playlists.py @@ -95,7 +95,7 @@ def listplaylistinfo(context: MpdContext, name: str) -> protocol.Result: @protocol.commands.add("listplaylists") -def listplaylists(context: MpdContext) -> protocol.Result: +def listplaylists(context: MpdContext) -> protocol.ResultList: """ *musicpd.org, stored playlists section:* @@ -121,11 +121,13 @@ def listplaylists(context: MpdContext) -> protocol.Result: ignore playlists without names, which isn't very useful anyway. """ last_modified = _get_last_modified() - result = [] + result: protocol.ResultList = [] for playlist_ref in context.core.playlists.as_list().get(): if not playlist_ref.name: continue name = context.lookup_playlist_name_from_uri(playlist_ref.uri) + if name is None: + continue result.append(("playlist", name)) result.append(("Last-Modified", last_modified)) return result @@ -220,7 +222,7 @@ def _create_playlist(context: MpdContext, name: str, tracks: Iterable[Track]) -> """ Creates new playlist using backend appropriate for the given tracks """ - uri_schemes = {urlparse(t.uri).scheme for t in tracks} + uri_schemes = {UriScheme(urlparse(t.uri).scheme) for t in tracks} for scheme in uri_schemes: new_playlist = context.core.playlists.create(name, scheme).get() if new_playlist is None: @@ -295,7 +297,9 @@ def playlistdelete(context: MpdContext, name: str, songpos: int) -> None: playlist = playlist.replace(tracks=tracks) saved_playlist = context.core.playlists.save(playlist).get() if saved_playlist is None: - raise exceptions.MpdFailedToSavePlaylistError(urlparse(playlist.uri).scheme) + raise exceptions.MpdFailedToSavePlaylistError( + UriScheme(urlparse(playlist.uri).scheme) + ) @protocol.commands.add("playlistmove", from_pos=protocol.UINT, to_pos=protocol.UINT) @@ -334,7 +338,9 @@ def playlistmove(context: MpdContext, name: str, from_pos: int, to_pos: int) -> playlist = playlist.replace(tracks=tracks) saved_playlist = context.core.playlists.save(playlist).get() if saved_playlist is None: - raise exceptions.MpdFailedToSavePlaylistError(urlparse(playlist.uri).scheme) + raise exceptions.MpdFailedToSavePlaylistError( + UriScheme(urlparse(playlist.uri).scheme) + ) @protocol.commands.add("rename") @@ -357,10 +363,11 @@ def rename(context: MpdContext, old_name: str, new_name: str) -> None: # Create copy of the playlist and remove original uri_scheme = UriScheme(urlparse(old_playlist.uri).scheme) - new_playlist = context.core.playlists.create(new_name, uri_scheme).get() - new_playlist = new_playlist.replace(tracks=old_playlist.tracks) - saved_playlist = context.core.playlists.save(new_playlist).get() - + empty_playlist = context.core.playlists.create(new_name, uri_scheme).get() + if empty_playlist is None: + raise exceptions.MpdFailedToSavePlaylistError(uri_scheme) + filled_playlist = empty_playlist.replace(tracks=old_playlist.tracks) + saved_playlist = context.core.playlists.save(filled_playlist).get() if saved_playlist is None: raise exceptions.MpdFailedToSavePlaylistError(uri_scheme) context.core.playlists.delete(old_playlist.uri).get() diff --git a/src/mopidy_mpd/session.py b/src/mopidy_mpd/session.py index 9c70d5c..11e22d2 100644 --- a/src/mopidy_mpd/session.py +++ b/src/mopidy_mpd/session.py @@ -10,7 +10,6 @@ from mopidy.core import CoreProxy from mopidy.ext import Config - from mopidy_mpd.uri_mapper import MpdUriMapper logger = logging.getLogger(__name__) @@ -26,17 +25,15 @@ class MpdSession(network.LineProtocol): def __init__( self, + config: Config, + core: CoreProxy, connection: network.Connection, - config: Config | None = None, - core: CoreProxy | None = None, - uri_map: MpdUriMapper | None = None, ) -> None: super().__init__(connection) self.dispatcher = dispatcher.MpdDispatcher( session=self, config=config, core=core, - uri_map=uri_map, ) self.tagtypes = tagtype_list.TAGTYPE_LIST.copy() diff --git a/src/mopidy_mpd/translator.py b/src/mopidy_mpd/translator.py index 46b4899..54ea6bd 100644 --- a/src/mopidy_mpd/translator.py +++ b/src/mopidy_mpd/translator.py @@ -41,7 +41,7 @@ def track_to_mpd_format( # noqa: C901, PLR0912 logger.warning("Ignoring track without uri") return [] - result: protocol.Result = [ + result: list[protocol.ResultTuple] = [ ("file", track.uri), ("Time", track.length and (track.length // 1000) or 0), *multi_tag_list(track.artists, "name", "Artist"), @@ -104,13 +104,17 @@ def track_to_mpd_format( # noqa: C901, PLR0912 if track.album and track.album.uri: result.append(("X-AlbumUri", track.album.uri)) - return [element for element in result if _has_value(tagtypes, *element)] + return [ + (tagtype, value) + for (tagtype, value) in result + if _has_value(tagtypes, tagtype, value) + ] def _has_value( tagtypes: set[str], tagtype: str, - value: str | int, + value: protocol.ResultValue, ) -> bool: """ Determine whether to add the tagtype to the output or not. The tagtype must @@ -149,7 +153,7 @@ def multi_tag_list( models: Iterable[Artist | Album | Track], attribute: str, tag: str, -) -> protocol.ResultList: +) -> list[protocol.ResultTuple]: """ Format multiple objects for output to MPD client in a list with one tag per value. diff --git a/tests/protocol/__init__.py b/tests/protocol/__init__.py index de63994..beab64d 100644 --- a/tests/protocol/__init__.py +++ b/tests/protocol/__init__.py @@ -1,9 +1,10 @@ import unittest +from typing import cast from unittest import mock import pykka from mopidy import core -from mopidy_mpd import session, uri_mapper +from mopidy_mpd import session from tests import dummy_audio, dummy_backend, dummy_mixer @@ -38,20 +39,21 @@ def setUp(self): self.audio = dummy_audio.create_proxy() self.backend = dummy_backend.create_proxy(audio=self.audio) - self.core = core.Core.start( - self.get_config(), - audio=self.audio, - mixer=self.mixer, - backends=[self.backend], - ).proxy() + self.core = cast( + core.CoreProxy, + core.Core.start( + self.get_config(), + audio=self.audio, + mixer=self.mixer, + backends=[self.backend], + ).proxy(), + ) - self.uri_map = uri_mapper.MpdUriMapper(self.core) self.connection = MockConnection() self.session = session.MpdSession( - self.connection, config=self.get_config(), core=self.core, - uri_map=self.uri_map, + connection=self.connection, ) self.dispatcher = self.session.dispatcher self.context = self.dispatcher.context diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index c3b542c..2155531 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -8,7 +8,6 @@ from mopidy.models import Ref from mopidy_mpd.dispatcher import MpdContext, MpdDispatcher from mopidy_mpd.exceptions import MpdAckError -from mopidy_mpd.uri_mapper import MpdUriMapper from tests import dummy_backend @@ -20,7 +19,11 @@ def setUp(self): self.core = cast( CoreProxy, Core.start(config=None, backends=[self.backend]).proxy() ) - self.dispatcher = MpdDispatcher(config=config, core=self.core) + self.dispatcher = MpdDispatcher( + config=config, + core=self.core, + session=None, + ) def tearDown(self): pykka.ActorRegistry.stop_all() @@ -78,8 +81,10 @@ def mpd_context(backend_to_browse: BackendProxy) -> MpdContext: Core.start(config=None, backends=[backend_to_browse]).proxy(), ) return MpdContext( + config=None, core=core, - uri_map=MpdUriMapper(core), + dispatcher=None, + session=None, ) diff --git a/tests/test_session.py b/tests/test_session.py index 2e76672..591a8d4 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -8,7 +8,11 @@ def test_on_start_logged(caplog): caplog.set_level(logging.INFO) connection = Mock(spec=network.Connection) - session.MpdSession(connection).on_start() + session.MpdSession( + config=None, + core=None, + connection=connection, + ).on_start() assert f"New MPD connection from {connection}" in caplog.text @@ -16,7 +20,11 @@ def test_on_start_logged(caplog): def test_on_line_received_logged(caplog): caplog.set_level(logging.DEBUG) connection = Mock(spec=network.Connection) - mpd_session = session.MpdSession(connection) + mpd_session = session.MpdSession( + config=None, + core=None, + connection=connection, + ) mpd_session.dispatcher = Mock(spec=dispatcher.MpdDispatcher) mpd_session.dispatcher.handle_request.return_value = [str(sentinel.resp)] diff --git a/tests/test_status.py b/tests/test_status.py index 962b6b4..bcb5385 100644 --- a/tests/test_status.py +++ b/tests/test_status.py @@ -1,4 +1,5 @@ import unittest +from typing import cast import pykka from mopidy import core @@ -19,20 +20,30 @@ class StatusHandlerTest(unittest.TestCase): def setUp(self): - config = {"core": {"max_tracklist_length": 10000}} + config = { + "core": {"max_tracklist_length": 10000}, + "mpd": {"password": None}, + } self.audio = dummy_audio.create_proxy() self.mixer = dummy_mixer.create_proxy() self.backend = dummy_backend.create_proxy(audio=self.audio) - self.core = core.Core.start( - config, - audio=self.audio, - mixer=self.mixer, - backends=[self.backend], - ).proxy() - - self.dispatcher = dispatcher.MpdDispatcher(core=self.core) + self.core = cast( + core.CoreProxy, + core.Core.start( + config, + audio=self.audio, + mixer=self.mixer, + backends=[self.backend], + ).proxy(), + ) + + self.dispatcher = dispatcher.MpdDispatcher( + config=config, + core=self.core, + session=None, + ) self.context = self.dispatcher.context def tearDown(self): From b76959a64dab27b914693a50f680db5b769ff10d Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Sat, 24 Feb 2024 02:45:24 +0100 Subject: [PATCH 06/19] Move MpdContext out of dispatcher module --- src/mopidy_mpd/context.py | 163 ++++++++++++++++++++ src/mopidy_mpd/dispatcher.py | 149 +----------------- src/mopidy_mpd/protocol/__init__.py | 2 +- src/mopidy_mpd/protocol/audio_output.py | 2 +- src/mopidy_mpd/protocol/channels.py | 2 +- src/mopidy_mpd/protocol/command_list.py | 2 +- src/mopidy_mpd/protocol/connection.py | 2 +- src/mopidy_mpd/protocol/current_playlist.py | 2 +- src/mopidy_mpd/protocol/mount.py | 2 +- src/mopidy_mpd/protocol/music_db.py | 2 +- src/mopidy_mpd/protocol/playback.py | 2 +- src/mopidy_mpd/protocol/reflection.py | 2 +- src/mopidy_mpd/protocol/status.py | 2 +- src/mopidy_mpd/protocol/stickers.py | 2 +- src/mopidy_mpd/protocol/stored_playlists.py | 2 +- src/mopidy_mpd/session.py | 2 +- tests/test_commands.py | 7 +- tests/test_context.py | 90 +++++++++++ tests/test_dispatcher.py | 85 +--------- 19 files changed, 279 insertions(+), 243 deletions(-) create mode 100644 src/mopidy_mpd/context.py create mode 100644 tests/test_context.py diff --git a/src/mopidy_mpd/context.py b/src/mopidy_mpd/context.py new file mode 100644 index 0000000..aea9341 --- /dev/null +++ b/src/mopidy_mpd/context.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import logging +import re +from typing import ( + TYPE_CHECKING, + Any, + Literal, + cast, + overload, +) + +from mopidy_mpd import exceptions, types +from mopidy_mpd.uri_mapper import MpdUriMapper + +if TYPE_CHECKING: + from collections.abc import Generator + + import pykka + from mopidy.core import CoreProxy + from mopidy.ext import Config + from mopidy.models import Ref, Track + from mopidy.types import Uri + + from mopidy_mpd.dispatcher import MpdDispatcher + from mopidy_mpd.session import MpdSession + + +logger = logging.getLogger(__name__) + + +class MpdContext: + """ + This object is passed as the first argument to all MPD command handlers to + give the command handlers access to important parts of Mopidy. + """ + + #: The Mopidy core API. + core: CoreProxy + + #: The current session instance. + session: MpdSession + + #: The current dispatcher instance. + dispatcher: MpdDispatcher + + #: The MPD password. + password: str | None = None + + #: The active subsystems that have pending events. + events: set[str] + + #: The subsystems that we want to be notified about in idle mode. + subscriptions: set[str] + + _uri_map: MpdUriMapper + + def __init__( + self, + config: Config, + core: CoreProxy, + session: MpdSession, + dispatcher: MpdDispatcher, + ) -> None: + self.core = core + self.session = session + self.dispatcher = dispatcher + + if config is not None: + mpd_config = cast(types.MpdConfig, config["mpd"]) + self.password = mpd_config["password"] + self.events = set() + self.subscriptions = set() + self._uri_map = MpdUriMapper(core) + + def lookup_playlist_uri_from_name(self, name: str) -> Uri | None: + """ + Helper function to retrieve a playlist from its unique MPD name. + """ + return self._uri_map.playlist_uri_from_name(name) + + def lookup_playlist_name_from_uri(self, uri: Uri) -> str | None: + """ + Helper function to retrieve the unique MPD playlist name from its uri. + """ + return self._uri_map.playlist_name_from_uri(uri) + + @overload + def browse( + self, path: str | None, *, recursive: bool, lookup: Literal[True] + ) -> Generator[tuple[str, pykka.Future[dict[Uri, list[Track]]] | None], Any, None]: + ... + + @overload + def browse( + self, path: str | None, *, recursive: bool, lookup: Literal[False] + ) -> Generator[tuple[str, Ref | None], Any, None]: + ... + + def browse( # noqa: C901, PLR0912 + self, + path: str | None, + *, + recursive: bool = True, + lookup: bool = True, + ) -> Generator[Any, Any, None]: + """ + Browse the contents of a given directory path. + + Returns a sequence of two-tuples ``(path, data)``. + + If ``recursive`` is true, it returns results for all entries in the + given path. + + If ``lookup`` is true and the ``path`` is to a track, the returned + ``data`` is a future which will contain the results from looking up + the URI with :meth:`mopidy.core.LibraryController.lookup`. If + ``lookup`` is false and the ``path`` is to a track, the returned + ``data`` will be a :class:`mopidy.models.Ref` for the track. + + For all entries that are not tracks, the returned ``data`` will be + :class:`None`. + """ + + path_parts: list[str] = re.findall(r"[^/]+", path or "") + root_path: str = "/".join(["", *path_parts]) + + uri = self._uri_map.uri_from_name(root_path) + if uri is None: + for part in path_parts: + for ref in self.core.library.browse(uri).get(): + if ref.type != ref.TRACK and ref.name == part: + uri = ref.uri + break + else: + raise exceptions.MpdNoExistError("Not found") + root_path = self._uri_map.insert(root_path, uri) + + if recursive: + yield (root_path, None) + + path_and_futures = [(root_path, self.core.library.browse(uri))] + while path_and_futures: + base_path, future = path_and_futures.pop() + for ref in future.get(): + if ref.name is None or ref.uri is None: + continue + + path = "/".join([base_path, ref.name.replace("/", "")]) + path = self._uri_map.insert(path, ref.uri) + + if ref.type == ref.TRACK: + if lookup: + # TODO: can we lookup all the refs at once now? + yield (path, self.core.library.lookup(uris=[ref.uri])) + else: + yield (path, ref) + else: + yield (path, None) + if recursive: + path_and_futures.append( + (path, self.core.library.browse(ref.uri)) + ) diff --git a/src/mopidy_mpd/dispatcher.py b/src/mopidy_mpd/dispatcher.py index 5c73590..1de4e09 100644 --- a/src/mopidy_mpd/dispatcher.py +++ b/src/mopidy_mpd/dispatcher.py @@ -2,28 +2,22 @@ import logging import re -from collections.abc import Callable, Generator +from collections.abc import Callable from typing import ( TYPE_CHECKING, - Any, - Literal, NewType, TypeAlias, TypeVar, cast, - overload, ) import pykka -from mopidy_mpd import exceptions, protocol, tokenize, types -from mopidy_mpd.uri_mapper import MpdUriMapper +from mopidy_mpd import context, exceptions, protocol, tokenize, types if TYPE_CHECKING: from mopidy.core import CoreProxy from mopidy.ext import Config - from mopidy.models import Ref, Track - from mopidy.types import Uri from mopidy_mpd.session import MpdSession @@ -60,11 +54,11 @@ def __init__( self.command_list_ok = False self.command_list = [] self.command_list_index = None - self.context = MpdContext( + self.context = context.MpdContext( + config=config, core=core, - dispatcher=self, session=session, - config=config, + dispatcher=self, ) def handle_request( @@ -289,136 +283,3 @@ def _format_lines( (key, value) = element return Response([f"{key}: {value}"]) return Response([element]) - - -class MpdContext: - """ - This object is passed as the first argument to all MPD command handlers to - give the command handlers access to important parts of Mopidy. - """ - - #: The Mopidy core API. - core: CoreProxy - - #: The current dispatcher instance. - dispatcher: MpdDispatcher - - #: The current session instance. - session: MpdSession - - #: The MPD password. - password: str | None = None - - #: The active subsystems that have pending events. - events: set[str] - - #: The subsystems that we want to be notified about in idle mode. - subscriptions: set[str] - - _uri_map: MpdUriMapper - - def __init__( - self, - config: Config, - core: CoreProxy, - dispatcher: MpdDispatcher, - session: MpdSession, - ) -> None: - self.core = core - self.dispatcher = dispatcher - self.session = session - if config is not None: - mpd_config = cast(types.MpdConfig, config["mpd"]) - self.password = mpd_config["password"] - self.events = set() - self.subscriptions = set() - self._uri_map = MpdUriMapper(core) - - def lookup_playlist_uri_from_name(self, name: str) -> Uri | None: - """ - Helper function to retrieve a playlist from its unique MPD name. - """ - return self._uri_map.playlist_uri_from_name(name) - - def lookup_playlist_name_from_uri(self, uri: Uri) -> str | None: - """ - Helper function to retrieve the unique MPD playlist name from its uri. - """ - return self._uri_map.playlist_name_from_uri(uri) - - @overload - def browse( - self, path: str | None, *, recursive: bool, lookup: Literal[True] - ) -> Generator[tuple[str, pykka.Future[dict[Uri, list[Track]]] | None], Any, None]: - ... - - @overload - def browse( - self, path: str | None, *, recursive: bool, lookup: Literal[False] - ) -> Generator[tuple[str, Ref | None], Any, None]: - ... - - def browse( # noqa: C901, PLR0912 - self, - path: str | None, - *, - recursive: bool = True, - lookup: bool = True, - ) -> Generator[Any, Any, None]: - """ - Browse the contents of a given directory path. - - Returns a sequence of two-tuples ``(path, data)``. - - If ``recursive`` is true, it returns results for all entries in the - given path. - - If ``lookup`` is true and the ``path`` is to a track, the returned - ``data`` is a future which will contain the results from looking up - the URI with :meth:`mopidy.core.LibraryController.lookup`. If - ``lookup`` is false and the ``path`` is to a track, the returned - ``data`` will be a :class:`mopidy.models.Ref` for the track. - - For all entries that are not tracks, the returned ``data`` will be - :class:`None`. - """ - - path_parts: list[str] = re.findall(r"[^/]+", path or "") - root_path: str = "/".join(["", *path_parts]) - - uri = self._uri_map.uri_from_name(root_path) - if uri is None: - for part in path_parts: - for ref in self.core.library.browse(uri).get(): - if ref.type != ref.TRACK and ref.name == part: - uri = ref.uri - break - else: - raise exceptions.MpdNoExistError("Not found") - root_path = self._uri_map.insert(root_path, uri) - - if recursive: - yield (root_path, None) - - path_and_futures = [(root_path, self.core.library.browse(uri))] - while path_and_futures: - base_path, future = path_and_futures.pop() - for ref in future.get(): - if ref.name is None or ref.uri is None: - continue - - path = "/".join([base_path, ref.name.replace("/", "")]) - path = self._uri_map.insert(path, ref.uri) - - if ref.type == ref.TRACK: - if lookup: - # TODO: can we lookup all the refs at once now? - yield (path, self.core.library.lookup(uris=[ref.uri])) - else: - yield (path, ref) - else: - yield (path, None) - if recursive: - path_and_futures.append( - (path, self.core.library.browse(ref.uri)) - ) diff --git a/src/mopidy_mpd/protocol/__init__.py b/src/mopidy_mpd/protocol/__init__.py index 8edf58c..e4d0a92 100644 --- a/src/mopidy_mpd/protocol/__init__.py +++ b/src/mopidy_mpd/protocol/__init__.py @@ -20,7 +20,7 @@ from mopidy_mpd import exceptions if TYPE_CHECKING: - from mopidy_mpd.dispatcher import MpdContext + from mopidy_mpd.context import MpdContext #: The MPD protocol uses UTF-8 for encoding all data. ENCODING = "utf-8" diff --git a/src/mopidy_mpd/protocol/audio_output.py b/src/mopidy_mpd/protocol/audio_output.py index fe84909..16cbcf0 100644 --- a/src/mopidy_mpd/protocol/audio_output.py +++ b/src/mopidy_mpd/protocol/audio_output.py @@ -5,7 +5,7 @@ from mopidy_mpd import exceptions, protocol if TYPE_CHECKING: - from mopidy_mpd.dispatcher import MpdContext + from mopidy_mpd.context import MpdContext @protocol.commands.add("disableoutput", outputid=protocol.UINT) diff --git a/src/mopidy_mpd/protocol/channels.py b/src/mopidy_mpd/protocol/channels.py index 020011f..4bfe38c 100644 --- a/src/mopidy_mpd/protocol/channels.py +++ b/src/mopidy_mpd/protocol/channels.py @@ -5,7 +5,7 @@ from mopidy_mpd import exceptions, protocol if TYPE_CHECKING: - from mopidy_mpd.dispatcher import MpdContext + from mopidy_mpd.context import MpdContext @protocol.commands.add("subscribe") diff --git a/src/mopidy_mpd/protocol/command_list.py b/src/mopidy_mpd/protocol/command_list.py index bdff82a..bdc4436 100644 --- a/src/mopidy_mpd/protocol/command_list.py +++ b/src/mopidy_mpd/protocol/command_list.py @@ -5,7 +5,7 @@ from mopidy_mpd import exceptions, protocol if TYPE_CHECKING: - from mopidy_mpd.dispatcher import MpdContext + from mopidy_mpd.context import MpdContext @protocol.commands.add("command_list_begin", list_command=False) diff --git a/src/mopidy_mpd/protocol/connection.py b/src/mopidy_mpd/protocol/connection.py index 0ce1460..6d686cb 100644 --- a/src/mopidy_mpd/protocol/connection.py +++ b/src/mopidy_mpd/protocol/connection.py @@ -6,7 +6,7 @@ from mopidy_mpd.protocol import tagtype_list if TYPE_CHECKING: - from mopidy_mpd.dispatcher import MpdContext + from mopidy_mpd.context import MpdContext @protocol.commands.add("close", auth_required=False) diff --git a/src/mopidy_mpd/protocol/current_playlist.py b/src/mopidy_mpd/protocol/current_playlist.py index b82a1df..3f4dac9 100644 --- a/src/mopidy_mpd/protocol/current_playlist.py +++ b/src/mopidy_mpd/protocol/current_playlist.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from mopidy.types import Uri - from mopidy_mpd.dispatcher import MpdContext + from mopidy_mpd.context import MpdContext @protocol.commands.add("add") diff --git a/src/mopidy_mpd/protocol/mount.py b/src/mopidy_mpd/protocol/mount.py index b8280bc..7dc2710 100644 --- a/src/mopidy_mpd/protocol/mount.py +++ b/src/mopidy_mpd/protocol/mount.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from mopidy.types import Uri - from mopidy_mpd.dispatcher import MpdContext + from mopidy_mpd.context import MpdContext @protocol.commands.add("mount") diff --git a/src/mopidy_mpd/protocol/music_db.py b/src/mopidy_mpd/protocol/music_db.py index b80e181..6023a66 100644 --- a/src/mopidy_mpd/protocol/music_db.py +++ b/src/mopidy_mpd/protocol/music_db.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from collections.abc import Iterable, Sequence - from mopidy_mpd.dispatcher import MpdContext + from mopidy_mpd.context import MpdContext _LIST_MAPPING: dict[str, DistinctField] = { diff --git a/src/mopidy_mpd/protocol/playback.py b/src/mopidy_mpd/protocol/playback.py index 4cc62c6..c7cacca 100644 --- a/src/mopidy_mpd/protocol/playback.py +++ b/src/mopidy_mpd/protocol/playback.py @@ -8,7 +8,7 @@ from mopidy_mpd import exceptions, protocol if TYPE_CHECKING: - from mopidy_mpd.dispatcher import MpdContext + from mopidy_mpd.context import MpdContext @protocol.commands.add("consume", state=protocol.BOOL) diff --git a/src/mopidy_mpd/protocol/reflection.py b/src/mopidy_mpd/protocol/reflection.py index d73e2cd..20feb3b 100644 --- a/src/mopidy_mpd/protocol/reflection.py +++ b/src/mopidy_mpd/protocol/reflection.py @@ -5,7 +5,7 @@ from mopidy_mpd import exceptions, protocol if TYPE_CHECKING: - from mopidy_mpd.dispatcher import MpdContext + from mopidy_mpd.context import MpdContext @protocol.commands.add("config", list_command=False) diff --git a/src/mopidy_mpd/protocol/status.py b/src/mopidy_mpd/protocol/status.py index bb29d36..f5f13c3 100644 --- a/src/mopidy_mpd/protocol/status.py +++ b/src/mopidy_mpd/protocol/status.py @@ -10,7 +10,7 @@ from mopidy.models import Track from mopidy.types import DurationMs - from mopidy_mpd.dispatcher import MpdContext + from mopidy_mpd.context import MpdContext #: Subsystems that can be registered with idle command. diff --git a/src/mopidy_mpd/protocol/stickers.py b/src/mopidy_mpd/protocol/stickers.py index 6808d51..34ca2d5 100644 --- a/src/mopidy_mpd/protocol/stickers.py +++ b/src/mopidy_mpd/protocol/stickers.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from mopidy.types import Uri - from mopidy_mpd.dispatcher import MpdContext + from mopidy_mpd.context import MpdContext @protocol.commands.add("sticker", list_command=False) diff --git a/src/mopidy_mpd/protocol/stored_playlists.py b/src/mopidy_mpd/protocol/stored_playlists.py index d1dc2e3..17a7077 100644 --- a/src/mopidy_mpd/protocol/stored_playlists.py +++ b/src/mopidy_mpd/protocol/stored_playlists.py @@ -15,7 +15,7 @@ from mopidy.models import Playlist, Track - from mopidy_mpd.dispatcher import MpdContext + from mopidy_mpd.context import MpdContext logger = logging.getLogger(__name__) diff --git a/src/mopidy_mpd/session.py b/src/mopidy_mpd/session.py index 11e22d2..3e0ddf7 100644 --- a/src/mopidy_mpd/session.py +++ b/src/mopidy_mpd/session.py @@ -31,9 +31,9 @@ def __init__( ) -> None: super().__init__(connection) self.dispatcher = dispatcher.MpdDispatcher( - session=self, config=config, core=core, + session=self, ) self.tagtypes = tagtype_list.TAGTYPE_LIST.copy() diff --git a/tests/test_commands.py b/tests/test_commands.py index 80d42c1..1390ca2 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1,7 +1,12 @@ +from __future__ import annotations + import unittest +from typing import TYPE_CHECKING from mopidy_mpd import exceptions, protocol -from mopidy_mpd.dispatcher import MpdContext + +if TYPE_CHECKING: + from mopidy_mpd.context import MpdContext class TestConverts(unittest.TestCase): diff --git a/tests/test_context.py b/tests/test_context.py new file mode 100644 index 0000000..fdd1b1b --- /dev/null +++ b/tests/test_context.py @@ -0,0 +1,90 @@ +from typing import cast + +import pykka +import pytest +from mopidy.backend import BackendProxy +from mopidy.core import Core, CoreProxy +from mopidy.models import Ref +from mopidy_mpd.context import MpdContext + +from tests import dummy_backend + + +@pytest.fixture() +def a_track() -> Ref: + return Ref.track(uri="dummy:/a", name="a") + + +@pytest.fixture() +def b_track() -> Ref: + return Ref.track(uri="dummy:/foo/b", name="b") + + +@pytest.fixture() +def backend_to_browse(a_track: Ref, b_track: Ref) -> BackendProxy: + backend = cast(BackendProxy, dummy_backend.create_proxy()) + backend.library.dummy_browse_result = { + "dummy:/": [ + a_track, + Ref.directory(uri="dummy:/foo", name="foo"), + ], + "dummy:/foo": [ + b_track, + ], + } + return backend + + +@pytest.fixture() +def mpd_context(backend_to_browse: BackendProxy) -> MpdContext: + core = cast( + CoreProxy, + Core.start(config=None, backends=[backend_to_browse]).proxy(), + ) + return MpdContext( + config=None, + core=core, + dispatcher=None, + session=None, + ) + + +class TestMpdContext: + @classmethod + def teardown_class(cls): + pykka.ActorRegistry.stop_all() + + def test_browse_root(self, mpd_context, a_track): + results = mpd_context.browse("dummy", recursive=False, lookup=False) + + assert [("/dummy/a", a_track), ("/dummy/foo", None)] == list(results) + + def test_browse_root_recursive(self, mpd_context, a_track, b_track): + results = mpd_context.browse("dummy", recursive=True, lookup=False) + + assert [ + ("/dummy", None), + ("/dummy/a", a_track), + ("/dummy/foo", None), + ("/dummy/foo/b", b_track), + ] == list(results) + + @pytest.mark.parametrize( + "bad_ref", + [ + Ref.track(uri="dummy:/x"), + Ref.track(name="x"), + Ref.directory(uri="dummy:/y"), + Ref.directory(name="y"), + ], + ) + def test_browse_skips_bad_refs( + self, backend_to_browse, a_track, bad_ref, mpd_context + ): + backend_to_browse.library.dummy_browse_result = { + "dummy:/": [bad_ref, a_track], + } + + results = mpd_context.browse("dummy", recursive=False, lookup=False) + + assert [("/dummy/a", a_track)] == list(results) diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index 2155531..a458b4b 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -2,11 +2,8 @@ from typing import cast import pykka -import pytest -from mopidy.backend import BackendProxy from mopidy.core import Core, CoreProxy -from mopidy.models import Ref -from mopidy_mpd.dispatcher import MpdContext, MpdDispatcher +from mopidy_mpd.dispatcher import MpdDispatcher from mopidy_mpd.exceptions import MpdAckError from tests import dummy_backend @@ -47,83 +44,3 @@ def test_handling_blacklisted_command(self): result[0] == 'ACK [0@0] {disabled} "disabled" has been disabled in the server' ) - - -@pytest.fixture() -def a_track() -> Ref: - return Ref.track(uri="dummy:/a", name="a") - - -@pytest.fixture() -def b_track() -> Ref: - return Ref.track(uri="dummy:/foo/b", name="b") - - -@pytest.fixture() -def backend_to_browse(a_track: Ref, b_track: Ref) -> BackendProxy: - backend = cast(BackendProxy, dummy_backend.create_proxy()) - backend.library.dummy_browse_result = { - "dummy:/": [ - a_track, - Ref.directory(uri="dummy:/foo", name="foo"), - ], - "dummy:/foo": [ - b_track, - ], - } - return backend - - -@pytest.fixture() -def mpd_context(backend_to_browse: BackendProxy) -> MpdContext: - core = cast( - CoreProxy, - Core.start(config=None, backends=[backend_to_browse]).proxy(), - ) - return MpdContext( - config=None, - core=core, - dispatcher=None, - session=None, - ) - - -class TestMpdContext: - @classmethod - def teardown_class(cls): - pykka.ActorRegistry.stop_all() - - def test_browse_root(self, mpd_context, a_track): - results = mpd_context.browse("dummy", recursive=False, lookup=False) - - assert [("/dummy/a", a_track), ("/dummy/foo", None)] == list(results) - - def test_browse_root_recursive(self, mpd_context, a_track, b_track): - results = mpd_context.browse("dummy", recursive=True, lookup=False) - - assert [ - ("/dummy", None), - ("/dummy/a", a_track), - ("/dummy/foo", None), - ("/dummy/foo/b", b_track), - ] == list(results) - - @pytest.mark.parametrize( - "bad_ref", - [ - Ref.track(uri="dummy:/x"), - Ref.track(name="x"), - Ref.directory(uri="dummy:/y"), - Ref.directory(name="y"), - ], - ) - def test_browse_skips_bad_refs( - self, backend_to_browse, a_track, bad_ref, mpd_context - ): - backend_to_browse.library.dummy_browse_result = { - "dummy:/": [bad_ref, a_track], - } - - results = mpd_context.browse("dummy", recursive=False, lookup=False) - - assert [("/dummy/a", a_track)] == list(results) From 8eb9efc6c259bcf82a530a8aef2b72dfbf85151d Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Sat, 24 Feb 2024 02:50:14 +0100 Subject: [PATCH 07/19] Move util functions out of MpdDispatcher class --- src/mopidy_mpd/dispatcher.py | 70 +++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 33 deletions(-) diff --git a/src/mopidy_mpd/dispatcher.py b/src/mopidy_mpd/dispatcher.py index 1de4e09..bc6eff8 100644 --- a/src/mopidy_mpd/dispatcher.py +++ b/src/mopidy_mpd/dispatcher.py @@ -227,7 +227,7 @@ def _call_handler_filter( ) -> Response: try: result = self._call_handler(request) - response = self._format_response(result) + response = _format_response(result) return self._call_next_filter(request, response, filter_chain) except pykka.ActorDeadError as exc: logger.warning("Tried to communicate with dead actor.") @@ -250,36 +250,40 @@ def _call_handler(self, request: str) -> protocol.Result: exc.command = tokens[0] raise - def _format_response(self, result: protocol.Result) -> Response: - response = Response([]) - for element in self._listify_result(result): - response.extend(self._format_lines(element)) - return response - def _listify_result(self, result: protocol.Result) -> protocol.ResultList: - match result: - case None: - return [] - case list(): - return self._flatten(result) - case _: - return [result] - - def _flatten(self, lst: protocol.ResultList) -> protocol.ResultList: - result: protocol.ResultList = [] - for element in lst: - if isinstance(element, list): - result.extend(self._flatten(element)) - else: - result.append(element) - return result - - def _format_lines( - self, element: protocol.ResultDict | protocol.ResultTuple | str - ) -> Response: - if isinstance(element, dict): - return Response([f"{key}: {value}" for (key, value) in element.items()]) - if isinstance(element, tuple): - (key, value) = element - return Response([f"{key}: {value}"]) - return Response([element]) +def _format_response(result: protocol.Result) -> Response: + response = Response([]) + for element in _listify_result(result): + response.extend(_format_lines(element)) + return response + + +def _listify_result(result: protocol.Result) -> protocol.ResultList: + match result: + case None: + return [] + case list(): + return _flatten(result) + case _: + return [result] + + +def _flatten(lst: protocol.ResultList) -> protocol.ResultList: + result: protocol.ResultList = [] + for element in lst: + if isinstance(element, list): + result.extend(_flatten(element)) + else: + result.append(element) + return result + + +def _format_lines( + element: protocol.ResultDict | protocol.ResultTuple | str, +) -> Response: + if isinstance(element, dict): + return Response([f"{key}: {value}" for (key, value) in element.items()]) + if isinstance(element, tuple): + (key, value) = element + return Response([f"{key}: {value}"]) + return Response([element]) From b45e33a846bb4ce14f78fe962ab7c83558ef6808 Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Sat, 24 Feb 2024 03:01:46 +0100 Subject: [PATCH 08/19] Expose MpdUriMapper directly to protocol implementation --- src/mopidy_mpd/context.py | 23 ++++++--------------- src/mopidy_mpd/protocol/music_db.py | 18 ++++++++-------- src/mopidy_mpd/protocol/stored_playlists.py | 6 +++--- 3 files changed, 18 insertions(+), 29 deletions(-) diff --git a/src/mopidy_mpd/context.py b/src/mopidy_mpd/context.py index aea9341..9089996 100644 --- a/src/mopidy_mpd/context.py +++ b/src/mopidy_mpd/context.py @@ -53,7 +53,8 @@ class MpdContext: #: The subsystems that we want to be notified about in idle mode. subscriptions: set[str] - _uri_map: MpdUriMapper + #: Mapping of URIs to MPD names. + uri_map: MpdUriMapper def __init__( self, @@ -71,19 +72,7 @@ def __init__( self.password = mpd_config["password"] self.events = set() self.subscriptions = set() - self._uri_map = MpdUriMapper(core) - - def lookup_playlist_uri_from_name(self, name: str) -> Uri | None: - """ - Helper function to retrieve a playlist from its unique MPD name. - """ - return self._uri_map.playlist_uri_from_name(name) - - def lookup_playlist_name_from_uri(self, uri: Uri) -> str | None: - """ - Helper function to retrieve the unique MPD playlist name from its uri. - """ - return self._uri_map.playlist_name_from_uri(uri) + self.uri_map = MpdUriMapper(core) @overload def browse( @@ -125,7 +114,7 @@ def browse( # noqa: C901, PLR0912 path_parts: list[str] = re.findall(r"[^/]+", path or "") root_path: str = "/".join(["", *path_parts]) - uri = self._uri_map.uri_from_name(root_path) + uri = self.uri_map.uri_from_name(root_path) if uri is None: for part in path_parts: for ref in self.core.library.browse(uri).get(): @@ -134,7 +123,7 @@ def browse( # noqa: C901, PLR0912 break else: raise exceptions.MpdNoExistError("Not found") - root_path = self._uri_map.insert(root_path, uri) + root_path = self.uri_map.insert(root_path, uri) if recursive: yield (root_path, None) @@ -147,7 +136,7 @@ def browse( # noqa: C901, PLR0912 continue path = "/".join([base_path, ref.name.replace("/", "")]) - path = self._uri_map.insert(path, ref.uri) + path = self.uri_map.insert(path, ref.uri) if ref.type == ref.TRACK: if lookup: diff --git a/src/mopidy_mpd/protocol/music_db.py b/src/mopidy_mpd/protocol/music_db.py index 6023a66..1e0422e 100644 --- a/src/mopidy_mpd/protocol/music_db.py +++ b/src/mopidy_mpd/protocol/music_db.py @@ -511,21 +511,21 @@ def searchaddpl(context: MpdContext, *args: str) -> None: parameters = list(args) if not parameters: raise exceptions.MpdArgError("incorrect arguments") + playlist_name = parameters.pop(0) + uri = context.uri_map.playlist_uri_from_name(playlist_name) + if uri: + playlist = context.core.playlists.lookup(uri).get() + else: + playlist = context.core.playlists.create(playlist_name).get() + if not playlist: + return # TODO: Raise error about failed playlist creation? + try: query = _query_for_search(parameters) except ValueError: return results = context.core.library.search(query).get() - - uri = context.lookup_playlist_uri_from_name(playlist_name) - if uri is None: - return # TODO: Raise error? - playlist = context.core.playlists.lookup(uri).get() - if not playlist: - playlist = context.core.playlists.create(playlist_name).get() - if not playlist: - return # TODO: Raise error about failed playlist creation? tracks = list(playlist.tracks) + _get_tracks(results) playlist = playlist.replace(tracks=tracks) context.core.playlists.save(playlist) diff --git a/src/mopidy_mpd/protocol/stored_playlists.py b/src/mopidy_mpd/protocol/stored_playlists.py index 17a7077..0cb3e9b 100644 --- a/src/mopidy_mpd/protocol/stored_playlists.py +++ b/src/mopidy_mpd/protocol/stored_playlists.py @@ -43,7 +43,7 @@ def _get_playlist( context: MpdContext, name: str, *, must_exist: bool ) -> Playlist | None: playlist = None - uri = context.lookup_playlist_uri_from_name(name) + uri = context.uri_map.playlist_uri_from_name(name) if uri: playlist = context.core.playlists.lookup(uri).get() if must_exist and playlist is None: @@ -125,7 +125,7 @@ def listplaylists(context: MpdContext) -> protocol.ResultList: for playlist_ref in context.core.playlists.as_list().get(): if not playlist_ref.name: continue - name = context.lookup_playlist_name_from_uri(playlist_ref.uri) + name = context.uri_map.playlist_name_from_uri(playlist_ref.uri) if name is None: continue result.append(("playlist", name)) @@ -383,7 +383,7 @@ def rm(context: MpdContext, name: str) -> None: Removes the playlist ``NAME.m3u`` from the playlist directory. """ _check_playlist_name(name) - uri = context.lookup_playlist_uri_from_name(name) + uri = context.uri_map.playlist_uri_from_name(name) if not uri: raise exceptions.MpdNoExistError("No such playlist") context.core.playlists.delete(uri).get() From b2f0b26141747d9fa8bfd32e9ad39f577beabb05 Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Sat, 24 Feb 2024 03:13:25 +0100 Subject: [PATCH 09/19] Move subsystem events/subscriptions to MpdDispatcher --- src/mopidy_mpd/context.py | 9 +-------- src/mopidy_mpd/dispatcher.py | 24 ++++++++++++++++++------ src/mopidy_mpd/protocol/status.py | 16 +++++++++------- tests/protocol/test_idle.py | 4 ++-- 4 files changed, 30 insertions(+), 23 deletions(-) diff --git a/src/mopidy_mpd/context.py b/src/mopidy_mpd/context.py index 9089996..c5be15d 100644 --- a/src/mopidy_mpd/context.py +++ b/src/mopidy_mpd/context.py @@ -47,12 +47,6 @@ class MpdContext: #: The MPD password. password: str | None = None - #: The active subsystems that have pending events. - events: set[str] - - #: The subsystems that we want to be notified about in idle mode. - subscriptions: set[str] - #: Mapping of URIs to MPD names. uri_map: MpdUriMapper @@ -70,8 +64,7 @@ def __init__( if config is not None: mpd_config = cast(types.MpdConfig, config["mpd"]) self.password = mpd_config["password"] - self.events = set() - self.subscriptions = set() + self.uri_map = MpdUriMapper(core) @overload diff --git a/src/mopidy_mpd/dispatcher.py b/src/mopidy_mpd/dispatcher.py index bc6eff8..ec11312 100644 --- a/src/mopidy_mpd/dispatcher.py +++ b/src/mopidy_mpd/dispatcher.py @@ -41,6 +41,12 @@ class MpdDispatcher: _noidle = re.compile(r"^noidle$") + #: The active subsystems that have pending events. + subsystem_events: set[str] + + #: The subsystems that we want to be notified about in idle mode. + subsystem_subscriptions: set[str] + def __init__( self, config: Config, @@ -49,11 +55,17 @@ def __init__( ) -> None: self.config = config self.mpd_config = cast(types.MpdConfig, config.get("mpd", {}) if config else {}) + self.authenticated = False + self.command_list_receiving = False self.command_list_ok = False self.command_list = [] self.command_list_index = None + + self.subsystem_events = set() + self.subsystem_subscriptions = set() + self.context = context.MpdContext( config=config, core=core, @@ -80,10 +92,10 @@ def handle_request( return self._call_next_filter(request, response, filter_chain) def handle_idle(self, subsystem: str) -> None: - # TODO: validate against mopidy_mpd/protocol/status.SUBSYSTEMS - self.context.events.add(subsystem) + # TODO: validate against mopidy_mpd.protocol.status.SUBSYSTEMS + self.subsystem_events.add(subsystem) - subsystems = self.context.subscriptions.intersection(self.context.events) + subsystems = self.subsystem_subscriptions.intersection(self.subsystem_events) if not subsystems: return @@ -91,8 +103,8 @@ def handle_idle(self, subsystem: str) -> None: for subsystem in subsystems: response.append(f"changed: {subsystem}") response.append("OK") - self.context.subscriptions = set() - self.context.events = set() + self.subsystem_events = set() + self.subsystem_subscriptions = set() self.context.session.send_lines(response) def _call_next_filter( @@ -199,7 +211,7 @@ def _idle_filter( return response def _is_currently_idle(self) -> bool: - return bool(self.context.subscriptions) + return bool(self.subsystem_subscriptions) # --- Filter: add OK diff --git a/src/mopidy_mpd/protocol/status.py b/src/mopidy_mpd/protocol/status.py index f5f13c3..1fd061d 100644 --- a/src/mopidy_mpd/protocol/status.py +++ b/src/mopidy_mpd/protocol/status.py @@ -101,16 +101,18 @@ def idle(context: MpdContext, *args: str) -> protocol.Result: subsystems = list(args) if args else SUBSYSTEMS for subsystem in subsystems: - context.subscriptions.add(subsystem) + context.dispatcher.subsystem_subscriptions.add(subsystem) - active = context.subscriptions.intersection(context.events) + active = context.dispatcher.subsystem_subscriptions.intersection( + context.dispatcher.subsystem_events + ) if not active: context.session.prevent_timeout = True return None response = [] - context.events = set() - context.subscriptions = set() + context.dispatcher.subsystem_events = set() + context.dispatcher.subsystem_subscriptions = set() for subsystem in active: response.append(f"changed: {subsystem}") @@ -120,10 +122,10 @@ def idle(context: MpdContext, *args: str) -> protocol.Result: @protocol.commands.add("noidle", list_command=False) def noidle(context: MpdContext) -> None: """See :meth:`_status_idle`.""" - if not context.subscriptions: + if not context.dispatcher.subsystem_subscriptions: return - context.subscriptions = set() - context.events = set() + context.dispatcher.subsystem_subscriptions = set() + context.dispatcher.subsystem_events = set() context.session.prevent_timeout = False diff --git a/tests/protocol/test_idle.py b/tests/protocol/test_idle.py index 89d7a0a..da607f6 100644 --- a/tests/protocol/test_idle.py +++ b/tests/protocol/test_idle.py @@ -10,10 +10,10 @@ def idle_event(self, subsystem): self.session.on_event(subsystem) def assertEqualEvents(self, events): # noqa: N802 - assert set(events) == self.context.events + assert self.dispatcher.subsystem_events == set(events) def assertEqualSubscriptions(self, events): # noqa: N802 - assert set(events) == self.context.subscriptions + assert self.dispatcher.subsystem_subscriptions == set(events) def assertNoEvents(self): # noqa: N802 self.assertEqualEvents([]) From 60c75f9b777884761d4c13c169171acb97e54a9e Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Sat, 24 Feb 2024 03:16:34 +0100 Subject: [PATCH 10/19] Make MpdDispatcher use MpdSession directly --- src/mopidy_mpd/dispatcher.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/mopidy_mpd/dispatcher.py b/src/mopidy_mpd/dispatcher.py index ec11312..6ee1304 100644 --- a/src/mopidy_mpd/dispatcher.py +++ b/src/mopidy_mpd/dispatcher.py @@ -55,6 +55,7 @@ def __init__( ) -> None: self.config = config self.mpd_config = cast(types.MpdConfig, config.get("mpd", {}) if config else {}) + self.session = session self.authenticated = False @@ -105,7 +106,7 @@ def handle_idle(self, subsystem: str) -> None: response.append("OK") self.subsystem_events = set() self.subsystem_subscriptions = set() - self.context.session.send_lines(response) + self.session.send_lines(response) def _call_next_filter( self, request: str, response: Response, filter_chain: list[Filter] @@ -197,7 +198,7 @@ def _idle_filter( repr(request), repr("noidle"), ) - self.context.session.close() + self.session.close() return Response([]) if not self._is_currently_idle() and self._noidle.match(request): From f0cc69cc59ddb6abd01072da5afbaf07c3e47d3d Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Sat, 24 Feb 2024 03:21:37 +0100 Subject: [PATCH 11/19] Replace regexp pattern with full string matching --- src/mopidy_mpd/dispatcher.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/mopidy_mpd/dispatcher.py b/src/mopidy_mpd/dispatcher.py index 6ee1304..aeab4e0 100644 --- a/src/mopidy_mpd/dispatcher.py +++ b/src/mopidy_mpd/dispatcher.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import re from collections.abc import Callable from typing import ( TYPE_CHECKING, @@ -39,8 +38,6 @@ class MpdDispatcher: back to the MPD session. """ - _noidle = re.compile(r"^noidle$") - #: The active subsystems that have pending events. subsystem_events: set[str] @@ -192,16 +189,16 @@ def _idle_filter( response: Response, filter_chain: list[Filter], ) -> Response: - if self._is_currently_idle() and not self._noidle.match(request): + if self._is_currently_idle() and request != "noidle": logger.debug( - "Client sent us %s, only %s is allowed while in " "the idle state", + "Client sent us %s, only %s is allowed while in the idle state", repr(request), repr("noidle"), ) self.session.close() return Response([]) - if not self._is_currently_idle() and self._noidle.match(request): + if not self._is_currently_idle() and request == "noidle": return Response([]) # noidle was called before idle response = self._call_next_filter(request, response, filter_chain) From 175f93d23cc3f49dd7a0bf262a2be896a831d945 Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Sat, 24 Feb 2024 03:57:43 +0100 Subject: [PATCH 12/19] Replace protocol_kwargs with something more type safe --- src/mopidy_mpd/actor.py | 16 ++++----- src/mopidy_mpd/context.py | 12 +++---- src/mopidy_mpd/dispatcher.py | 5 ++- src/mopidy_mpd/network.py | 43 +++++++++++++++++----- src/mopidy_mpd/session.py | 16 +++++++-- tests/network/test_connection.py | 61 ++++++++++++++++++++------------ tests/network/test_server.py | 59 ++++++++++++++++++------------ tests/protocol/__init__.py | 9 +++-- tests/test_context.py | 2 ++ tests/test_dispatcher.py | 2 ++ tests/test_session.py | 2 ++ tests/test_status.py | 3 +- 12 files changed, 154 insertions(+), 76 deletions(-) diff --git a/src/mopidy_mpd/actor.py b/src/mopidy_mpd/actor.py index 5795811..ab7809d 100644 --- a/src/mopidy_mpd/actor.py +++ b/src/mopidy_mpd/actor.py @@ -33,14 +33,12 @@ def __init__(self, config: Config, core: CoreProxy) -> None: super().__init__() mpd_config = cast(types.MpdConfig, config.get("mpd", {})) - self.hostname = network.format_hostname(mpd_config["hostname"]) self.port = mpd_config["port"] - self.uri_map = uri_mapper.MpdUriMapper(core) - self.zeroconf_name = mpd_config["zeroconf"] self.zeroconf_service = None + self.uri_map = uri_mapper.MpdUriMapper(core) self.server = self._setup_server(config, core) def _setup_server(self, config: Config, core: CoreProxy) -> network.Server: @@ -48,14 +46,12 @@ def _setup_server(self, config: Config, core: CoreProxy) -> network.Server: try: server = network.Server( - self.hostname, - self.port, + config=config, + core=core, + uri_map=self.uri_map, protocol=session.MpdSession, - protocol_kwargs={ - "config": config, - "core": core, - "uri_map": self.uri_map, - }, + host=self.hostname, + port=self.port, max_connections=mpd_config["max_connections"], timeout=mpd_config["connection_timeout"], ) diff --git a/src/mopidy_mpd/context.py b/src/mopidy_mpd/context.py index c5be15d..f47a2f9 100644 --- a/src/mopidy_mpd/context.py +++ b/src/mopidy_mpd/context.py @@ -11,19 +11,19 @@ ) from mopidy_mpd import exceptions, types -from mopidy_mpd.uri_mapper import MpdUriMapper if TYPE_CHECKING: from collections.abc import Generator import pykka + from mopidy.config import Config from mopidy.core import CoreProxy - from mopidy.ext import Config from mopidy.models import Ref, Track from mopidy.types import Uri from mopidy_mpd.dispatcher import MpdDispatcher from mopidy_mpd.session import MpdSession + from mopidy_mpd.uri_mapper import MpdUriMapper logger = logging.getLogger(__name__) @@ -50,23 +50,23 @@ class MpdContext: #: Mapping of URIs to MPD names. uri_map: MpdUriMapper - def __init__( + def __init__( # noqa: PLR0913 self, config: Config, core: CoreProxy, + uri_map: MpdUriMapper, session: MpdSession, dispatcher: MpdDispatcher, ) -> None: self.core = core + self.uri_map = uri_map self.session = session self.dispatcher = dispatcher if config is not None: - mpd_config = cast(types.MpdConfig, config["mpd"]) + mpd_config = cast(types.MpdConfig, config.get("mpd", {})) self.password = mpd_config["password"] - self.uri_map = MpdUriMapper(core) - @overload def browse( self, path: str | None, *, recursive: bool, lookup: Literal[True] diff --git a/src/mopidy_mpd/dispatcher.py b/src/mopidy_mpd/dispatcher.py index aeab4e0..de2ccd3 100644 --- a/src/mopidy_mpd/dispatcher.py +++ b/src/mopidy_mpd/dispatcher.py @@ -15,10 +15,11 @@ from mopidy_mpd import context, exceptions, protocol, tokenize, types if TYPE_CHECKING: + from mopidy.config import Config from mopidy.core import CoreProxy - from mopidy.ext import Config from mopidy_mpd.session import MpdSession + from mopidy_mpd.uri_mapper import MpdUriMapper logger = logging.getLogger(__name__) @@ -48,6 +49,7 @@ def __init__( self, config: Config, core: CoreProxy, + uri_map: MpdUriMapper, session: MpdSession, ) -> None: self.config = config @@ -67,6 +69,7 @@ def __init__( self.context = context.MpdContext( config=config, core=core, + uri_map=uri_map, session=session, dispatcher=self, ) diff --git a/src/mopidy_mpd/network.py b/src/mopidy_mpd/network.py index 3fc2872..617f4ce 100644 --- a/src/mopidy_mpd/network.py +++ b/src/mopidy_mpd/network.py @@ -19,8 +19,12 @@ from collections.abc import Generator from types import TracebackType - from mopidy_mpd.session import MpdSession + from mopidy.config import Config + from mopidy.core import CoreProxy + + from mopidy_mpd.session import MpdSession, MpdSessionKwargs from mopidy_mpd.types import SocketAddress + from mopidy_mpd.uri_mapper import MpdUriMapper CONTROL_CHARS = dict.fromkeys(range(32)) @@ -117,20 +121,25 @@ def format_hostname(hostname: str) -> str: class Server: - """Setup listener and register it with GLib's event loop.""" def __init__( # noqa: PLR0913 self, + *, + config: Config, + core: CoreProxy, + uri_map: MpdUriMapper, + protocol: type[MpdSession], host: str, port: int, - protocol: type[MpdSession], - protocol_kwargs: dict[str, Any] | None = None, max_connections: int = 5, timeout: int = 30, ) -> None: + self.config = config + self.core = core + self.uri_map = uri_map self.protocol = protocol - self.protocol_kwargs = protocol_kwargs or {} + self.max_connections = max_connections self.timeout = timeout self.server_socket = self.create_server_socket(host, port) @@ -218,7 +227,15 @@ def reject_connection(self, sock: socket.socket, addr: SocketAddress) -> None: sock.close() def init_connection(self, sock: socket.socket, addr: SocketAddress) -> None: - Connection(self.protocol, self.protocol_kwargs, sock, addr, self.timeout) + Connection( + config=self.config, + core=self.core, + uri_map=self.uri_map, + protocol=self.protocol, + sock=sock, + addr=addr, + timeout=self.timeout, + ) class Connection: @@ -235,8 +252,11 @@ class Connection: def __init__( # noqa: PLR0913 self, + *, + config: Config, + core: CoreProxy, + uri_map: MpdUriMapper, protocol: type[MpdSession], - protocol_kwargs: dict[str, Any], sock: socket.socket, addr: SocketAddress, timeout: int, @@ -247,7 +267,6 @@ def __init__( # noqa: PLR0913 self._sock = sock self.protocol = protocol - self.protocol_kwargs = protocol_kwargs self.timeout = timeout self.send_lock = threading.Lock() @@ -259,7 +278,13 @@ def __init__( # noqa: PLR0913 self.send_id = None self.timeout_id = None - self.actor_ref = self.protocol.start(self, **self.protocol_kwargs) + protocol_kwargs: MpdSessionKwargs = { + "config": config, + "core": core, + "uri_map": uri_map, + "connection": self, + } + self.actor_ref = self.protocol.start(**protocol_kwargs) self.enable_recv() self.enable_timeout() diff --git a/src/mopidy_mpd/session.py b/src/mopidy_mpd/session.py index 3e0ddf7..63e3424 100644 --- a/src/mopidy_mpd/session.py +++ b/src/mopidy_mpd/session.py @@ -1,19 +1,28 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, NoReturn +from typing import TYPE_CHECKING, NoReturn, TypedDict from mopidy_mpd import dispatcher, formatting, network, protocol from mopidy_mpd.protocol import tagtype_list if TYPE_CHECKING: + from mopidy.config import Config from mopidy.core import CoreProxy - from mopidy.ext import Config + + from mopidy_mpd.uri_mapper import MpdUriMapper logger = logging.getLogger(__name__) +class MpdSessionKwargs(TypedDict): + config: Config + core: CoreProxy + uri_map: MpdUriMapper + connection: network.Connection + + class MpdSession(network.LineProtocol): """ The MPD client session. Keeps track of a single client session. Any @@ -25,14 +34,17 @@ class MpdSession(network.LineProtocol): def __init__( self, + *, config: Config, core: CoreProxy, + uri_map: MpdUriMapper, connection: network.Connection, ) -> None: super().__init__(connection) self.dispatcher = dispatcher.MpdDispatcher( config=config, core=core, + uri_map=uri_map, session=self, ) self.tagtypes = tagtype_list.TAGTYPE_LIST.copy() diff --git a/tests/network/test_connection.py b/tests/network/test_connection.py index 1c6cb1d..7fc7de5 100644 --- a/tests/network/test_connection.py +++ b/tests/network/test_connection.py @@ -6,7 +6,7 @@ import pykka from gi.repository import GLib -from mopidy_mpd import network +from mopidy_mpd import network, uri_mapper from tests import any_int, any_unicode @@ -20,11 +20,13 @@ def test_init_ensure_nonblocking_io(self): network.Connection.__init__( self.mock, - Mock(), - {}, - sock, - (sentinel.host, sentinel.port), - sentinel.timeout, + config={}, + core=Mock(), + uri_map=Mock(spec=uri_mapper.MpdUriMapper), + protocol=Mock(spec=network.LineProtocol), + sock=sock, + addr=(sentinel.host, sentinel.port), + timeout=sentinel.timeout, ) sock.setblocking.assert_called_once_with(False) @@ -33,22 +35,26 @@ def test_init_starts_actor(self): network.Connection.__init__( self.mock, - protocol, - {}, - Mock(), - (sentinel.host, sentinel.port), - sentinel.timeout, + config={}, + core=Mock(), + uri_map=Mock(spec=uri_mapper.MpdUriMapper), + protocol=protocol, + sock=Mock(spec=socket.SocketType), + addr=(sentinel.host, sentinel.port), + timeout=sentinel.timeout, ) - protocol.start.assert_called_once_with(self.mock) + protocol.start.assert_called_once() def test_init_enables_recv_and_timeout(self): network.Connection.__init__( self.mock, - Mock(), - {}, - Mock(), - (sentinel.host, sentinel.port), - sentinel.timeout, + config={}, + core=Mock(), + uri_map=Mock(spec=uri_mapper.MpdUriMapper), + protocol=Mock(spec=network.LineProtocol), + sock=Mock(spec=socket.SocketType), + addr=(sentinel.host, sentinel.port), + timeout=sentinel.timeout, ) self.mock.enable_recv.assert_called_once_with() self.mock.enable_timeout.assert_called_once_with() @@ -56,15 +62,20 @@ def test_init_enables_recv_and_timeout(self): def test_init_stores_values_in_attributes(self): addr = (sentinel.host, sentinel.port) protocol = Mock(spec=network.LineProtocol) - protocol_kwargs = {} sock = Mock(spec=socket.SocketType) network.Connection.__init__( - self.mock, protocol, protocol_kwargs, sock, addr, sentinel.timeout + self.mock, + config={}, + core=Mock(), + uri_map=Mock(spec=uri_mapper.MpdUriMapper), + protocol=protocol, + sock=sock, + addr=addr, + timeout=sentinel.timeout, ) assert sock == self.mock._sock assert protocol == self.mock.protocol - assert protocol_kwargs == self.mock.protocol_kwargs assert sentinel.timeout == self.mock.timeout assert sentinel.host == self.mock.host assert sentinel.port == self.mock.port @@ -77,11 +88,17 @@ def test_init_handles_ipv6_addr(self): sentinel.scopeid, ) protocol = Mock(spec=network.LineProtocol) - protocol_kwargs = {} sock = Mock(spec=socket.SocketType) network.Connection.__init__( - self.mock, protocol, protocol_kwargs, sock, addr, sentinel.timeout + self.mock, + config={}, + core=Mock(), + uri_map=Mock(spec=uri_mapper.MpdUriMapper), + protocol=protocol, + sock=sock, + addr=addr, + timeout=sentinel.timeout, ) assert sentinel.host == self.mock.host assert sentinel.port == self.mock.port diff --git a/tests/network/test_server.py b/tests/network/test_server.py index 68e2b7c..d838082 100644 --- a/tests/network/test_server.py +++ b/tests/network/test_server.py @@ -5,7 +5,8 @@ from unittest.mock import Mock, patch, sentinel from gi.repository import GLib -from mopidy_mpd import network +from mopidy.core import CoreProxy +from mopidy_mpd import network, uri_mapper from tests import any_int @@ -17,7 +18,13 @@ def setUp(self): @patch.object(network, "get_socket_address", new=Mock()) def test_init_calls_create_server_socket(self): network.Server.__init__( - self.mock, sentinel.host, sentinel.port, sentinel.protocol + self.mock, + config={}, + core=Mock(spec=CoreProxy), + uri_map=Mock(uri_mapper.MpdUriMapper), + protocol=sentinel.protocol, + host=sentinel.host, + port=sentinel.port, ) self.mock.create_server_socket.assert_called_once_with( sentinel.host, sentinel.port @@ -27,7 +34,13 @@ def test_init_calls_create_server_socket(self): @patch.object(network, "get_socket_address", new=Mock()) def test_init_calls_get_socket_address(self): network.Server.__init__( - self.mock, sentinel.host, sentinel.port, sentinel.protocol + self.mock, + config={}, + core=Mock(spec=CoreProxy), + uri_map=Mock(uri_mapper.MpdUriMapper), + protocol=sentinel.protocol, + host=sentinel.host, + port=sentinel.port, ) self.mock.create_server_socket.return_value = None network.get_socket_address.assert_called_once_with(sentinel.host, sentinel.port) @@ -40,7 +53,13 @@ def test_init_calls_register_server(self): self.mock.create_server_socket.return_value = sock network.Server.__init__( - self.mock, sentinel.host, sentinel.port, sentinel.protocol + self.mock, + config={}, + core=Mock(spec=CoreProxy), + uri_map=Mock(uri_mapper.MpdUriMapper), + protocol=sentinel.protocol, + host=sentinel.host, + port=sentinel.port, ) self.mock.register_server_socket.assert_called_once_with(sentinel.fileno) @@ -52,7 +71,13 @@ def test_init_fails_on_fileno_call(self): with self.assertRaises(socket.error): network.Server.__init__( - self.mock, sentinel.host, sentinel.port, sentinel.protocol + self.mock, + config={}, + core=Mock(spec=CoreProxy), + uri_map=Mock(uri_mapper.MpdUriMapper), + protocol=sentinel.protocol, + host=sentinel.host, + port=sentinel.port, ) def test_init_stores_values_in_attributes(self): @@ -62,9 +87,12 @@ def test_init_stores_values_in_attributes(self): network.Server.__init__( self.mock, - str(sentinel.host), - sentinel.port, - sentinel.protocol, + config={}, + core=Mock(spec=CoreProxy), + uri_map=Mock(uri_mapper.MpdUriMapper), + protocol=sentinel.protocol, + host=str(sentinel.host), + port=sentinel.port, max_connections=sentinel.max_connections, timeout=sentinel.timeout, ) @@ -265,21 +293,6 @@ def test_number_of_connections(self, get_by_class): get_by_class.return_value = [] assert network.Server.number_of_connections(self.mock) == 0 - @patch.object(network, "Connection", new=Mock()) - def test_init_connection(self): - self.mock.protocol = sentinel.protocol - self.mock.protocol_kwargs = {} - self.mock.timeout = sentinel.timeout - - network.Server.init_connection(self.mock, sentinel.sock, sentinel.addr) - network.Connection.assert_called_once_with( - sentinel.protocol, - {}, - sentinel.sock, - sentinel.addr, - sentinel.timeout, - ) - def test_reject_connection(self): sock = Mock(spec=socket.socket) diff --git a/tests/protocol/__init__.py b/tests/protocol/__init__.py index beab64d..d3c0d20 100644 --- a/tests/protocol/__init__.py +++ b/tests/protocol/__init__.py @@ -4,7 +4,7 @@ import pykka from mopidy import core -from mopidy_mpd import session +from mopidy_mpd import session, uri_mapper from tests import dummy_audio, dummy_backend, dummy_mixer @@ -28,7 +28,11 @@ class BaseTestCase(unittest.TestCase): def get_config(self): return { "core": {"max_tracklist_length": 10000}, - "mpd": {"password": None, "default_playlist_scheme": "dummy"}, + "mpd": { + "command_blacklist": [], + "default_playlist_scheme": "dummy", + "password": None, + }, } def setUp(self): @@ -53,6 +57,7 @@ def setUp(self): self.session = session.MpdSession( config=self.get_config(), core=self.core, + uri_map=uri_mapper.MpdUriMapper(self.core), connection=self.connection, ) self.dispatcher = self.session.dispatcher diff --git a/tests/test_context.py b/tests/test_context.py index fdd1b1b..6d4c9ec 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -5,6 +5,7 @@ from mopidy.backend import BackendProxy from mopidy.core import Core, CoreProxy from mopidy.models import Ref +from mopidy_mpd import uri_mapper from mopidy_mpd.context import MpdContext from tests import dummy_backend @@ -44,6 +45,7 @@ def mpd_context(backend_to_browse: BackendProxy) -> MpdContext: return MpdContext( config=None, core=core, + uri_map=uri_mapper.MpdUriMapper(core), dispatcher=None, session=None, ) diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index a458b4b..4b7d591 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -3,6 +3,7 @@ import pykka from mopidy.core import Core, CoreProxy +from mopidy_mpd import uri_mapper from mopidy_mpd.dispatcher import MpdDispatcher from mopidy_mpd.exceptions import MpdAckError @@ -19,6 +20,7 @@ def setUp(self): self.dispatcher = MpdDispatcher( config=config, core=self.core, + uri_map=uri_mapper.MpdUriMapper(self.core), session=None, ) diff --git a/tests/test_session.py b/tests/test_session.py index 591a8d4..55512e9 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -11,6 +11,7 @@ def test_on_start_logged(caplog): session.MpdSession( config=None, core=None, + uri_map=None, connection=connection, ).on_start() @@ -23,6 +24,7 @@ def test_on_line_received_logged(caplog): mpd_session = session.MpdSession( config=None, core=None, + uri_map=None, connection=connection, ) mpd_session.dispatcher = Mock(spec=dispatcher.MpdDispatcher) diff --git a/tests/test_status.py b/tests/test_status.py index bcb5385..9c132aa 100644 --- a/tests/test_status.py +++ b/tests/test_status.py @@ -5,7 +5,7 @@ from mopidy import core from mopidy.core import PlaybackState from mopidy.models import Track -from mopidy_mpd import dispatcher +from mopidy_mpd import dispatcher, uri_mapper from mopidy_mpd.protocol import status from tests import dummy_audio, dummy_backend, dummy_mixer @@ -42,6 +42,7 @@ def setUp(self): self.dispatcher = dispatcher.MpdDispatcher( config=config, core=self.core, + uri_map=uri_mapper.MpdUriMapper(self.core), session=None, ) self.context = self.dispatcher.context From 9ae366044740577092ca38e20ad0ec80fa92fc63 Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Sun, 25 Feb 2024 23:25:48 +0100 Subject: [PATCH 13/19] Properly type config object with MPD config --- src/mopidy_mpd/actor.py | 20 ++++++++------------ src/mopidy_mpd/context.py | 15 +++++---------- src/mopidy_mpd/dispatcher.py | 7 ++----- src/mopidy_mpd/network.py | 6 +++--- src/mopidy_mpd/protocol/connection.py | 2 +- src/mopidy_mpd/protocol/stored_playlists.py | 2 +- src/mopidy_mpd/session.py | 7 +++---- src/mopidy_mpd/types.py | 13 +++++++++++-- 8 files changed, 34 insertions(+), 38 deletions(-) diff --git a/src/mopidy_mpd/actor.py b/src/mopidy_mpd/actor.py index ab7809d..66c9799 100644 --- a/src/mopidy_mpd/actor.py +++ b/src/mopidy_mpd/actor.py @@ -1,9 +1,8 @@ import logging -from typing import Any, cast +from typing import Any import pykka from mopidy import exceptions, listener, zeroconf -from mopidy.config import Config from mopidy.core import CoreListener, CoreProxy from mopidy_mpd import network, session, types, uri_mapper @@ -29,21 +28,18 @@ class MpdFrontend(pykka.ThreadingActor, CoreListener): - def __init__(self, config: Config, core: CoreProxy) -> None: + def __init__(self, config: types.Config, core: CoreProxy) -> None: super().__init__() - mpd_config = cast(types.MpdConfig, config.get("mpd", {})) - self.hostname = network.format_hostname(mpd_config["hostname"]) - self.port = mpd_config["port"] - self.zeroconf_name = mpd_config["zeroconf"] + self.hostname = network.format_hostname(config["mpd"]["hostname"]) + self.port = config["mpd"]["port"] + self.zeroconf_name = config["mpd"]["zeroconf"] self.zeroconf_service = None self.uri_map = uri_mapper.MpdUriMapper(core) self.server = self._setup_server(config, core) - def _setup_server(self, config: Config, core: CoreProxy) -> network.Server: - mpd_config = cast(types.MpdConfig, config.get("mpd", {})) - + def _setup_server(self, config: types.Config, core: CoreProxy) -> network.Server: try: server = network.Server( config=config, @@ -52,8 +48,8 @@ def _setup_server(self, config: Config, core: CoreProxy) -> network.Server: protocol=session.MpdSession, host=self.hostname, port=self.port, - max_connections=mpd_config["max_connections"], - timeout=mpd_config["connection_timeout"], + max_connections=config["mpd"]["max_connections"], + timeout=config["mpd"]["connection_timeout"], ) except OSError as exc: raise exceptions.FrontendError(f"MPD server startup failed: {exc}") from exc diff --git a/src/mopidy_mpd/context.py b/src/mopidy_mpd/context.py index f47a2f9..107bcbe 100644 --- a/src/mopidy_mpd/context.py +++ b/src/mopidy_mpd/context.py @@ -6,7 +6,6 @@ TYPE_CHECKING, Any, Literal, - cast, overload, ) @@ -16,7 +15,6 @@ from collections.abc import Generator import pykka - from mopidy.config import Config from mopidy.core import CoreProxy from mopidy.models import Ref, Track from mopidy.types import Uri @@ -35,6 +33,9 @@ class MpdContext: give the command handlers access to important parts of Mopidy. """ + #: The Mopidy config. + config: types.Config + #: The Mopidy core API. core: CoreProxy @@ -44,29 +45,23 @@ class MpdContext: #: The current dispatcher instance. dispatcher: MpdDispatcher - #: The MPD password. - password: str | None = None - #: Mapping of URIs to MPD names. uri_map: MpdUriMapper def __init__( # noqa: PLR0913 self, - config: Config, + config: types.Config, core: CoreProxy, uri_map: MpdUriMapper, session: MpdSession, dispatcher: MpdDispatcher, ) -> None: + self.config = config self.core = core self.uri_map = uri_map self.session = session self.dispatcher = dispatcher - if config is not None: - mpd_config = cast(types.MpdConfig, config.get("mpd", {})) - self.password = mpd_config["password"] - @overload def browse( self, path: str | None, *, recursive: bool, lookup: Literal[True] diff --git a/src/mopidy_mpd/dispatcher.py b/src/mopidy_mpd/dispatcher.py index de2ccd3..72a7cf2 100644 --- a/src/mopidy_mpd/dispatcher.py +++ b/src/mopidy_mpd/dispatcher.py @@ -7,7 +7,6 @@ NewType, TypeAlias, TypeVar, - cast, ) import pykka @@ -15,7 +14,6 @@ from mopidy_mpd import context, exceptions, protocol, tokenize, types if TYPE_CHECKING: - from mopidy.config import Config from mopidy.core import CoreProxy from mopidy_mpd.session import MpdSession @@ -47,13 +45,12 @@ class MpdDispatcher: def __init__( self, - config: Config, + config: types.Config, core: CoreProxy, uri_map: MpdUriMapper, session: MpdSession, ) -> None: self.config = config - self.mpd_config = cast(types.MpdConfig, config.get("mpd", {}) if config else {}) self.session = session self.authenticated = False @@ -249,7 +246,7 @@ def _call_handler_filter( def _call_handler(self, request: str) -> protocol.Result: tokens = tokenize.split(request) # TODO: check that blacklist items are valid commands? - blacklist = self.mpd_config.get("command_blacklist", []) + blacklist = self.config["mpd"]["command_blacklist"] if tokens and tokens[0] in blacklist: logger.warning("MPD client used blacklisted command: %s", tokens[0]) raise exceptions.MpdDisabledError(command=tokens[0]) diff --git a/src/mopidy_mpd/network.py b/src/mopidy_mpd/network.py index 617f4ce..f16465f 100644 --- a/src/mopidy_mpd/network.py +++ b/src/mopidy_mpd/network.py @@ -19,9 +19,9 @@ from collections.abc import Generator from types import TracebackType - from mopidy.config import Config from mopidy.core import CoreProxy + from mopidy_mpd import types from mopidy_mpd.session import MpdSession, MpdSessionKwargs from mopidy_mpd.types import SocketAddress from mopidy_mpd.uri_mapper import MpdUriMapper @@ -126,7 +126,7 @@ class Server: def __init__( # noqa: PLR0913 self, *, - config: Config, + config: types.Config, core: CoreProxy, uri_map: MpdUriMapper, protocol: type[MpdSession], @@ -253,7 +253,7 @@ class Connection: def __init__( # noqa: PLR0913 self, *, - config: Config, + config: types.Config, core: CoreProxy, uri_map: MpdUriMapper, protocol: type[MpdSession], diff --git a/src/mopidy_mpd/protocol/connection.py b/src/mopidy_mpd/protocol/connection.py index 6d686cb..b5e8813 100644 --- a/src/mopidy_mpd/protocol/connection.py +++ b/src/mopidy_mpd/protocol/connection.py @@ -43,7 +43,7 @@ def password(context: MpdContext, password: str) -> None: This is used for authentication with the server. ``PASSWORD`` is simply the plaintext password. """ - if password == context.password: + if password == context.config["mpd"]["password"]: context.dispatcher.authenticated = True else: raise exceptions.MpdPasswordError("incorrect password") diff --git a/src/mopidy_mpd/protocol/stored_playlists.py b/src/mopidy_mpd/protocol/stored_playlists.py index 0cb3e9b..9b4cc45 100644 --- a/src/mopidy_mpd/protocol/stored_playlists.py +++ b/src/mopidy_mpd/protocol/stored_playlists.py @@ -234,7 +234,7 @@ def _create_playlist(context: MpdContext, name: str, tracks: Iterable[Track]) -> return # Created and saved continue # Failed to save using this backend # Can't use backend appropriate for passed URI schemes, use default one - default_scheme = context.dispatcher.config["mpd"]["default_playlist_scheme"] + default_scheme = context.config["mpd"]["default_playlist_scheme"] new_playlist = context.core.playlists.create(name, default_scheme).get() if new_playlist is None: # If even MPD's default backend can't save playlist, everything is lost diff --git a/src/mopidy_mpd/session.py b/src/mopidy_mpd/session.py index 63e3424..05dc88c 100644 --- a/src/mopidy_mpd/session.py +++ b/src/mopidy_mpd/session.py @@ -3,11 +3,10 @@ import logging from typing import TYPE_CHECKING, NoReturn, TypedDict -from mopidy_mpd import dispatcher, formatting, network, protocol +from mopidy_mpd import dispatcher, formatting, network, protocol, types from mopidy_mpd.protocol import tagtype_list if TYPE_CHECKING: - from mopidy.config import Config from mopidy.core import CoreProxy from mopidy_mpd.uri_mapper import MpdUriMapper @@ -17,7 +16,7 @@ class MpdSessionKwargs(TypedDict): - config: Config + config: types.Config core: CoreProxy uri_map: MpdUriMapper connection: network.Connection @@ -35,7 +34,7 @@ class MpdSession(network.LineProtocol): def __init__( self, *, - config: Config, + config: types.Config, core: CoreProxy, uri_map: MpdUriMapper, connection: network.Connection, diff --git a/src/mopidy_mpd/types.py b/src/mopidy_mpd/types.py index 6929b47..95c969f 100644 --- a/src/mopidy_mpd/types.py +++ b/src/mopidy_mpd/types.py @@ -1,6 +1,15 @@ from __future__ import annotations -from typing import TypeAlias, TypedDict +from typing import TYPE_CHECKING, TypeAlias, TypedDict + +from mopidy.config import Config as MopidyConfig + +if TYPE_CHECKING: + from mopidy.types import UriScheme + + +class Config(MopidyConfig): + mpd: MpdConfig class MpdConfig(TypedDict): @@ -11,7 +20,7 @@ class MpdConfig(TypedDict): connection_timeout: int zeroconf: str command_blacklist: list[str] - default_playlist_scheme: str + default_playlist_scheme: UriScheme SocketAddress: TypeAlias = tuple[str, int | None] From 7652add60246d56bf280541bc05f053094e07dff Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Mon, 26 Feb 2024 00:17:00 +0100 Subject: [PATCH 14/19] Refactor command handler registration ...to use a Handler class instead of nested functions. --- src/mopidy_mpd/protocol/__init__.py | 129 ++++++++++++++++------------ 1 file changed, 74 insertions(+), 55 deletions(-) diff --git a/src/mopidy_mpd/protocol/__init__.py b/src/mopidy_mpd/protocol/__init__.py index e4d0a92..05842f0 100644 --- a/src/mopidy_mpd/protocol/__init__.py +++ b/src/mopidy_mpd/protocol/__init__.py @@ -12,7 +12,6 @@ from __future__ import annotations -import functools import inspect from collections.abc import Callable from typing import TYPE_CHECKING, Any, TypeAlias @@ -37,7 +36,7 @@ ResultTuple: TypeAlias = tuple[str, ResultValue] ResultList: TypeAlias = list[ResultTuple | ResultDict] Result: TypeAlias = None | ResultDict | ResultTuple | ResultList -Handler: TypeAlias = Callable[..., Result] +HandlerFunc: TypeAlias = Callable[..., Result] def load_protocol_modules() -> None: @@ -125,7 +124,6 @@ def RANGE(value: str) -> slice: # noqa: N802 class Commands: - """Collection of MPD commands to expose to users. Normally used through the global instance which command handlers have been @@ -133,18 +131,18 @@ class Commands: """ def __init__(self) -> None: - self.handlers = {} + self.handlers: dict[str, Handler] = {} # TODO: consider removing auth_required and list_command in favour of # additional command instances to register in? - def add( # noqa: C901 + def add( self, name: str, *, auth_required: bool = True, list_command: bool = True, **validators: Callable[[str], Any], - ) -> Callable[[Handler], Handler]: + ) -> Callable[[HandlerFunc], HandlerFunc]: """Create a decorator that registers a handler and validation rules. Additional keyword arguments are treated as converters/validators to @@ -165,58 +163,16 @@ def add( # noqa: C901 :param list_command: If command should be listed in reflection. """ - def wrapper(func: Handler) -> Handler: # noqa: C901 + def wrapper(func: HandlerFunc) -> HandlerFunc: if name in self.handlers: raise ValueError(f"{name} already registered") - - spec = inspect.getfullargspec(func) - defaults = dict( - zip( - spec.args[-len(spec.defaults or []) :], - spec.defaults or [], - strict=False, - ) + self.handlers[name] = Handler( + name=name, + func=func, + auth_required=auth_required, + list_command=list_command, + validators=validators, ) - - if not spec.args and not spec.varargs: - raise TypeError("Handler must accept at least one argument.") - - if len(spec.args) > 1 and spec.varargs: - raise TypeError("*args may not be combined with regular arguments") - - if not set(validators.keys()).issubset(spec.args): - raise TypeError("Validator for non-existent arg passed") - - if spec.varkw or spec.kwonlyargs: - raise TypeError("Keyword arguments are not permitted") - - @functools.wraps(func) - def validate(*args: Any, **kwargs: Any) -> Result: - if spec.varargs: - return func(*args, **kwargs) - - try: - ba = inspect.signature(func).bind(*args, **kwargs) - ba.apply_defaults() - callargs = ba.arguments - except TypeError as exc: - raise exceptions.MpdArgError( - f'wrong number of arguments for "{name}"' - ) from exc - - for key, value in callargs.items(): - default = defaults.get(key, object()) - if key in validators and value != default: - try: - callargs[key] = validators[key](value) - except ValueError as exc: - raise exceptions.MpdArgError("incorrect arguments") from exc - - return func(**callargs) - - validate.auth_required = auth_required - validate.list_command = list_command - self.handlers[name] = validate return func return wrapper @@ -245,3 +201,66 @@ def call( #: Global instance to install commands into commands = Commands() + + +class Handler: + def __init__( # noqa: PLR0913 + self, + *, + name: str, + func: HandlerFunc, + auth_required: bool, + list_command: bool, + validators: dict[str, Callable[[str], Any]], + ) -> None: + self.name = name + self.func = func + self.auth_required = auth_required + self.list_command = list_command + self.validators = validators + + self.spec = inspect.getfullargspec(func) + + if not self.spec.args and not self.spec.varargs: + raise TypeError("Handler must accept at least one argument.") + + if len(self.spec.args) > 1 and self.spec.varargs: + raise TypeError("*args may not be combined with regular arguments") + + if not set(self.validators.keys()).issubset(self.spec.args): + raise TypeError("Validator for non-existent arg passed") + + if self.spec.varkw or self.spec.kwonlyargs: + raise TypeError("Keyword arguments are not permitted") + + self.defaults = dict( + zip( + self.spec.args[-len(self.spec.defaults or []) :], + self.spec.defaults or [], + strict=False, + ) + ) + + def __call__(self, *args: Any, **kwargs: Any) -> Result: + if self.spec.varargs: + return self.func(*args, **kwargs) + + try: + ba = inspect.signature(self.func).bind(*args, **kwargs) + ba.apply_defaults() + callargs = ba.arguments + except TypeError as exc: + raise exceptions.MpdArgError( + f'wrong number of arguments for "{self.name}"' + ) from exc + + for key, value in callargs.items(): + if value == self.defaults.get(key, object()): + continue + if validator := self.validators.get(key): + try: + callargs[key] = validator(value) + except ValueError as exc: + raise exceptions.MpdArgError("incorrect arguments") from exc + + return self.func(**callargs) From 3198cfb707545dfd6eddb0556775db9fccdf08b6 Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Mon, 26 Feb 2024 00:19:53 +0100 Subject: [PATCH 15/19] Work around too flexible typing of collections in models --- src/mopidy_mpd/protocol/stored_playlists.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/mopidy_mpd/protocol/stored_playlists.py b/src/mopidy_mpd/protocol/stored_playlists.py index 9b4cc45..8bf43c3 100644 --- a/src/mopidy_mpd/protocol/stored_playlists.py +++ b/src/mopidy_mpd/protocol/stored_playlists.py @@ -3,9 +3,10 @@ import datetime import logging import re -from typing import TYPE_CHECKING, Literal, overload +from typing import TYPE_CHECKING, Literal, cast, overload from urllib.parse import urlparse +from mopidy.models import Track from mopidy.types import Uri, UriScheme from mopidy_mpd import exceptions, protocol, translator @@ -13,7 +14,7 @@ if TYPE_CHECKING: from collections.abc import Iterable - from mopidy.models import Playlist, Track + from mopidy.models import Playlist from mopidy_mpd.context import MpdContext @@ -178,7 +179,11 @@ def load( in either or both ends. """ playlist = _get_playlist(context, name, must_exist=True) - track_uris = [track.uri for track in playlist.tracks[playlist_slice]] + tracks = cast( # TODO(type): Improve typing of models to avoid cast. + tuple[Track], + playlist.tracks[playlist_slice], # pyright: ignore[reportIndexIssue] + ) + track_uris = [track.uri for track in tracks] context.core.tracklist.add(uris=track_uris).get() From c3be0209989164ea494d23ea63a680e051af687d Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Mon, 26 Feb 2024 00:28:43 +0100 Subject: [PATCH 16/19] Work around lost type information --- src/mopidy_mpd/protocol/music_db.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/mopidy_mpd/protocol/music_db.py b/src/mopidy_mpd/protocol/music_db.py index 1e0422e..4831d0d 100644 --- a/src/mopidy_mpd/protocol/music_db.py +++ b/src/mopidy_mpd/protocol/music_db.py @@ -414,7 +414,15 @@ def lsinfo(context: MpdContext, uri: str | None = None) -> protocol.Result: ) if uri in (None, "", "/"): - result.extend(stored_playlists.listplaylists(context)) + result.extend( + # We know that `listplaylists`` returns this specific variant of + # `protocol.Result``, but this information disappears because of the + # typing of the `protocol.commands.add()`` decorator. + cast( + protocol.ResultList, + stored_playlists.listplaylists(context), + ) + ) return result From 9d76306126b382f951ff47efd1414cf963d3961d Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Mon, 26 Feb 2024 00:31:00 +0100 Subject: [PATCH 17/19] Increase pyright's type checking mode to standard --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2002f57..67eddf8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ pythonVersion = "3.11" # Use venv from parent directory, to share it with any extensions: venvPath = "../" venv = ".venv" -typeCheckingMode = "basic" +typeCheckingMode = "standard" # Already covered by flake8-self: reportPrivateImportUsage = false From 2d5e88180953e55c49aaa0a3aaac6f3a938702cb Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Mon, 26 Feb 2024 01:08:37 +0100 Subject: [PATCH 18/19] Add pygobject-stubs for a bit better type coverage --- pyproject.toml | 2 +- src/mopidy_mpd/network.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 67eddf8..758c06c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ [project.optional-dependencies] lint = ["ruff"] test = ["pytest", "pytest-cov"] -typing = ["pyright"] +typing = ["pygobject-stubs", "pyright"] dev = ["mopidy-mpd[lint,test,typing]", "tox"] [project.urls] diff --git a/src/mopidy_mpd/network.py b/src/mopidy_mpd/network.py index f16465f..0debdca 100644 --- a/src/mopidy_mpd/network.py +++ b/src/mopidy_mpd/network.py @@ -181,7 +181,7 @@ def stop(self) -> None: if unix_socket_path is not None: os.unlink(unix_socket_path) # noqa: PTH108 - def register_server_socket(self, fileno: int) -> Any: + def register_server_socket(self, fileno: int) -> int: return GLib.io_add_watch(fileno, GLib.IO_IN, self.handle_connection) def handle_connection(self, _fd: int, _flags: int) -> bool: @@ -274,9 +274,9 @@ def __init__( # noqa: PLR0913 self.stopping = False - self.recv_id = None - self.send_id = None - self.timeout_id = None + self.recv_id: int | None = None + self.send_id: int | None = None + self.timeout_id: int | None = None protocol_kwargs: MpdSessionKwargs = { "config": config, From d524bbbae4e21f86d370b6b849473fe08c4e668a Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Thu, 29 Feb 2024 23:04:25 +0100 Subject: [PATCH 19/19] Review feedback --- src/mopidy_mpd/dispatcher.py | 8 ++++---- src/mopidy_mpd/network.py | 2 +- src/mopidy_mpd/protocol/current_playlist.py | 2 +- src/mopidy_mpd/protocol/music_db.py | 15 ++++++++------- src/mopidy_mpd/protocol/playback.py | 5 +++-- src/mopidy_mpd/protocol/status.py | 4 ++-- src/mopidy_mpd/translator.py | 10 +++++++--- tests/network/test_connection.py | 8 -------- 8 files changed, 26 insertions(+), 28 deletions(-) diff --git a/src/mopidy_mpd/dispatcher.py b/src/mopidy_mpd/dispatcher.py index 72a7cf2..795e5a3 100644 --- a/src/mopidy_mpd/dispatcher.py +++ b/src/mopidy_mpd/dispatcher.py @@ -106,7 +106,7 @@ def handle_idle(self, subsystem: str) -> None: self.session.send_lines(response) def _call_next_filter( - self, request: str, response: Response, filter_chain: list[Filter] + self, request: Request, response: Response, filter_chain: list[Filter] ) -> Response: if filter_chain: next_filter = filter_chain.pop(0) @@ -175,10 +175,10 @@ def _command_list_filter( response = Response(response[:-1]) return response - def _is_receiving_command_list(self, request: str) -> bool: + def _is_receiving_command_list(self, request: Request) -> bool: return self.command_list_receiving and request != "command_list_end" - def _is_processing_command_list(self, request: str) -> bool: + def _is_processing_command_list(self, request: Request) -> bool: return self.command_list_index is not None and request != "command_list_end" # --- Filter: idle @@ -243,7 +243,7 @@ def _call_handler_filter( logger.warning("Tried to communicate with dead actor.") raise exceptions.MpdSystemError(str(exc)) from exc - def _call_handler(self, request: str) -> protocol.Result: + def _call_handler(self, request: Request) -> protocol.Result: tokens = tokenize.split(request) # TODO: check that blacklist items are valid commands? blacklist = self.config["mpd"]["command_blacklist"] diff --git a/src/mopidy_mpd/network.py b/src/mopidy_mpd/network.py index 0debdca..d69041c 100644 --- a/src/mopidy_mpd/network.py +++ b/src/mopidy_mpd/network.py @@ -329,7 +329,7 @@ def send(self, data: bytes) -> bytes: def enable_timeout(self) -> None: """Reactivate timeout mechanism.""" - if self.timeout is None or self.timeout <= 0: + if self.timeout <= 0: return self.disable_timeout() diff --git a/src/mopidy_mpd/protocol/current_playlist.py b/src/mopidy_mpd/protocol/current_playlist.py index 3f4dac9..76c3558 100644 --- a/src/mopidy_mpd/protocol/current_playlist.py +++ b/src/mopidy_mpd/protocol/current_playlist.py @@ -341,7 +341,7 @@ def plchangesposid(context: MpdContext, version: int) -> protocol.Result: @protocol.commands.add("prio", priority=protocol.UINT, position=protocol.RANGE) -def prio(context: MpdContext, priority: int, position: int) -> protocol.Result: +def prio(context: MpdContext, priority: int, position: slice) -> protocol.Result: """ *musicpd.org, current playlist section:* diff --git a/src/mopidy_mpd/protocol/music_db.py b/src/mopidy_mpd/protocol/music_db.py index 4831d0d..0210da3 100644 --- a/src/mopidy_mpd/protocol/music_db.py +++ b/src/mopidy_mpd/protocol/music_db.py @@ -1,7 +1,7 @@ from __future__ import annotations import itertools -from typing import TYPE_CHECKING, Never, cast +from typing import TYPE_CHECKING, cast from mopidy.models import Album, Artist, SearchResult, Track from mopidy.types import DistinctField, Query, SearchField, Uri @@ -360,7 +360,7 @@ def listallinfo(context: MpdContext, uri: str | None = None) -> protocol.Result: @protocol.commands.add("listfiles") -def listfiles(context: MpdContext, uri: str | None = None) -> Never: +def listfiles(context: MpdContext, uri: str | None = None) -> protocol.Result: """ *musicpd.org, music database section:* @@ -521,6 +521,11 @@ def searchaddpl(context: MpdContext, *args: str) -> None: raise exceptions.MpdArgError("incorrect arguments") playlist_name = parameters.pop(0) + try: + query = _query_for_search(parameters) + except ValueError: + return + uri = context.uri_map.playlist_uri_from_name(playlist_name) if uri: playlist = context.core.playlists.lookup(uri).get() @@ -529,10 +534,6 @@ def searchaddpl(context: MpdContext, *args: str) -> None: if not playlist: return # TODO: Raise error about failed playlist creation? - try: - query = _query_for_search(parameters) - except ValueError: - return results = context.core.library.search(query).get() tracks = list(playlist.tracks) + _get_tracks(results) playlist = playlist.replace(tracks=tracks) @@ -561,7 +562,7 @@ def update(context: MpdContext, uri: Uri | None = None) -> protocol.Result: # TODO: add at least reflection tests before adding NotImplemented version # @protocol.commands.add('readcomments') -def readcomments(context: MpdContext, uri: Uri | None = None) -> None: +def readcomments(context: MpdContext, uri: Uri) -> None: """ *musicpd.org, music database section:* diff --git a/src/mopidy_mpd/protocol/playback.py b/src/mopidy_mpd/protocol/playback.py index c7cacca..06bb71d 100644 --- a/src/mopidy_mpd/protocol/playback.py +++ b/src/mopidy_mpd/protocol/playback.py @@ -232,12 +232,13 @@ def playid(context: MpdContext, tlid: int) -> None: replacement, starts playback at the first track. """ if tlid == -1: - return _play_minus_one(context) + _play_minus_one(context) + return tl_tracks = context.core.tracklist.filter({"tlid": [tlid]}).get() if not tl_tracks: raise exceptions.MpdNoExistError("No such song") - return context.core.playback.play(tlid=tl_tracks[0].tlid).get() + context.core.playback.play(tlid=tl_tracks[0].tlid).get() @protocol.commands.add("previous") diff --git a/src/mopidy_mpd/protocol/status.py b/src/mopidy_mpd/protocol/status.py index 1fd061d..7d059b6 100644 --- a/src/mopidy_mpd/protocol/status.py +++ b/src/mopidy_mpd/protocol/status.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Never +from typing import TYPE_CHECKING from mopidy.core import PlaybackState @@ -27,7 +27,7 @@ @protocol.commands.add("clearerror") -def clearerror(context: MpdContext) -> Never: +def clearerror(context: MpdContext) -> protocol.Result: """ *musicpd.org, status section:* diff --git a/src/mopidy_mpd/translator.py b/src/mopidy_mpd/translator.py index 54ea6bd..4457d37 100644 --- a/src/mopidy_mpd/translator.py +++ b/src/mopidy_mpd/translator.py @@ -26,7 +26,8 @@ def track_to_mpd_format( # noqa: C901, PLR0912 """ Format track for output to MPD client. - :param track: the track + :param obj: the track + :param tagtypes: the MPD tagtypes enabled by the client :param position: track's position in playlist :param stream_title: the current streams title """ @@ -120,7 +121,7 @@ def _has_value( Determine whether to add the tagtype to the output or not. The tagtype must be in the list of tagtypes configured for the client. - :param tagtypes: the MPD tagtypes configured for the client + :param tagtypes: the MPD tagtypes enabled by the client :param tagtype: the MPD tagtype :param value: the tag value """ @@ -183,8 +184,10 @@ def tracks_to_mpd_format( Optionally limit output to the slice ``[start:end]`` of the list. :param tracks: the tracks + :param tagtypes: the MPD tagtypes enabled by the client :param start: position of first track to include in output - :param end: position after last track to include in output + :param end: position after last track to include in output, or ``None`` for + end of list """ if end is None: end = len(tracks) @@ -210,6 +213,7 @@ def playlist_to_mpd_format( Format playlist for output to MPD client. :param playlist: the playlist + :param tagtypes: the MPD tagtypes enabled by the client :param start: position of first track to include in output :param end: position after last track to include in output """ diff --git a/tests/network/test_connection.py b/tests/network/test_connection.py index 7fc7de5..426a7fd 100644 --- a/tests/network/test_connection.py +++ b/tests/network/test_connection.py @@ -326,10 +326,6 @@ def test_enable_timeout_does_not_add_timeout(self): network.Connection.enable_timeout(self.mock) assert GLib.timeout_add_seconds.call_count == 0 - self.mock.timeout = None - network.Connection.enable_timeout(self.mock) - assert GLib.timeout_add_seconds.call_count == 0 - def test_enable_timeout_does_not_call_disable_for_invalid_timeout(self): self.mock.timeout = 0 network.Connection.enable_timeout(self.mock) @@ -339,10 +335,6 @@ def test_enable_timeout_does_not_call_disable_for_invalid_timeout(self): network.Connection.enable_timeout(self.mock) assert self.mock.disable_timeout.call_count == 0 - self.mock.timeout = None - network.Connection.enable_timeout(self.mock) - assert self.mock.disable_timeout.call_count == 0 - @patch.object(GLib, "source_remove", new=Mock()) def test_disable_timeout_deregisters(self): self.mock.timeout_id = sentinel.tag