Skip to content

Commit

Permalink
setup.py:
Browse files Browse the repository at this point in the history
  - Add --dev flag for dev tool install
  - Group pytorch items together for install
  • Loading branch information
torzdf committed Apr 16, 2024
1 parent de09ec9 commit 61bd910
Showing 1 changed file with 36 additions and 4 deletions.
40 changes: 36 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@
"zlib-wapi": ("zlib-wapi", ("conda-forge", )),
"xorg-libxft": ("xorg-libxft", ("conda-forge", ))}

_GROUPS = [["pytorch*", "torch*"]]
"""list[list[str]]: Packages that should be installed collectively at the same time """

_DEV_TOOLS = ["flake8", "mypy", "pylint", "types-setuptools", "types-PyYAML"]

# Force output to utf-8
sys.stdout.reconfigure(encoding="utf-8", errors="replace") # type:ignore[attr-defined]

Expand All @@ -77,6 +82,7 @@ def __init__(self, updater: bool = False) -> None:
self.updater = updater
# Flag that setup is being run by installer so steps can be skipped
self.is_installer: bool = False
self.include_dev_tools: bool = False
self.backend: backend_type | None = None
self.enable_docker: bool = False
self.cuda_cudnn = ["", ""]
Expand Down Expand Up @@ -153,6 +159,10 @@ def _process_arguments(self) -> None:
for arg in args:
if arg == "--installer":
self.is_installer = True
continue
if arg == "--dev":
self.include_dev_tools = True
continue
if not self.backend and (arg.startswith("--") and
arg.replace("--", "") in self._backends):
self.backend = arg.replace("--", "").lower() # type:ignore
Expand Down Expand Up @@ -461,6 +471,9 @@ def get_required_packages(self) -> None:
if package and (not package.startswith(("#", "-r"))):
requirements.append(package)

if self._env.include_dev_tools:
requirements.extend(_DEV_TOOLS)

self._required_packages = self._format_requirements(requirements)
logger.debug(self._required_packages)

Expand Down Expand Up @@ -510,7 +523,7 @@ def check_missing_dependencies(self) -> None:
self._check_conda_missing_dependencies()


class Checks():
class Checks(): # pylint:disable=too-few-public-methods
""" Pre-installation checks
Parameters
Expand Down Expand Up @@ -746,7 +759,7 @@ def _rocm_check(self) -> None:
return


class CudaCheck():
class CudaCheck(): # pylint:disable=too-few-public-methods
""" Find the location of system installed Cuda and cuDNN on Windows and Linux. """

def __init__(self) -> None:
Expand Down Expand Up @@ -898,7 +911,7 @@ def _get_checkfiles_windows(self) -> list[str]:
return cudnn_checkfiles


class Install():
class Install(): # pylint:disable=too-few-public-methods
""" Handles installation of Faceswap requirements
Parameters
Expand Down Expand Up @@ -999,6 +1012,24 @@ def _install_setup_packages(self) -> None:
logger.error("Unable to install package: %s. Process aborted", clean_pkg)
sys.exit(1)

def _install_grouped_packages(self) -> None:
""" Install packages that should be installed collectively as a group """
if not self._env.is_conda:
return

packages = []
channels: set[str] = set()
for group in _GROUPS:
for item in group:
for idx, pkg in reversed(list(enumerate(self._packages.to_install))):
if item == pkg[0] or (item.endswith("*") and pkg[0].startswith(item[:-1])):
i_pkg = self._packages.to_install.pop(idx)
packages.append(self._format_package(*i_pkg))
channels.update(c for c in _CONDA_MAPPING.get(i_pkg[0],
(i_pkg[0],
("defaults", )))[-1])
self._from_conda(packages, tuple(channels), conda_only=True)

def _install_conda_packages(self) -> None:
""" Install required conda packages """
logger.info("Installing Required Conda Packages. This may take some time...")
Expand All @@ -1021,7 +1052,8 @@ def _install_python_packages(self) -> None:

def _install_missing_dep(self) -> None:
""" Install missing dependencies """
self._install_conda_packages() # Install conda packages first
self._install_conda_packages() # Install required conda packages first
self._install_grouped_packages() # Then install grouped packages
self._install_python_packages()

def _from_conda(self,
Expand Down

0 comments on commit 61bd910

Please sign in to comment.