Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add send to process and debug on Ask #130

Merged
merged 2 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## 4.3.9 (unreleased)


- Nothing changed yet.
- Debug on ask and send to process


## 4.3.8 (2024-12-09)
Expand Down
38 changes: 38 additions & 0 deletions nuclia/sdk/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,25 @@ def update(
else:
raise ValueError("Either rid or slug must be provided")

@kb
def send_to_process(
self, *, rid: Optional[str] = None, slug: Optional[str] = None, **kwargs
):
ndb = kwargs["ndb"]
kw = {
"kbid": ndb.kbid,
}
if rid:
kw["rid"] = rid
if slug:
kw["slug"] = slug
ndb.ndb.reprocess_resource(**kw)
elif slug:
kw["rslug"] = slug
ndb.ndb.reprocess_resource_by_slug(**kw)
else:
raise ValueError("Either rid or slug must be provided")

@kb
def delete(
self, *, rid: Optional[str] = None, slug: Optional[str] = None, **kwargs
Expand Down Expand Up @@ -373,6 +392,25 @@ async def download_file(
if chunk:
f.write(chunk)

@kb
async def send_to_process(
self, *, rid: Optional[str] = None, slug: Optional[str] = None, **kwargs
):
ndb = kwargs["ndb"]
kw = {
"kbid": ndb.kbid,
}
if rid:
kw["rid"] = rid
if slug:
kw["slug"] = slug
await ndb.ndb.reprocess_resource(**kw)
elif slug:
kw["rslug"] = slug
await ndb.ndb.reprocess_resource_by_slug(**kw)
else:
raise ValueError("Either rid or slug must be provided")

@kb
async def update(
self, *, rid: Optional[str] = None, slug: Optional[str] = None, **kwargs
Expand Down
58 changes: 54 additions & 4 deletions nuclia/sdk/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
SearchOptions,
SearchRequest,
SyncAskResponse,
ChatModel,
)
from pydantic import ValidationError

Expand All @@ -40,6 +41,11 @@ class AskAnswer:
timings: Optional[Dict[str, float]]
tokens: Optional[Dict[str, int]]
retrieval_best_matches: Optional[List[str]]
status: Optional[str]
prompt_context: Optional[list[str]]
relations: Optional[Relations]
predict_request: Optional[ChatModel]
error_details: Optional[str]

def __str__(self):
if self.answer:
Expand Down Expand Up @@ -189,12 +195,23 @@ def ask(
timings=None,
tokens=None,
object=ask_response.answer_json,
status=ask_response.status,
error_details=ask_response.error_details,
predict_request=ChatModel.model_validate(ask_response.predict_request)
if ask_response.predict_request is not None
else None,
relations=ask_response.relations,
prompt_context=ask_response.prompt_context,
)

if ask_response.prompt_context:
result.prompt_context = ask_response.prompt_context
if ask_response.metadata is not None:
if ask_response.metadata.timings is not None:
result.timings = ask_response.metadata.timings.model_dump()
if ask_response.metadata.tokens is not None:
result.tokens = ask_response.metadata.tokens.model_dump()

return result

@kb
Expand Down Expand Up @@ -264,6 +281,13 @@ def ask_json(
timings=None,
tokens=None,
object=ask_response.answer_json,
status=ask_response.status,
error_details=ask_response.error_details,
predict_request=ChatModel.model_validate(ask_response.predict_request)
if ask_response.predict_request is not None
else None,
relations=ask_response.relations,
prompt_context=ask_response.prompt_context,
)
if ask_response.metadata is not None:
if ask_response.metadata.timings is not None:
Expand Down Expand Up @@ -399,6 +423,11 @@ async def ask(
timings=None,
tokens=None,
object=None,
status=None,
error_details=None,
predict_request=None,
relations=None,
prompt_context=None,
)
async for line in ask_stream_response.aiter_lines():
try:
Expand All @@ -422,8 +451,16 @@ async def ask(
if ask_response_item.tokens:
result.tokens = ask_response_item.tokens.model_dump()
elif ask_response_item.type == "status":
# Status is ignored
pass
result.status = ask_response_item.status
elif ask_response_item.type == "prequeries":
result.prequeries = ask_response_item.results
elif ask_response_item.type == "error":
result.error_details = ask_response_item.error
elif ask_response_item.type == "debug":
result.prompt_context = ask_response_item.metadata.get("prompt_context")
result.predict_request = ask_response_item.metadata.get(
"predict_request"
)
else: # pragma: no cover
warnings.warn(f"Unknown ask stream item type: {ask_response_item.type}")
return result
Expand Down Expand Up @@ -512,6 +549,11 @@ async def ask_json(
timings=None,
tokens=None,
object=None,
status=None,
error_details=None,
predict_request=None,
relations=None,
prompt_context=None,
)
async for line in ask_stream_response.aiter_lines():
try:
Expand All @@ -535,8 +577,16 @@ async def ask_json(
if ask_response_item.tokens:
result.tokens = ask_response_item.tokens.model_dump()
elif ask_response_item.type == "status":
# Status is ignored
pass
result.status = ask_response_item.status
elif ask_response_item.type == "prequeries":
result.prequeries = ask_response_item.results
elif ask_response_item.type == "error":
result.error_details = ask_response_item.error
elif ask_response_item.type == "debug":
result.prompt_context = ask_response_item.metadata.get("prompt_context")
result.predict_request = ask_response_item.metadata.get(
"predict_request"
)
else: # pragma: no cover
warnings.warn(f"Unknown ask stream item type: {ask_response_item.type}")
return result
Loading