Skip to content

Commit

Permalink
Merge pull request #24 from cpunion/thread-safe
Browse files Browse the repository at this point in the history
Thread safe dict, call non-borrowing api
  • Loading branch information
cpunion authored Nov 9, 2024
2 parents 397cf15 + c582b30 commit f2c4b10
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 36 deletions.
7 changes: 3 additions & 4 deletions convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,15 @@ func ToValue(from Object, to reflect.Value) bool {
t := to.Type()
to.Set(reflect.MakeMap(t))
dict := cast[Dict](from)
iter := dict.Iter()
for iter.HasNext() {
key, value := iter.Next()
dict.Items()(func(key, value Object) bool {
vk := reflect.New(t.Key()).Elem()
vv := reflect.New(t.Elem()).Elem()
if !ToValue(key, vk) || !ToValue(value, vv) {
return false
}
to.SetMapIndex(vk, vv)
}
return true
})
return true
} else {
return false
Expand Down
54 changes: 33 additions & 21 deletions dict.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,26 @@ package gp

/*
#include <Python.h>
typedef struct pyCriticalSection {
uintptr_t _cs_prev;
void *_cs_mutex;
} pyCriticalSection;
static inline void pyCriticalSection_Begin(pyCriticalSection *pcs, PyObject *op) {
#if PY_VERSION_HEX >= 0x030D0000
PyCriticalSection_Begin((PyCriticalSection*)pcs, op);
#else
PyGILState_STATE gstate = PyGILState_Ensure();
pcs->_cs_prev = (uintptr_t)gstate;
#endif
}
static inline void pyCriticalSection_End(pyCriticalSection *pcs) {
#if PY_VERSION_HEX >= 0x030D0000
PyCriticalSection_End((PyCriticalSection*)pcs);
#else
PyGILState_Release((PyGILState_STATE)pcs->_cs_prev);
#endif
}
*/
import "C"
import (
Expand Down Expand Up @@ -75,26 +95,18 @@ func (d Dict) Del(key Objecter) {
C.PyDict_DelItem(d.obj, key.cpyObj())
}

func (d Dict) Iter() *DictIter {
return &DictIter{dict: d, pos: 0}
}

type DictIter struct {
dict Dict
pos C.long
}

func (d *DictIter) HasNext() bool {
pos := d.pos
return C.PyDict_Next(d.dict.obj, &pos, nil, nil) != 0
}

func (d *DictIter) Next() (Object, Object) {
var key, value *C.PyObject
if C.PyDict_Next(d.dict.obj, &d.pos, &key, &value) == 0 {
return Nil(), Nil()
func (d Dict) Items() func(func(Object, Object) bool) {
obj := d.cpyObj()
var cs C.pyCriticalSection
C.pyCriticalSection_Begin(&cs, obj)
return func(fn func(Object, Object) bool) {
defer C.pyCriticalSection_End(&cs)
var pos C.long
var key, value *C.PyObject
for C.PyDict_Next(obj, &pos, &key, &value) == 1 {
if !fn(newObject(key), newObject(value)) {
return
}
}
}
C.Py_IncRef(key)
C.Py_IncRef(value)
return newObject(key), newObject(value)
}
7 changes: 3 additions & 4 deletions dict_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,15 @@ func TestDictForEach(t *testing.T) {
"key3": "value3",
}

iter := dict.Iter()
for iter.HasNext() {
key, value := iter.Next()
dict.Items()(func(key, value Object) bool {
count++
k := key.String()
v := value.String()
if expectedVal, ok := expectedPairs[k]; !ok || expectedVal != v {
t.Errorf("ForEach() unexpected pair: %v: %v", k, v)
}
}
return true
})

if count != len(expectedPairs) {
t.Errorf("ForEach() visited %d pairs, want %d", count, len(expectedPairs))
Expand Down
2 changes: 1 addition & 1 deletion extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ func wrapperMethod_(typeMeta *typeMeta, methodMeta *slotMeta, self, args *C.PyOb
}

for i := 0; i < int(argc); i++ {
arg := C.PyTuple_GetItem(args, C.Py_ssize_t(i))
arg := C.PySequence_GetItem(args, C.Py_ssize_t(i))
argType := methodType.In(i + argIndex)
argPy := FromPy(arg)
goValue := reflect.New(argType).Elem()
Expand Down
4 changes: 1 addition & 3 deletions list.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ func MakeList(args ...any) List {
}

func (l List) GetItem(index int) Object {
v := C.PyList_GetItem(l.obj, C.Py_ssize_t(index))
C.Py_IncRef(v)
return newObject(v)
return newObject(C.PySequence_GetItem(l.obj, C.Py_ssize_t(index)))
}

func (l List) SetItem(index int, item Objecter) {
Expand Down
4 changes: 1 addition & 3 deletions tuple.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ func MakeTuple(args ...any) Tuple {
}

func (t Tuple) Get(index int) Object {
v := C.PyTuple_GetItem(t.obj, C.Py_ssize_t(index))
C.Py_IncRef(v)
return newObject(v)
return newObject(C.PySequence_GetItem(t.obj, C.Py_ssize_t(index)))
}

func (t Tuple) Set(index int, obj Objecter) {
Expand Down

0 comments on commit f2c4b10

Please sign in to comment.