Skip to content

Commit

Permalink
refactor(utils.py): flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
chanshing committed Nov 19, 2024
1 parent a445f6d commit 8993562
Showing 1 changed file with 40 additions and 38 deletions.
78 changes: 40 additions & 38 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from datetime import datetime, timedelta, time
from pandas.plotting import register_matplotlib_converters
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import display, HTML
import numpy as np
import sklearn.metrics as metrics
Expand Down Expand Up @@ -78,7 +85,7 @@ def encode_one_hot(y):
3 -> 0,0,0,1,0
4 -> 0,0,0,0,1
'''
return (y.reshape(-1,1) == np.arange(NUM_CLASSES)).astype(int)
return (y.reshape(-1, 1) == np.arange(NUM_CLASSES)).astype(int)


def train_hmm(Y_pred, y_true):
Expand All @@ -87,12 +94,12 @@ def train_hmm(Y_pred, y_true):
if Y_pred.ndim == 1 or Y_pred.shape[1] == 1:
Y_pred = encode_one_hot(Y_pred)

prior = np.mean(y_true.reshape(-1,1) == np.arange(NUM_CLASSES), axis=0)
prior = np.mean(y_true.reshape(-1, 1) == np.arange(NUM_CLASSES), axis=0)
emission = np.vstack(
[np.mean(Y_pred[y_true==i], axis=0) for i in range(NUM_CLASSES)]
[np.mean(Y_pred[y_true == i], axis=0) for i in range(NUM_CLASSES)]
)
transition = np.vstack(
[np.mean(y_true[1:][(y_true==i)[:-1]].reshape(-1,1) == np.arange(NUM_CLASSES), axis=0)
[np.mean(y_true[1:][(y_true == i)[:-1]].reshape(-1, 1) == np.arange(NUM_CLASSES), axis=0)
for i in range(NUM_CLASSES)]
)
return prior, emission, transition
Expand All @@ -107,19 +114,19 @@ def log(x):

num_obs = len(y_pred)
probs = np.zeros((num_obs, NUM_CLASSES))
probs[0,:] = log(prior) + log(emission[:, y_pred[0]])
probs[0, :] = log(prior) + log(emission[:, y_pred[0]])
for j in range(1, num_obs):
for i in range(NUM_CLASSES):
probs[j,i] = np.max(
log(emission[i, y_pred[j]]) + \
log(transition[:, i]) + \
probs[j-1,:]) # probs already in log scale
probs[j, i] = np.max(
log(emission[i, y_pred[j]]) +
log(transition[:, i]) +
probs[j - 1, :]) # probs already in log scale
viterbi_path = np.zeros_like(y_pred)
viterbi_path[-1] = np.argmax(probs[-1,:])
for j in reversed(range(num_obs-1)):
viterbi_path[-1] = np.argmax(probs[-1, :])
for j in reversed(range(num_obs - 1)):
viterbi_path[j] = np.argmax(
log(transition[:, viterbi_path[j+1]]) + \
probs[j,:]) # probs already in log scale
log(transition[:, viterbi_path[j + 1]]) +
probs[j, :]) # probs already in log scale

return viterbi_path

Expand All @@ -132,11 +139,11 @@ def compute_scores(y_true, y_pred):
balanced_acuracy = metrics.balanced_accuracy_score(y_true, y_pred)
kappa = metrics.cohen_kappa_score(y_true, y_pred)
return {
'confusion':confusion,
'per_class_recall':per_class_recall,
'confusion': confusion,
'per_class_recall': per_class_recall,
'accuracy': accuracy,
'balanced_accuracy': balanced_acuracy,
'kappa':kappa,
'kappa': kappa,
}


Expand All @@ -158,17 +165,12 @@ def print_scores(scores):
# ----------------------------------------
# Function to plot activity timeseries
# ----------------------------------------
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
from pandas.plotting import register_matplotlib_converters
register_matplotlib_converters()
from datetime import datetime, timedelta, time


def plot_activity(x, y, t):
''' Plot activity timeseries '''
BACKGROUND_COLOR = '#d3d3d3' # lightgray
BACKGROUND_COLOR = '#d3d3d3' # lightgray

def split_by_timegap(group, seconds=30):
subgroupIDs = (group.index.to_series().diff() > timedelta(seconds=seconds)).cumsum()
Expand All @@ -177,9 +179,9 @@ def split_by_timegap(group, seconds=30):

convert_date = np.vectorize(
lambda day, x: matplotlib.dates.date2num(datetime.combine(day, x)))
timeseries = pd.DataFrame(data={'x':x, 'y':y, 't':t})
timeseries = pd.DataFrame(data={'x': x, 'y': y, 't': t})
timeseries.set_index('t', inplace=True)
timeseries['x'] = timeseries['x'].rolling(window=12, min_periods=1).mean() #! inplace?
timeseries['x'] = timeseries['x'].rolling(window=12, min_periods=1).mean() # ! inplace?
ylim_min, ylim_max = np.min(x), np.max(x)
groups = timeseries.groupby(timeseries.index.date)
fig, axs = plt.subplots(nrows=len(groups) + 1)
Expand All @@ -192,16 +194,16 @@ def split_by_timegap(group, seconds=30):

ax.get_xaxis().grid(True, which='major', color='grey', alpha=0.5)
ax.get_xaxis().grid(True, which='minor', color='grey', alpha=0.25)
ax.set_xlim((datetime.combine(day,time(0, 0, 0, 0)),
datetime.combine(day + timedelta(days=1), time(0, 0, 0, 0))))
ax.set_xticks(pd.date_range(start=datetime.combine(day,time(0, 0, 0, 0)),
end=datetime.combine(day + timedelta(days=1), time(0, 0, 0, 0)),
freq='4H'))
ax.set_xticks(pd.date_range(start=datetime.combine(day,time(0, 0, 0, 0)),
end=datetime.combine(day + timedelta(days=1), time(0, 0, 0, 0)),
freq='1H'), minor=True)
ax.set_xlim((datetime.combine(day, time(0, 0, 0, 0)),
datetime.combine(day + timedelta(days=1), time(0, 0, 0, 0))))
ax.set_xticks(pd.date_range(start=datetime.combine(day, time(0, 0, 0, 0)),
end=datetime.combine(day + timedelta(days=1), time(0, 0, 0, 0)),
freq='4H'))
ax.set_xticks(pd.date_range(start=datetime.combine(day, time(0, 0, 0, 0)),
end=datetime.combine(day + timedelta(days=1), time(0, 0, 0, 0)),
freq='1H'), minor=True)
ax.set_ylim((ylim_min, ylim_max))
ax.get_yaxis().set_ticks([]) # hide y-axis lables
ax.get_yaxis().set_ticks([]) # hide y-axis lables
ax.spines['top'].set_color(BACKGROUND_COLOR)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
Expand All @@ -225,16 +227,16 @@ def split_by_timegap(group, seconds=30):
for color, label in zip(COLORS, CLASSES):
legend_patches.append(mpatches.Patch(facecolor=color, label=label, alpha=0.5))
axs[-1].legend(handles=legend_patches, bbox_to_anchor=(0., 0., 1., 1.),
loc='center', ncol=min(3,len(legend_patches)), mode="best",
borderaxespad=0, framealpha=0.6, frameon=True, fancybox=True)
loc='center', ncol=min(3, len(legend_patches)), mode="best",
borderaxespad=0, framealpha=0.6, frameon=True, fancybox=True)
axs[-1].spines['left'].set_visible(False)
axs[-1].spines['right'].set_visible(False)
axs[-1].spines['top'].set_visible(False)
axs[-1].spines['bottom'].set_visible(False)

# format x-axis to show hours
fig.autofmt_xdate()
hours = [(str(hr) + 'am') if hr<=12 else (str(hr-12) + 'pm') for hr in range(0,24,4)]
hours = [(str(hr) + 'am') if hr <= 12 else (str(hr - 12) + 'pm') for hr in range(0, 24, 4)]
axs[0].set_xticklabels(hours)
axs[0].tick_params(labelbottom=False, labeltop=True, labelleft=False)

Expand Down

0 comments on commit 8993562

Please sign in to comment.