Skip to content

Commit

Permalink
Add __call__ (#216)
Browse files Browse the repository at this point in the history
  • Loading branch information
gatesn authored Oct 17, 2023
1 parent 73b927f commit 312173f
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 9 deletions.
19 changes: 10 additions & 9 deletions docs/guide/classes.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,16 @@ const binaryfunc = fn(*Self, object) !object;

### Type Methods

| Method | Signature |
| :--------- | :-------------------------------- |
| `__new__` | `#!zig fn(CallArgs) !Self` |
| `__init__` | `#!zig fn(*Self, CallArgs) !void` |
| `__del__` | `#!zig fn(*Self) void` |
| `__repr__` | `#!zig fn(*Self) !py.PyString` |
| `__str__` | `#!zig fn(*Self) !py.PyString` |
| `__iter__` | `#!zig fn(*Self) !object` |
| `__next__` | `#!zig fn(*Self) !?object` |
| Method | Signature |
| :--------- | :--------------------------------------- |
| `__new__` | `#!zig fn(CallArgs) !Self` |
| `__init__` | `#!zig fn(*Self, CallArgs) !void` |
| `__del__` | `#!zig fn(*Self) void` |
| `__repr__` | `#!zig fn(*Self) !py.PyString` |
| `__str__` | `#!zig fn(*Self) !py.PyString` |
| `__call__` | `#!zig fn(*Self, CallArgs) !py.PyObject` |
| `__iter__` | `#!zig fn(*Self) !object` |
| `__next__` | `#!zig fn(*Self) !?object` |

### Sequence Methods

Expand Down
9 changes: 9 additions & 0 deletions example/classes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@ from __future__ import annotations
class Animal:
def species(self, /): ...

class Callable:
def __init__():
pass
def __call__(self, /, *args, **kwargs):
"""
Call self as a function.
"""
...

class ConstructableClass:
def __init__(count, /):
pass
Expand Down
13 changes: 13 additions & 0 deletions example/classes.zig
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,19 @@ pub const Hash = py.class(struct {
}
});

pub const Callable = py.class(struct {
const Self = @This();

pub fn __new__() Self {
return .{};
}

pub fn __call__(self: *const Self, args: struct { i: u32 }) u32 {
_ = self;
return args.i;
}
});

comptime {
py.rootmodule(@This());
}
21 changes: 21 additions & 0 deletions pydust/src/pytypes.zig
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,13 @@ fn Slots(comptime definition: type, comptime name: [:0]const u8) type {
}};
}

if (@hasDecl(definition, "__call__")) {
slots_ = slots_ ++ .{ffi.PyType_Slot{
.slot = ffi.Py_tp_call,
.pfunc = @ptrCast(@constCast(&tp_call)),
}};
}

if (richcmp.hasCompare) {
slots_ = slots_ ++ .{ffi.PyType_Slot{
.slot = ffi.Py_tp_richcompare,
Expand Down Expand Up @@ -371,6 +378,20 @@ fn Slots(comptime definition: type, comptime name: [:0]const u8) type {
const result = tramp.coerceError(definition.__hash__(&self.state)) catch return -1;
return @as(isize, @bitCast(result));
}

fn tp_call(pyself: *ffi.PyObject, pyargs: [*c]ffi.PyObject, pykwargs: [*c]ffi.PyObject) callconv(.C) ?*ffi.PyObject {
const sig = funcs.parseSignature("__call__", @typeInfo(@TypeOf(definition.__call__)).Fn, &.{ *definition, *const definition, py.PyObject });

const args = if (pyargs) |pa| py.PyTuple.unchecked(.{ .py = pa }) else null;
const kwargs = if (pykwargs) |pk| py.PyDict.unchecked(.{ .py = pk }) else null;

const self = tramp.Trampoline(sig.selfParam.?).unwrap(py.PyObject{ .py = pyself }) catch return null;
const call_args = tramp.Trampoline(sig.argsParam.?).unwrapCallArgs(args, kwargs) catch return null;
defer funcs.deinitArgs(sig.argsParam.?, call_args);

const result = tramp.coerceError(definition.__call__(self, call_args)) catch return null;
return (py.createOwned(result) catch return null).py;
}
};
}

Expand Down
5 changes: 5 additions & 0 deletions test/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def test_hash():
assert hash(h) == -7849439630130923510


def test_callable():
c = classes.Callable()
assert c(30) == 30


def test_refcnt():
# Verify that initializing a class does not leak a reference to the module.
rc = sys.getrefcount(classes)
Expand Down

0 comments on commit 312173f

Please sign in to comment.