diff --git a/src/brevitas_examples/imagenet_classification/a2q/resnet.py b/src/brevitas_examples/imagenet_classification/a2q/resnet.py index df99f8c40..1ef7d40fc 100644 --- a/src/brevitas_examples/imagenet_classification/a2q/resnet.py +++ b/src/brevitas_examples/imagenet_classification/a2q/resnet.py @@ -1,5 +1,7 @@ -# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024, Advanced Micro Devices, Inc. +# Copyright (c) 2017, liukuang +# All rights reserved. +# SPDX-License-Identifier: MIT import torch.nn as nn import torch.nn.functional as F @@ -21,6 +23,8 @@ def weight_init(layer): class BasicBlock(nn.Module): + """Basic block architecture modified for CIFAR10. + Adapted from https://github.com/kuangliu/pytorch-cifar""" expansion = 1 def __init__(self, in_planes: int, planes: int, stride: int = 1): @@ -55,7 +59,7 @@ def forward(self, x): class ResNet(nn.Module): """ ResNet architecture modified for CIFAR10. - Based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py""" + Adapted from https://github.com/kuangliu/pytorch-cifar""" def __init__(self, block_impl, num_blocks, num_classes: int = 10): super(ResNet, self).__init__()