Skip to content

Commit

Permalink
Update architect.py
Browse files Browse the repository at this point in the history
[email protected]

Signed-off-by: Chen Pin-Han <72907153+sifa1024​@users.noreply.github.com>
  • Loading branch information
sifa1024 authored and Chen Pin-Han committed Mar 6, 2024
1 parent 4f3ef04 commit bbb9f3e
Showing 1 changed file with 7 additions and 17 deletions.
24 changes: 7 additions & 17 deletions examples/v1beta1/trial-images/darts-cnn-cifar10/architect.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ class Architect():
"""" Architect controls architecture of cell by computing gradients of alphas
"""

def __init__(self, model, w_momentum, w_weight_decay):
def __init__(self, model, w_momentum, w_weight_decay, device):
self.model = model
self.v_model = copy.deepcopy(model)
self.w_momentum = w_momentum
self.w_weight_decay = w_weight_decay
self.device = device

def virtual_step(self, train_x, train_y, xi, w_optim):
"""
Expand All @@ -43,14 +44,9 @@ def virtual_step(self, train_x, train_y, xi, w_optim):
# Forward and calculate loss
# Loss for train with w. L_train(w)
loss = self.model.loss(train_x, train_y)

# Compute gradient
gradients = torch.autograd.grad(loss, self.model.getWeights())

# Check device use cuda or cpu
use_cuda = list(range(torch.cuda.device_count()))
if use_cuda:
print("Using CUDA")
device = torch.device("cuda" if use_cuda else "cpu")

# Do virtual step (Update gradient)
# Below operations do not need gradient tracking
Expand All @@ -59,9 +55,9 @@ def virtual_step(self, train_x, train_y, xi, w_optim):
# be iterated also.
for w, vw, g in zip(self.model.getWeights(), self.v_model.getWeights(), gradients):
m = w_optim.state[w].get("momentum_buffer", 0.) * self.w_momentum
if(device == 'cuda'):
if(self.device == 'cuda'):
vw.copy_(w - torch.cuda.FloatTensor(xi) * (m + g + self.w_weight_decay * w))
elif(device == 'cpu'):
elif(self.device == 'cpu'):
vw.copy_(w - torch.FloatTensor(xi) * (m + g + self.w_weight_decay * w))

# Sync alphas
Expand All @@ -80,12 +76,6 @@ def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim):
# Calculate unrolled loss
# Loss for validation with w'. L_valid(w')
loss = self.v_model.loss(valid_x, valid_y)

# Check device use cuda or cpu
use_cuda = list(range(torch.cuda.device_count()))
if use_cuda:
print("Using CUDA")
device = torch.device("cuda" if use_cuda else "cpu")

# Calculate gradient
v_alphas = tuple(self.v_model.getAlphas())
Expand All @@ -100,9 +90,9 @@ def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim):
# Update final gradient = dalpha - xi * hessian
with torch.no_grad():
for alpha, da, h in zip(self.model.getAlphas(), dalpha, hessian):
if(device == 'cuda'):
if(self.device == 'cuda'):
alpha.grad = da - torch.cuda.FloatTensor(xi) * h
elif(device == 'cpu'):
elif(self.device == 'cpu'):
alpha.grad = da - torch.cpu.FloatTensor(xi) * h

def compute_hessian(self, dws, train_x, train_y):
Expand Down

0 comments on commit bbb9f3e

Please sign in to comment.