diff --git a/finisher_api.go b/finisher_api.go index 6802945cc..793143d47 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -19,10 +19,7 @@ func (db *DB) Create(value interface{}) (tx *DB) { if db.CreateBatchSize > 0 { return db.CreateInBatches(value, db.CreateBatchSize) } - - tx = db.getInstance() - tx.Statement.Dest = value - return tx.callbacks.Create().Execute(tx) + return db.create(value) } // CreateInBatches inserts value in batches of batchSize @@ -63,12 +60,15 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { tx.RowsAffected = rowsAffected default: - tx = db.getInstance() - tx.Statement.Dest = value - tx = tx.callbacks.Create().Execute(tx) + db.create(value) } return } +func (db *DB) create(value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = value + return tx.callbacks.Create().Execute(tx) +} // Save updates value in database. If value doesn't contain a matching primary key, value is inserted. func (db *DB) Save(value interface{}) (tx *DB) { diff --git a/scan.go b/scan.go index d852c2c9f..ec78c1d67 100644 --- a/scan.go +++ b/scan.go @@ -202,6 +202,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) { switch reflectValueType.Kind() { case reflect.Array, reflect.Slice: reflectValueType = reflectValueType.Elem() + if reflectValueType.Kind() == reflect.Interface && reflectValue.Len() > 0 { + reflectValueType = reflect.Indirect(reflectValue.Index(0)).Elem().Type() + } } isPtr := reflectValueType.Kind() == reflect.Ptr if isPtr { @@ -318,7 +321,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) { } else { elem = reflect.New(reflectValueType) } - + if elem.Type().Kind() == reflect.Interface { + elem = elem.Elem() + } db.scanIntoStruct(rows, elem, values, fields, joinFields) if !update { diff --git a/schema/schema.go b/schema/schema.go index db2367975..17fdc317f 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -136,8 +136,10 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() + if modelType.Kind() == reflect.Interface && value.Len() > 0 { + modelType = reflect.Indirect(value.Index(0)).Elem().Type() + } } - if modelType.Kind() != reflect.Struct { if modelType.PkgPath() == "" { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)