Skip to content

Commit

Permalink
Update reports.py
Browse files Browse the repository at this point in the history
small fixes and added docstrings
#104 , #79
  • Loading branch information
cb-Hades committed Dec 5, 2023
1 parent 231e232 commit b6b28f1
Showing 1 changed file with 76 additions and 13 deletions.
89 changes: 76 additions & 13 deletions refinegems/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,30 +82,70 @@ def add_sim_results(self, new_rep: SingleGrowthSimulationReport):
self.models.add(new_rep.model_name)
self.media.add(new_rep.medium_name)

def to_table(self):
def to_table(self) -> pd.DataFrame:
"""Return a table of the contents of the report.
Returns:
pd.DataFrame: The table containing the information in the report.
"""

l = []
for report in self.reports:
l.append(report.to_dict())
return pd.DataFrame(l)


def plot_growth(self, unit:Literal['h','dt']='dt', color_palette:str='YlGn') -> matplotlib.figure.Figure:
"""Visualise the contents of the report.
def plot_growth(self, unit:Literal['h','dt']='dt'):
Args:
unit (Literal['h','dt'], optional): Set the unit to plot.
Can be doubling time in minutes ('dt') or growth rates in mmol/gDWh ('h').
Defaults to 'dt'.
color_palette (str, optional): A colour gradient from the matplotlib library.
If the name does not exist, uses the default.
Defaults to 'YlGn'.
Returns:
matplotlib.figure.Figure: The plotted figure.
"""

#@ TODO add options / params or so to change the plot e.g. colour or size
def plot_growth_bar(xdata, xlab, ydata, ylab, title):
def plot_growth_bar(xdata:list[str], xlab:str, ydata:list[float],
ylab:str, title:str, color_palette:str='YlGn') -> matplotlib.figure.Figure:
"""Helper function to plot the bar plot for the growth visualisation.
Args:
xdata (list[str]): List of the x-axis data (medium or model names).
xlab (str): The x-axis label.
ydata (list[float]): List of thr y-axis data (the values).
ylab (str): The y-axis label.
title (str): The title of the plot.
color_palette (str, optional): A colour gradient from the matplotlib library.
If the name does not exist, uses the default.
Defaults to 'YlGn'.
Returns:
matplotlib.figure.Figure: The plotted figure.
"""

# create colour gradient
cmap = sns.cubehelix_palette(start=0.5, rot=-.75, gamma=0.75, dark=0.25, light=0.6, reverse=True, as_cmap=True)
try:
cmap = matplotlib.cm.get_cmap(color_palette).copy()
except ValueError:
warnings.warn('Unknown color palette, setting it to "YlGn"')
cmap = matplotlib.cm.get_cmap('YlGn').copy()

# set up the figure
fig = plt.figure()
ax = fig.add_axes([0,0,1,1])

# construct the plot
max_ydata = max(ydata)
if max_ydata <= 0:
warnings.warn('Model is not able to grow on every medium. Returning empty figure.')
return fig
cont = ax.bar(list(xdata), ydata, color=cmap([_/max_ydata for _ in ydata]))

# set labels and others
ax.bar_label(cont, fmt='%.2f', color='black', padding=1.0)
ax.set_ylabel(ylab, labelpad=12)
Expand All @@ -116,8 +156,22 @@ def plot_growth_bar(xdata, xlab, ydata, ylab, title):

return fig

def plot_growth_heatmap(data):
print(data)

def plot_growth_heatmap(data: pd.DataFrame, color_palette:str='YlGn') -> matplotlib.figure.Figure:
"""Helper function to plot the heatmap for the growth visualisation.
Args:
data (pd.DataFrame): The table containing the data to be plotted.
Needs to have the columns 'medium', 'model' and one for the growth values.
color_palette (str, optional): A colour gradient from the matplotlib library.
If the name does not exist, uses the default.
Defaults to 'YlGn'.
Returns:
matplotlib.figure.Figure: The plotted figure.
"""

# clean up + transform data
growth=data.set_index(['medium', 'model']).sort_index().T.stack()
growth.columns.name=None
growth.index.names = (None,None)
Expand All @@ -135,9 +189,17 @@ def plot_growth_heatmap(data):
annot = annot.round().astype(int)
annot[annot < 1e-5] = ''
annot.replace(over_growth.round().astype(int), 'No data', inplace=True)
cmap=matplotlib.cm.get_cmap('YlGn').copy()
cmap.set_under('black')
cmap.set_over('white')

# setting the colours
try:
cmap = matplotlib.cm.get_cmap(color_palette).copy()
except ValueError:
warnings.warn('Unknown color palette, setting it to "YlGn"')
cmap = matplotlib.cm.get_cmap('YlGn').copy()
cmap.set_under('black') # too low / no growth
cmap.set_over('white') # no data

# plot the heatmap
fig, ax = plt.subplots(figsize=(10,8))
sns.heatmap(growth.T,
annot=annot.T,
Expand All @@ -153,6 +215,7 @@ def plot_growth_heatmap(data):
rotation = 40 if len(growth.index) > 3 else 0
plt.tick_params(rotation=0, bottom=False, top=False, left=False, right=False)
ax.set_xticklabels(ax.get_xticklabels(), rotation=rotation, ha="right")

return fig


Expand All @@ -176,7 +239,7 @@ def plot_growth_heatmap(data):
title = f'Growth simulation on {next(iter(self.media))} for different models'

# plot
return plot_growth_bar(xdata,xlab,ydata,ylab,title)
return plot_growth_bar(xdata,xlab,ydata,ylab,title, color_palette)

# one model vs mutiple media
elif len(self.models) == 1 and len(self.media) > 1:
Expand All @@ -189,14 +252,14 @@ def plot_growth_heatmap(data):
title = f'Growth simulation for {next(iter(self.models))} on different media'

# plot
return plot_growth_bar(xdata,xlab,ydata,ylab,title)
return plot_growth_bar(xdata,xlab,ydata,ylab,title, color_palette)

# multiple vs multiple
elif len(self.models) > 1 and len(self.media) > 1:

data = pd.DataFrame({'model':[_.model_name for _ in self.reports], 'medium':[_.medium_name for _ in self.reports], 'growth':[_.growth_value if unit=='h' else _.doubling_time for _ in self.reports]})

return plot_growth_heatmap(data)
return plot_growth_heatmap(data, color_palette)

# problematic case
else:
Expand Down

0 comments on commit b6b28f1

Please sign in to comment.