Skip to content

Commit

Permalink
feat: modify embedding method to use embedding class
Browse files Browse the repository at this point in the history
  • Loading branch information
yrobla committed Nov 27, 2024
1 parent 7649886 commit cf089e1
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 113 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ build-backend = "poetry.core.masonry.api"
codegate = "codegate.cli:main"

[tool.black]
line-length = 88
line-length = 127
target-version = ["py310"]

[tool.ruff]
line-length = 88
line-length = 127
target-version = "py310"
fix = true

Expand Down
210 changes: 99 additions & 111 deletions scripts/import_packages.py
Original file line number Diff line number Diff line change
@@ -1,129 +1,117 @@
import json

import weaviate
import asyncio

from weaviate.classes.config import DataType, Property
from weaviate.embedded import EmbeddedOptions

from utils.embedding_util import generate_embeddings

json_files = [
"data/archived.jsonl",
"data/deprecated.jsonl",
"data/malicious.jsonl",
]
from weaviate.util import generate_uuid5
from codegate.inference.inference_engine import LlamaCppInferenceEngine


def setup_schema(client):
if not client.collections.exists("Package"):
client.collections.create(
"Package",
properties=[
Property(name="name", data_type=DataType.TEXT),
Property(name="type", data_type=DataType.TEXT),
Property(name="status", data_type=DataType.TEXT),
Property(name="description", data_type=DataType.TEXT),
],
class PackageImporter:
def __init__(self):
self.client = weaviate.WeaviateClient(
embedded_options=EmbeddedOptions(
persistence_data_path="./weaviate_data",
grpc_port=50052
)
)


def generate_vector_string(package):
vector_str = f"{package['name']}"
# add description
package_url = ""
if package["type"] == "pypi":
vector_str += " is a Python package available on PyPI"
package_url = f"https://trustypkg.dev/pypi/{package['name']}"
elif package["type"] == "npm":
vector_str += " is a JavaScript package available on NPM"
package_url = f"https://trustypkg.dev/npm/{package['name']}"
elif package["type"] == "go":
vector_str += " is a Go package. "
package_url = f"https://trustypkg.dev/go/{package['name']}"
elif package["type"] == "crates":
vector_str += " is a Rust package available on Crates. "
package_url = f"https://trustypkg.dev/crates/{package['name']}"
elif package["type"] == "java":
vector_str += " is a Java package. "
package_url = f"https://trustypkg.dev/java/{package['name']}"

# add extra status
if package["status"] == "archived":
vector_str += f". However, this package is found to be archived and no longer \
maintained. For additional information refer to {package_url}"
elif package["status"] == "deprecated":
vector_str += f". However, this package is found to be deprecated and no \
longer recommended for use. For additional information refer to {package_url}"
elif package["status"] == "malicious":
vector_str += f". However, this package is found to be malicious. For \
additional information refer to {package_url}"
return vector_str


def add_data(client):
collection = client.collections.get("Package")

# read all the data from db, we will only add if there is no data, or is different
existing_packages = list(collection.iterator())
packages_dict = {}
for package in existing_packages:
key = package.properties["name"] + "/" + package.properties["type"]
value = {
"status": package.properties["status"],
"description": package.properties["description"],
self.json_files = [
"data/archived.jsonl",
"data/deprecated.jsonl",
"data/malicious.jsonl",
]
self.client.connect()
self.inference_engine = LlamaCppInferenceEngine()
self.model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf"

def setup_schema(self):
if not self.client.collections.exists("Package"):
self.client.collections.create(
"Package",
properties=[
Property(name="name", data_type=DataType.TEXT),
Property(name="type", data_type=DataType.TEXT),
Property(name="status", data_type=DataType.TEXT),
Property(name="description", data_type=DataType.TEXT),
],
)

def generate_vector_string(self, package):
vector_str = f"{package['name']}"
package_url = ""
type_map = {
"pypi": "Python package available on PyPI",
"npm": "JavaScript package available on NPM",
"go": "Go package",
"crates": "Rust package available on Crates",
"java": "Java package"
}
status_messages = {
"archived": "However, this package is found to be archived and no longer maintained.",
"deprecated": "However, this package is found to be deprecated and no longer recommended for use.",
"malicious": "However, this package is found to be malicious."
}
vector_str += f" is a {type_map.get(package['type'], 'unknown type')} "
package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}"

# Add extra status
status_suffix = status_messages.get(package["status"], "")
if status_suffix:
vector_str += f"{status_suffix} For additional information refer to {package_url}"
return vector_str

async def process_package(self, batch, package):
vector_str = self.generate_vector_string(package)
vector = await self.inference_engine.embed(self.model_path, [vector_str])
# This is where the synchronous call is made
batch.add_object(properties=package, vector=vector[0])

async def add_data(self):
collection = self.client.collections.get("Package")
existing_packages = list(collection.iterator())
packages_dict = {
f"{package.properties['name']}/{package.properties['type']}": {
"status": package.properties["status"],
"description": package.properties["description"]
} for package in existing_packages
}
packages_dict[key] = value

for json_file in json_files:
with open(json_file, "r") as f:
print("Adding data from", json_file)

# temporary, just for testing
with collection.batch.dynamic() as batch:
for json_file in self.json_files:
with open(json_file, "r") as f:
print("Adding data from", json_file)
packages_to_insert = []
for line in f:
package = json.loads(line)
package["status"] = json_file.split('/')[-1].split('.')[0]
key = f"{package['name']}/{package['type']}"

# now add the status column
if "archived" in json_file:
package["status"] = "archived"
elif "deprecated" in json_file:
package["status"] = "deprecated"
elif "malicious" in json_file:
package["status"] = "malicious"
else:
package["status"] = "unknown"

# check for the existing package and only add if different
key = package["name"] + "/" + package["type"]
if key in packages_dict:
if (
packages_dict[key]["status"] == package["status"]
and packages_dict[key]["description"]
== package["description"]
):
print("Package already exists", key)
continue

# prepare the object for embedding
print("Generating data for", key)
vector_str = generate_vector_string(package)
vector = generate_embeddings(vector_str)

batch.add_object(properties=package, vector=vector)
if key in packages_dict and packages_dict[key] == {
"status": package["status"],
"description": package["description"]
}:
print("Package already exists", key)
continue

vector_str = self.generate_vector_string(package)
vector = await self.inference_engine.embed(self.model_path, [vector_str])
packages_to_insert.append((package, vector[0]))

def run_import():
client = weaviate.WeaviateClient(
embedded_options=EmbeddedOptions(
persistence_data_path="./weaviate_data", grpc_port=50052
),
)
with client:
client.connect()
print("is_ready:", client.is_ready())
# Synchronous batch insert after preparing all data
with collection.batch.dynamic() as batch:
for package, vector in packages_to_insert:
batch.add_object(properties=package, vector=vector, uuid=generate_uuid5(package))

setup_schema(client)
add_data(client)
async def run_import(self):
self.setup_schema()
await self.add_data()


if __name__ == "__main__":
run_import()
importer = PackageImporter()
asyncio.run(importer.run_import())
try:
assert importer.client.is_live()
pass
finally:
importer.client.close()

0 comments on commit cf089e1

Please sign in to comment.