Skip to content

Commit

Permalink
Fixed code formatting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 authored Sep 27, 2024
1 parent 7ddd117 commit 5dfb07b
Showing 1 changed file with 37 additions and 37 deletions.
74 changes: 37 additions & 37 deletions tests/ignite/metrics/test_hsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,54 +120,54 @@ def test_accumulator_detached():

@pytest.mark.usefixtures("distributed")
class TestDistributed:
@pytest.mark.parametrize("sigma_x", [-1.0, 1.0])
@pytest.mark.parametrize("sigma_y", [-1.0, 1.0])
def test_integration(self, sigma_x: float, sigma_y: float):
tol = 2e-5
n_iters = 100
batch_size = 20
n_dims_x = 100
n_dims_y = 50
@pytest.mark.parametrize("sigma_x", [-1.0, 1.0])
@pytest.mark.parametrize("sigma_y", [-1.0, 1.0])
def test_integration(self, sigma_x: float, sigma_y: float):
tol = 2e-5
n_iters = 100
batch_size = 20
n_dims_x = 100
n_dims_y = 50

rank = idist.get_rank()
torch.manual_seed(12 + rank)

rank = idist.get_rank()
torch.manual_seed(12 + rank)

device = idist.device()
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
device = idist.device()
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)

for metric_device in metric_devices:
x = torch.randn((n_iters * batch_size, n_dims_x)).float().to(device)
for metric_device in metric_devices:
x = torch.randn((n_iters * batch_size, n_dims_x)).float().to(device)

lin = nn.Linear(n_dims_x, n_dims_y).to(device)
y = torch.sin(lin(x) * 100) + torch.randn(n_iters * batch_size, n_dims_y) * 1e-4
lin = nn.Linear(n_dims_x, n_dims_y).to(device)
y = torch.sin(lin(x) * 100) + torch.randn(n_iters * batch_size, n_dims_y) * 1e-4

def data_loader(i, input_x, input_y):
return input_x[i * batch_size : (i + 1) * batch_size], input_y[i * batch_size : (i + 1) * batch_size]
def data_loader(i, input_x, input_y):
return input_x[i * batch_size : (i + 1) * batch_size], input_y[i * batch_size : (i + 1) * batch_size]

engine = Engine(lambda e, i: data_loader(i, x, y))
engine = Engine(lambda e, i: data_loader(i, x, y))

m = HSIC(sigma_x=sigma_x, sigma_y=sigma_y, device=metric_device)
m.attach(engine, "hsic")
m = HSIC(sigma_x=sigma_x, sigma_y=sigma_y, device=metric_device)
m.attach(engine, "hsic")

data = list(range(n_iters))
engine.run(data=data, max_epochs=1)
data = list(range(n_iters))
engine.run(data=data, max_epochs=1)

assert "hsic" in engine.state.metrics
res = engine.state.metrics["hsic"]
assert "hsic" in engine.state.metrics
res = engine.state.metrics["hsic"]

x = idist.all_gather(x)
y = idist.all_gather(y)
total_n_iters = idist.all_reduce(n_iters)
x = idist.all_gather(x)
y = idist.all_gather(y)
total_n_iters = idist.all_reduce(n_iters)

np_res = 0.0
for i in range(total_n_iters):
x_batch, y_batch = data_loader(i, x, y)
np_res += np_hsic(x_batch, y_batch, sigma_x, sigma_y)
np_res = 0.0
for i in range(total_n_iters):
x_batch, y_batch = data_loader(i, x, y)
np_res += np_hsic(x_batch, y_batch, sigma_x, sigma_y)

expected_hsic = np_res / total_n_iters
assert pytest.approx(expected_hsic, abs=tol) == res
expected_hsic = np_res / total_n_iters
assert pytest.approx(expected_hsic, abs=tol) == res

def test_accumulator_device(self):
device = idist.device()
Expand Down

0 comments on commit 5dfb07b

Please sign in to comment.