Skip to content

Commit

Permalink
Optimize memory hog at combination creation + fix autopep8 issue on >…
Browse files Browse the repository at this point in the history
…= py3.10 (#54)

* Optimize combinations creation
* fix autopep8@py10
* Fix failing test due to unstable sorting algo
* Bump semantic version
  • Loading branch information
miha-jenko authored Oct 24, 2023
1 parent c3ab440 commit 78c205b
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 30 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ repos:
rev: v2.0.4
hooks:
- id: autopep8
args: ["--global-config pyproject.toml"]
- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
hooks:
Expand Down
60 changes: 32 additions & 28 deletions outrank/core_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
IGNORED_VALUES = set()
HYPERLL_ERROR_BOUND = 0.02


def prior_combinations_sample(combinations: list[tuple[Any, ...]], args: Any) -> list[tuple[Any, ...]]:
"""Make sure only relevant subspace of combinations is selected based on prior counts"""

Expand All @@ -59,6 +60,36 @@ def prior_combinations_sample(combinations: list[tuple[Any, ...]], args: Any) ->
return tmp


def get_combinations_from_columns(all_columns: pd.Index, args: Any) -> list[tuple[Any, ...]]:
"""Return feature-feature & feature-label combinations, depending on the heuristic and ranking scope"""

if '3mr' in args.heuristic:
rel_columns = [column for column in all_columns if ' AND_REL ' in column]
non_rel_columns = sorted(set(all_columns) - set(rel_columns))

combinations = list(
itertools.combinations_with_replacement(non_rel_columns, 2),
)
combinations += [(column, args.label_column) for column in rel_columns]
else:
_combinations = itertools.combinations_with_replacement(all_columns, 2)

# Some applications do not require the full feature-feature triangular matrix
if args.target_ranking_only == 'True':
combinations = [x for x in _combinations if args.label_column in x]
else:
combinations = list(_combinations)

if args.target_ranking_only != 'True':
# Diagonal elements (non-label)
combinations += [
(individual_column, individual_column)
for individual_column in all_columns
if individual_column != args.label_column
]
return combinations


def mixed_rank_graph(
input_dataframe: pd.DataFrame, args: Any, cpu_pool: Any, pbar: Any,
) -> BatchRankingSummary:
Expand All @@ -78,34 +109,7 @@ def mixed_rank_graph(
end_enc_timer = timer()
out_time_struct['encoding_columns'] = end_enc_timer - start_enc_timer

# Helper method for parallel estimation
combinations = list(
itertools.combinations_with_replacement(all_columns, 2),
)

if '3mr' in args.heuristic:
rel_columns = [
column for column in all_columns if ' AND_REL ' in column
]
non_rel_columns = list(set(all_columns) - set(rel_columns))
combinations = list(
itertools.combinations_with_replacement(non_rel_columns, 2),
)
combinations += [(column, args.label_column) for column in rel_columns]
else:
combinations = list(
itertools.combinations_with_replacement(all_columns, 2),
)

# Diagonal elements
for individual_column in all_columns:
if individual_column != args.label_column:
combinations += [(individual_column, individual_column)]

# Some applications do not require the full feature-feature triangular matrix
if (args.target_ranking_only == 'True') and ('3mr' not in args.heuristic):
combinations = [x for x in combinations if args.label_column in x]

combinations = get_combinations_from_columns(all_columns, args)
combinations = prior_combinations_sample(combinations, args)
random.shuffle(combinations)

Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[tool.autopep8]
in-place = true
list-fixes = true
ignore = "W690"
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _read_description():
packages = [x for x in setuptools.find_packages() if x != 'test']
setuptools.setup(
name='outrank',
version='0.95.1',
version='0.95.2',
description='OutRank: Feature ranking for massive sparse data sets.',
long_description=_read_description(),
long_description_content_type='text/markdown',
Expand Down
35 changes: 34 additions & 1 deletion tests/ranking_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathos.multiprocessing import ProcessingPool as Pool

from outrank.core_ranking import compute_combined_features
from outrank.core_ranking import get_combinations_from_columns
from outrank.core_ranking import mixed_rank_graph
from outrank.feature_transformations.feature_transformer_vault import (
default_transformers,
Expand All @@ -29,7 +30,7 @@
class args:
label_column: str = 'label'
heuristic: str = 'surrogate-LR'
target_ranking_only: bool = True
target_ranking_only: str = 'True'
interaction_order: int = 3
combination_number_upper_bound: int = 1024

Expand Down Expand Up @@ -91,6 +92,38 @@ def test_compute_combinations(self):
)
self.assertEqual(transformed_df.shape[1], 6)

def test_get_combinations_from_columns_target_ranking_only(self):
all_columns = pd.Index(['a', 'b', 'label'])
args.heuristic = 'MI-numba-randomized'
args.target_ranking_only = 'True'
combinations = get_combinations_from_columns(all_columns, args)

self.assertSetEqual(
set(combinations),
{('a', 'label'), ('b', 'label'), ('label', 'label')},
)

def test_get_combinations_from_columns(self):
all_columns = pd.Index(['a', 'b', 'label'])
args.heuristic = 'MI-numba-randomized'
args.target_ranking_only = 'False'
combinations = get_combinations_from_columns(all_columns, args)

self.assertSetEqual(
set(combinations),
{('a', 'a'), ('b', 'b'), ('label', 'label'), ('a', 'b'), ('a', 'label'), ('b', 'label')},
)

def test_get_combinations_from_columns_3mr(self):
all_columns = pd.Index(['a', 'b', 'label'])
args.heuristic = 'MI-numba-3mr'
combinations = get_combinations_from_columns(all_columns, args)

self.assertSetEqual(
set(combinations),
{('a', 'a'), ('b', 'b'), ('label', 'label'), ('a', 'b'), ('a', 'label'), ('b', 'label')},
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 78c205b

Please sign in to comment.