Skip to content

Commit

Permalink
updating to latest (post v1.1.0) API, bumping rlt
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-eschmann committed Jun 28, 2024
1 parent a1ccfe4 commit c4793ae
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 11 deletions.
2 changes: 1 addition & 1 deletion external/rl_tools
Submodule rl_tools updated 164 files
8 changes: 8 additions & 0 deletions include/my_pendulum/my_pendulum.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,20 @@ struct MyPendulumState{
T theta_dot;
};

template <typename TI>
struct MyPendulumFourierObservation{
static constexpr TI DIM = 3; // cos(theta), sin(theta), theta_dot
};

template <typename T_SPEC>
struct MyPendulum: rl_tools::rl::environments::Environment{
using SPEC = T_SPEC;
using T = typename SPEC::T;
using TI = typename SPEC::TI;
using Parameters = typename SPEC::PARAMETERS;
using State = MyPendulumState<T, TI>;
using Observation = MyPendulumFourierObservation<TI>;
using ObservationPrivileged = Observation;
static constexpr TI OBSERVATION_DIM = 3;
static constexpr TI ACTION_DIM = 1;
};
24 changes: 14 additions & 10 deletions include/my_pendulum/operations_generic.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@ T angle_normalize(const DEVICE& dev, T x){


namespace rl_tools{
template<typename DEVICE, typename SPEC>
void initial_parameters(DEVICE& device, const MyPendulum<SPEC>& env, typename MyPendulum<SPEC>::Parameters& parameters){ }
template<typename DEVICE, typename SPEC, typename RNG>
static void sample_initial_state(DEVICE& device, const MyPendulum<SPEC>& env, typename MyPendulum<SPEC>::State& state, RNG& rng){
state.theta = random::uniform_real_distribution(typename DEVICE::SPEC::RANDOM(), SPEC::PARAMETERS::INITIAL_STATE_MIN_ANGLE, SPEC::PARAMETERS::INITIAL_STATE_MAX_ANGLE, rng);
state.theta_dot = random::uniform_real_distribution(typename DEVICE::SPEC::RANDOM(), SPEC::PARAMETERS::INITIAL_STATE_MIN_SPEED, SPEC::PARAMETERS::INITIAL_STATE_MAX_SPEED, rng);
}
void sample_initial_parameters(DEVICE& device, const MyPendulum<SPEC>& env, typename MyPendulum<SPEC>::Parameters& parameters, RNG& rng){ }
template<typename DEVICE, typename SPEC>
static void initial_state(DEVICE& device, const MyPendulum<SPEC>& env, typename MyPendulum<SPEC>::State& state){
void initial_state(DEVICE& device, const MyPendulum<SPEC>& env, const typename MyPendulum<SPEC>::Parameters& parameters, typename MyPendulum<SPEC>::State& state){
state.theta = -rl_tools::math::PI<typename SPEC::T>;
state.theta_dot = 0;
}
template<typename DEVICE, typename SPEC, typename RNG>
void sample_initial_state(DEVICE& device, const MyPendulum<SPEC>& env, const typename MyPendulum<SPEC>::Parameters& parameters, typename MyPendulum<SPEC>::State& state, RNG& rng){
state.theta = random::uniform_real_distribution(typename DEVICE::SPEC::RANDOM(), SPEC::PARAMETERS::INITIAL_STATE_MIN_ANGLE, SPEC::PARAMETERS::INITIAL_STATE_MAX_ANGLE, rng);
state.theta_dot = random::uniform_real_distribution(typename DEVICE::SPEC::RANDOM(), SPEC::PARAMETERS::INITIAL_STATE_MIN_SPEED, SPEC::PARAMETERS::INITIAL_STATE_MAX_SPEED, rng);
}
template<typename DEVICE, typename SPEC, typename ACTION_SPEC, typename RNG>
typename SPEC::T step(DEVICE& device, const MyPendulum<SPEC>& env, const typename MyPendulum<SPEC>::State& state, const Matrix<ACTION_SPEC>& action, typename MyPendulum<SPEC>::State& next_state, RNG& rng) {
typename SPEC::T step(DEVICE& device, const MyPendulum<SPEC>& env, const typename MyPendulum<SPEC>::Parameters& parameters, const typename MyPendulum<SPEC>::State& state, const Matrix<ACTION_SPEC>& action, typename MyPendulum<SPEC>::State& next_state, RNG& rng) {
static_assert(ACTION_SPEC::ROWS == 1);
static_assert(ACTION_SPEC::COLS == 1);
using T = typename SPEC::T;
Expand All @@ -51,7 +55,7 @@ namespace rl_tools{
return SPEC::PARAMETERS::DT;
}
template<typename DEVICE, typename SPEC, typename ACTION_SPEC, typename RNG>
static typename SPEC::T reward(DEVICE& device, const MyPendulum<SPEC>& env, const typename MyPendulum<SPEC>::State& state, const Matrix<ACTION_SPEC>& action, const typename MyPendulum<SPEC>::State& next_state, RNG& rng){
typename SPEC::T reward(DEVICE& device, const MyPendulum<SPEC>& env, const typename MyPendulum<SPEC>::Parameters& parameters, const typename MyPendulum<SPEC>::State& state, const Matrix<ACTION_SPEC>& action, const typename MyPendulum<SPEC>::State& next_state, RNG& rng){
using T = typename SPEC::T;
T angle_norm = angle_normalize(device.math, state.theta);
T u_normalised = get(action, 0, 0);
Expand All @@ -60,8 +64,8 @@ namespace rl_tools{
return -costs;
}

template<typename DEVICE, typename SPEC, typename OBS_SPEC, typename RNG>
static void observe(DEVICE& device, const MyPendulum<SPEC>& env, const typename MyPendulum<SPEC>::State& state, Matrix<OBS_SPEC>& observation, RNG& rng){
template<typename DEVICE, typename SPEC, typename OBS_TYPE_SPEC, typename OBS_SPEC, typename RNG>
void observe(DEVICE& device, const MyPendulum<SPEC>& env, const typename MyPendulum<SPEC>::Parameters& parameters, const typename MyPendulum<SPEC>::State& state, const MyPendulumFourierObservation<OBS_TYPE_SPEC>&, Matrix<OBS_SPEC>& observation, RNG& rng){
static_assert(OBS_SPEC::ROWS == 1);
static_assert(OBS_SPEC::COLS == 3);
using T = typename SPEC::T;
Expand All @@ -70,7 +74,7 @@ namespace rl_tools{
set(observation, 0, 2, state.theta_dot);
}
template<typename DEVICE, typename SPEC, typename RNG>
static bool terminated(DEVICE& device, const MyPendulum<SPEC>& env, const typename MyPendulum<SPEC>::State state, RNG& rng){
bool terminated(DEVICE& device, const MyPendulum<SPEC>& env, const typename MyPendulum<SPEC>::Parameters& parameters, const typename MyPendulum<SPEC>::State state, RNG& rng){
using T = typename SPEC::T;
return false;
}
Expand Down

0 comments on commit c4793ae

Please sign in to comment.