Skip to content

Commit

Permalink
computing accuracy for node classification
Browse files Browse the repository at this point in the history
  • Loading branch information
WwZzz committed May 8, 2023
1 parent bc8c5ed commit 9fbe3b1
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,19 @@ def test(self, model, dataset, batch_size=64, num_workers=0, pin_memory=False):
dataset.change_mask_for_test()
loader = self.DataLoader([dataset.data], batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
total_loss = 0
total_correct = 0
total_num_samples = 0
for batch in loader:
tdata = self.data_to_device(batch)
outputs = model(tdata)
loss = self.criterion(outputs[tdata.test_mask], tdata.y[tdata.test_mask])
num_samples = len(tdata.x)
total_loss += num_samples * loss
total_correct += outputs[tdata.test_mask].max(1)[1].eq(tdata.y[tdata.test_mask]).sum().item()
total_num_samples += num_samples
total_loss = total_loss.item()
dataset.restore_mask()
return {'loss': total_loss / total_num_samples}
return {'loss': total_loss / total_num_samples, 'accuracy':1.0*total_correct/total_num_samples}

def data_to_device(self, data):
return data.to(self.device)
Expand Down

0 comments on commit 9fbe3b1

Please sign in to comment.