Skip to content

Commit

Permalink
docs, notes, names
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed May 20, 2024
1 parent 85112ed commit 29692d2
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
7 changes: 6 additions & 1 deletion dominoes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,12 @@
"\n",
"# TSP Distance Traveled:\n",
"# - Explicitly measure the distance traveled by the agent in the TSP task\n",
"# - Compare to Held-Karp Solution\n"
"# - Compare to Held-Karp Solution\n",
"\n",
"\n",
"# TODO ASAP!!!!!!\n",
"# - Get test result plots in there for good plotting, then start running experiments with different parameters\n",
"# so you can save / see the results!!!\n"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions dominoes/experiments/ptr_arch_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..networks import get_pointer_network, get_pointer_methods, get_pointer_arguments
from .base import Experiment
from . import arglib
from ..plotting import plot_basic_results
from ..plotting import plot_train_results


class PointerArchitectureComparison(Experiment):
Expand Down Expand Up @@ -117,4 +117,4 @@ def plot(self, results):
main plotting loop
"""
pointer_methods = self.pointer_methods()
plot_basic_results(results["train_results"], pointer_methods, train=True)
plot_train_results(self, results["train_results"], pointer_methods, name="training")
9 changes: 9 additions & 0 deletions dominoes/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@


def plot_train_results(exp, results, labels, name="train"):
"""
Simplest plot method for plotting loss/reward (or something else eventually) across epochs
Assumes that the loss/rewards are divided into len(labels) types across the first dimension
(see below for variable names), will make a plot of the mean across epochs for each type and
label it accordingly.
The experiment object passed in as the first argument determines if the plot is saved or shown.
"""
num_types = len(labels)

if "loss" in results and results["loss"] is not None:
Expand Down

0 comments on commit 29692d2

Please sign in to comment.