diff --git a/result_set.go b/result_set.go index 88334985..9a4040e8 100644 --- a/result_set.go +++ b/result_set.go @@ -335,13 +335,19 @@ func (res ResultSet) Scan(v interface{}) error { } // Scan scans the rows into the given value. -func (res ResultSet) scanRow(row *nebula.Row, colNames []string, t reflect.Type) (reflect.Value, error) { +func (res ResultSet) scanRow(row *nebula.Row, colNames []string, rowType reflect.Type) (reflect.Value, error) { rowVals := row.GetValues() - val := reflect.New(t).Elem() + var result reflect.Value + if rowType.Kind() == reflect.Ptr { + result = reflect.New(rowType.Elem()) + } else { + result = reflect.New(rowType).Elem() + } + structVal := reflect.Indirect(result) - for fIdx := 0; fIdx < t.NumField(); fIdx++ { - f := t.Field(fIdx) + for fIdx := 0; fIdx < structVal.Type().NumField(); fIdx++ { + f := structVal.Type().Field(fIdx) tag := f.Tag.Get("nebula") if tag == "" { @@ -356,31 +362,147 @@ func (res ResultSet) scanRow(row *nebula.Row, colNames []string, t reflect.Type) rowVal := rowVals[cIdx] - switch f.Type.Kind() { - case reflect.Bool: - val.Field(fIdx).SetBool(rowVal.GetBVal()) - case reflect.Int: - val.Field(fIdx).SetInt(rowVal.GetIVal()) - case reflect.Int8: - val.Field(fIdx).SetInt(rowVal.GetIVal()) - case reflect.Int16: - val.Field(fIdx).SetInt(rowVal.GetIVal()) - case reflect.Int32: - val.Field(fIdx).SetInt(rowVal.GetIVal()) - case reflect.Int64: - val.Field(fIdx).SetInt(rowVal.GetIVal()) - case reflect.Float32: - val.Field(fIdx).SetFloat(rowVal.GetFVal()) - case reflect.Float64: - val.Field(fIdx).SetFloat(rowVal.GetFVal()) - case reflect.String: - val.Field(fIdx).SetString(string(rowVal.GetSVal())) - default: - return val, errors.New("scan: not support type") + if f.Type.Kind() == reflect.Slice { + list := rowVal.GetLVal() + err := scanListCol(list.Values, structVal.Field(fIdx), f.Type) + if err != nil { + return result, err + } + } else { + err := scanPrimitiveCol(rowVal, structVal.Field(fIdx), f.Type.Kind()) + if err != nil { + return result, err + } + } + } + + return result, nil +} + +func scanListCol(vals []*nebula.Value, listVal reflect.Value, sliceType reflect.Type) error { + switch sliceType.Elem().Kind() { + case reflect.Struct: + var listCol = reflect.MakeSlice(sliceType, 0, len(vals)) + for _, val := range vals { + ele := reflect.New(sliceType.Elem()).Elem() + err := scanStructField(val, ele, sliceType.Elem()) + if err != nil { + return err + } + listCol = reflect.Append(listCol, ele) + } + listVal.Set(listCol) + case reflect.Ptr: + var listCol = reflect.MakeSlice(sliceType, 0, len(vals)) + for _, val := range vals { + ele := reflect.New(sliceType.Elem().Elem()) + err := scanStructField(val, reflect.Indirect(ele), sliceType.Elem().Elem()) + if err != nil { + return err + } + listCol = reflect.Append(listCol, ele) } + listVal.Set(listCol) + default: + return errors.New("scan: not support list type") } - return val, nil + return nil +} + +func scanStructField(val *nebula.Value, eleVal reflect.Value, eleType reflect.Type) error { + vertex := val.GetVVal() + if vertex != nil { + tags := vertex.GetTags() + vid := vertex.GetVid() + + if len(tags) != 0 { + tag := tags[0] + + props := tag.GetProps() + props["_vid"] = vid + tagName := tag.GetName() + props["_tag_name"] = &nebula.Value{SVal: tagName} + + err := scanValFromProps(props, eleVal, eleType) + if err != nil { + return err + } + return nil + } + // no tags, continue + } + + edge := val.GetEVal() + if edge != nil { + props := edge.GetProps() + + src := edge.GetSrc() + dst := edge.GetDst() + name := edge.GetName() + props["_src"] = src + props["_dst"] = dst + props["_name"] = &nebula.Value{SVal: name} + + err := scanValFromProps(props, eleVal, eleType) + if err != nil { + return err + } + return nil + } + + return nil +} + +func scanValFromProps(props map[string]*nebula.Value, val reflect.Value, tpe reflect.Type) error { + for fIdx := 0; fIdx < tpe.NumField(); fIdx++ { + f := tpe.Field(fIdx) + n := f.Tag.Get("nebula") + v, ok := props[n] + if !ok { + continue + } + err := scanPrimitiveCol(v, val.Field(fIdx), f.Type.Kind()) + if err != nil { + return err + } + } + + return nil +} + +func scanPrimitiveCol(rowVal *nebula.Value, val reflect.Value, kind reflect.Kind) error { + w := ValueWrapper{value: rowVal} + if w.IsNull() || w.IsEmpty() { + // SetZero is introduced in go 1.20 + // val.SetZero() + return nil + } + + switch kind { + case reflect.Bool: + val.SetBool(rowVal.GetBVal()) + case reflect.Int: + val.SetInt(rowVal.GetIVal()) + case reflect.Int8: + val.SetInt(rowVal.GetIVal()) + case reflect.Int16: + val.SetInt(rowVal.GetIVal()) + case reflect.Int32: + val.SetInt(rowVal.GetIVal()) + case reflect.Int64: + val.SetInt(rowVal.GetIVal()) + case reflect.Float32: + val.SetFloat(rowVal.GetFVal()) + case reflect.Float64: + val.SetFloat(rowVal.GetFVal()) + case reflect.String: + val.SetString(string(rowVal.GetSVal())) + default: + return errors.New("scan: not support primitive type") + } + + return nil } // Returns the number of total rows diff --git a/result_set_test.go b/result_set_test.go index 870ae28b..420cce73 100644 --- a/result_set_test.go +++ b/result_set_test.go @@ -835,6 +835,203 @@ func TestScan(t *testing.T) { assert.Equal(t, true, testStructList[1].Col3) } +func TestScanPtr(t *testing.T) { + resp := &graph.ExecutionResponse{ + ErrorCode: nebula.ErrorCode_SUCCEEDED, + LatencyInUs: 1000, + Data: getDateset2(), + SpaceName: []byte("test_space"), + ErrorMsg: []byte("test"), + PlanDesc: graph.NewPlanDescription(), + Comment: []byte("test_comment")} + resultSet, err := genResultSet(resp, testTimezone) + if err != nil { + t.Error(err) + } + + type testStruct struct { + Col0 int64 `nebula:"col0_int64"` + Col1 float64 `nebula:"col1_float64"` + Col2 string `nebula:"col2_string"` + Col3 bool `nebula:"col3_bool"` + } + + var testStructList []*testStruct + err = resultSet.Scan(&testStructList) + if err != nil { + t.Error(err) + } + assert.Equal(t, 1, len(testStructList)) + assert.Equal(t, int64(1), testStructList[0].Col0) + assert.Equal(t, float64(2.0), testStructList[0].Col1) + assert.Equal(t, "string", testStructList[0].Col2) + assert.Equal(t, true, testStructList[0].Col3) + + // Scan again should work + err = resultSet.Scan(&testStructList) + if err != nil { + t.Error(err) + } + assert.Equal(t, 2, len(testStructList)) + assert.Equal(t, int64(1), testStructList[0].Col0) + assert.Equal(t, float64(2.0), testStructList[0].Col1) + assert.Equal(t, "string", testStructList[0].Col2) + assert.Equal(t, true, testStructList[0].Col3) + assert.Equal(t, int64(1), testStructList[1].Col0) + assert.Equal(t, float64(2.0), testStructList[1].Col1) + assert.Equal(t, "string", testStructList[1].Col2) + assert.Equal(t, true, testStructList[1].Col3) +} + +func TestScanWithNestStruct(t *testing.T) { + resp := &graph.ExecutionResponse{ + ErrorCode: nebula.ErrorCode_SUCCEEDED, + LatencyInUs: 1000, + Data: getNestDateset(), + SpaceName: []byte("test_space"), + ErrorMsg: []byte("test"), + PlanDesc: graph.NewPlanDescription(), + Comment: []byte("test_comment")} + resultSet, err := genResultSet(resp, testTimezone) + if err != nil { + t.Error(err) + } + + type Person struct { + Vid string `nebula:"_vid"` + Name string `nebula:"name"` + City string `nebula:"city"` + } + type Friend struct { + Src string `nebula:"_src"` + Dst string `nebula:"_dst"` + EdgeName string `nebula:"_name"` + CreatedAt string `nebula:"created_at"` + } + type Result struct { + Nodes []Person `nebula:"nodes"` + Edges []Friend `nebula:"relationships"` + } + + var results []Result + err = resultSet.Scan(&results) + if err != nil { + t.Error(err) + } + assert.Equal(t, 1, len(results)) + assert.NotEmpty(t, results[0].Nodes[0].Vid) + assert.Equal(t, "Tom", results[0].Nodes[0].Name) + assert.Equal(t, "Shanghai", results[0].Nodes[0].City) + assert.Equal(t, "Bob", results[0].Nodes[1].Name) + assert.Equal(t, "Hangzhou", results[0].Nodes[1].City) + assert.Equal(t, "2024-07-07", results[0].Edges[0].CreatedAt) + assert.Equal(t, "2024-07-07", results[0].Edges[1].CreatedAt) + assert.NotEmpty(t, results[0].Edges[0].Src) + assert.NotEmpty(t, results[0].Edges[0].Dst) + assert.Equal(t, "friend", results[0].Edges[0].EdgeName) + + // Scan again should work + err = resultSet.Scan(&results) + if err != nil { + t.Error(err) + } + assert.Equal(t, 2, len(results)) +} + +func TestScanWithNestStructPtr(t *testing.T) { + resp := &graph.ExecutionResponse{ + ErrorCode: nebula.ErrorCode_SUCCEEDED, + LatencyInUs: 1000, + Data: getNestDateset(), + SpaceName: []byte("test_space"), + ErrorMsg: []byte("test"), + PlanDesc: graph.NewPlanDescription(), + Comment: []byte("test_comment")} + resultSet, err := genResultSet(resp, testTimezone) + if err != nil { + t.Error(err) + } + + type Person struct { + Name string `nebula:"name"` + City string `nebula:"city"` + } + type Friend struct { + CreatedAt string `nebula:"created_at"` + } + type Result struct { + Nodes []*Person `nebula:"nodes"` + Edges []*Friend `nebula:"relationships"` + } + + var results []Result + err = resultSet.Scan(&results) + if err != nil { + t.Error(err) + } + assert.Equal(t, 1, len(results)) + assert.Equal(t, "Tom", results[0].Nodes[0].Name) + assert.Equal(t, "Shanghai", results[0].Nodes[0].City) + assert.Equal(t, "Bob", results[0].Nodes[1].Name) + assert.Equal(t, "Hangzhou", results[0].Nodes[1].City) + assert.Equal(t, "2024-07-07", results[0].Edges[0].CreatedAt) + assert.Equal(t, "2024-07-07", results[0].Edges[1].CreatedAt) + + // Scan again should work + err = resultSet.Scan(&results) + if err != nil { + t.Error(err) + } + assert.Equal(t, 2, len(results)) +} + +func TestScanWithStructPtr(t *testing.T) { + resp := &graph.ExecutionResponse{ + ErrorCode: nebula.ErrorCode_SUCCEEDED, + LatencyInUs: 1000, + Data: getNestDateset(), + SpaceName: []byte("test_space"), + ErrorMsg: []byte("test"), + PlanDesc: graph.NewPlanDescription(), + Comment: []byte("test_comment")} + resultSet, err := genResultSet(resp, testTimezone) + if err != nil { + t.Error(err) + } + + type Person struct { + Name string `nebula:"name"` + City string `nebula:"city"` + } + type Friend struct { + CreatedAt string `nebula:"created_at"` + } + type Result struct { + Nodes []*Person `nebula:"nodes"` + Edges []*Friend `nebula:"relationships"` + } + + var results []*Result + err = resultSet.Scan(&results) + if err != nil { + t.Error(err) + } + assert.Equal(t, 1, len(results)) + assert.Equal(t, "Tom", results[0].Nodes[0].Name) + assert.Equal(t, "Shanghai", results[0].Nodes[0].City) + assert.Equal(t, "Bob", results[0].Nodes[1].Name) + assert.Equal(t, "Hangzhou", results[0].Nodes[1].City) + assert.Equal(t, "2024-07-07", results[0].Edges[0].CreatedAt) + assert.Equal(t, "2024-07-07", results[0].Edges[1].CreatedAt) + + // Scan again should work + err = resultSet.Scan(&results) + if err != nil { + t.Error(err) + } + assert.Equal(t, 2, len(results)) +} + func TestIntVid(t *testing.T) { vertex := getVertexInt(101, 3, 5) node, err := genNode(vertex, testTimezone) @@ -1032,6 +1229,83 @@ func getDateset2() *nebula.DataSet { } } +func getNestDateset() *nebula.DataSet { + colNames := [][]byte{ + []byte("nodes"), + []byte("relationships"), + } + var list1 = nebula.NewValue() + list1.SetLVal(&nebula.NList{ + Values: []*nebula.Value{ + { + VVal: &nebula.Vertex{ + Vid: &nebula.Value{SVal: []byte("person_id_0")}, + Tags: []*nebula.Tag{ + { + Name: []byte("person"), + Props: map[string]*nebula.Value{ + "name": {SVal: []byte("Tom")}, + "city": {SVal: []byte("Shanghai")}, + }, + }, + }, + }, + }, + { + VVal: &nebula.Vertex{ + Vid: &nebula.Value{SVal: []byte("person_id_1")}, + Tags: []*nebula.Tag{ + { + Name: []byte("person"), + Props: map[string]*nebula.Value{ + "name": {SVal: []byte("Bob")}, + "city": {SVal: []byte("Hangzhou")}, + }, + }, + }, + }, + }, + }, + }) + + var list2 = nebula.NewValue() + list2.SetLVal(&nebula.NList{ + Values: []*nebula.Value{ + { + EVal: &nebula.Edge{ + Src: &nebula.Value{SVal: []byte("person_id_0")}, + Dst: &nebula.Value{SVal: []byte("person_id_1")}, + Name: []byte("friend"), + Props: map[string]*nebula.Value{ + "created_at": {SVal: []byte("2024-07-07")}, + }, + }, + }, + { + EVal: &nebula.Edge{ + Src: &nebula.Value{SVal: []byte("person_id_1")}, + Dst: &nebula.Value{SVal: []byte("person_id_0")}, + Name: []byte("friend"), + Props: map[string]*nebula.Value{ + "created_at": {SVal: []byte("2024-07-07")}, + }, + }, + }, + }, + }) + + valueList := []*nebula.Value{list1, list2} + var rows []*nebula.Row + row := &nebula.Row{ + Values: valueList, + } + rows = append(rows, row) + return &nebula.DataSet{ + ColumnNames: colNames, + Rows: rows, + } +} + func setIVal(ival int) *nebula.Value { var value = nebula.NewValue() newNum := new(int64)