Skip to content

Commit

Permalink
Rounding (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
jnelson16 authored and OliverSherouse committed Apr 13, 2018
1 parent ebcb7d2 commit 6686e0d
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 16 deletions.
5 changes: 4 additions & 1 deletion quantgov/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def parse_args():
estimate.add_argument(
'--probability', action='store_true',
help='output probabilities instead of predictions')
estimate.add_argument(
'--precision', default=4, type=int,
help='number of decimal places to round the probabilities')
estimate.add_argument(
'-o', '--outfile',
type=lambda x: open(x, 'w', newline='', encoding='utf-8'),
Expand Down Expand Up @@ -187,7 +190,7 @@ def run_estimator(args):
elif args.subcommand == "estimate":
quantgov.estimator.estimate(
args.vectorizer, args.model, args.corpus, args.probability,
args.outfile
args.precision, args.outfile
)


Expand Down
37 changes: 22 additions & 15 deletions quantgov/estimator/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def estimate_simple(vectorizer, model, streamer):
yield from zip(streamer.index, pipeline.predict(texts))


def estimate_probability(vectorizer, model, streamer):
def estimate_probability(vectorizer, model, streamer, precision):
"""
Generate probabilities for a one-label estimator
Expand All @@ -61,11 +61,13 @@ def estimate_probability(vectorizer, model, streamer):
pipeline = get_pipeline(vectorizer, model)
texts = (doc.text for doc in streamer)
truecol = list(int(i) for i in model.model.classes_).index(1)
predicted = (i[truecol] for i in pipeline.predict_proba(texts))
predicted = (
i[truecol] for i in pipeline.predict_proba(texts).round(precision)
)
yield from zip(streamer.index, predicted)


def estimate_probability_multilabel(vectorizer, model, streamer):
def estimate_probability_multilabel(vectorizer, model, streamer, precision):
"""
Generate probabilities for a multilabel binary estimator
Expand Down Expand Up @@ -96,13 +98,13 @@ def estimate_probability_multilabel(vectorizer, model, streamer):
try:
for i, docidx in enumerate(streamer.index):
yield docidx, tuple(
label_predictions[i, truecols[j]]
label_predictions[i, truecols[j]].round(int(precision))
for j, label_predictions in enumerate(predicted))
except IndexError:
yield from zip(streamer.index, predicted)
yield from zip(streamer.index, predicted.round(int(precision)))


def estimate_probability_multiclass(vectorizer, model, streamer):
def estimate_probability_multiclass(vectorizer, model, streamer, precision):
"""
Generate probabilities for a one-label, multiclass estimator
Expand All @@ -117,10 +119,14 @@ def estimate_probability_multiclass(vectorizer, model, streamer):
"""
pipeline = get_pipeline(vectorizer, model)
texts = (doc.text for doc in streamer)
yield from zip(streamer.index, pipeline.predict_proba(texts))
yield from zip(
streamer.index,
(i for i in pipeline.predict_proba(texts).round(precision))
)


def estimate_probability_multilabel_multiclass(vectorizer, model, streamer):
def estimate_probability_multilabel_multiclass(
vectorizer, model, streamer, precision):
"""
Generate probabilities for a multilabel, multiclass estimator
Expand All @@ -137,8 +143,8 @@ def estimate_probability_multilabel_multiclass(vectorizer, model, streamer):
texts = (doc.text for doc in streamer)
predicted = pipeline.predict_proba(texts)
for i, docidx in enumerate(streamer.index):
yield docidx, tuple(label_predictions[i]
for label_predictions in predicted)
yield docidx, tuple(label_predictions[i] for label_predictions
in predicted.round(precision))


def is_multiclass(classes):
Expand All @@ -152,7 +158,7 @@ def is_multiclass(classes):
return True


def estimate(vectorizer, model, corpus, probability, outfile):
def estimate(vectorizer, model, corpus, probability, precision, outfile):
"""
Estimate label values for documents in corpus
Expand Down Expand Up @@ -184,7 +190,7 @@ def estimate(vectorizer, model, corpus, probability, outfile):
if multilabel:
if multiclass: # Multilabel-multiclass probability
results = estimate_probability_multilabel_multiclass(
vectorizer, model, streamer)
vectorizer, model, streamer, precision)
writer.writerow(corpus.index_labels +
('label', 'class', 'probability'))
writer.writerows(
Expand All @@ -198,7 +204,7 @@ def estimate(vectorizer, model, corpus, probability, outfile):
)
else: # Multilabel probability
results = estimate_probability_multilabel(
vectorizer, model, streamer)
vectorizer, model, streamer, precision)
writer.writerow(corpus.index_labels + ('label', 'probability'))
writer.writerows(
docidx + (label_name, prediction)
Expand All @@ -209,15 +215,16 @@ def estimate(vectorizer, model, corpus, probability, outfile):
elif multiclass: # Multiclass probability
writer.writerow(corpus.index_labels + ('class', 'probability'))
results = estimate_probability_multiclass(
vectorizer, model, streamer)
vectorizer, model, streamer, precision)
writer.writerows(
docidx + (class_name, prediction)
for docidx, predictions in results
for class_name, prediction in zip(
model.model.classes_, predictions)
)
else: # Simple probability
results = estimate_probability(vectorizer, model, streamer)
results = estimate_probability(
vectorizer, model, streamer, precision)
writer.writerow(
corpus.index_labels + (model.label_names[0] + '_prob',))
writer.writerows(
Expand Down
92 changes: 92 additions & 0 deletions tests/pseudo_estimator/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
.snakemake
notebooks/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# IPython Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# dotenv
.env

# virtualenv
venv/
ENV/

# Spyder project settings
.spyderproject

# Rope project settings
.ropeproject
Binary file added tests/pseudo_estimator/data/model.pickle
Binary file not shown.
Binary file not shown.
Binary file added tests/pseudo_estimator/data/vectorizer.pickle
Binary file not shown.
71 changes: 71 additions & 0 deletions tests/test_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pytest
import quantgov.estimator
import subprocess

from pathlib import Path


PSEUDO_CORPUS_PATH = Path(__file__).resolve().parent.joinpath('pseudo_corpus')
PSEUDO_ESTIMATOR_PATH = (
Path(__file__).resolve().parent
.joinpath('pseudo_estimator')
)


def check_output(cmd):
return (
subprocess.check_output(cmd, universal_newlines=True)
.replace('\n\n', '\n')
)


def test_simple_estimator():
output = check_output(
['quantgov', 'estimator', 'estimate',
str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'vectorizer.pickle')),
str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'model.pickle')),
str(PSEUDO_CORPUS_PATH)]
)
assert output == 'file,is_world\ncfr,False\nmoby,False\n'


