-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API model #34
Comments
Sry, we do not support customized API models currently. You may initialize a new model from BaseAPIModel in lagent and write the code by yourself. Btw, we will release a template file for how to customize an API model class in lagent recently. |
Here is an unfinished reference code import json
import os
import time
from concurrent.futures import ThreadPoolExecutor, wait
from logging import getLogger
from threading import Lock
from typing import Dict, List, Optional, Union
import requests
from .base_api import BaseAPIModel, APITemplateParser
class CustomAPI(BaseAPIModel):
"""Model wrapper around Custom API models.
Args:
model_url (str): The url of the requested API model.
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
retry (int): Number of retires if the API call fails. Defaults to 2.
key (str or List[str]): key(s) for API model. In particular, when it
is set to "ENV", If it's a list, the keys will be used in round-robin
manner. Defaults to 'ENV'.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
gen_params: Default generation configuration which could be overridden
on the fly of generation.
"""
def __init__(self,
model_type: str,
model_url: str,
query_per_second: int = 1,
retry: int = 2,
key: Union[str, List[str]],
meta_template: Optional[Dict] = [
dict(role='system', api_role='system'),
dict(role='user', api_role='user'),
dict(role='assistant', api_role='assistant')
],
**gen_params):
self.url = model_url
super().__init__(
model_type=model_type,
meta_template=meta_template,
query_per_second=query_per_second,
retry=retry,
**gen_params)
self.logger = getLogger(__name__)
if key is None:
self.keys = None
elif isinstance(key, str):
self.keys = [key]
else:
self.keys = key
def _generate(self,
inputs: str or List,
max_out_len: int = None,
temperature: float = None) -> str:
"""Generate results given a list of inputs.
Args:
inputs (str or List): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic.
Returns:
str: The generated string.
"""
assert isinstance(inputs, (str))
max_num_retries = 0
while max_num_retries < self.retry:
header = {
'content-type': 'application/json',
}
self._session_id = (self._session_id + 1) % 1000000
try:
data = dict(
model=self.path,
session_id=self._session_id,
prompt=inputs,
sequence_start=True,
sequence_end=True,
max_tokens=max_out_len,
)
raw_response = requests.post(
self.url, headers=header, data=json.dumps(data))
except requests.ConnectionError:
print('Got connection error, retrying...')
max_num_retries += 1
continue
try:
response = raw_response.json()
except requests.JSONDecodeError:
print('JsonDecode error, got', str(raw_response.content))
max_num_retries += 1
continue
try:
if 'completion' in self.url:
return response['choices'][0]['text'].strip()
else:
return response['text'].strip()
except KeyError:
max_num_retries += 1
pass
raise RuntimeError('Calling API model failed after retrying for '
f'{max_num_retries} times. Check the logs for '
'details.') |
Thanks! I will try. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
If I want to test the qwen model with the API, can I just use the GPTAPI class and replace the model URL with the qwen one?
The text was updated successfully, but these errors were encountered: