Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
NicoNeureiter committed Jul 19, 2022
1 parent 72cda2e commit 6fffd9a
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 38 deletions.
13 changes: 6 additions & 7 deletions sbayes/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import io
import ruamel.yaml


from pydantic import BaseModel, Extra, Field
from pydantic import validator, root_validator, ValidationError
from pydantic import FilePath, DirectoryPath
from pydantic import PositiveInt, PositiveFloat, confloat
from pydantic import PositiveInt, PositiveFloat, confloat, NonNegativeFloat, NonNegativeInt

from sbayes.util import fix_relative_path, decompose_config_path, PathLike
from sbayes.util import update_recursive
Expand Down Expand Up @@ -238,19 +237,19 @@ class OperatorsConfig(BaseConfig):

"""The frequency of each MCMC operator. Will be normalized to 1.0 at runtime."""

clusters: PositiveFloat = 45.0
clusters: NonNegativeFloat = 45.0
"""Frequency at which the assignment of objects to clusters is changed."""

weights: PositiveFloat = 15.0
weights: NonNegativeFloat = 15.0
"""Frequency at which mixture weights are changed."""

cluster_effect: PositiveFloat = 5.0
cluster_effect: NonNegativeFloat = 5.0
"""Frequency at which cluster effect parameters are changed."""

confounding_effects: PositiveFloat = 15.0
confounding_effects: NonNegativeFloat = 15.0
"""Frequency at which confounding effects parameters are changed."""

source: PositiveFloat = 10.0
source: NonNegativeFloat = 10.0
"""Frequency at which the assignments of observations to mixture components are changed."""


Expand Down
28 changes: 15 additions & 13 deletions sbayes/load_data.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

""" Imports the real world data """
from __future__ import annotations
import pandas as pd
import pyproj
from dataclasses import dataclass
from logging import Logger
from pathlib import Path
from typing import List, Tuple, Dict, Optional, Union, Sequence
from collections import OrderedDict
from typing import List, Dict, Optional, Union, Sequence


try:
Expand Down Expand Up @@ -68,7 +68,7 @@ def from_dataframe(cls, data: pd.DataFrame) -> "Objects":
return cls(**objects_dict)


@dataclass
@dataclass(frozen=True)
class Features:

