diff --git a/src/flag_gems/utils/libentry.py b/src/flag_gems/utils/libentry.py index 490b9d52..8fddd9d6 100644 --- a/src/flag_gems/utils/libentry.py +++ b/src/flag_gems/utils/libentry.py @@ -38,6 +38,13 @@ def key(self, spec_args, dns_args, const_args): for arg in dns_args: if hasattr(arg, "data_ptr"): entry_key.append(str(arg.dtype)) + elif isinstance(arg, int): + if -(2**31) <= arg and arg <= 2**31 - 1: + entry_key.append("i32") + elif 2**63 <= arg and arg <= 2**64 - 1: + entry_key.append("u64") + else: + entry_key.append("i64") else: entry_key.append(type(arg)) # const args passed by position