Skip to content

Commit

Permalink
Ericbrehault/sc 7420/support api changes in the cli (#48)
Browse files Browse the repository at this point in the history
* support new regional endpoints

* support the new regional endpoints

* lint

* fix tests on stage

* fix account_id param

* fix delete kb with new endpoint

* mypy

* fix test
  • Loading branch information
ebrehault authored Dec 8, 2023
1 parent 23a0714 commit be97eb9
Show file tree
Hide file tree
Showing 13 changed files with 262 additions and 60 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/stage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: 3.9
cache: "pip"
cache: 'pip'

- name: Install package
run: make install
Expand All @@ -32,4 +32,4 @@ jobs:
run: make lint

- name: Test
run: BASE_NUCLIA_DOMAIN="stashify.cloud" GA_TESTING_SERVICE_TOKEN="${{ secrets.STAGE_TESTING_SERVICE_TOKEN }}" GA_TESTING_TOKEN="${{ secrets.STAGE_TESTING_TOKEN }}" GA_TESTING_NUA="${{ secrets.STAGE_TESTING_NUA }}" make test
run: USE_NEW_REGIONAL_ENDPOINTS="TRUE" BASE_NUCLIA_DOMAIN="stashify.cloud" GA_TESTING_SERVICE_TOKEN="${{ secrets.STAGE_TESTING_SERVICE_TOKEN }}" GA_TESTING_TOKEN="${{ secrets.STAGE_TESTING_TOKEN }}" GA_TESTING_NUA="${{ secrets.STAGE_TESTING_NUA }}" make test
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## 1.1.21 (unreleased)


- Nothing changed yet.
- Support the new regional endpoints.


## 1.1.20 (2023-12-05)
Expand Down
57 changes: 57 additions & 0 deletions docs/07-nua.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,38 @@ It can generate text from a prompt:
predict.generate(text="How to tell a good story?")
```

It can summarize a list of texts:

- CLI:

```bash
nuclia nua predict summarize --texts='["TEXT1", "TEXT2"]'
```

- SDK:

```python
from nuclia import sdk
predict = sdk.NucliaPredict()
predict.summarize(texts=["TEXT1", "TEXT2"])
```

It can generate a response to a question given a context:

- CLI:

```bash
nuclia nua predict rag --question="QUESTION" --context='["TEXT1", "TEXT2"]'
```

- SDK:

```python
from nuclia import sdk
predict = sdk.NucliaPredict()
predict.rag(question="QUESTION", context=["TEXT1", "TEXT2"])
```

### Agent

`agent` allows to generate LLM agents from an initial prompt:
Expand All @@ -110,3 +142,28 @@ It can generate text from a prompt:
```

(with the SDK, you will obtain an agent directly, you can call `ask` on it to generate answers)

### Process

`process` allows to process a file:

- CLI:

```bash
nuclia nua process file --path="path/to/file.txt"
```

And you can check the status with:

```bash
nuclia nua process status
```

- SDK:

```python
from nuclia import sdk
process = sdk.NucliaProcess()
process.file(path="path/to/file.txt")
print(process.status())
```
2 changes: 2 additions & 0 deletions nuclia/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
REGIONAL = "https://{region}." + BASE_DOMAIN
CLOUD_ID = BASE.split("/")[-1]

USE_NEW_REGIONAL_ENDPOINTS = os.environ.get("USE_NEW_REGIONAL_ENDPOINTS", "") == "TRUE"


def get_global_url(path: str):
return BASE + path
Expand Down
23 changes: 22 additions & 1 deletion nuclia/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class KnowledgeBox(BaseModel):
account: Optional[str] = None

def __str__(self):
return f"{self.id:36} -> {'(' + self.account + ')' if self.account else ''} {self.title}"
return f"{self.id:36} -> {self.slug} {'(account: ' + self.account + ')' if self.account else ''}"


class NuaKey(BaseModel):
Expand Down Expand Up @@ -58,6 +58,7 @@ class Selection(BaseModel):
kbid: Optional[str] = None
account: Optional[str] = None
nucliadb: Optional[str] = None
zone: Optional[str] = None


class Config(BaseModel):
Expand Down Expand Up @@ -201,6 +202,17 @@ def set_default_account(self, account: str):
self.default.account = account
self.save()

def get_default_zone(self) -> Optional[str]:
if self.default is None or self.default.zone is None:
return None
return self.default.zone

def set_default_zone(self, zone: str):
if self.default is None:
self.default = Selection()
self.default.zone = zone
self.save()

def get_default_kb(self) -> str:
if self.default is None or self.default.kbid is None:
raise NotDefinedDefault()
Expand Down Expand Up @@ -266,6 +278,15 @@ def retrieve_nua(nuas: List[NuaKey], nua: str) -> Optional[NuaKey]:
return nua_obj


def retrieve_account(accounts: List[Account], account: str) -> Optional[Account]:
account_obj: Optional[Account] = None
try:
account_obj = next(filter(lambda x: x.slug == account, accounts))
except StopIteration:
pass
return account_obj


def set_config_file(path: str):
global CONFIG_PATH
CONFIG_PATH = path
Expand Down
32 changes: 26 additions & 6 deletions nuclia/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import yaml

from nuclia import BASE_DOMAIN
from nuclia import BASE_DOMAIN, USE_NEW_REGIONAL_ENDPOINTS
from nuclia.data import get_auth
from nuclia.exceptions import NotDefinedDefault
from nuclia.lib.kb import Environment, NucliaDBClient
Expand All @@ -24,7 +24,11 @@ def kbs(func):
def wrapper_checkout_kbs(*args, **kwargs):
if "account" in kwargs:
auth = get_auth()
auth.kbs(kwargs["account"])
if not USE_NEW_REGIONAL_ENDPOINTS:
auth.kbs(kwargs["account"])
else:
account_id = auth.get_account_id(kwargs["account"])
auth.kbs(account_id)
return func(*args, **kwargs)

return wrapper_checkout_kbs
Expand Down Expand Up @@ -131,13 +135,17 @@ def wrapper_checkout_nua(*args, **kwargs):
def account(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not kwargs.get("account"):
auth = get_auth()
account_slug = kwargs.get("account")
account_id = kwargs.get("account_id")
auth = get_auth()
if not account_id and not account_slug:
account_slug = auth._config.get_default_account()
if account_slug is None:
raise NotDefinedDefault()
else:
kwargs["account"] = account_slug
kwargs["account"] = account_slug
if not account_id:
account_id = auth.get_account_id(account_slug)
kwargs["account_id"] = account_id
return func(*args, **kwargs)

return wrapper
Expand All @@ -154,3 +162,15 @@ def wrapper(*args, **kwargs):
return result

return wrapper


def zone(func):
@wraps(func)
def wrapper_checkout_zone(*args, **kwargs):
zone = kwargs.get("zone")
if not zone:
auth = get_auth()
kwargs["zone"] = auth._config.get_default_zone()
return func(*args, **kwargs)

return wrapper_checkout_zone
2 changes: 1 addition & 1 deletion nuclia/lib/nua.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import requests

from nuclia import REGIONAL
from nuclia.exceptions import NuaAPIException, AlreadyConsumed
from nuclia.exceptions import AlreadyConsumed, NuaAPIException
from nuclia.lib.nua_responses import (
Author,
ChatModel,
Expand Down
92 changes: 68 additions & 24 deletions nuclia/sdk/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
import requests
from prompt_toolkit import prompt

from nuclia import BASE, BASE_DOMAIN, get_global_url
from nuclia import USE_NEW_REGIONAL_ENDPOINTS, get_global_url, get_regional_url
from nuclia.cli.utils import yes_no
from nuclia.config import Account, Config, KnowledgeBox, Zone
from nuclia.config import Account, Config, KnowledgeBox, Zone, retrieve_account
from nuclia.exceptions import NeedUserToken, UserTokenExpired

USER = f"{BASE}/api/v1/user/welcome"
MEMBER = f"{BASE}/api/v1/user"
ACCOUNTS = f"{BASE}/api/v1/accounts"
ZONES = f"{BASE}/api/v1/zones"
LIST_KBS = BASE + "/api/v1/account/{account}/kbs"
USER = "/api/v1/user/welcome"
MEMBER = "/api/v1/user"
ACCOUNTS = "/api/v1/accounts"
ZONES = "/api/v1/zones"
LIST_KBS = "/api/v1/account/{account}/kbs"
VERIFY_NUA = "/api/authorizer/info"


Expand Down Expand Up @@ -175,7 +175,7 @@ def validate_kb(
return None, None

def _show_user(self):
resp = self._request("GET", MEMBER)
resp = self._request("GET", get_global_url(MEMBER))
print(f"User: {resp.get('name')} <{resp.get('email')}>")
print(f"Type: {resp.get('type')}")

Expand Down Expand Up @@ -220,7 +220,7 @@ def _validate_user_token(self, code: Optional[str] = None) -> bool:
if code is None:
code = self._config.token
resp = requests.get(
USER,
get_global_url(USER),
headers={"Authorization": f"Bearer {code}"},
)
if resp.status_code == 200:
Expand Down Expand Up @@ -261,7 +261,7 @@ def _request(
raise Exception({"status": resp.status_code, "message": resp.text})

def accounts(self) -> List[Account]:
accounts = self._request("GET", ACCOUNTS)
accounts = self._request("GET", get_global_url(ACCOUNTS))
result = []
self._config.accounts = []
for account in accounts:
Expand All @@ -272,7 +272,7 @@ def accounts(self) -> List[Account]:
return result

def zones(self) -> List[Zone]:
zones = self._request("GET", ZONES)
zones = self._request("GET", get_global_url(ZONES))
if self._config.accounts is None:
self._config.accounts = []
self._config.zones = []
Expand All @@ -285,19 +285,63 @@ def zones(self) -> List[Zone]:
return result

def kbs(self, account: str):
path = LIST_KBS.format(account=account)
try:
kbs = self._request("GET", path)
except UserTokenExpired:
return []
result = []
zones = self.zones()
region = {zone.id: zone.slug for zone in zones}
for kb in kbs:
zone = region[kb["zone"]]
url = f"https://{zone}.{BASE_DOMAIN}/api/v1/kb/{kb['id']}"
kb_obj = KnowledgeBox(
url=url, id=kb["id"], title=kb["title"], account=account, region=zone
)
result.append(kb_obj)
if not USE_NEW_REGIONAL_ENDPOINTS:
path = get_global_url(LIST_KBS.format(account=account))
try:
kbs = self._request("GET", path)
except UserTokenExpired:
return []
region = {zone.id: zone.slug for zone in zones}
for kb in kbs:
zone = region[kb["zone"]]
if not zone:
continue
url = get_regional_url(zone, f"/api/v1/kb/{kb['id']}")
kb_obj = KnowledgeBox(
url=url,
id=kb["id"],
slug=kb["slug"],
title=kb["title"],
account=account,
region=zone,
)
result.append(kb_obj)
else:
for zoneObj in zones:
zoneSlug = zoneObj.slug
if not zoneSlug:
continue
path = get_regional_url(zoneSlug, LIST_KBS.format(account=account))
try:
kbs = self._request("GET", path)
except UserTokenExpired:
return []
except requests.exceptions.ConnectionError:
print(
f"Connection error to {get_regional_url(zoneSlug, '')}, skipping zone"
)
continue
for kb in kbs:
url = get_regional_url(zoneSlug, f"/api/v1/kb/{kb['id']}")
kb_obj = KnowledgeBox(
url=url,
id=kb["id"],
slug=kb["slug"],
title=kb["title"],
account=account,
region=zoneSlug,
)
result.append(kb_obj)
return result

def get_account_id(self, account_slug: str) -> str:
if not USE_NEW_REGIONAL_ENDPOINTS:
account_id = account_slug
else:
account_obj = retrieve_account(self._config.accounts or [], account_slug)
if not account_obj:
raise ValueError(f"Account {account_slug} not found")
account_id = account_obj.id
return account_id
Loading

0 comments on commit be97eb9

Please sign in to comment.