Skip to content

feat: shard_key method implementation #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
56 changes: 55 additions & 1 deletion src/lighthouseweb3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import io
from typing import List, Dict, Any
from .functions import (
upload as d,
deal_status,
Expand All @@ -16,7 +17,11 @@
remove_ipns_record as removeIpnsRecord,
create_wallet as createWallet
)

from .functions.kavach import (
generate,
recover_key as recoverKey,
shard_key as shardKey
)

class Lighthouse:
def __init__(self, token: str = ""):
Expand Down Expand Up @@ -224,3 +229,52 @@ def getTagged(self, tag: str):
except Exception as e:
raise e

class Kavach:
"""
Kavach is a simple library for generating and managing secrets.

It uses Shamir's Secret Sharing algorithm to split a secret into multiple shares.
"""

@staticmethod
def generate(threshold: int, keyCount: int) -> List[Dict[str, Any]]:
"""
Generates a set of key shards with a given threshold and key count.

:param threshold: int, The minimum number of shards required to recover the key.
:param keyCount: int, The number of shards to generate.
:return: List[Dict[str, Any]], A list of key shards.
"""
try:
return generate.generate(threshold, keyCount)
except Exception as e:
raise e


@staticmethod
def recoverKey(keyShards: List[Dict[str, Any]]) -> int:
"""
Recovers a key from a set of key shards.

:param keyShards: List[Dict[str, Any]], A list of key shards.
:return: int, The recovered key.
"""
try:
return recoverKey.recover_key(keyShards)
except Exception as e:
raise e

@staticmethod
def shardKey(masterKey: int, threshold: int, keyCount: int) -> List[Dict[str, Any]]:
"""
Splits a master key into multiple shards.

:param masterKey: int, The master key to be split.
:param threshold: int, The minimum number of shards required to recover the key.
:param keyCount: int, The number of shards to generate.
:return: List[Dict[str, Any]], A list of key shards.
"""
try:
return shardKey.shard_key(masterKey, threshold, keyCount)
except Exception as e:
raise e
Empty file.
2 changes: 2 additions & 0 deletions src/lighthouseweb3/functions/kavach/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#A 257-bit prime to accommodate 256-bit secrets
PRIME = 2**256 + 297
45 changes: 45 additions & 0 deletions src/lighthouseweb3/functions/kavach/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import secrets
import logging
from typing import Dict, List, Any
from .shard_key import shard_key

logger = logging.getLogger(__name__)

async def generate(threshold: int = 3, key_count: int = 5) -> Dict[str, Any]:
"""
Generate threshold cryptography key shards using Shamir's Secret Sharing

Args:
threshold: Minimum number of shards needed to reconstruct the secret
key_count: Total number of key shards to generate

Returns:
{
"masterKey": "<master private key hex string>",
"keyShards": [
{
"key": "<shard value hex string>",
"index": "<shard index hex string>"
}
]
}
"""
logger.info(f"Generating key shards with threshold={threshold}, key_count={key_count}")

try:
random_int = secrets.randbits(256)
master_key = f"0x{random_int:064x}"

result = await shard_key(master_key, threshold, key_count)

if not result['isShardable']:
raise ValueError(result['error'])

return {
"masterKey": master_key,
"keyShards": result['keyShards']
}

except Exception as e:
logger.error(f"Error during key generation: {str(e)}")
raise e
178 changes: 178 additions & 0 deletions src/lighthouseweb3/functions/kavach/recover_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from typing import List, Dict, Any
import logging
from .config import PRIME

logger = logging.getLogger(__name__)

from typing import Tuple

