Skip to content

Commit

Permalink
Init does not use __new__ arguments (#173)
Browse files Browse the repository at this point in the history
`__new__` is only used when constructing the class from Python.

To add factory methods in Zig, they should be wrapped in `py.zig`
  • Loading branch information
gatesn authored Oct 4, 2023
1 parent e356037 commit cea2c23
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 26 deletions.
6 changes: 4 additions & 2 deletions example/iterators.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ class Range:
...

class RangeIterator:
def __init__(next, stop, step, /):
pass
"""
Range iterator
"""

def __next__(self, /):
"""
Implement next(self).
Expand Down
4 changes: 0 additions & 4 deletions example/iterators.zig
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ pub const RangeIterator = py.class(struct {
stop: i64,
step: i64,

pub fn __new__(args: struct { next: i64, stop: i64, step: i64 }) Self {
return .{ .next = args.next, .stop = args.stop, .step = args.step };
}

pub fn __next__(self: *Self) ?i64 {
if (self.next >= self.stop) {
return null;
Expand Down
22 changes: 2 additions & 20 deletions pydust/src/pydust.zig
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub fn finalize() void {
}

/// Instantiate a class defined in Pydust.
pub fn init(comptime Cls: type, args: NewArgs(Cls)) PyError!*Cls {
pub fn init(comptime Cls: type, args: Cls) PyError!*Cls {
const moduleDefinition = State.getContaining(Cls, .module);
const imported = try types.PyModule.import(State.getIdentifier(moduleDefinition).name);
const pytype = try imported.obj.get(State.getIdentifier(Cls).name);
Expand All @@ -49,29 +49,11 @@ pub fn init(comptime Cls: type, args: NewArgs(Cls)) PyError!*Cls {
// NOTE(ngates): we currently don't allow users to override tp_alloc, therefore we can shortcut
// using ffi.PyType_GetSlot(tp_alloc) since we know it will always return ffi.PyType_GenericAlloc
const pyobj: *pytypes.PyTypeStruct(Cls) = @alignCast(@ptrCast(ffi.PyType_GenericAlloc(@ptrCast(pytype.py), 0) orelse return PyError.PyRaised));

if (@hasDecl(Cls, "__new__")) {
pyobj.state = try tramp.coerceError(Cls.__new__(args));
} else if (@typeInfo(Cls).Struct.fields.len > 0) {
pyobj.state = args;
}
pyobj.state = args;

return &pyobj.state;
}

/// Find the type of the positional args for a class
inline fn NewArgs(comptime Cls: type) type {
if (!@hasDecl(Cls, "__new__")) {
// Default construct args are the struct fields themselves.
return Cls;
}

const func = @field(Cls, "__new__");
const typeInfo = @typeInfo(@TypeOf(func));
const sig = funcs.parseSignature("__new__", typeInfo.Fn, &.{});
return sig.argsParam orelse struct {};
}

/// Register the root Pydust module
pub fn rootmodule(comptime definition: type) void {
if (!State.isEmpty()) {
Expand Down

0 comments on commit cea2c23

Please sign in to comment.