Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ResNet #22

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions applications/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import edugrad.nn as nn
from edugrad.tensor import Tensor

class BasicBlock:
expansion = 1

def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64):
assert groups == 1 and base_width == 64, "BasicBlock only supports groups=1 and base_width=64"
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = []
if stride != 1 or in_planes != self.expansion*planes:
self.downsample = [
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
]

def __call__(self, x):
out = self.bn1(self.conv1(x)).relu()
out = self.bn2(self.conv2(out))
out = out + x.sequential(self.downsample)
out = out.relu()
return out


class Bottleneck:
# NOTE: stride_in_1x1=False, this is the v1.5 variant
expansion = 4

def __init__(self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64):
width = int(planes * (base_width / 64.0)) * groups
# NOTE: the original implementation places stride at the first convolution (self.conv1), control with stride_in_1x1
self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, stride=stride if stride_in_1x1 else 1, bias=False)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1, stride=1 if stride_in_1x1 else stride, groups=groups, bias=False)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(width, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
self.downsample = []
if stride != 1 or in_planes != self.expansion*planes:
self.downsample = [
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
]

def __call__(self, x):
out = self.bn1(self.conv1(x)).relu()
out = self.bn2(self.conv2(out)).relu()
out = self.bn3(self.conv3(out))
out = out + x.sequential(self.downsample)
out = out.relu()
return out

class ResNet:
def __init__(self, num, num_classes=None, groups=1, width_per_group=64, stride_in_1x1=False):
self.num = num
self.block = {
18: BasicBlock,
34: BasicBlock,
50: Bottleneck,
101: Bottleneck,
152: Bottleneck
}[num]

self.num_blocks = {
18: [2,2,2,2],
34: [3,4,6,3],
50: [3,4,6,3],
101: [3,4,23,3],
152: [3,8,36,3]
}[num]

self.in_planes = 64

self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, bias=False, padding=3)
#self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1)
self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1)
self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1)
self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1)
self.fc = nn.Linear(512 * self.block.expansion, num_classes) if num_classes is not None else None

def _make_layer(self, block, planes, num_blocks, stride, stride_in_1x1):
strides = [stride] + [1] * (num_blocks-1)
layers = []
for stride in strides:
if block == Bottleneck:
layers.append(block(self.in_planes, planes, stride, stride_in_1x1, self.groups, self.base_width))
else:
layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width))
self.in_planes = planes * block.expansion
return layers

def forward(self, x):
is_feature_only = self.fc is None
if is_feature_only: features = []
out = self.bn1(self.conv1(x)).relu()
out = out.pad2d([1,1,1,1]).max_pool2d((3,3), 2)
out = out.sequential(self.layer1)
if is_feature_only: features.append(out)
out = out.sequential(self.layer2)
if is_feature_only: features.append(out)
out = out.sequential(self.layer3)
if is_feature_only: features.append(out)
out = out.sequential(self.layer4)
if is_feature_only: features.append(out)
if not is_feature_only:
out = out.mean([2,3])
out = self.fc(out).log_softmax()
return out
return features

def __call__(self, x:Tensor) -> Tensor:
return self.forward(x)


ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes)
ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes)
ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes)
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)
7 changes: 7 additions & 0 deletions edugrad/_tensor/tensor_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,10 @@ def scaled_uniform(*shape, **kwargs) -> Tensor:
from edugrad.tensor import Tensor

return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(shape) ** -0.5)


def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
from edugrad.tensor import Tensor

bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
8 changes: 8 additions & 0 deletions edugrad/_tensor/tensor_nn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Contains typical neural network operations for processing tensors."""

from __future__ import annotations
from typing import Optional
import math

from edugrad.dtypes import dtypes
Expand Down Expand Up @@ -231,6 +232,13 @@ def linear(tensor: Tensor, weight: Tensor, bias: Tensor | None = None):
return x.add(bias) if bias is not None else x


def batchnorm(tensor: Tensor, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor:
x = (tensor - mean.reshape(shape=[1, -1, 1, 1]))
if weight: x = x * weight.reshape(shape=[1, -1, 1, 1])
ret = x.mul(invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd)
return (ret + bias.reshape(shape=[1, -1, 1, 1])) if bias else ret


def binary_crossentropy(tensor: Tensor, y: Tensor) -> Tensor:
"""Computes the binary cross-entropy loss between the predicted tensor and the target tensor.

Expand Down
69 changes: 69 additions & 0 deletions edugrad/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import math
from edugrad.tensor import Tensor
from edugrad.helpers import prod, all_int


class Linear:
def __init__(self, in_features, out_features, bias=True):
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
# TODO: remove this once we can represent Tensor with int shape in typing
assert isinstance(self.weight.shape[1], int), "does not support symbolic shape"
bound = 1 / math.sqrt(self.weight.shape[1])
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None

def __call__(self, x:Tensor):
return x.linear(self.weight.transpose(), self.bias)


class BatchNorm2d:
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum

if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
else: self.weight, self.bias = None, None

self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)

