Skip to content

Commit

Permalink
feat(proofs): New endpoint to extract price data from images (with Ge…
Browse files Browse the repository at this point in the history
…mini) (#557)
  • Loading branch information
TTalex authored Nov 13, 2024
1 parent 1337015 commit fa54a36
Show file tree
Hide file tree
Showing 8 changed files with 504 additions and 1 deletion.
1 change: 1 addition & 0 deletions .github/workflows/container-deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ jobs:
echo "POSTGRES_PASSWORD=${{ secrets.POSTGRES_PASSWORD }}" >> .env
echo "ENVIRONMENT=${{ env.ENVIRONMENT }}" >> .env
echo "GOOGLE_CLOUD_VISION_API_KEY=${{ secrets.GOOGLE_CLOUD_VISION_API_KEY }}" >> .env
echo "GOOGLE_GEMINI_API_KEY=${{ secrets.GOOGLE_GEMINI_API_KEY }}" >> .env
- name: Create Docker volumes
uses: appleboy/ssh-action@master
Expand Down
5 changes: 5 additions & 0 deletions config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,8 @@
# ------------------------------------------------------------------------------

GOOGLE_CLOUD_VISION_API_KEY = os.getenv("GOOGLE_CLOUD_VISION_API_KEY")

# Google Gemini API
# ------------------------------------------------------------------------------

GOOGLE_GEMINI_API_KEY = os.getenv("GOOGLE_GEMINI_API_KEY")
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ x-api-common: &api-common
- SENTRY_DSN
- LOG_LEVEL
- GOOGLE_CLOUD_VISION_API_KEY
- GOOGLE_GEMINI_API_KEY
networks:
- default

Expand Down
9 changes: 9 additions & 0 deletions open_prices/api/proofs/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,12 @@ class ProofUpdateSerializer(serializers.ModelSerializer):
class Meta:
model = Proof
fields = Proof.UPDATE_FIELDS


class ProofProcessWithGeminiSerializer(serializers.Serializer):
files = serializers.ListField(
child=serializers.FileField(required=True, use_url=False)
)
mode = (
serializers.CharField()
) # TODO: this mode param should be used to select the prompt to execute, unimplemented for now
16 changes: 16 additions & 0 deletions open_prices/api/proofs/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import PIL.Image
from django_filters.rest_framework import DjangoFilterBackend
from drf_spectacular.utils import extend_schema
from rest_framework import filters, mixins, status, viewsets
Expand All @@ -11,11 +12,13 @@
from open_prices.api.proofs.serializers import (
ProofCreateSerializer,
ProofFullSerializer,
ProofProcessWithGeminiSerializer,
ProofUpdateSerializer,
ProofUploadSerializer,
)
from open_prices.api.utils import get_source_from_request
from open_prices.common.authentication import CustomAuthentication
from open_prices.common.gemini import handle_bulk_labels
from open_prices.proofs.models import Proof
from open_prices.proofs.utils import store_file

Expand Down Expand Up @@ -94,3 +97,16 @@ def upload(self, request: Request) -> Response:
proof = serializer.save(owner=self.request.user.user_id, source=source)
# return full proof
return Response(ProofFullSerializer(proof).data, status=status.HTTP_201_CREATED)

@extend_schema(request=ProofProcessWithGeminiSerializer)
@action(
detail=False,
methods=["POST"],
url_path="process_with_gemini",
parser_classes=[MultiPartParser],
)
def process_with_gemini(self, request: Request) -> Response:
files = request.FILES.getlist("files")
sample_files = [PIL.Image.open(file.file) for file in files]
res = handle_bulk_labels(sample_files)
return Response(res, status=status.HTTP_200_OK)
141 changes: 141 additions & 0 deletions open_prices/common/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import enum
import json

import google.generativeai as genai
import typing_extensions as typing
from django.conf import settings

genai.configure(api_key=settings.GOOGLE_GEMINI_API_KEY)
model = genai.GenerativeModel(model_name="gemini-1.5-flash")


# TODO: what about orther categories ?
class Products(enum.Enum):
OTHER = "other"
APPLES = "en:apples"
APRICOTS = "en:apricots"
ARTICHOKES = "en:artichokes"
ASPARAGUS = "en:asparagus"
AUBERGINES = "en:aubergines"
AVOCADOS = "en:avocados"
BANANAS = "en:bananas"
BEET = "en:beet"
BERRIES = "en:berries"
BLACKBERRIES = "en:blackberries"
BLUEBERRIES = "en:blueberries"
BOK_CHOY = "en:bok-choy"
BROCCOLI = "en:broccoli"
CABBAGES = "en:cabbages"
CARROTS = "en:carrots"
CAULIFLOWERS = "en:cauliflowers"
CELERY = "en:celery"
CELERY_STALK = "en:celery-stalk"
CEP_MUSHROOMS = "en:cep-mushrooms"
CHANTERELLES = "en:chanterelles"
CHERRIES = "en:cherries"
CHERRY_TOMATOES = "en:cherry-tomatoes"
CHICKPEAS = "en:chickpeas"
CHIVES = "en:chives"
CLEMENTINES = "en:clementines"
COCONUTS = "en:coconuts"
CRANBERRIES = "en:cranberries"
CUCUMBERS = "en:cucumbers"
DATES = "en:dates"
ENDIVES = "en:endives"
FIGS = "en:figs"
GARLIC = "en:garlic"
GINGER = "en:ginger"
GRAPEFRUITS = "en:grapefruits"
GRAPES = "en:grapes"
GREEN_BEANS = "en:green-beans"
KIWIS = "en:kiwis"
KAKIS = "en:kakis"
LEEKS = "en:leeks"
LEMONS = "en:lemons"
LETTUCES = "en:lettuces"
LIMES = "en:limes"
LYCHEES = "en:lychees"
MANDARIN_ORANGES = "en:mandarin-oranges"
MANGOES = "en:mangoes"
MELONS = "en:melons"
MUSHROOMS = "en:mushrooms"
NECTARINES = "en:nectarines"
ONIONS = "en:onions"
ORANGES = "en:oranges"
PAPAYAS = "en:papayas"
PASSION_FRUITS = "en:passion-fruits"
PEACHES = "en:peaches"
PEARS = "en:pears"
PEAS = "en:peas"
PEPPERS = "en:peppers"
PINEAPPLE = "en:pineapple"
PLUMS = "en:plums"
POMEGRANATES = "en:pomegranates"
POMELOS = "en:pomelos"
POTATOES = "en:potatoes"
PUMPKINS = "en:pumpkins"
RADISHES = "en:radishes"
RASPBERRIES = "en:raspberries"
RHUBARBS = "en:rhubarbs"
SCALLIONS = "en:scallions"
SHALLOTS = "en:shallots"
SPINACHS = "en:spinachs"
SPROUTS = "en:sprouts"
STRAWBERRIES = "en:strawberries"
TOMATOES = "en:tomatoes"
TURNIP = "en:turnip"
WATERMELONS = "en:watermelons"
WALNUTS = "en:walnuts"
ZUCCHINI = "en:zucchini"


# TODO: what about other origins ?
class Origin(enum.Enum):
FRANCE = "en:france"
ITALY = "en:italy"
SPAIN = "en:spain"
POLAND = "en:poland"
CHINA = "en:china"
BELGIUM = "en:belgium"
MOROCCO = "en:morocco"
PERU = "en:peru"
PORTUGAL = "en:portugal"
MEXICO = "en:mexico"
OTHER = "other"
UNKNOWN = "unknown"


class Unit(enum.Enum):
KILOGRAM = "KILOGRAM"
UNIT = "UNIT"


class Label(typing.TypedDict):
product: Products
price: float
origin: Origin
unit: Unit
organic: bool
barcode: str


class Labels(typing.TypedDict):
labels: list[Label]


def handle_bulk_labels(images):
response = model.generate_content(
[
"Here are "
+ str(len(images))
+ " pictures containing a label. For each picture of a label, please extract all the following attributes: the product category matching product name, the origin category matching country of origin, the price, is the product organic, the unit (per KILOGRAM or per UNIT) and the barcode. I expect a list of "
+ str(len(images))
+ " labels in your reply, no more, no less. If you cannot decode an attribute, set it to an empty string"
]
+ images,
generation_config=genai.GenerationConfig(
response_mime_type="application/json", response_schema=Labels
),
)
vals = json.loads(response.text)
return vals
Loading

0 comments on commit fa54a36

Please sign in to comment.