diff --git a/server/storage_handler.go b/server/storage_handler.go index 0ee969cdd..7e2ace2c9 100644 --- a/server/storage_handler.go +++ b/server/storage_handler.go @@ -409,7 +409,6 @@ func (s *storageWriteServer) CreateWriteStream(ctx context.Context, req *storage TableSchema: schema, WriteMode: storagepb.WriteStream_INSERT, } - s.mu.Lock() s.streamMap[streamName] = &writeStreamStatus{ streamType: streamType, @@ -525,6 +524,7 @@ func (s *storageWriteServer) appendRows(req *storagepb.AppendRowsRequest, msgDes status.rows = append(status.rows, data...) } return s.sendResult(stream, streamName, offset+int64(len(rows))) + } func (s *storageWriteServer) sendResult(stream storagepb.BigQueryWrite_AppendRowsServer, streamName string, offset int64) error { @@ -677,10 +677,14 @@ func (s *storageWriteServer) insertTableData(ctx context.Context, tx *connection func (s *storageWriteServer) GetWriteStream(ctx context.Context, req *storagepb.GetWriteStreamRequest) (*storagepb.WriteStream, error) { s.mu.RLock() - defer s.mu.RUnlock() status, exists := s.streamMap[req.Name] + s.mu.RUnlock() if !exists { - return nil, fmt.Errorf("failed to find stream from %s", req.Name) + stream, err := s.createDefaultStream(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to find stream from %s", req.Name) + } + return stream, err } return status.stream, nil } @@ -775,6 +779,58 @@ func (s *storageWriteServer) FlushRows(ctx context.Context, req *storagepb.Flush }, nil } +/* +* +According to google documentation (https://pkg.go.dev/cloud.google.com/go/bigquery/storage/apiv1#BigQueryWriteClient.GetWriteStream) +every table has a special stream named ‘_default’ to which data can be written. This stream doesn’t need to be created using CreateWriteStream + +Here we create the default stream and add it to map in case it not exists yet, the GetWriteStreamRequest given as second +argument should have Name in this format: projects//datasets//tables//streams/_default +*/ +func (s *storageWriteServer) createDefaultStream(ctx context.Context, req *storagepb.GetWriteStreamRequest) (*storagepb.WriteStream, error) { + streamId := req.Name + suffix := "_default" + streams := "/streams/" + if !strings.HasSuffix(streamId, suffix) { + return nil, fmt.Errorf("unexpected stream id: %s, expected '%s' suffix", streamId, suffix) + } + index := strings.LastIndex(streamId, streams) + if index == -1 { + return nil, fmt.Errorf("unexpected stream id: %s, expected containg '%s'", streamId, streams) + } + streamPart := streamId[:index] + writeStreamReq := &storagepb.CreateWriteStreamRequest{ + Parent: streamPart, + WriteStream: &storagepb.WriteStream{ + Type: storagepb.WriteStream_COMMITTED, + }, + } + stream, err := s.CreateWriteStream(ctx, writeStreamReq) + if err != nil { + return nil, err + } + projectID, datasetID, tableID, err := getIDsFromPath(streamPart) + if err != nil { + return nil, err + } + tableMetadata, err := getTableMetadata(ctx, s.server, projectID, datasetID, tableID) + if err != nil { + return nil, err + } + streamStatus := &writeStreamStatus{ + streamType: storagepb.WriteStream_COMMITTED, + stream: stream, + projectID: projectID, + datasetID: datasetID, + tableID: tableID, + tableMetadata: tableMetadata, + } + s.mu.Lock() + defer s.mu.Unlock() + s.streamMap[streamId] = streamStatus + return stream, nil +} + func getIDsFromPath(path string) (string, string, string, error) { paths := strings.Split(path, "/") if len(paths)%2 != 0 { diff --git a/server/storage_test.go b/server/storage_test.go index a58c4ea14..fa80bdf14 100644 --- a/server/storage_test.go +++ b/server/storage_test.go @@ -395,6 +395,7 @@ func TestStorageWrite(t *testing.T) { for _, test := range []struct { name string streamType storagepb.WriteStream_Type + isDefaultStream bool expectedRowsAfterFirstWrite int expectedRowsAfterSecondWrite int expectedRowsAfterThirdWrite int @@ -416,6 +417,15 @@ func TestStorageWrite(t *testing.T) { expectedRowsAfterThirdWrite: 6, expectedRowsAfterExplicitCommit: 6, }, + { + name: "default", + streamType: storagepb.WriteStream_COMMITTED, + isDefaultStream: true, + expectedRowsAfterFirstWrite: 1, + expectedRowsAfterSecondWrite: 4, + expectedRowsAfterThirdWrite: 6, + expectedRowsAfterExplicitCommit: 6, + }, } { const ( projectID = "test" @@ -490,24 +500,36 @@ func TestStorageWrite(t *testing.T) { } defer client.Close() t.Run(test.name, func(t *testing.T) { - writeStream, err := client.CreateWriteStream(ctx, &storagepb.CreateWriteStreamRequest{ - Parent: fmt.Sprintf("projects/%s/datasets/%s/tables/%s", projectID, datasetID, tableID), - WriteStream: &storagepb.WriteStream{ - Type: test.streamType, - }, - }) - if err != nil { - t.Fatalf("CreateWriteStream: %v", err) + var writeStreamName string + fullTableName := fmt.Sprintf("projects/%s/datasets/%s/tables/%s", projectID, datasetID, tableID) + if !test.isDefaultStream { + writeStream, err := client.CreateWriteStream(ctx, &storagepb.CreateWriteStreamRequest{ + Parent: fullTableName, + WriteStream: &storagepb.WriteStream{ + Type: test.streamType, + }, + }) + if err != nil { + t.Fatalf("CreateWriteStream: %v", err) + } + writeStreamName = writeStream.GetName() } m := &exampleproto.SampleData{} descriptorProto, err := adapt.NormalizeDescriptor(m.ProtoReflect().Descriptor()) if err != nil { t.Fatalf("NormalizeDescriptor: %v", err) } + var writerOptions []managedwriter.WriterOption + if test.isDefaultStream { + writerOptions = append(writerOptions, managedwriter.WithType(managedwriter.DefaultStream)) + writerOptions = append(writerOptions, managedwriter.WithDestinationTable(fullTableName)) + } else { + writerOptions = append(writerOptions, managedwriter.WithStreamName(writeStreamName)) + } + writerOptions = append(writerOptions, managedwriter.WithSchemaDescriptor(descriptorProto)) managedStream, err := client.NewManagedStream( ctx, - managedwriter.WithStreamName(writeStream.GetName()), - managedwriter.WithSchemaDescriptor(descriptorProto), + writerOptions..., ) if err != nil { t.Fatalf("NewManagedStream: %v", err)