diff --git a/notebooks/tvm-ops.ipynb b/notebooks/tvm-ops.ipynb index 258dc62..1ba3a34 100644 --- a/notebooks/tvm-ops.ipynb +++ b/notebooks/tvm-ops.ipynb @@ -14,7 +14,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "74ce509b", "metadata": {}, "outputs": [], @@ -69,7 +69,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "7fa148b7", "metadata": {}, "outputs": [], @@ -84,6 +84,15 @@ "schedule = te.create_schedule([C.op])" ] }, + { + "cell_type": "markdown", + "id": "5ba90d93-9db9-4b74-a562-114fc9b26d3c", + "metadata": {}, + "source": [ + "* `C.op` is an operation for which we define the schedule.\n", + "* `schedule` defines what operations need to be computed - it will be subjected to further optimizations" + ] + }, { "cell_type": "markdown", "id": "beec19b7", @@ -96,10 +105,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "d5017665", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# from tvm.script import ir as I\n", + "# from tvm.script import tir as T\n", + "\n", + "@I.ir_module\n", + "class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle, C: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " m, n = T.int32(), T.int32()\n", + " A_1 = T.match_buffer(A, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " C_1 = T.match_buffer(C, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " for i, j in T.grid(m, n):\n", + " C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + " A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + " B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + " C_2[i * C_1.strides[0] + j * C_1.strides[1]] = A_2[i * A_1.strides[0] + j * A_1.strides[1]] * B_2[i * B_1.strides[0] + j * B_1.strides[1]]\n" + ] + } + ], "source": [ "base_function = str(tvm.lower(schedule, [A, B, C], simple_mode=True))\n", "\n", @@ -115,15 +148,44 @@ "\n", "https://tvm.apache.org/docs/reference/api/python/te.html#tvm.te.Stage.split\n", "\n", + "#### Split\n", + "\n", "`split` splits a given axis by `factor` into outer and inner axis (inner axis has `factor` length), where inner axis has a `factor` length" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "fac40433", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# from tvm.script import ir as I\n", + "# from tvm.script import tir as T\n", + "\n", + "@I.ir_module\n", + "class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle, C: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " m, n = T.int32(), T.int32()\n", + " A_1 = T.match_buffer(A, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " C_1 = T.match_buffer(C, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " for i_outer, i_inner in T.grid((m + 31) // 32, 32):\n", + " if T.likely(i_outer * 32 + i_inner < m):\n", + " for j in range(n):\n", + " cse_var_1: T.int32 = i_outer * 32 + i_inner\n", + " C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + " A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + " B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + " C_2[cse_var_1 * C_1.strides[0] + j * C_1.strides[1]] = A_2[cse_var_1 * A_1.strides[0] + j * A_1.strides[1]] * B_2[cse_var_1 * B_1.strides[0] + j * B_1.strides[1]]\n" + ] + } + ], "source": [ "n = te.var('n')\n", "m = te.var('m')\n", @@ -143,15 +205,56 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "cea38df2", "metadata": { - "scrolled": false, "slideshow": { "slide_type": "-" } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " # from tvm.script import ir as I\n", + " # from tvm.script import tir as T\n", + " \n", + " @I.ir_module\n", + " class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle, C: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " m, n = T.int32(), T.int32()\n", + " A_1 = T.match_buffer(A, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " C_1 = T.match_buffer(C, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + "- for i, j in T.grid(m, n):\n", + "+ for i_outer, i_inner in T.grid((m + 31) // 32, 32):\n", + "+ if T.likely(i_outer * 32 + i_inner < m):\n", + "+ for j in range(n):\n", + "+ cse_var_1: T.int32 = i_outer * 32 + i_inner\n", + "- C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + "+ C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + "? ++++++++\n", + "\n", + "- A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + "+ A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + "? ++++++++\n", + "\n", + "- B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + "+ B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + "? ++++++++\n", + "\n", + "- C_2[i * C_1.strides[0] + j * C_1.strides[1]] = A_2[i * A_1.strides[0] + j * A_1.strides[1]] * B_2[i * B_1.strides[0] + j * B_1.strides[1]]\n", + "? ^ ^ ^\n", + "\n", + "+ C_2[cse_var_1 * C_1.strides[0] + j * C_1.strides[1]] = A_2[cse_var_1 * A_1.strides[0] + j * A_1.strides[1]] * B_2[cse_var_1 * B_1.strides[0] + j * B_1.strides[1]]\n", + "? ++++++++ ^^^^^^^^^ ^^^^^^^^^ ^^^^^^^^^\n", + "\n" + ] + } + ], "source": [ "compute_diff(base_function, split_function)" ] @@ -161,6 +264,8 @@ "id": "4018e6c4", "metadata": {}, "source": [ + "#### Tile\n", + "\n", "https://tvm.apache.org/docs/reference/api/python/te.html#tvm.te.Stage.tile\n", "\n", "Same as split, but in 2D - tiles the computations along given axes" @@ -168,10 +273,39 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "ec0b348b", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# from tvm.script import ir as I\n", + "# from tvm.script import tir as T\n", + "\n", + "@I.ir_module\n", + "class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle, C: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " m, n = T.int32(), T.int32()\n", + " A_1 = T.match_buffer(A, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " C_1 = T.match_buffer(C, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " for i_outer, j_outer, i_inner in T.grid((m + 15) // 16, (n + 7) // 8, 16):\n", + " if T.likely(i_outer * 16 + i_inner < m):\n", + " for j_inner in range(8):\n", + " if T.likely(j_outer * 8 + j_inner < n):\n", + " cse_var_2: T.int32 = j_outer * 8 + j_inner\n", + " cse_var_1: T.int32 = i_outer * 16 + i_inner\n", + " C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + " A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + " B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + " C_2[cse_var_1 * C_1.strides[0] + cse_var_2 * C_1.strides[1]] = A_2[cse_var_1 * A_1.strides[0] + cse_var_2 * A_1.strides[1]] * B_2[cse_var_1 * B_1.strides[0] + cse_var_2 * B_1.strides[1]]\n" + ] + } + ], "source": [ "n = te.var('n')\n", "m = te.var('m')\n", @@ -186,15 +320,57 @@ "\n", "tile_function = str(tvm.lower(schedule, [A, B, C], simple_mode=True))\n", "\n", - "print(split_function)" + "print(tile_function)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "46f83dd3", - "metadata": {}, - "outputs": [], + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " # from tvm.script import ir as I\n", + " # from tvm.script import tir as T\n", + " \n", + " @I.ir_module\n", + " class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle, C: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " m, n = T.int32(), T.int32()\n", + " A_1 = T.match_buffer(A, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " C_1 = T.match_buffer(C, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + "- for i, j in T.grid(m, n):\n", + "+ for i_outer, j_outer, i_inner in T.grid((m + 15) // 16, (n + 7) // 8, 16):\n", + "+ if T.likely(i_outer * 16 + i_inner < m):\n", + "+ for j_inner in range(8):\n", + "+ if T.likely(j_outer * 8 + j_inner < n):\n", + "+ cse_var_2: T.int32 = j_outer * 8 + j_inner\n", + "+ cse_var_1: T.int32 = i_outer * 16 + i_inner\n", + "- C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + "+ C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + "? ++++++++++++\n", + "\n", + "- A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + "+ A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + "? ++++++++++++\n", + "\n", + "- B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + "+ B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + "? ++++++++++++\n", + "\n", + "- C_2[i * C_1.strides[0] + j * C_1.strides[1]] = A_2[i * A_1.strides[0] + j * A_1.strides[1]] * B_2[i * B_1.strides[0] + j * B_1.strides[1]]\n", + "+ C_2[cse_var_1 * C_1.strides[0] + cse_var_2 * C_1.strides[1]] = A_2[cse_var_1 * A_1.strides[0] + cse_var_2 * A_1.strides[1]] * B_2[cse_var_1 * B_1.strides[0] + cse_var_2 * B_1.strides[1]]\n" + ] + } + ], "source": [ "compute_diff(base_function, tile_function)" ] @@ -213,10 +389,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "acbf6204", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# from tvm.script import ir as I\n", + "# from tvm.script import tir as T\n", + "\n", + "@I.ir_module\n", + "class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle, C: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " m, n = T.int32(), T.int32()\n", + " A_1 = T.match_buffer(A, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " C_1 = T.match_buffer(C, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " for i_j_fused in range(m * n):\n", + " C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + " A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + " B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + " C_2[i_j_fused // n * C_1.strides[0] + i_j_fused % n * C_1.strides[1]] = A_2[i_j_fused // n * A_1.strides[0] + i_j_fused % n * A_1.strides[1]] * B_2[i_j_fused // n * B_1.strides[0] + i_j_fused % n * B_1.strides[1]]\n" + ] + } + ], "source": [ "n = te.var('n')\n", "m = te.var('m')\n", @@ -236,10 +436,36 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "bce56379", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " # from tvm.script import ir as I\n", + " # from tvm.script import tir as T\n", + " \n", + " @I.ir_module\n", + " class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle, C: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " m, n = T.int32(), T.int32()\n", + " A_1 = T.match_buffer(A, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " C_1 = T.match_buffer(C, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + "- for i, j in T.grid(m, n):\n", + "+ for i_j_fused in range(m * n):\n", + " C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + " A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + " B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + "- C_2[i * C_1.strides[0] + j * C_1.strides[1]] = A_2[i * A_1.strides[0] + j * A_1.strides[1]] * B_2[i * B_1.strides[0] + j * B_1.strides[1]]\n", + "+ C_2[i_j_fused // n * C_1.strides[0] + i_j_fused % n * C_1.strides[1]] = A_2[i_j_fused // n * A_1.strides[0] + i_j_fused % n * A_1.strides[1]] * B_2[i_j_fused // n * B_1.strides[0] + i_j_fused % n * B_1.strides[1]]\n" + ] + } + ], "source": [ "compute_diff(base_function, fuse_function)" ] @@ -256,10 +482,37 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "id": "da3c67ce", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# from tvm.script import ir as I\n", + "# from tvm.script import tir as T\n", + "\n", + "@I.ir_module\n", + "class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle, C: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " m, n = T.int32(), T.int32()\n", + " A_1 = T.match_buffer(A, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " C_1 = T.match_buffer(C, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " blockIdx_x = T.launch_thread(\"blockIdx.x\", (m + 63) // 64)\n", + " threadIdx_x = T.launch_thread(\"threadIdx.x\", 64)\n", + " for j in range(n):\n", + " if T.likely(blockIdx_x * 64 + threadIdx_x < m):\n", + " C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + " A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + " B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + " C_2[(blockIdx_x * 64 + threadIdx_x) * C_1.strides[0] + j * C_1.strides[1]] = A_2[(blockIdx_x * 64 + threadIdx_x) * A_1.strides[0] + j * A_1.strides[1]] * B_2[(blockIdx_x * 64 + threadIdx_x) * B_1.strides[0] + j * B_1.strides[1]]\n" + ] + } + ], "source": [ "n = te.var('n')\n", "m = te.var('m')\n", @@ -282,10 +535,48 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "d86fa267", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " # from tvm.script import ir as I\n", + " # from tvm.script import tir as T\n", + " \n", + " @I.ir_module\n", + " class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle, C: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " m, n = T.int32(), T.int32()\n", + " A_1 = T.match_buffer(A, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " C_1 = T.match_buffer(C, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + "- for i, j in T.grid(m, n):\n", + "+ blockIdx_x = T.launch_thread(\"blockIdx.x\", (m + 63) // 64)\n", + "+ threadIdx_x = T.launch_thread(\"threadIdx.x\", 64)\n", + "+ for j in range(n):\n", + "+ if T.likely(blockIdx_x * 64 + threadIdx_x < m):\n", + "- C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + "+ C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + "? ++++\n", + "\n", + "- A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + "+ A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + "? ++++\n", + "\n", + "- B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + "+ B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + "? ++++\n", + "\n", + "- C_2[i * C_1.strides[0] + j * C_1.strides[1]] = A_2[i * A_1.strides[0] + j * A_1.strides[1]] * B_2[i * B_1.strides[0] + j * B_1.strides[1]]\n", + "+ C_2[(blockIdx_x * 64 + threadIdx_x) * C_1.strides[0] + j * C_1.strides[1]] = A_2[(blockIdx_x * 64 + threadIdx_x) * A_1.strides[0] + j * A_1.strides[1]] * B_2[(blockIdx_x * 64 + threadIdx_x) * B_1.strides[0] + j * B_1.strides[1]]\n" + ] + } + ], "source": [ "compute_diff(base_function, bind_function)" ] @@ -304,10 +595,39 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "b213cbd6", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# from tvm.script import ir as I\n", + "# from tvm.script import tir as T\n", + "\n", + "@I.ir_module\n", + "class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle, C: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " m, n = T.int32(), T.int32()\n", + " A_1 = T.match_buffer(A, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " C_1 = T.match_buffer(C, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " for j_outer, j_inner in T.grid((n + 7) // 8, 8):\n", + " if T.likely(j_outer * 8 + j_inner < n):\n", + " for i_outer, i_inner in T.grid((m + 15) // 16, 16):\n", + " if T.likely(i_outer * 16 + i_inner < m):\n", + " cse_var_2: T.int32 = j_outer * 8 + j_inner\n", + " cse_var_1: T.int32 = i_outer * 16 + i_inner\n", + " C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + " A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + " B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + " C_2[cse_var_1 * C_1.strides[0] + cse_var_2 * C_1.strides[1]] = A_2[cse_var_1 * A_1.strides[0] + cse_var_2 * A_1.strides[1]] * B_2[cse_var_1 * B_1.strides[0] + cse_var_2 * B_1.strides[1]]\n" + ] + } + ], "source": [ "n = te.var('n')\n", "m = te.var('m')\n", @@ -318,23 +638,165 @@ "\n", "schedule = te.create_schedule([C.op])\n", "\n", - "xo, xi, yo, yi = schedule[C].tile(C.op.axis[0], C.op.axis[1], x_factor=16, y_factor=8)\n", + "xo, yo, xi, yi = schedule[C].tile(C.op.axis[0], C.op.axis[1], x_factor=16, y_factor=8)\n", + "\n", + "schedule[C].reorder(yo, yi, xo, xi)\n", + "\n", + "reordered_function_1 = str(tvm.lower(schedule, [A, B, C], simple_mode=True))\n", + "\n", + "print(reordered_function_1)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "be316654-c4d9-4ddc-bdbe-0a5fb40b1d27", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " # from tvm.script import ir as I\n", + " # from tvm.script import tir as T\n", + " \n", + " @I.ir_module\n", + " class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle, C: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " m, n = T.int32(), T.int32()\n", + " A_1 = T.match_buffer(A, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " C_1 = T.match_buffer(C, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + "+ for j_outer, j_inner in T.grid((n + 7) // 8, 8):\n", + "+ if T.likely(j_outer * 8 + j_inner < n):\n", + "- for i_outer, j_outer, i_inner in T.grid((m + 15) // 16, (n + 7) // 8, 16):\n", + "? --------- --------------\n", + "\n", + "+ for i_outer, i_inner in T.grid((m + 15) // 16, 16):\n", + "? ++++++++\n", + "\n", + "- if T.likely(i_outer * 16 + i_inner < m):\n", + "+ if T.likely(i_outer * 16 + i_inner < m):\n", + "? ++++++++\n", + "\n", + "- for j_inner in range(8):\n", + "- if T.likely(j_outer * 8 + j_inner < n):\n", + " cse_var_2: T.int32 = j_outer * 8 + j_inner\n", + " cse_var_1: T.int32 = i_outer * 16 + i_inner\n", + " C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + " A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + " B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + " C_2[cse_var_1 * C_1.strides[0] + cse_var_2 * C_1.strides[1]] = A_2[cse_var_1 * A_1.strides[0] + cse_var_2 * A_1.strides[1]] * B_2[cse_var_1 * B_1.strides[0] + cse_var_2 * B_1.strides[1]]\n" + ] + } + ], + "source": [ + "compute_diff(tile_function, reordered_function_1)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "14967874-4ee7-44ae-80a4-44dae6f02a88", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# from tvm.script import ir as I\n", + "# from tvm.script import tir as T\n", + "\n", + "@I.ir_module\n", + "class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle, C: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " m, n = T.int32(), T.int32()\n", + " A_1 = T.match_buffer(A, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " C_1 = T.match_buffer(C, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " for i_outer, j_outer, i_inner in T.grid((m + 15) // 16, (n + 7) // 8, 16):\n", + " if T.likely(i_outer * 16 + i_inner < m):\n", + " for j_inner in range(8):\n", + " if T.likely(j_outer * 8 + j_inner < n):\n", + " cse_var_2: T.int32 = j_outer * 8 + j_inner\n", + " cse_var_1: T.int32 = i_outer * 16 + i_inner\n", + " C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + " A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + " B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + " C_2[cse_var_1 * C_1.strides[0] + cse_var_2 * C_1.strides[1]] = A_2[cse_var_1 * A_1.strides[0] + cse_var_2 * A_1.strides[1]] * B_2[cse_var_1 * B_1.strides[0] + cse_var_2 * B_1.strides[1]]\n" + ] + } + ], + "source": [ + "n = te.var('n')\n", + "m = te.var('m')\n", + "\n", + "A = te.placeholder((m, n), name='A')\n", + "B = te.placeholder((m, n), name='B')\n", + "C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name='C')\n", + "\n", + "schedule = te.create_schedule([C.op])\n", + "\n", + "xo, yo, xi, yi = schedule[C].tile(C.op.axis[0], C.op.axis[1], x_factor=16, y_factor=8)\n", "\n", "schedule[C].reorder(xo, yo, xi, yi)\n", "\n", - "reordered_function = str(tvm.lower(schedule, [A, B, C], simple_mode=True))\n", + "reordered_function_2 = str(tvm.lower(schedule, [A, B, C], simple_mode=True))\n", "\n", - "print(reordered_function)" + "print(reordered_function_2)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "id": "46e7f114", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " # from tvm.script import ir as I\n", + " # from tvm.script import tir as T\n", + " \n", + " @I.ir_module\n", + " class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle, C: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " m, n = T.int32(), T.int32()\n", + " A_1 = T.match_buffer(A, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " C_1 = T.match_buffer(C, (m, n), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + "- for j_outer, j_inner in T.grid((n + 7) // 8, 8):\n", + "- if T.likely(j_outer * 8 + j_inner < n):\n", + "- for i_outer, i_inner in T.grid((m + 15) // 16, 16):\n", + "? --------\n", + "\n", + "+ for i_outer, j_outer, i_inner in T.grid((m + 15) // 16, (n + 7) // 8, 16):\n", + "? +++++++++ ++++++++++++++\n", + "\n", + "- if T.likely(i_outer * 16 + i_inner < m):\n", + "? --------\n", + "\n", + "+ if T.likely(i_outer * 16 + i_inner < m):\n", + "+ for j_inner in range(8):\n", + "+ if T.likely(j_outer * 8 + j_inner < n):\n", + " cse_var_2: T.int32 = j_outer * 8 + j_inner\n", + " cse_var_1: T.int32 = i_outer * 16 + i_inner\n", + " C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + " A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + " B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + " C_2[cse_var_1 * C_1.strides[0] + cse_var_2 * C_1.strides[1]] = A_2[cse_var_1 * A_1.strides[0] + cse_var_2 * A_1.strides[1]] * B_2[cse_var_1 * B_1.strides[0] + cse_var_2 * B_1.strides[1]]\n" + ] + } + ], "source": [ - "compute_diff(tile_function, reordered_function)" + "compute_diff(reordered_function_1, reordered_function_2)" ] }, { @@ -349,10 +811,36 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "id": "42283503", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# from tvm.script import ir as I\n", + "# from tvm.script import tir as T\n", + "\n", + "@I.ir_module\n", + "class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle, C: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " m = T.int32()\n", + " A_1 = T.match_buffer(A, (m,), strides=(\"stride\",), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (m,), strides=(\"stride\",), buffer_type=\"auto\")\n", + " C_1 = T.match_buffer(C, (m,), strides=(\"stride\",), buffer_type=\"auto\")\n", + " B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + " for i in range(m):\n", + " A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + " B_2[i * B_1.strides[0]] = A_2[i * A_1.strides[0]] + T.float32(1)\n", + " for i in range(m):\n", + " C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + " C_2[i * C_1.strides[0]] = B_2[i * B_1.strides[0]] * T.float32(2)\n" + ] + } + ], "source": [ "m = te.var('m')\n", "\n", @@ -379,10 +867,35 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "id": "8ba8567f", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# from tvm.script import ir as I\n", + "# from tvm.script import tir as T\n", + "\n", + "@I.ir_module\n", + "class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle, C: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " m = T.int32()\n", + " A_1 = T.match_buffer(A, (m,), strides=(\"stride\",), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (m,), strides=(\"stride\",), buffer_type=\"auto\")\n", + " C_1 = T.match_buffer(C, (m,), strides=(\"stride\",), buffer_type=\"auto\")\n", + " for i in range(m):\n", + " B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + " A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + " B_2[i * B_1.strides[0]] = A_2[i * A_1.strides[0]] + T.float32(1)\n", + " C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + " C_2[i * C_1.strides[0]] = B_2[i * B_1.strides[0]] * T.float32(2)\n" + ] + } + ], "source": [ "m = te.var('m')\n", "\n", @@ -400,10 +913,37 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "id": "dc3f78cd", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " # from tvm.script import ir as I\n", + " # from tvm.script import tir as T\n", + " \n", + " @I.ir_module\n", + " class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle, C: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " m = T.int32()\n", + " A_1 = T.match_buffer(A, (m,), strides=(\"stride\",), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (m,), strides=(\"stride\",), buffer_type=\"auto\")\n", + " C_1 = T.match_buffer(C, (m,), strides=(\"stride\",), buffer_type=\"auto\")\n", + "- B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + " for i in range(m):\n", + "+ B_2 = T.Buffer((B_1.strides[0] * m,), data=B_1.data, buffer_type=\"auto\")\n", + " A_2 = T.Buffer((A_1.strides[0] * m,), data=A_1.data, buffer_type=\"auto\")\n", + " B_2[i * B_1.strides[0]] = A_2[i * A_1.strides[0]] + T.float32(1)\n", + "- for i in range(m):\n", + " C_2 = T.Buffer((C_1.strides[0] * m,), data=C_1.data, buffer_type=\"auto\")\n", + " C_2[i * C_1.strides[0]] = B_2[i * B_1.strides[0]] * T.float32(2)\n" + ] + } + ], "source": [ "compute_diff(base_op_chain, computeshift_op_chain)" ] @@ -422,10 +962,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "id": "173fa420", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# from tvm.script import ir as I\n", + "# from tvm.script import tir as T\n", + "\n", + "@I.ir_module\n", + "class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " n, m = T.int32(), T.int32()\n", + " A_1 = T.match_buffer(A, (n, m), strides=(\"stride\", \"stride\"), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (n,), strides=(\"stride\",), buffer_type=\"auto\")\n", + " for i in range(n):\n", + " B_2 = T.Buffer((B_1.strides[0] * n,), data=B_1.data, buffer_type=\"auto\")\n", + " B_2[i * B_1.strides[0]] = T.float32(0)\n", + " for k in range(m):\n", + " A_2 = T.Buffer((A_1.strides[0] * n,), data=A_1.data, buffer_type=\"auto\")\n", + " B_2[i * B_1.strides[0]] = B_2[i * B_1.strides[0]] + A_2[i * A_1.strides[0] + k * A_1.strides[1]]\n" + ] + } + ], "source": [ "n = te.var(\"n\")\n", "m = te.var(\"m\")\n", @@ -472,10 +1036,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "id": "d539b840", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# from tvm.script import ir as I\n", + "# from tvm.script import tir as T\n", + "\n", + "@I.ir_module\n", + "class Module:\n", + " @T.prim_func\n", + " def main(A: T.handle, B: T.handle):\n", + " T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " n = T.int32()\n", + " A_1 = T.match_buffer(A, (n,), strides=(\"stride\",), buffer_type=\"auto\")\n", + " B_1 = T.match_buffer(B, (n,), strides=(\"stride\",), buffer_type=\"auto\")\n", + " blockIdx_x = T.launch_thread(\"blockIdx.x\", (n + 63) // 64)\n", + " threadIdx_x = T.launch_thread(\"threadIdx.x\", 64)\n", + " if T.likely(blockIdx_x * 64 + threadIdx_x < n):\n", + " B_2 = T.Buffer((B_1.strides[0] * n,), data=B_1.data, buffer_type=\"auto\")\n", + " A_2 = T.Buffer((A_1.strides[0] * n,), data=A_1.data, buffer_type=\"auto\")\n", + " B_2[(blockIdx_x * 64 + threadIdx_x) * B_1.strides[0]] = T.sigmoid(A_2[(blockIdx_x * 64 + threadIdx_x) * A_1.strides[0]])\n" + ] + } + ], "source": [ "n = te.var(\"n\")\n", "A = te.placeholder((n,), name=\"A\")\n", @@ -499,10 +1087,30 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "id": "1d7dd63c", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "// Function: mysigm_kernel\n", + "__kernel void mysigm_kernel(__global float* restrict A, __global float* restrict B, int n, int stride, int stride_1);\n", + "__kernel void mysigm_kernel(__global float* restrict A, __global float* restrict B, int n, int stride, int stride_1) {\n", + " if ((convert_int(get_group_id(0))) < (n >> 6)) {\n", + " B[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride)] = (1.000000e+00f / (1.000000e+00f + exp((0.000000e+00f - A[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride_1)]))));\n", + " } else {\n", + " if ((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) < n) {\n", + " B[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride)] = (1.000000e+00f / (1.000000e+00f + exp((0.000000e+00f - A[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride_1)]))));\n", + " }\n", + " }\n", + "}\n", + "\n", + "\n" + ] + } + ], "source": [ "fopencl = tvm.build(schedule, [A, B], \"opencl\", name=\"mysigm\")\n", "print(fopencl.imported_modules[0].get_source())" @@ -526,10 +1134,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "id": "45c12a1d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "def mylog(x):\n", " \"\"\"customized log intrinsic function\"\"\"\n", @@ -550,7 +1169,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "id": "894ccf90", "metadata": {}, "outputs": [], @@ -567,10 +1186,30 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 26, "id": "8ee98094", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "// Function: mykernel_kernel\n", + "__kernel void mykernel_kernel(__global float* restrict A, __global float* restrict B, int n, int stride, int stride_1);\n", + "__kernel void mykernel_kernel(__global float* restrict A, __global float* restrict B, int n, int stride, int stride_1) {\n", + " if ((convert_int(get_group_id(0))) < (n >> 6)) {\n", + " B[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride)] = log(A[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride_1)]);\n", + " } else {\n", + " if ((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) < n) {\n", + " B[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride)] = log(A[((((convert_int(get_group_id(0))) * 64) + (convert_int(get_local_id(0)))) * stride_1)]);\n", + " }\n", + " }\n", + "}\n", + "\n", + "\n" + ] + } + ], "source": [ "fopencl = tvm.build(schedule, [A, B], \"opencl\", name=\"mykernel\")\n", "print(fopencl.imported_modules[0].get_source())" @@ -588,12 +1227,230 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 27, "id": "18da8a5b", - "metadata": { - "scrolled": false - }, - "outputs": [], + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "// tvm target: c -keys=cpu \n", + "#define TVM_EXPORTS\n", + "#include \"tvm/runtime/c_runtime_api.h\"\n", + "#include \"tvm/runtime/c_backend_api.h\"\n", + "#include \n", + "#include \n", + "#ifdef __cplusplus\n", + "extern \"C\"\n", + "#endif\n", + "TVM_DLL int32_t tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);\n", + "#ifdef __cplusplus\n", + "extern \"C\"\n", + "#endif\n", + "TVM_DLL int32_t tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_1(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);\n", + "#ifdef __cplusplus\n", + "extern \"C\"\n", + "#endif\n", + "TVM_DLL int32_t tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_2(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);\n", + "#ifdef __cplusplus\n", + "extern \"C\"\n", + "#endif\n", + "TVM_DLL int32_t tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_add(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);\n", + "#ifdef __cplusplus\n", + "extern \"C\"\n", + "#endif\n", + "TVM_DLL int32_t tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle) {\n", + " int32_t p0_code = arg_type_ids[0];\n", + " int32_t p1_code = arg_type_ids[1];\n", + " int32_t p2_code = arg_type_ids[2];\n", + " int32_t T_relu_code = arg_type_ids[3];\n", + " void* p0 = (((TVMValue*)args)[0].v_handle);\n", + " void* p1 = (((TVMValue*)args)[1].v_handle);\n", + " void* p2 = (((TVMValue*)args)[2].v_handle);\n", + " void* T_relu = (((TVMValue*)args)[3].v_handle);\n", + " void* p0_1 = (((DLTensor*)p0)[0].data);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_p0_shape = (((DLTensor*)p0)[0].shape);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_p0_strides = (((DLTensor*)p0)[0].strides);\n", + " int32_t dev_id = (((DLTensor*)p0)[0].device.device_id);\n", + " void* p1_1 = (((DLTensor*)p1)[0].data);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_p1_shape = (((DLTensor*)p1)[0].shape);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_p1_strides = (((DLTensor*)p1)[0].strides);\n", + " void* p2_1 = (((DLTensor*)p2)[0].data);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_p2_shape = (((DLTensor*)p2)[0].shape);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_p2_strides = (((DLTensor*)p2)[0].strides);\n", + " void* T_relu_1 = (((DLTensor*)T_relu)[0].data);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_T_relu_shape = (((DLTensor*)T_relu)[0].shape);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_T_relu_strides = (((DLTensor*)T_relu)[0].strides);\n", + " if (!(tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_p0_strides == NULL)) {\n", + " }\n", + " if (!(tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_p1_strides == NULL)) {\n", + " }\n", + " if (!(tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_p2_strides == NULL)) {\n", + " }\n", + " if (!(tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_T_relu_strides == NULL)) {\n", + " }\n", + " float3 compute_global[1];\n", + " compute_global[0] = ((float3)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f));\n", + " for (int32_t k_outer = 0; k_outer < 4; ++k_outer) {\n", + " compute_global[0] = (compute_global[0] + (((float3)(((float*)p0_1)[k_outer], ((float*)p0_1)[k_outer], ((float*)p0_1)[k_outer])) * *(float3*)(((float*)p1_1) + (k_outer * 3))));\n", + " }\n", + " float3 v_ = compute_global[0] + *(float3*)(((float*)p2_1) + 0);\n", + " float3 v__1 = (float3)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f);\n", + " *(float3*)(((float*)T_relu_1) + 0) = ((v_) > (v__1) ? (v_) : (v__1));\n", + " return 0;\n", + "}\n", + "\n", + "#ifdef __cplusplus\n", + "extern \"C\"\n", + "#endif\n", + "TVM_DLL int32_t tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_1(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle) {\n", + " int32_t p0_code = arg_type_ids[0];\n", + " int32_t p1_code = arg_type_ids[1];\n", + " int32_t p2_code = arg_type_ids[2];\n", + " int32_t T_relu_code = arg_type_ids[3];\n", + " void* p0 = (((TVMValue*)args)[0].v_handle);\n", + " void* p1 = (((TVMValue*)args)[1].v_handle);\n", + " void* p2 = (((TVMValue*)args)[2].v_handle);\n", + " void* T_relu = (((TVMValue*)args)[3].v_handle);\n", + " void* p0_1 = (((DLTensor*)p0)[0].data);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_1_p0_shape = (((DLTensor*)p0)[0].shape);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_1_p0_strides = (((DLTensor*)p0)[0].strides);\n", + " int32_t dev_id = (((DLTensor*)p0)[0].device.device_id);\n", + " void* p1_1 = (((DLTensor*)p1)[0].data);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_1_p1_shape = (((DLTensor*)p1)[0].shape);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_1_p1_strides = (((DLTensor*)p1)[0].strides);\n", + " void* p2_1 = (((DLTensor*)p2)[0].data);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_1_p2_shape = (((DLTensor*)p2)[0].shape);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_1_p2_strides = (((DLTensor*)p2)[0].strides);\n", + " void* T_relu_1 = (((DLTensor*)T_relu)[0].data);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_1_T_relu_shape = (((DLTensor*)T_relu)[0].shape);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_1_T_relu_strides = (((DLTensor*)T_relu)[0].strides);\n", + " if (!(tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_1_p0_strides == NULL)) {\n", + " }\n", + " if (!(tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_1_p1_strides == NULL)) {\n", + " }\n", + " if (!(tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_1_p2_strides == NULL)) {\n", + " }\n", + " if (!(tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_1_T_relu_strides == NULL)) {\n", + " }\n", + " float2 compute_global[1];\n", + " compute_global[0] = ((float2)(0.000000e+00f, 0.000000e+00f));\n", + " for (int32_t k_outer = 0; k_outer < 3; ++k_outer) {\n", + " compute_global[0] = (compute_global[0] + (((float2)(((float*)p0_1)[k_outer], ((float*)p0_1)[k_outer])) * *(float2*)(((float*)p1_1) + (k_outer * 2))));\n", + " }\n", + " float2 v_ = compute_global[0] + *(float2*)(((float*)p2_1) + 0);\n", + " float2 v__1 = (float2)(0.000000e+00f, 0.000000e+00f);\n", + " *(float2*)(((float*)T_relu_1) + 0) = ((v_) > (v__1) ? (v_) : (v__1));\n", + " return 0;\n", + "}\n", + "\n", + "#ifdef __cplusplus\n", + "extern \"C\"\n", + "#endif\n", + "TVM_DLL int32_t tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_2(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle) {\n", + " int32_t p0_code = arg_type_ids[0];\n", + " int32_t p1_code = arg_type_ids[1];\n", + " int32_t p2_code = arg_type_ids[2];\n", + " int32_t T_relu_code = arg_type_ids[3];\n", + " void* p0 = (((TVMValue*)args)[0].v_handle);\n", + " void* p1 = (((TVMValue*)args)[1].v_handle);\n", + " void* p2 = (((TVMValue*)args)[2].v_handle);\n", + " void* T_relu = (((TVMValue*)args)[3].v_handle);\n", + " void* p0_1 = (((DLTensor*)p0)[0].data);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_2_p0_shape = (((DLTensor*)p0)[0].shape);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_2_p0_strides = (((DLTensor*)p0)[0].strides);\n", + " int32_t dev_id = (((DLTensor*)p0)[0].device.device_id);\n", + " void* p1_1 = (((DLTensor*)p1)[0].data);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_2_p1_shape = (((DLTensor*)p1)[0].shape);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_2_p1_strides = (((DLTensor*)p1)[0].strides);\n", + " void* p2_1 = (((DLTensor*)p2)[0].data);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_2_p2_shape = (((DLTensor*)p2)[0].shape);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_2_p2_strides = (((DLTensor*)p2)[0].strides);\n", + " void* T_relu_1 = (((DLTensor*)T_relu)[0].data);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_2_T_relu_shape = (((DLTensor*)T_relu)[0].shape);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_2_T_relu_strides = (((DLTensor*)T_relu)[0].strides);\n", + " if (!(tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_2_p0_strides == NULL)) {\n", + " }\n", + " if (!(tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_2_p1_strides == NULL)) {\n", + " }\n", + " if (!(tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_2_p2_strides == NULL)) {\n", + " }\n", + " if (!(tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_2_T_relu_strides == NULL)) {\n", + " }\n", + " float compute_global[1];\n", + " compute_global[0] = 0.000000e+00f;\n", + " for (int32_t k_outer = 0; k_outer < 2; ++k_outer) {\n", + " compute_global[0] = (compute_global[0] + (((float*)p0_1)[k_outer] * ((float*)p1_1)[k_outer]));\n", + " }\n", + " float v_ = compute_global[0] + ((float*)p2_1)[0];\n", + " ((float*)T_relu_1)[0] = ((v_) > (0.000000e+00f) ? (v_) : (0.000000e+00f));\n", + " return 0;\n", + "}\n", + "\n", + "#ifdef __cplusplus\n", + "extern \"C\"\n", + "#endif\n", + "TVM_DLL int32_t tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_add(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle) {\n", + " int32_t p0_code = arg_type_ids[0];\n", + " int32_t p1_code = arg_type_ids[1];\n", + " int32_t p2_code = arg_type_ids[2];\n", + " int32_t T_add_code = arg_type_ids[3];\n", + " void* p0 = (((TVMValue*)args)[0].v_handle);\n", + " void* p1 = (((TVMValue*)args)[1].v_handle);\n", + " void* p2 = (((TVMValue*)args)[2].v_handle);\n", + " void* T_add = (((TVMValue*)args)[3].v_handle);\n", + " void* p0_1 = (((DLTensor*)p0)[0].data);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_add_p0_shape = (((DLTensor*)p0)[0].shape);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_add_p0_strides = (((DLTensor*)p0)[0].strides);\n", + " int32_t dev_id = (((DLTensor*)p0)[0].device.device_id);\n", + " void* p1_1 = (((DLTensor*)p1)[0].data);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_add_p1_shape = (((DLTensor*)p1)[0].shape);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_add_p1_strides = (((DLTensor*)p1)[0].strides);\n", + " void* p2_1 = (((DLTensor*)p2)[0].data);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_add_p2_shape = (((DLTensor*)p2)[0].shape);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_add_p2_strides = (((DLTensor*)p2)[0].strides);\n", + " void* T_add_1 = (((DLTensor*)T_add)[0].data);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_add_T_add_shape = (((DLTensor*)T_add)[0].shape);\n", + " void* tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_add_T_add_strides = (((DLTensor*)T_add)[0].strides);\n", + " if (!(tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_add_p0_strides == NULL)) {\n", + " }\n", + " if (!(tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_add_p1_strides == NULL)) {\n", + " }\n", + " if (!(tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_add_p2_strides == NULL)) {\n", + " }\n", + " if (!(tvmgen_default_fused_nn_contrib_dense_pack_add_nn_relu_add_T_add_strides == NULL)) {\n", + " }\n", + " float3 compute_global[1];\n", + " compute_global[0] = ((float3)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f));\n", + " for (int32_t k_outer = 0; k_outer < 3; ++k_outer) {\n", + " compute_global[0] = (compute_global[0] + (((float3)(((float*)p0_1)[k_outer], ((float*)p0_1)[k_outer], ((float*)p0_1)[k_outer])) * *(float3*)(((float*)p1_1) + (k_outer * 3))));\n", + " }\n", + " float3 v_ = compute_global[0] + *(float3*)(((float*)p2_1) + 0);\n", + " float3 v__1 = (float3)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f);\n", + " *(float3*)(((float*)T_add_1) + 0) = (*(float3*)(((float*)p0_1) + 0) + ((v_) > (v__1) ? (v_) : (v__1)));\n", + " return 0;\n", + "}\n", + "\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_87/442107819.py:12: DeprecationWarning: legacy graph executor behavior of producing json / lib / params will be removed in the next release. Please see documents of tvm.contrib.graph_executor.GraphModule for the new recommended usage.\n", + " graph, lib, params = relay.build(\n" + ] + } + ], "source": [ "import onnx\n", "import tvm.relay as relay\n", @@ -659,7 +1516,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.11.2" } }, "nbformat": 4,