diff --git a/iree/turbine/aot/support/ir_utils.py b/iree/turbine/aot/support/ir_utils.py index 348cf292..0722ef71 100644 --- a/iree/turbine/aot/support/ir_utils.py +++ b/iree/turbine/aot/support/ir_utils.py @@ -290,9 +290,16 @@ def create_tensor_global( if attrs.mutable: ir_attrs["is_mutable"] = UnitAttr.get() if device: - ir_attrs["iree.abi.affinity"] = Attribute.parse( - f"#hal.device.promise<@__device_{device.ordinal}>" - ) + if device.queues is None: + ir_attrs["stream.affinity"] = Attribute.parse( + f"#hal.device.promise<@__device_{device.ordinal}>" + ) + else: + queues = ", ".join(device.queues) + ir_attrs["stream.affinity"] = Attribute.parse( + f"#hal.device.promise<@__device_{device.ordinal}, [{queues}]>" + ) + if external: # Emit named external reference. external_scope_attr = StringAttr.get(external_scope or "model") diff --git a/iree/turbine/aot/tensor_traits.py b/iree/turbine/aot/tensor_traits.py index c6b85514..8c9e8a67 100644 --- a/iree/turbine/aot/tensor_traits.py +++ b/iree/turbine/aot/tensor_traits.py @@ -20,16 +20,19 @@ class DeviceAffinity: """This is used to provide device affinities to exported function arguments.""" - def __init__(self, ordinal: int): + def __init__(self, ordinal: int, queues: Optional[list] = None): self.ordinal = ordinal + self.queues = queues def __eq__(self, other) -> bool: if not isinstance(other, DeviceAffinity): return False - return self.ordinal == other.ordinal + return self.ordinal == other.ordinal and self.queues == other.queues def __repr__(self) -> str: - return f"DeviceAffinity({self.ordinal})" + if self.queues is None: + return f"DeviceAffinity({self.ordinal})" + return f"DeviceAffinity({self.ordinal}, [{', '.join(self.queues)}])" @dataclass @@ -39,6 +42,7 @@ class DeviceTensorTrait: """ ordinal: int + queues: Optional[list] = None @staticmethod def get(from_tensor: torch.Tensor) -> Optional["DeviceTensorTrait"]: