diff --git a/lib/Command/RunMetaEstimator.php b/lib/Command/RunMetaEstimator.php
index 32d1187d37..8519ed84f4 100644
--- a/lib/Command/RunMetaEstimator.php
+++ b/lib/Command/RunMetaEstimator.php
@@ -10,12 +10,10 @@
namespace OCA\Mail\Command;
use OCA\Mail\Service\AccountService;
-use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor;
use OCA\Mail\Service\Classification\ImportanceClassifier;
use OCA\Mail\Support\ConsoleLoggerDecorator;
use OCP\AppFramework\Db\DoesNotExistException;
use OCP\IConfig;
-use Psr\Container\ContainerInterface;
use Psr\Log\LoggerInterface;
use Rubix\ML\Backends\Amp;
use Rubix\ML\Classifiers\KNearestNeighbors;
@@ -28,25 +26,21 @@
use Symfony\Component\Console\Command\Command;
use Symfony\Component\Console\Input\InputArgument;
use Symfony\Component\Console\Input\InputInterface;
-use Symfony\Component\Console\Input\InputOption;
use Symfony\Component\Console\Output\OutputInterface;
class RunMetaEstimator extends Command {
public const ARGUMENT_ACCOUNT_ID = 'account-id';
public const ARGUMENT_SHUFFLE = 'shuffle';
- public const ARGUMENT_LOAD_DATA = 'load-data';
private AccountService $accountService;
private LoggerInterface $logger;
private ImportanceClassifier $classifier;
- private ContainerInterface $container;
private IConfig $config;
public function __construct(
AccountService $accountService,
LoggerInterface $logger,
ImportanceClassifier $classifier,
- ContainerInterface $container,
IConfig $config,
) {
parent::__construct();
@@ -54,7 +48,6 @@ public function __construct(
$this->accountService = $accountService;
$this->logger = $logger;
$this->classifier = $classifier;
- $this->container = $container;
$this->config = $config;
}
@@ -63,12 +56,6 @@ protected function configure(): void {
$this->setDescription('Run the meta estimator for an account');
$this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED);
$this->addOption(self::ARGUMENT_SHUFFLE, null, null, 'Shuffle data set before training');
- $this->addOption(
- self::ARGUMENT_LOAD_DATA,
- null,
- InputOption::VALUE_REQUIRED,
- 'Load training data set from a JSON file'
- );
}
public function isEnabled(): bool {
@@ -86,26 +73,11 @@ protected function execute(InputInterface $input, OutputInterface $output): int
return 1;
}
- /** @var CompositeExtractor $extractor */
- $extractor = $this->container->get(CompositeExtractor::class);
$consoleLogger = new ConsoleLoggerDecorator(
$this->logger,
$output
);
- if ($loadDataPath = $input->getOption(self::ARGUMENT_LOAD_DATA)) {
- $json = file_get_contents($loadDataPath);
- $dataSet = json_decode($json, true, 512, JSON_THROW_ON_ERROR);
- } else {
- $dataSet = $this->classifier->buildDataSet(
- $account,
- $extractor,
- $consoleLogger,
- null,
- $shuffle,
- );
- }
-
$estimator = static function () use ($consoleLogger) {
$params = [
[5, 10, 15, 20, 25, 30, 35, 40], // Neighbors
@@ -125,28 +97,15 @@ protected function execute(InputInterface $input, OutputInterface $output): int
};
/** @var GridSearch $metaEstimator */
- if ($dataSet) {
- $metaEstimator = $this->classifier->trainWithCustomDataSet(
- $account,
- $consoleLogger,
- $dataSet,
- $extractor,
- $estimator,
- null,
- false,
- );
- } else {
- $metaEstimator = $this->classifier->train(
- $account,
- $consoleLogger,
- $extractor,
- $estimator,
- $shuffle,
- false,
- );
- }
+ $metaEstimator = $this->classifier->train(
+ $account,
+ $consoleLogger,
+ $estimator,
+ $shuffle,
+ false,
+ );
- if ($metaEstimator) {
+ if ($metaEstimator !== null) {
$output->writeln("Best estimator: {$metaEstimator->base()}");
}
diff --git a/lib/Command/TrainAccount.php b/lib/Command/TrainAccount.php
index 28963f496b..a3eaf94aa9 100644
--- a/lib/Command/TrainAccount.php
+++ b/lib/Command/TrainAccount.php
@@ -11,12 +11,9 @@
use OCA\Mail\Service\AccountService;
use OCA\Mail\Service\Classification\ClassificationSettingsService;
-use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor;
-use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor;
use OCA\Mail\Service\Classification\ImportanceClassifier;
use OCA\Mail\Support\ConsoleLoggerDecorator;
use OCP\AppFramework\Db\DoesNotExistException;
-use Psr\Container\ContainerInterface;
use Psr\Log\LoggerInterface;
use Symfony\Component\Console\Command\Command;
use Symfony\Component\Console\Input\InputArgument;
@@ -27,39 +24,28 @@
class TrainAccount extends Command {
public const ARGUMENT_ACCOUNT_ID = 'account-id';
- public const ARGUMENT_OLD = 'old';
- public const ARGUMENT_OLD_ESTIMATOR = 'old-estimator';
- public const ARGUMENT_OLD_EXTRACTOR = 'old-extractor';
public const ARGUMENT_SHUFFLE = 'shuffle';
- public const ARGUMENT_SAVE_DATA = 'save-data';
- public const ARGUMENT_LOAD_DATA = 'load-data';
public const ARGUMENT_DRY_RUN = 'dry-run';
public const ARGUMENT_FORCE = 'force';
private AccountService $accountService;
private ImportanceClassifier $classifier;
private LoggerInterface $logger;
- private ContainerInterface $container;
private ClassificationSettingsService $classificationSettingsService;
public function __construct(AccountService $service,
ImportanceClassifier $classifier,
ClassificationSettingsService $classificationSettingsService,
- LoggerInterface $logger,
- ContainerInterface $container) {
+ LoggerInterface $logger) {
parent::__construct();
$this->accountService = $service;
$this->classifier = $classifier;
$this->logger = $logger;
- $this->container = $container;
$this->classificationSettingsService = $classificationSettingsService;
}
- /**
- * @return void
- */
- protected function configure() {
+ protected function configure(): void {
$this->setName('mail:account:train');
$this->setDescription('Train the classifier of new messages');
$this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED);
@@ -76,18 +62,6 @@ protected function configure() {
null,
'Train an estimator even if the classification is disabled by the user'
);
- $this->addOption(
- self::ARGUMENT_SAVE_DATA,
- null,
- InputOption::VALUE_REQUIRED,
- 'Save training data set to a JSON file'
- );
- $this->addOption(
- self::ARGUMENT_LOAD_DATA,
- null,
- InputOption::VALUE_REQUIRED,
- 'Load training data set from a JSON file'
- );
}
protected function execute(InputInterface $input, OutputInterface $output): int {
@@ -108,50 +82,18 @@ protected function execute(InputInterface $input, OutputInterface $output): int
return 2;
}
- /** @var IExtractor $extractor */
- $extractor = $this->container->get(CompositeExtractor::class);
-
$consoleLogger = new ConsoleLoggerDecorator(
$this->logger,
$output
);
- $dataSet = null;
- if ($saveDataPath = $input->getOption(self::ARGUMENT_SAVE_DATA)) {
- $dataSet = $this->classifier->buildDataSet(
- $account,
- $extractor,
- $consoleLogger,
- null,
- $shuffle,
- );
- $json = json_encode($dataSet, JSON_THROW_ON_ERROR);
- file_put_contents($saveDataPath, $json);
- } elseif ($loadDataPath = $input->getOption(self::ARGUMENT_LOAD_DATA)) {
- $json = file_get_contents($loadDataPath);
- $dataSet = json_decode($json, true, 512, JSON_THROW_ON_ERROR);
- }
-
- if ($dataSet) {
- $this->classifier->trainWithCustomDataSet(
- $account,
- $consoleLogger,
- $dataSet,
- $extractor,
- null,
- null,
- !$dryRun
- );
- } else {
- $this->classifier->train(
- $account,
- $consoleLogger,
- $extractor,
- null,
- $shuffle,
- !$dryRun
- );
- }
+ $this->classifier->train(
+ $account,
+ $consoleLogger,
+ null,
+ $shuffle,
+ !$dryRun
+ );
$mbs = (int)(memory_get_peak_usage() / 1024 / 1024);
$output->writeln('' . $mbs . 'MB of memory used');
diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php
index 374a45a689..44bbca9e6f 100644
--- a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php
+++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php
@@ -15,7 +15,6 @@
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Transformers\MultibyteTextNormalizer;
use Rubix\ML\Transformers\TfIdfTransformer;
-use Rubix\ML\Transformers\Transformer;
use Rubix\ML\Transformers\WordCountVectorizer;
use RuntimeException;
use function array_column;
@@ -23,7 +22,7 @@
class SubjectExtractor implements IExtractor {
private WordCountVectorizer $wordCountVectorizer;
- private Transformer $tfidf;
+ private TfIdfTransformer $tfidf;
private int $max = -1;
public function __construct() {
diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php
index df05a355b2..fea58e4556 100644
--- a/lib/Service/Classification/ImportanceClassifier.php
+++ b/lib/Service/Classification/ImportanceClassifier.php
@@ -131,6 +131,17 @@ private static function createDefaultEstimator(): KNearestNeighbors {
return new KNearestNeighbors(15, true, new Manhattan());
}
+ /**
+ * @throws ServiceException If the extractor is not available
+ */
+ private function createExtractor(): CompositeExtractor {
+ try {
+ return $this->container->get(CompositeExtractor::class);
+ } catch (ContainerExceptionInterface $e) {
+ throw new ServiceException('Default extractor is not available', 0, $e);
+ }
+ }
+
private function filterMessageHasSenderEmail(Message $message): bool {
return $message->getFrom()->first() !== null && $message->getFrom()->first()->getEmail() !== null;
}
@@ -201,7 +212,6 @@ public function buildDataSet(
*
* @param Account $account
* @param LoggerInterface $logger
- * @param ?IExtractor $extractor The extractor to use for feature extraction. If null, the default extractor will be used.
* @param ?Closure $estimator Returned instance should at least implement Learner, Estimator and Persistable. If null, the default estimator will be used.
* @param bool $shuffleDataSet Shuffle the data set before training
* @param bool $persist Persist the trained classifier to use it for message classification
@@ -213,21 +223,13 @@ public function buildDataSet(
public function train(
Account $account,
LoggerInterface $logger,
- ?IExtractor $extractor = null,
?Closure $estimator = null,
bool $shuffleDataSet = false,
bool $persist = true,
): ?Estimator {
$perf = $this->performanceLogger->start('importance classifier training');
- if ($extractor === null) {
- try {
- $extractor = $this->container->get(CompositeExtractor::class);
- } catch (ContainerExceptionInterface $e) {
- throw new ServiceException('Default extractor is not available', 0, $e);
- }
- }
-
+ $extractor = $this->createExtractor();
$dataSet = $this->buildDataSet($account, $extractor, $logger, $perf, $shuffleDataSet);
if ($dataSet === null) {
return null;
@@ -250,7 +252,7 @@ public function train(
* @param Account $account
* @param LoggerInterface $logger
* @param array $dataSet Training data set built by buildDataSet()
- * @param IExtractor $extractor Extractor used to extract the given data set
+ * @param CompositeExtractor $extractor Extractor used to extract the given data set
* @param ?Closure $estimator Returned instance should at least implement Learner, Estimator and Persistable. If null, the default estimator will be used.
* @param PerformanceLoggerTask|null $perf Optionally reuse a performance logger task
* @param bool $persist Persist the trained classifier to use it for message classification
@@ -259,7 +261,7 @@ public function train(
*
* @throws ServiceException
*/
- public function trainWithCustomDataSet(
+ private function trainWithCustomDataSet(
Account $account,
LoggerInterface $logger,
array $dataSet,
diff --git a/lib/Service/Classification/PersistenceService.php b/lib/Service/Classification/PersistenceService.php
index d25c40fb53..fd29e67bb3 100644
--- a/lib/Service/Classification/PersistenceService.php
+++ b/lib/Service/Classification/PersistenceService.php
@@ -142,7 +142,7 @@ public function loadLatest(Account $account): ?ClassifierPipeline {
$extractor = $this->loadExtractor($transformers);
- return new ClassifierPipeline($estimator, $extractor, $transformers);
+ return new ClassifierPipeline($estimator, $extractor);
}
/**
diff --git a/lib/Service/Classification/RubixMemoryPersister.php b/lib/Service/Classification/RubixMemoryPersister.php
index 2b170b38b5..c3abff2463 100644
--- a/lib/Service/Classification/RubixMemoryPersister.php
+++ b/lib/Service/Classification/RubixMemoryPersister.php
@@ -11,15 +11,14 @@
use Rubix\ML\Encoding;
use Rubix\ML\Persisters\Persister;
-use ValueError;
class RubixMemoryPersister implements Persister {
public function __construct(
- private ?string $data = null,
+ private string $data = '',
) {
}
- public function getData(): ?string {
+ public function getData(): string {
return $this->data;
}
@@ -28,14 +27,10 @@ public function save(Encoding $encoding): void {
}
public function load(): Encoding {
- if ($this->data === null) {
- throw new ValueError('Trying to load encoding when no data is available');
- }
-
return new Encoding($this->data);
}
public function __toString() {
- return (string)self::class;
+ return self::class;
}
}