Skip to content

Commit

Permalink
add: Joint Inference paradigm and cloud-edge collaborative inference …
Browse files Browse the repository at this point in the history
…example

Signed-off-by: Yu Fan <[email protected]>

add: sedna 0.6.0.1 and move 0.4.1 to third_party-bk

Signed-off-by: Yu Fan <[email protected]>
  • Loading branch information
FuryMartin committed Oct 31, 2024
1 parent 3fa3879 commit fb95a9c
Show file tree
Hide file tree
Showing 43 changed files with 2,520 additions and 31 deletions.
9 changes: 9 additions & 0 deletions core/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class DatasetFormat(Enum):
TXT = "txt"
JSON = "json"
JSONL = "jsonl"
JSONFORLLM = "jsonforllm"


class ParadigmType(Enum):
Expand All @@ -40,6 +41,7 @@ class ParadigmType(Enum):
LIFELONG_LEARNING = "lifelonglearning"
FEDERATED_LEARNING = "federatedlearning"
FEDERATED_CLASS_INCREMENTAL_LEARNING = "federatedclassincrementallearning"
JOINT_INFERENCE = "jointinference"


class ModuleType(Enum):
Expand All @@ -49,6 +51,13 @@ class ModuleType(Enum):

BASEMODEL = "basemodel"

# JOINT INFERENCE
EDGEMODEL = "edgemodel"
CLOUDMODEL = "cloudmodel"

# Dataset Preprocessor
DATA_PROCESSOR = "dataset_processor"

# HEM
HARD_EXAMPLE_MINING = "hard_example_mining"

Expand Down
6 changes: 5 additions & 1 deletion core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ def is_local_dir(url):

def get_file_format(url):
"""Get file format of the url."""
return os.path.splitext(url)[-1][1:]
# Check if the url
if os.path.basename(url) == "metadata.json":
return "jsonforllm"

# Check if the url
return os.path.splitext(url)[-1][1:]

def parse_kwargs(func, **kwargs):
"""Get valid parameters of the func in kwargs."""
Expand Down
3 changes: 1 addition & 2 deletions core/storymanager/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
def print_table(rank_file):
""" print rank of the test"""
with open(rank_file, "r", encoding="utf-8") as file:
table = from_csv(file)
table = from_csv(file, delimiter=",")
print(table)

def draw_heatmap_picture(output, title, matrix):
Expand All @@ -40,7 +40,6 @@ def draw_heatmap_picture(output, title, matrix):
plt.title(title, fontsize=15)
plt.colorbar(format='%.2f')
output_dir = os.path.join(output, f"output/{title}-heatmap.png")
#print(output_dir)
plt.savefig(output_dir)
plt.show()

Expand Down
6 changes: 5 additions & 1 deletion core/testcasecontroller/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
MultiedgeInference,
LifelongLearning,
FederatedLearning,
FederatedClassIncrementalLearning
FederatedClassIncrementalLearning,
JointInference
)
from core.testcasecontroller.generation_assistant import get_full_combinations

Expand Down Expand Up @@ -123,6 +124,9 @@ def paradigm(self, workspace: str, **kwargs):
if self.paradigm_type == ParadigmType.FEDERATED_CLASS_INCREMENTAL_LEARNING.value:
return FederatedClassIncrementalLearning(workspace, **config)

if self.paradigm_type == ParadigmType.JOINT_INFERENCE.value:
return JointInference(workspace, **config)

return None

def _check_fields(self):
Expand Down
17 changes: 11 additions & 6 deletions core/testcasecontroller/algorithm/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def _check_fields(self):
if not isinstance(self.url, str):
raise ValueError(f"module url({self.url}) must be string type.")

