Skip to content

Commit c3bf061

Browse files
yunfengzhou-hubSxnan
authored andcommitted
[FLINK-38581][model] Support surfacing error message
This closes #27163
1 parent 7ab3e03 commit c3bf061

File tree

6 files changed

+257
-24
lines changed

6 files changed

+257
-24
lines changed

docs/content.zh/docs/connectors/models/openai.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,13 @@ FROM ML_PREDICT(
115115
</tr>
116116
</tbody>
117117
</table>
118+
119+
### 可用元数据
120+
121+
当配置 `error-handling-strategy``ignore` 时,您可以选择额外指定以下元数据列,将故障信息展示到您的输出流中。
122+
123+
* error-string(STRING):与错误相关的消息
124+
* http-status-code(INT):HTTP状态码
125+
* http-headers-map(MAP<STRING, ARRAY<STRING>>):响应返回的头部信息
126+
127+
如果您在Output Schema中定义了这些元数据列,但调用未失败,则这些列将填充为null值。

docs/content/docs/connectors/models/openai.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ FROM ML_PREDICT(
9494

9595
## Schema Requirement
9696

97+
The following table lists the schema requirement for each task.
98+
9799
<table class="table table-bordered">
98100
<thead>
99101
<tr>
@@ -115,3 +117,15 @@ FROM ML_PREDICT(
115117
</tr>
116118
</tbody>
117119
</table>
120+
121+
### Available Metadata
122+
123+
When configuring `error-handling-strategy` as `ignore`, you can choose to additionally specify the
124+
following metadata columns to surface information about failures into your stream.
125+
126+
* error-string(STRING): A message associated with the error
127+
* http-status-code(INT): The HTTP status code
128+
* http-headers-map(MAP<STRING, ARRAY<STRING>>): The headers returned with the response
129+
130+
If you defined these metadata columns in the output schema but the call did not fail, the columns
131+
will be filled with null values.

flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java

Lines changed: 144 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,41 @@
2121
import org.apache.flink.configuration.DescribedEnum;
2222
import org.apache.flink.configuration.ReadableConfig;
2323
import org.apache.flink.configuration.description.InlineElement;
24+
import org.apache.flink.table.api.DataTypes;
2425
import org.apache.flink.table.catalog.Column;
2526
import org.apache.flink.table.catalog.ResolvedSchema;
27+
import org.apache.flink.table.data.ArrayData;
28+
import org.apache.flink.table.data.GenericArrayData;
29+
import org.apache.flink.table.data.GenericMapData;
30+
import org.apache.flink.table.data.GenericRowData;
2631
import org.apache.flink.table.data.RowData;
32+
import org.apache.flink.table.data.StringData;
33+
import org.apache.flink.table.data.binary.BinaryStringData;
2734
import org.apache.flink.table.factories.ModelProviderFactory;
2835
import org.apache.flink.table.functions.AsyncPredictFunction;
2936
import org.apache.flink.table.functions.FunctionContext;
37+
import org.apache.flink.table.types.DataType;
3038
import org.apache.flink.table.types.logical.LogicalType;
3139
import org.apache.flink.table.types.logical.VarCharType;
40+
import org.apache.flink.util.ExceptionUtils;
41+
import org.apache.flink.util.Preconditions;
3242

3343
import com.openai.client.OpenAIClientAsync;
44+
import com.openai.core.http.Headers;
45+
import com.openai.errors.OpenAIServiceException;
3446
import org.slf4j.Logger;
3547
import org.slf4j.LoggerFactory;
3648

3749
import javax.annotation.Nullable;
3850

51+
import java.util.Arrays;
3952
import java.util.Collection;
4053
import java.util.Collections;
54+
import java.util.HashMap;
4155
import java.util.List;
56+
import java.util.Map;
4257
import java.util.concurrent.CompletableFuture;
58+
import java.util.function.Function;
4359
import java.util.stream.Collectors;
4460

4561
import static org.apache.flink.configuration.description.TextElement.text;
@@ -58,6 +74,7 @@ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction {
5874
private final String model;
5975
@Nullable private final Integer maxContextSize;
6076
private final ContextOverflowAction contextOverflowAction;
77+
protected final List<String> outputColumnNames;
6178

6279
public AbstractOpenAIModelFunction(
6380
ModelProviderFactory.Context factoryContext, ReadableConfig config) {
@@ -79,6 +96,9 @@ public AbstractOpenAIModelFunction(
7996
factoryContext.getCatalogModel().getResolvedInputSchema(),
8097
new VarCharType(VarCharType.MAX_LENGTH),
8198
"input");
99+
100+
this.outputColumnNames =
101+
factoryContext.getCatalogModel().getResolvedOutputSchema().getColumnNames();
82102
}
83103

84104
@Override
@@ -123,23 +143,19 @@ public void close() throws Exception {
123143
protected void validateSingleColumnSchema(
124144
ResolvedSchema schema, LogicalType expectedType, String inputOrOutput) {
125145
List<Column> columns = schema.getColumns();
126-
if (columns.size() != 1) {
127-
throw new IllegalArgumentException(
128-
String.format(
129-
"Model should have exactly one %s column, but actually has %s columns: %s",
130-
inputOrOutput,
131-
columns.size(),
132-
columns.stream().map(Column::getName).collect(Collectors.toList())));
133-
}
134-
135-
Column column = columns.get(0);
136-
if (!column.isPhysical()) {
146+
List<String> physicalColumnNames =
147+
columns.stream()
148+
.filter(Column::isPhysical)
149+
.map(Column::getName)
150+
.collect(Collectors.toList());
151+
if (physicalColumnNames.size() != 1) {
137152
throw new IllegalArgumentException(
138153
String.format(
139-
"%s column %s should be a physical column, but is a %s.",
140-
inputOrOutput, column.getName(), column.getClass()));
154+
"Model should have exactly one %s physical column, but actually has %s physical columns: %s",
155+
inputOrOutput, physicalColumnNames.size(), physicalColumnNames));
141156
}
142157

158+
Column column = schema.getColumn(physicalColumnNames.get(0)).get();
143159
if (!expectedType.equals(column.getDataType().getLogicalType())) {
144160
throw new IllegalArgumentException(
145161
String.format(
@@ -149,6 +165,33 @@ protected void validateSingleColumnSchema(
149165
expectedType,
150166
column.getDataType().getLogicalType()));
151167
}
168+
169+
List<Column> metadataColumns =
170+
columns.stream()
171+
.filter(x -> x instanceof Column.MetadataColumn)
172+
.collect(Collectors.toList());
173+
if (!metadataColumns.isEmpty()) {
174+
Preconditions.checkArgument(
175+
"output".equals(inputOrOutput), "Only output schema supports metadata column");
176+
177+
for (Column metadataColumn : metadataColumns) {
178+
ErrorMessageMetadata errorMessageMetadata =
179+
ErrorMessageMetadata.get(metadataColumn.getName());
180+
Preconditions.checkNotNull(
181+
errorMessageMetadata,
182+
String.format(
183+
"Unexpected metadata column %s. Supported metadata columns:\n%s",
184+
metadataColumn.getName(),
185+
ErrorMessageMetadata.getAllKeysAndDescriptions()));
186+
Preconditions.checkArgument(
187+
errorMessageMetadata.dataType.equals(metadataColumn.getDataType()),
188+
String.format(
189+
"Expected metadata column %s to be of type %s, but is of type %s",
190+
metadataColumn.getName(),
191+
errorMessageMetadata.dataType,
192+
metadataColumn.getDataType()));
193+
}
194+
}
152195
}
153196

154197
protected Collection<RowData> handleErrorsAndRespond(Throwable t) {
@@ -160,7 +203,20 @@ protected Collection<RowData> handleErrorsAndRespond(Throwable t) {
160203
if (finalErrorHandlingStrategy == ErrorHandlingStrategy.FAILOVER) {
161204
throw new RuntimeException(t);
162205
} else if (finalErrorHandlingStrategy == ErrorHandlingStrategy.IGNORE) {
163-
return Collections.emptyList();
206+
LOG.warn(
207+
"The input row data failed to acquire a valid response. Ignoring the input.",
208+
t);
209+
GenericRowData rowData = new GenericRowData(this.outputColumnNames.size());
210+
boolean isMetadataSet = false;
211+
for (int i = 0; i < this.outputColumnNames.size(); i++) {
212+
String columnName = this.outputColumnNames.get(i);
213+
ErrorMessageMetadata errorMessageMetadata = ErrorMessageMetadata.get(columnName);
214+
if (errorMessageMetadata != null) {
215+
rowData.setField(i, errorMessageMetadata.converter.apply(t));
216+
isMetadataSet = true;
217+
}
218+
}
219+
return isMetadataSet ? Collections.singletonList(rowData) : Collections.emptyList();
164220
} else {
165221
throw new UnsupportedOperationException(
166222
"Unsupported error handling strategy: " + finalErrorHandlingStrategy);
@@ -204,4 +260,78 @@ public InlineElement getDescription() {
204260
return text(strategy.description);
205261
}
206262
}
263+
264+
/**
265+
* Metadata that can be read from the output row about error messages. Referenced from Flink
266+
* HTTP Connector's ReadableMetadata.
267+
*/
268+
protected enum ErrorMessageMetadata {
269+
ERROR_STRING(
270+
"error-string",
271+
DataTypes.STRING(),
272+
x -> BinaryStringData.fromString(x.getMessage()),
273+
"A message associated with the error"),
274+
HTTP_STATUS_CODE(
275+
"http-status-code",
276+
DataTypes.INT(),
277+
e ->
278+
ExceptionUtils.findThrowable(e, OpenAIServiceException.class)
279+
.map(OpenAIServiceException::statusCode)
280+
.orElse(null),
281+
"The HTTP status code"),
282+
HTTP_HEADERS_MAP(
283+
"http-headers-map",
284+
DataTypes.MAP(DataTypes.STRING(), DataTypes.ARRAY(DataTypes.STRING())),
285+
e ->
286+
ExceptionUtils.findThrowable(e, OpenAIServiceException.class)
287+
.map(
288+
e1 -> {
289+
Map<StringData, ArrayData> map = new HashMap<>();
290+
Headers headers = e1.headers();
291+
for (String name : headers.names()) {
292+
map.put(
293+
BinaryStringData.fromString(name),
294+
new GenericArrayData(
295+
headers.values(name).stream()
296+
.map(
297+
BinaryStringData
298+
::fromString)
299+
.toArray()));
300+
}
301+
return new GenericMapData(map);
302+
})
303+
.orElse(null),
304+
"The headers returned with the response");
305+
306+
final String key;
307+
final DataType dataType;
308+
final Function<Throwable, Object> converter;
309+
final String description;
310+
311+
ErrorMessageMetadata(
312+
String key,
313+
DataType dataType,
314+
Function<Throwable, Object> converter,
315+
String description) {
316+
this.key = key;
317+
this.dataType = dataType;
318+
this.converter = converter;
319+
this.description = description;
320+
}
321+
322+
static @Nullable ErrorMessageMetadata get(String key) {
323+
for (ErrorMessageMetadata value : values()) {
324+
if (value.key.equals(key)) {
325+
return value;
326+
}
327+
}
328+
return null;
329+
}
330+
331+
static String getAllKeysAndDescriptions() {
332+
return Arrays.stream(values())
333+
.map(value -> value.key + ":\t" + value.description)
334+
.collect(Collectors.joining("\n"));
335+
}
336+
}
207337
}

flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ public class OpenAIChatModelFunction extends AbstractOpenAIModelFunction {
4848
private final String model;
4949
private final String systemPrompt;
5050
private final Configuration config;
51+
private final int outputColumnIndex;
5152

5253
public OpenAIChatModelFunction(
5354
ModelProviderFactory.Context factoryContext, ReadableConfig config) {
@@ -59,6 +60,21 @@ public OpenAIChatModelFunction(
5960
factoryContext.getCatalogModel().getResolvedOutputSchema(),
6061
new VarCharType(VarCharType.MAX_LENGTH),
6162
"output");
63+
this.outputColumnIndex = getOutputColumnIndex();
64+
}
65+
66+
private int getOutputColumnIndex() {
67+
for (int i = 0; i < this.outputColumnNames.size(); i++) {
68+
String columnName = this.outputColumnNames.get(i);
69+
if (ErrorMessageMetadata.get(columnName) == null) {
70+
// Prior checks have guaranteed that there is one and only one physical output
71+
// column.
72+
return i;
73+
}
74+
}
75+
throw new IllegalArgumentException(
76+
"There should be one and only one physical output column. Actual columns: "
77+
+ this.outputColumnNames);
6278
}
6379

6480
@Override
@@ -97,10 +113,15 @@ private Collection<RowData> convertToRowData(
97113

98114
return chatCompletion.choices().stream()
99115
.map(
100-
choice ->
101-
GenericRowData.of(
102-
BinaryStringData.fromString(
103-
choice.message().content().orElse(""))))
116+
choice -> {
117+
GenericRowData rowData =
118+
new GenericRowData(this.outputColumnNames.size());
119+
rowData.setField(
120+
this.outputColumnIndex,
121+
BinaryStringData.fromString(
122+
choice.message().content().orElse("")));
123+
return rowData;
124+
})
104125
.collect(Collectors.toList());
105126
}
106127

flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIEmbeddingModelFunction.java

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public class OpenAIEmbeddingModelFunction extends AbstractOpenAIModelFunction {
4444

4545
private final String model;
4646
@Nullable private final Long dimensions;
47+
private final int outputColumnIndex;
4748

4849
public OpenAIEmbeddingModelFunction(
4950
ModelProviderFactory.Context factoryContext, ReadableConfig config) {
@@ -55,6 +56,21 @@ public OpenAIEmbeddingModelFunction(
5556
factoryContext.getCatalogModel().getResolvedOutputSchema(),
5657
new ArrayType(new FloatType()),
5758
"output");
59+
this.outputColumnIndex = getOutputColumnIndex();
60+
}
61+
62+
private int getOutputColumnIndex() {
63+
for (int i = 0; i < this.outputColumnNames.size(); i++) {
64+
String columnName = this.outputColumnNames.get(i);
65+
if (ErrorMessageMetadata.get(columnName) == null) {
66+
// Prior checks have guaranteed that there is one and only one physical output
67+
// column.
68+
return i;
69+
}
70+
}
71+
throw new IllegalArgumentException(
72+
"There should be one and only one physical output column. Actual columns: "
73+
+ this.outputColumnNames);
5874
}
5975

6076
@Override
@@ -83,12 +99,17 @@ private Collection<RowData> convertToRowData(
8399

84100
return response.data().stream()
85101
.map(
86-
embedding ->
87-
GenericRowData.of(
88-
new GenericArrayData(
89-
embedding.embedding().stream()
90-
.map(Double::floatValue)
91-
.toArray(Float[]::new))))
102+
embedding -> {
103+
GenericRowData rowData =
104+
new GenericRowData(this.outputColumnNames.size());
105+
rowData.setField(
106+
outputColumnIndex,
107+
new GenericArrayData(
108+
embedding.embedding().stream()
109+
.map(Double::floatValue)
110+
.toArray(Float[]::new)));
111+
return rowData;
112+
})
92113
.collect(Collectors.toList());
93114
}
94115
}

0 commit comments

Comments
 (0)