Skip to content

Commit

Permalink
Update architect.py
Browse files Browse the repository at this point in the history
[email protected]

Signed-off-by: Chen Pin-Han <72907153+sifa1024​@users.noreply.github.com>
  • Loading branch information
sifa1024 authored and Chen Pin-Han committed Mar 5, 2024
1 parent 03a4001 commit 4d3ea0c
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions examples/v1beta1/trial-images/darts-cnn-cifar10/architect.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,23 @@ def virtual_step(self, train_x, train_y, xi, w_optim):
# Compute gradient
gradients = torch.autograd.grad(loss, self.model.getWeights())

# Check device use cuda or cpu
use_cuda = list(range(torch.cuda.device_count()))
if use_cuda:
print("Using CUDA")
device = torch.device("cuda" if use_cuda else "cpu")

# Do virtual step (Update gradient)
# Below operations do not need gradient tracking
with torch.no_grad():
# dict key is not the value, but the pointer. So original network weight have to
# be iterated also.
for w, vw, g in zip(self.model.getWeights(), self.v_model.getWeights(), gradients):
m = w_optim.state[w].get("momentum_buffer", 0.) * self.w_momentum
vw.copy_(w - torch.FloatTensor(xi) * (m + g + self.w_weight_decay * w))
if(device == 'cuda'):
vw.copy_(w - torch.cuda.FloatTensor(xi) * (m + g + self.w_weight_decay * w))
elif(device == 'cpu'):
vw.copy_(w - torch.FloatTensor(xi) * (m + g + self.w_weight_decay * w))

# Sync alphas
for a, va in zip(self.model.getAlphas(), self.v_model.getAlphas()):
Expand All @@ -72,6 +81,12 @@ def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim):
# Loss for validation with w'. L_valid(w')
loss = self.v_model.loss(valid_x, valid_y)

# Check device use cuda or cpu
use_cuda = list(range(torch.cuda.device_count()))
if use_cuda:
print("Using CUDA")
device = torch.device("cuda" if use_cuda else "cpu")

# Calculate gradient
v_alphas = tuple(self.v_model.getAlphas())
v_weights = tuple(self.v_model.getWeights())
Expand All @@ -85,7 +100,10 @@ def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim):
# Update final gradient = dalpha - xi * hessian
with torch.no_grad():
for alpha, da, h in zip(self.model.getAlphas(), dalpha, hessian):
alpha.grad = da - torch.FloatTensor(xi) * h
if(device == 'cuda'):
alpha.grad = da - torch.cuda.FloatTensor(xi) * h
elif(device == 'cpu'):
alpha.grad = da - torch.cpu.FloatTensor(xi) * h

def compute_hessian(self, dws, train_x, train_y):
"""
Expand Down

0 comments on commit 4d3ea0c

Please sign in to comment.