-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: modify embedding method to use embedding class
- Loading branch information
Showing
2 changed files
with
101 additions
and
113 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,129 +1,117 @@ | ||
import json | ||
|
||
import weaviate | ||
import asyncio | ||
|
||
from weaviate.classes.config import DataType, Property | ||
from weaviate.embedded import EmbeddedOptions | ||
|
||
from utils.embedding_util import generate_embeddings | ||
|
||
json_files = [ | ||
"data/archived.jsonl", | ||
"data/deprecated.jsonl", | ||
"data/malicious.jsonl", | ||
] | ||
from weaviate.util import generate_uuid5 | ||
from codegate.inference.inference_engine import LlamaCppInferenceEngine | ||
|
||
|
||
def setup_schema(client): | ||
if not client.collections.exists("Package"): | ||
client.collections.create( | ||
"Package", | ||
properties=[ | ||
Property(name="name", data_type=DataType.TEXT), | ||
Property(name="type", data_type=DataType.TEXT), | ||
Property(name="status", data_type=DataType.TEXT), | ||
Property(name="description", data_type=DataType.TEXT), | ||
], | ||
class PackageImporter: | ||
def __init__(self): | ||
self.client = weaviate.WeaviateClient( | ||
embedded_options=EmbeddedOptions( | ||
persistence_data_path="./weaviate_data", | ||
grpc_port=50052 | ||
) | ||
) | ||
|
||
|
||
def generate_vector_string(package): | ||
vector_str = f"{package['name']}" | ||
# add description | ||
package_url = "" | ||
if package["type"] == "pypi": | ||
vector_str += " is a Python package available on PyPI" | ||
package_url = f"https://trustypkg.dev/pypi/{package['name']}" | ||
elif package["type"] == "npm": | ||
vector_str += " is a JavaScript package available on NPM" | ||
package_url = f"https://trustypkg.dev/npm/{package['name']}" | ||
elif package["type"] == "go": | ||
vector_str += " is a Go package. " | ||
package_url = f"https://trustypkg.dev/go/{package['name']}" | ||
elif package["type"] == "crates": | ||
vector_str += " is a Rust package available on Crates. " | ||
package_url = f"https://trustypkg.dev/crates/{package['name']}" | ||
elif package["type"] == "java": | ||
vector_str += " is a Java package. " | ||
package_url = f"https://trustypkg.dev/java/{package['name']}" | ||
|
||
# add extra status | ||
if package["status"] == "archived": | ||
vector_str += f". However, this package is found to be archived and no longer \ | ||
maintained. For additional information refer to {package_url}" | ||
elif package["status"] == "deprecated": | ||
vector_str += f". However, this package is found to be deprecated and no \ | ||
longer recommended for use. For additional information refer to {package_url}" | ||
elif package["status"] == "malicious": | ||
vector_str += f". However, this package is found to be malicious. For \ | ||
additional information refer to {package_url}" | ||
return vector_str | ||
|
||
|
||
def add_data(client): | ||
collection = client.collections.get("Package") | ||
|
||
# read all the data from db, we will only add if there is no data, or is different | ||
existing_packages = list(collection.iterator()) | ||
packages_dict = {} | ||
for package in existing_packages: | ||
key = package.properties["name"] + "/" + package.properties["type"] | ||
value = { | ||
"status": package.properties["status"], | ||
"description": package.properties["description"], | ||
self.json_files = [ | ||
"data/archived.jsonl", | ||
"data/deprecated.jsonl", | ||
"data/malicious.jsonl", | ||
] | ||
self.client.connect() | ||
self.inference_engine = LlamaCppInferenceEngine() | ||
self.model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf" | ||
|
||
def setup_schema(self): | ||
if not self.client.collections.exists("Package"): | ||
self.client.collections.create( | ||
"Package", | ||
properties=[ | ||
Property(name="name", data_type=DataType.TEXT), | ||
Property(name="type", data_type=DataType.TEXT), | ||
Property(name="status", data_type=DataType.TEXT), | ||
Property(name="description", data_type=DataType.TEXT), | ||
], | ||
) | ||
|
||
def generate_vector_string(self, package): | ||
vector_str = f"{package['name']}" | ||
package_url = "" | ||
type_map = { | ||
"pypi": "Python package available on PyPI", | ||
"npm": "JavaScript package available on NPM", | ||
"go": "Go package", | ||
"crates": "Rust package available on Crates", | ||
"java": "Java package" | ||
} | ||
status_messages = { | ||
"archived": "However, this package is found to be archived and no longer maintained.", | ||
"deprecated": "However, this package is found to be deprecated and no longer recommended for use.", | ||
"malicious": "However, this package is found to be malicious." | ||
} | ||
vector_str += f" is a {type_map.get(package['type'], 'unknown type')} " | ||
package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}" | ||
|
||
# Add extra status | ||
status_suffix = status_messages.get(package["status"], "") | ||
if status_suffix: | ||
vector_str += f"{status_suffix} For additional information refer to {package_url}" | ||
return vector_str | ||
|
||
async def process_package(self, batch, package): | ||
vector_str = self.generate_vector_string(package) | ||
vector = await self.inference_engine.embed(self.model_path, [vector_str]) | ||
# This is where the synchronous call is made | ||
batch.add_object(properties=package, vector=vector[0]) | ||
|
||
async def add_data(self): | ||
collection = self.client.collections.get("Package") | ||
existing_packages = list(collection.iterator()) | ||
packages_dict = { | ||
f"{package.properties['name']}/{package.properties['type']}": { | ||
"status": package.properties["status"], | ||
"description": package.properties["description"] | ||
} for package in existing_packages | ||
} | ||
packages_dict[key] = value | ||
|
||
for json_file in json_files: | ||
with open(json_file, "r") as f: | ||
print("Adding data from", json_file) | ||
|
||
# temporary, just for testing | ||
with collection.batch.dynamic() as batch: | ||
for json_file in self.json_files: | ||
with open(json_file, "r") as f: | ||
print("Adding data from", json_file) | ||
packages_to_insert = [] | ||
for line in f: | ||
package = json.loads(line) | ||
package["status"] = json_file.split('/')[-1].split('.')[0] | ||
key = f"{package['name']}/{package['type']}" | ||
|
||
# now add the status column | ||
if "archived" in json_file: | ||
package["status"] = "archived" | ||
elif "deprecated" in json_file: | ||
package["status"] = "deprecated" | ||
elif "malicious" in json_file: | ||
package["status"] = "malicious" | ||
else: | ||
package["status"] = "unknown" | ||
|
||
# check for the existing package and only add if different | ||
key = package["name"] + "/" + package["type"] | ||
if key in packages_dict: | ||
if ( | ||
packages_dict[key]["status"] == package["status"] | ||
and packages_dict[key]["description"] | ||
== package["description"] | ||
): | ||
print("Package already exists", key) | ||
continue | ||
|
||
# prepare the object for embedding | ||
print("Generating data for", key) | ||
vector_str = generate_vector_string(package) | ||
vector = generate_embeddings(vector_str) | ||
|
||
batch.add_object(properties=package, vector=vector) | ||
if key in packages_dict and packages_dict[key] == { | ||
"status": package["status"], | ||
"description": package["description"] | ||
}: | ||
print("Package already exists", key) | ||
continue | ||
|
||
vector_str = self.generate_vector_string(package) | ||
vector = await self.inference_engine.embed(self.model_path, [vector_str]) | ||
packages_to_insert.append((package, vector[0])) | ||
|
||
def run_import(): | ||
client = weaviate.WeaviateClient( | ||
embedded_options=EmbeddedOptions( | ||
persistence_data_path="./weaviate_data", grpc_port=50052 | ||
), | ||
) | ||
with client: | ||
client.connect() | ||
print("is_ready:", client.is_ready()) | ||
# Synchronous batch insert after preparing all data | ||
with collection.batch.dynamic() as batch: | ||
for package, vector in packages_to_insert: | ||
batch.add_object(properties=package, vector=vector, uuid=generate_uuid5(package)) | ||
|
||
setup_schema(client) | ||
add_data(client) | ||
async def run_import(self): | ||
self.setup_schema() | ||
await self.add_data() | ||
|
||
|
||
if __name__ == "__main__": | ||
run_import() | ||
importer = PackageImporter() | ||
asyncio.run(importer.run_import()) | ||
try: | ||
assert importer.client.is_live() | ||
pass | ||
finally: | ||
importer.client.close() |