Skip to content

Commit

Permalink
Merge pull request #33 from andrewtavis/update-gensim
Browse files Browse the repository at this point in the history
#21 support for gensim 4.x
  • Loading branch information
andrewtavis authored Apr 29, 2021
2 parents 9826b4c + 0c5a802 commit c26c93f
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 60 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
author = "kwx developers"

# The full version, including alpha/beta/rc tags
release = "0.1.7.5"
release = "0.1.8"


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
name="kwx",
packages=find_packages(where="src"),
package_dir={"": "src"},
version="0.1.7.5",
version="0.1.8",
author="Andrew Tavis McAllister",
author_email="[email protected]",
classifiers=[
Expand Down
41 changes: 29 additions & 12 deletions src/kwx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from multiprocessing import Pool

import emoji
import gensim
import pandas as pd
import spacy
from googletrans import Translator
Expand Down Expand Up @@ -336,18 +337,34 @@ def clean(
gc.collect()
pbar.update()

bigrams = Phrases(
sentences=tokenized_texts,
min_count=min_ngram_count,
threshold=5.0,
common_terms=stop_words,
) # half the normal threshold
trigrams = Phrases(
sentences=bigrams[tokenized_texts],
min_count=min_ngram_count,
threshold=5.0,
common_terms=stop_words,
)
# Add bigrams and trigrams
# Half the normal threshold
if gensim.__version__[0] == "4":
bigrams = Phrases(
sentences=tokenized_texts,
min_count=min_ngram_count,
threshold=5.0,
connector_words=stop_words,
)
trigrams = Phrases(
sentences=bigrams[tokenized_texts],
min_count=min_ngram_count,
threshold=5.0,
connector_words=stop_words,
)
else:
bigrams = Phrases( # pylint: disable=unexpected-keyword-arg
sentences=tokenized_texts,
min_count=min_ngram_count,
threshold=5.0,
common_terms=stop_words,
)
trigrams = Phrases( # pylint: disable=unexpected-keyword-arg
sentences=bigrams[tokenized_texts],
min_count=min_ngram_count,
threshold=5.0,
common_terms=stop_words,
)

tokens_with_ngrams = []
for text in tqdm(
Expand Down
145 changes: 99 additions & 46 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
from io import StringIO

import gensim
import numpy as np
from kwx import model

Expand Down Expand Up @@ -80,52 +81,104 @@ def test_extract_TFIDF_kws(long_text_corpus):
]


def test_extract_LDA_kws(long_text_corpus):
kws = model.extract_kws(
method="lda",
text_corpus=long_text_corpus,
input_language="english",
num_keywords=10,
num_topics=10,
prompt_remove_words=False,
)
assert kws == [
"virginamerica",
"customer",
"flight",
"tco",
"airline",
"trip",
"fly",
"carrieunderwood",
"bag",
"week",
]


def test_extract_kws_remove_words(monkeypatch, long_text_corpus):
monkeypatch.setattr("sys.stdin", StringIO("y\nvirginamerica\nn\n"))

kws = model.extract_kws(
method="lda",
text_corpus=long_text_corpus,
input_language="english",
num_keywords=10,
num_topics=10,
prompt_remove_words=True,
)
assert kws == [
"flight",
"time",
"seat",
"traveler",
"lax",
"fly",
"ladygaga",
"virginamerica_ladygaga",
"carrieunderwood",
"virginamerica_ladygaga_carrieunderwood",
]
if gensim.__version__[0] == "4":

def test_extract_LDA_kws(long_text_corpus):
kws = model.extract_kws(
method="lda",
text_corpus=long_text_corpus,
input_language="english",
num_keywords=10,
num_topics=10,
prompt_remove_words=False,
)
print(kws)
assert kws == [
"virginamerica",
"tco",
"love",
"flight",
"airline",
"trip",
"fly",
"change",
"carrieunderwood",
"week",
]

def test_extract_kws_remove_words(monkeypatch, long_text_corpus):
monkeypatch.setattr("sys.stdin", StringIO("y\nvirginamerica\nn\n"))

kws = model.extract_kws(
method="lda",
text_corpus=long_text_corpus,
input_language="english",
num_keywords=10,
num_topics=10,
prompt_remove_words=True,
)
print(kws)
assert kws == [
"flight",
"reservation",
"online",
"change",
"check",
"lax",
"ladygaga",
"virginamerica_ladygaga",
"carrieunderwood",
"virginamerica_ladygaga_carrieunderwood",
]


else:

def test_extract_LDA_kws(long_text_corpus):
kws = model.extract_kws(
method="lda",
text_corpus=long_text_corpus,
input_language="english",
num_keywords=10,
num_topics=10,
prompt_remove_words=False,
)
assert kws == [
"virginamerica",
"customer",
"flight",
"tco",
"airline",
"trip",
"fly",
"carrieunderwood",
"bag",
"week",
]

def test_extract_kws_remove_words(monkeypatch, long_text_corpus):
monkeypatch.setattr("sys.stdin", StringIO("y\nvirginamerica\nn\n"))

kws = model.extract_kws(
method="lda",
text_corpus=long_text_corpus,
input_language="english",
num_keywords=10,
num_topics=10,
prompt_remove_words=True,
)
assert kws == [
"flight",
"time",
"seat",
"traveler",
"lax",
"fly",
"ladygaga",
"virginamerica_ladygaga",
"carrieunderwood",
"virginamerica_ladygaga_carrieunderwood",
]


def test_extract_BERT_kws(long_text_corpus):
Expand Down

0 comments on commit c26c93f

Please sign in to comment.