diff --git a/rows.go b/rows.go index 9f1ac73d57..8813b33b29 100644 --- a/rows.go +++ b/rows.go @@ -211,6 +211,15 @@ func (err ErrSheetNotExist) Error() string { return fmt.Sprintf("sheet %s does not exist", err.SheetName) } +// ErrCountRows defines an error of count rows +type ErrCountRows struct { + err error +} + +func (err ErrCountRows) Error() string { + return fmt.Sprintf("wrong count rows: %s", err.err.Error()) +} + // rowXMLIterator defined runtime use field for the worksheet row SAX parser. type rowXMLIterator struct { err error @@ -237,6 +246,80 @@ func (rows *Rows) rowXMLHandler(rowIterator *rowXMLIterator, xmlElement *xml.Sta } } +// CountRows returns the number of rows in the worksheet. +// if return -1, that row not found +func (f *File) CountRows(sheet string) (int64, error) { + name, ok := f.getSheetXMLPath(sheet) + if !ok { + return -1, ErrSheetNotExist{sheet} + } + + needClose, reader, tempFile, readerSize, err := f.contentReader(name) + if err != nil { + return -1, ErrCountRows{fmt.Errorf("content reader: %v", err)} + } + if needClose && err == nil { + defer tempFile.Close() + } + + var ( + index int + buffSize int64 = 1024 + buff = make([]byte, buffSize) + cursor = readerSize - buffSize + ) + + for { + if cursor < 0 { + cursor = 0 + } + + if _, err = reader.ReadAt(buff, cursor); err != nil && err != io.EOF { + return -1, ErrCountRows{fmt.Errorf("read at: %v", err)} + } + + index = bytes.LastIndex(buff, []byte(` readerSize { + return -1, ErrCountRows{fmt.Errorf("not found row number (after)")} + } + + if _, err = reader.ReadAt(buff, cursor); err != nil && err != io.EOF { + return -1, ErrCountRows{fmt.Errorf("read at: %v", err)} + } + + index = bytes.Index(buff, []byte(` r="`)) + if index == -1 { + cursor += buffSize / 2 + continue + } + + if _, err = reader.ReadAt(buff, cursor+int64(index)+4); err != nil && err != io.EOF { + return -1, ErrCountRows{fmt.Errorf("read at: %v", err)} + } + + index = bytes.Index(buff, []byte(`"`)) + if index == -1 { + return -1, ErrCountRows{fmt.Errorf("not found row number")} + } + + countStr := string(buff[:index]) + + return strconv.ParseInt(countStr, 10, 64) + } +} + // Rows returns a rows iterator, used for streaming reading data for a // worksheet with a large data. This function is concurrency safe. For // example: @@ -326,19 +409,38 @@ func (f *File) getFromStringItem(index int) string { return f.getFromStringItem(index) } -// xmlDecoder creates XML decoder by given path in the zip from memory data +type ReaderContent interface { + io.Reader + io.ReaderAt +} + +// contentReader returns reader by given path in the zip from memory data // or system temporary file. -func (f *File) xmlDecoder(name string) (bool, *xml.Decoder, *os.File, error) { +func (f *File) contentReader(name string) (bool, ReaderContent, *os.File, int64, error) { var ( content []byte err error tempFile *os.File ) if content = f.readXML(name); len(content) > 0 { - return false, f.xmlNewDecoder(bytes.NewReader(content)), tempFile, err + return false, bytes.NewReader(content), tempFile, int64(len(content)), err } + tempFile, err = f.readTemp(name) - return true, f.xmlNewDecoder(tempFile), tempFile, err + + fileStat, err := tempFile.Stat() + if err != nil { + return true, tempFile, tempFile, 0, fmt.Errorf("failed to get file stat: %w", err) + } + + return true, tempFile, tempFile, fileStat.Size(), err +} + +// xmlDecoder creates XML decoder by given path in the zip from memory data +// or system temporary file. +func (f *File) xmlDecoder(name string) (bool, *xml.Decoder, *os.File, error) { + needClose, reader, tempFile, _, err := f.contentReader(name) + return needClose, f.xmlNewDecoder(reader), tempFile, err } // SetRowHeight provides a function to set the height of a single row. For diff --git a/rows_test.go b/rows_test.go index 2e49c2877b..3a49c9c51e 100644 --- a/rows_test.go +++ b/rows_test.go @@ -1113,3 +1113,81 @@ func trimSliceSpace(s []string) []string { } return s } + +func TestFile_CountRows(t *testing.T) { + type fields struct { + filename string + } + tests := []struct { + name string + fields fields + want int64 + wantErr assert.ErrorAssertionFunc + }{{ + name: "BadWorkbook.xlsx", + fields: fields{filename: filepath.Join("test", "BadWorkbook.xlsx")}, + want: -1, + wantErr: func(t assert.TestingT, err error, _ ...interface{}) bool { + return assert.Error(t, err) + }, + }, { + name: "Book1.xlsx", + fields: fields{filename: filepath.Join("test", "Book1.xlsx")}, + want: 22, + wantErr: func(t assert.TestingT, err error, _ ...interface{}) bool { + return assert.NoError(t, err) + }, + }, { + name: "CalcChain.xlsx", + fields: fields{filename: filepath.Join("test", "CalcChain.xlsx")}, + want: 1, + wantErr: func(t assert.TestingT, err error, _ ...interface{}) bool { + return assert.NoError(t, err) + }, + }, { + name: "SharedStrings.xlsx", + fields: fields{filename: filepath.Join("test", "SharedStrings.xlsx")}, + want: 1, + wantErr: func(t assert.TestingT, err error, _ ...interface{}) bool { + return assert.NoError(t, err) + }, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f, err := OpenFile(tt.fields.filename) + assert.NoError(t, err) + defer f.Close() + firstSheet := f.GetSheetName(0) + got, err := f.CountRows(firstSheet) + if !tt.wantErr(t, err, "CountRows") { + return + } + assert.Equal(t, tt.want, got, "CountRows") + }) + } +} + +func BenchmarkFile_GetRows_Old(b *testing.B) { + for i := 0; i < b.N; i++ { + f, _ := OpenFile(filepath.Join("test", "Book1.xlsx")) + defer f.Close() + + firstSheet := f.GetSheetName(0) + count := 0 + rows, _ := f.GetRows(firstSheet) + for range rows { + count++ + } + } +} + +func BenchmarkFile_GetRows_New(b *testing.B) { + for i := 0; i < b.N; i++ { + f, _ := OpenFile(filepath.Join("test", "Book1.xlsx")) + defer f.Close() + + firstSheet := f.GetSheetName(0) + _, err := f.CountRows(firstSheet) + assert.NoError(b, err) + } +}