Skip to content

Commit

Permalink
WIP: Translate filterpy approach in rust
Browse files Browse the repository at this point in the history
  • Loading branch information
Edoalto-metis authored and Edoalto-metis committed May 22, 2024
1 parent 4354141 commit fd6449d
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 42 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ edition = "2021"

[dependencies]
rumqttc = "0.24"
eskf = "0.2"
nalgebra = "0.25"
nalgebra = "=0.32.5"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
rand = "0.8.5"
kalmanfilt = "0.2.4"
133 changes: 93 additions & 40 deletions src/bin/sailtrack-kalman.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use eskf::ESKF;
use kalmanfilt::kalman::kalman_filter::KalmanFilter as Kalman;
use rand::Rng;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::{Duration, Instant};

use nalgebra::{Matrix3, Vector3};
use nalgebra::{OMatrix, OVector, U3, U6};
use rumqttc::{Client, Event, Incoming, MqttOptions, QoS};
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -76,15 +76,14 @@ struct Boat {

#[derive(Debug, Clone, Copy)]
struct Measure {
vel: Vector3<f32>,
vel_variance: Matrix3<f32>,
meas: OVector<f32, U6>,
meas_variance: OMatrix<f32, U6, U6>,
new_measure: bool,
}

#[derive(Debug, Clone, Copy)]
struct Input {
acceleration: Vector3<f32>,
rotation: Vector3<f32>,
acceleration: OVector<f32, U3>,
new_input: bool,
}

Expand All @@ -101,21 +100,43 @@ fn acquire_lock<T>(mutex: &Arc<Mutex<T>>) -> std::sync::MutexGuard<T> {
}

// Function to compute the measure for the Kalman filter from the raw GPS data
fn get_measure_forom_gps(measure: &Gps) -> Measure {
let vel = Vector3::new(
measure.vel_n * f32::powf(10.0, -3.0),
measure.vel_e * f32::powf(10.0, -3.0),
measure.vel_d * f32::powf(10.0, -3.0),
);
fn get_measure_forom_gps(gps_data: &Gps, reference: &Gps) -> Measure {
let meas_vec = vec![
(gps_data.lat * f32::powf(10.0, -7.0) - reference.lat * f32::powf(10.0, -7.0))
* EARTH_CIRCUMFERENCE_METERS
/ 360.0,
(gps_data.lon * f32::powf(10.0, -7.0) - reference.lon * f32::powf(10.0, -7.0))
* EARTH_CIRCUMFERENCE_METERS
* LAT_FACTOR
/ 360.0,
gps_data.h_msl * f32::powf(10.0, -3.0) - reference.h_msl * f32::powf(10.0, -3.0),
gps_data.vel_n * f32::powf(10.0, -3.0),
gps_data.vel_e * f32::powf(10.0, -3.0),
gps_data.vel_d * f32::powf(10.0, -3.0),
];

let meas:OVector<f32, U6> = OVector::<f32, U6>::from_iterator(meas_vec.into_iter());


let accuracy_penality_factor = 100.0;
let mut vel_variance =
Matrix3::identity() * 0.25 * (measure.s_acc * f32::powf(10.0, -3.0)).powf(2.0);
if measure.fix_type != 3 {
vel_variance *= accuracy_penality_factor;
let mut meas_variance: OMatrix<f32, U6, U6> = OMatrix::zeros_generic(U6, U6);
let acc_scaling = f32::powf(10.0, -3.0);

for i in 0..5 {
let acc_value = match i {
0 | 1 => gps_data.h_acc,
2 => gps_data.v_acc,
_ => gps_data.s_acc,
};
meas_variance[(i, i)] = 0.25 * acc_value.powf(2.0) * acc_scaling;
if gps_data.fix_type != 3 {
meas_variance[(i, i)] *= accuracy_penality_factor;
}
}

Measure {
vel,
vel_variance,
meas,
meas_variance,
new_measure: true,
}
}
Expand All @@ -140,16 +161,15 @@ fn wait_for_fix_tipe(gps_ref_arc: &Arc<Mutex<Gps>>) -> bool {
}

fn on_message_imu(message: Imu, input: &Arc<Mutex<Input>>) {
let accel = Vector3::new(
let accel_vec = vec![
message.linear_accel.x,
message.linear_accel.y,
message.linear_accel.z,
);
let orientation = Vector3::new(message.euler.x, -message.euler.y, 360.0 - message.euler.z);
];
let accel = OVector::<f32, U3>::from_iterator(accel_vec.into_iter());
let mut input_lock = acquire_lock(input);
input_lock.new_input = true;
input_lock.acceleration = accel;
input_lock.rotation = orientation;
drop(input_lock);
}

Expand All @@ -160,22 +180,30 @@ fn on_message_gps(message: Gps, gps_ref_arc: &Arc<Mutex<Gps>>, measure_arc: &Arc
if gps_ref_lock.fix_type != 3 {
*gps_ref_lock = message;
}
let measure = get_measure_forom_gps(&message);
let measure = get_measure_forom_gps(&message, &gps_ref_lock);
*measure_lock = measure;
drop(measure_lock);
drop(gps_ref_lock);
}

// Kalman predict function on new input
fn filter_predict(kalman: &mut ESKF, input: &mut Input, filter_ts: Duration) {
fn filter_predict(kalman: &mut Kalman<f32, U6, U6, U3>, input: &mut Input, filter_ts: Duration) {
kalman.predict(input.acceleration, input.rotation, filter_ts);
input.new_input = false;
}

// Kalman update function on new measure
fn filter_update(kalman: &mut ESKF, measure: &mut Measure) {
kalman
.observe_velocity(measure.vel, measure.vel_variance)
.observe_position_velocity2d(
measure.position,
measure.pos_variance,
measure.velocity_xy,
measure.vel_variance,
)
.unwrap();
kalman
.observe_height(measure.velocity_z, measure.vel_z_variance)
.unwrap();
measure.new_measure = false;
}
Expand All @@ -198,8 +226,7 @@ fn main() {
let filter_ts = Duration::from_millis(KALMAN_SAMPLE_TIME_MS);

let input = Input {
acceleration: Vector3::new(0.0, 0.0, 0.0),
rotation: Vector3::new(0.0, 0.0, 0.0),
acceleration: OVector::<f32, U3>::zeros(),
new_input: false,
};

Expand All @@ -221,16 +248,42 @@ fn main() {
};

let measure = Measure {
vel: Vector3::zeros(),
vel_variance: Matrix3::zeros(),
meas: OVector::<f32, U6>::zeros(),
meas_variance: OMatrix::<f32, U6, U6>::identity(),
new_measure: false,
};

// Creating ESKF object
let filter = eskf::Builder::new()
.acceleration_variance(0.01) // FIXME
.rotation_variance(0.01) // FIXME
.build();
let w_std = 0.01;
let sample_time = filter_ts.as_secs_f32();
let transition_mtx = OMatrix::<f32, U6,
U6>::from_column_slice(&[
1.0, 0.0, 0.0, sample_time, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0, sample_time, 0.0,
0.0, 0.0, 1.0, 0.0, 0.0, sample_time,
0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
]);
let input_mtx = OMatrix::<f32, U6, U3>::from_row_slice(&[
sample_time.powi(2) / 2.0, 0.0, 0.0,
0.0, sample_time.powi(2) / 2.0, 0.0,
0.0, 0.0, sample_time.powi(2) / 2.0,
sample_time, 0.0, 0.0,
0.0, sample_time, 0.0,
0.0, 0.0, sample_time
]);
let output_mtx = OMatrix::<f32, U6, U6>::identity();
let noise_state_cov = input_mtx*input_mtx.transpose()*w_std;
let noise_meas_cov = OMatrix::<f32, U6, U6>::identity();

let mut filter = Kalman::<f32, U6, U6, U3>::default();

filter.F = transition_mtx;
filter.H = output_mtx;
filter.B = Some(input_mtx);
filter.Q = noise_state_cov;
filter.R = noise_meas_cov;

// Defining Mutex for thread share
let gps_ref_mutex = Arc::new(Mutex::new(gps_ref));
Expand All @@ -239,10 +292,8 @@ fn main() {
let filter_mutex = Arc::new(Mutex::new(filter));

// TODO: Add username and password authentication
// let mut mqqt_opts = MqttOptions::new("sailtrack-kalman", "192.168.42.1", 1883);
// mqqt_opts.set_credentials("mosquitto", "sailtrack");

let mqqt_opts = MqttOptions::new("sailtrack-kalman", "localhost", 1883);
let mut mqqt_opts = MqttOptions::new("sailtrack-kalman", "192.168.42.1", 1883);
mqqt_opts.set_credentials("mosquitto", "sailtrack");

let (client, mut connection) = Client::new(mqqt_opts, 10);
client.subscribe("sensor/gps0", QoS::AtMostOnce).unwrap();
Expand Down Expand Up @@ -315,12 +366,13 @@ fn main() {
thread::sleep(filter_ts - elapsed);
}
});

//MQTT publish loop
let gps_ref_clone = Arc::clone(&gps_ref_mutex);
let filter_clone = Arc::clone(&filter_mutex);
loop {
// Check if the GPS fix has been obtained
// wait_for_fix_tipe(&gps_ref_clone);
wait_for_fix_tipe(&gps_ref_clone);
let filter_lock = acquire_lock(&filter_clone);
let position = filter_lock.position;
let velocity = filter_lock.velocity;
Expand All @@ -332,7 +384,7 @@ fn main() {
let pitch = euler_orient.1;
let heading = euler_orient.2;

let sog = (velocity.x.powi(2) + velocity.y.powi(2)).sqrt() * MPS_TO_KNTS_MULTIPLIER;
let sog = velocity.norm() * MPS_TO_KNTS_MULTIPLIER;
let mut cog = heading;
let mut drift = 0.0;
if sog > 1.0 {
Expand All @@ -355,6 +407,7 @@ fn main() {
+ gps_ref_lock.lon * f32::powf(10.0, -7.0);
let altitude = position.z + gps_ref_lock.h_msl * f32::powf(10.0, -3.0);
drop(gps_ref_lock);

let message = Boat {
lon,
lat,
Expand All @@ -377,4 +430,4 @@ fn main() {
.unwrap();
thread::sleep(Duration::from_millis(1000 / MQTT_PUBLISH_FREQ_HZ));
}
}
}

0 comments on commit fd6449d

Please sign in to comment.