Skip to content

Commit

Permalink
Code review fixes 2
Browse files Browse the repository at this point in the history
  • Loading branch information
nicl-nno committed Feb 26, 2024
1 parent 93606a3 commit dcc87a3
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def run_image_classification_automl(train_dataset: tuple,
labels=y_test,
task=task)

dataset_to_train = dataset_to_train.subset_range(0, 100)
dataset_to_train = dataset_to_train.subset_range(0, min(100, dataset_to_train.features.shape[0]))

initial_pipeline = cnn_composite_pipeline()
initial_pipeline.show()
Expand All @@ -106,7 +106,7 @@ def run_image_classification_automl(train_dataset: tuple,
composer_requirements = PipelineComposerRequirements(
primary=get_operations_for_task(task=task, mode='all'),
timeout=datetime.timedelta(minutes=3),
num_of_generations=20, n_jobs=1
num_of_generations=20, n_jobs=1, cv_folds=None
)

pop_size = 5
Expand Down
30 changes: 3 additions & 27 deletions examples/advanced/customization/strategies/image_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from typing import Optional

from examples.advanced.customization.implementations.cnn_impls import MyCNNImplementation
from fedot.core.data.data import InputData, OutputData
from fedot.core.operations.evaluation.evaluation_interfaces import EvaluationStrategy
from fedot.core.data.data import InputData
from fedot.core.operations.evaluation.classification import FedotClassificationStrategy
from fedot.core.operations.operation_parameters import OperationParameters
from fedot.utilities.random import ImplementationRandomStateHandler

warnings.filterwarnings("ignore", category=UserWarning)


class ImageClassificationStrategy(EvaluationStrategy):
class ImageClassificationStrategy(FedotClassificationStrategy):
_operations_by_types = {
'cnn_1': MyCNNImplementation
}
Expand All @@ -33,27 +33,3 @@ def fit(self, train_data: InputData):
with ImplementationRandomStateHandler(implementation=operation_implementation):
operation_implementation.fit(train_data)
return operation_implementation

def predict(self, trained_operation, predict_data: InputData) -> OutputData:
"""
Predict method for classification task for predict stage
:param trained_operation: model object
:param predict_data: data used for prediction
:return: prediction target
"""
n_classes = len(trained_operation.classes_)
if self.output_mode == 'labels':
prediction = trained_operation.predict(predict_data)
elif self.output_mode in ['probs', 'full_probs', 'default']:
prediction = trained_operation.predict_proba(predict_data)
if n_classes < 2:
raise ValueError('Data set contain only 1 target class. Please reformat your data.')
elif n_classes == 2 and self.output_mode != 'full_probs' and len(prediction.shape) > 1:
prediction = prediction[:, 1]
else:
raise ValueError(f'Output model {self.output_mode} is not supported')

# Convert prediction to output (if it is required)
converted = self._convert_to_output(prediction, predict_data)
return converted
2 changes: 1 addition & 1 deletion fedot/core/operations/evaluation/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def predict(self, trained_operation, predict_data: InputData) -> OutputData:
if n_classes < 2:
raise ValueError('Data set contain only 1 target class. Please reformat your data.')
elif n_classes == 2 and self.output_mode != 'full_probs' and len(prediction.shape) > 1:
prediction = prediction[:, 1]
prediction = prediction[:, prediction.shape[1] - 1]
else:
raise ValueError(f'Output model {self.output_mode} is not supported')

Expand Down
Binary file modified test/data/test_labels.npy
Binary file not shown.
Binary file modified test/data/training_labels.npy
Binary file not shown.

0 comments on commit dcc87a3

Please sign in to comment.