From 14b8cf867cb2416f6b7e4a392c120491c8158e13 Mon Sep 17 00:00:00 2001 From: Christopher Barber Date: Wed, 13 Sep 2023 09:19:15 -0400 Subject: [PATCH] Improve Requires-Dist handling code --- src/whl2conda/__init__.py | 2 + src/whl2conda/api/converter.py | 115 ++++++++++++++++++++++----------- test/api/test_converter.py | 61 +++++++++++++++++ 3 files changed, 142 insertions(+), 36 deletions(-) diff --git a/src/whl2conda/__init__.py b/src/whl2conda/__init__.py index 1c045f8..34f8d0a 100644 --- a/src/whl2conda/__init__.py +++ b/src/whl2conda/__init__.py @@ -16,3 +16,5 @@ """ from .__about__ import __version__ + +__all__ = ["__version__"] diff --git a/src/whl2conda/api/converter.py b/src/whl2conda/api/converter.py index cb02925..0d5c786 100644 --- a/src/whl2conda/api/converter.py +++ b/src/whl2conda/api/converter.py @@ -20,7 +20,9 @@ # standard import configparser +import dataclasses import email +import io import json import logging import re @@ -51,6 +53,7 @@ def __compile_requires_dist_re() -> re.Pattern: + # NOTE: these are currently fairly forgiving and will accept bad syntax name_re = r"(?P[a-zA-Z0-9_.-]+)" extra_re = r"(?:\[(?P.+?)\])?" version_re = r"(?:\(?(?P.*?)\)?)?" @@ -63,6 +66,11 @@ def __compile_requires_dist_re() -> re.Pattern: _requires_dist_re = __compile_requires_dist_re() +_extra_marker_re = [ + re.compile(r"""\bextra\s*==\s*(['"])(?P\w+)\1"""), + re.compile(r"""\b(['"])(?P\w+)\1\s*==\s*extra"""), +] + @dataclass class RequiresDistEntry: @@ -77,6 +85,12 @@ class RequiresDistEntry: version: str = "" marker: str = "" + extra_marker_name: str = "" + """Name from extra expression in marker, if any""" + + generic: bool = True + """True if marker is empty or only contains an extra expression""" + @classmethod def parse(cls, raw: str) -> RequiresDistEntry: """ @@ -95,8 +109,26 @@ def parse(cls, raw: str) -> RequiresDistEntry: entry.version = version if marker := m.group("marker"): entry.marker = marker + entry.generic = False + for pat in _extra_marker_re: + if m := pat.search(marker): + entry.extra_marker_name = m.group("name") + if m.string == marker: + entry.generic = True + break return entry + def __str__(self) -> str: + with io.StringIO() as buf: + buf.write(self.name) + if self.extras: + buf.write(f" [{','.join(self.extras)}]") + if self.version: + buf.write(f" {self.version}") + if self.marker: + buf.write(f" ; {self.marker}") + return buf.getvalue() + class Wheel2CondaError(RuntimeError): """Errors from Wheel2CondaConverter""" @@ -127,7 +159,7 @@ class MetadataFromWheel: version: str wheel_build_number: str license: Optional[str] - dependencies: list[str] + dependencies: list[RequiresDistEntry] wheel_info_dir: Path @@ -366,7 +398,6 @@ def _write_link_file(self, conda_info_dir: Path, wheel_info_dir: Path) -> None: if section_name in wheel_entry_points: if section := wheel_entry_points[section_name]: console_scripts.extend(f"{k}={v}" for k, v in section.items()) - # TODO - check correct setting for gui scripts (#20) conda_link_file.write_text( json.dumps( dict( @@ -473,27 +504,26 @@ def _write_about(self, conda_info_dir: Path, md: dict[str, Any]) -> None: ) # pylint: disable=too-many-locals - def _compute_conda_dependencies(self, dependencies: Sequence[str]) -> list[str]: + def _compute_conda_dependencies( + self, + dependencies: Sequence[RequiresDistEntry], + ) -> list[str]: conda_dependencies: list[str] = [] - for dep in dependencies: - try: - entry = RequiresDistEntry.parse(dep) - except SyntaxError as err: - self._warn(str(err)) - continue + # TODO - instead RequiresDistEntrys should be passed as an argument - if marker := entry.marker: - if "extra" in marker: - self._debug("Skipping extra dependency: %s", dep) - else: - # TODO - support inclusion in OS-specific package - self._warn("Skipping dependency with environment marker: %s", dep) + for entry in dependencies: + if entry.extra_marker_name: + self._debug("Skipping extra dependency: %s", entry) + continue + if not entry.generic: + # TODO - support non-generic packages + self._warn("Skipping dependency with environment marker: %s", entry) continue conda_name = pip_name = entry.name version = entry.version - # TODO - do something with extras + # TODO - do something with extras (#36) # download target pip package and its extra dependencies # check manual renames first renamed = False @@ -509,10 +539,10 @@ def _compute_conda_dependencies(self, dependencies: Sequence[str]) -> list[str]: if conda_name == pip_name: self._debug("Dependency copied: '%s'", conda_dep) else: - self._debug("Dependency renamed: '%s' -> '%s'", dep, conda_dep) + self._debug("Dependency renamed: '%s' -> '%s'", entry, conda_dep) conda_dependencies.append(conda_dep) else: - self._debug("Dependency dropped: %s", dep) + self._debug("Dependency dropped: %s", entry) for dep in self.extra_dependencies: self._debug("Dependency added: '%s'", dep) conda_dependencies.append(dep) @@ -545,21 +575,25 @@ def _copy_licenses(self, conda_info_dir: Path, wheel_md: MetadataFromWheel) -> N shutil.copyfile(from_file, to_file) break - # pylint: disable=too-many-locals + # pylint: disable=too-many-locals, too-many-statements def _parse_wheel_metadata(self, wheel_dir: Path) -> MetadataFromWheel: wheel_info_dir = next(wheel_dir.glob("*.dist-info")) WHEEL_file = wheel_info_dir.joinpath("WHEEL") WHEEL_msg = email.message_from_string(WHEEL_file.read_text("utf8")) # https://peps.python.org/pep-0427/#what-s-the-deal-with-purelib-vs-platlib + is_pure_lib = WHEEL_msg.get("Root-Is-Purelib", "").lower() == "true" wheel_build_number = WHEEL_msg.get("Build", "") wheel_version = WHEEL_msg.get("Wheel-Version") + if wheel_version not in self.SUPPORTED_WHEEL_VERSIONS: raise Wheel2CondaError( f"Wheel {self.wheel_path} has unsupported wheel version {wheel_version}" ) + if not is_pure_lib: raise Wheel2CondaError(f"Wheel {self.wheel_path} is not pure python") + wheel_md_file = wheel_info_dir.joinpath("METADATA") md: dict[str, list[Any]] = {} # Metdata spec: https://packaging.python.org/en/latest/specifications/core-metadata/ @@ -580,38 +614,47 @@ def _parse_wheel_metadata(self, wheel_dir: Path) -> MetadataFromWheel: md[mdkey.lower()] = mdval if mdkey in {"requires-dist", "requires"}: continue + + requires: list[RequiresDistEntry] = [] + raw_requires_entries = md.get("requires-dist", md.get("requires", ())) + for raw_entry in raw_requires_entries: + try: + entry = RequiresDistEntry.parse(raw_entry) + requires.append(entry) + except SyntaxError as err: + # TODO: error in strict mode? + self._warn(str(err)) + if not self.keep_pip_dependencies: + # Turn requirements into optional extra requirements del md_msg["Requires"] del md_msg["Requires-Dist"] - requires = md_msg.get_all("Requires-Dist", ()) - # Turn requirements into optional extra requirements - if requires: - for require in requires: - if ';' not in require: - md_msg.add_header( - "Requires-Dist", f"{require}; extra == 'original" - ) + for entry in requires: + if not entry.extra_marker_name: + marker = entry.marker + extra_clause = "extra == 'original'" + if marker: + marker = f"({entry.marker}) and {extra_clause}" else: - # FIXME: check for extra vs environment marker - md_msg.add_header("Requires-Dist", require) - md_msg.add_header("Provides-Extra", "original") + marker = extra_clause + entry = dataclasses.replace(entry, marker=marker) + md_msg.add_header("Requires-Dist", str(entry)) + md_msg.add_header("Provides-Extra", "original") wheel_md_file.write_text(md_msg.as_string()) package_name = self.package_name or str(md.get("name")) self.package_name = package_name version = md.get("version") - dependencies: list[str] = [] - python_version = md.get("requires-python") + + python_version: str = str(md.get("requires-python", "")) if python_version: - dependencies.append(f"python {python_version}") - # Use Requires-Dist if present, otherwise deprecated Requires keyword - dependencies.extend(md.get("requires-dist", md.get("requires", []))) + requires.append(RequiresDistEntry("python", version=python_version)) self.wheel_md = MetadataFromWheel( md=md, package_name=package_name, version=str(version), wheel_build_number=wheel_build_number, license=md.get("license-expression") or md.get("license"), # type: ignore - dependencies=dependencies, + dependencies=requires, wheel_info_dir=wheel_info_dir, ) return self.wheel_md diff --git a/test/api/test_converter.py b/test/api/test_converter.py index 7870316..13c8eb5 100644 --- a/test/api/test_converter.py +++ b/test/api/test_converter.py @@ -33,6 +33,7 @@ Wheel2CondaError, CondaPackageFormat, DependencyRename, + RequiresDistEntry, ) from whl2conda.cli.convert import do_build_wheel from whl2conda.cli.install import install_main @@ -220,6 +221,66 @@ def test_case( # pylint: disable=redefined-outer-name +# +# RequiresdistEntry test cases +# + + +def check_dist_entry(entry: RequiresDistEntry) -> None: + """Check invariants on RequiresDistEntr""" + if not entry.marker: + assert entry.generic + if entry.extra_marker_name: + assert 'extra' in entry.marker + assert entry.extra_marker_name in entry.marker + else: + # technically, there COULD be an extra in another environment + # expression, but it wouldn't make much sense + assert 'extra' not in entry.marker + if entry.marker: + assert not entry.generic + + raw = str(entry) + entry2 = RequiresDistEntry.parse(raw) + assert entry == entry2 + + +def test_requires_dist_entry() -> None: + """Test RequiresDistEntry data structure""" + entry = RequiresDistEntry.parse("foo") + assert entry.name == "foo" + assert not entry.extras + assert not entry.version + assert not entry.marker + check_dist_entry(entry) + + entry2 = RequiresDistEntry.parse("foo >=1.2") + assert entry != entry2 + assert entry2.name == "foo" + assert entry2.version == ">=1.2" + assert not entry2.extras + assert not entry2.marker + check_dist_entry(entry2) + + entry3 = RequiresDistEntry.parse("foo-bar [baz,blah]") + assert entry3.name == "foo-bar" + assert entry3.extras == ("baz", "blah") + assert not entry3.version + assert not entry3.marker + check_dist_entry(entry3) + + entry4 = RequiresDistEntry.parse("frodo ; extra=='LOTR'") + assert entry4.name == "frodo" + assert entry4.extra_marker_name == "LOTR" + assert entry4.marker == "extra=='LOTR'" + assert not entry4.version + assert not entry4.extras + check_dist_entry(entry4) + + with pytest.raises(SyntaxError): + RequiresDistEntry.parse("=123 : bad") + + # # DependencyRename test cases #