Skip to content

Commit

Permalink
Add core sources extras to requirements in dlt init (#2028)
Browse files Browse the repository at this point in the history
  • Loading branch information
burnash authored Nov 13, 2024
1 parent 675b309 commit 0ea9de7
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 4 deletions.
1 change: 0 additions & 1 deletion .github/workflows/test_local_destinations.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ jobs:
DESTINATION__FILESYSTEM__CREDENTIALS__SFTP_USERNAME: foo
DESTINATION__FILESYSTEM__CREDENTIALS__SFTP_PASSWORD: pass


- name: Stop weaviate
if: always()
run: docker compose -f "tests/load/weaviate/docker-compose.yml" down -v
Expand Down
10 changes: 10 additions & 0 deletions dlt/cli/init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,16 @@ def init_command(
source_configuration = files_ops.get_core_source_configuration(
core_sources_storage, source_name
)
from importlib.metadata import Distribution

dist = Distribution.from_name(DLT_PKG_NAME)
extras = dist.metadata.get_all("Provides-Extra") or []

# Match the extra name to the source name
canonical_source_name = source_name.replace("_", "-").lower()

if canonical_source_name in extras:
source_configuration.requirements.update_dlt_extras(canonical_source_name)
else:
if not is_valid_schema_name(source_name):
raise InvalidSchemaName(source_name)
Expand Down
42 changes: 39 additions & 3 deletions tests/cli/test_init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import contextlib
from subprocess import CalledProcessError
from typing import Any, List, Tuple, Optional
from typing import List, Tuple, Optional
from hexbytes import HexBytes
import pytest
from unittest import mock
Expand Down Expand Up @@ -55,7 +55,12 @@

# we hardcode the core sources here so we can check that the init script picks
# up the right source
CORE_SOURCES = ["filesystem", "rest_api", "sql_database"]
CORE_SOURCES_CONFIG = {
"rest_api": {"requires_extra": False},
"sql_database": {"requires_extra": True},
"filesystem": {"requires_extra": True},
}
CORE_SOURCES = list(CORE_SOURCES_CONFIG.keys())

# we also hardcode all the templates here for testing
TEMPLATES = ["debug", "default", "arrow", "requests", "dataframe", "fruitshop", "github_api"]
Expand Down Expand Up @@ -167,6 +172,37 @@ def test_init_list_sources(repo_dir: str) -> None:
assert source in _out


@pytest.mark.parametrize(
"source_name",
[name for name in CORE_SOURCES_CONFIG if CORE_SOURCES_CONFIG[name]["requires_extra"]],
)
def test_init_command_core_source_requirements_with_extras(
source_name: str, repo_dir: str, project_files: FileStorage
) -> None:
init_command.init_command(source_name, "duckdb", repo_dir)
source_requirements = SourceRequirements.from_string(
project_files.load(cli_utils.REQUIREMENTS_TXT)
)
canonical_name = source_name.replace("_", "-")
assert canonical_name in source_requirements.dlt_requirement.extras


@pytest.mark.parametrize(
"source_name",
[name for name in CORE_SOURCES_CONFIG if not CORE_SOURCES_CONFIG[name]["requires_extra"]],
)
def test_init_command_core_source_requirements_without_extras(
source_name: str, repo_dir: str, project_files: FileStorage
) -> None:
init_command.init_command(source_name, "duckdb", repo_dir)
source_requirements = SourceRequirements.from_string(
project_files.load(cli_utils.REQUIREMENTS_TXT)
)
assert source_requirements.dlt_requirement.extras == {
"duckdb"
}, "Only duckdb should be in extras"


def test_init_list_sources_update_warning(repo_dir: str, project_files: FileStorage) -> None:
"""Sources listed include a warning if a different dlt version is required"""
with mock.patch.object(SourceRequirements, "current_dlt_version", return_value="0.0.1"):
Expand Down Expand Up @@ -571,7 +607,7 @@ def assert_requirements_txt(project_files: FileStorage, destination_name: str) -
project_files.load(cli_utils.REQUIREMENTS_TXT)
)
assert destination_name in source_requirements.dlt_requirement.extras
# Check that atleast some version range is specified
# Check that at least some version range is specified
assert len(source_requirements.dlt_requirement.specifier) >= 1


Expand Down

0 comments on commit 0ea9de7

Please sign in to comment.