values: NDArray[bool] # shape: (n_sites, n_features, n_states)
Expand Down Expand Up @@ -104,7 +104,7 @@ def from_dataframes(
return cls(**features_dict, na_number=na_number)


@dataclass
@dataclass(frozen=True)
class Confounder:
name: str
group_assignment: NDArray[bool] # shape: (n_groups, n_sites)
Expand All @@ -117,6 +117,9 @@ def __getitem__(self, key):
return self.group_assignment
return getattr(self, key)

def any_group(self) -> NDArray[bool]: # shape: (n_groups,)
return np.any(self.group_assignment, axis=0)

@property
def n_groups(self):
return len(self.group_names)
Expand Down Expand Up @@ -168,7 +171,7 @@ class Data:

objects: Objects
features: Features
confounders: Dict[str, Confounder]
confounders: OrderedDict[str, Confounder]
prior_confounders: ...
crs: Optional[pyproj.CRS]
geo_cost_matrix: Optional[NDArray[float]]
Expand All @@ -179,7 +182,7 @@ def __init__(
self,
objects: Objects,
features: Features,
confounders: Dict[str, Confounder],
confounders: OrderedDict[str, Confounder],
projection: Optional[str] = "epsg:4326",
geo_costs: Union[Literal["from_data"], PathLike] = "from_data",
logger: Logger = None,
Expand Down Expand Up @@ -245,7 +248,7 @@ def log_loading(logger):
logger.info("##########################################")


@dataclass
@dataclass(frozen=True)
class Prior:
counts: NDArray[int]
states: List
Expand All @@ -259,7 +262,7 @@ def read_features_from_csv(
feature_states_path: PathLike,
groups_by_confounder: Dict[str, list],
logger: Optional[Logger] = None,
) -> (Objects, Features, Dict[str, Confounder]):
) -> (Objects, Features, OrderedDict[str, Confounder]):
"""This is a helper function to import data (objects, features, confounders) from a csv file
Args:
data_path: path to the data csv file.
Expand All @@ -276,10 +279,9 @@ def read_features_from_csv(

features = Features.from_dataframes(data, feature_states)
objects = Objects.from_dataframe(data)
confounders = {
c: Confounder.from_dataframe(confounder_name=c, group_names=groups, data=data)
for c, groups in groups_by_confounder.items()
}
confounders = OrderedDict()
for c, groups in groups_by_confounder.items():
confounders[c] = Confounder.from_dataframe(confounder_name=c, group_names=groups, data=data)

if logger:
logger.info(
Expand Down
3 changes: 1 addition & 2 deletions sbayes/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,8 +628,7 @@ def add_legend_lines(cfg_graphic, cfg_legend, ax):
leg_line_width.append(line)

# Add legend text
prop_l = int(k * 100)
line_width_label.append(f'{prop_l}%')
line_width_label.append(f'{k:.0%}')

# Adds everything to the legend
legend_line_width = ax.legend(leg_line_width, line_width_label, title_fontsize=18,
Expand Down
9 changes: 7 additions & 2 deletions sbayes/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def match_clusters(samples):
return samples


# TODO: update to generalized sBayes
def contribution_per_cluster(mcmc_sampler):
"""Evaluate the contribution of each cluster to the lh and the posterior in each sample.
Args:
Expand Down Expand Up @@ -107,8 +108,12 @@ def contribution_per_cluster(mcmc_sampler):
cluster = stats['sample_clusters'][s][np.newaxis, z]
cluster_effect = stats['sample_cluster_effect'][s][np.newaxis, z]

single_cluster = Sample(clusters=cluster, weights=weights,
p_global=p_global, cluster_effect=cluster_effect, p_families=p_families)
single_cluster = Sample.from_numpy_arrays(
clusters=cluster,
weights=weights,
p_global=p_global,
cluster_effect=cluster_effect, p_families=p_families
)

lh = mcmc_sampler.likelihood(single_cluster, 0)
prior = mcmc_sampler.prior(single_cluster, 0)
Expand Down
4 changes: 2 additions & 2 deletions sbayes/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import typing as typ
from __future__ import annotations
try:
from typing import Literal
except ImportError:
Expand All @@ -9,7 +9,7 @@
import csv
import sys
import random
from typing import Sequence, Dict, Union
from typing import Sequence

import numpy as np
from numpy.typing import NDArray
Expand Down
38 changes: 26 additions & 12 deletions sbayes/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,8 +1127,8 @@ def log_multinom(n, ks):
"""
ks = np.asarray(ks)
assert np.all(ks >= 0)
assert np.sum(ks) <= n
# assert np.all(ks >= 0)
# assert np.sum(ks) <= n

# Simple special case
if np.sum(ks) == 0:
Expand All @@ -1149,11 +1149,11 @@ def log_multinom(n, ks):
# If there are is a remainder in the population, that was not assigned to any of the
# samples, subtract all permutations of the remainder population.
rest = n - np.sum(ks)
assert rest >= 0
# assert rest >= 0
if rest > 0:
m -= log_i_cumsum[rest-1]

assert m >= 0, m
# assert m >= 0, m
return m


Expand Down Expand Up @@ -1182,20 +1182,34 @@ def fix_relative_path(path: PathLike, base_directory: PathLike) -> Path:
return base_directory / path


def timeit(func):
def timeit(units='s'):
SECONDS_PER_UNIT = {
'h': 3600.,
'm': 60.,
's': 1.,
'ms': 1E-3,
'µs': 1E-6,
'ns': 1E-9
}
unit_scaler = SECONDS_PER_UNIT[units]

def timeit_decorator(func):

def timed_func(*args, **kwargs):

def timed_func(*args, **kwargs):

start = time.time()
result = func(*args, **kwargs)
end = time.time()
start = time.time()
result = func(*args, **kwargs)
end = time.time()
passed = (end - start) / unit_scaler

print('Runtime %s: %.4fs' % (func.__name__, (end - start)))
print(f'Runtime {func.__name__}: {passed:.2f}{units}')

return result
return result

return timed_func
return timed_func

return timeit_decorator

@lru_cache(maxsize=128)
def get_permutations(n: int) -> List[Tuple[int]]:
Expand Down

0 comments on commit 6fffd9a

Please sign in to comment.