diff --git a/pyproject.toml b/pyproject.toml index 8f55e302..1d76921d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,11 +40,11 @@ build-backend = "poetry.core.masonry.api" codegate = "codegate.cli:main" [tool.black] -line-length = 88 +line-length = 127 target-version = ["py310"] [tool.ruff] -line-length = 88 +line-length = 127 target-version = "py310" fix = true diff --git a/scripts/import_packages.py b/scripts/import_packages.py index b2ad3e37..4c0ef6a8 100644 --- a/scripts/import_packages.py +++ b/scripts/import_packages.py @@ -1,129 +1,117 @@ import json - import weaviate +import asyncio + from weaviate.classes.config import DataType, Property from weaviate.embedded import EmbeddedOptions - -from utils.embedding_util import generate_embeddings - -json_files = [ - "data/archived.jsonl", - "data/deprecated.jsonl", - "data/malicious.jsonl", -] +from weaviate.util import generate_uuid5 +from codegate.inference.inference_engine import LlamaCppInferenceEngine -def setup_schema(client): - if not client.collections.exists("Package"): - client.collections.create( - "Package", - properties=[ - Property(name="name", data_type=DataType.TEXT), - Property(name="type", data_type=DataType.TEXT), - Property(name="status", data_type=DataType.TEXT), - Property(name="description", data_type=DataType.TEXT), - ], +class PackageImporter: + def __init__(self): + self.client = weaviate.WeaviateClient( + embedded_options=EmbeddedOptions( + persistence_data_path="./weaviate_data", + grpc_port=50052 + ) ) - - -def generate_vector_string(package): - vector_str = f"{package['name']}" - # add description - package_url = "" - if package["type"] == "pypi": - vector_str += " is a Python package available on PyPI" - package_url = f"https://trustypkg.dev/pypi/{package['name']}" - elif package["type"] == "npm": - vector_str += " is a JavaScript package available on NPM" - package_url = f"https://trustypkg.dev/npm/{package['name']}" - elif package["type"] == "go": - vector_str += " is a Go package. " - package_url = f"https://trustypkg.dev/go/{package['name']}" - elif package["type"] == "crates": - vector_str += " is a Rust package available on Crates. " - package_url = f"https://trustypkg.dev/crates/{package['name']}" - elif package["type"] == "java": - vector_str += " is a Java package. " - package_url = f"https://trustypkg.dev/java/{package['name']}" - - # add extra status - if package["status"] == "archived": - vector_str += f". However, this package is found to be archived and no longer \ -maintained. For additional information refer to {package_url}" - elif package["status"] == "deprecated": - vector_str += f". However, this package is found to be deprecated and no \ -longer recommended for use. For additional information refer to {package_url}" - elif package["status"] == "malicious": - vector_str += f". However, this package is found to be malicious. For \ -additional information refer to {package_url}" - return vector_str - - -def add_data(client): - collection = client.collections.get("Package") - - # read all the data from db, we will only add if there is no data, or is different - existing_packages = list(collection.iterator()) - packages_dict = {} - for package in existing_packages: - key = package.properties["name"] + "/" + package.properties["type"] - value = { - "status": package.properties["status"], - "description": package.properties["description"], + self.json_files = [ + "data/archived.jsonl", + "data/deprecated.jsonl", + "data/malicious.jsonl", + ] + self.client.connect() + self.inference_engine = LlamaCppInferenceEngine() + self.model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf" + + def setup_schema(self): + if not self.client.collections.exists("Package"): + self.client.collections.create( + "Package", + properties=[ + Property(name="name", data_type=DataType.TEXT), + Property(name="type", data_type=DataType.TEXT), + Property(name="status", data_type=DataType.TEXT), + Property(name="description", data_type=DataType.TEXT), + ], + ) + + def generate_vector_string(self, package): + vector_str = f"{package['name']}" + package_url = "" + type_map = { + "pypi": "Python package available on PyPI", + "npm": "JavaScript package available on NPM", + "go": "Go package", + "crates": "Rust package available on Crates", + "java": "Java package" + } + status_messages = { + "archived": "However, this package is found to be archived and no longer maintained.", + "deprecated": "However, this package is found to be deprecated and no longer recommended for use.", + "malicious": "However, this package is found to be malicious." + } + vector_str += f" is a {type_map.get(package['type'], 'unknown type')} " + package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}" + + # Add extra status + status_suffix = status_messages.get(package["status"], "") + if status_suffix: + vector_str += f"{status_suffix} For additional information refer to {package_url}" + return vector_str + + async def process_package(self, batch, package): + vector_str = self.generate_vector_string(package) + vector = await self.inference_engine.embed(self.model_path, [vector_str]) + # This is where the synchronous call is made + batch.add_object(properties=package, vector=vector[0]) + + async def add_data(self): + collection = self.client.collections.get("Package") + existing_packages = list(collection.iterator()) + packages_dict = { + f"{package.properties['name']}/{package.properties['type']}": { + "status": package.properties["status"], + "description": package.properties["description"] + } for package in existing_packages } - packages_dict[key] = value - - for json_file in json_files: - with open(json_file, "r") as f: - print("Adding data from", json_file) - # temporary, just for testing - with collection.batch.dynamic() as batch: + for json_file in self.json_files: + with open(json_file, "r") as f: + print("Adding data from", json_file) + packages_to_insert = [] for line in f: package = json.loads(line) + package["status"] = json_file.split('/')[-1].split('.')[0] + key = f"{package['name']}/{package['type']}" - # now add the status column - if "archived" in json_file: - package["status"] = "archived" - elif "deprecated" in json_file: - package["status"] = "deprecated" - elif "malicious" in json_file: - package["status"] = "malicious" - else: - package["status"] = "unknown" - - # check for the existing package and only add if different - key = package["name"] + "/" + package["type"] - if key in packages_dict: - if ( - packages_dict[key]["status"] == package["status"] - and packages_dict[key]["description"] - == package["description"] - ): - print("Package already exists", key) - continue - - # prepare the object for embedding - print("Generating data for", key) - vector_str = generate_vector_string(package) - vector = generate_embeddings(vector_str) - - batch.add_object(properties=package, vector=vector) + if key in packages_dict and packages_dict[key] == { + "status": package["status"], + "description": package["description"] + }: + print("Package already exists", key) + continue + vector_str = self.generate_vector_string(package) + vector = await self.inference_engine.embed(self.model_path, [vector_str]) + packages_to_insert.append((package, vector[0])) -def run_import(): - client = weaviate.WeaviateClient( - embedded_options=EmbeddedOptions( - persistence_data_path="./weaviate_data", grpc_port=50052 - ), - ) - with client: - client.connect() - print("is_ready:", client.is_ready()) + # Synchronous batch insert after preparing all data + with collection.batch.dynamic() as batch: + for package, vector in packages_to_insert: + batch.add_object(properties=package, vector=vector, uuid=generate_uuid5(package)) - setup_schema(client) - add_data(client) + async def run_import(self): + self.setup_schema() + await self.add_data() if __name__ == "__main__": - run_import() + importer = PackageImporter() + asyncio.run(importer.run_import()) + try: + assert importer.client.is_live() + pass + finally: + importer.client.close()