Skip to content

Commit

Permalink
Release 0.6.4 (#76)
Browse files Browse the repository at this point in the history
* update contact email

* filter values for flesch and sentence length

* Release 0.6.3 (#73) (#74)

* update contact email

* use alternative version of textstat for flesch scores

* filter values for flesch and sentence length

* add argument to return single prediction for multiclass probability estimate (#75)

* Update version 0.6.4
  • Loading branch information
jnelson16 committed Dec 4, 2020
1 parent 4cbef96 commit f0e7023
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 12 deletions.
2 changes: 1 addition & 1 deletion quantgov/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from . import corpus, nlp, ml, utils
from .utils import load_driver

__version__ = '0.6.3'
__version__ = '0.6.4'
6 changes: 5 additions & 1 deletion quantgov/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def parse_args():
estimate.add_argument(
'--precision', default=4, type=int,
help='number of decimal places to round the probabilities')
estimate.add_argument(
'--oneclass', action='store_true',
help='only return predicted class for multiclass probabilty estimates')
estimate.add_argument(
'-o', '--outfile',
type=lambda x: open(x, 'w', newline='', encoding='utf-8'),
Expand Down Expand Up @@ -223,7 +226,8 @@ def run_estimator(args):
args.estimator,
args.corpus,
args.probability,
args.precision)
args.precision,
args.oneclass)
)


Expand Down
34 changes: 24 additions & 10 deletions quantgov/ml/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Functionality for making predictions with an estimator
"""
import logging
import numpy as np

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -106,7 +107,7 @@ def estimate_probability_multilabel(estimator, streamer, precision):
)


def estimate_probability_multiclass(estimator, streamer, precision):
def estimate_probability_multiclass(estimator, streamer, precision, oneclass):
"""
Generate probabilities for a one-label, multiclass estimator
Expand All @@ -119,12 +120,24 @@ def estimate_probability_multiclass(estimator, streamer, precision):
"""
texts = (doc.text for doc in streamer)
probs = estimator.pipeline.predict_proba(texts).round(precision)
yield from (
(docidx, (class_, probability))
for docidx, doc_probs in zip(streamer.index, probs)
for class_, probability in zip(estimator.pipeline.classes_, doc_probs)
)
probs = estimator.pipeline.predict_proba(texts)
# If oneclass flag is true, only returns the predicted class
if oneclass:
class_indices = list(i[-1] for i in np.argsort(probs, axis=1))
yield from (
(docidx, (estimator.pipeline.classes_[class_index],
doc_probs[class_index].round(precision)))
for docidx, doc_probs, class_index in zip(
streamer.index, probs, class_indices)
)
# Else returns probabilty values for all classes
else:
yield from (
(docidx, (class_, probability.round(precision)))
for docidx, doc_probs in zip(streamer.index, probs)
for class_, probability in zip(
estimator.pipeline.classes_, doc_probs)
)


def estimate_probability_multilabel_multiclass(estimator, streamer, precision):
Expand All @@ -140,7 +153,7 @@ def estimate_probability_multilabel_multiclass(estimator, streamer, precision):
"""
texts = (doc.text for doc in streamer)
probs = estimator.pipeline.predict_proba(texts)
probs = estimator.pipeline.predict_proba(texts).round(precision)
yield from (
(docidx, (label_name, class_, prob))
for label_name, label_probs in zip(estimator.label_names, probs)
Expand All @@ -149,7 +162,8 @@ def estimate_probability_multilabel_multiclass(estimator, streamer, precision):
)


def estimate(estimator, corpus, probability, precision=4, *args, **kwargs):
def estimate(estimator, corpus, probability, precision=4, oneclass=False,
*args, **kwargs):
"""
Estimate label values for documents in corpus
Expand All @@ -171,7 +185,7 @@ def estimate(estimator, corpus, probability, precision=4, *args, **kwargs):
estimator, streamer, precision)
elif estimator.multiclass: # Multiclass probability
yield from estimate_probability_multiclass(
estimator, streamer, precision)
estimator, streamer, precision, oneclass)
else: # Simple probability
yield from estimate_probability(
estimator, streamer, precision)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,14 @@ def test_multiclass_probability_estimator():
'moby,money,0.1536\n'
'moby,science-and-technology,0.1671\n'
'moby,world,0.141\n')


def test_multiclass_probability_oneclass_estimator():
output = check_output(
['quantgov', 'ml', 'estimate',
str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'multiclass.qge')),
str(PSEUDO_CORPUS_PATH), '--probability', '--oneclass']
)
assert output == ('file,class,probability\n'
'cfr,world,0.1997\n'
'moby,health-and-public-welfare,0.205\n')

0 comments on commit f0e7023

Please sign in to comment.