def test_probability_estimator():
output = check_output(
['quantgov', 'estimator', 'estimate',
str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'vectorizer.pickle')),
str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'model.pickle')),
str(PSEUDO_CORPUS_PATH), '--probability']
)
assert output == ('file,is_world_prob\ncfr,0.0899\nmoby,0.0216\n')


def test_probability_estimator_6decimals():
output = check_output(
['quantgov', 'estimator', 'estimate',
str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'vectorizer.pickle')),
str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'model.pickle')),
str(PSEUDO_CORPUS_PATH), '--probability', '--precision', '6']
)
assert output == ('file,is_world_prob\ncfr,0.089898\nmoby,0.02162\n')


def test_multiclass_probability_estimator():
output = check_output(
['quantgov', 'estimator', 'estimate',
str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'vectorizer.pickle')),
str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'modelmulticlass.pickle')),
str(PSEUDO_CORPUS_PATH), '--probability']
)
assert output == ('file,class,probability\n'
'cfr,business-and-industry,0.1765\n'
'cfr,environment,0.1294\n'
'cfr,health-and-public-welfare,0.1785\n'
'cfr,money,0.169\n'
'cfr,science-and-technology,0.147\n'
'cfr,world,0.1997\n'
'moby,business-and-industry,0.1804\n'
'moby,environment,0.1529\n'
'moby,health-and-public-welfare,0.205\n'
'moby,money,0.1536\n'
'moby,science-and-technology,0.1671\n'
'moby,world,0.141\n')

0 comments on commit 6686e0d

Please sign in to comment.