-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Topic modeling #37
base: dev
Are you sure you want to change the base?
Topic modeling #37
Changes from 12 commits
b07786a
257131d
265cc83
bb5f036
2a15b32
c3faebf
c6b764d
a17ed34
d5592f0
58513f7
aac7d98
ef9853f
c97c6df
aaffd94
58db43a
30af4c2
31a35ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,34 @@ | |
""" | ||
import collections | ||
import joblib as jl | ||
from sklearn.base import BaseEstimator, TransformerMixin | ||
from six import iteritems | ||
from decorator import decorator | ||
import re | ||
|
||
try: | ||
from spacy.lang.en.stop_words import STOP_WORDS | ||
from gensim.corpora import Dictionary | ||
from gensim import sklearn_api | ||
import gensim | ||
spacy = True | ||
except ImportError: | ||
spacy = None | ||
gensim = None | ||
|
||
|
||
@decorator | ||
def check_spacy(func, *args, **kwargs): | ||
if spacy is None: | ||
raise RuntimeError('Must install spacy to use {}'.format(func)) | ||
return func(*args, **kwargs) | ||
|
||
|
||
@decorator | ||
def check_gensim(func, *args, **kwargs): | ||
if gensim is None: | ||
raise RuntimeError('Must install gensim to use {}'.format(func)) | ||
return func(*args, **kwargs) | ||
|
||
|
||
class _PersistanceMixin(object): | ||
|
@@ -85,3 +113,41 @@ class CandidateModel( | |
parameter values to test as values | ||
""" | ||
pass | ||
|
||
|
||
class QGLdaModel(BaseEstimator, TransformerMixin): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like either the prefix or the Model specifier. I'd call this GensimLDA or something like that. |
||
@check_gensim | ||
@check_spacy | ||
def __init__(self, word_regex=r'\b[A-z]{2,}\b', stop_words=STOP_WORDS): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would think the options for
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not having any stop words seems to output a pretty unusable model - my thinking is it's best to have some default, and if the user chooses to override that default with None they can, but the defaults should be able to produce something usable - we could include some output if they don't provide any (e.g. "INFO: No stop words provided, using sklearn builtins"), and potentially a warning if None is passed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
self.stop_words = stop_words | ||
self.word_regex = re.compile(word_regex) | ||
|
||
def transform(self, driver): | ||
self.test_corpus = self.create_corpus(driver) | ||
return self.model.transform(self.test_corpus) | ||
|
||
def create_corpus(self, driver): | ||
return [self.dictionary.doc2bow([i.group(0).lower() | ||
for i in self.word_regex.finditer(doc.text)]) | ||
for doc in driver.stream()] | ||
|
||
def fit(self, driver, alpha=None, eta=None, num_topics=1, passes=1): | ||
self.dictionary = Dictionary([[i.group(0).lower() | ||
for i in self.word_regex | ||
.finditer(doc.text)] | ||
for doc in driver.stream()]) | ||
stop_ids = [self.dictionary.token2id[stopword] for stopword | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't it be better to only pass the dictionary words that aren't in stop_words? |
||
in self.stop_words if stopword in self.dictionary.token2id] | ||
once_ids = [tokenid for tokenid, docfreq in | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we doing this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Filtering out words that only occur once was recommended in the Gensim documentation - beyond that, I don't know if it actually improves the performance of the model. |
||
iteritems(self.dictionary.dfs) if docfreq == 1] | ||
self.dictionary.filter_tokens(stop_ids + once_ids) | ||
self.corpus = self.create_corpus(driver) | ||
self.model = sklearn_api.ldamodel.LdaTransformer( | ||
alpha=alpha, | ||
eta=eta, | ||
num_topics=num_topics, | ||
passes=passes, | ||
id2word=self.dictionary | ||
) | ||
self.model.fit(self.corpus) | ||
return self |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,6 +65,10 @@ def find_version(*file_paths): | |
'nlp': [ | ||
'textblob', | ||
'nltk', | ||
], | ||
'topic_modeling': [ | ||
'gensim', | ||
'spacy' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we not need a spacy corpus as well? |
||
] | ||
}, | ||
entry_points={ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# import pytest | ||
import subprocess | ||
import quantgov.estimator | ||
import quantgov | ||
|
||
from pathlib import Path | ||
|
||
PSEUDO_CORPUS_PATH = Path(__file__).resolve().parent.joinpath('pseudo_corpus') | ||
driver = quantgov.load_driver(PSEUDO_CORPUS_PATH) | ||
|
||
|
||
def test_topic_model(): | ||
sample = quantgov.estimator.structures.QGLdaModel() | ||
sample.fit(driver, num_topics=2) | ||
sample.transform(driver) | ||
|
||
|
||
def check_output(cmd): | ||
return ( | ||
subprocess.check_output(cmd, universal_newlines=True) | ||
.replace('\n\n', '\n') | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're literally only using spacy here for the stopwords, can't we somehow find the sklearn stopwords used in the
CountVectorizer
? That's got to be importable from somewhere.