diff --git a/encoder.go b/encoder.go index dbde3bc..3ceba28 100644 --- a/encoder.go +++ b/encoder.go @@ -58,6 +58,34 @@ func (e *Encoder) AddLE(src interface{}) error { return binary.Write(e.w, binary.LittleEndian, src) } +// insertLE serializes the passed value using little endian at at specified +// offset within the writer and restores the seek position without affecting +// WrittenBytes. +func (e *Encoder) insertLE(src interface{}, offset int64) (out error) { + if !e.wroteHeader { + return fmt.Errorf("cannot insert before header has been written") + } + originalOffset, err := e.w.Seek(0, io.SeekCurrent) + if err != nil { + return fmt.Errorf("failed to query current seek offset: %w", err) + } + _, err = e.w.Seek(offset, io.SeekStart) + if err != nil { + return fmt.Errorf("failed to seek to offset %d: %w", offset, err) + } + + // always restore the seek offset to the original position + defer func() { + _, deferErr := e.w.Seek(originalOffset, io.SeekStart) + if out == nil { + out = deferErr + } + }() + + out = binary.Write(e.w, binary.LittleEndian, src) + return out +} + // AddBE serializes and adds the passed value using big endian func (e *Encoder) AddBE(src interface{}) error { e.WrittenBytes += binary.Size(src) @@ -233,6 +261,36 @@ func (e *Encoder) writeMetadata() error { return e.AddBE(chunkData) } +// WriteCurrentSize updates the file and chunk headers with the length of the +// data written to the file so far to enable consumption of partially written +// audio mid-stream. However it is important to consume the data after calling +// WriteCurrentSize but before any further audio data is written. +// Additionally, metadata will not be present in the partial data. +func (e *Encoder) WriteCurrentSize() error { + // total size immediately follows the "RIFF" header, i.e. offset 4. + totalSizeOffset := binary.Size(riff.RiffID) + + // total size is the bytes written, less the "RIFF" header and the size + // value itself, i.e. minus 8. + var totalSizeValue uint32 + totalSizeValue = uint32(e.WrittenBytes - totalSizeOffset - binary.Size(totalSizeValue)) + + // go back and write total size in header + if err := e.insertLE(totalSizeValue, int64(totalSizeOffset)); err != nil { + return fmt.Errorf("%w when writing the total written bytes", err) + } + + // rewrite the audio chunk length header + if e.pcmChunkSizePos > 0 { + chunksize := uint32((int(e.BitDepth) / 8) * int(e.NumChans) * e.frames) + if err := e.insertLE(uint32(chunksize), int64(e.pcmChunkSizePos)); err != nil { + return fmt.Errorf("%w when writing wav data chunk size header", err) + } + } + + return nil +} + // Close flushes the content to disk, make sure the headers are up to date // Note that the underlying writer is NOT being closed. func (e *Encoder) Close() error { @@ -248,27 +306,12 @@ func (e *Encoder) Close() error { } } - // go back and write total size in header - if _, err := e.w.Seek(4, 0); err != nil { + if err := e.WriteCurrentSize(); err != nil { return err } - if err := e.AddLE(uint32(e.WrittenBytes) - 8); err != nil { - return fmt.Errorf("%w when writing the total written bytes", err) - } - - // rewrite the audio chunk length header - if e.pcmChunkSizePos > 0 { - if _, err := e.w.Seek(int64(e.pcmChunkSizePos), 0); err != nil { - return err - } - chunksize := uint32((int(e.BitDepth) / 8) * int(e.NumChans) * e.frames) - if err := e.AddLE(uint32(chunksize)); err != nil { - return fmt.Errorf("%w when writing wav data chunk size header", err) - } - } // jump back to the end of the file. - if _, err := e.w.Seek(0, 2); err != nil { + if _, err := e.w.Seek(0, io.SeekEnd); err != nil { return err } switch e.w.(type) { diff --git a/encoder_test.go b/encoder_test.go index d02d268..c68be33 100644 --- a/encoder_test.go +++ b/encoder_test.go @@ -1,9 +1,12 @@ package wav import ( + "math" "os" "path" + "path/filepath" "testing" + "time" ) func TestEncoderRoundTrip(t *testing.T) { @@ -156,3 +159,89 @@ func TestEncoderRoundTrip(t *testing.T) { }) } } + +func TestEncoder_UseBeforeClose(t *testing.T) { + testpath := filepath.Join(t.TempDir(), "test.wav") + f, err := os.Create(testpath) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + const sampleRateHz = 8000 + const bitDepth = 16 + const bitsPerByte = 8 + const bytesPerFrame = bitDepth / bitsPerByte + enc := NewEncoder(f, sampleRateHz, bitDepth, 1, 1) + + writeOneSecond := func() { + frameBuf := make([]byte, bytesPerFrame) + for i := 0; i < sampleRateHz; i++ { + enc.WriteFrame(frameBuf) + } + } + + sync := func() { + err = f.Sync() + if err != nil { + t.Fatal(err) + } + } + + getDuration := func() time.Duration { + r, err := os.Open(testpath) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + dec := NewDecoder(r) + dur, err := dec.Duration() + if err != nil { + t.Fatal(err) + } + return dur + } + + assertNotWithin := func(msg string, got, want, within float64) { + isWithin := math.Abs(got-want) <= within + if isWithin { + t.Errorf("%s: got %f, wanted not %f (within %f)", msg, got, want, within) + } + } + + assertWithin := func(msg string, got, want, within float64) { + isWithin := math.Abs(got-want) <= within + if !isWithin { + t.Errorf("%s: got %f, wanted %f (within %f)", msg, got, want, within) + } + } + + // duration should be undefined before explicit call to WriteCurrentSize or Close + writeOneSecond() + sync() + dur := getDuration() + assertNotWithin("before WriteCurrentSize()", dur.Seconds(), 1.0, 0.5) + + // duration should be correct after WriteCurrentSize + err = enc.WriteCurrentSize() + if err != nil { + t.Fatal(err) + } + dur = getDuration() + assertWithin("after WriteCurrentSize()", dur.Seconds(), 1.0, 0.5) + + // duration should be outdated after writing more audio data + writeOneSecond() + sync() + dur = getDuration() + assertWithin("after another second", dur.Seconds(), 1.0, 0.5) + + // duration should be correct after Close + err = enc.Close() + if err != nil { + t.Fatalf("Encoder.Close() = %v", err) + } + dur = getDuration() + assertWithin("after Encoder.Close()", dur.Seconds(), 2.0, 0.5) +}