diff --git a/.github/workflows/test_local_destinations.yml b/.github/workflows/test_local_destinations.yml index a4548f6529..61bfe1551a 100644 --- a/.github/workflows/test_local_destinations.yml +++ b/.github/workflows/test_local_destinations.yml @@ -119,7 +119,6 @@ jobs: DESTINATION__FILESYSTEM__CREDENTIALS__SFTP_USERNAME: foo DESTINATION__FILESYSTEM__CREDENTIALS__SFTP_PASSWORD: pass - - name: Stop weaviate if: always() run: docker compose -f "tests/load/weaviate/docker-compose.yml" down -v diff --git a/dlt/cli/init_command.py b/dlt/cli/init_command.py index a4a123359a..ac8adcc588 100644 --- a/dlt/cli/init_command.py +++ b/dlt/cli/init_command.py @@ -382,6 +382,16 @@ def init_command( source_configuration = files_ops.get_core_source_configuration( core_sources_storage, source_name ) + from importlib.metadata import Distribution + + dist = Distribution.from_name(DLT_PKG_NAME) + extras = dist.metadata.get_all("Provides-Extra") or [] + + # Match the extra name to the source name + canonical_source_name = source_name.replace("_", "-").lower() + + if canonical_source_name in extras: + source_configuration.requirements.update_dlt_extras(canonical_source_name) else: if not is_valid_schema_name(source_name): raise InvalidSchemaName(source_name) diff --git a/tests/cli/test_init_command.py b/tests/cli/test_init_command.py index d4ee1844d7..8e1affd164 100644 --- a/tests/cli/test_init_command.py +++ b/tests/cli/test_init_command.py @@ -4,7 +4,7 @@ import os import contextlib from subprocess import CalledProcessError -from typing import Any, List, Tuple, Optional +from typing import List, Tuple, Optional from hexbytes import HexBytes import pytest from unittest import mock @@ -55,7 +55,12 @@ # we hardcode the core sources here so we can check that the init script picks # up the right source -CORE_SOURCES = ["filesystem", "rest_api", "sql_database"] +CORE_SOURCES_CONFIG = { + "rest_api": {"requires_extra": False}, + "sql_database": {"requires_extra": True}, + "filesystem": {"requires_extra": True}, +} +CORE_SOURCES = list(CORE_SOURCES_CONFIG.keys()) # we also hardcode all the templates here for testing TEMPLATES = ["debug", "default", "arrow", "requests", "dataframe", "fruitshop", "github_api"] @@ -167,6 +172,37 @@ def test_init_list_sources(repo_dir: str) -> None: assert source in _out +@pytest.mark.parametrize( + "source_name", + [name for name in CORE_SOURCES_CONFIG if CORE_SOURCES_CONFIG[name]["requires_extra"]], +) +def test_init_command_core_source_requirements_with_extras( + source_name: str, repo_dir: str, project_files: FileStorage +) -> None: + init_command.init_command(source_name, "duckdb", repo_dir) + source_requirements = SourceRequirements.from_string( + project_files.load(cli_utils.REQUIREMENTS_TXT) + ) + canonical_name = source_name.replace("_", "-") + assert canonical_name in source_requirements.dlt_requirement.extras + + +@pytest.mark.parametrize( + "source_name", + [name for name in CORE_SOURCES_CONFIG if not CORE_SOURCES_CONFIG[name]["requires_extra"]], +) +def test_init_command_core_source_requirements_without_extras( + source_name: str, repo_dir: str, project_files: FileStorage +) -> None: + init_command.init_command(source_name, "duckdb", repo_dir) + source_requirements = SourceRequirements.from_string( + project_files.load(cli_utils.REQUIREMENTS_TXT) + ) + assert source_requirements.dlt_requirement.extras == { + "duckdb" + }, "Only duckdb should be in extras" + + def test_init_list_sources_update_warning(repo_dir: str, project_files: FileStorage) -> None: """Sources listed include a warning if a different dlt version is required""" with mock.patch.object(SourceRequirements, "current_dlt_version", return_value="0.0.1"): @@ -571,7 +607,7 @@ def assert_requirements_txt(project_files: FileStorage, destination_name: str) - project_files.load(cli_utils.REQUIREMENTS_TXT) ) assert destination_name in source_requirements.dlt_requirement.extras - # Check that atleast some version range is specified + # Check that at least some version range is specified assert len(source_requirements.dlt_requirement.specifier) >= 1