Skip to content

Commit

Permalink
presigned url and predict entrypoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede committed Jun 11, 2024
1 parent f532281 commit e7fa8f4
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 9 deletions.
4 changes: 3 additions & 1 deletion examples/mnist-pytorch/client/fedn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ entry_points:
train:
command: python train.py
validate:
command: python validate.py
command: python validate.py
predict:
command: python predict.py
39 changes: 39 additions & 0 deletions examples/mnist-pytorch/client/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
import sys

import torch
from data import load_data
from model import load_parameters

from fedn.utils.helpers.helpers import save_metrics

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.abspath(dir_path))


def predict(in_model_path, out_artifact_path, data_path=None):
"""Validate model.
:param in_model_path: The path to the input model.
:type in_model_path: str
:param out_artifact_path: The path to save the predict output to.
:type out_artifact_path: str
:param data_path: The path to the data file.
:type data_path: str
"""
# Load data
x_test, y_test = load_data(data_path, is_train=False)

# Load model
model = load_parameters(in_model_path)
model.eval()

# Predict
with torch.no_grad():
y_pred = model(x_test)
# Save prediction to file/artifact, the artifact will be uploaded to the object store by the client
torch.save(y_pred, out_artifact_path)


if __name__ == "__main__":
predict(sys.argv[1], sys.argv[2])
76 changes: 71 additions & 5 deletions fedn/network/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from shutil import copytree

import grpc
import requests
from cryptography.hazmat.primitives.serialization import Encoding
from google.protobuf.json_format import MessageToJson
from OpenSSL import SSL
Expand All @@ -22,13 +23,11 @@
import fedn.network.grpc.fedn_pb2 as fedn
import fedn.network.grpc.fedn_pb2_grpc as rpc
from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_PACKAGE_EXTRACT_DIR
from fedn.common.log_config import (logger, set_log_level_from_string,
set_log_stream)
from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream
from fedn.network.clients.connect import ConnectorClient, Status
from fedn.network.clients.package import PackageRuntime
from fedn.network.clients.state import ClientState, ClientStateToString
from fedn.network.combiner.modelservice import (get_tmp_path,
upload_request_generator)
from fedn.network.combiner.modelservice import get_tmp_path, upload_request_generator
from fedn.utils.dispatcher import Dispatcher
from fedn.utils.helpers.helpers import get_helper

Expand Down Expand Up @@ -438,12 +437,18 @@ def _listen_to_task_stream(self):
request=request,
sesssion_id=request.session_id,
)
logger.info("Received model update request of type {} for model_id {}".format(request.type, request.model_id))
logger.info("Received task request of type {} for model_id {}".format(request.type, request.model_id))

if request.type == fedn.StatusType.MODEL_UPDATE and self.config["trainer"]:
self.inbox.put(("train", request))
elif request.type == fedn.StatusType.MODEL_VALIDATION and self.config["validator"]:
self.inbox.put(("validate", request))
elif request.type == fedn.StatusType.INFERENCE and self.config["validator"]:
logger.info("Received inference request for model_id {}".format(request.model_id))
presined_url = json.loads(request.data)
presined_url = presined_url["presigned_url"]
logger.info("Inference presigned URL: {}".format(presined_url))
self.inbox.put(("infer", request))
else:
logger.error("Unknown request type: {}".format(request.type))

