Skip to content

Commit

Permalink
apacheGH-38017: [Go][FlightSQL] Increment types handled by internal c…
Browse files Browse the repository at this point in the history
…onverter (apache#38028)

### Rationale for this change
This PR targets flightsql.Driver, that complies with sql.Driver interface, from the stdlib. <br>

The driver has a resultset iterator: the type [Rows](https://github.com/apache/arrow/blob/7d834d65c37c17d1c19bfb497eadb983893c9ea0/go/arrow/flight/flightsql/driver/driver.go#L39). <br>

The method [Rows.Next()](https://github.com/apache/arrow/blob/7d834d65c37c17d1c19bfb497eadb983893c9ea0/go/arrow/flight/flightsql/driver/driver.go#L81), which populates the next row of data, lacks handling/treatment of some common types, as described in the [issue](apache#38017). <br>

### What changes are included in this PR?
This PR includes the missing common basic data types for Rows.Next() method. <br>

### Are these changes tested?
The driver is already tested, but I also included unitary tests that cover not only the new types, but all the implemented types. <br>

### Are there any user-facing changes?
All the contracts and signatures, and the current types were preserved. <br>

Closes apache#38017 

Authored-by: miguel pragier <[email protected]>
Signed-off-by: Matt Topol <[email protected]>
  • Loading branch information
miguelpragier authored and Jeremy Aguilon committed Oct 23, 2023
1 parent b3c4b0d commit c440484
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 20 deletions.
50 changes: 30 additions & 20 deletions go/arrow/flight/flightsql/driver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,46 +58,56 @@ func (g grpcCredentials) RequireTransportSecurity() bool {

// *** Type conversions ***
func fromArrowType(arr arrow.Array, idx int) (interface{}, error) {
if arr.IsNull(idx) {
return nil, nil
}

switch c := arr.(type) {
case *array.Boolean:
return c.Value(idx), nil
case *array.Float16:
return float64(c.Value(idx).Float32()), nil
return c.Value(idx), nil
case *array.Float32:
return float64(c.Value(idx)), nil
return c.Value(idx), nil
case *array.Float64:
return c.Value(idx), nil
case *array.Decimal128:
v := arr.DataType().(*arrow.Decimal128Type)
return c.Value(idx).ToFloat64(v.Scale), nil
case *array.Decimal256:
v := arr.DataType().(*arrow.Decimal256Type)
return c.Value(idx).ToFloat64(v.Scale), nil
case *array.Int8:
return int64(c.Value(idx)), nil
return c.Value(idx), nil
case *array.Int16:
return int64(c.Value(idx)), nil
return c.Value(idx), nil
case *array.Int32:
return int64(c.Value(idx)), nil
return c.Value(idx), nil
case *array.Int64:
return c.Value(idx), nil
case *array.Binary:
return c.Value(idx), nil
case *array.String:
return c.Value(idx), nil
case *array.Time32:
dt, ok := arr.DataType().(*arrow.Time32Type)
if !ok {
return nil, fmt.Errorf("datatype %T not matching time32", arr.DataType())
}
d32 := arr.DataType().(*arrow.Time32Type)
v := c.Value(idx)
return v.ToTime(dt.TimeUnit()), nil
return v.ToTime(d32.TimeUnit()), nil
case *array.Time64:
dt, ok := arr.DataType().(*arrow.Time64Type)
if !ok {
return nil, fmt.Errorf("datatype %T not matching time64", arr.DataType())
}
d64 := arr.DataType().(*arrow.Time64Type)
v := c.Value(idx)
return v.ToTime(dt.TimeUnit()), nil
return v.ToTime(d64.TimeUnit()), nil
case *array.Timestamp:
dt, ok := arr.DataType().(*arrow.TimestampType)
if !ok {
return nil, fmt.Errorf("datatype %T not matching timestamp", arr.DataType())
}
ts := arr.DataType().(*arrow.TimestampType)
v := c.Value(idx)
return v.ToTime(dt.TimeUnit()), nil
return v.ToTime(ts.TimeUnit()), nil
case *array.Date64:
return c.Value(idx).ToTime(), nil
case *array.DayTimeInterval:
durationDays := time.Duration(c.Value(idx).Days*24) * time.Hour
duration := time.Duration(c.Value(idx).Milliseconds) * time.Millisecond

return durationDays + duration, nil
}

return nil, fmt.Errorf("type %T: %w", arr, ErrNotSupported)
Expand Down
126 changes: 126 additions & 0 deletions go/arrow/flight/flightsql/driver/utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package driver

import (
"fmt"
"math/big"
"reflect"
"testing"
"time"

"github.com/apache/arrow/go/v14/arrow"
"github.com/apache/arrow/go/v14/arrow/array"
"github.com/apache/arrow/go/v14/arrow/decimal128"
"github.com/apache/arrow/go/v14/arrow/decimal256"
"github.com/apache/arrow/go/v14/arrow/float16"
"github.com/apache/arrow/go/v14/arrow/memory"
"github.com/stretchr/testify/require"
)

func Test_fromArrowType(t *testing.T) {
fields := []arrow.Field{
{Name: "f1-bool", Type: arrow.FixedWidthTypes.Boolean},
{Name: "f2-f16", Type: arrow.FixedWidthTypes.Float16},
{Name: "f3-f32", Type: arrow.PrimitiveTypes.Float32},
{Name: "f4-f64", Type: arrow.PrimitiveTypes.Float64},
{Name: "f5-d128", Type: &arrow.Decimal128Type{}},
{Name: "f6-d256", Type: &arrow.Decimal256Type{}},
{Name: "f7-i8", Type: arrow.PrimitiveTypes.Int8},
{Name: "f8-i16", Type: arrow.PrimitiveTypes.Int16},
{Name: "f9-i32", Type: arrow.PrimitiveTypes.Int32},
{Name: "f10-i64", Type: arrow.PrimitiveTypes.Int64},
{Name: "f11-binary", Type: arrow.BinaryTypes.Binary},
{Name: "f12-string", Type: arrow.BinaryTypes.String},
{Name: "f13-t32s", Type: arrow.FixedWidthTypes.Time32s},
{Name: "f14-t64us", Type: arrow.FixedWidthTypes.Time64us},
{Name: "f15-ts_us", Type: arrow.FixedWidthTypes.Timestamp_ns},
{Name: "f16-d64", Type: arrow.FixedWidthTypes.Date64},
{Name: "f17-dti", Type: arrow.FixedWidthTypes.DayTimeInterval},
}

schema := arrow.NewSchema(fields, nil)
pool := memory.NewGoAllocator()
b := array.NewRecordBuilder(pool, schema)
defer b.Release()

b.Field(0).(*array.BooleanBuilder).Append(true)
b.Field(1).(*array.Float16Builder).Append(float16.New(1))
b.Field(2).(*array.Float32Builder).Append(1)
b.Field(3).(*array.Float64Builder).Append(1)
b.Field(4).(*array.Decimal128Builder).Append(decimal128.FromBigInt(big.NewInt(1)))
b.Field(5).(*array.Decimal256Builder).Append(decimal256.FromBigInt(big.NewInt(1)))
b.Field(6).(*array.Int8Builder).Append(1)
b.Field(7).(*array.Int16Builder).Append(1)
b.Field(8).(*array.Int32Builder).Append(1)
b.Field(9).(*array.Int64Builder).Append(1)
b.Field(10).(*array.BinaryBuilder).Append([]byte("a"))
b.Field(11).(*array.StringBuilder).Append("a")

t32, err := arrow.Time32FromString("12:30:00", arrow.Second)
require.NoError(t, err)

b.Field(12).(*array.Time32Builder).Append(t32)

t64, err := arrow.Time64FromString("12:00:00", arrow.Microsecond)
require.NoError(t, err)

b.Field(13).(*array.Time64Builder).Append(t64)

ts, err := arrow.TimestampFromString("1970-01-01T12:00:00", arrow.Nanosecond)
require.NoError(t, err)

fmt.Println(ts.ToTime(arrow.Nanosecond))

b.Field(14).(*array.TimestampBuilder).Append(ts)

testTime := time.Now()
b.Field(15).(*array.Date64Builder).Append(arrow.Date64FromTime(testTime))
b.Field(16).(*array.DayTimeIntervalBuilder).Append(arrow.DayTimeInterval{Days: 1, Milliseconds: 1000})

rec := b.NewRecord()
defer rec.Release()

tf := func(t *testing.T, idx int, want any) {
t.Run(fmt.Sprintf("fromArrowType %v %s", fields[idx].Type, fields[idx].Name), func(t *testing.T) {
v, err := fromArrowType(rec.Column(idx), 0)
if err != nil {
t.Fatalf("err when converting from arrow: %s", err)
}
if !reflect.DeepEqual(v, want) {
t.Fatalf("test failed, wanted %T %v got %T %v", want, want, v, v)
}
})
}

tf(t, 0, true) // "f1-bool"
tf(t, 1, float16.New(1)) // "f2-f16"
tf(t, 2, float32(1)) // "f3-f32"
tf(t, 3, float64(1)) // "f4-f64"
tf(t, 4, float64(1)) // "f5-d128"
tf(t, 5, float64(1)) // "f6-d256"
tf(t, 6, int8(1)) // "f7-i8"
tf(t, 7, int16(1)) // "f8-i16"
tf(t, 8, int32(1)) // "f9-i32"
tf(t, 9, int64(1)) // "f10-i64"
tf(t, 10, []byte("a")) // "f11-binary"
tf(t, 11, "a") // "f12-string"
tf(t, 12, time.Date(1970, 1, 1, 12, 30, 0, 0, time.UTC)) // "f13-t32s"
tf(t, 13, time.Date(1970, 1, 1, 12, 0, 0, 0, time.UTC)) // "f14-t64us"
tf(t, 14, time.Date(1970, 1, 1, 12, 0, 0, 0, time.UTC)) // "f15-ts_us"
tf(t, 15, testTime.In(time.UTC).Truncate(24*time.Hour)) // "f16-d64"
tf(t, 16, time.Duration(24*time.Hour+time.Second)) // "f17-dti"
}

0 comments on commit c440484

Please sign in to comment.