Skip to content

Commit

Permalink
code review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sblackburn-mila committed Mar 28, 2024
1 parent de29ae6 commit 428109c
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions crystal_diffusion/analysis/dataset_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@ def compute_metrics_for_a_run(run_path: str) -> Dict[str, pd.Series]:
y0s = df['y'][0]
z0s = df['z'][0]

displacement = df.apply(lambda row: [(x - x0) ** 2 + (y - y0) ** 2 + (z - z0) ** 2 for x, y, z, x0, y0, z0 in
zip(row['x'], row['y'], row['z'], x0s, y0s, z0s)], axis=1)
# compute the square displacement: d^2 = (x - x0)^2

metrics['mean_square_displacement'] = displacement.apply(lambda row: np.sqrt(np.mean(row)))
metrics['std_square_displacement'] = displacement.apply(lambda row: np.sqrt(np.std(row)))
square_displacement = df.apply(lambda row: [(x - x0) ** 2 + (y - y0) ** 2 + (z - z0) ** 2 for x, y, z, x0, y0, z0 in
zip(row['x'], row['y'], row['z'], x0s, y0s, z0s)], axis=1)

metrics['root_mean_square_displacement'] = square_displacement.apply(lambda row: np.sqrt(np.mean(row)))

metrics['std_displacement'] = np.sqrt(sum([np.var(df[x]) for x in ['x', 'y', 'z']]))

return metrics

Expand Down Expand Up @@ -88,8 +91,8 @@ def plot_metrics_runs(dataset_name: str, mode: str = 'train'):
for k, m in metrics.items():
axs[0].plot(m['energy'], '-', lw=2)
axs[1].plot(m['force_norm_average'], ':', lw=2)
axs[2].plot(m['mean_square_displacement'], lw=2)
axs[3].plot(m['std_square_displacement'], lw=2)
axs[2].plot(m['root_mean_square_displacement'], lw=2)
axs[3].plot(m['std_displacement'], lw=2)
legend.append(k)

for ax in axs:
Expand Down

0 comments on commit 428109c

Please sign in to comment.