Skip to content

Commit

Permalink
[UPDATE] Enhance SDK Error Handling and Test Coverage (#34)
Browse files Browse the repository at this point in the history
* new psdk version release

* Updated integration test setup to reduce redunant code across tests

* Added unit tests for license module

* Updated checking if none to be faster

* Added checks for env vars

* added comment to clarify royalty policy lap
  • Loading branch information
aandrewchung authored Jan 30, 2025
1 parent c1c9d87 commit 243f447
Show file tree
Hide file tree
Showing 6 changed files with 532 additions and 153 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='story_protocol_python_sdk',
version='0.3.4',
version='0.3.5',
packages=find_packages(where='src', exclude=["tests"]),
package_dir={'': 'src'},
install_requires=[
Expand Down
5 changes: 3 additions & 2 deletions src/story_protocol_python_sdk/utils/license_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from story_protocol_python_sdk.abi.RoyaltyModule.RoyaltyModule_client import RoyaltyModuleClient

ZERO_ADDRESS = "0x0000000000000000000000000000000000000000"
ROYALTY_POLICY = "0xBe54FB168b3c982b7AaE60dB6CF75Bd8447b390E"

class LicenseTerms:
def __init__(self, web3: Web3):
Expand Down Expand Up @@ -45,8 +46,8 @@ def get_license_term_by_type(self, type, term=None):
if not term or 'defaultMintingFee' not in term or 'currency' not in term:
raise ValueError("DefaultMintingFee, currency are required for commercial use PIL.")

if 'royaltyPolicyAddress' not in term:
raise ValueError("royaltyPolicyAddress is required")
if term['royaltyPolicyAddress'] is None:
term['royaltyPolicyAddress'] = ROYALTY_POLICY

license_terms.update({
'defaultMintingFee': int(term['defaultMintingFee']),
Expand Down
67 changes: 67 additions & 0 deletions tests/integration/setup_for_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import sys
import pytest
from dotenv import load_dotenv
from web3 import Web3

# Ensure the src directory is in the Python path
current_dir = os.path.dirname(__file__)
src_path = os.path.abspath(os.path.join(current_dir, '..', '..'))
if src_path not in sys.path:
sys.path.append(src_path)

# Import everything from utils
from utils import (
get_story_client_in_devnet,
get_token_id,
mint_tokens,
approve,
getBlockTimestamp,
check_event_in_tx,
MockERC721,
MockERC20,
ZERO_ADDRESS,
ROYALTY_POLICY,
ROYALTY_MODULE,
PIL_LICENSE_TEMPLATE
)

# Load environment variables
load_dotenv(override=True)
private_key = os.getenv('WALLET_PRIVATE_KEY')
rpc_url = os.getenv('RPC_PROVIDER_URL')

if not private_key:
raise ValueError("WALLET_PRIVATE_KEY environment variable is not set")
if not rpc_url:
raise ValueError("RPC_PROVIDER_URL environment variable is not set")

# Initialize Web3
web3 = Web3(Web3.HTTPProvider(rpc_url))
if not web3.is_connected():
raise Exception("Failed to connect to Web3 provider")

# Set up the account with the private key
account = web3.eth.account.from_key(private_key)

@pytest.fixture(scope="session")
def story_client():
return get_story_client_in_devnet(web3, account)

# Export everything needed by test files
__all__ = [
'web3',
'account',
'story_client',
'get_token_id',
'mint_tokens',
'approve',
'getBlockTimestamp',
'check_event_in_tx',
'MockERC721',
'MockERC20',
'ZERO_ADDRESS',
'ROYALTY_POLICY',
'ROYALTY_MODULE',
'PIL_LICENSE_TEMPLATE'
]
105 changes: 53 additions & 52 deletions tests/integration/test_integration_license.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,28 @@
from dotenv import load_dotenv
from web3 import Web3

# Ensure the src directory is in the Python path
current_dir = os.path.dirname(__file__)
src_path = os.path.abspath(os.path.join(current_dir, '..', '..'))
if src_path not in sys.path:
sys.path.append(src_path)

from utils import get_story_client_in_devnet, MockERC20, MockERC721, get_token_id, approve, mint_tokens

load_dotenv(override=True)
private_key = os.getenv('WALLET_PRIVATE_KEY')
rpc_url = os.getenv('RPC_PROVIDER_URL')

# Initialize Web3
web3 = Web3(Web3.HTTPProvider(rpc_url))
if not web3.is_connected():
raise Exception("Failed to connect to Web3 provider")

# Set up the account with the private key
account = web3.eth.account.from_key(private_key)

royalty_policy="0xBe54FB168b3c982b7AaE60dB6CF75Bd8447b390E"
royalty_module="0xD2f60c40fEbccf6311f8B47c4f2Ec6b040400086"
license_template="0x2E896b0b2Fdb7457499B56AAaA4AE55BCB4Cd316"

@pytest.fixture(scope="module")
def story_client():
return get_story_client_in_devnet(web3, account)
from setup_for_integration import (
web3,
account,
story_client,
get_token_id,
mint_tokens,
approve,
getBlockTimestamp,
check_event_in_tx,
MockERC721,
MockERC20,
ZERO_ADDRESS,
ROYALTY_POLICY,
ROYALTY_MODULE,
PIL_LICENSE_TEMPLATE
)

def test_registerPILTerms(story_client):
response = story_client.License.registerPILTerms(
transferable=False,
royalty_policy=story_client.web3.to_checksum_address("0x0000000000000000000000000000000000000000"),
default_minting_fee=92,
default_minting_fee=0,
expiration=0,
commercial_use=False,
commercial_attribution=False,
Expand Down Expand Up @@ -67,32 +56,44 @@ def test_registerNonComSocialRemixingPIL(story_client):
assert 'licenseTermsId' in response
assert response['licenseTermsId'] is not None
assert isinstance(response['licenseTermsId'], int)

def test_registerCommercialUsePIL(story_client):

@pytest.fixture(scope="module")
def registerCommercialUsePIL(story_client):
response = story_client.License.registerCommercialUsePIL(
default_minting_fee=11,
currency=MockERC20,
royalty_policy=royalty_policy
royalty_policy=ROYALTY_POLICY
)

assert response is not None, "Response is None, indicating the contract interaction failed."
assert 'licenseTermsId' in response, "Response does not contain 'licenseTermsId'."
assert response['licenseTermsId'] is not None, "'licenseTermsId' is None."
assert isinstance(response['licenseTermsId'], int), "'licenseTermsId' is not an integer."

def test_registerCommercialRemixPIL(story_client):
return response['licenseTermsId']

def test_registerCommercialUsePIL(story_client, registerCommercialUsePIL):
assert registerCommercialUsePIL is not None

@pytest.fixture(scope="module")
def registerCommercialRemixPIL(story_client):
response = story_client.License.registerCommercialRemixPIL(
default_minting_fee=1,
currency=MockERC20,
commercial_rev_share=100,
royalty_policy=royalty_policy
royalty_policy=ROYALTY_POLICY
)

assert response is not None, "Response is None, indicating the contract interaction failed."
assert 'licenseTermsId' in response, "Response does not contain 'licenseTermsId'."
assert response['licenseTermsId'] is not None, "'licenseTermsId' is None."
assert isinstance(response['licenseTermsId'], int), "'licenseTermsId' is not an integer."

return response['licenseTermsId']

def test_registerCommercialRemixPIL(story_client, registerCommercialRemixPIL):
assert registerCommercialRemixPIL is not None

@pytest.fixture(scope="module")
def ip_id(story_client):
token_id = get_token_id(MockERC721, story_client.web3, story_client.account)
Expand All @@ -114,7 +115,7 @@ def ip_id(story_client):
erc20_contract_address=MockERC20,
web3=web3,
account=account,
spender_address=royalty_module,
spender_address=ROYALTY_MODULE,
amount=100000 * 10 ** 6)

assert response is not None
Expand All @@ -123,22 +124,22 @@ def ip_id(story_client):

return response['ipId']

def test_attachLicenseTerms(story_client, ip_id):
license_terms_id = 5
def test_attachLicenseTerms(story_client, ip_id, registerCommercialUsePIL):
license_terms_id = registerCommercialUsePIL

response = story_client.License.attachLicenseTerms(ip_id, license_template, license_terms_id)
response = story_client.License.attachLicenseTerms(ip_id, PIL_LICENSE_TEMPLATE, license_terms_id)

assert response is not None, "Response is None, indicating the contract interaction failed."
assert 'txHash' in response, "Response does not contain 'txHash'."
assert response['txHash'] is not None, "'txHash' is None."
assert isinstance(response['txHash'], str), "'txHash' is not a string."
assert len(response['txHash']) > 0, "'txHash' is empty."

def test_mintLicenseTokens(story_client, ip_id):
def test_mintLicenseTokens(story_client, ip_id, registerCommercialUsePIL):
response = story_client.License.mintLicenseTokens(
licensor_ip_id=ip_id,
license_template=license_template,
license_terms_id=5,
license_template=PIL_LICENSE_TEMPLATE,
license_terms_id=registerCommercialUsePIL,
amount=1,
receiver=account.address
)
Expand All @@ -160,10 +161,10 @@ def test_getLicenseTerms(story_client):

assert response is not None, "Response is None, indicating the call failed."

def test_predictMintingLicenseFee(story_client, ip_id):
def test_predictMintingLicenseFee(story_client, ip_id, registerCommercialUsePIL):
response = story_client.License.predictMintingLicenseFee(
licensor_ip_id=ip_id,
license_terms_id=5,
license_terms_id=registerCommercialUsePIL,
amount=1
)

Expand All @@ -176,7 +177,7 @@ def test_predictMintingLicenseFee(story_client, ip_id):
assert response['amount'] is not None, "'amount' is None."
assert isinstance(response['amount'], int), "'amount' is not an integer."

def test_setLicensingConfig(story_client, ip_id):
def test_setLicensingConfig(story_client, ip_id, registerCommercialRemixPIL):
licensing_config = {
'mintingFee': 1,
'isSet': True,
Expand All @@ -190,9 +191,9 @@ def test_setLicensingConfig(story_client, ip_id):

response = story_client.License.setLicensingConfig(
ip_id=ip_id,
license_terms_id=0,
license_terms_id=registerCommercialRemixPIL,
licensing_config=licensing_config,
license_template=None # Will default to zero address
license_template=PIL_LICENSE_TEMPLATE
)

assert response is not None, "Response is None, indicating the contract interaction failed."
Expand Down Expand Up @@ -249,14 +250,14 @@ def setup_license_terms(story_client, ip_id):
default_minting_fee=1,
currency=MockERC20,
commercial_rev_share=100,
royalty_policy=royalty_policy
royalty_policy=ROYALTY_POLICY
)
license_id = response['licenseTermsId']

# Attach the license terms
story_client.License.attachLicenseTerms(
ip_id=ip_id,
license_template=license_template,
license_template=PIL_LICENSE_TEMPLATE,
license_terms_id=license_id
)

Expand All @@ -266,7 +267,7 @@ def test_multi_token_minting(story_client, ip_id, setup_license_terms):
"""Test minting multiple license tokens at once."""
response = story_client.License.mintLicenseTokens(
licensor_ip_id=ip_id,
license_template=license_template,
license_template=PIL_LICENSE_TEMPLATE,
license_terms_id=setup_license_terms,
amount=3, # Mint multiple tokens
receiver=account.address
Expand All @@ -281,24 +282,24 @@ def test_multi_token_minting(story_client, ip_id, setup_license_terms):
assert isinstance(response['licenseTokenIds'], list)
assert len(response['licenseTokenIds']) > 0

def test_set_licensing_config_with_hooks(story_client, ip_id):
def test_set_licensing_config_with_hooks(story_client, ip_id, registerCommercialRemixPIL):
"""Test setting licensing configuration with hooks enabled."""
licensing_config = {
'mintingFee': 100,
'isSet': True,
'licensingHook': "0x0000000000000000000000000000000000000000",
'hookData': "0x1234567890", # Different hook data
'commercialRevShare': 50, # 50% revenue share
'commercialRevShare': 100, # 50% revenue share
'disabled': False,
'expectMinimumGroupRewardShare': 10, # 10% minimum group reward
'expectGroupRewardPool': "0x0000000000000000000000000000000000000000"
}

response = story_client.License.setLicensingConfig(
ip_id=ip_id,
license_terms_id=0,
license_terms_id=registerCommercialRemixPIL,
licensing_config=licensing_config,
license_template=license_template
license_template=PIL_LICENSE_TEMPLATE
)

assert response is not None
Expand Down
7 changes: 4 additions & 3 deletions tests/integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
load_dotenv()

# Mock ERC721 contract address
MockERC721 = "0xA3999e5ef20874478f1DD7a0534D05F766034478"
MockERC721 = "0xa1119092ea911202E0a65B743a13AE28C5CF2f21"

# Mock ERC20 contract address (same as used in TypeScript tests)
MockERC20 = "0x688abA77b2daA886c0aF029961Dc5fd219cEc3f6"
MockERC20 = "0xF2104833d386a2734a4eB3B8ad6FC6812F29E38E"

ZERO_ADDRESS = "0x0000000000000000000000000000000000000000"
ROYALTY_POLICY="0xBe54FB168b3c982b7AaE60dB6CF75Bd8447b390E"
ROYALTY_POLICY="0xBe54FB168b3c982b7AaE60dB6CF75Bd8447b390E" #Royalty Policy LAP
ROYALTY_MODULE="0xD2f60c40fEbccf6311f8B47c4f2Ec6b040400086"
PIL_LICENSE_TEMPLATE="0x2E896b0b2Fdb7457499B56AAaA4AE55BCB4Cd316"

def get_story_client_in_sepolia(web3: Web3, account) -> StoryClient:
chain_id = 11155111 # Sepolia chain ID
Expand Down
Loading

0 comments on commit 243f447

Please sign in to comment.