From 9fbe3b1a4169bf36e4b4e2fa70223a7ecf2c97b6 Mon Sep 17 00:00:00 2001 From: zzz Date: Tue, 9 May 2023 00:04:06 +0800 Subject: [PATCH] computing accuracy for node classification --- flgo/benchmark/toolkits/graph/node_classification/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flgo/benchmark/toolkits/graph/node_classification/__init__.py b/flgo/benchmark/toolkits/graph/node_classification/__init__.py index ffeb748c..3d52b08b 100644 --- a/flgo/benchmark/toolkits/graph/node_classification/__init__.py +++ b/flgo/benchmark/toolkits/graph/node_classification/__init__.py @@ -192,6 +192,7 @@ def test(self, model, dataset, batch_size=64, num_workers=0, pin_memory=False): dataset.change_mask_for_test() loader = self.DataLoader([dataset.data], batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory) total_loss = 0 + total_correct = 0 total_num_samples = 0 for batch in loader: tdata = self.data_to_device(batch) @@ -199,10 +200,11 @@ def test(self, model, dataset, batch_size=64, num_workers=0, pin_memory=False): loss = self.criterion(outputs[tdata.test_mask], tdata.y[tdata.test_mask]) num_samples = len(tdata.x) total_loss += num_samples * loss + total_correct += outputs[tdata.test_mask].max(1)[1].eq(tdata.y[tdata.test_mask]).sum().item() total_num_samples += num_samples total_loss = total_loss.item() dataset.restore_mask() - return {'loss': total_loss / total_num_samples} + return {'loss': total_loss / total_num_samples, 'accuracy':1.0*total_correct/total_num_samples} def data_to_device(self, data): return data.to(self.device)