From fb78267fa942f925ae2059e0b3cd1be473d12d2d Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 20 Sep 2023 14:13:51 +0100 Subject: [PATCH] Add __init__ slot (#118) --- docs/guide/classes.md | 17 +++++++++-------- pydust/src/builtins.zig | 10 ++++++++++ pydust/src/pytypes.zig | 38 ++++++++++++++++++++++++++++++++++---- pydust/src/trampoline.zig | 15 ++++++++++++++- pydust/src/types/obj.zig | 6 ++++++ 5 files changed, 73 insertions(+), 13 deletions(-) diff --git a/docs/guide/classes.md b/docs/guide/classes.md index db79d90d..1f352436 100644 --- a/docs/guide/classes.md +++ b/docs/guide/classes.md @@ -150,14 +150,15 @@ const binaryfunc = fn(*Self, other: object) !object; ### Type Methods -| Method | Signature | -|:-----------|:-------------------------------| -| `__new__` | `#!zig fn(CallArgs) !Self` | -| `__del__` | `#!zig fn(*Self) void` | -| `__repr__` | `#!zig fn(*Self) !py.PyString` | -| `__str__` | `#!zig fn(*Self) !py.PyString` | -| `__iter__` | `#!zig fn(*Self) !object` | -| `__next__` | `#!zig fn(*Self) !?object` | +| Method | Signature | +|:-----------|:---------------------------------| +| `__new__` | `#!zig fn(CallArgs) !Self` | +| `__init__` | `#!zig fn(*Self, CallArgs) !void`| +| `__del__` | `#!zig fn(*Self) void` | +| `__repr__` | `#!zig fn(*Self) !py.PyString` | +| `__str__` | `#!zig fn(*Self) !py.PyString` | +| `__iter__` | `#!zig fn(*Self) !object` | +| `__next__` | `#!zig fn(*Self) !?object` | ### Sequence Methods diff --git a/pydust/src/builtins.zig b/pydust/src/builtins.zig index 51bd82f4..f14d8220 100644 --- a/pydust/src/builtins.zig +++ b/pydust/src/builtins.zig @@ -84,6 +84,16 @@ pub fn importFrom(module_name: [:0]const u8, attr: [:0]const u8) !py.PyObject { return try mod.get(attr); } +/// Check if object is an instance of cls. +pub fn isinstance(object: anytype, cls: anytype) !bool { + const pyobj = py.object(object); + const pycls = py.object(cls); + + const result = ffi.PyObject_IsInstance(pyobj.py, pycls.py); + if (result < 0) return PyError.Propagate; + return result == 1; +} + /// Return the reference count of the object. pub fn refcnt(object: anytype) isize { const pyobj = py.object(object); diff --git a/pydust/src/pytypes.zig b/pydust/src/pytypes.zig index 8605805a..2966694b 100644 --- a/pydust/src/pytypes.zig +++ b/pydust/src/pytypes.zig @@ -131,6 +131,13 @@ fn Slots(comptime definition: type) type { }}; } + if (@hasDecl(definition, "__init__")) { + slots_ = slots_ ++ .{ffi.PyType_Slot{ + .slot = ffi.Py_tp_init, + .pfunc = @ptrCast(@constCast(&tp_init)), + }}; + } + if (@hasDecl(definition, "__del__")) { slots_ = slots_ ++ .{ffi.PyType_Slot{ .slot = ffi.Py_tp_finalize, @@ -225,6 +232,7 @@ fn Slots(comptime definition: type) type { // Allow the definition to initialize the state field. self.state = tp_new_internal( + .{ .py = @alignCast(@ptrCast(subtype)) }, if (pyargs) |pa| py.PyTuple.unchecked(.{ .py = pa }) else null, if (pykwargs) |pk| py.PyDict.unchecked(.{ .py = pk }) else null, ) catch return null; @@ -232,11 +240,17 @@ fn Slots(comptime definition: type) type { return pyself; } - fn tp_new_internal(pyargs: ?py.PyTuple, pykwargs: ?py.PyDict) !definition { - const sig = funcs.parseSignature("__new__", @typeInfo(@TypeOf(definition.__new__)).Fn, &.{}); - if (sig.selfParam) |_| @compileError("__new__ must not take a self parameter"); + fn tp_new_internal(subtype: py.PyObject, pyargs: ?py.PyTuple, pykwargs: ?py.PyDict) !definition { + const sig = funcs.parseSignature("__new__", @typeInfo(@TypeOf(definition.__new__)).Fn, &.{py.PyObject}); - if (sig.argsParam) |Args| { + if (sig.selfParam) |_| { + if (sig.argsParam) |Args| { + const args = try tramp.Trampoline(Args).unwrapCallArgs(.{ .args = pyargs, .kwargs = pykwargs }); + return try definition.__new__(subtype, args); + } else { + return try definition.__new__(subtype); + } + } else if (sig.argsParam) |Args| { const args = try tramp.Trampoline(Args).unwrapCallArgs(.{ .args = pyargs, .kwargs = pykwargs }); return try definition.__new__(args); } else { @@ -252,6 +266,22 @@ fn Slots(comptime definition: type) type { return null; } + fn tp_init(pyself: *ffi.PyObject, pyargs: [*c]ffi.PyObject, pykwargs: [*c]ffi.PyObject) callconv(.C) c_int { + const sig = funcs.parseSignature("__init__", @typeInfo(@TypeOf(definition.__init__)).Fn, &.{ *definition, *const definition, py.PyObject }); + + if (sig.selfParam == null or sig.argsParam == null) { + @compileError("__init__ must take both a self argument and an args struct"); + } + + const args = if (pyargs) |pa| py.PyTuple.unchecked(.{ .py = pa }) else null; + const kwargs = if (pykwargs) |pk| py.PyDict.unchecked(.{ .py = pk }) else null; + + const self = tramp.Trampoline(sig.selfParam.?).unwrap(py.PyObject{ .py = pyself }) catch return -1; + const init_args = tramp.Trampoline(sig.argsParam.?).unwrapCallArgs(.{ .args = args, .kwargs = kwargs }) catch return -1; + definition.__init__(self, init_args) catch return -1; + return 0; + } + /// Wrapper for the user's __del__ function. /// Note: tp_del is deprecated in favour of tp_finalize. /// diff --git a/pydust/src/trampoline.zig b/pydust/src/trampoline.zig index 2b8cc1ba..d2c2c49e 100644 --- a/pydust/src/trampoline.zig +++ b/pydust/src/trampoline.zig @@ -238,7 +238,20 @@ pub fn Trampoline(comptime T: type) type { // If the pointer is for a Pydust class if (def.type == .class) { - // TODO(ngates): check the PyType? + // TODO(ngates): What's the easiest / cheapest way to do this? + // For now, we just check the name + const clsName = State.getIdentifier(p.child).name; + const mod = State.getContaining(p.child, .module); + const modName = State.getIdentifier(mod).name; + + const Cls = try py.importFrom(modName, clsName); + if (!try py.isinstance(obj, Cls)) { + return py.TypeError.raiseFmt( + "Expected {s}.{s} but found {s}", + .{ modName, clsName, try obj.getTypeName() }, + ); + } + const PyType = pytypes.PyTypeStruct(p.child); const pyobject = @as(*PyType, @ptrCast(obj.py)); return @constCast(&pyobject.state); diff --git a/pydust/src/types/obj.zig b/pydust/src/types/obj.zig index adeedba7..89031ed5 100644 --- a/pydust/src/types/obj.zig +++ b/pydust/src/types/obj.zig @@ -36,6 +36,12 @@ pub const PyObject = extern struct { ffi.Py_DECREF(self.py); } + pub fn getTypeName(self: PyObject) ![:0]const u8 { + const pytype: *ffi.PyObject = ffi.PyObject_Type(self.py) orelse return PyError.Propagate; + const name = py.PyString.unchecked(.{ .py = ffi.PyType_GetName(@ptrCast(pytype)) orelse return PyError.Propagate }); + return name.asSlice(); + } + /// Call this object without any arguments. pub fn call0(self: PyObject) !PyObject { return .{ .py = ffi.PyObject_CallNoArgs(self.py) orelse return PyError.Propagate };