The goal of this exercise is to implement a multilayer dense neural network using jax and flax.
Type,
pip install -r requirements.txtinto the terminal to install the required software.
Jax takes care of our autograd needs. The documentation is available at https://jax.readthedocs.io/en/latest/index.html . Flax is a high-level neural network library. https://flax.readthedocs.io/en/latest/ hosts the documentation.
To get a notion of how function learning of a dense layer network works on given data, we will first have a look at the example from the lecture. In the following task you will implement gradient descent learning of a dense neural network using jax and use it to learn a function, e.g. a cosine.
- 
As a first step, create a cosine function in Jax and add some noise with jax.random.normal. Use, for example, a signal length of$n = 200$ samples and a period of your choosing. This will be the noisy signal that the model is supposed to learn the underlaying cosine from.
- 
Recall the definition of the sigmoid function $\sigma$ 
- 
Implement the sigmoidfunction insrc/denoise_cosine.py.
- 
Implement a dense layer in the netfunction ofsrc/denoise_cosine.py. The function should return
where W_1, W_2 and b.   Use numpys @ notation for the matrix product.
- 
Use jax.random.uniformto initialize your weights. For a signal length of$200$ the$W_2$ matrix should have e.g. have the shape [200,hidden_neurons] and$W_1$ a shape of [hidden_neurons, 200]. Start with$\mathcal{U}[-0.1, 0.1]$ for example.jax.random.PRNGKeyallows you to create a seed for the random number generator.
- 
Implement and test a squared error cost 
- 
**denotes squares in Python,jnp.sumallows you to sum up all terms.
- 
Define the forward pass in the net_costfunction. The forward pass evaluates the network and the cost function.
- 
Train your network to denoise a cosine. To do so, implement gradient descent on the noisy input signal and use e.g. jax.value_and_gradto compute cost and gradient at the same time. Remember the gradient descent update rule
- 
In the equation above $\mathbf{W} \in \mathbb{R}$ holds for weight matrices and biases$\epsilon$ denotes the step size and$\delta$ the gradient operation with respect to the following weight. Use a loop to repeat weight updates for multiple operations. Try to train for one hundred updates.
- 
At last, compute the network output y_haton the final values to see if the network learned the underlying cosine function. Usematplotlib.pyplot.plotto plot the noisy signal and the network output$\mathbf{o}$ .
- 
Test your code with nox -r -s testand run the script withpython ./src/denoise_cosine.pyor by pressingCtrl + F5in Vscode.
In this task we will go one step further. Instead of a cosine function, our neural network will learn how to identify handwritten digits from the MNSIT dataset. For that, we will be using the linen api of the module flax. Firstly, make yourself familiar with the linen api to get started with training a fully connected network in src/mnist.py. In this script, some functions are already implemented and can be reused. Use jax.numpy.array_split to create a list of batches from your training set. Broadcasting is an elegant way to deal with data batches. This task aims to compute gradients and update steps for all batches in the list. If you are coding on bender the function matplotlib.pyplot.show doesn't work if you are not connected to the X server of bender. Use e.g. plt.savefig to save the figure and view it in vscode.
- Implement the normalizefunction to ensure approximate standard-normal inputs. Make use of handy numpy methods that you already know. Normalization requires subtraction of the mean and division by the standard deviation with$i = 1, \dots w$ and$j = 1, \dots h$ with$w$ the image width and$h$ the image height and$k$ running through the batch dimension:
- 
The forward step requires the Netobject from its class. It is your fully connected neural network model. Applying weights to aflax.linen.Moduleis comparable to calculating the forward pass of the network in task 1. Implement a dense network inNetof your choosing using a combination offlax.linen.Denseandflax.linen.activation.reluorflax.linen.sigmoid.
- 
The forward pass ends with the evaluation of a cost function. Write a cross_entropycost function with$n_o$ the number of labels and$n_b$ in the batched case using
- 
If you have chosen to work with ten output neurons. Use jax.nn.one_hotto encode the labels.
- 
Now implement the forward_stepfunction. Calculate the network output first. Then compute the loss. It should return a scalar cost term you can use to compute gradients. Make use of the cross entropy.
- 
Next we want to be able to do an optimization step with stochastic gradient descent (sgd). Implement sgd_step. Use the gradients to update the weights. Considerjax.tree_util.tree_mapfor this task. Treemaps work best with a lambda expression.
- 
To evaluate the network we calculate the accuracy of the network output. Implement get_accto calculate the accuracy given a batch of images and the corresponding labels for these images.
- 
Now is the time to move back to the main procedure. First, the train data is fetched via the function get_mnist_train_data. To be able to evaluate the network while it is being trained, we use a validation set. Here the train set is split into two disjoint sets: the training and the validation set. Both sets must be normalized.
- 
Define your loss and gradient function with jax (see task 1). Next, initialize the network with the Netobject (see theflaxdocumentation for help).
- 
Train your network for a fixed number of epochsover the entire dataset.
- 
Last, load the test data with get_mnist_test_dataand calculate the test accuracy. Save it to a list.
- 
Optional: Plot the training and validation accuracies and add the test accuracy in the end.