def __call__(self, x:Tensor):
if Tensor.training:
# This requires two full memory accesses to x
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
batch_mean = x.mean(axis=(0,2,3))
y = (x - batch_mean.reshape(shape=[1, -1, 1, 1]))
batch_var = (y*y).mean(axis=(0,2,3))
batch_invstd = batch_var.add(self.eps).pow(-0.5)

# NOTE: wow, this is done all throughout training in most PyTorch models
if self.track_running_stats:
self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * prod(y.shape)/(prod(y.shape) - y.shape[1]) * batch_var.detach() ) # noqa: E501
self.num_batches_tracked += 1
else:
batch_mean = self.running_mean
# NOTE: this can be precomputed for static inference. we expand it here so it fuses
batch_invstd = self.running_var.reshape(1, -1, 1, 1).expand(x.shape).add(self.eps).rsqrt()

return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd)


# TODO: these Conv lines are terrible
def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
return Conv2d(in_channels, out_channels, (kernel_size,), stride, padding, dilation, groups, bias)


class Conv2d:
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
self.weight = self.initialize_weight(out_channels, in_channels, groups)
assert all_int(self.weight.shape), "does not support symbolic shape"
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None

def __call__(self, x:Tensor):
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)

def initialize_weight(self, out_channels, in_channels, groups):
return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
23 changes: 18 additions & 5 deletions edugrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from __future__ import annotations
import time
import math
from typing import ClassVar, Sequence, Any, Type
from functools import reduce
from typing import ClassVar, Sequence, Any, Type, Optional, List, Callable

import numpy as np

Expand All @@ -25,11 +26,11 @@

# fmt: off
from edugrad._tensor.tensor_create import _loadop, empty, manual_seed, rand
from edugrad._tensor.tensor_create import randn, randint, normal, uniform, scaled_uniform
from edugrad._tensor.tensor_create import randn, randint, normal, uniform, scaled_uniform, kaiming_uniform
from edugrad._tensor.tensor_create import full, zeros, ones, arange, eye, full_like, zeros_like, ones_like
from edugrad._tensor.tensor_combine_segment import cat, stack, repeat, chunk
from edugrad._tensor.tensor_reshape import reshape, expand, permute, flip, shrink, pad, pad2d, transpose, _flatten, squeeze, unsqueeze
from edugrad._tensor.tensor_nn import _pool, avg_pool2d, max_pool2d, conv2d, linear, binary_crossentropy, binary_crossentropy_logits, sparse_categorical_crossentropy
from edugrad._tensor.tensor_nn import _pool, avg_pool2d, max_pool2d, conv2d, linear, binary_crossentropy, binary_crossentropy_logits, sparse_categorical_crossentropy, batchnorm
from edugrad._tensor.tensor_index_slice import __getitem__, __setitem__, tslice, gather
from edugrad._tensor.tensor_broadcasted_binary_mlops import _broadcasted, _to_float, add, sub, mul, div, pow, matmul, maximum, minimum, where
from edugrad._tensor.tensor_reduce import _reduce, tsum, tmax, tmin, mean, std, _softmax, softmax, log_softmax, argmax, argmin
Expand Down Expand Up @@ -202,6 +203,11 @@ def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor: return scaled_uniform(*shape, **kwargs)

# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
return kaiming_uniform(*shape, a=a, **kwargs)

def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
Expand Down Expand Up @@ -296,6 +302,8 @@ def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return max_poo
def conv2d(self, weight:Tensor, bias:Tensor | None=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
return conv2d(self, weight, bias, groups, stride, dilation, padding)

def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor:
return batchnorm(self, weight, bias, mean, invstd)
# ------------------------------------------------------------------------------------------------------------------

def dot(self, w:Tensor) -> Tensor:
Expand Down Expand Up @@ -328,12 +336,15 @@ def exp(self): return function.Exp.apply(self)
def relu(self): return function.Relu.apply(self)
def sigmoid(self): return function.Sigmoid.apply(self)
def sqrt(self): return function.Sqrt.apply(self)
def rsqrt(self): return self.reciprocal().sqrt()
def sin(self): return function.Sin.apply(self)
def cos(self): return ((math.pi/2)-self).sin()

# math functions (unary) skipped
# math functions (unary)
def reciprocal(self): return 1.0/self

# activation functions (unary) skipped

# activation functions (unary)
def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu()
# ------------------------------------------------------------------------------------------------------------------
# tensor_bradcasted_binary_mlops.py
Expand Down Expand Up @@ -392,6 +403,8 @@ def __eq__(self, x) -> Tensor: return 1.0-(self != x) # type: ignore

def linear(self, weight:Tensor, bias:Tensor | None=None): return linear(self, weight, bias)

def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return reduce(lambda x,f: f(x), ll, self)

def binary_crossentropy(self, y:Tensor) -> Tensor: return binary_crossentropy(self, y)

def binary_crossentropy_logits(self, y:Tensor) -> Tensor: return binary_crossentropy_logits(self, y)
Expand Down
Loading