Skip to content

Commit

Permalink
add force dsp mult
Browse files Browse the repository at this point in the history
  • Loading branch information
jmduarte committed May 14, 2021
1 parent 2f5fac5 commit 7078e94
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 16 deletions.
18 changes: 9 additions & 9 deletions hls4ml/model/hls_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def config_cpp(self):
params['n_out'] = self.get_output_variable().size_cpp()
params['nzeros'] = self.get_weights('weight').nzeros
params['nonzeros'] = self.get_weights('weight').nonzeros
params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('weight').type.precision)
params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('weight').type.precision, force_dsp=self.get_attr('force_dsp', False))
params['strategy'] = self.get_attr('strategy')

return self._config_template.format(**params)
Expand Down Expand Up @@ -694,7 +694,7 @@ def config_cpp(self):
mult_params = self._default_config_params()
mult_params['n_in'] = self.get_attr('n_chan') * self.get_attr('filt_width')
mult_params['n_out'] = self.get_attr('n_filt')
mult_params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('weight').type.precision)
mult_params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('weight').type.precision, force_dsp=self.get_attr('force_dsp', False))
mult_config = self._config_template[1].format(**mult_params)

return mult_config + '\n' + conv_config
Expand Down Expand Up @@ -786,7 +786,7 @@ def config_cpp(self):
mult_params['n_in'] = self.get_attr('n_chan') * self.get_attr('filt_width')
mult_params['n_out'] = self.get_attr('n_chan')
mult_params['weight_t'] = self.get_weights('depthwise').type.name
mult_params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('depthwise').type.precision)
mult_params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('depthwise').type.precision, force_dsp=self.get_attr('force_dsp', False))
depthwise_mult_config = self._config_template[3].format(**mult_params)

# Pointwise config
Expand Down Expand Up @@ -818,7 +818,7 @@ def config_cpp(self):
mult_params['n_in'] = self.get_attr('n_chan')
mult_params['n_out'] = self.get_attr('n_filt')
mult_params['weight_t'] = self.get_weights('pointwise').type.name
mult_params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('pointwise').type.precision)
mult_params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('pointwise').type.precision, force_dsp=self.get_attr('force_dsp', False))
pointwise_mult_config = self._config_template[4].format(**mult_params)

return depthwise_mult_config + '\n' + depthwise_config + '\n' + pointwise_mult_config + '\n' + pointwise_config + '\n' + sep_config
Expand Down Expand Up @@ -891,7 +891,7 @@ def config_cpp(self):
mult_params = self._default_config_params()
mult_params['n_in'] = self.get_attr('n_chan') * self.get_attr('filt_height') * self.get_attr('filt_width')
mult_params['n_out'] = self.get_attr('n_filt')
mult_params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('weight').type.precision)
mult_params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('weight').type.precision, force_dsp=self.get_attr('force_dsp', False))
mult_config = self._config_template[1].format(**mult_params)

return mult_config + '\n' + conv_config
Expand Down Expand Up @@ -1042,7 +1042,7 @@ def config_cpp(self):
mult_params['n_in'] = self.get_attr('n_chan') * self.get_attr('filt_height') * self.get_attr('filt_width')
mult_params['n_out'] = self.get_attr('n_chan')
mult_params['weight_t'] = self.get_weights('depthwise').type.name
mult_params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('depthwise').type.precision)
mult_params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('depthwise').type.precision, force_dsp=self.get_attr('force_dsp', False))
depthwise_mult_config = self._config_template[3].format(**mult_params)

# Pointwise config
Expand Down Expand Up @@ -1080,7 +1080,7 @@ def config_cpp(self):
mult_params['n_in'] = self.get_attr('n_chan')
mult_params['n_out'] = self.get_attr('n_filt')
mult_params['weight_t'] = self.get_weights('pointwise').type.name
mult_params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('pointwise').type.precision)
mult_params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('pointwise').type.precision, force_dsp=self.get_attr('force_dsp', False))
pointwise_mult_config = self._config_template[4].format(**mult_params)

return depthwise_mult_config + '\n' + depthwise_config + '\n' + pointwise_mult_config + '\n' + pointwise_config + '\n' + sep_config
Expand Down Expand Up @@ -1371,7 +1371,7 @@ def function_cpp(self):
def config_cpp(self):
params = self._default_config_params()
params['n_in'] = self.get_input_variable().size_cpp()
params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('scale').type.precision)
params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('scale').type.precision, force_dsp=self.get_attr('force_dsp', False))

return self._config_template.format(**params)

Expand Down Expand Up @@ -1429,7 +1429,7 @@ def config_cpp(self):
params = self._default_config_params()
params['n_out'] = 1
params['n_in'] = inp1.shape[0]
params['product_type'] = self.model.config.backend.product_type(inp1.type.precision, inp2.type.precision)
params['product_type'] = self.model.config.backend.product_type(inp1.type.precision, inp2.type.precision, force_dsp=self.get_attr('force_dsp', False))
return self._config_template.format(**params)

class Concatenate(Merge):
Expand Down
46 changes: 40 additions & 6 deletions hls4ml/templates/vivado/nnet_utils/nnet_mult.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,52 @@ class weight_ternary : public Product<x_T, w_T, y_T>{
}
};

template<class x_T, class w_T, class y_T>
class mult_dsp : public Product<x_T, w_T, y_T>{
public:
template<typename T> constexpr
static T const& max(T const& a, T const& b) {
return a > b ? a : b;
}
static y_T product(x_T a, w_T w){
y_T res;
#pragma HLS INLINE
if(a.width < 18 && w.width < 18)
{
constexpr int bias = max<int>(18 - a.width, 18 - w.width);
ap_fixed<bias + a.width, bias + a.iwidth> a_ext = a;
ap_fixed<bias + w.width, bias + w.iwidth> w_ext = w;
res = a_ext * w_ext;
}
else if (a.width < 18)
{
constexpr int bias = 18 - a.width;
ap_fixed<bias + a.width, bias + a.iwidth> a_ext = a;
res = a_ext * w;
}
else if (w.width < 18)
{
constexpr int bias = 18 - w.width;
ap_fixed<bias + w.width, bias + w.iwidth> w_ext = w;
res = a * w_ext;
}
else
res = a * w;
return res;
}
static void limit(unsigned multiplier_limit){
#pragma HLS INLINE
#pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation
}
};

template<class x_T, class w_T, class y_T>
class mult : public Product<x_T, w_T, y_T>{
public:
static y_T product(x_T a, w_T w){
// 'Normal' product
#pragma HLS INLINE
//return a * w;
ap_fixed<16,6> tmp1, tmp2, tmp3;
tmp1 = a;
tmp2 = w;
tmp3 = tmp1*tmp2;
return tmp3;
return a * w;
}
static void limit(unsigned multiplier_limit){
#pragma HLS INLINE
Expand Down
4 changes: 3 additions & 1 deletion hls4ml/templates/vivado_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def convert_precision_string(self, precision):
elif 'int' in precision:
return IntegerPrecisionType(W, signed)

def product_type(self, data_T, weight_T):
def product_type(self, data_T, weight_T, force_dsp=False):
'''
Helper function to determine which product implementation to use during inference
'''
Expand All @@ -516,6 +516,8 @@ def product_type(self, data_T, weight_T):
product = 'weight_binary'
elif isinstance(weight_T, IntegerPrecisionType) and weight_T.width == 2 and weight_T.signed:
product = 'weight_ternary'
elif force_dsp:
product = 'mult_dsp'
else:
product = 'mult'
return product
Expand Down

0 comments on commit 7078e94

Please sign in to comment.