#pylint: disable=too-many-branches
def get_module_instance(self, module_type):
"""
get function of algorithm module by using module type
Expand All @@ -86,7 +87,6 @@ def get_module_instance(self, module_type):
function
"""
print(f'hyperparameters_list: {self.hyperparameters_list}')
class_factory_type = ClassType.GENERAL
if module_type in [ModuleType.HARD_EXAMPLE_MINING.value]:
class_factory_type = ClassType.HEM
Expand All @@ -110,13 +110,11 @@ def get_module_instance(self, module_type):
elif module_type in [ModuleType.AGGREGATION.value]:
class_factory_type = ClassType.FL_AGG
agg = None
print(self.url)
if self.url :
try:
utils.load_module(self.url)
agg = ClassFactory.get_cls(
type_name=class_factory_type, t_cls_name=self.name)(**self.hyperparameters)
print(agg)
except Exception as err:
raise RuntimeError(f"module(type={module_type} loads class(name={self.name}) "
f"failed, error: {err}.") from err
Expand All @@ -125,10 +123,17 @@ def get_module_instance(self, module_type):
if self.url:
try:
utils.load_module(self.url)
# pylint: disable=E1134
func = ClassFactory.get_cls(
type_name=class_factory_type, t_cls_name=self.name)(**self.hyperparameters)

if class_factory_type == ClassType.HEM:
func = {"method": self.name, "param":self.hyperparameters}
else:
func = ClassFactory.get_cls(
type_name=class_factory_type,
t_cls_name=self.name
)(**self.hyperparameters)

return func

except Exception as err:
raise RuntimeError(f"module(type={module_type} loads class(name={self.name}) "
f"failed, error: {err}.") from err
Expand Down
1 change: 1 addition & 0 deletions core/testcasecontroller/algorithm/paradigm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from .multiedge_inference import MultiedgeInference
from .lifelong_learning import LifelongLearning
from .federated_learning import FederatedLearning, FederatedClassIncrementalLearning
from .joint_inference import JointInference
18 changes: 17 additions & 1 deletion core/testcasecontroller/algorithm/paradigm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from sedna.core.incremental_learning import IncrementalLearning
from sedna.core.lifelong_learning import LifelongLearning
from sedna.core.joint_inference import JointInference
from core.common.constant import ModuleType, ParadigmType
from .sedna_federated_learning import FederatedLearning

Expand Down Expand Up @@ -76,6 +77,7 @@ def _get_module_instances(self):
module_instances.update({module_type: func})
return module_instances

# pylint: disable=too-many-return-statements
def build_paradigm_job(self, paradigm_type):
"""
build paradigm job instance according to paradigm type.
Expand Down Expand Up @@ -103,7 +105,10 @@ def build_paradigm_job(self, paradigm_type):

if paradigm_type == ParadigmType.LIFELONG_LEARNING.value:
return LifelongLearning(
estimator=self.module_instances.get(ModuleType.BASEMODEL.value),
seen_estimator=self.module_instances.get(
ModuleType.BASEMODEL.value
),
unseen_estimator=None,
task_definition=self.module_instances.get(
ModuleType.TASK_DEFINITION.value
),
Expand Down Expand Up @@ -144,4 +149,15 @@ def build_paradigm_job(self, paradigm_type):
estimator=self.module_instances.get(ModuleType.BASEMODEL.value)
)

if paradigm_type == ParadigmType.JOINT_INFERENCE.value:
return JointInference(
estimator=self.module_instances.get(
ModuleType.EDGEMODEL.value),
cloud=self.module_instances.get(
ModuleType.CLOUDMODEL.value),
hard_example_mining=self.module_instances.get(
ModuleType.HARD_EXAMPLE_MINING.value),
LCReporter_enable=False
)

return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2022 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=missing-module-docstring
from .joint_inference import JointInference
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright 2024 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Cloud-Edge Joint Inference"""

import os
from tqdm import tqdm

from core.common.log import LOGGER
from core.common.constant import ParadigmType
from core.testcasecontroller.algorithm.paradigm.base import ParadigmBase

class JointInference(ParadigmBase):
"""
Cloud-Edge-JointInference:
provide the flow of multi-edge inference paradigm.
Notes:
1. Ianvs serves as testing tools for test objects, e.g., algorithms.
2. Ianvs does NOT include code directly on test object.
3. Algorithms serve as typical test objects in Ianvs
and detailed algorithms are thus NOT included in this Ianvs python file.
4. As for the details of example test objects, e.g., algorithms,
please refer to third party packages in Ianvs example.
For example, AI workflow and interface pls refer to sedna
(sedna docs: https://sedna.readthedocs.io/en/latest/api/lib/index.html),
and module implementation pls refer to `examples' test algorithms`,
e.g., basemodel.py, hard_example_mining.py.
Parameters
---------
workspace: string
the output required for multi-edge inference paradigm.
kwargs: dict
config required for the test process of joint inference paradigm,
e.g.: hard_example_mining_mode
"""

