Skip to content

Commit

Permalink
Increased flexibility of number of graphs that can be plotted together.
Browse files Browse the repository at this point in the history
  • Loading branch information
pineapple-cat committed Dec 14, 2023
1 parent b723c42 commit b17338f
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions src/hivpy/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,31 @@ def graph_output(output_dir, output_stats, graph_outputs):
plt.close()


def compare_output(output_dir, df1, df2, graph_outputs, label1="HIVpy", label2="SAS"):
def compare_output(output_dir, output_stats, graph_outputs, label1="HIVpy", label2="SAS"):

for out in graph_outputs:
if out in df1.columns and out in df2.columns:

_, ax = plt.subplots()
plt.plot(df1["Date"], df1[out], label=label1)
plt.plot(df2["Date"], df2[out], label=label2)
_, ax = plt.subplots()

title_out = titlecase(out)
plt.xlabel("Date")
plt.ylabel(title_out)
plt.title("Comparison of {0} Over Time".format(title_out))
ax.legend()
plt.savefig(os.path.join(output_dir, "Comparison of {0} Over Time".format(title_out)), bbox_inches='tight')
plt.close()
for i in range(len(output_stats)):
df = output_stats[i]
if out in df.columns:

if i == 0:
plt.plot(df["Date"], df[out], label=label1)
elif i == 1:
plt.plot(df["Date"], df[out], label=label2)
else:
# plot additional files without label for now
plt.plot(df["Date"], df[out])

title_out = titlecase(out)
plt.xlabel("Date")
plt.ylabel(title_out)
plt.title("Comparison of {0} Over Time".format(title_out))
ax.legend()
plt.savefig(os.path.join(output_dir, "Comparison of {0} Over Time".format(title_out)), bbox_inches='tight')
plt.close()


def run_post():
Expand All @@ -63,9 +72,9 @@ def run_post():
if len(input) == 1:
print("graphing outputs")
graph_output(args.output_dir, input[0], graph_outputs)
if len(input) == 2:
if len(input) > 1:
print("comparing outputs")
compare_output(args.output_dir, input[0], input[1], graph_outputs)
compare_output(args.output_dir, input, graph_outputs)

except yaml.YAMLError as err:
print("Error parsing yaml file {}".format(err))
Expand Down

0 comments on commit b17338f

Please sign in to comment.