diff --git a/conda_libmamba_solver/index.py b/conda_libmamba_solver/index.py index edcc6824..f5e64e78 100644 --- a/conda_libmamba_solver/index.py +++ b/conda_libmamba_solver/index.py @@ -70,6 +70,8 @@ We maintain a map of subdir-specific URLs to `conda.model.channel.Channel` and `libmamba.Repo` objects. """ +from __future__ import annotations + import logging import os from dataclasses import dataclass @@ -109,7 +111,7 @@ class LibMambaIndexHelper(IndexHelper): def __init__( self, installed_records: Iterable[PackageRecord] = (), - channels: Iterable[Union[Channel, str]] = None, + channels: Iterable[Channel | str] = None, subdirs: Iterable[str] = None, repodata_fn: str = REPODATA_FN, query_format=api.QueryFormat.JSON, @@ -217,7 +219,7 @@ def _repo_from_records( finally: os.unlink(f.name) - def _fetch_channel(self, url: str) -> Tuple[str, os.PathLike]: + def _fetch_channel(self, url: str) -> tuple[str, Path, Path | None]: channel = Channel.from_url(url) if not channel.subdir: raise ValueError(f"Channel URLs must specify a subdir! Provided: {url}") @@ -230,11 +232,19 @@ def _fetch_channel(self, url: str) -> Tuple[str, os.PathLike]: del SubdirData._cache_[(url, self._repodata_fn)] # /Workaround - log.debug("Fetching %s with SubdirData.repo_fetch", channel) - subdir_data = SubdirData(channel, repodata_fn=self._repodata_fn) - json_path, _ = subdir_data.repo_fetch.fetch_latest_path() + # repo_fetch is created on each property access + repo_fetch = SubdirData(channel, repodata_fn=self._repodata_fn).repo_fetch + overlay_path = None + if hasattr(repo_fetch, "fetch_latest_path_and_overlay"): + log.debug( + "Fetching %s with SubdirData.repo_fetch.fetch_latest_path_and_overlay", channel + ) + json_path, overlay_path, _ = repo_fetch.fetch_latest_path_and_overlay() + else: + log.debug("Fetching %s with SubdirData.repo_fetch", channel) + json_path, _ = repo_fetch.fetch_latest_path() - return url, json_path + return url, json_path, overlay_path def _json_path_to_repo_info( self, url: str, json_path: str, try_solv: bool = False @@ -271,7 +281,19 @@ def _json_path_to_repo_info( else: path_to_use = json_path - repo = api.Repo(self._pool, noauth_url, str(path_to_use), escape_channel_url(noauth_url)) + if overlay_path: + # from https://github.com/mamba-org/mamba/pull/2969 + repo = api.Repo( + self._pool, + noauth_url, + str(path_to_use), + str(overlay_path), + escape_channel_url(noauth_url), + ) + else: + repo = api.Repo( + self._pool, noauth_url, str(path_to_use), escape_channel_url(noauth_url) + ) return _ChannelRepoInfo( repo=repo, channel=channel, @@ -279,7 +301,7 @@ def _json_path_to_repo_info( noauth_url=noauth_url, ) - def _load_channels(self) -> Dict[str, _ChannelRepoInfo]: + def _load_channels(self) -> dict[str, _ChannelRepoInfo]: # 1. Obtain and deduplicate URLs from channels urls = [] seen_noauth = set() @@ -310,12 +332,15 @@ def _load_channels(self) -> Dict[str, _ChannelRepoInfo]: else partial(ThreadLimitedThreadPoolExecutor, max_workers=context.repodata_threads) ) with Executor() as executor: - jsons = {url: str(path) for (url, path) in executor.map(self._fetch_channel, urls)} + jsons = { + url: (path, overlay) + for (url, path, overlay) in executor.map(self._fetch_channel, urls) + } # 3. Create repos in same order as `urls` index = {} for url in urls: - info = self._json_path_to_repo_info(url, jsons[url]) + info = self._json_path_to_repo_info(url, *jsons[url]) if info is not None: index[info.noauth_url] = info @@ -330,24 +355,22 @@ def _load_installed(self, records: Iterable[PackageRecord]) -> api.Repo: return repo def whoneeds( - self, query: Union[str, MatchSpec], records=True - ) -> Union[Iterable[PackageRecord], dict, str]: + self, query: str | MatchSpec, records=True + ) -> Iterable[PackageRecord] | dict | str: result_str = self._query.whoneeds(self._prepare_query(query), self._format) if self._format == api.QueryFormat.JSON: return self._process_query_result(result_str, records=records) return result_str def depends( - self, query: Union[str, MatchSpec], records=True - ) -> Union[Iterable[PackageRecord], dict, str]: + self, query: str | MatchSpec, records=True + ) -> Iterable[PackageRecord] | dict | str: result_str = self._query.depends(self._prepare_query(query), self._format) if self._format == api.QueryFormat.JSON: return self._process_query_result(result_str, records=records) return result_str - def search( - self, query: Union[str, MatchSpec], records=True - ) -> Union[Iterable[PackageRecord], dict, str]: + def search(self, query: str | MatchSpec, records=True) -> Iterable[PackageRecord] | dict | str: result_str = self._query.find(self._prepare_query(query), self._format) if self._format == api.QueryFormat.JSON: return self._process_query_result(result_str, records=records) @@ -364,7 +387,7 @@ def explicit_pool(self, specs: Iterable[MatchSpec]) -> Iterable[str]: explicit_pool.add(record.name) return tuple(explicit_pool) - def _prepare_query(self, query: Union[str, MatchSpec]) -> str: + def _prepare_query(self, query: str | MatchSpec) -> str: if isinstance(query, str): if "[" not in query: return query @@ -391,7 +414,7 @@ def _process_query_result( self, result_str, records=True, - ) -> Union[Iterable[PackageRecord], dict]: + ) -> Iterable[PackageRecord] | dict: result = json_load(result_str) if result.get("result", {}).get("status") != "OK": query_type = result.get("query", {}).get("type", "")