Skip to content

Commit

Permalink
Add softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
Aba committed Nov 20, 2023
1 parent 2944622 commit d304ec8
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 31 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ x = Bundle( core= {'type':'conv' , 'filters':8 , 'kernel_size':( 7, 7)
x = Bundle( core= {'type':'conv' , 'filters':8 , 'kernel_size':( 5, 5), 'strides':(1,1), 'padding':'same', 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':q4}, add = {'act_str':f'quantized_bits({hw.X_BITS},0,False,True,1)'})(x, x_skip1)
x = Bundle( core= {'type':'conv' , 'filters':24, 'kernel_size':( 3, 3), 'strides':(1,1), 'padding':'same', 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':q1},)(x)
x = Bundle( core= {'type':'conv' , 'filters':10, 'kernel_size':( 1, 1), 'strides':(1,1), 'padding':'same', 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':q4}, flatten= True)(x)
x = Bundle( core= {'type':'dense', 'units' :10, 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':q4})(x)
x = Bundle( core= {'type':'dense', 'units' :10, 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':q4}, softmax= True)(x)

model = QModel(inputs=x_in.raw, outputs=x)
model.compile()
Expand Down
43 changes: 35 additions & 8 deletions deepsocflow/c/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <stdlib.h>
#include <limits.h>
#include <stdint.h>
#include <math.h>

#ifdef VERILATOR
#define EXT_C "C"
Expand All @@ -14,9 +15,10 @@ typedef const struct {
const int32_t n, l, kw, coe, coe_tl, r_ll, h, w, ci, co, w_kw2, t, p, cm, cm_p0, xp_words;
const int32_t w_bpt, w_bpt_p0, x_bpt, x_bpt_p0, o_words, o_bytes; // bytes per transfer
const int8_t out_buffer_idx, add_out_buffer_idx, add_in_buffer_idx;
const int8_t is_bias, is_pool, is_flatten;
const int8_t is_bias, is_pool, is_flatten, is_softmax;
const int32_t b_offset, b_val_shift, b_bias_shift;
const int8_t ca_nzero, ca_shift, ca_pl_scale, add_act_shift, pool_act_shift;
const int8_t ca_nzero, ca_shift, ca_pl_scale, add_act_shift, pool_act_shift, softmax_frac;
const float softmax_max_f;
const int32_t csh, ch, csh_shift, pkh, psh, ph, psh_shift, csw, cw, csw_shift, pkw, psw, pw, psw_shift, pool, on, oh, ow, oc;
const uint64_t x_header, x_header_p0, w_header, w_header_p0; // 64 bits (at least)
const int32_t debug_nhwc_words;
Expand All @@ -33,7 +35,7 @@ typedef struct {
int8_t w [W_BYTES ];
B_TYPE b [B_WORDS ]; // keep next to w. weights are loaded to w_ptr
int8_t x [X_BYTES_ALL ];
int32_t y [O_WORDS ];
O_TYPE y [O_WORDS ];
int32_t nhwc [NHWC_WORDS ];
int8_t debug_tiled [O_WORDS_MAX ];
int32_t debug_nhwc [NHWC_WORDS ];
Expand Down Expand Up @@ -275,10 +277,6 @@ extern EXT_C void load_y (uint8_t *p_done, uint8_t *pt_done_proc, const uint32_
// ------ CORE ACT ------
out_val = quant_lrelu(out_val, pb->ca_nzero, pb->ca_shift, pb->ca_pl_scale);


// ------ SOFTMAX ------


// ------ RESIDUAL ADD ---

if (pb->add_in_buffer_idx != -1) {
Expand All @@ -288,6 +286,32 @@ extern EXT_C void load_y (uint8_t *p_done, uint8_t *pt_done_proc, const uint32_
out_val = clip(out_val, -(1<<(X_BITS-1)), (1<<(X_BITS-1))-1);
}

// ------ SOFTMAX ------

if (pb->is_softmax) {
assert_printf (ib , !=, N_BUNDLES, "Softmax is only allowed for the last bundle.", DEBUG_INFO);

float val = (float)out_val;
val = val / (float)(1 << pb->softmax_frac);
val = val - pb->softmax_max_f;
val = (float)exp(val);

mem.y[iy_nhwc] = val;

if (i_yc == pb->co-1) {
float sum = 0;
int32_t iy_nhwc;
for (int i=0; i<pb->co; i++){
iy_nhwc = flatten_nhwc(i_yn,i_yh,i_yw,i, yn,yh,yw,yc, "Before softmax sum", DEBUG_INFO);
sum += mem.y[iy_nhwc];
}
for (int i=0; i<pb->co; i++){
iy_nhwc = flatten_nhwc(i_yn,i_yh,i_yw,i, yn,yh,yw,yc, "After softmax sum", DEBUG_INFO);
mem.y[iy_nhwc] = mem.y[iy_nhwc] / sum;
}
}
goto PROCESS_AND_STORE_DONE;
}

// ------ MAX/AVG POOL ---

Expand Down Expand Up @@ -388,7 +412,10 @@ extern EXT_C void load_y (uint8_t *p_done, uint8_t *pt_done_proc, const uint32_
sprintf(f_path_tiled, "%s/%0d_y_tiled_sim.txt", DATA_DIR, ib);
FILE *fp_tiled = fopen(f_path_tiled, "w");
for (int32_t i=0; i<pb->o_words; i++)
fprintf(fp_tiled,"%d\n", ib == N_BUNDLES-1 ? mem.y[i] : mem.debug_tiled[i]);
if (ib == N_BUNDLES-1)
if (pb->is_softmax) fprintf(fp_tiled,"%f\n", mem.y[i]);
else fprintf(fp_tiled,"%d\n", mem.y[i]);
else fprintf(fp_tiled,"%d\n", mem.debug_tiled[i]);
fclose(fp_tiled);

if (ib != N_BUNDLES-1){
Expand Down
12 changes: 10 additions & 2 deletions deepsocflow/py/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,12 +432,20 @@ def apply_act(act_dict):


if self.softmax:
self.proc['int'] = self.proc['int'] / 2**self.proc['frac']
exp = np.exp(self.proc['int'] - self.proc['int'].max())
self.before_softmax = np.copy(self.proc['int'])
self.softmax_frac = self.proc['frac']
self.proc['int'] = self.proc['int'] / 2**self.softmax_frac

self.softmax_max_f = self.proc['int'].max()
exp = np.exp(self.proc['int'] - self.softmax_max_f)
self.proc['int'] = exp/np.sum(exp, axis=1)[0]

assert np.all(np.argmax(self.out['int'], axis=-1) == np.argmax(self.proc['int'], axis=-1))
else:
self.softmax_frac = 0
self.softmax_max_f = 0
assert np.all(self.proc['int'] == self.out['int']), f"Overall output of bundle {self.idx} is not a fixed point"
self.o_exp = self.proc['int']

@staticmethod
def get_compile_params(bundles, ROWS, COLS):
Expand Down
37 changes: 26 additions & 11 deletions deepsocflow/py/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,23 +144,24 @@ def export_inference(self, x, hw):
pool_type = 'POOL_AVG'
pool_act_shift = b.pool['act']['shift_bits'] if b.pool is not None else 0

out_type = 'float' if (ib == len(bundles)-1 and b.softmax) else 'int32_t'

out_buffer_idx = 1*(not out_buffer_idx) if ib != len(bundles)-1 else -1 # alternate between 0 and 1

ch.write(f" {{.n={b.r.XN:<3}, .l={b.r.XL:<3}, .kw={b.r.KW:<3}, .coe={y_coe:<3}, .coe_tl={y_coe_tl:<3}, .r_ll={y_r_ll:<3}, .h={b.r.XH:<3}, .w={b.r.XW:<3}, .ci={b.r.CI:<4}, .co={b.r.CO:<3}, .w_kw2={b.r.XW-b.r.KW//2:<3}, .t={b.r.IT:<3}, .p={b.r.CP:<3}, .cm={b.r.CM:<3}, .cm_p0={b.r.CM_0:<3}, .xp_words={xp_words:<3}, ")
ch.write( f".w_bpt={w_bpt:<5}, .w_bpt_p0={w_bpt_p0:<5}, .x_bpt={x_bpt:<5}, .x_bpt_p0={x_bpt_p0:<5}, .o_words={o_words_b:<5}, .o_bytes={o_bytes_b:<5}, ")
ch.write( f".out_buffer_idx={out_buffer_idx:<2}, .add_out_buffer_idx={add_out_buffer_idx:<2}, .add_in_buffer_idx={add_in_buffer_idx:<2}, ")
ch.write( f".is_bias={1*(b.b is not None):<3}, .is_flatten={1*b.flatten:<3}, ")
ch.write( f".is_bias={1*(b.b is not None):<3}, .is_flatten={1*b.flatten:<3}, .is_softmax={1*b.softmax:<3}, ")
ch.write( f".b_offset={b_words:<3}, .b_val_shift={b.bias_val_shift:<3}, .b_bias_shift={b.bias_b_shift:<3}, ")
ch.write( f".ca_nzero={ca_nzero:<3}, .ca_shift={ca_shift:<3}, .ca_pl_scale={ca_pl_scale:<3}, .add_act_shift={add_act_shift:<3}, .pool_act_shift={pool_act_shift:<3}, ")
ch.write( f".ca_nzero={ca_nzero:<3}, .ca_shift={ca_shift:<3}, .ca_pl_scale={ca_pl_scale:<3}, .add_act_shift={add_act_shift:<3}, .pool_act_shift={pool_act_shift:<3}, .softmax_frac={b.softmax_frac:<3}, ")
ch.write( f".softmax_max_f={b.softmax_max_f:<15}, ")
ch.write( f".csh={b.r.CSH:<3}, .ch={b.r.CYH:<3}, .csh_shift={b.r.CSH_SHIFT:<3}, .pkh={b.r.PKH:<3}, .psh={b.r.PSH:<3}, .ph={b.r.PYH:<3}, .psh_shift={b.r.PSH_SHIFT:<3}, .csw={b.r.CSW:<3}, .cw={b.r.CYW:<3}, .csw_shift={b.r.CSW_SHIFT:<3}, .pkw={b.r.PKW:<3}, .psw={b.r.PSW:<3}, .pw={b.r.PYW:<3}, .psw_shift={b.r.PSW_SHIFT:<3}, .pool={pool_type:<10}, .on={b.r.ON:<3}, .oh={b.r.OH:<3}, .ow={b.r.OW:<3}, .oc={b.r.OC:<3}, ")
ch.write( f".x_header={b.r.x_header_le_p[-1][0]:>23}u, .x_header_p0={b.r.x_header_le_p[0][0]:>23}u, .w_header={b.r.w_header_le_p[-1][0]:>23}u, .w_header_p0={b.r.x_header_le_p[0][0]:>25}u , ")
ch.write( f".debug_nhwc_words={b.oe_exp_nhwc.size:<5} }}")

b_words += b.be.size if b.b else 0
if b.idx != len(bundles)-1:
ch.write(',\n')

''' Bit masks for X_BITS '''


ch.write(f"\n}};\n\n")
Expand All @@ -181,6 +182,7 @@ def export_inference(self, x, hw):
ch.write(f"#define X_BYTES_ALL {x_bytes_all}\n")
ch.write(f"#define NHWC_WORDS {nhwc_words_max}\n")
ch.write(f"#define B_TYPE int{hw.B_BITS}_t\n")
ch.write(f"#define O_TYPE {out_type}\n")
ch.write(f"#define B_WORDS {b_words}\n")
ch.write(f'#define DATA_DIR "../{hw.DATA_DIR}"\n\n')

Expand Down Expand Up @@ -285,15 +287,28 @@ def verify_inference(self, SIM, SIM_PATH):
assert error == 0, f"Error={error}, for y_sum_sim at {b.idx=}"

''' Verify processed output HWC'''
y_nhwc_sim = np.loadtxt(f"{hw.DATA_DIR}/{b.idx}_y_nhwc_sim.txt",np.int32).reshape(b.oe_exp_nhwc.shape)
error = np.sum(np.abs(y_nhwc_sim - b.oe_exp_nhwc))
assert error == 0, f"sim:\n{y_nhwc_sim[0,:,:,0]}\n exp:\n{b.oe_exp_nhwc[0,:,:,0]}\n input:\n{b.before_pool[0,:,:,0] if b.pool else None}"
if not (ib == len(bundles)-1 and b.softmax):
y_nhwc_sim = np.loadtxt(f"{hw.DATA_DIR}/{b.idx}_y_nhwc_sim.txt",np.int32).reshape(b.oe_exp_nhwc.shape)
error = np.sum(np.abs(y_nhwc_sim - b.oe_exp_nhwc))
assert error == 0, f"sim:\n{y_nhwc_sim[0,:,:,0]}\n exp:\n{b.oe_exp_nhwc[0,:,:,0]}\n input:\n{b.before_pool[0,:,:,0] if b.pool else None}"


''' Verify tiled output'''
y_tiled_exp = b.o_int if ib == len(bundles)-1 else np.concatenate([a.flatten() for a in bundles[ib+1].xe])
y_tiled_sim = np.loadtxt(f"{hw.DATA_DIR}/{b.idx}_y_tiled_sim.txt", np.int32).reshape(y_tiled_exp.shape)
error = np.sum(np.abs(y_tiled_sim-y_tiled_exp))
assert error == 0, f"Error={error}, for y_tiled_sim at {b.idx=}"
if (ib == len(bundles)-1):
y_tiled_exp = b.o_int
if b.softmax:
y_tiled_sim = np.loadtxt(f"{hw.DATA_DIR}/{b.idx}_y_tiled_sim.txt", np.float32).reshape(y_tiled_exp.shape)
error = np.sum(np.abs(y_tiled_sim-y_tiled_exp))
assert np.allclose(y_tiled_sim, y_tiled_exp, 1e-3), f"Error={error}, for y_tiled_sim at {b.idx=}. \n y_tiled_sim=\n{y_tiled_sim} \n y_tiled_exp=\n{y_tiled_exp}\n \nbefore_softmax=\n{b.before_softmax}"
else:
y_tiled_sim = np.loadtxt(f"{hw.DATA_DIR}/{b.idx}_y_tiled_sim.txt", np.float32).reshape(y_tiled_exp.shape)
error = np.sum(np.abs(y_tiled_sim-y_tiled_exp))
assert error == 0, f"Error={error}, for y_tiled_sim at {b.idx=}"
else:
y_tiled_exp = np.concatenate([a.flatten() for a in bundles[ib+1].xe])
y_tiled_sim = np.loadtxt(f"{hw.DATA_DIR}/{b.idx}_y_tiled_sim.txt", np.float32).reshape(y_tiled_exp.shape)
error = np.sum(np.abs(y_tiled_sim-y_tiled_exp))
assert error == 0, f"Error={error}, for y_tiled_sim at {b.idx=}"

''' Verify packed output'''
if ib != len(bundles)-1:
Expand Down
2 changes: 1 addition & 1 deletion run/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
x = Bundle( core= {'type':'conv' , 'filters':8 , 'kernel_size':( 5, 5), 'strides':(1,1), 'padding':'same', 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':q4}, add = {'act_str':f'quantized_bits({hw.X_BITS},0,False,True,1)'})(x, x_skip1)
x = Bundle( core= {'type':'conv' , 'filters':24, 'kernel_size':( 3, 3), 'strides':(1,1), 'padding':'same', 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':q1},)(x)
x = Bundle( core= {'type':'conv' , 'filters':10, 'kernel_size':( 1, 1), 'strides':(1,1), 'padding':'same', 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':q4}, flatten= True)(x)
x = Bundle( core= {'type':'dense', 'units' :10, 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':q4})(x)
x = Bundle( core= {'type':'dense', 'units' :10, 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':q4}, softmax= True)(x)

model = QModel(inputs=x_in.raw, outputs=x)
model.compile()
Expand Down
2 changes: 1 addition & 1 deletion run/param_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_dnn_engine(PARAMS):
x = Bundle( core= {'type':'conv' , 'filters':8 , 'kernel_size':( 5, 5), 'strides':(1,1), 'padding':'same', 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':q4}, add = {'act_str':f'quantized_bits({hw.X_BITS},0,False,True,1)'})(x, x_skip1)
x = Bundle( core= {'type':'conv' , 'filters':24, 'kernel_size':( 3, 3), 'strides':(1,1), 'padding':'same', 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':q1},)(x)
x = Bundle( core= {'type':'conv' , 'filters':10, 'kernel_size':( 1, 1), 'strides':(1,1), 'padding':'same', 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':q4}, flatten= True)(x)
x = Bundle( core= {'type':'dense', 'units' :10, 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':q4})(x)
x = Bundle( core= {'type':'dense', 'units' :10, 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':q4}, softmax= True)(x)

model = QModel(inputs=x_in.raw, outputs=x)
model.compile()
Expand Down
Loading

0 comments on commit d304ec8

Please sign in to comment.