Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
maxdeichmann committed Oct 6, 2023
1 parent 4d04daf commit 4501e25
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 126 deletions.
4 changes: 2 additions & 2 deletions langfuse/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
DatasetRunItem,
DatasetStatus,
Error,
Generations,
LlmUsage,
MapValue,
MethodNotAllowedError,
NotFoundError,
Observation,
ObservationLevel,
Observations,
Score,
Scores,
Trace,
Expand Down Expand Up @@ -62,13 +62,13 @@
"DatasetRunItem",
"DatasetStatus",
"Error",
"Generations",
"LlmUsage",
"MapValue",
"MethodNotAllowedError",
"NotFoundError",
"Observation",
"ObservationLevel",
"Observations",
"Score",
"Scores",
"Trace",
Expand Down
5 changes: 3 additions & 2 deletions langfuse/api/resources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
from .dataset_items import CreateDatasetItemRequest
from .dataset_run_items import CreateDatasetRunItemRequest
from .datasets import CreateDatasetRequest
from .generations import Generations, UpdateGenerationRequest
from .generations import UpdateGenerationRequest
from .observations import Observations
from .score import CreateScoreRequest, Scores
from .span import UpdateSpanRequest
from .trace import CreateTraceRequest, Traces
Expand All @@ -61,13 +62,13 @@
"DatasetRunItem",
"DatasetStatus",
"Error",
"Generations",
"LlmUsage",
"MapValue",
"MethodNotAllowedError",
"NotFoundError",
"Observation",
"ObservationLevel",
"Observations",
"Score",
"Scores",
"Trace",
Expand Down
4 changes: 2 additions & 2 deletions langfuse/api/resources/generations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was auto-generated by Fern from our API Definition.

from .types import Generations, UpdateGenerationRequest
from .types import UpdateGenerationRequest

__all__ = ["Generations", "UpdateGenerationRequest"]
__all__ = ["UpdateGenerationRequest"]
86 changes: 0 additions & 86 deletions langfuse/api/resources/generations/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from ..commons.errors.unauthorized_error import UnauthorizedError
from ..commons.types.create_generation_request import CreateGenerationRequest
from ..commons.types.observation import Observation
from .types.generations import Generations
from .types.update_generation_request import UpdateGenerationRequest


Expand Down Expand Up @@ -109,48 +108,6 @@ def update(self, *, request: UpdateGenerationRequest) -> Observation:
raise ApiError(status_code=_response.status_code, body=_response.text)
raise ApiError(status_code=_response.status_code, body=_response_json)

def get(
self,
*,
page: typing.Optional[int] = None,
limit: typing.Optional[int] = None,
name: typing.Optional[str] = None,
user_id: typing.Optional[str] = None,
) -> Generations:
_response = httpx.request(
"GET",
urllib.parse.urljoin(f"{self._environment}/", "api/public/generations"),
params={"page": page, "limit": limit, "name": name, "userId": user_id},
headers=remove_none_from_headers(
{
"X-Langfuse-Sdk-Name": self.x_langfuse_sdk_name,
"X-Langfuse-Sdk-Version": self.x_langfuse_sdk_version,
"X-Langfuse-Public-Key": self.x_langfuse_public_key,
}
),
auth=(self._username, self._password)
if self._username is not None and self._password is not None
else None,
timeout=60,
)
if 200 <= _response.status_code < 300:
return pydantic.parse_obj_as(Generations, _response.json()) # type: ignore
if _response.status_code == 400:
raise Error(pydantic.parse_obj_as(str, _response.json())) # type: ignore
if _response.status_code == 401:
raise UnauthorizedError(pydantic.parse_obj_as(str, _response.json())) # type: ignore
if _response.status_code == 403:
raise AccessDeniedError(pydantic.parse_obj_as(str, _response.json())) # type: ignore
if _response.status_code == 405:
raise MethodNotAllowedError(pydantic.parse_obj_as(str, _response.json())) # type: ignore
if _response.status_code == 404:
raise NotFoundError(pydantic.parse_obj_as(str, _response.json())) # type: ignore
try:
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, body=_response.text)
raise ApiError(status_code=_response.status_code, body=_response_json)


class AsyncGenerationsClient:
def __init__(
Expand Down Expand Up @@ -241,46 +198,3 @@ async def update(self, *, request: UpdateGenerationRequest) -> Observation:
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, body=_response.text)
raise ApiError(status_code=_response.status_code, body=_response_json)

