-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #482 from singnet/development
Patch fixes and training
- Loading branch information
Showing
11 changed files
with
302 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ venv/ | |
__pycache__ | ||
blockchain/node_modules | ||
snet.egg-info/ | ||
*.pyi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import enum | ||
import importlib | ||
from urllib.parse import urlparse | ||
|
||
import grpc | ||
import web3 | ||
|
||
from snet.sdk.root_certificate import root_certificate | ||
from snet.snet_cli.utils.utils import RESOURCES_PATH, add_to_path | ||
|
||
|
||
# for local debug | ||
# from snet.snet_cli.resources.proto import training_pb2_grpc | ||
# from snet.snet_cli.resources.proto import training_pb2 | ||
|
||
|
||
# from daemon code | ||
class ModelMethodMessage(enum.Enum): | ||
CreateModel = "__CreateModel" | ||
GetModelStatus = "__GetModelStatus" | ||
UpdateModelAccess = "__UpdateModelAccess" | ||
GetAllModels = "__UpdateModelAccess" | ||
DeleteModel = "__GetModelStatus" | ||
|
||
|
||
class TrainingModel: | ||
|
||
def __init__(self): | ||
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))): | ||
self.training_pb2 = importlib.import_module("training_pb2") | ||
|
||
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))): | ||
self.training_pb2_grpc = importlib.import_module("training_pb2_grpc") | ||
|
||
def _invoke_model(self, service_client, msg: ModelMethodMessage): | ||
org_id, service_id, group_id, daemon_endpoint = service_client.get_service_details() | ||
|
||
endpoint_object = urlparse(daemon_endpoint) | ||
if endpoint_object.port is not None: | ||
channel_endpoint = endpoint_object.hostname + ":" + str(endpoint_object.port) | ||
else: | ||
channel_endpoint = endpoint_object.hostname | ||
|
||
if endpoint_object.scheme == "http": | ||
print("creating http channel: ", channel_endpoint) | ||
channel = grpc.insecure_channel(channel_endpoint) | ||
elif endpoint_object.scheme == "https": | ||
channel = grpc.secure_channel(channel_endpoint, | ||
grpc.ssl_channel_credentials(root_certificates=root_certificate)) | ||
else: | ||
raise ValueError('Unsupported scheme in service metadata ("{}")'.format(endpoint_object.scheme)) | ||
|
||
current_block_number = service_client.get_current_block_number() | ||
signature = service_client.generate_training_signature(msg.value, web3.Web3.to_checksum_address( | ||
service_client.account.address), current_block_number) | ||
auth_req = self.training_pb2.AuthorizationDetails(signature=bytes(signature), | ||
current_block=current_block_number, | ||
signer_address=service_client.account.address, | ||
message=msg.value) | ||
return auth_req, channel | ||
|
||
# params from AI-service: status, model_id | ||
# params pass to daemon: grpc_service_name, grpc_method_name, address_list, | ||
# description, model_name, training_data_link, is_public_accessible | ||
def create_model(self, service_client, grpc_method_name: str, model_name: str, | ||
description: str = '', | ||
training_data_link: str = '', grpc_service_name='service', | ||
is_publicly_accessible=False, address_list: list[str] = None): | ||
if address_list is None: | ||
address_list = [] | ||
try: | ||
auth_req, channel = self._invoke_model(service_client, ModelMethodMessage.CreateModel) | ||
model_details = self.training_pb2.ModelDetails(grpc_method_name=grpc_method_name, description=description, | ||
training_data_link=training_data_link, | ||
grpc_service_name=grpc_service_name, | ||
model_name=model_name, address_list=address_list, | ||
is_publicly_accessible=is_publicly_accessible) | ||
stub = self.training_pb2_grpc.ModelStub(channel) | ||
response = stub.create_model( | ||
self.training_pb2.CreateModelRequest(authorization=auth_req, model_details=model_details)) | ||
return response | ||
except Exception as e: | ||
print("Exception: ", e) | ||
return e | ||
|
||
# params from AI-service: status | ||
# params to daemon: grpc_service_name, grpc_method_name, model_id | ||
def get_model_status(self, service_client, model_id: str, grpc_method_name: str, grpc_service_name='service'): | ||
try: | ||
auth_req, channel = self._invoke_model(service_client, ModelMethodMessage.GetModelStatus) | ||
model_details = self.training_pb2.ModelDetails(grpc_method_name=grpc_method_name, | ||
grpc_service_name=grpc_service_name, model_id=str(model_id)) | ||
stub = self.training_pb2_grpc.ModelStub(channel) | ||
response = stub.get_model_status( | ||
self.training_pb2.ModelDetailsRequest(authorization=auth_req, model_details=model_details)) | ||
return response | ||
except Exception as e: | ||
print("Exception: ", e) | ||
return e | ||
|
||
# params from AI-service: status | ||
# params to daemon: grpc_service_name, grpc_method_name, model_id | ||
def delete_model(self, service_client, model_id: str, grpc_method_name: str, | ||
grpc_service_name='service'): | ||
try: | ||
auth_req, channel = self._invoke_model(service_client, ModelMethodMessage.DeleteModel) | ||
model_details = self.training_pb2.ModelDetails(grpc_method_name=grpc_method_name, | ||
grpc_service_name=grpc_service_name, model_id=str(model_id)) | ||
stub = self.training_pb2_grpc.ModelStub(channel) | ||
response = stub.delete_model( | ||
self.training_pb2.UpdateModelRequest(authorization=auth_req, update_model_details=model_details)) | ||
return response | ||
except Exception as e: | ||
print("Exception: ", e) | ||
return e | ||
|
||
# params from AI-service: None | ||
# params to daemon: grpc_service_name, grpc_method_name, model_id, address_list, is_public, model_name, desc | ||
# all params required | ||
def update_model_access(self, service_client, model_id: str, grpc_method_name: str, | ||
model_name: str, is_public: bool, | ||
description: str, grpc_service_name: str = 'service', address_list: list[str] = None): | ||
try: | ||
auth_req, channel = self._invoke_model(service_client, ModelMethodMessage.UpdateModelAccess) | ||
model_details = self.training_pb2.ModelDetails(grpc_method_name=grpc_method_name, description=description, | ||
grpc_service_name=grpc_service_name, | ||
address_list=address_list, | ||
is_publicly_accessible=is_public, model_name=model_name, | ||
model_id=str(model_id)) | ||
stub = self.training_pb2_grpc.ModelStub(channel) | ||
response = stub.update_model_access( | ||
self.training_pb2.UpdateModelRequest(authorization=auth_req, update_model_details=model_details)) | ||
return response | ||
except Exception as e: | ||
print("Exception: ", e) | ||
return e | ||
|
||
# params from AI-service: None | ||
# params to daemon: grpc_service_name, grpc_method_name | ||
def get_all_models(self, service_client, grpc_method_name: str, grpc_service_name='service'): | ||
try: | ||
auth_req, channel = self._invoke_model(service_client, ModelMethodMessage.GetAllModels) | ||
stub = self.training_pb2_grpc.ModelStub(channel) | ||
response = stub.get_all_models( | ||
self.training_pb2.AccessibleModelsRequest(authorization=auth_req, grpc_service_name=grpc_service_name, | ||
grpc_method_name=grpc_method_name)) | ||
return response | ||
except Exception as e: | ||
print("Exception: ", e) | ||
return e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "v3.1.0" | ||
__version__ = "v3.1.1" |
112 changes: 112 additions & 0 deletions
112
packages/snet_cli/snet/snet_cli/resources/proto/training.proto
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
syntax = "proto3"; | ||
import "google/protobuf/descriptor.proto"; | ||
package training; | ||
option go_package = "../training"; | ||
//Please note that the AI developers need to provide a server implementation of the gprc server of this proto. | ||
message ModelDetails { | ||
//This Id will be generated when you invoke the create_model method and hence doesnt need to be filled when you | ||
//invoke the create model | ||
string model_id = 1; | ||
//define the training method name | ||
string grpc_method_name = 2; | ||
//define the grpc service name , under which the method is defined | ||
string grpc_service_name = 3; | ||
string description = 4; | ||
|
||
string status = 6; | ||
string updated_date = 7; | ||
//List of all the addresses that will have access to this model | ||
repeated string address_list = 8; | ||
// this is optional | ||
string training_data_link = 9; | ||
string model_name = 10; | ||
|
||
|
||
string organization_id = 11; | ||
string service_id = 12 ; | ||
string group_id = 13; | ||
|
||
//set this to true if you want your model to be used by other AI consumers | ||
bool is_publicly_accessible = 14; | ||
|
||
} | ||
|
||
message AuthorizationDetails { | ||
uint64 current_block = 1; | ||
//Signer can fill in any message here | ||
string message = 2; | ||
//signature of the following message: | ||
//("user specified message", user_address, current_block_number) | ||
bytes signature = 3; | ||
string signer_address = 4; | ||
|
||
} | ||
|
||
enum Status { | ||
CREATED = 0; | ||
IN_PROGRESS = 1; | ||
ERRORED = 2; | ||
COMPLETED = 3; | ||
DELETED = 4; | ||
} | ||
|
||
message CreateModelRequest { | ||
AuthorizationDetails authorization = 1; | ||
ModelDetails model_details = 2; | ||
} | ||
|
||
//the signer address will get to know all the models associated with this address. | ||
message AccessibleModelsRequest { | ||
string grpc_method_name = 1; | ||
string grpc_service_name = 2; | ||
AuthorizationDetails authorization = 3; | ||
} | ||
|
||
message AccessibleModelsResponse { | ||
repeated ModelDetails list_of_models = 1; | ||
} | ||
|
||
message ModelDetailsRequest { | ||
ModelDetails model_details = 1 ; | ||
AuthorizationDetails authorization = 2; | ||
} | ||
|
||
//helps determine which service end point to call for model training | ||
//format is of type "packageName/serviceName/MethodName", Example :"/example_service.Calculator/estimate_add" | ||
//Daemon will invoke the model training end point , when the below method option is specified | ||
message TrainingMethodOption { | ||
string trainingMethodIndicator = 1; | ||
} | ||
|
||
extend google.protobuf.MethodOptions { | ||
TrainingMethodOption my_method_option = 9999197; | ||
} | ||
|
||
message UpdateModelRequest { | ||
ModelDetails update_model_details = 1 ; | ||
AuthorizationDetails authorization = 2; | ||
} | ||
|
||
|
||
message ModelDetailsResponse { | ||
Status status = 1; | ||
ModelDetails model_details = 2; | ||
|
||
} | ||
|
||
service Model { | ||
|
||
// The AI developer needs to Implement this service and Daemon will call these | ||
// There will be no cost borne by the consumer in calling these methods, | ||
// Pricing will apply when you actually call the training methods defined. | ||
// AI consumer will call all these methods | ||
rpc create_model(CreateModelRequest) returns (ModelDetailsResponse) {} | ||
rpc delete_model(UpdateModelRequest) returns (ModelDetailsResponse) {} | ||
rpc get_model_status(ModelDetailsRequest) returns (ModelDetailsResponse) {} | ||
|
||
// Daemon will implement , however the AI developer should skip implementing these and just provide dummy code. | ||
rpc update_model_access(UpdateModelRequest) returns (ModelDetailsResponse) {} | ||
rpc get_all_models(AccessibleModelsRequest) returns (AccessibleModelsResponse) {} | ||
|
||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "v2.1.0" | ||
__version__ = "v2.1.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.