-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[operator] fix libentry to support triton 2.3 #89
Conversation
StrongSpoon
commented
Jun 28, 2024
- modified libentry arguments collection, cache key generation and constexpr collection
- specify dns arguments for operators
- works for triton 2.2 and 2.3
@@ -82,7 +82,7 @@ def dropout_forward_kernel( | |||
"N", | |||
], | |||
) | |||
@triton.jit | |||
@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does Triton specialize on floats?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
floats not marked as dns are specialized as False
src/flag_gems/utils/libentry.py
Outdated
@@ -14,47 +16,110 @@ def __init__( | |||
while not isinstance(fn, triton.runtime.JITFunction): | |||
fn = fn.fn | |||
self.jit_function: triton.runtime.JITFunction = fn | |||
self.kernel_arg_indices = [] | |||
self.spec_indices = [] | |||
self.dns_indices = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's be more precise? do_not_specialize_indices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/flag_gems/utils/libentry.py
Outdated
for p in self.jit_function.params: | ||
if not p.is_constexpr: | ||
self.kernel_arg_indices.append(p.num) | ||
if p.do_not_specialize: | ||
self.dns_indices.append(p.num) | ||
else: | ||
self.spec_indices.append(p.num) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't a bit of list comprehension is more favorable than loops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/flag_gems/utils/libentry.py
Outdated
for arg in spec_args: | ||
if hasattr(arg, "data_ptr"): | ||
entry_key.append(str(arg.dtype)) | ||
entry_key.append(arg.data_ptr() % self.divisibility == 0) | ||
else: | ||
entry_key.append(type(arg)) | ||
entry_key.append(arg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
List comprehension is an immutable fashion for coding loops. It's normally more efficient than the mutable counterpart.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done