Skip to content

Commit

Permalink
Cloud load model (#31)
Browse files Browse the repository at this point in the history
* added load model to api

* removed hydra from dockerfile

* fix

* fixes

* path fixes

---------

Co-authored-by: Konstantina <[email protected]>
  • Loading branch information
artemdou and ntina10 authored Jan 24, 2025
1 parent f80f093 commit dc357b2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
18 changes: 14 additions & 4 deletions dockerfiles/api.dockerfile
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
# Change from latest to a specific version if your requirements.txt
# Use Python 3.11-slim as the base image
FROM python:3.11-slim AS base

# Install necessary dependencies
RUN apt update && \
apt install --no-install-recommends -y build-essential gcc && \
apt clean && rm -rf /var/lib/apt/lists/*

COPY .hydra .hydra/
COPY models models/
COPY src src/
# Set the working directory
WORKDIR /app

# Copy model files
COPY models/ models/

# Copy the source code
COPY src/ src/

# Copy the other necessary files
COPY requirements.txt requirements.txt
COPY requirements_dev.txt requirements_dev.txt
COPY README.md README.md
COPY pyproject.toml pyproject.toml

# Install required Python packages
RUN pip install -r requirements.txt --no-cache-dir --verbose
RUN pip install . --no-deps --no-cache-dir --verbose

# Set the entry point for the container
ENTRYPOINT ["uvicorn", "src.final_project.api:app", "--host", "0.0.0.0", "--port", "8000"]
15 changes: 12 additions & 3 deletions src/final_project/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from pydantic import BaseModel
from transformers import AutoTokenizer
from omegaconf import OmegaConf
from google.cloud import storage

from src.final_project.model import AwesomeModel
from final_project import AwesomeModel

DEVICE = torch.device("cuda" if torch.cuda.is_available(
) else "mps" if torch.backends.mps.is_available() else "cpu")
Expand All @@ -28,14 +29,22 @@ async def lifespan(app: FastAPI):

print("Loading model")

config_path = ".hydra/config.yaml"
config_path = "src/final_project/config/model.yaml"
cfg = OmegaConf.load(config_path)

# ---------------------------------------------------------
# Instantiate AwesomeModel and load its weights
# ---------------------------------------------------------
model = AwesomeModel(cfg)
model.load_state_dict(torch.load("models/model.pth", map_location=DEVICE))
client = storage.Client()
bucket_name, blob_name = "gs://mlops-bucket-1999/models/model.pth".replace("gs://", "").split("/", 1)
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name)

with blob.open("rb") as f:
state_dict = torch.load(f, map_location=DEVICE)

model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()

Expand Down

0 comments on commit dc357b2

Please sign in to comment.