-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
112 lines (94 loc) · 3.26 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import numpy as np
import torch as th
import imageio.v3 as imageio
import os
from pathlib import Path
import math
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
def draw_patterns(
positions: th.Tensor, na: float, indices: list[th.Tensor], root_path: Path
):
for i, pat_indices in enumerate(indices):
fig, ax = plt.subplots(frameon="false", figsize=(5, 5))
fig.patch.set_facecolor("k")
r = math.sqrt(na**2 * positions[0, 2] ** 2 / (1 - na**2)) / 1000
circle = Circle((0, 0), r, facecolor="gray", edgecolor="none", linewidth=None)
ax.patch.set_facecolor("k")
ax.axis("equal")
ax.set_axis_off()
ax.add_patch(circle)
ax.scatter(
*positions[:, :2].T.cpu().numpy() / 1000,
facecolors="none",
edgecolors="darkgray",
s=100,
alpha=0,
)
for idx in pat_indices:
ax.scatter(
*positions[idx, :2].cpu().numpy() / 1000,
facecolors="w",
edgecolors="w",
s=100,
)
path = root_path / "patterns"
if not os.path.exists(path):
os.makedirs(path)
fig.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
fig.savefig(path / f"{i:02d}.pdf")
plt.close()
def dump_experiments(x: th.Tensor, path: Path, crop: int):
lineplot_j = 196
lineplot_istart = 115
lineplot_iend = 128
plot_xaxis = np.arange(lineplot_iend - lineplot_istart)
x = np.fliplr(x[0, 0, crop:-crop, crop:-crop].cpu().numpy())
if not os.path.exists(path):
os.makedirs(path)
vals = x[lineplot_istart:lineplot_iend, lineplot_j]
vals = np.stack((plot_xaxis, vals)).T
np.savetxt(
path / "vals.csv", vals, delimiter=",", header="x,y", comments="", fmt="%.5f"
)
span = np.array([x.min().item(), x.max().item()])[None]
np.savetxt(path / "span.csv", span, delimiter=",", fmt="%.1f")
x -= x.min()
x /= x.max()
x *= 255.0
x = x.astype(np.uint8)
imageio.imwrite(path / "x_est.png", x)
def snr(x, y):
return 10 * th.log10((x**2).sum() / ((x - y) ** 2).sum())
def rmse(x, y):
return ((x - y) ** 2).mean().sqrt()
def dump_simulation(x, ref, path):
if not os.path.exists(path):
os.makedirs(path)
metrics = np.array(
[
snr(th.angle(ref), th.angle(x)).cpu().numpy(),
rmse(th.angle(x), th.angle(ref)).cpu().numpy(),
]
)[None]
np.savetxt(path / "metrics.csv", metrics, delimiter=",", fmt="%.2f")
ref_ft = th.fft.fft2(th.angle(ref))
x_ft = th.fft.fft2(th.angle(x))
ft_error = th.fft.fftshift(th.abs(ref_ft - x_ft) / th.abs(ref_ft + 1e-6)).clamp_max(
1
)
x = th.angle(x)
span = np.array([x.min().item().real, x.max().item().real])[None]
np.savetxt(path / "span.csv", span, delimiter=",", fmt="%.1f")
x = x.cpu().numpy().squeeze()
x -= x.min()
x /= x.max()
x *= 255.0
x = x.astype(np.uint8)
imageio.imwrite(path / "x_est.png", x)
ft_error = ft_error.cpu().numpy().squeeze()
ft_error -= ft_error.min()
ft_error /= ft_error.max()
ft_error *= 255.0
ft_error = ft_error.astype(np.uint8)
imageio.imwrite(path / "ft_error.png", ft_error)