diff --git a/pkg/mdb2/cache.go b/pkg/mdb2/cache.go index b0678ef..a79af8c 100644 --- a/pkg/mdb2/cache.go +++ b/pkg/mdb2/cache.go @@ -2,6 +2,7 @@ package mdb2 import ( "fmt" + "go.mongodb.org/mongo-driver/bson/bsoncodec" "io/ioutil" "os" "runtime" @@ -19,7 +20,9 @@ type cache struct { sync.Mutex } -func newCache(mdb *Mdb) (*cache, error) { +var codecRegistry *bsoncodec.Registry + +func newCache(mdb *Mdb, registry *bsoncodec.Registry) (*cache, error) { if err := os.MkdirAll(mdb.cacheDir, os.ModePerm); err != nil { return nil, err } @@ -27,6 +30,8 @@ func newCache(mdb *Mdb) (*cache, error) { m: make(map[string]*cacheItem), mdb: mdb, } + + codecRegistry = registry c.init() return c, nil } @@ -56,7 +61,7 @@ func (c *cache) init() { } // deserialize to get Id in appropriate type o := &obj{} - if err := bson.Unmarshal(raw, o); err == nil { + if err := bson.UnmarshalWithRegistry(codecRegistry, raw, o); err == nil { id = o.Id } @@ -126,7 +131,7 @@ func (c *cache) purge() { delete(c.m, k) c.Unlock() err := c.mdb.saveId(i.col, "saveId", i.id, i.o()) - if err != nil { + if err != nil && !IsUnackWrite(err) { log.S("col", i.col).S("id", fmt.Sprintf("%v", i.id)).Error(err) } c.Lock() @@ -134,7 +139,7 @@ func (c *cache) purge() { c.Unlock() continue } - if err == nil { + if err == nil || IsUnackWrite(err) { // remove from disk err2 := os.Remove(i.fn) if err2 != nil { @@ -175,5 +180,5 @@ func (i *cacheItem) o() *bson.Raw { } func (i *cacheItem) unmarshal(o interface{}) error { - return bson.Unmarshal(i.raw, o) + return bson.UnmarshalWithRegistry(codecRegistry, i.raw, o) } diff --git a/pkg/mdb2/mdb.go b/pkg/mdb2/mdb.go index 0605ab6..6933a2c 100644 --- a/pkg/mdb2/mdb.go +++ b/pkg/mdb2/mdb.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "go.mongodb.org/mongo-driver/x/mongo/driver" "reflect" "strings" "text/template" @@ -219,7 +220,7 @@ func (mdb *Mdb) Init(connStr string, opts ...func(db *Mdb)) error { } if mdb.cacheDir != "" { - mdb.cache, err = newCache(mdb) + mdb.cache, err = newCache(mdb, reg) if err != nil { return err } @@ -548,3 +549,14 @@ func IsDup(err error) bool { } return false } + +// IsUnackWrite checks if error is unacknowledged write error +func IsUnackWrite(err error) bool { + if err == nil { + return false + } + + return errors.Is(err, mongo.ErrUnacknowledgedWrite) || + errors.Is(err, driver.ErrUnacknowledgedWrite) || + err.Error() == mongo.ErrUnacknowledgedWrite.Error() +} diff --git a/pkg/mdb2/mdb_test.go b/pkg/mdb2/mdb_test.go index 3bf8d57..a727ae2 100644 --- a/pkg/mdb2/mdb_test.go +++ b/pkg/mdb2/mdb_test.go @@ -5,6 +5,7 @@ import ( "encoding/gob" "encoding/json" "fmt" + "go.mongodb.org/mongo-driver/bson" "io/ioutil" "os" "reflect" @@ -57,7 +58,7 @@ var testCacheDir = "./tmp/cacheDir" func TestCacheAdd(t *testing.T) { db := &Mdb{name: "dbName", cacheDir: testCacheDir} - c, err := newCache(db) + c, err := newCache(db, bson.DefaultRegistry) assert.Nil(t, err) err = c.add("obj", 1, &obj{Id: 1}) @@ -71,7 +72,7 @@ func TestCacheAdd(t *testing.T) { assert.Equal(t, []string{"obj.1", "obj.2", "obj.3"}, ls) t.Logf("%v", ls) - c2, err := newCache(db) + c2, err := newCache(db, bson.DefaultRegistry) assert.Nil(t, err) assert.Len(t, c2.m, 3) @@ -166,3 +167,32 @@ func TestMongoSerde(t *testing.T) { assert.Equal(t, int(m["b"].(map[string]interface{})["i32"].(int32)), res["b"].(map[string]interface{})["i32"]) assert.Equal(t, int64(m["c"].(int)), res["c"]) } + +func TestMongoCacheSerde(t *testing.T) { + db := &db2{} + if err := db.Init(DefaultConnStr(), Name(testDbName), CacheRoot("./tmp/tests/cache")); err != nil { + log.Fatalf("failed to open connection: %s", err) + return + } + m := map[string]interface{}{ + "a": 12345, + "b": map[string]interface{}{ + "i": 12345678901235, + "i64": int64(12345678901235), + "i32": int32(1234567), + }, + "c": 12345678901234, + } + db.SaveId(testCollectionName, 2513, m) + + var res map[string]interface{} + db.ReadId(testCollectionName, 2513, &res) + + assert.Equal(t, m["a"], res["a"]) + assert.Equal(t, int64(m["b"].(map[string]interface{})["i"].(int)), res["b"].(map[string]interface{})["i"]) + assert.Equal(t, m["b"].(map[string]interface{})["i64"].(int64), res["b"].(map[string]interface{})["i64"]) + assert.Equal(t, int(m["b"].(map[string]interface{})["i32"].(int32)), res["b"].(map[string]interface{})["i32"]) + assert.Equal(t, int64(m["c"].(int)), res["c"]) + + db.cache.purge() +}