In this first task use cross-correlation to find Waldo in the image below:

[ Image source: https://rare-gallery.com ]
Recall that cross-correlation, which the machine learning world often refers to as convolution is defined as:
for an image matrix I and a kernel matrix
Navigate to the src/custom_conv.py module.
- Start in
my_conv_directand implement the convolution following the equation above. Test your function with vscode tests ornox -s test. - Go to
src/waldo.pyand make sure thatmy_conv_directis used for convolution. This script finds waldo in the image using your convolution function. Execute it withpython ./src/waldo.pyin your terminal. - If your code passes the pytest but is too slow to find waldo feel free to use
jax.scipy.signal.correlate2dinsrc/waldo.pyinstead of your convolution function.
Navigate to the src/custom_conv.py module.
The function my_conv implements a fast version of the convolution operation above using a flattend kernel. We learned about this fast version in the lecture. Have a look at the slides again and then implement get_indices to make my_conv work. It should return
- A matrix of indices following the flattend convolution rule from the lecture, e.g. for a
$(2\times 2)$ kernel and a$(3\times 3)$ image it should return the index transformation
- The number of rows and columns in the result following
$$o=(i-k)+1$$ where$i$ denotes the input size and$k$ the kernel size.
You can test the function with nox -s test by importing my_conv in tests/test_conv.py and changing my_conv_direct to my_conv in the test function. Make sure that src/waldo.py now uses my_conv for convolution and run the script again.
Open src/mnist.py and implement MNIST digit recognition with CNN in jax, use flax to help you.
- Reuse your code from yesterday.
- Reuse yesterday's
Netclass, add convolutional layers and pooling.flax.linen.Convandflax.linen.max_poolwill help you.

