Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion backends/cadence/utils/facto_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
tensor_constraints.extend(
[
cp.Dtype.In(lambda deps: [torch.float32, torch.int32]),
# Avoid NaN/Inf values that expose clamp NaN handling bugs
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
cp.Value.Le(lambda deps, dtype, struct: 2**4),
]
)
case "rsqrt.default":
Expand Down Expand Up @@ -456,6 +459,7 @@ def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]:
| "mul.Scalar"
| "div.Scalar"
| "constant_pad_nd.default"
| "clamp.default"
):
return [ScalarDtype.int]
case "full.default":
Expand Down Expand Up @@ -483,7 +487,32 @@ def facto_testcase_gen( # noqa: C901
cp.Size.Le(lambda deps, r, d: 2**2),
]
)
if in_spec.name == "max_val": # hardtanh
# Special handling for clamp.default to ensure min < max with sufficient gap (at least 2) and never None
if op_name == "clamp.default":
if in_spec.name == "min":
# min must always be provided (not None) and bounded, leave room for max
spec.inspec[index].constraints.extend(
[
cp.Optional.Eq(lambda deps: False), # Never None
cp.Value.Ge(lambda deps, dtype: -(2**4)),
cp.Value.Le(
lambda deps, dtype: 2**4 - 2
), # Leave room for max (at least 2 units)
]
)
elif in_spec.name == "max":
# max must always be provided (not None), be >= min + 2 (sufficient gap), and bounded
spec.inspec[index].deps = [0, 1] # deps on input tensor and min
spec.inspec[index].constraints.extend(
[
cp.Optional.Eq(lambda deps: False), # Never None
cp.Value.Ge(
lambda deps, dtype: deps[1] + 2
), # max >= min + 2 (sufficient gap)
cp.Value.Le(lambda deps, dtype: 2**4),
]
)
elif in_spec.name == "max_val": # hardtanh
spec.inspec[index].deps = [0, 1]
spec.inspec[index].constraints.extend(
[cp.Value.Ge(lambda deps, _: deps[1])]
Expand Down
Loading