Skip to content

Commit

Permalink
load_models checks for device before getting the BGE or NLI model loa…
Browse files Browse the repository at this point in the history
…ded in memory. Was defaulting to CPU. And removed gunk for load_sql (#119)

Co-authored-by: Salman Paracha <[email protected]>
  • Loading branch information
salmanap and Salman Paracha authored Oct 4, 2024
1 parent 093891b commit 7011874
Showing 1 changed file with 17 additions and 28 deletions.
45 changes: 17 additions & 28 deletions model_server/app/load_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,28 @@
import sentence_transformers
from transformers import AutoTokenizer, pipeline
import sqlite3
from app.employee_data_generator import generate_employee_data
from app.network_data_generator import (
generate_device_data,
generate_interface_stats_data,
generate_flow_data,
)
import torch

def get_device():
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"

return device

def load_transformers(models=os.getenv("MODELS", "BAAI/bge-large-en-v1.5")):
transformers = {}
device = get_device()

print(f"Using device: {device}")
for model in models.split(","):
transformers[model] = sentence_transformers.SentenceTransformer(model)
transformers[model] = sentence_transformers.SentenceTransformer(model, device=device)

return transformers


def load_guard_model(
model_name,
hardware_config="cpu",
Expand Down Expand Up @@ -52,27 +57,11 @@ def load_zero_shot_models(
models=os.getenv("ZERO_SHOT_MODELS", "tasksource/deberta-base-long-nli")
):
zero_shot_models = {}

device = get_device()
for model in models.split(","):
zero_shot_models[model] = pipeline("zero-shot-classification", model=model)
zero_shot_models[model] = pipeline("zero-shot-classification", model=model, device=device)

return zero_shot_models


def load_sql():
# Example Usage
conn = sqlite3.connect(":memory:")

# create and load the employees table
generate_employee_data(conn)

# create and load the devices table
device_data = generate_device_data(conn)

# create and load the interface_stats table
generate_interface_stats_data(conn, device_data)

# create and load the flow table
generate_flow_data(conn, device_data)

return conn
if __name__ =="__main__":
print(get_device())

0 comments on commit 7011874

Please sign in to comment.