def extended_gcd(a: int, b: int) -> Tuple[int, int, int]:
"""Extended Euclidean algorithm to find modular inverse.

Args:
a: First integer
b: Second integer

Returns:
A tuple (g, x, y) such that a*x + b*y = g = gcd(a, b)
"""
if a == 0:
return b, 0, 1
else:
g, y, x = extended_gcd(b % a, a)
return g, x - (b // a) * y, y

def modinv(a: int, m: int) -> int:
"""Find the modular inverse of a mod m."""
g, x, y = extended_gcd(a, m)
if g != 1:
raise ValueError('Modular inverse does not exist')
else:
return x % m

def lagrange_interpolation(shares: List[Dict[str, str]], prime: int) -> int:
"""
Reconstruct the secret using Lagrange interpolation.

Args:
shares: List of dictionaries with 'key' and 'index' fields
prime: The prime number used in the finite field

Returns:
The reconstructed secret as integer

Raises:
ValueError: If there are duplicate indices
"""

points = []
seen_indices = set()

for i, share in enumerate(shares):
try:
key_str, index_str = validate_share(share, i)
x = int(index_str, 16)

if x in seen_indices:
raise ValueError(f"Duplicate share index found: 0x{x:x}")
seen_indices.add(x)

y = int(key_str, 16)
points.append((x, y))
except ValueError as e:
raise ValueError(f"Invalid share at position {i}: {e}")


secret = 0

for i, (x_i, y_i) in enumerate(points):
# Calculate the Lagrange basis polynomial L_i(0)
# Evaluate at x=0 to get the constant term
numerator = 1
denominator = 1

for j, (x_j, _) in enumerate(points):
if i != j:
numerator = (numerator * (-x_j)) % prime
denominator = (denominator * (x_i - x_j)) % prime

try:
inv_denominator = modinv(denominator, prime)
except ValueError as e:
raise ValueError(f"Error in modular inverse calculation: {e}")

term = (y_i * numerator * inv_denominator) % prime
secret = (secret + term) % prime

return secret

def validate_share(share: Dict[str, str], index: int) -> Tuple[str, str]:
"""Validate and normalize a single share.

Args:
share: Dictionary containing 'key' and 'index' fields
index: Position of the share in the input list (for error messages)

Returns:
Tuple of (normalized_key, normalized_index) as strings without '0x' prefix

Raises:
ValueError: If the share is invalid
"""
if not isinstance(share, dict):
raise ValueError(f"Share at index {index} must be a dictionary")

if 'key' not in share or 'index' not in share:
raise ValueError(f"Share at index {index} is missing required fields 'key' or 'index'")

key_str = str(share['key']).strip().lower()
index_str = str(share['index']).strip().lower()

if key_str.startswith('0x'):
key_str = key_str[2:]
if index_str.startswith('0x'):
index_str = index_str[2:]

if not key_str:
raise ValueError(f"Empty key in share at index {index}")
if not index_str:
raise ValueError(f"Empty index in share at index {index}")

if len(key_str) % 2 != 0:
key_str = '0' + key_str

if len(index_str) % 2 != 0:
index_str = '0' + index_str

try:
bytes.fromhex(key_str)
except ValueError:
raise ValueError(f"Invalid key format in share at index {index}: must be a valid hex string")

try:
bytes.fromhex(index_str)
except ValueError:
raise ValueError(f"Invalid index format in share at index {index}: must be a valid hex string")

index_int = int(index_str, 16)
if not (0 <= index_int <= 0xFFFFFFFF):
raise ValueError(f"Index out of range in share at index {index}: must be between 0 and 2^32-1")

return key_str, index_str

async def recover_key(keyShards: List[Dict[str, str]]) -> Dict[str, Any]:
"""
Recover the master key from a subset of key shares using Lagrange interpolation.

Args:
keyShards: List of dictionaries containing 'key' and 'index' fields

Returns:
{
"masterKey": "<recovered master key hex string>",
"error": "<error message if any>"
}
"""
logger.info(f"Attempting to recover master key from {len(keyShards)} shares")

try:
for i, share in enumerate(keyShards):
validate_share(share, i)
secret = lagrange_interpolation(keyShards, PRIME)
master_key = f"0x{secret:064x}"
return {
"masterKey": master_key,
"error": None
}
except ValueError as e:
logger.error(f"Validation error during key recovery: {str(e)}")
return {
"masterKey": None,
"error": f"Validation error: {str(e)}"
}
except Exception as e:
logger.error(f"Error during key recovery: {str(e)}")
return {
"masterKey": None,
"error": f"Recovery error: {str(e)}"
}
Loading