diff --git a/rofunc/__init__.py b/rofunc/__init__.py index eb30dabb4..a22a61fdd 100644 --- a/rofunc/__init__.py +++ b/rofunc/__init__.py @@ -19,11 +19,11 @@ from .devices import zed, xsens, optitrack, mmodal, emg from . import simulator as sim -# from .learning import ml +from .learning import ml from .learning import RofuncIL, RofuncRL -# from .planning_control import lqt, lqr +from .planning_control import lqt, lqr from .utils import visualab, robolab, logger, oslab from .utils.datalab import primitive, data_generator from . import config -# from .learning.ml import tpgmm, gmr, tpgmr +from .learning.ml import tpgmm, gmr, tpgmr diff --git a/rofunc/learning/ml/tpgmr.py b/rofunc/learning/ml/tpgmr.py index ec466b135..d1e1237bf 100644 --- a/rofunc/learning/ml/tpgmr.py +++ b/rofunc/learning/ml/tpgmr.py @@ -35,7 +35,7 @@ def __init__(self, demos_x, task_params, nb_states: int = 4, reg: float = 1e-3, :param plot: whether to plot the result """ super().__init__(demos_x, task_params, nb_states=nb_states, reg=reg, plot=plot) - self.gmr = rf.learning.gmr.GMR(self.demos_x, self.demos_dx, self.demos_xdx, nb_states=nb_states, reg=reg, + self.gmr = rf.gmr.GMR(self.demos_x, self.demos_dx, self.demos_xdx, nb_states=nb_states, reg=reg, plot=False) def gmm_learning(self):