Skip to content

Commit

Permalink
SupportedPoseFormat->KnownPoseFormat
Browse files Browse the repository at this point in the history
  • Loading branch information
cleong110 committed Jan 10, 2025
1 parent 5f86f29 commit 9996a43
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 24 deletions.
21 changes: 10 additions & 11 deletions src/python/pose_format/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# from pose_format.utils.holistic import holistic_components
# The import above creates an error: ImportError: Please install mediapipe with: pip install mediapipe

SupportedPoseFormat = Literal["holistic", "openpose", "openpose_135"]
KnownPoseFormat = Literal["holistic", "openpose", "openpose_135"]


def get_component_names(
Expand All @@ -35,7 +35,7 @@ def get_component_names(
raise ValueError(f"Could not get component_names from {pose_or_header_or_components}")


def detect_known_pose_format(component_names: List[str]) -> SupportedPoseFormat:
def detect_known_pose_format(component_names: List[str]) -> KnownPoseFormat:

# would be better to import from pose_format.utils.holistic but that creates a dep on mediapipe
mediapipe_components = [
Expand Down Expand Up @@ -89,7 +89,7 @@ def pose_hide_legs(pose: Pose):
]
pose.body.data[:, :, points, :] = 0
pose.body.confidence[:, :, points] = 0
elif known_pose_format in get_args(SupportedPoseFormat):
elif known_pose_format in get_args(KnownPoseFormat):
raise NotImplementedError(
f"Unsupported pose header schema {known_pose_format} for {pose_hide_legs.__name__}: {pose.header}"
)
Expand All @@ -109,7 +109,7 @@ def pose_shoulders(pose_header: PoseHeader):
if known_pose_format == "openpose":
return ("pose_keypoints_2d", "RShoulder"), ("pose_keypoints_2d", "LShoulder")

if known_pose_format in get_args(SupportedPoseFormat):
if known_pose_format in get_args(KnownPoseFormat):
raise NotImplementedError(
f"Unsupported pose header schema {known_pose_format} for {pose_shoulders.__name__}: {pose_header}"
)
Expand All @@ -129,7 +129,7 @@ def hands_indexes(pose_header: PoseHeader):
pose_header._get_point_index("hand_left_keypoints_2d", "M_CMC"),
pose_header._get_point_index("hand_right_keypoints_2d", "M_CMC"),
]
if known_pose_format in get_args(SupportedPoseFormat):
if known_pose_format in get_args(KnownPoseFormat):
raise NotImplementedError(
f"Unsupported pose header schema {known_pose_format} for {hands_indexes.__name__}: {pose_header}"
)
Expand All @@ -153,7 +153,7 @@ def hands_components(pose_header: PoseHeader):
if known_pose_format == "openpose":
return ("hand_left_keypoints_2d", "hand_right_keypoints_2d"), ("BASE", "P_CMC", "I_CMC"), ("BASE", "M_CMC")

if known_pose_format in get_args(SupportedPoseFormat):
if known_pose_format in get_args(KnownPoseFormat):
raise NotImplementedError(
f"Unsupported pose header schema '{known_pose_format}' for {hands_components.__name__}: {pose_header}"
)
Expand Down Expand Up @@ -183,7 +183,7 @@ def normalize_hands_3d(pose: Pose, left_hand=True, right_hand=True):
normalize_component_3d(pose, right_hand_component, plane, line)


def get_standard_components_for_known_format(known_pose_format: SupportedPoseFormat) -> List[PoseHeaderComponent]:
def get_standard_components_for_known_format(known_pose_format: KnownPoseFormat) -> List[PoseHeaderComponent]:
if known_pose_format == "holistic":
try:
import pose_format.utils.holistic as holistic_utils
Expand All @@ -198,7 +198,7 @@ def get_standard_components_for_known_format(known_pose_format: SupportedPoseFor
return OpenPose_Components
if known_pose_format == "openpose_135":
return OpenPose135_Components
if known_pose_format in get_args(SupportedPoseFormat):
if known_pose_format in get_args(KnownPoseFormat):
raise NotImplementedError(f"Unsupported pose header schema {known_pose_format}")
raise ValueError(f"Unknown pose format {known_pose_format}, cannot get standard components")

Expand All @@ -225,7 +225,7 @@ def get_hand_wrist_index(pose: Pose, hand: str):
return pose.header._get_point_index(f"{hand.upper()}_HAND_LANDMARKS", "WRIST")
if known_pose_format == "openpose":
return pose.header._get_point_index(f"hand_{hand.lower()}_keypoints_2d", "BASE")
if known_pose_format in get_args(SupportedPoseFormat):
if known_pose_format in get_args(KnownPoseFormat):
raise NotImplementedError(
f"{known_pose_format} pose header schema unsupported for get_hand_wrist_index: {pose.header}"
)
Expand All @@ -238,11 +238,10 @@ def get_body_hand_wrist_index(pose: Pose, hand: str):
return pose.header._get_point_index("POSE_LANDMARKS", f"{hand.upper()}_WRIST")
if known_pose_format == "openpose":
return pose.header._get_point_index("pose_keypoints_2d", f"{hand.upper()[0]}Wrist")
if known_pose_format in get_args(SupportedPoseFormat):
if known_pose_format in get_args(KnownPoseFormat):
raise NotImplementedError(
f"{known_pose_format} pose header schema unsupported for {get_body_hand_wrist_index.__name__}"
)
raise ValueError(f"Unknown pose header schema for {get_body_hand_wrist_index.__name__} {pose.header}")


def correct_wrist(pose: Pose, hand: str) -> Pose:
Expand Down
26 changes: 13 additions & 13 deletions src/python/pose_format/utils/generic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
detect_known_pose_format,
get_component_names,
get_standard_components_for_known_format,
SupportedPoseFormat,
KnownPoseFormat,
pose_hide_legs,
pose_shoulders,
hands_indexes,
Expand All @@ -22,7 +22,7 @@


@pytest.mark.parametrize(
"fake_poses, expected_type", [(fmt, fmt) for fmt in get_args(SupportedPoseFormat)], indirect=["fake_poses"]
"fake_poses, expected_type", [(fmt, fmt) for fmt in get_args(KnownPoseFormat)], indirect=["fake_poses"]
)
def test_detect_format(fake_poses, expected_type):
for pose in fake_poses:
Expand All @@ -36,9 +36,9 @@ def test_detect_format(fake_poses, expected_type):


@pytest.mark.parametrize(
"fake_poses, format", [(fmt, fmt) for fmt in get_args(SupportedPoseFormat)], indirect=["fake_poses"]
"fake_poses, known_pose_format", [(fmt, fmt) for fmt in get_args(KnownPoseFormat)], indirect=["fake_poses"]
)
def test_get_component_names(fake_poses: List[Pose], known_pose_format: SupportedPoseFormat):
def test_get_component_names(fake_poses: List[Pose], known_pose_format: KnownPoseFormat):

standard_components_for_format = get_standard_components_for_known_format(known_pose_format)
names_for_standard_components_for_format = sorted([c.name for c in standard_components_for_format])
Expand All @@ -56,7 +56,7 @@ def test_get_component_names(fake_poses: List[Pose], known_pose_format: Supporte
get_component_names(pose.body) # type: ignore


@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"])
@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"])
def test_pose_hide_legs(fake_poses: List[Pose]):
for pose in fake_poses:
orig_nonzeros_count = np.count_nonzero(pose.body.data)
Expand All @@ -67,7 +67,7 @@ def test_pose_hide_legs(fake_poses: List[Pose]):
assert orig_nonzeros_count > new_nonzeros_count


@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"])
@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"])
def test_pose_shoulders(fake_poses: List[Pose]):
for pose in fake_poses:
shoulders = pose_shoulders(pose.header)
Expand All @@ -76,21 +76,21 @@ def test_pose_shoulders(fake_poses: List[Pose]):
assert "LEFT" in shoulders[1][1] or shoulders[1][1][0] == "L"


@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"])
@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"])
def test_hands_indexes(fake_poses: List[Pose]):
for pose in fake_poses:
indices = hands_indexes(pose.header)
assert len(indices) > 0


@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"])
@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"])
def test_normalize_pose_size(fake_poses: List[Pose]):
for pose in fake_poses:
normalize_pose_size(pose)
# TODO: more tests, compare with test data


@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"])
@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"])
def test_pose_normalization_info(fake_poses: List[Pose]):
for pose in fake_poses:
info = pose_normalization_info(pose.header)
Expand All @@ -101,7 +101,7 @@ def test_pose_normalization_info(fake_poses: List[Pose]):
# TODO: more tests, compare with test data


@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"])
@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"])
def test_get_hand_wrist_index(fake_poses: List[Pose]):
for pose in fake_poses:
for hand in ["LEFT", "RIGHT"]:
Expand All @@ -110,23 +110,23 @@ def test_get_hand_wrist_index(fake_poses: List[Pose]):
# TODO: what are the expected values?


@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"])
@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"])
def test_get_body_hand_wrist_index(fake_poses: List[Pose]):
for pose in fake_poses:
for hand in ["LEFT", "RIGHT"]:
index = get_body_hand_wrist_index(pose, hand)
# TODO: what are the expected values?


@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"])
@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"])
def test_correct_wrists(fake_poses: List[Pose]):
for pose in fake_poses:
corrected_pose = correct_wrists(pose)
assert np.array_equal(corrected_pose.body.data, pose.body.data) is False
assert corrected_pose != pose


@pytest.mark.parametrize("fake_poses", list(get_args(SupportedPoseFormat)), indirect=["fake_poses"])
@pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"])
def test_hands_components(fake_poses: List[Pose]):
for pose in fake_poses:
hands_components_returned = hands_components(pose.header)
Expand Down

0 comments on commit 9996a43

Please sign in to comment.