2121import org .apache .flink .configuration .DescribedEnum ;
2222import org .apache .flink .configuration .ReadableConfig ;
2323import org .apache .flink .configuration .description .InlineElement ;
24+ import org .apache .flink .table .api .DataTypes ;
2425import org .apache .flink .table .catalog .Column ;
2526import org .apache .flink .table .catalog .ResolvedSchema ;
27+ import org .apache .flink .table .data .ArrayData ;
28+ import org .apache .flink .table .data .GenericArrayData ;
29+ import org .apache .flink .table .data .GenericMapData ;
30+ import org .apache .flink .table .data .GenericRowData ;
2631import org .apache .flink .table .data .RowData ;
32+ import org .apache .flink .table .data .StringData ;
33+ import org .apache .flink .table .data .binary .BinaryStringData ;
2734import org .apache .flink .table .factories .ModelProviderFactory ;
2835import org .apache .flink .table .functions .AsyncPredictFunction ;
2936import org .apache .flink .table .functions .FunctionContext ;
37+ import org .apache .flink .table .types .DataType ;
3038import org .apache .flink .table .types .logical .LogicalType ;
3139import org .apache .flink .table .types .logical .VarCharType ;
40+ import org .apache .flink .util .ExceptionUtils ;
41+ import org .apache .flink .util .Preconditions ;
3242
3343import com .openai .client .OpenAIClientAsync ;
44+ import com .openai .core .http .Headers ;
45+ import com .openai .errors .OpenAIServiceException ;
3446import org .slf4j .Logger ;
3547import org .slf4j .LoggerFactory ;
3648
3749import javax .annotation .Nullable ;
3850
51+ import java .util .Arrays ;
3952import java .util .Collection ;
4053import java .util .Collections ;
54+ import java .util .HashMap ;
4155import java .util .List ;
56+ import java .util .Map ;
4257import java .util .concurrent .CompletableFuture ;
58+ import java .util .function .Function ;
4359import java .util .stream .Collectors ;
4460
4561import static org .apache .flink .configuration .description .TextElement .text ;
@@ -58,6 +74,7 @@ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction {
5874 private final String model ;
5975 @ Nullable private final Integer maxContextSize ;
6076 private final ContextOverflowAction contextOverflowAction ;
77+ protected final List <String > outputColumnNames ;
6178
6279 public AbstractOpenAIModelFunction (
6380 ModelProviderFactory .Context factoryContext , ReadableConfig config ) {
@@ -79,6 +96,9 @@ public AbstractOpenAIModelFunction(
7996 factoryContext .getCatalogModel ().getResolvedInputSchema (),
8097 new VarCharType (VarCharType .MAX_LENGTH ),
8198 "input" );
99+
100+ this .outputColumnNames =
101+ factoryContext .getCatalogModel ().getResolvedOutputSchema ().getColumnNames ();
82102 }
83103
84104 @ Override
@@ -123,23 +143,19 @@ public void close() throws Exception {
123143 protected void validateSingleColumnSchema (
124144 ResolvedSchema schema , LogicalType expectedType , String inputOrOutput ) {
125145 List <Column > columns = schema .getColumns ();
126- if (columns .size () != 1 ) {
127- throw new IllegalArgumentException (
128- String .format (
129- "Model should have exactly one %s column, but actually has %s columns: %s" ,
130- inputOrOutput ,
131- columns .size (),
132- columns .stream ().map (Column ::getName ).collect (Collectors .toList ())));
133- }
134-
135- Column column = columns .get (0 );
136- if (!column .isPhysical ()) {
146+ List <String > physicalColumnNames =
147+ columns .stream ()
148+ .filter (Column ::isPhysical )
149+ .map (Column ::getName )
150+ .collect (Collectors .toList ());
151+ if (physicalColumnNames .size () != 1 ) {
137152 throw new IllegalArgumentException (
138153 String .format (
139- "%s column %s should be a physical column, but is a %s. " ,
140- inputOrOutput , column . getName (), column . getClass () ));
154+ "Model should have exactly one %s physical column, but actually has %s physical columns: %s " ,
155+ inputOrOutput , physicalColumnNames . size (), physicalColumnNames ));
141156 }
142157
158+ Column column = schema .getColumn (physicalColumnNames .get (0 )).get ();
143159 if (!expectedType .equals (column .getDataType ().getLogicalType ())) {
144160 throw new IllegalArgumentException (
145161 String .format (
@@ -149,6 +165,33 @@ protected void validateSingleColumnSchema(
149165 expectedType ,
150166 column .getDataType ().getLogicalType ()));
151167 }
168+
169+ List <Column > metadataColumns =
170+ columns .stream ()
171+ .filter (x -> x instanceof Column .MetadataColumn )
172+ .collect (Collectors .toList ());
173+ if (!metadataColumns .isEmpty ()) {
174+ Preconditions .checkArgument (
175+ "output" .equals (inputOrOutput ), "Only output schema supports metadata column" );
176+
177+ for (Column metadataColumn : metadataColumns ) {
178+ ErrorMessageMetadata errorMessageMetadata =
179+ ErrorMessageMetadata .get (metadataColumn .getName ());
180+ Preconditions .checkNotNull (
181+ errorMessageMetadata ,
182+ String .format (
183+ "Unexpected metadata column %s. Supported metadata columns:\n %s" ,
184+ metadataColumn .getName (),
185+ ErrorMessageMetadata .getAllKeysAndDescriptions ()));
186+ Preconditions .checkArgument (
187+ errorMessageMetadata .dataType .equals (metadataColumn .getDataType ()),
188+ String .format (
189+ "Expected metadata column %s to be of type %s, but is of type %s" ,
190+ metadataColumn .getName (),
191+ errorMessageMetadata .dataType ,
192+ metadataColumn .getDataType ()));
193+ }
194+ }
152195 }
153196
154197 protected Collection <RowData > handleErrorsAndRespond (Throwable t ) {
@@ -160,7 +203,20 @@ protected Collection<RowData> handleErrorsAndRespond(Throwable t) {
160203 if (finalErrorHandlingStrategy == ErrorHandlingStrategy .FAILOVER ) {
161204 throw new RuntimeException (t );
162205 } else if (finalErrorHandlingStrategy == ErrorHandlingStrategy .IGNORE ) {
163- return Collections .emptyList ();
206+ LOG .warn (
207+ "The input row data failed to acquire a valid response. Ignoring the input." ,
208+ t );
209+ GenericRowData rowData = new GenericRowData (this .outputColumnNames .size ());
210+ boolean isMetadataSet = false ;
211+ for (int i = 0 ; i < this .outputColumnNames .size (); i ++) {
212+ String columnName = this .outputColumnNames .get (i );
213+ ErrorMessageMetadata errorMessageMetadata = ErrorMessageMetadata .get (columnName );
214+ if (errorMessageMetadata != null ) {
215+ rowData .setField (i , errorMessageMetadata .converter .apply (t ));
216+ isMetadataSet = true ;
217+ }
218+ }
219+ return isMetadataSet ? Collections .singletonList (rowData ) : Collections .emptyList ();
164220 } else {
165221 throw new UnsupportedOperationException (
166222 "Unsupported error handling strategy: " + finalErrorHandlingStrategy );
@@ -204,4 +260,78 @@ public InlineElement getDescription() {
204260 return text (strategy .description );
205261 }
206262 }
263+
264+ /**
265+ * Metadata that can be read from the output row about error messages. Referenced from Flink
266+ * HTTP Connector's ReadableMetadata.
267+ */
268+ protected enum ErrorMessageMetadata {
269+ ERROR_STRING (
270+ "error-string" ,
271+ DataTypes .STRING (),
272+ x -> BinaryStringData .fromString (x .getMessage ()),
273+ "A message associated with the error" ),
274+ HTTP_STATUS_CODE (
275+ "http-status-code" ,
276+ DataTypes .INT (),
277+ e ->
278+ ExceptionUtils .findThrowable (e , OpenAIServiceException .class )
279+ .map (OpenAIServiceException ::statusCode )
280+ .orElse (null ),
281+ "The HTTP status code" ),
282+ HTTP_HEADERS_MAP (
283+ "http-headers-map" ,
284+ DataTypes .MAP (DataTypes .STRING (), DataTypes .ARRAY (DataTypes .STRING ())),
285+ e ->
286+ ExceptionUtils .findThrowable (e , OpenAIServiceException .class )
287+ .map (
288+ e1 -> {
289+ Map <StringData , ArrayData > map = new HashMap <>();
290+ Headers headers = e1 .headers ();
291+ for (String name : headers .names ()) {
292+ map .put (
293+ BinaryStringData .fromString (name ),
294+ new GenericArrayData (
295+ headers .values (name ).stream ()
296+ .map (
297+ BinaryStringData
298+ ::fromString )
299+ .toArray ()));
300+ }
301+ return new GenericMapData (map );
302+ })
303+ .orElse (null ),
304+ "The headers returned with the response" );
305+
306+ final String key ;
307+ final DataType dataType ;
308+ final Function <Throwable , Object > converter ;
309+ final String description ;
310+
311+ ErrorMessageMetadata (
312+ String key ,
313+ DataType dataType ,
314+ Function <Throwable , Object > converter ,
315+ String description ) {
316+ this .key = key ;
317+ this .dataType = dataType ;
318+ this .converter = converter ;
319+ this .description = description ;
320+ }
321+
322+ static @ Nullable ErrorMessageMetadata get (String key ) {
323+ for (ErrorMessageMetadata value : values ()) {
324+ if (value .key .equals (key )) {
325+ return value ;
326+ }
327+ }
328+ return null ;
329+ }
330+
331+ static String getAllKeysAndDescriptions () {
332+ return Arrays .stream (values ())
333+ .map (value -> value .key + ":\t " + value .description )
334+ .collect (Collectors .joining ("\n " ));
335+ }
336+ }
207337}
0 commit comments