diff --git a/examples/advanced/customization/image_classification_with_custom_models.py b/examples/advanced/customization/image_classification_with_custom_models.py index 03f0a5a82a..87f4008bcc 100644 --- a/examples/advanced/customization/image_classification_with_custom_models.py +++ b/examples/advanced/customization/image_classification_with_custom_models.py @@ -87,7 +87,7 @@ def run_image_classification_automl(train_dataset: tuple, labels=y_test, task=task) - dataset_to_train = dataset_to_train.subset_range(0, 100) + dataset_to_train = dataset_to_train.subset_range(0, min(100, dataset_to_train.features.shape[0])) initial_pipeline = cnn_composite_pipeline() initial_pipeline.show() @@ -106,7 +106,7 @@ def run_image_classification_automl(train_dataset: tuple, composer_requirements = PipelineComposerRequirements( primary=get_operations_for_task(task=task, mode='all'), timeout=datetime.timedelta(minutes=3), - num_of_generations=20, n_jobs=1 + num_of_generations=20, n_jobs=1, cv_folds=None ) pop_size = 5 diff --git a/examples/advanced/customization/strategies/image_class.py b/examples/advanced/customization/strategies/image_class.py index def060bec8..947915c3a4 100644 --- a/examples/advanced/customization/strategies/image_class.py +++ b/examples/advanced/customization/strategies/image_class.py @@ -2,15 +2,15 @@ from typing import Optional from examples.advanced.customization.implementations.cnn_impls import MyCNNImplementation -from fedot.core.data.data import InputData, OutputData -from fedot.core.operations.evaluation.evaluation_interfaces import EvaluationStrategy +from fedot.core.data.data import InputData +from fedot.core.operations.evaluation.classification import FedotClassificationStrategy from fedot.core.operations.operation_parameters import OperationParameters from fedot.utilities.random import ImplementationRandomStateHandler warnings.filterwarnings("ignore", category=UserWarning) -class ImageClassificationStrategy(EvaluationStrategy): +class ImageClassificationStrategy(FedotClassificationStrategy): _operations_by_types = { 'cnn_1': MyCNNImplementation } @@ -33,27 +33,3 @@ def fit(self, train_data: InputData): with ImplementationRandomStateHandler(implementation=operation_implementation): operation_implementation.fit(train_data) return operation_implementation - - def predict(self, trained_operation, predict_data: InputData) -> OutputData: - """ - Predict method for classification task for predict stage - - :param trained_operation: model object - :param predict_data: data used for prediction - :return: prediction target - """ - n_classes = len(trained_operation.classes_) - if self.output_mode == 'labels': - prediction = trained_operation.predict(predict_data) - elif self.output_mode in ['probs', 'full_probs', 'default']: - prediction = trained_operation.predict_proba(predict_data) - if n_classes < 2: - raise ValueError('Data set contain only 1 target class. Please reformat your data.') - elif n_classes == 2 and self.output_mode != 'full_probs' and len(prediction.shape) > 1: - prediction = prediction[:, 1] - else: - raise ValueError(f'Output model {self.output_mode} is not supported') - - # Convert prediction to output (if it is required) - converted = self._convert_to_output(prediction, predict_data) - return converted diff --git a/fedot/core/operations/evaluation/classification.py b/fedot/core/operations/evaluation/classification.py index 0b90b49648..5262e915bd 100644 --- a/fedot/core/operations/evaluation/classification.py +++ b/fedot/core/operations/evaluation/classification.py @@ -82,7 +82,7 @@ def predict(self, trained_operation, predict_data: InputData) -> OutputData: if n_classes < 2: raise ValueError('Data set contain only 1 target class. Please reformat your data.') elif n_classes == 2 and self.output_mode != 'full_probs' and len(prediction.shape) > 1: - prediction = prediction[:, 1] + prediction = prediction[:, prediction.shape[1] - 1] else: raise ValueError(f'Output model {self.output_mode} is not supported') diff --git a/test/data/test_labels.npy b/test/data/test_labels.npy index 815f538382..b3408a2e61 100644 Binary files a/test/data/test_labels.npy and b/test/data/test_labels.npy differ diff --git a/test/data/training_labels.npy b/test/data/training_labels.npy index a0454f741b..b3408a2e61 100644 Binary files a/test/data/training_labels.npy and b/test/data/training_labels.npy differ