diff --git a/gtsfm/averaging/rotation/shonan.py b/gtsfm/averaging/rotation/shonan.py index ad8643a43..e9108e80b 100644 --- a/gtsfm/averaging/rotation/shonan.py +++ b/gtsfm/averaging/rotation/shonan.py @@ -20,6 +20,8 @@ Rot3, ShonanAveraging3, ShonanAveragingParameters3, + BetweenFactorPose3, + Pose3, ) import gtsfm.utils.logger as logger_utils @@ -38,7 +40,10 @@ class ShonanRotationAveraging(RotationAveragingBase): """Performs Shonan rotation averaging.""" def __init__( - self, two_view_rotation_sigma: float = _DEFAULT_TWO_VIEW_ROTATION_SIGMA, weight_by_inliers: bool = True + self, + two_view_rotation_sigma: float = _DEFAULT_TWO_VIEW_ROTATION_SIGMA, + weight_by_inliers: bool = True, + use_chordal_init: bool = True, ) -> None: """Initializes module. @@ -54,12 +59,19 @@ def __init__( self._p_min = 3 self._p_max = 64 self._weight_by_inliers = weight_by_inliers + self._use_chordal_init = use_chordal_init def __get_shonan_params(self) -> ShonanAveragingParameters3: lm_params = LevenbergMarquardtParams.CeresDefaults() + # TODO(akshay-krishnan): These parameters speed up Shonan, but disabled now because accuracy dropped slightly. + # lm_params.setRelativeErrorTol(0.01) + # lm_params.setAbsoluteErrorTol(1) shonan_params = ShonanAveragingParameters3(lm_params) shonan_params.setUseHuber(False) shonan_params.setCertifyOptimality(True) + shonan_params.setGaugesWeight(0.0) + shonan_params.setKarcherWeight(1.0) + shonan_params.setAnchorWeight(0.0) return shonan_params def __measurements_from_2view_relative_rotations( @@ -132,9 +144,18 @@ def _run_with_consecutive_ordering( len(measurements), num_connected_nodes, ) - shonan = ShonanAveraging3(measurements, self.__get_shonan_params()) + shonan_params = self.__get_shonan_params() + if self._use_chordal_init: + shonan_params.setKarcherWeight(0.0) + shonan_params.setAnchorWeight(1.0) + shonan_params.setAnchor(measurements[0].key1(), Rot3()) + shonan = ShonanAveraging3(measurements, shonan_params) + + if self._use_chordal_init: + initial = self.chordal_initialize(measurements) + else: + initial = shonan.initializeRandomly() - initial = shonan.initializeRandomly() logger.info("Initial cost: %.5f", shonan.cost(initial)) result, _ = shonan.run(initial, self._p_min, self._p_max) logger.info("Final cost: %.5f", shonan.cost(result)) @@ -161,6 +182,27 @@ def _nodes_with_edges( return unique_nodes_with_edges + def chordal_initialize(self, measurements: gtsam.BinaryMeasurementsRot3) -> gtsam.Values: + """Initialize values using GTSAM's chordal init. + + Args: + measurements: BinaryMeasurementsRot3 object created before running Shonan. + + Returns: + Initial values as a gtsam.Values object. + """ + graph = gtsam.NonlinearFactorGraph() + anchor_key = None + noise_model = gtsam.noiseModel.Diagonal.Variances(np.array([1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4], dtype=float)) + for measurement in measurements: + if anchor_key is None: + anchor_key = measurement.key1() + pose_measurement = Pose3(measurement.measured(), np.array([0.0, 0.0, 0.0])) + graph.add(BetweenFactorPose3(measurement.key1(), measurement.key2(), pose_measurement, noise_model)) + + graph.addPriorPose3(anchor_key, Pose3(), noise_model) + return gtsam.InitializePose3.initializeOrientations(graph) + def run_rotation_averaging( self, num_images: int,