-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathglobal_model_chaincode.go
215 lines (183 loc) · 7 KB
/
global_model_chaincode.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
package chaincode
import (
"encoding/json"
"fmt"
"strings"
"github.com/hyperledger/fabric-contract-api-go/v2/contractapi"
)
// GlobalModelSmartContract provides functions for managing global ML models
type GlobalModelSmartContract struct {
contractapi.Contract
}
// GlobalModelAsset describes the global model details that are visible to all organizations
type GlobalModelAsset struct {
Type string `json:"objectType"` // Used to distinguish different types of objects in state database
GlobalModelHash string `json:"global_model_hash"`
PreviousGlobalModelHash string `json:"previous_global_model_hash"`
LocalModelHashes []string `json:"local_model_hashes"` // List of local model hashes used in aggregation
ZkpHash string `json:"zkp_hash"` // Hash of the ZKP (in our simplified case can be just a path to the ZKP file)
RunID string `json:"run_id"`
RoundID uint64 `json:"round_id"`
}
// GlobalModelPrivateDetails describes the private details of the global model
type GlobalModelPrivateDetails struct {
GlobalModelHash string `json:"global_model_hash"`
AggregatedWeights []byte `json:"aggregated_weights"` // Serialized model weights
}
// LocalModel describes a local model contribution
type LocalModel struct {
Type string `json:"objectType"`
LocalModelHash string `json:"local_model_hash"`
NumExamples uint64 `json:"num_examples"`
RootGlobalModelHash string `json:"root_global_model_hash"`
RunID string `json:"run_id"`
RoundID uint64 `json:"round_id"`
}
// ReadGlobalModel returns the global model stored with given hash
func (s *GlobalModelSmartContract) ReadGlobalModel(ctx contractapi.TransactionContextInterface, globalModelHash string) (*GlobalModelAsset, error) {
modelAsBytes, err := ctx.GetStub().GetState(globalModelHash)
if err != nil {
return nil, fmt.Errorf("failed to read local model: %v", err)
}
if modelAsBytes == nil {
return nil, fmt.Errorf("global model %s does not exist", globalModelHash)
}
var model GlobalModelAsset
err = json.Unmarshal(modelAsBytes, &model)
if err != nil {
return nil, err
}
return &model, nil
}
// CreateGlobalModel creates a new global model by placing the main asset details in the public collection
// and the model weights in the organization's private collection
func (s *GlobalModelSmartContract) CreateGlobalModel(
ctx contractapi.TransactionContextInterface,
globalModelHash string,
previousGlobalModelHash string,
localModelHashes string,
zkpHash string,
runID string,
roundID uint64,
) error {
// Get the client org id
clientMSPID, err := ctx.GetClientIdentity().GetMSPID()
if err != nil {
return fmt.Errorf("failed getting client's orgID: %v", err)
}
// Get only the weights from transient map
transientMap, err := ctx.GetStub().GetTransient()
if err != nil {
return fmt.Errorf("error getting transient: %v", err)
}
weights, ok := transientMap["weights"]
if !ok {
return fmt.Errorf("weights not found in transient map")
}
// Check if global model already exists
exists, err := s.GlobalModelExists(ctx, globalModelHash)
if err != nil {
return err
}
if exists {
return fmt.Errorf("the global model %s already exists", globalModelHash)
}
// Validate input
if len(globalModelHash) == 0 {
return fmt.Errorf("global model hash field must be a non-empty string")
}
if len(runID) == 0 {
return fmt.Errorf("run ID field must be a non-empty string")
}
if roundID == 0 {
return fmt.Errorf("round ID must be greater than 0")
}
if len(weights) == 0 {
return fmt.Errorf("aggregated weights must not be empty")
}
var localModelHashesSlice = make([]string, 0)
if localModelHashes != "" {
localModelHashesSlice = strings.Split(localModelHashes, ";")
}
// Verify all referenced local models exist and are valid
err = s.VerifyLocalModels(ctx, localModelHashesSlice, previousGlobalModelHash, runID, roundID)
if err != nil {
return fmt.Errorf("local model verification failed: %v", err)
}
// Create the global model asset
asset := GlobalModelAsset{
Type: "globalModel",
GlobalModelHash: globalModelHash,
PreviousGlobalModelHash: previousGlobalModelHash,
LocalModelHashes: localModelHashesSlice,
ZkpHash: zkpHash,
RunID: runID,
RoundID: roundID,
}
assetJSON, err := json.Marshal(asset)
if err != nil {
return fmt.Errorf("failed to marshal asset into JSON: %v", err)
}
// Save public asset data
err = ctx.GetStub().PutState(globalModelHash, assetJSON)
if err != nil {
return fmt.Errorf("failed to put asset in public state: %v", err)
}
// Save the private data (weights)
privateDetails := GlobalModelPrivateDetails{
GlobalModelHash: globalModelHash,
AggregatedWeights: weights,
}
privateDetailsJSON, err := json.Marshal(privateDetails)
if err != nil {
return fmt.Errorf("failed to marshal private details: %v", err)
}
// Get collection name for this organization
orgCollection := fmt.Sprintf("%sPrivateCollection", clientMSPID)
// Put the weights in the organization's private collection
err = ctx.GetStub().PutPrivateData(orgCollection, globalModelHash, privateDetailsJSON)
if err != nil {
return fmt.Errorf("failed to put private details: %v", err)
}
return nil
}
// verifyLocalModels verifies local models
func (s *GlobalModelSmartContract) VerifyLocalModels(ctx contractapi.TransactionContextInterface,
localModelHashes []string,
previousGlobalModelHash string,
runID string,
roundID uint64) error {
for _, hash := range localModelHashes {
// Query local model chaincode to verify the model exists
queryArgs := [][]byte{[]byte("ReadLocalModel"), []byte(hash)}
response := ctx.GetStub().InvokeChaincode("local_model_chaincode", queryArgs, "mychannel")
if response.Status != 200 {
return fmt.Errorf("local model %s verification failed: %s", hash, response.Message)
}
var localModel LocalModel
err := json.Unmarshal(response.Payload, &localModel)
if err != nil {
return fmt.Errorf("failed to unmarshal local model %s: %v", hash, err)
}
// Verify the local model belongs to the correct run and round
if localModel.RunID != runID {
return fmt.Errorf("local model %s belongs to different run", hash)
}
if localModel.RoundID != roundID {
return fmt.Errorf("local model %s belongs to different round", hash)
}
// Verify the local model was derived from the previous global model
if localModel.RootGlobalModelHash != previousGlobalModelHash {
return fmt.Errorf("local model %s was not derived from the previous global model", hash)
}
}
return nil
}
// GlobalModelExists returns true when global model with given ID exists in world state
func (s *GlobalModelSmartContract) GlobalModelExists(ctx contractapi.TransactionContextInterface, globalModelHash string) (bool, error) {
modelJSON, err := ctx.GetStub().GetState(globalModelHash)
if err != nil {
return false, fmt.Errorf("failed to read from world state: %v", err)
}
return modelJSON != nil, nil
}