From ebbed7202538a98cf93214ab5d2184a2851099e7 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Tue, 22 Oct 2024 12:08:31 +0200 Subject: [PATCH] Fix all remaining psalm issues --- lib/Command/RunMetaEstimator.php | 57 ++------------ lib/Command/TrainAccount.php | 76 +++---------------- .../FeatureExtraction/SubjectExtractor.php | 3 +- .../Classification/ImportanceClassifier.php | 26 ++++--- .../Classification/PersistenceService.php | 2 +- .../Classification/RubixMemoryPersister.php | 11 +-- 6 files changed, 36 insertions(+), 139 deletions(-) diff --git a/lib/Command/RunMetaEstimator.php b/lib/Command/RunMetaEstimator.php index 32d1187d37..8519ed84f4 100644 --- a/lib/Command/RunMetaEstimator.php +++ b/lib/Command/RunMetaEstimator.php @@ -10,12 +10,10 @@ namespace OCA\Mail\Command; use OCA\Mail\Service\AccountService; -use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor; use OCA\Mail\Service\Classification\ImportanceClassifier; use OCA\Mail\Support\ConsoleLoggerDecorator; use OCP\AppFramework\Db\DoesNotExistException; use OCP\IConfig; -use Psr\Container\ContainerInterface; use Psr\Log\LoggerInterface; use Rubix\ML\Backends\Amp; use Rubix\ML\Classifiers\KNearestNeighbors; @@ -28,25 +26,21 @@ use Symfony\Component\Console\Command\Command; use Symfony\Component\Console\Input\InputArgument; use Symfony\Component\Console\Input\InputInterface; -use Symfony\Component\Console\Input\InputOption; use Symfony\Component\Console\Output\OutputInterface; class RunMetaEstimator extends Command { public const ARGUMENT_ACCOUNT_ID = 'account-id'; public const ARGUMENT_SHUFFLE = 'shuffle'; - public const ARGUMENT_LOAD_DATA = 'load-data'; private AccountService $accountService; private LoggerInterface $logger; private ImportanceClassifier $classifier; - private ContainerInterface $container; private IConfig $config; public function __construct( AccountService $accountService, LoggerInterface $logger, ImportanceClassifier $classifier, - ContainerInterface $container, IConfig $config, ) { parent::__construct(); @@ -54,7 +48,6 @@ public function __construct( $this->accountService = $accountService; $this->logger = $logger; $this->classifier = $classifier; - $this->container = $container; $this->config = $config; } @@ -63,12 +56,6 @@ protected function configure(): void { $this->setDescription('Run the meta estimator for an account'); $this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED); $this->addOption(self::ARGUMENT_SHUFFLE, null, null, 'Shuffle data set before training'); - $this->addOption( - self::ARGUMENT_LOAD_DATA, - null, - InputOption::VALUE_REQUIRED, - 'Load training data set from a JSON file' - ); } public function isEnabled(): bool { @@ -86,26 +73,11 @@ protected function execute(InputInterface $input, OutputInterface $output): int return 1; } - /** @var CompositeExtractor $extractor */ - $extractor = $this->container->get(CompositeExtractor::class); $consoleLogger = new ConsoleLoggerDecorator( $this->logger, $output ); - if ($loadDataPath = $input->getOption(self::ARGUMENT_LOAD_DATA)) { - $json = file_get_contents($loadDataPath); - $dataSet = json_decode($json, true, 512, JSON_THROW_ON_ERROR); - } else { - $dataSet = $this->classifier->buildDataSet( - $account, - $extractor, - $consoleLogger, - null, - $shuffle, - ); - } - $estimator = static function () use ($consoleLogger) { $params = [ [5, 10, 15, 20, 25, 30, 35, 40], // Neighbors @@ -125,28 +97,15 @@ protected function execute(InputInterface $input, OutputInterface $output): int }; /** @var GridSearch $metaEstimator */ - if ($dataSet) { - $metaEstimator = $this->classifier->trainWithCustomDataSet( - $account, - $consoleLogger, - $dataSet, - $extractor, - $estimator, - null, - false, - ); - } else { - $metaEstimator = $this->classifier->train( - $account, - $consoleLogger, - $extractor, - $estimator, - $shuffle, - false, - ); - } + $metaEstimator = $this->classifier->train( + $account, + $consoleLogger, + $estimator, + $shuffle, + false, + ); - if ($metaEstimator) { + if ($metaEstimator !== null) { $output->writeln("Best estimator: {$metaEstimator->base()}"); } diff --git a/lib/Command/TrainAccount.php b/lib/Command/TrainAccount.php index 28963f496b..a3eaf94aa9 100644 --- a/lib/Command/TrainAccount.php +++ b/lib/Command/TrainAccount.php @@ -11,12 +11,9 @@ use OCA\Mail\Service\AccountService; use OCA\Mail\Service\Classification\ClassificationSettingsService; -use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor; -use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; use OCA\Mail\Service\Classification\ImportanceClassifier; use OCA\Mail\Support\ConsoleLoggerDecorator; use OCP\AppFramework\Db\DoesNotExistException; -use Psr\Container\ContainerInterface; use Psr\Log\LoggerInterface; use Symfony\Component\Console\Command\Command; use Symfony\Component\Console\Input\InputArgument; @@ -27,39 +24,28 @@ class TrainAccount extends Command { public const ARGUMENT_ACCOUNT_ID = 'account-id'; - public const ARGUMENT_OLD = 'old'; - public const ARGUMENT_OLD_ESTIMATOR = 'old-estimator'; - public const ARGUMENT_OLD_EXTRACTOR = 'old-extractor'; public const ARGUMENT_SHUFFLE = 'shuffle'; - public const ARGUMENT_SAVE_DATA = 'save-data'; - public const ARGUMENT_LOAD_DATA = 'load-data'; public const ARGUMENT_DRY_RUN = 'dry-run'; public const ARGUMENT_FORCE = 'force'; private AccountService $accountService; private ImportanceClassifier $classifier; private LoggerInterface $logger; - private ContainerInterface $container; private ClassificationSettingsService $classificationSettingsService; public function __construct(AccountService $service, ImportanceClassifier $classifier, ClassificationSettingsService $classificationSettingsService, - LoggerInterface $logger, - ContainerInterface $container) { + LoggerInterface $logger) { parent::__construct(); $this->accountService = $service; $this->classifier = $classifier; $this->logger = $logger; - $this->container = $container; $this->classificationSettingsService = $classificationSettingsService; } - /** - * @return void - */ - protected function configure() { + protected function configure(): void { $this->setName('mail:account:train'); $this->setDescription('Train the classifier of new messages'); $this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED); @@ -76,18 +62,6 @@ protected function configure() { null, 'Train an estimator even if the classification is disabled by the user' ); - $this->addOption( - self::ARGUMENT_SAVE_DATA, - null, - InputOption::VALUE_REQUIRED, - 'Save training data set to a JSON file' - ); - $this->addOption( - self::ARGUMENT_LOAD_DATA, - null, - InputOption::VALUE_REQUIRED, - 'Load training data set from a JSON file' - ); } protected function execute(InputInterface $input, OutputInterface $output): int { @@ -108,50 +82,18 @@ protected function execute(InputInterface $input, OutputInterface $output): int return 2; } - /** @var IExtractor $extractor */ - $extractor = $this->container->get(CompositeExtractor::class); - $consoleLogger = new ConsoleLoggerDecorator( $this->logger, $output ); - $dataSet = null; - if ($saveDataPath = $input->getOption(self::ARGUMENT_SAVE_DATA)) { - $dataSet = $this->classifier->buildDataSet( - $account, - $extractor, - $consoleLogger, - null, - $shuffle, - ); - $json = json_encode($dataSet, JSON_THROW_ON_ERROR); - file_put_contents($saveDataPath, $json); - } elseif ($loadDataPath = $input->getOption(self::ARGUMENT_LOAD_DATA)) { - $json = file_get_contents($loadDataPath); - $dataSet = json_decode($json, true, 512, JSON_THROW_ON_ERROR); - } - - if ($dataSet) { - $this->classifier->trainWithCustomDataSet( - $account, - $consoleLogger, - $dataSet, - $extractor, - null, - null, - !$dryRun - ); - } else { - $this->classifier->train( - $account, - $consoleLogger, - $extractor, - null, - $shuffle, - !$dryRun - ); - } + $this->classifier->train( + $account, + $consoleLogger, + null, + $shuffle, + !$dryRun + ); $mbs = (int)(memory_get_peak_usage() / 1024 / 1024); $output->writeln('' . $mbs . 'MB of memory used'); diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php index 374a45a689..44bbca9e6f 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php @@ -15,7 +15,6 @@ use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Transformers\MultibyteTextNormalizer; use Rubix\ML\Transformers\TfIdfTransformer; -use Rubix\ML\Transformers\Transformer; use Rubix\ML\Transformers\WordCountVectorizer; use RuntimeException; use function array_column; @@ -23,7 +22,7 @@ class SubjectExtractor implements IExtractor { private WordCountVectorizer $wordCountVectorizer; - private Transformer $tfidf; + private TfIdfTransformer $tfidf; private int $max = -1; public function __construct() { diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index df05a355b2..fea58e4556 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -131,6 +131,17 @@ private static function createDefaultEstimator(): KNearestNeighbors { return new KNearestNeighbors(15, true, new Manhattan()); } + /** + * @throws ServiceException If the extractor is not available + */ + private function createExtractor(): CompositeExtractor { + try { + return $this->container->get(CompositeExtractor::class); + } catch (ContainerExceptionInterface $e) { + throw new ServiceException('Default extractor is not available', 0, $e); + } + } + private function filterMessageHasSenderEmail(Message $message): bool { return $message->getFrom()->first() !== null && $message->getFrom()->first()->getEmail() !== null; } @@ -201,7 +212,6 @@ public function buildDataSet( * * @param Account $account * @param LoggerInterface $logger - * @param ?IExtractor $extractor The extractor to use for feature extraction. If null, the default extractor will be used. * @param ?Closure $estimator Returned instance should at least implement Learner, Estimator and Persistable. If null, the default estimator will be used. * @param bool $shuffleDataSet Shuffle the data set before training * @param bool $persist Persist the trained classifier to use it for message classification @@ -213,21 +223,13 @@ public function buildDataSet( public function train( Account $account, LoggerInterface $logger, - ?IExtractor $extractor = null, ?Closure $estimator = null, bool $shuffleDataSet = false, bool $persist = true, ): ?Estimator { $perf = $this->performanceLogger->start('importance classifier training'); - if ($extractor === null) { - try { - $extractor = $this->container->get(CompositeExtractor::class); - } catch (ContainerExceptionInterface $e) { - throw new ServiceException('Default extractor is not available', 0, $e); - } - } - + $extractor = $this->createExtractor(); $dataSet = $this->buildDataSet($account, $extractor, $logger, $perf, $shuffleDataSet); if ($dataSet === null) { return null; @@ -250,7 +252,7 @@ public function train( * @param Account $account * @param LoggerInterface $logger * @param array $dataSet Training data set built by buildDataSet() - * @param IExtractor $extractor Extractor used to extract the given data set + * @param CompositeExtractor $extractor Extractor used to extract the given data set * @param ?Closure $estimator Returned instance should at least implement Learner, Estimator and Persistable. If null, the default estimator will be used. * @param PerformanceLoggerTask|null $perf Optionally reuse a performance logger task * @param bool $persist Persist the trained classifier to use it for message classification @@ -259,7 +261,7 @@ public function train( * * @throws ServiceException */ - public function trainWithCustomDataSet( + private function trainWithCustomDataSet( Account $account, LoggerInterface $logger, array $dataSet, diff --git a/lib/Service/Classification/PersistenceService.php b/lib/Service/Classification/PersistenceService.php index d25c40fb53..fd29e67bb3 100644 --- a/lib/Service/Classification/PersistenceService.php +++ b/lib/Service/Classification/PersistenceService.php @@ -142,7 +142,7 @@ public function loadLatest(Account $account): ?ClassifierPipeline { $extractor = $this->loadExtractor($transformers); - return new ClassifierPipeline($estimator, $extractor, $transformers); + return new ClassifierPipeline($estimator, $extractor); } /** diff --git a/lib/Service/Classification/RubixMemoryPersister.php b/lib/Service/Classification/RubixMemoryPersister.php index 2b170b38b5..c3abff2463 100644 --- a/lib/Service/Classification/RubixMemoryPersister.php +++ b/lib/Service/Classification/RubixMemoryPersister.php @@ -11,15 +11,14 @@ use Rubix\ML\Encoding; use Rubix\ML\Persisters\Persister; -use ValueError; class RubixMemoryPersister implements Persister { public function __construct( - private ?string $data = null, + private string $data = '', ) { } - public function getData(): ?string { + public function getData(): string { return $this->data; } @@ -28,14 +27,10 @@ public function save(Encoding $encoding): void { } public function load(): Encoding { - if ($this->data === null) { - throw new ValueError('Trying to load encoding when no data is available'); - } - return new Encoding($this->data); } public function __toString() { - return (string)self::class; + return self::class; } }