Skip to content

Commit

Permalink
Change transform schema to move error column to the back
Browse files Browse the repository at this point in the history
  • Loading branch information
sss04 committed Jan 7, 2025
1 parent 07b7a9a commit 3484bb0
Showing 1 changed file with 8 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, functions => F, types => T}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.types.{DataType, StructField, StructType}

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -265,7 +265,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}

override def transformSchema(schema: StructType): StructType = {
openAICompletion match {
val transformedSchema = openAICompletion match {
case chatCompletion: OpenAIChatCompletion =>
chatCompletion
.transformSchema(schema.add(getMessagesCol, StructType(Seq())))
Expand All @@ -275,6 +275,12 @@ class OpenAIPrompt(override val uid: String) extends Transformer
.transformSchema(schema)
.add(getPostProcessing, getParser.outputSchema)
}

// Move error column to back
val errorFieldOpt: Option[StructField] = transformedSchema.fields.find(_.name == getErrorCol)
val fieldsWithoutError: Array[StructField] = transformedSchema.fields.filterNot(_.name == getErrorCol)
val reorderedFields = Array.concat(fieldsWithoutError, errorFieldOpt.toArray)
StructType(reorderedFields)
}
}

Expand Down

0 comments on commit 3484bb0

Please sign in to comment.