Skip to content

Commit

Permalink
Replace OpenAI with Cohere in LanceDB custom destination example
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy committed May 13, 2024
1 parent 3ea0d2b commit bc0567d
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 69 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
[spotify]
client_id = ""
client_secret = ""
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
import os
from dataclasses import dataclass, fields
from pathlib import Path
from typing import Dict, Any, Optional
from typing import Optional, Dict, Any

import lancedb # type: ignore
from lancedb.embeddings import get_registry, OpenAIEmbeddings # type: ignore
from lancedb.embeddings.registry import EmbeddingFunctionRegistry # type: ignore
from lancedb.pydantic import LanceModel, Vector # type: ignore

import dlt
Expand All @@ -33,25 +33,28 @@


BASE_SPOTIFY_URL = "https://api.spotify.com/v1"
os.environ["SPOTIFY__CLIENT_ID"] = ""
os.environ["SPOTIFY__CLIENT_SECRET"] = ""
os.environ["OPENAI_API_KEY"] = ""

# Spotify client ID and secret. Get these from https://developer.spotify.com/.
os.environ["CLIENT_ID"] = ""
os.environ["CLIENT_SECRET"] = ""

os.environ["COHERE_API_KEY"] = ""

# Where would you like to store your embeddings?
DB_PATH = "spotify.db"

# LanceDB global registry keeps track of text embedding callables implicitly.
openai = get_registry().get("openai")

embedding_model = openai.create()
cohere = EmbeddingFunctionRegistry
func = EmbeddingFunctionRegistry.get_instance().get("cohere").create(max_retries=1)

db_path = Path(DB_PATH)


class EpisodeSchema(LanceModel):
id: str # noqa: A003
name: str
description: str = embedding_model.SourceField()
vector: Vector(embedding_model.ndims()) = embedding_model.VectorField() # type: ignore[valid-type]
description: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField() # type: ignore[valid-type]
release_date: datetime.date
href: str

Expand Down Expand Up @@ -98,7 +101,9 @@ def fetch_show_episode_data(


@dlt.source
def spotify_shows(client_id: str = dlt.secrets.value, client_secret: str = dlt.secrets.value):
def spotify_shows(
client_id: str = dlt.secrets.value, client_secret: str = dlt.secrets.value
):
access_token: str = get_spotify_access_token(client_id, client_secret)
params: Dict[str, Any] = {"limit": 50}
for show in fields(Shows):
Expand All @@ -108,8 +113,7 @@ def spotify_shows(client_id: str = dlt.secrets.value, client_secret: str = dlt.s
fetch_show_episode_data(show_id, access_token, params),
name=show_name,
write_disposition="merge",
primary_key="id",
parallelized=True,
primary_key="id", # parallelized=True,
max_table_nesting=0,
)

Expand All @@ -121,7 +125,6 @@ def lancedb_destination(items: TDataItems, table: TTableSchema) -> None:
tbl = db.open_table(table["name"])
except FileNotFoundError:
tbl = db.create_table(table["name"], schema=EpisodeSchema)
tbl.checkout_latest()
tbl.add(items)


Expand All @@ -139,7 +142,7 @@ def lancedb_destination(items: TDataItems, table: TTableSchema) -> None:
)

load_info = pipeline.run(
spotify_shows(client_id=dlt.secrets.value, client_secret=dlt.secrets.value)
spotify_shows(client_id=dlt.secrets.value, client_secret=dlt.secrets.value),
)

row_counts = pipeline.last_trace.last_normalize_info
Expand All @@ -161,7 +164,6 @@ def lancedb_destination(items: TDataItems, table: TTableSchema) -> None:
print(f"Querying table: {table_to_query}")

tbl = db.open_table(table_to_query)
tbl.checkout_latest()

results = tbl.search(query=query).to_list()
print(results)
assert results
77 changes: 30 additions & 47 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ dbt-databricks = {version = ">=1.7.3", optional = true}
clickhouse-driver = { version = ">=0.2.7", optional = true }
clickhouse-connect = { version = ">=0.7.7", optional = true }
lancedb = { version = ">=0.6.13", optional = true }
openai = { version = ">=1.28.1", optional = true }
cohere = { version = ">=3.0.0", optional = true }

[tool.poetry.extras]
dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"]
Expand All @@ -109,7 +109,7 @@ qdrant = ["qdrant-client"]
databricks = ["databricks-sql-connector"]
clickhouse = ["clickhouse-driver", "clickhouse-connect", "s3fs", "gcsfs", "adlfs", "pyarrow"]
dremio = ["pyarrow"]
lancedb = ["lancedb", "openai"]
lancedb = ["lancedb", "cohere"]

[tool.poetry.scripts]
dlt = "dlt.cli._dlt:_main"
Expand Down

0 comments on commit bc0567d

Please sign in to comment.