-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Aba
committed
Oct 26, 2023
1 parent
da04103
commit 383f235
Showing
1 changed file
with
355 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,355 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(16384, 10240, 26624)" | ||
] | ||
}, | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"import numpy as np\n", | ||
"from collections import namedtuple\n", | ||
"\n", | ||
"ib = 3\n", | ||
"ROWS = 8\n", | ||
"X_PAD = 5\n", | ||
"KH_MAX = 11\n", | ||
"text = '''{\n", | ||
" {.n=8, .l=2, .kw=11, .coe=2, .coe_tl=2, .r_ll=2, .h=10, .w=8, .ci=3, .co=16, .w_kw2=3, .t=8, .p=3, .cm=1, .cm_p0=1, .w_bpt=272, .w_bpt_p0=272, .x_bpt=840, .x_bpt_p0=840, .is_bias=1, .b_offset=0, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=0, .ca_shift=12, .ca_pl_scale=0, .x_header=414341061322735616, .x_header_p0=414341061322735616, .w_header=414587437826703360, .w_header_p0=414341061322735616 },\n", | ||
" {.n=8, .l=2, .kw=1, .coe=24, .coe_tl=0, .r_ll=2, .h=10, .w=8, .ci=16, .co=16, .w_kw2=8, .t=1, .p=1, .cm=20, .cm_p0=16, .w_bpt=392, .w_bpt_p0=392, .x_bpt=13320, .x_bpt_p0=13320, .is_bias=0, .b_offset=16, .b_val_shift=0, .b_bias_shift=0, .ca_nzero=1, .ca_shift=7, .ca_pl_scale=0, .x_header=8700964375684448256, .x_header_p0=8700964375684448256, .w_header=8701210795138088960, .w_header_p0=8700964375684448256 },\n", | ||
" {.n=8, .l=2, .kw=7, .coe=3, .coe_tl=4, .r_ll=2, .h=10, .w=8, .ci=16, .co=16, .w_kw2=5, .t=6, .p=8, .cm=2, .cm_p0=2, .w_bpt=344, .w_bpt_p0=344, .x_bpt=1672, .x_bpt_p0=1672, .is_bias=1, .b_offset=16, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=1, .ca_shift=12, .ca_pl_scale=0, .x_header=846686625550303232, .x_header_p0=846686625550303232, .w_header=846933027824074752, .w_header_p0=846686625550303232 },\n", | ||
" \n", | ||
" {.n=8, .l=2, .kw=5, .coe=4, .coe_tl=4, .r_ll=2, .h=10, .w=8, .ci=16, .co=16, .w_kw2=6, .t=4, .p=4, .cm=4, .cm_p0=4, .w_bpt=488, .w_bpt_p0=488, .x_bpt=3336, .x_bpt_p0=3336, .is_bias=0, .b_offset=34, .b_val_shift=0, .b_bias_shift=0, .ca_nzero=1, .ca_shift=10, .ca_pl_scale=3, .x_header=1927550536119222272, .x_header_p0=1927550536119222272, .w_header=1927796989932601344, .w_header_p0=1927550536119222272 },\n", | ||
" \n", | ||
" {.n=8, .l=2, .kw=3, .coe=8, .coe_tl=8, .r_ll=2, .h=10, .w=8, .ci=16, .co=24, .w_kw2=7, .t=3, .p=3, .cm=6, .cm_p0=4, .w_bpt=440, .w_bpt_p0=296, .x_bpt=5000, .x_bpt_p0=3336, .is_bias=1, .b_offset=34, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=0, .ca_shift=12, .ca_pl_scale=0, .x_header=3008414446688141312, .x_header_p0=1855492942081294336, .w_header=3008660883321651200, .w_header_p0=1855492942081294336 },\n", | ||
" {.n=8, .l=2, .kw=1, .coe=24, .coe_tl=2, .r_ll=2, .h=10, .w=8, .ci=24, .co=50, .w_kw2=8, .t=3, .p=2, .cm=20, .cm_p0=4, .w_bpt=488, .w_bpt_p0=104, .x_bpt=16648, .x_bpt_p0=3336, .is_bias=0, .b_offset=58, .b_val_shift=0, .b_bias_shift=0, .ca_nzero=1, .ca_shift=10, .ca_pl_scale=3, .x_header=11006807384898142208, .x_header_p0=1783435348043366400, .w_header=11007053838711521280, .w_header_p0=1783435348043366400 },\n", | ||
" {.n=1, .l=1, .kw=1, .coe=24, .coe_tl=0, .r_ll=8, .h=8, .w=1, .ci=4000, .co=10, .w_kw2=1, .t=1, .p=200, .cm=20, .cm_p0=20, .w_bpt=488, .w_bpt_p0=488, .x_bpt=138, .x_bpt_p0=138, .is_bias=1, .b_offset=58, .b_val_shift=5, .b_bias_shift=0, .ca_nzero=1, .ca_shift=15, .ca_pl_scale=3, .x_header=10952754293765046272, .x_header_p0=10952754293765046272, .w_header=10952754456973803520, .w_header_p0=10952754293765046272 }\n", | ||
"};\n", | ||
"'''\n", | ||
"\n", | ||
"'''\n", | ||
"PARSE BUNDLES\n", | ||
"'''\n", | ||
"text = text.replace('\\n', '')\n", | ||
"text = text.replace(' ', '')\n", | ||
"text = text.replace(';', '')\n", | ||
"text = text.replace('.', '')\n", | ||
"text = text[2:-2] # remove brackets\n", | ||
"\n", | ||
"b_text_l = text.split('},{')\n", | ||
"bundles = []\n", | ||
"for b_text in b_text_l:\n", | ||
" b_params_l = b_text.split(',')\n", | ||
" b_params_d = {}\n", | ||
" for item in b_params_l:\n", | ||
" key, value = item.split('=')\n", | ||
" b_params_d[key] = int(value)\n", | ||
" bundles += [namedtuple('C_Bundle', b_params_d)(**b_params_d)]\n", | ||
"\n", | ||
"'''\n", | ||
"OTHER PARAMS\n", | ||
"'''\n", | ||
"ye = np.loadtxt(f\"D:/dnn-engine/test/vectors/{ib}_y_exp.txt\", dtype=np.int64)\n", | ||
"yq = np.loadtxt(f\"D:/dnn-engine/test/vectors/{ib}_y_hwc.txt\", dtype=np.int64)\n", | ||
"b = bundles[ib]\n", | ||
"\n", | ||
"if ib == len(bundles)-1:\n", | ||
" xe = yq\n", | ||
" bo = b\n", | ||
"else:\n", | ||
" xe = np.loadtxt(f\"D:/dnn-engine/test/vectors/{ib+1}_xe.txt\", dtype=np.int64)\n", | ||
" bo = bundles[ib+1]\n", | ||
"\n", | ||
"xe_arr = []\n", | ||
"xe_copy = np.copy(xe)\n", | ||
"for ixp in range(bo.p):\n", | ||
" xcm = bo.cm_p0 if ixp==0 else bo.cm\n", | ||
" size = (ROWS+X_PAD)*xcm*bo.w*bo.l*bo.n\n", | ||
" xe_sub_arr = xe_copy[0:size].reshape(bo.n,bo.l,bo.w,xcm,ROWS+X_PAD)\n", | ||
" xe_copy = xe_copy[size:]\n", | ||
" xe_arr += [xe_sub_arr]\n", | ||
"\n", | ||
"ye.size, yq.size, xe.size" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(C_Bundle(n=8, l=2, kw=5, coe=4, coe_tl=4, r_ll=2, h=10, w=8, ci=16, co=16, w_kw2=6, t=4, p=4, cm=4, cm_p0=4, w_bpt=488, w_bpt_p0=488, x_bpt=3336, x_bpt_p0=3336, is_bias=0, b_offset=34, b_val_shift=0, b_bias_shift=0, ca_nzero=1, ca_shift=10, ca_pl_scale=3, x_header=1927550536119222272, x_header_p0=1927550536119222272, w_header=1927796989932601344, w_header_p0=1927550536119222272),\n", | ||
" C_Bundle(n=8, l=2, kw=3, coe=8, coe_tl=8, r_ll=2, h=10, w=8, ci=16, co=24, w_kw2=7, t=3, p=3, cm=6, cm_p0=4, w_bpt=440, w_bpt_p0=296, x_bpt=5000, x_bpt_p0=3336, is_bias=1, b_offset=34, b_val_shift=5, b_bias_shift=0, ca_nzero=0, ca_shift=12, ca_pl_scale=0, x_header=3008414446688141312, x_header_p0=1855492942081294336, w_header=3008660883321651200, w_header_p0=1855492942081294336))" | ||
] | ||
}, | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"b, bo" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"0" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"'''\n", | ||
"Python Reshape: y_engine -> y_hwc\n", | ||
"'''\n", | ||
"\n", | ||
"y1 = np.copy(ye).reshape(b.t, b.n, b.l, b.w*b.coe, ROWS)\n", | ||
"\n", | ||
"y_w_last = y1[:,:,:,-(b.kw//2+1)*b.coe:,:]\n", | ||
"y_w_last = y_w_last.reshape(b.t,b.n,b.l,b.coe,(b.kw//2+1),ROWS)\n", | ||
"y_w_last = y_w_last.transpose(0,1,2,4,3,5) #(t,l,n,(kw//2+1),coe,ROWS)\n", | ||
"y_w_last = y_w_last.reshape(b.t,b.n,b.l,(b.kw//2+1),b.coe,ROWS)\n", | ||
"y_w_last = y_w_last.reshape(b.t,b.n,b.l,(b.kw//2+1)*b.coe,ROWS)\n", | ||
"\n", | ||
"y1[:,:,:,-(b.kw//2+1)*b.coe:,:] = y_w_last\n", | ||
"\n", | ||
"y1 = y1.reshape(b.t,b.n,b.l,b.w,b.coe,ROWS)\n", | ||
"y1 = y1.transpose(1,2,5,3,0,4)\n", | ||
"y1 = y1.reshape((b.n, b.l*ROWS, b.w, b.coe*b.t))\n", | ||
"y1 = y1[:,:b.h,:,:b.co]\n", | ||
"\n", | ||
"np.sum(np.abs(y1 - yq.reshape(y1.shape)))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"0" | ||
] | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"'''\n", | ||
"Python Reshape: y_hwc -> x_engine (bo)\n", | ||
"'''\n", | ||
"\n", | ||
"x1 = np.copy(yq).reshape(bo.n, bo.h, bo.w, bo.ci)\n", | ||
"x1 = np.pad(x1, ((0,0),(0,ROWS*bo.l-bo.h),(0,0),(0,0))) # (XN, L*HL , XW, CI)\n", | ||
"x1 = x1.reshape (bo.n, bo.l, ROWS, bo.w, bo.ci) # (XN, L, HL, XW, CI)\n", | ||
"\n", | ||
"zeros = np.zeros((bo.n, bo.l, ROWS+X_PAD, bo.w, bo.ci),x1.dtype) # (XN,L,ROWS+X_PAD,XW,CI)\n", | ||
"zeros[:,:,:ROWS,:,:] = x1\n", | ||
"\n", | ||
"''' Fill bot rows from next '''\n", | ||
"for l in range(bo.l):\n", | ||
" if l == bo.l-1:\n", | ||
" zeros[:,l, ROWS: ,:,:] = np.zeros((bo.n,X_PAD,bo.w,bo.ci),x1.dtype)\n", | ||
" else:\n", | ||
" zeros[:,l, ROWS: ,:,:] = x1[:,l+1,:X_PAD,:,:]\n", | ||
"\n", | ||
"x1 = zeros # (XN,L,ROWS+X_PAD,XW,CI)\n", | ||
"x1 = x1.transpose(0,1,3,4,2) # (XN,L,XW,CI,ROWS+X_PAD)\n", | ||
"x1 = x1.reshape((bo.n, bo.l, bo.w, bo.ci, (ROWS+X_PAD)))\n", | ||
"\n", | ||
"x_list = []\n", | ||
"ic_left = ic_right = 0\n", | ||
"for ip in range(bo.p):\n", | ||
" CM_p = bo.cm_p0 if ip==0 else bo.cm\n", | ||
" ic_right += CM_p\n", | ||
"\n", | ||
" xp = x1[:,:,:, ic_left:ic_right, :] #(XN, L, XW, CM, (ROWS+bo.x_pad))\n", | ||
" assert xp.shape == (bo.n, bo.l, bo.w, CM_p, (ROWS+X_PAD))\n", | ||
" x_list += [xp.flatten()]\n", | ||
"\n", | ||
" ic_left = ic_right\n", | ||
"\n", | ||
"x1 = np.concatenate(x_list)\n", | ||
"\n", | ||
"np.sum(np.abs(x1 - xe))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(0, 0)" | ||
] | ||
}, | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"yq_exp = np.zeros((b.n, b.h, b.w, b.co), dtype=np.int64)\n", | ||
"ye_flat = ye.flatten()\n", | ||
"xe_gen = np.zeros(xe.size, dtype=np.int64)\n", | ||
"\n", | ||
"def write_xe_gen(val, ixp, ixn, ixl, ixw, ixcm, ir, bo, X_CMP):\n", | ||
" \n", | ||
" exp_val = xe_arr[ixp][ixn,ixl,ixw,ixcm,ir]\n", | ||
" assert val == exp_val, f\"{(val, ixp, ixn, ixl, ixw, ixcm, ir, X_CMP)=}\"\n", | ||
"\n", | ||
" pp_n2r = ixn * ( bo.l * bo.w * X_CMP * (ROWS+X_PAD)) \\\n", | ||
" + ixl * ( bo.w * X_CMP * (ROWS+X_PAD)) \\\n", | ||
" + ixw * ( X_CMP * (ROWS+X_PAD)) \\\n", | ||
" + ixcm * ( (ROWS+X_PAD)) \\\n", | ||
" + ir\n", | ||
"\n", | ||
" if ixp == 0:\n", | ||
" pp = pp_n2r\n", | ||
" else:\n", | ||
" pp = bo.n * bo.l * bo.w * bo.cm_p0 * (ROWS+X_PAD) \\\n", | ||
" +(ixp-1) * (bo.n * bo.l * bo.w * bo.cm * (ROWS+X_PAD)) \\\n", | ||
" + pp_n2r\n", | ||
" \n", | ||
" xe_gen[pp] = val\n", | ||
" \n", | ||
" assert ir < ROWS+X_PAD, f\"{ir=} >= {ROWS+X_PAD=}\"\n", | ||
" assert ixcm < X_CMP , f\"{ixcm=} >= {X_CMP=}\"\n", | ||
" assert ixw < bo.w , f\"{ixw=} >= {bo.w=}\"\n", | ||
" assert ixl < bo.l , f\"{ixl=} >= {bo.l=}\"\n", | ||
" assert ixn < bo.n , f\"{ixn=} >= {bo.n=}\"\n", | ||
" assert ixp < bo.p , f\"{ixp=} >= {bo.p=}\"\n", | ||
"\n", | ||
" assert pp < xe_gen.size, f\"{pp=} >= {xe_gen.size=}; {ir=}/{(ROWS+X_PAD)=}, {ixcm=}/{X_CMP=}, {ixw=}/{bo.w=}, {ixl=}/{bo.l=}, {ixn=}/{bo.n=}, {ixp=}/{bo.p=}; {(ROWS+X_PAD)*bo.w*bo.l*bo.n*(bo.cm_p0+(bo.p-1)*bo.cm)=}, {exp_val=}, {val=}\"\n", | ||
" return pp\n", | ||
"\n", | ||
"y_ptr = 0\n", | ||
"i_xcm = 0\n", | ||
"i_xp = 0\n", | ||
"X_CMP = bo.cm_p0 # since ixp=0\n", | ||
"\n", | ||
"for i_t in range(b.t):\n", | ||
" for i_n in range(b.n):\n", | ||
" for i_l in range(b.l):\n", | ||
" for i_w_kw2 in range(b.w_kw2):\n", | ||
"\n", | ||
" w_last = b.kw//2+1 if i_w_kw2 == b.w_kw2-1 else 1\n", | ||
"\n", | ||
" for i_coe in range (b.coe):\n", | ||
" for iw_last in range(w_last):\n", | ||
" for i_r in range(ROWS):\n", | ||
"\n", | ||
" val = ye_flat[y_ptr]\n", | ||
" y_ptr +=1\n", | ||
"\n", | ||
" i_yn = i_n\n", | ||
" i_yh = ROWS*i_l + i_r\n", | ||
" i_yw = i_w_kw2 + iw_last\n", | ||
" i_yc = b.coe*i_t + i_coe\n", | ||
"\n", | ||
" if i_yh >= b.h or i_yc >= b.co:\n", | ||
" continue\n", | ||
" \n", | ||
" yq_exp[i_yn, i_yh, i_yw, i_yc] = val\n", | ||
" \n", | ||
" '''\n", | ||
" Calc x coordinates: [p, n, l, w,cmp, r+pad]\n", | ||
" '''\n", | ||
"\n", | ||
" i_xn = i_n\n", | ||
" i_xw = i_yw\n", | ||
" i_xh = i_yh\n", | ||
" i_xr = i_xh % ROWS\n", | ||
" i_xl = i_xh // ROWS\n", | ||
"\n", | ||
" if i_yc < bo.cm_p0:\n", | ||
" i_xp = 0\n", | ||
" i_xcm = i_yc\n", | ||
" X_CMP = bo.cm_p0\n", | ||
" else:\n", | ||
" i_xp = (i_yc - bo.cm_p0) // bo.cm + 1\n", | ||
" i_xcm = (i_yc - bo.cm_p0) % bo.cm\n", | ||
" X_CMP = bo.cm\n", | ||
"\n", | ||
" ''' Write Val '''\n", | ||
" write_xe_gen(val, i_xp, i_xn, i_xl, i_xw, i_xcm, i_xr, bo, X_CMP)\n", | ||
"\n", | ||
"\n", | ||
" ''' Padding the [bottom X_PAD rows of previous block (l-1)] with [first X_PAD rows of this block (l)]'''\n", | ||
" if i_xr < X_PAD:\n", | ||
" if i_xl == 0:\n", | ||
" write_xe_gen(0, i_xp, i_xn, bo.l-1, i_xw, i_xcm, i_xr+ROWS, bo, X_CMP)\n", | ||
" # print(xp, xe_gen[xp], 'pad zero')\n", | ||
" else:\n", | ||
" write_xe_gen(val, i_xp, i_xn, i_xl-1, i_xw, i_xcm, i_xr+ROWS, bo, X_CMP)\n", | ||
" # print(xp, xe_gen[xp], 'pad val')\n", | ||
" \n", | ||
" # if (i_l == bo.l-1) and (i_xr == bo.r_ll-1):\n", | ||
" # '''Last row of last block in y, but i_xr is not complete (need zero padding for block)'''\n", | ||
" # write_xe_gen(0,i_xp, i_xn, bo.l-1, i_xw, i_xcm, i_xr+ROWS, bo, X_CMP)\n", | ||
" # i_xr += 1\n", | ||
" # else:\n", | ||
" # break\n", | ||
"\n", | ||
" \n", | ||
"\n", | ||
"np.sum(np.abs(yq_exp.flatten()-yq)), np.sum(np.abs(xe_gen - xe))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "torch", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.10" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |