-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
16 changed files
with
369 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,4 +14,5 @@ experiments/ | |
*__pycache__ | ||
.DS_Store | ||
.vscode | ||
logs | ||
logs | ||
data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from pathlib import Path | ||
|
||
from tap import Tap | ||
|
||
|
||
class ArgumentParser(Tap): | ||
config: Path |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from pathlib import Path | ||
|
||
from report.argument_parser import ArgumentParser | ||
|
||
|
||
class ChartArgumentParser(ArgumentParser): | ||
save_path: Path = Path('./data/charts') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from argparse import ArgumentParser | ||
from pathlib import Path | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
from lkmeans.examples.data.experiment_data import get_experiment_data | ||
from lkmeans.examples.data.points_generator import generate_mix_distribution | ||
|
||
parser = ArgumentParser() | ||
|
||
parser.add_argument( | ||
'--path', | ||
type=Path, | ||
default=Path('images'), | ||
help='Path to save results' | ||
) | ||
|
||
|
||
def main(): | ||
args = parser.parse_args() | ||
args.path.mkdir(exist_ok=True) | ||
|
||
dimension = 20 | ||
n_points = 100 | ||
|
||
n_clusters, prob, mu_list, cov_matrices = get_experiment_data(num_clusters=2, dimension=dimension) | ||
|
||
for t in [0.2, 0.4, 0.9]: | ||
filename = args.path / f'{n_clusters}_cluster_hist_t_{t}.png' | ||
clusters, _, _ = generate_mix_distribution( | ||
probability=prob, | ||
mu_list=mu_list, | ||
cov_matrices=cov_matrices, | ||
n_samples=n_points, | ||
t=t | ||
) | ||
|
||
fig, ax = plt.subplots(figsize=(5, 3)) | ||
ax.hist(clusters[:, 0], bins=15) | ||
ax.grid(True, color='gray', linestyle='--', linewidth=0.5) | ||
fig.savefig(str(filename), dpi=300, bbox_inches='tight') | ||
plt.close(fig) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import json | ||
from collections import defaultdict | ||
from typing import Dict, List | ||
|
||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
|
||
from report.charts.chart_argument_parser import ChartArgumentParser | ||
from report.log_parser import LogParser | ||
from report.metric_name_processor import process_metric_name | ||
|
||
|
||
def select_metric(all_data: Dict[str, List[Dict[str, float]]], metric: str) -> Dict[str, List[float]]: | ||
data = {} | ||
for block_name, block_metrics in all_data.items(): | ||
data[block_name] = [metrics[metric] for metrics in block_metrics] | ||
return data | ||
|
||
|
||
def main() -> None: | ||
args = ChartArgumentParser(underscores_to_dashes=True).parse_args() | ||
with args.config.open() as file: | ||
json_data = json.load(file) | ||
|
||
args.save_path.mkdir(parents=True, exist_ok=True) | ||
|
||
parser = LogParser() | ||
data = defaultdict(list) | ||
for block_name, logs_block in json_data['logs'].items(): | ||
if block_name == 'LKMeans': | ||
continue | ||
for log_path in logs_block.values(): | ||
if len(log_path.split(' ')) > 1: | ||
log_data_dict = json.loads(log_path.replace('\'', '"')) | ||
else: | ||
log_data_dict = parser.parse(log_path) | ||
data[block_name].append(log_data_dict) | ||
|
||
for metric in json_data['plot_metrics']: | ||
config_name = json_data['name'] | ||
chart_name = args.save_path / f'{config_name}_{metric}.png' | ||
prepared_data = select_metric(data, metric) | ||
figure = plt.figure(figsize=(4,4), dpi=800) | ||
axes = figure.gca() | ||
for line_name, values in prepared_data.items(): | ||
axes.plot(values, label=line_name) | ||
axes.legend() | ||
axes.set_xticks(ticks=np.linspace(0, 3, 4)) | ||
axes.set_xticklabels(['0', '0.1', '0.15', '0.2']) | ||
axes.set_title(process_metric_name(metric)) | ||
figure.tight_layout() | ||
figure.savefig(chart_name) | ||
plt.close(figure) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from argparse import ArgumentParser | ||
from pathlib import Path | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from lkmeans.clustering.utils import assign_to_cluster | ||
from lkmeans.distance import DistanceCalculator | ||
from lkmeans.examples.data.experiment_data import get_experiment_data | ||
from lkmeans.examples.data.points_generator import generate_mix_distribution | ||
|
||
parser = ArgumentParser() | ||
|
||
parser.add_argument( | ||
'--path', | ||
type=Path, | ||
default=Path('images'), | ||
help='Path to save results' | ||
) | ||
|
||
parser.add_argument( | ||
'--p', | ||
type=float, | ||
default=2, | ||
help='Minkowski parameter' | ||
) | ||
|
||
parser.add_argument( | ||
'--t', | ||
type=float, | ||
default=0., | ||
help='T parameter of distribution' | ||
) | ||
|
||
|
||
# pylint: disable=too-many-locals | ||
def main(): | ||
args = parser.parse_args() | ||
args.path.mkdir(exist_ok=True) | ||
p = int(args.p) if (args.p).is_integer() else args.p | ||
|
||
dimension = 20 | ||
n_points = 10 | ||
n_observation = 10000 | ||
|
||
distance_calculator = DistanceCalculator(p) | ||
|
||
n_clusters, prob, mu_list, cov_matrices = get_experiment_data(num_clusters=2, dimension=dimension) | ||
|
||
filename = args.path / f'plot_minkowski_function_with_p_{p}.png' | ||
samples, _, centroids = generate_mix_distribution( | ||
probability=prob, | ||
mu_list=mu_list, | ||
cov_matrices=cov_matrices, | ||
n_samples=n_points, | ||
t=0.1 | ||
) | ||
|
||
dim = 0 | ||
|
||
clusters, _ = assign_to_cluster(samples, centroids, n_clusters, p) | ||
cluster = np.array(clusters[0]) | ||
dimension_data = cluster[:,dim] | ||
|
||
points = np.linspace(min(dimension_data), max(dimension_data), n_observation) | ||
minkowski_values = distance_calculator.get_pairwise_distance( | ||
point_a = dimension_data, | ||
points=points, | ||
) | ||
|
||
fig, ax = plt.subplots(figsize=(5, 3)) | ||
ax.scatter(points, minkowski_values) | ||
ax.axis('off') | ||
fig.savefig(str(filename), dpi=300, bbox_inches='tight') | ||
plt.close(fig) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import json | ||
from pathlib import Path | ||
from typing import Dict | ||
|
||
|
||
class LogParser: | ||
def parse(self, log_path: Path) -> Dict[str, float]: | ||
with Path(log_path).open(encoding='utf-8') as log_buff: | ||
log_data = log_buff.read() | ||
log_data = log_data.replace('\'', '"') | ||
return json.loads(log_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from typing import Dict | ||
|
||
|
||
def process_metric_name(metric: str) -> str: | ||
name_map: Dict[str, str] = { | ||
'ari': 'ARI', | ||
'ami': 'AMI', | ||
'nmi': 'NMI', | ||
'v_measure': 'V-measure' | ||
} | ||
return name_map[metric] if metric in name_map else metric.capitalize() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from enum import Enum | ||
from typing import Dict | ||
|
||
|
||
class HighlightRule(Enum): | ||
MAX = 'max' | ||
MIN = 'min' | ||
NONE = 'none' | ||
|
||
|
||
def get_highlight_rules() -> Dict[str, HighlightRule]: | ||
return { | ||
'log_name': HighlightRule.NONE, | ||
'ari': HighlightRule.MAX, | ||
'ami': HighlightRule.MAX, | ||
'completeness': HighlightRule.MAX, | ||
'homogeneity': HighlightRule.MAX, | ||
'nmi': HighlightRule.MAX, | ||
'v_measure': HighlightRule.MAX, | ||
'accuracy': HighlightRule.MAX, | ||
'time': HighlightRule.MIN, | ||
'inertia': HighlightRule.MIN | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import json | ||
|
||
import pandas as pd | ||
|
||
from report.log_parser import LogParser | ||
from report.tables.highlight_rule import get_highlight_rules | ||
from report.tables.saver import LatexSaver | ||
from report.tables.styler import TableStyler | ||
from report.tables.table_argument_parser import TableArgumentParser | ||
|
||
|
||
def main() -> None: | ||
args = TableArgumentParser(underscores_to_dashes=True).parse_args() | ||
with args.config.open() as file: | ||
json_data = json.load(file) | ||
|
||
args.save_path.mkdir(parents=True, exist_ok=True) | ||
saver = LatexSaver(args.save_path / json_data['name']) | ||
|
||
parser = LogParser() | ||
data = [] | ||
for logs_block in json_data['logs'].values(): | ||
for log_name, log_path in logs_block.items(): | ||
if len(log_path.split(' ')) > 1: | ||
log_data_dict = json.loads(log_path.replace('\'', '"')) | ||
else: | ||
log_data_dict = parser.parse(log_path) | ||
log_data_dict = {'log_name': log_name, **log_data_dict} | ||
data.append(log_data_dict) | ||
data_frame = pd.DataFrame(data) | ||
|
||
rules = get_highlight_rules() | ||
styler = TableStyler(data_frame, json_data['columns'], rules).style() | ||
saver.save(styler) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from abc import ABC, abstractmethod | ||
from pathlib import Path | ||
|
||
from pandas.io.formats.style import Styler | ||
|
||
|
||
class Saver(ABC): | ||
def __init__(self, file_name: Path) -> None: | ||
self._file_name = file_name | ||
self._convert_path() | ||
|
||
@abstractmethod | ||
def _convert_path(self) -> None: | ||
... | ||
|
||
@abstractmethod | ||
def save(self, styler: Styler) -> None: | ||
... | ||
|
||
|
||
class LatexSaver(Saver): | ||
|
||
def _convert_path(self) -> None: | ||
self._file_name = self._file_name.with_suffix('.tex') | ||
|
||
def save(self, styler: Styler) -> None: | ||
with self._file_name.open('w') as file: | ||
styler.format(escape='latex', precision=2).to_latex(file, convert_css=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from typing import Any, Dict, List | ||
|
||
import pandas as pd | ||
from pandas.io.formats.style import Styler | ||
|
||
from report.tables.highlight_rule import HighlightRule | ||
|
||
FORMAT_BOLD ='font-weight:bold;' | ||
|
||
|
||
class TableStyler: | ||
def __init__(self, data_frame: pd.DataFrame, columns: List[str], rules: Dict[str, Any]) -> None: | ||
self._data_frame = data_frame | ||
self._columns = columns | ||
self._rules = rules | ||
|
||
def _highlight_max(self, styler: Styler) -> Styler: | ||
columns_for_highlight = [name for name, highlight_rule in self._rules.items() \ | ||
if highlight_rule is HighlightRule.MAX] | ||
columns = list(set(self._columns).intersection(set(columns_for_highlight))) | ||
return styler.highlight_max(subset=columns, props=FORMAT_BOLD) | ||
|
||
def _highlight_min(self, styler: Styler) -> Styler: | ||
columns_for_highlight = [name for name, highlight_rule in self._rules.items() \ | ||
if highlight_rule is HighlightRule.MIN] | ||
columns = list(set(self._columns).intersection(set(columns_for_highlight))) | ||
return styler.highlight_min(subset=columns, props=FORMAT_BOLD) | ||
|
||
def _round_values(self, styler: Styler) -> Styler: | ||
columns_for_rounding = [name for name, highlight_rule in self._rules.items() \ | ||
if highlight_rule is not HighlightRule.NONE] | ||
columns = list(set(self._columns).intersection(set(columns_for_rounding))) | ||
return styler.format(lambda value: f'{value:.2f}', na_rep='N/A', subset=columns) | ||
|
||
def _hide_index(self, styler: Styler) -> Styler: | ||
return styler.hide() | ||
|
||
def _highlight_index(self, styler: Styler) -> Styler: | ||
return styler.applymap_index(lambda _: FORMAT_BOLD, axis='columns') | ||
|
||
def style(self) -> Styler: | ||
frame = self._data_frame[self._columns] | ||
styler = frame.style | ||
|
||
styler = self._highlight_max(styler) | ||
styler = self._highlight_min(styler) | ||
styler = self._hide_index(styler) | ||
styler = self._highlight_index(styler) | ||
styler = self._round_values(styler) | ||
return styler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from pathlib import Path | ||
|
||
from report.argument_parser import ArgumentParser | ||
|
||
|
||
class TableArgumentParser(ArgumentParser): | ||
save_path: Path = Path('./data/tables') |