Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add QuantConv3d and QuantConv3dTranspose #805

Merged
merged 42 commits into from
Mar 7, 2024
Merged
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
5979acb
first pass attempt at implementing QuantConv3D
costigt-dev Jan 19, 2024
9ecea98
placeholder implementation for QuantConvTranspose3d
costigt-dev Jan 19, 2024
091d70d
first implementation of QuantConvTranspose3d
costigt-dev Jan 19, 2024
44dc026
added new conv3d classes to the __init__.py
costigt-dev Jan 22, 2024
9c5fff5
added space to QuantConv3d to be close to other classes in file
costigt-dev Jan 22, 2024
6121a10
adapted conv2d to conv3d
costigt-dev Jan 22, 2024
c568895
removed is_same_padded_strided and its accompanying function as it is…
costigt-dev Jan 23, 2024
cf86c8b
formatting fixes
costigt-dev Jan 24, 2024
c62e07b
Revert "removed is_same_padded_strided and its accompanying function …
costigt-dev Jan 29, 2024
ff281b8
updated references to QuantConv and QuantConvTranpose thoughout code …
costigt-dev Jan 29, 2024
c749625
Merge branch 'dev' into feat/conv3d
costigt-dev Jan 29, 2024
a02aa23
Revert "updated references to QuantConv and QuantConvTranpose thougho…
costigt-dev Jan 30, 2024
10b331e
pre-commit hook changes
costigt-dev Jan 30, 2024
57d8d8b
updated references to QuantConv and QuantConvTranpose thoughout code …
costigt-dev Jan 29, 2024
d71762e
removing unused import
costigt-dev Jan 30, 2024
21d2d08
Merge branch 'dev' of github.com:Xilinx/brevitas into feat/conv3d
costigt-dev Jan 30, 2024
f6559cc
added condition for quantconv2d and made default case conv3d
costigt-dev Jan 31, 2024
bab6042
disable QuantConvTranspose3d in tests
costigt-dev Jan 31, 2024
22fe5c2
restored function necessary for QuantConv3d to be tested
costigt-dev Jan 31, 2024
ee01b84
restored QuantConvTranspose3d to tests
costigt-dev Jan 31, 2024
04ca06c
pre-commit changes
costigt-dev Jan 31, 2024
5ad14d4
Merge branch 'dev' of github.com:Xilinx/brevitas into feat/conv3d
costigt-dev Feb 1, 2024
382c50b
fixed missing parts in test code causing test failures
costigt-dev Feb 1, 2024
a068004
pre-commit hook changes
costigt-dev Feb 1, 2024
f30e36f
fixed typo - should be conv3d
costigt-dev Feb 7, 2024
f35cf4b
removed quantconv3d from flexml.py as it is unnecessary
costigt-dev Feb 8, 2024
8e7b647
made check for 3d version explicit instead of default case
costigt-dev Feb 8, 2024
479eb9f
added tests for conv1d,2d,3d merge batch norm
costigt-dev Feb 8, 2024
51771bc
added tests to check if avgpool is replace with quantconvs and that m…
costigt-dev Feb 8, 2024
4828856
reordered items in SUPPORTED_CONV_OP
costigt-dev Feb 12, 2024
82ea0da
collapsed into isInstance
costigt-dev Feb 12, 2024
9c40dc0
Merge branch 'master' of github.com:Xilinx/brevitas into feat/conv3d
costigt-dev Feb 20, 2024
b827129
resolved merge conflict
costigt-dev Feb 20, 2024
ea729c8
Revert "removed quantconv3d from flexml.py as it is unnecessary"
costigt-dev Feb 22, 2024
a8ad695
correct incorrect value in Kernel3dApplHandlerMixin from 4 to 3
costigt-dev Feb 27, 2024
4c76e4f
updated comments for tensor shapes
costigt-dev Mar 6, 2024
5b94089
updated convtranspose method based on PR suggestion
costigt-dev Mar 6, 2024
b607c8f
added max(...,1) to patch_size calculation
costigt-dev Mar 6, 2024
6271f40
added some basic tests for convtranspose
costigt-dev Mar 6, 2024
6a651c3
updated copyright year
costigt-dev Mar 6, 2024
f4f92e6
changed rounding from floor equivalent to ceil
costigt-dev Mar 6, 2024
3e5db0a
switched torch.ceil to math.ceil
costigt-dev Mar 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
added max(...,1) to patch_size calculation
  • Loading branch information
costigt-dev committed Mar 6, 2024
commit b607c8fe8114b7004af212abdc933e557540af1c
11 changes: 6 additions & 5 deletions src/brevitas/nn/quant_convtranspose.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
group_size = self.out_channels // self.groups
patch_size = (self.kernel_size[0] // self.stride[0])
patch_size = max(self.kernel_size[0] // self.stride[0], 1)
max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width
Expand Down Expand Up @@ -215,8 +215,8 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
group_size = self.out_channels // self.groups
patch_size = (self.kernel_size[0] //
self.stride[0]) * (self.kernel_size[1] // self.stride[1])
patch_size = max(self.kernel_size[0] // self.stride[0], 1) * max(
self.kernel_size[1] // self.stride[1], 1)
max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width
Expand Down Expand Up @@ -313,8 +313,9 @@ def max_acc_bit_width(self, input_bit_width, weight_bit_width):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
group_size = self.out_channels // self.groups
patch_size = (self.kernel_size[0] // self.stride[0]) * (
self.kernel_size[1] // self.stride[1]) * (self.kernel_size[2] // self.stride[2])
patch_size = max(self.kernel_size[0] // self.stride[0], 1) * max(
self.kernel_size[1] // self.stride[1], 1) * max(
self.kernel_size[2] // self.stride[2], 1)
max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width
Loading