-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrained_policy.txt
31 lines (23 loc) · 983 Bytes
/
trained_policy.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
def solve_cartpole(cart_position: float, cart_velocity: float, pole_angle: float, pole_angular_velocity: float) -> int:
# Define the constants for the PD controller
kp_pole = 0.5
kd_pole = 0.1
kp_cart = 0.2
kd_cart = 0.05
# Compute error signals and their derivatives
pole_error = pole_angle
pole_error_derivative = pole_angular_velocity
cart_error = cart_position
cart_error_derivative = cart_velocity
# Compute the control signals for the pole and cart using PD controllers
pole_control_signal = kp_pole * pole_error + kd_pole * pole_error_derivative
cart_control_signal = kp_cart * cart_error + kd_cart * cart_error_derivative
# Combine the control signals to get the final action
if pole_control_signal < -cart_control_signal:
action = 0
elif pole_control_signal > cart_control_signal:
action = 1
else:
action = 1
return action