Skip to content

Commit

Permalink
fix: differentiate between scalar and array reward in plot utility
Browse files Browse the repository at this point in the history
  • Loading branch information
JeanElsner committed Jun 21, 2024
1 parent 4ad77d1 commit 27710fd
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions src/dm_robotics/panda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ def __init__(self,
self.maxlines = None

def _init_buffer(self):
self.maxlines = self._rt._time_step.reward.shape[0]
if isinstance(self._rt._time_step.reward, np.ndarray):
self.maxlines = self._rt._time_step.reward.shape[0]

Check warning on line 263 in src/dm_robotics/panda/utils.py

View check run for this annotation

Codecov / codecov/patch

src/dm_robotics/panda/utils.py#L263

Added line #L263 was not covered by tests
else:
self.maxlines = 1
for _1 in range(self.maxlines):
self.y.append(deque(maxlen=self.maxlen))
self.reset_data()
Expand All @@ -269,12 +272,19 @@ def render(self, context, viewport):
return
if self.maxlines is None:
self._init_buffer()
for i, r in enumerate(self._rt._time_step.reward):
self.fig.linepnt[i] = self.maxlen
self.y[i].append(r)
self.fig.linedata[i][:self.maxlen * 2] = np.array([self.x,
self.y[i]]).T.reshape(
(-1,))
if self.maxlines > 1:
for i, r in enumerate(self._rt._time_step.reward):
self.fig.linepnt[i] = self.maxlen
self.y[i].append(r)
self.fig.linedata[i][:self.maxlen * 2] = np.array([self.x, self.y[i]

Check warning on line 279 in src/dm_robotics/panda/utils.py

View check run for this annotation

Codecov / codecov/patch

src/dm_robotics/panda/utils.py#L276-L279

Added lines #L276 - L279 were not covered by tests
]).T.reshape((-1,))
else:
r = self._rt._time_step.reward
self.fig.linepnt[0] = self.maxlen
self.y[0].append(r)
self.fig.linedata[0][:self.maxlen * 2] = np.array([self.x,
self.y[0]]).T.reshape(
(-1,))
pos = mujoco.MjrRect(2 * 300 + 5, viewport.height - 200 - 5, 300, 200)
mujoco.mjr_figure(pos, self.fig, context.ptr)

Expand Down

0 comments on commit 27710fd

Please sign in to comment.