Skip to content

Commit

Permalink
fix: pr change requests
Browse files Browse the repository at this point in the history
Co-authored-by: Philipp Gillé <[email protected]>
  • Loading branch information
iwilltry42 and philippgille committed Jul 3, 2024
1 parent 49eb498 commit a4a5653
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
12 changes: 8 additions & 4 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ func (db *DB) Import(filePath string, encryptionKey string) error {
// - filePath: Mandatory, must not be empty
// - encryptionKey: Optional, must be 32 bytes long if provided
// - collections: Optional. If provided, only the collections with the given names
// are imported. If not provided, all collections are imported.
// are imported. Non-existing collections are ignored.
// If not provided, all collections are imported.
func (db *DB) ImportFromFile(filePath string, encryptionKey string, collections ...string) error {
if filePath == "" {
return fmt.Errorf("file path is empty")
Expand Down Expand Up @@ -287,7 +288,8 @@ func (db *DB) ImportFromFile(filePath string, encryptionKey string, collections
// - reader: An implementation of [io.ReadSeeker]
// - encryptionKey: Optional, must be 32 bytes long if provided
// - collections: Optional. If provided, only the collections with the given names
// are imported. If not provided, all collections are imported.
// are imported. Non-existing collections are ignored.
// If not provided, all collections are imported.
func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string, collections ...string) error {
if encryptionKey != "" {
// AES 256 requires a 32 byte key
Expand Down Expand Up @@ -373,7 +375,8 @@ func (db *DB) Export(filePath string, compress bool, encryptionKey string) error
// - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes
// long if provided.
// - collections: Optional. If provided, only the collections with the given names
// are exported. If not provided, all collections are exported.
// are exported. Non-existing collections are ignored.
// If not provided, all collections are exported.
func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string, collections ...string) error {
if filePath == "" {
filePath = "./chromem-go.gob"
Expand Down Expand Up @@ -435,7 +438,8 @@ func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string,
// - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes
// long if provided.
// - collections: Optional. If provided, only the collections with the given names
// are exported. If not provided, all collections are exported.
// are exported. Non-existing collections are ignored.
// If not provided, all collections are exported.
func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey string, collections ...string) error {
if encryptionKey != "" {
// AES 256 requires a 32 byte key
Expand Down
25 changes: 13 additions & 12 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ func TestDB_ImportExport(t *testing.T) {
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
// Create DB, can just be in-memory
orig := NewDB()
origDB := NewDB()

// Create collection
c, err := orig.CreateCollection(name, metadata, embeddingFunc)
c, err := origDB.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand All @@ -139,7 +139,7 @@ func TestDB_ImportExport(t *testing.T) {
}

// Export
err = orig.ExportToFile(tc.filePath, tc.compress, tc.encryptionKey)
err = origDB.ExportToFile(tc.filePath, tc.compress, tc.encryptionKey)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand All @@ -156,8 +156,8 @@ func TestDB_ImportExport(t *testing.T) {
// We have to reset the embed function, but otherwise the DB objects
// should be deep equal.
c.embed = nil
if !reflect.DeepEqual(orig, newDB) {
t.Fatalf("expected DB %+v, got %+v", orig, newDB)
if !reflect.DeepEqual(origDB, newDB) {
t.Fatalf("expected DB %+v, got %+v", origDB, newDB)
}
})
}
Expand All @@ -180,15 +180,15 @@ func TestDB_ImportExportSpecificCollections(t *testing.T) {
}

// Create DB, can just be in-memory
orig := NewDB()
origDB := NewDB()

// Create collections
c, err := orig.CreateCollection(name, metadata, embeddingFunc)
c, err := origDB.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}

c2, err := orig.CreateCollection(name2, metadata, embeddingFunc)
c2, err := origDB.CreateCollection(name2, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down Expand Up @@ -218,15 +218,16 @@ func TestDB_ImportExportSpecificCollections(t *testing.T) {
t.Fatal("expected no error, got", err)
}

// Export
err = orig.ExportToFile(filePath, false, "", name2)
// Export only one of the two collections
err = origDB.ExportToFile(filePath, false, "", name2)
if err != nil {
t.Fatal("expected no error, got", err)
}

dir := filepath.Join(path, randomString(r, 10))
defer os.RemoveAll(dir)

// Instead of importing to an in-memory DB we use a persistent one to cover the behavior of immediate persistent files being created for the imported data
newPDB, err := NewPersistentDB(dir, false)
if err != nil {
t.Fatal("expected no error, got", err)
Expand All @@ -252,8 +253,8 @@ func TestDB_ImportExportSpecificCollections(t *testing.T) {
}

// Now export both collections and import them into the same persistent DB (overwriting the one we just imported)
filePath2 := path + "2.gob"
err = orig.ExportToFile(filePath2, false, "")
filePath2 := filepath.Join(path, "2.gob")
err = origDB.ExportToFile(filePath2, false, "")
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down

0 comments on commit a4a5653

Please sign in to comment.