Skip to content

Commit

Permalink
Add field frame estimator
Browse files Browse the repository at this point in the history
Estimates the transform between the odometry frame and the field
frame, given drive poses in the odometry frame and vision poses
in the field frame.

Topic: field-frame-estimator
Reviewers: saraansh, rithvik
  • Loading branch information
plusparth authored and sswadkar committed Feb 17, 2024
1 parent 8fc3a80 commit 2ac28f7
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/main/kotlin/com/team4099/lib/math/GeomUtil.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import org.team4099.lib.geometry.Pose2d
import org.team4099.lib.geometry.Transform2d
import org.team4099.lib.geometry.Translation2d
import org.team4099.lib.geometry.Twist2d
import org.team4099.lib.units.base.meters
import org.team4099.lib.units.derived.degrees
import org.team4099.lib.units.derived.radians

/**
* Multiplies a twist by a scaling factor
Expand All @@ -27,3 +29,13 @@ fun multiplyTwist(twist: Twist2d, factor: Double): Twist2d {
fun Pose2d.purelyTranslateBy(translation2d: Translation2d): Pose2d {
return this.transformBy(Transform2d(translation2d.rotateBy(-this.rotation), 0.0.degrees))
}

/**
* Returns the transform between the frame origin of the pose and the current pose state -- for
* example, if the pose describes the pose of the robot in the odometry frame, the returned
* transform will be the transform between the odometry frame and the robot frame.
* @return
*/
fun Pose2d.asTransform2d(): Transform2d {
return Transform2d(Pose2d(0.meters, 0.meters, 0.radians), this)
}
223 changes: 223 additions & 0 deletions src/main/kotlin/com/team4099/robot2023/util/FieldFrameEstimator.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
package com.team4099.robot2023.util

import com.team4099.lib.hal.Clock
import com.team4099.lib.math.asTransform2d
import edu.wpi.first.math.Matrix
import edu.wpi.first.math.Nat
import edu.wpi.first.math.VecBuilder
import edu.wpi.first.math.numbers.N1
import edu.wpi.first.math.numbers.N3
import org.littletonrobotics.junction.Logger
import org.team4099.lib.geometry.Pose2d
import org.team4099.lib.geometry.Pose2dWPILIB
import org.team4099.lib.geometry.Transform2d
import org.team4099.lib.geometry.Translation2d
import org.team4099.lib.geometry.Twist2d
import org.team4099.lib.units.base.Time
import org.team4099.lib.units.base.inMeters
import org.team4099.lib.units.base.inSeconds
import org.team4099.lib.units.base.meters
import org.team4099.lib.units.base.seconds
import org.team4099.lib.units.derived.inRadians
import org.team4099.lib.units.derived.radians
import java.util.NavigableMap
import java.util.TreeMap
import kotlin.Comparator
import kotlin.collections.ArrayList

