From bca2af74704159b9faead614f432cfe435b79b82 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Wed, 4 Jan 2023 17:25:36 +0100 Subject: [PATCH 01/37] Classify emails based on subjects --- lib/Db/StatisticsDao.php | 23 ++++ .../FeatureExtraction/SubjectExtractor.php | 105 ++++++++++++++++++ .../Classification/ImportanceClassifier.php | 86 +++++++++++++- 3 files changed, 208 insertions(+), 6 deletions(-) create mode 100644 lib/Service/Classification/FeatureExtraction/SubjectExtractor.php diff --git a/lib/Db/StatisticsDao.php b/lib/Db/StatisticsDao.php index 84c92af738..b1f3a1e2b4 100644 --- a/lib/Db/StatisticsDao.php +++ b/lib/Db/StatisticsDao.php @@ -144,6 +144,29 @@ public function getNumberOfMessagesGrouped(array $mailboxes, array $emails): arr return $data; } + public function getSubjectsGrouped(array $mailboxes, array $emails): array { + $qb = $this->db->getQueryBuilder(); + + $mailboxIds = array_map(function (Mailbox $mb) { + return $mb->getId(); + }, $mailboxes); + $select = $qb->selectAlias('r.email', 'email') + ->selectAlias('m.subject', 'subject') + ->from('mail_recipients', 'r') + ->join('r', 'mail_messages', 'm', $qb->expr()->eq('m.id', 'r.message_id', IQueryBuilder::PARAM_INT)) + ->where($qb->expr()->eq('r.type', $qb->createNamedParameter(Address::TYPE_FROM, IQueryBuilder::PARAM_INT), IQueryBuilder::PARAM_INT)) + ->andWhere($qb->expr()->in('m.mailbox_id', $qb->createNamedParameter($mailboxIds, IQueryBuilder::PARAM_INT_ARRAY))) + ->andWhere($qb->expr()->in('r.email', $qb->createNamedParameter($emails, IQueryBuilder::PARAM_STR_ARRAY), IQueryBuilder::PARAM_STR_ARRAY)); + $result = $select->execute(); + $rows = $result->fetchAll(); + $result->closeCursor(); + $data = []; + foreach ($rows as $row) { + $data[$row['email']][] = $row['subject']; + } + return $data; + } + public function getNrOfReadMessages(Mailbox $mb, string $email): int { $qb = $this->db->getQueryBuilder(); diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php new file mode 100644 index 0000000000..db6921edf8 --- /dev/null +++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php @@ -0,0 +1,105 @@ + + * + * @author Richard Steinmetz + * + * @license AGPL-3.0-or-later + * + * This code is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License, version 3, + * along with this program. If not, see + * + */ + +namespace OCA\Mail\Service\Classification\FeatureExtraction; + +use OCA\Mail\Account; +use OCA\Mail\Db\Message; +use OCA\Mail\Db\StatisticsDao; +use Rubix\ML\Datasets\Dataset; +use Rubix\ML\Datasets\Labeled; +use Rubix\ML\Datasets\Unlabeled; +use Rubix\ML\Transformers\TextNormalizer; +use Rubix\ML\Transformers\TfIdfTransformer; +use Rubix\ML\Transformers\WordCountVectorizer; +use function OCA\Mail\array_flat_map; + +class SubjectExtractor { + /** @var StatisticsDao */ + private $statisticsDao; + + /** @var string[][] */ + private $subjects; + + public function __construct(StatisticsDao $statisticsDao) { + $this->statisticsDao = $statisticsDao; + } + + /** + * @inheritDoc + */ + public function prepare(Account $account, array $incomingMailboxes, array $outgoingMailboxes, array $messages): void { + /** @var string[] $senders */ + $senders = array_unique(array_map(function (Message $message) { + return $message->getFrom()->first()->getEmail(); + }, array_filter($messages, function (Message $message) { + return $message->getFrom()->first() !== null && $message->getFrom()->first()->getEmail() !== null; + }))); + + $this->subjects = $this->statisticsDao->getSubjectsGrouped($incomingMailboxes, $senders); + } + + /** + * @inheritDoc + */ + public function extract(string $email): array { + $concatSubjects = []; + foreach ($this->subjects as $sender => $subjects) { + if ($sender !== $email) { + continue; + } + + $concatSubjects[] = $subjects; + } + + $subject = implode(' ', array_merge(...$concatSubjects)); + $subjects = array_unique(array_merge(...$concatSubjects)); + + //$data = new Labeled([$subject], [$email]); + //$data = new Unlabeled($subjects); + $data = Unlabeled::build($subjects) + ->apply(new TextNormalizer()) + ->apply(new WordCountVectorizer(20)); + //->apply(new TfIdfTransformer()); + return $data->samples(); + } + + public function getSubjects(): array { + return array_merge(...array_values($this->subjects)); + } + + public function getSubjectsOfSender(string $email): array { + $concatSubjects = []; + foreach ($this->subjects as $sender => $subjects) { + if ($sender !== $email) { + continue; + } + + $concatSubjects[] = $subjects; + } + + return array_merge(...$concatSubjects); + } +} diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index 434d0c5bcc..09f5fefd61 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -19,6 +19,7 @@ use OCA\Mail\Exception\ClassifierTrainingException; use OCA\Mail\Exception\ServiceException; use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor; +use OCA\Mail\Service\Classification\FeatureExtraction\SubjectExtractor; use OCA\Mail\Support\PerformanceLogger; use OCP\AppFramework\Db\DoesNotExistException; use Psr\Log\LoggerInterface; @@ -27,6 +28,9 @@ use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Estimator; +use Rubix\ML\Transformers\TextNormalizer; +use Rubix\ML\Transformers\TfIdfTransformer; +use Rubix\ML\Transformers\WordCountVectorizer; use RuntimeException; use function array_column; use function array_combine; @@ -81,7 +85,7 @@ class ImportanceClassifier { /** * The maximum number of data sets to train the classifier with */ - private const MAX_TRAINING_SET_SIZE = 1000; + private const MAX_TRAINING_SET_SIZE = 10000; /** @var MailboxMapper */ private $mailboxMapper; @@ -92,6 +96,9 @@ class ImportanceClassifier { /** @var CompositeExtractor */ private $extractor; + /** @var SubjectExtractor */ + private $subjectExtractor; + /** @var PersistenceService */ private $persistenceService; @@ -109,7 +116,8 @@ public function __construct(MailboxMapper $mailboxMapper, PersistenceService $persistenceService, PerformanceLogger $performanceLogger, ImportanceRulesClassifier $rulesClassifier, - LoggerInterface $logger) { + LoggerInterface $logger, + SubjectExtractor $subjectExtractor) { $this->mailboxMapper = $mailboxMapper; $this->messageMapper = $messageMapper; $this->extractor = $extractor; @@ -117,6 +125,7 @@ public function __construct(MailboxMapper $mailboxMapper, $this->performanceLogger = $performanceLogger; $this->rulesClassifier = $rulesClassifier; $this->logger = $logger; + $this->subjectExtractor = $subjectExtractor; } private function filterMessageHasSenderEmail(Message $message): bool { @@ -168,16 +177,33 @@ public function train(Account $account, LoggerInterface $logger): void { $dataSet = $this->getFeaturesAndImportance($account, $incomingMailboxes, $outgoingMailboxes, $messages); $perf->step('extract features from messages'); + shuffle($dataSet); + /** * How many of the most recent messages are excluded from training? */ $validationThreshold = max( 5, - (int)(count($dataSet) * 0.1) + (int)(count($dataSet) * 0.25) ); $validationSet = array_slice($dataSet, 0, $validationThreshold); $trainingSet = array_slice($dataSet, $validationThreshold); - $logger->debug('data set split into ' . count($trainingSet) . ' training and ' . count($validationSet) . ' validation sets with ' . count($trainingSet[0]['features'] ?? []) . ' dimensions'); + + $validationSetImportantCount = 0; + $trainingSetImportantCount = 0; + foreach ($validationSet as $data) { + if ($data['label'] === self::LABEL_IMPORTANT) { + $validationSetImportantCount++; + } + } + foreach ($trainingSet as $data) { + if ($data['label'] === self::LABEL_IMPORTANT) { + $trainingSetImportantCount++; + } + } + + $logger->debug('data set split into ' . count($trainingSet) . ' (' . $trainingSetImportantCount . ') training and ' . count($validationSet) . ' (' . $validationSetImportantCount . ') validation sets with ' . count($trainingSet[0]['features'] ?? []) . ' dimensions'); + if ($validationSet === [] || $trainingSet === []) { $logger->info('not enough messages to train a classifier'); $perf->end(); @@ -261,15 +287,63 @@ private function getFeaturesAndImportance(Account $account, array $outgoingMailboxes, array $messages): array { $this->extractor->prepare($account, $incomingMailboxes, $outgoingMailboxes, $messages); + $this->subjectExtractor->prepare($account, $incomingMailboxes, $outgoingMailboxes, $messages); + + $allSubjects = $this->subjectExtractor->getSubjects(); - return array_map(function (Message $message) { + $max = 1000; + $wcv = new WordCountVectorizer($max); + $tfidf = new TfIdfTransformer(); + + $accData = Unlabeled::build($allSubjects) + ->apply(new TextNormalizer()) + ->apply($wcv) + ->apply($tfidf); + + $vocab = $wcv->vocabularies()[0]; + //$vocab = array_slice($vocab, 0, $max); + //var_dump($vocab); + + return array_map(function (Message $message) use ($max, $wcv, $tfidf) { $sender = $message->getFrom()->first(); if ($sender === null) { throw new RuntimeException('This should not happen'); } + $subjects = $this->subjectExtractor->getSubjectsOfSender($sender->getEmail()); + + $data = []; + if ($message->getSubject() !== null) { + //$fdata = Unlabeled::build([$message->getSubject()]) + $fdata = Unlabeled::build($subjects) + ->apply($wcv); + //->apply($tfidf); + if ($fdata->numColumns() === 0) { + $data = array_fill(0, $max, 0); + } else { + $data = $fdata->sample(0); + //$data = array_slice($data, 0, $max); + } + } + if (count($data) > $max) { + //$data = array_slice($data, 0, $max); + } else { + while (count($data) < $max) { + //$data[] = 0; + } + } + + /* + $features = array_merge( + $this->extractor->extract($sender->getEmail()), + $data, + //$this->subjectExtractor->extract($sender->getEmail()), + ); + */ + $features = $data; + return [ - 'features' => $this->extractor->extract($message), + 'features' => $features, 'label' => $message->getFlagImportant() ? self::LABEL_IMPORTANT : self::LABEL_NOT_IMPORTANT, 'sender' => $sender->getEmail(), ]; From dc47dc1dff1acea9485a10fc71d44f620361bd5d Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Mon, 9 Jan 2023 11:15:05 +0100 Subject: [PATCH 02/37] fixup! Classify emails based on subjects --- .../Classification/ImportanceClassifier.php | 38 ++++++++++++++----- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index 09f5fefd61..11b979c51b 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -23,14 +23,19 @@ use OCA\Mail\Support\PerformanceLogger; use OCP\AppFramework\Db\DoesNotExistException; use Psr\Log\LoggerInterface; +use Rubix\ML\Classifiers\ClassificationTree; use Rubix\ML\Classifiers\GaussianNB; +use Rubix\ML\Classifiers\NaiveBayes; +use Rubix\ML\Classifiers\RandomForest; use Rubix\ML\CrossValidation\Reports\MulticlassBreakdown; use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Estimator; +use Rubix\ML\Transformers\MultibyteTextNormalizer; use Rubix\ML\Transformers\TextNormalizer; use Rubix\ML\Transformers\TfIdfTransformer; use Rubix\ML\Transformers\WordCountVectorizer; +use Rubix\ML\Transformers\ZScaleStandardizer; use RuntimeException; use function array_column; use function array_combine; @@ -85,7 +90,7 @@ class ImportanceClassifier { /** * The maximum number of data sets to train the classifier with */ - private const MAX_TRAINING_SET_SIZE = 10000; + private const MAX_TRAINING_SET_SIZE = 1000; /** @var MailboxMapper */ private $mailboxMapper; @@ -177,7 +182,7 @@ public function train(Account $account, LoggerInterface $logger): void { $dataSet = $this->getFeaturesAndImportance($account, $incomingMailboxes, $outgoingMailboxes, $messages); $perf->step('extract features from messages'); - shuffle($dataSet); + //shuffle($dataSet); /** * How many of the most recent messages are excluded from training? @@ -292,37 +297,42 @@ private function getFeaturesAndImportance(Account $account, $allSubjects = $this->subjectExtractor->getSubjects(); $max = 1000; - $wcv = new WordCountVectorizer($max); + $wcv = new WordCountVectorizer($max, 1); $tfidf = new TfIdfTransformer(); + $zscale = new ZScaleStandardizer(); $accData = Unlabeled::build($allSubjects) - ->apply(new TextNormalizer()) + ->apply(new MultibyteTextNormalizer()) ->apply($wcv) - ->apply($tfidf); + ->apply($tfidf) + ->apply($zscale); $vocab = $wcv->vocabularies()[0]; //$vocab = array_slice($vocab, 0, $max); //var_dump($vocab); + $max = count($vocab); - return array_map(function (Message $message) use ($max, $wcv, $tfidf) { + return array_map(function (Message $message) use ($max, $wcv, $tfidf, $zscale) { $sender = $message->getFrom()->first(); if ($sender === null) { throw new RuntimeException('This should not happen'); } $subjects = $this->subjectExtractor->getSubjectsOfSender($sender->getEmail()); + //$subjects = [$message->getSubject()]; $data = []; if ($message->getSubject() !== null) { //$fdata = Unlabeled::build([$message->getSubject()]) $fdata = Unlabeled::build($subjects) - ->apply($wcv); - //->apply($tfidf); + ->apply(new MultibyteTextNormalizer()) + ->apply($wcv) + ->apply($tfidf) + ->apply($zscale); if ($fdata->numColumns() === 0) { $data = array_fill(0, $max, 0); } else { $data = $fdata->sample(0); - //$data = array_slice($data, 0, $max); } } if (count($data) > $max) { @@ -404,8 +414,16 @@ public function classifyImportance(Account $account, array $messages): array { ); } - private function trainClassifier(array $trainingSet): GaussianNB { + private function trainClassifier(array $trainingSet): Estimator { $classifier = new GaussianNB(); + /* + $classifier = new RandomForest( + new ClassificationTree(10, 1), + 10, + 0.2, + true, + ); + */ $classifier->train(Labeled::build( array_column($trainingSet, 'features'), array_column($trainingSet, 'label') From 056f67f5bffa437adb604811e474f1e77d88300b Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Wed, 18 Jan 2023 15:03:16 +0100 Subject: [PATCH 03/37] fixup! Classify emails based on subjects --- lib/Command/TrainAccount.php | 25 ++- lib/Db/StatisticsDao.php | 26 ++- .../FeatureExtraction/CompositeExtractor.php | 24 ++- .../NewCompositeExtractor.php | 33 ++++ .../SubjectAndPreviewTextExtractor.php | 150 ++++++++++++++++++ .../FeatureExtraction/SubjectExtractor.php | 105 ------------ .../VanillaCompositeExtractor.php | 40 +++++ .../Classification/ImportanceClassifier.php | 93 ++--------- 8 files changed, 295 insertions(+), 201 deletions(-) create mode 100644 lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php create mode 100644 lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php delete mode 100644 lib/Service/Classification/FeatureExtraction/SubjectExtractor.php create mode 100644 lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php diff --git a/lib/Command/TrainAccount.php b/lib/Command/TrainAccount.php index a2d2b09cc4..80b1870081 100644 --- a/lib/Command/TrainAccount.php +++ b/lib/Command/TrainAccount.php @@ -11,9 +11,12 @@ use OCA\Mail\Service\AccountService; use OCA\Mail\Service\Classification\ClassificationSettingsService; +use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor; +use OCA\Mail\Service\Classification\FeatureExtraction\VanillaCompositeExtractor; use OCA\Mail\Service\Classification\ImportanceClassifier; use OCA\Mail\Support\ConsoleLoggerDecorator; use OCP\AppFramework\Db\DoesNotExistException; +use Psr\Container\ContainerInterface; use Psr\Log\LoggerInterface; use Symfony\Component\Console\Command\Command; use Symfony\Component\Console\Input\InputArgument; @@ -23,21 +26,26 @@ class TrainAccount extends Command { public const ARGUMENT_ACCOUNT_ID = 'account-id'; + public const ARGUMENT_NEW = 'new'; + public const ARGUMENT_SHUFFLE = 'shuffle'; private AccountService $accountService; private ImportanceClassifier $classifier; private LoggerInterface $logger; + private ContainerInterface $container; private ClassificationSettingsService $classificationSettingsService; public function __construct(AccountService $service, ImportanceClassifier $classifier, ClassificationSettingsService $classificationSettingsService, - LoggerInterface $logger) { + LoggerInterface $logger, + ContainerInterface $container) { parent::__construct(); $this->accountService = $service; $this->classifier = $classifier; $this->logger = $logger; + $this->container = $container; $this->classificationSettingsService = $classificationSettingsService; } @@ -48,6 +56,8 @@ protected function configure() { $this->setName('mail:account:train'); $this->setDescription('Train the classifier of new messages'); $this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED); + $this->addOption(self::ARGUMENT_NEW, null, null, 'Enable new composite extractor using text based features'); + $this->addOption(self::ARGUMENT_SHUFFLE, null, null, 'Shuffle data set before training'); } /** @@ -62,10 +72,19 @@ protected function execute(InputInterface $input, OutputInterface $output): int $output->writeln("account $accountId does not exist"); return 1; } + + /* if (!$this->classificationSettingsService->isClassificationEnabled($account->getUserId())) { $output->writeln("classification is turned off for account $accountId"); return 2; } + */ + + if ($input->getOption(self::ARGUMENT_NEW)) { + $extractor = $this->container->get(NewCompositeExtractor::class); + } else { + $extractor = $this->container->get(VanillaCompositeExtractor::class); + } $consoleLogger = new ConsoleLoggerDecorator( $this->logger, @@ -73,7 +92,9 @@ protected function execute(InputInterface $input, OutputInterface $output): int ); $this->classifier->train( $account, - $consoleLogger + $consoleLogger, + $extractor, + (bool)$input->getOption(self::ARGUMENT_SHUFFLE), ); $mbs = (int)(memory_get_peak_usage() / 1024 / 1024); diff --git a/lib/Db/StatisticsDao.php b/lib/Db/StatisticsDao.php index b1f3a1e2b4..9d3ded4ea6 100644 --- a/lib/Db/StatisticsDao.php +++ b/lib/Db/StatisticsDao.php @@ -144,7 +144,7 @@ public function getNumberOfMessagesGrouped(array $mailboxes, array $emails): arr return $data; } - public function getSubjectsGrouped(array $mailboxes, array $emails): array { + public function getSubjects(array $mailboxes, array $emails): array { $qb = $this->db->getQueryBuilder(); $mailboxIds = array_map(function (Mailbox $mb) { @@ -167,6 +167,30 @@ public function getSubjectsGrouped(array $mailboxes, array $emails): array { return $data; } + public function getPreviewTexts(array $mailboxes, array $emails): array { + $qb = $this->db->getQueryBuilder(); + + $mailboxIds = array_map(function (Mailbox $mb) { + return $mb->getId(); + }, $mailboxes); + $select = $qb->selectAlias('r.email', 'email') + ->selectAlias('m.preview_text', 'preview_text') + ->from('mail_recipients', 'r') + ->join('r', 'mail_messages', 'm', $qb->expr()->eq('m.id', 'r.message_id', IQueryBuilder::PARAM_INT)) + ->where($qb->expr()->eq('r.type', $qb->createNamedParameter(Address::TYPE_FROM, IQueryBuilder::PARAM_INT), IQueryBuilder::PARAM_INT)) + ->andWhere($qb->expr()->in('m.mailbox_id', $qb->createNamedParameter($mailboxIds, IQueryBuilder::PARAM_INT_ARRAY))) + ->andWhere($qb->expr()->in('r.email', $qb->createNamedParameter($emails, IQueryBuilder::PARAM_STR_ARRAY), IQueryBuilder::PARAM_STR_ARRAY)) + ->andWhere($qb->expr()->isNotNull('m.preview_text')); + $result = $select->execute(); + $rows = $result->fetchAll(); + $result->closeCursor(); + $data = []; + foreach ($rows as $row) { + $data[$row['email']][] = $row['preview_text']; + } + return $data; + } + public function getNrOfReadMessages(Mailbox $mb, string $email): int { $qb = $this->db->getQueryBuilder(); diff --git a/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php index eefadd92ba..197d0f2eb4 100644 --- a/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php @@ -16,20 +16,15 @@ /** * Combines a set of DI'ed extractors so they can be used as one class */ -class CompositeExtractor implements IExtractor { +abstract class CompositeExtractor implements IExtractor { /** @var IExtractor[] */ - private $extractors; - - public function __construct(ImportantMessagesExtractor $ex1, - ReadMessagesExtractor $ex2, - RepliedMessagesExtractor $ex3, - SentMessagesExtractor $ex4) { - $this->extractors = [ - $ex1, - $ex2, - $ex3, - $ex4, - ]; + protected array $extractors; + + /** + * @param IExtractor[] $extractors + */ + public function __construct(array $extractors) { + $this->extractors = $extractors; } public function prepare(Account $account, @@ -41,6 +36,9 @@ public function prepare(Account $account, } } + /** + * @inheritDoc + */ public function extract(Message $message): array { return array_flat_map(static function (IExtractor $extractor) use ($message) { return $extractor->extract($message); diff --git a/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php new file mode 100644 index 0000000000..df58a93fe2 --- /dev/null +++ b/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php @@ -0,0 +1,33 @@ + + * + * @author Richard Steinmetz + * + * @license AGPL-3.0-or-later + * + * This code is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License, version 3, + * along with this program. If not, see + * + */ + +namespace OCA\Mail\Service\Classification\FeatureExtraction; + +class NewCompositeExtractor extends CompositeExtractor { + public function __construct(VanillaCompositeExtractor $ex1, + SubjectAndPreviewTextExtractor $ex2) { + parent::__construct([$ex1, $ex2]); + } +} diff --git a/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php new file mode 100644 index 0000000000..6eb3710c0d --- /dev/null +++ b/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php @@ -0,0 +1,150 @@ + + * + * @author Richard Steinmetz + * + * @license AGPL-3.0-or-later + * + * This code is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License, version 3, + * along with this program. If not, see + * + */ + +namespace OCA\Mail\Service\Classification\FeatureExtraction; + +use OCA\Mail\Account; +use OCA\Mail\Db\Message; +use OCA\Mail\Db\StatisticsDao; +use Rubix\ML\Datasets\Unlabeled; +use Rubix\ML\Transformers\MultibyteTextNormalizer; +use Rubix\ML\Transformers\TfIdfTransformer; +use Rubix\ML\Transformers\WordCountVectorizer; +use RuntimeException; + +class SubjectAndPreviewTextExtractor implements IExtractor { + private StatisticsDao $statisticsDao; + private WordCountVectorizer $wordCountVectorizer; + private TfIdfTransformer $tfIdfTransformer; + private int $max = -1; + + /** @var string[][] */ + private array $subjects; + + /** @var string[][] */ + private array $previewTexts; + + public function __construct(StatisticsDao $statisticsDao) { + $this->statisticsDao = $statisticsDao; + // Limit vocabulary to limit ram usage. It takes about 5 GB of ram if an unbounded + // vocabulary is used (and a lot more time to compute). + $this->wordCountVectorizer = new WordCountVectorizer(1000); + $this->tfIdfTransformer = new TfIdfTransformer(); + } + + /** + * @inheritDoc + */ + public function prepare(Account $account, array $incomingMailboxes, array $outgoingMailboxes, array $messages): void { + /** @var string[] $senders */ + $senders = array_unique(array_map(function (Message $message) { + return $message->getFrom()->first()->getEmail(); + }, array_filter($messages, function (Message $message) { + return $message->getFrom()->first() !== null && $message->getFrom()->first()->getEmail() !== null; + }))); + + $this->subjects = $this->statisticsDao->getSubjects($incomingMailboxes, $senders); + $this->previewTexts = $this->statisticsDao->getPreviewTexts($incomingMailboxes, $senders); + + // Fit transformers + $fitText = implode(' ', [...$this->getSubjects(), ...$this->getPreviewTexts()]); + Unlabeled::build([$fitText]) + ->apply(new MultibyteTextNormalizer()) + ->apply($this->wordCountVectorizer) + ->apply($this->tfIdfTransformer); + + // Limit feature vector length to actual vocabulary size + $vocab = $this->wordCountVectorizer->vocabularies()[0]; + $this->max = count($vocab); + } + + /** + * @inheritDoc + */ + public function extract(Message $message): array { + $sender = $message->getFrom()->first(); + if ($sender === null) { + throw new RuntimeException("This should not happen"); + } + $email = $sender->getEmail(); + + // Build training data set + $subjects = $this->getSubjectsOfSender($email); + $previewTexts = $this->getPreviewTextsOfSender($email); + $trainText = implode(' ', [...$subjects, ...$previewTexts]); + + $textFeatures = []; + if ($message->getSubject() !== null) { + $trainDataSet = Unlabeled::build([$trainText]) + ->apply(new MultibyteTextNormalizer()) + ->apply($this->wordCountVectorizer) + ->apply($this->tfIdfTransformer); + + // Use zeroed vector if no features could be extracted + if ($trainDataSet->numColumns() === 0) { + $textFeatures = array_fill(0, $this->max, 0); + } else { + $textFeatures = $trainDataSet->sample(0); + } + } + assert(count($textFeatures) === $this->max); + + return $textFeatures; + } + + private function getSubjects(): array { + return array_merge(...array_values($this->subjects)); + } + + private function getPreviewTexts(): array { + return array_merge(...array_values($this->previewTexts)); + } + + private function getSubjectsOfSender(string $email): array { + $concatSubjects = []; + foreach ($this->subjects as $sender => $subjects) { + if ($sender !== $email) { + continue; + } + + $concatSubjects[] = $subjects; + } + + return array_merge(...$concatSubjects); + } + + private function getPreviewTextsOfSender(string $email): array { + $concatPreviewTexts = []; + foreach ($this->previewTexts as $sender => $previewTexts) { + if ($sender !== $email) { + continue; + } + + $concatPreviewTexts[] = $previewTexts; + } + + return array_merge(...$concatPreviewTexts); + } +} diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php deleted file mode 100644 index db6921edf8..0000000000 --- a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php +++ /dev/null @@ -1,105 +0,0 @@ - - * - * @author Richard Steinmetz - * - * @license AGPL-3.0-or-later - * - * This code is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License, version 3, - * as published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License, version 3, - * along with this program. If not, see - * - */ - -namespace OCA\Mail\Service\Classification\FeatureExtraction; - -use OCA\Mail\Account; -use OCA\Mail\Db\Message; -use OCA\Mail\Db\StatisticsDao; -use Rubix\ML\Datasets\Dataset; -use Rubix\ML\Datasets\Labeled; -use Rubix\ML\Datasets\Unlabeled; -use Rubix\ML\Transformers\TextNormalizer; -use Rubix\ML\Transformers\TfIdfTransformer; -use Rubix\ML\Transformers\WordCountVectorizer; -use function OCA\Mail\array_flat_map; - -class SubjectExtractor { - /** @var StatisticsDao */ - private $statisticsDao; - - /** @var string[][] */ - private $subjects; - - public function __construct(StatisticsDao $statisticsDao) { - $this->statisticsDao = $statisticsDao; - } - - /** - * @inheritDoc - */ - public function prepare(Account $account, array $incomingMailboxes, array $outgoingMailboxes, array $messages): void { - /** @var string[] $senders */ - $senders = array_unique(array_map(function (Message $message) { - return $message->getFrom()->first()->getEmail(); - }, array_filter($messages, function (Message $message) { - return $message->getFrom()->first() !== null && $message->getFrom()->first()->getEmail() !== null; - }))); - - $this->subjects = $this->statisticsDao->getSubjectsGrouped($incomingMailboxes, $senders); - } - - /** - * @inheritDoc - */ - public function extract(string $email): array { - $concatSubjects = []; - foreach ($this->subjects as $sender => $subjects) { - if ($sender !== $email) { - continue; - } - - $concatSubjects[] = $subjects; - } - - $subject = implode(' ', array_merge(...$concatSubjects)); - $subjects = array_unique(array_merge(...$concatSubjects)); - - //$data = new Labeled([$subject], [$email]); - //$data = new Unlabeled($subjects); - $data = Unlabeled::build($subjects) - ->apply(new TextNormalizer()) - ->apply(new WordCountVectorizer(20)); - //->apply(new TfIdfTransformer()); - return $data->samples(); - } - - public function getSubjects(): array { - return array_merge(...array_values($this->subjects)); - } - - public function getSubjectsOfSender(string $email): array { - $concatSubjects = []; - foreach ($this->subjects as $sender => $subjects) { - if ($sender !== $email) { - continue; - } - - $concatSubjects[] = $subjects; - } - - return array_merge(...$concatSubjects); - } -} diff --git a/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php new file mode 100644 index 0000000000..0d907ea594 --- /dev/null +++ b/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php @@ -0,0 +1,40 @@ + + * + * @author Richard Steinmetz + * + * @license AGPL-3.0-or-later + * + * This code is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License, version 3, + * along with this program. If not, see + * + */ + +namespace OCA\Mail\Service\Classification\FeatureExtraction; + +class VanillaCompositeExtractor extends CompositeExtractor { + public function __construct(ImportantMessagesExtractor $ex1, + ReadMessagesExtractor $ex2, + RepliedMessagesExtractor $ex3, + SentMessagesExtractor $ex4) { + parent::__construct([ + $ex1, + $ex2, + $ex3, + $ex4, + ]); + } +} diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index 11b979c51b..2f3513c5ed 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -18,24 +18,17 @@ use OCA\Mail\Db\MessageMapper; use OCA\Mail\Exception\ClassifierTrainingException; use OCA\Mail\Exception\ServiceException; -use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor; -use OCA\Mail\Service\Classification\FeatureExtraction\SubjectExtractor; +use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; use OCA\Mail\Support\PerformanceLogger; use OCP\AppFramework\Db\DoesNotExistException; use Psr\Log\LoggerInterface; use Rubix\ML\Classifiers\ClassificationTree; use Rubix\ML\Classifiers\GaussianNB; -use Rubix\ML\Classifiers\NaiveBayes; use Rubix\ML\Classifiers\RandomForest; use Rubix\ML\CrossValidation\Reports\MulticlassBreakdown; use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Estimator; -use Rubix\ML\Transformers\MultibyteTextNormalizer; -use Rubix\ML\Transformers\TextNormalizer; -use Rubix\ML\Transformers\TfIdfTransformer; -use Rubix\ML\Transformers\WordCountVectorizer; -use Rubix\ML\Transformers\ZScaleStandardizer; use RuntimeException; use function array_column; use function array_combine; @@ -90,7 +83,7 @@ class ImportanceClassifier { /** * The maximum number of data sets to train the classifier with */ - private const MAX_TRAINING_SET_SIZE = 1000; + private const MAX_TRAINING_SET_SIZE = 4000; /** @var MailboxMapper */ private $mailboxMapper; @@ -98,12 +91,6 @@ class ImportanceClassifier { /** @var MessageMapper */ private $messageMapper; - /** @var CompositeExtractor */ - private $extractor; - - /** @var SubjectExtractor */ - private $subjectExtractor; - /** @var PersistenceService */ private $persistenceService; @@ -125,12 +112,9 @@ public function __construct(MailboxMapper $mailboxMapper, SubjectExtractor $subjectExtractor) { $this->mailboxMapper = $mailboxMapper; $this->messageMapper = $messageMapper; - $this->extractor = $extractor; $this->persistenceService = $persistenceService; $this->performanceLogger = $performanceLogger; $this->rulesClassifier = $rulesClassifier; - $this->logger = $logger; - $this->subjectExtractor = $subjectExtractor; } private function filterMessageHasSenderEmail(Message $message): bool { @@ -152,7 +136,7 @@ private function filterMessageHasSenderEmail(Message $message): bool { * * @param Account $account */ - public function train(Account $account, LoggerInterface $logger): void { + public function train(Account $account, LoggerInterface $logger, IExtractor $extractor, bool $shuffleDataSet = false): void { $perf = $this->performanceLogger->start('importance classifier training'); $incomingMailboxes = $this->getIncomingMailboxes($account); $logger->debug('found ' . count($incomingMailboxes) . ' incoming mailbox(es)'); @@ -179,11 +163,12 @@ public function train(Account $account, LoggerInterface $logger): void { } $perf->step('find latest ' . self::MAX_TRAINING_SET_SIZE . ' messages'); - $dataSet = $this->getFeaturesAndImportance($account, $incomingMailboxes, $outgoingMailboxes, $messages); + $dataSet = $this->getFeaturesAndImportance($account, $incomingMailboxes, $outgoingMailboxes, $messages, $extractor); + if ($shuffleDataSet) { + shuffle($dataSet); + } $perf->step('extract features from messages'); - //shuffle($dataSet); - /** * How many of the most recent messages are excluded from training? */ @@ -207,7 +192,7 @@ public function train(Account $account, LoggerInterface $logger): void { } } - $logger->debug('data set split into ' . count($trainingSet) . ' (' . $trainingSetImportantCount . ') training and ' . count($validationSet) . ' (' . $validationSetImportantCount . ') validation sets with ' . count($trainingSet[0]['features'] ?? []) . ' dimensions'); + $logger->debug('data set split into ' . count($trainingSet) . ' (' . self::LABEL_IMPORTANT . ': ' . $trainingSetImportantCount . ') training and ' . count($validationSet) . ' (' . self::LABEL_IMPORTANT . ': ' . $validationSetImportantCount . ') validation sets with ' . count($trainingSet[0]['features'] ?? []) . ' dimensions'); if ($validationSet === [] || $trainingSet === []) { $logger->info('not enough messages to train a classifier'); @@ -290,70 +275,18 @@ private function getOutgoingMailboxes(Account $account): array { private function getFeaturesAndImportance(Account $account, array $incomingMailboxes, array $outgoingMailboxes, - array $messages): array { - $this->extractor->prepare($account, $incomingMailboxes, $outgoingMailboxes, $messages); - $this->subjectExtractor->prepare($account, $incomingMailboxes, $outgoingMailboxes, $messages); - - $allSubjects = $this->subjectExtractor->getSubjects(); - - $max = 1000; - $wcv = new WordCountVectorizer($max, 1); - $tfidf = new TfIdfTransformer(); - $zscale = new ZScaleStandardizer(); - - $accData = Unlabeled::build($allSubjects) - ->apply(new MultibyteTextNormalizer()) - ->apply($wcv) - ->apply($tfidf) - ->apply($zscale); - - $vocab = $wcv->vocabularies()[0]; - //$vocab = array_slice($vocab, 0, $max); - //var_dump($vocab); - $max = count($vocab); + array $messages, + IExtractor $extractor): array { + $extractor->prepare($account, $incomingMailboxes, $outgoingMailboxes, $messages); - return array_map(function (Message $message) use ($max, $wcv, $tfidf, $zscale) { + return array_map(function (Message $message) use ($extractor) { $sender = $message->getFrom()->first(); if ($sender === null) { throw new RuntimeException('This should not happen'); } - $subjects = $this->subjectExtractor->getSubjectsOfSender($sender->getEmail()); - //$subjects = [$message->getSubject()]; - - $data = []; - if ($message->getSubject() !== null) { - //$fdata = Unlabeled::build([$message->getSubject()]) - $fdata = Unlabeled::build($subjects) - ->apply(new MultibyteTextNormalizer()) - ->apply($wcv) - ->apply($tfidf) - ->apply($zscale); - if ($fdata->numColumns() === 0) { - $data = array_fill(0, $max, 0); - } else { - $data = $fdata->sample(0); - } - } - if (count($data) > $max) { - //$data = array_slice($data, 0, $max); - } else { - while (count($data) < $max) { - //$data[] = 0; - } - } - - /* - $features = array_merge( - $this->extractor->extract($sender->getEmail()), - $data, - //$this->subjectExtractor->extract($sender->getEmail()), - ); - */ - $features = $data; - return [ - 'features' => $features, + 'features' => $extractor->extract($message), 'label' => $message->getFlagImportant() ? self::LABEL_IMPORTANT : self::LABEL_NOT_IMPORTANT, 'sender' => $sender->getEmail(), ]; From 0be1a218500cdff6c88a1bcaae0357db66806f33 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Tue, 24 Jan 2023 09:47:00 +0100 Subject: [PATCH 04/37] Cache features per sender --- .../FeatureExtraction/SubjectAndPreviewTextExtractor.php | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php index 6eb3710c0d..675d6c21a5 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php @@ -39,6 +39,7 @@ class SubjectAndPreviewTextExtractor implements IExtractor { private WordCountVectorizer $wordCountVectorizer; private TfIdfTransformer $tfIdfTransformer; private int $max = -1; + private array $senderCache = []; /** @var string[][] */ private array $subjects; @@ -51,7 +52,7 @@ public function __construct(StatisticsDao $statisticsDao) { // Limit vocabulary to limit ram usage. It takes about 5 GB of ram if an unbounded // vocabulary is used (and a lot more time to compute). $this->wordCountVectorizer = new WordCountVectorizer(1000); - $this->tfIdfTransformer = new TfIdfTransformer(); + $this->tfIdfTransformer = new TfIdfTransformer(0.1); } /** @@ -90,6 +91,10 @@ public function extract(Message $message): array { } $email = $sender->getEmail(); + if (isset($this->senderCache[$email])) { + return $this->senderCache[$email]; + } + // Build training data set $subjects = $this->getSubjectsOfSender($email); $previewTexts = $this->getPreviewTextsOfSender($email); @@ -111,6 +116,8 @@ public function extract(Message $message): array { } assert(count($textFeatures) === $this->max); + $this->senderCache[$email] = $textFeatures; + return $textFeatures; } From 99dfa0a3128f40d7c01aee0bbfa4bcc2e55eb902 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Tue, 24 Jan 2023 09:49:12 +0100 Subject: [PATCH 05/37] Implement preprocess command --- appinfo/info.xml | 1 + lib/Command/PreprocessAccount.php | 81 +++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 lib/Command/PreprocessAccount.php diff --git a/appinfo/info.xml b/appinfo/info.xml index 7d2bcbbbd9..f2beb804a9 100644 --- a/appinfo/info.xml +++ b/appinfo/info.xml @@ -90,6 +90,7 @@ Learn more about the Nextcloud Ethical AI Rating [in our blog](https://nextcloud OCA\Mail\Command\TrainAccount OCA\Mail\Command\UpdateAccount OCA\Mail\Command\UpdateSystemAutoresponders + OCA\Mail\Command\PreprocessAccount OCA\Mail\Settings\AdminSettings diff --git a/lib/Command/PreprocessAccount.php b/lib/Command/PreprocessAccount.php new file mode 100644 index 0000000000..07caa588e6 --- /dev/null +++ b/lib/Command/PreprocessAccount.php @@ -0,0 +1,81 @@ + + * + * @author Richard Steinmetz + * + * @license AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +namespace OCA\Mail\Command; + +use OCA\Mail\Service\AccountService; +use OCA\Mail\Service\PreprocessingService; +use OCP\AppFramework\Db\DoesNotExistException; +use Psr\Log\LoggerInterface; +use Symfony\Component\Console\Command\Command; +use Symfony\Component\Console\Input\InputArgument; +use Symfony\Component\Console\Input\InputInterface; +use Symfony\Component\Console\Output\OutputInterface; +use function memory_get_peak_usage; + +class PreprocessAccount extends Command { + public const ARGUMENT_ACCOUNT_ID = 'account-id'; + + private AccountService $accountService; + private PreprocessingService $preprocessingService; + private LoggerInterface $logger; + + public function __construct(AccountService $service, + PreprocessingService $preprocessingService, + LoggerInterface $logger) { + parent::__construct(); + + $this->accountService = $service; + $this->preprocessingService = $preprocessingService; + $this->logger = $logger; + } + + /** + * @return void + */ + protected function configure() { + $this->setName('mail:account:preprocess'); + $this->setDescription('Preprocess all mailboxes of an IMAP account'); + $this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED); + } + + protected function execute(InputInterface $input, OutputInterface $output): int { + $accountId = (int)$input->getArgument(self::ARGUMENT_ACCOUNT_ID); + + try { + $account = $this->accountService->findById($accountId); + } catch (DoesNotExistException $e) { + $output->writeln("Account $accountId does not exist"); + return 1; + } + + $this->preprocessingService->process(4294967296, $account); + + $mbs = (int)(memory_get_peak_usage() / 1024 / 1024); + $output->writeln('' . $mbs . 'MB of memory used'); + + return 0; + } +} From 28762e6aa2b7ef371e73f92f82e9eb6151dce479 Mon Sep 17 00:00:00 2001 From: Christoph Wurst Date: Thu, 26 Jan 2023 16:59:14 +0100 Subject: [PATCH 06/37] feat(importance-classifier): Reduce text feature vector Signed-off-by: Christoph Wurst --- .../SubjectAndPreviewTextExtractor.php | 70 +++++++++---------- .../Classification/ImportanceClassifier.php | 21 ++++-- 2 files changed, 49 insertions(+), 42 deletions(-) diff --git a/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php index 675d6c21a5..1bf7dcbfb2 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php @@ -28,53 +28,52 @@ use OCA\Mail\Account; use OCA\Mail\Db\Message; use OCA\Mail\Db\StatisticsDao; +use OCA\Mail\Service\Classification\ImportanceClassifier; +use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; +use Rubix\ML\Transformers\LinearDiscriminantAnalysis; use Rubix\ML\Transformers\MultibyteTextNormalizer; use Rubix\ML\Transformers\TfIdfTransformer; use Rubix\ML\Transformers\WordCountVectorizer; use RuntimeException; +use function array_column; +use function array_map; class SubjectAndPreviewTextExtractor implements IExtractor { private StatisticsDao $statisticsDao; private WordCountVectorizer $wordCountVectorizer; - private TfIdfTransformer $tfIdfTransformer; + private LinearDiscriminantAnalysis $ldaTransformer; private int $max = -1; - private array $senderCache = []; - - /** @var string[][] */ - private array $subjects; - /** @var string[][] */ - private array $previewTexts; + private array $senderCache = []; public function __construct(StatisticsDao $statisticsDao) { $this->statisticsDao = $statisticsDao; // Limit vocabulary to limit ram usage. It takes about 5 GB of ram if an unbounded // vocabulary is used (and a lot more time to compute). $this->wordCountVectorizer = new WordCountVectorizer(1000); - $this->tfIdfTransformer = new TfIdfTransformer(0.1); + $this->ldaTransformer = new LinearDiscriminantAnalysis(20); } /** * @inheritDoc */ public function prepare(Account $account, array $incomingMailboxes, array $outgoingMailboxes, array $messages): void { - /** @var string[] $senders */ - $senders = array_unique(array_map(function (Message $message) { - return $message->getFrom()->first()->getEmail(); - }, array_filter($messages, function (Message $message) { - return $message->getFrom()->first() !== null && $message->getFrom()->first()->getEmail() !== null; - }))); - - $this->subjects = $this->statisticsDao->getSubjects($incomingMailboxes, $senders); - $this->previewTexts = $this->statisticsDao->getPreviewTexts($incomingMailboxes, $senders); + $data = array_map(function(Message $message) { + return [ + 'text' => ($message->getSubject() ?? '') . ' ' . ($message->getPreviewText() ?? ''), + 'label' => $message->getFlagImportant() ? 'i' : 'ni', + ]; + }, $messages); // Fit transformers - $fitText = implode(' ', [...$this->getSubjects(), ...$this->getPreviewTexts()]); - Unlabeled::build([$fitText]) + Labeled::build( + array_column($data, 'text'), + array_column($data, 'label'), + ) ->apply(new MultibyteTextNormalizer()) ->apply($this->wordCountVectorizer) - ->apply($this->tfIdfTransformer); + ->apply($this->ldaTransformer); // Limit feature vector length to actual vocabulary size $vocab = $this->wordCountVectorizer->vocabularies()[0]; @@ -96,25 +95,20 @@ public function extract(Message $message): array { } // Build training data set - $subjects = $this->getSubjectsOfSender($email); - $previewTexts = $this->getPreviewTextsOfSender($email); - $trainText = implode(' ', [...$subjects, ...$previewTexts]); - - $textFeatures = []; - if ($message->getSubject() !== null) { - $trainDataSet = Unlabeled::build([$trainText]) - ->apply(new MultibyteTextNormalizer()) - ->apply($this->wordCountVectorizer) - ->apply($this->tfIdfTransformer); - - // Use zeroed vector if no features could be extracted - if ($trainDataSet->numColumns() === 0) { - $textFeatures = array_fill(0, $this->max, 0); - } else { - $textFeatures = $trainDataSet->sample(0); - } + $trainText = ($message->getSubject() ?? '') . ' ' . ($message->getPreviewText() ?? ''); + + $trainDataSet = Unlabeled::build([$trainText]) + ->apply(new MultibyteTextNormalizer()) + ->apply($this->wordCountVectorizer) + ->apply($this->ldaTransformer); + + // Use zeroed vector if no features could be extracted + if ($trainDataSet->numColumns() === 0) { + $textFeatures = array_fill(0, $this->max, 0); + } else { + $textFeatures = $trainDataSet->sample(0); } - assert(count($textFeatures) === $this->max); + assert(count($textFeatures) === 20); $this->senderCache[$email] = $textFeatures; diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index 2f3513c5ed..bb071bfe9b 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -24,11 +24,16 @@ use Psr\Log\LoggerInterface; use Rubix\ML\Classifiers\ClassificationTree; use Rubix\ML\Classifiers\GaussianNB; +use Rubix\ML\Classifiers\MultilayerPerceptron; use Rubix\ML\Classifiers\RandomForest; use Rubix\ML\CrossValidation\Reports\MulticlassBreakdown; use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Estimator; +use Rubix\ML\NeuralNet\ActivationFunctions\Sigmoid; +use Rubix\ML\NeuralNet\Layers\Activation; +use Rubix\ML\NeuralNet\Layers\Dense; +use Rubix\ML\NeuralNet\Optimizers\Adam; use RuntimeException; use function array_column; use function array_combine; @@ -83,7 +88,7 @@ class ImportanceClassifier { /** * The maximum number of data sets to train the classifier with */ - private const MAX_TRAINING_SET_SIZE = 4000; + private const MAX_TRAINING_SET_SIZE = 1000; /** @var MailboxMapper */ private $mailboxMapper; @@ -348,15 +353,23 @@ public function classifyImportance(Account $account, array $messages): array { } private function trainClassifier(array $trainingSet): Estimator { - $classifier = new GaussianNB(); - /* + //$classifier = new GaussianNB(); $classifier = new RandomForest( new ClassificationTree(10, 1), 10, 0.2, true, ); - */ + /*$classifier = new MultilayerPerceptron( + [ + new Dense(1004), + new Activation(new Sigmoid()) + ], + 32, + null, + 1e-4, + 10, + );*/ $classifier->train(Labeled::build( array_column($trainingSet, 'features'), array_column($trainingSet, 'label') From af70adf343a333784f9179e6fd72e8650553b30c Mon Sep 17 00:00:00 2001 From: Christoph Wurst Date: Thu, 26 Jan 2023 17:48:37 +0100 Subject: [PATCH 07/37] fixup! feat(importance-classifier): Reduce text feature vector Signed-off-by: Christoph Wurst --- .../SubjectAndPreviewTextExtractor.php | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php index 1bf7dcbfb2..f71583b196 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php @@ -33,7 +33,9 @@ use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Transformers\LinearDiscriminantAnalysis; use Rubix\ML\Transformers\MultibyteTextNormalizer; +use Rubix\ML\Transformers\PrincipalComponentAnalysis; use Rubix\ML\Transformers\TfIdfTransformer; +use Rubix\ML\Transformers\Transformer; use Rubix\ML\Transformers\WordCountVectorizer; use RuntimeException; use function array_column; @@ -42,7 +44,7 @@ class SubjectAndPreviewTextExtractor implements IExtractor { private StatisticsDao $statisticsDao; private WordCountVectorizer $wordCountVectorizer; - private LinearDiscriminantAnalysis $ldaTransformer; + private Transformer $dimensionalReductionTransformer; private int $max = -1; private array $senderCache = []; @@ -51,8 +53,8 @@ public function __construct(StatisticsDao $statisticsDao) { $this->statisticsDao = $statisticsDao; // Limit vocabulary to limit ram usage. It takes about 5 GB of ram if an unbounded // vocabulary is used (and a lot more time to compute). - $this->wordCountVectorizer = new WordCountVectorizer(1000); - $this->ldaTransformer = new LinearDiscriminantAnalysis(20); + $this->wordCountVectorizer = new WordCountVectorizer(100); + $this->dimensionalReductionTransformer = new PrincipalComponentAnalysis(15); } /** @@ -73,7 +75,7 @@ public function prepare(Account $account, array $incomingMailboxes, array $outgo ) ->apply(new MultibyteTextNormalizer()) ->apply($this->wordCountVectorizer) - ->apply($this->ldaTransformer); + ->apply($this->dimensionalReductionTransformer); // Limit feature vector length to actual vocabulary size $vocab = $this->wordCountVectorizer->vocabularies()[0]; @@ -100,7 +102,7 @@ public function extract(Message $message): array { $trainDataSet = Unlabeled::build([$trainText]) ->apply(new MultibyteTextNormalizer()) ->apply($this->wordCountVectorizer) - ->apply($this->ldaTransformer); + ->apply($this->dimensionalReductionTransformer); // Use zeroed vector if no features could be extracted if ($trainDataSet->numColumns() === 0) { @@ -108,7 +110,6 @@ public function extract(Message $message): array { } else { $textFeatures = $trainDataSet->sample(0); } - assert(count($textFeatures) === 20); $this->senderCache[$email] = $textFeatures; From 08b1e1b8ad5b79af0936ad1b6192f4e16d7363a1 Mon Sep 17 00:00:00 2001 From: Christoph Wurst Date: Fri, 27 Jan 2023 10:16:25 +0100 Subject: [PATCH 08/37] fixup! feat(importance-classifier): Reduce text feature vector Signed-off-by: Christoph Wurst --- lib/Service/Classification/ImportanceClassifier.php | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index bb071bfe9b..aed5838813 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -24,12 +24,14 @@ use Psr\Log\LoggerInterface; use Rubix\ML\Classifiers\ClassificationTree; use Rubix\ML\Classifiers\GaussianNB; +use Rubix\ML\Classifiers\KNearestNeighbors; use Rubix\ML\Classifiers\MultilayerPerceptron; use Rubix\ML\Classifiers\RandomForest; use Rubix\ML\CrossValidation\Reports\MulticlassBreakdown; use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Estimator; +use Rubix\ML\Kernels\Distance\Manhattan; use Rubix\ML\NeuralNet\ActivationFunctions\Sigmoid; use Rubix\ML\NeuralNet\Layers\Activation; use Rubix\ML\NeuralNet\Layers\Dense; @@ -88,7 +90,7 @@ class ImportanceClassifier { /** * The maximum number of data sets to train the classifier with */ - private const MAX_TRAINING_SET_SIZE = 1000; + private const MAX_TRAINING_SET_SIZE = 2000; /** @var MailboxMapper */ private $mailboxMapper; @@ -354,12 +356,13 @@ public function classifyImportance(Account $account, array $messages): array { private function trainClassifier(array $trainingSet): Estimator { //$classifier = new GaussianNB(); - $classifier = new RandomForest( + /*$classifier = new RandomForest( new ClassificationTree(10, 1), 10, 0.2, true, - ); + );*/ + $classifier = new KNearestNeighbors(3, false, new Manhattan()); /*$classifier = new MultilayerPerceptron( [ new Dense(1004), From d7cca9c0d60088765653563bc248395e3422a880 Mon Sep 17 00:00:00 2001 From: Christoph Wurst Date: Mon, 30 Jan 2023 13:16:44 +0100 Subject: [PATCH 09/37] fixup! feat(importance-classifier): Reduce text feature vector Signed-off-by: Christoph Wurst --- lib/Service/Classification/ImportanceClassifier.php | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index aed5838813..e6c6f9c173 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -31,6 +31,7 @@ use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Estimator; +use Rubix\ML\Kernels\Distance\Jaccard; use Rubix\ML\Kernels\Distance\Manhattan; use Rubix\ML\NeuralNet\ActivationFunctions\Sigmoid; use Rubix\ML\NeuralNet\Layers\Activation; @@ -362,7 +363,7 @@ private function trainClassifier(array $trainingSet): Estimator { 0.2, true, );*/ - $classifier = new KNearestNeighbors(3, false, new Manhattan()); + $classifier = new KNearestNeighbors(3, true, new Jaccard()); /*$classifier = new MultilayerPerceptron( [ new Dense(1004), From a9f7399d2bbcf6f24c79fde934819ed5fe770397 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Tue, 31 Jan 2023 16:50:24 +0100 Subject: [PATCH 10/37] fixup! feat(importance-classifier): Reduce text feature vector --- .../SubjectAndPreviewTextExtractor.php | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php index f71583b196..859a92253e 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php @@ -31,9 +31,12 @@ use OCA\Mail\Service\Classification\ImportanceClassifier; use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; +use Rubix\ML\Transformers\TSNE; +use Rubix\ML\Transformers\GaussianRandomProjector; use Rubix\ML\Transformers\LinearDiscriminantAnalysis; use Rubix\ML\Transformers\MultibyteTextNormalizer; use Rubix\ML\Transformers\PrincipalComponentAnalysis; +use Rubix\ML\Transformers\SparseRandomProjector; use Rubix\ML\Transformers\TfIdfTransformer; use Rubix\ML\Transformers\Transformer; use Rubix\ML\Transformers\WordCountVectorizer; @@ -51,17 +54,20 @@ class SubjectAndPreviewTextExtractor implements IExtractor { public function __construct(StatisticsDao $statisticsDao) { $this->statisticsDao = $statisticsDao; + // Limit vocabulary to limit ram usage. It takes about 5 GB of ram if an unbounded // vocabulary is used (and a lot more time to compute). - $this->wordCountVectorizer = new WordCountVectorizer(100); - $this->dimensionalReductionTransformer = new PrincipalComponentAnalysis(15); + $vocabSize = 100; + $this->wordCountVectorizer = new WordCountVectorizer($vocabSize); + + $this->dimensionalReductionTransformer = new TSNE((int)($vocabSize * 0.1)); } /** * @inheritDoc */ public function prepare(Account $account, array $incomingMailboxes, array $outgoingMailboxes, array $messages): void { - $data = array_map(function(Message $message) { + $data = array_map(static function(Message $message) { return [ 'text' => ($message->getSubject() ?? '') . ' ' . ($message->getPreviewText() ?? ''), 'label' => $message->getFlagImportant() ? 'i' : 'ni', @@ -93,7 +99,7 @@ public function extract(Message $message): array { $email = $sender->getEmail(); if (isset($this->senderCache[$email])) { - return $this->senderCache[$email]; + //return $this->senderCache[$email]; } // Build training data set @@ -105,13 +111,13 @@ public function extract(Message $message): array { ->apply($this->dimensionalReductionTransformer); // Use zeroed vector if no features could be extracted - if ($trainDataSet->numColumns() === 0) { + if ($trainDataSet->numFeatures() === 0) { $textFeatures = array_fill(0, $this->max, 0); } else { $textFeatures = $trainDataSet->sample(0); } - $this->senderCache[$email] = $textFeatures; + //$this->senderCache[$email] = $textFeatures; return $textFeatures; } From cee58bfd743a456191d12e06998480df69010591 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Tue, 31 Jan 2023 16:59:30 +0100 Subject: [PATCH 11/37] fixup! feat(importance-classifier): Reduce text feature vector --- lib/Service/Classification/ImportanceClassifier.php | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index e6c6f9c173..83b35448fc 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -363,7 +363,7 @@ private function trainClassifier(array $trainingSet): Estimator { 0.2, true, );*/ - $classifier = new KNearestNeighbors(3, true, new Jaccard()); + $classifier = new KNearestNeighbors(5, true, new Jaccard()); /*$classifier = new MultilayerPerceptron( [ new Dense(1004), @@ -403,9 +403,14 @@ private function validateClassifier(Estimator $estimator, $predictedValidationLabel, array_column($validationSet, 'label') ); + /* $recallImportant = $report['classes'][self::LABEL_IMPORTANT]['recall'] ?? 0; $precisionImportant = $report['classes'][self::LABEL_IMPORTANT]['precision'] ?? 0; $f1ScoreImportant = $report['classes'][self::LABEL_IMPORTANT]['f1 score'] ?? 0; + */ + $recallImportant = $report['overall']['recall'] ?? 0; + $precisionImportant = $report['overall']['precision'] ?? 0; + $f1ScoreImportant = $report['overall']['f1 score'] ?? 0; /** * What we care most is the percentage of messages classified as important in relation to the truly important messages From 0e82c52aa34e5033e0aa1b6861d790e126c5bfb2 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Tue, 31 Jan 2023 17:00:19 +0100 Subject: [PATCH 12/37] fixup! feat(importance-classifier): Reduce text feature vector --- lib/Service/Classification/ImportanceClassifier.php | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index 83b35448fc..b6e581f683 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -403,14 +403,12 @@ private function validateClassifier(Estimator $estimator, $predictedValidationLabel, array_column($validationSet, 'label') ); - /* $recallImportant = $report['classes'][self::LABEL_IMPORTANT]['recall'] ?? 0; $precisionImportant = $report['classes'][self::LABEL_IMPORTANT]['precision'] ?? 0; $f1ScoreImportant = $report['classes'][self::LABEL_IMPORTANT]['f1 score'] ?? 0; - */ - $recallImportant = $report['overall']['recall'] ?? 0; - $precisionImportant = $report['overall']['precision'] ?? 0; - $f1ScoreImportant = $report['overall']['f1 score'] ?? 0; + //$recallImportant = $report['overall']['recall'] ?? 0; + //$precisionImportant = $report['overall']['precision'] ?? 0; + //$f1ScoreImportant = $report['overall']['f1 score'] ?? 0; /** * What we care most is the percentage of messages classified as important in relation to the truly important messages From bd82bea0660198b224d6955d42eeae1919604df2 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Tue, 31 Jan 2023 17:47:52 +0100 Subject: [PATCH 13/37] fixup! feat(importance-classifier): Reduce text feature vector --- lib/Service/Classification/ImportanceClassifier.php | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index b6e581f683..2e96204ae2 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -239,6 +239,8 @@ public function train(Account $account, LoggerInterface $logger, IExtractor $ext * @return Mailbox[] */ private function getIncomingMailboxes(Account $account): array { + return [$this->mailboxMapper->find($account, 'INBOX')]; + /* return array_filter($this->mailboxMapper->findAll($account), static function (Mailbox $mailbox) { foreach (self::EXEMPT_FROM_TRAINING as $excluded) { if ($mailbox->isSpecialUse($excluded)) { @@ -247,6 +249,7 @@ private function getIncomingMailboxes(Account $account): array { } return true; }); + */ } /** From 6c6e2cad2c2ad0c0df07ca8605532007c9ce6db6 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Thu, 2 Mar 2023 17:53:02 +0100 Subject: [PATCH 14/37] fixup! feat(importance-classifier): Reduce text feature vector --- appinfo/info.xml | 1 + .../TrainImportanceClassifierJob.php | 2 +- lib/Command/RunMetaEstimator.php | 135 ++++++++++++ lib/Command/TrainAccount.php | 67 +++++- .../Classification/ImportanceClassifier.php | 194 +++++++++++++----- 5 files changed, 346 insertions(+), 53 deletions(-) create mode 100644 lib/Command/RunMetaEstimator.php diff --git a/appinfo/info.xml b/appinfo/info.xml index f2beb804a9..b0e0c1e72a 100644 --- a/appinfo/info.xml +++ b/appinfo/info.xml @@ -91,6 +91,7 @@ Learn more about the Nextcloud Ethical AI Rating [in our blog](https://nextcloud OCA\Mail\Command\UpdateAccount OCA\Mail\Command\UpdateSystemAutoresponders OCA\Mail\Command\PreprocessAccount + OCA\Mail\Command\RunMetaEstimator OCA\Mail\Settings\AdminSettings diff --git a/lib/BackgroundJob/TrainImportanceClassifierJob.php b/lib/BackgroundJob/TrainImportanceClassifierJob.php index d6482cbe62..e3bbc4c44b 100644 --- a/lib/BackgroundJob/TrainImportanceClassifierJob.php +++ b/lib/BackgroundJob/TrainImportanceClassifierJob.php @@ -69,7 +69,7 @@ protected function run($argument) { } try { - $this->classifier->train( + $this->classifier->trainWithDefaultEstimator( $account, $this->logger ); diff --git a/lib/Command/RunMetaEstimator.php b/lib/Command/RunMetaEstimator.php new file mode 100644 index 0000000000..387d8cd6e6 --- /dev/null +++ b/lib/Command/RunMetaEstimator.php @@ -0,0 +1,135 @@ + + * + * @author Richard Steinmetz + * + * @license AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +namespace OCA\Mail\Command; + +use OCA\Mail\Service\AccountService; +use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor; +use OCA\Mail\Service\Classification\ImportanceClassifier; +use OCA\Mail\Support\ConsoleLoggerDecorator; +use OCP\AppFramework\Db\DoesNotExistException; +use Psr\Container\ContainerInterface; +use Psr\Log\LoggerInterface; +use Rubix\ML\Backends\Amp; +use Rubix\ML\Classifiers\KNearestNeighbors; +use Rubix\ML\CrossValidation\KFold; +use Rubix\ML\CrossValidation\Metrics\FBeta; +use Rubix\ML\GridSearch; +use Rubix\ML\Kernels\Distance\Euclidean; +use Rubix\ML\Kernels\Distance\Jaccard; +use Rubix\ML\Kernels\Distance\Manhattan; +use Symfony\Component\Console\Command\Command; +use Symfony\Component\Console\Input\InputArgument; +use Symfony\Component\Console\Input\InputInterface; +use Symfony\Component\Console\Input\InputOption; +use Symfony\Component\Console\Output\OutputInterface; + +class RunMetaEstimator extends Command { + public const ARGUMENT_ACCOUNT_ID = 'account-id'; + public const ARGUMENT_LOAD_DATA = 'load-data'; + + private AccountService $accountService; + private LoggerInterface $logger; + private ImportanceClassifier $classifier; + private ContainerInterface $container; + + public function __construct( + AccountService $accountService, + LoggerInterface $logger, + ImportanceClassifier $classifier, + ContainerInterface $container, + ) { + parent::__construct(); + + $this->accountService = $accountService; + $this->logger = $logger; + $this->classifier = $classifier; + $this->container = $container; + } + + protected function configure(): void { + $this->setName('mail:account:run-meta-estimator'); + $this->setDescription('Run the meta estimator for an account'); + $this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED); + $this->addOption( + self::ARGUMENT_LOAD_DATA, + null, + InputOption::VALUE_REQUIRED, + 'Load training data set from a JSON file' + ); + } + + protected function execute(InputInterface $input, OutputInterface $output): int { + $accountId = (int)$input->getArgument(self::ARGUMENT_ACCOUNT_ID); + + try { + $account = $this->accountService->findById($accountId); + } catch (DoesNotExistException $e) { + $output->writeln("Account $accountId does not exist"); + return 1; + } + + $extractor = $this->container->get(NewCompositeExtractor::class); + $consoleLogger = new ConsoleLoggerDecorator( + $this->logger, + $output + ); + + if ($loadDataPath = $input->getOption(self::ARGUMENT_LOAD_DATA)) { + $json = file_get_contents($loadDataPath); + $dataSet = json_decode($json, true); + } else { + $dataSet = $this->classifier->buildDataSet($account, $extractor, $consoleLogger); + } + + $params = [ + [1, 3, 5, 10], // Neighbors + [true, false], // Weighted? + [new Euclidean(), new Manhattan(), new Jaccard()], // Kernel + ]; + + $this->classifier->trainWithCustomDataSet( + $account, + $consoleLogger, + $dataSet, + static function () use ($params, $consoleLogger) { + $estimator = new GridSearch( + KNearestNeighbors::class, + $params, + new FBeta(), + new KFold(3) + ); + $estimator->setLogger($consoleLogger); + $estimator->setBackend(new Amp()); + return $estimator; + }, + null, + false, + ); + + return 0; + } +} diff --git a/lib/Command/TrainAccount.php b/lib/Command/TrainAccount.php index 80b1870081..f1a69fc314 100644 --- a/lib/Command/TrainAccount.php +++ b/lib/Command/TrainAccount.php @@ -21,6 +21,7 @@ use Symfony\Component\Console\Command\Command; use Symfony\Component\Console\Input\InputArgument; use Symfony\Component\Console\Input\InputInterface; +use Symfony\Component\Console\Input\InputOption; use Symfony\Component\Console\Output\OutputInterface; use function memory_get_peak_usage; @@ -28,6 +29,9 @@ class TrainAccount extends Command { public const ARGUMENT_ACCOUNT_ID = 'account-id'; public const ARGUMENT_NEW = 'new'; public const ARGUMENT_SHUFFLE = 'shuffle'; + public const ARGUMENT_SAVE_DATA = 'save-data'; + public const ARGUMENT_LOAD_DATA = 'load-data'; + public const ARGUMENT_DRY_RUN = 'dry-run'; private AccountService $accountService; private ImportanceClassifier $classifier; @@ -58,6 +62,24 @@ protected function configure() { $this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED); $this->addOption(self::ARGUMENT_NEW, null, null, 'Enable new composite extractor using text based features'); $this->addOption(self::ARGUMENT_SHUFFLE, null, null, 'Shuffle data set before training'); + $this->addOption( + self::ARGUMENT_DRY_RUN, + null, + null, + 'Don\'t persist classifier after training' + ); + $this->addOption( + self::ARGUMENT_SAVE_DATA, + null, + InputOption::VALUE_REQUIRED, + 'Save training data set to a JSON file' + ); + $this->addOption( + self::ARGUMENT_LOAD_DATA, + null, + InputOption::VALUE_REQUIRED, + 'Load training data set from a JSON file' + ); } /** @@ -65,6 +87,8 @@ protected function configure() { */ protected function execute(InputInterface $input, OutputInterface $output): int { $accountId = (int)$input->getArgument(self::ARGUMENT_ACCOUNT_ID); + $shuffle = (bool)$input->getOption(self::ARGUMENT_SHUFFLE); + $dryRun = (bool)$input->getOption(self::ARGUMENT_DRY_RUN); try { $account = $this->accountService->findById($accountId); @@ -90,12 +114,43 @@ protected function execute(InputInterface $input, OutputInterface $output): int $this->logger, $output ); - $this->classifier->train( - $account, - $consoleLogger, - $extractor, - (bool)$input->getOption(self::ARGUMENT_SHUFFLE), - ); + + $dataSet = null; + if ($saveDataPath = $input->getOption(self::ARGUMENT_SAVE_DATA)) { + $dataSet = $this->classifier->buildDataSet( + $account, + $extractor, + $consoleLogger, + null, + $shuffle, + ); + $json = json_encode($dataSet); + file_put_contents($saveDataPath, $json); + } else if ($loadDataPath = $input->getOption(self::ARGUMENT_LOAD_DATA)) { + $json = file_get_contents($loadDataPath); + $dataSet = json_decode($json, true); + } + + if ($dataSet) { + $this->classifier->trainWithCustomDataSet( + $account, + $consoleLogger, + $dataSet, + null, + null, + !$dryRun + ); + } else { + $this->classifier->train( + $account, + $consoleLogger, + $extractor, + null, + $shuffle, + !$dryRun + ); + + } $mbs = (int)(memory_get_peak_usage() / 1024 / 1024); $output->writeln('' . $mbs . 'MB of memory used'); diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index 2e96204ae2..03bbc6aa00 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -9,6 +9,7 @@ namespace OCA\Mail\Service\Classification; +use Closure; use Horde_Imap_Client; use OCA\Mail\Account; use OCA\Mail\Db\Classifier; @@ -20,7 +21,9 @@ use OCA\Mail\Exception\ServiceException; use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; use OCA\Mail\Support\PerformanceLogger; +use OCA\Mail\Support\PerformanceLoggerTask; use OCP\AppFramework\Db\DoesNotExistException; +use PHPUnit\Framework\Constraint\Callback; use Psr\Log\LoggerInterface; use Rubix\ML\Classifiers\ClassificationTree; use Rubix\ML\Classifiers\GaussianNB; @@ -33,10 +36,12 @@ use Rubix\ML\Estimator; use Rubix\ML\Kernels\Distance\Jaccard; use Rubix\ML\Kernels\Distance\Manhattan; +use Rubix\ML\Learner; use Rubix\ML\NeuralNet\ActivationFunctions\Sigmoid; use Rubix\ML\NeuralNet\Layers\Activation; use Rubix\ML\NeuralNet\Layers\Dense; use Rubix\ML\NeuralNet\Optimizers\Adam; +use Rubix\ML\Persistable; use RuntimeException; use function array_column; use function array_combine; @@ -91,7 +96,7 @@ class ImportanceClassifier { /** * The maximum number of data sets to train the classifier with */ - private const MAX_TRAINING_SET_SIZE = 2000; + private const MAX_TRAINING_SET_SIZE = 200; /** @var MailboxMapper */ private $mailboxMapper; @@ -130,22 +135,55 @@ private function filterMessageHasSenderEmail(Message $message): bool { } /** - * Train an account's classifier of important messages - * - * Train a classifier based on a user's existing messages to be able to derive - * importance markers for new incoming messages. - * - * To factor in (server-side) filtering into multiple mailboxes, the algorithm - * will not only look for messages in the inbox but also other non-special - * mailboxes. - * - * To prevent memory exhaustion, the process will only load a fixed maximum - * number of messages per account. + * @return Estimator&Learner&Persistable + */ + private function getDefaultEstimator(): Estimator { + //return new GaussianNB(); + + /* + return new RandomForest( + new ClassificationTree(10, 1), + 10, + 0.2, + true, + ); + */ + + /* + return new MultilayerPerceptron( + [ + new Dense(1004), + new Activation(new Sigmoid()) + ], + 32, + null, + 1e-4, + 10, + ); + */ + + return new KNearestNeighbors(5, true, new Jaccard()); + } + + /** + * Build a data set for training an importance classifier. * * @param Account $account + * @param IExtractor $extractor + * @param LoggerInterface $logger + * @param PerformanceLoggerTask|null $perf + * @param bool $shuffle + * @return array|null Returns null if there are not enough messages to train */ - public function train(Account $account, LoggerInterface $logger, IExtractor $extractor, bool $shuffleDataSet = false): void { - $perf = $this->performanceLogger->start('importance classifier training'); + public function buildDataSet( + Account $account, + IExtractor $extractor, + LoggerInterface $logger, + ?PerformanceLoggerTask $perf = null, + bool $shuffle = false, + ): ?array { + $perf ??= $this->performanceLogger->start('build data set for importance classifier training'); + $incomingMailboxes = $this->getIncomingMailboxes($account); $logger->debug('found ' . count($incomingMailboxes) . ' incoming mailbox(es)'); $perf->step('find incoming mailboxes'); @@ -166,23 +204,98 @@ public function train(Account $account, LoggerInterface $logger, IExtractor $ext $logger->debug('found ' . count($messages) . ' messages of which ' . count($importantMessages) . ' are important'); if (count($importantMessages) < self::COLD_START_THRESHOLD) { $logger->info('not enough messages to train a classifier'); - $perf->end(); - return; + return null; } $perf->step('find latest ' . self::MAX_TRAINING_SET_SIZE . ' messages'); $dataSet = $this->getFeaturesAndImportance($account, $incomingMailboxes, $outgoingMailboxes, $messages, $extractor); - if ($shuffleDataSet) { + if ($shuffle) { shuffle($dataSet); } - $perf->step('extract features from messages'); + + return $dataSet; + } + + /** + * Train an account's classifier of important messages + * + * Train a classifier based on a user's existing messages to be able to derive + * importance markers for new incoming messages. + * + * To factor in (server-side) filtering into multiple mailboxes, the algorithm + * will not only look for messages in the inbox but also other non-special + * mailboxes. + * + * To prevent memory exhaustion, the process will only load a fixed maximum + * number of messages per account. + * + * @param Account $account + * @param LoggerInterface $logger + * @param IExtractor $extractor + * @param ?Closure $estimator Returned instance should at least implement Learner, Estimator and Persistable. If null, the default estimator will be used. + * @param bool $shuffleDataSet Shuffle the data set before training + * @param bool $persist Persist the trained classifier to use it for message classification + * + * @throws ServiceException + */ + public function train( + Account $account, + LoggerInterface $logger, + IExtractor $extractor, + ?Closure $estimator = null, + bool $shuffleDataSet = false, + bool $persist = true, + ): void { + $perf = $this->performanceLogger->start('importance classifier training'); + + $dataSet = $this->buildDataSet($account, $extractor, $logger, $perf, $shuffleDataSet); + if ($dataSet === null) { + return; + } + + $this->trainWithCustomDataSet( + $account, + $logger, + $dataSet, + $estimator, + $perf, + $persist, + ); + } + + /** + * Train a classifier using a custom data set. + * + * @param Account $account + * @param LoggerInterface $logger + * @param array $dataSet Training data set built by buildDataSet() + * @param ?Closure $estimator Returned instance should at least implement Learner, Estimator and Persistable. If null, the default estimator will be used. + * @param bool $persist Persist the trained classifier to use it for message classification + * + * @throws ServiceException + */ + public function trainWithCustomDataSet( + Account $account, + LoggerInterface $logger, + array $dataSet, + ?Closure $estimator = null, + ?PerformanceLoggerTask $perf = null, + bool $persist = true, + ): void { + if ($estimator === null) { + $estimator = function () { + return $this->getDefaultEstimator(); + }; + } + + $perf ??= $this->performanceLogger->start('importance classifier training'); /** * How many of the most recent messages are excluded from training? */ $validationThreshold = max( 5, - (int)(count($dataSet) * 0.25) + (int)(count($dataSet) * 0.2) ); $validationSet = array_slice($dataSet, 0, $validationThreshold); $trainingSet = array_slice($dataSet, $validationThreshold); @@ -207,7 +320,10 @@ public function train(Account $account, LoggerInterface $logger, IExtractor $ext $perf->end(); return; } - $validationEstimator = $this->trainClassifier($trainingSet); + + /** @var Learner&Estimator&Persistable $validationEstimator */ + $validationEstimator = $estimator(); + $this->trainClassifier($validationEstimator, $validationSet); try { $classifier = $this->validateClassifier( $validationEstimator, @@ -224,15 +340,20 @@ public function train(Account $account, LoggerInterface $logger, IExtractor $ext } $perf->step('train and validate classifier with training and validation sets'); - $estimator = $this->trainClassifier($dataSet); - $perf->step('train classifier with full data set'); + if ($persist) { + /** @var Learner&Estimator&Persistable $persistedEstimator */ + $persistedEstimator = $estimator(); + $this->trainClassifier($persistedEstimator, $dataSet); + $perf->step("train classifier with full data set"); - $classifier->setAccountId($account->getId()); - $classifier->setDuration($perf->end()); - $this->persistenceService->persist($classifier, $estimator); - $logger->debug("classifier {$classifier->getId()} persisted"); + $classifier->setAccountId($account->getId()); + $classifier->setDuration($perf->end()); + $this->persistenceService->persist($classifier, $persistedEstimator); + $logger->debug("classifier {$classifier->getId()} persisted"); + } } + /** * @param Account $account * @@ -358,30 +479,11 @@ public function classifyImportance(Account $account, array $messages): array { ); } - private function trainClassifier(array $trainingSet): Estimator { - //$classifier = new GaussianNB(); - /*$classifier = new RandomForest( - new ClassificationTree(10, 1), - 10, - 0.2, - true, - );*/ - $classifier = new KNearestNeighbors(5, true, new Jaccard()); - /*$classifier = new MultilayerPerceptron( - [ - new Dense(1004), - new Activation(new Sigmoid()) - ], - 32, - null, - 1e-4, - 10, - );*/ + private function trainClassifier(Learner $classifier, array $trainingSet): void { $classifier->train(Labeled::build( array_column($trainingSet, 'features'), array_column($trainingSet, 'label') )); - return $classifier; } /** From 18767c7e34d1164cf25afe2a5801e3310348ac72 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Fri, 3 Mar 2023 14:08:27 +0100 Subject: [PATCH 15/37] fixup! feat(importance-classifier): Reduce text feature vector --- lib/Command/RunMetaEstimator.php | 68 +++++++++++++++++++++++--------- 1 file changed, 49 insertions(+), 19 deletions(-) diff --git a/lib/Command/RunMetaEstimator.php b/lib/Command/RunMetaEstimator.php index 387d8cd6e6..d3711a3dbc 100644 --- a/lib/Command/RunMetaEstimator.php +++ b/lib/Command/RunMetaEstimator.php @@ -49,6 +49,7 @@ class RunMetaEstimator extends Command { public const ARGUMENT_ACCOUNT_ID = 'account-id'; + public const ARGUMENT_SHUFFLE = 'shuffle'; public const ARGUMENT_LOAD_DATA = 'load-data'; private AccountService $accountService; @@ -74,6 +75,7 @@ protected function configure(): void { $this->setName('mail:account:run-meta-estimator'); $this->setDescription('Run the meta estimator for an account'); $this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED); + $this->addOption(self::ARGUMENT_SHUFFLE, null, null, 'Shuffle data set before training'); $this->addOption( self::ARGUMENT_LOAD_DATA, null, @@ -84,6 +86,7 @@ protected function configure(): void { protected function execute(InputInterface $input, OutputInterface $output): int { $accountId = (int)$input->getArgument(self::ARGUMENT_ACCOUNT_ID); + $shuffle = (bool)$input->getOption(self::ARGUMENT_SHUFFLE); try { $account = $this->accountService->findById($accountId); @@ -102,7 +105,13 @@ protected function execute(InputInterface $input, OutputInterface $output): int $json = file_get_contents($loadDataPath); $dataSet = json_decode($json, true); } else { - $dataSet = $this->classifier->buildDataSet($account, $extractor, $consoleLogger); + $dataSet = $this->classifier->buildDataSet( + $account, + $extractor, + $consoleLogger, + null, + $shuffle, + ); } $params = [ @@ -111,24 +120,45 @@ protected function execute(InputInterface $input, OutputInterface $output): int [new Euclidean(), new Manhattan(), new Jaccard()], // Kernel ]; - $this->classifier->trainWithCustomDataSet( - $account, - $consoleLogger, - $dataSet, - static function () use ($params, $consoleLogger) { - $estimator = new GridSearch( - KNearestNeighbors::class, - $params, - new FBeta(), - new KFold(3) - ); - $estimator->setLogger($consoleLogger); - $estimator->setBackend(new Amp()); - return $estimator; - }, - null, - false, - ); + if ($dataSet) { + $this->classifier->trainWithCustomDataSet( + $account, + $consoleLogger, + $dataSet, + static function () use ($params, $consoleLogger) { + $estimator = new GridSearch( + KNearestNeighbors::class, + $params, + new FBeta(), + new KFold(3) + ); + $estimator->setLogger($consoleLogger); + $estimator->setBackend(new Amp()); + return $estimator; + }, + null, + false, + ); + } else { + $this->classifier->train( + $account, + $consoleLogger, + $extractor, + static function () use ($params, $consoleLogger) { + $estimator = new GridSearch( + KNearestNeighbors::class, + $params, + new FBeta(), + new KFold(3) + ); + $estimator->setLogger($consoleLogger); + $estimator->setBackend(new Amp()); + return $estimator; + }, + $shuffle, + false, + ); + } return 0; } From c764944345ab2c14b0c32dd55ad37fe187f623f3 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Tue, 21 Mar 2023 18:51:47 +0100 Subject: [PATCH 16/37] fixup! feat(importance-classifier): Reduce text feature vector --- lib/Command/RunMetaEstimator.php | 46 ++++++++----------- lib/Command/TrainAccount.php | 43 +++++++++++++++-- .../NewCompositeExtractor.php | 2 +- .../Classification/ImportanceClassifier.php | 7 +-- 4 files changed, 63 insertions(+), 35 deletions(-) diff --git a/lib/Command/RunMetaEstimator.php b/lib/Command/RunMetaEstimator.php index d3711a3dbc..559d737af5 100644 --- a/lib/Command/RunMetaEstimator.php +++ b/lib/Command/RunMetaEstimator.php @@ -114,28 +114,30 @@ protected function execute(InputInterface $input, OutputInterface $output): int ); } - $params = [ - [1, 3, 5, 10], // Neighbors - [true, false], // Weighted? - [new Euclidean(), new Manhattan(), new Jaccard()], // Kernel - ]; + $estimator = static function () use ($consoleLogger) { + $params = [ + [5, 10, 15, 20, 25, 30], // Neighbors + [true, false], // Weighted? + [new Euclidean(), new Manhattan(), new Jaccard()], // Kernel + ]; + + $estimator = new GridSearch( + KNearestNeighbors::class, + $params, + new FBeta(), + new KFold(3) + ); + $estimator->setLogger($consoleLogger); + $estimator->setBackend(new Amp()); + return $estimator; + }; if ($dataSet) { $this->classifier->trainWithCustomDataSet( $account, $consoleLogger, $dataSet, - static function () use ($params, $consoleLogger) { - $estimator = new GridSearch( - KNearestNeighbors::class, - $params, - new FBeta(), - new KFold(3) - ); - $estimator->setLogger($consoleLogger); - $estimator->setBackend(new Amp()); - return $estimator; - }, + $estimator, null, false, ); @@ -144,17 +146,7 @@ static function () use ($params, $consoleLogger) { $account, $consoleLogger, $extractor, - static function () use ($params, $consoleLogger) { - $estimator = new GridSearch( - KNearestNeighbors::class, - $params, - new FBeta(), - new KFold(3) - ); - $estimator->setLogger($consoleLogger); - $estimator->setBackend(new Amp()); - return $estimator; - }, + $estimator, $shuffle, false, ); diff --git a/lib/Command/TrainAccount.php b/lib/Command/TrainAccount.php index f1a69fc314..6440fec4cc 100644 --- a/lib/Command/TrainAccount.php +++ b/lib/Command/TrainAccount.php @@ -18,6 +18,9 @@ use OCP\AppFramework\Db\DoesNotExistException; use Psr\Container\ContainerInterface; use Psr\Log\LoggerInterface; +use Rubix\ML\Classifiers\GaussianNB; +use Rubix\ML\Classifiers\KNearestNeighbors; +use Rubix\ML\Kernels\Distance\Manhattan; use Symfony\Component\Console\Command\Command; use Symfony\Component\Console\Input\InputArgument; use Symfony\Component\Console\Input\InputInterface; @@ -28,6 +31,8 @@ class TrainAccount extends Command { public const ARGUMENT_ACCOUNT_ID = 'account-id'; public const ARGUMENT_NEW = 'new'; + public const ARGUMENT_NEW_ESTIMATOR = 'new-estimator'; + public const ARGUMENT_NEW_EXTRACTOR = 'new-extractor'; public const ARGUMENT_SHUFFLE = 'shuffle'; public const ARGUMENT_SAVE_DATA = 'save-data'; public const ARGUMENT_LOAD_DATA = 'load-data'; @@ -60,7 +65,19 @@ protected function configure() { $this->setName('mail:account:train'); $this->setDescription('Train the classifier of new messages'); $this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED); - $this->addOption(self::ARGUMENT_NEW, null, null, 'Enable new composite extractor using text based features'); + $this->addOption( + self::ARGUMENT_NEW, + null, + null, + 'Enable new composite extractor and KNN estimator' + ); + $this->addOption( + self::ARGUMENT_NEW_EXTRACTOR, + null, + null, + 'Enable new composite extractor using text based features' + ); + $this->addOption(self::ARGUMENT_NEW_ESTIMATOR, null, null, 'Enable new KNN estimator'); $this->addOption(self::ARGUMENT_SHUFFLE, null, null, 'Shuffle data set before training'); $this->addOption( self::ARGUMENT_DRY_RUN, @@ -89,6 +106,9 @@ protected function execute(InputInterface $input, OutputInterface $output): int $accountId = (int)$input->getArgument(self::ARGUMENT_ACCOUNT_ID); $shuffle = (bool)$input->getOption(self::ARGUMENT_SHUFFLE); $dryRun = (bool)$input->getOption(self::ARGUMENT_DRY_RUN); + $new = (bool)$input->getOption(self::ARGUMENT_NEW); + $newEstimator = $new || (bool)$input->getOption(self::ARGUMENT_NEW_ESTIMATOR); + $newExtractor = $new || (bool)$input->getOption(self::ARGUMENT_NEW_EXTRACTOR); try { $account = $this->accountService->findById($accountId); @@ -104,7 +124,22 @@ protected function execute(InputInterface $input, OutputInterface $output): int } */ - if ($input->getOption(self::ARGUMENT_NEW)) { + if ($newEstimator) { + $estimator = static function () { + // A meta estimator was trained on the same data multiple times to average out the + // variance of the trained model. + // Parameters were chosen from the best configuration across 100 runs. + // Both variance (spread) and f1 score were considered. + // Note: Lower k values generally yield higher f1 scores but show higher variances. + return new KNearestNeighbors(15, true, new Manhattan()); + }; + } else { + $estimator = static function() { + return new GaussianNB(); + }; + } + + if ($newExtractor) { $extractor = $this->container->get(NewCompositeExtractor::class); } else { $extractor = $this->container->get(VanillaCompositeExtractor::class); @@ -136,7 +171,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int $account, $consoleLogger, $dataSet, - null, + $estimator, null, !$dryRun ); @@ -145,7 +180,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int $account, $consoleLogger, $extractor, - null, + $estimator, $shuffle, !$dryRun ); diff --git a/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php index df58a93fe2..e44cfd7b69 100644 --- a/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php @@ -26,7 +26,7 @@ namespace OCA\Mail\Service\Classification\FeatureExtraction; class NewCompositeExtractor extends CompositeExtractor { - public function __construct(VanillaCompositeExtractor $ex1, + public function __construct(VanillaCompositeExtractor $ex1, SubjectAndPreviewTextExtractor $ex2) { parent::__construct([$ex1, $ex2]); } diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index 03bbc6aa00..50fe20a52c 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -96,7 +96,7 @@ class ImportanceClassifier { /** * The maximum number of data sets to train the classifier with */ - private const MAX_TRAINING_SET_SIZE = 200; + private const MAX_TRAINING_SET_SIZE = 350; /** @var MailboxMapper */ private $mailboxMapper; @@ -138,7 +138,7 @@ private function filterMessageHasSenderEmail(Message $message): bool { * @return Estimator&Learner&Persistable */ private function getDefaultEstimator(): Estimator { - //return new GaussianNB(); + return new GaussianNB(); /* return new RandomForest( @@ -162,7 +162,8 @@ private function getDefaultEstimator(): Estimator { ); */ - return new KNearestNeighbors(5, true, new Jaccard()); + //return new KNearestNeighbors(5, true, new Jaccard()); + //return new KNearestNeighbors(2, false, new Manhattan()); } /** From 974ee46558d4e2bb6901ae929fc58f4366ecc481 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Fri, 24 Mar 2023 14:27:22 +0100 Subject: [PATCH 17/37] fixup! feat(importance-classifier): Reduce text feature vector --- .../TrainImportanceClassifierJob.php | 5 +- lib/Command/PredictImportance.php | 16 +- lib/Command/RunMetaEstimator.php | 6 +- lib/Command/TrainAccount.php | 71 ++++----- lib/Db/Classifier.php | 14 ++ .../NewMessageClassificationListener.php | 4 +- .../Version3100Date20230324113141.php | 57 +++++++ lib/Model/ClassifierPipeline.php | 57 +++++++ .../NewCompositeExtractor.php | 9 +- ...TextExtractor.php => SubjectExtractor.php} | 82 +++------- .../Classification/ImportanceClassifier.php | 148 +++++++++-------- .../Classification/PersistenceService.php | 149 +++++++++++++++--- lib/Service/Sync/ImapToDbSynchronizer.php | 13 +- 13 files changed, 433 insertions(+), 198 deletions(-) create mode 100644 lib/Migration/Version3100Date20230324113141.php create mode 100644 lib/Model/ClassifierPipeline.php rename lib/Service/Classification/FeatureExtraction/{SubjectAndPreviewTextExtractor.php => SubjectExtractor.php} (61%) diff --git a/lib/BackgroundJob/TrainImportanceClassifierJob.php b/lib/BackgroundJob/TrainImportanceClassifierJob.php index e3bbc4c44b..0ec467a1a1 100644 --- a/lib/BackgroundJob/TrainImportanceClassifierJob.php +++ b/lib/BackgroundJob/TrainImportanceClassifierJob.php @@ -69,10 +69,7 @@ protected function run($argument) { } try { - $this->classifier->trainWithDefaultEstimator( - $account, - $this->logger - ); + $this->classifier->train($account, $this->logger); } catch (Throwable $e) { $this->logger->error('Cron importance classifier training failed: ' . $e->getMessage(), [ 'exception' => $e, diff --git a/lib/Command/PredictImportance.php b/lib/Command/PredictImportance.php index 386dc413ea..f737f1d4a1 100644 --- a/lib/Command/PredictImportance.php +++ b/lib/Command/PredictImportance.php @@ -12,9 +12,12 @@ use OCA\Mail\AddressList; use OCA\Mail\Db\Message; use OCA\Mail\Service\AccountService; +use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor; use OCA\Mail\Service\Classification\ImportanceClassifier; +use OCA\Mail\Support\ConsoleLoggerDecorator; use OCP\AppFramework\Db\DoesNotExistException; use OCP\IConfig; +use Psr\Container\ContainerInterface; use Psr\Log\LoggerInterface; use Symfony\Component\Console\Command\Command; use Symfony\Component\Console\Input\InputArgument; @@ -30,17 +33,20 @@ class PredictImportance extends Command { private ImportanceClassifier $classifier; private IConfig $config; private LoggerInterface $logger; + private ContainerInterface $container; public function __construct(AccountService $service, ImportanceClassifier $classifier, IConfig $config, - LoggerInterface $logger) { + LoggerInterface $logger, + ContainerInterface $container) { parent::__construct(); $this->accountService = $service; $this->classifier = $classifier; $this->logger = $logger; $this->config = $config; + $this->container = $container; } /** @@ -64,6 +70,11 @@ protected function execute(InputInterface $input, OutputInterface $output): int $accountId = (int)$input->getArgument(self::ARGUMENT_ACCOUNT_ID); $sender = $input->getArgument(self::ARGUMENT_SENDER); + $consoleLogger = new ConsoleLoggerDecorator( + $this->logger, + $output + ); + try { $account = $this->accountService->findById($accountId); } catch (DoesNotExistException $e) { @@ -75,7 +86,8 @@ protected function execute(InputInterface $input, OutputInterface $output): int $fakeMessage->setFrom(AddressList::parse("Name <$sender>")); [$prediction] = $this->classifier->classifyImportance( $account, - [$fakeMessage] + [$fakeMessage], + $consoleLogger ); if ($prediction) { $output->writeln('Message is important'); diff --git a/lib/Command/RunMetaEstimator.php b/lib/Command/RunMetaEstimator.php index 559d737af5..ea7de222f6 100644 --- a/lib/Command/RunMetaEstimator.php +++ b/lib/Command/RunMetaEstimator.php @@ -95,6 +95,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int return 1; } + /** @var NewCompositeExtractor $extractor */ $extractor = $this->container->get(NewCompositeExtractor::class); $consoleLogger = new ConsoleLoggerDecorator( $this->logger, @@ -103,7 +104,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int if ($loadDataPath = $input->getOption(self::ARGUMENT_LOAD_DATA)) { $json = file_get_contents($loadDataPath); - $dataSet = json_decode($json, true); + $dataSet = json_decode($json, true, 512, JSON_THROW_ON_ERROR); } else { $dataSet = $this->classifier->buildDataSet( $account, @@ -137,6 +138,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int $account, $consoleLogger, $dataSet, + $extractor, $estimator, null, false, @@ -152,6 +154,8 @@ protected function execute(InputInterface $input, OutputInterface $output): int ); } + $mbs = (int)(memory_get_peak_usage() / 1024 / 1024); + $output->writeln('' . $mbs . 'MB of memory used'); return 0; } } diff --git a/lib/Command/TrainAccount.php b/lib/Command/TrainAccount.php index 6440fec4cc..2729d598b8 100644 --- a/lib/Command/TrainAccount.php +++ b/lib/Command/TrainAccount.php @@ -11,6 +11,7 @@ use OCA\Mail\Service\AccountService; use OCA\Mail\Service\Classification\ClassificationSettingsService; +use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor; use OCA\Mail\Service\Classification\FeatureExtraction\VanillaCompositeExtractor; use OCA\Mail\Service\Classification\ImportanceClassifier; @@ -19,8 +20,6 @@ use Psr\Container\ContainerInterface; use Psr\Log\LoggerInterface; use Rubix\ML\Classifiers\GaussianNB; -use Rubix\ML\Classifiers\KNearestNeighbors; -use Rubix\ML\Kernels\Distance\Manhattan; use Symfony\Component\Console\Command\Command; use Symfony\Component\Console\Input\InputArgument; use Symfony\Component\Console\Input\InputInterface; @@ -30,13 +29,14 @@ class TrainAccount extends Command { public const ARGUMENT_ACCOUNT_ID = 'account-id'; - public const ARGUMENT_NEW = 'new'; - public const ARGUMENT_NEW_ESTIMATOR = 'new-estimator'; - public const ARGUMENT_NEW_EXTRACTOR = 'new-extractor'; + public const ARGUMENT_OLD = 'old'; + public const ARGUMENT_OLD_ESTIMATOR = 'old-estimator'; + public const ARGUMENT_OLD_EXTRACTOR = 'old-extractor'; public const ARGUMENT_SHUFFLE = 'shuffle'; public const ARGUMENT_SAVE_DATA = 'save-data'; public const ARGUMENT_LOAD_DATA = 'load-data'; public const ARGUMENT_DRY_RUN = 'dry-run'; + public const ARGUMENT_FORCE = 'force'; private AccountService $accountService; private ImportanceClassifier $classifier; @@ -66,18 +66,18 @@ protected function configure() { $this->setDescription('Train the classifier of new messages'); $this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED); $this->addOption( - self::ARGUMENT_NEW, + self::ARGUMENT_OLD, null, null, - 'Enable new composite extractor and KNN estimator' + 'Use old vanilla composite extractor and GaussianNB estimator (implies --old-extractor and --old-estimator)' ); $this->addOption( - self::ARGUMENT_NEW_EXTRACTOR, + self::ARGUMENT_OLD_EXTRACTOR, null, null, - 'Enable new composite extractor using text based features' + 'Use old vanilla composite extractor without text based features' ); - $this->addOption(self::ARGUMENT_NEW_ESTIMATOR, null, null, 'Enable new KNN estimator'); + $this->addOption(self::ARGUMENT_OLD_ESTIMATOR, null, null, 'Use old GaussianNB estimator'); $this->addOption(self::ARGUMENT_SHUFFLE, null, null, 'Shuffle data set before training'); $this->addOption( self::ARGUMENT_DRY_RUN, @@ -85,6 +85,12 @@ protected function configure() { null, 'Don\'t persist classifier after training' ); + $this->addOption( + self::ARGUMENT_FORCE, + null, + null, + 'Train an estimator even if the classification is disabled by the user' + ); $this->addOption( self::ARGUMENT_SAVE_DATA, null, @@ -99,16 +105,14 @@ protected function configure() { ); } - /** - * @return int - */ protected function execute(InputInterface $input, OutputInterface $output): int { $accountId = (int)$input->getArgument(self::ARGUMENT_ACCOUNT_ID); $shuffle = (bool)$input->getOption(self::ARGUMENT_SHUFFLE); $dryRun = (bool)$input->getOption(self::ARGUMENT_DRY_RUN); - $new = (bool)$input->getOption(self::ARGUMENT_NEW); - $newEstimator = $new || (bool)$input->getOption(self::ARGUMENT_NEW_ESTIMATOR); - $newExtractor = $new || (bool)$input->getOption(self::ARGUMENT_NEW_EXTRACTOR); + $force = (bool)$input->getOption(self::ARGUMENT_FORCE); + $old = (bool)$input->getOption(self::ARGUMENT_OLD); + $oldEstimator = $old || $input->getOption(self::ARGUMENT_OLD_ESTIMATOR); + $oldExtractor = $old || $input->getOption(self::ARGUMENT_OLD_EXTRACTOR); try { $account = $this->accountService->findById($accountId); @@ -117,32 +121,23 @@ protected function execute(InputInterface $input, OutputInterface $output): int return 1; } - /* - if (!$this->classificationSettingsService->isClassificationEnabled($account->getUserId())) { + if (!$force && !$this->classificationSettingsService->isClassificationEnabled($account->getUserId())) { $output->writeln("classification is turned off for account $accountId"); return 2; } - */ - if ($newEstimator) { - $estimator = static function () { - // A meta estimator was trained on the same data multiple times to average out the - // variance of the trained model. - // Parameters were chosen from the best configuration across 100 runs. - // Both variance (spread) and f1 score were considered. - // Note: Lower k values generally yield higher f1 scores but show higher variances. - return new KNearestNeighbors(15, true, new Manhattan()); - }; + /** @var IExtractor $extractor */ + if ($oldExtractor) { + $extractor = $this->container->get(VanillaCompositeExtractor::class); } else { - $estimator = static function() { - return new GaussianNB(); - }; + $extractor = $this->container->get(NewCompositeExtractor::class); } - if ($newExtractor) { - $extractor = $this->container->get(NewCompositeExtractor::class); - } else { - $extractor = $this->container->get(VanillaCompositeExtractor::class); + $estimator = null; + if ($oldEstimator) { + $estimator = static function () { + return new GaussianNB(); + }; } $consoleLogger = new ConsoleLoggerDecorator( @@ -159,11 +154,11 @@ protected function execute(InputInterface $input, OutputInterface $output): int null, $shuffle, ); - $json = json_encode($dataSet); + $json = json_encode($dataSet, JSON_THROW_ON_ERROR); file_put_contents($saveDataPath, $json); } else if ($loadDataPath = $input->getOption(self::ARGUMENT_LOAD_DATA)) { $json = file_get_contents($loadDataPath); - $dataSet = json_decode($json, true); + $dataSet = json_decode($json, true, 512, JSON_THROW_ON_ERROR); } if ($dataSet) { @@ -171,6 +166,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int $account, $consoleLogger, $dataSet, + $extractor, $estimator, null, !$dryRun @@ -184,7 +180,6 @@ protected function execute(InputInterface $input, OutputInterface $output): int $shuffle, !$dryRun ); - } $mbs = (int)(memory_get_peak_usage() / 1024 / 1024); diff --git a/lib/Db/Classifier.php b/lib/Db/Classifier.php index dadbfaa4f6..9083c5e335 100644 --- a/lib/Db/Classifier.php +++ b/lib/Db/Classifier.php @@ -47,6 +47,12 @@ * * @method int getCreatedAt() * @method void setCreatedAt(int $createdAt) + * + * @method int getTransformerCount() + * @method void setTransformerCount(int $transformerCount) + * + * @method string|null getTransformers() + * @method void setTransformers(string|null $transformers) */ class Classifier extends Entity { public const TYPE_IMPORTANCE = 'importance'; @@ -87,6 +93,12 @@ class Classifier extends Entity { /** @var int */ protected $createdAt; + /** @var int */ + protected $transformerCount; + + /** @var string */ + protected $transformers; + public function __construct() { $this->addType('accountId', 'integer'); $this->addType('type', 'string'); @@ -99,5 +111,7 @@ public function __construct() { $this->addType('duration', 'integer'); $this->addType('active', 'boolean'); $this->addType('createdAt', 'integer'); + $this->addType('transformerCount', 'integer'); + $this->addType('transformers', 'string'); } } diff --git a/lib/Listener/NewMessageClassificationListener.php b/lib/Listener/NewMessageClassificationListener.php index 01c0ca5f4f..51d269c958 100644 --- a/lib/Listener/NewMessageClassificationListener.php +++ b/lib/Listener/NewMessageClassificationListener.php @@ -101,10 +101,12 @@ public function handle(Event $event): void { try { $predictions = $this->classifier->classifyImportance( $event->getAccount(), - $messages + $messages, + $this->logger ); foreach ($event->getMessages() as $message) { + $this->logger->info("Message {$message->getUid()} ({$message->getPreviewText()}) is " . ($predictions[$message->getUid()] ? 'important' : 'not important')); if ($predictions[$message->getUid()] ?? false) { $this->mailManager->flagMessage($event->getAccount(), $event->getMailbox()->getName(), $message->getUid(), Tag::LABEL_IMPORTANT, true); $this->mailManager->tagMessage($event->getAccount(), $event->getMailbox()->getName(), $message, $important, true); diff --git a/lib/Migration/Version3100Date20230324113141.php b/lib/Migration/Version3100Date20230324113141.php new file mode 100644 index 0000000000..81e438f364 --- /dev/null +++ b/lib/Migration/Version3100Date20230324113141.php @@ -0,0 +1,57 @@ + + * + * @author Richard Steinmetz + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * + */ + +namespace OCA\Mail\Migration; + +use Closure; +use OCP\DB\ISchemaWrapper; +use OCP\DB\Types; +use OCP\Migration\IOutput; +use OCP\Migration\SimpleMigrationStep; + +class Version3100Date20230324113141 extends SimpleMigrationStep { + + /** + * @param IOutput $output + * @param Closure(): ISchemaWrapper $schemaClosure + * @param array $options + * @return null|ISchemaWrapper + */ + public function changeSchema(IOutput $output, Closure $schemaClosure, array $options): ?ISchemaWrapper { + /** @var ISchemaWrapper $schema */ + $schema = $schemaClosure(); + + $classifierTable = $schema->getTable('mail_classifiers'); + if (!$classifierTable->hasColumn('transformer_count')) { + $classifierTable->addColumn('transformer_count', Types::INTEGER, [ + 'notnull' => true, + 'default' => 0, + ]); + } + + return $schema; + } +} diff --git a/lib/Model/ClassifierPipeline.php b/lib/Model/ClassifierPipeline.php new file mode 100644 index 0000000000..39f7f7c259 --- /dev/null +++ b/lib/Model/ClassifierPipeline.php @@ -0,0 +1,57 @@ + + * + * @author Richard Steinmetz + * + * @license AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +namespace OCA\Mail\Model; + +use Rubix\ML\Estimator; +use Rubix\ML\Transformers\Transformer; + +class ClassifierPipeline { + private Estimator $estimator; + + /** @var Transformer[] */ + private array $transformers; + + /** + * @param Estimator $estimator + * @param Transformer[] $transformers + */ + public function __construct(Estimator $estimator, array $transformers) { + $this->estimator = $estimator; + $this->transformers = $transformers; + } + + public function getEstimator(): Estimator { + return $this->estimator; + } + + /** + * @return Transformer[] + */ + public function getTransformers(): array { + return $this->transformers; + } +} diff --git a/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php index e44cfd7b69..46d48cb187 100644 --- a/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php @@ -26,8 +26,15 @@ namespace OCA\Mail\Service\Classification\FeatureExtraction; class NewCompositeExtractor extends CompositeExtractor { + private SubjectExtractor $subjectExtractor; + public function __construct(VanillaCompositeExtractor $ex1, - SubjectAndPreviewTextExtractor $ex2) { + SubjectExtractor $ex2) { parent::__construct([$ex1, $ex2]); + $this->subjectExtractor = $ex2; + } + + public function getSubjectExtractor(): SubjectExtractor { + return $this->subjectExtractor; } } diff --git a/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php similarity index 61% rename from lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php rename to lib/Service/Classification/FeatureExtraction/SubjectExtractor.php index 859a92253e..9ebeb3f56a 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectAndPreviewTextExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php @@ -27,49 +27,47 @@ use OCA\Mail\Account; use OCA\Mail\Db\Message; -use OCA\Mail\Db\StatisticsDao; -use OCA\Mail\Service\Classification\ImportanceClassifier; +use OCA\Mail\Exception\ServiceException; use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Transformers\TSNE; -use Rubix\ML\Transformers\GaussianRandomProjector; -use Rubix\ML\Transformers\LinearDiscriminantAnalysis; use Rubix\ML\Transformers\MultibyteTextNormalizer; -use Rubix\ML\Transformers\PrincipalComponentAnalysis; -use Rubix\ML\Transformers\SparseRandomProjector; -use Rubix\ML\Transformers\TfIdfTransformer; use Rubix\ML\Transformers\Transformer; use Rubix\ML\Transformers\WordCountVectorizer; use RuntimeException; use function array_column; use function array_map; -class SubjectAndPreviewTextExtractor implements IExtractor { - private StatisticsDao $statisticsDao; +class SubjectExtractor implements IExtractor { private WordCountVectorizer $wordCountVectorizer; private Transformer $dimensionalReductionTransformer; private int $max = -1; - private array $senderCache = []; - - public function __construct(StatisticsDao $statisticsDao) { - $this->statisticsDao = $statisticsDao; - - // Limit vocabulary to limit ram usage. It takes about 5 GB of ram if an unbounded - // vocabulary is used (and a lot more time to compute). + public function __construct() { + // Limit vocabulary to limit memory usage $vocabSize = 100; $this->wordCountVectorizer = new WordCountVectorizer($vocabSize); $this->dimensionalReductionTransformer = new TSNE((int)($vocabSize * 0.1)); } + public function getWordCountVectorizer(): WordCountVectorizer { + return $this->wordCountVectorizer; + } + + public function setWordCountVectorizer(WordCountVectorizer $wordCountVectorizer): void { + $this->wordCountVectorizer = $wordCountVectorizer; + $this->limitFeatureSize(); + } + /** * @inheritDoc */ public function prepare(Account $account, array $incomingMailboxes, array $outgoingMailboxes, array $messages): void { + /** @var array> $data */ $data = array_map(static function(Message $message) { return [ - 'text' => ($message->getSubject() ?? '') . ' ' . ($message->getPreviewText() ?? ''), + 'text' => $message->getSubject() ?? '', 'label' => $message->getFlagImportant() ? 'i' : 'ni', ]; }, $messages); @@ -83,9 +81,7 @@ public function prepare(Account $account, array $incomingMailboxes, array $outgo ->apply($this->wordCountVectorizer) ->apply($this->dimensionalReductionTransformer); - // Limit feature vector length to actual vocabulary size - $vocab = $this->wordCountVectorizer->vocabularies()[0]; - $this->max = count($vocab); + $this->limitFeatureSize(); } /** @@ -98,14 +94,10 @@ public function extract(Message $message): array { } $email = $sender->getEmail(); - if (isset($this->senderCache[$email])) { - //return $this->senderCache[$email]; - } - // Build training data set $trainText = ($message->getSubject() ?? '') . ' ' . ($message->getPreviewText() ?? ''); - $trainDataSet = Unlabeled::build([$trainText]) + $trainDataSet = Unlabeled::build([[$trainText]]) ->apply(new MultibyteTextNormalizer()) ->apply($this->wordCountVectorizer) ->apply($this->dimensionalReductionTransformer); @@ -117,42 +109,14 @@ public function extract(Message $message): array { $textFeatures = $trainDataSet->sample(0); } - //$this->senderCache[$email] = $textFeatures; - return $textFeatures; } - private function getSubjects(): array { - return array_merge(...array_values($this->subjects)); - } - - private function getPreviewTexts(): array { - return array_merge(...array_values($this->previewTexts)); - } - - private function getSubjectsOfSender(string $email): array { - $concatSubjects = []; - foreach ($this->subjects as $sender => $subjects) { - if ($sender !== $email) { - continue; - } - - $concatSubjects[] = $subjects; - } - - return array_merge(...$concatSubjects); - } - - private function getPreviewTextsOfSender(string $email): array { - $concatPreviewTexts = []; - foreach ($this->previewTexts as $sender => $previewTexts) { - if ($sender !== $email) { - continue; - } - - $concatPreviewTexts[] = $previewTexts; - } - - return array_merge(...$concatPreviewTexts); + /** + * Limit feature vector length to actual vocabulary size. + */ + private function limitFeatureSize(): void { + $vocab = $this->wordCountVectorizer->vocabularies()[0]; + $this->max = count($vocab); } } diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index 50fe20a52c..49d0add0cb 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -20,28 +20,25 @@ use OCA\Mail\Exception\ClassifierTrainingException; use OCA\Mail\Exception\ServiceException; use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; +use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor; +use OCA\Mail\Service\Classification\FeatureExtraction\SubjectExtractor; +use OCA\Mail\Service\Classification\FeatureExtraction\VanillaCompositeExtractor; use OCA\Mail\Support\PerformanceLogger; use OCA\Mail\Support\PerformanceLoggerTask; use OCP\AppFramework\Db\DoesNotExistException; -use PHPUnit\Framework\Constraint\Callback; +use Psr\Container\ContainerExceptionInterface; +use Psr\Container\ContainerInterface; use Psr\Log\LoggerInterface; -use Rubix\ML\Classifiers\ClassificationTree; -use Rubix\ML\Classifiers\GaussianNB; use Rubix\ML\Classifiers\KNearestNeighbors; -use Rubix\ML\Classifiers\MultilayerPerceptron; -use Rubix\ML\Classifiers\RandomForest; use Rubix\ML\CrossValidation\Reports\MulticlassBreakdown; use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Estimator; -use Rubix\ML\Kernels\Distance\Jaccard; use Rubix\ML\Kernels\Distance\Manhattan; use Rubix\ML\Learner; -use Rubix\ML\NeuralNet\ActivationFunctions\Sigmoid; -use Rubix\ML\NeuralNet\Layers\Activation; -use Rubix\ML\NeuralNet\Layers\Dense; -use Rubix\ML\NeuralNet\Optimizers\Adam; use Rubix\ML\Persistable; +use Rubix\ML\Transformers\Transformer; +use Rubix\ML\Transformers\WordCountVectorizer; use RuntimeException; use function array_column; use function array_combine; @@ -113,57 +110,36 @@ class ImportanceClassifier { /** @var ImportanceRulesClassifier */ private $rulesClassifier; - private LoggerInterface $logger; + private VanillaCompositeExtractor $vanillaExtractor; + private ContainerInterface $container; public function __construct(MailboxMapper $mailboxMapper, MessageMapper $messageMapper, - CompositeExtractor $extractor, PersistenceService $persistenceService, PerformanceLogger $performanceLogger, ImportanceRulesClassifier $rulesClassifier, - LoggerInterface $logger, - SubjectExtractor $subjectExtractor) { + VanillaCompositeExtractor $vanillaExtractor, + ContainerInterface $container) { $this->mailboxMapper = $mailboxMapper; $this->messageMapper = $messageMapper; $this->persistenceService = $persistenceService; $this->performanceLogger = $performanceLogger; $this->rulesClassifier = $rulesClassifier; + $this->vanillaExtractor = $vanillaExtractor; + $this->container = $container; } - private function filterMessageHasSenderEmail(Message $message): bool { - return $message->getFrom()->first() !== null && $message->getFrom()->first()->getEmail() !== null; + private static function createDefaultEstimator(): KNearestNeighbors { + // A meta estimator was trained on the same data multiple times to average out the + // variance of the trained model. + // Parameters were chosen from the best configuration across 100 runs. + // Both variance (spread) and f1 score were considered. + // Note: Lower k values yield slightly higher f1 scores but show higher variances. + return new KNearestNeighbors(15, true, new Manhattan()); } - /** - * @return Estimator&Learner&Persistable - */ - private function getDefaultEstimator(): Estimator { - return new GaussianNB(); - - /* - return new RandomForest( - new ClassificationTree(10, 1), - 10, - 0.2, - true, - ); - */ - - /* - return new MultilayerPerceptron( - [ - new Dense(1004), - new Activation(new Sigmoid()) - ], - 32, - null, - 1e-4, - 10, - ); - */ - - //return new KNearestNeighbors(5, true, new Jaccard()); - //return new KNearestNeighbors(2, false, new Manhattan()); + private function filterMessageHasSenderEmail(Message $message): bool { + return $message->getFrom()->first() !== null && $message->getFrom()->first()->getEmail() !== null; } /** @@ -232,7 +208,7 @@ public function buildDataSet( * * @param Account $account * @param LoggerInterface $logger - * @param IExtractor $extractor + * @param ?IExtractor $extractor The extractor to use for feature extraction. If null, the default extractor will be used. * @param ?Closure $estimator Returned instance should at least implement Learner, Estimator and Persistable. If null, the default estimator will be used. * @param bool $shuffleDataSet Shuffle the data set before training * @param bool $persist Persist the trained classifier to use it for message classification @@ -242,13 +218,21 @@ public function buildDataSet( public function train( Account $account, LoggerInterface $logger, - IExtractor $extractor, + ?IExtractor $extractor = null, ?Closure $estimator = null, bool $shuffleDataSet = false, bool $persist = true, ): void { $perf = $this->performanceLogger->start('importance classifier training'); + if ($extractor === null) { + try { + $extractor = $this->container->get(NewCompositeExtractor::class); + } catch (ContainerExceptionInterface $e) { + throw new ServiceException('Default extractor is not available', 0, $e); + } + } + $dataSet = $this->buildDataSet($account, $extractor, $logger, $perf, $shuffleDataSet); if ($dataSet === null) { return; @@ -258,6 +242,7 @@ public function train( $account, $logger, $dataSet, + $extractor, $estimator, $perf, $persist, @@ -270,7 +255,9 @@ public function train( * @param Account $account * @param LoggerInterface $logger * @param array $dataSet Training data set built by buildDataSet() + * @param IExtractor $extractor Extractor used to extract the given data set * @param ?Closure $estimator Returned instance should at least implement Learner, Estimator and Persistable. If null, the default estimator will be used. + * @param PerformanceLoggerTask|null $perf Optionally reuse a performance logger task * @param bool $persist Persist the trained classifier to use it for message classification * * @throws ServiceException @@ -279,18 +266,19 @@ public function trainWithCustomDataSet( Account $account, LoggerInterface $logger, array $dataSet, - ?Closure $estimator = null, + IExtractor $extractor, + ?Closure $estimator, ?PerformanceLoggerTask $perf = null, bool $persist = true, ): void { + $perf ??= $this->performanceLogger->start('importance classifier training'); + if ($estimator === null) { - $estimator = function () { - return $this->getDefaultEstimator(); + $estimator = static function() { + return self::createDefaultEstimator(); }; } - $perf ??= $this->performanceLogger->start('importance classifier training'); - /** * How many of the most recent messages are excluded from training? */ @@ -347,9 +335,17 @@ public function trainWithCustomDataSet( $this->trainClassifier($persistedEstimator, $dataSet); $perf->step("train classifier with full data set"); + // Extract persisted transformers of the subject extractor. + // Is a bit hacky but a full abstraction would be overkill. + /** @var (Transformer&Persistable)[] $transformers */ + $transformers = []; + if ($extractor instanceof NewCompositeExtractor) { + $transformers[] = $extractor->getSubjectExtractor()->getWordCountVectorizer(); + } + $classifier->setAccountId($account->getId()); $classifier->setDuration($perf->end()); - $this->persistenceService->persist($classifier, $persistedEstimator); + $this->persistenceService->persist($classifier, $persistedEstimator, $transformers); $logger->debug("classifier {$classifier->getId()} persisted"); } } @@ -361,8 +357,6 @@ public function trainWithCustomDataSet( * @return Mailbox[] */ private function getIncomingMailboxes(Account $account): array { - return [$this->mailboxMapper->find($account, 'INBOX')]; - /* return array_filter($this->mailboxMapper->findAll($account), static function (Mailbox $mailbox) { foreach (self::EXEMPT_FROM_TRAINING as $excluded) { if ($mailbox->isSpecialUse($excluded)) { @@ -371,7 +365,6 @@ private function getIncomingMailboxes(Account $account): array { } return true; }); - */ } /** @@ -412,7 +405,7 @@ private function getFeaturesAndImportance(Account $account, IExtractor $extractor): array { $extractor->prepare($account, $incomingMailboxes, $outgoingMailboxes, $messages); - return array_map(function (Message $message) use ($extractor) { + return array_map(static function (Message $message) use ($extractor) { $sender = $message->getFrom()->first(); if ($sender === null) { throw new RuntimeException('This should not happen'); @@ -429,21 +422,25 @@ private function getFeaturesAndImportance(Account $account, /** * @param Account $account * @param Message[] $messages + * @param LoggerInterface $logger * * @return bool[] + * * @throws ServiceException */ - public function classifyImportance(Account $account, array $messages): array { - $estimator = null; + public function classifyImportance(Account $account, + array $messages, + LoggerInterface $logger): array { + $pipeline = null; try { - $estimator = $this->persistenceService->loadLatest($account); + $pipeline = $this->persistenceService->loadLatest($account); } catch (ServiceException $e) { - $this->logger->warning('Failed to load importance classifier: ' . $e->getMessage(), [ + $logger->warning('Failed to load importance classifier: ' . $e->getMessage(), [ 'exception' => $e, ]); } - if ($estimator === null) { + if ($pipeline === null) { $predictions = $this->rulesClassifier->classifyImportance( $account, $this->getIncomingMailboxes($account), @@ -461,13 +458,29 @@ public function classifyImportance(Account $account, array $messages): array { } $messagesWithSender = array_filter($messages, [$this, 'filterMessageHasSenderEmail']); + // Load persisted transformers of the subject extractor. + // Is a bit hacky but a full abstraction would be overkill. + $transformers = $pipeline->getTransformers(); + $wordCountVectorizer = $transformers[0]; + if (!($wordCountVectorizer instanceof WordCountVectorizer)) { + throw new ServiceException("Failed to load persisted transformer: Expected " . WordCountVectorizer::class . ", got" . $wordCountVectorizer::class); + } + + $subjectExtractor = new SubjectExtractor(); + $subjectExtractor->setWordCountVectorizer($wordCountVectorizer); + $extractor = new NewCompositeExtractor( + $this->vanillaExtractor, + $subjectExtractor, + ); + $features = $this->getFeaturesAndImportance( $account, $this->getIncomingMailboxes($account), $this->getOutgoingMailboxes($account), - $messagesWithSender + $messagesWithSender, + $extractor ); - $predictions = $estimator->predict( + $predictions = $pipeline->getEstimator()->predict( Unlabeled::build(array_column($features, 'features')) ); return array_combine( @@ -491,9 +504,9 @@ private function trainClassifier(Learner $classifier, array $trainingSet): void * @param Estimator $estimator * @param array $trainingSet * @param array $validationSet + * @param LoggerInterface $logger * * @return Classifier - * @throws ClassifierTrainingException */ private function validateClassifier(Estimator $estimator, array $trainingSet, @@ -512,9 +525,6 @@ private function validateClassifier(Estimator $estimator, $recallImportant = $report['classes'][self::LABEL_IMPORTANT]['recall'] ?? 0; $precisionImportant = $report['classes'][self::LABEL_IMPORTANT]['precision'] ?? 0; $f1ScoreImportant = $report['classes'][self::LABEL_IMPORTANT]['f1 score'] ?? 0; - //$recallImportant = $report['overall']['recall'] ?? 0; - //$precisionImportant = $report['overall']['precision'] ?? 0; - //$f1ScoreImportant = $report['overall']['f1 score'] ?? 0; /** * What we care most is the percentage of messages classified as important in relation to the truly important messages diff --git a/lib/Service/Classification/PersistenceService.php b/lib/Service/Classification/PersistenceService.php index 5cd73c09c6..5417a16ba2 100644 --- a/lib/Service/Classification/PersistenceService.php +++ b/lib/Service/Classification/PersistenceService.php @@ -9,26 +9,33 @@ namespace OCA\Mail\Service\Classification; +use OCA\DAV\Connector\Sabre\File; use OCA\Mail\Account; use OCA\Mail\AppInfo\Application; use OCA\Mail\Db\Classifier; use OCA\Mail\Db\ClassifierMapper; use OCA\Mail\Db\MailAccountMapper; use OCA\Mail\Exception\ServiceException; +use OCA\Mail\Model\ClassifierPipeline; use OCP\App\IAppManager; use OCP\AppFramework\Db\DoesNotExistException; use OCP\AppFramework\Utility\ITimeFactory; +use OCP\DB\Exception; +use OCP\Files; use OCP\Files\IAppData; use OCP\Files\NotFoundException; use OCP\Files\NotPermittedException; use OCP\ICacheFactory; use OCP\ITempManager; use Psr\Log\LoggerInterface; +use Rubix\ML\Encoding; use Rubix\ML\Estimator; use Rubix\ML\Learner; use Rubix\ML\Persistable; use Rubix\ML\PersistentModel; use Rubix\ML\Persisters\Filesystem; +use Rubix\ML\Serializers\RBX; +use Rubix\ML\Transformers\Transformer; use RuntimeException; use function file_get_contents; use function file_put_contents; @@ -81,14 +88,17 @@ public function __construct(ClassifierMapper $mapper, } /** - * Persist the classifier data to the database and the estimator to storage + * Persist the classifier data to the database, the estimator and its transformers to storage * * @param Classifier $classifier * @param Learner&Persistable $estimator + * @param (Transformer&Persistable)[] $transformers * * @throws ServiceException */ - public function persist(Classifier $classifier, Learner $estimator): void { + public function persist(Classifier $classifier, + Learner $estimator, + array $transformers): void { /* * First we have to insert the row to get the unique ID, but disable * it until the model is persisted as well. Otherwise another process @@ -128,10 +138,50 @@ public function persist(Classifier $classifier, Learner $estimator): void { $file = $folder->newFile((string)$classifier->getId()); $file->putContent($serializedClassifier); $this->logger->debug('Serialized classifier written to app data'); - } catch (NotPermittedException $e) { + } catch (NotPermittedException | NotFoundException $e) { throw new ServiceException('Could not create classifiers directory: ' . $e->getMessage(), 0, $e); } + /* + * Then we serialize the transformer pipeline to temporary files + */ + $transformerIndex = 0; + $serializer = new RBX(); + foreach ($transformers as $transformer) { + $tmpPath = $this->tempManager->getTemporaryFile(); + try { + /** + * This is how to serialize a transformer according to the official docs. + * PersistentModel can only be used on Learners which transformers don't implement. + * + * Ref https://docs.rubixml.com/2.0/model-persistence.html#persisting-transformers + * + * @psalm-suppress InternalMethod + */ + $serializer->serialize($transformer)->saveTo(new Filesystem($tmpPath)); + $serializedTransformer = file_get_contents($tmpPath); + $this->logger->debug('Serialized transformer written to tmp file (' . strlen($serializedTransformer) . 'B'); + } catch (RuntimeException $e) { + throw new ServiceException("Could not serialize transformer: " . $e->getMessage(), 0, $e); + } + + try { + $file = $folder->newFile("{$classifier->getId()}_t$transformerIndex"); + $file->putContent($serializedTransformer); + $this->logger->debug("Serialized transformer $transformerIndex written to app data"); + } catch (NotPermittedException | NotFoundException $e) { + throw new ServiceException( + "Failed to persist transformer $transformerIndex: " . $e->getMessage(), + 0, + $e + ); + } + + $transformerIndex++; + } + + $classifier->setTransformerCount($transformerIndex); + /* * Now we set the model active so it can be used by the next request */ @@ -142,29 +192,34 @@ public function persist(Classifier $classifier, Learner $estimator): void { /** * @param Account $account * - * @return Estimator|null + * @return ?ClassifierPipeline + * * @throws ServiceException */ - public function loadLatest(Account $account): ?Estimator { + public function loadLatest(Account $account): ?ClassifierPipeline { try { $latestModel = $this->mapper->findLatest($account->getId()); } catch (DoesNotExistException $e) { return null; } - return $this->load($latestModel->getId()); + return $this->load($latestModel); } /** - * @param int $id + * Load an estimator and its transformers of a classifier from storage + * + * @param Classifier $classifier + * @return ClassifierPipeline * - * @return Estimator * @throws ServiceException */ - public function load(int $id): Estimator { - $cached = $this->getCached($id); + public function load(Classifier $classifier): ClassifierPipeline { + $id = $classifier->getId(); + $cached = $this->getCached($classifier->getId(), $classifier->getTransformerCount()); if ($cached !== null) { $this->logger->debug("Using cached serialized classifier $id"); - $serialized = $cached; + $serialized = $cached[0]; + $serializedTransformers = array_slice($cached, 1); } else { $this->logger->debug('Loading serialized classifier from app data'); try { @@ -184,19 +239,56 @@ public function load(int $id): Estimator { $size = strlen($serialized); $this->logger->debug("Serialized classifier loaded (size=$size)"); - $this->cache($id, $serialized); + $serializedTransformers = []; + for ($i = 0; $i < $classifier->getTransformerCount(); $i++) { + try { + $transformerFile = $modelsFolder->getFile("{$id}_t$i"); + } catch (NotFoundException $e) { + $this->logger->debug("Could not load transformer $i of classifier $id: " . $e->getMessage()); + throw new ServiceException("Could not load transformer $i of classifier $id: " . $e->getMessage(), 0, $e); + } + + try { + $serializedTransformer = $transformerFile->getContent(); + } catch (NotFoundException | NotPermittedException $e) { + $this->logger->debug("Could not load content for transformer file $i with classifier id $id: " . $e->getMessage()); + throw new ServiceException("Could not load content for transformer file $i with classifier id $id: " . $e->getMessage(), 0, $e); + } + $size = strlen($serializedTransformer); + $this->logger->debug("Serialized transformer $i loaded (size=$size)"); + $serializedTransformers[] = $serializedTransformer; + } + + $this->cache($id, $serialized, $serializedTransformers); } $tmpPath = $this->tempManager->getTemporaryFile(); file_put_contents($tmpPath, $serialized); - try { $estimator = PersistentModel::load(new Filesystem($tmpPath)); } catch (RuntimeException $e) { throw new ServiceException("Could not deserialize persisted classifier $id: " . $e->getMessage(), 0, $e); } - return $estimator; + $transformers = array_map(function(string $serializedTransformer) use ($id) { + $serializer = new RBX(); + $tmpPath = $this->tempManager->getTemporaryFile(); + file_put_contents($tmpPath, $serializedTransformer); + try { + $persister = new Filesystem($tmpPath); + $transformer = $persister->load()->deserializeWith($serializer); + } catch (RuntimeException $e) { + throw new ServiceException("Could not deserialize persisted transformer of classifier $id: " . $e->getMessage(), 0, $e); + } + + if (!($transformer instanceof Transformer)) { + throw new ServiceException("Transformer of classifier $id is not a transformer: Got " . $transformer::class); + } + + return $transformer; + }, $serializedTransformers); + + return new ClassifierPipeline($estimator, $transformers); } public function cleanUp(): void { @@ -239,22 +331,41 @@ private function getCacheKey(int $id): string { return "mail_classifier_$id"; } - private function getCached(int $id): ?string { + private function getTransformerCacheKey(int $id, int $index): string { + return $this->getCacheKey($id) . "_transformer_$index"; + } + + /** + * @param int $id + * @param int $transformerCount + * + * @return string[]|null Array of serialized classifier and transformers + */ + private function getCached(int $id, int $transformerCount): ?array { if (!$this->cacheFactory->isLocalCacheAvailable()) { return null; } $cache = $this->cacheFactory->createLocal(); - return $cache->get( - $this->getCacheKey($id) - ); + $values = []; + $values[] = $cache->get($this->getCacheKey($id)); + for ($i = 0; $i < $transformerCount; $i++) { + $values[] = $cache->get($this->getTransformerCacheKey($id, $i)); + } + return $values; } - private function cache(int $id, string $serialized): void { + private function cache(int $id, string $serialized, array $serializedTransformers): void { if (!$this->cacheFactory->isLocalCacheAvailable()) { return; } $cache = $this->cacheFactory->createLocal(); $cache->set($this->getCacheKey($id), $serialized); + + $transformerIndex = 0; + foreach ($serializedTransformers as $transformer) { + $cache->set($this->getTransformerCacheKey($id, $transformerIndex), $transformer); + $transformerIndex++; + } } } diff --git a/lib/Service/Sync/ImapToDbSynchronizer.php b/lib/Service/Sync/ImapToDbSynchronizer.php index 76d68c15db..0054bc54b2 100644 --- a/lib/Service/Sync/ImapToDbSynchronizer.php +++ b/lib/Service/Sync/ImapToDbSynchronizer.php @@ -29,6 +29,7 @@ use OCA\Mail\Exception\UidValidityChangedException; use OCA\Mail\IMAP\IMAPClientFactory; use OCA\Mail\IMAP\MessageMapper as ImapMessageMapper; +use OCA\Mail\IMAP\PreviewEnhancer; use OCA\Mail\IMAP\Sync\Request; use OCA\Mail\IMAP\Sync\Synchronizer; use OCA\Mail\Model\IMAPMessage; @@ -72,6 +73,8 @@ class ImapToDbSynchronizer { /** @var IMailManager */ private $mailManager; + private PreviewEnhancer $previewEnhancer; + public function __construct(DatabaseMessageMapper $dbMapper, IMAPClientFactory $clientFactory, ImapMessageMapper $imapMapper, @@ -80,7 +83,8 @@ public function __construct(DatabaseMessageMapper $dbMapper, IEventDispatcher $dispatcher, PerformanceLogger $performanceLogger, LoggerInterface $logger, - IMailManager $mailManager) { + IMailManager $mailManager, + PreviewEnhancer $previewEnhancer) { $this->dbMapper = $dbMapper; $this->clientFactory = $clientFactory; $this->imapMapper = $imapMapper; @@ -90,6 +94,7 @@ public function __construct(DatabaseMessageMapper $dbMapper, $this->performanceLogger = $performanceLogger; $this->logger = $logger; $this->mailManager = $mailManager; + $this->previewEnhancer = $previewEnhancer; } /** @@ -105,9 +110,9 @@ public function syncAccount(Account $account, $snoozeMailboxId = $account->getMailAccount()->getSnoozeMailboxId(); $sentMailboxId = $account->getMailAccount()->getSentMailboxId(); $trashRetentionDays = $account->getMailAccount()->getTrashRetentionDays(); - + $client = $this->clientFactory->getClient($account); - + foreach ($this->mailboxMapper->findAll($account) as $mailbox) { $syncTrash = $trashMailboxId === $mailbox->getId() && $trashRetentionDays !== null; $syncSnooze = $snoozeMailboxId === $mailbox->getId(); @@ -131,7 +136,7 @@ public function syncAccount(Account $account, $rebuildThreads = true; } } - + $client->logout(); $this->dispatcher->dispatchTyped( From f68501bc947376ab5fa8b0a8c64c5447aaa24cc2 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Fri, 24 Mar 2023 14:27:48 +0100 Subject: [PATCH 18/37] fixup! feat(importance-classifier): Reduce text feature vector --- lib/Command/PredictImportance.php | 1 - lib/Command/TrainAccount.php | 2 +- lib/Migration/Version3100Date20230324113141.php | 1 - .../Classification/FeatureExtraction/SubjectExtractor.php | 3 +-- lib/Service/Classification/ImportanceClassifier.php | 2 +- lib/Service/Classification/PersistenceService.php | 6 ++---- 6 files changed, 5 insertions(+), 10 deletions(-) diff --git a/lib/Command/PredictImportance.php b/lib/Command/PredictImportance.php index f737f1d4a1..eb9d49e802 100644 --- a/lib/Command/PredictImportance.php +++ b/lib/Command/PredictImportance.php @@ -12,7 +12,6 @@ use OCA\Mail\AddressList; use OCA\Mail\Db\Message; use OCA\Mail\Service\AccountService; -use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor; use OCA\Mail\Service\Classification\ImportanceClassifier; use OCA\Mail\Support\ConsoleLoggerDecorator; use OCP\AppFramework\Db\DoesNotExistException; diff --git a/lib/Command/TrainAccount.php b/lib/Command/TrainAccount.php index 2729d598b8..33ce44371d 100644 --- a/lib/Command/TrainAccount.php +++ b/lib/Command/TrainAccount.php @@ -156,7 +156,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int ); $json = json_encode($dataSet, JSON_THROW_ON_ERROR); file_put_contents($saveDataPath, $json); - } else if ($loadDataPath = $input->getOption(self::ARGUMENT_LOAD_DATA)) { + } elseif ($loadDataPath = $input->getOption(self::ARGUMENT_LOAD_DATA)) { $json = file_get_contents($loadDataPath); $dataSet = json_decode($json, true, 512, JSON_THROW_ON_ERROR); } diff --git a/lib/Migration/Version3100Date20230324113141.php b/lib/Migration/Version3100Date20230324113141.php index 81e438f364..2181e24a73 100644 --- a/lib/Migration/Version3100Date20230324113141.php +++ b/lib/Migration/Version3100Date20230324113141.php @@ -33,7 +33,6 @@ use OCP\Migration\SimpleMigrationStep; class Version3100Date20230324113141 extends SimpleMigrationStep { - /** * @param IOutput $output * @param Closure(): ISchemaWrapper $schemaClosure diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php index 9ebeb3f56a..1ca3fa74ae 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php @@ -27,7 +27,6 @@ use OCA\Mail\Account; use OCA\Mail\Db\Message; -use OCA\Mail\Exception\ServiceException; use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Transformers\TSNE; @@ -65,7 +64,7 @@ public function setWordCountVectorizer(WordCountVectorizer $wordCountVectorizer) */ public function prepare(Account $account, array $incomingMailboxes, array $outgoingMailboxes, array $messages): void { /** @var array> $data */ - $data = array_map(static function(Message $message) { + $data = array_map(static function (Message $message) { return [ 'text' => $message->getSubject() ?? '', 'label' => $message->getFlagImportant() ? 'i' : 'ni', diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index 49d0add0cb..122c39827b 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -274,7 +274,7 @@ public function trainWithCustomDataSet( $perf ??= $this->performanceLogger->start('importance classifier training'); if ($estimator === null) { - $estimator = static function() { + $estimator = static function () { return self::createDefaultEstimator(); }; } diff --git a/lib/Service/Classification/PersistenceService.php b/lib/Service/Classification/PersistenceService.php index 5417a16ba2..2815e05ad8 100644 --- a/lib/Service/Classification/PersistenceService.php +++ b/lib/Service/Classification/PersistenceService.php @@ -20,7 +20,6 @@ use OCP\App\IAppManager; use OCP\AppFramework\Db\DoesNotExistException; use OCP\AppFramework\Utility\ITimeFactory; -use OCP\DB\Exception; use OCP\Files; use OCP\Files\IAppData; use OCP\Files\NotFoundException; @@ -28,7 +27,6 @@ use OCP\ICacheFactory; use OCP\ITempManager; use Psr\Log\LoggerInterface; -use Rubix\ML\Encoding; use Rubix\ML\Estimator; use Rubix\ML\Learner; use Rubix\ML\Persistable; @@ -270,13 +268,13 @@ public function load(Classifier $classifier): ClassifierPipeline { throw new ServiceException("Could not deserialize persisted classifier $id: " . $e->getMessage(), 0, $e); } - $transformers = array_map(function(string $serializedTransformer) use ($id) { + $transformers = array_map(function (string $serializedTransformer) use ($id) { $serializer = new RBX(); $tmpPath = $this->tempManager->getTemporaryFile(); file_put_contents($tmpPath, $serializedTransformer); try { $persister = new Filesystem($tmpPath); - $transformer = $persister->load()->deserializeWith($serializer); + $transformer = $persister->load()->deserializeWith($serializer); } catch (RuntimeException $e) { throw new ServiceException("Could not deserialize persisted transformer of classifier $id: " . $e->getMessage(), 0, $e); } From 7127cf7cc3928f3580a56dbc973bd6398f358090 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Fri, 24 Mar 2023 15:24:22 +0100 Subject: [PATCH 19/37] fixup! feat(importance-classifier): Reduce text feature vector --- lib/Command/PredictImportance.php | 20 +++++++------------ lib/Command/RunMetaEstimator.php | 8 ++++++++ .../Classification/PersistenceService.php | 8 +++++++- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/lib/Command/PredictImportance.php b/lib/Command/PredictImportance.php index eb9d49e802..fcaa458ec2 100644 --- a/lib/Command/PredictImportance.php +++ b/lib/Command/PredictImportance.php @@ -16,7 +16,6 @@ use OCA\Mail\Support\ConsoleLoggerDecorator; use OCP\AppFramework\Db\DoesNotExistException; use OCP\IConfig; -use Psr\Container\ContainerInterface; use Psr\Log\LoggerInterface; use Symfony\Component\Console\Command\Command; use Symfony\Component\Console\Input\InputArgument; @@ -27,47 +26,41 @@ class PredictImportance extends Command { public const ARGUMENT_ACCOUNT_ID = 'account-id'; public const ARGUMENT_SENDER = 'sender'; + public const ARGUMENT_SUBJECT = 'subject'; private AccountService $accountService; private ImportanceClassifier $classifier; private IConfig $config; private LoggerInterface $logger; - private ContainerInterface $container; public function __construct(AccountService $service, ImportanceClassifier $classifier, IConfig $config, - LoggerInterface $logger, - ContainerInterface $container) { + LoggerInterface $logger) { parent::__construct(); $this->accountService = $service; $this->classifier = $classifier; $this->logger = $logger; $this->config = $config; - $this->container = $container; } - /** - * @return void - */ - protected function configure() { + protected function configure(): void { $this->setName('mail:predict-importance'); $this->setDescription('Predict importance of an incoming message'); $this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED); $this->addArgument(self::ARGUMENT_SENDER, InputArgument::REQUIRED); + $this->addArgument(self::ARGUMENT_SUBJECT, InputArgument::OPTIONAL); } - public function isEnabled() { + public function isEnabled(): bool { return $this->config->getSystemValueBool('debug'); } - /** - * @return int - */ protected function execute(InputInterface $input, OutputInterface $output): int { $accountId = (int)$input->getArgument(self::ARGUMENT_ACCOUNT_ID); $sender = $input->getArgument(self::ARGUMENT_SENDER); + $subject = $input->getArgument(self::ARGUMENT_SUBJECT) ?? ''; $consoleLogger = new ConsoleLoggerDecorator( $this->logger, @@ -83,6 +76,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int $fakeMessage = new Message(); $fakeMessage->setUid(0); $fakeMessage->setFrom(AddressList::parse("Name <$sender>")); + $fakeMessage->setSubject($subject); [$prediction] = $this->classifier->classifyImportance( $account, [$fakeMessage], diff --git a/lib/Command/RunMetaEstimator.php b/lib/Command/RunMetaEstimator.php index ea7de222f6..2cfbc02702 100644 --- a/lib/Command/RunMetaEstimator.php +++ b/lib/Command/RunMetaEstimator.php @@ -31,6 +31,7 @@ use OCA\Mail\Service\Classification\ImportanceClassifier; use OCA\Mail\Support\ConsoleLoggerDecorator; use OCP\AppFramework\Db\DoesNotExistException; +use OCP\IConfig; use Psr\Container\ContainerInterface; use Psr\Log\LoggerInterface; use Rubix\ML\Backends\Amp; @@ -56,12 +57,14 @@ class RunMetaEstimator extends Command { private LoggerInterface $logger; private ImportanceClassifier $classifier; private ContainerInterface $container; + private IConfig $config; public function __construct( AccountService $accountService, LoggerInterface $logger, ImportanceClassifier $classifier, ContainerInterface $container, + IConfig $config, ) { parent::__construct(); @@ -69,6 +72,7 @@ public function __construct( $this->logger = $logger; $this->classifier = $classifier; $this->container = $container; + $this->config = $config; } protected function configure(): void { @@ -84,6 +88,10 @@ protected function configure(): void { ); } + public function isEnabled(): bool { + return $this->config->getSystemValueBool('debug'); + } + protected function execute(InputInterface $input, OutputInterface $output): int { $accountId = (int)$input->getArgument(self::ARGUMENT_ACCOUNT_ID); $shuffle = (bool)$input->getOption(self::ARGUMENT_SHUFFLE); diff --git a/lib/Service/Classification/PersistenceService.php b/lib/Service/Classification/PersistenceService.php index 2815e05ad8..dbf659ea03 100644 --- a/lib/Service/Classification/PersistenceService.php +++ b/lib/Service/Classification/PersistenceService.php @@ -337,7 +337,7 @@ private function getTransformerCacheKey(int $id, int $index): string { * @param int $id * @param int $transformerCount * - * @return string[]|null Array of serialized classifier and transformers + * @return (?string)[]|null Array of serialized classifier and transformers */ private function getCached(int $id, int $transformerCount): ?array { if (!$this->cacheFactory->isLocalCacheAvailable()) { @@ -350,6 +350,12 @@ private function getCached(int $id, int $transformerCount): ?array { for ($i = 0; $i < $transformerCount; $i++) { $values[] = $cache->get($this->getTransformerCacheKey($id, $i)); } + + // Only return cached values if estimator and all transformers are available + if (in_array(null, $values, true)) { + return null; + } + return $values; } From c8c214c15e5470f6c7a9d4f5caadc08dff53019d Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Tue, 28 Mar 2023 13:54:27 +0200 Subject: [PATCH 20/37] fixup! feat(importance-classifier): Reduce text feature vector --- .../FeatureExtraction/SubjectExtractor.php | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php index 1ca3fa74ae..15c00bb3fe 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php @@ -29,6 +29,7 @@ use OCA\Mail\Db\Message; use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; +use Rubix\ML\Transformers\MinMaxNormalizer; use Rubix\ML\Transformers\TSNE; use Rubix\ML\Transformers\MultibyteTextNormalizer; use Rubix\ML\Transformers\Transformer; @@ -40,6 +41,7 @@ class SubjectExtractor implements IExtractor { private WordCountVectorizer $wordCountVectorizer; private Transformer $dimensionalReductionTransformer; + private Transformer $normalizer; private int $max = -1; public function __construct() { @@ -48,6 +50,7 @@ public function __construct() { $this->wordCountVectorizer = new WordCountVectorizer($vocabSize); $this->dimensionalReductionTransformer = new TSNE((int)($vocabSize * 0.1)); + $this->normalizer = new MinMaxNormalizer(); } public function getWordCountVectorizer(): WordCountVectorizer { @@ -78,7 +81,8 @@ public function prepare(Account $account, array $incomingMailboxes, array $outgo ) ->apply(new MultibyteTextNormalizer()) ->apply($this->wordCountVectorizer) - ->apply($this->dimensionalReductionTransformer); + ->apply($this->dimensionalReductionTransformer) + ->apply($this->normalizer); $this->limitFeatureSize(); } @@ -99,7 +103,8 @@ public function extract(Message $message): array { $trainDataSet = Unlabeled::build([[$trainText]]) ->apply(new MultibyteTextNormalizer()) ->apply($this->wordCountVectorizer) - ->apply($this->dimensionalReductionTransformer); + ->apply($this->dimensionalReductionTransformer) + ->apply($this->normalizer); // Use zeroed vector if no features could be extracted if ($trainDataSet->numFeatures() === 0) { @@ -108,6 +113,8 @@ public function extract(Message $message): array { $textFeatures = $trainDataSet->sample(0); } + var_dump($textFeatures); + return $textFeatures; } From 51f31bf5dc692e5ed7e6abd98d8c7dc417552971 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Thu, 30 Mar 2023 09:01:42 +0200 Subject: [PATCH 21/37] fixup! fixup! feat(importance-classifier): Reduce text feature vector --- .../Classification/FeatureExtraction/SubjectExtractor.php | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php index 15c00bb3fe..df22f4038a 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php @@ -95,10 +95,9 @@ public function extract(Message $message): array { if ($sender === null) { throw new RuntimeException("This should not happen"); } - $email = $sender->getEmail(); // Build training data set - $trainText = ($message->getSubject() ?? '') . ' ' . ($message->getPreviewText() ?? ''); + $trainText = $message->getSubject() ?? ''; $trainDataSet = Unlabeled::build([[$trainText]]) ->apply(new MultibyteTextNormalizer()) From 3bc398b7f1ac1d2fba21db51b9c42b73eee77181 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Thu, 30 Mar 2023 10:08:31 +0200 Subject: [PATCH 22/37] Try wcv -> tfidf pipeline --- .../FeatureExtraction/SubjectExtractor.php | 21 ++++++++++++------- .../Classification/ImportanceClassifier.php | 7 +++++-- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php index df22f4038a..a5320255b1 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php @@ -30,6 +30,8 @@ use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Transformers\MinMaxNormalizer; +use Rubix\ML\Transformers\PrincipalComponentAnalysis; +use Rubix\ML\Transformers\TfIdfTransformer; use Rubix\ML\Transformers\TSNE; use Rubix\ML\Transformers\MultibyteTextNormalizer; use Rubix\ML\Transformers\Transformer; @@ -42,15 +44,17 @@ class SubjectExtractor implements IExtractor { private WordCountVectorizer $wordCountVectorizer; private Transformer $dimensionalReductionTransformer; private Transformer $normalizer; + private Transformer $tfidf; private int $max = -1; public function __construct() { // Limit vocabulary to limit memory usage - $vocabSize = 100; + $vocabSize = 500; $this->wordCountVectorizer = new WordCountVectorizer($vocabSize); - $this->dimensionalReductionTransformer = new TSNE((int)($vocabSize * 0.1)); - $this->normalizer = new MinMaxNormalizer(); + $this->tfidf = new TfIdfTransformer(); + //$this->dimensionalReductionTransformer = new PrincipalComponentAnalysis((int)($vocabSize * 0.1)); + //$this->normalizer = new MinMaxNormalizer(); } public function getWordCountVectorizer(): WordCountVectorizer { @@ -81,8 +85,8 @@ public function prepare(Account $account, array $incomingMailboxes, array $outgo ) ->apply(new MultibyteTextNormalizer()) ->apply($this->wordCountVectorizer) - ->apply($this->dimensionalReductionTransformer) - ->apply($this->normalizer); + ->apply($this->tfidf) + ;//->apply($this->dimensionalReductionTransformer); $this->limitFeatureSize(); } @@ -102,8 +106,8 @@ public function extract(Message $message): array { $trainDataSet = Unlabeled::build([[$trainText]]) ->apply(new MultibyteTextNormalizer()) ->apply($this->wordCountVectorizer) - ->apply($this->dimensionalReductionTransformer) - ->apply($this->normalizer); + ->apply($this->tfidf) + ;//->apply($this->dimensionalReductionTransformer); // Use zeroed vector if no features could be extracted if ($trainDataSet->numFeatures() === 0) { @@ -112,7 +116,7 @@ public function extract(Message $message): array { $textFeatures = $trainDataSet->sample(0); } - var_dump($textFeatures); + //var_dump($textFeatures); return $textFeatures; } @@ -123,5 +127,6 @@ public function extract(Message $message): array { private function limitFeatureSize(): void { $vocab = $this->wordCountVectorizer->vocabularies()[0]; $this->max = count($vocab); + echo("WCF vocab size: {$this->max}\n"); } } diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index 122c39827b..8aeed8d932 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -93,7 +93,7 @@ class ImportanceClassifier { /** * The maximum number of data sets to train the classifier with */ - private const MAX_TRAINING_SET_SIZE = 350; + private const MAX_TRAINING_SET_SIZE = 333; /** @var MailboxMapper */ private $mailboxMapper; @@ -411,8 +411,11 @@ private function getFeaturesAndImportance(Account $account, throw new RuntimeException('This should not happen'); } + $features = $extractor->extract($message); + //var_dump($features); + return [ - 'features' => $extractor->extract($message), + 'features' => $features, 'label' => $message->getFlagImportant() ? self::LABEL_IMPORTANT : self::LABEL_NOT_IMPORTANT, 'sender' => $sender->getEmail(), ]; From fed201151c4580fc033bd984ae4a16111362a624 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Mon, 15 May 2023 16:36:49 +0200 Subject: [PATCH 23/37] Fix transformer persistence --- .../FeatureExtraction/SubjectExtractor.php | 8 +++++ .../Classification/ImportanceClassifier.php | 36 +++++++++++++------ 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php index a5320255b1..c5f299d475 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php @@ -66,6 +66,14 @@ public function setWordCountVectorizer(WordCountVectorizer $wordCountVectorizer) $this->limitFeatureSize(); } + public function getTfidf(): Transformer { + return $this->tfidf; + } + + public function setTfidf(TfIdfTransformer $tfidf): void { + $this->tfidf = $tfidf; + } + /** * @inheritDoc */ diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index 8aeed8d932..090ccfed18 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -37,6 +37,7 @@ use Rubix\ML\Kernels\Distance\Manhattan; use Rubix\ML\Learner; use Rubix\ML\Persistable; +use Rubix\ML\Transformers\TfIdfTransformer; use Rubix\ML\Transformers\Transformer; use Rubix\ML\Transformers\WordCountVectorizer; use RuntimeException; @@ -341,6 +342,7 @@ public function trainWithCustomDataSet( $transformers = []; if ($extractor instanceof NewCompositeExtractor) { $transformers[] = $extractor->getSubjectExtractor()->getWordCountVectorizer(); + $transformers[] = $extractor->getSubjectExtractor()->getTfidf(); } $classifier->setAccountId($account->getId()); @@ -358,12 +360,16 @@ public function trainWithCustomDataSet( */ private function getIncomingMailboxes(Account $account): array { return array_filter($this->mailboxMapper->findAll($account), static function (Mailbox $mailbox) { + return $mailbox->isInbox(); + + /* foreach (self::EXEMPT_FROM_TRAINING as $excluded) { if ($mailbox->isSpecialUse($excluded)) { return false; } } return true; + */ }); } @@ -464,17 +470,27 @@ public function classifyImportance(Account $account, // Load persisted transformers of the subject extractor. // Is a bit hacky but a full abstraction would be overkill. $transformers = $pipeline->getTransformers(); - $wordCountVectorizer = $transformers[0]; - if (!($wordCountVectorizer instanceof WordCountVectorizer)) { - throw new ServiceException("Failed to load persisted transformer: Expected " . WordCountVectorizer::class . ", got" . $wordCountVectorizer::class); - } + if (count($transformers) === 2) { + $wordCountVectorizer = $transformers[0]; + if (!($wordCountVectorizer instanceof WordCountVectorizer)) { + throw new ServiceException("Failed to load persisted transformer: Expected " . WordCountVectorizer::class . ", got" . $wordCountVectorizer::class); + } + $tfidfTransformer = $transformers[1]; + if (!($tfidfTransformer instanceof TfIdfTransformer)) { + throw new ServiceException("Failed to load persisted transformer: Expected " . TfIdfTransformer::class . ", got" . $tfidfTransformer::class); + } - $subjectExtractor = new SubjectExtractor(); - $subjectExtractor->setWordCountVectorizer($wordCountVectorizer); - $extractor = new NewCompositeExtractor( - $this->vanillaExtractor, - $subjectExtractor, - ); + $subjectExtractor = new SubjectExtractor(); + $subjectExtractor->setWordCountVectorizer($wordCountVectorizer); + $subjectExtractor->setTfidf($tfidfTransformer); + $extractor = new NewCompositeExtractor( + $this->vanillaExtractor, + $subjectExtractor, + ); + } else { + $logger->warning('Falling back to vanilla feature extractor'); + $extractor = $this->vanillaExtractor; + } $features = $this->getFeaturesAndImportance( $account, From e2c057c08da188e5470be2cd1cafebd01de3a4cf Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Mon, 15 May 2023 16:37:55 +0200 Subject: [PATCH 24/37] Refactor classifcation of new messages --- lib/AppInfo/Application.php | 2 +- .../NewMessageClassificationListener.php | 121 ----------------- .../Classification/NewMessagesClassifier.php | 125 ++++++++++++++++++ lib/Service/Sync/ImapToDbSynchronizer.php | 32 ++++- 4 files changed, 156 insertions(+), 124 deletions(-) delete mode 100644 lib/Listener/NewMessageClassificationListener.php create mode 100644 lib/Service/Classification/NewMessagesClassifier.php diff --git a/lib/AppInfo/Application.php b/lib/AppInfo/Application.php index 48c85854f2..b4499fec62 100644 --- a/lib/AppInfo/Application.php +++ b/lib/AppInfo/Application.php @@ -29,7 +29,6 @@ use OCA\Mail\Events\MessageDeletedEvent; use OCA\Mail\Events\MessageFlaggedEvent; use OCA\Mail\Events\MessageSentEvent; -use OCA\Mail\Events\NewMessagesSynchronized; use OCA\Mail\Events\OutboxMessageCreatedEvent; use OCA\Mail\Events\SynchronizationEvent; use OCA\Mail\HordeTranslationHandler; @@ -127,6 +126,7 @@ public function register(IRegistrationContext $context): void { $context->registerEventListener(NewMessagesSynchronized::class, NewMessageClassificationListener::class); $context->registerEventListener(NewMessagesSynchronized::class, MessageKnownSinceListener::class); $context->registerEventListener(NewMessagesSynchronized::class, NewMessagesNotifier::class); + $context->registerEventListener(MessageSentEvent::class, SaveSentMessageListener::class); $context->registerEventListener(SynchronizationEvent::class, AccountSynchronizedThreadUpdaterListener::class); $context->registerEventListener(UserDeletedEvent::class, UserDeletedListener::class); $context->registerEventListener(NewMessagesSynchronized::class, FollowUpClassifierListener::class); diff --git a/lib/Listener/NewMessageClassificationListener.php b/lib/Listener/NewMessageClassificationListener.php deleted file mode 100644 index 51d269c958..0000000000 --- a/lib/Listener/NewMessageClassificationListener.php +++ /dev/null @@ -1,121 +0,0 @@ - - */ -class NewMessageClassificationListener implements IEventListener { - private const EXEMPT_FROM_CLASSIFICATION = [ - Horde_Imap_Client::SPECIALUSE_ARCHIVE, - Horde_Imap_Client::SPECIALUSE_DRAFTS, - Horde_Imap_Client::SPECIALUSE_JUNK, - Horde_Imap_Client::SPECIALUSE_SENT, - Horde_Imap_Client::SPECIALUSE_TRASH, - ]; - - /** @var ImportanceClassifier */ - private $classifier; - - /** @var TagMapper */ - private $tagMapper; - - /** @var LoggerInterface */ - private $logger; - - /** @var IMailManager */ - private $mailManager; - - private ClassificationSettingsService $classificationSettingsService; - - public function __construct(ImportanceClassifier $classifier, - TagMapper $tagMapper, - LoggerInterface $logger, - IMailManager $mailManager, - ClassificationSettingsService $classificationSettingsService) { - $this->classifier = $classifier; - $this->logger = $logger; - $this->tagMapper = $tagMapper; - $this->mailManager = $mailManager; - $this->classificationSettingsService = $classificationSettingsService; - } - - public function handle(Event $event): void { - if (!($event instanceof NewMessagesSynchronized)) { - return; - } - - if (!$this->classificationSettingsService->isClassificationEnabled($event->getAccount()->getUserId())) { - return; - } - - foreach (self::EXEMPT_FROM_CLASSIFICATION as $specialUse) { - if ($event->getMailbox()->isSpecialUse($specialUse)) { - // Nothing to do then - return; - } - } - - $messages = $event->getMessages(); - - // if this is a message that's been flagged / tagged as important before, we don't want to reclassify it again. - $doNotReclassify = $this->tagMapper->getTaggedMessageIdsForMessages( - $event->getMessages(), - $event->getAccount()->getUserId(), - Tag::LABEL_IMPORTANT - ); - $messages = array_filter($messages, static function ($message) use ($doNotReclassify) { - return ($message->getFlagImportant() === false || in_array($message->getMessageId(), $doNotReclassify, true)); - }); - - try { - $important = $this->tagMapper->getTagByImapLabel(Tag::LABEL_IMPORTANT, $event->getAccount()->getUserId()); - } catch (DoesNotExistException $e) { - // just in case - if we get here, the tag is missing - $this->logger->error('Could not find important tag for ' . $event->getAccount()->getUserId() . ' ' . $e->getMessage(), [ - 'exception' => $e, - ]); - return; - } - - try { - $predictions = $this->classifier->classifyImportance( - $event->getAccount(), - $messages, - $this->logger - ); - - foreach ($event->getMessages() as $message) { - $this->logger->info("Message {$message->getUid()} ({$message->getPreviewText()}) is " . ($predictions[$message->getUid()] ? 'important' : 'not important')); - if ($predictions[$message->getUid()] ?? false) { - $this->mailManager->flagMessage($event->getAccount(), $event->getMailbox()->getName(), $message->getUid(), Tag::LABEL_IMPORTANT, true); - $this->mailManager->tagMessage($event->getAccount(), $event->getMailbox()->getName(), $message, $important, true); - } - } - } catch (ServiceException $e) { - $this->logger->error('Could not classify incoming message importance: ' . $e->getMessage(), [ - 'exception' => $e, - ]); - } - } -} diff --git a/lib/Service/Classification/NewMessagesClassifier.php b/lib/Service/Classification/NewMessagesClassifier.php new file mode 100644 index 0000000000..3b2fe4f6bb --- /dev/null +++ b/lib/Service/Classification/NewMessagesClassifier.php @@ -0,0 +1,125 @@ + + * + * @author Richard Steinmetz + * + * @license AGPL-3.0-or-later + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +namespace OCA\Mail\Service\Classification; + +use Horde_Imap_Client; +use OCA\Mail\Account; +use OCA\Mail\Contracts\IMailManager; +use OCA\Mail\Contracts\IUserPreferences; +use OCA\Mail\Db\Mailbox; +use OCA\Mail\Db\Message; +use OCA\Mail\Db\Tag; +use OCA\Mail\Db\TagMapper; +use OCA\Mail\Exception\ClientException; +use OCA\Mail\Exception\ServiceException; +use Psr\Log\LoggerInterface; + +class NewMessagesClassifier { + private const EXEMPT_FROM_CLASSIFICATION = [ + Horde_Imap_Client::SPECIALUSE_ARCHIVE, + Horde_Imap_Client::SPECIALUSE_DRAFTS, + Horde_Imap_Client::SPECIALUSE_JUNK, + Horde_Imap_Client::SPECIALUSE_SENT, + Horde_Imap_Client::SPECIALUSE_TRASH, + ]; + + public function __construct( + private ImportanceClassifier $classifier, + private TagMapper $tagMapper, + private LoggerInterface $logger, + private IMailManager $mailManager, + private IUserPreferences $preferences) { + } + + /** + * Classify a batch on freshly synced messages. + * Objects in the incoming $messages array are mutated in place. + * + * The importance tag will be propagated to IMAP and its mapping will be persisted to the db. + * However, changes to db message objects themselves won't be persisted. + * This is up to the caller (e.g. MessageMapper->insertBulk()). + * + * @param Message[] $messages + * @param Mailbox $mailbox + * @param Account $account + * @param Tag $importantTag + * @return void + */ + public function classifyNewMessages( + array $messages, + Mailbox $mailbox, + Account $account, + Tag $importantTag, + ): void { + $allowTagging = $this->preferences->getPreference($account->getUserId(), 'tag-classified-messages'); + if ($allowTagging === 'false') { + return; + } + + foreach (self::EXEMPT_FROM_CLASSIFICATION as $specialUse) { + if ($mailbox->isSpecialUse($specialUse)) { + // Nothing to do then + return; + } + } + + // if this is a message that's been flagged / tagged as important before, we don't want to reclassify it again. + $doNotReclassify = $this->tagMapper->getTaggedMessageIdsForMessages( + $messages, + $account->getUserId(), + Tag::LABEL_IMPORTANT + ); + $messages = array_filter($messages, static function ($message) use ($doNotReclassify) { + return ($message->getFlagImportant() === false || in_array($message->getMessageId(), $doNotReclassify, true)); + }); + + try { + $predictions = $this->classifier->classifyImportance( + $account, + $messages, + $this->logger + ); + + foreach ($messages as $message) { + $this->logger->info("Message {$message->getUid()} ({$message->getPreviewText()}) is " . ($predictions[$message->getUid()] ? 'important' : 'not important')); + if ($predictions[$message->getUid()] ?? false) { + $message->setFlagImportant(true); + $this->mailManager->flagMessage($account, $mailbox->getName(), $message->getUid(), Tag::LABEL_IMPORTANT, true); + $this->mailManager->tagMessage($account, $mailbox->getName(), $message, $importantTag, true); + } + } + } catch (ServiceException $e) { + $this->logger->error('Could not classify incoming message importance: ' . $e->getMessage(), [ + 'exception' => $e, + ]); + } catch (ClientException $e) { + $this->logger->error('Could not persist incoming message importance to IMAP: ' . $e->getMessage(), [ + 'exception' => $e, + ]); + } + } +} diff --git a/lib/Service/Sync/ImapToDbSynchronizer.php b/lib/Service/Sync/ImapToDbSynchronizer.php index 0054bc54b2..65be294b6f 100644 --- a/lib/Service/Sync/ImapToDbSynchronizer.php +++ b/lib/Service/Sync/ImapToDbSynchronizer.php @@ -18,7 +18,8 @@ use OCA\Mail\Db\Mailbox; use OCA\Mail\Db\MailboxMapper; use OCA\Mail\Db\MessageMapper as DatabaseMessageMapper; -use OCA\Mail\Events\NewMessagesSynchronized; +use OCA\Mail\Db\Tag; +use OCA\Mail\Db\TagMapper; use OCA\Mail\Events\SynchronizationEvent; use OCA\Mail\Exception\ClientException; use OCA\Mail\Exception\IncompleteSyncException; @@ -33,7 +34,9 @@ use OCA\Mail\IMAP\Sync\Request; use OCA\Mail\IMAP\Sync\Synchronizer; use OCA\Mail\Model\IMAPMessage; +use OCA\Mail\Service\Classification\NewMessagesClassifier; use OCA\Mail\Support\PerformanceLogger; +use OCP\AppFramework\Db\DoesNotExistException; use OCP\EventDispatcher\IEventDispatcher; use Psr\Log\LoggerInterface; use Throwable; @@ -74,17 +77,22 @@ class ImapToDbSynchronizer { private $mailManager; private PreviewEnhancer $previewEnhancer; + private TagMapper $tagMapper; + private NewMessagesClassifier $newMessagesClassifier; public function __construct(DatabaseMessageMapper $dbMapper, IMAPClientFactory $clientFactory, ImapMessageMapper $imapMapper, MailboxMapper $mailboxMapper, + DatabaseMessageMapper $messageMapper, Synchronizer $synchronizer, IEventDispatcher $dispatcher, PerformanceLogger $performanceLogger, LoggerInterface $logger, IMailManager $mailManager, - PreviewEnhancer $previewEnhancer) { + PreviewEnhancer $previewEnhancer, + TagMapper $tagMapper, + NewMessagesClassifier $newMessagesClassifier) { $this->dbMapper = $dbMapper; $this->clientFactory = $clientFactory; $this->imapMapper = $imapMapper; @@ -95,6 +103,8 @@ public function __construct(DatabaseMessageMapper $dbMapper, $this->logger = $logger; $this->mailManager = $mailManager; $this->previewEnhancer = $previewEnhancer; + $this->tagMapper = $tagMapper; + $this->newMessagesClassifier = $newMessagesClassifier; } /** @@ -418,6 +428,15 @@ private function runPartialSync( }); } + $importantTag = null; + try { + $importantTag = $this->tagMapper->getTagByImapLabel(Tag::LABEL_IMPORTANT, $account->getUserId()); + } catch (DoesNotExistException $e) { + $this->logger->error('Could not find important tag for ' . $account->getUserId(). ' ' . $e->getMessage(), [ + 'exception' => $e, + ]); + } + foreach (array_chunk($newMessages, 500) as $chunk) { $dbMessages = array_map(static function (IMAPMessage $imapMessage) use ($mailbox, $account) { return $imapMessage->toDbMessage($mailbox->getId(), $account->getMailAccount()); @@ -425,6 +444,15 @@ private function runPartialSync( $this->dbMapper->insertBulk($account, ...$dbMessages); + if ($importantTag) { + $this->newMessagesClassifier->classifyNewMessages( + $dbMessages, + $mailbox, + $account, + $importantTag, + ); + } + $this->dispatcher->dispatch( NewMessagesSynchronized::class, new NewMessagesSynchronized($account, $mailbox, $dbMessages) From bb9056dd0a2a671b95755ae0a24bc6b330ce067c Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Wed, 17 May 2023 12:35:59 +0200 Subject: [PATCH 25/37] Refactor peristence --- lib/Db/Classifier.php | 9 -- .../Version3100Date20230324113141.php | 56 -------- .../Classification/ImportanceClassifier.php | 34 +---- .../Classification/NewMessagesClassifier.php | 3 +- .../Classification/PersistenceService.php | 124 ++++++++++++++++-- 5 files changed, 122 insertions(+), 104 deletions(-) delete mode 100644 lib/Migration/Version3100Date20230324113141.php diff --git a/lib/Db/Classifier.php b/lib/Db/Classifier.php index 9083c5e335..51f9b25a93 100644 --- a/lib/Db/Classifier.php +++ b/lib/Db/Classifier.php @@ -48,9 +48,6 @@ * @method int getCreatedAt() * @method void setCreatedAt(int $createdAt) * - * @method int getTransformerCount() - * @method void setTransformerCount(int $transformerCount) - * * @method string|null getTransformers() * @method void setTransformers(string|null $transformers) */ @@ -93,12 +90,6 @@ class Classifier extends Entity { /** @var int */ protected $createdAt; - /** @var int */ - protected $transformerCount; - - /** @var string */ - protected $transformers; - public function __construct() { $this->addType('accountId', 'integer'); $this->addType('type', 'string'); diff --git a/lib/Migration/Version3100Date20230324113141.php b/lib/Migration/Version3100Date20230324113141.php deleted file mode 100644 index 2181e24a73..0000000000 --- a/lib/Migration/Version3100Date20230324113141.php +++ /dev/null @@ -1,56 +0,0 @@ - - * - * @author Richard Steinmetz - * - * @license GNU AGPL version 3 or any later version - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as - * published by the Free Software Foundation, either version 3 of the - * License, or (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - * - */ - -namespace OCA\Mail\Migration; - -use Closure; -use OCP\DB\ISchemaWrapper; -use OCP\DB\Types; -use OCP\Migration\IOutput; -use OCP\Migration\SimpleMigrationStep; - -class Version3100Date20230324113141 extends SimpleMigrationStep { - /** - * @param IOutput $output - * @param Closure(): ISchemaWrapper $schemaClosure - * @param array $options - * @return null|ISchemaWrapper - */ - public function changeSchema(IOutput $output, Closure $schemaClosure, array $options): ?ISchemaWrapper { - /** @var ISchemaWrapper $schema */ - $schema = $schemaClosure(); - - $classifierTable = $schema->getTable('mail_classifiers'); - if (!$classifierTable->hasColumn('transformer_count')) { - $classifierTable->addColumn('transformer_count', Types::INTEGER, [ - 'notnull' => true, - 'default' => 0, - ]); - } - - return $schema; - } -} diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index 090ccfed18..882572185b 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -444,7 +444,7 @@ public function classifyImportance(Account $account, try { $pipeline = $this->persistenceService->loadLatest($account); } catch (ServiceException $e) { - $logger->warning('Failed to load importance classifier: ' . $e->getMessage(), [ + $logger->warning('Failed to load persisted estimator and extractor: ' . $e->getMessage(), [ 'exception' => $e, ]); } @@ -465,41 +465,17 @@ public function classifyImportance(Account $account, }, $messages) ); } - $messagesWithSender = array_filter($messages, [$this, 'filterMessageHasSenderEmail']); - - // Load persisted transformers of the subject extractor. - // Is a bit hacky but a full abstraction would be overkill. - $transformers = $pipeline->getTransformers(); - if (count($transformers) === 2) { - $wordCountVectorizer = $transformers[0]; - if (!($wordCountVectorizer instanceof WordCountVectorizer)) { - throw new ServiceException("Failed to load persisted transformer: Expected " . WordCountVectorizer::class . ", got" . $wordCountVectorizer::class); - } - $tfidfTransformer = $transformers[1]; - if (!($tfidfTransformer instanceof TfIdfTransformer)) { - throw new ServiceException("Failed to load persisted transformer: Expected " . TfIdfTransformer::class . ", got" . $tfidfTransformer::class); - } - - $subjectExtractor = new SubjectExtractor(); - $subjectExtractor->setWordCountVectorizer($wordCountVectorizer); - $subjectExtractor->setTfidf($tfidfTransformer); - $extractor = new NewCompositeExtractor( - $this->vanillaExtractor, - $subjectExtractor, - ); - } else { - $logger->warning('Falling back to vanilla feature extractor'); - $extractor = $this->vanillaExtractor; - } + [$estimator, $extractor] = $pipeline; + $messagesWithSender = array_filter($messages, [$this, 'filterMessageHasSenderEmail']); $features = $this->getFeaturesAndImportance( $account, $this->getIncomingMailboxes($account), $this->getOutgoingMailboxes($account), $messagesWithSender, - $extractor + $extractor, ); - $predictions = $pipeline->getEstimator()->predict( + $predictions = $estimator->predict( Unlabeled::build(array_column($features, 'features')) ); return array_combine( diff --git a/lib/Service/Classification/NewMessagesClassifier.php b/lib/Service/Classification/NewMessagesClassifier.php index 3b2fe4f6bb..ad9029c03a 100644 --- a/lib/Service/Classification/NewMessagesClassifier.php +++ b/lib/Service/Classification/NewMessagesClassifier.php @@ -52,7 +52,8 @@ public function __construct( private TagMapper $tagMapper, private LoggerInterface $logger, private IMailManager $mailManager, - private IUserPreferences $preferences) { + private IUserPreferences $preferences, + ) { } /** diff --git a/lib/Service/Classification/PersistenceService.php b/lib/Service/Classification/PersistenceService.php index dbf659ea03..f6b9fcdd7a 100644 --- a/lib/Service/Classification/PersistenceService.php +++ b/lib/Service/Classification/PersistenceService.php @@ -17,6 +17,10 @@ use OCA\Mail\Db\MailAccountMapper; use OCA\Mail\Exception\ServiceException; use OCA\Mail\Model\ClassifierPipeline; +use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; +use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor; +use OCA\Mail\Service\Classification\FeatureExtraction\SubjectExtractor; +use OCA\Mail\Service\Classification\FeatureExtraction\VanillaCompositeExtractor; use OCP\App\IAppManager; use OCP\AppFramework\Db\DoesNotExistException; use OCP\AppFramework\Utility\ITimeFactory; @@ -26,14 +30,17 @@ use OCP\Files\NotPermittedException; use OCP\ICacheFactory; use OCP\ITempManager; +use Psr\Container\ContainerExceptionInterface; +use Psr\Container\ContainerInterface; use Psr\Log\LoggerInterface; -use Rubix\ML\Estimator; use Rubix\ML\Learner; use Rubix\ML\Persistable; use Rubix\ML\PersistentModel; use Rubix\ML\Persisters\Filesystem; use Rubix\ML\Serializers\RBX; +use Rubix\ML\Transformers\TfIdfTransformer; use Rubix\ML\Transformers\Transformer; +use Rubix\ML\Transformers\WordCountVectorizer; use RuntimeException; use function file_get_contents; use function file_put_contents; @@ -67,6 +74,8 @@ class PersistenceService { /** @var MailAccountMapper */ private $accountMapper; + private ContainerInterface $container; + public function __construct(ClassifierMapper $mapper, IAppData $appData, ITempManager $tempManager, @@ -74,7 +83,8 @@ public function __construct(ClassifierMapper $mapper, IAppManager $appManager, ICacheFactory $cacheFactory, LoggerInterface $logger, - MailAccountMapper $accountMapper) { + MailAccountMapper $accountMapper, + ContainerInterface $container) { $this->mapper = $mapper; $this->appData = $appData; $this->tempManager = $tempManager; @@ -83,6 +93,7 @@ public function __construct(ClassifierMapper $mapper, $this->cacheFactory = $cacheFactory; $this->logger = $logger; $this->accountMapper = $accountMapper; + $this->container = $container; } /** @@ -178,8 +189,6 @@ public function persist(Classifier $classifier, $transformerIndex++; } - $classifier->setTransformerCount($transformerIndex); - /* * Now we set the model active so it can be used by the next request */ @@ -190,17 +199,29 @@ public function persist(Classifier $classifier, /** * @param Account $account * - * @return ?ClassifierPipeline + * @return ?array [Estimator, IExtractor] * * @throws ServiceException */ - public function loadLatest(Account $account): ?ClassifierPipeline { + public function loadLatest(Account $account): ?array { try { $latestModel = $this->mapper->findLatest($account->getId()); } catch (DoesNotExistException $e) { return null; } - return $this->load($latestModel); + + $pipeline = $this->load($latestModel); + try { + $extractor = $this->loadExtractor($latestModel, $pipeline); + } catch (ContainerExceptionInterface $e) { + throw new ServiceException( + "Failed to load extractor: {$e->getMessage()}", + 0, + $e, + ); + } + + return [$pipeline->getEstimator(), $extractor]; } /** @@ -212,8 +233,14 @@ public function loadLatest(Account $account): ?ClassifierPipeline { * @throws ServiceException */ public function load(Classifier $classifier): ClassifierPipeline { + $transformerCount = 0; + $appVersion = $this->parseAppVersion($classifier->getAppVersion()); + if ($appVersion[0] >= 3 && $appVersion[1] >= 2) { + $transformerCount = 2; + } + $id = $classifier->getId(); - $cached = $this->getCached($classifier->getId(), $classifier->getTransformerCount()); + $cached = $this->getCached($classifier->getId(), $transformerCount); if ($cached !== null) { $this->logger->debug("Using cached serialized classifier $id"); $serialized = $cached[0]; @@ -238,7 +265,7 @@ public function load(Classifier $classifier): ClassifierPipeline { $this->logger->debug("Serialized classifier loaded (size=$size)"); $serializedTransformers = []; - for ($i = 0; $i < $classifier->getTransformerCount(); $i++) { + for ($i = 0; $i < $transformerCount; $i++) { try { $transformerFile = $modelsFolder->getFile("{$id}_t$i"); } catch (NotFoundException $e) { @@ -323,6 +350,60 @@ private function deleteModel(int $id): void { 'exception' => $e, ]); } + + /** + * Load and instantiate extractor based on a classifier's app version. + * + * @param Classifier $classifier + * @param ClassifierPipeline $pipeline + * @return IExtractor + * + * @throws ContainerExceptionInterface + * @throws ServiceException + */ + private function loadExtractor(Classifier $classifier, + ClassifierPipeline $pipeline): IExtractor { + $appVersion = $this->parseAppVersion($classifier->getAppVersion()); + if ($appVersion[0] >= 3 && $appVersion[1] >= 2) { + return $this->loadExtractorV2($pipeline->getTransformers()); + } + + return $this->loadExtractorV1($pipeline->getTransformers()); + } + + /** + * @return VanillaCompositeExtractor + * + * @throws ContainerExceptionInterface + */ + private function loadExtractorV1(): VanillaCompositeExtractor { + return $this->container->get(VanillaCompositeExtractor::class); + } + + /** + * @param Transformer[] $transformers + * @return NewCompositeExtractor + * + * @throws ContainerExceptionInterface + * @throws ServiceException + */ + private function loadExtractorV2(array $transformers): NewCompositeExtractor { + $wordCountVectorizer = $transformers[0]; + if (!($wordCountVectorizer instanceof WordCountVectorizer)) { + throw new ServiceException("Failed to load persisted transformer: Expected " . WordCountVectorizer::class . ", got" . $wordCountVectorizer::class); + } + $tfidfTransformer = $transformers[1]; + if (!($tfidfTransformer instanceof TfIdfTransformer)) { + throw new ServiceException("Failed to load persisted transformer: Expected " . TfIdfTransformer::class . ", got" . $tfidfTransformer::class); + } + + $subjectExtractor = new SubjectExtractor(); + $subjectExtractor->setWordCountVectorizer($wordCountVectorizer); + $subjectExtractor->setTfidf($tfidfTransformer); + return new NewCompositeExtractor( + $this->container->get(VanillaCompositeExtractor::class), + $subjectExtractor, + ); } private function getCacheKey(int $id): string { @@ -340,6 +421,9 @@ private function getTransformerCacheKey(int $id, int $index): string { * @return (?string)[]|null Array of serialized classifier and transformers */ private function getCached(int $id, int $transformerCount): ?array { + // FIXME: Will always return null as the cached, serialized data is always an empty string. + // See my note in self::cache() for further elaboration. + if (!$this->cacheFactory->isLocalCacheAvailable()) { return null; } @@ -360,6 +444,14 @@ private function getCached(int $id, int $transformerCount): ?array { } private function cache(int $id, string $serialized, array $serializedTransformers): void { + // FIXME: This is broken as some cache implementations will run the provided value through + // json_encode which drops non-utf8 strings. The serialized string contains binary + // data so an empty string will be saved instead (tested on Redis). + // Note: JSON requires strings to be valid utf8 (as per its spec). + + // IDEA: Implement a method ICache::setRaw() that forwards a raw/binary string as is to the + // underlying cache backend. + if (!$this->cacheFactory->isLocalCacheAvailable()) { return; } @@ -372,4 +464,18 @@ private function cache(int $id, string $serialized, array $serializedTransformer $transformerIndex++; } } + + /** + * Parse minor and major part of the given semver string. + * + * @return int[] + */ + private function parseAppVersion(string $version): array { + $parts = explode('.', $version); + if (count($parts) < 2) { + return [0, 0]; + } + + return [(int)$parts[0], (int)$parts[1]]; + } } From 71b1b5fef778c2ce5ab77ae90b1e3a91cd8d13bc Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Wed, 14 Jun 2023 10:53:38 +0200 Subject: [PATCH 26/37] Adjust meta estimator params --- lib/Command/RunMetaEstimator.php | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Command/RunMetaEstimator.php b/lib/Command/RunMetaEstimator.php index 2cfbc02702..bf47ce1862 100644 --- a/lib/Command/RunMetaEstimator.php +++ b/lib/Command/RunMetaEstimator.php @@ -125,7 +125,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int $estimator = static function () use ($consoleLogger) { $params = [ - [5, 10, 15, 20, 25, 30], // Neighbors + [5, 10, 15, 20, 25, 30, 35, 40], // Neighbors [true, false], // Weighted? [new Euclidean(), new Manhattan(), new Jaccard()], // Kernel ]; @@ -134,7 +134,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int KNearestNeighbors::class, $params, new FBeta(), - new KFold(3) + new KFold(5) ); $estimator->setLogger($consoleLogger); $estimator->setBackend(new Amp()); From 877be83bca01b2b4c7f304ae25de4e28c9dbc3e5 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Wed, 14 Jun 2023 10:54:16 +0200 Subject: [PATCH 27/37] Change training sample size to 300 --- lib/Service/Classification/ImportanceClassifier.php | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index 882572185b..ec25a690b1 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -94,7 +94,7 @@ class ImportanceClassifier { /** * The maximum number of data sets to train the classifier with */ - private const MAX_TRAINING_SET_SIZE = 333; + private const MAX_TRAINING_SET_SIZE = 300; /** @var MailboxMapper */ private $mailboxMapper; From 7c740b13f0695154777b5b4b1c5950f3a7e7cfbf Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Wed, 14 Jun 2023 10:54:44 +0200 Subject: [PATCH 28/37] Adjust tuned knn params --- lib/Service/Classification/ImportanceClassifier.php | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index ec25a690b1..46960843f8 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -35,6 +35,7 @@ use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Estimator; use Rubix\ML\Kernels\Distance\Manhattan; +use Rubix\ML\Kernels\Distance\Jaccard; use Rubix\ML\Learner; use Rubix\ML\Persistable; use Rubix\ML\Transformers\TfIdfTransformer; @@ -133,10 +134,8 @@ public function __construct(MailboxMapper $mailboxMapper, private static function createDefaultEstimator(): KNearestNeighbors { // A meta estimator was trained on the same data multiple times to average out the // variance of the trained model. - // Parameters were chosen from the best configuration across 100 runs. - // Both variance (spread) and f1 score were considered. - // Note: Lower k values yield slightly higher f1 scores but show higher variances. - return new KNearestNeighbors(15, true, new Manhattan()); + // Parameters were chosen from the best configuration across 20 runs. + return new KNearestNeighbors(35, true, new Jaccard()); } private function filterMessageHasSenderEmail(Message $message): bool { From 1a5c5b5e17ec491f0e756573b38e19b9c98af8c3 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Thu, 17 Oct 2024 20:47:16 +0200 Subject: [PATCH 29/37] Fix reuse compliance --- lib/Command/PreprocessAccount.php | 20 ++---------------- lib/Command/RunMetaEstimator.php | 21 ++----------------- lib/Model/ClassifierPipeline.php | 21 ++----------------- .../NewCompositeExtractor.php | 20 ++---------------- .../FeatureExtraction/SubjectExtractor.php | 20 ++---------------- .../VanillaCompositeExtractor.php | 20 ++---------------- .../Classification/NewMessagesClassifier.php | 21 ++----------------- 7 files changed, 14 insertions(+), 129 deletions(-) diff --git a/lib/Command/PreprocessAccount.php b/lib/Command/PreprocessAccount.php index 07caa588e6..dba57f696b 100644 --- a/lib/Command/PreprocessAccount.php +++ b/lib/Command/PreprocessAccount.php @@ -3,24 +3,8 @@ declare(strict_types=1); /** - * @copyright Copyright (c) 2023 Richard Steinmetz - * - * @author Richard Steinmetz - * - * @license AGPL-3.0-or-later - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as - * published by the Free Software Foundation, either version 3 of the - * License, or (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . + * SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors + * SPDX-License-Identifier: AGPL-3.0-or-later */ namespace OCA\Mail\Command; diff --git a/lib/Command/RunMetaEstimator.php b/lib/Command/RunMetaEstimator.php index bf47ce1862..8a00c1bafe 100644 --- a/lib/Command/RunMetaEstimator.php +++ b/lib/Command/RunMetaEstimator.php @@ -3,25 +3,8 @@ declare(strict_types=1); /** - * @copyright Copyright (c) 2023 Richard Steinmetz - * - * @author Richard Steinmetz - * - * @license AGPL-3.0-or-later - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * + * SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors + * SPDX-License-Identifier: AGPL-3.0-or-later */ namespace OCA\Mail\Command; diff --git a/lib/Model/ClassifierPipeline.php b/lib/Model/ClassifierPipeline.php index 39f7f7c259..eca0098195 100644 --- a/lib/Model/ClassifierPipeline.php +++ b/lib/Model/ClassifierPipeline.php @@ -3,25 +3,8 @@ declare(strict_types=1); /** - * @copyright Copyright (c) 2023 Richard Steinmetz - * - * @author Richard Steinmetz - * - * @license AGPL-3.0-or-later - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * + * SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors + * SPDX-License-Identifier: AGPL-3.0-or-later */ namespace OCA\Mail\Model; diff --git a/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php index 46d48cb187..c07f56fe42 100644 --- a/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php @@ -3,24 +3,8 @@ declare(strict_types=1); /** - * @copyright Copyright (c) 2023 Richard Steinmetz - * - * @author Richard Steinmetz - * - * @license AGPL-3.0-or-later - * - * This code is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License, version 3, - * as published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License, version 3, - * along with this program. If not, see - * + * SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors + * SPDX-License-Identifier: AGPL-3.0-or-later */ namespace OCA\Mail\Service\Classification\FeatureExtraction; diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php index c5f299d475..e73aba277a 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php @@ -3,24 +3,8 @@ declare(strict_types=1); /** - * @copyright Copyright (c) 2023 Richard Steinmetz - * - * @author Richard Steinmetz - * - * @license AGPL-3.0-or-later - * - * This code is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License, version 3, - * as published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License, version 3, - * along with this program. If not, see - * + * SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors + * SPDX-License-Identifier: AGPL-3.0-or-later */ namespace OCA\Mail\Service\Classification\FeatureExtraction; diff --git a/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php index 0d907ea594..69dfdfca51 100644 --- a/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php @@ -3,24 +3,8 @@ declare(strict_types=1); /** - * @copyright Copyright (c) 2023 Richard Steinmetz - * - * @author Richard Steinmetz - * - * @license AGPL-3.0-or-later - * - * This code is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License, version 3, - * as published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License, version 3, - * along with this program. If not, see - * + * SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors + * SPDX-License-Identifier: AGPL-3.0-or-later */ namespace OCA\Mail\Service\Classification\FeatureExtraction; diff --git a/lib/Service/Classification/NewMessagesClassifier.php b/lib/Service/Classification/NewMessagesClassifier.php index ad9029c03a..64602f55b4 100644 --- a/lib/Service/Classification/NewMessagesClassifier.php +++ b/lib/Service/Classification/NewMessagesClassifier.php @@ -3,25 +3,8 @@ declare(strict_types=1); /** - * @copyright Copyright (c) 2023 Richard Steinmetz - * - * @author Richard Steinmetz - * - * @license AGPL-3.0-or-later - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * + * SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors + * SPDX-License-Identifier: AGPL-3.0-or-later */ namespace OCA\Mail\Service\Classification; From a7510ca1df6b08670f66b001ccbc8777bbbf1e99 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Thu, 17 Oct 2024 20:50:29 +0200 Subject: [PATCH 30/37] Run composer cs:fix --- .../NewCompositeExtractor.php | 2 +- .../FeatureExtraction/SubjectExtractor.php | 9 ++++----- .../VanillaCompositeExtractor.php | 6 +++--- .../Classification/ImportanceClassifier.php | 10 +++------- .../Classification/PersistenceService.php | 19 ++++++++++--------- lib/Service/Sync/ImapToDbSynchronizer.php | 2 +- 6 files changed, 22 insertions(+), 26 deletions(-) diff --git a/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php index c07f56fe42..b7d86622ca 100644 --- a/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php @@ -13,7 +13,7 @@ class NewCompositeExtractor extends CompositeExtractor { private SubjectExtractor $subjectExtractor; public function __construct(VanillaCompositeExtractor $ex1, - SubjectExtractor $ex2) { + SubjectExtractor $ex2) { parent::__construct([$ex1, $ex2]); $this->subjectExtractor = $ex2; } diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php index e73aba277a..20dcbeae34 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php @@ -14,10 +14,9 @@ use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Transformers\MinMaxNormalizer; +use Rubix\ML\Transformers\MultibyteTextNormalizer; use Rubix\ML\Transformers\PrincipalComponentAnalysis; use Rubix\ML\Transformers\TfIdfTransformer; -use Rubix\ML\Transformers\TSNE; -use Rubix\ML\Transformers\MultibyteTextNormalizer; use Rubix\ML\Transformers\Transformer; use Rubix\ML\Transformers\WordCountVectorizer; use RuntimeException; @@ -78,7 +77,7 @@ public function prepare(Account $account, array $incomingMailboxes, array $outgo ->apply(new MultibyteTextNormalizer()) ->apply($this->wordCountVectorizer) ->apply($this->tfidf) - ;//->apply($this->dimensionalReductionTransformer); + ;//->apply($this->dimensionalReductionTransformer); $this->limitFeatureSize(); } @@ -89,7 +88,7 @@ public function prepare(Account $account, array $incomingMailboxes, array $outgo public function extract(Message $message): array { $sender = $message->getFrom()->first(); if ($sender === null) { - throw new RuntimeException("This should not happen"); + throw new RuntimeException('This should not happen'); } // Build training data set @@ -99,7 +98,7 @@ public function extract(Message $message): array { ->apply(new MultibyteTextNormalizer()) ->apply($this->wordCountVectorizer) ->apply($this->tfidf) - ;//->apply($this->dimensionalReductionTransformer); + ;//->apply($this->dimensionalReductionTransformer); // Use zeroed vector if no features could be extracted if ($trainDataSet->numFeatures() === 0) { diff --git a/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php index 69dfdfca51..09351f33f6 100644 --- a/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php @@ -11,9 +11,9 @@ class VanillaCompositeExtractor extends CompositeExtractor { public function __construct(ImportantMessagesExtractor $ex1, - ReadMessagesExtractor $ex2, - RepliedMessagesExtractor $ex3, - SentMessagesExtractor $ex4) { + ReadMessagesExtractor $ex2, + RepliedMessagesExtractor $ex3, + SentMessagesExtractor $ex4) { parent::__construct([ $ex1, $ex2, diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index 46960843f8..72ad186546 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -21,7 +21,6 @@ use OCA\Mail\Exception\ServiceException; use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor; -use OCA\Mail\Service\Classification\FeatureExtraction\SubjectExtractor; use OCA\Mail\Service\Classification\FeatureExtraction\VanillaCompositeExtractor; use OCA\Mail\Support\PerformanceLogger; use OCA\Mail\Support\PerformanceLoggerTask; @@ -34,13 +33,10 @@ use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Estimator; -use Rubix\ML\Kernels\Distance\Manhattan; use Rubix\ML\Kernels\Distance\Jaccard; use Rubix\ML\Learner; use Rubix\ML\Persistable; -use Rubix\ML\Transformers\TfIdfTransformer; use Rubix\ML\Transformers\Transformer; -use Rubix\ML\Transformers\WordCountVectorizer; use RuntimeException; use function array_column; use function array_combine; @@ -333,7 +329,7 @@ public function trainWithCustomDataSet( /** @var Learner&Estimator&Persistable $persistedEstimator */ $persistedEstimator = $estimator(); $this->trainClassifier($persistedEstimator, $dataSet); - $perf->step("train classifier with full data set"); + $perf->step('train classifier with full data set'); // Extract persisted transformers of the subject extractor. // Is a bit hacky but a full abstraction would be overkill. @@ -437,8 +433,8 @@ private function getFeaturesAndImportance(Account $account, * @throws ServiceException */ public function classifyImportance(Account $account, - array $messages, - LoggerInterface $logger): array { + array $messages, + LoggerInterface $logger): array { $pipeline = null; try { $pipeline = $this->persistenceService->loadLatest($account); diff --git a/lib/Service/Classification/PersistenceService.php b/lib/Service/Classification/PersistenceService.php index f6b9fcdd7a..4b4eea5345 100644 --- a/lib/Service/Classification/PersistenceService.php +++ b/lib/Service/Classification/PersistenceService.php @@ -106,8 +106,8 @@ public function __construct(ClassifierMapper $mapper, * @throws ServiceException */ public function persist(Classifier $classifier, - Learner $estimator, - array $transformers): void { + Learner $estimator, + array $transformers): void { /* * First we have to insert the row to get the unique ID, but disable * it until the model is persisted as well. Otherwise another process @@ -147,7 +147,7 @@ public function persist(Classifier $classifier, $file = $folder->newFile((string)$classifier->getId()); $file->putContent($serializedClassifier); $this->logger->debug('Serialized classifier written to app data'); - } catch (NotPermittedException | NotFoundException $e) { + } catch (NotPermittedException|NotFoundException $e) { throw new ServiceException('Could not create classifiers directory: ' . $e->getMessage(), 0, $e); } @@ -171,14 +171,14 @@ public function persist(Classifier $classifier, $serializedTransformer = file_get_contents($tmpPath); $this->logger->debug('Serialized transformer written to tmp file (' . strlen($serializedTransformer) . 'B'); } catch (RuntimeException $e) { - throw new ServiceException("Could not serialize transformer: " . $e->getMessage(), 0, $e); + throw new ServiceException('Could not serialize transformer: ' . $e->getMessage(), 0, $e); } try { $file = $folder->newFile("{$classifier->getId()}_t$transformerIndex"); $file->putContent($serializedTransformer); $this->logger->debug("Serialized transformer $transformerIndex written to app data"); - } catch (NotPermittedException | NotFoundException $e) { + } catch (NotPermittedException|NotFoundException $e) { throw new ServiceException( "Failed to persist transformer $transformerIndex: " . $e->getMessage(), 0, @@ -275,7 +275,7 @@ public function load(Classifier $classifier): ClassifierPipeline { try { $serializedTransformer = $transformerFile->getContent(); - } catch (NotFoundException | NotPermittedException $e) { + } catch (NotFoundException|NotPermittedException $e) { $this->logger->debug("Could not load content for transformer file $i with classifier id $id: " . $e->getMessage()); throw new ServiceException("Could not load content for transformer file $i with classifier id $id: " . $e->getMessage(), 0, $e); } @@ -350,6 +350,7 @@ private function deleteModel(int $id): void { 'exception' => $e, ]); } + } /** * Load and instantiate extractor based on a classifier's app version. @@ -362,7 +363,7 @@ private function deleteModel(int $id): void { * @throws ServiceException */ private function loadExtractor(Classifier $classifier, - ClassifierPipeline $pipeline): IExtractor { + ClassifierPipeline $pipeline): IExtractor { $appVersion = $this->parseAppVersion($classifier->getAppVersion()); if ($appVersion[0] >= 3 && $appVersion[1] >= 2) { return $this->loadExtractorV2($pipeline->getTransformers()); @@ -390,11 +391,11 @@ private function loadExtractorV1(): VanillaCompositeExtractor { private function loadExtractorV2(array $transformers): NewCompositeExtractor { $wordCountVectorizer = $transformers[0]; if (!($wordCountVectorizer instanceof WordCountVectorizer)) { - throw new ServiceException("Failed to load persisted transformer: Expected " . WordCountVectorizer::class . ", got" . $wordCountVectorizer::class); + throw new ServiceException('Failed to load persisted transformer: Expected ' . WordCountVectorizer::class . ', got' . $wordCountVectorizer::class); } $tfidfTransformer = $transformers[1]; if (!($tfidfTransformer instanceof TfIdfTransformer)) { - throw new ServiceException("Failed to load persisted transformer: Expected " . TfIdfTransformer::class . ", got" . $tfidfTransformer::class); + throw new ServiceException('Failed to load persisted transformer: Expected ' . TfIdfTransformer::class . ', got' . $tfidfTransformer::class); } $subjectExtractor = new SubjectExtractor(); diff --git a/lib/Service/Sync/ImapToDbSynchronizer.php b/lib/Service/Sync/ImapToDbSynchronizer.php index 65be294b6f..71fec38a48 100644 --- a/lib/Service/Sync/ImapToDbSynchronizer.php +++ b/lib/Service/Sync/ImapToDbSynchronizer.php @@ -432,7 +432,7 @@ private function runPartialSync( try { $importantTag = $this->tagMapper->getTagByImapLabel(Tag::LABEL_IMPORTANT, $account->getUserId()); } catch (DoesNotExistException $e) { - $this->logger->error('Could not find important tag for ' . $account->getUserId(). ' ' . $e->getMessage(), [ + $this->logger->error('Could not find important tag for ' . $account->getUserId() . ' ' . $e->getMessage(), [ 'exception' => $e, ]); } From dade1df3555f5c4615be56dbb7f180b669fdb361 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Thu, 17 Oct 2024 21:01:59 +0200 Subject: [PATCH 31/37] Fix most psalm issues --- lib/AppInfo/Application.php | 4 +--- lib/Db/StatisticsDao.php | 4 ++-- .../FeatureExtraction/SubjectExtractor.php | 9 +++++++-- lib/Service/Sync/ImapToDbSynchronizer.php | 5 +---- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/lib/AppInfo/Application.php b/lib/AppInfo/Application.php index b4499fec62..1709bd5018 100644 --- a/lib/AppInfo/Application.php +++ b/lib/AppInfo/Application.php @@ -29,6 +29,7 @@ use OCA\Mail\Events\MessageDeletedEvent; use OCA\Mail\Events\MessageFlaggedEvent; use OCA\Mail\Events\MessageSentEvent; +use OCA\Mail\Events\NewMessagesSynchronized; use OCA\Mail\Events\OutboxMessageCreatedEvent; use OCA\Mail\Events\SynchronizationEvent; use OCA\Mail\HordeTranslationHandler; @@ -44,7 +45,6 @@ use OCA\Mail\Listener\MessageCacheUpdaterListener; use OCA\Mail\Listener\MessageKnownSinceListener; use OCA\Mail\Listener\MoveJunkListener; -use OCA\Mail\Listener\NewMessageClassificationListener; use OCA\Mail\Listener\NewMessagesNotifier; use OCA\Mail\Listener\OauthTokenRefreshListener; use OCA\Mail\Listener\OptionalIndicesListener; @@ -123,10 +123,8 @@ public function register(IRegistrationContext $context): void { $context->registerEventListener(MessageDeletedEvent::class, MessageCacheUpdaterListener::class); $context->registerEventListener(MessageSentEvent::class, AddressCollectionListener::class); $context->registerEventListener(MessageSentEvent::class, InteractionListener::class); - $context->registerEventListener(NewMessagesSynchronized::class, NewMessageClassificationListener::class); $context->registerEventListener(NewMessagesSynchronized::class, MessageKnownSinceListener::class); $context->registerEventListener(NewMessagesSynchronized::class, NewMessagesNotifier::class); - $context->registerEventListener(MessageSentEvent::class, SaveSentMessageListener::class); $context->registerEventListener(SynchronizationEvent::class, AccountSynchronizedThreadUpdaterListener::class); $context->registerEventListener(UserDeletedEvent::class, UserDeletedListener::class); $context->registerEventListener(NewMessagesSynchronized::class, FollowUpClassifierListener::class); diff --git a/lib/Db/StatisticsDao.php b/lib/Db/StatisticsDao.php index 9d3ded4ea6..0b23e0b759 100644 --- a/lib/Db/StatisticsDao.php +++ b/lib/Db/StatisticsDao.php @@ -157,7 +157,7 @@ public function getSubjects(array $mailboxes, array $emails): array { ->where($qb->expr()->eq('r.type', $qb->createNamedParameter(Address::TYPE_FROM, IQueryBuilder::PARAM_INT), IQueryBuilder::PARAM_INT)) ->andWhere($qb->expr()->in('m.mailbox_id', $qb->createNamedParameter($mailboxIds, IQueryBuilder::PARAM_INT_ARRAY))) ->andWhere($qb->expr()->in('r.email', $qb->createNamedParameter($emails, IQueryBuilder::PARAM_STR_ARRAY), IQueryBuilder::PARAM_STR_ARRAY)); - $result = $select->execute(); + $result = $select->executeQuery(); $rows = $result->fetchAll(); $result->closeCursor(); $data = []; @@ -181,7 +181,7 @@ public function getPreviewTexts(array $mailboxes, array $emails): array { ->andWhere($qb->expr()->in('m.mailbox_id', $qb->createNamedParameter($mailboxIds, IQueryBuilder::PARAM_INT_ARRAY))) ->andWhere($qb->expr()->in('r.email', $qb->createNamedParameter($emails, IQueryBuilder::PARAM_STR_ARRAY), IQueryBuilder::PARAM_STR_ARRAY)) ->andWhere($qb->expr()->isNotNull('m.preview_text')); - $result = $select->execute(); + $result = $select->executeQuery(); $rows = $result->fetchAll(); $result->closeCursor(); $data = []; diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php index 20dcbeae34..b768938139 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php @@ -116,8 +116,13 @@ public function extract(Message $message): array { * Limit feature vector length to actual vocabulary size. */ private function limitFeatureSize(): void { - $vocab = $this->wordCountVectorizer->vocabularies()[0]; - $this->max = count($vocab); + $vocabularies = $this->wordCountVectorizer->vocabularies(); + if (!isset($vocabularies[0])) { + // Should not happen but better safe than sorry + return; + } + + $this->max = count($vocabularies[0]); echo("WCF vocab size: {$this->max}\n"); } } diff --git a/lib/Service/Sync/ImapToDbSynchronizer.php b/lib/Service/Sync/ImapToDbSynchronizer.php index 71fec38a48..e4d0e87474 100644 --- a/lib/Service/Sync/ImapToDbSynchronizer.php +++ b/lib/Service/Sync/ImapToDbSynchronizer.php @@ -20,6 +20,7 @@ use OCA\Mail\Db\MessageMapper as DatabaseMessageMapper; use OCA\Mail\Db\Tag; use OCA\Mail\Db\TagMapper; +use OCA\Mail\Events\NewMessagesSynchronized; use OCA\Mail\Events\SynchronizationEvent; use OCA\Mail\Exception\ClientException; use OCA\Mail\Exception\IncompleteSyncException; @@ -30,7 +31,6 @@ use OCA\Mail\Exception\UidValidityChangedException; use OCA\Mail\IMAP\IMAPClientFactory; use OCA\Mail\IMAP\MessageMapper as ImapMessageMapper; -use OCA\Mail\IMAP\PreviewEnhancer; use OCA\Mail\IMAP\Sync\Request; use OCA\Mail\IMAP\Sync\Synchronizer; use OCA\Mail\Model\IMAPMessage; @@ -76,7 +76,6 @@ class ImapToDbSynchronizer { /** @var IMailManager */ private $mailManager; - private PreviewEnhancer $previewEnhancer; private TagMapper $tagMapper; private NewMessagesClassifier $newMessagesClassifier; @@ -90,7 +89,6 @@ public function __construct(DatabaseMessageMapper $dbMapper, PerformanceLogger $performanceLogger, LoggerInterface $logger, IMailManager $mailManager, - PreviewEnhancer $previewEnhancer, TagMapper $tagMapper, NewMessagesClassifier $newMessagesClassifier) { $this->dbMapper = $dbMapper; @@ -102,7 +100,6 @@ public function __construct(DatabaseMessageMapper $dbMapper, $this->performanceLogger = $performanceLogger; $this->logger = $logger; $this->mailManager = $mailManager; - $this->previewEnhancer = $previewEnhancer; $this->tagMapper = $tagMapper; $this->newMessagesClassifier = $newMessagesClassifier; } From a808b2e295ccc867cc3313e0982733a444cce4a8 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Mon, 21 Oct 2024 12:15:29 +0200 Subject: [PATCH 32/37] Persist classifiers in memory cache only --- appinfo/info.xml | 2 +- lib/Command/RunMetaEstimator.php | 15 +- lib/Command/TrainAccount.php | 26 +- lib/Db/Classifier.php | 108 ---- lib/Db/ClassifierMapper.php | 57 --- .../Version4100Date20241021091352.php | 29 ++ lib/Model/Classifier.php | 134 +++++ lib/Model/ClassifierPipeline.php | 19 +- .../FeatureExtraction/CompositeExtractor.php | 40 +- .../NewCompositeExtractor.php | 24 - .../FeatureExtraction/SubjectExtractor.php | 2 +- .../VanillaCompositeExtractor.php | 24 - .../Classification/ImportanceClassifier.php | 74 ++- .../Classification/PersistenceService.php | 463 ++++-------------- .../Classification/RubixMemoryPersister.php | 41 ++ lib/Service/CleanupService.php | 6 - 16 files changed, 396 insertions(+), 668 deletions(-) delete mode 100644 lib/Db/Classifier.php delete mode 100644 lib/Db/ClassifierMapper.php create mode 100644 lib/Migration/Version4100Date20241021091352.php create mode 100644 lib/Model/Classifier.php delete mode 100644 lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php delete mode 100644 lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php create mode 100644 lib/Service/Classification/RubixMemoryPersister.php diff --git a/appinfo/info.xml b/appinfo/info.xml index b0e0c1e72a..7008982213 100644 --- a/appinfo/info.xml +++ b/appinfo/info.xml @@ -34,7 +34,7 @@ The rating depends on the installed text processing backend. See [the rating ove Learn more about the Nextcloud Ethical AI Rating [in our blog](https://nextcloud.com/blog/nextcloud-ethical-ai-rating/). ]]> - 4.1.0-alpha.2 + 4.1.0-alpha.3 agpl Christoph Wurst GretaD diff --git a/lib/Command/RunMetaEstimator.php b/lib/Command/RunMetaEstimator.php index 8a00c1bafe..32d1187d37 100644 --- a/lib/Command/RunMetaEstimator.php +++ b/lib/Command/RunMetaEstimator.php @@ -10,7 +10,7 @@ namespace OCA\Mail\Command; use OCA\Mail\Service\AccountService; -use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor; +use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor; use OCA\Mail\Service\Classification\ImportanceClassifier; use OCA\Mail\Support\ConsoleLoggerDecorator; use OCP\AppFramework\Db\DoesNotExistException; @@ -86,8 +86,8 @@ protected function execute(InputInterface $input, OutputInterface $output): int return 1; } - /** @var NewCompositeExtractor $extractor */ - $extractor = $this->container->get(NewCompositeExtractor::class); + /** @var CompositeExtractor $extractor */ + $extractor = $this->container->get(CompositeExtractor::class); $consoleLogger = new ConsoleLoggerDecorator( $this->logger, $output @@ -124,8 +124,9 @@ protected function execute(InputInterface $input, OutputInterface $output): int return $estimator; }; + /** @var GridSearch $metaEstimator */ if ($dataSet) { - $this->classifier->trainWithCustomDataSet( + $metaEstimator = $this->classifier->trainWithCustomDataSet( $account, $consoleLogger, $dataSet, @@ -135,7 +136,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int false, ); } else { - $this->classifier->train( + $metaEstimator = $this->classifier->train( $account, $consoleLogger, $extractor, @@ -145,6 +146,10 @@ protected function execute(InputInterface $input, OutputInterface $output): int ); } + if ($metaEstimator) { + $output->writeln("Best estimator: {$metaEstimator->base()}"); + } + $mbs = (int)(memory_get_peak_usage() / 1024 / 1024); $output->writeln('' . $mbs . 'MB of memory used'); return 0; diff --git a/lib/Command/TrainAccount.php b/lib/Command/TrainAccount.php index 33ce44371d..12866de7f3 100644 --- a/lib/Command/TrainAccount.php +++ b/lib/Command/TrainAccount.php @@ -3,7 +3,7 @@ declare(strict_types=1); /** - * SPDX-FileCopyrightText: 2019 Nextcloud GmbH and Nextcloud contributors + * SPDX-FileCopyrightText: 2019-2024 Nextcloud GmbH and Nextcloud contributors * SPDX-License-Identifier: AGPL-3.0-or-later */ @@ -11,15 +11,13 @@ use OCA\Mail\Service\AccountService; use OCA\Mail\Service\Classification\ClassificationSettingsService; +use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor; use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; -use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor; -use OCA\Mail\Service\Classification\FeatureExtraction\VanillaCompositeExtractor; use OCA\Mail\Service\Classification\ImportanceClassifier; use OCA\Mail\Support\ConsoleLoggerDecorator; use OCP\AppFramework\Db\DoesNotExistException; use Psr\Container\ContainerInterface; use Psr\Log\LoggerInterface; -use Rubix\ML\Classifiers\GaussianNB; use Symfony\Component\Console\Command\Command; use Symfony\Component\Console\Input\InputArgument; use Symfony\Component\Console\Input\InputInterface; @@ -110,9 +108,6 @@ protected function execute(InputInterface $input, OutputInterface $output): int $shuffle = (bool)$input->getOption(self::ARGUMENT_SHUFFLE); $dryRun = (bool)$input->getOption(self::ARGUMENT_DRY_RUN); $force = (bool)$input->getOption(self::ARGUMENT_FORCE); - $old = (bool)$input->getOption(self::ARGUMENT_OLD); - $oldEstimator = $old || $input->getOption(self::ARGUMENT_OLD_ESTIMATOR); - $oldExtractor = $old || $input->getOption(self::ARGUMENT_OLD_EXTRACTOR); try { $account = $this->accountService->findById($accountId); @@ -127,18 +122,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int } /** @var IExtractor $extractor */ - if ($oldExtractor) { - $extractor = $this->container->get(VanillaCompositeExtractor::class); - } else { - $extractor = $this->container->get(NewCompositeExtractor::class); - } - - $estimator = null; - if ($oldEstimator) { - $estimator = static function () { - return new GaussianNB(); - }; - } + $extractor = $this->container->get(CompositeExtractor::class); $consoleLogger = new ConsoleLoggerDecorator( $this->logger, @@ -167,7 +151,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int $consoleLogger, $dataSet, $extractor, - $estimator, + null, null, !$dryRun ); @@ -176,7 +160,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int $account, $consoleLogger, $extractor, - $estimator, + null, $shuffle, !$dryRun ); diff --git a/lib/Db/Classifier.php b/lib/Db/Classifier.php deleted file mode 100644 index 51f9b25a93..0000000000 --- a/lib/Db/Classifier.php +++ /dev/null @@ -1,108 +0,0 @@ -addType('accountId', 'integer'); - $this->addType('type', 'string'); - $this->addType('appVersion', 'string'); - $this->addType('trainingSetSize', 'integer'); - $this->addType('validationSetSize', 'integer'); - $this->addType('recallImportant', 'float'); - $this->addType('precisionImportant', 'float'); - $this->addType('f1ScoreImportant', 'float'); - $this->addType('duration', 'integer'); - $this->addType('active', 'boolean'); - $this->addType('createdAt', 'integer'); - $this->addType('transformerCount', 'integer'); - $this->addType('transformers', 'string'); - } -} diff --git a/lib/Db/ClassifierMapper.php b/lib/Db/ClassifierMapper.php deleted file mode 100644 index 946b70f8b0..0000000000 --- a/lib/Db/ClassifierMapper.php +++ /dev/null @@ -1,57 +0,0 @@ - - */ -class ClassifierMapper extends QBMapper { - public function __construct(IDBConnection $db) { - parent::__construct($db, 'mail_classifiers'); - } - - /** - * @param int $id - * - * @return Classifier - * @throws DoesNotExistException - */ - public function findLatest(int $id): Classifier { - $qb = $this->db->getQueryBuilder(); - - $select = $qb->select('*') - ->from($this->getTableName()) - ->where( - $qb->expr()->eq('account_id', $qb->createNamedParameter($id, IQueryBuilder::PARAM_INT), IQueryBuilder::PARAM_INT), - $qb->expr()->eq('active', $qb->createNamedParameter(true, IQueryBuilder::PARAM_BOOL), IQueryBuilder::PARAM_BOOL) - ) - ->orderBy('created_at', 'desc') - ->setMaxResults(1); - - return $this->findEntity($select); - } - - public function findHistoric(int $threshold, int $limit) { - $qb = $this->db->getQueryBuilder(); - $select = $qb->select('*') - ->from($this->getTableName()) - ->where( - $qb->expr()->lte('created_at', $qb->createNamedParameter($threshold, IQueryBuilder::PARAM_INT), IQueryBuilder::PARAM_INT), - ) - ->orderBy('created_at', 'asc') - ->setMaxResults($limit); - return $this->findEntities($select); - } -} diff --git a/lib/Migration/Version4100Date20241021091352.php b/lib/Migration/Version4100Date20241021091352.php new file mode 100644 index 0000000000..ccdf8be264 --- /dev/null +++ b/lib/Migration/Version4100Date20241021091352.php @@ -0,0 +1,29 @@ +dropTable('mail_classifiers'); + return $schema; + } +} diff --git a/lib/Model/Classifier.php b/lib/Model/Classifier.php new file mode 100644 index 0000000000..df4d21eeb1 --- /dev/null +++ b/lib/Model/Classifier.php @@ -0,0 +1,134 @@ +accountId; + } + + public function setAccountId(int $accountId): void { + $this->accountId = $accountId; + } + + public function getType(): string { + return $this->type; + } + + public function setType(string $type): void { + $this->type = $type; + } + + public function getEstimator(): string { + return $this->estimator; + } + + public function setEstimator(string $estimator): void { + $this->estimator = $estimator; + } + + public function getPersistenceVersion(): int { + return $this->persistenceVersion; + } + + public function setPersistenceVersion(int $persistenceVersion): void { + $this->persistenceVersion = $persistenceVersion; + } + + public function getTrainingSetSize(): int { + return $this->trainingSetSize; + } + + public function setTrainingSetSize(int $trainingSetSize): void { + $this->trainingSetSize = $trainingSetSize; + } + + public function getValidationSetSize(): int { + return $this->validationSetSize; + } + + public function setValidationSetSize(int $validationSetSize): void { + $this->validationSetSize = $validationSetSize; + } + + public function getRecallImportant(): float { + return $this->recallImportant; + } + + public function setRecallImportant(float $recallImportant): void { + $this->recallImportant = $recallImportant; + } + + public function getPrecisionImportant(): float { + return $this->precisionImportant; + } + + public function setPrecisionImportant(float $precisionImportant): void { + $this->precisionImportant = $precisionImportant; + } + + public function getF1ScoreImportant(): float { + return $this->f1ScoreImportant; + } + + public function setF1ScoreImportant(float $f1ScoreImportant): void { + $this->f1ScoreImportant = $f1ScoreImportant; + } + + public function getDuration(): int { + return $this->duration; + } + + public function setDuration(int $duration): void { + $this->duration = $duration; + } + + public function getCreatedAt(): int { + return $this->createdAt; + } + + public function setCreatedAt(int $createdAt): void { + $this->createdAt = $createdAt; + } + + #[ReturnTypeWillChange] + public function jsonSerialize() { + return [ + 'accountId' => $this->accountId, + 'type' => $this->type, + 'estimator' => $this->estimator, + 'persistenceVersion' => $this->persistenceVersion, + 'trainingSetSize' => $this->trainingSetSize, + 'validationSetSize' => $this->validationSetSize, + 'recallImportant' => $this->recallImportant, + 'precisionImportant' => $this->precisionImportant, + 'f1ScoreImportant' => $this->f1ScoreImportant, + 'duration' => $this->duration, + 'createdAt' => $this->createdAt, + ]; + } +} diff --git a/lib/Model/ClassifierPipeline.php b/lib/Model/ClassifierPipeline.php index eca0098195..bcc1be893e 100644 --- a/lib/Model/ClassifierPipeline.php +++ b/lib/Model/ClassifierPipeline.php @@ -9,28 +9,29 @@ namespace OCA\Mail\Model; +use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; use Rubix\ML\Estimator; use Rubix\ML\Transformers\Transformer; class ClassifierPipeline { - private Estimator $estimator; - - /** @var Transformer[] */ - private array $transformers; - /** - * @param Estimator $estimator * @param Transformer[] $transformers */ - public function __construct(Estimator $estimator, array $transformers) { - $this->estimator = $estimator; - $this->transformers = $transformers; + public function __construct( + private Estimator $estimator, + private IExtractor $extractor, + private array $transformers, + ) { } public function getEstimator(): Estimator { return $this->estimator; } + public function getExtractor(): IExtractor { + return $this->extractor; + } + /** * @return Transformer[] */ diff --git a/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php index 197d0f2eb4..aaa30dcbfe 100644 --- a/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php @@ -3,7 +3,7 @@ declare(strict_types=1); /** - * SPDX-FileCopyrightText: 2020 Nextcloud GmbH and Nextcloud contributors + * SPDX-FileCopyrightText: 2020-2024 Nextcloud GmbH and Nextcloud contributors * SPDX-License-Identifier: AGPL-3.0-or-later */ @@ -11,20 +11,34 @@ use OCA\Mail\Account; use OCA\Mail\Db\Message; +use Rubix\ML\Transformers\TfIdfTransformer; +use Rubix\ML\Transformers\WordCountVectorizer; use function OCA\Mail\array_flat_map; /** * Combines a set of DI'ed extractors so they can be used as one class */ -abstract class CompositeExtractor implements IExtractor { - /** @var IExtractor[] */ - protected array $extractors; +class CompositeExtractor implements IExtractor { + private readonly SubjectExtractor $subjectExtractor; - /** - * @param IExtractor[] $extractors - */ - public function __construct(array $extractors) { - $this->extractors = $extractors; + /** @var IExtractor[] */ + private readonly array $extractors; + + public function __construct( + ImportantMessagesExtractor $ex1, + ReadMessagesExtractor $ex2, + RepliedMessagesExtractor $ex3, + SentMessagesExtractor $ex4, + SubjectExtractor $ex5, + ) { + $this->subjectExtractor = $ex5; + $this->extractors = [ + $ex1, + $ex2, + $ex3, + $ex4, + $ex5, + ]; } public function prepare(Account $account, @@ -36,12 +50,14 @@ public function prepare(Account $account, } } - /** - * @inheritDoc - */ public function extract(Message $message): array { return array_flat_map(static function (IExtractor $extractor) use ($message) { return $extractor->extract($message); }, $this->extractors); } + + public function getSubjectExtractor(): SubjectExtractor { + return $this->subjectExtractor; + } } + diff --git a/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php deleted file mode 100644 index b7d86622ca..0000000000 --- a/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php +++ /dev/null @@ -1,24 +0,0 @@ -subjectExtractor = $ex2; - } - - public function getSubjectExtractor(): SubjectExtractor { - return $this->subjectExtractor; - } -} diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php index b768938139..c5ecf6681f 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php @@ -49,7 +49,7 @@ public function setWordCountVectorizer(WordCountVectorizer $wordCountVectorizer) $this->limitFeatureSize(); } - public function getTfidf(): Transformer { + public function getTfidf(): TfIdfTransformer { return $this->tfidf; } diff --git a/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php deleted file mode 100644 index 09351f33f6..0000000000 --- a/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php +++ /dev/null @@ -1,24 +0,0 @@ -mailboxMapper = $mailboxMapper; $this->messageMapper = $messageMapper; $this->persistenceService = $persistenceService; $this->performanceLogger = $performanceLogger; $this->rulesClassifier = $rulesClassifier; - $this->vanillaExtractor = $vanillaExtractor; $this->container = $container; } @@ -209,6 +204,8 @@ public function buildDataSet( * @param bool $shuffleDataSet Shuffle the data set before training * @param bool $persist Persist the trained classifier to use it for message classification * + * @return Estimator|null The validation estimator, persisted estimator (if `$persist` === true) or null in case none was trained + * * @throws ServiceException */ public function train( @@ -218,12 +215,12 @@ public function train( ?Closure $estimator = null, bool $shuffleDataSet = false, bool $persist = true, - ): void { + ): ?Estimator { $perf = $this->performanceLogger->start('importance classifier training'); if ($extractor === null) { try { - $extractor = $this->container->get(NewCompositeExtractor::class); + $extractor = $this->container->get(CompositeExtractor::class); } catch (ContainerExceptionInterface $e) { throw new ServiceException('Default extractor is not available', 0, $e); } @@ -231,10 +228,10 @@ public function train( $dataSet = $this->buildDataSet($account, $extractor, $logger, $perf, $shuffleDataSet); if ($dataSet === null) { - return; + return null; } - $this->trainWithCustomDataSet( + return $this->trainWithCustomDataSet( $account, $logger, $dataSet, @@ -256,24 +253,21 @@ public function train( * @param PerformanceLoggerTask|null $perf Optionally reuse a performance logger task * @param bool $persist Persist the trained classifier to use it for message classification * + * @return Estimator|null The validation estimator, persisted estimator (if `$persist` === true) or null in case none was trained + * * @throws ServiceException */ public function trainWithCustomDataSet( Account $account, LoggerInterface $logger, array $dataSet, - IExtractor $extractor, + CompositeExtractor $extractor, ?Closure $estimator, ?PerformanceLoggerTask $perf = null, bool $persist = true, - ): void { + ): ?Estimator { $perf ??= $this->performanceLogger->start('importance classifier training'); - - if ($estimator === null) { - $estimator = static function () { - return self::createDefaultEstimator(); - }; - } + $estimator ??= self::createDefaultEstimator(...); /** * How many of the most recent messages are excluded from training? @@ -303,7 +297,7 @@ public function trainWithCustomDataSet( if ($validationSet === [] || $trainingSet === []) { $logger->info('not enough messages to train a classifier'); $perf->end(); - return; + return null; } /** @var Learner&Estimator&Persistable $validationEstimator */ @@ -321,30 +315,28 @@ public function trainWithCustomDataSet( 'exception' => $e, ]); $perf->end(); - return; + return null; } $perf->step('train and validate classifier with training and validation sets'); - if ($persist) { - /** @var Learner&Estimator&Persistable $persistedEstimator */ - $persistedEstimator = $estimator(); - $this->trainClassifier($persistedEstimator, $dataSet); - $perf->step('train classifier with full data set'); - - // Extract persisted transformers of the subject extractor. - // Is a bit hacky but a full abstraction would be overkill. - /** @var (Transformer&Persistable)[] $transformers */ - $transformers = []; - if ($extractor instanceof NewCompositeExtractor) { - $transformers[] = $extractor->getSubjectExtractor()->getWordCountVectorizer(); - $transformers[] = $extractor->getSubjectExtractor()->getTfidf(); - } - - $classifier->setAccountId($account->getId()); - $classifier->setDuration($perf->end()); - $this->persistenceService->persist($classifier, $persistedEstimator, $transformers); - $logger->debug("classifier {$classifier->getId()} persisted"); + if (!$persist) { + return $validationEstimator; } + + /** @var Learner&Estimator&Persistable $persistedEstimator */ + $persistedEstimator = $estimator(); + $this->trainClassifier($persistedEstimator, $dataSet); + $perf->step('train classifier with full data set'); + $classifier->setDuration($perf->end()); + $classifier->setAccountId($account->getId()); + $classifier->setEstimator(get_class($persistedEstimator)); + $classifier->setPersistenceVersion(PersistenceService::VERSION); + + $this->persistenceService->persist($account, $persistedEstimator, $extractor); + $logger->debug("Classifier for account {$account->getId()} persisted", [ + 'classifier' => $classifier, + ]); + return $persistedEstimator; } diff --git a/lib/Service/Classification/PersistenceService.php b/lib/Service/Classification/PersistenceService.php index 4b4eea5345..120136b27e 100644 --- a/lib/Service/Classification/PersistenceService.php +++ b/lib/Service/Classification/PersistenceService.php @@ -3,162 +3,78 @@ declare(strict_types=1); /** - * SPDX-FileCopyrightText: 2020 Nextcloud GmbH and Nextcloud contributors + * SPDX-FileCopyrightText: 2020-2024 Nextcloud GmbH and Nextcloud contributors * SPDX-License-Identifier: AGPL-3.0-or-later */ namespace OCA\Mail\Service\Classification; -use OCA\DAV\Connector\Sabre\File; use OCA\Mail\Account; -use OCA\Mail\AppInfo\Application; -use OCA\Mail\Db\Classifier; -use OCA\Mail\Db\ClassifierMapper; -use OCA\Mail\Db\MailAccountMapper; use OCA\Mail\Exception\ServiceException; use OCA\Mail\Model\ClassifierPipeline; +use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor; use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; -use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor; -use OCA\Mail\Service\Classification\FeatureExtraction\SubjectExtractor; -use OCA\Mail\Service\Classification\FeatureExtraction\VanillaCompositeExtractor; -use OCP\App\IAppManager; -use OCP\AppFramework\Db\DoesNotExistException; -use OCP\AppFramework\Utility\ITimeFactory; -use OCP\Files; -use OCP\Files\IAppData; -use OCP\Files\NotFoundException; -use OCP\Files\NotPermittedException; +use OCP\ICache; use OCP\ICacheFactory; -use OCP\ITempManager; use Psr\Container\ContainerExceptionInterface; use Psr\Container\ContainerInterface; -use Psr\Log\LoggerInterface; use Rubix\ML\Learner; use Rubix\ML\Persistable; use Rubix\ML\PersistentModel; -use Rubix\ML\Persisters\Filesystem; use Rubix\ML\Serializers\RBX; use Rubix\ML\Transformers\TfIdfTransformer; use Rubix\ML\Transformers\Transformer; use Rubix\ML\Transformers\WordCountVectorizer; use RuntimeException; -use function file_get_contents; -use function file_put_contents; use function get_class; -use function strlen; class PersistenceService { - private const ADD_DATA_FOLDER = 'classifiers'; + // Increment the version when changing the classifier or transformer pipeline + public const VERSION = 1; - /** @var ClassifierMapper */ - private $mapper; - - /** @var IAppData */ - private $appData; - - /** @var ITempManager */ - private $tempManager; - - /** @var ITimeFactory */ - private $timeFactory; - - /** @var IAppManager */ - private $appManager; - - /** @var ICacheFactory */ - private $cacheFactory; - - /** @var LoggerInterface */ - private $logger; - - /** @var MailAccountMapper */ - private $accountMapper; - - private ContainerInterface $container; - - public function __construct(ClassifierMapper $mapper, - IAppData $appData, - ITempManager $tempManager, - ITimeFactory $timeFactory, - IAppManager $appManager, - ICacheFactory $cacheFactory, - LoggerInterface $logger, - MailAccountMapper $accountMapper, - ContainerInterface $container) { - $this->mapper = $mapper; - $this->appData = $appData; - $this->tempManager = $tempManager; - $this->timeFactory = $timeFactory; - $this->appManager = $appManager; - $this->cacheFactory = $cacheFactory; - $this->logger = $logger; - $this->accountMapper = $accountMapper; - $this->container = $container; + public function __construct( + private readonly ICacheFactory $cacheFactory, + private readonly ContainerInterface $container, + ) { } /** - * Persist the classifier data to the database, the estimator and its transformers to storage + * Persist classifier, estimator and its transformers to the memory cache. * - * @param Classifier $classifier * @param Learner&Persistable $estimator - * @param (Transformer&Persistable)[] $transformers * - * @throws ServiceException + * @throws ServiceException If any serialization fails */ - public function persist(Classifier $classifier, + public function persist( + Account $account, Learner $estimator, - array $transformers): void { - /* - * First we have to insert the row to get the unique ID, but disable - * it until the model is persisted as well. Otherwise another process - * might try to load the model in the meantime and run into an error - * due to the missing data in app data. - */ - $classifier->setAppVersion($this->appManager->getAppVersion(Application::APP_ID)); - $classifier->setEstimator(get_class($estimator)); - $classifier->setActive(false); - $classifier->setCreatedAt($this->timeFactory->getTime()); - $this->mapper->insert($classifier); + CompositeExtractor $extractor, + ): void { + $serializedData = []; /* - * Then we serialize the estimator into a temporary file + * First we serialize the estimator */ - $tmpPath = $this->tempManager->getTemporaryFile(); try { - $model = new PersistentModel($estimator, new Filesystem($tmpPath)); + $persister = new RubixMemoryPersister(); + $model = new PersistentModel($estimator, $persister); $model->save(); - $serializedClassifier = file_get_contents($tmpPath); - $this->logger->debug('Serialized classifier written to tmp file (' . strlen($serializedClassifier) . 'B'); + $serializedData[] = $persister->getData(); } catch (RuntimeException $e) { throw new ServiceException('Could not serialize classifier: ' . $e->getMessage(), 0, $e); } /* - * Then we store the serialized model to app data - */ - try { - try { - $folder = $this->appData->getFolder(self::ADD_DATA_FOLDER); - $this->logger->debug('Using existing folder for the serialized classifier'); - } catch (NotFoundException $e) { - $folder = $this->appData->newFolder(self::ADD_DATA_FOLDER); - $this->logger->debug('New folder created for serialized classifiers'); - } - $file = $folder->newFile((string)$classifier->getId()); - $file->putContent($serializedClassifier); - $this->logger->debug('Serialized classifier written to app data'); - } catch (NotPermittedException|NotFoundException $e) { - throw new ServiceException('Could not create classifiers directory: ' . $e->getMessage(), 0, $e); - } - - /* - * Then we serialize the transformer pipeline to temporary files + * Then we serialize the transformer pipeline */ - $transformerIndex = 0; + $transfomers = [ + $extractor->getSubjectExtractor()->getWordCountVectorizer(), + $extractor->getSubjectExtractor()->getTfIdf(), + ]; $serializer = new RBX(); - foreach ($transformers as $transformer) { - $tmpPath = $this->tempManager->getTemporaryFile(); + foreach ($transfomers as $transformer) { try { + $persister = new RubixMemoryPersister(); /** * This is how to serialize a transformer according to the official docs. * PersistentModel can only be used on Learners which transformers don't implement. @@ -167,316 +83,145 @@ public function persist(Classifier $classifier, * * @psalm-suppress InternalMethod */ - $serializer->serialize($transformer)->saveTo(new Filesystem($tmpPath)); - $serializedTransformer = file_get_contents($tmpPath); - $this->logger->debug('Serialized transformer written to tmp file (' . strlen($serializedTransformer) . 'B'); + $serializer->serialize($transformer)->saveTo($persister); + $serializedData[] = $persister->getData(); } catch (RuntimeException $e) { throw new ServiceException('Could not serialize transformer: ' . $e->getMessage(), 0, $e); } - - try { - $file = $folder->newFile("{$classifier->getId()}_t$transformerIndex"); - $file->putContent($serializedTransformer); - $this->logger->debug("Serialized transformer $transformerIndex written to app data"); - } catch (NotPermittedException|NotFoundException $e) { - throw new ServiceException( - "Failed to persist transformer $transformerIndex: " . $e->getMessage(), - 0, - $e - ); - } - - $transformerIndex++; } - /* - * Now we set the model active so it can be used by the next request - */ - $classifier->setActive(true); - $this->mapper->update($classifier); + $this->setCached((string)$account->getId(), $serializedData); } /** - * @param Account $account - * - * @return ?array [Estimator, IExtractor] + * Load the latest estimator and its transformers. * - * @throws ServiceException + * @throws ServiceException If any deserialization fails */ - public function loadLatest(Account $account): ?array { - try { - $latestModel = $this->mapper->findLatest($account->getId()); - } catch (DoesNotExistException $e) { + public function loadLatest(Account $account): ?ClassifierPipeline { + $cached = $this->getCached((string)$account->getId()); + if ($cached == null) { return null; } - $pipeline = $this->load($latestModel); + $serializedModel = $cached[0]; + $serializedTransformers = array_slice($cached, 1); try { - $extractor = $this->loadExtractor($latestModel, $pipeline); - } catch (ContainerExceptionInterface $e) { + $estimator = PersistentModel::load(new RubixMemoryPersister($serializedModel)); + } catch (RuntimeException $e) { throw new ServiceException( - "Failed to load extractor: {$e->getMessage()}", + 'Could not deserialize persisted classifier: ' . $e->getMessage(), 0, $e, ); } - return [$pipeline->getEstimator(), $extractor]; - } - - /** - * Load an estimator and its transformers of a classifier from storage - * - * @param Classifier $classifier - * @return ClassifierPipeline - * - * @throws ServiceException - */ - public function load(Classifier $classifier): ClassifierPipeline { - $transformerCount = 0; - $appVersion = $this->parseAppVersion($classifier->getAppVersion()); - if ($appVersion[0] >= 3 && $appVersion[1] >= 2) { - $transformerCount = 2; - } - - $id = $classifier->getId(); - $cached = $this->getCached($classifier->getId(), $transformerCount); - if ($cached !== null) { - $this->logger->debug("Using cached serialized classifier $id"); - $serialized = $cached[0]; - $serializedTransformers = array_slice($cached, 1); - } else { - $this->logger->debug('Loading serialized classifier from app data'); - try { - $modelsFolder = $this->appData->getFolder(self::ADD_DATA_FOLDER); - $modelFile = $modelsFolder->getFile((string)$id); - } catch (NotFoundException $e) { - $this->logger->debug("Could not load classifier $id: " . $e->getMessage()); - throw new ServiceException("Could not load classifier $id: " . $e->getMessage(), 0, $e); - } - - try { - $serialized = $modelFile->getContent(); - } catch (NotFoundException|NotPermittedException $e) { - $this->logger->debug("Could not load content for model file with classifier id $id: " . $e->getMessage()); - throw new ServiceException("Could not load content for model file with classifier id $id: " . $e->getMessage(), 0, $e); - } - $size = strlen($serialized); - $this->logger->debug("Serialized classifier loaded (size=$size)"); - - $serializedTransformers = []; - for ($i = 0; $i < $transformerCount; $i++) { - try { - $transformerFile = $modelsFolder->getFile("{$id}_t$i"); - } catch (NotFoundException $e) { - $this->logger->debug("Could not load transformer $i of classifier $id: " . $e->getMessage()); - throw new ServiceException("Could not load transformer $i of classifier $id: " . $e->getMessage(), 0, $e); - } - - try { - $serializedTransformer = $transformerFile->getContent(); - } catch (NotFoundException|NotPermittedException $e) { - $this->logger->debug("Could not load content for transformer file $i with classifier id $id: " . $e->getMessage()); - throw new ServiceException("Could not load content for transformer file $i with classifier id $id: " . $e->getMessage(), 0, $e); - } - $size = strlen($serializedTransformer); - $this->logger->debug("Serialized transformer $i loaded (size=$size)"); - $serializedTransformers[] = $serializedTransformer; - } - - $this->cache($id, $serialized, $serializedTransformers); - } - - $tmpPath = $this->tempManager->getTemporaryFile(); - file_put_contents($tmpPath, $serialized); - try { - $estimator = PersistentModel::load(new Filesystem($tmpPath)); - } catch (RuntimeException $e) { - throw new ServiceException("Could not deserialize persisted classifier $id: " . $e->getMessage(), 0, $e); - } - - $transformers = array_map(function (string $serializedTransformer) use ($id) { - $serializer = new RBX(); - $tmpPath = $this->tempManager->getTemporaryFile(); - file_put_contents($tmpPath, $serializedTransformer); + $serializer = new RBX(); + $transformers = array_map(function (string $serializedTransformer) use ($serializer) { try { - $persister = new Filesystem($tmpPath); + $persister = new RubixMemoryPersister($serializedTransformer); $transformer = $persister->load()->deserializeWith($serializer); } catch (RuntimeException $e) { - throw new ServiceException("Could not deserialize persisted transformer of classifier $id: " . $e->getMessage(), 0, $e); + throw new ServiceException( + 'Could not deserialize persisted transformer of classifier: ' . $e->getMessage(), + 0, + $e, + ); } if (!($transformer instanceof Transformer)) { - throw new ServiceException("Transformer of classifier $id is not a transformer: Got " . $transformer::class); + throw new ServiceException(sprintf( + 'Transformer is not an instance of %s: Got %s', + Transformer::class, + get_class($transformer), + )); } return $transformer; }, $serializedTransformers); - return new ClassifierPipeline($estimator, $transformers); - } - - public function cleanUp(): void { - $threshold = $this->timeFactory->getTime() - 30 * 24 * 60 * 60; - $totalAccounts = $this->accountMapper->getTotal(); - $classifiers = $this->mapper->findHistoric($threshold, $totalAccounts * 10); - foreach ($classifiers as $classifier) { - try { - $this->deleteModel($classifier->getId()); - $this->mapper->delete($classifier); - } catch (NotPermittedException $e) { - // Log and continue. This is not critical - $this->logger->warning('Could not clean-up old classifier', [ - 'id' => $classifier->getId(), - 'exception' => $e, - ]); - } - } - } - - /** - * @throws NotPermittedException - */ - private function deleteModel(int $id): void { - $this->logger->debug('Deleting serialized classifier from app data', [ - 'id' => $id, - ]); - try { - $modelsFolder = $this->appData->getFolder(self::ADD_DATA_FOLDER); - $modelFile = $modelsFolder->getFile((string)$id); - $modelFile->delete(); - } catch (NotFoundException $e) { - $this->logger->debug("Classifier model $id does not exist", [ - 'exception' => $e, - ]); - } - } - - /** - * Load and instantiate extractor based on a classifier's app version. - * - * @param Classifier $classifier - * @param ClassifierPipeline $pipeline - * @return IExtractor - * - * @throws ContainerExceptionInterface - * @throws ServiceException - */ - private function loadExtractor(Classifier $classifier, - ClassifierPipeline $pipeline): IExtractor { - $appVersion = $this->parseAppVersion($classifier->getAppVersion()); - if ($appVersion[0] >= 3 && $appVersion[1] >= 2) { - return $this->loadExtractorV2($pipeline->getTransformers()); - } - - return $this->loadExtractorV1($pipeline->getTransformers()); - } + $extractor = $this->loadExtractor($transformers); - /** - * @return VanillaCompositeExtractor - * - * @throws ContainerExceptionInterface - */ - private function loadExtractorV1(): VanillaCompositeExtractor { - return $this->container->get(VanillaCompositeExtractor::class); + return new ClassifierPipeline($estimator, $extractor, $transformers); } /** - * @param Transformer[] $transformers - * @return NewCompositeExtractor + * Load and instantiate extractor based on the given transformers. * - * @throws ContainerExceptionInterface - * @throws ServiceException + * @throws ServiceException If the transformers array contains unexpected instances or the composite extractor can't be instantiated */ - private function loadExtractorV2(array $transformers): NewCompositeExtractor { + private function loadExtractor(array $transformers): IExtractor { $wordCountVectorizer = $transformers[0]; if (!($wordCountVectorizer instanceof WordCountVectorizer)) { - throw new ServiceException('Failed to load persisted transformer: Expected ' . WordCountVectorizer::class . ', got' . $wordCountVectorizer::class); + throw new ServiceException(sprintf( + 'Failed to load persisted transformer: Expected %s, got %s', + WordCountVectorizer::class, + get_class($wordCountVectorizer), + )); } + $tfidfTransformer = $transformers[1]; if (!($tfidfTransformer instanceof TfIdfTransformer)) { - throw new ServiceException('Failed to load persisted transformer: Expected ' . TfIdfTransformer::class . ', got' . $tfidfTransformer::class); + throw new ServiceException(sprintf( + 'Failed to load persisted transformer: Expected %s, got %s', + TfIdfTransformer::class, + get_class($tfidfTransformer), + )); } - $subjectExtractor = new SubjectExtractor(); - $subjectExtractor->setWordCountVectorizer($wordCountVectorizer); - $subjectExtractor->setTfidf($tfidfTransformer); - return new NewCompositeExtractor( - $this->container->get(VanillaCompositeExtractor::class), - $subjectExtractor, - ); - } + try { + /** @var CompositeExtractor $extractor */ + $extractor = $this->container->get(CompositeExtractor::class); + } catch (ContainerExceptionInterface $e) { + throw new ServiceException('Failed to instantiate the composite extractor', 0, $e); + } - private function getCacheKey(int $id): string { - return "mail_classifier_$id"; + $extractor->getSubjectExtractor()->setWordCountVectorizer($wordCountVectorizer); + $extractor->getSubjectExtractor()->setTfidf($tfidfTransformer); + return $extractor; } - private function getTransformerCacheKey(int $id, int $index): string { - return $this->getCacheKey($id) . "_transformer_$index"; + private function getCacheInstance(): ?ICache { + if (!$this->cacheFactory->isAvailable()) { + return null; + } + + $version = self::VERSION; + return $this->cacheFactory->createDistributed("mail-classifier/v$version/"); } /** - * @param int $id - * @param int $transformerCount - * - * @return (?string)[]|null Array of serialized classifier and transformers + * @return string[]|null Array of serialized classifier and transformers */ - private function getCached(int $id, int $transformerCount): ?array { - // FIXME: Will always return null as the cached, serialized data is always an empty string. - // See my note in self::cache() for further elaboration. - - if (!$this->cacheFactory->isLocalCacheAvailable()) { + private function getCached(string $id): ?array { + $cache = $this->getCacheInstance(); + if ($cache === null) { return null; } - $cache = $this->cacheFactory->createLocal(); - $values = []; - $values[] = $cache->get($this->getCacheKey($id)); - for ($i = 0; $i < $transformerCount; $i++) { - $values[] = $cache->get($this->getTransformerCacheKey($id, $i)); - } - - // Only return cached values if estimator and all transformers are available - if (in_array(null, $values, true)) { + $json = $cache->get($id); + if (!is_string($json)) { return null; } - return $values; - } - - private function cache(int $id, string $serialized, array $serializedTransformers): void { - // FIXME: This is broken as some cache implementations will run the provided value through - // json_encode which drops non-utf8 strings. The serialized string contains binary - // data so an empty string will be saved instead (tested on Redis). - // Note: JSON requires strings to be valid utf8 (as per its spec). - - // IDEA: Implement a method ICache::setRaw() that forwards a raw/binary string as is to the - // underlying cache backend. - - if (!$this->cacheFactory->isLocalCacheAvailable()) { - return; - } - $cache = $this->cacheFactory->createLocal(); - $cache->set($this->getCacheKey($id), $serialized); - - $transformerIndex = 0; - foreach ($serializedTransformers as $transformer) { - $cache->set($this->getTransformerCacheKey($id, $transformerIndex), $transformer); - $transformerIndex++; - } + $serializedData = json_decode($json); + return array_map(base64_decode(...), $serializedData); } /** - * Parse minor and major part of the given semver string. - * - * @return int[] + * @param string[] $serializedData Array of serialized classifier and transformers */ - private function parseAppVersion(string $version): array { - $parts = explode('.', $version); - if (count($parts) < 2) { - return [0, 0]; + private function setCached(string $id, array $serializedData): void { + $cache = $this->getCacheInstance(); + if ($cache === null) { + return; } - return [(int)$parts[0], (int)$parts[1]]; + // Serialized data contains binary, non-utf8 data so we encode it as base64 first + $encodedData = array_map(base64_encode(...), $serializedData); + $json = json_encode($encodedData, JSON_THROW_ON_ERROR); + + // Set a ttl of a week because a new model will be generated daily + $cache->set($id, $json, 3600 * 24 * 7); } } diff --git a/lib/Service/Classification/RubixMemoryPersister.php b/lib/Service/Classification/RubixMemoryPersister.php new file mode 100644 index 0000000000..2b170b38b5 --- /dev/null +++ b/lib/Service/Classification/RubixMemoryPersister.php @@ -0,0 +1,41 @@ +data; + } + + public function save(Encoding $encoding): void { + $this->data = $encoding->data(); + } + + public function load(): Encoding { + if ($this->data === null) { + throw new ValueError('Trying to load encoding when no data is available'); + } + + return new Encoding($this->data); + } + + public function __toString() { + return (string)self::class; + } +} diff --git a/lib/Service/CleanupService.php b/lib/Service/CleanupService.php index 65fa421200..74f5c0070a 100644 --- a/lib/Service/CleanupService.php +++ b/lib/Service/CleanupService.php @@ -17,7 +17,6 @@ use OCA\Mail\Db\MessageRetentionMapper; use OCA\Mail\Db\MessageSnoozeMapper; use OCA\Mail\Db\TagMapper; -use OCA\Mail\Service\Classification\PersistenceService; use OCA\Mail\Support\PerformanceLogger; use OCP\AppFramework\Utility\ITimeFactory; use Psr\Log\LoggerInterface; @@ -44,7 +43,6 @@ class CleanupService { private MessageSnoozeMapper $messageSnoozeMapper; - private PersistenceService $classifierPersistenceService; private ITimeFactory $timeFactory; public function __construct(MailAccountMapper $mailAccountMapper, @@ -55,7 +53,6 @@ public function __construct(MailAccountMapper $mailAccountMapper, TagMapper $tagMapper, MessageRetentionMapper $messageRetentionMapper, MessageSnoozeMapper $messageSnoozeMapper, - PersistenceService $classifierPersistenceService, ITimeFactory $timeFactory) { $this->aliasMapper = $aliasMapper; $this->mailboxMapper = $mailboxMapper; @@ -64,7 +61,6 @@ public function __construct(MailAccountMapper $mailAccountMapper, $this->tagMapper = $tagMapper; $this->messageRetentionMapper = $messageRetentionMapper; $this->messageSnoozeMapper = $messageSnoozeMapper; - $this->classifierPersistenceService = $classifierPersistenceService; $this->mailAccountMapper = $mailAccountMapper; $this->timeFactory = $timeFactory; } @@ -92,8 +88,6 @@ public function cleanUp(LoggerInterface $logger): void { $task->step('delete expired messages'); $this->messageSnoozeMapper->deleteOrphans(); $task->step('delete orphan snoozes'); - $this->classifierPersistenceService->cleanUp(); - $task->step('delete orphan classifiers'); $task->end(); } } From 379484ae2ad3a0db6cf568312c18c5ffc7ecf6b6 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Tue, 22 Oct 2024 10:01:47 +0200 Subject: [PATCH 33/37] Revert "Adjust tuned knn params" This reverts commit fb2475f9e8303af2a8b2e7485ea21b35ec4c0bcf. --- lib/Service/Classification/ImportanceClassifier.php | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index 7cc499c782..0f8ed16983 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -32,7 +32,7 @@ use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Estimator; -use Rubix\ML\Kernels\Distance\Jaccard; +use Rubix\ML\Kernels\Distance\Manhattan; use Rubix\ML\Learner; use Rubix\ML\Persistable; use RuntimeException; @@ -125,8 +125,10 @@ public function __construct(MailboxMapper $mailboxMapper, private static function createDefaultEstimator(): KNearestNeighbors { // A meta estimator was trained on the same data multiple times to average out the // variance of the trained model. - // Parameters were chosen from the best configuration across 20 runs. - return new KNearestNeighbors(35, true, new Jaccard()); + // Parameters were chosen from the best configuration across 100 runs. + // Both variance (spread) and f1 score were considered. + // Note: Lower k values yield slightly higher f1 scores but show higher variances. + return new KNearestNeighbors(15, true, new Manhattan()); } private function filterMessageHasSenderEmail(Message $message): bool { From 7da784ebd210528652868f8739401b094734ecf2 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Tue, 22 Oct 2024 10:38:50 +0200 Subject: [PATCH 34/37] Finalize code changes --- lib/Command/TrainAccount.php | 13 --------- lib/Model/ClassifierPipeline.php | 16 ++--------- .../FeatureExtraction/SubjectExtractor.php | 28 ++++--------------- .../Classification/ImportanceClassifier.php | 10 ++----- .../Classification/PersistenceService.php | 4 +-- 5 files changed, 11 insertions(+), 60 deletions(-) diff --git a/lib/Command/TrainAccount.php b/lib/Command/TrainAccount.php index 12866de7f3..28963f496b 100644 --- a/lib/Command/TrainAccount.php +++ b/lib/Command/TrainAccount.php @@ -63,19 +63,6 @@ protected function configure() { $this->setName('mail:account:train'); $this->setDescription('Train the classifier of new messages'); $this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED); - $this->addOption( - self::ARGUMENT_OLD, - null, - null, - 'Use old vanilla composite extractor and GaussianNB estimator (implies --old-extractor and --old-estimator)' - ); - $this->addOption( - self::ARGUMENT_OLD_EXTRACTOR, - null, - null, - 'Use old vanilla composite extractor without text based features' - ); - $this->addOption(self::ARGUMENT_OLD_ESTIMATOR, null, null, 'Use old GaussianNB estimator'); $this->addOption(self::ARGUMENT_SHUFFLE, null, null, 'Shuffle data set before training'); $this->addOption( self::ARGUMENT_DRY_RUN, diff --git a/lib/Model/ClassifierPipeline.php b/lib/Model/ClassifierPipeline.php index bcc1be893e..f58d40ba89 100644 --- a/lib/Model/ClassifierPipeline.php +++ b/lib/Model/ClassifierPipeline.php @@ -11,16 +11,11 @@ use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; use Rubix\ML\Estimator; -use Rubix\ML\Transformers\Transformer; class ClassifierPipeline { - /** - * @param Transformer[] $transformers - */ public function __construct( - private Estimator $estimator, - private IExtractor $extractor, - private array $transformers, + private readonly Estimator $estimator, + private readonly IExtractor $extractor, ) { } @@ -31,11 +26,4 @@ public function getEstimator(): Estimator { public function getExtractor(): IExtractor { return $this->extractor; } - - /** - * @return Transformer[] - */ - public function getTransformers(): array { - return $this->transformers; - } } diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php index c5ecf6681f..374a45a689 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php @@ -13,9 +13,7 @@ use OCA\Mail\Db\Message; use Rubix\ML\Datasets\Labeled; use Rubix\ML\Datasets\Unlabeled; -use Rubix\ML\Transformers\MinMaxNormalizer; use Rubix\ML\Transformers\MultibyteTextNormalizer; -use Rubix\ML\Transformers\PrincipalComponentAnalysis; use Rubix\ML\Transformers\TfIdfTransformer; use Rubix\ML\Transformers\Transformer; use Rubix\ML\Transformers\WordCountVectorizer; @@ -25,19 +23,13 @@ class SubjectExtractor implements IExtractor { private WordCountVectorizer $wordCountVectorizer; - private Transformer $dimensionalReductionTransformer; - private Transformer $normalizer; private Transformer $tfidf; private int $max = -1; public function __construct() { // Limit vocabulary to limit memory usage - $vocabSize = 500; - $this->wordCountVectorizer = new WordCountVectorizer($vocabSize); - + $this->wordCountVectorizer = new WordCountVectorizer(500); $this->tfidf = new TfIdfTransformer(); - //$this->dimensionalReductionTransformer = new PrincipalComponentAnalysis((int)($vocabSize * 0.1)); - //$this->normalizer = new MinMaxNormalizer(); } public function getWordCountVectorizer(): WordCountVectorizer { @@ -49,7 +41,7 @@ public function setWordCountVectorizer(WordCountVectorizer $wordCountVectorizer) $this->limitFeatureSize(); } - public function getTfidf(): TfIdfTransformer { + public function getTfIdf(): TfIdfTransformer { return $this->tfidf; } @@ -57,9 +49,6 @@ public function setTfidf(TfIdfTransformer $tfidf): void { $this->tfidf = $tfidf; } - /** - * @inheritDoc - */ public function prepare(Account $account, array $incomingMailboxes, array $outgoingMailboxes, array $messages): void { /** @var array> $data */ $data = array_map(static function (Message $message) { @@ -76,15 +65,11 @@ public function prepare(Account $account, array $incomingMailboxes, array $outgo ) ->apply(new MultibyteTextNormalizer()) ->apply($this->wordCountVectorizer) - ->apply($this->tfidf) - ;//->apply($this->dimensionalReductionTransformer); + ->apply($this->tfidf); $this->limitFeatureSize(); } - /** - * @inheritDoc - */ public function extract(Message $message): array { $sender = $message->getFrom()->first(); if ($sender === null) { @@ -97,8 +82,7 @@ public function extract(Message $message): array { $trainDataSet = Unlabeled::build([[$trainText]]) ->apply(new MultibyteTextNormalizer()) ->apply($this->wordCountVectorizer) - ->apply($this->tfidf) - ;//->apply($this->dimensionalReductionTransformer); + ->apply($this->tfidf); // Use zeroed vector if no features could be extracted if ($trainDataSet->numFeatures() === 0) { @@ -107,13 +91,11 @@ public function extract(Message $message): array { $textFeatures = $trainDataSet->sample(0); } - //var_dump($textFeatures); - return $textFeatures; } /** - * Limit feature vector length to actual vocabulary size. + * Limit feature vector length to actual size of vocabulary. */ private function limitFeatureSize(): void { $vocabularies = $this->wordCountVectorizer->vocabularies(); diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index 0f8ed16983..df05a355b2 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -349,16 +349,12 @@ public function trainWithCustomDataSet( */ private function getIncomingMailboxes(Account $account): array { return array_filter($this->mailboxMapper->findAll($account), static function (Mailbox $mailbox) { - return $mailbox->isInbox(); - - /* foreach (self::EXEMPT_FROM_TRAINING as $excluded) { if ($mailbox->isSpecialUse($excluded)) { return false; } } return true; - */ }); } @@ -407,7 +403,6 @@ private function getFeaturesAndImportance(Account $account, } $features = $extractor->extract($message); - //var_dump($features); return [ 'features' => $features, @@ -455,16 +450,15 @@ public function classifyImportance(Account $account, ); } - [$estimator, $extractor] = $pipeline; $messagesWithSender = array_filter($messages, [$this, 'filterMessageHasSenderEmail']); $features = $this->getFeaturesAndImportance( $account, $this->getIncomingMailboxes($account), $this->getOutgoingMailboxes($account), $messagesWithSender, - $extractor, + $pipeline->getExtractor(), ); - $predictions = $estimator->predict( + $predictions = $pipeline->getEstimator()->predict( Unlabeled::build(array_column($features, 'features')) ); return array_combine( diff --git a/lib/Service/Classification/PersistenceService.php b/lib/Service/Classification/PersistenceService.php index 120136b27e..d25c40fb53 100644 --- a/lib/Service/Classification/PersistenceService.php +++ b/lib/Service/Classification/PersistenceService.php @@ -67,12 +67,12 @@ public function persist( /* * Then we serialize the transformer pipeline */ - $transfomers = [ + $transformers = [ $extractor->getSubjectExtractor()->getWordCountVectorizer(), $extractor->getSubjectExtractor()->getTfIdf(), ]; $serializer = new RBX(); - foreach ($transfomers as $transformer) { + foreach ($transformers as $transformer) { try { $persister = new RubixMemoryPersister(); /** From 7b62b394c00c91bea285a12eee3c371f3fa91e64 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Tue, 22 Oct 2024 10:39:11 +0200 Subject: [PATCH 35/37] Run compser cs:fix --- .../Classification/FeatureExtraction/CompositeExtractor.php | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php index aaa30dcbfe..54469bc220 100644 --- a/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php @@ -11,8 +11,6 @@ use OCA\Mail\Account; use OCA\Mail\Db\Message; -use Rubix\ML\Transformers\TfIdfTransformer; -use Rubix\ML\Transformers\WordCountVectorizer; use function OCA\Mail\array_flat_map; /** @@ -60,4 +58,3 @@ public function getSubjectExtractor(): SubjectExtractor { return $this->subjectExtractor; } } - From 21d45eb99d4c0073fb9400b9144d9f14a7208d40 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Tue, 22 Oct 2024 12:08:31 +0200 Subject: [PATCH 36/37] Fix all remaining psalm issues --- lib/Command/RunMetaEstimator.php | 57 ++------------ lib/Command/TrainAccount.php | 76 +++---------------- .../FeatureExtraction/SubjectExtractor.php | 3 +- .../Classification/ImportanceClassifier.php | 26 ++++--- .../Classification/PersistenceService.php | 2 +- .../Classification/RubixMemoryPersister.php | 11 +-- 6 files changed, 36 insertions(+), 139 deletions(-) diff --git a/lib/Command/RunMetaEstimator.php b/lib/Command/RunMetaEstimator.php index 32d1187d37..8519ed84f4 100644 --- a/lib/Command/RunMetaEstimator.php +++ b/lib/Command/RunMetaEstimator.php @@ -10,12 +10,10 @@ namespace OCA\Mail\Command; use OCA\Mail\Service\AccountService; -use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor; use OCA\Mail\Service\Classification\ImportanceClassifier; use OCA\Mail\Support\ConsoleLoggerDecorator; use OCP\AppFramework\Db\DoesNotExistException; use OCP\IConfig; -use Psr\Container\ContainerInterface; use Psr\Log\LoggerInterface; use Rubix\ML\Backends\Amp; use Rubix\ML\Classifiers\KNearestNeighbors; @@ -28,25 +26,21 @@ use Symfony\Component\Console\Command\Command; use Symfony\Component\Console\Input\InputArgument; use Symfony\Component\Console\Input\InputInterface; -use Symfony\Component\Console\Input\InputOption; use Symfony\Component\Console\Output\OutputInterface; class RunMetaEstimator extends Command { public const ARGUMENT_ACCOUNT_ID = 'account-id'; public const ARGUMENT_SHUFFLE = 'shuffle'; - public const ARGUMENT_LOAD_DATA = 'load-data'; private AccountService $accountService; private LoggerInterface $logger; private ImportanceClassifier $classifier; - private ContainerInterface $container; private IConfig $config; public function __construct( AccountService $accountService, LoggerInterface $logger, ImportanceClassifier $classifier, - ContainerInterface $container, IConfig $config, ) { parent::__construct(); @@ -54,7 +48,6 @@ public function __construct( $this->accountService = $accountService; $this->logger = $logger; $this->classifier = $classifier; - $this->container = $container; $this->config = $config; } @@ -63,12 +56,6 @@ protected function configure(): void { $this->setDescription('Run the meta estimator for an account'); $this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED); $this->addOption(self::ARGUMENT_SHUFFLE, null, null, 'Shuffle data set before training'); - $this->addOption( - self::ARGUMENT_LOAD_DATA, - null, - InputOption::VALUE_REQUIRED, - 'Load training data set from a JSON file' - ); } public function isEnabled(): bool { @@ -86,26 +73,11 @@ protected function execute(InputInterface $input, OutputInterface $output): int return 1; } - /** @var CompositeExtractor $extractor */ - $extractor = $this->container->get(CompositeExtractor::class); $consoleLogger = new ConsoleLoggerDecorator( $this->logger, $output ); - if ($loadDataPath = $input->getOption(self::ARGUMENT_LOAD_DATA)) { - $json = file_get_contents($loadDataPath); - $dataSet = json_decode($json, true, 512, JSON_THROW_ON_ERROR); - } else { - $dataSet = $this->classifier->buildDataSet( - $account, - $extractor, - $consoleLogger, - null, - $shuffle, - ); - } - $estimator = static function () use ($consoleLogger) { $params = [ [5, 10, 15, 20, 25, 30, 35, 40], // Neighbors @@ -125,28 +97,15 @@ protected function execute(InputInterface $input, OutputInterface $output): int }; /** @var GridSearch $metaEstimator */ - if ($dataSet) { - $metaEstimator = $this->classifier->trainWithCustomDataSet( - $account, - $consoleLogger, - $dataSet, - $extractor, - $estimator, - null, - false, - ); - } else { - $metaEstimator = $this->classifier->train( - $account, - $consoleLogger, - $extractor, - $estimator, - $shuffle, - false, - ); - } + $metaEstimator = $this->classifier->train( + $account, + $consoleLogger, + $estimator, + $shuffle, + false, + ); - if ($metaEstimator) { + if ($metaEstimator !== null) { $output->writeln("Best estimator: {$metaEstimator->base()}"); } diff --git a/lib/Command/TrainAccount.php b/lib/Command/TrainAccount.php index 28963f496b..a3eaf94aa9 100644 --- a/lib/Command/TrainAccount.php +++ b/lib/Command/TrainAccount.php @@ -11,12 +11,9 @@ use OCA\Mail\Service\AccountService; use OCA\Mail\Service\Classification\ClassificationSettingsService; -use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor; -use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; use OCA\Mail\Service\Classification\ImportanceClassifier; use OCA\Mail\Support\ConsoleLoggerDecorator; use OCP\AppFramework\Db\DoesNotExistException; -use Psr\Container\ContainerInterface; use Psr\Log\LoggerInterface; use Symfony\Component\Console\Command\Command; use Symfony\Component\Console\Input\InputArgument; @@ -27,39 +24,28 @@ class TrainAccount extends Command { public const ARGUMENT_ACCOUNT_ID = 'account-id'; - public const ARGUMENT_OLD = 'old'; - public const ARGUMENT_OLD_ESTIMATOR = 'old-estimator'; - public const ARGUMENT_OLD_EXTRACTOR = 'old-extractor'; public const ARGUMENT_SHUFFLE = 'shuffle'; - public const ARGUMENT_SAVE_DATA = 'save-data'; - public const ARGUMENT_LOAD_DATA = 'load-data'; public const ARGUMENT_DRY_RUN = 'dry-run'; public const ARGUMENT_FORCE = 'force'; private AccountService $accountService; private ImportanceClassifier $classifier; private LoggerInterface $logger; - private ContainerInterface $container; private ClassificationSettingsService $classificationSettingsService; public function __construct(AccountService $service, ImportanceClassifier $classifier, ClassificationSettingsService $classificationSettingsService, - LoggerInterface $logger, - ContainerInterface $container) { + LoggerInterface $logger) { parent::__construct(); $this->accountService = $service; $this->classifier = $classifier; $this->logger = $logger; - $this->container = $container; $this->classificationSettingsService = $classificationSettingsService; } - /** - * @return void - */ - protected function configure() { + protected function configure(): void { $this->setName('mail:account:train'); $this->setDescription('Train the classifier of new messages'); $this->addArgument(self::ARGUMENT_ACCOUNT_ID, InputArgument::REQUIRED); @@ -76,18 +62,6 @@ protected function configure() { null, 'Train an estimator even if the classification is disabled by the user' ); - $this->addOption( - self::ARGUMENT_SAVE_DATA, - null, - InputOption::VALUE_REQUIRED, - 'Save training data set to a JSON file' - ); - $this->addOption( - self::ARGUMENT_LOAD_DATA, - null, - InputOption::VALUE_REQUIRED, - 'Load training data set from a JSON file' - ); } protected function execute(InputInterface $input, OutputInterface $output): int { @@ -108,50 +82,18 @@ protected function execute(InputInterface $input, OutputInterface $output): int return 2; } - /** @var IExtractor $extractor */ - $extractor = $this->container->get(CompositeExtractor::class); - $consoleLogger = new ConsoleLoggerDecorator( $this->logger, $output ); - $dataSet = null; - if ($saveDataPath = $input->getOption(self::ARGUMENT_SAVE_DATA)) { - $dataSet = $this->classifier->buildDataSet( - $account, - $extractor, - $consoleLogger, - null, - $shuffle, - ); - $json = json_encode($dataSet, JSON_THROW_ON_ERROR); - file_put_contents($saveDataPath, $json); - } elseif ($loadDataPath = $input->getOption(self::ARGUMENT_LOAD_DATA)) { - $json = file_get_contents($loadDataPath); - $dataSet = json_decode($json, true, 512, JSON_THROW_ON_ERROR); - } - - if ($dataSet) { - $this->classifier->trainWithCustomDataSet( - $account, - $consoleLogger, - $dataSet, - $extractor, - null, - null, - !$dryRun - ); - } else { - $this->classifier->train( - $account, - $consoleLogger, - $extractor, - null, - $shuffle, - !$dryRun - ); - } + $this->classifier->train( + $account, + $consoleLogger, + null, + $shuffle, + !$dryRun + ); $mbs = (int)(memory_get_peak_usage() / 1024 / 1024); $output->writeln('' . $mbs . 'MB of memory used'); diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php index 374a45a689..44bbca9e6f 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php @@ -15,7 +15,6 @@ use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Transformers\MultibyteTextNormalizer; use Rubix\ML\Transformers\TfIdfTransformer; -use Rubix\ML\Transformers\Transformer; use Rubix\ML\Transformers\WordCountVectorizer; use RuntimeException; use function array_column; @@ -23,7 +22,7 @@ class SubjectExtractor implements IExtractor { private WordCountVectorizer $wordCountVectorizer; - private Transformer $tfidf; + private TfIdfTransformer $tfidf; private int $max = -1; public function __construct() { diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index df05a355b2..fea58e4556 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -131,6 +131,17 @@ private static function createDefaultEstimator(): KNearestNeighbors { return new KNearestNeighbors(15, true, new Manhattan()); } + /** + * @throws ServiceException If the extractor is not available + */ + private function createExtractor(): CompositeExtractor { + try { + return $this->container->get(CompositeExtractor::class); + } catch (ContainerExceptionInterface $e) { + throw new ServiceException('Default extractor is not available', 0, $e); + } + } + private function filterMessageHasSenderEmail(Message $message): bool { return $message->getFrom()->first() !== null && $message->getFrom()->first()->getEmail() !== null; } @@ -201,7 +212,6 @@ public function buildDataSet( * * @param Account $account * @param LoggerInterface $logger - * @param ?IExtractor $extractor The extractor to use for feature extraction. If null, the default extractor will be used. * @param ?Closure $estimator Returned instance should at least implement Learner, Estimator and Persistable. If null, the default estimator will be used. * @param bool $shuffleDataSet Shuffle the data set before training * @param bool $persist Persist the trained classifier to use it for message classification @@ -213,21 +223,13 @@ public function buildDataSet( public function train( Account $account, LoggerInterface $logger, - ?IExtractor $extractor = null, ?Closure $estimator = null, bool $shuffleDataSet = false, bool $persist = true, ): ?Estimator { $perf = $this->performanceLogger->start('importance classifier training'); - if ($extractor === null) { - try { - $extractor = $this->container->get(CompositeExtractor::class); - } catch (ContainerExceptionInterface $e) { - throw new ServiceException('Default extractor is not available', 0, $e); - } - } - + $extractor = $this->createExtractor(); $dataSet = $this->buildDataSet($account, $extractor, $logger, $perf, $shuffleDataSet); if ($dataSet === null) { return null; @@ -250,7 +252,7 @@ public function train( * @param Account $account * @param LoggerInterface $logger * @param array $dataSet Training data set built by buildDataSet() - * @param IExtractor $extractor Extractor used to extract the given data set + * @param CompositeExtractor $extractor Extractor used to extract the given data set * @param ?Closure $estimator Returned instance should at least implement Learner, Estimator and Persistable. If null, the default estimator will be used. * @param PerformanceLoggerTask|null $perf Optionally reuse a performance logger task * @param bool $persist Persist the trained classifier to use it for message classification @@ -259,7 +261,7 @@ public function train( * * @throws ServiceException */ - public function trainWithCustomDataSet( + private function trainWithCustomDataSet( Account $account, LoggerInterface $logger, array $dataSet, diff --git a/lib/Service/Classification/PersistenceService.php b/lib/Service/Classification/PersistenceService.php index d25c40fb53..fd29e67bb3 100644 --- a/lib/Service/Classification/PersistenceService.php +++ b/lib/Service/Classification/PersistenceService.php @@ -142,7 +142,7 @@ public function loadLatest(Account $account): ?ClassifierPipeline { $extractor = $this->loadExtractor($transformers); - return new ClassifierPipeline($estimator, $extractor, $transformers); + return new ClassifierPipeline($estimator, $extractor); } /** diff --git a/lib/Service/Classification/RubixMemoryPersister.php b/lib/Service/Classification/RubixMemoryPersister.php index 2b170b38b5..c3abff2463 100644 --- a/lib/Service/Classification/RubixMemoryPersister.php +++ b/lib/Service/Classification/RubixMemoryPersister.php @@ -11,15 +11,14 @@ use Rubix\ML\Encoding; use Rubix\ML\Persisters\Persister; -use ValueError; class RubixMemoryPersister implements Persister { public function __construct( - private ?string $data = null, + private string $data = '', ) { } - public function getData(): ?string { + public function getData(): string { return $this->data; } @@ -28,14 +27,10 @@ public function save(Encoding $encoding): void { } public function load(): Encoding { - if ($this->data === null) { - throw new ValueError('Trying to load encoding when no data is available'); - } - return new Encoding($this->data); } public function __toString() { - return (string)self::class; + return self::class; } } From a7ea9c05ad3d29b896bf442c8b275a7adbdd1281 Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Tue, 22 Oct 2024 13:58:35 +0200 Subject: [PATCH 37/37] Run composer cs:fix --- lib/Command/TrainAccount.php | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/Command/TrainAccount.php b/lib/Command/TrainAccount.php index a3eaf94aa9..19fef4ad0d 100644 --- a/lib/Command/TrainAccount.php +++ b/lib/Command/TrainAccount.php @@ -18,7 +18,6 @@ use Symfony\Component\Console\Command\Command; use Symfony\Component\Console\Input\InputArgument; use Symfony\Component\Console\Input\InputInterface; -use Symfony\Component\Console\Input\InputOption; use Symfony\Component\Console\Output\OutputInterface; use function memory_get_peak_usage;