\u001b[0m in \u001b[0;36moperator_mul\u001b[0;34m(self, rhs)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m# define forward\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mVariable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mrhs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'{r.name} = {self.name} * {rhs.name}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mAttributeError\u001b[0m: 'Tensor' object has no attribute 'value'"
- ]
- }
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "pa5P78Y4PHfT"
- },
- "source": [
- "This doesn't work because `t` is being captured and used in propagate, but propgate expects to compute on Variables. Becuase `t` was extracted from autograd, it can no longer directly participate in the `propagate` call. One way to fix this is to recompute `t`"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "gY9_KfgZPgql",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 1000
- },
- "outputId": "f64d673a-7e95-42d6-b605-9ceec8efd6ae"
- },
- "source": [
- "def simple_recompute(a, b):\n",
- " t = (a.value + b.value)\n",
- " r = Variable(t * b.value)\n",
- " def propagate(dL_doutputs: List[Variable]) -> List[Variable]:\n",
- " dL_dr, = dL_doutputs\n",
- " dr_dt = b # partial from: r = t * b\n",
- " t = a + b # RECOMPUTE!\n",
- " dr_db = t # partial from: r = t * b\n",
- " dL_dt = dL_dr*dr_dt # chain rule\n",
- " dt_da = 1.0 # partial from t = a + b\n",
- " dt_db = 1.0 # partial from t = a + b\n",
- " dL_da = dL_dt * dt_da # chain rule\n",
- " dL_db = dL_dt * dt_db + dL_dr * dr_db # chain rule\n",
- " return [dL_da, dL_db]\n",
- " gradient_tape.append(TapeEntry(inputs=[a.name, b.name], outputs=[r.name], propagate=propagate))\n",
- " return r\n",
- "\n",
- "da, db = run_gradients(simple_recompute)\n",
- "da_ref, db_ref = run_gradients(simple)\n",
- "print(\"da\", da, da_ref)\n",
- "print(\"db\", db, db_ref)"
- ],
- "execution_count": null,
- "outputs": [
- {
- "output_type": "stream",
- "text": [
- "a = tensor([0.4605, 0.4061, 0.9422, 0.3946])\n",
- "b = tensor([0.0850, 0.3296, 0.9888, 0.6494])\n",
- "L0 = v0.sum()\n",
- "dL0 ------------------------\n",
- "v2 = v1.expand(4)\n",
- "v3 = a + b\n",
- "v4 = v2 * b\n",
- "v5 = v2 * v3\n",
- "v6 = v4 + v5\n",
- "dL0_dL0 = v1\n",
- "dL0_dv0 = v2\n",
- "dL0_da = v4\n",
- "dL0_db = v6\n",
- "------------------------\n",
- "v7 = v4 * v4\n",
- "v8 = v6 * v6\n",
- "v9 = v7 + v8\n",
- "L1 = v9.sum()\n",
- "dL1 ------------------------\n",
- "v11 = v10.expand(4)\n",
- "v12 = v11 * v6\n",
- "v13 = v11 * v6\n",
- "v14 = v12 + v13\n",
- "v15 = v11 * v4\n",
- "v16 = v11 * v4\n",
- "v17 = v15 + v16\n",
- "v18 = v17 + v14\n",
- "v19 = v14 * v3\n",
- "v20 = v14 * v2\n",
- "v21 = v18 * b\n",
- "v22 = v18 * v2\n",
- "v23 = v19 + v21\n",
- "v24 = v22 + v20\n",
- "v25 = v23.sum()\n",
- "dL1_dL1 = v10\n",
- "dL1_dv9 = v11\n",
- "dL1_dv7 = v11\n",
- "dL1_dv8 = v11\n",
- "dL1_dv6 = v14\n",
- "dL1_dv4 = v18\n",
- "dL1_dv5 = v14\n",
- "dL1_dv2 = v23\n",
- "dL1_dv3 = v20\n",
- "dL1_db = v24\n",
- "dL1_da = v20\n",
- "dL1_dv1 = v25\n",
- "------------------------\n",
- "a = tensor([0.4605, 0.4061, 0.9422, 0.3946])\n",
- "b = tensor([0.0850, 0.3296, 0.9888, 0.6494])\n",
- "v0 = a + b\n",
- "v1 = v0 * b\n",
- "L0 = v1.sum()\n",
- "dL0 ------------------------\n",
- "v3 = v2.expand(4)\n",
- "v4 = v3 * b\n",
- "v5 = v3 * v0\n",
- "v6 = v5 + v4\n",
- "dL0_dL0 = v2\n",
- "dL0_dv1 = v3\n",
- "dL0_dv0 = v4\n",
- "dL0_db = v6\n",
- "dL0_da = v4\n",
- "------------------------\n",
- "v7 = v4 * v4\n",
- "v8 = v6 * v6\n",
- "v9 = v7 + v8\n",
- "L1 = v9.sum()\n",
- "dL1 ------------------------\n",
- "v11 = v10.expand(4)\n",
- "v12 = v11 * v6\n",
- "v13 = v11 * v6\n",
- "v14 = v12 + v13\n",
- "v15 = v11 * v4\n",
- "v16 = v11 * v4\n",
- "v17 = v15 + v16\n",
- "v18 = v17 + v14\n",
- "v19 = v14 * v0\n",
- "v20 = v14 * v3\n",
- "v21 = v18 * b\n",
- "v22 = v18 * v3\n",
- "v23 = v19 + v21\n",
- "v24 = v23.sum()\n",
- "v25 = v22 + v20\n",
- "dL1_dL1 = v10\n",
- "dL1_dv9 = v11\n",
- "dL1_dv7 = v11\n",
- "dL1_dv8 = v11\n",
- "dL1_dv6 = v14\n",
- "dL1_dv4 = v18\n",
- "dL1_dv5 = v14\n",
- "dL1_dv3 = v23\n",
- "dL1_dv0 = v20\n",
- "dL1_db = v25\n",
- "dL1_dv2 = v24\n",
- "dL1_da = v20\n",
- "------------------------\n",
- "da tensor([1.2611, 2.1304, 5.8394, 3.3869]) tensor([1.2611, 2.1304, 5.8394, 3.3869])\n",
- "db tensor([ 2.6923, 4.9201, 13.6563, 8.0727]) tensor([ 2.6923, 4.9201, 13.6563, 8.0727])\n"
- ],
- "name": "stdout"
- }
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "HLUaN1UeP4C_"
- },
- "source": [
- "This recompute works but it is not ideal. First, the original compute may have been expensive (think a bunch of convolutions and multiplies), so redoing it in the backward pass may take significant time. Second, we need to save `a` and `b` to recompute `t`. Previously we only had to save `b`. What if `a` was a _huge_ matrix but `t` was small? Then we are using _more total memory_ by doing this recompute as well. In general, we want to avoid recomputing things unless we know it won't be expensive in time or space.\n",
- "\n",
- "Let's consider another approach. What happens if we just make `t` into a Variable?"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "ZIysmiVCQaw5",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 799
- },
- "outputId": "7b1a0d2a-a546-436f-af00-9b5d5cc53986"
- },
- "source": [
- "def simple_variable_wrong(a, b):\n",
- " t = (a.value + b.value)\n",
- " t_v = Variable(t, name='t') # named for debugging\n",
- " r = Variable(t * b.value)\n",
- " def propagate(dL_doutputs: List[Variable]) -> List[Variable]:\n",
- " dL_dr, = dL_doutputs\n",
- " dr_dt = b # partial from: r = t * b\n",
- " dr_db = t_v # partial from: r = t * b\n",
- " dL_dt = dL_dr*dr_dt # chain rule\n",
- " dt_da = 1.0 # partial from t = a + b\n",
- " dt_db = 1.0 # partial from t = a + b\n",
- " dL_da = dL_dt * dt_da # chain rule\n",
- " dL_db = dL_dt * dt_db + dL_dr * dr_db # chain rule\n",
- " return [dL_da, dL_db]\n",
- " gradient_tape.append(TapeEntry(inputs=[a.name, b.name], outputs=[r.name], propagate=propagate))\n",
- " return r\n",
- "\n",
- "da, db = run_gradients(simple_variable_wrong)\n",
- "print(\"da\", da) # ERROR: da is None!!!????\n",
- "print(\"db\", db)"
- ],
- "execution_count": null,
- "outputs": [
- {
- "output_type": "stream",
- "text": [
- "a = tensor([0.4605, 0.4061, 0.9422, 0.3946])\n",
- "b = tensor([0.0850, 0.3296, 0.9888, 0.6494])\n",
- "L0 = v0.sum()\n",
- "dL0 ------------------------\n",
- "v2 = v1.expand(4)\n",
- "v3 = v2 * b\n",
- "v4 = v2 * t\n",
- "v5 = v3 + v4\n",
- "dL0_dL0 = v1\n",
- "dL0_dv0 = v2\n",
- "dL0_da = v3\n",
- "dL0_db = v5\n",
- "------------------------\n",
- "v6 = v3 * v3\n",
- "v7 = v5 * v5\n",
- "v8 = v6 + v7\n",
- "L1 = v8.sum()\n",
- "dL1 ------------------------\n",
- "v10 = v9.expand(4)\n",
- "v11 = v10 * v5\n",
- "v12 = v10 * v5\n",
- "v13 = v11 + v12\n",
- "v14 = v10 * v3\n",
- "v15 = v10 * v3\n",
- "v16 = v14 + v15\n",
- "v17 = v16 + v13\n",
- "v18 = v13 * t\n",
- "v19 = v13 * v2\n",
- "v20 = v17 * b\n",
- "v21 = v17 * v2\n",
- "v22 = v18 + v20\n",
- "v23 = v22.sum()\n",
- "dL1_dL1 = v9\n",
- "dL1_dv8 = v10\n",
- "dL1_dv6 = v10\n",
- "dL1_dv7 = v10\n",
- "dL1_dv5 = v13\n",
- "dL1_dv3 = v17\n",
- "dL1_dv4 = v13\n",
- "dL1_dv2 = v22\n",
- "dL1_dt = v19\n",
- "dL1_db = v21\n",
- "dL1_dv1 = v23\n",
- "------------------------\n",
- "da None\n",
- "db tensor([1.4312, 2.7896, 7.8169, 4.6857])\n"
- ],
- "name": "stdout"
- }
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "v6D1xMlfQ4xE"
- },
- "source": [
- "While we do not get an error, something is clearly wrong. `dL1/da` is None, but we _know_ that the value of `a` affects the norm of the gradients of the original loss so this value should not be None. We are not propagating a gradient somewhere!\n",
- "\n",
- "Let's see what happens when we run just the first gradient.\n"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "KsCpcoCYRVgR",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 544
- },
- "outputId": "c67a729e-0dfe-4e0e-b267-50d7ee05e307"
- },
- "source": [
- "da, db = run_gradients(simple_variable_wrong, second_loss=False)\n",
- "da_ref, db_ref = run_gradients(simple, second_loss=False)\n",
- "print(\"da\", da, da_ref) \n",
- "print(\"db\", db, db_ref)"
- ],
- "execution_count": null,
- "outputs": [
- {
- "output_type": "stream",
- "text": [
- "a = tensor([0.4605, 0.4061, 0.9422, 0.3946])\n",
- "b = tensor([0.0850, 0.3296, 0.9888, 0.6494])\n",
- "L0 = v0.sum()\n",
- "dL0 ------------------------\n",
- "v2 = v1.expand(4)\n",
- "v3 = v2 * b\n",
- "v4 = v2 * t\n",
- "v5 = v3 + v4\n",
- "dL0_dL0 = v1\n",
- "dL0_dv0 = v2\n",
- "dL0_da = v3\n",
- "dL0_db = v5\n",
- "------------------------\n",
- "a = tensor([0.4605, 0.4061, 0.9422, 0.3946])\n",
- "b = tensor([0.0850, 0.3296, 0.9888, 0.6494])\n",
- "v0 = a + b\n",
- "v1 = v0 * b\n",
- "L0 = v1.sum()\n",
- "dL0 ------------------------\n",
- "v3 = v2.expand(4)\n",
- "v4 = v3 * b\n",
- "v5 = v3 * v0\n",
- "v6 = v5 + v4\n",
- "dL0_dL0 = v2\n",
- "dL0_dv1 = v3\n",
- "dL0_dv0 = v4\n",
- "dL0_db = v6\n",
- "dL0_da = v4\n",
- "------------------------\n",
- "da tensor([0.0850, 0.3296, 0.9888, 0.6494]) tensor([0.0850, 0.3296, 0.9888, 0.6494])\n",
- "db tensor([0.6306, 1.0652, 2.9197, 1.6935]) tensor([0.6306, 1.0652, 2.9197, 1.6935])\n"
- ],
- "name": "stdout"
- }
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "-ck3a9nQRil7"
- },
- "source": [
- "In the single-backward case, we get the right answer! This illustrates a key part of autograd: it is _very easy_ to make it appear to work for a single backward pass but have the code be broken when trying higher order gradients. \n",
- "\n",
- "So what is going wrong? Look at the debug trace from the first time we ran `simple_variable_wrong`. Inside the compute of `dL0` (the first backward), you can see a line: `v4 = v2 * t`. The first backward is using the value of `t`. But if a computation _uses_ `t` then the gradient of that computation will have a non-zero gradient `dL1/dt` for any future loss (`L1`) that uses the results of that computation. But this future use of `t` is not accounted for in `simple_variable_wrong`! We consider the effect of `r` on `t` as `dL_dt = dL_dr*dr_dt`, but do not consider uses of `t` outside the local aggregate. This is because the way `t` can be used in the future is subtle: it escapes from our compute _only_ through its use as a closed over variable in `propagate`. So this gradient pathway can only be non-zero for higher-order gradients where we are differentiating through this use.\n",
- "\n",
- "The problem originates because `t` was not declared as an output of the original computation, even though it was defined by the computation and used by later computations. We can fix this by defining it as an output in the gradient tape and then using the derivative contribution that comes from it."
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "2cbRekShbmxX",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 986
- },
- "outputId": "950d5943-f63f-4ff3-e370-72e9316d340b"
- },
- "source": [
- "def simple_variable_almost(a, b):\n",
- " t = (a.value + b.value)\n",
- " t_v = Variable(t, name='t_v')\n",
- " r = Variable(t * b.value)\n",
- " def propagate(dL_doutputs: List[Variable]) -> List[Variable]:\n",
- " # t is considered an output, so we now get dL_dt0 as an input.\n",
- " dL_dr, dL_dt0 = dL_doutputs\n",
- " ###### new gradient contribution\n",
- "\n",
- " # Handle cases where one incoming gradient is zero (None)\n",
- " if dL_dr is None:\n",
- " dL_dr = Variable.constant(torch.zeros(()))\n",
- " if dL_dt0 is None:\n",
- " dL_dt0 = Variable.constant(torch.zeros(()))\n",
- " \n",
- "\n",
- " dr_dt = b \n",
- " dr_db = t_v \n",
- " # we combine this with the contribution from r to calculate \n",
- " # all gradient paths to dL_dt\n",
- " dL_dt = dL_dt0 + dL_dr*dr_dt # chain rule\n",
- " ######\n",
- "\n",
- " dt_da = 1.0 \n",
- " dt_db = 1.0 \n",
- " dL_db = dL_dr * dr_db + dL_dt * dt_db \n",
- " dL_da = dL_dt * dt_da\n",
- " return [dL_da, dL_db]\n",
- "\n",
- " # note: t_v is now considered an output in the tape\n",
- " gradient_tape.append(TapeEntry(inputs=[a.name, b.name], outputs=[r.name, t_v.name], propagate=propagate))\n",
- " ######### new output\n",
- " return r\n",
- "da, db = run_gradients(simple_variable_almost)\n",
- "print(\"da\", da) \n",
- "print(\"db\", db)"
- ],
- "execution_count": null,
- "outputs": [
- {
- "output_type": "stream",
- "text": [
- "a = tensor([0.0171, 0.1633, 0.5833, 0.3794])\n",
- "b = tensor([0.3774, 0.6308, 0.5239, 0.1387])\n",
- "L0 = v0.sum()\n",
- "dL0 ------------------------\n",
- "v2 = v1.expand(4)\n",
- "v3 = 0.0\n",
- "v4 = v2 * b\n",
- "v5 = v3 + v4\n",
- "v6 = v2 * t_v\n",
- "v7 = v6 + v5\n",
- "dL0_dL0 = v1\n",
- "dL0_dv0 = v2\n",
- "dL0_da = v5\n",
- "dL0_db = v7\n",
- "------------------------\n",
- "v8 = v5 * v5\n",
- "v9 = v7 * v7\n",
- "v10 = v8 + v9\n",
- "L1 = v10.sum()\n",
- "dL1 ------------------------\n",
- "v12 = v11.expand(4)\n",
- "v13 = v12 * v7\n",
- "v14 = v12 * v7\n",
- "v15 = v13 + v14\n",
- "v16 = v12 * v5\n",
- "v17 = v12 * v5\n",
- "v18 = v16 + v17\n",
- "v19 = v18 + v15\n",
- "v20 = v15 * t_v\n",
- "v21 = v15 * v2\n",
- "v22 = v19 * b\n",
- "v23 = v19 * v2\n",
- "v24 = v20 + v22\n",
- "v25 = v24.sum()\n",
- "v26 = 0.0\n",
- "v27 = v26 * b\n",
- "v28 = v21 + v27\n",
- "v29 = v26 * t_v\n",
- "v30 = v29 + v28\n",
- "v31 = v23 + v30\n",
- "dL1_dL1 = v11\n",
- "dL1_dv10 = v12\n",
- "dL1_dv8 = v12\n",
- "dL1_dv9 = v12\n",
- "dL1_dv7 = v15\n",
- "dL1_dv5 = v19\n",
- "dL1_dv6 = v15\n",
- "dL1_dv2 = v24\n",
- "dL1_dt_v = v21\n",
- "dL1_dv3 = v19\n",
- "dL1_dv4 = v19\n",
- "dL1_db = v31\n",
- "dL1_dv1 = v25\n",
- "dL1_da = v28\n",
- "------------------------\n",
- "da tensor([1.5438, 2.8499, 3.2622, 1.3134])\n",
- "db tensor([3.8424, 6.9614, 7.5721, 2.9042])\n"
- ],
- "name": "stdout"
- }
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "GjIEL9nncdmO"
- },
- "source": [
- "This code is now correct! However, it has some non-optimal behavior. Notice how at the beginning of `propagate` we need to handle the cases where the gradients coming in are `None`. Recall that when a pathway has no gradient we give it the value `None`. The first time through `propagate` `dL_dt0` will be `None` since `t` is not used outside of the propagate function itself on the first backward. The _second_ time through `propgate`, `dL_dt0` will have a value but `dL_dr` will be `None`. Excercise: convince yourself why `dL_dr` is `None` the second time through. When we fix this by changing the `None` into zeros, we get the right answer but at the cost of always doing more compute. For instance in this case, it adds an additional pointwise addition of a zero tensor to every single-backward call to handle `dL_dt0` input which will be zero.\n",
- "\n",
- " It makes sense to use a constant-time check for zero to eliminate a tensor-sized amount of work. So we optimize this code by replicating some of the `None` handling logic in `grad` directly into the aggregate op. It is a little messy but it handles the cases where inputs might be `None` with a minimal amount of compute."
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "fshJIV4xcJKW",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 833
- },
- "outputId": "90eee555-de63-4fa5-fdf1-313d9fae8116"
- },
- "source": [
- "def add_optional(a: Optional['Variable'], b: Optional['Variable']):\n",
- " if a is None:\n",
- " return b\n",
- " if b is None:\n",
- " return a\n",
- " return a + b\n",
- "\n",
- "def simple_variable(a, b):\n",
- " t = (a.value + b.value)\n",
- " t_v = Variable(t, name='t_v')\n",
- " r = Variable(t * b.value)\n",
- " def propagate(dL_doutputs: List[Variable]) -> List[Variable]:\n",
- " dL_dr, dL_dt0 = dL_doutputs\n",
- " dr_dt = b # partial from: r = t * b\n",
- " dr_db = t_v # partial from: r = t * b\n",
- " dL_dt = dL_dt0\n",
- " if dL_dr is not None:\n",
- " dL_dt = add_optional(dL_dt, dL_dr*dr_dt) # chain rule\n",
- "\n",
- " dt_da = 1.0 # partial from t = a + b\n",
- " dt_db = 1.0 # partial from t = a + b\n",
- " if dL_dr is not None:\n",
- " dL_db = dL_dr * dr_db # chain rule\n",
- " else:\n",
- " dL_db = None\n",
- "\n",
- " if dL_dt is not None:\n",
- " dL_da = dL_dt * dt_da # chain rule\n",
- " dL_db = add_optional(dL_db, dL_dt * dt_db)\n",
- " else:\n",
- " dL_da = None\n",
- "\n",
- " return [dL_da, dL_db]\n",
- "\n",
- " gradient_tape.append(TapeEntry(inputs=[a.name, b.name], outputs=[r.name, t_v.name], propagate=propagate))\n",
- " return r\n",
- "da, db = run_gradients(simple_variable)\n",
- "print(\"da\", da) \n",
- "print(\"db\", db)"
- ],
- "execution_count": null,
- "outputs": [
- {
- "output_type": "stream",
- "text": [
- "a = tensor([0.0171, 0.1633, 0.5833, 0.3794])\n",
- "b = tensor([0.3774, 0.6308, 0.5239, 0.1387])\n",
- "L0 = v0.sum()\n",
- "dL0 ------------------------\n",
- "v2 = v1.expand(4)\n",
- "v3 = v2 * b\n",
- "v4 = v2 * t_v\n",
- "v5 = v4 + v3\n",
- "dL0_dL0 = v1\n",
- "dL0_dv0 = v2\n",
- "dL0_da = v3\n",
- "dL0_db = v5\n",
- "------------------------\n",
- "v6 = v3 * v3\n",
- "v7 = v5 * v5\n",
- "v8 = v6 + v7\n",
- "L1 = v8.sum()\n",
- "dL1 ------------------------\n",
- "v10 = v9.expand(4)\n",
- "v11 = v10 * v5\n",
- "v12 = v10 * v5\n",
- "v13 = v11 + v12\n",
- "v14 = v10 * v3\n",
- "v15 = v10 * v3\n",
- "v16 = v14 + v15\n",
- "v17 = v16 + v13\n",
- "v18 = v13 * t_v\n",
- "v19 = v13 * v2\n",
- "v20 = v17 * b\n",
- "v21 = v17 * v2\n",
- "v22 = v18 + v20\n",
- "v23 = v22.sum()\n",
- "v24 = v21 + v19\n",
- "dL1_dL1 = v9\n",
- "dL1_dv8 = v10\n",
- "dL1_dv6 = v10\n",
- "dL1_dv7 = v10\n",
- "dL1_dv5 = v13\n",
- "dL1_dv3 = v17\n",
- "dL1_dv4 = v13\n",
- "dL1_dv2 = v22\n",
- "dL1_dt_v = v19\n",
- "dL1_db = v24\n",
- "dL1_dv1 = v23\n",
- "dL1_da = v19\n",
- "------------------------\n",
- "da tensor([1.5438, 2.8499, 3.2622, 1.3134])\n",
- "db tensor([3.8424, 6.9614, 7.5721, 2.9042])\n"
- ],
- "name": "stdout"
- }
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "r77nkn7Y5NjU"
- },
- "source": [
- "**Excercise** modify `run_gradients` such that the second call to `grad` produces non-zero values for both `dL_dr` and `dL_dt`. Hint: it can be done with the addition of 2 characters."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "lZvYIrBS5fXr"
- },
- "source": [
- "In PyTorch's symbolic autodiff implementation, the handling of zero tensors is done with undefined tensors in the place of `None` values, but the logic in TorchScript is very similar. The function `any_defined(...)` is used to check if any inputs are non-zero and guards the calculation of unused parts of the autograd. The `AutogradAdd(a, b)` operator adds two tensors, handling the case where either is undefined, similar to `add_optional`. \n",
- "\n",
- "The backward pass is very messy as-is with all of this conditional logic. Furthermore, as you have seen in these examples, in many cases the logic will branch in the same direction. This is especially true for single-backward where gradients from captured temporaries will always be zero. It is profitable to try to specialize this code for particular patterns of non-zeros since it allows more aggresive fusion of the backward pass."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "sAApdEb27KTA"
- },
- "source": [
- "# PyTorch vs Simple Grad\n",
- "\n",
- "Simple Grad gives a good overview of how PyTorch's autograd works. TorchScript's symbolic gradient pass can generate aggregate operators from subsets of the IR by automating the process we went through to define `simple` as an aggregate operator.\n",
- "\n",
- "The real PyTorch autograd has some features that go beyond this example related to mutable tensors. Simple Grad assumes that tensors are immutable, so saving a Tensor for use in `propagate` is as simple as saving a reference to it. In PyTorch, the gradient formulas need to explicity mark a Tensor as needing to be saved so we can track future potential mutations. The idea is to be able to track if a user mutated a tensor that is needed by backward and report an error on use. Mutable ops themselves also affect how the `propagate` functions get recorded. If a tensor is mutated, uses of the tensor _before_ the mutation need to propagate gradient to the original value, while uses _after_ propagate gradient to the new mutated value. Since tensors can be views of other mutable tensors, PyTorch needs bookkeeping to make sure any time a tensor is updated all views of the tensor now propagate gradient to the new value and not the old one. \n",
- "\n",
- "# Where to go from here\n",
- "\n",
- "If you still have questions about how this process works, I encourage you to edit this notebook with additional debug information and play around with compute. You can try:\n",
- "* Adding a new operator with `propagate` formula (use torch.grad to verify correctness)\n",
- "* Modify `run_gradient` to calculate weirder higher order gradients and see if it behaves as you expect.\n",
- "* Remove `None` and implement gradients using Tensor zeros.\n",
- "* Try to manually define an another aggregate operator for something similar to `simple`\n",
- "* Write a 'compiler' that can take a small expression similar to `simple` and transform it automatically into a forward and `propagate`, similar to autodiff.cpp\n",
- "* Rewrite `simple_variable` so all the branching for `None` checks is at the top of `propagate`. Can you generalize this such that a compiler can generate specializations for the seen non-zero patterns?\n",
- "* Read `autodiff.cpp` and add a description to this document about how concenpts in here directly relate to that code."
- ]
- }
- ]
-}
\ No newline at end of file
diff --git a/S24/document/recitation/Recitation4/Paper_Writing_Workshop.pdf b/S24/document/recitation/Recitation4/Paper_Writing_Workshop.pdf
deleted file mode 100644
index ed786322..00000000
Binary files a/S24/document/recitation/Recitation4/Paper_Writing_Workshop.pdf and /dev/null differ
diff --git a/S24/document/recitation/Recitation4/idl_recitation4_F22.pdf b/S24/document/recitation/Recitation4/idl_recitation4_F22.pdf
deleted file mode 100644
index cdfb47a2..00000000
Binary files a/S24/document/recitation/Recitation4/idl_recitation4_F22.pdf and /dev/null differ
diff --git a/S24/document/recitation/Recitation5/S23_IDL_ Recitation 5.pdf b/S24/document/recitation/Recitation5/S23_IDL_ Recitation 5.pdf
deleted file mode 100644
index 06356317..00000000
Binary files a/S24/document/recitation/Recitation5/S23_IDL_ Recitation 5.pdf and /dev/null differ
diff --git a/S24/document/recitation/Recitation6/IDL S'23 Recitation 6_CNN Classification.pdf b/S24/document/recitation/Recitation6/IDL S'23 Recitation 6_CNN Classification.pdf
deleted file mode 100644
index bd347a52..00000000
Binary files a/S24/document/recitation/Recitation6/IDL S'23 Recitation 6_CNN Classification.pdf and /dev/null differ
diff --git a/S24/document/recitation/Recitation7/IDL_S23_Recitation_7_CNN_Verification.pdf b/S24/document/recitation/Recitation7/IDL_S23_Recitation_7_CNN_Verification.pdf
deleted file mode 100644
index 9d92bb6a..00000000
Binary files a/S24/document/recitation/Recitation7/IDL_S23_Recitation_7_CNN_Verification.pdf and /dev/null differ
diff --git a/S24/document/recitation/Recitation8/IDL_S23_Recitation_8__RNN_Basics.pdf b/S24/document/recitation/Recitation8/IDL_S23_Recitation_8__RNN_Basics.pdf
deleted file mode 100644
index 693fcf40..00000000
Binary files a/S24/document/recitation/Recitation8/IDL_S23_Recitation_8__RNN_Basics.pdf and /dev/null differ
diff --git a/S24/document/recitation/Recitation8/Recitation 8_RNN Basics.pptx b/S24/document/recitation/Recitation8/Recitation 8_RNN Basics.pptx
deleted file mode 100644
index 648dbca0..00000000
Binary files a/S24/document/recitation/Recitation8/Recitation 8_RNN Basics.pptx and /dev/null differ
diff --git a/S24/document/recitation/Recitation8/language_model-2.ipynb b/S24/document/recitation/Recitation8/language_model-2.ipynb
deleted file mode 100644
index 2d347deb..00000000
--- a/S24/document/recitation/Recitation8/language_model-2.ipynb
+++ /dev/null
@@ -1,2112 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "pFvgJbAu50m8"
- },
- "source": [
- "# Shakespeare Character Language Model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 36
- },
- "id": "mcIAFm9g50m9",
- "outputId": "3a72c4cb-aae6-4078-c143-d2cd1545851a"
- },
- "outputs": [
- {
- "data": {
- "application/vnd.google.colaboratory.intrinsic+json": {
- "type": "string"
- },
- "text/plain": [
- "'cuda'"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "import torch\n",
- "import torch.nn as nn\n",
- "import torch.nn.utils.rnn as rnn\n",
- "from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
- "import numpy as np\n",
- "import time\n",
- "\n",
- "import shakespeare_data as sh\n",
- "\n",
- "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
- "DEVICE"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "gN0cVBCS50nB"
- },
- "source": [
- "## Fixed length input"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "uFhKFJEN50nB",
- "outputId": "04863cf3-2215-4539-86dc-bf56ad09b78a",
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "First 203 characters...Last 50 characters\n",
- "1609\n",
- " THE SONNETS\n",
- " by William Shakespeare\n",
- " 1\n",
- " From fairest creatures we desire increase,\n",
- " That thereby beauty's rose might never die,\n",
- " But as the riper should by time decease,\n",
- "...,\n",
- " And new pervert a reconciled maid.'\n",
- " THE END\n",
- "\n",
- "Total character count: 5551930\n",
- "Unique character count: 84\n",
- "\n",
- "shakespeare_array.shape: (5551930,)\n",
- "\n",
- "First 17 characters as indices [12 17 11 20 0 1 45 33 30 1 44 40 39 39 30 45 44]\n",
- "First 17 characters as characters: ['1', '6', '0', '9', '\\n', ' ', 'T', 'H', 'E', ' ', 'S', 'O', 'N', 'N', 'E', 'T', 'S']\n",
- "First 17 character indices as text:\n",
- " 1609\n",
- " THE SONNETS\n"
- ]
- }
- ],
- "source": [
- "# Data - refer to shakespeare_data.py for details\n",
- "corpus = sh.read_corpus()\n",
- "print(\"First 203 characters...Last 50 characters\")\n",
- "print(\"{}...{}\".format(corpus[:203], corpus[-50:]))\n",
- "print(\"Total character count: {}\".format(len(corpus)))\n",
- "chars, charmap = sh.get_charmap(corpus)\n",
- "charcount = len(chars)\n",
- "print(\"Unique character count: {}\\n\".format(len(chars)))\n",
- "shakespeare_array = sh.map_corpus(corpus, charmap)\n",
- "print(\"shakespeare_array.shape: {}\\n\".format(shakespeare_array.shape))\n",
- "small_example = shakespeare_array[:17]\n",
- "print(\"First 17 characters as indices\", small_example)\n",
- "print(\"First 17 characters as characters:\", [chars[c] for c in small_example])\n",
- "print(\"First 17 character indices as text:\\n\", sh.to_text(small_example,chars))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "DBcpz6iD50nD"
- },
- "outputs": [],
- "source": [
- "# Dataset class. Transform raw text into a set of sequences of fixed length, and extracts inputs and targets\n",
- "class TextDataset(Dataset):\n",
- " \n",
- " def __init__(self,text, seq_len = 200):\n",
- " n_seq = len(text) // seq_len\n",
- " text = text[:n_seq * seq_len]\n",
- " self.data = torch.tensor(text).view(-1,seq_len)\n",
- " \n",
- " def __getitem__(self,i):\n",
- " txt = self.data[i]\n",
- " \n",
- " # labels are the input sequence shifted by 1\n",
- " return txt[:-1],txt[1:]\n",
- " \n",
- " def __len__(self):\n",
- " return self.data.size(0)\n",
- "\n",
- "# Collate function. Transform a list of sequences into a batch. Passed as an argument to the DataLoader.\n",
- "# Returns data of the format seq_len x batch_size\n",
- "def collate(seq_list):\n",
- " inputs = torch.cat([s[0].unsqueeze(1) for s in seq_list],dim=1)\n",
- " targets = torch.cat([s[1].unsqueeze(1) for s in seq_list],dim=1)\n",
- " return inputs,targets\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "iHb5PHQs50nF"
- },
- "outputs": [],
- "source": [
- "# Model\n",
- "class CharLanguageModel(nn.Module):\n",
- "\n",
- " def __init__(self,vocab_size,embed_size,hidden_size, nlayers):\n",
- " super(CharLanguageModel,self).__init__()\n",
- " self.vocab_size=vocab_size\n",
- " self.embed_size = embed_size\n",
- " self.hidden_size = hidden_size\n",
- " self.nlayers=nlayers\n",
- " self.embedding = nn.Embedding(vocab_size,embed_size) # Embedding layer\n",
- " self.rnn = nn.LSTM(input_size = embed_size,hidden_size=hidden_size,num_layers=nlayers) # Recurrent network\n",
- " # You can also try GRUs instead of LSTMs.\n",
- " \n",
- " self.scoring = nn.Linear(hidden_size,vocab_size) # Projection layer\n",
- " \n",
- " def forward(self,seq_batch): #L x N\n",
- " # returns 3D logits\n",
- " batch_size = seq_batch.size(1)\n",
- " embed = self.embedding(seq_batch) #L x N x E\n",
- " hidden = None\n",
- " output_lstm,hidden = self.rnn(embed,hidden) #L x N x H\n",
- " output_lstm_flatten = output_lstm.view(-1,self.hidden_size) #(L*N) x H\n",
- " output_flatten = self.scoring(output_lstm_flatten) #(L*N) x V\n",
- " return output_flatten.view(-1,batch_size,self.vocab_size)\n",
- " \n",
- " def generate(self,seq, n_chars): # L x V\n",
- " # performs greedy search to extract and return words (one sequence).\n",
- " generated_chars = []\n",
- " embed = self.embedding(seq).unsqueeze(1) # L x 1 x E\n",
- " hidden = None\n",
- " output_lstm, hidden = self.rnn(embed,hidden) # L x 1 x H\n",
- " output = output_lstm[-1] # 1 x H\n",
- " scores = self.scoring(output) # 1 x V\n",
- " _,current_char = torch.max(scores,dim=1) # 1 x 1\n",
- " generated_chars.append(current_char)\n",
- " if n_chars > 1:\n",
- " for i in range(n_chars-1):\n",
- " embed = self.embedding(current_char).unsqueeze(0) # 1 x 1 x E\n",
- " output_lstm, hidden = self.rnn(embed,hidden) # 1 x 1 x H\n",
- " output = output_lstm[0] # 1 x H\n",
- " scores = self.scoring(output) # V\n",
- " _,current_char = torch.max(scores,dim=1) # 1\n",
- " generated_chars.append(current_char)\n",
- " return torch.cat(generated_chars,dim=0)\n",
- " \n",
- " "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "QRxGHF6E50nH"
- },
- "outputs": [],
- "source": [
- "def train_epoch(model, optimizer, train_loader, val_loader):\n",
- " criterion = nn.CrossEntropyLoss()\n",
- " criterion = criterion.to(DEVICE)\n",
- " before = time.time()\n",
- " print(\"training\", len(train_loader), \"number of batches\")\n",
- " for batch_idx, (inputs,targets) in enumerate(train_loader):\n",
- " if batch_idx == 0:\n",
- " first_time = time.time()\n",
- " inputs = inputs.to(DEVICE)\n",
- " targets = targets.to(DEVICE)\n",
- " outputs = model(inputs) # 3D\n",
- " loss = criterion(outputs.view(-1,outputs.size(2)),targets.view(-1)) # Loss of the flattened outputs\n",
- " optimizer.zero_grad()\n",
- " loss.backward()\n",
- " optimizer.step()\n",
- " \n",
- " if batch_idx == 0:\n",
- " print(\"Time elapsed\", time.time() - first_time)\n",
- " \n",
- " if batch_idx % 100 == 0 and batch_idx != 0:\n",
- " after = time.time()\n",
- " print(\"Time: \", after - before)\n",
- " print(\"Loss per word: \", loss.item() / batch_idx)\n",
- " print(\"Perplexity: \", np.exp(loss.item() / batch_idx))\n",
- " after = before\n",
- " \n",
- " val_loss = 0\n",
- " batch_id=0\n",
- " for inputs,targets in val_loader:\n",
- " batch_id+=1\n",
- " inputs = inputs.to(DEVICE)\n",
- " targets = targets.to(DEVICE)\n",
- " outputs = model(inputs)\n",
- " loss = criterion(outputs.view(-1,outputs.size(2)),targets.view(-1))\n",
- " val_loss+=loss.item()\n",
- " val_lpw = val_loss / batch_id\n",
- " print(\"\\nValidation loss per word:\",val_lpw)\n",
- " print(\"Validation perplexity :\",np.exp(val_lpw),\"\\n\")\n",
- " return val_lpw\n",
- " "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "RNHa3FAU50nI"
- },
- "outputs": [],
- "source": [
- "model = CharLanguageModel(charcount,256,256,3)\n",
- "model = model.to(DEVICE)\n",
- "optimizer = torch.optim.Adam(model.parameters(),lr=0.001, weight_decay=1e-6)\n",
- "split = 5000000\n",
- "train_dataset = TextDataset(shakespeare_array[:split])\n",
- "val_dataset = TextDataset(shakespeare_array[split:])\n",
- "train_loader = DataLoader(train_dataset, shuffle=True, batch_size=64, collate_fn = collate)\n",
- "val_loader = DataLoader(val_dataset, shuffle=False, batch_size=64, collate_fn = collate, drop_last=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "VVROzTRT50nK",
- "outputId": "2cf1af8e-19ac-4e60-a6b4-8180e41df5d1",
- "scrolled": false
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "training 391 number of batches\n",
- "Time elapsed 0.07767295837402344\n",
- "Time: 5.596414089202881\n",
- "Loss per word: 0.01185429334640503\n",
- "Perplexity: 1.0119248339425135\n",
- "Time: 11.145010232925415\n",
- "Loss per word: 0.005843929648399353\n",
- "Perplexity: 1.0058610387170948\n",
- "Time: 16.780640840530396\n",
- "Loss per word: 0.003932354052861532\n",
- "Perplexity: 1.0039400959016305\n",
- "\n",
- "Validation loss per word: 1.3321008599081705\n",
- "Validation perplexity : 3.7889951799089627 \n",
- "\n",
- "training 391 number of batches\n",
- "Time elapsed 0.04892563819885254\n",
- "Time: 5.784719467163086\n",
- "Loss per word: 0.011950627565383912\n",
- "Perplexity: 1.0120223216266813\n",
- "Time: 11.487154722213745\n",
- "Loss per word: 0.005958112478256225\n",
- "Perplexity: 1.0059758973342545\n",
- "Time: 17.114842653274536\n",
- "Loss per word: 0.003922495444615682\n",
- "Perplexity: 1.0039301984983102\n",
- "\n",
- "Validation loss per word: 1.3247673289720403\n",
- "Validation perplexity : 3.761310105292561 \n",
- "\n",
- "training 391 number of batches\n",
- "Time elapsed 0.052561044692993164\n",
- "Time: 5.5854432582855225\n",
- "Loss per word: 0.011668695211410523\n",
- "Perplexity: 1.0117370400082202\n",
- "Time: 11.07470154762268\n",
- "Loss per word: 0.005841174721717834\n",
- "Perplexity: 1.0058582676474983\n",
- "Time: 16.553176164627075\n",
- "Loss per word: 0.003922032912572225\n",
- "Perplexity: 1.0039297341485314\n",
- "\n",
- "Validation loss per word: 1.3197767152342685\n",
- "Validation perplexity : 3.742585621604832 \n",
- "\n",
- "training 391 number of batches\n",
- "Time elapsed 0.049115657806396484\n",
- "Time: 5.514904737472534\n",
- "Loss per word: 0.011583248376846314\n",
- "Perplexity: 1.0116505939740628\n",
- "Time: 10.967044353485107\n",
- "Loss per word: 0.005864846110343933\n",
- "Perplexity: 1.0058820779912654\n",
- "Time: 16.433537483215332\n",
- "Loss per word: 0.004020507335662842\n",
- "Perplexity: 1.0040286004177446\n",
- "\n",
- "Validation loss per word: 1.313875392425892\n",
- "Validation perplexity : 3.720564456623678 \n",
- "\n",
- "training 391 number of batches\n",
- "Time elapsed 0.049449920654296875\n",
- "Time: 5.5508527755737305\n",
- "Loss per word: 0.011253877878189086\n",
- "Perplexity: 1.011317440981854\n",
- "Time: 11.068422079086304\n",
- "Loss per word: 0.005857662558555603\n",
- "Perplexity: 1.0058748522112186\n",
- "Time: 16.625342845916748\n",
- "Loss per word: 0.003796435991923014\n",
- "Perplexity: 1.0038036515833308\n",
- "\n",
- "Validation loss per word: 1.30533528050711\n",
- "Validation perplexity : 3.6889257112697154 \n",
- "\n"
- ]
- }
- ],
- "source": [
- "for i in range(5):\n",
- " train_epoch(model, optimizer, train_loader, val_loader)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "Xd-bklsK50nM"
- },
- "outputs": [],
- "source": [
- "def generate(model, seed,nchars):\n",
- " seq = sh.map_corpus(seed, charmap)\n",
- " seq = torch.tensor(seq).to(DEVICE)\n",
- " out = model.generate(seq,nchars)\n",
- " return sh.to_text(out.cpu().detach().numpy(),chars)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "V-Sp34eF50nN",
- "outputId": "85a633de-af9d-405b-b2e5-4ad1a8e41998"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "uestion\n",
- "\n"
- ]
- }
- ],
- "source": [
- "print(generate(model, \"To be, or not to be, that is the q\",8))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "3WPIhJur50nP",
- "outputId": "ac980873-12e5-48d3-8a3e-866f04d09768"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "and the King of the compolutes\n",
- " \n"
- ]
- }
- ],
- "source": [
- "print(generate(model, \"Richard \", 1000))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "OITobxJ_50nS"
- },
- "source": [
- "## Packed sequences"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "fZshan9w50nS",
- "outputId": "92b2741d-e08d-4ac7-ff44-6cbea43420b2"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "1609\n",
- "\n",
- " THE SONNETS\n",
- "\n",
- " by William Shakespeare\n",
- "\n",
- " 1\n",
- "\n",
- " From fairest creatures we desire increase,\n",
- "\n",
- " That thereby beauty's rose might never die,\n",
- "\n",
- " But as the riper should by time decease,\n",
- "\n",
- " His tender heir might bear his memory:\n",
- "\n",
- " But thou contracted to thine own bright eyes,\n",
- "\n",
- " Feed'st thy light's flame with self-substantial fuel,\n",
- "\n",
- "114638\n"
- ]
- }
- ],
- "source": [
- "stop_character = charmap['\\n']\n",
- "space_character = charmap[\" \"]\n",
- "lines = np.split(shakespeare_array, np.where(shakespeare_array == stop_character)[0]+1) # split the data in lines\n",
- "shakespeare_lines = []\n",
- "for s in lines:\n",
- " s_trimmed = np.trim_zeros(s-space_character)+space_character # remove space-only lines\n",
- " if len(s_trimmed)>1:\n",
- " shakespeare_lines.append(s)\n",
- "for i in range(10):\n",
- " print(sh.to_text(shakespeare_lines[i],chars))\n",
- "print(len(shakespeare_lines))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "J1UckTXZ50nU"
- },
- "outputs": [],
- "source": [
- "class LinesDataset(Dataset):\n",
- " def __init__(self,lines):\n",
- " self.lines=[torch.tensor(l) for l in lines]\n",
- " def __getitem__(self,i):\n",
- " line = self.lines[i]\n",
- " return line[:-1].to(DEVICE),line[1:].to(DEVICE)\n",
- " def __len__(self):\n",
- " return len(self.lines)\n",
- "\n",
- "# collate fn lets you control the return value of each batch\n",
- "# for packed_seqs, you want to return your data sorted by length\n",
- "def collate_lines(seq_list):\n",
- " inputs,targets = zip(*seq_list)\n",
- " lens = [len(seq) for seq in inputs]\n",
- " seq_order = sorted(range(len(lens)), key=lens.__getitem__, reverse=True)\n",
- " inputs = [inputs[i] for i in seq_order]\n",
- " targets = [targets[i] for i in seq_order]\n",
- " return inputs,targets"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "cg8ln5cG50nW"
- },
- "outputs": [],
- "source": [
- "# Model that takes packed sequences in training\n",
- "class PackedLanguageModel(nn.Module):\n",
- " \n",
- " def __init__(self,vocab_size,embed_size,hidden_size, nlayers, stop):\n",
- " super(PackedLanguageModel,self).__init__()\n",
- " self.vocab_size=vocab_size\n",
- " self.embed_size = embed_size\n",
- " self.hidden_size = hidden_size\n",
- " self.nlayers=nlayers\n",
- " self.embedding = nn.Embedding(vocab_size,embed_size)\n",
- " self.rnn = nn.LSTM(input_size = embed_size,hidden_size=hidden_size,num_layers=nlayers) # 1 layer, batch_size = False\n",
- " self.scoring = nn.Linear(hidden_size,vocab_size)\n",
- " self.stop = stop # stop line character (\\n)\n",
- " \n",
- " def forward(self,seq_list): # list\n",
- " batch_size = len(seq_list)\n",
- " lens = [len(s) for s in seq_list] # lens of all lines (already sorted)\n",
- " bounds = [0]\n",
- " for l in lens:\n",
- " bounds.append(bounds[-1]+l) # bounds of all lines in the concatenated sequence. Indexing into the list to \n",
- " # see where the sequence occurs. Need this at line marked **\n",
- " seq_concat = torch.cat(seq_list) # concatenated sequence\n",
- " embed_concat = self.embedding(seq_concat) # concatenated embeddings\n",
- " embed_list = [embed_concat[bounds[i]:bounds[i+1]] for i in range(batch_size)] # embeddings per line **\n",
- " packed_input = rnn.pack_sequence(embed_list) # packed version\n",
- " \n",
- " # alternatively, you could use rnn.pad_sequence, followed by rnn.pack_padded_sequence\n",
- " \n",
- " \n",
- " \n",
- " hidden = None\n",
- " output_packed,hidden = self.rnn(packed_input,hidden)\n",
- " output_padded, _ = rnn.pad_packed_sequence(output_packed) # unpacked output (padded). Also gives you the lengths\n",
- " output_flatten = torch.cat([output_padded[:lens[i],i] for i in range(batch_size)]) # concatenated output\n",
- " scores_flatten = self.scoring(output_flatten) # concatenated logits\n",
- " return scores_flatten # return concatenated logits\n",
- " \n",
- " def generate(self,seq, n_words): # L x V\n",
- " generated_words = []\n",
- " embed = self.embedding(seq).unsqueeze(1) # L x 1 x E\n",
- " hidden = None\n",
- " output_lstm, hidden = self.rnn(embed,hidden) # L x 1 x H\n",
- " output = output_lstm[-1] # 1 x H\n",
- " scores = self.scoring(output) # 1 x V\n",
- " _,current_word = torch.max(scores,dim=1) # 1 x 1\n",
- " generated_words.append(current_word)\n",
- " if n_words > 1:\n",
- " for i in range(n_words-1):\n",
- " embed = self.embedding(current_word).unsqueeze(0) # 1 x 1 x E\n",
- " output_lstm, hidden = self.rnn(embed,hidden) # 1 x 1 x H\n",
- " output = output_lstm[0] # 1 x H\n",
- " scores = self.scoring(output) # V\n",
- " _,current_word = torch.max(scores,dim=1) # 1\n",
- " generated_words.append(current_word)\n",
- " if current_word[0].item()==self.stop: # If end of line\n",
- " break\n",
- " return torch.cat(generated_words,dim=0)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "6vx3G8mc50nY"
- },
- "outputs": [],
- "source": [
- "def train_epoch_packed(model, optimizer, train_loader, val_loader):\n",
- " criterion = nn.CrossEntropyLoss(reduction=\"sum\") # sum instead of averaging, to take into account the different lengths\n",
- " criterion = criterion.to(DEVICE)\n",
- " batch_id=0\n",
- " before = time.time()\n",
- " print(\"Training\", len(train_loader), \"number of batches\")\n",
- " for inputs,targets in train_loader: # lists, presorted, preloaded on GPU\n",
- " batch_id+=1\n",
- " outputs = model(inputs)\n",
- " loss = criterion(outputs,torch.cat(targets)) # criterion of the concatenated output\n",
- " optimizer.zero_grad()\n",
- " loss.backward()\n",
- " optimizer.step()\n",
- " if batch_id % 100 == 0:\n",
- " after = time.time()\n",
- " nwords = np.sum(np.array([len(l) for l in inputs]))\n",
- " lpw = loss.item() / nwords\n",
- " print(\"Time elapsed: \", after - before)\n",
- " print(\"At batch\",batch_id)\n",
- " print(\"Training loss per word:\",lpw)\n",
- " print(\"Training perplexity :\",np.exp(lpw))\n",
- " before = after\n",
- " \n",
- " val_loss = 0\n",
- " batch_id=0\n",
- " nwords = 0\n",
- " for inputs,targets in val_loader:\n",
- " nwords += np.sum(np.array([len(l) for l in inputs]))\n",
- " batch_id+=1\n",
- " outputs = model(inputs)\n",
- " loss = criterion(outputs,torch.cat(targets))\n",
- " val_loss+=loss.item()\n",
- " val_lpw = val_loss / nwords\n",
- " print(\"\\nValidation loss per word:\",val_lpw)\n",
- " print(\"Validation perplexity :\",np.exp(val_lpw),\"\\n\")\n",
- " return val_lpw"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "FFvvuete50na"
- },
- "outputs": [],
- "source": [
- "model = PackedLanguageModel(charcount,256,256,3, stop=stop_character)\n",
- "model = model.to(DEVICE)\n",
- "optimizer = torch.optim.Adam(model.parameters(),lr=0.001, weight_decay=1e-6)\n",
- "split = 100000\n",
- "train_dataset = LinesDataset(shakespeare_lines[:split])\n",
- "val_dataset = LinesDataset(shakespeare_lines[split:])\n",
- "train_loader = DataLoader(train_dataset, shuffle=True, batch_size=64, collate_fn = collate_lines)\n",
- "val_loader = DataLoader(val_dataset, shuffle=False, batch_size=64, collate_fn = collate_lines, drop_last=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "4SFfNTCL50nb",
- "outputId": "dceb7e6a-3f15-46d9-b2e7-82041b108222",
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Training 1563 number of batches\n",
- "Time elapsed: 3.7703583240509033\n",
- "At batch 100\n",
- "Training loss per word: 2.7013041346193503\n",
- "Training perplexity : 14.899149557123419\n",
- "Time elapsed: 3.691027879714966\n",
- "At batch 200\n",
- "Training loss per word: 2.197311673753789\n",
- "Training perplexity : 9.000783901895156\n",
- "Time elapsed: 3.6885509490966797\n",
- "At batch 300\n",
- "Training loss per word: 1.9548370171777951\n",
- "Training perplexity : 7.062767820058658\n",
- "Time elapsed: 3.6700551509857178\n",
- "At batch 400\n",
- "Training loss per word: 1.7970468966562716\n",
- "Training perplexity : 6.03180858325135\n",
- "Time elapsed: 3.7132349014282227\n",
- "At batch 500\n",
- "Training loss per word: 1.8254557599342311\n",
- "Training perplexity : 6.205622648866053\n",
- "Time elapsed: 3.7091219425201416\n",
- "At batch 600\n",
- "Training loss per word: 1.8306482488458806\n",
- "Training perplexity : 6.237929078461925\n",
- "Time elapsed: 3.708371639251709\n",
- "At batch 700\n",
- "Training loss per word: 1.7147244232747705\n",
- "Training perplexity : 5.555144432895182\n",
- "Time elapsed: 3.713968515396118\n",
- "At batch 800\n",
- "Training loss per word: 1.6335569952082643\n",
- "Training perplexity : 5.12206150243047\n",
- "Time elapsed: 3.7396087646484375\n",
- "At batch 900\n",
- "Training loss per word: 1.5792983890503876\n",
- "Training perplexity : 4.851550715746207\n",
- "Time elapsed: 3.7228660583496094\n",
- "At batch 1000\n",
- "Training loss per word: 1.621026875629406\n",
- "Training perplexity : 5.058281876950357\n",
- "Time elapsed: 3.7168807983398438\n",
- "At batch 1100\n",
- "Training loss per word: 1.5696750812099223\n",
- "Training perplexity : 4.805086677156511\n",
- "Time elapsed: 3.7677266597747803\n",
- "At batch 1200\n",
- "Training loss per word: 1.5301319406920428\n",
- "Training perplexity : 4.6187861879448695\n",
- "Time elapsed: 3.72387433052063\n",
- "At batch 1300\n",
- "Training loss per word: 1.5210934303396073\n",
- "Training perplexity : 4.577227339139673\n",
- "Time elapsed: 3.749553918838501\n",
- "At batch 1400\n",
- "Training loss per word: 1.4983956473214286\n",
- "Training perplexity : 4.474504625206471\n",
- "Time elapsed: 3.761760711669922\n",
- "At batch 1500\n",
- "Training loss per word: 1.4545306382094139\n",
- "Training perplexity : 4.282472964634041\n",
- "\n",
- "Validation loss per word: 1.5616251907288874\n",
- "Validation perplexity : 4.766561525318013 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.7287967205047607\n",
- "At batch 100\n",
- "Training loss per word: 1.4484571357280522\n",
- "Training perplexity : 4.256542179510825\n",
- "Time elapsed: 3.7227351665496826\n",
- "At batch 200\n",
- "Training loss per word: 1.4475157597876576\n",
- "Training perplexity : 4.252537058571778\n",
- "Time elapsed: 3.7012224197387695\n",
- "At batch 300\n",
- "Training loss per word: 1.4798200545277405\n",
- "Training perplexity : 4.392155261351937\n",
- "Time elapsed: 3.7456235885620117\n",
- "At batch 400\n",
- "Training loss per word: 1.4387576705451068\n",
- "Training perplexity : 4.215455577988563\n",
- "Time elapsed: 3.728778123855591\n",
- "At batch 500\n",
- "Training loss per word: 1.3787083915855327\n",
- "Training perplexity : 3.9697709252483686\n",
- "Time elapsed: 3.708531141281128\n",
- "At batch 600\n",
- "Training loss per word: 1.4090683109504132\n",
- "Training perplexity : 4.092141024457387\n",
- "Time elapsed: 3.714569091796875\n",
- "At batch 700\n",
- "Training loss per word: 1.4055531548691633\n",
- "Training perplexity : 4.0777817623593124\n",
- "Time elapsed: 3.7208497524261475\n",
- "At batch 800\n",
- "Training loss per word: 1.2926560013746777\n",
- "Training perplexity : 3.6424480666522805\n",
- "Time elapsed: 3.730631113052368\n",
- "At batch 900\n",
- "Training loss per word: 1.3833357829224782\n",
- "Training perplexity : 3.988183176328358\n",
- "Time elapsed: 3.714390277862549\n",
- "At batch 1000\n",
- "Training loss per word: 1.331831748307841\n",
- "Training perplexity : 3.787975654541661\n",
- "Time elapsed: 3.7247533798217773\n",
- "At batch 1100\n",
- "Training loss per word: 1.292771070749634\n",
- "Training perplexity : 3.6428672249903027\n",
- "Time elapsed: 3.727574348449707\n",
- "At batch 1200\n",
- "Training loss per word: 1.3829026442307693\n",
- "Training perplexity : 3.9864561139408403\n",
- "Time elapsed: 3.6857106685638428\n",
- "At batch 1300\n",
- "Training loss per word: 1.405520340086133\n",
- "Training perplexity : 4.0776479530310095\n",
- "Time elapsed: 3.7007181644439697\n",
- "At batch 1400\n",
- "Training loss per word: 1.2932184916919702\n",
- "Training perplexity : 3.6444974847558975\n",
- "Time elapsed: 3.7181015014648438\n",
- "At batch 1500\n",
- "Training loss per word: 1.3205488685079587\n",
- "Training perplexity : 3.7454765873353084\n",
- "\n",
- "Validation loss per word: 1.4662550231819402\n",
- "Validation perplexity : 4.3329778169347035 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.7183938026428223\n",
- "At batch 100\n",
- "Training loss per word: 1.407801974281519\n",
- "Training perplexity : 4.086962275935464\n",
- "Time elapsed: 3.736786127090454\n",
- "At batch 200\n",
- "Training loss per word: 1.270954735225231\n",
- "Training perplexity : 3.564253857138264\n",
- "Time elapsed: 3.7378339767456055\n",
- "At batch 300\n",
- "Training loss per word: 1.3011009506032436\n",
- "Training perplexity : 3.67333860657824\n",
- "Time elapsed: 3.6993186473846436\n",
- "At batch 400\n",
- "Training loss per word: 1.3311979819283615\n",
- "Training perplexity : 3.785575723503658\n",
- "Time elapsed: 3.7243549823760986\n",
- "At batch 500\n",
- "Training loss per word: 1.3467497694500334\n",
- "Training perplexity : 3.844908361255419\n",
- "Time elapsed: 3.6973633766174316\n",
- "At batch 600\n",
- "Training loss per word: 1.245755025093129\n",
- "Training perplexity : 3.475557942401932\n",
- "Time elapsed: 3.7047317028045654\n",
- "At batch 700\n",
- "Training loss per word: 1.302772730860903\n",
- "Training perplexity : 3.6794847576159344\n",
- "Time elapsed: 3.7606568336486816\n",
- "At batch 800\n",
- "Training loss per word: 1.3673185022865855\n",
- "Training perplexity : 3.9248121973730536\n",
- "Time elapsed: 3.712702751159668\n",
- "At batch 900\n",
- "Training loss per word: 1.2946469603466386\n",
- "Training perplexity : 3.6497072552859513\n",
- "Time elapsed: 3.7113733291625977\n",
- "At batch 1000\n",
- "Training loss per word: 1.3207825568993778\n",
- "Training perplexity : 3.746351964012801\n",
- "Time elapsed: 3.7720985412597656\n",
- "At batch 1100\n",
- "Training loss per word: 1.2613232793426998\n",
- "Training perplexity : 3.5300896927826253\n",
- "Time elapsed: 3.7074599266052246\n",
- "At batch 1200\n",
- "Training loss per word: 1.2277225235133495\n",
- "Training perplexity : 3.413446632519601\n",
- "Time elapsed: 3.7118020057678223\n",
- "At batch 1300\n",
- "Training loss per word: 1.348728454331585\n",
- "Training perplexity : 3.852523755048424\n",
- "Time elapsed: 3.7075886726379395\n",
- "At batch 1400\n",
- "Training loss per word: 1.2741113349374347\n",
- "Training perplexity : 3.5755225558666193\n",
- "Time elapsed: 3.7233262062072754\n",
- "At batch 1500\n",
- "Training loss per word: 1.2676521634823907\n",
- "Training perplexity : 3.552502069307952\n",
- "\n",
- "Validation loss per word: 1.437432799953033\n",
- "Validation perplexity : 4.209874342884608 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.7521398067474365\n",
- "At batch 100\n",
- "Training loss per word: 1.2984101040072107\n",
- "Training perplexity : 3.6634675026619434\n",
- "Time elapsed: 3.6541831493377686\n",
- "At batch 200\n",
- "Training loss per word: 1.2519192490281101\n",
- "Training perplexity : 3.4970482272737584\n",
- "Time elapsed: 3.718026638031006\n",
- "At batch 300\n",
- "Training loss per word: 1.3276140372983871\n",
- "Training perplexity : 3.77203271291436\n",
- "Time elapsed: 3.684814453125\n",
- "At batch 400\n",
- "Training loss per word: 1.2439926560950925\n",
- "Training perplexity : 3.469438121109008\n",
- "Time elapsed: 3.7279927730560303\n",
- "At batch 500\n",
- "Training loss per word: 1.2611345896068025\n",
- "Training perplexity : 3.5294236639291805\n",
- "Time elapsed: 3.7200450897216797\n",
- "At batch 600\n",
- "Training loss per word: 1.2275730937179705\n",
- "Training perplexity : 3.4129365999957435\n",
- "Time elapsed: 3.719654083251953\n",
- "At batch 700\n",
- "Training loss per word: 1.2286639347254673\n",
- "Training perplexity : 3.4166616025183822\n",
- "Time elapsed: 3.7141995429992676\n",
- "At batch 800\n",
- "Training loss per word: 1.293564885779272\n",
- "Training perplexity : 3.6457601358106078\n",
- "Time elapsed: 3.7546281814575195\n",
- "At batch 900\n",
- "Training loss per word: 1.4087229681558935\n",
- "Training perplexity : 4.090728077030081\n",
- "Time elapsed: 3.700044870376587\n",
- "At batch 1000\n",
- "Training loss per word: 1.2707874804559762\n",
- "Training perplexity : 3.563657768532543\n",
- "Time elapsed: 3.717473030090332\n",
- "At batch 1100\n",
- "Training loss per word: 1.2463651149392985\n",
- "Training perplexity : 3.477678991961975\n",
- "Time elapsed: 3.7167797088623047\n",
- "At batch 1200\n",
- "Training loss per word: 1.3100708621231156\n",
- "Training perplexity : 3.7064363488534666\n",
- "Time elapsed: 3.720799446105957\n",
- "At batch 1300\n",
- "Training loss per word: 1.2738981620998784\n",
- "Training perplexity : 3.574760432812492\n",
- "Time elapsed: 3.6796560287475586\n",
- "At batch 1400\n",
- "Training loss per word: 1.3129865208128482\n",
- "Training perplexity : 3.717258821853787\n",
- "Time elapsed: 3.6994059085845947\n",
- "At batch 1500\n",
- "Training loss per word: 1.2379438652061638\n",
- "Training perplexity : 3.448515557311611\n",
- "\n",
- "Validation loss per word: 1.401114204142619\n",
- "Validation perplexity : 4.059720805547787 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.7109155654907227\n",
- "At batch 100\n",
- "Training loss per word: 1.2433489549446974\n",
- "Training perplexity : 3.4672055584278976\n",
- "Time elapsed: 3.6977715492248535\n",
- "At batch 200\n",
- "Training loss per word: 1.2903910941613195\n",
- "Training perplexity : 3.6342075952260915\n",
- "Time elapsed: 3.7268831729888916\n",
- "At batch 300\n",
- "Training loss per word: 1.1544621405909903\n",
- "Training perplexity : 3.1723166981967363\n",
- "Time elapsed: 3.6845014095306396\n",
- "At batch 400\n",
- "Training loss per word: 1.2515151961482813\n",
- "Training perplexity : 3.4956355202900102\n",
- "Time elapsed: 3.7195098400115967\n",
- "At batch 500\n",
- "Training loss per word: 1.2464558344057315\n",
- "Training perplexity : 3.4779944994556704\n",
- "Time elapsed: 3.6965761184692383\n",
- "At batch 600\n",
- "Training loss per word: 1.2064722318284709\n",
- "Training perplexity : 3.3416751789182\n",
- "Time elapsed: 3.7104105949401855\n",
- "At batch 700\n",
- "Training loss per word: 1.2265445738779892\n",
- "Training perplexity : 3.409428131564032\n",
- "Time elapsed: 3.7035112380981445\n",
- "At batch 800\n",
- "Training loss per word: 1.1967898446772212\n",
- "Training perplexity : 3.3094759204978104\n",
- "Time elapsed: 3.7424962520599365\n",
- "At batch 900\n",
- "Training loss per word: 1.234192673040896\n",
- "Training perplexity : 3.4356037452559796\n",
- "Time elapsed: 3.68930983543396\n",
- "At batch 1000\n",
- "Training loss per word: 1.2543550637093812\n",
- "Training perplexity : 3.5055767714466266\n",
- "Time elapsed: 3.7369351387023926\n",
- "At batch 1100\n",
- "Training loss per word: 1.1588790893554688\n",
- "Training perplexity : 3.1863596491840465\n",
- "Time elapsed: 3.718149423599243\n",
- "At batch 1200\n",
- "Training loss per word: 1.2224001718996451\n",
- "Training perplexity : 3.395327330746977\n",
- "Time elapsed: 3.704207420349121\n",
- "At batch 1300\n",
- "Training loss per word: 1.2001046291157973\n",
- "Training perplexity : 3.320464321808229\n",
- "Time elapsed: 3.7240679264068604\n",
- "At batch 1400\n",
- "Training loss per word: 1.2382448136659627\n",
- "Training perplexity : 3.4495535389388285\n",
- "Time elapsed: 3.694542646408081\n",
- "At batch 1500\n",
- "Training loss per word: 1.1240893522690039\n",
- "Training perplexity : 3.0774131332260097\n",
- "\n",
- "Validation loss per word: 1.3893149912296887\n",
- "Validation perplexity : 4.012100787239498 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.760380983352661\n",
- "At batch 100\n",
- "Training loss per word: 1.2399032188772912\n",
- "Training perplexity : 3.455279042795036\n",
- "Time elapsed: 3.7267673015594482\n",
- "At batch 200\n",
- "Training loss per word: 1.19580353990113\n",
- "Training perplexity : 3.3062133777862766\n",
- "Time elapsed: 3.710188627243042\n",
- "At batch 300\n",
- "Training loss per word: 1.2362357496620215\n",
- "Training perplexity : 3.4426301222165097\n",
- "Time elapsed: 3.715798854827881\n",
- "At batch 400\n",
- "Training loss per word: 1.2031334475091569\n",
- "Training perplexity : 3.330536651148724\n",
- "Time elapsed: 3.7028536796569824\n",
- "At batch 500\n",
- "Training loss per word: 1.2586697915761813\n",
- "Training perplexity : 3.5207350596592186\n",
- "Time elapsed: 3.7149481773376465\n",
- "At batch 600\n",
- "Training loss per word: 1.2212349467534946\n",
- "Training perplexity : 3.3913733140689897\n",
- "Time elapsed: 3.6803011894226074\n",
- "At batch 700\n",
- "Training loss per word: 1.2101083691578483\n",
- "Training perplexity : 3.35384808654884\n",
- "Time elapsed: 3.6926045417785645\n",
- "At batch 800\n",
- "Training loss per word: 1.159741473454301\n",
- "Training perplexity : 3.1891087002772642\n",
- "Time elapsed: 3.7095448970794678\n",
- "At batch 900\n",
- "Training loss per word: 1.1972919941590259\n",
- "Training perplexity : 3.3111381894351473\n",
- "Time elapsed: 3.7010698318481445\n",
- "At batch 1000\n",
- "Training loss per word: 1.19060720627464\n",
- "Training perplexity : 3.2890777498126944\n",
- "Time elapsed: 3.7239255905151367\n",
- "At batch 1100\n",
- "Training loss per word: 1.2420086092351343\n",
- "Training perplexity : 3.462561417406009\n",
- "Time elapsed: 3.7008190155029297\n",
- "At batch 1200\n",
- "Training loss per word: 1.217229673189615\n",
- "Training perplexity : 3.3778171024794905\n",
- "Time elapsed: 3.6846933364868164\n",
- "At batch 1300\n",
- "Training loss per word: 1.140925168829449\n",
- "Training perplexity : 3.1296624923862377\n",
- "Time elapsed: 3.701658010482788\n",
- "At batch 1400\n",
- "Training loss per word: 1.1341194704659505\n",
- "Training perplexity : 3.1084352684577303\n",
- "Time elapsed: 3.7287278175354004\n",
- "At batch 1500\n",
- "Training loss per word: 1.2465296672952586\n",
- "Training perplexity : 3.478251299319346\n",
- "\n",
- "Validation loss per word: 1.3809387088436718\n",
- "Validation perplexity : 3.9786346546438813 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.7069437503814697\n",
- "At batch 100\n",
- "Training loss per word: 1.1683425449489837\n",
- "Training perplexity : 3.2166567537715474\n",
- "Time elapsed: 3.7282252311706543\n",
- "At batch 200\n",
- "Training loss per word: 1.1665503661533765\n",
- "Training perplexity : 3.210897092457753\n",
- "Time elapsed: 3.734584331512451\n",
- "At batch 300\n",
- "Training loss per word: 1.178677349525876\n",
- "Training perplexity : 3.2500726486607086\n",
- "Time elapsed: 3.6963186264038086\n",
- "At batch 400\n",
- "Training loss per word: 1.2070136789678938\n",
- "Training perplexity : 3.3434850093042336\n",
- "Time elapsed: 3.705839157104492\n",
- "At batch 500\n",
- "Training loss per word: 1.1855977651235219\n",
- "Training perplexity : 3.2726425084402857\n",
- "Time elapsed: 3.7076480388641357\n",
- "At batch 600\n",
- "Training loss per word: 1.2006107478946835\n",
- "Training perplexity : 3.322145296506666\n",
- "Time elapsed: 3.6715080738067627\n",
- "At batch 700\n",
- "Training loss per word: 1.1356211530321882\n",
- "Training perplexity : 3.1131066581029954\n",
- "Time elapsed: 3.723267078399658\n",
- "At batch 800\n",
- "Training loss per word: 1.167945887890892\n",
- "Training perplexity : 3.215381097182526\n",
- "Time elapsed: 3.7093513011932373\n",
- "At batch 900\n",
- "Training loss per word: 1.1677971639537472\n",
- "Training perplexity : 3.2149029286047703\n",
- "Time elapsed: 3.6834521293640137\n",
- "At batch 1000\n",
- "Training loss per word: 1.1820589142832822\n",
- "Training perplexity : 3.261081583010059\n",
- "Time elapsed: 3.7082276344299316\n",
- "At batch 1100\n",
- "Training loss per word: 1.2310498327591308\n",
- "Training perplexity : 3.4248231411453696\n",
- "Time elapsed: 3.691101551055908\n",
- "At batch 1200\n",
- "Training loss per word: 1.227515040792063\n",
- "Training perplexity : 3.4127384747911065\n",
- "Time elapsed: 3.712756872177124\n",
- "At batch 1300\n",
- "Training loss per word: 1.1833673983071586\n",
- "Training perplexity : 3.2653514490785387\n",
- "Time elapsed: 3.6986610889434814\n",
- "At batch 1400\n",
- "Training loss per word: 1.180272545291736\n",
- "Training perplexity : 3.255261288136178\n",
- "Time elapsed: 3.7075552940368652\n",
- "At batch 1500\n",
- "Training loss per word: 1.1456190812371883\n",
- "Training perplexity : 3.1443873856349542\n",
- "\n",
- "Validation loss per word: 1.3714664597770394\n",
- "Validation perplexity : 3.941125962537374 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.714073657989502\n",
- "At batch 100\n",
- "Training loss per word: 1.2042411727831368\n",
- "Training perplexity : 3.334228014904123\n",
- "Time elapsed: 3.700242757797241\n",
- "At batch 200\n",
- "Training loss per word: 1.080187253760745\n",
- "Training perplexity : 2.9452310050149517\n",
- "Time elapsed: 3.6805882453918457\n",
- "At batch 300\n",
- "Training loss per word: 1.1463370953786645\n",
- "Training perplexity : 3.1466459109736546\n",
- "Time elapsed: 3.697406768798828\n",
- "At batch 400\n",
- "Training loss per word: 1.1752805264313133\n",
- "Training perplexity : 3.23905145594695\n",
- "Time elapsed: 3.6938788890838623\n",
- "At batch 500\n",
- "Training loss per word: 1.1683631187550978\n",
- "Training perplexity : 3.2167229333247156\n",
- "Time elapsed: 3.713092088699341\n",
- "At batch 600\n",
- "Training loss per word: 1.214711290724734\n",
- "Training perplexity : 3.369321169613484\n",
- "Time elapsed: 3.696408271789551\n",
- "At batch 700\n",
- "Training loss per word: 1.148598030821918\n",
- "Training perplexity : 3.15376832286384\n",
- "Time elapsed: 3.692023754119873\n",
- "At batch 800\n",
- "Training loss per word: 1.0949797689332248\n",
- "Training perplexity : 2.9891222096506187\n",
- "Time elapsed: 3.699280023574829\n",
- "At batch 900\n",
- "Training loss per word: 1.1116656863747953\n",
- "Training perplexity : 3.039416895638226\n",
- "Time elapsed: 3.7105231285095215\n",
- "At batch 1000\n",
- "Training loss per word: 1.1340194229200653\n",
- "Training perplexity : 3.1081242926940185\n",
- "Time elapsed: 3.709425687789917\n",
- "At batch 1100\n",
- "Training loss per word: 1.1836204710337355\n",
- "Training perplexity : 3.266177925047841\n",
- "Time elapsed: 3.7070138454437256\n",
- "At batch 1200\n",
- "Training loss per word: 1.1820404978742913\n",
- "Training perplexity : 3.261021526150891\n",
- "Time elapsed: 3.7249083518981934\n",
- "At batch 1300\n",
- "Training loss per word: 1.1580513136380806\n",
- "Training perplexity : 3.183723149405381\n",
- "Time elapsed: 3.6846048831939697\n",
- "At batch 1400\n",
- "Training loss per word: 1.1968900566087843\n",
- "Training perplexity : 3.3098075860904124\n",
- "Time elapsed: 3.7383875846862793\n",
- "At batch 1500\n",
- "Training loss per word: 1.085747366480498\n",
- "Training perplexity : 2.9616524315744126\n",
- "\n",
- "Validation loss per word: 1.3639650511099437\n",
- "Validation perplexity : 3.9116725751460977 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.731407642364502\n",
- "At batch 100\n",
- "Training loss per word: 1.138607757260101\n",
- "Training perplexity : 3.122418173594389\n",
- "Time elapsed: 3.737220287322998\n",
- "At batch 200\n",
- "Training loss per word: 1.0827404090447155\n",
- "Training perplexity : 2.952760244686936\n",
- "Time elapsed: 3.7554142475128174\n",
- "At batch 300\n",
- "Training loss per word: 1.139403432521323\n",
- "Training perplexity : 3.1249035931526095\n",
- "Time elapsed: 3.7403788566589355\n",
- "At batch 400\n",
- "Training loss per word: 1.2261186223713056\n",
- "Training perplexity : 3.4079761897648138\n",
- "Time elapsed: 3.7107059955596924\n",
- "At batch 500\n",
- "Training loss per word: 1.1567305672268908\n",
- "Training perplexity : 3.1795210340568025\n",
- "Time elapsed: 3.722501039505005\n",
- "At batch 600\n",
- "Training loss per word: 1.1871671037373888\n",
- "Training perplexity : 3.277782424777861\n",
- "Time elapsed: 3.7460691928863525\n",
- "At batch 700\n",
- "Training loss per word: 1.1577818131630053\n",
- "Training perplexity : 3.1828652501114343\n",
- "Time elapsed: 3.716709613800049\n",
- "At batch 800\n",
- "Training loss per word: 1.2216747341654488\n",
- "Training perplexity : 3.392865125377627\n",
- "Time elapsed: 3.731794834136963\n",
- "At batch 900\n",
- "Training loss per word: 1.1167742776497225\n",
- "Training perplexity : 3.0549837627980807\n",
- "Time elapsed: 3.727544069290161\n",
- "At batch 1000\n",
- "Training loss per word: 1.2042394139196992\n",
- "Training perplexity : 3.3342221504575322\n",
- "Time elapsed: 3.7003278732299805\n",
- "At batch 1100\n",
- "Training loss per word: 1.1784710476808886\n",
- "Training perplexity : 3.2494022218344703\n",
- "Time elapsed: 3.720731019973755\n",
- "At batch 1200\n",
- "Training loss per word: 1.1495366143521943\n",
- "Training perplexity : 3.156729787443522\n",
- "Time elapsed: 3.852750539779663\n",
- "At batch 1300\n",
- "Training loss per word: 1.1285352658002805\n",
- "Training perplexity : 3.091125505339989\n",
- "Time elapsed: 3.7642767429351807\n",
- "At batch 1400\n",
- "Training loss per word: 1.1840574717695236\n",
- "Training perplexity : 3.267605559120152\n",
- "Time elapsed: 3.772087335586548\n",
- "At batch 1500\n",
- "Training loss per word: 1.1927117184879235\n",
- "Training perplexity : 3.2960069428358536\n",
- "\n",
- "Validation loss per word: 1.355904031907872\n",
- "Validation perplexity : 3.8802672569020276 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.7840490341186523\n",
- "At batch 100\n",
- "Training loss per word: 1.1124401816217846\n",
- "Training perplexity : 3.0417718213992053\n",
- "Time elapsed: 3.8998072147369385\n",
- "At batch 200\n",
- "Training loss per word: 1.1762464859835287\n",
- "Training perplexity : 3.242181760271219\n",
- "Time elapsed: 3.8313801288604736\n",
- "At batch 300\n",
- "Training loss per word: 1.141115539648126\n",
- "Training perplexity : 3.13025834551182\n",
- "Time elapsed: 3.8294456005096436\n",
- "At batch 400\n",
- "Training loss per word: 1.1098259095620955\n",
- "Training perplexity : 3.0338301876332396\n",
- "Time elapsed: 3.8397085666656494\n",
- "At batch 500\n",
- "Training loss per word: 1.1858041221217106\n",
- "Training perplexity : 3.273317910809078\n",
- "Time elapsed: 3.8198180198669434\n",
- "At batch 600\n",
- "Training loss per word: 1.1565715791078035\n",
- "Training perplexity : 3.179015568170599\n",
- "Time elapsed: 3.7804319858551025\n",
- "At batch 700\n",
- "Training loss per word: 1.1901554751047891\n",
- "Training perplexity : 3.287592306408844\n",
- "Time elapsed: 3.831249475479126\n",
- "At batch 800\n",
- "Training loss per word: 1.0787425135484199\n",
- "Training perplexity : 2.9409789836201603\n",
- "Time elapsed: 3.842750072479248\n",
- "At batch 900\n",
- "Training loss per word: 1.1146143615015685\n",
- "Training perplexity : 3.048392375021446\n",
- "Time elapsed: 3.820983409881592\n",
- "At batch 1000\n",
- "Training loss per word: 1.1737537865402736\n",
- "Training perplexity : 3.234110039968315\n",
- "Time elapsed: 3.8320257663726807\n",
- "At batch 1100\n",
- "Training loss per word: 1.1355072420852617\n",
- "Training perplexity : 3.112752061372295\n",
- "Time elapsed: 3.8270204067230225\n",
- "At batch 1200\n",
- "Training loss per word: 1.2335316416539759\n",
- "Training perplexity : 3.43333345379697\n",
- "Time elapsed: 3.795576572418213\n",
- "At batch 1300\n",
- "Training loss per word: 1.1775943048286124\n",
- "Training perplexity : 3.246554580169454\n",
- "Time elapsed: 3.7663586139678955\n",
- "At batch 1400\n",
- "Training loss per word: 1.1165851997106095\n",
- "Training perplexity : 3.054406187369294\n",
- "Time elapsed: 3.792233467102051\n",
- "At batch 1500\n",
- "Training loss per word: 1.0653275908562894\n",
- "Training perplexity : 2.901789428056261\n",
- "\n",
- "Validation loss per word: 1.349567706146099\n",
- "Validation perplexity : 3.8557583497301797 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.801424264907837\n",
- "At batch 100\n",
- "Training loss per word: 1.1798410001008406\n",
- "Training perplexity : 3.2538567988542693\n",
- "Time elapsed: 3.7891664505004883\n",
- "At batch 200\n",
- "Training loss per word: 1.2098545373174743\n",
- "Training perplexity : 3.352996881152783\n",
- "Time elapsed: 3.813218832015991\n",
- "At batch 300\n",
- "Training loss per word: 1.0777284977000972\n",
- "Training perplexity : 2.9379982958089275\n",
- "Time elapsed: 3.8325226306915283\n",
- "At batch 400\n",
- "Training loss per word: 1.1992148277407786\n",
- "Training perplexity : 3.317511082182175\n",
- "Time elapsed: 3.787041664123535\n",
- "At batch 500\n",
- "Training loss per word: 1.2374958790787536\n",
- "Training perplexity : 3.44697101617411\n",
- "Time elapsed: 3.8416755199432373\n",
- "At batch 600\n",
- "Training loss per word: 1.18956913508422\n",
- "Training perplexity : 3.2856652244861184\n",
- "Time elapsed: 3.7902870178222656\n",
- "At batch 700\n",
- "Training loss per word: 1.1219792623153455\n",
- "Training perplexity : 3.070926360933708\n",
- "Time elapsed: 3.7493112087249756\n",
- "At batch 800\n",
- "Training loss per word: 1.1325905098820364\n",
- "Training perplexity : 3.1036862249299535\n",
- "Time elapsed: 3.7742223739624023\n",
- "At batch 900\n",
- "Training loss per word: 1.11555278468373\n",
- "Training perplexity : 3.0512543997796473\n",
- "Time elapsed: 3.787297248840332\n",
- "At batch 1000\n",
- "Training loss per word: 1.1625877173310282\n",
- "Training perplexity : 3.1981986113029186\n",
- "Time elapsed: 3.8134567737579346\n",
- "At batch 1100\n",
- "Training loss per word: 1.1657647944307146\n",
- "Training perplexity : 3.2083756929972673\n",
- "Time elapsed: 3.766493797302246\n",
- "At batch 1200\n",
- "Training loss per word: 1.1748896632588532\n",
- "Training perplexity : 3.2377856774083393\n",
- "Time elapsed: 3.8155629634857178\n",
- "At batch 1300\n",
- "Training loss per word: 1.1103692056259904\n",
- "Training perplexity : 3.0354789034625624\n",
- "Time elapsed: 3.779963254928589\n",
- "At batch 1400\n",
- "Training loss per word: 1.14418991708818\n",
- "Training perplexity : 3.1398967496051684\n",
- "Time elapsed: 3.763909339904785\n",
- "At batch 1500\n",
- "Training loss per word: 1.1915889870472838\n",
- "Training perplexity : 3.2923084887863863\n",
- "\n",
- "Validation loss per word: 1.3544946988514919\n",
- "Validation perplexity : 3.8748025197111846 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.7811734676361084\n",
- "At batch 100\n",
- "Training loss per word: 1.0663566247449183\n",
- "Training perplexity : 2.9047770046105246\n",
- "Time elapsed: 3.787771463394165\n",
- "At batch 200\n",
- "Training loss per word: 1.1603800186258277\n",
- "Training perplexity : 3.1911457405411974\n",
- "Time elapsed: 3.8014190196990967\n",
- "At batch 300\n",
- "Training loss per word: 1.1064504523026315\n",
- "Training perplexity : 3.02360688736574\n",
- "Time elapsed: 3.7817535400390625\n",
- "At batch 400\n",
- "Training loss per word: 1.1599953924517536\n",
- "Training perplexity : 3.1899185783785726\n",
- "Time elapsed: 3.7866930961608887\n",
- "At batch 500\n",
- "Training loss per word: 1.1528212436409884\n",
- "Training perplexity : 3.167115521866446\n",
- "Time elapsed: 3.805013656616211\n",
- "At batch 600\n",
- "Training loss per word: 1.0859856002916397\n",
- "Training perplexity : 2.962358081371947\n",
- "Time elapsed: 3.8244552612304688\n",
- "At batch 700\n",
- "Training loss per word: 1.1673833375336022\n",
- "Training perplexity : 3.2135727920765134\n",
- "Time elapsed: 3.844132661819458\n",
- "At batch 800\n",
- "Training loss per word: 1.1472129162707325\n",
- "Training perplexity : 3.14940301639145\n",
- "Time elapsed: 3.7915048599243164\n",
- "At batch 900\n",
- "Training loss per word: 1.1193068434255191\n",
- "Training perplexity : 3.0627305155612508\n",
- "Time elapsed: 3.854457139968872\n",
- "At batch 1000\n",
- "Training loss per word: 1.1248606687898088\n",
- "Training perplexity : 3.079787708473843\n",
- "Time elapsed: 3.794856071472168\n",
- "At batch 1100\n",
- "Training loss per word: 1.104549096162336\n",
- "Training perplexity : 3.0178633957863994\n",
- "Time elapsed: 3.775325298309326\n",
- "At batch 1200\n",
- "Training loss per word: 1.132256503959687\n",
- "Training perplexity : 3.1026497484539894\n",
- "Time elapsed: 3.863532304763794\n",
- "At batch 1300\n",
- "Training loss per word: 1.1573611899207061\n",
- "Training perplexity : 3.1815267445331434\n",
- "Time elapsed: 3.7871999740600586\n",
- "At batch 1400\n",
- "Training loss per word: 1.1226295386014753\n",
- "Training perplexity : 3.0729239609482413\n",
- "Time elapsed: 3.803307294845581\n",
- "At batch 1500\n",
- "Training loss per word: 1.0962326895043732\n",
- "Training perplexity : 2.992869689513836\n",
- "\n",
- "Validation loss per word: 1.3441824345825584\n",
- "Validation perplexity : 3.8350498544164515 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.823367118835449\n",
- "At batch 100\n",
- "Training loss per word: 1.098994354081284\n",
- "Training perplexity : 3.0011464152283818\n",
- "Time elapsed: 3.828970432281494\n",
- "At batch 200\n",
- "Training loss per word: 1.099687687303654\n",
- "Training perplexity : 3.003227931251978\n",
- "Time elapsed: 3.8045783042907715\n",
- "At batch 300\n",
- "Training loss per word: 1.1488120911449526\n",
- "Training perplexity : 3.15444349179067\n",
- "Time elapsed: 3.807218074798584\n",
- "At batch 400\n",
- "Training loss per word: 1.137648790458962\n",
- "Training perplexity : 3.11942531348274\n",
- "Time elapsed: 3.826814651489258\n",
- "At batch 500\n",
- "Training loss per word: 1.1339701912040903\n",
- "Training perplexity : 3.10797127816824\n",
- "Time elapsed: 3.8112239837646484\n",
- "At batch 600\n",
- "Training loss per word: 1.1199135349428997\n",
- "Training perplexity : 3.0645892119557208\n",
- "Time elapsed: 3.841059923171997\n",
- "At batch 700\n",
- "Training loss per word: 1.1574921703895604\n",
- "Training perplexity : 3.1819434896899574\n",
- "Time elapsed: 3.8483245372772217\n",
- "At batch 800\n",
- "Training loss per word: 1.098882556849023\n",
- "Training perplexity : 3.000810914119946\n",
- "Time elapsed: 3.808338165283203\n",
- "At batch 900\n",
- "Training loss per word: 1.1703846118244168\n",
- "Training perplexity : 3.22323209333467\n",
- "Time elapsed: 3.8445372581481934\n",
- "At batch 1000\n",
- "Training loss per word: 1.1218593718899283\n",
- "Training perplexity : 3.0705582083352976\n",
- "Time elapsed: 3.8156325817108154\n",
- "At batch 1100\n",
- "Training loss per word: 1.081787839121139\n",
- "Training perplexity : 2.949948873312807\n",
- "Time elapsed: 3.8010196685791016\n",
- "At batch 1200\n",
- "Training loss per word: 1.1653183690759519\n",
- "Training perplexity : 3.2069437124003146\n",
- "Time elapsed: 3.8367254734039307\n",
- "At batch 1300\n",
- "Training loss per word: 1.0637366903981855\n",
- "Training perplexity : 2.8971766401419488\n",
- "Time elapsed: 3.816741943359375\n",
- "At batch 1400\n",
- "Training loss per word: 1.1049788031586625\n",
- "Training perplexity : 3.019160471462814\n",
- "Time elapsed: 3.8468456268310547\n",
- "At batch 1500\n",
- "Training loss per word: 1.1472445435550493\n",
- "Training perplexity : 3.149502625031245\n",
- "\n",
- "Validation loss per word: 1.3474884972846892\n",
- "Validation perplexity : 3.8477497514613797 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.8349106311798096\n",
- "At batch 100\n",
- "Training loss per word: 1.1429718584217508\n",
- "Training perplexity : 3.136074499492635\n",
- "Time elapsed: 3.818223237991333\n",
- "At batch 200\n",
- "Training loss per word: 1.114212308431097\n",
- "Training perplexity : 3.0471670058552913\n",
- "Time elapsed: 3.8348867893218994\n",
- "At batch 300\n",
- "Training loss per word: 1.1334894353693181\n",
- "Training perplexity : 3.106477461951159\n",
- "Time elapsed: 3.8033835887908936\n",
- "At batch 400\n",
- "Training loss per word: 1.1438675551094517\n",
- "Training perplexity : 3.1388847294031446\n",
- "Time elapsed: 3.8263168334960938\n",
- "At batch 500\n",
- "Training loss per word: 1.0865191452618703\n",
- "Training perplexity : 2.96393905434887\n",
- "Time elapsed: 3.7579545974731445\n",
- "At batch 600\n",
- "Training loss per word: 1.1192416697600813\n",
- "Training perplexity : 3.0625309126917997\n",
- "Time elapsed: 3.789614677429199\n",
- "At batch 700\n",
- "Training loss per word: 1.0518411875891958\n",
- "Training perplexity : 2.862917436488333\n",
- "Time elapsed: 3.776650905609131\n",
- "At batch 800\n",
- "Training loss per word: 1.1072874813988096\n",
- "Training perplexity : 3.0261387937977213\n",
- "Time elapsed: 3.813297986984253\n",
- "At batch 900\n",
- "Training loss per word: 1.1313975205155835\n",
- "Training perplexity : 3.0999857680085667\n",
- "Time elapsed: 3.8124847412109375\n",
- "At batch 1000\n",
- "Training loss per word: 1.0942230369700872\n",
- "Training perplexity : 2.9868611009673587\n",
- "Time elapsed: 3.7665154933929443\n",
- "At batch 1100\n",
- "Training loss per word: 1.1217584830326635\n",
- "Training perplexity : 3.0702484388529006\n",
- "Time elapsed: 3.763503313064575\n",
- "At batch 1200\n",
- "Training loss per word: 1.1460056161145935\n",
- "Training perplexity : 3.145603035958064\n",
- "Time elapsed: 3.7855074405670166\n",
- "At batch 1300\n",
- "Training loss per word: 1.1267460799469964\n",
- "Training perplexity : 3.0855998519995556\n",
- "Time elapsed: 3.791804790496826\n",
- "At batch 1400\n",
- "Training loss per word: 1.1263326390108068\n",
- "Training perplexity : 3.0843244023877476\n",
- "Time elapsed: 3.7853331565856934\n",
- "At batch 1500\n",
- "Training loss per word: 1.0750992197923617\n",
- "Training perplexity : 2.9302836282436466\n",
- "\n",
- "Validation loss per word: 1.3388573407200655\n",
- "Validation perplexity : 3.8146821321208044 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.8155081272125244\n",
- "At batch 100\n",
- "Training loss per word: 1.1465816047345114\n",
- "Training perplexity : 3.14741538940693\n",
- "Time elapsed: 3.7689809799194336\n",
- "At batch 200\n",
- "Training loss per word: 1.1306585109405718\n",
- "Training perplexity : 3.097695695140721\n",
- "Time elapsed: 3.7883920669555664\n",
- "At batch 300\n",
- "Training loss per word: 1.1397602998421168\n",
- "Training perplexity : 3.1260189681342405\n",
- "Time elapsed: 3.790637731552124\n",
- "At batch 400\n",
- "Training loss per word: 1.0735297877126964\n",
- "Training perplexity : 2.9256883540435603\n",
- "Time elapsed: 3.767266273498535\n",
- "At batch 500\n",
- "Training loss per word: 1.1048049315623998\n",
- "Training perplexity : 3.0186355708462442\n",
- "Time elapsed: 3.763728618621826\n",
- "At batch 600\n",
- "Training loss per word: 1.1420912249116058\n",
- "Training perplexity : 3.1333139828781906\n",
- "Time elapsed: 3.782888889312744\n",
- "At batch 700\n",
- "Training loss per word: 1.0711027190919264\n",
- "Training perplexity : 2.918596117798826\n",
- "Time elapsed: 3.7938215732574463\n",
- "At batch 800\n",
- "Training loss per word: 1.0894365344454509\n",
- "Training perplexity : 2.972598643640959\n",
- "Time elapsed: 3.792896032333374\n",
- "At batch 900\n",
- "Training loss per word: 1.1475062995788354\n",
- "Training perplexity : 3.1503271342206927\n",
- "Time elapsed: 3.8132219314575195\n",
- "At batch 1000\n",
- "Training loss per word: 1.1115614922776647\n",
- "Training perplexity : 3.0391002228369857\n",
- "Time elapsed: 3.8269190788269043\n",
- "At batch 1100\n",
- "Training loss per word: 1.143299027185239\n",
- "Training perplexity : 3.137100692968913\n",
- "Time elapsed: 3.7997896671295166\n",
- "At batch 1200\n",
- "Training loss per word: 1.15957128420398\n",
- "Training perplexity : 3.188565994441032\n",
- "Time elapsed: 3.776357889175415\n",
- "At batch 1300\n",
- "Training loss per word: 1.1464660242845082\n",
- "Training perplexity : 3.147051630741975\n",
- "Time elapsed: 3.7989509105682373\n",
- "At batch 1400\n",
- "Training loss per word: 1.107467296511628\n",
- "Training perplexity : 3.0266829882120554\n",
- "Time elapsed: 3.7822105884552\n",
- "At batch 1500\n",
- "Training loss per word: 1.1874470166449504\n",
- "Training perplexity : 3.2787000468078693\n",
- "\n",
- "Validation loss per word: 1.341572059561056\n",
- "Validation perplexity : 3.8250519908341163 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.8168201446533203\n",
- "At batch 100\n",
- "Training loss per word: 1.0604724310631402\n",
- "Training perplexity : 2.887734922739937\n",
- "Time elapsed: 3.810335636138916\n",
- "At batch 200\n",
- "Training loss per word: 1.1267382973118831\n",
- "Training perplexity : 3.0855758379952483\n",
- "Time elapsed: 3.7992618083953857\n",
- "At batch 300\n",
- "Training loss per word: 1.094500573594331\n",
- "Training perplexity : 2.9876901793588932\n",
- "Time elapsed: 3.7486989498138428\n",
- "At batch 400\n",
- "Training loss per word: 1.0873632493361929\n",
- "Training perplexity : 2.9664419835984646\n",
- "Time elapsed: 3.8080248832702637\n",
- "At batch 500\n",
- "Training loss per word: 1.1584351419440564\n",
- "Training perplexity : 3.184945387018204\n",
- "Time elapsed: 3.8131954669952393\n",
- "At batch 600\n",
- "Training loss per word: 1.089704428939916\n",
- "Training perplexity : 2.973395093129298\n",
- "Time elapsed: 3.7971811294555664\n",
- "At batch 700\n",
- "Training loss per word: 1.1562952828996909\n",
- "Training perplexity : 3.178137339554788\n",
- "Time elapsed: 3.785998821258545\n",
- "At batch 800\n",
- "Training loss per word: 1.1333110326328326\n",
- "Training perplexity : 3.105923307303939\n",
- "Time elapsed: 3.8070566654205322\n",
- "At batch 900\n",
- "Training loss per word: 1.0939211415146624\n",
- "Training perplexity : 2.985959517273849\n",
- "Time elapsed: 3.7978057861328125\n",
- "At batch 1000\n",
- "Training loss per word: 1.0147604237432066\n",
- "Training perplexity : 2.7587023986053727\n",
- "Time elapsed: 3.7495784759521484\n",
- "At batch 1100\n",
- "Training loss per word: 1.1340868284557972\n",
- "Training perplexity : 3.1083338045381383\n",
- "Time elapsed: 3.769446849822998\n",
- "At batch 1200\n",
- "Training loss per word: 1.1045949921275442\n",
- "Training perplexity : 3.0180019067183372\n",
- "Time elapsed: 3.7816579341888428\n",
- "At batch 1300\n",
- "Training loss per word: 1.1174229738780979\n",
- "Training perplexity : 3.0569661621607334\n",
- "Time elapsed: 3.7789602279663086\n",
- "At batch 1400\n",
- "Training loss per word: 1.1201495769350878\n",
- "Training perplexity : 3.0653126690783203\n",
- "Time elapsed: 3.824803590774536\n",
- "At batch 1500\n",
- "Training loss per word: 1.1051722046675956\n",
- "Training perplexity : 3.0197444381219025\n",
- "\n",
- "Validation loss per word: 1.3464858253222147\n",
- "Validation perplexity : 3.8438936541903415 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.7695260047912598\n",
- "At batch 100\n",
- "Training loss per word: 1.0290114195507079\n",
- "Training perplexity : 2.798298124055436\n",
- "Time elapsed: 3.7838001251220703\n",
- "At batch 200\n",
- "Training loss per word: 1.0758974751576786\n",
- "Training perplexity : 2.9326236767256617\n",
- "Time elapsed: 3.7888290882110596\n",
- "At batch 300\n",
- "Training loss per word: 1.0689325360445519\n",
- "Training perplexity : 2.912269097860745\n",
- "Time elapsed: 3.7760934829711914\n",
- "At batch 400\n",
- "Training loss per word: 1.098075355425545\n",
- "Training perplexity : 2.9983896326408783\n",
- "Time elapsed: 3.7993335723876953\n",
- "At batch 500\n",
- "Training loss per word: 1.1470130774175893\n",
- "Training perplexity : 3.1487737061869776\n",
- "Time elapsed: 3.8072171211242676\n",
- "At batch 600\n",
- "Training loss per word: 1.1184220890153036\n",
- "Training perplexity : 3.060021949614826\n",
- "Time elapsed: 3.7615554332733154\n",
- "At batch 700\n",
- "Training loss per word: 1.1594816615030106\n",
- "Training perplexity : 3.188280239349653\n",
- "Time elapsed: 3.8286609649658203\n",
- "At batch 800\n",
- "Training loss per word: 1.147852127724748\n",
- "Training perplexity : 3.1514167944192613\n",
- "Time elapsed: 3.8103957176208496\n",
- "At batch 900\n",
- "Training loss per word: 1.0395428649965506\n",
- "Training perplexity : 2.8279239757910557\n",
- "Time elapsed: 3.783261775970459\n",
- "At batch 1000\n",
- "Training loss per word: 1.1195837507706534\n",
- "Training perplexity : 3.0635787255695477\n",
- "Time elapsed: 3.7467522621154785\n",
- "At batch 1100\n",
- "Training loss per word: 1.0825985467027806\n",
- "Training perplexity : 2.9523413889140855\n",
- "Time elapsed: 3.7833640575408936\n",
- "At batch 1200\n",
- "Training loss per word: 1.1820229240085767\n",
- "Training perplexity : 3.260964217900063\n",
- "Time elapsed: 3.736557722091675\n",
- "At batch 1300\n",
- "Training loss per word: 1.0592126011532246\n",
- "Training perplexity : 2.884099158615438\n",
- "Time elapsed: 3.799887180328369\n",
- "At batch 1400\n",
- "Training loss per word: 1.0864442113967636\n",
- "Training perplexity : 2.963716963260763\n",
- "Time elapsed: 3.7434606552124023\n",
- "At batch 1500\n",
- "Training loss per word: 1.0987684037419823\n",
- "Training perplexity : 3.0004683817813946\n",
- "\n",
- "Validation loss per word: 1.3423204088491216\n",
- "Validation perplexity : 3.8279155371010374 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.748110771179199\n",
- "At batch 100\n",
- "Training loss per word: 1.0774480682967873\n",
- "Training perplexity : 2.937174510212159\n",
- "Time elapsed: 3.769700527191162\n",
- "At batch 200\n",
- "Training loss per word: 1.0767136277387757\n",
- "Training perplexity : 2.935018122092278\n",
- "Time elapsed: 3.7462425231933594\n",
- "At batch 300\n",
- "Training loss per word: 1.0570872274623304\n",
- "Training perplexity : 2.877975879574887\n",
- "Time elapsed: 3.740286350250244\n",
- "At batch 400\n",
- "Training loss per word: 1.0765112739919693\n",
- "Training perplexity : 2.934424270264426\n",
- "Time elapsed: 3.7884538173675537\n",
- "At batch 500\n",
- "Training loss per word: 1.0194008484792285\n",
- "Training perplexity : 2.7715336977689327\n",
- "Time elapsed: 3.793185234069824\n",
- "At batch 600\n",
- "Training loss per word: 1.158294820626045\n",
- "Training perplexity : 3.1844985026381396\n",
- "Time elapsed: 3.77243709564209\n",
- "At batch 700\n",
- "Training loss per word: 1.070255908016044\n",
- "Training perplexity : 2.9161256644315117\n",
- "Time elapsed: 3.7856192588806152\n",
- "At batch 800\n",
- "Training loss per word: 1.1220820908874045\n",
- "Training perplexity : 3.071242156142408\n",
- "Time elapsed: 3.809751510620117\n",
- "At batch 900\n",
- "Training loss per word: 1.1019160212302694\n",
- "Training perplexity : 3.0099275877041842\n",
- "Time elapsed: 3.7891132831573486\n",
- "At batch 1000\n",
- "Training loss per word: 1.1133945395142806\n",
- "Training perplexity : 3.044676146006247\n",
- "Time elapsed: 3.8100802898406982\n",
- "At batch 1100\n",
- "Training loss per word: 1.147145927510494\n",
- "Training perplexity : 3.1491920488541965\n",
- "Time elapsed: 3.7659027576446533\n",
- "At batch 1200\n",
- "Training loss per word: 1.1528299265894397\n",
- "Training perplexity : 3.1671430218866523\n",
- "Time elapsed: 3.7785401344299316\n",
- "At batch 1300\n",
- "Training loss per word: 1.0877712259030914\n",
- "Training perplexity : 2.967652469322952\n",
- "Time elapsed: 3.8046278953552246\n",
- "At batch 1400\n",
- "Training loss per word: 1.082173963983727\n",
- "Training perplexity : 2.951088141851955\n",
- "Time elapsed: 3.7735743522644043\n",
- "At batch 1500\n",
- "Training loss per word: 1.0206523487773487\n",
- "Training perplexity : 2.775004444385148\n",
- "\n",
- "Validation loss per word: 1.3401957653391814\n",
- "Validation perplexity : 3.8197912148989297 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.7973008155822754\n",
- "At batch 100\n",
- "Training loss per word: 1.0854998406316232\n",
- "Training perplexity : 2.960919436763539\n",
- "Time elapsed: 3.788548231124878\n",
- "At batch 200\n",
- "Training loss per word: 1.0343923061708862\n",
- "Training perplexity : 2.813396032628909\n",
- "Time elapsed: 3.77480149269104\n",
- "At batch 300\n",
- "Training loss per word: 1.0724025374116142\n",
- "Training perplexity : 2.9223922290935436\n",
- "Time elapsed: 3.765207290649414\n",
- "At batch 400\n",
- "Training loss per word: 1.091713715105202\n",
- "Training perplexity : 2.979375500917144\n",
- "Time elapsed: 3.768878698348999\n",
- "At batch 500\n",
- "Training loss per word: 1.0816575361220218\n",
- "Training perplexity : 2.9495645111696853\n",
- "Time elapsed: 3.7661168575286865\n",
- "At batch 600\n",
- "Training loss per word: 1.09487015015023\n",
- "Training perplexity : 2.988794563670179\n",
- "Time elapsed: 3.7885003089904785\n",
- "At batch 700\n",
- "Training loss per word: 1.0401752044299777\n",
- "Training perplexity : 2.829712749132039\n",
- "Time elapsed: 3.7507691383361816\n",
- "At batch 800\n",
- "Training loss per word: 1.0489363170557358\n",
- "Training perplexity : 2.854613099339013\n",
- "Time elapsed: 3.7792656421661377\n",
- "At batch 900\n",
- "Training loss per word: 1.1101965929045265\n",
- "Training perplexity : 3.034954986406765\n",
- "Time elapsed: 3.805776834487915\n",
- "At batch 1000\n",
- "Training loss per word: 1.0929490400282718\n",
- "Training perplexity : 2.9830582719697096\n",
- "Time elapsed: 3.796440839767456\n",
- "At batch 1100\n",
- "Training loss per word: 1.0269196875\n",
- "Training perplexity : 2.7924509516728016\n",
- "Time elapsed: 3.771205186843872\n",
- "At batch 1200\n",
- "Training loss per word: 1.0856025053662126\n",
- "Training perplexity : 2.961223434376305\n",
- "Time elapsed: 3.817021369934082\n",
- "At batch 1300\n",
- "Training loss per word: 1.19846151513991\n",
- "Training perplexity : 3.31501290035451\n",
- "Time elapsed: 3.7447714805603027\n",
- "At batch 1400\n",
- "Training loss per word: 1.087195527072303\n",
- "Training perplexity : 2.9659444869550753\n",
- "Time elapsed: 3.757854461669922\n",
- "At batch 1500\n",
- "Training loss per word: 1.0551075411295794\n",
- "Training perplexity : 2.872284025961448\n",
- "\n",
- "Validation loss per word: 1.3427743964693049\n",
- "Validation perplexity : 3.829653757901496 \n",
- "\n",
- "Training 1563 number of batches\n",
- "Time elapsed: 3.7451558113098145\n",
- "At batch 100\n",
- "Training loss per word: 1.118965058696411\n",
- "Training perplexity : 3.0616838999104705\n",
- "Time elapsed: 3.7456469535827637\n",
- "At batch 200\n",
- "Training loss per word: 1.084817152756911\n",
- "Training perplexity : 2.958898742796177\n",
- "Time elapsed: 3.739215850830078\n",
- "At batch 300\n",
- "Training loss per word: 1.0808059325332904\n",
- "Training perplexity : 2.947053720697608\n",
- "Time elapsed: 3.784787654876709\n",
- "At batch 400\n",
- "Training loss per word: 1.0143558166579538\n",
- "Training perplexity : 2.7575864338476412\n",
- "Time elapsed: 3.757758378982544\n",
- "At batch 500\n",
- "Training loss per word: 1.0636638919890873\n",
- "Training perplexity : 2.896965737968434\n",
- "Time elapsed: 3.761496067047119\n",
- "At batch 600\n",
- "Training loss per word: 1.0679087537400267\n",
- "Training perplexity : 2.9092890939902407\n",
- "Time elapsed: 3.7373595237731934\n",
- "At batch 700\n",
- "Training loss per word: 1.0820845007856728\n",
- "Training perplexity : 2.950824139878452\n",
- "Time elapsed: 3.7674787044525146\n",
- "At batch 800\n",
- "Training loss per word: 1.110800911860211\n",
- "Training perplexity : 3.0367896215312564\n",
- "Time elapsed: 3.762507677078247\n",
- "At batch 900\n",
- "Training loss per word: 1.0788491338809407\n",
- "Training perplexity : 2.9412925684942954\n",
- "Time elapsed: 3.7523109912872314\n",
- "At batch 1000\n",
- "Training loss per word: 1.1006824035840281\n",
- "Training perplexity : 3.0062167772493096\n",
- "Time elapsed: 3.719322443008423\n",
- "At batch 1100\n",
- "Training loss per word: 1.092326754734693\n",
- "Training perplexity : 2.9812025361356183\n",
- "Time elapsed: 3.751136064529419\n",
- "At batch 1200\n",
- "Training loss per word: 1.0834127004087597\n",
- "Training perplexity : 2.9547460273370016\n",
- "Time elapsed: 3.7584688663482666\n",
- "At batch 1300\n",
- "Training loss per word: 1.1341377882408181\n",
- "Training perplexity : 3.108492208596675\n",
- "Time elapsed: 3.7554566860198975\n",
- "At batch 1400\n",
- "Training loss per word: 1.0805350895705585\n",
- "Training perplexity : 2.9462556400187108\n",
- "Time elapsed: 3.770933151245117\n",
- "At batch 1500\n",
- "Training loss per word: 1.0837757330277602\n",
- "Training perplexity : 2.9558188912563033\n",
- "\n",
- "Validation loss per word: 1.3419258234432407\n",
- "Validation perplexity : 3.8264053954546737 \n",
- "\n"
- ]
- }
- ],
- "source": [
- "for i in range(20):\n",
- " train_epoch_packed(model, optimizer, train_loader, val_loader)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "Oz9Kg1p650nd"
- },
- "outputs": [],
- "source": [
- "torch.save(model, \"trained_model.pt\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "VHp8x3l650ng",
- "outputId": "159b75d2-be4a-4401-e1db-303f5ea64638"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "uarrel\n",
- "\n"
- ]
- }
- ],
- "source": [
- "print(generate(model, \"To be, or not to be, that is the q\",20))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "Ig5Y50kJ50ni",
- "outputId": "e5e9266a-76e6-4d6f-9b24-0f624ad5982f"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Scotland\n",
- "\n"
- ]
- }
- ],
- "source": [
- "print(generate(model, \"Richard \", 1000))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "8mMJwLSd50nm",
- "outputId": "3194b7e5-c1ef-4074-d90c-a0e152546f28"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "wear\n",
- "\n"
- ]
- }
- ],
- "source": [
- "print(generate(model, \"Hello\", 1000))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "8woC85Ud50np"
- },
- "source": [
- "### Reminders\n",
- "\n",
- "By default, for all rnn modules (rnn, GRU, LSTM) batch_first = False\n",
- "To use packed sequences, your inputs first need to be sorted in descending order of length (longest to shortest)\n",
- "Batches need to have inputs of the same length "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "u6sGgg7K50nq"
- },
- "source": []
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "provenance": []
- },
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.6"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
\ No newline at end of file
diff --git a/S24/document/recitation/Recitation8/language_model.ipynb b/S24/document/recitation/Recitation8/language_model.ipynb
deleted file mode 100644
index cbb4a844..00000000
--- a/S24/document/recitation/Recitation8/language_model.ipynb
+++ /dev/null
@@ -1 +0,0 @@
-{"cells":[{"cell_type":"markdown","metadata":{"id":"pFvgJbAu50m8"},"source":["# Shakespeare Character Language Model"]},{"cell_type":"code","execution_count":9,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":36},"executionInfo":{"elapsed":564,"status":"ok","timestamp":1666533005617,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"},"user_tz":240},"id":"mcIAFm9g50m9","outputId":"3a72c4cb-aae6-4078-c143-d2cd1545851a"},"outputs":[{"output_type":"execute_result","data":{"text/plain":["'cuda'"],"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"}},"metadata":{},"execution_count":9}],"source":["import torch\n","import torch.nn as nn\n","import torch.nn.utils.rnn as rnn\n","from torch.utils.data import Dataset, DataLoader, TensorDataset\n","import numpy as np\n","import time\n","\n","import shakespeare_data as sh\n","\n","DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n","DEVICE"]},{"cell_type":"markdown","metadata":{"id":"gN0cVBCS50nB"},"source":["## Fixed length input"]},{"cell_type":"code","execution_count":31,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":858,"status":"ok","timestamp":1666534317715,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"},"user_tz":240},"id":"uFhKFJEN50nB","outputId":"04863cf3-2215-4539-86dc-bf56ad09b78a","scrolled":true},"outputs":[{"output_type":"stream","name":"stdout","text":["First 203 characters...Last 50 characters\n","1609\n"," THE SONNETS\n"," by William Shakespeare\n"," 1\n"," From fairest creatures we desire increase,\n"," That thereby beauty's rose might never die,\n"," But as the riper should by time decease,\n","...,\n"," And new pervert a reconciled maid.'\n"," THE END\n","\n","Total character count: 5551930\n","Unique character count: 84\n","\n","shakespeare_array.shape: (5551930,)\n","\n","First 17 characters as indices [12 17 11 20 0 1 45 33 30 1 44 40 39 39 30 45 44]\n","First 17 characters as characters: ['1', '6', '0', '9', '\\n', ' ', 'T', 'H', 'E', ' ', 'S', 'O', 'N', 'N', 'E', 'T', 'S']\n","First 17 character indices as text:\n"," 1609\n"," THE SONNETS\n"]}],"source":["# Data - refer to shakespeare_data.py for details\n","corpus = sh.read_corpus()\n","print(\"First 203 characters...Last 50 characters\")\n","print(\"{}...{}\".format(corpus[:203], corpus[-50:]))\n","print(\"Total character count: {}\".format(len(corpus)))\n","chars, charmap = sh.get_charmap(corpus)\n","charcount = len(chars)\n","print(\"Unique character count: {}\\n\".format(len(chars)))\n","shakespeare_array = sh.map_corpus(corpus, charmap)\n","print(\"shakespeare_array.shape: {}\\n\".format(shakespeare_array.shape))\n","small_example = shakespeare_array[:17]\n","print(\"First 17 characters as indices\", small_example)\n","print(\"First 17 characters as characters:\", [chars[c] for c in small_example])\n","print(\"First 17 character indices as text:\\n\", sh.to_text(small_example,chars))"]},{"cell_type":"code","execution_count":32,"metadata":{"id":"DBcpz6iD50nD","executionInfo":{"status":"ok","timestamp":1666534320257,"user_tz":240,"elapsed":213,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}}},"outputs":[],"source":["# Dataset class. Transform raw text into a set of sequences of fixed length, and extracts inputs and targets\n","class TextDataset(Dataset):\n"," \n"," def __init__(self,text, seq_len = 200):\n"," n_seq = len(text) // seq_len\n"," text = text[:n_seq * seq_len]\n"," self.data = torch.tensor(text).view(-1,seq_len)\n"," \n"," def __getitem__(self,i):\n"," txt = self.data[i]\n"," \n"," # labels are the input sequence shifted by 1\n"," return txt[:-1],txt[1:]\n"," \n"," def __len__(self):\n"," return self.data.size(0)\n","\n","# Collate function. Transform a list of sequences into a batch. Passed as an argument to the DataLoader.\n","# Returns data of the format seq_len x batch_size\n","def collate(seq_list):\n"," inputs = torch.cat([s[0].unsqueeze(1) for s in seq_list],dim=1)\n"," targets = torch.cat([s[1].unsqueeze(1) for s in seq_list],dim=1)\n"," return inputs,targets\n"]},{"cell_type":"code","execution_count":33,"metadata":{"id":"iHb5PHQs50nF","executionInfo":{"status":"ok","timestamp":1666534323333,"user_tz":240,"elapsed":639,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}}},"outputs":[],"source":["# Model\n","class CharLanguageModel(nn.Module):\n","\n"," def __init__(self,vocab_size,embed_size,hidden_size, nlayers):\n"," super(CharLanguageModel,self).__init__()\n"," self.vocab_size=vocab_size\n"," self.embed_size = embed_size\n"," self.hidden_size = hidden_size\n"," self.nlayers=nlayers\n"," self.embedding = nn.Embedding(vocab_size,embed_size) # Embedding layer\n"," self.rnn = nn.LSTM(input_size = embed_size,hidden_size=hidden_size,num_layers=nlayers) # Recurrent network\n"," # You can also try GRUs instead of LSTMs.\n"," \n"," self.scoring = nn.Linear(hidden_size,vocab_size) # Projection layer\n"," \n"," def forward(self,seq_batch): #L x N\n"," # returns 3D logits\n"," batch_size = seq_batch.size(1)\n"," embed = self.embedding(seq_batch) #L x N x E\n"," hidden = None\n"," output_lstm,hidden = self.rnn(embed,hidden) #L x N x H\n"," output_lstm_flatten = output_lstm.view(-1,self.hidden_size) #(L*N) x H\n"," output_flatten = self.scoring(output_lstm_flatten) #(L*N) x V\n"," return output_flatten.view(-1,batch_size,self.vocab_size)\n"," \n"," def generate(self,seq, n_chars): # L x V\n"," # performs greedy search to extract and return words (one sequence).\n"," generated_chars = []\n"," embed = self.embedding(seq).unsqueeze(1) # L x 1 x E\n"," hidden = None\n"," output_lstm, hidden = self.rnn(embed,hidden) # L x 1 x H\n"," output = output_lstm[-1] # 1 x H\n"," scores = self.scoring(output) # 1 x V\n"," _,current_char = torch.max(scores,dim=1) # 1 x 1\n"," generated_chars.append(current_char)\n"," if n_chars > 1:\n"," for i in range(n_chars-1):\n"," embed = self.embedding(current_char).unsqueeze(0) # 1 x 1 x E\n"," output_lstm, hidden = self.rnn(embed,hidden) # 1 x 1 x H\n"," output = output_lstm[0] # 1 x H\n"," scores = self.scoring(output) # V\n"," _,current_char = torch.max(scores,dim=1) # 1\n"," generated_chars.append(current_char)\n"," return torch.cat(generated_chars,dim=0)\n"," \n"," "]},{"cell_type":"code","execution_count":34,"metadata":{"id":"QRxGHF6E50nH","executionInfo":{"status":"ok","timestamp":1666534326315,"user_tz":240,"elapsed":205,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}}},"outputs":[],"source":["def train_epoch(model, optimizer, train_loader, val_loader):\n"," criterion = nn.CrossEntropyLoss()\n"," criterion = criterion.to(DEVICE)\n"," before = time.time()\n"," print(\"training\", len(train_loader), \"number of batches\")\n"," for batch_idx, (inputs,targets) in enumerate(train_loader):\n"," if batch_idx == 0:\n"," first_time = time.time()\n"," inputs = inputs.to(DEVICE)\n"," targets = targets.to(DEVICE)\n"," outputs = model(inputs) # 3D\n"," loss = criterion(outputs.view(-1,outputs.size(2)),targets.view(-1)) # Loss of the flattened outputs\n"," optimizer.zero_grad()\n"," loss.backward()\n"," optimizer.step()\n"," \n"," if batch_idx == 0:\n"," print(\"Time elapsed\", time.time() - first_time)\n"," \n"," if batch_idx % 100 == 0 and batch_idx != 0:\n"," after = time.time()\n"," print(\"Time: \", after - before)\n"," print(\"Loss per word: \", loss.item() / batch_idx)\n"," print(\"Perplexity: \", np.exp(loss.item() / batch_idx))\n"," after = before\n"," \n"," val_loss = 0\n"," batch_id=0\n"," for inputs,targets in val_loader:\n"," batch_id+=1\n"," inputs = inputs.to(DEVICE)\n"," targets = targets.to(DEVICE)\n"," outputs = model(inputs)\n"," loss = criterion(outputs.view(-1,outputs.size(2)),targets.view(-1))\n"," val_loss+=loss.item()\n"," val_lpw = val_loss / batch_id\n"," print(\"\\nValidation loss per word:\",val_lpw)\n"," print(\"Validation perplexity :\",np.exp(val_lpw),\"\\n\")\n"," return val_lpw\n"," "]},{"cell_type":"code","execution_count":35,"metadata":{"id":"RNHa3FAU50nI","executionInfo":{"status":"ok","timestamp":1666534330083,"user_tz":240,"elapsed":331,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}}},"outputs":[],"source":["model = CharLanguageModel(charcount,256,256,3)\n","model = model.to(DEVICE)\n","optimizer = torch.optim.Adam(model.parameters(),lr=0.001, weight_decay=1e-6)\n","split = 5000000\n","train_dataset = TextDataset(shakespeare_array[:split])\n","val_dataset = TextDataset(shakespeare_array[split:])\n","train_loader = DataLoader(train_dataset, shuffle=True, batch_size=64, collate_fn = collate)\n","val_loader = DataLoader(val_dataset, shuffle=False, batch_size=64, collate_fn = collate, drop_last=True)"]},{"cell_type":"code","execution_count":57,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"VVROzTRT50nK","outputId":"2cf1af8e-19ac-4e60-a6b4-8180e41df5d1","scrolled":false,"executionInfo":{"status":"ok","timestamp":1666534917486,"user_tz":240,"elapsed":113354,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["training 391 number of batches\n","Time elapsed 0.07767295837402344\n","Time: 5.596414089202881\n","Loss per word: 0.01185429334640503\n","Perplexity: 1.0119248339425135\n","Time: 11.145010232925415\n","Loss per word: 0.005843929648399353\n","Perplexity: 1.0058610387170948\n","Time: 16.780640840530396\n","Loss per word: 0.003932354052861532\n","Perplexity: 1.0039400959016305\n","\n","Validation loss per word: 1.3321008599081705\n","Validation perplexity : 3.7889951799089627 \n","\n","training 391 number of batches\n","Time elapsed 0.04892563819885254\n","Time: 5.784719467163086\n","Loss per word: 0.011950627565383912\n","Perplexity: 1.0120223216266813\n","Time: 11.487154722213745\n","Loss per word: 0.005958112478256225\n","Perplexity: 1.0059758973342545\n","Time: 17.114842653274536\n","Loss per word: 0.003922495444615682\n","Perplexity: 1.0039301984983102\n","\n","Validation loss per word: 1.3247673289720403\n","Validation perplexity : 3.761310105292561 \n","\n","training 391 number of batches\n","Time elapsed 0.052561044692993164\n","Time: 5.5854432582855225\n","Loss per word: 0.011668695211410523\n","Perplexity: 1.0117370400082202\n","Time: 11.07470154762268\n","Loss per word: 0.005841174721717834\n","Perplexity: 1.0058582676474983\n","Time: 16.553176164627075\n","Loss per word: 0.003922032912572225\n","Perplexity: 1.0039297341485314\n","\n","Validation loss per word: 1.3197767152342685\n","Validation perplexity : 3.742585621604832 \n","\n","training 391 number of batches\n","Time elapsed 0.049115657806396484\n","Time: 5.514904737472534\n","Loss per word: 0.011583248376846314\n","Perplexity: 1.0116505939740628\n","Time: 10.967044353485107\n","Loss per word: 0.005864846110343933\n","Perplexity: 1.0058820779912654\n","Time: 16.433537483215332\n","Loss per word: 0.004020507335662842\n","Perplexity: 1.0040286004177446\n","\n","Validation loss per word: 1.313875392425892\n","Validation perplexity : 3.720564456623678 \n","\n","training 391 number of batches\n","Time elapsed 0.049449920654296875\n","Time: 5.5508527755737305\n","Loss per word: 0.011253877878189086\n","Perplexity: 1.011317440981854\n","Time: 11.068422079086304\n","Loss per word: 0.005857662558555603\n","Perplexity: 1.0058748522112186\n","Time: 16.625342845916748\n","Loss per word: 0.003796435991923014\n","Perplexity: 1.0038036515833308\n","\n","Validation loss per word: 1.30533528050711\n","Validation perplexity : 3.6889257112697154 \n","\n"]}],"source":["for i in range(5):\n"," train_epoch(model, optimizer, train_loader, val_loader)"]},{"cell_type":"code","execution_count":55,"metadata":{"id":"Xd-bklsK50nM","executionInfo":{"status":"ok","timestamp":1666534773997,"user_tz":240,"elapsed":640,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}}},"outputs":[],"source":["def generate(model, seed,nchars):\n"," seq = sh.map_corpus(seed, charmap)\n"," seq = torch.tensor(seq).to(DEVICE)\n"," out = model.generate(seq,nchars)\n"," return sh.to_text(out.cpu().detach().numpy(),chars)"]},{"cell_type":"code","execution_count":53,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"V-Sp34eF50nN","executionInfo":{"status":"ok","timestamp":1666534725139,"user_tz":240,"elapsed":211,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}},"outputId":"85a633de-af9d-405b-b2e5-4ad1a8e41998"},"outputs":[{"output_type":"stream","name":"stdout","text":["uestion\n","\n"]}],"source":["print(generate(model, \"To be, or not to be, that is the q\",8))"]},{"cell_type":"code","execution_count":56,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"3WPIhJur50nP","executionInfo":{"status":"ok","timestamp":1666534776721,"user_tz":240,"elapsed":860,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}},"outputId":"ac980873-12e5-48d3-8a3e-866f04d09768"},"outputs":[{"output_type":"stream","name":"stdout","text":["and the King of the compolutes\n"," \n"]}],"source":["print(generate(model, \"Richard \", 1000))"]},{"cell_type":"markdown","metadata":{"id":"OITobxJ_50nS"},"source":["## Packed sequences"]},{"cell_type":"code","execution_count":19,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"fZshan9w50nS","executionInfo":{"status":"ok","timestamp":1666533077684,"user_tz":240,"elapsed":1427,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}},"outputId":"92b2741d-e08d-4ac7-ff44-6cbea43420b2"},"outputs":[{"output_type":"stream","name":"stdout","text":["1609\n","\n"," THE SONNETS\n","\n"," by William Shakespeare\n","\n"," 1\n","\n"," From fairest creatures we desire increase,\n","\n"," That thereby beauty's rose might never die,\n","\n"," But as the riper should by time decease,\n","\n"," His tender heir might bear his memory:\n","\n"," But thou contracted to thine own bright eyes,\n","\n"," Feed'st thy light's flame with self-substantial fuel,\n","\n","114638\n"]}],"source":["stop_character = charmap['\\n']\n","space_character = charmap[\" \"]\n","lines = np.split(shakespeare_array, np.where(shakespeare_array == stop_character)[0]+1) # split the data in lines\n","shakespeare_lines = []\n","for s in lines:\n"," s_trimmed = np.trim_zeros(s-space_character)+space_character # remove space-only lines\n"," if len(s_trimmed)>1:\n"," shakespeare_lines.append(s)\n","for i in range(10):\n"," print(sh.to_text(shakespeare_lines[i],chars))\n","print(len(shakespeare_lines))"]},{"cell_type":"code","execution_count":20,"metadata":{"id":"J1UckTXZ50nU","executionInfo":{"status":"ok","timestamp":1666533077687,"user_tz":240,"elapsed":15,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}}},"outputs":[],"source":["class LinesDataset(Dataset):\n"," def __init__(self,lines):\n"," self.lines=[torch.tensor(l) for l in lines]\n"," def __getitem__(self,i):\n"," line = self.lines[i]\n"," return line[:-1].to(DEVICE),line[1:].to(DEVICE)\n"," def __len__(self):\n"," return len(self.lines)\n","\n","# collate fn lets you control the return value of each batch\n","# for packed_seqs, you want to return your data sorted by length\n","def collate_lines(seq_list):\n"," inputs,targets = zip(*seq_list)\n"," lens = [len(seq) for seq in inputs]\n"," seq_order = sorted(range(len(lens)), key=lens.__getitem__, reverse=True)\n"," inputs = [inputs[i] for i in seq_order]\n"," targets = [targets[i] for i in seq_order]\n"," return inputs,targets"]},{"cell_type":"code","execution_count":21,"metadata":{"id":"cg8ln5cG50nW","executionInfo":{"status":"ok","timestamp":1666533077690,"user_tz":240,"elapsed":17,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}}},"outputs":[],"source":["# Model that takes packed sequences in training\n","class PackedLanguageModel(nn.Module):\n"," \n"," def __init__(self,vocab_size,embed_size,hidden_size, nlayers, stop):\n"," super(PackedLanguageModel,self).__init__()\n"," self.vocab_size=vocab_size\n"," self.embed_size = embed_size\n"," self.hidden_size = hidden_size\n"," self.nlayers=nlayers\n"," self.embedding = nn.Embedding(vocab_size,embed_size)\n"," self.rnn = nn.LSTM(input_size = embed_size,hidden_size=hidden_size,num_layers=nlayers) # 1 layer, batch_size = False\n"," self.scoring = nn.Linear(hidden_size,vocab_size)\n"," self.stop = stop # stop line character (\\n)\n"," \n"," def forward(self,seq_list): # list\n"," batch_size = len(seq_list)\n"," lens = [len(s) for s in seq_list] # lens of all lines (already sorted)\n"," bounds = [0]\n"," for l in lens:\n"," bounds.append(bounds[-1]+l) # bounds of all lines in the concatenated sequence. Indexing into the list to \n"," # see where the sequence occurs. Need this at line marked **\n"," seq_concat = torch.cat(seq_list) # concatenated sequence\n"," embed_concat = self.embedding(seq_concat) # concatenated embeddings\n"," embed_list = [embed_concat[bounds[i]:bounds[i+1]] for i in range(batch_size)] # embeddings per line **\n"," packed_input = rnn.pack_sequence(embed_list) # packed version\n"," \n"," # alternatively, you could use rnn.pad_sequence, followed by rnn.pack_padded_sequence\n"," \n"," \n"," \n"," hidden = None\n"," output_packed,hidden = self.rnn(packed_input,hidden)\n"," output_padded, _ = rnn.pad_packed_sequence(output_packed) # unpacked output (padded). Also gives you the lengths\n"," output_flatten = torch.cat([output_padded[:lens[i],i] for i in range(batch_size)]) # concatenated output\n"," scores_flatten = self.scoring(output_flatten) # concatenated logits\n"," return scores_flatten # return concatenated logits\n"," \n"," def generate(self,seq, n_words): # L x V\n"," generated_words = []\n"," embed = self.embedding(seq).unsqueeze(1) # L x 1 x E\n"," hidden = None\n"," output_lstm, hidden = self.rnn(embed,hidden) # L x 1 x H\n"," output = output_lstm[-1] # 1 x H\n"," scores = self.scoring(output) # 1 x V\n"," _,current_word = torch.max(scores,dim=1) # 1 x 1\n"," generated_words.append(current_word)\n"," if n_words > 1:\n"," for i in range(n_words-1):\n"," embed = self.embedding(current_word).unsqueeze(0) # 1 x 1 x E\n"," output_lstm, hidden = self.rnn(embed,hidden) # 1 x 1 x H\n"," output = output_lstm[0] # 1 x H\n"," scores = self.scoring(output) # V\n"," _,current_word = torch.max(scores,dim=1) # 1\n"," generated_words.append(current_word)\n"," if current_word[0].item()==self.stop: # If end of line\n"," break\n"," return torch.cat(generated_words,dim=0)"]},{"cell_type":"code","execution_count":22,"metadata":{"id":"6vx3G8mc50nY","executionInfo":{"status":"ok","timestamp":1666533077692,"user_tz":240,"elapsed":17,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}}},"outputs":[],"source":["def train_epoch_packed(model, optimizer, train_loader, val_loader):\n"," criterion = nn.CrossEntropyLoss(reduction=\"sum\") # sum instead of averaging, to take into account the different lengths\n"," criterion = criterion.to(DEVICE)\n"," batch_id=0\n"," before = time.time()\n"," print(\"Training\", len(train_loader), \"number of batches\")\n"," for inputs,targets in train_loader: # lists, presorted, preloaded on GPU\n"," batch_id+=1\n"," outputs = model(inputs)\n"," loss = criterion(outputs,torch.cat(targets)) # criterion of the concatenated output\n"," optimizer.zero_grad()\n"," loss.backward()\n"," optimizer.step()\n"," if batch_id % 100 == 0:\n"," after = time.time()\n"," nwords = np.sum(np.array([len(l) for l in inputs]))\n"," lpw = loss.item() / nwords\n"," print(\"Time elapsed: \", after - before)\n"," print(\"At batch\",batch_id)\n"," print(\"Training loss per word:\",lpw)\n"," print(\"Training perplexity :\",np.exp(lpw))\n"," before = after\n"," \n"," val_loss = 0\n"," batch_id=0\n"," nwords = 0\n"," for inputs,targets in val_loader:\n"," nwords += np.sum(np.array([len(l) for l in inputs]))\n"," batch_id+=1\n"," outputs = model(inputs)\n"," loss = criterion(outputs,torch.cat(targets))\n"," val_loss+=loss.item()\n"," val_lpw = val_loss / nwords\n"," print(\"\\nValidation loss per word:\",val_lpw)\n"," print(\"Validation perplexity :\",np.exp(val_lpw),\"\\n\")\n"," return val_lpw"]},{"cell_type":"code","execution_count":23,"metadata":{"id":"FFvvuete50na","executionInfo":{"status":"ok","timestamp":1666533078971,"user_tz":240,"elapsed":1295,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}}},"outputs":[],"source":["model = PackedLanguageModel(charcount,256,256,3, stop=stop_character)\n","model = model.to(DEVICE)\n","optimizer = torch.optim.Adam(model.parameters(),lr=0.001, weight_decay=1e-6)\n","split = 100000\n","train_dataset = LinesDataset(shakespeare_lines[:split])\n","val_dataset = LinesDataset(shakespeare_lines[split:])\n","train_loader = DataLoader(train_dataset, shuffle=True, batch_size=64, collate_fn = collate_lines)\n","val_loader = DataLoader(val_dataset, shuffle=False, batch_size=64, collate_fn = collate_lines, drop_last=True)"]},{"cell_type":"code","execution_count":24,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"4SFfNTCL50nb","scrolled":true,"executionInfo":{"status":"ok","timestamp":1666534295811,"user_tz":240,"elapsed":915140,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}},"outputId":"dceb7e6a-3f15-46d9-b2e7-82041b108222"},"outputs":[{"output_type":"stream","name":"stdout","text":["Training 1563 number of batches\n","Time elapsed: 3.7703583240509033\n","At batch 100\n","Training loss per word: 2.7013041346193503\n","Training perplexity : 14.899149557123419\n","Time elapsed: 3.691027879714966\n","At batch 200\n","Training loss per word: 2.197311673753789\n","Training perplexity : 9.000783901895156\n","Time elapsed: 3.6885509490966797\n","At batch 300\n","Training loss per word: 1.9548370171777951\n","Training perplexity : 7.062767820058658\n","Time elapsed: 3.6700551509857178\n","At batch 400\n","Training loss per word: 1.7970468966562716\n","Training perplexity : 6.03180858325135\n","Time elapsed: 3.7132349014282227\n","At batch 500\n","Training loss per word: 1.8254557599342311\n","Training perplexity : 6.205622648866053\n","Time elapsed: 3.7091219425201416\n","At batch 600\n","Training loss per word: 1.8306482488458806\n","Training perplexity : 6.237929078461925\n","Time elapsed: 3.708371639251709\n","At batch 700\n","Training loss per word: 1.7147244232747705\n","Training perplexity : 5.555144432895182\n","Time elapsed: 3.713968515396118\n","At batch 800\n","Training loss per word: 1.6335569952082643\n","Training perplexity : 5.12206150243047\n","Time elapsed: 3.7396087646484375\n","At batch 900\n","Training loss per word: 1.5792983890503876\n","Training perplexity : 4.851550715746207\n","Time elapsed: 3.7228660583496094\n","At batch 1000\n","Training loss per word: 1.621026875629406\n","Training perplexity : 5.058281876950357\n","Time elapsed: 3.7168807983398438\n","At batch 1100\n","Training loss per word: 1.5696750812099223\n","Training perplexity : 4.805086677156511\n","Time elapsed: 3.7677266597747803\n","At batch 1200\n","Training loss per word: 1.5301319406920428\n","Training perplexity : 4.6187861879448695\n","Time elapsed: 3.72387433052063\n","At batch 1300\n","Training loss per word: 1.5210934303396073\n","Training perplexity : 4.577227339139673\n","Time elapsed: 3.749553918838501\n","At batch 1400\n","Training loss per word: 1.4983956473214286\n","Training perplexity : 4.474504625206471\n","Time elapsed: 3.761760711669922\n","At batch 1500\n","Training loss per word: 1.4545306382094139\n","Training perplexity : 4.282472964634041\n","\n","Validation loss per word: 1.5616251907288874\n","Validation perplexity : 4.766561525318013 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.7287967205047607\n","At batch 100\n","Training loss per word: 1.4484571357280522\n","Training perplexity : 4.256542179510825\n","Time elapsed: 3.7227351665496826\n","At batch 200\n","Training loss per word: 1.4475157597876576\n","Training perplexity : 4.252537058571778\n","Time elapsed: 3.7012224197387695\n","At batch 300\n","Training loss per word: 1.4798200545277405\n","Training perplexity : 4.392155261351937\n","Time elapsed: 3.7456235885620117\n","At batch 400\n","Training loss per word: 1.4387576705451068\n","Training perplexity : 4.215455577988563\n","Time elapsed: 3.728778123855591\n","At batch 500\n","Training loss per word: 1.3787083915855327\n","Training perplexity : 3.9697709252483686\n","Time elapsed: 3.708531141281128\n","At batch 600\n","Training loss per word: 1.4090683109504132\n","Training perplexity : 4.092141024457387\n","Time elapsed: 3.714569091796875\n","At batch 700\n","Training loss per word: 1.4055531548691633\n","Training perplexity : 4.0777817623593124\n","Time elapsed: 3.7208497524261475\n","At batch 800\n","Training loss per word: 1.2926560013746777\n","Training perplexity : 3.6424480666522805\n","Time elapsed: 3.730631113052368\n","At batch 900\n","Training loss per word: 1.3833357829224782\n","Training perplexity : 3.988183176328358\n","Time elapsed: 3.714390277862549\n","At batch 1000\n","Training loss per word: 1.331831748307841\n","Training perplexity : 3.787975654541661\n","Time elapsed: 3.7247533798217773\n","At batch 1100\n","Training loss per word: 1.292771070749634\n","Training perplexity : 3.6428672249903027\n","Time elapsed: 3.727574348449707\n","At batch 1200\n","Training loss per word: 1.3829026442307693\n","Training perplexity : 3.9864561139408403\n","Time elapsed: 3.6857106685638428\n","At batch 1300\n","Training loss per word: 1.405520340086133\n","Training perplexity : 4.0776479530310095\n","Time elapsed: 3.7007181644439697\n","At batch 1400\n","Training loss per word: 1.2932184916919702\n","Training perplexity : 3.6444974847558975\n","Time elapsed: 3.7181015014648438\n","At batch 1500\n","Training loss per word: 1.3205488685079587\n","Training perplexity : 3.7454765873353084\n","\n","Validation loss per word: 1.4662550231819402\n","Validation perplexity : 4.3329778169347035 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.7183938026428223\n","At batch 100\n","Training loss per word: 1.407801974281519\n","Training perplexity : 4.086962275935464\n","Time elapsed: 3.736786127090454\n","At batch 200\n","Training loss per word: 1.270954735225231\n","Training perplexity : 3.564253857138264\n","Time elapsed: 3.7378339767456055\n","At batch 300\n","Training loss per word: 1.3011009506032436\n","Training perplexity : 3.67333860657824\n","Time elapsed: 3.6993186473846436\n","At batch 400\n","Training loss per word: 1.3311979819283615\n","Training perplexity : 3.785575723503658\n","Time elapsed: 3.7243549823760986\n","At batch 500\n","Training loss per word: 1.3467497694500334\n","Training perplexity : 3.844908361255419\n","Time elapsed: 3.6973633766174316\n","At batch 600\n","Training loss per word: 1.245755025093129\n","Training perplexity : 3.475557942401932\n","Time elapsed: 3.7047317028045654\n","At batch 700\n","Training loss per word: 1.302772730860903\n","Training perplexity : 3.6794847576159344\n","Time elapsed: 3.7606568336486816\n","At batch 800\n","Training loss per word: 1.3673185022865855\n","Training perplexity : 3.9248121973730536\n","Time elapsed: 3.712702751159668\n","At batch 900\n","Training loss per word: 1.2946469603466386\n","Training perplexity : 3.6497072552859513\n","Time elapsed: 3.7113733291625977\n","At batch 1000\n","Training loss per word: 1.3207825568993778\n","Training perplexity : 3.746351964012801\n","Time elapsed: 3.7720985412597656\n","At batch 1100\n","Training loss per word: 1.2613232793426998\n","Training perplexity : 3.5300896927826253\n","Time elapsed: 3.7074599266052246\n","At batch 1200\n","Training loss per word: 1.2277225235133495\n","Training perplexity : 3.413446632519601\n","Time elapsed: 3.7118020057678223\n","At batch 1300\n","Training loss per word: 1.348728454331585\n","Training perplexity : 3.852523755048424\n","Time elapsed: 3.7075886726379395\n","At batch 1400\n","Training loss per word: 1.2741113349374347\n","Training perplexity : 3.5755225558666193\n","Time elapsed: 3.7233262062072754\n","At batch 1500\n","Training loss per word: 1.2676521634823907\n","Training perplexity : 3.552502069307952\n","\n","Validation loss per word: 1.437432799953033\n","Validation perplexity : 4.209874342884608 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.7521398067474365\n","At batch 100\n","Training loss per word: 1.2984101040072107\n","Training perplexity : 3.6634675026619434\n","Time elapsed: 3.6541831493377686\n","At batch 200\n","Training loss per word: 1.2519192490281101\n","Training perplexity : 3.4970482272737584\n","Time elapsed: 3.718026638031006\n","At batch 300\n","Training loss per word: 1.3276140372983871\n","Training perplexity : 3.77203271291436\n","Time elapsed: 3.684814453125\n","At batch 400\n","Training loss per word: 1.2439926560950925\n","Training perplexity : 3.469438121109008\n","Time elapsed: 3.7279927730560303\n","At batch 500\n","Training loss per word: 1.2611345896068025\n","Training perplexity : 3.5294236639291805\n","Time elapsed: 3.7200450897216797\n","At batch 600\n","Training loss per word: 1.2275730937179705\n","Training perplexity : 3.4129365999957435\n","Time elapsed: 3.719654083251953\n","At batch 700\n","Training loss per word: 1.2286639347254673\n","Training perplexity : 3.4166616025183822\n","Time elapsed: 3.7141995429992676\n","At batch 800\n","Training loss per word: 1.293564885779272\n","Training perplexity : 3.6457601358106078\n","Time elapsed: 3.7546281814575195\n","At batch 900\n","Training loss per word: 1.4087229681558935\n","Training perplexity : 4.090728077030081\n","Time elapsed: 3.700044870376587\n","At batch 1000\n","Training loss per word: 1.2707874804559762\n","Training perplexity : 3.563657768532543\n","Time elapsed: 3.717473030090332\n","At batch 1100\n","Training loss per word: 1.2463651149392985\n","Training perplexity : 3.477678991961975\n","Time elapsed: 3.7167797088623047\n","At batch 1200\n","Training loss per word: 1.3100708621231156\n","Training perplexity : 3.7064363488534666\n","Time elapsed: 3.720799446105957\n","At batch 1300\n","Training loss per word: 1.2738981620998784\n","Training perplexity : 3.574760432812492\n","Time elapsed: 3.6796560287475586\n","At batch 1400\n","Training loss per word: 1.3129865208128482\n","Training perplexity : 3.717258821853787\n","Time elapsed: 3.6994059085845947\n","At batch 1500\n","Training loss per word: 1.2379438652061638\n","Training perplexity : 3.448515557311611\n","\n","Validation loss per word: 1.401114204142619\n","Validation perplexity : 4.059720805547787 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.7109155654907227\n","At batch 100\n","Training loss per word: 1.2433489549446974\n","Training perplexity : 3.4672055584278976\n","Time elapsed: 3.6977715492248535\n","At batch 200\n","Training loss per word: 1.2903910941613195\n","Training perplexity : 3.6342075952260915\n","Time elapsed: 3.7268831729888916\n","At batch 300\n","Training loss per word: 1.1544621405909903\n","Training perplexity : 3.1723166981967363\n","Time elapsed: 3.6845014095306396\n","At batch 400\n","Training loss per word: 1.2515151961482813\n","Training perplexity : 3.4956355202900102\n","Time elapsed: 3.7195098400115967\n","At batch 500\n","Training loss per word: 1.2464558344057315\n","Training perplexity : 3.4779944994556704\n","Time elapsed: 3.6965761184692383\n","At batch 600\n","Training loss per word: 1.2064722318284709\n","Training perplexity : 3.3416751789182\n","Time elapsed: 3.7104105949401855\n","At batch 700\n","Training loss per word: 1.2265445738779892\n","Training perplexity : 3.409428131564032\n","Time elapsed: 3.7035112380981445\n","At batch 800\n","Training loss per word: 1.1967898446772212\n","Training perplexity : 3.3094759204978104\n","Time elapsed: 3.7424962520599365\n","At batch 900\n","Training loss per word: 1.234192673040896\n","Training perplexity : 3.4356037452559796\n","Time elapsed: 3.68930983543396\n","At batch 1000\n","Training loss per word: 1.2543550637093812\n","Training perplexity : 3.5055767714466266\n","Time elapsed: 3.7369351387023926\n","At batch 1100\n","Training loss per word: 1.1588790893554688\n","Training perplexity : 3.1863596491840465\n","Time elapsed: 3.718149423599243\n","At batch 1200\n","Training loss per word: 1.2224001718996451\n","Training perplexity : 3.395327330746977\n","Time elapsed: 3.704207420349121\n","At batch 1300\n","Training loss per word: 1.2001046291157973\n","Training perplexity : 3.320464321808229\n","Time elapsed: 3.7240679264068604\n","At batch 1400\n","Training loss per word: 1.2382448136659627\n","Training perplexity : 3.4495535389388285\n","Time elapsed: 3.694542646408081\n","At batch 1500\n","Training loss per word: 1.1240893522690039\n","Training perplexity : 3.0774131332260097\n","\n","Validation loss per word: 1.3893149912296887\n","Validation perplexity : 4.012100787239498 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.760380983352661\n","At batch 100\n","Training loss per word: 1.2399032188772912\n","Training perplexity : 3.455279042795036\n","Time elapsed: 3.7267673015594482\n","At batch 200\n","Training loss per word: 1.19580353990113\n","Training perplexity : 3.3062133777862766\n","Time elapsed: 3.710188627243042\n","At batch 300\n","Training loss per word: 1.2362357496620215\n","Training perplexity : 3.4426301222165097\n","Time elapsed: 3.715798854827881\n","At batch 400\n","Training loss per word: 1.2031334475091569\n","Training perplexity : 3.330536651148724\n","Time elapsed: 3.7028536796569824\n","At batch 500\n","Training loss per word: 1.2586697915761813\n","Training perplexity : 3.5207350596592186\n","Time elapsed: 3.7149481773376465\n","At batch 600\n","Training loss per word: 1.2212349467534946\n","Training perplexity : 3.3913733140689897\n","Time elapsed: 3.6803011894226074\n","At batch 700\n","Training loss per word: 1.2101083691578483\n","Training perplexity : 3.35384808654884\n","Time elapsed: 3.6926045417785645\n","At batch 800\n","Training loss per word: 1.159741473454301\n","Training perplexity : 3.1891087002772642\n","Time elapsed: 3.7095448970794678\n","At batch 900\n","Training loss per word: 1.1972919941590259\n","Training perplexity : 3.3111381894351473\n","Time elapsed: 3.7010698318481445\n","At batch 1000\n","Training loss per word: 1.19060720627464\n","Training perplexity : 3.2890777498126944\n","Time elapsed: 3.7239255905151367\n","At batch 1100\n","Training loss per word: 1.2420086092351343\n","Training perplexity : 3.462561417406009\n","Time elapsed: 3.7008190155029297\n","At batch 1200\n","Training loss per word: 1.217229673189615\n","Training perplexity : 3.3778171024794905\n","Time elapsed: 3.6846933364868164\n","At batch 1300\n","Training loss per word: 1.140925168829449\n","Training perplexity : 3.1296624923862377\n","Time elapsed: 3.701658010482788\n","At batch 1400\n","Training loss per word: 1.1341194704659505\n","Training perplexity : 3.1084352684577303\n","Time elapsed: 3.7287278175354004\n","At batch 1500\n","Training loss per word: 1.2465296672952586\n","Training perplexity : 3.478251299319346\n","\n","Validation loss per word: 1.3809387088436718\n","Validation perplexity : 3.9786346546438813 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.7069437503814697\n","At batch 100\n","Training loss per word: 1.1683425449489837\n","Training perplexity : 3.2166567537715474\n","Time elapsed: 3.7282252311706543\n","At batch 200\n","Training loss per word: 1.1665503661533765\n","Training perplexity : 3.210897092457753\n","Time elapsed: 3.734584331512451\n","At batch 300\n","Training loss per word: 1.178677349525876\n","Training perplexity : 3.2500726486607086\n","Time elapsed: 3.6963186264038086\n","At batch 400\n","Training loss per word: 1.2070136789678938\n","Training perplexity : 3.3434850093042336\n","Time elapsed: 3.705839157104492\n","At batch 500\n","Training loss per word: 1.1855977651235219\n","Training perplexity : 3.2726425084402857\n","Time elapsed: 3.7076480388641357\n","At batch 600\n","Training loss per word: 1.2006107478946835\n","Training perplexity : 3.322145296506666\n","Time elapsed: 3.6715080738067627\n","At batch 700\n","Training loss per word: 1.1356211530321882\n","Training perplexity : 3.1131066581029954\n","Time elapsed: 3.723267078399658\n","At batch 800\n","Training loss per word: 1.167945887890892\n","Training perplexity : 3.215381097182526\n","Time elapsed: 3.7093513011932373\n","At batch 900\n","Training loss per word: 1.1677971639537472\n","Training perplexity : 3.2149029286047703\n","Time elapsed: 3.6834521293640137\n","At batch 1000\n","Training loss per word: 1.1820589142832822\n","Training perplexity : 3.261081583010059\n","Time elapsed: 3.7082276344299316\n","At batch 1100\n","Training loss per word: 1.2310498327591308\n","Training perplexity : 3.4248231411453696\n","Time elapsed: 3.691101551055908\n","At batch 1200\n","Training loss per word: 1.227515040792063\n","Training perplexity : 3.4127384747911065\n","Time elapsed: 3.712756872177124\n","At batch 1300\n","Training loss per word: 1.1833673983071586\n","Training perplexity : 3.2653514490785387\n","Time elapsed: 3.6986610889434814\n","At batch 1400\n","Training loss per word: 1.180272545291736\n","Training perplexity : 3.255261288136178\n","Time elapsed: 3.7075552940368652\n","At batch 1500\n","Training loss per word: 1.1456190812371883\n","Training perplexity : 3.1443873856349542\n","\n","Validation loss per word: 1.3714664597770394\n","Validation perplexity : 3.941125962537374 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.714073657989502\n","At batch 100\n","Training loss per word: 1.2042411727831368\n","Training perplexity : 3.334228014904123\n","Time elapsed: 3.700242757797241\n","At batch 200\n","Training loss per word: 1.080187253760745\n","Training perplexity : 2.9452310050149517\n","Time elapsed: 3.6805882453918457\n","At batch 300\n","Training loss per word: 1.1463370953786645\n","Training perplexity : 3.1466459109736546\n","Time elapsed: 3.697406768798828\n","At batch 400\n","Training loss per word: 1.1752805264313133\n","Training perplexity : 3.23905145594695\n","Time elapsed: 3.6938788890838623\n","At batch 500\n","Training loss per word: 1.1683631187550978\n","Training perplexity : 3.2167229333247156\n","Time elapsed: 3.713092088699341\n","At batch 600\n","Training loss per word: 1.214711290724734\n","Training perplexity : 3.369321169613484\n","Time elapsed: 3.696408271789551\n","At batch 700\n","Training loss per word: 1.148598030821918\n","Training perplexity : 3.15376832286384\n","Time elapsed: 3.692023754119873\n","At batch 800\n","Training loss per word: 1.0949797689332248\n","Training perplexity : 2.9891222096506187\n","Time elapsed: 3.699280023574829\n","At batch 900\n","Training loss per word: 1.1116656863747953\n","Training perplexity : 3.039416895638226\n","Time elapsed: 3.7105231285095215\n","At batch 1000\n","Training loss per word: 1.1340194229200653\n","Training perplexity : 3.1081242926940185\n","Time elapsed: 3.709425687789917\n","At batch 1100\n","Training loss per word: 1.1836204710337355\n","Training perplexity : 3.266177925047841\n","Time elapsed: 3.7070138454437256\n","At batch 1200\n","Training loss per word: 1.1820404978742913\n","Training perplexity : 3.261021526150891\n","Time elapsed: 3.7249083518981934\n","At batch 1300\n","Training loss per word: 1.1580513136380806\n","Training perplexity : 3.183723149405381\n","Time elapsed: 3.6846048831939697\n","At batch 1400\n","Training loss per word: 1.1968900566087843\n","Training perplexity : 3.3098075860904124\n","Time elapsed: 3.7383875846862793\n","At batch 1500\n","Training loss per word: 1.085747366480498\n","Training perplexity : 2.9616524315744126\n","\n","Validation loss per word: 1.3639650511099437\n","Validation perplexity : 3.9116725751460977 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.731407642364502\n","At batch 100\n","Training loss per word: 1.138607757260101\n","Training perplexity : 3.122418173594389\n","Time elapsed: 3.737220287322998\n","At batch 200\n","Training loss per word: 1.0827404090447155\n","Training perplexity : 2.952760244686936\n","Time elapsed: 3.7554142475128174\n","At batch 300\n","Training loss per word: 1.139403432521323\n","Training perplexity : 3.1249035931526095\n","Time elapsed: 3.7403788566589355\n","At batch 400\n","Training loss per word: 1.2261186223713056\n","Training perplexity : 3.4079761897648138\n","Time elapsed: 3.7107059955596924\n","At batch 500\n","Training loss per word: 1.1567305672268908\n","Training perplexity : 3.1795210340568025\n","Time elapsed: 3.722501039505005\n","At batch 600\n","Training loss per word: 1.1871671037373888\n","Training perplexity : 3.277782424777861\n","Time elapsed: 3.7460691928863525\n","At batch 700\n","Training loss per word: 1.1577818131630053\n","Training perplexity : 3.1828652501114343\n","Time elapsed: 3.716709613800049\n","At batch 800\n","Training loss per word: 1.2216747341654488\n","Training perplexity : 3.392865125377627\n","Time elapsed: 3.731794834136963\n","At batch 900\n","Training loss per word: 1.1167742776497225\n","Training perplexity : 3.0549837627980807\n","Time elapsed: 3.727544069290161\n","At batch 1000\n","Training loss per word: 1.2042394139196992\n","Training perplexity : 3.3342221504575322\n","Time elapsed: 3.7003278732299805\n","At batch 1100\n","Training loss per word: 1.1784710476808886\n","Training perplexity : 3.2494022218344703\n","Time elapsed: 3.720731019973755\n","At batch 1200\n","Training loss per word: 1.1495366143521943\n","Training perplexity : 3.156729787443522\n","Time elapsed: 3.852750539779663\n","At batch 1300\n","Training loss per word: 1.1285352658002805\n","Training perplexity : 3.091125505339989\n","Time elapsed: 3.7642767429351807\n","At batch 1400\n","Training loss per word: 1.1840574717695236\n","Training perplexity : 3.267605559120152\n","Time elapsed: 3.772087335586548\n","At batch 1500\n","Training loss per word: 1.1927117184879235\n","Training perplexity : 3.2960069428358536\n","\n","Validation loss per word: 1.355904031907872\n","Validation perplexity : 3.8802672569020276 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.7840490341186523\n","At batch 100\n","Training loss per word: 1.1124401816217846\n","Training perplexity : 3.0417718213992053\n","Time elapsed: 3.8998072147369385\n","At batch 200\n","Training loss per word: 1.1762464859835287\n","Training perplexity : 3.242181760271219\n","Time elapsed: 3.8313801288604736\n","At batch 300\n","Training loss per word: 1.141115539648126\n","Training perplexity : 3.13025834551182\n","Time elapsed: 3.8294456005096436\n","At batch 400\n","Training loss per word: 1.1098259095620955\n","Training perplexity : 3.0338301876332396\n","Time elapsed: 3.8397085666656494\n","At batch 500\n","Training loss per word: 1.1858041221217106\n","Training perplexity : 3.273317910809078\n","Time elapsed: 3.8198180198669434\n","At batch 600\n","Training loss per word: 1.1565715791078035\n","Training perplexity : 3.179015568170599\n","Time elapsed: 3.7804319858551025\n","At batch 700\n","Training loss per word: 1.1901554751047891\n","Training perplexity : 3.287592306408844\n","Time elapsed: 3.831249475479126\n","At batch 800\n","Training loss per word: 1.0787425135484199\n","Training perplexity : 2.9409789836201603\n","Time elapsed: 3.842750072479248\n","At batch 900\n","Training loss per word: 1.1146143615015685\n","Training perplexity : 3.048392375021446\n","Time elapsed: 3.820983409881592\n","At batch 1000\n","Training loss per word: 1.1737537865402736\n","Training perplexity : 3.234110039968315\n","Time elapsed: 3.8320257663726807\n","At batch 1100\n","Training loss per word: 1.1355072420852617\n","Training perplexity : 3.112752061372295\n","Time elapsed: 3.8270204067230225\n","At batch 1200\n","Training loss per word: 1.2335316416539759\n","Training perplexity : 3.43333345379697\n","Time elapsed: 3.795576572418213\n","At batch 1300\n","Training loss per word: 1.1775943048286124\n","Training perplexity : 3.246554580169454\n","Time elapsed: 3.7663586139678955\n","At batch 1400\n","Training loss per word: 1.1165851997106095\n","Training perplexity : 3.054406187369294\n","Time elapsed: 3.792233467102051\n","At batch 1500\n","Training loss per word: 1.0653275908562894\n","Training perplexity : 2.901789428056261\n","\n","Validation loss per word: 1.349567706146099\n","Validation perplexity : 3.8557583497301797 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.801424264907837\n","At batch 100\n","Training loss per word: 1.1798410001008406\n","Training perplexity : 3.2538567988542693\n","Time elapsed: 3.7891664505004883\n","At batch 200\n","Training loss per word: 1.2098545373174743\n","Training perplexity : 3.352996881152783\n","Time elapsed: 3.813218832015991\n","At batch 300\n","Training loss per word: 1.0777284977000972\n","Training perplexity : 2.9379982958089275\n","Time elapsed: 3.8325226306915283\n","At batch 400\n","Training loss per word: 1.1992148277407786\n","Training perplexity : 3.317511082182175\n","Time elapsed: 3.787041664123535\n","At batch 500\n","Training loss per word: 1.2374958790787536\n","Training perplexity : 3.44697101617411\n","Time elapsed: 3.8416755199432373\n","At batch 600\n","Training loss per word: 1.18956913508422\n","Training perplexity : 3.2856652244861184\n","Time elapsed: 3.7902870178222656\n","At batch 700\n","Training loss per word: 1.1219792623153455\n","Training perplexity : 3.070926360933708\n","Time elapsed: 3.7493112087249756\n","At batch 800\n","Training loss per word: 1.1325905098820364\n","Training perplexity : 3.1036862249299535\n","Time elapsed: 3.7742223739624023\n","At batch 900\n","Training loss per word: 1.11555278468373\n","Training perplexity : 3.0512543997796473\n","Time elapsed: 3.787297248840332\n","At batch 1000\n","Training loss per word: 1.1625877173310282\n","Training perplexity : 3.1981986113029186\n","Time elapsed: 3.8134567737579346\n","At batch 1100\n","Training loss per word: 1.1657647944307146\n","Training perplexity : 3.2083756929972673\n","Time elapsed: 3.766493797302246\n","At batch 1200\n","Training loss per word: 1.1748896632588532\n","Training perplexity : 3.2377856774083393\n","Time elapsed: 3.8155629634857178\n","At batch 1300\n","Training loss per word: 1.1103692056259904\n","Training perplexity : 3.0354789034625624\n","Time elapsed: 3.779963254928589\n","At batch 1400\n","Training loss per word: 1.14418991708818\n","Training perplexity : 3.1398967496051684\n","Time elapsed: 3.763909339904785\n","At batch 1500\n","Training loss per word: 1.1915889870472838\n","Training perplexity : 3.2923084887863863\n","\n","Validation loss per word: 1.3544946988514919\n","Validation perplexity : 3.8748025197111846 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.7811734676361084\n","At batch 100\n","Training loss per word: 1.0663566247449183\n","Training perplexity : 2.9047770046105246\n","Time elapsed: 3.787771463394165\n","At batch 200\n","Training loss per word: 1.1603800186258277\n","Training perplexity : 3.1911457405411974\n","Time elapsed: 3.8014190196990967\n","At batch 300\n","Training loss per word: 1.1064504523026315\n","Training perplexity : 3.02360688736574\n","Time elapsed: 3.7817535400390625\n","At batch 400\n","Training loss per word: 1.1599953924517536\n","Training perplexity : 3.1899185783785726\n","Time elapsed: 3.7866930961608887\n","At batch 500\n","Training loss per word: 1.1528212436409884\n","Training perplexity : 3.167115521866446\n","Time elapsed: 3.805013656616211\n","At batch 600\n","Training loss per word: 1.0859856002916397\n","Training perplexity : 2.962358081371947\n","Time elapsed: 3.8244552612304688\n","At batch 700\n","Training loss per word: 1.1673833375336022\n","Training perplexity : 3.2135727920765134\n","Time elapsed: 3.844132661819458\n","At batch 800\n","Training loss per word: 1.1472129162707325\n","Training perplexity : 3.14940301639145\n","Time elapsed: 3.7915048599243164\n","At batch 900\n","Training loss per word: 1.1193068434255191\n","Training perplexity : 3.0627305155612508\n","Time elapsed: 3.854457139968872\n","At batch 1000\n","Training loss per word: 1.1248606687898088\n","Training perplexity : 3.079787708473843\n","Time elapsed: 3.794856071472168\n","At batch 1100\n","Training loss per word: 1.104549096162336\n","Training perplexity : 3.0178633957863994\n","Time elapsed: 3.775325298309326\n","At batch 1200\n","Training loss per word: 1.132256503959687\n","Training perplexity : 3.1026497484539894\n","Time elapsed: 3.863532304763794\n","At batch 1300\n","Training loss per word: 1.1573611899207061\n","Training perplexity : 3.1815267445331434\n","Time elapsed: 3.7871999740600586\n","At batch 1400\n","Training loss per word: 1.1226295386014753\n","Training perplexity : 3.0729239609482413\n","Time elapsed: 3.803307294845581\n","At batch 1500\n","Training loss per word: 1.0962326895043732\n","Training perplexity : 2.992869689513836\n","\n","Validation loss per word: 1.3441824345825584\n","Validation perplexity : 3.8350498544164515 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.823367118835449\n","At batch 100\n","Training loss per word: 1.098994354081284\n","Training perplexity : 3.0011464152283818\n","Time elapsed: 3.828970432281494\n","At batch 200\n","Training loss per word: 1.099687687303654\n","Training perplexity : 3.003227931251978\n","Time elapsed: 3.8045783042907715\n","At batch 300\n","Training loss per word: 1.1488120911449526\n","Training perplexity : 3.15444349179067\n","Time elapsed: 3.807218074798584\n","At batch 400\n","Training loss per word: 1.137648790458962\n","Training perplexity : 3.11942531348274\n","Time elapsed: 3.826814651489258\n","At batch 500\n","Training loss per word: 1.1339701912040903\n","Training perplexity : 3.10797127816824\n","Time elapsed: 3.8112239837646484\n","At batch 600\n","Training loss per word: 1.1199135349428997\n","Training perplexity : 3.0645892119557208\n","Time elapsed: 3.841059923171997\n","At batch 700\n","Training loss per word: 1.1574921703895604\n","Training perplexity : 3.1819434896899574\n","Time elapsed: 3.8483245372772217\n","At batch 800\n","Training loss per word: 1.098882556849023\n","Training perplexity : 3.000810914119946\n","Time elapsed: 3.808338165283203\n","At batch 900\n","Training loss per word: 1.1703846118244168\n","Training perplexity : 3.22323209333467\n","Time elapsed: 3.8445372581481934\n","At batch 1000\n","Training loss per word: 1.1218593718899283\n","Training perplexity : 3.0705582083352976\n","Time elapsed: 3.8156325817108154\n","At batch 1100\n","Training loss per word: 1.081787839121139\n","Training perplexity : 2.949948873312807\n","Time elapsed: 3.8010196685791016\n","At batch 1200\n","Training loss per word: 1.1653183690759519\n","Training perplexity : 3.2069437124003146\n","Time elapsed: 3.8367254734039307\n","At batch 1300\n","Training loss per word: 1.0637366903981855\n","Training perplexity : 2.8971766401419488\n","Time elapsed: 3.816741943359375\n","At batch 1400\n","Training loss per word: 1.1049788031586625\n","Training perplexity : 3.019160471462814\n","Time elapsed: 3.8468456268310547\n","At batch 1500\n","Training loss per word: 1.1472445435550493\n","Training perplexity : 3.149502625031245\n","\n","Validation loss per word: 1.3474884972846892\n","Validation perplexity : 3.8477497514613797 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.8349106311798096\n","At batch 100\n","Training loss per word: 1.1429718584217508\n","Training perplexity : 3.136074499492635\n","Time elapsed: 3.818223237991333\n","At batch 200\n","Training loss per word: 1.114212308431097\n","Training perplexity : 3.0471670058552913\n","Time elapsed: 3.8348867893218994\n","At batch 300\n","Training loss per word: 1.1334894353693181\n","Training perplexity : 3.106477461951159\n","Time elapsed: 3.8033835887908936\n","At batch 400\n","Training loss per word: 1.1438675551094517\n","Training perplexity : 3.1388847294031446\n","Time elapsed: 3.8263168334960938\n","At batch 500\n","Training loss per word: 1.0865191452618703\n","Training perplexity : 2.96393905434887\n","Time elapsed: 3.7579545974731445\n","At batch 600\n","Training loss per word: 1.1192416697600813\n","Training perplexity : 3.0625309126917997\n","Time elapsed: 3.789614677429199\n","At batch 700\n","Training loss per word: 1.0518411875891958\n","Training perplexity : 2.862917436488333\n","Time elapsed: 3.776650905609131\n","At batch 800\n","Training loss per word: 1.1072874813988096\n","Training perplexity : 3.0261387937977213\n","Time elapsed: 3.813297986984253\n","At batch 900\n","Training loss per word: 1.1313975205155835\n","Training perplexity : 3.0999857680085667\n","Time elapsed: 3.8124847412109375\n","At batch 1000\n","Training loss per word: 1.0942230369700872\n","Training perplexity : 2.9868611009673587\n","Time elapsed: 3.7665154933929443\n","At batch 1100\n","Training loss per word: 1.1217584830326635\n","Training perplexity : 3.0702484388529006\n","Time elapsed: 3.763503313064575\n","At batch 1200\n","Training loss per word: 1.1460056161145935\n","Training perplexity : 3.145603035958064\n","Time elapsed: 3.7855074405670166\n","At batch 1300\n","Training loss per word: 1.1267460799469964\n","Training perplexity : 3.0855998519995556\n","Time elapsed: 3.791804790496826\n","At batch 1400\n","Training loss per word: 1.1263326390108068\n","Training perplexity : 3.0843244023877476\n","Time elapsed: 3.7853331565856934\n","At batch 1500\n","Training loss per word: 1.0750992197923617\n","Training perplexity : 2.9302836282436466\n","\n","Validation loss per word: 1.3388573407200655\n","Validation perplexity : 3.8146821321208044 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.8155081272125244\n","At batch 100\n","Training loss per word: 1.1465816047345114\n","Training perplexity : 3.14741538940693\n","Time elapsed: 3.7689809799194336\n","At batch 200\n","Training loss per word: 1.1306585109405718\n","Training perplexity : 3.097695695140721\n","Time elapsed: 3.7883920669555664\n","At batch 300\n","Training loss per word: 1.1397602998421168\n","Training perplexity : 3.1260189681342405\n","Time elapsed: 3.790637731552124\n","At batch 400\n","Training loss per word: 1.0735297877126964\n","Training perplexity : 2.9256883540435603\n","Time elapsed: 3.767266273498535\n","At batch 500\n","Training loss per word: 1.1048049315623998\n","Training perplexity : 3.0186355708462442\n","Time elapsed: 3.763728618621826\n","At batch 600\n","Training loss per word: 1.1420912249116058\n","Training perplexity : 3.1333139828781906\n","Time elapsed: 3.782888889312744\n","At batch 700\n","Training loss per word: 1.0711027190919264\n","Training perplexity : 2.918596117798826\n","Time elapsed: 3.7938215732574463\n","At batch 800\n","Training loss per word: 1.0894365344454509\n","Training perplexity : 2.972598643640959\n","Time elapsed: 3.792896032333374\n","At batch 900\n","Training loss per word: 1.1475062995788354\n","Training perplexity : 3.1503271342206927\n","Time elapsed: 3.8132219314575195\n","At batch 1000\n","Training loss per word: 1.1115614922776647\n","Training perplexity : 3.0391002228369857\n","Time elapsed: 3.8269190788269043\n","At batch 1100\n","Training loss per word: 1.143299027185239\n","Training perplexity : 3.137100692968913\n","Time elapsed: 3.7997896671295166\n","At batch 1200\n","Training loss per word: 1.15957128420398\n","Training perplexity : 3.188565994441032\n","Time elapsed: 3.776357889175415\n","At batch 1300\n","Training loss per word: 1.1464660242845082\n","Training perplexity : 3.147051630741975\n","Time elapsed: 3.7989509105682373\n","At batch 1400\n","Training loss per word: 1.107467296511628\n","Training perplexity : 3.0266829882120554\n","Time elapsed: 3.7822105884552\n","At batch 1500\n","Training loss per word: 1.1874470166449504\n","Training perplexity : 3.2787000468078693\n","\n","Validation loss per word: 1.341572059561056\n","Validation perplexity : 3.8250519908341163 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.8168201446533203\n","At batch 100\n","Training loss per word: 1.0604724310631402\n","Training perplexity : 2.887734922739937\n","Time elapsed: 3.810335636138916\n","At batch 200\n","Training loss per word: 1.1267382973118831\n","Training perplexity : 3.0855758379952483\n","Time elapsed: 3.7992618083953857\n","At batch 300\n","Training loss per word: 1.094500573594331\n","Training perplexity : 2.9876901793588932\n","Time elapsed: 3.7486989498138428\n","At batch 400\n","Training loss per word: 1.0873632493361929\n","Training perplexity : 2.9664419835984646\n","Time elapsed: 3.8080248832702637\n","At batch 500\n","Training loss per word: 1.1584351419440564\n","Training perplexity : 3.184945387018204\n","Time elapsed: 3.8131954669952393\n","At batch 600\n","Training loss per word: 1.089704428939916\n","Training perplexity : 2.973395093129298\n","Time elapsed: 3.7971811294555664\n","At batch 700\n","Training loss per word: 1.1562952828996909\n","Training perplexity : 3.178137339554788\n","Time elapsed: 3.785998821258545\n","At batch 800\n","Training loss per word: 1.1333110326328326\n","Training perplexity : 3.105923307303939\n","Time elapsed: 3.8070566654205322\n","At batch 900\n","Training loss per word: 1.0939211415146624\n","Training perplexity : 2.985959517273849\n","Time elapsed: 3.7978057861328125\n","At batch 1000\n","Training loss per word: 1.0147604237432066\n","Training perplexity : 2.7587023986053727\n","Time elapsed: 3.7495784759521484\n","At batch 1100\n","Training loss per word: 1.1340868284557972\n","Training perplexity : 3.1083338045381383\n","Time elapsed: 3.769446849822998\n","At batch 1200\n","Training loss per word: 1.1045949921275442\n","Training perplexity : 3.0180019067183372\n","Time elapsed: 3.7816579341888428\n","At batch 1300\n","Training loss per word: 1.1174229738780979\n","Training perplexity : 3.0569661621607334\n","Time elapsed: 3.7789602279663086\n","At batch 1400\n","Training loss per word: 1.1201495769350878\n","Training perplexity : 3.0653126690783203\n","Time elapsed: 3.824803590774536\n","At batch 1500\n","Training loss per word: 1.1051722046675956\n","Training perplexity : 3.0197444381219025\n","\n","Validation loss per word: 1.3464858253222147\n","Validation perplexity : 3.8438936541903415 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.7695260047912598\n","At batch 100\n","Training loss per word: 1.0290114195507079\n","Training perplexity : 2.798298124055436\n","Time elapsed: 3.7838001251220703\n","At batch 200\n","Training loss per word: 1.0758974751576786\n","Training perplexity : 2.9326236767256617\n","Time elapsed: 3.7888290882110596\n","At batch 300\n","Training loss per word: 1.0689325360445519\n","Training perplexity : 2.912269097860745\n","Time elapsed: 3.7760934829711914\n","At batch 400\n","Training loss per word: 1.098075355425545\n","Training perplexity : 2.9983896326408783\n","Time elapsed: 3.7993335723876953\n","At batch 500\n","Training loss per word: 1.1470130774175893\n","Training perplexity : 3.1487737061869776\n","Time elapsed: 3.8072171211242676\n","At batch 600\n","Training loss per word: 1.1184220890153036\n","Training perplexity : 3.060021949614826\n","Time elapsed: 3.7615554332733154\n","At batch 700\n","Training loss per word: 1.1594816615030106\n","Training perplexity : 3.188280239349653\n","Time elapsed: 3.8286609649658203\n","At batch 800\n","Training loss per word: 1.147852127724748\n","Training perplexity : 3.1514167944192613\n","Time elapsed: 3.8103957176208496\n","At batch 900\n","Training loss per word: 1.0395428649965506\n","Training perplexity : 2.8279239757910557\n","Time elapsed: 3.783261775970459\n","At batch 1000\n","Training loss per word: 1.1195837507706534\n","Training perplexity : 3.0635787255695477\n","Time elapsed: 3.7467522621154785\n","At batch 1100\n","Training loss per word: 1.0825985467027806\n","Training perplexity : 2.9523413889140855\n","Time elapsed: 3.7833640575408936\n","At batch 1200\n","Training loss per word: 1.1820229240085767\n","Training perplexity : 3.260964217900063\n","Time elapsed: 3.736557722091675\n","At batch 1300\n","Training loss per word: 1.0592126011532246\n","Training perplexity : 2.884099158615438\n","Time elapsed: 3.799887180328369\n","At batch 1400\n","Training loss per word: 1.0864442113967636\n","Training perplexity : 2.963716963260763\n","Time elapsed: 3.7434606552124023\n","At batch 1500\n","Training loss per word: 1.0987684037419823\n","Training perplexity : 3.0004683817813946\n","\n","Validation loss per word: 1.3423204088491216\n","Validation perplexity : 3.8279155371010374 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.748110771179199\n","At batch 100\n","Training loss per word: 1.0774480682967873\n","Training perplexity : 2.937174510212159\n","Time elapsed: 3.769700527191162\n","At batch 200\n","Training loss per word: 1.0767136277387757\n","Training perplexity : 2.935018122092278\n","Time elapsed: 3.7462425231933594\n","At batch 300\n","Training loss per word: 1.0570872274623304\n","Training perplexity : 2.877975879574887\n","Time elapsed: 3.740286350250244\n","At batch 400\n","Training loss per word: 1.0765112739919693\n","Training perplexity : 2.934424270264426\n","Time elapsed: 3.7884538173675537\n","At batch 500\n","Training loss per word: 1.0194008484792285\n","Training perplexity : 2.7715336977689327\n","Time elapsed: 3.793185234069824\n","At batch 600\n","Training loss per word: 1.158294820626045\n","Training perplexity : 3.1844985026381396\n","Time elapsed: 3.77243709564209\n","At batch 700\n","Training loss per word: 1.070255908016044\n","Training perplexity : 2.9161256644315117\n","Time elapsed: 3.7856192588806152\n","At batch 800\n","Training loss per word: 1.1220820908874045\n","Training perplexity : 3.071242156142408\n","Time elapsed: 3.809751510620117\n","At batch 900\n","Training loss per word: 1.1019160212302694\n","Training perplexity : 3.0099275877041842\n","Time elapsed: 3.7891132831573486\n","At batch 1000\n","Training loss per word: 1.1133945395142806\n","Training perplexity : 3.044676146006247\n","Time elapsed: 3.8100802898406982\n","At batch 1100\n","Training loss per word: 1.147145927510494\n","Training perplexity : 3.1491920488541965\n","Time elapsed: 3.7659027576446533\n","At batch 1200\n","Training loss per word: 1.1528299265894397\n","Training perplexity : 3.1671430218866523\n","Time elapsed: 3.7785401344299316\n","At batch 1300\n","Training loss per word: 1.0877712259030914\n","Training perplexity : 2.967652469322952\n","Time elapsed: 3.8046278953552246\n","At batch 1400\n","Training loss per word: 1.082173963983727\n","Training perplexity : 2.951088141851955\n","Time elapsed: 3.7735743522644043\n","At batch 1500\n","Training loss per word: 1.0206523487773487\n","Training perplexity : 2.775004444385148\n","\n","Validation loss per word: 1.3401957653391814\n","Validation perplexity : 3.8197912148989297 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.7973008155822754\n","At batch 100\n","Training loss per word: 1.0854998406316232\n","Training perplexity : 2.960919436763539\n","Time elapsed: 3.788548231124878\n","At batch 200\n","Training loss per word: 1.0343923061708862\n","Training perplexity : 2.813396032628909\n","Time elapsed: 3.77480149269104\n","At batch 300\n","Training loss per word: 1.0724025374116142\n","Training perplexity : 2.9223922290935436\n","Time elapsed: 3.765207290649414\n","At batch 400\n","Training loss per word: 1.091713715105202\n","Training perplexity : 2.979375500917144\n","Time elapsed: 3.768878698348999\n","At batch 500\n","Training loss per word: 1.0816575361220218\n","Training perplexity : 2.9495645111696853\n","Time elapsed: 3.7661168575286865\n","At batch 600\n","Training loss per word: 1.09487015015023\n","Training perplexity : 2.988794563670179\n","Time elapsed: 3.7885003089904785\n","At batch 700\n","Training loss per word: 1.0401752044299777\n","Training perplexity : 2.829712749132039\n","Time elapsed: 3.7507691383361816\n","At batch 800\n","Training loss per word: 1.0489363170557358\n","Training perplexity : 2.854613099339013\n","Time elapsed: 3.7792656421661377\n","At batch 900\n","Training loss per word: 1.1101965929045265\n","Training perplexity : 3.034954986406765\n","Time elapsed: 3.805776834487915\n","At batch 1000\n","Training loss per word: 1.0929490400282718\n","Training perplexity : 2.9830582719697096\n","Time elapsed: 3.796440839767456\n","At batch 1100\n","Training loss per word: 1.0269196875\n","Training perplexity : 2.7924509516728016\n","Time elapsed: 3.771205186843872\n","At batch 1200\n","Training loss per word: 1.0856025053662126\n","Training perplexity : 2.961223434376305\n","Time elapsed: 3.817021369934082\n","At batch 1300\n","Training loss per word: 1.19846151513991\n","Training perplexity : 3.31501290035451\n","Time elapsed: 3.7447714805603027\n","At batch 1400\n","Training loss per word: 1.087195527072303\n","Training perplexity : 2.9659444869550753\n","Time elapsed: 3.757854461669922\n","At batch 1500\n","Training loss per word: 1.0551075411295794\n","Training perplexity : 2.872284025961448\n","\n","Validation loss per word: 1.3427743964693049\n","Validation perplexity : 3.829653757901496 \n","\n","Training 1563 number of batches\n","Time elapsed: 3.7451558113098145\n","At batch 100\n","Training loss per word: 1.118965058696411\n","Training perplexity : 3.0616838999104705\n","Time elapsed: 3.7456469535827637\n","At batch 200\n","Training loss per word: 1.084817152756911\n","Training perplexity : 2.958898742796177\n","Time elapsed: 3.739215850830078\n","At batch 300\n","Training loss per word: 1.0808059325332904\n","Training perplexity : 2.947053720697608\n","Time elapsed: 3.784787654876709\n","At batch 400\n","Training loss per word: 1.0143558166579538\n","Training perplexity : 2.7575864338476412\n","Time elapsed: 3.757758378982544\n","At batch 500\n","Training loss per word: 1.0636638919890873\n","Training perplexity : 2.896965737968434\n","Time elapsed: 3.761496067047119\n","At batch 600\n","Training loss per word: 1.0679087537400267\n","Training perplexity : 2.9092890939902407\n","Time elapsed: 3.7373595237731934\n","At batch 700\n","Training loss per word: 1.0820845007856728\n","Training perplexity : 2.950824139878452\n","Time elapsed: 3.7674787044525146\n","At batch 800\n","Training loss per word: 1.110800911860211\n","Training perplexity : 3.0367896215312564\n","Time elapsed: 3.762507677078247\n","At batch 900\n","Training loss per word: 1.0788491338809407\n","Training perplexity : 2.9412925684942954\n","Time elapsed: 3.7523109912872314\n","At batch 1000\n","Training loss per word: 1.1006824035840281\n","Training perplexity : 3.0062167772493096\n","Time elapsed: 3.719322443008423\n","At batch 1100\n","Training loss per word: 1.092326754734693\n","Training perplexity : 2.9812025361356183\n","Time elapsed: 3.751136064529419\n","At batch 1200\n","Training loss per word: 1.0834127004087597\n","Training perplexity : 2.9547460273370016\n","Time elapsed: 3.7584688663482666\n","At batch 1300\n","Training loss per word: 1.1341377882408181\n","Training perplexity : 3.108492208596675\n","Time elapsed: 3.7554566860198975\n","At batch 1400\n","Training loss per word: 1.0805350895705585\n","Training perplexity : 2.9462556400187108\n","Time elapsed: 3.770933151245117\n","At batch 1500\n","Training loss per word: 1.0837757330277602\n","Training perplexity : 2.9558188912563033\n","\n","Validation loss per word: 1.3419258234432407\n","Validation perplexity : 3.8264053954546737 \n","\n"]}],"source":["for i in range(20):\n"," train_epoch_packed(model, optimizer, train_loader, val_loader)"]},{"cell_type":"code","execution_count":25,"metadata":{"id":"Oz9Kg1p650nd","executionInfo":{"status":"ok","timestamp":1666534295812,"user_tz":240,"elapsed":14,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}}},"outputs":[],"source":["torch.save(model, \"trained_model.pt\")"]},{"cell_type":"code","execution_count":30,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"VHp8x3l650ng","executionInfo":{"status":"ok","timestamp":1666534306544,"user_tz":240,"elapsed":214,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}},"outputId":"159b75d2-be4a-4401-e1db-303f5ea64638"},"outputs":[{"output_type":"stream","name":"stdout","text":["uarrel\n","\n"]}],"source":["print(generate(model, \"To be, or not to be, that is the q\",20))"]},{"cell_type":"code","execution_count":27,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Ig5Y50kJ50ni","executionInfo":{"status":"ok","timestamp":1666534295814,"user_tz":240,"elapsed":9,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}},"outputId":"e5e9266a-76e6-4d6f-9b24-0f624ad5982f"},"outputs":[{"output_type":"stream","name":"stdout","text":["Scotland\n","\n"]}],"source":["print(generate(model, \"Richard \", 1000))"]},{"cell_type":"code","execution_count":28,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"8mMJwLSd50nm","executionInfo":{"status":"ok","timestamp":1666534295814,"user_tz":240,"elapsed":7,"user":{"displayName":"Soumya Empran","userId":"05313861740772333512"}},"outputId":"3194b7e5-c1ef-4074-d90c-a0e152546f28"},"outputs":[{"output_type":"stream","name":"stdout","text":["wear\n","\n"]}],"source":["print(generate(model, \"Hello\", 1000))"]},{"cell_type":"markdown","metadata":{"id":"8woC85Ud50np"},"source":["### Reminders\n","\n","By default, for all rnn modules (rnn, GRU, LSTM) batch_first = False\n","To use packed sequences, your inputs first need to be sorted in descending order of length (longest to shortest)\n","Batches need to have inputs of the same length "]},{"cell_type":"markdown","metadata":{"id":"u6sGgg7K50nq"},"source":[]}],"metadata":{"colab":{"provenance":[]},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.4"},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":0}
\ No newline at end of file
diff --git a/S24/document/recitation/Recitation9/Recitation 9 (CTC Decoding and Beam Search).pdf b/S24/document/recitation/Recitation9/Recitation 9 (CTC Decoding and Beam Search).pdf
deleted file mode 100644
index e72d85e6..00000000
Binary files a/S24/document/recitation/Recitation9/Recitation 9 (CTC Decoding and Beam Search).pdf and /dev/null differ
diff --git a/S24/document/recitation/S24_IDL__Recitation_5.pdf b/S24/document/recitation/S24_IDL__Recitation_5.pdf
new file mode 100644
index 00000000..0991de1a
Binary files /dev/null and b/S24/document/recitation/S24_IDL__Recitation_5.pdf differ
diff --git a/S24/index.html b/S24/index.html
index ee6c71e0..bfc24c46 100644
--- a/S24/index.html
+++ b/S24/index.html
@@ -2231,7 +2231,11 @@ Recitations and Bootcamps
HW2P1, HW2P2 |
- TBA
+ Slides:
+
+ HW2P1,
+
+ HW2P2
|
@@ -2241,7 +2245,6 @@ Recitations and Bootcamps
|
Chetan Chilkunda, Heena Chandak, Ishan Mamadapur, Kateryna Shapovalenko, Syed Abdul Hannan
|
-
@@ -2254,6 +2257,29 @@ Recitations and Bootcamps
CNN: Basics and Backprop
+ |
+ Slides (PDF)
+ |
+
+
+ Link
+ |
+
+
+ Denis Musinguzi, Syed Abdul Hannan, Miya Sylvester
+ |
+
+
+
+
+ Lab 6 |
+
+ Friday, Feb. 23rd |
+
+ |
+
+ CNN: Classification and Verification
+
|
TBA
|
- Denis Musinguzi, Syed Abdul Hannan, Harshit Mehrotra
+ Ishan Mamadapur, Shreya Ajay Kale, Aarya Makwana, Sarthak Bisht
|
-
+
- Lab 6 |
-
- Friday, Feb. 23rd |
+ Lab 7 |
+
+ Friday, Feb. 30th |
|
- CNN: Classification and Verification
+ | RNN Basics
|
TBA
@@ -2300,7 +2326,7 @@ Recitations and Bootcamps
|
- Ishan Mamadapur, Shreya Ajay Kale, Aarya Makwana, Sarthak Bisht
+ Aarya Makwana, Alexander Moker, Harshit Mehrotra
|