Open the src/input_opt.py file. The network ./data/weights.pkl contains network weights pre-trained on MNIST. Turn the network optimization problem around, and find an input that makes a particular output neuron extremely happy. In other words maximize,
Use jax.value_and_grad to find the gradients of the network input jax.random.uniform network input of shape [1, 28, 28, 1] and
iteratively optimize it.
Reuse your MNIST digit recognition code. Implement IG as discussed in the lecture. Recall the equation
F partial xi denotes the gradients with respect to the input color-channels i. x prime denotes a baseline black image. And x symbolizes an input we are interested in. Finally, m denotes the number of summation steps from the black baseline image to the interesting input.
Follow the todos in ./src/mnist_integrated.