Skip to content

Commit

Permalink
Fix all remaining psalm issues
Browse files Browse the repository at this point in the history
  • Loading branch information
st3iny committed Oct 22, 2024
1 parent d0221d1 commit ebbed72
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 139 deletions.
57 changes: 8 additions & 49 deletions lib/Command/RunMetaEstimator.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@
namespace OCA\Mail\Command;

use OCA\Mail\Service\AccountService;
use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor;
use OCA\Mail\Service\Classification\ImportanceClassifier;
use OCA\Mail\Support\ConsoleLoggerDecorator;
use OCP\AppFramework\Db\DoesNotExistException;
use OCP\IConfig;
use Psr\Container\ContainerInterface;
use Psr\Log\LoggerInterface;
use Rubix\ML\Backends\Amp;
use Rubix\ML\Classifiers\KNearestNeighbors;
Expand All @@ -28,33 +26,28 @@
use Symfony\Component\Console\Command\Command;
use Symfony\Component\Console\Input\InputArgument;
use Symfony\Component\Console\Input\InputInterface;
use Symfony\Component\Console\Input\InputOption;
use Symfony\Component\Console\Output\OutputInterface;

class RunMetaEstimator extends Command {
public const ARGUMENT_ACCOUNT_ID = 'account-id';
public const ARGUMENT_SHUFFLE = 'shuffle';
public const ARGUMENT_LOAD_DATA = 'load-data';

private AccountService $accountService;
private LoggerInterface $logger;
private ImportanceClassifier $classifier;
private ContainerInterface $container;
private IConfig $config;

public function __construct(
AccountService $accountService,
LoggerInterface $logger,
ImportanceClassifier $classifier,
ContainerInterface $container,
IConfig $config,
) {
parent::__construct();

$this->accountService = $accountService;
$this->logger = $logger;
$this->classifier = $classifier;
$this->container = $container;
$this->config = $config;
}

Expand All @@ -63,12 +56,6 @@ protected function configure(): void {
$this->setDescription('Run the meta estimator for an account');
$this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED);
$this->addOption(self::ARGUMENT_SHUFFLE, null, null, 'Shuffle data set before training');
$this->addOption(
self::ARGUMENT_LOAD_DATA,
null,
InputOption::VALUE_REQUIRED,
'Load training data set from a JSON file'
);
}

public function isEnabled(): bool {
Expand All @@ -86,26 +73,11 @@ protected function execute(InputInterface $input, OutputInterface $output): int
return 1;
}

/** @var CompositeExtractor $extractor */
$extractor = $this->container->get(CompositeExtractor::class);
$consoleLogger = new ConsoleLoggerDecorator(
$this->logger,
$output
);

if ($loadDataPath = $input->getOption(self::ARGUMENT_LOAD_DATA)) {
$json = file_get_contents($loadDataPath);
$dataSet = json_decode($json, true, 512, JSON_THROW_ON_ERROR);
} else {
$dataSet = $this->classifier->buildDataSet(
$account,
$extractor,
$consoleLogger,
null,
$shuffle,
);
}