async def get(
self,
*,
page: typing.Optional[int] = None,
limit: typing.Optional[int] = None,
name: typing.Optional[str] = None,
user_id: typing.Optional[str] = None,
) -> Generations:
async with httpx.AsyncClient() as _client:
_response = await _client.request(
"GET",
urllib.parse.urljoin(f"{self._environment}/", "api/public/generations"),
params={"page": page, "limit": limit, "name": name, "userId": user_id},
headers=remove_none_from_headers(
{
"X-Langfuse-Sdk-Name": self.x_langfuse_sdk_name,
"X-Langfuse-Sdk-Version": self.x_langfuse_sdk_version,
"X-Langfuse-Public-Key": self.x_langfuse_public_key,
}
),
auth=(self._username, self._password)
if self._username is not None and self._password is not None
else None,
timeout=60,
)
if 200 <= _response.status_code < 300:
return pydantic.parse_obj_as(Generations, _response.json()) # type: ignore
if _response.status_code == 400:
raise Error(pydantic.parse_obj_as(str, _response.json())) # type: ignore
if _response.status_code == 401:
raise UnauthorizedError(pydantic.parse_obj_as(str, _response.json())) # type: ignore
if _response.status_code == 403:
raise AccessDeniedError(pydantic.parse_obj_as(str, _response.json())) # type: ignore
if _response.status_code == 405:
raise MethodNotAllowedError(pydantic.parse_obj_as(str, _response.json())) # type: ignore
if _response.status_code == 404:
raise NotFoundError(pydantic.parse_obj_as(str, _response.json())) # type: ignore
try:
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, body=_response.text)
raise ApiError(status_code=_response.status_code, body=_response_json)
3 changes: 1 addition & 2 deletions langfuse/api/resources/generations/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# This file was auto-generated by Fern from our API Definition.

from .generations import Generations
from .update_generation_request import UpdateGenerationRequest

__all__ = ["Generations", "UpdateGenerationRequest"]
__all__ = ["UpdateGenerationRequest"]
3 changes: 3 additions & 0 deletions langfuse/api/resources/observations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
# This file was auto-generated by Fern from our API Definition.

from .types import Observations

__all__ = ["Observations"]
88 changes: 88 additions & 0 deletions langfuse/api/resources/observations/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..commons.errors.not_found_error import NotFoundError
from ..commons.errors.unauthorized_error import UnauthorizedError
from ..commons.types.observation import Observation
from .types.observations import Observations


class ObservationsClient:
Expand Down Expand Up @@ -69,6 +70,49 @@ def get(self, observation_id: str) -> Observation:
raise ApiError(status_code=_response.status_code, body=_response.text)
raise ApiError(status_code=_response.status_code, body=_response_json)

def get_many(
self,
*,
page: typing.Optional[int] = None,
limit: typing.Optional[int] = None,
name: typing.Optional[str] = None,
user_id: typing.Optional[str] = None,
type: typing.Optional[str] = None,
) -> Observations:
_response = httpx.request(
"GET",
urllib.parse.urljoin(f"{self._environment}/", "api/public/observations"),
params={"page": page, "limit": limit, "name": name, "userId": user_id, "type": type},
headers=remove_none_from_headers(
{
"X-Langfuse-Sdk-Name": self.x_langfuse_sdk_name,
"X-Langfuse-Sdk-Version": self.x_langfuse_sdk_version,
"X-Langfuse-Public-Key": self.x_langfuse_public_key,
}
),
auth=(self._username, self._password)
if self._username is not None and self._password is not None
else None,
timeout=60,
)
if 200 <= _response.status_code < 300:
return pydantic.parse_obj_as(Observations, _response.json()) # type: ignore
if _response.status_code == 400:
raise Error(pydantic.parse_obj_as(str, _response.json())) # type: ignore
if _response.status_code == 401:
raise UnauthorizedError(pydantic.parse_obj_as(str, _response.json())) # type: ignore
if _response.status_code == 403:
raise AccessDeniedError(pydantic.parse_obj_as(str, _response.json())) # type: ignore
if _response.status_code == 405:
raise MethodNotAllowedError(pydantic.parse_obj_as(str, _response.json())) # type: ignore
if _response.status_code == 404:
raise NotFoundError(pydantic.parse_obj_as(str, _response.json())) # type: ignore
try:
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, body=_response.text)
raise ApiError(status_code=_response.status_code, body=_response_json)


class AsyncObservationsClient:
def __init__(
Expand Down Expand Up @@ -122,3 +166,47 @@ async def get(self, observation_id: str) -> Observation:
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, body=_response.text)
raise ApiError(status_code=_response.status_code, body=_response_json)

