Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suppport faster JLAP by calling two-file fetch, api.Repo() #379

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 42 additions & 19 deletions conda_libmamba_solver/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
We maintain a map of subdir-specific URLs to `conda.model.channel.Channel`
and `libmamba.Repo` objects.
"""
from __future__ import annotations

import logging
import os
from dataclasses import dataclass
Expand Down Expand Up @@ -109,7 +111,7 @@ class LibMambaIndexHelper(IndexHelper):
def __init__(
self,
installed_records: Iterable[PackageRecord] = (),
channels: Iterable[Union[Channel, str]] = None,
channels: Iterable[Channel | str] = None,
subdirs: Iterable[str] = None,
repodata_fn: str = REPODATA_FN,
query_format=api.QueryFormat.JSON,
Expand Down Expand Up @@ -217,7 +219,7 @@ def _repo_from_records(
finally:
os.unlink(f.name)

def _fetch_channel(self, url: str) -> Tuple[str, os.PathLike]:
def _fetch_channel(self, url: str) -> tuple[str, Path, Path | None]:
channel = Channel.from_url(url)
if not channel.subdir:
raise ValueError(f"Channel URLs must specify a subdir! Provided: {url}")
Expand All @@ -230,11 +232,19 @@ def _fetch_channel(self, url: str) -> Tuple[str, os.PathLike]:
del SubdirData._cache_[(url, self._repodata_fn)]
# /Workaround

log.debug("Fetching %s with SubdirData.repo_fetch", channel)
subdir_data = SubdirData(channel, repodata_fn=self._repodata_fn)
json_path, _ = subdir_data.repo_fetch.fetch_latest_path()
# repo_fetch is created on each property access
repo_fetch = SubdirData(channel, repodata_fn=self._repodata_fn).repo_fetch
overlay_path = None
if hasattr(repo_fetch, "fetch_latest_path_and_overlay"):
log.debug(
"Fetching %s with SubdirData.repo_fetch.fetch_latest_path_and_overlay", channel
)
json_path, overlay_path, _ = repo_fetch.fetch_latest_path_and_overlay()
else:
log.debug("Fetching %s with SubdirData.repo_fetch", channel)
json_path, _ = repo_fetch.fetch_latest_path()

return url, json_path
return url, json_path, overlay_path

def _json_path_to_repo_info(
self, url: str, json_path: str, try_solv: bool = False
Expand Down Expand Up @@ -271,15 +281,27 @@ def _json_path_to_repo_info(
else:
path_to_use = json_path

repo = api.Repo(self._pool, noauth_url, str(path_to_use), escape_channel_url(noauth_url))
if overlay_path:
# from https://github.com/mamba-org/mamba/pull/2969
repo = api.Repo(
self._pool,
noauth_url,
str(path_to_use),
str(overlay_path),
escape_channel_url(noauth_url),
)
else:
repo = api.Repo(
self._pool, noauth_url, str(path_to_use), escape_channel_url(noauth_url)
)
return _ChannelRepoInfo(
repo=repo,
channel=channel,
full_url=url,
noauth_url=noauth_url,
)

def _load_channels(self) -> Dict[str, _ChannelRepoInfo]:
def _load_channels(self) -> dict[str, _ChannelRepoInfo]:
# 1. Obtain and deduplicate URLs from channels
urls = []
seen_noauth = set()
Expand Down Expand Up @@ -310,12 +332,15 @@ def _load_channels(self) -> Dict[str, _ChannelRepoInfo]:
else partial(ThreadLimitedThreadPoolExecutor, max_workers=context.repodata_threads)
)
with Executor() as executor:
jsons = {url: str(path) for (url, path) in executor.map(self._fetch_channel, urls)}
jsons = {
url: (path, overlay)
for (url, path, overlay) in executor.map(self._fetch_channel, urls)
}

# 3. Create repos in same order as `urls`
index = {}
for url in urls:
info = self._json_path_to_repo_info(url, jsons[url])
info = self._json_path_to_repo_info(url, *jsons[url])
if info is not None:
index[info.noauth_url] = info

Expand All @@ -330,24 +355,22 @@ def _load_installed(self, records: Iterable[PackageRecord]) -> api.Repo:
return repo

def whoneeds(
self, query: Union[str, MatchSpec], records=True
) -> Union[Iterable[PackageRecord], dict, str]:
self, query: str | MatchSpec, records=True
) -> Iterable[PackageRecord] | dict | str:
result_str = self._query.whoneeds(self._prepare_query(query), self._format)
if self._format == api.QueryFormat.JSON:
return self._process_query_result(result_str, records=records)
return result_str

def depends(
self, query: Union[str, MatchSpec], records=True
) -> Union[Iterable[PackageRecord], dict, str]:
self, query: str | MatchSpec, records=True
) -> Iterable[PackageRecord] | dict | str:
result_str = self._query.depends(self._prepare_query(query), self._format)
if self._format == api.QueryFormat.JSON:
return self._process_query_result(result_str, records=records)
return result_str

def search(
self, query: Union[str, MatchSpec], records=True
) -> Union[Iterable[PackageRecord], dict, str]:
def search(self, query: str | MatchSpec, records=True) -> Iterable[PackageRecord] | dict | str:
result_str = self._query.find(self._prepare_query(query), self._format)
if self._format == api.QueryFormat.JSON:
return self._process_query_result(result_str, records=records)
Expand All @@ -364,7 +387,7 @@ def explicit_pool(self, specs: Iterable[MatchSpec]) -> Iterable[str]:
explicit_pool.add(record.name)
return tuple(explicit_pool)

def _prepare_query(self, query: Union[str, MatchSpec]) -> str:
def _prepare_query(self, query: str | MatchSpec) -> str:
if isinstance(query, str):
if "[" not in query:
return query
Expand All @@ -391,7 +414,7 @@ def _process_query_result(
self,
result_str,
records=True,
) -> Union[Iterable[PackageRecord], dict]:
) -> Iterable[PackageRecord] | dict:
result = json_load(result_str)
if result.get("result", {}).get("status") != "OK":
query_type = result.get("query", {}).get("type", "<Unknown>")
Expand Down