$estimator = static function () use ($consoleLogger) {
$params = [
[5, 10, 15, 20, 25, 30, 35, 40], // Neighbors
Expand All @@ -125,28 +97,15 @@ protected function execute(InputInterface $input, OutputInterface $output): int
};

/** @var GridSearch $metaEstimator */
if ($dataSet) {
$metaEstimator = $this->classifier->trainWithCustomDataSet(
$account,
$consoleLogger,
$dataSet,
$extractor,
$estimator,
null,
false,
);
} else {
$metaEstimator = $this->classifier->train(
$account,
$consoleLogger,
$extractor,
$estimator,
$shuffle,
false,
);
}
$metaEstimator = $this->classifier->train(
$account,
$consoleLogger,
$estimator,
$shuffle,
false,
);

if ($metaEstimator) {
if ($metaEstimator !== null) {
$output->writeln("<info>Best estimator: {$metaEstimator->base()}</info>");
}

Expand Down
76 changes: 9 additions & 67 deletions lib/Command/TrainAccount.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@

use OCA\Mail\Service\AccountService;
use OCA\Mail\Service\Classification\ClassificationSettingsService;
use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor;
use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor;
use OCA\Mail\Service\Classification\ImportanceClassifier;
use OCA\Mail\Support\ConsoleLoggerDecorator;
use OCP\AppFramework\Db\DoesNotExistException;
use Psr\Container\ContainerInterface;
use Psr\Log\LoggerInterface;
use Symfony\Component\Console\Command\Command;
use Symfony\Component\Console\Input\InputArgument;
Expand All @@ -27,39 +24,28 @@

class TrainAccount extends Command {
public const ARGUMENT_ACCOUNT_ID = 'account-id';
public const ARGUMENT_OLD = 'old';
public const ARGUMENT_OLD_ESTIMATOR = 'old-estimator';
public const ARGUMENT_OLD_EXTRACTOR = 'old-extractor';
public const ARGUMENT_SHUFFLE = 'shuffle';
public const ARGUMENT_SAVE_DATA = 'save-data';
public const ARGUMENT_LOAD_DATA = 'load-data';
public const ARGUMENT_DRY_RUN = 'dry-run';
public const ARGUMENT_FORCE = 'force';

private AccountService $accountService;
private ImportanceClassifier $classifier;
private LoggerInterface $logger;
private ContainerInterface $container;
private ClassificationSettingsService $classificationSettingsService;

public function __construct(AccountService $service,
ImportanceClassifier $classifier,
ClassificationSettingsService $classificationSettingsService,
LoggerInterface $logger,
ContainerInterface $container) {
LoggerInterface $logger) {
parent::__construct();

$this->accountService = $service;
$this->classifier = $classifier;
$this->logger = $logger;
$this->container = $container;
$this->classificationSettingsService = $classificationSettingsService;
}

/**
* @return void
*/
protected function configure() {
protected function configure(): void {
$this->setName('mail:account:train');
$this->setDescription('Train the classifier of new messages');
$this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED);
Expand All @@ -76,18 +62,6 @@ protected function configure() {
null,
'Train an estimator even if the classification is disabled by the user'
);
$this->addOption(
self::ARGUMENT_SAVE_DATA,
null,
InputOption::VALUE_REQUIRED,
'Save training data set to a JSON file'
);
$this->addOption(
self::ARGUMENT_LOAD_DATA,
null,
InputOption::VALUE_REQUIRED,
'Load training data set from a JSON file'
);
}

protected function execute(InputInterface $input, OutputInterface $output): int {
Expand All @@ -108,50 +82,18 @@ protected function execute(InputInterface $input, OutputInterface $output): int
return 2;
}

/** @var IExtractor $extractor */
$extractor = $this->container->get(CompositeExtractor::class);

$consoleLogger = new ConsoleLoggerDecorator(
$this->logger,
$output
);

$dataSet = null;
if ($saveDataPath = $input->getOption(self::ARGUMENT_SAVE_DATA)) {
$dataSet = $this->classifier->buildDataSet(
$account,
$extractor,
$consoleLogger,
null,
$shuffle,
);
$json = json_encode($dataSet, JSON_THROW_ON_ERROR);
file_put_contents($saveDataPath, $json);
} elseif ($loadDataPath = $input->getOption(self::ARGUMENT_LOAD_DATA)) {
$json = file_get_contents($loadDataPath);
$dataSet = json_decode($json, true, 512, JSON_THROW_ON_ERROR);
}

if ($dataSet) {
$this->classifier->trainWithCustomDataSet(
$account,
$consoleLogger,
$dataSet,
$extractor,
null,
null,
!$dryRun
);
} else {
$this->classifier->train(
$account,
$consoleLogger,
$extractor,
null,
$shuffle,
!$dryRun
);
}
$this->classifier->train(
$account,
$consoleLogger,
null,
$shuffle,
!$dryRun
);

$mbs = (int)(memory_get_peak_usage() / 1024 / 1024);
$output->writeln('<info>' . $mbs . 'MB of memory used</info>');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Transformers\MultibyteTextNormalizer;
use Rubix\ML\Transformers\TfIdfTransformer;
use Rubix\ML\Transformers\Transformer;
use Rubix\ML\Transformers\WordCountVectorizer;
use RuntimeException;
use function array_column;
use function array_map;

class SubjectExtractor implements IExtractor {
private WordCountVectorizer $wordCountVectorizer;
private Transformer $tfidf;
private TfIdfTransformer $tfidf;
private int $max = -1;

public function __construct() {
Expand Down
26 changes: 14 additions & 12 deletions lib/Service/Classification/ImportanceClassifier.php
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,17 @@ private static function createDefaultEstimator(): KNearestNeighbors {
return new KNearestNeighbors(15, true, new Manhattan());
}

/**
* @throws ServiceException If the extractor is not available
*/
private function createExtractor(): CompositeExtractor {
try {
return $this->container->get(CompositeExtractor::class);
} catch (ContainerExceptionInterface $e) {
throw new ServiceException('Default extractor is not available', 0, $e);
}
}

private function filterMessageHasSenderEmail(Message $message): bool {
return $message->getFrom()->first() !== null && $message->getFrom()->first()->getEmail() !== null;
}
Expand Down Expand Up @@ -201,7 +212,6 @@ public function buildDataSet(
*
* @param Account $account
* @param LoggerInterface $logger
* @param ?IExtractor $extractor The extractor to use for feature extraction. If null, the default extractor will be used.
* @param ?Closure $estimator Returned instance should at least implement Learner, Estimator and Persistable. If null, the default estimator will be used.
* @param bool $shuffleDataSet Shuffle the data set before training
* @param bool $persist Persist the trained classifier to use it for message classification
Expand All @@ -213,21 +223,13 @@ public function buildDataSet(
public function train(
Account $account,
LoggerInterface $logger,
?IExtractor $extractor = null,
?Closure $estimator = null,
bool $shuffleDataSet = false,
bool $persist = true,
): ?Estimator {
$perf = $this->performanceLogger->start('importance classifier training');

if ($extractor === null) {
try {
$extractor = $this->container->get(CompositeExtractor::class);
} catch (ContainerExceptionInterface $e) {
throw new ServiceException('Default extractor is not available', 0, $e);
}
}

$extractor = $this->createExtractor();
$dataSet = $this->buildDataSet($account, $extractor, $logger, $perf, $shuffleDataSet);
if ($dataSet === null) {
return null;
Expand All @@ -250,7 +252,7 @@ public function train(
* @param Account $account
* @param LoggerInterface $logger
* @param array $dataSet Training data set built by buildDataSet()
* @param IExtractor $extractor Extractor used to extract the given data set
* @param CompositeExtractor $extractor Extractor used to extract the given data set
* @param ?Closure $estimator Returned instance should at least implement Learner, Estimator and Persistable. If null, the default estimator will be used.
* @param PerformanceLoggerTask|null $perf Optionally reuse a performance logger task
* @param bool $persist Persist the trained classifier to use it for message classification
Expand All @@ -259,7 +261,7 @@ public function train(
*
* @throws ServiceException
*/
public function trainWithCustomDataSet(
private function trainWithCustomDataSet(
Account $account,
LoggerInterface $logger,
array $dataSet,
Expand Down
2 changes: 1 addition & 1 deletion lib/Service/Classification/PersistenceService.php
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ public function loadLatest(Account $account): ?ClassifierPipeline {

$extractor = $this->loadExtractor($transformers);

return new ClassifierPipeline($estimator, $extractor, $transformers);
return new ClassifierPipeline($estimator, $extractor);
}

/**
Expand Down
11 changes: 3 additions & 8 deletions lib/Service/Classification/RubixMemoryPersister.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@

use Rubix\ML\Encoding;
use Rubix\ML\Persisters\Persister;
use ValueError;

class RubixMemoryPersister implements Persister {
public function __construct(
private ?string $data = null,
private string $data = '',
) {
}

public function getData(): ?string {
public function getData(): string {
return $this->data;
}

Expand All @@ -28,14 +27,10 @@ public function save(Encoding $encoding): void {
}

public function load(): Encoding {
if ($this->data === null) {
throw new ValueError('Trying to load encoding when no data is available');
}

return new Encoding($this->data);
}

public function __toString() {
return (string)self::class;
return self::class;
}
}

0 comments on commit ebbed72

Please sign in to comment.