class FieldFrameEstimator(stateStdDevs: Matrix<N3?, N1?>) {
// Maintains the state of the field frame transform for the period of time before the currently
// tracked history
private var baseOdometryTField: Transform2d =
Transform2d(Translation2d(0.meters, 0.meters), 0.radians)

// Maintains the latest state of the field frame transform, including the currently tracked
// history
private var odometryTField: Transform2d =
Transform2d(Translation2d(0.meters, 0.meters), 0.radians)

private val updates: NavigableMap<Time, PoseUpdate> = TreeMap()
private val q: Matrix<N3?, N1?> = Matrix(Nat.N3(), Nat.N1())

/** Returns the latest robot pose based on drive and vision data. */
fun getLatestOdometryTField(): Transform2d {
return odometryTField
}

/** Resets the field frame transform to a known pose. */
fun resetFieldFrameFilter(transform: Transform2d) {
baseOdometryTField = transform
updates.clear()
update()
}

/** Records a new drive movement. */
fun addDriveData(timestamp: Time, odomTRobot: Pose2d) {
updates[timestamp] = PoseUpdate(odomTRobot, ArrayList<VisionUpdate>())
update()
}

/** Records a new set of vision updates. */
fun addVisionData(visionData: List<TimestampedVisionUpdate>) {
for (timestampedVisionUpdate in visionData) {
val timestamp: Time = timestampedVisionUpdate.timestamp
val visionUpdate =
VisionUpdate(
timestampedVisionUpdate.fieldTRobot,
timestampedVisionUpdate.stdDevs,
timestampedVisionUpdate.fromVision
)
if (updates.containsKey(timestamp)) {
// There was already an update at this timestamp, add to it
val oldVisionUpdates: ArrayList<VisionUpdate> = updates[timestamp]!!.visionUpdates
oldVisionUpdates.add(visionUpdate)
oldVisionUpdates.sortWith(VisionUpdate.compareDescStdDev)
} else {
// Insert a new update
val prevUpdate = updates.floorEntry(timestamp)
val nextUpdate = updates.ceilingEntry(timestamp)
if (prevUpdate == null || nextUpdate == null) {
// Outside the range of existing data
return
}

// Create partial twists (prev -> vision, vision -> next)
val prevToVisionTwist =
multiplyTwist(
prevUpdate.value.odomTRobot.log(nextUpdate.value.odomTRobot),
(timestamp - prevUpdate.key) / (nextUpdate.key - prevUpdate.key)
)

// Add new pose updates
val newVisionUpdates = ArrayList<VisionUpdate>()
newVisionUpdates.add(visionUpdate)
newVisionUpdates.sortWith(VisionUpdate.compareDescStdDev)
updates[timestamp] =
PoseUpdate(prevUpdate.value.odomTRobot.exp(prevToVisionTwist), newVisionUpdates)
}
}

// Recalculate latest pose once
update()
}

/** Clears old data and calculates the latest pose. */
private fun update() {
// Clear old data and update base pose
// NOTE(parth): We need to maintain the history so that when vision updates come in, they have
// some buffer to interpolate within.
while (updates.size > 1 && updates.firstKey() < Clock.fpgaTime - HISTORY_LENGTH) {
val (_, value) = updates.pollFirstEntry()
baseOdometryTField = value.apply(baseOdometryTField, q)
}

// Update latest pose
odometryTField = baseOdometryTField
for (updateEntry in updates.entries) {
odometryTField = updateEntry.value.apply(odometryTField, q)
}

for (update in updates) {
if (update.value.visionUpdates.size > 0 && update.value.visionUpdates[0].fromVision) {
Logger.recordOutput("Vision/Buffer/Vision", update.key.inSeconds)

Logger.recordOutput(
"Vision/Buffer/VisionPose",
Pose2dWPILIB.struct,
update.value.visionUpdates[0].fieldTRobot.pose2d
)
} else {
Logger.recordOutput("Vision/Buffer/Drivetrain", update.key.inSeconds)
}
}
}

/**
* Represents a sequential update to a pose estimate, with a twist (drive movement) and list of
* vision updates.
*/
private class PoseUpdate(val odomTRobot: Pose2d, val visionUpdates: ArrayList<VisionUpdate>) {
fun apply(previousOdomTField: Transform2d, q: Matrix<N3?, N1?>): Transform2d {
var currentOdomTField = previousOdomTField

// Apply vision updates
for (visionUpdate in visionUpdates) {
// Calculate Kalman gains based on std devs
// (https://github.com/wpilibsuite/allwpilib/blob/main/wpimath/src/main/java/edu/wpi/first/math/estimator/)
val visionK: Matrix<N3, N3> = Matrix(Nat.N3(), Nat.N3())
val r = DoubleArray(3)
for (i in 0..2) {
r[i] = visionUpdate.stdDevs.get(i, 0) * visionUpdate.stdDevs.get(i, 0)
}
for (row in 0..2) {
if (q.get(row, 0) === 0.0) {
visionK.set(row, row, 0.0)
} else {
visionK.set(
row, row, q.get(row, 0) / (q.get(row, 0) + Math.sqrt(q.get(row, 0) * r[row]))
)
}
}

// Calculate odom_T_field from this update's vision pose
val odomTVisionField =
odomTRobot.asTransform2d() + visionUpdate.fieldTRobot.asTransform2d().inverse()

// Calculate twist between current field frame transform and latest vision update
val fieldTVisionField = currentOdomTField.inverse() + odomTVisionField
val visionTwist = fieldTVisionField.log()

// Multiply by Kalman gain matrix
val twistMatrix =
visionK.times(
VecBuilder.fill(
visionTwist.dx.inMeters, visionTwist.dy.inMeters, visionTwist.dtheta.inRadians
)
)

// Apply twist
currentOdomTField +=
Transform2d.exp(
Twist2d(
twistMatrix.get(0, 0).meters,
twistMatrix.get(1, 0).meters,
twistMatrix.get(2, 0).radians
)
)
}
return currentOdomTField
}
}

/** Represents a single vision pose with associated standard deviations. */
class VisionUpdate(
val fieldTRobot: Pose2d,
val stdDevs: Matrix<N3, N1>,
val fromVision: Boolean = false
) {
companion object {
val compareDescStdDev = Comparator { a: VisionUpdate, b: VisionUpdate ->
-(a.stdDevs.get(0, 0) + a.stdDevs.get(1, 0)).compareTo(
b.stdDevs.get(0, 0) + b.stdDevs.get(1, 0)
)
}
}
}

/** Represents a single vision pose with a timestamp and associated standard deviations. */
class TimestampedVisionUpdate(
val timestamp: Time,
val fieldTRobot: Pose2d,
val stdDevs: Matrix<N3, N1>,
val fromVision: Boolean = false
)
companion object {
private val HISTORY_LENGTH = 0.3.seconds
}

init {
for (i in 0..2) {
q.set(i, 0, stateStdDevs.get(i, 0) * stateStdDevs.get(i, 0))
}
}
}

0 comments on commit 2ac28f7

Please sign in to comment.