async def get_many(
self,
*,
page: typing.Optional[int] = None,
limit: typing.Optional[int] = None,
name: typing.Optional[str] = None,
user_id: typing.Optional[str] = None,
type: typing.Optional[str] = None,
) -> Observations:
async with httpx.AsyncClient() as _client:
_response = await _client.request(
"GET",
urllib.parse.urljoin(f"{self._environment}/", "api/public/observations"),
params={"page": page, "limit": limit, "name": name, "userId": user_id, "type": type},
headers=remove_none_from_headers(
{
"X-Langfuse-Sdk-Name": self.x_langfuse_sdk_name,
"X-Langfuse-Sdk-Version": self.x_langfuse_sdk_version,
"X-Langfuse-Public-Key": self.x_langfuse_public_key,
}
),
auth=(self._username, self._password)
if self._username is not None and self._password is not None
else None,
timeout=60,
)
if 200 <= _response.status_code < 300:
return pydantic.parse_obj_as(Observations, _response.json()) # type: ignore
if _response.status_code == 400:
raise Error(pydantic.parse_obj_as(str, _response.json())) # type: ignore
if _response.status_code == 401:
raise UnauthorizedError(pydantic.parse_obj_as(str, _response.json())) # type: ignore
if _response.status_code == 403:
raise AccessDeniedError(pydantic.parse_obj_as(str, _response.json())) # type: ignore
if _response.status_code == 405:
raise MethodNotAllowedError(pydantic.parse_obj_as(str, _response.json())) # type: ignore
if _response.status_code == 404:
raise NotFoundError(pydantic.parse_obj_as(str, _response.json())) # type: ignore
try:
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, body=_response.text)
raise ApiError(status_code=_response.status_code, body=_response_json)
5 changes: 5 additions & 0 deletions langfuse/api/resources/observations/types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# This file was auto-generated by Fern from our API Definition.

from .observations import Observations

__all__ = ["Observations"]
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ...utils.resources.pagination.types.meta_response import MetaResponse


class Generations(pydantic.BaseModel):
class Observations(pydantic.BaseModel):
data: typing.List[Observation]
meta: MetaResponse

Expand Down
4 changes: 3 additions & 1 deletion langfuse/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def get_generations(
):
try:
self.log.debug(f"Getting generations... {page}, {limit}, {name}, {user_id}")
return self.client.generations.get(page=page, limit=limit, name=name, user_id=user_id)
return self.client.observations.get_many(
page=page, limit=limit, name=name, user_id=user_id, type="GENERATION"
)
except Exception as e:
self.log.exception(e)
raise e
Expand Down
12 changes: 5 additions & 7 deletions tests/test_core_sdk.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
import logging

from langfuse import Langfuse
from langfuse.model import (
Expand Down Expand Up @@ -533,12 +534,10 @@ def test_get_generations():

langfuse.flush()
generations = langfuse.get_generations(name=generation_name, limit=10, page=1)

assert len(generations.data) == 1
assert generations.data[0].name == generation_name
# TODO: add back in
# assert generations.data[0].input == "great-prompt"
# assert generations.data[0].completion == "great-completion"
assert generations.data[0].input == "great-prompt"
assert generations.data[0].output == {"completion": "great-completion"}


def test_get_generations_by_user():
Expand Down Expand Up @@ -573,6 +572,5 @@ def test_get_generations_by_user():
print(generations)
assert len(generations.data) == 1
assert generations.data[0].name == generation_name
# TODO: add back in
# assert generations.data[0].input == "great-prompt"
# assert generations.data[0].output == "great-completion"
assert generations.data[0].input == "great-prompt"
assert generations.data[0].output == {"completion": "great-completion"}
16 changes: 14 additions & 2 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_linking_observation():
assert run.dataset_run_items[0].observation_id == generation_id


@pytest.mark.skip(reason="inference cost")
# @pytest.mark.skip(reason="inference cost")
def test_langchain_dataset():
langfuse = Langfuse(debug=True)
dataset_name = create_uuid()
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_langchain_dataset():
assert run.dataset_run_items[0].dataset_run_id == run.id

api = get_api()
api.generations.get()

trace = api.trace.get(handler.get_trace_id())

assert len(trace.observations) == 3
Expand All @@ -122,6 +122,18 @@ def test_langchain_dataset():
"dataset_id": dataset.id,
}

generations = list(filter(lambda obs: obs.type == "GENERATION", sorted_observations))
print(generations)
assert len(generations) > 0
for generation in generations:
assert generation.input is not None
assert generation.output is not None
assert generation.input != ""
assert generation.output != ""
assert generation.total_tokens is not None
assert generation.prompt_tokens is not None
assert generation.completion_tokens is not None


def sorted_dependencies(
observations: List[Observation],
Expand Down
Loading

0 comments on commit 4501e25

Please sign in to comment.