diff --git a/docs/examples/custom_destination_lancedb/.dlt/example.secrets.toml b/docs/examples/custom_destination_lancedb/.dlt/example.secrets.toml index 275d1ada76..e69de29bb2 100644 --- a/docs/examples/custom_destination_lancedb/.dlt/example.secrets.toml +++ b/docs/examples/custom_destination_lancedb/.dlt/example.secrets.toml @@ -1,3 +0,0 @@ -[spotify] -client_id = "" -client_secret = "" \ No newline at end of file diff --git a/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py b/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py index 977c332a7b..6f70824ce3 100644 --- a/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py +++ b/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py @@ -20,10 +20,10 @@ import os from dataclasses import dataclass, fields from pathlib import Path -from typing import Dict, Any, Optional +from typing import Optional, Dict, Any import lancedb # type: ignore -from lancedb.embeddings import get_registry, OpenAIEmbeddings # type: ignore +from lancedb.embeddings.registry import EmbeddingFunctionRegistry # type: ignore from lancedb.pydantic import LanceModel, Vector # type: ignore import dlt @@ -33,16 +33,19 @@ BASE_SPOTIFY_URL = "https://api.spotify.com/v1" -os.environ["SPOTIFY__CLIENT_ID"] = "" -os.environ["SPOTIFY__CLIENT_SECRET"] = "" -os.environ["OPENAI_API_KEY"] = "" +# Spotify client ID and secret. Get these from https://developer.spotify.com/. +os.environ["CLIENT_ID"] = "" +os.environ["CLIENT_SECRET"] = "" + +os.environ["COHERE_API_KEY"] = "" + +# Where would you like to store your embeddings? DB_PATH = "spotify.db" # LanceDB global registry keeps track of text embedding callables implicitly. -openai = get_registry().get("openai") - -embedding_model = openai.create() +cohere = EmbeddingFunctionRegistry +func = EmbeddingFunctionRegistry.get_instance().get("cohere").create(max_retries=1) db_path = Path(DB_PATH) @@ -50,8 +53,8 @@ class EpisodeSchema(LanceModel): id: str # noqa: A003 name: str - description: str = embedding_model.SourceField() - vector: Vector(embedding_model.ndims()) = embedding_model.VectorField() # type: ignore[valid-type] + description: str = func.SourceField() + vector: Vector(func.ndims()) = func.VectorField() # type: ignore[valid-type] release_date: datetime.date href: str @@ -98,7 +101,9 @@ def fetch_show_episode_data( @dlt.source -def spotify_shows(client_id: str = dlt.secrets.value, client_secret: str = dlt.secrets.value): +def spotify_shows( + client_id: str = dlt.secrets.value, client_secret: str = dlt.secrets.value +): access_token: str = get_spotify_access_token(client_id, client_secret) params: Dict[str, Any] = {"limit": 50} for show in fields(Shows): @@ -108,8 +113,7 @@ def spotify_shows(client_id: str = dlt.secrets.value, client_secret: str = dlt.s fetch_show_episode_data(show_id, access_token, params), name=show_name, write_disposition="merge", - primary_key="id", - parallelized=True, + primary_key="id", # parallelized=True, max_table_nesting=0, ) @@ -121,7 +125,6 @@ def lancedb_destination(items: TDataItems, table: TTableSchema) -> None: tbl = db.open_table(table["name"]) except FileNotFoundError: tbl = db.create_table(table["name"], schema=EpisodeSchema) - tbl.checkout_latest() tbl.add(items) @@ -139,7 +142,7 @@ def lancedb_destination(items: TDataItems, table: TTableSchema) -> None: ) load_info = pipeline.run( - spotify_shows(client_id=dlt.secrets.value, client_secret=dlt.secrets.value) + spotify_shows(client_id=dlt.secrets.value, client_secret=dlt.secrets.value), ) row_counts = pipeline.last_trace.last_normalize_info @@ -161,7 +164,6 @@ def lancedb_destination(items: TDataItems, table: TTableSchema) -> None: print(f"Querying table: {table_to_query}") tbl = db.open_table(table_to_query) - tbl.checkout_latest() results = tbl.search(query=query).to_list() - print(results) + assert results diff --git a/poetry.lock b/poetry.lock index 737bfce606..975b386124 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1883,6 +1883,22 @@ lz4 = ["clickhouse-cityhash (>=1.0.2.1)", "lz4", "lz4 (<=3.0.1)"] numpy = ["numpy (>=1.12.0)", "pandas (>=0.24.0)"] zstd = ["clickhouse-cityhash (>=1.0.2.1)", "zstd"] +[[package]] +name = "cohere" +version = "5.1.5" +description = "" +optional = true +python-versions = "<4.0,>=3.8" +files = [ + {file = "cohere-5.1.5-py3-none-any.whl", hash = "sha256:10a2bc0aab8b3a03d5da84412a6309aeb2b267d852b969d8b6acd0a19176b5a9"}, + {file = "cohere-5.1.5.tar.gz", hash = "sha256:2ed04ac3fb1b4e3a1e243cc94898158b3be41cd64165dd0910952d73207298fa"}, +] + +[package.dependencies] +httpx = ">=0.21.2" +pydantic = ">=1.9.2" +typing_extensions = ">=4.0.0" + [[package]] name = "colorama" version = "0.4.6" @@ -2465,17 +2481,6 @@ files = [ [package.extras] graph = ["objgraph (>=1.7.2)"] -[[package]] -name = "distro" -version = "1.9.0" -description = "Distro - an OS platform information API" -optional = true -python-versions = ">=3.6" -files = [ - {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, - {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, -] - [[package]] name = "dnspython" version = "2.4.2" @@ -4199,24 +4204,24 @@ files = [ [[package]] name = "httpcore" -version = "0.17.3" +version = "1.0.5" description = "A minimal low-level HTTP client." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "httpcore-0.17.3-py3-none-any.whl", hash = "sha256:c2789b767ddddfa2a5782e3199b2b7f6894540b17b16ec26b2c4d8e103510b87"}, - {file = "httpcore-0.17.3.tar.gz", hash = "sha256:a6f30213335e34c1ade7be6ec7c47f19f50c56db36abef1a9dfa3815b1cb3888"}, + {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, + {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, ] [package.dependencies] -anyio = ">=3.0,<5.0" certifi = "*" h11 = ">=0.13,<0.15" -sniffio = "==1.*" [package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.26.0)"] [[package]] name = "httplib2" @@ -4234,19 +4239,20 @@ pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0 [[package]] name = "httpx" -version = "0.24.1" +version = "0.27.0" description = "The next generation HTTP client." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "httpx-0.24.1-py3-none-any.whl", hash = "sha256:06781eb9ac53cde990577af654bd990a4949de37a28bdb4a230d434f3a30b9bd"}, - {file = "httpx-0.24.1.tar.gz", hash = "sha256:5853a43053df830c20f8110c5e69fe44d035d850b2dfe795e196f00fdb774bdd"}, + {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, + {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, ] [package.dependencies] +anyio = "*" certifi = "*" h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""} -httpcore = ">=0.15.0,<0.18.0" +httpcore = "==1.*" idna = "*" sniffio = "*" @@ -5835,29 +5841,6 @@ packaging = "*" protobuf = "*" sympy = "*" -[[package]] -name = "openai" -version = "1.28.1" -description = "The official Python library for the openai API" -optional = true -python-versions = ">=3.7.1" -files = [ - {file = "openai-1.28.1-py3-none-any.whl", hash = "sha256:943e0d0d587b9a62f99bd3acbaf479ae5362986e5fff013f57b5b7bde85cce93"}, - {file = "openai-1.28.1.tar.gz", hash = "sha256:8a3adbba16882434768d76fd3129fcc9b40ace98f8d55a6ddacfc05c4096ac30"}, -] - -[package.dependencies] -anyio = ">=3.5.0,<5" -distro = ">=1.7.0,<2" -httpx = ">=0.23.0,<1" -pydantic = ">=1.9.0,<3" -sniffio = "*" -tqdm = ">4" -typing-extensions = ">=4.7,<5" - -[package.extras] -datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] - [[package]] name = "openpyxl" version = "3.1.2" @@ -9571,7 +9554,7 @@ duckdb = ["duckdb", "duckdb"] filesystem = ["botocore", "s3fs"] gcp = ["gcsfs", "google-cloud-bigquery", "grpcio"] gs = ["gcsfs"] -lancedb = ["lancedb", "openai"] +lancedb = ["cohere", "lancedb"] motherduck = ["duckdb", "duckdb", "pyarrow"] mssql = ["pyodbc"] parquet = ["pyarrow"] @@ -9586,4 +9569,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "7524509cb6e931436cdcb8a5adf37fda7ca56a723fd9e0dbc91d3c43e223b026" +content-hash = "f06280cc33ebd7bc9e2478493d576d7acde770665149f2ea8d95a86b98c0ed20" diff --git a/pyproject.toml b/pyproject.toml index 105e2ca14f..ab29e26594 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,7 @@ dbt-databricks = {version = ">=1.7.3", optional = true} clickhouse-driver = { version = ">=0.2.7", optional = true } clickhouse-connect = { version = ">=0.7.7", optional = true } lancedb = { version = ">=0.6.13", optional = true } -openai = { version = ">=1.28.1", optional = true } +cohere = { version = ">=3.0.0", optional = true } [tool.poetry.extras] dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"] @@ -109,7 +109,7 @@ qdrant = ["qdrant-client"] databricks = ["databricks-sql-connector"] clickhouse = ["clickhouse-driver", "clickhouse-connect", "s3fs", "gcsfs", "adlfs", "pyarrow"] dremio = ["pyarrow"] -lancedb = ["lancedb", "openai"] +lancedb = ["lancedb", "cohere"] [tool.poetry.scripts] dlt = "dlt.cli._dlt:_main"