diff --git a/Makefile b/Makefile index b0222544..8b77828b 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ PACKAGES=`go list ./... | grep -v example` test: - go test --count 1 -v -cover ${PACKAGES} + go test -trimpath -race --count 1 -v -cover ${PACKAGES} format: go fmt github.com/xitongsys/parquet-go/... diff --git a/compress/lz4_raw.go b/compress/lz4_raw.go index a6fd23cb..aae28e21 100644 --- a/compress/lz4_raw.go +++ b/compress/lz4_raw.go @@ -9,11 +9,11 @@ import ( ) func init() { - lz4hc := lz4.CompressorHC{ - Level: lz4.CompressionLevel(9), - } compressors[parquet.CompressionCodec_LZ4_RAW] = &Compressor{ Compress: func(buf []byte) []byte { + lz4hc := lz4.CompressorHC{ + Level: lz4.CompressionLevel(9), + } res := make([]byte, lz4.CompressBlockBound(len(buf))) count, _ := lz4hc.CompressBlock(buf, res) return res[:count] diff --git a/compress/lz4_raw_test.go b/compress/lz4_raw_test.go index 16725b2a..6a91e693 100644 --- a/compress/lz4_raw_test.go +++ b/compress/lz4_raw_test.go @@ -2,6 +2,7 @@ package compress import ( "bytes" + "sync" "testing" "github.com/xitongsys/parquet-go/parquet" @@ -15,10 +16,18 @@ func TestLz4RawCompress(t *testing.T) { } // compression - output := lz4RawCompressor.Compress(input) - if !bytes.Equal(compressed, output) { - t.Fatalf("expected output %s but was %s", string(compressed), string(output)) + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + output := lz4RawCompressor.Compress(input) + if !bytes.Equal(compressed, output) { + t.Fatalf("expected output %s but was %s", string(compressed), string(output)) + } + }() } + wg.Wait() // uncompression output, err := lz4RawCompressor.Uncompress(compressed) diff --git a/writer/writer.go b/writer/writer.go index 3efaf5e4..12ea242b 100644 --- a/writer/writer.go +++ b/writer/writer.go @@ -290,6 +290,7 @@ func (pw *ParquetWriter) flushObjs() error { defer func() { wg.Done() if r := recover(); r != nil { + lock.Lock() switch x := r.(type) { case string: errs[index] = errors.New(x) @@ -298,6 +299,7 @@ func (pw *ParquetWriter) flushObjs() error { default: errs[index] = errors.New("unknown error") } + lock.Unlock() } }() @@ -338,12 +340,14 @@ func (pw *ParquetWriter) flushObjs() error { wg.Wait() + lock.Lock() for _, err2 := range errs { if err2 != nil { err = err2 break } } + lock.Unlock() for _, pagesMap := range pagesMapList { for name, pages := range pagesMap { diff --git a/writer/writer_test.go b/writer/writer_test.go index cee981e3..e943f62d 100644 --- a/writer/writer_test.go +++ b/writer/writer_test.go @@ -237,3 +237,16 @@ func TestNewWriterWithInvaidFile(t *testing.T) { assert.Nil(t, pw) assert.ErrorIs(t, err, testWriteErr) } + +func TestWriteStopRaceConditionOnError(t *testing.T) { + var buf bytes.Buffer + fw := writerfile.NewWriterFile(&buf) + pw, err := NewJSONWriter(`{"Tag":"name=parquet-go-root","Fields":[{"Tag":"name=x, type=INT64"}]}`, fw, 4) + assert.NoError(t, err) + + for i := 0; i < 10; i++ { + entry := fmt.Sprintf(`{"not-x":%d}`, i) + assert.NoError(t, pw.Write(entry)) + } + assert.Error(t, pw.WriteStop()) +}