JAX scan
loop not compiling when having a single tile_map
call
#14
Labels
bug
Something isn't working
scan
loop not compiling when having a single tile_map
call
#14
There is a weird bug where it seems that the call is forwarded to CPU XLA backend when there is a single
tile_map
call in a JAX loop.Minimal reproducer:
The text was updated successfully, but these errors were encountered: