Skip to content

Commit

Permalink
Minor code improvement, including adding more imports to __init__
Browse files Browse the repository at this point in the history
  • Loading branch information
ruancomelli committed Jun 11, 2021
1 parent 849550c commit ccdbd20
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 271 deletions.
14 changes: 7 additions & 7 deletions boiling_learning/datasets/creators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import boiling_learning.utils as bl_utils
from boiling_learning.datasets.datasets import (
DatasetSplits,
tf_concatenate,
tf_train_val_test_split,
concatenate,
train_val_test_split,
)
from boiling_learning.io.io import DatasetTriplet
from boiling_learning.management.Manager import Manager
Expand Down Expand Up @@ -41,7 +41,7 @@ def experiment_video_dataset_creator(
t.as_tf_py_function(pack_tuple=True), num_parallel_calls=AUTOTUNE
)

ds_train, ds_val, ds_test = tf_train_val_test_split(ds, splits)
ds_train, ds_val, ds_test = train_val_test_split(ds, splits)

if dataset_size is not None:
ds_train = ds_train.take(dataset_size)
Expand All @@ -53,7 +53,7 @@ def experiment_video_dataset_creator(
snapshot_path = bl_utils.ensure_dir(snapshot_path)

if dataset_size is not None:
num_shards = min([dataset_size, num_shards])
num_shards = min(dataset_size, num_shards)

ds_train = ds_train.apply(
bl_preprocessing.snapshotter(
Expand Down Expand Up @@ -175,12 +175,12 @@ def dataset_creator(
tuple, mit.unzip(ds_dict.values())
)

ds_train = tf_concatenate(datasets_train)
ds_train = concatenate(datasets_train)
if None in datasets_val:
ds_val = None
else:
ds_val = tf_concatenate(datasets_val)
ds_test = tf_concatenate(datasets_test)
ds_val = concatenate(datasets_val)
ds_test = concatenate(datasets_test)

if dataset_size is not None:
ds_train = ds_train.take(dataset_size)
Expand Down
29 changes: 14 additions & 15 deletions boiling_learning/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def to_str(self):
)


def tf_concatenate(datasets: Iterable[tf.data.Dataset]) -> tf.data.Dataset:
def concatenate(datasets: Iterable[tf.data.Dataset]) -> tf.data.Dataset:
datasets = deque(datasets)

if not datasets:
Expand All @@ -128,7 +128,14 @@ def tf_concatenate(datasets: Iterable[tf.data.Dataset]) -> tf.data.Dataset:
return ds


def tf_train_val_test_split(
def flatten_zip(*ds: tf.data.Dataset) -> tf.data.Dataset:
if len(ds) == 1:
return ds[0]
else:
return tf.data.Dataset.zip(ds)


def train_val_test_split(
ds: tf.data.Dataset, splits: DatasetSplits, shuffle: bool = False
) -> DatasetTriplet:
"""#TODO(docstring): describe here
Expand All @@ -139,12 +146,6 @@ def tf_train_val_test_split(
Consequently it is safe to be used early in the pipeline to avoid consuming test data.
"""

def flatten_zip(*ds: tf.data.Dataset) -> tf.data.Dataset:
if len(ds) == 1:
return ds[0]
else:
return tf.data.Dataset.zip(ds)

if splits.val == 0:
split_train, split_test = mathutils.proportional_ints(
splits.train, splits.test
Expand Down Expand Up @@ -181,7 +182,7 @@ def flatten_zip(*ds: tf.data.Dataset) -> tf.data.Dataset:
return ds_train, ds_val, ds_test


def tf_train_val_test_split_concat(
def train_val_test_split_concat(
datasets: Iterable[tf.data.Dataset], splits: DatasetSplits
) -> DatasetTriplet:
"""#TODO(docstring): describe here
Expand All @@ -199,16 +200,16 @@ def tf_train_val_test_split_concat(
raise ValueError('argument *datasets* must be a non-empty iterable.')

ds = datasets.popleft()
ds_train, ds_val, ds_test = tf_train_val_test_split(ds, splits)
ds_train, ds_val, ds_test = train_val_test_split(ds, splits)

if ds_val is None:
for ds in datasets:
ds_train_, _, ds_test_ = tf_train_val_test_split(ds, splits)
ds_train_, _, ds_test_ = train_val_test_split(ds, splits)
ds_train = ds_train.concatenate(ds_train_)
ds_test = ds_test.concatenate(ds_test_)
else:
for ds in datasets:
ds_train_, ds_val_, ds_test_ = tf_train_val_test_split(ds, splits)
ds_train_, ds_val_, ds_test_ = train_val_test_split(ds, splits)
ds_train = ds_train.concatenate(ds_train_)
ds_val = ds_val.concatenate(ds_val_)
ds_test = ds_test.concatenate(ds_test_)
Expand Down Expand Up @@ -266,9 +267,7 @@ def take(
if isinstance(count, int):
return ds.take(count)
elif isinstance(count, Fraction):
return tf_train_val_test_split(ds, splits=DatasetSplits(train=count))[
0
]
return train_val_test_split(ds, splits=DatasetSplits(train=count))[0]
else:
raise TypeError(
f'*count* must be either *int* or *Fraction*, got {type(count)}.'
Expand Down
1 change: 1 addition & 0 deletions boiling_learning/preprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from boiling_learning.preprocessing.Case import *
from boiling_learning.preprocessing.ExperimentalData import *
from boiling_learning.preprocessing.ExperimentVideo import *
from boiling_learning.preprocessing.ImageDataset import *
from boiling_learning.preprocessing.preprocessing import *
6 changes: 5 additions & 1 deletion boiling_learning/preprocessing/image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Any, Iterable, Optional, Tuple, TypeVar, Union

import numpy as np
Expand Down Expand Up @@ -159,7 +160,10 @@ def grayscale(image: ImageType) -> tf.Tensor:
def downscale(
image: ImageType, factors: Tuple[int, int], antialias: bool = False
) -> tf.Tensor:
sizes = (image.shape[0] // factors[0], image.shape[1] // factors[1])
sizes = (
math.ceil(image.shape[0] / factors[0]),
math.ceil(image.shape[1] / factors[1]),
)
return tf.image.resize(
image, sizes, method='bilinear', antialias=antialias
)
Expand Down
1 change: 1 addition & 0 deletions boiling_learning/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from boiling_learning.utils.FrozenDict import *
from boiling_learning.utils.Parameters import *
from boiling_learning.utils.utils import *
3 changes: 2 additions & 1 deletion image_generators/steady_state.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
import modin.pandas as pd
import seaborn as sns
from pathlib2 import Path

STEADY_STATE_DIR = Path(__file__).parent.parent / 'Selected Experiments'
STEADY_STATE_PATH = STEADY_STATE_DIR / 'SteadyState.csv'
Expand Down
23 changes: 4 additions & 19 deletions run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime
from pathlib import Path

import matplotlib as mpl
import modin.pandas as pd
import nidaqmx
import numpy as np
Expand Down Expand Up @@ -66,7 +67,6 @@ def surface_area(self):
)


#%%
# -------------------------------------------------------
# Settings
# -------------------------------------------------------
Expand Down Expand Up @@ -224,7 +224,6 @@ def correct_wire_temperature(reference_file):

# For timestamps: <https://knowledge.ni.com/KnowledgeArticleDetails?id=kA00Z000000kJy2SAE&l=pt-BR>

#%%
"""
Support definitions -------------------------------------------------------
"""
Expand Down Expand Up @@ -283,7 +282,6 @@ def print_if_must(keys, *args, conds=None, **kwargs):
bl.utils.print_verbose(cond, *args, **kwargs)


#%%
# -------------------------------------------------------
# Channel definitions
# -------------------------------------------------------
Expand Down Expand Up @@ -339,7 +337,6 @@ def print_if_must(keys, *args, conds=None, **kwargs):
ChannelType.INPUT,
)

#%%
"""
Print system information -------------------------------------------------------
"""
Expand All @@ -361,7 +358,6 @@ def print_if_must(keys, *args, conds=None, **kwargs):
('anything', 'info'), f'> Types in ChannelType: {[x for x in ChannelType]}'
)

#%%
"""
Load calibration polynomial -------------------------------------------------------
"""
Expand All @@ -374,7 +370,6 @@ def print_if_must(keys, *args, conds=None, **kwargs):
f'> Calibrated polynomial:\n{calibrated_polynomial}',
)

#%%
# -------------------------------------------------------
# Initialize
# -------------------------------------------------------
Expand All @@ -397,7 +392,6 @@ def print_if_must(keys, *args, conds=None, **kwargs):

output_writer = csv.writer(output_file)

#%%
# -------------------------------------------------------
# Setup channels
# -------------------------------------------------------
Expand Down Expand Up @@ -458,7 +452,6 @@ def print_if_must(keys, *args, conds=None, **kwargs):
f'experiment samp_clk_rate: {experiment.timing.samp_clk_rate}',
)

#%%
# -------------------------------------------------------
# Run experiment
# -------------------------------------------------------
Expand All @@ -472,20 +465,17 @@ def print_if_must(keys, *args, conds=None, **kwargs):
# Header
# -------------------------------------------------------
print_if_must(('anything', 'info'), f'> Iteration {iter_count}')
#%%
# -------------------------------------------------------
# Time measurement
# -------------------------------------------------------
elapsed_time = np.array([time.time() - start])
#%%
# -------------------------------------------------------
# Read data
# -------------------------------------------------------
readings[experiment.name] = experiment.read(
number_of_samples_per_channel=nidaqmx.constants.READ_ALL_AVAILABLE
)

#%%
# -------------------------------------------------------
# Process data
# -------------------------------------------------------
Expand Down Expand Up @@ -524,7 +514,6 @@ def print_if_must(keys, *args, conds=None, **kwargs):
experiment, readings, dtype=np.array
)

#%%
# -------------------------------------------------------
# Saving
# -------------------------------------------------------
Expand All @@ -542,7 +531,7 @@ def print_if_must(keys, *args, conds=None, **kwargs):
# 'Temperature from Resistance [deg C]': wire_temperature_from_resistance,
# 'Wire Temperature (corrected) [deg C]': wire_temperature_corrected,
}
n_values = min([local.size for local in local_data.values()])
n_values = min(local.size for local in local_data.values())

# Time measurement
if n_values > 1:
Expand All @@ -569,6 +558,7 @@ def print_if_must(keys, *args, conds=None, **kwargs):

# TODO: here
if measure_loop_time:
previous_elapsed_time = None
if first:
loop_time = np.zeros(n_values)
else:
Expand Down Expand Up @@ -608,7 +598,6 @@ def print_if_must(keys, *args, conds=None, **kwargs):
time.sleep(sleeping_time)
continue

#%%
# -------------------------------------------------------
# Writing to file
# -------------------------------------------------------
Expand All @@ -626,7 +615,6 @@ def print_if_must(keys, *args, conds=None, **kwargs):

print_if_must(('anything', 'writing'), '>> Done')

#%%
"""
Printing -------------------------------------------------------
"""
Expand Down Expand Up @@ -670,7 +658,6 @@ def print_if_must(keys, *args, conds=None, **kwargs):
# print_if_must(('anything', 'temperature from resistance'), f'>> Temperature from Resistance [deg C]: {wire_temperature_from_resistance}', conds=[wire_temperature_from_resistance.size > 0])
# print_if_must(('anything', 'wire temperature corrected'), f'>> Wire Temperature (corrected) [deg C]: {wire_temperature_corrected}', conds=[wire_temperature_corrected.size > 0])

#%%
# -------------------------------------------------------
# Plotting
# -------------------------------------------------------
Expand Down Expand Up @@ -748,7 +735,6 @@ def print_if_must(keys, *args, conds=None, **kwargs):

QtGui.QApplication.processEvents()

#%%
# -------------------------------------------------------
# Finish iteration
# -------------------------------------------------------
Expand All @@ -764,11 +750,10 @@ def print_if_must(keys, *args, conds=None, **kwargs):
##################


#%%
# -------------------------------------------------------
# Plot Results
# -------------------------------------------------------
matplotlib.use('Agg')
mpl.use('Agg')

# datatype = [
# ('index', np.float32),
Expand Down
Loading

0 comments on commit ccdbd20

Please sign in to comment.