-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
datashader speedup and bugfixes #309
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #309 +/- ##
==========================================
+ Coverage 83.76% 84.53% +0.76%
==========================================
Files 8 8
Lines 1540 1629 +89
==========================================
+ Hits 1290 1377 +87
- Misses 250 252 +2
|
src/spatialdata_plot/pl/render.py
Outdated
trans = mtransforms.Affine2D(matrix=affine_trans) | ||
trans_data = trans + ax.transData | ||
|
||
rgba_image = np.transpose(rgba_image.data.compute(), (1, 2, 0)) # type: ignore[attr-defined] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here (and in render_shapes), we access the image as numpy array from the SpatialImage. mypy doesn't believe that compute()
exists...
src/spatialdata_plot/pl/render.py
Outdated
|
||
# compute canvas size in pixels close to the actual image size to speed up computation | ||
plot_width = x_ext[1] - x_ext[0] | ||
plot_height = y_ext[1] - y_ext[0] | ||
plot_width_px = int(round(fig_params.fig.get_size_inches()[0] * fig_params.fig.dpi)) | ||
plot_height_px = int(round(fig_params.fig.get_size_inches()[1] * fig_params.fig.dpi)) | ||
factor = np.min([plot_width / plot_width_px, plot_height / plot_height_px]) | ||
plot_width = int(np.round(plot_width / factor)) | ||
plot_height = int(np.round(plot_height / factor)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd consider bundling this code in a private function since it's duplicate from above.
src/spatialdata_plot/pl/render.py
Outdated
@@ -456,8 +495,25 @@ def _render_points( | |||
cmap=render_params.cmap_params.cmap, | |||
) | |||
|
|||
rbga_image = np.transpose(ds_result.to_numpy().base, (0, 1, 2)) | |||
cax = ax.imshow(rbga_image, zorder=render_params.zorder, alpha=render_params.alpha) | |||
# create SpatialImage to get it back to original size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also this code is a duplicate, I'd consider it refactoring it into a private function.
…f eq_hist for ds.shade()
Thanks @Sonja-Stockhaus ! I have tried the PR on the Visium HD data and the speed up is significant! I switched back to using |
Sonja, I found a bug with this PR when the rasterized object needs to be aligned with the rest #291. The solution seems straightforward, I think one just needs to initialize the image element from datashader using the old coordinate transformations. |
Will look into it! |
@timtreis ready for review |
Hi @timtreis, checking the status of this PR. There are still some open tasks on this PR right? Can you list them here as tickable tasks please? |
Todo:
|
Thanks for the update!
This task here seems tricky. I would consider raising a warning saying that the outline is currently not supported and workin on this on a separate PR (with low priority). |
If you use |
Ah ok, then it should be a fast addition, thanks for the update. |
General problem of datashader so far: when you render elements, coloring them by a value and a color map, you wouldn't see a single element if all of them have the same value Fix in this PR: internally, not the whole colormap is passed to datashader, but just the color given by |
Thanks for the explanation, I think the approach that you implemented is a good one! |
This speeds everything up for me locally, but for the benchmark (#296), I see an effect (aka datashader faster than matplotlib) for e.g. 10k points/shapes etc.