Skip to content

Commit

Permalink
Merge pull request #482 from singnet/development
Browse files Browse the repository at this point in the history
Patch fixes and training
  • Loading branch information
AlbinaPomogalova authored Mar 15, 2024
2 parents b5a728b + ba84fe8 commit 52d2052
Show file tree
Hide file tree
Showing 11 changed files with 302 additions and 22 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ venv/
__pycache__
blockchain/node_modules
snet.egg-info/
*.pyi
16 changes: 11 additions & 5 deletions packages/sdk/setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import pkg_resources
import importlib.metadata
from setuptools import setup, find_namespace_packages
from os import path

Expand All @@ -10,10 +10,16 @@


def is_package_installed(package_name):
installed_modules = [p.project_name for p in pkg_resources.working_set]
print("Installed modules:")
print(installed_modules)
return package_name in installed_modules
try:
package = importlib.metadata.metadata(package_name)
name, version = package.json["name"], package.json["version"]
print(f"Installed {name} {version}")
return True
except importlib.metadata.PackageNotFoundError:
print(f"Package {package_name} is not installed")
return False




dependencies = []
Expand Down
5 changes: 4 additions & 1 deletion packages/sdk/snet/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ def __init__(

# Instantiate Ethereum client
eth_rpc_endpoint = self._config.get("eth_rpc_endpoint", "https://mainnet.infura.io/v3/e7732e1f679e461b9bb4da5653ac3fc2")
provider = web3.HTTPProvider(eth_rpc_endpoint)
eth_rpc_request_kwargs = self._config.get("eth_rpc_request_kwargs")

provider = web3.HTTPProvider(endpoint_uri=eth_rpc_endpoint, request_kwargs=eth_rpc_request_kwargs)

self.web3 = web3.Web3(provider)

# Get MPE contract address from config if specified; mostly for local testing
Expand Down
20 changes: 16 additions & 4 deletions packages/sdk/snet/sdk/service_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
import importlib

import grpc
import snet.sdk.generic_client_interceptor as generic_client_interceptor
import web3
from eth_account.messages import defunct_hash_message
from rfc3986 import urlparse

import snet.sdk.generic_client_interceptor as generic_client_interceptor
from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider
from snet.snet_cli.utils.utils import RESOURCES_PATH, add_to_path
from snet.sdk.root_certificate import root_certificate
from snet.snet_cli.utils.utils import RESOURCES_PATH, add_to_path


class _ClientCallDetails(
collections.namedtuple(
Expand Down Expand Up @@ -69,7 +72,8 @@ def _get_grpc_channel(self):
if endpoint_object.scheme == "http":
return grpc.insecure_channel(channel_endpoint)
elif endpoint_object.scheme == "https":
return grpc.secure_channel(channel_endpoint, grpc.ssl_channel_credentials(root_certificates=root_certificate))
return grpc.secure_channel(channel_endpoint,
grpc.ssl_channel_credentials(root_certificates=root_certificate))
else:
raise ValueError('Unsupported scheme in service metadata ("{}")'.format(endpoint_object.scheme))

Expand Down Expand Up @@ -154,13 +158,21 @@ def generate_signature(self, message):

return signature

def generate_training_signature(self, text: str, address, block_number):
message = web3.Web3.solidity_keccak(
["string", "address", "uint256"],
[text, address, block_number]
)
return self.sdk_web3.eth.account.signHash(defunct_hash_message(message),
self.account.signer_private_key).signature

def get_free_call_config(self):
return self.options['email'], self.options['free_call_auth_token-bin'], self.options[
'free-call-token-expiry-block']

def get_service_details(self):
return self.org_id, self.service_id, self.group["group_id"], \
self.service_metadata.get_all_endpoints_for_group(self.group["group_name"])[0]
self.service_metadata.get_all_endpoints_for_group(self.group["group_name"])[0]

def get_concurrency_flag(self):
return self.options.get('concurrency', True)
Expand Down
150 changes: 150 additions & 0 deletions packages/sdk/snet/sdk/training/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import enum
import importlib
from urllib.parse import urlparse

import grpc
import web3

from snet.sdk.root_certificate import root_certificate
from snet.snet_cli.utils.utils import RESOURCES_PATH, add_to_path


# for local debug
# from snet.snet_cli.resources.proto import training_pb2_grpc
# from snet.snet_cli.resources.proto import training_pb2


# from daemon code
class ModelMethodMessage(enum.Enum):
CreateModel = "__CreateModel"
GetModelStatus = "__GetModelStatus"
UpdateModelAccess = "__UpdateModelAccess"
GetAllModels = "__UpdateModelAccess"
DeleteModel = "__GetModelStatus"


class TrainingModel:

def __init__(self):
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
self.training_pb2 = importlib.import_module("training_pb2")

with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
self.training_pb2_grpc = importlib.import_module("training_pb2_grpc")

def _invoke_model(self, service_client, msg: ModelMethodMessage):
org_id, service_id, group_id, daemon_endpoint = service_client.get_service_details()

endpoint_object = urlparse(daemon_endpoint)
if endpoint_object.port is not None:
channel_endpoint = endpoint_object.hostname + ":" + str(endpoint_object.port)
else:
channel_endpoint = endpoint_object.hostname

if endpoint_object.scheme == "http":
print("creating http channel: ", channel_endpoint)
channel = grpc.insecure_channel(channel_endpoint)
elif endpoint_object.scheme == "https":
channel = grpc.secure_channel(channel_endpoint,
grpc.ssl_channel_credentials(root_certificates=root_certificate))
else:
raise ValueError('Unsupported scheme in service metadata ("{}")'.format(endpoint_object.scheme))

current_block_number = service_client.get_current_block_number()
signature = service_client.generate_training_signature(msg.value, web3.Web3.to_checksum_address(
service_client.account.address), current_block_number)
auth_req = self.training_pb2.AuthorizationDetails(signature=bytes(signature),
current_block=current_block_number,
signer_address=service_client.account.address,
message=msg.value)
return auth_req, channel

# params from AI-service: status, model_id
# params pass to daemon: grpc_service_name, grpc_method_name, address_list,
# description, model_name, training_data_link, is_public_accessible
def create_model(self, service_client, grpc_method_name: str, model_name: str,
description: str = '',
training_data_link: str = '', grpc_service_name='service',
is_publicly_accessible=False, address_list: list[str] = None):
if address_list is None:
address_list = []
try:
auth_req, channel = self._invoke_model(service_client, ModelMethodMessage.CreateModel)
model_details = self.training_pb2.ModelDetails(grpc_method_name=grpc_method_name, description=description,
training_data_link=training_data_link,
grpc_service_name=grpc_service_name,
model_name=model_name, address_list=address_list,
is_publicly_accessible=is_publicly_accessible)
stub = self.training_pb2_grpc.ModelStub(channel)
response = stub.create_model(
self.training_pb2.CreateModelRequest(authorization=auth_req, model_details=model_details))
return response
except Exception as e:
print("Exception: ", e)
return e

# params from AI-service: status
# params to daemon: grpc_service_name, grpc_method_name, model_id
def get_model_status(self, service_client, model_id: str, grpc_method_name: str, grpc_service_name='service'):
try:
auth_req, channel = self._invoke_model(service_client, ModelMethodMessage.GetModelStatus)
model_details = self.training_pb2.ModelDetails(grpc_method_name=grpc_method_name,
grpc_service_name=grpc_service_name, model_id=str(model_id))
stub = self.training_pb2_grpc.ModelStub(channel)
response = stub.get_model_status(
self.training_pb2.ModelDetailsRequest(authorization=auth_req, model_details=model_details))
return response
except Exception as e:
print("Exception: ", e)
return e

# params from AI-service: status
# params to daemon: grpc_service_name, grpc_method_name, model_id
def delete_model(self, service_client, model_id: str, grpc_method_name: str,
grpc_service_name='service'):
try:
auth_req, channel = self._invoke_model(service_client, ModelMethodMessage.DeleteModel)
model_details = self.training_pb2.ModelDetails(grpc_method_name=grpc_method_name,
grpc_service_name=grpc_service_name, model_id=str(model_id))
stub = self.training_pb2_grpc.ModelStub(channel)
response = stub.delete_model(
self.training_pb2.UpdateModelRequest(authorization=auth_req, update_model_details=model_details))
return response
except Exception as e:
print("Exception: ", e)
return e

# params from AI-service: None
# params to daemon: grpc_service_name, grpc_method_name, model_id, address_list, is_public, model_name, desc
# all params required
def update_model_access(self, service_client, model_id: str, grpc_method_name: str,
model_name: str, is_public: bool,
description: str, grpc_service_name: str = 'service', address_list: list[str] = None):
try:
auth_req, channel = self._invoke_model(service_client, ModelMethodMessage.UpdateModelAccess)
model_details = self.training_pb2.ModelDetails(grpc_method_name=grpc_method_name, description=description,
grpc_service_name=grpc_service_name,
address_list=address_list,
is_publicly_accessible=is_public, model_name=model_name,
model_id=str(model_id))
stub = self.training_pb2_grpc.ModelStub(channel)
response = stub.update_model_access(
self.training_pb2.UpdateModelRequest(authorization=auth_req, update_model_details=model_details))
return response
except Exception as e:
print("Exception: ", e)
return e

# params from AI-service: None
# params to daemon: grpc_service_name, grpc_method_name
def get_all_models(self, service_client, grpc_method_name: str, grpc_service_name='service'):
try:
auth_req, channel = self._invoke_model(service_client, ModelMethodMessage.GetAllModels)
stub = self.training_pb2_grpc.ModelStub(channel)
response = stub.get_all_models(
self.training_pb2.AccessibleModelsRequest(authorization=auth_req, grpc_service_name=grpc_service_name,
grpc_method_name=grpc_method_name))
return response
except Exception as e:
print("Exception: ", e)
return e
2 changes: 1 addition & 1 deletion packages/sdk/snet/sdk/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "v3.1.0"
__version__ = "v3.1.1"
112 changes: 112 additions & 0 deletions packages/snet_cli/snet/snet_cli/resources/proto/training.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
syntax = "proto3";
import "google/protobuf/descriptor.proto";
package training;
option go_package = "../training";
//Please note that the AI developers need to provide a server implementation of the gprc server of this proto.
message ModelDetails {
//This Id will be generated when you invoke the create_model method and hence doesnt need to be filled when you
//invoke the create model
string model_id = 1;
//define the training method name
string grpc_method_name = 2;
//define the grpc service name , under which the method is defined
string grpc_service_name = 3;
string description = 4;

string status = 6;
string updated_date = 7;
//List of all the addresses that will have access to this model
repeated string address_list = 8;
// this is optional
string training_data_link = 9;
string model_name = 10;


string organization_id = 11;
string service_id = 12 ;
string group_id = 13;

//set this to true if you want your model to be used by other AI consumers
bool is_publicly_accessible = 14;

}

message AuthorizationDetails {
uint64 current_block = 1;
//Signer can fill in any message here
string message = 2;
//signature of the following message:
//("user specified message", user_address, current_block_number)
bytes signature = 3;
string signer_address = 4;

}

enum Status {
CREATED = 0;
IN_PROGRESS = 1;
ERRORED = 2;
COMPLETED = 3;
DELETED = 4;
}

message CreateModelRequest {
AuthorizationDetails authorization = 1;
ModelDetails model_details = 2;
}

//the signer address will get to know all the models associated with this address.
message AccessibleModelsRequest {
string grpc_method_name = 1;
string grpc_service_name = 2;
AuthorizationDetails authorization = 3;
}

message AccessibleModelsResponse {
repeated ModelDetails list_of_models = 1;
}

message ModelDetailsRequest {
ModelDetails model_details = 1 ;
AuthorizationDetails authorization = 2;
}

//helps determine which service end point to call for model training
//format is of type "packageName/serviceName/MethodName", Example :"/example_service.Calculator/estimate_add"
//Daemon will invoke the model training end point , when the below method option is specified
message TrainingMethodOption {
string trainingMethodIndicator = 1;
}

extend google.protobuf.MethodOptions {
TrainingMethodOption my_method_option = 9999197;
}

message UpdateModelRequest {
ModelDetails update_model_details = 1 ;
AuthorizationDetails authorization = 2;
}


message ModelDetailsResponse {
Status status = 1;
ModelDetails model_details = 2;

}

service Model {

// The AI developer needs to Implement this service and Daemon will call these
// There will be no cost borne by the consumer in calling these methods,
// Pricing will apply when you actually call the training methods defined.
// AI consumer will call all these methods
rpc create_model(CreateModelRequest) returns (ModelDetailsResponse) {}
rpc delete_model(UpdateModelRequest) returns (ModelDetailsResponse) {}
rpc get_model_status(ModelDetailsRequest) returns (ModelDetailsResponse) {}

// Daemon will implement , however the AI developer should skip implementing these and just provide dummy code.
rpc update_model_access(UpdateModelRequest) returns (ModelDetailsResponse) {}
rpc get_all_models(AccessibleModelsRequest) returns (AccessibleModelsResponse) {}


}
2 changes: 1 addition & 1 deletion packages/snet_cli/snet/snet_cli/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "v2.1.0"
__version__ = "v2.1.1"
6 changes: 0 additions & 6 deletions packages/snet_cli/snet_cli/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,12 +495,6 @@ def add_eth_call_arguments(parser):

def add_transaction_arguments(parser):
transaction_g = parser.add_argument_group(title="transaction arguments")
transaction_g.add_argument(
"--gas-price",
help="Ethereum gas price in Wei or time based gas price strategy "
"('fast' ~1min, 'medium' ~5min or 'slow' ~60min) (defaults to session.default_gas_price)"
)

transaction_g.add_argument(
"--wallet-index", type=int,
help="Wallet index of account to use for signing (defaults to session.identity.default_wallet_index)")
Expand Down
Loading

0 comments on commit 52d2052

Please sign in to comment.