diff --git a/README.md b/README.md index def4b75..5318fb7 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ * [remove](#remove) Remove columns. * [rename](#rename) Rename columns. * [replace](#replace) Replace values. +* [sort](#sort) Sort rows. * [slice](#slice) Slice specified range of rows. * [transform](#transform) Transform format. * [unique](#unique) Extract unique rows. @@ -813,6 +814,89 @@ Please refer to the following for the syntax of regular expressions. * https://golang.org/pkg/regexp/syntax/ +## sort + +Creates a new CSV file from the input CSV file by sorting by the values in the specified columns. + +### Usage + +``` +csvt sort -i INPUT -c COLUMN1 ... [--desc] [--number] -o OUTPUT [--usingfile] +``` + +``` +Usage: + csvt sort [flags] + +Flags: + -i, --input string Input CSV file path. + -c, --column stringArray Name of the column to use for sorting. + --desc (optional) Sort in descending order. The default is ascending order. + --number (optional) Sorts as a number. The default is to sort as a string. + -o, --output string Output CSV file path. + --usingfile (optional) Use temporary files for sorting. Use this when sorting large files that will not fit in memory. + -h, --help help for sort +``` + +### Example + +The contents of `input.csv`. + +``` +col1,col2 +02,a +10,b +01,a +11,c +20,b +``` + +Sort by "col1". + +``` +$ csvt sort -i input.csv -c col1 -o output.csv +``` + +The contents of the created `output.tsv`. + +``` +col1,col2 +01,a +02,a +10,b +11,c +20,b +``` + +By default, it is sorted as a string. +For example, it could look like this + +``` +col1 +1 +12 +123 +2 +21 +3 +``` + +If you want to sort as a number, specify `--number`. + +``` +$ csvt sort -i input.csv -c col1 --number -o output.csv +``` + +``` +col1 +1 +2 +3 +12 +21 +123 +``` + ## slice Create a new CSV file by slicing the specified range of rows from the input CSV file. diff --git a/cmd/root.go b/cmd/root.go index 30e7179..1e7aef9 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -41,6 +41,7 @@ func newRootCmd() *cobra.Command { rootCmd.AddCommand(newConcatCmd()) rootCmd.AddCommand(newSliceCmd()) rootCmd.AddCommand(newAddCmd()) + rootCmd.AddCommand(newSortCmd()) for _, c := range rootCmd.Commands() { // フラグ以外は受け付けないように diff --git a/cmd/sort.go b/cmd/sort.go new file mode 100644 index 0000000..9274f76 --- /dev/null +++ b/cmd/sort.go @@ -0,0 +1,123 @@ +package cmd + +import ( + "github.com/onozaty/csvt/csv" + "github.com/spf13/cobra" +) + +func newSortCmd() *cobra.Command { + + sortCmd := &cobra.Command{ + Use: "sort", + Short: "Sort rows", + RunE: func(cmd *cobra.Command, args []string) error { + + format, err := getFlagBaseCsvFormat(cmd.Flags()) + if err != nil { + return err + } + + inputPath, _ := cmd.Flags().GetString("input") + targetColumnNames, _ := cmd.Flags().GetStringArray("column") + sortDescending, _ := cmd.Flags().GetBool("desc") + asNumber, _ := cmd.Flags().GetBool("number") + useFileRows, _ := cmd.Flags().GetBool("usingfile") + outputPath, _ := cmd.Flags().GetString("output") + + // 引数の解析に成功した時点で、エラーが起きてもUsageは表示しない + cmd.SilenceUsage = true + + return runSort( + format, + inputPath, + targetColumnNames, + outputPath, + SortOptions{ + sortDescending: sortDescending, + asNumber: asNumber, + useFileRows: useFileRows, + }) + }, + } + + sortCmd.Flags().StringP("input", "i", "", "Input CSV file path.") + sortCmd.MarkFlagRequired("input") + sortCmd.Flags().StringArrayP("column", "c", []string{}, "Name of the column to use for sorting.") + sortCmd.MarkFlagRequired("column") + sortCmd.Flags().BoolP("desc", "", false, "(optional) Sort in descending order. The default is ascending order.") + sortCmd.Flags().BoolP("number", "", false, "(optional) Sorts as a number. The default is to sort as a string.") + sortCmd.Flags().StringP("output", "o", "", "Output CSV file path.") + sortCmd.MarkFlagRequired("output") + sortCmd.Flags().BoolP("usingfile", "", false, "(optional) Use temporary files for sorting. Use this when sorting large files that will not fit in memory.") + + return sortCmd +} + +type SortOptions struct { + sortDescending bool + asNumber bool + useFileRows bool +} + +func runSort(format csv.Format, inputPath string, targetColumnNames []string, outputPath string, options SortOptions) error { + + reader, writer, close, err := setupInputOutput(inputPath, outputPath, format) + if err != nil { + return err + } + defer close() + + err = sort(reader, targetColumnNames, writer, options) + + if err != nil { + return err + } + + return writer.Flush() +} + +func sort(reader csv.CsvReader, targetColumnNames []string, writer csv.CsvWriter, options SortOptions) error { + + var compare func(item1 string, item2 string) (int, error) + + if options.asNumber { + compare = csv.CompareNumber + } else { + compare = csv.CompareString + } + + if options.sortDescending { + compare = csv.Descending(compare) + } + + var sortedRows csv.CsvSortedRows + var err error + if options.useFileRows { + sortedRows, err = csv.LoadCsvFileSortedRows(reader, targetColumnNames, compare) + } else { + sortedRows, err = csv.LoadCsvMemorySortedRows(reader, targetColumnNames, compare) + } + if err != nil { + return err + } + + err = writer.Write(sortedRows.ColumnNames()) + if err != nil { + return err + } + + for i := 0; i < sortedRows.Count(); i++ { + + row, err := sortedRows.Row(i) + if err != nil { + return err + } + + err = writer.Write(row) + if err != nil { + return err + } + } + + return nil +} diff --git a/cmd/sort_test.go b/cmd/sort_test.go new file mode 100644 index 0000000..375d373 --- /dev/null +++ b/cmd/sort_test.go @@ -0,0 +1,486 @@ +package cmd + +import ( + "io" + "os" + "testing" +) + +func TestSortCmd(t *testing.T) { + + s := joinRows( + "col1,col2", + "2,a", + "1,b", + "4,c", + "3,d", + ) + + fi := createTempFile(t, s) + defer os.Remove(fi) + + fo := createTempFile(t, "") + defer os.Remove(fo) + + rootCmd := newRootCmd() + rootCmd.SetArgs([]string{ + "sort", + "-i", fi, + "-o", fo, + "-c", "col1", + }) + + err := rootCmd.Execute() + if err != nil { + t.Fatal("failed test\n", err) + } + + result := readString(t, fo) + + expect := joinRows( + "col1,col2", + "1,b", + "2,a", + "3,d", + "4,c", + ) + + if result != expect { + t.Fatal("failed test\n", result) + } +} + +func TestSortCmd_format(t *testing.T) { + + s := joinRows( + "col1\tcol2", + "2\ta", + "1\tb", + "4\tc", + "3\td", + ) + + fi := createTempFile(t, s) + defer os.Remove(fi) + + fo := createTempFile(t, "") + defer os.Remove(fo) + + rootCmd := newRootCmd() + rootCmd.SetArgs([]string{ + "sort", + "-i", fi, + "-o", fo, + "-c", "col1", + "--delim", `\t`, + }) + + err := rootCmd.Execute() + if err != nil { + t.Fatal("failed test\n", err) + } + + result := readString(t, fo) + + expect := joinRows( + "col1\tcol2", + "1\tb", + "2\ta", + "3\td", + "4\tc", + ) + + if result != expect { + t.Fatal("failed test\n", result) + } +} +func TestSortCmd_multiColumn(t *testing.T) { + + s := joinRows( + "col1,col2", + "2,b", + "1,b", + "2,a", + "3,a", + "1,c", + "1,a", + ) + + fi := createTempFile(t, s) + defer os.Remove(fi) + + fo := createTempFile(t, "") + defer os.Remove(fo) + + rootCmd := newRootCmd() + rootCmd.SetArgs([]string{ + "sort", + "-i", fi, + "-o", fo, + "-c", "col1", + "-c", "col2", + }) + + err := rootCmd.Execute() + if err != nil { + t.Fatal("failed test\n", err) + } + + result := readString(t, fo) + + expect := joinRows( + "col1,col2", + "1,a", + "1,b", + "1,c", + "2,a", + "2,b", + "3,a", + ) + + if result != expect { + t.Fatal("failed test\n", result) + } +} + +func TestSortCmd_number(t *testing.T) { + + s := joinRows( + "col1,col2", + "100,b", + "1,b", + "11,a", + "2,a", + ) + + fi := createTempFile(t, s) + defer os.Remove(fi) + + fo := createTempFile(t, "") + defer os.Remove(fo) + + rootCmd := newRootCmd() + rootCmd.SetArgs([]string{ + "sort", + "-i", fi, + "-o", fo, + "-c", "col1", + "--number", + }) + + err := rootCmd.Execute() + if err != nil { + t.Fatal("failed test\n", err) + } + + result := readString(t, fo) + + expect := joinRows( + "col1,col2", + "1,b", + "2,a", + "11,a", + "100,b", + ) + + if result != expect { + t.Fatal("failed test\n", result) + } +} + +func TestSortCmd_desc(t *testing.T) { + + s := joinRows( + "col1,col2", + "2,a", + "1,b", + "4,c", + "3,d", + ) + + fi := createTempFile(t, s) + defer os.Remove(fi) + + fo := createTempFile(t, "") + defer os.Remove(fo) + + rootCmd := newRootCmd() + rootCmd.SetArgs([]string{ + "sort", + "-i", fi, + "-o", fo, + "-c", "col1", + "--desc", + }) + + err := rootCmd.Execute() + if err != nil { + t.Fatal("failed test\n", err) + } + + result := readString(t, fo) + + expect := joinRows( + "col1,col2", + "4,c", + "3,d", + "2,a", + "1,b", + ) + + if result != expect { + t.Fatal("failed test\n", result) + } +} + +func TestSortCmd_multiColumn_number_desc(t *testing.T) { + + s := joinRows( + "col1,col2", + "1,100", + "2,10", + "11,10", + "4,2", + "5,1", + "10,10", + ) + + fi := createTempFile(t, s) + defer os.Remove(fi) + + fo := createTempFile(t, "") + defer os.Remove(fo) + + rootCmd := newRootCmd() + rootCmd.SetArgs([]string{ + "sort", + "-i", fi, + "-o", fo, + "-c", "col2", + "-c", "col1", + "--number", + "--desc", + }) + + err := rootCmd.Execute() + if err != nil { + t.Fatal("failed test\n", err) + } + + result := readString(t, fo) + + expect := joinRows( + "col1,col2", + "1,100", + "11,10", + "10,10", + "2,10", + "4,2", + "5,1", + ) + + if result != expect { + t.Fatal("failed test\n", result) + } +} + +func TestSortCmd_usingfile(t *testing.T) { + + s := joinRows( + "col1,col2", + "2,a", + "1,b", + "4,c", + "3,d", + ) + + fi := createTempFile(t, s) + defer os.Remove(fi) + + fo := createTempFile(t, "") + defer os.Remove(fo) + + rootCmd := newRootCmd() + rootCmd.SetArgs([]string{ + "sort", + "-i", fi, + "-o", fo, + "-c", "col1", + "--usingfile", + }) + + err := rootCmd.Execute() + if err != nil { + t.Fatal("failed test\n", err) + } + + result := readString(t, fo) + + expect := joinRows( + "col1,col2", + "1,b", + "2,a", + "3,d", + "4,c", + ) + + if result != expect { + t.Fatal("failed test\n", result) + } +} + +func TestSortCmd_multiColumn_number_desc_usingfile(t *testing.T) { + + s := joinRows( + "col1,col2", + "1,100", + "2,10", + "11,10", + "4,2", + "5,1", + "10,10", + ) + + fi := createTempFile(t, s) + defer os.Remove(fi) + + fo := createTempFile(t, "") + defer os.Remove(fo) + + rootCmd := newRootCmd() + rootCmd.SetArgs([]string{ + "sort", + "-i", fi, + "-o", fo, + "-c", "col2", + "-c", "col1", + "--number", + "--desc", + "--usingfile", + }) + + err := rootCmd.Execute() + if err != nil { + t.Fatal("failed test\n", err) + } + + result := readString(t, fo) + + expect := joinRows( + "col1,col2", + "1,100", + "11,10", + "10,10", + "2,10", + "4,2", + "5,1", + ) + + if result != expect { + t.Fatal("failed test\n", result) + } +} + +func TestSortCmd_invalidFormat(t *testing.T) { + + s := joinRows( + "col1,col2", + "1,1", + ) + + fi := createTempFile(t, s) + defer os.Remove(fi) + + fo := createTempFile(t, "") + defer os.Remove(fo) + + rootCmd := newRootCmd() + rootCmd.SetArgs([]string{ + "sort", + "-i", fi, + "-o", fo, + "-c", "col1", + "--delim", "xx", + }) + + err := rootCmd.Execute() + if err == nil || err.Error() != "flag delim should be specified with a single character" { + t.Fatal("failed test\n", err) + } +} + +func TestSortCmd_columnNotFound(t *testing.T) { + + s := joinRows( + "col1,col2", + "1,1", + ) + + fi := createTempFile(t, s) + defer os.Remove(fi) + + fo := createTempFile(t, "") + defer os.Remove(fo) + + rootCmd := newRootCmd() + rootCmd.SetArgs([]string{ + "sort", + "-i", fi, + "-o", fo, + "-c", "col3", + }) + + err := rootCmd.Execute() + if err == nil || err.Error() != "col3 is not found" { + t.Fatal("failed test\n", err) + } +} + +func TestSortCmd_empty(t *testing.T) { + + fi := createTempFile(t, "") + defer os.Remove(fi) + + fo := createTempFile(t, "") + defer os.Remove(fo) + + rootCmd := newRootCmd() + rootCmd.SetArgs([]string{ + "sort", + "-i", fi, + "-o", fo, + "-c", "col1", + }) + + err := rootCmd.Execute() + if err != io.EOF { + t.Fatal("failed test\n", err) + } +} + +func TestSortCmd_inputFileNotFound(t *testing.T) { + + fi := createTempFile(t, "") + defer os.Remove(fi) + + fo := createTempFile(t, "") + defer os.Remove(fo) + + rootCmd := newRootCmd() + rootCmd.SetArgs([]string{ + "sort", + "-i", fi + "____", // 存在しないファイル + "-o", fo, + "-c", "col1", + }) + + err := rootCmd.Execute() + if err == nil { + t.Fatal("failed test\n", err) + } + + pathErr := err.(*os.PathError) + if pathErr.Path != fi+"____" || pathErr.Op != "open" { + t.Fatal("failed test\n", err) + } +} diff --git a/csv/rows.go b/csv/rows.go new file mode 100644 index 0000000..af655f4 --- /dev/null +++ b/csv/rows.go @@ -0,0 +1,344 @@ +package csv + +import ( + "encoding/json" + "fmt" + "io" + "os" + "sort" + "strconv" + + "github.com/boltdb/bolt" + "github.com/onozaty/csvt/util" +) + +type CsvSortedRows interface { + Count() int + ColumnNames() []string + Row(index int) ([]string, error) + Close() error +} + +type memorySortedRows struct { + rows [][]string + columnNames []string +} + +func (t *memorySortedRows) Count() int { + + return len(t.rows) +} + +func (t *memorySortedRows) Row(index int) ([]string, error) { + + return t.rows[index], nil +} + +func (t *memorySortedRows) ColumnNames() []string { + + return t.columnNames +} + +func (t *memorySortedRows) Close() error { + + // リソースは保持しないので何もしない + return nil +} + +func LoadCsvMemorySortedRows(reader CsvReader, useColumnNames []string, compare func(item1 string, item2 string) (int, error)) (CsvSortedRows, error) { + + allColumnNames, err := reader.Read() + if err != nil { + return nil, err + } + + useColumnIndexes := []int{} + for _, useColumnName := range useColumnNames { + + useColumnIndex := util.IndexOf(allColumnNames, useColumnName) + if useColumnIndex == -1 { + return nil, fmt.Errorf("%s is not found", useColumnName) + } + + useColumnIndexes = append(useColumnIndexes, useColumnIndex) + } + + rows := [][]string{} + for { + row, err := reader.Read() + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + + rows = append(rows, row) + } + + var sortError error + // ソート + sort.SliceStable(rows, func(i, j int) bool { + + if sortError != nil { + // エラーが起きているときは以降の比較は行わない + return false + } + + n := 0 + + for _, useColumnIndex := range useColumnIndexes { + + n, sortError = compare(rows[i][useColumnIndex], rows[j][useColumnIndex]) + if sortError != nil { + return false + } + + if n != 0 { + break + } + } + + return n < 0 + }) + + if sortError != nil { + return nil, sortError + } + + return &memorySortedRows{ + rows: rows, + columnNames: allColumnNames, + }, nil +} + +type fileSortedRows struct { + sortedIndexies []int + columnNames []string + dbPath string + db *bolt.DB +} + +func (t *fileSortedRows) Count() int { + return len(t.sortedIndexies) +} + +func (t *fileSortedRows) Row(index int) ([]string, error) { + + // 既にDBを開いている場合は、使いまわす + // (CsvTableのClose時に閉じている) + if t.db == nil { + db, err := bolt.Open(t.dbPath, 0600, nil) + if err != nil { + return nil, err + } + t.db = db + } + + row := make([]string, 0) + + err := t.db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("csvRows")) + + v := b.Get([]byte(strconv.Itoa(t.sortedIndexies[index]))) + if v != nil { + json.Unmarshal(v, &row) + } + + return nil + }) + + if err != nil { + return nil, err + } + + if len(row) == 0 { + return nil, nil + } + + return row, nil +} + +func (t *fileSortedRows) ColumnNames() []string { + + return t.columnNames +} + +func (t *fileSortedRows) Close() error { + + if t.db != nil { + err := t.db.Close() + if err != nil { + return err + } + } + + return os.Remove(t.dbPath) +} + +type SortSource struct { + index int + items []string +} + +func LoadCsvFileSortedRows(reader CsvReader, useColumnNames []string, compare func(item1 string, item2 string) (int, error)) (CsvSortedRows, error) { + + allColumnNames, err := reader.Read() + if err != nil { + return nil, err + } + + useColumnIndexes := []int{} + for _, useColumnName := range useColumnNames { + + useColumnIndex := util.IndexOf(allColumnNames, useColumnName) + if useColumnIndex == -1 { + return nil, fmt.Errorf("%s is not found", useColumnName) + } + + useColumnIndexes = append(useColumnIndexes, useColumnIndex) + } + + dbFile, err := os.CreateTemp("", "csvdb") + if err != nil { + return nil, err + } + defer dbFile.Close() + + db, err := bolt.Open(dbFile.Name(), 0600, nil) + if err != nil { + return nil, err + } + defer db.Close() + + sortSources := []SortSource{} + rowIndex := 0 + eof := false + + for !eof { + + err = db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucketIfNotExists([]byte("csvRows")) + if err != nil { + return err + } + + // 1トランザクションで大量の書き込みを行うと速度が落ちるため + // 分割してコミットを行う + for i := 0; i < 10000; i++ { + row, err := reader.Read() + if err == io.EOF { + eof = true + break + } + if err != nil { + return err + } + + items := []string{} + for _, useColumnIndex := range useColumnIndexes { + items = append(items, row[useColumnIndex]) + } + + sortSources = append(sortSources, SortSource{ + index: rowIndex, + items: items, + }) + + rowJson, err := json.Marshal(row) + if err != nil { + return err + } + + err = b.Put([]byte(strconv.Itoa(rowIndex)), []byte(rowJson)) + if err != nil { + return err + } + + rowIndex++ + } + + return nil + }) + + if err != nil { + return nil, err + } + } + + var sortError error + // ソート + sort.SliceStable(sortSources, func(i, j int) bool { + + if sortError != nil { + // エラーが起きているときは以降の比較は行わない + return false + } + + n := 0 + + for index := range useColumnIndexes { + + n, sortError = compare(sortSources[i].items[index], sortSources[j].items[index]) + if sortError != nil { + return false + } + + if n != 0 { + break + } + } + + return n < 0 + }) + + if sortError != nil { + return nil, sortError + } + + sortedIndexies := []int{} + for _, sortSource := range sortSources { + sortedIndexies = append(sortedIndexies, sortSource.index) + } + + return &fileSortedRows{ + sortedIndexies: sortedIndexies, + columnNames: allColumnNames, + dbPath: dbFile.Name(), + }, nil +} + +func CompareString(item1 string, item2 string) (int, error) { + if item1 == item2 { + return 0, nil + } + if item1 < item2 { + return -1, nil + } + return 1, nil +} + +func CompareNumber(item1 string, item2 string) (int, error) { + + num1, err := strconv.Atoi(item1) + if err != nil { + return 0, err + } + + num2, err := strconv.Atoi(item2) + if err != nil { + return 0, err + } + + return num1 - num2, nil +} + +func Descending(compare func(item1 string, item2 string) (int, error)) func(item1 string, item2 string) (int, error) { + + return func(item1 string, item2 string) (int, error) { + n, err := compare(item1, item2) + + // 結果を反転させる + return n * -1, err + } +} diff --git a/csv/rows_test.go b/csv/rows_test.go new file mode 100644 index 0000000..d2b8dc5 --- /dev/null +++ b/csv/rows_test.go @@ -0,0 +1,521 @@ +package csv + +import ( + "io" + "reflect" + "strconv" + "strings" + "testing" +) + +// Memory +func TestLoadCsvMemorySortedRows(t *testing.T) { + + s := joinRows( + []string{"col1", "col2"}, + []string{"2", "b"}, + []string{"5", "d"}, + []string{"1", "c"}, + []string{"3", "a"}, + []string{"4", "e"}, + ) + + r := NewCsvReader(strings.NewReader(s), Format{}) + + rows, err := LoadCsvMemorySortedRows(r, []string{"col1"}, CompareString) + + if err != nil { + t.Fatal("failed test\n", err) + } + defer rows.Close() + + if rows.Count() != 5 { + t.Fatal("failed test\n", rows.Count()) + } + + if reflect.DeepEqual(rows.ColumnNames(), []string{"col1", "col2"}) { + t.Fatal("failed test\n", rows.ColumnNames()) + } + + assertRows(t, rows, + []string{"1", "c"}, + []string{"2", "b"}, + []string{"3", "a"}, + []string{"4", "e"}, + []string{"5", "d"}, + ) +} + +func TestLoadCsvMemorySortedRows_multiColumn(t *testing.T) { + + s := joinRows( + []string{"col1", "col2"}, + []string{"1", "c"}, + []string{"2", "a"}, + []string{"1", "a"}, + []string{"2", "b"}, + []string{"1", "b"}, + ) + + r := NewCsvReader(strings.NewReader(s), Format{}) + + rows, err := LoadCsvMemorySortedRows(r, []string{"col1", "col2"}, CompareString) + + if err != nil { + t.Fatal("failed test\n", err) + } + defer rows.Close() + + if rows.Count() != 5 { + t.Fatal("failed test\n", rows.Count()) + } + + if reflect.DeepEqual(rows.ColumnNames(), []string{"col1", "col2"}) { + t.Fatal("failed test\n", rows.ColumnNames()) + } + + assertRows(t, rows, + []string{"1", "a"}, + []string{"1", "b"}, + []string{"1", "c"}, + []string{"2", "a"}, + []string{"2", "b"}, + ) +} + +func TestLoadCsvMemorySortedRows_num(t *testing.T) { + + s := joinRows( + []string{"col1"}, + []string{"10"}, + []string{"2"}, + []string{"9"}, + []string{"123"}, + ) + + r := NewCsvReader(strings.NewReader(s), Format{}) + + rows, err := LoadCsvMemorySortedRows(r, []string{"col1"}, CompareNumber) + + if err != nil { + t.Fatal("failed test\n", err) + } + defer rows.Close() + + if rows.Count() != 4 { + t.Fatal("failed test\n", rows.Count()) + } + + if reflect.DeepEqual(rows.ColumnNames(), []string{"col1"}) { + t.Fatal("failed test\n", rows.ColumnNames()) + } + + assertRows(t, rows, + []string{"2"}, + []string{"9"}, + []string{"10"}, + []string{"123"}, + ) +} + +func TestLoadCsvMemorySortedRows_same(t *testing.T) { + + s := joinRows( + []string{"col1", "col2"}, + []string{"1", "3"}, + []string{"2", "1"}, + []string{"1", "1"}, + []string{"2", "2"}, + []string{"1", "2"}, + ) + + r := NewCsvReader(strings.NewReader(s), Format{}) + + // col1だけ指定して同じ値がどうなるか確認 + rows, err := LoadCsvMemorySortedRows(r, []string{"col1"}, CompareString) + + if err != nil { + t.Fatal("failed test\n", err) + } + defer rows.Close() + + if rows.Count() != 5 { + t.Fatal("failed test\n", rows.Count()) + } + + if reflect.DeepEqual(rows.ColumnNames(), []string{"col1", "col2"}) { + t.Fatal("failed test\n", rows.ColumnNames()) + } + + assertRows(t, rows, + []string{"1", "3"}, + []string{"1", "1"}, + []string{"1", "2"}, + []string{"2", "1"}, + []string{"2", "2"}, + ) +} + +func TestLoadCsvMemorySortedRows_empty(t *testing.T) { + + r := NewCsvReader(strings.NewReader(""), Format{}) + + _, err := LoadCsvMemorySortedRows(r, []string{"col1"}, CompareString) + + if err != io.EOF { + t.Fatal("failed test\n", err) + } +} + +func TestLoadCsvMemorySortedRows_columnNotFound(t *testing.T) { + + s := joinRows( + []string{"col1", "col2"}, + []string{"1", "3"}, + ) + + r := NewCsvReader(strings.NewReader(s), Format{}) + + _, err := LoadCsvMemorySortedRows(r, []string{"col1", "col3"}, CompareString) + + if err == nil || err.Error() != "col3 is not found" { + t.Fatal("failed test\n", err) + } +} + +func TestLoadCsvMemorySortedRows_invalidNumber(t *testing.T) { + + s := joinRows( + []string{"col1", "col2"}, + []string{"1", "1"}, + []string{"a", "2"}, // 数字じゃない + []string{"3", "3"}, + ) + + r := NewCsvReader(strings.NewReader(s), Format{}) + + _, err := LoadCsvMemorySortedRows(r, []string{"col1"}, CompareNumber) + + if err == nil || err.Error() != `strconv.Atoi: parsing "a": invalid syntax` { + t.Fatal("failed test\n", err) + } +} + +func TestLoadCsvMemorySortedRows_big(t *testing.T) { + + const maxId = 100000 + + s := [maxId + 1]string{} + s[0] = "col1,col2" + for i := 1; i <= maxId; i++ { + s[i] = strconv.Itoa(i) + "," + strconv.Itoa(maxId-i) + } + + r := NewCsvReader(strings.NewReader(strings.Join(s[:], "\n")), Format{}) + + rows, err := LoadCsvMemorySortedRows(r, []string{"col2"}, CompareString) + + if err != nil { + t.Fatal("failed test\n", err) + } + defer rows.Close() + + if rows.Count() != maxId { + t.Fatal("failed test\n", rows.Count()) + } + + if reflect.DeepEqual(rows.ColumnNames(), []string{"col1", "col2"}) { + t.Fatal("failed test\n", rows.ColumnNames()) + } + + // 先頭と末尾を確認 + { + row, err := rows.Row(0) // 先頭 + if err != nil { + t.Fatal("failed test\n", err) + } + + if !reflect.DeepEqual(row, []string{strconv.Itoa(maxId), "0"}) { + t.Fatal("failed test\n", row) + } + } + { + row, err := rows.Row(maxId - 1) // 末尾 + if err != nil { + t.Fatal("failed test\n", err) + } + + if !reflect.DeepEqual(row, []string{"1", strconv.Itoa(maxId - 1)}) { + t.Fatal("failed test\n", row) + } + } +} + +// File +func TestLoadCsvFileSortedRows(t *testing.T) { + + s := joinRows( + []string{"col1", "col2"}, + []string{"2", "b"}, + []string{"5", "d"}, + []string{"1", "c"}, + []string{"3", "a"}, + []string{"4", "e"}, + ) + + r := NewCsvReader(strings.NewReader(s), Format{}) + + rows, err := LoadCsvFileSortedRows(r, []string{"col1"}, CompareString) + + if err != nil { + t.Fatal("failed test\n", err) + } + defer rows.Close() + + if rows.Count() != 5 { + t.Fatal("failed test\n", rows.Count()) + } + + if reflect.DeepEqual(rows.ColumnNames(), []string{"col1", "col2"}) { + t.Fatal("failed test\n", rows.ColumnNames()) + } + + assertRows(t, rows, + []string{"1", "c"}, + []string{"2", "b"}, + []string{"3", "a"}, + []string{"4", "e"}, + []string{"5", "d"}, + ) +} + +func TestLoadCsvFileSortedRows_multiColumn(t *testing.T) { + + s := joinRows( + []string{"col1", "col2"}, + []string{"1", "c"}, + []string{"2", "a"}, + []string{"1", "a"}, + []string{"2", "b"}, + []string{"1", "b"}, + ) + + r := NewCsvReader(strings.NewReader(s), Format{}) + + rows, err := LoadCsvFileSortedRows(r, []string{"col1", "col2"}, CompareString) + + if err != nil { + t.Fatal("failed test\n", err) + } + defer rows.Close() + + if rows.Count() != 5 { + t.Fatal("failed test\n", rows.Count()) + } + + if reflect.DeepEqual(rows.ColumnNames(), []string{"col1", "col2"}) { + t.Fatal("failed test\n", rows.ColumnNames()) + } + + assertRows(t, rows, + []string{"1", "a"}, + []string{"1", "b"}, + []string{"1", "c"}, + []string{"2", "a"}, + []string{"2", "b"}, + ) +} + +func TestLoadCsvFileSortedRows_num(t *testing.T) { + + s := joinRows( + []string{"col1"}, + []string{"10"}, + []string{"2"}, + []string{"9"}, + []string{"123"}, + ) + + r := NewCsvReader(strings.NewReader(s), Format{}) + + rows, err := LoadCsvFileSortedRows(r, []string{"col1"}, CompareNumber) + + if err != nil { + t.Fatal("failed test\n", err) + } + defer rows.Close() + + if rows.Count() != 4 { + t.Fatal("failed test\n", rows.Count()) + } + + if reflect.DeepEqual(rows.ColumnNames(), []string{"col1"}) { + t.Fatal("failed test\n", rows.ColumnNames()) + } + + assertRows(t, rows, + []string{"2"}, + []string{"9"}, + []string{"10"}, + []string{"123"}, + ) +} + +func TestLoadCsvFileSortedRows_same(t *testing.T) { + + s := joinRows( + []string{"col1", "col2"}, + []string{"1", "3"}, + []string{"2", "1"}, + []string{"1", "1"}, + []string{"2", "2"}, + []string{"1", "2"}, + ) + + r := NewCsvReader(strings.NewReader(s), Format{}) + + // col1だけ指定して同じ値がどうなるか確認 + rows, err := LoadCsvFileSortedRows(r, []string{"col1"}, CompareString) + + if err != nil { + t.Fatal("failed test\n", err) + } + defer rows.Close() + + if rows.Count() != 5 { + t.Fatal("failed test\n", rows.Count()) + } + + if reflect.DeepEqual(rows.ColumnNames(), []string{"col1", "col2"}) { + t.Fatal("failed test\n", rows.ColumnNames()) + } + + assertRows(t, rows, + []string{"1", "3"}, + []string{"1", "1"}, + []string{"1", "2"}, + []string{"2", "1"}, + []string{"2", "2"}, + ) +} + +func TestLoadCsvFileSortedRows_empty(t *testing.T) { + + r := NewCsvReader(strings.NewReader(""), Format{}) + + _, err := LoadCsvFileSortedRows(r, []string{"col1"}, CompareString) + + if err != io.EOF { + t.Fatal("failed test\n", err) + } +} + +func TestLoadCsvFileSortedRows_columnNotFound(t *testing.T) { + + s := joinRows( + []string{"col1", "col2"}, + []string{"1", "3"}, + ) + + r := NewCsvReader(strings.NewReader(s), Format{}) + + _, err := LoadCsvFileSortedRows(r, []string{"col1", "col3"}, CompareString) + + if err == nil || err.Error() != "col3 is not found" { + t.Fatal("failed test\n", err) + } +} + +func TestLoadCsvFileSortedRows_invalidNumber(t *testing.T) { + + s := joinRows( + []string{"col1", "col2"}, + []string{"1", "1"}, + []string{"a", "2"}, // 数字じゃない + []string{"3", "3"}, + ) + + r := NewCsvReader(strings.NewReader(s), Format{}) + + _, err := LoadCsvFileSortedRows(r, []string{"col1"}, CompareNumber) + + if err == nil || err.Error() != `strconv.Atoi: parsing "a": invalid syntax` { + t.Fatal("failed test\n", err) + } +} + +func TestLoadCsvFileSortedRows_big(t *testing.T) { + + const maxId = 100000 + + s := [maxId + 1]string{} + s[0] = "col1,col2" + for i := 1; i <= maxId; i++ { + s[i] = strconv.Itoa(i) + "," + strconv.Itoa(maxId-i) + } + + r := NewCsvReader(strings.NewReader(strings.Join(s[:], "\n")), Format{}) + + rows, err := LoadCsvFileSortedRows(r, []string{"col2"}, CompareString) + + if err != nil { + t.Fatal("failed test\n", err) + } + defer rows.Close() + + if rows.Count() != maxId { + t.Fatal("failed test\n", rows.Count()) + } + + if reflect.DeepEqual(rows.ColumnNames(), []string{"col1", "col2"}) { + t.Fatal("failed test\n", rows.ColumnNames()) + } + + // 先頭と末尾を確認 + { + row, err := rows.Row(0) // 先頭 + if err != nil { + t.Fatal("failed test\n", err) + } + + if !reflect.DeepEqual(row, []string{strconv.Itoa(maxId), "0"}) { + t.Fatal("failed test\n", row) + } + } + { + row, err := rows.Row(maxId - 1) // 末尾 + if err != nil { + t.Fatal("failed test\n", err) + } + + if !reflect.DeepEqual(row, []string{"1", strconv.Itoa(maxId - 1)}) { + t.Fatal("failed test\n", row) + } + } +} + +func assertRows(t *testing.T, rows CsvSortedRows, expecteds ...[]string) { + + for i, expected := range expecteds { + + row, err := rows.Row(i) + if err != nil { + t.Fatal("failed test\n", err) + } + + if !reflect.DeepEqual(row, expected) { + t.Fatal("failed test\n", i, row) + } + } +} + +func joinRows(rows ...[]string) string { + + csv := "" + + for _, row := range rows { + csv += strings.Join(row, ",") + "\r\n" + } + + return csv +} diff --git a/csv/table.go b/csv/table.go index 9874878..ca3bcdb 100644 --- a/csv/table.go +++ b/csv/table.go @@ -12,18 +12,18 @@ import ( type CsvTable interface { Find(key string) (map[string]string, error) - JoinColumnName() string + KeyColumnName() string ColumnNames() []string Close() error } -type MemoryTable struct { - joinColumnName string - columnNames []string - rows map[string][]string +type memoryTable struct { + keyColumnName string + columnNames []string + rows map[string][]string } -func (t *MemoryTable) Find(key string) (map[string]string, error) { +func (t *memoryTable) Find(key string) (map[string]string, error) { row := t.rows[key] @@ -39,32 +39,32 @@ func (t *MemoryTable) Find(key string) (map[string]string, error) { return rowMap, nil } -func (t *MemoryTable) JoinColumnName() string { +func (t *memoryTable) KeyColumnName() string { - return t.joinColumnName + return t.keyColumnName } -func (t *MemoryTable) ColumnNames() []string { +func (t *memoryTable) ColumnNames() []string { return t.columnNames } -func (t *MemoryTable) Close() error { +func (t *memoryTable) Close() error { // リソースは保持しないので何もしない return nil } -func LoadCsvMemoryTable(reader CsvReader, joinColumnName string) (CsvTable, error) { +func LoadCsvMemoryTable(reader CsvReader, keyColumnName string) (CsvTable, error) { headers, err := reader.Read() if err != nil { return nil, err } - primaryColumnIndex := util.IndexOf(headers, joinColumnName) + primaryColumnIndex := util.IndexOf(headers, keyColumnName) if primaryColumnIndex == -1 { - return nil, fmt.Errorf("%s is not found", joinColumnName) + return nil, fmt.Errorf("%s is not found", keyColumnName) } rows := make(map[string][]string) @@ -81,27 +81,27 @@ func LoadCsvMemoryTable(reader CsvReader, joinColumnName string) (CsvTable, erro // -> 重複して存在した場合はエラーに _, has := rows[row[primaryColumnIndex]] if has { - return nil, fmt.Errorf("%s:%s is duplicated", joinColumnName, row[primaryColumnIndex]) + return nil, fmt.Errorf("%s:%s is duplicated", keyColumnName, row[primaryColumnIndex]) } rows[row[primaryColumnIndex]] = row } - return &MemoryTable{ - joinColumnName: joinColumnName, - columnNames: headers, - rows: rows, + return &memoryTable{ + keyColumnName: keyColumnName, + columnNames: headers, + rows: rows, }, nil } -type FileTable struct { - joinColumnName string - columnNames []string - dbPath string - db *bolt.DB +type fileTable struct { + keyColumnName string + columnNames []string + dbPath string + db *bolt.DB } -func (t *FileTable) Find(key string) (map[string]string, error) { +func (t *fileTable) Find(key string) (map[string]string, error) { // 既にDBを開いている場合は、使いまわす // (CsvTableのClose時に閉じている) @@ -142,17 +142,17 @@ func (t *FileTable) Find(key string) (map[string]string, error) { return rowMap, nil } -func (t *FileTable) JoinColumnName() string { +func (t *fileTable) KeyColumnName() string { - return t.joinColumnName + return t.keyColumnName } -func (t *FileTable) ColumnNames() []string { +func (t *fileTable) ColumnNames() []string { return t.columnNames } -func (t *FileTable) Close() error { +func (t *fileTable) Close() error { if t.db != nil { err := t.db.Close() @@ -164,16 +164,16 @@ func (t *FileTable) Close() error { return os.Remove(t.dbPath) } -func LoadCsvFileTable(reader CsvReader, joinColumnName string) (CsvTable, error) { +func LoadCsvFileTable(reader CsvReader, keyColumnName string) (CsvTable, error) { headers, err := reader.Read() if err != nil { return nil, err } - primaryColumnIndex := util.IndexOf(headers, joinColumnName) + primaryColumnIndex := util.IndexOf(headers, keyColumnName) if primaryColumnIndex == -1 { - return nil, fmt.Errorf("%s is not found", joinColumnName) + return nil, fmt.Errorf("%s is not found", keyColumnName) } dbFile, err := os.CreateTemp("", "csvdb") @@ -216,7 +216,7 @@ func LoadCsvFileTable(reader CsvReader, joinColumnName string) (CsvTable, error) // -> 重複して存在した場合はエラーに v := b.Get([]byte(key)) if v != nil { - return fmt.Errorf("%s:%s is duplicated", joinColumnName, key) + return fmt.Errorf("%s:%s is duplicated", keyColumnName, key) } rowJson, err := json.Marshal(row) @@ -238,9 +238,9 @@ func LoadCsvFileTable(reader CsvReader, joinColumnName string) (CsvTable, error) } } - return &FileTable{ - joinColumnName: joinColumnName, - columnNames: headers, - dbPath: dbFile.Name(), + return &fileTable{ + keyColumnName: keyColumnName, + columnNames: headers, + dbPath: dbFile.Name(), }, nil } diff --git a/csv/table_test.go b/csv/table_test.go index f7a6b5a..6740494 100644 --- a/csv/table_test.go +++ b/csv/table_test.go @@ -28,8 +28,8 @@ func TestLoadCsvMemoryTable(t *testing.T) { t.Fatal("failed test\n", table.ColumnNames()) } - if table.JoinColumnName() != "ID" { - t.Fatal("failed test\n", table.JoinColumnName()) + if table.KeyColumnName() != "ID" { + t.Fatal("failed test\n", table.KeyColumnName()) } result, err := table.Find("5") @@ -73,7 +73,7 @@ func TestLoadCsvMemoryTable_duplicateKey(t *testing.T) { } } -func TestLoadCsvMemoryTable_joinColumnNotFound(t *testing.T) { +func TestLoadCsvMemoryTable_keyColumnNotFound(t *testing.T) { s := `ID,Name,Height,Weight 1,Yamada,171,50 @@ -111,7 +111,7 @@ func TestLoadCsvMemoryTable_changeLineOnly(t *testing.T) { func TestLoadCsvMemoryTable_big(t *testing.T) { - const maxId = 1000000 + const maxId = 100000 s := [maxId]string{} s[0] = "ID,Name,Age" @@ -131,8 +131,8 @@ func TestLoadCsvMemoryTable_big(t *testing.T) { t.Fatal("failed test\n", table.ColumnNames()) } - if table.JoinColumnName() != "ID" { - t.Fatal("failed test\n", table.JoinColumnName()) + if table.KeyColumnName() != "ID" { + t.Fatal("failed test\n", table.KeyColumnName()) } for i := 1; i < maxId; i++ { @@ -177,8 +177,8 @@ func TestLoadCsvFileTable(t *testing.T) { t.Fatal("failed test\n", table.ColumnNames()) } - if table.JoinColumnName() != "ID" { - t.Fatal("failed test\n", table.JoinColumnName()) + if table.KeyColumnName() != "ID" { + t.Fatal("failed test\n", table.KeyColumnName()) } result, err := table.Find("5") @@ -222,7 +222,7 @@ func TestLoadCsvFileTable_duplicateKey(t *testing.T) { } } -func TestLoadCsvFileTable_joinColumnNotFound(t *testing.T) { +func TestLoadCsvFileTable_keyColumnNotFound(t *testing.T) { s := `ID,Name,Height,Weight 1,Yamada,171,50 @@ -260,7 +260,7 @@ func TestLoadCsvFileTable_changeLineOnly(t *testing.T) { func TestLoadCsvFileTable_big(t *testing.T) { - const maxId = 1000000 + const maxId = 100000 s := [maxId]string{} s[0] = "ID,Name,Age" @@ -280,8 +280,8 @@ func TestLoadCsvFileTable_big(t *testing.T) { t.Fatal("failed test\n", table.ColumnNames()) } - if table.JoinColumnName() != "ID" { - t.Fatal("failed test\n", table.JoinColumnName()) + if table.KeyColumnName() != "ID" { + t.Fatal("failed test\n", table.KeyColumnName()) } for i := 1; i < maxId; i++ {