Expand Down Expand Up @@ -586,6 +591,51 @@ def _process_validation_request(self, model_id: str, is_inference: bool, session
self.state = ClientState.idle
return validation

def _process_inference_request(self, model_id: str, session_id: str, presigned_url: str):
"""Process an inference request.
:param model_id: The model id of the model to be used for inference.
:type model_id: str
:param session_id: The id of the current session.
:type session_id: str
:param presigned_url: The presigned URL for the data to be used for inference.
:type presigned_url: str
:return: None
"""
self.send_status(f"Processing inference request for model_id {model_id}", sesssion_id=session_id)
try:
model = self.get_model_from_combiner(str(model_id))
if model is None:
logger.error("Could not retrieve model from combiner. Aborting inference request.")
return
inpath = self.helper.get_tmp_path()

with open(inpath, "wb") as fh:
fh.write(model.getbuffer())

outpath = get_tmp_path()
self.dispatcher.run_cmd(f"predict {inpath} {outpath}")

# Upload the inference result to the presigned URL
with open(outpath, "rb") as fh:
response = requests.put(presigned_url, data=fh.read())

os.unlink(inpath)
os.unlink(outpath)

if response.status_code != 200:
logger.warning("Inference upload failed with status code {}".format(response.status_code))
self.state = ClientState.idle
return

except Exception as e:
logger.warning("Inference failed with exception {}".format(e))
self.state = ClientState.idle
return

self.state = ClientState.idle
return

def process_request(self):
"""Process training and validation tasks."""
while True:
Expand Down Expand Up @@ -682,6 +732,22 @@ def process_request(self):

self.state = ClientState.idle
self.inbox.task_done()
elif task_type == "infer":
self.state = ClientState.inferencing
try:
presigned_url = json.loads(request.data)
except json.JSONDecodeError as e:
logger.error(f"Failed to decode inference request data: {e}")
self.state = ClientState.idle
continue

if "presigned_url" not in presigned_url:
logger.error("Inference request missing presigned_url.")
self.state = ClientState.idle
continue
presigned_url = presigned_url["presigned_url"]
_ = self._process_inference_request(request.model_id, request.session_id, presigned_url)
self.state = ClientState.idle
except queue.Empty:
pass
except grpc.RpcError as e:
Expand Down
1 change: 1 addition & 0 deletions fedn/network/clients/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class ClientState(Enum):
idle = 1
training = 2
validating = 3
inferencing = 4


def ClientStateToString(state):
Expand Down
7 changes: 4 additions & 3 deletions fedn/network/combiner/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,17 @@ def _send_request_type(self, request_type, session_id, model_id, config=None, cl
if len(clients) == 0:
clients = self.get_active_validators()
elif request_type == fedn.StatusType.INFERENCE:
request.data = json.dumps(config)
if len(clients) == 0:
# TODO: add inference clients type
clients = self.get_active_validators()

# TODO: if inference, request.data should be user-defined data/parameters

for client in clients:
request.receiver.name = client
request.receiver.role = fedn.WORKER
if request_type == fedn.StatusType.INFERENCE:
presigned_url = self.repository.presigned_put_url(self.repository.inference_bucket, f"{client}/{session_id}")
# TODO: in inference, request.data should also contain user-defined data/parameters
request.data = json.dumps({"presigned_url": presigned_url})
self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE)

return request, clients
Expand Down
34 changes: 34 additions & 0 deletions fedn/network/storage/s3/repository.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import uuid

from fedn.common.log_config import logger
Expand All @@ -10,12 +11,17 @@ class Repository:
def __init__(self, config):
self.model_bucket = config["storage_bucket"]
self.context_bucket = config["context_bucket"]
try:
self.inference_bucket = config["inference_bucket"]
except KeyError:
self.inference_bucket = "fedn-inference"

# TODO: Make a plug-in solution
self.client = MINIORepository(config)

self.client.create_bucket(self.context_bucket)
self.client.create_bucket(self.model_bucket)
self.client.create_bucket(self.inference_bucket)

def get_model(self, model_id):
"""Retrieve a model with id model_id.
Expand Down Expand Up @@ -104,3 +110,31 @@ def delete_compute_package(self, compute_package):
except Exception:
logger.error("Failed to delete compute_package from repository.")
raise

def presigned_put_url(self, bucket: str, object_name: str, expires: datetime.timedelta = datetime.timedelta(hours=1)):
"""Generate a presigned URL for an upload object request.
:param bucket: The bucket name
:type bucket: str
:param object_name: The object name
:type object_name: str
:param expires: The time the URL is valid
:type expires: datetime.timedelta
:return: The URL
:rtype: str
"""
return self.client.client.presigned_put_object(bucket, object_name, expires)

def presigned_get_url(self, bucket: str, object_name: str, expires: datetime.timedelta = datetime.timedelta(hours=1)) -> str:
"""Generate a presigned URL for a download object request.
:param bucket: The bucket name
:type bucket: str
:param object_name: The object name
:type object_name: str
:param expires: The time the URL is valid
:type expires: datetime.timedelta
:return: The URL
:rtype: str
"""
return self.client.client.presigned_get_object(bucket, object_name, expires)

0 comments on commit e7fa8f4

Please sign in to comment.