Skip to content

Commit

Permalink
Merge pull request yandexdataschool#94 from poedator/patch_device_in_…
Browse files Browse the repository at this point in the history
…test

added `device` param in `print_metrics()`
  • Loading branch information
justheuristic authored Oct 28, 2022
2 parents 664803e + ee3e379 commit 1339f19
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions week02_classification/seminar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -562,11 +562,11 @@
"metadata": {},
"outputs": [],
"source": [
"def print_metrics(model, data, batch_size=BATCH_SIZE, name=\"\", **kw):\n",
"def print_metrics(model, data, batch_size=BATCH_SIZE, name=\"\", device=torch.device('cpu'), **kw):\n",
" squared_error = abs_error = num_samples = 0.0\n",
" model.eval()\n",
" with torch.no_grad():\n",
" for batch in iterate_minibatches(data, batch_size=batch_size, shuffle=False, **kw):\n",
" for batch in iterate_minibatches(data, batch_size=batch_size, shuffle=False, device=device, **kw):\n",
" batch_pred = model(batch)\n",
" squared_error += torch.sum(torch.square(batch_pred - batch[TARGET_COLUMN]))\n",
" abs_error += torch.sum(torch.abs(batch_pred - batch[TARGET_COLUMN]))\n",
Expand Down

0 comments on commit 1339f19

Please sign in to comment.