def __init__(self, workspace, **kwargs):
ParadigmBase.__init__(self, workspace, **kwargs)
self.inference_dataset = None
self.kwargs = kwargs
self.hard_example_mining_mode = kwargs.get(
"hard_example_mining_mode",
"mining-then-inference"
)

def set_config(self):
""" Set the configuration for the joint inference paradigm.
Raises
------
KeyError
If required modules are not provided.
"""


inference_output_dir = os.path.dirname(self.workspace)
os.environ["RESULT_SAVED_URL"] = inference_output_dir
os.makedirs(inference_output_dir, exist_ok=True)

LOGGER.info("Loading dataset")

self.inference_dataset = self.dataset.load_data(
self.dataset.test_data_info,
"inference"
)

dataset_processor = self.module_instances.get("dataset_processor", None)
if callable(dataset_processor):
self.inference_dataset = dataset_processor(self.inference_dataset)

# validate module instances
required_modules = {"edgemodel", "cloudmodel", "hard_example_mining"}

if not required_modules.issubset(set(self.module_instances.keys())):
raise KeyError(
f"Required modules: {required_modules}, "
f"but got: {self.module_instances.keys()}"
)

# if hard example mining is OracleRouter,
# add the edgemodel and cloudmodel object to its kwargs so that it can use them.
mining = self.module_instances["hard_example_mining"]
param = mining.get("param")
if mining.get("method", None) == "OracleRouter":
param["edgemodel"] = self.module_instances["edgemodel"]
param["cloudmodel"] = self.module_instances["cloudmodel"]

def run(self):
"""
run the test flow of joint inference paradigm.
Returns
------
inference_result: list
system_metric_info: dict
information needed to compute system metrics.
"""
self.set_config()

job = self.build_paradigm_job(ParadigmType.JOINT_INFERENCE.value)

inference_result = self._inference(job)

self._cleanup(job)

return inference_result, self.system_metric_info

def _cleanup(self, job):
"""Call module's cleanup method to release resources
Parameters
----------
job : Sedna JointInference
Sedna JointInference API
"""

LOGGER.info("Release models")
# release module resources
for module in self.module_instances.values():
if hasattr(module, "cleanup"):
module.cleanup()

# Special call is required for hard example mining module
# since it is instantiated within the job.
mining_instance = job.hard_example_mining_algorithm
if hasattr(mining_instance, "cleanup"):
mining_instance.cleanup()

del job

def _inference(self, job):
"""Inference each data in Inference Dataset
Parameters
----------
job : Sedna JointInference
Sedna JointInference API
Returns
-------
tuple
Inference Result with the format of `(is_hard_example, res, edge_result, cloud_result)`
"""
results = []

cloud_count, edge_count = 0,0

LOGGER.info("Inference Start")

pbar = tqdm(
self.inference_dataset.x,
total=len(self.inference_dataset.x),
ncols=100
)

for data in pbar:
# inference via sedna JointInference API
infer_res = job.inference(
data,
mining_mode=self.hard_example_mining_mode
)

if infer_res[2]:
edge_count += 1
elif infer_res[3]:
cloud_count += 1

pbar.set_postfix({"Edge": edge_count, "Cloud": cloud_count})

results.append(infer_res)

LOGGER.info("Inference Finished")

return results
1 change: 0 additions & 1 deletion core/testcasecontroller/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def compute(key, matrix):
"""
Compute BWT and FWT scores for a given matrix.
"""
print(f"compute function: key={key}, matrix={matrix}, type(matrix)={type(matrix)}")

length = len(matrix)
accuracy = 0.0
Expand Down
Loading

0 comments on commit fb95a9c

Please sign in to comment.