diff --git a/src/python/pose_format/utils/generic.py b/src/python/pose_format/utils/generic.py index 5bbd7c0..4d79e6d 100644 --- a/src/python/pose_format/utils/generic.py +++ b/src/python/pose_format/utils/generic.py @@ -13,7 +13,7 @@ # from pose_format.utils.holistic import holistic_components # The import above creates an error: ImportError: Please install mediapipe with: pip install mediapipe -SupportedPoseFormat = Literal["holistic", "openpose", "openpose_135"] +KnownPoseFormat = Literal["holistic", "openpose", "openpose_135"] def get_component_names( @@ -35,7 +35,7 @@ def get_component_names( raise ValueError(f"Could not get component_names from {pose_or_header_or_components}") -def detect_known_pose_format(component_names: List[str]) -> SupportedPoseFormat: +def detect_known_pose_format(component_names: List[str]) -> KnownPoseFormat: # would be better to import from pose_format.utils.holistic but that creates a dep on mediapipe mediapipe_components = [ @@ -89,7 +89,7 @@ def pose_hide_legs(pose: Pose): ] pose.body.data[:, :, points, :] = 0 pose.body.confidence[:, :, points] = 0 - elif known_pose_format in get_args(SupportedPoseFormat): + elif known_pose_format in get_args(KnownPoseFormat): raise NotImplementedError( f"Unsupported pose header schema {known_pose_format} for {pose_hide_legs.__name__}: {pose.header}" ) @@ -109,7 +109,7 @@ def pose_shoulders(pose_header: PoseHeader): if known_pose_format == "openpose": return ("pose_keypoints_2d", "RShoulder"), ("pose_keypoints_2d", "LShoulder") - if known_pose_format in get_args(SupportedPoseFormat): + if known_pose_format in get_args(KnownPoseFormat): raise NotImplementedError( f"Unsupported pose header schema {known_pose_format} for {pose_shoulders.__name__}: {pose_header}" ) @@ -129,7 +129,7 @@ def hands_indexes(pose_header: PoseHeader): pose_header._get_point_index("hand_left_keypoints_2d", "M_CMC"), pose_header._get_point_index("hand_right_keypoints_2d", "M_CMC"), ] - if known_pose_format in get_args(SupportedPoseFormat): + if known_pose_format in get_args(KnownPoseFormat): raise NotImplementedError( f"Unsupported pose header schema {known_pose_format} for {hands_indexes.__name__}: {pose_header}" ) @@ -153,7 +153,7 @@ def hands_components(pose_header: PoseHeader): if known_pose_format == "openpose": return ("hand_left_keypoints_2d", "hand_right_keypoints_2d"), ("BASE", "P_CMC", "I_CMC"), ("BASE", "M_CMC") - if known_pose_format in get_args(SupportedPoseFormat): + if known_pose_format in get_args(KnownPoseFormat): raise NotImplementedError( f"Unsupported pose header schema '{known_pose_format}' for {hands_components.__name__}: {pose_header}" ) @@ -183,7 +183,7 @@ def normalize_hands_3d(pose: Pose, left_hand=True, right_hand=True): normalize_component_3d(pose, right_hand_component, plane, line) -def get_standard_components_for_known_format(known_pose_format: SupportedPoseFormat) -> List[PoseHeaderComponent]: +def get_standard_components_for_known_format(known_pose_format: KnownPoseFormat) -> List[PoseHeaderComponent]: if known_pose_format == "holistic": try: import pose_format.utils.holistic as holistic_utils @@ -198,7 +198,7 @@ def get_standard_components_for_known_format(known_pose_format: SupportedPoseFor return OpenPose_Components if known_pose_format == "openpose_135": return OpenPose135_Components - if known_pose_format in get_args(SupportedPoseFormat): + if known_pose_format in get_args(KnownPoseFormat): raise NotImplementedError(f"Unsupported pose header schema {known_pose_format}") raise ValueError(f"Unknown pose format {known_pose_format}, cannot get standard components") @@ -225,7 +225,7 @@ def get_hand_wrist_index(pose: Pose, hand: str): return pose.header._get_point_index(f"{hand.upper()}_HAND_LANDMARKS", "WRIST") if known_pose_format == "openpose": return pose.header._get_point_index(f"hand_{hand.lower()}_keypoints_2d", "BASE") - if known_pose_format in get_args(SupportedPoseFormat): + if known_pose_format in get_args(KnownPoseFormat): raise NotImplementedError( f"{known_pose_format} pose header schema unsupported for get_hand_wrist_index: {pose.header}" ) @@ -238,11 +238,10 @@ def get_body_hand_wrist_index(pose: Pose, hand: str): return pose.header._get_point_index("POSE_LANDMARKS", f"{hand.upper()}_WRIST") if known_pose_format == "openpose": return pose.header._get_point_index("pose_keypoints_2d", f"{hand.upper()[0]}Wrist") - if known_pose_format in get_args(SupportedPoseFormat): + if known_pose_format in get_args(KnownPoseFormat): raise NotImplementedError( f"{known_pose_format} pose header schema unsupported for {get_body_hand_wrist_index.__name__}" ) - raise ValueError(f"Unknown pose header schema for {get_body_hand_wrist_index.__name__} {pose.header}") def correct_wrist(pose: Pose, hand: str) -> Pose: diff --git a/src/python/pose_format/utils/generic_test.py b/src/python/pose_format/utils/generic_test.py index 349f4c7..09f3c99 100644 --- a/src/python/pose_format/utils/generic_test.py +++ b/src/python/pose_format/utils/generic_test.py @@ -7,7 +7,7 @@ detect_known_pose_format, get_component_names, get_standard_components_for_known_format, - SupportedPoseFormat, + KnownPoseFormat, pose_hide_legs, pose_shoulders, hands_indexes, @@ -22,7 +22,7 @@ @pytest.mark.parametrize( - "fake_poses, expected_type", [(fmt, fmt) for fmt in get_args(SupportedPoseFormat)], indirect=["fake_poses"] + "fake_poses, expected_type", [(fmt, fmt) for fmt in get_args(KnownPoseFormat)], indirect=["fake_poses"] ) def test_detect_format(fake_poses, expected_type): for pose in fake_poses: @@ -36,9 +36,9 @@ def test_detect_format(fake_poses, expected_type): @pytest.mark.parametrize( - "fake_poses, format", [(fmt, fmt) for fmt in get_args(SupportedPoseFormat)], indirect=["fake_poses"] + "fake_poses, known_pose_format", [(fmt, fmt) for fmt in get_args(KnownPoseFormat)], indirect=["fake_poses"] ) -def test_get_component_names(fake_poses: List[Pose], known_pose_format: SupportedPoseFormat): +def test_get_component_names(fake_poses: List[Pose], known_pose_format: KnownPoseFormat): standard_components_for_format = get_standard_components_for_known_format(known_pose_format) names_for_standard_components_for_format = sorted([c.name for c in standard_components_for_format]) @@ -56,7 +56,7 @@ def test_get_component_names(fake_poses: List[Pose], known_pose_format: Supporte get_component_names(pose.body) # type: ignore -@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"]) +@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"]) def test_pose_hide_legs(fake_poses: List[Pose]): for pose in fake_poses: orig_nonzeros_count = np.count_nonzero(pose.body.data) @@ -67,7 +67,7 @@ def test_pose_hide_legs(fake_poses: List[Pose]): assert orig_nonzeros_count > new_nonzeros_count -@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"]) +@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"]) def test_pose_shoulders(fake_poses: List[Pose]): for pose in fake_poses: shoulders = pose_shoulders(pose.header) @@ -76,21 +76,21 @@ def test_pose_shoulders(fake_poses: List[Pose]): assert "LEFT" in shoulders[1][1] or shoulders[1][1][0] == "L" -@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"]) +@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"]) def test_hands_indexes(fake_poses: List[Pose]): for pose in fake_poses: indices = hands_indexes(pose.header) assert len(indices) > 0 -@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"]) +@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"]) def test_normalize_pose_size(fake_poses: List[Pose]): for pose in fake_poses: normalize_pose_size(pose) # TODO: more tests, compare with test data -@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"]) +@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"]) def test_pose_normalization_info(fake_poses: List[Pose]): for pose in fake_poses: info = pose_normalization_info(pose.header) @@ -101,7 +101,7 @@ def test_pose_normalization_info(fake_poses: List[Pose]): # TODO: more tests, compare with test data -@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"]) +@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"]) def test_get_hand_wrist_index(fake_poses: List[Pose]): for pose in fake_poses: for hand in ["LEFT", "RIGHT"]: @@ -110,7 +110,7 @@ def test_get_hand_wrist_index(fake_poses: List[Pose]): # TODO: what are the expected values? -@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"]) +@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"]) def test_get_body_hand_wrist_index(fake_poses: List[Pose]): for pose in fake_poses: for hand in ["LEFT", "RIGHT"]: @@ -118,7 +118,7 @@ def test_get_body_hand_wrist_index(fake_poses: List[Pose]): # TODO: what are the expected values? -@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"]) +@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"]) def test_correct_wrists(fake_poses: List[Pose]): for pose in fake_poses: corrected_pose = correct_wrists(pose) @@ -126,7 +126,7 @@ def test_correct_wrists(fake_poses: List[Pose]): assert corrected_pose != pose -@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"]) +@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"]) def test_hands_components(fake_poses: List[Pose]): for pose in fake_poses: hands_components_returned = hands_components(pose.header)