Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions hugr-py/src/hugr/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""HUGR model data structures."""

from collections.abc import Sequence
from abc import ABC
from collections.abc import Generator, Sequence
from dataclasses import dataclass, field
from enum import Enum
from typing import Protocol
from typing import Optional

from semver import Version

Expand All @@ -21,7 +22,7 @@ def _current_version() -> Version:
CURRENT_VERSION: Version = _current_version()


class Term(Protocol):
class Term(ABC):
"""A model term for static data such as types, constants and metadata."""

def __str__(self) -> str:
Expand All @@ -33,6 +34,26 @@ def from_str(s: str) -> "Term":
"""Read the term from its string representation."""
return rust.string_to_term(s)

def to_list_parts(self) -> Generator["SeqPart"]:
if isinstance(self, List):
for part in self.parts:
if isinstance(part, Splice):
yield from part.seq.to_list_parts()
else:
yield part
else:
yield Splice(self)

def to_tuple_parts(self) -> Generator["SeqPart"]:
if isinstance(self, Tuple):
for part in self.parts:
if isinstance(part, Splice):
yield from part.seq.to_tuple_parts()
else:
yield part
else:
yield Splice(self)


@dataclass(frozen=True)
class Wildcard(Term):
Expand Down Expand Up @@ -129,9 +150,13 @@ def from_str(s: str) -> "Symbol":
return rust.string_to_symbol(s)


class Op(Protocol):
class Op(ABC):
"""The operation of a node."""

def symbol_name(self) -> str | None:
"""Returns name of the symbol introduced by this node, if any."""
return None


@dataclass(frozen=True)
class InvalidOp(Op):
Expand Down Expand Up @@ -159,13 +184,19 @@ class DefineFunc(Op):

symbol: Symbol

def symbol_name(self) -> str | None:
return self.symbol.name


@dataclass(frozen=True)
class DeclareFunc(Op):
"""Function declaration."""

symbol: Symbol

def symbol_name(self) -> str | None:
return self.symbol.name


@dataclass(frozen=True)
class CustomOp(Op):
Expand All @@ -181,13 +212,19 @@ class DefineAlias(Op):
symbol: Symbol
value: Term

def symbol_name(self) -> str | None:
return self.symbol.name


@dataclass(frozen=True)
class DeclareAlias(Op):
"""Alias declaration."""

symbol: Symbol

def symbol_name(self) -> str | None:
return self.symbol.name


@dataclass(frozen=True)
class TailLoop(Op):
Expand All @@ -205,20 +242,29 @@ class DeclareConstructor(Op):

symbol: Symbol

def symbol_name(self) -> str | None:
return self.symbol.name


@dataclass(frozen=True)
class DeclareOperation(Op):
"""Operation declaration."""

symbol: Symbol

def symbol_name(self) -> str | None:
return self.symbol.name


@dataclass(frozen=True)
class Import(Op):
"""Import operation."""

name: str

def symbol_name(self) -> str | None:
return self.name


@dataclass
class Node:
Expand Down
4 changes: 2 additions & 2 deletions hugr-py/src/hugr/model/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import Generic, TypeVar, cast

import hugr.model as model
from hugr.hugr.base import Hugr, Node
from hugr.hugr.node_port import InPort, OutPort
from hugr.hugr.base import Hugr
from hugr.hugr.node_port import InPort, Node, OutPort
from hugr.ops import (
CFG,
DFG,
Expand Down
Loading
Loading