-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next]: as_offset implementation in embedded #1397
base: main
Are you sure you want to change the base?
Conversation
For later reference: |
cscs-ci run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, still a few comments.
@@ -198,8 +216,11 @@ def remap( | |||
# then compute the index array | |||
xp = self.array_ns | |||
new_idx_array = xp.asarray(restricted_connectivity.ndarray) - current_range.start | |||
# finally, take the new array | |||
new_buffer = xp.take(self._ndarray, new_idx_array, axis=dim_idx) | |||
if self._ndarray.ndim > 1 and restricted_connectivity_domain == new_domain: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the second part of this condition? restricted_connectivity_domain == new_domain
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to avoid entering this condition in cases like:
@gtx.field_operator
def testee(a: gtx.Field[[Vertex, KDim], float]) -> gtx.Field[[Edge, KDim], float]:
return a(E2V[0])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain this if else branch and are you sure all cases are handled? I am confused...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When using FieldOffsets, only the specific dimensions related to the offset are taken into account.
Say I have this field_operator:
@gtx.field_operator
def testee(a: gtx.Field[[Edge, KDim], int]) -> gtx.Field[[Vertex, KDim], int]:
tmp = neighbor_sum(a(V2E), axis=V2EDim)
return tmp
Here the restricted_connectivity_domain
will be over [Edge, V2E]
and will exclude KDim
. In this case using the regular xp.take
works.
When using as_offset
, xp.take
is also ok to use if the offset_field
contains only one dimension.
However, when restricted_connectivity_domain
contains multiple dimensions that are exactly the same as in new_domain
, we have seen that xp.take
does not work and hence had to create _take_mdim
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but what about restricted_connectivity_domain.dims == new_domain.dims
, but ranges are different?
tests/next_tests/integration_tests/feature_tests/ffront_tests/test_as_offset.py
Outdated
Show resolved
Hide resolved
cscs-ci run |
cscs-ci run |
tests/next_tests/integration_tests/feature_tests/ffront_tests/test_as_offset.py
Outdated
Show resolved
Hide resolved
…test_as_offset.py Co-authored-by: Hannes Vogt <[email protected]>
…offset_embedded
cscs-ci run |
offset_provider={"Ioff": IDim, "Koff": KDim}, | ||
ref=a[2:], | ||
comparison=lambda out, ref: np.all(out == ref), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed we might be missing another test case: let's say we have a 3D field but the offset field is only 2D. I think the expected semantic is probably as if the offset field would be broadcasted first. This might be related to my comment about
if self._ndarray.ndim > 1 and restricted_connectivity_domain == new_domain:
implementation of
as_offset
in embedded