forked from qcri/LLMeBench
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_base.py
138 lines (114 loc) · 4.35 KB
/
model_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import logging
import sys
import traceback
from abc import ABC, abstractmethod
from pathlib import Path
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
def log_retry(retry_state):
if retry_state.attempt_number == 1:
return
logging.warning(f"Request failed, retry attempt {retry_state.attempt_number}...")
class ModelBase(object):
"""
Base class for models
Implementations of this class need to define at least two mandatory methods;
`prompt()` and `summarize_response()`. Implementations of this class should target
a specific model inference API, such as a platform (Azure, OpenAI), custom
hosted inference server (Petals, FastChat) or other model-specific APIs.
Attributes
----------
max_tries : int, defaults to 5
Defines how many retries are allowed per-sample in case of failure.
Failure is defined by `retry_exceptions`.
retry_exceptions : tuple
Tuple of exceptions on which the framework should retry the request
for any given sample. Specific exceptions should be included by the
implementing class, such as HTTP Request failures (in case of HTTP-
based APIs).
Methods
-------
prompt(processed_input):
Method that takes inputs from an asset and makes the actual request
to the underlying model inference API.
summarize_response(response):
Method that takes a model response and summarizes it into a simpler
form
run_model(processed_input):
Wrapper that calls the `prompt` method and captures exceptions
"""
def __init__(self, max_tries=5, retry_exceptions=(), **kwargs):
self.max_tries = max_tries
# Instantiate retrying mechanism
self.prompt = retry(
wait=wait_random_exponential(multiplier=1, max=60),
stop=stop_after_attempt(self.max_tries),
retry=retry_if_exception_type(retry_exceptions),
before=log_retry,
reraise=True,
)(self.prompt)
@abstractmethod
def prompt(self, processed_input):
"""
Method that implements communication to the underlying model
Arguments
---------
processed_input : dict
Input from an asset. The structure of this will be dependent
on a specific model implementation, and must be documented by
the class implementation itself
Returns
-------
response : mixed
Response form the underlying model API
Notes
-----
Ideally, this method will never be called directly, but through the
`run_model` wrapper which takes care of returning the output in a
consistent manner and also handles errors/exceptions.
"""
pass
@abstractmethod
def summarize_response(self, response):
"""
Method that summarizes/simplifies a model's response
Arguments
---------
response : mixed
Response from `prompt()`
Returns
-------
simplified_response : mixed
Should ideally be a short string that summarizes the model's response
(e.g. only the actual label instead of scores and other metadata). Will
be saved in the summary file for quick debugging. If the response is not
simplifiable, return the response object as is.
"""
pass
def run_model(self, processed_input):
"""
Wrapper that calls the `prompt` method and captures exceptions
Arguments
---------
processed_input : dict
Input from an asset. The structure of this will be dependent
on a specific model implementation, and must be documented by
the class implementation itself
Returns
-------
response : dict
Returns a dictionary with the key "response" holding the model's
response, or "failure_exception" with the error that occurred when
using the model
"""
try:
response = self.prompt(processed_input)
return {"response": response}
except Exception as e:
exc_info = sys.exc_info()
exception_str = "".join(traceback.format_exception(*exc_info))
return {"failure_exception": exception_str}