Skip to content
This repository has been archived by the owner on Jul 16, 2024. It is now read-only.

Commit

Permalink
Support more index types besides ivfflat (#224)
Browse files Browse the repository at this point in the history
Previously, we only support indexing embeddings using the `ivfflat`
access method in [pgvector](https://github.com/pgvector/pgvector).

Recently, a new access method `hnsw` has been added to pgvector.
`hnsw` is believed to be more performant and accurate than `ivfflat`.
To allow for more flexibility, we add a new parameter `method` to
allow user to choose which access method to use when creating index.

Also, a new parameter `embedding_dimension` is added to support more
models, since dimension is required for pgvector to create index.

A new test case for embeddings is added in `tests/test_embedding.py`.

To support `set allow_system_table_mods = on;`, Postgres is upgraded
from 12 to 13 on CI.
  • Loading branch information
xuebinsu authored Nov 23, 2023
1 parent 417dcc8 commit 4f52aae
Show file tree
Hide file tree
Showing 15 changed files with 131 additions and 85 deletions.
18 changes: 13 additions & 5 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ jobs:
fail-fast: true
matrix:
python-version: ["3.9", "3.11"]
server: ["postgres12-python39", "postgres12-python311"]
server: ["postgres13-python39", "postgres13-python311"]
include:
- server: "postgres12-python39"
- server: "postgres13-python39"
server-python-version: "3.9"
- server: "postgres12-python311"
- server: "postgres13-python311"
server-python-version: "3.11"

steps:
Expand All @@ -30,11 +30,18 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
cache: pip
- name: Build server image
run: |
docker build \
-t greenplumpython-server:${{ matrix.server }} \
-f server/${{ matrix.server }}.Dockerfile \
server/
- name: Run tests without pickler
run: |
python3 -m pip install tox~=4.11 tox-docker~=4.1 && \
tox \
--override=docker:server.dockerfile=server/${{ matrix.server }}.Dockerfile \
--override=docker:server.image=greenplumpython-server:${{ matrix.server }} \
--override=docker:server.dockerfile='' \
-e test-container \
-- \
--override-ini=server_use_pickler=false \
Expand All @@ -43,7 +50,8 @@ jobs:
if: ${{ matrix.python-version == matrix.server-python-version }}
run: |
tox \
--override=docker:server.dockerfile=server/${{ matrix.server }}.Dockerfile \
--override=docker:server.image=greenplumpython-server:${{ matrix.server }} \
--override=docker:server.dockerfile='' \
-e test-container \
-- \
--override-ini=server_use_pickler=true \
54 changes: 0 additions & 54 deletions concourse/test.sh

This file was deleted.

3 changes: 2 additions & 1 deletion doc/source/notebooks/embedding.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
" distribution_key={\"id\"},\n",
" distribution_type=\"hash\",\n",
" drop_if_exists=True,\n",
" drop_cascade=True,\n",
" )\n",
" .check_unique(columns={\"id\"})\n",
")"
Expand Down Expand Up @@ -128,7 +129,7 @@
"source": [
"import greenplumpython.experimental.embedding\n",
"\n",
"t = t.embedding().create_index(column=\"content\", model=\"all-MiniLM-L6-v2\")\n",
"t = t.embedding().create_index(column=\"content\", model_name=\"all-MiniLM-L6-v2\")\n",
"t"
]
},
Expand Down
2 changes: 1 addition & 1 deletion greenplumpython/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,7 @@ def save_as(
if distribution_type == "replicated"
else "RANDOMLY"}
"""
if distribution_type is not None
if self._db._is_variant("greenplum") and distribution_type is not None
else ""
)
if drop_cascade:
Expand Down
40 changes: 26 additions & 14 deletions greenplumpython/experimental/embedding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, cast
from typing import Any, Callable, Literal, Optional, cast
from uuid import uuid4

import greenplumpython as gp
Expand Down Expand Up @@ -75,7 +75,13 @@ class Embedding:
def __init__(self, dataframe: gp.DataFrame) -> None:
self._dataframe = dataframe

def create_index(self, column: str, model_name: str) -> gp.DataFrame:
def create_index(
self,
column: str,
model_name: str,
embedding_dimension: Optional[int] = None,
method: Optional[Literal["ivfflat", "hnsw"]] = "hnsw",
) -> gp.DataFrame:
"""
Generate embeddings and create index for a column of unstructured data.
Expand All @@ -96,6 +102,8 @@ def create_index(self, column: str, model_name: str) -> gp.DataFrame:
Args:
column: name of column to create index on.
model_name: name of model to generate embedding.
embedding_dimension: dimension of the embedding.
method: name of the index access method (i.e. index type) in `pgvector <https://github.com/pgvector/pgvector>`_.
Returns:
Dataframe with target column indexed based on embeddings.
Expand All @@ -105,17 +113,17 @@ def create_index(self, column: str, model_name: str) -> gp.DataFrame:
"""

import sentence_transformers # type: ignore reportMissingImports

model = sentence_transformers.SentenceTransformer(model_name) # type: ignore reportUnknownVariableType

assert self._dataframe.unique_key is not None, "Unique key is required to create index."
try:
word_embedding_dimension: int = model[1].word_embedding_dimension # From models.Pooling
except:
raise NotImplementedError(
"Model '{model_name}' doesn't provide embedding dimension information"
)
if embedding_dimension is None:
try:
import sentence_transformers # type: ignore reportMissingImports

model = sentence_transformers.SentenceTransformer(model_name) # type: ignore reportUnknownVariableType
embedding_dimension: int = model[1].word_embedding_dimension # From models.Pooling
except:
raise NotImplementedError(
"Model '{model_name}' doesn't provide embedding dimension information"
)

embedding_col_name = "_emb_" + uuid4().hex
embedding_df_cols = list(self._dataframe.unique_key) + [embedding_col_name]
Expand All @@ -126,7 +134,7 @@ def create_index(self, column: str, model_name: str) -> gp.DataFrame:
Callable[[gp.DataFrame], TypeCast],
# FIXME: Modifier must be adapted to all types of model.
# Can this be done with transformers.AutoConfig?
lambda t: gp.type_("vector", modifier=word_embedding_dimension)(_generate_embedding(t[column], model_name)), # type: ignore reportUnknownLambdaType
lambda t: gp.type_("vector", modifier=embedding_dimension)(_generate_embedding(t[column], model_name)), # type: ignore reportUnknownLambdaType
)
},
)[embedding_df_cols]
Expand All @@ -136,8 +144,12 @@ def create_index(self, column: str, model_name: str) -> gp.DataFrame:
distribution_type="hash",
)
.check_unique(self._dataframe.unique_key)
.create_index(columns={embedding_col_name}, method="ivfflat")
)
if method is not None:
assert method in ["ivfflat", "hnsw"]
embedding_df = embedding_df.create_index(
columns={embedding_col_name: "vector_l2_ops"}, method=method
)
assert self._dataframe._db is not None
_record_dependency._create_in_db(self._dataframe._db)
sql_add_relationship = f"""
Expand Down
16 changes: 14 additions & 2 deletions server/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,22 @@ apt-get install --no-install-recommends -y \
python3-venv
apt-get autoclean

POSTGRES_USER_SITE=$(su --login postgres --session-command "python3 -m site --user-site")
POSTGRES_USER_BASE=$(su --login postgres --session-command "python3 -m site --user-base")
POSTGRES_USER_SITE=$(su postgres --session-command "python3 -m site --user-site")
POSTGRES_USER_BASE=$(su postgres --session-command "python3 -m site --user-base")
mkdir --parents "$POSTGRES_USER_SITE"
chown --recursive postgres "$POSTGRES_USER_BASE"

cp /tmp/initdb.sh /docker-entrypoint-initdb.d
chown postgres /docker-entrypoint-initdb.d/*

setup_venv() {
python3 -m venv "$HOME"/venv
# shellcheck source=/dev/null
source "$HOME"/venv/bin/activate

# shellcheck source=/dev/null
source /tmp/requirements.sh
}

export -f setup_venv
su postgres --session-command 'bash -c setup_venv'
2 changes: 1 addition & 1 deletion server/initdb.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ set -o nounset -o xtrace -o errexit -o pipefail
echo "log_destination = 'csvlog'"
} >>"$PGDATA"/postgresql.conf

python3 -m venv "$HOME"/venv
# shellcheck source=/dev/null
source "$HOME"/venv/bin/activate
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FROM postgres:12-bookworm
FROM postgres:13-bookworm

COPY build.sh initdb.sh /tmp/
COPY build.sh initdb.sh requirements.sh /tmp/
RUN bash /tmp/build.sh

HEALTHCHECK --interval=1s --timeout=1s --start-period=1s --retries=30 CMD psql \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FROM postgres:12-bullseye
FROM postgres:13-bullseye

COPY build.sh initdb.sh /tmp/
COPY build.sh initdb.sh requirements.sh /tmp/
RUN bash /tmp/build.sh

HEALTHCHECK --interval=1s --timeout=1s --start-period=1s --retries=30 CMD psql \
Expand Down
6 changes: 6 additions & 0 deletions server/requirements.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash

set -o errexit -o nounset -o pipefail -o xtrace

python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
python3 -m pip install sentence-transformers
6 changes: 3 additions & 3 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def pip_install(requirements: str) -> str:


@pytest.fixture(scope="session")
def db(server_use_pickler: bool):
def db(server_use_pickler: bool, server_has_pgvector: bool):
# for the connection both work for GitHub Actions and concourse
db = gp.database(
params={
Expand All @@ -41,10 +41,10 @@ def db(server_use_pickler: bool):
db._execute(
"""
CREATE EXTENSION IF NOT EXISTS plpython3u;
CREATE EXTENSION IF NOT EXISTS vector;
DROP SCHEMA IF EXISTS test CASCADE;
CREATE SCHEMA test;
""",
"""
+ ("CREATE EXTENSION IF NOT EXISTS vector;" if server_has_pgvector else ""),
has_results=False,
)
if server_use_pickler:
Expand Down
26 changes: 26 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,35 @@ def pytest_addoption(parser: pytest.Parser):
default=True,
help="Use pickler to deserialize UDFs on server.",
)
parser.addini(
"server_has_pgvector",
type="bool",
default=True,
help="pgvector is available on server.",
)


@pytest.fixture(scope="session")
def server_use_pickler(pytestconfig: pytest.Config) -> bool:
val: bool = pytestconfig.getini("server_use_pickler")
return val


@pytest.fixture(scope="session")
def server_has_pgvector(pytestconfig: pytest.Config) -> bool:
val: bool = pytestconfig.getini("server_has_pgvector")
return val


def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]):
server_has_pgvector: bool = config.getini("server_has_pgvector")
server_use_pickler: bool = config.getini("server_use_pickler")
xfail_requires_pgvector = pytest.mark.xfail(reason="requires pgvector on server to run")
xfail_requires_pickler_on_server = pytest.mark.xfail(
reason="requires pickler (e.g. dill) on server to run"
)
for item in items:
if "requires_pgvector" in item.keywords and not server_has_pgvector:
item.add_marker(xfail_requires_pgvector)
if "requires_pickler_on_server" in item.keywords and not server_use_pickler:
item.add_marker(xfail_requires_pickler_on_server)
26 changes: 26 additions & 0 deletions tests/test_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest

import greenplumpython as gp
from tests import db


@pytest.mark.requires_pgvector
def test_embedding_query_string(db: gp.Database):
content = ["I have a dog.", "I like eating apples."]
t = (
db.create_dataframe(columns={"id": range(len(content)), "content": content})
.save_as(
temp=True,
column_names=["id", "content"],
distribution_key={"id"},
distribution_type="hash",
drop_if_exists=True,
drop_cascade=True,
)
.check_unique(columns={"id"})
)
t = t.embedding().create_index(column="content", model_name="all-MiniLM-L6-v2")
df = t.embedding().search(column="content", query="apple", top_k=1)
assert len(list(df)) == 1
for row in df:
assert row["content"] == "I like eating apples."
1 change: 1 addition & 0 deletions tests/test_use_pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def test_pickler_option(server_use_pickler: bool):
from dataclasses import dataclass


@pytest.mark.requires_pickler_on_server
def test_pickler_outside_class(db: gp.Database):
@dataclass
class Int:
Expand Down
8 changes: 8 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ host = localhost
port = 65432

[docker:server]
image =
dockerfile = server/postgres12-python39.Dockerfile
environment =
POSTGRES_USER={[server]user}
Expand All @@ -28,7 +29,9 @@ ports =
[testenv:test-container]
docker = server
deps = -r requirements-dev.txt
allowlist_externals = bash
commands =
bash -c 'source .tox/test-container/bin/activate && source server/requirements.sh'
pytest --exitfirst {posargs}
setenv =
PGHOST={[server]host}
Expand All @@ -38,7 +41,9 @@ setenv =

[testenv:test]
deps = -r requirements-dev.txt
allowlist_externals = bash
commands =
bash -c 'source .tox/test/bin/activate && source server/requirements.sh'
pytest --exitfirst {posargs}
setenv =
PGHOST={env:PGHOST:localhost}
Expand Down Expand Up @@ -74,6 +79,9 @@ doctest_optionflags = NORMALIZE_WHITESPACE
testpaths =
tests
greenplumpython
markers =
requires_pickler_on_server
requires_pgvector

[pydocstyle]
# TODO: Enable docstyle check for all files
Expand Down

0 comments on commit 4f52aae

Please sign in to comment.