Skip to content

Commit c0f01cd

Browse files
authored
feat: add get_claimable_reward method to Group and GroupingModuleClient (#136)
1 parent 4cb6fdb commit c0f01cd

File tree

5 files changed

+225
-1
lines changed

5 files changed

+225
-1
lines changed

src/story_protocol_python_sdk/abi/GroupingModule/GroupingModule_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,6 @@ def build_registerGroup_transaction(self, groupPool, tx_params):
6666
return self.contract.functions.registerGroup(groupPool).build_transaction(
6767
tx_params
6868
)
69+
70+
def getClaimableReward(self, groupId, token, ipIds):
71+
return self.contract.functions.getClaimableReward(groupId, token, ipIds).call()

src/story_protocol_python_sdk/resources/Group.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,40 @@ def collect_royalties(
647647
except Exception as e:
648648
raise ValueError(f"Failed to collect royalties: {str(e)}")
649649

650+
def get_claimable_reward(
651+
self,
652+
group_ip_id: Address,
653+
currency_token: Address,
654+
member_ip_ids: list[Address],
655+
) -> list[int]:
656+
"""
657+
Returns the available reward for each IP in the group.
658+
659+
:param group_ip_id Address: The ID of the group IP.
660+
:param currency_token Address: The address of the currency (revenue) token to check.
661+
:param member_ip_ids list[Address]: The IDs of the member IPs to check claimable rewards for.
662+
:return list[int]: A list of claimable reward amounts corresponding to each member IP ID.
663+
"""
664+
try:
665+
if not self.web3.is_address(group_ip_id):
666+
raise ValueError(f"Invalid group IP ID: {group_ip_id}")
667+
if not self.web3.is_address(currency_token):
668+
raise ValueError(f"Invalid currency token: {currency_token}")
669+
for ip_id in member_ip_ids:
670+
if not self.web3.is_address(ip_id):
671+
raise ValueError(f"Invalid member IP ID: {ip_id}")
672+
673+
claimable_rewards = self.grouping_module_client.getClaimableReward(
674+
groupId=group_ip_id,
675+
token=currency_token,
676+
ipIds=member_ip_ids,
677+
)
678+
679+
return claimable_rewards
680+
681+
except Exception as e:
682+
raise ValueError(f"Failed to get claimable rewards: {str(e)}")
683+
650684
def _get_license_data(self, license_data: list) -> list:
651685
"""
652686
Process license data into the format expected by the contracts.

src/story_protocol_python_sdk/scripts/config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@
184184
"addIp",
185185
"IPGroupRegistered",
186186
"claimReward",
187-
"collectRoyalties"
187+
"collectRoyalties",
188+
"getClaimableReward"
188189
]
189190
},
190191
{

tests/integration/test_integration_group.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,56 @@ def test_collect_and_distribute_group_royalties(
307307
assert len(response["royalties_distributed"]) == 2
308308
assert response["royalties_distributed"][0]["amount"] == 10
309309
assert response["royalties_distributed"][1]["amount"] == 10
310+
311+
def test_get_claimable_reward(
312+
self, story_client: StoryClient, nft_collection: Address
313+
):
314+
"""Test getting claimable rewards for group members."""
315+
# Register IP id
316+
result1 = GroupTestHelper.mint_and_register_ip_asset_with_pil_terms(
317+
story_client, nft_collection
318+
)
319+
ip_id1 = result1["ip_id"]
320+
license_terms_id1 = result1["license_terms_id"]
321+
result2 = GroupTestHelper.mint_and_register_ip_asset_with_pil_terms(
322+
story_client, nft_collection
323+
)
324+
ip_id2 = result2["ip_id"]
325+
license_terms_id2 = result2["license_terms_id"]
326+
327+
# Register group id
328+
group_ip_id = GroupTestHelper.register_group_and_attach_license(
329+
story_client, license_terms_id1, [ip_id1, ip_id2]
330+
)
331+
# Create a derivative IP and pay royalties
332+
child_ip_id = GroupTestHelper.mint_and_register_ip_and_make_derivative(
333+
story_client, nft_collection, group_ip_id, license_terms_id1
334+
)
335+
child_ip_id2 = GroupTestHelper.mint_and_register_ip_and_make_derivative(
336+
story_client, nft_collection, group_ip_id, license_terms_id2
337+
)
338+
339+
# Pay royalties from group IP id to child IP id
340+
GroupTestHelper.pay_royalty_and_transfer_to_vault(
341+
story_client, child_ip_id, group_ip_id, MockERC20, 100
342+
)
343+
GroupTestHelper.pay_royalty_and_transfer_to_vault(
344+
story_client, child_ip_id2, group_ip_id, MockERC20, 100
345+
)
346+
347+
# Collect royalties
348+
story_client.Group.collect_royalties(
349+
group_ip_id=group_ip_id,
350+
currency_token=MockERC20,
351+
)
352+
# Get claimable rewards after royalties are collected
353+
claimable_rewards = story_client.Group.get_claimable_reward(
354+
group_ip_id=group_ip_id,
355+
currency_token=MockERC20,
356+
member_ip_ids=[ip_id1, ip_id2],
357+
)
358+
359+
assert isinstance(claimable_rewards, list)
360+
assert len(claimable_rewards) == 2
361+
assert claimable_rewards[0] == 10
362+
assert claimable_rewards[1] == 10

tests/unit/resources/test_group.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,3 +422,136 @@ def test_claim_rewards_transaction_build_failure(
422422
currency_token=ADDRESS,
423423
member_ip_ids=[IP_ID],
424424
)
425+
426+
427+
class TestGroupGetClaimableReward:
428+
"""Test class for Group.get_claimable_reward method"""
429+
430+
def test_get_claimable_reward_invalid_group_ip_id(
431+
self, group: Group, mock_web3_is_address
432+
):
433+
"""Test get_claimable_reward with invalid group IP ID."""
434+
invalid_group_ip_id = "invalid_group_ip_id"
435+
with mock_web3_is_address(False):
436+
with pytest.raises(
437+
ValueError,
438+
match=f"Failed to get claimable rewards: Invalid group IP ID: {invalid_group_ip_id}",
439+
):
440+
group.get_claimable_reward(
441+
group_ip_id=invalid_group_ip_id,
442+
currency_token=ADDRESS,
443+
member_ip_ids=[IP_ID],
444+
)
445+
446+
def test_get_claimable_reward_invalid_currency_token(self, group: Group, mock_web3):
447+
"""Test get_claimable_reward with invalid currency token."""
448+
invalid_currency_token = "invalid_currency_token"
449+
with patch.object(mock_web3, "is_address") as mock_is_address:
450+
# group_ip_id=True, currency_token=False, member_ip_ids=True
451+
mock_is_address.side_effect = [True, False]
452+
with pytest.raises(
453+
ValueError,
454+
match=f"Failed to get claimable rewards: Invalid currency token: {invalid_currency_token}",
455+
):
456+
group.get_claimable_reward(
457+
group_ip_id=IP_ID,
458+
currency_token=invalid_currency_token,
459+
member_ip_ids=[IP_ID],
460+
)
461+
462+
def test_get_claimable_reward_invalid_member_ip_id(self, group: Group, mock_web3):
463+
"""Test get_claimable_reward with invalid member IP ID."""
464+
invalid_member_ip_id = "invalid_member_ip_id"
465+
with patch.object(mock_web3, "is_address") as mock_is_address:
466+
# group_ip_id=True, currency_token=True, first_member=True, second_member=False
467+
mock_is_address.side_effect = [True, True, True, False]
468+
with pytest.raises(
469+
ValueError,
470+
match=f"Failed to get claimable rewards: Invalid member IP ID: {invalid_member_ip_id}",
471+
):
472+
group.get_claimable_reward(
473+
group_ip_id=IP_ID,
474+
currency_token=ADDRESS,
475+
member_ip_ids=[ADDRESS, invalid_member_ip_id],
476+
)
477+
478+
def test_get_claimable_reward_success(
479+
self,
480+
group: Group,
481+
mock_web3_is_address,
482+
):
483+
"""Test successful get_claimable_reward operation."""
484+
expected_claimable_rewards = [100, 200, 300]
485+
member_ip_ids = [IP_ID, ADDRESS, ADDRESS]
486+
487+
with mock_web3_is_address():
488+
with patch.object(
489+
group.grouping_module_client,
490+
"getClaimableReward",
491+
return_value=expected_claimable_rewards,
492+
) as mock_get_claimable_reward:
493+
result = group.get_claimable_reward(
494+
group_ip_id=IP_ID,
495+
currency_token=ADDRESS,
496+
member_ip_ids=member_ip_ids,
497+
)
498+
499+
# Verify the result
500+
assert result == expected_claimable_rewards
501+
assert len(result) == len(member_ip_ids)
502+
mock_get_claimable_reward.assert_called_once_with(
503+
groupId=IP_ID,
504+
token=ADDRESS,
505+
ipIds=member_ip_ids,
506+
)
507+
508+
def test_get_claimable_reward_empty_member_ip_ids(
509+
self,
510+
group: Group,
511+
mock_web3_is_address,
512+
):
513+
"""Test get_claimable_reward with empty member IP IDs list."""
514+
expected_claimable_rewards: list[int] = []
515+
516+
with mock_web3_is_address():
517+
with patch.object(
518+
group.grouping_module_client,
519+
"getClaimableReward",
520+
return_value=expected_claimable_rewards,
521+
) as mock_get_claimable_reward:
522+
result = group.get_claimable_reward(
523+
group_ip_id=IP_ID,
524+
currency_token=ADDRESS,
525+
member_ip_ids=[],
526+
)
527+
528+
# Verify the result
529+
assert result == expected_claimable_rewards
530+
assert len(result) == 0
531+
532+
# Verify the client method was called with correct parameters
533+
mock_get_claimable_reward.assert_called_once_with(
534+
groupId=IP_ID,
535+
token=ADDRESS,
536+
ipIds=[],
537+
)
538+
539+
def test_get_claimable_reward_client_call_failure(
540+
self, group: Group, mock_web3_is_address
541+
):
542+
"""Test get_claimable_reward when client call fails."""
543+
with mock_web3_is_address():
544+
with patch.object(
545+
group.grouping_module_client,
546+
"getClaimableReward",
547+
side_effect=Exception("Client call failed"),
548+
):
549+
with pytest.raises(
550+
ValueError,
551+
match="Failed to get claimable rewards: Client call failed",
552+
):
553+
group.get_claimable_reward(
554+
group_ip_id=IP_ID,
555+
currency_token=ADDRESS,
556+
member_ip_ids=[IP_ID],
557+
)

0 commit comments

Comments
 (0)