Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compression module #1455

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,694 changes: 862 additions & 832 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ ipython = "^8.17.2"
rich = "^13.6.0"
jupyter = "^1.0.0"
mike = "^2.0.0"
zstandard = "^0.22.0"

[tool.poetry.group.compat]
optional = true
Expand Down
11 changes: 11 additions & 0 deletions river/compression/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Compression-based text classification. This module contains wrappers to enable compression-based
text classification using different algorithms."""
from __future__ import annotations

from .text_compression_classifier import (
TextCompressionClassifier,
)

__all__ = [
"TextCompressionClassifier",
]
61 changes: 61 additions & 0 deletions river/compression/test-compression.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"from sklearn.datasets import fetch_20newsgroups\n",
"from river import compression\n",
"from sklearn.metrics import classification_report\n",
"\n",
"# prepare train dataset\n",
"categories = [\n",
" 'alt.atheism',\n",
" 'talk.religion.misc',\n",
" 'comp.graphics',\n",
" 'sci.space'\n",
"]\n",
"\n",
"train = fetch_20newsgroups(subset='train', categories=categories)\n",
"test = fetch_20newsgroups(subset='test', categories=categories)\n",
"\n",
"model = compression.TextCompressionClassifier()\n",
"\n",
"for i, text in enumerate(train['data']):\n",
" x = text \n",
" y = train['target_names'][train['target'][i]]\n",
" model = model.learn_one(x, y)\n",
"\n",
"predictions = []\n",
"\n",
"for text in test['data']:\n",
" y_pred = model.predict_one(text)\n",
" predictions.append(y_pred)\n",
"\n",
"test_labels = [\n",
" test['target_names'][label]\n",
" for label in test['target']\n",
"]\n",
"\n",
"print(classification_report(\n",
" test_labels,\n",
" predictions,\n",
" digits=3\n",
"))"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
54 changes: 54 additions & 0 deletions river/compression/teste.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n",
"estou na classe !! hello world\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"from river import compression\n",
"model = compression.TextCompressionClassifier()\n",
"\n",
"model.test(\"hello world\")"
]
}
],
"metadata": {
"interpreter": {
"hash": "1c919c39556be274582a1f16bb0a84a92481bcfbe155bed1bcf4102511c6f533"
},
"kernelspec": {
"display_name": "Python 3.8.6rc1 ('.venv': venv)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6rc1"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
135 changes: 135 additions & 0 deletions river/compression/text_compression_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from __future__ import annotations
from collections import deque

import zstandard # type: ignore

from river import base


class TextCompressionClassifier(base.Classifier):
"""Classifier based on text compression techniques.

The classifier utilizes the `zstandard` compression library to classify text by measuring
how well a new piece of data compresses with existing data of each class.

Attributes:
compression_level (int): The level of compression. Default is 3.
k (int): The maximum number of documents to be stored per label. Default is 150.
label_documents (dict): Stores concatenated documents for each label.
compression_contexts (dict): Stores Zstd compression contexts for each label.

"""
def __init__(self, compression_level=3, k=150):
"""
Initializes the TextCompressionClassifier.

Args:
compression_level (int, optional): The desired compression level. Defaults to 3.
k (int, optional): The maximum number of documents to store per label. Defaults to 150.
"""
self.compression_level = compression_level
self.k = k
self.label_documents = {} # Concatenated documents for each label
self.compression_contexts = {} # Zstd compression contexts for each label

def __getstate__(self):
return {'compression_level': self.compression_level,
'k': self.k,
'label_documents': self.label_documents}

def __setstate__(self, state):
self.compression_level = state['compression_level']
self.k = state['k']
self.label_documents = state['label_documents']
self.compression_contexts = self.recreate_contexts(self.label_documents)

def recreate_contexts(self, label_documents):
new_compression_contexts = {}
for label, documents in label_documents.items():
concatenated_documents = " ".join(documents).encode("utf-8")
compression_dict = zstandard.ZstdCompressionDict(concatenated_documents)

zstd_compressor = zstandard.ZstdCompressor(
level=self.compression_level, dict_data=compression_dict
)

new_compression_contexts[label] = zstd_compressor

return new_compression_contexts

def _unit_test_skips(self):
return {"check_shuffle_features_no_impact"}

def learn_one(self, x, y):
"""Updates the classifier with a new sample.

Args:
x (dict): The input sample.
y (str): The label of the input sample.

Returns:
TextCompressionClassifier: The classifier instance (for chaining).

"""
# Convert your input 'x' to a string representation if it's not already
# For the sake of example, let's assume 'x' is a dictionary of features
x_str = str(x)

# Initialize if label is new
if y not in self.label_documents:
self.label_documents[y] = deque(maxlen=self.k)

# Append the new document and remove the oldest if length > k
self.label_documents[y].append(x_str)

# Concatenate documents in the deque into a single string
concatenated_documents = " ".join(self.label_documents[y])

# Create a dictionary with encoded concatenated text
compression_dict = zstandard.ZstdCompressionDict(concatenated_documents.encode("utf-8"))

# Create a Zstandard compression context for this label using the dictionary
zstd_compressor = zstandard.ZstdCompressor(
level=self.compression_level, dict_data=compression_dict
)

# Update the compression context for this label with the new compressor
self.compression_contexts[y] = zstd_compressor

return self

def predict_one(self, x):
"""Predict the output label for the input `x`.

Parameters
----------
x : any type that can be converted to string
The input to be classified.

Returns
-------
best_label: string
The label corresponding to the smallest increase in compressed size after
adding the new data `x`.
"""
min_size_increase = float("inf")
best_label = None

# Convert your input 'x' to a string representation if it's not already
x_str = str(x)

for label, compressor in self.compression_contexts.items():
concatenated_doc = (" ".join(self.label_documents[label]) + " " + x_str).encode("utf-8")
compressed_size = len(compressor.compress(concatenated_doc))

previous_size = len(
compressor.compress(" ".join(self.label_documents[label]).encode("utf-8"))
)

size_increase = compressed_size - previous_size

if size_increase < min_size_increase:
min_size_increase = size_increase
best_label = label

return best_label
2 changes: 2 additions & 0 deletions river/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
checks,
compat,
compose,
compression,
facto,
feature_extraction,
feature_selection,
Expand Down Expand Up @@ -105,6 +106,7 @@ def can_be_tested(estimator):
)
for estimator in list(iter_estimators_which_can_be_tested())
+ [
compression.TextCompressionClassifier(),
preprocessing.StandardScaler() | linear_model.LinearRegression(),
preprocessing.StandardScaler() | linear_model.PAClassifier(),
(
Expand Down
Loading