Skip to content

Commit

Permalink
fix errors with certain image dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
avoonix committed Aug 1, 2024
1 parent b2b383c commit 9605609
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion inference/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ RUN pip install --no-cache-dir -r requirements.txt

COPY . .

CMD [ "python", "src/main.py" ]
CMD [ "python", "-u", "src/main.py" ]
2 changes: 1 addition & 1 deletion inference/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ grpcio==1.64.1
Pillow==10.3.0
protobuf==5.27.1
torch==2.3.1
transformers==4.41.2
transformers==4.43.3
8 changes: 7 additions & 1 deletion inference/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from io import BytesIO
import os

print("starting setup")

model_name = os.environ.get("FUZZLE_INFERENCE_MODEL_NAME")
port = os.environ.get("FUZZLE_INFERENCE_PORT")
if model_name is None:
Expand All @@ -18,8 +20,11 @@
exit(1)

processor = AutoProcessor.from_pretrained(model_name)
print("processor initialized")
model = AutoModelForZeroShotImageClassification.from_pretrained(model_name)
print("model initialized")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("tokenizer initialized")


# TODO: batch requests?
Expand All @@ -34,13 +39,14 @@ def TextEmbedding(self, request: inference_pb2.TextEmbeddingRequest, context):
def ImageEmbedding(self, request: inference_pb2.ImageEmbeddingRequest, context):
assert request.model == inference_pb2.ImageModel.CLIP_IMAGE
image = Image.open(BytesIO(request.image))
inputs = processor(images=[image], return_tensors="pt")
inputs = processor(images=[image], return_tensors="pt", input_data_format="channels_last")
with torch.no_grad():
image_features = model.get_image_features(**inputs)
return inference_pb2.EmbeddingResponse(embedding=image_features[0])


def serve():
print("starting server")
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
inference_pb2_grpc.add_GenerateServicer_to_server(Generator(), server)
server.add_insecure_port("0.0.0.0:" + port)
Expand Down

0 comments on commit 9605609

Please sign in to comment.