Skip to content

Commit

Permalink
Fix serialization and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanocerza committed Dec 15, 2023
1 parent 42a0d6d commit 32cccef
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 57 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from typing import Any, Dict, List, Optional

from haystack.core.component import component
Expand Down Expand Up @@ -35,18 +36,29 @@ def __init__(

self._model_name = model
self._project_id = project_id
self._api_key = api_key
self._location = location
self._kwargs = kwargs

self._model = ImageTextModel.from_pretrained(self._model_name)

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs
data = default_to_dict(
self,
model=self._model_name,
project_id=self._project_id,
api_key=self._api_key,
location=self._location,
**self._kwargs,
)
if data["init_parameters"].get("api_key"):
data["init_parameters"]["api_key"] = "GOOGLE_API_KEY"
return data

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "VertexAIImageCaptioner":
if (api_key := data["init_parameters"].get("api_key")) in os.environ:
data["init_parameters"]["api_key"] = os.environ[api_key]
return default_from_dict(cls, data)

@component.output_types(captions=List[str])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from typing import Any, Dict, List, Optional

from haystack.core.component import component
Expand Down Expand Up @@ -35,18 +36,29 @@ def __init__(

self._model_name = model
self._project_id = project_id
self._api_key = api_key
self._location = location
self._kwargs = kwargs

self._model = CodeGenerationModel.from_pretrained(self._model_name)

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs
data = default_to_dict(
self,
model=self._model_name,
project_id=self._project_id,
api_key=self._api_key,
location=self._location,
**self._kwargs,
)
if data["init_parameters"].get("api_key"):
data["init_parameters"]["api_key"] = "GOOGLE_API_KEY"
return data

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "VertexAICodeGenerator":
if (api_key := data["init_parameters"].get("api_key")) in os.environ:
data["init_parameters"]["api_key"] = os.environ[api_key]
return default_from_dict(cls, data)

@component.output_types(answers=List[str])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from typing import Any, Dict, List, Optional

from haystack.core.component import component
Expand Down Expand Up @@ -36,17 +37,28 @@ def __init__(
self._model_name = model
self._project_id = project_id
self._location = location
self._api_key = api_key
self._kwargs = kwargs

self._model = ImageGenerationModel.from_pretrained(self._model_name)

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs
data = default_to_dict(
self,
model=self._model_name,
project_id=self._project_id,
api_key=self._api_key,
location=self._location,
**self._kwargs,
)
if data["init_parameters"].get("api_key"):
data["init_parameters"]["api_key"] = "GOOGLE_API_KEY"
return data

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "VertexAIImageGenerator":
if (api_key := data["init_parameters"].get("api_key")) in os.environ:
data["init_parameters"]["api_key"] = os.environ[api_key]
return default_from_dict(cls, data)

@component.output_types(images=List[ByteStream])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from typing import Any, Dict, List, Optional

from haystack.core.component import component
Expand Down Expand Up @@ -35,18 +36,29 @@ def __init__(

self._model_name = model
self._project_id = project_id
self._api_key = api_key
self._location = location
self._kwargs = kwargs

self._model = ImageTextModel.from_pretrained(self._model_name)

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs
data = default_to_dict(
self,
model=self._model_name,
project_id=self._project_id,
api_key=self._api_key,
location=self._location,
**self._kwargs,
)
if data["init_parameters"].get("api_key"):
data["init_parameters"]["api_key"] = "GOOGLE_API_KEY"
return data

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "VertexAIImageQA":
if (api_key := data["init_parameters"].get("api_key")) in os.environ:
data["init_parameters"]["api_key"] = os.environ[api_key]
return default_from_dict(cls, data)

@component.output_types(answers=List[str])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
import logging
import os
from dataclasses import fields
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -36,14 +37,20 @@ def __init__(

self._model_name = model
self._project_id = project_id
self._api_key = api_key
self._location = location
self._kwargs = kwargs

self._model = TextGenerationModel.from_pretrained(self._model_name)

def to_dict(self) -> Dict[str, Any]:
data = default_to_dict(
self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs
self,
model=self._model_name,
project_id=self._project_id,
api_key=self._api_key,
location=self._location,
**self._kwargs,
)

if (grounding_source := data["init_parameters"].get("grounding_source")) is not None:
Expand All @@ -54,6 +61,8 @@ def to_dict(self) -> Dict[str, Any]:
"type": class_type,
"init_parameters": init_fields,
}
if data["init_parameters"].get("api_key"):
data["init_parameters"]["api_key"] = "GOOGLE_API_KEY"

return data

Expand All @@ -65,6 +74,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAITextGenerator":
data["init_parameters"]["grounding_source"] = getattr(module, class_name)(
**grounding_source["init_parameters"]
)
if (api_key := data["init_parameters"].get("api_key")) in os.environ:
data["init_parameters"]["api_key"] = os.environ[api_key]
return default_from_dict(cls, data)

@component.output_types(answers=List[str], safety_attributes=Dict[str, float], citations=List[Dict[str, Any]])
Expand Down
25 changes: 15 additions & 10 deletions integrations/google-vertex/tests/test_captioner.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
from unittest.mock import Mock, patch

import pytest
from haystack.dataclasses.byte_stream import ByteStream

from google_vertex_haystack.generators.captioner import VertexAIImageCaptioner


@patch("google_vertex_haystack.generators.captioner.vertexai")
@patch("google_vertex_haystack.generators.captioner.authenticate")
@patch("google_vertex_haystack.generators.captioner.ImageTextModel")
def test_init(mock_model_class, mock_vertexai):
def test_init(mock_model_class, mock_authenticate):
captioner = VertexAIImageCaptioner(
model="imagetext", project_id="myproject-123456", number_of_results=1, language="it"
model="imagetext", project_id="myproject-123456", api_key="my_api_key", number_of_results=1, language="it"
)
mock_vertexai.init.assert_called_once_with(project="myproject-123456", location=None)
mock_authenticate.assert_called_once_with(project_id="myproject-123456", api_key="my_api_key", location=None)
mock_model_class.from_pretrained.assert_called_once_with("imagetext")
assert captioner._model_name == "imagetext"
assert captioner._project_id == "myproject-123456"
assert captioner._api_key == "my_api_key"
assert captioner._location is None
assert captioner._kwargs == {"number_of_results": 1, "language": "it"}


@patch("google_vertex_haystack.generators.captioner.vertexai")
@patch("google_vertex_haystack.generators.captioner.authenticate")
@patch("google_vertex_haystack.generators.captioner.ImageTextModel")
def test_to_dict(_mock_model_class, _mock_vertexai):
def test_to_dict(_mock_model_class, _mock_authenticate):
captioner = VertexAIImageCaptioner(
model="imagetext", project_id="myproject-123456", number_of_results=1, language="it"
)
Expand All @@ -30,37 +32,40 @@ def test_to_dict(_mock_model_class, _mock_vertexai):
"init_parameters": {
"model": "imagetext",
"project_id": "myproject-123456",
"api_key": "",
"location": None,
"number_of_results": 1,
"language": "it",
},
}


@patch("google_vertex_haystack.generators.captioner.vertexai")
@patch("google_vertex_haystack.generators.captioner.authenticate")
@patch("google_vertex_haystack.generators.captioner.ImageTextModel")
def test_from_dict(_mock_model_class, _mock_vertexai):
def test_from_dict(_mock_model_class, _mock_authenticate):
captioner = VertexAIImageCaptioner.from_dict(
{
"type": "google_vertex_haystack.generators.captioner.VertexAIImageCaptioner",
"init_parameters": {
"model": "imagetext",
"project_id": "myproject-123456",
"api_key": "",
"number_of_results": 1,
"language": "it",
},
}
)
assert captioner._model_name == "imagetext"
assert captioner._project_id == "myproject-123456"
assert captioner._api_key == ""
assert captioner._location is None
assert captioner._kwargs == {"number_of_results": 1, "language": "it"}
assert captioner._model is not None


@patch("google_vertex_haystack.generators.captioner.vertexai")
@patch("google_vertex_haystack.generators.captioner.authenticate")
@patch("google_vertex_haystack.generators.captioner.ImageTextModel")
def test_run_calls_get_captions(mock_model_class, _mock_vertexai):
def test_run_calls_get_captions(mock_model_class, _mock_authenticate):
mock_model = Mock()
mock_model_class.from_pretrained.return_value = mock_model
captioner = VertexAIImageCaptioner(
Expand Down
24 changes: 14 additions & 10 deletions integrations/google-vertex/tests/test_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,24 @@
from google_vertex_haystack.generators.code_generator import VertexAICodeGenerator


@patch("google_vertex_haystack.generators.code_generator.vertexai")
@patch("google_vertex_haystack.generators.code_generator.authenticate")
@patch("google_vertex_haystack.generators.code_generator.CodeGenerationModel")
def test_init(mock_model_class, mock_vertexai):
def test_init(mock_model_class, mock_authenticate):
generator = VertexAICodeGenerator(
model="code-bison", project_id="myproject-123456", candidate_count=3, temperature=0.5
model="code-bison", project_id="myproject-123456", api_key="my_api_key", candidate_count=3, temperature=0.5
)
mock_vertexai.init.assert_called_once_with(project="myproject-123456", location=None)
mock_authenticate.assert_called_once_with(project_id="myproject-123456", api_key="my_api_key", location=None)
mock_model_class.from_pretrained.assert_called_once_with("code-bison")
assert generator._model_name == "code-bison"
assert generator._project_id == "myproject-123456"
assert generator._api_key == "my_api_key"
assert generator._location is None
assert generator._kwargs == {"candidate_count": 3, "temperature": 0.5}


@patch("google_vertex_haystack.generators.code_generator.vertexai")
@patch("google_vertex_haystack.generators.code_generator.authenticate")
@patch("google_vertex_haystack.generators.code_generator.CodeGenerationModel")
def test_to_dict(_mock_model_class, _mock_vertexai):
def test_to_dict(_mock_model_class, _mock_authenticate):
generator = VertexAICodeGenerator(
model="code-bison", project_id="myproject-123456", candidate_count=3, temperature=0.5
)
Expand All @@ -31,36 +32,39 @@ def test_to_dict(_mock_model_class, _mock_vertexai):
"model": "code-bison",
"project_id": "myproject-123456",
"location": None,
"api_key": "",
"candidate_count": 3,
"temperature": 0.5,
},
}


@patch("google_vertex_haystack.generators.code_generator.vertexai")
@patch("google_vertex_haystack.generators.code_generator.authenticate")
@patch("google_vertex_haystack.generators.code_generator.CodeGenerationModel")
def test_from_dict(_mock_model_class, _mock_vertexai):
def test_from_dict(_mock_model_class, _mock_authenticate):
generator = VertexAICodeGenerator.from_dict(
{
"type": "google_vertex_haystack.generators.code_generator.VertexAICodeGenerator",
"init_parameters": {
"model": "code-bison",
"project_id": "myproject-123456",
"api_key": "",
"candidate_count": 2,
"temperature": 0.5,
},
}
)
assert generator._model_name == "code-bison"
assert generator._project_id == "myproject-123456"
assert generator._api_key == ""
assert generator._location is None
assert generator._kwargs == {"candidate_count": 2, "temperature": 0.5}
assert generator._model is not None


@patch("google_vertex_haystack.generators.code_generator.vertexai")
@patch("google_vertex_haystack.generators.code_generator.authenticate")
@patch("google_vertex_haystack.generators.code_generator.CodeGenerationModel")
def test_run_calls_predict(mock_model_class, _mock_vertexai):
def test_run_calls_predict(mock_model_class, _mock_authenticate):
mock_model = Mock()
mock_model.predict.return_value = TextGenerationResponse("answer", None)
mock_model_class.from_pretrained.return_value = mock_model
Expand Down
Loading

0 comments on commit 32cccef

Please sign in to comment.