Skip to content

Commit

Permalink
feat: convert project_id to string, no matter what the service returns
Browse files Browse the repository at this point in the history
  • Loading branch information
NiklasKoehneckeAA committed Jan 24, 2025
1 parent 2fbeee1 commit 7b9a9ba
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
30 changes: 17 additions & 13 deletions src/intelligence_layer/connectors/studio/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from uuid import uuid4

import requests
from pydantic import BaseModel, Field, RootModel
from pydantic import BaseModel, Field, RootModel, field_validator
from requests.exceptions import ConnectionError, MissingSchema

from intelligence_layer.connectors import JsonSerializable
Expand Down Expand Up @@ -96,7 +96,7 @@ class PostBenchmarkRequest(BaseModel):

class GetBenchmarkResponse(BaseModel):
id: str
project_id: int
project_id: str
dataset_id: str
name: str
description: str | None
Expand All @@ -109,6 +109,10 @@ class GetBenchmarkResponse(BaseModel):
created_by: str | None
updated_by: str | None

@field_validator("project_id", mode="before")
def transform_id_to_str(cls, value) -> str:
return str(value)


class PostBenchmarkExecution(BaseModel):
name: str
Expand Down Expand Up @@ -226,7 +230,7 @@ def __init__(
self.url = StudioClient.get_url(studio_url)
self._check_connection()
self._project_name = project
self._project_id: int | None = None
self._project_id: str | None = None

if create_project:
project_id = self._get_project(self._project_name)
Expand Down Expand Up @@ -256,7 +260,7 @@ def _check_connection(self) -> None:
) from None

@property
def project_id(self) -> int:
def project_id(self) -> str:
if self._project_id is None:
project_id = self._get_project(self._project_name)
if project_id is None:
Expand All @@ -266,7 +270,7 @@ def project_id(self) -> int:
self._project_id = project_id
return self._project_id

def _get_project(self, project: str) -> int | None:
def _get_project(self, project_name: str) -> str | None:
url = urljoin(self.url, "/api/projects")
response = requests.get(
url,
Expand All @@ -276,24 +280,24 @@ def _get_project(self, project: str) -> int | None:
all_projects = response.json()
try:
project_of_interest = next(
proj for proj in all_projects if proj["name"] == project
proj for proj in all_projects if proj["name"] == project_name
)
return int(project_of_interest["id"])
return str(project_of_interest["id"])
except StopIteration:
return None

def create_project(
self,
project: str,
project_name: str,
description: Optional[str] = None,
reuse_existing: bool = False,
) -> int:
) -> str:
"""Creates a project in Studio.
Projects are uniquely identified by the user provided name.
Args:
project: User provided name of the project.
project_name: User provided name of the project.
description: Description explaining the usage of the project. Defaults to None.
reuse_existing: Reuse project with specified name if already existing. Defaults to False.
Expand All @@ -302,7 +306,7 @@ def create_project(
The ID of the newly created project.
"""
url = urljoin(self.url, "/api/projects")
data = StudioProject(name=project, description=description)
data = StudioProject(name=project_name, description=description)
response = requests.post(
url,
data=data.model_dump_json(),
Expand All @@ -311,15 +315,15 @@ def create_project(
match response.status_code:
case 409:
if reuse_existing:
fetched_project = self._get_project(project)
fetched_project = self._get_project(project_name)
assert (
fetched_project is not None
), "Project already exists but not allowed to be used."
return fetched_project
raise ValueError("Project already exists")
case _:
response.raise_for_status()
return int(response.text)
return response.text

def submit_trace(self, data: Sequence[ExportedSpan]) -> str:
"""Sends the provided spans to Studio as a singular trace.
Expand Down
2 changes: 1 addition & 1 deletion tests/evaluation/benchmark/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_benchmark_response(datatset_id: str) -> GetBenchmarkResponse:
aggregation_identifier = create_aggregation_logic_identifier(aggregation_logic)
return GetBenchmarkResponse(
id="id",
project_id=0,
project_id=str(uuid4()),
dataset_id=datatset_id,
name="name",
description="description",
Expand Down

0 comments on commit 7b9a9ba

Please sign in to comment.