diff --git a/docs/pytorch_lightning.html b/docs/pytorch_lightning.html deleted file mode 100644 index 5988176..0000000 --- a/docs/pytorch_lightning.html +++ /dev/null @@ -1,16092 +0,0 @@ ---- - -title: Title - - -keywords: fastai -sidebar: home_sidebar - - - -nb_path: "nbs/04d_pytorch_lightning.ipynb" ---- - - -
- - {% raw %} - -
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
import os
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torchvision
-from pl_bolts.datamodules import CIFAR10DataModule
-from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
-from pytorch_lightning import LightningModule, Trainer, seed_everything
-from pytorch_lightning.callbacks import LearningRateMonitor
-from pytorch_lightning.loggers import TensorBoardLogger
-from torch.optim.lr_scheduler import OneCycleLR
-from torch.optim.swa_utils import AveragedModel, update_bn
-from torchmetrics.functional import accuracy
-
-seed_everything(7)
-
-PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
-AVAIL_GPUS = min(1, torch.cuda.device_count())
-BATCH_SIZE = 256 if AVAIL_GPUS else 64
-NUM_WORKERS = int(os.cpu_count() / 2)
-
- -
-
-
- -
-
- -
- -
-
Global seed set to 7
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
train_transforms = torchvision.transforms.Compose(
-    [
-        torchvision.transforms.RandomCrop(32, padding=4),
-        torchvision.transforms.RandomHorizontalFlip(),
-        torchvision.transforms.ToTensor(),
-        cifar10_normalization(),
-    ]
-)
-
-test_transforms = torchvision.transforms.Compose(
-    [
-        torchvision.transforms.ToTensor(),
-        cifar10_normalization(),
-    ]
-)
-
-cifar10_dm = CIFAR10DataModule(
-    data_dir=PATH_DATASETS,
-    batch_size=BATCH_SIZE,
-    num_workers=NUM_WORKERS,
-    train_transforms=train_transforms,
-    test_transforms=test_transforms,
-    val_transforms=test_transforms,
-)
-
- -
-
-
- -
-
- -
- -
-
/home/HubensN/miniconda3/envs/deep/lib/python3.7/site-packages/pytorch_lightning/core/datamodule.py:74: LightningDeprecationWarning: DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7.
-  "DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7."
-/home/HubensN/miniconda3/envs/deep/lib/python3.7/site-packages/pytorch_lightning/core/datamodule.py:78: LightningDeprecationWarning: DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7.
-  "DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7."
-/home/HubensN/miniconda3/envs/deep/lib/python3.7/site-packages/pytorch_lightning/core/datamodule.py:82: LightningDeprecationWarning: DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7.
-  "DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7."
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
def create_model():
-    model = torchvision.models.resnet18(pretrained=False, num_classes=10)
-    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
-    model.maxpool = nn.Identity()
-    return model
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class LitResnet(LightningModule):
-    def __init__(self, lr=0.05):
-        super().__init__()
-
-        self.save_hyperparameters()
-        self.model = create_model()
-
-    def forward(self, x):
-        out = self.model(x)
-        return F.log_softmax(out, dim=1)
-
-    def training_step(self, batch, batch_idx):
-        x, y = batch
-        logits = self(x)
-        loss = F.nll_loss(logits, y)
-        self.log("train_loss", loss)
-        return loss
-
-    def evaluate(self, batch, stage=None):
-        x, y = batch
-        logits = self(x)
-        loss = F.nll_loss(logits, y)
-        preds = torch.argmax(logits, dim=1)
-        acc = accuracy(preds, y)
-
-        if stage:
-            self.log(f"{stage}_loss", loss, prog_bar=True)
-            self.log(f"{stage}_acc", acc, prog_bar=True)
-
-    def validation_step(self, batch, batch_idx):
-        self.evaluate(batch, "val")
-
-    def test_step(self, batch, batch_idx):
-        self.evaluate(batch, "test")
-
-    def configure_optimizers(self):
-        optimizer = torch.optim.SGD(
-            self.parameters(),
-            lr=self.hparams.lr,
-            momentum=0.9,
-            weight_decay=5e-4,
-        )
-        steps_per_epoch = 45000 // BATCH_SIZE
-        scheduler_dict = {
-            "scheduler": OneCycleLR(
-                optimizer,
-                0.1,
-                epochs=self.trainer.max_epochs,
-                steps_per_epoch=steps_per_epoch,
-            ),
-            "interval": "step",
-        }
-        return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
#from fastai.vision.all import *
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
import fastai
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
import fastai.call
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class SparsifyCallback(fastai.callback.all.Callback):
-
-    def __init__(self, end_sparsity, granularity, method, criteria, sched_func, start_sparsity=0, start_epoch=0, end_epoch=None, lth=False, rewind_epoch=0, reset_end=False, model=None, round_to=None, layer_type=nn.Conv2d):
-        self.end_sparsity = end_sparsity
-        self.granularity, self.method, self.criteria, self.sched_func = granularity, method, criteria, sched_func
-        self.start_sparsity, self.start_epoch, self.end_epoch = start_sparsity, start_epoch, end_epoch
-        self.lth, self.rewind_epoch, self.reset_end = lth, rewind_epoch, reset_end
-        self.model = model
-        self.round_to = round_to
-        self.layer_type = layer_type
-        self.train_iter = 0
-        self.current_sparsity, self.previous_sparsity = 0, 0
-        assert self.start_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'
-        print("Starting to init trainer!")
-        
-        self.current_sparsity, self.previous_sparsity = 0, 0
-
-        assert self.start_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'
-
-    def before_fit(self):
-        print(f'Pruning of {self.granularity} until a sparsity of {self.end_sparsity}%')
-        self.end_epoch = self.n_epoch if self.end_epoch is None else self.end_epoch
-        assert self.end_epoch <= self.n_epoch, 'Your end_epoch must be smaller than total number of epoch'
-
-        model = self.learn.model if self.model is None else self.model # Pass a model if you don't want the whole model to be pruned
-        self.sparsifier = Sparsifier(model, self.granularity, self.method, self.criteria, self.layer_type)
-        self.n_batches = math.floor(len(self.learn.dls.dataset)/self.learn.dls.bs)
-        self.total_iters = self.end_epoch * self.n_batches
-        self.start_iter = self.start_epoch * self.n_batches
-
-    def before_epoch(self):
-        if self.epoch == self.rewind_epoch:
-            print(f'Saving Weights at epoch {self.epoch}')
-            self.sparsifier._save_weights()
-
-    def before_batch(self):
-        if self.epoch>=self.start_epoch:
-            if self.epoch < self.end_epoch: self._set_sparsity()
-            self.sparsifier.prune_model(self.current_sparsity, self.round_to)
-
-            if self.lth and self.current_sparsity!=self.previous_sparsity: # If sparsity has changed, the network has been pruned
-                    print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
-                    self.sparsifier._reset_weights()
-
-            self.previous_sparsity = self.current_sparsity
-
-    def before_step(self):
-        if self.epoch>=self.start_epoch:
-            self.sparsifier._mask_grad()
-
-    def after_epoch(self):
-        print(f'Sparsity at the end of epoch {self.epoch}: {self.current_sparsity:.2f}%')
-
-    def after_fit(self):
-        print(f'Final Sparsity: {self.current_sparsity:.2f}')
-        if self.reset_end:
-            self.sparsifier._reset_weights()
-        self.sparsifier._clean_buffers() # Remove buffers at the end of training
-        self.sparsifier.print_sparsity()
-
-    def _set_sparsity(self):
-        self.current_sparsity = self.sched_func(start=self.start_sparsity, end=self.end_sparsity, pos=(self.train_iter-self.start_iter)/(self.total_iters-self.start_iter))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb = SparsifyCallback(end_sparsity=50, granularity='filter', method='local', criteria=large_final, sched_func=sched_onecycle)
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb = SparsifyCallbackFlash(end_sparsity=50, granularity='filter', method='local', criteria=large_final, sched_func=sched_onecycle)
-
- -
-
-
- -
-
- -
- -
-
Starting to init trainer!
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
from fasterai.sparse.all import *
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
from fastcore.basics import store_attr
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
store_attr??
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
import pytorch_lightning
-import torch.nn as nn
-import math
-
-class SparsifyCallbackFlash(pytorch_lightning.callbacks.Callback):
-    def __init__(self, end_sparsity, granularity, method, criteria, sched_func, start_sparsity=0, start_epoch=0, end_epoch=None, lth=False, rewind_epoch=0, reset_end=False, model=None, round_to=None, layer_type=nn.Conv2d):
-        store_attr('end_sparsity, granularity, method, criteria, sched_func, start_sparsity, start_epoch, end_epoch, lth, rewind_epoch, reset_end, model, round_to, layer_type')
-        self.train_iter = 0
-        self.current_sparsity, self.previous_sparsity = 0, 0
-        assert self.start_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'
-        print("Starting to init trainer!")
-
-
-    def on_fit_start(self, trainer, pl_module):
-        print(f'Pruning of {self.granularity} until a sparsity of {self.end_sparsity}%')
-        self.end_epoch = trainer.max_epochs if self.end_epoch is None else self.end_epoch
-        assert self.end_epoch <= trainer.max_epochs, 'Your end_epoch must be smaller than total number of epoch'
-
-        model = trainer.model if self.model is None else self.model # Pass a model if you don't want the whole model to be pruned
-        self.sparsifier = Sparsifier(model, self.granularity, self.method, self.criteria, self.layer_type)
-        self.n_batches = math.floor(len(trainer.datamodule.dataset_train)/trainer.datamodule.batch_size)
-        self.total_iters = self.end_epoch * self.n_batches
-        self.start_iter = self.start_epoch * self.n_batches
-        
-    def on_fit_end(self, trainer, pl_module):
-        print(f'Final Sparsity: {self.current_sparsity:.2f}')
-        if self.reset_end:
-            self.sparsifier._reset_weights()
-        self.sparsifier._clean_buffers() # Remove buffers at the end of training
-        self.sparsifier.print_sparsity()
-        
-    def on_train_epoch_start(self, trainer, pl_module):
-        if trainer.current_epoch == self.rewind_epoch:
-            print(f'Saving Weights at epoch {trainer.current_epoch}')
-            self.sparsifier._save_weights()
-        
-    def on_train_epoch_end(self, trainer, pl_module):
-        print(f'Sparsity at the end of epoch {trainer.current_epoch}: {self.current_sparsity:.2f}%')
-
-    def on_batch_start(self, trainer, pl_module):
-        self.train_iter+=1
-        if trainer.current_epoch>=self.start_epoch:
-            if trainer.current_epoch < self.end_epoch: self._set_sparsity()
-            self.sparsifier.prune_model(self.current_sparsity, self.round_to)
-
-            if self.lth and self.current_sparsity!=self.previous_sparsity: # If sparsity has changed, the network has been pruned
-                    print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
-                    self.sparsifier._reset_weights()
-
-            self.previous_sparsity = self.current_sparsity
-            
-    #def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
-    #    if trainer.current_epoch>=self.start_epoch:
-    #        self.sparsifier.prune_model(self.current_sparsity, self.round_to)
-
-    #        if self.lth and self.current_sparsity!=self.previous_sparsity: # If sparsity has changed, the network has been pruned
-    #                print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
-    #                self.sparsifier._reset_weights()
-        
-    def on_after_backward(self, trainer, pl_module): #, optimizer, opt_idx
-        if trainer.current_epoch>=self.start_epoch:
-            #print('After BW', model.model.conv1.weight.grad.sum())
-            self.sparsifier._mask_grad()
-            #print('After BW, After Prune', model.model.conv1.weight.grad.sum())
-        
-    def on_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
-        print('After Batch', model.model.conv1.weight.sum(dim=(1,2,3)))
-    
-    def on_before_optimizer_step(self, trainer, pl_module, optimizer, opt_idx):
-        print('Before Step', model.model.conv1.weight.sum(dim=(1,2,3)))
-        
-    def on_before_zero_grad(self, trainer, pl_module, optimizer):
-        print('After Step', model.model.conv1.weight.sum(dim=(1,2,3)))
-        
-    
-    def _set_sparsity(self):
-        self.current_sparsity = self.sched_func(start=self.start_sparsity, end=self.end_sparsity, pos=(self.train_iter-self.start_iter)/(self.total_iters-self.start_iter))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb = SparsifyCallbackFlash(end_sparsity=50, granularity='filter', method='local', criteria=large_final, sched_func=sched_onecycle)
-
- -
-
-
- -
-
- -
- -
-
Starting to init trainer!
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
model = LitResnet(lr=0.05)
-model.datamodule = cifar10_dm
-
-trainer = Trainer(
-    progress_bar_refresh_rate=10,
-    max_epochs=3,
-    gpus=AVAIL_GPUS,
-    logger=TensorBoardLogger("lightning_logs/", name="resnet"),
-    callbacks=[LearningRateMonitor(logging_interval="step"), sp_cb],
-)
-
-trainer.fit(model, cifar10_dm)
-#trainer.test(model, datamodule=cifar10_dm)
-
- -
-
-
- -
-
- -
- -
-
GPU available: True, used: True
-TPU available: False, using: 0 TPU cores
-IPU available: False, using: 0 IPUs
-LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
-
-  | Name  | Type   | Params
----------------------------------
-0 | model | ResNet | 11.2 M
----------------------------------
-11.2 M    Trainable params
-0         Non-trainable params
-11.2 M    Total params
-44.696    Total estimated model params size (MB)
-
-
-
- -
- -
-
Pruning of filter until a sparsity of 50%
-
-
-
- -
- -
-
Global seed set to 7
-
-
-
- -
- -
-
Saving Weights at epoch 0
-After Step tensor([ 0.9081,  0.6772, -0.3952,  0.2494,  0.1716,  0.0284, -0.5729, -1.1187,
-        -0.6003, -0.5951, -0.2932, -0.2708,  0.0602,  0.3945, -0.0567, -0.4998,
-        -0.7118,  0.6990, -0.8625, -0.6608, -0.3355, -0.0327, -0.6007, -0.2280,
-        -0.0635,  0.3191,  0.1373,  0.1476,  0.1205,  1.4984, -0.9119, -0.0762,
-        -0.7782, -0.1663, -0.9118, -0.2271,  0.5274,  0.0000,  0.1409,  0.5117,
-         0.3650, -0.2874, -0.0246,  0.7153,  0.5553,  0.5945,  0.1937,  0.4063,
-         0.5245, -0.5398,  0.5487, -0.1477,  0.4366,  0.2581,  0.7108,  0.5323,
-        -0.1485,  1.1084, -0.1958, -0.6098,  1.0361, -0.1469, -0.2773, -0.2078],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9081,  0.6772, -0.3952,  0.2494,  0.1716,  0.0284, -0.5729, -1.1187,
-        -0.6003, -0.5951, -0.2932, -0.2708,  0.0602,  0.3945, -0.0567, -0.4998,
-        -0.7118,  0.6990, -0.8625, -0.6608, -0.3355, -0.0327, -0.6007, -0.2280,
-        -0.0635,  0.3191,  0.1373,  0.1476,  0.1205,  1.4984, -0.9119, -0.0762,
-        -0.7782, -0.1663, -0.9118, -0.2271,  0.5274,  0.0000,  0.1409,  0.5117,
-         0.3650, -0.2874, -0.0246,  0.7153,  0.5553,  0.5945,  0.1937,  0.4063,
-         0.5245, -0.5398,  0.5487, -0.1477,  0.4366,  0.2581,  0.7108,  0.5323,
-        -0.1485,  1.1084, -0.1958, -0.6098,  1.0361, -0.1469, -0.2773, -0.2078],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9077,  0.6766, -0.3968,  0.2510,  0.1723,  0.0253, -0.5725, -1.1187,
-        -0.5994, -0.5963, -0.2932, -0.2697,  0.0583,  0.3911, -0.0608, -0.5002,
-        -0.7128,  0.6987, -0.8623, -0.6603, -0.3353, -0.0314, -0.6016, -0.2284,
-        -0.0652,  0.3181,  0.1375,  0.1443,  0.1189,  1.4984, -0.9123, -0.0740,
-        -0.7780, -0.1686, -0.9115, -0.2282,  0.5280,  0.0000,  0.1393,  0.5112,
-         0.3673, -0.2886, -0.0255,  0.7156,  0.5556,  0.5940,  0.1920,  0.4071,
-         0.5256, -0.5393,  0.5502, -0.1461,  0.4368,  0.2590,  0.7105,  0.5324,
-        -0.1518,  1.1084, -0.1961, -0.6084,  1.0361, -0.1472, -0.2784, -0.2083],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9077,  0.6766, -0.3968,  0.2510,  0.1723,  0.0253, -0.5725, -1.1187,
-        -0.5994, -0.5963, -0.2932, -0.2697,  0.0583,  0.3911, -0.0608, -0.5002,
-        -0.7128,  0.6987, -0.8623, -0.6603, -0.3353, -0.0314, -0.6016, -0.2284,
-        -0.0652,  0.3181,  0.1375,  0.1443,  0.1189,  1.4984, -0.9123, -0.0740,
-        -0.7780, -0.1686, -0.9115, -0.2282,  0.5280,  0.0000,  0.1393,  0.5112,
-         0.3673, -0.2886, -0.0255,  0.7156,  0.5556,  0.5940,  0.1920,  0.4071,
-         0.5256, -0.5393,  0.5502, -0.1461,  0.4368,  0.2590,  0.7105,  0.5324,
-        -0.1518,  1.1084, -0.1961, -0.6084,  1.0361, -0.1472, -0.2784, -0.2083],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9077,  0.6766, -0.3968,  0.2510,  0.1723,  0.0253, -0.5725, -1.1187,
-        -0.5994, -0.5963, -0.2932, -0.2697,  0.0583,  0.3911, -0.0608, -0.5002,
-        -0.7128,  0.6987, -0.8623, -0.6603, -0.3353, -0.0314, -0.6016, -0.2284,
-        -0.0652,  0.3181,  0.1375,  0.1443,  0.1189,  1.4984, -0.9123, -0.0740,
-        -0.7780, -0.1686, -0.9115, -0.2282,  0.5280,  0.0000,  0.1393,  0.5112,
-         0.3673, -0.2886, -0.0255,  0.7156,  0.5556,  0.5940,  0.1920,  0.4071,
-         0.5256, -0.5393,  0.5502, -0.1461,  0.4368,  0.2590,  0.7105,  0.5324,
-        -0.1518,  1.1084, -0.1961, -0.6084,  1.0361, -0.1472, -0.2784, -0.2083],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9067,  0.6768, -0.3994,  0.2519,  0.1738,  0.0215, -0.5718, -1.1184,
-        -0.5988, -0.5976, -0.2888, -0.2670,  0.0592,  0.3892, -0.0651, -0.5004,
-        -0.7149,  0.6978, -0.8622, -0.6604, -0.3347, -0.0293, -0.6020, -0.2296,
-        -0.0675,  0.3181,  0.1361,  0.1391,  0.1179,  1.4982, -0.9125, -0.0722,
-        -0.7785, -0.1713, -0.9110, -0.2300,  0.5292,  0.0000,  0.1344,  0.5093,
-         0.3678, -0.2896, -0.0236,  0.7161,  0.5565,  0.5941,  0.1913,  0.4073,
-         0.5262, -0.5394,  0.5518, -0.1395,  0.4371,  0.2598,  0.7110,  0.5323,
-        -0.1547,  1.1088, -0.1971, -0.6074,  1.0365, -0.1461, -0.2777, -0.2076],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9067,  0.6768, -0.3994,  0.2519,  0.1738,  0.0215, -0.5718, -1.1184,
-        -0.5988, -0.5976, -0.2888, -0.2670,  0.0592,  0.3892, -0.0651, -0.5004,
-        -0.7149,  0.6978, -0.8622, -0.6604, -0.3347, -0.0293, -0.6020, -0.2296,
-        -0.0675,  0.3181,  0.1361,  0.1391,  0.1179,  1.4982, -0.9125, -0.0722,
-        -0.7785, -0.1713, -0.9110, -0.2300,  0.5292,  0.0000,  0.1344,  0.5093,
-         0.3678, -0.2896, -0.0236,  0.7161,  0.5565,  0.5941,  0.1913,  0.4073,
-         0.5262, -0.5394,  0.5518, -0.1395,  0.4371,  0.2598,  0.7110,  0.5323,
-        -0.1547,  1.1088, -0.1971, -0.6074,  1.0365, -0.1461, -0.2777, -0.2076],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9067,  0.6768, -0.3994,  0.2519,  0.1738,  0.0215, -0.5718, -1.1184,
-        -0.5988, -0.5976, -0.2888, -0.2670,  0.0592,  0.3892, -0.0651, -0.5004,
-        -0.7149,  0.6978, -0.8622, -0.6604, -0.3347, -0.0293, -0.6020, -0.2296,
-        -0.0675,  0.3181,  0.1361,  0.1391,  0.1179,  1.4982, -0.9125, -0.0722,
-        -0.7785, -0.1713, -0.9110, -0.2300,  0.5292,  0.0000,  0.1344,  0.5093,
-         0.3678, -0.2896, -0.0236,  0.7161,  0.5565,  0.5941,  0.1913,  0.4073,
-         0.5262, -0.5394,  0.5518, -0.1395,  0.4371,  0.2598,  0.7110,  0.5323,
-        -0.1547,  1.1088, -0.1971, -0.6074,  1.0365, -0.1461, -0.2777, -0.2076],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9054,  0.6777, -0.4027,  0.2506,  0.1768,  0.0169, -0.5707, -1.1182,
-        -0.5983, -0.5986, -0.2849, -0.2662,  0.0603,  0.3884, -0.0682, -0.5005,
-        -0.7170,  0.6972, -0.8620, -0.6612, -0.3324, -0.0273, -0.6027, -0.2344,
-        -0.0713,  0.3180,  0.1321,  0.1368,  0.1171,  1.4979, -0.9123, -0.0706,
-        -0.7794, -0.1755, -0.9104, -0.2303,  0.5306,  0.0000,  0.1321,  0.5081,
-         0.3678, -0.2889, -0.0219,  0.7161,  0.5574,  0.5945,  0.1923,  0.4087,
-         0.5273, -0.5389,  0.5536, -0.1342,  0.4376,  0.2593,  0.7122,  0.5322,
-        -0.1581,  1.1091, -0.1973, -0.6073,  1.0371, -0.1443, -0.2750, -0.2057],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9054,  0.6777, -0.4027,  0.2506,  0.1768,  0.0169, -0.5707, -1.1182,
-        -0.5983, -0.5986, -0.2849, -0.2662,  0.0603,  0.3884, -0.0682, -0.5005,
-        -0.7170,  0.6972, -0.8620, -0.6612, -0.3324, -0.0273, -0.6027, -0.2344,
-        -0.0713,  0.3180,  0.1321,  0.1368,  0.1171,  1.4979, -0.9123, -0.0706,
-        -0.7794, -0.1755, -0.9104, -0.2303,  0.5306,  0.0000,  0.1321,  0.5081,
-         0.3678, -0.2889, -0.0219,  0.7161,  0.5574,  0.5945,  0.1923,  0.4087,
-         0.5273, -0.5389,  0.5536, -0.1342,  0.4376,  0.2593,  0.7122,  0.5322,
-        -0.1581,  1.1091, -0.1973, -0.6073,  1.0371, -0.1443, -0.2750, -0.2057],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9054,  0.6777, -0.4027,  0.2506,  0.1768,  0.0169, -0.5707, -1.1182,
-        -0.5983, -0.5986, -0.2849, -0.2662,  0.0603,  0.3884, -0.0682, -0.5005,
-        -0.7170,  0.6972, -0.8620, -0.6612, -0.3324, -0.0273, -0.6027, -0.2344,
-        -0.0713,  0.3180,  0.1321,  0.1368,  0.1171,  1.4979, -0.9123, -0.0706,
-        -0.7794, -0.1755, -0.9104, -0.2303,  0.5306,  0.0000,  0.1321,  0.5081,
-         0.3678, -0.2889, -0.0219,  0.7161,  0.5574,  0.5945,  0.1923,  0.4087,
-         0.5273, -0.5389,  0.5536, -0.1342,  0.4376,  0.2593,  0.7122,  0.5322,
-        -0.1581,  1.1091, -0.1973, -0.6073,  1.0371, -0.1443, -0.2750, -0.2057],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9047,  0.6789, -0.4069,  0.2467,  0.1825,  0.0134, -0.5705, -1.1180,
-        -0.5983, -0.5991, -0.2800, -0.2661,  0.0600,  0.3891, -0.0648, -0.5013,
-        -0.7192,  0.6964, -0.8612, -0.6632, -0.3316, -0.0273, -0.6032, -0.2386,
-        -0.0747,  0.3168,  0.1275,  0.1369,  0.1147,  1.4978, -0.9122, -0.0674,
-        -0.7799, -0.1790, -0.9100, -0.2305,  0.5335,  0.0000,  0.1285,  0.5066,
-         0.3676, -0.2888, -0.0226,  0.7160,  0.5588,  0.5953,  0.1949,  0.4107,
-         0.5263, -0.5385,  0.5556, -0.1313,  0.4397,  0.2578,  0.7134,  0.5325,
-        -0.1588,  1.1093, -0.1960, -0.6065,  1.0374, -0.1432, -0.2723, -0.2028],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9047,  0.6789, -0.4069,  0.2467,  0.1825,  0.0134, -0.5705, -1.1180,
-        -0.5983, -0.5991, -0.2800, -0.2661,  0.0600,  0.3891, -0.0648, -0.5013,
-        -0.7192,  0.6964, -0.8612, -0.6632, -0.3316, -0.0273, -0.6032, -0.2386,
-        -0.0747,  0.3168,  0.1275,  0.1369,  0.1147,  1.4978, -0.9122, -0.0674,
-        -0.7799, -0.1790, -0.9100, -0.2305,  0.5335,  0.0000,  0.1285,  0.5066,
-         0.3676, -0.2888, -0.0226,  0.7160,  0.5588,  0.5953,  0.1949,  0.4107,
-         0.5263, -0.5385,  0.5556, -0.1313,  0.4397,  0.2578,  0.7134,  0.5325,
-        -0.1588,  1.1093, -0.1960, -0.6065,  1.0374, -0.1432, -0.2723, -0.2028],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9047,  0.6789, -0.4069,  0.2467,  0.1825,  0.0134, -0.5705, -1.1180,
-        -0.5983, -0.5991, -0.2800, -0.2661,  0.0600,  0.3891, -0.0648, -0.5013,
-        -0.7192,  0.6964, -0.8612, -0.6632, -0.3316, -0.0273, -0.6032, -0.2386,
-        -0.0747,  0.3168,  0.1275,  0.1369,  0.1147,  1.4978, -0.9122, -0.0674,
-        -0.7799, -0.1790, -0.9100, -0.2305,  0.5335,  0.0000,  0.1285,  0.5066,
-         0.3676, -0.2888, -0.0226,  0.7160,  0.5588,  0.5953,  0.1949,  0.4107,
-         0.5263, -0.5385,  0.5556, -0.1313,  0.4397,  0.2578,  0.7134,  0.5325,
-        -0.1588,  1.1093, -0.1960, -0.6065,  1.0374, -0.1432, -0.2723, -0.2028],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9039,  0.6800, -0.4124,  0.2452,  0.1900,  0.0067, -0.5710, -1.1178,
-        -0.5980, -0.5994, -0.2733, -0.2691,  0.0571,  0.3893, -0.0627, -0.5025,
-        -0.7218,  0.6962, -0.8608, -0.6641, -0.3306, -0.0264, -0.6032, -0.2454,
-        -0.0787,  0.3152,  0.1191,  0.1348,  0.1121,  1.4975, -0.9119, -0.0649,
-        -0.7805, -0.1829, -0.9098, -0.2318,  0.5354,  0.0000,  0.1287,  0.5045,
-         0.3671, -0.2912, -0.0235,  0.7155,  0.5597,  0.5969,  0.1956,  0.4132,
-         0.5252, -0.5378,  0.5577, -0.1288,  0.4389,  0.2536,  0.7148,  0.5332,
-        -0.1587,  1.1094, -0.1920, -0.6054,  1.0375, -0.1405, -0.2689, -0.1993],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9039,  0.6800, -0.4124,  0.2452,  0.1900,  0.0067, -0.5710, -1.1178,
-        -0.5980, -0.5994, -0.2733, -0.2691,  0.0571,  0.3893, -0.0627, -0.5025,
-        -0.7218,  0.6962, -0.8608, -0.6641, -0.3306, -0.0264, -0.6032, -0.2454,
-        -0.0787,  0.3152,  0.1191,  0.1348,  0.1121,  1.4975, -0.9119, -0.0649,
-        -0.7805, -0.1829, -0.9098, -0.2318,  0.5354,  0.0000,  0.1287,  0.5045,
-         0.3671, -0.2912, -0.0235,  0.7155,  0.5597,  0.5969,  0.1956,  0.4132,
-         0.5252, -0.5378,  0.5577, -0.1288,  0.4389,  0.2536,  0.7148,  0.5332,
-        -0.1587,  1.1094, -0.1920, -0.6054,  1.0375, -0.1405, -0.2689, -0.1993],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9039,  0.6800, -0.4124,  0.2452,  0.1900,  0.0067, -0.5710, -1.1178,
-        -0.5980, -0.5994, -0.2733, -0.2691,  0.0571,  0.3893, -0.0627, -0.5025,
-        -0.7218,  0.6962, -0.8608, -0.6641, -0.3306, -0.0264, -0.6032, -0.2454,
-        -0.0787,  0.3152,  0.1191,  0.1348,  0.1121,  1.4975, -0.9119, -0.0649,
-        -0.7805, -0.1829, -0.9098, -0.2318,  0.5354,  0.0000,  0.1287,  0.5045,
-         0.3671, -0.2912, -0.0235,  0.7155,  0.5597,  0.5969,  0.1956,  0.4132,
-         0.5252, -0.5378,  0.5577, -0.1288,  0.4389,  0.2536,  0.7148,  0.5332,
-        -0.1587,  1.1094, -0.1920, -0.6054,  1.0375, -0.1405, -0.2689, -0.1993],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9035,  0.6803, -0.4170,  0.2432,  0.1989,  0.0034, -0.5718, -1.1175,
-        -0.5966, -0.5992, -0.2676, -0.2730,  0.0548,  0.3880, -0.0651, -0.5043,
-        -0.7242,  0.6950, -0.8603, -0.6645, -0.3307, -0.0262, -0.6033, -0.2486,
-        -0.0795,  0.3143,  0.1072,  0.1324,  0.1107,  1.4973, -0.9114, -0.0671,
-        -0.7813, -0.1838, -0.9096, -0.2323,  0.5361,  0.0000,  0.1296,  0.5025,
-         0.3656, -0.2931, -0.0241,  0.7144,  0.5599,  0.5982,  0.1948,  0.4162,
-         0.5250, -0.5372,  0.5586, -0.1258,  0.4399,  0.2505,  0.7168,  0.5335,
-        -0.1593,  1.1096, -0.1864, -0.6047,  1.0376, -0.1376, -0.2682, -0.1942],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9035,  0.6803, -0.4170,  0.2432,  0.1989,  0.0034, -0.5718, -1.1175,
-        -0.5966, -0.5992, -0.2676, -0.2730,  0.0548,  0.3880, -0.0651, -0.5043,
-        -0.7242,  0.6950, -0.8603, -0.6645, -0.3307, -0.0262, -0.6033, -0.2486,
-        -0.0795,  0.3143,  0.1072,  0.1324,  0.1107,  1.4973, -0.9114, -0.0671,
-        -0.7813, -0.1838, -0.9096, -0.2323,  0.5361,  0.0000,  0.1296,  0.5025,
-         0.3656, -0.2931, -0.0241,  0.7144,  0.5599,  0.5982,  0.1948,  0.4162,
-         0.5250, -0.5372,  0.5586, -0.1258,  0.4399,  0.2505,  0.7168,  0.5335,
-        -0.1593,  1.1096, -0.1864, -0.6047,  1.0376, -0.1376, -0.2682, -0.1942],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9035,  0.6803, -0.4170,  0.2432,  0.1989,  0.0034, -0.5718, -1.1175,
-        -0.5966, -0.5992, -0.2676, -0.2730,  0.0548,  0.3880, -0.0651, -0.5043,
-        -0.7242,  0.6950, -0.8603, -0.6645, -0.3307, -0.0262, -0.6033, -0.2486,
-        -0.0795,  0.3143,  0.1072,  0.1324,  0.1107,  1.4973, -0.9114, -0.0671,
-        -0.7813, -0.1838, -0.9096, -0.2323,  0.5361,  0.0000,  0.1296,  0.5025,
-         0.3656, -0.2931, -0.0241,  0.7144,  0.5599,  0.5982,  0.1948,  0.4162,
-         0.5250, -0.5372,  0.5586, -0.1258,  0.4399,  0.2505,  0.7168,  0.5335,
-        -0.1593,  1.1096, -0.1864, -0.6047,  1.0376, -0.1376, -0.2682, -0.1942],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 9.0326e-01,  6.8058e-01, -4.2272e-01,  2.3786e-01,  2.0892e-01,
-        -6.7803e-04, -5.7346e-01, -1.1174e+00, -5.9490e-01, -5.9974e-01,
-        -2.6319e-01, -2.7499e-01,  5.0632e-02,  3.8441e-01, -6.4690e-02,
-        -5.0749e-01, -7.2754e-01,  6.9292e-01, -8.5993e-01, -6.6510e-01,
-        -3.3036e-01, -2.8271e-02, -6.0367e-01, -2.5165e-01, -8.1150e-02,
-         3.1468e-01,  9.5137e-02,  1.2634e-01,  1.1046e-01,  1.4974e+00,
-        -9.1108e-01, -6.9237e-02, -7.8212e-01, -1.8506e-01, -9.0969e-01,
-        -2.3252e-01,  5.3505e-01,  0.0000e+00,  1.3171e-01,  5.0027e-01,
-         3.6314e-01, -2.9196e-01, -2.6607e-02,  7.1387e-01,  5.6020e-01,
-         6.0096e-01,  2.0155e-01,  4.1924e-01,  5.2513e-01, -5.3732e-01,
-         5.5881e-01, -1.2220e-01,  4.4163e-01,  2.4590e-01,  7.1860e-01,
-         5.3458e-01, -1.5731e-01,  1.1098e+00, -1.8395e-01, -6.0554e-01,
-         1.0378e+00, -1.3325e-01, -2.6667e-01, -1.8969e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 9.0326e-01,  6.8058e-01, -4.2272e-01,  2.3786e-01,  2.0892e-01,
-        -6.7803e-04, -5.7346e-01, -1.1174e+00, -5.9490e-01, -5.9974e-01,
-        -2.6319e-01, -2.7499e-01,  5.0632e-02,  3.8441e-01, -6.4690e-02,
-        -5.0749e-01, -7.2754e-01,  6.9292e-01, -8.5993e-01, -6.6510e-01,
-        -3.3036e-01, -2.8271e-02, -6.0367e-01, -2.5165e-01, -8.1150e-02,
-         3.1468e-01,  9.5137e-02,  1.2634e-01,  1.1046e-01,  1.4974e+00,
-        -9.1108e-01, -6.9237e-02, -7.8212e-01, -1.8506e-01, -9.0969e-01,
-        -2.3252e-01,  5.3505e-01,  0.0000e+00,  1.3171e-01,  5.0027e-01,
-         3.6314e-01, -2.9196e-01, -2.6607e-02,  7.1387e-01,  5.6020e-01,
-         6.0096e-01,  2.0155e-01,  4.1924e-01,  5.2513e-01, -5.3732e-01,
-         5.5881e-01, -1.2220e-01,  4.4163e-01,  2.4590e-01,  7.1860e-01,
-         5.3458e-01, -1.5731e-01,  1.1098e+00, -1.8395e-01, -6.0554e-01,
-         1.0378e+00, -1.3325e-01, -2.6667e-01, -1.8969e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 9.0326e-01,  6.8058e-01, -4.2272e-01,  2.3786e-01,  2.0892e-01,
-        -6.7803e-04, -5.7346e-01, -1.1174e+00, -5.9490e-01, -5.9974e-01,
-        -2.6319e-01, -2.7499e-01,  5.0632e-02,  3.8441e-01, -6.4690e-02,
-        -5.0749e-01, -7.2754e-01,  6.9292e-01, -8.5993e-01, -6.6510e-01,
-        -3.3036e-01, -2.8271e-02, -6.0367e-01, -2.5165e-01, -8.1150e-02,
-         3.1468e-01,  9.5137e-02,  1.2634e-01,  1.1046e-01,  1.4974e+00,
-        -9.1108e-01, -6.9237e-02, -7.8212e-01, -1.8506e-01, -9.0969e-01,
-        -2.3252e-01,  5.3505e-01,  0.0000e+00,  1.3171e-01,  5.0027e-01,
-         3.6314e-01, -2.9196e-01, -2.6607e-02,  7.1387e-01,  5.6020e-01,
-         6.0096e-01,  2.0155e-01,  4.1924e-01,  5.2513e-01, -5.3732e-01,
-         5.5881e-01, -1.2220e-01,  4.4163e-01,  2.4590e-01,  7.1860e-01,
-         5.3458e-01, -1.5731e-01,  1.1098e+00, -1.8395e-01, -6.0554e-01,
-         1.0378e+00, -1.3325e-01, -2.6667e-01, -1.8969e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 9.0331e-01,  6.8092e-01, -4.2676e-01,  2.3341e-01,  2.2045e-01,
-        -1.1952e-03, -5.7586e-01, -1.1174e+00, -5.9263e-01, -5.9908e-01,
-        -2.5908e-01, -2.7426e-01,  4.6200e-02,  3.7879e-01, -6.2461e-02,
-        -5.1004e-01, -7.3088e-01,  6.9179e-01, -8.5985e-01, -6.6591e-01,
-        -3.2887e-01, -3.4696e-02, -6.0402e-01, -2.5573e-01, -8.5363e-02,
-         3.1356e-01,  8.0069e-02,  1.1617e-01,  1.1121e-01,  1.4977e+00,
-        -9.1085e-01, -7.4302e-02, -7.8335e-01, -1.8587e-01, -9.0964e-01,
-        -2.3405e-01,  5.3210e-01,  0.0000e+00,  1.3593e-01,  4.9847e-01,
-         3.6028e-01, -2.9464e-01, -2.9597e-02,  7.1263e-01,  5.5947e-01,
-         6.0425e-01,  2.0614e-01,  4.2172e-01,  5.2841e-01, -5.3748e-01,
-         5.5806e-01, -1.1630e-01,  4.4254e-01,  2.4232e-01,  7.2057e-01,
-         5.3573e-01, -1.5028e-01,  1.1101e+00, -1.8262e-01, -6.0682e-01,
-         1.0382e+00, -1.2586e-01, -2.6385e-01, -1.8694e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 9.0331e-01,  6.8092e-01, -4.2676e-01,  2.3341e-01,  2.2045e-01,
-        -1.1952e-03, -5.7586e-01, -1.1174e+00, -5.9263e-01, -5.9908e-01,
-        -2.5908e-01, -2.7426e-01,  4.6200e-02,  3.7879e-01, -6.2461e-02,
-        -5.1004e-01, -7.3088e-01,  6.9179e-01, -8.5985e-01, -6.6591e-01,
-        -3.2887e-01, -3.4696e-02, -6.0402e-01, -2.5573e-01, -8.5363e-02,
-         3.1356e-01,  8.0069e-02,  1.1617e-01,  1.1121e-01,  1.4977e+00,
-        -9.1085e-01, -7.4302e-02, -7.8335e-01, -1.8587e-01, -9.0964e-01,
-        -2.3405e-01,  5.3210e-01,  0.0000e+00,  1.3593e-01,  4.9847e-01,
-         3.6028e-01, -2.9464e-01, -2.9597e-02,  7.1263e-01,  5.5947e-01,
-         6.0425e-01,  2.0614e-01,  4.2172e-01,  5.2841e-01, -5.3748e-01,
-         5.5806e-01, -1.1630e-01,  4.4254e-01,  2.4232e-01,  7.2057e-01,
-         5.3573e-01, -1.5028e-01,  1.1101e+00, -1.8262e-01, -6.0682e-01,
-         1.0382e+00, -1.2586e-01, -2.6385e-01, -1.8694e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 9.0331e-01,  6.8092e-01, -4.2676e-01,  2.3341e-01,  2.2045e-01,
-        -1.1952e-03, -5.7586e-01, -1.1174e+00, -5.9263e-01, -5.9908e-01,
-        -2.5908e-01, -2.7426e-01,  4.6200e-02,  3.7879e-01, -6.2461e-02,
-        -5.1004e-01, -7.3088e-01,  6.9179e-01, -8.5985e-01, -6.6591e-01,
-        -3.2887e-01, -3.4696e-02, -6.0402e-01, -2.5573e-01, -8.5363e-02,
-         3.1356e-01,  8.0069e-02,  1.1617e-01,  1.1121e-01,  1.4977e+00,
-        -9.1085e-01, -7.4302e-02, -7.8335e-01, -1.8587e-01, -9.0964e-01,
-        -2.3405e-01,  5.3210e-01,  0.0000e+00,  1.3593e-01,  4.9847e-01,
-         3.6028e-01, -2.9464e-01, -2.9597e-02,  7.1263e-01,  5.5947e-01,
-         6.0425e-01,  2.0614e-01,  4.2172e-01,  5.2841e-01, -5.3748e-01,
-         5.5806e-01, -1.1630e-01,  4.4254e-01,  2.4232e-01,  7.2057e-01,
-         5.3573e-01, -1.5028e-01,  1.1101e+00, -1.8262e-01, -6.0682e-01,
-         1.0382e+00, -1.2586e-01, -2.6385e-01, -1.8694e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 9.0362e-01,  6.8231e-01, -4.3062e-01,  2.2729e-01,  2.3439e-01,
-        -4.5037e-04, -5.7814e-01, -1.1173e+00, -5.8928e-01, -5.9883e-01,
-        -2.5595e-01, -2.7521e-01,  4.2081e-02,  3.7035e-01, -6.8336e-02,
-        -5.1345e-01, -7.3419e-01,  6.9036e-01, -8.5969e-01, -6.6726e-01,
-        -3.2734e-01, -3.9906e-02, -6.0411e-01, -2.5952e-01, -8.9613e-02,
-         3.1316e-01,  6.8056e-02,  1.0623e-01,  1.1009e-01,  1.4980e+00,
-        -9.1055e-01, -7.3296e-02, -7.8475e-01, -1.8351e-01, -9.0979e-01,
-        -2.3460e-01,  5.2860e-01,  0.0000e+00,  1.4435e-01,  4.9718e-01,
-         3.5915e-01, -2.9532e-01, -2.8910e-02,  7.1120e-01,  5.5895e-01,
-         6.0758e-01,  2.1014e-01,  4.2308e-01,  5.3264e-01, -5.3784e-01,
-         5.5652e-01, -1.1029e-01,  4.4389e-01,  2.3881e-01,  7.2261e-01,
-         5.3638e-01, -1.4419e-01,  1.1106e+00, -1.8342e-01, -6.0779e-01,
-         1.0385e+00, -1.1912e-01, -2.6227e-01, -1.8456e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 9.0362e-01,  6.8231e-01, -4.3062e-01,  2.2729e-01,  2.3439e-01,
-        -4.5037e-04, -5.7814e-01, -1.1173e+00, -5.8928e-01, -5.9883e-01,
-        -2.5595e-01, -2.7521e-01,  4.2081e-02,  3.7035e-01, -6.8336e-02,
-        -5.1345e-01, -7.3419e-01,  6.9036e-01, -8.5969e-01, -6.6726e-01,
-        -3.2734e-01, -3.9906e-02, -6.0411e-01, -2.5952e-01, -8.9613e-02,
-         3.1316e-01,  6.8056e-02,  1.0623e-01,  1.1009e-01,  1.4980e+00,
-        -9.1055e-01, -7.3296e-02, -7.8475e-01, -1.8351e-01, -9.0979e-01,
-        -2.3460e-01,  5.2860e-01,  0.0000e+00,  1.4435e-01,  4.9718e-01,
-         3.5915e-01, -2.9532e-01, -2.8910e-02,  7.1120e-01,  5.5895e-01,
-         6.0758e-01,  2.1014e-01,  4.2308e-01,  5.3264e-01, -5.3784e-01,
-         5.5652e-01, -1.1029e-01,  4.4389e-01,  2.3881e-01,  7.2261e-01,
-         5.3638e-01, -1.4419e-01,  1.1106e+00, -1.8342e-01, -6.0779e-01,
-         1.0385e+00, -1.1912e-01, -2.6227e-01, -1.8456e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 9.0362e-01,  6.8231e-01, -4.3062e-01,  2.2729e-01,  2.3439e-01,
-        -4.5037e-04, -5.7814e-01, -1.1173e+00, -5.8928e-01, -5.9883e-01,
-        -2.5595e-01, -2.7521e-01,  4.2081e-02,  3.7035e-01, -6.8336e-02,
-        -5.1345e-01, -7.3419e-01,  6.9036e-01, -8.5969e-01, -6.6726e-01,
-        -3.2734e-01, -3.9906e-02, -6.0411e-01, -2.5952e-01, -8.9613e-02,
-         3.1316e-01,  6.8056e-02,  1.0623e-01,  1.1009e-01,  1.4980e+00,
-        -9.1055e-01, -7.3296e-02, -7.8475e-01, -1.8351e-01, -9.0979e-01,
-        -2.3460e-01,  5.2860e-01,  0.0000e+00,  1.4435e-01,  4.9718e-01,
-         3.5915e-01, -2.9532e-01, -2.8910e-02,  7.1120e-01,  5.5895e-01,
-         6.0758e-01,  2.1014e-01,  4.2308e-01,  5.3264e-01, -5.3784e-01,
-         5.5652e-01, -1.1029e-01,  4.4389e-01,  2.3881e-01,  7.2261e-01,
-         5.3638e-01, -1.4419e-01,  1.1106e+00, -1.8342e-01, -6.0779e-01,
-         1.0385e+00, -1.1912e-01, -2.6227e-01, -1.8456e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9046,  0.6840, -0.4340,  0.2203,  0.2470,  0.0019, -0.5806, -1.1174,
-        -0.5855, -0.5996, -0.2532, -0.2769,  0.0369,  0.3613, -0.0722, -0.5172,
-        -0.7381,  0.6895, -0.8589, -0.6684, -0.3252, -0.0450, -0.6036, -0.2670,
-        -0.0944,  0.3126,  0.0525,  0.0969,  0.1118,  1.4988, -0.9097, -0.0727,
-        -0.7858, -0.1827, -0.9101, -0.2380,  0.5255,  0.0000,  0.1522,  0.4964,
-         0.3578, -0.2947, -0.0253,  0.7100,  0.5594,  0.6102,  0.2206,  0.4255,
-         0.5360, -0.5375,  0.5551, -0.1036,  0.4456,  0.2389,  0.7249,  0.5359,
-        -0.1330,  1.1106, -0.1843, -0.6086,  1.0390, -0.1126, -0.2634, -0.1828],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9046,  0.6840, -0.4340,  0.2203,  0.2470,  0.0019, -0.5806, -1.1174,
-        -0.5855, -0.5996, -0.2532, -0.2769,  0.0369,  0.3613, -0.0722, -0.5172,
-        -0.7381,  0.6895, -0.8589, -0.6684, -0.3252, -0.0450, -0.6036, -0.2670,
-        -0.0944,  0.3126,  0.0525,  0.0969,  0.1118,  1.4988, -0.9097, -0.0727,
-        -0.7858, -0.1827, -0.9101, -0.2380,  0.5255,  0.0000,  0.1522,  0.4964,
-         0.3578, -0.2947, -0.0253,  0.7100,  0.5594,  0.6102,  0.2206,  0.4255,
-         0.5360, -0.5375,  0.5551, -0.1036,  0.4456,  0.2389,  0.7249,  0.5359,
-        -0.1330,  1.1106, -0.1843, -0.6086,  1.0390, -0.1126, -0.2634, -0.1828],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9046,  0.6840, -0.4340,  0.2203,  0.2470,  0.0019, -0.5806, -1.1174,
-        -0.5855, -0.5996, -0.2532, -0.2769,  0.0369,  0.3613, -0.0722, -0.5172,
-        -0.7381,  0.6895, -0.8589, -0.6684, -0.3252, -0.0450, -0.6036, -0.2670,
-        -0.0944,  0.3126,  0.0525,  0.0969,  0.1118,  1.4988, -0.9097, -0.0727,
-        -0.7858, -0.1827, -0.9101, -0.2380,  0.5255,  0.0000,  0.1522,  0.4964,
-         0.3578, -0.2947, -0.0253,  0.7100,  0.5594,  0.6102,  0.2206,  0.4255,
-         0.5360, -0.5375,  0.5551, -0.1036,  0.4456,  0.2389,  0.7249,  0.5359,
-        -0.1330,  1.1106, -0.1843, -0.6086,  1.0390, -0.1126, -0.2634, -0.1828],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9053,  0.6864, -0.4376,  0.2118,  0.2558,  0.0022, -0.5828, -1.1176,
-        -0.5831, -0.5999, -0.2494, -0.2795,  0.0338,  0.3495, -0.0730, -0.5217,
-        -0.7424,  0.6882, -0.8579, -0.6685, -0.3240, -0.0504, -0.6026, -0.2754,
-        -0.0995,  0.3148,  0.0358,  0.0861,  0.1081,  1.4994, -0.9089, -0.0729,
-        -0.7869, -0.1824, -0.9107, -0.2411,  0.5223,  0.0000,  0.1603,  0.4957,
-         0.3577, -0.2915, -0.0220,  0.7095,  0.5597,  0.6133,  0.2308,  0.4271,
-         0.5402, -0.5375,  0.5538, -0.0994,  0.4472,  0.2405,  0.7276,  0.5359,
-        -0.1212,  1.1109, -0.1841, -0.6085,  1.0396, -0.1053, -0.2622, -0.1839],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9053,  0.6864, -0.4376,  0.2118,  0.2558,  0.0022, -0.5828, -1.1176,
-        -0.5831, -0.5999, -0.2494, -0.2795,  0.0338,  0.3495, -0.0730, -0.5217,
-        -0.7424,  0.6882, -0.8579, -0.6685, -0.3240, -0.0504, -0.6026, -0.2754,
-        -0.0995,  0.3148,  0.0358,  0.0861,  0.1081,  1.4994, -0.9089, -0.0729,
-        -0.7869, -0.1824, -0.9107, -0.2411,  0.5223,  0.0000,  0.1603,  0.4957,
-         0.3577, -0.2915, -0.0220,  0.7095,  0.5597,  0.6133,  0.2308,  0.4271,
-         0.5402, -0.5375,  0.5538, -0.0994,  0.4472,  0.2405,  0.7276,  0.5359,
-        -0.1212,  1.1109, -0.1841, -0.6085,  1.0396, -0.1053, -0.2622, -0.1839],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9053,  0.6864, -0.4376,  0.2118,  0.2558,  0.0022, -0.5828, -1.1176,
-        -0.5831, -0.5999, -0.2494, -0.2795,  0.0338,  0.3495, -0.0730, -0.5217,
-        -0.7424,  0.6882, -0.8579, -0.6685, -0.3240, -0.0504, -0.6026, -0.2754,
-        -0.0995,  0.3148,  0.0358,  0.0861,  0.1081,  1.4994, -0.9089, -0.0729,
-        -0.7869, -0.1824, -0.9107, -0.2411,  0.5223,  0.0000,  0.1603,  0.4957,
-         0.3577, -0.2915, -0.0220,  0.7095,  0.5597,  0.6133,  0.2308,  0.4271,
-         0.5402, -0.5375,  0.5538, -0.0994,  0.4472,  0.2405,  0.7276,  0.5359,
-        -0.1212,  1.1109, -0.1841, -0.6085,  1.0396, -0.1053, -0.2622, -0.1839],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9058,  0.6879, -0.4411,  0.2004,  0.2613,  0.0032, -0.5840, -1.1177,
-        -0.5818, -0.6004, -0.2454, -0.2772,  0.0304,  0.3370, -0.0675, -0.5260,
-        -0.7459,  0.6865, -0.8573, -0.6697, -0.3234, -0.0565, -0.6018, -0.2808,
-        -0.0975,  0.3192,  0.0183,  0.0736,  0.1054,  1.5001, -0.9085, -0.0747,
-        -0.7877, -0.1845, -0.9111, -0.2469,  0.5200,  0.0000,  0.1673,  0.4956,
-         0.3596, -0.2896, -0.0168,  0.7094,  0.5603,  0.6160,  0.2407,  0.4282,
-         0.5458, -0.5378,  0.5521, -0.0961,  0.4486,  0.2419,  0.7297,  0.5338,
-        -0.1104,  1.1113, -0.1862, -0.6085,  1.0398, -0.0969, -0.2615, -0.1872],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9058,  0.6879, -0.4411,  0.2004,  0.2613,  0.0032, -0.5840, -1.1177,
-        -0.5818, -0.6004, -0.2454, -0.2772,  0.0304,  0.3370, -0.0675, -0.5260,
-        -0.7459,  0.6865, -0.8573, -0.6697, -0.3234, -0.0565, -0.6018, -0.2808,
-        -0.0975,  0.3192,  0.0183,  0.0736,  0.1054,  1.5001, -0.9085, -0.0747,
-        -0.7877, -0.1845, -0.9111, -0.2469,  0.5200,  0.0000,  0.1673,  0.4956,
-         0.3596, -0.2896, -0.0168,  0.7094,  0.5603,  0.6160,  0.2407,  0.4282,
-         0.5458, -0.5378,  0.5521, -0.0961,  0.4486,  0.2419,  0.7297,  0.5338,
-        -0.1104,  1.1113, -0.1862, -0.6085,  1.0398, -0.0969, -0.2615, -0.1872],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9058,  0.6879, -0.4411,  0.2004,  0.2613,  0.0032, -0.5840, -1.1177,
-        -0.5818, -0.6004, -0.2454, -0.2772,  0.0304,  0.3370, -0.0675, -0.5260,
-        -0.7459,  0.6865, -0.8573, -0.6697, -0.3234, -0.0565, -0.6018, -0.2808,
-        -0.0975,  0.3192,  0.0183,  0.0736,  0.1054,  1.5001, -0.9085, -0.0747,
-        -0.7877, -0.1845, -0.9111, -0.2469,  0.5200,  0.0000,  0.1673,  0.4956,
-         0.3596, -0.2896, -0.0168,  0.7094,  0.5603,  0.6160,  0.2407,  0.4282,
-         0.5458, -0.5378,  0.5521, -0.0961,  0.4486,  0.2419,  0.7297,  0.5338,
-        -0.1104,  1.1113, -0.1862, -0.6085,  1.0398, -0.0969, -0.2615, -0.1872],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9066,  0.6897, -0.4438,  0.1929,  0.2643,  0.0043, -0.5846, -1.1180,
-        -0.5805, -0.6003, -0.2405, -0.2728,  0.0349,  0.3279, -0.0627, -0.5295,
-        -0.7491,  0.6861, -0.8571, -0.6700, -0.3220, -0.0618, -0.6008, -0.2866,
-        -0.0927,  0.3199,  0.0120,  0.0559,  0.1042,  1.5007, -0.9084, -0.0724,
-        -0.7880, -0.1825, -0.9118, -0.2538,  0.5168,  0.0000,  0.1815,  0.4954,
-         0.3621, -0.2863, -0.0147,  0.7091,  0.5611,  0.6199,  0.2499,  0.4305,
-         0.5527, -0.5380,  0.5504, -0.0929,  0.4498,  0.2489,  0.7318,  0.5301,
-        -0.1047,  1.1114, -0.1914, -0.6082,  1.0401, -0.0886, -0.2607, -0.1922],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9066,  0.6897, -0.4438,  0.1929,  0.2643,  0.0043, -0.5846, -1.1180,
-        -0.5805, -0.6003, -0.2405, -0.2728,  0.0349,  0.3279, -0.0627, -0.5295,
-        -0.7491,  0.6861, -0.8571, -0.6700, -0.3220, -0.0618, -0.6008, -0.2866,
-        -0.0927,  0.3199,  0.0120,  0.0559,  0.1042,  1.5007, -0.9084, -0.0724,
-        -0.7880, -0.1825, -0.9118, -0.2538,  0.5168,  0.0000,  0.1815,  0.4954,
-         0.3621, -0.2863, -0.0147,  0.7091,  0.5611,  0.6199,  0.2499,  0.4305,
-         0.5527, -0.5380,  0.5504, -0.0929,  0.4498,  0.2489,  0.7318,  0.5301,
-        -0.1047,  1.1114, -0.1914, -0.6082,  1.0401, -0.0886, -0.2607, -0.1922],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9066,  0.6897, -0.4438,  0.1929,  0.2643,  0.0043, -0.5846, -1.1180,
-        -0.5805, -0.6003, -0.2405, -0.2728,  0.0349,  0.3279, -0.0627, -0.5295,
-        -0.7491,  0.6861, -0.8571, -0.6700, -0.3220, -0.0618, -0.6008, -0.2866,
-        -0.0927,  0.3199,  0.0120,  0.0559,  0.1042,  1.5007, -0.9084, -0.0724,
-        -0.7880, -0.1825, -0.9118, -0.2538,  0.5168,  0.0000,  0.1815,  0.4954,
-         0.3621, -0.2863, -0.0147,  0.7091,  0.5611,  0.6199,  0.2499,  0.4305,
-         0.5527, -0.5380,  0.5504, -0.0929,  0.4498,  0.2489,  0.7318,  0.5301,
-        -0.1047,  1.1114, -0.1914, -0.6082,  1.0401, -0.0886, -0.2607, -0.1922],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9067,  0.6917, -0.4460,  0.1880,  0.2689,  0.0043, -0.5849, -1.1180,
-        -0.5788, -0.6010, -0.2378, -0.2721,  0.0393,  0.3190, -0.0577, -0.5324,
-        -0.7523,  0.6853, -0.8570, -0.6697, -0.3200, -0.0677, -0.5992, -0.2895,
-        -0.0892,  0.3176,  0.0076,  0.0378,  0.1030,  1.5013, -0.9082, -0.0667,
-        -0.7890, -0.1853, -0.9124, -0.2588,  0.5127,  0.0000,  0.1970,  0.4963,
-         0.3651, -0.2819, -0.0120,  0.7088,  0.5620,  0.6233,  0.2587,  0.4317,
-         0.5611, -0.5381,  0.5482, -0.0874,  0.4539,  0.2583,  0.7338,  0.5264,
-        -0.0967,  1.1117, -0.1930, -0.6087,  1.0404, -0.0814, -0.2584, -0.1984],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9067,  0.6917, -0.4460,  0.1880,  0.2689,  0.0043, -0.5849, -1.1180,
-        -0.5788, -0.6010, -0.2378, -0.2721,  0.0393,  0.3190, -0.0577, -0.5324,
-        -0.7523,  0.6853, -0.8570, -0.6697, -0.3200, -0.0677, -0.5992, -0.2895,
-        -0.0892,  0.3176,  0.0076,  0.0378,  0.1030,  1.5013, -0.9082, -0.0667,
-        -0.7890, -0.1853, -0.9124, -0.2588,  0.5127,  0.0000,  0.1970,  0.4963,
-         0.3651, -0.2819, -0.0120,  0.7088,  0.5620,  0.6233,  0.2587,  0.4317,
-         0.5611, -0.5381,  0.5482, -0.0874,  0.4539,  0.2583,  0.7338,  0.5264,
-        -0.0967,  1.1117, -0.1930, -0.6087,  1.0404, -0.0814, -0.2584, -0.1984],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9067,  0.6917, -0.4460,  0.1880,  0.2689,  0.0043, -0.5849, -1.1180,
-        -0.5788, -0.6010, -0.2378, -0.2721,  0.0393,  0.3190, -0.0577, -0.5324,
-        -0.7523,  0.6853, -0.8570, -0.6697, -0.3200, -0.0677, -0.5992, -0.2895,
-        -0.0892,  0.3176,  0.0076,  0.0378,  0.1030,  1.5013, -0.9082, -0.0667,
-        -0.7890, -0.1853, -0.9124, -0.2588,  0.5127,  0.0000,  0.1970,  0.4963,
-         0.3651, -0.2819, -0.0120,  0.7088,  0.5620,  0.6233,  0.2587,  0.4317,
-         0.5611, -0.5381,  0.5482, -0.0874,  0.4539,  0.2583,  0.7338,  0.5264,
-        -0.0967,  1.1117, -0.1930, -0.6087,  1.0404, -0.0814, -0.2584, -0.1984],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9069,  0.6938, -0.4478,  0.1783,  0.2721,  0.0034, -0.5845, -1.1178,
-        -0.5775, -0.6023, -0.2337, -0.2710,  0.0433,  0.3107, -0.0527, -0.5356,
-        -0.7558,  0.6853, -0.8567, -0.6677, -0.3183, -0.0739, -0.5968, -0.2924,
-        -0.0841,  0.3135,  0.0068,  0.0194,  0.1056,  1.5023, -0.9080, -0.0594,
-        -0.7897, -0.1899, -0.9130, -0.2638,  0.5088,  0.0000,  0.2112,  0.4976,
-         0.3681, -0.2765, -0.0132,  0.7077,  0.5634,  0.6259,  0.2659,  0.4319,
-         0.5705, -0.5390,  0.5457, -0.0838,  0.4562,  0.2672,  0.7353,  0.5230,
-        -0.0871,  1.1115, -0.1971, -0.6104,  1.0405, -0.0757, -0.2551, -0.2043],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9069,  0.6938, -0.4478,  0.1783,  0.2721,  0.0034, -0.5845, -1.1178,
-        -0.5775, -0.6023, -0.2337, -0.2710,  0.0433,  0.3107, -0.0527, -0.5356,
-        -0.7558,  0.6853, -0.8567, -0.6677, -0.3183, -0.0739, -0.5968, -0.2924,
-        -0.0841,  0.3135,  0.0068,  0.0194,  0.1056,  1.5023, -0.9080, -0.0594,
-        -0.7897, -0.1899, -0.9130, -0.2638,  0.5088,  0.0000,  0.2112,  0.4976,
-         0.3681, -0.2765, -0.0132,  0.7077,  0.5634,  0.6259,  0.2659,  0.4319,
-         0.5705, -0.5390,  0.5457, -0.0838,  0.4562,  0.2672,  0.7353,  0.5230,
-        -0.0871,  1.1115, -0.1971, -0.6104,  1.0405, -0.0757, -0.2551, -0.2043],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9069,  0.6938, -0.4478,  0.1783,  0.2721,  0.0034, -0.5845, -1.1178,
-        -0.5775, -0.6023, -0.2337, -0.2710,  0.0433,  0.3107, -0.0527, -0.5356,
-        -0.7558,  0.6853, -0.8567, -0.6677, -0.3183, -0.0739, -0.5968, -0.2924,
-        -0.0841,  0.3135,  0.0068,  0.0194,  0.1056,  1.5023, -0.9080, -0.0594,
-        -0.7897, -0.1899, -0.9130, -0.2638,  0.5088,  0.0000,  0.2112,  0.4976,
-         0.3681, -0.2765, -0.0132,  0.7077,  0.5634,  0.6259,  0.2659,  0.4319,
-         0.5705, -0.5390,  0.5457, -0.0838,  0.4562,  0.2672,  0.7353,  0.5230,
-        -0.0871,  1.1115, -0.1971, -0.6104,  1.0405, -0.0757, -0.2551, -0.2043],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9077,  0.6957, -0.4499,  0.1688,  0.2772,  0.0024, -0.5842, -1.1173,
-        -0.5771, -0.6048, -0.2298, -0.2716,  0.0449,  0.3024, -0.0538, -0.5387,
-        -0.7594,  0.6847, -0.8564, -0.6661, -0.3166, -0.0810, -0.5943, -0.2956,
-        -0.0836,  0.3080,  0.0083, -0.0025,  0.1035,  1.5033, -0.9076, -0.0521,
-        -0.7905, -0.2004, -0.9134, -0.2695,  0.5074,  0.0000,  0.2219,  0.5001,
-         0.3719, -0.2699, -0.0144,  0.7056,  0.5650,  0.6304,  0.2773,  0.4329,
-         0.5810, -0.5393,  0.5426, -0.0798,  0.4616,  0.2751,  0.7372,  0.5189,
-        -0.0831,  1.1112, -0.1955, -0.6127,  1.0407, -0.0740, -0.2541, -0.2065],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9077,  0.6957, -0.4499,  0.1688,  0.2772,  0.0024, -0.5842, -1.1173,
-        -0.5771, -0.6048, -0.2298, -0.2716,  0.0449,  0.3024, -0.0538, -0.5387,
-        -0.7594,  0.6847, -0.8564, -0.6661, -0.3166, -0.0810, -0.5943, -0.2956,
-        -0.0836,  0.3080,  0.0083, -0.0025,  0.1035,  1.5033, -0.9076, -0.0521,
-        -0.7905, -0.2004, -0.9134, -0.2695,  0.5074,  0.0000,  0.2219,  0.5001,
-         0.3719, -0.2699, -0.0144,  0.7056,  0.5650,  0.6304,  0.2773,  0.4329,
-         0.5810, -0.5393,  0.5426, -0.0798,  0.4616,  0.2751,  0.7372,  0.5189,
-        -0.0831,  1.1112, -0.1955, -0.6127,  1.0407, -0.0740, -0.2541, -0.2065],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9077,  0.6957, -0.4499,  0.1688,  0.2772,  0.0024, -0.5842, -1.1173,
-        -0.5771, -0.6048, -0.2298, -0.2716,  0.0449,  0.3024, -0.0538, -0.5387,
-        -0.7594,  0.6847, -0.8564, -0.6661, -0.3166, -0.0810, -0.5943, -0.2956,
-        -0.0836,  0.3080,  0.0083, -0.0025,  0.1035,  1.5033, -0.9076, -0.0521,
-        -0.7905, -0.2004, -0.9134, -0.2695,  0.5074,  0.0000,  0.2219,  0.5001,
-         0.3719, -0.2699, -0.0144,  0.7056,  0.5650,  0.6304,  0.2773,  0.4329,
-         0.5810, -0.5393,  0.5426, -0.0798,  0.4616,  0.2751,  0.7372,  0.5189,
-        -0.0831,  1.1112, -0.1955, -0.6127,  1.0407, -0.0740, -0.2541, -0.2065],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9081,  0.6975, -0.4515,  0.1596,  0.2813,  0.0029, -0.5841, -1.1171,
-        -0.5775, -0.6065, -0.2239, -0.2703,  0.0394,  0.2900, -0.0527, -0.5429,
-        -0.7630,  0.6845, -0.8573, -0.6653, -0.3175, -0.0901, -0.5918, -0.2999,
-        -0.0875,  0.3023,  0.0116, -0.0226,  0.1005,  1.5046, -0.9070, -0.0471,
-        -0.7927, -0.2117, -0.9136, -0.2742,  0.5035,  0.0000,  0.2337,  0.5030,
-         0.3754, -0.2605, -0.0142,  0.7043,  0.5671,  0.6361,  0.2890,  0.4311,
-         0.5928, -0.5399,  0.5401, -0.0779,  0.4663,  0.2834,  0.7389,  0.5157,
-        -0.0760,  1.1110, -0.1953, -0.6170,  1.0409, -0.0695, -0.2545, -0.2089],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9081,  0.6975, -0.4515,  0.1596,  0.2813,  0.0029, -0.5841, -1.1171,
-        -0.5775, -0.6065, -0.2239, -0.2703,  0.0394,  0.2900, -0.0527, -0.5429,
-        -0.7630,  0.6845, -0.8573, -0.6653, -0.3175, -0.0901, -0.5918, -0.2999,
-        -0.0875,  0.3023,  0.0116, -0.0226,  0.1005,  1.5046, -0.9070, -0.0471,
-        -0.7927, -0.2117, -0.9136, -0.2742,  0.5035,  0.0000,  0.2337,  0.5030,
-         0.3754, -0.2605, -0.0142,  0.7043,  0.5671,  0.6361,  0.2890,  0.4311,
-         0.5928, -0.5399,  0.5401, -0.0779,  0.4663,  0.2834,  0.7389,  0.5157,
-        -0.0760,  1.1110, -0.1953, -0.6170,  1.0409, -0.0695, -0.2545, -0.2089],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9081,  0.6975, -0.4515,  0.1596,  0.2813,  0.0029, -0.5841, -1.1171,
-        -0.5775, -0.6065, -0.2239, -0.2703,  0.0394,  0.2900, -0.0527, -0.5429,
-        -0.7630,  0.6845, -0.8573, -0.6653, -0.3175, -0.0901, -0.5918, -0.2999,
-        -0.0875,  0.3023,  0.0116, -0.0226,  0.1005,  1.5046, -0.9070, -0.0471,
-        -0.7927, -0.2117, -0.9136, -0.2742,  0.5035,  0.0000,  0.2337,  0.5030,
-         0.3754, -0.2605, -0.0142,  0.7043,  0.5671,  0.6361,  0.2890,  0.4311,
-         0.5928, -0.5399,  0.5401, -0.0779,  0.4663,  0.2834,  0.7389,  0.5157,
-        -0.0760,  1.1110, -0.1953, -0.6170,  1.0409, -0.0695, -0.2545, -0.2089],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 9.0809e-01,  6.9898e-01, -4.5331e-01,  1.5216e-01,  2.8274e-01,
-         3.8534e-05, -5.8375e-01, -1.1168e+00, -5.7785e-01, -6.0879e-01,
-        -2.1652e-01, -2.7356e-01,  3.4214e-02,  2.7478e-01, -6.3727e-02,
-        -5.4732e-01, -7.6711e-01,  6.8492e-01, -8.5857e-01, -6.6502e-01,
-        -3.1773e-01, -9.8530e-02, -5.8796e-01, -3.0606e-01, -9.1878e-02,
-         2.9933e-01,  1.8453e-02, -3.8549e-02,  1.0132e-01,  1.5058e+00,
-        -9.0663e-01, -5.0223e-02, -7.9554e-01, -2.2241e-01, -9.1341e-01,
-        -2.8046e-01,  5.0073e-01,  0.0000e+00,  2.4553e-01,  5.0247e-01,
-         3.7934e-01, -2.4644e-01, -1.0619e-02,  7.0329e-01,  5.7052e-01,
-         6.4158e-01,  2.9682e-01,  4.2901e-01,  6.0502e-01, -5.4004e-01,
-         5.3717e-01, -7.9878e-02,  4.6994e-01,  2.9041e-01,  7.4008e-01,
-         5.1300e-01, -6.4303e-02,  1.1109e+00, -1.9746e-01, -6.1998e-01,
-         1.0418e+00, -6.0271e-02, -2.5628e-01, -2.0947e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 9.0809e-01,  6.9898e-01, -4.5331e-01,  1.5216e-01,  2.8274e-01,
-         3.8534e-05, -5.8375e-01, -1.1168e+00, -5.7785e-01, -6.0879e-01,
-        -2.1652e-01, -2.7356e-01,  3.4214e-02,  2.7478e-01, -6.3727e-02,
-        -5.4732e-01, -7.6711e-01,  6.8492e-01, -8.5857e-01, -6.6502e-01,
-        -3.1773e-01, -9.8530e-02, -5.8796e-01, -3.0606e-01, -9.1878e-02,
-         2.9933e-01,  1.8453e-02, -3.8549e-02,  1.0132e-01,  1.5058e+00,
-        -9.0663e-01, -5.0223e-02, -7.9554e-01, -2.2241e-01, -9.1341e-01,
-        -2.8046e-01,  5.0073e-01,  0.0000e+00,  2.4553e-01,  5.0247e-01,
-         3.7934e-01, -2.4644e-01, -1.0619e-02,  7.0329e-01,  5.7052e-01,
-         6.4158e-01,  2.9682e-01,  4.2901e-01,  6.0502e-01, -5.4004e-01,
-         5.3717e-01, -7.9878e-02,  4.6994e-01,  2.9041e-01,  7.4008e-01,
-         5.1300e-01, -6.4303e-02,  1.1109e+00, -1.9746e-01, -6.1998e-01,
-         1.0418e+00, -6.0271e-02, -2.5628e-01, -2.0947e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 9.0809e-01,  6.9898e-01, -4.5331e-01,  1.5216e-01,  2.8274e-01,
-         3.8534e-05, -5.8375e-01, -1.1168e+00, -5.7785e-01, -6.0879e-01,
-        -2.1652e-01, -2.7356e-01,  3.4214e-02,  2.7478e-01, -6.3727e-02,
-        -5.4732e-01, -7.6711e-01,  6.8492e-01, -8.5857e-01, -6.6502e-01,
-        -3.1773e-01, -9.8530e-02, -5.8796e-01, -3.0606e-01, -9.1878e-02,
-         2.9933e-01,  1.8453e-02, -3.8549e-02,  1.0132e-01,  1.5058e+00,
-        -9.0663e-01, -5.0223e-02, -7.9554e-01, -2.2241e-01, -9.1341e-01,
-        -2.8046e-01,  5.0073e-01,  0.0000e+00,  2.4553e-01,  5.0247e-01,
-         3.7934e-01, -2.4644e-01, -1.0619e-02,  7.0329e-01,  5.7052e-01,
-         6.4158e-01,  2.9682e-01,  4.2901e-01,  6.0502e-01, -5.4004e-01,
-         5.3717e-01, -7.9878e-02,  4.6994e-01,  2.9041e-01,  7.4008e-01,
-         5.1300e-01, -6.4303e-02,  1.1109e+00, -1.9746e-01, -6.1998e-01,
-         1.0418e+00, -6.0271e-02, -2.5628e-01, -2.0947e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9082,  0.7016, -0.4534,  0.1388,  0.2857, -0.0019, -0.5843, -1.1163,
-        -0.5804, -0.6104, -0.2084, -0.2757,  0.0238,  0.2603, -0.0733, -0.5544,
-        -0.7707,  0.6854, -0.8598, -0.6638, -0.3173, -0.1110, -0.5853, -0.3088,
-        -0.0968,  0.2967,  0.0226, -0.0596,  0.0984,  1.5074, -0.9062, -0.0448,
-        -0.7989, -0.2336, -0.9136, -0.2848,  0.4983,  0.0000,  0.2580,  0.5003,
-         0.3830, -0.2286, -0.0051,  0.7022,  0.5734,  0.6475,  0.3064,  0.4254,
-         0.6158, -0.5419,  0.5330, -0.0805,  0.4748,  0.2967,  0.7402,  0.5106,
-        -0.0460,  1.1109, -0.2011, -0.6216,  1.0421, -0.0527, -0.2584, -0.2088],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9082,  0.7016, -0.4534,  0.1388,  0.2857, -0.0019, -0.5843, -1.1163,
-        -0.5804, -0.6104, -0.2084, -0.2757,  0.0238,  0.2603, -0.0733, -0.5544,
-        -0.7707,  0.6854, -0.8598, -0.6638, -0.3173, -0.1110, -0.5853, -0.3088,
-        -0.0968,  0.2967,  0.0226, -0.0596,  0.0984,  1.5074, -0.9062, -0.0448,
-        -0.7989, -0.2336, -0.9136, -0.2848,  0.4983,  0.0000,  0.2580,  0.5003,
-         0.3830, -0.2286, -0.0051,  0.7022,  0.5734,  0.6475,  0.3064,  0.4254,
-         0.6158, -0.5419,  0.5330, -0.0805,  0.4748,  0.2967,  0.7402,  0.5106,
-        -0.0460,  1.1109, -0.2011, -0.6216,  1.0421, -0.0527, -0.2584, -0.2088],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9082,  0.7016, -0.4534,  0.1388,  0.2857, -0.0019, -0.5843, -1.1163,
-        -0.5804, -0.6104, -0.2084, -0.2757,  0.0238,  0.2603, -0.0733, -0.5544,
-        -0.7707,  0.6854, -0.8598, -0.6638, -0.3173, -0.1110, -0.5853, -0.3088,
-        -0.0968,  0.2967,  0.0226, -0.0596,  0.0984,  1.5074, -0.9062, -0.0448,
-        -0.7989, -0.2336, -0.9136, -0.2848,  0.4983,  0.0000,  0.2580,  0.5003,
-         0.3830, -0.2286, -0.0051,  0.7022,  0.5734,  0.6475,  0.3064,  0.4254,
-         0.6158, -0.5419,  0.5330, -0.0805,  0.4748,  0.2967,  0.7402,  0.5106,
-        -0.0460,  1.1109, -0.2011, -0.6216,  1.0421, -0.0527, -0.2584, -0.2088],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 9.0735e-01,  7.0306e-01, -4.5280e-01,  1.2427e-01,  2.8572e-01,
-        -4.0952e-03, -5.8513e-01, -1.1155e+00, -5.8369e-01, -6.1238e-01,
-        -2.0482e-01, -2.7752e-01,  1.4553e-02,  2.5039e-01, -8.9902e-02,
-        -5.5956e-01, -7.7304e-01,  6.8676e-01, -8.6044e-01, -6.6298e-01,
-        -3.1876e-01, -1.2073e-01, -5.8386e-01, -3.1067e-01, -1.0285e-01,
-         2.9348e-01,  2.2847e-02, -8.3241e-02,  9.3427e-02,  1.5093e+00,
-        -9.0596e-01, -3.1366e-02, -8.0256e-01, -2.4658e-01, -9.1331e-01,
-        -2.8910e-01,  4.9493e-01,  0.0000e+00,  2.6856e-01,  4.9764e-01,
-         3.8877e-01, -2.0599e-01,  6.7335e-04,  7.0219e-01,  5.7672e-01,
-         6.5394e-01,  3.1667e-01,  4.2309e-01,  6.2805e-01, -5.4364e-01,
-         5.3175e-01, -8.0631e-02,  4.7960e-01,  3.0852e-01,  7.4047e-01,
-         5.0808e-01, -3.1520e-02,  1.1107e+00, -2.0246e-01, -6.2141e-01,
-         1.0426e+00, -4.7125e-02, -2.5999e-01, -2.1075e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 9.0735e-01,  7.0306e-01, -4.5280e-01,  1.2427e-01,  2.8572e-01,
-        -4.0952e-03, -5.8513e-01, -1.1155e+00, -5.8369e-01, -6.1238e-01,
-        -2.0482e-01, -2.7752e-01,  1.4553e-02,  2.5039e-01, -8.9902e-02,
-        -5.5956e-01, -7.7304e-01,  6.8676e-01, -8.6044e-01, -6.6298e-01,
-        -3.1876e-01, -1.2073e-01, -5.8386e-01, -3.1067e-01, -1.0285e-01,
-         2.9348e-01,  2.2847e-02, -8.3241e-02,  9.3427e-02,  1.5093e+00,
-        -9.0596e-01, -3.1366e-02, -8.0256e-01, -2.4658e-01, -9.1331e-01,
-        -2.8910e-01,  4.9493e-01,  0.0000e+00,  2.6856e-01,  4.9764e-01,
-         3.8877e-01, -2.0599e-01,  6.7335e-04,  7.0219e-01,  5.7672e-01,
-         6.5394e-01,  3.1667e-01,  4.2309e-01,  6.2805e-01, -5.4364e-01,
-         5.3175e-01, -8.0631e-02,  4.7960e-01,  3.0852e-01,  7.4047e-01,
-         5.0808e-01, -3.1520e-02,  1.1107e+00, -2.0246e-01, -6.2141e-01,
-         1.0426e+00, -4.7125e-02, -2.5999e-01, -2.1075e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 9.0735e-01,  7.0306e-01, -4.5280e-01,  1.2427e-01,  2.8572e-01,
-        -4.0952e-03, -5.8513e-01, -1.1155e+00, -5.8369e-01, -6.1238e-01,
-        -2.0482e-01, -2.7752e-01,  1.4553e-02,  2.5039e-01, -8.9902e-02,
-        -5.5956e-01, -7.7304e-01,  6.8676e-01, -8.6044e-01, -6.6298e-01,
-        -3.1876e-01, -1.2073e-01, -5.8386e-01, -3.1067e-01, -1.0285e-01,
-         2.9348e-01,  2.2847e-02, -8.3241e-02,  9.3427e-02,  1.5093e+00,
-        -9.0596e-01, -3.1366e-02, -8.0256e-01, -2.4658e-01, -9.1331e-01,
-        -2.8910e-01,  4.9493e-01,  0.0000e+00,  2.6856e-01,  4.9764e-01,
-         3.8877e-01, -2.0599e-01,  6.7335e-04,  7.0219e-01,  5.7672e-01,
-         6.5394e-01,  3.1667e-01,  4.2309e-01,  6.2805e-01, -5.4364e-01,
-         5.3175e-01, -8.0631e-02,  4.7960e-01,  3.0852e-01,  7.4047e-01,
-         5.0808e-01, -3.1520e-02,  1.1107e+00, -2.0246e-01, -6.2141e-01,
-         1.0426e+00, -4.7125e-02, -2.5999e-01, -2.1075e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9071,  0.7055, -0.4502,  0.1171,  0.2856, -0.0041, -0.5851, -1.1147,
-        -0.5870, -0.6136, -0.1971, -0.2846,  0.0083,  0.2404, -0.1021, -0.5651,
-        -0.7755,  0.6884, -0.8611, -0.6613, -0.3226, -0.1312, -0.5830, -0.3122,
-        -0.1121,  0.2915,  0.0292, -0.1095,  0.0877,  1.5109, -0.9060, -0.0140,
-        -0.8059, -0.2547, -0.9129, -0.2935,  0.4930,  0.0000,  0.2806,  0.4951,
-         0.3935, -0.1778,  0.0052,  0.7016,  0.5800,  0.6624,  0.3287,  0.4221,
-         0.6403, -0.5465,  0.5309, -0.0773,  0.4824,  0.3177,  0.7390,  0.5066,
-        -0.0134,  1.1100, -0.1999, -0.6191,  1.0424, -0.0449, -0.2615, -0.2139],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9071,  0.7055, -0.4502,  0.1171,  0.2856, -0.0041, -0.5851, -1.1147,
-        -0.5870, -0.6136, -0.1971, -0.2846,  0.0083,  0.2404, -0.1021, -0.5651,
-        -0.7755,  0.6884, -0.8611, -0.6613, -0.3226, -0.1312, -0.5830, -0.3122,
-        -0.1121,  0.2915,  0.0292, -0.1095,  0.0877,  1.5109, -0.9060, -0.0140,
-        -0.8059, -0.2547, -0.9129, -0.2935,  0.4930,  0.0000,  0.2806,  0.4951,
-         0.3935, -0.1778,  0.0052,  0.7016,  0.5800,  0.6624,  0.3287,  0.4221,
-         0.6403, -0.5465,  0.5309, -0.0773,  0.4824,  0.3177,  0.7390,  0.5066,
-        -0.0134,  1.1100, -0.1999, -0.6191,  1.0424, -0.0449, -0.2615, -0.2139],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9071,  0.7055, -0.4502,  0.1171,  0.2856, -0.0041, -0.5851, -1.1147,
-        -0.5870, -0.6136, -0.1971, -0.2846,  0.0083,  0.2404, -0.1021, -0.5651,
-        -0.7755,  0.6884, -0.8611, -0.6613, -0.3226, -0.1312, -0.5830, -0.3122,
-        -0.1121,  0.2915,  0.0292, -0.1095,  0.0877,  1.5109, -0.9060, -0.0140,
-        -0.8059, -0.2547, -0.9129, -0.2935,  0.4930,  0.0000,  0.2806,  0.4951,
-         0.3935, -0.1778,  0.0052,  0.7016,  0.5800,  0.6624,  0.3287,  0.4221,
-         0.6403, -0.5465,  0.5309, -0.0773,  0.4824,  0.3177,  0.7390,  0.5066,
-        -0.0134,  1.1100, -0.1999, -0.6191,  1.0424, -0.0449, -0.2615, -0.2139],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 9.0613e-01,  7.0808e-01, -4.4644e-01,  9.6410e-02,  2.8943e-01,
-        -2.7042e-03, -5.8521e-01, -1.1138e+00, -5.9066e-01, -6.1504e-01,
-        -1.9424e-01, -2.9477e-01,  2.2738e-03,  2.3198e-01, -1.0466e-01,
-        -5.7044e-01, -7.7745e-01,  6.9018e-01, -8.6190e-01, -6.6058e-01,
-        -3.2516e-01, -1.3661e-01, -5.8217e-01, -3.1276e-01, -1.2280e-01,
-         2.8869e-01,  3.7664e-02, -1.3780e-01,  8.6353e-02,  1.5123e+00,
-        -9.0595e-01,  3.7749e-03, -8.0916e-01, -2.5962e-01, -9.1271e-01,
-        -2.9870e-01,  4.9041e-01,  0.0000e+00,  2.9293e-01,  4.9208e-01,
-         3.9696e-01, -1.5024e-01,  1.0647e-02,  6.9866e-01,  5.8363e-01,
-         6.6923e-01,  3.4186e-01,  4.2001e-01,  6.5021e-01, -5.4899e-01,
-         5.3154e-01, -7.9079e-02,  4.8539e-01,  3.2534e-01,  7.3732e-01,
-         5.0484e-01,  9.3496e-04,  1.1090e+00, -1.9319e-01, -6.1569e-01,
-         1.0422e+00, -4.6211e-02, -2.6277e-01, -2.1266e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 9.0613e-01,  7.0808e-01, -4.4644e-01,  9.6410e-02,  2.8943e-01,
-        -2.7042e-03, -5.8521e-01, -1.1138e+00, -5.9066e-01, -6.1504e-01,
-        -1.9424e-01, -2.9477e-01,  2.2738e-03,  2.3198e-01, -1.0466e-01,
-        -5.7044e-01, -7.7745e-01,  6.9018e-01, -8.6190e-01, -6.6058e-01,
-        -3.2516e-01, -1.3661e-01, -5.8217e-01, -3.1276e-01, -1.2280e-01,
-         2.8869e-01,  3.7664e-02, -1.3780e-01,  8.6353e-02,  1.5123e+00,
-        -9.0595e-01,  3.7749e-03, -8.0916e-01, -2.5962e-01, -9.1271e-01,
-        -2.9870e-01,  4.9041e-01,  0.0000e+00,  2.9293e-01,  4.9208e-01,
-         3.9696e-01, -1.5024e-01,  1.0647e-02,  6.9866e-01,  5.8363e-01,
-         6.6923e-01,  3.4186e-01,  4.2001e-01,  6.5021e-01, -5.4899e-01,
-         5.3154e-01, -7.9079e-02,  4.8539e-01,  3.2534e-01,  7.3732e-01,
-         5.0484e-01,  9.3496e-04,  1.1090e+00, -1.9319e-01, -6.1569e-01,
-         1.0422e+00, -4.6211e-02, -2.6277e-01, -2.1266e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 9.0613e-01,  7.0808e-01, -4.4644e-01,  9.6410e-02,  2.8943e-01,
-        -2.7042e-03, -5.8521e-01, -1.1138e+00, -5.9066e-01, -6.1504e-01,
-        -1.9424e-01, -2.9477e-01,  2.2738e-03,  2.3198e-01, -1.0466e-01,
-        -5.7044e-01, -7.7745e-01,  6.9018e-01, -8.6190e-01, -6.6058e-01,
-        -3.2516e-01, -1.3661e-01, -5.8217e-01, -3.1276e-01, -1.2280e-01,
-         2.8869e-01,  3.7664e-02, -1.3780e-01,  8.6353e-02,  1.5123e+00,
-        -9.0595e-01,  3.7749e-03, -8.0916e-01, -2.5962e-01, -9.1271e-01,
-        -2.9870e-01,  4.9041e-01,  0.0000e+00,  2.9293e-01,  4.9208e-01,
-         3.9696e-01, -1.5024e-01,  1.0647e-02,  6.9866e-01,  5.8363e-01,
-         6.6923e-01,  3.4186e-01,  4.2001e-01,  6.5021e-01, -5.4899e-01,
-         5.3154e-01, -7.9079e-02,  4.8539e-01,  3.2534e-01,  7.3732e-01,
-         5.0484e-01,  9.3496e-04,  1.1090e+00, -1.9319e-01, -6.1569e-01,
-         1.0422e+00, -4.6211e-02, -2.6277e-01, -2.1266e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9050,  0.7090, -0.4428,  0.0621,  0.2917, -0.0021, -0.5851, -1.1131,
-        -0.5942, -0.6155, -0.2006, -0.3052, -0.0033,  0.2221, -0.0993, -0.5760,
-        -0.7811,  0.6914, -0.8627, -0.6572, -0.3274, -0.1419, -0.5826, -0.3159,
-        -0.1268,  0.2859,  0.0527, -0.1725,  0.0850,  1.5138, -0.9058,  0.0181,
-        -0.8133, -0.2710, -0.9126, -0.3047,  0.4894,  0.0000,  0.3094,  0.4910,
-         0.3997, -0.1157,  0.0187,  0.6959,  0.5887,  0.6769,  0.3532,  0.4233,
-         0.6589, -0.5515,  0.5308, -0.0902,  0.4869,  0.3354,  0.7359,  0.5037,
-         0.0155,  1.1071, -0.1871, -0.6131,  1.0427, -0.0476, -0.2588, -0.2121],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9050,  0.7090, -0.4428,  0.0621,  0.2917, -0.0021, -0.5851, -1.1131,
-        -0.5942, -0.6155, -0.2006, -0.3052, -0.0033,  0.2221, -0.0993, -0.5760,
-        -0.7811,  0.6914, -0.8627, -0.6572, -0.3274, -0.1419, -0.5826, -0.3159,
-        -0.1268,  0.2859,  0.0527, -0.1725,  0.0850,  1.5138, -0.9058,  0.0181,
-        -0.8133, -0.2710, -0.9126, -0.3047,  0.4894,  0.0000,  0.3094,  0.4910,
-         0.3997, -0.1157,  0.0187,  0.6959,  0.5887,  0.6769,  0.3532,  0.4233,
-         0.6589, -0.5515,  0.5308, -0.0902,  0.4869,  0.3354,  0.7359,  0.5037,
-         0.0155,  1.1071, -0.1871, -0.6131,  1.0427, -0.0476, -0.2588, -0.2121],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9050,  0.7090, -0.4428,  0.0621,  0.2917, -0.0021, -0.5851, -1.1131,
-        -0.5942, -0.6155, -0.2006, -0.3052, -0.0033,  0.2221, -0.0993, -0.5760,
-        -0.7811,  0.6914, -0.8627, -0.6572, -0.3274, -0.1419, -0.5826, -0.3159,
-        -0.1268,  0.2859,  0.0527, -0.1725,  0.0850,  1.5138, -0.9058,  0.0181,
-        -0.8133, -0.2710, -0.9126, -0.3047,  0.4894,  0.0000,  0.3094,  0.4910,
-         0.3997, -0.1157,  0.0187,  0.6959,  0.5887,  0.6769,  0.3532,  0.4233,
-         0.6589, -0.5515,  0.5308, -0.0902,  0.4869,  0.3354,  0.7359,  0.5037,
-         0.0155,  1.1071, -0.1871, -0.6131,  1.0427, -0.0476, -0.2588, -0.2121],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9024,  0.7099, -0.4382,  0.0253,  0.2915,  0.0029, -0.5854, -1.1123,
-        -0.5988, -0.6146, -0.2091, -0.3161, -0.0127,  0.2118, -0.0929, -0.5797,
-        -0.7847,  0.6928, -0.8638, -0.6533, -0.3271, -0.1491, -0.5842, -0.3145,
-        -0.1287,  0.2864,  0.0636, -0.2015,  0.0826,  1.5153, -0.9048,  0.0401,
-        -0.8174, -0.2842, -0.9127, -0.3088,  0.4886,  0.0000,  0.3236,  0.4893,
-         0.4019, -0.0817,  0.0194,  0.6937,  0.5930,  0.6844,  0.3661,  0.4270,
-         0.6675, -0.5538,  0.5277, -0.1057,  0.4880,  0.3447,  0.7345,  0.5035,
-         0.0239,  1.1051, -0.1804, -0.6096,  1.0432, -0.0457, -0.2525, -0.2147],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9024,  0.7099, -0.4382,  0.0253,  0.2915,  0.0029, -0.5854, -1.1123,
-        -0.5988, -0.6146, -0.2091, -0.3161, -0.0127,  0.2118, -0.0929, -0.5797,
-        -0.7847,  0.6928, -0.8638, -0.6533, -0.3271, -0.1491, -0.5842, -0.3145,
-        -0.1287,  0.2864,  0.0636, -0.2015,  0.0826,  1.5153, -0.9048,  0.0401,
-        -0.8174, -0.2842, -0.9127, -0.3088,  0.4886,  0.0000,  0.3236,  0.4893,
-         0.4019, -0.0817,  0.0194,  0.6937,  0.5930,  0.6844,  0.3661,  0.4270,
-         0.6675, -0.5538,  0.5277, -0.1057,  0.4880,  0.3447,  0.7345,  0.5035,
-         0.0239,  1.1051, -0.1804, -0.6096,  1.0432, -0.0457, -0.2525, -0.2147],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9024,  0.7099, -0.4382,  0.0253,  0.2915,  0.0029, -0.5854, -1.1123,
-        -0.5988, -0.6146, -0.2091, -0.3161, -0.0127,  0.2118, -0.0929, -0.5797,
-        -0.7847,  0.6928, -0.8638, -0.6533, -0.3271, -0.1491, -0.5842, -0.3145,
-        -0.1287,  0.2864,  0.0636, -0.2015,  0.0826,  1.5153, -0.9048,  0.0401,
-        -0.8174, -0.2842, -0.9127, -0.3088,  0.4886,  0.0000,  0.3236,  0.4893,
-         0.4019, -0.0817,  0.0194,  0.6937,  0.5930,  0.6844,  0.3661,  0.4270,
-         0.6675, -0.5538,  0.5277, -0.1057,  0.4880,  0.3447,  0.7345,  0.5035,
-         0.0239,  1.1051, -0.1804, -0.6096,  1.0432, -0.0457, -0.2525, -0.2147],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 8.9953e-01,  7.0912e-01, -4.3275e-01, -1.4428e-03,  2.9315e-01,
-         1.0506e-02, -5.8402e-01, -1.1115e+00, -6.0515e-01, -6.1241e-01,
-        -2.2131e-01, -3.2038e-01, -1.4874e-02,  1.9834e-01, -8.6796e-02,
-        -5.8208e-01, -7.8766e-01,  6.9730e-01, -8.6626e-01, -6.5007e-01,
-        -3.2225e-01, -1.5535e-01, -5.8389e-01, -3.1289e-01, -1.3179e-01,
-         2.8526e-01,  6.9646e-02, -2.2773e-01,  8.0369e-02,  1.5170e+00,
-        -9.0445e-01,  4.7837e-02, -8.2241e-01, -2.9930e-01, -9.1316e-01,
-        -3.1202e-01,  4.8882e-01,  0.0000e+00,  3.3642e-01,  4.8637e-01,
-         4.0246e-01, -4.2253e-02,  2.0668e-02,  6.9341e-01,  5.9713e-01,
-         6.9088e-01,  3.7560e-01,  4.3223e-01,  6.7582e-01, -5.5512e-01,
-         5.2483e-01, -1.2274e-01,  4.8765e-01,  3.4926e-01,  7.3391e-01,
-         5.0577e-01,  4.3959e-02,  1.1035e+00, -1.7291e-01, -6.0415e-01,
-         1.0449e+00, -4.4390e-02, -2.4873e-01, -2.1702e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.9953e-01,  7.0912e-01, -4.3275e-01, -1.4428e-03,  2.9315e-01,
-         1.0506e-02, -5.8402e-01, -1.1115e+00, -6.0515e-01, -6.1241e-01,
-        -2.2131e-01, -3.2038e-01, -1.4874e-02,  1.9834e-01, -8.6796e-02,
-        -5.8208e-01, -7.8766e-01,  6.9730e-01, -8.6626e-01, -6.5007e-01,
-        -3.2225e-01, -1.5535e-01, -5.8389e-01, -3.1289e-01, -1.3179e-01,
-         2.8526e-01,  6.9646e-02, -2.2773e-01,  8.0369e-02,  1.5170e+00,
-        -9.0445e-01,  4.7837e-02, -8.2241e-01, -2.9930e-01, -9.1316e-01,
-        -3.1202e-01,  4.8882e-01,  0.0000e+00,  3.3642e-01,  4.8637e-01,
-         4.0246e-01, -4.2253e-02,  2.0668e-02,  6.9341e-01,  5.9713e-01,
-         6.9088e-01,  3.7560e-01,  4.3223e-01,  6.7582e-01, -5.5512e-01,
-         5.2483e-01, -1.2274e-01,  4.8765e-01,  3.4926e-01,  7.3391e-01,
-         5.0577e-01,  4.3959e-02,  1.1035e+00, -1.7291e-01, -6.0415e-01,
-         1.0449e+00, -4.4390e-02, -2.4873e-01, -2.1702e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.9953e-01,  7.0912e-01, -4.3275e-01, -1.4428e-03,  2.9315e-01,
-         1.0506e-02, -5.8402e-01, -1.1115e+00, -6.0515e-01, -6.1241e-01,
-        -2.2131e-01, -3.2038e-01, -1.4874e-02,  1.9834e-01, -8.6796e-02,
-        -5.8208e-01, -7.8766e-01,  6.9730e-01, -8.6626e-01, -6.5007e-01,
-        -3.2225e-01, -1.5535e-01, -5.8389e-01, -3.1289e-01, -1.3179e-01,
-         2.8526e-01,  6.9646e-02, -2.2773e-01,  8.0369e-02,  1.5170e+00,
-        -9.0445e-01,  4.7837e-02, -8.2241e-01, -2.9930e-01, -9.1316e-01,
-        -3.1202e-01,  4.8882e-01,  0.0000e+00,  3.3642e-01,  4.8637e-01,
-         4.0246e-01, -4.2253e-02,  2.0668e-02,  6.9341e-01,  5.9713e-01,
-         6.9088e-01,  3.7560e-01,  4.3223e-01,  6.7582e-01, -5.5512e-01,
-         5.2483e-01, -1.2274e-01,  4.8765e-01,  3.4926e-01,  7.3391e-01,
-         5.0577e-01,  4.3959e-02,  1.1035e+00, -1.7291e-01, -6.0415e-01,
-         1.0449e+00, -4.4390e-02, -2.4873e-01, -2.1702e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8966,  0.7082, -0.4260, -0.0222,  0.2990,  0.0209, -0.5841, -1.1111,
-        -0.6092, -0.6098, -0.2316, -0.3238, -0.0132,  0.1833, -0.0753, -0.5849,
-        -0.7902,  0.7016, -0.8677, -0.6484, -0.3175, -0.1594, -0.5821, -0.3115,
-        -0.1319,  0.2791,  0.0780, -0.2524,  0.0773,  1.5186, -0.9038,  0.0445,
-        -0.8272, -0.3124, -0.9131, -0.3129,  0.4872,  0.0000,  0.3490,  0.4820,
-         0.4056,  0.0119,  0.0243,  0.6942,  0.6015,  0.6988,  0.3768,  0.4351,
-         0.6832, -0.5567,  0.5229, -0.1428,  0.4861,  0.3546,  0.7328,  0.5069,
-         0.0747,  1.1016, -0.1584, -0.5987,  1.0473, -0.0401, -0.2403, -0.2175],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8966,  0.7082, -0.4260, -0.0222,  0.2990,  0.0209, -0.5841, -1.1111,
-        -0.6092, -0.6098, -0.2316, -0.3238, -0.0132,  0.1833, -0.0753, -0.5849,
-        -0.7902,  0.7016, -0.8677, -0.6484, -0.3175, -0.1594, -0.5821, -0.3115,
-        -0.1319,  0.2791,  0.0780, -0.2524,  0.0773,  1.5186, -0.9038,  0.0445,
-        -0.8272, -0.3124, -0.9131, -0.3129,  0.4872,  0.0000,  0.3490,  0.4820,
-         0.4056,  0.0119,  0.0243,  0.6942,  0.6015,  0.6988,  0.3768,  0.4351,
-         0.6832, -0.5567,  0.5229, -0.1428,  0.4861,  0.3546,  0.7328,  0.5069,
-         0.0747,  1.1016, -0.1584, -0.5987,  1.0473, -0.0401, -0.2403, -0.2175],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8966,  0.7082, -0.4260, -0.0222,  0.2990,  0.0209, -0.5841, -1.1111,
-        -0.6092, -0.6098, -0.2316, -0.3238, -0.0132,  0.1833, -0.0753, -0.5849,
-        -0.7902,  0.7016, -0.8677, -0.6484, -0.3175, -0.1594, -0.5821, -0.3115,
-        -0.1319,  0.2791,  0.0780, -0.2524,  0.0773,  1.5186, -0.9038,  0.0445,
-        -0.8272, -0.3124, -0.9131, -0.3129,  0.4872,  0.0000,  0.3490,  0.4820,
-         0.4056,  0.0119,  0.0243,  0.6942,  0.6015,  0.6988,  0.3768,  0.4351,
-         0.6832, -0.5567,  0.5229, -0.1428,  0.4861,  0.3546,  0.7328,  0.5069,
-         0.0747,  1.1016, -0.1584, -0.5987,  1.0473, -0.0401, -0.2403, -0.2175],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8921,  0.7066, -0.4191, -0.0540,  0.3021,  0.0305, -0.5838, -1.1109,
-        -0.6133, -0.6066, -0.2385, -0.3286, -0.0122,  0.1706, -0.0707, -0.5865,
-        -0.7919,  0.7057, -0.8701, -0.6469, -0.3092, -0.1595, -0.5791, -0.3110,
-        -0.1364,  0.2734,  0.0856, -0.2766,  0.0720,  1.5202, -0.9028,  0.0405,
-        -0.8321, -0.3168, -0.9132, -0.3138,  0.4866,  0.0000,  0.3643,  0.4784,
-         0.4109,  0.0609,  0.0203,  0.6947,  0.6057,  0.7060,  0.3785,  0.4384,
-         0.6923, -0.5580,  0.5187, -0.1607,  0.4832,  0.3625,  0.7316,  0.5056,
-         0.1058,  1.0999, -0.1462, -0.5906,  1.0499, -0.0356, -0.2278, -0.2146],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8921,  0.7066, -0.4191, -0.0540,  0.3021,  0.0305, -0.5838, -1.1109,
-        -0.6133, -0.6066, -0.2385, -0.3286, -0.0122,  0.1706, -0.0707, -0.5865,
-        -0.7919,  0.7057, -0.8701, -0.6469, -0.3092, -0.1595, -0.5791, -0.3110,
-        -0.1364,  0.2734,  0.0856, -0.2766,  0.0720,  1.5202, -0.9028,  0.0405,
-        -0.8321, -0.3168, -0.9132, -0.3138,  0.4866,  0.0000,  0.3643,  0.4784,
-         0.4109,  0.0609,  0.0203,  0.6947,  0.6057,  0.7060,  0.3785,  0.4384,
-         0.6923, -0.5580,  0.5187, -0.1607,  0.4832,  0.3625,  0.7316,  0.5056,
-         0.1058,  1.0999, -0.1462, -0.5906,  1.0499, -0.0356, -0.2278, -0.2146],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8921,  0.7066, -0.4191, -0.0540,  0.3021,  0.0305, -0.5838, -1.1109,
-        -0.6133, -0.6066, -0.2385, -0.3286, -0.0122,  0.1706, -0.0707, -0.5865,
-        -0.7919,  0.7057, -0.8701, -0.6469, -0.3092, -0.1595, -0.5791, -0.3110,
-        -0.1364,  0.2734,  0.0856, -0.2766,  0.0720,  1.5202, -0.9028,  0.0405,
-        -0.8321, -0.3168, -0.9132, -0.3138,  0.4866,  0.0000,  0.3643,  0.4784,
-         0.4109,  0.0609,  0.0203,  0.6947,  0.6057,  0.7060,  0.3785,  0.4384,
-         0.6923, -0.5580,  0.5187, -0.1607,  0.4832,  0.3625,  0.7316,  0.5056,
-         0.1058,  1.0999, -0.1462, -0.5906,  1.0499, -0.0356, -0.2278, -0.2146],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 8.8733e-01,  7.0062e-01, -4.1233e-01, -7.7788e-02,  3.1067e-01,
-         4.2623e-02, -5.8435e-01, -1.1099e+00, -6.1735e-01, -6.0454e-01,
-        -2.4445e-01, -3.3174e-01,  1.1368e-03,  1.5474e-01, -6.9680e-02,
-        -5.9007e-01, -7.9353e-01,  7.1062e-01, -8.7218e-01, -6.4871e-01,
-        -2.9790e-01, -1.5554e-01, -5.7366e-01, -3.1225e-01, -1.3725e-01,
-         2.6970e-01,  8.2326e-02, -2.9924e-01,  5.9300e-02,  1.5219e+00,
-        -9.0248e-01,  1.8600e-02, -8.3580e-01, -3.2375e-01, -9.1369e-01,
-        -3.1447e-01,  4.8534e-01,  0.0000e+00,  3.8451e-01,  4.7385e-01,
-         4.1585e-01,  9.6695e-02,  1.9922e-02,  6.9430e-01,  6.0903e-01,
-         7.1289e-01,  3.7740e-01,  4.4350e-01,  7.0198e-01, -5.5883e-01,
-         5.1737e-01, -1.7847e-01,  4.7834e-01,  3.6765e-01,  7.2901e-01,
-         5.0385e-01,  1.3124e-01,  1.0983e+00, -1.3111e-01, -5.8349e-01,
-         1.0523e+00, -2.6574e-02, -2.0960e-01, -2.1107e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.8733e-01,  7.0062e-01, -4.1233e-01, -7.7788e-02,  3.1067e-01,
-         4.2623e-02, -5.8435e-01, -1.1099e+00, -6.1735e-01, -6.0454e-01,
-        -2.4445e-01, -3.3174e-01,  1.1368e-03,  1.5474e-01, -6.9680e-02,
-        -5.9007e-01, -7.9353e-01,  7.1062e-01, -8.7218e-01, -6.4871e-01,
-        -2.9790e-01, -1.5554e-01, -5.7366e-01, -3.1225e-01, -1.3725e-01,
-         2.6970e-01,  8.2326e-02, -2.9924e-01,  5.9300e-02,  1.5219e+00,
-        -9.0248e-01,  1.8600e-02, -8.3580e-01, -3.2375e-01, -9.1369e-01,
-        -3.1447e-01,  4.8534e-01,  0.0000e+00,  3.8451e-01,  4.7385e-01,
-         4.1585e-01,  9.6695e-02,  1.9922e-02,  6.9430e-01,  6.0903e-01,
-         7.1289e-01,  3.7740e-01,  4.4350e-01,  7.0198e-01, -5.5883e-01,
-         5.1737e-01, -1.7847e-01,  4.7834e-01,  3.6765e-01,  7.2901e-01,
-         5.0385e-01,  1.3124e-01,  1.0983e+00, -1.3111e-01, -5.8349e-01,
-         1.0523e+00, -2.6574e-02, -2.0960e-01, -2.1107e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.8733e-01,  7.0062e-01, -4.1233e-01, -7.7788e-02,  3.1067e-01,
-         4.2623e-02, -5.8435e-01, -1.1099e+00, -6.1735e-01, -6.0454e-01,
-        -2.4445e-01, -3.3174e-01,  1.1368e-03,  1.5474e-01, -6.9680e-02,
-        -5.9007e-01, -7.9353e-01,  7.1062e-01, -8.7218e-01, -6.4871e-01,
-        -2.9790e-01, -1.5554e-01, -5.7366e-01, -3.1225e-01, -1.3725e-01,
-         2.6970e-01,  8.2326e-02, -2.9924e-01,  5.9300e-02,  1.5219e+00,
-        -9.0248e-01,  1.8600e-02, -8.3580e-01, -3.2375e-01, -9.1369e-01,
-        -3.1447e-01,  4.8534e-01,  0.0000e+00,  3.8451e-01,  4.7385e-01,
-         4.1585e-01,  9.6695e-02,  1.9922e-02,  6.9430e-01,  6.0903e-01,
-         7.1289e-01,  3.7740e-01,  4.4350e-01,  7.0198e-01, -5.5883e-01,
-         5.1737e-01, -1.7847e-01,  4.7834e-01,  3.6765e-01,  7.2901e-01,
-         5.0385e-01,  1.3124e-01,  1.0983e+00, -1.3111e-01, -5.8349e-01,
-         1.0523e+00, -2.6574e-02, -2.0960e-01, -2.1107e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8829,  0.6941, -0.4042, -0.0971,  0.3114,  0.0577, -0.5844, -1.1088,
-        -0.6198, -0.5994, -0.2424, -0.3353,  0.0118,  0.1344, -0.0753, -0.5916,
-        -0.7958,  0.7153, -0.8761, -0.6528, -0.2810, -0.1473, -0.5676, -0.3071,
-        -0.1384,  0.2696,  0.0605, -0.3190,  0.0451,  1.5236, -0.9026, -0.0098,
-        -0.8394, -0.3253, -0.9138, -0.3149,  0.4830,  0.0000,  0.3991,  0.4678,
-         0.4197,  0.1129,  0.0153,  0.6925,  0.6120,  0.7189,  0.3778,  0.4490,
-         0.7110, -0.5598,  0.5173, -0.1967,  0.4752,  0.3731,  0.7269,  0.5036,
-         0.1593,  1.0955, -0.1055, -0.5749,  1.0551, -0.0143, -0.1960, -0.2098],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8829,  0.6941, -0.4042, -0.0971,  0.3114,  0.0577, -0.5844, -1.1088,
-        -0.6198, -0.5994, -0.2424, -0.3353,  0.0118,  0.1344, -0.0753, -0.5916,
-        -0.7958,  0.7153, -0.8761, -0.6528, -0.2810, -0.1473, -0.5676, -0.3071,
-        -0.1384,  0.2696,  0.0605, -0.3190,  0.0451,  1.5236, -0.9026, -0.0098,
-        -0.8394, -0.3253, -0.9138, -0.3149,  0.4830,  0.0000,  0.3991,  0.4678,
-         0.4197,  0.1129,  0.0153,  0.6925,  0.6120,  0.7189,  0.3778,  0.4490,
-         0.7110, -0.5598,  0.5173, -0.1967,  0.4752,  0.3731,  0.7269,  0.5036,
-         0.1593,  1.0955, -0.1055, -0.5749,  1.0551, -0.0143, -0.1960, -0.2098],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8829,  0.6941, -0.4042, -0.0971,  0.3114,  0.0577, -0.5844, -1.1088,
-        -0.6198, -0.5994, -0.2424, -0.3353,  0.0118,  0.1344, -0.0753, -0.5916,
-        -0.7958,  0.7153, -0.8761, -0.6528, -0.2810, -0.1473, -0.5676, -0.3071,
-        -0.1384,  0.2696,  0.0605, -0.3190,  0.0451,  1.5236, -0.9026, -0.0098,
-        -0.8394, -0.3253, -0.9138, -0.3149,  0.4830,  0.0000,  0.3991,  0.4678,
-         0.4197,  0.1129,  0.0153,  0.6925,  0.6120,  0.7189,  0.3778,  0.4490,
-         0.7110, -0.5598,  0.5173, -0.1967,  0.4752,  0.3731,  0.7269,  0.5036,
-         0.1593,  1.0955, -0.1055, -0.5749,  1.0551, -0.0143, -0.1960, -0.2098],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8779,  0.6892, -0.3961, -0.1161,  0.3117,  0.0732, -0.5847, -1.1081,
-        -0.6209, -0.5901, -0.2408, -0.3417,  0.0198,  0.1118, -0.1008, -0.5939,
-        -0.7971,  0.7210, -0.8806, -0.6565, -0.2615, -0.1348, -0.5635, -0.3020,
-        -0.1348,  0.2640,  0.0343, -0.3340,  0.0354,  1.5255, -0.9018, -0.0435,
-        -0.8422, -0.3188, -0.9144, -0.3172,  0.4829,  0.0000,  0.4122,  0.4616,
-         0.4241,  0.1365,  0.0117,  0.6883,  0.6151,  0.7234,  0.3807,  0.4526,
-         0.7203, -0.5618,  0.5181, -0.2124,  0.4740,  0.3762,  0.7257,  0.5043,
-         0.1846,  1.0937, -0.0772, -0.5641,  1.0571, -0.0026, -0.1845, -0.2138],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8779,  0.6892, -0.3961, -0.1161,  0.3117,  0.0732, -0.5847, -1.1081,
-        -0.6209, -0.5901, -0.2408, -0.3417,  0.0198,  0.1118, -0.1008, -0.5939,
-        -0.7971,  0.7210, -0.8806, -0.6565, -0.2615, -0.1348, -0.5635, -0.3020,
-        -0.1348,  0.2640,  0.0343, -0.3340,  0.0354,  1.5255, -0.9018, -0.0435,
-        -0.8422, -0.3188, -0.9144, -0.3172,  0.4829,  0.0000,  0.4122,  0.4616,
-         0.4241,  0.1365,  0.0117,  0.6883,  0.6151,  0.7234,  0.3807,  0.4526,
-         0.7203, -0.5618,  0.5181, -0.2124,  0.4740,  0.3762,  0.7257,  0.5043,
-         0.1846,  1.0937, -0.0772, -0.5641,  1.0571, -0.0026, -0.1845, -0.2138],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8779,  0.6892, -0.3961, -0.1161,  0.3117,  0.0732, -0.5847, -1.1081,
-        -0.6209, -0.5901, -0.2408, -0.3417,  0.0198,  0.1118, -0.1008, -0.5939,
-        -0.7971,  0.7210, -0.8806, -0.6565, -0.2615, -0.1348, -0.5635, -0.3020,
-        -0.1348,  0.2640,  0.0343, -0.3340,  0.0354,  1.5255, -0.9018, -0.0435,
-        -0.8422, -0.3188, -0.9144, -0.3172,  0.4829,  0.0000,  0.4122,  0.4616,
-         0.4241,  0.1365,  0.0117,  0.6883,  0.6151,  0.7234,  0.3807,  0.4526,
-         0.7203, -0.5618,  0.5181, -0.2124,  0.4740,  0.3762,  0.7257,  0.5043,
-         0.1846,  1.0937, -0.0772, -0.5641,  1.0571, -0.0026, -0.1845, -0.2138],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8739,  0.6856, -0.3867, -0.1387,  0.3066,  0.0839, -0.5852, -1.1074,
-        -0.6220, -0.5830, -0.2374, -0.3520,  0.0312,  0.0856, -0.1342, -0.5966,
-        -0.7993,  0.7242, -0.8850, -0.6611, -0.2406, -0.1247, -0.5604, -0.2984,
-        -0.1307,  0.2561,  0.0038, -0.3484,  0.0280,  1.5269, -0.8999, -0.0892,
-        -0.8443, -0.3069, -0.9151, -0.3222,  0.4856,  0.0000,  0.4274,  0.4577,
-         0.4249,  0.1660,  0.0061,  0.6843,  0.6194,  0.7249,  0.3858,  0.4515,
-         0.7284, -0.5639,  0.5183, -0.2260,  0.4725,  0.3767,  0.7249,  0.5065,
-         0.2092,  1.0912, -0.0403, -0.5509,  1.0583,  0.0183, -0.1769, -0.2185],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8739,  0.6856, -0.3867, -0.1387,  0.3066,  0.0839, -0.5852, -1.1074,
-        -0.6220, -0.5830, -0.2374, -0.3520,  0.0312,  0.0856, -0.1342, -0.5966,
-        -0.7993,  0.7242, -0.8850, -0.6611, -0.2406, -0.1247, -0.5604, -0.2984,
-        -0.1307,  0.2561,  0.0038, -0.3484,  0.0280,  1.5269, -0.8999, -0.0892,
-        -0.8443, -0.3069, -0.9151, -0.3222,  0.4856,  0.0000,  0.4274,  0.4577,
-         0.4249,  0.1660,  0.0061,  0.6843,  0.6194,  0.7249,  0.3858,  0.4515,
-         0.7284, -0.5639,  0.5183, -0.2260,  0.4725,  0.3767,  0.7249,  0.5065,
-         0.2092,  1.0912, -0.0403, -0.5509,  1.0583,  0.0183, -0.1769, -0.2185],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8739,  0.6856, -0.3867, -0.1387,  0.3066,  0.0839, -0.5852, -1.1074,
-        -0.6220, -0.5830, -0.2374, -0.3520,  0.0312,  0.0856, -0.1342, -0.5966,
-        -0.7993,  0.7242, -0.8850, -0.6611, -0.2406, -0.1247, -0.5604, -0.2984,
-        -0.1307,  0.2561,  0.0038, -0.3484,  0.0280,  1.5269, -0.8999, -0.0892,
-        -0.8443, -0.3069, -0.9151, -0.3222,  0.4856,  0.0000,  0.4274,  0.4577,
-         0.4249,  0.1660,  0.0061,  0.6843,  0.6194,  0.7249,  0.3858,  0.4515,
-         0.7284, -0.5639,  0.5183, -0.2260,  0.4725,  0.3767,  0.7249,  0.5065,
-         0.2092,  1.0912, -0.0403, -0.5509,  1.0583,  0.0183, -0.1769, -0.2185],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8680,  0.6834, -0.3778, -0.1633,  0.2926,  0.0991, -0.5862, -1.1067,
-        -0.6225, -0.5777, -0.2308, -0.3628,  0.0319,  0.0578, -0.1767, -0.6005,
-        -0.8013,  0.7279, -0.8890, -0.6664, -0.2191, -0.1067, -0.5558, -0.2917,
-        -0.1206,  0.2519, -0.0161, -0.3618,  0.0190,  1.5280, -0.8977, -0.1352,
-        -0.8464, -0.2889, -0.9153, -0.3272,  0.4932,  0.0000,  0.4445,  0.4547,
-         0.4261,  0.1892, -0.0083,  0.6790,  0.6249,  0.7243,  0.3969,  0.4589,
-         0.7368, -0.5658,  0.5168, -0.2377,  0.4727,  0.3698,  0.7230,  0.5116,
-         0.2350,  1.0883, -0.0025, -0.5301,  1.0601,  0.0364, -0.1706, -0.2170],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8680,  0.6834, -0.3778, -0.1633,  0.2926,  0.0991, -0.5862, -1.1067,
-        -0.6225, -0.5777, -0.2308, -0.3628,  0.0319,  0.0578, -0.1767, -0.6005,
-        -0.8013,  0.7279, -0.8890, -0.6664, -0.2191, -0.1067, -0.5558, -0.2917,
-        -0.1206,  0.2519, -0.0161, -0.3618,  0.0190,  1.5280, -0.8977, -0.1352,
-        -0.8464, -0.2889, -0.9153, -0.3272,  0.4932,  0.0000,  0.4445,  0.4547,
-         0.4261,  0.1892, -0.0083,  0.6790,  0.6249,  0.7243,  0.3969,  0.4589,
-         0.7368, -0.5658,  0.5168, -0.2377,  0.4727,  0.3698,  0.7230,  0.5116,
-         0.2350,  1.0883, -0.0025, -0.5301,  1.0601,  0.0364, -0.1706, -0.2170],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8680,  0.6834, -0.3778, -0.1633,  0.2926,  0.0991, -0.5862, -1.1067,
-        -0.6225, -0.5777, -0.2308, -0.3628,  0.0319,  0.0578, -0.1767, -0.6005,
-        -0.8013,  0.7279, -0.8890, -0.6664, -0.2191, -0.1067, -0.5558, -0.2917,
-        -0.1206,  0.2519, -0.0161, -0.3618,  0.0190,  1.5280, -0.8977, -0.1352,
-        -0.8464, -0.2889, -0.9153, -0.3272,  0.4932,  0.0000,  0.4445,  0.4547,
-         0.4261,  0.1892, -0.0083,  0.6790,  0.6249,  0.7243,  0.3969,  0.4589,
-         0.7368, -0.5658,  0.5168, -0.2377,  0.4727,  0.3698,  0.7230,  0.5116,
-         0.2350,  1.0883, -0.0025, -0.5301,  1.0601,  0.0364, -0.1706, -0.2170],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8616,  0.6784, -0.3682, -0.1885,  0.2735,  0.1160, -0.5898, -1.1059,
-        -0.6250, -0.5733, -0.2214, -0.3744,  0.0364,  0.0171, -0.2265, -0.6058,
-        -0.8023,  0.7294, -0.8921, -0.6732, -0.1979, -0.0874, -0.5504, -0.2777,
-        -0.1009,  0.2484, -0.0350, -0.3753, -0.0033,  1.5290, -0.8963, -0.1809,
-        -0.8487, -0.2754, -0.9149, -0.3340,  0.5018,  0.0000,  0.4606,  0.4533,
-         0.4260,  0.2003, -0.0252,  0.6753,  0.6305,  0.7240,  0.4099,  0.4680,
-         0.7452, -0.5672,  0.5151, -0.2510,  0.4737,  0.3600,  0.7202,  0.5161,
-         0.2677,  1.0844,  0.0226, -0.5053,  1.0608,  0.0481, -0.1564, -0.2171],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8616,  0.6784, -0.3682, -0.1885,  0.2735,  0.1160, -0.5898, -1.1059,
-        -0.6250, -0.5733, -0.2214, -0.3744,  0.0364,  0.0171, -0.2265, -0.6058,
-        -0.8023,  0.7294, -0.8921, -0.6732, -0.1979, -0.0874, -0.5504, -0.2777,
-        -0.1009,  0.2484, -0.0350, -0.3753, -0.0033,  1.5290, -0.8963, -0.1809,
-        -0.8487, -0.2754, -0.9149, -0.3340,  0.5018,  0.0000,  0.4606,  0.4533,
-         0.4260,  0.2003, -0.0252,  0.6753,  0.6305,  0.7240,  0.4099,  0.4680,
-         0.7452, -0.5672,  0.5151, -0.2510,  0.4737,  0.3600,  0.7202,  0.5161,
-         0.2677,  1.0844,  0.0226, -0.5053,  1.0608,  0.0481, -0.1564, -0.2171],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8616,  0.6784, -0.3682, -0.1885,  0.2735,  0.1160, -0.5898, -1.1059,
-        -0.6250, -0.5733, -0.2214, -0.3744,  0.0364,  0.0171, -0.2265, -0.6058,
-        -0.8023,  0.7294, -0.8921, -0.6732, -0.1979, -0.0874, -0.5504, -0.2777,
-        -0.1009,  0.2484, -0.0350, -0.3753, -0.0033,  1.5290, -0.8963, -0.1809,
-        -0.8487, -0.2754, -0.9149, -0.3340,  0.5018,  0.0000,  0.4606,  0.4533,
-         0.4260,  0.2003, -0.0252,  0.6753,  0.6305,  0.7240,  0.4099,  0.4680,
-         0.7452, -0.5672,  0.5151, -0.2510,  0.4737,  0.3600,  0.7202,  0.5161,
-         0.2677,  1.0844,  0.0226, -0.5053,  1.0608,  0.0481, -0.1564, -0.2171],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8549,  0.6751, -0.3571, -0.2131,  0.2501,  0.1341, -0.5920, -1.1055,
-        -0.6260, -0.5680, -0.2080, -0.3959,  0.0295, -0.0249, -0.2799, -0.6121,
-        -0.8053,  0.7285, -0.8933, -0.6800, -0.1791, -0.0660, -0.5444, -0.2574,
-        -0.0784,  0.2468, -0.0395, -0.3894, -0.0221,  1.5304, -0.8946, -0.2240,
-        -0.8512, -0.2574, -0.9139, -0.3364,  0.5116,  0.0000,  0.4743,  0.4500,
-         0.4231,  0.2076, -0.0407,  0.6681,  0.6368,  0.7255,  0.4250,  0.4769,
-         0.7513, -0.5687,  0.5142, -0.2660,  0.4720,  0.3462,  0.7168,  0.5217,
-         0.3055,  1.0805,  0.0521, -0.4805,  1.0611,  0.0531, -0.1400, -0.2094],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8549,  0.6751, -0.3571, -0.2131,  0.2501,  0.1341, -0.5920, -1.1055,
-        -0.6260, -0.5680, -0.2080, -0.3959,  0.0295, -0.0249, -0.2799, -0.6121,
-        -0.8053,  0.7285, -0.8933, -0.6800, -0.1791, -0.0660, -0.5444, -0.2574,
-        -0.0784,  0.2468, -0.0395, -0.3894, -0.0221,  1.5304, -0.8946, -0.2240,
-        -0.8512, -0.2574, -0.9139, -0.3364,  0.5116,  0.0000,  0.4743,  0.4500,
-         0.4231,  0.2076, -0.0407,  0.6681,  0.6368,  0.7255,  0.4250,  0.4769,
-         0.7513, -0.5687,  0.5142, -0.2660,  0.4720,  0.3462,  0.7168,  0.5217,
-         0.3055,  1.0805,  0.0521, -0.4805,  1.0611,  0.0531, -0.1400, -0.2094],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8549,  0.6751, -0.3571, -0.2131,  0.2501,  0.1341, -0.5920, -1.1055,
-        -0.6260, -0.5680, -0.2080, -0.3959,  0.0295, -0.0249, -0.2799, -0.6121,
-        -0.8053,  0.7285, -0.8933, -0.6800, -0.1791, -0.0660, -0.5444, -0.2574,
-        -0.0784,  0.2468, -0.0395, -0.3894, -0.0221,  1.5304, -0.8946, -0.2240,
-        -0.8512, -0.2574, -0.9139, -0.3364,  0.5116,  0.0000,  0.4743,  0.4500,
-         0.4231,  0.2076, -0.0407,  0.6681,  0.6368,  0.7255,  0.4250,  0.4769,
-         0.7513, -0.5687,  0.5142, -0.2660,  0.4720,  0.3462,  0.7168,  0.5217,
-         0.3055,  1.0805,  0.0521, -0.4805,  1.0611,  0.0531, -0.1400, -0.2094],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8488,  0.6713, -0.3420, -0.2311,  0.2361,  0.1492, -0.5959, -1.1047,
-        -0.6284, -0.5644, -0.2015, -0.4121,  0.0281, -0.0554, -0.3284, -0.6185,
-        -0.8091,  0.7275, -0.8953, -0.6835, -0.1505, -0.0481, -0.5416, -0.2526,
-        -0.0563,  0.2499, -0.0381, -0.4028, -0.0447,  1.5320, -0.8926, -0.2716,
-        -0.8546, -0.2521, -0.9124, -0.3346,  0.5216,  0.0000,  0.4868,  0.4423,
-         0.4171,  0.2135, -0.0556,  0.6617,  0.6424,  0.7270,  0.4396,  0.4858,
-         0.7557, -0.5690,  0.5154, -0.2789,  0.4696,  0.3254,  0.7104,  0.5248,
-         0.3405,  1.0751,  0.0709, -0.4601,  1.0613,  0.0571, -0.1228, -0.1892],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8488,  0.6713, -0.3420, -0.2311,  0.2361,  0.1492, -0.5959, -1.1047,
-        -0.6284, -0.5644, -0.2015, -0.4121,  0.0281, -0.0554, -0.3284, -0.6185,
-        -0.8091,  0.7275, -0.8953, -0.6835, -0.1505, -0.0481, -0.5416, -0.2526,
-        -0.0563,  0.2499, -0.0381, -0.4028, -0.0447,  1.5320, -0.8926, -0.2716,
-        -0.8546, -0.2521, -0.9124, -0.3346,  0.5216,  0.0000,  0.4868,  0.4423,
-         0.4171,  0.2135, -0.0556,  0.6617,  0.6424,  0.7270,  0.4396,  0.4858,
-         0.7557, -0.5690,  0.5154, -0.2789,  0.4696,  0.3254,  0.7104,  0.5248,
-         0.3405,  1.0751,  0.0709, -0.4601,  1.0613,  0.0571, -0.1228, -0.1892],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8488,  0.6713, -0.3420, -0.2311,  0.2361,  0.1492, -0.5959, -1.1047,
-        -0.6284, -0.5644, -0.2015, -0.4121,  0.0281, -0.0554, -0.3284, -0.6185,
-        -0.8091,  0.7275, -0.8953, -0.6835, -0.1505, -0.0481, -0.5416, -0.2526,
-        -0.0563,  0.2499, -0.0381, -0.4028, -0.0447,  1.5320, -0.8926, -0.2716,
-        -0.8546, -0.2521, -0.9124, -0.3346,  0.5216,  0.0000,  0.4868,  0.4423,
-         0.4171,  0.2135, -0.0556,  0.6617,  0.6424,  0.7270,  0.4396,  0.4858,
-         0.7557, -0.5690,  0.5154, -0.2789,  0.4696,  0.3254,  0.7104,  0.5248,
-         0.3405,  1.0751,  0.0709, -0.4601,  1.0613,  0.0571, -0.1228, -0.1892],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8435,  0.6673, -0.3222, -0.2449,  0.2230,  0.1679, -0.5980, -1.1045,
-        -0.6295, -0.5626, -0.1915, -0.4256,  0.0411, -0.0720, -0.3773, -0.6272,
-        -0.8136,  0.7240, -0.8970, -0.6882, -0.1252, -0.0254, -0.5410, -0.2429,
-        -0.0341,  0.2560, -0.0352, -0.4209, -0.0699,  1.5331, -0.8895, -0.3166,
-        -0.8590, -0.2454, -0.9102, -0.3352,  0.5339,  0.0000,  0.4966,  0.4346,
-         0.4086,  0.2417, -0.0645,  0.6585,  0.6485,  0.7280,  0.4546,  0.4933,
-         0.7576, -0.5691,  0.5169, -0.2875,  0.4734,  0.3007,  0.7026,  0.5277,
-         0.3785,  1.0689,  0.0888, -0.4440,  1.0613,  0.0757, -0.1090, -0.1641],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8435,  0.6673, -0.3222, -0.2449,  0.2230,  0.1679, -0.5980, -1.1045,
-        -0.6295, -0.5626, -0.1915, -0.4256,  0.0411, -0.0720, -0.3773, -0.6272,
-        -0.8136,  0.7240, -0.8970, -0.6882, -0.1252, -0.0254, -0.5410, -0.2429,
-        -0.0341,  0.2560, -0.0352, -0.4209, -0.0699,  1.5331, -0.8895, -0.3166,
-        -0.8590, -0.2454, -0.9102, -0.3352,  0.5339,  0.0000,  0.4966,  0.4346,
-         0.4086,  0.2417, -0.0645,  0.6585,  0.6485,  0.7280,  0.4546,  0.4933,
-         0.7576, -0.5691,  0.5169, -0.2875,  0.4734,  0.3007,  0.7026,  0.5277,
-         0.3785,  1.0689,  0.0888, -0.4440,  1.0613,  0.0757, -0.1090, -0.1641],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8435,  0.6673, -0.3222, -0.2449,  0.2230,  0.1679, -0.5980, -1.1045,
-        -0.6295, -0.5626, -0.1915, -0.4256,  0.0411, -0.0720, -0.3773, -0.6272,
-        -0.8136,  0.7240, -0.8970, -0.6882, -0.1252, -0.0254, -0.5410, -0.2429,
-        -0.0341,  0.2560, -0.0352, -0.4209, -0.0699,  1.5331, -0.8895, -0.3166,
-        -0.8590, -0.2454, -0.9102, -0.3352,  0.5339,  0.0000,  0.4966,  0.4346,
-         0.4086,  0.2417, -0.0645,  0.6585,  0.6485,  0.7280,  0.4546,  0.4933,
-         0.7576, -0.5691,  0.5169, -0.2875,  0.4734,  0.3007,  0.7026,  0.5277,
-         0.3785,  1.0689,  0.0888, -0.4440,  1.0613,  0.0757, -0.1090, -0.1641],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8391,  0.6691, -0.3004, -0.2495,  0.2097,  0.1873, -0.6013, -1.1041,
-        -0.6269, -0.5639, -0.1777, -0.4454,  0.0375, -0.0682, -0.4275, -0.6342,
-        -0.8192,  0.7200, -0.8989, -0.6911, -0.1006, -0.0115, -0.5464, -0.2337,
-        -0.0160,  0.2618, -0.0305, -0.4366, -0.0957,  1.5344, -0.8857, -0.3558,
-        -0.8624, -0.2297, -0.9083, -0.3328,  0.5464,  0.0000,  0.5053,  0.4262,
-         0.3974,  0.2753, -0.0859,  0.6565,  0.6551,  0.7265,  0.4721,  0.5000,
-         0.7602, -0.5681,  0.5192, -0.2972,  0.4786,  0.2767,  0.6934,  0.5301,
-         0.4225,  1.0631,  0.1032, -0.4325,  1.0621,  0.1015, -0.1037, -0.1397],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8391,  0.6691, -0.3004, -0.2495,  0.2097,  0.1873, -0.6013, -1.1041,
-        -0.6269, -0.5639, -0.1777, -0.4454,  0.0375, -0.0682, -0.4275, -0.6342,
-        -0.8192,  0.7200, -0.8989, -0.6911, -0.1006, -0.0115, -0.5464, -0.2337,
-        -0.0160,  0.2618, -0.0305, -0.4366, -0.0957,  1.5344, -0.8857, -0.3558,
-        -0.8624, -0.2297, -0.9083, -0.3328,  0.5464,  0.0000,  0.5053,  0.4262,
-         0.3974,  0.2753, -0.0859,  0.6565,  0.6551,  0.7265,  0.4721,  0.5000,
-         0.7602, -0.5681,  0.5192, -0.2972,  0.4786,  0.2767,  0.6934,  0.5301,
-         0.4225,  1.0631,  0.1032, -0.4325,  1.0621,  0.1015, -0.1037, -0.1397],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8391,  0.6691, -0.3004, -0.2495,  0.2097,  0.1873, -0.6013, -1.1041,
-        -0.6269, -0.5639, -0.1777, -0.4454,  0.0375, -0.0682, -0.4275, -0.6342,
-        -0.8192,  0.7200, -0.8989, -0.6911, -0.1006, -0.0115, -0.5464, -0.2337,
-        -0.0160,  0.2618, -0.0305, -0.4366, -0.0957,  1.5344, -0.8857, -0.3558,
-        -0.8624, -0.2297, -0.9083, -0.3328,  0.5464,  0.0000,  0.5053,  0.4262,
-         0.3974,  0.2753, -0.0859,  0.6565,  0.6551,  0.7265,  0.4721,  0.5000,
-         0.7602, -0.5681,  0.5192, -0.2972,  0.4786,  0.2767,  0.6934,  0.5301,
-         0.4225,  1.0631,  0.1032, -0.4325,  1.0621,  0.1015, -0.1037, -0.1397],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8361,  0.6720, -0.2792, -0.2438,  0.2012,  0.2060, -0.6036, -1.1035,
-        -0.6218, -0.5627, -0.1545, -0.4718,  0.0380, -0.0572, -0.4797, -0.6399,
-        -0.8268,  0.7146, -0.9006, -0.6946, -0.0867,  0.0023, -0.5481, -0.2410,
-         0.0043,  0.2607, -0.0221, -0.4514, -0.1151,  1.5356, -0.8815, -0.3965,
-        -0.8659, -0.2050, -0.9076, -0.3322,  0.5594,  0.0000,  0.5113,  0.4191,
-         0.3891,  0.3217, -0.1103,  0.6559,  0.6601,  0.7266,  0.4872,  0.5030,
-         0.7630, -0.5683,  0.5195, -0.3081,  0.4850,  0.2481,  0.6843,  0.5312,
-         0.4574,  1.0570,  0.1117, -0.4299,  1.0638,  0.1134, -0.1064, -0.1164],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8361,  0.6720, -0.2792, -0.2438,  0.2012,  0.2060, -0.6036, -1.1035,
-        -0.6218, -0.5627, -0.1545, -0.4718,  0.0380, -0.0572, -0.4797, -0.6399,
-        -0.8268,  0.7146, -0.9006, -0.6946, -0.0867,  0.0023, -0.5481, -0.2410,
-         0.0043,  0.2607, -0.0221, -0.4514, -0.1151,  1.5356, -0.8815, -0.3965,
-        -0.8659, -0.2050, -0.9076, -0.3322,  0.5594,  0.0000,  0.5113,  0.4191,
-         0.3891,  0.3217, -0.1103,  0.6559,  0.6601,  0.7266,  0.4872,  0.5030,
-         0.7630, -0.5683,  0.5195, -0.3081,  0.4850,  0.2481,  0.6843,  0.5312,
-         0.4574,  1.0570,  0.1117, -0.4299,  1.0638,  0.1134, -0.1064, -0.1164],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8361,  0.6720, -0.2792, -0.2438,  0.2012,  0.2060, -0.6036, -1.1035,
-        -0.6218, -0.5627, -0.1545, -0.4718,  0.0380, -0.0572, -0.4797, -0.6399,
-        -0.8268,  0.7146, -0.9006, -0.6946, -0.0867,  0.0023, -0.5481, -0.2410,
-         0.0043,  0.2607, -0.0221, -0.4514, -0.1151,  1.5356, -0.8815, -0.3965,
-        -0.8659, -0.2050, -0.9076, -0.3322,  0.5594,  0.0000,  0.5113,  0.4191,
-         0.3891,  0.3217, -0.1103,  0.6559,  0.6601,  0.7266,  0.4872,  0.5030,
-         0.7630, -0.5683,  0.5195, -0.3081,  0.4850,  0.2481,  0.6843,  0.5312,
-         0.4574,  1.0570,  0.1117, -0.4299,  1.0638,  0.1134, -0.1064, -0.1164],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8342,  0.6717, -0.2580, -0.2372,  0.2124,  0.2299, -0.6076, -1.1030,
-        -0.6166, -0.5620, -0.1467, -0.4903,  0.0508, -0.0319, -0.5286, -0.6476,
-        -0.8355,  0.7117, -0.9024, -0.6978, -0.0690,  0.0206, -0.5494, -0.2538,
-         0.0121,  0.2630, -0.0340, -0.4613, -0.1419,  1.5365, -0.8791, -0.4363,
-        -0.8703, -0.1913, -0.9066, -0.3341,  0.5699,  0.0000,  0.5182,  0.4144,
-         0.3789,  0.3720, -0.1299,  0.6548,  0.6638,  0.7298,  0.5034,  0.5029,
-         0.7638, -0.5697,  0.5211, -0.3174,  0.4900,  0.2090,  0.6759,  0.5319,
-         0.4914,  1.0517,  0.1078, -0.4367,  1.0664,  0.1249, -0.0992, -0.0911],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8342,  0.6717, -0.2580, -0.2372,  0.2124,  0.2299, -0.6076, -1.1030,
-        -0.6166, -0.5620, -0.1467, -0.4903,  0.0508, -0.0319, -0.5286, -0.6476,
-        -0.8355,  0.7117, -0.9024, -0.6978, -0.0690,  0.0206, -0.5494, -0.2538,
-         0.0121,  0.2630, -0.0340, -0.4613, -0.1419,  1.5365, -0.8791, -0.4363,
-        -0.8703, -0.1913, -0.9066, -0.3341,  0.5699,  0.0000,  0.5182,  0.4144,
-         0.3789,  0.3720, -0.1299,  0.6548,  0.6638,  0.7298,  0.5034,  0.5029,
-         0.7638, -0.5697,  0.5211, -0.3174,  0.4900,  0.2090,  0.6759,  0.5319,
-         0.4914,  1.0517,  0.1078, -0.4367,  1.0664,  0.1249, -0.0992, -0.0911],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8342,  0.6717, -0.2580, -0.2372,  0.2124,  0.2299, -0.6076, -1.1030,
-        -0.6166, -0.5620, -0.1467, -0.4903,  0.0508, -0.0319, -0.5286, -0.6476,
-        -0.8355,  0.7117, -0.9024, -0.6978, -0.0690,  0.0206, -0.5494, -0.2538,
-         0.0121,  0.2630, -0.0340, -0.4613, -0.1419,  1.5365, -0.8791, -0.4363,
-        -0.8703, -0.1913, -0.9066, -0.3341,  0.5699,  0.0000,  0.5182,  0.4144,
-         0.3789,  0.3720, -0.1299,  0.6548,  0.6638,  0.7298,  0.5034,  0.5029,
-         0.7638, -0.5697,  0.5211, -0.3174,  0.4900,  0.2090,  0.6759,  0.5319,
-         0.4914,  1.0517,  0.1078, -0.4367,  1.0664,  0.1249, -0.0992, -0.0911],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8320,  0.6782, -0.2257, -0.2280,  0.2331,  0.2634, -0.6142, -1.1017,
-        -0.6103, -0.5659, -0.1389, -0.5033,  0.0693, -0.0228, -0.5789, -0.6532,
-        -0.8451,  0.7116, -0.9060, -0.6979, -0.0542,  0.0530, -0.5582, -0.2583,
-         0.0170,  0.2651, -0.0416, -0.4759, -0.1735,  1.5370, -0.8771, -0.4708,
-        -0.8720, -0.1885, -0.9062, -0.3400,  0.5778,  0.0000,  0.5220,  0.4084,
-         0.3728,  0.4303, -0.1370,  0.6557,  0.6660,  0.7345,  0.5201,  0.4988,
-         0.7588, -0.5712,  0.5153, -0.3158,  0.5050,  0.1578,  0.6675,  0.5338,
-         0.5305,  1.0481,  0.1046, -0.4353,  1.0706,  0.1359, -0.0899, -0.0652],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8320,  0.6782, -0.2257, -0.2280,  0.2331,  0.2634, -0.6142, -1.1017,
-        -0.6103, -0.5659, -0.1389, -0.5033,  0.0693, -0.0228, -0.5789, -0.6532,
-        -0.8451,  0.7116, -0.9060, -0.6979, -0.0542,  0.0530, -0.5582, -0.2583,
-         0.0170,  0.2651, -0.0416, -0.4759, -0.1735,  1.5370, -0.8771, -0.4708,
-        -0.8720, -0.1885, -0.9062, -0.3400,  0.5778,  0.0000,  0.5220,  0.4084,
-         0.3728,  0.4303, -0.1370,  0.6557,  0.6660,  0.7345,  0.5201,  0.4988,
-         0.7588, -0.5712,  0.5153, -0.3158,  0.5050,  0.1578,  0.6675,  0.5338,
-         0.5305,  1.0481,  0.1046, -0.4353,  1.0706,  0.1359, -0.0899, -0.0652],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8320,  0.6782, -0.2257, -0.2280,  0.2331,  0.2634, -0.6142, -1.1017,
-        -0.6103, -0.5659, -0.1389, -0.5033,  0.0693, -0.0228, -0.5789, -0.6532,
-        -0.8451,  0.7116, -0.9060, -0.6979, -0.0542,  0.0530, -0.5582, -0.2583,
-         0.0170,  0.2651, -0.0416, -0.4759, -0.1735,  1.5370, -0.8771, -0.4708,
-        -0.8720, -0.1885, -0.9062, -0.3400,  0.5778,  0.0000,  0.5220,  0.4084,
-         0.3728,  0.4303, -0.1370,  0.6557,  0.6660,  0.7345,  0.5201,  0.4988,
-         0.7588, -0.5712,  0.5153, -0.3158,  0.5050,  0.1578,  0.6675,  0.5338,
-         0.5305,  1.0481,  0.1046, -0.4353,  1.0706,  0.1359, -0.0899, -0.0652],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8328,  0.6827, -0.2034, -0.2186,  0.2607,  0.3025, -0.6211, -1.1005,
-        -0.6012, -0.5698, -0.1202, -0.5168,  0.0740, -0.0193, -0.6270, -0.6569,
-        -0.8544,  0.7149, -0.9099, -0.6954, -0.0441,  0.0788, -0.5598, -0.2495,
-         0.0209,  0.2696, -0.0472, -0.4878, -0.2115,  1.5370, -0.8757, -0.5042,
-        -0.8744, -0.1809, -0.9056, -0.3410,  0.5850,  0.0000,  0.5241,  0.3953,
-         0.3655,  0.4912, -0.1603,  0.6543,  0.6681,  0.7375,  0.5376,  0.4919,
-         0.7544, -0.5737,  0.5116, -0.3163,  0.5206,  0.1029,  0.6601,  0.5383,
-         0.5704,  1.0450,  0.0972, -0.4387,  1.0750,  0.1367, -0.0733, -0.0433],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8328,  0.6827, -0.2034, -0.2186,  0.2607,  0.3025, -0.6211, -1.1005,
-        -0.6012, -0.5698, -0.1202, -0.5168,  0.0740, -0.0193, -0.6270, -0.6569,
-        -0.8544,  0.7149, -0.9099, -0.6954, -0.0441,  0.0788, -0.5598, -0.2495,
-         0.0209,  0.2696, -0.0472, -0.4878, -0.2115,  1.5370, -0.8757, -0.5042,
-        -0.8744, -0.1809, -0.9056, -0.3410,  0.5850,  0.0000,  0.5241,  0.3953,
-         0.3655,  0.4912, -0.1603,  0.6543,  0.6681,  0.7375,  0.5376,  0.4919,
-         0.7544, -0.5737,  0.5116, -0.3163,  0.5206,  0.1029,  0.6601,  0.5383,
-         0.5704,  1.0450,  0.0972, -0.4387,  1.0750,  0.1367, -0.0733, -0.0433],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8328,  0.6827, -0.2034, -0.2186,  0.2607,  0.3025, -0.6211, -1.1005,
-        -0.6012, -0.5698, -0.1202, -0.5168,  0.0740, -0.0193, -0.6270, -0.6569,
-        -0.8544,  0.7149, -0.9099, -0.6954, -0.0441,  0.0788, -0.5598, -0.2495,
-         0.0209,  0.2696, -0.0472, -0.4878, -0.2115,  1.5370, -0.8757, -0.5042,
-        -0.8744, -0.1809, -0.9056, -0.3410,  0.5850,  0.0000,  0.5241,  0.3953,
-         0.3655,  0.4912, -0.1603,  0.6543,  0.6681,  0.7375,  0.5376,  0.4919,
-         0.7544, -0.5737,  0.5116, -0.3163,  0.5206,  0.1029,  0.6601,  0.5383,
-         0.5704,  1.0450,  0.0972, -0.4387,  1.0750,  0.1367, -0.0733, -0.0433],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8351,  0.6848, -0.1897, -0.2048,  0.2838,  0.3446, -0.6252, -1.0997,
-        -0.5887, -0.5760, -0.0915, -0.5324,  0.0763, -0.0127, -0.6672, -0.6593,
-        -0.8647,  0.7176, -0.9129, -0.6936, -0.0493,  0.0977, -0.5597, -0.2150,
-         0.0413,  0.2816, -0.0432, -0.5008, -0.2447,  1.5366, -0.8754, -0.5314,
-        -0.8746, -0.1532, -0.9049, -0.3450,  0.5902,  0.0000,  0.5304,  0.3827,
-         0.3599,  0.5477, -0.1978,  0.6502,  0.6690,  0.7393,  0.5509,  0.4826,
-         0.7515, -0.5756,  0.5046, -0.3172,  0.5358,  0.0467,  0.6544,  0.5398,
-         0.6094,  1.0428,  0.0791, -0.4364,  1.0782,  0.1284, -0.0567, -0.0316],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8351,  0.6848, -0.1897, -0.2048,  0.2838,  0.3446, -0.6252, -1.0997,
-        -0.5887, -0.5760, -0.0915, -0.5324,  0.0763, -0.0127, -0.6672, -0.6593,
-        -0.8647,  0.7176, -0.9129, -0.6936, -0.0493,  0.0977, -0.5597, -0.2150,
-         0.0413,  0.2816, -0.0432, -0.5008, -0.2447,  1.5366, -0.8754, -0.5314,
-        -0.8746, -0.1532, -0.9049, -0.3450,  0.5902,  0.0000,  0.5304,  0.3827,
-         0.3599,  0.5477, -0.1978,  0.6502,  0.6690,  0.7393,  0.5509,  0.4826,
-         0.7515, -0.5756,  0.5046, -0.3172,  0.5358,  0.0467,  0.6544,  0.5398,
-         0.6094,  1.0428,  0.0791, -0.4364,  1.0782,  0.1284, -0.0567, -0.0316],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8351,  0.6848, -0.1897, -0.2048,  0.2838,  0.3446, -0.6252, -1.0997,
-        -0.5887, -0.5760, -0.0915, -0.5324,  0.0763, -0.0127, -0.6672, -0.6593,
-        -0.8647,  0.7176, -0.9129, -0.6936, -0.0493,  0.0977, -0.5597, -0.2150,
-         0.0413,  0.2816, -0.0432, -0.5008, -0.2447,  1.5366, -0.8754, -0.5314,
-        -0.8746, -0.1532, -0.9049, -0.3450,  0.5902,  0.0000,  0.5304,  0.3827,
-         0.3599,  0.5477, -0.1978,  0.6502,  0.6690,  0.7393,  0.5509,  0.4826,
-         0.7515, -0.5756,  0.5046, -0.3172,  0.5358,  0.0467,  0.6544,  0.5398,
-         0.6094,  1.0428,  0.0791, -0.4364,  1.0782,  0.1284, -0.0567, -0.0316],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8425,  0.6875, -0.1784, -0.1889,  0.3025,  0.3909, -0.6248, -1.0981,
-        -0.5758, -0.5842, -0.0779, -0.5512,  0.0906,  0.0071, -0.7033, -0.6596,
-        -0.8756,  0.7186, -0.9150, -0.6917, -0.0557,  0.0995, -0.5576, -0.1973,
-         0.0740,  0.3121, -0.0289, -0.5090, -0.2735,  1.5361, -0.8731, -0.5574,
-        -0.8751, -0.1274, -0.9027, -0.3456,  0.5930,  0.0000,  0.5420,  0.3823,
-         0.3565,  0.6008, -0.2406,  0.6478,  0.6726,  0.7348,  0.5655,  0.4840,
-         0.7448, -0.5772,  0.4974, -0.3192,  0.5496,  0.0103,  0.6487,  0.5401,
-         0.6463,  1.0404,  0.0565, -0.4403,  1.0826,  0.0867, -0.0219, -0.0071],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8425,  0.6875, -0.1784, -0.1889,  0.3025,  0.3909, -0.6248, -1.0981,
-        -0.5758, -0.5842, -0.0779, -0.5512,  0.0906,  0.0071, -0.7033, -0.6596,
-        -0.8756,  0.7186, -0.9150, -0.6917, -0.0557,  0.0995, -0.5576, -0.1973,
-         0.0740,  0.3121, -0.0289, -0.5090, -0.2735,  1.5361, -0.8731, -0.5574,
-        -0.8751, -0.1274, -0.9027, -0.3456,  0.5930,  0.0000,  0.5420,  0.3823,
-         0.3565,  0.6008, -0.2406,  0.6478,  0.6726,  0.7348,  0.5655,  0.4840,
-         0.7448, -0.5772,  0.4974, -0.3192,  0.5496,  0.0103,  0.6487,  0.5401,
-         0.6463,  1.0404,  0.0565, -0.4403,  1.0826,  0.0867, -0.0219, -0.0071],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8425,  0.6875, -0.1784, -0.1889,  0.3025,  0.3909, -0.6248, -1.0981,
-        -0.5758, -0.5842, -0.0779, -0.5512,  0.0906,  0.0071, -0.7033, -0.6596,
-        -0.8756,  0.7186, -0.9150, -0.6917, -0.0557,  0.0995, -0.5576, -0.1973,
-         0.0740,  0.3121, -0.0289, -0.5090, -0.2735,  1.5361, -0.8731, -0.5574,
-        -0.8751, -0.1274, -0.9027, -0.3456,  0.5930,  0.0000,  0.5420,  0.3823,
-         0.3565,  0.6008, -0.2406,  0.6478,  0.6726,  0.7348,  0.5655,  0.4840,
-         0.7448, -0.5772,  0.4974, -0.3192,  0.5496,  0.0103,  0.6487,  0.5401,
-         0.6463,  1.0404,  0.0565, -0.4403,  1.0826,  0.0867, -0.0219, -0.0071],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8496,  0.6888, -0.1583, -0.1660,  0.3228,  0.4375, -0.6218, -1.0961,
-        -0.5621, -0.5938, -0.0463, -0.5668,  0.1274,  0.0353, -0.7413, -0.6554,
-        -0.8861,  0.7168, -0.9164, -0.6914, -0.0501,  0.0992, -0.5501, -0.1789,
-         0.1146,  0.3535, -0.0037, -0.5176, -0.2938,  1.5358, -0.8712, -0.5825,
-        -0.8759, -0.0872, -0.9012, -0.3440,  0.5928,  0.0000,  0.5583,  0.3760,
-         0.3504,  0.6540, -0.2843,  0.6521,  0.6763,  0.7279,  0.5770,  0.4851,
-         0.7340, -0.5807,  0.4902, -0.3197,  0.5594, -0.0185,  0.6470,  0.5375,
-         0.6816,  1.0392,  0.0452, -0.4440,  1.0880,  0.0386,  0.0155,  0.0200],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8496,  0.6888, -0.1583, -0.1660,  0.3228,  0.4375, -0.6218, -1.0961,
-        -0.5621, -0.5938, -0.0463, -0.5668,  0.1274,  0.0353, -0.7413, -0.6554,
-        -0.8861,  0.7168, -0.9164, -0.6914, -0.0501,  0.0992, -0.5501, -0.1789,
-         0.1146,  0.3535, -0.0037, -0.5176, -0.2938,  1.5358, -0.8712, -0.5825,
-        -0.8759, -0.0872, -0.9012, -0.3440,  0.5928,  0.0000,  0.5583,  0.3760,
-         0.3504,  0.6540, -0.2843,  0.6521,  0.6763,  0.7279,  0.5770,  0.4851,
-         0.7340, -0.5807,  0.4902, -0.3197,  0.5594, -0.0185,  0.6470,  0.5375,
-         0.6816,  1.0392,  0.0452, -0.4440,  1.0880,  0.0386,  0.0155,  0.0200],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8496,  0.6888, -0.1583, -0.1660,  0.3228,  0.4375, -0.6218, -1.0961,
-        -0.5621, -0.5938, -0.0463, -0.5668,  0.1274,  0.0353, -0.7413, -0.6554,
-        -0.8861,  0.7168, -0.9164, -0.6914, -0.0501,  0.0992, -0.5501, -0.1789,
-         0.1146,  0.3535, -0.0037, -0.5176, -0.2938,  1.5358, -0.8712, -0.5825,
-        -0.8759, -0.0872, -0.9012, -0.3440,  0.5928,  0.0000,  0.5583,  0.3760,
-         0.3504,  0.6540, -0.2843,  0.6521,  0.6763,  0.7279,  0.5770,  0.4851,
-         0.7340, -0.5807,  0.4902, -0.3197,  0.5594, -0.0185,  0.6470,  0.5375,
-         0.6816,  1.0392,  0.0452, -0.4440,  1.0880,  0.0386,  0.0155,  0.0200],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8540,  0.6826, -0.1418, -0.1293,  0.3483,  0.4703, -0.6192, -1.0949,
-        -0.5517, -0.5965, -0.0411, -0.5794,  0.1471,  0.0631, -0.7827, -0.6545,
-        -0.8879,  0.7126, -0.9167, -0.6927, -0.0452,  0.1059, -0.5423, -0.1787,
-         0.1443,  0.3862,  0.0142, -0.5258, -0.3166,  1.5357, -0.8689, -0.6053,
-        -0.8779, -0.0517, -0.9016, -0.3421,  0.5954,  0.0000,  0.5813,  0.3591,
-         0.3447,  0.7102, -0.3187,  0.6618,  0.6797,  0.7215,  0.5902,  0.4847,
-         0.7254, -0.5833,  0.4949, -0.3267,  0.5613, -0.0523,  0.6475,  0.5382,
-         0.7158,  1.0384,  0.0236, -0.4734,  1.0927, -0.0100,  0.0488,  0.0483],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8540,  0.6826, -0.1418, -0.1293,  0.3483,  0.4703, -0.6192, -1.0949,
-        -0.5517, -0.5965, -0.0411, -0.5794,  0.1471,  0.0631, -0.7827, -0.6545,
-        -0.8879,  0.7126, -0.9167, -0.6927, -0.0452,  0.1059, -0.5423, -0.1787,
-         0.1443,  0.3862,  0.0142, -0.5258, -0.3166,  1.5357, -0.8689, -0.6053,
-        -0.8779, -0.0517, -0.9016, -0.3421,  0.5954,  0.0000,  0.5813,  0.3591,
-         0.3447,  0.7102, -0.3187,  0.6618,  0.6797,  0.7215,  0.5902,  0.4847,
-         0.7254, -0.5833,  0.4949, -0.3267,  0.5613, -0.0523,  0.6475,  0.5382,
-         0.7158,  1.0384,  0.0236, -0.4734,  1.0927, -0.0100,  0.0488,  0.0483],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8540,  0.6826, -0.1418, -0.1293,  0.3483,  0.4703, -0.6192, -1.0949,
-        -0.5517, -0.5965, -0.0411, -0.5794,  0.1471,  0.0631, -0.7827, -0.6545,
-        -0.8879,  0.7126, -0.9167, -0.6927, -0.0452,  0.1059, -0.5423, -0.1787,
-         0.1443,  0.3862,  0.0142, -0.5258, -0.3166,  1.5357, -0.8689, -0.6053,
-        -0.8779, -0.0517, -0.9016, -0.3421,  0.5954,  0.0000,  0.5813,  0.3591,
-         0.3447,  0.7102, -0.3187,  0.6618,  0.6797,  0.7215,  0.5902,  0.4847,
-         0.7254, -0.5833,  0.4949, -0.3267,  0.5613, -0.0523,  0.6475,  0.5382,
-         0.7158,  1.0384,  0.0236, -0.4734,  1.0927, -0.0100,  0.0488,  0.0483],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8557,  0.6765, -0.1242, -0.0930,  0.3557,  0.4981, -0.6172, -1.0929,
-        -0.5442, -0.5942, -0.0296, -0.5883,  0.1710,  0.0779, -0.8210, -0.6484,
-        -0.8914,  0.7124, -0.9192, -0.6930, -0.0256,  0.1264, -0.5343, -0.1782,
-         0.1817,  0.3988,  0.0091, -0.5358, -0.3322,  1.5361, -0.8679, -0.6253,
-        -0.8805, -0.0115, -0.9016, -0.3420,  0.5969,  0.0000,  0.5990,  0.3315,
-         0.3420,  0.7627, -0.3453,  0.6644,  0.6847,  0.7174,  0.6001,  0.4729,
-         0.7136, -0.5859,  0.4988, -0.3285,  0.5610, -0.0685,  0.6426,  0.5384,
-         0.7475,  1.0379, -0.0088, -0.4867,  1.0961, -0.0454,  0.0844,  0.0482],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8557,  0.6765, -0.1242, -0.0930,  0.3557,  0.4981, -0.6172, -1.0929,
-        -0.5442, -0.5942, -0.0296, -0.5883,  0.1710,  0.0779, -0.8210, -0.6484,
-        -0.8914,  0.7124, -0.9192, -0.6930, -0.0256,  0.1264, -0.5343, -0.1782,
-         0.1817,  0.3988,  0.0091, -0.5358, -0.3322,  1.5361, -0.8679, -0.6253,
-        -0.8805, -0.0115, -0.9016, -0.3420,  0.5969,  0.0000,  0.5990,  0.3315,
-         0.3420,  0.7627, -0.3453,  0.6644,  0.6847,  0.7174,  0.6001,  0.4729,
-         0.7136, -0.5859,  0.4988, -0.3285,  0.5610, -0.0685,  0.6426,  0.5384,
-         0.7475,  1.0379, -0.0088, -0.4867,  1.0961, -0.0454,  0.0844,  0.0482],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8557,  0.6765, -0.1242, -0.0930,  0.3557,  0.4981, -0.6172, -1.0929,
-        -0.5442, -0.5942, -0.0296, -0.5883,  0.1710,  0.0779, -0.8210, -0.6484,
-        -0.8914,  0.7124, -0.9192, -0.6930, -0.0256,  0.1264, -0.5343, -0.1782,
-         0.1817,  0.3988,  0.0091, -0.5358, -0.3322,  1.5361, -0.8679, -0.6253,
-        -0.8805, -0.0115, -0.9016, -0.3420,  0.5969,  0.0000,  0.5990,  0.3315,
-         0.3420,  0.7627, -0.3453,  0.6644,  0.6847,  0.7174,  0.6001,  0.4729,
-         0.7136, -0.5859,  0.4988, -0.3285,  0.5610, -0.0685,  0.6426,  0.5384,
-         0.7475,  1.0379, -0.0088, -0.4867,  1.0961, -0.0454,  0.0844,  0.0482],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 8.5266e-01,  6.7580e-01, -9.5389e-02, -5.8647e-02,  3.4499e-01,
-         5.2972e-01, -6.1499e-01, -1.0908e+00, -5.3981e-01, -6.0411e-01,
-        -1.2679e-02, -5.8782e-01,  1.8367e-01,  1.0575e-01, -8.5817e-01,
-        -6.4225e-01, -9.0060e-01,  7.1460e-01, -9.2264e-01, -6.9305e-01,
-        -5.1725e-04,  1.5005e-01, -5.3724e-01, -1.5387e-01,  2.2525e-01,
-         4.1215e-01, -2.4788e-02, -5.4276e-01, -3.3418e-01,  1.5373e+00,
-        -8.6478e-01, -6.4354e-01, -8.8071e-01,  4.0049e-02, -9.0071e-01,
-        -3.4485e-01,  5.9956e-01,  0.0000e+00,  6.2284e-01,  3.1167e-01,
-         3.4139e-01,  8.0222e-01, -3.4693e-01,  6.6650e-01,  6.8761e-01,
-         7.1025e-01,  6.0540e-01,  4.5924e-01,  7.0868e-01, -5.8703e-01,
-         4.9695e-01, -3.2038e-01,  5.6034e-01, -8.1232e-02,  6.3611e-01,
-         5.4369e-01,  7.7647e-01,  1.0382e+00, -3.0172e-02, -4.6999e-01,
-         1.1006e+00, -6.4157e-02,  9.5164e-02,  3.3943e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.5266e-01,  6.7580e-01, -9.5389e-02, -5.8647e-02,  3.4499e-01,
-         5.2972e-01, -6.1499e-01, -1.0908e+00, -5.3981e-01, -6.0411e-01,
-        -1.2679e-02, -5.8782e-01,  1.8367e-01,  1.0575e-01, -8.5817e-01,
-        -6.4225e-01, -9.0060e-01,  7.1460e-01, -9.2264e-01, -6.9305e-01,
-        -5.1725e-04,  1.5005e-01, -5.3724e-01, -1.5387e-01,  2.2525e-01,
-         4.1215e-01, -2.4788e-02, -5.4276e-01, -3.3418e-01,  1.5373e+00,
-        -8.6478e-01, -6.4354e-01, -8.8071e-01,  4.0049e-02, -9.0071e-01,
-        -3.4485e-01,  5.9956e-01,  0.0000e+00,  6.2284e-01,  3.1167e-01,
-         3.4139e-01,  8.0222e-01, -3.4693e-01,  6.6650e-01,  6.8761e-01,
-         7.1025e-01,  6.0540e-01,  4.5924e-01,  7.0868e-01, -5.8703e-01,
-         4.9695e-01, -3.2038e-01,  5.6034e-01, -8.1232e-02,  6.3611e-01,
-         5.4369e-01,  7.7647e-01,  1.0382e+00, -3.0172e-02, -4.6999e-01,
-         1.1006e+00, -6.4157e-02,  9.5164e-02,  3.3943e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.5266e-01,  6.7580e-01, -9.5389e-02, -5.8647e-02,  3.4499e-01,
-         5.2972e-01, -6.1499e-01, -1.0908e+00, -5.3981e-01, -6.0411e-01,
-        -1.2679e-02, -5.8782e-01,  1.8367e-01,  1.0575e-01, -8.5817e-01,
-        -6.4225e-01, -9.0060e-01,  7.1460e-01, -9.2264e-01, -6.9305e-01,
-        -5.1725e-04,  1.5005e-01, -5.3724e-01, -1.5387e-01,  2.2525e-01,
-         4.1215e-01, -2.4788e-02, -5.4276e-01, -3.3418e-01,  1.5373e+00,
-        -8.6478e-01, -6.4354e-01, -8.8071e-01,  4.0049e-02, -9.0071e-01,
-        -3.4485e-01,  5.9956e-01,  0.0000e+00,  6.2284e-01,  3.1167e-01,
-         3.4139e-01,  8.0222e-01, -3.4693e-01,  6.6650e-01,  6.8761e-01,
-         7.1025e-01,  6.0540e-01,  4.5924e-01,  7.0868e-01, -5.8703e-01,
-         4.9695e-01, -3.2038e-01,  5.6034e-01, -8.1232e-02,  6.3611e-01,
-         5.4369e-01,  7.7647e-01,  1.0382e+00, -3.0172e-02, -4.6999e-01,
-         1.1006e+00, -6.4157e-02,  9.5164e-02,  3.3943e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8481,  0.6777, -0.0667, -0.0233,  0.3218,  0.5614, -0.6137, -1.0897,
-        -0.5371, -0.6205,  0.0116, -0.5840,  0.1812,  0.1305, -0.8841, -0.6349,
-        -0.9096,  0.7181, -0.9248, -0.6915,  0.0243,  0.1978, -0.5450, -0.1431,
-         0.2955,  0.4249, -0.0832, -0.5469, -0.3278,  1.5390, -0.8622, -0.6591,
-        -0.8821,  0.0986, -0.8981, -0.3505,  0.5981,  0.0000,  0.6303,  0.2840,
-         0.3475,  0.8376, -0.3497,  0.6621,  0.6887,  0.7069,  0.6090,  0.4262,
-         0.7096, -0.5890,  0.4958, -0.3156,  0.5619, -0.1015,  0.6207,  0.5497,
-         0.8033,  1.0353, -0.0656, -0.4447,  1.1025, -0.0874,  0.1014, -0.0045],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8481,  0.6777, -0.0667, -0.0233,  0.3218,  0.5614, -0.6137, -1.0897,
-        -0.5371, -0.6205,  0.0116, -0.5840,  0.1812,  0.1305, -0.8841, -0.6349,
-        -0.9096,  0.7181, -0.9248, -0.6915,  0.0243,  0.1978, -0.5450, -0.1431,
-         0.2955,  0.4249, -0.0832, -0.5469, -0.3278,  1.5390, -0.8622, -0.6591,
-        -0.8821,  0.0986, -0.8981, -0.3505,  0.5981,  0.0000,  0.6303,  0.2840,
-         0.3475,  0.8376, -0.3497,  0.6621,  0.6887,  0.7069,  0.6090,  0.4262,
-         0.7096, -0.5890,  0.4958, -0.3156,  0.5619, -0.1015,  0.6207,  0.5497,
-         0.8033,  1.0353, -0.0656, -0.4447,  1.1025, -0.0874,  0.1014, -0.0045],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8481,  0.6777, -0.0667, -0.0233,  0.3218,  0.5614, -0.6137, -1.0897,
-        -0.5371, -0.6205,  0.0116, -0.5840,  0.1812,  0.1305, -0.8841, -0.6349,
-        -0.9096,  0.7181, -0.9248, -0.6915,  0.0243,  0.1978, -0.5450, -0.1431,
-         0.2955,  0.4249, -0.0832, -0.5469, -0.3278,  1.5390, -0.8622, -0.6591,
-        -0.8821,  0.0986, -0.8981, -0.3505,  0.5981,  0.0000,  0.6303,  0.2840,
-         0.3475,  0.8376, -0.3497,  0.6621,  0.6887,  0.7069,  0.6090,  0.4262,
-         0.7096, -0.5890,  0.4958, -0.3156,  0.5619, -0.1015,  0.6207,  0.5497,
-         0.8033,  1.0353, -0.0656, -0.4447,  1.1025, -0.0874,  0.1014, -0.0045],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8453,  0.6709, -0.0316,  0.0031,  0.3022,  0.5885, -0.6127, -1.0888,
-        -0.5377, -0.6351,  0.0136, -0.5838,  0.1487,  0.1609, -0.9119, -0.6284,
-        -0.9182,  0.7209, -0.9274, -0.6892,  0.0227,  0.2179, -0.5582, -0.1325,
-         0.3503,  0.4398, -0.1383, -0.5490, -0.3169,  1.5414, -0.8586, -0.6740,
-        -0.8825,  0.1398, -0.8967, -0.3510,  0.5993,  0.0000,  0.6398,  0.2602,
-         0.3466,  0.8698, -0.3477,  0.6601,  0.6898,  0.7015,  0.6142,  0.4043,
-         0.7144, -0.5898,  0.4980, -0.3139,  0.5599, -0.1175,  0.6093,  0.5526,
-         0.8292,  1.0344, -0.0941, -0.4420,  1.1051, -0.1116,  0.0880, -0.0505],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8453,  0.6709, -0.0316,  0.0031,  0.3022,  0.5885, -0.6127, -1.0888,
-        -0.5377, -0.6351,  0.0136, -0.5838,  0.1487,  0.1609, -0.9119, -0.6284,
-        -0.9182,  0.7209, -0.9274, -0.6892,  0.0227,  0.2179, -0.5582, -0.1325,
-         0.3503,  0.4398, -0.1383, -0.5490, -0.3169,  1.5414, -0.8586, -0.6740,
-        -0.8825,  0.1398, -0.8967, -0.3510,  0.5993,  0.0000,  0.6398,  0.2602,
-         0.3466,  0.8698, -0.3477,  0.6601,  0.6898,  0.7015,  0.6142,  0.4043,
-         0.7144, -0.5898,  0.4980, -0.3139,  0.5599, -0.1175,  0.6093,  0.5526,
-         0.8292,  1.0344, -0.0941, -0.4420,  1.1051, -0.1116,  0.0880, -0.0505],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8453,  0.6709, -0.0316,  0.0031,  0.3022,  0.5885, -0.6127, -1.0888,
-        -0.5377, -0.6351,  0.0136, -0.5838,  0.1487,  0.1609, -0.9119, -0.6284,
-        -0.9182,  0.7209, -0.9274, -0.6892,  0.0227,  0.2179, -0.5582, -0.1325,
-         0.3503,  0.4398, -0.1383, -0.5490, -0.3169,  1.5414, -0.8586, -0.6740,
-        -0.8825,  0.1398, -0.8967, -0.3510,  0.5993,  0.0000,  0.6398,  0.2602,
-         0.3466,  0.8698, -0.3477,  0.6601,  0.6898,  0.7015,  0.6142,  0.4043,
-         0.7144, -0.5898,  0.4980, -0.3139,  0.5599, -0.1175,  0.6093,  0.5526,
-         0.8292,  1.0344, -0.0941, -0.4420,  1.1051, -0.1116,  0.0880, -0.0505],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 8.4785e-01,  6.6183e-01,  1.8025e-03,  1.4265e-03,  3.0244e-01,
-         6.0873e-01, -6.0872e-01, -1.0883e+00, -5.4400e-01, -6.4973e-01,
-        -1.6679e-03, -5.7666e-01,  1.2513e-01,  2.0507e-01, -9.3640e-01,
-        -6.2306e-01, -9.2391e-01,  7.2304e-01, -9.2803e-01, -6.8419e-01,
-         9.0225e-03,  2.0134e-01, -5.7122e-01, -1.1008e-01,  3.9814e-01,
-         4.5700e-01, -1.6610e-01, -5.4718e-01, -3.0446e-01,  1.5438e+00,
-        -8.5262e-01, -6.8574e-01, -8.8078e-01,  1.5274e-01, -8.9266e-01,
-        -3.4508e-01,  6.0414e-01,  0.0000e+00,  6.4873e-01,  2.4528e-01,
-         3.4118e-01,  9.0364e-01, -3.5487e-01,  6.5901e-01,  6.8997e-01,
-         6.9338e-01,  6.2017e-01,  3.8932e-01,  7.1688e-01, -5.9111e-01,
-         4.9627e-01, -3.2175e-01,  5.6234e-01, -1.3454e-01,  5.9622e-01,
-         5.5684e-01,  8.5444e-01,  1.0331e+00, -1.2141e-01, -4.6627e-01,
-         1.1095e+00, -1.4739e-01,  7.7974e-02, -1.0855e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.4785e-01,  6.6183e-01,  1.8025e-03,  1.4265e-03,  3.0244e-01,
-         6.0873e-01, -6.0872e-01, -1.0883e+00, -5.4400e-01, -6.4973e-01,
-        -1.6679e-03, -5.7666e-01,  1.2513e-01,  2.0507e-01, -9.3640e-01,
-        -6.2306e-01, -9.2391e-01,  7.2304e-01, -9.2803e-01, -6.8419e-01,
-         9.0225e-03,  2.0134e-01, -5.7122e-01, -1.1008e-01,  3.9814e-01,
-         4.5700e-01, -1.6610e-01, -5.4718e-01, -3.0446e-01,  1.5438e+00,
-        -8.5262e-01, -6.8574e-01, -8.8078e-01,  1.5274e-01, -8.9266e-01,
-        -3.4508e-01,  6.0414e-01,  0.0000e+00,  6.4873e-01,  2.4528e-01,
-         3.4118e-01,  9.0364e-01, -3.5487e-01,  6.5901e-01,  6.8997e-01,
-         6.9338e-01,  6.2017e-01,  3.8932e-01,  7.1688e-01, -5.9111e-01,
-         4.9627e-01, -3.2175e-01,  5.6234e-01, -1.3454e-01,  5.9622e-01,
-         5.5684e-01,  8.5444e-01,  1.0331e+00, -1.2141e-01, -4.6627e-01,
-         1.1095e+00, -1.4739e-01,  7.7974e-02, -1.0855e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.4785e-01,  6.6183e-01,  1.8025e-03,  1.4265e-03,  3.0244e-01,
-         6.0873e-01, -6.0872e-01, -1.0883e+00, -5.4400e-01, -6.4973e-01,
-        -1.6679e-03, -5.7666e-01,  1.2513e-01,  2.0507e-01, -9.3640e-01,
-        -6.2306e-01, -9.2391e-01,  7.2304e-01, -9.2803e-01, -6.8419e-01,
-         9.0225e-03,  2.0134e-01, -5.7122e-01, -1.1008e-01,  3.9814e-01,
-         4.5700e-01, -1.6610e-01, -5.4718e-01, -3.0446e-01,  1.5438e+00,
-        -8.5262e-01, -6.8574e-01, -8.8078e-01,  1.5274e-01, -8.9266e-01,
-        -3.4508e-01,  6.0414e-01,  0.0000e+00,  6.4873e-01,  2.4528e-01,
-         3.4118e-01,  9.0364e-01, -3.5487e-01,  6.5901e-01,  6.8997e-01,
-         6.9338e-01,  6.2017e-01,  3.8932e-01,  7.1688e-01, -5.9111e-01,
-         4.9627e-01, -3.2175e-01,  5.6234e-01, -1.3454e-01,  5.9622e-01,
-         5.5684e-01,  8.5444e-01,  1.0331e+00, -1.2141e-01, -4.6627e-01,
-         1.1095e+00, -1.4739e-01,  7.7974e-02, -1.0855e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8480,  0.6467,  0.0505,  0.0129,  0.3521,  0.6176, -0.6035, -1.0859,
-        -0.5493, -0.6594, -0.0316, -0.5608,  0.1000,  0.2710, -0.9691, -0.6158,
-        -0.9280,  0.7162, -0.9277, -0.6774, -0.0052,  0.1552, -0.5815, -0.1002,
-         0.4326,  0.4630, -0.1855, -0.5458, -0.2998,  1.5459, -0.8451, -0.6956,
-        -0.8774,  0.1352, -0.8911, -0.3328,  0.6248,  0.0000,  0.6657,  0.2678,
-         0.3230,  0.9347, -0.3515,  0.6685,  0.6924,  0.6883,  0.6275,  0.3795,
-         0.7224, -0.5930,  0.4989, -0.3273,  0.5579, -0.1341,  0.5868,  0.5589,
-         0.8795,  1.0309, -0.1550, -0.5100,  1.1168, -0.1660,  0.0439, -0.1603],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8480,  0.6467,  0.0505,  0.0129,  0.3521,  0.6176, -0.6035, -1.0859,
-        -0.5493, -0.6594, -0.0316, -0.5608,  0.1000,  0.2710, -0.9691, -0.6158,
-        -0.9280,  0.7162, -0.9277, -0.6774, -0.0052,  0.1552, -0.5815, -0.1002,
-         0.4326,  0.4630, -0.1855, -0.5458, -0.2998,  1.5459, -0.8451, -0.6956,
-        -0.8774,  0.1352, -0.8911, -0.3328,  0.6248,  0.0000,  0.6657,  0.2678,
-         0.3230,  0.9347, -0.3515,  0.6685,  0.6924,  0.6883,  0.6275,  0.3795,
-         0.7224, -0.5930,  0.4989, -0.3273,  0.5579, -0.1341,  0.5868,  0.5589,
-         0.8795,  1.0309, -0.1550, -0.5100,  1.1168, -0.1660,  0.0439, -0.1603],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8480,  0.6467,  0.0505,  0.0129,  0.3521,  0.6176, -0.6035, -1.0859,
-        -0.5493, -0.6594, -0.0316, -0.5608,  0.1000,  0.2710, -0.9691, -0.6158,
-        -0.9280,  0.7162, -0.9277, -0.6774, -0.0052,  0.1552, -0.5815, -0.1002,
-         0.4326,  0.4630, -0.1855, -0.5458, -0.2998,  1.5459, -0.8451, -0.6956,
-        -0.8774,  0.1352, -0.8911, -0.3328,  0.6248,  0.0000,  0.6657,  0.2678,
-         0.3230,  0.9347, -0.3515,  0.6685,  0.6924,  0.6883,  0.6275,  0.3795,
-         0.7224, -0.5930,  0.4989, -0.3273,  0.5579, -0.1341,  0.5868,  0.5589,
-         0.8795,  1.0309, -0.1550, -0.5100,  1.1168, -0.1660,  0.0439, -0.1603],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8516,  0.6250,  0.1032,  0.0291,  0.4085,  0.6249, -0.5966, -1.0827,
-        -0.5561, -0.6681, -0.0772, -0.5355,  0.0922,  0.3283, -1.0032, -0.6110,
-        -0.9278,  0.7076, -0.9278, -0.6685, -0.0113,  0.0982, -0.5864, -0.0968,
-         0.4542,  0.4567, -0.2037, -0.5439, -0.2914,  1.5472, -0.8413, -0.7019,
-        -0.8749,  0.0927, -0.8897, -0.3183,  0.6463,  0.0000,  0.6891,  0.3010,
-         0.2956,  0.9626, -0.3477,  0.6820,  0.6967,  0.6807,  0.6386,  0.3650,
-         0.7230, -0.5933,  0.4977, -0.3283,  0.5453, -0.1070,  0.5833,  0.5609,
-         0.9040,  1.0290, -0.1999, -0.5570,  1.1253, -0.1857,  0.0091, -0.2012],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8516,  0.6250,  0.1032,  0.0291,  0.4085,  0.6249, -0.5966, -1.0827,
-        -0.5561, -0.6681, -0.0772, -0.5355,  0.0922,  0.3283, -1.0032, -0.6110,
-        -0.9278,  0.7076, -0.9278, -0.6685, -0.0113,  0.0982, -0.5864, -0.0968,
-         0.4542,  0.4567, -0.2037, -0.5439, -0.2914,  1.5472, -0.8413, -0.7019,
-        -0.8749,  0.0927, -0.8897, -0.3183,  0.6463,  0.0000,  0.6891,  0.3010,
-         0.2956,  0.9626, -0.3477,  0.6820,  0.6967,  0.6807,  0.6386,  0.3650,
-         0.7230, -0.5933,  0.4977, -0.3283,  0.5453, -0.1070,  0.5833,  0.5609,
-         0.9040,  1.0290, -0.1999, -0.5570,  1.1253, -0.1857,  0.0091, -0.2012],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8516,  0.6250,  0.1032,  0.0291,  0.4085,  0.6249, -0.5966, -1.0827,
-        -0.5561, -0.6681, -0.0772, -0.5355,  0.0922,  0.3283, -1.0032, -0.6110,
-        -0.9278,  0.7076, -0.9278, -0.6685, -0.0113,  0.0982, -0.5864, -0.0968,
-         0.4542,  0.4567, -0.2037, -0.5439, -0.2914,  1.5472, -0.8413, -0.7019,
-        -0.8749,  0.0927, -0.8897, -0.3183,  0.6463,  0.0000,  0.6891,  0.3010,
-         0.2956,  0.9626, -0.3477,  0.6820,  0.6967,  0.6807,  0.6386,  0.3650,
-         0.7230, -0.5933,  0.4977, -0.3283,  0.5453, -0.1070,  0.5833,  0.5609,
-         0.9040,  1.0290, -0.1999, -0.5570,  1.1253, -0.1857,  0.0091, -0.2012],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8531,  0.6081,  0.1566,  0.0508,  0.4595,  0.6326, -0.5882, -1.0798,
-        -0.5604, -0.6778, -0.1220, -0.5030,  0.0868,  0.3818, -1.0357, -0.6079,
-        -0.9257,  0.7018, -0.9292, -0.6589, -0.0235,  0.0500, -0.5933, -0.0724,
-         0.4748,  0.4460, -0.2289, -0.5414, -0.2876,  1.5483, -0.8386, -0.7071,
-        -0.8714,  0.0432, -0.8889, -0.2982,  0.6698,  0.0000,  0.7122,  0.3247,
-         0.2632,  0.9882, -0.3348,  0.6976,  0.7009,  0.6716,  0.6463,  0.3542,
-         0.7302, -0.5936,  0.4987, -0.3238,  0.5302, -0.0843,  0.5755,  0.5640,
-         0.9278,  1.0256, -0.2392, -0.6060,  1.1331, -0.2147, -0.0266, -0.2313],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8531,  0.6081,  0.1566,  0.0508,  0.4595,  0.6326, -0.5882, -1.0798,
-        -0.5604, -0.6778, -0.1220, -0.5030,  0.0868,  0.3818, -1.0357, -0.6079,
-        -0.9257,  0.7018, -0.9292, -0.6589, -0.0235,  0.0500, -0.5933, -0.0724,
-         0.4748,  0.4460, -0.2289, -0.5414, -0.2876,  1.5483, -0.8386, -0.7071,
-        -0.8714,  0.0432, -0.8889, -0.2982,  0.6698,  0.0000,  0.7122,  0.3247,
-         0.2632,  0.9882, -0.3348,  0.6976,  0.7009,  0.6716,  0.6463,  0.3542,
-         0.7302, -0.5936,  0.4987, -0.3238,  0.5302, -0.0843,  0.5755,  0.5640,
-         0.9278,  1.0256, -0.2392, -0.6060,  1.1331, -0.2147, -0.0266, -0.2313],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8531,  0.6081,  0.1566,  0.0508,  0.4595,  0.6326, -0.5882, -1.0798,
-        -0.5604, -0.6778, -0.1220, -0.5030,  0.0868,  0.3818, -1.0357, -0.6079,
-        -0.9257,  0.7018, -0.9292, -0.6589, -0.0235,  0.0500, -0.5933, -0.0724,
-         0.4748,  0.4460, -0.2289, -0.5414, -0.2876,  1.5483, -0.8386, -0.7071,
-        -0.8714,  0.0432, -0.8889, -0.2982,  0.6698,  0.0000,  0.7122,  0.3247,
-         0.2632,  0.9882, -0.3348,  0.6976,  0.7009,  0.6716,  0.6463,  0.3542,
-         0.7302, -0.5936,  0.4987, -0.3238,  0.5302, -0.0843,  0.5755,  0.5640,
-         0.9278,  1.0256, -0.2392, -0.6060,  1.1331, -0.2147, -0.0266, -0.2313],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8520,  0.6035,  0.1812,  0.0575,  0.4987,  0.6428, -0.5787, -1.0763,
-        -0.5627, -0.6876, -0.1697, -0.4773,  0.0595,  0.4090, -1.0658, -0.6038,
-        -0.9232,  0.6984, -0.9313, -0.6519, -0.0357,  0.0149, -0.5956, -0.0668,
-         0.4934,  0.4378, -0.2524, -0.5389, -0.2868,  1.5499, -0.8385, -0.7119,
-        -0.8683,  0.0256, -0.8870, -0.2823,  0.6978,  0.0000,  0.7296,  0.3361,
-         0.2282,  1.0128, -0.3342,  0.7094,  0.7043,  0.6589,  0.6542,  0.3241,
-         0.7465, -0.5933,  0.4978, -0.3146,  0.5213, -0.0610,  0.5602,  0.5637,
-         0.9503,  1.0209, -0.2637, -0.6393,  1.1405, -0.2535, -0.0682, -0.2747],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8520,  0.6035,  0.1812,  0.0575,  0.4987,  0.6428, -0.5787, -1.0763,
-        -0.5627, -0.6876, -0.1697, -0.4773,  0.0595,  0.4090, -1.0658, -0.6038,
-        -0.9232,  0.6984, -0.9313, -0.6519, -0.0357,  0.0149, -0.5956, -0.0668,
-         0.4934,  0.4378, -0.2524, -0.5389, -0.2868,  1.5499, -0.8385, -0.7119,
-        -0.8683,  0.0256, -0.8870, -0.2823,  0.6978,  0.0000,  0.7296,  0.3361,
-         0.2282,  1.0128, -0.3342,  0.7094,  0.7043,  0.6589,  0.6542,  0.3241,
-         0.7465, -0.5933,  0.4978, -0.3146,  0.5213, -0.0610,  0.5602,  0.5637,
-         0.9503,  1.0209, -0.2637, -0.6393,  1.1405, -0.2535, -0.0682, -0.2747],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8520,  0.6035,  0.1812,  0.0575,  0.4987,  0.6428, -0.5787, -1.0763,
-        -0.5627, -0.6876, -0.1697, -0.4773,  0.0595,  0.4090, -1.0658, -0.6038,
-        -0.9232,  0.6984, -0.9313, -0.6519, -0.0357,  0.0149, -0.5956, -0.0668,
-         0.4934,  0.4378, -0.2524, -0.5389, -0.2868,  1.5499, -0.8385, -0.7119,
-        -0.8683,  0.0256, -0.8870, -0.2823,  0.6978,  0.0000,  0.7296,  0.3361,
-         0.2282,  1.0128, -0.3342,  0.7094,  0.7043,  0.6589,  0.6542,  0.3241,
-         0.7465, -0.5933,  0.4978, -0.3146,  0.5213, -0.0610,  0.5602,  0.5637,
-         0.9503,  1.0209, -0.2637, -0.6393,  1.1405, -0.2535, -0.0682, -0.2747],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8515,  0.6070,  0.2001,  0.0773,  0.5316,  0.6571, -0.5678, -1.0729,
-        -0.5543, -0.6971, -0.1996, -0.4722,  0.0395,  0.4368, -1.0937, -0.5989,
-        -0.9209,  0.6960, -0.9357, -0.6486, -0.0348,  0.0029, -0.5945, -0.0713,
-         0.5130,  0.4299, -0.2678, -0.5333, -0.2836,  1.5520, -0.8362, -0.7143,
-        -0.8646,  0.0457, -0.8843, -0.2621,  0.7215,  0.0000,  0.7394,  0.3244,
-         0.2036,  1.0351, -0.3345,  0.7191,  0.7049,  0.6445,  0.6627,  0.3014,
-         0.7634, -0.5925,  0.4956, -0.2897,  0.5141, -0.0160,  0.5418,  0.5622,
-         0.9717,  1.0174, -0.2860, -0.6613,  1.1461, -0.2893, -0.1124, -0.2999],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8515,  0.6070,  0.2001,  0.0773,  0.5316,  0.6571, -0.5678, -1.0729,
-        -0.5543, -0.6971, -0.1996, -0.4722,  0.0395,  0.4368, -1.0937, -0.5989,
-        -0.9209,  0.6960, -0.9357, -0.6486, -0.0348,  0.0029, -0.5945, -0.0713,
-         0.5130,  0.4299, -0.2678, -0.5333, -0.2836,  1.5520, -0.8362, -0.7143,
-        -0.8646,  0.0457, -0.8843, -0.2621,  0.7215,  0.0000,  0.7394,  0.3244,
-         0.2036,  1.0351, -0.3345,  0.7191,  0.7049,  0.6445,  0.6627,  0.3014,
-         0.7634, -0.5925,  0.4956, -0.2897,  0.5141, -0.0160,  0.5418,  0.5622,
-         0.9717,  1.0174, -0.2860, -0.6613,  1.1461, -0.2893, -0.1124, -0.2999],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8515,  0.6070,  0.2001,  0.0773,  0.5316,  0.6571, -0.5678, -1.0729,
-        -0.5543, -0.6971, -0.1996, -0.4722,  0.0395,  0.4368, -1.0937, -0.5989,
-        -0.9209,  0.6960, -0.9357, -0.6486, -0.0348,  0.0029, -0.5945, -0.0713,
-         0.5130,  0.4299, -0.2678, -0.5333, -0.2836,  1.5520, -0.8362, -0.7143,
-        -0.8646,  0.0457, -0.8843, -0.2621,  0.7215,  0.0000,  0.7394,  0.3244,
-         0.2036,  1.0351, -0.3345,  0.7191,  0.7049,  0.6445,  0.6627,  0.3014,
-         0.7634, -0.5925,  0.4956, -0.2897,  0.5141, -0.0160,  0.5418,  0.5622,
-         0.9717,  1.0174, -0.2860, -0.6613,  1.1461, -0.2893, -0.1124, -0.2999],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8491,  0.6108,  0.2011,  0.0875,  0.5533,  0.6765, -0.5541, -1.0689,
-        -0.5547, -0.7052, -0.2310, -0.4666,  0.0126,  0.4657, -1.1206, -0.5940,
-        -0.9187,  0.6926, -0.9401, -0.6459, -0.0500, -0.0171, -0.5942, -0.0899,
-         0.5299,  0.4219, -0.2783, -0.5267, -0.2695,  1.5535, -0.8322, -0.7162,
-        -0.8616,  0.0367, -0.8810, -0.2461,  0.7414,  0.0000,  0.7476,  0.3008,
-         0.1767,  1.0551, -0.3252,  0.7272,  0.7003,  0.6273,  0.6711,  0.2753,
-         0.7719, -0.5910,  0.4902, -0.2761,  0.5119,  0.0158,  0.5210,  0.5580,
-         0.9936,  1.0151, -0.2927, -0.6728,  1.1517, -0.3129, -0.1469, -0.3052],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8491,  0.6108,  0.2011,  0.0875,  0.5533,  0.6765, -0.5541, -1.0689,
-        -0.5547, -0.7052, -0.2310, -0.4666,  0.0126,  0.4657, -1.1206, -0.5940,
-        -0.9187,  0.6926, -0.9401, -0.6459, -0.0500, -0.0171, -0.5942, -0.0899,
-         0.5299,  0.4219, -0.2783, -0.5267, -0.2695,  1.5535, -0.8322, -0.7162,
-        -0.8616,  0.0367, -0.8810, -0.2461,  0.7414,  0.0000,  0.7476,  0.3008,
-         0.1767,  1.0551, -0.3252,  0.7272,  0.7003,  0.6273,  0.6711,  0.2753,
-         0.7719, -0.5910,  0.4902, -0.2761,  0.5119,  0.0158,  0.5210,  0.5580,
-         0.9936,  1.0151, -0.2927, -0.6728,  1.1517, -0.3129, -0.1469, -0.3052],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8491,  0.6108,  0.2011,  0.0875,  0.5533,  0.6765, -0.5541, -1.0689,
-        -0.5547, -0.7052, -0.2310, -0.4666,  0.0126,  0.4657, -1.1206, -0.5940,
-        -0.9187,  0.6926, -0.9401, -0.6459, -0.0500, -0.0171, -0.5942, -0.0899,
-         0.5299,  0.4219, -0.2783, -0.5267, -0.2695,  1.5535, -0.8322, -0.7162,
-        -0.8616,  0.0367, -0.8810, -0.2461,  0.7414,  0.0000,  0.7476,  0.3008,
-         0.1767,  1.0551, -0.3252,  0.7272,  0.7003,  0.6273,  0.6711,  0.2753,
-         0.7719, -0.5910,  0.4902, -0.2761,  0.5119,  0.0158,  0.5210,  0.5580,
-         0.9936,  1.0151, -0.2927, -0.6728,  1.1517, -0.3129, -0.1469, -0.3052],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8476,  0.6133,  0.1960,  0.0940,  0.5631,  0.6911, -0.5442, -1.0651,
-        -0.5504, -0.7142, -0.2784, -0.4500, -0.0025,  0.4767, -1.1409, -0.5943,
-        -0.9175,  0.6890, -0.9441, -0.6340, -0.0892, -0.0555, -0.6063, -0.1418,
-         0.5391,  0.4140, -0.2930, -0.5231, -0.2627,  1.5544, -0.8277, -0.7152,
-        -0.8604,  0.0117, -0.8763, -0.2268,  0.7590,  0.0000,  0.7540,  0.2741,
-         0.1441,  1.0712, -0.3007,  0.7313,  0.6955,  0.6136,  0.6799,  0.2500,
-         0.7742, -0.5901,  0.4855, -0.2693,  0.5099,  0.0352,  0.5001,  0.5517,
-         1.0148,  1.0122, -0.3105, -0.6823,  1.1570, -0.3296, -0.1802, -0.2976],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8476,  0.6133,  0.1960,  0.0940,  0.5631,  0.6911, -0.5442, -1.0651,
-        -0.5504, -0.7142, -0.2784, -0.4500, -0.0025,  0.4767, -1.1409, -0.5943,
-        -0.9175,  0.6890, -0.9441, -0.6340, -0.0892, -0.0555, -0.6063, -0.1418,
-         0.5391,  0.4140, -0.2930, -0.5231, -0.2627,  1.5544, -0.8277, -0.7152,
-        -0.8604,  0.0117, -0.8763, -0.2268,  0.7590,  0.0000,  0.7540,  0.2741,
-         0.1441,  1.0712, -0.3007,  0.7313,  0.6955,  0.6136,  0.6799,  0.2500,
-         0.7742, -0.5901,  0.4855, -0.2693,  0.5099,  0.0352,  0.5001,  0.5517,
-         1.0148,  1.0122, -0.3105, -0.6823,  1.1570, -0.3296, -0.1802, -0.2976],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8476,  0.6133,  0.1960,  0.0940,  0.5631,  0.6911, -0.5442, -1.0651,
-        -0.5504, -0.7142, -0.2784, -0.4500, -0.0025,  0.4767, -1.1409, -0.5943,
-        -0.9175,  0.6890, -0.9441, -0.6340, -0.0892, -0.0555, -0.6063, -0.1418,
-         0.5391,  0.4140, -0.2930, -0.5231, -0.2627,  1.5544, -0.8277, -0.7152,
-        -0.8604,  0.0117, -0.8763, -0.2268,  0.7590,  0.0000,  0.7540,  0.2741,
-         0.1441,  1.0712, -0.3007,  0.7313,  0.6955,  0.6136,  0.6799,  0.2500,
-         0.7742, -0.5901,  0.4855, -0.2693,  0.5099,  0.0352,  0.5001,  0.5517,
-         1.0148,  1.0122, -0.3105, -0.6823,  1.1570, -0.3296, -0.1802, -0.2976],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8467,  0.6123,  0.1899,  0.1074,  0.5671,  0.7027, -0.5378, -1.0605,
-        -0.5415, -0.7226, -0.3128, -0.4210,  0.0053,  0.4820, -1.1586, -0.5926,
-        -0.9217,  0.6874, -0.9492, -0.6188, -0.1328, -0.0850, -0.6290, -0.2181,
-         0.5469,  0.4040, -0.3106, -0.5161, -0.2652,  1.5559, -0.8244, -0.7136,
-        -0.8584, -0.0254, -0.8724, -0.2042,  0.7771,  0.0000,  0.7605,  0.2528,
-         0.1130,  1.0846, -0.2796,  0.7376,  0.6908,  0.6007,  0.6901,  0.2177,
-         0.7747, -0.5897,  0.4794, -0.2650,  0.5076,  0.0404,  0.4811,  0.5427,
-         1.0338,  1.0093, -0.3320, -0.6788,  1.1636, -0.3396, -0.2094, -0.2607],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8467,  0.6123,  0.1899,  0.1074,  0.5671,  0.7027, -0.5378, -1.0605,
-        -0.5415, -0.7226, -0.3128, -0.4210,  0.0053,  0.4820, -1.1586, -0.5926,
-        -0.9217,  0.6874, -0.9492, -0.6188, -0.1328, -0.0850, -0.6290, -0.2181,
-         0.5469,  0.4040, -0.3106, -0.5161, -0.2652,  1.5559, -0.8244, -0.7136,
-        -0.8584, -0.0254, -0.8724, -0.2042,  0.7771,  0.0000,  0.7605,  0.2528,
-         0.1130,  1.0846, -0.2796,  0.7376,  0.6908,  0.6007,  0.6901,  0.2177,
-         0.7747, -0.5897,  0.4794, -0.2650,  0.5076,  0.0404,  0.4811,  0.5427,
-         1.0338,  1.0093, -0.3320, -0.6788,  1.1636, -0.3396, -0.2094, -0.2607],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8467,  0.6123,  0.1899,  0.1074,  0.5671,  0.7027, -0.5378, -1.0605,
-        -0.5415, -0.7226, -0.3128, -0.4210,  0.0053,  0.4820, -1.1586, -0.5926,
-        -0.9217,  0.6874, -0.9492, -0.6188, -0.1328, -0.0850, -0.6290, -0.2181,
-         0.5469,  0.4040, -0.3106, -0.5161, -0.2652,  1.5559, -0.8244, -0.7136,
-        -0.8584, -0.0254, -0.8724, -0.2042,  0.7771,  0.0000,  0.7605,  0.2528,
-         0.1130,  1.0846, -0.2796,  0.7376,  0.6908,  0.6007,  0.6901,  0.2177,
-         0.7747, -0.5897,  0.4794, -0.2650,  0.5076,  0.0404,  0.4811,  0.5427,
-         1.0338,  1.0093, -0.3320, -0.6788,  1.1636, -0.3396, -0.2094, -0.2607],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8474,  0.6132,  0.1829,  0.1168,  0.5688,  0.7116, -0.5337, -1.0566,
-        -0.5277, -0.7316, -0.3507, -0.3867,  0.0144,  0.4779, -1.1747, -0.5925,
-        -0.9294,  0.6856, -0.9538, -0.6004, -0.1819, -0.1123, -0.6513, -0.2942,
-         0.5534,  0.3851, -0.3150, -0.5152, -0.2635,  1.5576, -0.8218, -0.7147,
-        -0.8588, -0.0687, -0.8700, -0.1843,  0.7952,  0.0000,  0.7638,  0.2224,
-         0.0771,  1.0977, -0.2675,  0.7410,  0.6850,  0.5844,  0.6953,  0.1815,
-         0.7815, -0.5911,  0.4724, -0.2703,  0.5090,  0.0521,  0.4607,  0.5361,
-         1.0518,  1.0065, -0.3401, -0.6790,  1.1713, -0.3394, -0.2196, -0.1918],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8474,  0.6132,  0.1829,  0.1168,  0.5688,  0.7116, -0.5337, -1.0566,
-        -0.5277, -0.7316, -0.3507, -0.3867,  0.0144,  0.4779, -1.1747, -0.5925,
-        -0.9294,  0.6856, -0.9538, -0.6004, -0.1819, -0.1123, -0.6513, -0.2942,
-         0.5534,  0.3851, -0.3150, -0.5152, -0.2635,  1.5576, -0.8218, -0.7147,
-        -0.8588, -0.0687, -0.8700, -0.1843,  0.7952,  0.0000,  0.7638,  0.2224,
-         0.0771,  1.0977, -0.2675,  0.7410,  0.6850,  0.5844,  0.6953,  0.1815,
-         0.7815, -0.5911,  0.4724, -0.2703,  0.5090,  0.0521,  0.4607,  0.5361,
-         1.0518,  1.0065, -0.3401, -0.6790,  1.1713, -0.3394, -0.2196, -0.1918],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8474,  0.6132,  0.1829,  0.1168,  0.5688,  0.7116, -0.5337, -1.0566,
-        -0.5277, -0.7316, -0.3507, -0.3867,  0.0144,  0.4779, -1.1747, -0.5925,
-        -0.9294,  0.6856, -0.9538, -0.6004, -0.1819, -0.1123, -0.6513, -0.2942,
-         0.5534,  0.3851, -0.3150, -0.5152, -0.2635,  1.5576, -0.8218, -0.7147,
-        -0.8588, -0.0687, -0.8700, -0.1843,  0.7952,  0.0000,  0.7638,  0.2224,
-         0.0771,  1.0977, -0.2675,  0.7410,  0.6850,  0.5844,  0.6953,  0.1815,
-         0.7815, -0.5911,  0.4724, -0.2703,  0.5090,  0.0521,  0.4607,  0.5361,
-         1.0518,  1.0065, -0.3401, -0.6790,  1.1713, -0.3394, -0.2196, -0.1918],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8490,  0.6039,  0.1740,  0.1185,  0.5610,  0.7151, -0.5294, -1.0527,
-        -0.5123, -0.7409, -0.3927, -0.3595,  0.0355,  0.4668, -1.1904, -0.5921,
-        -0.9386,  0.6779, -0.9575, -0.5849, -0.2253, -0.1374, -0.6737, -0.3559,
-         0.5566,  0.3647, -0.3065, -0.5126, -0.2671,  1.5598, -0.8186, -0.7171,
-        -0.8621, -0.1111, -0.8682, -0.1585,  0.8146,  0.0000,  0.7626,  0.2079,
-         0.0339,  1.1099, -0.2296,  0.7422,  0.6785,  0.5683,  0.6977,  0.1529,
-         0.7865, -0.5919,  0.4633, -0.2607,  0.5098,  0.0791,  0.4369,  0.5255,
-         1.0672,  1.0001, -0.3243, -0.6882,  1.1785, -0.3314, -0.2240, -0.1222],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8490,  0.6039,  0.1740,  0.1185,  0.5610,  0.7151, -0.5294, -1.0527,
-        -0.5123, -0.7409, -0.3927, -0.3595,  0.0355,  0.4668, -1.1904, -0.5921,
-        -0.9386,  0.6779, -0.9575, -0.5849, -0.2253, -0.1374, -0.6737, -0.3559,
-         0.5566,  0.3647, -0.3065, -0.5126, -0.2671,  1.5598, -0.8186, -0.7171,
-        -0.8621, -0.1111, -0.8682, -0.1585,  0.8146,  0.0000,  0.7626,  0.2079,
-         0.0339,  1.1099, -0.2296,  0.7422,  0.6785,  0.5683,  0.6977,  0.1529,
-         0.7865, -0.5919,  0.4633, -0.2607,  0.5098,  0.0791,  0.4369,  0.5255,
-         1.0672,  1.0001, -0.3243, -0.6882,  1.1785, -0.3314, -0.2240, -0.1222],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8490,  0.6039,  0.1740,  0.1185,  0.5610,  0.7151, -0.5294, -1.0527,
-        -0.5123, -0.7409, -0.3927, -0.3595,  0.0355,  0.4668, -1.1904, -0.5921,
-        -0.9386,  0.6779, -0.9575, -0.5849, -0.2253, -0.1374, -0.6737, -0.3559,
-         0.5566,  0.3647, -0.3065, -0.5126, -0.2671,  1.5598, -0.8186, -0.7171,
-        -0.8621, -0.1111, -0.8682, -0.1585,  0.8146,  0.0000,  0.7626,  0.2079,
-         0.0339,  1.1099, -0.2296,  0.7422,  0.6785,  0.5683,  0.6977,  0.1529,
-         0.7865, -0.5919,  0.4633, -0.2607,  0.5098,  0.0791,  0.4369,  0.5255,
-         1.0672,  1.0001, -0.3243, -0.6882,  1.1785, -0.3314, -0.2240, -0.1222],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8501,  0.5965,  0.1676,  0.1304,  0.5592,  0.7166, -0.5240, -1.0491,
-        -0.4926, -0.7490, -0.4272, -0.3435,  0.0586,  0.4549, -1.2062, -0.5850,
-        -0.9468,  0.6708, -0.9595, -0.5716, -0.2552, -0.1329, -0.6914, -0.4000,
-         0.5622,  0.3327, -0.2902, -0.5076, -0.2614,  1.5620, -0.8155, -0.7168,
-        -0.8646, -0.1306, -0.8657, -0.1393,  0.8346,  0.0000,  0.7606,  0.2072,
-        -0.0147,  1.1181, -0.1869,  0.7424,  0.6721,  0.5478,  0.6995,  0.1294,
-         0.7959, -0.5925,  0.4493, -0.2360,  0.5167,  0.1159,  0.4207,  0.5121,
-         1.0822,  0.9943, -0.2969, -0.6928,  1.1852, -0.3301, -0.2213, -0.0560],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8501,  0.5965,  0.1676,  0.1304,  0.5592,  0.7166, -0.5240, -1.0491,
-        -0.4926, -0.7490, -0.4272, -0.3435,  0.0586,  0.4549, -1.2062, -0.5850,
-        -0.9468,  0.6708, -0.9595, -0.5716, -0.2552, -0.1329, -0.6914, -0.4000,
-         0.5622,  0.3327, -0.2902, -0.5076, -0.2614,  1.5620, -0.8155, -0.7168,
-        -0.8646, -0.1306, -0.8657, -0.1393,  0.8346,  0.0000,  0.7606,  0.2072,
-        -0.0147,  1.1181, -0.1869,  0.7424,  0.6721,  0.5478,  0.6995,  0.1294,
-         0.7959, -0.5925,  0.4493, -0.2360,  0.5167,  0.1159,  0.4207,  0.5121,
-         1.0822,  0.9943, -0.2969, -0.6928,  1.1852, -0.3301, -0.2213, -0.0560],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8501,  0.5965,  0.1676,  0.1304,  0.5592,  0.7166, -0.5240, -1.0491,
-        -0.4926, -0.7490, -0.4272, -0.3435,  0.0586,  0.4549, -1.2062, -0.5850,
-        -0.9468,  0.6708, -0.9595, -0.5716, -0.2552, -0.1329, -0.6914, -0.4000,
-         0.5622,  0.3327, -0.2902, -0.5076, -0.2614,  1.5620, -0.8155, -0.7168,
-        -0.8646, -0.1306, -0.8657, -0.1393,  0.8346,  0.0000,  0.7606,  0.2072,
-        -0.0147,  1.1181, -0.1869,  0.7424,  0.6721,  0.5478,  0.6995,  0.1294,
-         0.7959, -0.5925,  0.4493, -0.2360,  0.5167,  0.1159,  0.4207,  0.5121,
-         1.0822,  0.9943, -0.2969, -0.6928,  1.1852, -0.3301, -0.2213, -0.0560],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8535,  0.5810,  0.1600,  0.1479,  0.5540,  0.7116, -0.5180, -1.0450,
-        -0.4726, -0.7553, -0.4579, -0.3374,  0.0603,  0.4231, -1.2211, -0.5807,
-        -0.9495,  0.6631, -0.9600, -0.5568, -0.2692, -0.1153, -0.7009, -0.4492,
-         0.5679,  0.3091, -0.2787, -0.5002, -0.2573,  1.5641, -0.8098, -0.7163,
-        -0.8668, -0.1216, -0.8616, -0.1226,  0.8540,  0.0000,  0.7552,  0.2185,
-        -0.0477,  1.1243, -0.1507,  0.7398,  0.6655,  0.5236,  0.7004,  0.1113,
-         0.8041, -0.5938,  0.4341, -0.1815,  0.5201,  0.1599,  0.4045,  0.5024,
-         1.0964,  0.9882, -0.2748, -0.7010,  1.1931, -0.3249, -0.2172, -0.0146],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8535,  0.5810,  0.1600,  0.1479,  0.5540,  0.7116, -0.5180, -1.0450,
-        -0.4726, -0.7553, -0.4579, -0.3374,  0.0603,  0.4231, -1.2211, -0.5807,
-        -0.9495,  0.6631, -0.9600, -0.5568, -0.2692, -0.1153, -0.7009, -0.4492,
-         0.5679,  0.3091, -0.2787, -0.5002, -0.2573,  1.5641, -0.8098, -0.7163,
-        -0.8668, -0.1216, -0.8616, -0.1226,  0.8540,  0.0000,  0.7552,  0.2185,
-        -0.0477,  1.1243, -0.1507,  0.7398,  0.6655,  0.5236,  0.7004,  0.1113,
-         0.8041, -0.5938,  0.4341, -0.1815,  0.5201,  0.1599,  0.4045,  0.5024,
-         1.0964,  0.9882, -0.2748, -0.7010,  1.1931, -0.3249, -0.2172, -0.0146],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8535,  0.5810,  0.1600,  0.1479,  0.5540,  0.7116, -0.5180, -1.0450,
-        -0.4726, -0.7553, -0.4579, -0.3374,  0.0603,  0.4231, -1.2211, -0.5807,
-        -0.9495,  0.6631, -0.9600, -0.5568, -0.2692, -0.1153, -0.7009, -0.4492,
-         0.5679,  0.3091, -0.2787, -0.5002, -0.2573,  1.5641, -0.8098, -0.7163,
-        -0.8668, -0.1216, -0.8616, -0.1226,  0.8540,  0.0000,  0.7552,  0.2185,
-        -0.0477,  1.1243, -0.1507,  0.7398,  0.6655,  0.5236,  0.7004,  0.1113,
-         0.8041, -0.5938,  0.4341, -0.1815,  0.5201,  0.1599,  0.4045,  0.5024,
-         1.0964,  0.9882, -0.2748, -0.7010,  1.1931, -0.3249, -0.2172, -0.0146],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8580,  0.5682,  0.1519,  0.1486,  0.5568,  0.7077, -0.5037, -1.0415,
-        -0.4545, -0.7562, -0.4813, -0.3469,  0.0300,  0.3686, -1.2326, -0.5846,
-        -0.9479,  0.6553, -0.9596, -0.5484, -0.2972, -0.0854, -0.6926, -0.4905,
-         0.5666,  0.2776, -0.2707, -0.4858, -0.2472,  1.5675, -0.8091, -0.7165,
-        -0.8655, -0.0885, -0.8584, -0.1248,  0.8695,  0.0000,  0.7445,  0.1974,
-        -0.0750,  1.1339, -0.1468,  0.7349,  0.6593,  0.5000,  0.7032,  0.0968,
-         0.8104, -0.5956,  0.4206, -0.1424,  0.5201,  0.1960,  0.3719,  0.4885,
-         1.1095,  0.9795, -0.2756, -0.7006,  1.1990, -0.3140, -0.2239, -0.0107],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8580,  0.5682,  0.1519,  0.1486,  0.5568,  0.7077, -0.5037, -1.0415,
-        -0.4545, -0.7562, -0.4813, -0.3469,  0.0300,  0.3686, -1.2326, -0.5846,
-        -0.9479,  0.6553, -0.9596, -0.5484, -0.2972, -0.0854, -0.6926, -0.4905,
-         0.5666,  0.2776, -0.2707, -0.4858, -0.2472,  1.5675, -0.8091, -0.7165,
-        -0.8655, -0.0885, -0.8584, -0.1248,  0.8695,  0.0000,  0.7445,  0.1974,
-        -0.0750,  1.1339, -0.1468,  0.7349,  0.6593,  0.5000,  0.7032,  0.0968,
-         0.8104, -0.5956,  0.4206, -0.1424,  0.5201,  0.1960,  0.3719,  0.4885,
-         1.1095,  0.9795, -0.2756, -0.7006,  1.1990, -0.3140, -0.2239, -0.0107],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8580,  0.5682,  0.1519,  0.1486,  0.5568,  0.7077, -0.5037, -1.0415,
-        -0.4545, -0.7562, -0.4813, -0.3469,  0.0300,  0.3686, -1.2326, -0.5846,
-        -0.9479,  0.6553, -0.9596, -0.5484, -0.2972, -0.0854, -0.6926, -0.4905,
-         0.5666,  0.2776, -0.2707, -0.4858, -0.2472,  1.5675, -0.8091, -0.7165,
-        -0.8655, -0.0885, -0.8584, -0.1248,  0.8695,  0.0000,  0.7445,  0.1974,
-        -0.0750,  1.1339, -0.1468,  0.7349,  0.6593,  0.5000,  0.7032,  0.0968,
-         0.8104, -0.5956,  0.4206, -0.1424,  0.5201,  0.1960,  0.3719,  0.4885,
-         1.1095,  0.9795, -0.2756, -0.7006,  1.1990, -0.3140, -0.2239, -0.0107],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8615,  0.5498,  0.1426,  0.1437,  0.5609,  0.7063, -0.4853, -1.0392,
-        -0.4422, -0.7566, -0.5024, -0.3469, -0.0150,  0.3130, -1.2433, -0.5908,
-        -0.9448,  0.6501, -0.9575, -0.5335, -0.3347, -0.0239, -0.6772, -0.5361,
-         0.5668,  0.2451, -0.2727, -0.4706, -0.2271,  1.5718, -0.8091, -0.7202,
-        -0.8642, -0.0333, -0.8540, -0.1359,  0.8822,  0.0000,  0.7337,  0.1696,
-        -0.0924,  1.1415, -0.1539,  0.7312,  0.6514,  0.4783,  0.7057,  0.0821,
-         0.8094, -0.5977,  0.4134, -0.1033,  0.5177,  0.2258,  0.3406,  0.4769,
-         1.1206,  0.9733, -0.3000, -0.7001,  1.2049, -0.3125, -0.2168, -0.0075],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8615,  0.5498,  0.1426,  0.1437,  0.5609,  0.7063, -0.4853, -1.0392,
-        -0.4422, -0.7566, -0.5024, -0.3469, -0.0150,  0.3130, -1.2433, -0.5908,
-        -0.9448,  0.6501, -0.9575, -0.5335, -0.3347, -0.0239, -0.6772, -0.5361,
-         0.5668,  0.2451, -0.2727, -0.4706, -0.2271,  1.5718, -0.8091, -0.7202,
-        -0.8642, -0.0333, -0.8540, -0.1359,  0.8822,  0.0000,  0.7337,  0.1696,
-        -0.0924,  1.1415, -0.1539,  0.7312,  0.6514,  0.4783,  0.7057,  0.0821,
-         0.8094, -0.5977,  0.4134, -0.1033,  0.5177,  0.2258,  0.3406,  0.4769,
-         1.1206,  0.9733, -0.3000, -0.7001,  1.2049, -0.3125, -0.2168, -0.0075],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8615,  0.5498,  0.1426,  0.1437,  0.5609,  0.7063, -0.4853, -1.0392,
-        -0.4422, -0.7566, -0.5024, -0.3469, -0.0150,  0.3130, -1.2433, -0.5908,
-        -0.9448,  0.6501, -0.9575, -0.5335, -0.3347, -0.0239, -0.6772, -0.5361,
-         0.5668,  0.2451, -0.2727, -0.4706, -0.2271,  1.5718, -0.8091, -0.7202,
-        -0.8642, -0.0333, -0.8540, -0.1359,  0.8822,  0.0000,  0.7337,  0.1696,
-        -0.0924,  1.1415, -0.1539,  0.7312,  0.6514,  0.4783,  0.7057,  0.0821,
-         0.8094, -0.5977,  0.4134, -0.1033,  0.5177,  0.2258,  0.3406,  0.4769,
-         1.1206,  0.9733, -0.3000, -0.7001,  1.2049, -0.3125, -0.2168, -0.0075],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8640,  0.5320,  0.1126,  0.1348,  0.5654,  0.7066, -0.4646, -1.0361,
-        -0.4352, -0.7581, -0.5252, -0.3422, -0.0782,  0.2571, -1.2543, -0.5938,
-        -0.9410,  0.6455, -0.9536, -0.5214, -0.3666,  0.0523, -0.6580, -0.5772,
-         0.5655,  0.2158, -0.2773, -0.4560, -0.2098,  1.5758, -0.8074, -0.7258,
-        -0.8647,  0.0232, -0.8497, -0.1368,  0.8938,  0.0000,  0.7224,  0.1350,
-        -0.0992,  1.1476, -0.1670,  0.7284,  0.6425,  0.4515,  0.7073,  0.0663,
-         0.8090, -0.5992,  0.4038, -0.0688,  0.5143,  0.2458,  0.3057,  0.4591,
-         1.1309,  0.9683, -0.3125, -0.6936,  1.2092, -0.2963, -0.2008, -0.0061],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8640,  0.5320,  0.1126,  0.1348,  0.5654,  0.7066, -0.4646, -1.0361,
-        -0.4352, -0.7581, -0.5252, -0.3422, -0.0782,  0.2571, -1.2543, -0.5938,
-        -0.9410,  0.6455, -0.9536, -0.5214, -0.3666,  0.0523, -0.6580, -0.5772,
-         0.5655,  0.2158, -0.2773, -0.4560, -0.2098,  1.5758, -0.8074, -0.7258,
-        -0.8647,  0.0232, -0.8497, -0.1368,  0.8938,  0.0000,  0.7224,  0.1350,
-        -0.0992,  1.1476, -0.1670,  0.7284,  0.6425,  0.4515,  0.7073,  0.0663,
-         0.8090, -0.5992,  0.4038, -0.0688,  0.5143,  0.2458,  0.3057,  0.4591,
-         1.1309,  0.9683, -0.3125, -0.6936,  1.2092, -0.2963, -0.2008, -0.0061],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8640,  0.5320,  0.1126,  0.1348,  0.5654,  0.7066, -0.4646, -1.0361,
-        -0.4352, -0.7581, -0.5252, -0.3422, -0.0782,  0.2571, -1.2543, -0.5938,
-        -0.9410,  0.6455, -0.9536, -0.5214, -0.3666,  0.0523, -0.6580, -0.5772,
-         0.5655,  0.2158, -0.2773, -0.4560, -0.2098,  1.5758, -0.8074, -0.7258,
-        -0.8647,  0.0232, -0.8497, -0.1368,  0.8938,  0.0000,  0.7224,  0.1350,
-        -0.0992,  1.1476, -0.1670,  0.7284,  0.6425,  0.4515,  0.7073,  0.0663,
-         0.8090, -0.5992,  0.4038, -0.0688,  0.5143,  0.2458,  0.3057,  0.4591,
-         1.1309,  0.9683, -0.3125, -0.6936,  1.2092, -0.2963, -0.2008, -0.0061],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8656,  0.5115,  0.0939,  0.1335,  0.5718,  0.7014, -0.4412, -1.0328,
-        -0.4289, -0.7569, -0.5427, -0.3206, -0.1134,  0.2304, -1.2674, -0.5922,
-        -0.9348,  0.6377, -0.9505, -0.5183, -0.4073,  0.1053, -0.6441, -0.6184,
-         0.5584,  0.1972, -0.2746, -0.4430, -0.1966,  1.5780, -0.8044, -0.7284,
-        -0.8676,  0.0553, -0.8461, -0.1422,  0.9068,  0.0000,  0.7189,  0.1351,
-        -0.0901,  1.1528, -0.1569,  0.7295,  0.6330,  0.4215,  0.7081,  0.0564,
-         0.8018, -0.5989,  0.3900, -0.0458,  0.5043,  0.2656,  0.2728,  0.4339,
-         1.1404,  0.9667, -0.3008, -0.6888,  1.2145, -0.2633, -0.1884,  0.0172],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8656,  0.5115,  0.0939,  0.1335,  0.5718,  0.7014, -0.4412, -1.0328,
-        -0.4289, -0.7569, -0.5427, -0.3206, -0.1134,  0.2304, -1.2674, -0.5922,
-        -0.9348,  0.6377, -0.9505, -0.5183, -0.4073,  0.1053, -0.6441, -0.6184,
-         0.5584,  0.1972, -0.2746, -0.4430, -0.1966,  1.5780, -0.8044, -0.7284,
-        -0.8676,  0.0553, -0.8461, -0.1422,  0.9068,  0.0000,  0.7189,  0.1351,
-        -0.0901,  1.1528, -0.1569,  0.7295,  0.6330,  0.4215,  0.7081,  0.0564,
-         0.8018, -0.5989,  0.3900, -0.0458,  0.5043,  0.2656,  0.2728,  0.4339,
-         1.1404,  0.9667, -0.3008, -0.6888,  1.2145, -0.2633, -0.1884,  0.0172],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8656,  0.5115,  0.0939,  0.1335,  0.5718,  0.7014, -0.4412, -1.0328,
-        -0.4289, -0.7569, -0.5427, -0.3206, -0.1134,  0.2304, -1.2674, -0.5922,
-        -0.9348,  0.6377, -0.9505, -0.5183, -0.4073,  0.1053, -0.6441, -0.6184,
-         0.5584,  0.1972, -0.2746, -0.4430, -0.1966,  1.5780, -0.8044, -0.7284,
-        -0.8676,  0.0553, -0.8461, -0.1422,  0.9068,  0.0000,  0.7189,  0.1351,
-        -0.0901,  1.1528, -0.1569,  0.7295,  0.6330,  0.4215,  0.7081,  0.0564,
-         0.8018, -0.5989,  0.3900, -0.0458,  0.5043,  0.2656,  0.2728,  0.4339,
-         1.1404,  0.9667, -0.3008, -0.6888,  1.2145, -0.2633, -0.1884,  0.0172],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8669,  0.4829,  0.0972,  0.1472,  0.5625,  0.6824, -0.4260, -1.0295,
-        -0.4305, -0.7566, -0.5594, -0.2943, -0.0971,  0.2563, -1.2820, -0.5911,
-        -0.9294,  0.6229, -0.9481, -0.5163, -0.4501,  0.1345, -0.6443, -0.6559,
-         0.5519,  0.1859, -0.2537, -0.4325, -0.1762,  1.5793, -0.7995, -0.7330,
-        -0.8718,  0.0407, -0.8443, -0.1425,  0.9189,  0.0000,  0.7189,  0.1953,
-        -0.0826,  1.1557, -0.1064,  0.7344,  0.6248,  0.3809,  0.7055,  0.0529,
-         0.7995, -0.5960,  0.3688, -0.0358,  0.4932,  0.2845,  0.2478,  0.4123,
-         1.1489,  0.9678, -0.2617, -0.6868,  1.2202, -0.2200, -0.1738,  0.0540],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8669,  0.4829,  0.0972,  0.1472,  0.5625,  0.6824, -0.4260, -1.0295,
-        -0.4305, -0.7566, -0.5594, -0.2943, -0.0971,  0.2563, -1.2820, -0.5911,
-        -0.9294,  0.6229, -0.9481, -0.5163, -0.4501,  0.1345, -0.6443, -0.6559,
-         0.5519,  0.1859, -0.2537, -0.4325, -0.1762,  1.5793, -0.7995, -0.7330,
-        -0.8718,  0.0407, -0.8443, -0.1425,  0.9189,  0.0000,  0.7189,  0.1953,
-        -0.0826,  1.1557, -0.1064,  0.7344,  0.6248,  0.3809,  0.7055,  0.0529,
-         0.7995, -0.5960,  0.3688, -0.0358,  0.4932,  0.2845,  0.2478,  0.4123,
-         1.1489,  0.9678, -0.2617, -0.6868,  1.2202, -0.2200, -0.1738,  0.0540],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8669,  0.4829,  0.0972,  0.1472,  0.5625,  0.6824, -0.4260, -1.0295,
-        -0.4305, -0.7566, -0.5594, -0.2943, -0.0971,  0.2563, -1.2820, -0.5911,
-        -0.9294,  0.6229, -0.9481, -0.5163, -0.4501,  0.1345, -0.6443, -0.6559,
-         0.5519,  0.1859, -0.2537, -0.4325, -0.1762,  1.5793, -0.7995, -0.7330,
-        -0.8718,  0.0407, -0.8443, -0.1425,  0.9189,  0.0000,  0.7189,  0.1953,
-        -0.0826,  1.1557, -0.1064,  0.7344,  0.6248,  0.3809,  0.7055,  0.0529,
-         0.7995, -0.5960,  0.3688, -0.0358,  0.4932,  0.2845,  0.2478,  0.4123,
-         1.1489,  0.9678, -0.2617, -0.6868,  1.2202, -0.2200, -0.1738,  0.0540],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8688,  0.4379,  0.0994,  0.1414,  0.5560,  0.6651, -0.4047, -1.0253,
-        -0.4308, -0.7580, -0.5686, -0.2633, -0.0554,  0.2868, -1.2924, -0.5931,
-        -0.9239,  0.6103, -0.9454, -0.5114, -0.4967,  0.1284, -0.6470, -0.6966,
-         0.5330,  0.1831, -0.2245, -0.4145, -0.1705,  1.5809, -0.7938, -0.7410,
-        -0.8770, -0.0184, -0.8423, -0.1393,  0.9317,  0.0000,  0.7166,  0.2511,
-        -0.0946,  1.1601, -0.0619,  0.7387,  0.6193,  0.3414,  0.7009,  0.0547,
-         0.8046, -0.5957,  0.3517, -0.0678,  0.4722,  0.2912,  0.2280,  0.3978,
-         1.1567,  0.9689, -0.2237, -0.6884,  1.2253, -0.1650, -0.1686,  0.0921],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8688,  0.4379,  0.0994,  0.1414,  0.5560,  0.6651, -0.4047, -1.0253,
-        -0.4308, -0.7580, -0.5686, -0.2633, -0.0554,  0.2868, -1.2924, -0.5931,
-        -0.9239,  0.6103, -0.9454, -0.5114, -0.4967,  0.1284, -0.6470, -0.6966,
-         0.5330,  0.1831, -0.2245, -0.4145, -0.1705,  1.5809, -0.7938, -0.7410,
-        -0.8770, -0.0184, -0.8423, -0.1393,  0.9317,  0.0000,  0.7166,  0.2511,
-        -0.0946,  1.1601, -0.0619,  0.7387,  0.6193,  0.3414,  0.7009,  0.0547,
-         0.8046, -0.5957,  0.3517, -0.0678,  0.4722,  0.2912,  0.2280,  0.3978,
-         1.1567,  0.9689, -0.2237, -0.6884,  1.2253, -0.1650, -0.1686,  0.0921],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8688,  0.4379,  0.0994,  0.1414,  0.5560,  0.6651, -0.4047, -1.0253,
-        -0.4308, -0.7580, -0.5686, -0.2633, -0.0554,  0.2868, -1.2924, -0.5931,
-        -0.9239,  0.6103, -0.9454, -0.5114, -0.4967,  0.1284, -0.6470, -0.6966,
-         0.5330,  0.1831, -0.2245, -0.4145, -0.1705,  1.5809, -0.7938, -0.7410,
-        -0.8770, -0.0184, -0.8423, -0.1393,  0.9317,  0.0000,  0.7166,  0.2511,
-        -0.0946,  1.1601, -0.0619,  0.7387,  0.6193,  0.3414,  0.7009,  0.0547,
-         0.8046, -0.5957,  0.3517, -0.0678,  0.4722,  0.2912,  0.2280,  0.3978,
-         1.1567,  0.9689, -0.2237, -0.6884,  1.2253, -0.1650, -0.1686,  0.0921],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8696,  0.3970,  0.1143,  0.1278,  0.5514,  0.6516, -0.3824, -1.0206,
-        -0.4309, -0.7595, -0.5691, -0.2167,  0.0160,  0.3043, -1.3013, -0.5937,
-        -0.9225,  0.5987, -0.9449, -0.5065, -0.5425,  0.1188, -0.6465, -0.7227,
-         0.5062,  0.1883, -0.1891, -0.4015, -0.1601,  1.5817, -0.7894, -0.7479,
-        -0.8811, -0.0756, -0.8405, -0.1365,  0.9439,  0.0000,  0.7094,  0.3116,
-        -0.1168,  1.1625,  0.0049,  0.7410,  0.6145,  0.3042,  0.6945,  0.0510,
-         0.8084, -0.5940,  0.3379, -0.0902,  0.4497,  0.2889,  0.2170,  0.3915,
-         1.1637,  0.9719, -0.1641, -0.6805,  1.2303, -0.1053, -0.1621,  0.1267],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8696,  0.3970,  0.1143,  0.1278,  0.5514,  0.6516, -0.3824, -1.0206,
-        -0.4309, -0.7595, -0.5691, -0.2167,  0.0160,  0.3043, -1.3013, -0.5937,
-        -0.9225,  0.5987, -0.9449, -0.5065, -0.5425,  0.1188, -0.6465, -0.7227,
-         0.5062,  0.1883, -0.1891, -0.4015, -0.1601,  1.5817, -0.7894, -0.7479,
-        -0.8811, -0.0756, -0.8405, -0.1365,  0.9439,  0.0000,  0.7094,  0.3116,
-        -0.1168,  1.1625,  0.0049,  0.7410,  0.6145,  0.3042,  0.6945,  0.0510,
-         0.8084, -0.5940,  0.3379, -0.0902,  0.4497,  0.2889,  0.2170,  0.3915,
-         1.1637,  0.9719, -0.1641, -0.6805,  1.2303, -0.1053, -0.1621,  0.1267],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8696,  0.3970,  0.1143,  0.1278,  0.5514,  0.6516, -0.3824, -1.0206,
-        -0.4309, -0.7595, -0.5691, -0.2167,  0.0160,  0.3043, -1.3013, -0.5937,
-        -0.9225,  0.5987, -0.9449, -0.5065, -0.5425,  0.1188, -0.6465, -0.7227,
-         0.5062,  0.1883, -0.1891, -0.4015, -0.1601,  1.5817, -0.7894, -0.7479,
-        -0.8811, -0.0756, -0.8405, -0.1365,  0.9439,  0.0000,  0.7094,  0.3116,
-        -0.1168,  1.1625,  0.0049,  0.7410,  0.6145,  0.3042,  0.6945,  0.0510,
-         0.8084, -0.5940,  0.3379, -0.0902,  0.4497,  0.2889,  0.2170,  0.3915,
-         1.1637,  0.9719, -0.1641, -0.6805,  1.2303, -0.1053, -0.1621,  0.1267],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8698,  0.3611,  0.1099,  0.1090,  0.5275,  0.6427, -0.3652, -1.0166,
-        -0.4233, -0.7580, -0.5605, -0.1878,  0.0627,  0.3002, -1.3084, -0.5952,
-        -0.9189,  0.5894, -0.9454, -0.5083, -0.5770,  0.1144, -0.6397, -0.7255,
-         0.4741,  0.2004, -0.1680, -0.3891, -0.1354,  1.5824, -0.7827, -0.7523,
-        -0.8859, -0.0990, -0.8373, -0.1441,  0.9571,  0.0000,  0.7045,  0.3622,
-        -0.1115,  1.1636,  0.0488,  0.7433,  0.6091,  0.2751,  0.6872,  0.0442,
-         0.8123, -0.5903,  0.3236, -0.0826,  0.4268,  0.2764,  0.2027,  0.3790,
-         1.1707,  0.9743, -0.0914, -0.6543,  1.2349, -0.0513, -0.1588,  0.1596],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8698,  0.3611,  0.1099,  0.1090,  0.5275,  0.6427, -0.3652, -1.0166,
-        -0.4233, -0.7580, -0.5605, -0.1878,  0.0627,  0.3002, -1.3084, -0.5952,
-        -0.9189,  0.5894, -0.9454, -0.5083, -0.5770,  0.1144, -0.6397, -0.7255,
-         0.4741,  0.2004, -0.1680, -0.3891, -0.1354,  1.5824, -0.7827, -0.7523,
-        -0.8859, -0.0990, -0.8373, -0.1441,  0.9571,  0.0000,  0.7045,  0.3622,
-        -0.1115,  1.1636,  0.0488,  0.7433,  0.6091,  0.2751,  0.6872,  0.0442,
-         0.8123, -0.5903,  0.3236, -0.0826,  0.4268,  0.2764,  0.2027,  0.3790,
-         1.1707,  0.9743, -0.0914, -0.6543,  1.2349, -0.0513, -0.1588,  0.1596],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8698,  0.3611,  0.1099,  0.1090,  0.5275,  0.6427, -0.3652, -1.0166,
-        -0.4233, -0.7580, -0.5605, -0.1878,  0.0627,  0.3002, -1.3084, -0.5952,
-        -0.9189,  0.5894, -0.9454, -0.5083, -0.5770,  0.1144, -0.6397, -0.7255,
-         0.4741,  0.2004, -0.1680, -0.3891, -0.1354,  1.5824, -0.7827, -0.7523,
-        -0.8859, -0.0990, -0.8373, -0.1441,  0.9571,  0.0000,  0.7045,  0.3622,
-        -0.1115,  1.1636,  0.0488,  0.7433,  0.6091,  0.2751,  0.6872,  0.0442,
-         0.8123, -0.5903,  0.3236, -0.0826,  0.4268,  0.2764,  0.2027,  0.3790,
-         1.1707,  0.9743, -0.0914, -0.6543,  1.2349, -0.0513, -0.1588,  0.1596],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8704,  0.3286,  0.1077,  0.0956,  0.4878,  0.6359, -0.3478, -1.0133,
-        -0.4153, -0.7529, -0.5467, -0.1702,  0.1164,  0.3137, -1.3147, -0.5960,
-        -0.9108,  0.5797, -0.9463, -0.5160, -0.5981,  0.1140, -0.6267, -0.7357,
-         0.4419,  0.2234, -0.1568, -0.3778, -0.1140,  1.5825, -0.7775, -0.7555,
-        -0.8920, -0.0965, -0.8347, -0.1692,  0.9711,  0.0000,  0.7129,  0.4328,
-        -0.0760,  1.1650,  0.0857,  0.7479,  0.6016,  0.2381,  0.6803,  0.0424,
-         0.8087, -0.5866,  0.3149, -0.0707,  0.3964,  0.2709,  0.1915,  0.3780,
-         1.1764,  0.9833, -0.0193, -0.6249,  1.2403,  0.0029, -0.1500,  0.1741],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8704,  0.3286,  0.1077,  0.0956,  0.4878,  0.6359, -0.3478, -1.0133,
-        -0.4153, -0.7529, -0.5467, -0.1702,  0.1164,  0.3137, -1.3147, -0.5960,
-        -0.9108,  0.5797, -0.9463, -0.5160, -0.5981,  0.1140, -0.6267, -0.7357,
-         0.4419,  0.2234, -0.1568, -0.3778, -0.1140,  1.5825, -0.7775, -0.7555,
-        -0.8920, -0.0965, -0.8347, -0.1692,  0.9711,  0.0000,  0.7129,  0.4328,
-        -0.0760,  1.1650,  0.0857,  0.7479,  0.6016,  0.2381,  0.6803,  0.0424,
-         0.8087, -0.5866,  0.3149, -0.0707,  0.3964,  0.2709,  0.1915,  0.3780,
-         1.1764,  0.9833, -0.0193, -0.6249,  1.2403,  0.0029, -0.1500,  0.1741],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8704,  0.3286,  0.1077,  0.0956,  0.4878,  0.6359, -0.3478, -1.0133,
-        -0.4153, -0.7529, -0.5467, -0.1702,  0.1164,  0.3137, -1.3147, -0.5960,
-        -0.9108,  0.5797, -0.9463, -0.5160, -0.5981,  0.1140, -0.6267, -0.7357,
-         0.4419,  0.2234, -0.1568, -0.3778, -0.1140,  1.5825, -0.7775, -0.7555,
-        -0.8920, -0.0965, -0.8347, -0.1692,  0.9711,  0.0000,  0.7129,  0.4328,
-        -0.0760,  1.1650,  0.0857,  0.7479,  0.6016,  0.2381,  0.6803,  0.0424,
-         0.8087, -0.5866,  0.3149, -0.0707,  0.3964,  0.2709,  0.1915,  0.3780,
-         1.1764,  0.9833, -0.0193, -0.6249,  1.2403,  0.0029, -0.1500,  0.1741],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8679,  0.3139,  0.0929,  0.0541,  0.4245,  0.6272, -0.3308, -1.0107,
-        -0.4154, -0.7463, -0.5303, -0.1639,  0.1226,  0.3212, -1.3224, -0.5940,
-        -0.9024,  0.5757, -0.9478, -0.5246, -0.6207,  0.1331, -0.6115, -0.7372,
-         0.4204,  0.2472, -0.1458, -0.3620, -0.1023,  1.5820, -0.7730, -0.7567,
-        -0.8995, -0.1088, -0.8316, -0.1923,  0.9837,  0.0000,  0.7192,  0.4981,
-        -0.0179,  1.1649,  0.0936,  0.7506,  0.5945,  0.2056,  0.6726,  0.0391,
-         0.8051, -0.5816,  0.3122, -0.0646,  0.3859,  0.2674,  0.2138,  0.3733,
-         1.1825,  0.9923,  0.0642, -0.6004,  1.2448,  0.0722, -0.1372,  0.1711],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8679,  0.3139,  0.0929,  0.0541,  0.4245,  0.6272, -0.3308, -1.0107,
-        -0.4154, -0.7463, -0.5303, -0.1639,  0.1226,  0.3212, -1.3224, -0.5940,
-        -0.9024,  0.5757, -0.9478, -0.5246, -0.6207,  0.1331, -0.6115, -0.7372,
-         0.4204,  0.2472, -0.1458, -0.3620, -0.1023,  1.5820, -0.7730, -0.7567,
-        -0.8995, -0.1088, -0.8316, -0.1923,  0.9837,  0.0000,  0.7192,  0.4981,
-        -0.0179,  1.1649,  0.0936,  0.7506,  0.5945,  0.2056,  0.6726,  0.0391,
-         0.8051, -0.5816,  0.3122, -0.0646,  0.3859,  0.2674,  0.2138,  0.3733,
-         1.1825,  0.9923,  0.0642, -0.6004,  1.2448,  0.0722, -0.1372,  0.1711],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8679,  0.3139,  0.0929,  0.0541,  0.4245,  0.6272, -0.3308, -1.0107,
-        -0.4154, -0.7463, -0.5303, -0.1639,  0.1226,  0.3212, -1.3224, -0.5940,
-        -0.9024,  0.5757, -0.9478, -0.5246, -0.6207,  0.1331, -0.6115, -0.7372,
-         0.4204,  0.2472, -0.1458, -0.3620, -0.1023,  1.5820, -0.7730, -0.7567,
-        -0.8995, -0.1088, -0.8316, -0.1923,  0.9837,  0.0000,  0.7192,  0.4981,
-        -0.0179,  1.1649,  0.0936,  0.7506,  0.5945,  0.2056,  0.6726,  0.0391,
-         0.8051, -0.5816,  0.3122, -0.0646,  0.3859,  0.2674,  0.2138,  0.3733,
-         1.1825,  0.9923,  0.0642, -0.6004,  1.2448,  0.0722, -0.1372,  0.1711],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8672,  0.3030,  0.0648,  0.0193,  0.3642,  0.6214, -0.3156, -1.0073,
-        -0.4101, -0.7380, -0.5107, -0.1470,  0.1068,  0.3307, -1.3287, -0.5852,
-        -0.8922,  0.5744, -0.9503, -0.5357, -0.6376,  0.1431, -0.5926, -0.7413,
-         0.3965,  0.2771, -0.1334, -0.3441, -0.1018,  1.5810, -0.7673, -0.7574,
-        -0.9065, -0.1061, -0.8263, -0.2126,  0.9943,  0.0000,  0.7328,  0.5667,
-         0.0477,  1.1632,  0.0676,  0.7547,  0.5879,  0.1770,  0.6652,  0.0417,
-         0.8048, -0.5754,  0.3123, -0.0399,  0.3696,  0.2673,  0.2639,  0.3743,
-         1.1871,  1.0038,  0.1352, -0.5730,  1.2491,  0.1186, -0.1312,  0.1505],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8672,  0.3030,  0.0648,  0.0193,  0.3642,  0.6214, -0.3156, -1.0073,
-        -0.4101, -0.7380, -0.5107, -0.1470,  0.1068,  0.3307, -1.3287, -0.5852,
-        -0.8922,  0.5744, -0.9503, -0.5357, -0.6376,  0.1431, -0.5926, -0.7413,
-         0.3965,  0.2771, -0.1334, -0.3441, -0.1018,  1.5810, -0.7673, -0.7574,
-        -0.9065, -0.1061, -0.8263, -0.2126,  0.9943,  0.0000,  0.7328,  0.5667,
-         0.0477,  1.1632,  0.0676,  0.7547,  0.5879,  0.1770,  0.6652,  0.0417,
-         0.8048, -0.5754,  0.3123, -0.0399,  0.3696,  0.2673,  0.2639,  0.3743,
-         1.1871,  1.0038,  0.1352, -0.5730,  1.2491,  0.1186, -0.1312,  0.1505],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8672,  0.3030,  0.0648,  0.0193,  0.3642,  0.6214, -0.3156, -1.0073,
-        -0.4101, -0.7380, -0.5107, -0.1470,  0.1068,  0.3307, -1.3287, -0.5852,
-        -0.8922,  0.5744, -0.9503, -0.5357, -0.6376,  0.1431, -0.5926, -0.7413,
-         0.3965,  0.2771, -0.1334, -0.3441, -0.1018,  1.5810, -0.7673, -0.7574,
-        -0.9065, -0.1061, -0.8263, -0.2126,  0.9943,  0.0000,  0.7328,  0.5667,
-         0.0477,  1.1632,  0.0676,  0.7547,  0.5879,  0.1770,  0.6652,  0.0417,
-         0.8048, -0.5754,  0.3123, -0.0399,  0.3696,  0.2673,  0.2639,  0.3743,
-         1.1871,  1.0038,  0.1352, -0.5730,  1.2491,  0.1186, -0.1312,  0.1505],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8614,  0.3050,  0.0419, -0.0145,  0.3315,  0.6065, -0.3018, -1.0037,
-        -0.4066, -0.7240, -0.4917, -0.1261,  0.0862,  0.3117, -1.3317, -0.5752,
-        -0.8818,  0.5694, -0.9537, -0.5408, -0.6529,  0.1675, -0.5670, -0.7463,
-         0.3711,  0.2870, -0.1297, -0.3210, -0.1083,  1.5799, -0.7640, -0.7552,
-        -0.9134, -0.0812, -0.8187, -0.2431,  1.0043,  0.0000,  0.7388,  0.6328,
-         0.1135,  1.1619,  0.0513,  0.7566,  0.5826,  0.1513,  0.6596,  0.0444,
-         0.8066, -0.5697,  0.3179,  0.0136,  0.3618,  0.2705,  0.3170,  0.3688,
-         1.1913,  1.0124,  0.1925, -0.5293,  1.2533,  0.1649, -0.1249,  0.1178],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8614,  0.3050,  0.0419, -0.0145,  0.3315,  0.6065, -0.3018, -1.0037,
-        -0.4066, -0.7240, -0.4917, -0.1261,  0.0862,  0.3117, -1.3317, -0.5752,
-        -0.8818,  0.5694, -0.9537, -0.5408, -0.6529,  0.1675, -0.5670, -0.7463,
-         0.3711,  0.2870, -0.1297, -0.3210, -0.1083,  1.5799, -0.7640, -0.7552,
-        -0.9134, -0.0812, -0.8187, -0.2431,  1.0043,  0.0000,  0.7388,  0.6328,
-         0.1135,  1.1619,  0.0513,  0.7566,  0.5826,  0.1513,  0.6596,  0.0444,
-         0.8066, -0.5697,  0.3179,  0.0136,  0.3618,  0.2705,  0.3170,  0.3688,
-         1.1913,  1.0124,  0.1925, -0.5293,  1.2533,  0.1649, -0.1249,  0.1178],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8614,  0.3050,  0.0419, -0.0145,  0.3315,  0.6065, -0.3018, -1.0037,
-        -0.4066, -0.7240, -0.4917, -0.1261,  0.0862,  0.3117, -1.3317, -0.5752,
-        -0.8818,  0.5694, -0.9537, -0.5408, -0.6529,  0.1675, -0.5670, -0.7463,
-         0.3711,  0.2870, -0.1297, -0.3210, -0.1083,  1.5799, -0.7640, -0.7552,
-        -0.9134, -0.0812, -0.8187, -0.2431,  1.0043,  0.0000,  0.7388,  0.6328,
-         0.1135,  1.1619,  0.0513,  0.7566,  0.5826,  0.1513,  0.6596,  0.0444,
-         0.8066, -0.5697,  0.3179,  0.0136,  0.3618,  0.2705,  0.3170,  0.3688,
-         1.1913,  1.0124,  0.1925, -0.5293,  1.2533,  0.1649, -0.1249,  0.1178],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8566,  0.3113,  0.0420, -0.0425,  0.3402,  0.5804, -0.3005, -0.9985,
-        -0.4009, -0.7087, -0.4680, -0.0930,  0.0949,  0.2933, -1.3351, -0.5605,
-        -0.8733,  0.5662, -0.9572, -0.5380, -0.6701,  0.1948, -0.5460, -0.7614,
-         0.3433,  0.2954, -0.1110, -0.3044, -0.0967,  1.5793, -0.7585, -0.7508,
-        -0.9182, -0.0417, -0.8074, -0.2594,  1.0123,  0.0000,  0.7393,  0.6916,
-         0.1578,  1.1618,  0.0555,  0.7544,  0.5799,  0.1336,  0.6559,  0.0417,
-         0.8091, -0.5644,  0.3254,  0.1106,  0.3542,  0.2757,  0.3596,  0.3612,
-         1.1942,  1.0189,  0.2592, -0.4954,  1.2575,  0.2151, -0.1040,  0.1039],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8566,  0.3113,  0.0420, -0.0425,  0.3402,  0.5804, -0.3005, -0.9985,
-        -0.4009, -0.7087, -0.4680, -0.0930,  0.0949,  0.2933, -1.3351, -0.5605,
-        -0.8733,  0.5662, -0.9572, -0.5380, -0.6701,  0.1948, -0.5460, -0.7614,
-         0.3433,  0.2954, -0.1110, -0.3044, -0.0967,  1.5793, -0.7585, -0.7508,
-        -0.9182, -0.0417, -0.8074, -0.2594,  1.0123,  0.0000,  0.7393,  0.6916,
-         0.1578,  1.1618,  0.0555,  0.7544,  0.5799,  0.1336,  0.6559,  0.0417,
-         0.8091, -0.5644,  0.3254,  0.1106,  0.3542,  0.2757,  0.3596,  0.3612,
-         1.1942,  1.0189,  0.2592, -0.4954,  1.2575,  0.2151, -0.1040,  0.1039],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8566,  0.3113,  0.0420, -0.0425,  0.3402,  0.5804, -0.3005, -0.9985,
-        -0.4009, -0.7087, -0.4680, -0.0930,  0.0949,  0.2933, -1.3351, -0.5605,
-        -0.8733,  0.5662, -0.9572, -0.5380, -0.6701,  0.1948, -0.5460, -0.7614,
-         0.3433,  0.2954, -0.1110, -0.3044, -0.0967,  1.5793, -0.7585, -0.7508,
-        -0.9182, -0.0417, -0.8074, -0.2594,  1.0123,  0.0000,  0.7393,  0.6916,
-         0.1578,  1.1618,  0.0555,  0.7544,  0.5799,  0.1336,  0.6559,  0.0417,
-         0.8091, -0.5644,  0.3254,  0.1106,  0.3542,  0.2757,  0.3596,  0.3612,
-         1.1942,  1.0189,  0.2592, -0.4954,  1.2575,  0.2151, -0.1040,  0.1039],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8527,  0.3075,  0.0538, -0.0763,  0.3537,  0.5605, -0.2848, -0.9932,
-        -0.3962, -0.6895, -0.4484, -0.0556,  0.1198,  0.2800, -1.3415, -0.5452,
-        -0.8655,  0.5664, -0.9609, -0.5298, -0.6882,  0.2132, -0.5262, -0.7765,
-         0.3103,  0.2877, -0.0869, -0.2927, -0.0813,  1.5782, -0.7500, -0.7489,
-        -0.9225, -0.0154, -0.7969, -0.2784,  1.0154,  0.0000,  0.7373,  0.7465,
-         0.1914,  1.1638,  0.0705,  0.7518,  0.5802,  0.1237,  0.6503,  0.0413,
-         0.8054, -0.5587,  0.3282,  0.1911,  0.3487,  0.2673,  0.4117,  0.3498,
-         1.1965,  1.0253,  0.3289, -0.4747,  1.2622,  0.2682, -0.0792,  0.1003],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8527,  0.3075,  0.0538, -0.0763,  0.3537,  0.5605, -0.2848, -0.9932,
-        -0.3962, -0.6895, -0.4484, -0.0556,  0.1198,  0.2800, -1.3415, -0.5452,
-        -0.8655,  0.5664, -0.9609, -0.5298, -0.6882,  0.2132, -0.5262, -0.7765,
-         0.3103,  0.2877, -0.0869, -0.2927, -0.0813,  1.5782, -0.7500, -0.7489,
-        -0.9225, -0.0154, -0.7969, -0.2784,  1.0154,  0.0000,  0.7373,  0.7465,
-         0.1914,  1.1638,  0.0705,  0.7518,  0.5802,  0.1237,  0.6503,  0.0413,
-         0.8054, -0.5587,  0.3282,  0.1911,  0.3487,  0.2673,  0.4117,  0.3498,
-         1.1965,  1.0253,  0.3289, -0.4747,  1.2622,  0.2682, -0.0792,  0.1003],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8527,  0.3075,  0.0538, -0.0763,  0.3537,  0.5605, -0.2848, -0.9932,
-        -0.3962, -0.6895, -0.4484, -0.0556,  0.1198,  0.2800, -1.3415, -0.5452,
-        -0.8655,  0.5664, -0.9609, -0.5298, -0.6882,  0.2132, -0.5262, -0.7765,
-         0.3103,  0.2877, -0.0869, -0.2927, -0.0813,  1.5782, -0.7500, -0.7489,
-        -0.9225, -0.0154, -0.7969, -0.2784,  1.0154,  0.0000,  0.7373,  0.7465,
-         0.1914,  1.1638,  0.0705,  0.7518,  0.5802,  0.1237,  0.6503,  0.0413,
-         0.8054, -0.5587,  0.3282,  0.1911,  0.3487,  0.2673,  0.4117,  0.3498,
-         1.1965,  1.0253,  0.3289, -0.4747,  1.2622,  0.2682, -0.0792,  0.1003],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8489,  0.3025,  0.0578, -0.1292,  0.3443,  0.5440, -0.2672, -0.9873,
-        -0.4080, -0.6745, -0.4315,  0.0092,  0.1380,  0.2704, -1.3479, -0.5388,
-        -0.8563,  0.5678, -0.9633, -0.5123, -0.7147,  0.2302, -0.5135, -0.7945,
-         0.2699,  0.2711, -0.0704, -0.2985, -0.0735,  1.5769, -0.7365, -0.7470,
-        -0.9264, -0.0130, -0.7855, -0.2727,  1.0135,  0.0000,  0.7217,  0.7883,
-         0.2323,  1.1661,  0.0815,  0.7420,  0.5806,  0.1357,  0.6404,  0.0411,
-         0.8017, -0.5518,  0.3246,  0.2735,  0.3493,  0.2570,  0.4444,  0.3352,
-         1.1985,  1.0291,  0.3760, -0.4897,  1.2661,  0.3257, -0.0445,  0.1112],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8489,  0.3025,  0.0578, -0.1292,  0.3443,  0.5440, -0.2672, -0.9873,
-        -0.4080, -0.6745, -0.4315,  0.0092,  0.1380,  0.2704, -1.3479, -0.5388,
-        -0.8563,  0.5678, -0.9633, -0.5123, -0.7147,  0.2302, -0.5135, -0.7945,
-         0.2699,  0.2711, -0.0704, -0.2985, -0.0735,  1.5769, -0.7365, -0.7470,
-        -0.9264, -0.0130, -0.7855, -0.2727,  1.0135,  0.0000,  0.7217,  0.7883,
-         0.2323,  1.1661,  0.0815,  0.7420,  0.5806,  0.1357,  0.6404,  0.0411,
-         0.8017, -0.5518,  0.3246,  0.2735,  0.3493,  0.2570,  0.4444,  0.3352,
-         1.1985,  1.0291,  0.3760, -0.4897,  1.2661,  0.3257, -0.0445,  0.1112],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8489,  0.3025,  0.0578, -0.1292,  0.3443,  0.5440, -0.2672, -0.9873,
-        -0.4080, -0.6745, -0.4315,  0.0092,  0.1380,  0.2704, -1.3479, -0.5388,
-        -0.8563,  0.5678, -0.9633, -0.5123, -0.7147,  0.2302, -0.5135, -0.7945,
-         0.2699,  0.2711, -0.0704, -0.2985, -0.0735,  1.5769, -0.7365, -0.7470,
-        -0.9264, -0.0130, -0.7855, -0.2727,  1.0135,  0.0000,  0.7217,  0.7883,
-         0.2323,  1.1661,  0.0815,  0.7420,  0.5806,  0.1357,  0.6404,  0.0411,
-         0.8017, -0.5518,  0.3246,  0.2735,  0.3493,  0.2570,  0.4444,  0.3352,
-         1.1985,  1.0291,  0.3760, -0.4897,  1.2661,  0.3257, -0.0445,  0.1112],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8425,  0.2976,  0.0388, -0.1585,  0.3094,  0.5290, -0.2437, -0.9823,
-        -0.4177, -0.6654, -0.4195,  0.0475,  0.1327,  0.2572, -1.3518, -0.5333,
-        -0.8482,  0.5650, -0.9676, -0.5022, -0.7319,  0.2373, -0.5030, -0.8043,
-         0.2128,  0.2431, -0.0533, -0.2969, -0.0725,  1.5760, -0.7242, -0.7457,
-        -0.9287,  0.0127, -0.7745, -0.2500,  1.0096,  0.0000,  0.6997,  0.8206,
-         0.2632,  1.1646,  0.0770,  0.7333,  0.5830,  0.1496,  0.6319,  0.0279,
-         0.8071, -0.5447,  0.3186,  0.2982,  0.3509,  0.2218,  0.4740,  0.3247,
-         1.2011,  1.0300,  0.4202, -0.4719,  1.2688,  0.3675, -0.0295,  0.1025],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8425,  0.2976,  0.0388, -0.1585,  0.3094,  0.5290, -0.2437, -0.9823,
-        -0.4177, -0.6654, -0.4195,  0.0475,  0.1327,  0.2572, -1.3518, -0.5333,
-        -0.8482,  0.5650, -0.9676, -0.5022, -0.7319,  0.2373, -0.5030, -0.8043,
-         0.2128,  0.2431, -0.0533, -0.2969, -0.0725,  1.5760, -0.7242, -0.7457,
-        -0.9287,  0.0127, -0.7745, -0.2500,  1.0096,  0.0000,  0.6997,  0.8206,
-         0.2632,  1.1646,  0.0770,  0.7333,  0.5830,  0.1496,  0.6319,  0.0279,
-         0.8071, -0.5447,  0.3186,  0.2982,  0.3509,  0.2218,  0.4740,  0.3247,
-         1.2011,  1.0300,  0.4202, -0.4719,  1.2688,  0.3675, -0.0295,  0.1025],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8425,  0.2976,  0.0388, -0.1585,  0.3094,  0.5290, -0.2437, -0.9823,
-        -0.4177, -0.6654, -0.4195,  0.0475,  0.1327,  0.2572, -1.3518, -0.5333,
-        -0.8482,  0.5650, -0.9676, -0.5022, -0.7319,  0.2373, -0.5030, -0.8043,
-         0.2128,  0.2431, -0.0533, -0.2969, -0.0725,  1.5760, -0.7242, -0.7457,
-        -0.9287,  0.0127, -0.7745, -0.2500,  1.0096,  0.0000,  0.6997,  0.8206,
-         0.2632,  1.1646,  0.0770,  0.7333,  0.5830,  0.1496,  0.6319,  0.0279,
-         0.8071, -0.5447,  0.3186,  0.2982,  0.3509,  0.2218,  0.4740,  0.3247,
-         1.2011,  1.0300,  0.4202, -0.4719,  1.2688,  0.3675, -0.0295,  0.1025],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8375,  0.2833, -0.0045, -0.1650,  0.2715,  0.5186, -0.2251, -0.9785,
-        -0.4229, -0.6607, -0.4137,  0.0199,  0.0827,  0.2623, -1.3560, -0.5316,
-        -0.8404,  0.5644, -0.9704, -0.4959, -0.7453,  0.2144, -0.4864, -0.8171,
-         0.1376,  0.2319, -0.0331, -0.2929, -0.0817,  1.5764, -0.7156, -0.7394,
-        -0.9297,  0.0513, -0.7637, -0.2271,  1.0110,  0.0000,  0.6878,  0.8470,
-         0.2675,  1.1607,  0.0411,  0.7291,  0.5805,  0.1486,  0.6226,  0.0385,
-         0.8169, -0.5387,  0.3069,  0.2822,  0.3429,  0.1732,  0.4923,  0.3309,
-         1.2040,  1.0328,  0.4562, -0.4657,  1.2727,  0.3888, -0.0144,  0.0939],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8375,  0.2833, -0.0045, -0.1650,  0.2715,  0.5186, -0.2251, -0.9785,
-        -0.4229, -0.6607, -0.4137,  0.0199,  0.0827,  0.2623, -1.3560, -0.5316,
-        -0.8404,  0.5644, -0.9704, -0.4959, -0.7453,  0.2144, -0.4864, -0.8171,
-         0.1376,  0.2319, -0.0331, -0.2929, -0.0817,  1.5764, -0.7156, -0.7394,
-        -0.9297,  0.0513, -0.7637, -0.2271,  1.0110,  0.0000,  0.6878,  0.8470,
-         0.2675,  1.1607,  0.0411,  0.7291,  0.5805,  0.1486,  0.6226,  0.0385,
-         0.8169, -0.5387,  0.3069,  0.2822,  0.3429,  0.1732,  0.4923,  0.3309,
-         1.2040,  1.0328,  0.4562, -0.4657,  1.2727,  0.3888, -0.0144,  0.0939],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8375,  0.2833, -0.0045, -0.1650,  0.2715,  0.5186, -0.2251, -0.9785,
-        -0.4229, -0.6607, -0.4137,  0.0199,  0.0827,  0.2623, -1.3560, -0.5316,
-        -0.8404,  0.5644, -0.9704, -0.4959, -0.7453,  0.2144, -0.4864, -0.8171,
-         0.1376,  0.2319, -0.0331, -0.2929, -0.0817,  1.5764, -0.7156, -0.7394,
-        -0.9297,  0.0513, -0.7637, -0.2271,  1.0110,  0.0000,  0.6878,  0.8470,
-         0.2675,  1.1607,  0.0411,  0.7291,  0.5805,  0.1486,  0.6226,  0.0385,
-         0.8169, -0.5387,  0.3069,  0.2822,  0.3429,  0.1732,  0.4923,  0.3309,
-         1.2040,  1.0328,  0.4562, -0.4657,  1.2727,  0.3888, -0.0144,  0.0939],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 8.3317e-01,  2.7207e-01, -5.6172e-02, -1.6518e-01,  2.6192e-01,
-         5.1705e-01, -2.3862e-01, -9.7664e-01, -4.2490e-01, -6.5969e-01,
-        -4.0968e-01, -4.4241e-03,  4.9412e-02,  2.7385e-01, -1.3576e+00,
-        -5.3574e-01, -8.3907e-01,  5.5970e-01, -9.7616e-01, -4.9480e-01,
-        -7.5499e-01,  1.6080e-01, -4.7468e-01, -8.2654e-01,  4.3112e-02,
-         2.3072e-01, -1.5595e-02, -2.9040e-01, -8.6214e-02,  1.5775e+00,
-        -7.1273e-01, -7.3383e-01, -9.3075e-01,  5.6695e-02, -7.5314e-01,
-        -2.0307e-01,  1.0139e+00,  0.0000e+00,  6.7313e-01,  8.7261e-01,
-         2.5058e-01,  1.1540e+00,  5.1844e-03,  7.2862e-01,  5.7412e-01,
-         1.4032e-01,  6.1620e-01,  4.9613e-02,  8.2662e-01, -5.3755e-01,
-         2.8893e-01,  2.3396e-01,  3.2864e-01,  1.0818e-01,  5.0813e-01,
-         3.3818e-01,  1.2068e+00,  1.0357e+00,  4.8912e-01, -4.2889e-01,
-         1.2759e+00,  3.9636e-01, -3.2507e-04,  7.9335e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.3317e-01,  2.7207e-01, -5.6172e-02, -1.6518e-01,  2.6192e-01,
-         5.1705e-01, -2.3862e-01, -9.7664e-01, -4.2490e-01, -6.5969e-01,
-        -4.0968e-01, -4.4241e-03,  4.9412e-02,  2.7385e-01, -1.3576e+00,
-        -5.3574e-01, -8.3907e-01,  5.5970e-01, -9.7616e-01, -4.9480e-01,
-        -7.5499e-01,  1.6080e-01, -4.7468e-01, -8.2654e-01,  4.3112e-02,
-         2.3072e-01, -1.5595e-02, -2.9040e-01, -8.6214e-02,  1.5775e+00,
-        -7.1273e-01, -7.3383e-01, -9.3075e-01,  5.6695e-02, -7.5314e-01,
-        -2.0307e-01,  1.0139e+00,  0.0000e+00,  6.7313e-01,  8.7261e-01,
-         2.5058e-01,  1.1540e+00,  5.1844e-03,  7.2862e-01,  5.7412e-01,
-         1.4032e-01,  6.1620e-01,  4.9613e-02,  8.2662e-01, -5.3755e-01,
-         2.8893e-01,  2.3396e-01,  3.2864e-01,  1.0818e-01,  5.0813e-01,
-         3.3818e-01,  1.2068e+00,  1.0357e+00,  4.8912e-01, -4.2889e-01,
-         1.2759e+00,  3.9636e-01, -3.2507e-04,  7.9335e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.3317e-01,  2.7207e-01, -5.6172e-02, -1.6518e-01,  2.6192e-01,
-         5.1705e-01, -2.3862e-01, -9.7664e-01, -4.2490e-01, -6.5969e-01,
-        -4.0968e-01, -4.4241e-03,  4.9412e-02,  2.7385e-01, -1.3576e+00,
-        -5.3574e-01, -8.3907e-01,  5.5970e-01, -9.7616e-01, -4.9480e-01,
-        -7.5499e-01,  1.6080e-01, -4.7468e-01, -8.2654e-01,  4.3112e-02,
-         2.3072e-01, -1.5595e-02, -2.9040e-01, -8.6214e-02,  1.5775e+00,
-        -7.1273e-01, -7.3383e-01, -9.3075e-01,  5.6695e-02, -7.5314e-01,
-        -2.0307e-01,  1.0139e+00,  0.0000e+00,  6.7313e-01,  8.7261e-01,
-         2.5058e-01,  1.1540e+00,  5.1844e-03,  7.2862e-01,  5.7412e-01,
-         1.4032e-01,  6.1620e-01,  4.9613e-02,  8.2662e-01, -5.3755e-01,
-         2.8893e-01,  2.3396e-01,  3.2864e-01,  1.0818e-01,  5.0813e-01,
-         3.3818e-01,  1.2068e+00,  1.0357e+00,  4.8912e-01, -4.2889e-01,
-         1.2759e+00,  3.9636e-01, -3.2507e-04,  7.9335e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 8.3011e-01,  2.6457e-01, -1.0603e-01, -1.8032e-01,  2.2555e-01,
-         5.2373e-01, -2.6862e-01, -9.7533e-01, -4.2992e-01, -6.5976e-01,
-        -4.0496e-01, -2.8469e-02,  3.5085e-02,  2.7971e-01, -1.3568e+00,
-        -5.4298e-01, -8.4320e-01,  5.5562e-01, -9.8359e-01, -4.8837e-01,
-        -7.5816e-01,  6.6743e-02, -4.5767e-01, -8.2435e-01, -5.9995e-02,
-         2.2646e-01,  1.3474e-02, -2.7961e-01, -8.0743e-02,  1.5793e+00,
-        -7.1144e-01, -7.2887e-01, -9.3148e-01,  4.1963e-02, -7.4577e-01,
-        -1.7196e-01,  1.0116e+00,  0.0000e+00,  6.3948e-01,  8.9107e-01,
-         2.2907e-01,  1.1469e+00, -3.3750e-02,  7.2373e-01,  5.7125e-01,
-         1.4722e-01,  6.1618e-01,  5.0592e-02,  8.3478e-01, -5.3777e-01,
-         2.7441e-01,  1.5093e-01,  3.1038e-01,  2.9085e-02,  5.2145e-01,
-         3.4857e-01,  1.2088e+00,  1.0340e+00,  5.4875e-01, -3.8305e-01,
-         1.2771e+00,  4.0332e-01,  4.0163e-04,  6.3698e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.3011e-01,  2.6457e-01, -1.0603e-01, -1.8032e-01,  2.2555e-01,
-         5.2373e-01, -2.6862e-01, -9.7533e-01, -4.2992e-01, -6.5976e-01,
-        -4.0496e-01, -2.8469e-02,  3.5085e-02,  2.7971e-01, -1.3568e+00,
-        -5.4298e-01, -8.4320e-01,  5.5562e-01, -9.8359e-01, -4.8837e-01,
-        -7.5816e-01,  6.6743e-02, -4.5767e-01, -8.2435e-01, -5.9995e-02,
-         2.2646e-01,  1.3474e-02, -2.7961e-01, -8.0743e-02,  1.5793e+00,
-        -7.1144e-01, -7.2887e-01, -9.3148e-01,  4.1963e-02, -7.4577e-01,
-        -1.7196e-01,  1.0116e+00,  0.0000e+00,  6.3948e-01,  8.9107e-01,
-         2.2907e-01,  1.1469e+00, -3.3750e-02,  7.2373e-01,  5.7125e-01,
-         1.4722e-01,  6.1618e-01,  5.0592e-02,  8.3478e-01, -5.3777e-01,
-         2.7441e-01,  1.5093e-01,  3.1038e-01,  2.9085e-02,  5.2145e-01,
-         3.4857e-01,  1.2088e+00,  1.0340e+00,  5.4875e-01, -3.8305e-01,
-         1.2771e+00,  4.0332e-01,  4.0163e-04,  6.3698e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.3011e-01,  2.6457e-01, -1.0603e-01, -1.8032e-01,  2.2555e-01,
-         5.2373e-01, -2.6862e-01, -9.7533e-01, -4.2992e-01, -6.5976e-01,
-        -4.0496e-01, -2.8469e-02,  3.5085e-02,  2.7971e-01, -1.3568e+00,
-        -5.4298e-01, -8.4320e-01,  5.5562e-01, -9.8359e-01, -4.8837e-01,
-        -7.5816e-01,  6.6743e-02, -4.5767e-01, -8.2435e-01, -5.9995e-02,
-         2.2646e-01,  1.3474e-02, -2.7961e-01, -8.0743e-02,  1.5793e+00,
-        -7.1144e-01, -7.2887e-01, -9.3148e-01,  4.1963e-02, -7.4577e-01,
-        -1.7196e-01,  1.0116e+00,  0.0000e+00,  6.3948e-01,  8.9107e-01,
-         2.2907e-01,  1.1469e+00, -3.3750e-02,  7.2373e-01,  5.7125e-01,
-         1.4722e-01,  6.1618e-01,  5.0592e-02,  8.3478e-01, -5.3777e-01,
-         2.7441e-01,  1.5093e-01,  3.1038e-01,  2.9085e-02,  5.2145e-01,
-         3.4857e-01,  1.2088e+00,  1.0340e+00,  5.4875e-01, -3.8305e-01,
-         1.2771e+00,  4.0332e-01,  4.0163e-04,  6.3698e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8308,  0.2676, -0.1373, -0.1864,  0.1869,  0.5278, -0.2908, -0.9722,
-        -0.4321, -0.6585, -0.3927, -0.0707,  0.0173,  0.3022, -1.3598, -0.5338,
-        -0.8415,  0.5534, -0.9877, -0.4840, -0.7626,  0.0057, -0.4433, -0.8223,
-        -0.1539,  0.2197,  0.0460, -0.2677, -0.0824,  1.5808, -0.7011, -0.7204,
-        -0.9312,  0.0565, -0.7350, -0.1329,  1.0096,  0.0000,  0.6162,  0.9091,
-         0.2304,  1.1432, -0.0595,  0.7203,  0.5743,  0.1464,  0.6192,  0.0437,
-         0.8385, -0.5341,  0.2543,  0.1349,  0.2806, -0.0346,  0.5379,  0.3585,
-         1.2096,  1.0335,  0.5980, -0.3524,  1.2792,  0.4118, -0.0103,  0.0577],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8308,  0.2676, -0.1373, -0.1864,  0.1869,  0.5278, -0.2908, -0.9722,
-        -0.4321, -0.6585, -0.3927, -0.0707,  0.0173,  0.3022, -1.3598, -0.5338,
-        -0.8415,  0.5534, -0.9877, -0.4840, -0.7626,  0.0057, -0.4433, -0.8223,
-        -0.1539,  0.2197,  0.0460, -0.2677, -0.0824,  1.5808, -0.7011, -0.7204,
-        -0.9312,  0.0565, -0.7350, -0.1329,  1.0096,  0.0000,  0.6162,  0.9091,
-         0.2304,  1.1432, -0.0595,  0.7203,  0.5743,  0.1464,  0.6192,  0.0437,
-         0.8385, -0.5341,  0.2543,  0.1349,  0.2806, -0.0346,  0.5379,  0.3585,
-         1.2096,  1.0335,  0.5980, -0.3524,  1.2792,  0.4118, -0.0103,  0.0577],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8308,  0.2676, -0.1373, -0.1864,  0.1869,  0.5278, -0.2908, -0.9722,
-        -0.4321, -0.6585, -0.3927, -0.0707,  0.0173,  0.3022, -1.3598, -0.5338,
-        -0.8415,  0.5534, -0.9877, -0.4840, -0.7626,  0.0057, -0.4433, -0.8223,
-        -0.1539,  0.2197,  0.0460, -0.2677, -0.0824,  1.5808, -0.7011, -0.7204,
-        -0.9312,  0.0565, -0.7350, -0.1329,  1.0096,  0.0000,  0.6162,  0.9091,
-         0.2304,  1.1432, -0.0595,  0.7203,  0.5743,  0.1464,  0.6192,  0.0437,
-         0.8385, -0.5341,  0.2543,  0.1349,  0.2806, -0.0346,  0.5379,  0.3585,
-         1.2096,  1.0335,  0.5980, -0.3524,  1.2792,  0.4118, -0.0103,  0.0577],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8302,  0.2822, -0.1337, -0.1589,  0.1910,  0.5145, -0.3252, -0.9699,
-        -0.4267, -0.6557, -0.3788, -0.1377,  0.0518,  0.2831, -1.3629, -0.5251,
-        -0.8399,  0.5518, -0.9897, -0.4851, -0.7653, -0.0256, -0.4319, -0.8279,
-        -0.2263,  0.1962,  0.0757, -0.2480, -0.0960,  1.5799, -0.6887, -0.7079,
-        -0.9315,  0.1065, -0.7306, -0.0948,  1.0061,  0.0000,  0.5844,  0.9245,
-         0.2386,  1.1414, -0.0495,  0.7103,  0.5786,  0.1463,  0.6234,  0.0293,
-         0.8337, -0.5312,  0.2455,  0.1503,  0.2513, -0.0896,  0.5649,  0.3551,
-         1.2097,  1.0361,  0.6554, -0.3470,  1.2819,  0.4263, -0.0158,  0.0702],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8302,  0.2822, -0.1337, -0.1589,  0.1910,  0.5145, -0.3252, -0.9699,
-        -0.4267, -0.6557, -0.3788, -0.1377,  0.0518,  0.2831, -1.3629, -0.5251,
-        -0.8399,  0.5518, -0.9897, -0.4851, -0.7653, -0.0256, -0.4319, -0.8279,
-        -0.2263,  0.1962,  0.0757, -0.2480, -0.0960,  1.5799, -0.6887, -0.7079,
-        -0.9315,  0.1065, -0.7306, -0.0948,  1.0061,  0.0000,  0.5844,  0.9245,
-         0.2386,  1.1414, -0.0495,  0.7103,  0.5786,  0.1463,  0.6234,  0.0293,
-         0.8337, -0.5312,  0.2455,  0.1503,  0.2513, -0.0896,  0.5649,  0.3551,
-         1.2097,  1.0361,  0.6554, -0.3470,  1.2819,  0.4263, -0.0158,  0.0702],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8302,  0.2822, -0.1337, -0.1589,  0.1910,  0.5145, -0.3252, -0.9699,
-        -0.4267, -0.6557, -0.3788, -0.1377,  0.0518,  0.2831, -1.3629, -0.5251,
-        -0.8399,  0.5518, -0.9897, -0.4851, -0.7653, -0.0256, -0.4319, -0.8279,
-        -0.2263,  0.1962,  0.0757, -0.2480, -0.0960,  1.5799, -0.6887, -0.7079,
-        -0.9315,  0.1065, -0.7306, -0.0948,  1.0061,  0.0000,  0.5844,  0.9245,
-         0.2386,  1.1414, -0.0495,  0.7103,  0.5786,  0.1463,  0.6234,  0.0293,
-         0.8337, -0.5312,  0.2455,  0.1503,  0.2513, -0.0896,  0.5649,  0.3551,
-         1.2097,  1.0361,  0.6554, -0.3470,  1.2819,  0.4263, -0.0158,  0.0702],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8318,  0.2883, -0.1146, -0.1102,  0.1919,  0.5028, -0.3566, -0.9678,
-        -0.4070, -0.6508, -0.3568, -0.1717,  0.1102,  0.3183, -1.3649, -0.5066,
-        -0.8352,  0.5470, -0.9922, -0.4837, -0.7667, -0.0673, -0.4216, -0.8422,
-        -0.3096,  0.2279,  0.1206, -0.2180, -0.1016,  1.5786, -0.6769, -0.6946,
-        -0.9323,  0.1557, -0.7235, -0.0620,  1.0039,  0.0000,  0.5900,  0.9460,
-         0.2591,  1.1412, -0.0132,  0.7118,  0.5809,  0.1091,  0.6265,  0.0519,
-         0.8268, -0.5302,  0.2381,  0.1969,  0.1762, -0.1172,  0.6031,  0.3743,
-         1.2096,  1.0447,  0.7099, -0.3388,  1.2854,  0.4243, -0.0400,  0.0817],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8318,  0.2883, -0.1146, -0.1102,  0.1919,  0.5028, -0.3566, -0.9678,
-        -0.4070, -0.6508, -0.3568, -0.1717,  0.1102,  0.3183, -1.3649, -0.5066,
-        -0.8352,  0.5470, -0.9922, -0.4837, -0.7667, -0.0673, -0.4216, -0.8422,
-        -0.3096,  0.2279,  0.1206, -0.2180, -0.1016,  1.5786, -0.6769, -0.6946,
-        -0.9323,  0.1557, -0.7235, -0.0620,  1.0039,  0.0000,  0.5900,  0.9460,
-         0.2591,  1.1412, -0.0132,  0.7118,  0.5809,  0.1091,  0.6265,  0.0519,
-         0.8268, -0.5302,  0.2381,  0.1969,  0.1762, -0.1172,  0.6031,  0.3743,
-         1.2096,  1.0447,  0.7099, -0.3388,  1.2854,  0.4243, -0.0400,  0.0817],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8318,  0.2883, -0.1146, -0.1102,  0.1919,  0.5028, -0.3566, -0.9678,
-        -0.4070, -0.6508, -0.3568, -0.1717,  0.1102,  0.3183, -1.3649, -0.5066,
-        -0.8352,  0.5470, -0.9922, -0.4837, -0.7667, -0.0673, -0.4216, -0.8422,
-        -0.3096,  0.2279,  0.1206, -0.2180, -0.1016,  1.5786, -0.6769, -0.6946,
-        -0.9323,  0.1557, -0.7235, -0.0620,  1.0039,  0.0000,  0.5900,  0.9460,
-         0.2591,  1.1412, -0.0132,  0.7118,  0.5809,  0.1091,  0.6265,  0.0519,
-         0.8268, -0.5302,  0.2381,  0.1969,  0.1762, -0.1172,  0.6031,  0.3743,
-         1.2096,  1.0447,  0.7099, -0.3388,  1.2854,  0.4243, -0.0400,  0.0817],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8324,  0.3084, -0.1140, -0.0927,  0.2257,  0.4785, -0.3787, -0.9660,
-        -0.3880, -0.6482, -0.3367, -0.2236,  0.1311,  0.3343, -1.3668, -0.4901,
-        -0.8232,  0.5425, -0.9922, -0.4758, -0.7741, -0.0960, -0.4016, -0.8652,
-        -0.3790,  0.2669,  0.1363, -0.1703, -0.1168,  1.5772, -0.6629, -0.6820,
-        -0.9342,  0.2361, -0.7144, -0.0455,  1.0014,  0.0000,  0.5871,  0.9672,
-         0.2949,  1.1430, -0.0031,  0.7144,  0.5802,  0.0471,  0.6338,  0.0639,
-         0.8227, -0.5275,  0.2252,  0.2986,  0.0797, -0.1300,  0.6462,  0.3868,
-         1.2096,  1.0554,  0.7715, -0.3669,  1.2890,  0.4188, -0.0700,  0.0648],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8324,  0.3084, -0.1140, -0.0927,  0.2257,  0.4785, -0.3787, -0.9660,
-        -0.3880, -0.6482, -0.3367, -0.2236,  0.1311,  0.3343, -1.3668, -0.4901,
-        -0.8232,  0.5425, -0.9922, -0.4758, -0.7741, -0.0960, -0.4016, -0.8652,
-        -0.3790,  0.2669,  0.1363, -0.1703, -0.1168,  1.5772, -0.6629, -0.6820,
-        -0.9342,  0.2361, -0.7144, -0.0455,  1.0014,  0.0000,  0.5871,  0.9672,
-         0.2949,  1.1430, -0.0031,  0.7144,  0.5802,  0.0471,  0.6338,  0.0639,
-         0.8227, -0.5275,  0.2252,  0.2986,  0.0797, -0.1300,  0.6462,  0.3868,
-         1.2096,  1.0554,  0.7715, -0.3669,  1.2890,  0.4188, -0.0700,  0.0648],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8324,  0.3084, -0.1140, -0.0927,  0.2257,  0.4785, -0.3787, -0.9660,
-        -0.3880, -0.6482, -0.3367, -0.2236,  0.1311,  0.3343, -1.3668, -0.4901,
-        -0.8232,  0.5425, -0.9922, -0.4758, -0.7741, -0.0960, -0.4016, -0.8652,
-        -0.3790,  0.2669,  0.1363, -0.1703, -0.1168,  1.5772, -0.6629, -0.6820,
-        -0.9342,  0.2361, -0.7144, -0.0455,  1.0014,  0.0000,  0.5871,  0.9672,
-         0.2949,  1.1430, -0.0031,  0.7144,  0.5802,  0.0471,  0.6338,  0.0639,
-         0.8227, -0.5275,  0.2252,  0.2986,  0.0797, -0.1300,  0.6462,  0.3868,
-         1.2096,  1.0554,  0.7715, -0.3669,  1.2890,  0.4188, -0.0700,  0.0648],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8299,  0.3172, -0.1362, -0.0863,  0.2383,  0.4549, -0.3951, -0.9638,
-        -0.3753, -0.6448, -0.3211, -0.2770,  0.0905,  0.3530, -1.3695, -0.4662,
-        -0.7966,  0.5398, -0.9883, -0.4648, -0.7841, -0.1032, -0.3854, -0.8868,
-        -0.4519,  0.3328,  0.1162, -0.1212, -0.1492,  1.5746, -0.6425, -0.6668,
-        -0.9350,  0.2830, -0.7067, -0.0357,  1.0032,  0.0000,  0.6057,  0.9896,
-         0.3258,  1.1443, -0.0269,  0.7251,  0.5763, -0.0517,  0.6365,  0.0748,
-         0.8145, -0.5216,  0.2281,  0.3837, -0.0431, -0.1352,  0.6870,  0.3980,
-         1.2093,  1.0656,  0.8145, -0.4099,  1.2934,  0.3886, -0.1161, -0.0050],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8299,  0.3172, -0.1362, -0.0863,  0.2383,  0.4549, -0.3951, -0.9638,
-        -0.3753, -0.6448, -0.3211, -0.2770,  0.0905,  0.3530, -1.3695, -0.4662,
-        -0.7966,  0.5398, -0.9883, -0.4648, -0.7841, -0.1032, -0.3854, -0.8868,
-        -0.4519,  0.3328,  0.1162, -0.1212, -0.1492,  1.5746, -0.6425, -0.6668,
-        -0.9350,  0.2830, -0.7067, -0.0357,  1.0032,  0.0000,  0.6057,  0.9896,
-         0.3258,  1.1443, -0.0269,  0.7251,  0.5763, -0.0517,  0.6365,  0.0748,
-         0.8145, -0.5216,  0.2281,  0.3837, -0.0431, -0.1352,  0.6870,  0.0000,
-         1.2093,  1.0656,  0.8145, -0.4099,  1.2934,  0.3886, -0.1161, -0.0050],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8299,  0.3172, -0.1362, -0.0863,  0.2383,  0.4549, -0.3951, -0.9638,
-        -0.3753, -0.6448, -0.3211, -0.2770,  0.0905,  0.3530, -1.3695, -0.4662,
-        -0.7966,  0.5398, -0.9883, -0.4648, -0.7841, -0.1032, -0.3854, -0.8868,
-        -0.4519,  0.3328,  0.1162, -0.1212, -0.1492,  1.5746, -0.6425, -0.6668,
-        -0.9350,  0.2830, -0.7067, -0.0357,  1.0032,  0.0000,  0.6057,  0.9896,
-         0.3258,  1.1443, -0.0269,  0.7251,  0.5763, -0.0517,  0.6365,  0.0748,
-         0.8145, -0.5216,  0.2281,  0.3837, -0.0431, -0.1352,  0.6870,  0.0000,
-         1.2093,  1.0656,  0.8145, -0.4099,  1.2934,  0.3886, -0.1161, -0.0050],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8277,  0.3282, -0.1847, -0.0867,  0.1869,  0.4438, -0.4151, -0.9648,
-        -0.3726, -0.6420, -0.3063, -0.3449,  0.0154,  0.3714, -1.3723, -0.4558,
-        -0.7803,  0.5386, -0.9848, -0.4636, -0.7790, -0.0577, -0.3554, -0.8901,
-        -0.4966,  0.4008,  0.0869, -0.0636, -0.1590,  1.5714, -0.6243, -0.6510,
-        -0.9407,  0.2869, -0.6977, -0.0151,  1.0039,  0.0000,  0.6233,  1.0065,
-         0.3538,  1.1359, -0.0781,  0.7325,  0.5699, -0.1461,  0.6402,  0.0927,
-         0.8081, -0.5141,  0.2201,  0.4008, -0.1555, -0.1140,  0.7227,  0.0101,
-         1.2087,  1.0748,  0.8352, -0.4293,  1.2969,  0.3564, -0.1577, -0.0931],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8277,  0.3282, -0.1847, -0.0867,  0.1869,  0.4438, -0.4151, -0.9648,
-        -0.3726, -0.6420, -0.3063, -0.3449,  0.0154,  0.3714, -1.3723, -0.4558,
-        -0.7803,  0.5386, -0.9848, -0.4636, -0.7790, -0.0577, -0.3554, -0.8901,
-        -0.4966,  0.4008,  0.0869, -0.0636, -0.1590,  1.5714, -0.6243, -0.6510,
-        -0.9407,  0.2869, -0.6977, -0.0151,  1.0039,  0.0000,  0.6233,  1.0065,
-         0.3538,  1.1359, -0.0781,  0.7325,  0.5699, -0.1461,  0.6402,  0.0927,
-         0.8081, -0.5141,  0.2201,  0.4008, -0.1555, -0.1140,  0.7227,  0.0000,
-         1.2087,  1.0748,  0.8352, -0.4293,  1.2969,  0.3564, -0.1577, -0.0931],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8277,  0.3282, -0.1847, -0.0867,  0.1869,  0.4438, -0.4151, -0.9648,
-        -0.3726, -0.6420, -0.3063, -0.3449,  0.0154,  0.3714, -1.3723, -0.4558,
-        -0.7803,  0.5386, -0.9848, -0.4636, -0.7790, -0.0577, -0.3554, -0.8901,
-        -0.4966,  0.4008,  0.0869, -0.0636, -0.1590,  1.5714, -0.6243, -0.6510,
-        -0.9407,  0.2869, -0.6977, -0.0151,  1.0039,  0.0000,  0.6233,  1.0065,
-         0.3538,  1.1359, -0.0781,  0.7325,  0.5699, -0.1461,  0.6402,  0.0927,
-         0.8081, -0.5141,  0.2201,  0.4008, -0.1555, -0.1140,  0.7227,  0.0000,
-         1.2087,  1.0748,  0.8352, -0.4293,  1.2969,  0.3564, -0.1577, -0.0931],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8163,  0.3677, -0.2186, -0.0644,  0.1750,  0.4349, -0.4269, -0.9644,
-        -0.3529, -0.6419, -0.2813, -0.4204, -0.0730,  0.3680, -1.3768, -0.4452,
-        -0.7744,  0.5421, -0.9793, -0.4663, -0.7561,  0.0263, -0.3356, -0.8687,
-        -0.5284,  0.4499,  0.0362, -0.0018, -0.1413,  1.5714, -0.6200, -0.6351,
-        -0.9481,  0.3071, -0.6917,  0.0193,  1.0055,  0.0000,  0.6108,  1.0186,
-         0.4011,  1.1223, -0.0880,  0.7320,  0.5602, -0.2445,  0.6359,  0.0903,
-         0.8175, -0.5061,  0.2037,  0.3850, -0.2454, -0.0705,  0.7483,  0.0092,
-         1.2080,  1.0775,  0.8402, -0.4152,  1.2979,  0.3260, -0.1909, -0.1620],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8163,  0.3677, -0.2186, -0.0644,  0.1750,  0.4349, -0.4269, -0.9644,
-        -0.3529, -0.6419, -0.2813, -0.4204, -0.0730,  0.3680, -1.3768, -0.4452,
-        -0.7744,  0.5421, -0.9793, -0.4663, -0.7561,  0.0263, -0.3356, -0.8687,
-        -0.5284,  0.4499,  0.0362, -0.0018, -0.1413,  1.5714, -0.6200, -0.6351,
-        -0.9481,  0.3071, -0.6917,  0.0193,  1.0055,  0.0000,  0.6108,  1.0186,
-         0.4011,  1.1223, -0.0880,  0.7320,  0.5602, -0.2445,  0.6359,  0.0903,
-         0.8175, -0.5061,  0.2037,  0.3850, -0.2454, -0.0705,  0.7483,  0.0000,
-         1.2080,  1.0775,  0.8402, -0.4152,  1.2979,  0.3260, -0.1909, -0.1620],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8163,  0.3677, -0.2186, -0.0644,  0.1750,  0.4349, -0.4269, -0.9644,
-        -0.3529, -0.6419, -0.2813, -0.4204, -0.0730,  0.3680, -1.3768, -0.4452,
-        -0.7744,  0.5421, -0.9793, -0.4663, -0.7561,  0.0263, -0.3356, -0.8687,
-        -0.5284,  0.4499,  0.0362, -0.0018, -0.1413,  1.5714, -0.6200, -0.6351,
-        -0.9481,  0.3071, -0.6917,  0.0193,  1.0055,  0.0000,  0.6108,  1.0186,
-         0.4011,  1.1223, -0.0880,  0.7320,  0.5602, -0.2445,  0.6359,  0.0903,
-         0.8175, -0.5061,  0.2037,  0.3850, -0.2454, -0.0705,  0.7483,  0.0000,
-         1.2080,  1.0775,  0.8402, -0.4152,  1.2979,  0.3260, -0.1909, -0.1620],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8025,  0.4118, -0.2302, -0.0094,  0.1471,  0.4366, -0.4306, -0.9641,
-        -0.3293, -0.6451, -0.2429, -0.5076, -0.1077,  0.3623, -1.3778, -0.4466,
-        -0.7796,  0.5443, -0.9769, -0.4724, -0.7279,  0.0813, -0.3262, -0.8360,
-        -0.5554,  0.4835,  0.0035,  0.0411, -0.1214,  1.5719, -0.6231, -0.6259,
-        -0.9584,  0.2911, -0.6889,  0.0633,  1.0046,  0.0000,  0.5714,  1.0237,
-         0.4378,  1.1088, -0.0708,  0.7295,  0.5524, -0.3352,  0.6294,  0.0943,
-         0.8291, -0.4971,  0.1899,  0.3341, -0.3264, -0.0299,  0.7654,  0.0083,
-         1.2076,  1.0754,  0.8448, -0.3788,  1.2982,  0.3052, -0.2177, -0.1983],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8025,  0.4118, -0.2302, -0.0094,  0.1471,  0.4366, -0.4306, -0.9641,
-        -0.3293, -0.6451, -0.2429, -0.5076, -0.1077,  0.3623, -1.3778, -0.4466,
-        -0.7796,  0.5443, -0.9769, -0.4724, -0.7279,  0.0813, -0.3262, -0.8360,
-        -0.5554,  0.4835,  0.0035,  0.0411, -0.1214,  1.5719, -0.6231, -0.6259,
-        -0.9584,  0.2911, -0.6889,  0.0633,  1.0046,  0.0000,  0.5714,  1.0237,
-         0.4378,  1.1088, -0.0708,  0.7295,  0.5524, -0.3352,  0.6294,  0.0943,
-         0.8291, -0.4971,  0.1899,  0.3341, -0.3264, -0.0299,  0.7654,  0.0000,
-         1.2076,  1.0754,  0.8448, -0.3788,  1.2982,  0.3052, -0.2177, -0.1983],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8025,  0.4118, -0.2302, -0.0094,  0.1471,  0.4366, -0.4306, -0.9641,
-        -0.3293, -0.6451, -0.2429, -0.5076, -0.1077,  0.3623, -1.3778, -0.4466,
-        -0.7796,  0.5443, -0.9769, -0.4724, -0.7279,  0.0813, -0.3262, -0.8360,
-        -0.5554,  0.4835,  0.0035,  0.0411, -0.1214,  1.5719, -0.6231, -0.6259,
-        -0.9584,  0.2911, -0.6889,  0.0633,  1.0046,  0.0000,  0.5714,  1.0237,
-         0.4378,  1.1088, -0.0708,  0.7295,  0.5524, -0.3352,  0.6294,  0.0943,
-         0.8291, -0.4971,  0.1899,  0.3341, -0.3264, -0.0299,  0.7654,  0.0000,
-         1.2076,  1.0754,  0.8448, -0.3788,  1.2982,  0.3052, -0.2177, -0.1983],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.7887,  0.4436, -0.2378,  0.0274,  0.1600,  0.4436, -0.4368, -0.9652,
-        -0.3115, -0.6452, -0.2043, -0.5608, -0.1474,  0.3518, -1.3752, -0.4499,
-        -0.7877,  0.5459, -0.9744, -0.4727, -0.7155,  0.0808, -0.3216, -0.8139,
-        -0.5740,  0.5045, -0.0283,  0.0667, -0.1386,  1.5723, -0.6255, -0.6150,
-        -0.9667,  0.2640, -0.6815,  0.0953,  1.0031,  0.0000,  0.5104,  1.0241,
-         0.4688,  1.1001, -0.0565,  0.7172,  0.5430, -0.4056,  0.6237,  0.1092,
-         0.8417, -0.4868,  0.2018,  0.2831, -0.3978, -0.0181,  0.7699,  0.0075,
-         1.2074,  1.0659,  0.8463, -0.3743,  1.2988,  0.2909, -0.2348, -0.2306],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.7887,  0.4436, -0.2378,  0.0274,  0.1600,  0.4436, -0.4368, -0.9652,
-        -0.3115, -0.6452, -0.2043, -0.5608, -0.1474,  0.3518, -1.3752, -0.4499,
-        -0.7877,  0.5459, -0.9744, -0.4727, -0.7155,  0.0808, -0.3216, -0.8139,
-        -0.5740,  0.5045, -0.0283,  0.0667, -0.1386,  1.5723, -0.6255, -0.6150,
-        -0.9667,  0.2640, -0.6815,  0.0953,  1.0031,  0.0000,  0.5104,  1.0241,
-         0.4688,  1.1001, -0.0565,  0.7172,  0.5430, -0.4056,  0.6237,  0.1092,
-         0.8417, -0.4868,  0.2018,  0.2831, -0.3978, -0.0181,  0.7699,  0.0000,
-         1.2074,  1.0659,  0.8463, -0.3743,  1.2988,  0.2909, -0.2348, -0.2306],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.7887,  0.4436, -0.2378,  0.0274,  0.1600,  0.4436, -0.4368, -0.9652,
-        -0.3115, -0.6452, -0.2043, -0.5608, -0.1474,  0.3518, -1.3752, -0.4499,
-        -0.7877,  0.5459, -0.9744, -0.4727, -0.7155,  0.0808, -0.3216, -0.8139,
-        -0.5740,  0.5045, -0.0283,  0.0667, -0.1386,  1.5723, -0.6255, -0.6150,
-        -0.9667,  0.2640, -0.6815,  0.0953,  1.0031,  0.0000,  0.5104,  1.0241,
-         0.4688,  1.1001, -0.0565,  0.7172,  0.5430, -0.4056,  0.6237,  0.1092,
-         0.8417, -0.4868,  0.2018,  0.2831, -0.3978, -0.0181,  0.7699,  0.0000,
-         1.2074,  1.0659,  0.8463, -0.3743,  1.2988,  0.2909, -0.2348, -0.2306],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.7767,  0.4693, -0.2234,  0.0584,  0.2462,  0.4369, -0.4370, -0.9640,
-        -0.2956, -0.6449, -0.1623, -0.5844, -0.1290,  0.3267, -1.3768, -0.4635,
-        -0.7871,  0.5390, -0.9756, -0.4746, -0.7173,  0.0412, -0.3197, -0.8180,
-        -0.5873,  0.5159, -0.0433,  0.0667, -0.1841,  1.5714, -0.6320, -0.6012,
-        -0.9700,  0.2367, -0.6762,  0.1011,  1.0000,  0.0000,  0.4544,  1.0202,
-         0.4831,  1.0983, -0.0366,  0.7058,  0.5382, -0.4648,  0.6133,  0.1355,
-         0.8423, -0.4782,  0.2149,  0.2499, -0.4574, -0.0104,  0.7702,  0.0068,
-         1.2072,  1.0570,  0.8440, -0.4041,  1.3003,  0.2828, -0.2440, -0.2283],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.7767,  0.4693, -0.2234,  0.0584,  0.2462,  0.4369, -0.4370, -0.9640,
-        -0.2956, -0.6449, -0.1623, -0.5844, -0.1290,  0.3267, -1.3768, -0.4635,
-        -0.7871,  0.5390, -0.9756, -0.4746, -0.7173,  0.0412, -0.3197, -0.8180,
-        -0.5873,  0.5159, -0.0433,  0.0667, -0.1841,  1.5714, -0.6320, -0.6012,
-        -0.9700,  0.2367, -0.6762,  0.1011,  1.0000,  0.0000,  0.4544,  1.0202,
-         0.4831,  1.0983, -0.0366,  0.7058,  0.5382, -0.4648,  0.6133,  0.1355,
-         0.8423, -0.4782,  0.2149,  0.2499, -0.4574, -0.0104,  0.7702,  0.0000,
-         1.2072,  1.0570,  0.8440, -0.4041,  1.3003,  0.2828, -0.2440, -0.2283],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.7767,  0.4693, -0.2234,  0.0584,  0.2462,  0.4369, -0.4370, -0.9640,
-        -0.2956, -0.6449, -0.1623, -0.5844, -0.1290,  0.3267, -1.3768, -0.4635,
-        -0.7871,  0.5390, -0.9756, -0.4746, -0.7173,  0.0412, -0.3197, -0.8180,
-        -0.5873,  0.5159, -0.0433,  0.0667, -0.1841,  1.5714, -0.6320, -0.6012,
-        -0.9700,  0.2367, -0.6762,  0.1011,  1.0000,  0.0000,  0.4544,  1.0202,
-         0.4831,  1.0983, -0.0366,  0.7058,  0.5382, -0.4648,  0.6133,  0.1355,
-         0.8423, -0.4782,  0.2149,  0.2499, -0.4574, -0.0104,  0.7702,  0.0000,
-         1.2072,  1.0570,  0.8440, -0.4041,  1.3003,  0.2828, -0.2440, -0.2283],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.7623,  0.4791, -0.1915,  0.1059,  0.3411,  0.4290, -0.4388, -0.9620,
-        -0.2750, -0.6369, -0.1058, -0.5792, -0.0805,  0.3223, -1.3784, -0.4670,
-        -0.7863,  0.5326, -0.9764, -0.4681, -0.7218, -0.0032, -0.3203, -0.8200,
-        -0.6056,  0.5299, -0.0640,  0.0630, -0.2125,  1.5694, -0.6414, -0.5857,
-        -0.9748,  0.2057, -0.6683,  0.1073,  0.9950,  0.0000,  0.4127,  1.0137,
-         0.4832,  1.0998, -0.0139,  0.6940,  0.5331, -0.5273,  0.6015,  0.1758,
-         0.8410, -0.4738,  0.2250,  0.2178, -0.5148,  0.0082,  0.7708,  0.0061,
-         1.2073,  1.0476,  0.8404, -0.4347,  1.3025,  0.2677, -0.2484, -0.2204],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.7623,  0.4791, -0.1915,  0.1059,  0.3411,  0.4290, -0.4388, -0.9620,
-        -0.2750, -0.6369, -0.1058, -0.5792, -0.0805,  0.3223, -1.3784, -0.4670,
-        -0.7863,  0.5326, -0.9764, -0.4681, -0.7218, -0.0032, -0.3203, -0.8200,
-        -0.6056,  0.5299, -0.0640,  0.0630, -0.2125,  1.5694, -0.6414, -0.5857,
-        -0.9748,  0.2057, -0.6683,  0.1073,  0.9950,  0.0000,  0.4127,  1.0137,
-         0.4832,  1.0998, -0.0139,  0.6940,  0.5331, -0.5273,  0.6015,  0.1758,
-         0.8410, -0.4738,  0.2250,  0.2178, -0.5148,  0.0082,  0.7708,  0.0000,
-         1.2073,  1.0476,  0.8404, -0.4347,  1.3025,  0.2677, -0.2484, -0.2204],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.7623,  0.4791, -0.1915,  0.1059,  0.3411,  0.4290, -0.4388, -0.9620,
-        -0.2750, -0.6369, -0.1058, -0.5792, -0.0805,  0.3223, -1.3784, -0.4670,
-        -0.7863,  0.5326, -0.9764, -0.4681, -0.7218, -0.0032, -0.3203, -0.8200,
-        -0.6056,  0.5299, -0.0640,  0.0630, -0.2125,  1.5694, -0.6414, -0.5857,
-        -0.9748,  0.2057, -0.6683,  0.1073,  0.9950,  0.0000,  0.4127,  1.0137,
-         0.4832,  1.0998, -0.0139,  0.6940,  0.5331, -0.5273,  0.6015,  0.1758,
-         0.8410, -0.4738,  0.2250,  0.2178, -0.5148,  0.0082,  0.7708,  0.0000,
-         1.2073,  1.0476,  0.8404, -0.4347,  1.3025,  0.2677, -0.2484, -0.2204],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.7489,  0.4920, -0.1744,  0.1343,  0.3842,  0.4058, -0.4467, -0.9578,
-        -0.2606, -0.6295, -0.0653, -0.5724, -0.0502,  0.3123, -1.3794, -0.4577,
-        -0.7817,  0.5279, -0.9776, -0.4546, -0.7167, -0.0193, -0.3095, -0.8039,
-        -0.6164,  0.5401, -0.1128,  0.0595, -0.2329,  1.5671, -0.6505, -0.5671,
-        -0.9782,  0.1780, -0.6553,  0.1061,  0.9941,  0.0000,  0.3796,  1.0085,
-         0.4856,  1.0998, -0.0187,  0.6779,  0.5283, -0.5838,  0.5920,  0.2017,
-         0.8391, -0.4696,  0.2370,  0.1961, -0.5607,  0.0315,  0.7685,  0.0055,
-         1.2078,  1.0433,  0.8305, -0.4443,  1.3048,  0.2536, -0.2400, -0.2525],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.7489,  0.4920, -0.1744,  0.1343,  0.3842,  0.4058, -0.4467, -0.9578,
-        -0.2606, -0.6295, -0.0653, -0.5724, -0.0502,  0.3123, -1.3794, -0.4577,
-        -0.7817,  0.5279, -0.9776, -0.4546, -0.7167, -0.0193, -0.3095, -0.8039,
-        -0.6164,  0.5401, -0.1128,  0.0595, -0.2329,  1.5671, -0.6505, -0.5671,
-        -0.9782,  0.1780, -0.6553,  0.1061,  0.9941,  0.0000,  0.3796,  1.0085,
-         0.4856,  1.0998, -0.0187,  0.6779,  0.5283, -0.5838,  0.5920,  0.2017,
-         0.8391, -0.4696,  0.2370,  0.1961, -0.5607,  0.0315,  0.7685,  0.0000,
-         1.2078,  1.0433,  0.8305, -0.4443,  1.3048,  0.2536, -0.2400, -0.2525],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.7489,  0.4920, -0.1744,  0.1343,  0.3842,  0.4058, -0.4467, -0.9578,
-        -0.2606, -0.6295, -0.0653, -0.5724, -0.0502,  0.3123, -1.3794, -0.4577,
-        -0.7817,  0.5279, -0.9776, -0.4546, -0.7167, -0.0193, -0.3095, -0.8039,
-        -0.6164,  0.5401, -0.1128,  0.0595, -0.2329,  1.5671, -0.6505, -0.5671,
-        -0.9782,  0.1780, -0.6553,  0.1061,  0.9941,  0.0000,  0.3796,  1.0085,
-         0.4856,  1.0998, -0.0187,  0.6779,  0.5283, -0.5838,  0.5920,  0.2017,
-         0.8391, -0.4696,  0.2370,  0.1961, -0.5607,  0.0315,  0.7685,  0.0000,
-         1.2078,  1.0433,  0.8305, -0.4443,  1.3048,  0.2536, -0.2400, -0.2525],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.7205,  0.5222, -0.1530,  0.1550,  0.4126,  0.3960, -0.4590, -0.9528,
-        -0.2547, -0.6189, -0.0316, -0.5702, -0.0458,  0.2591, -1.3793, -0.4546,
-        -0.7812,  0.5213, -0.9767, -0.4498, -0.7058,  0.0328, -0.3017, -0.7928,
-        -0.6130,  0.5416, -0.1469,  0.0830, -0.2618,  1.5645, -0.6570, -0.5486,
-        -0.9839,  0.1876, -0.6433,  0.0850,  0.9905,  0.0000,  0.2882,  0.9983,
-         0.4979,  1.0988, -0.0507,  0.6555,  0.5255, -0.6291,  0.5893,  0.1969,
-         0.8536, -0.4656,  0.2262,  0.1734, -0.5982,  0.0413,  0.7642,  0.0049,
-         1.2074,  1.0368,  0.8223, -0.4417,  1.3040,  0.2426, -0.2408, -0.2619],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.7205,  0.5222, -0.1530,  0.1550,  0.4126,  0.3960, -0.4590, -0.9528,
-        -0.2547, -0.6189, -0.0316, -0.5702, -0.0458,  0.2591, -1.3793, -0.4546,
-        -0.7812,  0.5213, -0.9767, -0.4498, -0.7058,  0.0328, -0.3017, -0.7928,
-        -0.6130,  0.5416, -0.1469,  0.0830, -0.2618,  1.5645, -0.6570, -0.5486,
-        -0.9839,  0.1876, -0.6433,  0.0850,  0.9905,  0.0000,  0.2882,  0.9983,
-         0.4979,  1.0988, -0.0507,  0.6555,  0.5255, -0.6291,  0.5893,  0.1969,
-         0.8536, -0.4656,  0.2262,  0.1734, -0.5982,  0.0413,  0.7642,  0.0000,
-         1.2074,  1.0368,  0.8223, -0.4417,  1.3040,  0.2426, -0.2408, -0.2619],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.7205,  0.5222, -0.1530,  0.1550,  0.4126,  0.3960, -0.4590, -0.9528,
-        -0.2547, -0.6189, -0.0316, -0.5702, -0.0458,  0.2591, -1.3793, -0.4546,
-        -0.7812,  0.5213, -0.9767, -0.4498, -0.7058,  0.0328, -0.3017, -0.7928,
-        -0.6130,  0.5416, -0.1469,  0.0830, -0.2618,  1.5645, -0.6570, -0.5486,
-        -0.9839,  0.1876, -0.6433,  0.0850,  0.9905,  0.0000,  0.2882,  0.9983,
-         0.4979,  1.0988, -0.0507,  0.6555,  0.5255, -0.6291,  0.5893,  0.1969,
-         0.8536, -0.4656,  0.2262,  0.1734, -0.5982,  0.0413,  0.7642,  0.0000,
-         1.2074,  1.0368,  0.8223, -0.4417,  1.3040,  0.2426, -0.2408, -0.2619],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.6848,  0.5551, -0.1236,  0.1538,  0.4349,  0.3923, -0.4691, -0.9467,
-        -0.2398, -0.6073, -0.0110, -0.5766, -0.0535,  0.1671, -1.3782, -0.4541,
-        -0.7812,  0.5111, -0.9752, -0.4510, -0.6966,  0.1244, -0.2817, -0.7816,
-        -0.5993,  0.5374, -0.1961,  0.1127, -0.2874,  1.5630, -0.6686, -0.5304,
-        -0.9887,  0.2162, -0.6334,  0.0405,  0.9871,  0.0000,  0.1592,  0.9894,
-         0.5114,  1.0993, -0.0869,  0.6345,  0.5232, -0.6698,  0.5903,  0.1972,
-         0.8739, -0.4563,  0.1931,  0.1418, -0.6319,  0.0392,  0.7599,  0.0044,
-         1.2064,  1.0363,  0.8116, -0.4307,  1.3036,  0.2288, -0.2439, -0.2580],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.6848,  0.5551, -0.1236,  0.1538,  0.4349,  0.3923, -0.4691, -0.9467,
-        -0.2398, -0.6073, -0.0110, -0.5766, -0.0535,  0.1671, -1.3782, -0.4541,
-        -0.7812,  0.5111, -0.9752, -0.4510, -0.6966,  0.1244, -0.2817, -0.7816,
-        -0.5993,  0.5374, -0.1961,  0.1127, -0.2874,  1.5630, -0.6686, -0.5304,
-        -0.9887,  0.2162, -0.6334,  0.0405,  0.9871,  0.0000,  0.1592,  0.9894,
-         0.5114,  1.0993, -0.0869,  0.6345,  0.5232, -0.6698,  0.5903,  0.1972,
-         0.8739, -0.4563,  0.1931,  0.1418, -0.6319,  0.0392,  0.7599,  0.0000,
-         1.2064,  1.0363,  0.8116, -0.4307,  1.3036,  0.2288, -0.2439, -0.2580],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6848,  0.5551, -0.1236,  0.1538,  0.4349,  0.3923, -0.4691, -0.9467,
-        -0.2398, -0.6073, -0.0110, -0.5766, -0.0535,  0.1671, -1.3782, -0.4541,
-        -0.7812,  0.5111, -0.9752, -0.4510, -0.6966,  0.1244, -0.2817, -0.7816,
-        -0.5993,  0.5374, -0.1961,  0.1127, -0.2874,  1.5630, -0.6686, -0.5304,
-        -0.9887,  0.2162, -0.6334,  0.0405,  0.9871,  0.0000,  0.1592,  0.9894,
-         0.5114,  1.0993, -0.0869,  0.6345,  0.5232, -0.6698,  0.5903,  0.1972,
-         0.8739, -0.4563,  0.1931,  0.1418, -0.6319,  0.0392,  0.7599,  0.0000,
-         1.2064,  1.0363,  0.8116, -0.4307,  1.3036,  0.2288, -0.2439, -0.2580],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.5870e-01,  5.5715e-01, -6.9219e-02,  1.7682e-01,  4.3876e-01,
-         3.8071e-01, -4.7506e-01, -9.4113e-01, -2.1942e-01, -5.9914e-01,
-         1.1191e-03, -5.6578e-01, -1.4353e-02,  1.7588e-01, -1.3771e+00,
-        -4.6068e-01, -7.7570e-01,  5.0082e-01, -9.7435e-01, -4.5474e-01,
-        -6.8725e-01,  1.3975e-01, -2.7406e-01, -7.6694e-01, -5.9270e-01,
-         5.4300e-01, -2.0805e-01,  1.2052e-01, -3.0651e-01,  1.5608e+00,
-        -6.7597e-01, -5.0654e-01, -9.9108e-01,  2.1175e-01, -6.2769e-01,
-        -1.4722e-02,  9.9377e-01,  0.0000e+00,  1.4855e-01,  9.8261e-01,
-         5.0991e-01,  1.1008e+00, -1.0448e-01,  6.2125e-01,  5.2014e-01,
-        -6.9215e-01,  5.9101e-01,  2.3236e-01,  8.8575e-01, -4.4348e-01,
-         1.5843e-01,  1.0954e-01, -6.5193e-01,  3.5356e-02,  7.6581e-01,
-         3.9493e-03,  1.2065e+00,  1.0422e+00,  7.9736e-01, -4.1187e-01,
-         1.3048e+00,  2.0540e-01, -2.5482e-01, -2.2250e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 6.5870e-01,  5.5715e-01, -6.9219e-02,  1.7682e-01,  4.3876e-01,
-         3.8071e-01, -4.7506e-01, -9.4113e-01, -2.1942e-01, -5.9914e-01,
-         1.1191e-03, -5.6578e-01, -1.4353e-02,  1.7588e-01, -1.3771e+00,
-        -4.6068e-01, -7.7570e-01,  5.0082e-01, -9.7435e-01, -4.5474e-01,
-        -6.8725e-01,  1.3975e-01, -2.7406e-01, -7.6694e-01, -5.9270e-01,
-         5.4300e-01, -2.0805e-01,  1.2052e-01, -3.0651e-01,  1.5608e+00,
-        -6.7597e-01, -5.0654e-01, -9.9108e-01,  2.1175e-01, -6.2769e-01,
-        -1.4722e-02,  9.9377e-01,  0.0000e+00,  1.4855e-01,  9.8261e-01,
-         5.0991e-01,  1.1008e+00, -1.0448e-01,  6.2125e-01,  5.2014e-01,
-        -6.9215e-01,  5.9101e-01,  2.3236e-01,  8.8575e-01, -4.4348e-01,
-         1.5843e-01,  1.0954e-01, -6.5193e-01,  3.5356e-02,  7.6581e-01,
-         0.0000e+00,  1.2065e+00,  1.0422e+00,  7.9736e-01, -4.1187e-01,
-         1.3048e+00,  2.0540e-01, -2.5482e-01, -2.2250e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 6.5870e-01,  5.5715e-01, -6.9219e-02,  1.7682e-01,  4.3876e-01,
-         3.8071e-01, -4.7506e-01, -9.4113e-01, -2.1942e-01, -5.9914e-01,
-         1.1191e-03, -5.6578e-01, -1.4353e-02,  1.7588e-01, -1.3771e+00,
-        -4.6068e-01, -7.7570e-01,  5.0082e-01, -9.7435e-01, -4.5474e-01,
-        -6.8725e-01,  1.3975e-01, -2.7406e-01, -7.6694e-01, -5.9270e-01,
-         5.4300e-01, -2.0805e-01,  1.2052e-01, -3.0651e-01,  1.5608e+00,
-        -6.7597e-01, -5.0654e-01, -9.9108e-01,  2.1175e-01, -6.2769e-01,
-        -1.4722e-02,  9.9377e-01,  0.0000e+00,  1.4855e-01,  9.8261e-01,
-         5.0991e-01,  1.1008e+00, -1.0448e-01,  6.2125e-01,  5.2014e-01,
-        -6.9215e-01,  5.9101e-01,  2.3236e-01,  8.8575e-01, -4.4348e-01,
-         1.5843e-01,  1.0954e-01, -6.5193e-01,  3.5356e-02,  7.6581e-01,
-         0.0000e+00,  1.2065e+00,  1.0422e+00,  7.9736e-01, -4.1187e-01,
-         1.3048e+00,  2.0540e-01, -2.5482e-01, -2.2250e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.6382,  0.5408,  0.0427,  0.2101,  0.4241,  0.3461, -0.4705, -0.9347,
-        -0.1797, -0.5914,  0.0109, -0.5453,  0.1068,  0.2282, -1.3768, -0.4758,
-        -0.7751,  0.4818, -0.9741, -0.4543, -0.6869,  0.0983, -0.2858, -0.7501,
-        -0.5831,  0.5540, -0.1615,  0.1113, -0.3177,  1.5577, -0.6703, -0.4796,
-        -0.9940,  0.1809, -0.6268, -0.0727,  1.0035,  0.0000,  0.2040,  0.9783,
-         0.4860,  1.1042, -0.0764,  0.6149,  0.5125, -0.7006,  0.5841,  0.2719,
-         0.8924, -0.4291,  0.1192,  0.0809, -0.6587,  0.0225,  0.7770,  0.0035,
-         1.2057,  1.0478,  0.7713, -0.4022,  1.3072,  0.1673, -0.2472, -0.1263],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.6382,  0.5408,  0.0427,  0.2101,  0.4241,  0.3461, -0.4705, -0.9347,
-        -0.1797, -0.5914,  0.0109, -0.5453,  0.1068,  0.2282, -1.3768, -0.4758,
-        -0.7751,  0.4818, -0.9741, -0.4543, -0.6869,  0.0983, -0.2858, -0.7501,
-        -0.5831,  0.5540, -0.1615,  0.1113, -0.3177,  1.5577, -0.6703, -0.4796,
-        -0.9940,  0.1809, -0.6268, -0.0727,  1.0035,  0.0000,  0.2040,  0.9783,
-         0.4860,  1.1042, -0.0764,  0.6149,  0.5125, -0.7006,  0.5841,  0.2719,
-         0.8924, -0.4291,  0.1192,  0.0809, -0.6587,  0.0225,  0.7770,  0.0000,
-         1.2057,  1.0478,  0.7713, -0.4022,  1.3072,  0.1673, -0.2472, -0.1263],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6382,  0.5408,  0.0427,  0.2101,  0.4241,  0.3461, -0.4705, -0.9347,
-        -0.1797, -0.5914,  0.0109, -0.5453,  0.1068,  0.2282, -1.3768, -0.4758,
-        -0.7751,  0.4818, -0.9741, -0.4543, -0.6869,  0.0983, -0.2858, -0.7501,
-        -0.5831,  0.5540, -0.1615,  0.1113, -0.3177,  1.5577, -0.6703, -0.4796,
-        -0.9940,  0.1809, -0.6268, -0.0727,  1.0035,  0.0000,  0.2040,  0.9783,
-         0.4860,  1.1042, -0.0764,  0.6149,  0.5125, -0.7006,  0.5841,  0.2719,
-         0.8924, -0.4291,  0.1192,  0.0809, -0.6587,  0.0225,  0.7770,  0.0000,
-         1.2057,  1.0478,  0.7713, -0.4022,  1.3072,  0.1673, -0.2472, -0.1263],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.6299,  0.5402,  0.1409,  0.2143,  0.4311,  0.3010, -0.4659, -0.9314,
-        -0.1441, -0.5819,  0.0179, -0.5103,  0.2177,  0.2475, -1.3752, -0.4903,
-        -0.7693,  0.4631, -0.9757, -0.4491, -0.6893,  0.0584, -0.3036, -0.7383,
-        -0.5942,  0.5652, -0.1529,  0.0918, -0.3400,  1.5559, -0.6581, -0.4661,
-        -0.9935,  0.1553, -0.6217, -0.1191,  1.0134,  0.0000,  0.2537,  0.9748,
-         0.4589,  1.1063, -0.0611,  0.5946,  0.5056, -0.7167,  0.5784,  0.2913,
-         0.9033, -0.4131,  0.1245,  0.0359, -0.6750, -0.0107,  0.7910,  0.0032,
-         1.2050,  1.0504,  0.7484, -0.3984,  1.3102,  0.1164, -0.2541, -0.0773],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.6299,  0.5402,  0.1409,  0.2143,  0.4311,  0.3010, -0.4659, -0.9314,
-        -0.1441, -0.5819,  0.0179, -0.5103,  0.2177,  0.2475, -1.3752, -0.4903,
-        -0.7693,  0.4631, -0.9757, -0.4491, -0.6893,  0.0584, -0.3036, -0.7383,
-        -0.5942,  0.5652, -0.1529,  0.0918, -0.3400,  1.5559, -0.6581, -0.4661,
-        -0.9935,  0.1553, -0.6217, -0.1191,  1.0134,  0.0000,  0.2537,  0.9748,
-         0.4589,  1.1063, -0.0611,  0.5946,  0.5056, -0.7167,  0.5784,  0.2913,
-         0.9033, -0.4131,  0.1245,  0.0359, -0.6750, -0.0107,  0.7910,  0.0000,
-         1.2050,  1.0504,  0.7484, -0.3984,  1.3102,  0.1164, -0.2541, -0.0773],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6299,  0.5402,  0.1409,  0.2143,  0.4311,  0.3010, -0.4659, -0.9314,
-        -0.1441, -0.5819,  0.0179, -0.5103,  0.2177,  0.2475, -1.3752, -0.4903,
-        -0.7693,  0.4631, -0.9757, -0.4491, -0.6893,  0.0584, -0.3036, -0.7383,
-        -0.5942,  0.5652, -0.1529,  0.0918, -0.3400,  1.5559, -0.6581, -0.4661,
-        -0.9935,  0.1553, -0.6217, -0.1191,  1.0134,  0.0000,  0.2537,  0.9748,
-         0.4589,  1.1063, -0.0611,  0.5946,  0.5056, -0.7167,  0.5784,  0.2913,
-         0.9033, -0.4131,  0.1245,  0.0359, -0.6750, -0.0107,  0.7910,  0.0000,
-         1.2050,  1.0504,  0.7484, -0.3984,  1.3102,  0.1164, -0.2541, -0.0773],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.6209,  0.5493,  0.2001,  0.2187,  0.4439,  0.2660, -0.4628, -0.9284,
-        -0.1247, -0.5698,  0.0491, -0.4859,  0.2783,  0.2530, -1.3742, -0.5011,
-        -0.7558,  0.4420, -0.9762, -0.4517, -0.6882,  0.0881, -0.2835, -0.7221,
-        -0.5781,  0.5690, -0.1657,  0.0790, -0.3528,  1.5555, -0.6453, -0.4532,
-        -0.9911,  0.1473, -0.6176, -0.1869,  1.0270,  0.0000,  0.3172,  0.9688,
-         0.4431,  1.1128, -0.0593,  0.5775,  0.4988, -0.7158,  0.5743,  0.2983,
-         0.9107, -0.3982,  0.1645,  0.0152, -0.6827, -0.0529,  0.8020,  0.0028,
-         1.2040,  1.0507,  0.7201, -0.3739,  1.3138,  0.0781, -0.2506, -0.0702],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.6209,  0.5493,  0.2001,  0.2187,  0.4439,  0.2660, -0.4628, -0.9284,
-        -0.1247, -0.5698,  0.0491, -0.4859,  0.2783,  0.2530, -1.3742, -0.5011,
-        -0.7558,  0.4420, -0.9762, -0.4517, -0.6882,  0.0881, -0.2835, -0.7221,
-        -0.5781,  0.5690, -0.1657,  0.0790, -0.3528,  1.5555, -0.6453, -0.4532,
-        -0.9911,  0.1473, -0.6176, -0.1869,  1.0270,  0.0000,  0.3172,  0.9688,
-         0.4431,  1.1128, -0.0593,  0.5775,  0.4988, -0.7158,  0.5743,  0.2983,
-         0.9107, -0.3982,  0.1645,  0.0152, -0.6827, -0.0529,  0.8020,  0.0000,
-         1.2040,  1.0507,  0.7201, -0.3739,  1.3138,  0.0781, -0.2506, -0.0702],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6209,  0.5493,  0.2001,  0.2187,  0.4439,  0.2660, -0.4628, -0.9284,
-        -0.1247, -0.5698,  0.0491, -0.4859,  0.2783,  0.2530, -1.3742, -0.5011,
-        -0.7558,  0.4420, -0.9762, -0.4517, -0.6882,  0.0881, -0.2835, -0.7221,
-        -0.5781,  0.5690, -0.1657,  0.0790, -0.3528,  1.5555, -0.6453, -0.4532,
-        -0.9911,  0.1473, -0.6176, -0.1869,  1.0270,  0.0000,  0.3172,  0.9688,
-         0.4431,  1.1128, -0.0593,  0.5775,  0.4988, -0.7158,  0.5743,  0.2983,
-         0.9107, -0.3982,  0.1645,  0.0152, -0.6827, -0.0529,  0.8020,  0.0000,
-         1.2040,  1.0507,  0.7201, -0.3739,  1.3138,  0.0781, -0.2506, -0.0702],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.5955,  0.5675,  0.2330,  0.2362,  0.4632,  0.2458, -0.4477, -0.9226,
-        -0.0987, -0.5497,  0.1231, -0.4790,  0.3334,  0.2342, -1.3730, -0.4931,
-        -0.7453,  0.4307, -0.9776, -0.4503, -0.6800,  0.1665, -0.2474, -0.7028,
-        -0.5253,  0.5699, -0.1896,  0.0797, -0.3371,  1.5565, -0.6445, -0.4397,
-        -0.9890,  0.1524, -0.6059, -0.2657,  1.0405,  0.0000,  0.3600,  0.9622,
-         0.4365,  1.1166, -0.0735,  0.5589,  0.4895, -0.7132,  0.5781,  0.3026,
-         0.9273, -0.3784,  0.2104,  0.0228, -0.6957, -0.0547,  0.8135,  0.0025,
-         1.2031,  1.0502,  0.6972, -0.3073,  1.3171,  0.0594, -0.2293, -0.0916],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.5955,  0.5675,  0.2330,  0.2362,  0.4632,  0.2458, -0.4477, -0.9226,
-        -0.0987, -0.5497,  0.1231, -0.4790,  0.3334,  0.2342, -1.3730, -0.4931,
-        -0.7453,  0.4307, -0.9776, -0.4503, -0.6800,  0.1665, -0.2474, -0.7028,
-        -0.5253,  0.5699, -0.1896,  0.0797, -0.3371,  1.5565, -0.6445, -0.4397,
-        -0.9890,  0.1524, -0.6059, -0.2657,  1.0405,  0.0000,  0.3600,  0.9622,
-         0.4365,  1.1166, -0.0735,  0.5589,  0.4895, -0.7132,  0.5781,  0.3026,
-         0.9273, -0.3784,  0.2104,  0.0228, -0.6957, -0.0547,  0.8135,  0.0000,
-         1.2031,  1.0502,  0.6972, -0.3073,  1.3171,  0.0594, -0.2293, -0.0916],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5955,  0.5675,  0.2330,  0.2362,  0.4632,  0.2458, -0.4477, -0.9226,
-        -0.0987, -0.5497,  0.1231, -0.4790,  0.3334,  0.2342, -1.3730, -0.4931,
-        -0.7453,  0.4307, -0.9776, -0.4503, -0.6800,  0.1665, -0.2474, -0.7028,
-        -0.5253,  0.5699, -0.1896,  0.0797, -0.3371,  1.5565, -0.6445, -0.4397,
-        -0.9890,  0.1524, -0.6059, -0.2657,  1.0405,  0.0000,  0.3600,  0.9622,
-         0.4365,  1.1166, -0.0735,  0.5589,  0.4895, -0.7132,  0.5781,  0.3026,
-         0.9273, -0.3784,  0.2104,  0.0228, -0.6957, -0.0547,  0.8135,  0.0000,
-         1.2031,  1.0502,  0.6972, -0.3073,  1.3171,  0.0594, -0.2293, -0.0916],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.5813,  0.5651,  0.2753,  0.2543,  0.4302,  0.2038, -0.4176, -0.9174,
-        -0.0976, -0.5289,  0.1641, -0.4786,  0.3949,  0.2580, -1.3714, -0.4751,
-        -0.7186,  0.4255, -0.9798, -0.4561, -0.6782,  0.1596, -0.2272, -0.6777,
-        -0.5192,  0.5775, -0.1882,  0.0613, -0.3343,  1.5559, -0.6304, -0.4257,
-        -0.9888,  0.1033, -0.5965, -0.3489,  1.0535,  0.0000,  0.4118,  0.9567,
-         0.4148,  1.1235, -0.0830,  0.5351,  0.4784, -0.7118,  0.5742,  0.3149,
-         0.9327, -0.3449,  0.2411, -0.0334, -0.7113, -0.0721,  0.8276,  0.0022,
-         1.2025,  1.0502,  0.6623, -0.2959,  1.3227,  0.0145, -0.2248, -0.0907],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.5813,  0.5651,  0.2753,  0.2543,  0.4302,  0.2038, -0.4176, -0.9174,
-        -0.0976, -0.5289,  0.1641, -0.4786,  0.3949,  0.2580, -1.3714, -0.4751,
-        -0.7186,  0.4255, -0.9798, -0.4561, -0.6782,  0.1596, -0.2272, -0.6777,
-        -0.5192,  0.5775, -0.1882,  0.0613, -0.3343,  1.5559, -0.6304, -0.4257,
-        -0.9888,  0.1033, -0.5965, -0.3489,  1.0535,  0.0000,  0.4118,  0.9567,
-         0.4148,  1.1235, -0.0830,  0.5351,  0.4784, -0.7118,  0.5742,  0.3149,
-         0.9327, -0.3449,  0.2411, -0.0334, -0.7113, -0.0721,  0.8276,  0.0000,
-         1.2025,  1.0502,  0.6623, -0.2959,  1.3227,  0.0145, -0.2248, -0.0907],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5813,  0.5651,  0.2753,  0.2543,  0.4302,  0.2038, -0.4176, -0.9174,
-        -0.0976, -0.5289,  0.1641, -0.4786,  0.3949,  0.2580, -1.3714, -0.4751,
-        -0.7186,  0.4255, -0.9798, -0.4561, -0.6782,  0.1596, -0.2272, -0.6777,
-        -0.5192,  0.5775, -0.1882,  0.0613, -0.3343,  1.5559, -0.6304, -0.4257,
-        -0.9888,  0.1033, -0.5965, -0.3489,  1.0535,  0.0000,  0.4118,  0.9567,
-         0.4148,  1.1235, -0.0830,  0.5351,  0.4784, -0.7118,  0.5742,  0.3149,
-         0.9327, -0.3449,  0.2411, -0.0334, -0.7113, -0.0721,  0.8276,  0.0000,
-         1.2025,  1.0502,  0.6623, -0.2959,  1.3227,  0.0145, -0.2248, -0.0907],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.5591,  0.5713,  0.3191,  0.2683,  0.4031,  0.1885, -0.4046, -0.9093,
-        -0.0924, -0.5008,  0.1842, -0.4614,  0.4461,  0.2507, -1.3643, -0.4583,
-        -0.6997,  0.4122, -0.9827, -0.4529, -0.6729,  0.1595, -0.2354, -0.6300,
-        -0.5200,  0.5701, -0.1655,  0.0521, -0.2947,  1.5550, -0.6139, -0.4070,
-        -0.9856,  0.0863, -0.5962, -0.3972,  1.0654,  0.0000,  0.4179,  0.9464,
-         0.3835,  1.1266, -0.0424,  0.5093,  0.4666, -0.6983,  0.5687,  0.3391,
-         0.9374, -0.3153,  0.2666, -0.0591, -0.7222, -0.0757,  0.8320,  0.0020,
-         1.2009,  1.0451,  0.6168, -0.2452,  1.3270, -0.0100, -0.1978, -0.0275],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.5591,  0.5713,  0.3191,  0.2683,  0.4031,  0.1885, -0.4046, -0.9093,
-        -0.0924, -0.5008,  0.1842, -0.4614,  0.4461,  0.2507, -1.3643, -0.4583,
-        -0.6997,  0.4122, -0.9827, -0.4529, -0.6729,  0.1595, -0.2354, -0.6300,
-        -0.5200,  0.5701, -0.1655,  0.0521, -0.2947,  1.5550, -0.6139, -0.4070,
-        -0.9856,  0.0863, -0.5962, -0.3972,  1.0654,  0.0000,  0.4179,  0.9464,
-         0.3835,  1.1266, -0.0424,  0.5093,  0.4666, -0.6983,  0.5687,  0.3391,
-         0.9374, -0.3153,  0.2666, -0.0591, -0.7222, -0.0757,  0.8320,  0.0000,
-         1.2009,  1.0451,  0.6168, -0.2452,  1.3270, -0.0100, -0.1978, -0.0275],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5591,  0.5713,  0.3191,  0.2683,  0.4031,  0.1885, -0.4046, -0.9093,
-        -0.0924, -0.5008,  0.1842, -0.4614,  0.4461,  0.2507, -1.3643, -0.4583,
-        -0.6997,  0.4122, -0.9827, -0.4529, -0.6729,  0.1595, -0.2354, -0.6300,
-        -0.5200,  0.5701, -0.1655,  0.0521, -0.2947,  1.5550, -0.6139, -0.4070,
-        -0.9856,  0.0863, -0.5962, -0.3972,  1.0654,  0.0000,  0.4179,  0.9464,
-         0.3835,  1.1266, -0.0424,  0.5093,  0.4666, -0.6983,  0.5687,  0.3391,
-         0.9374, -0.3153,  0.2666, -0.0591, -0.7222, -0.0757,  0.8320,  0.0000,
-         1.2009,  1.0451,  0.6168, -0.2452,  1.3270, -0.0100, -0.1978, -0.0275],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.5294,  0.5662,  0.3447,  0.2772,  0.4133,  0.1618, -0.4182, -0.9009,
-        -0.0963, -0.4656,  0.1341, -0.4548,  0.4800,  0.2360, -1.3569, -0.4797,
-        -0.7023,  0.4031, -0.9835, -0.4663, -0.6635,  0.1190, -0.2683, -0.5804,
-        -0.5047,  0.5566, -0.0453,  0.0423, -0.2403,  1.5543, -0.5938, -0.3708,
-        -0.9854,  0.0826, -0.6210, -0.4279,  1.0800,  0.0000,  0.4280,  0.9370,
-         0.3243,  1.1275,  0.0757,  0.4907,  0.4579, -0.6528,  0.5595,  0.3604,
-         0.9395, -0.2901,  0.2684, -0.0788, -0.7184, -0.0177,  0.8324,  0.0018,
-         1.1991,  1.0430,  0.5760, -0.1971,  1.3300, -0.0120, -0.1227,  0.1753],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.5294,  0.5662,  0.3447,  0.2772,  0.4133,  0.1618, -0.4182, -0.9009,
-        -0.0963, -0.4656,  0.1341, -0.4548,  0.4800,  0.2360, -1.3569, -0.4797,
-        -0.7023,  0.4031, -0.9835, -0.4663, -0.6635,  0.1190, -0.2683, -0.5804,
-        -0.5047,  0.5566, -0.0453,  0.0423, -0.2403,  1.5543, -0.5938, -0.3708,
-        -0.9854,  0.0826, -0.6210, -0.4279,  1.0800,  0.0000,  0.4280,  0.9370,
-         0.3243,  1.1275,  0.0757,  0.4907,  0.4579, -0.6528,  0.5595,  0.3604,
-         0.9395, -0.2901,  0.2684, -0.0788, -0.7184, -0.0177,  0.8324,  0.0000,
-         1.1991,  1.0430,  0.5760, -0.1971,  1.3300, -0.0120, -0.1227,  0.1753],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5294,  0.5662,  0.3447,  0.2772,  0.4133,  0.1618, -0.4182, -0.9009,
-        -0.0963, -0.4656,  0.1341, -0.4548,  0.4800,  0.2360, -1.3569, -0.4797,
-        -0.7023,  0.4031, -0.9835, -0.4663, -0.6635,  0.1190, -0.2683, -0.5804,
-        -0.5047,  0.5566, -0.0453,  0.0423, -0.2403,  1.5543, -0.5938, -0.3708,
-        -0.9854,  0.0826, -0.6210, -0.4279,  1.0800,  0.0000,  0.4280,  0.9370,
-         0.3243,  1.1275,  0.0757,  0.4907,  0.4579, -0.6528,  0.5595,  0.3604,
-         0.9395, -0.2901,  0.2684, -0.0788, -0.7184, -0.0177,  0.8324,  0.0000,
-         1.1991,  1.0430,  0.5760, -0.1971,  1.3300, -0.0120, -0.1227,  0.1753],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.5005,  0.5507,  0.3612,  0.2846,  0.4012,  0.1246, -0.4335, -0.8915,
-        -0.1007, -0.4370,  0.0806, -0.4378,  0.5013,  0.2206, -1.3530, -0.4994,
-        -0.6981,  0.3951, -0.9832, -0.4623, -0.6551,  0.0596, -0.2913, -0.5655,
-        -0.4720,  0.5376,  0.0360,  0.0245, -0.2119,  1.5503, -0.5772, -0.3315,
-        -0.9879,  0.0964, -0.6416, -0.4411,  1.0923,  0.0000,  0.4323,  0.9236,
-         0.2467,  1.1289,  0.1601,  0.4739,  0.4536, -0.6137,  0.5505,  0.3725,
-         0.9316, -0.2821,  0.2657, -0.0886, -0.7127,  0.0501,  0.8308,  0.0016,
-         1.1966,  1.0436,  0.5378, -0.1836,  1.3361, -0.0289, -0.0317,  0.3285],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.5005,  0.5507,  0.3612,  0.2846,  0.4012,  0.1246, -0.4335, -0.8915,
-        -0.1007, -0.4370,  0.0806, -0.4378,  0.5013,  0.2206, -1.3530, -0.4994,
-        -0.6981,  0.3951, -0.9832, -0.4623, -0.6551,  0.0596, -0.2913, -0.5655,
-        -0.4720,  0.5376,  0.0360,  0.0245, -0.2119,  1.5503, -0.5772, -0.3315,
-        -0.9879,  0.0964, -0.6416, -0.4411,  1.0923,  0.0000,  0.4323,  0.9236,
-         0.2467,  1.1289,  0.1601,  0.4739,  0.4536, -0.6137,  0.5505,  0.3725,
-         0.9316, -0.2821,  0.2657, -0.0886, -0.7127,  0.0501,  0.8308,  0.0000,
-         1.1966,  1.0436,  0.5378, -0.1836,  1.3361, -0.0289, -0.0317,  0.3285],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5005,  0.5507,  0.3612,  0.2846,  0.4012,  0.1246, -0.4335, -0.8915,
-        -0.1007, -0.4370,  0.0806, -0.4378,  0.5013,  0.2206, -1.3530, -0.4994,
-        -0.6981,  0.3951, -0.9832, -0.4623, -0.6551,  0.0596, -0.2913, -0.5655,
-        -0.4720,  0.5376,  0.0360,  0.0245, -0.2119,  1.5503, -0.5772, -0.3315,
-        -0.9879,  0.0964, -0.6416, -0.4411,  1.0923,  0.0000,  0.4323,  0.9236,
-         0.2467,  1.1289,  0.1601,  0.4739,  0.4536, -0.6137,  0.5505,  0.3725,
-         0.9316, -0.2821,  0.2657, -0.0886, -0.7127,  0.0501,  0.8308,  0.0000,
-         1.1966,  1.0436,  0.5378, -0.1836,  1.3361, -0.0289, -0.0317,  0.3285],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.7434e-01,  5.3136e-01,  3.4447e-01,  2.6340e-01,  3.8525e-01,
-         1.7047e-02, -4.3931e-01, -8.8958e-01, -1.2484e-01, -4.2201e-01,
-         5.8387e-02, -4.2758e-01,  4.8165e-01,  2.4041e-01, -1.3549e+00,
-        -5.0273e-01, -6.8400e-01,  3.8064e-01, -9.7913e-01, -4.5901e-01,
-        -6.4912e-01, -2.2920e-02, -2.8090e-01, -5.8313e-01, -4.4491e-01,
-         5.3314e-01, -2.4882e-04, -3.0611e-02, -2.1588e-01,  1.5445e+00,
-        -5.4475e-01, -2.9967e-01, -9.9086e-01,  9.5095e-02, -6.5237e-01,
-        -4.6093e-01,  1.1036e+00,  0.0000e+00,  4.5985e-01,  9.1634e-01,
-         1.6312e-01,  1.1300e+00,  1.8601e-01,  4.7185e-01,  4.4466e-01,
-        -5.9200e-01,  5.4432e-01,  3.6106e-01,  9.1395e-01, -2.7625e-01,
-         2.7235e-01, -1.5361e-01, -7.0213e-01,  7.8308e-02,  8.3940e-01,
-         1.3793e-03,  1.1936e+00,  1.0505e+00,  5.0011e-01, -2.3361e-01,
-         1.3433e+00, -1.1271e-01,  6.2518e-02,  4.0680e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 4.7434e-01,  5.3136e-01,  3.4447e-01,  2.6340e-01,  3.8525e-01,
-         1.7047e-02, -4.3931e-01, -8.8958e-01, -1.2484e-01, -4.2201e-01,
-         5.8387e-02, -4.2758e-01,  4.8165e-01,  2.4041e-01, -1.3549e+00,
-        -5.0273e-01, -6.8400e-01,  3.8064e-01, -9.7913e-01, -4.5901e-01,
-        -6.4912e-01, -2.2920e-02, -2.8090e-01, -5.8313e-01, -4.4491e-01,
-         5.3314e-01, -2.4882e-04, -3.0611e-02, -2.1588e-01,  1.5445e+00,
-        -5.4475e-01, -2.9967e-01, -9.9086e-01,  9.5095e-02, -6.5237e-01,
-        -4.6093e-01,  1.1036e+00,  0.0000e+00,  4.5985e-01,  9.1634e-01,
-         1.6312e-01,  1.1300e+00,  1.8601e-01,  4.7185e-01,  4.4466e-01,
-        -5.9200e-01,  5.4432e-01,  3.6106e-01,  9.1395e-01, -2.7625e-01,
-         2.7235e-01, -1.5361e-01, -7.0213e-01,  7.8308e-02,  8.3940e-01,
-         0.0000e+00,  1.1936e+00,  1.0505e+00,  5.0011e-01, -2.3361e-01,
-         1.3433e+00, -1.1271e-01,  6.2518e-02,  4.0680e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 4.7434e-01,  5.3136e-01,  3.4447e-01,  2.6340e-01,  3.8525e-01,
-         1.7047e-02, -4.3931e-01, -8.8958e-01, -1.2484e-01, -4.2201e-01,
-         5.8387e-02, -4.2758e-01,  4.8165e-01,  2.4041e-01, -1.3549e+00,
-        -5.0273e-01, -6.8400e-01,  3.8064e-01, -9.7913e-01, -4.5901e-01,
-        -6.4912e-01, -2.2920e-02, -2.8090e-01, -5.8313e-01, -4.4491e-01,
-         5.3314e-01, -2.4882e-04, -3.0611e-02, -2.1588e-01,  1.5445e+00,
-        -5.4475e-01, -2.9967e-01, -9.9086e-01,  9.5095e-02, -6.5237e-01,
-        -4.6093e-01,  1.1036e+00,  0.0000e+00,  4.5985e-01,  9.1634e-01,
-         1.6312e-01,  1.1300e+00,  1.8601e-01,  4.7185e-01,  4.4466e-01,
-        -5.9200e-01,  5.4432e-01,  3.6106e-01,  9.1395e-01, -2.7625e-01,
-         2.7235e-01, -1.5361e-01, -7.0213e-01,  7.8308e-02,  8.3940e-01,
-         0.0000e+00,  1.1936e+00,  1.0505e+00,  5.0011e-01, -2.3361e-01,
-         1.3433e+00, -1.1271e-01,  6.2518e-02,  4.0680e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 4.4686e-01,  4.9925e-01,  3.1454e-01,  2.2944e-01,  2.8628e-01,
-        -1.1367e-01, -4.4661e-01, -8.8967e-01, -1.7278e-01, -4.0789e-01,
-         5.8748e-02, -4.5186e-01,  4.7216e-01,  2.7655e-01, -1.3624e+00,
-        -5.0357e-01, -6.6253e-01,  3.6791e-01, -9.7459e-01, -4.7744e-01,
-        -6.2327e-01, -5.5031e-02, -2.3543e-01, -5.7939e-01, -3.9417e-01,
-         5.3328e-01, -7.1986e-02, -1.0139e-01, -2.0998e-01,  1.5374e+00,
-        -5.0120e-01, -2.8643e-01, -9.9618e-01,  6.2610e-02, -6.5971e-01,
-        -4.8190e-01,  1.1151e+00,  0.0000e+00,  5.0631e-01,  9.0601e-01,
-         7.9952e-02,  1.1289e+00,  1.5917e-01,  4.9404e-01,  4.2534e-01,
-        -5.6014e-01,  5.3539e-01,  3.5149e-01,  8.9095e-01, -2.6897e-01,
-         2.8291e-01, -2.1534e-01, -6.8663e-01,  9.5797e-02,  8.4580e-01,
-         1.2193e-03,  1.1923e+00,  1.0567e+00,  4.4160e-01, -2.8526e-01,
-         1.3486e+00, -2.1306e-01,  1.3604e-01,  4.8053e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4469,  0.4993,  0.3145,  0.2294,  0.2863, -0.1137, -0.4466, -0.8897,
-        -0.1728, -0.4079,  0.0587, -0.4519,  0.4722,  0.2765, -1.3624, -0.5036,
-        -0.6625,  0.3679, -0.9746, -0.4774, -0.6233, -0.0550, -0.2354, -0.5794,
-        -0.3942,  0.5333, -0.0720, -0.1014, -0.2100,  1.5374, -0.5012, -0.2864,
-        -0.9962,  0.0626, -0.6597, -0.4819,  1.1151,  0.0000,  0.5063,  0.9060,
-         0.0800,  1.1289,  0.1592,  0.4940,  0.4253, -0.5601,  0.5354,  0.3515,
-         0.8909, -0.2690,  0.2829, -0.2153, -0.6866,  0.0958,  0.8458,  0.0000,
-         1.1923,  1.0567,  0.4416, -0.2853,  1.3486, -0.2131,  0.1360,  0.4805],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4469,  0.4993,  0.3145,  0.2294,  0.2863, -0.1137, -0.4466, -0.8897,
-        -0.1728, -0.4079,  0.0587, -0.4519,  0.4722,  0.2765, -1.3624, -0.5036,
-        -0.6625,  0.3679, -0.9746, -0.4774, -0.6233, -0.0550, -0.2354, -0.5794,
-        -0.3942,  0.5333, -0.0720, -0.1014, -0.2100,  1.5374, -0.5012, -0.2864,
-        -0.9962,  0.0626, -0.6597, -0.4819,  1.1151,  0.0000,  0.5063,  0.9060,
-         0.0800,  1.1289,  0.1592,  0.4940,  0.4253, -0.5601,  0.5354,  0.3515,
-         0.8909, -0.2690,  0.2829, -0.2153, -0.6866,  0.0958,  0.8458,  0.0000,
-         1.1923,  1.0567,  0.4416, -0.2853,  1.3486, -0.2131,  0.1360,  0.4805],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.1489e-01,  4.6735e-01,  2.4744e-01,  2.2695e-01,  1.4792e-01,
-        -2.3412e-01, -4.4970e-01, -8.9424e-01, -2.1878e-01, -3.8687e-01,
-        -4.1711e-03, -5.1513e-01,  4.3972e-01,  3.3494e-01, -1.3702e+00,
-        -4.9991e-01, -6.4546e-01,  3.4746e-01, -9.7031e-01, -5.1515e-01,
-        -6.0155e-01, -2.4765e-02, -1.5839e-01, -5.6952e-01, -3.2878e-01,
-         5.3044e-01, -1.3007e-01, -1.5043e-01, -2.1795e-01,  1.5329e+00,
-        -4.4278e-01, -2.7795e-01, -1.0083e+00,  4.5158e-02, -6.6397e-01,
-        -5.2204e-01,  1.1237e+00,  0.0000e+00,  5.4498e-01,  8.9738e-01,
-         1.9331e-02,  1.1237e+00,  9.0092e-02,  5.1898e-01,  4.0142e-01,
-        -5.2836e-01,  5.2314e-01,  3.3275e-01,  8.7549e-01, -2.5520e-01,
-         2.7865e-01, -2.6177e-01, -6.8787e-01,  8.1429e-02,  8.4915e-01,
-         1.0766e-03,  1.1914e+00,  1.0611e+00,  4.0930e-01, -3.1299e-01,
-         1.3517e+00, -2.8304e-01,  1.8263e-01,  5.2554e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4149,  0.4673,  0.2474,  0.2270,  0.1479, -0.2341, -0.4497, -0.8942,
-        -0.2188, -0.3869, -0.0042, -0.5151,  0.4397,  0.3349, -1.3702, -0.4999,
-        -0.6455,  0.3475, -0.9703, -0.5152, -0.6015, -0.0248, -0.1584, -0.5695,
-        -0.3288,  0.5304, -0.1301, -0.1504, -0.2179,  1.5329, -0.4428, -0.2779,
-        -1.0083,  0.0452, -0.6640, -0.5220,  1.1237,  0.0000,  0.5450,  0.8974,
-         0.0193,  1.1237,  0.0901,  0.5190,  0.4014, -0.5284,  0.5231,  0.3328,
-         0.8755, -0.2552,  0.2786, -0.2618, -0.6879,  0.0814,  0.8492,  0.0000,
-         1.1914,  1.0611,  0.4093, -0.3130,  1.3517, -0.2830,  0.1826,  0.5255],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4149,  0.4673,  0.2474,  0.2270,  0.1479, -0.2341, -0.4497, -0.8942,
-        -0.2188, -0.3869, -0.0042, -0.5151,  0.4397,  0.3349, -1.3702, -0.4999,
-        -0.6455,  0.3475, -0.9703, -0.5152, -0.6015, -0.0248, -0.1584, -0.5695,
-        -0.3288,  0.5304, -0.1301, -0.1504, -0.2179,  1.5329, -0.4428, -0.2779,
-        -1.0083,  0.0452, -0.6640, -0.5220,  1.1237,  0.0000,  0.5450,  0.8974,
-         0.0193,  1.1237,  0.0901,  0.5190,  0.4014, -0.5284,  0.5231,  0.3328,
-         0.8755, -0.2552,  0.2786, -0.2618, -0.6879,  0.0814,  0.8492,  0.0000,
-         1.1914,  1.0611,  0.4093, -0.3130,  1.3517, -0.2830,  0.1826,  0.5255],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.8658e-01,  4.3207e-01,  2.0286e-01,  2.3765e-01,  7.5984e-02,
-        -3.3305e-01, -4.5326e-01, -8.9755e-01, -2.5361e-01, -3.7180e-01,
-        -5.4420e-02, -5.6576e-01,  4.2756e-01,  3.6967e-01, -1.3699e+00,
-        -5.0082e-01, -6.3573e-01,  3.1747e-01, -9.7002e-01, -5.3799e-01,
-        -5.8383e-01,  1.8670e-02, -9.7602e-02, -5.7139e-01, -2.7084e-01,
-         5.1331e-01, -1.9138e-01, -1.9863e-01, -2.2566e-01,  1.5297e+00,
-        -4.0576e-01, -2.9125e-01, -1.0203e+00,  4.0524e-02, -6.6513e-01,
-        -5.5201e-01,  1.1272e+00,  0.0000e+00,  5.6757e-01,  8.8297e-01,
-        -2.8388e-02,  1.1182e+00,  5.9735e-02,  5.3286e-01,  3.7572e-01,
-        -5.0664e-01,  5.1590e-01,  3.2802e-01,  8.7402e-01, -2.5172e-01,
-         3.0577e-01, -2.8437e-01, -6.9558e-01,  1.0266e-01,  8.4576e-01,
-         9.4944e-04,  1.1903e+00,  1.0571e+00,  3.9179e-01, -3.5700e-01,
-         1.3521e+00, -3.3488e-01,  2.3143e-01,  5.7027e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3866,  0.4321,  0.2029,  0.2376,  0.0760, -0.3331, -0.4533, -0.8976,
-        -0.2536, -0.3718, -0.0544, -0.5658,  0.4276,  0.3697, -1.3699, -0.5008,
-        -0.6357,  0.3175, -0.9700, -0.5380, -0.5838,  0.0187, -0.0976, -0.5714,
-        -0.2708,  0.5133, -0.1914, -0.1986, -0.2257,  1.5297, -0.4058, -0.2913,
-        -1.0203,  0.0405, -0.6651, -0.5520,  1.1272,  0.0000,  0.5676,  0.8830,
-        -0.0284,  1.1182,  0.0597,  0.5329,  0.3757, -0.5066,  0.5159,  0.3280,
-         0.8740, -0.2517,  0.3058, -0.2844, -0.6956,  0.1027,  0.8458,  0.0000,
-         1.1903,  1.0571,  0.3918, -0.3570,  1.3521, -0.3349,  0.2314,  0.5703],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3866,  0.4321,  0.2029,  0.2376,  0.0760, -0.3331, -0.4533, -0.8976,
-        -0.2536, -0.3718, -0.0544, -0.5658,  0.4276,  0.3697, -1.3699, -0.5008,
-        -0.6357,  0.3175, -0.9700, -0.5380, -0.5838,  0.0187, -0.0976, -0.5714,
-        -0.2708,  0.5133, -0.1914, -0.1986, -0.2257,  1.5297, -0.4058, -0.2913,
-        -1.0203,  0.0405, -0.6651, -0.5520,  1.1272,  0.0000,  0.5676,  0.8830,
-        -0.0284,  1.1182,  0.0597,  0.5329,  0.3757, -0.5066,  0.5159,  0.3280,
-         0.8740, -0.2517,  0.3058, -0.2844, -0.6956,  0.1027,  0.8458,  0.0000,
-         1.1903,  1.0571,  0.3918, -0.3570,  1.3521, -0.3349,  0.2314,  0.5703],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5832e-01,  4.1469e-01,  1.6834e-01,  2.3787e-01,  6.2973e-02,
-        -4.0067e-01, -4.6883e-01, -8.9622e-01, -2.5055e-01, -3.6187e-01,
-        -6.0464e-02, -5.8041e-01,  4.0624e-01,  3.8544e-01, -1.3688e+00,
-        -5.1530e-01, -6.4254e-01,  2.9577e-01, -9.7209e-01, -5.3324e-01,
-        -5.5106e-01,  7.3299e-02, -6.7736e-02, -5.7136e-01, -2.1825e-01,
-         4.7969e-01, -2.5544e-01, -2.5601e-01, -1.9099e-01,  1.5274e+00,
-        -4.1321e-01, -2.9038e-01, -1.0316e+00,  7.1851e-02, -6.6025e-01,
-        -5.6114e-01,  1.1279e+00,  0.0000e+00,  5.6557e-01,  8.6631e-01,
-        -8.3516e-02,  1.1080e+00,  7.3494e-02,  5.3421e-01,  3.6870e-01,
-        -5.0147e-01,  5.1789e-01,  3.1655e-01,  8.7763e-01, -2.7093e-01,
-         3.4129e-01, -2.3053e-01, -7.0188e-01,  1.2035e-01,  8.4321e-01,
-         8.3629e-04,  1.1892e+00,  1.0512e+00,  3.9486e-01, -3.7781e-01,
-         1.3520e+00, -3.9737e-01,  2.6004e-01,  6.2106e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3583,  0.4147,  0.1683,  0.2379,  0.0630, -0.4007, -0.4688, -0.8962,
-        -0.2505, -0.3619, -0.0605, -0.5804,  0.4062,  0.3854, -1.3688, -0.5153,
-        -0.6425,  0.2958, -0.9721, -0.5332, -0.5511,  0.0733, -0.0677, -0.5714,
-        -0.2183,  0.4797, -0.2554, -0.2560, -0.1910,  1.5274, -0.4132, -0.2904,
-        -1.0316,  0.0719, -0.6603, -0.5611,  1.1279,  0.0000,  0.5656,  0.8663,
-        -0.0835,  1.1080,  0.0735,  0.5342,  0.3687, -0.5015,  0.5179,  0.3166,
-         0.8776, -0.2709,  0.3413, -0.2305, -0.7019,  0.1203,  0.8432,  0.0000,
-         1.1892,  1.0512,  0.3949, -0.3778,  1.3520, -0.3974,  0.2600,  0.6211],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3583,  0.4147,  0.1683,  0.2379,  0.0630, -0.4007, -0.4688, -0.8962,
-        -0.2505, -0.3619, -0.0605, -0.5804,  0.4062,  0.3854, -1.3688, -0.5153,
-        -0.6425,  0.2958, -0.9721, -0.5332, -0.5511,  0.0733, -0.0677, -0.5714,
-        -0.2183,  0.4797, -0.2554, -0.2560, -0.1910,  1.5274, -0.4132, -0.2904,
-        -1.0316,  0.0719, -0.6603, -0.5611,  1.1279,  0.0000,  0.5656,  0.8663,
-        -0.0835,  1.1080,  0.0735,  0.5342,  0.3687, -0.5015,  0.5179,  0.3166,
-         0.8776, -0.2709,  0.3413, -0.2305, -0.7019,  0.1203,  0.8432,  0.0000,
-         1.1892,  1.0512,  0.3949, -0.3778,  1.3520, -0.3974,  0.2600,  0.6211],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2526e-01,  4.0956e-01,  1.2180e-01,  1.8539e-01, -5.2661e-03,
-        -4.5072e-01, -4.7692e-01, -8.8702e-01, -2.0640e-01, -3.6031e-01,
-        -5.8457e-02, -5.6045e-01,  3.3965e-01,  3.4610e-01, -1.3687e+00,
-        -5.3847e-01, -6.5621e-01,  2.5606e-01, -9.7466e-01, -5.1314e-01,
-        -5.2988e-01,  9.2788e-02, -6.9188e-02, -5.6227e-01, -2.0315e-01,
-         4.3293e-01, -2.7233e-01, -3.0383e-01, -1.6673e-01,  1.5235e+00,
-        -4.3152e-01, -2.6211e-01, -1.0452e+00,  1.0224e-01, -6.5439e-01,
-        -5.4628e-01,  1.1253e+00,  0.0000e+00,  5.3577e-01,  8.4370e-01,
-        -1.7305e-01,  1.0959e+00,  9.2268e-02,  5.2217e-01,  3.6916e-01,
-        -4.9427e-01,  5.2949e-01,  3.0132e-01,  8.8039e-01, -2.8069e-01,
-         3.8209e-01, -1.7640e-01, -6.9605e-01,  1.3251e-01,  8.4039e-01,
-         7.3576e-04,  1.1878e+00,  1.0463e+00,  4.1179e-01, -4.3116e-01,
-         1.3512e+00, -4.4218e-01,  2.5608e-01,  6.5839e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3253,  0.4096,  0.1218,  0.1854, -0.0053, -0.4507, -0.4769, -0.8870,
-        -0.2064, -0.3603, -0.0585, -0.5605,  0.3397,  0.3461, -1.3687, -0.5385,
-        -0.6562,  0.2561, -0.9747, -0.5131, -0.5299,  0.0928, -0.0692, -0.5623,
-        -0.2032,  0.4329, -0.2723, -0.3038, -0.1667,  1.5235, -0.4315, -0.2621,
-         0.0000,  0.1022, -0.6544, -0.5463,  1.1253,  0.0000,  0.5358,  0.8437,
-        -0.1730,  1.0959,  0.0923,  0.5222,  0.3692, -0.4943,  0.5295,  0.3013,
-         0.8804, -0.2807,  0.3821, -0.1764, -0.6960,  0.1325,  0.8404,  0.0000,
-         1.1878,  1.0463,  0.4118, -0.4312,  1.3512, -0.4422,  0.2561,  0.6584],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3253,  0.4096,  0.1218,  0.1854, -0.0053, -0.4507, -0.4769, -0.8870,
-        -0.2064, -0.3603, -0.0585, -0.5605,  0.3397,  0.3461, -1.3687, -0.5385,
-        -0.6562,  0.2561, -0.9747, -0.5131, -0.5299,  0.0928, -0.0692, -0.5623,
-        -0.2032,  0.4329, -0.2723, -0.3038, -0.1667,  1.5235, -0.4315, -0.2621,
-         0.0000,  0.1022, -0.6544, -0.5463,  1.1253,  0.0000,  0.5358,  0.8437,
-        -0.1730,  1.0959,  0.0923,  0.5222,  0.3692, -0.4943,  0.5295,  0.3013,
-         0.8804, -0.2807,  0.3821, -0.1764, -0.6960,  0.1325,  0.8404,  0.0000,
-         1.1878,  1.0463,  0.4118, -0.4312,  1.3512, -0.4422,  0.2561,  0.6584],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8335e-01,  3.7645e-01,  7.2242e-02,  9.0517e-02, -9.5988e-02,
-        -4.8812e-01, -4.7566e-01, -8.7639e-01, -1.5875e-01, -3.6006e-01,
-        -4.5103e-02, -5.1877e-01,  2.3706e-01,  3.1634e-01, -1.3701e+00,
-        -5.6253e-01, -6.5865e-01,  1.8215e-01, -9.7692e-01, -4.7606e-01,
-        -5.1803e-01,  3.9137e-02, -7.8780e-02, -5.5568e-01, -2.0589e-01,
-         3.8830e-01, -2.7659e-01, -3.4333e-01, -1.3575e-01,  1.5201e+00,
-        -4.5991e-01, -2.4152e-01, -1.1976e-02,  8.1002e-02, -6.4323e-01,
-        -5.2810e-01,  1.1240e+00,  0.0000e+00,  5.3377e-01,  8.2152e-01,
-        -2.8555e-01,  1.0840e+00,  8.8123e-02,  5.1602e-01,  3.7535e-01,
-        -4.8682e-01,  5.4024e-01,  2.8516e-01,  8.8603e-01, -2.7155e-01,
-         3.9116e-01, -1.3488e-01, -6.8366e-01,  4.1572e-02,  8.3335e-01,
-         6.4655e-04,  1.1855e+00,  1.0445e+00,  4.4608e-01, -5.2146e-01,
-         1.3502e+00, -4.7163e-01,  2.6009e-01,  6.6427e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2833,  0.3764,  0.0722,  0.0905, -0.0960, -0.4881, -0.4757, -0.8764,
-        -0.1588, -0.3601, -0.0451, -0.5188,  0.2371,  0.3163, -1.3701, -0.5625,
-        -0.6586,  0.1821, -0.9769, -0.4761, -0.5180,  0.0391, -0.0788, -0.5557,
-        -0.2059,  0.3883, -0.2766, -0.3433, -0.1358,  1.5201, -0.4599, -0.2415,
-         0.0000,  0.0810, -0.6432, -0.5281,  1.1240,  0.0000,  0.5338,  0.8215,
-        -0.2855,  1.0840,  0.0881,  0.5160,  0.3754, -0.4868,  0.5402,  0.2852,
-         0.8860, -0.2715,  0.3912, -0.1349, -0.6837,  0.0416,  0.8334,  0.0000,
-         1.1855,  1.0445,  0.4461, -0.5215,  1.3502, -0.4716,  0.2601,  0.6643],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2833,  0.3764,  0.0722,  0.0905, -0.0960, -0.4881, -0.4757, -0.8764,
-        -0.1588, -0.3601, -0.0451, -0.5188,  0.2371,  0.3163, -1.3701, -0.5625,
-        -0.6586,  0.1821, -0.9769, -0.4761, -0.5180,  0.0391, -0.0788, -0.5557,
-        -0.2059,  0.3883, -0.2766, -0.3433, -0.1358,  1.5201, -0.4599, -0.2415,
-         0.0000,  0.0810, -0.6432, -0.5281,  1.1240,  0.0000,  0.5338,  0.8215,
-        -0.2855,  1.0840,  0.0881,  0.5160,  0.3754, -0.4868,  0.5402,  0.2852,
-         0.8860, -0.2715,  0.3912, -0.1349, -0.6837,  0.0416,  0.8334,  0.0000,
-         1.1855,  1.0445,  0.4461, -0.5215,  1.3502, -0.4716,  0.2601,  0.6643],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4234e-01,  3.3863e-01,  3.4895e-02,  2.5590e-02, -1.3559e-01,
-        -5.2509e-01, -4.6994e-01, -8.6440e-01, -1.1404e-01, -3.5915e-01,
-        -7.6912e-03, -4.9026e-01,  1.2877e-01,  3.0387e-01, -1.3708e+00,
-        -5.8762e-01, -6.5281e-01,  8.4058e-02, -9.7731e-01, -4.4821e-01,
-        -4.9260e-01, -6.7715e-02, -9.7038e-02, -5.2340e-01, -2.1808e-01,
-         3.5036e-01, -2.6479e-01, -3.6615e-01, -6.8418e-02,  1.5187e+00,
-        -4.6735e-01, -2.1986e-01, -1.0512e-02,  2.9859e-02, -6.3482e-01,
-        -5.1298e-01,  1.1219e+00,  0.0000e+00,  5.2084e-01,  7.9900e-01,
-        -3.8237e-01,  1.0663e+00,  7.0612e-02,  5.1228e-01,  3.7495e-01,
-        -4.7571e-01,  5.4339e-01,  2.8009e-01,  8.9329e-01, -2.4979e-01,
-         3.5966e-01, -1.0902e-01, -6.6832e-01, -6.5344e-02,  8.2540e-01,
-         5.6750e-04,  1.1828e+00,  1.0378e+00,  4.8432e-01, -5.6467e-01,
-         1.3480e+00, -4.7184e-01,  2.4507e-01,  6.6420e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2423,  0.3386,  0.0349,  0.0256, -0.1356, -0.5251, -0.4699, -0.8644,
-        -0.1140, -0.3591, -0.0077, -0.4903,  0.1288,  0.3039, -1.3708, -0.5876,
-        -0.6528,  0.0841, -0.9773, -0.4482, -0.4926, -0.0677, -0.0970, -0.5234,
-        -0.2181,  0.3504, -0.2648, -0.3661, -0.0684,  1.5187, -0.4673, -0.2199,
-         0.0000,  0.0299, -0.6348, -0.5130,  1.1219,  0.0000,  0.5208,  0.7990,
-        -0.3824,  1.0663,  0.0706,  0.5123,  0.3750, -0.4757,  0.5434,  0.2801,
-         0.8933, -0.2498,  0.3597, -0.1090, -0.6683, -0.0653,  0.8254,  0.0000,
-         1.1828,  1.0378,  0.4843, -0.5647,  1.3480, -0.4718,  0.2451,  0.6642],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2423,  0.3386,  0.0349,  0.0256, -0.1356, -0.5251, -0.4699, -0.8644,
-        -0.1140, -0.3591, -0.0077, -0.4903,  0.1288,  0.3039, -1.3708, -0.5876,
-        -0.6528,  0.0841, -0.9773, -0.4482, -0.4926, -0.0677, -0.0970, -0.5234,
-        -0.2181,  0.3504, -0.2648, -0.3661, -0.0684,  1.5187, -0.4673, -0.2199,
-         0.0000,  0.0299, -0.6348, -0.5130,  1.1219,  0.0000,  0.5208,  0.7990,
-        -0.3824,  1.0663,  0.0706,  0.5123,  0.3750, -0.4757,  0.5434,  0.2801,
-         0.8933, -0.2498,  0.3597, -0.1090, -0.6683, -0.0653,  0.8254,  0.0000,
-         1.1828,  1.0378,  0.4843, -0.5647,  1.3480, -0.4718,  0.2451,  0.6642],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.2510e-01,  3.0338e-01,  1.1492e-02, -1.2641e-02, -1.3784e-01,
-        -5.4989e-01, -4.5570e-01, -8.4936e-01, -8.4999e-02, -3.5397e-01,
-         5.5501e-02, -4.6702e-01,  1.1279e-02,  3.0916e-01, -1.3711e+00,
-        -6.0596e-01, -6.4047e-01, -4.1169e-02, -9.7885e-01, -4.2125e-01,
-        -4.5465e-01, -1.2476e-01, -1.0983e-01, -4.8359e-01, -1.8692e-01,
-         2.9702e-01, -2.3227e-01, -3.7419e-01,  3.2989e-02,  1.5181e+00,
-        -4.8642e-01, -2.0648e-01, -9.2159e-03, -2.4378e-02, -6.2438e-01,
-        -4.9296e-01,  1.1179e+00,  0.0000e+00,  5.2893e-01,  7.7672e-01,
-        -4.4759e-01,  1.0488e+00,  6.8724e-02,  5.2019e-01,  3.5787e-01,
-        -4.7100e-01,  5.5387e-01,  2.8443e-01,  9.1257e-01, -2.1086e-01,
-         3.3195e-01, -5.1742e-02, -6.5144e-01, -1.1947e-01,  8.0996e-01,
-         4.9755e-04,  1.1798e+00,  1.0248e+00,  5.3054e-01, -5.8665e-01,
-         1.3438e+00, -4.5631e-01,  2.5315e-01,  6.8225e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2251,  0.3034,  0.0115, -0.0126, -0.1378, -0.5499, -0.4557, -0.8494,
-        -0.0850, -0.3540,  0.0555, -0.4670,  0.0113,  0.3092, -1.3711, -0.6060,
-        -0.6405, -0.0412, -0.9789, -0.4212, -0.4547, -0.1248, -0.1098, -0.4836,
-        -0.1869,  0.2970, -0.2323, -0.3742,  0.0330,  1.5181, -0.4864, -0.2065,
-         0.0000, -0.0244, -0.6244, -0.4930,  1.1179,  0.0000,  0.5289,  0.7767,
-        -0.4476,  1.0488,  0.0687,  0.5202,  0.3579, -0.4710,  0.5539,  0.2844,
-         0.9126, -0.2109,  0.3319, -0.0517, -0.6514, -0.1195,  0.8100,  0.0000,
-         1.1798,  1.0248,  0.5305, -0.5867,  1.3438, -0.4563,  0.2532,  0.6823],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2251,  0.3034,  0.0115, -0.0126, -0.1378, -0.5499, -0.4557, -0.8494,
-        -0.0850, -0.3540,  0.0555, -0.4670,  0.0113,  0.3092, -1.3711, -0.6060,
-        -0.6405, -0.0412, -0.9789, -0.4212, -0.4547, -0.1248, -0.1098, -0.4836,
-        -0.1869,  0.2970, -0.2323, -0.3742,  0.0330,  1.5181, -0.4864, -0.2065,
-         0.0000, -0.0244, -0.6244, -0.4930,  1.1179,  0.0000,  0.5289,  0.7767,
-        -0.4476,  1.0488,  0.0687,  0.5202,  0.3579, -0.4710,  0.5539,  0.2844,
-         0.9126, -0.2109,  0.3319, -0.0517, -0.6514, -0.1195,  0.8100,  0.0000,
-         1.1798,  1.0248,  0.5305, -0.5867,  1.3438, -0.4563,  0.2532,  0.6823],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.3037e-01,  2.9201e-01, -3.1206e-02, -4.7282e-02, -6.9462e-02,
-        -5.7195e-01, -4.2314e-01, -8.3676e-01, -4.6083e-02, -3.5037e-01,
-         1.1067e-01, -4.5478e-01, -1.2709e-01,  2.9887e-01, -1.3727e+00,
-        -6.0974e-01, -6.1831e-01, -1.3539e-01, -9.7930e-01, -4.0291e-01,
-        -4.0710e-01, -1.1956e-01, -1.1495e-01, -4.2807e-01, -1.0415e-01,
-         2.6321e-01, -1.8016e-01, -3.6382e-01,  1.2106e-01,  1.5182e+00,
-        -4.9538e-01, -2.0564e-01, -8.0709e-03, -1.8849e-02, -6.1295e-01,
-        -4.7795e-01,  1.1156e+00,  0.0000e+00,  5.4083e-01,  7.5445e-01,
-        -4.8830e-01,  1.0327e+00,  8.3477e-02,  5.3046e-01,  3.3649e-01,
-        -4.4017e-01,  5.6419e-01,  2.9533e-01,  9.3467e-01, -1.5506e-01,
-         3.1528e-01,  1.9320e-02, -6.2648e-01, -1.0191e-01,  7.9118e-01,
-         4.3574e-04,  1.1760e+00,  1.0123e+00,  5.8793e-01, -5.6897e-01,
-         1.3386e+00, -4.3941e-01,  2.6283e-01,  7.1293e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2304,  0.2920, -0.0312, -0.0473, -0.0695, -0.5719, -0.4231, -0.8368,
-        -0.0461, -0.3504,  0.1107, -0.4548, -0.1271,  0.2989, -1.3727, -0.6097,
-        -0.6183, -0.1354, -0.9793, -0.4029, -0.4071, -0.1196, -0.1149, -0.4281,
-        -0.1041,  0.2632, -0.1802, -0.3638,  0.1211,  1.5182, -0.4954, -0.2056,
-         0.0000, -0.0188, -0.6130, -0.4779,  1.1156,  0.0000,  0.5408,  0.7545,
-        -0.4883,  1.0327,  0.0835,  0.5305,  0.3365, -0.4402,  0.5642,  0.2953,
-         0.9347, -0.1551,  0.3153,  0.0193, -0.6265, -0.1019,  0.7912,  0.0000,
-         1.1760,  1.0123,  0.5879, -0.5690,  1.3386, -0.4394,  0.2628,  0.7129],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2304,  0.2920, -0.0312, -0.0473, -0.0695, -0.5719, -0.4231, -0.8368,
-        -0.0461, -0.3504,  0.1107, -0.4548, -0.1271,  0.2989, -1.3727, -0.6097,
-        -0.6183, -0.1354, -0.9793, -0.4029, -0.4071, -0.1196, -0.1149, -0.4281,
-        -0.1041,  0.2632, -0.1802, -0.3638,  0.1211,  1.5182, -0.4954, -0.2056,
-         0.0000, -0.0188, -0.6130, -0.4779,  1.1156,  0.0000,  0.5408,  0.7545,
-        -0.4883,  1.0327,  0.0835,  0.5305,  0.3365, -0.4402,  0.5642,  0.2953,
-         0.9347, -0.1551,  0.3153,  0.0193, -0.6265, -0.1019,  0.7912,  0.0000,
-         1.1760,  1.0123,  0.5879, -0.5690,  1.3386, -0.4394,  0.2628,  0.7129],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.5148e-01,  2.3096e-01, -7.0476e-02, -6.4816e-02, -9.3666e-03,
-        -5.8261e-01, -3.8003e-01, -8.2577e-01, -1.6536e-03, -3.2015e-01,
-         8.9324e-02, -4.4685e-01, -2.5598e-01,  3.6009e-01, -1.3754e+00,
-        -6.1005e-01, -6.0531e-01, -2.1327e-01, -9.7892e-01, -4.0289e-01,
-        -3.4515e-01, -1.9114e-02, -8.7746e-02, -4.1921e-01,  3.0078e-02,
-         2.6468e-01, -8.9257e-02, -3.2142e-01,  1.7207e-01,  1.5187e+00,
-        -5.0466e-01, -2.0565e-01, -7.0603e-03,  3.9014e-02, -6.0936e-01,
-        -4.9118e-01,  1.1190e+00,  0.0000e+00,  6.3535e-01,  7.5230e-01,
-        -4.8854e-01,  1.0188e+00,  7.2376e-02,  5.7060e-01,  3.0097e-01,
-        -3.7380e-01,  5.6350e-01,  3.4461e-01,  9.5013e-01, -1.0100e-01,
-         2.7893e-01,  1.1064e-01, -5.8045e-01,  7.0618e-03,  7.7870e-01,
-         3.8117e-04,  1.1715e+00,  1.0237e+00,  6.3909e-01, -5.3812e-01,
-         1.3357e+00, -3.9157e-01,  2.7676e-01,  7.2371e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2515,  0.2310, -0.0705, -0.0648, -0.0094, -0.5826, -0.3800, -0.8258,
-        -0.0017, -0.3201,  0.0893, -0.4469, -0.2560,  0.3601, -1.3754, -0.6100,
-        -0.6053, -0.2133, -0.9789, -0.4029, -0.3451, -0.0191, -0.0877, -0.4192,
-         0.0301,  0.2647, -0.0893, -0.3214,  0.1721,  1.5187, -0.5047, -0.2057,
-         0.0000,  0.0390, -0.6094, -0.4912,  1.1190,  0.0000,  0.6353,  0.7523,
-        -0.4885,  1.0188,  0.0724,  0.5706,  0.3010, -0.3738,  0.5635,  0.3446,
-         0.9501, -0.1010,  0.2789,  0.1106, -0.5805,  0.0071,  0.7787,  0.0000,
-         1.1715,  1.0237,  0.6391, -0.5381,  1.3357, -0.3916,  0.2768,  0.7237],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2515,  0.2310, -0.0705, -0.0648, -0.0094, -0.5826, -0.3800, -0.8258,
-        -0.0017, -0.3201,  0.0893, -0.4469, -0.2560,  0.3601, -1.3754, -0.6100,
-        -0.6053, -0.2133, -0.9789, -0.4029, -0.3451, -0.0191, -0.0877, -0.4192,
-         0.0301,  0.2647, -0.0893, -0.3214,  0.1721,  1.5187, -0.5047, -0.2057,
-         0.0000,  0.0390, -0.6094, -0.4912,  1.1190,  0.0000,  0.6353,  0.7523,
-        -0.4885,  1.0188,  0.0724,  0.5706,  0.3010, -0.3738,  0.5635,  0.3446,
-         0.9501, -0.1010,  0.2789,  0.1106, -0.5805,  0.0071,  0.7787,  0.0000,
-         1.1715,  1.0237,  0.6391, -0.5381,  1.3357, -0.3916,  0.2768,  0.7237],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8784e-01,  1.9606e-01, -3.0975e-02, -6.8456e-02,  7.4342e-02,
-        -5.8470e-01, -3.3447e-01, -8.1293e-01,  4.6377e-02, -3.0117e-01,
-         1.0643e-01, -4.1663e-01, -2.9100e-01,  4.1550e-01, -1.3795e+00,
-        -6.0941e-01, -5.9117e-01, -2.8134e-01, -9.7999e-01, -3.9381e-01,
-        -2.8723e-01,  1.5130e-02, -9.7310e-02, -3.9246e-01,  9.1345e-02,
-         2.7114e-01,  2.5741e-02, -2.9561e-01,  2.3816e-01,  1.5207e+00,
-        -5.1283e-01, -1.8798e-01, -6.1695e-03,  4.0847e-02, -6.0354e-01,
-        -5.0087e-01,  1.1245e+00,  0.0000e+00,  7.2577e-01,  7.5191e-01,
-        -4.8876e-01,  1.0086e+00,  9.4448e-02,  5.9994e-01,  2.7256e-01,
-        -3.2912e-01,  5.6066e-01,  3.9535e-01,  9.5921e-01, -5.0561e-03,
-         2.6296e-01,  1.6482e-01, -5.5183e-01,  1.1016e-01,  7.6922e-01,
-         3.3308e-04,  1.1670e+00,  1.0388e+00,  6.9050e-01, -5.1265e-01,
-         1.3343e+00, -3.3464e-01,  2.9953e-01,  7.2662e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2878,  0.1961, -0.0310, -0.0685,  0.0743, -0.5847, -0.3345, -0.8129,
-         0.0464, -0.3012,  0.1064, -0.4166, -0.2910,  0.4155, -1.3795, -0.6094,
-        -0.5912, -0.2813, -0.9800, -0.3938, -0.2872,  0.0151, -0.0973, -0.3925,
-         0.0913,  0.2711,  0.0257, -0.2956,  0.2382,  1.5207, -0.5128, -0.1880,
-         0.0000,  0.0408, -0.6035, -0.5009,  1.1245,  0.0000,  0.7258,  0.7519,
-        -0.4888,  1.0086,  0.0944,  0.5999,  0.2726, -0.3291,  0.5607,  0.3953,
-         0.9592, -0.0051,  0.2630,  0.1648, -0.5518,  0.1102,  0.7692,  0.0000,
-         1.1670,  1.0388,  0.6905, -0.5126,  1.3343, -0.3346,  0.2995,  0.7266],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2878,  0.1961, -0.0310, -0.0685,  0.0743, -0.5847, -0.3345, -0.8129,
-         0.0464, -0.3012,  0.1064, -0.4166, -0.2910,  0.4155, -1.3795, -0.6094,
-        -0.5912, -0.2813, -0.9800, -0.3938, -0.2872,  0.0151, -0.0973, -0.3925,
-         0.0913,  0.2711,  0.0257, -0.2956,  0.2382,  1.5207, -0.5128, -0.1880,
-         0.0000,  0.0408, -0.6035, -0.5009,  1.1245,  0.0000,  0.7258,  0.7519,
-        -0.4888,  1.0086,  0.0944,  0.5999,  0.2726, -0.3291,  0.5607,  0.3953,
-         0.9592, -0.0051,  0.2630,  0.1648, -0.5518,  0.1102,  0.7692,  0.0000,
-         1.1670,  1.0388,  0.6905, -0.5126,  1.3343, -0.3346,  0.2995,  0.7266],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2016e-01,  1.6902e-01,  6.5960e-02, -5.2243e-02,  5.4671e-02,
-        -5.8142e-01, -2.9816e-01, -7.9898e-01,  9.6247e-02, -2.7137e-01,
-         1.3916e-01, -3.9343e-01, -3.1070e-01,  4.5726e-01, -1.3844e+00,
-        -6.0476e-01, -5.8012e-01, -3.5720e-01, -9.8188e-01, -3.8313e-01,
-        -2.2709e-01,  4.4916e-02, -9.1694e-02, -4.2204e-01,  1.6743e-01,
-         2.7368e-01,  1.5190e-01, -2.5995e-01,  3.1660e-01,  1.5222e+00,
-        -5.1727e-01, -1.6358e-01, -5.3853e-03,  4.5528e-02, -5.9365e-01,
-        -5.0034e-01,  1.1287e+00,  0.0000e+00,  8.0447e-01,  7.4696e-01,
-        -4.7955e-01,  1.0071e+00,  1.3333e-01,  6.2559e-01,  2.7621e-01,
-        -2.8693e-01,  5.5761e-01,  4.3043e-01,  9.5421e-01,  1.2754e-01,
-         2.5680e-01,  2.3589e-01, -5.3667e-01,  1.7479e-01,  7.6241e-01,
-         2.9074e-04,  1.1624e+00,  1.0563e+00,  7.3782e-01, -5.1307e-01,
-         1.3352e+00, -2.6328e-01,  3.2791e-01,  7.1292e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3202,  0.1690,  0.0660, -0.0522,  0.0547, -0.5814, -0.2982, -0.7990,
-         0.0962, -0.2714,  0.1392, -0.3934, -0.3107,  0.4573, -1.3844, -0.6048,
-        -0.5801, -0.3572, -0.9819, -0.3831, -0.2271,  0.0449, -0.0917, -0.4220,
-         0.1674,  0.2737,  0.1519, -0.2600,  0.3166,  1.5222, -0.5173, -0.1636,
-         0.0000,  0.0455, -0.5936, -0.5003,  1.1287,  0.0000,  0.8045,  0.7470,
-        -0.4796,  1.0071,  0.1333,  0.6256,  0.2762, -0.2869,  0.5576,  0.4304,
-         0.9542,  0.1275,  0.2568,  0.2359, -0.5367,  0.1748,  0.7624,  0.0000,
-         1.1624,  1.0563,  0.7378, -0.5131,  1.3352, -0.2633,  0.3279,  0.7129],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3202,  0.1690,  0.0660, -0.0522,  0.0547, -0.5814, -0.2982, -0.7990,
-         0.0962, -0.2714,  0.1392, -0.3934, -0.3107,  0.4573, -1.3844, -0.6048,
-        -0.5801, -0.3572, -0.9819, -0.3831, -0.2271,  0.0449, -0.0917, -0.4220,
-         0.1674,  0.2737,  0.1519, -0.2600,  0.3166,  1.5222, -0.5173, -0.1636,
-         0.0000,  0.0455, -0.5936, -0.5003,  1.1287,  0.0000,  0.8045,  0.7470,
-        -0.4796,  1.0071,  0.1333,  0.6256,  0.2762, -0.2869,  0.5576,  0.4304,
-         0.9542,  0.1275,  0.2568,  0.2359, -0.5367,  0.1748,  0.7624,  0.0000,
-         1.1624,  1.0563,  0.7378, -0.5131,  1.3352, -0.2633,  0.3279,  0.7129],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.3870e-01,  1.9632e-01,  1.2374e-01, -2.8232e-02,  1.9848e-02,
-        -5.7453e-01, -2.6936e-01, -7.8782e-01,  1.5332e-01, -2.4755e-01,
-         1.0010e-01, -3.4842e-01, -3.3096e-01,  4.5476e-01, -1.3890e+00,
-        -5.9188e-01, -5.6845e-01, -4.2543e-01, -9.8422e-01, -3.6692e-01,
-        -1.7175e-01,  6.1644e-02, -9.7776e-02, -4.2134e-01,  2.1071e-01,
-         2.5768e-01,  2.6931e-01, -2.3119e-01,  3.5724e-01,  1.5225e+00,
-        -4.9909e-01, -1.5195e-01, -4.6958e-03,  1.4643e-03, -5.7885e-01,
-        -4.9063e-01,  1.1356e+00,  0.0000e+00,  8.6379e-01,  7.3703e-01,
-        -4.7238e-01,  1.0119e+00,  1.5884e-01,  6.3738e-01,  2.8576e-01,
-        -2.4326e-01,  5.5101e-01,  4.3891e-01,  9.4219e-01,  2.1345e-01,
-         2.6593e-01,  2.6503e-01, -5.3963e-01,  2.4435e-01,  7.5653e-01,
-         2.5352e-04,  1.1580e+00,  1.0762e+00,  7.7085e-01, -5.1182e-01,
-         1.3385e+00, -1.9836e-01,  3.3044e-01,  7.1001e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.3870e-01,  1.9632e-01,  1.2374e-01, -2.8232e-02,  1.9848e-02,
-        -5.7453e-01, -2.6936e-01, -7.8782e-01,  1.5332e-01, -2.4755e-01,
-         1.0010e-01, -3.4842e-01, -3.3096e-01,  4.5476e-01, -1.3890e+00,
-        -5.9188e-01, -5.6845e-01, -4.2543e-01, -9.8422e-01, -3.6692e-01,
-        -1.7175e-01,  6.1644e-02, -9.7776e-02, -4.2134e-01,  2.1071e-01,
-         2.5768e-01,  2.6931e-01, -2.3119e-01,  3.5724e-01,  1.5225e+00,
-        -4.9909e-01, -1.5195e-01,  0.0000e+00,  1.4643e-03, -5.7885e-01,
-        -4.9063e-01,  1.1356e+00,  0.0000e+00,  8.6379e-01,  7.3703e-01,
-        -4.7238e-01,  1.0119e+00,  1.5884e-01,  6.3738e-01,  2.8576e-01,
-        -2.4326e-01,  5.5101e-01,  4.3891e-01,  9.4219e-01,  2.1345e-01,
-         2.6593e-01,  2.6503e-01, -5.3963e-01,  2.4435e-01,  7.5653e-01,
-         0.0000e+00,  1.1580e+00,  1.0762e+00,  7.7085e-01, -5.1182e-01,
-         1.3385e+00, -1.9836e-01,  3.3044e-01,  7.1001e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.3870e-01,  1.9632e-01,  1.2374e-01, -2.8232e-02,  1.9848e-02,
-        -5.7453e-01, -2.6936e-01, -7.8782e-01,  1.5332e-01, -2.4755e-01,
-         1.0010e-01, -3.4842e-01, -3.3096e-01,  4.5476e-01, -1.3890e+00,
-        -5.9188e-01, -5.6845e-01, -4.2543e-01, -9.8422e-01, -3.6692e-01,
-        -1.7175e-01,  6.1644e-02, -9.7776e-02, -4.2134e-01,  2.1071e-01,
-         2.5768e-01,  2.6931e-01, -2.3119e-01,  3.5724e-01,  1.5225e+00,
-        -4.9909e-01, -1.5195e-01,  0.0000e+00,  1.4643e-03, -5.7885e-01,
-        -4.9063e-01,  1.1356e+00,  0.0000e+00,  8.6379e-01,  7.3703e-01,
-        -4.7238e-01,  1.0119e+00,  1.5884e-01,  6.3738e-01,  2.8576e-01,
-        -2.4326e-01,  5.5101e-01,  4.3891e-01,  9.4219e-01,  2.1345e-01,
-         2.6593e-01,  2.6503e-01, -5.3963e-01,  2.4435e-01,  7.5653e-01,
-         0.0000e+00,  1.1580e+00,  1.0762e+00,  7.7085e-01, -5.1182e-01,
-         1.3385e+00, -1.9836e-01,  3.3044e-01,  7.1001e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7427e-01,  2.3862e-01,  1.1569e-01, -1.1472e-04, -4.9408e-02,
-        -5.6334e-01, -2.7061e-01, -7.7616e-01,  1.8241e-01, -2.4802e-01,
-        -9.7585e-02, -2.3705e-01, -2.7083e-01,  4.4199e-01, -1.3913e+00,
-        -5.8010e-01, -5.5689e-01, -5.0147e-01, -9.8966e-01, -3.4045e-01,
-        -1.3759e-01,  6.7824e-02, -1.2590e-01, -3.5126e-01,  2.3472e-01,
-         2.4121e-01,  2.6258e-01, -2.5120e-01,  3.2060e-01,  1.5240e+00,
-        -4.6840e-01, -1.3575e-01, -4.0903e-03, -6.3776e-02, -5.6554e-01,
-        -4.8846e-01,  1.1431e+00,  0.0000e+00,  9.1539e-01,  7.3696e-01,
-        -5.0517e-01,  1.0125e+00,  1.0421e-01,  6.3970e-01,  2.9006e-01,
-        -2.6576e-01,  5.4095e-01,  4.0371e-01,  9.2974e-01,  2.4050e-01,
-         2.3361e-01,  2.7879e-01, -5.4637e-01,  2.5070e-01,  7.5700e-01,
-         2.2083e-04,  1.1535e+00,  1.1003e+00,  7.8545e-01, -5.0040e-01,
-         1.3426e+00, -1.6117e-01,  3.2458e-01,  6.5906e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.7427e-01,  2.3862e-01,  1.1569e-01, -1.1472e-04, -4.9408e-02,
-        -5.6334e-01, -2.7061e-01, -7.7616e-01,  1.8241e-01, -2.4802e-01,
-        -9.7585e-02, -2.3705e-01, -2.7083e-01,  4.4199e-01, -1.3913e+00,
-        -5.8010e-01, -5.5689e-01, -5.0147e-01, -9.8966e-01, -3.4045e-01,
-        -1.3759e-01,  6.7824e-02, -1.2590e-01, -3.5126e-01,  2.3472e-01,
-         2.4121e-01,  2.6258e-01, -2.5120e-01,  3.2060e-01,  1.5240e+00,
-        -4.6840e-01, -1.3575e-01,  0.0000e+00, -6.3776e-02, -5.6554e-01,
-        -4.8846e-01,  1.1431e+00,  0.0000e+00,  9.1539e-01,  7.3696e-01,
-        -5.0517e-01,  1.0125e+00,  1.0421e-01,  6.3970e-01,  2.9006e-01,
-        -2.6576e-01,  5.4095e-01,  4.0371e-01,  9.2974e-01,  2.4050e-01,
-         2.3361e-01,  2.7879e-01, -5.4637e-01,  2.5070e-01,  7.5700e-01,
-         0.0000e+00,  1.1535e+00,  1.1003e+00,  7.8545e-01, -5.0040e-01,
-         1.3426e+00, -1.6117e-01,  3.2458e-01,  6.5906e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.7427e-01,  2.3862e-01,  1.1569e-01, -1.1472e-04, -4.9408e-02,
-        -5.6334e-01, -2.7061e-01, -7.7616e-01,  1.8241e-01, -2.4802e-01,
-        -9.7585e-02, -2.3705e-01, -2.7083e-01,  4.4199e-01, -1.3913e+00,
-        -5.8010e-01, -5.5689e-01, -5.0147e-01, -9.8966e-01, -3.4045e-01,
-        -1.3759e-01,  6.7824e-02, -1.2590e-01, -3.5126e-01,  2.3472e-01,
-         2.4121e-01,  2.6258e-01, -2.5120e-01,  3.2060e-01,  1.5240e+00,
-        -4.6840e-01, -1.3575e-01,  0.0000e+00, -6.3776e-02, -5.6554e-01,
-        -4.8846e-01,  1.1431e+00,  0.0000e+00,  9.1539e-01,  7.3696e-01,
-        -5.0517e-01,  1.0125e+00,  1.0421e-01,  6.3970e-01,  2.9006e-01,
-        -2.6576e-01,  5.4095e-01,  4.0371e-01,  9.2974e-01,  2.4050e-01,
-         2.3361e-01,  2.7879e-01, -5.4637e-01,  2.5070e-01,  7.5700e-01,
-         0.0000e+00,  1.1535e+00,  1.1003e+00,  7.8545e-01, -5.0040e-01,
-         1.3426e+00, -1.6117e-01,  3.2458e-01,  6.5906e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 4.0275e-01,  3.0295e-01,  8.7616e-02, -6.8473e-03, -2.3629e-03,
-        -5.5596e-01, -3.0572e-01, -7.6919e-01,  2.0723e-01, -2.5831e-01,
-        -2.8551e-01, -1.6403e-01, -2.0650e-01,  4.3972e-01, -1.3927e+00,
-        -5.8607e-01, -5.4892e-01, -5.6838e-01, -9.9597e-01, -3.0391e-01,
-        -1.1824e-01,  4.8019e-02, -1.5658e-01, -2.7805e-01,  2.4039e-01,
-         2.0225e-01,  2.0949e-01, -2.7274e-01,  2.8724e-01,  1.5262e+00,
-        -4.3072e-01, -9.9075e-02, -3.5592e-03, -1.1320e-01, -5.5375e-01,
-        -4.6086e-01,  1.1487e+00,  0.0000e+00,  9.4047e-01,  7.3973e-01,
-        -5.2195e-01,  1.0119e+00,  2.7601e-02,  6.4003e-01,  2.8693e-01,
-        -2.7280e-01,  5.3535e-01,  3.3758e-01,  9.2064e-01,  2.3349e-01,
-         1.7206e-01,  3.0498e-01, -5.5013e-01,  2.1792e-01,  7.5538e-01,
-         1.9216e-04,  1.1493e+00,  1.1191e+00,  7.8044e-01, -4.9833e-01,
-         1.3452e+00, -1.6168e-01,  3.0656e-01,  6.0137e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4027,  0.3029,  0.0876, -0.0068, -0.0024, -0.5560, -0.3057, -0.7692,
-         0.2072, -0.2583, -0.2855, -0.1640, -0.2065,  0.4397, -1.3927, -0.5861,
-        -0.5489, -0.5684, -0.9960, -0.3039, -0.1182,  0.0480, -0.1566, -0.2780,
-         0.2404,  0.2023,  0.2095, -0.2727,  0.2872,  1.5262, -0.4307, -0.0991,
-         0.0000, -0.1132, -0.5538, -0.4609,  1.1487,  0.0000,  0.9405,  0.7397,
-        -0.5219,  1.0119,  0.0276,  0.6400,  0.2869, -0.2728,  0.5353,  0.3376,
-         0.9206,  0.2335,  0.1721,  0.3050, -0.5501,  0.2179,  0.7554,  0.0000,
-         1.1493,  1.1191,  0.7804, -0.4983,  1.3452, -0.1617,  0.3066,  0.6014],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4027,  0.3029,  0.0876, -0.0068, -0.0024, -0.5560, -0.3057, -0.7692,
-         0.2072, -0.2583, -0.2855, -0.1640, -0.2065,  0.4397, -1.3927, -0.5861,
-        -0.5489, -0.5684, -0.9960, -0.3039, -0.1182,  0.0480, -0.1566, -0.2780,
-         0.2404,  0.2023,  0.2095, -0.2727,  0.2872,  1.5262, -0.4307, -0.0991,
-         0.0000, -0.1132, -0.5538, -0.4609,  1.1487,  0.0000,  0.9405,  0.7397,
-        -0.5219,  1.0119,  0.0276,  0.6400,  0.2869, -0.2728,  0.5353,  0.3376,
-         0.9206,  0.2335,  0.1721,  0.3050, -0.5501,  0.2179,  0.7554,  0.0000,
-         1.1493,  1.1191,  0.7804, -0.4983,  1.3452, -0.1617,  0.3066,  0.6014],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.3692e-01,  3.4762e-01,  7.3149e-02, -2.2991e-02, -2.9774e-02,
-        -5.4568e-01, -3.6261e-01, -7.6319e-01,  2.4172e-01, -2.5998e-01,
-        -4.0015e-01, -1.8170e-01, -1.8017e-01,  4.4415e-01, -1.3868e+00,
-        -6.0537e-01, -5.6873e-01, -6.0433e-01, -1.0051e+00, -2.8043e-01,
-        -7.5100e-02,  5.6049e-02, -1.8048e-01, -2.6297e-01,  2.4726e-01,
-         1.7628e-01,  1.7304e-01, -2.7591e-01,  2.5943e-01,  1.5298e+00,
-        -4.2249e-01, -6.5333e-02, -3.0940e-03, -1.1292e-01, -5.4387e-01,
-        -4.3574e-01,  1.1507e+00,  0.0000e+00,  9.4352e-01,  7.3697e-01,
-        -4.9464e-01,  1.0097e+00,  7.3701e-04,  6.3928e-01,  2.5555e-01,
-        -2.0775e-01,  5.3094e-01,  2.6114e-01,  9.1628e-01,  2.4760e-01,
-         1.2261e-01,  3.1895e-01, -5.3048e-01,  1.8111e-01,  7.4288e-01,
-         1.6704e-04,  1.1459e+00,  1.1251e+00,  7.8041e-01, -4.8126e-01,
-         1.3467e+00, -1.4445e-01,  2.8659e-01,  5.4464e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 4.3692e-01,  3.4762e-01,  7.3149e-02, -2.2991e-02, -2.9774e-02,
-        -5.4568e-01, -3.6261e-01, -7.6319e-01,  2.4172e-01, -2.5998e-01,
-        -4.0015e-01, -1.8170e-01, -1.8017e-01,  4.4415e-01, -1.3868e+00,
-        -6.0537e-01, -5.6873e-01, -6.0433e-01, -1.0051e+00, -2.8043e-01,
-        -7.5100e-02,  5.6049e-02, -1.8048e-01, -2.6297e-01,  2.4726e-01,
-         1.7628e-01,  1.7304e-01, -2.7591e-01,  2.5943e-01,  1.5298e+00,
-        -4.2249e-01, -6.5333e-02,  0.0000e+00, -1.1292e-01, -5.4387e-01,
-        -4.3574e-01,  1.1507e+00,  0.0000e+00,  9.4352e-01,  7.3697e-01,
-        -4.9464e-01,  1.0097e+00,  7.3701e-04,  6.3928e-01,  2.5555e-01,
-        -2.0775e-01,  5.3094e-01,  2.6114e-01,  9.1628e-01,  2.4760e-01,
-         1.2261e-01,  3.1895e-01, -5.3048e-01,  1.8111e-01,  7.4288e-01,
-         0.0000e+00,  1.1459e+00,  1.1251e+00,  7.8041e-01, -4.8126e-01,
-         1.3467e+00, -1.4445e-01,  2.8659e-01,  5.4464e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 4.3692e-01,  3.4762e-01,  7.3149e-02, -2.2991e-02, -2.9774e-02,
-        -5.4568e-01, -3.6261e-01, -7.6319e-01,  2.4172e-01, -2.5998e-01,
-        -4.0015e-01, -1.8170e-01, -1.8017e-01,  4.4415e-01, -1.3868e+00,
-        -6.0537e-01, -5.6873e-01, -6.0433e-01, -1.0051e+00, -2.8043e-01,
-        -7.5100e-02,  5.6049e-02, -1.8048e-01, -2.6297e-01,  2.4726e-01,
-         1.7628e-01,  1.7304e-01, -2.7591e-01,  2.5943e-01,  1.5298e+00,
-        -4.2249e-01, -6.5333e-02,  0.0000e+00, -1.1292e-01, -5.4387e-01,
-        -4.3574e-01,  1.1507e+00,  0.0000e+00,  9.4352e-01,  7.3697e-01,
-        -4.9464e-01,  1.0097e+00,  7.3701e-04,  6.3928e-01,  2.5555e-01,
-        -2.0775e-01,  5.3094e-01,  2.6114e-01,  9.1628e-01,  2.4760e-01,
-         1.2261e-01,  3.1895e-01, -5.3048e-01,  1.8111e-01,  7.4288e-01,
-         0.0000e+00,  1.1459e+00,  1.1251e+00,  7.8041e-01, -4.8126e-01,
-         1.3467e+00, -1.4445e-01,  2.8659e-01,  5.4464e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 4.5853e-01,  3.8525e-01,  3.5848e-02, -2.6446e-02, -9.0685e-02,
-        -5.3688e-01, -4.1273e-01, -7.5595e-01,  2.5259e-01, -2.6868e-01,
-        -5.0205e-01, -1.7252e-01, -1.9034e-01,  4.5152e-01, -1.3821e+00,
-        -6.2082e-01, -5.9211e-01, -6.3704e-01, -1.0111e+00, -2.4640e-01,
-        -5.6272e-02,  6.1906e-02, -2.0459e-01, -3.5353e-01,  2.4431e-01,
-         1.6795e-01,  1.5231e-01, -2.7926e-01,  2.7038e-01,  1.5328e+00,
-        -4.3813e-01, -5.7218e-02, -2.6870e-03, -1.3538e-01, -5.2806e-01,
-        -4.0913e-01,  1.1564e+00,  0.0000e+00,  9.3936e-01,  7.3211e-01,
-        -4.6745e-01,  1.0101e+00, -3.9769e-02,  6.3343e-01,  2.5223e-01,
-        -1.0816e-01,  5.3117e-01,  1.8499e-01,  9.1631e-01,  2.5180e-01,
-         6.4771e-02,  3.1120e-01, -5.0622e-01,  1.5747e-01,  7.2205e-01,
-         1.4506e-04,  1.1412e+00,  1.1265e+00,  7.8545e-01, -4.7025e-01,
-         1.3465e+00, -1.0869e-01,  2.5313e-01,  5.0347e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4585,  0.3853,  0.0358, -0.0264, -0.0907, -0.5369, -0.4127, -0.7559,
-         0.2526, -0.2687, -0.5021, -0.1725, -0.1903,  0.4515, -1.3821, -0.6208,
-        -0.5921, -0.6370, -1.0111, -0.2464, -0.0563,  0.0619, -0.2046, -0.3535,
-         0.2443,  0.1679,  0.1523, -0.2793,  0.2704,  1.5328, -0.4381, -0.0572,
-         0.0000, -0.1354, -0.5281, -0.4091,  1.1564,  0.0000,  0.9394,  0.7321,
-        -0.4675,  1.0101, -0.0398,  0.6334,  0.2522, -0.1082,  0.5312,  0.1850,
-         0.9163,  0.2518,  0.0648,  0.3112, -0.5062,  0.1575,  0.7220,  0.0000,
-         1.1412,  1.1265,  0.7854, -0.4703,  1.3465, -0.1087,  0.2531,  0.5035],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4585,  0.3853,  0.0358, -0.0264, -0.0907, -0.5369, -0.4127, -0.7559,
-         0.2526, -0.2687, -0.5021, -0.1725, -0.1903,  0.4515, -1.3821, -0.6208,
-        -0.5921, -0.6370, -1.0111, -0.2464, -0.0563,  0.0619, -0.2046, -0.3535,
-         0.2443,  0.1679,  0.1523, -0.2793,  0.2704,  1.5328, -0.4381, -0.0572,
-         0.0000, -0.1354, -0.5281, -0.4091,  1.1564,  0.0000,  0.9394,  0.7321,
-        -0.4675,  1.0101, -0.0398,  0.6334,  0.2522, -0.1082,  0.5312,  0.1850,
-         0.9163,  0.2518,  0.0648,  0.3112, -0.5062,  0.1575,  0.7220,  0.0000,
-         1.1412,  1.1265,  0.7854, -0.4703,  1.3465, -0.1087,  0.2531,  0.5035],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.9695e-01,  3.8668e-01,  3.3073e-02, -4.3990e-02, -1.1958e-01,
-        -5.2729e-01, -4.6085e-01, -7.5641e-01,  2.9098e-01, -2.9789e-01,
-        -6.0404e-01, -1.9196e-01, -1.8827e-01,  4.2683e-01, -1.3776e+00,
-        -6.3927e-01, -6.0448e-01, -6.7030e-01, -1.0144e+00, -2.3420e-01,
-        -5.8769e-02, -1.3312e-02, -2.3290e-01, -4.5635e-01,  1.8024e-01,
-         1.3665e-01,  1.3583e-01, -2.8286e-01,  2.7064e-01,  1.5337e+00,
-        -4.5571e-01, -5.2556e-02, -2.3312e-03, -1.8505e-01, -5.2214e-01,
-        -3.9583e-01,  1.1628e+00,  0.0000e+00,  9.2557e-01,  7.2819e-01,
-        -4.4290e-01,  1.0157e+00, -5.9706e-02,  6.3392e-01,  2.4110e-01,
-         2.9521e-02,  5.3333e-01,  1.3844e-01,  9.0546e-01,  2.6556e-01,
-         1.5511e-02,  2.7189e-01, -4.5555e-01,  7.7601e-02,  6.9984e-01,
-         1.2586e-04,  1.1370e+00,  1.1301e+00,  7.9383e-01, -4.7887e-01,
-         1.3456e+00, -7.7985e-02,  2.1311e-01,  4.9110e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4970,  0.3867,  0.0331, -0.0440, -0.1196, -0.5273, -0.4609, -0.7564,
-         0.2910, -0.2979, -0.6040, -0.1920, -0.1883,  0.4268, -1.3776, -0.6393,
-        -0.6045, -0.6703, -1.0144, -0.2342, -0.0588, -0.0133, -0.2329, -0.4564,
-         0.1802,  0.1367,  0.1358, -0.2829,  0.2706,  1.5337, -0.4557, -0.0526,
-         0.0000, -0.1851, -0.5221, -0.3958,  1.1628,  0.0000,  0.9256,  0.7282,
-        -0.4429,  1.0157, -0.0597,  0.6339,  0.2411,  0.0295,  0.5333,  0.1384,
-         0.9055,  0.2656,  0.0155,  0.2719, -0.4555,  0.0776,  0.6998,  0.0000,
-         1.1370,  1.1301,  0.7938, -0.4789,  1.3456, -0.0780,  0.2131,  0.4911],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4970,  0.3867,  0.0331, -0.0440, -0.1196, -0.5273, -0.4609, -0.7564,
-         0.2910, -0.2979, -0.6040, -0.1920, -0.1883,  0.4268, -1.3776, -0.6393,
-        -0.6045, -0.6703, -1.0144, -0.2342, -0.0588, -0.0133, -0.2329, -0.4564,
-         0.1802,  0.1367,  0.1358, -0.2829,  0.2706,  1.5337, -0.4557, -0.0526,
-         0.0000, -0.1851, -0.5221, -0.3958,  1.1628,  0.0000,  0.9256,  0.7282,
-        -0.4429,  1.0157, -0.0597,  0.6339,  0.2411,  0.0295,  0.5333,  0.1384,
-         0.9055,  0.2656,  0.0155,  0.2719, -0.4555,  0.0776,  0.6998,  0.0000,
-         1.1370,  1.1301,  0.7938, -0.4789,  1.3456, -0.0780,  0.2131,  0.4911],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.3471e-01,  3.9068e-01,  5.0698e-02, -7.6962e-02, -1.3458e-01,
-        -5.2232e-01, -4.8205e-01, -7.5373e-01,  3.5691e-01, -3.1490e-01,
-        -6.9891e-01, -1.5452e-01, -1.6866e-01,  3.8841e-01, -1.3773e+00,
-        -6.3968e-01, -6.0274e-01, -6.9783e-01, -1.0155e+00, -1.8802e-01,
-        -4.3873e-02, -1.1034e-01, -2.6079e-01, -5.5225e-01,  9.6189e-02,
-         9.8750e-02,  1.6891e-01, -3.0405e-01,  2.7231e-01,  1.5338e+00,
-        -4.6014e-01, -3.3035e-02, -2.0206e-03, -2.6875e-01, -4.9604e-01,
-        -3.6592e-01,  1.1648e+00,  0.0000e+00,  9.0979e-01,  7.2409e-01,
-        -4.3818e-01,  1.0187e+00, -6.6932e-02,  6.2557e-01,  2.1692e-01,
-         1.4740e-01,  5.4063e-01,  9.4134e-02,  8.9168e-01,  2.8511e-01,
-        -3.7572e-02,  2.1620e-01, -4.0274e-01,  9.4125e-04,  6.9243e-01,
-         1.0909e-04,  1.1328e+00,  1.1293e+00,  8.0724e-01, -4.6717e-01,
-         1.3443e+00, -4.4625e-02,  1.3921e-01,  5.0773e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 5.3471e-01,  3.9068e-01,  5.0698e-02, -7.6962e-02, -1.3458e-01,
-        -5.2232e-01, -4.8205e-01, -7.5373e-01,  3.5691e-01, -3.1490e-01,
-        -6.9891e-01, -1.5452e-01, -1.6866e-01,  3.8841e-01, -1.3773e+00,
-        -6.3968e-01, -6.0274e-01, -6.9783e-01, -1.0155e+00, -1.8802e-01,
-        -4.3873e-02, -1.1034e-01, -2.6079e-01, -5.5225e-01,  9.6189e-02,
-         9.8750e-02,  1.6891e-01, -3.0405e-01,  2.7231e-01,  1.5338e+00,
-        -4.6014e-01, -3.3035e-02,  0.0000e+00, -2.6875e-01, -4.9604e-01,
-        -3.6592e-01,  1.1648e+00,  0.0000e+00,  9.0979e-01,  7.2409e-01,
-        -4.3818e-01,  1.0187e+00, -6.6932e-02,  6.2557e-01,  2.1692e-01,
-         1.4740e-01,  5.4063e-01,  9.4134e-02,  8.9168e-01,  2.8511e-01,
-        -3.7572e-02,  2.1620e-01, -4.0274e-01,  9.4125e-04,  6.9243e-01,
-         0.0000e+00,  1.1328e+00,  1.1293e+00,  8.0724e-01, -4.6717e-01,
-         1.3443e+00, -4.4625e-02,  1.3921e-01,  5.0773e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 5.3471e-01,  3.9068e-01,  5.0698e-02, -7.6962e-02, -1.3458e-01,
-        -5.2232e-01, -4.8205e-01, -7.5373e-01,  3.5691e-01, -3.1490e-01,
-        -6.9891e-01, -1.5452e-01, -1.6866e-01,  3.8841e-01, -1.3773e+00,
-        -6.3968e-01, -6.0274e-01, -6.9783e-01, -1.0155e+00, -1.8802e-01,
-        -4.3873e-02, -1.1034e-01, -2.6079e-01, -5.5225e-01,  9.6189e-02,
-         9.8750e-02,  1.6891e-01, -3.0405e-01,  2.7231e-01,  1.5338e+00,
-        -4.6014e-01, -3.3035e-02,  0.0000e+00, -2.6875e-01, -4.9604e-01,
-        -3.6592e-01,  1.1648e+00,  0.0000e+00,  9.0979e-01,  7.2409e-01,
-        -4.3818e-01,  1.0187e+00, -6.6932e-02,  6.2557e-01,  2.1692e-01,
-         1.4740e-01,  5.4063e-01,  9.4134e-02,  8.9168e-01,  2.8511e-01,
-        -3.7572e-02,  2.1620e-01, -4.0274e-01,  9.4125e-04,  6.9243e-01,
-         0.0000e+00,  1.1328e+00,  1.1293e+00,  8.0724e-01, -4.6717e-01,
-         1.3443e+00, -4.4625e-02,  1.3921e-01,  5.0773e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 5.6860e-01,  4.2000e-01,  3.7833e-03, -9.2718e-02, -1.2073e-01,
-        -5.2624e-01, -5.1463e-01, -7.4893e-01,  4.1616e-01, -3.2364e-01,
-        -7.6966e-01, -2.1930e-01, -2.1248e-01,  3.6566e-01, -1.3799e+00,
-        -6.4627e-01, -6.1324e-01, -7.1704e-01, -1.0153e+00, -1.9522e-01,
-        -2.1602e-02, -8.9539e-02, -2.6052e-01, -6.1614e-01,  7.5086e-02,
-         1.4300e-02,  1.4274e-01, -2.8074e-01,  2.4918e-01,  1.5312e+00,
-        -4.7389e-01, -4.7057e-02, -1.7498e-03, -2.4997e-01, -4.8610e-01,
-        -3.6019e-01,  1.1694e+00,  0.0000e+00,  9.1383e-01,  7.4334e-01,
-        -3.8021e-01,  1.0194e+00, -1.1672e-01,  6.2444e-01,  1.6246e-01,
-         1.7479e-01,  5.3856e-01,  2.3581e-02,  8.7764e-01,  2.8786e-01,
-        -6.7085e-02,  1.6407e-01, -3.3969e-01, -3.0659e-02,  7.1903e-01,
-         9.4467e-05,  1.1293e+00,  1.1426e+00,  8.0919e-01, -4.4248e-01,
-         1.3436e+00, -9.8711e-02,  5.9063e-02,  5.0502e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5686,  0.4200,  0.0038, -0.0927, -0.1207, -0.5262, -0.5146, -0.7489,
-         0.4162, -0.3236, -0.7697, -0.2193, -0.2125,  0.3657, -1.3799, -0.6463,
-        -0.6132, -0.7170, -1.0153, -0.1952, -0.0216, -0.0895, -0.2605, -0.6161,
-         0.0751,  0.0143,  0.1427, -0.2807,  0.2492,  1.5312, -0.4739, -0.0471,
-         0.0000, -0.2500, -0.4861, -0.3602,  1.1694,  0.0000,  0.9138,  0.7433,
-        -0.3802,  1.0194, -0.1167,  0.6244,  0.1625,  0.1748,  0.5386,  0.0236,
-         0.8776,  0.2879, -0.0671,  0.1641, -0.3397, -0.0307,  0.7190,  0.0000,
-         1.1293,  0.0000,  0.8092, -0.4425,  1.3436, -0.0987,  0.0591,  0.5050],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5686,  0.4200,  0.0038, -0.0927, -0.1207, -0.5262, -0.5146, -0.7489,
-         0.4162, -0.3236, -0.7697, -0.2193, -0.2125,  0.3657, -1.3799, -0.6463,
-        -0.6132, -0.7170, -1.0153, -0.1952, -0.0216, -0.0895, -0.2605, -0.6161,
-         0.0751,  0.0143,  0.1427, -0.2807,  0.2492,  1.5312, -0.4739, -0.0471,
-         0.0000, -0.2500, -0.4861, -0.3602,  1.1694,  0.0000,  0.9138,  0.7433,
-        -0.3802,  1.0194, -0.1167,  0.6244,  0.1625,  0.1748,  0.5386,  0.0236,
-         0.8776,  0.2879, -0.0671,  0.1641, -0.3397, -0.0307,  0.7190,  0.0000,
-         1.1293,  0.0000,  0.8092, -0.4425,  1.3436, -0.0987,  0.0591,  0.5050],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.8695e-01,  5.1189e-01, -1.0347e-01, -6.3838e-02, -2.4159e-01,
-        -5.1535e-01, -5.4151e-01, -7.4287e-01,  4.5508e-01, -3.0577e-01,
-        -8.0039e-01, -3.0833e-01, -2.9383e-01,  3.1051e-01, -1.3850e+00,
-        -6.4033e-01, -6.4625e-01, -7.1002e-01, -1.0137e+00, -2.1913e-01,
-         4.5566e-02,  1.1794e-01, -1.9909e-01, -6.6377e-01,  1.5014e-01,
-        -5.2223e-02,  8.8079e-02, -2.2341e-01,  2.5711e-01,  1.5286e+00,
-        -5.2051e-01, -3.6771e-02, -1.5138e-03, -6.0769e-02, -4.7096e-01,
-        -3.4831e-01,  1.1756e+00,  0.0000e+00,  9.2941e-01,  7.6742e-01,
-        -2.2049e-01,  1.0226e+00, -1.7491e-01,  6.1707e-01,  1.9571e-01,
-         2.0024e-01,  5.3333e-01, -4.6268e-02,  8.6270e-01,  2.6868e-01,
-        -4.8197e-02,  1.8411e-01, -2.6200e-01,  7.5846e-02,  7.4545e-01,
-         8.1730e-05,  1.1260e+00,  1.1552e-02,  8.0656e-01, -3.8243e-01,
-         1.3437e+00, -1.3187e-01,  2.8380e-02,  4.6475e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5870,  0.5119, -0.1035, -0.0638, -0.2416, -0.5153, -0.5415, -0.7429,
-         0.4551, -0.3058, -0.8004, -0.3083, -0.2938,  0.3105, -1.3850, -0.6403,
-        -0.6463, -0.7100, -1.0137, -0.2191,  0.0456,  0.1179, -0.1991, -0.6638,
-         0.1501, -0.0522,  0.0881, -0.2234,  0.2571,  1.5286, -0.5205, -0.0368,
-         0.0000, -0.0608, -0.4710, -0.3483,  1.1756,  0.0000,  0.9294,  0.7674,
-        -0.2205,  1.0226, -0.1749,  0.6171,  0.1957,  0.2002,  0.5333, -0.0463,
-         0.8627,  0.2687, -0.0482,  0.1841, -0.2620,  0.0758,  0.7455,  0.0000,
-         1.1260,  0.0000,  0.8066, -0.3824,  1.3437, -0.1319,  0.0284,  0.4648],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5870,  0.5119, -0.1035, -0.0638, -0.2416, -0.5153, -0.5415, -0.7429,
-         0.4551, -0.3058, -0.8004, -0.3083, -0.2938,  0.3105, -1.3850, -0.6403,
-        -0.6463, -0.7100, -1.0137, -0.2191,  0.0456,  0.1179, -0.1991, -0.6638,
-         0.1501, -0.0522,  0.0881, -0.2234,  0.2571,  1.5286, -0.5205, -0.0368,
-         0.0000, -0.0608, -0.4710, -0.3483,  1.1756,  0.0000,  0.9294,  0.7674,
-        -0.2205,  1.0226, -0.1749,  0.6171,  0.1957,  0.2002,  0.5333, -0.0463,
-         0.8627,  0.2687, -0.0482,  0.1841, -0.2620,  0.0758,  0.7455,  0.0000,
-         1.1260,  0.0000,  0.8066, -0.3824,  1.3437, -0.1319,  0.0284,  0.4648],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.2043e-01,  5.7769e-01, -8.7914e-02,  8.1186e-03, -2.9066e-01,
-        -4.7513e-01, -5.7123e-01, -7.3882e-01,  4.9853e-01, -3.1102e-01,
-        -8.1509e-01, -3.3172e-01, -2.2306e-01,  2.3820e-01, -1.3906e+00,
-        -6.3904e-01, -6.8925e-01, -6.8894e-01, -1.0131e+00, -2.2469e-01,
-         6.7976e-02,  2.4498e-01, -2.1487e-01, -6.9443e-01,  1.7633e-01,
-        -8.9693e-02,  7.7979e-02, -1.8881e-01,  2.7238e-01,  1.5232e+00,
-        -5.2430e-01,  2.4418e-02, -1.3086e-03,  3.6408e-02, -4.6946e-01,
-        -3.1815e-01,  1.1821e+00,  0.0000e+00,  9.5164e-01,  7.8662e-01,
-        -1.2208e-01,  1.0336e+00, -1.5476e-01,  6.1636e-01,  1.9148e-01,
-         1.9525e-01,  5.2666e-01, -3.8751e-02,  8.3619e-01,  2.3729e-01,
-        -3.1254e-02,  2.1893e-01, -2.2052e-01,  1.3929e-01,  7.7025e-01,
-         7.0648e-05,  1.1222e+00,  9.9856e-03,  8.0597e-01, -3.9623e-01,
-         1.3461e+00, -1.1712e-01,  1.8521e-02,  3.9878e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6204,  0.5777, -0.0879,  0.0081, -0.2907, -0.4751, -0.5712, -0.7388,
-         0.4985, -0.3110, -0.8151, -0.3317, -0.2231,  0.2382, -1.3906, -0.6390,
-        -0.6892, -0.6889, -1.0131, -0.2247,  0.0680,  0.2450, -0.2149, -0.6944,
-         0.1763, -0.0897,  0.0780, -0.1888,  0.2724,  1.5232, -0.5243,  0.0244,
-         0.0000,  0.0364, -0.4695, -0.3181,  1.1821,  0.0000,  0.9516,  0.7866,
-        -0.1221,  1.0336, -0.1548,  0.6164,  0.1915,  0.1953,  0.5267, -0.0388,
-         0.8362,  0.2373, -0.0313,  0.2189, -0.2205,  0.1393,  0.7703,  0.0000,
-         1.1222,  0.0000,  0.8060, -0.3962,  1.3461, -0.1171,  0.0185,  0.3988],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6204,  0.5777, -0.0879,  0.0081, -0.2907, -0.4751, -0.5712, -0.7388,
-         0.4985, -0.3110, -0.8151, -0.3317, -0.2231,  0.2382, -1.3906, -0.6390,
-        -0.6892, -0.6889, -1.0131, -0.2247,  0.0680,  0.2450, -0.2149, -0.6944,
-         0.1763, -0.0897,  0.0780, -0.1888,  0.2724,  1.5232, -0.5243,  0.0244,
-         0.0000,  0.0364, -0.4695, -0.3181,  1.1821,  0.0000,  0.9516,  0.7866,
-        -0.1221,  1.0336, -0.1548,  0.6164,  0.1915,  0.1953,  0.5267, -0.0388,
-         0.8362,  0.2373, -0.0313,  0.2189, -0.2205,  0.1393,  0.7703,  0.0000,
-         1.1222,  0.0000,  0.8060, -0.3962,  1.3461, -0.1171,  0.0185,  0.3988],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.5761e-01,  5.8153e-01,  1.7425e-02,  1.3276e-01, -2.4271e-01,
-        -4.0866e-01, -5.9760e-01, -7.3592e-01,  5.0567e-01, -3.3411e-01,
-        -8.3008e-01, -2.9983e-01,  8.6295e-03,  2.9145e-01, -1.3984e+00,
-        -6.4478e-01, -7.2627e-01, -6.6784e-01, -1.0158e+00, -2.0468e-01,
-         1.1341e-02,  1.8765e-01, -2.9849e-01, -7.0205e-01,  7.2261e-02,
-        -7.8381e-02,  1.3910e-01, -1.8352e-01,  1.8774e-01,  1.5193e+00,
-        -4.8954e-01,  1.1920e-01, -1.1301e-03, -1.5725e-02, -4.6944e-01,
-        -3.0583e-01,  1.1898e+00,  0.0000e+00,  1.0006e+00,  8.0854e-01,
-        -8.0463e-02,  1.0350e+00, -9.9988e-02,  6.5440e-01,  1.2036e-01,
-         1.4133e-01,  5.1159e-01,  7.5100e-02,  8.1171e-01,  1.0939e-01,
-        -4.5023e-02,  1.1455e-01, -2.4772e-01,  2.3981e-02,  7.9062e-01,
-         6.1015e-05,  1.1189e+00,  8.6240e-03,  7.9341e-01, -4.5599e-01,
-         1.3490e+00, -6.9826e-02,  3.2040e-02,  3.5432e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6576,  0.5815,  0.0174,  0.1328, -0.2427, -0.4087, -0.5976, -0.7359,
-         0.5057, -0.3341, -0.8301, -0.2998,  0.0086,  0.2915, -1.3984, -0.6448,
-        -0.7263, -0.6678, -1.0158, -0.2047,  0.0113,  0.1877, -0.2985, -0.7020,
-         0.0723, -0.0784,  0.1391, -0.1835,  0.1877,  1.5193, -0.4895,  0.1192,
-         0.0000, -0.0157, -0.4694, -0.3058,  1.1898,  0.0000,  1.0006,  0.8085,
-        -0.0805,  1.0350, -0.1000,  0.6544,  0.1204,  0.1413,  0.5116,  0.0751,
-         0.8117,  0.1094, -0.0450,  0.1146, -0.2477,  0.0240,  0.7906,  0.0000,
-         1.1189,  0.0000,  0.7934, -0.4560,  1.3490, -0.0698,  0.0320,  0.3543],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6576,  0.5815,  0.0174,  0.1328, -0.2427, -0.4087, -0.5976, -0.7359,
-         0.5057, -0.3341, -0.8301, -0.2998,  0.0086,  0.2915, -1.3984, -0.6448,
-        -0.7263, -0.6678, -1.0158, -0.2047,  0.0113,  0.1877, -0.2985, -0.7020,
-         0.0723, -0.0784,  0.1391, -0.1835,  0.1877,  1.5193, -0.4895,  0.1192,
-         0.0000, -0.0157, -0.4694, -0.3058,  1.1898,  0.0000,  1.0006,  0.8085,
-        -0.0805,  1.0350, -0.1000,  0.6544,  0.1204,  0.1413,  0.5116,  0.0751,
-         0.8117,  0.1094, -0.0450,  0.1146, -0.2477,  0.0240,  0.7906,  0.0000,
-         1.1189,  0.0000,  0.7934, -0.4560,  1.3490, -0.0698,  0.0320,  0.3543],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.5791e-01,  5.9043e-01, -2.2695e-02,  2.2607e-01, -6.7192e-02,
-        -3.6700e-01, -6.1101e-01, -7.3357e-01,  4.8594e-01, -3.4695e-01,
-        -8.5088e-01, -3.1682e-01,  9.3813e-02,  3.0281e-01, -1.4076e+00,
-        -6.4638e-01, -7.4318e-01, -6.4954e-01, -1.0198e+00, -2.2862e-01,
-        -6.6806e-02,  1.6071e-01, -3.5988e-01, -6.6233e-01, -3.4655e-02,
-        -1.0185e-01,  1.6433e-01, -1.5656e-01,  3.7063e-02,  1.5152e+00,
-        -4.3781e-01,  1.6938e-01, -9.7521e-04, -2.9853e-02, -4.8813e-01,
-        -3.4032e-01,  1.1980e+00,  0.0000e+00,  1.0326e+00,  8.1171e-01,
-        -2.6099e-03,  1.0324e+00, -8.1139e-02,  6.8462e-01,  1.1865e-02,
-         1.7951e-01,  4.9846e-01,  1.3710e-01,  7.9437e-01, -1.0549e-02,
-        -6.6282e-02, -8.2287e-02, -2.8151e-01, -4.7410e-02,  7.8354e-01,
-         5.2650e-05,  1.1149e+00,  7.4418e-03,  7.7509e-01, -4.7077e-01,
-         1.3503e+00, -9.8224e-03,  5.0041e-02,  2.5901e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6579,  0.5904, -0.0227,  0.2261, -0.0672, -0.3670, -0.6110, -0.7336,
-         0.4859, -0.3470, -0.8509, -0.3168,  0.0938,  0.3028, -1.4076, -0.6464,
-        -0.7432, -0.6495, -1.0198, -0.2286, -0.0668,  0.1607, -0.3599, -0.6623,
-        -0.0347, -0.1018,  0.1643, -0.1566,  0.0371,  1.5152, -0.4378,  0.1694,
-         0.0000, -0.0299, -0.4881, -0.3403,  1.1980,  0.0000,  1.0326,  0.8117,
-        -0.0026,  1.0324, -0.0811,  0.6846,  0.0119,  0.1795,  0.4985,  0.1371,
-         0.7944, -0.0105, -0.0663, -0.0823, -0.2815, -0.0474,  0.7835,  0.0000,
-         1.1149,  0.0000,  0.7751, -0.4708,  1.3503, -0.0098,  0.0500,  0.2590],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6579,  0.5904, -0.0227,  0.2261, -0.0672, -0.3670, -0.6110, -0.7336,
-         0.4859, -0.3470, -0.8509, -0.3168,  0.0938,  0.3028, -1.4076, -0.6464,
-        -0.7432, -0.6495, -1.0198, -0.2286, -0.0668,  0.1607, -0.3599, -0.6623,
-        -0.0347, -0.1018,  0.1643, -0.1566,  0.0371,  1.5152, -0.4378,  0.1694,
-         0.0000, -0.0299, -0.4881, -0.3403,  1.1980,  0.0000,  1.0326,  0.8117,
-        -0.0026,  1.0324, -0.0811,  0.6846,  0.0119,  0.1795,  0.4985,  0.1371,
-         0.7944, -0.0105, -0.0663, -0.0823, -0.2815, -0.0474,  0.7835,  0.0000,
-         1.1149,  0.0000,  0.7751, -0.4708,  1.3503, -0.0098,  0.0500,  0.2590],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.5276e-01,  6.1241e-01, -1.1744e-01,  2.8536e-01,  7.7322e-02,
-        -3.4502e-01, -6.1922e-01, -7.2868e-01,  4.6212e-01, -3.6640e-01,
-        -8.7519e-01, -3.3555e-01,  1.2206e-01,  2.7276e-01, -1.4147e+00,
-        -6.3029e-01, -7.4412e-01, -6.4414e-01, -1.0199e+00, -2.4335e-01,
-        -1.3920e-01,  1.5055e-01, -3.9791e-01, -6.2108e-01, -1.2426e-01,
-        -1.3863e-01,  1.8760e-01, -1.3276e-01, -7.3575e-02,  1.5109e+00,
-        -3.9482e-01,  1.9435e-01, -8.4082e-04, -5.3609e-02, -4.8687e-01,
-        -3.5447e-01,  1.2015e+00,  0.0000e+00,  1.0460e+00,  8.0637e-01,
-         6.3426e-02,  1.0320e+00, -1.0275e-01,  7.0502e-01, -2.1281e-02,
-         2.2734e-01,  4.7988e-01,  1.5777e-01,  7.9043e-01, -9.4417e-02,
-        -8.0935e-02, -2.5989e-01, -2.9725e-01, -1.7429e-01,  7.7661e-01,
-         4.5395e-05,  1.1106e+00,  6.4162e-03,  7.5577e-01, -4.9298e-01,
-         1.3510e+00,  3.2247e-02,  5.6908e-02,  1.6271e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6528,  0.6124, -0.1174,  0.2854,  0.0773, -0.3450, -0.6192, -0.7287,
-         0.4621, -0.3664, -0.8752, -0.3356,  0.1221,  0.2728, -1.4147, -0.6303,
-        -0.7441, -0.6441, -1.0199, -0.2434, -0.1392,  0.1506, -0.3979, -0.6211,
-        -0.1243, -0.1386,  0.1876, -0.1328, -0.0736,  1.5109, -0.3948,  0.1944,
-         0.0000, -0.0536, -0.4869, -0.3545,  1.2015,  0.0000,  1.0460,  0.8064,
-         0.0634,  1.0320, -0.1027,  0.7050, -0.0213,  0.2273,  0.4799,  0.1578,
-         0.7904, -0.0944, -0.0809, -0.2599, -0.2972, -0.1743,  0.7766,  0.0000,
-         1.1106,  0.0000,  0.7558, -0.4930,  1.3510,  0.0322,  0.0569,  0.1627],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6528,  0.6124, -0.1174,  0.2854,  0.0773, -0.3450, -0.6192, -0.7287,
-         0.4621, -0.3664, -0.8752, -0.3356,  0.1221,  0.2728, -1.4147, -0.6303,
-        -0.7441, -0.6441, -1.0199, -0.2434, -0.1392,  0.1506, -0.3979, -0.6211,
-        -0.1243, -0.1386,  0.1876, -0.1328, -0.0736,  1.5109, -0.3948,  0.1944,
-         0.0000, -0.0536, -0.4869, -0.3545,  1.2015,  0.0000,  1.0460,  0.8064,
-         0.0634,  1.0320, -0.1027,  0.7050, -0.0213,  0.2273,  0.4799,  0.1578,
-         0.7904, -0.0944, -0.0809, -0.2599, -0.2972, -0.1743,  0.7766,  0.0000,
-         1.1106,  0.0000,  0.7558, -0.4930,  1.3510,  0.0322,  0.0569,  0.1627],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.5549e-01,  6.4026e-01, -1.5770e-01,  3.4894e-01,  1.6809e-01,
-        -3.0708e-01, -6.1959e-01, -7.1941e-01,  4.2907e-01, -3.6012e-01,
-        -8.9371e-01, -3.5282e-01,  1.2659e-01,  2.1759e-01, -1.4199e+00,
-        -6.0804e-01, -7.4914e-01, -6.3329e-01, -1.0186e+00, -2.4863e-01,
-        -1.6980e-01,  1.9601e-01, -4.1304e-01, -6.0764e-01, -1.5422e-01,
-        -1.6428e-01,  2.2579e-01, -8.1978e-02, -1.7287e-01,  1.5077e+00,
-        -3.7021e-01,  2.2123e-01, -7.2436e-04, -1.9138e-02, -4.7626e-01,
-        -3.5783e-01,  1.1999e+00,  0.0000e+00,  1.0484e+00,  7.9066e-01,
-         1.5014e-01,  1.0316e+00, -1.0207e-01,  7.2148e-01, -4.9782e-02,
-         2.8153e-01,  4.5881e-01,  1.6924e-01,  7.9796e-01, -1.5615e-01,
-        -7.0846e-02, -3.7062e-01, -3.3120e-01, -2.0414e-01,  7.5795e-01,
-         3.9107e-05,  1.1057e+00,  5.5276e-03,  7.4733e-01, -4.9619e-01,
-         1.3497e+00,  9.5207e-02,  4.9101e-02,  5.3319e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6555,  0.6403, -0.1577,  0.3489,  0.1681, -0.3071, -0.6196, -0.7194,
-         0.4291, -0.3601, -0.8937, -0.3528,  0.1266,  0.2176, -1.4199, -0.6080,
-        -0.7491, -0.6333, -1.0186, -0.2486, -0.1698,  0.1960, -0.4130, -0.6076,
-        -0.1542, -0.1643,  0.2258, -0.0820, -0.1729,  1.5077, -0.3702,  0.2212,
-         0.0000, -0.0191, -0.4763, -0.3578,  1.1999,  0.0000,  1.0484,  0.7907,
-         0.1501,  1.0316, -0.1021,  0.7215, -0.0498,  0.2815,  0.4588,  0.1692,
-         0.7980, -0.1561, -0.0708, -0.3706, -0.3312, -0.2041,  0.7579,  0.0000,
-         1.1057,  0.0000,  0.7473, -0.4962,  1.3497,  0.0952,  0.0491,  0.0533],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6555,  0.6403, -0.1577,  0.3489,  0.1681, -0.3071, -0.6196, -0.7194,
-         0.4291, -0.3601, -0.8937, -0.3528,  0.1266,  0.2176, -1.4199, -0.6080,
-        -0.7491, -0.6333, -1.0186, -0.2486, -0.1698,  0.1960, -0.4130, -0.6076,
-        -0.1542, -0.1643,  0.2258, -0.0820, -0.1729,  1.5077, -0.3702,  0.2212,
-         0.0000, -0.0191, -0.4763, -0.3578,  1.1999,  0.0000,  1.0484,  0.7907,
-         0.1501,  1.0316, -0.1021,  0.7215, -0.0498,  0.2815,  0.4588,  0.1692,
-         0.7980, -0.1561, -0.0708, -0.3706, -0.3312, -0.2041,  0.7579,  0.0000,
-         1.1057,  0.0000,  0.7473, -0.4962,  1.3497,  0.0952,  0.0491,  0.0533],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.5394e-01,  6.6149e-01, -1.6717e-01,  3.9510e-01,  3.1591e-01,
-        -2.7043e-01, -6.1437e-01, -7.0739e-01,  3.9745e-01, -3.3005e-01,
-        -8.9599e-01, -3.6284e-01,  1.5788e-01,  1.2871e-01, -1.4203e+00,
-        -5.8095e-01, -7.6507e-01, -6.0874e-01, -1.0172e+00, -2.4270e-01,
-        -1.5161e-01,  2.5262e-01, -3.9610e-01, -6.0651e-01, -1.0586e-01,
-        -1.6273e-01,  2.6542e-01, -1.1012e-02, -2.6522e-01,  1.5065e+00,
-        -4.0852e-01,  2.4853e-01, -6.2354e-04,  5.7846e-02, -4.4987e-01,
-        -3.4467e-01,  1.1922e+00,  0.0000e+00,  1.0315e+00,  7.6163e-01,
-         2.2823e-01,  1.0298e+00, -9.6764e-02,  7.2481e-01, -1.1064e-01,
-         2.9725e-01,  4.3652e-01,  1.6715e-01,  8.0285e-01, -2.5267e-01,
-        -1.2949e-02, -4.0625e-01, -3.8560e-01, -1.6056e-01,  7.2701e-01,
-         3.3664e-05,  1.0993e+00,  4.7582e-03,  7.4120e-01, -4.9245e-01,
-         1.3472e+00,  1.4747e-01,  6.7337e-02, -4.8052e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6539,  0.6615, -0.1672,  0.3951,  0.3159, -0.2704, -0.6144, -0.7074,
-         0.3974, -0.3300, -0.8960, -0.3628,  0.1579,  0.1287, -1.4203, -0.5809,
-        -0.7651, -0.6087, -1.0172, -0.2427, -0.1516,  0.2526, -0.3961, -0.6065,
-        -0.1059, -0.1627,  0.2654, -0.0110, -0.2652,  1.5065, -0.4085,  0.2485,
-         0.0000,  0.0578, -0.4499, -0.3447,  1.1922,  0.0000,  1.0315,  0.7616,
-         0.2282,  1.0298, -0.0968,  0.7248, -0.1106,  0.2972,  0.4365,  0.1671,
-         0.8029, -0.2527, -0.0129, -0.4063, -0.3856, -0.1606,  0.7270,  0.0000,
-         1.0993,  0.0000,  0.7412, -0.4925,  1.3472,  0.1475,  0.0673, -0.0481],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6539,  0.6615, -0.1672,  0.3951,  0.3159, -0.2704, -0.6144, -0.7074,
-         0.3974, -0.3300, -0.8960, -0.3628,  0.1579,  0.1287, -1.4203, -0.5809,
-        -0.7651, -0.6087, -1.0172, -0.2427, -0.1516,  0.2526, -0.3961, -0.6065,
-        -0.1059, -0.1627,  0.2654, -0.0110, -0.2652,  1.5065, -0.4085,  0.2485,
-         0.0000,  0.0578, -0.4499, -0.3447,  1.1922,  0.0000,  1.0315,  0.7616,
-         0.2282,  1.0298, -0.0968,  0.7248, -0.1106,  0.2972,  0.4365,  0.1671,
-         0.8029, -0.2527, -0.0129, -0.4063, -0.3856, -0.1606,  0.7270,  0.0000,
-         1.0993,  0.0000,  0.7412, -0.4925,  1.3472,  0.1475,  0.0673, -0.0481],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.5679e-01,  6.6606e-01, -2.0509e-01,  4.5879e-01,  4.4541e-01,
-        -2.1834e-01, -6.0882e-01, -6.9886e-01,  3.8084e-01, -3.1058e-01,
-        -8.9759e-01, -3.4203e-01,  1.9194e-01,  6.1897e-02, -1.4182e+00,
-        -5.4835e-01, -7.8447e-01, -5.8274e-01, -1.0150e+00, -2.3620e-01,
-        -1.2336e-01,  2.7482e-01, -3.7972e-01, -6.1043e-01, -5.8284e-02,
-        -1.5507e-01,  3.2312e-01,  5.6568e-02, -3.4358e-01,  1.5032e+00,
-        -4.4974e-01,  2.6548e-01, -5.3634e-04,  9.2124e-02, -4.1867e-01,
-        -3.1881e-01,  1.1841e+00,  0.0000e+00,  1.0178e+00,  7.3072e-01,
-         2.7513e-01,  1.0291e+00, -1.1836e-01,  7.2450e-01, -1.6223e-01,
-         2.9663e-01,  4.1688e-01,  1.9078e-01,  8.0279e-01, -3.3273e-01,
-         5.7784e-02, -4.1537e-01, -4.2721e-01, -1.2661e-01,  6.9766e-01,
-         2.8956e-05,  1.0930e+00,  4.0928e-03,  7.2737e-01, -4.8216e-01,
-         1.3432e+00,  1.8723e-01,  9.9279e-02, -1.4648e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6568,  0.6661, -0.2051,  0.4588,  0.4454, -0.2183, -0.6088, -0.6989,
-         0.3808, -0.3106, -0.8976, -0.3420,  0.1919,  0.0619, -1.4182, -0.5484,
-        -0.7845, -0.5827, -1.0150, -0.2362, -0.1234,  0.2748, -0.3797, -0.6104,
-        -0.0583, -0.1551,  0.3231,  0.0566, -0.3436,  1.5032, -0.4497,  0.2655,
-         0.0000,  0.0921, -0.4187, -0.3188,  1.1841,  0.0000,  1.0178,  0.7307,
-         0.2751,  1.0291, -0.1184,  0.7245, -0.1622,  0.2966,  0.4169,  0.1908,
-         0.8028, -0.3327,  0.0578, -0.4154, -0.4272, -0.1266,  0.6977,  0.0000,
-         1.0930,  0.0000,  0.7274, -0.4822,  1.3432,  0.1872,  0.0993, -0.1465],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6568,  0.6661, -0.2051,  0.4588,  0.4454, -0.2183, -0.6088, -0.6989,
-         0.3808, -0.3106, -0.8976, -0.3420,  0.1919,  0.0619, -1.4182, -0.5484,
-        -0.7845, -0.5827, -1.0150, -0.2362, -0.1234,  0.2748, -0.3797, -0.6104,
-        -0.0583, -0.1551,  0.3231,  0.0566, -0.3436,  1.5032, -0.4497,  0.2655,
-         0.0000,  0.0921, -0.4187, -0.3188,  1.1841,  0.0000,  1.0178,  0.7307,
-         0.2751,  1.0291, -0.1184,  0.7245, -0.1622,  0.2966,  0.4169,  0.1908,
-         0.8028, -0.3327,  0.0578, -0.4154, -0.4272, -0.1266,  0.6977,  0.0000,
-         1.0930,  0.0000,  0.7274, -0.4822,  1.3432,  0.1872,  0.0993, -0.1465],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.6211e-01,  6.5763e-01, -1.9453e-01,  4.5795e-01,  5.3745e-01,
-        -1.6628e-01, -6.0263e-01, -6.8882e-01,  3.4290e-01, -2.9219e-01,
-        -8.9423e-01, -2.9939e-01,  2.7980e-01,  4.7185e-02, -1.4150e+00,
-        -5.2399e-01, -7.8968e-01, -5.6064e-01, -1.0131e+00, -2.2467e-01,
-        -1.1156e-01,  1.8769e-01, -3.8474e-01, -6.1338e-01, -9.8566e-02,
-        -1.5622e-01,  3.5904e-01,  5.8445e-02, -4.3411e-01,  1.5020e+00,
-        -4.6262e-01,  2.8525e-01, -4.6099e-04,  2.7604e-02, -3.9669e-01,
-        -2.8849e-01,  1.1730e+00,  0.0000e+00,  9.9839e-01,  7.1344e-01,
-         2.5722e-01,  1.0200e+00, -1.0944e-01,  7.1377e-01, -1.5735e-01,
-         2.7550e-01,  4.0078e-01,  2.2695e-01,  8.1072e-01, -3.6981e-01,
-         1.0578e-01, -4.6131e-01, -4.4949e-01, -1.5753e-01,  6.8657e-01,
-         2.4888e-05,  1.0876e+00,  3.5178e-03,  7.2118e-01, -4.8478e-01,
-         1.3372e+00,  1.9023e-01,  1.2497e-01, -2.4141e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6621,  0.6576, -0.1945,  0.4580,  0.5375, -0.1663, -0.6026, -0.6888,
-         0.3429, -0.2922, -0.8942, -0.2994,  0.2798,  0.0472, -1.4150, -0.5240,
-        -0.7897, -0.5606, -1.0131, -0.2247, -0.1116,  0.1877, -0.3847, -0.6134,
-        -0.0986, -0.1562,  0.3590,  0.0584, -0.4341,  1.5020, -0.4626,  0.2853,
-         0.0000,  0.0276, -0.3967, -0.2885,  1.1730,  0.0000,  0.9984,  0.7134,
-         0.2572,  1.0200, -0.1094,  0.7138, -0.1574,  0.2755,  0.4008,  0.2269,
-         0.8107, -0.3698,  0.1058, -0.4613, -0.4495, -0.1575,  0.6866,  0.0000,
-         1.0876,  0.0000,  0.7212, -0.4848,  1.3372,  0.1902,  0.1250, -0.2414],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6621,  0.6576, -0.1945,  0.4580,  0.5375, -0.1663, -0.6026, -0.6888,
-         0.3429, -0.2922, -0.8942, -0.2994,  0.2798,  0.0472, -1.4150, -0.5240,
-        -0.7897, -0.5606, -1.0131, -0.2247, -0.1116,  0.1877, -0.3847, -0.6134,
-        -0.0986, -0.1562,  0.3590,  0.0584, -0.4341,  1.5020, -0.4626,  0.2853,
-         0.0000,  0.0276, -0.3967, -0.2885,  1.1730,  0.0000,  0.9984,  0.7134,
-         0.2572,  1.0200, -0.1094,  0.7138, -0.1574,  0.2755,  0.4008,  0.2269,
-         0.8107, -0.3698,  0.1058, -0.4613, -0.4495, -0.1575,  0.6866,  0.0000,
-         1.0876,  0.0000,  0.7212, -0.4848,  1.3372,  0.1902,  0.1250, -0.2414],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.6349e-01,  6.4107e-01, -2.3389e-01,  4.5054e-01,  5.6821e-01,
-        -1.3208e-01, -5.9898e-01, -6.8672e-01,  2.9663e-01, -2.9051e-01,
-        -8.8753e-01, -2.6570e-01,  3.3972e-01,  4.8943e-02, -1.4117e+00,
-        -5.0666e-01, -7.8431e-01, -5.4074e-01, -1.0111e+00, -1.9586e-01,
-        -7.4469e-02,  8.8066e-02, -3.8997e-01, -5.8413e-01, -1.1882e-01,
-        -1.6900e-01,  3.7506e-01,  4.5228e-02, -4.7427e-01,  1.5011e+00,
-        -4.8375e-01,  3.0744e-01, -3.9594e-04, -3.4891e-02, -3.6805e-01,
-        -2.4369e-01,  1.1643e+00,  0.0000e+00,  9.8065e-01,  6.9848e-01,
-         2.3319e-01,  1.0101e+00, -1.5108e-01,  7.0082e-01, -1.2771e-01,
-         2.4815e-01,  3.7265e-01,  2.8436e-01,  8.2575e-01, -3.9981e-01,
-         1.5613e-01, -4.8922e-01, -4.7689e-01, -1.9987e-01,  6.7648e-01,
-         2.1376e-05,  1.0839e+00,  3.0214e-03,  7.0114e-01, -4.5708e-01,
-         1.3324e+00,  1.7674e-01,  1.3306e-01, -3.5182e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6635,  0.6411, -0.2339,  0.4505,  0.5682, -0.1321, -0.5990, -0.6867,
-         0.2966, -0.2905, -0.8875, -0.2657,  0.3397,  0.0489, -1.4117, -0.5067,
-        -0.7843, -0.5407, -1.0111, -0.1959, -0.0745,  0.0881, -0.3900, -0.5841,
-        -0.1188, -0.1690,  0.3751,  0.0452, -0.4743,  1.5011, -0.4837,  0.3074,
-         0.0000, -0.0349, -0.3681, -0.2437,  1.1643,  0.0000,  0.9807,  0.6985,
-         0.2332,  1.0101, -0.1511,  0.7008, -0.1277,  0.2481,  0.3727,  0.2844,
-         0.8257, -0.3998,  0.1561, -0.4892, -0.4769, -0.1999,  0.6765,  0.0000,
-         1.0839,  0.0000,  0.7011, -0.4571,  1.3324,  0.1767,  0.1331, -0.3518],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6635,  0.6411, -0.2339,  0.4505,  0.5682, -0.1321, -0.5990, -0.6867,
-         0.2966, -0.2905, -0.8875, -0.2657,  0.3397,  0.0489, -1.4117, -0.5067,
-        -0.7843, -0.5407, -1.0111, -0.1959, -0.0745,  0.0881, -0.3900, -0.5841,
-        -0.1188, -0.1690,  0.3751,  0.0452, -0.4743,  1.5011, -0.4837,  0.3074,
-         0.0000, -0.0349, -0.3681, -0.2437,  1.1643,  0.0000,  0.9807,  0.6985,
-         0.2332,  1.0101, -0.1511,  0.7008, -0.1277,  0.2481,  0.3727,  0.2844,
-         0.8257, -0.3998,  0.1561, -0.4892, -0.4769, -0.1999,  0.6765,  0.0000,
-         1.0839,  0.0000,  0.7011, -0.4571,  1.3324,  0.1767,  0.1331, -0.3518],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.4914e-01,  6.4136e-01, -2.8685e-01,  4.2981e-01,  5.6004e-01,
-        -9.2208e-02, -5.9756e-01, -6.7507e-01,  2.3357e-01, -2.6646e-01,
-        -8.7867e-01, -2.8872e-01,  3.0000e-01,  3.4929e-03, -1.4068e+00,
-        -4.8711e-01, -7.6802e-01, -5.1273e-01, -1.0099e+00, -2.0459e-01,
-        -3.0585e-02,  1.0186e-01, -3.6468e-01, -5.2879e-01, -8.4328e-02,
-        -1.8924e-01,  3.2117e-01,  5.8324e-02, -4.7095e-01,  1.5046e+00,
-        -4.9192e-01,  2.7044e-01, -3.3983e-04, -6.0518e-03, -3.7204e-01,
-        -2.1853e-01,  1.1504e+00,  0.0000e+00,  9.3357e-01,  6.5178e-01,
-         2.5520e-01,  9.9338e-01, -1.3638e-01,  6.7179e-01, -1.2109e-01,
-         2.3878e-01,  3.3854e-01,  2.7951e-01,  8.4611e-01, -4.0436e-01,
-         2.2040e-01, -4.8488e-01, -5.0362e-01, -1.1984e-01,  6.4257e-01,
-         1.8347e-05,  1.0797e+00,  2.5932e-03,  6.7367e-01, -4.1221e-01,
-         1.3279e+00,  1.0727e-01,  1.3902e-01, -4.3714e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6491,  0.6414, -0.2869,  0.4298,  0.5600, -0.0922,  0.0000, -0.6751,
-         0.2336, -0.2665, -0.8787, -0.2887,  0.3000,  0.0035, -1.4068, -0.4871,
-        -0.7680, -0.5127, -1.0099, -0.2046, -0.0306,  0.1019, -0.3647, -0.5288,
-        -0.0843, -0.1892,  0.3212,  0.0583, -0.4709,  1.5046, -0.4919,  0.2704,
-         0.0000, -0.0061, -0.3720, -0.2185,  1.1504,  0.0000,  0.9336,  0.6518,
-         0.2552,  0.9934, -0.1364,  0.6718, -0.1211,  0.2388,  0.3385,  0.2795,
-         0.8461, -0.4044,  0.2204, -0.4849, -0.5036, -0.1198,  0.6426,  0.0000,
-         1.0797,  0.0000,  0.6737, -0.4122,  1.3279,  0.1073,  0.1390, -0.4371],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6491,  0.6414, -0.2869,  0.4298,  0.5600, -0.0922,  0.0000, -0.6751,
-         0.2336, -0.2665, -0.8787, -0.2887,  0.3000,  0.0035, -1.4068, -0.4871,
-        -0.7680, -0.5127, -1.0099, -0.2046, -0.0306,  0.1019, -0.3647, -0.5288,
-        -0.0843, -0.1892,  0.3212,  0.0583, -0.4709,  1.5046, -0.4919,  0.2704,
-         0.0000, -0.0061, -0.3720, -0.2185,  1.1504,  0.0000,  0.9336,  0.6518,
-         0.2552,  0.9934, -0.1364,  0.6718, -0.1211,  0.2388,  0.3385,  0.2795,
-         0.8461, -0.4044,  0.2204, -0.4849, -0.5036, -0.1198,  0.6426,  0.0000,
-         1.0797,  0.0000,  0.6737, -0.4122,  1.3279,  0.1073,  0.1390, -0.4371],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.4252e-01,  6.3402e-01, -2.8374e-01,  3.7684e-01,  5.3533e-01,
-        -3.9781e-02,  1.2158e-03, -6.6409e-01,  1.6676e-01, -2.5957e-01,
-        -8.7363e-01, -2.6339e-01,  2.4215e-01, -4.3323e-03, -1.4038e+00,
-        -4.9160e-01, -7.5233e-01, -4.9835e-01, -1.0068e+00, -2.1774e-01,
-         2.1627e-02,  7.2175e-02, -3.4702e-01, -4.5177e-01, -6.4269e-02,
-        -1.9702e-01,  2.6134e-01,  4.6130e-02, -4.5929e-01,  1.5048e+00,
-        -4.9724e-01,  2.2649e-01, -2.9147e-04, -2.3297e-03, -4.0629e-01,
-        -1.8653e-01,  1.1439e+00,  0.0000e+00,  9.1770e-01,  6.2277e-01,
-         2.6115e-01,  9.8423e-01, -1.0076e-01,  6.5142e-01, -5.1071e-02,
-         2.1387e-01,  3.2306e-01,  2.7766e-01,  8.3834e-01, -4.2490e-01,
-         2.6851e-01, -4.9343e-01, -5.3250e-01, -2.4451e-02,  6.2822e-01,
-         1.5736e-05,  1.0755e+00,  2.2242e-03,  6.2390e-01, -4.3738e-01,
-         1.3263e+00,  2.5933e-02,  1.4347e-01, -4.3168e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6425,  0.6340, -0.2837,  0.3768,  0.5353, -0.0398,  0.0000, -0.6641,
-         0.1668, -0.2596, -0.8736, -0.2634,  0.2422, -0.0043, -1.4038, -0.4916,
-        -0.7523, -0.4983, -1.0068, -0.2177,  0.0216,  0.0722, -0.3470, -0.4518,
-        -0.0643, -0.1970,  0.2613,  0.0461, -0.4593,  1.5048, -0.4972,  0.2265,
-         0.0000, -0.0023, -0.4063, -0.1865,  1.1439,  0.0000,  0.9177,  0.6228,
-         0.2612,  0.9842, -0.1008,  0.6514, -0.0511,  0.2139,  0.3231,  0.2777,
-         0.8383, -0.4249,  0.2685, -0.4934, -0.5325, -0.0245,  0.6282,  0.0000,
-         1.0755,  0.0000,  0.6239, -0.4374,  1.3263,  0.0259,  0.1435, -0.4317],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6425,  0.6340, -0.2837,  0.3768,  0.5353, -0.0398,  0.0000, -0.6641,
-         0.1668, -0.2596, -0.8736, -0.2634,  0.2422, -0.0043, -1.4038, -0.4916,
-        -0.7523, -0.4983, -1.0068, -0.2177,  0.0216,  0.0722, -0.3470, -0.4518,
-        -0.0643, -0.1970,  0.2613,  0.0461, -0.4593,  1.5048, -0.4972,  0.2265,
-         0.0000, -0.0023, -0.4063, -0.1865,  1.1439,  0.0000,  0.9177,  0.6228,
-         0.2612,  0.9842, -0.1008,  0.6514, -0.0511,  0.2139,  0.3231,  0.2777,
-         0.8383, -0.4249,  0.2685, -0.4934, -0.5325, -0.0245,  0.6282,  0.0000,
-         1.0755,  0.0000,  0.6239, -0.4374,  1.3263,  0.0259,  0.1435, -0.4317],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.4036e-01,  6.3036e-01, -2.4734e-01,  3.6162e-01,  5.1393e-01,
-         3.8762e-02,  1.0421e-03, -6.7201e-01,  8.2848e-02, -2.2102e-01,
-        -8.6657e-01, -2.5956e-01,  1.7677e-01,  1.7074e-02, -1.4025e+00,
-        -5.1630e-01, -7.4886e-01, -4.8601e-01, -1.0030e+00, -2.8233e-01,
-         5.1506e-02,  4.2051e-02, -3.2945e-01, -3.8630e-01, -3.4360e-02,
-        -1.8149e-01,  2.2100e-01,  3.6813e-02, -4.4745e-01,  1.5070e+00,
-        -5.1985e-01,  2.1595e-01, -2.4983e-04, -3.9169e-03, -4.5877e-01,
-        -2.1876e-01,  1.1396e+00,  0.0000e+00,  9.1592e-01,  5.9060e-01,
-         2.6293e-01,  9.8552e-01, -9.9852e-03,  6.3599e-01,  2.3314e-02,
-         1.8204e-01,  3.1664e-01,  2.8439e-01,  8.3383e-01, -4.5513e-01,
-         2.9897e-01, -5.2158e-01, -5.4139e-01,  1.9295e-01,  6.0206e-01,
-         1.3488e-05,  1.0722e+00,  1.9064e-03,  5.7847e-01, -4.7437e-01,
-         1.3252e+00,  1.2519e-02,  1.2625e-01, -3.5588e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6404,  0.6304, -0.2473,  0.3616,  0.5139,  0.0388,  0.0000, -0.6720,
-         0.0828, -0.2210, -0.8666, -0.2596,  0.1768,  0.0171, -1.4025, -0.5163,
-        -0.7489, -0.4860, -1.0030, -0.2823,  0.0515,  0.0421, -0.3294, -0.3863,
-        -0.0344, -0.1815,  0.2210,  0.0368, -0.4474,  1.5070, -0.5199,  0.2159,
-         0.0000, -0.0039, -0.4588, -0.2188,  1.1396,  0.0000,  0.9159,  0.5906,
-         0.2629,  0.9855, -0.0100,  0.6360,  0.0233,  0.1820,  0.3166,  0.2844,
-         0.8338, -0.4551,  0.2990, -0.5216, -0.5414,  0.1930,  0.6021,  0.0000,
-         1.0722,  0.0000,  0.5785, -0.4744,  1.3252,  0.0125,  0.1262, -0.3559],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6404,  0.6304, -0.2473,  0.3616,  0.5139,  0.0388,  0.0000, -0.6720,
-         0.0828, -0.2210, -0.8666, -0.2596,  0.1768,  0.0171, -1.4025, -0.5163,
-        -0.7489, -0.4860, -1.0030, -0.2823,  0.0515,  0.0421, -0.3294, -0.3863,
-        -0.0344, -0.1815,  0.2210,  0.0368, -0.4474,  1.5070, -0.5199,  0.2159,
-         0.0000, -0.0039, -0.4588, -0.2188,  1.1396,  0.0000,  0.9159,  0.5906,
-         0.2629,  0.9855, -0.0100,  0.6360,  0.0233,  0.1820,  0.3166,  0.2844,
-         0.8338, -0.4551,  0.2990, -0.5216, -0.5414,  0.1930,  0.6021,  0.0000,
-         1.0722,  0.0000,  0.5785, -0.4744,  1.3252,  0.0125,  0.1262, -0.3559],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.3500e-01,  6.1131e-01, -1.7671e-01,  3.4817e-01,  4.0812e-01,
-         1.5163e-01,  8.9264e-04, -6.8258e-01,  3.3915e-02, -2.3383e-01,
-        -8.4980e-01, -2.3349e-01,  1.8038e-01,  8.7006e-02, -1.4035e+00,
-        -5.4962e-01, -7.4221e-01, -4.8086e-01, -1.0001e+00, -3.4251e-01,
-         6.5892e-02, -5.0297e-02, -3.4186e-01, -3.2445e-01, -3.4199e-02,
-        -1.4665e-01,  2.1284e-01,  1.6382e-02, -3.8811e-01,  1.5050e+00,
-        -5.4709e-01,  2.2840e-01, -2.1400e-04, -5.0070e-02, -4.9590e-01,
-        -2.6100e-01,  1.1433e+00,  0.0000e+00,  9.2475e-01,  5.7539e-01,
-         2.2568e-01,  9.8637e-01,  7.6764e-02,  6.3672e-01,  7.2475e-02,
-         1.3156e-01,  3.1048e-01,  3.1114e-01,  8.3001e-01, -4.9028e-01,
-         3.0501e-01, -5.3834e-01, -5.4012e-01,  3.3862e-01,  5.8820e-01,
-         1.1553e-05,  1.0710e+00,  1.6330e-03,  5.4950e-01, -4.8809e-01,
-         1.3290e+00,  1.7673e-02,  9.6163e-02, -2.8290e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6350,  0.6113, -0.1767,  0.3482,  0.4081,  0.1516,  0.0000, -0.6826,
-         0.0339, -0.2338, -0.8498, -0.2335,  0.1804,  0.0870, -1.4035, -0.5496,
-        -0.7422, -0.4809, -1.0001, -0.3425,  0.0659, -0.0503, -0.3419, -0.3244,
-        -0.0342, -0.1467,  0.2128,  0.0164, -0.3881,  1.5050, -0.5471,  0.2284,
-         0.0000, -0.0501, -0.4959, -0.2610,  1.1433,  0.0000,  0.9248,  0.5754,
-         0.2257,  0.9864,  0.0768,  0.6367,  0.0725,  0.1316,  0.3105,  0.3111,
-         0.8300, -0.4903,  0.3050, -0.5383, -0.5401,  0.3386,  0.5882,  0.0000,
-         1.0710,  0.0000,  0.5495, -0.4881,  1.3290,  0.0177,  0.0962, -0.2829],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6350,  0.6113, -0.1767,  0.3482,  0.4081,  0.1516,  0.0000, -0.6826,
-         0.0339, -0.2338, -0.8498, -0.2335,  0.1804,  0.0870, -1.4035, -0.5496,
-        -0.7422, -0.4809, -1.0001, -0.3425,  0.0659, -0.0503, -0.3419, -0.3244,
-        -0.0342, -0.1467,  0.2128,  0.0164, -0.3881,  1.5050, -0.5471,  0.2284,
-         0.0000, -0.0501, -0.4959, -0.2610,  1.1433,  0.0000,  0.9248,  0.5754,
-         0.2257,  0.9864,  0.0768,  0.6367,  0.0725,  0.1316,  0.3105,  0.3111,
-         0.8300, -0.4903,  0.3050, -0.5383, -0.5401,  0.3386,  0.5882,  0.0000,
-         1.0710,  0.0000,  0.5495, -0.4881,  1.3290,  0.0177,  0.0962, -0.2829],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.2145e-01,  5.9023e-01, -9.7670e-02,  3.1163e-01,  3.7370e-01,
-         2.5237e-01,  7.6415e-04, -6.8939e-01,  2.7698e-03, -3.0795e-01,
-        -8.3914e-01, -1.7875e-01,  1.9994e-01,  2.1118e-01, -1.4038e+00,
-        -5.6700e-01, -7.3355e-01, -4.6738e-01, -9.9813e-01, -3.8493e-01,
-         6.0937e-02, -1.4467e-01, -3.8431e-01, -2.7310e-01, -5.7281e-02,
-        -8.0635e-02,  1.7740e-01, -2.6739e-02, -3.5758e-01,  1.5028e+00,
-        -5.5109e-01,  2.4288e-01, -1.8320e-04, -8.5688e-02, -5.2144e-01,
-        -2.5870e-01,  1.1417e+00,  0.0000e+00,  9.1919e-01,  5.4971e-01,
-         1.7333e-01,  9.8054e-01,  1.5245e-01,  6.4214e-01,  7.2957e-02,
-         1.0104e-01,  2.7743e-01,  3.3944e-01,  8.3388e-01, -5.0729e-01,
-         2.7222e-01, -5.5482e-01, -5.0834e-01,  4.2531e-01,  5.6785e-01,
-         9.8905e-06,  1.0701e+00,  1.3980e-03,  5.3191e-01, -4.9820e-01,
-         1.3314e+00,  2.5425e-02,  6.9219e-02, -2.2708e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6215,  0.5902, -0.0977,  0.3116,  0.3737,  0.2524,  0.0000, -0.6894,
-         0.0028, -0.3080, -0.8391, -0.1788,  0.1999,  0.2112, -1.4038, -0.5670,
-        -0.7336, -0.4674, -0.9981, -0.3849,  0.0609, -0.1447, -0.3843, -0.2731,
-        -0.0573, -0.0806,  0.1774, -0.0267, -0.3576,  1.5028, -0.5511,  0.2429,
-         0.0000, -0.0857, -0.5214, -0.2587,  1.1417,  0.0000,  0.9192,  0.5497,
-         0.1733,  0.9805,  0.1525,  0.6421,  0.0730,  0.1010,  0.2774,  0.3394,
-         0.8339, -0.5073,  0.2722, -0.5548, -0.5083,  0.4253,  0.5678,  0.0000,
-         1.0701,  0.0000,  0.5319, -0.4982,  1.3314,  0.0254,  0.0692, -0.2271],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6215,  0.5902, -0.0977,  0.3116,  0.3737,  0.2524,  0.0000, -0.6894,
-         0.0028, -0.3080, -0.8391, -0.1788,  0.1999,  0.2112, -1.4038, -0.5670,
-        -0.7336, -0.4674, -0.9981, -0.3849,  0.0609, -0.1447, -0.3843, -0.2731,
-        -0.0573, -0.0806,  0.1774, -0.0267, -0.3576,  1.5028, -0.5511,  0.2429,
-         0.0000, -0.0857, -0.5214, -0.2587,  1.1417,  0.0000,  0.9192,  0.5497,
-         0.1733,  0.9805,  0.1525,  0.6421,  0.0730,  0.1010,  0.2774,  0.3394,
-         0.8339, -0.5073,  0.2722, -0.5548, -0.5083,  0.4253,  0.5678,  0.0000,
-         1.0701,  0.0000,  0.5319, -0.4982,  1.3314,  0.0254,  0.0692, -0.2271],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.1668e-01,  5.6268e-01, -5.7647e-02,  2.6522e-01,  3.4746e-01,
-         2.9227e-01,  6.5377e-04, -6.9172e-01, -2.5842e-02, -3.7660e-01,
-        -8.1673e-01, -1.2168e-01,  2.0623e-01,  3.1455e-01, -1.4028e+00,
-        -5.7764e-01, -7.2087e-01, -4.4396e-01, -9.9791e-01, -4.0508e-01,
-         9.6803e-02, -1.8390e-01, -4.0050e-01, -2.0532e-01, -1.4204e-02,
-         1.8732e-02,  9.4744e-02, -4.2403e-02, -3.0900e-01,  1.5004e+00,
-        -5.5411e-01,  2.1999e-01, -1.5673e-04, -5.1963e-02, -5.2070e-01,
-        -2.4102e-01,  1.1441e+00,  0.0000e+00,  9.2264e-01,  5.6226e-01,
-         1.2789e-01,  9.7059e-01,  1.6548e-01,  6.5120e-01,  6.5211e-02,
-         6.1442e-02,  2.4294e-01,  3.7029e-01,  8.3851e-01, -5.2723e-01,
-         2.5153e-01, -5.2954e-01, -4.7323e-01,  4.3459e-01,  5.9014e-01,
-         8.4618e-06,  1.0697e+00,  1.1960e-03,  5.0578e-01, -5.0195e-01,
-         1.3360e+00, -8.6477e-02,  6.4690e-02, -2.5386e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6167,  0.5627, -0.0576,  0.2652,  0.3475,  0.2923,  0.0000, -0.6917,
-        -0.0258, -0.3766, -0.8167, -0.1217,  0.2062,  0.3145, -1.4028, -0.5776,
-        -0.7209, -0.4440, -0.9979, -0.4051,  0.0968, -0.1839, -0.4005, -0.2053,
-        -0.0142,  0.0187,  0.0947, -0.0424, -0.3090,  1.5004, -0.5541,  0.2200,
-         0.0000, -0.0520, -0.5207, -0.2410,  1.1441,  0.0000,  0.9226,  0.5623,
-         0.1279,  0.9706,  0.1655,  0.6512,  0.0652,  0.0614,  0.2429,  0.3703,
-         0.8385, -0.5272,  0.2515, -0.5295, -0.4732,  0.4346,  0.5901,  0.0000,
-         1.0697,  0.0000,  0.5058, -0.5020,  1.3360, -0.0865,  0.0647, -0.2539],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6167,  0.5627, -0.0576,  0.2652,  0.3475,  0.2923,  0.0000, -0.6917,
-        -0.0258, -0.3766, -0.8167, -0.1217,  0.2062,  0.3145, -1.4028, -0.5776,
-        -0.7209, -0.4440, -0.9979, -0.4051,  0.0968, -0.1839, -0.4005, -0.2053,
-        -0.0142,  0.0187,  0.0947, -0.0424, -0.3090,  1.5004, -0.5541,  0.2200,
-         0.0000, -0.0520, -0.5207, -0.2410,  1.1441,  0.0000,  0.9226,  0.5623,
-         0.1279,  0.9706,  0.1655,  0.6512,  0.0652,  0.0614,  0.2429,  0.3703,
-         0.8385, -0.5272,  0.2515, -0.5295, -0.4732,  0.4346,  0.5901,  0.0000,
-         1.0697,  0.0000,  0.5058, -0.5020,  1.3360, -0.0865,  0.0647, -0.2539],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.0525e-01,  5.4585e-01, -5.4705e-02,  2.1522e-01,  2.6450e-01,
-         2.8843e-01,  5.5902e-04, -6.9545e-01, -5.7872e-02, -3.8318e-01,
-        -7.8899e-01, -1.9815e-01,  1.2692e-01,  3.3820e-01, -1.3989e+00,
-        -5.7985e-01, -7.1642e-01, -3.9627e-01, -9.9945e-01, -4.3921e-01,
-         1.0305e-01, -1.2271e-01, -3.8790e-01, -1.4685e-01,  9.0309e-02,
-         8.0191e-02, -2.8660e-02, -3.3127e-02, -2.3066e-01,  1.5036e+00,
-        -5.4713e-01,  1.4020e-01, -1.3402e-04,  8.0386e-02, -5.2785e-01,
-        -2.6344e-01,  1.1405e+00,  0.0000e+00,  9.0622e-01,  5.7846e-01,
-         1.0356e-01,  9.5395e-01,  1.7825e-01,  6.4828e-01,  8.1883e-02,
-         6.3103e-02,  2.0066e-01,  3.6093e-01,  8.6294e-01, -5.4557e-01,
-         2.7658e-01, -4.7054e-01, -4.5655e-01,  3.9055e-01,  6.1684e-01,
-         7.2354e-06,  1.0687e+00,  1.0227e-03,  4.9583e-01, -4.8175e-01,
-         1.3393e+00, -1.9640e-01,  8.7251e-02, -2.8885e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6052,  0.5458, -0.0547,  0.2152,  0.2645,  0.2884,  0.0000, -0.6955,
-        -0.0579, -0.3832, -0.7890, -0.1982,  0.1269,  0.3382, -1.3989, -0.5799,
-        -0.7164, -0.3963, -0.9995, -0.4392,  0.1030, -0.1227, -0.3879, -0.1468,
-         0.0903,  0.0802, -0.0287, -0.0331, -0.2307,  1.5036, -0.5471,  0.1402,
-         0.0000,  0.0804, -0.5278, -0.2634,  1.1405,  0.0000,  0.9062,  0.5785,
-         0.1036,  0.9540,  0.1783,  0.6483,  0.0819,  0.0631,  0.2007,  0.3609,
-         0.8629, -0.5456,  0.2766, -0.4705, -0.4566,  0.3905,  0.6168,  0.0000,
-         1.0687,  0.0000,  0.4958, -0.4817,  1.3393, -0.1964,  0.0873, -0.2889],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6052,  0.5458, -0.0547,  0.2152,  0.2645,  0.2884,  0.0000, -0.6955,
-        -0.0579, -0.3832, -0.7890, -0.1982,  0.1269,  0.3382, -1.3989, -0.5799,
-        -0.7164, -0.3963, -0.9995, -0.4392,  0.1030, -0.1227, -0.3879, -0.1468,
-         0.0903,  0.0802, -0.0287, -0.0331, -0.2307,  1.5036, -0.5471,  0.1402,
-         0.0000,  0.0804, -0.5278, -0.2634,  1.1405,  0.0000,  0.9062,  0.5785,
-         0.1036,  0.9540,  0.1783,  0.6483,  0.0819,  0.0631,  0.2007,  0.3609,
-         0.8629, -0.5456,  0.2766, -0.4705, -0.4566,  0.3905,  0.6168,  0.0000,
-         1.0687,  0.0000,  0.4958, -0.4817,  1.3393, -0.1964,  0.0873, -0.2889],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.7844e-01,  5.2429e-01, -6.7794e-02,  1.8611e-01,  2.4306e-01,
-         3.1547e-01,  4.7773e-04, -7.0102e-01, -9.0168e-02, -3.5620e-01,
-        -7.5371e-01, -3.1503e-01,  2.5413e-02,  2.8830e-01, -1.4024e+00,
-        -5.7404e-01, -7.0919e-01, -3.2607e-01, -1.0011e+00, -5.0673e-01,
-         5.1226e-02, -8.1347e-03, -3.6101e-01, -1.1391e-01,  1.9696e-01,
-         7.4553e-02, -9.9145e-02, -3.0621e-02, -1.7908e-01,  1.5074e+00,
-        -5.4155e-01,  2.3425e-02, -1.1453e-04,  1.7663e-01, -5.2747e-01,
-        -3.2004e-01,  1.1319e+00,  0.0000e+00,  8.5451e-01,  5.5273e-01,
-         8.4971e-02,  9.3705e-01,  1.9437e-01,  6.3211e-01, -2.8614e-02,
-         9.1126e-02,  1.6426e-01,  3.2153e-01,  8.9943e-01, -5.5338e-01,
-         2.9050e-01, -3.9890e-01, -4.4031e-01,  3.4651e-01,  6.1473e-01,
-         6.1833e-06,  1.0662e+00,  8.7397e-04,  5.0349e-01, -4.3674e-01,
-         1.3410e+00, -2.2385e-01,  6.1650e-02, -2.6741e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5784,  0.5243, -0.0678,  0.1861,  0.2431,  0.3155,  0.0000, -0.7010,
-        -0.0902, -0.3562, -0.7537, -0.3150,  0.0254,  0.2883, -1.4024, -0.5740,
-        -0.7092, -0.3261, -1.0011, -0.5067,  0.0512, -0.0081, -0.3610, -0.1139,
-         0.1970,  0.0746, -0.0991, -0.0306, -0.1791,  1.5074, -0.5416,  0.0234,
-         0.0000,  0.1766, -0.5275, -0.3200,  1.1319,  0.0000,  0.8545,  0.5527,
-         0.0850,  0.9370,  0.1944,  0.6321, -0.0286,  0.0911,  0.1643,  0.3215,
-         0.8994, -0.5534,  0.2905, -0.3989, -0.4403,  0.3465,  0.6147,  0.0000,
-         1.0662,  0.0000,  0.5035, -0.4367,  1.3410, -0.2239,  0.0617, -0.2674],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5784,  0.5243, -0.0678,  0.1861,  0.2431,  0.3155,  0.0000, -0.7010,
-        -0.0902, -0.3562, -0.7537, -0.3150,  0.0254,  0.2883, -1.4024, -0.5740,
-        -0.7092, -0.3261, -1.0011, -0.5067,  0.0512, -0.0081, -0.3610, -0.1139,
-         0.1970,  0.0746, -0.0991, -0.0306, -0.1791,  1.5074, -0.5416,  0.0234,
-         0.0000,  0.1766, -0.5275, -0.3200,  1.1319,  0.0000,  0.8545,  0.5527,
-         0.0850,  0.9370,  0.1944,  0.6321, -0.0286,  0.0911,  0.1643,  0.3215,
-         0.8994, -0.5534,  0.2905, -0.3989, -0.4403,  0.3465,  0.6147,  0.0000,
-         1.0662,  0.0000,  0.5035, -0.4367,  1.3410, -0.2239,  0.0617, -0.2674],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.6626e-01,  5.0818e-01, -4.1618e-02,  1.2371e-01,  3.6499e-01,
-         3.7498e-01,  4.0805e-04, -7.0309e-01, -1.2660e-01, -3.5420e-01,
-        -7.2753e-01, -3.6262e-01, -1.6776e-02,  2.6766e-01, -1.4090e+00,
-        -5.7065e-01, -6.9769e-01, -2.9078e-01, -1.0003e+00, -5.4069e-01,
-         1.6912e-02, -2.2876e-02, -3.5790e-01, -1.0318e-01,  2.0936e-01,
-         5.1014e-02, -7.7842e-02, -4.2371e-02, -1.6923e-01,  1.5040e+00,
-        -5.1757e-01, -2.1990e-02, -9.7826e-05,  1.5210e-01, -5.2883e-01,
-        -3.3943e-01,  1.1228e+00,  0.0000e+00,  7.9887e-01,  5.0730e-01,
-        -4.9438e-03,  9.2725e-01,  2.3362e-01,  6.1538e-01, -1.4305e-01,
-         1.0326e-01,  1.5295e-01,  2.8285e-01,  9.2284e-01, -5.5858e-01,
-         2.6296e-01, -3.6377e-01, -3.9113e-01,  2.9999e-01,  5.8458e-01,
-         5.2815e-06,  1.0659e+00,  7.4650e-04,  5.1879e-01, -4.1963e-01,
-         1.3430e+00, -2.1477e-01,  2.2848e-02, -2.4526e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5663,  0.5082, -0.0416,  0.1237,  0.3650,  0.3750,  0.0000, -0.7031,
-        -0.1266, -0.3542, -0.7275, -0.3626, -0.0168,  0.2677, -1.4090, -0.5707,
-        -0.6977, -0.2908, -1.0003, -0.5407,  0.0169, -0.0229, -0.3579, -0.1032,
-         0.2094,  0.0510, -0.0778, -0.0424, -0.1692,  1.5040, -0.5176, -0.0220,
-         0.0000,  0.1521, -0.5288, -0.3394,  1.1228,  0.0000,  0.7989,  0.5073,
-        -0.0049,  0.9273,  0.2336,  0.6154, -0.1430,  0.1033,  0.1529,  0.2829,
-         0.9228, -0.5586,  0.2630, -0.3638, -0.3911,  0.3000,  0.5846,  0.0000,
-         1.0659,  0.0000,  0.5188, -0.4196,  0.0000, -0.2148,  0.0228, -0.2453],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5663,  0.5082, -0.0416,  0.1237,  0.3650,  0.3750,  0.0000, -0.7031,
-        -0.1266, -0.3542, -0.7275, -0.3626, -0.0168,  0.2677, -1.4090, -0.5707,
-        -0.6977, -0.2908, -1.0003, -0.5407,  0.0169, -0.0229, -0.3579, -0.1032,
-         0.2094,  0.0510, -0.0778, -0.0424, -0.1692,  1.5040, -0.5176, -0.0220,
-         0.0000,  0.1521, -0.5288, -0.3394,  1.1228,  0.0000,  0.7989,  0.5073,
-        -0.0049,  0.9273,  0.2336,  0.6154, -0.1430,  0.1033,  0.1529,  0.2829,
-         0.9228, -0.5586,  0.2630, -0.3638, -0.3911,  0.3000,  0.5846,  0.0000,
-         1.0659,  0.0000,  0.5188, -0.4196,  0.0000, -0.2148,  0.0228, -0.2453],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.5728e-01,  4.9672e-01, -2.2435e-03,  1.6039e-02,  5.2010e-01,
-         3.5580e-01,  3.4836e-04, -6.8731e-01, -1.3660e-01, -3.6136e-01,
-        -6.9511e-01, -3.4699e-01, -1.4176e-02,  2.0941e-01, -1.4173e+00,
-        -5.5032e-01, -6.7155e-01, -2.6646e-01, -9.9879e-01, -5.5331e-01,
-         2.6910e-03, -4.0174e-02, -3.2042e-01, -8.5846e-02,  2.2005e-01,
-         3.1165e-02, -4.2465e-02, -5.8996e-02, -1.7584e-01,  1.4961e+00,
-        -5.2029e-01, -1.0416e-01, -8.3516e-05,  1.2926e-01, -5.1236e-01,
-        -3.2901e-01,  1.1158e+00,  0.0000e+00,  7.6130e-01,  4.8167e-01,
-        -1.0942e-01,  9.2122e-01,  1.8350e-01,  5.9760e-01, -3.1080e-01,
-         9.3059e-02,  2.0638e-01,  2.4425e-01,  9.3663e-01, -5.5527e-01,
-         2.4943e-01, -2.7359e-01, -3.6909e-01,  7.4561e-02,  5.6233e-01,
-         4.5089e-06,  1.0656e+00,  6.3730e-04,  5.4558e-01, -3.8663e-01,
-         1.7024e-03, -2.0978e-01,  3.4123e-02, -1.9346e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5573,  0.4967, -0.0022,  0.0160,  0.5201,  0.3558,  0.0000, -0.6873,
-        -0.1366, -0.3614, -0.6951, -0.3470, -0.0142,  0.2094, -1.4173, -0.5503,
-        -0.6716, -0.2665, -0.9988, -0.5533,  0.0027, -0.0402, -0.3204, -0.0858,
-         0.2200,  0.0312, -0.0425, -0.0590, -0.1758,  1.4961, -0.5203, -0.1042,
-         0.0000,  0.1293, -0.5124, -0.3290,  1.1158,  0.0000,  0.7613,  0.4817,
-        -0.1094,  0.9212,  0.1835,  0.5976, -0.3108,  0.0931,  0.2064,  0.2443,
-         0.9366, -0.5553,  0.2494, -0.2736, -0.3691,  0.0746,  0.5623,  0.0000,
-         1.0656,  0.0000,  0.5456, -0.3866,  0.0000, -0.2098,  0.0341, -0.1935],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5573,  0.4967, -0.0022,  0.0160,  0.5201,  0.3558,  0.0000, -0.6873,
-        -0.1366, -0.3614, -0.6951, -0.3470, -0.0142,  0.2094, -1.4173, -0.5503,
-        -0.6716, -0.2665, -0.9988, -0.5533,  0.0027, -0.0402, -0.3204, -0.0858,
-         0.2200,  0.0312, -0.0425, -0.0590, -0.1758,  1.4961, -0.5203, -0.1042,
-         0.0000,  0.1293, -0.5124, -0.3290,  1.1158,  0.0000,  0.7613,  0.4817,
-        -0.1094,  0.9212,  0.1835,  0.5976, -0.3108,  0.0931,  0.2064,  0.2443,
-         0.9366, -0.5553,  0.2494, -0.2736, -0.3691,  0.0746,  0.5623,  0.0000,
-         1.0656,  0.0000,  0.5456, -0.3866,  0.0000, -0.2098,  0.0341, -0.1935],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.5047e-01,  4.4973e-01,  6.2542e-02, -5.7362e-02,  6.3128e-01,
-         3.3603e-01,  2.9726e-04, -6.7263e-01, -1.4934e-01, -3.8050e-01,
-        -6.6589e-01, -3.0309e-01,  4.8251e-02,  1.9018e-01, -1.4266e+00,
-        -5.3952e-01, -6.3903e-01, -2.4409e-01, -9.9123e-01, -5.6829e-01,
-        -3.6255e-02, -6.8751e-02, -2.7236e-01, -6.8772e-02,  2.1357e-01,
-        -1.3793e-04,  9.2802e-03, -6.0042e-02, -1.9401e-01,  1.4867e+00,
-        -5.1789e-01, -1.3741e-01, -7.1265e-05,  1.1398e-01, -5.1325e-01,
-        -3.3397e-01,  1.1126e+00,  0.0000e+00,  7.5114e-01,  4.5934e-01,
-        -1.9083e-01,  9.1833e-01,  1.7224e-01,  5.9478e-01, -4.4776e-01,
-         4.4938e-02,  2.2342e-01,  2.1155e-01,  9.4658e-01, -5.5168e-01,
-         2.2345e-01, -1.6398e-01, -3.0004e-01, -8.0197e-02,  5.3234e-01,
-         3.8475e-06,  1.0655e+00,  5.4382e-04,  5.7665e-01, -3.4530e-01,
-         1.4527e-03, -1.6444e-01,  4.3380e-02, -1.3788e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 5.5047e-01,  4.4973e-01,  6.2542e-02, -5.7362e-02,  6.3128e-01,
-         3.3603e-01,  0.0000e+00, -6.7263e-01, -1.4934e-01, -3.8050e-01,
-        -6.6589e-01, -3.0309e-01,  4.8251e-02,  1.9018e-01, -1.4266e+00,
-        -5.3952e-01, -6.3903e-01, -2.4409e-01, -9.9123e-01, -5.6829e-01,
-        -3.6255e-02, -6.8751e-02, -2.7236e-01, -6.8772e-02,  2.1357e-01,
-        -1.3793e-04,  9.2802e-03, -6.0042e-02, -1.9401e-01,  1.4867e+00,
-        -5.1789e-01, -1.3741e-01,  0.0000e+00,  1.1398e-01, -5.1325e-01,
-        -3.3397e-01,  1.1126e+00,  0.0000e+00,  7.5114e-01,  4.5934e-01,
-        -1.9083e-01,  9.1833e-01,  1.7224e-01,  5.9478e-01, -4.4776e-01,
-         4.4938e-02,  2.2342e-01,  2.1155e-01,  9.4658e-01, -5.5168e-01,
-         2.2345e-01, -1.6398e-01, -3.0004e-01, -8.0197e-02,  5.3234e-01,
-         0.0000e+00,  1.0655e+00,  0.0000e+00,  5.7665e-01, -3.4530e-01,
-         0.0000e+00, -1.6444e-01,  4.3380e-02, -1.3788e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 5.5047e-01,  4.4973e-01,  6.2542e-02, -5.7362e-02,  6.3128e-01,
-         3.3603e-01,  0.0000e+00, -6.7263e-01, -1.4934e-01, -3.8050e-01,
-        -6.6589e-01, -3.0309e-01,  4.8251e-02,  1.9018e-01, -1.4266e+00,
-        -5.3952e-01, -6.3903e-01, -2.4409e-01, -9.9123e-01, -5.6829e-01,
-        -3.6255e-02, -6.8751e-02, -2.7236e-01, -6.8772e-02,  2.1357e-01,
-        -1.3793e-04,  9.2802e-03, -6.0042e-02, -1.9401e-01,  1.4867e+00,
-        -5.1789e-01, -1.3741e-01,  0.0000e+00,  1.1398e-01, -5.1325e-01,
-        -3.3397e-01,  1.1126e+00,  0.0000e+00,  7.5114e-01,  4.5934e-01,
-        -1.9083e-01,  9.1833e-01,  1.7224e-01,  5.9478e-01, -4.4776e-01,
-         4.4938e-02,  2.2342e-01,  2.1155e-01,  9.4658e-01, -5.5168e-01,
-         2.2345e-01, -1.6398e-01, -3.0004e-01, -8.0197e-02,  5.3234e-01,
-         0.0000e+00,  1.0655e+00,  0.0000e+00,  5.7665e-01, -3.4530e-01,
-         0.0000e+00, -1.6444e-01,  4.3380e-02, -1.3788e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 5.5140e-01,  3.8380e-01,  1.0333e-01, -1.0257e-01,  6.3202e-01,
-         3.5682e-01,  2.5354e-04, -6.6595e-01, -1.5545e-01, -4.0760e-01,
-        -6.3484e-01, -2.8992e-01,  9.1408e-02,  1.5846e-01, -1.4380e+00,
-        -5.1524e-01, -6.0738e-01, -2.2113e-01, -9.8586e-01, -5.9243e-01,
-        -5.8980e-02, -5.1422e-02, -2.0942e-01, -8.4843e-02,  2.4475e-01,
-        -2.8406e-02,  6.3014e-02, -2.6862e-02, -1.8045e-01,  1.4787e+00,
-        -5.3108e-01, -1.6195e-01, -6.0784e-05,  8.3846e-02, -5.1943e-01,
-        -3.4078e-01,  1.1146e+00,  0.0000e+00,  7.5436e-01,  4.5692e-01,
-        -2.0929e-01,  9.1523e-01,  1.5030e-01,  6.0322e-01, -5.6487e-01,
-         3.0204e-03,  2.3032e-01,  1.7768e-01,  9.4351e-01, -5.6330e-01,
-         2.2243e-01, -5.4021e-02, -2.2558e-01, -1.2689e-01,  5.0926e-01,
-         3.2816e-06,  1.0652e+00,  4.6384e-04,  5.7319e-01, -2.9187e-01,
-         1.2390e-03, -1.3754e-01,  4.7827e-02, -9.4046e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5514,  0.3838,  0.1033, -0.1026,  0.6320,  0.3568,  0.0000, -0.6660,
-        -0.1554, -0.4076, -0.6348, -0.2899,  0.0914,  0.1585, -1.4380, -0.5152,
-        -0.6074, -0.2211, -0.9859, -0.5924, -0.0590, -0.0514, -0.2094, -0.0848,
-         0.2448, -0.0284,  0.0630, -0.0269, -0.1805,  1.4787, -0.5311, -0.1619,
-         0.0000,  0.0838, -0.5194, -0.3408,  1.1146,  0.0000,  0.7544,  0.4569,
-        -0.2093,  0.9152,  0.1503,  0.6032, -0.5649,  0.0030,  0.2303,  0.1777,
-         0.9435, -0.5633,  0.2224, -0.0540, -0.2256, -0.1269,  0.5093,  0.0000,
-         1.0652,  0.0000,  0.5732, -0.2919,  0.0000, -0.1375,  0.0478, -0.0940],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5514,  0.3838,  0.1033, -0.1026,  0.6320,  0.3568,  0.0000, -0.6660,
-        -0.1554, -0.4076, -0.6348, -0.2899,  0.0914,  0.1585, -1.4380, -0.5152,
-        -0.6074, -0.2211, -0.9859, -0.5924, -0.0590, -0.0514, -0.2094, -0.0848,
-         0.2448, -0.0284,  0.0630, -0.0269, -0.1805,  1.4787, -0.5311, -0.1619,
-         0.0000,  0.0838, -0.5194, -0.3408,  1.1146,  0.0000,  0.7544,  0.4569,
-        -0.2093,  0.9152,  0.1503,  0.6032, -0.5649,  0.0030,  0.2303,  0.1777,
-         0.9435, -0.5633,  0.2224, -0.0540, -0.2256, -0.1269,  0.5093,  0.0000,
-         1.0652,  0.0000,  0.5732, -0.2919,  0.0000, -0.1375,  0.0478, -0.0940],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.3905e-01,  3.4065e-01,  1.1721e-01, -1.1936e-01,  6.3459e-01,
-         3.9768e-01,  2.1616e-04, -6.5812e-01, -1.9629e-01, -4.0594e-01,
-        -6.1425e-01, -2.8467e-01,  9.6524e-02,  1.0301e-01, -1.4489e+00,
-        -4.9337e-01, -5.8914e-01, -2.0104e-01, -9.8298e-01, -6.1570e-01,
-        -1.0453e-01, -1.7080e-02, -1.6030e-01, -1.6458e-01,  2.7094e-01,
-        -9.7508e-02,  1.4582e-01,  1.0226e-02, -1.9795e-01,  1.4738e+00,
-        -5.2536e-01, -1.7267e-01, -5.1822e-05,  8.6860e-02, -5.3917e-01,
-        -3.6035e-01,  1.1141e+00,  0.0000e+00,  7.4927e-01,  4.2505e-01,
-        -2.0808e-01,  9.1462e-01,  1.2448e-01,  6.0344e-01, -6.7064e-01,
-         1.2047e-02,  2.3145e-01,  1.4974e-01,  9.4994e-01, -5.6551e-01,
-         2.3764e-01,  1.4328e-02, -1.6489e-01, -1.0922e-01,  4.7046e-01,
-         2.7978e-06,  1.0651e+00,  3.9545e-04,  5.7799e-01, -2.6864e-01,
-         1.0563e-03, -6.4826e-02,  3.6531e-02, -1.0345e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5391,  0.3406,  0.1172, -0.1194,  0.6346,  0.3977,  0.0000, -0.6581,
-        -0.1963, -0.4059, -0.6143, -0.2847,  0.0965,  0.1030, -1.4489, -0.4934,
-        -0.5891, -0.2010, -0.9830, -0.6157, -0.1045, -0.0171, -0.1603, -0.1646,
-         0.2709, -0.0975,  0.1458,  0.0102, -0.1980,  1.4738, -0.5254, -0.1727,
-         0.0000,  0.0869, -0.5392, -0.3603,  1.1141,  0.0000,  0.7493,  0.4250,
-        -0.2081,  0.9146,  0.1245,  0.6034, -0.6706,  0.0120,  0.2314,  0.1497,
-         0.9499, -0.5655,  0.2376,  0.0143, -0.1649, -0.1092,  0.4705,  0.0000,
-         1.0651,  0.0000,  0.5780, -0.2686,  0.0000, -0.0648,  0.0365, -0.0103],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5391,  0.3406,  0.1172, -0.1194,  0.6346,  0.3977,  0.0000, -0.6581,
-        -0.1963, -0.4059, -0.6143, -0.2847,  0.0965,  0.1030, -1.4489, -0.4934,
-        -0.5891, -0.2010, -0.9830, -0.6157, -0.1045, -0.0171, -0.1603, -0.1646,
-         0.2709, -0.0975,  0.1458,  0.0102, -0.1980,  1.4738, -0.5254, -0.1727,
-         0.0000,  0.0869, -0.5392, -0.3603,  1.1141,  0.0000,  0.7493,  0.4250,
-        -0.2081,  0.9146,  0.1245,  0.6034, -0.6706,  0.0120,  0.2314,  0.1497,
-         0.9499, -0.5655,  0.2376,  0.0143, -0.1649, -0.1092,  0.4705,  0.0000,
-         1.0651,  0.0000,  0.5780, -0.2686,  0.0000, -0.0648,  0.0365, -0.0103],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.2527e-01,  3.3718e-01,  9.2360e-02, -1.1930e-01,  6.4721e-01,
-         4.2618e-01,  1.8421e-04, -6.4675e-01, -1.9705e-01, -3.8463e-01,
-        -5.8859e-01, -3.3711e-01,  7.3762e-02, -3.5781e-02, -1.4572e+00,
-        -4.6714e-01, -5.5992e-01, -1.7509e-01, -9.8324e-01, -6.3746e-01,
-        -1.3227e-01,  5.4825e-02, -1.2348e-01, -2.2187e-01,  2.8305e-01,
-        -1.5737e-01,  2.2050e-01,  3.8638e-02, -2.0444e-01,  1.4683e+00,
-        -5.1914e-01, -1.7522e-01, -4.4163e-05,  9.8886e-02, -5.4426e-01,
-        -3.9320e-01,  1.1082e+00,  0.0000e+00,  7.2997e-01,  3.8877e-01,
-        -1.9213e-01,  9.0984e-01,  7.4562e-02,  5.9224e-01, -7.6301e-01,
-         2.6643e-02,  2.6192e-01,  9.4339e-02,  9.6233e-01, -5.5871e-01,
-         2.3592e-01,  3.9413e-02, -1.2800e-01, -5.8657e-02,  4.3572e-01,
-         2.3843e-06,  1.0661e+00,  3.3701e-04,  5.8931e-01, -2.4155e-01,
-         9.0022e-04, -2.3956e-02,  1.1601e-02,  3.6838e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5253,  0.3372,  0.0924, -0.1193,  0.6472,  0.4262,  0.0000, -0.6468,
-        -0.1971, -0.3846, -0.5886, -0.3371,  0.0738, -0.0358, -1.4572, -0.4671,
-        -0.5599, -0.1751, -0.9832, -0.6375, -0.1323,  0.0548, -0.1235, -0.2219,
-         0.2830, -0.1574,  0.2205,  0.0386, -0.2044,  1.4683, -0.5191, -0.1752,
-         0.0000,  0.0989, -0.5443, -0.3932,  1.1082,  0.0000,  0.7300,  0.3888,
-        -0.1921,  0.9098,  0.0746,  0.5922, -0.7630,  0.0266,  0.2619,  0.0943,
-         0.9623, -0.5587,  0.2359,  0.0394, -0.1280, -0.0587,  0.4357,  0.0000,
-         1.0661,  0.0000,  0.5893, -0.2415,  0.0000, -0.0240,  0.0116,  0.0368],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5253,  0.3372,  0.0924, -0.1193,  0.6472,  0.4262,  0.0000, -0.6468,
-        -0.1971, -0.3846, -0.5886, -0.3371,  0.0738, -0.0358, -1.4572, -0.4671,
-        -0.5599, -0.1751, -0.9832, -0.6375, -0.1323,  0.0548, -0.1235, -0.2219,
-         0.2830, -0.1574,  0.2205,  0.0386, -0.2044,  1.4683, -0.5191, -0.1752,
-         0.0000,  0.0989, -0.5443, -0.3932,  1.1082,  0.0000,  0.7300,  0.3888,
-        -0.1921,  0.9098,  0.0746,  0.5922, -0.7630,  0.0266,  0.2619,  0.0943,
-         0.9623, -0.5587,  0.2359,  0.0394, -0.1280, -0.0587,  0.4357,  0.0000,
-         1.0661,  0.0000,  0.5893, -0.2415,  0.0000, -0.0240,  0.0116,  0.0368],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.1107e-01,  3.6013e-01,  1.0424e-01, -8.3970e-02,  6.8822e-01,
-         4.3232e-01,  1.5693e-04, -6.2801e-01, -1.9123e-01, -3.6385e-01,
-        -5.6183e-01, -3.6256e-01,  1.0312e-01, -1.5182e-01, -1.4616e+00,
-        -4.3924e-01, -5.4203e-01, -1.5586e-01, -9.8493e-01, -6.3676e-01,
-        -1.4500e-01,  1.0214e-01, -1.1260e-01, -2.2932e-01,  2.4679e-01,
-        -1.9054e-01,  2.7658e-01,  6.7748e-02, -2.2385e-01,  1.4634e+00,
-        -5.1848e-01, -1.5862e-01, -3.7622e-05,  1.1791e-01, -5.3438e-01,
-        -3.9995e-01,  1.0952e+00,  0.0000e+00,  7.0000e-01,  3.6749e-01,
-        -2.1312e-01,  9.0526e-01,  6.7397e-02,  5.6471e-01, -8.4994e-01,
-         5.0062e-02,  3.0251e-01,  2.6352e-02,  9.6509e-01, -5.3954e-01,
-         2.1624e-01,  7.4626e-02, -1.3976e-01, -2.8365e-03,  4.3100e-01,
-         2.0312e-06,  1.0666e+00,  2.8709e-04,  6.1937e-01, -2.5057e-01,
-         7.6688e-04, -9.4253e-03,  9.5602e-03,  6.4166e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5111,  0.3601,  0.1042, -0.0840,  0.6882,  0.4323,  0.0000, -0.6280,
-        -0.1912, -0.3638, -0.5618, -0.3626,  0.1031, -0.1518, -1.4616, -0.4392,
-        -0.5420, -0.1559, -0.9849, -0.6368, -0.1450,  0.1021, -0.1126, -0.2293,
-         0.2468, -0.1905,  0.2766,  0.0677, -0.2239,  1.4634, -0.5185, -0.1586,
-         0.0000,  0.1179, -0.5344, -0.3999,  1.0952,  0.0000,  0.7000,  0.3675,
-        -0.2131,  0.9053,  0.0674,  0.5647, -0.8499,  0.0501,  0.3025,  0.0264,
-         0.9651, -0.5395,  0.2162,  0.0746, -0.1398, -0.0028,  0.4310,  0.0000,
-         1.0666,  0.0000,  0.6194, -0.2506,  0.0000, -0.0094,  0.0096,  0.0642],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5111,  0.3601,  0.1042, -0.0840,  0.6882,  0.4323,  0.0000, -0.6280,
-        -0.1912, -0.3638, -0.5618, -0.3626,  0.1031, -0.1518, -1.4616, -0.4392,
-        -0.5420, -0.1559, -0.9849, -0.6368, -0.1450,  0.1021, -0.1126, -0.2293,
-         0.2468, -0.1905,  0.2766,  0.0677, -0.2239,  1.4634, -0.5185, -0.1586,
-         0.0000,  0.1179, -0.5344, -0.3999,  1.0952,  0.0000,  0.7000,  0.3675,
-        -0.2131,  0.9053,  0.0674,  0.5647, -0.8499,  0.0501,  0.3025,  0.0264,
-         0.9651, -0.5395,  0.2162,  0.0746, -0.1398, -0.0028,  0.4310,  0.0000,
-         1.0666,  0.0000,  0.6194, -0.2506,  0.0000, -0.0094,  0.0096,  0.0642],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.9059e-01,  3.3983e-01,  1.3530e-01, -6.8357e-02,  6.9423e-01,
-         4.3732e-01,  1.3364e-04, -6.0342e-01, -1.7385e-01, -3.5521e-01,
-        -5.3148e-01, -3.7287e-01,  1.3630e-01, -2.9825e-01, -1.4656e+00,
-        -4.1021e-01, -5.1826e-01, -1.3938e-01, -9.8922e-01, -6.3435e-01,
-        -1.4256e-01,  6.8182e-02, -1.4576e-01, -2.5938e-01,  1.5539e-01,
-        -2.0452e-01,  3.4185e-01,  4.7995e-02, -2.2470e-01,  1.4552e+00,
-        -5.1370e-01, -1.6977e-01, -3.2038e-05,  4.6933e-02, -5.1883e-01,
-        -3.9791e-01,  1.0887e+00,  0.0000e+00,  7.0005e-01,  3.8581e-01,
-        -2.5009e-01,  8.9414e-01,  6.2919e-02,  5.6291e-01, -9.2316e-01,
-        -1.3646e-02,  3.2716e-01, -7.6990e-04,  9.7335e-01, -5.2262e-01,
-         1.3059e-01,  5.4989e-02, -1.3847e-01, -1.1867e-02,  4.2832e-01,
-         1.7297e-06,  1.0667e+00,  2.4448e-04,  6.4024e-01, -2.2205e-01,
-         6.5306e-04, -1.2438e-02, -1.1463e-02,  1.3325e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 4.9059e-01,  3.3983e-01,  1.3530e-01, -6.8357e-02,  6.9423e-01,
-         4.3732e-01,  0.0000e+00, -6.0342e-01, -1.7385e-01, -3.5521e-01,
-        -5.3148e-01, -3.7287e-01,  1.3630e-01, -2.9825e-01, -1.4656e+00,
-        -4.1021e-01, -5.1826e-01, -1.3938e-01, -9.8922e-01, -6.3435e-01,
-        -1.4256e-01,  6.8182e-02, -1.4576e-01, -2.5938e-01,  1.5539e-01,
-        -2.0452e-01,  3.4185e-01,  4.7995e-02, -2.2470e-01,  1.4552e+00,
-        -5.1370e-01, -1.6977e-01,  0.0000e+00,  4.6933e-02, -5.1883e-01,
-        -3.9791e-01,  1.0887e+00,  0.0000e+00,  7.0005e-01,  3.8581e-01,
-        -2.5009e-01,  8.9414e-01,  6.2919e-02,  5.6291e-01, -9.2316e-01,
-        -1.3646e-02,  3.2716e-01, -7.6990e-04,  9.7335e-01, -5.2262e-01,
-         1.3059e-01,  5.4989e-02, -1.3847e-01, -1.1867e-02,  4.2832e-01,
-         0.0000e+00,  1.0667e+00,  0.0000e+00,  6.4024e-01, -2.2205e-01,
-         0.0000e+00, -1.2438e-02, -1.1463e-02,  1.3325e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 4.9059e-01,  3.3983e-01,  1.3530e-01, -6.8357e-02,  6.9423e-01,
-         4.3732e-01,  0.0000e+00, -6.0342e-01, -1.7385e-01, -3.5521e-01,
-        -5.3148e-01, -3.7287e-01,  1.3630e-01, -2.9825e-01, -1.4656e+00,
-        -4.1021e-01, -5.1826e-01, -1.3938e-01, -9.8922e-01, -6.3435e-01,
-        -1.4256e-01,  6.8182e-02, -1.4576e-01, -2.5938e-01,  1.5539e-01,
-        -2.0452e-01,  3.4185e-01,  4.7995e-02, -2.2470e-01,  1.4552e+00,
-        -5.1370e-01, -1.6977e-01,  0.0000e+00,  4.6933e-02, -5.1883e-01,
-        -3.9791e-01,  1.0887e+00,  0.0000e+00,  7.0005e-01,  3.8581e-01,
-        -2.5009e-01,  8.9414e-01,  6.2919e-02,  5.6291e-01, -9.2316e-01,
-        -1.3646e-02,  3.2716e-01, -7.6990e-04,  9.7335e-01, -5.2262e-01,
-         1.3059e-01,  5.4989e-02, -1.3847e-01, -1.1867e-02,  4.2832e-01,
-         0.0000e+00,  1.0667e+00,  0.0000e+00,  6.4024e-01, -2.2205e-01,
-         0.0000e+00, -1.2438e-02, -1.1463e-02,  1.3325e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 4.7081e-01,  2.3883e-01,  1.8326e-01, -7.8665e-02,  6.9662e-01,
-         4.2747e-01,  1.1377e-04, -5.8964e-01, -1.7300e-01, -3.7769e-01,
-        -5.2078e-01, -3.4602e-01,  1.6947e-01, -4.6670e-01, -1.4683e+00,
-        -3.9910e-01, -4.8490e-01, -1.3589e-01, -9.9218e-01, -6.2765e-01,
-        -1.5998e-01, -7.2776e-02, -2.1902e-01, -3.5279e-01,  4.6658e-03,
-        -2.0815e-01,  3.9863e-01, -4.5269e-02, -2.0966e-01,  1.4476e+00,
-        -4.9207e-01, -1.8722e-01, -2.7274e-05, -1.0340e-01, -5.1482e-01,
-        -3.7614e-01,  1.0923e+00,  0.0000e+00,  7.7747e-01,  4.9859e-01,
-        -3.3901e-01,  8.8331e-01,  5.0554e-02,  6.0221e-01, -9.9149e-01,
-        -1.2255e-01,  3.2519e-01,  2.3015e-02,  9.8171e-01, -5.0829e-01,
-         2.4330e-03, -1.2207e-02, -1.3796e-01, -1.2129e-01,  4.5706e-01,
-         1.4725e-06,  1.0683e+00,  2.0812e-04,  6.4925e-01, -2.5422e-01,
-         5.5595e-04, -1.3995e-02, -3.8960e-02,  1.9881e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4708,  0.2388,  0.1833, -0.0787,  0.6966,  0.4275,  0.0000, -0.5896,
-        -0.1730, -0.3777, -0.5208, -0.3460,  0.1695, -0.4667, -1.4683, -0.3991,
-        -0.4849, -0.1359, -0.9922, -0.6277, -0.1600, -0.0728,  0.0000, -0.3528,
-         0.0047, -0.2082,  0.3986, -0.0453, -0.2097,  1.4476, -0.4921, -0.1872,
-         0.0000, -0.1034, -0.5148, -0.3761,  1.0923,  0.0000,  0.7775,  0.4986,
-        -0.3390,  0.8833,  0.0506,  0.6022, -0.9915, -0.1226,  0.3252,  0.0230,
-         0.9817, -0.5083,  0.0024, -0.0122, -0.1380, -0.1213,  0.4571,  0.0000,
-         1.0683,  0.0000,  0.6492, -0.2542,  0.0000, -0.0140, -0.0390,  0.1988],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4708,  0.2388,  0.1833, -0.0787,  0.6966,  0.4275,  0.0000, -0.5896,
-        -0.1730, -0.3777, -0.5208, -0.3460,  0.1695, -0.4667, -1.4683, -0.3991,
-        -0.4849, -0.1359, -0.9922, -0.6277, -0.1600, -0.0728,  0.0000, -0.3528,
-         0.0047, -0.2082,  0.3986, -0.0453, -0.2097,  1.4476, -0.4921, -0.1872,
-         0.0000, -0.1034, -0.5148, -0.3761,  1.0923,  0.0000,  0.7775,  0.4986,
-        -0.3390,  0.8833,  0.0506,  0.6022, -0.9915, -0.1226,  0.3252,  0.0230,
-         0.9817, -0.5083,  0.0024, -0.0122, -0.1380, -0.1213,  0.4571,  0.0000,
-         1.0683,  0.0000,  0.6492, -0.2542,  0.0000, -0.0140, -0.0390,  0.1988],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.4753e-01,  1.7046e-01,  2.1648e-01, -1.0346e-01,  7.2184e-01,
-         4.1942e-01,  9.6817e-05, -5.6937e-01, -1.5476e-01, -3.8554e-01,
-        -4.9998e-01, -2.9748e-01,  1.6451e-01, -5.6919e-01, -1.4679e+00,
-        -3.8453e-01, -4.6549e-01, -1.0699e-01, -9.9176e-01, -6.0293e-01,
-        -1.5204e-01, -1.5665e-01, -6.2346e-02, -4.6023e-01, -1.0676e-01,
-        -2.0659e-01,  4.3018e-01, -8.0905e-02, -1.9533e-01,  1.4393e+00,
-        -4.6111e-01, -1.9065e-01, -2.3211e-05, -2.0491e-01, -5.0935e-01,
-        -3.4284e-01,  1.0918e+00,  0.0000e+00,  8.1896e-01,  5.8240e-01,
-        -3.9636e-01,  8.7870e-01,  1.4896e-02,  6.1872e-01, -1.0375e+00,
-        -1.6153e-01,  3.5192e-01,  1.2937e-02,  9.8651e-01, -5.0111e-01,
-        -9.9456e-02, -2.7428e-02, -9.6373e-02, -1.7430e-01,  4.7531e-01,
-         1.2531e-06,  1.0710e+00,  1.7712e-04,  6.6802e-01, -2.6326e-01,
-         4.7313e-04, -1.5476e-02, -4.5682e-02,  2.3322e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4475,  0.1705,  0.2165, -0.1035,  0.7218,  0.4194,  0.0000, -0.5694,
-        -0.1548, -0.3855, -0.5000, -0.2975,  0.1645, -0.5692, -1.4679, -0.3845,
-        -0.4655, -0.1070, -0.9918, -0.6029, -0.1520, -0.1566,  0.0000, -0.4602,
-        -0.1068, -0.2066,  0.4302, -0.0809, -0.1953,  1.4393, -0.4611, -0.1907,
-         0.0000, -0.2049, -0.5094, -0.3428,  1.0918,  0.0000,  0.8190,  0.5824,
-        -0.3964,  0.8787,  0.0149,  0.6187, -1.0375, -0.1615,  0.3519,  0.0129,
-         0.9865, -0.5011, -0.0995, -0.0274, -0.0964, -0.1743,  0.4753,  0.0000,
-         1.0710,  0.0000,  0.6680, -0.2633,  0.0000, -0.0155, -0.0457,  0.2332],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4475,  0.1705,  0.2165, -0.1035,  0.7218,  0.4194,  0.0000, -0.5694,
-        -0.1548, -0.3855, -0.5000, -0.2975,  0.1645, -0.5692, -1.4679, -0.3845,
-        -0.4655, -0.1070, -0.9918, -0.6029, -0.1520, -0.1566,  0.0000, -0.4602,
-        -0.1068, -0.2066,  0.4302, -0.0809, -0.1953,  1.4393, -0.4611, -0.1907,
-         0.0000, -0.2049, -0.5094, -0.3428,  1.0918,  0.0000,  0.8190,  0.5824,
-        -0.3964,  0.8787,  0.0149,  0.6187, -1.0375, -0.1615,  0.3519,  0.0129,
-         0.9865, -0.5011, -0.0995, -0.0274, -0.0964, -0.1743,  0.4753,  0.0000,
-         1.0710,  0.0000,  0.6680, -0.2633,  0.0000, -0.0155, -0.0457,  0.2332],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.3259e-01,  1.9163e-01,  1.8969e-01, -1.6569e-01,  7.1753e-01,
-         3.8239e-01,  8.2370e-05, -5.6074e-01, -1.0451e-01, -3.8877e-01,
-        -4.4906e-01, -2.9815e-01,  9.2326e-02, -6.3893e-01, -1.4654e+00,
-        -3.5690e-01, -4.9529e-01, -2.1740e-02, -9.8927e-01, -5.9025e-01,
-        -1.5437e-01, -1.1209e-01, -5.3043e-02, -5.5274e-01, -1.1057e-01,
-        -2.0467e-01,  3.9509e-01, -7.6677e-02, -1.3375e-01,  1.4351e+00,
-        -4.3399e-01, -1.8138e-01, -1.9747e-05, -2.2238e-01, -5.0814e-01,
-        -3.1810e-01,  1.0938e+00,  0.0000e+00,  8.2395e-01,  6.1429e-01,
-        -3.5502e-01,  8.7460e-01, -9.1705e-02,  6.1877e-01, -1.0736e+00,
-        -1.6666e-01,  3.7361e-01, -4.3820e-02,  9.8897e-01, -5.2061e-01,
-        -1.2917e-01, -2.1119e-02, -3.5087e-02, -1.1197e-01,  4.6225e-01,
-         1.0661e-06,  1.0726e+00,  1.5069e-04,  6.7779e-01, -2.4294e-01,
-         4.0253e-04, -4.4036e-02, -3.6369e-02,  1.4107e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4326,  0.1916,  0.1897, -0.1657,  0.7175,  0.3824,  0.0000, -0.5607,
-        -0.1045, -0.3888, -0.4491, -0.2981,  0.0923, -0.6389, -1.4654, -0.3569,
-        -0.4953, -0.0217, -0.9893, -0.5903, -0.1544, -0.1121,  0.0000, -0.5527,
-        -0.1106, -0.2047,  0.3951, -0.0767, -0.1337,  1.4351, -0.4340, -0.1814,
-         0.0000, -0.2224, -0.5081, -0.3181,  1.0938,  0.0000,  0.8239,  0.6143,
-        -0.3550,  0.8746, -0.0917,  0.6188, -1.0736, -0.1667,  0.3736, -0.0438,
-         0.9890, -0.5206, -0.1292, -0.0211, -0.0351, -0.1120,  0.4622,  0.0000,
-         1.0726,  0.0000,  0.6778, -0.2429,  0.0000, -0.0440, -0.0364,  0.1411],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4326,  0.1916,  0.1897, -0.1657,  0.7175,  0.3824,  0.0000, -0.5607,
-        -0.1045, -0.3888, -0.4491, -0.2981,  0.0923, -0.6389, -1.4654, -0.3569,
-        -0.4953, -0.0217, -0.9893, -0.5903, -0.1544, -0.1121,  0.0000, -0.5527,
-        -0.1106, -0.2047,  0.3951, -0.0767, -0.1337,  1.4351, -0.4340, -0.1814,
-         0.0000, -0.2224, -0.5081, -0.3181,  1.0938,  0.0000,  0.8239,  0.6143,
-        -0.3550,  0.8746, -0.0917,  0.6188, -1.0736, -0.1667,  0.3736, -0.0438,
-         0.9890, -0.5206, -0.1292, -0.0211, -0.0351, -0.1120,  0.4622,  0.0000,
-         1.0726,  0.0000,  0.6778, -0.2429,  0.0000, -0.0440, -0.0364,  0.1411],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.4532e-01,  2.3895e-01,  1.7110e-01, -2.3145e-01,  7.2178e-01,
-         3.8681e-01,  7.0061e-05, -5.6339e-01, -3.9849e-02, -3.8142e-01,
-        -3.6340e-01, -3.1950e-01,  6.3203e-02, -7.0342e-01, -1.4591e+00,
-        -3.1612e-01, -5.6265e-01,  1.1765e-01, -9.8531e-01, -5.8519e-01,
-        -1.8281e-01,  2.5840e-02, -4.5116e-02, -6.3061e-01, -5.1264e-02,
-        -2.2680e-01,  3.9858e-01, -4.5774e-02, -7.3069e-02,  1.4343e+00,
-        -4.2128e-01, -1.3238e-01, -1.6796e-05, -1.8789e-01, -5.2675e-01,
-        -3.0426e-01,  1.0835e+00,  0.0000e+00,  8.0481e-01,  6.0208e-01,
-        -2.5056e-01,  8.7287e-01, -1.5560e-01,  5.9806e-01, -1.0983e+00,
-        -7.5374e-02,  4.0696e-01, -1.3451e-01,  9.8887e-01, -5.4693e-01,
-        -8.9774e-02, -3.8470e-02,  1.2505e-01,  7.1373e-02,  4.1805e-01,
-         9.0680e-07,  1.0734e+00,  1.2817e-04,  6.8414e-01, -1.8182e-01,
-         3.4237e-04, -2.1057e-02,  1.8580e-03,  8.6332e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4453,  0.2389,  0.1711, -0.2314,  0.7218,  0.3868,  0.0000, -0.5634,
-        -0.0398, -0.3814, -0.3634, -0.3195,  0.0632, -0.7034, -1.4591, -0.3161,
-        -0.5627,  0.1176, -0.9853, -0.5852, -0.1828,  0.0258,  0.0000, -0.6306,
-        -0.0513, -0.2268,  0.3986, -0.0458, -0.0731,  1.4343, -0.4213, -0.1324,
-         0.0000, -0.1879, -0.5268, -0.3043,  1.0835,  0.0000,  0.8048,  0.6021,
-        -0.2506,  0.8729, -0.1556,  0.5981, -1.0983, -0.0754,  0.4070, -0.1345,
-         0.9889, -0.5469, -0.0898, -0.0385,  0.1250,  0.0714,  0.4181,  0.0000,
-         1.0734,  0.0000,  0.6841, -0.1818,  0.0000, -0.0211,  0.0019,  0.0863],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4453,  0.2389,  0.1711, -0.2314,  0.7218,  0.3868,  0.0000, -0.5634,
-        -0.0398, -0.3814, -0.3634, -0.3195,  0.0632, -0.7034, -1.4591, -0.3161,
-        -0.5627,  0.1176, -0.9853, -0.5852, -0.1828,  0.0258,  0.0000, -0.6306,
-        -0.0513, -0.2268,  0.3986, -0.0458, -0.0731,  1.4343, -0.4213, -0.1324,
-         0.0000, -0.1879, -0.5268, -0.3043,  1.0835,  0.0000,  0.8048,  0.6021,
-        -0.2506,  0.8729, -0.1556,  0.5981, -1.0983, -0.0754,  0.4070, -0.1345,
-         0.9889, -0.5469, -0.0898, -0.0385,  0.1250,  0.0714,  0.4181,  0.0000,
-         1.0734,  0.0000,  0.6841, -0.1818,  0.0000, -0.0211,  0.0019,  0.0863],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.6785e-01,  2.4444e-01,  1.9509e-01, -2.8371e-01,  7.2764e-01,
-         4.1348e-01,  5.9576e-05, -5.7969e-01, -1.1963e-02, -3.9287e-01,
-        -3.1136e-01, -3.0052e-01,  6.9259e-02, -7.7039e-01, -1.4514e+00,
-        -3.0913e-01, -6.2357e-01,  2.5411e-01, -9.8252e-01, -5.8364e-01,
-        -2.1796e-01,  8.1617e-02, -3.8365e-02, -7.0534e-01, -5.0387e-02,
-        -2.4585e-01,  4.5036e-01, -4.2769e-02, -5.1513e-02,  1.4313e+00,
-        -4.0993e-01, -1.0435e-01, -1.4283e-05, -2.1790e-01, -5.4341e-01,
-        -2.7610e-01,  1.0780e+00,  0.0000e+00,  7.8962e-01,  5.9004e-01,
-        -1.9473e-01,  8.7561e-01, -2.5912e-01,  5.7900e-01, -1.1183e+00,
-        -1.9432e-02,  4.2397e-01, -1.7701e-01,  9.8784e-01, -5.7921e-01,
-        -5.3255e-02, -9.5059e-02,  2.2315e-01,  3.0519e-01,  3.7591e-01,
-         7.7110e-07,  1.0727e+00,  1.0899e-04,  6.8030e-01, -1.5751e-01,
-         2.9114e-04,  5.7634e-02,  1.5826e-02,  9.1401e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4679,  0.2444,  0.1951, -0.2837,  0.7276,  0.4135,  0.0000, -0.5797,
-        -0.0120, -0.3929, -0.3114, -0.3005,  0.0693, -0.7704, -1.4514, -0.3091,
-        -0.6236,  0.2541, -0.9825, -0.5836, -0.2180,  0.0816,  0.0000, -0.7053,
-        -0.0504, -0.2458,  0.4504, -0.0428, -0.0515,  1.4313, -0.4099, -0.1044,
-         0.0000, -0.2179, -0.5434, -0.2761,  1.0780,  0.0000,  0.7896,  0.5900,
-        -0.1947,  0.8756, -0.2591,  0.5790, -1.1183, -0.0194,  0.4240, -0.1770,
-         0.9878, -0.5792, -0.0533, -0.0951,  0.2231,  0.3052,  0.3759,  0.0000,
-         1.0727,  0.0000,  0.6803, -0.1575,  0.0000,  0.0576,  0.0158,  0.0914],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4679,  0.2444,  0.1951, -0.2837,  0.7276,  0.4135,  0.0000, -0.5797,
-        -0.0120, -0.3929, -0.3114, -0.3005,  0.0693, -0.7704, -1.4514, -0.3091,
-        -0.6236,  0.2541, -0.9825, -0.5836, -0.2180,  0.0816,  0.0000, -0.7053,
-        -0.0504, -0.2458,  0.4504, -0.0428, -0.0515,  1.4313, -0.4099, -0.1044,
-         0.0000, -0.2179, -0.5434, -0.2761,  1.0780,  0.0000,  0.7896,  0.5900,
-        -0.1947,  0.8756, -0.2591,  0.5790, -1.1183, -0.0194,  0.4240, -0.1770,
-         0.9878, -0.5792, -0.0533, -0.0951,  0.2231,  0.3052,  0.3759,  0.0000,
-         1.0727,  0.0000,  0.6803, -0.1575,  0.0000,  0.0576,  0.0158,  0.0914],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.5840e-01,  2.9057e-01,  1.7091e-01, -3.8891e-01,  6.6411e-01,
-         3.9543e-01,  5.0650e-05, -5.9420e-01,  2.8812e-02, -4.0198e-01,
-        -2.8684e-01, -2.7969e-01,  5.0558e-02, -8.2597e-01, -1.4440e+00,
-        -2.5839e-01, -6.3633e-01,  3.8163e-01, -9.8341e-01, -5.8471e-01,
-        -2.2842e-01,  1.0183e-01, -3.2616e-02, -7.3844e-01, -1.1496e-01,
-        -2.5931e-01,  4.7596e-01, -9.4712e-02, -1.3716e-02,  1.4321e+00,
-        -3.8447e-01, -9.2772e-02, -1.2143e-05, -2.8732e-01, -5.3446e-01,
-        -2.4303e-01,  1.0679e+00,  0.0000e+00,  7.6617e-01,  5.9346e-01,
-        -2.1270e-01,  8.5549e-01, -3.9992e-01,  5.5886e-01, -1.1256e+00,
-         3.6758e-04,  4.3037e-01, -2.1781e-01,  9.9584e-01, -5.9193e-01,
-        -3.9790e-02, -1.7109e-01,  3.3329e-01,  5.6285e-01,  2.9734e-01,
-         6.5556e-07,  1.0739e+00,  9.2659e-05,  6.5510e-01, -5.8074e-03,
-         2.4751e-04,  6.2377e-02,  9.9500e-03,  2.4163e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Sparsity at the end of epoch 0: 10.68%
-After Step tensor([ 4.5840e-01,  2.9057e-01,  1.7091e-01, -3.8891e-01,  6.6411e-01,
-         3.9543e-01,  0.0000e+00, -5.9420e-01,  2.8812e-02, -4.0198e-01,
-        -2.8684e-01, -2.7969e-01,  5.0558e-02, -8.2597e-01, -1.4440e+00,
-        -2.5839e-01, -6.3633e-01,  3.8163e-01, -9.8341e-01, -5.8471e-01,
-        -2.2842e-01,  1.0183e-01,  0.0000e+00, -7.3844e-01, -1.1496e-01,
-        -2.5931e-01,  4.7596e-01, -9.4712e-02, -1.3716e-02,  1.4321e+00,
-        -3.8447e-01, -9.2772e-02,  0.0000e+00, -2.8732e-01, -5.3446e-01,
-        -2.4303e-01,  1.0679e+00,  0.0000e+00,  7.6617e-01,  5.9346e-01,
-        -2.1270e-01,  8.5549e-01, -3.9992e-01,  5.5886e-01, -1.1256e+00,
-         3.6758e-04,  4.3037e-01, -2.1781e-01,  9.9584e-01, -5.9193e-01,
-        -3.9790e-02, -1.7109e-01,  3.3329e-01,  5.6285e-01,  2.9734e-01,
-         0.0000e+00,  1.0739e+00,  0.0000e+00,  6.5510e-01, -5.8074e-03,
-         0.0000e+00,  6.2377e-02,  9.9500e-03,  2.4163e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 4.5840e-01,  2.9057e-01,  1.7091e-01, -3.8891e-01,  6.6411e-01,
-         3.9543e-01,  0.0000e+00, -5.9420e-01,  2.8812e-02, -4.0198e-01,
-        -2.8684e-01, -2.7969e-01,  5.0558e-02, -8.2597e-01, -1.4440e+00,
-        -2.5839e-01, -6.3633e-01,  3.8163e-01, -9.8341e-01, -5.8471e-01,
-        -2.2842e-01,  1.0183e-01,  0.0000e+00, -7.3844e-01, -1.1496e-01,
-        -2.5931e-01,  4.7596e-01, -9.4712e-02, -1.3716e-02,  1.4321e+00,
-        -3.8447e-01, -9.2772e-02,  0.0000e+00, -2.8732e-01, -5.3446e-01,
-        -2.4303e-01,  1.0679e+00,  0.0000e+00,  7.6617e-01,  5.9346e-01,
-        -2.1270e-01,  8.5549e-01, -3.9992e-01,  5.5886e-01, -1.1256e+00,
-         3.6758e-04,  4.3037e-01, -2.1781e-01,  9.9584e-01, -5.9193e-01,
-        -3.9790e-02, -1.7109e-01,  3.3329e-01,  5.6285e-01,  2.9734e-01,
-         0.0000e+00,  1.0739e+00,  0.0000e+00,  6.5510e-01, -5.8074e-03,
-         0.0000e+00,  6.2377e-02,  9.9500e-03,  2.4163e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 4.3286e-01,  2.8000e-01,  1.6923e-01, -4.8403e-01,  5.7640e-01,
-         3.2788e-01,  4.3053e-05, -5.9901e-01,  3.5814e-02, -4.1091e-01,
-        -2.9351e-01, -2.3261e-01,  4.8610e-02, -8.9256e-01, -1.4412e+00,
-        -2.1577e-01, -6.2868e-01,  4.5008e-01, -9.8364e-01, -5.7130e-01,
-        -2.5340e-01,  3.5375e-02, -2.7724e-02, -7.5156e-01, -2.0321e-01,
-        -2.5749e-01,  4.3597e-01, -1.6754e-01,  1.3792e-02,  1.4283e+00,
-        -3.4021e-01, -8.7801e-02, -1.0321e-05, -3.5330e-01, -5.1770e-01,
-        -1.7677e-01,  1.0708e+00,  0.0000e+00,  7.8555e-01,  6.6360e-01,
-        -2.6467e-01,  8.3659e-01, -5.0640e-01,  5.7120e-01, -1.1257e+00,
-        -4.0641e-02,  4.0581e-01, -2.2135e-01,  9.9107e-01, -5.9461e-01,
-        -6.3748e-02, -2.2799e-01,  3.6500e-01,  7.3063e-01,  3.0955e-01,
-         5.5724e-07,  1.0745e+00,  7.8762e-05,  5.7942e-01,  1.0340e-01,
-         2.1039e-04, -4.0106e-02,  6.1098e-02, -5.6981e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4329,  0.2800,  0.1692, -0.4840,  0.5764,  0.3279,  0.0000, -0.5990,
-         0.0358, -0.4109, -0.2935, -0.2326,  0.0486, -0.8926, -1.4412, -0.2158,
-        -0.6287,  0.4501, -0.9836, -0.5713, -0.2534,  0.0354,  0.0000, -0.7516,
-        -0.2032, -0.2575,  0.4360, -0.1675,  0.0138,  1.4283, -0.3402, -0.0878,
-         0.0000, -0.3533, -0.5177, -0.1768,  1.0708,  0.0000,  0.7856,  0.6636,
-        -0.2647,  0.8366, -0.5064,  0.5712, -1.1257, -0.0406,  0.4058, -0.2214,
-         0.9911,  0.0000, -0.0637, -0.2280,  0.3650,  0.7306,  0.3096,  0.0000,
-         1.0745,  0.0000,  0.5794,  0.1034,  0.0000, -0.0401,  0.0611, -0.0570],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4329,  0.2800,  0.1692, -0.4840,  0.5764,  0.3279,  0.0000, -0.5990,
-         0.0358, -0.4109, -0.2935, -0.2326,  0.0486, -0.8926, -1.4412, -0.2158,
-        -0.6287,  0.4501, -0.9836, -0.5713, -0.2534,  0.0354,  0.0000, -0.7516,
-        -0.2032, -0.2575,  0.4360, -0.1675,  0.0138,  1.4283, -0.3402, -0.0878,
-         0.0000, -0.3533, -0.5177, -0.1768,  1.0708,  0.0000,  0.7856,  0.6636,
-        -0.2647,  0.8366, -0.5064,  0.5712, -1.1257, -0.0406,  0.4058, -0.2214,
-         0.9911,  0.0000, -0.0637, -0.2280,  0.3650,  0.7306,  0.3096,  0.0000,
-         1.0745,  0.0000,  0.5794,  0.1034,  0.0000, -0.0401,  0.0611, -0.0570],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6872e-01,  2.3193e-01,  1.3710e-01, -5.2275e-01,  5.0012e-01,
-         2.4770e-01,  3.6594e-05, -5.9221e-01,  7.0436e-02, -4.1994e-01,
-        -2.3946e-01, -1.5612e-01,  3.1458e-02, -9.5569e-01, -1.4419e+00,
-        -1.9207e-01, -6.3548e-01,  4.9663e-01, -9.8374e-01, -5.5298e-01,
-        -2.7169e-01, -7.0868e-03, -2.3565e-02, -8.0094e-01, -2.5023e-01,
-        -2.4441e-01,  3.4917e-01, -1.9131e-01, -5.3381e-03,  1.4293e+00,
-        -3.2000e-01, -6.2051e-02, -8.7730e-06, -4.4616e-01, -4.9482e-01,
-        -9.1986e-02,  1.0754e+00,  0.0000e+00,  8.0484e-01,  7.2812e-01,
-        -2.9912e-01,  8.3145e-01, -5.6387e-01,  5.9547e-01, -1.1286e+00,
-        -7.2772e-02,  3.7698e-01, -2.1293e-01,  9.9072e-01, -2.2758e-03,
-        -8.7015e-02, -2.7108e-01,  3.9484e-01,  8.3613e-01,  3.5962e-01,
-         4.7364e-07,  1.0747e+00,  6.6946e-05,  5.0648e-01,  1.8811e-01,
-         1.7883e-04, -1.3825e-01,  1.3570e-01, -4.0340e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3687,  0.2319,  0.1371, -0.5227,  0.5001,  0.2477,  0.0000, -0.5922,
-         0.0704, -0.4199, -0.2395, -0.1561,  0.0315, -0.9557, -1.4419, -0.1921,
-        -0.6355,  0.4966, -0.9837, -0.5530, -0.2717, -0.0071,  0.0000, -0.8009,
-        -0.2502, -0.2444,  0.3492, -0.1913, -0.0053,  1.4293, -0.3200, -0.0621,
-         0.0000, -0.4462, -0.4948, -0.0920,  1.0754,  0.0000,  0.8048,  0.7281,
-        -0.2991,  0.8315, -0.5639,  0.5955, -1.1286, -0.0728,  0.3770, -0.2129,
-         0.9907,  0.0000, -0.0870, -0.2711,  0.3948,  0.8361,  0.3596,  0.0000,
-         1.0747,  0.0000,  0.5065,  0.1881,  0.0000, -0.1382,  0.1357, -0.0403],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3687,  0.2319,  0.1371, -0.5227,  0.5001,  0.2477,  0.0000, -0.5922,
-         0.0704, -0.4199, -0.2395, -0.1561,  0.0315, -0.9557, -1.4419, -0.1921,
-        -0.6355,  0.4966, -0.9837, -0.5530, -0.2717, -0.0071,  0.0000, -0.8009,
-        -0.2502, -0.2444,  0.3492, -0.1913, -0.0053,  1.4293, -0.3200, -0.0621,
-         0.0000, -0.4462, -0.4948, -0.0920,  1.0754,  0.0000,  0.8048,  0.7281,
-        -0.2991,  0.8315, -0.5639,  0.5955, -1.1286, -0.0728,  0.3770, -0.2129,
-         0.9907,  0.0000, -0.0870, -0.2711,  0.3948,  0.8361,  0.3596,  0.0000,
-         1.0747,  0.0000,  0.5065,  0.1881,  0.0000, -0.1382,  0.1357, -0.0403],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.0481e-01,  1.6211e-01,  1.7190e-01, -4.8729e-01,  4.4929e-01,
-         2.1156e-01,  3.1103e-05, -5.7031e-01,  1.0767e-01, -4.2435e-01,
-        -1.3468e-01, -1.0302e-01,  6.8509e-02, -9.9568e-01, -1.4393e+00,
-        -1.6964e-01, -6.6665e-01,  5.1785e-01, -9.8483e-01, -5.3094e-01,
-        -2.4917e-01, -6.3964e-02, -2.0029e-02, -8.5677e-01, -2.5904e-01,
-        -1.9273e-01,  3.1915e-01, -1.4614e-01, -3.0751e-02,  1.4261e+00,
-        -2.9671e-01,  4.8122e-02, -7.4566e-06, -4.9071e-01, -4.5982e-01,
-        -9.4314e-03,  1.0702e+00,  0.0000e+00,  8.1290e-01,  7.7479e-01,
-        -2.9592e-01,  8.2922e-01, -5.8070e-01,  6.3264e-01, -1.1387e+00,
-        -1.5261e-01,  3.2977e-01, -1.7383e-01,  9.7892e-01, -1.9343e-03,
-        -1.0327e-01, -3.3896e-01,  4.1536e-01,  9.2507e-01,  4.2543e-01,
-         4.0257e-07,  1.0728e+00,  5.6900e-05,  4.2439e-01,  3.4125e-01,
-         1.5199e-04, -1.4543e-01,  1.9537e-01,  7.0624e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3048,  0.1621,  0.1719, -0.4873,  0.4493,  0.2116,  0.0000, -0.5703,
-         0.1077, -0.4243, -0.1347, -0.1030,  0.0685, -0.9957, -1.4393, -0.1696,
-        -0.6667,  0.5178, -0.9848, -0.5309, -0.2492, -0.0640,  0.0000, -0.8568,
-        -0.2590, -0.1927,  0.3192, -0.1461, -0.0308,  1.4261, -0.2967,  0.0481,
-         0.0000, -0.4907, -0.4598, -0.0094,  1.0702,  0.0000,  0.8129,  0.7748,
-        -0.2959,  0.8292, -0.5807,  0.6326, -1.1387, -0.1526,  0.3298, -0.1738,
-         0.9789,  0.0000, -0.1033, -0.3390,  0.4154,  0.9251,  0.4254,  0.0000,
-         1.0728,  0.0000,  0.4244,  0.3413,  0.0000, -0.1454,  0.1954,  0.0706],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3048,  0.1621,  0.1719, -0.4873,  0.4493,  0.2116,  0.0000, -0.5703,
-         0.1077, -0.4243, -0.1347, -0.1030,  0.0685, -0.9957, -1.4393, -0.1696,
-        -0.6667,  0.5178, -0.9848, -0.5309, -0.2492, -0.0640,  0.0000, -0.8568,
-        -0.2590, -0.1927,  0.3192, -0.1461, -0.0308,  1.4261, -0.2967,  0.0481,
-         0.0000, -0.4907, -0.4598, -0.0094,  1.0702,  0.0000,  0.8129,  0.7748,
-        -0.2959,  0.8292, -0.5807,  0.6326, -1.1387, -0.1526,  0.3298, -0.1738,
-         0.9789,  0.0000, -0.1033, -0.3390,  0.4154,  0.9251,  0.4254,  0.0000,
-         1.0728,  0.0000,  0.4244,  0.3413,  0.0000, -0.1454,  0.1954,  0.0706],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6365e-01,  1.3955e-01,  1.8071e-01, -3.9785e-01,  3.7089e-01,
-         2.1298e-01,  2.6435e-05, -5.4954e-01,  1.8309e-01, -3.7393e-01,
-         1.3167e-02, -1.6227e-01,  6.0437e-02, -1.0358e+00, -1.4446e+00,
-        -1.1734e-01, -6.8033e-01,  5.3122e-01, -9.8540e-01, -5.4216e-01,
-        -2.2457e-01,  4.6959e-02, -1.7023e-02, -8.9690e-01, -1.1204e-01,
-        -1.8204e-01,  3.1812e-01, -2.0414e-02, -3.7023e-02,  1.4214e+00,
-        -2.4547e-01,  1.4408e-01, -6.3376e-06, -4.2121e-01, -4.6593e-01,
-         2.2624e-02,  1.0569e+00,  0.0000e+00,  7.9468e-01,  8.0103e-01,
-        -1.3462e-01,  8.3528e-01, -6.0019e-01,  6.6424e-01, -1.1410e+00,
-        -2.1555e-01,  3.0108e-01, -1.2877e-01,  9.5973e-01, -1.6441e-03,
-        -7.8448e-02, -3.3698e-01,  4.3697e-01,  9.9099e-01,  4.3861e-01,
-         3.4215e-07,  1.0697e+00,  4.8361e-05,  3.3012e-01,  5.1944e-01,
-         1.2918e-04, -1.2110e-01,  2.3327e-01,  1.6198e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2637,  0.1396,  0.1807, -0.3979,  0.3709,  0.2130,  0.0000, -0.5495,
-         0.1831, -0.3739,  0.0132, -0.1623,  0.0604, -1.0358, -1.4446, -0.1173,
-        -0.6803,  0.5312, -0.9854, -0.5422, -0.2246,  0.0470,  0.0000, -0.8969,
-        -0.1120, -0.1820,  0.3181, -0.0204, -0.0370,  1.4214, -0.2455,  0.1441,
-         0.0000, -0.4212, -0.4659,  0.0226,  1.0569,  0.0000,  0.7947,  0.8010,
-        -0.1346,  0.8353, -0.6002,  0.6642, -1.1410, -0.2155,  0.3011, -0.1288,
-         0.9597,  0.0000, -0.0784, -0.3370,  0.4370,  0.9910,  0.4386,  0.0000,
-         1.0697,  0.0000,  0.3301,  0.5194,  0.0000, -0.1211,  0.2333,  0.1620],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2637,  0.1396,  0.1807, -0.3979,  0.3709,  0.2130,  0.0000, -0.5495,
-         0.1831, -0.3739,  0.0132, -0.1623,  0.0604, -1.0358, -1.4446, -0.1173,
-        -0.6803,  0.5312, -0.9854, -0.5422, -0.2246,  0.0470,  0.0000, -0.8969,
-        -0.1120, -0.1820,  0.3181, -0.0204, -0.0370,  1.4214, -0.2455,  0.1441,
-         0.0000, -0.4212, -0.4659,  0.0226,  1.0569,  0.0000,  0.7947,  0.8010,
-        -0.1346,  0.8353, -0.6002,  0.6642, -1.1410, -0.2155,  0.3011, -0.1288,
-         0.9597,  0.0000, -0.0784, -0.3370,  0.4370,  0.9910,  0.4386,  0.0000,
-         1.0697,  0.0000,  0.3301,  0.5194,  0.0000, -0.1211,  0.2333,  0.1620],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.0751e-01,  1.5776e-01,  1.4082e-01, -3.2038e-01,  3.0326e-01,
-         1.8234e-01,  2.2468e-05, -5.2783e-01,  2.5449e-01, -3.0431e-01,
-         1.5517e-01, -2.5459e-01,  1.3105e-02, -1.0532e+00, -1.4452e+00,
-        -4.1518e-02, -6.7700e-01,  5.2327e-01, -9.8569e-01, -5.5120e-01,
-        -1.8370e-01,  1.5838e-01, -1.4468e-02, -9.2182e-01,  6.7250e-02,
-        -1.3455e-01,  2.8411e-01,  6.4854e-02, -6.5996e-03,  1.4225e+00,
-        -2.0421e-01,  1.8590e-01, -5.3864e-06, -3.2187e-01, -4.4783e-01,
-         2.3858e-02,  1.0340e+00,  0.0000e+00,  7.3643e-01,  8.0524e-01,
-         4.3743e-02,  8.3487e-01, -6.4782e-01,  6.6160e-01, -1.1440e+00,
-        -2.9977e-01,  3.1974e-01, -6.1461e-02,  9.5476e-01, -1.3973e-03,
-        -4.2827e-02, -3.0855e-01,  4.7512e-01,  1.0292e+00,  4.3427e-01,
-         2.9080e-07,  1.0655e+00,  4.1103e-05,  3.1277e-01,  6.6393e-01,
-         1.0980e-04, -1.4363e-01,  2.3713e-01,  1.6046e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2075,  0.1578,  0.1408, -0.3204,  0.3033,  0.1823,  0.0000, -0.5278,
-         0.2545, -0.3043,  0.1552, -0.2546,  0.0131, -1.0532, -1.4452, -0.0415,
-        -0.6770,  0.5233, -0.9857, -0.5512, -0.1837,  0.1584,  0.0000, -0.9218,
-         0.0673, -0.1345,  0.2841,  0.0649, -0.0066,  1.4225, -0.2042,  0.1859,
-         0.0000, -0.3219, -0.4478,  0.0239,  1.0340,  0.0000,  0.7364,  0.8052,
-         0.0437,  0.8349, -0.6478,  0.6616, -1.1440, -0.2998,  0.3197, -0.0615,
-         0.9548,  0.0000, -0.0428, -0.3085,  0.4751,  1.0292,  0.4343,  0.0000,
-         1.0655,  0.0000,  0.3128,  0.6639,  0.0000, -0.1436,  0.2371,  0.1605],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2075,  0.1578,  0.1408, -0.3204,  0.3033,  0.1823,  0.0000, -0.5278,
-         0.2545, -0.3043,  0.1552, -0.2546,  0.0131, -1.0532, -1.4452, -0.0415,
-        -0.6770,  0.5233, -0.9857, -0.5512, -0.1837,  0.1584,  0.0000, -0.9218,
-         0.0673, -0.1345,  0.2841,  0.0649, -0.0066,  1.4225, -0.2042,  0.1859,
-         0.0000, -0.3219, -0.4478,  0.0239,  1.0340,  0.0000,  0.7364,  0.8052,
-         0.0437,  0.8349, -0.6478,  0.6616, -1.1440, -0.2998,  0.3197, -0.0615,
-         0.9548,  0.0000, -0.0428, -0.3085,  0.4751,  1.0292,  0.4343,  0.0000,
-         1.0655,  0.0000,  0.3128,  0.6639,  0.0000, -0.1436,  0.2371,  0.1605],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.4978e-01,  1.8740e-01,  7.5682e-03, -3.0770e-01,  3.4717e-01,
-         1.9704e-01,  1.9095e-05, -5.2046e-01,  3.0162e-01, -2.3308e-01,
-         1.7117e-01, -3.4409e-01, -1.5220e-01, -1.0243e+00, -1.4433e+00,
-        -1.1345e-01, -6.7411e-01,  4.9056e-01, -9.8131e-01, -5.8536e-01,
-        -2.0477e-01,  2.4996e-01, -1.2297e-02, -9.4856e-01,  1.8053e-01,
-        -3.3959e-02,  2.5085e-01,  1.2361e-01,  3.2509e-02,  1.4306e+00,
-        -1.1625e-01,  1.6701e-01, -4.5779e-06, -2.6260e-01, -4.5310e-01,
-        -4.5735e-02,  1.0152e+00,  0.0000e+00,  6.5909e-01,  7.7738e-01,
-         1.3558e-01,  8.3990e-01, -6.9321e-01,  6.5666e-01, -1.1399e+00,
-        -3.3061e-01,  3.5541e-01,  5.2146e-03,  9.6493e-01, -1.1876e-03,
-        -5.0998e-02, -3.0815e-01,  4.9805e-01,  1.0572e+00,  3.6031e-01,
-         2.4715e-07,  1.0605e+00,  3.4933e-05,  3.0198e-01,  7.9845e-01,
-         9.3315e-05, -1.8349e-01,  2.0132e-01,  7.9328e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1498,  0.1874,  0.0076, -0.3077,  0.3472,  0.1970,  0.0000, -0.5205,
-         0.3016, -0.2331,  0.1712, -0.3441, -0.1522, -1.0243, -1.4433, -0.1134,
-        -0.6741,  0.4906, -0.9813, -0.5854, -0.2048,  0.2500,  0.0000, -0.9486,
-         0.1805, -0.0340,  0.2509,  0.1236,  0.0325,  1.4306, -0.1162,  0.1670,
-         0.0000, -0.2626, -0.4531, -0.0457,  1.0152,  0.0000,  0.6591,  0.7774,
-         0.1356,  0.8399, -0.6932,  0.6567, -1.1399, -0.3306,  0.3554,  0.0052,
-         0.9649,  0.0000, -0.0510, -0.3082,  0.4980,  1.0572,  0.3603,  0.0000,
-         1.0605,  0.0000,  0.3020,  0.7984,  0.0000, -0.1835,  0.2013,  0.0793],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1498,  0.1874,  0.0076, -0.3077,  0.3472,  0.1970,  0.0000, -0.5205,
-         0.3016, -0.2331,  0.1712, -0.3441, -0.1522, -1.0243, -1.4433, -0.1134,
-        -0.6741,  0.4906, -0.9813, -0.5854, -0.2048,  0.2500,  0.0000, -0.9486,
-         0.1805, -0.0340,  0.2509,  0.1236,  0.0325,  1.4306, -0.1162,  0.1670,
-         0.0000, -0.2626, -0.4531, -0.0457,  1.0152,  0.0000,  0.6591,  0.7774,
-         0.1356,  0.8399, -0.6932,  0.6567, -1.1399, -0.3306,  0.3554,  0.0052,
-         0.9649,  0.0000, -0.0510, -0.3082,  0.4980,  1.0572,  0.3603,  0.0000,
-         1.0605,  0.0000,  0.3020,  0.7984,  0.0000, -0.1835,  0.2013,  0.0793],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.5810e-01,  1.7891e-01, -4.3123e-02, -2.5934e-01,  3.1817e-01,
-         2.1187e-01,  1.6229e-05, -5.2316e-01,  3.1411e-01, -1.9375e-01,
-         5.0494e-02, -3.7392e-01, -1.9656e-01, -1.0103e+00, -1.4390e+00,
-        -3.2594e-01, -6.9573e-01,  4.4236e-01, -9.7929e-01, -6.1097e-01,
-        -2.6155e-01,  1.9261e-01, -1.0451e-02, -9.8562e-01,  1.6377e-01,
-         1.8883e-01,  2.4317e-01,  9.5977e-02, -3.6470e-02,  1.4446e+00,
-        -4.3615e-02,  8.9426e-02, -3.8907e-06, -2.9347e-01, -4.8162e-01,
-        -1.1294e-01,  1.0058e+00,  0.0000e+00,  6.2938e-01,  7.7551e-01,
-         1.1438e-01,  8.2889e-01, -6.9170e-01,  6.7252e-01, -1.1390e+00,
-        -3.1564e-01,  3.7512e-01,  9.2685e-02,  9.7982e-01, -1.0093e-03,
-        -8.5232e-02, -4.0931e-01,  5.1520e-01,  1.0725e+00,  3.5285e-01,
-         2.1005e-07,  1.0562e+00,  2.9690e-05,  3.0220e-01,  8.9759e-01,
-         7.9308e-05, -1.7013e-01,  1.0625e-01,  7.2991e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1581,  0.1789, -0.0431, -0.2593,  0.3182,  0.2119,  0.0000, -0.5232,
-         0.3141, -0.1937,  0.0505, -0.3739, -0.1966, -1.0103, -1.4390, -0.3259,
-        -0.6957,  0.4424, -0.9793, -0.6110, -0.2615,  0.1926,  0.0000, -0.9856,
-         0.1638,  0.1888,  0.2432,  0.0960, -0.0365,  1.4446, -0.0436,  0.0894,
-         0.0000, -0.2935, -0.4816, -0.1129,  1.0058,  0.0000,  0.6294,  0.7755,
-         0.1144,  0.8289, -0.6917,  0.0000, -1.1390, -0.3156,  0.3751,  0.0927,
-         0.9798,  0.0000, -0.0852, -0.4093,  0.5152,  1.0725,  0.3528,  0.0000,
-         1.0562,  0.0000,  0.3022,  0.8976,  0.0000, -0.1701,  0.1062,  0.0730],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1581,  0.1789, -0.0431, -0.2593,  0.3182,  0.2119,  0.0000, -0.5232,
-         0.3141, -0.1937,  0.0505, -0.3739, -0.1966, -1.0103, -1.4390, -0.3259,
-        -0.6957,  0.4424, -0.9793, -0.6110, -0.2615,  0.1926,  0.0000, -0.9856,
-         0.1638,  0.1888,  0.2432,  0.0960, -0.0365,  1.4446, -0.0436,  0.0894,
-         0.0000, -0.2935, -0.4816, -0.1129,  1.0058,  0.0000,  0.6294,  0.7755,
-         0.1144,  0.8289, -0.6917,  0.0000, -1.1390, -0.3156,  0.3751,  0.0927,
-         0.9798,  0.0000, -0.0852, -0.4093,  0.5152,  1.0725,  0.3528,  0.0000,
-         1.0562,  0.0000,  0.3022,  0.8976,  0.0000, -0.1701,  0.1062,  0.0730],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.3907e-01,  1.0897e-01, -2.4078e-03, -2.1051e-01,  2.7389e-01,
-         2.8915e-01,  1.3793e-05, -5.0355e-01,  3.0349e-01, -1.8780e-01,
-        -4.9766e-02, -3.2657e-01, -1.3447e-01, -9.9413e-01, -1.4344e+00,
-        -5.1470e-01, -7.2613e-01,  4.0641e-01, -9.7538e-01, -6.1477e-01,
-        -2.9949e-01,  8.1025e-02, -8.8820e-03, -1.0141e+00,  9.6718e-02,
-         3.8095e-01,  2.7685e-01,  1.1766e-01, -1.2483e-01,  1.4497e+00,
-        -3.9679e-02,  1.2039e-01, -3.3066e-06, -3.3132e-01, -4.9620e-01,
-        -1.2674e-01,  9.9396e-01,  0.0000e+00,  5.9157e-01,  7.7126e-01,
-         7.7502e-02,  8.0705e-01, -6.6338e-01,  1.3474e-02, -1.1341e+00,
-        -2.8312e-01,  3.4949e-01,  1.9550e-01,  9.7938e-01, -8.5779e-04,
-        -1.1370e-01, -4.8413e-01,  4.8924e-01,  1.0884e+00,  3.3630e-01,
-         1.7852e-07,  1.0544e+00,  2.5233e-05,  2.8714e-01,  9.7866e-01,
-         6.7402e-05, -1.6686e-01,  1.9690e-02,  1.5694e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1391,  0.1090, -0.0024, -0.2105,  0.2739,  0.2892,  0.0000, -0.5035,
-         0.3035, -0.1878, -0.0498, -0.3266, -0.1345, -0.9941, -1.4344, -0.5147,
-        -0.7261,  0.4064, -0.9754, -0.6148, -0.2995,  0.0810,  0.0000, -1.0141,
-         0.0967,  0.3810,  0.2769,  0.1177, -0.1248,  1.4497, -0.0397,  0.1204,
-         0.0000, -0.3313, -0.4962, -0.1267,  0.9940,  0.0000,  0.5916,  0.7713,
-         0.0775,  0.8071, -0.6634,  0.0000, -1.1341, -0.2831,  0.3495,  0.1955,
-         0.9794,  0.0000, -0.1137, -0.4841,  0.4892,  1.0884,  0.3363,  0.0000,
-         1.0544,  0.0000,  0.2871,  0.9787,  0.0000, -0.1669,  0.0197,  0.1569],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1391,  0.1090, -0.0024, -0.2105,  0.2739,  0.2892,  0.0000, -0.5035,
-         0.3035, -0.1878, -0.0498, -0.3266, -0.1345, -0.9941, -1.4344, -0.5147,
-        -0.7261,  0.4064, -0.9754, -0.6148, -0.2995,  0.0810,  0.0000, -1.0141,
-         0.0967,  0.3810,  0.2769,  0.1177, -0.1248,  1.4497, -0.0397,  0.1204,
-         0.0000, -0.3313, -0.4962, -0.1267,  0.9940,  0.0000,  0.5916,  0.7713,
-         0.0775,  0.8071, -0.6634,  0.0000, -1.1341, -0.2831,  0.3495,  0.1955,
-         0.9794,  0.0000, -0.1137, -0.4841,  0.4892,  1.0884,  0.3363,  0.0000,
-         1.0544,  0.0000,  0.2871,  0.9787,  0.0000, -0.1669,  0.0197,  0.1569],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.4294e-01,  5.7035e-02,  3.6000e-02, -1.6203e-01,  2.0263e-01,
-         3.6383e-01,  1.1722e-05, -4.8752e-01,  2.8077e-01, -1.7213e-01,
-        -1.4671e-01, -2.8663e-01, -6.4207e-02, -9.8960e-01, -1.4323e+00,
-        -6.7548e-01, -7.4609e-01,  3.5597e-01, -9.7366e-01, -6.0121e-01,
-        -3.2768e-01, -2.4364e-02, -7.5487e-03, -1.0263e+00,  3.5703e-02,
-         5.4627e-01,  2.5879e-01,  1.0493e-01, -1.7433e-01,  1.4490e+00,
-        -2.1982e-02,  6.5968e-02, -2.8103e-06, -3.2079e-01, -4.8514e-01,
-        -1.5897e-01,  9.8545e-01,  0.0000e+00,  5.6668e-01,  7.7991e-01,
-         2.6700e-02,  7.9484e-01, -6.5106e-01,  1.1451e-02, -1.1302e+00,
-        -3.0412e-01,  3.3777e-01,  2.8504e-01,  9.7216e-01, -7.2902e-04,
-        -1.4906e-01, -4.6675e-01,  4.9053e-01,  1.0953e+00,  3.3842e-01,
-         1.5172e-07,  1.0527e+00,  2.1445e-05,  2.8793e-01,  1.0501e+00,
-         5.7284e-05, -1.8349e-01, -6.0638e-02,  1.8071e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1429,  0.0570,  0.0360, -0.1620,  0.2026,  0.3638,  0.0000, -0.4875,
-         0.2808, -0.1721, -0.1467, -0.2866, -0.0642, -0.9896, -1.4323, -0.6755,
-        -0.7461,  0.3560, -0.9737, -0.6012, -0.3277, -0.0244,  0.0000, -1.0263,
-         0.0357,  0.5463,  0.2588,  0.1049, -0.1743,  1.4490, -0.0220,  0.0660,
-         0.0000, -0.3208, -0.4851, -0.1590,  0.9854,  0.0000,  0.5667,  0.7799,
-         0.0267,  0.7948, -0.6511,  0.0000, -1.1302, -0.3041,  0.3378,  0.2850,
-         0.9722,  0.0000, -0.1491, -0.4668,  0.4905,  1.0953,  0.3384,  0.0000,
-         1.0527,  0.0000,  0.2879,  1.0501,  0.0000, -0.1835, -0.0606,  0.1807],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1429,  0.0570,  0.0360, -0.1620,  0.2026,  0.3638,  0.0000, -0.4875,
-         0.2808, -0.1721, -0.1467, -0.2866, -0.0642, -0.9896, -1.4323, -0.6755,
-        -0.7461,  0.3560, -0.9737, -0.6012, -0.3277, -0.0244,  0.0000, -1.0263,
-         0.0357,  0.5463,  0.2588,  0.1049, -0.1743,  1.4490, -0.0220,  0.0660,
-         0.0000, -0.3208, -0.4851, -0.1590,  0.9854,  0.0000,  0.5667,  0.7799,
-         0.0267,  0.7948, -0.6511,  0.0000, -1.1302, -0.3041,  0.3378,  0.2850,
-         0.9722,  0.0000, -0.1491, -0.4668,  0.4905,  1.0953,  0.3384,  0.0000,
-         1.0527,  0.0000,  0.2879,  1.0501,  0.0000, -0.1835, -0.0606,  0.1807],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.3240e-01,  8.7308e-02, -2.7938e-02, -1.3434e-01,  5.9964e-02,
-         2.6087e-01,  9.9626e-06, -4.7517e-01,  1.6981e-01, -1.3111e-01,
-        -1.8134e-01, -2.7601e-01, -5.3990e-02, -9.9595e-01, -1.4359e+00,
-        -7.9622e-01, -7.5244e-01,  3.1461e-01, -9.7651e-01, -5.9997e-01,
-        -3.4420e-01, -6.1874e-02, -6.4155e-03, -1.0214e+00,  4.3662e-03,
-         6.8517e-01,  4.5901e-02,  2.7926e-02, -1.3818e-01,  1.4562e+00,
-         3.8682e-02, -2.3016e-01, -2.3884e-06, -2.7575e-01, -4.7148e-01,
-        -2.0629e-01,  1.0038e+00,  0.0000e+00,  5.4027e-01,  8.0768e-01,
-         1.4057e-02,  7.7089e-01, -6.2107e-01,  9.7324e-03, -1.1329e+00,
-        -3.2476e-01,  3.6568e-01,  3.0654e-01,  9.7098e-01, -6.1959e-04,
-        -1.5225e-01, -3.8552e-01,  5.1552e-01,  1.0618e+00,  4.4740e-01,
-         1.2895e-07,  1.0508e+00,  1.8226e-05,  3.3540e-01,  1.0993e+00,
-         4.8685e-05, -1.9179e-01, -7.7134e-02, -8.3810e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2324,  0.0873, -0.0279, -0.1343,  0.0600,  0.2609,  0.0000, -0.4752,
-         0.1698, -0.1311, -0.1813, -0.2760, -0.0540, -0.9959, -1.4359, -0.7962,
-        -0.7524,  0.3146, -0.9765, -0.6000, -0.3442, -0.0619,  0.0000, -1.0214,
-         0.0044,  0.6852,  0.0459,  0.0279, -0.1382,  1.4562,  0.0387, -0.2302,
-         0.0000, -0.2758, -0.4715, -0.2063,  1.0038,  0.0000,  0.5403,  0.8077,
-         0.0141,  0.7709, -0.6211,  0.0000, -1.1329, -0.3248,  0.3657,  0.3065,
-         0.9710,  0.0000, -0.1522, -0.3855,  0.5155,  1.0618,  0.4474,  0.0000,
-         1.0508,  0.0000,  0.3354,  1.0993,  0.0000, -0.1918, -0.0771, -0.0838],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2324,  0.0873, -0.0279, -0.1343,  0.0600,  0.2609,  0.0000, -0.4752,
-         0.1698, -0.1311, -0.1813, -0.2760, -0.0540, -0.9959, -1.4359, -0.7962,
-        -0.7524,  0.3146, -0.9765, -0.6000, -0.3442, -0.0619,  0.0000, -1.0214,
-         0.0044,  0.6852,  0.0459,  0.0279, -0.1382,  1.4562,  0.0387, -0.2302,
-         0.0000, -0.2758, -0.4715, -0.2063,  1.0038,  0.0000,  0.5403,  0.8077,
-         0.0141,  0.7709, -0.6211,  0.0000, -1.1329, -0.3248,  0.3657,  0.3065,
-         0.9710,  0.0000, -0.1522, -0.3855,  0.5155,  1.0618,  0.4474,  0.0000,
-         1.0508,  0.0000,  0.3354,  1.0993,  0.0000, -0.1918, -0.0771, -0.0838],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7245e-01,  1.4616e-01, -4.8475e-02, -5.8708e-02,  3.2515e-02,
-         2.3619e-01,  8.4671e-06, -4.5635e-01,  8.2299e-02, -1.6468e-01,
-        -1.3310e-01, -2.5594e-01, -2.0201e-02, -9.9734e-01, -1.4334e+00,
-        -8.8470e-01, -7.6120e-01,  2.3122e-01, -9.8036e-01, -5.8429e-01,
-        -3.1775e-01, -1.0592e-01, -5.4525e-03, -1.0157e+00, -5.7045e-03,
-         7.7844e-01, -1.2500e-01, -2.5594e-02, -9.9459e-02,  1.4618e+00,
-         1.1155e-01, -4.6705e-01, -2.0299e-06, -2.5208e-01, -4.4062e-01,
-        -2.3959e-01,  1.0081e+00,  0.0000e+00,  4.7566e-01,  8.2243e-01,
-        -1.2147e-04,  7.4697e-01, -5.9488e-01,  8.2715e-03, -1.1340e+00,
-        -3.6533e-01,  4.0094e-01,  3.3818e-01,  9.8065e-01, -5.2658e-04,
-        -1.2866e-01, -2.8538e-01,  5.3223e-01,  1.0291e+00,  5.2272e-01,
-         1.0959e-07,  1.0486e+00,  1.5490e-05,  4.0599e-01,  1.1393e+00,
-         4.1377e-05, -2.1226e-01, -6.8684e-02, -2.7417e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.7245e-01,  1.4616e-01, -4.8475e-02, -5.8708e-02,  3.2515e-02,
-         2.3619e-01,  0.0000e+00, -4.5635e-01,  8.2299e-02, -1.6468e-01,
-        -1.3310e-01, -2.5594e-01, -2.0201e-02, -9.9734e-01, -1.4334e+00,
-        -8.8470e-01, -7.6120e-01,  2.3122e-01, -9.8036e-01, -5.8429e-01,
-        -3.1775e-01, -1.0592e-01,  0.0000e+00, -1.0157e+00, -5.7045e-03,
-         7.7844e-01, -1.2500e-01, -2.5594e-02, -9.9459e-02,  1.4618e+00,
-         1.1155e-01, -4.6705e-01,  0.0000e+00, -2.5208e-01, -4.4062e-01,
-        -2.3959e-01,  1.0081e+00,  0.0000e+00,  4.7566e-01,  8.2243e-01,
-        -1.2147e-04,  7.4697e-01, -5.9488e-01,  0.0000e+00, -1.1340e+00,
-        -3.6533e-01,  4.0094e-01,  3.3818e-01,  9.8065e-01,  0.0000e+00,
-        -1.2866e-01, -2.8538e-01,  5.3223e-01,  1.0291e+00,  5.2272e-01,
-         0.0000e+00,  1.0486e+00,  0.0000e+00,  4.0599e-01,  1.1393e+00,
-         0.0000e+00, -2.1226e-01, -6.8684e-02, -2.7417e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.7245e-01,  1.4616e-01, -4.8475e-02, -5.8708e-02,  3.2515e-02,
-         2.3619e-01,  0.0000e+00, -4.5635e-01,  8.2299e-02, -1.6468e-01,
-        -1.3310e-01, -2.5594e-01, -2.0201e-02, -9.9734e-01, -1.4334e+00,
-        -8.8470e-01, -7.6120e-01,  2.3122e-01, -9.8036e-01, -5.8429e-01,
-        -3.1775e-01, -1.0592e-01,  0.0000e+00, -1.0157e+00, -5.7045e-03,
-         7.7844e-01, -1.2500e-01, -2.5594e-02, -9.9459e-02,  1.4618e+00,
-         1.1155e-01, -4.6705e-01,  0.0000e+00, -2.5208e-01, -4.4062e-01,
-        -2.3959e-01,  1.0081e+00,  0.0000e+00,  4.7566e-01,  8.2243e-01,
-        -1.2147e-04,  7.4697e-01, -5.9488e-01,  0.0000e+00, -1.1340e+00,
-        -3.6533e-01,  4.0094e-01,  3.3818e-01,  9.8065e-01,  0.0000e+00,
-        -1.2866e-01, -2.8538e-01,  5.3223e-01,  1.0291e+00,  5.2272e-01,
-         0.0000e+00,  1.0486e+00,  0.0000e+00,  4.0599e-01,  1.1393e+00,
-         0.0000e+00, -2.1226e-01, -6.8684e-02, -2.7417e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1723e-01,  1.8953e-01, -4.1519e-02,  6.6128e-02, -3.8059e-02,
-         2.2362e-01,  7.1962e-06, -4.2616e-01,  2.0236e-02, -1.5731e-01,
-        -5.0043e-02, -2.2810e-01,  6.0120e-02, -9.9804e-01, -1.4341e+00,
-        -9.4834e-01, -7.6613e-01,  1.8811e-01, -9.8254e-01, -5.7600e-01,
-        -2.7366e-01, -4.6454e-02, -4.6340e-03, -1.0066e+00,  7.4466e-02,
-         8.4078e-01, -2.1329e-01, -7.5612e-03, -8.7748e-02,  1.4672e+00,
-         1.9695e-01, -6.6891e-01, -1.7252e-06, -1.8003e-01, -4.1023e-01,
-        -2.6113e-01,  1.0044e+00,  0.0000e+00,  3.7794e-01,  8.1377e-01,
-         2.3878e-02,  7.3331e-01, -5.5837e-01,  7.0299e-03, -1.1330e+00,
-        -3.9198e-01,  4.1058e-01,  3.6122e-01,  9.8599e-01, -4.4754e-04,
-        -7.5417e-02, -1.5858e-01,  5.2204e-01,  1.0050e+00,  5.7755e-01,
-         9.3140e-08,  1.0461e+00,  1.3165e-05,  4.6594e-01,  1.1767e+00,
-         3.5166e-05, -1.5269e-01, -2.7973e-02, -4.0337e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3172,  0.1895, -0.0415,  0.0661, -0.0381,  0.2236,  0.0000, -0.4262,
-         0.0202, -0.1573, -0.0500, -0.2281,  0.0601, -0.9980, -1.4341, -0.9483,
-        -0.7661,  0.1881,  0.0000, -0.5760, -0.2737, -0.0465,  0.0000, -1.0066,
-         0.0745,  0.8408, -0.2133, -0.0076, -0.0877,  1.4672,  0.1970, -0.6689,
-         0.0000, -0.1800, -0.4102, -0.2611,  1.0044,  0.0000,  0.3779,  0.8138,
-         0.0239,  0.7333, -0.5584,  0.0000, -1.1330, -0.3920,  0.4106,  0.3612,
-         0.9860,  0.0000, -0.0754, -0.1586,  0.5220,  1.0050,  0.5775,  0.0000,
-         1.0461,  0.0000,  0.4659,  1.1767,  0.0000, -0.1527, -0.0280, -0.4034],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3172,  0.1895, -0.0415,  0.0661, -0.0381,  0.2236,  0.0000, -0.4262,
-         0.0202, -0.1573, -0.0500, -0.2281,  0.0601, -0.9980, -1.4341, -0.9483,
-        -0.7661,  0.1881,  0.0000, -0.5760, -0.2737, -0.0465,  0.0000, -1.0066,
-         0.0745,  0.8408, -0.2133, -0.0076, -0.0877,  1.4672,  0.1970, -0.6689,
-         0.0000, -0.1800, -0.4102, -0.2611,  1.0044,  0.0000,  0.3779,  0.8138,
-         0.0239,  0.7333, -0.5584,  0.0000, -1.1330, -0.3920,  0.4106,  0.3612,
-         0.9860,  0.0000, -0.0754, -0.1586,  0.5220,  1.0050,  0.5775,  0.0000,
-         1.0461,  0.0000,  0.4659,  1.1767,  0.0000, -0.1527, -0.0280, -0.4034],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7397e-01,  1.8845e-01, -6.0694e-03,  2.2073e-01, -9.1157e-02,
-         2.3191e-01,  6.1161e-06, -4.1566e-01, -2.2269e-02, -1.6865e-01,
-         2.3742e-02, -1.9176e-01,  1.7879e-01, -9.9931e-01, -1.4416e+00,
-        -1.0074e+00, -7.7152e-01,  1.8252e-01, -1.8546e-03, -5.6476e-01,
-        -2.3775e-01, -1.1553e-03, -3.9385e-03, -9.9364e-01,  1.3416e-01,
-         8.8688e-01, -2.8534e-01,  3.3260e-02, -6.8379e-02,  1.4686e+00,
-         2.3371e-01, -8.3692e-01, -1.4663e-06, -1.1921e-01, -3.8358e-01,
-        -2.8047e-01,  1.0040e+00,  0.0000e+00,  2.7674e-01,  8.0635e-01,
-         6.2674e-02,  7.1552e-01, -5.3071e-01,  5.9748e-03, -1.1290e+00,
-        -3.8029e-01,  4.0189e-01,  3.7826e-01,  9.6939e-01, -3.8037e-04,
-        -3.0626e-02, -4.4263e-02,  4.8054e-01,  9.8598e-01,  6.3348e-01,
-         7.9161e-08,  1.0437e+00,  1.1189e-05,  4.8936e-01,  1.2055e+00,
-         2.9888e-05, -8.3759e-02,  1.8840e-02, -5.0062e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.7397e-01,  1.8845e-01, -6.0694e-03,  2.2073e-01, -9.1157e-02,
-         2.3191e-01,  0.0000e+00, -4.1566e-01, -2.2269e-02, -1.6865e-01,
-         2.3742e-02, -1.9176e-01,  1.7879e-01, -9.9931e-01, -1.4416e+00,
-        -1.0074e+00, -7.7152e-01,  1.8252e-01,  0.0000e+00, -5.6476e-01,
-        -2.3775e-01, -1.1553e-03,  0.0000e+00, -9.9364e-01,  1.3416e-01,
-         8.8688e-01, -2.8534e-01,  3.3260e-02, -6.8379e-02,  1.4686e+00,
-         2.3371e-01, -8.3692e-01,  0.0000e+00, -1.1921e-01, -3.8358e-01,
-        -2.8047e-01,  1.0040e+00,  0.0000e+00,  2.7674e-01,  8.0635e-01,
-         6.2674e-02,  7.1552e-01, -5.3071e-01,  0.0000e+00, -1.1290e+00,
-        -3.8029e-01,  4.0189e-01,  3.7826e-01,  9.6939e-01,  0.0000e+00,
-        -3.0626e-02, -4.4263e-02,  4.8054e-01,  9.8598e-01,  6.3348e-01,
-         0.0000e+00,  1.0437e+00,  0.0000e+00,  4.8936e-01,  1.2055e+00,
-         0.0000e+00, -8.3759e-02,  1.8840e-02, -5.0062e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.7397e-01,  1.8845e-01, -6.0694e-03,  2.2073e-01, -9.1157e-02,
-         2.3191e-01,  0.0000e+00, -4.1566e-01, -2.2269e-02, -1.6865e-01,
-         2.3742e-02, -1.9176e-01,  1.7879e-01, -9.9931e-01, -1.4416e+00,
-        -1.0074e+00, -7.7152e-01,  1.8252e-01,  0.0000e+00, -5.6476e-01,
-        -2.3775e-01, -1.1553e-03,  0.0000e+00, -9.9364e-01,  1.3416e-01,
-         8.8688e-01, -2.8534e-01,  3.3260e-02, -6.8379e-02,  1.4686e+00,
-         2.3371e-01, -8.3692e-01,  0.0000e+00, -1.1921e-01, -3.8358e-01,
-        -2.8047e-01,  1.0040e+00,  0.0000e+00,  2.7674e-01,  8.0635e-01,
-         6.2674e-02,  7.1552e-01, -5.3071e-01,  0.0000e+00, -1.1290e+00,
-        -3.8029e-01,  4.0189e-01,  3.7826e-01,  9.6939e-01,  0.0000e+00,
-        -3.0626e-02, -4.4263e-02,  4.8054e-01,  9.8598e-01,  6.3348e-01,
-         0.0000e+00,  1.0437e+00,  0.0000e+00,  4.8936e-01,  1.2055e+00,
-         0.0000e+00, -8.3759e-02,  1.8840e-02, -5.0062e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 4.0766e-01,  1.8012e-01,  4.8314e-03,  3.2518e-01, -3.5104e-02,
-         2.2786e-01,  5.1983e-06, -4.2209e-01, -8.0147e-02, -2.2458e-01,
-         7.4818e-02, -1.7145e-01,  2.3669e-01, -1.0024e+00, -1.4472e+00,
-        -1.0546e+00, -7.6293e-01,  1.3986e-01, -1.5763e-03, -5.4820e-01,
-        -1.8418e-01,  3.4674e-02, -3.3475e-03, -9.7653e-01,  1.8591e-01,
-         9.3103e-01, -3.5723e-01,  5.3907e-02,  3.0640e-03,  1.4679e+00,
-         2.2266e-01, -9.7920e-01, -1.2462e-06, -6.4136e-02, -3.3288e-01,
-        -2.9960e-01,  1.0235e+00,  0.0000e+00,  2.0724e-01,  8.2444e-01,
-         6.9693e-02,  7.0790e-01, -5.2680e-01,  5.0782e-03, -1.1269e+00,
-        -3.7162e-01,  4.0907e-01,  3.9645e-01,  9.4605e-01, -3.2329e-04,
-        -8.4121e-03,  5.1231e-02,  4.3888e-01,  9.5349e-01,  6.9666e-01,
-         6.7281e-08,  1.0416e+00,  9.5098e-06,  5.2024e-01,  1.2266e+00,
-         2.5403e-05, -8.5101e-02,  5.7277e-02, -5.6608e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4077,  0.1801,  0.0048,  0.3252, -0.0351,  0.2279,  0.0000, -0.4221,
-        -0.0801, -0.2246,  0.0748, -0.1715,  0.2367, -1.0024, -1.4472, -1.0546,
-        -0.7629,  0.1399,  0.0000, -0.5482, -0.1842,  0.0347,  0.0000, -0.9765,
-         0.1859,  0.9310, -0.3572,  0.0539,  0.0031,  1.4679,  0.2227, -0.9792,
-         0.0000, -0.0641, -0.3329, -0.2996,  1.0235,  0.0000,  0.2072,  0.8244,
-         0.0697,  0.7079, -0.5268,  0.0000, -1.1269, -0.3716,  0.4091,  0.3965,
-         0.9460,  0.0000, -0.0084,  0.0512,  0.4389,  0.9535,  0.6967,  0.0000,
-         1.0416,  0.0000,  0.5202,  1.2266,  0.0000, -0.0851,  0.0573, -0.5661],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4077,  0.1801,  0.0048,  0.3252, -0.0351,  0.2279,  0.0000, -0.4221,
-        -0.0801, -0.2246,  0.0748, -0.1715,  0.2367, -1.0024, -1.4472, -1.0546,
-        -0.7629,  0.1399,  0.0000, -0.5482, -0.1842,  0.0347,  0.0000, -0.9765,
-         0.1859,  0.9310, -0.3572,  0.0539,  0.0031,  1.4679,  0.2227, -0.9792,
-         0.0000, -0.0641, -0.3329, -0.2996,  1.0235,  0.0000,  0.2072,  0.8244,
-         0.0697,  0.7079, -0.5268,  0.0000, -1.1269, -0.3716,  0.4091,  0.3965,
-         0.9460,  0.0000, -0.0084,  0.0512,  0.4389,  0.9535,  0.6967,  0.0000,
-         1.0416,  0.0000,  0.5202,  1.2266,  0.0000, -0.0851,  0.0573, -0.5661],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.2423e-01,  1.5502e-01,  5.2634e-02,  4.5070e-01, -6.2584e-02,
-         2.2690e-01,  4.4183e-06, -4.1950e-01, -1.1429e-01, -2.7776e-01,
-         6.6411e-02, -1.4908e-01,  3.3049e-01, -9.9083e-01, -1.4555e+00,
-        -1.0897e+00, -7.5201e-01,  1.0624e-01, -1.3398e-03, -5.3843e-01,
-        -1.3009e-01,  5.5861e-03, -2.8452e-03, -9.4983e-01,  1.8258e-01,
-         9.7111e-01, -3.9392e-01,  8.0380e-02,  5.9078e-03,  1.4665e+00,
-         2.0823e-01, -1.0944e+00, -1.0592e-06, -2.5834e-02, -2.9611e-01,
-        -3.1231e-01,  1.0370e+00,  0.0000e+00,  1.4405e-01,  8.1867e-01,
-         6.9479e-02,  6.9087e-01, -5.2009e-01,  4.3162e-03, -1.1221e+00,
-        -3.9937e-01,  3.9411e-01,  3.9697e-01,  9.2194e-01, -2.7478e-04,
-        -1.5683e-02,  1.2875e-01,  4.0314e-01,  9.2555e-01,  7.3061e-01,
-         5.7186e-08,  1.0391e+00,  8.0828e-06,  5.3779e-01,  1.2377e+00,
-         2.1591e-05, -3.3102e-02,  7.2290e-02, -5.9700e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4242,  0.1550,  0.0526,  0.4507, -0.0626,  0.2269,  0.0000, -0.4195,
-        -0.1143, -0.2778,  0.0664, -0.1491,  0.3305, -0.9908, -1.4555, -1.0897,
-        -0.7520,  0.1062,  0.0000, -0.5384, -0.1301,  0.0056,  0.0000, -0.9498,
-         0.1826,  0.9711, -0.3939,  0.0804,  0.0059,  1.4665,  0.2082, -1.0944,
-         0.0000, -0.0258, -0.2961, -0.3123,  1.0370,  0.0000,  0.1440,  0.8187,
-         0.0695,  0.6909, -0.5201,  0.0000, -1.1221, -0.3994,  0.3941,  0.3970,
-         0.9219,  0.0000, -0.0157,  0.1288,  0.4031,  0.9256,  0.7306,  0.0000,
-         1.0391,  0.0000,  0.5378,  1.2377,  0.0000, -0.0331,  0.0723, -0.5970],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4242,  0.1550,  0.0526,  0.4507, -0.0626,  0.2269,  0.0000, -0.4195,
-        -0.1143, -0.2778,  0.0664, -0.1491,  0.3305, -0.9908, -1.4555, -1.0897,
-        -0.7520,  0.1062,  0.0000, -0.5384, -0.1301,  0.0056,  0.0000, -0.9498,
-         0.1826,  0.9711, -0.3939,  0.0804,  0.0059,  1.4665,  0.2082, -1.0944,
-         0.0000, -0.0258, -0.2961, -0.3123,  1.0370,  0.0000,  0.1440,  0.8187,
-         0.0695,  0.6909, -0.5201,  0.0000, -1.1221, -0.3994,  0.3941,  0.3970,
-         0.9219,  0.0000, -0.0157,  0.1288,  0.4031,  0.9256,  0.7306,  0.0000,
-         1.0391,  0.0000,  0.5378,  1.2377,  0.0000, -0.0331,  0.0723, -0.5970],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7248e-01,  1.7891e-01, -8.4729e-02,  4.3486e-01, -1.1811e-01,
-         2.6720e-01,  3.7554e-06, -4.3366e-01, -1.3607e-01, -2.8671e-01,
-        -2.9530e-02, -1.8005e-01,  2.4557e-01, -9.8832e-01, -1.4608e+00,
-        -1.1149e+00, -7.2996e-01, -7.9751e-03, -1.1388e-03, -5.5468e-01,
-        -1.1021e-01, -4.9983e-02, -2.4183e-03, -9.2346e-01,  1.1904e-01,
-         1.0111e+00, -4.2528e-01, -5.5788e-02,  4.3783e-02,  1.4679e+00,
-         1.7003e-01, -1.1849e+00, -9.0031e-07,  1.0205e-02, -2.8241e-01,
-        -3.4664e-01,  1.0456e+00,  0.0000e+00,  2.3582e-01,  8.1166e-01,
-        -6.0303e-03,  6.5958e-01, -4.7602e-01,  3.6686e-03, -1.1173e+00,
-        -3.8422e-01,  3.7250e-01,  4.9312e-01,  9.3944e-01, -2.3355e-04,
-        -2.6169e-02,  7.4678e-02,  3.8696e-01,  8.6395e-01,  7.3319e-01,
-         4.8606e-08,  1.0360e+00,  6.8702e-06,  5.7537e-01,  1.2453e+00,
-         1.8352e-05, -4.8375e-02,  1.1372e-01, -6.2985e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3725,  0.1789, -0.0847,  0.4349, -0.1181,  0.2672,  0.0000, -0.4337,
-        -0.1361, -0.2867, -0.0295, -0.1801,  0.2456, -0.9883, -1.4608, -1.1149,
-        -0.7300, -0.0080,  0.0000, -0.5547, -0.1102, -0.0500,  0.0000, -0.9235,
-         0.1190,  1.0111, -0.4253, -0.0558,  0.0438,  1.4679,  0.1700, -1.1849,
-         0.0000,  0.0102, -0.2824, -0.3466,  1.0456,  0.0000,  0.2358,  0.8117,
-        -0.0060,  0.6596, -0.4760,  0.0000, -1.1173, -0.3842,  0.3725,  0.4931,
-         0.9394,  0.0000, -0.0262,  0.0747,  0.3870,  0.8639,  0.7332,  0.0000,
-         1.0360,  0.0000,  0.5754,  1.2453,  0.0000, -0.0484,  0.1137, -0.6298],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3725,  0.1789, -0.0847,  0.4349, -0.1181,  0.2672,  0.0000, -0.4337,
-        -0.1361, -0.2867, -0.0295, -0.1801,  0.2456, -0.9883, -1.4608, -1.1149,
-        -0.7300, -0.0080,  0.0000, -0.5547, -0.1102, -0.0500,  0.0000, -0.9235,
-         0.1190,  1.0111, -0.4253, -0.0558,  0.0438,  1.4679,  0.1700, -1.1849,
-         0.0000,  0.0102, -0.2824, -0.3466,  1.0456,  0.0000,  0.2358,  0.8117,
-        -0.0060,  0.6596, -0.4760,  0.0000, -1.1173, -0.3842,  0.3725,  0.4931,
-         0.9394,  0.0000, -0.0262,  0.0747,  0.3870,  0.8639,  0.7332,  0.0000,
-         1.0360,  0.0000,  0.5754,  1.2453,  0.0000, -0.0484,  0.1137, -0.6298],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1747e-01,  2.1618e-01, -2.0078e-01,  4.2589e-01, -1.0982e-01,
-         3.0594e-01,  3.1921e-06, -4.3070e-01, -1.3890e-01, -3.0253e-01,
-        -8.5740e-02, -2.0321e-01,  1.6348e-01, -9.9061e-01, -1.4694e+00,
-        -1.1350e+00, -7.0274e-01, -9.8771e-02, -9.6795e-04, -5.6233e-01,
-        -8.6700e-02, -8.8027e-02, -2.0556e-03, -8.9777e-01,  6.3903e-02,
-         1.0498e+00, -4.2763e-01, -1.2969e-01,  7.6184e-02,  1.4682e+00,
-         1.7568e-01, -1.2637e+00, -7.6526e-07,  3.5928e-02, -2.6151e-01,
-        -3.8249e-01,  1.0583e+00,  0.0000e+00,  2.6328e-01,  8.0901e-01,
-        -5.8050e-02,  6.3151e-01, -4.4432e-01,  3.1183e-03, -1.1119e+00,
-        -3.9846e-01,  3.5773e-01,  5.5995e-01,  9.4448e-01, -1.9852e-04,
-        -6.4543e-02,  2.4372e-02,  3.7686e-01,  8.0046e-01,  7.4493e-01,
-         4.1315e-08,  1.0331e+00,  5.8396e-06,  5.9819e-01,  1.2520e+00,
-         1.5599e-05, -6.4074e-02,  1.5937e-01, -6.3664e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3175,  0.2162, -0.2008,  0.4259, -0.1098,  0.3059,  0.0000, -0.4307,
-        -0.1389, -0.3025, -0.0857, -0.2032,  0.1635, -0.9906, -1.4694, -1.1350,
-        -0.7027, -0.0988,  0.0000, -0.5623, -0.0867, -0.0880,  0.0000, -0.8978,
-         0.0639,  1.0498, -0.4276, -0.1297,  0.0762,  1.4682,  0.1757, -1.2637,
-         0.0000,  0.0359, -0.2615, -0.3825,  0.0000,  0.0000,  0.2633,  0.8090,
-        -0.0580,  0.6315, -0.4443,  0.0000, -1.1119, -0.3985,  0.3577,  0.5599,
-         0.9445,  0.0000, -0.0645,  0.0244,  0.3769,  0.8005,  0.7449,  0.0000,
-         1.0331,  0.0000,  0.5982,  1.2520,  0.0000, -0.0641,  0.1594, -0.6366],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3175,  0.2162, -0.2008,  0.4259, -0.1098,  0.3059,  0.0000, -0.4307,
-        -0.1389, -0.3025, -0.0857, -0.2032,  0.1635, -0.9906, -1.4694, -1.1350,
-        -0.7027, -0.0988,  0.0000, -0.5623, -0.0867, -0.0880,  0.0000, -0.8978,
-         0.0639,  1.0498, -0.4276, -0.1297,  0.0762,  1.4682,  0.1757, -1.2637,
-         0.0000,  0.0359, -0.2615, -0.3825,  0.0000,  0.0000,  0.2633,  0.8090,
-        -0.0580,  0.6315, -0.4443,  0.0000, -1.1119, -0.3985,  0.3577,  0.5599,
-         0.9445,  0.0000, -0.0645,  0.0244,  0.3769,  0.8005,  0.7449,  0.0000,
-         1.0331,  0.0000,  0.5982,  1.2520,  0.0000, -0.0641,  0.1594, -0.6366],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8663e-01,  2.4004e-01, -3.1423e-01,  4.0739e-01, -1.0538e-01,
-         2.8262e-01,  2.7134e-06, -4.1694e-01, -1.4115e-01, -3.2163e-01,
-        -1.3257e-01, -2.3433e-01,  6.1114e-02, -9.8192e-01, -1.4783e+00,
-        -1.1509e+00, -6.8160e-01, -1.7221e-01, -8.2279e-04, -5.7397e-01,
-        -6.2393e-02, -9.8204e-02, -1.7473e-03, -8.7714e-01,  3.5847e-02,
-         1.0862e+00, -3.9834e-01, -1.6417e-01,  1.4894e-01,  1.4694e+00,
-         1.6205e-01, -1.3323e+00, -6.5049e-07,  1.0482e-01, -2.6235e-01,
-        -4.0769e-01,  1.0770e-02,  0.0000e+00,  3.3322e-01,  8.1219e-01,
-        -7.2262e-02,  6.0511e-01, -3.9093e-01,  2.6507e-03, -1.1008e+00,
-        -3.9997e-01,  3.3870e-01,  6.1496e-01,  9.5528e-01, -1.6875e-04,
-        -1.0452e-01,  1.6829e-02,  3.4087e-01,  7.2939e-01,  7.7597e-01,
-         3.5119e-08,  1.0296e+00,  4.9639e-06,  6.1914e-01,  1.2555e+00,
-         1.3260e-05, -7.1478e-02,  2.2837e-01, -6.2838e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2866,  0.2400, -0.3142,  0.4074, -0.1054,  0.2826,  0.0000, -0.4169,
-        -0.1412, -0.3216, -0.1326, -0.2343,  0.0611, -0.9819, -1.4783, -1.1509,
-        -0.6816, -0.1722,  0.0000, -0.5740, -0.0624, -0.0982,  0.0000, -0.8771,
-         0.0358,  1.0862, -0.3983, -0.1642,  0.1489,  1.4694,  0.1620, -1.3323,
-         0.0000,  0.1048, -0.2623, -0.4077,  0.0000,  0.0000,  0.3332,  0.8122,
-        -0.0723,  0.6051, -0.3909,  0.0000, -1.1008, -0.4000,  0.3387,  0.6150,
-         0.9553,  0.0000, -0.1045,  0.0168,  0.3409,  0.7294,  0.7760,  0.0000,
-         1.0296,  0.0000,  0.6191,  1.2555,  0.0000, -0.0715,  0.2284, -0.6284],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2866,  0.2400, -0.3142,  0.4074, -0.1054,  0.2826,  0.0000, -0.4169,
-        -0.1412, -0.3216, -0.1326, -0.2343,  0.0611, -0.9819, -1.4783, -1.1509,
-        -0.6816, -0.1722,  0.0000, -0.5740, -0.0624, -0.0982,  0.0000, -0.8771,
-         0.0358,  1.0862, -0.3983, -0.1642,  0.1489,  1.4694,  0.1620, -1.3323,
-         0.0000,  0.1048, -0.2623, -0.4077,  0.0000,  0.0000,  0.3332,  0.8122,
-        -0.0723,  0.6051, -0.3909,  0.0000, -1.1008, -0.4000,  0.3387,  0.6150,
-         0.9553,  0.0000, -0.1045,  0.0168,  0.3409,  0.7294,  0.7760,  0.0000,
-         1.0296,  0.0000,  0.6191,  1.2555,  0.0000, -0.0715,  0.2284, -0.6284],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7710e-01,  2.5357e-01, -3.0476e-01,  4.0096e-01, -1.5157e-01,
-         2.5870e-01,  2.3065e-06, -4.0984e-01, -8.6176e-02, -2.9764e-01,
-        -1.2434e-01, -2.5842e-01,  4.9551e-02, -9.8766e-01, -1.4860e+00,
-        -1.1623e+00, -6.8572e-01, -1.1095e-01, -6.9942e-04, -6.0045e-01,
-        -4.9358e-02, -7.9095e-02, -1.4853e-03, -8.7888e-01, -1.0741e-02,
-         1.1150e+00, -2.8817e-01, -1.7277e-01,  1.9748e-01,  1.4679e+00,
-         1.3157e-01, -1.3900e+00, -5.5296e-07,  1.0230e-01, -3.0096e-01,
-        -4.2655e-01,  9.1555e-03,  0.0000e+00,  5.1114e-01,  8.1287e-01,
-        -5.2508e-02,  5.6405e-01, -2.6095e-01,  2.2532e-03, -1.0956e+00,
-        -3.8752e-01,  2.6601e-01,  6.8077e-01,  9.5167e-01, -1.4345e-04,
-        -1.0040e-01, -9.3248e-02,  2.8176e-01,  6.5283e-01,  7.9745e-01,
-         2.9854e-08,  1.0262e+00,  4.2196e-06,  6.0573e-01,  1.2600e+00,
-         1.1272e-05,  2.7964e-02,  2.5955e-01, -5.5223e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2771,  0.2536, -0.3048,  0.4010, -0.1516,  0.2587,  0.0000, -0.4098,
-        -0.0862, -0.2976, -0.1243, -0.2584,  0.0496, -0.9877, -1.4860, -1.1623,
-        -0.6857, -0.1110,  0.0000, -0.6004, -0.0494, -0.0791,  0.0000, -0.8789,
-        -0.0107,  1.1150, -0.2882, -0.1728,  0.1975,  1.4679,  0.1316, -1.3900,
-         0.0000,  0.1023, -0.3010, -0.4265,  0.0000,  0.0000,  0.5111,  0.8129,
-        -0.0525,  0.5641, -0.2610,  0.0000, -1.0956, -0.3875,  0.2660,  0.6808,
-         0.9517,  0.0000, -0.1004, -0.0932,  0.2818,  0.6528,  0.7974,  0.0000,
-         1.0262,  0.0000,  0.6057,  1.2600,  0.0000,  0.0280,  0.2595, -0.5522],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2771,  0.2536, -0.3048,  0.4010, -0.1516,  0.2587,  0.0000, -0.4098,
-        -0.0862, -0.2976, -0.1243, -0.2584,  0.0496, -0.9877, -1.4860, -1.1623,
-        -0.6857, -0.1110,  0.0000, -0.6004, -0.0494, -0.0791,  0.0000, -0.8789,
-        -0.0107,  1.1150, -0.2882, -0.1728,  0.1975,  1.4679,  0.1316, -1.3900,
-         0.0000,  0.1023, -0.3010, -0.4265,  0.0000,  0.0000,  0.5111,  0.8129,
-        -0.0525,  0.5641, -0.2610,  0.0000, -1.0956, -0.3875,  0.2660,  0.6808,
-         0.9517,  0.0000, -0.1004, -0.0932,  0.2818,  0.6528,  0.7974,  0.0000,
-         1.0262,  0.0000,  0.6057,  1.2600,  0.0000,  0.0280,  0.2595, -0.5522],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6906e-01,  2.7432e-01, -3.3735e-01,  3.6916e-01, -2.1648e-01,
-         2.0848e-01,  1.9608e-06, -4.1097e-01, -4.6847e-02, -2.5810e-01,
-        -1.1246e-01, -2.7069e-01,  2.5278e-02, -9.8426e-01, -1.4895e+00,
-        -1.1739e+00, -6.7094e-01, -7.6075e-02, -5.9458e-04, -6.1241e-01,
-        -4.8307e-02, -3.9564e-02, -1.2627e-03, -8.8377e-01, -2.6163e-02,
-         1.1347e+00, -1.9886e-01, -1.7442e-01,  2.1113e-01,  1.4682e+00,
-         5.7084e-02, -1.4393e+00, -4.7007e-07,  6.4219e-02, -3.0255e-01,
-        -4.4013e-01,  7.7831e-03,  0.0000e+00,  6.6712e-01,  8.1175e-01,
-         5.8324e-03,  5.2554e-01, -2.2149e-01,  1.9155e-03, -1.0879e+00,
-        -3.7975e-01,  1.9139e-01,  7.2331e-01,  9.4877e-01, -1.2194e-04,
-        -7.1009e-02, -1.5646e-01,  2.5408e-01,  5.7617e-01,  8.0616e-01,
-         2.5379e-08,  1.0227e+00,  3.5871e-06,  5.7252e-01,  1.2629e+00,
-         9.5819e-06,  7.6519e-02,  2.9040e-01, -5.2407e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2691,  0.2743, -0.3373,  0.3692, -0.2165,  0.2085,  0.0000, -0.4110,
-        -0.0468, -0.2581, -0.1125, -0.2707,  0.0253, -0.9843, -1.4895, -1.1739,
-        -0.6709, -0.0761,  0.0000, -0.6124, -0.0483, -0.0396,  0.0000, -0.8838,
-        -0.0262,  1.1347, -0.1989, -0.1744,  0.2111,  1.4682,  0.0571, -1.4393,
-         0.0000,  0.0642, -0.3026, -0.4401,  0.0000,  0.0000,  0.6671,  0.8117,
-         0.0058,  0.5255, -0.2215,  0.0000, -1.0879, -0.3797,  0.1914,  0.7233,
-         0.9488,  0.0000, -0.0710, -0.1565,  0.2541,  0.5762,  0.8062,  0.0000,
-         1.0227,  0.0000,  0.5725,  1.2629,  0.0000,  0.0765,  0.2904, -0.5241],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2691,  0.2743, -0.3373,  0.3692, -0.2165,  0.2085,  0.0000, -0.4110,
-        -0.0468, -0.2581, -0.1125, -0.2707,  0.0253, -0.9843, -1.4895, -1.1739,
-        -0.6709, -0.0761,  0.0000, -0.6124, -0.0483, -0.0396,  0.0000, -0.8838,
-        -0.0262,  1.1347, -0.1989, -0.1744,  0.2111,  1.4682,  0.0571, -1.4393,
-         0.0000,  0.0642, -0.3026, -0.4401,  0.0000,  0.0000,  0.6671,  0.8117,
-         0.0058,  0.5255, -0.2215,  0.0000, -1.0879, -0.3797,  0.1914,  0.7233,
-         0.9488,  0.0000, -0.0710, -0.1565,  0.2541,  0.5762,  0.8062,  0.0000,
-         1.0227,  0.0000,  0.5725,  1.2629,  0.0000,  0.0765,  0.2904, -0.5241],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4671e-01,  3.2366e-01, -3.8324e-01,  3.3630e-01, -2.8084e-01,
-         2.0885e-01,  1.6669e-06, -4.1262e-01, -1.8214e-02, -2.2662e-01,
-        -8.0455e-02, -2.6811e-01, -2.6612e-02, -9.5699e-01, -1.4952e+00,
-        -1.1825e+00, -6.6125e-01, -7.6611e-03, -5.0548e-04, -6.1966e-01,
-        -3.1876e-02,  2.6328e-02, -1.0734e-03, -8.8899e-01, -1.6099e-02,
-         1.1416e+00, -1.2119e-01, -1.6389e-01,  2.7078e-01,  1.4686e+00,
-         3.4764e-02, -1.4803e+00, -3.9963e-07,  1.0847e-01, -2.8863e-01,
-        -4.4095e-01,  6.6167e-03,  0.0000e+00,  7.5865e-01,  7.8764e-01,
-         6.7472e-02,  5.1305e-01, -2.0686e-01,  1.6284e-03, -1.0724e+00,
-        -3.4105e-01,  1.6787e-01,  7.3608e-01,  9.6035e-01, -1.0367e-04,
-        -4.1138e-02, -1.7857e-01,  2.2107e-01,  5.3210e-01,  7.9217e-01,
-         2.1575e-08,  1.0181e+00,  3.0495e-06,  5.4714e-01,  1.2688e+00,
-         8.1460e-06,  1.5342e-01,  3.3696e-01, -5.1437e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2467,  0.3237, -0.3832,  0.3363, -0.2808,  0.2088,  0.0000, -0.4126,
-        -0.0182, -0.2266, -0.0805, -0.2681, -0.0266, -0.9570, -1.4952, -1.1825,
-        -0.6612, -0.0077,  0.0000, -0.6197, -0.0319,  0.0263,  0.0000, -0.8890,
-        -0.0161,  1.1416, -0.1212, -0.1639,  0.2708,  1.4686,  0.0348, -1.4803,
-         0.0000,  0.1085, -0.2886, -0.4410,  0.0000,  0.0000,  0.7586,  0.7876,
-         0.0675,  0.5131, -0.2069,  0.0000, -1.0724, -0.3411,  0.1679,  0.7361,
-         0.9603,  0.0000, -0.0411, -0.1786,  0.2211,  0.5321,  0.7922,  0.0000,
-         1.0181,  0.0000,  0.5471,  1.2688,  0.0000,  0.1534,  0.3370, -0.5144],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2467,  0.3237, -0.3832,  0.3363, -0.2808,  0.2088,  0.0000, -0.4126,
-        -0.0182, -0.2266, -0.0805, -0.2681, -0.0266, -0.9570, -1.4952, -1.1825,
-        -0.6612, -0.0077,  0.0000, -0.6197, -0.0319,  0.0263,  0.0000, -0.8890,
-        -0.0161,  1.1416, -0.1212, -0.1639,  0.2708,  1.4686,  0.0348, -1.4803,
-         0.0000,  0.1085, -0.2886, -0.4410,  0.0000,  0.0000,  0.7586,  0.7876,
-         0.0675,  0.5131, -0.2069,  0.0000, -1.0724, -0.3411,  0.1679,  0.7361,
-         0.9603,  0.0000, -0.0411, -0.1786,  0.2211,  0.5321,  0.7922,  0.0000,
-         1.0181,  0.0000,  0.5471,  1.2688,  0.0000,  0.1534,  0.3370, -0.5144],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.0450e-01,  3.8653e-01, -4.3324e-01,  2.9120e-01, -3.0803e-01,
-         2.3172e-01,  1.4172e-06, -4.0774e-01, -1.5615e-02, -2.2557e-01,
-        -4.5453e-02, -2.6491e-01, -1.1066e-01, -9.2156e-01, -1.4978e+00,
-        -1.1907e+00, -6.4491e-01,  3.9629e-02, -4.2975e-04, -6.1957e-01,
-        -1.2950e-02,  6.7810e-02, -9.1263e-04, -8.6309e-01, -1.9376e-02,
-         1.1352e+00, -9.2835e-02, -1.6902e-01,  3.0166e-01,  1.4740e+00,
-         1.3071e-01, -1.5147e+00, -3.3976e-07,  1.5398e-01, -2.5772e-01,
-        -4.4679e-01,  5.6254e-03,  0.0000e+00,  8.3091e-01,  7.3524e-01,
-         9.6237e-02,  4.9842e-01, -2.1212e-01,  1.3845e-03, -1.0591e+00,
-        -2.6375e-01,  1.2135e-01,  7.3000e-01,  9.9782e-01, -8.8139e-05,
-         6.3435e-03, -1.9331e-01,  2.1788e-01,  5.0094e-01,  7.3956e-01,
-         1.8343e-08,  1.0124e+00,  2.5927e-06,  5.0806e-01,  1.2772e+00,
-         6.9256e-06,  2.5321e-01,  3.8438e-01, -5.6880e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2045,  0.3865, -0.4332,  0.2912, -0.3080,  0.2317,  0.0000, -0.4077,
-        -0.0156, -0.2256, -0.0455, -0.2649, -0.1107, -0.9216,  0.0000, -1.1907,
-        -0.6449,  0.0396,  0.0000, -0.6196, -0.0129,  0.0678,  0.0000, -0.8631,
-        -0.0194,  1.1352, -0.0928, -0.1690,  0.3017,  1.4740,  0.1307, -1.5147,
-         0.0000,  0.1540, -0.2577, -0.4468,  0.0000,  0.0000,  0.8309,  0.7352,
-         0.0962,  0.4984, -0.2121,  0.0000, -1.0591, -0.2638,  0.1214,  0.7300,
-         0.9978,  0.0000,  0.0063, -0.1933,  0.2179,  0.5009,  0.7396,  0.0000,
-         1.0124,  0.0000,  0.5081,  1.2772,  0.0000,  0.2532,  0.3844, -0.5688],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2045,  0.3865, -0.4332,  0.2912, -0.3080,  0.2317,  0.0000, -0.4077,
-        -0.0156, -0.2256, -0.0455, -0.2649, -0.1107, -0.9216,  0.0000, -1.1907,
-        -0.6449,  0.0396,  0.0000, -0.6196, -0.0129,  0.0678,  0.0000, -0.8631,
-        -0.0194,  1.1352, -0.0928, -0.1690,  0.3017,  1.4740,  0.1307, -1.5147,
-         0.0000,  0.1540, -0.2577, -0.4468,  0.0000,  0.0000,  0.8309,  0.7352,
-         0.0962,  0.4984, -0.2121,  0.0000, -1.0591, -0.2638,  0.1214,  0.7300,
-         0.9978,  0.0000,  0.0063, -0.1933,  0.2179,  0.5009,  0.7396,  0.0000,
-         1.0124,  0.0000,  0.5081,  1.2772,  0.0000,  0.2532,  0.3844, -0.5688],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.9911e-01,  4.1270e-01, -4.1082e-01,  2.4686e-01, -4.0191e-01,
-         2.6366e-01,  1.2050e-06, -4.0798e-01, -4.8034e-02, -1.8587e-01,
-         4.0765e-03, -2.8117e-01, -1.0909e-01, -8.8879e-01, -2.1608e-03,
-        -1.1955e+00, -6.3248e-01,  1.2646e-01, -3.6539e-04, -6.1041e-01,
-        -2.2560e-02,  5.8135e-02, -7.7595e-04, -8.4232e-01, -6.5855e-02,
-         1.1247e+00, -3.8617e-03, -1.8366e-01,  2.7254e-01,  1.4743e+00,
-         1.8549e-01, -1.5396e+00, -2.8888e-07,  1.0823e-01, -2.8904e-01,
-        -4.3374e-01,  4.7830e-03,  0.0000e+00,  8.8842e-01,  6.7009e-01,
-         1.1701e-01,  4.8218e-01, -1.7180e-01,  1.1771e-03, -1.0522e+00,
-        -1.7936e-01,  6.5480e-02,  7.0441e-01,  1.0151e+00, -7.4939e-05,
-         5.1340e-02, -3.5034e-01,  2.2090e-01,  4.6109e-01,  6.8294e-01,
-         1.5596e-08,  1.0067e+00,  2.2044e-06,  4.5685e-01,  1.2870e+00,
-         5.8884e-06,  3.3556e-01,  4.0601e-01, -5.4193e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1991,  0.4127, -0.4108,  0.2469, -0.4019,  0.2637,  0.0000, -0.4080,
-        -0.0480, -0.1859,  0.0041, -0.2812, -0.1091, -0.8888,  0.0000, -1.1955,
-        -0.6325,  0.1265,  0.0000, -0.6104, -0.0226,  0.0581,  0.0000, -0.8423,
-        -0.0659,  1.1247, -0.0039, -0.1837,  0.2725,  1.4743,  0.1855, -1.5396,
-         0.0000,  0.1082, -0.2890, -0.4337,  0.0000,  0.0000,  0.8884,  0.6701,
-         0.1170,  0.4822, -0.1718,  0.0000, -1.0522, -0.1794,  0.0655,  0.7044,
-         1.0151,  0.0000,  0.0513, -0.3503,  0.2209,  0.4611,  0.6829,  0.0000,
-         1.0067,  0.0000,  0.4569,  1.2870,  0.0000,  0.3356,  0.4060, -0.5419],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1991,  0.4127, -0.4108,  0.2469, -0.4019,  0.2637,  0.0000, -0.4080,
-        -0.0480, -0.1859,  0.0041, -0.2812, -0.1091, -0.8888,  0.0000, -1.1955,
-        -0.6325,  0.1265,  0.0000, -0.6104, -0.0226,  0.0581,  0.0000, -0.8423,
-        -0.0659,  1.1247, -0.0039, -0.1837,  0.2725,  1.4743,  0.1855, -1.5396,
-         0.0000,  0.1082, -0.2890, -0.4337,  0.0000,  0.0000,  0.8884,  0.6701,
-         0.1170,  0.4822, -0.1718,  0.0000, -1.0522, -0.1794,  0.0655,  0.7044,
-         1.0151,  0.0000,  0.0513, -0.3503,  0.2209,  0.4611,  0.6829,  0.0000,
-         1.0067,  0.0000,  0.4569,  1.2870,  0.0000,  0.3356,  0.4060, -0.5419],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.1676e-01,  4.2598e-01, -4.2784e-01,  1.9675e-01, -4.1641e-01,
-         2.4639e-01,  1.0246e-06, -4.0349e-01, -5.3918e-02, -1.6192e-01,
-         5.5418e-02, -2.8205e-01, -1.3553e-01, -8.6746e-01, -1.8373e-03,
-        -1.1961e+00, -5.8981e-01,  2.1806e-01, -3.1069e-04, -5.9949e-01,
-        -7.7951e-03,  2.3842e-02, -6.5978e-04, -8.2589e-01, -7.6146e-02,
-         1.1198e+00,  6.4716e-02, -1.8078e-01,  2.3203e-01,  1.4768e+00,
-         1.9749e-01, -1.5601e+00, -2.4563e-07,  6.5391e-02, -2.8827e-01,
-        -4.2160e-01,  4.0669e-03,  0.0000e+00,  9.4164e-01,  6.1129e-01,
-         1.4747e-01,  4.8560e-01, -1.6879e-01,  1.0009e-03, -1.0415e+00,
-        -1.2621e-01,  3.3697e-02,  6.7923e-01,  1.0233e+00, -6.3719e-05,
-         9.2279e-02, -4.5275e-01,  2.0929e-01,  4.0579e-01,  6.5067e-01,
-         1.3261e-08,  1.0026e+00,  1.8744e-06,  4.2463e-01,  1.2920e+00,
-         5.0069e-06,  3.6161e-01,  4.2049e-01, -5.1519e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2168,  0.4260, -0.4278,  0.1967, -0.4164,  0.2464,  0.0000, -0.4035,
-        -0.0539, -0.1619,  0.0554, -0.2820, -0.1355, -0.8675,  0.0000, -1.1961,
-        -0.5898,  0.2181,  0.0000, -0.5995, -0.0078,  0.0238,  0.0000, -0.8259,
-        -0.0761,  1.1198,  0.0647, -0.1808,  0.2320,  1.4768,  0.1975, -1.5601,
-         0.0000,  0.0654, -0.2883, -0.4216,  0.0000,  0.0000,  0.9416,  0.6113,
-         0.1475,  0.4856, -0.1688,  0.0000, -1.0415, -0.1262,  0.0337,  0.6792,
-         1.0233,  0.0000,  0.0923, -0.4528,  0.2093,  0.4058,  0.6507,  0.0000,
-         1.0026,  0.0000,  0.4246,  1.2920,  0.0000,  0.3616,  0.4205, -0.5152],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2168,  0.4260, -0.4278,  0.1967, -0.4164,  0.2464,  0.0000, -0.4035,
-        -0.0539, -0.1619,  0.0554, -0.2820, -0.1355, -0.8675,  0.0000, -1.1961,
-        -0.5898,  0.2181,  0.0000, -0.5995, -0.0078,  0.0238,  0.0000, -0.8259,
-        -0.0761,  1.1198,  0.0647, -0.1808,  0.2320,  1.4768,  0.1975, -1.5601,
-         0.0000,  0.0654, -0.2883, -0.4216,  0.0000,  0.0000,  0.9416,  0.6113,
-         0.1475,  0.4856, -0.1688,  0.0000, -1.0415, -0.1262,  0.0337,  0.6792,
-         1.0233,  0.0000,  0.0923, -0.4528,  0.2093,  0.4058,  0.6507,  0.0000,
-         1.0026,  0.0000,  0.4246,  1.2920,  0.0000,  0.3616,  0.4205, -0.5152],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.2528e-01,  3.9173e-01, -4.6102e-01,  1.2023e-01, -3.8303e-01,
-         1.8642e-01,  8.7124e-07, -4.1169e-01, -2.9308e-02, -1.3733e-01,
-         1.2244e-01, -2.6160e-01, -2.0499e-01, -8.4228e-01, -1.5624e-03,
-        -1.1955e+00, -5.4102e-01,  3.0136e-01, -2.6419e-04, -5.7509e-01,
-         5.4192e-03, -2.1356e-02, -5.6104e-04, -8.1247e-01, -6.3046e-02,
-         1.1198e+00,  6.1938e-02, -1.7873e-01,  1.6628e-01,  1.4791e+00,
-         1.5640e-01, -1.5784e+00, -2.0887e-07,  2.0082e-02, -2.2989e-01,
-        -3.8808e-01,  3.4583e-03,  0.0000e+00,  9.7142e-01,  5.5755e-01,
-         1.7094e-01,  4.9515e-01, -2.1285e-01,  8.5111e-04, -1.0287e+00,
-        -9.0370e-02,  5.7504e-02,  6.6488e-01,  1.0310e+00, -5.4183e-05,
-         1.2804e-01, -5.0345e-01,  1.9570e-01,  2.8804e-01,  6.3300e-01,
-         1.1276e-08,  9.9878e-01,  1.5939e-06,  3.6431e-01,  1.2910e+00,
-         4.2576e-06,  3.1806e-01,  4.0787e-01, -5.0625e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2253,  0.3917, -0.4610,  0.1202, -0.3830,  0.1864,  0.0000, -0.4117,
-        -0.0293, -0.1373,  0.1224, -0.2616, -0.2050, -0.8423,  0.0000, -1.1955,
-        -0.5410,  0.3014,  0.0000, -0.5751,  0.0054, -0.0214,  0.0000, -0.8125,
-        -0.0630,  1.1198,  0.0619, -0.1787,  0.1663,  1.4791,  0.1564, -1.5784,
-         0.0000,  0.0201, -0.2299, -0.3881,  0.0000,  0.0000,  0.9714,  0.5576,
-         0.1709,  0.4952, -0.2128,  0.0000, -1.0287, -0.0904,  0.0575,  0.6649,
-         1.0310,  0.0000,  0.1280, -0.5035,  0.1957,  0.2880,  0.6330,  0.0000,
-         0.9988,  0.0000,  0.3643,  1.2910,  0.0000,  0.3181,  0.4079, -0.5062],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2253,  0.3917, -0.4610,  0.1202, -0.3830,  0.1864,  0.0000, -0.4117,
-        -0.0293, -0.1373,  0.1224, -0.2616, -0.2050, -0.8423,  0.0000, -1.1955,
-        -0.5410,  0.3014,  0.0000, -0.5751,  0.0054, -0.0214,  0.0000, -0.8125,
-        -0.0630,  1.1198,  0.0619, -0.1787,  0.1663,  1.4791,  0.1564, -1.5784,
-         0.0000,  0.0201, -0.2299, -0.3881,  0.0000,  0.0000,  0.9714,  0.5576,
-         0.1709,  0.4952, -0.2128,  0.0000, -1.0287, -0.0904,  0.0575,  0.6649,
-         1.0310,  0.0000,  0.1280, -0.5035,  0.1957,  0.2880,  0.6330,  0.0000,
-         0.9988,  0.0000,  0.3643,  1.2910,  0.0000,  0.3181,  0.4079, -0.5062],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.1895e-01,  3.3101e-01, -4.8818e-01,  7.9390e-02, -1.9813e-01,
-         1.2785e-01,  7.4090e-07, -4.0593e-01,  1.9324e-02, -9.4112e-02,
-         1.5895e-01, -2.5351e-01, -2.7666e-01, -8.2803e-01, -1.3286e-03,
-        -1.1934e+00, -4.7233e-01,  3.6835e-01, -2.2467e-04, -5.3621e-01,
-         6.1393e-02, -2.0156e-02, -4.7711e-04, -7.6934e-01, -1.6347e-02,
-         1.1220e+00,  3.6155e-02, -1.6542e-01,  5.1708e-02,  1.4793e+00,
-         9.1605e-02, -1.5940e+00, -1.7762e-07, -2.0850e-03, -1.2948e-01,
-        -3.5155e-01,  2.9409e-03,  0.0000e+00,  1.0038e+00,  5.3806e-01,
-         1.9905e-01,  4.9157e-01, -2.3445e-01,  7.2379e-04, -1.0194e+00,
-        -6.7632e-02,  8.4455e-02,  6.4223e-01,  1.0221e+00, -4.6078e-05,
-         1.5364e-01, -5.0494e-01,  1.9634e-01,  1.3452e-01,  6.3572e-01,
-         9.5895e-09,  9.9585e-01,  1.3554e-06,  2.9400e-01,  1.2841e+00,
-         3.6206e-06,  2.4998e-01,  3.6583e-01, -4.9202e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2189,  0.3310, -0.4882,  0.0794, -0.1981,  0.1279,  0.0000, -0.4059,
-         0.0193, -0.0941,  0.1590, -0.2535, -0.2767, -0.8280,  0.0000, -1.1934,
-        -0.4723,  0.3684,  0.0000, -0.5362,  0.0614, -0.0202,  0.0000, -0.7693,
-        -0.0163,  1.1220,  0.0362, -0.1654,  0.0517,  1.4793,  0.0916, -1.5940,
-         0.0000, -0.0021, -0.1295, -0.3516,  0.0000,  0.0000,  1.0038,  0.5381,
-         0.1991,  0.4916, -0.2345,  0.0000, -1.0194, -0.0676,  0.0845,  0.6422,
-         1.0221,  0.0000,  0.1536, -0.5049,  0.1963,  0.1345,  0.6357,  0.0000,
-         0.9959,  0.0000,  0.2940,  1.2841,  0.0000,  0.2500,  0.3658, -0.4920],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2189,  0.3310, -0.4882,  0.0794, -0.1981,  0.1279,  0.0000, -0.4059,
-         0.0193, -0.0941,  0.1590, -0.2535, -0.2767, -0.8280,  0.0000, -1.1934,
-        -0.4723,  0.3684,  0.0000, -0.5362,  0.0614, -0.0202,  0.0000, -0.7693,
-        -0.0163,  1.1220,  0.0362, -0.1654,  0.0517,  1.4793,  0.0916, -1.5940,
-         0.0000, -0.0021, -0.1295, -0.3516,  0.0000,  0.0000,  1.0038,  0.5381,
-         0.1991,  0.4916, -0.2345,  0.0000, -1.0194, -0.0676,  0.0845,  0.6422,
-         1.0221,  0.0000,  0.1536, -0.5049,  0.1963,  0.1345,  0.6357,  0.0000,
-         0.9959,  0.0000,  0.2940,  1.2841,  0.0000,  0.2500,  0.3658, -0.4920],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.6823e-01,  2.4932e-01, -5.0622e-01,  4.8536e-02,  2.7289e-02,
-         8.2699e-02,  6.3011e-07, -4.1112e-01,  2.3412e-02, -4.8491e-02,
-         2.1688e-01, -2.1784e-01, -3.1499e-01, -8.1232e-01, -1.1300e-03,
-        -1.1921e+00, -4.1126e-01,  4.2602e-01, -1.9107e-04, -4.9716e-01,
-         1.0292e-01, -4.4837e-02, -4.0577e-04, -7.3375e-01, -6.3962e-03,
-         1.1193e+00,  6.3727e-03, -1.7109e-01, -4.2706e-02,  1.4783e+00,
-         3.8724e-02, -1.6065e+00, -1.5106e-07, -1.5706e-02, -5.4500e-02,
-        -2.9778e-01,  2.5011e-03,  0.0000e+00,  1.0069e+00,  4.8246e-01,
-         2.1752e-01,  4.8798e-01, -2.1480e-01,  6.1555e-04, -1.0053e+00,
-        -5.4059e-02,  1.0883e-01,  6.0521e-01,  1.0072e+00, -3.9188e-05,
-         1.8768e-01, -5.0143e-01,  1.8498e-01, -3.9155e-03,  6.1525e-01,
-         8.1556e-09,  9.9224e-01,  1.1527e-06,  2.3943e-01,  1.2766e+00,
-         3.0792e-06,  2.2606e-01,  3.3816e-01, -4.9383e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1682,  0.2493, -0.5062,  0.0485,  0.0273,  0.0827,  0.0000, -0.4111,
-         0.0234, -0.0485,  0.2169, -0.2178, -0.3150, -0.8123,  0.0000, -1.1921,
-        -0.4113,  0.4260,  0.0000, -0.4972,  0.1029, -0.0448,  0.0000, -0.7338,
-        -0.0064,  1.1193,  0.0064, -0.1711, -0.0427,  1.4783,  0.0387, -1.6065,
-         0.0000, -0.0157,  0.0000, -0.2978,  0.0000,  0.0000,  1.0069,  0.4825,
-         0.2175,  0.4880, -0.2148,  0.0000, -1.0053, -0.0541,  0.1088,  0.6052,
-         1.0072,  0.0000,  0.1877, -0.5014,  0.1850, -0.0039,  0.6153,  0.0000,
-         0.9922,  0.0000,  0.2394,  1.2766,  0.0000,  0.2261,  0.3382, -0.4938],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1682,  0.2493, -0.5062,  0.0485,  0.0273,  0.0827,  0.0000, -0.4111,
-         0.0234, -0.0485,  0.2169, -0.2178, -0.3150, -0.8123,  0.0000, -1.1921,
-        -0.4113,  0.4260,  0.0000, -0.4972,  0.1029, -0.0448,  0.0000, -0.7338,
-        -0.0064,  1.1193,  0.0064, -0.1711, -0.0427,  1.4783,  0.0387, -1.6065,
-         0.0000, -0.0157,  0.0000, -0.2978,  0.0000,  0.0000,  1.0069,  0.4825,
-         0.2175,  0.4880, -0.2148,  0.0000, -1.0053, -0.0541,  0.1088,  0.6052,
-         1.0072,  0.0000,  0.1877, -0.5014,  0.1850, -0.0039,  0.6153,  0.0000,
-         0.9922,  0.0000,  0.2394,  1.2766,  0.0000,  0.2261,  0.3382, -0.4938],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.3446e-01,  1.5256e-01, -4.5825e-01,  7.0556e-02,  2.4293e-01,
-         6.2367e-02,  5.3593e-07, -4.4126e-01, -2.5409e-02, -7.6537e-03,
-         2.5336e-01, -1.4038e-01, -2.4451e-01, -7.9405e-01, -9.6108e-04,
-        -1.1958e+00, -4.3450e-01,  4.7246e-01, -1.6251e-04, -4.6013e-01,
-         1.3634e-01, -8.8417e-02, -3.4512e-04, -7.3336e-01, -7.0718e-02,
-         1.1171e+00,  2.8329e-02, -2.3609e-01, -1.3459e-01,  1.4740e+00,
-        -3.7122e-02, -1.6164e+00, -1.2848e-07, -5.7969e-02,  6.3773e-02,
-        -2.1439e-01,  2.1273e-03,  0.0000e+00,  1.0164e+00,  4.4847e-01,
-         2.2650e-01,  4.7697e-01, -1.1380e-01,  5.2355e-04, -1.0026e+00,
-        -1.0419e-01,  1.5495e-01,  5.8041e-01,  9.9253e-01, -3.3330e-05,
-         1.9590e-01, -5.3405e-01,  1.3513e-01, -5.9950e-02,  5.9235e-01,
-         6.9366e-09,  9.8961e-01,  9.8044e-07,  2.1472e-01,  1.2694e+00,
-         2.6190e-06,  2.4879e-01,  2.9015e-01, -4.5651e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1345,  0.1526, -0.4582,  0.0706,  0.2429,  0.0624,  0.0000, -0.4413,
-        -0.0254, -0.0077,  0.2534, -0.1404, -0.2445, -0.7940,  0.0000, -1.1958,
-        -0.4345,  0.4725,  0.0000, -0.4601,  0.1363, -0.0884,  0.0000, -0.7334,
-        -0.0707,  1.1171,  0.0283, -0.2361, -0.1346,  1.4740, -0.0371, -1.6164,
-         0.0000, -0.0580,  0.0000, -0.2144,  0.0000,  0.0000,  1.0164,  0.4485,
-         0.2265,  0.4770, -0.1138,  0.0000, -1.0026, -0.1042,  0.1550,  0.5804,
-         0.9925,  0.0000,  0.1959, -0.5340,  0.1351, -0.0600,  0.5924,  0.0000,
-         0.9896,  0.0000,  0.2147,  1.2694,  0.0000,  0.2488,  0.2902, -0.4565],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1345,  0.1526, -0.4582,  0.0706,  0.2429,  0.0624,  0.0000, -0.4413,
-        -0.0254, -0.0077,  0.2534, -0.1404, -0.2445, -0.7940,  0.0000, -1.1958,
-        -0.4345,  0.4725,  0.0000, -0.4601,  0.1363, -0.0884,  0.0000, -0.7334,
-        -0.0707,  1.1171,  0.0283, -0.2361, -0.1346,  1.4740, -0.0371, -1.6164,
-         0.0000, -0.0580,  0.0000, -0.2144,  0.0000,  0.0000,  1.0164,  0.4485,
-         0.2265,  0.4770, -0.1138,  0.0000, -1.0026, -0.1042,  0.1550,  0.5804,
-         0.9925,  0.0000,  0.1959, -0.5340,  0.1351, -0.0600,  0.5924,  0.0000,
-         0.9896,  0.0000,  0.2147,  1.2694,  0.0000,  0.2488,  0.2902, -0.4565],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.3631e-01,  1.3478e-01, -4.4397e-01,  1.0368e-01,  7.2166e-01,
-         6.8913e-02,  4.5586e-07, -4.4755e-01,  3.5794e-02, -1.1683e-03,
-         2.4836e-01, -1.5215e-01, -2.2407e-01, -7.6111e-01, -8.1749e-04,
-        -1.1898e+00, -4.2365e-01,  5.0827e-01, -1.3823e-04, -4.4726e-01,
-         1.8396e-01, -4.5938e-02, -2.9356e-04, -7.1478e-01, -1.8016e-02,
-         1.1059e+00,  5.7891e-02, -1.9684e-01, -2.2225e-01,  1.4731e+00,
-        -1.0174e-01, -1.6264e+00, -1.0929e-07, -1.2148e-02,  5.4245e-02,
-        -1.9544e-01,  1.8095e-03,  0.0000e+00,  9.9276e-01,  3.8274e-01,
-         2.5883e-01,  4.8607e-01, -4.4119e-02,  4.4533e-04, -9.9732e-01,
-        -1.4942e-01,  2.3184e-01,  5.1739e-01,  9.8125e-01, -2.8351e-05,
-         1.8413e-01, -5.3049e-01,  4.4841e-02,  3.6133e-02,  5.5851e-01,
-         5.9003e-09,  9.8711e-01,  8.3396e-07,  2.0635e-01,  1.2667e+00,
-         2.2277e-06,  2.6415e-01,  2.1800e-01, -4.2162e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 1.3631e-01,  1.3478e-01, -4.4397e-01,  1.0368e-01,  7.2166e-01,
-         6.8913e-02,  0.0000e+00, -4.4755e-01,  3.5794e-02, -1.1683e-03,
-         2.4836e-01, -1.5215e-01, -2.2407e-01, -7.6111e-01,  0.0000e+00,
-        -1.1898e+00, -4.2365e-01,  5.0827e-01,  0.0000e+00, -4.4726e-01,
-         1.8396e-01, -4.5938e-02,  0.0000e+00, -7.1478e-01, -1.8016e-02,
-         1.1059e+00,  5.7891e-02, -1.9684e-01, -2.2225e-01,  1.4731e+00,
-        -1.0174e-01, -1.6264e+00,  0.0000e+00, -1.2148e-02,  0.0000e+00,
-        -1.9544e-01,  0.0000e+00,  0.0000e+00,  9.9276e-01,  3.8274e-01,
-         2.5883e-01,  4.8607e-01, -4.4119e-02,  0.0000e+00, -9.9732e-01,
-        -1.4942e-01,  2.3184e-01,  5.1739e-01,  9.8125e-01,  0.0000e+00,
-         1.8413e-01, -5.3049e-01,  4.4841e-02,  3.6133e-02,  5.5851e-01,
-         0.0000e+00,  9.8711e-01,  0.0000e+00,  2.0635e-01,  1.2667e+00,
-         0.0000e+00,  2.6415e-01,  2.1800e-01, -4.2162e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 1.3631e-01,  1.3478e-01, -4.4397e-01,  1.0368e-01,  7.2166e-01,
-         6.8913e-02,  0.0000e+00, -4.4755e-01,  3.5794e-02, -1.1683e-03,
-         2.4836e-01, -1.5215e-01, -2.2407e-01, -7.6111e-01,  0.0000e+00,
-        -1.1898e+00, -4.2365e-01,  5.0827e-01,  0.0000e+00, -4.4726e-01,
-         1.8396e-01, -4.5938e-02,  0.0000e+00, -7.1478e-01, -1.8016e-02,
-         1.1059e+00,  5.7891e-02, -1.9684e-01, -2.2225e-01,  1.4731e+00,
-        -1.0174e-01, -1.6264e+00,  0.0000e+00, -1.2148e-02,  0.0000e+00,
-        -1.9544e-01,  0.0000e+00,  0.0000e+00,  9.9276e-01,  3.8274e-01,
-         2.5883e-01,  4.8607e-01, -4.4119e-02,  0.0000e+00, -9.9732e-01,
-        -1.4942e-01,  2.3184e-01,  5.1739e-01,  9.8125e-01,  0.0000e+00,
-         1.8413e-01, -5.3049e-01,  4.4841e-02,  3.6133e-02,  5.5851e-01,
-         0.0000e+00,  9.8711e-01,  0.0000e+00,  2.0635e-01,  1.2667e+00,
-         0.0000e+00,  2.6415e-01,  2.1800e-01, -4.2162e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.6583e-01,  1.2397e-01, -4.4260e-01,  1.3279e-01,  1.0911e+00,
-         2.5085e-02,  3.8779e-07, -4.4176e-01,  1.1504e-01, -1.5695e-02,
-         2.0457e-01, -1.8718e-01, -2.2734e-01, -7.4127e-01, -6.9542e-04,
-        -1.1871e+00, -4.1626e-01,  5.2450e-01, -1.1759e-04, -4.1688e-01,
-         2.0424e-01,  7.2009e-02, -2.4972e-04, -7.1322e-01,  1.0383e-01,
-         1.1043e+00,  4.1592e-02, -8.3131e-02, -2.7766e-01,  1.4741e+00,
-        -1.5271e-01, -1.6341e+00, -9.2968e-08,  5.4633e-02,  4.6145e-02,
-        -1.9134e-01,  1.5393e-03,  0.0000e+00,  9.9243e-01,  3.7926e-01,
-         2.9846e-01,  4.7983e-01, -1.9601e-02,  3.7883e-04, -1.0101e+00,
-        -2.2535e-01,  2.7962e-01,  4.5148e-01,  9.7706e-01, -2.4117e-05,
-         1.9067e-01, -4.7775e-01, -6.9137e-02,  1.4932e-01,  5.3364e-01,
-         5.0192e-09,  9.8377e-01,  7.0943e-07,  2.0043e-01,  1.2634e+00,
-         1.8951e-06,  2.0444e-01,  1.6510e-01, -3.9053e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1658,  0.1240, -0.4426,  0.1328,  1.0911,  0.0251,  0.0000, -0.4418,
-         0.1150, -0.0157,  0.2046, -0.1872, -0.2273, -0.7413,  0.0000, -1.1871,
-        -0.4163,  0.5245,  0.0000, -0.4169,  0.2042,  0.0720,  0.0000, -0.7132,
-         0.1038,  1.1043,  0.0416, -0.0831, -0.2777,  1.4741, -0.1527, -1.6341,
-         0.0000,  0.0546,  0.0000, -0.1913,  0.0000,  0.0000,  0.9924,  0.3793,
-         0.2985,  0.4798, -0.0196,  0.0000, -1.0101, -0.2254,  0.2796,  0.4515,
-         0.9771,  0.0000,  0.1907, -0.4777, -0.0691,  0.1493,  0.5336,  0.0000,
-         0.9838,  0.0000,  0.2004,  1.2634,  0.0000,  0.2044,  0.1651, -0.3905],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1658,  0.1240, -0.4426,  0.1328,  1.0911,  0.0251,  0.0000, -0.4418,
-         0.1150, -0.0157,  0.2046, -0.1872, -0.2273, -0.7413,  0.0000, -1.1871,
-        -0.4163,  0.5245,  0.0000, -0.4169,  0.2042,  0.0720,  0.0000, -0.7132,
-         0.1038,  1.1043,  0.0416, -0.0831, -0.2777,  1.4741, -0.1527, -1.6341,
-         0.0000,  0.0546,  0.0000, -0.1913,  0.0000,  0.0000,  0.9924,  0.3793,
-         0.2985,  0.4798, -0.0196,  0.0000, -1.0101, -0.2254,  0.2796,  0.4515,
-         0.9771,  0.0000,  0.1907, -0.4777, -0.0691,  0.1493,  0.5336,  0.0000,
-         0.9838,  0.0000,  0.2004,  1.2634,  0.0000,  0.2044,  0.1651, -0.3905],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.1900e-01,  5.2798e-02, -3.5580e-01,  1.8259e-01,  1.3894e+00,
-         2.4168e-02,  3.2991e-07, -4.7006e-01,  1.6102e-01, -4.2962e-02,
-         1.6505e-01, -2.4828e-01, -1.5037e-01, -7.3758e-01, -5.9163e-04,
-        -1.1924e+00, -4.8332e-01,  5.2812e-01, -1.0004e-04, -4.2177e-01,
-         1.8140e-01,  1.5832e-01, -2.1245e-04, -7.5679e-01,  1.8709e-01,
-         1.1161e+00,  4.6190e-02, -2.8562e-02, -2.9361e-01,  1.4744e+00,
-        -2.7806e-01, -1.6374e+00, -7.9093e-08,  3.5944e-02,  3.9258e-02,
-        -2.0798e-01,  1.3096e-03,  0.0000e+00,  1.0250e+00,  4.5422e-01,
-         3.1980e-01,  4.5691e-01,  1.3818e-02,  3.2229e-04, -1.0282e+00,
-        -3.2373e-01,  3.1616e-01,  4.0764e-01,  9.7834e-01, -2.0518e-05,
-         2.0336e-01, -4.7354e-01, -1.4257e-01,  1.9826e-01,  5.3389e-01,
-         4.2701e-09,  9.8012e-01,  6.0355e-07,  1.2736e-01,  1.2593e+00,
-         1.6122e-06,  1.1002e-01,  1.2529e-01, -3.3120e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2190,  0.0528, -0.3558,  0.1826,  1.3894,  0.0242,  0.0000, -0.4701,
-         0.1610, -0.0430,  0.1651, -0.2483, -0.1504, -0.7376,  0.0000, -1.1924,
-        -0.4833,  0.5281,  0.0000, -0.4218,  0.1814,  0.1583,  0.0000, -0.7568,
-         0.1871,  1.1161,  0.0462, -0.0286, -0.2936,  0.0000, -0.2781, -1.6374,
-         0.0000,  0.0359,  0.0000, -0.2080,  0.0000,  0.0000,  1.0250,  0.4542,
-         0.3198,  0.4569,  0.0138,  0.0000, -1.0282, -0.3237,  0.3162,  0.4076,
-         0.9783,  0.0000,  0.2034, -0.4735, -0.1426,  0.1983,  0.5339,  0.0000,
-         0.9801,  0.0000,  0.1274,  1.2593,  0.0000,  0.1100,  0.1253, -0.3312],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2190,  0.0528, -0.3558,  0.1826,  1.3894,  0.0242,  0.0000, -0.4701,
-         0.1610, -0.0430,  0.1651, -0.2483, -0.1504, -0.7376,  0.0000, -1.1924,
-        -0.4833,  0.5281,  0.0000, -0.4218,  0.1814,  0.1583,  0.0000, -0.7568,
-         0.1871,  1.1161,  0.0462, -0.0286, -0.2936,  0.0000, -0.2781, -1.6374,
-         0.0000,  0.0359,  0.0000, -0.2080,  0.0000,  0.0000,  1.0250,  0.4542,
-         0.3198,  0.4569,  0.0138,  0.0000, -1.0282, -0.3237,  0.3162,  0.4076,
-         0.9783,  0.0000,  0.2034, -0.4735, -0.1426,  0.1983,  0.5339,  0.0000,
-         0.9801,  0.0000,  0.1274,  1.2593,  0.0000,  0.1100,  0.1253, -0.3312],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.0415e-01,  7.4420e-03, -3.2425e-01,  1.9624e-01,  1.6281e+00,
-         8.8227e-02,  2.8070e-07, -5.0675e-01,  1.1959e-01, -1.1284e-01,
-         8.7176e-02, -3.0875e-01, -1.4085e-01, -7.3765e-01, -5.0338e-04,
-        -1.1975e+00, -4.9939e-01,  5.1279e-01, -8.5119e-05, -4.7277e-01,
-         8.1324e-02,  1.1157e-01, -1.8076e-04, -7.4715e-01,  1.5967e-01,
-         1.1260e+00,  8.6875e-02, -5.5487e-02, -2.8028e-01,  2.6289e-04,
-        -3.9827e-01, -1.6367e+00, -6.7295e-08, -4.9916e-02,  3.3402e-02,
-        -2.5813e-01,  1.1142e-03,  0.0000e+00,  1.0368e+00,  4.6072e-01,
-         2.8860e-01,  4.1323e-01,  5.3528e-02,  2.7422e-04, -1.0534e+00,
-        -3.4397e-01,  3.0453e-01,  3.6023e-01,  9.5915e-01, -1.7457e-05,
-         1.9155e-01, -5.3709e-01, -1.4238e-01,  2.1263e-01,  5.0316e-01,
-         3.6331e-09,  9.7900e-01,  5.1352e-07,  6.9916e-02,  1.2494e+00,
-         1.3717e-06,  1.2472e-01,  6.7086e-02, -2.7898e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3042,  0.0074, -0.3242,  0.1962,  1.6281,  0.0882,  0.0000, -0.5067,
-         0.1196, -0.1128,  0.0872, -0.3088, -0.1409, -0.7376,  0.0000, -1.1975,
-        -0.4994,  0.5128,  0.0000, -0.4728,  0.0813,  0.1116,  0.0000, -0.7472,
-         0.1597,  1.1260,  0.0869, -0.0555, -0.2803,  0.0000, -0.3983, -1.6367,
-         0.0000, -0.0499,  0.0000, -0.2581,  0.0000,  0.0000,  1.0368,  0.4607,
-         0.2886,  0.4132,  0.0535,  0.0000, -1.0534, -0.3440,  0.3045,  0.3602,
-         0.9591,  0.0000,  0.1916, -0.5371, -0.1424,  0.2126,  0.5032,  0.0000,
-         0.9790,  0.0000,  0.0699,  1.2494,  0.0000,  0.1247,  0.0671, -0.2790],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3042,  0.0074, -0.3242,  0.1962,  1.6281,  0.0882,  0.0000, -0.5067,
-         0.1196, -0.1128,  0.0872, -0.3088, -0.1409, -0.7376,  0.0000, -1.1975,
-        -0.4994,  0.5128,  0.0000, -0.4728,  0.0813,  0.1116,  0.0000, -0.7472,
-         0.1597,  1.1260,  0.0869, -0.0555, -0.2803,  0.0000, -0.3983, -1.6367,
-         0.0000, -0.0499,  0.0000, -0.2581,  0.0000,  0.0000,  1.0368,  0.4607,
-         0.2886,  0.4132,  0.0535,  0.0000, -1.0534, -0.3440,  0.3045,  0.3602,
-         0.9591,  0.0000,  0.1916, -0.5371, -0.1424,  0.2126,  0.5032,  0.0000,
-         0.9790,  0.0000,  0.0699,  1.2494,  0.0000,  0.1247,  0.0671, -0.2790],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4086e-01,  5.4015e-03, -3.6387e-01,  1.0157e-01,  1.8300e+00,
-         1.2201e-01,  2.3885e-07, -5.3257e-01,  4.2014e-02, -2.0836e-01,
-         5.4233e-04, -3.5044e-01, -1.9528e-01, -7.3592e-01, -4.2833e-04,
-        -1.1988e+00, -4.8245e-01,  4.9522e-01, -7.2429e-05, -5.4065e-01,
-        -3.7614e-02,  1.1575e-03, -1.5381e-04, -7.3150e-01,  9.9349e-02,
-         1.1403e+00,  8.0624e-02, -1.1415e-01, -2.2133e-01,  2.2370e-04,
-        -4.7816e-01, -1.6330e+00, -5.7262e-08, -1.4431e-01,  2.8422e-02,
-        -3.0511e-01,  9.4810e-04,  0.0000e+00,  1.0416e+00,  4.8529e-01,
-         2.3109e-01,  3.5306e-01,  4.6353e-02,  2.3334e-04, -1.0707e+00,
-        -3.2892e-01,  3.0438e-01,  2.8920e-01,  9.4983e-01, -1.4855e-05,
-         1.5028e-01, -6.2934e-01, -7.0323e-02,  1.3862e-01,  4.9088e-01,
-         3.0915e-09,  9.7782e-01,  4.3696e-07,  7.4061e-02,  1.2442e+00,
-         1.1672e-06,  1.1241e-01, -3.9282e-02, -2.6930e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.4086e-01,  5.4015e-03, -3.6387e-01,  1.0157e-01,  1.8300e+00,
-         1.2201e-01,  0.0000e+00, -5.3257e-01,  4.2014e-02, -2.0836e-01,
-         5.4233e-04, -3.5044e-01, -1.9528e-01, -7.3592e-01,  0.0000e+00,
-        -1.1988e+00, -4.8245e-01,  4.9522e-01,  0.0000e+00, -5.4065e-01,
-        -3.7614e-02,  1.1575e-03,  0.0000e+00, -7.3150e-01,  9.9349e-02,
-         1.1403e+00,  8.0624e-02, -1.1415e-01, -2.2133e-01,  0.0000e+00,
-        -4.7816e-01, -1.6330e+00,  0.0000e+00, -1.4431e-01,  0.0000e+00,
-        -3.0511e-01,  0.0000e+00,  0.0000e+00,  1.0416e+00,  4.8529e-01,
-         2.3109e-01,  3.5306e-01,  4.6353e-02,  0.0000e+00, -1.0707e+00,
-        -3.2892e-01,  3.0438e-01,  2.8920e-01,  9.4983e-01,  0.0000e+00,
-         1.5028e-01, -6.2934e-01, -7.0323e-02,  1.3862e-01,  4.9088e-01,
-         0.0000e+00,  9.7782e-01,  0.0000e+00,  7.4061e-02,  1.2442e+00,
-         0.0000e+00,  1.1241e-01, -3.9282e-02, -2.6930e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.4086e-01,  5.4015e-03, -3.6387e-01,  1.0157e-01,  1.8300e+00,
-         1.2201e-01,  0.0000e+00, -5.3257e-01,  4.2014e-02, -2.0836e-01,
-         5.4233e-04, -3.5044e-01, -1.9528e-01, -7.3592e-01,  0.0000e+00,
-        -1.1988e+00, -4.8245e-01,  4.9522e-01,  0.0000e+00, -5.4065e-01,
-        -3.7614e-02,  1.1575e-03,  0.0000e+00, -7.3150e-01,  9.9349e-02,
-         1.1403e+00,  8.0624e-02, -1.1415e-01, -2.2133e-01,  0.0000e+00,
-        -4.7816e-01, -1.6330e+00,  0.0000e+00, -1.4431e-01,  0.0000e+00,
-        -3.0511e-01,  0.0000e+00,  0.0000e+00,  1.0416e+00,  4.8529e-01,
-         2.3109e-01,  3.5306e-01,  4.6353e-02,  0.0000e+00, -1.0707e+00,
-        -3.2892e-01,  3.0438e-01,  2.8920e-01,  9.4983e-01,  0.0000e+00,
-         1.5028e-01, -6.2934e-01, -7.0323e-02,  1.3862e-01,  4.9088e-01,
-         0.0000e+00,  9.7782e-01,  0.0000e+00,  7.4061e-02,  1.2442e+00,
-         0.0000e+00,  1.1241e-01, -3.9282e-02, -2.6930e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.3909e-01,  4.2286e-02, -4.2047e-01,  1.1162e-02,  1.9997e+00,
-         9.6435e-02,  2.0327e-07, -5.4634e-01,  2.5081e-02, -2.7292e-01,
-        -6.4461e-02, -3.5200e-01, -2.7293e-01, -7.2464e-01, -3.6451e-04,
-        -1.1986e+00, -4.8132e-01,  5.0541e-01, -6.1637e-05, -5.8542e-01,
-        -9.9705e-02, -8.8148e-03, -1.3089e-04, -7.2482e-01,  1.3656e-01,
-         1.1437e+00,  1.3523e-02, -1.1533e-01, -1.7721e-01,  1.9037e-04,
-        -5.5571e-01, -1.6329e+00, -4.8730e-08, -1.5270e-01,  2.4187e-02,
-        -3.4839e-01,  8.0683e-04,  0.0000e+00,  1.0440e+00,  5.0701e-01,
-         2.2123e-01,  3.0777e-01,  7.9101e-03,  1.9857e-04, -1.0720e+00,
-        -2.6580e-01,  3.5074e-01,  2.0285e-01,  9.7042e-01, -1.2641e-05,
-         1.3243e-01, -6.7739e-01, -2.4416e-02,  2.0713e-02,  4.9084e-01,
-         2.6309e-09,  9.7478e-01,  3.7186e-07,  7.1613e-02,  1.2405e+00,
-         9.9331e-07, -7.8198e-04, -8.1368e-02, -3.3318e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.3909e-01,  4.2286e-02, -4.2047e-01,  1.1162e-02,  1.9997e+00,
-         9.6435e-02,  0.0000e+00, -5.4634e-01,  2.5081e-02, -2.7292e-01,
-        -6.4461e-02, -3.5200e-01, -2.7293e-01, -7.2464e-01,  0.0000e+00,
-        -1.1986e+00, -4.8132e-01,  5.0541e-01,  0.0000e+00, -5.8542e-01,
-        -9.9705e-02, -8.8148e-03,  0.0000e+00, -7.2482e-01,  1.3656e-01,
-         1.1437e+00,  1.3523e-02, -1.1533e-01, -1.7721e-01,  0.0000e+00,
-        -5.5571e-01, -1.6329e+00,  0.0000e+00, -1.5270e-01,  0.0000e+00,
-        -3.4839e-01,  0.0000e+00,  0.0000e+00,  1.0440e+00,  5.0701e-01,
-         2.2123e-01,  3.0777e-01,  7.9101e-03,  0.0000e+00, -1.0720e+00,
-        -2.6580e-01,  3.5074e-01,  2.0285e-01,  9.7042e-01,  0.0000e+00,
-         1.3243e-01, -6.7739e-01, -2.4416e-02,  2.0713e-02,  4.9084e-01,
-         0.0000e+00,  9.7478e-01,  0.0000e+00,  7.1613e-02,  1.2405e+00,
-         0.0000e+00, -7.8198e-04, -8.1368e-02, -3.3318e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.3909e-01,  4.2286e-02, -4.2047e-01,  1.1162e-02,  1.9997e+00,
-         9.6435e-02,  0.0000e+00, -5.4634e-01,  2.5081e-02, -2.7292e-01,
-        -6.4461e-02, -3.5200e-01, -2.7293e-01, -7.2464e-01,  0.0000e+00,
-        -1.1986e+00, -4.8132e-01,  5.0541e-01,  0.0000e+00, -5.8542e-01,
-        -9.9705e-02, -8.8148e-03,  0.0000e+00, -7.2482e-01,  1.3656e-01,
-         1.1437e+00,  1.3523e-02, -1.1533e-01, -1.7721e-01,  0.0000e+00,
-        -5.5571e-01, -1.6329e+00,  0.0000e+00, -1.5270e-01,  0.0000e+00,
-        -3.4839e-01,  0.0000e+00,  0.0000e+00,  1.0440e+00,  5.0701e-01,
-         2.2123e-01,  3.0777e-01,  7.9101e-03,  0.0000e+00, -1.0720e+00,
-        -2.6580e-01,  3.5074e-01,  2.0285e-01,  9.7042e-01,  0.0000e+00,
-         1.3243e-01, -6.7739e-01, -2.4416e-02,  2.0713e-02,  4.9084e-01,
-         0.0000e+00,  9.7478e-01,  0.0000e+00,  7.1613e-02,  1.2405e+00,
-         0.0000e+00, -7.8198e-04, -8.1368e-02, -3.3318e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1839e-01,  4.9494e-02, -4.3702e-01, -9.2403e-02,  2.1408e+00,
-         5.2306e-02,  1.7300e-07, -5.4616e-01,  1.7769e-02, -2.9559e-01,
-        -1.0003e-01, -3.3665e-01, -2.9175e-01, -7.0869e-01, -3.1023e-04,
-        -1.1973e+00, -4.7154e-01,  4.9720e-01, -5.2459e-05, -6.1184e-01,
-        -1.5541e-01, -2.9958e-02, -1.1140e-04, -7.0071e-01,  1.2534e-01,
-         1.1459e+00, -4.0897e-02, -1.3729e-01, -1.3579e-01,  1.6202e-04,
-        -6.1646e-01, -1.6324e+00, -4.1474e-08, -1.4651e-01,  2.0586e-02,
-        -3.7422e-01,  6.8669e-04,  0.0000e+00,  1.0397e+00,  5.0325e-01,
-         1.9470e-01,  2.6690e-01, -1.7530e-02,  1.6900e-04, -1.0752e+00,
-        -2.1830e-01,  3.9894e-01,  1.3599e-01,  9.8972e-01, -1.0759e-05,
-         1.1556e-01, -6.9798e-01,  1.6282e-02, -1.3618e-01,  4.7704e-01,
-         2.2391e-09,  9.7017e-01,  3.1648e-07,  1.2185e-01,  1.2381e+00,
-         8.4540e-07, -9.9388e-02, -1.2542e-01, -3.8339e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3184,  0.0495, -0.4370, -0.0924,  2.1408,  0.0523,  0.0000, -0.5462,
-         0.0178, -0.2956, -0.1000, -0.3367, -0.2917, -0.7087,  0.0000, -1.1973,
-        -0.4715,  0.4972,  0.0000, -0.6118, -0.1554, -0.0300,  0.0000, -0.7007,
-         0.1253,  1.1459, -0.0409, -0.1373, -0.1358,  0.0000, -0.6165, -1.6324,
-         0.0000, -0.1465,  0.0000, -0.3742,  0.0000,  0.0000,  1.0397,  0.5032,
-         0.1947,  0.2669, -0.0175,  0.0000, -1.0752, -0.2183,  0.3989,  0.1360,
-         0.9897,  0.0000,  0.1156, -0.6980,  0.0163, -0.1362,  0.4770,  0.0000,
-         0.9702,  0.0000,  0.1219,  1.2381,  0.0000, -0.0994, -0.1254, -0.3834],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3184,  0.0495, -0.4370, -0.0924,  2.1408,  0.0523,  0.0000, -0.5462,
-         0.0178, -0.2956, -0.1000, -0.3367, -0.2917, -0.7087,  0.0000, -1.1973,
-        -0.4715,  0.4972,  0.0000, -0.6118, -0.1554, -0.0300,  0.0000, -0.7007,
-         0.1253,  1.1459, -0.0409, -0.1373, -0.1358,  0.0000, -0.6165, -1.6324,
-         0.0000, -0.1465,  0.0000, -0.3742,  0.0000,  0.0000,  1.0397,  0.5032,
-         0.1947,  0.2669, -0.0175,  0.0000, -1.0752, -0.2183,  0.3989,  0.1360,
-         0.9897,  0.0000,  0.1156, -0.6980,  0.0163, -0.1362,  0.4770,  0.0000,
-         0.9702,  0.0000,  0.1219,  1.2381,  0.0000, -0.0994, -0.1254, -0.3834],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.3846e-01,  8.5701e-02, -3.8076e-01, -1.5493e-01,  2.2583e+00,
-         2.9116e-02,  1.4725e-07, -4.9830e-01,  1.2159e-02, -2.5460e-01,
-        -6.4146e-02, -3.4491e-01, -2.0445e-01, -6.9929e-01, -2.6407e-04,
-        -1.2032e+00, -4.9085e-01,  4.5801e-01, -4.4652e-05, -6.4056e-01,
-        -1.9054e-01,  3.3600e-02, -9.4825e-05, -6.8600e-01,  1.3122e-01,
-         1.1501e+00, -3.9226e-02, -1.5371e-01, -1.2543e-01,  1.3791e-04,
-        -6.6067e-01, -1.6262e+00, -3.5302e-08, -1.4764e-01,  1.7522e-02,
-        -3.6687e-01,  5.8450e-04,  0.0000e+00,  1.0379e+00,  5.1756e-01,
-         1.9051e-01,  2.5289e-01,  1.0016e-02,  1.4385e-04, -1.0831e+00,
-        -1.6463e-01,  4.5604e-01,  1.2719e-01,  1.0012e+00, -9.1579e-06,
-         1.3146e-01, -7.1119e-01, -1.3704e-02, -9.0878e-02,  5.1530e-01,
-         1.9059e-09,  9.6498e-01,  2.6939e-07,  1.6945e-01,  1.2336e+00,
-         7.1960e-07, -1.0596e-01, -9.9868e-02, -3.4194e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3385,  0.0857, -0.3808, -0.1549,  2.2583,  0.0291,  0.0000, -0.4983,
-         0.0122, -0.2546, -0.0641, -0.3449, -0.2044, -0.6993,  0.0000, -1.2032,
-         0.0000,  0.4580,  0.0000, -0.6406, -0.1905,  0.0336,  0.0000, -0.6860,
-         0.1312,  1.1501, -0.0392, -0.1537, -0.1254,  0.0000, -0.6607, -1.6262,
-         0.0000, -0.1476,  0.0000, -0.3669,  0.0000,  0.0000,  1.0379,  0.5176,
-         0.1905,  0.2529,  0.0100,  0.0000, -1.0831, -0.1646,  0.4560,  0.1272,
-         1.0012,  0.0000,  0.1315, -0.7112, -0.0137, -0.0909,  0.5153,  0.0000,
-         0.9650,  0.0000,  0.1694,  1.2336,  0.0000, -0.1060, -0.0999, -0.3419],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3385,  0.0857, -0.3808, -0.1549,  2.2583,  0.0291,  0.0000, -0.4983,
-         0.0122, -0.2546, -0.0641, -0.3449, -0.2044, -0.6993,  0.0000, -1.2032,
-         0.0000,  0.4580,  0.0000, -0.6406, -0.1905,  0.0336,  0.0000, -0.6860,
-         0.1312,  1.1501, -0.0392, -0.1537, -0.1254,  0.0000, -0.6607, -1.6262,
-         0.0000, -0.1476,  0.0000, -0.3669,  0.0000,  0.0000,  1.0379,  0.5176,
-         0.1905,  0.2529,  0.0100,  0.0000, -1.0831, -0.1646,  0.4560,  0.1272,
-         1.0012,  0.0000,  0.1315, -0.7112, -0.0137, -0.0909,  0.5153,  0.0000,
-         0.9650,  0.0000,  0.1694,  1.2336,  0.0000, -0.1060, -0.0999, -0.3419],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6194e-01,  1.1277e-01, -3.0877e-01, -1.9019e-01,  2.3539e+00,
-         5.3199e-02,  1.2535e-07, -4.6421e-01, -2.6619e-02, -2.0506e-01,
-        -3.1760e-03, -3.1940e-01, -8.3008e-02, -7.0362e-01, -2.2480e-04,
-        -1.2106e+00, -1.6435e-02,  4.1444e-01, -3.8012e-05, -6.6979e-01,
-        -2.2993e-01,  1.0003e-01, -8.0723e-05, -7.1550e-01,  1.3287e-01,
-         1.1521e+00, -3.6087e-03, -1.7563e-01, -1.9636e-01,  1.1740e-04,
-        -6.6187e-01, -1.6196e+00, -3.0052e-08, -1.9940e-01,  1.4916e-02,
-        -3.3522e-01,  4.9758e-04,  0.0000e+00,  1.0336e+00,  5.2806e-01,
-         1.7454e-01,  2.4990e-01,  6.4877e-02,  1.2246e-04, -1.0895e+00,
-        -8.0136e-02,  5.0763e-01,  1.4117e-01,  9.9709e-01, -7.7960e-06,
-         1.5494e-01, -7.5460e-01, -2.3992e-02,  1.9358e-02,  5.4570e-01,
-         1.6225e-09,  9.6025e-01,  2.2933e-07,  2.1642e-01,  1.2384e+00,
-         6.1258e-07, -3.7178e-02, -3.9137e-02, -2.5710e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3619,  0.1128, -0.3088, -0.1902,  2.3539,  0.0532,  0.0000, -0.4642,
-        -0.0266, -0.2051, -0.0032, -0.3194, -0.0830, -0.7036,  0.0000, -1.2106,
-         0.0000,  0.4144,  0.0000, -0.6698, -0.2299,  0.1000,  0.0000, -0.7155,
-         0.1329,  1.1521, -0.0036, -0.1756, -0.1964,  0.0000, -0.6619, -1.6196,
-         0.0000, -0.1994,  0.0000, -0.3352,  0.0000,  0.0000,  1.0336,  0.5281,
-         0.1745,  0.2499,  0.0649,  0.0000, -1.0895, -0.0801,  0.5076,  0.1412,
-         0.9971,  0.0000,  0.1549, -0.7546, -0.0240,  0.0194,  0.5457,  0.0000,
-         0.9602,  0.0000,  0.2164,  1.2384,  0.0000, -0.0372, -0.0391, -0.2571],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3619,  0.1128, -0.3088, -0.1902,  2.3539,  0.0532,  0.0000, -0.4642,
-        -0.0266, -0.2051, -0.0032, -0.3194, -0.0830, -0.7036,  0.0000, -1.2106,
-         0.0000,  0.4144,  0.0000, -0.6698, -0.2299,  0.1000,  0.0000, -0.7155,
-         0.1329,  1.1521, -0.0036, -0.1756, -0.1964,  0.0000, -0.6619, -1.6196,
-         0.0000, -0.1994,  0.0000, -0.3352,  0.0000,  0.0000,  1.0336,  0.5281,
-         0.1745,  0.2499,  0.0649,  0.0000, -1.0895, -0.0801,  0.5076,  0.1412,
-         0.9971,  0.0000,  0.1549, -0.7546, -0.0240,  0.0194,  0.5457,  0.0000,
-         0.9602,  0.0000,  0.2164,  1.2384,  0.0000, -0.0372, -0.0391, -0.2571],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6533e-01,  1.6307e-01, -2.5847e-01, -2.2187e-01,  2.4372e+00,
-         4.2116e-02,  1.0673e-07, -4.0954e-01, -3.9911e-02, -2.0979e-01,
-         2.1614e-02, -2.6262e-01,  1.3342e-02, -6.7814e-01, -1.9139e-04,
-        -1.2060e+00, -1.3992e-02,  3.9265e-01, -3.2363e-05, -6.8025e-01,
-        -2.3821e-01,  1.1741e-01, -6.8727e-05, -7.3586e-01,  1.1249e-01,
-         1.1511e+00,  6.4921e-03, -1.7071e-01, -2.6219e-01,  9.9953e-05,
-        -6.6007e-01, -1.6137e+00, -2.5586e-08, -1.9484e-01,  1.2700e-02,
-        -2.9337e-01,  4.2363e-04,  0.0000e+00,  1.0189e+00,  5.2192e-01,
-         1.5995e-01,  2.4062e-01,  1.1596e-01,  1.0426e-04, -1.0960e+00,
-         9.5587e-03,  5.5293e-01,  1.2229e-01,  9.9635e-01, -6.6374e-06,
-         1.7653e-01, -7.8136e-01, -2.5412e-02,  1.0333e-01,  5.6480e-01,
-         1.3813e-09,  9.5769e-01,  1.9524e-07,  2.4322e-01,  1.2378e+00,
-         5.2154e-07,  1.2371e-02,  3.7374e-03, -1.9089e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3653,  0.1631, -0.2585, -0.2219,  2.4372,  0.0421,  0.0000, -0.4095,
-        -0.0399, -0.2098,  0.0216, -0.2626,  0.0133, -0.6781,  0.0000, -1.2060,
-         0.0000,  0.3926,  0.0000, -0.6803, -0.2382,  0.1174,  0.0000, -0.7359,
-         0.1125,  1.1511,  0.0065, -0.1707, -0.2622,  0.0000, -0.6601, -1.6137,
-         0.0000, -0.1948,  0.0000, -0.2934,  0.0000,  0.0000,  1.0189,  0.5219,
-         0.1600,  0.2406,  0.1160,  0.0000, -1.0960,  0.0096,  0.5529,  0.1223,
-         0.9963,  0.0000,  0.1765, -0.7814, -0.0254,  0.1033,  0.5648,  0.0000,
-         0.9577,  0.0000,  0.2432,  1.2378,  0.0000,  0.0124,  0.0037, -0.1909],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3653,  0.1631, -0.2585, -0.2219,  2.4372,  0.0421,  0.0000, -0.4095,
-        -0.0399, -0.2098,  0.0216, -0.2626,  0.0133, -0.6781,  0.0000, -1.2060,
-         0.0000,  0.3926,  0.0000, -0.6803, -0.2382,  0.1174,  0.0000, -0.7359,
-         0.1125,  1.1511,  0.0065, -0.1707, -0.2622,  0.0000, -0.6601, -1.6137,
-         0.0000, -0.1948,  0.0000, -0.2934,  0.0000,  0.0000,  1.0189,  0.5219,
-         0.1600,  0.2406,  0.1160,  0.0000, -1.0960,  0.0096,  0.5529,  0.1223,
-         0.9963,  0.0000,  0.1765, -0.7814, -0.0254,  0.1033,  0.5648,  0.0000,
-         0.9577,  0.0000,  0.2432,  1.2378,  0.0000,  0.0124,  0.0037, -0.1909],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2938e-01,  2.4265e-01, -2.8657e-01, -2.7689e-01,  2.5049e+00,
-        -4.0681e-03,  9.0875e-08, -3.5372e-01, -3.7643e-02, -2.5501e-01,
-         1.1189e-01, -2.3593e-01,  4.9567e-02, -6.4470e-01, -1.6297e-04,
-        -1.2033e+00, -1.1914e-02,  3.6899e-01, -2.7557e-05, -6.8378e-01,
-        -2.5067e-01,  1.6948e-01, -5.8520e-05, -7.2651e-01,  1.4675e-01,
-         1.1527e+00, -4.5062e-02, -1.0511e-01, -2.9978e-01,  8.5109e-05,
-        -6.8064e-01, -1.6080e+00, -2.1786e-08, -1.4734e-01,  1.0814e-02,
-        -2.7368e-01,  3.6072e-04,  0.0000e+00,  1.0006e+00,  4.9126e-01,
-         1.8493e-01,  2.5698e-01,  1.0678e-01,  8.8776e-05, -1.1010e+00,
-         7.0966e-02,  5.9990e-01,  7.6303e-02,  1.0003e+00, -5.6517e-06,
-         2.3084e-01, -7.5335e-01, -5.8585e-03,  1.1081e-01,  5.5174e-01,
-         1.1762e-09,  9.5664e-01,  1.6625e-07,  2.5351e-01,  1.2324e+00,
-         4.4409e-07, -5.3832e-02,  6.6364e-02, -1.8481e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3294,  0.2426, -0.2866, -0.2769,  2.5049, -0.0041,  0.0000, -0.3537,
-        -0.0376, -0.2550,  0.1119, -0.2359,  0.0496, -0.6447,  0.0000, -1.2033,
-         0.0000,  0.3690,  0.0000, -0.6838, -0.2507,  0.1695,  0.0000, -0.7265,
-         0.1467,  1.1527, -0.0451, -0.1051, -0.2998,  0.0000, -0.6806, -1.6080,
-         0.0000, -0.1473,  0.0000, -0.2737,  0.0000,  0.0000,  1.0006,  0.4913,
-         0.1849,  0.2570,  0.1068,  0.0000, -1.1010,  0.0710,  0.5999,  0.0763,
-         1.0003,  0.0000,  0.2308, -0.7533, -0.0059,  0.1108,  0.5517,  0.0000,
-         0.9566,  0.0000,  0.2535,  1.2324,  0.0000, -0.0538,  0.0664, -0.1848],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3294,  0.2426, -0.2866, -0.2769,  2.5049, -0.0041,  0.0000, -0.3537,
-        -0.0376, -0.2550,  0.1119, -0.2359,  0.0496, -0.6447,  0.0000, -1.2033,
-         0.0000,  0.3690,  0.0000, -0.6838, -0.2507,  0.1695,  0.0000, -0.7265,
-         0.1467,  1.1527, -0.0451, -0.1051, -0.2998,  0.0000, -0.6806, -1.6080,
-         0.0000, -0.1473,  0.0000, -0.2737,  0.0000,  0.0000,  1.0006,  0.4913,
-         0.1849,  0.2570,  0.1068,  0.0000, -1.1010,  0.0710,  0.5999,  0.0763,
-         1.0003,  0.0000,  0.2308, -0.7533, -0.0059,  0.1108,  0.5517,  0.0000,
-         0.9566,  0.0000,  0.2535,  1.2324,  0.0000, -0.0538,  0.0664, -0.1848],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8029e-01,  2.9624e-01, -3.3060e-01, -3.3703e-01,  2.5575e+00,
-        -1.9141e-02,  7.7389e-08, -3.5035e-01, -3.0959e-02, -2.8140e-01,
-         1.3617e-01, -2.2096e-01,  4.9927e-02, -6.2027e-01, -1.3878e-04,
-        -1.2006e+00, -1.0146e-02,  3.1638e-01, -2.3467e-05, -6.9843e-01,
-        -2.6589e-01,  1.4600e-01, -4.9836e-05, -6.9897e-01,  1.3408e-01,
-         1.1594e+00, -6.9759e-02, -7.6027e-02, -3.5340e-01,  7.2479e-05,
-        -6.9241e-01, -1.6017e+00, -1.8553e-08, -1.0877e-01,  9.2089e-03,
-        -2.6793e-01,  3.0719e-04,  0.0000e+00,  9.8183e-01,  4.5764e-01,
-         1.4231e-01,  2.5142e-01,  9.7194e-02,  7.5601e-05, -1.1013e+00,
-         1.9355e-01,  6.2879e-01,  3.6543e-02,  9.9136e-01, -4.8129e-06,
-         2.3380e-01, -7.3635e-01,  4.1772e-02,  1.0736e-01,  4.9048e-01,
-         1.0017e-09,  9.5624e-01,  1.4158e-07,  2.4830e-01,  1.2273e+00,
-         3.7819e-07, -8.2222e-02,  5.4294e-02, -2.0486e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2803,  0.2962, -0.3306, -0.3370,  2.5575, -0.0191,  0.0000, -0.3504,
-        -0.0310, -0.2814,  0.1362, -0.2210,  0.0499, -0.6203,  0.0000, -1.2006,
-         0.0000,  0.3164,  0.0000, -0.6984, -0.2659,  0.1460,  0.0000, -0.6990,
-         0.1341,  1.1594, -0.0698, -0.0760, -0.3534,  0.0000, -0.6924, -1.6017,
-         0.0000, -0.1088,  0.0000, -0.2679,  0.0000,  0.0000,  0.9818,  0.4576,
-         0.1423,  0.2514,  0.0972,  0.0000, -1.1013,  0.1936,  0.6288,  0.0365,
-         0.9914,  0.0000,  0.2338, -0.7364,  0.0418,  0.1074,  0.4905,  0.0000,
-         0.0000,  0.0000,  0.2483,  1.2273,  0.0000, -0.0822,  0.0543, -0.2049],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2803,  0.2962, -0.3306, -0.3370,  2.5575, -0.0191,  0.0000, -0.3504,
-        -0.0310, -0.2814,  0.1362, -0.2210,  0.0499, -0.6203,  0.0000, -1.2006,
-         0.0000,  0.3164,  0.0000, -0.6984, -0.2659,  0.1460,  0.0000, -0.6990,
-         0.1341,  1.1594, -0.0698, -0.0760, -0.3534,  0.0000, -0.6924, -1.6017,
-         0.0000, -0.1088,  0.0000, -0.2679,  0.0000,  0.0000,  0.9818,  0.4576,
-         0.1423,  0.2514,  0.0972,  0.0000, -1.1013,  0.1936,  0.6288,  0.0365,
-         0.9914,  0.0000,  0.2338, -0.7364,  0.0418,  0.1074,  0.4905,  0.0000,
-         0.0000,  0.0000,  0.2483,  1.2273,  0.0000, -0.0822,  0.0543, -0.2049],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6082e-01,  3.2753e-01, -3.5603e-01, -3.8333e-01,  2.5988e+00,
-        -9.7076e-03,  6.5913e-08, -3.5006e-01,  2.5504e-02, -2.6638e-01,
-         1.0657e-01, -2.2174e-01, -3.9076e-03, -6.0264e-01, -1.1820e-04,
-        -1.1991e+00, -8.6415e-03,  2.8947e-01, -1.9987e-05, -7.2047e-01,
-        -2.6836e-01,  1.1269e-01, -4.2445e-05, -6.6228e-01,  1.5729e-01,
-         1.1725e+00, -9.2198e-02, -3.4408e-02, -3.6270e-01,  6.1731e-05,
-        -6.8494e-01, -1.5955e+00, -1.5802e-08, -6.1525e-02,  7.8433e-03,
-        -2.8633e-01,  2.6163e-04,  0.0000e+00,  9.5653e-01,  3.9588e-01,
-         7.0901e-02,  2.2389e-01,  8.1896e-02,  6.4390e-05, -1.1007e+00,
-         3.1306e-01,  6.4397e-01, -1.6756e-02,  9.9268e-01, -4.0992e-06,
-         2.0297e-01, -7.0532e-01,  6.1119e-02,  1.6094e-01,  4.0214e-01,
-         8.5312e-10, -3.3529e-04,  1.2058e-07,  2.0217e-01,  1.2169e+00,
-         3.2210e-07, -1.1626e-01,  3.8073e-03, -2.1409e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2608,  0.3275, -0.3560, -0.3833,  2.5988, -0.0097,  0.0000, -0.3501,
-         0.0255, -0.2664,  0.1066, -0.2217, -0.0039, -0.6026,  0.0000, -1.1991,
-         0.0000,  0.2895,  0.0000, -0.7205, -0.2684,  0.1127,  0.0000, -0.6623,
-         0.1573,  1.1725, -0.0922, -0.0344, -0.3627,  0.0000, -0.6849, -1.5955,
-         0.0000, -0.0615,  0.0000, -0.2863,  0.0000,  0.0000,  0.9565,  0.3959,
-         0.0709,  0.2239,  0.0819,  0.0000, -1.1007,  0.3131,  0.6440, -0.0168,
-         0.9927,  0.0000,  0.2030, -0.7053,  0.0611,  0.1609,  0.4021,  0.0000,
-         0.0000,  0.0000,  0.2022,  1.2169,  0.0000, -0.1163,  0.0038, -0.2141],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2608,  0.3275, -0.3560, -0.3833,  2.5988, -0.0097,  0.0000, -0.3501,
-         0.0255, -0.2664,  0.1066, -0.2217, -0.0039, -0.6026,  0.0000, -1.1991,
-         0.0000,  0.2895,  0.0000, -0.7205, -0.2684,  0.1127,  0.0000, -0.6623,
-         0.1573,  1.1725, -0.0922, -0.0344, -0.3627,  0.0000, -0.6849, -1.5955,
-         0.0000, -0.0615,  0.0000, -0.2863,  0.0000,  0.0000,  0.9565,  0.3959,
-         0.0709,  0.2239,  0.0819,  0.0000, -1.1007,  0.3131,  0.6440, -0.0168,
-         0.9927,  0.0000,  0.2030, -0.7053,  0.0611,  0.1609,  0.4021,  0.0000,
-         0.0000,  0.0000,  0.2022,  1.2169,  0.0000, -0.1163,  0.0038, -0.2141],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8554e-01,  3.2674e-01, -3.5165e-01, -4.3247e-01,  2.6344e+00,
-         1.6420e-02,  5.6146e-08, -3.5790e-01,  1.2655e-03, -2.2194e-01,
-         8.4136e-02, -1.8196e-01,  1.9811e-02, -6.0217e-01, -1.0069e-04,
-        -1.2045e+00, -7.3610e-03,  2.5757e-01, -1.7026e-05, -7.2488e-01,
-        -2.6226e-01,  6.7459e-02, -3.6156e-05, -6.5102e-01,  1.6992e-01,
-         1.1790e+00, -8.0194e-02,  5.8209e-03, -3.8123e-01,  5.2584e-05,
-        -6.8024e-01, -1.5894e+00, -1.3460e-08, -9.2689e-02,  6.6811e-03,
-        -2.7350e-01,  2.2287e-04,  0.0000e+00,  9.5817e-01,  4.2480e-01,
-         2.3028e-03,  1.9750e-01,  3.5224e-02,  5.4849e-05, -1.0994e+00,
-         3.6369e-01,  6.5611e-01, -1.1833e-03,  9.9217e-01, -3.4918e-06,
-         1.5225e-01, -6.7281e-01,  3.8428e-02,  1.1926e-01,  3.8971e-01,
-         7.2670e-10, -2.8561e-04,  1.0271e-07,  1.6975e-01,  1.2041e+00,
-         2.7438e-07, -1.6735e-01, -1.6385e-02, -1.9701e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.8554e-01,  3.2674e-01, -3.5165e-01, -4.3247e-01,  2.6344e+00,
-         1.6420e-02,  0.0000e+00, -3.5790e-01,  1.2655e-03, -2.2194e-01,
-         8.4136e-02, -1.8196e-01,  1.9811e-02, -6.0217e-01,  0.0000e+00,
-        -1.2045e+00,  0.0000e+00,  2.5757e-01,  0.0000e+00, -7.2488e-01,
-        -2.6226e-01,  6.7459e-02,  0.0000e+00, -6.5102e-01,  1.6992e-01,
-         1.1790e+00, -8.0194e-02,  5.8209e-03, -3.8123e-01,  0.0000e+00,
-        -6.8024e-01, -1.5894e+00,  0.0000e+00, -9.2689e-02,  0.0000e+00,
-        -2.7350e-01,  0.0000e+00,  0.0000e+00,  9.5817e-01,  4.2480e-01,
-         2.3028e-03,  1.9750e-01,  3.5224e-02,  0.0000e+00, -1.0994e+00,
-         3.6369e-01,  6.5611e-01, -1.1833e-03,  9.9217e-01,  0.0000e+00,
-         1.5225e-01, -6.7281e-01,  3.8428e-02,  1.1926e-01,  3.8971e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.6975e-01,  1.2041e+00,
-         0.0000e+00, -1.6735e-01, -1.6385e-02, -1.9701e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.8554e-01,  3.2674e-01, -3.5165e-01, -4.3247e-01,  2.6344e+00,
-         1.6420e-02,  0.0000e+00, -3.5790e-01,  1.2655e-03, -2.2194e-01,
-         8.4136e-02, -1.8196e-01,  1.9811e-02, -6.0217e-01,  0.0000e+00,
-        -1.2045e+00,  0.0000e+00,  2.5757e-01,  0.0000e+00, -7.2488e-01,
-        -2.6226e-01,  6.7459e-02,  0.0000e+00, -6.5102e-01,  1.6992e-01,
-         1.1790e+00, -8.0194e-02,  5.8209e-03, -3.8123e-01,  0.0000e+00,
-        -6.8024e-01, -1.5894e+00,  0.0000e+00, -9.2689e-02,  0.0000e+00,
-        -2.7350e-01,  0.0000e+00,  0.0000e+00,  9.5817e-01,  4.2480e-01,
-         2.3028e-03,  1.9750e-01,  3.5224e-02,  0.0000e+00, -1.0994e+00,
-         3.6369e-01,  6.5611e-01, -1.1833e-03,  9.9217e-01,  0.0000e+00,
-         1.5225e-01, -6.7281e-01,  3.8428e-02,  1.1926e-01,  3.8971e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.6975e-01,  1.2041e+00,
-         0.0000e+00, -1.6735e-01, -1.6385e-02, -1.9701e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.3586e-01,  2.8995e-01, -3.1029e-01, -4.6964e-01,  2.6653e+00,
-         6.6093e-02,  4.7833e-08, -3.4545e-01, -3.5952e-02, -1.4463e-01,
-         1.1457e-01, -8.3715e-02,  1.3084e-01, -5.9214e-01, -8.5779e-05,
-        -1.2053e+00, -6.2712e-03,  2.5915e-01, -1.4505e-05, -7.2159e-01,
-        -2.3406e-01, -8.8243e-03, -3.0803e-05, -6.4234e-01,  1.3680e-01,
-         1.1795e+00, -5.6621e-02,  1.6773e-02, -3.7962e-01,  4.4798e-05,
-        -6.8693e-01, -1.5827e+00, -1.1467e-08, -1.7331e-01,  5.6919e-03,
-        -2.3338e-01,  1.8987e-04,  0.0000e+00,  9.5467e-01,  4.6410e-01,
-        -5.6642e-02,  1.7913e-01,  1.0881e-02,  4.6728e-05, -1.0976e+00,
-         3.8881e-01,  6.7180e-01,  3.9993e-02,  9.9133e-01, -2.9748e-06,
-         1.2660e-01, -6.2717e-01,  2.4263e-02,  5.3831e-02,  3.9157e-01,
-         6.1911e-10, -2.4332e-04,  8.7507e-08,  1.0545e-01,  1.1941e+00,
-         2.3375e-07, -1.8611e-01, -1.6598e-02, -1.4129e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3359,  0.2900, -0.3103, -0.4696,  2.6653,  0.0661,  0.0000, -0.3455,
-        -0.0360, -0.1446,  0.1146, -0.0837,  0.1308, -0.5921,  0.0000, -1.2053,
-         0.0000,  0.2592,  0.0000, -0.7216, -0.2341, -0.0088,  0.0000, -0.6423,
-         0.1368,  1.1795, -0.0566,  0.0168, -0.3796,  0.0000, -0.6869, -1.5827,
-         0.0000, -0.1733,  0.0000, -0.2334,  0.0000,  0.0000,  0.9547,  0.4641,
-        -0.0566,  0.1791,  0.0109,  0.0000, -1.0976,  0.3888,  0.6718,  0.0400,
-         0.9913,  0.0000,  0.1266, -0.6272,  0.0243,  0.0538,  0.3916,  0.0000,
-         0.0000,  0.0000,  0.1055,  1.1941,  0.0000, -0.1861, -0.0166, -0.1413],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3359,  0.2900, -0.3103, -0.4696,  2.6653,  0.0661,  0.0000, -0.3455,
-        -0.0360, -0.1446,  0.1146, -0.0837,  0.1308, -0.5921,  0.0000, -1.2053,
-         0.0000,  0.2592,  0.0000, -0.7216, -0.2341, -0.0088,  0.0000, -0.6423,
-         0.1368,  1.1795, -0.0566,  0.0168, -0.3796,  0.0000, -0.6869, -1.5827,
-         0.0000, -0.1733,  0.0000, -0.2334,  0.0000,  0.0000,  0.9547,  0.4641,
-        -0.0566,  0.1791,  0.0109,  0.0000, -1.0976,  0.3888,  0.6718,  0.0400,
-         0.9913,  0.0000,  0.1266, -0.6272,  0.0243,  0.0538,  0.3916,  0.0000,
-         0.0000,  0.0000,  0.1055,  1.1941,  0.0000, -0.1861, -0.0166, -0.1413],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7629e-01,  2.5589e-01, -2.6619e-01, -5.1612e-01,  2.6927e+00,
-         1.1796e-01,  4.0757e-08, -3.0627e-01, -9.3462e-02, -5.6244e-02,
-         2.3380e-01,  3.7126e-02,  2.4282e-01, -5.9189e-01, -7.3089e-05,
-        -1.2102e+00, -5.3434e-03,  2.3564e-01, -1.2359e-05, -7.1261e-01,
-        -2.0341e-01, -9.1648e-02, -2.6246e-05, -6.2226e-01,  8.4062e-02,
-         1.1810e+00, -3.8947e-02, -7.0219e-03, -3.7447e-01,  3.8171e-05,
-        -6.9077e-01, -1.5769e+00, -9.7710e-09, -2.7326e-01,  4.8498e-03,
-        -1.6254e-01,  1.6178e-04,  0.0000e+00,  9.6262e-01,  5.4382e-01,
-        -1.0668e-01,  1.5753e-01, -1.6687e-02,  3.9815e-05, -1.0889e+00,
-         4.0298e-01,  6.8975e-01,  8.8506e-02,  9.8066e-01, -2.5347e-06,
-         1.4641e-01, -5.9084e-01,  4.2961e-02, -7.3256e-02,  4.1980e-01,
-         5.2752e-10, -2.0732e-04,  7.4561e-08,  1.0493e-02,  1.1820e+00,
-         1.9917e-07, -2.1601e-01,  1.3343e-03, -7.4563e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.7629e-01,  2.5589e-01, -2.6619e-01, -5.1612e-01,  2.6927e+00,
-         1.1796e-01,  0.0000e+00, -3.0627e-01, -9.3462e-02, -5.6244e-02,
-         2.3380e-01,  3.7126e-02,  2.4282e-01, -5.9189e-01,  0.0000e+00,
-        -1.2102e+00,  0.0000e+00,  2.3564e-01,  0.0000e+00, -7.1261e-01,
-        -2.0341e-01, -9.1648e-02,  0.0000e+00, -6.2226e-01,  8.4062e-02,
-         1.1810e+00, -3.8947e-02, -7.0219e-03, -3.7447e-01,  0.0000e+00,
-        -6.9077e-01, -1.5769e+00,  0.0000e+00, -2.7326e-01,  0.0000e+00,
-        -1.6254e-01,  0.0000e+00,  0.0000e+00,  9.6262e-01,  5.4382e-01,
-        -1.0668e-01,  0.0000e+00, -1.6687e-02,  0.0000e+00, -1.0889e+00,
-         4.0298e-01,  6.8975e-01,  8.8506e-02,  9.8066e-01,  0.0000e+00,
-         1.4641e-01, -5.9084e-01,  4.2961e-02, -7.3256e-02,  4.1980e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0493e-02,  1.1820e+00,
-         0.0000e+00, -2.1601e-01,  1.3343e-03, -7.4563e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.7629e-01,  2.5589e-01, -2.6619e-01, -5.1612e-01,  2.6927e+00,
-         1.1796e-01,  0.0000e+00, -3.0627e-01, -9.3462e-02, -5.6244e-02,
-         2.3380e-01,  3.7126e-02,  2.4282e-01, -5.9189e-01,  0.0000e+00,
-        -1.2102e+00,  0.0000e+00,  2.3564e-01,  0.0000e+00, -7.1261e-01,
-        -2.0341e-01, -9.1648e-02,  0.0000e+00, -6.2226e-01,  8.4062e-02,
-         1.1810e+00, -3.8947e-02, -7.0219e-03, -3.7447e-01,  0.0000e+00,
-        -6.9077e-01, -1.5769e+00,  0.0000e+00, -2.7326e-01,  0.0000e+00,
-        -1.6254e-01,  0.0000e+00,  0.0000e+00,  9.6262e-01,  5.4382e-01,
-        -1.0668e-01,  0.0000e+00, -1.6687e-02,  0.0000e+00, -1.0889e+00,
-         4.0298e-01,  6.8975e-01,  8.8506e-02,  9.8066e-01,  0.0000e+00,
-         1.4641e-01, -5.9084e-01,  4.2961e-02, -7.3256e-02,  4.1980e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0493e-02,  1.1820e+00,
-         0.0000e+00, -2.1601e-01,  1.3343e-03, -7.4563e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.9597e-01,  2.9150e-01, -2.0785e-01, -5.1453e-01,  2.7153e+00,
-         2.6733e-01,  3.4733e-08, -2.4367e-01, -8.0091e-02,  5.8067e-02,
-         3.2736e-01,  8.9504e-02,  2.8826e-01, -5.8848e-01, -6.2285e-05,
-        -1.2186e+00, -4.5536e-03,  2.2786e-01, -1.0532e-05, -7.1105e-01,
-        -1.4868e-01,  2.3451e-02, -2.2366e-05, -5.9915e-01,  1.8202e-01,
-         1.1816e+00,  3.2583e-03,  5.1164e-02, -3.4627e-01,  3.2529e-05,
-        -7.0374e-01, -1.5713e+00, -8.3267e-09, -2.3613e-01,  4.1330e-03,
-        -1.5218e-01,  1.3787e-04,  0.0000e+00,  9.5472e-01,  5.7578e-01,
-        -2.2663e-02, -1.8414e-02,  4.4136e-03,  3.3930e-05, -1.0809e+00,
-         3.1290e-01,  7.0048e-01,  9.6657e-02,  9.8773e-01, -2.1601e-06,
-         1.9494e-01, -4.8882e-01,  2.4142e-02,  5.7301e-02,  4.1918e-01,
-         4.4955e-10, -1.7668e-04,  6.3540e-08, -5.7581e-04,  1.1687e+00,
-         1.6973e-07, -1.9104e-01,  5.2621e-02,  3.9439e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.9597e-01,  2.9150e-01, -2.0785e-01, -5.1453e-01,  2.7153e+00,
-         2.6733e-01,  0.0000e+00, -2.4367e-01, -8.0091e-02,  5.8067e-02,
-         3.2736e-01,  8.9504e-02,  2.8826e-01, -5.8848e-01,  0.0000e+00,
-        -1.2186e+00,  0.0000e+00,  2.2786e-01,  0.0000e+00, -7.1105e-01,
-        -1.4868e-01,  2.3451e-02,  0.0000e+00, -5.9915e-01,  1.8202e-01,
-         1.1816e+00,  3.2583e-03,  5.1164e-02, -3.4627e-01,  0.0000e+00,
-        -7.0374e-01, -1.5713e+00,  0.0000e+00, -2.3613e-01,  0.0000e+00,
-        -1.5218e-01,  0.0000e+00,  0.0000e+00,  9.5472e-01,  5.7578e-01,
-        -2.2663e-02,  0.0000e+00,  4.4136e-03,  0.0000e+00, -1.0809e+00,
-         3.1290e-01,  7.0048e-01,  9.6657e-02,  9.8773e-01,  0.0000e+00,
-         1.9494e-01, -4.8882e-01,  2.4142e-02,  5.7301e-02,  4.1918e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -5.7581e-04,  1.1687e+00,
-         0.0000e+00, -1.9104e-01,  5.2621e-02,  3.9439e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.9597e-01,  2.9150e-01, -2.0785e-01, -5.1453e-01,  2.7153e+00,
-         2.6733e-01,  0.0000e+00, -2.4367e-01, -8.0091e-02,  5.8067e-02,
-         3.2736e-01,  8.9504e-02,  2.8826e-01, -5.8848e-01,  0.0000e+00,
-        -1.2186e+00,  0.0000e+00,  2.2786e-01,  0.0000e+00, -7.1105e-01,
-        -1.4868e-01,  2.3451e-02,  0.0000e+00, -5.9915e-01,  1.8202e-01,
-         1.1816e+00,  3.2583e-03,  5.1164e-02, -3.4627e-01,  0.0000e+00,
-        -7.0374e-01, -1.5713e+00,  0.0000e+00, -2.3613e-01,  0.0000e+00,
-        -1.5218e-01,  0.0000e+00,  0.0000e+00,  9.5472e-01,  5.7578e-01,
-        -2.2663e-02,  0.0000e+00,  4.4136e-03,  0.0000e+00, -1.0809e+00,
-         3.1290e-01,  7.0048e-01,  9.6657e-02,  9.8773e-01,  0.0000e+00,
-         1.9494e-01, -4.8882e-01,  2.4142e-02,  5.7301e-02,  4.1918e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -5.7581e-04,  1.1687e+00,
-         0.0000e+00, -1.9104e-01,  5.2621e-02,  3.9439e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7466e-01,  3.6420e-01, -1.3762e-01, -4.8241e-01,  2.7323e+00,
-         4.3250e-01,  2.9603e-08, -1.8445e-01, -2.2609e-02,  1.4945e-01,
-         4.4482e-01,  9.2314e-02,  2.7332e-01, -5.6656e-01, -5.3087e-05,
-        -1.2279e+00, -3.8811e-03,  2.7367e-01, -8.9767e-06, -7.0644e-01,
-        -9.7683e-02,  2.0208e-01, -1.9063e-05, -5.6923e-01,  3.5015e-01,
-         1.1801e+00,  3.8817e-02,  1.3571e-01, -3.1424e-01,  2.7725e-05,
-        -6.9715e-01, -1.5675e+00, -7.0970e-09, -1.6069e-01,  3.5226e-03,
-        -1.4147e-01,  1.1751e-04,  0.0000e+00,  9.1792e-01,  5.3769e-01,
-         6.4764e-02, -1.5695e-02,  5.3009e-02,  2.8919e-05, -1.0720e+00,
-         1.7646e-01,  7.0780e-01,  5.5209e-02,  1.0160e+00, -1.8411e-06,
-         2.3260e-01, -3.8116e-01,  5.2245e-03,  2.8750e-01,  3.4297e-01,
-         3.8315e-10, -1.5059e-04,  5.4156e-08,  3.4867e-02,  1.1618e+00,
-         1.4466e-07, -8.3979e-02,  1.0529e-01,  1.3709e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3747,  0.3642, -0.1376, -0.4824,  2.7323,  0.4325,  0.0000, -0.1845,
-        -0.0226,  0.1494,  0.4448,  0.0923,  0.2733, -0.5666,  0.0000, -1.2279,
-         0.0000,  0.2737,  0.0000, -0.7064, -0.0977,  0.2021,  0.0000, -0.5692,
-         0.3501,  1.1801,  0.0388,  0.1357, -0.3142,  0.0000, -0.6972, -1.5675,
-         0.0000, -0.1607,  0.0000, -0.1415,  0.0000,  0.0000,  0.9179,  0.5377,
-         0.0648,  0.0000,  0.0530,  0.0000, -1.0720,  0.1765,  0.7078,  0.0552,
-         1.0160,  0.0000,  0.2326, -0.3812,  0.0052,  0.2875,  0.3430,  0.0000,
-         0.0000,  0.0000,  0.0349,  1.1618,  0.0000, -0.0840,  0.1053,  0.1371],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3747,  0.3642, -0.1376, -0.4824,  2.7323,  0.4325,  0.0000, -0.1845,
-        -0.0226,  0.1494,  0.4448,  0.0923,  0.2733, -0.5666,  0.0000, -1.2279,
-         0.0000,  0.2737,  0.0000, -0.7064, -0.0977,  0.2021,  0.0000, -0.5692,
-         0.3501,  1.1801,  0.0388,  0.1357, -0.3142,  0.0000, -0.6972, -1.5675,
-         0.0000, -0.1607,  0.0000, -0.1415,  0.0000,  0.0000,  0.9179,  0.5377,
-         0.0648,  0.0000,  0.0530,  0.0000, -1.0720,  0.1765,  0.7078,  0.0552,
-         1.0160,  0.0000,  0.2326, -0.3812,  0.0052,  0.2875,  0.3430,  0.0000,
-         0.0000,  0.0000,  0.0349,  1.1618,  0.0000, -0.0840,  0.1053,  0.1371],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4041e-01,  4.2409e-01, -1.1027e-01, -4.6686e-01,  2.7436e+00,
-         4.4869e-01,  2.5235e-08, -1.3507e-01,  2.9859e-02,  1.2966e-01,
-         3.9204e-01,  9.2712e-02,  1.2342e-01, -5.3730e-01, -4.5254e-05,
-        -1.2375e+00, -3.3084e-03,  1.3596e-01, -7.6521e-06, -7.0240e-01,
-        -3.1623e-02,  2.7886e-01, -1.6250e-05, -5.2852e-01,  4.6081e-01,
-         1.1898e+00, -1.1048e-01,  1.6364e-01, -2.1910e-01,  2.3634e-05,
-        -6.7570e-01, -1.5635e+00, -6.0498e-09, -1.0816e-01,  3.0028e-03,
-        -9.9207e-02,  1.0017e-04,  0.0000e+00,  9.2648e-01,  5.9719e-01,
-         1.6802e-01, -1.3379e-02, -7.2615e-02,  2.4652e-05, -1.0622e+00,
-        -6.6422e-02,  7.1615e-01,  3.3434e-02,  1.0383e+00, -1.5694e-06,
-         3.2607e-01, -2.6873e-01,  3.5969e-02,  2.0436e-01,  4.6561e-01,
-         3.2662e-10, -1.2837e-04,  4.6165e-08,  1.1141e-01,  1.1541e+00,
-         1.2332e-07, -1.3430e-01,  9.7005e-02, -6.4754e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3404,  0.4241, -0.1103, -0.4669,  2.7436,  0.4487,  0.0000, -0.1351,
-         0.0299,  0.1297,  0.3920,  0.0927,  0.1234, -0.5373,  0.0000, -1.2375,
-         0.0000,  0.1360,  0.0000, -0.7024, -0.0316,  0.2789,  0.0000, -0.5285,
-         0.4608,  1.1898, -0.1105,  0.1636, -0.2191,  0.0000, -0.6757, -1.5635,
-         0.0000, -0.1082,  0.0000, -0.0992,  0.0000,  0.0000,  0.9265,  0.5972,
-         0.1680,  0.0000, -0.0726,  0.0000, -1.0622, -0.0664,  0.7162,  0.0334,
-         1.0383,  0.0000,  0.3261, -0.2687,  0.0360,  0.2044,  0.4656,  0.0000,
-         0.0000,  0.0000,  0.1114,  1.1541,  0.0000, -0.1343,  0.0970, -0.0648],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3404,  0.4241, -0.1103, -0.4669,  2.7436,  0.4487,  0.0000, -0.1351,
-         0.0299,  0.1297,  0.3920,  0.0927,  0.1234, -0.5373,  0.0000, -1.2375,
-         0.0000,  0.1360,  0.0000, -0.7024, -0.0316,  0.2789,  0.0000, -0.5285,
-         0.4608,  1.1898, -0.1105,  0.1636, -0.2191,  0.0000, -0.6757, -1.5635,
-         0.0000, -0.1082,  0.0000, -0.0992,  0.0000,  0.0000,  0.9265,  0.5972,
-         0.1680,  0.0000, -0.0726,  0.0000, -1.0622, -0.0664,  0.7162,  0.0334,
-         1.0383,  0.0000,  0.3261, -0.2687,  0.0360,  0.2044,  0.4656,  0.0000,
-         0.0000,  0.0000,  0.1114,  1.1541,  0.0000, -0.1343,  0.0970, -0.0648],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6143e-01,  4.1953e-01, -6.0128e-02, -4.6700e-01,  2.7505e+00,
-         4.3393e-01,  2.1515e-08, -1.7290e-01,  1.7515e-02,  8.8390e-02,
-         2.6861e-01,  1.1855e-01,  3.5438e-02, -5.3719e-01, -3.8582e-05,
-        -1.2416e+00, -2.8207e-03, -4.3583e-02, -6.5241e-06, -6.9885e-01,
-         1.4810e-02,  1.7807e-01, -1.3855e-05, -5.3262e-01,  4.4661e-01,
-         1.1984e+00, -2.3884e-01,  8.4848e-02, -1.8175e-01,  2.0150e-05,
-        -6.5850e-01, -1.5584e+00, -5.1579e-09, -1.1465e-01,  2.5601e-03,
-        -3.2203e-02,  8.5400e-05,  0.0000e+00,  9.5603e-01,  7.1225e-01,
-         1.1748e-01, -1.1407e-02, -1.9632e-01,  2.1018e-05, -1.0570e+00,
-        -2.3788e-01,  7.1901e-01,  3.3129e-02,  1.0428e+00, -1.3380e-06,
-         3.5566e-01, -2.4345e-01,  2.0594e-02,  1.3216e-02,  6.0806e-01,
-         2.7847e-10, -1.0944e-04,  3.9360e-08,  1.6219e-01,  1.1468e+00,
-         1.0514e-07, -1.5184e-01,  3.2521e-02, -2.2179e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3614,  0.4195, -0.0601, -0.4670,  2.7505,  0.4339,  0.0000, -0.1729,
-         0.0175,  0.0884,  0.2686,  0.1185,  0.0354, -0.5372,  0.0000, -1.2416,
-         0.0000, -0.0436,  0.0000, -0.6989,  0.0148,  0.1781,  0.0000, -0.5326,
-         0.4466,  1.1984, -0.2388,  0.0848, -0.1818,  0.0000, -0.6585, -1.5584,
-         0.0000, -0.1147,  0.0000, -0.0322,  0.0000,  0.0000,  0.9560,  0.7123,
-         0.1175,  0.0000, -0.1963,  0.0000,  0.0000, -0.2379,  0.7190,  0.0331,
-         1.0428,  0.0000,  0.3557, -0.2435,  0.0206,  0.0132,  0.6081,  0.0000,
-         0.0000,  0.0000,  0.1622,  1.1468,  0.0000, -0.1518,  0.0325, -0.2218],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3614,  0.4195, -0.0601, -0.4670,  2.7505,  0.4339,  0.0000, -0.1729,
-         0.0175,  0.0884,  0.2686,  0.1185,  0.0354, -0.5372,  0.0000, -1.2416,
-         0.0000, -0.0436,  0.0000, -0.6989,  0.0148,  0.1781,  0.0000, -0.5326,
-         0.4466,  1.1984, -0.2388,  0.0848, -0.1818,  0.0000, -0.6585, -1.5584,
-         0.0000, -0.1147,  0.0000, -0.0322,  0.0000,  0.0000,  0.9560,  0.7123,
-         0.1175,  0.0000, -0.1963,  0.0000,  0.0000, -0.2379,  0.7190,  0.0331,
-         1.0428,  0.0000,  0.3557, -0.2435,  0.0206,  0.0132,  0.6081,  0.0000,
-         0.0000,  0.0000,  0.1622,  1.1468,  0.0000, -0.1518,  0.0325, -0.2218],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.1297e-01,  3.9638e-01,  3.2786e-02, -4.9246e-01,  2.7602e+00,
-         4.0556e-01,  1.8346e-08, -1.9282e-01, -1.6123e-02,  6.9374e-02,
-         1.2359e-01,  1.4276e-01,  3.6331e-02, -5.2735e-01, -3.2900e-05,
-        -1.2450e+00, -2.4053e-03, -2.1134e-01, -5.5632e-06, -6.9071e-01,
-         1.7863e-02, -4.3029e-04, -1.1814e-05, -5.6250e-01,  3.5425e-01,
-         1.2000e+00, -3.2394e-01, -1.9351e-02, -2.3215e-01,  1.7182e-05,
-        -6.4611e-01, -1.5547e+00, -4.3982e-09, -1.6431e-01,  2.1831e-03,
-         2.1611e-02,  7.2822e-05,  0.0000e+00,  9.6652e-01,  8.0444e-01,
-        -2.8414e-03, -9.7267e-03, -2.9018e-01,  1.7922e-05,  4.3836e-03,
-        -3.9751e-01,  7.1886e-01,  4.5415e-02,  1.0505e+00, -1.1410e-06,
-         3.4028e-01, -2.6347e-01, -3.1480e-02, -1.6453e-01,  7.0453e-01,
-         2.3745e-10, -9.3324e-05,  3.3563e-08,  2.4443e-01,  1.1372e+00,
-         8.9654e-08, -1.2922e-01, -4.2801e-02, -3.5938e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 4.1297e-01,  3.9638e-01,  3.2786e-02, -4.9246e-01,  2.7602e+00,
-         4.0556e-01,  0.0000e+00, -1.9282e-01, -1.6123e-02,  6.9374e-02,
-         1.2359e-01,  1.4276e-01,  3.6331e-02, -5.2735e-01,  0.0000e+00,
-        -1.2450e+00,  0.0000e+00, -2.1134e-01,  0.0000e+00, -6.9071e-01,
-         1.7863e-02, -4.3029e-04,  0.0000e+00, -5.6250e-01,  3.5425e-01,
-         1.2000e+00, -3.2394e-01, -1.9351e-02, -2.3215e-01,  0.0000e+00,
-        -6.4611e-01, -1.5547e+00,  0.0000e+00, -1.6431e-01,  0.0000e+00,
-         2.1611e-02,  0.0000e+00,  0.0000e+00,  9.6652e-01,  8.0444e-01,
-        -2.8414e-03,  0.0000e+00, -2.9018e-01,  0.0000e+00,  0.0000e+00,
-        -3.9751e-01,  7.1886e-01,  4.5415e-02,  1.0505e+00,  0.0000e+00,
-         3.4028e-01, -2.6347e-01, -3.1480e-02, -1.6453e-01,  7.0453e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  2.4443e-01,  1.1372e+00,
-         0.0000e+00, -1.2922e-01, -4.2801e-02, -3.5938e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 4.1297e-01,  3.9638e-01,  3.2786e-02, -4.9246e-01,  2.7602e+00,
-         4.0556e-01,  0.0000e+00, -1.9282e-01, -1.6123e-02,  6.9374e-02,
-         1.2359e-01,  1.4276e-01,  3.6331e-02, -5.2735e-01,  0.0000e+00,
-        -1.2450e+00,  0.0000e+00, -2.1134e-01,  0.0000e+00, -6.9071e-01,
-         1.7863e-02, -4.3029e-04,  0.0000e+00, -5.6250e-01,  3.5425e-01,
-         1.2000e+00, -3.2394e-01, -1.9351e-02, -2.3215e-01,  0.0000e+00,
-        -6.4611e-01, -1.5547e+00,  0.0000e+00, -1.6431e-01,  0.0000e+00,
-         2.1611e-02,  0.0000e+00,  0.0000e+00,  9.6652e-01,  8.0444e-01,
-        -2.8414e-03,  0.0000e+00, -2.9018e-01,  0.0000e+00,  0.0000e+00,
-        -3.9751e-01,  7.1886e-01,  4.5415e-02,  1.0505e+00,  0.0000e+00,
-         3.4028e-01, -2.6347e-01, -3.1480e-02, -1.6453e-01,  7.0453e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  2.4443e-01,  1.1372e+00,
-         0.0000e+00, -1.2922e-01, -4.2801e-02, -3.5938e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 4.6937e-01,  3.6499e-01,  1.3808e-01, -4.8514e-01,  2.7673e+00,
-         3.7000e-01,  1.5647e-08, -2.8914e-01, -1.2178e-01,  8.0286e-03,
-         8.1254e-03,  1.4301e-01,  8.6609e-02, -5.2454e-01, -2.8059e-05,
-        -1.2568e+00, -2.0513e-03, -2.4853e-01, -4.7446e-06, -7.1146e-01,
-        -4.2529e-02, -1.6778e-01, -1.0076e-05, -5.8546e-01,  2.4426e-01,
-         1.1884e+00, -3.4728e-01, -1.3313e-01, -3.1291e-01,  1.4654e-05,
-        -6.2628e-01, -1.5486e+00, -3.7511e-09, -1.7319e-01,  1.8619e-03,
-         2.0965e-02,  6.2107e-05,  0.0000e+00,  9.5427e-01,  8.6896e-01,
-        -8.2772e-02, -8.2955e-03, -2.7866e-01,  1.5285e-05,  3.7386e-03,
-        -4.4846e-01,  7.1678e-01,  2.7530e-02,  1.0587e+00, -9.7309e-07,
-         3.4633e-01, -3.3524e-01, -1.0180e-01, -6.8606e-02,  7.7251e-01,
-         2.0252e-10, -7.9592e-05,  2.8624e-08,  3.1954e-01,  1.1377e+00,
-         7.6462e-08,  2.3156e-02, -6.1328e-02, -4.5453e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4694,  0.3650,  0.1381, -0.4851,  2.7673,  0.3700,  0.0000, -0.2891,
-        -0.1218,  0.0080,  0.0081,  0.1430,  0.0866, -0.5245,  0.0000, -1.2568,
-         0.0000, -0.2485,  0.0000, -0.7115, -0.0425, -0.1678,  0.0000, -0.5855,
-         0.2443,  1.1884, -0.3473, -0.1331, -0.3129,  0.0000, -0.6263, -1.5486,
-         0.0000, -0.1732,  0.0000,  0.0210,  0.0000,  0.0000,  0.9543,  0.8690,
-        -0.0828,  0.0000, -0.2787,  0.0000,  0.0000, -0.4485,  0.7168,  0.0275,
-         1.0587,  0.0000,  0.3463, -0.3352, -0.1018, -0.0686,  0.7725,  0.0000,
-         0.0000,  0.0000,  0.3195,  1.1377,  0.0000,  0.0232, -0.0613, -0.4545],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4694,  0.3650,  0.1381, -0.4851,  2.7673,  0.3700,  0.0000, -0.2891,
-        -0.1218,  0.0080,  0.0081,  0.1430,  0.0866, -0.5245,  0.0000, -1.2568,
-         0.0000, -0.2485,  0.0000, -0.7115, -0.0425, -0.1678,  0.0000, -0.5855,
-         0.2443,  1.1884, -0.3473, -0.1331, -0.3129,  0.0000, -0.6263, -1.5486,
-         0.0000, -0.1732,  0.0000,  0.0210,  0.0000,  0.0000,  0.9543,  0.8690,
-        -0.0828,  0.0000, -0.2787,  0.0000,  0.0000, -0.4485,  0.7168,  0.0275,
-         1.0587,  0.0000,  0.3463, -0.3352, -0.1018, -0.0686,  0.7725,  0.0000,
-         0.0000,  0.0000,  0.3195,  1.1377,  0.0000,  0.0232, -0.0613, -0.4545],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.9422e-01,  3.6734e-01,  2.0589e-01, -4.7460e-01,  2.7732e+00,
-         3.6803e-01,  1.3347e-08, -3.7806e-01, -1.8199e-01, -6.1560e-02,
-        -5.9970e-02,  8.5337e-02,  6.7126e-02, -4.9847e-01, -2.3934e-05,
-        -1.2665e+00, -1.7498e-03, -2.6879e-01, -4.0472e-06, -7.1967e-01,
-        -1.1370e-01, -2.6580e-01, -8.5947e-06, -5.9533e-01,  1.4969e-01,
-         1.1714e+00, -3.7923e-01, -1.8964e-01, -4.1811e-01,  1.2500e-05,
-        -5.8865e-01, -1.5401e+00, -3.1997e-09, -1.3318e-01,  1.5882e-03,
-        -2.8527e-02,  5.2978e-05,  0.0000e+00,  9.3028e-01,  9.1600e-01,
-        -1.0179e-01, -7.0761e-03, -2.9570e-01,  1.3038e-05,  3.1890e-03,
-        -4.5386e-01,  7.1190e-01,  2.3013e-02,  1.0699e+00, -8.3005e-07,
-         3.7022e-01, -3.3616e-01, -1.7351e-01,  4.5539e-02,  8.2437e-01,
-         1.7275e-10, -6.7892e-05,  2.4417e-08,  4.0194e-01,  1.1362e+00,
-         6.5222e-08,  1.4399e-01, -8.3551e-02, -5.4565e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4942,  0.3673,  0.2059, -0.4746,  2.7732,  0.3680,  0.0000, -0.3781,
-        -0.1820, -0.0616, -0.0600,  0.0853,  0.0671, -0.4985,  0.0000, -1.2665,
-         0.0000, -0.2688,  0.0000, -0.7197, -0.1137, -0.2658,  0.0000, -0.5953,
-         0.1497,  1.1714, -0.3792, -0.1896, -0.4181,  0.0000, -0.5887, -1.5401,
-         0.0000, -0.1332,  0.0000, -0.0285,  0.0000,  0.0000,  0.9303,  0.9160,
-        -0.1018,  0.0000, -0.2957,  0.0000,  0.0000, -0.4539,  0.7119,  0.0230,
-         1.0699,  0.0000,  0.3702, -0.3362, -0.1735,  0.0455,  0.8244,  0.0000,
-         0.0000,  0.0000,  0.4019,  1.1362,  0.0000,  0.1440, -0.0836, -0.5457],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4942,  0.3673,  0.2059, -0.4746,  2.7732,  0.3680,  0.0000, -0.3781,
-        -0.1820, -0.0616, -0.0600,  0.0853,  0.0671, -0.4985,  0.0000, -1.2665,
-         0.0000, -0.2688,  0.0000, -0.7197, -0.1137, -0.2658,  0.0000, -0.5953,
-         0.1497,  1.1714, -0.3792, -0.1896, -0.4181,  0.0000, -0.5887, -1.5401,
-         0.0000, -0.1332,  0.0000, -0.0285,  0.0000,  0.0000,  0.9303,  0.9160,
-        -0.1018,  0.0000, -0.2957,  0.0000,  0.0000, -0.4539,  0.7119,  0.0230,
-         1.0699,  0.0000,  0.3702, -0.3362, -0.1735,  0.0455,  0.8244,  0.0000,
-         0.0000,  0.0000,  0.4019,  1.1362,  0.0000,  0.1440, -0.0836, -0.5457],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.9972e-01,  3.8914e-01,  2.3318e-01, -4.5085e-01,  2.7777e+00,
-         3.6829e-01,  1.1387e-08, -4.7032e-01, -1.6243e-01, -1.4417e-01,
-        -2.1745e-02,  8.2538e-05, -3.9310e-02, -4.8265e-01, -2.0420e-05,
-        -1.2738e+00, -1.4928e-03, -2.2417e-01, -3.4528e-06, -7.2172e-01,
-        -1.2450e-01, -2.4052e-01, -7.3326e-06, -6.0175e-01,  1.2311e-01,
-         1.1580e+00, -4.3335e-01, -1.5594e-01, -5.0574e-01,  1.0664e-05,
-        -5.5115e-01, -1.5326e+00, -2.7298e-09, -5.3620e-02,  1.3550e-03,
-        -7.3203e-02,  4.5198e-05,  0.0000e+00,  8.8960e-01,  9.3610e-01,
-        -8.5673e-03, -6.0369e-03, -3.5212e-01,  1.1124e-05,  2.7207e-03,
-        -4.9676e-01,  7.0752e-01,  1.7112e-02,  1.0574e+00, -7.0815e-07,
-         4.1044e-01, -2.5439e-01, -2.4381e-01,  9.3941e-02,  8.5840e-01,
-         1.4738e-10, -5.7922e-05,  2.0831e-08,  4.4358e-01,  1.1380e+00,
-         5.5644e-08,  1.9653e-01, -5.8764e-02, -6.1996e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 4.9972e-01,  3.8914e-01,  2.3318e-01, -4.5085e-01,  2.7777e+00,
-         3.6829e-01,  0.0000e+00, -4.7032e-01, -1.6243e-01, -1.4417e-01,
-        -2.1745e-02,  8.2538e-05, -3.9310e-02, -4.8265e-01,  0.0000e+00,
-        -1.2738e+00,  0.0000e+00, -2.2417e-01,  0.0000e+00, -7.2172e-01,
-        -1.2450e-01, -2.4052e-01,  0.0000e+00, -6.0175e-01,  1.2311e-01,
-         1.1580e+00, -4.3335e-01, -1.5594e-01, -5.0574e-01,  0.0000e+00,
-        -5.5115e-01, -1.5326e+00,  0.0000e+00, -5.3620e-02,  0.0000e+00,
-        -7.3203e-02,  0.0000e+00,  0.0000e+00,  8.8960e-01,  9.3610e-01,
-        -8.5673e-03,  0.0000e+00, -3.5212e-01,  0.0000e+00,  0.0000e+00,
-        -4.9676e-01,  7.0752e-01,  1.7112e-02,  1.0574e+00,  0.0000e+00,
-         4.1044e-01, -2.5439e-01, -2.4381e-01,  9.3941e-02,  8.5840e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.4358e-01,  1.1380e+00,
-         0.0000e+00,  1.9653e-01, -5.8764e-02, -6.1996e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 4.9972e-01,  3.8914e-01,  2.3318e-01, -4.5085e-01,  2.7777e+00,
-         3.6829e-01,  0.0000e+00, -4.7032e-01, -1.6243e-01, -1.4417e-01,
-        -2.1745e-02,  8.2538e-05, -3.9310e-02, -4.8265e-01,  0.0000e+00,
-        -1.2738e+00,  0.0000e+00, -2.2417e-01,  0.0000e+00, -7.2172e-01,
-        -1.2450e-01, -2.4052e-01,  0.0000e+00, -6.0175e-01,  1.2311e-01,
-         1.1580e+00, -4.3335e-01, -1.5594e-01, -5.0574e-01,  0.0000e+00,
-        -5.5115e-01, -1.5326e+00,  0.0000e+00, -5.3620e-02,  0.0000e+00,
-        -7.3203e-02,  0.0000e+00,  0.0000e+00,  8.8960e-01,  9.3610e-01,
-        -8.5673e-03,  0.0000e+00, -3.5212e-01,  0.0000e+00,  0.0000e+00,
-        -4.9676e-01,  7.0752e-01,  1.7112e-02,  1.0574e+00,  0.0000e+00,
-         4.1044e-01, -2.5439e-01, -2.4381e-01,  9.3941e-02,  8.5840e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.4358e-01,  1.1380e+00,
-         0.0000e+00,  1.9653e-01, -5.8764e-02, -6.1996e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 5.0827e-01,  4.3162e-01,  2.6909e-01, -4.1004e-01,  2.7808e+00,
-         3.7558e-01,  9.7163e-09, -5.4704e-01, -1.1955e-01, -2.1143e-01,
-         5.1767e-02, -7.6659e-02, -1.0658e-01, -4.5141e-01, -1.7424e-05,
-        -1.2792e+00, -1.2738e-03, -1.7721e-01, -2.9463e-06, -7.2132e-01,
-        -1.2764e-01, -1.9195e-01, -6.2569e-06, -6.0222e-01,  1.1984e-01,
-         1.1428e+00, -4.5046e-01, -1.0375e-01, -5.7382e-01,  9.0998e-06,
-        -5.3920e-01, -1.5266e+00, -2.3294e-09,  7.6146e-03,  1.1562e-03,
-        -1.2388e-01,  3.8568e-05,  0.0000e+00,  8.2192e-01,  9.4845e-01,
-         1.1968e-01, -5.1513e-03, -4.1203e-01,  9.4918e-06,  2.3216e-03,
-        -5.4156e-01,  7.0101e-01,  2.3639e-02,  1.0465e+00, -6.0427e-07,
-         4.5336e-01, -1.5822e-01, -3.1862e-01,  2.0222e-01,  8.9239e-01,
-         1.2576e-10, -4.9425e-05,  1.7775e-08,  4.6229e-01,  1.1411e+00,
-         4.7482e-08,  2.8186e-01,  1.2398e-03, -6.8419e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 5.0827e-01,  4.3162e-01,  2.6909e-01, -4.1004e-01,  2.7808e+00,
-         3.7558e-01,  0.0000e+00,  0.0000e+00, -1.1955e-01, -2.1143e-01,
-         5.1767e-02, -7.6659e-02, -1.0658e-01, -4.5141e-01,  0.0000e+00,
-        -1.2792e+00,  0.0000e+00, -1.7721e-01,  0.0000e+00, -7.2132e-01,
-        -1.2764e-01, -1.9195e-01,  0.0000e+00, -6.0222e-01,  1.1984e-01,
-         1.1428e+00, -4.5046e-01, -1.0375e-01, -5.7382e-01,  0.0000e+00,
-        -5.3920e-01, -1.5266e+00,  0.0000e+00,  7.6146e-03,  0.0000e+00,
-        -1.2388e-01,  0.0000e+00,  0.0000e+00,  8.2192e-01,  9.4845e-01,
-         1.1968e-01,  0.0000e+00, -4.1203e-01,  0.0000e+00,  0.0000e+00,
-        -5.4156e-01,  7.0101e-01,  2.3639e-02,  1.0465e+00,  0.0000e+00,
-         4.5336e-01, -1.5822e-01, -3.1862e-01,  2.0222e-01,  8.9239e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6229e-01,  1.1411e+00,
-         0.0000e+00,  2.8186e-01,  1.2398e-03, -6.8419e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 5.0827e-01,  4.3162e-01,  2.6909e-01, -4.1004e-01,  2.7808e+00,
-         3.7558e-01,  0.0000e+00,  0.0000e+00, -1.1955e-01, -2.1143e-01,
-         5.1767e-02, -7.6659e-02, -1.0658e-01, -4.5141e-01,  0.0000e+00,
-        -1.2792e+00,  0.0000e+00, -1.7721e-01,  0.0000e+00, -7.2132e-01,
-        -1.2764e-01, -1.9195e-01,  0.0000e+00, -6.0222e-01,  1.1984e-01,
-         1.1428e+00, -4.5046e-01, -1.0375e-01, -5.7382e-01,  0.0000e+00,
-        -5.3920e-01, -1.5266e+00,  0.0000e+00,  7.6146e-03,  0.0000e+00,
-        -1.2388e-01,  0.0000e+00,  0.0000e+00,  8.2192e-01,  9.4845e-01,
-         1.1968e-01,  0.0000e+00, -4.1203e-01,  0.0000e+00,  0.0000e+00,
-        -5.4156e-01,  7.0101e-01,  2.3639e-02,  1.0465e+00,  0.0000e+00,
-         4.5336e-01, -1.5822e-01, -3.1862e-01,  2.0222e-01,  8.9239e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6229e-01,  1.1411e+00,
-         0.0000e+00,  2.8186e-01,  1.2398e-03, -6.8419e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 5.1439e-01,  4.5427e-01,  3.0043e-01, -3.5376e-01,  2.7791e+00,
-         3.9175e-01,  8.2924e-09, -6.5478e-02, -6.6721e-02, -2.6445e-01,
-         8.2392e-02, -1.7072e-01, -1.6402e-01, -4.1670e-01, -1.4871e-05,
-        -1.2814e+00, -1.0872e-03, -1.2216e-01, -2.5146e-06, -7.1646e-01,
-        -1.2343e-01, -1.6177e-01, -5.3400e-06, -6.1710e-01,  1.0442e-01,
-         1.1263e+00, -4.4463e-01, -3.9453e-02, -6.0450e-01,  7.7662e-06,
-        -5.4570e-01, -1.5231e+00, -1.9880e-09,  3.4398e-02,  9.8675e-04,
-        -1.9172e-01,  3.2916e-05,  0.0000e+00,  7.5341e-01,  9.6419e-01,
-         2.4059e-01, -4.3964e-03, -4.6245e-01,  8.1008e-06,  1.9814e-03,
-        -5.9481e-01,  6.9491e-01,  2.9752e-02,  1.0399e+00, -5.1572e-07,
-         4.7976e-01, -7.3564e-02, -3.7715e-01,  2.7514e-01,  9.2753e-01,
-         1.0733e-10, -4.2182e-05,  1.5170e-08,  4.7595e-01,  1.1448e+00,
-         4.0523e-08,  3.5228e-01,  2.4484e-02, -7.1817e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5144,  0.4543,  0.3004, -0.3538,  2.7791,  0.3917,  0.0000,  0.0000,
-        -0.0667, -0.2645,  0.0824, -0.1707, -0.1640, -0.4167,  0.0000, -1.2814,
-         0.0000, -0.1222,  0.0000, -0.7165, -0.1234, -0.1618,  0.0000, -0.6171,
-         0.1044,  1.1263, -0.4446, -0.0395, -0.6045,  0.0000, -0.5457, -1.5231,
-         0.0000,  0.0344,  0.0000, -0.1917,  0.0000,  0.0000,  0.7534,  0.9642,
-         0.2406,  0.0000, -0.4624,  0.0000,  0.0000, -0.5948,  0.6949,  0.0298,
-         1.0399,  0.0000,  0.4798, -0.0736, -0.3772,  0.2751,  0.9275,  0.0000,
-         0.0000,  0.0000,  0.4760,  1.1448,  0.0000,  0.3523,  0.0245, -0.7182],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5144,  0.4543,  0.3004, -0.3538,  2.7791,  0.3917,  0.0000,  0.0000,
-        -0.0667, -0.2645,  0.0824, -0.1707, -0.1640, -0.4167,  0.0000, -1.2814,
-         0.0000, -0.1222,  0.0000, -0.7165, -0.1234, -0.1618,  0.0000, -0.6171,
-         0.1044,  1.1263, -0.4446, -0.0395, -0.6045,  0.0000, -0.5457, -1.5231,
-         0.0000,  0.0344,  0.0000, -0.1917,  0.0000,  0.0000,  0.7534,  0.9642,
-         0.2406,  0.0000, -0.4624,  0.0000,  0.0000, -0.5948,  0.6949,  0.0298,
-         1.0399,  0.0000,  0.4798, -0.0736, -0.3772,  0.2751,  0.9275,  0.0000,
-         0.0000,  0.0000,  0.4760,  1.1448,  0.0000,  0.3523,  0.0245, -0.7182],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.2155e-01,  4.5685e-01,  3.6684e-01, -2.7621e-01,  2.7747e+00,
-         4.3421e-01,  7.0785e-09, -5.5893e-02, -1.1054e-02, -2.7233e-01,
-         6.5258e-02, -2.3059e-01, -1.6990e-01, -3.9763e-01, -1.2694e-05,
-        -1.2837e+00, -9.2803e-04, -1.3970e-02, -2.1465e-06, -7.0910e-01,
-        -1.2577e-01, -1.0847e-01, -4.5583e-06, -6.1807e-01,  6.8812e-02,
-         1.1185e+00, -3.9099e-01,  2.9979e-02, -5.8527e-01,  6.6294e-06,
-        -5.4355e-01, -1.5167e+00, -1.6970e-09,  6.6920e-02,  8.4230e-04,
-        -2.4991e-01,  2.8097e-05,  0.0000e+00,  6.9759e-01,  9.8444e-01,
-         3.2262e-01, -3.7528e-03, -4.8418e-01,  6.9150e-06,  1.6913e-03,
-        -6.5493e-01,  6.7905e-01,  3.9363e-02,  1.0226e+00, -4.4022e-07,
-         4.7303e-01, -2.8497e-02, -4.6316e-01,  3.5240e-01,  9.5883e-01,
-         9.1618e-11, -3.6007e-05,  1.2950e-08,  4.7477e-01,  1.1391e+00,
-         3.4591e-08,  4.2892e-01,  5.0684e-02, -7.1826e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5215,  0.4569,  0.3668, -0.2762,  2.7747,  0.4342,  0.0000,  0.0000,
-        -0.0111, -0.2723,  0.0653, -0.2306, -0.1699, -0.3976,  0.0000, -1.2837,
-         0.0000, -0.0140,  0.0000, -0.7091, -0.1258, -0.1085,  0.0000, -0.6181,
-         0.0688,  1.1185, -0.3910,  0.0300, -0.5853,  0.0000, -0.5435, -1.5167,
-         0.0000,  0.0669,  0.0000, -0.2499,  0.0000,  0.0000,  0.6976,  0.9844,
-         0.3226,  0.0000, -0.4842,  0.0000,  0.0000, -0.6549,  0.6791,  0.0394,
-         1.0226,  0.0000,  0.4730, -0.0285, -0.4632,  0.3524,  0.9588,  0.0000,
-         0.0000,  0.0000,  0.4748,  1.1391,  0.0000,  0.4289,  0.0507, -0.7183],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5215,  0.4569,  0.3668, -0.2762,  2.7747,  0.4342,  0.0000,  0.0000,
-        -0.0111, -0.2723,  0.0653, -0.2306, -0.1699, -0.3976,  0.0000, -1.2837,
-         0.0000, -0.0140,  0.0000, -0.7091, -0.1258, -0.1085,  0.0000, -0.6181,
-         0.0688,  1.1185, -0.3910,  0.0300, -0.5853,  0.0000, -0.5435, -1.5167,
-         0.0000,  0.0669,  0.0000, -0.2499,  0.0000,  0.0000,  0.6976,  0.9844,
-         0.3226,  0.0000, -0.4842,  0.0000,  0.0000, -0.6549,  0.6791,  0.0394,
-         1.0226,  0.0000,  0.4730, -0.0285, -0.4632,  0.3524,  0.9588,  0.0000,
-         0.0000,  0.0000,  0.4748,  1.1391,  0.0000,  0.4289,  0.0507, -0.7183],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.2336e-01,  4.7795e-01,  3.9967e-01, -2.0265e-01,  2.7708e+00,
-         4.6473e-01,  6.0434e-09, -4.7720e-02,  5.4610e-02, -2.8867e-01,
-         6.8803e-02, -2.6840e-01, -2.1967e-01, -3.7408e-01, -1.0838e-05,
-        -1.2791e+00, -7.9232e-04,  7.5216e-02, -1.8326e-06, -6.9101e-01,
-        -1.0784e-01, -5.8703e-02, -3.8917e-06, -5.9960e-01,  4.5075e-02,
-         1.1187e+00, -3.5392e-01,  8.4930e-02, -5.2633e-01,  5.6600e-06,
-        -5.2566e-01, -1.5091e+00, -1.4488e-09,  6.7229e-02,  7.1913e-04,
-        -2.8939e-01,  2.3989e-05,  0.0000e+00,  6.8981e-01,  1.0045e+00,
-         3.8778e-01, -3.2041e-03, -5.2108e-01,  5.9038e-06,  1.4440e-03,
-        -6.8890e-01,  6.5725e-01,  6.4291e-02,  1.0101e+00, -3.7585e-07,
-         4.6655e-01,  1.2703e-02, -5.2192e-01,  3.9178e-01,  9.9423e-01,
-         7.8221e-11, -3.0742e-05,  1.1056e-08,  4.6088e-01,  1.1249e+00,
-         2.9533e-08,  4.3932e-01,  1.1424e-01, -7.1390e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5234,  0.4779,  0.3997, -0.2027,  2.7708,  0.4647,  0.0000,  0.0000,
-         0.0546, -0.2887,  0.0688, -0.2684, -0.2197, -0.3741,  0.0000, -1.2791,
-         0.0000,  0.0752,  0.0000, -0.6910, -0.1078, -0.0587,  0.0000, -0.5996,
-         0.0451,  1.1187, -0.3539,  0.0849, -0.5263,  0.0000, -0.5257, -1.5091,
-         0.0000,  0.0672,  0.0000, -0.2894,  0.0000,  0.0000,  0.6898,  1.0045,
-         0.3878,  0.0000, -0.5211,  0.0000,  0.0000, -0.6889,  0.6573,  0.0643,
-         1.0101,  0.0000,  0.4666,  0.0127, -0.5219,  0.3918,  0.9942,  0.0000,
-         0.0000,  0.0000,  0.4609,  1.1249,  0.0000,  0.4393,  0.1142, -0.7139],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5234,  0.4779,  0.3997, -0.2027,  2.7708,  0.4647,  0.0000,  0.0000,
-         0.0546, -0.2887,  0.0688, -0.2684, -0.2197, -0.3741,  0.0000, -1.2791,
-         0.0000,  0.0752,  0.0000, -0.6910, -0.1078, -0.0587,  0.0000, -0.5996,
-         0.0451,  1.1187, -0.3539,  0.0849, -0.5263,  0.0000, -0.5257, -1.5091,
-         0.0000,  0.0672,  0.0000, -0.2894,  0.0000,  0.0000,  0.6898,  1.0045,
-         0.3878,  0.0000, -0.5211,  0.0000,  0.0000, -0.6889,  0.6573,  0.0643,
-         1.0101,  0.0000,  0.4666,  0.0127, -0.5219,  0.3918,  0.9942,  0.0000,
-         0.0000,  0.0000,  0.4609,  1.1249,  0.0000,  0.4393,  0.1142, -0.7139],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.9066e-01,  4.6157e-01,  3.7517e-01, -1.6273e-01,  2.7631e+00,
-         4.4793e-01,  5.1607e-09, -4.0750e-02,  1.1650e-01, -3.2733e-01,
-         1.1794e-01, -2.9736e-01, -3.1262e-01, -3.4747e-01, -9.2546e-06,
-        -1.2700e+00, -6.7659e-04,  1.2106e-01, -1.5649e-06, -6.7220e-01,
-        -6.6253e-02, -2.5054e-02, -3.3233e-06, -5.5994e-01,  3.0166e-02,
-         1.1239e+00, -3.4683e-01,  1.0869e-01, -4.4463e-01,  4.8332e-06,
-        -5.0706e-01, -1.5024e+00, -1.2372e-09,  6.0613e-02,  6.1409e-04,
-        -3.0794e-01,  2.0485e-05,  0.0000e+00,  6.8260e-01,  1.0259e+00,
-         4.4690e-01, -2.7361e-03, -5.5278e-01,  5.0415e-06,  1.2331e-03,
-        -6.8503e-01,  6.3657e-01,  1.1944e-01,  9.9027e-01, -3.2095e-07,
-         4.9176e-01,  5.2283e-02, -5.4418e-01,  3.1771e-01,  1.0237e+00,
-         6.6795e-11, -2.6252e-05,  9.4411e-09,  4.4419e-01,  1.1176e+00,
-         2.5219e-08,  3.7193e-01,  2.0000e-01, -7.0621e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4907,  0.4616,  0.3752, -0.1627,  2.7631,  0.4479,  0.0000,  0.0000,
-         0.1165, -0.3273,  0.1179, -0.2974, -0.3126, -0.3475,  0.0000, -1.2700,
-         0.0000,  0.1211,  0.0000,  0.0000, -0.0663, -0.0251,  0.0000, -0.5599,
-         0.0302,  1.1239, -0.3468,  0.1087, -0.4446,  0.0000, -0.5071, -1.5024,
-         0.0000,  0.0606,  0.0000, -0.3079,  0.0000,  0.0000,  0.6826,  1.0259,
-         0.4469,  0.0000, -0.5528,  0.0000,  0.0000, -0.6850,  0.6366,  0.1194,
-         0.9903,  0.0000,  0.4918,  0.0523, -0.5442,  0.3177,  1.0237,  0.0000,
-         0.0000,  0.0000,  0.4442,  1.1176,  0.0000,  0.3719,  0.2000, -0.7062],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4907,  0.4616,  0.3752, -0.1627,  2.7631,  0.4479,  0.0000,  0.0000,
-         0.1165, -0.3273,  0.1179, -0.2974, -0.3126, -0.3475,  0.0000, -1.2700,
-         0.0000,  0.1211,  0.0000,  0.0000, -0.0663, -0.0251,  0.0000, -0.5599,
-         0.0302,  1.1239, -0.3468,  0.1087, -0.4446,  0.0000, -0.5071, -1.5024,
-         0.0000,  0.0606,  0.0000, -0.3079,  0.0000,  0.0000,  0.6826,  1.0259,
-         0.4469,  0.0000, -0.5528,  0.0000,  0.0000, -0.6850,  0.6366,  0.1194,
-         0.9903,  0.0000,  0.4918,  0.0523, -0.5442,  0.3177,  1.0237,  0.0000,
-         0.0000,  0.0000,  0.4442,  1.1176,  0.0000,  0.3719,  0.2000, -0.7062],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.4564e-01,  4.3686e-01,  3.5586e-01, -1.0115e-01,  2.7514e+00,
-         4.0052e-01,  4.4078e-09, -3.4804e-02,  1.9066e-01, -3.4893e-01,
-         1.2744e-01, -3.3050e-01, -4.0061e-01, -2.8619e-01, -7.9044e-06,
-        -1.2569e+00, -5.7788e-04,  1.9167e-01, -1.3366e-06,  1.6070e-02,
-        -2.4684e-02, -4.7605e-02, -2.8384e-06, -4.9052e-01, -7.8256e-03,
-         1.1309e+00, -3.5191e-01,  1.5242e-01, -3.5952e-01,  4.1281e-06,
-        -4.4362e-01, -1.4992e+00, -1.0567e-09,  4.0074e-02,  5.2450e-04,
-        -3.2096e-01,  1.7496e-05,  0.0000e+00,  6.5954e-01,  1.0469e+00,
-         5.0606e-01, -2.3369e-03, -5.8621e-01,  4.3059e-06,  1.0532e-03,
-        -6.4734e-01,  6.2154e-01,  1.5981e-01,  9.6530e-01, -2.7413e-07,
-         5.0901e-01,  9.2816e-02, -5.4733e-01,  2.3990e-01,  1.0482e+00,
-         5.7050e-11, -2.2422e-05,  8.0637e-09,  4.0656e-01,  1.1102e+00,
-         2.1540e-08,  3.0995e-01,  2.0298e-01, -7.0515e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4456,  0.4369,  0.3559, -0.1012,  2.7514,  0.4005,  0.0000,  0.0000,
-         0.1907, -0.3489,  0.1274, -0.3305, -0.4006, -0.2862,  0.0000, -1.2569,
-         0.0000,  0.1917,  0.0000,  0.0000, -0.0247, -0.0476,  0.0000, -0.4905,
-        -0.0078,  1.1309, -0.3519,  0.1524, -0.3595,  0.0000, -0.4436, -1.4992,
-         0.0000,  0.0401,  0.0000, -0.3210,  0.0000,  0.0000,  0.6595,  1.0469,
-         0.5061,  0.0000, -0.5862,  0.0000,  0.0000, -0.6473,  0.6215,  0.1598,
-         0.9653,  0.0000,  0.5090,  0.0928, -0.5473,  0.2399,  1.0482,  0.0000,
-         0.0000,  0.0000,  0.4066,  1.1102,  0.0000,  0.3100,  0.2030, -0.7052],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4456,  0.4369,  0.3559, -0.1012,  2.7514,  0.4005,  0.0000,  0.0000,
-         0.1907, -0.3489,  0.1274, -0.3305, -0.4006, -0.2862,  0.0000, -1.2569,
-         0.0000,  0.1917,  0.0000,  0.0000, -0.0247, -0.0476,  0.0000, -0.4905,
-        -0.0078,  1.1309, -0.3519,  0.1524, -0.3595,  0.0000, -0.4436, -1.4992,
-         0.0000,  0.0401,  0.0000, -0.3210,  0.0000,  0.0000,  0.6595,  1.0469,
-         0.5061,  0.0000, -0.5862,  0.0000,  0.0000, -0.6473,  0.6215,  0.1598,
-         0.9653,  0.0000,  0.5090,  0.0928, -0.5473,  0.2399,  1.0482,  0.0000,
-         0.0000,  0.0000,  0.4066,  1.1102,  0.0000,  0.3100,  0.2030, -0.7052],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.0675e-01,  4.3495e-01,  3.6935e-01, -4.9371e-02,  2.7408e+00,
-         4.2849e-01,  3.7654e-09, -2.9732e-02,  2.6867e-01, -3.7072e-01,
-         1.1136e-01, -3.1791e-01, -4.0417e-01, -2.2272e-01, -6.7525e-06,
-        -1.2451e+00, -4.9367e-04,  2.9456e-01, -1.1418e-06,  1.3728e-02,
-         8.9688e-03, -7.6323e-02, -2.4248e-06, -4.6161e-01, -6.2462e-02,
-         1.1251e+00, -3.1951e-01,  1.8377e-01, -3.0984e-01,  3.5265e-06,
-        -3.5314e-01, -1.4933e+00, -9.0272e-10,  1.3515e-04,  4.4807e-04,
-        -2.9705e-01,  1.4946e-05,  0.0000e+00,  6.0204e-01,  1.0398e+00,
-         5.3325e-01, -1.9963e-03, -5.4726e-01,  3.6785e-06,  8.9971e-04,
-        -5.8262e-01,  5.8942e-01,  1.8665e-01,  9.3911e-01, -2.3418e-07,
-         5.0487e-01,  2.1886e-02, -5.1786e-01,  2.2792e-01,  1.0523e+00,
-         4.8736e-11, -1.9154e-05,  6.8886e-09,  3.7068e-01,  1.1044e+00,
-         1.8401e-08,  3.3121e-01,  1.6405e-01, -6.8330e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 4.0675e-01,  4.3495e-01,  3.6935e-01, -4.9371e-02,  2.7408e+00,
-         4.2849e-01,  0.0000e+00,  0.0000e+00,  2.6867e-01, -3.7072e-01,
-         1.1136e-01, -3.1791e-01, -4.0417e-01, -2.2272e-01,  0.0000e+00,
-        -1.2451e+00,  0.0000e+00,  2.9456e-01,  0.0000e+00,  0.0000e+00,
-         8.9688e-03, -7.6323e-02,  0.0000e+00, -4.6161e-01, -6.2462e-02,
-         1.1251e+00, -3.1951e-01,  1.8377e-01, -3.0984e-01,  0.0000e+00,
-        -3.5314e-01, -1.4933e+00,  0.0000e+00,  1.3515e-04,  0.0000e+00,
-        -2.9705e-01,  0.0000e+00,  0.0000e+00,  6.0204e-01,  1.0398e+00,
-         5.3325e-01,  0.0000e+00, -5.4726e-01,  0.0000e+00,  0.0000e+00,
-        -5.8262e-01,  5.8942e-01,  1.8665e-01,  9.3911e-01,  0.0000e+00,
-         5.0487e-01,  2.1886e-02, -5.1786e-01,  2.2792e-01,  1.0523e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.7068e-01,  1.1044e+00,
-         0.0000e+00,  3.3121e-01,  1.6405e-01, -6.8330e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 4.0675e-01,  4.3495e-01,  3.6935e-01, -4.9371e-02,  2.7408e+00,
-         4.2849e-01,  0.0000e+00,  0.0000e+00,  2.6867e-01, -3.7072e-01,
-         1.1136e-01, -3.1791e-01, -4.0417e-01, -2.2272e-01,  0.0000e+00,
-        -1.2451e+00,  0.0000e+00,  2.9456e-01,  0.0000e+00,  0.0000e+00,
-         8.9688e-03, -7.6323e-02,  0.0000e+00, -4.6161e-01, -6.2462e-02,
-         1.1251e+00, -3.1951e-01,  1.8377e-01, -3.0984e-01,  0.0000e+00,
-        -3.5314e-01, -1.4933e+00,  0.0000e+00,  1.3515e-04,  0.0000e+00,
-        -2.9705e-01,  0.0000e+00,  0.0000e+00,  6.0204e-01,  1.0398e+00,
-         5.3325e-01,  0.0000e+00, -5.4726e-01,  0.0000e+00,  0.0000e+00,
-        -5.8262e-01,  5.8942e-01,  1.8665e-01,  9.3911e-01,  0.0000e+00,
-         5.0487e-01,  2.1886e-02, -5.1786e-01,  2.2792e-01,  1.0523e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.7068e-01,  1.1044e+00,
-         0.0000e+00,  3.3121e-01,  1.6405e-01, -6.8330e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.3334e-01,  4.5853e-01,  3.5917e-01, -5.7213e-02,  2.7320e+00,
-         4.8615e-01,  3.2174e-09, -2.5405e-02,  3.4827e-01, -3.9466e-01,
-         9.9675e-02, -3.3736e-01, -4.1812e-01, -1.8233e-01, -5.7697e-06,
-        -1.2296e+00, -4.2181e-04,  3.2746e-01, -9.7562e-07,  1.1730e-02,
-         6.4807e-02, -1.1550e-01, -2.0719e-06, -4.4305e-01, -1.1371e-01,
-         1.1172e+00, -3.0572e-01,  2.2500e-01, -2.5798e-01,  3.0132e-06,
-        -2.4887e-01, -1.4862e+00, -7.7132e-10, -5.9355e-02,  3.8285e-04,
-        -2.7560e-01,  1.2771e-05,  0.0000e+00,  5.3156e-01,  1.0299e+00,
-         5.4878e-01, -1.7058e-03, -5.1723e-01,  3.1430e-06,  7.6875e-04,
-        -5.2316e-01,  5.6412e-01,  1.6917e-01,  9.1337e-01, -2.0009e-07,
-         4.9189e-01, -1.1376e-01, -4.6885e-01,  8.6124e-02,  1.0471e+00,
-         4.1643e-11, -1.6366e-05,  5.8859e-09,  3.4401e-01,  1.1021e+00,
-         1.5723e-08,  3.5308e-01,  9.0090e-02, -6.6843e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3333,  0.4585,  0.3592, -0.0572,  2.7320,  0.4862,  0.0000,  0.0000,
-         0.3483, -0.3947,  0.0997, -0.3374, -0.4181, -0.1823,  0.0000, -1.2296,
-         0.0000,  0.3275,  0.0000,  0.0000,  0.0648, -0.1155,  0.0000, -0.4430,
-        -0.1137,  1.1172, -0.3057,  0.2250, -0.2580,  0.0000, -0.2489, -1.4862,
-         0.0000, -0.0594,  0.0000, -0.2756,  0.0000,  0.0000,  0.5316,  1.0299,
-         0.5488,  0.0000, -0.5172,  0.0000,  0.0000, -0.5232,  0.5641,  0.1692,
-         0.9134,  0.0000,  0.4919, -0.1138, -0.4688,  0.0861,  1.0471,  0.0000,
-         0.0000,  0.0000,  0.3440,  1.1021,  0.0000,  0.3531,  0.0901, -0.6684],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3333,  0.4585,  0.3592, -0.0572,  2.7320,  0.4862,  0.0000,  0.0000,
-         0.3483, -0.3947,  0.0997, -0.3374, -0.4181, -0.1823,  0.0000, -1.2296,
-         0.0000,  0.3275,  0.0000,  0.0000,  0.0648, -0.1155,  0.0000, -0.4430,
-        -0.1137,  1.1172, -0.3057,  0.2250, -0.2580,  0.0000, -0.2489, -1.4862,
-         0.0000, -0.0594,  0.0000, -0.2756,  0.0000,  0.0000,  0.5316,  1.0299,
-         0.5488,  0.0000, -0.5172,  0.0000,  0.0000, -0.5232,  0.5641,  0.1692,
-         0.9134,  0.0000,  0.4919, -0.1138, -0.4688,  0.0861,  1.0471,  0.0000,
-         0.0000,  0.0000,  0.3440,  1.1021,  0.0000,  0.3531,  0.0901, -0.6684],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9627e-01,  4.5092e-01,  3.4652e-01, -4.6569e-02,  2.7260e+00,
-         5.2229e-01,  2.7496e-09, -2.1711e-02,  3.9071e-01, -4.2237e-01,
-         1.2604e-02, -3.3097e-01, -3.9669e-01, -1.1215e-01, -4.9309e-06,
-        -1.2185e+00, -3.6049e-04,  3.4274e-01, -8.3379e-07,  1.0025e-02,
-         1.0582e-01, -1.4972e-01, -1.7707e-06, -4.0933e-01, -1.6046e-01,
-         1.1069e+00, -2.8464e-01,  2.6046e-01, -2.0897e-01,  2.5752e-06,
-        -1.3702e-01, -1.4749e+00, -6.5919e-10, -7.0743e-02,  3.2719e-04,
-        -2.6231e-01,  1.0914e-05,  0.0000e+00,  4.5097e-01,  1.0114e+00,
-         5.6462e-01, -1.4578e-03, -4.4888e-01,  2.6861e-06,  6.5699e-04,
-        -4.4035e-01,  5.3199e-01,  1.3148e-01,  8.8902e-01, -1.7100e-07,
-         4.7687e-01, -2.0492e-01, -4.3648e-01, -5.7009e-02,  1.0389e+00,
-         3.5589e-11, -1.3987e-05,  5.0302e-09,  3.0275e-01,  1.1000e+00,
-         1.3437e-08,  3.6486e-01,  1.0215e-02, -6.5347e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2963,  0.4509,  0.3465, -0.0466,  2.7260,  0.5223,  0.0000,  0.0000,
-         0.3907, -0.4224,  0.0126, -0.3310, -0.3967, -0.1122,  0.0000, -1.2185,
-         0.0000,  0.3427,  0.0000,  0.0000,  0.1058, -0.1497,  0.0000, -0.4093,
-        -0.1605,  1.1069, -0.2846,  0.2605, -0.2090,  0.0000, -0.1370, -1.4749,
-         0.0000, -0.0707,  0.0000, -0.2623,  0.0000,  0.0000,  0.4510,  1.0114,
-         0.5646,  0.0000, -0.4489,  0.0000,  0.0000, -0.4404,  0.5320,  0.1315,
-         0.8890,  0.0000,  0.4769, -0.2049, -0.4365, -0.0570,  1.0389,  0.0000,
-         0.0000,  0.0000,  0.3028,  1.1000,  0.0000,  0.3649,  0.0102, -0.6535],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2963,  0.4509,  0.3465, -0.0466,  2.7260,  0.5223,  0.0000,  0.0000,
-         0.3907, -0.4224,  0.0126, -0.3310, -0.3967, -0.1122,  0.0000, -1.2185,
-         0.0000,  0.3427,  0.0000,  0.0000,  0.1058, -0.1497,  0.0000, -0.4093,
-        -0.1605,  1.1069, -0.2846,  0.2605, -0.2090,  0.0000, -0.1370, -1.4749,
-         0.0000, -0.0707,  0.0000, -0.2623,  0.0000,  0.0000,  0.4510,  1.0114,
-         0.5646,  0.0000, -0.4489,  0.0000,  0.0000, -0.4404,  0.5320,  0.1315,
-         0.8890,  0.0000,  0.4769, -0.2049, -0.4365, -0.0570,  1.0389,  0.0000,
-         0.0000,  0.0000,  0.3028,  1.1000,  0.0000,  0.3649,  0.0102, -0.6535],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4912e-01,  4.2760e-01,  3.2308e-01, -5.6127e-02,  2.7220e+00,
-         5.2645e-01,  2.3504e-09, -1.8559e-02,  4.1562e-01, -4.5729e-01,
-        -2.4381e-02, -3.1222e-01, -3.9895e-01, -8.5055e-02, -4.2149e-06,
-        -1.2031e+00, -3.0815e-04,  3.3275e-01, -7.1272e-07,  8.5693e-03,
-         1.1523e-01, -1.6045e-01, -1.5136e-06, -3.8545e-01, -1.6394e-01,
-         1.0957e+00, -2.8261e-01,  2.6544e-01, -1.3976e-01,  2.2012e-06,
-        -8.6239e-02, -1.4628e+00, -5.6348e-10, -7.9504e-02,  2.7968e-04,
-        -2.4722e-01,  9.3295e-06,  0.0000e+00,  3.9389e-01,  9.8736e-01,
-         5.6027e-01, -1.2461e-03, -4.0912e-01,  2.2961e-06,  5.6160e-04,
-        -3.6911e-01,  5.0858e-01,  1.2223e-01,  8.8733e-01, -1.4617e-07,
-         4.7767e-01, -2.7861e-01, -3.8287e-01, -2.7437e-01,  1.0284e+00,
-         3.0421e-11, -1.1956e-05,  4.2998e-09,  3.2058e-01,  1.0982e+00,
-         1.1486e-08,  3.3467e-01, -2.0896e-02, -6.3487e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2491,  0.4276,  0.3231, -0.0561,  2.7220,  0.5264,  0.0000,  0.0000,
-         0.4156, -0.4573, -0.0244, -0.3122, -0.3990, -0.0851,  0.0000, -1.2031,
-         0.0000,  0.3327,  0.0000,  0.0000,  0.1152, -0.1605,  0.0000, -0.3854,
-        -0.1639,  1.0957, -0.2826,  0.2654, -0.1398,  0.0000, -0.0862, -1.4628,
-         0.0000, -0.0795,  0.0000,  0.0000,  0.0000,  0.0000,  0.3939,  0.9874,
-         0.5603,  0.0000, -0.4091,  0.0000,  0.0000, -0.3691,  0.5086,  0.1222,
-         0.8873,  0.0000,  0.4777, -0.2786, -0.3829, -0.2744,  1.0284,  0.0000,
-         0.0000,  0.0000,  0.3206,  1.0982,  0.0000,  0.3347, -0.0209, -0.6349],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2491,  0.4276,  0.3231, -0.0561,  2.7220,  0.5264,  0.0000,  0.0000,
-         0.4156, -0.4573, -0.0244, -0.3122, -0.3990, -0.0851,  0.0000, -1.2031,
-         0.0000,  0.3327,  0.0000,  0.0000,  0.1152, -0.1605,  0.0000, -0.3854,
-        -0.1639,  1.0957, -0.2826,  0.2654, -0.1398,  0.0000, -0.0862, -1.4628,
-         0.0000, -0.0795,  0.0000,  0.0000,  0.0000,  0.0000,  0.3939,  0.9874,
-         0.5603,  0.0000, -0.4091,  0.0000,  0.0000, -0.3691,  0.5086,  0.1222,
-         0.8873,  0.0000,  0.4777, -0.2786, -0.3829, -0.2744,  1.0284,  0.0000,
-         0.0000,  0.0000,  0.3206,  1.0982,  0.0000,  0.3347, -0.0209, -0.6349],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.8975e-01,  3.8360e-01,  3.0019e-01, -7.0464e-03,  2.7199e+00,
-         4.8507e-01,  2.0095e-09, -1.5868e-02,  4.2737e-01, -4.2401e-01,
-        -8.3435e-02, -3.3731e-01, -3.5443e-01, -1.7457e-01, -3.6037e-06,
-        -1.1933e+00, -2.6346e-04,  3.1275e-01, -6.0936e-07,  7.3266e-03,
-         2.7356e-02,  3.3620e-02, -1.2941e-06, -4.1268e-01, -1.0802e-01,
-         1.0780e+00, -2.1754e-01,  2.5102e-01, -8.1246e-02,  1.8820e-06,
-        -5.7902e-03, -1.4553e+00, -4.8176e-10,  4.7904e-02,  2.3912e-04,
-         1.2902e-02,  7.9766e-06,  0.0000e+00,  3.6391e-01,  9.6330e-01,
-         5.2022e-01, -1.0654e-03, -2.8936e-01,  1.9631e-06,  4.8016e-04,
-        -3.0912e-01,  5.0856e-01,  9.7507e-02,  8.9250e-01, -1.2498e-07,
-         4.1062e-01, -3.9516e-01, -4.0363e-01,  5.3886e-02,  1.0152e+00,
-         2.6010e-11, -1.0222e-05,  3.6763e-09,  3.6290e-01,  1.1037e+00,
-         9.8202e-09,  2.3932e-01, -6.0232e-02, -5.8555e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1898,  0.3836,  0.3002, -0.0070,  2.7199,  0.4851,  0.0000,  0.0000,
-         0.4274, -0.4240, -0.0834, -0.3373, -0.3544, -0.1746,  0.0000, -1.1933,
-         0.0000,  0.3127,  0.0000,  0.0000,  0.0274,  0.0336,  0.0000, -0.4127,
-        -0.1080,  1.0780, -0.2175,  0.2510, -0.0812,  0.0000, -0.0058, -1.4553,
-         0.0000,  0.0479,  0.0000,  0.0000,  0.0000,  0.0000,  0.3639,  0.9633,
-         0.5202,  0.0000, -0.2894,  0.0000,  0.0000, -0.3091,  0.5086,  0.0975,
-         0.8925,  0.0000,  0.4106, -0.3952, -0.4036,  0.0539,  1.0152,  0.0000,
-         0.0000,  0.0000,  0.3629,  1.1037,  0.0000,  0.2393, -0.0602, -0.5855],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1898,  0.3836,  0.3002, -0.0070,  2.7199,  0.4851,  0.0000,  0.0000,
-         0.4274, -0.4240, -0.0834, -0.3373, -0.3544, -0.1746,  0.0000, -1.1933,
-         0.0000,  0.3127,  0.0000,  0.0000,  0.0274,  0.0336,  0.0000, -0.4127,
-        -0.1080,  1.0780, -0.2175,  0.2510, -0.0812,  0.0000, -0.0058, -1.4553,
-         0.0000,  0.0479,  0.0000,  0.0000,  0.0000,  0.0000,  0.3639,  0.9633,
-         0.5202,  0.0000, -0.2894,  0.0000,  0.0000, -0.3091,  0.5086,  0.0975,
-         0.8925,  0.0000,  0.4106, -0.3952, -0.4036,  0.0539,  1.0152,  0.0000,
-         0.0000,  0.0000,  0.3629,  1.1037,  0.0000,  0.2393, -0.0602, -0.5855],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.5906e-01,  3.3782e-01,  3.0480e-01, -6.2721e-02,  2.7197e+00,
-         4.9805e-01,  1.7185e-09, -1.3569e-02,  4.2031e-01, -3.6243e-01,
-         4.8423e-02, -3.3672e-01, -3.4525e-01, -2.7846e-01, -3.0817e-06,
-        -1.1947e+00, -2.2530e-04,  2.1786e-01, -5.2111e-07,  6.2654e-03,
-        -7.4090e-02,  2.2070e-01, -1.1066e-06, -4.2115e-01,  4.7580e-02,
-         1.0761e+00, -2.1757e-01,  2.1229e-01,  5.7261e-02,  1.6094e-06,
-        -2.0782e-03, -1.4484e+00, -4.1199e-10,  2.2591e-01,  2.0449e-04,
-         1.1034e-02,  6.8213e-06,  0.0000e+00,  4.2499e-01,  9.7126e-01,
-         4.7087e-01, -9.1110e-04, -2.4300e-01,  1.6788e-06,  4.1061e-04,
-        -3.2012e-01,  4.9045e-01,  1.4869e-01,  8.9514e-01, -1.0688e-07,
-         3.6222e-01, -4.3391e-01, -4.2441e-01,  1.8858e-01,  1.0138e+00,
-         2.2243e-11, -8.7417e-06,  3.1438e-09,  3.6721e-01,  1.1092e+00,
-         8.3979e-09,  4.4375e-02, -8.7367e-03, -5.7737e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 1.5906e-01,  3.3782e-01,  3.0480e-01, -6.2721e-02,  2.7197e+00,
-         4.9805e-01,  0.0000e+00,  0.0000e+00,  4.2031e-01, -3.6243e-01,
-         4.8423e-02, -3.3672e-01, -3.4525e-01, -2.7846e-01,  0.0000e+00,
-        -1.1947e+00,  0.0000e+00,  2.1786e-01,  0.0000e+00,  0.0000e+00,
-        -7.4090e-02,  2.2070e-01,  0.0000e+00, -4.2115e-01,  4.7580e-02,
-         1.0761e+00, -2.1757e-01,  2.1229e-01,  5.7261e-02,  0.0000e+00,
-        -2.0782e-03, -1.4484e+00,  0.0000e+00,  2.2591e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2499e-01,  9.7126e-01,
-         4.7087e-01,  0.0000e+00, -2.4300e-01,  0.0000e+00,  0.0000e+00,
-        -3.2012e-01,  4.9045e-01,  1.4869e-01,  8.9514e-01,  0.0000e+00,
-         3.6222e-01, -4.3391e-01, -4.2441e-01,  1.8858e-01,  1.0138e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.6721e-01,  1.1092e+00,
-         0.0000e+00,  4.4375e-02, -8.7367e-03, -5.7737e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 1.5906e-01,  3.3782e-01,  3.0480e-01, -6.2721e-02,  2.7197e+00,
-         4.9805e-01,  0.0000e+00,  0.0000e+00,  4.2031e-01, -3.6243e-01,
-         4.8423e-02, -3.3672e-01, -3.4525e-01, -2.7846e-01,  0.0000e+00,
-        -1.1947e+00,  0.0000e+00,  2.1786e-01,  0.0000e+00,  0.0000e+00,
-        -7.4090e-02,  2.2070e-01,  0.0000e+00, -4.2115e-01,  4.7580e-02,
-         1.0761e+00, -2.1757e-01,  2.1229e-01,  5.7261e-02,  0.0000e+00,
-        -2.0782e-03, -1.4484e+00,  0.0000e+00,  2.2591e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2499e-01,  9.7126e-01,
-         4.7087e-01,  0.0000e+00, -2.4300e-01,  0.0000e+00,  0.0000e+00,
-        -3.2012e-01,  4.9045e-01,  1.4869e-01,  8.9514e-01,  0.0000e+00,
-         3.6222e-01, -4.3391e-01, -4.2441e-01,  1.8858e-01,  1.0138e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.6721e-01,  1.1092e+00,
-         0.0000e+00,  4.4375e-02, -8.7367e-03, -5.7737e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.4747e-01,  2.5880e-01,  2.7776e-01, -1.2920e-01,  2.7231e+00,
-         4.7706e-01,  1.4699e-09, -1.1607e-02,  4.0159e-01, -3.1927e-01,
-         6.1222e-02, -2.8541e-01, -3.1573e-01, -2.4676e-01, -2.6360e-06,
-        -1.1963e+00, -1.9271e-04,  1.1561e-01, -4.4573e-07,  5.3592e-03,
-        -2.5868e-01,  3.2565e-01, -9.4657e-07, -4.4370e-01,  1.2411e-01,
-         1.0794e+00, -2.5003e-01,  1.8258e-01,  4.7644e-02,  1.3767e-06,
-         2.3525e-02, -1.4454e+00, -3.5240e-10,  3.0549e-01,  1.7491e-04,
-         9.4376e-03,  5.8347e-06,  0.0000e+00,  4.8075e-01,  9.7776e-01,
-         4.1098e-01, -7.7932e-04, -2.2798e-01,  1.4360e-06,  3.5122e-04,
-        -2.7862e-01,  4.9180e-01,  1.8175e-01,  9.0644e-01, -9.1416e-08,
-         3.0758e-01, -4.2692e-01, -4.1507e-01,  2.2718e-01,  1.0155e+00,
-         1.9025e-11, -7.4773e-06,  2.6891e-09,  3.3442e-01,  1.1188e+00,
-         7.1832e-09, -1.6147e-01,  5.4861e-02, -5.6404e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1475,  0.2588,  0.2778, -0.1292,  2.7231,  0.4771,  0.0000,  0.0000,
-         0.4016, -0.3193,  0.0612, -0.2854, -0.3157, -0.2468,  0.0000, -1.1963,
-         0.0000,  0.1156,  0.0000,  0.0000, -0.2587,  0.3257,  0.0000, -0.4437,
-         0.1241,  1.0794, -0.2500,  0.1826,  0.0476,  0.0000,  0.0235, -1.4454,
-         0.0000,  0.3055,  0.0000,  0.0000,  0.0000,  0.0000,  0.4808,  0.9778,
-         0.4110,  0.0000, -0.2280,  0.0000,  0.0000, -0.2786,  0.4918,  0.1818,
-         0.9064,  0.0000,  0.3076, -0.4269, -0.4151,  0.2272,  1.0155,  0.0000,
-         0.0000,  0.0000,  0.3344,  1.1188,  0.0000, -0.1615,  0.0549, -0.5640],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1475,  0.2588,  0.2778, -0.1292,  2.7231,  0.4771,  0.0000,  0.0000,
-         0.4016, -0.3193,  0.0612, -0.2854, -0.3157, -0.2468,  0.0000, -1.1963,
-         0.0000,  0.1156,  0.0000,  0.0000, -0.2587,  0.3257,  0.0000, -0.4437,
-         0.1241,  1.0794, -0.2500,  0.1826,  0.0476,  0.0000,  0.0235, -1.4454,
-         0.0000,  0.3055,  0.0000,  0.0000,  0.0000,  0.0000,  0.4808,  0.9778,
-         0.4110,  0.0000, -0.2280,  0.0000,  0.0000, -0.2786,  0.4918,  0.1818,
-         0.9064,  0.0000,  0.3076, -0.4269, -0.4151,  0.2272,  1.0155,  0.0000,
-         0.0000,  0.0000,  0.3344,  1.1188,  0.0000, -0.1615,  0.0549, -0.5640],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.3775e-01,  1.7042e-01,  2.4988e-01, -1.5140e-01,  2.7253e+00,
-         4.3696e-01,  1.2576e-09, -9.9301e-03,  3.6596e-01, -2.9710e-01,
-         7.8235e-02, -2.0984e-01, -2.5707e-01, -1.6060e-01, -2.2552e-06,
-        -1.1980e+00, -1.6488e-04, -2.4578e-02, -3.8134e-07,  4.5850e-03,
-        -3.9738e-01,  4.0200e-01, -8.0984e-07, -4.7869e-01,  1.6679e-01,
-         1.0756e+00, -2.4501e-01,  1.3527e-01,  1.1726e-02,  1.1778e-06,
-        -1.4915e-02, -1.4448e+00, -3.0149e-10,  3.4816e-01,  1.4965e-04,
-         8.0743e-03,  4.9918e-06,  0.0000e+00,  5.1952e-01,  9.8283e-01,
-         3.7718e-01, -6.6674e-04, -1.7695e-01,  1.2285e-06,  3.0049e-04,
-        -2.7716e-01,  4.9637e-01,  2.6612e-01,  9.0918e-01, -7.8211e-08,
-         2.9908e-01, -4.1023e-01, -3.9323e-01,  1.7808e-01,  1.0273e+00,
-         1.6277e-11, -6.3971e-06,  2.3006e-09,  3.0368e-01,  1.1245e+00,
-         6.1456e-09, -3.2575e-01,  1.8428e-01, -5.1446e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1378,  0.1704,  0.2499, -0.1514,  2.7253,  0.4370,  0.0000,  0.0000,
-         0.3660, -0.2971,  0.0782, -0.2098, -0.2571, -0.1606,  0.0000,  0.0000,
-         0.0000, -0.0246,  0.0000,  0.0000, -0.3974,  0.4020,  0.0000, -0.4787,
-         0.1668,  1.0756, -0.2450,  0.1353,  0.0117,  0.0000, -0.0149, -1.4448,
-         0.0000,  0.3482,  0.0000,  0.0000,  0.0000,  0.0000,  0.5195,  0.9828,
-         0.3772,  0.0000, -0.1769,  0.0000,  0.0000, -0.2772,  0.4964,  0.2661,
-         0.9092,  0.0000,  0.2991, -0.4102, -0.3932,  0.1781,  1.0273,  0.0000,
-         0.0000,  0.0000,  0.3037,  1.1245,  0.0000, -0.3257,  0.1843, -0.5145],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1378,  0.1704,  0.2499, -0.1514,  2.7253,  0.4370,  0.0000,  0.0000,
-         0.3660, -0.2971,  0.0782, -0.2098, -0.2571, -0.1606,  0.0000,  0.0000,
-         0.0000, -0.0246,  0.0000,  0.0000, -0.3974,  0.4020,  0.0000, -0.4787,
-         0.1668,  1.0756, -0.2450,  0.1353,  0.0117,  0.0000, -0.0149, -1.4448,
-         0.0000,  0.3482,  0.0000,  0.0000,  0.0000,  0.0000,  0.5195,  0.9828,
-         0.3772,  0.0000, -0.1769,  0.0000,  0.0000, -0.2772,  0.4964,  0.2661,
-         0.9092,  0.0000,  0.2991, -0.4102, -0.3932,  0.1781,  1.0273,  0.0000,
-         0.0000,  0.0000,  0.3037,  1.1245,  0.0000, -0.3257,  0.1843, -0.5145],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.4441e-01,  5.3463e-02,  2.4305e-01, -9.3927e-02,  2.7255e+00,
-         4.2809e-01,  1.0762e-09, -8.4976e-03,  3.1551e-01, -2.9110e-01,
-        -4.5661e-02, -1.6388e-01, -2.2305e-01, -3.0147e-02, -1.9299e-06,
-        -1.4344e-03, -1.4109e-04, -1.5792e-01, -3.2633e-07,  3.9236e-03,
-        -5.1015e-01,  4.6752e-01, -6.9301e-07, -5.4081e-01,  1.7424e-01,
-         1.0796e+00, -1.9255e-01,  1.1037e-01, -9.3539e-02,  1.0079e-06,
-        -2.8732e-02, -1.4460e+00, -2.5800e-10,  3.4710e-01,  1.2806e-04,
-         6.9095e-03,  4.2717e-06,  0.0000e+00,  5.3135e-01,  9.8822e-01,
-         2.9936e-01, -5.7056e-04, -9.4871e-02,  1.0513e-06,  2.5714e-04,
-        -3.2398e-01,  5.0008e-01,  3.2340e-01,  9.1593e-01, -6.6928e-08,
-         2.4303e-01, -4.6766e-01, -4.1479e-01,  2.2508e-01,  1.0406e+00,
-         1.3929e-11, -5.4743e-06,  1.9688e-09,  2.6128e-01,  1.1341e+00,
-         5.2590e-09, -4.5035e-01,  2.7184e-01, -4.5637e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1444,  0.0535,  0.2431, -0.0939,  2.7255,  0.4281,  0.0000,  0.0000,
-         0.3155, -0.2911, -0.0457, -0.1639, -0.2230, -0.0301,  0.0000,  0.0000,
-         0.0000, -0.1579,  0.0000,  0.0000, -0.5102,  0.4675,  0.0000, -0.5408,
-         0.1742,  1.0796, -0.1926,  0.1104, -0.0935,  0.0000, -0.0287, -1.4460,
-         0.0000,  0.3471,  0.0000,  0.0000,  0.0000,  0.0000,  0.5314,  0.9882,
-         0.2994,  0.0000, -0.0949,  0.0000,  0.0000, -0.3240,  0.5001,  0.3234,
-         0.9159,  0.0000,  0.2430, -0.4677, -0.4148,  0.2251,  1.0406,  0.0000,
-         0.0000,  0.0000,  0.2613,  1.1341,  0.0000, -0.4503,  0.2718, -0.4564],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1444,  0.0535,  0.2431, -0.0939,  2.7255,  0.4281,  0.0000,  0.0000,
-         0.3155, -0.2911, -0.0457, -0.1639, -0.2230, -0.0301,  0.0000,  0.0000,
-         0.0000, -0.1579,  0.0000,  0.0000, -0.5102,  0.4675,  0.0000, -0.5408,
-         0.1742,  1.0796, -0.1926,  0.1104, -0.0935,  0.0000, -0.0287, -1.4460,
-         0.0000,  0.3471,  0.0000,  0.0000,  0.0000,  0.0000,  0.5314,  0.9882,
-         0.2994,  0.0000, -0.0949,  0.0000,  0.0000, -0.3240,  0.5001,  0.3234,
-         0.9159,  0.0000,  0.2430, -0.4677, -0.4148,  0.2251,  1.0406,  0.0000,
-         0.0000,  0.0000,  0.2613,  1.1341,  0.0000, -0.4503,  0.2718, -0.4564],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.0892e-01, -5.0623e-02,  1.9014e-01, -6.4007e-02,  2.7256e+00,
-         4.0215e-01,  9.2113e-10, -7.2734e-03,  2.5854e-01, -2.2957e-01,
-        -2.0593e-01, -2.1027e-01, -2.2740e-01,  2.1232e-02, -1.6519e-06,
-        -1.2277e-03, -1.2076e-04, -2.4978e-01, -2.7932e-07,  3.3584e-03,
-        -6.0895e-01,  5.6245e-01, -5.9317e-07, -6.0964e-01,  2.0902e-01,
-         1.0860e+00, -1.6120e-01,  7.7727e-02, -2.2549e-01,  8.6269e-07,
-        -1.6265e-02, -1.4469e+00, -2.2083e-10,  3.4150e-01,  1.0961e-04,
-         5.9141e-03,  3.6563e-06,  0.0000e+00,  5.0018e-01,  9.8313e-01,
-         2.5468e-01, -4.8836e-04, -4.7183e-02,  8.9985e-07,  2.2009e-04,
-        -3.3648e-01,  5.0457e-01,  3.3175e-01,  9.2041e-01, -5.7287e-08,
-         2.0069e-01, -5.0893e-01, -3.9296e-01,  2.6548e-01,  1.0424e+00,
-         1.1922e-11, -4.6857e-06,  1.6851e-09,  2.3007e-01,  1.1444e+00,
-         4.5014e-09, -4.9671e-01,  3.2431e-01, -4.4017e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1089, -0.0506,  0.1901, -0.0640,  2.7256,  0.4022,  0.0000,  0.0000,
-         0.2585, -0.2296, -0.2059, -0.2103, -0.2274,  0.0212,  0.0000,  0.0000,
-         0.0000, -0.2498,  0.0000,  0.0000, -0.6090,  0.5625,  0.0000, -0.6096,
-         0.2090,  1.0860, -0.1612,  0.0777, -0.2255,  0.0000, -0.0163, -1.4469,
-         0.0000,  0.3415,  0.0000,  0.0000,  0.0000,  0.0000,  0.5002,  0.9831,
-         0.2547,  0.0000, -0.0472,  0.0000,  0.0000, -0.3365,  0.5046,  0.3318,
-         0.9204,  0.0000,  0.2007, -0.5089, -0.3930,  0.2655,  1.0424,  0.0000,
-         0.0000,  0.0000,  0.2301,  1.1444,  0.0000, -0.4967,  0.3243, -0.4402],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1089, -0.0506,  0.1901, -0.0640,  2.7256,  0.4022,  0.0000,  0.0000,
-         0.2585, -0.2296, -0.2059, -0.2103, -0.2274,  0.0212,  0.0000,  0.0000,
-         0.0000, -0.2498,  0.0000,  0.0000, -0.6090,  0.5625,  0.0000, -0.6096,
-         0.2090,  1.0860, -0.1612,  0.0777, -0.2255,  0.0000, -0.0163, -1.4469,
-         0.0000,  0.3415,  0.0000,  0.0000,  0.0000,  0.0000,  0.5002,  0.9831,
-         0.2547,  0.0000, -0.0472,  0.0000,  0.0000, -0.3365,  0.5046,  0.3318,
-         0.9204,  0.0000,  0.2007, -0.5089, -0.3930,  0.2655,  1.0424,  0.0000,
-         0.0000,  0.0000,  0.2301,  1.1444,  0.0000, -0.4967,  0.3243, -0.4402],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.0077e-01, -9.7602e-02,  1.6039e-01, -5.8889e-02,  2.7230e+00,
-         3.7489e-01,  7.8862e-10, -6.2270e-03,  2.3456e-01, -2.4798e-01,
-        -5.4535e-02, -2.3648e-01, -2.1779e-01, -3.6407e-02, -1.4142e-06,
-        -1.0511e-03, -1.0339e-04, -1.7114e-01, -2.3914e-07,  2.8752e-03,
-        -6.8931e-01,  6.5400e-01, -5.0784e-07, -6.1865e-01,  2.6618e-01,
-         1.0904e+00, -1.9180e-01,  5.2599e-02, -2.5756e-01,  7.3858e-07,
-        -1.0909e-01, -1.4444e+00, -1.8906e-10,  3.1947e-01,  9.3841e-05,
-         5.0633e-03,  3.1303e-06,  0.0000e+00,  4.4377e-01,  9.6500e-01,
-         2.2279e-01, -4.1811e-04,  2.1944e-02,  7.7040e-07,  1.8843e-04,
-        -3.0929e-01,  5.0590e-01,  3.0948e-01,  9.2250e-01, -4.9045e-08,
-         2.2910e-01, -4.6234e-01, -3.3872e-01,  2.0925e-01,  1.0285e+00,
-         1.0207e-11, -4.0116e-06,  1.4427e-09,  2.0616e-01,  1.1426e+00,
-         3.8538e-09, -5.4856e-01,  3.7626e-01, -4.1290e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1008, -0.0976,  0.1604, -0.0589,  2.7230,  0.3749,  0.0000,  0.0000,
-         0.2346, -0.2480, -0.0545, -0.2365, -0.2178, -0.0364,  0.0000,  0.0000,
-         0.0000, -0.1711,  0.0000,  0.0000, -0.6893,  0.6540,  0.0000, -0.6187,
-         0.2662,  1.0904, -0.1918,  0.0526, -0.2576,  0.0000, -0.1091, -1.4444,
-         0.0000,  0.3195,  0.0000,  0.0000,  0.0000,  0.0000,  0.4438,  0.9650,
-         0.2228,  0.0000,  0.0219,  0.0000,  0.0000, -0.3093,  0.5059,  0.3095,
-         0.9225,  0.0000,  0.2291, -0.4623, -0.3387,  0.2092,  1.0285,  0.0000,
-         0.0000,  0.0000,  0.2062,  1.1426,  0.0000, -0.5486,  0.3763, -0.4129],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1008, -0.0976,  0.1604, -0.0589,  2.7230,  0.3749,  0.0000,  0.0000,
-         0.2346, -0.2480, -0.0545, -0.2365, -0.2178, -0.0364,  0.0000,  0.0000,
-         0.0000, -0.1711,  0.0000,  0.0000, -0.6893,  0.6540,  0.0000, -0.6187,
-         0.2662,  1.0904, -0.1918,  0.0526, -0.2576,  0.0000, -0.1091, -1.4444,
-         0.0000,  0.3195,  0.0000,  0.0000,  0.0000,  0.0000,  0.4438,  0.9650,
-         0.2228,  0.0000,  0.0219,  0.0000,  0.0000, -0.3093,  0.5059,  0.3095,
-         0.9225,  0.0000,  0.2291, -0.4623, -0.3387,  0.2092,  1.0285,  0.0000,
-         0.0000,  0.0000,  0.2062,  1.1426,  0.0000, -0.5486,  0.3763, -0.4129],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.0843e-01, -6.9421e-02,  1.4900e-01, -1.0442e-01,  2.7227e+00,
-         3.5106e-01,  6.7533e-10, -5.3325e-03,  1.7863e-01, -2.2125e-01,
-         2.0709e-01, -2.9004e-01, -2.3644e-01, -1.9440e-01, -1.2111e-06,
-        -9.0010e-04, -8.8538e-05, -9.1089e-02, -2.0478e-07,  2.4622e-03,
-        -7.7383e-01,  7.2631e-01, -4.3488e-07, -6.4622e-01,  3.0499e-01,
-         1.0867e+00, -2.3092e-01, -1.9291e-02, -3.1787e-01,  6.3248e-07,
-        -2.5053e-01, -1.4455e+00, -1.6190e-10,  2.6691e-01,  8.0360e-05,
-         4.3359e-03,  2.6806e-06,  0.0000e+00,  3.5932e-01,  9.3889e-01,
-         1.8620e-01, -3.5804e-04,  9.0788e-02,  6.5972e-07,  1.6136e-04,
-        -2.7528e-01,  4.8699e-01,  2.9125e-01,  9.2230e-01, -4.1999e-08,
-         2.9150e-01, -4.2251e-01, -2.6801e-01,  2.2634e-01,  1.0095e+00,
-         8.7408e-12, -3.4353e-06,  1.2355e-09,  1.8177e-01,  1.1431e+00,
-         3.3002e-09, -5.6880e-01,  4.3173e-01, -4.1387e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1084, -0.0694,  0.1490, -0.1044,  2.7227,  0.3511,  0.0000,  0.0000,
-         0.1786, -0.2213,  0.2071, -0.2900, -0.2364, -0.1944,  0.0000,  0.0000,
-         0.0000, -0.0911,  0.0000,  0.0000, -0.7738,  0.7263,  0.0000, -0.6462,
-         0.3050,  1.0867, -0.2309, -0.0193, -0.3179,  0.0000, -0.2505, -1.4455,
-         0.0000,  0.2669,  0.0000,  0.0000,  0.0000,  0.0000,  0.3593,  0.9389,
-         0.1862,  0.0000,  0.0908,  0.0000,  0.0000, -0.2753,  0.4870,  0.2913,
-         0.9223,  0.0000,  0.2915, -0.4225, -0.2680,  0.2263,  1.0095,  0.0000,
-         0.0000,  0.0000,  0.1818,  1.1431,  0.0000, -0.5688,  0.4317, -0.4139],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1084, -0.0694,  0.1490, -0.1044,  2.7227,  0.3511,  0.0000,  0.0000,
-         0.1786, -0.2213,  0.2071, -0.2900, -0.2364, -0.1944,  0.0000,  0.0000,
-         0.0000, -0.0911,  0.0000,  0.0000, -0.7738,  0.7263,  0.0000, -0.6462,
-         0.3050,  1.0867, -0.2309, -0.0193, -0.3179,  0.0000, -0.2505, -1.4455,
-         0.0000,  0.2669,  0.0000,  0.0000,  0.0000,  0.0000,  0.3593,  0.9389,
-         0.1862,  0.0000,  0.0908,  0.0000,  0.0000, -0.2753,  0.4870,  0.2913,
-         0.9223,  0.0000,  0.2915, -0.4225, -0.2680,  0.2263,  1.0095,  0.0000,
-         0.0000,  0.0000,  0.1818,  1.1431,  0.0000, -0.5688,  0.4317, -0.4139],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.7547e-01, -4.1972e-02,  1.4288e-01, -1.0000e-01,  2.7193e+00,
-         3.3817e-01,  5.7845e-10, -4.5675e-03,  9.2131e-02, -9.5394e-02,
-         2.9351e-01, -3.9305e-01, -2.4236e-01, -3.2972e-01, -1.0373e-06,
-        -7.7098e-04, -7.5837e-05, -1.1179e-02, -1.7541e-07,  2.1090e-03,
-        -8.2351e-01,  7.9083e-01, -3.7250e-07, -6.2581e-01,  3.1280e-01,
-         1.0822e+00, -2.0750e-01, -7.4404e-02, -3.1729e-01,  5.4174e-07,
-        -2.2476e-01, -1.4466e+00, -1.3868e-10,  1.9755e-01,  6.8832e-05,
-         3.7139e-03,  2.2961e-06,  0.0000e+00,  2.6670e-01,  8.9971e-01,
-         1.0476e-01, -3.0668e-04,  1.5639e-01,  5.6509e-07,  1.3821e-04,
-        -3.3379e-01,  4.6749e-01,  2.2909e-01,  9.4535e-01, -3.5975e-08,
-         2.8409e-01, -3.6624e-01, -2.4897e-01,  3.5014e-01,  9.8552e-01,
-         7.4869e-12, -2.9425e-06,  1.0582e-09,  1.3642e-01,  1.1371e+00,
-         2.8268e-09, -5.3266e-01,  4.4751e-01, -4.2522e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1755, -0.0420,  0.1429, -0.1000,  2.7193,  0.3382,  0.0000,  0.0000,
-         0.0921, -0.0954,  0.2935, -0.3931, -0.2424, -0.3297,  0.0000,  0.0000,
-         0.0000, -0.0112,  0.0000,  0.0000, -0.8235,  0.7908,  0.0000, -0.6258,
-         0.3128,  1.0822, -0.2075, -0.0744, -0.3173,  0.0000, -0.2248, -1.4466,
-         0.0000,  0.1975,  0.0000,  0.0000,  0.0000,  0.0000,  0.2667,  0.8997,
-         0.1048,  0.0000,  0.1564,  0.0000,  0.0000, -0.3338,  0.0000,  0.2291,
-         0.9453,  0.0000,  0.2841, -0.3662, -0.2490,  0.3501,  0.9855,  0.0000,
-         0.0000,  0.0000,  0.1364,  1.1371,  0.0000, -0.5327,  0.4475, -0.4252],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1755, -0.0420,  0.1429, -0.1000,  2.7193,  0.3382,  0.0000,  0.0000,
-         0.0921, -0.0954,  0.2935, -0.3931, -0.2424, -0.3297,  0.0000,  0.0000,
-         0.0000, -0.0112,  0.0000,  0.0000, -0.8235,  0.7908,  0.0000, -0.6258,
-         0.3128,  1.0822, -0.2075, -0.0744, -0.3173,  0.0000, -0.2248, -1.4466,
-         0.0000,  0.1975,  0.0000,  0.0000,  0.0000,  0.0000,  0.2667,  0.8997,
-         0.1048,  0.0000,  0.1564,  0.0000,  0.0000, -0.3338,  0.0000,  0.2291,
-         0.9453,  0.0000,  0.2841, -0.3662, -0.2490,  0.3501,  0.9855,  0.0000,
-         0.0000,  0.0000,  0.1364,  1.1371,  0.0000, -0.5327,  0.4475, -0.4252],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.1371e-01, -2.6869e-02,  7.7750e-02, -1.0547e-01,  2.7176e+00,
-         3.2486e-01,  4.9559e-10, -3.9132e-03, -4.2963e-02, -2.9258e-02,
-         2.9280e-01, -4.4309e-01, -2.6921e-01, -3.9869e-01, -8.8873e-07,
-        -6.6054e-04, -6.4974e-05,  1.3126e-02, -1.5028e-07,  1.8069e-03,
-        -8.5762e-01,  7.3957e-01, -3.1914e-07, -6.6854e-01,  1.8483e-01,
-         1.0770e+00, -1.8161e-01, -1.4929e-01, -3.3667e-01,  4.6414e-07,
-        -2.3617e-01, -1.4476e+00, -1.1881e-10, -8.7138e-03,  5.8972e-05,
-         3.1819e-03,  1.9672e-06,  0.0000e+00,  2.7619e-01,  8.6511e-01,
-        -1.6681e-02, -2.6275e-04,  1.9125e-01,  4.8414e-07,  1.1842e-04,
-        -2.6737e-01, -1.6708e-02,  3.0226e-01,  9.7715e-01, -3.0821e-08,
-         2.5838e-01, -3.5294e-01, -1.9338e-01,  2.2639e-01,  9.6249e-01,
-         6.4144e-12, -2.5210e-06,  9.0664e-10,  1.6967e-01,  1.1374e+00,
-         2.4218e-09, -4.7681e-01,  4.3146e-01, -4.2991e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2137, -0.0269,  0.0778, -0.1055,  2.7176,  0.3249,  0.0000,  0.0000,
-        -0.0430, -0.0293,  0.2928, -0.4431, -0.2692, -0.3987,  0.0000,  0.0000,
-         0.0000,  0.0131,  0.0000,  0.0000, -0.8576,  0.7396,  0.0000, -0.6685,
-         0.1848,  1.0770, -0.1816, -0.1493, -0.3367,  0.0000, -0.2362, -1.4476,
-         0.0000, -0.0087,  0.0000,  0.0000,  0.0000,  0.0000,  0.2762,  0.8651,
-        -0.0167,  0.0000,  0.1912,  0.0000,  0.0000, -0.2674,  0.0000,  0.3023,
-         0.9772,  0.0000,  0.2584, -0.3529, -0.1934,  0.2264,  0.9625,  0.0000,
-         0.0000,  0.0000,  0.1697,  1.1374,  0.0000, -0.4768,  0.4315, -0.4299],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2137, -0.0269,  0.0778, -0.1055,  2.7176,  0.3249,  0.0000,  0.0000,
-        -0.0430, -0.0293,  0.2928, -0.4431, -0.2692, -0.3987,  0.0000,  0.0000,
-         0.0000,  0.0131,  0.0000,  0.0000, -0.8576,  0.7396,  0.0000, -0.6685,
-         0.1848,  1.0770, -0.1816, -0.1493, -0.3367,  0.0000, -0.2362, -1.4476,
-         0.0000, -0.0087,  0.0000,  0.0000,  0.0000,  0.0000,  0.2762,  0.8651,
-        -0.0167,  0.0000,  0.1912,  0.0000,  0.0000, -0.2674,  0.0000,  0.3023,
-         0.9772,  0.0000,  0.2584, -0.3529, -0.1934,  0.2264,  0.9625,  0.0000,
-         0.0000,  0.0000,  0.1697,  1.1374,  0.0000, -0.4768,  0.4315, -0.4299],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4871e-01, -4.7986e-02,  6.1407e-02, -4.7337e-02,  2.7156e+00,
-         3.0784e-01,  4.2470e-10, -3.3535e-03, -1.2093e-01, -1.3794e-02,
-         2.0347e-01, -4.5561e-01, -2.6295e-01, -4.3799e-01, -7.6162e-07,
-        -5.6606e-04, -5.5681e-05,  1.0214e-02, -1.2879e-07,  1.5484e-03,
-        -8.7576e-01,  6.6823e-01, -2.7349e-07, -7.0658e-01,  4.2610e-02,
-         1.0705e+00, -1.3507e-01, -2.1437e-01, -3.4740e-01,  3.9776e-07,
-        -2.2285e-01, -1.4453e+00, -1.0182e-10, -1.5795e-01,  5.0537e-05,
-         2.7268e-03,  1.6858e-06,  0.0000e+00,  2.8467e-01,  8.4643e-01,
-        -1.7527e-01, -2.2517e-04,  2.2805e-01,  4.1489e-07,  1.0148e-04,
-        -1.7095e-01, -1.4318e-02,  3.6932e-01,  9.9764e-01, -2.6413e-08,
-         1.7752e-01, -3.1618e-01, -1.4959e-01,  1.0415e-01,  9.6082e-01,
-         5.4970e-12, -2.1604e-06,  7.7696e-10,  1.9751e-01,  1.1420e+00,
-         2.0754e-09, -3.6488e-01,  3.8071e-01, -3.8700e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2487, -0.0480,  0.0614, -0.0473,  2.7156,  0.3078,  0.0000,  0.0000,
-        -0.1209, -0.0138,  0.2035, -0.4556, -0.2629, -0.4380,  0.0000,  0.0000,
-         0.0000,  0.0102,  0.0000,  0.0000, -0.8758,  0.6682,  0.0000, -0.7066,
-         0.0426,  1.0705, -0.1351, -0.2144, -0.3474,  0.0000, -0.2229, -1.4453,
-         0.0000, -0.1579,  0.0000,  0.0000,  0.0000,  0.0000,  0.2847,  0.8464,
-        -0.1753,  0.0000,  0.2281,  0.0000,  0.0000, -0.1709,  0.0000,  0.3693,
-         0.9976,  0.0000,  0.1775, -0.3162, -0.1496,  0.1042,  0.9608,  0.0000,
-         0.0000,  0.0000,  0.1975,  1.1420,  0.0000, -0.3649,  0.3807, -0.3870],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2487, -0.0480,  0.0614, -0.0473,  2.7156,  0.3078,  0.0000,  0.0000,
-        -0.1209, -0.0138,  0.2035, -0.4556, -0.2629, -0.4380,  0.0000,  0.0000,
-         0.0000,  0.0102,  0.0000,  0.0000, -0.8758,  0.6682,  0.0000, -0.7066,
-         0.0426,  1.0705, -0.1351, -0.2144, -0.3474,  0.0000, -0.2229, -1.4453,
-         0.0000, -0.1579,  0.0000,  0.0000,  0.0000,  0.0000,  0.2847,  0.8464,
-        -0.1753,  0.0000,  0.2281,  0.0000,  0.0000, -0.1709,  0.0000,  0.3693,
-         0.9976,  0.0000,  0.1775, -0.3162, -0.1496,  0.1042,  0.9608,  0.0000,
-         0.0000,  0.0000,  0.1975,  1.1420,  0.0000, -0.3649,  0.3807, -0.3870],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.2867e-01, -1.5250e-02, -3.7444e-02, -4.9020e-02,  2.7123e+00,
-         2.4129e-01,  3.6405e-10, -2.8746e-03, -2.2854e-01,  1.5173e-02,
-         1.0161e-01, -4.9490e-01, -3.0574e-01, -4.6745e-01, -6.5284e-07,
-        -4.8522e-04, -4.7729e-05, -4.2990e-02, -1.1039e-07,  1.3273e-03,
-        -8.7459e-01,  6.4497e-01, -2.3443e-07, -7.5085e-01,  7.6758e-03,
-         1.0710e+00, -8.2192e-02, -2.4681e-01, -3.0131e-01,  3.4095e-07,
-        -1.3223e-01, -1.4415e+00, -8.7276e-11, -2.8222e-01,  4.3320e-05,
-         2.3374e-03,  1.4450e-06,  0.0000e+00,  3.0921e-01,  8.1886e-01,
-        -2.2082e-01, -1.9301e-04,  2.3073e-01,  3.5564e-07,  8.6985e-05,
-        -1.0220e-01, -1.2273e-02,  3.9944e-01,  1.0259e+00, -2.2641e-08,
-         9.8165e-02, -2.4832e-01, -9.6035e-02, -7.4285e-02,  9.5476e-01,
-         4.7119e-12, -1.8519e-06,  6.6600e-10,  2.0447e-01,  1.1440e+00,
-         1.7790e-09, -3.5145e-01,  3.6884e-01, -3.9327e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2287, -0.0153, -0.0374, -0.0490,  2.7123,  0.2413,  0.0000,  0.0000,
-        -0.2285,  0.0152,  0.1016, -0.4949, -0.3057, -0.4675,  0.0000,  0.0000,
-         0.0000, -0.0430,  0.0000,  0.0000, -0.8746,  0.6450,  0.0000, -0.7509,
-         0.0077,  1.0710, -0.0822, -0.2468, -0.3013,  0.0000, -0.1322, -1.4415,
-         0.0000, -0.2822,  0.0000,  0.0000,  0.0000,  0.0000,  0.3092,  0.8189,
-        -0.2208,  0.0000,  0.2307,  0.0000,  0.0000, -0.1022,  0.0000,  0.3994,
-         1.0259,  0.0000,  0.0982, -0.2483, -0.0960, -0.0743,  0.9548,  0.0000,
-         0.0000,  0.0000,  0.2045,  1.1440,  0.0000, -0.3514,  0.3688, -0.3933],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2287, -0.0153, -0.0374, -0.0490,  2.7123,  0.2413,  0.0000,  0.0000,
-        -0.2285,  0.0152,  0.1016, -0.4949, -0.3057, -0.4675,  0.0000,  0.0000,
-         0.0000, -0.0430,  0.0000,  0.0000, -0.8746,  0.6450,  0.0000, -0.7509,
-         0.0077,  1.0710, -0.0822, -0.2468, -0.3013,  0.0000, -0.1322, -1.4415,
-         0.0000, -0.2822,  0.0000,  0.0000,  0.0000,  0.0000,  0.3092,  0.8189,
-        -0.2208,  0.0000,  0.2307,  0.0000,  0.0000, -0.1022,  0.0000,  0.3994,
-         1.0259,  0.0000,  0.0982, -0.2483, -0.0960, -0.0743,  0.9548,  0.0000,
-         0.0000,  0.0000,  0.2045,  1.1440,  0.0000, -0.3514,  0.3688, -0.3933],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4716e-01, -1.0358e-03, -9.9892e-02, -1.9029e-05,  2.7086e+00,
-         2.1204e-01,  3.1214e-10, -2.4647e-03, -3.0539e-01,  3.6697e-02,
-         3.3641e-02, -5.1736e-01, -3.1680e-01, -4.8411e-01, -5.5975e-07,
-        -4.1602e-04, -4.0922e-05, -6.5540e-02, -9.4651e-08,  1.1380e-03,
-        -8.6322e-01,  6.2180e-01, -2.0100e-07, -7.7742e-01, -2.8044e-02,
-         1.0749e+00, -1.6849e-02, -2.5378e-01, -2.0055e-01,  2.9233e-07,
-        -5.5885e-02, -1.4341e+00, -7.4831e-11, -3.9203e-01,  3.7142e-05,
-         2.0041e-03,  1.2390e-06,  0.0000e+00,  3.3545e-01,  8.0387e-01,
-        -2.5088e-01, -1.6549e-04,  2.4577e-01,  3.0492e-07,  7.4581e-05,
-        -1.2534e-01, -1.0523e-02,  4.1276e-01,  1.0505e+00, -1.9412e-08,
-        -4.5639e-03, -1.5887e-01, -1.2574e-01, -2.1929e-01,  9.5060e-01,
-         4.0400e-12, -1.5878e-06,  5.7103e-10,  1.9028e-01,  1.1417e+00,
-         1.5253e-09, -3.4495e-01,  3.4884e-01, -4.0268e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.4716e-01, -1.0358e-03, -9.9892e-02, -1.9029e-05,  2.7086e+00,
-         2.1204e-01,  0.0000e+00,  0.0000e+00, -3.0539e-01,  3.6697e-02,
-         3.3641e-02, -5.1736e-01, -3.1680e-01, -4.8411e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -6.5540e-02,  0.0000e+00,  0.0000e+00,
-        -8.6322e-01,  6.2180e-01,  0.0000e+00, -7.7742e-01, -2.8044e-02,
-         1.0749e+00, -1.6849e-02, -2.5378e-01, -2.0055e-01,  0.0000e+00,
-        -5.5885e-02, -1.4341e+00,  0.0000e+00, -3.9203e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.3545e-01,  8.0387e-01,
-        -2.5088e-01,  0.0000e+00,  2.4577e-01,  0.0000e+00,  0.0000e+00,
-        -1.2534e-01,  0.0000e+00,  4.1276e-01,  1.0505e+00,  0.0000e+00,
-        -4.5639e-03, -1.5887e-01, -1.2574e-01, -2.1929e-01,  9.5060e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.9028e-01,  1.1417e+00,
-         0.0000e+00, -3.4495e-01,  3.4884e-01, -4.0268e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.4716e-01, -1.0358e-03, -9.9892e-02, -1.9029e-05,  2.7086e+00,
-         2.1204e-01,  0.0000e+00,  0.0000e+00, -3.0539e-01,  3.6697e-02,
-         3.3641e-02, -5.1736e-01, -3.1680e-01, -4.8411e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -6.5540e-02,  0.0000e+00,  0.0000e+00,
-        -8.6322e-01,  6.2180e-01,  0.0000e+00, -7.7742e-01, -2.8044e-02,
-         1.0749e+00, -1.6849e-02, -2.5378e-01, -2.0055e-01,  0.0000e+00,
-        -5.5885e-02, -1.4341e+00,  0.0000e+00, -3.9203e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.3545e-01,  8.0387e-01,
-        -2.5088e-01,  0.0000e+00,  2.4577e-01,  0.0000e+00,  0.0000e+00,
-        -1.2534e-01,  0.0000e+00,  4.1276e-01,  1.0505e+00,  0.0000e+00,
-        -4.5639e-03, -1.5887e-01, -1.2574e-01, -2.1929e-01,  9.5060e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.9028e-01,  1.1417e+00,
-         0.0000e+00, -3.4495e-01,  3.4884e-01, -4.0268e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9035e-01,  3.3389e-03, -8.9543e-02,  1.4478e-01,  2.7035e+00,
-         1.9249e-01,  2.6769e-10, -2.1137e-03, -3.4407e-01,  2.2397e-02,
-         2.5902e-02, -5.1962e-01, -2.5383e-01, -4.8357e-01, -4.8005e-07,
-        -3.5679e-04, -3.5096e-05, -1.2338e-02, -8.1174e-08,  9.7598e-04,
-        -8.4771e-01,  6.2627e-01, -1.7238e-07, -8.2172e-01, -2.3617e-02,
-         1.0763e+00,  4.3004e-02, -2.1612e-01, -1.2854e-01,  2.5071e-07,
-        -1.0376e-01, -1.4263e+00, -6.4176e-11, -4.5838e-01,  3.1854e-05,
-         1.7187e-03,  1.0626e-06,  0.0000e+00,  3.3162e-01,  7.8801e-01,
-        -2.3809e-01, -1.4192e-04,  2.5768e-01,  2.6151e-07,  6.3962e-05,
-        -1.5032e-01, -9.0249e-03,  3.9810e-01,  1.0645e+00, -1.6648e-08,
-        -7.8823e-02, -5.0953e-02, -2.3022e-01, -2.2067e-01,  9.4399e-01,
-         3.4648e-12, -1.3617e-06,  4.8972e-10,  1.0936e-01,  1.1370e+00,
-         1.3082e-09, -3.2614e-01,  3.0548e-01, -3.8484e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2903,  0.0033, -0.0895,  0.1448,  2.7035,  0.1925,  0.0000,  0.0000,
-        -0.3441,  0.0224,  0.0259, -0.5196, -0.2538, -0.4836,  0.0000,  0.0000,
-         0.0000, -0.0123,  0.0000,  0.0000, -0.8477,  0.6263,  0.0000, -0.8217,
-        -0.0236,  0.0000,  0.0430, -0.2161, -0.1285,  0.0000, -0.1038, -1.4263,
-         0.0000, -0.4584,  0.0000,  0.0000,  0.0000,  0.0000,  0.3316,  0.7880,
-        -0.2381,  0.0000,  0.2577,  0.0000,  0.0000, -0.1503,  0.0000,  0.3981,
-         1.0645,  0.0000, -0.0788, -0.0510, -0.2302, -0.2207,  0.9440,  0.0000,
-         0.0000,  0.0000,  0.1094,  1.1370,  0.0000, -0.3261,  0.3055, -0.3848],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2903,  0.0033, -0.0895,  0.1448,  2.7035,  0.1925,  0.0000,  0.0000,
-        -0.3441,  0.0224,  0.0259, -0.5196, -0.2538, -0.4836,  0.0000,  0.0000,
-         0.0000, -0.0123,  0.0000,  0.0000, -0.8477,  0.6263,  0.0000, -0.8217,
-        -0.0236,  0.0000,  0.0430, -0.2161, -0.1285,  0.0000, -0.1038, -1.4263,
-         0.0000, -0.4584,  0.0000,  0.0000,  0.0000,  0.0000,  0.3316,  0.7880,
-        -0.2381,  0.0000,  0.2577,  0.0000,  0.0000, -0.1503,  0.0000,  0.3981,
-         1.0645,  0.0000, -0.0788, -0.0510, -0.2302, -0.2207,  0.9440,  0.0000,
-         0.0000,  0.0000,  0.1094,  1.1370,  0.0000, -0.3261,  0.3055, -0.3848],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9569e-01,  3.4081e-03, -1.0599e-01,  2.0427e-01,  2.6993e+00,
-         1.4784e-01,  2.2964e-10, -1.8133e-03, -3.6219e-01, -2.4117e-02,
-         1.1007e-01, -4.9826e-01, -2.1290e-01, -4.9086e-01, -4.1181e-07,
-        -3.0607e-04, -3.0107e-05,  6.6141e-02, -6.9635e-08,  8.3724e-04,
-        -8.2476e-01,  6.1421e-01, -1.4788e-07, -8.5635e-01,  4.1950e-03,
-         1.2002e-03,  7.2389e-02, -1.5932e-01, -6.7514e-02,  2.1507e-07,
-        -1.7342e-01, -1.4171e+00, -5.5053e-11, -4.9000e-01,  2.7326e-05,
-         1.4744e-03,  9.1152e-07,  0.0000e+00,  3.0484e-01,  7.9335e-01,
-        -1.7200e-01, -1.2175e-04,  2.3440e-01,  2.2433e-07,  5.4870e-05,
-        -1.4847e-01, -7.7420e-03,  3.9458e-01,  1.0804e+00, -1.4282e-08,
-        -1.2079e-01,  7.9625e-02, -3.0224e-01, -2.2134e-01,  9.4535e-01,
-         2.9722e-12, -1.1681e-06,  4.2011e-10,  1.7522e-02,  1.1286e+00,
-         1.1222e-09, -3.1601e-01,  3.0646e-01, -3.8071e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2957,  0.0034, -0.1060,  0.2043,  2.6993,  0.1478,  0.0000,  0.0000,
-        -0.3622, -0.0241,  0.1101, -0.4983, -0.2129, -0.4909,  0.0000,  0.0000,
-         0.0000,  0.0661,  0.0000,  0.0000, -0.8248,  0.6142,  0.0000, -0.8563,
-         0.0042,  0.0000,  0.0724, -0.1593, -0.0675,  0.0000, -0.1734, -1.4171,
-         0.0000, -0.4900,  0.0000,  0.0000,  0.0000,  0.0000,  0.3048,  0.7933,
-        -0.1720,  0.0000,  0.2344,  0.0000,  0.0000, -0.1485,  0.0000,  0.3946,
-         1.0804,  0.0000, -0.1208,  0.0796, -0.3022, -0.2213,  0.9454,  0.0000,
-         0.0000,  0.0000,  0.0175,  1.1286,  0.0000, -0.3160,  0.3065, -0.3807],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2957,  0.0034, -0.1060,  0.2043,  2.6993,  0.1478,  0.0000,  0.0000,
-        -0.3622, -0.0241,  0.1101, -0.4983, -0.2129, -0.4909,  0.0000,  0.0000,
-         0.0000,  0.0661,  0.0000,  0.0000, -0.8248,  0.6142,  0.0000, -0.8563,
-         0.0042,  0.0000,  0.0724, -0.1593, -0.0675,  0.0000, -0.1734, -1.4171,
-         0.0000, -0.4900,  0.0000,  0.0000,  0.0000,  0.0000,  0.3048,  0.7933,
-        -0.1720,  0.0000,  0.2344,  0.0000,  0.0000, -0.1485,  0.0000,  0.3946,
-         1.0804,  0.0000, -0.1208,  0.0796, -0.3022, -0.2213,  0.9454,  0.0000,
-         0.0000,  0.0000,  0.0175,  1.1286,  0.0000, -0.3160,  0.3065, -0.3807],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8308e-01, -8.8828e-03, -1.1465e-01,  2.1035e-01,  2.6929e+00,
-         1.5552e-01,  1.9705e-10, -1.5559e-03, -3.7568e-01, -1.0373e-01,
-         2.8768e-01, -4.4622e-01, -1.7766e-01, -4.6545e-01, -3.5336e-07,
-        -2.6263e-04, -2.5834e-05,  8.4867e-02, -5.9751e-08,  7.1841e-04,
-        -7.8696e-01,  5.6132e-01, -1.2689e-07, -8.8006e-01, -4.8216e-04,
-         1.0298e-03,  1.0271e-01, -8.6071e-02, -5.2635e-03,  1.8454e-07,
-        -3.2834e-01, -1.4089e+00, -4.7239e-11, -5.2398e-01,  2.3447e-05,
-         1.2651e-03,  7.8215e-07,  0.0000e+00,  3.0027e-01,  7.9244e-01,
-        -6.1583e-02, -1.0447e-04,  2.0611e-01,  1.9249e-07,  4.7082e-05,
-        -1.8831e-01, -6.6431e-03,  4.0418e-01,  1.0928e+00, -1.2255e-08,
-        -1.1987e-01,  2.1263e-01, -3.4670e-01, -2.9571e-01,  9.3391e-01,
-         2.5504e-12, -1.0023e-06,  3.6048e-10, -1.0035e-01,  1.1228e+00,
-         9.6292e-10, -3.5352e-01,  2.8800e-01, -3.9390e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.8308e-01, -8.8828e-03, -1.1465e-01,  2.1035e-01,  2.6929e+00,
-         1.5552e-01,  0.0000e+00,  0.0000e+00, -3.7568e-01, -1.0373e-01,
-         2.8768e-01, -4.4622e-01, -1.7766e-01, -4.6545e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  8.4867e-02,  0.0000e+00,  0.0000e+00,
-        -7.8696e-01,  5.6132e-01,  0.0000e+00, -8.8006e-01, -4.8216e-04,
-         0.0000e+00,  1.0271e-01, -8.6071e-02, -5.2635e-03,  0.0000e+00,
-        -3.2834e-01, -1.4089e+00,  0.0000e+00, -5.2398e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.0027e-01,  7.9244e-01,
-        -6.1583e-02,  0.0000e+00,  2.0611e-01,  0.0000e+00,  0.0000e+00,
-        -1.8831e-01,  0.0000e+00,  4.0418e-01,  1.0928e+00,  0.0000e+00,
-        -1.1987e-01,  2.1263e-01, -3.4670e-01, -2.9571e-01,  9.3391e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0035e-01,  1.1228e+00,
-         0.0000e+00, -3.5352e-01,  2.8800e-01, -3.9390e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.8308e-01, -8.8828e-03, -1.1465e-01,  2.1035e-01,  2.6929e+00,
-         1.5552e-01,  0.0000e+00,  0.0000e+00, -3.7568e-01, -1.0373e-01,
-         2.8768e-01, -4.4622e-01, -1.7766e-01, -4.6545e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  8.4867e-02,  0.0000e+00,  0.0000e+00,
-        -7.8696e-01,  5.6132e-01,  0.0000e+00, -8.8006e-01, -4.8216e-04,
-         0.0000e+00,  1.0271e-01, -8.6071e-02, -5.2635e-03,  0.0000e+00,
-        -3.2834e-01, -1.4089e+00,  0.0000e+00, -5.2398e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.0027e-01,  7.9244e-01,
-        -6.1583e-02,  0.0000e+00,  2.0611e-01,  0.0000e+00,  0.0000e+00,
-        -1.8831e-01,  0.0000e+00,  4.0418e-01,  1.0928e+00,  0.0000e+00,
-        -1.1987e-01,  2.1263e-01, -3.4670e-01, -2.9571e-01,  9.3391e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0035e-01,  1.1228e+00,
-         0.0000e+00, -3.5352e-01,  2.8800e-01, -3.9390e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7054e-01, -9.1906e-03, -1.0666e-01,  1.5802e-01,  2.6860e+00,
-         1.7085e-01,  1.6912e-10, -1.3354e-03, -3.7799e-01, -1.6671e-01,
-         4.6542e-01, -4.1445e-01, -1.6099e-01, -4.4486e-01, -3.0329e-07,
-        -2.2541e-04, -2.2173e-05,  4.0656e-02, -5.1284e-08,  6.1661e-04,
-        -7.5333e-01,  4.7912e-01, -1.0891e-07, -9.1430e-01, -1.6345e-02,
-         8.8392e-04,  1.6487e-01, -2.5692e-02,  1.7473e-02,  1.5839e-07,
-        -4.6980e-01, -1.4026e+00, -4.0545e-11, -5.5277e-01,  2.0125e-05,
-         1.0859e-03,  6.7132e-07,  0.0000e+00,  2.9724e-01,  7.8694e-01,
-         5.2851e-02, -8.9665e-05,  1.6731e-01,  1.6522e-07,  4.0410e-05,
-        -1.5864e-01, -5.7018e-03,  4.3923e-01,  1.1110e+00, -1.0518e-08,
-        -8.3497e-02,  2.7116e-01, -3.4538e-01, -3.0322e-01,  9.1715e-01,
-         2.1890e-12, -8.6031e-07,  3.0940e-10, -1.9250e-01,  1.1194e+00,
-         8.2648e-10, -3.4191e-01,  2.7031e-01, -3.8746e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2705, -0.0092, -0.1067,  0.1580,  2.6860,  0.1709,  0.0000,  0.0000,
-        -0.3780, -0.1667,  0.4654, -0.4144, -0.1610, -0.4449,  0.0000,  0.0000,
-         0.0000,  0.0407,  0.0000,  0.0000, -0.7533,  0.4791,  0.0000, -0.9143,
-        -0.0163,  0.0000,  0.1649, -0.0257,  0.0175,  0.0000, -0.4698, -1.4026,
-         0.0000, -0.5528,  0.0000,  0.0000,  0.0000,  0.0000,  0.2972,  0.7869,
-         0.0529,  0.0000,  0.1673,  0.0000,  0.0000, -0.1586,  0.0000,  0.4392,
-         1.1110,  0.0000, -0.0835,  0.2712, -0.3454, -0.3032,  0.9172,  0.0000,
-         0.0000,  0.0000, -0.1925,  1.1194,  0.0000, -0.3419,  0.2703, -0.3875],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2705, -0.0092, -0.1067,  0.1580,  2.6860,  0.1709,  0.0000,  0.0000,
-        -0.3780, -0.1667,  0.4654, -0.4144, -0.1610, -0.4449,  0.0000,  0.0000,
-         0.0000,  0.0407,  0.0000,  0.0000, -0.7533,  0.4791,  0.0000, -0.9143,
-        -0.0163,  0.0000,  0.1649, -0.0257,  0.0175,  0.0000, -0.4698, -1.4026,
-         0.0000, -0.5528,  0.0000,  0.0000,  0.0000,  0.0000,  0.2972,  0.7869,
-         0.0529,  0.0000,  0.1673,  0.0000,  0.0000, -0.1586,  0.0000,  0.4392,
-         1.1110,  0.0000, -0.0835,  0.2712, -0.3454, -0.3032,  0.9172,  0.0000,
-         0.0000,  0.0000, -0.1925,  1.1194,  0.0000, -0.3419,  0.2703, -0.3875],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8211e-01, -1.8207e-02, -5.0423e-02,  1.8016e-01,  2.6772e+00,
-         2.7024e-01,  1.4520e-10, -1.1465e-03, -3.8419e-01, -1.7157e-01,
-         5.9033e-01, -4.1022e-01, -1.3199e-01, -4.5059e-01, -2.6038e-07,
-        -1.9352e-04, -1.9036e-05,  1.7195e-02, -4.4029e-08,  5.2938e-04,
-        -7.3625e-01,  4.1476e-01, -9.3502e-08, -9.3738e-01, -3.7001e-02,
-         7.5887e-04,  2.6493e-01,  4.7494e-02, -6.4919e-02,  1.3598e-07,
-        -5.5556e-01, -1.4015e+00, -3.4809e-11, -5.7963e-01,  1.7278e-05,
-         9.3224e-04,  5.7634e-07,  0.0000e+00,  2.3386e-01,  7.6617e-01,
-         9.7468e-02, -7.6980e-05,  1.4542e-01,  1.4184e-07,  3.4693e-05,
-        -6.4191e-02, -4.8952e-03,  4.4443e-01,  1.1285e+00, -9.0301e-09,
-        -1.1233e-01,  2.4889e-01, -3.8594e-01, -1.1542e-01,  8.9278e-01,
-         1.8793e-12, -7.3860e-07,  2.6563e-10, -2.1836e-01,  1.1199e+00,
-         7.0955e-10, -2.2768e-01,  1.9248e-01, -3.4609e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2821, -0.0182, -0.0504,  0.1802,  2.6772,  0.2702,  0.0000,  0.0000,
-        -0.3842, -0.1716,  0.5903, -0.4102, -0.1320, -0.4506,  0.0000,  0.0000,
-         0.0000,  0.0172,  0.0000,  0.0000, -0.7363,  0.4148,  0.0000, -0.9374,
-        -0.0370,  0.0000,  0.2649,  0.0475, -0.0649,  0.0000, -0.5556, -1.4015,
-         0.0000, -0.5796,  0.0000,  0.0000,  0.0000,  0.0000,  0.2339,  0.7662,
-         0.0975,  0.0000,  0.1454,  0.0000,  0.0000, -0.0642,  0.0000,  0.4444,
-         1.1285,  0.0000, -0.1123,  0.2489, -0.3859, -0.1154,  0.8928,  0.0000,
-         0.0000,  0.0000, -0.2184,  1.1199,  0.0000, -0.2277,  0.1925, -0.3461],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2821, -0.0182, -0.0504,  0.1802,  2.6772,  0.2702,  0.0000,  0.0000,
-        -0.3842, -0.1716,  0.5903, -0.4102, -0.1320, -0.4506,  0.0000,  0.0000,
-         0.0000,  0.0172,  0.0000,  0.0000, -0.7363,  0.4148,  0.0000, -0.9374,
-        -0.0370,  0.0000,  0.2649,  0.0475, -0.0649,  0.0000, -0.5556, -1.4015,
-         0.0000, -0.5796,  0.0000,  0.0000,  0.0000,  0.0000,  0.2339,  0.7662,
-         0.0975,  0.0000,  0.1454,  0.0000,  0.0000, -0.0642,  0.0000,  0.4444,
-         1.1285,  0.0000, -0.1123,  0.2489, -0.3859, -0.1154,  0.8928,  0.0000,
-         0.0000,  0.0000, -0.2184,  1.1199,  0.0000, -0.2277,  0.1925, -0.3461],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1033e-01, -1.5015e-02,  6.6583e-03,  2.7459e-01,  2.6686e+00,
-         3.8367e-01,  1.2469e-10, -9.8457e-04, -3.6747e-01, -1.6769e-01,
-         6.4531e-01, -3.8695e-01, -1.0737e-01, -4.2403e-01, -2.2361e-07,
-        -1.6619e-04, -1.6347e-05,  4.2570e-02, -3.7811e-08,  4.5461e-04,
-        -7.2399e-01,  4.0067e-01, -8.0296e-08, -9.1474e-01,  2.0353e-03,
-         6.5169e-04,  3.3559e-01,  1.8437e-01, -1.4613e-01,  1.1678e-07,
-        -6.0815e-01, -1.4009e+00, -2.9893e-11, -5.7324e-01,  1.4837e-05,
-         8.0057e-04,  4.9494e-07,  0.0000e+00,  1.7548e-01,  7.3800e-01,
-         1.4706e-01, -6.6108e-05,  1.3747e-01,  1.2181e-07,  2.9793e-05,
-         1.1668e-02, -4.2038e-03,  4.1253e-01,  1.1332e+00, -7.7547e-09,
-        -1.8378e-01,  1.9610e-01, -4.3716e-01,  1.1799e-01,  8.7546e-01,
-         1.6139e-12, -6.3428e-07,  2.2811e-10, -2.2326e-01,  1.1162e+00,
-         6.0934e-10, -1.3604e-01,  8.3063e-02, -3.2780e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.1033e-01, -1.5015e-02,  6.6583e-03,  2.7459e-01,  2.6686e+00,
-         3.8367e-01,  0.0000e+00,  0.0000e+00, -3.6747e-01, -1.6769e-01,
-         6.4531e-01, -3.8695e-01, -1.0737e-01, -4.2403e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  4.2570e-02,  0.0000e+00,  0.0000e+00,
-        -7.2399e-01,  4.0067e-01,  0.0000e+00, -9.1474e-01,  2.0353e-03,
-         0.0000e+00,  3.3559e-01,  1.8437e-01, -1.4613e-01,  0.0000e+00,
-        -6.0815e-01, -1.4009e+00,  0.0000e+00, -5.7324e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.7548e-01,  7.3800e-01,
-         1.4706e-01,  0.0000e+00,  1.3747e-01,  0.0000e+00,  0.0000e+00,
-         1.1668e-02,  0.0000e+00,  4.1253e-01,  1.1332e+00,  0.0000e+00,
-        -1.8378e-01,  1.9610e-01, -4.3716e-01,  1.1799e-01,  8.7546e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -2.2326e-01,  1.1162e+00,
-         0.0000e+00, -1.3604e-01,  8.3063e-02, -3.2780e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.1033e-01, -1.5015e-02,  6.6583e-03,  2.7459e-01,  2.6686e+00,
-         3.8367e-01,  0.0000e+00,  0.0000e+00, -3.6747e-01, -1.6769e-01,
-         6.4531e-01, -3.8695e-01, -1.0737e-01, -4.2403e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  4.2570e-02,  0.0000e+00,  0.0000e+00,
-        -7.2399e-01,  4.0067e-01,  0.0000e+00, -9.1474e-01,  2.0353e-03,
-         0.0000e+00,  3.3559e-01,  1.8437e-01, -1.4613e-01,  0.0000e+00,
-        -6.0815e-01, -1.4009e+00,  0.0000e+00, -5.7324e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.7548e-01,  7.3800e-01,
-         1.4706e-01,  0.0000e+00,  1.3747e-01,  0.0000e+00,  0.0000e+00,
-         1.1668e-02,  0.0000e+00,  4.1253e-01,  1.1332e+00,  0.0000e+00,
-        -1.8378e-01,  1.9610e-01, -4.3716e-01,  1.1799e-01,  8.7546e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -2.2326e-01,  1.1162e+00,
-         0.0000e+00, -1.3604e-01,  8.3063e-02, -3.2780e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1219e-01,  2.9300e-02, -2.0959e-02,  2.6107e-01,  2.6620e+00,
-         4.6770e-01,  1.0711e-10, -8.4575e-04, -3.4298e-01, -1.6129e-01,
-         6.5810e-01, -3.6769e-01, -1.3176e-01, -4.1456e-01, -1.9208e-07,
-        -1.4276e-04, -1.4042e-05, -4.3844e-02, -3.2479e-08,  3.9051e-04,
-        -7.1437e-01,  3.7055e-01, -6.8974e-08, -8.8042e-01,  7.7549e-02,
-         5.5980e-04,  3.7640e-01,  2.5000e-01, -1.9239e-01,  1.0031e-07,
-        -6.5497e-01, -1.3965e+00, -2.5678e-11, -5.4421e-01,  1.2745e-05,
-         6.8769e-04,  4.2515e-07,  0.0000e+00,  2.1125e-01,  7.0329e-01,
-         2.0211e-01, -5.6787e-05,  6.2976e-02,  1.0463e-07,  2.5592e-05,
-        -2.6796e-02, -3.6110e-03,  4.1877e-01,  1.1257e+00, -6.6612e-09,
-        -1.7620e-01,  1.8543e-01, -4.1950e-01,  5.5684e-02,  8.5993e-01,
-         1.3863e-12, -5.4485e-07,  1.9595e-10, -1.7314e-01,  1.1132e+00,
-         5.2342e-10, -1.5250e-01,  5.6999e-02, -3.3169e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3122,  0.0293, -0.0210,  0.2611,  2.6620,  0.4677,  0.0000,  0.0000,
-        -0.3430, -0.1613,  0.6581,  0.0000, -0.1318, -0.4146,  0.0000,  0.0000,
-         0.0000, -0.0438,  0.0000,  0.0000, -0.7144,  0.3706,  0.0000, -0.8804,
-         0.0775,  0.0000,  0.3764,  0.2500, -0.1924,  0.0000, -0.6550, -1.3965,
-         0.0000, -0.5442,  0.0000,  0.0000,  0.0000,  0.0000,  0.2113,  0.7033,
-         0.2021,  0.0000,  0.0630,  0.0000,  0.0000, -0.0268,  0.0000,  0.4188,
-         1.1257,  0.0000, -0.1762,  0.1854, -0.4195,  0.0557,  0.8599,  0.0000,
-         0.0000,  0.0000, -0.1731,  1.1132,  0.0000, -0.1525,  0.0570, -0.3317],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3122,  0.0293, -0.0210,  0.2611,  2.6620,  0.4677,  0.0000,  0.0000,
-        -0.3430, -0.1613,  0.6581,  0.0000, -0.1318, -0.4146,  0.0000,  0.0000,
-         0.0000, -0.0438,  0.0000,  0.0000, -0.7144,  0.3706,  0.0000, -0.8804,
-         0.0775,  0.0000,  0.3764,  0.2500, -0.1924,  0.0000, -0.6550, -1.3965,
-         0.0000, -0.5442,  0.0000,  0.0000,  0.0000,  0.0000,  0.2113,  0.7033,
-         0.2021,  0.0000,  0.0630,  0.0000,  0.0000, -0.0268,  0.0000,  0.4188,
-         1.1257,  0.0000, -0.1762,  0.1854, -0.4195,  0.0557,  0.8599,  0.0000,
-         0.0000,  0.0000, -0.1731,  1.1132,  0.0000, -0.1525,  0.0570, -0.3317],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7915e-01,  9.1450e-02, -1.2974e-01,  8.5408e-02,  2.6555e+00,
-         4.3824e-01,  9.2032e-11, -7.2670e-04, -3.2821e-01, -1.3109e-01,
-         5.9990e-01,  1.6545e-02, -1.9055e-01, -4.1613e-01, -1.6504e-07,
-        -1.2266e-04, -1.2066e-05, -1.3184e-01, -2.7907e-08,  3.3554e-04,
-        -7.0439e-01,  3.3906e-01, -5.9265e-08, -8.2825e-01,  1.8413e-01,
-         4.8100e-04,  3.0777e-01,  2.2547e-01, -1.3378e-01,  8.6192e-08,
-        -6.7021e-01, -1.3889e+00, -2.2064e-11, -4.9412e-01,  1.0951e-05,
-         5.9089e-04,  3.6531e-07,  0.0000e+00,  3.2341e-01,  7.2898e-01,
-         2.6847e-01, -4.8793e-05, -8.9372e-02,  8.9906e-08,  2.1990e-05,
-        -2.1449e-01, -3.1027e-03,  4.1155e-01,  1.1156e+00, -5.7236e-09,
-        -1.1582e-01,  2.2942e-01, -4.0623e-01, -3.1000e-01,  8.7647e-01,
-         1.1912e-12, -4.6815e-07,  1.6836e-10, -1.2026e-01,  1.1004e+00,
-         4.4974e-10, -2.5656e-01,  1.2526e-01, -3.8036e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2791,  0.0915, -0.1297,  0.0854,  2.6555,  0.4382,  0.0000,  0.0000,
-        -0.3282, -0.1311,  0.5999,  0.0000, -0.1905, -0.4161,  0.0000,  0.0000,
-         0.0000, -0.1318,  0.0000,  0.0000, -0.7044,  0.3391,  0.0000, -0.8283,
-         0.1841,  0.0000,  0.3078,  0.2255, -0.1338,  0.0000, -0.6702, -1.3889,
-         0.0000, -0.4941,  0.0000,  0.0000,  0.0000,  0.0000,  0.3234,  0.7290,
-         0.2685,  0.0000, -0.0894,  0.0000,  0.0000, -0.2145,  0.0000,  0.4115,
-         1.1156,  0.0000, -0.1158,  0.2294, -0.4062, -0.3100,  0.8765,  0.0000,
-         0.0000,  0.0000, -0.1203,  1.1004,  0.0000, -0.2566,  0.1253, -0.3804],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2791,  0.0915, -0.1297,  0.0854,  2.6555,  0.4382,  0.0000,  0.0000,
-        -0.3282, -0.1311,  0.5999,  0.0000, -0.1905, -0.4161,  0.0000,  0.0000,
-         0.0000, -0.1318,  0.0000,  0.0000, -0.7044,  0.3391,  0.0000, -0.8283,
-         0.1841,  0.0000,  0.3078,  0.2255, -0.1338,  0.0000, -0.6702, -1.3889,
-         0.0000, -0.4941,  0.0000,  0.0000,  0.0000,  0.0000,  0.3234,  0.7290,
-         0.2685,  0.0000, -0.0894,  0.0000,  0.0000, -0.2145,  0.0000,  0.4115,
-         1.1156,  0.0000, -0.1158,  0.2294, -0.4062, -0.3100,  0.8765,  0.0000,
-         0.0000,  0.0000, -0.1203,  1.1004,  0.0000, -0.2566,  0.1253, -0.3804],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9986e-01,  1.7076e-01, -1.4639e-01, -1.7220e-02,  2.6504e+00,
-         4.0107e-01,  7.9100e-11, -6.2458e-04, -2.9674e-01, -2.0105e-02,
-         5.2917e-01,  1.4220e-02, -2.1059e-01, -4.1158e-01, -1.4185e-07,
-        -1.0543e-04, -1.0370e-05, -2.3769e-01, -2.3986e-08,  2.8839e-04,
-        -7.1288e-01,  2.9436e-01, -5.0937e-08, -8.0589e-01,  2.7268e-01,
-         4.1341e-04,  2.7770e-01,  1.8201e-01, -9.5636e-02,  7.4081e-08,
-        -6.5317e-01, -1.3796e+00, -1.8963e-11, -4.5579e-01,  9.4124e-06,
-         5.0786e-04,  3.1398e-07,  0.0000e+00,  5.1522e-01,  7.7031e-01,
-         3.0255e-01, -4.1937e-05, -2.0089e-01,  7.7272e-08,  1.8900e-05,
-        -3.7161e-01, -2.6667e-03,  4.5057e-01,  1.1019e+00, -4.9193e-09,
-        -8.0208e-02,  2.0031e-01, -4.0204e-01, -4.6599e-01,  8.9759e-01,
-         1.0238e-12, -4.0237e-07,  1.4471e-10, -9.4082e-02,  1.0884e+00,
-         3.8654e-10, -2.5072e-01,  1.3785e-01, -3.8612e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2999,  0.1708, -0.1464, -0.0172,  2.6504,  0.4011,  0.0000,  0.0000,
-        -0.2967, -0.0201,  0.5292,  0.0000, -0.2106, -0.4116,  0.0000,  0.0000,
-         0.0000, -0.2377,  0.0000,  0.0000, -0.7129,  0.2944,  0.0000, -0.8059,
-         0.2727,  0.0000,  0.2777,  0.1820, -0.0956,  0.0000, -0.6532, -1.3796,
-         0.0000, -0.4558,  0.0000,  0.0000,  0.0000,  0.0000,  0.5152,  0.7703,
-         0.3026,  0.0000, -0.2009,  0.0000,  0.0000, -0.3716,  0.0000,  0.4506,
-         1.1019,  0.0000, -0.0802,  0.2003, -0.4020, -0.4660,  0.8976,  0.0000,
-         0.0000,  0.0000, -0.0941,  1.0884,  0.0000, -0.2507,  0.1379, -0.3861],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2999,  0.1708, -0.1464, -0.0172,  2.6504,  0.4011,  0.0000,  0.0000,
-        -0.2967, -0.0201,  0.5292,  0.0000, -0.2106, -0.4116,  0.0000,  0.0000,
-         0.0000, -0.2377,  0.0000,  0.0000, -0.7129,  0.2944,  0.0000, -0.8059,
-         0.2727,  0.0000,  0.2777,  0.1820, -0.0956,  0.0000, -0.6532, -1.3796,
-         0.0000, -0.4558,  0.0000,  0.0000,  0.0000,  0.0000,  0.5152,  0.7703,
-         0.3026,  0.0000, -0.2009,  0.0000,  0.0000, -0.3716,  0.0000,  0.4506,
-         1.1019,  0.0000, -0.0802,  0.2003, -0.4020, -0.4660,  0.8976,  0.0000,
-         0.0000,  0.0000, -0.0941,  1.0884,  0.0000, -0.2507,  0.1379, -0.3861],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1589e-01,  2.4695e-01, -1.3026e-01, -4.1884e-02,  2.6430e+00,
-         3.5168e-01,  6.8004e-11, -5.3697e-04, -2.5217e-01,  1.0416e-01,
-         4.6424e-01,  1.2225e-02, -1.9891e-01, -4.3790e-01, -1.2195e-07,
-        -9.0638e-05, -8.9156e-06, -2.6892e-01, -2.0621e-08,  2.4794e-04,
-        -7.2984e-01,  2.6256e-01, -4.3792e-08, -8.0923e-01,  3.5193e-01,
-         3.5542e-04,  2.8743e-01,  1.5142e-01, -6.4233e-02,  6.3689e-08,
-        -6.3132e-01, -1.3727e+00, -1.6303e-11, -4.4096e-01,  8.0921e-06,
-         4.3662e-04,  2.6993e-07,  0.0000e+00,  5.9528e-01,  7.5479e-01,
-         3.3337e-01, -3.6054e-05, -2.7858e-01,  6.6433e-08,  1.6249e-05,
-        -4.3562e-01, -2.2927e-03,  4.5986e-01,  1.0795e+00, -4.2293e-09,
-        -4.6698e-02,  1.5005e-01, -4.2041e-01, -4.3854e-01,  8.9134e-01,
-         8.8018e-13, -3.4593e-07,  1.2441e-10,  4.6207e-02,  1.0784e+00,
-         3.3232e-10, -1.7998e-01,  7.0286e-02, -3.5704e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3159,  0.2469, -0.1303, -0.0419,  2.6430,  0.3517,  0.0000,  0.0000,
-        -0.2522,  0.1042,  0.4642,  0.0000, -0.1989, -0.4379,  0.0000,  0.0000,
-         0.0000, -0.2689,  0.0000,  0.0000, -0.7298,  0.2626,  0.0000, -0.8092,
-         0.3519,  0.0000,  0.2874,  0.1514, -0.0642,  0.0000, -0.6313, -1.3727,
-         0.0000, -0.4410,  0.0000,  0.0000,  0.0000,  0.0000,  0.5953,  0.7548,
-         0.3334,  0.0000, -0.2786,  0.0000,  0.0000, -0.4356,  0.0000,  0.4599,
-         1.0795,  0.0000, -0.0467,  0.1500, -0.4204, -0.4385,  0.8913,  0.0000,
-         0.0000,  0.0000,  0.0462,  1.0784,  0.0000, -0.1800,  0.0703, -0.3570],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3159,  0.2469, -0.1303, -0.0419,  2.6430,  0.3517,  0.0000,  0.0000,
-        -0.2522,  0.1042,  0.4642,  0.0000, -0.1989, -0.4379,  0.0000,  0.0000,
-         0.0000, -0.2689,  0.0000,  0.0000, -0.7298,  0.2626,  0.0000, -0.8092,
-         0.3519,  0.0000,  0.2874,  0.1514, -0.0642,  0.0000, -0.6313, -1.3727,
-         0.0000, -0.4410,  0.0000,  0.0000,  0.0000,  0.0000,  0.5953,  0.7548,
-         0.3334,  0.0000, -0.2786,  0.0000,  0.0000, -0.4356,  0.0000,  0.4599,
-         1.0795,  0.0000, -0.0467,  0.1500, -0.4204, -0.4385,  0.8913,  0.0000,
-         0.0000,  0.0000,  0.0462,  1.0784,  0.0000, -0.1800,  0.0703, -0.3570],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.3446e-01,  3.0970e-01, -7.9786e-02, -3.2658e-02,  2.6354e+00,
-         3.0823e-01,  5.8482e-11, -4.6178e-04, -2.1346e-01,  1.9379e-01,
-         3.8426e-01,  1.0513e-02, -1.6644e-01, -4.6282e-01, -1.0487e-07,
-        -7.7946e-05, -7.6672e-06, -2.5127e-01, -1.7734e-08,  2.1322e-04,
-        -7.3610e-01,  2.1041e-01, -3.7660e-08, -8.2115e-01,  3.9415e-01,
-         3.0565e-04,  3.0349e-01,  1.0883e-01, -4.6496e-02,  5.4771e-08,
-        -6.1660e-01, -1.3634e+00, -1.4020e-11, -4.3904e-01,  6.9590e-06,
-         3.7548e-04,  2.3213e-07,  0.0000e+00,  6.6014e-01,  7.4616e-01,
-         3.5960e-01, -3.1005e-05, -3.1356e-01,  5.7130e-08,  1.3974e-05,
-        -4.0975e-01, -1.9716e-03,  4.5151e-01,  1.0613e+00, -3.6370e-09,
-        -3.5793e-03,  7.0003e-02, -4.1092e-01, -4.1323e-01,  8.8169e-01,
-         7.5693e-13, -2.9749e-07,  1.0699e-10,  2.1963e-01,  1.0669e+00,
-         2.8579e-10, -7.5800e-02,  3.7141e-02, -2.9723e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3345,  0.3097, -0.0798, -0.0327,  2.6354,  0.3082,  0.0000,  0.0000,
-        -0.2135,  0.1938,  0.3843,  0.0000, -0.1664, -0.4628,  0.0000,  0.0000,
-         0.0000, -0.2513,  0.0000,  0.0000, -0.7361,  0.2104,  0.0000, -0.8212,
-         0.3941,  0.0000,  0.3035,  0.1088, -0.0465,  0.0000, -0.6166, -1.3634,
-         0.0000, -0.4390,  0.0000,  0.0000,  0.0000,  0.0000,  0.6601,  0.7462,
-         0.3596,  0.0000, -0.3136,  0.0000,  0.0000, -0.4098,  0.0000,  0.4515,
-         1.0613,  0.0000, -0.0036,  0.0700, -0.4109, -0.4132,  0.8817,  0.0000,
-         0.0000,  0.0000,  0.2196,  1.0669,  0.0000, -0.0758,  0.0371, -0.2972],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3345,  0.3097, -0.0798, -0.0327,  2.6354,  0.3082,  0.0000,  0.0000,
-        -0.2135,  0.1938,  0.3843,  0.0000, -0.1664, -0.4628,  0.0000,  0.0000,
-         0.0000, -0.2513,  0.0000,  0.0000, -0.7361,  0.2104,  0.0000, -0.8212,
-         0.3941,  0.0000,  0.3035,  0.1088, -0.0465,  0.0000, -0.6166, -1.3634,
-         0.0000, -0.4390,  0.0000,  0.0000,  0.0000,  0.0000,  0.6601,  0.7462,
-         0.3596,  0.0000, -0.3136,  0.0000,  0.0000, -0.4098,  0.0000,  0.4515,
-         1.0613,  0.0000, -0.0036,  0.0700, -0.4109, -0.4132,  0.8817,  0.0000,
-         0.0000,  0.0000,  0.2196,  1.0669,  0.0000, -0.0758,  0.0371, -0.2972],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4276e-01,  3.1536e-01, -3.5144e-03,  6.4016e-03,  2.6285e+00,
-         2.7322e-01,  5.0307e-11, -3.9723e-04, -1.4724e-01,  1.9810e-01,
-         2.9762e-01,  9.0438e-03, -7.5566e-02, -4.9886e-01, -9.0215e-08,
-        -6.7051e-05, -6.5955e-06, -1.7927e-01, -1.5255e-08,  1.8341e-04,
-        -7.3683e-01,  1.3379e-01, -3.2396e-08, -8.3544e-01,  4.0887e-01,
-         2.6293e-04,  3.0234e-01,  7.3373e-02, -5.8009e-02,  4.7115e-08,
-        -6.0160e-01, -1.3567e+00, -1.2060e-11, -4.5221e-01,  5.9863e-06,
-         3.2300e-04,  1.9969e-07,  0.0000e+00,  6.9583e-01,  7.1754e-01,
-         3.5014e-01, -2.6672e-05, -3.1348e-01,  4.9145e-08,  1.2020e-05,
-        -3.0537e-01, -1.6960e-03,  4.0880e-01,  1.0504e+00, -3.1287e-09,
-        -2.2029e-03, -1.0204e-02, -3.6491e-01, -3.9096e-01,  8.6109e-01,
-         6.5113e-13, -2.5590e-07,  9.2033e-11,  3.9798e-01,  1.0620e+00,
-         2.4584e-10,  8.0830e-02, -9.2811e-03, -2.2349e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.4276e-01,  3.1536e-01, -3.5144e-03,  6.4016e-03,  2.6285e+00,
-         2.7322e-01,  0.0000e+00,  0.0000e+00, -1.4724e-01,  1.9810e-01,
-         2.9762e-01,  0.0000e+00, -7.5566e-02, -4.9886e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.7927e-01,  0.0000e+00,  0.0000e+00,
-        -7.3683e-01,  1.3379e-01,  0.0000e+00, -8.3544e-01,  4.0887e-01,
-         0.0000e+00,  3.0234e-01,  7.3373e-02, -5.8009e-02,  0.0000e+00,
-        -6.0160e-01, -1.3567e+00,  0.0000e+00, -4.5221e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.9583e-01,  7.1754e-01,
-         3.5014e-01,  0.0000e+00, -3.1348e-01,  0.0000e+00,  0.0000e+00,
-        -3.0537e-01,  0.0000e+00,  4.0880e-01,  1.0504e+00,  0.0000e+00,
-        -2.2029e-03, -1.0204e-02, -3.6491e-01, -3.9096e-01,  8.6109e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.9798e-01,  1.0620e+00,
-         0.0000e+00,  8.0830e-02, -9.2811e-03, -2.2349e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.4276e-01,  3.1536e-01, -3.5144e-03,  6.4016e-03,  2.6285e+00,
-         2.7322e-01,  0.0000e+00,  0.0000e+00, -1.4724e-01,  1.9810e-01,
-         2.9762e-01,  0.0000e+00, -7.5566e-02, -4.9886e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.7927e-01,  0.0000e+00,  0.0000e+00,
-        -7.3683e-01,  1.3379e-01,  0.0000e+00, -8.3544e-01,  4.0887e-01,
-         0.0000e+00,  3.0234e-01,  7.3373e-02, -5.8009e-02,  0.0000e+00,
-        -6.0160e-01, -1.3567e+00,  0.0000e+00, -4.5221e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.9583e-01,  7.1754e-01,
-         3.5014e-01,  0.0000e+00, -3.1348e-01,  0.0000e+00,  0.0000e+00,
-        -3.0537e-01,  0.0000e+00,  4.0880e-01,  1.0504e+00,  0.0000e+00,
-        -2.2029e-03, -1.0204e-02, -3.6491e-01, -3.9096e-01,  8.6109e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.9798e-01,  1.0620e+00,
-         0.0000e+00,  8.0830e-02, -9.2811e-03, -2.2349e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2637e-01,  3.0079e-01,  6.8790e-02,  7.7466e-02,  2.6213e+00,
-         1.9715e-01,  4.3288e-11, -3.4181e-04, -7.1124e-02,  2.1444e-01,
-         2.6414e-01,  7.7820e-03, -1.2519e-02, -5.3058e-01, -7.7628e-08,
-        -5.7696e-05, -5.6752e-06, -9.7288e-02, -1.3126e-08,  1.5782e-04,
-        -7.2624e-01,  8.9622e-02, -2.7876e-08, -8.3681e-01,  4.3087e-01,
-         2.2624e-04,  3.1630e-01,  2.8983e-02, -3.3566e-02,  4.0541e-08,
-        -5.8602e-01, -1.3515e+00, -1.0378e-11, -4.1797e-01,  5.1510e-06,
-         2.7793e-04,  1.7183e-07,  0.0000e+00,  7.0137e-01,  6.9736e-01,
-         3.4796e-01, -2.2950e-05, -2.8276e-01,  4.2288e-08,  1.0343e-05,
-        -2.2141e-01, -1.4594e-03,  3.5084e-01,  1.0358e+00, -2.6921e-09,
-         1.8981e-03, -2.4629e-02, -3.0311e-01, -3.9143e-01,  8.4325e-01,
-         5.6028e-13, -2.2020e-07,  7.9192e-11,  5.5549e-01,  1.0506e+00,
-         2.1154e-10,  2.1586e-01, -3.3342e-02, -1.3558e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.2637e-01,  3.0079e-01,  6.8790e-02,  7.7466e-02,  2.6213e+00,
-         1.9715e-01,  0.0000e+00,  0.0000e+00, -7.1124e-02,  2.1444e-01,
-         2.6414e-01,  0.0000e+00, -1.2519e-02, -5.3058e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -9.7288e-02,  0.0000e+00,  0.0000e+00,
-        -7.2624e-01,  8.9622e-02,  0.0000e+00, -8.3681e-01,  4.3087e-01,
-         0.0000e+00,  3.1630e-01,  2.8983e-02, -3.3566e-02,  0.0000e+00,
-        -5.8602e-01, -1.3515e+00,  0.0000e+00, -4.1797e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  7.0137e-01,  6.9736e-01,
-         3.4796e-01,  0.0000e+00, -2.8276e-01,  0.0000e+00,  0.0000e+00,
-        -2.2141e-01,  0.0000e+00,  3.5084e-01,  1.0358e+00,  0.0000e+00,
-         1.8981e-03, -2.4629e-02, -3.0311e-01, -3.9143e-01,  8.4325e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.5549e-01,  0.0000e+00,
-         0.0000e+00,  2.1586e-01, -3.3342e-02, -1.3558e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.2637e-01,  3.0079e-01,  6.8790e-02,  7.7466e-02,  2.6213e+00,
-         1.9715e-01,  0.0000e+00,  0.0000e+00, -7.1124e-02,  2.1444e-01,
-         2.6414e-01,  0.0000e+00, -1.2519e-02, -5.3058e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -9.7288e-02,  0.0000e+00,  0.0000e+00,
-        -7.2624e-01,  8.9622e-02,  0.0000e+00, -8.3681e-01,  4.3087e-01,
-         0.0000e+00,  3.1630e-01,  2.8983e-02, -3.3566e-02,  0.0000e+00,
-        -5.8602e-01, -1.3515e+00,  0.0000e+00, -4.1797e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  7.0137e-01,  6.9736e-01,
-         3.4796e-01,  0.0000e+00, -2.8276e-01,  0.0000e+00,  0.0000e+00,
-        -2.2141e-01,  0.0000e+00,  3.5084e-01,  1.0358e+00,  0.0000e+00,
-         1.8981e-03, -2.4629e-02, -3.0311e-01, -3.9143e-01,  8.4325e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.5549e-01,  0.0000e+00,
-         0.0000e+00,  2.1586e-01, -3.3342e-02, -1.3558e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7827e-01,  2.7671e-01,  7.0787e-02,  1.3648e-01,  2.6155e+00,
-         1.0596e-01,  3.7259e-11, -2.9420e-04,  1.4822e-02,  1.9266e-01,
-         2.4061e-01,  6.6982e-03,  8.6958e-03, -5.6034e-01, -6.6816e-08,
-        -4.9660e-05, -4.8848e-06, -2.3332e-02, -1.1298e-08,  1.3584e-04,
-        -7.0285e-01,  8.9440e-02, -2.3993e-08, -8.2644e-01,  4.6786e-01,
-         1.9473e-04,  2.8796e-01,  2.7942e-03,  3.0316e-03,  3.4895e-08,
-        -5.6220e-01, -1.3471e+00, -8.9324e-12, -4.0848e-01,  4.4336e-06,
-         2.3922e-04,  1.4790e-07,  0.0000e+00,  6.7348e-01,  6.6133e-01,
-         3.5311e-01, -1.9754e-05, -2.7728e-01,  3.6398e-08,  8.9027e-06,
-        -1.6202e-01, -1.2561e-03,  2.9813e-01,  1.0379e+00, -2.3172e-09,
-         1.4623e-02, -3.3228e-02, -2.5407e-01, -3.8555e-01,  8.3178e-01,
-         4.8225e-13, -1.8953e-07,  6.8162e-11,  6.6804e-01, -9.8441e-03,
-         1.8208e-10,  2.7831e-01, -1.8815e-02, -6.2175e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2783,  0.2767,  0.0708,  0.1365,  2.6155,  0.1060,  0.0000,  0.0000,
-         0.0148,  0.1927,  0.2406,  0.0000,  0.0087, -0.5603,  0.0000,  0.0000,
-         0.0000, -0.0233,  0.0000,  0.0000, -0.7028,  0.0894,  0.0000, -0.8264,
-         0.4679,  0.0000,  0.2880,  0.0028,  0.0030,  0.0000, -0.5622, -1.3471,
-         0.0000, -0.4085,  0.0000,  0.0000,  0.0000,  0.0000,  0.6735,  0.6613,
-         0.3531,  0.0000, -0.2773,  0.0000,  0.0000, -0.1620,  0.0000,  0.2981,
-         1.0379,  0.0000,  0.0146, -0.0332, -0.2541, -0.3856,  0.8318,  0.0000,
-         0.0000,  0.0000,  0.6680,  0.0000,  0.0000,  0.2783, -0.0188, -0.0622],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2783,  0.2767,  0.0708,  0.1365,  2.6155,  0.1060,  0.0000,  0.0000,
-         0.0148,  0.1927,  0.2406,  0.0000,  0.0087, -0.5603,  0.0000,  0.0000,
-         0.0000, -0.0233,  0.0000,  0.0000, -0.7028,  0.0894,  0.0000, -0.8264,
-         0.4679,  0.0000,  0.2880,  0.0028,  0.0030,  0.0000, -0.5622, -1.3471,
-         0.0000, -0.4085,  0.0000,  0.0000,  0.0000,  0.0000,  0.6735,  0.6613,
-         0.3531,  0.0000, -0.2773,  0.0000,  0.0000, -0.1620,  0.0000,  0.2981,
-         1.0379,  0.0000,  0.0146, -0.0332, -0.2541, -0.3856,  0.8318,  0.0000,
-         0.0000,  0.0000,  0.6680,  0.0000,  0.0000,  0.2783, -0.0188, -0.0622],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.7936e-01,  3.0453e-01,  1.5097e-02,  1.7338e-01,  2.6121e+00,
-         1.3216e-02,  3.2080e-11, -2.5330e-04,  1.8587e-02,  1.7673e-01,
-         2.1569e-01,  5.7670e-03, -6.0712e-02, -5.8702e-01, -5.7528e-08,
-        -4.2757e-05, -4.2058e-06, -4.6727e-02, -9.7277e-09,  1.1696e-04,
-        -6.8272e-01,  1.1866e-01, -2.0658e-08, -8.0133e-01,  5.6137e-01,
-         1.6766e-04,  2.3355e-01, -1.3338e-03,  9.5496e-02,  3.0044e-08,
-        -5.4001e-01, -1.3467e+00, -7.6907e-12, -3.7930e-01,  3.8173e-06,
-         2.0597e-04,  1.2734e-07,  0.0000e+00,  6.0747e-01,  5.9366e-01,
-         4.1607e-01, -1.7008e-05, -2.9327e-01,  3.1338e-08,  7.6651e-06,
-        -1.0513e-01, -1.0815e-03,  2.2642e-01,  1.0519e+00, -1.9951e-09,
-         8.3533e-02, -4.9148e-03, -2.6531e-01, -3.1990e-01,  8.0637e-01,
-         4.1521e-13, -1.6318e-07,  5.8687e-11,  7.4716e-01, -8.4757e-03,
-         1.5677e-10,  2.4216e-01,  4.6870e-04, -3.9941e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 1.7936e-01,  3.0453e-01,  1.5097e-02,  1.7338e-01,  2.6121e+00,
-         1.3216e-02,  0.0000e+00,  0.0000e+00,  1.8587e-02,  1.7673e-01,
-         2.1569e-01,  0.0000e+00, -6.0712e-02, -5.8702e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.6727e-02,  0.0000e+00,  0.0000e+00,
-        -6.8272e-01,  1.1866e-01,  0.0000e+00, -8.0133e-01,  5.6137e-01,
-         0.0000e+00,  2.3355e-01, -1.3338e-03,  9.5496e-02,  0.0000e+00,
-        -5.4001e-01, -1.3467e+00,  0.0000e+00, -3.7930e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.0747e-01,  5.9366e-01,
-         4.1607e-01,  0.0000e+00, -2.9327e-01,  0.0000e+00,  0.0000e+00,
-        -1.0513e-01,  0.0000e+00,  2.2642e-01,  1.0519e+00,  0.0000e+00,
-         8.3533e-02, -4.9148e-03, -2.6531e-01, -3.1990e-01,  8.0637e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  7.4716e-01,  0.0000e+00,
-         0.0000e+00,  2.4216e-01,  4.6870e-04, -3.9941e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 1.7936e-01,  3.0453e-01,  1.5097e-02,  1.7338e-01,  2.6121e+00,
-         1.3216e-02,  0.0000e+00,  0.0000e+00,  1.8587e-02,  1.7673e-01,
-         2.1569e-01,  0.0000e+00, -6.0712e-02, -5.8702e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.6727e-02,  0.0000e+00,  0.0000e+00,
-        -6.8272e-01,  1.1866e-01,  0.0000e+00, -8.0133e-01,  5.6137e-01,
-         0.0000e+00,  2.3355e-01, -1.3338e-03,  9.5496e-02,  0.0000e+00,
-        -5.4001e-01, -1.3467e+00,  0.0000e+00, -3.7930e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.0747e-01,  5.9366e-01,
-         4.1607e-01,  0.0000e+00, -2.9327e-01,  0.0000e+00,  0.0000e+00,
-        -1.0513e-01,  0.0000e+00,  2.2642e-01,  1.0519e+00,  0.0000e+00,
-         8.3533e-02, -4.9148e-03, -2.6531e-01, -3.1990e-01,  8.0637e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  7.4716e-01,  0.0000e+00,
-         0.0000e+00,  2.4216e-01,  4.6870e-04, -3.9941e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 7.0065e-02,  3.1371e-01, -1.9506e-02,  2.0379e-01,  2.6102e+00,
-        -1.9865e-02,  2.7628e-11, -2.1816e-04,  5.2595e-02,  1.5022e-01,
-         1.5863e-01,  4.9668e-03, -1.3202e-01, -5.9518e-01, -4.9546e-08,
-        -3.6824e-05, -3.6222e-06, -1.2782e-01, -8.3779e-09,  1.0073e-04,
-        -6.4142e-01,  8.8004e-02, -1.7792e-08, -7.9065e-01,  6.0451e-01,
-         1.4440e-04,  1.7498e-01, -1.3778e-02,  2.0686e-01,  2.5875e-08,
-        -5.4047e-01, -1.3431e+00, -6.6236e-12, -3.4356e-01,  3.2876e-06,
-         1.7739e-04,  1.0967e-07,  0.0000e+00,  5.7402e-01,  5.6533e-01,
-         4.3342e-01, -1.4648e-05, -3.1958e-01,  2.6990e-08,  6.6015e-06,
-        -8.7567e-02, -9.3146e-04,  1.6705e-01,  1.0717e+00, -1.7182e-09,
-         9.8817e-02,  5.5946e-02, -2.5770e-01, -3.1858e-01,  7.8982e-01,
-         3.5760e-13, -1.4054e-07,  5.0544e-11,  7.7452e-01, -7.2996e-03,
-         1.3501e-10,  2.0570e-01, -3.0742e-03, -2.8856e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0701,  0.3137, -0.0195,  0.2038,  2.6102, -0.0199,  0.0000,  0.0000,
-         0.0526,  0.1502,  0.1586,  0.0000, -0.1320, -0.5952,  0.0000,  0.0000,
-         0.0000, -0.1278,  0.0000,  0.0000, -0.6414,  0.0880,  0.0000, -0.7907,
-         0.6045,  0.0000,  0.1750, -0.0138,  0.2069,  0.0000, -0.5405, -1.3431,
-         0.0000, -0.3436,  0.0000,  0.0000,  0.0000,  0.0000,  0.5740,  0.5653,
-         0.4334,  0.0000, -0.3196,  0.0000,  0.0000, -0.0876,  0.0000,  0.1670,
-         1.0717,  0.0000,  0.0988,  0.0559, -0.2577, -0.3186,  0.7898,  0.0000,
-         0.0000,  0.0000,  0.7745,  0.0000,  0.0000,  0.2057, -0.0031, -0.0289],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0701,  0.3137, -0.0195,  0.2038,  2.6102, -0.0199,  0.0000,  0.0000,
-         0.0526,  0.1502,  0.1586,  0.0000, -0.1320, -0.5952,  0.0000,  0.0000,
-         0.0000, -0.1278,  0.0000,  0.0000, -0.6414,  0.0880,  0.0000, -0.7907,
-         0.6045,  0.0000,  0.1750, -0.0138,  0.2069,  0.0000, -0.5405, -1.3431,
-         0.0000, -0.3436,  0.0000,  0.0000,  0.0000,  0.0000,  0.5740,  0.5653,
-         0.4334,  0.0000, -0.3196,  0.0000,  0.0000, -0.0876,  0.0000,  0.1670,
-         1.0717,  0.0000,  0.0988,  0.0559, -0.2577, -0.3186,  0.7898,  0.0000,
-         0.0000,  0.0000,  0.7745,  0.0000,  0.0000,  0.2057, -0.0031, -0.0289],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 9.3838e-02,  2.8430e-01,  4.1292e-02,  2.1230e-01,  2.6069e+00,
-         9.0751e-02,  2.3802e-11, -1.8794e-04,  1.3358e-01,  6.4146e-02,
-        -2.6545e-02,  4.2790e-03, -1.1373e-01, -5.9493e-01, -4.2684e-08,
-        -3.1724e-05, -3.1206e-06, -1.0975e-01, -7.2176e-09,  8.6780e-05,
-        -6.1596e-01, -4.0593e-02, -1.5328e-08, -8.0359e-01,  5.1597e-01,
-         1.2440e-04,  1.2774e-01, -6.3832e-03,  3.6516e-01,  2.2292e-08,
-        -5.1830e-01, -1.3407e+00, -5.7063e-12, -3.0675e-01,  2.8323e-06,
-         1.5282e-04,  9.4479e-08,  0.0000e+00,  6.0493e-01,  5.7349e-01,
-         3.4717e-01, -1.2619e-05, -3.0151e-01,  2.3252e-08,  5.6872e-06,
-        -1.8385e-01, -8.0246e-04,  1.9560e-01,  1.0823e+00, -1.4803e-09,
-         1.8315e-02, -4.5105e-02, -3.5821e-01, -1.9945e-01,  7.8836e-01,
-         3.0807e-13, -1.2108e-07,  4.3544e-11,  7.3948e-01, -6.2887e-03,
-         1.1632e-10,  2.3044e-01, -6.8403e-02, -2.8809e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0938,  0.2843,  0.0413,  0.2123,  2.6069,  0.0908,  0.0000,  0.0000,
-         0.1336,  0.0641, -0.0265,  0.0000, -0.1137, -0.5949,  0.0000,  0.0000,
-         0.0000, -0.1098,  0.0000,  0.0000, -0.6160, -0.0406,  0.0000, -0.8036,
-         0.5160,  0.0000,  0.1277, -0.0064,  0.3652,  0.0000, -0.5183, -1.3407,
-         0.0000, -0.3067,  0.0000,  0.0000,  0.0000,  0.0000,  0.6049,  0.5735,
-         0.3472,  0.0000, -0.3015,  0.0000,  0.0000, -0.1839,  0.0000,  0.1956,
-         1.0823,  0.0000,  0.0183, -0.0451, -0.3582, -0.1994,  0.7884,  0.0000,
-         0.0000,  0.0000,  0.7395,  0.0000,  0.0000,  0.2304, -0.0684, -0.0288],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0938,  0.2843,  0.0413,  0.2123,  2.6069,  0.0908,  0.0000,  0.0000,
-         0.1336,  0.0641, -0.0265,  0.0000, -0.1137, -0.5949,  0.0000,  0.0000,
-         0.0000, -0.1098,  0.0000,  0.0000, -0.6160, -0.0406,  0.0000, -0.8036,
-         0.5160,  0.0000,  0.1277, -0.0064,  0.3652,  0.0000, -0.5183, -1.3407,
-         0.0000, -0.3067,  0.0000,  0.0000,  0.0000,  0.0000,  0.6049,  0.5735,
-         0.3472,  0.0000, -0.3015,  0.0000,  0.0000, -0.1839,  0.0000,  0.1956,
-         1.0823,  0.0000,  0.0183, -0.0451, -0.3582, -0.1994,  0.7884,  0.0000,
-         0.0000,  0.0000,  0.7395,  0.0000,  0.0000,  0.2304, -0.0684, -0.0288],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 8.3765e-02,  2.4966e-01,  2.6514e-02,  1.8825e-01,  2.6063e+00,
-         1.6574e-01,  2.0512e-11, -1.6197e-04,  2.1275e-01, -2.5420e-02,
-        -1.2117e-01,  3.6875e-03, -1.1914e-01, -5.6019e-01, -3.6784e-08,
-        -2.7339e-05, -2.6892e-06, -2.0505e-01, -6.2200e-09,  7.4785e-05,
-        -5.5346e-01, -1.6983e-01, -1.3209e-08, -8.2769e-01,  4.2376e-01,
-         1.0720e-04,  5.3785e-02, -4.1806e-03,  4.9748e-01,  1.9210e-08,
-        -4.7906e-01, -1.3382e+00, -4.9175e-12, -3.2176e-01,  2.4408e-06,
-         1.3170e-04,  8.1420e-08,  0.0000e+00,  6.0367e-01,  5.7964e-01,
-         2.7926e-01, -1.0875e-05, -3.5318e-01,  2.0038e-08,  4.9011e-06,
-        -2.1787e-01, -6.9153e-04,  2.0333e-01,  1.0902e+00, -1.2757e-09,
-        -7.4045e-02, -8.8392e-02, -3.5711e-01, -2.9035e-01,  7.7763e-01,
-         2.6549e-13, -1.0434e-07,  3.7525e-11,  6.9522e-01, -5.4194e-03,
-         1.0024e-10,  2.2369e-01, -1.1674e-01, -5.7528e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0838,  0.2497,  0.0265,  0.1883,  2.6063,  0.1657,  0.0000,  0.0000,
-         0.2127, -0.0254, -0.1212,  0.0000, -0.1191, -0.5602,  0.0000,  0.0000,
-         0.0000, -0.2051,  0.0000,  0.0000, -0.5535, -0.1698,  0.0000, -0.8277,
-         0.4238,  0.0000,  0.0538, -0.0042,  0.4975,  0.0000, -0.4791, -1.3382,
-         0.0000, -0.3218,  0.0000,  0.0000,  0.0000,  0.0000,  0.6037,  0.5796,
-         0.2793,  0.0000, -0.3532,  0.0000,  0.0000, -0.2179,  0.0000,  0.2033,
-         1.0902,  0.0000, -0.0740, -0.0884, -0.3571, -0.2904,  0.7776,  0.0000,
-         0.0000,  0.0000,  0.6952,  0.0000,  0.0000,  0.2237, -0.1167, -0.0575],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0838,  0.2497,  0.0265,  0.1883,  2.6063,  0.1657,  0.0000,  0.0000,
-         0.2127, -0.0254, -0.1212,  0.0000, -0.1191, -0.5602,  0.0000,  0.0000,
-         0.0000, -0.2051,  0.0000,  0.0000, -0.5535, -0.1698,  0.0000, -0.8277,
-         0.4238,  0.0000,  0.0538, -0.0042,  0.4975,  0.0000, -0.4791, -1.3382,
-         0.0000, -0.3218,  0.0000,  0.0000,  0.0000,  0.0000,  0.6037,  0.5796,
-         0.2793,  0.0000, -0.3532,  0.0000,  0.0000, -0.2179,  0.0000,  0.2033,
-         1.0902,  0.0000, -0.0740, -0.0884, -0.3571, -0.2904,  0.7776,  0.0000,
-         0.0000,  0.0000,  0.6952,  0.0000,  0.0000,  0.2237, -0.1167, -0.0575],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.1670e-02,  2.7265e-01, -4.7193e-02,  7.5304e-02,  2.6037e+00,
-         1.6776e-01,  1.7682e-11, -1.3962e-04,  2.5282e-01, -9.9736e-02,
-        -1.2384e-02,  3.1788e-03, -1.9863e-01, -5.6802e-01, -3.1709e-08,
-        -2.3567e-05, -2.3182e-06, -2.3874e-01, -5.3619e-09,  6.4467e-05,
-        -4.5752e-01, -1.7830e-01, -1.1387e-08, -7.8851e-01,  3.3238e-01,
-         9.2415e-05, -3.0496e-02, -3.2712e-02,  6.4108e-01,  1.6560e-08,
-        -4.6195e-01, -1.3296e+00, -4.2391e-12, -2.2793e-01,  2.1041e-06,
-         1.1353e-04,  7.0187e-08,  0.0000e+00,  5.6857e-01,  5.5359e-01,
-         2.3883e-01, -9.3746e-06, -4.3676e-01,  1.7274e-08,  4.2250e-06,
-        -2.7387e-01, -5.9613e-04,  2.2122e-01,  1.0894e+00, -1.0997e-09,
-        -2.8585e-02, -1.7422e-02, -3.4544e-01, -3.3349e-01,  7.4652e-01,
-         2.2886e-13, -8.9946e-08,  3.2348e-11,  6.3989e-01, -4.6717e-03,
-         8.6409e-11,  1.7249e-01, -1.4442e-01, -1.0016e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0117,  0.2726, -0.0472,  0.0753,  2.6037,  0.1678,  0.0000,  0.0000,
-         0.2528, -0.0997, -0.0124,  0.0000, -0.1986, -0.5680,  0.0000,  0.0000,
-         0.0000, -0.2387,  0.0000,  0.0000, -0.4575, -0.1783,  0.0000, -0.7885,
-         0.3324,  0.0000, -0.0305, -0.0327,  0.6411,  0.0000, -0.4619, -1.3296,
-         0.0000, -0.2279,  0.0000,  0.0000,  0.0000,  0.0000,  0.5686,  0.5536,
-         0.2388,  0.0000, -0.4368,  0.0000,  0.0000, -0.2739,  0.0000,  0.2212,
-         1.0894,  0.0000, -0.0286, -0.0174, -0.3454, -0.3335,  0.7465,  0.0000,
-         0.0000,  0.0000,  0.6399,  0.0000,  0.0000,  0.1725, -0.1444, -0.1002],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0117,  0.2726, -0.0472,  0.0753,  2.6037,  0.1678,  0.0000,  0.0000,
-         0.2528, -0.0997, -0.0124,  0.0000, -0.1986, -0.5680,  0.0000,  0.0000,
-         0.0000, -0.2387,  0.0000,  0.0000, -0.4575, -0.1783,  0.0000, -0.7885,
-         0.3324,  0.0000, -0.0305, -0.0327,  0.6411,  0.0000, -0.4619, -1.3296,
-         0.0000, -0.2279,  0.0000,  0.0000,  0.0000,  0.0000,  0.5686,  0.5536,
-         0.2388,  0.0000, -0.4368,  0.0000,  0.0000, -0.2739,  0.0000,  0.2212,
-         1.0894,  0.0000, -0.0286, -0.0174, -0.3454, -0.3335,  0.7465,  0.0000,
-         0.0000,  0.0000,  0.6399,  0.0000,  0.0000,  0.1725, -0.1444, -0.1002],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.7204e-01,  3.2178e-01, -1.4703e-01, -2.8551e-02,  2.5997e+00,
-         1.2472e-01,  1.5247e-11, -1.2040e-04,  2.3038e-01, -1.3700e-01,
-         9.3311e-02,  2.7411e-03, -3.2469e-01, -5.8426e-01, -2.7343e-08,
-        -2.0322e-05, -1.9990e-06, -2.1220e-01, -4.6236e-09,  5.5591e-05,
-        -3.7621e-01, -1.3197e-01, -9.8188e-09, -7.6244e-01,  2.8318e-01,
-         7.9690e-05, -8.8889e-02, -6.1810e-02,  7.5561e-01,  1.4280e-08,
-        -4.2168e-01, -1.3220e+00, -3.6554e-12, -1.5515e-01,  1.8144e-06,
-         9.7896e-05,  6.0523e-08,  0.0000e+00,  4.9392e-01,  5.0235e-01,
-         2.3983e-01, -8.0838e-06, -5.0152e-01,  1.4895e-08,  3.6432e-06,
-        -2.9534e-01, -5.1405e-04,  2.2077e-01,  1.0912e+00, -9.4826e-10,
-         5.0534e-02,  2.8403e-02, -3.5502e-01, -2.7906e-01,  7.0893e-01,
-         1.9735e-13, -7.7562e-08,  2.7894e-11,  5.7576e-01, -4.0285e-03,
-         7.4511e-11,  1.1739e-01, -1.6277e-01, -1.0950e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.1720,  0.3218, -0.1470, -0.0286,  2.5997,  0.1247,  0.0000,  0.0000,
-         0.2304, -0.1370,  0.0933,  0.0000, -0.3247, -0.5843,  0.0000,  0.0000,
-         0.0000, -0.2122,  0.0000,  0.0000, -0.3762, -0.1320,  0.0000, -0.7624,
-         0.2832,  0.0000, -0.0889, -0.0618,  0.7556,  0.0000, -0.4217, -1.3220,
-         0.0000, -0.1552,  0.0000,  0.0000,  0.0000,  0.0000,  0.4939,  0.5023,
-         0.2398,  0.0000, -0.5015,  0.0000,  0.0000, -0.2953,  0.0000,  0.2208,
-         0.0000,  0.0000,  0.0505,  0.0284, -0.3550, -0.2791,  0.7089,  0.0000,
-         0.0000,  0.0000,  0.5758,  0.0000,  0.0000,  0.1174, -0.1628, -0.1095],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.1720,  0.3218, -0.1470, -0.0286,  2.5997,  0.1247,  0.0000,  0.0000,
-         0.2304, -0.1370,  0.0933,  0.0000, -0.3247, -0.5843,  0.0000,  0.0000,
-         0.0000, -0.2122,  0.0000,  0.0000, -0.3762, -0.1320,  0.0000, -0.7624,
-         0.2832,  0.0000, -0.0889, -0.0618,  0.7556,  0.0000, -0.4217, -1.3220,
-         0.0000, -0.1552,  0.0000,  0.0000,  0.0000,  0.0000,  0.4939,  0.5023,
-         0.2398,  0.0000, -0.5015,  0.0000,  0.0000, -0.2953,  0.0000,  0.2208,
-         0.0000,  0.0000,  0.0505,  0.0284, -0.3550, -0.2791,  0.7089,  0.0000,
-         0.0000,  0.0000,  0.5758,  0.0000,  0.0000,  0.1174, -0.1628, -0.1095],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.5002e-01,  3.3807e-01, -1.8290e-01, -4.8839e-02,  2.5935e+00,
-         6.1193e-02,  1.3152e-11, -1.0385e-04,  2.1490e-01, -1.4038e-01,
-         6.9259e-02,  2.3644e-03, -3.9389e-01, -6.0241e-01, -2.3586e-08,
-        -1.7530e-05, -1.7243e-06, -8.1251e-02, -3.9882e-09,  4.7952e-05,
-        -3.8821e-01, -1.5784e-02, -8.4695e-09, -7.7198e-01,  2.7099e-01,
-         6.8739e-05, -1.1067e-01, -8.7536e-02,  8.3280e-01,  1.2318e-08,
-        -3.5070e-01, -1.3209e+00, -3.1531e-12, -9.8398e-02,  1.5650e-06,
-         8.4443e-05,  5.2206e-08,  0.0000e+00,  4.1829e-01,  4.3972e-01,
-         2.1830e-01, -6.9730e-06, -5.0779e-01,  1.2848e-08,  3.1426e-06,
-        -2.7465e-01, -4.4341e-04,  2.1188e-01,  1.4916e-03, -8.1795e-10,
-         6.9243e-02,  3.1586e-02, -4.2104e-01, -5.9834e-02,  6.8751e-01,
-         1.7023e-13, -6.6903e-08,  2.4061e-11,  5.3913e-01, -3.4749e-03,
-         6.4272e-11,  8.7096e-02, -1.6642e-01, -1.3174e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.2500,  0.3381, -0.1829, -0.0488,  2.5935,  0.0612,  0.0000,  0.0000,
-         0.2149, -0.1404,  0.0693,  0.0000, -0.3939, -0.6024,  0.0000,  0.0000,
-         0.0000, -0.0813,  0.0000,  0.0000, -0.3882, -0.0158,  0.0000, -0.7720,
-         0.2710,  0.0000, -0.1107, -0.0875,  0.8328,  0.0000, -0.3507, -1.3209,
-         0.0000, -0.0984,  0.0000,  0.0000,  0.0000,  0.0000,  0.4183,  0.4397,
-         0.2183,  0.0000, -0.5078,  0.0000,  0.0000, -0.2747,  0.0000,  0.2119,
-         0.0000,  0.0000,  0.0692,  0.0316, -0.4210, -0.0598,  0.6875,  0.0000,
-         0.0000,  0.0000,  0.5391,  0.0000,  0.0000,  0.0871, -0.1664, -0.1317],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.2500,  0.3381, -0.1829, -0.0488,  2.5935,  0.0612,  0.0000,  0.0000,
-         0.2149, -0.1404,  0.0693,  0.0000, -0.3939, -0.6024,  0.0000,  0.0000,
-         0.0000, -0.0813,  0.0000,  0.0000, -0.3882, -0.0158,  0.0000, -0.7720,
-         0.2710,  0.0000, -0.1107, -0.0875,  0.8328,  0.0000, -0.3507, -1.3209,
-         0.0000, -0.0984,  0.0000,  0.0000,  0.0000,  0.0000,  0.4183,  0.4397,
-         0.2183,  0.0000, -0.5078,  0.0000,  0.0000, -0.2747,  0.0000,  0.2119,
-         0.0000,  0.0000,  0.0692,  0.0316, -0.4210, -0.0598,  0.6875,  0.0000,
-         0.0000,  0.0000,  0.5391,  0.0000,  0.0000,  0.0871, -0.1664, -0.1317],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.4532e-01,  3.0665e-01, -1.3084e-01, -7.1319e-02,  2.5867e+00,
-         5.4689e-02,  1.1348e-11, -8.9609e-05,  2.1722e-01, -1.8086e-01,
-        -1.6117e-02,  2.0401e-03, -3.7858e-01, -6.1900e-01, -2.0351e-08,
-        -1.5126e-05, -1.4878e-06, -1.1687e-02, -3.4412e-09,  4.1375e-05,
-        -4.2142e-01,  5.0488e-02, -7.3079e-09, -7.9645e-01,  2.1084e-01,
-         5.9312e-05, -1.1058e-01, -1.1382e-01,  8.6231e-01,  1.0628e-08,
-        -2.9587e-01, -1.3244e+00, -2.7206e-12, -8.1925e-02,  1.3504e-06,
-         7.2862e-05,  4.5046e-08,  0.0000e+00,  4.5492e-01,  4.5656e-01,
-         1.5089e-01, -6.0167e-06, -4.6471e-01,  1.1086e-08,  2.7116e-06,
-        -2.0358e-01, -3.8260e-04,  2.3888e-01,  1.2871e-03, -7.0577e-10,
-         4.9486e-02, -5.5265e-03, -4.2209e-01,  1.4293e-01,  7.0266e-01,
-         1.4688e-13, -5.7728e-08,  2.0761e-11,  5.1543e-01, -2.9983e-03,
-         5.5457e-11,  9.3411e-02, -1.0523e-01, -1.6950e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.2453,  0.3066, -0.1308, -0.0713,  2.5867,  0.0547,  0.0000,  0.0000,
-         0.2172, -0.1809, -0.0161,  0.0000, -0.3786, -0.6190,  0.0000,  0.0000,
-         0.0000, -0.0117,  0.0000,  0.0000, -0.4214,  0.0505,  0.0000, -0.7964,
-         0.2108,  0.0000, -0.1106, -0.1138,  0.8623,  0.0000, -0.2959, -1.3244,
-         0.0000, -0.0819,  0.0000,  0.0000,  0.0000,  0.0000,  0.4549,  0.4566,
-         0.1509,  0.0000, -0.4647,  0.0000,  0.0000, -0.2036,  0.0000,  0.2389,
-         0.0000,  0.0000,  0.0495, -0.0055, -0.4221,  0.1429,  0.7027,  0.0000,
-         0.0000,  0.0000,  0.5154,  0.0000,  0.0000,  0.0934, -0.1052, -0.1695],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.2453,  0.3066, -0.1308, -0.0713,  2.5867,  0.0547,  0.0000,  0.0000,
-         0.2172, -0.1809, -0.0161,  0.0000, -0.3786, -0.6190,  0.0000,  0.0000,
-         0.0000, -0.0117,  0.0000,  0.0000, -0.4214,  0.0505,  0.0000, -0.7964,
-         0.2108,  0.0000, -0.1106, -0.1138,  0.8623,  0.0000, -0.2959, -1.3244,
-         0.0000, -0.0819,  0.0000,  0.0000,  0.0000,  0.0000,  0.4549,  0.4566,
-         0.1509,  0.0000, -0.4647,  0.0000,  0.0000, -0.2036,  0.0000,  0.2389,
-         0.0000,  0.0000,  0.0495, -0.0055, -0.4221,  0.1429,  0.7027,  0.0000,
-         0.0000,  0.0000,  0.5154,  0.0000,  0.0000,  0.0934, -0.1052, -0.1695],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.3053e-01,  2.7186e-01, -7.1899e-02, -1.1368e-01,  2.5834e+00,
-         8.6707e-02,  9.7952e-12, -7.7344e-05,  2.3148e-01, -2.2406e-01,
-        -7.3512e-02,  1.7609e-03, -3.6414e-01, -6.3759e-01, -1.7566e-08,
-        -1.3055e-05, -1.2842e-06, -2.2711e-02, -2.9702e-09,  3.5712e-05,
-        -4.7398e-01,  8.8438e-02, -6.3077e-09, -8.3519e-01,  1.5209e-01,
-         5.1194e-05, -1.0593e-01, -1.6086e-01,  8.9108e-01,  9.1737e-09,
-        -2.6953e-01, -1.3296e+00, -2.3483e-12, -7.8036e-02,  1.1656e-06,
-         6.2890e-05,  3.8881e-08,  0.0000e+00,  4.6880e-01,  4.5075e-01,
-         5.4863e-02, -5.1932e-06, -4.2603e-01,  9.5689e-09,  2.3405e-06,
-        -1.1213e-01, -3.3023e-04,  2.5037e-01,  1.1109e-03, -6.0918e-10,
-         1.7240e-02, -7.7790e-02, -3.7834e-01,  2.1872e-01,  6.9589e-01,
-         1.2678e-13, -4.9827e-08,  1.7919e-11,  5.1353e-01, -2.5880e-03,
-         4.7867e-11,  9.9574e-02, -3.6641e-02, -1.8956e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.2305,  0.2719, -0.0719, -0.1137,  2.5834,  0.0867,  0.0000,  0.0000,
-         0.2315, -0.2241, -0.0735,  0.0000, -0.3641, -0.6376,  0.0000,  0.0000,
-         0.0000, -0.0227,  0.0000,  0.0000, -0.4740,  0.0884,  0.0000, -0.8352,
-         0.1521,  0.0000, -0.1059, -0.1609,  0.8911,  0.0000, -0.2695, -1.3296,
-         0.0000, -0.0780,  0.0000,  0.0000,  0.0000,  0.0000,  0.4688,  0.4508,
-         0.0549,  0.0000, -0.4260,  0.0000,  0.0000, -0.1121,  0.0000,  0.2504,
-         0.0000,  0.0000,  0.0172, -0.0778, -0.3783,  0.2187,  0.6959,  0.0000,
-         0.0000,  0.0000,  0.5135,  0.0000,  0.0000,  0.0996, -0.0366, -0.1896],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.2305,  0.2719, -0.0719, -0.1137,  2.5834,  0.0867,  0.0000,  0.0000,
-         0.2315, -0.2241, -0.0735,  0.0000, -0.3641, -0.6376,  0.0000,  0.0000,
-         0.0000, -0.0227,  0.0000,  0.0000, -0.4740,  0.0884,  0.0000, -0.8352,
-         0.1521,  0.0000, -0.1059, -0.1609,  0.8911,  0.0000, -0.2695, -1.3296,
-         0.0000, -0.0780,  0.0000,  0.0000,  0.0000,  0.0000,  0.4688,  0.4508,
-         0.0549,  0.0000, -0.4260,  0.0000,  0.0000, -0.1121,  0.0000,  0.2504,
-         0.0000,  0.0000,  0.0172, -0.0778, -0.3783,  0.2187,  0.6959,  0.0000,
-         0.0000,  0.0000,  0.5135,  0.0000,  0.0000,  0.0996, -0.0366, -0.1896],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.3780e-01,  2.1329e-01,  1.9273e-02, -1.3147e-01,  2.5747e+00,
-         1.3677e-01,  8.4573e-12, -6.6780e-05,  2.6610e-01, -2.4826e-01,
-        -4.2789e-02,  1.5204e-03, -3.3520e-01, -6.1838e-01, -1.5166e-08,
-        -1.1272e-05, -1.1088e-06, -5.3885e-02, -2.5645e-09,  3.0834e-05,
-        -4.7221e-01,  1.0094e-01, -5.4461e-09, -8.4031e-01,  8.0027e-02,
-         4.4201e-05, -9.4712e-02, -1.8224e-01,  8.9433e-01,  7.9206e-09,
-        -2.9764e-01, -1.3322e+00, -2.0275e-12, -8.2277e-02,  1.0064e-06,
-         5.4300e-05,  3.3570e-08,  0.0000e+00,  4.9957e-01,  4.6972e-01,
-        -8.2387e-03, -4.4838e-06, -4.0756e-01,  8.2619e-09,  2.0208e-06,
-        -5.2432e-02, -2.8513e-04,  2.3684e-01,  9.5916e-04, -5.2597e-10,
-         3.2168e-02, -9.4077e-02, -3.3381e-01,  2.1378e-01,  6.9074e-01,
-         1.0946e-13, -4.3021e-08,  1.5472e-11,  5.2364e-01, -2.2345e-03,
-         4.1329e-11,  1.1983e-01,  3.7000e-02, -1.6981e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.2378,  0.2133,  0.0193, -0.1315,  2.5747,  0.1368,  0.0000,  0.0000,
-         0.2661, -0.2483, -0.0428,  0.0000, -0.3352, -0.6184,  0.0000,  0.0000,
-         0.0000, -0.0539,  0.0000,  0.0000, -0.4722,  0.1009,  0.0000, -0.8403,
-         0.0800,  0.0000, -0.0947, -0.1822,  0.8943,  0.0000, -0.2976, -1.3322,
-         0.0000, -0.0823,  0.0000,  0.0000,  0.0000,  0.0000,  0.4996,  0.4697,
-        -0.0082,  0.0000, -0.4076,  0.0000,  0.0000, -0.0524,  0.0000,  0.2368,
-         0.0000,  0.0000,  0.0322, -0.0941, -0.3338,  0.2138,  0.6907,  0.0000,
-         0.0000,  0.0000,  0.5236,  0.0000,  0.0000,  0.1198,  0.0370, -0.1698],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.2378,  0.2133,  0.0193, -0.1315,  2.5747,  0.1368,  0.0000,  0.0000,
-         0.2661, -0.2483, -0.0428,  0.0000, -0.3352, -0.6184,  0.0000,  0.0000,
-         0.0000, -0.0539,  0.0000,  0.0000, -0.4722,  0.1009,  0.0000, -0.8403,
-         0.0800,  0.0000, -0.0947, -0.1822,  0.8943,  0.0000, -0.2976, -1.3322,
-         0.0000, -0.0823,  0.0000,  0.0000,  0.0000,  0.0000,  0.4996,  0.4697,
-        -0.0082,  0.0000, -0.4076,  0.0000,  0.0000, -0.0524,  0.0000,  0.2368,
-         0.0000,  0.0000,  0.0322, -0.0941, -0.3338,  0.2138,  0.6907,  0.0000,
-         0.0000,  0.0000,  0.5236,  0.0000,  0.0000,  0.1198,  0.0370, -0.1698],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.4799e-01,  1.9348e-01,  7.2727e-02, -1.3215e-01,  2.5653e+00,
-         1.7513e-01,  7.3045e-12, -5.7677e-05,  2.9346e-01, -2.6757e-01,
-         5.5927e-02,  1.3131e-03, -3.2097e-01, -5.9302e-01, -1.3099e-08,
-        -9.7356e-06, -9.5765e-07, -4.7899e-02, -2.2150e-09,  2.6631e-05,
-        -4.2798e-01,  1.0752e-01, -4.7038e-09, -8.2876e-01,  2.9705e-02,
-         3.8176e-05, -8.1453e-02, -1.7743e-01,  8.8498e-01,  6.8410e-09,
-        -3.7272e-01, -1.3324e+00, -1.7512e-12, -9.7805e-02,  8.6919e-07,
-         4.6898e-05,  2.8994e-08,  0.0000e+00,  4.8838e-01,  4.7618e-01,
-        -5.6818e-03, -3.8726e-06, -3.9282e-01,  7.1357e-09,  1.7453e-06,
-        -4.8389e-02, -2.4626e-04,  2.4342e-01,  8.2842e-04, -4.5427e-10,
-         7.9226e-02, -5.3006e-02, -3.3095e-01,  1.4087e-01,  6.7124e-01,
-         9.4542e-14, -3.7157e-08,  1.3363e-11,  5.2888e-01, -1.9299e-03,
-         3.5695e-11,  1.1663e-01,  8.2148e-02, -1.1581e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.2480,  0.1935,  0.0727, -0.1321,  2.5653,  0.1751,  0.0000,  0.0000,
-         0.2935, -0.2676,  0.0559,  0.0000, -0.3210, -0.5930,  0.0000,  0.0000,
-         0.0000, -0.0479,  0.0000,  0.0000, -0.4280,  0.1075,  0.0000, -0.8288,
-         0.0297,  0.0000, -0.0815, -0.1774,  0.8850,  0.0000, -0.3727, -1.3324,
-         0.0000, -0.0978,  0.0000,  0.0000,  0.0000,  0.0000,  0.4884,  0.4762,
-        -0.0057,  0.0000, -0.3928,  0.0000,  0.0000, -0.0484,  0.0000,  0.2434,
-         0.0000,  0.0000,  0.0792, -0.0530, -0.3310,  0.1409,  0.6712,  0.0000,
-         0.0000,  0.0000,  0.5289,  0.0000,  0.0000,  0.1166,  0.0821, -0.1158],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.2480,  0.1935,  0.0727, -0.1321,  2.5653,  0.1751,  0.0000,  0.0000,
-         0.2935, -0.2676,  0.0559,  0.0000, -0.3210, -0.5930,  0.0000,  0.0000,
-         0.0000, -0.0479,  0.0000,  0.0000, -0.4280,  0.1075,  0.0000, -0.8288,
-         0.0297,  0.0000, -0.0815, -0.1774,  0.8850,  0.0000, -0.3727, -1.3324,
-         0.0000, -0.0978,  0.0000,  0.0000,  0.0000,  0.0000,  0.4884,  0.4762,
-        -0.0057,  0.0000, -0.3928,  0.0000,  0.0000, -0.0484,  0.0000,  0.2434,
-         0.0000,  0.0000,  0.0792, -0.0530, -0.3310,  0.1409,  0.6712,  0.0000,
-         0.0000,  0.0000,  0.5289,  0.0000,  0.0000,  0.1166,  0.0821, -0.1158],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.5398e-01,  1.8674e-01,  1.0285e-01, -1.1424e-01,  2.5552e+00,
-         1.5971e-01,  6.3109e-12, -4.9831e-05,  3.3553e-01, -3.2314e-01,
-         1.7727e-01,  1.1345e-03, -3.0748e-01, -5.6069e-01, -1.1317e-08,
-        -8.4113e-06, -8.2738e-07, -5.1191e-03, -1.9137e-09,  2.3009e-05,
-        -3.4238e-01,  1.3296e-01, -4.0639e-09, -7.9836e-01,  3.7046e-02,
-         3.2983e-05, -1.1039e-01, -1.3809e-01,  8.7258e-01,  5.9104e-09,
-        -4.6906e-01, -1.3287e+00, -1.5130e-12, -1.0270e-01,  7.5096e-07,
-         4.0519e-05,  2.5050e-08,  0.0000e+00,  4.6951e-01,  4.7802e-01,
-         3.4243e-02, -3.3459e-06, -3.9607e-01,  6.1651e-09,  1.5079e-06,
-        -6.8945e-02, -2.1276e-04,  2.3163e-01,  7.1573e-04, -3.9248e-10,
-         1.3234e-01,  7.2952e-02, -3.3770e-01,  2.7725e-02,  6.4730e-01,
-         8.1682e-14, -3.2102e-08,  1.1545e-11,  5.4423e-01, -1.6674e-03,
-         3.0840e-11,  8.0261e-02,  1.4813e-01, -8.1414e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.2540,  0.1867,  0.1029, -0.1142,  2.5552,  0.1597,  0.0000,  0.0000,
-         0.3355, -0.3231,  0.1773,  0.0000, -0.3075, -0.5607,  0.0000,  0.0000,
-         0.0000, -0.0051,  0.0000,  0.0000, -0.3424,  0.1330,  0.0000, -0.7984,
-         0.0370,  0.0000, -0.1104, -0.1381,  0.8726,  0.0000, -0.4691, -1.3287,
-         0.0000, -0.1027,  0.0000,  0.0000,  0.0000,  0.0000,  0.4695,  0.4780,
-         0.0342,  0.0000, -0.3961,  0.0000,  0.0000, -0.0689,  0.0000,  0.2316,
-         0.0000,  0.0000,  0.1323,  0.0730, -0.3377,  0.0277,  0.6473,  0.0000,
-         0.0000,  0.0000,  0.5442,  0.0000,  0.0000,  0.0803,  0.1481, -0.0814],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.2540,  0.1867,  0.1029, -0.1142,  2.5552,  0.1597,  0.0000,  0.0000,
-         0.3355, -0.3231,  0.1773,  0.0000, -0.3075, -0.5607,  0.0000,  0.0000,
-         0.0000, -0.0051,  0.0000,  0.0000, -0.3424,  0.1330,  0.0000, -0.7984,
-         0.0370,  0.0000, -0.1104, -0.1381,  0.8726,  0.0000, -0.4691, -1.3287,
-         0.0000, -0.1027,  0.0000,  0.0000,  0.0000,  0.0000,  0.4695,  0.4780,
-         0.0342,  0.0000, -0.3961,  0.0000,  0.0000, -0.0689,  0.0000,  0.2316,
-         0.0000,  0.0000,  0.1323,  0.0730, -0.3377,  0.0277,  0.6473,  0.0000,
-         0.0000,  0.0000,  0.5442,  0.0000,  0.0000,  0.0803,  0.1481, -0.0814],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.6107e-01,  1.7947e-01,  1.0251e-01, -8.1097e-02,  2.5502e+00,
-         1.6180e-01,  5.4542e-12, -4.3067e-05,  3.5964e-01, -3.7686e-01,
-         2.3703e-01,  9.8052e-04, -3.1621e-01, -4.9321e-01, -9.7810e-09,
-        -7.2696e-06, -7.1507e-07,  3.1041e-02, -1.6539e-09,  1.9886e-05,
-        -2.0964e-01,  1.6771e-01, -3.5123e-09, -7.3877e-01,  7.8275e-02,
-         2.8506e-05, -1.6951e-01, -6.0264e-02,  8.6436e-01,  5.1081e-09,
-        -4.9292e-01, -1.3182e+00, -1.3076e-12, -1.0996e-01,  6.4902e-07,
-         3.5019e-05,  2.1650e-08,  0.0000e+00,  4.0392e-01,  4.6579e-01,
-         5.1996e-02, -2.8917e-06, -4.2489e-01,  5.3282e-09,  1.3032e-06,
-        -1.6157e-01, -1.8388e-04,  1.9535e-01,  6.1858e-04, -3.3921e-10,
-         1.3066e-01,  2.2634e-01, -3.6806e-01, -1.1509e-01,  6.2738e-01,
-         7.0594e-14, -2.7745e-08,  9.9781e-12,  5.7131e-01, -1.4410e-03,
-         2.6654e-11, -2.5570e-02,  1.3253e-01, -4.7074e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.2611,  0.1795,  0.1025, -0.0811,  2.5502,  0.1618,  0.0000,  0.0000,
-         0.3596, -0.3769,  0.2370,  0.0000, -0.3162, -0.4932,  0.0000,  0.0000,
-         0.0000,  0.0310,  0.0000,  0.0000, -0.2096,  0.1677,  0.0000, -0.7388,
-         0.0783,  0.0000, -0.1695, -0.0603,  0.8644,  0.0000, -0.4929, -1.3182,
-         0.0000, -0.1100,  0.0000,  0.0000,  0.0000,  0.0000,  0.4039,  0.4658,
-         0.0520,  0.0000, -0.4249,  0.0000,  0.0000, -0.1616,  0.0000,  0.1954,
-         0.0000,  0.0000,  0.1307,  0.2263, -0.3681, -0.1151,  0.6274,  0.0000,
-         0.0000,  0.0000,  0.5713,  0.0000,  0.0000, -0.0256,  0.1325, -0.0471],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.2611,  0.1795,  0.1025, -0.0811,  2.5502,  0.1618,  0.0000,  0.0000,
-         0.3596, -0.3769,  0.2370,  0.0000, -0.3162, -0.4932,  0.0000,  0.0000,
-         0.0000,  0.0310,  0.0000,  0.0000, -0.2096,  0.1677,  0.0000, -0.7388,
-         0.0783,  0.0000, -0.1695, -0.0603,  0.8644,  0.0000, -0.4929, -1.3182,
-         0.0000, -0.1100,  0.0000,  0.0000,  0.0000,  0.0000,  0.4039,  0.4658,
-         0.0520,  0.0000, -0.4249,  0.0000,  0.0000, -0.1616,  0.0000,  0.1954,
-         0.0000,  0.0000,  0.1307,  0.2263, -0.3681, -0.1151,  0.6274,  0.0000,
-         0.0000,  0.0000,  0.5713,  0.0000,  0.0000, -0.0256,  0.1325, -0.0471],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.2958e-01,  1.7168e-01,  1.5807e-01, -4.9534e-02,  2.5436e+00,
-         2.0280e-01,  4.7154e-12, -3.7234e-05,  3.1880e-01, -3.4245e-01,
-         1.8493e-01,  8.4771e-04, -2.7036e-01, -4.5175e-01, -8.4561e-09,
-        -6.2849e-06, -6.1822e-07,  1.0559e-01, -1.4299e-09,  1.7192e-05,
-        -1.6724e-01,  1.9071e-01, -3.0366e-09, -7.2094e-01,  5.4915e-02,
-         2.4645e-05, -1.4277e-01, -4.0272e-02,  8.4157e-01,  4.4162e-09,
-        -4.2830e-01, -1.3113e+00, -1.1305e-12, -1.1241e-01,  5.6111e-07,
-         3.0275e-05,  1.8717e-08,  0.0000e+00,  3.0168e-01,  4.2977e-01,
-         1.9050e-02, -2.5000e-06, -3.9633e-01,  4.6065e-09,  1.1267e-06,
-        -1.6378e-01, -1.5897e-04,  1.8566e-01,  5.3479e-04, -2.9326e-10,
-         1.3747e-01,  2.4789e-01, -3.7385e-01, -6.7709e-02,  5.9899e-01,
-         6.1032e-14, -2.3987e-08,  8.6265e-12,  6.0711e-01, -1.2459e-03,
-         2.3043e-11, -4.4203e-02,  6.8054e-02, -2.9286e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.2296,  0.0000,  0.1581, -0.0495,  2.5436,  0.2028,  0.0000,  0.0000,
-         0.3188, -0.3425,  0.1849,  0.0000, -0.2704, -0.4518,  0.0000,  0.0000,
-         0.0000,  0.1056,  0.0000,  0.0000, -0.1672,  0.1907,  0.0000, -0.7209,
-         0.0549,  0.0000, -0.1428, -0.0403,  0.8416,  0.0000, -0.4283, -1.3113,
-         0.0000, -0.1124,  0.0000,  0.0000,  0.0000,  0.0000,  0.3017,  0.4298,
-         0.0191,  0.0000, -0.3963,  0.0000,  0.0000, -0.1638,  0.0000,  0.1857,
-         0.0000,  0.0000,  0.1375,  0.2479, -0.3739, -0.0677,  0.5990,  0.0000,
-         0.0000,  0.0000,  0.6071,  0.0000,  0.0000, -0.0442,  0.0681, -0.0293],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.2296,  0.0000,  0.1581, -0.0495,  2.5436,  0.2028,  0.0000,  0.0000,
-         0.3188, -0.3425,  0.1849,  0.0000, -0.2704, -0.4518,  0.0000,  0.0000,
-         0.0000,  0.1056,  0.0000,  0.0000, -0.1672,  0.1907,  0.0000, -0.7209,
-         0.0549,  0.0000, -0.1428, -0.0403,  0.8416,  0.0000, -0.4283, -1.3113,
-         0.0000, -0.1124,  0.0000,  0.0000,  0.0000,  0.0000,  0.3017,  0.4298,
-         0.0191,  0.0000, -0.3963,  0.0000,  0.0000, -0.1638,  0.0000,  0.1857,
-         0.0000,  0.0000,  0.1375,  0.2479, -0.3739, -0.0677,  0.5990,  0.0000,
-         0.0000,  0.0000,  0.6071,  0.0000,  0.0000, -0.0442,  0.0681, -0.0293],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.8218e-01, -6.7371e-03,  2.5157e-01,  2.5279e-02,  2.5382e+00,
-         2.7428e-01,  4.0781e-12, -3.2201e-05,  2.4645e-01, -3.0264e-01,
-         3.0325e-02,  7.3313e-04, -2.0910e-01, -4.0136e-01, -7.3132e-09,
-        -5.4354e-06, -5.3466e-07,  2.4469e-01, -1.2366e-09,  1.4868e-05,
-        -1.2418e-01,  2.1882e-01, -2.6261e-09, -6.8648e-01,  3.0166e-02,
-         2.1314e-05, -1.0788e-01, -4.2918e-03,  8.1985e-01,  3.8193e-09,
-        -3.4597e-01, -1.3038e+00, -9.7767e-13, -1.0764e-01,  4.8527e-07,
-         2.6183e-05,  1.6187e-08,  0.0000e+00,  1.8625e-01,  3.9266e-01,
-        -3.0853e-02, -2.1621e-06, -3.7359e-01,  3.9839e-09,  9.7441e-07,
-        -9.9827e-02, -1.3749e-04,  1.7511e-01,  4.6251e-04, -2.5362e-10,
-         1.6434e-01,  2.2562e-01, -4.0487e-01,  1.0930e-01,  5.8738e-01,
-         5.2783e-14, -2.0745e-08,  7.4605e-12,  6.4430e-01, -1.0775e-03,
-         1.9929e-11, -3.3465e-02, -4.2269e-03, -6.1928e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.1822,  0.0000,  0.2516,  0.0253,  2.5382,  0.2743,  0.0000,  0.0000,
-         0.2465, -0.3026,  0.0303,  0.0000, -0.2091, -0.4014,  0.0000,  0.0000,
-         0.0000,  0.2447,  0.0000,  0.0000, -0.1242,  0.2188,  0.0000, -0.6865,
-         0.0302,  0.0000, -0.1079, -0.0043,  0.8198,  0.0000, -0.3460, -1.3038,
-         0.0000, -0.1076,  0.0000,  0.0000,  0.0000,  0.0000,  0.1863,  0.3927,
-        -0.0309,  0.0000, -0.3736,  0.0000,  0.0000, -0.0998,  0.0000,  0.1751,
-         0.0000,  0.0000,  0.1643,  0.2256, -0.4049,  0.1093,  0.5874,  0.0000,
-         0.0000,  0.0000,  0.6443,  0.0000,  0.0000, -0.0335, -0.0042, -0.0619],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.1822,  0.0000,  0.2516,  0.0253,  2.5382,  0.2743,  0.0000,  0.0000,
-         0.2465, -0.3026,  0.0303,  0.0000, -0.2091, -0.4014,  0.0000,  0.0000,
-         0.0000,  0.2447,  0.0000,  0.0000, -0.1242,  0.2188,  0.0000, -0.6865,
-         0.0302,  0.0000, -0.1079, -0.0043,  0.8198,  0.0000, -0.3460, -1.3038,
-         0.0000, -0.1076,  0.0000,  0.0000,  0.0000,  0.0000,  0.1863,  0.3927,
-        -0.0309,  0.0000, -0.3736,  0.0000,  0.0000, -0.0998,  0.0000,  0.1751,
-         0.0000,  0.0000,  0.1643,  0.2256, -0.4049,  0.1093,  0.5874,  0.0000,
-         0.0000,  0.0000,  0.6443,  0.0000,  0.0000, -0.0335, -0.0042, -0.0619],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.6913e-01, -5.8285e-03,  2.5380e-01, -6.8480e-03,  2.5342e+00,
-         3.3634e-01,  3.5281e-12, -2.7858e-05,  1.6857e-01, -2.3391e-01,
-        -5.0993e-02,  6.3425e-04, -2.3246e-01, -4.0379e-01, -6.3268e-09,
-        -4.7023e-06, -4.6255e-07,  2.0262e-01, -1.0698e-09,  1.2863e-05,
-        -1.6629e-01,  1.9073e-01, -2.2719e-09, -6.7686e-01, -1.4812e-02,
-         1.8439e-05, -8.9147e-02, -7.6753e-02,  7.7041e-01,  3.3042e-09,
-        -1.8817e-01, -1.2990e+00, -8.4581e-13, -1.1002e-01,  4.1982e-07,
-         2.2652e-05,  1.4004e-08,  0.0000e+00,  6.6865e-02,  3.8757e-01,
-        -1.0797e-01, -1.8705e-06, -3.6869e-01,  3.4466e-09,  8.4299e-07,
-        -5.4075e-02, -1.1894e-04,  2.2282e-01,  4.0013e-04, -2.1942e-10,
-         1.9317e-01,  1.6503e-01, -4.0102e-01,  3.7400e-03,  5.6429e-01,
-         4.5664e-14, -1.7947e-08,  6.4543e-12,  6.8280e-01, -9.3214e-04,
-         1.7241e-11, -1.0368e-01, -5.5675e-02, -1.1798e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.1691,  0.0000,  0.2538, -0.0068,  2.5342,  0.3363,  0.0000,  0.0000,
-         0.1686, -0.2339, -0.0510,  0.0000, -0.2325, -0.4038,  0.0000,  0.0000,
-         0.0000,  0.2026,  0.0000,  0.0000, -0.1663,  0.1907,  0.0000, -0.6769,
-        -0.0148,  0.0000, -0.0891, -0.0768,  0.7704,  0.0000, -0.1882, -1.2990,
-         0.0000, -0.1100,  0.0000,  0.0000,  0.0000,  0.0000,  0.0669,  0.3876,
-        -0.1080,  0.0000, -0.3687,  0.0000,  0.0000, -0.0541,  0.0000,  0.2228,
-         0.0000,  0.0000,  0.1932,  0.1650, -0.4010,  0.0037,  0.5643,  0.0000,
-         0.0000,  0.0000,  0.6828,  0.0000,  0.0000, -0.1037, -0.0557, -0.1180],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.1691,  0.0000,  0.2538, -0.0068,  2.5342,  0.3363,  0.0000,  0.0000,
-         0.1686, -0.2339, -0.0510,  0.0000, -0.2325, -0.4038,  0.0000,  0.0000,
-         0.0000,  0.2026,  0.0000,  0.0000, -0.1663,  0.1907,  0.0000, -0.6769,
-        -0.0148,  0.0000, -0.0891, -0.0768,  0.7704,  0.0000, -0.1882, -1.2990,
-         0.0000, -0.1100,  0.0000,  0.0000,  0.0000,  0.0000,  0.0669,  0.3876,
-        -0.1080,  0.0000, -0.3687,  0.0000,  0.0000, -0.0541,  0.0000,  0.2228,
-         0.0000,  0.0000,  0.1932,  0.1650, -0.4010,  0.0037,  0.5643,  0.0000,
-         0.0000,  0.0000,  0.6828,  0.0000,  0.0000, -0.1037, -0.0557, -0.1180],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.4038e-01, -5.0441e-03,  3.0227e-01, -3.6768e-02,  2.5311e+00,
-         3.6597e-01,  3.0533e-12, -2.4109e-05,  8.3160e-02, -1.4387e-01,
-        -7.1530e-02,  5.4889e-04, -2.2538e-01, -4.4711e-01, -5.4754e-09,
-        -4.0695e-06, -4.0030e-07,  1.0543e-01, -9.2586e-10,  1.1132e-05,
-        -2.0940e-01,  1.6196e-01, -1.9662e-09, -6.8606e-01, -7.8738e-02,
-         1.5958e-05, -1.9378e-02, -1.4695e-01,  7.2591e-01,  2.8595e-09,
-        -6.1849e-02, -1.2948e+00, -7.3198e-13, -1.2903e-01,  3.6332e-07,
-         1.9603e-05,  1.2120e-08,  0.0000e+00,  1.8895e-02,  3.7956e-01,
-        -1.9649e-01, -1.6188e-06, -3.4145e-01,  2.9827e-09,  7.2955e-07,
-        -6.5218e-02, -1.0294e-04,  2.7597e-01,  3.4628e-04, -1.8989e-10,
-         2.4721e-01,  5.9790e-02, -3.6061e-01, -1.0040e-01,  5.3084e-01,
-         3.9519e-14, -1.5532e-08,  5.5857e-12,  7.1089e-01, -8.0670e-04,
-         1.4921e-11, -1.1839e-01, -6.9219e-02, -1.2872e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.1404,  0.0000,  0.3023, -0.0368,  2.5311,  0.3660,  0.0000,  0.0000,
-         0.0832, -0.1439, -0.0715,  0.0000, -0.2254, -0.4471,  0.0000,  0.0000,
-         0.0000,  0.1054,  0.0000,  0.0000, -0.2094,  0.1620,  0.0000, -0.6861,
-        -0.0787,  0.0000, -0.0194, -0.1470,  0.7259,  0.0000, -0.0618, -1.2948,
-         0.0000, -0.1290,  0.0000,  0.0000,  0.0000,  0.0000,  0.0189,  0.3796,
-        -0.1965,  0.0000, -0.3414,  0.0000,  0.0000, -0.0652,  0.0000,  0.2760,
-         0.0000,  0.0000,  0.2472,  0.0598, -0.3606, -0.1004,  0.5308,  0.0000,
-         0.0000,  0.0000,  0.7109,  0.0000,  0.0000, -0.1184, -0.0692, -0.1287],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.1404,  0.0000,  0.3023, -0.0368,  2.5311,  0.3660,  0.0000,  0.0000,
-         0.0832, -0.1439, -0.0715,  0.0000, -0.2254, -0.4471,  0.0000,  0.0000,
-         0.0000,  0.1054,  0.0000,  0.0000, -0.2094,  0.1620,  0.0000, -0.6861,
-        -0.0787,  0.0000, -0.0194, -0.1470,  0.7259,  0.0000, -0.0618, -1.2948,
-         0.0000, -0.1290,  0.0000,  0.0000,  0.0000,  0.0000,  0.0189,  0.3796,
-        -0.1965,  0.0000, -0.3414,  0.0000,  0.0000, -0.0652,  0.0000,  0.2760,
-         0.0000,  0.0000,  0.2472,  0.0598, -0.3606, -0.1004,  0.5308,  0.0000,
-         0.0000,  0.0000,  0.7109,  0.0000,  0.0000, -0.1184, -0.0692, -0.1287],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.4393e-01, -4.3668e-03,  2.3422e-01, -3.1640e-02,  2.5276e+00,
-         3.1684e-01,  2.6433e-12, -2.0872e-05,  9.3822e-02, -1.0375e-01,
-        -1.0227e-01,  4.7519e-04, -2.2226e-01, -4.4907e-01, -4.7401e-09,
-        -3.5230e-06, -3.4655e-07, -6.3109e-02, -8.0153e-10,  9.6371e-06,
-        -1.7706e-01,  1.3193e-01, -1.7022e-09, -6.7220e-01, -1.1599e-01,
-         1.3815e-05,  1.1126e-02, -1.6739e-01,  6.1117e-01,  2.4756e-09,
-        -3.8017e-02, -1.2894e+00, -6.3369e-13, -1.2322e-01,  3.1453e-07,
-         1.6971e-05,  1.0492e-08,  0.0000e+00,  8.2770e-02,  4.3232e-01,
-        -2.4077e-01, -1.4014e-06, -3.4612e-01,  2.5822e-09,  6.3158e-07,
-        -1.2113e-01, -8.9115e-05,  3.1910e-01,  2.9978e-04, -1.6439e-10,
-         2.7838e-01,  1.2167e-01, -3.1262e-01, -2.9200e-01,  5.2683e-01,
-         3.4212e-14, -1.3446e-08,  4.8357e-12,  7.5005e-01, -6.9837e-04,
-         1.2917e-11, -1.4300e-01, -7.6767e-02, -1.4968e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.1439,  0.0000,  0.2342, -0.0316,  2.5276,  0.3168,  0.0000,  0.0000,
-         0.0938, -0.1037, -0.1023,  0.0000, -0.2223, -0.4491,  0.0000,  0.0000,
-         0.0000, -0.0631,  0.0000,  0.0000, -0.1771,  0.1319,  0.0000, -0.6722,
-        -0.1160,  0.0000,  0.0111, -0.1674,  0.6112,  0.0000, -0.0380, -1.2894,
-         0.0000, -0.1232,  0.0000,  0.0000,  0.0000,  0.0000,  0.0828,  0.4323,
-        -0.2408,  0.0000, -0.3461,  0.0000,  0.0000, -0.1211,  0.0000,  0.3191,
-         0.0000,  0.0000,  0.2784,  0.1217, -0.3126, -0.2920,  0.5268,  0.0000,
-         0.0000,  0.0000,  0.7501,  0.0000,  0.0000, -0.1430, -0.0768, -0.1497],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.1439,  0.0000,  0.2342, -0.0316,  2.5276,  0.3168,  0.0000,  0.0000,
-         0.0938, -0.1037, -0.1023,  0.0000, -0.2223, -0.4491,  0.0000,  0.0000,
-         0.0000, -0.0631,  0.0000,  0.0000, -0.1771,  0.1319,  0.0000, -0.6722,
-        -0.1160,  0.0000,  0.0111, -0.1674,  0.6112,  0.0000, -0.0380, -1.2894,
-         0.0000, -0.1232,  0.0000,  0.0000,  0.0000,  0.0000,  0.0828,  0.4323,
-        -0.2408,  0.0000, -0.3461,  0.0000,  0.0000, -0.1211,  0.0000,  0.3191,
-         0.0000,  0.0000,  0.2784,  0.1217, -0.3126, -0.2920,  0.5268,  0.0000,
-         0.0000,  0.0000,  0.7501,  0.0000,  0.0000, -0.1430, -0.0768, -0.1497],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.1126e-01, -3.7817e-03,  1.4912e-01,  1.8317e-02,  2.5252e+00,
-         2.1958e-01,  2.2891e-12, -1.8075e-05,  1.1120e-01, -7.3115e-02,
-        -8.0167e-02,  4.1152e-04, -2.1007e-01, -4.3419e-01, -4.1050e-09,
-        -3.0510e-06, -3.0011e-07, -1.0469e-01, -6.9414e-10,  8.3459e-06,
-        -1.3957e-01,  1.3047e-01, -1.4741e-09, -6.5851e-01, -8.8649e-02,
-         1.1964e-05,  4.2691e-02, -1.4890e-01,  5.0792e-01,  2.1439e-09,
-        -3.7300e-02, -1.2826e+00, -5.4879e-13, -1.1072e-01,  2.7239e-07,
-         1.4697e-05,  9.0864e-09,  0.0000e+00,  2.1582e-01,  4.9127e-01,
-        -2.1746e-01, -1.2136e-06, -3.5114e-01,  2.2362e-09,  5.4696e-07,
-        -2.7280e-01, -7.7175e-05,  3.1763e-01,  2.5962e-04, -1.4236e-10,
-         3.1748e-01,  2.0836e-01, -3.1745e-01, -2.6720e-01,  5.4042e-01,
-         2.9628e-14, -1.1644e-08,  4.1878e-12,  7.9562e-01, -6.0480e-04,
-         1.1186e-11, -1.5003e-01, -7.8529e-02, -1.8572e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.1113,  0.0000,  0.1491,  0.0183,  2.5252,  0.2196,  0.0000,  0.0000,
-         0.1112, -0.0731, -0.0802,  0.0000, -0.2101, -0.4342,  0.0000,  0.0000,
-         0.0000, -0.1047,  0.0000,  0.0000, -0.1396,  0.1305,  0.0000, -0.6585,
-        -0.0886,  0.0000,  0.0427, -0.1489,  0.5079,  0.0000, -0.0373, -1.2826,
-         0.0000, -0.1107,  0.0000,  0.0000,  0.0000,  0.0000,  0.2158,  0.4913,
-        -0.2175,  0.0000, -0.3511,  0.0000,  0.0000, -0.2728,  0.0000,  0.3176,
-         0.0000,  0.0000,  0.3175,  0.2084, -0.3175, -0.2672,  0.5404,  0.0000,
-         0.0000,  0.0000,  0.7956,  0.0000,  0.0000, -0.1500, -0.0785, -0.1857],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.1113,  0.0000,  0.1491,  0.0183,  2.5252,  0.2196,  0.0000,  0.0000,
-         0.1112, -0.0731, -0.0802,  0.0000, -0.2101, -0.4342,  0.0000,  0.0000,
-         0.0000, -0.1047,  0.0000,  0.0000, -0.1396,  0.1305,  0.0000, -0.6585,
-        -0.0886,  0.0000,  0.0427, -0.1489,  0.5079,  0.0000, -0.0373, -1.2826,
-         0.0000, -0.1107,  0.0000,  0.0000,  0.0000,  0.0000,  0.2158,  0.4913,
-        -0.2175,  0.0000, -0.3511,  0.0000,  0.0000, -0.2728,  0.0000,  0.3176,
-         0.0000,  0.0000,  0.3175,  0.2084, -0.3175, -0.2672,  0.5404,  0.0000,
-         0.0000,  0.0000,  0.7956,  0.0000,  0.0000, -0.1500, -0.0785, -0.1857],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-7.4348e-02, -3.2761e-03,  8.9666e-02,  6.3554e-02,  2.5239e+00,
-         9.6962e-02,  1.9831e-12, -1.5659e-05,  1.3405e-01, -7.3333e-02,
-        -2.7665e-02,  3.5651e-04, -1.8784e-01, -3.9717e-01, -3.5563e-09,
-        -2.6431e-06, -2.5999e-07, -1.0759e-01, -6.0135e-10,  7.2302e-06,
-        -1.2357e-01,  1.4251e-01, -1.2770e-09, -6.6153e-01, -3.9639e-02,
-         1.0365e-05,  6.6478e-02, -1.1190e-01,  4.1908e-01,  1.8573e-09,
-        -8.0067e-02, -1.2772e+00, -4.7542e-13, -1.1524e-01,  2.3598e-07,
-         1.2732e-05,  7.8717e-09,  0.0000e+00,  3.4971e-01,  5.4227e-01,
-        -1.6532e-01, -1.0514e-06, -3.6391e-01,  1.9373e-09,  4.7384e-07,
-        -3.3040e-01, -6.6858e-05,  2.8158e-01,  2.2491e-04, -1.2333e-10,
-         3.3232e-01,  2.6588e-01, -2.6032e-01, -2.2444e-01,  5.5369e-01,
-         2.5667e-14, -1.0088e-08,  3.6279e-12,  8.1652e-01, -5.2395e-04,
-         9.6910e-12, -1.5454e-01, -3.2972e-02, -2.1313e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0743,  0.0000,  0.0897,  0.0636,  2.5239,  0.0970,  0.0000,  0.0000,
-         0.1341, -0.0733, -0.0277,  0.0000, -0.1878, -0.3972,  0.0000,  0.0000,
-         0.0000, -0.1076,  0.0000,  0.0000, -0.1236,  0.1425,  0.0000, -0.6615,
-        -0.0396,  0.0000,  0.0665, -0.1119,  0.4191,  0.0000, -0.0801, -1.2772,
-         0.0000, -0.1152,  0.0000,  0.0000,  0.0000,  0.0000,  0.3497,  0.5423,
-        -0.1653,  0.0000, -0.3639,  0.0000,  0.0000, -0.3304,  0.0000,  0.2816,
-         0.0000,  0.0000,  0.3323,  0.2659, -0.2603, -0.2244,  0.5537,  0.0000,
-         0.0000,  0.0000,  0.8165,  0.0000,  0.0000, -0.1545, -0.0330, -0.2131],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0743,  0.0000,  0.0897,  0.0636,  2.5239,  0.0970,  0.0000,  0.0000,
-         0.1341, -0.0733, -0.0277,  0.0000, -0.1878, -0.3972,  0.0000,  0.0000,
-         0.0000, -0.1076,  0.0000,  0.0000, -0.1236,  0.1425,  0.0000, -0.6615,
-        -0.0396,  0.0000,  0.0665, -0.1119,  0.4191,  0.0000, -0.0801, -1.2772,
-         0.0000, -0.1152,  0.0000,  0.0000,  0.0000,  0.0000,  0.3497,  0.5423,
-        -0.1653,  0.0000, -0.3639,  0.0000,  0.0000, -0.3304,  0.0000,  0.2816,
-         0.0000,  0.0000,  0.3323,  0.2659, -0.2603, -0.2244,  0.5537,  0.0000,
-         0.0000,  0.0000,  0.8165,  0.0000,  0.0000, -0.1545, -0.0330, -0.2131],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-4.1126e-02, -2.8392e-03,  6.3950e-02,  1.3227e-01,  2.5214e+00,
-        -2.4282e-02,  1.7186e-12, -1.3570e-05,  1.4902e-01, -6.1021e-02,
-         4.0137e-02,  3.0896e-04, -1.3395e-01, -3.7437e-01, -3.0819e-09,
-        -2.2906e-06, -2.2532e-07, -5.1089e-02, -5.2114e-10,  6.2658e-06,
-        -1.4539e-01,  1.8223e-01, -1.1067e-09, -6.5880e-01,  2.5406e-02,
-         8.9821e-06,  8.3458e-02, -6.7024e-02,  3.4539e-01,  1.6095e-09,
-        -1.2759e-01, -1.2689e+00, -4.1201e-13, -1.0228e-01,  2.0450e-07,
-         1.1034e-05,  6.8217e-09,  0.0000e+00,  4.3872e-01,  5.8801e-01,
-        -9.6998e-02, -9.1116e-07, -3.4837e-01,  1.6789e-09,  4.1064e-07,
-        -3.6726e-01, -5.7940e-05,  2.2368e-01,  1.9491e-04, -1.0688e-10,
-         3.4796e-01,  3.0719e-01, -2.0768e-01, -1.0475e-01,  5.5695e-01,
-         2.2244e-14, -8.7422e-09,  3.1440e-12,  8.2090e-01, -4.5406e-04,
-         8.3984e-12, -1.3146e-01,  8.6258e-03, -2.4756e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0411,  0.0000,  0.0639,  0.1323,  2.5214, -0.0243,  0.0000,  0.0000,
-         0.1490, -0.0610,  0.0401,  0.0000, -0.1339, -0.3744,  0.0000,  0.0000,
-         0.0000, -0.0511,  0.0000,  0.0000, -0.1454,  0.1822,  0.0000, -0.6588,
-         0.0254,  0.0000,  0.0835, -0.0670,  0.3454,  0.0000, -0.1276, -1.2689,
-         0.0000, -0.1023,  0.0000,  0.0000,  0.0000,  0.0000,  0.4387,  0.5880,
-        -0.0970,  0.0000, -0.3484,  0.0000,  0.0000, -0.3673,  0.0000,  0.2237,
-         0.0000,  0.0000,  0.3480,  0.3072, -0.2077, -0.1047,  0.5569,  0.0000,
-         0.0000,  0.0000,  0.8209,  0.0000,  0.0000, -0.1315,  0.0086, -0.2476],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0411,  0.0000,  0.0639,  0.1323,  2.5214, -0.0243,  0.0000,  0.0000,
-         0.1490, -0.0610,  0.0401,  0.0000, -0.1339, -0.3744,  0.0000,  0.0000,
-         0.0000, -0.0511,  0.0000,  0.0000, -0.1454,  0.1822,  0.0000, -0.6588,
-         0.0254,  0.0000,  0.0835, -0.0670,  0.3454,  0.0000, -0.1276, -1.2689,
-         0.0000, -0.1023,  0.0000,  0.0000,  0.0000,  0.0000,  0.4387,  0.5880,
-        -0.0970,  0.0000, -0.3484,  0.0000,  0.0000, -0.3673,  0.0000,  0.2237,
-         0.0000,  0.0000,  0.3480,  0.3072, -0.2077, -0.1047,  0.5569,  0.0000,
-         0.0000,  0.0000,  0.8209,  0.0000,  0.0000, -0.1315,  0.0086, -0.2476],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-9.2761e-03, -2.4613e-03,  6.6971e-02,  1.7295e-01,  2.5158e+00,
-        -1.1136e-01,  1.4899e-12, -1.1764e-05,  1.6696e-01, -2.7409e-02,
-         1.1097e-01,  2.6784e-04, -6.8867e-02, -3.3438e-01, -2.6718e-09,
-        -1.9858e-06, -1.9533e-07,  3.6382e-02, -4.5179e-10,  5.4320e-06,
-        -1.4220e-01,  1.9196e-01, -9.5943e-10, -6.2879e-01,  4.9738e-02,
-         7.7868e-06,  1.0693e-01, -2.1481e-02,  2.7979e-01,  1.3954e-09,
-        -1.3504e-01, -1.2557e+00, -3.5718e-13, -9.2548e-02,  1.7729e-07,
-         9.5658e-06,  5.9139e-09,  0.0000e+00,  4.6606e-01,  5.9581e-01,
-        -4.5729e-02, -7.8990e-07, -3.2793e-01,  1.4555e-09,  3.5599e-07,
-        -3.7052e-01, -5.0230e-05,  1.4700e-01,  1.6897e-04, -9.2658e-11,
-         3.6108e-01,  3.0518e-01, -1.5397e-01,  1.2649e-02,  5.3693e-01,
-         1.9284e-14, -7.5788e-09,  2.7256e-12,  8.2298e-01, -3.9364e-04,
-         7.2808e-12, -7.6354e-02,  4.0111e-02, -2.4973e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0093,  0.0000,  0.0670,  0.1729,  2.5158, -0.1114,  0.0000,  0.0000,
-         0.1670, -0.0274,  0.1110,  0.0000, -0.0689, -0.3344,  0.0000,  0.0000,
-         0.0000,  0.0364,  0.0000,  0.0000, -0.1422,  0.1920,  0.0000, -0.6288,
-         0.0497,  0.0000,  0.1069, -0.0215,  0.2798,  0.0000, -0.1350, -1.2557,
-         0.0000, -0.0925,  0.0000,  0.0000,  0.0000,  0.0000,  0.4661,  0.5958,
-        -0.0457,  0.0000, -0.3279,  0.0000,  0.0000, -0.3705,  0.0000,  0.1470,
-         0.0000,  0.0000,  0.3611,  0.3052, -0.1540,  0.0126,  0.5369,  0.0000,
-         0.0000,  0.0000,  0.8230,  0.0000,  0.0000, -0.0764,  0.0401, -0.2497],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0093,  0.0000,  0.0670,  0.1729,  2.5158, -0.1114,  0.0000,  0.0000,
-         0.1670, -0.0274,  0.1110,  0.0000, -0.0689, -0.3344,  0.0000,  0.0000,
-         0.0000,  0.0364,  0.0000,  0.0000, -0.1422,  0.1920,  0.0000, -0.6288,
-         0.0497,  0.0000,  0.1069, -0.0215,  0.2798,  0.0000, -0.1350, -1.2557,
-         0.0000, -0.0925,  0.0000,  0.0000,  0.0000,  0.0000,  0.4661,  0.5958,
-        -0.0457,  0.0000, -0.3279,  0.0000,  0.0000, -0.3705,  0.0000,  0.1470,
-         0.0000,  0.0000,  0.3611,  0.3052, -0.1540,  0.0126,  0.5369,  0.0000,
-         0.0000,  0.0000,  0.8230,  0.0000,  0.0000, -0.0764,  0.0401, -0.2497],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.0663e-02, -2.1345e-03,  8.7744e-02,  2.1527e-01,  2.5096e+00,
-        -1.5477e-01,  1.2921e-12, -1.0202e-05,  1.4808e-01,  2.9020e-02,
-         1.4692e-01,  2.3228e-04, -1.9470e-02, -3.1970e-01, -2.3171e-09,
-        -1.7221e-06, -1.6940e-07,  9.0340e-02, -3.9180e-10,  4.7108e-06,
-        -1.8600e-01,  1.7457e-01, -8.3205e-10, -6.1681e-01,  6.9971e-02,
-         6.7529e-06,  1.6285e-01, -6.9178e-03,  2.0175e-01,  1.2101e-09,
-        -1.4846e-01, -1.2470e+00, -3.0976e-13, -7.2128e-02,  1.5375e-07,
-         8.2957e-06,  5.1287e-09,  0.0000e+00,  4.4875e-01,  5.9658e-01,
-        -1.9150e-02, -6.8503e-07, -3.0062e-01,  1.2622e-09,  3.0873e-07,
-        -3.8446e-01, -4.3561e-05,  8.5129e-02,  1.4654e-04, -8.0356e-11,
-         3.7102e-01,  2.5898e-01, -1.4415e-01,  1.7834e-01,  5.0748e-01,
-         1.6723e-14, -6.5726e-09,  2.3637e-12,  8.2629e-01, -3.4138e-04,
-         6.3141e-12, -2.4695e-02,  6.2514e-02, -2.4754e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0207,  0.0000,  0.0877,  0.2153,  2.5096, -0.1548,  0.0000,  0.0000,
-         0.1481,  0.0290,  0.1469,  0.0000, -0.0195, -0.3197,  0.0000,  0.0000,
-         0.0000,  0.0903,  0.0000,  0.0000, -0.1860,  0.1746,  0.0000, -0.6168,
-         0.0700,  0.0000,  0.1629, -0.0069,  0.2018,  0.0000, -0.1485, -1.2470,
-         0.0000, -0.0721,  0.0000,  0.0000,  0.0000,  0.0000,  0.4488,  0.5966,
-        -0.0192,  0.0000, -0.3006,  0.0000,  0.0000, -0.3845,  0.0000,  0.0851,
-         0.0000,  0.0000,  0.3710,  0.2590, -0.1441,  0.1783,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.8263,  0.0000,  0.0000, -0.0247,  0.0625, -0.2475],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0207,  0.0000,  0.0877,  0.2153,  2.5096, -0.1548,  0.0000,  0.0000,
-         0.1481,  0.0290,  0.1469,  0.0000, -0.0195, -0.3197,  0.0000,  0.0000,
-         0.0000,  0.0903,  0.0000,  0.0000, -0.1860,  0.1746,  0.0000, -0.6168,
-         0.0700,  0.0000,  0.1629, -0.0069,  0.2018,  0.0000, -0.1485, -1.2470,
-         0.0000, -0.0721,  0.0000,  0.0000,  0.0000,  0.0000,  0.4488,  0.5966,
-        -0.0192,  0.0000, -0.3006,  0.0000,  0.0000, -0.3845,  0.0000,  0.0851,
-         0.0000,  0.0000,  0.3710,  0.2590, -0.1441,  0.1783,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.8263,  0.0000,  0.0000, -0.0247,  0.0625, -0.2475],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4064e-02, -1.8518e-03,  6.3442e-02,  2.4105e-01,  2.5040e+00,
-        -1.4085e-01,  1.1209e-12, -8.8510e-06,  9.5242e-02,  7.0510e-02,
-         1.4029e-01,  2.0151e-04, -1.3558e-02, -3.0734e-01, -2.0101e-09,
-        -1.4940e-06, -1.4696e-07,  9.3523e-03, -3.3990e-10,  4.0868e-06,
-        -2.4902e-01,  8.7881e-02, -7.2183e-10, -6.3134e-01,  4.5070e-02,
-         5.8584e-06,  2.3216e-01, -4.8558e-02,  1.0426e-01,  1.0498e-09,
-        -1.9902e-01, -1.2431e+00, -2.6873e-13, -6.1991e-02,  1.3338e-07,
-         7.1969e-06,  4.4494e-09,  0.0000e+00,  4.1948e-01,  5.7870e-01,
-        -8.8463e-02, -5.9429e-07, -2.6605e-01,  1.0950e-09,  2.6783e-07,
-        -3.3064e-01, -3.7790e-05,  8.3697e-02,  1.2713e-04, -6.9712e-11,
-         3.5882e-01,  1.8360e-01, -1.4795e-01,  1.2895e-01, -2.5546e-02,
-         1.4508e-14, -5.7020e-09,  2.0506e-12,  8.6247e-01, -2.9616e-04,
-         5.4777e-12, -3.2043e-03,  2.9837e-02, -2.2936e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0341,  0.0000,  0.0634,  0.2411,  2.5040, -0.1408,  0.0000,  0.0000,
-         0.0952,  0.0705,  0.1403,  0.0000, -0.0136, -0.3073,  0.0000,  0.0000,
-         0.0000,  0.0094,  0.0000,  0.0000, -0.2490,  0.0879,  0.0000, -0.6313,
-         0.0451,  0.0000,  0.2322, -0.0486,  0.1043,  0.0000, -0.1990, -1.2431,
-         0.0000, -0.0620,  0.0000,  0.0000,  0.0000,  0.0000,  0.4195,  0.5787,
-        -0.0885,  0.0000, -0.2660,  0.0000,  0.0000, -0.3306,  0.0000,  0.0837,
-         0.0000,  0.0000,  0.3588,  0.1836, -0.1480,  0.1290,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.8625,  0.0000,  0.0000, -0.0032,  0.0298, -0.2294],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0341,  0.0000,  0.0634,  0.2411,  2.5040, -0.1408,  0.0000,  0.0000,
-         0.0952,  0.0705,  0.1403,  0.0000, -0.0136, -0.3073,  0.0000,  0.0000,
-         0.0000,  0.0094,  0.0000,  0.0000, -0.2490,  0.0879,  0.0000, -0.6313,
-         0.0451,  0.0000,  0.2322, -0.0486,  0.1043,  0.0000, -0.1990, -1.2431,
-         0.0000, -0.0620,  0.0000,  0.0000,  0.0000,  0.0000,  0.4195,  0.5787,
-        -0.0885,  0.0000, -0.2660,  0.0000,  0.0000, -0.3306,  0.0000,  0.0837,
-         0.0000,  0.0000,  0.3588,  0.1836, -0.1480,  0.1290,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.8625,  0.0000,  0.0000, -0.0032,  0.0298, -0.2294],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7907e-02, -1.6071e-03,  1.5507e-02,  2.3172e-01,  2.5004e+00,
-        -8.6071e-02,  9.7279e-13, -7.6813e-06,  4.0515e-02,  6.5664e-02,
-         1.1892e-01,  1.7488e-04, -3.9326e-02, -2.5496e-01, -1.7445e-09,
-        -1.2966e-06, -1.2754e-07, -1.5872e-01, -2.9498e-10,  3.5467e-06,
-        -2.8134e-01, -3.5351e-02, -6.2644e-10, -6.2779e-01, -2.8667e-02,
-         5.0842e-06,  2.7250e-01, -8.4589e-02, -8.3811e-03,  9.1106e-10,
-        -2.4681e-01, -1.2432e+00, -2.3321e-13, -8.6443e-02,  1.1576e-07,
-         6.2458e-06,  3.8614e-09,  0.0000e+00,  3.3571e-01,  5.5891e-01,
-        -1.5798e-01, -5.1575e-07, -2.5925e-01,  9.5031e-10,  2.3244e-07,
-        -2.3895e-01, -3.2796e-05,  7.8146e-02,  1.1033e-04, -6.0499e-11,
-         3.3324e-01,  1.1871e-01, -1.2166e-01, -1.0950e-01, -2.2170e-02,
-         1.2591e-14, -4.9484e-09,  1.7796e-12,  8.8481e-01, -2.5702e-04,
-         4.7538e-12, -4.0737e-02,  2.1610e-03, -1.9896e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.7907e-02,  0.0000e+00,  1.5507e-02,  2.3172e-01,  2.5004e+00,
-        -8.6071e-02,  0.0000e+00,  0.0000e+00,  4.0515e-02,  6.5664e-02,
-         1.1892e-01,  0.0000e+00, -3.9326e-02, -2.5496e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.5872e-01,  0.0000e+00,  0.0000e+00,
-        -2.8134e-01, -3.5351e-02,  0.0000e+00, -6.2779e-01, -2.8667e-02,
-         0.0000e+00,  2.7250e-01, -8.4589e-02, -8.3811e-03,  0.0000e+00,
-        -2.4681e-01, -1.2432e+00,  0.0000e+00, -8.6443e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.3571e-01,  5.5891e-01,
-        -1.5798e-01,  0.0000e+00, -2.5925e-01,  0.0000e+00,  0.0000e+00,
-        -2.3895e-01,  0.0000e+00,  7.8146e-02,  0.0000e+00,  0.0000e+00,
-         3.3324e-01,  1.1871e-01, -1.2166e-01, -1.0950e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  8.8481e-01,  0.0000e+00,
-         0.0000e+00, -4.0737e-02,  2.1610e-03, -1.9896e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.7907e-02,  0.0000e+00,  1.5507e-02,  2.3172e-01,  2.5004e+00,
-        -8.6071e-02,  0.0000e+00,  0.0000e+00,  4.0515e-02,  6.5664e-02,
-         1.1892e-01,  0.0000e+00, -3.9326e-02, -2.5496e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.5872e-01,  0.0000e+00,  0.0000e+00,
-        -2.8134e-01, -3.5351e-02,  0.0000e+00, -6.2779e-01, -2.8667e-02,
-         0.0000e+00,  2.7250e-01, -8.4589e-02, -8.3811e-03,  0.0000e+00,
-        -2.4681e-01, -1.2432e+00,  0.0000e+00, -8.6443e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.3571e-01,  5.5891e-01,
-        -1.5798e-01,  0.0000e+00, -2.5925e-01,  0.0000e+00,  0.0000e+00,
-        -2.3895e-01,  0.0000e+00,  7.8146e-02,  0.0000e+00,  0.0000e+00,
-         3.3324e-01,  1.1871e-01, -1.2166e-01, -1.0950e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  8.8481e-01,  0.0000e+00,
-         0.0000e+00, -4.0737e-02,  2.1610e-03, -1.9896e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 8.0087e-02, -1.3952e-03,  5.3678e-03,  2.3880e-01,  2.4956e+00,
-        -4.2729e-02,  8.4453e-13, -6.6686e-06, -2.4407e-02,  5.9434e-02,
-         9.4507e-02,  1.5182e-04, -3.4909e-02, -1.9836e-01, -1.5145e-09,
-        -1.1256e-06, -1.1072e-07, -2.6628e-01, -2.5609e-10,  3.0791e-06,
-        -3.1666e-01, -8.9567e-02, -5.4385e-10, -6.3856e-01, -7.1245e-02,
-         4.4139e-06,  3.1759e-01, -8.4985e-02, -7.6956e-02,  7.9095e-10,
-        -3.0108e-01, -1.2420e+00, -2.0247e-13, -9.3832e-02,  1.0049e-07,
-         5.4223e-06,  3.3523e-09,  0.0000e+00,  3.0745e-01,  5.7250e-01,
-        -1.8593e-01, -4.4775e-07, -2.5160e-01,  8.2502e-10,  2.0179e-07,
-        -1.8650e-01, -2.8472e-05,  9.5961e-02,  9.5781e-05, -5.2523e-11,
-         3.3479e-01,  4.7218e-02, -1.4074e-01, -2.1112e-01, -1.9247e-02,
-         1.0931e-14, -4.2960e-09,  1.5450e-12,  8.9971e-01, -2.2313e-04,
-         4.1271e-12, -9.7421e-02, -2.3452e-03, -1.8672e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.0087e-02,  0.0000e+00,  5.3678e-03,  2.3880e-01,  2.4956e+00,
-        -4.2729e-02,  0.0000e+00,  0.0000e+00, -2.4407e-02,  5.9434e-02,
-         9.4507e-02,  0.0000e+00, -3.4909e-02, -1.9836e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.6628e-01,  0.0000e+00,  0.0000e+00,
-        -3.1666e-01, -8.9567e-02,  0.0000e+00, -6.3856e-01, -7.1245e-02,
-         0.0000e+00,  3.1759e-01, -8.4985e-02, -7.6956e-02,  0.0000e+00,
-        -3.0108e-01, -1.2420e+00,  0.0000e+00, -9.3832e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.0745e-01,  5.7250e-01,
-        -1.8593e-01,  0.0000e+00, -2.5160e-01,  0.0000e+00,  0.0000e+00,
-        -1.8650e-01,  0.0000e+00,  9.5961e-02,  0.0000e+00,  0.0000e+00,
-         3.3479e-01,  4.7218e-02, -1.4074e-01, -2.1112e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  8.9971e-01,  0.0000e+00,
-         0.0000e+00, -9.7421e-02, -2.3452e-03, -1.8672e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.0087e-02,  0.0000e+00,  5.3678e-03,  2.3880e-01,  2.4956e+00,
-        -4.2729e-02,  0.0000e+00,  0.0000e+00, -2.4407e-02,  5.9434e-02,
-         9.4507e-02,  0.0000e+00, -3.4909e-02, -1.9836e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.6628e-01,  0.0000e+00,  0.0000e+00,
-        -3.1666e-01, -8.9567e-02,  0.0000e+00, -6.3856e-01, -7.1245e-02,
-         0.0000e+00,  3.1759e-01, -8.4985e-02, -7.6956e-02,  0.0000e+00,
-        -3.0108e-01, -1.2420e+00,  0.0000e+00, -9.3832e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.0745e-01,  5.7250e-01,
-        -1.8593e-01,  0.0000e+00, -2.5160e-01,  0.0000e+00,  0.0000e+00,
-        -1.8650e-01,  0.0000e+00,  9.5961e-02,  0.0000e+00,  0.0000e+00,
-         3.3479e-01,  4.7218e-02, -1.4074e-01, -2.1112e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  8.9971e-01,  0.0000e+00,
-         0.0000e+00, -9.7421e-02, -2.3452e-03, -1.8672e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.6361e-01, -1.2117e-03,  4.4280e-03,  2.6936e-01,  2.4925e+00,
-        -4.0941e-02,  7.3345e-13, -5.7915e-06, -1.2273e-01,  9.7050e-02,
-         7.5004e-02,  1.3185e-04, -5.2414e-02, -2.0891e-01, -1.3153e-09,
-        -9.7757e-07, -9.6159e-08, -1.8499e-01, -2.2241e-10,  2.6741e-06,
-        -3.5158e-01,  1.7810e-02, -4.7232e-10, -7.0881e-01,  6.8123e-03,
-         3.8334e-06,  3.4221e-01, -1.6662e-02, -7.0062e-02,  6.8692e-10,
-        -3.3963e-01, -1.2434e+00, -1.7584e-13, -5.0052e-02,  8.7277e-08,
-         4.7091e-06,  2.9114e-09,  0.0000e+00,  3.0802e-01,  6.0847e-01,
-        -7.5855e-02, -3.8886e-07, -2.8159e-01,  7.1651e-10,  1.7525e-07,
-        -2.3231e-01, -2.4727e-05,  7.3876e-02,  8.3183e-05, -4.5615e-11,
-         3.8240e-01,  2.3471e-02, -2.4097e-01,  1.4952e-02, -1.6716e-02,
-         9.4931e-15, -3.7310e-09,  1.3418e-12,  8.8883e-01, -1.9378e-04,
-         3.5842e-12, -1.5778e-01,  1.8561e-03, -2.1321e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 1.6361e-01,  0.0000e+00,  4.4280e-03,  2.6936e-01,  2.4925e+00,
-        -4.0941e-02,  0.0000e+00,  0.0000e+00, -1.2273e-01,  9.7050e-02,
-         7.5004e-02,  0.0000e+00, -5.2414e-02, -2.0891e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.8499e-01,  0.0000e+00,  0.0000e+00,
-        -3.5158e-01,  1.7810e-02,  0.0000e+00, -7.0881e-01,  6.8123e-03,
-         0.0000e+00,  3.4221e-01, -1.6662e-02, -7.0062e-02,  0.0000e+00,
-        -3.3963e-01, -1.2434e+00,  0.0000e+00, -5.0052e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.0802e-01,  6.0847e-01,
-        -7.5855e-02,  0.0000e+00, -2.8159e-01,  0.0000e+00,  0.0000e+00,
-        -2.3231e-01,  0.0000e+00,  7.3876e-02,  0.0000e+00,  0.0000e+00,
-         3.8240e-01,  2.3471e-02, -2.4097e-01,  1.4952e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  8.8883e-01,  0.0000e+00,
-         0.0000e+00, -1.5778e-01,  1.8561e-03, -2.1321e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 1.6361e-01,  0.0000e+00,  4.4280e-03,  2.6936e-01,  2.4925e+00,
-        -4.0941e-02,  0.0000e+00,  0.0000e+00, -1.2273e-01,  9.7050e-02,
-         7.5004e-02,  0.0000e+00, -5.2414e-02, -2.0891e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.8499e-01,  0.0000e+00,  0.0000e+00,
-        -3.5158e-01,  1.7810e-02,  0.0000e+00, -7.0881e-01,  6.8123e-03,
-         0.0000e+00,  3.4221e-01, -1.6662e-02, -7.0062e-02,  0.0000e+00,
-        -3.3963e-01, -1.2434e+00,  0.0000e+00, -5.0052e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.0802e-01,  6.0847e-01,
-        -7.5855e-02,  0.0000e+00, -2.8159e-01,  0.0000e+00,  0.0000e+00,
-        -2.3231e-01,  0.0000e+00,  7.3876e-02,  0.0000e+00,  0.0000e+00,
-         3.8240e-01,  2.3471e-02, -2.4097e-01,  1.4952e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  8.8883e-01,  0.0000e+00,
-         0.0000e+00, -1.5778e-01,  1.8561e-03, -2.1321e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4381e-01, -1.0527e-03, -2.8344e-02,  3.0493e-01,  2.4914e+00,
-        -7.1854e-02,  6.3721e-13, -5.0315e-06, -1.9979e-01,  1.6784e-01,
-         7.2574e-02,  1.1455e-04, -1.0102e-01, -1.8049e-01, -1.1427e-09,
-        -8.4930e-07, -8.3542e-08, -8.1820e-02, -1.9323e-10,  2.3232e-06,
-        -3.6335e-01,  1.4870e-01, -4.1034e-10, -7.4969e-01,  1.2738e-01,
-         3.3304e-06,  3.6242e-01,  7.8344e-02, -5.1309e-02,  5.9678e-10,
-        -4.0609e-01, -1.2366e+00, -1.5276e-13,  1.3363e-02,  7.5825e-08,
-         4.0912e-06,  2.5293e-09,  0.0000e+00,  2.9638e-01,  6.3573e-01,
-         6.0021e-02, -3.3784e-07, -3.2784e-01,  6.2249e-10,  1.5226e-07,
-        -2.7983e-01, -2.1483e-05,  2.5683e-02,  7.2268e-05, -3.9629e-11,
-         4.2440e-01,  5.5823e-02, -3.3879e-01,  1.5079e-01, -1.4522e-02,
-         8.2475e-15, -3.2414e-09,  1.1657e-12,  8.7495e-01, -1.6836e-04,
-         3.1139e-12, -2.6324e-01,  6.6831e-03, -2.2706e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2438,  0.0000, -0.0283,  0.3049,  2.4914, -0.0719,  0.0000,  0.0000,
-        -0.1998,  0.1678,  0.0726,  0.0000, -0.1010, -0.1805,  0.0000,  0.0000,
-         0.0000, -0.0818,  0.0000,  0.0000, -0.3633,  0.1487,  0.0000, -0.7497,
-         0.1274,  0.0000,  0.3624,  0.0783, -0.0513,  0.0000, -0.4061, -1.2366,
-         0.0000,  0.0134,  0.0000,  0.0000,  0.0000,  0.0000,  0.2964,  0.6357,
-         0.0600,  0.0000, -0.3278,  0.0000,  0.0000, -0.2798,  0.0000,  0.0257,
-         0.0000,  0.0000,  0.4244,  0.0558, -0.3388,  0.1508,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.8749,  0.0000,  0.0000, -0.2632,  0.0067, -0.2271],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2438,  0.0000, -0.0283,  0.3049,  2.4914, -0.0719,  0.0000,  0.0000,
-        -0.1998,  0.1678,  0.0726,  0.0000, -0.1010, -0.1805,  0.0000,  0.0000,
-         0.0000, -0.0818,  0.0000,  0.0000, -0.3633,  0.1487,  0.0000, -0.7497,
-         0.1274,  0.0000,  0.3624,  0.0783, -0.0513,  0.0000, -0.4061, -1.2366,
-         0.0000,  0.0134,  0.0000,  0.0000,  0.0000,  0.0000,  0.2964,  0.6357,
-         0.0600,  0.0000, -0.3278,  0.0000,  0.0000, -0.2798,  0.0000,  0.0257,
-         0.0000,  0.0000,  0.4244,  0.0558, -0.3388,  0.1508,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.8749,  0.0000,  0.0000, -0.2632,  0.0067, -0.2271],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.0691e-01, -9.1490e-04, -1.1710e-02,  3.3378e-01,  2.4903e+00,
-        -8.3662e-02,  5.5381e-13, -4.3729e-06, -2.4844e-01,  2.1124e-01,
-         1.5990e-02,  9.9559e-05, -8.3334e-02, -1.2504e-01, -9.9313e-10,
-        -7.3813e-07, -7.2607e-08,  1.6695e-02, -1.6793e-10,  2.0191e-06,
-        -3.4264e-01,  1.9348e-01, -3.5663e-10, -7.3740e-01,  1.8205e-01,
-         2.8944e-06,  3.6632e-01,  1.5813e-01, -9.3960e-03,  5.1867e-10,
-        -5.0606e-01, -1.2223e+00, -1.3277e-13,  6.5983e-02,  6.5900e-08,
-         3.5557e-06,  2.1983e-09,  0.0000e+00,  3.6673e-01,  6.8220e-01,
-         7.8350e-02, -2.9361e-07, -3.3107e-01,  5.4101e-10,  1.3233e-07,
-        -2.9119e-01, -1.8671e-05,  3.8104e-02,  6.2809e-05, -3.4442e-11,
-         4.5595e-01,  1.1895e-01, -4.4283e-01,  2.2395e-01, -1.2621e-02,
-         7.1679e-15, -2.8171e-09,  1.0131e-12,  8.4282e-01, -1.4632e-04,
-         2.7063e-12, -3.7391e-01, -2.9510e-02, -2.2758e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3069,  0.0000, -0.0117,  0.3338,  2.4903, -0.0837,  0.0000,  0.0000,
-        -0.2484,  0.2112,  0.0160,  0.0000, -0.0833, -0.1250,  0.0000,  0.0000,
-         0.0000,  0.0167,  0.0000,  0.0000, -0.3426,  0.1935,  0.0000, -0.7374,
-         0.1820,  0.0000,  0.3663,  0.1581, -0.0094,  0.0000, -0.5061, -1.2223,
-         0.0000,  0.0660,  0.0000,  0.0000,  0.0000,  0.0000,  0.3667,  0.6822,
-         0.0784,  0.0000, -0.3311,  0.0000,  0.0000, -0.2912,  0.0000,  0.0381,
-         0.0000,  0.0000,  0.4559,  0.1190, -0.4428,  0.2239,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.8428,  0.0000,  0.0000, -0.3739, -0.0295, -0.2276],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3069,  0.0000, -0.0117,  0.3338,  2.4903, -0.0837,  0.0000,  0.0000,
-        -0.2484,  0.2112,  0.0160,  0.0000, -0.0833, -0.1250,  0.0000,  0.0000,
-         0.0000,  0.0167,  0.0000,  0.0000, -0.3426,  0.1935,  0.0000, -0.7374,
-         0.1820,  0.0000,  0.3663,  0.1581, -0.0094,  0.0000, -0.5061, -1.2223,
-         0.0000,  0.0660,  0.0000,  0.0000,  0.0000,  0.0000,  0.3667,  0.6822,
-         0.0784,  0.0000, -0.3311,  0.0000,  0.0000, -0.2912,  0.0000,  0.0381,
-         0.0000,  0.0000,  0.4559,  0.1190, -0.4428,  0.2239,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.8428,  0.0000,  0.0000, -0.3739, -0.0295, -0.2276],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6450e-01, -7.9544e-04,  2.3147e-02,  3.4645e-01,  2.4844e+00,
-        -7.6422e-02,  4.8149e-13, -3.8019e-06, -2.5304e-01,  2.7517e-01,
-        -3.6314e-02,  8.6559e-05, -2.4694e-02, -8.3301e-02, -8.6345e-10,
-        -6.4175e-07, -6.3126e-08,  1.1796e-01, -1.4601e-10,  1.7555e-06,
-        -3.0096e-01,  1.8460e-01, -3.1006e-10, -6.5703e-01,  1.6038e-01,
-         2.5165e-06,  3.6726e-01,  2.3145e-01,  1.0745e-02,  4.5094e-10,
-        -6.5267e-01, -1.2056e+00, -1.1543e-13,  8.6846e-02,  5.7295e-08,
-         3.0914e-06,  1.9112e-09,  0.0000e+00,  4.1689e-01,  7.3506e-01,
-         5.2292e-02, -2.5527e-07, -3.1394e-01,  4.7037e-10,  1.1505e-07,
-        -2.6739e-01, -1.6233e-05,  6.3986e-02,  5.4607e-05, -2.9945e-11,
-         4.7691e-01,  1.7959e-01, -5.0076e-01,  2.1365e-01, -1.0973e-02,
-         6.2320e-15, -2.4493e-09,  8.8085e-13,  7.9530e-01, -1.2721e-04,
-         2.3529e-12, -4.6626e-01, -8.2155e-02, -2.2701e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3645,  0.0000,  0.0231,  0.3465,  2.4844, -0.0764,  0.0000,  0.0000,
-        -0.2530,  0.2752, -0.0363,  0.0000, -0.0247, -0.0833,  0.0000,  0.0000,
-         0.0000,  0.1180,  0.0000,  0.0000, -0.3010,  0.1846,  0.0000, -0.6570,
-         0.1604,  0.0000,  0.3673,  0.2315,  0.0107,  0.0000, -0.6527, -1.2056,
-         0.0000,  0.0868,  0.0000,  0.0000,  0.0000,  0.0000,  0.4169,  0.7351,
-         0.0523,  0.0000, -0.3139,  0.0000,  0.0000, -0.2674,  0.0000,  0.0640,
-         0.0000,  0.0000,  0.4769,  0.1796, -0.5008,  0.2137,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7953,  0.0000,  0.0000, -0.4663, -0.0822, -0.2270],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3645,  0.0000,  0.0231,  0.3465,  2.4844, -0.0764,  0.0000,  0.0000,
-        -0.2530,  0.2752, -0.0363,  0.0000, -0.0247, -0.0833,  0.0000,  0.0000,
-         0.0000,  0.1180,  0.0000,  0.0000, -0.3010,  0.1846,  0.0000, -0.6570,
-         0.1604,  0.0000,  0.3673,  0.2315,  0.0107,  0.0000, -0.6527, -1.2056,
-         0.0000,  0.0868,  0.0000,  0.0000,  0.0000,  0.0000,  0.4169,  0.7351,
-         0.0523,  0.0000, -0.3139,  0.0000,  0.0000, -0.2674,  0.0000,  0.0640,
-         0.0000,  0.0000,  0.4769,  0.1796, -0.5008,  0.2137,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7953,  0.0000,  0.0000, -0.4663, -0.0822, -0.2270],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7323e-01, -6.9182e-04,  1.8806e-02,  3.3616e-01,  2.4777e+00,
-        -4.9277e-02,  4.1877e-13, -3.3067e-06, -2.5814e-01,  3.6309e-01,
-        -1.2003e-02,  7.5284e-05, -1.8831e-02, -6.4794e-02, -7.5098e-10,
-        -5.5815e-07, -5.4903e-08,  1.8021e-01, -1.2699e-10,  1.5268e-06,
-        -2.5448e-01,  1.6603e-01, -2.6967e-10, -5.9410e-01,  1.3207e-01,
-         2.1887e-06,  3.7357e-01,  2.8971e-01,  1.7733e-02,  3.9220e-10,
-        -7.7564e-01, -1.1958e+00, -1.0040e-13,  1.0530e-01,  4.9832e-08,
-         2.6887e-06,  1.6623e-09,  0.0000e+00,  4.4953e-01,  7.5786e-01,
-         4.1089e-02, -2.2202e-07, -3.1670e-01,  4.0910e-10,  1.0006e-07,
-        -2.0051e-01, -1.4118e-05,  9.4722e-02,  4.7494e-05, -2.6044e-11,
-         4.9095e-01,  1.9697e-01, -5.0386e-01,  1.3463e-01, -9.5439e-03,
-         5.4202e-15, -2.1302e-09,  7.6611e-13,  7.3903e-01, -1.1064e-04,
-         2.0465e-12, -5.3985e-01, -1.0053e-01, -2.0777e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3732,  0.0000,  0.0188,  0.3362,  2.4777, -0.0493,  0.0000,  0.0000,
-        -0.2581,  0.3631, -0.0120,  0.0000, -0.0188, -0.0648,  0.0000,  0.0000,
-         0.0000,  0.1802,  0.0000,  0.0000, -0.2545,  0.1660,  0.0000, -0.5941,
-         0.1321,  0.0000,  0.3736,  0.2897,  0.0177,  0.0000, -0.7756, -1.1958,
-         0.0000,  0.1053,  0.0000,  0.0000,  0.0000,  0.0000,  0.4495,  0.7579,
-         0.0411,  0.0000, -0.3167,  0.0000,  0.0000, -0.2005,  0.0000,  0.0947,
-         0.0000,  0.0000,  0.4910,  0.1970, -0.5039,  0.1346,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7390,  0.0000,  0.0000, -0.5398, -0.1005, -0.2078],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3732,  0.0000,  0.0188,  0.3362,  2.4777, -0.0493,  0.0000,  0.0000,
-        -0.2581,  0.3631, -0.0120,  0.0000, -0.0188, -0.0648,  0.0000,  0.0000,
-         0.0000,  0.1802,  0.0000,  0.0000, -0.2545,  0.1660,  0.0000, -0.5941,
-         0.1321,  0.0000,  0.3736,  0.2897,  0.0177,  0.0000, -0.7756, -1.1958,
-         0.0000,  0.1053,  0.0000,  0.0000,  0.0000,  0.0000,  0.4495,  0.7579,
-         0.0411,  0.0000, -0.3167,  0.0000,  0.0000, -0.2005,  0.0000,  0.0947,
-         0.0000,  0.0000,  0.4910,  0.1970, -0.5039,  0.1346,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7390,  0.0000,  0.0000, -0.5398, -0.1005, -0.2078],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7560e-01, -6.0193e-04,  6.4413e-02,  3.1489e-01,  2.4705e+00,
-         1.2096e-02,  3.6436e-13, -2.8770e-06, -2.7910e-01,  4.3189e-01,
-         5.3175e-02,  6.5502e-05,  1.1170e-02, -8.9337e-02, -6.5340e-10,
-        -4.8563e-07, -4.7769e-08,  1.9462e-01, -1.1049e-10,  1.3284e-06,
-        -2.4614e-01,  1.0148e-01, -2.3463e-10, -5.6785e-01,  5.8958e-02,
-         1.9043e-06,  3.8619e-01,  2.9406e-01, -1.0002e-02,  3.4124e-10,
-        -8.4260e-01, -1.1884e+00, -8.7351e-14,  1.0299e-01,  4.3357e-08,
-         2.3394e-06,  1.4463e-09,  0.0000e+00,  4.8192e-01,  7.7157e-01,
-        -3.3535e-02, -1.9317e-07, -2.9051e-01,  3.5594e-10,  8.7060e-08,
-        -8.0814e-02, -1.2284e-05,  1.5022e-01,  4.1323e-05, -2.2660e-11,
-         4.9823e-01,  1.1855e-01, -4.8418e-01,  2.3343e-02, -8.3038e-03,
-         4.7159e-15, -1.8534e-09,  6.6657e-13,  6.8350e-01, -9.6266e-05,
-         1.7806e-12, -5.7469e-01, -1.0293e-01, -1.5016e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3756,  0.0000,  0.0644,  0.3149,  2.4705,  0.0121,  0.0000,  0.0000,
-        -0.2791,  0.4319,  0.0532,  0.0000,  0.0112, -0.0893,  0.0000,  0.0000,
-         0.0000,  0.1946,  0.0000,  0.0000, -0.2461,  0.1015,  0.0000, -0.5679,
-         0.0590,  0.0000,  0.3862,  0.2941, -0.0100,  0.0000, -0.8426, -1.1884,
-         0.0000,  0.1030,  0.0000,  0.0000,  0.0000,  0.0000,  0.4819,  0.7716,
-        -0.0335,  0.0000, -0.2905,  0.0000,  0.0000, -0.0808,  0.0000,  0.1502,
-         0.0000,  0.0000,  0.4982,  0.1185, -0.4842,  0.0233,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6835,  0.0000,  0.0000, -0.5747, -0.1029, -0.1502],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3756,  0.0000,  0.0644,  0.3149,  2.4705,  0.0121,  0.0000,  0.0000,
-        -0.2791,  0.4319,  0.0532,  0.0000,  0.0112, -0.0893,  0.0000,  0.0000,
-         0.0000,  0.1946,  0.0000,  0.0000, -0.2461,  0.1015,  0.0000, -0.5679,
-         0.0590,  0.0000,  0.3862,  0.2941, -0.0100,  0.0000, -0.8426, -1.1884,
-         0.0000,  0.1030,  0.0000,  0.0000,  0.0000,  0.0000,  0.4819,  0.7716,
-        -0.0335,  0.0000, -0.2905,  0.0000,  0.0000, -0.0808,  0.0000,  0.1502,
-         0.0000,  0.0000,  0.4982,  0.1185, -0.4842,  0.0233,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6835,  0.0000,  0.0000, -0.5747, -0.1029, -0.1502],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5162e-01, -5.2391e-04,  7.1525e-02,  3.3173e-01,  2.4652e+00,
-         3.3652e-02,  3.1713e-13, -2.5041e-06, -3.1245e-01,  4.3749e-01,
-         6.1354e-02,  5.7012e-05, -3.0386e-02, -1.5073e-01, -5.6871e-10,
-        -4.2269e-07, -4.1578e-08,  1.7622e-01, -9.6166e-11,  1.1562e-06,
-        -2.3195e-01,  5.4875e-02, -2.0422e-10, -5.7149e-01,  1.5336e-02,
-         1.6575e-06,  3.7587e-01,  2.8823e-01, -4.3335e-02,  2.9701e-10,
-        -8.7417e-01, -1.1849e+00, -7.6029e-14,  1.2475e-01,  3.7737e-08,
-         2.0362e-06,  1.2588e-09,  0.0000e+00,  4.6752e-01,  7.6607e-01,
-        -9.5252e-02, -1.6814e-07, -2.6439e-01,  3.0981e-10,  7.5776e-08,
-         4.6638e-02, -1.0692e-05,  1.8988e-01,  3.5967e-05, -1.9723e-11,
-         4.5973e-01,  2.6329e-02, -4.6245e-01, -1.2783e-02, -7.2275e-03,
-         4.1047e-15, -1.6132e-09,  5.8017e-13,  6.5549e-01, -8.3789e-05,
-         1.5498e-12, -5.8915e-01, -1.2263e-01, -5.0584e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3516,  0.0000,  0.0715,  0.3317,  2.4652,  0.0337,  0.0000,  0.0000,
-        -0.3124,  0.4375,  0.0614,  0.0000, -0.0304, -0.1507,  0.0000,  0.0000,
-         0.0000,  0.1762,  0.0000,  0.0000, -0.2319,  0.0549,  0.0000, -0.5715,
-         0.0153,  0.0000,  0.3759,  0.2882, -0.0433,  0.0000, -0.8742, -1.1849,
-         0.0000,  0.1247,  0.0000,  0.0000,  0.0000,  0.0000,  0.4675,  0.7661,
-        -0.0953,  0.0000, -0.2644,  0.0000,  0.0000,  0.0466,  0.0000,  0.1899,
-         0.0000,  0.0000,  0.4597,  0.0263, -0.4624, -0.0128,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6555,  0.0000,  0.0000, -0.5891, -0.1226, -0.0506],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3516,  0.0000,  0.0715,  0.3317,  2.4652,  0.0337,  0.0000,  0.0000,
-        -0.3124,  0.4375,  0.0614,  0.0000, -0.0304, -0.1507,  0.0000,  0.0000,
-         0.0000,  0.1762,  0.0000,  0.0000, -0.2319,  0.0549,  0.0000, -0.5715,
-         0.0153,  0.0000,  0.3759,  0.2882, -0.0433,  0.0000, -0.8742, -1.1849,
-         0.0000,  0.1247,  0.0000,  0.0000,  0.0000,  0.0000,  0.4675,  0.7661,
-        -0.0953,  0.0000, -0.2644,  0.0000,  0.0000,  0.0466,  0.0000,  0.1899,
-         0.0000,  0.0000,  0.4597,  0.0263, -0.4624, -0.0128,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6555,  0.0000,  0.0000, -0.5891, -0.1226, -0.0506],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2621e-01, -4.5618e-04,  4.4235e-02,  3.6739e-01,  2.4605e+00,
-         4.0816e-02,  2.7613e-13, -2.1804e-06, -3.3749e-01,  4.0230e-01,
-        -4.3735e-02,  4.9641e-05, -1.0444e-01, -1.7251e-01, -4.9519e-10,
-        -3.6804e-07, -3.6202e-08,  1.7472e-01, -8.3733e-11,  1.0068e-06,
-        -1.9741e-01,  4.6291e-02, -1.7782e-10, -5.8698e-01,  1.6987e-02,
-         1.4432e-06,  3.4957e-01,  2.7991e-01, -6.7149e-02,  2.5861e-10,
-        -8.6671e-01, -1.1834e+00, -6.6199e-14,  1.4602e-01,  3.2858e-08,
-         1.7729e-06,  1.0961e-09,  0.0000e+00,  4.6088e-01,  7.6480e-01,
-        -1.4547e-01, -1.4640e-07, -2.6508e-01,  2.6975e-10,  6.5979e-08,
-         1.1161e-01, -9.3095e-06,  2.2269e-01,  3.1317e-05, -1.7173e-11,
-         4.0807e-01,  3.5387e-03, -4.4618e-01,  4.6432e-03, -6.2931e-03,
-         3.5740e-15, -1.4046e-09,  5.0516e-13,  6.2427e-01, -7.2956e-05,
-         1.3494e-12, -5.9692e-01, -1.1404e-01,  2.8934e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3262,  0.0000,  0.0442,  0.3674,  2.4605,  0.0408,  0.0000,  0.0000,
-        -0.3375,  0.4023, -0.0437,  0.0000, -0.1044, -0.1725,  0.0000,  0.0000,
-         0.0000,  0.1747,  0.0000,  0.0000, -0.1974,  0.0463,  0.0000, -0.5870,
-         0.0170,  0.0000,  0.3496,  0.2799, -0.0671,  0.0000, -0.8667, -1.1834,
-         0.0000,  0.1460,  0.0000,  0.0000,  0.0000,  0.0000,  0.4609,  0.7648,
-        -0.1455,  0.0000, -0.2651,  0.0000,  0.0000,  0.1116,  0.0000,  0.2227,
-         0.0000,  0.0000,  0.4081,  0.0035, -0.4462,  0.0046,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6243,  0.0000,  0.0000, -0.5969, -0.1140,  0.0289],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3262,  0.0000,  0.0442,  0.3674,  2.4605,  0.0408,  0.0000,  0.0000,
-        -0.3375,  0.4023, -0.0437,  0.0000, -0.1044, -0.1725,  0.0000,  0.0000,
-         0.0000,  0.1747,  0.0000,  0.0000, -0.1974,  0.0463,  0.0000, -0.5870,
-         0.0170,  0.0000,  0.3496,  0.2799, -0.0671,  0.0000, -0.8667, -1.1834,
-         0.0000,  0.1460,  0.0000,  0.0000,  0.0000,  0.0000,  0.4609,  0.7648,
-        -0.1455,  0.0000, -0.2651,  0.0000,  0.0000,  0.1116,  0.0000,  0.2227,
-         0.0000,  0.0000,  0.4081,  0.0035, -0.4462,  0.0046,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6243,  0.0000,  0.0000, -0.5969, -0.1140,  0.0289],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1206e-01, -3.9735e-04,  1.7211e-02,  4.2150e-01,  2.4545e+00,
-         2.7120e-02,  2.4052e-13, -1.8992e-06, -3.6543e-01,  3.5268e-01,
-        -1.3115e-01,  4.3239e-05, -1.4794e-01, -1.8585e-01, -4.3133e-10,
-        -3.2058e-07, -3.1534e-08,  1.9539e-01, -7.2935e-11,  8.7692e-07,
-        -1.4195e-01,  5.3537e-02, -1.5489e-10, -5.6023e-01,  7.0294e-02,
-         1.2571e-06,  2.9595e-01,  2.8069e-01, -7.8406e-02,  2.2526e-10,
-        -8.4769e-01, -1.1783e+00, -5.7662e-14,  1.6793e-01,  2.8621e-08,
-         1.5443e-06,  9.5472e-10,  0.0000e+00,  4.3764e-01,  7.5494e-01,
-        -1.7561e-01, -1.2752e-07, -2.6375e-01,  2.3497e-10,  5.7470e-08,
-         1.0425e-01, -8.1089e-06,  2.0564e-01,  2.7278e-05, -1.4958e-11,
-         3.2348e-01,  5.8941e-02, -4.5283e-01,  5.7191e-03, -5.4815e-03,
-         3.1131e-15, -1.2235e-09,  4.4002e-13,  5.8036e-01, -6.3548e-05,
-         1.1754e-12, -6.0392e-01, -1.5750e-01,  1.0564e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3121,  0.0000,  0.0172,  0.4215,  2.4545,  0.0271,  0.0000,  0.0000,
-        -0.3654,  0.3527, -0.1312,  0.0000, -0.1479, -0.1858,  0.0000,  0.0000,
-         0.0000,  0.1954,  0.0000,  0.0000, -0.1419,  0.0535,  0.0000, -0.5602,
-         0.0703,  0.0000,  0.2959,  0.2807, -0.0784,  0.0000, -0.8477, -1.1783,
-         0.0000,  0.1679,  0.0000,  0.0000,  0.0000,  0.0000,  0.4376,  0.7549,
-        -0.1756,  0.0000, -0.2637,  0.0000,  0.0000,  0.1042,  0.0000,  0.2056,
-         0.0000,  0.0000,  0.3235,  0.0589, -0.4528,  0.0057,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5804,  0.0000,  0.0000, -0.6039, -0.1575,  0.1056],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3121,  0.0000,  0.0172,  0.4215,  2.4545,  0.0271,  0.0000,  0.0000,
-        -0.3654,  0.3527, -0.1312,  0.0000, -0.1479, -0.1858,  0.0000,  0.0000,
-         0.0000,  0.1954,  0.0000,  0.0000, -0.1419,  0.0535,  0.0000, -0.5602,
-         0.0703,  0.0000,  0.2959,  0.2807, -0.0784,  0.0000, -0.8477, -1.1783,
-         0.0000,  0.1679,  0.0000,  0.0000,  0.0000,  0.0000,  0.4376,  0.7549,
-        -0.1756,  0.0000, -0.2637,  0.0000,  0.0000,  0.1042,  0.0000,  0.2056,
-         0.0000,  0.0000,  0.3235,  0.0589, -0.4528,  0.0057,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5804,  0.0000,  0.0000, -0.6039, -0.1575,  0.1056],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8278e-01, -3.4624e-04,  1.7099e-02,  4.3274e-01,  2.4493e+00,
-         1.3791e-02,  2.0958e-13, -1.6549e-06, -3.8574e-01,  3.0898e-01,
-        -1.3869e-01,  3.7677e-05, -1.6834e-01, -2.1448e-01, -3.7584e-10,
-        -2.7934e-07, -2.7477e-08,  1.7465e-01, -6.3553e-11,  7.6412e-07,
-        -6.4627e-02,  5.6005e-02, -1.3496e-10, -5.4024e-01,  9.7093e-02,
-         1.0954e-06,  2.5339e-01,  2.7647e-01, -6.6673e-02,  1.9629e-10,
-        -8.4090e-01, -1.1693e+00, -5.0245e-14,  1.6247e-01,  2.4939e-08,
-         1.3456e-06,  8.3192e-10,  0.0000e+00,  4.5215e-01,  7.6420e-01,
-        -1.6367e-01, -1.1112e-07, -2.5733e-01,  2.0474e-10,  5.0078e-08,
-         8.7597e-02, -7.0658e-06,  2.0308e-01,  2.3769e-05, -1.3034e-11,
-         2.8131e-01,  8.3393e-02, -4.2288e-01, -2.4764e-02, -4.7764e-03,
-         2.7127e-15, -1.0661e-09,  3.8342e-13,  5.3834e-01, -5.5374e-05,
-         1.0242e-12, -6.0539e-01, -1.6750e-01,  2.0464e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2828,  0.0000,  0.0171,  0.4327,  2.4493,  0.0138,  0.0000,  0.0000,
-        -0.3857,  0.3090, -0.1387,  0.0000, -0.1683, -0.2145,  0.0000,  0.0000,
-         0.0000,  0.1746,  0.0000,  0.0000, -0.0646,  0.0560,  0.0000, -0.5402,
-         0.0971,  0.0000,  0.2534,  0.2765, -0.0667,  0.0000, -0.8409, -1.1693,
-         0.0000,  0.1625,  0.0000,  0.0000,  0.0000,  0.0000,  0.4521,  0.7642,
-        -0.1637,  0.0000, -0.2573,  0.0000,  0.0000,  0.0876,  0.0000,  0.2031,
-         0.0000,  0.0000,  0.0000,  0.0834, -0.4229, -0.0248,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5383,  0.0000,  0.0000, -0.6054, -0.1675,  0.2046],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2828,  0.0000,  0.0171,  0.4327,  2.4493,  0.0138,  0.0000,  0.0000,
-        -0.3857,  0.3090, -0.1387,  0.0000, -0.1683, -0.2145,  0.0000,  0.0000,
-         0.0000,  0.1746,  0.0000,  0.0000, -0.0646,  0.0560,  0.0000, -0.5402,
-         0.0971,  0.0000,  0.2534,  0.2765, -0.0667,  0.0000, -0.8409, -1.1693,
-         0.0000,  0.1625,  0.0000,  0.0000,  0.0000,  0.0000,  0.4521,  0.7642,
-        -0.1637,  0.0000, -0.2573,  0.0000,  0.0000,  0.0876,  0.0000,  0.2031,
-         0.0000,  0.0000,  0.0000,  0.0834, -0.4229, -0.0248,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5383,  0.0000,  0.0000, -0.6054, -0.1675,  0.2046],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.5784e-01, -3.0182e-04, -5.0099e-02,  4.1206e-01,  2.4444e+00,
-        -1.0291e-03,  1.8269e-13, -1.4426e-06, -4.0500e-01,  2.5124e-01,
-        -1.0460e-01,  3.2843e-05, -2.2089e-01, -2.0591e-01, -3.2762e-10,
-        -2.4350e-07, -2.3952e-08,  1.4153e-01, -5.5399e-11,  6.6608e-07,
-         6.7749e-02,  5.6480e-02, -1.1765e-10, -4.7403e-01,  1.6743e-01,
-         9.5484e-07,  1.6749e-01,  3.1232e-01, -1.4279e-02,  1.7110e-10,
-        -8.3694e-01, -1.1571e+00, -4.3799e-14,  1.6326e-01,  2.1740e-08,
-         1.1730e-06,  7.2518e-10,  0.0000e+00,  4.2368e-01,  7.7885e-01,
-        -7.8586e-02, -9.6860e-08, -2.7378e-01,  1.7847e-10,  4.3653e-08,
-         5.4021e-02, -6.1593e-06,  1.5157e-01,  2.0720e-05, -1.1362e-11,
-        -3.6762e-02,  1.4779e-01, -3.6994e-01, -1.5105e-01, -4.1636e-03,
-         2.3646e-15, -9.2934e-10,  3.3422e-13,  5.1101e-01, -4.8269e-05,
-         8.9279e-13, -6.4367e-01, -1.7654e-01,  2.7615e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.5784e-01,  0.0000e+00, -5.0099e-02,  4.1206e-01,  2.4444e+00,
-        -1.0291e-03,  0.0000e+00,  0.0000e+00, -4.0500e-01,  2.5124e-01,
-        -1.0460e-01,  0.0000e+00, -2.2089e-01, -2.0591e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  1.4153e-01,  0.0000e+00,  0.0000e+00,
-         6.7749e-02,  5.6480e-02,  0.0000e+00, -4.7403e-01,  1.6743e-01,
-         0.0000e+00,  1.6749e-01,  3.1232e-01, -1.4279e-02,  0.0000e+00,
-        -8.3694e-01, -1.1571e+00,  0.0000e+00,  1.6326e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2368e-01,  7.7885e-01,
-        -7.8586e-02,  0.0000e+00, -2.7378e-01,  0.0000e+00,  0.0000e+00,
-         5.4021e-02,  0.0000e+00,  1.5157e-01,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.4779e-01, -3.6994e-01, -1.5105e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.1101e-01,  0.0000e+00,
-         0.0000e+00, -6.4367e-01, -1.7654e-01,  2.7615e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.5784e-01,  0.0000e+00, -5.0099e-02,  4.1206e-01,  2.4444e+00,
-        -1.0291e-03,  0.0000e+00,  0.0000e+00, -4.0500e-01,  2.5124e-01,
-        -1.0460e-01,  0.0000e+00, -2.2089e-01, -2.0591e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  1.4153e-01,  0.0000e+00,  0.0000e+00,
-         6.7749e-02,  5.6480e-02,  0.0000e+00, -4.7403e-01,  1.6743e-01,
-         0.0000e+00,  1.6749e-01,  3.1232e-01, -1.4279e-02,  0.0000e+00,
-        -8.3694e-01, -1.1571e+00,  0.0000e+00,  1.6326e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2368e-01,  7.7885e-01,
-        -7.8586e-02,  0.0000e+00, -2.7378e-01,  0.0000e+00,  0.0000e+00,
-         5.4021e-02,  0.0000e+00,  1.5157e-01,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.4779e-01, -3.6994e-01, -1.5105e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.1101e-01,  0.0000e+00,
-         0.0000e+00, -6.4367e-01, -1.7654e-01,  2.7615e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7038e-01, -2.6319e-04, -6.6353e-02,  3.7236e-01,  2.4423e+00,
-         1.6916e-02,  1.5931e-13, -1.2580e-06, -3.9687e-01,  1.7376e-01,
-        -9.1570e-02,  2.8640e-05, -1.8875e-01, -2.0240e-01, -2.8570e-10,
-        -2.1234e-07, -2.0887e-08,  9.8325e-02, -4.8310e-11,  5.8085e-07,
-         1.3039e-01,  6.1700e-02, -1.0259e-10, -5.2420e-01,  2.1224e-01,
-         8.3265e-07,  7.1878e-02,  3.0130e-01,  4.8567e-02,  1.4921e-10,
-        -8.4034e-01, -1.1401e+00, -3.8194e-14,  1.5669e-01,  1.8958e-08,
-         1.0229e-06,  6.3238e-10,  0.0000e+00,  4.3200e-01,  8.1780e-01,
-         1.8279e-02, -8.4465e-08, -2.5922e-01,  1.5563e-10,  3.8066e-08,
-        -4.3714e-02, -5.3711e-06,  6.8732e-02,  1.8068e-05, -9.9080e-12,
-        -3.2058e-02,  1.4941e-01, -3.6091e-01, -1.7313e-01, -3.6308e-03,
-         2.0620e-15, -8.1041e-10,  2.9145e-13,  4.8550e-01, -4.2092e-05,
-         7.7854e-13, -6.4208e-01, -1.6581e-01,  2.6847e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2704,  0.0000, -0.0664,  0.3724,  2.4423,  0.0169,  0.0000,  0.0000,
-        -0.3969,  0.1738, -0.0916,  0.0000, -0.1887, -0.2024,  0.0000,  0.0000,
-         0.0000,  0.0983,  0.0000,  0.0000,  0.1304,  0.0617,  0.0000, -0.5242,
-         0.2122,  0.0000,  0.0719,  0.3013,  0.0486,  0.0000, -0.8403, -1.1401,
-         0.0000,  0.1567,  0.0000,  0.0000,  0.0000,  0.0000,  0.4320,  0.8178,
-         0.0183,  0.0000, -0.2592,  0.0000,  0.0000, -0.0437,  0.0000,  0.0687,
-         0.0000,  0.0000,  0.0000,  0.1494, -0.3609, -0.1731,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4855,  0.0000,  0.0000, -0.6421, -0.1658,  0.2685],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2704,  0.0000, -0.0664,  0.3724,  2.4423,  0.0169,  0.0000,  0.0000,
-        -0.3969,  0.1738, -0.0916,  0.0000, -0.1887, -0.2024,  0.0000,  0.0000,
-         0.0000,  0.0983,  0.0000,  0.0000,  0.1304,  0.0617,  0.0000, -0.5242,
-         0.2122,  0.0000,  0.0719,  0.3013,  0.0486,  0.0000, -0.8403, -1.1401,
-         0.0000,  0.1567,  0.0000,  0.0000,  0.0000,  0.0000,  0.4320,  0.8178,
-         0.0183,  0.0000, -0.2592,  0.0000,  0.0000, -0.0437,  0.0000,  0.0687,
-         0.0000,  0.0000,  0.0000,  0.1494, -0.3609, -0.1731,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4855,  0.0000,  0.0000, -0.6421, -0.1658,  0.2685],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8168e-01, -2.2960e-04, -2.7975e-02,  3.4407e-01,  2.4416e+00,
-         8.7118e-02,  1.3898e-13, -1.0974e-06, -3.5528e-01,  7.1206e-02,
-        -1.5023e-01,  2.4985e-05, -1.2449e-01, -2.3496e-01, -2.4923e-10,
-        -1.8524e-07, -1.8221e-08,  4.2841e-02, -4.2144e-11,  5.0671e-07,
-         1.2747e-01,  3.0238e-02, -8.9498e-11, -5.9550e-01,  1.9769e-01,
-         7.2637e-07,  1.9658e-02,  2.5702e-01,  1.1870e-01,  1.3016e-10,
-        -8.3229e-01, -1.1321e+00, -3.3319e-14,  1.2431e-01,  1.6538e-08,
-         8.9232e-07,  5.5166e-10,  0.0000e+00,  4.5291e-01,  8.6585e-01,
-         4.8801e-02, -7.3684e-08, -2.1440e-01,  1.3577e-10,  3.3208e-08,
-        -8.8934e-02, -4.6855e-06,  1.2802e-02,  1.5762e-05, -8.6434e-12,
-        -2.7966e-02,  8.2602e-02, -3.8278e-01, -4.6579e-02, -3.1674e-03,
-         1.7988e-15, -7.0697e-10,  2.5425e-13,  4.8864e-01, -3.6720e-05,
-         6.7917e-13, -5.7040e-01, -1.9821e-01,  2.0057e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2817,  0.0000, -0.0280,  0.3441,  2.4416,  0.0871,  0.0000,  0.0000,
-        -0.3553,  0.0712, -0.1502,  0.0000, -0.1245, -0.2350,  0.0000,  0.0000,
-         0.0000,  0.0428,  0.0000,  0.0000,  0.1275,  0.0302,  0.0000, -0.5955,
-         0.1977,  0.0000,  0.0197,  0.2570,  0.1187,  0.0000, -0.8323, -1.1321,
-         0.0000,  0.1243,  0.0000,  0.0000,  0.0000,  0.0000,  0.4529,  0.8658,
-         0.0488,  0.0000, -0.2144,  0.0000,  0.0000, -0.0889,  0.0000,  0.0128,
-         0.0000,  0.0000,  0.0000,  0.0826, -0.3828, -0.0466,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4886,  0.0000,  0.0000, -0.5704, -0.1982,  0.2006],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2817,  0.0000, -0.0280,  0.3441,  2.4416,  0.0871,  0.0000,  0.0000,
-        -0.3553,  0.0712, -0.1502,  0.0000, -0.1245, -0.2350,  0.0000,  0.0000,
-         0.0000,  0.0428,  0.0000,  0.0000,  0.1275,  0.0302,  0.0000, -0.5955,
-         0.1977,  0.0000,  0.0197,  0.2570,  0.1187,  0.0000, -0.8323, -1.1321,
-         0.0000,  0.1243,  0.0000,  0.0000,  0.0000,  0.0000,  0.4529,  0.8658,
-         0.0488,  0.0000, -0.2144,  0.0000,  0.0000, -0.0889,  0.0000,  0.0128,
-         0.0000,  0.0000,  0.0000,  0.0826, -0.3828, -0.0466,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4886,  0.0000,  0.0000, -0.5704, -0.1982,  0.2006],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8532e-01, -2.0037e-04,  4.2665e-02,  3.2617e-01,  2.4423e+00,
-         2.0871e-01,  1.2129e-13, -9.5770e-07, -3.0621e-01,  4.2552e-03,
-        -1.7522e-01,  2.1804e-05, -2.4579e-02, -2.6876e-01, -2.1750e-10,
-        -1.6166e-07, -1.5901e-08, -3.0743e-02, -3.6779e-11,  4.4220e-07,
-         7.4929e-02, -1.2829e-02, -7.8104e-11, -7.1196e-01,  1.7774e-01,
-         6.3390e-07,  7.0560e-04,  1.9939e-01,  1.9148e-01,  1.1359e-10,
-        -8.2765e-01, -1.1334e+00, -2.9077e-14,  6.0190e-02,  1.4433e-08,
-         7.7872e-07,  4.8144e-10,  0.0000e+00,  4.6752e-01,  8.9653e-01,
-         4.6748e-02, -6.4304e-08, -1.6300e-01,  1.1849e-10,  2.8980e-08,
-        -1.4117e-01, -4.0890e-06,  2.3222e-03,  1.3756e-05, -7.5430e-12,
-        -2.4406e-02, -6.7295e-03, -4.0895e-01,  1.1349e-01, -2.7642e-03,
-         1.5698e-15, -6.1697e-10,  2.2189e-13,  5.0067e-01, -3.2045e-05,
-         5.9271e-13, -5.0443e-01, -2.3119e-01,  1.0409e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.8532e-01,  0.0000e+00,  4.2665e-02,  3.2617e-01,  2.4423e+00,
-         2.0871e-01,  0.0000e+00,  0.0000e+00, -3.0621e-01,  4.2552e-03,
-        -1.7522e-01,  0.0000e+00, -2.4579e-02, -2.6876e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.0743e-02,  0.0000e+00,  0.0000e+00,
-         7.4929e-02, -1.2829e-02,  0.0000e+00, -7.1196e-01,  1.7774e-01,
-         0.0000e+00,  7.0560e-04,  1.9939e-01,  1.9148e-01,  0.0000e+00,
-        -8.2765e-01, -1.1334e+00,  0.0000e+00,  6.0190e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6752e-01,  8.9653e-01,
-         4.6748e-02,  0.0000e+00, -1.6300e-01,  0.0000e+00,  0.0000e+00,
-        -1.4117e-01,  0.0000e+00,  2.3222e-03,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -6.7295e-03, -4.0895e-01,  1.1349e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.0067e-01,  0.0000e+00,
-         0.0000e+00, -5.0443e-01, -2.3119e-01,  1.0409e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.8532e-01,  0.0000e+00,  4.2665e-02,  3.2617e-01,  2.4423e+00,
-         2.0871e-01,  0.0000e+00,  0.0000e+00, -3.0621e-01,  4.2552e-03,
-        -1.7522e-01,  0.0000e+00, -2.4579e-02, -2.6876e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.0743e-02,  0.0000e+00,  0.0000e+00,
-         7.4929e-02, -1.2829e-02,  0.0000e+00, -7.1196e-01,  1.7774e-01,
-         0.0000e+00,  7.0560e-04,  1.9939e-01,  1.9148e-01,  0.0000e+00,
-        -8.2765e-01, -1.1334e+00,  0.0000e+00,  6.0190e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6752e-01,  8.9653e-01,
-         4.6748e-02,  0.0000e+00, -1.6300e-01,  0.0000e+00,  0.0000e+00,
-        -1.4117e-01,  0.0000e+00,  2.3222e-03,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -6.7295e-03, -4.0895e-01,  1.1349e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.0067e-01,  0.0000e+00,
-         0.0000e+00, -5.0443e-01, -2.3119e-01,  1.0409e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6792e-01, -1.7493e-04,  6.4759e-02,  2.7333e-01,  2.4425e+00,
-         3.5665e-01,  1.0589e-13, -8.3611e-07, -2.5456e-01, -7.8516e-03,
-        -1.3802e-01,  1.9036e-05,  1.8049e-02, -3.1434e-01, -1.8989e-10,
-        -1.4113e-07, -1.3882e-08, -1.0430e-01, -3.2109e-11,  3.8606e-07,
-        -2.8989e-02, -1.2378e-02, -6.8188e-11, -7.8751e-01,  2.0740e-01,
-         5.5342e-07, -2.1021e-02,  1.2559e-01,  3.0989e-01,  9.9169e-11,
-        -8.3519e-01, -1.1378e+00, -2.5385e-14,  1.7936e-02,  1.2600e-08,
-         6.7985e-07,  4.2031e-10,  0.0000e+00,  4.3104e-01,  8.9997e-01,
-         1.2016e-01, -5.6139e-08, -1.3737e-01,  1.0344e-10,  2.5301e-08,
-        -1.6142e-01, -3.5699e-06, -3.2400e-02,  1.2009e-05, -6.5853e-12,
-        -2.1307e-02, -8.1226e-02, -4.0657e-01,  2.3315e-01, -2.4132e-03,
-         1.3705e-15, -5.3864e-10,  1.9371e-13,  5.0346e-01, -2.7976e-05,
-         5.1745e-13, -4.6304e-01, -2.6264e-01,  5.2330e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2679,  0.0000,  0.0648,  0.2733,  2.4425,  0.3567,  0.0000,  0.0000,
-        -0.2546, -0.0079, -0.1380,  0.0000,  0.0180, -0.3143,  0.0000,  0.0000,
-         0.0000, -0.1043,  0.0000,  0.0000, -0.0290, -0.0124,  0.0000, -0.7875,
-         0.2074,  0.0000, -0.0210,  0.1256,  0.3099,  0.0000, -0.8352, -1.1378,
-         0.0000,  0.0179,  0.0000,  0.0000,  0.0000,  0.0000,  0.4310,  0.9000,
-         0.1202,  0.0000, -0.1374,  0.0000,  0.0000, -0.1614,  0.0000, -0.0324,
-         0.0000,  0.0000,  0.0000, -0.0812, -0.4066,  0.2331,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5035,  0.0000,  0.0000, -0.4630, -0.2626,  0.0523],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2679,  0.0000,  0.0648,  0.2733,  2.4425,  0.3567,  0.0000,  0.0000,
-        -0.2546, -0.0079, -0.1380,  0.0000,  0.0180, -0.3143,  0.0000,  0.0000,
-         0.0000, -0.1043,  0.0000,  0.0000, -0.0290, -0.0124,  0.0000, -0.7875,
-         0.2074,  0.0000, -0.0210,  0.1256,  0.3099,  0.0000, -0.8352, -1.1378,
-         0.0000,  0.0179,  0.0000,  0.0000,  0.0000,  0.0000,  0.4310,  0.9000,
-         0.1202,  0.0000, -0.1374,  0.0000,  0.0000, -0.1614,  0.0000, -0.0324,
-         0.0000,  0.0000,  0.0000, -0.0812, -0.4066,  0.2331,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5035,  0.0000,  0.0000, -0.4630, -0.2626,  0.0523],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.2234e-01, -1.5278e-04,  1.1193e-02,  1.9859e-01,  2.4409e+00,
-         4.7460e-01,  9.2479e-14, -7.3023e-07, -1.9530e-01, -2.6922e-02,
-        -7.2796e-02,  1.6625e-05, -5.9594e-03, -3.4675e-01, -1.6584e-10,
-        -1.2326e-07, -1.2124e-08, -1.3939e-01, -2.8043e-11,  3.3717e-07,
-        -1.4151e-01,  2.9938e-02, -5.9553e-11, -8.3891e-01,  2.4572e-01,
-         4.8334e-07, -7.2800e-02,  4.7188e-02,  3.9825e-01,  8.6611e-11,
-        -8.4847e-01, -1.1413e+00, -2.2171e-14, -1.5381e-03,  1.1005e-08,
-         5.9376e-07,  3.6708e-10,  0.0000e+00,  3.8435e-01,  8.9743e-01,
-         2.3247e-01, -4.9030e-08, -1.4596e-01,  9.0343e-11,  2.2097e-08,
-        -1.7051e-01, -3.1178e-06, -4.4113e-02,  1.0488e-05, -5.7514e-12,
-        -1.8609e-02, -1.0802e-01, -3.7511e-01,  2.6346e-01, -2.1076e-03,
-         1.1970e-15, -4.7043e-10,  1.6918e-13,  5.2875e-01, -2.4434e-05,
-         4.5193e-13, -4.6779e-01, -2.4735e-01, -4.7853e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.2234e-01,  0.0000e+00,  1.1193e-02,  1.9859e-01,  2.4409e+00,
-         4.7460e-01,  0.0000e+00,  0.0000e+00, -1.9530e-01, -2.6922e-02,
-        -7.2796e-02,  0.0000e+00, -5.9594e-03, -3.4675e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.3939e-01,  0.0000e+00,  0.0000e+00,
-        -1.4151e-01,  2.9938e-02,  0.0000e+00, -8.3891e-01,  2.4572e-01,
-         0.0000e+00, -7.2800e-02,  4.7188e-02,  3.9825e-01,  0.0000e+00,
-        -8.4847e-01, -1.1413e+00,  0.0000e+00, -1.5381e-03,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.8435e-01,  8.9743e-01,
-         2.3247e-01,  0.0000e+00, -1.4596e-01,  0.0000e+00,  0.0000e+00,
-        -1.7051e-01,  0.0000e+00, -4.4113e-02,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -1.0802e-01, -3.7511e-01,  2.6346e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.2875e-01,  0.0000e+00,
-         0.0000e+00, -4.6779e-01, -2.4735e-01, -4.7853e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.2234e-01,  0.0000e+00,  1.1193e-02,  1.9859e-01,  2.4409e+00,
-         4.7460e-01,  0.0000e+00,  0.0000e+00, -1.9530e-01, -2.6922e-02,
-        -7.2796e-02,  0.0000e+00, -5.9594e-03, -3.4675e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.3939e-01,  0.0000e+00,  0.0000e+00,
-        -1.4151e-01,  2.9938e-02,  0.0000e+00, -8.3891e-01,  2.4572e-01,
-         0.0000e+00, -7.2800e-02,  4.7188e-02,  3.9825e-01,  0.0000e+00,
-        -8.4847e-01, -1.1413e+00,  0.0000e+00, -1.5381e-03,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.8435e-01,  8.9743e-01,
-         2.3247e-01,  0.0000e+00, -1.4596e-01,  0.0000e+00,  0.0000e+00,
-        -1.7051e-01,  0.0000e+00, -4.4113e-02,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -1.0802e-01, -3.7511e-01,  2.6346e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.2875e-01,  0.0000e+00,
-         0.0000e+00, -4.6779e-01, -2.4735e-01, -4.7853e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.7595e-01, -1.3348e-04, -6.4304e-02,  1.0458e-01,  2.4381e+00,
-         5.6884e-01,  8.0800e-14, -6.3801e-07, -1.2509e-01, -6.2898e-02,
-        -7.8560e-04,  1.4526e-05, -5.7334e-02, -3.6389e-01, -1.4490e-10,
-        -1.0769e-07, -1.0593e-08, -1.6299e-01, -2.4501e-11,  2.9459e-07,
-        -2.4164e-01,  8.5751e-02, -5.2032e-11, -8.4927e-01,  2.7630e-01,
-         4.2230e-07, -1.4648e-01, -2.7657e-02,  4.6871e-01,  7.5673e-11,
-        -8.6652e-01, -1.1462e+00, -1.9371e-14, -1.0880e-02,  9.6147e-09,
-         5.1877e-07,  3.2072e-10,  0.0000e+00,  3.1876e-01,  8.9231e-01,
-         3.2730e-01, -4.2838e-08, -1.6573e-01,  7.8933e-11,  1.9306e-08,
-        -1.7528e-01, -2.7241e-06, -6.5885e-02,  9.1637e-06, -5.0251e-12,
-        -1.6259e-02, -8.6654e-02, -3.2876e-01,  2.3614e-01, -1.8414e-03,
-         1.0458e-15, -4.1102e-10,  1.4782e-13,  5.4292e-01, -2.1348e-05,
-         3.9485e-13, -4.8493e-01, -2.1685e-01, -8.1347e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 1.7595e-01,  0.0000e+00, -6.4304e-02,  1.0458e-01,  2.4381e+00,
-         5.6884e-01,  0.0000e+00,  0.0000e+00, -1.2509e-01, -6.2898e-02,
-        -7.8560e-04,  0.0000e+00, -5.7334e-02, -3.6389e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.6299e-01,  0.0000e+00,  0.0000e+00,
-        -2.4164e-01,  8.5751e-02,  0.0000e+00, -8.4927e-01,  2.7630e-01,
-         0.0000e+00, -1.4648e-01, -2.7657e-02,  4.6871e-01,  0.0000e+00,
-        -8.6652e-01, -1.1462e+00,  0.0000e+00, -1.0880e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.1876e-01,  8.9231e-01,
-         3.2730e-01,  0.0000e+00, -1.6573e-01,  0.0000e+00,  0.0000e+00,
-        -1.7528e-01,  0.0000e+00, -6.5885e-02,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -8.6654e-02, -3.2876e-01,  2.3614e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.4292e-01,  0.0000e+00,
-         0.0000e+00, -4.8493e-01, -2.1685e-01, -8.1347e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 1.7595e-01,  0.0000e+00, -6.4304e-02,  1.0458e-01,  2.4381e+00,
-         5.6884e-01,  0.0000e+00,  0.0000e+00, -1.2509e-01, -6.2898e-02,
-        -7.8560e-04,  0.0000e+00, -5.7334e-02, -3.6389e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.6299e-01,  0.0000e+00,  0.0000e+00,
-        -2.4164e-01,  8.5751e-02,  0.0000e+00, -8.4927e-01,  2.7630e-01,
-         0.0000e+00, -1.4648e-01, -2.7657e-02,  4.6871e-01,  0.0000e+00,
-        -8.6652e-01, -1.1462e+00,  0.0000e+00, -1.0880e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.1876e-01,  8.9231e-01,
-         3.2730e-01,  0.0000e+00, -1.6573e-01,  0.0000e+00,  0.0000e+00,
-        -1.7528e-01,  0.0000e+00, -6.5885e-02,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -8.6654e-02, -3.2876e-01,  2.3614e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.4292e-01,  0.0000e+00,
-         0.0000e+00, -4.8493e-01, -2.1685e-01, -8.1347e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.4915e-01, -1.1667e-04, -1.1199e-01,  6.5085e-02,  2.4350e+00,
-         6.7004e-01,  7.0623e-14, -5.5765e-07, -1.4394e-02, -7.0757e-02,
-        -1.8389e-02,  1.2696e-05, -9.0452e-02, -3.5275e-01, -1.2665e-10,
-        -9.4128e-08, -9.2590e-09, -1.0583e-01, -2.1415e-11,  2.5748e-07,
-        -2.7518e-01,  1.5880e-01, -4.5478e-11, -8.5793e-01,  3.0170e-01,
-         3.6911e-07, -1.9879e-01, -4.8126e-02,  5.0637e-01,  6.6142e-11,
-        -8.8356e-01, -1.1490e+00, -1.6931e-14, -2.4941e-02,  8.4037e-09,
-         4.5343e-07,  2.8033e-10,  0.0000e+00,  2.7823e-01,  8.9262e-01,
-         3.9162e-01, -3.7443e-08, -1.8320e-01,  6.8991e-11,  1.6875e-08,
-        -2.0704e-01, -2.3810e-06, -6.5140e-02,  8.0095e-06, -4.3921e-12,
-        -1.4211e-02, -2.5641e-02, -2.9163e-01,  2.5251e-01, -1.6095e-03,
-         9.1408e-16, -3.5925e-10,  1.2920e-13,  5.4782e-01, -1.8659e-05,
-         3.4512e-13, -4.6805e-01, -1.9369e-01, -1.4413e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1491,  0.0000, -0.1120,  0.0651,  2.4350,  0.6700,  0.0000,  0.0000,
-        -0.0144, -0.0708, -0.0184,  0.0000, -0.0905, -0.3528,  0.0000,  0.0000,
-         0.0000, -0.1058,  0.0000,  0.0000, -0.2752,  0.1588,  0.0000, -0.8579,
-         0.3017,  0.0000, -0.1988, -0.0481,  0.5064,  0.0000, -0.8836, -1.1490,
-         0.0000, -0.0249,  0.0000,  0.0000,  0.0000,  0.0000,  0.2782,  0.8926,
-         0.3916,  0.0000, -0.1832,  0.0000,  0.0000, -0.2070,  0.0000, -0.0651,
-         0.0000,  0.0000,  0.0000, -0.0256, -0.2916,  0.2525,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5478,  0.0000,  0.0000, -0.4681, -0.1937, -0.1441],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1491,  0.0000, -0.1120,  0.0651,  2.4350,  0.6700,  0.0000,  0.0000,
-        -0.0144, -0.0708, -0.0184,  0.0000, -0.0905, -0.3528,  0.0000,  0.0000,
-         0.0000, -0.1058,  0.0000,  0.0000, -0.2752,  0.1588,  0.0000, -0.8579,
-         0.3017,  0.0000, -0.1988, -0.0481,  0.5064,  0.0000, -0.8836, -1.1490,
-         0.0000, -0.0249,  0.0000,  0.0000,  0.0000,  0.0000,  0.2782,  0.8926,
-         0.3916,  0.0000, -0.1832,  0.0000,  0.0000, -0.2070,  0.0000, -0.0651,
-         0.0000,  0.0000,  0.0000, -0.0256, -0.2916,  0.2525,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5478,  0.0000,  0.0000, -0.4681, -0.1937, -0.1441],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.1592e-01, -1.0202e-04, -1.5458e-01,  5.5854e-02,  2.4321e+00,
-         7.7636e-01,  6.1752e-14, -4.8760e-07,  9.4582e-02, -7.3868e-02,
-        -7.9312e-02,  1.1101e-05, -1.0773e-01, -3.2577e-01, -1.1074e-10,
-        -8.2305e-08, -8.0960e-09, -3.4433e-02, -1.8725e-11,  2.2514e-07,
-        -2.7779e-01,  2.1047e-01, -3.9766e-11, -8.7433e-01,  2.8746e-01,
-         3.2274e-07, -2.3442e-01, -4.7876e-02,  5.2118e-01,  5.7834e-11,
-        -8.7955e-01, -1.1453e+00, -1.4804e-14, -5.4417e-02,  7.3481e-09,
-         3.9648e-07,  2.4512e-10,  0.0000e+00,  2.2905e-01,  8.8025e-01,
-         4.0828e-01, -3.2739e-08, -2.0316e-01,  6.0325e-11,  1.4755e-08,
-        -2.4311e-01, -2.0819e-06, -2.4420e-02,  7.0034e-06, -3.8404e-12,
-        -1.2426e-02,  1.9836e-02, -2.7734e-01,  2.6641e-01, -1.4073e-03,
-         7.9926e-16, -3.1412e-10,  1.1297e-13,  5.6321e-01, -1.6315e-05,
-         3.0177e-13, -4.2650e-01, -1.7179e-01, -1.8152e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1159,  0.0000, -0.1546,  0.0559,  2.4321,  0.7764,  0.0000,  0.0000,
-         0.0946, -0.0739, -0.0793,  0.0000, -0.1077, -0.3258,  0.0000,  0.0000,
-         0.0000, -0.0344,  0.0000,  0.0000, -0.2778,  0.2105,  0.0000, -0.8743,
-         0.2875,  0.0000, -0.2344, -0.0479,  0.5212,  0.0000, -0.8796, -1.1453,
-         0.0000, -0.0544,  0.0000,  0.0000,  0.0000,  0.0000,  0.2290,  0.8802,
-         0.4083,  0.0000, -0.2032,  0.0000,  0.0000, -0.2431,  0.0000, -0.0244,
-         0.0000,  0.0000,  0.0000,  0.0198, -0.2773,  0.2664,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5632,  0.0000,  0.0000, -0.4265, -0.1718, -0.1815],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1159,  0.0000, -0.1546,  0.0559,  2.4321,  0.7764,  0.0000,  0.0000,
-         0.0946, -0.0739, -0.0793,  0.0000, -0.1077, -0.3258,  0.0000,  0.0000,
-         0.0000, -0.0344,  0.0000,  0.0000, -0.2778,  0.2105,  0.0000, -0.8743,
-         0.2875,  0.0000, -0.2344, -0.0479,  0.5212,  0.0000, -0.8796, -1.1453,
-         0.0000, -0.0544,  0.0000,  0.0000,  0.0000,  0.0000,  0.2290,  0.8802,
-         0.4083,  0.0000, -0.2032,  0.0000,  0.0000, -0.2431,  0.0000, -0.0244,
-         0.0000,  0.0000,  0.0000,  0.0198, -0.2773,  0.2664,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5632,  0.0000,  0.0000, -0.4265, -0.1718, -0.1815],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.3958e-02, -8.9236e-05, -2.0953e-01,  1.7173e-02,  2.4273e+00,
-         8.7973e-01,  5.4016e-14, -4.2652e-07,  1.8594e-01, -1.3627e-01,
-        -1.3098e-01,  9.7106e-06, -1.3139e-01, -2.9550e-01, -9.6866e-11,
-        -7.1995e-08, -7.0818e-09, -6.1339e-02, -1.6380e-11,  1.9694e-07,
-        -2.4098e-01,  2.2425e-01, -3.4784e-11, -8.7828e-01,  2.4837e-01,
-         2.8231e-07, -2.6941e-01, -6.0304e-02,  5.5472e-01,  5.0589e-11,
-        -9.0584e-01, -1.1381e+00, -1.2950e-14, -8.2552e-02,  6.4276e-09,
-         3.4681e-07,  2.1441e-10,  0.0000e+00,  1.3433e-01,  8.7020e-01,
-         4.1816e-01, -2.8638e-08, -2.2235e-01,  5.2768e-11,  1.2907e-08,
-        -2.6979e-01, -1.8211e-06,  4.4270e-02,  6.1261e-06, -3.3593e-12,
-        -1.0869e-02,  7.4551e-02, -2.4394e-01,  1.9431e-01, -1.2310e-03,
-         6.9914e-16, -2.7477e-10,  9.8818e-14,  5.7626e-01, -1.4271e-05,
-         2.6397e-13, -4.1195e-01, -1.5126e-01, -2.1864e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0440,  0.0000, -0.2095,  0.0172,  2.4273,  0.8797,  0.0000,  0.0000,
-         0.1859, -0.1363, -0.1310,  0.0000, -0.1314, -0.2955,  0.0000,  0.0000,
-         0.0000, -0.0613,  0.0000,  0.0000, -0.2410,  0.2242,  0.0000, -0.8783,
-         0.2484,  0.0000, -0.2694, -0.0603,  0.5547,  0.0000, -0.9058, -1.1381,
-         0.0000, -0.0826,  0.0000,  0.0000,  0.0000,  0.0000,  0.1343,  0.8702,
-         0.4182,  0.0000, -0.2224,  0.0000,  0.0000, -0.2698,  0.0000,  0.0443,
-         0.0000,  0.0000,  0.0000,  0.0746, -0.2439,  0.1943,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5763,  0.0000,  0.0000, -0.4119, -0.1513, -0.2186],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0440,  0.0000, -0.2095,  0.0172,  2.4273,  0.8797,  0.0000,  0.0000,
-         0.1859, -0.1363, -0.1310,  0.0000, -0.1314, -0.2955,  0.0000,  0.0000,
-         0.0000, -0.0613,  0.0000,  0.0000, -0.2410,  0.2242,  0.0000, -0.8783,
-         0.2484,  0.0000, -0.2694, -0.0603,  0.5547,  0.0000, -0.9058, -1.1381,
-         0.0000, -0.0826,  0.0000,  0.0000,  0.0000,  0.0000,  0.1343,  0.8702,
-         0.4182,  0.0000, -0.2224,  0.0000,  0.0000, -0.2698,  0.0000,  0.0443,
-         0.0000,  0.0000,  0.0000,  0.0746, -0.2439,  0.1943,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5763,  0.0000,  0.0000, -0.4119, -0.1513, -0.2186],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 9.0467e-03, -7.8088e-05, -2.5551e-01, -3.5744e-02,  2.4193e+00,
-         9.4598e-01,  4.7268e-14, -3.7324e-07,  2.5536e-01, -1.7691e-01,
-        -1.5319e-01,  8.4975e-06, -1.4986e-01, -2.8131e-01, -8.4765e-11,
-        -6.3001e-08, -6.1971e-09, -9.3451e-02, -1.4333e-11,  1.7233e-07,
-        -2.1692e-01,  1.9491e-01, -3.0439e-11, -8.7789e-01,  1.6699e-01,
-         2.4704e-07, -2.5154e-01, -1.0055e-01,  5.5471e-01,  4.4269e-11,
-        -9.1178e-01, -1.1384e+00, -1.1332e-14, -1.3928e-01,  5.6246e-09,
-         3.0348e-07,  1.8762e-10,  0.0000e+00,  2.4316e-02,  8.5069e-01,
-         4.1049e-01, -2.5060e-08, -2.2885e-01,  4.6176e-11,  1.1294e-08,
-        -2.8354e-01, -1.5936e-06,  1.0864e-01,  5.3608e-06, -2.9397e-12,
-        -9.5114e-03,  6.5526e-02, -1.8425e-01,  1.0941e-01, -1.0772e-03,
-         6.1179e-16, -2.4045e-10,  8.6473e-14,  5.7382e-01, -1.2489e-05,
-         2.3099e-13, -3.6636e-01, -1.1059e-01, -2.3845e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0090,  0.0000, -0.2555, -0.0357,  2.4193,  0.9460,  0.0000,  0.0000,
-         0.2554, -0.1769, -0.1532,  0.0000, -0.1499, -0.2813,  0.0000,  0.0000,
-         0.0000, -0.0935,  0.0000,  0.0000, -0.2169,  0.1949,  0.0000, -0.8779,
-         0.1670,  0.0000, -0.2515, -0.1005,  0.5547,  0.0000, -0.9118, -1.1384,
-         0.0000, -0.1393,  0.0000,  0.0000,  0.0000,  0.0000,  0.0243,  0.8507,
-         0.4105,  0.0000, -0.2288,  0.0000,  0.0000, -0.2835,  0.0000,  0.1086,
-         0.0000,  0.0000,  0.0000,  0.0655, -0.1842,  0.1094,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5738,  0.0000,  0.0000, -0.3664, -0.1106, -0.2385],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0090,  0.0000, -0.2555, -0.0357,  2.4193,  0.9460,  0.0000,  0.0000,
-         0.2554, -0.1769, -0.1532,  0.0000, -0.1499, -0.2813,  0.0000,  0.0000,
-         0.0000, -0.0935,  0.0000,  0.0000, -0.2169,  0.1949,  0.0000, -0.8779,
-         0.1670,  0.0000, -0.2515, -0.1005,  0.5547,  0.0000, -0.9118, -1.1384,
-         0.0000, -0.1393,  0.0000,  0.0000,  0.0000,  0.0000,  0.0243,  0.8507,
-         0.4105,  0.0000, -0.2288,  0.0000,  0.0000, -0.2835,  0.0000,  0.1086,
-         0.0000,  0.0000,  0.0000,  0.0655, -0.1842,  0.1094,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5738,  0.0000,  0.0000, -0.3664, -0.1106, -0.2385],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.2516e-03, -6.8360e-05, -2.7619e-01, -7.0575e-02,  2.4119e+00,
-         9.9005e-01,  4.1379e-14, -3.2674e-07,  3.1606e-01, -1.7276e-01,
-        -1.5195e-01,  7.4389e-06, -1.7682e-01, -2.8361e-01, -7.4205e-11,
-        -5.5152e-08, -5.4250e-09, -1.0549e-01, -1.2548e-11,  1.5086e-07,
-        -2.1127e-01,  1.7789e-01, -2.6647e-11, -8.8727e-01,  8.8601e-02,
-         2.1627e-07, -2.1889e-01, -1.4150e-01,  5.4091e-01,  3.8754e-11,
-        -9.0785e-01, -1.1489e+00, -9.9202e-15, -1.7601e-01,  4.9239e-09,
-         2.6568e-07,  1.6425e-10,  0.0000e+00, -9.0658e-02,  8.2874e-01,
-         3.9359e-01, -2.1938e-08, -2.2207e-01,  4.0423e-11,  9.8871e-09,
-        -2.6641e-01, -1.3951e-06,  1.6908e-01,  4.6929e-06, -2.5734e-12,
-        -8.3265e-03,  4.4850e-02, -1.3156e-01,  8.6804e-02, -9.4304e-04,
-         5.3558e-16, -2.1049e-10,  7.5700e-14,  5.5913e-01, -1.0933e-05,
-         2.0221e-13, -3.0562e-01, -8.5440e-02, -2.6519e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0063,  0.0000, -0.2762, -0.0706,  2.4119,  0.9901,  0.0000,  0.0000,
-         0.3161, -0.1728, -0.1520,  0.0000, -0.1768, -0.2836,  0.0000,  0.0000,
-         0.0000, -0.1055,  0.0000,  0.0000, -0.2113,  0.1779,  0.0000, -0.8873,
-         0.0886,  0.0000, -0.2189, -0.1415,  0.5409,  0.0000, -0.9079, -1.1489,
-         0.0000, -0.1760,  0.0000,  0.0000,  0.0000,  0.0000, -0.0907,  0.8287,
-         0.3936,  0.0000, -0.2221,  0.0000,  0.0000, -0.2664,  0.0000,  0.1691,
-         0.0000,  0.0000,  0.0000,  0.0448, -0.1316,  0.0868,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5591,  0.0000,  0.0000, -0.3056, -0.0854, -0.2652],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0063,  0.0000, -0.2762, -0.0706,  2.4119,  0.9901,  0.0000,  0.0000,
-         0.3161, -0.1728, -0.1520,  0.0000, -0.1768, -0.2836,  0.0000,  0.0000,
-         0.0000, -0.1055,  0.0000,  0.0000, -0.2113,  0.1779,  0.0000, -0.8873,
-         0.0886,  0.0000, -0.2189, -0.1415,  0.5409,  0.0000, -0.9079, -1.1489,
-         0.0000, -0.1760,  0.0000,  0.0000,  0.0000,  0.0000, -0.0907,  0.8287,
-         0.3936,  0.0000, -0.2221,  0.0000,  0.0000, -0.2664,  0.0000,  0.1691,
-         0.0000,  0.0000,  0.0000,  0.0448, -0.1316,  0.0868,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5591,  0.0000,  0.0000, -0.3056, -0.0854, -0.2652],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.3257e-02, -5.9867e-05, -2.3028e-01, -8.6201e-02,  2.4038e+00,
-         1.0251e+00,  3.6239e-14, -2.8614e-07,  3.5762e-01, -1.1913e-01,
-        -1.1953e-01,  6.5147e-06, -1.6636e-01, -3.0213e-01, -6.4986e-11,
-        -4.8300e-08, -4.7510e-09, -9.3260e-02, -1.0989e-11,  1.3212e-07,
-        -2.2403e-01,  1.6458e-01, -2.3336e-11, -9.0503e-01,  1.1615e-02,
-         1.8940e-07, -1.2897e-01, -2.0925e-01,  5.1959e-01,  3.3939e-11,
-        -8.9600e-01, -1.1643e+00, -8.6877e-15, -2.0290e-01,  4.3122e-09,
-         2.3267e-07,  1.4384e-10,  0.0000e+00, -1.6583e-01,  8.0938e-01,
-         3.6108e-01, -1.9213e-08, -1.7094e-01,  3.5401e-11,  8.6588e-09,
-        -2.4792e-01, -1.2217e-06,  2.4315e-01,  4.1099e-06, -2.2537e-12,
-        -7.2920e-03, -5.4695e-02, -1.2421e-01,  2.0123e-01, -8.2588e-04,
-         4.6904e-16, -1.8434e-10,  6.6295e-14,  5.3295e-01, -9.5745e-06,
-         1.7709e-13, -2.1142e-01, -5.8128e-02, -2.8156e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0233,  0.0000, -0.2303, -0.0862,  2.4038,  1.0251,  0.0000,  0.0000,
-         0.3576, -0.1191, -0.1195,  0.0000, -0.1664, -0.3021,  0.0000,  0.0000,
-         0.0000, -0.0933,  0.0000,  0.0000, -0.2240,  0.1646,  0.0000, -0.9050,
-         0.0116,  0.0000, -0.1290, -0.2093,  0.5196,  0.0000, -0.8960, -1.1643,
-         0.0000, -0.2029,  0.0000,  0.0000,  0.0000,  0.0000, -0.1658,  0.8094,
-         0.3611,  0.0000, -0.1709,  0.0000,  0.0000, -0.2479,  0.0000,  0.2432,
-         0.0000,  0.0000,  0.0000, -0.0547, -0.1242,  0.2012,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5330,  0.0000,  0.0000, -0.2114, -0.0581, -0.2816],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0233,  0.0000, -0.2303, -0.0862,  2.4038,  1.0251,  0.0000,  0.0000,
-         0.3576, -0.1191, -0.1195,  0.0000, -0.1664, -0.3021,  0.0000,  0.0000,
-         0.0000, -0.0933,  0.0000,  0.0000, -0.2240,  0.1646,  0.0000, -0.9050,
-         0.0116,  0.0000, -0.1290, -0.2093,  0.5196,  0.0000, -0.8960, -1.1643,
-         0.0000, -0.2029,  0.0000,  0.0000,  0.0000,  0.0000, -0.1658,  0.8094,
-         0.3611,  0.0000, -0.1709,  0.0000,  0.0000, -0.2479,  0.0000,  0.2432,
-         0.0000,  0.0000,  0.0000, -0.0547, -0.1242,  0.2012,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5330,  0.0000,  0.0000, -0.2114, -0.0581, -0.2816],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.6425e-02, -5.2450e-05, -2.3463e-01, -9.8281e-02,  2.3968e+00,
-         1.0455e+00,  3.1749e-14, -2.5069e-07,  3.8281e-01, -1.4376e-02,
-        -4.7228e-02,  5.7076e-06, -1.9548e-01, -3.2274e-01, -5.6935e-11,
-        -4.2316e-08, -4.1624e-09,  1.3500e-02, -9.6274e-12,  1.1575e-07,
-        -2.1498e-01,  2.2398e-01, -2.0445e-11, -8.6986e-01,  4.0855e-02,
-         1.6593e-07, -1.0179e-01, -2.1226e-01,  5.0298e-01,  2.9734e-11,
-        -8.6452e-01, -1.1736e+00, -7.6114e-15, -1.7156e-01,  3.7779e-09,
-         2.0384e-07,  1.2602e-10,  0.0000e+00, -2.8165e-01,  7.9452e-01,
-         3.6620e-01, -1.6833e-08, -1.6410e-01,  3.1015e-11,  7.5861e-09,
-        -2.7531e-01, -1.0704e-06,  2.6693e-01,  3.6007e-06, -1.9745e-12,
-        -6.3886e-03, -2.4725e-02, -1.8576e-01,  3.4751e-01, -7.2356e-04,
-         4.1093e-16, -1.6150e-10,  5.8082e-14,  5.0384e-01, -8.3883e-06,
-         1.5515e-13, -1.2888e-01, -5.0114e-02, -3.0642e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0164,  0.0000, -0.2346, -0.0983,  2.3968,  1.0455,  0.0000,  0.0000,
-         0.3828, -0.0144, -0.0472,  0.0000, -0.1955, -0.3227,  0.0000,  0.0000,
-         0.0000,  0.0135,  0.0000,  0.0000, -0.2150,  0.2240,  0.0000, -0.8699,
-         0.0409,  0.0000, -0.1018, -0.2123,  0.5030,  0.0000, -0.8645, -1.1736,
-         0.0000, -0.1716,  0.0000,  0.0000,  0.0000,  0.0000, -0.2817,  0.7945,
-         0.3662,  0.0000, -0.1641,  0.0000,  0.0000, -0.2753,  0.0000,  0.2669,
-         0.0000,  0.0000,  0.0000, -0.0247, -0.1858,  0.3475,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5038,  0.0000,  0.0000, -0.1289, -0.0501, -0.3064],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0164,  0.0000, -0.2346, -0.0983,  2.3968,  1.0455,  0.0000,  0.0000,
-         0.3828, -0.0144, -0.0472,  0.0000, -0.1955, -0.3227,  0.0000,  0.0000,
-         0.0000,  0.0135,  0.0000,  0.0000, -0.2150,  0.2240,  0.0000, -0.8699,
-         0.0409,  0.0000, -0.1018, -0.2123,  0.5030,  0.0000, -0.8645, -1.1736,
-         0.0000, -0.1716,  0.0000,  0.0000,  0.0000,  0.0000, -0.2817,  0.7945,
-         0.3662,  0.0000, -0.1641,  0.0000,  0.0000, -0.2753,  0.0000,  0.2669,
-         0.0000,  0.0000,  0.0000, -0.0247, -0.1858,  0.3475,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5038,  0.0000,  0.0000, -0.1289, -0.0501, -0.3064],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.9391e-02, -4.5970e-05, -2.1353e-01, -9.6219e-02,  2.3897e+00,
-         1.0626e+00,  2.7827e-14, -2.1972e-07,  4.0798e-01,  4.8122e-02,
-        -2.8404e-02,  5.0025e-06, -1.9660e-01, -3.2711e-01, -4.9901e-11,
-        -3.7088e-08, -3.6482e-09,  1.3583e-01, -8.4380e-12,  1.0145e-07,
-        -1.7190e-01,  2.6453e-01, -1.7919e-11, -8.2562e-01,  5.3005e-02,
-         1.4543e-07, -6.0023e-02, -1.9627e-01,  4.7956e-01,  2.6061e-11,
-        -8.3877e-01, -1.1797e+00, -6.6711e-15, -1.3835e-01,  3.3112e-09,
-         1.7866e-07,  1.1045e-10,  0.0000e+00, -3.9511e-01,  7.8790e-01,
-         3.5450e-01, -1.4753e-08, -1.3269e-01,  2.7184e-11,  6.6489e-09,
-        -2.6405e-01, -9.3814e-07,  3.0655e-01,  3.1559e-06, -1.7306e-12,
-        -5.5993e-03,  7.5969e-03, -2.5795e-01,  4.7813e-01, -6.3417e-04,
-         3.6016e-16, -1.4155e-10,  5.0907e-14,  4.6094e-01, -7.3520e-06,
-         1.3598e-13, -5.3259e-02, -4.3792e-02, -3.4981e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0194,  0.0000, -0.2135, -0.0962,  2.3897,  1.0626,  0.0000,  0.0000,
-         0.4080,  0.0481, -0.0284,  0.0000, -0.1966, -0.3271,  0.0000,  0.0000,
-         0.0000,  0.1358,  0.0000,  0.0000, -0.1719,  0.2645,  0.0000, -0.8256,
-         0.0530,  0.0000, -0.0600, -0.1963,  0.4796,  0.0000, -0.8388, -1.1797,
-         0.0000, -0.1383,  0.0000,  0.0000,  0.0000,  0.0000, -0.3951,  0.7879,
-         0.3545,  0.0000, -0.1327,  0.0000,  0.0000, -0.2640,  0.0000,  0.3066,
-         0.0000,  0.0000,  0.0000,  0.0076, -0.2579,  0.4781,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4609,  0.0000,  0.0000, -0.0533, -0.0438, -0.3498],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0194,  0.0000, -0.2135, -0.0962,  2.3897,  1.0626,  0.0000,  0.0000,
-         0.4080,  0.0481, -0.0284,  0.0000, -0.1966, -0.3271,  0.0000,  0.0000,
-         0.0000,  0.1358,  0.0000,  0.0000, -0.1719,  0.2645,  0.0000, -0.8256,
-         0.0530,  0.0000, -0.0600, -0.1963,  0.4796,  0.0000, -0.8388, -1.1797,
-         0.0000, -0.1383,  0.0000,  0.0000,  0.0000,  0.0000, -0.3951,  0.7879,
-         0.3545,  0.0000, -0.1327,  0.0000,  0.0000, -0.2640,  0.0000,  0.3066,
-         0.0000,  0.0000,  0.0000,  0.0076, -0.2579,  0.4781,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4609,  0.0000,  0.0000, -0.0533, -0.0438, -0.3498],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.8239e-02, -4.0307e-05, -2.2231e-01, -1.4754e-01,  2.3806e+00,
-         1.0545e+00,  2.4399e-14, -1.9265e-07,  4.2140e-01,  2.1091e-02,
-         2.2928e-02,  4.3862e-06, -1.9737e-01, -3.3502e-01, -4.3754e-11,
-        -3.2519e-08, -3.1988e-09,  1.0856e-01, -7.3985e-12,  8.8955e-08,
-        -1.4906e-01,  2.6437e-01, -1.5712e-11, -7.7408e-01,  4.1049e-02,
-         1.2752e-07, -6.1224e-02, -1.9710e-01,  4.3032e-01,  2.2850e-11,
-        -8.1056e-01, -1.1840e+00, -5.8493e-15, -1.1118e-01,  2.9033e-09,
-         1.5665e-07,  9.6847e-11,  0.0000e+00, -4.5921e-01,  7.8824e-01,
-         3.1632e-01, -1.2936e-08, -1.0376e-01,  2.3835e-11,  5.8298e-09,
-        -2.4681e-01, -8.2257e-07,  3.6023e-01,  2.7671e-06, -1.5174e-12,
-        -4.9095e-03,  6.5767e-02, -2.7375e-01,  3.8492e-01, -5.5605e-04,
-         3.1579e-16, -1.2411e-10,  4.4635e-14,  4.7279e-01, -6.4463e-06,
-         1.1923e-13, -6.5143e-02,  2.6754e-02, -4.0376e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0382,  0.0000, -0.2223, -0.1475,  2.3806,  1.0545,  0.0000,  0.0000,
-         0.4214,  0.0211,  0.0229,  0.0000, -0.1974, -0.3350,  0.0000,  0.0000,
-         0.0000,  0.1086,  0.0000,  0.0000, -0.1491,  0.2644,  0.0000, -0.7741,
-         0.0410,  0.0000, -0.0612, -0.1971,  0.4303,  0.0000, -0.8106, -1.1840,
-         0.0000, -0.1112,  0.0000,  0.0000,  0.0000,  0.0000, -0.4592,  0.7882,
-         0.3163,  0.0000, -0.1038,  0.0000,  0.0000, -0.2468,  0.0000,  0.3602,
-         0.0000,  0.0000,  0.0000,  0.0658, -0.2737,  0.3849,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4728,  0.0000,  0.0000, -0.0651,  0.0268, -0.4038],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0382,  0.0000, -0.2223, -0.1475,  2.3806,  1.0545,  0.0000,  0.0000,
-         0.4214,  0.0211,  0.0229,  0.0000, -0.1974, -0.3350,  0.0000,  0.0000,
-         0.0000,  0.1086,  0.0000,  0.0000, -0.1491,  0.2644,  0.0000, -0.7741,
-         0.0410,  0.0000, -0.0612, -0.1971,  0.4303,  0.0000, -0.8106, -1.1840,
-         0.0000, -0.1112,  0.0000,  0.0000,  0.0000,  0.0000, -0.4592,  0.7882,
-         0.3163,  0.0000, -0.1038,  0.0000,  0.0000, -0.2468,  0.0000,  0.3602,
-         0.0000,  0.0000,  0.0000,  0.0658, -0.2737,  0.3849,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4728,  0.0000,  0.0000, -0.0651,  0.0268, -0.4038],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.4047e-02, -3.5356e-05, -2.2597e-01, -2.3628e-01,  2.3731e+00,
-         1.0383e+00,  2.1401e-14, -1.6899e-07,  4.3383e-01, -1.6966e-02,
-         1.0988e-01,  3.8474e-06, -1.7462e-01, -3.3765e-01, -3.8379e-11,
-        -2.8524e-08, -2.8058e-09, -1.5402e-02, -6.4897e-12,  7.8027e-08,
-        -1.5063e-01,  2.2487e-01, -1.3782e-11, -7.6497e-01,  2.8476e-03,
-         1.1185e-07, -4.4339e-02, -2.0760e-01,  3.7381e-01,  2.0043e-11,
-        -7.8712e-01, -1.1889e+00, -5.1307e-15, -1.2125e-01,  2.5466e-09,
-         1.3741e-07,  8.4950e-11,  0.0000e+00, -4.6232e-01,  7.7459e-01,
-         2.9288e-01, -1.1346e-08, -8.4920e-02,  2.0907e-11,  5.1136e-09,
-        -2.1839e-01, -7.2152e-07,  4.3129e-01,  2.4272e-06, -1.3310e-12,
-        -4.3064e-03,  8.7193e-02, -1.9976e-01,  2.0103e-01, -4.8774e-04,
-         2.7700e-16, -1.0887e-10,  3.9152e-14,  5.2590e-01, -5.6544e-06,
-         1.0458e-13, -7.6494e-02,  1.5081e-01, -4.2434e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0540,  0.0000, -0.2260, -0.2363,  2.3731,  1.0383,  0.0000,  0.0000,
-         0.4338, -0.0170,  0.1099,  0.0000, -0.1746, -0.3377,  0.0000,  0.0000,
-         0.0000, -0.0154,  0.0000,  0.0000, -0.1506,  0.2249,  0.0000, -0.7650,
-         0.0028,  0.0000, -0.0443, -0.2076,  0.3738,  0.0000, -0.7871, -1.1889,
-         0.0000, -0.1213,  0.0000,  0.0000,  0.0000,  0.0000, -0.4623,  0.7746,
-         0.2929,  0.0000, -0.0849,  0.0000,  0.0000, -0.2184,  0.0000,  0.4313,
-         0.0000,  0.0000,  0.0000,  0.0872, -0.1998,  0.2010,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5259,  0.0000,  0.0000, -0.0765,  0.1508, -0.4243],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0540,  0.0000, -0.2260, -0.2363,  2.3731,  1.0383,  0.0000,  0.0000,
-         0.4338, -0.0170,  0.1099,  0.0000, -0.1746, -0.3377,  0.0000,  0.0000,
-         0.0000, -0.0154,  0.0000,  0.0000, -0.1506,  0.2249,  0.0000, -0.7650,
-         0.0028,  0.0000, -0.0443, -0.2076,  0.3738,  0.0000, -0.7871, -1.1889,
-         0.0000, -0.1213,  0.0000,  0.0000,  0.0000,  0.0000, -0.4623,  0.7746,
-         0.2929,  0.0000, -0.0849,  0.0000,  0.0000, -0.2184,  0.0000,  0.4313,
-         0.0000,  0.0000,  0.0000,  0.0872, -0.1998,  0.2010,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5259,  0.0000,  0.0000, -0.0765,  0.1508, -0.4243],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 7.0772e-02, -3.1025e-05, -2.3516e-01, -3.0062e-01,  2.3682e+00,
-         1.0211e+00,  1.8780e-14, -1.4829e-07,  4.5154e-01,  1.8601e-02,
-         2.7509e-01,  3.3761e-06, -1.8352e-01, -3.6824e-01, -3.3678e-11,
-        -2.5030e-08, -2.4621e-09, -1.0720e-01, -5.6947e-12,  6.8469e-08,
-        -1.8629e-01,  2.1725e-01, -1.2093e-11, -7.6849e-01,  8.1073e-03,
-         9.8152e-08,  3.6126e-03, -2.1353e-01,  3.2327e-01,  1.7588e-11,
-        -7.6642e-01, -1.2001e+00, -4.5022e-15, -1.2712e-01,  2.2347e-09,
-         1.2058e-07,  7.4544e-11,  0.0000e+00, -4.2872e-01,  7.4901e-01,
-         2.8230e-01, -9.9566e-09, -1.0769e-01,  1.8346e-11,  4.4872e-09,
-        -2.0814e-01, -6.3314e-07,  4.8023e-01,  2.1299e-06, -1.1679e-12,
-        -3.7789e-03,  9.4982e-02, -1.3134e-01,  6.3441e-02, -4.2800e-04,
-         2.4307e-16, -9.5530e-11,  3.4356e-14,  5.6405e-01, -4.9618e-06,
-         9.1773e-14, -6.9067e-02,  2.2792e-01, -4.2764e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0708,  0.0000, -0.2352, -0.3006,  2.3682,  1.0211,  0.0000,  0.0000,
-         0.4515,  0.0186,  0.2751,  0.0000, -0.1835, -0.3682,  0.0000,  0.0000,
-         0.0000, -0.1072,  0.0000,  0.0000, -0.1863,  0.2172,  0.0000, -0.7685,
-         0.0081,  0.0000,  0.0036, -0.2135,  0.3233,  0.0000, -0.7664, -1.2001,
-         0.0000, -0.1271,  0.0000,  0.0000,  0.0000,  0.0000, -0.4287,  0.7490,
-         0.2823,  0.0000, -0.1077,  0.0000,  0.0000, -0.2081,  0.0000,  0.4802,
-         0.0000,  0.0000,  0.0000,  0.0950, -0.1313,  0.0634,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5641,  0.0000,  0.0000, -0.0691,  0.0000, -0.4276],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0708,  0.0000, -0.2352, -0.3006,  2.3682,  1.0211,  0.0000,  0.0000,
-         0.4515,  0.0186,  0.2751,  0.0000, -0.1835, -0.3682,  0.0000,  0.0000,
-         0.0000, -0.1072,  0.0000,  0.0000, -0.1863,  0.2172,  0.0000, -0.7685,
-         0.0081,  0.0000,  0.0036, -0.2135,  0.3233,  0.0000, -0.7664, -1.2001,
-         0.0000, -0.1271,  0.0000,  0.0000,  0.0000,  0.0000, -0.4287,  0.7490,
-         0.2823,  0.0000, -0.1077,  0.0000,  0.0000, -0.2081,  0.0000,  0.4802,
-         0.0000,  0.0000,  0.0000,  0.0950, -0.1313,  0.0634,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5641,  0.0000,  0.0000, -0.0691,  0.0000, -0.4276],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.1413e-01, -2.7235e-05, -2.4560e-01, -3.1330e-01,  2.3641e+00,
-         1.0200e+00,  1.6486e-14, -1.3018e-07,  4.8288e-01,  4.9718e-02,
-         3.3757e-01,  2.9637e-06, -1.6411e-01, -4.1985e-01, -2.9564e-11,
-        -2.1973e-08, -2.1614e-09, -1.3354e-01, -4.9992e-12,  6.0107e-08,
-        -1.8563e-01,  2.1448e-01, -1.0616e-11, -7.3568e-01,  2.5259e-02,
-         8.6163e-08,  5.7455e-02, -1.9026e-01,  2.9704e-01,  1.5440e-11,
-        -7.7001e-01, -1.2091e+00, -3.9523e-15, -1.2450e-01,  1.9617e-09,
-         1.0585e-07,  6.5439e-11,  0.0000e+00, -3.8022e-01,  7.1287e-01,
-         2.6441e-01, -8.7405e-09, -1.3062e-01,  1.6105e-11,  3.9392e-09,
-        -2.1744e-01, -5.5581e-07,  5.1216e-01,  1.8697e-06, -1.0253e-12,
-        -3.3174e-03,  1.7549e-01, -9.3439e-02, -2.9112e-02, -3.7572e-04,
-         2.1338e-16, -8.3862e-11,  3.0160e-14,  6.0475e-01, -4.3557e-06,
-         8.0564e-14,  1.7507e-03,  6.7689e-02, -4.2621e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 1.1413e-01,  0.0000e+00, -2.4560e-01, -3.1330e-01,  2.3641e+00,
-         1.0200e+00,  0.0000e+00,  0.0000e+00,  4.8288e-01,  4.9718e-02,
-         3.3757e-01,  0.0000e+00, -1.6411e-01, -4.1985e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.3354e-01,  0.0000e+00,  0.0000e+00,
-        -1.8563e-01,  2.1448e-01,  0.0000e+00, -7.3568e-01,  2.5259e-02,
-         0.0000e+00,  5.7455e-02, -1.9026e-01,  2.9704e-01,  0.0000e+00,
-        -7.7001e-01, -1.2091e+00,  0.0000e+00, -1.2450e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -3.8022e-01,  7.1287e-01,
-         2.6441e-01,  0.0000e+00, -1.3062e-01,  0.0000e+00,  0.0000e+00,
-        -2.1744e-01,  0.0000e+00,  5.1216e-01,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.7549e-01, -9.3439e-02, -2.9112e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.0475e-01,  0.0000e+00,
-         0.0000e+00,  1.7507e-03,  0.0000e+00, -4.2621e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 1.1413e-01,  0.0000e+00, -2.4560e-01, -3.1330e-01,  2.3641e+00,
-         1.0200e+00,  0.0000e+00,  0.0000e+00,  4.8288e-01,  4.9718e-02,
-         3.3757e-01,  0.0000e+00, -1.6411e-01, -4.1985e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.3354e-01,  0.0000e+00,  0.0000e+00,
-        -1.8563e-01,  2.1448e-01,  0.0000e+00, -7.3568e-01,  2.5259e-02,
-         0.0000e+00,  5.7455e-02, -1.9026e-01,  2.9704e-01,  0.0000e+00,
-        -7.7001e-01, -1.2091e+00,  0.0000e+00, -1.2450e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -3.8022e-01,  7.1287e-01,
-         2.6441e-01,  0.0000e+00, -1.3062e-01,  0.0000e+00,  0.0000e+00,
-        -2.1744e-01,  0.0000e+00,  5.1216e-01,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.7549e-01, -9.3439e-02, -2.9112e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.0475e-01,  0.0000e+00,
-         0.0000e+00,  1.7507e-03,  0.0000e+00, -4.2621e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.8272e-01, -2.3918e-05, -2.3074e-01, -2.8774e-01,  2.3592e+00,
-         1.0047e+00,  1.4478e-14, -1.1432e-07,  5.3675e-01,  2.0737e-02,
-         3.1615e-01,  2.6028e-06, -8.9359e-02, -4.3104e-01, -2.5964e-11,
-        -1.9297e-08, -1.8982e-09, -1.0937e-01, -4.3903e-12,  5.2786e-08,
-        -1.1761e-01,  2.2608e-01, -9.3234e-12, -7.0709e-01,  4.4547e-02,
-         7.5670e-08,  9.0018e-02, -1.2772e-01,  3.0415e-01,  1.3560e-11,
-        -7.6986e-01, -1.2192e+00, -3.4710e-15, -1.0290e-01,  1.7228e-09,
-         9.2957e-08,  5.7470e-11,  0.0000e+00, -3.2945e-01,  6.9400e-01,
-         2.3770e-01, -7.6760e-09, -1.2133e-01,  1.4144e-11,  3.4594e-09,
-        -1.9025e-01, -4.8812e-07,  5.1451e-01,  1.6420e-06, -9.0042e-13,
-        -2.9134e-03,  3.2089e-01, -1.0767e-01, -9.7170e-02, -3.2996e-04,
-         1.8739e-16, -7.3649e-11,  2.6487e-14,  6.4373e-01, -3.8253e-06,
-         7.0752e-14,  1.1892e-01,  5.9446e-02, -4.2944e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1827,  0.0000, -0.2307, -0.2877,  2.3592,  1.0047,  0.0000,  0.0000,
-         0.5367,  0.0207,  0.3161,  0.0000, -0.0894, -0.4310,  0.0000,  0.0000,
-         0.0000, -0.1094,  0.0000,  0.0000, -0.1176,  0.2261,  0.0000, -0.7071,
-         0.0445,  0.0000,  0.0900, -0.1277,  0.3041,  0.0000, -0.7699, -1.2192,
-         0.0000, -0.1029,  0.0000,  0.0000,  0.0000,  0.0000, -0.3295,  0.6940,
-         0.2377,  0.0000, -0.1213,  0.0000,  0.0000, -0.1902,  0.0000,  0.5145,
-         0.0000,  0.0000,  0.0000,  0.3209, -0.1077, -0.0972,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6437,  0.0000,  0.0000,  0.1189,  0.0000, -0.4294],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1827,  0.0000, -0.2307, -0.2877,  2.3592,  1.0047,  0.0000,  0.0000,
-         0.5367,  0.0207,  0.3161,  0.0000, -0.0894, -0.4310,  0.0000,  0.0000,
-         0.0000, -0.1094,  0.0000,  0.0000, -0.1176,  0.2261,  0.0000, -0.7071,
-         0.0445,  0.0000,  0.0900, -0.1277,  0.3041,  0.0000, -0.7699, -1.2192,
-         0.0000, -0.1029,  0.0000,  0.0000,  0.0000,  0.0000, -0.3295,  0.6940,
-         0.2377,  0.0000, -0.1213,  0.0000,  0.0000, -0.1902,  0.0000,  0.5145,
-         0.0000,  0.0000,  0.0000,  0.3209, -0.1077, -0.0972,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6437,  0.0000,  0.0000,  0.1189,  0.0000, -0.4294],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.5192e-01, -2.1014e-05, -2.0999e-01, -2.1881e-01,  2.3565e+00,
-         1.0079e+00,  1.2720e-14, -1.0044e-07,  5.8410e-01, -2.8010e-03,
-         2.5244e-01,  2.2867e-06, -1.4856e-02, -4.3181e-01, -2.2811e-11,
-        -1.6954e-08, -1.6677e-09, -2.4564e-02, -3.8572e-12,  4.6376e-08,
-        -6.0285e-02,  2.2609e-01, -8.1912e-12, -7.4030e-01,  3.8425e-02,
-         6.6481e-08,  1.4288e-01, -4.3570e-02,  3.1212e-01,  1.1913e-11,
-        -7.5299e-01, -1.2319e+00, -3.0495e-15, -1.1795e-01,  1.5136e-09,
-         8.1669e-08,  5.0491e-11,  0.0000e+00, -2.6142e-01,  6.7584e-01,
-         1.8574e-01, -6.7439e-09, -1.1088e-01,  1.2426e-11,  3.0393e-09,
-        -1.5506e-01, -4.2884e-07,  5.3022e-01,  1.4426e-06, -7.9108e-13,
-        -2.5596e-03,  4.0743e-01, -1.1330e-01, -9.4854e-02, -2.8989e-04,
-         1.6464e-16, -6.4705e-11,  2.3270e-14,  6.7547e-01, -3.3607e-06,
-         6.2161e-14,  2.4534e-01,  5.2227e-02, -4.0746e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2519,  0.0000, -0.2100, -0.2188,  2.3565,  1.0079,  0.0000,  0.0000,
-         0.5841, -0.0028,  0.2524,  0.0000, -0.0149, -0.4318,  0.0000,  0.0000,
-         0.0000, -0.0246,  0.0000,  0.0000, -0.0603,  0.2261,  0.0000, -0.7403,
-         0.0384,  0.0000,  0.1429, -0.0436,  0.3121,  0.0000, -0.7530, -1.2319,
-         0.0000, -0.1179,  0.0000,  0.0000,  0.0000,  0.0000, -0.2614,  0.6758,
-         0.1857,  0.0000, -0.1109,  0.0000,  0.0000, -0.1551,  0.0000,  0.5302,
-         0.0000,  0.0000,  0.0000,  0.4074, -0.1133, -0.0949,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6755,  0.0000,  0.0000,  0.2453,  0.0000, -0.4075],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2519,  0.0000, -0.2100, -0.2188,  2.3565,  1.0079,  0.0000,  0.0000,
-         0.5841, -0.0028,  0.2524,  0.0000, -0.0149, -0.4318,  0.0000,  0.0000,
-         0.0000, -0.0246,  0.0000,  0.0000, -0.0603,  0.2261,  0.0000, -0.7403,
-         0.0384,  0.0000,  0.1429, -0.0436,  0.3121,  0.0000, -0.7530, -1.2319,
-         0.0000, -0.1179,  0.0000,  0.0000,  0.0000,  0.0000, -0.2614,  0.6758,
-         0.1857,  0.0000, -0.1109,  0.0000,  0.0000, -0.1551,  0.0000,  0.5302,
-         0.0000,  0.0000,  0.0000,  0.4074, -0.1133, -0.0949,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6755,  0.0000,  0.0000,  0.2453,  0.0000, -0.4075],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.0708e-01, -1.8470e-05, -1.9715e-01, -1.6019e-01,  2.3552e+00,
-         1.0071e+00,  1.1180e-14, -8.8278e-08,  6.0517e-01,  1.5929e-02,
-         2.2271e-01,  2.0098e-06,  2.5622e-02, -4.4087e-01, -2.0049e-11,
-        -1.4901e-08, -1.4657e-09,  7.4419e-02, -3.3902e-12,  4.0761e-08,
-        -3.8309e-02,  2.4293e-01, -7.1994e-12, -7.6226e-01,  4.4067e-02,
-         5.8431e-08,  1.7541e-01,  1.9587e-02,  3.0306e-01,  1.0471e-11,
-        -7.1693e-01, -1.2452e+00, -2.6803e-15, -1.2406e-01,  1.3303e-09,
-         7.1781e-08,  4.4377e-11,  0.0000e+00, -2.1211e-01,  6.7650e-01,
-         1.6891e-01, -5.9273e-09, -1.0333e-01,  1.0922e-11,  2.6713e-09,
-        -1.3797e-01, -3.7692e-07,  5.3680e-01,  1.2679e-06, -6.9530e-13,
-        -2.2497e-03,  4.3727e-01, -1.1832e-01, -5.9126e-02, -2.5479e-04,
-         1.4470e-16, -5.6871e-11,  2.0453e-14,  6.8783e-01, -2.9538e-06,
-         5.4634e-14,  3.2991e-01,  4.5903e-02, -3.6408e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3071,  0.0000, -0.1972, -0.1602,  2.3552,  1.0071,  0.0000,  0.0000,
-         0.6052,  0.0159,  0.2227,  0.0000,  0.0256, -0.4409,  0.0000,  0.0000,
-         0.0000,  0.0744,  0.0000,  0.0000, -0.0383,  0.2429,  0.0000, -0.7623,
-         0.0441,  0.0000,  0.1754,  0.0196,  0.3031,  0.0000, -0.7169, -1.2452,
-         0.0000, -0.1241,  0.0000,  0.0000,  0.0000,  0.0000, -0.2121,  0.6765,
-         0.1689,  0.0000, -0.1033,  0.0000,  0.0000, -0.1380,  0.0000,  0.5368,
-         0.0000,  0.0000,  0.0000,  0.4373, -0.1183, -0.0591,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6878,  0.0000,  0.0000,  0.3299,  0.0000, -0.3641],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3071,  0.0000, -0.1972, -0.1602,  2.3552,  1.0071,  0.0000,  0.0000,
-         0.6052,  0.0159,  0.2227,  0.0000,  0.0256, -0.4409,  0.0000,  0.0000,
-         0.0000,  0.0744,  0.0000,  0.0000, -0.0383,  0.2429,  0.0000, -0.7623,
-         0.0441,  0.0000,  0.1754,  0.0196,  0.3031,  0.0000, -0.7169, -1.2452,
-         0.0000, -0.1241,  0.0000,  0.0000,  0.0000,  0.0000, -0.2121,  0.6765,
-         0.1689,  0.0000, -0.1033,  0.0000,  0.0000, -0.1380,  0.0000,  0.5368,
-         0.0000,  0.0000,  0.0000,  0.4373, -0.1183, -0.0591,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6878,  0.0000,  0.0000,  0.3299,  0.0000, -0.3641],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7227e-01, -1.6240e-05, -1.7951e-01, -1.0737e-01,  2.3520e+00,
-         1.0009e+00,  9.8302e-15, -7.7621e-08,  5.9401e-01,  8.7087e-02,
-         2.5269e-01,  1.7672e-06,  1.7841e-02, -4.5200e-01, -1.7628e-11,
-        -1.3102e-08, -1.2888e-09,  1.4896e-01, -2.9809e-12,  3.5840e-08,
-        -5.7395e-02,  2.7750e-01, -6.3303e-12, -7.8236e-01,  6.1801e-02,
-         5.1377e-08,  2.0872e-01,  5.3016e-02,  2.6825e-01,  9.2065e-12,
-        -6.8231e-01, -1.2583e+00, -2.3567e-15, -1.4287e-01,  1.1697e-09,
-         6.3115e-08,  3.9020e-11,  0.0000e+00, -1.3304e-01,  6.7933e-01,
-         1.6668e-01, -5.2118e-09, -9.0407e-02,  9.6031e-12,  2.3488e-09,
-        -1.2710e-01, -3.3141e-07,  5.2418e-01,  1.1149e-06, -6.1136e-13,
-        -1.9781e-03,  4.2775e-01, -9.6875e-02,  1.1704e-02, -2.2403e-04,
-         1.2723e-16, -5.0005e-11,  1.7984e-14,  6.9138e-01, -2.5972e-06,
-         4.8038e-14,  3.8778e-01,  4.0362e-02, -3.0653e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3723,  0.0000, -0.1795, -0.1074,  2.3520,  1.0009,  0.0000,  0.0000,
-         0.5940,  0.0871,  0.2527,  0.0000,  0.0178, -0.4520,  0.0000,  0.0000,
-         0.0000,  0.1490,  0.0000,  0.0000, -0.0574,  0.2775,  0.0000, -0.7824,
-         0.0618,  0.0000,  0.2087,  0.0530,  0.2682,  0.0000, -0.6823, -1.2583,
-         0.0000, -0.1429,  0.0000,  0.0000,  0.0000,  0.0000, -0.1330,  0.6793,
-         0.1667,  0.0000, -0.0904,  0.0000,  0.0000, -0.1271,  0.0000,  0.5242,
-         0.0000,  0.0000,  0.0000,  0.4278, -0.0969,  0.0117,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6914,  0.0000,  0.0000,  0.3878,  0.0000, -0.3065],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3723,  0.0000, -0.1795, -0.1074,  2.3520,  1.0009,  0.0000,  0.0000,
-         0.5940,  0.0871,  0.2527,  0.0000,  0.0178, -0.4520,  0.0000,  0.0000,
-         0.0000,  0.1490,  0.0000,  0.0000, -0.0574,  0.2775,  0.0000, -0.7824,
-         0.0618,  0.0000,  0.2087,  0.0530,  0.2682,  0.0000, -0.6823, -1.2583,
-         0.0000, -0.1429,  0.0000,  0.0000,  0.0000,  0.0000, -0.1330,  0.6793,
-         0.1667,  0.0000, -0.0904,  0.0000,  0.0000, -0.1271,  0.0000,  0.5242,
-         0.0000,  0.0000,  0.0000,  0.4278, -0.0969,  0.0117,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6914,  0.0000,  0.0000,  0.3878,  0.0000, -0.3065],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.1533e-01, -1.4285e-05, -1.7200e-01, -1.0260e-01,  2.3492e+00,
-         9.8657e-01,  8.6470e-15, -6.8278e-08,  5.4818e-01,  1.5997e-01,
-         3.3378e-01,  1.5545e-06, -7.0387e-02, -4.7939e-01, -1.5506e-11,
-        -1.1525e-08, -1.1337e-09,  1.3386e-01, -2.6221e-12,  3.1526e-08,
-        -1.2474e-01,  2.8858e-01, -5.5683e-12, -7.8400e-01,  6.0521e-02,
-         4.5193e-08,  2.2652e-01,  1.2839e-02,  1.9591e-01,  8.0983e-12,
-        -6.3596e-01, -1.2680e+00, -2.0730e-15, -1.4949e-01,  1.0289e-09,
-         5.5518e-08,  3.4323e-11,  0.0000e+00, -7.3279e-02,  6.7040e-01,
-         1.7007e-01, -4.5844e-09, -7.4071e-02,  8.4472e-12,  2.0661e-09,
-        -1.4813e-01, -2.9152e-07,  4.9151e-01,  9.8068e-07, -5.3777e-13,
-        -1.7400e-03,  3.6991e-01, -6.5993e-02, -8.8378e-03, -1.9707e-04,
-         1.1192e-16, -4.3986e-11,  1.5819e-14,  6.9313e-01, -2.2846e-06,
-         4.2256e-14,  4.0053e-01,  3.5503e-02, -2.5000e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4153,  0.0000, -0.1720, -0.1026,  2.3492,  0.9866,  0.0000,  0.0000,
-         0.5482,  0.1600,  0.3338,  0.0000, -0.0704, -0.4794,  0.0000,  0.0000,
-         0.0000,  0.1339,  0.0000,  0.0000, -0.1247,  0.2886,  0.0000, -0.7840,
-         0.0605,  0.0000,  0.2265,  0.0128,  0.1959,  0.0000, -0.6360, -1.2680,
-         0.0000, -0.1495,  0.0000,  0.0000,  0.0000,  0.0000, -0.0733,  0.6704,
-         0.1701,  0.0000, -0.0741,  0.0000,  0.0000, -0.1481,  0.0000,  0.4915,
-         0.0000,  0.0000,  0.0000,  0.3699, -0.0660, -0.0088,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6931,  0.0000,  0.0000,  0.4005,  0.0000, -0.2500],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4153,  0.0000, -0.1720, -0.1026,  2.3492,  0.9866,  0.0000,  0.0000,
-         0.5482,  0.1600,  0.3338,  0.0000, -0.0704, -0.4794,  0.0000,  0.0000,
-         0.0000,  0.1339,  0.0000,  0.0000, -0.1247,  0.2886,  0.0000, -0.7840,
-         0.0605,  0.0000,  0.2265,  0.0128,  0.1959,  0.0000, -0.6360, -1.2680,
-         0.0000, -0.1495,  0.0000,  0.0000,  0.0000,  0.0000, -0.0733,  0.6704,
-         0.1701,  0.0000, -0.0741,  0.0000,  0.0000, -0.1481,  0.0000,  0.4915,
-         0.0000,  0.0000,  0.0000,  0.3699, -0.0660, -0.0088,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6931,  0.0000,  0.0000,  0.4005,  0.0000, -0.2500],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.6677e-01, -1.2571e-05, -1.5594e-01, -1.0024e-01,  2.3504e+00,
-         9.7484e-01,  7.6092e-15, -6.0084e-08,  4.9806e-01,  1.6105e-01,
-         2.4589e-01,  1.3679e-06, -1.5337e-01, -4.7843e-01, -1.3646e-11,
-        -1.0142e-08, -9.9761e-10,  1.3810e-01, -2.3074e-12,  2.7742e-08,
-        -1.0715e-01,  2.5661e-01, -4.9000e-12, -7.8881e-01,  5.4050e-02,
-         3.9769e-08,  2.0259e-01, -2.7546e-02,  1.4280e-01,  7.1264e-12,
-        -5.8620e-01, -1.2694e+00, -1.8242e-15, -1.3681e-01,  9.0545e-10,
-         4.8855e-08,  3.0204e-11,  0.0000e+00, -1.5987e-02,  7.0753e-01,
-         1.6494e-01, -4.0342e-09, -5.1281e-02,  7.4334e-12,  1.8181e-09,
-        -1.7663e-01, -2.5653e-07,  4.9033e-01,  8.6298e-07, -4.7323e-13,
-        -1.5311e-03,  3.1824e-01, -1.3724e-01, -2.9192e-02, -1.7342e-04,
-         9.8487e-17, -3.8707e-11,  1.3920e-14,  6.9884e-01, -2.0104e-06,
-         3.7185e-14,  3.8833e-01,  3.1242e-02, -2.5436e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4668,  0.0000, -0.1559, -0.1002,  2.3504,  0.9748,  0.0000,  0.0000,
-         0.4981,  0.1610,  0.2459,  0.0000, -0.1534, -0.4784,  0.0000,  0.0000,
-         0.0000,  0.1381,  0.0000,  0.0000, -0.1071,  0.2566,  0.0000, -0.7888,
-         0.0541,  0.0000,  0.2026, -0.0275,  0.1428,  0.0000, -0.5862, -1.2694,
-         0.0000, -0.1368,  0.0000,  0.0000,  0.0000,  0.0000, -0.0160,  0.7075,
-         0.1649,  0.0000, -0.0513,  0.0000,  0.0000, -0.1766,  0.0000,  0.4903,
-         0.0000,  0.0000,  0.0000,  0.3182, -0.1372, -0.0292,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6988,  0.0000,  0.0000,  0.3883,  0.0000, -0.2544],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4668,  0.0000, -0.1559, -0.1002,  2.3504,  0.9748,  0.0000,  0.0000,
-         0.4981,  0.1610,  0.2459,  0.0000, -0.1534, -0.4784,  0.0000,  0.0000,
-         0.0000,  0.1381,  0.0000,  0.0000, -0.1071,  0.2566,  0.0000, -0.7888,
-         0.0541,  0.0000,  0.2026, -0.0275,  0.1428,  0.0000, -0.5862, -1.2694,
-         0.0000, -0.1368,  0.0000,  0.0000,  0.0000,  0.0000, -0.0160,  0.7075,
-         0.1649,  0.0000, -0.0513,  0.0000,  0.0000, -0.1766,  0.0000,  0.4903,
-         0.0000,  0.0000,  0.0000,  0.3182, -0.1372, -0.0292,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6988,  0.0000,  0.0000,  0.3883,  0.0000, -0.2544],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.0352e-01, -1.1066e-05, -1.3806e-01, -1.2082e-01,  2.3510e+00,
-         9.6690e-01,  6.6987e-15, -5.2894e-08,  4.1665e-01,  1.8644e-01,
-         2.5002e-01,  1.2042e-06, -2.2726e-01, -4.8709e-01, -1.2013e-11,
-        -8.9283e-09, -8.7824e-10,  9.8670e-02, -2.0313e-12,  2.4423e-08,
-        -1.2151e-01,  2.0525e-01, -4.3137e-12, -7.7584e-01,  2.8900e-02,
-         3.5011e-08,  1.9878e-01, -9.7896e-02,  1.0834e-01,  6.2737e-12,
-        -5.5776e-01, -1.2660e+00, -1.6059e-15, -1.3382e-01,  7.9711e-10,
-         4.3009e-08,  2.6590e-11,  0.0000e+00,  3.2679e-02,  7.2355e-01,
-         1.5632e-01, -3.5515e-09, -2.1112e-02,  6.5440e-12,  1.6006e-09,
-        -2.0694e-01, -2.2584e-07,  4.7992e-01,  7.5972e-07, -4.1660e-13,
-        -1.3479e-03,  2.1605e-01, -1.7241e-01, -9.3092e-02, -1.5267e-04,
-         8.6702e-17, -3.4075e-11,  1.2255e-14,  7.0635e-01, -1.7699e-06,
-         3.2735e-14,  3.5260e-01,  2.7504e-02, -2.3832e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5035,  0.0000, -0.1381, -0.1208,  2.3510,  0.9669,  0.0000,  0.0000,
-         0.4167,  0.1864,  0.2500,  0.0000, -0.2273, -0.4871,  0.0000,  0.0000,
-         0.0000,  0.0987,  0.0000,  0.0000, -0.1215,  0.2052,  0.0000, -0.7758,
-         0.0289,  0.0000,  0.1988, -0.0979,  0.1083,  0.0000, -0.5578, -1.2660,
-         0.0000, -0.1338,  0.0000,  0.0000,  0.0000,  0.0000,  0.0327,  0.7235,
-         0.1563,  0.0000, -0.0211,  0.0000,  0.0000, -0.2069,  0.0000,  0.4799,
-         0.0000,  0.0000,  0.0000,  0.2161, -0.1724, -0.0931,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7063,  0.0000,  0.0000,  0.3526,  0.0000, -0.2383],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5035,  0.0000, -0.1381, -0.1208,  2.3510,  0.9669,  0.0000,  0.0000,
-         0.4167,  0.1864,  0.2500,  0.0000, -0.2273, -0.4871,  0.0000,  0.0000,
-         0.0000,  0.0987,  0.0000,  0.0000, -0.1215,  0.2052,  0.0000, -0.7758,
-         0.0289,  0.0000,  0.1988, -0.0979,  0.1083,  0.0000, -0.5578, -1.2660,
-         0.0000, -0.1338,  0.0000,  0.0000,  0.0000,  0.0000,  0.0327,  0.7235,
-         0.1563,  0.0000, -0.0211,  0.0000,  0.0000, -0.2069,  0.0000,  0.4799,
-         0.0000,  0.0000,  0.0000,  0.2161, -0.1724, -0.0931,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7063,  0.0000,  0.0000,  0.3526,  0.0000, -0.2383],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.1706e-01, -9.7463e-06, -1.2696e-01, -1.4366e-01,  2.3517e+00,
-         9.5595e-01,  5.8996e-15, -4.6584e-08,  3.4555e-01,  2.1688e-01,
-         1.9804e-01,  1.0606e-06, -2.8907e-01, -4.8232e-01, -1.0580e-11,
-        -7.8632e-09, -7.7346e-10,  5.6382e-02, -1.7890e-12,  2.1509e-08,
-        -1.1559e-01,  1.5015e-01, -3.7991e-12, -7.5118e-01,  1.3157e-02,
-         3.0834e-08,  2.0103e-01, -1.5387e-01,  9.3221e-02,  5.5253e-12,
-        -5.3771e-01, -1.2635e+00, -1.4144e-15, -1.2594e-01,  7.0202e-10,
-         3.7878e-08,  2.3418e-11,  0.0000e+00,  1.3018e-01,  7.2430e-01,
-         1.4567e-01, -3.1278e-09,  8.7943e-03,  5.7633e-12,  1.4096e-09,
-        -2.2328e-01, -1.9890e-07,  4.9053e-01,  6.6909e-07, -3.6690e-13,
-        -1.1871e-03,  1.1435e-01, -2.0203e-01, -1.2635e-01, -1.3445e-04,
-         7.6359e-17, -3.0010e-11,  1.0793e-14,  6.9683e-01, -1.5587e-06,
-         2.8830e-14,  3.3215e-01,  2.4223e-02, -2.2867e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5171,  0.0000, -0.1270, -0.1437,  2.3517,  0.9559,  0.0000,  0.0000,
-         0.3455,  0.2169,  0.1980,  0.0000, -0.2891, -0.4823,  0.0000,  0.0000,
-         0.0000,  0.0564,  0.0000,  0.0000, -0.1156,  0.1502,  0.0000, -0.7512,
-         0.0132,  0.0000,  0.2010, -0.1539,  0.0932,  0.0000, -0.5377, -1.2635,
-         0.0000, -0.1259,  0.0000,  0.0000,  0.0000,  0.0000,  0.1302,  0.7243,
-         0.1457,  0.0000,  0.0088,  0.0000,  0.0000, -0.2233,  0.0000,  0.4905,
-         0.0000,  0.0000,  0.0000,  0.1143, -0.2020, -0.1264,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6968,  0.0000,  0.0000,  0.3322,  0.0000, -0.2287],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5171,  0.0000, -0.1270, -0.1437,  2.3517,  0.9559,  0.0000,  0.0000,
-         0.3455,  0.2169,  0.1980,  0.0000, -0.2891, -0.4823,  0.0000,  0.0000,
-         0.0000,  0.0564,  0.0000,  0.0000, -0.1156,  0.1502,  0.0000, -0.7512,
-         0.0132,  0.0000,  0.2010, -0.1539,  0.0932,  0.0000, -0.5377, -1.2635,
-         0.0000, -0.1259,  0.0000,  0.0000,  0.0000,  0.0000,  0.1302,  0.7243,
-         0.1457,  0.0000,  0.0088,  0.0000,  0.0000, -0.2233,  0.0000,  0.4905,
-         0.0000,  0.0000,  0.0000,  0.1143, -0.2020, -0.1264,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6968,  0.0000,  0.0000,  0.3322,  0.0000, -0.2287],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.2252e-01, -8.5871e-06, -9.9585e-02, -1.2198e-01,  2.3513e+00,
-         9.4473e-01,  5.1979e-15, -4.1043e-08,  2.9694e-01,  2.5381e-01,
-         1.0537e-01,  9.3444e-07, -3.2595e-01, -4.5994e-01, -9.3213e-12,
-        -6.9279e-09, -6.8147e-10,  5.1780e-02, -1.5762e-12,  1.8951e-08,
-        -5.6011e-02,  1.3185e-01, -3.3472e-12, -7.4366e-01,  2.8747e-02,
-         2.7166e-08,  2.1261e-01, -1.4032e-01,  8.2189e-02,  4.8681e-12,
-        -5.4267e-01, -1.2573e+00, -1.2461e-15, -8.5196e-02,  6.1852e-10,
-         3.3373e-08,  2.0632e-11,  0.0000e+00,  2.4013e-01,  6.9613e-01,
-         1.3280e-01, -2.7558e-09,  3.7309e-02,  5.0778e-12,  1.2420e-09,
-        -2.5367e-01, -1.7524e-07,  4.9054e-01,  5.8951e-07, -3.2326e-13,
-        -1.0459e-03,  1.1721e-01, -2.5591e-01, -7.8590e-02, -1.1846e-04,
-         6.7277e-17, -2.6441e-11,  9.5091e-15,  7.0517e-01, -1.3733e-06,
-         2.5401e-14,  3.3370e-01,  2.1342e-02, -2.1430e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5225,  0.0000, -0.0996, -0.1220,  2.3513,  0.9447,  0.0000,  0.0000,
-         0.2969,  0.2538,  0.1054,  0.0000, -0.3259, -0.4599,  0.0000,  0.0000,
-         0.0000,  0.0518,  0.0000,  0.0000, -0.0560,  0.1319,  0.0000, -0.7437,
-         0.0287,  0.0000,  0.2126, -0.1403,  0.0822,  0.0000, -0.5427, -1.2573,
-         0.0000, -0.0852,  0.0000,  0.0000,  0.0000,  0.0000,  0.2401,  0.6961,
-         0.1328,  0.0000,  0.0373,  0.0000,  0.0000, -0.2537,  0.0000,  0.4905,
-         0.0000,  0.0000,  0.0000,  0.1172, -0.2559, -0.0786,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7052,  0.0000,  0.0000,  0.3337,  0.0000, -0.2143],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5225,  0.0000, -0.0996, -0.1220,  2.3513,  0.9447,  0.0000,  0.0000,
-         0.2969,  0.2538,  0.1054,  0.0000, -0.3259, -0.4599,  0.0000,  0.0000,
-         0.0000,  0.0518,  0.0000,  0.0000, -0.0560,  0.1319,  0.0000, -0.7437,
-         0.0287,  0.0000,  0.2126, -0.1403,  0.0822,  0.0000, -0.5427, -1.2573,
-         0.0000, -0.0852,  0.0000,  0.0000,  0.0000,  0.0000,  0.2401,  0.6961,
-         0.1328,  0.0000,  0.0373,  0.0000,  0.0000, -0.2537,  0.0000,  0.4905,
-         0.0000,  0.0000,  0.0000,  0.1172, -0.2559, -0.0786,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7052,  0.0000,  0.0000,  0.3337,  0.0000, -0.2143],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.1172e-01, -7.5688e-06, -6.6259e-02, -8.1769e-02,  2.3482e+00,
-         9.2472e-01,  4.5815e-15, -3.6176e-08,  2.6120e-01,  2.6335e-01,
-         1.2724e-02,  8.2363e-07, -3.3312e-01, -4.4191e-01, -8.2160e-12,
-        -6.1064e-09, -6.0066e-10,  6.2712e-02, -1.3893e-12,  1.6704e-08,
-         6.1756e-02,  1.3277e-01, -2.9503e-12, -7.1028e-01,  4.9281e-02,
-         2.3945e-08,  1.9987e-01, -6.0068e-02,  1.0444e-01,  4.2908e-12,
-        -5.6810e-01, -1.2497e+00, -1.0984e-15, -5.9465e-02,  5.4518e-10,
-         2.9416e-08,  1.8186e-11,  0.0000e+00,  2.8695e-01,  6.5724e-01,
-         1.3701e-01, -2.4290e-09,  8.2626e-02,  4.4757e-12,  1.0947e-09,
-        -2.9378e-01, -1.5446e-07,  4.6593e-01,  5.1960e-07, -2.8493e-13,
-        -9.2191e-04,  1.8246e-01, -3.2119e-01, -7.2704e-02, -1.0441e-04,
-         5.9299e-17, -2.3306e-11,  8.3815e-15,  7.2027e-01, -1.2105e-06,
-         2.2389e-14,  3.3204e-01,  1.8811e-02, -1.9969e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5117,  0.0000, -0.0663, -0.0818,  2.3482,  0.9247,  0.0000,  0.0000,
-         0.2612,  0.2633,  0.0127,  0.0000, -0.3331, -0.4419,  0.0000,  0.0000,
-         0.0000,  0.0627,  0.0000,  0.0000,  0.0618,  0.1328,  0.0000, -0.7103,
-         0.0493,  0.0000,  0.1999, -0.0601,  0.1044,  0.0000, -0.5681, -1.2497,
-         0.0000, -0.0595,  0.0000,  0.0000,  0.0000,  0.0000,  0.2870,  0.6572,
-         0.1370,  0.0000,  0.0826,  0.0000,  0.0000, -0.2938,  0.0000,  0.4659,
-         0.0000,  0.0000,  0.0000,  0.1825, -0.3212, -0.0727,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7203,  0.0000,  0.0000,  0.3320,  0.0000, -0.1997],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5117,  0.0000, -0.0663, -0.0818,  2.3482,  0.9247,  0.0000,  0.0000,
-         0.2612,  0.2633,  0.0127,  0.0000, -0.3331, -0.4419,  0.0000,  0.0000,
-         0.0000,  0.0627,  0.0000,  0.0000,  0.0618,  0.1328,  0.0000, -0.7103,
-         0.0493,  0.0000,  0.1999, -0.0601,  0.1044,  0.0000, -0.5681, -1.2497,
-         0.0000, -0.0595,  0.0000,  0.0000,  0.0000,  0.0000,  0.2870,  0.6572,
-         0.1370,  0.0000,  0.0826,  0.0000,  0.0000, -0.2938,  0.0000,  0.4659,
-         0.0000,  0.0000,  0.0000,  0.1825, -0.3212, -0.0727,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7203,  0.0000,  0.0000,  0.3320,  0.0000, -0.1997],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.0905e-01, -6.6740e-06, -4.1537e-02, -4.2190e-02,  2.3451e+00,
-         9.0445e-01,  4.0399e-15, -3.1900e-08,  2.3241e-01,  2.8265e-01,
-         4.6823e-03,  7.2626e-07, -3.3368e-01, -4.1274e-01, -7.2447e-12,
-        -5.3845e-09, -5.2965e-10,  7.0325e-02, -1.2250e-12,  1.4729e-08,
-         1.6560e-01,  1.3888e-01, -2.6015e-12, -6.5201e-01,  8.0529e-02,
-         2.1114e-08,  1.8959e-01,  3.2678e-02,  1.3194e-01,  3.7836e-12,
-        -6.1452e-01, -1.2450e+00, -9.6852e-16, -4.2310e-02,  4.8072e-10,
-         2.5938e-08,  1.6036e-11,  0.0000e+00,  3.1728e-01,  6.2498e-01,
-         1.6149e-01, -2.1419e-09,  1.1774e-01,  3.9466e-12,  9.6529e-10,
-        -3.2236e-01, -1.3620e-07,  4.4197e-01,  4.5818e-07, -2.5125e-13,
-        -8.1292e-04,  2.3801e-01, -3.5430e-01, -8.3684e-02, -9.2070e-05,
-         5.2289e-17, -2.0550e-11,  7.3907e-15,  7.3011e-01, -1.0674e-06,
-         1.9742e-14,  3.4774e-01,  1.6587e-02, -1.7754e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5090,  0.0000, -0.0415, -0.0422,  2.3451,  0.9045,  0.0000,  0.0000,
-         0.2324,  0.2826,  0.0047,  0.0000, -0.3337, -0.4127,  0.0000,  0.0000,
-         0.0000,  0.0703,  0.0000,  0.0000,  0.1656,  0.1389,  0.0000, -0.6520,
-         0.0805,  0.0000,  0.1896,  0.0327,  0.1319,  0.0000, -0.6145, -1.2450,
-         0.0000, -0.0423,  0.0000,  0.0000,  0.0000,  0.0000,  0.3173,  0.6250,
-         0.1615,  0.0000,  0.1177,  0.0000,  0.0000, -0.3224,  0.0000,  0.4420,
-         0.0000,  0.0000,  0.0000,  0.2380, -0.3543, -0.0837,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7301,  0.0000,  0.0000,  0.3477,  0.0000, -0.1775],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5090,  0.0000, -0.0415, -0.0422,  2.3451,  0.9045,  0.0000,  0.0000,
-         0.2324,  0.2826,  0.0047,  0.0000, -0.3337, -0.4127,  0.0000,  0.0000,
-         0.0000,  0.0703,  0.0000,  0.0000,  0.1656,  0.1389,  0.0000, -0.6520,
-         0.0805,  0.0000,  0.1896,  0.0327,  0.1319,  0.0000, -0.6145, -1.2450,
-         0.0000, -0.0423,  0.0000,  0.0000,  0.0000,  0.0000,  0.3173,  0.6250,
-         0.1615,  0.0000,  0.1177,  0.0000,  0.0000, -0.3224,  0.0000,  0.4420,
-         0.0000,  0.0000,  0.0000,  0.2380, -0.3543, -0.0837,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7301,  0.0000,  0.0000,  0.3477,  0.0000, -0.1775],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.0374e-01, -5.8874e-06, -2.9388e-02, -9.9043e-03,  2.3427e+00,
-         8.8646e-01,  3.5638e-15, -2.8140e-08,  1.7854e-01,  2.9530e-01,
-         3.3002e-02,  6.4066e-07, -3.5830e-01, -3.8373e-01, -6.3908e-12,
-        -4.7499e-09, -4.6722e-10,  7.2009e-02, -1.0807e-12,  1.2993e-08,
-         2.2293e-01,  1.5603e-01, -2.2949e-12, -6.2750e-01,  1.1240e-01,
-         1.8626e-08,  2.0123e-01,  1.0313e-01,  1.4659e-01,  3.3376e-12,
-        -6.3825e-01, -1.2496e+00, -8.5437e-16, -2.9356e-02,  4.2407e-10,
-         2.2881e-08,  1.4146e-11,  0.0000e+00,  3.2662e-01,  5.7728e-01,
-         1.7883e-01, -1.8894e-09,  1.3003e-01,  3.4814e-12,  8.5152e-10,
-        -3.1574e-01, -1.2015e-07,  4.1138e-01,  4.0417e-07, -2.2163e-13,
-        -7.1711e-04,  2.5832e-01, -3.6877e-01, -5.9825e-02, -8.1218e-05,
-         4.6126e-17, -1.8128e-11,  6.5196e-15,  7.4360e-01, -9.4157e-07,
-         1.7415e-14,  3.5281e-01,  1.4632e-02, -1.3226e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5037,  0.0000, -0.0294, -0.0099,  2.3427,  0.8865,  0.0000,  0.0000,
-         0.1785,  0.2953,  0.0330,  0.0000, -0.3583, -0.3837,  0.0000,  0.0000,
-         0.0000,  0.0720,  0.0000,  0.0000,  0.2229,  0.1560,  0.0000, -0.6275,
-         0.1124,  0.0000,  0.2012,  0.1031,  0.1466,  0.0000, -0.6382, -1.2496,
-         0.0000, -0.0294,  0.0000,  0.0000,  0.0000,  0.0000,  0.3266,  0.5773,
-         0.1788,  0.0000,  0.1300,  0.0000,  0.0000, -0.3157,  0.0000,  0.4114,
-         0.0000,  0.0000,  0.0000,  0.2583, -0.3688, -0.0598,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7436,  0.0000,  0.0000,  0.3528,  0.0000, -0.1323],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5037,  0.0000, -0.0294, -0.0099,  2.3427,  0.8865,  0.0000,  0.0000,
-         0.1785,  0.2953,  0.0330,  0.0000, -0.3583, -0.3837,  0.0000,  0.0000,
-         0.0000,  0.0720,  0.0000,  0.0000,  0.2229,  0.1560,  0.0000, -0.6275,
-         0.1124,  0.0000,  0.2012,  0.1031,  0.1466,  0.0000, -0.6382, -1.2496,
-         0.0000, -0.0294,  0.0000,  0.0000,  0.0000,  0.0000,  0.3266,  0.5773,
-         0.1788,  0.0000,  0.1300,  0.0000,  0.0000, -0.3157,  0.0000,  0.4114,
-         0.0000,  0.0000,  0.0000,  0.2583, -0.3688, -0.0598,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7436,  0.0000,  0.0000,  0.3528,  0.0000, -0.1323],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.0636e-01, -5.1956e-06,  2.3213e-02,  3.7369e-02,  2.3416e+00,
-         8.5972e-01,  3.1450e-15, -2.4833e-08,  1.2882e-01,  3.1971e-01,
-         1.6869e-01,  5.6539e-07, -3.7553e-01, -3.5003e-01, -5.6399e-12,
-        -4.1918e-09, -4.1233e-10,  4.0336e-02, -9.5368e-13,  1.1466e-08,
-         2.5255e-01,  1.3169e-01, -2.0253e-12, -6.5628e-01,  7.7795e-02,
-         1.6437e-08,  2.2244e-01,  1.5080e-01,  1.3718e-01,  2.9455e-12,
-        -6.9508e-01, -1.2479e+00, -7.5398e-16, -2.7531e-02,  3.7424e-10,
-         2.0192e-08,  1.2484e-11,  0.0000e+00,  3.4812e-01,  5.2528e-01,
-         1.7922e-01, -1.6674e-09,  1.7927e-01,  3.0724e-12,  7.5147e-10,
-        -3.6138e-01, -1.0603e-07,  3.6561e-01,  3.5668e-07, -1.9559e-13,
-        -6.3285e-04,  2.7168e-01, -3.5850e-01, -4.4886e-02, -7.1675e-05,
-         4.0706e-17, -1.5998e-11,  5.7535e-15,  7.3065e-01, -8.3093e-07,
-         1.5369e-14,  3.7831e-01,  1.2913e-02, -5.1602e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Sparsity at the end of epoch 1: 48.39%
-After Step tensor([ 0.5064,  0.0000,  0.0232,  0.0374,  2.3416,  0.8597,  0.0000,  0.0000,
-         0.1288,  0.3197,  0.1687,  0.0000, -0.3755, -0.3500,  0.0000,  0.0000,
-         0.0000,  0.0403,  0.0000,  0.0000,  0.2525,  0.1317,  0.0000, -0.6563,
-         0.0778,  0.0000,  0.2224,  0.1508,  0.1372,  0.0000, -0.6951, -1.2479,
-         0.0000, -0.0275,  0.0000,  0.0000,  0.0000,  0.0000,  0.3481,  0.5253,
-         0.1792,  0.0000,  0.1793,  0.0000,  0.0000, -0.3614,  0.0000,  0.3656,
-         0.0000,  0.0000,  0.0000,  0.2717, -0.3585, -0.0449,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7307,  0.0000,  0.0000,  0.3783,  0.0000, -0.0516],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5064,  0.0000,  0.0232,  0.0374,  2.3416,  0.8597,  0.0000,  0.0000,
-         0.1288,  0.3197,  0.1687,  0.0000, -0.3755, -0.3500,  0.0000,  0.0000,
-         0.0000,  0.0403,  0.0000,  0.0000,  0.2525,  0.1317,  0.0000, -0.6563,
-         0.0778,  0.0000,  0.2224,  0.1508,  0.1372,  0.0000, -0.6951, -1.2479,
-         0.0000, -0.0275,  0.0000,  0.0000,  0.0000,  0.0000,  0.3481,  0.5253,
-         0.1792,  0.0000,  0.1793,  0.0000,  0.0000, -0.3614,  0.0000,  0.3656,
-         0.0000,  0.0000,  0.0000,  0.2717, -0.3585, -0.0449,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7307,  0.0000,  0.0000,  0.3783,  0.0000, -0.0516],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.1017e-01, -4.5870e-06,  5.5963e-02,  5.8613e-02,  2.3396e+00,
-         8.2662e-01,  2.7766e-15, -2.1925e-08,  8.4063e-02,  3.3392e-01,
-         2.4563e-01,  4.9916e-07, -4.0126e-01, -3.2466e-01, -4.9793e-12,
-        -3.7008e-09, -3.6403e-10, -3.2571e-02, -8.4197e-13,  1.0123e-08,
-         2.3642e-01,  7.5095e-02, -1.7880e-12, -6.9629e-01,  2.1137e-02,
-         1.4512e-08,  2.2716e-01,  1.6111e-01,  9.4385e-02,  2.6004e-12,
-        -7.3967e-01, -1.2517e+00, -6.6566e-16, -4.8594e-02,  3.3040e-10,
-         1.7827e-08,  1.1021e-11,  0.0000e+00,  3.6436e-01,  4.8578e-01,
-         1.7299e-01, -1.4721e-09,  1.9960e-01,  2.7125e-12,  6.6344e-10,
-        -3.6394e-01, -9.3610e-08,  3.3499e-01,  3.1490e-07, -1.7268e-13,
-        -5.5872e-04,  2.4021e-01, -3.1761e-01, -5.5102e-02, -6.3279e-05,
-         3.5938e-17, -1.4124e-11,  5.0796e-15,  7.2253e-01, -7.3360e-07,
-         1.3569e-14,  3.8404e-01,  1.1400e-02,  4.8220e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5102,  0.0000,  0.0560,  0.0586,  2.3396,  0.8266,  0.0000,  0.0000,
-         0.0841,  0.3339,  0.2456,  0.0000, -0.4013, -0.3247,  0.0000,  0.0000,
-         0.0000, -0.0326,  0.0000,  0.0000,  0.2364,  0.0751,  0.0000, -0.6963,
-         0.0211,  0.0000,  0.2272,  0.1611,  0.0944,  0.0000, -0.7397, -1.2517,
-         0.0000, -0.0486,  0.0000,  0.0000,  0.0000,  0.0000,  0.3644,  0.4858,
-         0.1730,  0.0000,  0.1996,  0.0000,  0.0000, -0.3639,  0.0000,  0.3350,
-         0.0000,  0.0000,  0.0000,  0.2402, -0.3176, -0.0551,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7225,  0.0000,  0.0000,  0.3840,  0.0000,  0.0048],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5102,  0.0000,  0.0560,  0.0586,  2.3396,  0.8266,  0.0000,  0.0000,
-         0.0841,  0.3339,  0.2456,  0.0000, -0.4013, -0.3247,  0.0000,  0.0000,
-         0.0000, -0.0326,  0.0000,  0.0000,  0.2364,  0.0751,  0.0000, -0.6963,
-         0.0211,  0.0000,  0.2272,  0.1611,  0.0944,  0.0000, -0.7397, -1.2517,
-         0.0000, -0.0486,  0.0000,  0.0000,  0.0000,  0.0000,  0.3644,  0.4858,
-         0.1730,  0.0000,  0.1996,  0.0000,  0.0000, -0.3639,  0.0000,  0.3350,
-         0.0000,  0.0000,  0.0000,  0.2402, -0.3176, -0.0551,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7225,  0.0000,  0.0000,  0.3840,  0.0000,  0.0048],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.2471e-01, -4.0514e-06,  1.0786e-01,  1.0295e-01,  2.3384e+00,
-         7.9501e-01,  2.4524e-15, -1.9364e-08,  6.1956e-02,  3.4676e-01,
-         2.3148e-01,  4.4087e-07, -4.1695e-01, -3.2150e-01, -4.3978e-12,
-        -3.2686e-09, -3.2152e-10, -9.1195e-02, -7.4364e-13,  8.9411e-09,
-         2.8330e-01,  2.8586e-02, -1.5792e-12, -7.1171e-01, -4.3213e-02,
-         1.2817e-08,  2.3696e-01,  2.1385e-01,  8.7035e-02,  2.2968e-12,
-        -7.7860e-01, -1.2473e+00, -5.8792e-16, -5.4487e-02,  2.9182e-10,
-         1.5745e-08,  9.7343e-12,  0.0000e+00,  4.0454e-01,  4.6507e-01,
-         1.4993e-01, -1.3002e-09,  2.3597e-01,  2.3957e-12,  5.8597e-10,
-        -3.8866e-01, -8.2678e-08,  3.0897e-01,  2.7813e-07, -1.5252e-13,
-        -4.9347e-04,  2.3813e-01, -3.0145e-01, -3.9031e-02, -5.5890e-05,
-         3.1741e-17, -1.2475e-11,  4.4864e-15,  6.9425e-01, -6.4793e-07,
-         1.1984e-14,  4.1229e-01,  1.0069e-02,  2.1657e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5247,  0.0000,  0.1079,  0.1030,  2.3384,  0.7950,  0.0000,  0.0000,
-         0.0620,  0.3468,  0.2315,  0.0000, -0.4169, -0.3215,  0.0000,  0.0000,
-         0.0000, -0.0912,  0.0000,  0.0000,  0.2833,  0.0286,  0.0000, -0.7117,
-        -0.0432,  0.0000,  0.2370,  0.2138,  0.0870,  0.0000, -0.7786, -1.2473,
-         0.0000, -0.0545,  0.0000,  0.0000,  0.0000,  0.0000,  0.4045,  0.4651,
-         0.1499,  0.0000,  0.2360,  0.0000,  0.0000, -0.3887,  0.0000,  0.3090,
-         0.0000,  0.0000,  0.0000,  0.2381, -0.3014, -0.0390,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6943,  0.0000,  0.0000,  0.4123,  0.0000,  0.0217],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5247,  0.0000,  0.1079,  0.1030,  2.3384,  0.7950,  0.0000,  0.0000,
-         0.0620,  0.3468,  0.2315,  0.0000, -0.4169, -0.3215,  0.0000,  0.0000,
-         0.0000, -0.0912,  0.0000,  0.0000,  0.2833,  0.0286,  0.0000, -0.7117,
-        -0.0432,  0.0000,  0.2370,  0.2138,  0.0870,  0.0000, -0.7786, -1.2473,
-         0.0000, -0.0545,  0.0000,  0.0000,  0.0000,  0.0000,  0.4045,  0.4651,
-         0.1499,  0.0000,  0.2360,  0.0000,  0.0000, -0.3887,  0.0000,  0.3090,
-         0.0000,  0.0000,  0.0000,  0.2381, -0.3014, -0.0390,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6943,  0.0000,  0.0000,  0.4123,  0.0000,  0.0217],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.1716e-01, -3.5797e-06,  1.4272e-01,  1.1016e-01,  2.3386e+00,
-         7.8149e-01,  2.1669e-15, -1.7110e-08,  4.2773e-02,  3.2777e-01,
-         1.4361e-01,  3.8954e-07, -4.3792e-01, -3.2725e-01, -3.8858e-12,
-        -2.8881e-09, -2.8409e-10, -1.7170e-01, -6.5707e-13,  7.9002e-09,
-         3.0059e-01, -4.3050e-02, -1.3954e-12, -7.3913e-01, -1.1889e-01,
-         1.1325e-08,  2.3617e-01,  2.2207e-01,  4.8044e-02,  2.0294e-12,
-        -8.0718e-01, -1.2483e+00, -5.1948e-16, -7.7130e-02,  2.5785e-10,
-         1.3912e-08,  8.6011e-12,  0.0000e+00,  4.1691e-01,  4.4858e-01,
-         1.0628e-01, -1.1488e-09,  2.8512e-01,  2.1168e-12,  5.1775e-10,
-        -3.9264e-01, -7.3053e-08,  3.0944e-01,  2.4575e-07, -1.3476e-13,
-        -4.3602e-04,  1.7708e-01, -2.8940e-01, -1.6737e-02, -4.9383e-05,
-         2.8046e-17, -1.1023e-11,  3.9641e-15,  6.7380e-01, -5.7250e-07,
-         1.0589e-14,  4.1227e-01,  8.8969e-03,  4.2595e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5172,  0.0000,  0.1427,  0.1102,  2.3386,  0.7815,  0.0000,  0.0000,
-         0.0428,  0.3278,  0.1436,  0.0000, -0.4379, -0.3273,  0.0000,  0.0000,
-         0.0000, -0.1717,  0.0000,  0.0000,  0.3006, -0.0431,  0.0000, -0.7391,
-        -0.1189,  0.0000,  0.2362,  0.2221,  0.0480,  0.0000, -0.8072, -1.2483,
-         0.0000, -0.0771,  0.0000,  0.0000,  0.0000,  0.0000,  0.4169,  0.4486,
-         0.1063,  0.0000,  0.2851,  0.0000,  0.0000, -0.3926,  0.0000,  0.3094,
-         0.0000,  0.0000,  0.0000,  0.1771, -0.2894, -0.0167,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6738,  0.0000,  0.0000,  0.4123,  0.0000,  0.0426],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5172,  0.0000,  0.1427,  0.1102,  2.3386,  0.7815,  0.0000,  0.0000,
-         0.0428,  0.3278,  0.1436,  0.0000, -0.4379, -0.3273,  0.0000,  0.0000,
-         0.0000, -0.1717,  0.0000,  0.0000,  0.3006, -0.0431,  0.0000, -0.7391,
-        -0.1189,  0.0000,  0.2362,  0.2221,  0.0480,  0.0000, -0.8072, -1.2483,
-         0.0000, -0.0771,  0.0000,  0.0000,  0.0000,  0.0000,  0.4169,  0.4486,
-         0.1063,  0.0000,  0.2851,  0.0000,  0.0000, -0.3926,  0.0000,  0.3094,
-         0.0000,  0.0000,  0.0000,  0.1771, -0.2894, -0.0167,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6738,  0.0000,  0.0000,  0.4123,  0.0000,  0.0426],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.1651e-01, -3.1643e-06,  1.0399e-01,  1.3593e-01,  2.3378e+00,
-         7.5344e-01,  1.9154e-15, -1.5124e-08,  2.6139e-02,  3.2669e-01,
-         8.5215e-02,  3.4434e-07, -4.7656e-01, -3.2884e-01, -3.4349e-12,
-        -2.5529e-09, -2.5112e-10, -2.4629e-01, -5.8082e-13,  6.9834e-09,
-         2.6978e-01, -5.8811e-02, -1.2334e-12, -7.6683e-01, -1.2705e-01,
-         1.0011e-08,  1.8226e-01,  2.3235e-01,  1.9107e-02,  1.7939e-12,
-        -8.3220e-01, -1.2491e+00, -4.5919e-16, -1.0012e-01,  2.2792e-10,
-         1.2298e-08,  7.6029e-12,  0.0000e+00,  4.1873e-01,  4.4239e-01,
-         1.3370e-01, -1.0155e-09,  2.6697e-01,  1.8711e-12,  4.5766e-10,
-        -3.8335e-01, -6.4575e-08,  2.5043e-01,  2.1723e-07, -1.1912e-13,
-        -3.8542e-04,  1.6244e-01, -2.9563e-01,  2.2056e-02, -4.3652e-05,
-         2.4791e-17, -9.7434e-12,  3.5041e-15,  6.6253e-01, -5.0606e-07,
-         9.3602e-15,  3.7359e-01,  7.8644e-03,  6.1951e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5165,  0.0000,  0.1040,  0.1359,  2.3378,  0.7534,  0.0000,  0.0000,
-         0.0261,  0.3267,  0.0852,  0.0000, -0.4766, -0.3288,  0.0000,  0.0000,
-         0.0000, -0.2463,  0.0000,  0.0000,  0.2698, -0.0588,  0.0000, -0.7668,
-        -0.1270,  0.0000,  0.1823,  0.2324,  0.0191,  0.0000, -0.8322, -1.2491,
-         0.0000, -0.1001,  0.0000,  0.0000,  0.0000,  0.0000,  0.4187,  0.4424,
-         0.1337,  0.0000,  0.2670,  0.0000,  0.0000, -0.3833,  0.0000,  0.2504,
-         0.0000,  0.0000,  0.0000,  0.1624, -0.2956,  0.0221,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6625,  0.0000,  0.0000,  0.3736,  0.0000,  0.0620],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5165,  0.0000,  0.1040,  0.1359,  2.3378,  0.7534,  0.0000,  0.0000,
-         0.0261,  0.3267,  0.0852,  0.0000, -0.4766, -0.3288,  0.0000,  0.0000,
-         0.0000, -0.2463,  0.0000,  0.0000,  0.2698, -0.0588,  0.0000, -0.7668,
-        -0.1270,  0.0000,  0.1823,  0.2324,  0.0191,  0.0000, -0.8322, -1.2491,
-         0.0000, -0.1001,  0.0000,  0.0000,  0.0000,  0.0000,  0.4187,  0.4424,
-         0.1337,  0.0000,  0.2670,  0.0000,  0.0000, -0.3833,  0.0000,  0.2504,
-         0.0000,  0.0000,  0.0000,  0.1624, -0.2956,  0.0221,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6625,  0.0000,  0.0000,  0.3736,  0.0000,  0.0620],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.2290e-01, -2.7982e-06,  1.6282e-02,  8.2067e-02,  2.3345e+00,
-         7.0914e-01,  1.6938e-15, -1.3375e-08,  2.5124e-02,  2.9947e-01,
-         1.0676e-03,  3.0450e-07, -5.2485e-01, -3.0306e-01, -3.0375e-12,
-        -2.2576e-09, -2.2207e-10, -3.2260e-01, -5.1362e-13,  6.1754e-09,
-         2.0268e-01, -2.1891e-02, -1.0907e-12, -7.7422e-01, -8.4374e-02,
-         8.8526e-09,  7.0355e-02,  2.0147e-01, -2.0748e-02,  1.5863e-12,
-        -8.3997e-01, -1.2487e+00, -4.0607e-16, -1.3191e-01,  2.0155e-10,
-         1.0875e-08,  6.7233e-12,  0.0000e+00,  4.0574e-01,  4.2783e-01,
-         1.7422e-01, -8.9801e-10,  2.1918e-01,  1.6547e-12,  4.0472e-10,
-        -3.6775e-01, -5.7104e-08,  1.9214e-01,  1.9210e-07, -1.0534e-13,
-        -3.4083e-04,  1.8853e-01, -2.9962e-01, -2.0618e-02, -3.8602e-05,
-         2.1923e-17, -8.6161e-12,  3.0987e-15,  6.5934e-01, -4.4752e-07,
-         8.2773e-15,  2.7873e-01,  6.9545e-03,  2.4204e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 5.2290e-01,  0.0000e+00,  1.6282e-02,  8.2067e-02,  2.3345e+00,
-         7.0914e-01,  0.0000e+00,  0.0000e+00,  2.5124e-02,  2.9947e-01,
-         1.0676e-03,  0.0000e+00, -5.2485e-01, -3.0306e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.2260e-01,  0.0000e+00,  0.0000e+00,
-         2.0268e-01, -2.1891e-02,  0.0000e+00, -7.7422e-01, -8.4374e-02,
-         0.0000e+00,  7.0355e-02,  2.0147e-01, -2.0748e-02,  0.0000e+00,
-        -8.3997e-01, -1.2487e+00,  0.0000e+00, -1.3191e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.0574e-01,  4.2783e-01,
-         1.7422e-01,  0.0000e+00,  2.1918e-01,  0.0000e+00,  0.0000e+00,
-        -3.6775e-01,  0.0000e+00,  1.9214e-01,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.8853e-01, -2.9962e-01, -2.0618e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.5934e-01,  0.0000e+00,
-         0.0000e+00,  2.7873e-01,  0.0000e+00,  2.4204e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 5.2290e-01,  0.0000e+00,  1.6282e-02,  8.2067e-02,  2.3345e+00,
-         7.0914e-01,  0.0000e+00,  0.0000e+00,  2.5124e-02,  2.9947e-01,
-         1.0676e-03,  0.0000e+00, -5.2485e-01, -3.0306e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.2260e-01,  0.0000e+00,  0.0000e+00,
-         2.0268e-01, -2.1891e-02,  0.0000e+00, -7.7422e-01, -8.4374e-02,
-         0.0000e+00,  7.0355e-02,  2.0147e-01, -2.0748e-02,  0.0000e+00,
-        -8.3997e-01, -1.2487e+00,  0.0000e+00, -1.3191e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.0574e-01,  4.2783e-01,
-         1.7422e-01,  0.0000e+00,  2.1918e-01,  0.0000e+00,  0.0000e+00,
-        -3.6775e-01,  0.0000e+00,  1.9214e-01,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.8853e-01, -2.9962e-01, -2.0618e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.5934e-01,  0.0000e+00,
-         0.0000e+00,  2.7873e-01,  0.0000e+00,  2.4204e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 5.3343e-01, -2.4755e-06, -6.2621e-02,  3.4592e-02,  2.3306e+00,
-         6.7341e-01,  1.4985e-15, -1.1832e-08,  1.4102e-02,  2.8403e-01,
-        -2.6111e-02,  2.6938e-07, -5.7926e-01, -2.7802e-01, -2.6872e-12,
-        -1.9972e-09, -1.9646e-10, -3.5141e-01, -4.5439e-13,  5.4632e-09,
-         1.3933e-01,  2.8377e-02, -9.6495e-13, -7.6704e-01, -3.4198e-02,
-         7.8316e-09, -3.4219e-02,  1.5725e-01, -6.1825e-02,  1.4034e-12,
-        -8.4573e-01, -1.2488e+00, -3.5924e-16, -1.5102e-01,  1.7831e-10,
-         9.6208e-09,  5.9480e-12,  0.0000e+00,  3.9646e-01,  4.0924e-01,
-         2.2378e-01, -7.9445e-10,  1.7596e-01,  1.4638e-12,  3.5804e-10,
-        -3.5254e-01, -5.0519e-08,  1.2542e-01,  1.6994e-07, -9.3191e-14,
-        -3.0152e-04,  1.8364e-01, -2.8855e-01, -2.5956e-02, -3.4150e-05,
-         1.9395e-17, -7.6225e-12,  2.7413e-15,  6.6134e-01, -3.9590e-07,
-         7.3227e-15,  1.8435e-01,  6.1525e-03,  3.0309e-05], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 5.3343e-01,  0.0000e+00, -6.2621e-02,  3.4592e-02,  2.3306e+00,
-         6.7341e-01,  0.0000e+00,  0.0000e+00,  1.4102e-02,  2.8403e-01,
-        -2.6111e-02,  0.0000e+00, -5.7926e-01, -2.7802e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.5141e-01,  0.0000e+00,  0.0000e+00,
-         1.3933e-01,  2.8377e-02,  0.0000e+00, -7.6704e-01, -3.4198e-02,
-         0.0000e+00, -3.4219e-02,  1.5725e-01, -6.1825e-02,  0.0000e+00,
-        -8.4573e-01, -1.2488e+00,  0.0000e+00, -1.5102e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.9646e-01,  4.0924e-01,
-         2.2378e-01,  0.0000e+00,  1.7596e-01,  0.0000e+00,  0.0000e+00,
-        -3.5254e-01,  0.0000e+00,  1.2542e-01,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.8364e-01, -2.8855e-01, -2.5956e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.6134e-01,  0.0000e+00,
-         0.0000e+00,  1.8435e-01,  0.0000e+00,  3.0309e-05], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 5.3343e-01,  0.0000e+00, -6.2621e-02,  3.4592e-02,  2.3306e+00,
-         6.7341e-01,  0.0000e+00,  0.0000e+00,  1.4102e-02,  2.8403e-01,
-        -2.6111e-02,  0.0000e+00, -5.7926e-01, -2.7802e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.5141e-01,  0.0000e+00,  0.0000e+00,
-         1.3933e-01,  2.8377e-02,  0.0000e+00, -7.6704e-01, -3.4198e-02,
-         0.0000e+00, -3.4219e-02,  1.5725e-01, -6.1825e-02,  0.0000e+00,
-        -8.4573e-01, -1.2488e+00,  0.0000e+00, -1.5102e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.9646e-01,  4.0924e-01,
-         2.2378e-01,  0.0000e+00,  1.7596e-01,  0.0000e+00,  0.0000e+00,
-        -3.5254e-01,  0.0000e+00,  1.2542e-01,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.8364e-01, -2.8855e-01, -2.5956e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.6134e-01,  0.0000e+00,
-         0.0000e+00,  1.8435e-01,  0.0000e+00,  3.0309e-05], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 5.3367e-01, -2.1909e-06, -1.2739e-01, -7.4849e-03,  2.3277e+00,
-         6.2178e-01,  1.3262e-15, -1.0472e-08,  1.0277e-02,  2.6290e-01,
-        -1.3892e-02,  2.3841e-07, -6.2328e-01, -2.5429e-01, -2.3782e-12,
-        -1.7676e-09, -1.7387e-10, -3.8543e-01, -4.0215e-13,  4.8352e-09,
-         9.2043e-02,  4.9267e-02, -8.5401e-13, -7.6019e-01, -1.3747e-02,
-         6.9313e-09, -1.3801e-01,  1.0381e-01, -7.9089e-02,  1.2420e-12,
-        -8.5951e-01, -1.2467e+00, -3.1794e-16, -1.2381e-01,  1.5781e-10,
-         8.5148e-09,  5.2641e-12,  0.0000e+00,  4.0329e-01,  3.8639e-01,
-         2.7855e-01, -7.0311e-10,  1.5345e-01,  1.2956e-12,  3.1688e-10,
-        -3.3909e-01, -4.4711e-08,  7.9086e-02,  1.5041e-07, -8.2478e-14,
-        -2.6686e-04,  1.9910e-01, -2.5912e-01, -8.0970e-02, -3.0224e-05,
-         1.7165e-17, -6.7461e-12,  2.4262e-15,  6.7391e-01, -3.5039e-07,
-         6.4808e-15,  8.6988e-02,  5.4452e-03, -2.6834e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5337,  0.0000, -0.1274, -0.0075,  2.3277,  0.6218,  0.0000,  0.0000,
-         0.0103,  0.2629, -0.0139,  0.0000, -0.6233, -0.2543,  0.0000,  0.0000,
-         0.0000, -0.3854,  0.0000,  0.0000,  0.0920,  0.0493,  0.0000, -0.7602,
-        -0.0137,  0.0000, -0.1380,  0.1038, -0.0791,  0.0000, -0.8595, -1.2467,
-         0.0000, -0.1238,  0.0000,  0.0000,  0.0000,  0.0000,  0.4033,  0.3864,
-         0.2785,  0.0000,  0.1534,  0.0000,  0.0000, -0.3391,  0.0000,  0.0791,
-         0.0000,  0.0000,  0.0000,  0.1991, -0.2591, -0.0810,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6739,  0.0000,  0.0000,  0.0870,  0.0000, -0.0268],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5337,  0.0000, -0.1274, -0.0075,  2.3277,  0.6218,  0.0000,  0.0000,
-         0.0103,  0.2629, -0.0139,  0.0000, -0.6233, -0.2543,  0.0000,  0.0000,
-         0.0000, -0.3854,  0.0000,  0.0000,  0.0920,  0.0493,  0.0000, -0.7602,
-        -0.0137,  0.0000, -0.1380,  0.1038, -0.0791,  0.0000, -0.8595, -1.2467,
-         0.0000, -0.1238,  0.0000,  0.0000,  0.0000,  0.0000,  0.4033,  0.3864,
-         0.2785,  0.0000,  0.1534,  0.0000,  0.0000, -0.3391,  0.0000,  0.0791,
-         0.0000,  0.0000,  0.0000,  0.1991, -0.2591, -0.0810,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6739,  0.0000,  0.0000,  0.0870,  0.0000, -0.0268],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.2074e-01, -1.9398e-06, -1.6239e-01, -1.1877e-02,  2.3231e+00,
-         5.6153e-01,  1.1742e-15, -9.2717e-09, -2.0190e-02,  2.3701e-01,
-        -3.9129e-02,  2.1109e-07, -6.5988e-01, -2.5931e-01, -2.1057e-12,
-        -1.5650e-09, -1.5394e-10, -3.8703e-01, -3.5606e-13,  4.2810e-09,
-         6.3678e-02,  3.8120e-02, -7.5614e-13, -7.3322e-01, -2.6244e-02,
-         6.1369e-09, -1.8675e-01,  5.8830e-02, -9.3253e-02,  1.0997e-12,
-        -8.5180e-01, -1.2417e+00, -2.8150e-16, -7.6939e-02,  1.3972e-10,
-         7.5390e-09,  4.6609e-12,  0.0000e+00,  4.0118e-01,  3.6808e-01,
-         2.8113e-01, -6.2254e-10,  1.7496e-01,  1.1471e-12,  2.8056e-10,
-        -3.4745e-01, -3.9587e-08,  2.9692e-02,  1.3317e-07, -7.3026e-14,
-        -2.3628e-04,  1.7388e-01, -2.5522e-01, -1.2145e-02, -2.6760e-05,
-         1.5198e-17, -5.9730e-12,  2.1481e-15,  6.8095e-01, -3.1023e-07,
-         5.7381e-15,  4.5986e-02,  4.8211e-03, -5.1947e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5207,  0.0000, -0.1624, -0.0119,  2.3231,  0.5615,  0.0000,  0.0000,
-        -0.0202,  0.2370, -0.0391,  0.0000, -0.6599, -0.2593,  0.0000,  0.0000,
-         0.0000, -0.3870,  0.0000,  0.0000,  0.0637,  0.0381,  0.0000, -0.7332,
-        -0.0262,  0.0000, -0.1868,  0.0588, -0.0933,  0.0000, -0.8518, -1.2417,
-         0.0000, -0.0769,  0.0000,  0.0000,  0.0000,  0.0000,  0.4012,  0.3681,
-         0.2811,  0.0000,  0.1750,  0.0000,  0.0000, -0.3475,  0.0000,  0.0297,
-         0.0000,  0.0000,  0.0000,  0.1739, -0.2552, -0.0121,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6810,  0.0000,  0.0000,  0.0460,  0.0000, -0.0519],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5207,  0.0000, -0.1624, -0.0119,  2.3231,  0.5615,  0.0000,  0.0000,
-        -0.0202,  0.2370, -0.0391,  0.0000, -0.6599, -0.2593,  0.0000,  0.0000,
-         0.0000, -0.3870,  0.0000,  0.0000,  0.0637,  0.0381,  0.0000, -0.7332,
-        -0.0262,  0.0000, -0.1868,  0.0588, -0.0933,  0.0000, -0.8518, -1.2417,
-         0.0000, -0.0769,  0.0000,  0.0000,  0.0000,  0.0000,  0.4012,  0.3681,
-         0.2811,  0.0000,  0.1750,  0.0000,  0.0000, -0.3475,  0.0000,  0.0297,
-         0.0000,  0.0000,  0.0000,  0.1739, -0.2552, -0.0121,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6810,  0.0000,  0.0000,  0.0460,  0.0000, -0.0519],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.1310e-01, -1.7182e-06, -1.6530e-01, -2.0789e-02,  2.3205e+00,
-         5.0595e-01,  1.0401e-15, -8.2125e-09, -5.3121e-02,  1.8237e-01,
-        -4.9279e-02,  1.8698e-07, -6.6630e-01, -2.7404e-01, -1.8651e-12,
-        -1.3862e-09, -1.3636e-10, -3.5952e-01, -3.1539e-13,  3.7920e-09,
-         5.1737e-02, -4.3761e-03, -6.6976e-13, -7.2262e-01, -7.2184e-02,
-         5.4358e-09, -2.1712e-01,  2.3085e-02, -1.0737e-01,  9.7407e-13,
-        -8.4424e-01, -1.2354e+00, -2.4934e-16, -4.2077e-02,  1.2376e-10,
-         6.6777e-09,  4.1284e-12,  0.0000e+00,  4.0212e-01,  3.4037e-01,
-         2.5400e-01, -5.5142e-10,  2.2478e-01,  1.0160e-12,  2.4851e-10,
-        -3.3536e-01, -3.5065e-08,  5.9840e-03,  1.1796e-07, -6.4683e-14,
-        -2.0929e-04,  1.1539e-01, -2.4890e-01,  1.1372e-01, -2.3703e-05,
-         1.3462e-17, -5.2907e-12,  1.9027e-15,  6.8002e-01, -2.7479e-07,
-         5.0826e-15,  3.4305e-02,  4.2704e-03, -6.5036e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5131,  0.0000, -0.1653, -0.0208,  2.3205,  0.5060,  0.0000,  0.0000,
-        -0.0531,  0.1824, -0.0493,  0.0000, -0.6663, -0.2740,  0.0000,  0.0000,
-         0.0000, -0.3595,  0.0000,  0.0000,  0.0517, -0.0044,  0.0000, -0.7226,
-        -0.0722,  0.0000, -0.2171,  0.0231, -0.1074,  0.0000, -0.8442, -1.2354,
-         0.0000, -0.0421,  0.0000,  0.0000,  0.0000,  0.0000,  0.4021,  0.3404,
-         0.2540,  0.0000,  0.2248,  0.0000,  0.0000, -0.3354,  0.0000,  0.0060,
-         0.0000,  0.0000,  0.0000,  0.1154, -0.2489,  0.1137,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6800,  0.0000,  0.0000,  0.0343,  0.0000, -0.0650],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5131,  0.0000, -0.1653, -0.0208,  2.3205,  0.5060,  0.0000,  0.0000,
-        -0.0531,  0.1824, -0.0493,  0.0000, -0.6663, -0.2740,  0.0000,  0.0000,
-         0.0000, -0.3595,  0.0000,  0.0000,  0.0517, -0.0044,  0.0000, -0.7226,
-        -0.0722,  0.0000, -0.2171,  0.0231, -0.1074,  0.0000, -0.8442, -1.2354,
-         0.0000, -0.0421,  0.0000,  0.0000,  0.0000,  0.0000,  0.4021,  0.3404,
-         0.2540,  0.0000,  0.2248,  0.0000,  0.0000, -0.3354,  0.0000,  0.0060,
-         0.0000,  0.0000,  0.0000,  0.1154, -0.2489,  0.1137,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6800,  0.0000,  0.0000,  0.0343,  0.0000, -0.0650],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.9016e-01, -1.5226e-06, -1.2168e-01, -3.1924e-02,  2.3188e+00,
-         4.4942e-01,  9.2163e-16, -7.2773e-09, -1.0751e-01,  1.1453e-01,
-        -5.6487e-02,  1.6568e-07, -6.4972e-01, -3.1001e-01, -1.6527e-12,
-        -1.2284e-09, -1.2083e-10, -3.1265e-01, -2.7947e-13,  3.3602e-09,
-         3.0972e-02, -5.6796e-02, -5.9349e-13, -7.1835e-01, -1.4696e-01,
-         4.8168e-09, -2.3618e-01, -3.0939e-02, -1.1312e-01,  8.6315e-13,
-        -8.3655e-01, -1.2304e+00, -2.2095e-16,  5.8476e-03,  1.0967e-10,
-         5.9173e-09,  3.6583e-12,  0.0000e+00,  4.3381e-01,  3.2868e-01,
-         1.9088e-01, -4.8863e-10,  3.0511e-01,  9.0034e-13,  2.2021e-10,
-        -3.2036e-01, -3.1072e-08,  4.0201e-03,  1.0452e-07, -5.7317e-14,
-        -1.8545e-04,  7.1840e-02, -2.4482e-01,  2.4166e-01, -2.1004e-05,
-         1.1929e-17, -4.6882e-12,  1.6860e-15,  6.6414e-01, -2.4350e-07,
-         4.5038e-15,  4.6850e-02,  3.7841e-03, -8.8436e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4902,  0.0000, -0.1217, -0.0319,  2.3188,  0.4494,  0.0000,  0.0000,
-        -0.1075,  0.1145, -0.0565,  0.0000, -0.6497, -0.3100,  0.0000,  0.0000,
-         0.0000, -0.3126,  0.0000,  0.0000,  0.0310, -0.0568,  0.0000, -0.7183,
-        -0.1470,  0.0000, -0.2362, -0.0309, -0.1131,  0.0000, -0.8365, -1.2304,
-         0.0000,  0.0058,  0.0000,  0.0000,  0.0000,  0.0000,  0.4338,  0.3287,
-         0.1909,  0.0000,  0.3051,  0.0000,  0.0000, -0.3204,  0.0000,  0.0040,
-         0.0000,  0.0000,  0.0000,  0.0718, -0.2448,  0.2417,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6641,  0.0000,  0.0000,  0.0469,  0.0000, -0.0884],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4902,  0.0000, -0.1217, -0.0319,  2.3188,  0.4494,  0.0000,  0.0000,
-        -0.1075,  0.1145, -0.0565,  0.0000, -0.6497, -0.3100,  0.0000,  0.0000,
-         0.0000, -0.3126,  0.0000,  0.0000,  0.0310, -0.0568,  0.0000, -0.7183,
-        -0.1470,  0.0000, -0.2362, -0.0309, -0.1131,  0.0000, -0.8365, -1.2304,
-         0.0000,  0.0058,  0.0000,  0.0000,  0.0000,  0.0000,  0.4338,  0.3287,
-         0.1909,  0.0000,  0.3051,  0.0000,  0.0000, -0.3204,  0.0000,  0.0040,
-         0.0000,  0.0000,  0.0000,  0.0718, -0.2448,  0.2417,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6641,  0.0000,  0.0000,  0.0469,  0.0000, -0.0884],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.6475e-01, -1.3497e-06, -8.7065e-02, -6.2907e-02,  2.3167e+00,
-         3.8527e-01,  8.1701e-16, -6.4513e-09, -1.2696e-01,  5.7173e-02,
-        -1.0061e-01,  1.4688e-07, -6.4049e-01, -3.2327e-01, -1.4651e-12,
-        -1.0889e-09, -1.0711e-10, -2.7764e-01, -2.4775e-13,  2.9788e-09,
-         2.5421e-02, -9.6904e-02, -5.2612e-13, -7.0980e-01, -2.1056e-01,
-         4.2701e-09, -2.5221e-01, -7.9663e-02, -1.0382e-01,  7.6517e-13,
-        -8.3130e-01, -1.2260e+00, -1.9587e-16,  5.4681e-02,  9.7220e-11,
-         5.2456e-09,  3.2430e-12,  0.0000e+00,  4.7579e-01,  3.5424e-01,
-         1.0865e-01, -4.3316e-10,  3.9423e-01,  7.9814e-13,  1.9522e-10,
-        -3.1116e-01, -2.7545e-08,  2.2827e-02,  9.2660e-08, -5.0811e-14,
-        -1.6440e-04,  5.9195e-02, -2.2801e-01,  3.1154e-01, -1.8620e-05,
-         1.0575e-17, -4.1560e-12,  1.4947e-15,  6.3417e-01, -2.1586e-07,
-         3.9926e-15,  5.8835e-02,  3.3545e-03, -1.0345e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4648,  0.0000, -0.0871, -0.0629,  2.3167,  0.3853,  0.0000,  0.0000,
-        -0.1270,  0.0572, -0.1006,  0.0000, -0.6405, -0.3233,  0.0000,  0.0000,
-         0.0000, -0.2776,  0.0000,  0.0000,  0.0254, -0.0969,  0.0000, -0.7098,
-        -0.2106,  0.0000, -0.2522, -0.0797, -0.1038,  0.0000, -0.8313, -1.2260,
-         0.0000,  0.0547,  0.0000,  0.0000,  0.0000,  0.0000,  0.4758,  0.3542,
-         0.1087,  0.0000,  0.3942,  0.0000,  0.0000, -0.3112,  0.0000,  0.0228,
-         0.0000,  0.0000,  0.0000,  0.0592, -0.2280,  0.3115,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6342,  0.0000,  0.0000,  0.0588,  0.0000, -0.1034],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4648,  0.0000, -0.0871, -0.0629,  2.3167,  0.3853,  0.0000,  0.0000,
-        -0.1270,  0.0572, -0.1006,  0.0000, -0.6405, -0.3233,  0.0000,  0.0000,
-         0.0000, -0.2776,  0.0000,  0.0000,  0.0254, -0.0969,  0.0000, -0.7098,
-        -0.2106,  0.0000, -0.2522, -0.0797, -0.1038,  0.0000, -0.8313, -1.2260,
-         0.0000,  0.0547,  0.0000,  0.0000,  0.0000,  0.0000,  0.4758,  0.3542,
-         0.1087,  0.0000,  0.3942,  0.0000,  0.0000, -0.3112,  0.0000,  0.0228,
-         0.0000,  0.0000,  0.0000,  0.0592, -0.2280,  0.3115,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6342,  0.0000,  0.0000,  0.0588,  0.0000, -0.1034],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.3809e-01, -1.1970e-06, -9.1794e-02, -1.1416e-01,  2.3139e+00,
-         3.4374e-01,  7.2457e-16, -5.7213e-09, -1.5650e-01,  1.2470e-02,
-        -9.4811e-02,  1.3026e-07, -6.5185e-01, -3.4543e-01, -1.2994e-12,
-        -9.6573e-10, -9.4995e-11, -2.3618e-01, -2.1972e-13,  2.6417e-09,
-         3.3661e-02, -1.0059e-01, -4.6659e-13, -6.7618e-01, -2.3017e-01,
-         3.7869e-09, -2.8014e-01, -9.4572e-02, -6.9279e-02,  6.7860e-13,
-        -8.1366e-01, -1.2193e+00, -1.7371e-16,  1.1151e-01,  8.6220e-11,
-         4.6521e-09,  2.8761e-12,  0.0000e+00,  4.7927e-01,  3.5028e-01,
-         9.1960e-02, -3.8415e-10,  4.4188e-01,  7.0783e-13,  1.7313e-10,
-        -3.2384e-01, -2.4428e-08,  7.2347e-03,  8.2176e-08, -4.5062e-14,
-        -1.4580e-04,  6.6175e-02, -2.1190e-01,  3.2904e-01, -1.6513e-05,
-         9.3782e-18, -3.6858e-12,  1.3255e-15,  6.1267e-01, -1.9144e-07,
-         3.5408e-15,  1.5464e-02,  2.9750e-03, -1.1476e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4381,  0.0000, -0.0918, -0.1142,  2.3139,  0.3437,  0.0000,  0.0000,
-        -0.1565,  0.0125, -0.0948,  0.0000, -0.6519, -0.3454,  0.0000,  0.0000,
-         0.0000, -0.2362,  0.0000,  0.0000,  0.0337, -0.1006,  0.0000, -0.6762,
-        -0.2302,  0.0000, -0.2801, -0.0946, -0.0693,  0.0000, -0.8137, -1.2193,
-         0.0000,  0.1115,  0.0000,  0.0000,  0.0000,  0.0000,  0.4793,  0.3503,
-         0.0920,  0.0000,  0.4419,  0.0000,  0.0000, -0.3238,  0.0000,  0.0072,
-         0.0000,  0.0000,  0.0000,  0.0662, -0.2119,  0.3290,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6127,  0.0000,  0.0000,  0.0155,  0.0000, -0.1148],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4381,  0.0000, -0.0918, -0.1142,  2.3139,  0.3437,  0.0000,  0.0000,
-        -0.1565,  0.0125, -0.0948,  0.0000, -0.6519, -0.3454,  0.0000,  0.0000,
-         0.0000, -0.2362,  0.0000,  0.0000,  0.0337, -0.1006,  0.0000, -0.6762,
-        -0.2302,  0.0000, -0.2801, -0.0946, -0.0693,  0.0000, -0.8137, -1.2193,
-         0.0000,  0.1115,  0.0000,  0.0000,  0.0000,  0.0000,  0.4793,  0.3503,
-         0.0920,  0.0000,  0.4419,  0.0000,  0.0000, -0.3238,  0.0000,  0.0072,
-         0.0000,  0.0000,  0.0000,  0.0662, -0.2119,  0.3290,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6127,  0.0000,  0.0000,  0.0155,  0.0000, -0.1148],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.9665e-01, -1.0620e-06, -1.1883e-01, -1.7697e-01,  2.3128e+00,
-         3.0663e-01,  6.4285e-16, -5.0760e-09, -1.8207e-01, -6.9289e-03,
-        -1.7013e-02,  1.1557e-07, -6.5844e-01, -3.5072e-01, -1.1528e-12,
-        -8.5681e-10, -8.4281e-11, -1.9958e-01, -1.9494e-13,  2.3438e-09,
-         5.7166e-03, -8.7546e-02, -4.1397e-13, -6.7957e-01, -2.3569e-01,
-         3.3598e-09, -3.2114e-01, -1.0877e-01, -4.7134e-02,  6.0206e-13,
-        -7.9370e-01, -1.2125e+00, -1.5412e-16,  1.1716e-01,  7.6496e-11,
-         4.1274e-09,  2.5517e-12,  0.0000e+00,  4.4195e-01,  2.9210e-01,
-         1.0382e-01, -3.4082e-10,  4.6339e-01,  6.2800e-13,  1.5360e-10,
-        -3.1956e-01, -2.1673e-08, -6.8540e-03,  7.2907e-08, -3.9980e-14,
-        -1.2936e-04,  4.3288e-02, -1.2778e-01,  2.9932e-01, -1.4651e-05,
-         8.3205e-18, -3.2701e-12,  1.1760e-15,  6.0421e-01, -1.6985e-07,
-         3.1415e-15, -5.1588e-02,  2.6394e-03, -9.5423e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3966,  0.0000, -0.1188, -0.1770,  2.3128,  0.3066,  0.0000,  0.0000,
-        -0.1821, -0.0069, -0.0170,  0.0000, -0.6584, -0.3507,  0.0000,  0.0000,
-         0.0000, -0.1996,  0.0000,  0.0000,  0.0057, -0.0875,  0.0000, -0.6796,
-        -0.2357,  0.0000, -0.3211, -0.1088, -0.0471,  0.0000, -0.7937, -1.2125,
-         0.0000,  0.1172,  0.0000,  0.0000,  0.0000,  0.0000,  0.4419,  0.2921,
-         0.1038,  0.0000,  0.4634,  0.0000,  0.0000, -0.3196,  0.0000, -0.0069,
-         0.0000,  0.0000,  0.0000,  0.0433, -0.1278,  0.2993,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6042,  0.0000,  0.0000, -0.0516,  0.0000, -0.0954],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3966,  0.0000, -0.1188, -0.1770,  2.3128,  0.3066,  0.0000,  0.0000,
-        -0.1821, -0.0069, -0.0170,  0.0000, -0.6584, -0.3507,  0.0000,  0.0000,
-         0.0000, -0.1996,  0.0000,  0.0000,  0.0057, -0.0875,  0.0000, -0.6796,
-        -0.2357,  0.0000, -0.3211, -0.1088, -0.0471,  0.0000, -0.7937, -1.2125,
-         0.0000,  0.1172,  0.0000,  0.0000,  0.0000,  0.0000,  0.4419,  0.2921,
-         0.1038,  0.0000,  0.4634,  0.0000,  0.0000, -0.3196,  0.0000, -0.0069,
-         0.0000,  0.0000,  0.0000,  0.0433, -0.1278,  0.2993,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6042,  0.0000,  0.0000, -0.0516,  0.0000, -0.0954],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6626e-01, -9.4261e-07, -1.3988e-01, -2.4234e-01,  2.3118e+00,
-         2.5761e-01,  5.7058e-16, -4.5054e-09, -1.9098e-01,  1.4487e-02,
-         9.5630e-02,  1.0257e-07, -6.5812e-01, -3.4965e-01, -1.0232e-12,
-        -7.6049e-10, -7.4806e-11, -1.6146e-01, -1.7302e-13,  2.0803e-09,
-        -5.9953e-02, -4.2160e-02, -3.6743e-13, -7.4778e-01, -1.9503e-01,
-         2.9821e-09, -3.5435e-01, -1.3060e-01, -6.9464e-02,  5.3438e-13,
-        -7.6912e-01, -1.2134e+00, -1.3679e-16,  9.4074e-02,  6.7896e-11,
-         3.6634e-09,  2.2648e-12,  0.0000e+00,  4.4734e-01,  2.8128e-01,
-         1.2618e-01, -3.0251e-10,  4.7720e-01,  5.5740e-13,  1.3633e-10,
-        -2.6979e-01, -1.9236e-08,  2.8863e-03,  6.4711e-08, -3.5485e-14,
-        -1.1481e-04,  7.1209e-03,  2.5226e-03,  3.3270e-01, -1.3004e-05,
-         7.3851e-18, -2.9025e-12,  1.0438e-15,  5.9260e-01, -1.5075e-07,
-         2.7883e-15, -8.8787e-02,  2.3427e-03, -5.9164e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3663,  0.0000, -0.1399, -0.2423,  2.3118,  0.2576,  0.0000,  0.0000,
-        -0.1910,  0.0145,  0.0956,  0.0000, -0.6581, -0.3496,  0.0000,  0.0000,
-         0.0000, -0.1615,  0.0000,  0.0000, -0.0600, -0.0422,  0.0000, -0.7478,
-        -0.1950,  0.0000, -0.3543, -0.1306, -0.0695,  0.0000, -0.7691, -1.2134,
-         0.0000,  0.0941,  0.0000,  0.0000,  0.0000,  0.0000,  0.4473,  0.2813,
-         0.1262,  0.0000,  0.4772,  0.0000,  0.0000, -0.2698,  0.0000,  0.0029,
-         0.0000,  0.0000,  0.0000,  0.0071,  0.0025,  0.3327,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5926,  0.0000,  0.0000, -0.0888,  0.0000, -0.0592],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3663,  0.0000, -0.1399, -0.2423,  2.3118,  0.2576,  0.0000,  0.0000,
-        -0.1910,  0.0145,  0.0956,  0.0000, -0.6581, -0.3496,  0.0000,  0.0000,
-         0.0000, -0.1615,  0.0000,  0.0000, -0.0600, -0.0422,  0.0000, -0.7478,
-        -0.1950,  0.0000, -0.3543, -0.1306, -0.0695,  0.0000, -0.7691, -1.2134,
-         0.0000,  0.0941,  0.0000,  0.0000,  0.0000,  0.0000,  0.4473,  0.2813,
-         0.1262,  0.0000,  0.4772,  0.0000,  0.0000, -0.2698,  0.0000,  0.0029,
-         0.0000,  0.0000,  0.0000,  0.0071,  0.0025,  0.3327,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5926,  0.0000,  0.0000, -0.0888,  0.0000, -0.0592],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5163e-01, -8.3699e-07, -1.4101e-01, -3.0766e-01,  2.3092e+00,
-         2.0814e-01,  5.0664e-16, -4.0005e-09, -1.7778e-01,  4.4130e-02,
-         1.4313e-01,  9.1080e-08, -6.4981e-01, -3.3032e-01, -9.0855e-13,
-        -6.7527e-10, -6.6423e-11, -1.2170e-01, -1.5363e-13,  1.8472e-09,
-        -1.4326e-01,  2.1316e-02, -3.2626e-13, -8.2848e-01, -1.4314e-01,
-         2.6479e-09, -3.8219e-01, -1.4869e-01, -1.0369e-01,  4.7450e-13,
-        -7.3434e-01, -1.2157e+00, -1.2146e-16,  1.4744e-02,  6.0288e-11,
-         3.2529e-09,  2.0111e-12,  0.0000e+00,  4.8548e-01,  3.0567e-01,
-         1.5153e-01, -2.6861e-10,  4.9691e-01,  4.9494e-13,  1.2106e-10,
-        -2.1026e-01, -1.7081e-08,  2.0250e-02,  5.7460e-08, -3.1509e-14,
-        -1.0195e-04, -1.9277e-02,  1.1692e-01,  3.3699e-01, -1.1546e-05,
-         6.5575e-18, -2.5772e-12,  9.2686e-16,  5.5769e-01, -1.3386e-07,
-         2.4759e-15, -1.1481e-01,  2.0802e-03, -2.4919e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3516,  0.0000, -0.1410, -0.3077,  2.3092,  0.2081,  0.0000,  0.0000,
-        -0.1778,  0.0441,  0.1431,  0.0000, -0.6498, -0.3303,  0.0000,  0.0000,
-         0.0000, -0.1217,  0.0000,  0.0000, -0.1433,  0.0213,  0.0000, -0.8285,
-        -0.1431,  0.0000, -0.3822, -0.1487, -0.1037,  0.0000, -0.7343, -1.2157,
-         0.0000,  0.0147,  0.0000,  0.0000,  0.0000,  0.0000,  0.4855,  0.3057,
-         0.1515,  0.0000,  0.4969,  0.0000,  0.0000, -0.2103,  0.0000,  0.0202,
-         0.0000,  0.0000,  0.0000, -0.0193,  0.1169,  0.3370,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5577,  0.0000,  0.0000, -0.1148,  0.0000, -0.0249],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3516,  0.0000, -0.1410, -0.3077,  2.3092,  0.2081,  0.0000,  0.0000,
-        -0.1778,  0.0441,  0.1431,  0.0000, -0.6498, -0.3303,  0.0000,  0.0000,
-         0.0000, -0.1217,  0.0000,  0.0000, -0.1433,  0.0213,  0.0000, -0.8285,
-        -0.1431,  0.0000, -0.3822, -0.1487, -0.1037,  0.0000, -0.7343, -1.2157,
-         0.0000,  0.0147,  0.0000,  0.0000,  0.0000,  0.0000,  0.4855,  0.3057,
-         0.1515,  0.0000,  0.4969,  0.0000,  0.0000, -0.2103,  0.0000,  0.0202,
-         0.0000,  0.0000,  0.0000, -0.0193,  0.1169,  0.3370,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5577,  0.0000,  0.0000, -0.1148,  0.0000, -0.0249],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2026e-01, -7.4350e-07, -1.2260e-01, -3.6153e-01,  2.3048e+00,
-         1.7831e-01,  4.5005e-16, -3.5537e-09, -1.4908e-01,  5.6622e-02,
-         1.7111e-02,  8.0907e-08, -6.1887e-01, -2.9864e-01, -8.0707e-13,
-        -5.9985e-10, -5.9004e-11, -8.5926e-02, -1.3647e-13,  1.6408e-09,
-        -2.0223e-01,  8.6843e-02, -2.8982e-13, -8.9501e-01, -1.1746e-01,
-         2.3522e-09, -4.0143e-01, -1.5125e-01, -1.3226e-01,  4.2150e-13,
-        -6.9859e-01, -1.2175e+00, -1.0789e-16, -6.5371e-02,  5.3554e-11,
-         2.8896e-09,  1.7864e-12,  0.0000e+00,  5.4746e-01,  3.3139e-01,
-         1.3536e-01, -2.3861e-10,  5.2889e-01,  4.3966e-13,  1.0754e-10,
-        -1.7312e-01, -1.5173e-08,  4.7647e-02,  5.1042e-08, -2.7989e-14,
-        -9.0561e-05, -2.2379e-02,  1.3828e-01,  3.0602e-01, -1.0257e-05,
-         5.8251e-18, -2.2894e-12,  8.2333e-16,  5.2799e-01, -1.1891e-07,
-         2.1993e-15, -1.2794e-01,  1.8479e-03, -4.7571e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3203,  0.0000, -0.1226, -0.3615,  2.3048,  0.1783,  0.0000,  0.0000,
-        -0.1491,  0.0566,  0.0171,  0.0000, -0.6189, -0.2986,  0.0000,  0.0000,
-         0.0000, -0.0859,  0.0000,  0.0000, -0.2022,  0.0868,  0.0000, -0.8950,
-        -0.1175,  0.0000, -0.4014, -0.1512, -0.1323,  0.0000, -0.6986, -1.2175,
-         0.0000, -0.0654,  0.0000,  0.0000,  0.0000,  0.0000,  0.5475,  0.3314,
-         0.1354,  0.0000,  0.5289,  0.0000,  0.0000, -0.1731,  0.0000,  0.0476,
-         0.0000,  0.0000,  0.0000, -0.0224,  0.1383,  0.3060,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5280,  0.0000,  0.0000, -0.1279,  0.0000, -0.0476],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3203,  0.0000, -0.1226, -0.3615,  2.3048,  0.1783,  0.0000,  0.0000,
-        -0.1491,  0.0566,  0.0171,  0.0000, -0.6189, -0.2986,  0.0000,  0.0000,
-         0.0000, -0.0859,  0.0000,  0.0000, -0.2022,  0.0868,  0.0000, -0.8950,
-        -0.1175,  0.0000, -0.4014, -0.1512, -0.1323,  0.0000, -0.6986, -1.2175,
-         0.0000, -0.0654,  0.0000,  0.0000,  0.0000,  0.0000,  0.5475,  0.3314,
-         0.1354,  0.0000,  0.5289,  0.0000,  0.0000, -0.1731,  0.0000,  0.0476,
-         0.0000,  0.0000,  0.0000, -0.0224,  0.1383,  0.3060,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5280,  0.0000,  0.0000, -0.1279,  0.0000, -0.0476],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9664e-01, -6.6072e-07, -1.0824e-01, -4.1468e-01,  2.3013e+00,
-         1.6635e-01,  3.9995e-16, -3.1580e-09, -1.0513e-01,  7.6155e-02,
-        -1.2161e-01,  7.1900e-08, -5.8379e-01, -2.5193e-01, -7.1722e-13,
-        -5.3306e-10, -5.2435e-11, -6.7573e-02, -1.2128e-13,  1.4582e-09,
-        -2.5201e-01,  1.2034e-01, -2.5755e-13, -9.5043e-01, -1.0458e-01,
-         2.0903e-09, -4.0914e-01, -1.6788e-01, -1.7086e-01,  3.7457e-13,
-        -6.5554e-01, -1.2178e+00, -9.5883e-17, -1.3219e-01,  4.7592e-11,
-         2.5679e-09,  1.5875e-12,  0.0000e+00,  5.9793e-01,  3.3468e-01,
-         8.9060e-02, -2.1204e-10,  5.5515e-01,  3.9071e-13,  9.5563e-11,
-        -1.1487e-01, -1.3484e-08,  7.8198e-02,  4.5359e-08, -2.4873e-14,
-        -8.0479e-05, -2.6247e-02,  1.6040e-01,  2.5762e-01, -9.1149e-06,
-         5.1766e-18, -2.0345e-12,  7.3167e-16,  5.0643e-01, -1.0567e-07,
-         1.9545e-15, -1.2414e-01,  1.6421e-03, -5.9310e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2966,  0.0000, -0.1082, -0.4147,  2.3013,  0.1664,  0.0000,  0.0000,
-        -0.1051,  0.0762, -0.1216,  0.0000, -0.5838, -0.2519,  0.0000,  0.0000,
-         0.0000, -0.0676,  0.0000,  0.0000, -0.2520,  0.1203,  0.0000, -0.9504,
-        -0.1046,  0.0000, -0.4091, -0.1679, -0.1709,  0.0000, -0.6555, -1.2178,
-         0.0000, -0.1322,  0.0000,  0.0000,  0.0000,  0.0000,  0.5979,  0.3347,
-         0.0891,  0.0000,  0.5552,  0.0000,  0.0000, -0.1149,  0.0000,  0.0782,
-         0.0000,  0.0000,  0.0000, -0.0262,  0.1604,  0.2576,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5064,  0.0000,  0.0000, -0.1241,  0.0000, -0.0593],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2966,  0.0000, -0.1082, -0.4147,  2.3013,  0.1664,  0.0000,  0.0000,
-        -0.1051,  0.0762, -0.1216,  0.0000, -0.5838, -0.2519,  0.0000,  0.0000,
-         0.0000, -0.0676,  0.0000,  0.0000, -0.2520,  0.1203,  0.0000, -0.9504,
-        -0.1046,  0.0000, -0.4091, -0.1679, -0.1709,  0.0000, -0.6555, -1.2178,
-         0.0000, -0.1322,  0.0000,  0.0000,  0.0000,  0.0000,  0.5979,  0.3347,
-         0.0891,  0.0000,  0.5552,  0.0000,  0.0000, -0.1149,  0.0000,  0.0782,
-         0.0000,  0.0000,  0.0000, -0.0262,  0.1604,  0.2576,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5064,  0.0000,  0.0000, -0.1241,  0.0000, -0.0593],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8280e-01, -5.8740e-07, -1.0122e-01, -4.4841e-01,  2.3000e+00,
-         1.7765e-01,  3.5557e-16, -2.8076e-09, -7.4145e-02,  8.3770e-02,
-        -2.5788e-01,  6.3921e-08, -5.4600e-01, -2.1871e-01, -6.3763e-13,
-        -4.7391e-10, -4.6616e-11, -4.8803e-02, -1.0782e-13,  1.2964e-09,
-        -2.6685e-01,  1.2018e-01, -2.2897e-13, -9.9958e-01, -1.1855e-01,
-         1.8583e-09, -4.0038e-01, -1.4934e-01, -1.9834e-01,  3.3301e-13,
-        -6.1313e-01, -1.2155e+00, -8.5243e-17, -1.8024e-01,  4.2310e-11,
-         2.2829e-09,  1.4114e-12,  0.0000e+00,  6.3831e-01,  3.5807e-01,
-         5.0984e-02, -1.8851e-10,  5.7329e-01,  3.4735e-13,  8.4959e-11,
-        -8.3609e-02, -1.1987e-08,  1.0715e-01,  4.0326e-08, -2.2113e-14,
-        -7.1548e-05, -6.2678e-02,  1.6188e-01,  1.6080e-01, -8.1034e-06,
-         4.6021e-18, -1.8087e-12,  6.5048e-16,  5.0180e-01, -9.3943e-08,
-         1.7376e-15, -1.1097e-01,  1.4599e-03, -5.3954e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2828,  0.0000, -0.1012, -0.4484,  2.3000,  0.1776,  0.0000,  0.0000,
-        -0.0741,  0.0838, -0.2579,  0.0000, -0.5460, -0.2187,  0.0000,  0.0000,
-         0.0000, -0.0488,  0.0000,  0.0000, -0.2668,  0.1202,  0.0000, -0.9996,
-        -0.1186,  0.0000, -0.4004, -0.1493, -0.1983,  0.0000, -0.6131, -1.2155,
-         0.0000, -0.1802,  0.0000,  0.0000,  0.0000,  0.0000,  0.6383,  0.3581,
-         0.0510,  0.0000,  0.5733,  0.0000,  0.0000, -0.0836,  0.0000,  0.1072,
-         0.0000,  0.0000,  0.0000, -0.0627,  0.1619,  0.1608,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5018,  0.0000,  0.0000, -0.1110,  0.0000, -0.0540],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2828,  0.0000, -0.1012, -0.4484,  2.3000,  0.1776,  0.0000,  0.0000,
-        -0.0741,  0.0838, -0.2579,  0.0000, -0.5460, -0.2187,  0.0000,  0.0000,
-         0.0000, -0.0488,  0.0000,  0.0000, -0.2668,  0.1202,  0.0000, -0.9996,
-        -0.1186,  0.0000, -0.4004, -0.1493, -0.1983,  0.0000, -0.6131, -1.2155,
-         0.0000, -0.1802,  0.0000,  0.0000,  0.0000,  0.0000,  0.6383,  0.3581,
-         0.0510,  0.0000,  0.5733,  0.0000,  0.0000, -0.0836,  0.0000,  0.1072,
-         0.0000,  0.0000,  0.0000, -0.0627,  0.1619,  0.1608,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5018,  0.0000,  0.0000, -0.1110,  0.0000, -0.0540],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8112e-01, -5.2243e-07, -1.2803e-01, -4.6797e-01,  2.2969e+00,
-         1.9022e-01,  3.1624e-16, -2.4971e-09, -5.8704e-02,  1.3999e-01,
-        -2.7864e-01,  5.6851e-08, -5.3901e-01, -2.0743e-01, -5.6711e-13,
-        -4.2149e-10, -4.1460e-11, -1.5949e-02, -9.5895e-14,  1.1530e-09,
-        -2.7831e-01,  1.5445e-01, -2.0364e-13, -1.0245e+00, -7.3648e-02,
-         1.6528e-09, -4.1163e-01, -9.1099e-02, -2.4096e-01,  2.9617e-13,
-        -5.7702e-01, -1.2132e+00, -7.5814e-17, -2.3411e-01,  3.7631e-11,
-         2.0304e-09,  1.2553e-12,  0.0000e+00,  6.5304e-01,  3.8866e-01,
-         8.6875e-02, -1.6766e-10,  5.7269e-01,  3.0893e-13,  7.5562e-11,
-        -7.6500e-02, -1.0662e-08,  1.1793e-01,  3.5865e-08, -1.9667e-14,
-        -6.3634e-05, -7.7651e-02,  1.7544e-01,  4.7074e-02, -7.2071e-06,
-         4.0931e-18, -1.6087e-12,  5.7853e-16,  4.7874e-01, -8.3552e-08,
-         1.5454e-15, -1.1302e-01,  1.2984e-03, -1.3100e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2811,  0.0000, -0.1280, -0.4680,  2.2969,  0.1902,  0.0000,  0.0000,
-        -0.0587,  0.1400, -0.2786,  0.0000, -0.5390, -0.2074,  0.0000,  0.0000,
-         0.0000, -0.0159,  0.0000,  0.0000, -0.2783,  0.1544,  0.0000, -1.0245,
-        -0.0736,  0.0000, -0.4116, -0.0911, -0.2410,  0.0000, -0.5770, -1.2132,
-         0.0000, -0.2341,  0.0000,  0.0000,  0.0000,  0.0000,  0.6530,  0.3887,
-         0.0869,  0.0000,  0.5727,  0.0000,  0.0000, -0.0765,  0.0000,  0.1179,
-         0.0000,  0.0000,  0.0000, -0.0777,  0.1754,  0.0471,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4787,  0.0000,  0.0000, -0.1130,  0.0000, -0.0131],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2811,  0.0000, -0.1280, -0.4680,  2.2969,  0.1902,  0.0000,  0.0000,
-        -0.0587,  0.1400, -0.2786,  0.0000, -0.5390, -0.2074,  0.0000,  0.0000,
-         0.0000, -0.0159,  0.0000,  0.0000, -0.2783,  0.1544,  0.0000, -1.0245,
-        -0.0736,  0.0000, -0.4116, -0.0911, -0.2410,  0.0000, -0.5770, -1.2132,
-         0.0000, -0.2341,  0.0000,  0.0000,  0.0000,  0.0000,  0.6530,  0.3887,
-         0.0869,  0.0000,  0.5727,  0.0000,  0.0000, -0.0765,  0.0000,  0.1179,
-         0.0000,  0.0000,  0.0000, -0.0777,  0.1754,  0.0471,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4787,  0.0000,  0.0000, -0.1130,  0.0000, -0.0131],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8642e-01, -4.6484e-07, -1.4974e-01, -4.7644e-01,  2.2911e+00,
-         2.0491e-01,  2.8138e-16, -2.2218e-09, -2.3243e-02,  1.8925e-01,
-        -2.9361e-01,  5.0583e-08, -5.3654e-01, -1.9141e-01, -5.0459e-13,
-        -3.7503e-10, -3.6890e-11,  3.6324e-02, -8.5323e-14,  1.0259e-09,
-        -2.7546e-01,  1.9098e-01, -1.8119e-13, -1.0431e+00, -9.8616e-03,
-         1.4706e-09, -4.1867e-01, -1.8085e-02, -2.7881e-01,  2.6352e-13,
-        -5.4406e-01, -1.2111e+00, -6.7456e-17, -2.5393e-01,  3.3482e-11,
-         1.8066e-09,  1.1169e-12,  0.0000e+00,  6.5451e-01,  4.0938e-01,
-         1.2548e-01, -1.4918e-10,  5.7783e-01,  2.7487e-13,  6.7232e-11,
-        -9.8518e-02, -9.4862e-09,  9.9108e-02,  3.1912e-08, -1.7499e-14,
-        -5.6619e-05, -6.0057e-02,  1.6498e-01, -4.3738e-03, -6.4126e-06,
-         3.6419e-18, -1.4313e-12,  5.1475e-16,  4.7400e-01, -7.4341e-08,
-         1.3750e-15, -9.8273e-02,  1.1553e-03,  1.3533e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2864,  0.0000, -0.1497, -0.4764,  2.2911,  0.2049,  0.0000,  0.0000,
-        -0.0232,  0.1893, -0.2936,  0.0000, -0.5365, -0.1914,  0.0000,  0.0000,
-         0.0000,  0.0363,  0.0000,  0.0000, -0.2755,  0.1910,  0.0000, -1.0431,
-        -0.0099,  0.0000, -0.4187, -0.0181, -0.2788,  0.0000, -0.5441, -1.2111,
-         0.0000, -0.2539,  0.0000,  0.0000,  0.0000,  0.0000,  0.6545,  0.4094,
-         0.1255,  0.0000,  0.5778,  0.0000,  0.0000, -0.0985,  0.0000,  0.0991,
-         0.0000,  0.0000,  0.0000, -0.0601,  0.1650, -0.0044,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4740,  0.0000,  0.0000, -0.0983,  0.0000,  0.0135],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2864,  0.0000, -0.1497, -0.4764,  2.2911,  0.2049,  0.0000,  0.0000,
-        -0.0232,  0.1893, -0.2936,  0.0000, -0.5365, -0.1914,  0.0000,  0.0000,
-         0.0000,  0.0363,  0.0000,  0.0000, -0.2755,  0.1910,  0.0000, -1.0431,
-        -0.0099,  0.0000, -0.4187, -0.0181, -0.2788,  0.0000, -0.5441, -1.2111,
-         0.0000, -0.2539,  0.0000,  0.0000,  0.0000,  0.0000,  0.6545,  0.4094,
-         0.1255,  0.0000,  0.5778,  0.0000,  0.0000, -0.0985,  0.0000,  0.0991,
-         0.0000,  0.0000,  0.0000, -0.0601,  0.1650, -0.0044,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4740,  0.0000,  0.0000, -0.0983,  0.0000,  0.0135],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7586e-01, -4.1376e-07, -1.7958e-01, -4.6902e-01,  2.2849e+00,
-         2.1220e-01,  2.5046e-16, -1.9776e-09, -9.0246e-03,  2.3860e-01,
-        -2.7725e-01,  4.5025e-08, -5.4477e-01, -1.7788e-01, -4.4914e-13,
-        -3.3382e-10, -3.2836e-11,  1.0564e-01, -7.5948e-14,  9.1314e-10,
-        -2.5442e-01,  2.3527e-01, -1.6128e-13, -1.0504e+00,  5.9368e-02,
-         1.3090e-09, -4.2495e-01,  6.1479e-02, -2.8450e-01,  2.3457e-13,
-        -5.2342e-01, -1.2064e+00, -6.0044e-17, -2.6093e-01,  2.9803e-11,
-         1.6081e-09,  9.9416e-13,  0.0000e+00,  6.2987e-01,  4.0831e-01,
-         1.8343e-01, -1.3279e-10,  5.7149e-01,  2.4467e-13,  5.9844e-11,
-        -1.2375e-01, -8.4438e-09,  7.4320e-02,  2.8405e-08, -1.5576e-14,
-        -5.0398e-05, -3.3927e-02,  1.8479e-01, -3.3131e-02, -5.7080e-06,
-         3.2417e-18, -1.2740e-12,  4.5819e-16,  4.7286e-01, -6.6173e-08,
-         1.2239e-15, -7.7419e-02,  1.0283e-03,  5.6428e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2759,  0.0000, -0.1796, -0.4690,  2.2849,  0.2122,  0.0000,  0.0000,
-        -0.0090,  0.2386, -0.2773,  0.0000, -0.5448, -0.1779,  0.0000,  0.0000,
-         0.0000,  0.1056,  0.0000,  0.0000, -0.2544,  0.2353,  0.0000, -1.0504,
-         0.0594,  0.0000, -0.4249,  0.0615, -0.2845,  0.0000, -0.5234, -1.2064,
-         0.0000, -0.2609,  0.0000,  0.0000,  0.0000,  0.0000,  0.6299,  0.4083,
-         0.1834,  0.0000,  0.5715,  0.0000,  0.0000, -0.1237,  0.0000,  0.0743,
-         0.0000,  0.0000,  0.0000, -0.0339,  0.1848, -0.0331,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4729,  0.0000,  0.0000, -0.0774,  0.0000,  0.0564],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2759,  0.0000, -0.1796, -0.4690,  2.2849,  0.2122,  0.0000,  0.0000,
-        -0.0090,  0.2386, -0.2773,  0.0000, -0.5448, -0.1779,  0.0000,  0.0000,
-         0.0000,  0.1056,  0.0000,  0.0000, -0.2544,  0.2353,  0.0000, -1.0504,
-         0.0594,  0.0000, -0.4249,  0.0615, -0.2845,  0.0000, -0.5234, -1.2064,
-         0.0000, -0.2609,  0.0000,  0.0000,  0.0000,  0.0000,  0.6299,  0.4083,
-         0.1834,  0.0000,  0.5715,  0.0000,  0.0000, -0.1237,  0.0000,  0.0743,
-         0.0000,  0.0000,  0.0000, -0.0339,  0.1848, -0.0331,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4729,  0.0000,  0.0000, -0.0774,  0.0000,  0.0564],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6224e-01, -3.6845e-07, -1.9650e-01, -4.5142e-01,  2.2790e+00,
-         2.2052e-01,  2.2303e-16, -1.7611e-09,  1.6175e-02,  2.6875e-01,
-        -2.6246e-01,  4.0094e-08, -5.4218e-01, -1.5701e-01, -3.9995e-13,
-        -2.9726e-10, -2.9240e-11,  1.2846e-01, -6.7630e-14,  8.1313e-10,
-        -2.3278e-01,  2.6426e-01, -1.4362e-13, -1.0495e+00,  1.1155e-01,
-         1.1656e-09, -4.2862e-01,  1.1538e-01, -2.9444e-01,  2.0888e-13,
-        -5.1047e-01, -1.1998e+00, -5.3468e-17, -2.5402e-01,  2.6539e-11,
-         1.4319e-09,  8.8528e-13,  0.0000e+00,  6.0419e-01,  4.1438e-01,
-         2.3250e-01, -1.1824e-10,  5.5972e-01,  2.1787e-13,  5.3290e-11,
-        -1.5970e-01, -7.5191e-09,  7.7168e-02,  2.5294e-08, -1.3870e-14,
-        -4.4878e-05,  2.6006e-02,  1.3175e-01, -3.6325e-02, -5.0828e-06,
-         2.8867e-18, -1.1345e-12,  4.0801e-16,  4.8681e-01, -5.8925e-08,
-         1.0899e-15, -6.1548e-02,  9.1572e-04,  4.8569e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2622,  0.0000, -0.1965, -0.4514,  2.2790,  0.2205,  0.0000,  0.0000,
-         0.0162,  0.2687, -0.2625,  0.0000, -0.5422, -0.1570,  0.0000,  0.0000,
-         0.0000,  0.1285,  0.0000,  0.0000, -0.2328,  0.2643,  0.0000, -1.0495,
-         0.1115,  0.0000, -0.4286,  0.1154, -0.2944,  0.0000, -0.5105, -1.1998,
-         0.0000, -0.2540,  0.0000,  0.0000,  0.0000,  0.0000,  0.6042,  0.4144,
-         0.2325,  0.0000,  0.5597,  0.0000,  0.0000, -0.1597,  0.0000,  0.0772,
-         0.0000,  0.0000,  0.0000,  0.0260,  0.1317, -0.0363,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4868,  0.0000,  0.0000, -0.0615,  0.0000,  0.0486],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2622,  0.0000, -0.1965, -0.4514,  2.2790,  0.2205,  0.0000,  0.0000,
-         0.0162,  0.2687, -0.2625,  0.0000, -0.5422, -0.1570,  0.0000,  0.0000,
-         0.0000,  0.1285,  0.0000,  0.0000, -0.2328,  0.2643,  0.0000, -1.0495,
-         0.1115,  0.0000, -0.4286,  0.1154, -0.2944,  0.0000, -0.5105, -1.1998,
-         0.0000, -0.2540,  0.0000,  0.0000,  0.0000,  0.0000,  0.6042,  0.4144,
-         0.2325,  0.0000,  0.5597,  0.0000,  0.0000, -0.1597,  0.0000,  0.0772,
-         0.0000,  0.0000,  0.0000,  0.0260,  0.1317, -0.0363,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4868,  0.0000,  0.0000, -0.0615,  0.0000,  0.0486],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4729e-01, -3.2823e-07, -1.7901e-01, -4.0970e-01,  2.2740e+00,
-         2.4470e-01,  1.9868e-16, -1.5688e-09,  3.2867e-02,  2.8244e-01,
-        -2.6277e-01,  3.5718e-08, -5.1528e-01, -1.2659e-01, -3.5629e-13,
-        -2.6481e-10, -2.6048e-11,  1.4875e-01, -6.0247e-14,  7.2437e-10,
-        -1.9626e-01,  2.4790e-01, -1.2794e-13, -1.0417e+00,  1.2096e-01,
-         1.0384e-09, -4.1846e-01,  1.5017e-01, -2.7403e-01,  1.8608e-13,
-        -5.2147e-01, -1.1907e+00, -4.7632e-17, -2.3139e-01,  2.3642e-11,
-         1.2756e-09,  7.8864e-13,  0.0000e+00,  5.6864e-01,  4.1168e-01,
-         2.1829e-01, -1.0534e-10,  5.5237e-01,  1.9409e-13,  4.7473e-11,
-        -2.0708e-01, -6.6983e-09,  9.8782e-02,  2.2533e-08, -1.2356e-14,
-        -3.9979e-05,  7.1602e-02,  5.5880e-02, -1.7080e-02, -4.5280e-06,
-         2.5716e-18, -1.0107e-12,  3.6347e-16,  5.0786e-01, -5.2493e-08,
-         9.7092e-16, -2.6973e-02,  8.1576e-04,  3.9182e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2473,  0.0000, -0.1790, -0.4097,  2.2740,  0.2447,  0.0000,  0.0000,
-         0.0329,  0.2824, -0.2628,  0.0000, -0.5153, -0.1266,  0.0000,  0.0000,
-         0.0000,  0.1488,  0.0000,  0.0000, -0.1963,  0.2479,  0.0000, -1.0417,
-         0.1210,  0.0000, -0.4185,  0.1502, -0.2740,  0.0000, -0.5215, -1.1907,
-         0.0000, -0.2314,  0.0000,  0.0000,  0.0000,  0.0000,  0.5686,  0.4117,
-         0.2183,  0.0000,  0.5524,  0.0000,  0.0000, -0.2071,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0716,  0.0559, -0.0171,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5079,  0.0000,  0.0000, -0.0270,  0.0000,  0.0392],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2473,  0.0000, -0.1790, -0.4097,  2.2740,  0.2447,  0.0000,  0.0000,
-         0.0329,  0.2824, -0.2628,  0.0000, -0.5153, -0.1266,  0.0000,  0.0000,
-         0.0000,  0.1488,  0.0000,  0.0000, -0.1963,  0.2479,  0.0000, -1.0417,
-         0.1210,  0.0000, -0.4185,  0.1502, -0.2740,  0.0000, -0.5215, -1.1907,
-         0.0000, -0.2314,  0.0000,  0.0000,  0.0000,  0.0000,  0.5686,  0.4117,
-         0.2183,  0.0000,  0.5524,  0.0000,  0.0000, -0.2071,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0716,  0.0559, -0.0171,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5079,  0.0000,  0.0000, -0.0270,  0.0000,  0.0392],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4479e-01, -2.9252e-07, -1.1882e-01, -3.8066e-01,  2.2670e+00,
-         2.6597e-01,  1.7707e-16, -1.3981e-09,  4.7311e-02,  2.6543e-01,
-        -2.8268e-01,  3.1832e-08, -4.7664e-01, -1.1953e-01, -3.1753e-13,
-        -2.3600e-10, -2.3214e-11,  1.1542e-01, -5.3693e-14,  6.4556e-10,
-        -1.6699e-01,  1.9691e-01, -1.1402e-13, -1.0153e+00,  9.2774e-02,
-         9.2542e-10, -3.9104e-01,  1.4197e-01, -2.2725e-01,  1.6583e-13,
-        -5.4870e-01, -1.1838e+00, -4.2449e-17, -1.7033e-01,  2.1070e-11,
-         1.1368e-09,  7.0284e-13,  0.0000e+00,  5.0972e-01,  4.0626e-01,
-         1.5430e-01, -9.3876e-11,  5.5455e-01,  1.7298e-13,  4.2308e-11,
-        -2.3094e-01, -5.9695e-09,  1.9262e-02,  2.0082e-08, -1.1012e-14,
-        -3.5630e-05,  1.1576e-01,  1.3987e-02, -3.6454e-02, -4.0354e-06,
-         2.2918e-18, -9.0071e-13,  3.2393e-16,  5.1283e-01, -4.6782e-08,
-         8.6529e-16,  2.2108e-02,  7.2701e-04,  3.3254e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2448,  0.0000, -0.1188, -0.3807,  2.2670,  0.2660,  0.0000,  0.0000,
-         0.0473,  0.2654, -0.2827,  0.0000, -0.4766, -0.1195,  0.0000,  0.0000,
-         0.0000,  0.1154,  0.0000,  0.0000, -0.1670,  0.1969,  0.0000, -1.0153,
-         0.0928,  0.0000, -0.3910,  0.1420, -0.2272,  0.0000, -0.5487, -1.1838,
-         0.0000, -0.1703,  0.0000,  0.0000,  0.0000,  0.0000,  0.5097,  0.4063,
-         0.1543,  0.0000,  0.5546,  0.0000,  0.0000, -0.2309,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1158,  0.0140, -0.0365,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5128,  0.0000,  0.0000,  0.0221,  0.0000,  0.0333],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2448,  0.0000, -0.1188, -0.3807,  2.2670,  0.2660,  0.0000,  0.0000,
-         0.0473,  0.2654, -0.2827,  0.0000, -0.4766, -0.1195,  0.0000,  0.0000,
-         0.0000,  0.1154,  0.0000,  0.0000, -0.1670,  0.1969,  0.0000, -1.0153,
-         0.0928,  0.0000, -0.3910,  0.1420, -0.2272,  0.0000, -0.5487, -1.1838,
-         0.0000, -0.1703,  0.0000,  0.0000,  0.0000,  0.0000,  0.5097,  0.4063,
-         0.1543,  0.0000,  0.5546,  0.0000,  0.0000, -0.2309,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1158,  0.0140, -0.0365,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5128,  0.0000,  0.0000,  0.0221,  0.0000,  0.0333],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4992e-01, -2.6080e-07, -3.5197e-02, -3.6707e-01,  2.2632e+00,
-         3.0179e-01,  1.5787e-16, -1.2465e-09,  1.8097e-02,  2.5864e-01,
-        -2.7572e-01,  2.8380e-08, -4.4054e-01, -1.3543e-01, -2.8310e-13,
-        -2.1041e-10, -2.0697e-11,  3.7034e-02, -4.7870e-14,  5.7556e-10,
-        -1.4987e-01,  1.2770e-01, -1.0166e-13, -9.7505e-01,  4.7262e-02,
-         8.2507e-10, -3.6630e-01,  1.1017e-01, -1.7136e-01,  1.4785e-13,
-        -5.7439e-01, -1.1744e+00, -3.7846e-17, -8.6629e-02,  1.8785e-11,
-         1.0136e-09,  6.2663e-13,  0.0000e+00,  4.3167e-01,  3.9812e-01,
-         6.1886e-02, -8.3696e-11,  5.6700e-01,  1.5422e-13,  3.7720e-11,
-        -2.7326e-01, -5.3222e-09,  1.7173e-02,  1.7904e-08, -9.8179e-15,
-        -3.1766e-05,  1.4870e-01, -3.1268e-02, -1.0266e-01, -3.5978e-06,
-         2.0433e-18, -8.0304e-13,  2.8880e-16,  5.0206e-01, -4.1709e-08,
-         7.7146e-16,  4.1251e-02,  6.4817e-04,  3.6656e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2499,  0.0000, -0.0352, -0.3671,  2.2632,  0.3018,  0.0000,  0.0000,
-         0.0181,  0.2586, -0.2757,  0.0000, -0.4405, -0.1354,  0.0000,  0.0000,
-         0.0000,  0.0370,  0.0000,  0.0000, -0.1499,  0.1277,  0.0000, -0.9751,
-         0.0473,  0.0000, -0.3663,  0.1102, -0.1714,  0.0000, -0.5744, -1.1744,
-         0.0000, -0.0866,  0.0000,  0.0000,  0.0000,  0.0000,  0.4317,  0.3981,
-         0.0619,  0.0000,  0.5670,  0.0000,  0.0000, -0.2733,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1487, -0.0313, -0.1027,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5021,  0.0000,  0.0000,  0.0413,  0.0000,  0.0367],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2499,  0.0000, -0.0352, -0.3671,  2.2632,  0.3018,  0.0000,  0.0000,
-         0.0181,  0.2586, -0.2757,  0.0000, -0.4405, -0.1354,  0.0000,  0.0000,
-         0.0000,  0.0370,  0.0000,  0.0000, -0.1499,  0.1277,  0.0000, -0.9751,
-         0.0473,  0.0000, -0.3663,  0.1102, -0.1714,  0.0000, -0.5744, -1.1744,
-         0.0000, -0.0866,  0.0000,  0.0000,  0.0000,  0.0000,  0.4317,  0.3981,
-         0.0619,  0.0000,  0.5670,  0.0000,  0.0000, -0.2733,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1487, -0.0313, -0.1027,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5021,  0.0000,  0.0000,  0.0413,  0.0000,  0.0367],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6803e-01, -2.3261e-07,  4.8139e-02, -3.5141e-01,  2.2595e+00,
-         3.1105e-01,  1.4080e-16, -1.1118e-09, -2.4896e-02,  2.4683e-01,
-        -2.2702e-01,  2.5313e-08, -3.9527e-01, -1.7536e-01, -2.5250e-13,
-        -1.8767e-10, -1.8460e-11, -3.8987e-02, -4.2697e-14,  5.1336e-10,
-        -1.3224e-01,  6.4320e-02, -9.0672e-14, -9.1843e-01,  1.0646e-02,
-         7.3590e-10, -3.4159e-01,  7.1687e-02, -1.1190e-01,  1.3187e-13,
-        -6.1279e-01, -1.1612e+00, -3.3756e-17,  2.5557e-03,  1.6755e-11,
-         9.0403e-10,  5.5890e-13,  0.0000e+00,  3.6115e-01,  4.0053e-01,
-        -4.1032e-04, -7.4651e-11,  5.7768e-01,  1.3755e-13,  3.3643e-11,
-        -3.3263e-01, -4.7470e-09,  1.5317e-02,  1.5969e-08, -8.7568e-15,
-        -2.8333e-05,  2.0060e-01, -5.1673e-02, -1.4239e-01, -3.2089e-06,
-         1.8224e-18, -7.1625e-13,  2.5759e-16,  4.5233e-01, -3.7201e-08,
-         6.8808e-16,  5.7055e-02,  5.7812e-04,  6.7388e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.6803e-01,  0.0000e+00,  4.8139e-02, -3.5141e-01,  2.2595e+00,
-         3.1105e-01,  0.0000e+00,  0.0000e+00, -2.4896e-02,  2.4683e-01,
-        -2.2702e-01,  0.0000e+00, -3.9527e-01, -1.7536e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.8987e-02,  0.0000e+00,  0.0000e+00,
-        -1.3224e-01,  6.4320e-02,  0.0000e+00, -9.1843e-01,  1.0646e-02,
-         0.0000e+00, -3.4159e-01,  7.1687e-02, -1.1190e-01,  0.0000e+00,
-        -6.1279e-01, -1.1612e+00,  0.0000e+00,  2.5557e-03,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.6115e-01,  4.0053e-01,
-        -4.1032e-04,  0.0000e+00,  5.7768e-01,  0.0000e+00,  0.0000e+00,
-        -3.3263e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  2.0060e-01, -5.1673e-02, -1.4239e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.5233e-01,  0.0000e+00,
-         0.0000e+00,  5.7055e-02,  0.0000e+00,  6.7388e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.6803e-01,  0.0000e+00,  4.8139e-02, -3.5141e-01,  2.2595e+00,
-         3.1105e-01,  0.0000e+00,  0.0000e+00, -2.4896e-02,  2.4683e-01,
-        -2.2702e-01,  0.0000e+00, -3.9527e-01, -1.7536e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.8987e-02,  0.0000e+00,  0.0000e+00,
-        -1.3224e-01,  6.4320e-02,  0.0000e+00, -9.1843e-01,  1.0646e-02,
-         0.0000e+00, -3.4159e-01,  7.1687e-02, -1.1190e-01,  0.0000e+00,
-        -6.1279e-01, -1.1612e+00,  0.0000e+00,  2.5557e-03,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.6115e-01,  4.0053e-01,
-        -4.1032e-04,  0.0000e+00,  5.7768e-01,  0.0000e+00,  0.0000e+00,
-        -3.3263e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  2.0060e-01, -5.1673e-02, -1.4239e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.5233e-01,  0.0000e+00,
-         0.0000e+00,  5.7055e-02,  0.0000e+00,  6.7388e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7356e-01, -2.0755e-07,  9.9293e-02, -3.2322e-01,  2.2570e+00,
-         3.3381e-01,  1.2564e-16, -9.9204e-10, -9.8128e-02,  2.4154e-01,
-        -1.8237e-01,  2.2586e-08, -4.0294e-01, -2.5021e-01, -2.2530e-13,
-        -1.6745e-10, -1.6472e-11, -5.8983e-02, -3.8097e-14,  4.5806e-10,
-        -1.3613e-01,  1.1942e-02, -8.0905e-14, -8.6811e-01, -5.4647e-03,
-         6.5663e-10, -2.9492e-01,  2.5949e-02, -4.6723e-02,  1.1766e-13,
-        -5.9608e-01, -1.1521e+00, -3.0120e-17,  8.9923e-02,  1.4950e-11,
-         8.0665e-10,  4.9870e-13,  0.0000e+00,  2.8769e-01,  4.0196e-01,
-        -5.7089e-02, -6.6609e-11,  5.7148e-01,  1.2273e-13,  3.0019e-11,
-        -3.9640e-01, -4.2357e-09,  1.3667e-02,  1.4249e-08, -7.8135e-15,
-        -2.5281e-05,  1.9112e-01, -9.4021e-02, -1.4426e-02, -2.8633e-06,
-         1.6261e-18, -6.3909e-13,  2.2984e-16,  4.0475e-01, -3.3194e-08,
-         6.1396e-16,  7.3316e-02,  5.1585e-04,  9.2102e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2736,  0.0000,  0.0993, -0.3232,  2.2570,  0.3338,  0.0000,  0.0000,
-        -0.0981,  0.2415, -0.1824,  0.0000, -0.4029, -0.2502,  0.0000,  0.0000,
-         0.0000, -0.0590,  0.0000,  0.0000, -0.1361,  0.0119,  0.0000, -0.8681,
-        -0.0055,  0.0000, -0.2949,  0.0259, -0.0467,  0.0000, -0.5961, -1.1521,
-         0.0000,  0.0899,  0.0000,  0.0000,  0.0000,  0.0000,  0.2877,  0.4020,
-        -0.0571,  0.0000,  0.5715,  0.0000,  0.0000, -0.3964,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1911, -0.0940, -0.0144,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4048,  0.0000,  0.0000,  0.0733,  0.0000,  0.0921],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2736,  0.0000,  0.0993, -0.3232,  2.2570,  0.3338,  0.0000,  0.0000,
-        -0.0981,  0.2415, -0.1824,  0.0000, -0.4029, -0.2502,  0.0000,  0.0000,
-         0.0000, -0.0590,  0.0000,  0.0000, -0.1361,  0.0119,  0.0000, -0.8681,
-        -0.0055,  0.0000, -0.2949,  0.0259, -0.0467,  0.0000, -0.5961, -1.1521,
-         0.0000,  0.0899,  0.0000,  0.0000,  0.0000,  0.0000,  0.2877,  0.4020,
-        -0.0571,  0.0000,  0.5715,  0.0000,  0.0000, -0.3964,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1911, -0.0940, -0.0144,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4048,  0.0000,  0.0000,  0.0733,  0.0000,  0.0921],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6572e-01, -1.8527e-07,  9.8344e-02, -2.9523e-01,  2.2545e+00,
-         3.1931e-01,  1.1215e-16, -8.8554e-10, -1.8339e-01,  2.3818e-01,
-        -1.2468e-01,  2.0161e-08, -4.5262e-01, -3.1846e-01, -2.0111e-13,
-        -1.4947e-10, -1.4703e-11, -5.5978e-02, -3.4007e-14,  4.0888e-10,
-        -1.2456e-01, -4.6265e-02, -7.2219e-14, -8.1916e-01,  2.3146e-02,
-         5.8614e-10, -2.5415e-01,  2.6548e-02,  4.6869e-02,  1.0503e-13,
-        -5.6240e-01, -1.1474e+00, -2.6886e-17,  1.5230e-01,  1.3345e-11,
-         7.2005e-10,  4.4516e-13,  0.0000e+00,  2.3427e-01,  4.1774e-01,
-        -4.9406e-02, -5.9458e-11,  5.4426e-01,  1.0956e-13,  2.6797e-11,
-        -4.4156e-01, -3.7809e-09,  1.2200e-02,  1.2719e-08, -6.9747e-15,
-        -2.2567e-05,  1.8730e-01, -1.2091e-01,  8.0518e-02, -2.5559e-06,
-         1.4515e-18, -5.7048e-13,  2.0517e-16,  3.5350e-01, -2.9630e-08,
-         5.4805e-16,  4.2613e-02,  4.6046e-04,  5.9167e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2657,  0.0000,  0.0983, -0.2952,  2.2545,  0.3193,  0.0000,  0.0000,
-        -0.1834,  0.2382, -0.1247,  0.0000, -0.4526, -0.3185,  0.0000,  0.0000,
-         0.0000, -0.0560,  0.0000,  0.0000, -0.1246, -0.0463,  0.0000, -0.8192,
-         0.0231,  0.0000, -0.2541,  0.0265,  0.0469,  0.0000, -0.5624, -1.1474,
-         0.0000,  0.1523,  0.0000,  0.0000,  0.0000,  0.0000,  0.2343,  0.4177,
-        -0.0494,  0.0000,  0.5443,  0.0000,  0.0000, -0.4416,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1873, -0.1209,  0.0805,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3535,  0.0000,  0.0000,  0.0426,  0.0000,  0.0592],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2657,  0.0000,  0.0983, -0.2952,  2.2545,  0.3193,  0.0000,  0.0000,
-        -0.1834,  0.2382, -0.1247,  0.0000, -0.4526, -0.3185,  0.0000,  0.0000,
-         0.0000, -0.0560,  0.0000,  0.0000, -0.1246, -0.0463,  0.0000, -0.8192,
-         0.0231,  0.0000, -0.2541,  0.0265,  0.0469,  0.0000, -0.5624, -1.1474,
-         0.0000,  0.1523,  0.0000,  0.0000,  0.0000,  0.0000,  0.2343,  0.4177,
-        -0.0494,  0.0000,  0.5443,  0.0000,  0.0000, -0.4416,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1873, -0.1209,  0.0805,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3535,  0.0000,  0.0000,  0.0426,  0.0000,  0.0592],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7571e-01, -1.6545e-07,  8.0222e-02, -2.6634e-01,  2.2515e+00,
-         2.8470e-01,  1.0015e-16, -7.9078e-10, -2.5568e-01,  2.0939e-01,
-        -6.5277e-02,  1.8004e-08, -4.8120e-01, -3.7952e-01, -1.7959e-13,
-        -1.3348e-10, -1.3130e-11, -8.5133e-02, -3.0369e-14,  3.6513e-10,
-        -1.0325e-01, -1.1270e-01, -6.4491e-14, -7.7804e-01,  4.6319e-02,
-         5.2342e-10, -2.1807e-01,  4.7850e-02,  1.4303e-01,  9.3794e-14,
-        -5.4595e-01, -1.1440e+00, -2.4009e-17,  1.9695e-01,  1.1917e-11,
-         6.4300e-10,  3.9753e-13,  0.0000e+00,  1.7851e-01,  4.4968e-01,
-        -1.4693e-02, -5.3096e-11,  5.0782e-01,  9.7835e-14,  2.3929e-11,
-        -4.7855e-01, -3.3764e-09,  1.0895e-02,  1.1358e-08, -6.2284e-15,
-        -2.0152e-05,  2.0034e-01, -1.3090e-01,  8.2256e-02, -2.2824e-06,
-         1.2962e-18, -5.0944e-13,  1.8321e-16,  2.9210e-01, -2.6460e-08,
-         4.8940e-16, -9.6377e-03,  4.1119e-04,  3.0739e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2757,  0.0000,  0.0802, -0.2663,  2.2515,  0.2847,  0.0000,  0.0000,
-        -0.2557,  0.2094, -0.0653,  0.0000, -0.4812, -0.3795,  0.0000,  0.0000,
-         0.0000, -0.0851,  0.0000,  0.0000, -0.1033, -0.1127,  0.0000, -0.7780,
-         0.0463,  0.0000, -0.2181,  0.0479,  0.1430,  0.0000, -0.5459, -1.1440,
-         0.0000,  0.1970,  0.0000,  0.0000,  0.0000,  0.0000,  0.1785,  0.4497,
-        -0.0147,  0.0000,  0.5078,  0.0000,  0.0000, -0.4785,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.2003, -0.1309,  0.0823,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2921,  0.0000,  0.0000, -0.0096,  0.0000,  0.0031],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2757,  0.0000,  0.0802, -0.2663,  2.2515,  0.2847,  0.0000,  0.0000,
-        -0.2557,  0.2094, -0.0653,  0.0000, -0.4812, -0.3795,  0.0000,  0.0000,
-         0.0000, -0.0851,  0.0000,  0.0000, -0.1033, -0.1127,  0.0000, -0.7780,
-         0.0463,  0.0000, -0.2181,  0.0479,  0.1430,  0.0000, -0.5459, -1.1440,
-         0.0000,  0.1970,  0.0000,  0.0000,  0.0000,  0.0000,  0.1785,  0.4497,
-        -0.0147,  0.0000,  0.5078,  0.0000,  0.0000, -0.4785,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.2003, -0.1309,  0.0823,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2921,  0.0000,  0.0000, -0.0096,  0.0000,  0.0031],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7288e-01, -1.4780e-07,  4.7549e-02, -2.5731e-01,  2.2471e+00,
-         2.3969e-01,  8.9468e-17, -7.0645e-10, -3.0470e-01,  1.5107e-01,
-         8.2497e-03,  1.6084e-08, -4.9956e-01, -4.3502e-01, -1.6044e-13,
-        -1.1925e-10, -1.1730e-11, -1.2998e-01, -2.7130e-14,  3.2619e-10,
-        -6.7593e-02, -1.6429e-01, -5.7614e-14, -7.3653e-01,  8.5602e-02,
-         4.6760e-10, -2.1261e-01,  7.5673e-02,  2.1109e-01,  8.3791e-14,
-        -5.4078e-01, -1.1329e+00, -2.1449e-17,  2.1160e-01,  1.0646e-11,
-         5.7443e-10,  3.5513e-13,  0.0000e+00,  1.1443e-01,  4.8148e-01,
-         3.8393e-02, -4.7434e-11,  4.6827e-01,  8.7401e-14,  2.1377e-11,
-        -5.0722e-01, -3.0163e-09,  9.7328e-03,  1.0147e-08, -5.5641e-15,
-        -1.8003e-05,  2.3194e-01, -1.3057e-01,  3.5615e-02, -2.0390e-06,
-         1.1580e-18, -4.5511e-13,  1.6367e-16,  2.3573e-01, -2.3638e-08,
-         4.3721e-16, -7.5281e-02,  3.6734e-04, -5.9831e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2729,  0.0000,  0.0475, -0.2573,  2.2471,  0.2397,  0.0000,  0.0000,
-        -0.3047,  0.1511,  0.0082,  0.0000, -0.4996, -0.4350,  0.0000,  0.0000,
-         0.0000, -0.1300,  0.0000,  0.0000, -0.0676, -0.1643,  0.0000, -0.7365,
-         0.0856,  0.0000, -0.2126,  0.0757,  0.2111,  0.0000, -0.5408, -1.1329,
-         0.0000,  0.2116,  0.0000,  0.0000,  0.0000,  0.0000,  0.1144,  0.4815,
-         0.0384,  0.0000,  0.4683,  0.0000,  0.0000, -0.5072,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.2319, -0.1306,  0.0356,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2357,  0.0000,  0.0000, -0.0753,  0.0000, -0.0598],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2729,  0.0000,  0.0475, -0.2573,  2.2471,  0.2397,  0.0000,  0.0000,
-        -0.3047,  0.1511,  0.0082,  0.0000, -0.4996, -0.4350,  0.0000,  0.0000,
-         0.0000, -0.1300,  0.0000,  0.0000, -0.0676, -0.1643,  0.0000, -0.7365,
-         0.0856,  0.0000, -0.2126,  0.0757,  0.2111,  0.0000, -0.5408, -1.1329,
-         0.0000,  0.2116,  0.0000,  0.0000,  0.0000,  0.0000,  0.1144,  0.4815,
-         0.0384,  0.0000,  0.4683,  0.0000,  0.0000, -0.5072,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.2319, -0.1306,  0.0356,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2357,  0.0000,  0.0000, -0.0753,  0.0000, -0.0598],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.5661e-01, -1.3209e-07,  3.8956e-02, -2.2120e-01,  2.2405e+00,
-         2.3968e-01,  7.9959e-17, -6.3137e-10, -3.4922e-01,  7.2618e-02,
-         2.0195e-02,  1.4374e-08, -5.1851e-01, -4.7075e-01, -1.4339e-13,
-        -1.0657e-10, -1.0483e-11, -1.9829e-01, -2.4246e-14,  2.9152e-10,
-        -3.9259e-02, -2.3194e-01, -5.1490e-14, -7.1924e-01,  6.7476e-02,
-         4.1790e-10, -1.7889e-01,  7.3577e-02,  2.3312e-01,  7.4885e-14,
-        -5.3612e-01, -1.1253e+00, -1.9169e-17,  1.7331e-01,  9.5146e-12,
-         5.1337e-10,  3.1739e-13,  0.0000e+00,  8.0527e-02,  5.0698e-01,
-         2.7682e-02, -4.2392e-11,  4.3361e-01,  7.8112e-14,  1.9105e-11,
-        -5.1550e-01, -2.6957e-09,  8.6983e-03,  9.0683e-09, -4.9727e-15,
-        -1.6090e-05,  2.0031e-01, -1.3471e-01,  1.5051e-02, -1.8223e-06,
-         1.0349e-18, -4.0674e-13,  1.4628e-16,  1.7542e-01, -2.1126e-08,
-         3.9074e-16, -1.2333e-01,  3.2830e-04, -9.3982e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2566,  0.0000,  0.0390, -0.2212,  2.2405,  0.2397,  0.0000,  0.0000,
-        -0.3492,  0.0726,  0.0202,  0.0000, -0.5185, -0.4708,  0.0000,  0.0000,
-         0.0000, -0.1983,  0.0000,  0.0000, -0.0393, -0.2319,  0.0000, -0.7192,
-         0.0675,  0.0000, -0.1789,  0.0736,  0.2331,  0.0000, -0.5361, -1.1253,
-         0.0000,  0.1733,  0.0000,  0.0000,  0.0000,  0.0000,  0.0805,  0.5070,
-         0.0277,  0.0000,  0.4336,  0.0000,  0.0000, -0.5155,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.2003, -0.1347,  0.0151,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.1754,  0.0000,  0.0000, -0.1233,  0.0000, -0.0940],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2566,  0.0000,  0.0390, -0.2212,  2.2405,  0.2397,  0.0000,  0.0000,
-        -0.3492,  0.0726,  0.0202,  0.0000, -0.5185, -0.4708,  0.0000,  0.0000,
-         0.0000, -0.1983,  0.0000,  0.0000, -0.0393, -0.2319,  0.0000, -0.7192,
-         0.0675,  0.0000, -0.1789,  0.0736,  0.2331,  0.0000, -0.5361, -1.1253,
-         0.0000,  0.1733,  0.0000,  0.0000,  0.0000,  0.0000,  0.0805,  0.5070,
-         0.0277,  0.0000,  0.4336,  0.0000,  0.0000, -0.5155,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.2003, -0.1347,  0.0151,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.1754,  0.0000,  0.0000, -0.1233,  0.0000, -0.0940],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.1939e-01, -1.1810e-07,  2.7222e-02, -1.8902e-01,  2.2314e+00,
-         2.6210e-01,  7.1489e-17, -5.6448e-10, -3.6548e-01, -3.9029e-03,
-         2.2403e-02,  1.2852e-08, -5.2135e-01, -5.0087e-01, -1.2820e-13,
-        -9.5282e-11, -9.3725e-12, -2.3963e-01, -2.1678e-14,  2.6064e-10,
-        -1.1604e-02, -2.7514e-01, -4.6036e-14, -7.1530e-01,  4.8151e-02,
-         3.7363e-10, -1.4721e-01,  6.8678e-02,  2.1815e-01,  6.6953e-14,
-        -5.2438e-01, -1.1221e+00, -1.7139e-17,  9.7202e-02,  8.5067e-12,
-         4.5899e-10,  2.8376e-13,  0.0000e+00,  5.6664e-02,  5.0921e-01,
-        -4.7499e-04, -3.7902e-11,  3.8851e-01,  6.9837e-14,  1.7081e-11,
-        -4.7823e-01, -2.4101e-09,  7.7769e-03,  8.1077e-09, -4.4460e-15,
-        -1.4385e-05,  1.4134e-01, -1.3079e-01,  3.9752e-02, -1.6292e-06,
-         9.2528e-19, -3.6365e-13,  1.3078e-16,  1.3548e-01, -1.8888e-08,
-         3.4935e-16, -1.4356e-01,  2.9352e-04, -1.1890e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.1939e-01,  0.0000e+00,  2.7222e-02, -1.8902e-01,  2.2314e+00,
-         2.6210e-01,  0.0000e+00,  0.0000e+00, -3.6548e-01, -3.9029e-03,
-         2.2403e-02,  0.0000e+00, -5.2135e-01, -5.0087e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.3963e-01,  0.0000e+00,  0.0000e+00,
-        -1.1604e-02, -2.7514e-01,  0.0000e+00, -7.1530e-01,  4.8151e-02,
-         0.0000e+00, -1.4721e-01,  6.8678e-02,  2.1815e-01,  0.0000e+00,
-        -5.2438e-01, -1.1221e+00,  0.0000e+00,  9.7202e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.6664e-02,  5.0921e-01,
-        -4.7499e-04,  0.0000e+00,  3.8851e-01,  0.0000e+00,  0.0000e+00,
-        -4.7823e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.4134e-01, -1.3079e-01,  3.9752e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.3548e-01,  0.0000e+00,
-         0.0000e+00, -1.4356e-01,  0.0000e+00, -1.1890e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.1939e-01,  0.0000e+00,  2.7222e-02, -1.8902e-01,  2.2314e+00,
-         2.6210e-01,  0.0000e+00,  0.0000e+00, -3.6548e-01, -3.9029e-03,
-         2.2403e-02,  0.0000e+00, -5.2135e-01, -5.0087e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.3963e-01,  0.0000e+00,  0.0000e+00,
-        -1.1604e-02, -2.7514e-01,  0.0000e+00, -7.1530e-01,  4.8151e-02,
-         0.0000e+00, -1.4721e-01,  6.8678e-02,  2.1815e-01,  0.0000e+00,
-        -5.2438e-01, -1.1221e+00,  0.0000e+00,  9.7202e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.6664e-02,  5.0921e-01,
-        -4.7499e-04,  0.0000e+00,  3.8851e-01,  0.0000e+00,  0.0000e+00,
-        -4.7823e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.4134e-01, -1.3079e-01,  3.9752e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.3548e-01,  0.0000e+00,
-         0.0000e+00, -1.4356e-01,  0.0000e+00, -1.1890e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.8682e-01, -1.0563e-07,  3.0226e-02, -1.5479e-01,  2.2231e+00,
-         2.9966e-01,  6.3941e-17, -5.0489e-10, -3.6898e-01, -7.6030e-02,
-         2.2792e-02,  1.1495e-08, -4.9795e-01, -5.3146e-01, -1.1466e-13,
-        -8.5223e-11, -8.3830e-12, -2.4933e-01, -1.9389e-14,  2.3312e-10,
-         2.7749e-03, -3.1650e-01, -4.1176e-14, -7.3483e-01,  3.4095e-03,
-         3.3418e-10, -9.9256e-02,  4.8092e-02,  1.7809e-01,  5.9884e-14,
-        -5.2319e-01, -1.1200e+00, -1.5329e-17, -2.6731e-02,  7.6086e-12,
-         4.1053e-10,  2.5381e-13,  0.0000e+00,  2.8062e-02,  4.9504e-01,
-        -5.0031e-02, -3.3900e-11,  3.4239e-01,  6.2464e-14,  1.5278e-11,
-        -4.3122e-01, -2.1557e-09,  6.9558e-03,  7.2517e-09, -3.9766e-15,
-        -1.2866e-05,  7.0491e-02, -1.1964e-01,  1.6164e-01, -1.4572e-06,
-         8.2759e-19, -3.2526e-13,  1.1697e-16,  1.2076e-01, -1.6894e-08,
-         3.1247e-16, -1.4281e-01,  2.6253e-04, -1.2533e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1868,  0.0000,  0.0302, -0.1548,  2.2231,  0.2997,  0.0000,  0.0000,
-        -0.3690, -0.0760,  0.0228,  0.0000, -0.4979, -0.5315,  0.0000,  0.0000,
-         0.0000, -0.2493,  0.0000,  0.0000,  0.0028, -0.3165,  0.0000, -0.7348,
-         0.0034,  0.0000, -0.0993,  0.0481,  0.1781,  0.0000, -0.5232, -1.1200,
-         0.0000, -0.0267,  0.0000,  0.0000,  0.0000,  0.0000,  0.0281,  0.4950,
-        -0.0500,  0.0000,  0.3424,  0.0000,  0.0000, -0.4312,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0705, -0.1196,  0.1616,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.1208,  0.0000,  0.0000, -0.1428,  0.0000, -0.1253],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1868,  0.0000,  0.0302, -0.1548,  2.2231,  0.2997,  0.0000,  0.0000,
-        -0.3690, -0.0760,  0.0228,  0.0000, -0.4979, -0.5315,  0.0000,  0.0000,
-         0.0000, -0.2493,  0.0000,  0.0000,  0.0028, -0.3165,  0.0000, -0.7348,
-         0.0034,  0.0000, -0.0993,  0.0481,  0.1781,  0.0000, -0.5232, -1.1200,
-         0.0000, -0.0267,  0.0000,  0.0000,  0.0000,  0.0000,  0.0281,  0.4950,
-        -0.0500,  0.0000,  0.3424,  0.0000,  0.0000, -0.4312,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0705, -0.1196,  0.1616,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.1208,  0.0000,  0.0000, -0.1428,  0.0000, -0.1253],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.6420e-01, -9.4518e-08,  5.2378e-03, -1.3065e-01,  2.2171e+00,
-         3.1887e-01,  5.7213e-17, -4.5176e-10, -3.4461e-01, -1.3023e-01,
-         9.9437e-02,  1.0285e-08, -4.7275e-01, -5.5808e-01, -1.0260e-13,
-        -7.6256e-11, -7.5009e-12, -2.3745e-01, -1.7349e-14,  2.0859e-10,
-         3.2309e-02, -3.1658e-01, -3.6843e-14, -7.3189e-01,  2.5932e-03,
-         2.9902e-10, -7.4248e-02,  5.3871e-02,  1.5104e-01,  5.3583e-14,
-        -5.3727e-01, -1.1164e+00, -1.3716e-17, -1.4623e-01,  6.8080e-12,
-         3.6734e-10,  2.2710e-13,  0.0000e+00, -2.0704e-02,  4.7388e-01,
-        -5.2288e-02, -3.0333e-11,  2.8249e-01,  5.5891e-14,  1.3670e-11,
-        -3.6492e-01, -1.9289e-09,  6.2239e-03,  6.4887e-09, -3.5582e-15,
-        -1.1513e-05,  4.7448e-02, -8.0751e-02,  2.6581e-01, -1.3039e-06,
-         7.4051e-19, -2.9103e-13,  1.0467e-16,  1.0764e-01, -1.5116e-08,
-         2.7959e-16, -1.4402e-01,  2.3491e-04, -1.0754e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1642,  0.0000,  0.0052, -0.1307,  2.2171,  0.3189,  0.0000,  0.0000,
-        -0.3446, -0.1302,  0.0994,  0.0000, -0.4728, -0.5581,  0.0000,  0.0000,
-         0.0000, -0.2375,  0.0000,  0.0000,  0.0323, -0.3166,  0.0000, -0.7319,
-         0.0026,  0.0000, -0.0742,  0.0539,  0.1510,  0.0000, -0.5373, -1.1164,
-         0.0000, -0.1462,  0.0000,  0.0000,  0.0000,  0.0000, -0.0207,  0.4739,
-        -0.0523,  0.0000,  0.2825,  0.0000,  0.0000, -0.3649,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0474, -0.0808,  0.2658,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.1076,  0.0000,  0.0000, -0.1440,  0.0000, -0.1075],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1642,  0.0000,  0.0052, -0.1307,  2.2171,  0.3189,  0.0000,  0.0000,
-        -0.3446, -0.1302,  0.0994,  0.0000, -0.4728, -0.5581,  0.0000,  0.0000,
-         0.0000, -0.2375,  0.0000,  0.0000,  0.0323, -0.3166,  0.0000, -0.7319,
-         0.0026,  0.0000, -0.0742,  0.0539,  0.1510,  0.0000, -0.5373, -1.1164,
-         0.0000, -0.1462,  0.0000,  0.0000,  0.0000,  0.0000, -0.0207,  0.4739,
-        -0.0523,  0.0000,  0.2825,  0.0000,  0.0000, -0.3649,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0474, -0.0808,  0.2658,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.1076,  0.0000,  0.0000, -0.1440,  0.0000, -0.1075],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.3048e-01, -8.4606e-08, -4.5497e-02, -1.1563e-01,  2.2113e+00,
-         3.3436e-01,  5.1213e-17, -4.0439e-10, -3.2504e-01, -1.7729e-01,
-         1.9512e-01,  9.2067e-09, -4.5527e-01, -5.7912e-01, -9.1840e-14,
-        -6.8259e-11, -6.7143e-12, -1.9822e-01, -1.5530e-14,  1.8672e-10,
-         1.0140e-01, -2.8702e-01, -3.2979e-14, -7.0837e-01,  2.9384e-02,
-         2.6766e-10, -7.6125e-02,  9.8569e-02,  1.6135e-01,  4.7964e-14,
-        -5.5478e-01, -1.1122e+00, -1.2278e-17, -2.4555e-01,  6.0941e-12,
-         3.2881e-10,  2.0328e-13,  0.0000e+00, -7.3913e-02,  4.6084e-01,
-         1.6038e-03, -2.7152e-11,  2.0858e-01,  5.0030e-14,  1.2237e-11,
-        -3.2910e-01, -1.7266e-09,  5.5712e-03,  5.8082e-09, -3.1850e-15,
-        -1.0305e-05,  7.0279e-02, -6.0363e-02,  3.5066e-01, -1.1672e-06,
-         6.6286e-19, -2.6051e-13,  9.3690e-17,  1.0545e-01, -1.3531e-08,
-         2.5027e-16, -1.6202e-01,  2.1027e-04, -9.9008e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 1.3048e-01,  0.0000e+00, -4.5497e-02, -1.1563e-01,  2.2113e+00,
-         3.3436e-01,  0.0000e+00,  0.0000e+00, -3.2504e-01, -1.7729e-01,
-         1.9512e-01,  0.0000e+00, -4.5527e-01, -5.7912e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.9822e-01,  0.0000e+00,  0.0000e+00,
-         1.0140e-01, -2.8702e-01,  0.0000e+00, -7.0837e-01,  2.9384e-02,
-         0.0000e+00, -7.6125e-02,  9.8569e-02,  1.6135e-01,  0.0000e+00,
-        -5.5478e-01, -1.1122e+00,  0.0000e+00, -2.4555e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -7.3913e-02,  4.6084e-01,
-         1.6038e-03,  0.0000e+00,  2.0858e-01,  0.0000e+00,  0.0000e+00,
-        -3.2910e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  7.0279e-02, -6.0363e-02,  3.5066e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0545e-01,  0.0000e+00,
-         0.0000e+00, -1.6202e-01,  0.0000e+00, -9.9008e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 1.3048e-01,  0.0000e+00, -4.5497e-02, -1.1563e-01,  2.2113e+00,
-         3.3436e-01,  0.0000e+00,  0.0000e+00, -3.2504e-01, -1.7729e-01,
-         1.9512e-01,  0.0000e+00, -4.5527e-01, -5.7912e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.9822e-01,  0.0000e+00,  0.0000e+00,
-         1.0140e-01, -2.8702e-01,  0.0000e+00, -7.0837e-01,  2.9384e-02,
-         0.0000e+00, -7.6125e-02,  9.8569e-02,  1.6135e-01,  0.0000e+00,
-        -5.5478e-01, -1.1122e+00,  0.0000e+00, -2.4555e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -7.3913e-02,  4.6084e-01,
-         1.6038e-03,  0.0000e+00,  2.0858e-01,  0.0000e+00,  0.0000e+00,
-        -3.2910e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  7.0279e-02, -6.0363e-02,  3.5066e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0545e-01,  0.0000e+00,
-         0.0000e+00, -1.6202e-01,  0.0000e+00, -9.9008e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.1721e-01, -7.5763e-08, -1.1815e-01, -1.0550e-01,  2.2069e+00,
-         3.5825e-01,  4.5861e-17, -3.6212e-10, -3.0533e-01, -2.0965e-01,
-         2.6315e-01,  8.2445e-09, -4.4596e-01, -5.9796e-01, -8.2241e-14,
-        -6.1125e-11, -6.0125e-12, -1.6885e-01, -1.3907e-14,  1.6720e-10,
-         1.6714e-01, -2.3609e-01, -2.9532e-14, -6.9449e-01,  8.7718e-02,
-         2.3969e-10, -8.6926e-02,  1.6074e-01,  1.6395e-01,  4.2951e-14,
-        -5.8677e-01, -1.1115e+00, -1.0995e-17, -3.5199e-01,  5.4572e-12,
-         2.9445e-10,  1.8204e-13,  0.0000e+00, -1.1096e-01,  4.3641e-01,
-         8.0470e-02, -2.4314e-11,  1.0848e-01,  4.4801e-14,  1.0958e-11,
-        -2.9341e-01, -1.5461e-09,  4.9889e-03,  5.2012e-09, -2.8521e-15,
-        -9.2282e-06,  9.5768e-02, -3.4995e-02,  4.4252e-01, -1.0452e-06,
-         5.9358e-19, -2.3329e-13,  8.3898e-17,  1.1127e-01, -1.2117e-08,
-         2.2411e-16, -1.8644e-01,  1.8830e-04, -8.7255e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1172,  0.0000, -0.1181, -0.1055,  2.2069,  0.3583,  0.0000,  0.0000,
-        -0.3053, -0.2097,  0.2631,  0.0000, -0.4460, -0.5980,  0.0000,  0.0000,
-         0.0000, -0.1689,  0.0000,  0.0000,  0.1671, -0.2361,  0.0000, -0.6945,
-         0.0877,  0.0000, -0.0869,  0.1607,  0.1639,  0.0000, -0.5868, -1.1115,
-         0.0000, -0.3520,  0.0000,  0.0000,  0.0000,  0.0000, -0.1110,  0.4364,
-         0.0805,  0.0000,  0.1085,  0.0000,  0.0000, -0.2934,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0958, -0.0350,  0.4425,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.1113,  0.0000,  0.0000, -0.1864,  0.0000, -0.0873],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1172,  0.0000, -0.1181, -0.1055,  2.2069,  0.3583,  0.0000,  0.0000,
-        -0.3053, -0.2097,  0.2631,  0.0000, -0.4460, -0.5980,  0.0000,  0.0000,
-         0.0000, -0.1689,  0.0000,  0.0000,  0.1671, -0.2361,  0.0000, -0.6945,
-         0.0877,  0.0000, -0.0869,  0.1607,  0.1639,  0.0000, -0.5868, -1.1115,
-         0.0000, -0.3520,  0.0000,  0.0000,  0.0000,  0.0000, -0.1110,  0.4364,
-         0.0805,  0.0000,  0.1085,  0.0000,  0.0000, -0.2934,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0958, -0.0350,  0.4425,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.1113,  0.0000,  0.0000, -0.1864,  0.0000, -0.0873],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.0865e-01, -6.7871e-08, -1.8208e-01, -1.1466e-01,  2.2063e+00,
-         3.8849e-01,  4.1084e-17, -3.2440e-10, -2.9375e-01, -2.3723e-01,
-         2.4072e-01,  7.3857e-09, -4.3286e-01, -5.9550e-01, -7.3675e-14,
-        -5.4758e-11, -5.3863e-12, -1.6084e-01, -1.2458e-14,  1.4979e-10,
-         2.3365e-01, -2.0234e-01, -2.6456e-14, -6.8876e-01,  1.3362e-01,
-         2.1472e-10, -8.8727e-02,  2.0327e-01,  1.7660e-01,  3.8477e-14,
-        -6.2113e-01, -1.1059e+00, -9.8493e-18, -4.2517e-01,  4.8887e-12,
-         2.6378e-10,  1.6308e-13,  0.0000e+00, -7.7511e-02,  4.4660e-01,
-         1.2541e-01, -2.1782e-11,  1.4356e-03,  4.0134e-14,  9.8165e-12,
-        -2.4675e-01, -1.3851e-09,  4.4693e-03,  4.6594e-09, -2.5550e-15,
-        -8.2669e-06,  8.9244e-02, -2.6857e-02,  4.1353e-01, -9.3630e-07,
-         5.3175e-19, -2.0899e-13,  7.5159e-17,  1.5275e-01, -1.0855e-08,
-         2.0077e-16, -2.2142e-01,  1.6868e-04, -6.9252e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 1.0865e-01,  0.0000e+00, -1.8208e-01, -1.1466e-01,  2.2063e+00,
-         3.8849e-01,  0.0000e+00,  0.0000e+00, -2.9375e-01, -2.3723e-01,
-         2.4072e-01,  0.0000e+00, -4.3286e-01, -5.9550e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.6084e-01,  0.0000e+00,  0.0000e+00,
-         2.3365e-01, -2.0234e-01,  0.0000e+00, -6.8876e-01,  1.3362e-01,
-         0.0000e+00, -8.8727e-02,  2.0327e-01,  1.7660e-01,  0.0000e+00,
-        -6.2113e-01, -1.1059e+00,  0.0000e+00, -4.2517e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -7.7511e-02,  4.4660e-01,
-         1.2541e-01,  0.0000e+00,  1.4356e-03,  0.0000e+00,  0.0000e+00,
-        -2.4675e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  8.9244e-02, -2.6857e-02,  4.1353e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.5275e-01,  0.0000e+00,
-         0.0000e+00, -2.2142e-01,  0.0000e+00, -6.9252e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 1.0865e-01,  0.0000e+00, -1.8208e-01, -1.1466e-01,  2.2063e+00,
-         3.8849e-01,  0.0000e+00,  0.0000e+00, -2.9375e-01, -2.3723e-01,
-         2.4072e-01,  0.0000e+00, -4.3286e-01, -5.9550e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.6084e-01,  0.0000e+00,  0.0000e+00,
-         2.3365e-01, -2.0234e-01,  0.0000e+00, -6.8876e-01,  1.3362e-01,
-         0.0000e+00, -8.8727e-02,  2.0327e-01,  1.7660e-01,  0.0000e+00,
-        -6.2113e-01, -1.1059e+00,  0.0000e+00, -4.2517e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -7.7511e-02,  4.4660e-01,
-         1.2541e-01,  0.0000e+00,  1.4356e-03,  0.0000e+00,  0.0000e+00,
-        -2.4675e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  8.9244e-02, -2.6857e-02,  4.1353e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.5275e-01,  0.0000e+00,
-         0.0000e+00, -2.2142e-01,  0.0000e+00, -6.9252e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 8.7277e-02, -6.0825e-08, -2.3035e-01, -1.4992e-01,  2.2046e+00,
-         4.1104e-01,  3.6818e-17, -2.9072e-10, -2.9391e-01, -2.5233e-01,
-         1.6785e-01,  6.6189e-09, -4.3153e-01, -5.8040e-01, -6.6026e-14,
-        -4.9073e-11, -4.8271e-12, -1.8719e-01, -1.1165e-14,  1.3424e-10,
-         2.6702e-01, -1.4827e-01, -2.3710e-14, -6.6153e-01,  1.8758e-01,
-         1.9243e-10, -7.7243e-02,  1.9427e-01,  1.5553e-01,  3.4482e-14,
-        -6.3454e-01, -1.1021e+00, -8.8268e-18, -4.5784e-01,  4.3812e-12,
-         2.3639e-10,  1.4615e-13,  0.0000e+00, -6.2604e-02,  4.7636e-01,
-         1.1090e-01, -1.9520e-11, -9.6509e-02,  3.5968e-14,  8.7974e-12,
-        -2.0219e-01, -1.2413e-09,  4.0053e-03,  4.1757e-09, -2.2898e-15,
-        -7.4087e-06,  8.5606e-02, -5.7004e-02,  3.0213e-01, -8.3910e-07,
-         4.7654e-19, -1.8729e-13,  6.7356e-17,  2.0525e-01, -9.7277e-09,
-         1.7992e-16, -2.6373e-01,  1.5117e-04, -6.5413e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0873,  0.0000, -0.2304, -0.1499,  2.2046,  0.4110,  0.0000,  0.0000,
-        -0.2939, -0.2523,  0.1679,  0.0000, -0.4315, -0.5804,  0.0000,  0.0000,
-         0.0000, -0.1872,  0.0000,  0.0000,  0.2670, -0.1483,  0.0000, -0.6615,
-         0.1876,  0.0000, -0.0772,  0.1943,  0.1555,  0.0000, -0.6345, -1.1021,
-         0.0000, -0.4578,  0.0000,  0.0000,  0.0000,  0.0000, -0.0626,  0.4764,
-         0.1109,  0.0000, -0.0965,  0.0000,  0.0000, -0.2022,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0856, -0.0570,  0.3021,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2052,  0.0000,  0.0000, -0.2637,  0.0000, -0.0654],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0873,  0.0000, -0.2304, -0.1499,  2.2046,  0.4110,  0.0000,  0.0000,
-        -0.2939, -0.2523,  0.1679,  0.0000, -0.4315, -0.5804,  0.0000,  0.0000,
-         0.0000, -0.1872,  0.0000,  0.0000,  0.2670, -0.1483,  0.0000, -0.6615,
-         0.1876,  0.0000, -0.0772,  0.1943,  0.1555,  0.0000, -0.6345, -1.1021,
-         0.0000, -0.4578,  0.0000,  0.0000,  0.0000,  0.0000, -0.0626,  0.4764,
-         0.1109,  0.0000, -0.0965,  0.0000,  0.0000, -0.2022,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0856, -0.0570,  0.3021,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2052,  0.0000,  0.0000, -0.2637,  0.0000, -0.0654],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 7.7901e-02, -5.4532e-08, -2.3545e-01, -1.7069e-01,  2.2001e+00,
-         4.3889e-01,  3.3009e-17, -2.6064e-10, -3.1143e-01, -2.6091e-01,
-        -1.1294e-02,  5.9341e-09, -4.0632e-01, -5.6937e-01, -5.9195e-14,
-        -4.3996e-11, -4.3276e-12, -2.0525e-01, -1.0010e-14,  1.2035e-10,
-         3.0423e-01, -1.3081e-01, -2.1257e-14, -6.1010e-01,  1.5073e-01,
-         1.7252e-10, -2.9142e-02,  1.7123e-01,  1.5352e-01,  3.0915e-14,
-        -6.6092e-01, -1.0961e+00, -7.9135e-18, -4.7325e-01,  3.9279e-12,
-         2.1193e-10,  1.3103e-13,  0.0000e+00, -4.1359e-02,  4.8984e-01,
-         2.7288e-02, -1.7501e-11, -1.5565e-01,  3.2246e-14,  7.8872e-12,
-        -2.0647e-01, -1.1129e-09,  3.5909e-03,  3.7436e-09, -2.0529e-15,
-        -6.6422e-06,  5.4724e-02, -1.3199e-01,  2.1904e-01, -7.5228e-07,
-         4.2724e-19, -1.6791e-13,  6.0387e-17,  2.6036e-01, -8.7212e-09,
-         1.6131e-16, -3.0616e-01,  1.3553e-04, -4.8266e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0779,  0.0000, -0.2355, -0.1707,  2.2001,  0.4389,  0.0000,  0.0000,
-        -0.3114, -0.2609, -0.0113,  0.0000, -0.4063, -0.5694,  0.0000,  0.0000,
-         0.0000, -0.2052,  0.0000,  0.0000,  0.3042, -0.1308,  0.0000, -0.6101,
-         0.1507,  0.0000, -0.0291,  0.1712,  0.1535,  0.0000, -0.6609, -1.0961,
-         0.0000, -0.4733,  0.0000,  0.0000,  0.0000,  0.0000, -0.0414,  0.4898,
-         0.0273,  0.0000, -0.1556,  0.0000,  0.0000, -0.2065,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0547, -0.1320,  0.2190,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2604,  0.0000,  0.0000, -0.3062,  0.0000, -0.0483],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0779,  0.0000, -0.2355, -0.1707,  2.2001,  0.4389,  0.0000,  0.0000,
-        -0.3114, -0.2609, -0.0113,  0.0000, -0.4063, -0.5694,  0.0000,  0.0000,
-         0.0000, -0.2052,  0.0000,  0.0000,  0.3042, -0.1308,  0.0000, -0.6101,
-         0.1507,  0.0000, -0.0291,  0.1712,  0.1535,  0.0000, -0.6609, -1.0961,
-         0.0000, -0.4733,  0.0000,  0.0000,  0.0000,  0.0000, -0.0414,  0.4898,
-         0.0273,  0.0000, -0.1556,  0.0000,  0.0000, -0.2065,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0547, -0.1320,  0.2190,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2604,  0.0000,  0.0000, -0.3062,  0.0000, -0.0483],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 8.0849e-02, -4.8909e-08, -2.2952e-01, -1.9234e-01,  2.1929e+00,
-         4.6850e-01,  2.9605e-17, -2.3377e-10, -3.2204e-01, -2.6868e-01,
-        -1.4700e-01,  5.3222e-09, -3.8681e-01, -5.6844e-01, -5.3091e-14,
-        -3.9459e-11, -3.8814e-12, -1.9376e-01, -8.9774e-15,  1.0794e-10,
-         3.4703e-01, -1.1528e-01, -1.9065e-14, -5.4849e-01,  1.0807e-01,
-         1.5473e-10,  1.9534e-02,  1.4167e-01,  1.5303e-01,  2.7727e-14,
-        -6.7682e-01, -1.0872e+00, -7.0975e-18, -4.7177e-01,  3.5229e-12,
-         1.9008e-10,  1.1751e-13,  0.0000e+00, -2.2039e-02,  4.8278e-01,
-        -4.6581e-02, -1.5696e-11, -2.0071e-01,  2.8921e-14,  7.0739e-12,
-        -2.1677e-01, -9.9810e-10,  3.2206e-03,  3.3576e-09, -1.8412e-15,
-        -5.9573e-06,  2.7717e-02, -2.1946e-01,  1.5266e-01, -6.7471e-07,
-         3.8318e-19, -1.5060e-13,  5.4160e-17,  2.8602e-01, -7.8219e-09,
-         1.4468e-16, -3.4249e-01,  1.2156e-04, -2.9879e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0808,  0.0000, -0.2295, -0.1923,  2.1929,  0.4685,  0.0000,  0.0000,
-        -0.3220, -0.2687, -0.1470,  0.0000, -0.3868, -0.5684,  0.0000,  0.0000,
-         0.0000, -0.1938,  0.0000,  0.0000,  0.3470, -0.1153,  0.0000, -0.5485,
-         0.1081,  0.0000,  0.0195,  0.1417,  0.1530,  0.0000, -0.6768, -1.0872,
-         0.0000, -0.4718,  0.0000,  0.0000,  0.0000,  0.0000, -0.0220,  0.4828,
-        -0.0466,  0.0000, -0.2007,  0.0000,  0.0000, -0.2168,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0277, -0.2195,  0.1527,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2860,  0.0000,  0.0000, -0.3425,  0.0000, -0.0299],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0808,  0.0000, -0.2295, -0.1923,  2.1929,  0.4685,  0.0000,  0.0000,
-        -0.3220, -0.2687, -0.1470,  0.0000, -0.3868, -0.5684,  0.0000,  0.0000,
-         0.0000, -0.1938,  0.0000,  0.0000,  0.3470, -0.1153,  0.0000, -0.5485,
-         0.1081,  0.0000,  0.0195,  0.1417,  0.1530,  0.0000, -0.6768, -1.0872,
-         0.0000, -0.4718,  0.0000,  0.0000,  0.0000,  0.0000, -0.0220,  0.4828,
-        -0.0466,  0.0000, -0.2007,  0.0000,  0.0000, -0.2168,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0277, -0.2195,  0.1527,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2860,  0.0000,  0.0000, -0.3425,  0.0000, -0.0299],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 8.8158e-02, -4.3882e-08, -2.0568e-01, -2.0216e-01,  2.1856e+00,
-         4.9726e-01,  2.6563e-17, -2.0974e-10, -3.2508e-01, -2.5369e-01,
-        -2.3415e-01,  4.7753e-09, -3.6092e-01, -5.7014e-01, -4.7635e-14,
-        -3.5404e-11, -3.4825e-12, -1.5378e-01, -8.0548e-15,  9.6845e-11,
-         3.6530e-01, -9.9443e-02, -1.7105e-14, -5.2247e-01,  8.6062e-02,
-         1.3883e-10,  6.9003e-02,  1.1497e-01,  1.5472e-01,  2.4877e-14,
-        -6.7629e-01, -1.0806e+00, -6.3681e-18, -4.6750e-01,  3.1608e-12,
-         1.7055e-10,  1.0544e-13,  0.0000e+00, -2.5522e-02,  4.5958e-01,
-        -1.2076e-01, -1.4083e-11, -2.2024e-01,  2.5949e-14,  6.3469e-12,
-        -2.0388e-01, -8.9553e-10,  2.8896e-03,  3.0126e-09, -1.6520e-15,
-        -5.3450e-06, -1.1358e-02, -2.8737e-01,  1.5933e-01, -6.0537e-07,
-         3.4380e-19, -1.3512e-13,  4.8594e-17,  3.1114e-01, -7.0181e-09,
-         1.2981e-16, -3.6035e-01,  1.0906e-04, -4.3638e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0882,  0.0000, -0.2057, -0.2022,  2.1856,  0.4973,  0.0000,  0.0000,
-        -0.3251, -0.2537, -0.2342,  0.0000, -0.3609, -0.5701,  0.0000,  0.0000,
-         0.0000, -0.1538,  0.0000,  0.0000,  0.3653, -0.0994,  0.0000, -0.5225,
-         0.0861,  0.0000,  0.0690,  0.1150,  0.1547,  0.0000, -0.6763, -1.0806,
-         0.0000, -0.4675,  0.0000,  0.0000,  0.0000,  0.0000, -0.0255,  0.4596,
-        -0.1208,  0.0000, -0.2202,  0.0000,  0.0000, -0.2039,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0114, -0.2874,  0.1593,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3111,  0.0000,  0.0000, -0.3603,  0.0000, -0.0044],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0882,  0.0000, -0.2057, -0.2022,  2.1856,  0.4973,  0.0000,  0.0000,
-        -0.3251, -0.2537, -0.2342,  0.0000, -0.3609, -0.5701,  0.0000,  0.0000,
-         0.0000, -0.1538,  0.0000,  0.0000,  0.3653, -0.0994,  0.0000, -0.5225,
-         0.0861,  0.0000,  0.0690,  0.1150,  0.1547,  0.0000, -0.6763, -1.0806,
-         0.0000, -0.4675,  0.0000,  0.0000,  0.0000,  0.0000, -0.0255,  0.4596,
-        -0.1208,  0.0000, -0.2202,  0.0000,  0.0000, -0.2039,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0114, -0.2874,  0.1593,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3111,  0.0000,  0.0000, -0.3603,  0.0000, -0.0044],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.1860e-01, -3.9388e-08, -1.9417e-01, -2.1915e-01,  2.1841e+00,
-         5.0431e-01,  2.3842e-17, -1.8826e-10, -3.3885e-01, -2.2301e-01,
-        -2.6224e-01,  4.2862e-09, -3.4094e-01, -5.6932e-01, -4.2756e-14,
-        -3.1778e-11, -3.1258e-12, -1.3496e-01, -7.2298e-15,  8.6926e-11,
-         4.0009e-01, -7.4254e-02, -1.5353e-14, -4.9488e-01,  7.4346e-02,
-         1.2461e-10,  8.6249e-02,  1.0824e-01,  1.8810e-01,  2.2329e-14,
-        -6.7527e-01, -1.0768e+00, -5.7158e-18, -4.6688e-01,  2.8371e-12,
-         1.5308e-10,  9.4638e-14,  0.0000e+00,  2.2788e-03,  4.6578e-01,
-        -1.4241e-01, -1.2641e-11, -2.3910e-01,  2.3291e-14,  5.6968e-12,
-        -1.8526e-01, -8.0381e-10,  2.5937e-03,  2.7040e-09, -1.4828e-15,
-        -4.7976e-06, -2.1790e-02, -3.3516e-01,  1.6476e-01, -5.4336e-07,
-         3.0859e-19, -1.2128e-13,  4.3617e-17,  3.2496e-01, -6.2993e-09,
-         1.1651e-16, -4.0126e-01,  9.7892e-05, -1.4540e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1186,  0.0000, -0.1942, -0.2192,  2.1841,  0.5043,  0.0000,  0.0000,
-        -0.3388, -0.2230, -0.2622,  0.0000, -0.3409, -0.5693,  0.0000,  0.0000,
-         0.0000, -0.1350,  0.0000,  0.0000,  0.4001, -0.0743,  0.0000, -0.4949,
-         0.0743,  0.0000,  0.0862,  0.1082,  0.1881,  0.0000, -0.6753, -1.0768,
-         0.0000, -0.4669,  0.0000,  0.0000,  0.0000,  0.0000,  0.0023,  0.4658,
-        -0.1424,  0.0000, -0.2391,  0.0000,  0.0000, -0.1853,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0218, -0.3352,  0.1648,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3250,  0.0000,  0.0000, -0.4013,  0.0000, -0.0145],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1186,  0.0000, -0.1942, -0.2192,  2.1841,  0.5043,  0.0000,  0.0000,
-        -0.3388, -0.2230, -0.2622,  0.0000, -0.3409, -0.5693,  0.0000,  0.0000,
-         0.0000, -0.1350,  0.0000,  0.0000,  0.4001, -0.0743,  0.0000, -0.4949,
-         0.0743,  0.0000,  0.0862,  0.1082,  0.1881,  0.0000, -0.6753, -1.0768,
-         0.0000, -0.4669,  0.0000,  0.0000,  0.0000,  0.0000,  0.0023,  0.4658,
-        -0.1424,  0.0000, -0.2391,  0.0000,  0.0000, -0.1853,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0218, -0.3352,  0.1648,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3250,  0.0000,  0.0000, -0.4013,  0.0000, -0.0145],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.5159e-01, -3.5367e-08, -1.7159e-01, -2.0020e-01,  2.1837e+00,
-         4.9485e-01,  2.1408e-17, -1.6904e-10, -3.5114e-01, -1.8128e-01,
-        -2.7092e-01,  3.8486e-09, -3.2484e-01, -5.6438e-01, -3.8391e-14,
-        -2.8534e-11, -2.8067e-12, -1.0087e-01, -6.4918e-15,  7.8053e-11,
-         4.1468e-01, -3.3873e-02, -1.3786e-14, -4.9376e-01,  8.2251e-02,
-         1.1189e-10,  9.2754e-02,  8.9762e-02,  2.1908e-01,  2.0050e-14,
-        -6.7266e-01, -1.0762e+00, -5.1324e-18, -4.5892e-01,  2.5475e-12,
-         1.3745e-10,  8.4978e-14,  0.0000e+00,  4.8161e-02,  4.8525e-01,
-        -1.6192e-01, -1.1350e-11, -2.4874e-01,  2.0914e-14,  5.1153e-12,
-        -1.5843e-01, -7.2175e-10,  2.3289e-03,  2.4280e-09, -1.3314e-15,
-        -4.3078e-06, -2.5467e-02, -3.8119e-01,  2.1734e-01, -4.8790e-07,
-         2.7709e-19, -1.0890e-13,  3.9165e-17,  3.4215e-01, -5.6562e-09,
-         1.0462e-16, -4.2493e-01,  8.7899e-05, -3.8984e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1516,  0.0000, -0.1716, -0.2002,  2.1837,  0.4948,  0.0000,  0.0000,
-        -0.3511, -0.1813, -0.2709,  0.0000, -0.3248, -0.5644,  0.0000,  0.0000,
-         0.0000, -0.1009,  0.0000,  0.0000,  0.4147, -0.0339,  0.0000, -0.4938,
-         0.0823,  0.0000,  0.0928,  0.0898,  0.2191,  0.0000, -0.6727, -1.0762,
-         0.0000, -0.4589,  0.0000,  0.0000,  0.0000,  0.0000,  0.0482,  0.4853,
-        -0.1619,  0.0000, -0.2487,  0.0000,  0.0000, -0.1584,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0255, -0.3812,  0.2173,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3421,  0.0000,  0.0000, -0.4249,  0.0000, -0.0390],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1516,  0.0000, -0.1716, -0.2002,  2.1837,  0.4948,  0.0000,  0.0000,
-        -0.3511, -0.1813, -0.2709,  0.0000, -0.3248, -0.5644,  0.0000,  0.0000,
-         0.0000, -0.1009,  0.0000,  0.0000,  0.4147, -0.0339,  0.0000, -0.4938,
-         0.0823,  0.0000,  0.0928,  0.0898,  0.2191,  0.0000, -0.6727, -1.0762,
-         0.0000, -0.4589,  0.0000,  0.0000,  0.0000,  0.0000,  0.0482,  0.4853,
-        -0.1619,  0.0000, -0.2487,  0.0000,  0.0000, -0.1584,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0255, -0.3812,  0.2173,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3421,  0.0000,  0.0000, -0.4249,  0.0000, -0.0390],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.0056e-01, -3.1769e-08, -1.5123e-01, -1.8306e-01,  2.1831e+00,
-         4.7450e-01,  1.9230e-17, -1.5185e-10, -3.6834e-01, -1.1915e-01,
-        -2.7160e-01,  3.4571e-09, -3.1276e-01, -5.6668e-01, -3.4485e-14,
-        -2.5631e-11, -2.5212e-12, -5.3643e-02, -5.8313e-15,  7.0112e-11,
-         4.1631e-01,  2.0690e-02, -1.2384e-14, -5.1648e-01,  8.8515e-02,
-         1.0051e-10,  9.6805e-02,  6.2519e-02,  2.3528e-01,  1.8010e-14,
-        -6.6496e-01, -1.0797e+00, -4.6102e-18, -4.6288e-01,  2.2883e-12,
-         1.2347e-10,  7.6332e-14,  0.0000e+00,  1.3235e-01,  5.0308e-01,
-        -1.5819e-01, -1.0195e-11, -2.5954e-01,  1.8786e-14,  4.5949e-12,
-        -1.4248e-01, -6.4833e-10,  2.0920e-03,  2.1810e-09, -1.1960e-15,
-        -3.8696e-06, -2.8477e-02, -4.0378e-01,  3.0796e-01, -4.3826e-07,
-         2.4890e-19, -9.7822e-14,  3.5180e-17,  3.5346e-01, -5.0808e-09,
-         9.3975e-17, -4.4582e-01,  7.8957e-05, -7.5409e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2006,  0.0000, -0.1512, -0.1831,  2.1831,  0.4745,  0.0000,  0.0000,
-        -0.3683, -0.1191, -0.2716,  0.0000, -0.3128, -0.5667,  0.0000,  0.0000,
-         0.0000, -0.0536,  0.0000,  0.0000,  0.4163,  0.0207,  0.0000, -0.5165,
-         0.0885,  0.0000,  0.0968,  0.0625,  0.2353,  0.0000, -0.6650, -1.0797,
-         0.0000, -0.4629,  0.0000,  0.0000,  0.0000,  0.0000,  0.1323,  0.5031,
-        -0.1582,  0.0000, -0.2595,  0.0000,  0.0000, -0.1425,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0285, -0.4038,  0.3080,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3535,  0.0000,  0.0000, -0.4458,  0.0000, -0.0754],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2006,  0.0000, -0.1512, -0.1831,  2.1831,  0.4745,  0.0000,  0.0000,
-        -0.3683, -0.1191, -0.2716,  0.0000, -0.3128, -0.5667,  0.0000,  0.0000,
-         0.0000, -0.0536,  0.0000,  0.0000,  0.4163,  0.0207,  0.0000, -0.5165,
-         0.0885,  0.0000,  0.0968,  0.0625,  0.2353,  0.0000, -0.6650, -1.0797,
-         0.0000, -0.4629,  0.0000,  0.0000,  0.0000,  0.0000,  0.1323,  0.5031,
-        -0.1582,  0.0000, -0.2595,  0.0000,  0.0000, -0.1425,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0285, -0.4038,  0.3080,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3535,  0.0000,  0.0000, -0.4458,  0.0000, -0.0754],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.3413e-01, -2.8548e-08, -1.4251e-01, -1.7883e-01,  2.1824e+00,
-         4.4955e-01,  1.7280e-17, -1.3645e-10, -3.8090e-01, -6.3280e-02,
-        -2.6438e-01,  3.1066e-09, -3.1060e-01, -5.5767e-01, -3.0989e-14,
-        -2.3032e-11, -2.2656e-12, -9.7334e-03, -5.2401e-15,  6.3003e-11,
-         4.0821e-01,  9.2159e-02, -1.1128e-14, -5.2195e-01,  1.1423e-01,
-         9.0316e-11,  9.0726e-02,  3.6295e-02,  2.5391e-01,  1.6184e-14,
-        -6.5224e-01, -1.0842e+00, -4.1428e-18, -4.6633e-01,  2.0563e-12,
-         1.1095e-10,  6.8593e-14,  0.0000e+00,  1.6752e-01,  5.1606e-01,
-        -1.3779e-01, -9.1617e-12, -2.7304e-01,  1.6881e-14,  4.1290e-12,
-        -1.1879e-01, -5.8259e-10,  1.8799e-03,  1.9598e-09, -1.0747e-15,
-        -3.4772e-06, -1.5189e-02, -4.1812e-01,  3.8017e-01, -3.9383e-07,
-         2.2366e-19, -8.7903e-14,  3.1613e-17,  3.7162e-01, -4.5656e-09,
-         8.4446e-17, -4.7260e-01,  7.0951e-05, -1.1006e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2341,  0.0000, -0.1425, -0.1788,  2.1824,  0.4496,  0.0000,  0.0000,
-        -0.3809, -0.0633, -0.2644,  0.0000, -0.3106, -0.5577,  0.0000,  0.0000,
-         0.0000, -0.0097,  0.0000,  0.0000,  0.4082,  0.0922,  0.0000, -0.5220,
-         0.1142,  0.0000,  0.0907,  0.0363,  0.2539,  0.0000, -0.6522, -1.0842,
-         0.0000, -0.4663,  0.0000,  0.0000,  0.0000,  0.0000,  0.1675,  0.5161,
-        -0.1378,  0.0000, -0.2730,  0.0000,  0.0000, -0.1188,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0152, -0.4181,  0.3802,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3716,  0.0000,  0.0000, -0.4726,  0.0000, -0.1101],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2341,  0.0000, -0.1425, -0.1788,  2.1824,  0.4496,  0.0000,  0.0000,
-        -0.3809, -0.0633, -0.2644,  0.0000, -0.3106, -0.5577,  0.0000,  0.0000,
-         0.0000, -0.0097,  0.0000,  0.0000,  0.4082,  0.0922,  0.0000, -0.5220,
-         0.1142,  0.0000,  0.0907,  0.0363,  0.2539,  0.0000, -0.6522, -1.0842,
-         0.0000, -0.4663,  0.0000,  0.0000,  0.0000,  0.0000,  0.1675,  0.5161,
-        -0.1378,  0.0000, -0.2730,  0.0000,  0.0000, -0.1188,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0152, -0.4181,  0.3802,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3716,  0.0000,  0.0000, -0.4726,  0.0000, -0.1101],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7548e-01, -2.5663e-08, -1.3338e-01, -1.6024e-01,  2.1807e+00,
-         4.0528e-01,  1.5534e-17, -1.2266e-10, -3.7627e-01, -1.2850e-03,
-        -2.6194e-01,  2.7926e-09, -2.8646e-01, -5.5185e-01, -2.7857e-14,
-        -2.0705e-11, -2.0366e-12,  2.7113e-02, -4.7105e-15,  5.6636e-11,
-         4.0526e-01,  1.4718e-01, -1.0003e-14, -5.2873e-01,  1.2618e-01,
-         8.1189e-11,  8.1500e-02,  1.7100e-02,  2.6130e-01,  1.4549e-14,
-        -6.4595e-01, -1.0916e+00, -3.7241e-18, -4.8084e-01,  1.8485e-12,
-         9.9737e-11,  6.1661e-14,  0.0000e+00,  2.0802e-01,  5.3673e-01,
-        -1.1928e-01, -8.2359e-12, -2.7776e-01,  1.5175e-14,  3.7117e-12,
-        -8.9848e-02, -5.2372e-10,  1.6899e-03,  1.7618e-09, -9.6610e-16,
-        -3.1258e-06,  2.4375e-02, -4.1497e-01,  4.6433e-01, -3.5403e-07,
-         2.0106e-19, -7.9020e-14,  2.8419e-17,  3.8264e-01, -4.1043e-09,
-         7.5913e-17, -4.8093e-01,  6.3781e-05, -1.3466e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.7548e-01,  0.0000e+00, -1.3338e-01, -1.6024e-01,  2.1807e+00,
-         4.0528e-01,  0.0000e+00,  0.0000e+00, -3.7627e-01, -1.2850e-03,
-        -2.6194e-01,  0.0000e+00, -2.8646e-01, -5.5185e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  2.7113e-02,  0.0000e+00,  0.0000e+00,
-         4.0526e-01,  1.4718e-01,  0.0000e+00, -5.2873e-01,  1.2618e-01,
-         0.0000e+00,  8.1500e-02,  1.7100e-02,  2.6130e-01,  0.0000e+00,
-        -6.4595e-01, -1.0916e+00,  0.0000e+00, -4.8084e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  2.0802e-01,  5.3673e-01,
-        -1.1928e-01,  0.0000e+00, -2.7776e-01,  0.0000e+00,  0.0000e+00,
-        -8.9848e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  2.4375e-02, -4.1497e-01,  4.6433e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.8264e-01,  0.0000e+00,
-         0.0000e+00, -4.8093e-01,  0.0000e+00, -1.3466e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.7548e-01,  0.0000e+00, -1.3338e-01, -1.6024e-01,  2.1807e+00,
-         4.0528e-01,  0.0000e+00,  0.0000e+00, -3.7627e-01, -1.2850e-03,
-        -2.6194e-01,  0.0000e+00, -2.8646e-01, -5.5185e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  2.7113e-02,  0.0000e+00,  0.0000e+00,
-         4.0526e-01,  1.4718e-01,  0.0000e+00, -5.2873e-01,  1.2618e-01,
-         0.0000e+00,  8.1500e-02,  1.7100e-02,  2.6130e-01,  0.0000e+00,
-        -6.4595e-01, -1.0916e+00,  0.0000e+00, -4.8084e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  2.0802e-01,  5.3673e-01,
-        -1.1928e-01,  0.0000e+00, -2.7776e-01,  0.0000e+00,  0.0000e+00,
-        -8.9848e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  2.4375e-02, -4.1497e-01,  4.6433e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.8264e-01,  0.0000e+00,
-         0.0000e+00, -4.8093e-01,  0.0000e+00, -1.3466e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9162e-01, -2.3078e-08, -1.2654e-01, -1.5506e-01,  2.1791e+00,
-         3.6061e-01,  1.3970e-17, -1.1031e-10, -3.6315e-01,  4.2981e-02,
-        -2.4182e-01,  2.5114e-09, -2.6364e-01, -5.5068e-01, -2.5052e-14,
-        -1.8619e-11, -1.8315e-12,  5.2867e-02, -4.2361e-15,  5.0932e-11,
-         3.9400e-01,  2.0614e-01, -8.9960e-15, -5.3159e-01,  1.4412e-01,
-         7.3012e-11,  7.2858e-02, -1.1051e-02,  2.4810e-01,  1.3083e-14,
-        -6.4119e-01, -1.0992e+00, -3.3491e-18, -4.9478e-01,  1.6623e-12,
-         8.9692e-11,  5.5451e-14,  0.0000e+00,  2.2818e-01,  5.3236e-01,
-        -1.0247e-01, -7.4064e-12, -2.8088e-01,  1.3647e-14,  3.3379e-12,
-        -5.1483e-02, -4.7097e-10,  1.5197e-03,  1.5843e-09, -8.6880e-16,
-        -2.8110e-06,  5.6270e-02, -3.8638e-01,  4.9464e-01, -3.1837e-07,
-         1.8081e-19, -7.1062e-14,  2.5556e-17,  4.0585e-01, -3.6909e-09,
-         6.8267e-17, -4.8845e-01,  5.7358e-05, -1.4390e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2916,  0.0000, -0.1265, -0.1551,  2.1791,  0.3606,  0.0000,  0.0000,
-        -0.3631,  0.0430, -0.2418,  0.0000, -0.2636, -0.5507,  0.0000,  0.0000,
-         0.0000,  0.0529,  0.0000,  0.0000,  0.3940,  0.2061,  0.0000, -0.5316,
-         0.1441,  0.0000,  0.0729, -0.0111,  0.2481,  0.0000, -0.6412, -1.0992,
-         0.0000, -0.4948,  0.0000,  0.0000,  0.0000,  0.0000,  0.2282,  0.5324,
-        -0.1025,  0.0000, -0.2809,  0.0000,  0.0000, -0.0515,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0563, -0.3864,  0.4946,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4058,  0.0000,  0.0000, -0.4884,  0.0000, -0.1439],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2916,  0.0000, -0.1265, -0.1551,  2.1791,  0.3606,  0.0000,  0.0000,
-        -0.3631,  0.0430, -0.2418,  0.0000, -0.2636, -0.5507,  0.0000,  0.0000,
-         0.0000,  0.0529,  0.0000,  0.0000,  0.3940,  0.2061,  0.0000, -0.5316,
-         0.1441,  0.0000,  0.0729, -0.0111,  0.2481,  0.0000, -0.6412, -1.0992,
-         0.0000, -0.4948,  0.0000,  0.0000,  0.0000,  0.0000,  0.2282,  0.5324,
-        -0.1025,  0.0000, -0.2809,  0.0000,  0.0000, -0.0515,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0563, -0.3864,  0.4946,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4058,  0.0000,  0.0000, -0.4884,  0.0000, -0.1439],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1972e-01, -2.0762e-08, -1.1903e-01, -1.4196e-01,  2.1776e+00,
-         3.1615e-01,  1.2568e-17, -9.9235e-11, -3.3834e-01,  7.3959e-02,
-        -1.9676e-01,  2.2593e-09, -2.3950e-01, -5.3189e-01, -2.2537e-14,
-        -1.6750e-11, -1.6477e-12,  3.3702e-02, -3.8109e-15,  4.5820e-11,
-         3.8152e-01,  2.5769e-01, -8.0930e-15, -5.5177e-01,  1.6626e-01,
-         6.5683e-11,  4.9992e-02, -3.4816e-02,  2.2942e-01,  1.1770e-14,
-        -6.4460e-01, -1.1089e+00, -3.0129e-18, -5.0445e-01,  1.4955e-12,
-         8.0689e-11,  4.9885e-14,  0.0000e+00,  2.3938e-01,  5.2122e-01,
-        -7.7574e-02, -6.6630e-12, -2.7592e-01,  1.2277e-14,  3.0029e-12,
-        -2.8329e-02, -4.2370e-10,  1.3672e-03,  1.4253e-09, -7.8159e-16,
-        -2.5289e-06,  8.6942e-02, -3.3848e-01,  4.5782e-01, -2.8642e-07,
-         1.6266e-19, -6.3929e-14,  2.2991e-17,  4.2619e-01, -3.3204e-09,
-         6.1415e-17, -4.9603e-01,  5.1600e-05, -1.2836e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3197,  0.0000, -0.1190, -0.1420,  2.1776,  0.3162,  0.0000,  0.0000,
-        -0.3383,  0.0740, -0.1968,  0.0000, -0.2395, -0.5319,  0.0000,  0.0000,
-         0.0000,  0.0337,  0.0000,  0.0000,  0.3815,  0.2577,  0.0000, -0.5518,
-         0.1663,  0.0000,  0.0500, -0.0348,  0.2294,  0.0000, -0.6446, -1.1089,
-         0.0000, -0.5044,  0.0000,  0.0000,  0.0000,  0.0000,  0.2394,  0.5212,
-        -0.0776,  0.0000, -0.2759,  0.0000,  0.0000, -0.0283,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0869, -0.3385,  0.4578,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4262,  0.0000,  0.0000, -0.4960,  0.0000, -0.1284],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3197,  0.0000, -0.1190, -0.1420,  2.1776,  0.3162,  0.0000,  0.0000,
-        -0.3383,  0.0740, -0.1968,  0.0000, -0.2395, -0.5319,  0.0000,  0.0000,
-         0.0000,  0.0337,  0.0000,  0.0000,  0.3815,  0.2577,  0.0000, -0.5518,
-         0.1663,  0.0000,  0.0500, -0.0348,  0.2294,  0.0000, -0.6446, -1.1089,
-         0.0000, -0.5044,  0.0000,  0.0000,  0.0000,  0.0000,  0.2394,  0.5212,
-        -0.0776,  0.0000, -0.2759,  0.0000,  0.0000, -0.0283,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0869, -0.3385,  0.4578,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4262,  0.0000,  0.0000, -0.4960,  0.0000, -0.1284],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.3476e-01, -1.8685e-08, -1.2629e-01, -1.4151e-01,  2.1754e+00,
-         3.0189e-01,  1.1310e-17, -8.9308e-11, -2.9716e-01,  7.7018e-02,
-        -1.4756e-01,  2.0333e-09, -2.1819e-01, -5.1534e-01, -2.0283e-14,
-        -1.5075e-11, -1.4828e-12,  3.8296e-03, -3.4297e-15,  4.1236e-11,
-         3.7151e-01,  2.8307e-01, -7.2834e-15, -5.7182e-01,  1.6477e-01,
-         5.9112e-11,  4.0789e-02, -6.1798e-02,  2.1454e-01,  1.0593e-14,
-        -6.5109e-01, -1.1187e+00, -2.7115e-18, -5.0840e-01,  1.3459e-12,
-         7.2617e-11,  4.4895e-14,  0.0000e+00,  2.2067e-01,  4.9120e-01,
-        -6.0785e-02, -5.9964e-12, -2.8235e-01,  1.1049e-14,  2.7025e-12,
-         3.2307e-03, -3.8131e-10,  1.2304e-03,  1.2827e-09, -7.0340e-16,
-        -2.2759e-06,  9.6949e-02, -2.6468e-01,  3.9122e-01, -2.5776e-07,
-         1.4639e-19, -5.7534e-14,  2.0691e-17,  4.5740e-01, -2.9883e-09,
-         5.5271e-17, -4.9029e-01,  4.6438e-05, -9.0530e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3348,  0.0000, -0.1263, -0.1415,  2.1754,  0.3019,  0.0000,  0.0000,
-        -0.2972,  0.0770, -0.1476,  0.0000, -0.2182, -0.5153,  0.0000,  0.0000,
-         0.0000,  0.0038,  0.0000,  0.0000,  0.3715,  0.2831,  0.0000, -0.5718,
-         0.1648,  0.0000,  0.0408, -0.0618,  0.2145,  0.0000, -0.6511, -1.1187,
-         0.0000, -0.5084,  0.0000,  0.0000,  0.0000,  0.0000,  0.2207,  0.4912,
-        -0.0608,  0.0000, -0.2824,  0.0000,  0.0000,  0.0032,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0969, -0.2647,  0.3912,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4574,  0.0000,  0.0000, -0.4903,  0.0000, -0.0905],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3348,  0.0000, -0.1263, -0.1415,  2.1754,  0.3019,  0.0000,  0.0000,
-        -0.2972,  0.0770, -0.1476,  0.0000, -0.2182, -0.5153,  0.0000,  0.0000,
-         0.0000,  0.0038,  0.0000,  0.0000,  0.3715,  0.2831,  0.0000, -0.5718,
-         0.1648,  0.0000,  0.0408, -0.0618,  0.2145,  0.0000, -0.6511, -1.1187,
-         0.0000, -0.5084,  0.0000,  0.0000,  0.0000,  0.0000,  0.2207,  0.4912,
-        -0.0608,  0.0000, -0.2824,  0.0000,  0.0000,  0.0032,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0969, -0.2647,  0.3912,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4574,  0.0000,  0.0000, -0.4903,  0.0000, -0.0905],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4350e-01, -1.6822e-08, -1.2135e-01, -1.3644e-01,  2.1728e+00,
-         2.8648e-01,  1.0183e-17, -8.0403e-11, -2.5700e-01,  8.6198e-02,
-        -8.1634e-02,  1.8305e-09, -1.8948e-01, -5.0418e-01, -1.8260e-14,
-        -1.3572e-11, -1.3350e-12, -2.0093e-02, -3.0877e-15,  3.7125e-11,
-         3.5317e-01,  2.9839e-01, -6.5572e-15, -6.2921e-01,  1.6288e-01,
-         5.3219e-11,  5.0730e-02, -8.6931e-02,  1.9487e-01,  9.5365e-15,
-        -6.5378e-01, -1.1307e+00, -2.4411e-18, -5.2250e-01,  1.2117e-12,
-         6.5377e-11,  4.0418e-14,  0.0000e+00,  1.9503e-01,  4.5270e-01,
-        -5.7417e-02, -5.3986e-12, -2.7732e-01,  9.9473e-15,  2.4330e-12,
-         4.7138e-02, -3.4329e-10,  1.1077e-03,  1.1548e-09, -6.3327e-16,
-        -2.0490e-06,  9.9572e-02, -1.9544e-01,  2.9516e-01, -2.3206e-07,
-         1.3179e-19, -5.1797e-14,  1.8628e-17,  4.9191e-01, -2.6903e-09,
-         4.9760e-17, -4.7307e-01,  4.1808e-05, -2.5427e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3435,  0.0000, -0.1213, -0.1364,  2.1728,  0.2865,  0.0000,  0.0000,
-        -0.2570,  0.0862, -0.0816,  0.0000, -0.1895, -0.5042,  0.0000,  0.0000,
-         0.0000, -0.0201,  0.0000,  0.0000,  0.3532,  0.2984,  0.0000, -0.6292,
-         0.1629,  0.0000,  0.0507, -0.0869,  0.1949,  0.0000, -0.6538, -1.1307,
-         0.0000, -0.5225,  0.0000,  0.0000,  0.0000,  0.0000,  0.1950,  0.4527,
-        -0.0574,  0.0000, -0.2773,  0.0000,  0.0000,  0.0471,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0996, -0.1954,  0.2952,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4919,  0.0000,  0.0000, -0.4731,  0.0000, -0.0254],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3435,  0.0000, -0.1213, -0.1364,  2.1728,  0.2865,  0.0000,  0.0000,
-        -0.2570,  0.0862, -0.0816,  0.0000, -0.1895, -0.5042,  0.0000,  0.0000,
-         0.0000, -0.0201,  0.0000,  0.0000,  0.3532,  0.2984,  0.0000, -0.6292,
-         0.1629,  0.0000,  0.0507, -0.0869,  0.1949,  0.0000, -0.6538, -1.1307,
-         0.0000, -0.5225,  0.0000,  0.0000,  0.0000,  0.0000,  0.1950,  0.4527,
-        -0.0574,  0.0000, -0.2773,  0.0000,  0.0000,  0.0471,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0996, -0.1954,  0.2952,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4919,  0.0000,  0.0000, -0.4731,  0.0000, -0.0254],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4566e-01, -1.5150e-08, -1.2021e-01, -1.3811e-01,  2.1698e+00,
-         2.7488e-01,  9.1707e-18, -7.2413e-11, -2.1275e-01,  8.5270e-02,
-        -1.6017e-02,  1.6486e-09, -1.5622e-01, -4.8785e-01, -1.6446e-14,
-        -1.2223e-11, -1.2023e-12, -4.8906e-02, -2.7809e-15,  3.3436e-11,
-         3.3580e-01,  2.9030e-01, -5.9056e-15, -6.8274e-01,  1.4093e-01,
-         4.7930e-11,  6.8274e-02, -1.0490e-01,  1.7695e-01,  8.5888e-15,
-        -6.5756e-01, -1.1418e+00, -2.1986e-18, -5.3391e-01,  1.0913e-12,
-         5.8880e-11,  3.6402e-14,  0.0000e+00,  1.7121e-01,  4.1787e-01,
-        -6.2913e-02, -4.8621e-12, -2.8114e-01,  8.9589e-15,  2.1912e-12,
-         1.1269e-01, -3.0918e-10,  9.9764e-04,  1.0401e-09, -5.7034e-16,
-        -1.8454e-06,  9.1831e-02, -9.4393e-02,  1.9276e-01, -2.0900e-07,
-         1.1870e-19, -4.6650e-14,  1.6777e-17,  5.3321e-01, -2.4230e-09,
-         4.4815e-17, -4.4573e-01,  3.7654e-05,  4.7404e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3457,  0.0000, -0.1202, -0.1381,  2.1698,  0.2749,  0.0000,  0.0000,
-        -0.2127,  0.0853, -0.0160,  0.0000, -0.1562, -0.4879,  0.0000,  0.0000,
-         0.0000, -0.0489,  0.0000,  0.0000,  0.3358,  0.2903,  0.0000, -0.6827,
-         0.1409,  0.0000,  0.0683, -0.1049,  0.1770,  0.0000, -0.6576, -1.1418,
-         0.0000, -0.5339,  0.0000,  0.0000,  0.0000,  0.0000,  0.1712,  0.4179,
-        -0.0629,  0.0000, -0.2811,  0.0000,  0.0000,  0.1127,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0918, -0.0944,  0.1928,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5332,  0.0000,  0.0000, -0.4457,  0.0000,  0.0474],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3457,  0.0000, -0.1202, -0.1381,  2.1698,  0.2749,  0.0000,  0.0000,
-        -0.2127,  0.0853, -0.0160,  0.0000, -0.1562, -0.4879,  0.0000,  0.0000,
-         0.0000, -0.0489,  0.0000,  0.0000,  0.3358,  0.2903,  0.0000, -0.6827,
-         0.1409,  0.0000,  0.0683, -0.1049,  0.1770,  0.0000, -0.6576, -1.1418,
-         0.0000, -0.5339,  0.0000,  0.0000,  0.0000,  0.0000,  0.1712,  0.4179,
-        -0.0629,  0.0000, -0.2811,  0.0000,  0.0000,  0.1127,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0918, -0.0944,  0.1928,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5332,  0.0000,  0.0000, -0.4457,  0.0000,  0.0474],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6494e-01, -1.3650e-08, -1.1858e-01, -1.2496e-01,  2.1669e+00,
-         2.6179e-01,  8.2624e-18, -6.5241e-11, -1.7382e-01,  8.7583e-02,
-         1.1389e-02,  1.4854e-09, -1.2473e-01, -4.8490e-01, -1.4817e-14,
-        -1.1012e-11, -1.0832e-12, -3.3844e-02, -2.5055e-15,  3.0124e-11,
-         3.3300e-01,  2.7813e-01, -5.3207e-15, -7.3835e-01,  9.9127e-02,
-         4.3183e-11,  8.5187e-02, -9.3523e-02,  1.6559e-01,  7.7382e-15,
-        -6.4384e-01, -1.1541e+00, -1.9808e-18, -5.2862e-01,  9.8318e-13,
-         5.3049e-11,  3.2797e-14,  0.0000e+00,  1.8341e-01,  4.2557e-01,
-        -6.4252e-02, -4.3805e-12, -2.9425e-01,  8.0716e-15,  1.9742e-12,
-         1.3845e-01, -2.7856e-10,  8.9883e-04,  9.3707e-10, -5.1385e-16,
-        -1.6626e-06,  9.9805e-02, -3.5748e-02,  1.3988e-01, -1.8830e-07,
-         1.0694e-19, -4.2030e-14,  1.5115e-17,  5.6053e-01, -2.1830e-09,
-         4.0377e-17, -4.0925e-01,  3.3924e-05,  7.4484e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3649,  0.0000, -0.1186, -0.1250,  2.1669,  0.2618,  0.0000,  0.0000,
-        -0.1738,  0.0876,  0.0114,  0.0000, -0.1247, -0.4849,  0.0000,  0.0000,
-         0.0000, -0.0338,  0.0000,  0.0000,  0.3330,  0.2781,  0.0000, -0.7384,
-         0.0991,  0.0000,  0.0852, -0.0935,  0.1656,  0.0000, -0.6438, -1.1541,
-         0.0000, -0.5286,  0.0000,  0.0000,  0.0000,  0.0000,  0.1834,  0.4256,
-        -0.0643,  0.0000, -0.2943,  0.0000,  0.0000,  0.1385,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0998, -0.0357,  0.1399,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5605,  0.0000,  0.0000, -0.4092,  0.0000,  0.0745],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3649,  0.0000, -0.1186, -0.1250,  2.1669,  0.2618,  0.0000,  0.0000,
-        -0.1738,  0.0876,  0.0114,  0.0000, -0.1247, -0.4849,  0.0000,  0.0000,
-         0.0000, -0.0338,  0.0000,  0.0000,  0.3330,  0.2781,  0.0000, -0.7384,
-         0.0991,  0.0000,  0.0852, -0.0935,  0.1656,  0.0000, -0.6438, -1.1541,
-         0.0000, -0.5286,  0.0000,  0.0000,  0.0000,  0.0000,  0.1834,  0.4256,
-        -0.0643,  0.0000, -0.2943,  0.0000,  0.0000,  0.1385,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0998, -0.0357,  0.1399,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5605,  0.0000,  0.0000, -0.4092,  0.0000,  0.0745],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.8378e-01, -1.2302e-08, -1.1654e-01, -9.4522e-02,  2.1641e+00,
-         2.4571e-01,  7.4469e-18, -5.8801e-11, -1.1570e-01,  7.7782e-02,
-         2.7506e-02,  1.3387e-09, -9.3391e-02, -4.7191e-01, -1.3354e-14,
-        -9.9254e-12, -9.7632e-13, -3.5504e-04, -2.2582e-15,  2.7151e-11,
-         3.3659e-01,  2.8729e-01, -4.7955e-15, -7.7964e-01,  6.9043e-02,
-         3.8921e-11,  1.0306e-01, -5.6698e-02,  1.7068e-01,  6.9744e-15,
-        -6.4054e-01, -1.1627e+00, -1.7853e-18, -5.1431e-01,  8.8613e-13,
-         4.7812e-11,  2.9559e-14,  0.0000e+00,  1.9784e-01,  4.3911e-01,
-        -4.6658e-02, -3.9481e-12, -3.1377e-01,  7.2748e-15,  1.7793e-12,
-         1.1974e-01, -2.5106e-10,  8.1011e-04,  8.4457e-10, -4.6313e-16,
-        -1.4985e-06,  1.1641e-01, -4.5995e-03,  1.2354e-01, -1.6971e-07,
-         9.6385e-20, -3.7881e-14,  1.3623e-17,  5.8265e-01, -1.9675e-09,
-         3.6391e-17, -3.7103e-01,  3.0576e-05,  9.5392e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.8378e-01,  0.0000e+00, -1.1654e-01, -9.4522e-02,  2.1641e+00,
-         2.4571e-01,  0.0000e+00,  0.0000e+00, -1.1570e-01,  7.7782e-02,
-         2.7506e-02,  0.0000e+00, -9.3391e-02, -4.7191e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.5504e-04,  0.0000e+00,  0.0000e+00,
-         3.3659e-01,  2.8729e-01,  0.0000e+00, -7.7964e-01,  6.9043e-02,
-         0.0000e+00,  1.0306e-01, -5.6698e-02,  1.7068e-01,  0.0000e+00,
-        -6.4054e-01, -1.1627e+00,  0.0000e+00, -5.1431e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.9784e-01,  4.3911e-01,
-        -4.6658e-02,  0.0000e+00, -3.1377e-01,  0.0000e+00,  0.0000e+00,
-         1.1974e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.1641e-01, -4.5995e-03,  1.2354e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.8265e-01,  0.0000e+00,
-         0.0000e+00, -3.7103e-01,  0.0000e+00,  9.5392e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.8378e-01,  0.0000e+00, -1.1654e-01, -9.4522e-02,  2.1641e+00,
-         2.4571e-01,  0.0000e+00,  0.0000e+00, -1.1570e-01,  7.7782e-02,
-         2.7506e-02,  0.0000e+00, -9.3391e-02, -4.7191e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.5504e-04,  0.0000e+00,  0.0000e+00,
-         3.3659e-01,  2.8729e-01,  0.0000e+00, -7.7964e-01,  6.9043e-02,
-         0.0000e+00,  1.0306e-01, -5.6698e-02,  1.7068e-01,  0.0000e+00,
-        -6.4054e-01, -1.1627e+00,  0.0000e+00, -5.1431e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.9784e-01,  4.3911e-01,
-        -4.6658e-02,  0.0000e+00, -3.1377e-01,  0.0000e+00,  0.0000e+00,
-         1.1974e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.1641e-01, -4.5995e-03,  1.2354e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.8265e-01,  0.0000e+00,
-         0.0000e+00, -3.7103e-01,  0.0000e+00,  9.5392e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 4.0115e-01, -1.1092e-08, -1.1136e-01, -5.3675e-02,  2.1611e+00,
-         2.2618e-01,  6.7142e-18, -5.3016e-11, -5.0749e-02,  7.4129e-02,
-        -1.1032e-02,  1.2070e-09, -6.2400e-02, -4.7595e-01, -1.2041e-14,
-        -8.9489e-12, -8.8027e-13,  5.8627e-02, -2.0360e-15,  2.4479e-11,
-         3.4226e-01,  2.9376e-01, -4.3237e-15, -8.1436e-01,  3.9984e-02,
-         3.5091e-11,  1.2339e-01, -1.7668e-02,  1.7923e-01,  6.2882e-15,
-        -6.1292e-01, -1.1682e+00, -1.6097e-18, -4.8654e-01,  7.9895e-13,
-         4.3108e-11,  2.6651e-14,  0.0000e+00,  2.0491e-01,  4.4606e-01,
-        -4.9989e-02, -3.5597e-12, -3.3376e-01,  6.5591e-15,  1.6043e-12,
-         6.8220e-02, -2.2636e-10,  7.3041e-04,  7.6148e-10, -4.1757e-16,
-        -1.3511e-06,  1.4861e-01, -1.7972e-02,  1.3884e-01, -1.5302e-07,
-         8.6903e-20, -3.4154e-14,  1.2283e-17,  5.9899e-01, -1.7739e-09,
-         3.2811e-17, -3.3345e-01,  2.7568e-05,  1.0280e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4011,  0.0000, -0.1114, -0.0537,  2.1611,  0.2262,  0.0000,  0.0000,
-        -0.0507,  0.0741, -0.0110,  0.0000, -0.0624, -0.4760,  0.0000,  0.0000,
-         0.0000,  0.0586,  0.0000,  0.0000,  0.3423,  0.2938,  0.0000, -0.8144,
-         0.0400,  0.0000,  0.1234, -0.0177,  0.1792,  0.0000, -0.6129, -1.1682,
-         0.0000, -0.4865,  0.0000,  0.0000,  0.0000,  0.0000,  0.2049,  0.4461,
-        -0.0500,  0.0000, -0.3338,  0.0000,  0.0000,  0.0682,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1486, -0.0180,  0.1388,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5990,  0.0000,  0.0000, -0.3334,  0.0000,  0.1028],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4011,  0.0000, -0.1114, -0.0537,  2.1611,  0.2262,  0.0000,  0.0000,
-        -0.0507,  0.0741, -0.0110,  0.0000, -0.0624, -0.4760,  0.0000,  0.0000,
-         0.0000,  0.0586,  0.0000,  0.0000,  0.3423,  0.2938,  0.0000, -0.8144,
-         0.0400,  0.0000,  0.1234, -0.0177,  0.1792,  0.0000, -0.6129, -1.1682,
-         0.0000, -0.4865,  0.0000,  0.0000,  0.0000,  0.0000,  0.2049,  0.4461,
-        -0.0500,  0.0000, -0.3338,  0.0000,  0.0000,  0.0682,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1486, -0.0180,  0.1388,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5990,  0.0000,  0.0000, -0.3334,  0.0000,  0.1028],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.0911e-01, -1.0004e-08, -9.8797e-02, -7.8877e-03,  2.1569e+00,
-         2.0037e-01,  6.0558e-18, -4.7818e-11, -2.2492e-03,  6.3084e-02,
-        -7.9037e-02,  1.0887e-09, -4.0200e-02, -4.8248e-01, -1.0860e-14,
-        -8.0714e-12, -7.9395e-13,  9.6703e-02, -1.8363e-15,  2.2079e-11,
-         3.4783e-01,  2.9475e-01, -3.8997e-15, -8.4592e-01,  1.3096e-02,
-         3.1651e-11,  1.5022e-01,  1.6158e-02,  1.7999e-01,  5.6716e-15,
-        -5.7247e-01, -1.1728e+00, -1.4518e-18, -4.5991e-01,  7.2061e-13,
-         3.8881e-11,  2.4038e-14,  0.0000e+00,  2.1585e-01,  4.5347e-01,
-        -5.9902e-02, -3.2107e-12, -3.4228e-01,  5.9159e-15,  1.4470e-12,
-        -4.5267e-03, -2.0416e-10,  6.5878e-04,  6.8681e-10, -3.7662e-16,
-        -1.2186e-06,  1.7241e-01, -5.9256e-02,  1.6836e-01, -1.3801e-07,
-         7.8381e-20, -3.0805e-14,  1.1079e-17,  6.1069e-01, -1.6000e-09,
-         2.9594e-17, -2.9408e-01,  2.4864e-05,  8.2620e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4091,  0.0000, -0.0988, -0.0079,  2.1569,  0.2004,  0.0000,  0.0000,
-        -0.0022,  0.0631, -0.0790,  0.0000, -0.0402, -0.4825,  0.0000,  0.0000,
-         0.0000,  0.0967,  0.0000,  0.0000,  0.3478,  0.2947,  0.0000, -0.8459,
-         0.0131,  0.0000,  0.1502,  0.0162,  0.1800,  0.0000, -0.5725, -1.1728,
-         0.0000, -0.4599,  0.0000,  0.0000,  0.0000,  0.0000,  0.2159,  0.4535,
-        -0.0599,  0.0000, -0.3423,  0.0000,  0.0000, -0.0045,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1724, -0.0593,  0.1684,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6107,  0.0000,  0.0000, -0.2941,  0.0000,  0.0826],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4091,  0.0000, -0.0988, -0.0079,  2.1569,  0.2004,  0.0000,  0.0000,
-        -0.0022,  0.0631, -0.0790,  0.0000, -0.0402, -0.4825,  0.0000,  0.0000,
-         0.0000,  0.0967,  0.0000,  0.0000,  0.3478,  0.2947,  0.0000, -0.8459,
-         0.0131,  0.0000,  0.1502,  0.0162,  0.1800,  0.0000, -0.5725, -1.1728,
-         0.0000, -0.4599,  0.0000,  0.0000,  0.0000,  0.0000,  0.2159,  0.4535,
-        -0.0599,  0.0000, -0.3423,  0.0000,  0.0000, -0.0045,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1724, -0.0593,  0.1684,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6107,  0.0000,  0.0000, -0.2941,  0.0000,  0.0826],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.0110e-01, -9.0267e-09, -8.1565e-02,  3.0888e-02,  2.1523e+00,
-         1.9345e-01,  5.4640e-18, -4.3145e-11,  4.2721e-02,  5.9515e-02,
-        -1.8127e-01,  9.8228e-10, -3.0601e-02, -4.9964e-01, -9.7985e-15,
-        -7.2826e-12, -7.1636e-13,  9.9762e-02, -1.6569e-15,  1.9921e-11,
-         3.4717e-01,  2.7908e-01, -3.5186e-15, -8.5526e-01, -2.2547e-02,
-         2.8557e-11,  1.7925e-01,  3.8190e-02,  1.8060e-01,  5.1173e-15,
-        -5.1324e-01, -1.1779e+00, -1.3099e-18, -4.1472e-01,  6.5018e-13,
-         3.5081e-11,  2.1689e-14,  0.0000e+00,  2.1105e-01,  4.5648e-01,
-        -8.6101e-02, -2.8969e-12, -3.6095e-01,  5.3378e-15,  1.3056e-12,
-        -8.3358e-02, -1.8421e-10,  5.9440e-04,  6.1969e-10, -3.3981e-16,
-        -1.0995e-06,  1.5729e-01, -1.1390e-01,  2.0171e-01, -1.2453e-07,
-         7.0721e-20, -2.7795e-14,  9.9959e-18,  6.0594e-01, -1.4436e-09,
-         2.6701e-17, -2.5595e-01,  2.2434e-05,  6.2650e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4011,  0.0000, -0.0816,  0.0309,  2.1523,  0.1935,  0.0000,  0.0000,
-         0.0427,  0.0595, -0.1813,  0.0000, -0.0306, -0.4996,  0.0000,  0.0000,
-         0.0000,  0.0998,  0.0000,  0.0000,  0.3472,  0.2791,  0.0000, -0.8553,
-        -0.0225,  0.0000,  0.1793,  0.0382,  0.1806,  0.0000, -0.5132, -1.1779,
-         0.0000, -0.4147,  0.0000,  0.0000,  0.0000,  0.0000,  0.2111,  0.4565,
-        -0.0861,  0.0000, -0.3610,  0.0000,  0.0000, -0.0834,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1573, -0.1139,  0.2017,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6059,  0.0000,  0.0000, -0.2559,  0.0000,  0.0627],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4011,  0.0000, -0.0816,  0.0309,  2.1523,  0.1935,  0.0000,  0.0000,
-         0.0427,  0.0595, -0.1813,  0.0000, -0.0306, -0.4996,  0.0000,  0.0000,
-         0.0000,  0.0998,  0.0000,  0.0000,  0.3472,  0.2791,  0.0000, -0.8553,
-        -0.0225,  0.0000,  0.1793,  0.0382,  0.1806,  0.0000, -0.5132, -1.1779,
-         0.0000, -0.4147,  0.0000,  0.0000,  0.0000,  0.0000,  0.2111,  0.4565,
-        -0.0861,  0.0000, -0.3610,  0.0000,  0.0000, -0.0834,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1573, -0.1139,  0.2017,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6059,  0.0000,  0.0000, -0.2559,  0.0000,  0.0627],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.9704e-01, -8.1474e-09, -8.1265e-02,  5.1529e-02,  2.1474e+00,
-         1.8000e-01,  4.9318e-18, -3.8942e-11,  7.4672e-02,  6.8723e-02,
-        -2.3235e-01,  8.8659e-10, -5.4527e-02, -5.2342e-01, -8.8440e-15,
-        -6.5732e-12, -6.4658e-13,  9.6594e-02, -1.4955e-15,  1.7981e-11,
-         3.3653e-01,  2.6537e-01, -3.1759e-15, -8.7700e-01, -3.8003e-02,
-         2.5776e-11,  2.0347e-01,  4.9575e-02,  1.7915e-01,  4.6188e-15,
-        -4.5949e-01, -1.1865e+00, -1.1823e-18, -3.7282e-01,  5.8685e-13,
-         3.1664e-11,  1.9576e-14,  0.0000e+00,  1.8918e-01,  4.4122e-01,
-        -1.0750e-01, -2.6147e-12, -3.8015e-01,  4.8178e-15,  1.1784e-12,
-        -1.4649e-01, -1.6627e-10,  5.3650e-04,  5.5932e-10, -3.0671e-16,
-        -9.9238e-07,  1.1776e-01, -1.4347e-01,  2.1562e-01, -1.1240e-07,
-         6.3832e-20, -2.5087e-14,  9.0222e-18,  6.0578e-01, -1.3030e-09,
-         2.4100e-17, -2.2909e-01,  2.0249e-05,  6.2085e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3970,  0.0000, -0.0813,  0.0515,  2.1474,  0.1800,  0.0000,  0.0000,
-         0.0747,  0.0687, -0.2324,  0.0000, -0.0545, -0.5234,  0.0000,  0.0000,
-         0.0000,  0.0966,  0.0000,  0.0000,  0.3365,  0.2654,  0.0000, -0.8770,
-        -0.0380,  0.0000,  0.2035,  0.0496,  0.1791,  0.0000, -0.4595, -1.1865,
-         0.0000, -0.3728,  0.0000,  0.0000,  0.0000,  0.0000,  0.1892,  0.4412,
-        -0.1075,  0.0000, -0.3801,  0.0000,  0.0000, -0.1465,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1178, -0.1435,  0.2156,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6058,  0.0000,  0.0000, -0.2291,  0.0000,  0.0621],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3970,  0.0000, -0.0813,  0.0515,  2.1474,  0.1800,  0.0000,  0.0000,
-         0.0747,  0.0687, -0.2324,  0.0000, -0.0545, -0.5234,  0.0000,  0.0000,
-         0.0000,  0.0966,  0.0000,  0.0000,  0.3365,  0.2654,  0.0000, -0.8770,
-        -0.0380,  0.0000,  0.2035,  0.0496,  0.1791,  0.0000, -0.4595, -1.1865,
-         0.0000, -0.3728,  0.0000,  0.0000,  0.0000,  0.0000,  0.1892,  0.4412,
-        -0.1075,  0.0000, -0.3801,  0.0000,  0.0000, -0.1465,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1178, -0.1435,  0.2156,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6058,  0.0000,  0.0000, -0.2291,  0.0000,  0.0621],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.0410e-01, -7.3564e-09, -7.1895e-02,  6.1716e-02,  2.1431e+00,
-         1.7677e-01,  4.4529e-18, -3.5161e-11,  9.2339e-02,  6.4859e-02,
-        -2.8683e-01,  8.0051e-10, -7.4801e-02, -5.3148e-01, -7.9854e-15,
-        -5.9350e-12, -5.8380e-13,  4.4310e-02, -1.3503e-15,  1.6235e-11,
-         3.1639e-01,  2.4145e-01, -2.8675e-15, -8.9891e-01, -4.9640e-02,
-         2.3273e-11,  2.1889e-01,  5.5079e-02,  1.7035e-01,  4.1704e-15,
-        -4.2391e-01, -1.1949e+00, -1.0675e-18, -3.2751e-01,  5.2987e-13,
-         2.8590e-11,  1.7675e-14,  0.0000e+00,  1.6976e-01,  4.4156e-01,
-        -1.2947e-01, -2.3608e-12, -3.9670e-01,  4.3501e-15,  1.0640e-12,
-        -2.0468e-01, -1.5012e-10,  4.8441e-04,  5.0502e-10, -2.7693e-16,
-        -8.9603e-07,  9.5233e-02, -1.9198e-01,  1.6677e-01, -1.0148e-07,
-         5.7635e-20, -2.2651e-14,  8.1463e-18,  6.0210e-01, -1.1765e-09,
-         2.1761e-17, -2.2801e-01,  1.8283e-05,  3.8462e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4041,  0.0000, -0.0719,  0.0617,  2.1431,  0.1768,  0.0000,  0.0000,
-         0.0923,  0.0649, -0.2868,  0.0000, -0.0748, -0.5315,  0.0000,  0.0000,
-         0.0000,  0.0443,  0.0000,  0.0000,  0.3164,  0.2414,  0.0000, -0.8989,
-        -0.0496,  0.0000,  0.2189,  0.0551,  0.1704,  0.0000, -0.4239, -1.1949,
-         0.0000, -0.3275,  0.0000,  0.0000,  0.0000,  0.0000,  0.1698,  0.4416,
-        -0.1295,  0.0000, -0.3967,  0.0000,  0.0000, -0.2047,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0952, -0.1920,  0.1668,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6021,  0.0000,  0.0000, -0.2280,  0.0000,  0.0385],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4041,  0.0000, -0.0719,  0.0617,  2.1431,  0.1768,  0.0000,  0.0000,
-         0.0923,  0.0649, -0.2868,  0.0000, -0.0748, -0.5315,  0.0000,  0.0000,
-         0.0000,  0.0443,  0.0000,  0.0000,  0.3164,  0.2414,  0.0000, -0.8989,
-        -0.0496,  0.0000,  0.2189,  0.0551,  0.1704,  0.0000, -0.4239, -1.1949,
-         0.0000, -0.3275,  0.0000,  0.0000,  0.0000,  0.0000,  0.1698,  0.4416,
-        -0.1295,  0.0000, -0.3967,  0.0000,  0.0000, -0.2047,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0952, -0.1920,  0.1668,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6021,  0.0000,  0.0000, -0.2280,  0.0000,  0.0385],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.9583e-01, -6.6445e-09, -7.8200e-02,  5.3416e-02,  2.1392e+00,
-         1.7083e-01,  4.0220e-18, -3.1758e-11,  1.0095e-01,  5.8373e-02,
-        -3.2230e-01,  7.2305e-10, -1.1711e-01, -5.3943e-01, -7.2126e-15,
-        -5.3607e-12, -5.2731e-13, -1.1187e-02, -1.2196e-15,  1.4664e-11,
-         2.8215e-01,  2.1569e-01, -2.5900e-15, -9.1368e-01, -5.3962e-02,
-         2.1021e-11,  2.3058e-01,  5.4624e-02,  1.5303e-01,  3.7668e-15,
-        -3.9872e-01, -1.2010e+00, -9.6423e-19, -2.6792e-01,  4.7860e-13,
-         2.5823e-11,  1.5965e-14,  0.0000e+00,  1.3462e-01,  4.2412e-01,
-        -1.4048e-01, -2.1324e-12, -4.1408e-01,  3.9291e-15,  9.6102e-13,
-        -2.4794e-01, -1.3560e-10,  4.3753e-04,  4.5615e-10, -2.5013e-16,
-        -8.0932e-07,  7.1522e-02, -2.2935e-01,  1.0977e-01, -9.1662e-08,
-         5.2057e-20, -2.0459e-14,  7.3580e-18,  5.9550e-01, -1.0626e-09,
-         1.9655e-17, -2.3231e-01,  1.6514e-05,  2.0656e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3958,  0.0000, -0.0782,  0.0534,  2.1392,  0.1708,  0.0000,  0.0000,
-         0.1009,  0.0584, -0.3223,  0.0000, -0.1171, -0.5394,  0.0000,  0.0000,
-         0.0000, -0.0112,  0.0000,  0.0000,  0.2821,  0.2157,  0.0000, -0.9137,
-        -0.0540,  0.0000,  0.2306,  0.0546,  0.1530,  0.0000, -0.3987, -1.2010,
-         0.0000, -0.2679,  0.0000,  0.0000,  0.0000,  0.0000,  0.1346,  0.4241,
-        -0.1405,  0.0000, -0.4141,  0.0000,  0.0000, -0.2479,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0715, -0.2294,  0.1098,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5955,  0.0000,  0.0000, -0.2323,  0.0000,  0.0207],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3958,  0.0000, -0.0782,  0.0534,  2.1392,  0.1708,  0.0000,  0.0000,
-         0.1009,  0.0584, -0.3223,  0.0000, -0.1171, -0.5394,  0.0000,  0.0000,
-         0.0000, -0.0112,  0.0000,  0.0000,  0.2821,  0.2157,  0.0000, -0.9137,
-        -0.0540,  0.0000,  0.2306,  0.0546,  0.1530,  0.0000, -0.3987, -1.2010,
-         0.0000, -0.2679,  0.0000,  0.0000,  0.0000,  0.0000,  0.1346,  0.4241,
-        -0.1405,  0.0000, -0.4141,  0.0000,  0.0000, -0.2479,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0715, -0.2294,  0.1098,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5955,  0.0000,  0.0000, -0.2323,  0.0000,  0.0207],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.9069e-01, -6.0036e-09, -8.5107e-02,  5.0641e-02,  2.1372e+00,
-         1.5501e-01,  3.6341e-18, -2.8695e-11,  9.3670e-02,  5.3434e-02,
-        -3.2931e-01,  6.5331e-10, -1.5960e-01, -5.5094e-01, -6.5169e-15,
-        -4.8436e-12, -4.7645e-13, -2.5864e-02, -1.1020e-15,  1.3249e-11,
-         2.3870e-01,  1.8193e-01, -2.3402e-15, -9.2260e-01, -5.7603e-02,
-         1.8993e-11,  2.3431e-01,  4.1574e-02,  1.4704e-01,  3.4035e-15,
-        -3.7446e-01, -1.2043e+00, -8.7123e-19, -2.1458e-01,  4.3243e-13,
-         2.3333e-11,  1.4425e-14,  0.0000e+00,  1.1699e-01,  4.2401e-01,
-        -1.3944e-01, -1.9267e-12, -4.1944e-01,  3.5501e-15,  8.6832e-13,
-        -2.7847e-01, -1.2252e-10,  3.9533e-04,  4.1215e-10, -2.2601e-16,
-        -7.3126e-07,  3.8147e-02, -2.3733e-01,  8.3300e-02, -8.2821e-08,
-         4.7036e-20, -1.8486e-14,  6.6482e-18,  5.8993e-01, -9.6015e-10,
-         1.7759e-17, -2.2976e-01,  1.4921e-05, -1.1387e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3907,  0.0000, -0.0851,  0.0506,  2.1372,  0.1550,  0.0000,  0.0000,
-         0.0937,  0.0534, -0.3293,  0.0000, -0.1596, -0.5509,  0.0000,  0.0000,
-         0.0000, -0.0259,  0.0000,  0.0000,  0.2387,  0.1819,  0.0000, -0.9226,
-        -0.0576,  0.0000,  0.2343,  0.0416,  0.1470,  0.0000, -0.3745, -1.2043,
-         0.0000, -0.2146,  0.0000,  0.0000,  0.0000,  0.0000,  0.1170,  0.4240,
-        -0.1394,  0.0000, -0.4194,  0.0000,  0.0000, -0.2785,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0381, -0.2373,  0.0833,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5899,  0.0000,  0.0000, -0.2298,  0.0000, -0.0114],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3907,  0.0000, -0.0851,  0.0506,  2.1372,  0.1550,  0.0000,  0.0000,
-         0.0937,  0.0534, -0.3293,  0.0000, -0.1596, -0.5509,  0.0000,  0.0000,
-         0.0000, -0.0259,  0.0000,  0.0000,  0.2387,  0.1819,  0.0000, -0.9226,
-        -0.0576,  0.0000,  0.2343,  0.0416,  0.1470,  0.0000, -0.3745, -1.2043,
-         0.0000, -0.2146,  0.0000,  0.0000,  0.0000,  0.0000,  0.1170,  0.4240,
-        -0.1394,  0.0000, -0.4194,  0.0000,  0.0000, -0.2785,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0381, -0.2373,  0.0833,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5899,  0.0000,  0.0000, -0.2298,  0.0000, -0.0114],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.8641e-01, -5.4264e-09, -8.5104e-02,  5.5911e-02,  2.1365e+00,
-         1.4341e-01,  3.2847e-18, -2.5936e-11,  7.8114e-02,  4.0593e-02,
-        -3.3944e-01,  5.9050e-10, -1.8963e-01, -5.5894e-01, -5.8904e-15,
-        -4.3780e-12, -4.3064e-13, -4.1990e-02, -9.9604e-16,  1.1976e-11,
-         2.0089e-01,  1.5649e-01, -2.1152e-15, -9.2898e-01, -4.3973e-02,
-         1.7167e-11,  2.5516e-01,  3.9154e-02,  1.5059e-01,  3.0763e-15,
-        -3.4980e-01, -1.2072e+00, -7.8747e-19, -1.7054e-01,  3.9086e-13,
-         2.1089e-11,  1.3038e-14,  0.0000e+00,  1.0565e-01,  4.1317e-01,
-        -1.2028e-01, -1.7415e-12, -4.2372e-01,  3.2088e-15,  7.8484e-13,
-        -3.1389e-01, -1.1074e-10,  3.5733e-04,  3.7253e-10, -2.0428e-16,
-        -6.6096e-07,  1.0251e-03, -2.5526e-01,  8.4742e-02, -7.4859e-08,
-         4.2514e-20, -1.6709e-14,  6.0091e-18,  5.7808e-01, -8.6784e-10,
-         1.6052e-17, -2.3319e-01,  1.3487e-05, -4.5115e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.8641e-01,  0.0000e+00, -8.5104e-02,  5.5911e-02,  2.1365e+00,
-         1.4341e-01,  0.0000e+00,  0.0000e+00,  7.8114e-02,  4.0593e-02,
-        -3.3944e-01,  0.0000e+00, -1.8963e-01, -5.5894e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.1990e-02,  0.0000e+00,  0.0000e+00,
-         2.0089e-01,  1.5649e-01,  0.0000e+00, -9.2898e-01, -4.3973e-02,
-         0.0000e+00,  2.5516e-01,  3.9154e-02,  1.5059e-01,  0.0000e+00,
-        -3.4980e-01, -1.2072e+00,  0.0000e+00, -1.7054e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0565e-01,  4.1317e-01,
-        -1.2028e-01,  0.0000e+00, -4.2372e-01,  0.0000e+00,  0.0000e+00,
-        -3.1389e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.0251e-03, -2.5526e-01,  8.4742e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.7808e-01,  0.0000e+00,
-         0.0000e+00, -2.3319e-01,  0.0000e+00, -4.5115e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.8641e-01,  0.0000e+00, -8.5104e-02,  5.5911e-02,  2.1365e+00,
-         1.4341e-01,  0.0000e+00,  0.0000e+00,  7.8114e-02,  4.0593e-02,
-        -3.3944e-01,  0.0000e+00, -1.8963e-01, -5.5894e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.1990e-02,  0.0000e+00,  0.0000e+00,
-         2.0089e-01,  1.5649e-01,  0.0000e+00, -9.2898e-01, -4.3973e-02,
-         0.0000e+00,  2.5516e-01,  3.9154e-02,  1.5059e-01,  0.0000e+00,
-        -3.4980e-01, -1.2072e+00,  0.0000e+00, -1.7054e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0565e-01,  4.1317e-01,
-        -1.2028e-01,  0.0000e+00, -4.2372e-01,  0.0000e+00,  0.0000e+00,
-        -3.1389e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.0251e-03, -2.5526e-01,  8.4742e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.7808e-01,  0.0000e+00,
-         0.0000e+00, -2.3319e-01,  0.0000e+00, -4.5115e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.9714e-01, -4.9064e-09, -7.8195e-02,  6.7568e-02,  2.1378e+00,
-         1.3090e-01,  2.9699e-18, -2.3451e-11,  6.5038e-02,  1.5987e-02,
-        -3.2268e-01,  5.3391e-10, -2.0468e-01, -5.7340e-01, -5.3259e-15,
-        -3.9584e-12, -3.8937e-13, -3.0201e-02, -9.0059e-16,  1.0828e-11,
-         1.6396e-01,  1.2288e-01, -1.9125e-15, -9.2925e-01, -6.3366e-02,
-         1.5522e-11,  2.8130e-01,  4.2477e-02,  1.6243e-01,  2.7815e-15,
-        -3.4495e-01, -1.2086e+00, -7.1201e-19, -1.3689e-01,  3.5341e-13,
-         1.9068e-11,  1.1789e-14,  0.0000e+00,  1.0668e-01,  4.0531e-01,
-        -1.0487e-01, -1.5746e-12, -4.1730e-01,  2.9013e-15,  7.0963e-13,
-        -3.5162e-01, -1.0013e-10,  3.2308e-04,  3.3683e-10, -1.8470e-16,
-        -5.9762e-07, -5.3570e-02, -2.5745e-01,  1.0981e-01, -6.7685e-08,
-         3.8440e-20, -1.5108e-14,  5.4333e-18,  5.6280e-01, -7.8468e-10,
-         1.4513e-17, -2.2947e-01,  1.2194e-05, -6.2740e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3971,  0.0000, -0.0782,  0.0676,  2.1378,  0.1309,  0.0000,  0.0000,
-         0.0650,  0.0160, -0.3227,  0.0000, -0.2047, -0.5734,  0.0000,  0.0000,
-         0.0000, -0.0302,  0.0000,  0.0000,  0.1640,  0.1229,  0.0000, -0.9293,
-        -0.0634,  0.0000,  0.2813,  0.0425,  0.1624,  0.0000, -0.3449, -1.2086,
-         0.0000, -0.1369,  0.0000,  0.0000,  0.0000,  0.0000,  0.1067,  0.4053,
-        -0.1049,  0.0000, -0.4173,  0.0000,  0.0000, -0.3516,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0536, -0.2574,  0.1098,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5628,  0.0000,  0.0000, -0.2295,  0.0000, -0.0627],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3971,  0.0000, -0.0782,  0.0676,  2.1378,  0.1309,  0.0000,  0.0000,
-         0.0650,  0.0160, -0.3227,  0.0000, -0.2047, -0.5734,  0.0000,  0.0000,
-         0.0000, -0.0302,  0.0000,  0.0000,  0.1640,  0.1229,  0.0000, -0.9293,
-        -0.0634,  0.0000,  0.2813,  0.0425,  0.1624,  0.0000, -0.3449, -1.2086,
-         0.0000, -0.1369,  0.0000,  0.0000,  0.0000,  0.0000,  0.1067,  0.4053,
-        -0.1049,  0.0000, -0.4173,  0.0000,  0.0000, -0.3516,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0536, -0.2574,  0.1098,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5628,  0.0000,  0.0000, -0.2295,  0.0000, -0.0627],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.0914e-01, -4.4378e-09, -7.3298e-02,  8.6441e-02,  2.1403e+00,
-         1.2633e-01,  2.6863e-18, -2.1211e-11,  6.3865e-02, -1.4888e-02,
-        -2.8912e-01,  4.8291e-10, -2.1821e-01, -5.8506e-01, -4.8172e-15,
-        -3.5803e-12, -3.5218e-13, -4.6212e-03, -8.1457e-16,  9.7938e-12,
-         1.3957e-01,  9.3924e-02, -1.7298e-15, -9.2136e-01, -8.5674e-02,
-         1.4040e-11,  3.0935e-01,  6.4437e-02,  1.7740e-01,  2.5158e-15,
-        -3.6428e-01, -1.2063e+00, -6.4400e-19, -9.1872e-02,  3.1965e-13,
-         1.7247e-11,  1.0663e-14,  0.0000e+00,  1.0485e-01,  3.9285e-01,
-        -7.8289e-02, -1.4242e-12, -4.1797e-01,  2.6242e-15,  6.4185e-13,
-        -4.1902e-01, -9.0564e-11,  2.9222e-04,  3.0466e-10, -1.6706e-16,
-        -5.4054e-07, -7.4744e-02, -2.6964e-01,  1.1383e-01, -6.1220e-08,
-         3.4768e-20, -1.3665e-14,  4.9143e-18,  5.4747e-01, -7.0973e-10,
-         1.3127e-17, -2.4801e-01,  1.1029e-05, -8.5391e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4091,  0.0000, -0.0733,  0.0864,  2.1403,  0.1263,  0.0000,  0.0000,
-         0.0639, -0.0149, -0.2891,  0.0000, -0.2182, -0.5851,  0.0000,  0.0000,
-         0.0000, -0.0046,  0.0000,  0.0000,  0.1396,  0.0939,  0.0000, -0.9214,
-        -0.0857,  0.0000,  0.3093,  0.0644,  0.1774,  0.0000, -0.3643, -1.2063,
-         0.0000, -0.0919,  0.0000,  0.0000,  0.0000,  0.0000,  0.1048,  0.3928,
-        -0.0783,  0.0000, -0.4180,  0.0000,  0.0000, -0.4190,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0747, -0.2696,  0.1138,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5475,  0.0000,  0.0000, -0.2480,  0.0000, -0.0854],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4091,  0.0000, -0.0733,  0.0864,  2.1403,  0.1263,  0.0000,  0.0000,
-         0.0639, -0.0149, -0.2891,  0.0000, -0.2182, -0.5851,  0.0000,  0.0000,
-         0.0000, -0.0046,  0.0000,  0.0000,  0.1396,  0.0939,  0.0000, -0.9214,
-        -0.0857,  0.0000,  0.3093,  0.0644,  0.1774,  0.0000, -0.3643, -1.2063,
-         0.0000, -0.0919,  0.0000,  0.0000,  0.0000,  0.0000,  0.1048,  0.3928,
-        -0.0783,  0.0000, -0.4180,  0.0000,  0.0000, -0.4190,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0747, -0.2696,  0.1138,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5475,  0.0000,  0.0000, -0.2480,  0.0000, -0.0854],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.1951e-01, -4.0153e-09, -4.8519e-02,  1.2564e-01,  2.1434e+00,
-         1.4247e-01,  2.4305e-18, -1.9192e-11,  5.5981e-02, -1.8101e-02,
-        -2.4667e-01,  4.3694e-10, -2.2248e-01, -5.8814e-01, -4.3586e-15,
-        -3.2394e-12, -3.1865e-13, -5.4427e-02, -7.3701e-16,  8.8614e-12,
-         1.1645e-01,  6.5525e-02, -1.5651e-15, -9.1879e-01, -9.3407e-02,
-         1.2703e-11,  3.4264e-01,  8.6330e-02,  1.8878e-01,  2.2763e-15,
-        -3.7527e-01, -1.2056e+00, -5.8268e-19, -5.3023e-02,  2.8922e-13,
-         1.5605e-11,  9.6476e-15,  0.0000e+00,  1.4604e-01,  4.2025e-01,
-        -6.6499e-02, -1.2886e-12, -4.0602e-01,  2.3743e-15,  5.8074e-13,
-        -4.7286e-01, -8.1941e-11,  2.6440e-04,  2.7565e-10, -1.5116e-16,
-        -4.8907e-07, -8.0989e-02, -2.8446e-01,  1.3526e-01, -5.5391e-08,
-         3.1458e-20, -1.2364e-14,  4.4464e-18,  5.2404e-01, -6.4215e-10,
-         1.1877e-17, -2.6457e-01,  9.9793e-06, -1.2422e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4195,  0.0000, -0.0485,  0.1256,  2.1434,  0.1425,  0.0000,  0.0000,
-         0.0560, -0.0181, -0.2467,  0.0000, -0.2225, -0.5881,  0.0000,  0.0000,
-         0.0000, -0.0544,  0.0000,  0.0000,  0.1164,  0.0655,  0.0000, -0.9188,
-        -0.0934,  0.0000,  0.3426,  0.0863,  0.1888,  0.0000, -0.3753, -1.2056,
-         0.0000, -0.0530,  0.0000,  0.0000,  0.0000,  0.0000,  0.1460,  0.4202,
-        -0.0665,  0.0000, -0.4060,  0.0000,  0.0000, -0.4729,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0810, -0.2845,  0.1353,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5240,  0.0000,  0.0000, -0.2646,  0.0000, -0.1242],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4195,  0.0000, -0.0485,  0.1256,  2.1434,  0.1425,  0.0000,  0.0000,
-         0.0560, -0.0181, -0.2467,  0.0000, -0.2225, -0.5881,  0.0000,  0.0000,
-         0.0000, -0.0544,  0.0000,  0.0000,  0.1164,  0.0655,  0.0000, -0.9188,
-        -0.0934,  0.0000,  0.3426,  0.0863,  0.1888,  0.0000, -0.3753, -1.2056,
-         0.0000, -0.0530,  0.0000,  0.0000,  0.0000,  0.0000,  0.1460,  0.4202,
-        -0.0665,  0.0000, -0.4060,  0.0000,  0.0000, -0.4729,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0810, -0.2845,  0.1353,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5240,  0.0000,  0.0000, -0.2646,  0.0000, -0.1242],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.3152e-01, -3.6342e-09, -3.6905e-02,  1.4011e-01,  2.1469e+00,
-         1.3631e-01,  2.1998e-18, -1.7370e-11,  5.4049e-02, -3.2134e-02,
-        -1.8243e-01,  3.9547e-10, -2.2764e-01, -5.8756e-01, -3.9449e-15,
-        -2.9320e-12, -2.8841e-13, -9.1807e-02, -6.6707e-16,  8.0204e-12,
-         9.0834e-02,  4.1555e-02, -1.4166e-15, -9.1402e-01, -9.7152e-02,
-         1.1497e-11,  3.4283e-01,  1.0495e-01,  2.0176e-01,  2.0603e-15,
-        -3.9133e-01, -1.2041e+00, -5.2738e-19, -1.2702e-02,  2.6177e-13,
-         1.4124e-11,  8.7320e-15,  0.0000e+00,  1.8983e-01,  4.5268e-01,
-        -1.7059e-02, -1.1663e-12, -3.9586e-01,  2.1490e-15,  5.2563e-13,
-        -5.1225e-01, -7.4165e-11,  2.3931e-04,  2.4949e-10, -1.3681e-16,
-        -4.4266e-07, -7.3561e-02, -2.7951e-01,  1.0603e-01, -5.0135e-08,
-         2.8473e-20, -1.1190e-14,  4.0244e-18,  4.9906e-01, -5.8121e-10,
-         1.0750e-17, -2.7100e-01,  9.0322e-06, -1.6562e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4315,  0.0000, -0.0369,  0.1401,  2.1469,  0.1363,  0.0000,  0.0000,
-         0.0540, -0.0321, -0.1824,  0.0000, -0.2276, -0.5876,  0.0000,  0.0000,
-         0.0000, -0.0918,  0.0000,  0.0000,  0.0908,  0.0416,  0.0000, -0.9140,
-        -0.0972,  0.0000,  0.3428,  0.1050,  0.2018,  0.0000, -0.3913, -1.2041,
-         0.0000, -0.0127,  0.0000,  0.0000,  0.0000,  0.0000,  0.1898,  0.4527,
-        -0.0171,  0.0000, -0.3959,  0.0000,  0.0000, -0.5122,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0736, -0.2795,  0.1060,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4991,  0.0000,  0.0000, -0.2710,  0.0000, -0.1656],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4315,  0.0000, -0.0369,  0.1401,  2.1469,  0.1363,  0.0000,  0.0000,
-         0.0540, -0.0321, -0.1824,  0.0000, -0.2276, -0.5876,  0.0000,  0.0000,
-         0.0000, -0.0918,  0.0000,  0.0000,  0.0908,  0.0416,  0.0000, -0.9140,
-        -0.0972,  0.0000,  0.3428,  0.1050,  0.2018,  0.0000, -0.3913, -1.2041,
-         0.0000, -0.0127,  0.0000,  0.0000,  0.0000,  0.0000,  0.1898,  0.4527,
-        -0.0171,  0.0000, -0.3959,  0.0000,  0.0000, -0.5122,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0736, -0.2795,  0.1060,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4991,  0.0000,  0.0000, -0.2710,  0.0000, -0.1656],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.4302e-01, -3.2904e-09, -3.8418e-02,  1.3441e-01,  2.1490e+00,
-         1.1444e-01,  1.9917e-18, -1.5727e-11,  4.7702e-02, -4.8435e-02,
-        -1.1545e-01,  3.5806e-10, -2.4063e-01, -5.7813e-01, -3.5717e-15,
-        -2.6546e-12, -2.6113e-13, -1.1651e-01, -6.0396e-16,  7.2616e-12,
-         7.4061e-02,  2.7562e-02, -1.2826e-15, -9.0458e-01, -9.5855e-02,
-         1.0410e-11,  3.2987e-01,  1.1572e-01,  2.0123e-01,  1.8654e-15,
-        -3.7424e-01, -1.2018e+00, -4.7749e-19,  1.9392e-02,  2.3700e-13,
-         1.2788e-11,  7.9059e-15,  0.0000e+00,  2.1621e-01,  4.7830e-01,
-         3.9080e-02, -1.0560e-12, -3.8689e-01,  1.9457e-15,  4.7590e-13,
-        -5.3262e-01, -6.7149e-11,  2.1667e-04,  2.2589e-10, -1.2387e-16,
-        -4.0078e-07, -7.1431e-02, -2.6423e-01,  4.8083e-02, -4.5392e-08,
-         2.5779e-20, -1.0132e-14,  3.6437e-18,  4.7320e-01, -5.2623e-10,
-         9.7332e-18, -2.6505e-01,  8.1778e-06, -2.0608e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4430,  0.0000, -0.0384,  0.1344,  2.1490,  0.1144,  0.0000,  0.0000,
-         0.0477, -0.0484, -0.1155,  0.0000, -0.2406, -0.5781,  0.0000,  0.0000,
-         0.0000, -0.1165,  0.0000,  0.0000,  0.0741,  0.0276,  0.0000, -0.9046,
-        -0.0959,  0.0000,  0.3299,  0.1157,  0.2012,  0.0000, -0.3742, -1.2018,
-         0.0000,  0.0194,  0.0000,  0.0000,  0.0000,  0.0000,  0.2162,  0.4783,
-         0.0391,  0.0000, -0.3869,  0.0000,  0.0000, -0.5326,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0714, -0.2642,  0.0481,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4732,  0.0000,  0.0000, -0.2650,  0.0000, -0.2061],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4430,  0.0000, -0.0384,  0.1344,  2.1490,  0.1144,  0.0000,  0.0000,
-         0.0477, -0.0484, -0.1155,  0.0000, -0.2406, -0.5781,  0.0000,  0.0000,
-         0.0000, -0.1165,  0.0000,  0.0000,  0.0741,  0.0276,  0.0000, -0.9046,
-        -0.0959,  0.0000,  0.3299,  0.1157,  0.2012,  0.0000, -0.3742, -1.2018,
-         0.0000,  0.0194,  0.0000,  0.0000,  0.0000,  0.0000,  0.2162,  0.4783,
-         0.0391,  0.0000, -0.3869,  0.0000,  0.0000, -0.5326,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0714, -0.2642,  0.0481,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4732,  0.0000,  0.0000, -0.2650,  0.0000, -0.2061],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.5069e-01, -2.9801e-09, -4.7820e-02,  1.1929e-01,  2.1510e+00,
-         8.7137e-02,  1.8039e-18, -1.4244e-11,  3.7938e-02, -5.9939e-02,
-        -6.7333e-02,  3.2429e-10, -2.7175e-01, -5.6222e-01, -3.2349e-15,
-        -2.4043e-12, -2.3650e-13, -1.2331e-01, -5.4701e-16,  6.5769e-12,
-         5.9394e-02,  2.5223e-02, -1.1616e-15, -8.9001e-01, -9.8985e-02,
-         9.4280e-12,  3.0495e-01,  1.2966e-01,  1.9925e-01,  1.6895e-15,
-        -3.3493e-01, -1.2005e+00, -4.3246e-19,  5.8012e-02,  2.1465e-13,
-         1.1582e-11,  7.1604e-15,  0.0000e+00,  2.1573e-01,  4.9980e-01,
-         1.0437e-01, -9.5639e-13, -3.8438e-01,  1.7622e-15,  4.3102e-13,
-        -5.5048e-01, -6.0816e-11,  1.9624e-04,  2.0459e-10, -1.1219e-16,
-        -3.6299e-07, -7.5312e-02, -2.5040e-01,  3.0210e-02, -4.1111e-08,
-         2.3348e-20, -9.1762e-15,  3.3001e-18,  4.5615e-01, -4.7661e-10,
-         8.8153e-18, -2.5110e-01,  7.4066e-06, -2.5420e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4507,  0.0000, -0.0478,  0.1193,  2.1510,  0.0871,  0.0000,  0.0000,
-         0.0379, -0.0599, -0.0673,  0.0000, -0.2718, -0.5622,  0.0000,  0.0000,
-         0.0000, -0.1233,  0.0000,  0.0000,  0.0594,  0.0252,  0.0000, -0.8900,
-        -0.0990,  0.0000,  0.3049,  0.1297,  0.1992,  0.0000, -0.3349, -1.2005,
-         0.0000,  0.0580,  0.0000,  0.0000,  0.0000,  0.0000,  0.2157,  0.4998,
-         0.1044,  0.0000, -0.3844,  0.0000,  0.0000, -0.5505,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0753, -0.2504,  0.0302,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4561,  0.0000,  0.0000, -0.2511,  0.0000, -0.2542],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4507,  0.0000, -0.0478,  0.1193,  2.1510,  0.0871,  0.0000,  0.0000,
-         0.0379, -0.0599, -0.0673,  0.0000, -0.2718, -0.5622,  0.0000,  0.0000,
-         0.0000, -0.1233,  0.0000,  0.0000,  0.0594,  0.0252,  0.0000, -0.8900,
-        -0.0990,  0.0000,  0.3049,  0.1297,  0.1992,  0.0000, -0.3349, -1.2005,
-         0.0000,  0.0580,  0.0000,  0.0000,  0.0000,  0.0000,  0.2157,  0.4998,
-         0.1044,  0.0000, -0.3844,  0.0000,  0.0000, -0.5505,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0753, -0.2504,  0.0302,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4561,  0.0000,  0.0000, -0.2511,  0.0000, -0.2542],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.4885e-01, -2.7000e-09, -6.3930e-02,  1.0563e-01,  2.1523e+00,
-         7.1791e-02,  1.6343e-18, -1.2905e-11,  2.4376e-02, -5.7642e-02,
-        -4.2179e-02,  2.9381e-10, -3.2329e-01, -5.4201e-01, -2.9308e-15,
-        -2.1783e-12, -2.1427e-13, -9.7448e-02, -4.9559e-16,  5.9586e-12,
-         5.5199e-02,  3.6746e-02, -1.0524e-15, -8.6701e-01, -8.7867e-02,
-         8.5418e-12,  2.9679e-01,  1.5468e-01,  2.1523e-01,  1.5306e-15,
-        -2.6481e-01, -1.1969e+00, -3.9181e-19,  1.2046e-01,  1.9448e-13,
-         1.0493e-11,  6.4873e-15,  0.0000e+00,  1.7330e-01,  4.8980e-01,
-         1.4973e-01, -8.6649e-13, -3.9479e-01,  1.5966e-15,  3.9051e-13,
-        -5.7992e-01, -5.5100e-11,  1.7779e-04,  1.8535e-10, -1.0164e-16,
-        -3.2887e-07, -6.8382e-02, -2.5192e-01,  3.1946e-02, -3.7247e-08,
-         2.1153e-20, -8.3136e-15,  2.9899e-18,  4.4105e-01, -4.3180e-10,
-         7.9867e-18, -2.4006e-01,  6.7104e-06, -3.0129e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4489,  0.0000, -0.0639,  0.1056,  2.1523,  0.0718,  0.0000,  0.0000,
-         0.0244, -0.0576, -0.0422,  0.0000, -0.3233, -0.5420,  0.0000,  0.0000,
-         0.0000, -0.0974,  0.0000,  0.0000,  0.0552,  0.0367,  0.0000, -0.8670,
-        -0.0879,  0.0000,  0.2968,  0.1547,  0.2152,  0.0000, -0.2648, -1.1969,
-         0.0000,  0.1205,  0.0000,  0.0000,  0.0000,  0.0000,  0.1733,  0.4898,
-         0.1497,  0.0000, -0.3948,  0.0000,  0.0000, -0.5799,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0684, -0.2519,  0.0319,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4410,  0.0000,  0.0000, -0.2401,  0.0000, -0.3013],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4489,  0.0000, -0.0639,  0.1056,  2.1523,  0.0718,  0.0000,  0.0000,
-         0.0244, -0.0576, -0.0422,  0.0000, -0.3233, -0.5420,  0.0000,  0.0000,
-         0.0000, -0.0974,  0.0000,  0.0000,  0.0552,  0.0367,  0.0000, -0.8670,
-        -0.0879,  0.0000,  0.2968,  0.1547,  0.2152,  0.0000, -0.2648, -1.1969,
-         0.0000,  0.1205,  0.0000,  0.0000,  0.0000,  0.0000,  0.1733,  0.4898,
-         0.1497,  0.0000, -0.3948,  0.0000,  0.0000, -0.5799,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0684, -0.2519,  0.0319,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4410,  0.0000,  0.0000, -0.2401,  0.0000, -0.3013],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.5226e-01, -2.4470e-09, -7.0564e-02,  9.7122e-02,  2.1540e+00,
-         6.0743e-02,  1.4812e-18, -1.1696e-11,  1.5772e-02, -4.9603e-02,
-        -3.8418e-03,  2.6628e-10, -3.5757e-01, -5.1838e-01, -2.6562e-15,
-        -1.9742e-12, -1.9419e-13, -8.3801e-02, -4.4915e-16,  5.4003e-12,
-         5.6428e-02,  4.6823e-02, -9.5383e-16, -8.5112e-01, -7.0659e-02,
-         7.7414e-12,  2.8190e-01,  1.7924e-01,  2.2288e-01,  1.3872e-15,
-        -1.9139e-01, -1.1950e+00, -3.5510e-19,  1.5786e-01,  1.7625e-13,
-         9.5099e-12,  5.8794e-15,  0.0000e+00,  1.2869e-01,  4.9997e-01,
-         1.8625e-01, -7.8529e-13, -3.9704e-01,  1.4470e-15,  3.5391e-13,
-        -5.7498e-01, -4.9936e-11,  1.6113e-04,  1.6799e-10, -9.2117e-17,
-        -2.9805e-07, -5.8489e-02, -2.5481e-01,  2.2397e-02, -3.3757e-08,
-         1.9171e-20, -7.5346e-15,  2.7097e-18,  4.2627e-01, -3.9134e-10,
-         7.2383e-18, -2.2028e-01,  6.0816e-06, -3.3484e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4523,  0.0000, -0.0706,  0.0971,  2.1540,  0.0607,  0.0000,  0.0000,
-         0.0158, -0.0496, -0.0038,  0.0000, -0.3576, -0.5184,  0.0000,  0.0000,
-         0.0000, -0.0838,  0.0000,  0.0000,  0.0564,  0.0468,  0.0000, -0.8511,
-        -0.0707,  0.0000,  0.2819,  0.1792,  0.2229,  0.0000, -0.1914, -1.1950,
-         0.0000,  0.1579,  0.0000,  0.0000,  0.0000,  0.0000,  0.1287,  0.5000,
-         0.1862,  0.0000, -0.3970,  0.0000,  0.0000, -0.5750,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0585, -0.2548,  0.0224,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4263,  0.0000,  0.0000, -0.2203,  0.0000, -0.3348],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4523,  0.0000, -0.0706,  0.0971,  2.1540,  0.0607,  0.0000,  0.0000,
-         0.0158, -0.0496, -0.0038,  0.0000, -0.3576, -0.5184,  0.0000,  0.0000,
-         0.0000, -0.0838,  0.0000,  0.0000,  0.0564,  0.0468,  0.0000, -0.8511,
-        -0.0707,  0.0000,  0.2819,  0.1792,  0.2229,  0.0000, -0.1914, -1.1950,
-         0.0000,  0.1579,  0.0000,  0.0000,  0.0000,  0.0000,  0.1287,  0.5000,
-         0.1862,  0.0000, -0.3970,  0.0000,  0.0000, -0.5750,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0585, -0.2548,  0.0224,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4263,  0.0000,  0.0000, -0.2203,  0.0000, -0.3348],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.5117e-01, -2.2184e-09, -6.8056e-02,  9.7886e-02,  2.1546e+00,
-         6.9671e-02,  1.3428e-18, -1.0603e-11, -4.4980e-03, -3.1720e-02,
-         1.4815e-02,  2.4140e-10, -3.8746e-01, -5.0271e-01, -2.4081e-15,
-        -1.7898e-12, -1.7605e-13, -7.9299e-02, -4.0719e-16,  4.8958e-12,
-         7.3800e-02,  3.7792e-02, -8.6473e-16, -8.4892e-01, -5.4301e-02,
-         7.0182e-12,  2.8454e-01,  2.1942e-01,  2.3870e-01,  1.2576e-15,
-        -1.2650e-01, -1.1929e+00, -3.2193e-19,  1.9131e-01,  1.5979e-13,
-         8.6216e-12,  5.3302e-15,  0.0000e+00,  8.4757e-02,  5.1084e-01,
-         2.0495e-01, -7.1193e-13, -3.9683e-01,  1.3118e-15,  3.2085e-13,
-        -5.6878e-01, -4.5272e-11,  1.4608e-04,  1.5229e-10, -8.3512e-17,
-        -2.7021e-07, -4.3175e-02, -2.5561e-01,  1.0178e-02, -3.0603e-08,
-         1.7380e-20, -6.8308e-15,  2.4566e-18,  4.2125e-01, -3.5478e-10,
-         6.5621e-18, -1.9989e-01,  5.5135e-06, -3.5772e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4512,  0.0000, -0.0681,  0.0979,  2.1546,  0.0697,  0.0000,  0.0000,
-        -0.0045, -0.0317,  0.0148,  0.0000, -0.3875, -0.5027,  0.0000,  0.0000,
-         0.0000, -0.0793,  0.0000,  0.0000,  0.0738,  0.0378,  0.0000, -0.8489,
-        -0.0543,  0.0000,  0.2845,  0.2194,  0.2387,  0.0000, -0.1265, -1.1929,
-         0.0000,  0.1913,  0.0000,  0.0000,  0.0000,  0.0000,  0.0848,  0.5108,
-         0.2050,  0.0000, -0.3968,  0.0000,  0.0000, -0.5688,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0432, -0.2556,  0.0102,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4213,  0.0000,  0.0000, -0.1999,  0.0000, -0.3577],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4512,  0.0000, -0.0681,  0.0979,  2.1546,  0.0697,  0.0000,  0.0000,
-        -0.0045, -0.0317,  0.0148,  0.0000, -0.3875, -0.5027,  0.0000,  0.0000,
-         0.0000, -0.0793,  0.0000,  0.0000,  0.0738,  0.0378,  0.0000, -0.8489,
-        -0.0543,  0.0000,  0.2845,  0.2194,  0.2387,  0.0000, -0.1265, -1.1929,
-         0.0000,  0.1913,  0.0000,  0.0000,  0.0000,  0.0000,  0.0848,  0.5108,
-         0.2050,  0.0000, -0.3968,  0.0000,  0.0000, -0.5688,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0432, -0.2556,  0.0102,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4213,  0.0000,  0.0000, -0.1999,  0.0000, -0.3577],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.5410e-01, -2.0118e-09, -5.3254e-02,  9.1695e-02,  2.1544e+00,
-         5.7660e-02,  1.2178e-18, -9.6158e-12, -1.7433e-02, -8.7435e-03,
-         3.6701e-02,  2.1892e-10, -3.9375e-01, -4.9025e-01, -2.1838e-15,
-        -1.6231e-12, -1.5966e-13, -7.8029e-02, -3.6927e-16,  4.4399e-12,
-         6.7282e-02,  2.6136e-02, -7.8420e-16, -8.6148e-01, -3.3450e-02,
-         6.3647e-12,  2.7109e-01,  2.2843e-01,  2.3383e-01,  1.1405e-15,
-        -9.2931e-02, -1.1907e+00, -2.9195e-19,  2.0636e-01,  1.4491e-13,
-         7.8187e-12,  4.8338e-15,  0.0000e+00,  7.1849e-02,  5.3892e-01,
-         2.1767e-01, -6.4564e-13, -3.8712e-01,  1.1896e-15,  2.9098e-13,
-        -5.4051e-01, -4.1056e-11,  1.3248e-04,  1.3811e-10, -7.5735e-17,
-        -2.4504e-07, -1.0451e-02, -2.5879e-01, -2.8148e-03, -2.7753e-08,
-         1.5762e-20, -6.1947e-15,  2.2278e-18,  4.0812e-01, -3.2175e-10,
-         5.9510e-18, -1.7739e-01,  5.0000e-06, -3.8542e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4541,  0.0000, -0.0533,  0.0917,  2.1544,  0.0577,  0.0000,  0.0000,
-        -0.0174, -0.0087,  0.0367,  0.0000, -0.3937, -0.4903,  0.0000,  0.0000,
-         0.0000, -0.0780,  0.0000,  0.0000,  0.0673,  0.0261,  0.0000, -0.8615,
-        -0.0335,  0.0000,  0.2711,  0.2284,  0.2338,  0.0000, -0.0929, -1.1907,
-         0.0000,  0.2064,  0.0000,  0.0000,  0.0000,  0.0000,  0.0718,  0.5389,
-         0.2177,  0.0000, -0.3871,  0.0000,  0.0000, -0.5405,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0105, -0.2588, -0.0028,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4081,  0.0000,  0.0000, -0.1774,  0.0000, -0.3854],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4541,  0.0000, -0.0533,  0.0917,  2.1544,  0.0577,  0.0000,  0.0000,
-        -0.0174, -0.0087,  0.0367,  0.0000, -0.3937, -0.4903,  0.0000,  0.0000,
-         0.0000, -0.0780,  0.0000,  0.0000,  0.0673,  0.0261,  0.0000, -0.8615,
-        -0.0335,  0.0000,  0.2711,  0.2284,  0.2338,  0.0000, -0.0929, -1.1907,
-         0.0000,  0.2064,  0.0000,  0.0000,  0.0000,  0.0000,  0.0718,  0.5389,
-         0.2177,  0.0000, -0.3871,  0.0000,  0.0000, -0.5405,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0105, -0.2588, -0.0028,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4081,  0.0000,  0.0000, -0.1774,  0.0000, -0.3854],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.3936e-01, -1.8250e-09, -3.3250e-02,  6.8244e-02,  2.1550e+00,
-         3.5034e-02,  1.1047e-18, -8.7231e-12, -2.6099e-02, -1.1488e-02,
-         8.1688e-02,  1.9860e-10, -3.8277e-01, -4.8096e-01, -1.9811e-15,
-        -1.4724e-12, -1.4484e-13, -9.5814e-02, -3.3499e-16,  4.0277e-12,
-         4.6840e-02, -1.3753e-02, -7.1140e-16, -8.7648e-01, -2.8910e-02,
-         5.7738e-12,  2.5531e-01,  2.1836e-01,  2.2211e-01,  1.0346e-15,
-        -8.4111e-02, -1.1892e+00, -2.6484e-19,  2.1351e-01,  1.3146e-13,
-         7.0929e-12,  4.3851e-15,  0.0000e+00,  8.6293e-02,  5.7013e-01,
-         2.2374e-01, -5.8570e-13, -3.7380e-01,  1.0792e-15,  2.6396e-13,
-        -5.0942e-01, -3.7244e-11,  1.2018e-04,  1.2529e-10, -6.8705e-17,
-        -2.2230e-07,  2.0181e-02, -2.4948e-01, -6.8828e-02, -2.5177e-08,
-         1.4299e-20, -5.6196e-15,  2.0210e-18,  4.0100e-01, -2.9188e-10,
-         5.3986e-18, -1.6735e-01,  4.5359e-06, -4.0563e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4394,  0.0000, -0.0333,  0.0682,  2.1550,  0.0350,  0.0000,  0.0000,
-        -0.0261, -0.0115,  0.0817,  0.0000, -0.3828, -0.4810,  0.0000,  0.0000,
-         0.0000, -0.0958,  0.0000,  0.0000,  0.0468, -0.0138,  0.0000, -0.8765,
-        -0.0289,  0.0000,  0.2553,  0.2184,  0.2221,  0.0000, -0.0841, -1.1892,
-         0.0000,  0.2135,  0.0000,  0.0000,  0.0000,  0.0000,  0.0863,  0.5701,
-         0.2237,  0.0000, -0.3738,  0.0000,  0.0000, -0.5094,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0202, -0.2495, -0.0688,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4010,  0.0000,  0.0000, -0.1673,  0.0000, -0.4056],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4394,  0.0000, -0.0333,  0.0682,  2.1550,  0.0350,  0.0000,  0.0000,
-        -0.0261, -0.0115,  0.0817,  0.0000, -0.3828, -0.4810,  0.0000,  0.0000,
-         0.0000, -0.0958,  0.0000,  0.0000,  0.0468, -0.0138,  0.0000, -0.8765,
-        -0.0289,  0.0000,  0.2553,  0.2184,  0.2221,  0.0000, -0.0841, -1.1892,
-         0.0000,  0.2135,  0.0000,  0.0000,  0.0000,  0.0000,  0.0863,  0.5701,
-         0.2237,  0.0000, -0.3738,  0.0000,  0.0000, -0.5094,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0202, -0.2495, -0.0688,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4010,  0.0000,  0.0000, -0.1673,  0.0000, -0.4056],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.1626e-01, -1.6561e-09, -5.6330e-03,  5.2609e-02,  2.1543e+00,
-         1.0344e-02,  1.0025e-18, -7.9158e-12, -2.8198e-02, -1.2641e-02,
-         1.3531e-01,  1.8022e-10, -3.6197e-01, -4.8003e-01, -1.7977e-15,
-        -1.3361e-12, -1.3143e-13, -1.1626e-01, -3.0399e-16,  3.6550e-12,
-         2.6060e-02, -7.2710e-02, -6.4556e-16, -8.9298e-01, -2.3719e-02,
-         5.2394e-12,  2.5809e-01,  2.1408e-01,  1.9543e-01,  9.3888e-16,
-        -5.8618e-02, -1.1883e+00, -2.4033e-19,  2.0416e-01,  1.1929e-13,
-         6.4364e-12,  3.9792e-15,  0.0000e+00,  8.9009e-02,  5.9202e-01,
-         2.1928e-01, -5.3149e-13, -3.5344e-01,  9.7933e-16,  2.3953e-13,
-        -4.5633e-01, -3.3797e-11,  1.0906e-04,  1.1369e-10, -6.2346e-17,
-        -2.0172e-07,  5.1264e-02, -2.3042e-01, -1.2935e-01, -2.2847e-08,
-         1.2975e-20, -5.0995e-15,  1.8340e-18,  4.0559e-01, -2.6486e-10,
-         4.8989e-18, -1.3574e-01,  4.1161e-06, -4.1011e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4163,  0.0000, -0.0056,  0.0526,  2.1543,  0.0103,  0.0000,  0.0000,
-        -0.0282, -0.0126,  0.1353,  0.0000, -0.3620, -0.4800,  0.0000,  0.0000,
-         0.0000, -0.1163,  0.0000,  0.0000,  0.0261, -0.0727,  0.0000, -0.8930,
-        -0.0237,  0.0000,  0.2581,  0.2141,  0.1954,  0.0000, -0.0586, -1.1883,
-         0.0000,  0.2042,  0.0000,  0.0000,  0.0000,  0.0000,  0.0890,  0.5920,
-         0.2193,  0.0000, -0.3534,  0.0000,  0.0000, -0.4563,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0513, -0.2304, -0.1293,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4056,  0.0000,  0.0000, -0.1357,  0.0000, -0.4101],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4163,  0.0000, -0.0056,  0.0526,  2.1543,  0.0103,  0.0000,  0.0000,
-        -0.0282, -0.0126,  0.1353,  0.0000, -0.3620, -0.4800,  0.0000,  0.0000,
-         0.0000, -0.1163,  0.0000,  0.0000,  0.0261, -0.0727,  0.0000, -0.8930,
-        -0.0237,  0.0000,  0.2581,  0.2141,  0.1954,  0.0000, -0.0586, -1.1883,
-         0.0000,  0.2042,  0.0000,  0.0000,  0.0000,  0.0000,  0.0890,  0.5920,
-         0.2193,  0.0000, -0.3534,  0.0000,  0.0000, -0.4563,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0513, -0.2304, -0.1293,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4056,  0.0000,  0.0000, -0.1357,  0.0000, -0.4101],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.9457e-01, -1.5033e-09,  1.2561e-02,  3.3995e-02,  2.1531e+00,
-        -2.3256e-02,  9.0999e-19, -7.1854e-12, -1.0905e-02, -1.6835e-02,
-         1.9659e-01,  1.6359e-10, -3.4158e-01, -4.7913e-01, -1.6319e-15,
-        -1.2129e-12, -1.1930e-13, -1.4442e-01, -2.7594e-16,  3.3177e-12,
-        -2.4849e-03, -1.2110e-01, -5.8600e-16, -9.0865e-01, -1.7995e-02,
-         4.7560e-12,  2.6374e-01,  1.9614e-01,  1.4722e-01,  8.5225e-16,
-        -3.6158e-02, -1.1871e+00, -2.1816e-19,  1.8914e-01,  1.0828e-13,
-         5.8426e-12,  3.6121e-15,  0.0000e+00,  1.0360e-01,  6.1012e-01,
-         2.1354e-01, -4.8245e-13, -3.3471e-01,  8.8897e-16,  2.1743e-13,
-        -3.8156e-01, -3.0679e-11,  9.8993e-05,  1.0320e-10, -5.6593e-17,
-        -1.8311e-07,  7.4471e-02, -2.0557e-01, -1.7670e-01, -2.0739e-08,
-         1.1778e-20, -4.6290e-15,  1.6647e-18,  4.0100e-01, -2.4043e-10,
-         4.4469e-18, -9.1986e-02,  3.7363e-06, -4.1120e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3946,  0.0000,  0.0126,  0.0340,  2.1531, -0.0233,  0.0000,  0.0000,
-        -0.0109, -0.0168,  0.1966,  0.0000, -0.3416, -0.4791,  0.0000,  0.0000,
-         0.0000, -0.1444,  0.0000,  0.0000, -0.0025, -0.1211,  0.0000, -0.9087,
-        -0.0180,  0.0000,  0.2637,  0.1961,  0.1472,  0.0000, -0.0362, -1.1871,
-         0.0000,  0.1891,  0.0000,  0.0000,  0.0000,  0.0000,  0.1036,  0.6101,
-         0.2135,  0.0000, -0.3347,  0.0000,  0.0000, -0.3816,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0745, -0.2056, -0.1767,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4010,  0.0000,  0.0000, -0.0920,  0.0000, -0.4112],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3946,  0.0000,  0.0126,  0.0340,  2.1531, -0.0233,  0.0000,  0.0000,
-        -0.0109, -0.0168,  0.1966,  0.0000, -0.3416, -0.4791,  0.0000,  0.0000,
-         0.0000, -0.1444,  0.0000,  0.0000, -0.0025, -0.1211,  0.0000, -0.9087,
-        -0.0180,  0.0000,  0.2637,  0.1961,  0.1472,  0.0000, -0.0362, -1.1871,
-         0.0000,  0.1891,  0.0000,  0.0000,  0.0000,  0.0000,  0.1036,  0.6101,
-         0.2135,  0.0000, -0.3347,  0.0000,  0.0000, -0.3816,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0745, -0.2056, -0.1767,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4010,  0.0000,  0.0000, -0.0920,  0.0000, -0.4112],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7359e-01, -1.3650e-09,  3.3985e-02,  2.3786e-02,  2.1518e+00,
-        -6.9313e-02,  8.2628e-19, -6.5244e-12,  6.9137e-03, -2.1519e-02,
-         2.3598e-01,  1.4854e-10, -3.2082e-01, -4.8588e-01, -1.4818e-15,
-        -1.1013e-12, -1.0833e-13, -1.4051e-01, -2.5056e-16,  3.0125e-12,
-        -2.7577e-02, -1.4907e-01, -5.3209e-16, -9.1802e-01, -2.5576e-02,
-         4.3185e-12,  2.7453e-01,  1.8112e-01,  1.0130e-01,  7.7385e-16,
-        -1.4312e-02, -1.1850e+00, -1.9809e-19,  1.7486e-01,  9.8323e-14,
-         5.3051e-12,  3.2798e-15,  0.0000e+00,  1.3591e-01,  6.3475e-01,
-         1.8978e-01, -4.3807e-13, -3.2194e-01,  8.0719e-16,  1.9743e-13,
-        -3.2082e-01, -2.7857e-11,  8.9887e-05,  9.3711e-11, -5.1388e-17,
-        -1.6627e-07,  8.6494e-02, -1.7494e-01, -1.6283e-01, -1.8831e-08,
-         1.0695e-20, -4.2032e-15,  1.5116e-18,  3.9472e-01, -2.1831e-10,
-         4.0379e-18, -3.9465e-02,  3.3926e-06, -4.0800e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3736,  0.0000,  0.0340,  0.0238,  2.1518, -0.0693,  0.0000,  0.0000,
-         0.0069, -0.0215,  0.2360,  0.0000, -0.3208, -0.4859,  0.0000,  0.0000,
-         0.0000, -0.1405,  0.0000,  0.0000, -0.0276, -0.1491,  0.0000, -0.9180,
-        -0.0256,  0.0000,  0.2745,  0.1811,  0.1013,  0.0000, -0.0143, -1.1850,
-         0.0000,  0.1749,  0.0000,  0.0000,  0.0000,  0.0000,  0.1359,  0.6347,
-         0.1898,  0.0000, -0.3219,  0.0000,  0.0000, -0.3208,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0865, -0.1749, -0.1628,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3947,  0.0000,  0.0000, -0.0395,  0.0000, -0.4080],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3736,  0.0000,  0.0340,  0.0238,  2.1518, -0.0693,  0.0000,  0.0000,
-         0.0069, -0.0215,  0.2360,  0.0000, -0.3208, -0.4859,  0.0000,  0.0000,
-         0.0000, -0.1405,  0.0000,  0.0000, -0.0276, -0.1491,  0.0000, -0.9180,
-        -0.0256,  0.0000,  0.2745,  0.1811,  0.1013,  0.0000, -0.0143, -1.1850,
-         0.0000,  0.1749,  0.0000,  0.0000,  0.0000,  0.0000,  0.1359,  0.6347,
-         0.1898,  0.0000, -0.3219,  0.0000,  0.0000, -0.3208,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0865, -0.1749, -0.1628,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3947,  0.0000,  0.0000, -0.0395,  0.0000, -0.4080],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4993e-01, -1.2398e-09,  3.9284e-02,  1.5900e-02,  2.1497e+00,
-        -1.1369e-01,  7.5050e-19, -5.9261e-12,  3.3193e-02, -2.6449e-02,
-         2.6991e-01,  1.3492e-10, -3.1109e-01, -4.8825e-01, -1.3459e-15,
-        -1.0003e-12, -9.8394e-14, -1.3512e-01, -2.2758e-16,  2.7363e-12,
-        -4.3216e-02, -1.5421e-01, -4.8329e-16, -9.1541e-01, -2.5324e-02,
-         3.9225e-12,  2.7860e-01,  1.6416e-01,  5.7777e-02,  7.0288e-16,
-        -9.3030e-03, -1.1830e+00, -1.7992e-19,  1.6200e-01,  8.9305e-14,
-         4.8186e-12,  2.9790e-15,  0.0000e+00,  1.3628e-01,  6.4147e-01,
-         1.6949e-01, -3.9790e-13, -3.2301e-01,  7.3316e-16,  1.7932e-13,
-        -2.6976e-01, -2.5302e-11,  8.1643e-05,  8.5116e-11, -4.6675e-17,
-        -1.5102e-07,  9.0189e-02, -1.4838e-01, -1.6012e-01, -1.7104e-08,
-         9.7138e-21, -3.8177e-15,  1.3730e-18,  3.9614e-01, -1.9829e-10,
-         3.6676e-18,  3.5202e-03,  3.0814e-06, -4.0020e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3499,  0.0000,  0.0393,  0.0159,  2.1497, -0.1137,  0.0000,  0.0000,
-         0.0332, -0.0264,  0.2699,  0.0000, -0.3111, -0.4882,  0.0000,  0.0000,
-         0.0000, -0.1351,  0.0000,  0.0000, -0.0432, -0.1542,  0.0000, -0.9154,
-        -0.0253,  0.0000,  0.2786,  0.1642,  0.0578,  0.0000, -0.0093, -1.1830,
-         0.0000,  0.1620,  0.0000,  0.0000,  0.0000,  0.0000,  0.1363,  0.6415,
-         0.1695,  0.0000, -0.3230,  0.0000,  0.0000, -0.2698,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0902, -0.1484, -0.1601,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3961,  0.0000,  0.0000,  0.0035,  0.0000, -0.4002],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3499,  0.0000,  0.0393,  0.0159,  2.1497, -0.1137,  0.0000,  0.0000,
-         0.0332, -0.0264,  0.2699,  0.0000, -0.3111, -0.4882,  0.0000,  0.0000,
-         0.0000, -0.1351,  0.0000,  0.0000, -0.0432, -0.1542,  0.0000, -0.9154,
-        -0.0253,  0.0000,  0.2786,  0.1642,  0.0578,  0.0000, -0.0093, -1.1830,
-         0.0000,  0.1620,  0.0000,  0.0000,  0.0000,  0.0000,  0.1363,  0.6415,
-         0.1695,  0.0000, -0.3230,  0.0000,  0.0000, -0.2698,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0902, -0.1484, -0.1601,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3961,  0.0000,  0.0000,  0.0035,  0.0000, -0.4002],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2802e-01, -1.1265e-09,  4.1982e-02, -1.0714e-02,  2.1484e+00,
-        -1.6262e-01,  6.8188e-19, -5.3842e-12,  6.7824e-02, -3.8583e-02,
-         2.9242e-01,  1.2258e-10, -3.0277e-01, -4.9186e-01, -1.2228e-15,
-        -9.0883e-13, -8.9397e-14, -1.3027e-01, -2.0677e-16,  2.4861e-12,
-        -4.7100e-02, -1.5671e-01, -4.3910e-16, -9.1194e-01, -3.2335e-02,
-         3.5638e-12,  2.7844e-01,  1.4389e-01,  2.7577e-02,  6.3861e-16,
-        -1.6871e-02, -1.1807e+00, -1.6347e-19,  1.6246e-01,  8.1139e-14,
-         4.3780e-12,  2.7066e-15,  0.0000e+00,  1.3042e-01,  6.4096e-01,
-         1.5434e-01, -3.6151e-13, -3.2892e-01,  6.6612e-16,  1.6293e-13,
-        -2.4004e-01, -2.2989e-11,  7.4178e-05,  7.7334e-11, -4.2407e-17,
-        -1.3721e-07,  1.0037e-01, -1.0142e-01, -1.6728e-01, -1.5540e-08,
-         8.8256e-21, -3.4686e-15,  1.2474e-18,  4.0574e-01, -1.8016e-10,
-         3.3322e-18,  3.5225e-02,  2.7997e-06, -3.8769e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3280,  0.0000,  0.0420, -0.0107,  2.1484, -0.1626,  0.0000,  0.0000,
-         0.0678, -0.0386,  0.2924,  0.0000, -0.3028, -0.4919,  0.0000,  0.0000,
-         0.0000, -0.1303,  0.0000,  0.0000, -0.0471, -0.1567,  0.0000, -0.9119,
-        -0.0323,  0.0000,  0.2784,  0.1439,  0.0276,  0.0000, -0.0169, -1.1807,
-         0.0000,  0.1625,  0.0000,  0.0000,  0.0000,  0.0000,  0.1304,  0.6410,
-         0.1543,  0.0000, -0.3289,  0.0000,  0.0000, -0.2400,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1004, -0.1014, -0.1673,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4057,  0.0000,  0.0000,  0.0352,  0.0000, -0.3877],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3280,  0.0000,  0.0420, -0.0107,  2.1484, -0.1626,  0.0000,  0.0000,
-         0.0678, -0.0386,  0.2924,  0.0000, -0.3028, -0.4919,  0.0000,  0.0000,
-         0.0000, -0.1303,  0.0000,  0.0000, -0.0471, -0.1567,  0.0000, -0.9119,
-        -0.0323,  0.0000,  0.2784,  0.1439,  0.0276,  0.0000, -0.0169, -1.1807,
-         0.0000,  0.1625,  0.0000,  0.0000,  0.0000,  0.0000,  0.1304,  0.6410,
-         0.1543,  0.0000, -0.3289,  0.0000,  0.0000, -0.2400,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1004, -0.1014, -0.1673,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4057,  0.0000,  0.0000,  0.0352,  0.0000, -0.3877],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9778e-01, -1.0238e-09,  3.5561e-02, -4.6637e-02,  2.1467e+00,
-        -2.2620e-01,  6.1971e-19, -4.8933e-12,  8.7436e-02, -4.8246e-02,
-         3.1137e-01,  1.1141e-10, -3.0412e-01, -4.9416e-01, -1.1113e-15,
-        -8.2597e-13, -8.1247e-14, -1.0412e-01, -1.8792e-16,  2.2594e-12,
-        -3.9505e-02, -1.3160e-01, -3.9907e-16, -8.9415e-01, -2.0927e-02,
-         3.2389e-12,  2.6271e-01,  1.2991e-01,  4.2825e-03,  5.8039e-16,
-        -3.5290e-02, -1.1795e+00, -1.4857e-19,  1.5355e-01,  7.3742e-14,
-         3.9788e-12,  2.4599e-15,  0.0000e+00,  1.0704e-01,  6.4569e-01,
-         1.6329e-01, -3.2856e-13, -3.4696e-01,  6.0540e-16,  1.4807e-13,
-        -2.3996e-01, -2.0893e-11,  6.7415e-05,  7.0283e-11, -3.8541e-17,
-        -1.2470e-07,  1.2206e-01, -6.7998e-02, -1.3638e-01, -1.4123e-08,
-         8.0210e-21, -3.1524e-15,  1.1337e-18,  4.0806e-01, -1.6373e-10,
-         3.0284e-18,  5.0050e-02,  2.5444e-06, -3.8032e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2978,  0.0000,  0.0356, -0.0466,  2.1467, -0.2262,  0.0000,  0.0000,
-         0.0874, -0.0482,  0.3114,  0.0000, -0.3041, -0.4942,  0.0000,  0.0000,
-         0.0000, -0.1041,  0.0000,  0.0000, -0.0395, -0.1316,  0.0000, -0.8941,
-        -0.0209,  0.0000,  0.2627,  0.1299,  0.0043,  0.0000, -0.0353, -1.1795,
-         0.0000,  0.1535,  0.0000,  0.0000,  0.0000,  0.0000,  0.1070,  0.6457,
-         0.1633,  0.0000, -0.3470,  0.0000,  0.0000, -0.2400,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1221, -0.0680, -0.1364,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4081,  0.0000,  0.0000,  0.0501,  0.0000, -0.3803],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2978,  0.0000,  0.0356, -0.0466,  2.1467, -0.2262,  0.0000,  0.0000,
-         0.0874, -0.0482,  0.3114,  0.0000, -0.3041, -0.4942,  0.0000,  0.0000,
-         0.0000, -0.1041,  0.0000,  0.0000, -0.0395, -0.1316,  0.0000, -0.8941,
-        -0.0209,  0.0000,  0.2627,  0.1299,  0.0043,  0.0000, -0.0353, -1.1795,
-         0.0000,  0.1535,  0.0000,  0.0000,  0.0000,  0.0000,  0.1070,  0.6457,
-         0.1633,  0.0000, -0.3470,  0.0000,  0.0000, -0.2400,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1221, -0.0680, -0.1364,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4081,  0.0000,  0.0000,  0.0501,  0.0000, -0.3803],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6461e-01, -9.3072e-10,  3.4359e-02, -5.7816e-02,  2.1451e+00,
-        -2.8821e-01,  5.6338e-19, -4.4485e-12,  1.0462e-01, -5.7450e-02,
-         3.2412e-01,  1.0128e-10, -3.0550e-01, -4.9935e-01, -1.0103e-15,
-        -7.5089e-13, -7.3862e-14, -5.2230e-02, -1.7084e-16,  2.0540e-12,
-        -3.4066e-02, -1.0074e-01, -3.6279e-16, -8.8205e-01, -1.1467e-02,
-         2.9445e-12,  2.5590e-01,  1.0575e-01, -2.7657e-02,  5.2763e-16,
-        -7.1769e-02, -1.1788e+00, -1.3506e-19,  1.3755e-01,  6.7039e-14,
-         3.6172e-12,  2.2363e-15,  0.0000e+00,  9.8881e-02,  6.5189e-01,
-         1.6798e-01, -2.9869e-13, -3.6732e-01,  5.5036e-16,  1.3461e-13,
-        -2.4572e-01, -1.8994e-11,  6.1287e-05,  6.3894e-11, -3.5037e-17,
-        -1.1336e-07,  1.2151e-01, -3.0953e-02, -5.6835e-02, -1.2840e-08,
-         7.2919e-21, -2.8658e-15,  1.0307e-18,  4.1274e-01, -1.4885e-10,
-         2.7531e-18,  5.8276e-02,  2.3132e-06, -3.7071e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2646,  0.0000,  0.0344, -0.0578,  2.1451, -0.2882,  0.0000,  0.0000,
-         0.1046, -0.0574,  0.3241,  0.0000, -0.3055, -0.4993,  0.0000,  0.0000,
-         0.0000, -0.0522,  0.0000,  0.0000, -0.0341, -0.1007,  0.0000, -0.8821,
-        -0.0115,  0.0000,  0.2559,  0.1058, -0.0277,  0.0000, -0.0718, -1.1788,
-         0.0000,  0.1375,  0.0000,  0.0000,  0.0000,  0.0000,  0.0989,  0.6519,
-         0.1680,  0.0000, -0.3673,  0.0000,  0.0000, -0.2457,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1215, -0.0310, -0.0568,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4127,  0.0000,  0.0000,  0.0583,  0.0000, -0.3707],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2646,  0.0000,  0.0344, -0.0578,  2.1451, -0.2882,  0.0000,  0.0000,
-         0.1046, -0.0574,  0.3241,  0.0000, -0.3055, -0.4993,  0.0000,  0.0000,
-         0.0000, -0.0522,  0.0000,  0.0000, -0.0341, -0.1007,  0.0000, -0.8821,
-        -0.0115,  0.0000,  0.2559,  0.1058, -0.0277,  0.0000, -0.0718, -1.1788,
-         0.0000,  0.1375,  0.0000,  0.0000,  0.0000,  0.0000,  0.0989,  0.6519,
-         0.1680,  0.0000, -0.3673,  0.0000,  0.0000, -0.2457,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1215, -0.0310, -0.0568,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4127,  0.0000,  0.0000,  0.0583,  0.0000, -0.3707],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.3627e-01, -8.4636e-10,  3.9656e-02, -4.8489e-02,  2.1446e+00,
-        -3.2983e-01,  5.1232e-19, -4.0453e-12,  1.0507e-01, -6.4026e-02,
-         3.1913e-01,  9.2100e-11, -2.9565e-01, -5.0587e-01, -9.1873e-16,
-        -6.8283e-13, -6.7167e-14, -1.4495e-02, -1.5535e-16,  1.8679e-12,
-        -2.6238e-02, -5.9370e-02, -3.2991e-16, -8.6904e-01,  2.5548e-02,
-         2.6776e-12,  2.6142e-01,  9.6673e-02, -4.9599e-02,  4.7981e-16,
-        -9.6917e-02, -1.1805e+00, -1.2282e-19,  1.1171e-01,  6.0963e-14,
-         3.2893e-12,  2.0336e-15,  0.0000e+00,  9.2156e-02,  6.5867e-01,
-         1.6305e-01, -2.7162e-13, -3.8061e-01,  5.0048e-16,  1.2241e-13,
-        -2.6370e-01, -1.7272e-11,  5.5732e-05,  5.8103e-11, -3.1862e-17,
-        -1.0309e-07,  1.2435e-01, -3.1969e-02,  5.8059e-02, -1.1676e-08,
-         6.6310e-21, -2.6061e-15,  9.3724e-19,  4.1830e-01, -1.3536e-10,
-         2.5036e-18,  6.9783e-02,  2.1035e-06, -3.7293e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2363,  0.0000,  0.0397, -0.0485,  2.1446, -0.3298,  0.0000,  0.0000,
-         0.1051, -0.0640,  0.3191,  0.0000, -0.2957, -0.5059,  0.0000,  0.0000,
-         0.0000, -0.0145,  0.0000,  0.0000, -0.0262, -0.0594,  0.0000, -0.8690,
-         0.0255,  0.0000,  0.2614,  0.0967, -0.0496,  0.0000, -0.0969, -1.1805,
-         0.0000,  0.1117,  0.0000,  0.0000,  0.0000,  0.0000,  0.0922,  0.6587,
-         0.1631,  0.0000, -0.3806,  0.0000,  0.0000, -0.2637,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1244, -0.0320,  0.0581,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4183,  0.0000,  0.0000,  0.0698,  0.0000, -0.3729],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2363,  0.0000,  0.0397, -0.0485,  2.1446, -0.3298,  0.0000,  0.0000,
-         0.1051, -0.0640,  0.3191,  0.0000, -0.2957, -0.5059,  0.0000,  0.0000,
-         0.0000, -0.0145,  0.0000,  0.0000, -0.0262, -0.0594,  0.0000, -0.8690,
-         0.0255,  0.0000,  0.2614,  0.0967, -0.0496,  0.0000, -0.0969, -1.1805,
-         0.0000,  0.1117,  0.0000,  0.0000,  0.0000,  0.0000,  0.0922,  0.6587,
-         0.1631,  0.0000, -0.3806,  0.0000,  0.0000, -0.2637,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1244, -0.0320,  0.0581,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4183,  0.0000,  0.0000,  0.0698,  0.0000, -0.3729],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.1179e-01, -7.6987e-10,  4.2731e-02, -2.6399e-02,  2.1445e+00,
-        -3.6167e-01,  4.6602e-19, -3.6797e-12,  1.0061e-01, -4.0530e-02,
-         2.9783e-01,  8.3777e-11, -2.8906e-01, -4.9924e-01, -8.3570e-16,
-        -6.2112e-13, -6.1097e-14, -2.6944e-02, -1.4131e-16,  1.6990e-12,
-        -3.4883e-02, -3.6277e-03, -3.0010e-16, -8.6364e-01,  9.1054e-02,
-         2.4356e-12,  2.6050e-01,  7.6615e-02, -7.4913e-02,  4.3645e-16,
-        -1.2978e-01, -1.1846e+00, -1.1172e-19,  1.0069e-01,  5.5453e-14,
-         2.9920e-12,  1.8498e-15,  0.0000e+00,  8.1028e-02,  6.7074e-01,
-         1.6098e-01, -2.4707e-13, -3.8723e-01,  4.5525e-16,  1.1135e-13,
-        -2.8661e-01, -1.5711e-11,  5.0696e-05,  5.2852e-11, -2.8982e-17,
-        -9.3773e-08,  1.4992e-01, -5.8000e-02,  1.5812e-01, -1.0621e-08,
-         6.0317e-21, -2.3706e-15,  8.5254e-19,  4.1808e-01, -1.2312e-10,
-         2.2773e-18,  6.0177e-02,  1.9134e-06, -3.7393e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2118,  0.0000,  0.0427, -0.0264,  2.1445, -0.3617,  0.0000,  0.0000,
-         0.1006, -0.0405,  0.2978,  0.0000, -0.2891, -0.4992,  0.0000,  0.0000,
-         0.0000, -0.0269,  0.0000,  0.0000, -0.0349, -0.0036,  0.0000, -0.8636,
-         0.0911,  0.0000,  0.2605,  0.0766, -0.0749,  0.0000, -0.1298, -1.1846,
-         0.0000,  0.1007,  0.0000,  0.0000,  0.0000,  0.0000,  0.0810,  0.6707,
-         0.1610,  0.0000, -0.3872,  0.0000,  0.0000, -0.2866,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1499, -0.0580,  0.1581,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4181,  0.0000,  0.0000,  0.0602,  0.0000, -0.3739],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2118,  0.0000,  0.0427, -0.0264,  2.1445, -0.3617,  0.0000,  0.0000,
-         0.1006, -0.0405,  0.2978,  0.0000, -0.2891, -0.4992,  0.0000,  0.0000,
-         0.0000, -0.0269,  0.0000,  0.0000, -0.0349, -0.0036,  0.0000, -0.8636,
-         0.0911,  0.0000,  0.2605,  0.0766, -0.0749,  0.0000, -0.1298, -1.1846,
-         0.0000,  0.1007,  0.0000,  0.0000,  0.0000,  0.0000,  0.0810,  0.6707,
-         0.1610,  0.0000, -0.3872,  0.0000,  0.0000, -0.2866,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1499, -0.0580,  0.1581,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4181,  0.0000,  0.0000,  0.0602,  0.0000, -0.3739],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.9244e-01, -7.0049e-10,  5.1404e-02, -7.5829e-03,  2.1454e+00,
-        -3.8066e-01,  4.2402e-19, -3.3481e-12,  8.0606e-02, -7.8519e-03,
-         2.7277e-01,  7.6227e-11, -2.9122e-01, -4.8479e-01, -7.6039e-16,
-        -5.6515e-13, -5.5591e-14, -8.1403e-02, -1.2858e-16,  1.5459e-12,
-        -4.2187e-02,  5.4476e-02, -2.7305e-16, -8.6340e-01,  1.5457e-01,
-         2.2161e-12,  2.6065e-01,  5.7025e-02, -9.0545e-02,  3.9712e-16,
-        -1.6179e-01, -1.1868e+00, -1.0165e-19,  9.3628e-02,  5.0456e-14,
-         2.7224e-12,  1.6831e-15,  0.0000e+00,  9.3812e-02,  6.8901e-01,
-         1.5859e-01, -2.2481e-13, -3.8839e-01,  4.1422e-16,  1.0132e-13,
-        -3.0363e-01, -1.4295e-11,  4.6127e-05,  4.8089e-11, -2.6370e-17,
-        -8.5322e-08,  1.7372e-01, -8.5100e-02,  2.1206e-01, -9.6635e-09,
-         5.4881e-21, -2.1569e-15,  7.7571e-19,  4.0676e-01, -1.1203e-10,
-         2.0721e-18,  2.9324e-02,  1.7410e-06, -3.7920e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1924,  0.0000,  0.0514, -0.0076,  2.1454, -0.3807,  0.0000,  0.0000,
-         0.0806, -0.0079,  0.2728,  0.0000, -0.2912, -0.4848,  0.0000,  0.0000,
-         0.0000, -0.0814,  0.0000,  0.0000, -0.0422,  0.0545,  0.0000, -0.8634,
-         0.1546,  0.0000,  0.2606,  0.0570, -0.0905,  0.0000, -0.1618, -1.1868,
-         0.0000,  0.0936,  0.0000,  0.0000,  0.0000,  0.0000,  0.0938,  0.6890,
-         0.1586,  0.0000, -0.3884,  0.0000,  0.0000, -0.3036,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1737, -0.0851,  0.2121,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4068,  0.0000,  0.0000,  0.0293,  0.0000, -0.3792],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1924,  0.0000,  0.0514, -0.0076,  2.1454, -0.3807,  0.0000,  0.0000,
-         0.0806, -0.0079,  0.2728,  0.0000, -0.2912, -0.4848,  0.0000,  0.0000,
-         0.0000, -0.0814,  0.0000,  0.0000, -0.0422,  0.0545,  0.0000, -0.8634,
-         0.1546,  0.0000,  0.2606,  0.0570, -0.0905,  0.0000, -0.1618, -1.1868,
-         0.0000,  0.0936,  0.0000,  0.0000,  0.0000,  0.0000,  0.0938,  0.6890,
-         0.1586,  0.0000, -0.3884,  0.0000,  0.0000, -0.3036,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1737, -0.0851,  0.2121,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4068,  0.0000,  0.0000,  0.0293,  0.0000, -0.3792],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.6704e-01, -6.3754e-10,  6.5814e-02,  6.8782e-03,  2.1463e+00,
-        -3.8971e-01,  3.8592e-19, -3.0472e-12,  4.6401e-02,  4.5575e-02,
-         2.5054e-01,  6.9377e-11, -2.9432e-01, -4.6655e-01, -6.9206e-16,
-        -5.1436e-13, -5.0595e-14, -1.5744e-01, -1.1702e-16,  1.4070e-12,
-        -4.7998e-02,  8.9611e-02, -2.4851e-16, -8.6931e-01,  1.9614e-01,
-         2.0170e-12,  2.6614e-01,  3.9340e-02, -1.0889e-01,  3.6143e-16,
-        -1.9513e-01, -1.1891e+00, -9.2519e-20,  9.8201e-02,  4.5922e-14,
-         2.4778e-12,  1.5318e-15,  0.0000e+00,  1.2043e-01,  7.1315e-01,
-         1.5618e-01, -2.0460e-13, -3.9147e-01,  3.7700e-16,  9.2210e-14,
-        -3.0804e-01, -1.3011e-11,  4.1982e-05,  4.3768e-11, -2.4001e-17,
-        -7.7655e-08,  1.9388e-01, -9.1519e-02,  2.3886e-01, -8.7951e-09,
-         4.9949e-21, -1.9631e-15,  7.0600e-19,  3.9639e-01, -1.0196e-10,
-         1.8859e-18, -2.8760e-03,  1.5845e-06, -3.7148e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1670,  0.0000,  0.0658,  0.0069,  2.1463, -0.3897,  0.0000,  0.0000,
-         0.0464,  0.0456,  0.2505,  0.0000, -0.2943, -0.4665,  0.0000,  0.0000,
-         0.0000, -0.1574,  0.0000,  0.0000, -0.0480,  0.0896,  0.0000, -0.8693,
-         0.1961,  0.0000,  0.2661,  0.0393, -0.1089,  0.0000, -0.1951, -1.1891,
-         0.0000,  0.0982,  0.0000,  0.0000,  0.0000,  0.0000,  0.1204,  0.7132,
-         0.1562,  0.0000, -0.3915,  0.0000,  0.0000, -0.3080,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1939, -0.0915,  0.2389,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3964,  0.0000,  0.0000, -0.0029,  0.0000, -0.3715],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1670,  0.0000,  0.0658,  0.0069,  2.1463, -0.3897,  0.0000,  0.0000,
-         0.0464,  0.0456,  0.2505,  0.0000, -0.2943, -0.4665,  0.0000,  0.0000,
-         0.0000, -0.1574,  0.0000,  0.0000, -0.0480,  0.0896,  0.0000, -0.8693,
-         0.1961,  0.0000,  0.2661,  0.0393, -0.1089,  0.0000, -0.1951, -1.1891,
-         0.0000,  0.0982,  0.0000,  0.0000,  0.0000,  0.0000,  0.1204,  0.7132,
-         0.1562,  0.0000, -0.3915,  0.0000,  0.0000, -0.3080,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1939, -0.0915,  0.2389,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3964,  0.0000,  0.0000, -0.0029,  0.0000, -0.3715],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.4078e-01, -5.8041e-10,  8.7758e-02,  1.2731e-02,  2.1467e+00,
-        -3.8858e-01,  3.5133e-19, -2.7742e-12,  8.9883e-03,  1.0028e-01,
-         2.1505e-01,  6.3160e-11, -2.9750e-01, -4.4950e-01, -6.3004e-16,
-        -4.6827e-13, -4.6061e-14, -2.3712e-01, -1.0654e-16,  1.2809e-12,
-        -5.5962e-02,  1.0640e-01, -2.2624e-16, -8.6686e-01,  2.2099e-01,
-         1.8362e-12,  2.8258e-01,  2.6917e-02, -1.1922e-01,  3.2904e-16,
-        -2.1498e-01, -1.1913e+00, -8.4228e-20,  1.0322e-01,  4.1807e-14,
-         2.2557e-12,  1.3946e-15,  0.0000e+00,  1.4477e-01,  7.3679e-01,
-         1.3892e-01, -1.8627e-13, -3.9335e-01,  3.4322e-16,  8.3947e-14,
-        -3.0817e-01, -1.1845e-11,  3.8220e-05,  3.9846e-11, -2.1850e-17,
-        -7.0696e-08,  1.8989e-01, -9.0739e-02,  2.3116e-01, -8.0069e-09,
-         4.5473e-21, -1.7872e-15,  6.4274e-19,  3.8628e-01, -9.2825e-11,
-         1.7169e-18, -3.4931e-02,  1.4425e-06, -3.5624e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1408,  0.0000,  0.0878,  0.0127,  2.1467, -0.3886,  0.0000,  0.0000,
-         0.0090,  0.1003,  0.2150,  0.0000, -0.2975, -0.4495,  0.0000,  0.0000,
-         0.0000, -0.2371,  0.0000,  0.0000, -0.0560,  0.1064,  0.0000, -0.8669,
-         0.2210,  0.0000,  0.2826,  0.0269, -0.1192,  0.0000, -0.2150, -1.1913,
-         0.0000,  0.1032,  0.0000,  0.0000,  0.0000,  0.0000,  0.1448,  0.7368,
-         0.1389,  0.0000, -0.3934,  0.0000,  0.0000, -0.3082,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1899, -0.0907,  0.2312,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3863,  0.0000,  0.0000, -0.0349,  0.0000, -0.3562],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1408,  0.0000,  0.0878,  0.0127,  2.1467, -0.3886,  0.0000,  0.0000,
-         0.0090,  0.1003,  0.2150,  0.0000, -0.2975, -0.4495,  0.0000,  0.0000,
-         0.0000, -0.2371,  0.0000,  0.0000, -0.0560,  0.1064,  0.0000, -0.8669,
-         0.2210,  0.0000,  0.2826,  0.0269, -0.1192,  0.0000, -0.2150, -1.1913,
-         0.0000,  0.1032,  0.0000,  0.0000,  0.0000,  0.0000,  0.1448,  0.7368,
-         0.1389,  0.0000, -0.3934,  0.0000,  0.0000, -0.3082,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1899, -0.0907,  0.2312,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3863,  0.0000,  0.0000, -0.0349,  0.0000, -0.3562],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 9.7483e-02, -5.2854e-10,  1.0256e-01, -2.7314e-03,  2.1470e+00,
-        -3.8713e-01,  3.1994e-19, -2.5263e-12, -2.3947e-02,  1.3782e-01,
-         1.9597e-01,  5.7516e-11, -3.0850e-01, -4.3681e-01, -5.7374e-16,
-        -4.2642e-13, -4.1945e-14, -3.1641e-01, -9.7016e-17,  1.1665e-12,
-        -7.1537e-02,  9.1505e-02, -2.0603e-16, -8.6555e-01,  2.2780e-01,
-         1.6721e-12,  3.0204e-01,  3.0161e-03, -1.3962e-01,  2.9964e-16,
-        -2.3189e-01, -1.1920e+00, -7.6701e-20,  1.0858e-01,  3.8071e-14,
-         2.0541e-12,  1.2699e-15,  0.0000e+00,  1.5873e-01,  7.5553e-01,
-         9.9521e-02, -1.6962e-13, -3.9809e-01,  3.1254e-16,  7.6445e-14,
-        -2.9705e-01, -1.0786e-11,  3.4804e-05,  3.6285e-11, -1.9897e-17,
-        -6.4378e-08,  1.5534e-01, -7.6028e-02,  1.7240e-01, -7.2914e-09,
-         4.1410e-21, -1.6275e-15,  5.8530e-19,  3.8610e-01, -8.4529e-11,
-         1.5635e-18, -6.6524e-02,  1.3136e-06, -3.3394e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0975,  0.0000,  0.1026, -0.0027,  2.1470, -0.3871,  0.0000,  0.0000,
-        -0.0239,  0.1378,  0.1960,  0.0000, -0.3085, -0.4368,  0.0000,  0.0000,
-         0.0000, -0.3164,  0.0000,  0.0000, -0.0715,  0.0915,  0.0000, -0.8656,
-         0.2278,  0.0000,  0.3020,  0.0030, -0.1396,  0.0000, -0.2319, -1.1920,
-         0.0000,  0.1086,  0.0000,  0.0000,  0.0000,  0.0000,  0.1587,  0.7555,
-         0.0995,  0.0000, -0.3981,  0.0000,  0.0000, -0.2971,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1553, -0.0760,  0.1724,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3861,  0.0000,  0.0000, -0.0665,  0.0000, -0.3339],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0975,  0.0000,  0.1026, -0.0027,  2.1470, -0.3871,  0.0000,  0.0000,
-        -0.0239,  0.1378,  0.1960,  0.0000, -0.3085, -0.4368,  0.0000,  0.0000,
-         0.0000, -0.3164,  0.0000,  0.0000, -0.0715,  0.0915,  0.0000, -0.8656,
-         0.2278,  0.0000,  0.3020,  0.0030, -0.1396,  0.0000, -0.2319, -1.1920,
-         0.0000,  0.1086,  0.0000,  0.0000,  0.0000,  0.0000,  0.1587,  0.7555,
-         0.0995,  0.0000, -0.3981,  0.0000,  0.0000, -0.2971,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1553, -0.0760,  0.1724,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3861,  0.0000,  0.0000, -0.0665,  0.0000, -0.3339],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.2685e-02, -4.8144e-10,  1.1574e-01, -9.1928e-03,  2.1490e+00,
-        -3.7819e-01,  2.9142e-19, -2.3011e-12, -5.4455e-02,  1.6112e-01,
-         1.5590e-01,  5.2390e-11, -3.2171e-01, -4.2674e-01, -5.2260e-16,
-        -3.8842e-13, -3.8207e-14, -3.7707e-01, -8.8370e-17,  1.0625e-12,
-        -6.9257e-02,  6.7403e-02, -1.8766e-16, -8.6868e-01,  2.1879e-01,
-         1.5231e-12,  3.2590e-01, -9.2153e-03, -1.5168e-01,  2.7293e-16,
-        -2.4828e-01, -1.1916e+00, -6.9865e-20,  1.2722e-01,  3.4678e-14,
-         1.8711e-12,  1.1568e-15,  0.0000e+00,  1.5851e-01,  7.6587e-01,
-         6.4815e-02, -1.5451e-13, -4.0097e-01,  2.8469e-16,  6.9632e-14,
-        -2.9165e-01, -9.8249e-12,  3.1702e-05,  3.3051e-11, -1.8124e-17,
-        -5.8641e-08,  1.1995e-01, -6.7949e-02,  1.1452e-01, -6.6416e-09,
-         3.7719e-21, -1.4824e-15,  5.3313e-19,  3.9978e-01, -7.6996e-11,
-         1.4241e-18, -9.2934e-02,  1.1965e-06, -3.0418e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0527,  0.0000,  0.1157, -0.0092,  2.1490, -0.3782,  0.0000,  0.0000,
-        -0.0545,  0.1611,  0.1559,  0.0000, -0.3217, -0.4267,  0.0000,  0.0000,
-         0.0000, -0.3771,  0.0000,  0.0000, -0.0693,  0.0674,  0.0000, -0.8687,
-         0.2188,  0.0000,  0.3259, -0.0092, -0.1517,  0.0000, -0.2483, -1.1916,
-         0.0000,  0.1272,  0.0000,  0.0000,  0.0000,  0.0000,  0.1585,  0.7659,
-         0.0648,  0.0000, -0.4010,  0.0000,  0.0000, -0.2917,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1200, -0.0679,  0.1145,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3998,  0.0000,  0.0000, -0.0929,  0.0000, -0.3042],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0527,  0.0000,  0.1157, -0.0092,  2.1490, -0.3782,  0.0000,  0.0000,
-        -0.0545,  0.1611,  0.1559,  0.0000, -0.3217, -0.4267,  0.0000,  0.0000,
-         0.0000, -0.3771,  0.0000,  0.0000, -0.0693,  0.0674,  0.0000, -0.8687,
-         0.2188,  0.0000,  0.3259, -0.0092, -0.1517,  0.0000, -0.2483, -1.1916,
-         0.0000,  0.1272,  0.0000,  0.0000,  0.0000,  0.0000,  0.1585,  0.7659,
-         0.0648,  0.0000, -0.4010,  0.0000,  0.0000, -0.2917,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1200, -0.0679,  0.1145,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3998,  0.0000,  0.0000, -0.0929,  0.0000, -0.3042],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 9.4764e-03, -4.3865e-10,  1.3154e-01, -5.4274e-03,  2.1502e+00,
-        -3.6951e-01,  2.6552e-19, -2.0966e-12, -9.1067e-02,  1.6959e-01,
-         1.0208e-01,  4.7733e-11, -3.3124e-01, -4.1762e-01, -4.7615e-16,
-        -3.5389e-13, -3.4811e-14, -4.2202e-01, -8.0515e-17,  9.6806e-13,
-        -6.3539e-02,  2.4248e-02, -1.7098e-16, -8.7085e-01,  1.9370e-01,
-         1.3877e-12,  3.5402e-01, -2.4667e-02, -1.5775e-01,  2.4867e-16,
-        -2.5737e-01, -1.1900e+00, -6.3655e-20,  1.3929e-01,  3.1595e-14,
-         1.7048e-12,  1.0540e-15,  0.0000e+00,  1.5117e-01,  7.6563e-01,
-         3.0670e-02, -1.4077e-13, -3.9329e-01,  2.5939e-16,  6.3443e-14,
-        -2.8141e-01, -8.9517e-12,  2.8885e-05,  3.0113e-11, -1.6513e-17,
-        -5.3429e-08,  8.7375e-02, -5.2453e-02,  6.2866e-02, -6.0513e-09,
-         3.4367e-21, -1.3507e-15,  4.8575e-19,  4.1630e-01, -7.0152e-11,
-         1.2975e-18, -1.0201e-01,  1.0902e-06, -2.7193e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0095,  0.0000,  0.1315, -0.0054,  2.1502, -0.3695,  0.0000,  0.0000,
-        -0.0911,  0.1696,  0.1021,  0.0000, -0.3312, -0.4176,  0.0000,  0.0000,
-         0.0000, -0.4220,  0.0000,  0.0000, -0.0635,  0.0242,  0.0000, -0.8709,
-         0.1937,  0.0000,  0.3540, -0.0247, -0.1577,  0.0000, -0.2574, -1.1900,
-         0.0000,  0.1393,  0.0000,  0.0000,  0.0000,  0.0000,  0.1512,  0.7656,
-         0.0307,  0.0000, -0.3933,  0.0000,  0.0000, -0.2814,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0874, -0.0525,  0.0629,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4163,  0.0000,  0.0000, -0.1020,  0.0000, -0.2719],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0095,  0.0000,  0.1315, -0.0054,  2.1502, -0.3695,  0.0000,  0.0000,
-        -0.0911,  0.1696,  0.1021,  0.0000, -0.3312, -0.4176,  0.0000,  0.0000,
-         0.0000, -0.4220,  0.0000,  0.0000, -0.0635,  0.0242,  0.0000, -0.8709,
-         0.1937,  0.0000,  0.3540, -0.0247, -0.1577,  0.0000, -0.2574, -1.1900,
-         0.0000,  0.1393,  0.0000,  0.0000,  0.0000,  0.0000,  0.1512,  0.7656,
-         0.0307,  0.0000, -0.3933,  0.0000,  0.0000, -0.2814,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0874, -0.0525,  0.0629,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4163,  0.0000,  0.0000, -0.1020,  0.0000, -0.2719],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.4154e-02, -3.9976e-10,  1.5097e-01,  3.7738e-03,  2.1515e+00,
-        -3.5779e-01,  2.4198e-19, -1.9107e-12, -1.1135e-01,  1.7970e-01,
-         6.1167e-02,  4.3502e-11, -3.3361e-01, -4.1352e-01, -4.3395e-16,
-        -3.2252e-13, -3.1725e-14, -4.4052e-01, -7.3378e-17,  8.8225e-13,
-        -5.3080e-02, -1.0377e-02, -1.5583e-16, -8.7003e-01,  1.6351e-01,
-         1.2647e-12,  3.8086e-01, -3.5121e-02, -1.7059e-01,  2.2663e-16,
-        -2.6254e-01, -1.1882e+00, -5.8013e-20,  1.5469e-01,  2.8795e-14,
-         1.5537e-12,  9.6052e-16,  0.0000e+00,  1.4929e-01,  7.6028e-01,
-         2.0515e-03, -1.2829e-13, -3.8470e-01,  2.3639e-16,  5.7819e-14,
-        -2.7684e-01, -8.1582e-12,  2.6324e-05,  2.7444e-11, -1.5049e-17,
-        -4.8693e-08,  5.1521e-02, -2.5694e-02,  1.6081e-02, -5.5148e-09,
-         3.1320e-21, -1.2309e-15,  4.4269e-19,  4.3138e-01, -6.3934e-11,
-         1.1825e-18, -1.0056e-01,  9.9355e-07, -2.3883e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-2.4154e-02,  0.0000e+00,  1.5097e-01,  3.7738e-03,  2.1515e+00,
-        -3.5779e-01,  0.0000e+00,  0.0000e+00, -1.1135e-01,  1.7970e-01,
-         6.1167e-02,  0.0000e+00, -3.3361e-01, -4.1352e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.4052e-01,  0.0000e+00,  0.0000e+00,
-        -5.3080e-02, -1.0377e-02,  0.0000e+00, -8.7003e-01,  1.6351e-01,
-         0.0000e+00,  3.8086e-01, -3.5121e-02, -1.7059e-01,  0.0000e+00,
-        -2.6254e-01, -1.1882e+00,  0.0000e+00,  1.5469e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.4929e-01,  7.6028e-01,
-         2.0515e-03,  0.0000e+00, -3.8470e-01,  0.0000e+00,  0.0000e+00,
-        -2.7684e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  5.1521e-02, -2.5694e-02,  1.6081e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.3138e-01,  0.0000e+00,
-         0.0000e+00, -1.0056e-01,  0.0000e+00, -2.3883e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([-2.4154e-02,  0.0000e+00,  1.5097e-01,  3.7738e-03,  2.1515e+00,
-        -3.5779e-01,  0.0000e+00,  0.0000e+00, -1.1135e-01,  1.7970e-01,
-         6.1167e-02,  0.0000e+00, -3.3361e-01, -4.1352e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.4052e-01,  0.0000e+00,  0.0000e+00,
-        -5.3080e-02, -1.0377e-02,  0.0000e+00, -8.7003e-01,  1.6351e-01,
-         0.0000e+00,  3.8086e-01, -3.5121e-02, -1.7059e-01,  0.0000e+00,
-        -2.6254e-01, -1.1882e+00,  0.0000e+00,  1.5469e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.4929e-01,  7.6028e-01,
-         2.0515e-03,  0.0000e+00, -3.8470e-01,  0.0000e+00,  0.0000e+00,
-        -2.7684e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  5.1521e-02, -2.5694e-02,  1.6081e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.3138e-01,  0.0000e+00,
-         0.0000e+00, -1.0056e-01,  0.0000e+00, -2.3883e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([-3.6557e-02, -3.6442e-10,  1.6995e-01,  1.9801e-02,  2.1527e+00,
-        -3.6694e-01,  2.2059e-19, -1.7418e-12, -1.2546e-01,  1.7737e-01,
-         8.2475e-03,  3.9656e-11, -3.1799e-01, -4.1102e-01, -3.9558e-16,
-        -2.9401e-13, -2.8920e-14, -4.3985e-01, -6.6890e-17,  8.0425e-13,
-        -3.0839e-02, -5.0506e-02, -1.4205e-16, -8.6265e-01,  1.0964e-01,
-         1.1529e-12,  4.0652e-01, -3.2921e-02, -1.7634e-01,  2.0659e-16,
-        -2.7274e-01, -1.1827e+00, -5.2883e-20,  1.5588e-01,  2.6249e-14,
-         1.4163e-12,  8.7560e-16,  0.0000e+00,  1.5386e-01,  7.5340e-01,
-        -1.6659e-03, -1.1695e-13, -3.7068e-01,  2.1549e-16,  5.2707e-14,
-        -2.7665e-01, -7.4369e-12,  2.3997e-05,  2.5018e-11, -1.3719e-17,
-        -4.4388e-08,  2.7079e-02, -1.0155e-02, -1.4541e-02, -5.0273e-09,
-         2.8551e-21, -1.1221e-15,  4.0355e-19,  4.4536e-01, -5.8281e-11,
-         1.0780e-18, -8.2575e-02,  9.0571e-07, -2.2572e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-3.6557e-02,  0.0000e+00,  1.6995e-01,  1.9801e-02,  2.1527e+00,
-        -3.6694e-01,  0.0000e+00,  0.0000e+00, -1.2546e-01,  1.7737e-01,
-         8.2475e-03,  0.0000e+00, -3.1799e-01, -4.1102e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.3985e-01,  0.0000e+00,  0.0000e+00,
-        -3.0839e-02, -5.0506e-02,  0.0000e+00, -8.6265e-01,  1.0964e-01,
-         0.0000e+00,  4.0652e-01, -3.2921e-02, -1.7634e-01,  0.0000e+00,
-        -2.7274e-01, -1.1827e+00,  0.0000e+00,  1.5588e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.5386e-01,  7.5340e-01,
-        -1.6659e-03,  0.0000e+00, -3.7068e-01,  0.0000e+00,  0.0000e+00,
-        -2.7665e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  2.7079e-02, -1.0155e-02, -1.4541e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.4536e-01,  0.0000e+00,
-         0.0000e+00, -8.2575e-02,  0.0000e+00, -2.2572e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([-3.6557e-02,  0.0000e+00,  1.6995e-01,  1.9801e-02,  2.1527e+00,
-        -3.6694e-01,  0.0000e+00,  0.0000e+00, -1.2546e-01,  1.7737e-01,
-         8.2475e-03,  0.0000e+00, -3.1799e-01, -4.1102e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.3985e-01,  0.0000e+00,  0.0000e+00,
-        -3.0839e-02, -5.0506e-02,  0.0000e+00, -8.6265e-01,  1.0964e-01,
-         0.0000e+00,  4.0652e-01, -3.2921e-02, -1.7634e-01,  0.0000e+00,
-        -2.7274e-01, -1.1827e+00,  0.0000e+00,  1.5588e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.5386e-01,  7.5340e-01,
-        -1.6659e-03,  0.0000e+00, -3.7068e-01,  0.0000e+00,  0.0000e+00,
-        -2.7665e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  2.7079e-02, -1.0155e-02, -1.4541e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.4536e-01,  0.0000e+00,
-         0.0000e+00, -8.2575e-02,  0.0000e+00, -2.2572e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([-4.5450e-02, -3.3228e-10,  1.8053e-01,  3.9076e-02,  2.1534e+00,
-        -3.7920e-01,  2.0114e-19, -1.5882e-12, -1.2794e-01,  1.6626e-01,
-        -3.6370e-02,  3.6159e-11, -3.0345e-01, -4.0418e-01, -3.6070e-16,
-        -2.6808e-13, -2.6370e-14, -4.2671e-01, -6.0992e-17,  7.3332e-13,
-        -3.7073e-03, -7.7802e-02, -1.2952e-16, -8.5346e-01,  5.9516e-02,
-         1.0512e-12,  4.1956e-01, -2.4087e-02, -1.7476e-01,  1.8837e-16,
-        -2.9059e-01, -1.1764e+00, -4.8220e-20,  1.5453e-01,  2.3934e-14,
-         1.2914e-12,  7.9839e-16,  0.0000e+00,  1.6487e-01,  7.4646e-01,
-         5.0171e-03, -1.0664e-13, -3.6133e-01,  1.9649e-16,  4.8059e-14,
-        -2.8539e-01, -6.7811e-12,  2.1881e-05,  2.2811e-11, -1.2509e-17,
-        -4.0473e-08,  4.0412e-04,  9.6613e-04, -2.6334e-02, -4.5839e-09,
-         2.6033e-21, -1.0232e-15,  3.6796e-19,  4.6256e-01, -5.3142e-11,
-         9.8291e-19, -6.0307e-02,  8.2584e-07, -2.1923e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-4.5450e-02,  0.0000e+00,  1.8053e-01,  3.9076e-02,  2.1534e+00,
-        -3.7920e-01,  0.0000e+00,  0.0000e+00, -1.2794e-01,  1.6626e-01,
-        -3.6370e-02,  0.0000e+00, -3.0345e-01, -4.0418e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.2671e-01,  0.0000e+00,  0.0000e+00,
-        -3.7073e-03, -7.7802e-02,  0.0000e+00, -8.5346e-01,  5.9516e-02,
-         0.0000e+00,  4.1956e-01, -2.4087e-02, -1.7476e-01,  0.0000e+00,
-        -2.9059e-01, -1.1764e+00,  0.0000e+00,  1.5453e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.6487e-01,  7.4646e-01,
-         5.0171e-03,  0.0000e+00, -3.6133e-01,  0.0000e+00,  0.0000e+00,
-        -2.8539e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  4.0412e-04,  9.6613e-04, -2.6334e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6256e-01,  0.0000e+00,
-         0.0000e+00, -6.0307e-02,  0.0000e+00, -2.1923e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([-4.5450e-02,  0.0000e+00,  1.8053e-01,  3.9076e-02,  2.1534e+00,
-        -3.7920e-01,  0.0000e+00,  0.0000e+00, -1.2794e-01,  1.6626e-01,
-        -3.6370e-02,  0.0000e+00, -3.0345e-01, -4.0418e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.2671e-01,  0.0000e+00,  0.0000e+00,
-        -3.7073e-03, -7.7802e-02,  0.0000e+00, -8.5346e-01,  5.9516e-02,
-         0.0000e+00,  4.1956e-01, -2.4087e-02, -1.7476e-01,  0.0000e+00,
-        -2.9059e-01, -1.1764e+00,  0.0000e+00,  1.5453e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.6487e-01,  7.4646e-01,
-         5.0171e-03,  0.0000e+00, -3.6133e-01,  0.0000e+00,  0.0000e+00,
-        -2.8539e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  4.0412e-04,  9.6613e-04, -2.6334e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6256e-01,  0.0000e+00,
-         0.0000e+00, -6.0307e-02,  0.0000e+00, -2.1923e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([-5.0422e-02, -3.0306e-10,  2.0123e-01,  6.9000e-02,  2.1540e+00,
-        -3.8677e-01,  1.8344e-19, -1.4485e-12, -1.2315e-01,  1.6198e-01,
-        -7.4725e-02,  3.2978e-11, -2.7860e-01, -3.9351e-01, -3.2897e-16,
-        -2.4450e-13, -2.4050e-14, -4.1267e-01, -5.5627e-17,  6.6882e-13,
-         2.9638e-02, -1.0240e-01, -1.1813e-16, -8.4261e-01,  6.1633e-03,
-         9.5876e-13,  4.3187e-01, -7.6450e-03, -1.6694e-01,  1.7181e-16,
-        -3.0906e-01, -1.1715e+00, -4.3979e-20,  1.4965e-01,  2.1829e-14,
-         1.1778e-12,  7.2816e-16,  0.0000e+00,  1.7446e-01,  7.3607e-01,
-         1.9023e-02, -9.7258e-14, -3.4858e-01,  1.7921e-16,  4.3832e-14,
-        -2.9203e-01, -6.1846e-12,  1.9956e-05,  2.0805e-11, -1.1409e-17,
-        -3.6913e-08, -5.5604e-03,  6.2544e-03, -1.6613e-02, -4.1807e-09,
-         2.3743e-21, -9.3316e-16,  3.3560e-19,  4.7682e-01, -4.8467e-11,
-         8.9646e-19, -3.5430e-02,  7.5320e-07, -2.1709e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0504,  0.0000,  0.2012,  0.0690,  2.1540, -0.3868,  0.0000,  0.0000,
-        -0.1232,  0.1620, -0.0747,  0.0000, -0.2786, -0.3935,  0.0000,  0.0000,
-         0.0000, -0.4127,  0.0000,  0.0000,  0.0296, -0.1024,  0.0000, -0.8426,
-         0.0062,  0.0000,  0.4319, -0.0076, -0.1669,  0.0000, -0.3091, -1.1715,
-         0.0000,  0.1497,  0.0000,  0.0000,  0.0000,  0.0000,  0.1745,  0.7361,
-         0.0190,  0.0000, -0.3486,  0.0000,  0.0000, -0.2920,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0056,  0.0063, -0.0166,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4768,  0.0000,  0.0000, -0.0354,  0.0000, -0.2171],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0504,  0.0000,  0.2012,  0.0690,  2.1540, -0.3868,  0.0000,  0.0000,
-        -0.1232,  0.1620, -0.0747,  0.0000, -0.2786, -0.3935,  0.0000,  0.0000,
-         0.0000, -0.4127,  0.0000,  0.0000,  0.0296, -0.1024,  0.0000, -0.8426,
-         0.0062,  0.0000,  0.4319, -0.0076, -0.1669,  0.0000, -0.3091, -1.1715,
-         0.0000,  0.1497,  0.0000,  0.0000,  0.0000,  0.0000,  0.1745,  0.7361,
-         0.0190,  0.0000, -0.3486,  0.0000,  0.0000, -0.2920,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0056,  0.0063, -0.0166,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4768,  0.0000,  0.0000, -0.0354,  0.0000, -0.2171],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-5.2054e-02, -2.7647e-10,  2.0843e-01,  9.5849e-02,  2.1529e+00,
-        -4.0205e-01,  1.6735e-19, -1.3214e-12, -9.2790e-02,  1.4358e-01,
-        -9.1880e-02,  3.0085e-11, -2.5558e-01, -3.8210e-01, -3.0011e-16,
-        -2.2305e-13, -2.1940e-14, -3.8091e-01, -5.0746e-17,  6.1014e-13,
-         7.7906e-02, -1.0692e-01, -1.0777e-16, -8.2601e-01, -3.9649e-02,
-         8.7464e-13,  4.3742e-01,  2.6251e-02, -1.5530e-01,  1.5673e-16,
-        -3.4090e-01, -1.1654e+00, -4.0120e-20,  1.4218e-01,  1.9914e-14,
-         1.0745e-12,  6.6427e-16,  0.0000e+00,  1.8967e-01,  7.2518e-01,
-         5.2489e-02, -8.8725e-14, -3.4090e-01,  1.6348e-16,  3.9986e-14,
-        -3.0935e-01, -5.6420e-12,  1.8205e-05,  1.8980e-11, -1.0408e-17,
-        -3.3675e-08,  2.8685e-03,  2.3738e-02,  3.0304e-03, -3.8139e-09,
-         2.1660e-21, -8.5128e-16,  3.0615e-19,  4.9381e-01, -4.4215e-11,
-         8.1780e-19, -1.1091e-02,  6.8711e-07, -2.1446e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0521,  0.0000,  0.2084,  0.0958,  2.1529, -0.4021,  0.0000,  0.0000,
-        -0.0928,  0.1436, -0.0919,  0.0000, -0.2556, -0.3821,  0.0000,  0.0000,
-         0.0000, -0.3809,  0.0000,  0.0000,  0.0779, -0.1069,  0.0000, -0.8260,
-        -0.0396,  0.0000,  0.4374,  0.0263, -0.1553,  0.0000, -0.3409, -1.1654,
-         0.0000,  0.1422,  0.0000,  0.0000,  0.0000,  0.0000,  0.1897,  0.7252,
-         0.0525,  0.0000, -0.3409,  0.0000,  0.0000, -0.3093,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0029,  0.0237,  0.0030,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4938,  0.0000,  0.0000, -0.0111,  0.0000, -0.2145],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0521,  0.0000,  0.2084,  0.0958,  2.1529, -0.4021,  0.0000,  0.0000,
-        -0.0928,  0.1436, -0.0919,  0.0000, -0.2556, -0.3821,  0.0000,  0.0000,
-         0.0000, -0.3809,  0.0000,  0.0000,  0.0779, -0.1069,  0.0000, -0.8260,
-        -0.0396,  0.0000,  0.4374,  0.0263, -0.1553,  0.0000, -0.3409, -1.1654,
-         0.0000,  0.1422,  0.0000,  0.0000,  0.0000,  0.0000,  0.1897,  0.7252,
-         0.0525,  0.0000, -0.3409,  0.0000,  0.0000, -0.3093,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0029,  0.0237,  0.0030,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4938,  0.0000,  0.0000, -0.0111,  0.0000, -0.2145],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-4.3099e-02, -2.5227e-10,  2.1312e-01,  1.3581e-01,  2.1521e+00,
-        -4.2299e-01,  1.5270e-19, -1.2058e-12, -5.6188e-02,  1.1032e-01,
-        -1.0366e-01,  2.7452e-11, -2.2529e-01, -3.7640e-01, -2.7384e-16,
-        -2.0353e-13, -2.0020e-14, -3.4255e-01, -4.6305e-17,  5.5674e-13,
-         1.3157e-01, -9.4695e-02, -9.8334e-17, -8.1510e-01, -7.8293e-02,
-         7.9809e-13,  4.3895e-01,  6.2635e-02, -1.5406e-01,  1.4301e-16,
-        -3.8517e-01, -1.1584e+00, -3.6609e-20,  1.2690e-01,  1.8171e-14,
-         9.8042e-13,  6.0613e-16,  0.0000e+00,  2.0814e-01,  7.1367e-01,
-         9.0432e-02, -8.0959e-14, -3.3412e-01,  1.4917e-16,  3.6487e-14,
-        -3.2648e-01, -5.1482e-12,  1.6612e-05,  1.7318e-11, -9.4968e-18,
-        -3.0727e-08,  3.3837e-02,  3.3222e-02,  6.5868e-02, -3.4801e-09,
-         1.9764e-21, -7.7678e-16,  2.7936e-19,  5.0747e-01, -4.0345e-11,
-         7.4623e-19,  1.8098e-02,  6.2698e-07, -2.2340e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0431,  0.0000,  0.2131,  0.1358,  2.1521, -0.4230,  0.0000,  0.0000,
-        -0.0562,  0.1103, -0.1037,  0.0000, -0.2253, -0.3764,  0.0000,  0.0000,
-         0.0000, -0.3425,  0.0000,  0.0000,  0.1316, -0.0947,  0.0000, -0.8151,
-        -0.0783,  0.0000,  0.4390,  0.0626, -0.1541,  0.0000, -0.3852, -1.1584,
-         0.0000,  0.1269,  0.0000,  0.0000,  0.0000,  0.0000,  0.2081,  0.7137,
-         0.0904,  0.0000, -0.3341,  0.0000,  0.0000, -0.3265,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0338,  0.0332,  0.0659,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5075,  0.0000,  0.0000,  0.0181,  0.0000, -0.2234],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0431,  0.0000,  0.2131,  0.1358,  2.1521, -0.4230,  0.0000,  0.0000,
-        -0.0562,  0.1103, -0.1037,  0.0000, -0.2253, -0.3764,  0.0000,  0.0000,
-         0.0000, -0.3425,  0.0000,  0.0000,  0.1316, -0.0947,  0.0000, -0.8151,
-        -0.0783,  0.0000,  0.4390,  0.0626, -0.1541,  0.0000, -0.3852, -1.1584,
-         0.0000,  0.1269,  0.0000,  0.0000,  0.0000,  0.0000,  0.2081,  0.7137,
-         0.0904,  0.0000, -0.3341,  0.0000,  0.0000, -0.3265,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0338,  0.0332,  0.0659,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5075,  0.0000,  0.0000,  0.0181,  0.0000, -0.2234],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-4.3115e-02, -2.3024e-10,  2.0856e-01,  1.6451e-01,  2.1503e+00,
-        -4.4330e-01,  1.3937e-19, -1.1005e-12, -2.6483e-02,  6.7923e-02,
-        -9.8113e-02,  2.5055e-11, -1.9889e-01, -3.7533e-01, -2.4993e-16,
-        -1.8576e-13, -1.8272e-14, -3.1217e-01, -4.2262e-17,  5.0813e-13,
-         1.7155e-01, -8.6881e-02, -8.9749e-17, -8.0749e-01, -1.0977e-01,
-         7.2841e-13,  4.3450e-01,  8.7299e-02, -1.5848e-01,  1.3053e-16,
-        -4.2430e-01, -1.1526e+00, -3.3412e-20,  1.1429e-01,  1.6584e-14,
-         8.9482e-13,  5.5321e-16,  0.0000e+00,  2.2349e-01,  7.0390e-01,
-         1.2516e-01, -7.3891e-14, -3.2363e-01,  1.3615e-16,  3.3301e-14,
-        -3.2902e-01, -4.6987e-12,  1.5161e-05,  1.5806e-11, -8.6676e-18,
-        -2.8044e-08,  4.1987e-02,  4.4683e-02,  9.9377e-02, -3.1763e-09,
-         1.8039e-21, -7.0896e-16,  2.5497e-19,  5.2572e-01, -3.6823e-11,
-         6.8107e-19,  4.5752e-02,  5.7223e-07, -2.1962e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0431,  0.0000,  0.2086,  0.1645,  2.1503, -0.4433,  0.0000,  0.0000,
-        -0.0265,  0.0679, -0.0981,  0.0000, -0.1989, -0.3753,  0.0000,  0.0000,
-         0.0000, -0.3122,  0.0000,  0.0000,  0.1715, -0.0869,  0.0000, -0.8075,
-        -0.1098,  0.0000,  0.4345,  0.0873, -0.1585,  0.0000, -0.4243, -1.1526,
-         0.0000,  0.1143,  0.0000,  0.0000,  0.0000,  0.0000,  0.2235,  0.7039,
-         0.1252,  0.0000, -0.3236,  0.0000,  0.0000, -0.3290,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0420,  0.0447,  0.0994,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5257,  0.0000,  0.0000,  0.0458,  0.0000, -0.2196],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0431,  0.0000,  0.2086,  0.1645,  2.1503, -0.4433,  0.0000,  0.0000,
-        -0.0265,  0.0679, -0.0981,  0.0000, -0.1989, -0.3753,  0.0000,  0.0000,
-         0.0000, -0.3122,  0.0000,  0.0000,  0.1715, -0.0869,  0.0000, -0.8075,
-        -0.1098,  0.0000,  0.4345,  0.0873, -0.1585,  0.0000, -0.4243, -1.1526,
-         0.0000,  0.1143,  0.0000,  0.0000,  0.0000,  0.0000,  0.2235,  0.7039,
-         0.1252,  0.0000, -0.3236,  0.0000,  0.0000, -0.3290,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0420,  0.0447,  0.0994,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5257,  0.0000,  0.0000,  0.0458,  0.0000, -0.2196],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-3.9428e-02, -2.1019e-10,  2.0668e-01,  1.6689e-01,  2.1490e+00,
-        -4.5697e-01,  1.2723e-19, -1.0046e-12, -8.9598e-03,  1.8142e-02,
-        -8.1625e-02,  2.2873e-11, -1.7724e-01, -3.7258e-01, -2.2816e-16,
-        -1.6958e-13, -1.6681e-14, -3.0432e-01, -3.8581e-17,  4.6387e-13,
-         1.9262e-01, -8.5359e-02, -8.1932e-17, -8.0709e-01, -1.3269e-01,
-         6.6496e-13,  4.1849e-01,  9.6670e-02, -1.6290e-01,  1.1916e-16,
-        -4.5090e-01, -1.1488e+00, -3.0502e-20,  1.0589e-01,  1.5140e-14,
-         8.1688e-13,  5.0503e-16,  0.0000e+00,  2.4152e-01,  6.9473e-01,
-         1.4278e-01, -6.7455e-14, -3.1083e-01,  1.2429e-16,  3.0400e-14,
-        -3.1487e-01, -4.2894e-12,  1.3841e-05,  1.4430e-11, -7.9127e-18,
-        -2.5602e-08,  3.5147e-02,  4.7582e-02,  1.1530e-01, -2.8996e-09,
-         1.6468e-21, -6.4720e-16,  2.3276e-19,  5.4160e-01, -3.3615e-11,
-         6.2175e-19,  6.3851e-02,  5.2239e-07, -2.1885e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0394,  0.0000,  0.2067,  0.1669,  2.1490, -0.4570,  0.0000,  0.0000,
-        -0.0090,  0.0181, -0.0816,  0.0000, -0.1772, -0.3726,  0.0000,  0.0000,
-         0.0000, -0.3043,  0.0000,  0.0000,  0.1926, -0.0854,  0.0000, -0.8071,
-        -0.1327,  0.0000,  0.4185,  0.0967, -0.1629,  0.0000, -0.4509, -1.1488,
-         0.0000,  0.1059,  0.0000,  0.0000,  0.0000,  0.0000,  0.2415,  0.6947,
-         0.1428,  0.0000, -0.3108,  0.0000,  0.0000, -0.3149,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0351,  0.0476,  0.1153,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5416,  0.0000,  0.0000,  0.0639,  0.0000, -0.2189],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0394,  0.0000,  0.2067,  0.1669,  2.1490, -0.4570,  0.0000,  0.0000,
-        -0.0090,  0.0181, -0.0816,  0.0000, -0.1772, -0.3726,  0.0000,  0.0000,
-         0.0000, -0.3043,  0.0000,  0.0000,  0.1926, -0.0854,  0.0000, -0.8071,
-        -0.1327,  0.0000,  0.4185,  0.0967, -0.1629,  0.0000, -0.4509, -1.1488,
-         0.0000,  0.1059,  0.0000,  0.0000,  0.0000,  0.0000,  0.2415,  0.6947,
-         0.1428,  0.0000, -0.3108,  0.0000,  0.0000, -0.3149,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0351,  0.0476,  0.1153,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5416,  0.0000,  0.0000,  0.0639,  0.0000, -0.2189],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-3.0696e-02, -1.9192e-10,  1.9706e-01,  1.5856e-01,  2.1472e+00,
-        -4.7323e-01,  1.1617e-19, -9.1733e-13,  1.0055e-02, -2.4654e-02,
-        -3.7781e-02,  2.0885e-11, -1.5986e-01, -3.6647e-01, -2.0833e-16,
-        -1.5484e-13, -1.5231e-14, -3.0346e-01, -3.5228e-17,  4.2356e-13,
-         1.9455e-01, -9.0991e-02, -7.4812e-17, -8.1423e-01, -1.4478e-01,
-         6.0718e-13,  3.8979e-01,  8.6172e-02, -1.7861e-01,  1.0880e-16,
-        -4.7636e-01, -1.1478e+00, -2.7851e-20,  9.9336e-02,  1.3824e-14,
-         7.4590e-13,  4.6114e-16,  0.0000e+00,  2.5616e-01,  6.8105e-01,
-         1.4747e-01, -6.1593e-14, -2.9800e-01,  1.1349e-16,  2.7759e-14,
-        -3.0813e-01, -3.9167e-12,  1.2638e-05,  1.3176e-11, -7.2251e-18,
-        -2.3377e-08,  2.9594e-02,  5.7604e-02,  9.0989e-02, -2.6476e-09,
-         1.5037e-21, -5.9096e-16,  2.1253e-19,  5.5719e-01, -3.0694e-11,
-         5.6772e-19,  6.9361e-02,  4.7700e-07, -2.0085e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0307,  0.0000,  0.1971,  0.1586,  2.1472, -0.4732,  0.0000,  0.0000,
-         0.0101, -0.0247, -0.0378,  0.0000, -0.1599, -0.3665,  0.0000,  0.0000,
-         0.0000, -0.3035,  0.0000,  0.0000,  0.1946, -0.0910,  0.0000, -0.8142,
-        -0.1448,  0.0000,  0.3898,  0.0862, -0.1786,  0.0000, -0.4764, -1.1478,
-         0.0000,  0.0993,  0.0000,  0.0000,  0.0000,  0.0000,  0.2562,  0.6811,
-         0.1475,  0.0000, -0.2980,  0.0000,  0.0000, -0.3081,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0296,  0.0576,  0.0910,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5572,  0.0000,  0.0000,  0.0694,  0.0000, -0.2008],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0307,  0.0000,  0.1971,  0.1586,  2.1472, -0.4732,  0.0000,  0.0000,
-         0.0101, -0.0247, -0.0378,  0.0000, -0.1599, -0.3665,  0.0000,  0.0000,
-         0.0000, -0.3035,  0.0000,  0.0000,  0.1946, -0.0910,  0.0000, -0.8142,
-        -0.1448,  0.0000,  0.3898,  0.0862, -0.1786,  0.0000, -0.4764, -1.1478,
-         0.0000,  0.0993,  0.0000,  0.0000,  0.0000,  0.0000,  0.2562,  0.6811,
-         0.1475,  0.0000, -0.2980,  0.0000,  0.0000, -0.3081,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0296,  0.0576,  0.0910,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5572,  0.0000,  0.0000,  0.0694,  0.0000, -0.2008],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.5844e-02, -1.7528e-10,  1.8161e-01,  1.4227e-01,  2.1463e+00,
-        -4.7706e-01,  1.0610e-19, -8.3780e-13,  1.7089e-02, -6.5402e-02,
-         4.1537e-03,  1.9074e-11, -1.5951e-01, -3.6642e-01, -1.9027e-16,
-        -1.4142e-13, -1.3911e-14, -3.0134e-01, -3.2174e-17,  3.8684e-13,
-         1.8607e-01, -9.6346e-02, -6.8326e-17, -8.2146e-01, -1.4576e-01,
-         5.5454e-13,  3.6105e-01,  7.6856e-02, -1.9387e-01,  9.9370e-17,
-        -4.8753e-01, -1.1474e+00, -2.5437e-20,  1.0290e-01,  1.2626e-14,
-         6.8123e-13,  4.2116e-16,  0.0000e+00,  2.4056e-01,  6.6699e-01,
-         1.3606e-01, -5.6253e-14, -2.9439e-01,  1.0365e-16,  2.5352e-14,
-        -2.9696e-01, -3.5771e-12,  1.1542e-05,  1.2033e-11, -6.5987e-18,
-        -2.1350e-08,  5.7720e-03,  7.0826e-02,  5.3869e-02, -2.4181e-09,
-         1.3733e-21, -5.3973e-16,  1.9411e-19,  5.7409e-01, -2.8033e-11,
-         5.1850e-19,  7.3435e-02,  4.3564e-07, -1.8535e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0258,  0.0000,  0.1816,  0.1423,  2.1463, -0.4771,  0.0000,  0.0000,
-         0.0171, -0.0654,  0.0042,  0.0000, -0.1595, -0.3664,  0.0000,  0.0000,
-         0.0000, -0.3013,  0.0000,  0.0000,  0.1861, -0.0963,  0.0000, -0.8215,
-        -0.1458,  0.0000,  0.3611,  0.0769, -0.1939,  0.0000, -0.4875, -1.1474,
-         0.0000,  0.1029,  0.0000,  0.0000,  0.0000,  0.0000,  0.2406,  0.6670,
-         0.1361,  0.0000, -0.2944,  0.0000,  0.0000, -0.2970,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0058,  0.0708,  0.0539,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5741,  0.0000,  0.0000,  0.0734,  0.0000, -0.1854],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0258,  0.0000,  0.1816,  0.1423,  2.1463, -0.4771,  0.0000,  0.0000,
-         0.0171, -0.0654,  0.0042,  0.0000, -0.1595, -0.3664,  0.0000,  0.0000,
-         0.0000, -0.3013,  0.0000,  0.0000,  0.1861, -0.0963,  0.0000, -0.8215,
-        -0.1458,  0.0000,  0.3611,  0.0769, -0.1939,  0.0000, -0.4875, -1.1474,
-         0.0000,  0.1029,  0.0000,  0.0000,  0.0000,  0.0000,  0.2406,  0.6670,
-         0.1361,  0.0000, -0.2944,  0.0000,  0.0000, -0.2970,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0058,  0.0708,  0.0539,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5741,  0.0000,  0.0000,  0.0734,  0.0000, -0.1854],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.2700e-02, -1.6012e-10,  1.5520e-01,  1.2577e-01,  2.1454e+00,
-        -4.7525e-01,  9.6924e-20, -7.6533e-13,  2.3327e-02, -9.6334e-02,
-         4.3837e-02,  1.7424e-11, -1.6909e-01, -3.6731e-01, -1.7381e-16,
-        -1.2918e-13, -1.2707e-14, -3.0029e-01, -2.9391e-17,  3.5338e-13,
-         1.6497e-01, -9.8499e-02, -6.2415e-17, -8.3118e-01, -1.2879e-01,
-         5.0657e-13,  3.2574e-01,  5.9599e-02, -2.0626e-01,  9.0774e-17,
-        -4.9161e-01, -1.1465e+00, -2.3236e-20,  9.5888e-02,  1.1533e-14,
-         6.2230e-13,  3.8473e-16,  0.0000e+00,  2.0944e-01,  6.5696e-01,
-         1.1950e-01, -5.1387e-14, -2.9184e-01,  9.4685e-17,  2.3159e-14,
-        -2.8075e-01, -3.2677e-12,  1.0544e-05,  1.0992e-11, -6.0278e-18,
-        -1.9503e-08, -8.5395e-03,  7.0098e-02,  7.4599e-03, -2.2089e-09,
-         1.2545e-21, -4.9304e-16,  1.7731e-19,  5.9330e-01, -2.5608e-11,
-         4.7365e-19,  6.8798e-02,  3.9796e-07, -1.7288e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0227,  0.0000,  0.1552,  0.1258,  2.1454, -0.4753,  0.0000,  0.0000,
-         0.0233, -0.0963,  0.0438,  0.0000, -0.1691, -0.3673,  0.0000,  0.0000,
-         0.0000, -0.3003,  0.0000,  0.0000,  0.1650, -0.0985,  0.0000, -0.8312,
-        -0.1288,  0.0000,  0.3257,  0.0596, -0.2063,  0.0000, -0.4916, -1.1465,
-         0.0000,  0.0959,  0.0000,  0.0000,  0.0000,  0.0000,  0.2094,  0.6570,
-         0.1195,  0.0000, -0.2918,  0.0000,  0.0000, -0.2807,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0085,  0.0701,  0.0075,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5933,  0.0000,  0.0000,  0.0688,  0.0000, -0.1729],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0227,  0.0000,  0.1552,  0.1258,  2.1454, -0.4753,  0.0000,  0.0000,
-         0.0233, -0.0963,  0.0438,  0.0000, -0.1691, -0.3673,  0.0000,  0.0000,
-         0.0000, -0.3003,  0.0000,  0.0000,  0.1650, -0.0985,  0.0000, -0.8312,
-        -0.1288,  0.0000,  0.3257,  0.0596, -0.2063,  0.0000, -0.4916, -1.1465,
-         0.0000,  0.0959,  0.0000,  0.0000,  0.0000,  0.0000,  0.2094,  0.6570,
-         0.1195,  0.0000, -0.2918,  0.0000,  0.0000, -0.2807,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0085,  0.0701,  0.0075,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5933,  0.0000,  0.0000,  0.0688,  0.0000, -0.1729],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.9143e-02, -1.4630e-10,  1.3067e-01,  9.7382e-02,  2.1452e+00,
-        -4.7232e-01,  8.8558e-20, -6.9927e-13,  2.6456e-02, -1.2613e-01,
-         7.3280e-02,  1.5920e-11, -1.7443e-01, -3.7455e-01, -1.5881e-16,
-        -1.1803e-13, -1.1610e-14, -3.0510e-01, -2.6854e-17,  3.2287e-13,
-         1.4248e-01, -1.0646e-01, -5.7028e-17, -8.3584e-01, -1.1888e-01,
-         4.6284e-13,  2.8563e-01,  3.2295e-02, -2.1407e-01,  8.2939e-17,
-        -4.9498e-01, -1.1447e+00, -2.1231e-20,  8.3201e-02,  1.0538e-14,
-         5.6859e-13,  3.5152e-16,  0.0000e+00,  1.7002e-01,  6.5270e-01,
-         9.0548e-02, -4.6951e-14, -2.8586e-01,  8.6512e-17,  2.1160e-14,
-        -2.6376e-01, -2.9856e-12,  9.6338e-06,  1.0044e-11, -5.5076e-18,
-        -1.7820e-08, -2.9287e-02,  4.3958e-02, -5.0029e-02, -2.0183e-09,
-         1.1462e-21, -4.5048e-16,  1.6201e-19,  6.0566e-01, -2.3398e-11,
-         4.3277e-19,  4.8068e-02,  3.6361e-07, -1.6446e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0191,  0.0000,  0.1307,  0.0974,  2.1452, -0.4723,  0.0000,  0.0000,
-         0.0265, -0.1261,  0.0733,  0.0000, -0.1744, -0.3745,  0.0000,  0.0000,
-         0.0000, -0.3051,  0.0000,  0.0000,  0.1425, -0.1065,  0.0000, -0.8358,
-        -0.1189,  0.0000,  0.2856,  0.0323, -0.2141,  0.0000, -0.4950, -1.1447,
-         0.0000,  0.0832,  0.0000,  0.0000,  0.0000,  0.0000,  0.1700,  0.6527,
-         0.0905,  0.0000, -0.2859,  0.0000,  0.0000, -0.2638,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0293,  0.0440, -0.0500,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6057,  0.0000,  0.0000,  0.0481,  0.0000, -0.1645],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0191,  0.0000,  0.1307,  0.0974,  2.1452, -0.4723,  0.0000,  0.0000,
-         0.0265, -0.1261,  0.0733,  0.0000, -0.1744, -0.3745,  0.0000,  0.0000,
-         0.0000, -0.3051,  0.0000,  0.0000,  0.1425, -0.1065,  0.0000, -0.8358,
-        -0.1189,  0.0000,  0.2856,  0.0323, -0.2141,  0.0000, -0.4950, -1.1447,
-         0.0000,  0.0832,  0.0000,  0.0000,  0.0000,  0.0000,  0.1700,  0.6527,
-         0.0905,  0.0000, -0.2859,  0.0000,  0.0000, -0.2638,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0293,  0.0440, -0.0500,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6057,  0.0000,  0.0000,  0.0481,  0.0000, -0.1645],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.7603e-02, -1.3370e-10,  1.1024e-01,  8.2304e-02,  2.1439e+00,
-        -4.6162e-01,  8.0931e-20, -6.3904e-13,  3.0839e-02, -1.3842e-01,
-         9.1792e-02,  1.4549e-11, -1.7731e-01, -3.8494e-01, -1.4513e-16,
-        -1.0787e-13, -1.0610e-14, -3.0447e-01, -2.4541e-17,  2.9507e-13,
-         1.1242e-01, -1.0512e-01, -5.2116e-17, -8.4015e-01, -1.0444e-01,
-         4.2298e-13,  2.5211e-01,  2.5508e-03, -2.2386e-01,  7.5796e-17,
-        -4.8415e-01, -1.1439e+00, -1.9402e-20,  6.3344e-02,  9.6303e-15,
-         5.1961e-13,  3.2124e-16,  0.0000e+00,  1.2410e-01,  6.4254e-01,
-         6.2077e-02, -4.2908e-14, -2.7537e-01,  7.9061e-17,  1.9338e-14,
-        -2.3863e-01, -2.7285e-12,  8.8041e-06,  9.1786e-12, -5.0332e-18,
-        -1.6285e-08, -3.6714e-02,  4.1163e-03, -8.7170e-02, -1.8444e-09,
-         1.0475e-21, -4.1168e-16,  1.4806e-19,  6.1333e-01, -2.1383e-11,
-         3.9549e-19,  3.0658e-02,  3.3229e-07, -1.6494e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0176,  0.0000,  0.1102,  0.0823,  2.1439, -0.4616,  0.0000,  0.0000,
-         0.0308, -0.1384,  0.0918,  0.0000, -0.1773, -0.3849,  0.0000,  0.0000,
-         0.0000, -0.3045,  0.0000,  0.0000,  0.1124, -0.1051,  0.0000, -0.8401,
-        -0.1044,  0.0000,  0.2521,  0.0026, -0.2239,  0.0000, -0.4841, -1.1439,
-         0.0000,  0.0633,  0.0000,  0.0000,  0.0000,  0.0000,  0.1241,  0.6425,
-         0.0621,  0.0000, -0.2754,  0.0000,  0.0000, -0.2386,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0367,  0.0041, -0.0872,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6133,  0.0000,  0.0000,  0.0307,  0.0000, -0.1649],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0176,  0.0000,  0.1102,  0.0823,  2.1439, -0.4616,  0.0000,  0.0000,
-         0.0308, -0.1384,  0.0918,  0.0000, -0.1773, -0.3849,  0.0000,  0.0000,
-         0.0000, -0.3045,  0.0000,  0.0000,  0.1124, -0.1051,  0.0000, -0.8401,
-        -0.1044,  0.0000,  0.2521,  0.0026, -0.2239,  0.0000, -0.4841, -1.1439,
-         0.0000,  0.0633,  0.0000,  0.0000,  0.0000,  0.0000,  0.1241,  0.6425,
-         0.0621,  0.0000, -0.2754,  0.0000,  0.0000, -0.2386,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0367,  0.0041, -0.0872,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6133,  0.0000,  0.0000,  0.0307,  0.0000, -0.1649],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-5.9668e-03, -1.2221e-10,  9.9170e-02,  8.2513e-02,  2.1428e+00,
-        -4.5169e-01,  7.3975e-20, -5.8412e-13,  3.5229e-02, -1.3577e-01,
-         1.0008e-01,  1.3299e-11, -1.7365e-01, -4.0070e-01, -1.3266e-16,
-        -9.8596e-14, -9.6985e-15, -2.9294e-01, -2.2432e-17,  2.6971e-13,
-         8.1958e-02, -1.0664e-01, -4.7637e-17, -8.4698e-01, -8.9948e-02,
-         3.8663e-13,  2.3582e-01, -1.7131e-02, -2.3758e-01,  6.9281e-17,
-        -4.7091e-01, -1.1449e+00, -1.7735e-20,  4.4544e-02,  8.8026e-15,
-         4.7496e-13,  2.9363e-16,  0.0000e+00,  8.0630e-02,  6.3308e-01,
-         2.8752e-02, -3.9220e-14, -2.6035e-01,  7.2266e-17,  1.7676e-14,
-        -2.0988e-01, -2.4940e-12,  8.0474e-06,  8.3897e-12, -4.6006e-18,
-        -1.4885e-08, -4.9874e-02, -4.0534e-02, -6.7422e-02, -1.6859e-09,
-         9.5746e-22, -3.7630e-16,  1.3533e-19,  6.1971e-01, -1.9545e-11,
-         3.6150e-19,  2.7653e-02,  3.0373e-07, -1.5613e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0060,  0.0000,  0.0992,  0.0825,  2.1428, -0.4517,  0.0000,  0.0000,
-         0.0352, -0.1358,  0.1001,  0.0000, -0.1737, -0.4007,  0.0000,  0.0000,
-         0.0000, -0.2929,  0.0000,  0.0000,  0.0820, -0.1066,  0.0000, -0.8470,
-        -0.0899,  0.0000,  0.2358, -0.0171, -0.2376,  0.0000, -0.4709, -1.1449,
-         0.0000,  0.0445,  0.0000,  0.0000,  0.0000,  0.0000,  0.0806,  0.6331,
-         0.0288,  0.0000, -0.2604,  0.0000,  0.0000, -0.2099,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0499, -0.0405, -0.0674,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6197,  0.0000,  0.0000,  0.0277,  0.0000, -0.1561],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0060,  0.0000,  0.0992,  0.0825,  2.1428, -0.4517,  0.0000,  0.0000,
-         0.0352, -0.1358,  0.1001,  0.0000, -0.1737, -0.4007,  0.0000,  0.0000,
-         0.0000, -0.2929,  0.0000,  0.0000,  0.0820, -0.1066,  0.0000, -0.8470,
-        -0.0899,  0.0000,  0.2358, -0.0171, -0.2376,  0.0000, -0.4709, -1.1449,
-         0.0000,  0.0445,  0.0000,  0.0000,  0.0000,  0.0000,  0.0806,  0.6331,
-         0.0288,  0.0000, -0.2604,  0.0000,  0.0000, -0.2099,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0499, -0.0405, -0.0674,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6197,  0.0000,  0.0000,  0.0277,  0.0000, -0.1561],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.1360e-02, -1.1173e-10,  9.8507e-02,  8.5128e-02,  2.1421e+00,
-        -4.4148e-01,  6.7630e-20, -5.3402e-13,  3.4002e-02, -1.3562e-01,
-         9.6247e-02,  1.2158e-11, -1.6195e-01, -4.2726e-01, -1.2128e-16,
-        -9.0140e-14, -8.8666e-15, -2.7467e-01, -2.0508e-17,  2.4657e-13,
-         5.9268e-02, -1.1204e-01, -4.3551e-17, -8.5830e-01, -8.4626e-02,
-         3.5346e-13,  2.3062e-01, -3.2361e-02, -2.5209e-01,  6.3339e-17,
-        -4.5675e-01, -1.1456e+00, -1.6213e-20,  1.2895e-02,  8.0476e-15,
-         4.3422e-13,  2.6845e-16,  0.0000e+00,  4.4299e-02,  6.2817e-01,
-        -4.4018e-03, -3.5856e-14, -2.3954e-01,  6.6068e-17,  1.6159e-14,
-        -1.8098e-01, -2.2801e-12,  7.3571e-06,  7.6701e-12, -4.2060e-18,
-        -1.3609e-08, -6.4432e-02, -9.8332e-02, -2.6827e-02, -1.5413e-09,
-         8.7534e-22, -3.4402e-16,  1.2372e-19,  6.2409e-01, -1.7868e-11,
-         3.3049e-19,  3.1434e-02,  2.7768e-07, -1.4609e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0114,  0.0000,  0.0985,  0.0851,  2.1421, -0.4415,  0.0000,  0.0000,
-         0.0340, -0.1356,  0.0962,  0.0000, -0.1619, -0.4273,  0.0000,  0.0000,
-         0.0000, -0.2747,  0.0000,  0.0000,  0.0593, -0.1120,  0.0000, -0.8583,
-        -0.0846,  0.0000,  0.2306, -0.0324, -0.2521,  0.0000, -0.4568, -1.1456,
-         0.0000,  0.0129,  0.0000,  0.0000,  0.0000,  0.0000,  0.0443,  0.6282,
-        -0.0044,  0.0000, -0.2395,  0.0000,  0.0000, -0.1810,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0644, -0.0983, -0.0268,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6241,  0.0000,  0.0000,  0.0314,  0.0000, -0.1461],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0114,  0.0000,  0.0985,  0.0851,  2.1421, -0.4415,  0.0000,  0.0000,
-         0.0340, -0.1356,  0.0962,  0.0000, -0.1619, -0.4273,  0.0000,  0.0000,
-         0.0000, -0.2747,  0.0000,  0.0000,  0.0593, -0.1120,  0.0000, -0.8583,
-        -0.0846,  0.0000,  0.2306, -0.0324, -0.2521,  0.0000, -0.4568, -1.1456,
-         0.0000,  0.0129,  0.0000,  0.0000,  0.0000,  0.0000,  0.0443,  0.6282,
-        -0.0044,  0.0000, -0.2395,  0.0000,  0.0000, -0.1810,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0644, -0.0983, -0.0268,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6241,  0.0000,  0.0000,  0.0314,  0.0000, -0.1461],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9978e-02, -1.0216e-10,  9.5246e-02,  9.1153e-02,  2.1413e+00,
-        -4.1727e-01,  6.1841e-20, -4.8830e-13,  2.5556e-02, -1.2243e-01,
-         7.5153e-02,  1.1117e-11, -1.6467e-01, -4.5455e-01, -1.1090e-16,
-        -8.2424e-14, -8.1076e-15, -2.5944e-01, -1.8752e-17,  2.2547e-13,
-         3.3026e-02, -1.1284e-01, -3.9823e-17, -8.7466e-01, -6.9704e-02,
-         3.2321e-13,  2.3654e-01, -3.9488e-02, -2.6784e-01,  5.7917e-17,
-        -4.3159e-01, -1.1475e+00, -1.4826e-20, -5.3266e-03,  7.3587e-15,
-         3.9705e-13,  2.4547e-16,  0.0000e+00,  1.8045e-02,  6.2669e-01,
-        -5.1119e-02, -3.2786e-14, -2.3077e-01,  6.0412e-17,  1.4776e-14,
-        -1.3519e-01, -2.0849e-12,  6.7273e-06,  7.0135e-12, -3.8460e-18,
-        -1.2444e-08, -8.6351e-02, -1.4530e-01,  1.8949e-02, -1.4094e-09,
-         8.0041e-22, -3.1457e-16,  1.1313e-19,  6.2870e-01, -1.6339e-11,
-         3.0220e-19,  4.2261e-02,  2.5391e-07, -1.2688e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0300,  0.0000,  0.0952,  0.0912,  2.1413, -0.4173,  0.0000,  0.0000,
-         0.0256, -0.1224,  0.0752,  0.0000, -0.1647, -0.4545,  0.0000,  0.0000,
-         0.0000, -0.2594,  0.0000,  0.0000,  0.0330, -0.1128,  0.0000, -0.8747,
-        -0.0697,  0.0000,  0.2365, -0.0395, -0.2678,  0.0000, -0.4316, -1.1475,
-         0.0000, -0.0053,  0.0000,  0.0000,  0.0000,  0.0000,  0.0180,  0.6267,
-        -0.0511,  0.0000, -0.2308,  0.0000,  0.0000, -0.1352,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0864, -0.1453,  0.0189,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6287,  0.0000,  0.0000,  0.0423,  0.0000, -0.1269],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0300,  0.0000,  0.0952,  0.0912,  2.1413, -0.4173,  0.0000,  0.0000,
-         0.0256, -0.1224,  0.0752,  0.0000, -0.1647, -0.4545,  0.0000,  0.0000,
-         0.0000, -0.2594,  0.0000,  0.0000,  0.0330, -0.1128,  0.0000, -0.8747,
-        -0.0697,  0.0000,  0.2365, -0.0395, -0.2678,  0.0000, -0.4316, -1.1475,
-         0.0000, -0.0053,  0.0000,  0.0000,  0.0000,  0.0000,  0.0180,  0.6267,
-        -0.0511,  0.0000, -0.2308,  0.0000,  0.0000, -0.1352,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0864, -0.1453,  0.0189,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6287,  0.0000,  0.0000,  0.0423,  0.0000, -0.1269],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.6584e-02, -9.3434e-11,  9.2716e-02,  1.0304e-01,  2.1416e+00,
-        -3.9071e-01,  5.6557e-20, -4.4659e-13,  1.2456e-02, -1.0636e-01,
-         4.3122e-02,  1.0167e-11, -1.6943e-01, -4.6888e-01, -1.0142e-16,
-        -7.5382e-14, -7.4149e-15, -2.4774e-01, -1.7150e-17,  2.0620e-13,
-         2.0526e-02, -1.0985e-01, -3.6421e-17, -8.9091e-01, -4.6599e-02,
-         2.9559e-13,  2.3927e-01, -3.6055e-02, -2.7976e-01,  5.2969e-17,
-        -4.0888e-01, -1.1492e+00, -1.3559e-20, -2.5773e-02,  6.7300e-15,
-         3.6313e-13,  2.2450e-16,  0.0000e+00, -9.7608e-03,  6.3074e-01,
-        -8.9807e-02, -2.9985e-14, -2.2922e-01,  5.5251e-17,  1.3514e-14,
-        -9.4096e-02, -1.9068e-12,  6.1526e-06,  6.4143e-12, -3.5174e-18,
-        -1.1381e-08, -9.6414e-02, -2.0167e-01,  7.4660e-02, -1.2890e-09,
-         7.3203e-22, -2.8770e-16,  1.0347e-19,  6.3316e-01, -1.4943e-11,
-         2.7638e-19,  4.5313e-02,  2.3222e-07, -1.2279e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0466,  0.0000,  0.0927,  0.1030,  2.1416, -0.3907,  0.0000,  0.0000,
-         0.0125, -0.1064,  0.0431,  0.0000, -0.1694, -0.4689,  0.0000,  0.0000,
-         0.0000, -0.2477,  0.0000,  0.0000,  0.0205, -0.1099,  0.0000, -0.8909,
-        -0.0466,  0.0000,  0.2393, -0.0361, -0.2798,  0.0000, -0.4089, -1.1492,
-         0.0000, -0.0258,  0.0000,  0.0000,  0.0000,  0.0000, -0.0098,  0.6307,
-        -0.0898,  0.0000, -0.2292,  0.0000,  0.0000, -0.0941,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0964, -0.2017,  0.0747,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6332,  0.0000,  0.0000,  0.0453,  0.0000, -0.1228],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0466,  0.0000,  0.0927,  0.1030,  2.1416, -0.3907,  0.0000,  0.0000,
-         0.0125, -0.1064,  0.0431,  0.0000, -0.1694, -0.4689,  0.0000,  0.0000,
-         0.0000, -0.2477,  0.0000,  0.0000,  0.0205, -0.1099,  0.0000, -0.8909,
-        -0.0466,  0.0000,  0.2393, -0.0361, -0.2798,  0.0000, -0.4089, -1.1492,
-         0.0000, -0.0258,  0.0000,  0.0000,  0.0000,  0.0000, -0.0098,  0.6307,
-        -0.0898,  0.0000, -0.2292,  0.0000,  0.0000, -0.0941,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0964, -0.2017,  0.0747,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6332,  0.0000,  0.0000,  0.0453,  0.0000, -0.1228],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.0399e-02, -8.5467e-11,  9.7680e-02,  1.1211e-01,  2.1408e+00,
-        -3.6795e-01,  5.1734e-20, -4.0850e-13,  3.6946e-03, -8.6520e-02,
-         3.3709e-02,  9.3004e-12, -1.7439e-01, -4.7614e-01, -9.2775e-17,
-        -6.8953e-14, -6.7826e-15, -2.3493e-01, -1.5688e-17,  1.8862e-13,
-         7.8452e-03, -1.0925e-01, -3.3315e-17, -9.0147e-01, -2.6068e-02,
-         2.7039e-13,  2.4601e-01, -3.5070e-02, -2.8701e-01,  4.8452e-17,
-        -3.8935e-01, -1.1503e+00, -1.2403e-20, -4.9814e-02,  6.1561e-15,
-         3.3216e-13,  2.0535e-16,  0.0000e+00, -3.0692e-02,  6.3261e-01,
-        -1.2431e-01, -2.7428e-14, -2.2725e-01,  5.0539e-17,  1.2361e-14,
-        -6.2387e-02, -1.7442e-12,  5.6279e-06,  5.8673e-12, -3.2174e-18,
-        -1.0410e-08, -9.9996e-02, -2.4509e-01,  1.1927e-01, -1.1790e-09,
-         6.6960e-22, -2.6317e-16,  9.4644e-20,  6.3055e-01, -1.3669e-11,
-         2.5282e-19,  4.0573e-02,  2.1241e-07, -1.1328e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0604,  0.0000,  0.0977,  0.1121,  2.1408, -0.3679,  0.0000,  0.0000,
-         0.0037, -0.0865,  0.0337,  0.0000, -0.1744, -0.4761,  0.0000,  0.0000,
-         0.0000, -0.2349,  0.0000,  0.0000,  0.0078, -0.1092,  0.0000, -0.9015,
-        -0.0261,  0.0000,  0.2460, -0.0351, -0.2870,  0.0000, -0.3893, -1.1503,
-         0.0000, -0.0498,  0.0000,  0.0000,  0.0000,  0.0000, -0.0307,  0.6326,
-        -0.1243,  0.0000, -0.2273,  0.0000,  0.0000, -0.0624,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.1000, -0.2451,  0.1193,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6306,  0.0000,  0.0000,  0.0406,  0.0000, -0.1133],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0604,  0.0000,  0.0977,  0.1121,  2.1408, -0.3679,  0.0000,  0.0000,
-         0.0037, -0.0865,  0.0337,  0.0000, -0.1744, -0.4761,  0.0000,  0.0000,
-         0.0000, -0.2349,  0.0000,  0.0000,  0.0078, -0.1092,  0.0000, -0.9015,
-        -0.0261,  0.0000,  0.2460, -0.0351, -0.2870,  0.0000, -0.3893, -1.1503,
-         0.0000, -0.0498,  0.0000,  0.0000,  0.0000,  0.0000, -0.0307,  0.6326,
-        -0.1243,  0.0000, -0.2273,  0.0000,  0.0000, -0.0624,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.1000, -0.2451,  0.1193,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6306,  0.0000,  0.0000,  0.0406,  0.0000, -0.1133],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 7.2794e-02, -7.8192e-11,  1.0193e-01,  1.1351e-01,  2.1397e+00,
-        -3.4792e-01,  4.7331e-20, -3.7373e-13, -8.1407e-04, -7.4476e-02,
-         4.5637e-02,  8.5088e-12, -1.7901e-01, -4.7733e-01, -8.4878e-17,
-        -6.3084e-14, -6.2053e-15, -2.2577e-01, -1.4352e-17,  1.7256e-13,
-         1.5301e-03, -1.1389e-01, -3.0479e-17, -9.0575e-01, -9.7372e-03,
-         2.4737e-13,  2.3762e-01, -3.1597e-02, -2.8705e-01,  4.4328e-17,
-        -3.7925e-01, -1.1502e+00, -1.1347e-20, -7.3869e-02,  5.6321e-15,
-         3.0389e-13,  1.8787e-16,  0.0000e+00, -3.8378e-02,  6.3319e-01,
-        -1.5407e-01, -2.5094e-14, -2.2702e-01,  4.6237e-17,  1.1309e-14,
-        -3.5466e-02, -1.5957e-12,  5.1489e-06,  5.3679e-12, -2.9436e-18,
-        -9.5240e-09, -9.1298e-02, -2.7637e-01,  1.4938e-01, -1.0787e-09,
-         6.1261e-22, -2.4076e-16,  8.6588e-20,  6.2359e-01, -1.2505e-11,
-         2.3130e-19,  2.8711e-02,  1.9433e-07, -1.0001e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 7.2794e-02,  0.0000e+00,  1.0193e-01,  1.1351e-01,  2.1397e+00,
-        -3.4792e-01,  0.0000e+00,  0.0000e+00, -8.1407e-04, -7.4476e-02,
-         4.5637e-02,  0.0000e+00, -1.7901e-01, -4.7733e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.2577e-01,  0.0000e+00,  0.0000e+00,
-         1.5301e-03, -1.1389e-01,  0.0000e+00, -9.0575e-01, -9.7372e-03,
-         0.0000e+00,  2.3762e-01, -3.1597e-02, -2.8705e-01,  0.0000e+00,
-        -3.7925e-01, -1.1502e+00,  0.0000e+00, -7.3869e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -3.8378e-02,  6.3319e-01,
-        -1.5407e-01,  0.0000e+00, -2.2702e-01,  0.0000e+00,  0.0000e+00,
-        -3.5466e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -9.1298e-02, -2.7637e-01,  1.4938e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.2359e-01,  0.0000e+00,
-         0.0000e+00,  2.8711e-02,  0.0000e+00, -1.0001e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 7.2794e-02,  0.0000e+00,  1.0193e-01,  1.1351e-01,  2.1397e+00,
-        -3.4792e-01,  0.0000e+00,  0.0000e+00, -8.1407e-04, -7.4476e-02,
-         4.5637e-02,  0.0000e+00, -1.7901e-01, -4.7733e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.2577e-01,  0.0000e+00,  0.0000e+00,
-         1.5301e-03, -1.1389e-01,  0.0000e+00, -9.0575e-01, -9.7372e-03,
-         0.0000e+00,  2.3762e-01, -3.1597e-02, -2.8705e-01,  0.0000e+00,
-        -3.7925e-01, -1.1502e+00,  0.0000e+00, -7.3869e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -3.8378e-02,  6.3319e-01,
-        -1.5407e-01,  0.0000e+00, -2.2702e-01,  0.0000e+00,  0.0000e+00,
-        -3.5466e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -9.1298e-02, -2.7637e-01,  1.4938e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.2359e-01,  0.0000e+00,
-         0.0000e+00,  2.8711e-02,  0.0000e+00, -1.0001e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 8.1795e-02, -7.1548e-11,  9.3781e-02,  1.0702e-01,  2.1384e+00,
-        -3.3166e-01,  4.3309e-20, -3.4197e-13, -1.8115e-03, -5.1305e-02,
-         5.6134e-02,  7.7858e-12, -1.8903e-01, -4.7747e-01, -7.7665e-17,
-        -5.7724e-14, -5.6780e-15, -2.1564e-01, -1.3133e-17,  1.5790e-13,
-        -9.9353e-03, -1.0878e-01, -2.7889e-17, -9.0229e-01,  1.3225e-02,
-         2.2635e-13,  2.1973e-01, -3.1266e-02, -2.8252e-01,  4.0561e-17,
-        -3.6999e-01, -1.1511e+00, -1.0383e-20, -8.8463e-02,  5.1535e-15,
-         2.7806e-13,  1.7191e-16,  0.0000e+00, -4.2173e-02,  6.3065e-01,
-        -1.6614e-01, -2.2961e-14, -2.3213e-01,  4.2308e-17,  1.0348e-14,
-        -2.6745e-02, -1.4601e-12,  4.7114e-06,  4.9118e-12, -2.6934e-18,
-        -8.7148e-09, -8.4654e-02, -2.9846e-01,  1.5209e-01, -9.8702e-10,
-         5.6055e-22, -2.2031e-16,  7.9230e-20,  6.1289e-01, -1.1443e-11,
-         2.1164e-19,  4.0104e-03,  1.7782e-07, -8.6307e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.1795e-02,  0.0000e+00,  9.3781e-02,  1.0702e-01,  2.1384e+00,
-        -3.3166e-01,  0.0000e+00,  0.0000e+00, -1.8115e-03, -5.1305e-02,
-         5.6134e-02,  0.0000e+00, -1.8903e-01, -4.7747e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.1564e-01,  0.0000e+00,  0.0000e+00,
-        -9.9353e-03, -1.0878e-01,  0.0000e+00, -9.0229e-01,  1.3225e-02,
-         0.0000e+00,  2.1973e-01, -3.1266e-02, -2.8252e-01,  0.0000e+00,
-        -3.6999e-01, -1.1511e+00,  0.0000e+00, -8.8463e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -4.2173e-02,  6.3065e-01,
-        -1.6614e-01,  0.0000e+00, -2.3213e-01,  0.0000e+00,  0.0000e+00,
-        -2.6745e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -8.4654e-02, -2.9846e-01,  1.5209e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.1289e-01,  0.0000e+00,
-         0.0000e+00,  4.0104e-03,  0.0000e+00, -8.6307e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.1795e-02,  0.0000e+00,  9.3781e-02,  1.0702e-01,  2.1384e+00,
-        -3.3166e-01,  0.0000e+00,  0.0000e+00, -1.8115e-03, -5.1305e-02,
-         5.6134e-02,  0.0000e+00, -1.8903e-01, -4.7747e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.1564e-01,  0.0000e+00,  0.0000e+00,
-        -9.9353e-03, -1.0878e-01,  0.0000e+00, -9.0229e-01,  1.3225e-02,
-         0.0000e+00,  2.1973e-01, -3.1266e-02, -2.8252e-01,  0.0000e+00,
-        -3.6999e-01, -1.1511e+00,  0.0000e+00, -8.8463e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -4.2173e-02,  6.3065e-01,
-        -1.6614e-01,  0.0000e+00, -2.3213e-01,  0.0000e+00,  0.0000e+00,
-        -2.6745e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -8.4654e-02, -2.9846e-01,  1.5209e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.1289e-01,  0.0000e+00,
-         0.0000e+00,  4.0104e-03,  0.0000e+00, -8.6307e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 7.6617e-02, -6.5478e-11,  8.1871e-02,  7.8570e-02,  2.1369e+00,
-        -3.1530e-01,  3.9635e-20, -3.1297e-13, -4.6049e-03, -4.0404e-02,
-         8.0625e-02,  7.1253e-12, -2.0550e-01, -4.7474e-01, -7.1077e-17,
-        -5.2827e-14, -5.1964e-15, -2.2200e-01, -1.2019e-17,  1.4451e-13,
-        -2.5884e-02, -1.0994e-01, -2.5524e-17, -8.9622e-01,  3.2426e-02,
-         2.0715e-13,  1.8764e-01, -3.6660e-02, -2.6669e-01,  3.7120e-17,
-        -3.6852e-01, -1.1529e+00, -9.5021e-21, -1.0279e-01,  4.7164e-15,
-         2.5448e-13,  1.5733e-16,  0.0000e+00, -4.3549e-02,  6.2225e-01,
-        -1.7442e-01, -2.1014e-14, -2.3692e-01,  3.8720e-17,  9.4704e-15,
-        -2.2162e-02, -1.3363e-12,  4.3117e-06,  4.4951e-12, -2.4650e-18,
-        -7.9755e-09, -7.7465e-02, -3.0937e-01,  9.6399e-02, -9.0329e-10,
-         5.1300e-22, -2.0162e-16,  7.2509e-20,  6.0223e-01, -1.0472e-11,
-         1.9369e-19, -3.0259e-02,  1.6274e-07, -7.3742e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0766,  0.0000,  0.0819,  0.0786,  2.1369, -0.3153,  0.0000,  0.0000,
-        -0.0046, -0.0404,  0.0806,  0.0000, -0.2055, -0.4747,  0.0000,  0.0000,
-         0.0000, -0.2220,  0.0000,  0.0000, -0.0259, -0.1099,  0.0000, -0.8962,
-         0.0324,  0.0000,  0.1876, -0.0367, -0.2667,  0.0000, -0.3685, -1.1529,
-         0.0000, -0.1028,  0.0000,  0.0000,  0.0000,  0.0000, -0.0435,  0.6223,
-        -0.1744,  0.0000, -0.2369,  0.0000,  0.0000, -0.0222,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0775, -0.3094,  0.0964,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6022,  0.0000,  0.0000, -0.0303,  0.0000, -0.0737],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0766,  0.0000,  0.0819,  0.0786,  2.1369, -0.3153,  0.0000,  0.0000,
-        -0.0046, -0.0404,  0.0806,  0.0000, -0.2055, -0.4747,  0.0000,  0.0000,
-         0.0000, -0.2220,  0.0000,  0.0000, -0.0259, -0.1099,  0.0000, -0.8962,
-         0.0324,  0.0000,  0.1876, -0.0367, -0.2667,  0.0000, -0.3685, -1.1529,
-         0.0000, -0.1028,  0.0000,  0.0000,  0.0000,  0.0000, -0.0435,  0.6223,
-        -0.1744,  0.0000, -0.2369,  0.0000,  0.0000, -0.0222,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0775, -0.3094,  0.0964,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6022,  0.0000,  0.0000, -0.0303,  0.0000, -0.0737],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.6938e-02, -5.9933e-11,  6.1463e-02,  4.6548e-02,  2.1357e+00,
-        -3.0125e-01,  3.6279e-20, -2.8646e-13, -9.6954e-03, -3.8910e-02,
-         8.9126e-02,  6.5219e-12, -2.2150e-01, -4.6246e-01, -6.5058e-17,
-        -4.8353e-14, -4.7563e-15, -2.3334e-01, -1.1001e-17,  1.3227e-13,
-        -3.8196e-02, -9.9417e-02, -2.3362e-17, -8.8880e-01,  4.9475e-02,
-         1.8961e-13,  1.4863e-01, -3.7596e-02, -2.5014e-01,  3.3977e-17,
-        -3.6574e-01, -1.1539e+00, -8.6973e-21, -1.2601e-01,  4.3169e-15,
-         2.3293e-13,  1.4400e-16,  0.0000e+00, -3.4922e-02,  6.1636e-01,
-        -1.6892e-01, -1.9234e-14, -2.4364e-01,  3.5440e-17,  8.6684e-15,
-        -2.1936e-02, -1.2231e-12,  3.9466e-06,  4.1145e-12, -2.2562e-18,
-        -7.3001e-09, -7.0178e-02, -3.1832e-01,  1.7714e-02, -8.2679e-10,
-         4.6956e-22, -1.8454e-16,  6.6369e-20,  5.9207e-01, -9.5851e-12,
-         1.7729e-19, -7.1038e-02,  1.4895e-07, -6.7960e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0669,  0.0000,  0.0615,  0.0465,  2.1357, -0.3013,  0.0000,  0.0000,
-        -0.0097, -0.0389,  0.0891,  0.0000, -0.2215, -0.4625,  0.0000,  0.0000,
-         0.0000, -0.2333,  0.0000,  0.0000, -0.0382, -0.0994,  0.0000, -0.8888,
-         0.0495,  0.0000,  0.1486, -0.0376, -0.2501,  0.0000, -0.3657, -1.1539,
-         0.0000, -0.1260,  0.0000,  0.0000,  0.0000,  0.0000, -0.0349,  0.6164,
-        -0.1689,  0.0000, -0.2436,  0.0000,  0.0000, -0.0219,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0702, -0.3183,  0.0177,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5921,  0.0000,  0.0000, -0.0710,  0.0000, -0.0680],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0669,  0.0000,  0.0615,  0.0465,  2.1357, -0.3013,  0.0000,  0.0000,
-        -0.0097, -0.0389,  0.0891,  0.0000, -0.2215, -0.4625,  0.0000,  0.0000,
-         0.0000, -0.2333,  0.0000,  0.0000, -0.0382, -0.0994,  0.0000, -0.8888,
-         0.0495,  0.0000,  0.1486, -0.0376, -0.2501,  0.0000, -0.3657, -1.1539,
-         0.0000, -0.1260,  0.0000,  0.0000,  0.0000,  0.0000, -0.0349,  0.6164,
-        -0.1689,  0.0000, -0.2436,  0.0000,  0.0000, -0.0219,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0702, -0.3183,  0.0177,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5921,  0.0000,  0.0000, -0.0710,  0.0000, -0.0680],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.3507e-02, -5.4865e-11,  4.3173e-02,  5.5495e-03,  2.1346e+00,
-        -2.9161e-01,  3.3211e-20, -2.6224e-13, -1.3052e-02, -4.9939e-02,
-         9.6492e-02,  5.9704e-12, -2.2917e-01, -4.4722e-01, -5.9557e-17,
-        -4.4265e-14, -4.3541e-15, -2.5368e-01, -1.0071e-17,  1.2108e-13,
-        -4.0178e-02, -8.7360e-02, -2.1387e-17, -8.7857e-01,  5.9436e-02,
-         1.7358e-13,  9.6969e-02, -2.9862e-02, -2.2555e-01,  3.1104e-17,
-        -3.7213e-01, -1.1565e+00, -7.9619e-21, -1.5184e-01,  3.9519e-15,
-         2.1323e-13,  1.3183e-16,  0.0000e+00, -2.0449e-02,  6.1691e-01,
-        -1.4920e-01, -1.7608e-14, -2.4758e-01,  3.2444e-17,  7.9354e-15,
-        -3.9224e-02, -1.1197e-12,  3.6129e-06,  3.7666e-12, -2.0654e-18,
-        -6.6828e-09, -5.0729e-02, -3.1867e-01, -7.4174e-02, -7.5688e-10,
-         4.2985e-22, -1.6894e-16,  6.0757e-20,  5.7622e-01, -8.7746e-12,
-         1.6230e-19, -1.1159e-01,  1.3636e-07, -6.9165e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0635,  0.0000,  0.0432,  0.0055,  2.1346, -0.2916,  0.0000,  0.0000,
-        -0.0131, -0.0499,  0.0965,  0.0000, -0.2292, -0.4472,  0.0000,  0.0000,
-         0.0000, -0.2537,  0.0000,  0.0000, -0.0402, -0.0874,  0.0000, -0.8786,
-         0.0594,  0.0000,  0.0970, -0.0299, -0.2256,  0.0000, -0.3721, -1.1565,
-         0.0000, -0.1518,  0.0000,  0.0000,  0.0000,  0.0000, -0.0204,  0.6169,
-        -0.1492,  0.0000, -0.2476,  0.0000,  0.0000, -0.0392,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0507, -0.3187, -0.0742,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5762,  0.0000,  0.0000, -0.1116,  0.0000, -0.0692],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0635,  0.0000,  0.0432,  0.0055,  2.1346, -0.2916,  0.0000,  0.0000,
-        -0.0131, -0.0499,  0.0965,  0.0000, -0.2292, -0.4472,  0.0000,  0.0000,
-         0.0000, -0.2537,  0.0000,  0.0000, -0.0402, -0.0874,  0.0000, -0.8786,
-         0.0594,  0.0000,  0.0970, -0.0299, -0.2256,  0.0000, -0.3721, -1.1565,
-         0.0000, -0.1518,  0.0000,  0.0000,  0.0000,  0.0000, -0.0204,  0.6169,
-        -0.1492,  0.0000, -0.2476,  0.0000,  0.0000, -0.0392,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0507, -0.3187, -0.0742,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5762,  0.0000,  0.0000, -0.1116,  0.0000, -0.0692],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.2678e-02, -5.0233e-11,  2.6497e-02, -2.3576e-02,  2.1333e+00,
-        -2.8006e-01,  3.0407e-20, -2.4010e-13, -1.9638e-02, -4.4359e-02,
-         1.0088e-01,  5.4663e-12, -2.3938e-01, -4.3522e-01, -5.4528e-17,
-        -4.0528e-14, -3.9865e-15, -2.6727e-01, -9.2205e-18,  1.1086e-13,
-        -4.2476e-02, -7.5340e-02, -1.9581e-17, -8.7837e-01,  6.6038e-02,
-         1.5892e-13,  5.5152e-02, -2.0502e-02, -1.9720e-01,  2.8478e-17,
-        -3.7086e-01, -1.1595e+00, -7.2897e-21, -1.8150e-01,  3.6183e-15,
-         1.9523e-13,  1.2070e-16,  0.0000e+00,  2.6202e-03,  6.2035e-01,
-        -1.2951e-01, -1.6121e-14, -2.4932e-01,  2.9705e-17,  7.2654e-15,
-        -5.1791e-02, -1.0251e-12,  3.3078e-06,  3.4485e-12, -1.8911e-18,
-        -6.1186e-09, -4.7629e-02, -3.1592e-01, -1.3980e-01, -6.9298e-10,
-         3.9356e-22, -1.5468e-16,  5.5627e-20,  5.6041e-01, -8.0338e-12,
-         1.4859e-19, -1.4523e-01,  1.2485e-07, -7.1836e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0627,  0.0000,  0.0265, -0.0236,  2.1333, -0.2801,  0.0000,  0.0000,
-        -0.0196, -0.0444,  0.1009,  0.0000, -0.2394, -0.4352,  0.0000,  0.0000,
-         0.0000, -0.2673,  0.0000,  0.0000, -0.0425, -0.0753,  0.0000, -0.8784,
-         0.0660,  0.0000,  0.0552, -0.0205, -0.1972,  0.0000, -0.3709, -1.1595,
-         0.0000, -0.1815,  0.0000,  0.0000,  0.0000,  0.0000,  0.0026,  0.6204,
-        -0.1295,  0.0000, -0.2493,  0.0000,  0.0000, -0.0518,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0476, -0.3159, -0.1398,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5604,  0.0000,  0.0000, -0.1452,  0.0000, -0.0718],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0627,  0.0000,  0.0265, -0.0236,  2.1333, -0.2801,  0.0000,  0.0000,
-        -0.0196, -0.0444,  0.1009,  0.0000, -0.2394, -0.4352,  0.0000,  0.0000,
-         0.0000, -0.2673,  0.0000,  0.0000, -0.0425, -0.0753,  0.0000, -0.8784,
-         0.0660,  0.0000,  0.0552, -0.0205, -0.1972,  0.0000, -0.3709, -1.1595,
-         0.0000, -0.1815,  0.0000,  0.0000,  0.0000,  0.0000,  0.0026,  0.6204,
-        -0.1295,  0.0000, -0.2493,  0.0000,  0.0000, -0.0518,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0476, -0.3159, -0.1398,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5604,  0.0000,  0.0000, -0.1452,  0.0000, -0.0718],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 7.2268e-02, -4.5998e-11,  1.0641e-02, -3.4315e-02,  2.1319e+00,
-        -2.6252e-01,  2.7843e-20, -2.1986e-13, -1.8884e-02, -3.3902e-02,
-         8.7996e-02,  5.0055e-12, -2.4638e-01, -4.2388e-01, -4.9931e-17,
-        -3.7111e-14, -3.6504e-15, -2.6899e-01, -8.4431e-18,  1.0151e-13,
-        -4.1280e-02, -5.0595e-02, -1.7930e-17, -8.7522e-01,  8.1586e-02,
-         1.4552e-13,  2.3555e-02, -1.3291e-03, -1.6956e-01,  2.6077e-17,
-        -3.6784e-01, -1.1636e+00, -6.6751e-21, -2.0497e-01,  3.3132e-15,
-         1.7877e-13,  1.1052e-16,  0.0000e+00,  2.0155e-02,  6.2339e-01,
-        -1.1199e-01, -1.4762e-14, -2.5498e-01,  2.7200e-17,  6.6529e-15,
-        -6.0775e-02, -9.3871e-13,  3.0289e-06,  3.1578e-12, -1.7316e-18,
-        -5.6027e-09, -3.4230e-02, -3.2406e-01, -1.6817e-01, -6.3456e-10,
-         3.6038e-22, -1.4164e-16,  5.0937e-20,  5.4433e-01, -7.3564e-12,
-         1.3607e-19, -1.6799e-01,  1.1432e-07, -7.8988e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 7.2268e-02,  0.0000e+00,  1.0641e-02, -3.4315e-02,  2.1319e+00,
-        -2.6252e-01,  0.0000e+00,  0.0000e+00, -1.8884e-02, -3.3902e-02,
-         8.7996e-02,  0.0000e+00, -2.4638e-01, -4.2388e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.6899e-01,  0.0000e+00,  0.0000e+00,
-        -4.1280e-02, -5.0595e-02,  0.0000e+00, -8.7522e-01,  8.1586e-02,
-         0.0000e+00,  2.3555e-02, -1.3291e-03, -1.6956e-01,  0.0000e+00,
-        -3.6784e-01, -1.1636e+00,  0.0000e+00, -2.0497e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  2.0155e-02,  6.2339e-01,
-        -1.1199e-01,  0.0000e+00, -2.5498e-01,  0.0000e+00,  0.0000e+00,
-        -6.0775e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -3.4230e-02, -3.2406e-01, -1.6817e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.4433e-01,  0.0000e+00,
-         0.0000e+00, -1.6799e-01,  0.0000e+00, -7.8988e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 7.2268e-02,  0.0000e+00,  1.0641e-02, -3.4315e-02,  2.1319e+00,
-        -2.6252e-01,  0.0000e+00,  0.0000e+00, -1.8884e-02, -3.3902e-02,
-         8.7996e-02,  0.0000e+00, -2.4638e-01, -4.2388e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.6899e-01,  0.0000e+00,  0.0000e+00,
-        -4.1280e-02, -5.0595e-02,  0.0000e+00, -8.7522e-01,  8.1586e-02,
-         0.0000e+00,  2.3555e-02, -1.3291e-03, -1.6956e-01,  0.0000e+00,
-        -3.6784e-01, -1.1636e+00,  0.0000e+00, -2.0497e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  2.0155e-02,  6.2339e-01,
-        -1.1199e-01,  0.0000e+00, -2.5498e-01,  0.0000e+00,  0.0000e+00,
-        -6.0775e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -3.4230e-02, -3.2406e-01, -1.6817e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.4433e-01,  0.0000e+00,
-         0.0000e+00, -1.6799e-01,  0.0000e+00, -7.8988e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 8.5388e-02, -4.2125e-11,  9.5886e-04, -3.9037e-02,  2.1305e+00,
-        -2.4920e-01,  2.5499e-20, -2.0135e-13, -1.6963e-02, -5.2321e-03,
-         8.6718e-02,  4.5841e-12, -2.5206e-01, -4.1977e-01, -4.5727e-17,
-        -3.3986e-14, -3.3431e-15, -2.6092e-01, -7.7323e-18,  9.2968e-14,
-        -4.5885e-02, -2.1350e-02, -1.6420e-17, -8.7672e-01,  9.1119e-02,
-         1.3327e-13,  3.8002e-03,  7.9093e-03, -1.5021e-01,  2.3881e-17,
-        -3.6437e-01, -1.1689e+00, -6.1131e-21, -2.2313e-01,  3.0343e-15,
-         1.6372e-13,  1.0122e-16,  0.0000e+00,  5.7679e-02,  6.2209e-01,
-        -1.0363e-01, -1.3519e-14, -2.5795e-01,  2.4910e-17,  6.0928e-15,
-        -7.2189e-02, -8.5967e-13,  2.7739e-06,  2.8919e-12, -1.5858e-18,
-        -5.1310e-09, -2.7932e-02, -3.2513e-01, -1.7361e-01, -5.8113e-10,
-         3.3004e-22, -1.2971e-16,  4.6649e-20,  5.2857e-01, -6.7371e-12,
-         1.2461e-19, -1.8695e-01,  1.0470e-07, -7.2964e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.5388e-02,  0.0000e+00,  9.5886e-04, -3.9037e-02,  2.1305e+00,
-        -2.4920e-01,  0.0000e+00,  0.0000e+00, -1.6963e-02, -5.2321e-03,
-         8.6718e-02,  0.0000e+00, -2.5206e-01, -4.1977e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.6092e-01,  0.0000e+00,  0.0000e+00,
-        -4.5885e-02, -2.1350e-02,  0.0000e+00, -8.7672e-01,  9.1119e-02,
-         0.0000e+00,  3.8002e-03,  7.9093e-03, -1.5021e-01,  0.0000e+00,
-        -3.6437e-01, -1.1689e+00,  0.0000e+00, -2.2313e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.7679e-02,  6.2209e-01,
-        -1.0363e-01,  0.0000e+00, -2.5795e-01,  0.0000e+00,  0.0000e+00,
-        -7.2189e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -2.7932e-02, -3.2513e-01, -1.7361e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.2857e-01,  0.0000e+00,
-         0.0000e+00, -1.8695e-01,  0.0000e+00, -7.2964e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.5388e-02,  0.0000e+00,  9.5886e-04, -3.9037e-02,  2.1305e+00,
-        -2.4920e-01,  0.0000e+00,  0.0000e+00, -1.6963e-02, -5.2321e-03,
-         8.6718e-02,  0.0000e+00, -2.5206e-01, -4.1977e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.6092e-01,  0.0000e+00,  0.0000e+00,
-        -4.5885e-02, -2.1350e-02,  0.0000e+00, -8.7672e-01,  9.1119e-02,
-         0.0000e+00,  3.8002e-03,  7.9093e-03, -1.5021e-01,  0.0000e+00,
-        -3.6437e-01, -1.1689e+00,  0.0000e+00, -2.2313e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.7679e-02,  6.2209e-01,
-        -1.0363e-01,  0.0000e+00, -2.5795e-01,  0.0000e+00,  0.0000e+00,
-        -7.2189e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -2.7932e-02, -3.2513e-01, -1.7361e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.2857e-01,  0.0000e+00,
-         0.0000e+00, -1.8695e-01,  0.0000e+00, -7.2964e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.0531e-01, -3.8583e-11, -5.4389e-03, -3.6414e-02,  2.1289e+00,
-        -2.3607e-01,  2.3355e-20, -1.8441e-13, -1.2856e-02,  1.9750e-02,
-         1.0177e-01,  4.1986e-12, -2.5550e-01, -4.1275e-01, -4.1882e-17,
-        -3.1128e-14, -3.0620e-15, -2.5575e-01, -7.0821e-18,  8.5150e-14,
-        -3.9874e-02,  3.1199e-03, -1.5040e-17, -8.7837e-01,  9.9017e-02,
-         1.2206e-13, -1.7523e-02,  2.3170e-02, -1.2691e-01,  2.1873e-17,
-        -3.7357e-01, -1.1745e+00, -5.5991e-21, -2.3933e-01,  2.7791e-15,
-         1.4995e-13,  9.2705e-17,  0.0000e+00,  1.0245e-01,  6.2703e-01,
-        -8.1577e-02, -1.2382e-14, -2.6015e-01,  2.2816e-17,  5.5804e-15,
-        -8.9105e-02, -7.8739e-13,  2.5407e-06,  2.6488e-12, -1.4525e-18,
-        -4.6996e-09, -1.4716e-02, -3.2407e-01, -1.8132e-01, -5.3227e-10,
-         3.0229e-22, -1.1880e-16,  4.2726e-20,  5.0799e-01, -6.1706e-12,
-         1.1413e-19, -1.9706e-01,  9.5893e-08, -6.1009e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1053,  0.0000, -0.0054, -0.0364,  2.1289, -0.2361,  0.0000,  0.0000,
-        -0.0129,  0.0198,  0.1018,  0.0000, -0.2555, -0.4127,  0.0000,  0.0000,
-         0.0000, -0.2558,  0.0000,  0.0000, -0.0399,  0.0031,  0.0000, -0.8784,
-         0.0990,  0.0000, -0.0175,  0.0232, -0.1269,  0.0000, -0.3736, -1.1745,
-         0.0000, -0.2393,  0.0000,  0.0000,  0.0000,  0.0000,  0.1025,  0.6270,
-        -0.0816,  0.0000, -0.2601,  0.0000,  0.0000, -0.0891,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0147, -0.3241, -0.1813,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5080,  0.0000,  0.0000, -0.1971,  0.0000, -0.0610],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1053,  0.0000, -0.0054, -0.0364,  2.1289, -0.2361,  0.0000,  0.0000,
-        -0.0129,  0.0198,  0.1018,  0.0000, -0.2555, -0.4127,  0.0000,  0.0000,
-         0.0000, -0.2558,  0.0000,  0.0000, -0.0399,  0.0031,  0.0000, -0.8784,
-         0.0990,  0.0000, -0.0175,  0.0232, -0.1269,  0.0000, -0.3736, -1.1745,
-         0.0000, -0.2393,  0.0000,  0.0000,  0.0000,  0.0000,  0.1025,  0.6270,
-        -0.0816,  0.0000, -0.2601,  0.0000,  0.0000, -0.0891,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0147, -0.3241, -0.1813,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5080,  0.0000,  0.0000, -0.1971,  0.0000, -0.0610],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.2991e-01, -3.5343e-11, -1.3292e-02, -1.6000e-02,  2.1276e+00,
-        -2.2168e-01,  2.1394e-20, -1.6893e-13,  2.9234e-03,  6.2336e-02,
-         1.1895e-01,  3.8460e-12, -2.6316e-01, -4.1309e-01, -3.8365e-17,
-        -2.8514e-14, -2.8048e-15, -2.3212e-01, -6.4873e-18,  7.7999e-14,
-        -2.5214e-02,  2.9430e-02, -1.3777e-17, -8.7864e-01,  1.1061e-01,
-         1.1181e-13, -3.4542e-02,  5.3287e-02, -1.0590e-01,  2.0036e-17,
-        -3.8517e-01, -1.1800e+00, -5.1288e-21, -2.4183e-01,  2.5457e-15,
-         1.3736e-13,  8.4919e-17,  0.0000e+00,  1.3589e-01,  6.2837e-01,
-        -6.4070e-02, -1.1342e-14, -2.6718e-01,  2.0899e-17,  5.1118e-15,
-        -1.2432e-01, -7.2126e-13,  2.3273e-06,  2.4263e-12, -1.3305e-18,
-        -4.3049e-09,  1.3193e-02, -3.3216e-01, -1.4112e-01, -4.8756e-10,
-         2.7690e-22, -1.0883e-16,  3.9138e-20,  4.9321e-01, -5.6523e-12,
-         1.0455e-19, -1.9811e-01,  8.7839e-08, -5.0117e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1299,  0.0000, -0.0133, -0.0160,  2.1276, -0.2217,  0.0000,  0.0000,
-         0.0029,  0.0623,  0.1189,  0.0000, -0.2632, -0.4131,  0.0000,  0.0000,
-         0.0000, -0.2321,  0.0000,  0.0000, -0.0252,  0.0294,  0.0000, -0.8786,
-         0.1106,  0.0000, -0.0345,  0.0533, -0.1059,  0.0000, -0.3852, -1.1800,
-         0.0000, -0.2418,  0.0000,  0.0000,  0.0000,  0.0000,  0.1359,  0.6284,
-        -0.0641,  0.0000, -0.2672,  0.0000,  0.0000, -0.1243,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0132, -0.3322, -0.1411,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4932,  0.0000,  0.0000, -0.1981,  0.0000, -0.0501],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1299,  0.0000, -0.0133, -0.0160,  2.1276, -0.2217,  0.0000,  0.0000,
-         0.0029,  0.0623,  0.1189,  0.0000, -0.2632, -0.4131,  0.0000,  0.0000,
-         0.0000, -0.2321,  0.0000,  0.0000, -0.0252,  0.0294,  0.0000, -0.8786,
-         0.1106,  0.0000, -0.0345,  0.0533, -0.1059,  0.0000, -0.3852, -1.1800,
-         0.0000, -0.2418,  0.0000,  0.0000,  0.0000,  0.0000,  0.1359,  0.6284,
-        -0.0641,  0.0000, -0.2672,  0.0000,  0.0000, -0.1243,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0132, -0.3322, -0.1411,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4932,  0.0000,  0.0000, -0.1981,  0.0000, -0.0501],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.6119e-01, -3.2378e-11, -1.3247e-02,  2.1350e-02,  2.1265e+00,
-        -2.0430e-01,  1.9599e-20, -1.5476e-13,  2.0546e-02,  1.1199e-01,
-         1.4149e-01,  3.5233e-12, -2.6821e-01, -4.1558e-01, -3.5146e-17,
-        -2.6122e-14, -2.5695e-15, -2.0677e-01, -5.9431e-18,  7.1456e-14,
-        -4.1493e-03,  5.0063e-02, -1.2621e-17, -8.8213e-01,  1.1772e-01,
-         1.0243e-13, -4.2162e-02,  8.7606e-02, -8.4986e-02,  1.8355e-17,
-        -3.9485e-01, -1.1855e+00, -4.6986e-21, -2.4327e-01,  2.3322e-15,
-         1.2583e-13,  7.7795e-17,  0.0000e+00,  1.5602e-01,  6.2846e-01,
-        -4.5897e-02, -1.0391e-14, -2.6596e-01,  1.9146e-17,  4.6829e-15,
-        -1.5828e-01, -6.6075e-13,  2.1321e-06,  2.2228e-12, -1.2189e-18,
-        -3.9437e-09,  3.0760e-02, -3.4051e-01, -8.3875e-02, -4.4666e-10,
-         2.5367e-22, -9.9697e-17,  3.5855e-20,  4.8226e-01, -5.1782e-12,
-         9.5776e-20, -1.9304e-01,  8.0470e-08, -3.1624e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1612,  0.0000, -0.0132,  0.0214,  2.1265, -0.2043,  0.0000,  0.0000,
-         0.0205,  0.1120,  0.1415,  0.0000, -0.2682, -0.4156,  0.0000,  0.0000,
-         0.0000, -0.2068,  0.0000,  0.0000, -0.0041,  0.0501,  0.0000, -0.8821,
-         0.1177,  0.0000, -0.0422,  0.0876, -0.0850,  0.0000, -0.3948, -1.1855,
-         0.0000, -0.2433,  0.0000,  0.0000,  0.0000,  0.0000,  0.1560,  0.6285,
-        -0.0459,  0.0000, -0.2660,  0.0000,  0.0000, -0.1583,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0308, -0.3405, -0.0839,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4823,  0.0000,  0.0000, -0.1930,  0.0000, -0.0316],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1612,  0.0000, -0.0132,  0.0214,  2.1265, -0.2043,  0.0000,  0.0000,
-         0.0205,  0.1120,  0.1415,  0.0000, -0.2682, -0.4156,  0.0000,  0.0000,
-         0.0000, -0.2068,  0.0000,  0.0000, -0.0041,  0.0501,  0.0000, -0.8821,
-         0.1177,  0.0000, -0.0422,  0.0876, -0.0850,  0.0000, -0.3948, -1.1855,
-         0.0000, -0.2433,  0.0000,  0.0000,  0.0000,  0.0000,  0.1560,  0.6285,
-        -0.0459,  0.0000, -0.2660,  0.0000,  0.0000, -0.1583,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0308, -0.3405, -0.0839,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4823,  0.0000,  0.0000, -0.1930,  0.0000, -0.0316],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.8729e-01, -2.9665e-11, -1.4310e-02,  4.7926e-02,  2.1254e+00,
-        -1.8388e-01,  1.7956e-20, -1.4179e-13,  3.2952e-02,  1.5043e-01,
-         1.6229e-01,  3.2281e-12, -2.7787e-01, -4.2311e-01, -3.2201e-17,
-        -2.3933e-14, -2.3542e-15, -1.8170e-01, -5.4450e-18,  6.5468e-14,
-         1.6194e-02,  6.7957e-02, -1.1563e-17, -8.8113e-01,  1.2352e-01,
-         9.3849e-14, -4.9049e-02,  1.1706e-01, -6.4289e-02,  1.6817e-17,
-        -4.0237e-01, -1.1901e+00, -4.3048e-21, -2.3653e-01,  2.1367e-15,
-         1.1529e-13,  7.1276e-17,  0.0000e+00,  1.6296e-01,  6.2234e-01,
-        -3.7369e-02, -9.5201e-15, -2.6667e-01,  1.7542e-17,  4.2905e-15,
-        -1.8154e-01, -6.0538e-13,  1.9534e-06,  2.0365e-12, -1.1167e-18,
-        -3.6133e-09,  4.5636e-02, -3.4729e-01, -3.3345e-02, -4.0923e-10,
-         2.3241e-22, -9.1342e-17,  3.2850e-20,  4.7249e-01, -4.7442e-12,
-         8.7750e-20, -1.8797e-01,  7.3727e-08, -1.3423e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1873,  0.0000, -0.0143,  0.0479,  2.1254, -0.1839,  0.0000,  0.0000,
-         0.0330,  0.1504,  0.1623,  0.0000, -0.2779, -0.4231,  0.0000,  0.0000,
-         0.0000, -0.1817,  0.0000,  0.0000,  0.0162,  0.0680,  0.0000, -0.8811,
-         0.1235,  0.0000, -0.0490,  0.1171, -0.0643,  0.0000, -0.4024, -1.1901,
-         0.0000, -0.2365,  0.0000,  0.0000,  0.0000,  0.0000,  0.1630,  0.6223,
-        -0.0374,  0.0000, -0.2667,  0.0000,  0.0000, -0.1815,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0456, -0.3473, -0.0333,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4725,  0.0000,  0.0000, -0.1880,  0.0000, -0.0134],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1873,  0.0000, -0.0143,  0.0479,  2.1254, -0.1839,  0.0000,  0.0000,
-         0.0330,  0.1504,  0.1623,  0.0000, -0.2779, -0.4231,  0.0000,  0.0000,
-         0.0000, -0.1817,  0.0000,  0.0000,  0.0162,  0.0680,  0.0000, -0.8811,
-         0.1235,  0.0000, -0.0490,  0.1171, -0.0643,  0.0000, -0.4024, -1.1901,
-         0.0000, -0.2365,  0.0000,  0.0000,  0.0000,  0.0000,  0.1630,  0.6223,
-        -0.0374,  0.0000, -0.2667,  0.0000,  0.0000, -0.1815,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0456, -0.3473, -0.0333,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4725,  0.0000,  0.0000, -0.1880,  0.0000, -0.0134],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.1064e-01, -2.7181e-11, -1.7090e-02,  6.1898e-02,  2.1245e+00,
-        -1.6672e-01,  1.6453e-20, -1.2992e-13,  4.2495e-02,  1.7704e-01,
-         1.7838e-01,  2.9578e-12, -2.8816e-01, -4.2982e-01, -2.9505e-17,
-        -2.1929e-14, -2.1571e-15, -1.6110e-01, -4.9892e-18,  5.9987e-14,
-         3.4422e-02,  7.5510e-02, -1.0595e-17, -8.8204e-01,  1.1236e-01,
-         8.5992e-14, -4.7655e-02,  1.3437e-01, -5.2759e-02,  1.5409e-17,
-        -4.1108e-01, -1.1933e+00, -3.9444e-21, -2.2899e-01,  1.9578e-15,
-         1.0564e-13,  6.5309e-17,  0.0000e+00,  1.6950e-01,  6.1554e-01,
-        -3.5682e-02, -8.7231e-15, -2.6982e-01,  1.6073e-17,  3.9313e-15,
-        -2.0172e-01, -5.5470e-13,  1.7899e-06,  1.8660e-12, -1.0232e-18,
-        -3.3108e-09,  5.1494e-02, -3.4543e-01,  8.3509e-04, -3.7497e-10,
-         2.1295e-22, -8.3695e-17,  3.0100e-20,  4.6548e-01, -4.3470e-12,
-         8.0403e-20, -1.7846e-01,  6.7554e-08,  1.7027e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.1064e-01,  0.0000e+00, -1.7090e-02,  6.1898e-02,  2.1245e+00,
-        -1.6672e-01,  0.0000e+00,  0.0000e+00,  4.2495e-02,  1.7704e-01,
-         1.7838e-01,  0.0000e+00, -2.8816e-01, -4.2982e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.6110e-01,  0.0000e+00,  0.0000e+00,
-         3.4422e-02,  7.5510e-02,  0.0000e+00, -8.8204e-01,  1.1236e-01,
-         0.0000e+00, -4.7655e-02,  1.3437e-01, -5.2759e-02,  0.0000e+00,
-        -4.1108e-01, -1.1933e+00,  0.0000e+00, -2.2899e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.6950e-01,  6.1554e-01,
-        -3.5682e-02,  0.0000e+00, -2.6982e-01,  0.0000e+00,  0.0000e+00,
-        -2.0172e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  5.1494e-02, -3.4543e-01,  8.3509e-04,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6548e-01,  0.0000e+00,
-         0.0000e+00, -1.7846e-01,  0.0000e+00,  1.7027e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.1064e-01,  0.0000e+00, -1.7090e-02,  6.1898e-02,  2.1245e+00,
-        -1.6672e-01,  0.0000e+00,  0.0000e+00,  4.2495e-02,  1.7704e-01,
-         1.7838e-01,  0.0000e+00, -2.8816e-01, -4.2982e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.6110e-01,  0.0000e+00,  0.0000e+00,
-         3.4422e-02,  7.5510e-02,  0.0000e+00, -8.8204e-01,  1.1236e-01,
-         0.0000e+00, -4.7655e-02,  1.3437e-01, -5.2759e-02,  0.0000e+00,
-        -4.1108e-01, -1.1933e+00,  0.0000e+00, -2.2899e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.6950e-01,  6.1554e-01,
-        -3.5682e-02,  0.0000e+00, -2.6982e-01,  0.0000e+00,  0.0000e+00,
-        -2.0172e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  5.1494e-02, -3.4543e-01,  8.3509e-04,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6548e-01,  0.0000e+00,
-         0.0000e+00, -1.7846e-01,  0.0000e+00,  1.7027e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.3743e-01, -2.4908e-11, -1.5570e-02,  6.4393e-02,  2.1239e+00,
-        -1.4270e-01,  1.5077e-20, -1.1905e-13,  5.1032e-02,  1.8997e-01,
-         1.9111e-01,  2.7104e-12, -2.9630e-01, -4.3265e-01, -2.7037e-17,
-        -2.0095e-14, -1.9767e-15, -1.5406e-01, -4.5719e-18,  5.4969e-14,
-         5.6851e-02,  7.8984e-02, -9.7090e-18, -8.8359e-01,  1.0068e-01,
-         7.8799e-14, -4.5478e-02,  1.5726e-01, -3.5843e-02,  1.4120e-17,
-        -4.2117e-01, -1.1961e+00, -3.6145e-21, -2.1787e-01,  1.7941e-15,
-         9.6801e-14,  5.9846e-17,  0.0000e+00,  1.6234e-01,  6.0937e-01,
-        -3.0819e-02, -7.9934e-15, -2.6953e-01,  1.4729e-17,  3.6025e-15,
-        -2.1821e-01, -5.0830e-13,  1.6401e-06,  1.7099e-12, -9.3766e-19,
-        -3.0338e-09,  6.5184e-02, -3.4123e-01,  2.1036e-02, -3.4361e-10,
-         1.9514e-22, -7.6694e-17,  2.7582e-20,  4.6362e-01, -3.9834e-12,
-         7.3678e-20, -1.7018e-01,  6.1904e-08,  3.1652e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2374,  0.0000, -0.0156,  0.0644,  2.1239, -0.1427,  0.0000,  0.0000,
-         0.0510,  0.1900,  0.1911,  0.0000, -0.2963, -0.4326,  0.0000,  0.0000,
-         0.0000, -0.1541,  0.0000,  0.0000,  0.0569,  0.0790,  0.0000, -0.8836,
-         0.1007,  0.0000, -0.0455,  0.1573, -0.0358,  0.0000, -0.4212, -1.1961,
-         0.0000, -0.2179,  0.0000,  0.0000,  0.0000,  0.0000,  0.1623,  0.6094,
-        -0.0308,  0.0000, -0.2695,  0.0000,  0.0000, -0.2182,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0652, -0.3412,  0.0210,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4636,  0.0000,  0.0000, -0.1702,  0.0000,  0.0317],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2374,  0.0000, -0.0156,  0.0644,  2.1239, -0.1427,  0.0000,  0.0000,
-         0.0510,  0.1900,  0.1911,  0.0000, -0.2963, -0.4326,  0.0000,  0.0000,
-         0.0000, -0.1541,  0.0000,  0.0000,  0.0569,  0.0790,  0.0000, -0.8836,
-         0.1007,  0.0000, -0.0455,  0.1573, -0.0358,  0.0000, -0.4212, -1.1961,
-         0.0000, -0.2179,  0.0000,  0.0000,  0.0000,  0.0000,  0.1623,  0.6094,
-        -0.0308,  0.0000, -0.2695,  0.0000,  0.0000, -0.2182,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0652, -0.3412,  0.0210,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4636,  0.0000,  0.0000, -0.1702,  0.0000,  0.0317],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6488e-01, -2.2826e-11, -1.0590e-02,  6.9818e-02,  2.1233e+00,
-        -1.1707e-01,  1.3817e-20, -1.0910e-13,  4.9659e-02,  2.1036e-01,
-         2.1109e-01,  2.4839e-12, -2.9897e-01, -4.3478e-01, -2.4778e-17,
-        -1.8416e-14, -1.8115e-15, -1.4978e-01, -4.1898e-18,  5.0375e-14,
-         7.6163e-02,  7.8269e-02, -8.8975e-18, -8.9004e-01,  8.7098e-02,
-         7.2213e-14, -3.4419e-02,  1.8296e-01, -2.5840e-02,  1.2940e-17,
-        -4.3015e-01, -1.1987e+00, -3.3124e-21, -2.1217e-01,  1.6441e-15,
-         8.8711e-14,  5.4844e-17,  0.0000e+00,  1.4588e-01,  6.0002e-01,
-        -2.3997e-02, -7.3253e-15, -2.6465e-01,  1.3498e-17,  3.3014e-15,
-        -2.1552e-01, -4.6582e-13,  1.5031e-06,  1.5670e-12, -8.5929e-19,
-        -2.7803e-09,  6.3127e-02, -3.2429e-01,  3.7251e-02, -3.1489e-10,
-         1.7883e-22, -7.0284e-17,  2.5277e-20,  4.6588e-01, -3.6505e-12,
-         6.7520e-20, -1.5533e-01,  5.6730e-08,  5.5068e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2649,  0.0000, -0.0106,  0.0698,  2.1233, -0.1171,  0.0000,  0.0000,
-         0.0497,  0.2104,  0.2111,  0.0000, -0.2990, -0.4348,  0.0000,  0.0000,
-         0.0000, -0.1498,  0.0000,  0.0000,  0.0762,  0.0783,  0.0000, -0.8900,
-         0.0871,  0.0000, -0.0344,  0.1830, -0.0258,  0.0000, -0.4302, -1.1987,
-         0.0000, -0.2122,  0.0000,  0.0000,  0.0000,  0.0000,  0.1459,  0.6000,
-        -0.0240,  0.0000, -0.2646,  0.0000,  0.0000, -0.2155,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0631, -0.3243,  0.0373,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4659,  0.0000,  0.0000, -0.1553,  0.0000,  0.0551],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2649,  0.0000, -0.0106,  0.0698,  2.1233, -0.1171,  0.0000,  0.0000,
-         0.0497,  0.2104,  0.2111,  0.0000, -0.2990, -0.4348,  0.0000,  0.0000,
-         0.0000, -0.1498,  0.0000,  0.0000,  0.0762,  0.0783,  0.0000, -0.8900,
-         0.0871,  0.0000, -0.0344,  0.1830, -0.0258,  0.0000, -0.4302, -1.1987,
-         0.0000, -0.2122,  0.0000,  0.0000,  0.0000,  0.0000,  0.1459,  0.6000,
-        -0.0240,  0.0000, -0.2646,  0.0000,  0.0000, -0.2155,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0631, -0.3243,  0.0373,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4659,  0.0000,  0.0000, -0.1553,  0.0000,  0.0551],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8760e-01, -2.0919e-11, -1.8564e-03,  8.1031e-02,  2.1229e+00,
-        -8.9586e-02,  1.2663e-20, -9.9988e-14,  4.0205e-02,  2.2851e-01,
-         2.1765e-01,  2.2764e-12, -2.9845e-01, -4.3250e-01, -2.2708e-17,
-        -1.6878e-14, -1.6602e-15, -1.5500e-01, -3.8398e-18,  4.6168e-14,
-         9.8736e-02,  7.3316e-02, -8.1544e-18, -8.9838e-01,  7.2188e-02,
-         6.6182e-14, -1.6430e-02,  2.0759e-01, -2.0161e-02,  1.1859e-17,
-        -4.3503e-01, -1.2022e+00, -3.0358e-21, -2.0397e-01,  1.5068e-15,
-         8.1302e-14,  5.0264e-17,  0.0000e+00,  1.3713e-01,  5.9293e-01,
-        -2.5686e-02, -6.7135e-15, -2.6156e-01,  1.2370e-17,  3.0257e-15,
-        -2.0058e-01, -4.2691e-13,  1.3775e-06,  1.4361e-12, -7.8752e-19,
-        -2.5481e-09,  5.6708e-02, -3.0576e-01,  6.0895e-02, -2.8859e-10,
-         1.6390e-22, -6.4414e-17,  2.3166e-20,  4.6811e-01, -3.3456e-12,
-         6.1881e-20, -1.3988e-01,  5.1992e-08,  7.3246e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.8760e-01,  0.0000e+00, -1.8564e-03,  8.1031e-02,  2.1229e+00,
-        -8.9586e-02,  0.0000e+00,  0.0000e+00,  4.0205e-02,  2.2851e-01,
-         2.1765e-01,  0.0000e+00, -2.9845e-01, -4.3250e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.5500e-01,  0.0000e+00,  0.0000e+00,
-         9.8736e-02,  7.3316e-02,  0.0000e+00, -8.9838e-01,  7.2188e-02,
-         0.0000e+00, -1.6430e-02,  2.0759e-01, -2.0161e-02,  0.0000e+00,
-        -4.3503e-01, -1.2022e+00,  0.0000e+00, -2.0397e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.3713e-01,  5.9293e-01,
-        -2.5686e-02,  0.0000e+00, -2.6156e-01,  0.0000e+00,  0.0000e+00,
-        -2.0058e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  5.6708e-02, -3.0576e-01,  6.0895e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6811e-01,  0.0000e+00,
-         0.0000e+00, -1.3988e-01,  0.0000e+00,  7.3246e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.8760e-01,  0.0000e+00, -1.8564e-03,  8.1031e-02,  2.1229e+00,
-        -8.9586e-02,  0.0000e+00,  0.0000e+00,  4.0205e-02,  2.2851e-01,
-         2.1765e-01,  0.0000e+00, -2.9845e-01, -4.3250e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.5500e-01,  0.0000e+00,  0.0000e+00,
-         9.8736e-02,  7.3316e-02,  0.0000e+00, -8.9838e-01,  7.2188e-02,
-         0.0000e+00, -1.6430e-02,  2.0759e-01, -2.0161e-02,  0.0000e+00,
-        -4.3503e-01, -1.2022e+00,  0.0000e+00, -2.0397e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.3713e-01,  5.9293e-01,
-        -2.5686e-02,  0.0000e+00, -2.6156e-01,  0.0000e+00,  0.0000e+00,
-        -2.0058e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  5.6708e-02, -3.0576e-01,  6.0895e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6811e-01,  0.0000e+00,
-         0.0000e+00, -1.3988e-01,  0.0000e+00,  7.3246e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.0861e-01, -1.9173e-11,  3.6097e-03,  9.1203e-02,  2.1222e+00,
-        -5.9552e-02,  1.1606e-20, -9.1642e-14,  2.4046e-02,  2.4126e-01,
-         2.2376e-01,  2.0864e-12, -3.0076e-01, -4.2984e-01, -2.0813e-17,
-        -1.5469e-14, -1.5216e-15, -1.6453e-01, -3.5193e-18,  4.2314e-14,
-         1.2224e-01,  7.2768e-02, -7.4738e-18, -9.0374e-01,  6.0192e-02,
-         6.0658e-14,  5.1806e-03,  2.3297e-01, -1.1880e-02,  1.0870e-17,
-        -4.4051e-01, -1.2048e+00, -2.7824e-21, -1.9254e-01,  1.3810e-15,
-         7.4516e-14,  4.6068e-17,  0.0000e+00,  1.1477e-01,  5.8233e-01,
-        -2.7829e-02, -6.1532e-15, -2.6212e-01,  1.1338e-17,  2.7731e-15,
-        -1.7758e-01, -3.9128e-13,  1.2626e-06,  1.3163e-12, -7.2179e-19,
-        -2.3354e-09,  5.1271e-02, -2.9037e-01,  8.1541e-02, -2.6450e-10,
-         1.5022e-22, -5.9038e-17,  2.1232e-20,  4.7216e-01, -3.0664e-12,
-         5.6716e-20, -1.1784e-01,  4.7652e-08,  9.3388e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3086,  0.0000,  0.0036,  0.0912,  2.1222, -0.0596,  0.0000,  0.0000,
-         0.0240,  0.2413,  0.2238,  0.0000, -0.3008, -0.4298,  0.0000,  0.0000,
-         0.0000, -0.1645,  0.0000,  0.0000,  0.1222,  0.0728,  0.0000, -0.9037,
-         0.0602,  0.0000,  0.0052,  0.2330, -0.0119,  0.0000, -0.4405, -1.2048,
-         0.0000, -0.1925,  0.0000,  0.0000,  0.0000,  0.0000,  0.1148,  0.5823,
-        -0.0278,  0.0000, -0.2621,  0.0000,  0.0000, -0.1776,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0513, -0.2904,  0.0815,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4722,  0.0000,  0.0000, -0.1178,  0.0000,  0.0934],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3086,  0.0000,  0.0036,  0.0912,  2.1222, -0.0596,  0.0000,  0.0000,
-         0.0240,  0.2413,  0.2238,  0.0000, -0.3008, -0.4298,  0.0000,  0.0000,
-         0.0000, -0.1645,  0.0000,  0.0000,  0.1222,  0.0728,  0.0000, -0.9037,
-         0.0602,  0.0000,  0.0052,  0.2330, -0.0119,  0.0000, -0.4405, -1.2048,
-         0.0000, -0.1925,  0.0000,  0.0000,  0.0000,  0.0000,  0.1148,  0.5823,
-        -0.0278,  0.0000, -0.2621,  0.0000,  0.0000, -0.1776,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0513, -0.2904,  0.0815,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4722,  0.0000,  0.0000, -0.1178,  0.0000,  0.0934],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2832e-01, -1.7574e-11,  1.4626e-02,  8.9968e-02,  2.1216e+00,
-        -3.4520e-02,  1.0638e-20, -8.3997e-14,  4.3057e-03,  2.5007e-01,
-         2.3039e-01,  1.9124e-12, -2.9691e-01, -4.3178e-01, -1.9077e-17,
-        -1.4178e-14, -1.3947e-15, -1.7841e-01, -3.2257e-18,  3.8784e-14,
-         1.3500e-01,  6.8073e-02, -6.8503e-18, -9.0655e-01,  4.2481e-02,
-         5.5598e-14,  3.0253e-02,  2.4845e-01, -1.3431e-02,  9.9628e-18,
-        -4.3919e-01, -1.2070e+00, -2.5503e-21, -1.7949e-01,  1.2658e-15,
-         6.8299e-14,  4.2225e-17,  0.0000e+00,  9.7051e-02,  5.7282e-01,
-        -4.1986e-02, -5.6399e-15, -2.5741e-01,  1.0392e-17,  2.5418e-15,
-        -1.5957e-01, -3.5864e-13,  1.1572e-06,  1.2065e-12, -6.6158e-19,
-        -2.1406e-09,  3.9283e-02, -2.7184e-01,  9.2667e-02, -2.4244e-10,
-         1.3769e-22, -5.4113e-17,  1.9461e-20,  4.7425e-01, -2.8106e-12,
-         5.1985e-20, -9.9447e-02,  4.3677e-08,  1.1818e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3283,  0.0000,  0.0146,  0.0900,  2.1216, -0.0345,  0.0000,  0.0000,
-         0.0043,  0.2501,  0.2304,  0.0000, -0.2969, -0.4318,  0.0000,  0.0000,
-         0.0000, -0.1784,  0.0000,  0.0000,  0.1350,  0.0681,  0.0000, -0.9065,
-         0.0425,  0.0000,  0.0303,  0.2485, -0.0134,  0.0000, -0.4392, -1.2070,
-         0.0000, -0.1795,  0.0000,  0.0000,  0.0000,  0.0000,  0.0971,  0.5728,
-        -0.0420,  0.0000, -0.2574,  0.0000,  0.0000, -0.1596,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0393, -0.2718,  0.0927,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4743,  0.0000,  0.0000, -0.0994,  0.0000,  0.1182],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3283,  0.0000,  0.0146,  0.0900,  2.1216, -0.0345,  0.0000,  0.0000,
-         0.0043,  0.2501,  0.2304,  0.0000, -0.2969, -0.4318,  0.0000,  0.0000,
-         0.0000, -0.1784,  0.0000,  0.0000,  0.1350,  0.0681,  0.0000, -0.9065,
-         0.0425,  0.0000,  0.0303,  0.2485, -0.0134,  0.0000, -0.4392, -1.2070,
-         0.0000, -0.1795,  0.0000,  0.0000,  0.0000,  0.0000,  0.0971,  0.5728,
-        -0.0420,  0.0000, -0.2574,  0.0000,  0.0000, -0.1596,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0393, -0.2718,  0.0927,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4743,  0.0000,  0.0000, -0.0994,  0.0000,  0.1182],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4188e-01, -1.6108e-11,  2.4281e-02,  7.7558e-02,  2.1210e+00,
-        -1.4899e-02,  9.7507e-21, -7.6993e-14, -1.5258e-02,  2.5409e-01,
-         2.4262e-01,  1.7529e-12, -2.9676e-01, -4.3655e-01, -1.7486e-17,
-        -1.2996e-14, -1.2784e-15, -1.8861e-01, -2.9568e-18,  3.5550e-14,
-         1.3922e-01,  5.9578e-02, -6.2791e-18, -9.0797e-01,  1.7404e-02,
-         5.0961e-14,  4.9174e-02,  2.5185e-01, -2.0437e-02,  9.1320e-18,
-        -4.3302e-01, -1.2096e+00, -2.3376e-21, -1.5695e-01,  1.1603e-15,
-         6.2604e-14,  3.8704e-17,  0.0000e+00,  8.2697e-02,  5.6413e-01,
-        -6.9713e-02, -5.1696e-15, -2.5345e-01,  9.5254e-18,  2.3298e-15,
-        -1.4961e-01, -3.2873e-13,  1.0607e-06,  1.1059e-12, -6.0641e-19,
-        -1.9621e-09,  1.8041e-02, -2.4882e-01,  8.3043e-02, -2.2222e-10,
-         1.2620e-22, -4.9600e-17,  1.7838e-20,  4.7757e-01, -2.5762e-12,
-         4.7650e-20, -8.3220e-02,  4.0035e-08,  1.3650e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3419,  0.0000,  0.0243,  0.0776,  2.1210, -0.0149,  0.0000,  0.0000,
-        -0.0153,  0.2541,  0.2426,  0.0000, -0.2968, -0.4365,  0.0000,  0.0000,
-         0.0000, -0.1886,  0.0000,  0.0000,  0.1392,  0.0596,  0.0000, -0.9080,
-         0.0174,  0.0000,  0.0492,  0.2519, -0.0204,  0.0000, -0.4330, -1.2096,
-         0.0000, -0.1570,  0.0000,  0.0000,  0.0000,  0.0000,  0.0827,  0.5641,
-        -0.0697,  0.0000, -0.2534,  0.0000,  0.0000, -0.1496,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0180, -0.2488,  0.0830,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4776,  0.0000,  0.0000, -0.0832,  0.0000,  0.1365],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3419,  0.0000,  0.0243,  0.0776,  2.1210, -0.0149,  0.0000,  0.0000,
-        -0.0153,  0.2541,  0.2426,  0.0000, -0.2968, -0.4365,  0.0000,  0.0000,
-         0.0000, -0.1886,  0.0000,  0.0000,  0.1392,  0.0596,  0.0000, -0.9080,
-         0.0174,  0.0000,  0.0492,  0.2519, -0.0204,  0.0000, -0.4330, -1.2096,
-         0.0000, -0.1570,  0.0000,  0.0000,  0.0000,  0.0000,  0.0827,  0.5641,
-        -0.0697,  0.0000, -0.2534,  0.0000,  0.0000, -0.1496,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0180, -0.2488,  0.0830,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4776,  0.0000,  0.0000, -0.0832,  0.0000,  0.1365],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5279e-01, -1.4766e-11,  3.1721e-02,  5.7192e-02,  2.1201e+00,
-         4.9814e-03,  8.9379e-21, -7.0575e-14, -3.3182e-02,  2.5181e-01,
-         2.4820e-01,  1.6068e-12, -2.9578e-01, -4.4083e-01, -1.6028e-17,
-        -1.1913e-14, -1.1718e-15, -2.0381e-01, -2.7103e-18,  3.2587e-14,
-         1.4161e-01,  4.9020e-02, -5.7557e-18, -9.0792e-01, -8.4028e-03,
-         4.6713e-14,  6.8413e-02,  2.5013e-01, -2.9449e-02,  8.3708e-18,
-        -4.2867e-01, -1.2125e+00, -2.1428e-21, -1.3591e-01,  1.0636e-15,
-         5.7386e-14,  3.5478e-17,  0.0000e+00,  7.4793e-02,  5.5590e-01,
-        -1.0408e-01, -4.7387e-15, -2.4997e-01,  8.7314e-18,  2.1356e-15,
-        -1.4318e-01, -3.0133e-13,  9.7231e-07,  1.0137e-12, -5.5586e-19,
-        -1.7985e-09, -1.6715e-03, -2.2677e-01,  5.7274e-02, -2.0370e-10,
-         1.1568e-22, -4.5466e-17,  1.6351e-20,  4.7850e-01, -2.3615e-12,
-         4.3678e-20, -7.2717e-02,  3.6698e-08,  1.4979e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.5279e-01,  0.0000e+00,  3.1721e-02,  5.7192e-02,  2.1201e+00,
-         4.9814e-03,  0.0000e+00,  0.0000e+00, -3.3182e-02,  2.5181e-01,
-         2.4820e-01,  0.0000e+00, -2.9578e-01, -4.4083e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.0381e-01,  0.0000e+00,  0.0000e+00,
-         1.4161e-01,  4.9020e-02,  0.0000e+00, -9.0792e-01, -8.4028e-03,
-         0.0000e+00,  6.8413e-02,  2.5013e-01, -2.9449e-02,  0.0000e+00,
-        -4.2867e-01, -1.2125e+00,  0.0000e+00, -1.3591e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  7.4793e-02,  5.5590e-01,
-        -1.0408e-01,  0.0000e+00, -2.4997e-01,  0.0000e+00,  0.0000e+00,
-        -1.4318e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -1.6715e-03, -2.2677e-01,  5.7274e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.7850e-01,  0.0000e+00,
-         0.0000e+00, -7.2717e-02,  0.0000e+00,  1.4979e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.5279e-01,  0.0000e+00,  3.1721e-02,  5.7192e-02,  2.1201e+00,
-         4.9814e-03,  0.0000e+00,  0.0000e+00, -3.3182e-02,  2.5181e-01,
-         2.4820e-01,  0.0000e+00, -2.9578e-01, -4.4083e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.0381e-01,  0.0000e+00,  0.0000e+00,
-         1.4161e-01,  4.9020e-02,  0.0000e+00, -9.0792e-01, -8.4028e-03,
-         0.0000e+00,  6.8413e-02,  2.5013e-01, -2.9449e-02,  0.0000e+00,
-        -4.2867e-01, -1.2125e+00,  0.0000e+00, -1.3591e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  7.4793e-02,  5.5590e-01,
-        -1.0408e-01,  0.0000e+00, -2.4997e-01,  0.0000e+00,  0.0000e+00,
-        -1.4318e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -1.6715e-03, -2.2677e-01,  5.7274e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.7850e-01,  0.0000e+00,
-         0.0000e+00, -7.2717e-02,  0.0000e+00,  1.4979e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5778e-01, -1.3535e-11,  4.0296e-02,  3.9735e-02,  2.1188e+00,
-         1.9505e-02,  8.1930e-21, -6.4693e-14, -4.5695e-02,  2.4841e-01,
-         2.5081e-01,  1.4729e-12, -2.9121e-01, -4.4567e-01, -1.4692e-17,
-        -1.0920e-14, -1.0741e-15, -2.1407e-01, -2.4844e-18,  2.9871e-14,
-         1.3973e-01,  4.4073e-02, -5.2760e-18, -9.0854e-01, -2.8704e-02,
-         4.2820e-14,  8.8075e-02,  2.4534e-01, -4.0658e-02,  7.6732e-18,
-        -4.2580e-01, -1.2150e+00, -1.9642e-21, -1.2181e-01,  9.7493e-16,
-         5.2603e-14,  3.2521e-17,  0.0000e+00,  5.9261e-02,  5.4629e-01,
-        -1.3305e-01, -4.3437e-15, -2.4330e-01,  8.0038e-18,  1.9576e-15,
-        -1.3959e-01, -2.7622e-13,  8.9128e-07,  9.2920e-13, -5.0954e-19,
-        -1.6486e-09, -5.8048e-03, -2.1683e-01,  3.3231e-02, -1.8672e-10,
-         1.0604e-22, -4.1677e-17,  1.4988e-20,  4.8154e-01, -2.1647e-12,
-         4.0038e-20, -6.1493e-02,  3.3639e-08,  1.6257e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3578,  0.0000,  0.0403,  0.0397,  2.1188,  0.0195,  0.0000,  0.0000,
-        -0.0457,  0.2484,  0.2508,  0.0000, -0.2912, -0.4457,  0.0000,  0.0000,
-         0.0000, -0.2141,  0.0000,  0.0000,  0.1397,  0.0441,  0.0000, -0.9085,
-        -0.0287,  0.0000,  0.0881,  0.2453, -0.0407,  0.0000, -0.4258, -1.2150,
-         0.0000, -0.1218,  0.0000,  0.0000,  0.0000,  0.0000,  0.0593,  0.5463,
-        -0.1330,  0.0000, -0.2433,  0.0000,  0.0000, -0.1396,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0058, -0.2168,  0.0332,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4815,  0.0000,  0.0000, -0.0615,  0.0000,  0.1626],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3578,  0.0000,  0.0403,  0.0397,  2.1188,  0.0195,  0.0000,  0.0000,
-        -0.0457,  0.2484,  0.2508,  0.0000, -0.2912, -0.4457,  0.0000,  0.0000,
-         0.0000, -0.2141,  0.0000,  0.0000,  0.1397,  0.0441,  0.0000, -0.9085,
-        -0.0287,  0.0000,  0.0881,  0.2453, -0.0407,  0.0000, -0.4258, -1.2150,
-         0.0000, -0.1218,  0.0000,  0.0000,  0.0000,  0.0000,  0.0593,  0.5463,
-        -0.1330,  0.0000, -0.2433,  0.0000,  0.0000, -0.1396,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0058, -0.2168,  0.0332,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4815,  0.0000,  0.0000, -0.0615,  0.0000,  0.1626],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5650e-01, -1.2407e-11,  4.1074e-02,  2.1614e-02,  2.1173e+00,
-         3.3108e-02,  7.5103e-21, -5.9303e-14, -6.0062e-02,  2.3582e-01,
-         2.4454e-01,  1.3502e-12, -2.9251e-01, -4.4646e-01, -1.3468e-17,
-        -1.0010e-14, -9.8464e-16, -2.2177e-01, -2.2774e-18,  2.7382e-14,
-         1.3496e-01,  4.8117e-02, -4.8364e-18, -9.0648e-01, -3.9553e-02,
-         3.9252e-14,  9.8255e-02,  2.3805e-01, -4.8752e-02,  7.0338e-18,
-        -4.2091e-01, -1.2175e+00, -1.8005e-21, -1.0962e-01,  8.9369e-16,
-         4.8220e-14,  2.9811e-17,  0.0000e+00,  3.7229e-02,  5.3753e-01,
-        -1.5245e-01, -3.9818e-15, -2.4309e-01,  7.3368e-18,  1.7945e-15,
-        -1.4374e-01, -2.5320e-13,  8.1701e-07,  8.5177e-13, -4.6708e-19,
-        -1.5112e-09, -4.5652e-03, -2.1610e-01,  6.6716e-03, -1.7116e-10,
-         9.7207e-23, -3.8204e-17,  1.3740e-20,  4.8207e-01, -1.9843e-12,
-         3.6702e-20, -5.8882e-02,  3.0836e-08,  1.6471e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3565,  0.0000,  0.0411,  0.0216,  2.1173,  0.0331,  0.0000,  0.0000,
-        -0.0601,  0.2358,  0.2445,  0.0000, -0.2925, -0.4465,  0.0000,  0.0000,
-         0.0000, -0.2218,  0.0000,  0.0000,  0.1350,  0.0481,  0.0000, -0.9065,
-        -0.0396,  0.0000,  0.0983,  0.2380, -0.0488,  0.0000, -0.4209, -1.2175,
-         0.0000, -0.1096,  0.0000,  0.0000,  0.0000,  0.0000,  0.0372,  0.5375,
-        -0.1525,  0.0000, -0.2431,  0.0000,  0.0000, -0.1437,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0046, -0.2161,  0.0067,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4821,  0.0000,  0.0000, -0.0589,  0.0000,  0.1647],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3565,  0.0000,  0.0411,  0.0216,  2.1173,  0.0331,  0.0000,  0.0000,
-        -0.0601,  0.2358,  0.2445,  0.0000, -0.2925, -0.4465,  0.0000,  0.0000,
-         0.0000, -0.2218,  0.0000,  0.0000,  0.1350,  0.0481,  0.0000, -0.9065,
-        -0.0396,  0.0000,  0.0983,  0.2380, -0.0488,  0.0000, -0.4209, -1.2175,
-         0.0000, -0.1096,  0.0000,  0.0000,  0.0000,  0.0000,  0.0372,  0.5375,
-        -0.1525,  0.0000, -0.2431,  0.0000,  0.0000, -0.1437,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0046, -0.2161,  0.0067,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4821,  0.0000,  0.0000, -0.0589,  0.0000,  0.1647],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5444e-01, -1.1373e-11,  4.2188e-02,  7.7516e-05,  2.1158e+00,
-         4.6635e-02,  6.8845e-21, -5.4361e-14, -6.8509e-02,  2.2330e-01,
-         2.3022e-01,  1.2376e-12, -2.9157e-01, -4.4475e-01, -1.2346e-17,
-        -9.1760e-15, -9.0260e-16, -2.3360e-01, -2.0876e-18,  2.5100e-14,
-         1.3170e-01,  5.9434e-02, -4.4334e-18, -9.0369e-01, -4.6607e-02,
-         3.5982e-14,  1.0392e-01,  2.3432e-01, -5.6825e-02,  6.4477e-18,
-        -4.1316e-01, -1.2193e+00, -1.6505e-21, -9.5199e-02,  8.1922e-16,
-         4.4202e-14,  2.7327e-17,  0.0000e+00,  1.8286e-02,  5.3136e-01,
-        -1.6535e-01, -3.6500e-15, -2.4317e-01,  6.7255e-18,  1.6450e-15,
-        -1.5168e-01, -2.3210e-13,  7.4893e-07,  7.8079e-13, -4.2816e-19,
-        -1.3853e-09,  5.2105e-03, -2.2137e-01, -1.2249e-02, -1.5690e-10,
-         8.9107e-23, -3.5021e-17,  1.2595e-20,  4.8326e-01, -1.8189e-12,
-         3.3643e-20, -5.7022e-02,  2.8267e-08,  1.5922e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.5444e-01,  0.0000e+00,  4.2188e-02,  7.7516e-05,  2.1158e+00,
-         4.6635e-02,  0.0000e+00,  0.0000e+00, -6.8509e-02,  2.2330e-01,
-         2.3022e-01,  0.0000e+00, -2.9157e-01, -4.4475e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.3360e-01,  0.0000e+00,  0.0000e+00,
-         1.3170e-01,  5.9434e-02,  0.0000e+00, -9.0369e-01, -4.6607e-02,
-         0.0000e+00,  1.0392e-01,  2.3432e-01, -5.6825e-02,  0.0000e+00,
-        -4.1316e-01, -1.2193e+00,  0.0000e+00, -9.5199e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.8286e-02,  5.3136e-01,
-        -1.6535e-01,  0.0000e+00, -2.4317e-01,  0.0000e+00,  0.0000e+00,
-        -1.5168e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  5.2105e-03, -2.2137e-01, -1.2249e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.8326e-01,  0.0000e+00,
-         0.0000e+00, -5.7022e-02,  0.0000e+00,  1.5922e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.5444e-01,  0.0000e+00,  4.2188e-02,  7.7516e-05,  2.1158e+00,
-         4.6635e-02,  0.0000e+00,  0.0000e+00, -6.8509e-02,  2.2330e-01,
-         2.3022e-01,  0.0000e+00, -2.9157e-01, -4.4475e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.3360e-01,  0.0000e+00,  0.0000e+00,
-         1.3170e-01,  5.9434e-02,  0.0000e+00, -9.0369e-01, -4.6607e-02,
-         0.0000e+00,  1.0392e-01,  2.3432e-01, -5.6825e-02,  0.0000e+00,
-        -4.1316e-01, -1.2193e+00,  0.0000e+00, -9.5199e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.8286e-02,  5.3136e-01,
-        -1.6535e-01,  0.0000e+00, -2.4317e-01,  0.0000e+00,  0.0000e+00,
-        -1.5168e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  5.2105e-03, -2.2137e-01, -1.2249e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.8326e-01,  0.0000e+00,
-         0.0000e+00, -5.7022e-02,  0.0000e+00,  1.5922e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5013e-01, -1.0426e-11,  4.1648e-02, -2.3349e-02,  2.1152e+00,
-         5.8595e-02,  6.3108e-21, -4.9831e-14, -7.0808e-02,  2.0968e-01,
-         2.0366e-01,  1.1345e-12, -2.9171e-01, -4.4061e-01, -1.1317e-17,
-        -8.4113e-15, -8.2738e-16, -2.2751e-01, -1.9137e-18,  2.3009e-14,
-         1.2688e-01,  7.6131e-02, -4.0639e-18, -8.9639e-01, -4.7831e-02,
-         3.2983e-14,  1.0725e-01,  2.2228e-01, -6.1396e-02,  5.9104e-18,
-        -4.0580e-01, -1.2207e+00, -1.5129e-21, -7.3318e-02,  7.5095e-16,
-         4.0519e-14,  2.5050e-17,  0.0000e+00, -8.4725e-04,  5.2204e-01,
-        -1.7427e-01, -3.3459e-15, -2.4837e-01,  6.1650e-18,  1.5079e-15,
-        -1.6617e-01, -2.1276e-13,  6.8652e-07,  7.1573e-13, -3.9248e-19,
-        -1.2699e-09,  3.1572e-02, -2.3856e-01, -2.2193e-02, -1.4382e-10,
-         8.1682e-23, -3.2102e-17,  1.1545e-20,  4.8518e-01, -1.6674e-12,
-         3.0840e-20, -5.6478e-02,  2.5911e-08,  1.4634e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.5013e-01,  0.0000e+00,  4.1648e-02, -2.3349e-02,  2.1152e+00,
-         5.8595e-02,  0.0000e+00,  0.0000e+00, -7.0808e-02,  2.0968e-01,
-         2.0366e-01,  0.0000e+00, -2.9171e-01, -4.4061e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.2751e-01,  0.0000e+00,  0.0000e+00,
-         1.2688e-01,  7.6131e-02,  0.0000e+00, -8.9639e-01, -4.7831e-02,
-         0.0000e+00,  1.0725e-01,  2.2228e-01, -6.1396e-02,  0.0000e+00,
-        -4.0580e-01, -1.2207e+00,  0.0000e+00, -7.3318e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -8.4725e-04,  5.2204e-01,
-        -1.7427e-01,  0.0000e+00, -2.4837e-01,  0.0000e+00,  0.0000e+00,
-        -1.6617e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  3.1572e-02, -2.3856e-01, -2.2193e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.8518e-01,  0.0000e+00,
-         0.0000e+00, -5.6478e-02,  0.0000e+00,  1.4634e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.5013e-01,  0.0000e+00,  4.1648e-02, -2.3349e-02,  2.1152e+00,
-         5.8595e-02,  0.0000e+00,  0.0000e+00, -7.0808e-02,  2.0968e-01,
-         2.0366e-01,  0.0000e+00, -2.9171e-01, -4.4061e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.2751e-01,  0.0000e+00,  0.0000e+00,
-         1.2688e-01,  7.6131e-02,  0.0000e+00, -8.9639e-01, -4.7831e-02,
-         0.0000e+00,  1.0725e-01,  2.2228e-01, -6.1396e-02,  0.0000e+00,
-        -4.0580e-01, -1.2207e+00,  0.0000e+00, -7.3318e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -8.4725e-04,  5.2204e-01,
-        -1.7427e-01,  0.0000e+00, -2.4837e-01,  0.0000e+00,  0.0000e+00,
-        -1.6617e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  3.1572e-02, -2.3856e-01, -2.2193e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.8518e-01,  0.0000e+00,
-         0.0000e+00, -5.6478e-02,  0.0000e+00,  1.4634e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4584e-01, -9.5567e-12,  3.8066e-02, -4.5897e-02,  2.1146e+00,
-         6.5201e-02,  5.7848e-21, -4.5678e-14, -7.0340e-02,  1.9432e-01,
-         1.7341e-01,  1.0400e-12, -2.9069e-01, -4.3911e-01, -1.0374e-17,
-        -7.7102e-15, -7.5842e-16, -2.1360e-01, -1.7542e-18,  2.1091e-14,
-         1.2304e-01,  9.3750e-02, -3.7252e-18, -8.8891e-01, -4.8145e-02,
-         3.0234e-14,  1.0685e-01,  2.1180e-01, -6.0212e-02,  5.4178e-18,
-        -4.0235e-01, -1.2231e+00, -1.3868e-21, -5.6660e-02,  6.8836e-16,
-         3.7141e-14,  2.2962e-17,  0.0000e+00, -1.9719e-02,  5.1301e-01,
-        -1.7461e-01, -3.0670e-15, -2.5298e-01,  5.6512e-18,  1.3822e-15,
-        -1.8097e-01, -1.9503e-13,  6.2930e-07,  6.5608e-13, -3.5977e-19,
-        -1.1640e-09,  6.1424e-02, -2.5794e-01, -2.1567e-02, -1.3184e-10,
-         7.4874e-23, -2.9427e-17,  1.0583e-20,  4.8630e-01, -1.5284e-12,
-         2.8269e-20, -5.8850e-02,  2.3752e-08,  1.2417e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3458,  0.0000,  0.0381, -0.0459,  2.1146,  0.0652,  0.0000,  0.0000,
-        -0.0703,  0.1943,  0.1734,  0.0000, -0.2907, -0.4391,  0.0000,  0.0000,
-         0.0000, -0.2136,  0.0000,  0.0000,  0.1230,  0.0938,  0.0000, -0.8889,
-        -0.0481,  0.0000,  0.1068,  0.2118, -0.0602,  0.0000, -0.4024, -1.2231,
-         0.0000, -0.0567,  0.0000,  0.0000,  0.0000,  0.0000, -0.0197,  0.5130,
-        -0.1746,  0.0000, -0.2530,  0.0000,  0.0000, -0.1810,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0614, -0.2579, -0.0216,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4863,  0.0000,  0.0000, -0.0589,  0.0000,  0.1242],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3458,  0.0000,  0.0381, -0.0459,  2.1146,  0.0652,  0.0000,  0.0000,
-        -0.0703,  0.1943,  0.1734,  0.0000, -0.2907, -0.4391,  0.0000,  0.0000,
-         0.0000, -0.2136,  0.0000,  0.0000,  0.1230,  0.0938,  0.0000, -0.8889,
-        -0.0481,  0.0000,  0.1068,  0.2118, -0.0602,  0.0000, -0.4024, -1.2231,
-         0.0000, -0.0567,  0.0000,  0.0000,  0.0000,  0.0000, -0.0197,  0.5130,
-        -0.1746,  0.0000, -0.2530,  0.0000,  0.0000, -0.1810,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0614, -0.2579, -0.0216,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4863,  0.0000,  0.0000, -0.0589,  0.0000,  0.1242],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4584e-01, -8.7599e-12,  3.2733e-02, -5.7945e-02,  2.1143e+00,
-         6.8495e-02,  5.3025e-21, -4.1870e-14, -6.7369e-02,  1.8307e-01,
-         1.5097e-01,  9.5325e-13, -2.8736e-01, -4.3510e-01, -9.5090e-18,
-        -7.0674e-15, -6.9519e-16, -1.9362e-01, -1.6079e-18,  1.9333e-14,
-         1.2384e-01,  1.2006e-01, -3.4146e-18, -8.8110e-01, -3.9032e-02,
-         2.7713e-14,  1.0022e-01,  2.1465e-01, -5.5457e-02,  4.9661e-18,
-        -4.0243e-01, -1.2249e+00, -1.2712e-21, -4.8196e-02,  6.3097e-16,
-         3.4045e-14,  2.1048e-17,  0.0000e+00, -3.2097e-02,  5.0603e-01,
-        -1.6133e-01, -2.8113e-15, -2.5909e-01,  5.1800e-18,  1.2670e-15,
-        -1.9287e-01, -1.7877e-13,  5.7684e-07,  6.0138e-13, -3.2977e-19,
-        -1.0670e-09,  9.3388e-02, -2.7698e-01, -6.6370e-03, -1.2085e-10,
-         6.8631e-23, -2.6973e-17,  9.7006e-21,  4.8277e-01, -1.4010e-12,
-         2.5912e-20, -5.8011e-02,  2.1771e-08,  9.8686e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3458,  0.0000,  0.0327, -0.0579,  2.1143,  0.0685,  0.0000,  0.0000,
-        -0.0674,  0.1831,  0.1510,  0.0000, -0.2874, -0.4351,  0.0000,  0.0000,
-         0.0000, -0.1936,  0.0000,  0.0000,  0.1238,  0.1201,  0.0000, -0.8811,
-        -0.0390,  0.0000,  0.1002,  0.2146, -0.0555,  0.0000, -0.4024, -1.2249,
-         0.0000, -0.0482,  0.0000,  0.0000,  0.0000,  0.0000, -0.0321,  0.5060,
-        -0.1613,  0.0000, -0.2591,  0.0000,  0.0000, -0.1929,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0934, -0.2770, -0.0066,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4828,  0.0000,  0.0000, -0.0580,  0.0000,  0.0987],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3458,  0.0000,  0.0327, -0.0579,  2.1143,  0.0685,  0.0000,  0.0000,
-        -0.0674,  0.1831,  0.1510,  0.0000, -0.2874, -0.4351,  0.0000,  0.0000,
-         0.0000, -0.1936,  0.0000,  0.0000,  0.1238,  0.1201,  0.0000, -0.8811,
-        -0.0390,  0.0000,  0.1002,  0.2146, -0.0555,  0.0000, -0.4024, -1.2249,
-         0.0000, -0.0482,  0.0000,  0.0000,  0.0000,  0.0000, -0.0321,  0.5060,
-        -0.1613,  0.0000, -0.2591,  0.0000,  0.0000, -0.1929,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0934, -0.2770, -0.0066,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4828,  0.0000,  0.0000, -0.0580,  0.0000,  0.0987],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4533e-01, -8.0293e-12,  2.4431e-02, -6.6858e-02,  2.1141e+00,
-         6.5681e-02,  4.8603e-21, -3.8377e-14, -6.3023e-02,  1.6374e-01,
-         1.3649e-01,  8.7374e-13, -2.8086e-01, -4.3218e-01, -8.7158e-18,
-        -6.4779e-15, -6.3720e-16, -1.7501e-01, -1.4738e-18,  1.7720e-14,
-         1.2946e-01,  1.4350e-01, -3.1298e-18, -8.7252e-01, -3.2404e-02,
-         2.5402e-14,  8.9816e-02,  2.2108e-01, -4.7282e-02,  4.5519e-18,
-        -4.0793e-01, -1.2261e+00, -1.1652e-21, -4.5358e-02,  5.7834e-16,
-         3.1205e-14,  1.9292e-17,  0.0000e+00, -3.4429e-02,  5.0218e-01,
-        -1.4215e-01, -2.5768e-15, -2.6487e-01,  4.7480e-18,  1.1613e-15,
-        -2.0645e-01, -1.6386e-13,  5.2872e-07,  5.5122e-13, -3.0227e-19,
-        -9.7799e-10,  1.2429e-01, -2.9319e-01,  6.0327e-03, -1.1077e-10,
-         6.2907e-23, -2.4723e-17,  8.8914e-21,  4.7794e-01, -1.2841e-12,
-         2.3751e-20, -5.7497e-02,  1.9956e-08,  7.1081e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3453,  0.0000,  0.0244, -0.0669,  2.1141,  0.0657,  0.0000,  0.0000,
-        -0.0630,  0.1637,  0.1365,  0.0000, -0.2809, -0.4322,  0.0000,  0.0000,
-         0.0000, -0.1750,  0.0000,  0.0000,  0.1295,  0.1435,  0.0000, -0.8725,
-        -0.0324,  0.0000,  0.0898,  0.2211, -0.0473,  0.0000, -0.4079, -1.2261,
-         0.0000, -0.0454,  0.0000,  0.0000,  0.0000,  0.0000, -0.0344,  0.5022,
-        -0.1421,  0.0000, -0.2649,  0.0000,  0.0000, -0.2064,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1243, -0.2932,  0.0060,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4779,  0.0000,  0.0000, -0.0575,  0.0000,  0.0711],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3453,  0.0000,  0.0244, -0.0669,  2.1141,  0.0657,  0.0000,  0.0000,
-        -0.0630,  0.1637,  0.1365,  0.0000, -0.2809, -0.4322,  0.0000,  0.0000,
-         0.0000, -0.1750,  0.0000,  0.0000,  0.1295,  0.1435,  0.0000, -0.8725,
-        -0.0324,  0.0000,  0.0898,  0.2211, -0.0473,  0.0000, -0.4079, -1.2261,
-         0.0000, -0.0454,  0.0000,  0.0000,  0.0000,  0.0000, -0.0344,  0.5022,
-        -0.1421,  0.0000, -0.2649,  0.0000,  0.0000, -0.2064,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1243, -0.2932,  0.0060,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4779,  0.0000,  0.0000, -0.0575,  0.0000,  0.0711],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4585e-01, -7.3592e-12,  1.8012e-02, -7.1499e-02,  2.1139e+00,
-         6.5921e-02,  4.4546e-21, -3.5175e-14, -6.5567e-02,  1.4121e-01,
-         1.1891e-01,  8.0082e-13, -2.7479e-01, -4.2717e-01, -7.9884e-18,
-        -5.9373e-15, -5.8403e-16, -1.6319e-01, -1.3508e-18,  1.6241e-14,
-         1.3760e-01,  1.6643e-01, -2.8686e-18, -8.6228e-01, -2.7125e-02,
-         2.3282e-14,  8.1310e-02,  2.2910e-01, -3.6544e-02,  4.1720e-18,
-        -4.1313e-01, -1.2270e+00, -1.0679e-21, -4.7359e-02,  5.3008e-16,
-         2.8601e-14,  1.7682e-17,  0.0000e+00, -3.4913e-02,  4.9876e-01,
-        -1.2283e-01, -2.3617e-15, -2.6991e-01,  4.3517e-18,  1.0644e-15,
-        -2.1964e-01, -1.5018e-13,  4.8460e-07,  5.0521e-13, -2.7704e-19,
-        -8.9638e-10,  1.5108e-01, -3.1118e-01,  2.2342e-02, -1.0152e-10,
-         5.7657e-23, -2.2660e-17,  8.1494e-21,  4.7314e-01, -1.1769e-12,
-         2.1769e-20, -5.5402e-02,  1.8290e-08,  4.3383e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3458,  0.0000,  0.0180, -0.0715,  2.1139,  0.0659,  0.0000,  0.0000,
-        -0.0656,  0.1412,  0.1189,  0.0000, -0.2748, -0.4272,  0.0000,  0.0000,
-         0.0000, -0.1632,  0.0000,  0.0000,  0.1376,  0.1664,  0.0000, -0.8623,
-        -0.0271,  0.0000,  0.0813,  0.2291, -0.0365,  0.0000, -0.4131, -1.2270,
-         0.0000, -0.0474,  0.0000,  0.0000,  0.0000,  0.0000, -0.0349,  0.4988,
-        -0.1228,  0.0000, -0.2699,  0.0000,  0.0000, -0.2196,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1511, -0.3112,  0.0223,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4731,  0.0000,  0.0000, -0.0554,  0.0000,  0.0434],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3458,  0.0000,  0.0180, -0.0715,  2.1139,  0.0659,  0.0000,  0.0000,
-        -0.0656,  0.1412,  0.1189,  0.0000, -0.2748, -0.4272,  0.0000,  0.0000,
-         0.0000, -0.1632,  0.0000,  0.0000,  0.1376,  0.1664,  0.0000, -0.8623,
-        -0.0271,  0.0000,  0.0813,  0.2291, -0.0365,  0.0000, -0.4131, -1.2270,
-         0.0000, -0.0474,  0.0000,  0.0000,  0.0000,  0.0000, -0.0349,  0.4988,
-        -0.1228,  0.0000, -0.2699,  0.0000,  0.0000, -0.2196,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1511, -0.3112,  0.0223,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4731,  0.0000,  0.0000, -0.0554,  0.0000,  0.0434],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4646e-01, -6.7446e-12,  1.2027e-02, -7.5426e-02,  2.1135e+00,
-         6.2781e-02,  4.0826e-21, -3.2237e-14, -6.6253e-02,  1.2671e-01,
-         1.0003e-01,  7.3394e-13, -2.6916e-01, -4.2358e-01, -7.3213e-18,
-        -5.4415e-15, -5.3525e-16, -1.4710e-01, -1.2380e-18,  1.4885e-14,
-         1.4662e-01,  1.8836e-01, -2.6291e-18, -8.5356e-01, -2.2917e-02,
-         2.1338e-14,  7.6154e-02,  2.3988e-01, -2.3864e-02,  3.8236e-18,
-        -4.1967e-01, -1.2276e+00, -9.7876e-22, -5.2811e-02,  4.8581e-16,
-         2.6212e-14,  1.6205e-17,  0.0000e+00, -4.0792e-02,  4.9390e-01,
-        -1.0414e-01, -2.1645e-15, -2.7435e-01,  3.9883e-18,  9.7550e-16,
-        -2.3270e-01, -1.3764e-13,  4.4413e-07,  4.6302e-13, -2.5390e-19,
-        -8.2152e-10,  1.7887e-01, -3.2484e-01,  4.0133e-02, -9.3044e-11,
-         5.2842e-23, -2.0768e-17,  7.4688e-21,  4.6924e-01, -1.0787e-12,
-         1.9951e-20, -5.1011e-02,  1.6763e-08,  2.3442e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3465,  0.0000,  0.0120, -0.0754,  2.1135,  0.0628,  0.0000,  0.0000,
-        -0.0663,  0.1267,  0.1000,  0.0000, -0.2692, -0.4236,  0.0000,  0.0000,
-         0.0000, -0.1471,  0.0000,  0.0000,  0.1466,  0.1884,  0.0000, -0.8536,
-        -0.0229,  0.0000,  0.0762,  0.2399, -0.0239,  0.0000, -0.4197, -1.2276,
-         0.0000, -0.0528,  0.0000,  0.0000,  0.0000,  0.0000, -0.0408,  0.4939,
-        -0.1041,  0.0000, -0.2744,  0.0000,  0.0000, -0.2327,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1789, -0.3248,  0.0401,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4692,  0.0000,  0.0000, -0.0510,  0.0000,  0.0234],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3465,  0.0000,  0.0120, -0.0754,  2.1135,  0.0628,  0.0000,  0.0000,
-        -0.0663,  0.1267,  0.1000,  0.0000, -0.2692, -0.4236,  0.0000,  0.0000,
-         0.0000, -0.1471,  0.0000,  0.0000,  0.1466,  0.1884,  0.0000, -0.8536,
-        -0.0229,  0.0000,  0.0762,  0.2399, -0.0239,  0.0000, -0.4197, -1.2276,
-         0.0000, -0.0528,  0.0000,  0.0000,  0.0000,  0.0000, -0.0408,  0.4939,
-        -0.1041,  0.0000, -0.2744,  0.0000,  0.0000, -0.2327,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1789, -0.3248,  0.0401,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4692,  0.0000,  0.0000, -0.0510,  0.0000,  0.0234],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4915e-01, -6.1809e-12,  1.1792e-02, -7.4549e-02,  2.1128e+00,
-         6.3420e-02,  3.7414e-21, -2.9543e-14, -7.0485e-02,  1.1816e-01,
-         9.3475e-02,  6.7260e-13, -2.5964e-01, -4.2373e-01, -6.7094e-18,
-        -4.9867e-15, -4.9052e-16, -1.3414e-01, -1.1345e-18,  1.3641e-14,
-         1.5264e-01,  2.0565e-01, -2.4093e-18, -8.4949e-01, -2.2858e-02,
-         1.9554e-14,  7.7852e-02,  2.5303e-01, -1.6422e-02,  3.5040e-18,
-        -4.2444e-01, -1.2284e+00, -8.9696e-22, -6.1953e-02,  4.4521e-16,
-         2.4022e-14,  1.4851e-17,  0.0000e+00, -3.5356e-02,  4.9051e-01,
-        -9.0057e-02, -1.9836e-15, -2.7249e-01,  3.6550e-18,  8.9397e-16,
-        -2.3519e-01, -1.2614e-13,  4.0701e-07,  4.2432e-13, -2.3268e-19,
-        -7.5286e-10,  1.9331e-01, -3.2496e-01,  6.2684e-02, -8.5267e-11,
-         4.8425e-23, -1.9032e-17,  6.8446e-21,  4.6431e-01, -9.8851e-13,
-         1.8284e-20, -4.2452e-02,  1.5362e-08,  1.7480e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3491,  0.0000,  0.0118, -0.0745,  2.1128,  0.0634,  0.0000,  0.0000,
-        -0.0705,  0.1182,  0.0935,  0.0000, -0.2596, -0.4237,  0.0000,  0.0000,
-         0.0000, -0.1341,  0.0000,  0.0000,  0.1526,  0.2056,  0.0000, -0.8495,
-        -0.0229,  0.0000,  0.0779,  0.2530, -0.0164,  0.0000, -0.4244, -1.2284,
-         0.0000, -0.0620,  0.0000,  0.0000,  0.0000,  0.0000, -0.0354,  0.4905,
-        -0.0901,  0.0000, -0.2725,  0.0000,  0.0000, -0.2352,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1933, -0.3250,  0.0627,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4643,  0.0000,  0.0000, -0.0425,  0.0000,  0.0175],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3491,  0.0000,  0.0118, -0.0745,  2.1128,  0.0634,  0.0000,  0.0000,
-        -0.0705,  0.1182,  0.0935,  0.0000, -0.2596, -0.4237,  0.0000,  0.0000,
-         0.0000, -0.1341,  0.0000,  0.0000,  0.1526,  0.2056,  0.0000, -0.8495,
-        -0.0229,  0.0000,  0.0779,  0.2530, -0.0164,  0.0000, -0.4244, -1.2284,
-         0.0000, -0.0620,  0.0000,  0.0000,  0.0000,  0.0000, -0.0354,  0.4905,
-        -0.0901,  0.0000, -0.2725,  0.0000,  0.0000, -0.2352,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1933, -0.3250,  0.0627,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4643,  0.0000,  0.0000, -0.0425,  0.0000,  0.0175],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5177e-01, -5.6638e-12,  1.2570e-02, -7.0952e-02,  2.1121e+00,
-         6.4830e-02,  3.4284e-21, -2.7071e-14, -7.5429e-02,  1.1411e-01,
-         9.8885e-02,  6.1634e-13, -2.5321e-01, -4.2497e-01, -6.1481e-18,
-        -4.5695e-15, -4.4948e-16, -1.2743e-01, -1.0396e-18,  1.2500e-14,
-         1.5322e-01,  2.1473e-01, -2.2078e-18, -8.5050e-01, -2.6477e-02,
-         1.7918e-14,  8.2865e-02,  2.6152e-01, -1.1179e-02,  3.2109e-18,
-        -4.3363e-01, -1.2302e+00, -8.2192e-22, -7.3463e-02,  4.0796e-16,
-         2.2012e-14,  1.3609e-17,  0.0000e+00, -3.2633e-02,  4.8754e-01,
-        -7.6060e-02, -1.8177e-15, -2.7172e-01,  3.3492e-18,  8.1918e-16,
-        -2.3226e-01, -1.1558e-13,  3.7296e-07,  3.8883e-13, -2.1322e-19,
-        -6.8988e-10,  1.9504e-01, -3.1217e-01,  8.2218e-02, -7.8134e-11,
-         4.4374e-23, -1.7440e-17,  6.2720e-21,  4.5946e-01, -9.0581e-13,
-         1.6754e-20, -3.5049e-02,  1.4077e-08,  2.1967e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3518,  0.0000,  0.0126, -0.0710,  2.1121,  0.0648,  0.0000,  0.0000,
-        -0.0754,  0.1141,  0.0989,  0.0000, -0.2532, -0.4250,  0.0000,  0.0000,
-         0.0000, -0.1274,  0.0000,  0.0000,  0.1532,  0.2147,  0.0000, -0.8505,
-        -0.0265,  0.0000,  0.0829,  0.2615, -0.0112,  0.0000, -0.4336, -1.2302,
-         0.0000, -0.0735,  0.0000,  0.0000,  0.0000,  0.0000, -0.0326,  0.4875,
-        -0.0761,  0.0000, -0.2717,  0.0000,  0.0000, -0.2323,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1950, -0.3122,  0.0822,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4595,  0.0000,  0.0000, -0.0350,  0.0000,  0.0220],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3518,  0.0000,  0.0126, -0.0710,  2.1121,  0.0648,  0.0000,  0.0000,
-        -0.0754,  0.1141,  0.0989,  0.0000, -0.2532, -0.4250,  0.0000,  0.0000,
-         0.0000, -0.1274,  0.0000,  0.0000,  0.1532,  0.2147,  0.0000, -0.8505,
-        -0.0265,  0.0000,  0.0829,  0.2615, -0.0112,  0.0000, -0.4336, -1.2302,
-         0.0000, -0.0735,  0.0000,  0.0000,  0.0000,  0.0000, -0.0326,  0.4875,
-        -0.0761,  0.0000, -0.2717,  0.0000,  0.0000, -0.2323,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1950, -0.3122,  0.0822,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4595,  0.0000,  0.0000, -0.0350,  0.0000,  0.0220],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5822e-01, -5.1895e-12,  1.6191e-02, -6.9945e-02,  2.1113e+00,
-         6.2055e-02,  3.1413e-21, -2.4804e-14, -7.9788e-02,  1.1033e-01,
-         1.1460e-01,  5.6472e-13, -2.4385e-01, -4.2552e-01, -5.6333e-18,
-        -4.1868e-15, -4.1184e-16, -1.2471e-01, -9.5256e-19,  1.1453e-14,
-         1.4724e-01,  2.1991e-01, -2.0229e-18, -8.5195e-01, -3.3335e-02,
-         1.6418e-14,  8.6836e-02,  2.6604e-01, -8.2903e-03,  2.9420e-18,
-        -4.4341e-01, -1.2319e+00, -7.5309e-22, -8.8913e-02,  3.7380e-16,
-         2.0169e-14,  1.2469e-17,  0.0000e+00, -1.5010e-02,  4.8820e-01,
-        -6.0024e-02, -1.6654e-15, -2.6687e-01,  3.0687e-18,  7.5058e-16,
-        -2.3308e-01, -1.0591e-13,  3.4173e-07,  3.5626e-13, -1.9536e-19,
-        -6.3210e-10,  1.8943e-01, -2.9808e-01,  9.5377e-02, -7.1591e-11,
-         4.0658e-23, -1.5979e-17,  5.7468e-21,  4.5093e-01, -8.2996e-13,
-         1.5351e-20, -3.2771e-02,  1.2898e-08,  3.0725e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3582,  0.0000,  0.0162, -0.0699,  2.1113,  0.0621,  0.0000,  0.0000,
-        -0.0798,  0.1103,  0.1146,  0.0000, -0.2438, -0.4255,  0.0000,  0.0000,
-         0.0000, -0.1247,  0.0000,  0.0000,  0.1472,  0.2199,  0.0000, -0.8519,
-        -0.0333,  0.0000,  0.0868,  0.2660, -0.0083,  0.0000, -0.4434, -1.2319,
-         0.0000, -0.0889,  0.0000,  0.0000,  0.0000,  0.0000, -0.0150,  0.4882,
-        -0.0600,  0.0000, -0.2669,  0.0000,  0.0000, -0.2331,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1894, -0.2981,  0.0954,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4509,  0.0000,  0.0000, -0.0328,  0.0000,  0.0307],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3582,  0.0000,  0.0162, -0.0699,  2.1113,  0.0621,  0.0000,  0.0000,
-        -0.0798,  0.1103,  0.1146,  0.0000, -0.2438, -0.4255,  0.0000,  0.0000,
-         0.0000, -0.1247,  0.0000,  0.0000,  0.1472,  0.2199,  0.0000, -0.8519,
-        -0.0333,  0.0000,  0.0868,  0.2660, -0.0083,  0.0000, -0.4434, -1.2319,
-         0.0000, -0.0889,  0.0000,  0.0000,  0.0000,  0.0000, -0.0150,  0.4882,
-        -0.0600,  0.0000, -0.2669,  0.0000,  0.0000, -0.2331,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1894, -0.2981,  0.0954,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4509,  0.0000,  0.0000, -0.0328,  0.0000,  0.0307],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6520e-01, -4.7544e-12,  2.1500e-02, -6.4923e-02,  2.1108e+00,
-         6.1163e-02,  2.8779e-21, -2.2724e-14, -8.4447e-02,  1.0694e-01,
-         1.3714e-01,  5.1737e-13, -2.3451e-01, -4.2682e-01, -5.1609e-18,
-        -3.8358e-15, -3.7731e-16, -1.1894e-01, -8.7269e-19,  1.0493e-14,
-         1.3941e-01,  2.1873e-01, -1.8533e-18, -8.5362e-01, -4.1685e-02,
-         1.5041e-14,  9.3821e-02,  2.6947e-01, -8.8483e-03,  2.6953e-18,
-        -4.5522e-01, -1.2337e+00, -6.8994e-22, -1.0540e-01,  3.4246e-16,
-         1.8478e-14,  1.1424e-17,  0.0000e+00,  7.3369e-03,  4.8980e-01,
-        -4.4973e-02, -1.5258e-15, -2.6181e-01,  2.8114e-18,  6.8765e-16,
-        -2.3373e-01, -9.7025e-14,  3.1307e-07,  3.2639e-13, -1.7898e-19,
-        -5.7910e-10,  1.7738e-01, -2.7877e-01,  1.1228e-01, -6.5588e-11,
-         3.7249e-23, -1.4640e-17,  5.2649e-21,  4.4313e-01, -7.6037e-13,
-         1.4064e-20, -2.8914e-02,  1.1816e-08,  4.3356e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3652,  0.0000,  0.0215, -0.0649,  2.1108,  0.0612,  0.0000,  0.0000,
-        -0.0844,  0.1069,  0.1371,  0.0000, -0.2345, -0.4268,  0.0000,  0.0000,
-         0.0000, -0.1189,  0.0000,  0.0000,  0.1394,  0.2187,  0.0000, -0.8536,
-        -0.0417,  0.0000,  0.0938,  0.2695, -0.0088,  0.0000, -0.4552, -1.2337,
-         0.0000, -0.1054,  0.0000,  0.0000,  0.0000,  0.0000,  0.0073,  0.4898,
-        -0.0450,  0.0000, -0.2618,  0.0000,  0.0000, -0.2337,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1774, -0.2788,  0.1123,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4431,  0.0000,  0.0000, -0.0289,  0.0000,  0.0434],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3652,  0.0000,  0.0215, -0.0649,  2.1108,  0.0612,  0.0000,  0.0000,
-        -0.0844,  0.1069,  0.1371,  0.0000, -0.2345, -0.4268,  0.0000,  0.0000,
-         0.0000, -0.1189,  0.0000,  0.0000,  0.1394,  0.2187,  0.0000, -0.8536,
-        -0.0417,  0.0000,  0.0938,  0.2695, -0.0088,  0.0000, -0.4552, -1.2337,
-         0.0000, -0.1054,  0.0000,  0.0000,  0.0000,  0.0000,  0.0073,  0.4898,
-        -0.0450,  0.0000, -0.2618,  0.0000,  0.0000, -0.2337,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1774, -0.2788,  0.1123,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4431,  0.0000,  0.0000, -0.0289,  0.0000,  0.0434],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6910e-01, -4.3552e-12,  2.1715e-02, -5.9685e-02,  2.1101e+00,
-         5.6170e-02,  2.6363e-21, -2.0816e-14, -9.2591e-02,  1.0353e-01,
-         1.6875e-01,  4.7393e-13, -2.3019e-01, -4.2299e-01, -4.7276e-18,
-        -3.5137e-15, -3.4563e-16, -1.2259e-01, -7.9942e-19,  9.6116e-15,
-         1.2646e-01,  2.2059e-01, -1.6977e-18, -8.5743e-01, -4.4427e-02,
-         1.3778e-14,  9.2554e-02,  2.6941e-01, -1.2206e-02,  2.4690e-18,
-        -4.6901e-01, -1.2362e+00, -6.3202e-22, -1.2213e-01,  3.1370e-16,
-         1.6926e-14,  1.0464e-17,  0.0000e+00,  3.4122e-02,  4.8848e-01,
-        -2.2763e-02, -1.3977e-15, -2.5805e-01,  2.5754e-18,  6.2991e-16,
-        -2.3063e-01, -8.8879e-14,  2.8679e-07,  2.9899e-13, -1.6395e-19,
-        -5.3048e-10,  1.6694e-01, -2.4919e-01,  1.1693e-01, -6.0081e-11,
-         3.4122e-23, -1.3410e-17,  4.8229e-21,  4.3704e-01, -6.9653e-13,
-         1.2883e-20, -3.0326e-02,  1.0824e-08,  5.5178e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3691,  0.0000,  0.0217, -0.0597,  2.1101,  0.0562,  0.0000,  0.0000,
-        -0.0926,  0.1035,  0.1687,  0.0000, -0.2302, -0.4230,  0.0000,  0.0000,
-         0.0000, -0.1226,  0.0000,  0.0000,  0.1265,  0.2206,  0.0000, -0.8574,
-        -0.0444,  0.0000,  0.0926,  0.2694, -0.0122,  0.0000, -0.4690, -1.2362,
-         0.0000, -0.1221,  0.0000,  0.0000,  0.0000,  0.0000,  0.0341,  0.4885,
-        -0.0228,  0.0000, -0.2580,  0.0000,  0.0000, -0.2306,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1669, -0.2492,  0.1169,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4370,  0.0000,  0.0000, -0.0303,  0.0000,  0.0552],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3691,  0.0000,  0.0217, -0.0597,  2.1101,  0.0562,  0.0000,  0.0000,
-        -0.0926,  0.1035,  0.1687,  0.0000, -0.2302, -0.4230,  0.0000,  0.0000,
-         0.0000, -0.1226,  0.0000,  0.0000,  0.1265,  0.2206,  0.0000, -0.8574,
-        -0.0444,  0.0000,  0.0926,  0.2694, -0.0122,  0.0000, -0.4690, -1.2362,
-         0.0000, -0.1221,  0.0000,  0.0000,  0.0000,  0.0000,  0.0341,  0.4885,
-        -0.0228,  0.0000, -0.2580,  0.0000,  0.0000, -0.2306,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1669, -0.2492,  0.1169,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4370,  0.0000,  0.0000, -0.0303,  0.0000,  0.0552],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7135e-01, -3.9890e-12,  2.0851e-02, -5.3417e-02,  2.1096e+00,
-         4.9808e-02,  2.4146e-21, -1.9066e-14, -1.0088e-01,  1.0150e-01,
-         2.0725e-01,  4.3408e-13, -2.2730e-01, -4.1900e-01, -4.3301e-18,
-        -3.2183e-15, -3.1657e-16, -1.2873e-01, -7.3220e-19,  8.8034e-15,
-         1.0767e-01,  2.1667e-01, -1.5549e-18, -8.6378e-01, -4.4602e-02,
-         1.2620e-14,  8.9562e-02,  2.6228e-01, -1.8857e-02,  2.2614e-18,
-        -4.8378e-01, -1.2392e+00, -5.7887e-22, -1.3914e-01,  2.8733e-16,
-         1.5503e-14,  9.5845e-18,  0.0000e+00,  6.2128e-02,  4.8801e-01,
-        -3.9472e-03, -1.2802e-15, -2.5432e-01,  2.3588e-18,  5.7695e-16,
-        -2.2498e-01, -8.1405e-14,  2.6267e-07,  2.7385e-13, -1.5017e-19,
-        -4.8587e-10,  1.4925e-01, -2.1477e-01,  1.0855e-01, -5.5029e-11,
-         3.1252e-23, -1.2283e-17,  4.4173e-21,  4.3140e-01, -6.3796e-13,
-         1.1800e-20, -3.5841e-02,  9.9140e-09,  6.8748e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3714,  0.0000,  0.0209, -0.0534,  2.1096,  0.0498,  0.0000,  0.0000,
-        -0.1009,  0.1015,  0.2072,  0.0000, -0.2273, -0.4190,  0.0000,  0.0000,
-         0.0000, -0.1287,  0.0000,  0.0000,  0.1077,  0.2167,  0.0000, -0.8638,
-        -0.0446,  0.0000,  0.0896,  0.2623, -0.0189,  0.0000, -0.4838, -1.2392,
-         0.0000, -0.1391,  0.0000,  0.0000,  0.0000,  0.0000,  0.0621,  0.4880,
-        -0.0039,  0.0000, -0.2543,  0.0000,  0.0000, -0.2250,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1493, -0.2148,  0.1086,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4314,  0.0000,  0.0000, -0.0358,  0.0000,  0.0687],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3714,  0.0000,  0.0209, -0.0534,  2.1096,  0.0498,  0.0000,  0.0000,
-        -0.1009,  0.1015,  0.2072,  0.0000, -0.2273, -0.4190,  0.0000,  0.0000,
-         0.0000, -0.1287,  0.0000,  0.0000,  0.1077,  0.2167,  0.0000, -0.8638,
-        -0.0446,  0.0000,  0.0896,  0.2623, -0.0189,  0.0000, -0.4838, -1.2392,
-         0.0000, -0.1391,  0.0000,  0.0000,  0.0000,  0.0000,  0.0621,  0.4880,
-        -0.0039,  0.0000, -0.2543,  0.0000,  0.0000, -0.2250,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1493, -0.2148,  0.1086,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4314,  0.0000,  0.0000, -0.0358,  0.0000,  0.0687],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7358e-01, -3.6530e-12,  1.8609e-02, -4.6791e-02,  2.1089e+00,
-         4.0635e-02,  2.2112e-21, -1.7460e-14, -1.0792e-01,  1.0170e-01,
-         2.4187e-01,  3.9752e-13, -2.2301e-01, -4.1660e-01, -3.9654e-18,
-        -2.9472e-15, -2.8990e-16, -1.3082e-01, -6.7053e-19,  8.0620e-15,
-         8.7363e-02,  2.1270e-01, -1.4240e-18, -8.6986e-01, -4.4197e-02,
-         1.1557e-14,  8.3915e-02,  2.5128e-01, -2.5102e-02,  2.0709e-18,
-        -4.9764e-01, -1.2423e+00, -5.3012e-22, -1.5288e-01,  2.6313e-16,
-         1.4197e-14,  8.7772e-18,  0.0000e+00,  8.4237e-02,  4.8693e-01,
-         1.6294e-02, -1.1723e-15, -2.5130e-01,  2.1602e-18,  5.2835e-16,
-        -2.1971e-01, -7.4549e-14,  2.4055e-07,  2.5078e-13, -1.3752e-19,
-        -4.4495e-10,  1.3297e-01, -1.8000e-01,  9.9593e-02, -5.0394e-11,
-         2.8620e-23, -1.1248e-17,  4.0453e-21,  4.2654e-01, -5.8423e-13,
-         1.0806e-20, -4.3850e-02,  9.0790e-09,  7.9838e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3736,  0.0000,  0.0186, -0.0468,  2.1089,  0.0406,  0.0000,  0.0000,
-        -0.1079,  0.1017,  0.2419,  0.0000, -0.2230, -0.4166,  0.0000,  0.0000,
-         0.0000, -0.1308,  0.0000,  0.0000,  0.0874,  0.2127,  0.0000, -0.8699,
-        -0.0442,  0.0000,  0.0839,  0.2513, -0.0251,  0.0000, -0.4976, -1.2423,
-         0.0000, -0.1529,  0.0000,  0.0000,  0.0000,  0.0000,  0.0842,  0.4869,
-         0.0163,  0.0000, -0.2513,  0.0000,  0.0000, -0.2197,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1330, -0.1800,  0.0996,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4265,  0.0000,  0.0000, -0.0439,  0.0000,  0.0798],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3736,  0.0000,  0.0186, -0.0468,  2.1089,  0.0406,  0.0000,  0.0000,
-        -0.1079,  0.1017,  0.2419,  0.0000, -0.2230, -0.4166,  0.0000,  0.0000,
-         0.0000, -0.1308,  0.0000,  0.0000,  0.0874,  0.2127,  0.0000, -0.8699,
-        -0.0442,  0.0000,  0.0839,  0.2513, -0.0251,  0.0000, -0.4976, -1.2423,
-         0.0000, -0.1529,  0.0000,  0.0000,  0.0000,  0.0000,  0.0842,  0.4869,
-         0.0163,  0.0000, -0.2513,  0.0000,  0.0000, -0.2197,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1330, -0.1800,  0.0996,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4265,  0.0000,  0.0000, -0.0439,  0.0000,  0.0798],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7573e-01, -3.3448e-12,  1.7836e-02, -4.1880e-02,  2.1080e+00,
-         3.2902e-02,  2.0247e-21, -1.5987e-14, -1.1348e-01,  9.5617e-02,
-         2.7301e-01,  3.6398e-13, -2.1792e-01, -4.1456e-01, -3.6308e-18,
-        -2.6985e-15, -2.6544e-16, -1.2751e-01, -6.1395e-19,  7.3817e-15,
-         6.8378e-02,  2.1118e-01, -1.3038e-18, -8.7556e-01, -4.1638e-02,
-         1.0582e-14,  7.7475e-02,  2.3927e-01, -3.2292e-02,  1.8962e-18,
-        -5.1222e-01, -1.2455e+00, -4.8539e-22, -1.6532e-01,  2.4092e-16,
-         1.2999e-14,  8.0366e-18,  0.0000e+00,  9.7730e-02,  4.8370e-01,
-         3.3704e-02, -1.0734e-15, -2.4930e-01,  1.9779e-18,  4.8377e-16,
-        -2.2012e-01, -6.8259e-14,  2.2025e-07,  2.2962e-13, -1.2592e-19,
-        -4.0741e-10,  1.2212e-01, -1.5391e-01,  9.4415e-02, -4.6142e-11,
-         2.6205e-23, -1.0299e-17,  3.7039e-21,  4.2129e-01, -5.3493e-13,
-         9.8941e-21, -5.4038e-02,  8.3130e-09,  8.7099e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3757,  0.0000,  0.0178, -0.0419,  2.1080,  0.0329,  0.0000,  0.0000,
-        -0.1135,  0.0956,  0.2730,  0.0000, -0.2179, -0.4146,  0.0000,  0.0000,
-         0.0000, -0.1275,  0.0000,  0.0000,  0.0684,  0.2112,  0.0000, -0.8756,
-        -0.0416,  0.0000,  0.0775,  0.2393, -0.0323,  0.0000, -0.5122, -1.2455,
-         0.0000, -0.1653,  0.0000,  0.0000,  0.0000,  0.0000,  0.0977,  0.4837,
-         0.0337,  0.0000, -0.2493,  0.0000,  0.0000, -0.2201,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1221, -0.1539,  0.0944,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4213,  0.0000,  0.0000, -0.0540,  0.0000,  0.0871],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3757,  0.0000,  0.0178, -0.0419,  2.1080,  0.0329,  0.0000,  0.0000,
-        -0.1135,  0.0956,  0.2730,  0.0000, -0.2179, -0.4146,  0.0000,  0.0000,
-         0.0000, -0.1275,  0.0000,  0.0000,  0.0684,  0.2112,  0.0000, -0.8756,
-        -0.0416,  0.0000,  0.0775,  0.2393, -0.0323,  0.0000, -0.5122, -1.2455,
-         0.0000, -0.1653,  0.0000,  0.0000,  0.0000,  0.0000,  0.0977,  0.4837,
-         0.0337,  0.0000, -0.2493,  0.0000,  0.0000, -0.2201,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1221, -0.1539,  0.0944,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4213,  0.0000,  0.0000, -0.0540,  0.0000,  0.0871],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7569e-01, -3.0620e-12,  1.6142e-02, -3.8454e-02,  2.1068e+00,
-         2.7333e-02,  1.8535e-21, -1.4635e-14, -1.2173e-01,  8.6506e-02,
-         2.9220e-01,  3.3321e-13, -2.1504e-01, -4.1113e-01, -3.3238e-18,
-        -2.4704e-15, -2.4300e-16, -1.2720e-01, -5.6204e-19,  6.7576e-15,
-         5.0116e-02,  2.0853e-01, -1.1936e-18, -8.7850e-01, -3.9427e-02,
-         9.6871e-15,  7.1011e-02,  2.2695e-01, -3.7943e-02,  1.7359e-18,
-        -5.2447e-01, -1.2479e+00, -4.4435e-22, -1.7194e-01,  2.2055e-16,
-         1.1900e-14,  7.3572e-18,  0.0000e+00,  1.0602e-01,  4.7742e-01,
-         4.4115e-02, -9.8267e-16, -2.4756e-01,  1.8107e-18,  4.4287e-16,
-        -2.2511e-01, -6.2488e-14,  2.0163e-07,  2.1021e-13, -1.1527e-19,
-        -3.7296e-10,  1.1441e-01, -1.3893e-01,  8.4395e-02, -4.2241e-11,
-         2.3990e-23, -9.4284e-18,  3.3908e-21,  4.1846e-01, -4.8970e-13,
-         9.0576e-21, -6.3752e-02,  7.6102e-09,  9.0301e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3757,  0.0000,  0.0161, -0.0385,  2.1068,  0.0273,  0.0000,  0.0000,
-        -0.1217,  0.0865,  0.2922,  0.0000, -0.2150, -0.4111,  0.0000,  0.0000,
-         0.0000, -0.1272,  0.0000,  0.0000,  0.0501,  0.2085,  0.0000, -0.8785,
-        -0.0394,  0.0000,  0.0710,  0.2270, -0.0379,  0.0000, -0.5245, -1.2479,
-         0.0000, -0.1719,  0.0000,  0.0000,  0.0000,  0.0000,  0.1060,  0.4774,
-         0.0441,  0.0000, -0.2476,  0.0000,  0.0000, -0.2251,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1144, -0.1389,  0.0844,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4185,  0.0000,  0.0000, -0.0638,  0.0000,  0.0903],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3757,  0.0000,  0.0161, -0.0385,  2.1068,  0.0273,  0.0000,  0.0000,
-        -0.1217,  0.0865,  0.2922,  0.0000, -0.2150, -0.4111,  0.0000,  0.0000,
-         0.0000, -0.1272,  0.0000,  0.0000,  0.0501,  0.2085,  0.0000, -0.8785,
-        -0.0394,  0.0000,  0.0710,  0.2270, -0.0379,  0.0000, -0.5245, -1.2479,
-         0.0000, -0.1719,  0.0000,  0.0000,  0.0000,  0.0000,  0.1060,  0.4774,
-         0.0441,  0.0000, -0.2476,  0.0000,  0.0000, -0.2251,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1144, -0.1389,  0.0844,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4185,  0.0000,  0.0000, -0.0638,  0.0000,  0.0903],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7524e-01, -2.8026e-12,  1.4143e-02, -3.5058e-02,  2.1059e+00,
-         2.3657e-02,  1.6965e-21, -1.3395e-14, -1.3056e-01,  8.3285e-02,
-         2.9900e-01,  3.0498e-13, -2.1473e-01, -4.0783e-01, -3.0422e-18,
-        -2.2611e-15, -2.2241e-16, -1.2500e-01, -5.1442e-19,  6.1851e-15,
-         3.4761e-02,  2.0312e-01, -1.0924e-18, -8.7960e-01, -4.1105e-02,
-         8.8664e-15,  6.6773e-02,  2.1438e-01, -4.0827e-02,  1.5888e-18,
-        -5.3355e-01, -1.2501e+00, -4.0670e-22, -1.7092e-01,  2.0187e-16,
-         1.0892e-14,  6.7339e-18,  0.0000e+00,  1.1793e-01,  4.7404e-01,
-         4.6601e-02, -8.9942e-16, -2.4722e-01,  1.6573e-18,  4.0535e-16,
-        -2.2978e-01, -5.7194e-14,  1.8455e-07,  1.9240e-13, -1.0550e-19,
-        -3.4136e-10,  1.0309e-01, -1.2520e-01,  7.3557e-02, -3.8662e-11,
-         2.1957e-23, -8.6296e-18,  3.1035e-21,  4.1526e-01, -4.4821e-13,
-         8.2902e-21, -7.1553e-02,  6.9654e-09,  9.1182e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3752,  0.0000,  0.0141, -0.0351,  2.1059,  0.0237,  0.0000,  0.0000,
-        -0.1306,  0.0833,  0.2990,  0.0000, -0.2147, -0.4078,  0.0000,  0.0000,
-         0.0000, -0.1250,  0.0000,  0.0000,  0.0348,  0.2031,  0.0000, -0.8796,
-        -0.0411,  0.0000,  0.0668,  0.2144, -0.0408,  0.0000, -0.5335, -1.2501,
-         0.0000, -0.1709,  0.0000,  0.0000,  0.0000,  0.0000,  0.1179,  0.4740,
-         0.0466,  0.0000, -0.2472,  0.0000,  0.0000, -0.2298,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1031, -0.1252,  0.0736,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4153,  0.0000,  0.0000, -0.0716,  0.0000,  0.0912],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3752,  0.0000,  0.0141, -0.0351,  2.1059,  0.0237,  0.0000,  0.0000,
-        -0.1306,  0.0833,  0.2990,  0.0000, -0.2147, -0.4078,  0.0000,  0.0000,
-         0.0000, -0.1250,  0.0000,  0.0000,  0.0348,  0.2031,  0.0000, -0.8796,
-        -0.0411,  0.0000,  0.0668,  0.2144, -0.0408,  0.0000, -0.5335, -1.2501,
-         0.0000, -0.1709,  0.0000,  0.0000,  0.0000,  0.0000,  0.1179,  0.4740,
-         0.0466,  0.0000, -0.2472,  0.0000,  0.0000, -0.2298,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1031, -0.1252,  0.0736,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4153,  0.0000,  0.0000, -0.0716,  0.0000,  0.0912],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7258e-01, -2.5646e-12,  9.7755e-03, -3.1829e-02,  2.1048e+00,
-         2.1110e-02,  1.5524e-21, -1.2258e-14, -1.3855e-01,  8.2203e-02,
-         2.9882e-01,  2.7908e-13, -2.1846e-01, -4.0462e-01, -2.7839e-18,
-        -2.0691e-15, -2.0353e-16, -1.1714e-01, -4.7074e-19,  5.6599e-15,
-         2.2466e-02,  1.9989e-01, -9.9968e-19, -8.8077e-01, -4.0946e-02,
-         8.1135e-15,  6.2548e-02,  2.0300e-01, -4.1858e-02,  1.4539e-18,
-        -5.4371e-01, -1.2525e+00, -3.7217e-22, -1.6575e-01,  1.8473e-16,
-         9.9671e-15,  6.1620e-18,  0.0000e+00,  1.1894e-01,  4.6526e-01,
-         4.6520e-02, -8.2304e-16, -2.4955e-01,  1.5165e-18,  3.7093e-16,
-        -2.3612e-01, -5.2337e-14,  1.6888e-07,  1.7606e-13, -9.6545e-20,
-        -3.1238e-10,  9.1299e-02, -1.1017e-01,  6.5578e-02, -3.5379e-11,
-         2.0093e-23, -7.8968e-18,  2.8400e-21,  4.1369e-01, -4.1015e-13,
-         7.5862e-21, -7.9332e-02,  6.3739e-09,  9.1857e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3726,  0.0000,  0.0098, -0.0318,  2.1048,  0.0211,  0.0000,  0.0000,
-        -0.1385,  0.0822,  0.2988,  0.0000, -0.2185, -0.4046,  0.0000,  0.0000,
-         0.0000, -0.1171,  0.0000,  0.0000,  0.0225,  0.1999,  0.0000, -0.8808,
-        -0.0409,  0.0000,  0.0625,  0.2030, -0.0419,  0.0000, -0.5437, -1.2525,
-         0.0000, -0.1658,  0.0000,  0.0000,  0.0000,  0.0000,  0.1189,  0.4653,
-         0.0465,  0.0000, -0.2495,  0.0000,  0.0000, -0.2361,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0913, -0.1102,  0.0656,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4137,  0.0000,  0.0000, -0.0793,  0.0000,  0.0919],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3726,  0.0000,  0.0098, -0.0318,  2.1048,  0.0211,  0.0000,  0.0000,
-        -0.1385,  0.0822,  0.2988,  0.0000, -0.2185, -0.4046,  0.0000,  0.0000,
-         0.0000, -0.1171,  0.0000,  0.0000,  0.0225,  0.1999,  0.0000, -0.8808,
-        -0.0409,  0.0000,  0.0625,  0.2030, -0.0419,  0.0000, -0.5437, -1.2525,
-         0.0000, -0.1658,  0.0000,  0.0000,  0.0000,  0.0000,  0.1189,  0.4653,
-         0.0465,  0.0000, -0.2495,  0.0000,  0.0000, -0.2361,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0913, -0.1102,  0.0656,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4137,  0.0000,  0.0000, -0.0793,  0.0000,  0.0919],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6953e-01, -2.3463e-12,  5.5697e-03, -3.0639e-02,  2.1037e+00,
-         1.7444e-02,  1.4202e-21, -1.1214e-14, -1.4510e-01,  7.6353e-02,
-         2.8507e-01,  2.5532e-13, -2.2070e-01, -4.0172e-01, -2.5469e-18,
-        -1.8929e-15, -1.8620e-16, -1.0801e-01, -4.3067e-19,  5.1780e-15,
-         1.2969e-02,  1.9565e-01, -9.1457e-19, -8.7833e-01, -4.0853e-02,
-         7.4228e-15,  5.5642e-02,  1.9201e-01, -3.9804e-02,  1.3301e-18,
-        -5.5064e-01, -1.2545e+00, -3.4048e-22, -1.5999e-01,  1.6900e-16,
-         9.1186e-15,  5.6374e-18,  0.0000e+00,  1.2044e-01,  4.5878e-01,
-         4.1836e-02, -7.5297e-16, -2.5341e-01,  1.3874e-18,  3.3935e-16,
-        -2.4242e-01, -4.7881e-14,  1.5450e-07,  1.6107e-13, -8.8326e-20,
-        -2.8578e-10,  8.2796e-02, -1.0401e-01,  5.9458e-02, -3.2367e-11,
-         1.8382e-23, -7.2245e-18,  2.5982e-21,  4.1391e-01, -3.7524e-13,
-         6.9404e-21, -8.6162e-02,  5.8313e-09,  8.6706e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3695,  0.0000,  0.0056, -0.0306,  2.1037,  0.0174,  0.0000,  0.0000,
-        -0.1451,  0.0764,  0.2851,  0.0000, -0.2207, -0.4017,  0.0000,  0.0000,
-         0.0000, -0.1080,  0.0000,  0.0000,  0.0130,  0.1956,  0.0000, -0.8783,
-        -0.0409,  0.0000,  0.0556,  0.1920, -0.0398,  0.0000, -0.5506, -1.2545,
-         0.0000, -0.1600,  0.0000,  0.0000,  0.0000,  0.0000,  0.1204,  0.4588,
-         0.0418,  0.0000, -0.2534,  0.0000,  0.0000, -0.2424,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0828, -0.1040,  0.0595,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4139,  0.0000,  0.0000, -0.0862,  0.0000,  0.0867],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3695,  0.0000,  0.0056, -0.0306,  2.1037,  0.0174,  0.0000,  0.0000,
-        -0.1451,  0.0764,  0.2851,  0.0000, -0.2207, -0.4017,  0.0000,  0.0000,
-         0.0000, -0.1080,  0.0000,  0.0000,  0.0130,  0.1956,  0.0000, -0.8783,
-        -0.0409,  0.0000,  0.0556,  0.1920, -0.0398,  0.0000, -0.5506, -1.2545,
-         0.0000, -0.1600,  0.0000,  0.0000,  0.0000,  0.0000,  0.1204,  0.4588,
-         0.0418,  0.0000, -0.2534,  0.0000,  0.0000, -0.2424,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0828, -0.1040,  0.0595,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4139,  0.0000,  0.0000, -0.0862,  0.0000,  0.0867],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6481e-01, -2.1460e-12,  2.4234e-03, -2.4726e-02,  2.1025e+00,
-         1.7278e-02,  1.2990e-21, -1.0257e-14, -1.4834e-01,  6.8507e-02,
-         2.6214e-01,  2.3353e-13, -2.2672e-01, -3.9673e-01, -2.3295e-18,
-        -1.7314e-15, -1.7031e-16, -1.0428e-01, -3.9390e-19,  4.7360e-15,
-         2.5478e-03,  1.9256e-01, -8.3651e-19, -8.7571e-01, -3.8005e-02,
-         6.7892e-15,  5.2835e-02,  1.7717e-01, -4.0167e-02,  1.2166e-18,
-        -5.5514e-01, -1.2565e+00, -3.1142e-22, -1.5315e-01,  1.5457e-16,
-         8.3402e-15,  5.1562e-18,  0.0000e+00,  1.1818e-01,  4.5012e-01,
-         2.9829e-02, -6.8870e-16, -2.5780e-01,  1.2690e-18,  3.1038e-16,
-        -2.4680e-01, -4.3794e-14,  1.4131e-07,  1.4732e-13, -8.0787e-20,
-        -2.6139e-10,  7.8275e-02, -1.0511e-01,  5.2340e-02, -2.9605e-11,
-         1.6813e-23, -6.6078e-18,  2.3764e-21,  4.1634e-01, -3.4321e-13,
-         6.3480e-21, -9.3164e-02,  5.3335e-09,  8.0490e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3648,  0.0000,  0.0024, -0.0247,  2.1025,  0.0173,  0.0000,  0.0000,
-        -0.1483,  0.0685,  0.2621,  0.0000, -0.2267, -0.3967,  0.0000,  0.0000,
-         0.0000, -0.1043,  0.0000,  0.0000,  0.0025,  0.1926,  0.0000, -0.8757,
-        -0.0380,  0.0000,  0.0528,  0.1772, -0.0402,  0.0000, -0.5551, -1.2565,
-         0.0000, -0.1531,  0.0000,  0.0000,  0.0000,  0.0000,  0.1182,  0.4501,
-         0.0298,  0.0000, -0.2578,  0.0000,  0.0000, -0.2468,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0783, -0.1051,  0.0523,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4163,  0.0000,  0.0000, -0.0932,  0.0000,  0.0805],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3648,  0.0000,  0.0024, -0.0247,  2.1025,  0.0173,  0.0000,  0.0000,
-        -0.1483,  0.0685,  0.2621,  0.0000, -0.2267, -0.3967,  0.0000,  0.0000,
-         0.0000, -0.1043,  0.0000,  0.0000,  0.0025,  0.1926,  0.0000, -0.8757,
-        -0.0380,  0.0000,  0.0528,  0.1772, -0.0402,  0.0000, -0.5551, -1.2565,
-         0.0000, -0.1531,  0.0000,  0.0000,  0.0000,  0.0000,  0.1182,  0.4501,
-         0.0298,  0.0000, -0.2578,  0.0000,  0.0000, -0.2468,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0783, -0.1051,  0.0523,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4163,  0.0000,  0.0000, -0.0932,  0.0000,  0.0805],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5820e-01, -1.9623e-12,  3.2151e-04, -1.9326e-02,  2.1013e+00,
-         2.0961e-02,  1.1878e-21, -9.3791e-15, -1.5247e-01,  5.6304e-02,
-         2.3008e-01,  2.1353e-13, -2.3446e-01, -3.9320e-01, -2.1301e-18,
-        -1.5832e-15, -1.5573e-16, -9.6629e-02, -3.6019e-19,  4.3306e-15,
-        -3.7235e-03,  1.8604e-01, -7.6490e-19, -8.7223e-01, -3.9777e-02,
-         6.2080e-15,  5.0239e-02,  1.6444e-01, -3.5931e-02,  1.1124e-18,
-        -5.5711e-01, -1.2577e+00, -2.8476e-22, -1.4508e-01,  1.4134e-16,
-         7.6263e-15,  4.7149e-18,  0.0000e+00,  1.1444e-01,  4.4116e-01,
-         1.8849e-02, -6.2975e-16, -2.6078e-01,  1.1604e-18,  2.8381e-16,
-        -2.5063e-01, -4.0045e-14,  1.2922e-07,  1.3471e-13, -7.3871e-20,
-        -2.3901e-10,  7.1437e-02, -1.1015e-01,  4.4962e-02, -2.7070e-11,
-         1.5374e-23, -6.0422e-18,  2.1730e-21,  4.2008e-01, -3.1383e-13,
-         5.8046e-21, -9.9873e-02,  4.8770e-09,  6.8962e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.5820e-01,  0.0000e+00,  3.2151e-04, -1.9326e-02,  2.1013e+00,
-         2.0961e-02,  0.0000e+00,  0.0000e+00, -1.5247e-01,  5.6304e-02,
-         2.3008e-01,  0.0000e+00, -2.3446e-01, -3.9320e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -9.6629e-02,  0.0000e+00,  0.0000e+00,
-        -3.7235e-03,  1.8604e-01,  0.0000e+00, -8.7223e-01, -3.9777e-02,
-         0.0000e+00,  5.0239e-02,  1.6444e-01, -3.5931e-02,  0.0000e+00,
-        -5.5711e-01, -1.2577e+00,  0.0000e+00, -1.4508e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1444e-01,  4.4116e-01,
-         1.8849e-02,  0.0000e+00, -2.6078e-01,  0.0000e+00,  0.0000e+00,
-        -2.5063e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  7.1437e-02, -1.1015e-01,  4.4962e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2008e-01,  0.0000e+00,
-         0.0000e+00, -9.9873e-02,  0.0000e+00,  6.8962e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.5820e-01,  0.0000e+00,  3.2151e-04, -1.9326e-02,  2.1013e+00,
-         2.0961e-02,  0.0000e+00,  0.0000e+00, -1.5247e-01,  5.6304e-02,
-         2.3008e-01,  0.0000e+00, -2.3446e-01, -3.9320e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -9.6629e-02,  0.0000e+00,  0.0000e+00,
-        -3.7235e-03,  1.8604e-01,  0.0000e+00, -8.7223e-01, -3.9777e-02,
-         0.0000e+00,  5.0239e-02,  1.6444e-01, -3.5931e-02,  0.0000e+00,
-        -5.5711e-01, -1.2577e+00,  0.0000e+00, -1.4508e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1444e-01,  4.4116e-01,
-         1.8849e-02,  0.0000e+00, -2.6078e-01,  0.0000e+00,  0.0000e+00,
-        -2.5063e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  7.1437e-02, -1.1015e-01,  4.4962e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2008e-01,  0.0000e+00,
-         0.0000e+00, -9.9873e-02,  0.0000e+00,  6.8962e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5197e-01, -1.7938e-12, -7.5829e-04, -1.5311e-02,  2.1003e+00,
-         2.5982e-02,  1.0858e-21, -8.5738e-15, -1.5766e-01,  4.6765e-02,
-         1.9186e-01,  1.9520e-13, -2.4244e-01, -3.9172e-01, -1.9472e-18,
-        -1.4472e-15, -1.4236e-16, -8.4226e-02, -3.2926e-19,  3.9588e-15,
-        -9.0963e-03,  1.8219e-01, -6.9922e-19, -8.6863e-01, -4.0227e-02,
-         5.6750e-15,  4.9138e-02,  1.5414e-01, -3.2122e-02,  1.0169e-18,
-        -5.5699e-01, -1.2586e+00, -2.6031e-22, -1.3206e-01,  1.2921e-16,
-         6.9715e-15,  4.3100e-18,  0.0000e+00,  1.1167e-01,  4.3562e-01,
-         8.8359e-03, -5.7567e-16, -2.6368e-01,  1.0607e-18,  2.5944e-16,
-        -2.5368e-01, -3.6607e-14,  1.1812e-07,  1.2315e-13, -6.7529e-20,
-        -2.1849e-10,  6.6593e-02, -1.1932e-01,  4.4862e-02, -2.4746e-11,
-         1.4054e-23, -5.5234e-18,  1.9864e-21,  4.2456e-01, -2.8688e-13,
-         5.3062e-21, -1.0455e-01,  4.4582e-09,  5.4885e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.5197e-01,  0.0000e+00, -7.5829e-04, -1.5311e-02,  2.1003e+00,
-         2.5982e-02,  0.0000e+00,  0.0000e+00, -1.5766e-01,  4.6765e-02,
-         1.9186e-01,  0.0000e+00, -2.4244e-01, -3.9172e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -8.4226e-02,  0.0000e+00,  0.0000e+00,
-        -9.0963e-03,  1.8219e-01,  0.0000e+00, -8.6863e-01, -4.0227e-02,
-         0.0000e+00,  4.9138e-02,  1.5414e-01, -3.2122e-02,  0.0000e+00,
-        -5.5699e-01, -1.2586e+00,  0.0000e+00, -1.3206e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1167e-01,  4.3562e-01,
-         8.8359e-03,  0.0000e+00, -2.6368e-01,  0.0000e+00,  0.0000e+00,
-        -2.5368e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  6.6593e-02, -1.1932e-01,  4.4862e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2456e-01,  0.0000e+00,
-         0.0000e+00, -1.0455e-01,  0.0000e+00,  5.4885e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.5197e-01,  0.0000e+00, -7.5829e-04, -1.5311e-02,  2.1003e+00,
-         2.5982e-02,  0.0000e+00,  0.0000e+00, -1.5766e-01,  4.6765e-02,
-         1.9186e-01,  0.0000e+00, -2.4244e-01, -3.9172e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -8.4226e-02,  0.0000e+00,  0.0000e+00,
-        -9.0963e-03,  1.8219e-01,  0.0000e+00, -8.6863e-01, -4.0227e-02,
-         0.0000e+00,  4.9138e-02,  1.5414e-01, -3.2122e-02,  0.0000e+00,
-        -5.5699e-01, -1.2586e+00,  0.0000e+00, -1.3206e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1167e-01,  4.3562e-01,
-         8.8359e-03,  0.0000e+00, -2.6368e-01,  0.0000e+00,  0.0000e+00,
-        -2.5368e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  6.6593e-02, -1.1932e-01,  4.4862e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2456e-01,  0.0000e+00,
-         0.0000e+00, -1.0455e-01,  0.0000e+00,  5.4885e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4443e-01, -1.6393e-12, -2.1729e-03, -1.2307e-02,  2.0995e+00,
-         3.2250e-02,  9.9229e-22, -7.8352e-15, -1.6298e-01,  3.7598e-02,
-         1.5230e-01,  1.7839e-13, -2.5082e-01, -3.9040e-01, -1.7795e-18,
-        -1.3226e-15, -1.3009e-16, -7.3678e-02, -3.0090e-19,  3.6178e-15,
-        -1.5462e-02,  1.7988e-01, -6.3899e-19, -8.6480e-01, -3.9734e-02,
-         5.1861e-15,  4.7885e-02,  1.4311e-01, -2.7588e-02,  9.2933e-19,
-        -5.5598e-01, -1.2597e+00, -2.3789e-22, -1.1899e-01,  1.1808e-16,
-         6.3709e-15,  3.9387e-18,  0.0000e+00,  1.0430e-01,  4.2841e-01,
-         9.4905e-04, -5.2609e-16, -2.6628e-01,  9.6936e-19,  2.3710e-16,
-        -2.5953e-01, -3.3454e-14,  1.0795e-07,  1.1254e-13, -6.1712e-20,
-        -1.9967e-10,  6.3458e-02, -1.3199e-01,  4.3065e-02, -2.2614e-11,
-         1.2843e-23, -5.0476e-18,  1.8153e-21,  4.2981e-01, -2.6217e-13,
-         4.8491e-21, -1.0960e-01,  4.0742e-09,  3.9359e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.4443e-01,  0.0000e+00, -2.1729e-03, -1.2307e-02,  2.0995e+00,
-         3.2250e-02,  0.0000e+00,  0.0000e+00, -1.6298e-01,  3.7598e-02,
-         1.5230e-01,  0.0000e+00, -2.5082e-01, -3.9040e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -7.3678e-02,  0.0000e+00,  0.0000e+00,
-        -1.5462e-02,  1.7988e-01,  0.0000e+00, -8.6480e-01, -3.9734e-02,
-         0.0000e+00,  4.7885e-02,  1.4311e-01, -2.7588e-02,  0.0000e+00,
-        -5.5598e-01, -1.2597e+00,  0.0000e+00, -1.1899e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0430e-01,  4.2841e-01,
-         9.4905e-04,  0.0000e+00, -2.6628e-01,  0.0000e+00,  0.0000e+00,
-        -2.5953e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  6.3458e-02, -1.3199e-01,  4.3065e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2981e-01,  0.0000e+00,
-         0.0000e+00, -1.0960e-01,  0.0000e+00,  3.9359e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.4443e-01,  0.0000e+00, -2.1729e-03, -1.2307e-02,  2.0995e+00,
-         3.2250e-02,  0.0000e+00,  0.0000e+00, -1.6298e-01,  3.7598e-02,
-         1.5230e-01,  0.0000e+00, -2.5082e-01, -3.9040e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -7.3678e-02,  0.0000e+00,  0.0000e+00,
-        -1.5462e-02,  1.7988e-01,  0.0000e+00, -8.6480e-01, -3.9734e-02,
-         0.0000e+00,  4.7885e-02,  1.4311e-01, -2.7588e-02,  0.0000e+00,
-        -5.5598e-01, -1.2597e+00,  0.0000e+00, -1.1899e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0430e-01,  4.2841e-01,
-         9.4905e-04,  0.0000e+00, -2.6628e-01,  0.0000e+00,  0.0000e+00,
-        -2.5953e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  6.3458e-02, -1.3199e-01,  4.3065e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2981e-01,  0.0000e+00,
-         0.0000e+00, -1.0960e-01,  0.0000e+00,  3.9359e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.3598e-01, -1.4976e-12, -1.5898e-03, -7.6163e-03,  2.0984e+00,
-         4.0337e-02,  9.0651e-22, -7.1580e-15, -1.6624e-01,  2.4136e-02,
-         1.1244e-01,  1.6297e-13, -2.5515e-01, -3.8654e-01, -1.6256e-18,
-        -1.2082e-15, -1.1885e-16, -6.7193e-02, -2.7489e-19,  3.3051e-15,
-        -1.9767e-02,  1.7957e-01, -5.8376e-19, -8.5995e-01, -3.6186e-02,
-         4.7378e-15,  4.4966e-02,  1.3375e-01, -2.4861e-02,  8.4900e-19,
-        -5.5384e-01, -1.2603e+00, -2.1733e-22, -1.0840e-01,  1.0787e-16,
-         5.8202e-15,  3.5983e-18,  0.0000e+00,  9.8809e-02,  4.2213e-01,
-        -5.4346e-03, -4.8061e-16, -2.6675e-01,  8.8557e-19,  2.1660e-16,
-        -2.6435e-01, -3.0562e-14,  9.8615e-08,  1.0281e-13, -5.6377e-20,
-        -1.8241e-10,  6.7144e-02, -1.4860e-01,  4.1649e-02, -2.0660e-11,
-         1.1733e-23, -4.6113e-18,  1.6584e-21,  4.3496e-01, -2.3951e-13,
-         4.4299e-21, -1.1482e-01,  3.7220e-09,  2.1484e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.3598e-01,  0.0000e+00, -1.5898e-03, -7.6163e-03,  2.0984e+00,
-         4.0337e-02,  0.0000e+00,  0.0000e+00, -1.6624e-01,  2.4136e-02,
-         1.1244e-01,  0.0000e+00, -2.5515e-01, -3.8654e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -6.7193e-02,  0.0000e+00,  0.0000e+00,
-        -1.9767e-02,  1.7957e-01,  0.0000e+00, -8.5995e-01, -3.6186e-02,
-         0.0000e+00,  4.4966e-02,  1.3375e-01, -2.4861e-02,  0.0000e+00,
-        -5.5384e-01, -1.2603e+00,  0.0000e+00, -1.0840e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  9.8809e-02,  4.2213e-01,
-        -5.4346e-03,  0.0000e+00, -2.6675e-01,  0.0000e+00,  0.0000e+00,
-        -2.6435e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  6.7144e-02, -1.4860e-01,  4.1649e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.3496e-01,  0.0000e+00,
-         0.0000e+00, -1.1482e-01,  0.0000e+00,  2.1484e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.3598e-01,  0.0000e+00, -1.5898e-03, -7.6163e-03,  2.0984e+00,
-         4.0337e-02,  0.0000e+00,  0.0000e+00, -1.6624e-01,  2.4136e-02,
-         1.1244e-01,  0.0000e+00, -2.5515e-01, -3.8654e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -6.7193e-02,  0.0000e+00,  0.0000e+00,
-        -1.9767e-02,  1.7957e-01,  0.0000e+00, -8.5995e-01, -3.6186e-02,
-         0.0000e+00,  4.4966e-02,  1.3375e-01, -2.4861e-02,  0.0000e+00,
-        -5.5384e-01, -1.2603e+00,  0.0000e+00, -1.0840e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  9.8809e-02,  4.2213e-01,
-        -5.4346e-03,  0.0000e+00, -2.6675e-01,  0.0000e+00,  0.0000e+00,
-        -2.6435e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  6.7144e-02, -1.4860e-01,  4.1649e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.3496e-01,  0.0000e+00,
-         0.0000e+00, -1.1482e-01,  0.0000e+00,  2.1484e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2789e-01, -1.3677e-12,  1.2657e-03, -2.4517e-03,  2.0976e+00,
-         4.9379e-02,  8.2787e-22, -6.5370e-15, -1.7040e-01,  1.2490e-02,
-         7.4740e-02,  1.4883e-13, -2.5726e-01, -3.8518e-01, -1.4846e-18,
-        -1.1034e-15, -1.0854e-16, -6.4675e-02, -2.5104e-19,  3.0183e-15,
-        -2.4024e-02,  1.7828e-01, -5.3311e-19, -8.5657e-01, -3.1500e-02,
-         4.3268e-15,  4.5143e-02,  1.2314e-01, -2.2946e-02,  7.7534e-19,
-        -5.4934e-01, -1.2608e+00, -1.9847e-22, -1.0246e-01,  9.8511e-17,
-         5.3153e-15,  3.2861e-18,  0.0000e+00,  9.4034e-02,  4.1692e-01,
-        -1.2646e-02, -4.3891e-16, -2.6501e-01,  8.0874e-19,  1.9781e-16,
-        -2.6421e-01, -2.7910e-14,  9.0059e-08,  9.3891e-14, -5.1486e-20,
-        -1.6659e-10,  7.1764e-02, -1.6459e-01,  3.8303e-02, -1.8867e-11,
-         1.0715e-23, -4.2112e-18,  1.5145e-21,  4.3869e-01, -2.1873e-13,
-         4.0456e-21, -1.2081e-01,  3.3991e-09,  4.8951e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.2789e-01,  0.0000e+00,  1.2657e-03, -2.4517e-03,  2.0976e+00,
-         4.9379e-02,  0.0000e+00,  0.0000e+00, -1.7040e-01,  1.2490e-02,
-         7.4740e-02,  0.0000e+00, -2.5726e-01, -3.8518e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -6.4675e-02,  0.0000e+00,  0.0000e+00,
-        -2.4024e-02,  1.7828e-01,  0.0000e+00, -8.5657e-01, -3.1500e-02,
-         0.0000e+00,  4.5143e-02,  1.2314e-01, -2.2946e-02,  0.0000e+00,
-        -5.4934e-01, -1.2608e+00,  0.0000e+00, -1.0246e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  9.4034e-02,  4.1692e-01,
-        -1.2646e-02,  0.0000e+00, -2.6501e-01,  0.0000e+00,  0.0000e+00,
-        -2.6421e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  7.1764e-02, -1.6459e-01,  3.8303e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.3869e-01,  0.0000e+00,
-         0.0000e+00, -1.2081e-01,  0.0000e+00,  4.8951e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.2789e-01,  0.0000e+00,  1.2657e-03, -2.4517e-03,  2.0976e+00,
-         4.9379e-02,  0.0000e+00,  0.0000e+00, -1.7040e-01,  1.2490e-02,
-         7.4740e-02,  0.0000e+00, -2.5726e-01, -3.8518e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -6.4675e-02,  0.0000e+00,  0.0000e+00,
-        -2.4024e-02,  1.7828e-01,  0.0000e+00, -8.5657e-01, -3.1500e-02,
-         0.0000e+00,  4.5143e-02,  1.2314e-01, -2.2946e-02,  0.0000e+00,
-        -5.4934e-01, -1.2608e+00,  0.0000e+00, -1.0246e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  9.4034e-02,  4.1692e-01,
-        -1.2646e-02,  0.0000e+00, -2.6501e-01,  0.0000e+00,  0.0000e+00,
-        -2.6421e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  7.1764e-02, -1.6459e-01,  3.8303e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.3869e-01,  0.0000e+00,
-         0.0000e+00, -1.2081e-01,  0.0000e+00,  4.8951e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1720e-01, -1.2485e-12,  3.7605e-03,  2.4614e-03,  2.0967e+00,
-         5.4613e-02,  7.5576e-22, -5.9676e-15, -1.7414e-01,  3.9396e-03,
-         5.1275e-02,  1.3587e-13, -2.6045e-01, -3.8550e-01, -1.3553e-18,
-        -1.0073e-15, -9.9084e-17, -6.2618e-02, -2.2917e-19,  2.7554e-15,
-        -2.9798e-02,  1.7863e-01, -4.8668e-19, -8.5379e-01, -2.6081e-02,
-         3.9500e-15,  4.6201e-02,  1.1365e-01, -2.2987e-02,  7.0781e-19,
-        -5.4629e-01, -1.2616e+00, -1.8119e-22, -9.7161e-02,  8.9932e-17,
-         4.8524e-15,  2.9999e-18,  0.0000e+00,  8.8029e-02,  4.1006e-01,
-        -2.1547e-02, -4.0069e-16, -2.6291e-01,  7.3830e-19,  1.8058e-16,
-        -2.6403e-01, -2.5480e-14,  8.2216e-08,  8.5713e-14, -4.7002e-20,
-        -1.5208e-10,  7.4839e-02, -1.7510e-01,  3.3851e-02, -1.7224e-11,
-         9.7819e-24, -3.8445e-18,  1.3826e-21,  4.4234e-01, -1.9968e-13,
-         3.6933e-21, -1.2717e-01,  3.1031e-09, -6.2694e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3172,  0.0000,  0.0038,  0.0025,  2.0967,  0.0546,  0.0000,  0.0000,
-        -0.1741,  0.0039,  0.0513,  0.0000, -0.2604, -0.3855,  0.0000,  0.0000,
-         0.0000, -0.0626,  0.0000,  0.0000, -0.0298,  0.1786,  0.0000, -0.8538,
-        -0.0261,  0.0000,  0.0462,  0.1137, -0.0230,  0.0000, -0.5463, -1.2616,
-         0.0000, -0.0972,  0.0000,  0.0000,  0.0000,  0.0000,  0.0880,  0.4101,
-        -0.0215,  0.0000, -0.2629,  0.0000,  0.0000, -0.2640,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0748, -0.1751,  0.0339,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4423,  0.0000,  0.0000, -0.1272,  0.0000, -0.0063],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3172,  0.0000,  0.0038,  0.0025,  2.0967,  0.0546,  0.0000,  0.0000,
-        -0.1741,  0.0039,  0.0513,  0.0000, -0.2604, -0.3855,  0.0000,  0.0000,
-         0.0000, -0.0626,  0.0000,  0.0000, -0.0298,  0.1786,  0.0000, -0.8538,
-        -0.0261,  0.0000,  0.0462,  0.1137, -0.0230,  0.0000, -0.5463, -1.2616,
-         0.0000, -0.0972,  0.0000,  0.0000,  0.0000,  0.0000,  0.0880,  0.4101,
-        -0.0215,  0.0000, -0.2629,  0.0000,  0.0000, -0.2640,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0748, -0.1751,  0.0339,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4423,  0.0000,  0.0000, -0.1272,  0.0000, -0.0063],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.0896e-01, -1.1394e-12,  8.9681e-03,  4.0967e-03,  2.0957e+00,
-         5.9799e-02,  6.8967e-22, -5.4457e-15, -1.7904e-01, -5.8454e-03,
-         3.3818e-02,  1.2398e-13, -2.6278e-01, -3.8730e-01, -1.2368e-18,
-        -9.1922e-16, -9.0419e-17, -6.1595e-02, -2.0913e-19,  2.5145e-15,
-        -3.6718e-02,  1.7530e-01, -4.4412e-19, -8.5079e-01, -2.4585e-02,
-         3.6045e-15,  4.9239e-02,  1.0103e-01, -2.3903e-02,  6.4591e-19,
-        -5.4337e-01, -1.2622e+00, -1.6534e-22, -9.2914e-02,  8.2067e-17,
-         4.4280e-15,  2.7376e-18,  0.0000e+00,  8.4866e-02,  4.0620e-01,
-        -3.2085e-02, -3.6565e-16, -2.5837e-01,  6.7374e-19,  1.6479e-16,
-        -2.6700e-01, -2.3251e-14,  7.5026e-08,  7.8217e-14, -4.2892e-20,
-        -1.3878e-10,  7.1584e-02, -1.9056e-01,  2.5945e-02, -1.5718e-11,
-         8.9265e-24, -3.5082e-18,  1.2617e-21,  4.4414e-01, -1.8222e-13,
-         3.3703e-21, -1.3364e-01,  2.8317e-09, -1.4528e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3090,  0.0000,  0.0090,  0.0041,  2.0957,  0.0598,  0.0000,  0.0000,
-        -0.1790, -0.0058,  0.0338,  0.0000, -0.2628, -0.3873,  0.0000,  0.0000,
-         0.0000, -0.0616,  0.0000,  0.0000, -0.0367,  0.1753,  0.0000, -0.8508,
-        -0.0246,  0.0000,  0.0492,  0.1010, -0.0239,  0.0000, -0.5434, -1.2622,
-         0.0000, -0.0929,  0.0000,  0.0000,  0.0000,  0.0000,  0.0849,  0.4062,
-        -0.0321,  0.0000, -0.2584,  0.0000,  0.0000, -0.2670,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0716, -0.1906,  0.0259,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4441,  0.0000,  0.0000, -0.1336,  0.0000, -0.0145],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3090,  0.0000,  0.0090,  0.0041,  2.0957,  0.0598,  0.0000,  0.0000,
-        -0.1790, -0.0058,  0.0338,  0.0000, -0.2628, -0.3873,  0.0000,  0.0000,
-         0.0000, -0.0616,  0.0000,  0.0000, -0.0367,  0.1753,  0.0000, -0.8508,
-        -0.0246,  0.0000,  0.0492,  0.1010, -0.0239,  0.0000, -0.5434, -1.2622,
-         0.0000, -0.0929,  0.0000,  0.0000,  0.0000,  0.0000,  0.0849,  0.4062,
-        -0.0321,  0.0000, -0.2584,  0.0000,  0.0000, -0.2670,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0716, -0.1906,  0.0259,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4441,  0.0000,  0.0000, -0.1336,  0.0000, -0.0145],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.0148e-01, -1.0393e-12,  1.4269e-02,  7.6697e-03,  2.0947e+00,
-         6.3072e-02,  6.2909e-22, -4.9674e-15, -1.8241e-01, -1.2662e-02,
-         2.2120e-02,  1.1309e-13, -2.6403e-01, -3.8822e-01, -1.1281e-18,
-        -8.3848e-16, -8.2477e-17, -6.0704e-02, -1.9076e-19,  2.2936e-15,
-        -4.3446e-02,  1.7471e-01, -4.0511e-19, -8.4693e-01, -2.0560e-02,
-         3.2879e-15,  5.1424e-02,  8.8464e-02, -2.5438e-02,  5.8918e-19,
-        -5.4152e-01, -1.2621e+00, -1.5082e-22, -8.9325e-02,  7.4859e-17,
-         4.0391e-15,  2.4971e-18,  0.0000e+00,  8.4010e-02,  4.0333e-01,
-        -4.2808e-02, -3.3353e-16, -2.5354e-01,  6.1456e-19,  1.5032e-16,
-        -2.7064e-01, -2.1209e-14,  6.8436e-08,  7.1347e-14, -3.9124e-20,
-        -1.2659e-10,  7.1597e-02, -2.0416e-01,  1.8862e-02, -1.4337e-11,
-         8.1424e-24, -3.2001e-18,  1.1509e-21,  4.4617e-01, -1.6621e-13,
-         3.0743e-21, -1.3942e-01,  2.5830e-09, -1.9292e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3015,  0.0000,  0.0143,  0.0077,  2.0947,  0.0631,  0.0000,  0.0000,
-        -0.1824, -0.0127,  0.0221,  0.0000, -0.2640, -0.3882,  0.0000,  0.0000,
-         0.0000, -0.0607,  0.0000,  0.0000, -0.0434,  0.1747,  0.0000, -0.8469,
-        -0.0206,  0.0000,  0.0514,  0.0885, -0.0254,  0.0000, -0.5415, -1.2621,
-         0.0000, -0.0893,  0.0000,  0.0000,  0.0000,  0.0000,  0.0840,  0.4033,
-        -0.0428,  0.0000, -0.2535,  0.0000,  0.0000, -0.2706,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0716, -0.2042,  0.0189,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4462,  0.0000,  0.0000, -0.1394,  0.0000, -0.0193],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3015,  0.0000,  0.0143,  0.0077,  2.0947,  0.0631,  0.0000,  0.0000,
-        -0.1824, -0.0127,  0.0221,  0.0000, -0.2640, -0.3882,  0.0000,  0.0000,
-         0.0000, -0.0607,  0.0000,  0.0000, -0.0434,  0.1747,  0.0000, -0.8469,
-        -0.0206,  0.0000,  0.0514,  0.0885, -0.0254,  0.0000, -0.5415, -1.2621,
-         0.0000, -0.0893,  0.0000,  0.0000,  0.0000,  0.0000,  0.0840,  0.4033,
-        -0.0428,  0.0000, -0.2535,  0.0000,  0.0000, -0.2706,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0716, -0.2042,  0.0189,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4462,  0.0000,  0.0000, -0.1394,  0.0000, -0.0193],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9396e-01, -9.4758e-13,  1.8215e-02,  9.7023e-03,  2.0938e+00,
-         6.4198e-02,  5.7358e-22, -4.5291e-15, -1.8744e-01, -1.7959e-02,
-         1.4392e-02,  1.0311e-13, -2.6694e-01, -3.8918e-01, -1.0286e-18,
-        -7.6449e-16, -7.5200e-17, -5.9900e-02, -1.7393e-19,  2.0912e-15,
-        -5.1187e-02,  1.7087e-01, -3.6937e-19, -8.4397e-01, -1.9693e-02,
-         2.9978e-15,  5.5533e-02,  7.5367e-02, -2.7458e-02,  5.3719e-19,
-        -5.4007e-01, -1.2621e+00, -1.3751e-22, -8.5249e-02,  6.8253e-17,
-         3.6827e-15,  2.2768e-18,  0.0000e+00,  8.2562e-02,  4.0136e-01,
-        -5.4306e-02, -3.0410e-16, -2.4833e-01,  5.6033e-19,  1.3705e-16,
-        -2.7231e-01, -1.9338e-14,  6.2397e-08,  6.5052e-14, -3.5672e-20,
-        -1.1542e-10,  6.7050e-02, -2.1400e-01,  1.1176e-02, -1.3072e-11,
-         7.4239e-24, -2.9177e-18,  1.0493e-21,  4.4850e-01, -1.5154e-13,
-         2.8030e-21, -1.4395e-01,  2.3551e-09, -1.9831e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2940,  0.0000,  0.0182,  0.0097,  2.0938,  0.0642,  0.0000,  0.0000,
-        -0.1874, -0.0180,  0.0144,  0.0000, -0.2669, -0.3892,  0.0000,  0.0000,
-         0.0000, -0.0599,  0.0000,  0.0000, -0.0512,  0.1709,  0.0000, -0.8440,
-        -0.0197,  0.0000,  0.0555,  0.0754, -0.0275,  0.0000, -0.5401, -1.2621,
-         0.0000, -0.0852,  0.0000,  0.0000,  0.0000,  0.0000,  0.0826,  0.4014,
-        -0.0543,  0.0000, -0.2483,  0.0000,  0.0000, -0.2723,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0671, -0.2140,  0.0112,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4485,  0.0000,  0.0000, -0.1440,  0.0000, -0.0198],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2940,  0.0000,  0.0182,  0.0097,  2.0938,  0.0642,  0.0000,  0.0000,
-        -0.1874, -0.0180,  0.0144,  0.0000, -0.2669, -0.3892,  0.0000,  0.0000,
-         0.0000, -0.0599,  0.0000,  0.0000, -0.0512,  0.1709,  0.0000, -0.8440,
-        -0.0197,  0.0000,  0.0555,  0.0754, -0.0275,  0.0000, -0.5401, -1.2621,
-         0.0000, -0.0852,  0.0000,  0.0000,  0.0000,  0.0000,  0.0826,  0.4014,
-        -0.0543,  0.0000, -0.2483,  0.0000,  0.0000, -0.2723,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0671, -0.2140,  0.0112,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4485,  0.0000,  0.0000, -0.1440,  0.0000, -0.0198],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8364e-01, -8.6356e-13,  2.3995e-02,  1.3488e-02,  2.0936e+00,
-         6.5073e-02,  5.2273e-22, -4.1275e-15, -1.9212e-01, -2.2407e-02,
-         1.6205e-02,  9.3972e-14, -2.6933e-01, -3.8867e-01, -9.3740e-19,
-        -6.9671e-16, -6.8532e-17, -5.2110e-02, -1.5851e-19,  1.9058e-15,
-        -6.0010e-02,  1.6931e-01, -3.3661e-19, -8.3944e-01, -1.8648e-02,
-         2.7320e-15,  5.8916e-02,  6.1481e-02, -3.2871e-02,  4.8956e-19,
-        -5.3908e-01, -1.2617e+00, -1.2532e-22, -8.2074e-02,  6.2201e-17,
-         3.3562e-15,  2.0749e-18,  0.0000e+00,  8.2516e-02,  3.9997e-01,
-        -6.5787e-02, -2.7714e-16, -2.4107e-01,  5.1065e-19,  1.2490e-16,
-        -2.7515e-01, -1.7623e-14,  5.6865e-08,  5.9284e-14, -3.2509e-20,
-        -1.0518e-10,  6.5321e-02, -2.2398e-01,  6.6516e-03, -1.1913e-11,
-         6.7657e-24, -2.6590e-18,  9.5628e-22,  4.5156e-01, -1.3811e-13,
-         2.5545e-21, -1.4627e-01,  2.1462e-09, -2.4427e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Sparsity at the end of epoch 2: 50.00%
-Final Sparsity: 50.00
-Sparsity in Conv2d 2: 1.56%
-Sparsity in Conv2d 8: 1.56%
-Sparsity in Conv2d 11: 1.56%
-Sparsity in Conv2d 14: 1.56%
-Sparsity in Conv2d 17: 1.56%
-Sparsity in Conv2d 21: 0.78%
-Sparsity in Conv2d 24: 0.78%
-Sparsity in Conv2d 27: 0.78%
-Sparsity in Conv2d 30: 0.78%
-Sparsity in Conv2d 33: 0.78%
-Sparsity in Conv2d 37: 0.39%
-Sparsity in Conv2d 40: 0.39%
-Sparsity in Conv2d 43: 0.39%
-Sparsity in Conv2d 46: 0.39%
-Sparsity in Conv2d 49: 0.39%
-Sparsity in Conv2d 53: 0.20%
-Sparsity in Conv2d 56: 0.20%
-Sparsity in Conv2d 59: 0.20%
-Sparsity in Conv2d 62: 0.20%
-Sparsity in Conv2d 65: 0.20%
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
model.model.conv1.weight.sum(dim=(1,2,3))
-
- -
-
-
- -
-
- -
- - - -
-
tensor([ 2.9347e-01, -2.7638e-15, -4.3429e-01,  1.5531e-20, -1.2244e-01,
-         6.7792e-02,  3.4213e-24,  9.2662e-15, -2.5555e-01, -6.6723e-11,
-        -1.1368e-02, -5.4554e-18,  2.7437e-02, -3.6576e-12,  1.6695e-18,
-        -8.0519e-02,  6.7549e-18,  6.4657e-02,  5.7248e-18,  3.1335e-17,
-        -6.9838e-14, -1.6188e-02,  1.2506e-20,  5.0455e-01,  1.3777e-13,
-        -6.4526e-19, -3.7569e-02, -1.2282e-14,  6.2495e-02, -1.4700e-18,
-        -2.6848e-01,  9.4839e-02,  9.6079e-22,  1.5481e-01, -4.7590e-19,
-         2.1518e-14, -7.0799e-16,  0.0000e+00,  1.6172e+00,  5.7085e-01,
-        -6.2181e-02, -3.7426e-01,  1.1096e-01, -6.0660e-16, -5.0897e-22,
-        -1.4613e-01, -2.6145e-12, -1.7860e-08,  3.6786e-10, -3.4189e-17,
-         5.0733e-13,  1.2981e-01, -9.3539e-01, -1.3682e-01, -5.1219e-01,
-        -2.5171e-02, -9.8362e-02, -3.2823e-23, -1.1528e-15, -1.0429e+00,
-        -1.0777e-19, -1.6025e-01,  1.1684e-02,  8.0589e-02],
-       grad_fn=<SumBackward1>)
-
- -
- -
-
- -
- {% endraw %} - -
- - diff --git a/docs/sparsify_callback-Copy1.html b/docs/sparsify_callback-Copy1.html deleted file mode 100644 index ebff007..0000000 --- a/docs/sparsify_callback-Copy1.html +++ /dev/null @@ -1,7956 +0,0 @@ ---- - -title: SparsifyCallback - - -keywords: fastai -sidebar: home_sidebar - -summary: "Use the sparsifier in fastai Callback system" -description: "Use the sparsifier in fastai Callback system" -nb_path: "nbs/02_sparsify_callback-Copy1.ipynb" ---- - - -
- - {% raw %} - -
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
 
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
path = untar_data(URLs.PETS)
-files = get_image_files(path/"images")
-
-def label_func(f): return f[0].isupper()
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
- -
-
- -
- - -
-

class SparsifyCallback[source]

SparsifyCallback(sparsity, granularity, context, criteria, schedule, lth=False, rewind_epoch=0, reset_end=False, save_tickets=False, model=None, round_to=None, layer_type=Conv2d) :: Callback

-
-

Basic class handling tweaks of the training loop by changing a Learner in various events

- -
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
- -
- {% endraw %} - -
-
-

The most important part of our Callback happens in before_batch. There, we first compute the sparsity of our network according to our schedule and then we remove the parameters accordingly.

- -
-
-
- {% raw %} - -
-
- -
-
-
learn = cnn_learner(dls, resnet18, metrics=accuracy)
-learn.unfreeze()
-
- -
-
-
- -
-
- -
- -
-
/home/HubensN/miniconda3/envs/deep/lib/python3.8/site-packages/fastai/vision/learner.py:265: UserWarning: `cnn_learner` has been renamed to `vision_learner` -- please update your code
-  warn("`cnn_learner` has been renamed to `vision_learner` -- please update your code")
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
learn.fit_one_cycle(5)
-
- -
-
-
- -
-
- -
- - -
- -
- -
- -
- - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
epochtrain_lossvalid_lossaccuracytime
00.6560140.5388870.85521000:08
10.3768430.2446170.89309900:07
20.2280750.3104310.89039200:07
30.1361980.1619780.94249000:07
40.0690740.1669800.94181300:07
- -
- -
-
- -
- {% endraw %} - -
-
-

Let's now try adding some sparsity in our model

- -
-
-
- {% raw %} - -
-
- -
-
-
learn = cnn_learner(dls, resnet18, metrics=accuracy)
-learn.unfreeze()
-
- -
-
-
- -
-
- -
- -
-
/home/HubensN/miniconda3/envs/deep/lib/python3.8/site-packages/fastai/vision/learner.py:265: UserWarning: `cnn_learner` has been renamed to `vision_learner` -- please update your code
-  warn("`cnn_learner` has been renamed to `vision_learner` -- please update your code")
-
-
-
- -
-
- -
- {% endraw %} - -
-
-

The SparsifyCallback requires a new argument compared to the Sparsifier. Indeed, we need to know the pruning schedule that we should follow during training in order to prune the parameters accordingly.

- -
-
-
-
-
-

You can use any scheduling function already available in fastai or come up with your own ! For more information about the pruning schedules, take a look at the Schedules section.

- -
-
-
- {% raw %} - -
-
- -
-
-
learn = cnn_learner(dls, resnet18, metrics=accuracy)
-learn.unfreeze()
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb = SparsifyCallback(sparsity=50, granularity='weight', context='local', criteria=updating_movmag, schedule=cos)
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
learn.fit(5, cbs=sp_cb)
-
- -
-
-
- -
-
- -
- -
-
Pruning of weight until a sparsity of [50]%
-Saving Weights at epoch 0
-
-
-
- -
- - -
- -
- -
- -
- - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
epochtrain_lossvalid_lossaccuracytime
00.5755920.4965580.85385700:08
10.3650920.2765310.88701000:08
20.2681070.2554630.89174600:08
30.2196080.4642340.83355900:08
40.1857350.2450780.89309900:08
- -
- -
- -
-
Sparsity at the end of epoch 0: [4.77]%
-Sparsity at the end of epoch 1: [17.27]%
-Sparsity at the end of epoch 2: [32.73]%
-Sparsity at the end of epoch 3: [45.23]%
-Sparsity at the end of epoch 4: [50.0]%
-Final Sparsity: [50.0]%
-Sparsity in Conv2d 2: 50.00%
-Sparsity in Conv2d 8: 50.00%
-Sparsity in Conv2d 11: 50.00%
-Sparsity in Conv2d 14: 50.00%
-Sparsity in Conv2d 17: 50.00%
-Sparsity in Conv2d 21: 50.00%
-Sparsity in Conv2d 24: 50.00%
-Sparsity in Conv2d 27: 50.00%
-Sparsity in Conv2d 30: 50.00%
-Sparsity in Conv2d 33: 50.00%
-Sparsity in Conv2d 37: 50.00%
-Sparsity in Conv2d 40: 50.00%
-Sparsity in Conv2d 43: 50.00%
-Sparsity in Conv2d 46: 50.00%
-Sparsity in Conv2d 49: 50.00%
-Sparsity in Conv2d 53: 50.00%
-Sparsity in Conv2d 56: 50.00%
-Sparsity in Conv2d 59: 50.00%
-Sparsity in Conv2d 62: 50.00%
-Sparsity in Conv2d 65: 50.00%
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
zrs, els = 0,0
-for k,m in enumerate(learn.model.modules()):
-        if isinstance(m, nn.Conv2d):
-            zrs += torch.sum(m.weight == 0)
-            els += m.weight.nelement() 
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class Criteria():
-    def __init__(self, f, needs_init=False, needs_update=False, output_f=None, return_init=False):
-        store_attr()
-        assert (needs_init and needs_update)==False, "The init values will be overwritten by the updating ones."
-
-    def __call__(self, m, g):
-        self.min_value=None
-        if self.needs_update and hasattr(m, '_old_weights') == False:
-            m.register_buffer("_old_weights", m._init_weights.clone()) # If the previous value of weights is not known, take the initial value
-
-        if g in granularities[m.__class__.__name__]:
-            dim = granularities[m.__class__.__name__][g]
-            wf = self.f(m.weight)[None].mean(dim=dim, keepdim=True).squeeze(0)
-            #wf = self.f(m.weight[None].mean(dim=dim, keepdim=True)).squeeze(0)
-            if self.needs_init: wi = self.f(m._init_weights)[None].mean(dim=dim, keepdim=True).squeeze(0)
-            if self.needs_update: wi = self.f(m._old_weights)[None].mean(dim=dim, keepdim=True).squeeze(0)
-            #if self.needs_update: wi = self.f(m._old_weights[None].mean(dim=dim, keepdim=True)).squeeze(0)
-
-        else: raise NameError('Invalid Granularity')
-            
-        if hasattr(m, '_mask') == False: m.register_buffer("_mask", torch.ones_like(wf)) # Put the mask into a buffer
-
-        if self.output_f: output = self.output_f(wf, wi)
-        elif self.return_init: output = wi
-        else: output = wf
-        return output
-    
-    def rescale(self, w, min_value=None):
-        self.min_value = min_value if min_value else w.view(-1)[w.argmin()]
-        output =  w + self.min_value.abs() + torch.finfo(torch.float32).eps
-        return output
-    
-    def descale(self, w):
-        output = w - self.min_value.abs()
-        return output
-        
-    def update_weights(self, m):
-        if self.needs_update: 
-            m._old_weights = m.weight.data.clone() # The current value becomes the old one for the next iteration
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
from fasterai.sparse.granularity import *
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
random = Criteria(torch.rand_like)
-large_final = Criteria(torch.abs)
-updating_magnitude_increase = Criteria(torch.abs, needs_update=True, output_f= torch.sub)
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
learn = cnn_learner(dls, resnet18, metrics=accuracy)
-learn.unfreeze()
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
torch.rand_like(learn.model[0][0].weight)[None].mean(dim=0, keepdim=True).squeeze(0).mean()
-
- -
-
-
- -
-
- -
- - - -
-
tensor(0.5009)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
random(learn.model[0][0], 'weight').mean()
-
- -
-
-
- -
-
- -
- -
-
tensor(0.4979)
-tensor(0.4964)
-
-
-
- -
- - - -
-
tensor(0.4964)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
torch.abs(learn.model[0][5][0].downsample[0].weight)[None].sum(dim=0, keepdim=True).squeeze(0).mean()
-
- -
-
-
- -
-
- -
- - - -
-
tensor(0.0240, device='cuda:0', grad_fn=<MeanBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
torch.abs(learn.model[0][0].weight).mean()
-
- -
-
-
- -
-
- -
- - - -
-
tensor(0.0372, device='cuda:0', grad_fn=<MeanBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
torch.abs(learn.model[0][5][0].downsample[0].weight).mean()
-
- -
-
-
- -
-
- -
- - - -
-
tensor(0.0240, device='cuda:0', grad_fn=<MeanBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
torch.abs(learn.model[0][7][1].conv1.weight).mean()
-
- -
-
-
- -
-
- -
- - - -
-
tensor(0.0076, device='cuda:0', grad_fn=<MeanBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
large_final(learn.model[0][0], 'filter').mean()
-
- -
-
-
- -
-
- -
- - - -
-
tensor(11.1946, grad_fn=<MeanBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
large_final(learn.model[0][5][0].downsample[0], 'filter').mean()
-
- -
-
-
- -
-
- -
- - - -
-
tensor(2.8424, grad_fn=<MeanBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
large_final(learn.model[0][7][1].conv1, 'filter').mean()
-
- -
-
-
- -
-
- -
- - - -
-
tensor(64.6641, grad_fn=<MeanBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
learn = cnn_learner(dls, resnet18, metrics=accuracy)
-learn.unfreeze()
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb = SparsifyCallback(sparsity=50, granularity='filter', context='global', criteria=large_final, schedule=cos)
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
learn.fit(5, cbs=sp_cb)
-
- -
-
-
- -
-
- -
- -
-
Pruning of filter until a sparsity of [50]%
-Saving Weights at epoch 0
-
-
-
- -
- - -
- -
- -
- -
- - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
epochtrain_lossvalid_lossaccuracytime
00.5883330.3406100.85521000:07
10.3820230.3086270.87144800:07
20.2738230.3172290.86941800:07
30.2093440.2277010.89986500:07
40.1841720.2652350.90189400:07
- -
- -
- -
-
Sparsity at the end of epoch 0: [4.77]%
-Sparsity at the end of epoch 1: [17.27]%
-Sparsity at the end of epoch 2: [32.73]%
-Sparsity at the end of epoch 3: [45.23]%
-Sparsity at the end of epoch 4: [50.0]%
-Final Sparsity: [50.0]%
-Sparsity in Conv2d 2: 12.50%
-Sparsity in Conv2d 8: 0.00%
-Sparsity in Conv2d 11: 0.00%
-Sparsity in Conv2d 14: 0.00%
-Sparsity in Conv2d 17: 0.00%
-Sparsity in Conv2d 21: 0.00%
-Sparsity in Conv2d 24: 0.00%
-Sparsity in Conv2d 27: 0.78%
-Sparsity in Conv2d 30: 0.00%
-Sparsity in Conv2d 33: 1.56%
-Sparsity in Conv2d 37: 0.00%
-Sparsity in Conv2d 40: 2.34%
-Sparsity in Conv2d 43: 3.12%
-Sparsity in Conv2d 46: 54.69%
-Sparsity in Conv2d 49: 66.80%
-Sparsity in Conv2d 53: 99.61%
-Sparsity in Conv2d 56: 99.80%
-Sparsity in Conv2d 59: 3.32%
-Sparsity in Conv2d 62: 99.80%
-Sparsity in Conv2d 65: 99.80%
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
learn.fit(5, cbs=sp_cb)
-
- -
-
-
- -
-
- -
- -
-
Pruning of filter until a sparsity of [50]%
-Saving Weights at epoch 0
-
-
-
- -
- - -
- -
- -
- -
- - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
epochtrain_lossvalid_lossaccuracytime
00.7787441.1399400.67050100:08
10.7543050.6795140.66847100:08
20.6910040.6616150.66847100:08
30.6410021.5615770.67659000:08
40.6427960.7284560.66847100:08
- -
- -
- -
-
Sparsity at the end of epoch 0: [4.77]%
-Sparsity at the end of epoch 1: [17.27]%
-Sparsity at the end of epoch 2: [32.73]%
-Sparsity at the end of epoch 3: [45.23]%
-Sparsity at the end of epoch 4: [50.0]%
-Final Sparsity: [50.0]%
-Sparsity in Conv2d 2: 98.44%
-Sparsity in Conv2d 8: 98.63%
-Sparsity in Conv2d 11: 90.62%
-Sparsity in Conv2d 14: 79.69%
-Sparsity in Conv2d 17: 56.25%
-Sparsity in Conv2d 21: 72.66%
-Sparsity in Conv2d 24: 53.91%
-Sparsity in Conv2d 27: 83.59%
-Sparsity in Conv2d 30: 40.62%
-Sparsity in Conv2d 33: 25.00%
-Sparsity in Conv2d 37: 50.78%
-Sparsity in Conv2d 40: 53.12%
-Sparsity in Conv2d 43: 67.19%
-Sparsity in Conv2d 46: 46.48%
-Sparsity in Conv2d 49: 23.44%
-Sparsity in Conv2d 53: 47.07%
-Sparsity in Conv2d 56: 38.48%
-Sparsity in Conv2d 59: 41.21%
-Sparsity in Conv2d 62: 60.74%
-Sparsity in Conv2d 65: 10.74%
-
-
-
- -
-
- -
- {% endraw %} - -
-
-

To remove

-
-
-
- {% raw %} - -
-
- -
-
-
from fasterai.sparse.granularity import *
-
-# Cell
-class Criteria():
-    def __init__(self, f, needs_init=False, needs_update=False, output_f=None, return_init=False):
-        store_attr()
-        assert (needs_init and needs_update)==False, "The init values will be overwritten by the updating ones."
-
-    def __call__(self, m, g):
-        if self.needs_update and hasattr(m, '_old_weights') == False:
-            m.register_buffer("_old_weights", m._init_weights.clone()) # If the previous value of weights is not known, take the initial value
-
-        if g in granularities[m.__class__.__name__]:
-            dim = granularities[m.__class__.__name__][g]
-            wf = self.f(m.weight)[None].mean(dim=dim, keepdim=True).squeeze(0)
-            if self.needs_init: wi = self.f(m._init_weights)[None].mean(dim=dim, keepdim=True).squeeze(0)
-            if self.needs_update: wi = self.f(m._old_weights)[None].mean(dim=dim, keepdim=True).squeeze(0)
-
-        else: raise NameError('Invalid Granularity')
-
-        if self.needs_update: m._old_weights = m.weight.clone() # The current value becomes the old one for the next iteration
-
-        if self.output_f: return self.output_f(wf, wi)
-        elif self.return_init: return wi
-        else: return wf
-
-# Cell
-random = Criteria(torch.randn_like)
-
-# Cell
-large_final = Criteria(torch.abs)
-
-# Cell
-squared_final = Criteria(torch.square)
-
-# Cell
-small_final = Criteria(compose(torch.abs, torch.neg))
-
-# Cell
-large_init = Criteria(torch.abs, needs_init=True, return_init=True)
-
-# Cell
-small_init = Criteria(compose(torch.abs, torch.neg), needs_init=True, return_init=True)
-
-# Cell
-large_init_large_final = Criteria(torch.abs, needs_init=True, output_f=torch.min)
-
-# Cell
-small_init_small_final = Criteria(torch.abs, needs_init=True, output_f=lambda x,y: torch.neg(torch.max(x,y)))
-
-# Cell
-magnitude_increase = Criteria(torch.abs, needs_init=True, output_f= torch.sub)
-
-# Cell
-movement = Criteria(noop, needs_init=True, output_f= lambda x,y: torch.abs(torch.sub(x,y)))
-
-# Cell
-updating_magnitude_increase = Criteria(torch.abs, needs_update=True, output_f= torch.sub)
-
-# Cell
-updating_movement = Criteria(noop, needs_update=True, output_f= lambda x,y: torch.abs(torch.sub(x,y)))
-
-# Cell
-updating_movmag = Criteria(noop, needs_update=True, output_f=lambda x,y: torch.abs(torch.mul(x, torch.sub(x,y))))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class Sparsifier():
-    def __init__(self, model, granularity, context, criteria, layer_type=nn.Conv2d):
-        store_attr()
-        self._save_weights() # Save the original weights
-
-    def prune_layer(self, m, sparsity, round_to=None):
-        weight = self.criteria(m, self.granularity)
-        mask = self._compute_mask(weight, sparsity, round_to)
-        m.register_buffer("_mask", mask) # Put the mask into a buffer
-        self._apply(m)
-
-    def prune_model(self, sparsity, round_to=None):
-        self.threshold=None
-        sparsity_list = listify(sparsity)
-        if len(sparsity_list)>1: assert self.context=='local', f"A list of sparsities cannot be passed using: {self.context}"
-        sparsities = cycle(sparsity_list) if len(sparsity_list)==1 else iter(sparsity_list)
-        mods = list(self.model.modules())
-        for k,m in enumerate(self.model.modules()):
-            if isinstance(m, self.layer_type):
-                sp = next(sparsities)
-                self.prune_layer(m, sp, round_to)
-                if isinstance(mods[k+1], nn.modules.batchnorm._BatchNorm): self.prune_batchnorm(m, mods[k+1])
-
-    def prune_batchnorm(self, m, bn):
-        mask = getattr(m, "_mask", None)
-        if self.granularity == 'filter' and mask is not None:
-            bn.weight.data.mul_(mask.squeeze())
-            bn.bias.data.mul_(mask.squeeze())
-
-
-    def _apply_masks(self):
-        for m in self.model.modules():
-            if isinstance(m, self.layer_type):
-                self._apply(m)
-
-    def _apply(self, m):
-        mask = getattr(m, "_mask", None)
-        if mask is not None: m.weight.data.mul_(mask)
-        if self.granularity == 'filter' and m.bias is not None:
-            if mask is not None: m.bias.data.mul_(mask.squeeze()) # We want to prune the bias when pruning filters
-
-    def _mask_grad(self):
-        for m in self.model.modules():
-            if isinstance(m, self.layer_type) and hasattr(m, '_mask'):
-                mask = getattr(m, "_mask")
-                if m.weight.grad is not None: m.weight.grad.mul_(mask)
-                if self.granularity == 'filter' and m.bias is not None:
-                    if m.bias.grad is not None: m.bias.grad.mul_(mask.squeeze())
-
-
-    def _reset_weights(self, model=None): # Reset non-pruned weights
-        if not model: model=self.model
-        for m in model.modules():
-            if hasattr(m, 'weight'):
-                init_weights = getattr(m, "_init_weights", m.weight)
-                init_biases = getattr(m, "_init_biases", m.bias)
-                with torch.no_grad():
-                    if m.weight is not None: m.weight.copy_(init_weights)
-                    if m.bias is not None: m.bias.copy_(init_biases)
-                self._apply(m)
-            if isinstance(m, nn.modules.batchnorm._BatchNorm): m.reset_parameters()
-
-    def _save_weights(self):
-        for m in self.model.modules():
-            if hasattr(m, 'weight'):
-                m.register_buffer("_init_weights", m.weight.clone())
-                b = getattr(m, 'bias', None)
-                if b is not None: m.register_buffer("_init_biases", b.clone())
-
-    def save_model(self, path, model=None):
-        if not model: model=self.model
-        tmp_model = pickle.loads(pickle.dumps(model))
-        self._reset_weights(tmp_model)
-        self._clean_buffers(tmp_model)
-        torch.save(tmp_model, path)
-
-    def _clean_buffers(self, model=None):
-        if not model: model=self.model
-        for m in model.modules():
-            if hasattr(m, 'weight'):
-                if hasattr(m, '_mask'): del m._buffers["_mask"]
-                if hasattr(m, '_init_weights'): del m._buffers["_init_weights"]
-                if hasattr(m, '_init_biases'): del m._buffers["_init_biases"]
-
-    def _compute_threshold(self, weight, sparsity):
-        if self.context == 'global':
-            global_weight = torch.cat([self.criteria(m, self.granularity).view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)])
-            if self.threshold is None: self.threshold = torch.quantile(global_weight, sparsity/100) # Compute the threshold globally (only once per model pruning)
-            return self.threshold
-        elif self.context == 'local':
-            return torch.quantile(weight.view(-1), sparsity/100) # Compute the threshold locally
-        else: raise NameError('Invalid Context')
-
-    def _rounded_sparsity(self, n_to_prune, round_to):
-        return max(round_to*torch.ceil(n_to_prune/round_to), round_to)
-
-    def _compute_mask(self, weight, sparsity, round_to):
-        threshold = self._compute_threshold(weight, sparsity)
-        if round_to:
-            n_to_keep = sum(weight.ge(threshold)).squeeze()
-            threshold = torch.topk(weight.squeeze(), int(self._rounded_sparsity(n_to_keep, round_to)))[0].min()
-        if threshold > weight.max(): threshold = weight.max() # Make sure we don't remove every weight of a given layer
-        return weight.ge(threshold).to(dtype=weight.dtype)
-
-    def print_sparsity(self):
-        for k,m in enumerate(self.model.modules()):
-            if isinstance(m, self.layer_type):
-                print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class SparsifyCallback(Callback):
-    def __init__(self, sparsity, granularity, context, criteria, schedule, lth=False, rewind_epoch=0, reset_end=False, save_tickets=False, model=None, round_to=None, layer_type=nn.Conv2d):
-        store_attr()
-        self.sparsity = listify(self.sparsity)
-
-    def before_fit(self):
-        print(f'Pruning of {self.granularity} until a sparsity of {self.sparsity}%')
-        assert self.schedule.start_pct*self.n_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'
-        model = self.model if self.model else self.learn.model
-        self.sparsifier = Sparsifier(model, self.granularity, self.context, self.criteria, self.layer_type)
-
-    def before_epoch(self):
-        if self.epoch == self.rewind_epoch:
-            print(f'Saving Weights at epoch {self.epoch}')
-            self.sparsifier._save_weights()
-
-    def before_batch(self):
-        self.current_sparsity = self.schedule(self.sparsity, round(self.pct_train,3))
-        if self.schedule.pruned and self.training:
-            if self.lth and self.save_tickets:
-                print('Saving Intermediate Ticket')
-                self.sparsifier.save_model(f'winning_ticket_{self.previous_sparsity[0]:.2f}.pth', self.learn.model)
-            self.sparsifier.prune_model(self.current_sparsity, self.round_to)
-
-    def after_step(self):
-        if self.lth and self.schedule.pruned:
-            print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
-            self.sparsifier._reset_weights(self.learn.model)
-        self.schedule.after_pruned()
-        self.sparsifier._apply_masks()
-
-    def after_epoch(self):
-        sparsity_str = [float(f"%0.2f"%sp) for sp in self.current_sparsity]
-        print(f'Sparsity at the end of epoch {self.epoch}: {sparsity_str}%')
-
-    def after_fit(self):
-        if self.save_tickets:
-            print('Saving Final Ticket')
-            self.sparsifier.save_model(f'winning_ticket_{self.previous_sparsity[0]:.2f}.pth', self.learn.model)
-        print(f'Final Sparsity: {self.schedule.current_sparsity:}%')
-        if self.reset_end: self.sparsifier._reset_weights()
-        self.sparsifier._clean_buffers()
-        self.schedule.reset()
-        self.sparsifier.print_sparsity()
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
learn = cnn_learner(dls, resnet18, metrics=accuracy)
-learn.unfreeze()
-
-sp_cb = SparsifyCallback(sparsity=50, granularity='filter', context='global', criteria=large_final, schedule=cos)
-learn.fit(5, cbs=sp_cb)
-
- -
-
-
- -
-
- -
- -
-
Pruning of filter until a sparsity of [50]%
-Saving Weights at epoch 0
-
-
-
- -
- - -
- -
- -
- -
- - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
epochtrain_lossvalid_lossaccuracytime
00.5580050.3927340.82882300:09
10.3666870.3254900.86332900:09
20.2579370.4965810.82882300:10
30.2212620.2314590.90392400:09
40.1803290.3520360.87347800:09
- -
- -
- -
-
Sparsity at the end of epoch 0: [4.77]%
-Sparsity at the end of epoch 1: [17.27]%
-Sparsity at the end of epoch 2: [32.73]%
-Sparsity at the end of epoch 3: [45.23]%
-Sparsity at the end of epoch 4: [50.0]%
-Final Sparsity: [50.0]%
-Sparsity in Conv2d 2: 12.50%
-Sparsity in Conv2d 8: 0.00%
-Sparsity in Conv2d 11: 0.00%
-Sparsity in Conv2d 14: 0.00%
-Sparsity in Conv2d 17: 0.00%
-Sparsity in Conv2d 21: 0.00%
-Sparsity in Conv2d 24: 0.00%
-Sparsity in Conv2d 27: 0.78%
-Sparsity in Conv2d 30: 0.00%
-Sparsity in Conv2d 33: 2.34%
-Sparsity in Conv2d 37: 0.00%
-Sparsity in Conv2d 40: 3.12%
-Sparsity in Conv2d 43: 2.34%
-Sparsity in Conv2d 46: 54.30%
-Sparsity in Conv2d 49: 67.19%
-Sparsity in Conv2d 53: 99.22%
-Sparsity in Conv2d 56: 99.80%
-Sparsity in Conv2d 59: 3.71%
-Sparsity in Conv2d 62: 99.80%
-Sparsity in Conv2d 65: 99.80%
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
%debug
-
- -
-
-
- -
-
- -
- -
-
> /tmp/ipykernel_1952159/3551873890.py(27)granularize()
-     25             dim = granularities[m.__class__.__name__][g]
-     26             #if hasattr(m, '_mask') == False: m.register_buffer("_mask", torch.ones_like(scores[None].mean(dim=dim, keepdim=True).squeeze(0))) # Put the mask into a buffer
----> 27             scores = self.rescale(scores, min_value)[None].mean(dim=dim, keepdim=True).squeeze(0).mul_(m._mask)
-     28         else: raise NameError('Invalid Granularity')
-     29         return scores
-
-ipdb> self.rescale(scores, min_value)[None].mean(dim=dim, keepdim=True).squeeze(0) * m._mask
-tensor([[[[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],
-
-         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],
-
-         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]]],
-
-
-        [[[0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          ...,
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009]],
-
-         [[0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          ...,
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009]],
-
-         [[0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          ...,
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009]]],
-
-
-        [[[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],
-
-         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],
-
-         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]]],
-
-
-        ...,
-
-
-        [[[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],
-
-         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],
-
-         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]]],
-
-
-        [[[0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          ...,
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008]],
-
-         [[0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          ...,
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008]],
-
-         [[0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          ...,
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008]]],
-
-
-        [[[0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          ...,
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009]],
-
-         [[0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          ...,
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009]],
-
-         [[0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          ...,
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009]]]],
-       device='cuda:0', grad_fn=<MulBackward0>)
-ipdb> self.rescale(scores, min_value)[None].mean(dim=dim, keepdim=True).squeeze(0).mul(m._mask)
-tensor([[[[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],
-
-         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],
-
-         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]]],
-
-
-        [[[0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          ...,
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009]],
-
-         [[0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          ...,
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009]],
-
-         [[0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          ...,
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009]]],
-
-
-        [[[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],
-
-         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],
-
-         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]]],
-
-
-        ...,
-
-
-        [[[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],
-
-         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],
-
-         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]]],
-
-
-        [[[0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          ...,
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008]],
-
-         [[0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          ...,
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008]],
-
-         [[0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          ...,
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0008, 0.0008, 0.0008,  ..., 0.0008, 0.0008, 0.0008]]],
-
-
-        [[[0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          ...,
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009]],
-
-         [[0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          ...,
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009]],
-
-         [[0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          ...,
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009]]]],
-       device='cuda:0', grad_fn=<MulBackward0>)
-ipdb> self.rescale(scores, min_value).mul_(m._mask)[None].mean(dim=dim, keepdim=True).squeeze(0)
-tensor([[[[0.0010]]],
-
-
-        [[[0.0009]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0011]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0008]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0011]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0011]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0009]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0009]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0011]]],
-
-
-        [[[0.0009]]],
-
-
-        [[[0.0011]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0011]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0011]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0011]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0011]]],
-
-
-        [[[0.0008]]],
-
-
-        [[[0.0011]]],
-
-
-        [[[0.0012]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0011]]],
-
-
-        [[[0.0009]]],
-
-
-        [[[0.0011]]],
-
-
-        [[[0.0011]]],
-
-
-        [[[0.0009]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0011]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0011]]],
-
-
-        [[[0.0008]]],
-
-
-        [[[0.0009]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0009]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0010]]],
-
-
-        [[[0.0008]]],
-
-
-        [[[0.0009]]]], device='cuda:0', grad_fn=<SqueezeBackward1>)
-ipdb> quit
-
-
-
- -
-
- -
- {% endraw %} - -
-
-

New

-
-
-
- {% raw %} - -
-
- -
-
-
from fastai.vision.all import *
-from fastai.callback.all import *
-from fasterai.sparse.sparsifier import *
-from fasterai.sparse.criteria import *
-from fasterai.sparse.schedule import *
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
torch.quantile(torch.Tensor([0.9, 1.3, -0.1, 0.4]), 0.5)
-
- -
-
-
- -
-
- -
- - - -
-
tensor(0.6500)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class Sparsifier():
-    def __init__(self, model, granularity, context, criteria, layer_type=nn.Conv2d):
-        store_attr()
-        self._save_weights() # Save the original weights
-
-    def prune_layer(self, m, sparsity, round_to=None):
-        weight = self.criteria(m, self.granularity)
-        if self.context=="local": weight = self.criteria.rescale(weight).mul_(m._mask) # We don't want to scale individual layers in global
-        weight = weight.mul_(m._mask) # We don't want to scale individual layers in global
-        setattr(m, '_mask', self._compute_mask(m, weight, sparsity, round_to))
-        self._apply(m)
-        self.criteria.update_weights(m)
-
-    def prune_model(self, sparsity, round_to=None):
-        self.threshold=None
-        sparsity_list = listify(sparsity)
-        if len(sparsity_list)>1: assert self.context=='local', f"A list of sparsities cannot be passed using: {self.context}"
-        sparsities = cycle(sparsity_list) if len(sparsity_list)==1 else iter(sparsity_list)
-        mods = list(self.model.modules())
-        for k,m in enumerate(self.model.modules()):
-            if isinstance(m, self.layer_type): 
-                sp = next(sparsities)
-                self.prune_layer(m, sp, round_to)
-                if isinstance(mods[k+1], nn.modules.batchnorm._BatchNorm): self.prune_batchnorm(m, mods[k+1])
-                
-    def prune_batchnorm(self, m, bn):
-        mask = getattr(m, "_mask", None)
-        if self.granularity == 'filter' and mask is not None:
-            bn.weight.data.mul_(mask.squeeze())
-            bn.bias.data.mul_(mask.squeeze())
-            
-    def _apply_masks(self):
-        for m in self.model.modules():
-            if isinstance(m, self.layer_type):
-                self._apply(m)
-        
-    def _apply(self, m):
-        mask = getattr(m, "_mask", None)
-        if mask is not None: m.weight.data.mul_(mask)
-        if self.granularity == 'filter' and m.bias is not None:
-            if mask is not None: m.bias.data.mul_(mask.squeeze()) # We want to prune the bias when pruning filters
-    
-    def _reset_weights(self, model=None):
-        if not model: model=self.model
-        for m in model.modules():
-            if hasattr(m, 'weight'):
-                init_weights = getattr(m, "_init_weights", m.weight)
-                init_biases = getattr(m, "_init_biases", m.bias)
-                with torch.no_grad():
-                    if m.weight is not None: m.weight.copy_(init_weights)
-                    if m.bias is not None: m.bias.copy_(init_biases)
-                self._apply(m)
-            if isinstance(m, nn.modules.batchnorm._BatchNorm): m.reset_parameters()
-                
-    def _save_weights(self):
-        for m in self.model.modules():
-            if hasattr(m, 'weight'):              
-                m.register_buffer("_init_weights", m.weight.clone())
-                b = getattr(m, 'bias', None)
-                if b is not None: m.register_buffer("_init_biases", b.clone())
-                    
-    def save_model(self, path, model=None):
-        if not model: model=self.model
-        tmp_model = pickle.loads(pickle.dumps(model))
-        self._reset_weights(tmp_model)
-        self._clean_buffers(tmp_model)
-        torch.save(tmp_model, path)
-
-    def _clean_buffers(self, model=None):
-        if not model: model=self.model
-        for m in model.modules():
-            if hasattr(m, 'weight'):
-                if hasattr(m, '_mask'): del m._buffers["_mask"]
-                if hasattr(m, '_init_weights'): del m._buffers["_init_weights"]
-                if hasattr(m, '_init_biases'): del m._buffers["_init_biases"]
-    
-    def _compute_threshold(self, m, weight, sparsity):
-        if self.context == 'global':
-            if self.threshold is None: 
-                global_weight = torch.cat([self.criteria(m, self.granularity).view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)]) # Get all weights
-                global_mask = torch.cat([m._mask.view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)]) # Get all masks
-                global_weight = self.criteria.descale(self.criteria.rescale(global_weight).mul_(global_mask.squeeze())).mul_(global_mask.squeeze()) # Rescale all weights, apply the mask, then descale
-                #global_weight = self.criteria.rescale(global_weight).mul_(global_mask.squeeze()) # Rescale all weights, apply the mask, then descale
-                self.threshold = torch.quantile(global_weight, sparsity/100) # Compute the threshold globally (only once per model pruning)
-                #self.threshold-=(self.criteria.min_value+torch.finfo(torch.float32).eps)
-            return self.threshold
-        elif self.context == 'local':
-            return torch.quantile(weight.view(-1), sparsity/100) # Compute the threshold locally
-        else: raise NameError('Invalid Context')
-
-    def _rounded_sparsity(self, n_to_prune, round_to):
-        return max(round_to*torch.ceil(n_to_prune/round_to), round_to)
-
-    def _compute_mask(self, m, weight, sparsity, round_to):
-        threshold = self._compute_threshold(m, weight, sparsity)
-        if round_to:
-            n_to_keep = sum(weight.ge(threshold)).squeeze()
-            threshold = torch.topk(weight.squeeze(), int(self._rounded_sparsity(n_to_keep, round_to)))[0].min()
-        if threshold > weight.max(): threshold = weight.max() # Make sure we don't remove every weight of a given layer
-        return weight.gt(threshold).to(dtype=weight.dtype)
-    
-    def print_sparsity(self):
-        for k,m in enumerate(self.model.modules()):
-            if isinstance(m, self.layer_type):
-                print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class Criteria():
-    def __init__(self, f, needs_init=False, needs_update=False, output_f=None, return_init=False):
-        store_attr()
-        assert (needs_init and needs_update)==False, "The init values will be overwritten by the updating ones."
-        self.min_value=0
-
-    def __call__(self, m, g):
-        if self.needs_update and hasattr(m, '_old_weights') == False:
-            m.register_buffer("_old_weights", m._init_weights.clone()) # If the previous value of weights is not known, take the initial value
-
-        if g in granularities[m.__class__.__name__]:
-            dim = granularities[m.__class__.__name__][g]
-            wf = self.f(m.weight)[None].mean(dim=dim, keepdim=True).squeeze(0)
-            if self.needs_init: wi = self.f(m._init_weights)[None].mean(dim=dim, keepdim=True).squeeze(0)
-            if self.needs_update: wi = self.f(m._old_weights)[None].mean(dim=dim, keepdim=True).squeeze(0)
-
-        else: raise NameError('Invalid Granularity')
-            
-        if hasattr(m, '_mask') == False: m.register_buffer("_mask", torch.ones_like(wf)) # Put the mask into a buffer
-
-        if self.output_f: output = self.output_f(wf, wi)
-        elif self.return_init: output = wi
-        else: output = wf
-        return output
-    
-    def rescale(self, w):
-        self.min_value = w.view(-1)[w.argmin()]
-        output =  w + self.min_value.abs() + torch.finfo(torch.float32).eps
-        return output
-    
-    def descale(self, w):
-        output = w - self.min_value.abs() - torch.finfo(torch.float32).eps
-        return output
-        
-    def update_weights(self, m):
-        if self.needs_update: 
-            m._old_weights = m.weight.data.clone() # The current value becomes the old one for the next iteration
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_magnitude_increase = Criteria(torch.abs, needs_update=True, output_f= torch.sub)
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class SparsifyCallback(Callback):
-    def __init__(self, sparsity, granularity, context, criteria, schedule, lth=False, rewind_epoch=0, reset_end=False, save_tickets=False, model=None, round_to=None, layer_type=nn.Conv2d):
-        store_attr()
-        self.sparsity = listify(self.sparsity)
-
-    def before_fit(self):
-        print(f'Pruning of {self.granularity} until a sparsity of {self.sparsity}%')
-        assert self.schedule.start_pct*self.n_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'
-        model = self.model if self.model else self.learn.model
-        self.sparsifier = Sparsifier(model, self.granularity, self.context, self.criteria, self.layer_type)
-
-    def before_epoch(self):
-        if self.epoch == self.rewind_epoch:
-            print(f'Saving Weights at epoch {self.epoch}')
-            self.sparsifier._save_weights()
-
-    def before_batch(self):
-        self.current_sparsity = self.schedule(self.sparsity, round(self.pct_train,3))
-        if self.schedule.pruned and self.training:
-            if self.lth and self.save_tickets:
-                print('Saving Intermediate Ticket')
-                self.sparsifier.save_model(f'winning_ticket_{self.previous_sparsity[0]:.2f}.pth', self.learn.model)
-            self.sparsifier.prune_model(self.current_sparsity, self.round_to)
-
-    def after_step(self):
-        if self.lth and self.schedule.pruned:
-            print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
-            self.sparsifier._reset_weights(self.learn.model)
-        self.schedule.after_pruned()
-        self.sparsifier._apply_masks()
-
-    def after_epoch(self):
-        sparsity_str = [float(f"%0.2f"%sp) for sp in self.current_sparsity]
-        print(f'Sparsity at the end of epoch {self.epoch}: {sparsity_str}%')
-
-    def after_fit(self):
-        if self.save_tickets:
-            print('Saving Final Ticket')
-            self.sparsifier.save_model(f'winning_ticket_{self.previous_sparsity[0]:.2f}.pth', self.learn.model)
-        print(f'Final Sparsity: {self.schedule.current_sparsity:}%')
-        if self.reset_end: self.sparsifier._reset_weights()
-        self.sparsifier._clean_buffers()
-        self.schedule.reset()
-        self.sparsifier.print_sparsity()
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
learn = cnn_learner(dls, resnet18, metrics=accuracy)
-learn.unfreeze()
-
-sp_cb = SparsifyCallback(sparsity=50, granularity='filter', context='global', criteria=updating_magnitude_increase, schedule=cos)
-learn.fit(5, cbs=sp_cb)
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class Criteria():
-    def __init__(self, f, needs_init=False, needs_update=False, output_f=None, return_init=False):
-        store_attr()
-        assert (needs_init and needs_update)==False, "The init values will be overwritten by the updating ones."
-        self.min_value=0
-
-    def __call__(self, m, g):
-        if self.needs_update and hasattr(m, '_old_weights') == False:
-            m.register_buffer("_old_weights", m._init_weights.clone()) # If the previous value of weights is not known, take the initial value
-
-        if g in granularities[m.__class__.__name__]:
-            dim = granularities[m.__class__.__name__][g]
-            wf = self.f(m.weight)[None].mean(dim=dim, keepdim=True).squeeze(0)
-            if self.needs_init: wi = self.f(m._init_weights)[None].mean(dim=dim, keepdim=True).squeeze(0)
-            if self.needs_update: wi = self.f(m._old_weights)[None].mean(dim=dim, keepdim=True).squeeze(0)
-
-        else: raise NameError('Invalid Granularity')
-            
-        if hasattr(m, '_mask') == False: m.register_buffer("_mask", torch.ones_like(wf)) # Put the mask into a buffer
-
-        if self.output_f: output = self.output_f(wf, wi)
-        elif self.return_init: output = wi
-        else: output = wf
-        return output
-    
-    def rescale(self, w):
-        self.min_value = w.view(-1)[w.argmin()]
-        output =  w + self.min_value.abs() + torch.finfo(torch.float32).eps
-        #print(self.min_value.abs())
-        return output
-    
-    def descale(self, w):
-        output = w - self.min_value.abs()
-        #print(self.min_value.abs())
-        return output
-        
-    def update_weights(self, m):
-        if self.needs_update: 
-            m._old_weights = m.weight.data.clone() # The current value becomes the old one for the next iteration
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_magnitude_increase = Criteria(torch.abs, needs_update=True, output_f= torch.sub)
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
from fasterai.sparse.granularity import *
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class Sparsifier():
-    def __init__(self, model, granularity, context, criteria, layer_type=nn.Conv2d):
-        store_attr()
-        self._save_weights() # Save the original weights
-
-    def prune_layer(self, m, sparsity, round_to=None):
-        scores = self.criteria(m, self.granularity)
-        if self.context=="local": scores = self.criteria.rescale(scores).mul_(m._mask) # We don't want to scale individual layers in global
-        scores = scores.mul_(m._mask) # We don't want to scale individual layers in global
-        setattr(m, '_mask', self._compute_mask(m, scores, sparsity, round_to))
-        self._apply(m)
-        self.criteria.update_weights(m)
-
-    def prune_model(self, sparsity, round_to=None):
-        self.threshold=None
-        sparsity_list = listify(sparsity)
-        if len(sparsity_list)>1: assert self.context=='local', f"A list of sparsities cannot be passed using: {self.context}"
-        sparsities = cycle(sparsity_list) if len(sparsity_list)==1 else iter(sparsity_list)
-        mods = list(self.model.modules())
-        for k,m in enumerate(self.model.modules()):
-            if isinstance(m, self.layer_type): 
-                sp = next(sparsities)
-                self.prune_layer(m, sp, round_to)
-                if isinstance(mods[k+1], nn.modules.batchnorm._BatchNorm): self.prune_batchnorm(m, mods[k+1])
-                
-    def prune_batchnorm(self, m, bn):
-        mask = getattr(m, "_mask", None)
-        if self.granularity == 'filter' and mask is not None:
-            bn.weight.data.mul_(mask.squeeze())
-            bn.bias.data.mul_(mask.squeeze())
-            
-    def _apply_masks(self):
-        for m in self.model.modules():
-            if isinstance(m, self.layer_type):
-                self._apply(m)
-        
-    def _apply(self, m):
-        mask = getattr(m, "_mask", None)
-        if mask is not None: m.weight.data.mul_(mask)
-        if self.granularity == 'filter' and m.bias is not None:
-            if mask is not None: m.bias.data.mul_(mask.squeeze()) # We want to prune the bias when pruning filters
-    
-    def _reset_weights(self, model=None):
-        if not model: model=self.model
-        for m in model.modules():
-            if hasattr(m, 'weight'):
-                init_weights = getattr(m, "_init_weights", m.weight)
-                init_biases = getattr(m, "_init_biases", m.bias)
-                with torch.no_grad():
-                    if m.weight is not None: m.weight.copy_(init_weights)
-                    if m.bias is not None: m.bias.copy_(init_biases)
-                self._apply(m)
-            if isinstance(m, nn.modules.batchnorm._BatchNorm): m.reset_parameters()
-                
-    def _save_weights(self):
-        for m in self.model.modules():
-            if hasattr(m, 'weight'):              
-                m.register_buffer("_init_weights", m.weight.clone())
-                b = getattr(m, 'bias', None)
-                if b is not None: m.register_buffer("_init_biases", b.clone())
-                    
-    def save_model(self, path, model=None):
-        if not model: model=self.model
-        tmp_model = pickle.loads(pickle.dumps(model))
-        self._reset_weights(tmp_model)
-        self._clean_buffers(tmp_model)
-        torch.save(tmp_model, path)
-
-    def _clean_buffers(self, model=None):
-        if not model: model=self.model
-        for m in model.modules():
-            if hasattr(m, 'weight'):
-                if hasattr(m, '_mask'): del m._buffers["_mask"]
-                if hasattr(m, '_init_weights'): del m._buffers["_init_weights"]
-                if hasattr(m, '_init_biases'): del m._buffers["_init_biases"]
-    
-    def _compute_threshold(self, m, scores, sparsity):
-        if self.context == 'global':
-            if self.threshold is None: 
-                global_scores = torch.cat([self.criteria(m, self.granularity).view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)]) # Get all scores
-                #print('GW init', global_scores)
-                global_mask = torch.cat([m._mask.view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)]) # Get all masks
-                #print(global_mask)
-                global_scores = self.criteria.rescale(global_scores)
-                #print('GW after rescale', global_scores)
-                global_scores = global_scores.mul_(global_mask.squeeze())# Rescale all scores, apply the mask, then descale
-                #print('GW after mask', global_scores)
-                global_scores = self.criteria.descale(global_scores)
-                #print('GW after dscale', global_scores)
-                self.threshold = torch.quantile(global_scores, sparsity/100) # Compute the threshold globally (only once per model pruning)
-            return self.threshold
-        elif self.context == 'local':
-            return torch.quantile(scores.view(-1), sparsity/100) # Compute the threshold locally
-        else: raise NameError('Invalid Context')
-
-    def _rounded_sparsity(self, n_to_prune, round_to):
-        return max(round_to*torch.ceil(n_to_prune/round_to), round_to)
-
-    def _compute_mask(self, m, scores, sparsity, round_to):
-        threshold = self._compute_threshold(m, scores, sparsity)
-        #print('Thresh', threshold)
-        #print('Weight',scores)
-        if round_to:
-            n_to_keep = sum(scores.ge(threshold)).squeeze()
-            threshold = torch.topk(scores.squeeze(), int(self._rounded_sparsity(n_to_keep, round_to)))[0].min()
-        if threshold > scores.max(): threshold = scores.max() # Make sure we don't remove every weight of a given layer
-        return scores.ge(threshold).to(dtype=scores.dtype)
-    
-    def print_sparsity(self):
-        for k,m in enumerate(self.model.modules()):
-            if isinstance(m, self.layer_type):
-                print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class Criteria():
-    def __init__(self, f, needs_init=False, needs_update=False, output_f=None, return_init=False):
-        store_attr()
-        assert (needs_init and needs_update)==False, "The init values will be overwritten by the updating ones."
-        self.min_value=0
-
-    def __call__(self, m, g):
-        if self.needs_update and hasattr(m, '_old_weights') == False:
-            m.register_buffer("_old_weights", m._init_weights.clone()) # If the previous value of weights is not known, take the initial value
-
-        if g in granularities[m.__class__.__name__]:
-            dim = granularities[m.__class__.__name__][g]
-            wf = self.f(m.weight)[None].mean(dim=dim, keepdim=True).squeeze(0)
-            if self.needs_init: wi = self.f(m._init_weights)[None].mean(dim=dim, keepdim=True).squeeze(0)
-            if self.needs_update: wi = self.f(m._old_weights)[None].mean(dim=dim, keepdim=True).squeeze(0)
-
-        else: raise NameError('Invalid Granularity')
-            
-        if hasattr(m, '_mask') == False: m.register_buffer("_mask", torch.ones_like(wf)) # Put the mask into a buffer
-
-        if self.output_f: output = self.output_f(wf, wi)
-        elif self.return_init: output = wi
-        else: output = wf
-        return output
-    
-    def rescale(self, w, min_value=None):
-        self.min_value = min_value if min_value else w.view(-1)[w.argmin()]
-        output =  w + self.min_value.abs() + torch.finfo(torch.float32).eps
-        return output, self.min_value
-    
-    def descale(self, w):
-        output = w - self.min_value.abs()
-        return output
-        
-    def update_weights(self, m):
-        if self.needs_update: 
-            m._old_weights = m.weight.data.clone() # The current value becomes the old one for the next iteration
-            
-
-
-
-class Sparsifier():
-    def __init__(self, model, granularity, context, criteria, layer_type=nn.Conv2d):
-        store_attr()
-        self._save_weights() # Save the original weights
-
-    def prune_layer(self, m, sparsity, round_to=None, k=0):
-        scores = self.criteria(m, self.granularity)
-        #scores = scores.mul_(m._mask) # We don't want to scale individual layers in global
-        #if k==2: 
-        #    print(m._mask[0].squeeze())
-        #    print(scores[0])
-        setattr(m, '_mask', self._compute_mask(m, scores, sparsity, round_to))
-        self._apply(m)
-        self.criteria.update_weights(m)
-
-    def prune_model(self, sparsity, round_to=None):
-        self.threshold=None
-        sparsity_list = listify(sparsity)
-        if len(sparsity_list)>1: assert self.context=='local', f"A list of sparsities cannot be passed using: {self.context}"
-        sparsities = cycle(sparsity_list) if len(sparsity_list)==1 else iter(sparsity_list)
-        mods = list(self.model.modules())
-        for k,m in enumerate(self.model.modules()):
-            if isinstance(m, self.layer_type): 
-                sp = next(sparsities)
-                self.prune_layer(m, sp, round_to,k)
-                if isinstance(mods[k+1], nn.modules.batchnorm._BatchNorm): self.prune_batchnorm(m, mods[k+1])
-                
-    def prune_batchnorm(self, m, bn):
-        mask = getattr(m, "_mask", None)
-        if self.granularity == 'filter' and mask is not None:
-            bn.weight.data.mul_(mask.squeeze())
-            bn.bias.data.mul_(mask.squeeze())
-            
-    def _apply_masks(self):
-        for m in self.model.modules():
-            if isinstance(m, self.layer_type):
-                self._apply(m)
-        
-    def _apply(self, m):
-        mask = getattr(m, "_mask", None)
-        if mask is not None: m.weight.data.mul_(mask)
-        if self.granularity == 'filter' and m.bias is not None:
-            if mask is not None: m.bias.data.mul_(mask.squeeze()) # We want to prune the bias when pruning filters
-    
-    def _reset_weights(self, model=None):
-        if not model: model=self.model
-        for m in model.modules():
-            if hasattr(m, 'weight'):
-                init_weights = getattr(m, "_init_weights", m.weight)
-                init_biases = getattr(m, "_init_biases", m.bias)
-                with torch.no_grad():
-                    if m.weight is not None: m.weight.copy_(init_weights)
-                    if m.bias is not None: m.bias.copy_(init_biases)
-                self._apply(m)
-            if isinstance(m, nn.modules.batchnorm._BatchNorm): m.reset_parameters()
-                
-    def _save_weights(self):
-        for m in self.model.modules():
-            if hasattr(m, 'weight'):              
-                m.register_buffer("_init_weights", m.weight.clone())
-                b = getattr(m, 'bias', None)
-                if b is not None: m.register_buffer("_init_biases", b.clone())
-                    
-    def save_model(self, path, model=None):
-        if not model: model=self.model
-        tmp_model = pickle.loads(pickle.dumps(model))
-        self._reset_weights(tmp_model)
-        self._clean_buffers(tmp_model)
-        torch.save(tmp_model, path)
-
-    def _clean_buffers(self, model=None):
-        if not model: model=self.model
-        for m in model.modules():
-            if hasattr(m, 'weight'):
-                if hasattr(m, '_mask'): del m._buffers["_mask"]
-                if hasattr(m, '_init_weights'): del m._buffers["_init_weights"]
-                if hasattr(m, '_init_biases'): del m._buffers["_init_biases"]
-    
-    def _compute_threshold(self, scores, sparsity):
-        if self.context == 'global':
-            if self.threshold is None: 
-                global_scores = torch.cat([self.criteria(m, self.granularity).view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)]) # Get all scores
-                #print('GS', global_scores)
-                global_mask = torch.cat([m._mask.view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)]) # Get all masks
-                global_scores, self.min_value = self.criteria.rescale(global_scores)
-                #print('GS rescaled', global_scores)
-                #print('min_value', self.min_value)
-                global_scores = global_scores.mul_(global_mask.squeeze())# Rescale all scores, apply the mask, then descale
-                #print('GS pruned', global_scores)
-                global_scores = self.criteria.descale(global_scores)
-                #print('GS descaled', global_scores)
-                self.threshold = torch.quantile(global_scores, sparsity/100) # Compute the threshold globally (only once per model pruning)
-            return self.threshold
-        elif self.context == 'local':
-            return torch.quantile(scores.view(-1), sparsity/100) # Compute the threshold locally
-        else: raise NameError('Invalid Context')
-
-    def _rounded_sparsity(self, n_to_prune, round_to):
-        return max(round_to*torch.ceil(n_to_prune/round_to), round_to)
-
-    def _compute_mask(self, m, scores, sparsity, round_to):
-        if self.context=='local': scores = self.criteria.descale(self.criteria.rescale(scores)[0].mul_(m._mask)) # We don't want to scale individual layers in globa
-        threshold = self._compute_threshold(scores, sparsity)
-        if self.context=='global': scores = self.criteria.descale(self.criteria.rescale(scores, self.min_value)[0].mul_(m._mask))
-        #print('Scores', scores)
-        #print('Thresh', threshold)
-        if round_to:
-            n_to_keep = sum(scores.ge(threshold)).squeeze()
-            threshold = torch.topk(scores.squeeze(), int(self._rounded_sparsity(n_to_keep, round_to)))[0].min()
-        if threshold > scores.max(): threshold = scores.max() # Make sure we don't remove every weight of a given layer
-        return scores.ge(threshold).to(dtype=scores.dtype)
-    
-    def print_sparsity(self):
-        for k,m in enumerate(self.model.modules()):
-            if isinstance(m, self.layer_type):
-                print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class Criteria():
-    def __init__(self, f, needs_init=False, needs_update=False, output_f=None, return_init=False):
-        store_attr()
-        assert (needs_init and needs_update)==False, "The init values will be overwritten by the updating ones."
-        self.min_value=0
-
-    def __call__(self, m, g):
-        if self.needs_update and hasattr(m, '_old_weights') == False:
-            m.register_buffer("_old_weights", m._init_weights.clone()) # If the previous value of weights is not known, take the initial value
-
-        if g in granularities[m.__class__.__name__]:
-            dim = granularities[m.__class__.__name__][g]
-            wf = self.f(m.weight)[None].mean(dim=dim, keepdim=True).squeeze(0)
-            if self.needs_init: wi = self.f(m._init_weights)[None].mean(dim=dim, keepdim=True).squeeze(0)
-            if self.needs_update: wi = self.f(m._old_weights)[None].mean(dim=dim, keepdim=True).squeeze(0)
-
-        else: raise NameError('Invalid Granularity')
-            
-        if hasattr(m, '_mask') == False: m.register_buffer("_mask", torch.ones_like(wf)) # Put the mask into a buffer
-
-        if self.output_f: output = self.output_f(wf, wi)
-        elif self.return_init: output = wi
-        else: output = wf
-        return output
-    
-    def rescale(self, w, min_value=None):
-        self.min_value = min_value if min_value else w.view(-1)[w.argmin()]
-        output =  w + self.min_value.abs() + torch.finfo(torch.float32).eps
-        return output, self.min_value
-    
-    def descale(self, w):
-        output = w - self.min_value.abs()
-        return output
-        
-    def update_weights(self, m):
-        if self.needs_update: 
-            m._old_weights = m.weight.data.clone() # The current value becomes the old one for the next iteration
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class Criteria():
-    def __init__(self, f, needs_init=False, needs_update=False, output_f=None, return_init=False):
-        store_attr()
-        assert (needs_init and needs_update)==False, "The init values will be overwritten by the updating ones."
-        self.min_value=None
-        
-    def __call__(self, m):
-        if self.needs_update and hasattr(m, '_old_weights') == False:
-            m.register_buffer("_old_weights", m._init_weights.clone()) # If the previous value of weights is not known, take the initial value
-
-        wf = self.f(m.weight)
-        if self.needs_init: wi = self.f(m._init_weights)
-        if self.needs_update: wi = self.f(m._old_weights)
-
-        if hasattr(m, '_mask') == False: m.register_buffer("_mask", torch.ones_like(wf)) # Put the mask into a buffer
-
-        if self.output_f: output = self.output_f(wf, wi)
-        elif self.return_init: output = wi
-        else: output = wf
-        return output
-    
-    def granularize(self, m, scores, g):
-        if g in granularities[m.__class__.__name__]:
-            dim = granularities[m.__class__.__name__][g]
-            scores = scores[None].mean(dim=dim, keepdim=True).squeeze(0)
-        else: raise NameError('Invalid Granularity')
-        return scores
-    
-    def get_scores(self, m, scores, g, min_value=None):  
-        scores = self.granularize(m, self.rescale(scores, min_value).mul_(m._mask), g)
-        return scores
-    
-    def rescale(self, w, min_value=None):
-        self.min_value = min_value if min_value else w.min()
-        output =  w + self.min_value.abs() + torch.finfo(torch.float32).eps
-        return output
-    
-    #def descale(self, w):
-    #    output = w - self.min_value.abs()
-    #    return output
-    
-    #def get_min(self, w):
-    #    return w.view(-1)[w.argmin()]
-        
-    def update_weights(self, m):
-        if self.needs_update: 
-            m._old_weights = m.weight.data.clone() # The current value becomes the old one for the next iteration
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class Sparsifier():
-    def __init__(self, model, granularity, context, criteria, layer_type=nn.Conv2d):
-        store_attr()
-        self._save_weights() # Save the original weights
-
-    def prune_layer(self, m, sparsity, round_to=None, k=0):
-        scores = self.criteria(m)
-        setattr(m, '_mask', self._compute_mask(m, scores, sparsity, round_to))
-        self._apply(m)
-        self.criteria.update_weights(m)
-
-    def prune_model(self, sparsity, round_to=None):
-        self.threshold=None
-        sparsity_list = listify(sparsity)
-        if len(sparsity_list)>1: assert self.context=='local', f"A list of sparsities cannot be passed using: {self.context}"
-        sparsities = cycle(sparsity_list) if len(sparsity_list)==1 else iter(sparsity_list)
-        mods = list(self.model.modules())
-        for k,m in enumerate(self.model.modules()):
-            if isinstance(m, self.layer_type): 
-                sp = next(sparsities)
-                self.prune_layer(m, sp, round_to,k)
-                if isinstance(mods[k+1], nn.modules.batchnorm._BatchNorm): self.prune_batchnorm(m, mods[k+1])
-                
-    def prune_batchnorm(self, m, bn):
-        mask = getattr(m, "_mask", None)
-        if self.granularity == 'filter' and mask is not None:
-            bn.weight.data.mul_(mask.squeeze())
-            bn.bias.data.mul_(mask.squeeze())
-            
-    def _apply_masks(self):
-        for m in self.model.modules():
-            if isinstance(m, self.layer_type):
-                self._apply(m)
-        
-    def _apply(self, m):
-        mask = getattr(m, "_mask", None)
-        if mask is not None: m.weight.data.mul_(mask)
-        if self.granularity == 'filter' and m.bias is not None:
-            if mask is not None: m.bias.data.mul_(mask.squeeze()) # We want to prune the bias when pruning filters
-    
-    def _reset_weights(self, model=None):
-        if not model: model=self.model
-        for m in model.modules():
-            if hasattr(m, 'weight'):
-                init_weights = getattr(m, "_init_weights", m.weight)
-                init_biases = getattr(m, "_init_biases", m.bias)
-                with torch.no_grad():
-                    if m.weight is not None: m.weight.copy_(init_weights)
-                    if m.bias is not None: m.bias.copy_(init_biases)
-                self._apply(m)
-            if isinstance(m, nn.modules.batchnorm._BatchNorm): m.reset_parameters()
-                
-    def _save_weights(self):
-        for m in self.model.modules():
-            if hasattr(m, 'weight'):              
-                m.register_buffer("_init_weights", m.weight.clone())
-                b = getattr(m, 'bias', None)
-                if b is not None: m.register_buffer("_init_biases", b.clone())
-                    
-    def save_model(self, path, model=None):
-        if not model: model=self.model
-        tmp_model = pickle.loads(pickle.dumps(model))
-        self._reset_weights(tmp_model)
-        self._clean_buffers(tmp_model)
-        torch.save(tmp_model, path)
-
-    def _clean_buffers(self, model=None):
-        if not model: model=self.model
-        for m in model.modules():
-            if hasattr(m, 'weight'):
-                if hasattr(m, '_mask'): del m._buffers["_mask"]
-                if hasattr(m, '_init_weights'): del m._buffers["_init_weights"]
-                if hasattr(m, '_init_biases'): del m._buffers["_init_biases"]
-    
-    def _compute_threshold(self, m, scores, sparsity):
-        if self.context == 'global':
-            if self.threshold is None: 
-                global_criteria = torch.cat([self.criteria(m).view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)]) # Get all scores
-                global_scores = torch.cat([self.criteria.get_scores(m, self.criteria(m), self.granularity, global_criteria.min()).view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)])
-                self.threshold = torch.quantile(global_scores, sparsity/100) # Compute the threshold globally (only once per model pruning)
-            scores = self.criteria.get_scores(m, scores, self.granularity, self.criteria.min_value) # min_value is computed only once per prune_model
-            return self.threshold, scores
-        elif self.context == 'local':
-            scores = self.criteria.get_scores(m, scores, self.granularity)
-            return torch.quantile(scores.view(-1), sparsity/100), scores
-        else: raise NameError('Invalid Context')
-
-    def _rounded_sparsity(self, n_to_prune, round_to):
-        return max(round_to*torch.ceil(n_to_prune/round_to), round_to)
-    
-    def _compute_mask(self, m, scores, sparsity, round_to):
-        self.threshold, scores = self._compute_threshold(m, scores, sparsity)
-        if round_to:
-            n_to_keep = sum(scores.ge(self.threshold)).squeeze()
-            self.threshold = torch.topk(scores.squeeze(), int(self._rounded_sparsity(n_to_keep, round_to)))[0].min()
-        if self.threshold > scores.max(): self.threshold = scores.max() # Make sure we don't remove every weight of a given layer
-        return scores.ge(self.threshold).to(dtype=scores.dtype)
-    
-    def print_sparsity(self):
-        for k,m in enumerate(self.model.modules()):
-            if isinstance(m, self.layer_type):
-                print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class SparsifyCallback(Callback):
-    def __init__(self, sparsity, granularity, context, criteria, schedule, lth=False, rewind_epoch=0, reset_end=False, save_tickets=False, model=None, round_to=None, layer_type=nn.Conv2d):
-        store_attr()
-        self.sparsity = listify(self.sparsity)
-
-    def before_fit(self):
-        print(f'Pruning of {self.granularity} until a sparsity of {self.sparsity}%')
-        assert self.schedule.start_pct*self.n_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'
-        model = self.model if self.model else self.learn.model
-        self.sparsifier = Sparsifier(model, self.granularity, self.context, self.criteria, self.layer_type)
-
-    def before_epoch(self):
-        if self.epoch == self.rewind_epoch:
-            print(f'Saving Weights at epoch {self.epoch}')
-            self.sparsifier._save_weights()
-
-    def before_batch(self):
-        self.current_sparsity = self.schedule(self.sparsity, round(self.pct_train,3))
-        if self.schedule.pruned and self.training:
-            if self.lth and self.save_tickets:
-                print('Saving Intermediate Ticket')
-                self.sparsifier.save_model(f'winning_ticket_{self.previous_sparsity[0]:.2f}.pth', self.learn.model)
-            self.sparsifier.prune_model(self.current_sparsity, self.round_to)
-
-    def after_step(self):
-        if self.lth and self.schedule.pruned:
-            print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
-            self.sparsifier._reset_weights(self.learn.model)
-        self.schedule.after_pruned()
-        self.sparsifier._apply_masks()
-
-    def after_epoch(self):
-        sparsity_str = [float(f"%0.2f"%sp) for sp in self.current_sparsity]
-        print(f'Sparsity at the end of epoch {self.epoch}: {sparsity_str}%')
-
-    def after_fit(self):
-        if self.save_tickets:
-            print('Saving Final Ticket')
-            self.sparsifier.save_model(f'winning_ticket_{self.previous_sparsity[0]:.2f}.pth', self.learn.model)
-        print(f'Final Sparsity: {self.schedule.current_sparsity:}%')
-        if self.reset_end: self.sparsifier._reset_weights()
-        self.sparsifier._clean_buffers()
-        self.schedule.reset()
-        self.sparsifier.print_sparsity()
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_movement = Criteria(noop, needs_update=True, output_f= lambda x,y: torch.abs(torch.sub(x,y)))
-updating_magnitude_increase = Criteria(torch.abs, needs_update=True, output_f= lambda x,y: torch.sub(x,y))
-magnitude_increase = Criteria(torch.abs, needs_init=True, output_f= lambda x,y: torch.sub(x,y))
-large_final = Criteria(torch.abs)
-random = Criteria(torch.rand_like)
-#updating_magnitude_increase = Criteria(torch.abs, needs_update=True, output_f= lambda x,y: torch.sub(x,y))
-updating_movmag = Criteria(noop, needs_update=True, output_f=lambda x,y: torch.abs(torch.mul(x, torch.sub(x,y))))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
from fasterai.sparse.granularity import *
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
learn = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
-learn.unfreeze()
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb = SparsifyCallback(sparsity=50, granularity='weight', context='global', criteria=magnitude_increase, schedule=cos)
-learn.fit(10, cbs=sp_cb)
-
- -
-
-
- -
-
- -
- -
-
Pruning of weight until a sparsity of [50]%
-Saving Weights at epoch 0
-
-
-
- -
- - -
- -
- -
- -
- - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
epochtrain_lossvalid_lossaccuracytime
00.5966000.5630790.69688800:09
10.5703300.5079120.74221900:09
20.5220230.5027960.74763200:09
30.4791690.4757540.76319300:09
40.4467570.4423240.78822700:09
50.3911620.3828150.83152900:09
60.3704180.3929720.81461400:09
70.3124890.3525640.83761800:09
80.2865200.3921770.82476300:09
90.2376630.3719760.85994600:09
- -
- -
- -
-
Sparsity at the end of epoch 0: [1.22]%
-Sparsity at the end of epoch 1: [4.77]%
-Sparsity at the end of epoch 2: [10.31]%
-Sparsity at the end of epoch 3: [17.27]%
-Sparsity at the end of epoch 4: [25.0]%
-Sparsity at the end of epoch 5: [32.73]%
-Sparsity at the end of epoch 6: [39.69]%
-Sparsity at the end of epoch 7: [45.23]%
-Sparsity at the end of epoch 8: [48.78]%
-Sparsity at the end of epoch 9: [50.0]%
-Final Sparsity: [50.0]%
-Sparsity in Conv2d 1: 57.26%
-Sparsity in Conv2d 7: 66.60%
-Sparsity in Conv2d 10: 66.29%
-Sparsity in Conv2d 13: 67.00%
-Sparsity in Conv2d 16: 66.91%
-Sparsity in Conv2d 20: 66.38%
-Sparsity in Conv2d 23: 67.30%
-Sparsity in Conv2d 26: 70.37%
-Sparsity in Conv2d 29: 66.94%
-Sparsity in Conv2d 32: 66.56%
-Sparsity in Conv2d 36: 66.06%
-Sparsity in Conv2d 39: 63.82%
-Sparsity in Conv2d 42: 70.28%
-Sparsity in Conv2d 45: 61.56%
-Sparsity in Conv2d 48: 57.46%
-Sparsity in Conv2d 52: 52.97%
-Sparsity in Conv2d 55: 46.10%
-Sparsity in Conv2d 58: 63.64%
-Sparsity in Conv2d 61: 44.09%
-Sparsity in Conv2d 64: 42.23%
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sc = torch.load('scores')
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_magnitude_increase.min_value
-
- -
-
-
- -
-
- -
- - - -
-
tensor(-0.0010, device='cuda:0', grad_fn=<SelectBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
a = (torch.abs(learn.model[0][0].weight)-torch.abs(learn.model[0][0]._old_weights))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
global_scs = torch.cat([updating_magnitude_increase(m).view(-1) for m in learn.model.modules() if isinstance(m, nn.Conv2d)])
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
global_scores = torch.cat([self.criteria.get_scores(m, self.criteria(m), self.granularity, self.criteria.get_min(global_criteria)).view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)])
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
    def granularize(self, m, scores, g):
-        if g in granularities[m.__class__.__name__]:
-            dim = granularities[m.__class__.__name__][g]
-            scores = scores[None].mean(dim=dim, keepdim=True).squeeze(0)
-        else: raise NameError('Invalid Granularity')
-        return scores
-    
-    def get_scores(self, m, scores, g, min_value=None):  
-        scores = self.granularize(m, self.rescale(scores, min_value).mul_(m._mask), g)
-        return scores
-    
-    def rescale(self, w, min_value=None):
-        self.min_value = min_value if min_value else self.get_min(w)
-        output =  w + self.min_value.abs() + torch.finfo(torch.float32).eps
-        return output
-    
-    def descale(self, w):
-        output = w - self.min_value.abs()
-        return output
-    
-    def get_min(self, w):
-        return w.view(-1)[w.argmin()]
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_magnitude_increase(learn.model[0][0])
-
- -
-
-
- -
-
- -
- - - -
-
tensor([[[[-3.3144e-04, -3.1918e-04, -3.4335e-04,  ...,  3.3004e-04,
-            1.9391e-04, -2.5945e-04],
-          [ 2.0925e-04,  1.7628e-04, -2.0480e-04,  ..., -2.7353e-04,
-           -2.5964e-04,  2.1820e-04],
-          [-2.0804e-04,  2.3791e-04,  2.4399e-04,  ...,  2.9409e-04,
-            2.4548e-04,  2.1187e-04],
-          ...,
-          [-8.4434e-05,  8.2493e-05, -1.1325e-06,  ..., -2.8819e-04,
-           -2.4536e-04, -2.1890e-04],
-          [ 6.2760e-05,  7.0795e-05, -4.2699e-05,  ...,  2.3955e-04,
-            1.3757e-04,  1.8336e-04],
-          [-4.3441e-05, -5.4070e-05, -1.0069e-04,  ..., -1.4922e-04,
-           -3.3379e-05, -1.5076e-04]],
-
-         [[-4.3890e-04, -3.7444e-04, -3.8284e-04,  ...,  3.9039e-04,
-            2.8816e-04, -3.7799e-04],
-          [ 2.9835e-04,  2.5048e-04, -2.8243e-04,  ..., -3.5775e-04,
-           -3.8518e-04, -3.3346e-04],
-          [-2.6954e-04,  2.9063e-04,  3.2023e-04,  ...,  3.7491e-04,
-            3.6588e-04,  3.0197e-04],
-          ...,
-          [-1.2656e-04, -7.2401e-05,  4.1192e-05,  ..., -3.9968e-04,
-           -3.7676e-04, -3.1322e-04],
-          [ 9.2570e-05,  4.0483e-05, -5.3547e-05,  ...,  0.0000e+00,
-            2.2745e-04,  2.5685e-04],
-          [ 9.8230e-05,  4.7730e-05, -7.1051e-05,  ..., -2.1550e-04,
-           -1.2256e-04, -2.1886e-04]],
-
-         [[ 7.8092e-04, -6.6559e-04,  6.4549e-04,  ...,  5.8065e-04,
-            5.1550e-04, -5.7230e-04],
-          [ 6.6384e-04, -5.7723e-04, -5.9432e-04,  ..., -5.8512e-04,
-           -6.2011e-04, -5.5517e-04],
-          [ 6.2386e-04,  5.8434e-04,  5.8521e-04,  ...,  5.9587e-04,
-            5.8930e-04,  5.1679e-04],
-          ...,
-          [-4.3719e-04,  3.5231e-04,  3.0484e-04,  ..., -6.5505e-04,
-           -5.9533e-04, -5.1524e-04],
-          [ 3.6376e-04, -2.8929e-04, -2.2778e-04,  ...,  6.1174e-04,
-            4.9141e-04,  4.7650e-04],
-          [ 3.3179e-04,  2.4332e-04, -2.8578e-04,  ..., -4.9752e-04,
-           -3.5613e-04, -3.9281e-04]]],
-
-
-        [[[-1.2585e-04, -2.5886e-04,  3.2665e-04,  ..., -3.4992e-04,
-           -2.0191e-04, -1.8906e-04],
-          [ 2.6297e-05,  1.5686e-04,  1.7621e-04,  ...,  2.2529e-04,
-            1.5993e-04,  1.0419e-04],
-          [-8.0564e-05,  2.4889e-04,  1.5916e-04,  ...,  3.0969e-04,
-            1.9268e-04,  1.4362e-04],
-          ...,
-          [-1.2355e-04, -2.3226e-04, -5.7667e-05,  ..., -3.2130e-04,
-           -2.4065e-04, -2.3525e-04],
-          [ 1.0460e-04, -2.4335e-04, -2.2093e-04,  ..., -4.8937e-04,
-           -3.5599e-04, -3.1879e-04],
-          [ 1.1960e-04,  1.8635e-04,  2.3850e-04,  ...,  5.5231e-04,
-            4.0008e-04,  3.3582e-04]],
-
-         [[-7.2374e-05,  2.6280e-04,  3.5218e-04,  ...,  3.7677e-04,
-            2.3657e-04, -2.6416e-04],
-          [ 7.0084e-05,  2.0251e-04,  2.3527e-04,  ...,  2.7841e-04,
-            2.3589e-04,  1.9839e-04],
-          [-1.1152e-04, -2.6606e-04,  2.0768e-04,  ...,  4.0168e-04,
-            2.7700e-04,  2.3864e-04],
-          ...,
-          [-1.4636e-04, -2.5934e-04, -6.7025e-05,  ..., -4.8709e-04,
-           -3.3769e-04, -3.3838e-04],
-          [ 1.8742e-04, -3.1423e-04, -3.0451e-04,  ..., -6.9997e-04,
-           -5.1052e-04, -4.1980e-04],
-          [ 2.3142e-04,  2.8688e-04,  3.3598e-04,  ...,  7.2668e-04,
-            5.6131e-04,  4.2578e-04]],
-
-         [[-2.2763e-05, -1.5359e-04,  2.0038e-04,  ...,  2.7895e-04,
-            2.6812e-04, -3.6456e-04],
-          [ 1.3581e-04,  2.0230e-04,  2.1207e-04,  ...,  3.1996e-04,
-            3.4207e-04,  3.6031e-04],
-          [-1.9113e-04, -2.5232e-04,  1.9227e-04,  ...,  4.5109e-04,
-            3.9276e-04,  3.7660e-04],
-          ...,
-          [-2.1584e-04, -2.6898e-04, -7.8976e-05,  ..., -6.0733e-04,
-           -5.6171e-04, -4.8898e-04],
-          [ 2.3721e-04, -2.4289e-04, -2.2273e-04,  ..., -7.1803e-04,
-           -6.5264e-04, -4.9591e-04],
-          [ 2.8685e-04,  2.4354e-04,  2.1915e-04,  ...,  5.4258e-04,
-            5.1527e-04,  4.2753e-04]]],
-
-
-        [[[-7.1054e-13, -6.4659e-13, -7.3896e-13,  ..., -9.8055e-13,
-           -1.0942e-12, -8.3844e-13],
-          [-6.1284e-14, -2.0650e-14, -8.0824e-14,  ..., -4.9738e-13,
-           -4.4054e-13, -3.0642e-14],
-          [-7.1765e-13, -7.6028e-13, -5.9330e-13,  ..., -9.7700e-14,
-           -1.0991e-14, -4.2633e-13],
-          ...,
-          [-9.5923e-13, -1.0019e-12, -7.9581e-13,  ..., -1.7586e-13,
-           -4.7606e-13, -1.3323e-13],
-          [-1.2932e-12, -1.4779e-12, -1.7479e-12,  ..., -1.3216e-12,
-           -1.0658e-12, -9.3792e-13],
-          [-1.2506e-12, -1.3642e-12, -1.8474e-12,  ..., -2.1458e-12,
-           -1.7764e-12, -1.7195e-12]],
-
-         [[-1.2648e-12, -9.5923e-13, -1.0374e-12,  ..., -1.1795e-12,
-           -1.3358e-12, -1.0800e-12],
-          [-5.7554e-13, -2.5047e-13, -3.0198e-13,  ..., -7.3186e-13,
-           -6.6791e-13, -2.2560e-13],
-          [-2.1849e-13, -4.8672e-13, -3.1264e-13,  ..., -1.8652e-13,
-           -7.9936e-14, -3.9790e-13],
-          ...,
-          [-5.6133e-13, -7.5318e-13, -4.4409e-13,  ..., -4.4054e-13,
-           -5.9686e-13, -1.8296e-13],
-          [-7.7449e-13, -9.8765e-13, -1.0445e-12,  ..., -6.3238e-13,
-           -4.1922e-13, -4.5830e-13],
-          [-5.9686e-13, -7.1054e-13, -9.0239e-13,  ..., -1.1653e-12,
-           -8.7397e-13, -9.8765e-13]],
-
-         [[-4.3698e-13, -1.3323e-13, -7.8160e-14,  ..., -5.9064e-14,
-           -2.6290e-13, -1.5632e-13],
-          [-4.1922e-13, -1.0800e-12, -1.0942e-12,  ..., -7.6739e-13,
-           -7.1765e-13, -9.8055e-13],
-          [-1.0445e-12, -1.6627e-12, -1.5916e-12,  ..., -1.3500e-12,
-           -1.3500e-12, -1.6485e-12],
-          ...,
-          [-9.8765e-13, -1.5064e-12, -1.2506e-12,  ..., -6.8212e-13,
-           -6.8212e-13, -1.1369e-12],
-          [-9.1660e-13, -1.3642e-12, -1.3785e-12,  ..., -1.1724e-12,
-           -1.1724e-12, -1.4353e-12],
-          [-6.2528e-13, -8.8107e-13, -1.0445e-12,  ..., -1.3927e-12,
-           -1.3358e-12, -1.5916e-12]]],
-
-
-        ...,
-
-
-        [[[-8.6904e-04, -8.4894e-04,  7.9927e-04,  ...,  7.9363e-04,
-           -4.5269e-04, -6.0248e-04],
-          [-8.1012e-04,  8.2050e-04,  7.9072e-04,  ...,  7.0001e-04,
-            2.5748e-04,  2.6945e-05],
-          [-7.6672e-04,  6.7008e-04,  6.0956e-04,  ...,  4.5993e-04,
-            1.2256e-04, -7.0671e-05],
-          ...,
-          [ 6.5641e-04,  1.6475e-04,  2.5021e-04,  ..., -5.6136e-05,
-            9.5177e-05, -1.3975e-04],
-          [-7.4311e-04, -2.5973e-04,  2.0902e-04,  ..., -4.2551e-04,
-            2.6818e-04,  2.2633e-04],
-          [-6.2624e-04, -1.1007e-04,  2.8012e-04,  ..., -2.7156e-04,
-           -2.6215e-04,  2.6566e-04]],
-
-         [[-7.7405e-04, -5.4897e-04,  4.8165e-04,  ...,  5.2210e-04,
-           -2.2935e-04, -2.8875e-04],
-          [-8.2406e-04,  7.0529e-04,  6.1416e-04,  ...,  6.1457e-04,
-            2.5998e-04, -1.1078e-04],
-          [-7.8529e-04,  7.2722e-04,  4.9793e-04,  ...,  5.6629e-04,
-           -2.2814e-05, -1.7570e-04],
-          ...,
-          [ 5.9857e-04,  5.6561e-05,  2.3614e-04,  ..., -2.3817e-04,
-            5.5708e-05,  3.2766e-04],
-          [-5.5145e-04,  2.3667e-04,  9.6943e-05,  ..., -3.6980e-04,
-           -2.0518e-04,  9.0559e-05],
-          [-5.3308e-04,  1.6965e-04, -2.5730e-04,  ..., -3.1097e-04,
-           -2.1714e-04,  1.9961e-04]],
-
-         [[-7.8655e-04, -6.6515e-04,  6.4316e-04,  ...,  6.8756e-04,
-           -3.9407e-04, -4.1086e-04],
-          [-8.6101e-04,  8.1302e-04,  7.9565e-04,  ...,  8.4542e-04,
-            2.6228e-04, -3.8486e-05],
-          [-7.4761e-04,  6.8086e-04,  5.9993e-04,  ...,  4.7606e-04,
-            1.8442e-04, -2.1681e-06],
-          ...,
-          [-5.4715e-04,  1.9849e-04,  3.9772e-04,  ...,  1.9206e-04,
-           -2.6702e-04, -5.5816e-04],
-          [-5.1322e-04, -1.9236e-04,  1.9207e-04,  ..., -8.0660e-05,
-            3.6270e-05, -8.8178e-05],
-          [-4.5972e-04,  1.5248e-04, -5.4374e-05,  ..., -1.6878e-04,
-           -1.0406e-04,  4.2112e-05]]],
-
-
-        [[[ 1.1352e-04, -2.3189e-04, -2.7893e-04,  ..., -5.8590e-04,
-           -1.7842e-04,  2.8087e-05],
-          [ 7.9070e-05,  5.1564e-05,  1.9996e-04,  ..., -6.8858e-04,
-           -1.1953e-04,  4.0517e-04],
-          [-2.1073e-04, -3.9864e-04, -1.1776e-04,  ..., -4.5086e-04,
-            9.6641e-05,  4.8853e-04],
-          ...,
-          [ 4.1895e-04,  4.8359e-04,  3.8038e-05,  ...,  4.1355e-06,
-            2.0843e-04,  3.2214e-04],
-          [ 3.9389e-04,  2.3630e-04,  2.5790e-05,  ...,  1.9319e-04,
-            3.3343e-04,  4.1148e-04],
-          [ 5.6018e-04,  4.9015e-04,  4.2643e-04,  ...,  3.6970e-04,
-            5.0580e-04,  4.8329e-04]],
-
-         [[ 7.4025e-05, -9.4887e-05, -1.0080e-04,  ...,  5.0498e-04,
-            2.8386e-05,  2.2475e-04],
-          [-5.0457e-05, -1.1731e-04, -7.9520e-05,  ..., -6.5706e-04,
-           -5.5134e-07,  5.2929e-04],
-          [-1.6361e-04, -5.9307e-04, -4.1223e-04,  ..., -4.3448e-04,
-            1.2499e-04,  5.7054e-04],
-          ...,
-          [ 3.7687e-04, -7.7112e-04, -1.7964e-04,  ...,  6.0303e-07,
-            2.7341e-04,  3.8639e-04],
-          [ 3.6994e-04,  4.5240e-04,  1.2206e-04,  ...,  1.2392e-04,
-            3.5541e-04,  4.4945e-04],
-          [-5.7531e-04,  6.1655e-04,  4.7375e-04,  ..., -2.9809e-04,
-            4.7820e-04,  4.8137e-04]],
-
-         [[ 3.0337e-04,  2.7763e-04, -1.7796e-04,  ...,  4.8358e-04,
-           -1.5804e-04, -3.5433e-04],
-          [ 3.7647e-04,  3.5135e-04,  3.0509e-04,  ...,  5.7623e-04,
-           -2.1577e-05, -4.4085e-04],
-          [ 1.8969e-04,  2.4454e-04,  2.2849e-04,  ...,  5.0501e-04,
-           -1.2124e-04, -5.2574e-04],
-          ...,
-          [ 9.7818e-05, -7.4751e-05,  2.5805e-04,  ...,  3.1158e-04,
-           -1.3677e-04, -2.4750e-04],
-          [ 6.8858e-05,  3.1424e-05, -2.7046e-04,  ...,  2.2599e-04,
-            1.9263e-04, -2.7891e-04],
-          [-3.1162e-04, -3.2638e-04, -1.7051e-04,  ..., -4.0129e-05,
-           -2.9694e-04, -3.3495e-04]]],
-
-
-        [[[ 4.1216e-04, -1.9831e-04, -1.6797e-04,  ..., -1.1450e-05,
-            1.9640e-05, -5.0569e-05],
-          [ 2.5212e-04, -2.0190e-04, -1.3407e-04,  ...,  1.2434e-04,
-           -1.3850e-04, -1.2483e-04],
-          [-2.1975e-04, -2.0221e-04, -8.8252e-05,  ..., -6.8083e-05,
-           -3.0100e-05, -7.2725e-05],
-          ...,
-          [ 1.1750e-04,  5.8115e-05,  1.4338e-04,  ...,  4.8280e-06,
-            4.3541e-05,  1.3473e-04],
-          [-1.7955e-04, -1.9152e-04,  1.0860e-04,  ...,  1.2682e-04,
-           -1.0862e-04, -2.0244e-04],
-          [-4.0270e-04,  2.5880e-04,  1.6352e-04,  ..., -2.1666e-04,
-           -2.3279e-04, -2.9883e-04]],
-
-         [[ 3.9849e-04, -7.9185e-05, -2.8191e-05,  ..., -1.2506e-04,
-           -1.9356e-04, -1.9540e-04],
-          [-2.8281e-04, -1.2228e-04, -1.9073e-06,  ..., -2.3163e-04,
-           -2.4356e-04, -1.8345e-04],
-          [-2.6762e-04,  0.0000e+00, -3.7756e-05,  ..., -1.3402e-04,
-           -1.4585e-04, -1.7186e-04],
-          ...,
-          [ 1.2365e-04, -5.0783e-05,  5.5879e-05,  ...,  2.7269e-05,
-            1.1384e-04,  2.1859e-04],
-          [-2.4480e-04, -3.0971e-04,  2.2112e-04,  ...,  2.0769e-04,
-            2.3259e-04, -3.0828e-04],
-          [ 4.1378e-04,  3.3271e-04,  2.5587e-04,  ...,  3.1947e-04,
-           -3.3649e-04, -3.5989e-04]],
-
-         [[ 1.4121e-04,  1.2737e-04, -1.3521e-04,  ...,  1.8655e-04,
-            1.5655e-04,  1.9725e-04],
-          [-1.2920e-06,  1.0746e-04,  1.7328e-04,  ...,  2.5312e-04,
-           -2.4189e-04,  2.1137e-04],
-          [ 4.2211e-05,  1.3388e-04,  2.4034e-04,  ..., -2.3255e-04,
-           -2.2399e-04, -2.2176e-04],
-          ...,
-          [-2.1140e-04, -3.4189e-04, -2.5313e-04,  ...,  2.0485e-04,
-            2.5709e-04,  2.8732e-04],
-          [-3.8829e-04, -4.5652e-04,  3.7189e-04,  ...,  3.3320e-04,
-            3.5398e-04, -3.7044e-04],
-          [-4.1519e-04, -4.4337e-04,  3.6917e-04,  ..., -3.9689e-04,
-           -4.0812e-04, -4.0868e-04]]]], device='cuda:0',
-       grad_fn=<SubBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
(updating_magnitude_increase(learn.model[0][0])+global_scs.min().abs())[0][0][0][0]<(updating_magnitude_increase(learn.model[0][0])+global_scs.min().abs())[0][0][0][1]
-
- -
-
-
- -
-
- -
- - - -
-
tensor(True, device='cuda:0')
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
print(updating_magnitude_increase.rescale(updating_magnitude_increase(learn.model[0][0]), global_scs.min()))
-
- -
-
-
- -
-
- -
- -
-
tensor([[[[0.0007, 0.0007, 0.0007,  ..., 0.0013, 0.0012, 0.0007],
-          [0.0012, 0.0012, 0.0008,  ..., 0.0007, 0.0007, 0.0012],
-          [0.0008, 0.0012, 0.0012,  ..., 0.0013, 0.0012, 0.0012],
-          ...,
-          [0.0009, 0.0011, 0.0010,  ..., 0.0007, 0.0008, 0.0008],
-          [0.0011, 0.0011, 0.0010,  ..., 0.0012, 0.0011, 0.0012],
-          [0.0010, 0.0009, 0.0009,  ..., 0.0008, 0.0010, 0.0008]],
-
-         [[0.0006, 0.0006, 0.0006,  ..., 0.0014, 0.0013, 0.0006],
-          [0.0013, 0.0012, 0.0007,  ..., 0.0006, 0.0006, 0.0007],
-          [0.0007, 0.0013, 0.0013,  ..., 0.0014, 0.0014, 0.0013],
-          ...,
-          [0.0009, 0.0009, 0.0010,  ..., 0.0006, 0.0006, 0.0007],
-          [0.0011, 0.0010, 0.0009,  ..., 0.0010, 0.0012, 0.0013],
-          [0.0011, 0.0010, 0.0009,  ..., 0.0008, 0.0009, 0.0008]],
-
-         [[0.0018, 0.0003, 0.0016,  ..., 0.0016, 0.0015, 0.0004],
-          [0.0017, 0.0004, 0.0004,  ..., 0.0004, 0.0004, 0.0004],
-          [0.0016, 0.0016, 0.0016,  ..., 0.0016, 0.0016, 0.0015],
-          ...,
-          [0.0006, 0.0014, 0.0013,  ..., 0.0003, 0.0004, 0.0005],
-          [0.0014, 0.0007, 0.0008,  ..., 0.0016, 0.0015, 0.0015],
-          [0.0013, 0.0012, 0.0007,  ..., 0.0005, 0.0006, 0.0006]]],
-
-
-        [[[0.0009, 0.0007, 0.0013,  ..., 0.0006, 0.0008, 0.0008],
-          [0.0010, 0.0012, 0.0012,  ..., 0.0012, 0.0012, 0.0011],
-          [0.0009, 0.0012, 0.0012,  ..., 0.0013, 0.0012, 0.0011],
-          ...,
-          [0.0009, 0.0008, 0.0009,  ..., 0.0007, 0.0008, 0.0008],
-          [0.0011, 0.0008, 0.0008,  ..., 0.0005, 0.0006, 0.0007],
-          [0.0011, 0.0012, 0.0012,  ..., 0.0016, 0.0014, 0.0013]],
-
-         [[0.0009, 0.0013, 0.0014,  ..., 0.0014, 0.0012, 0.0007],
-          [0.0011, 0.0012, 0.0012,  ..., 0.0013, 0.0012, 0.0012],
-          [0.0009, 0.0007, 0.0012,  ..., 0.0014, 0.0013, 0.0012],
-          ...,
-          [0.0009, 0.0007, 0.0009,  ..., 0.0005, 0.0007, 0.0007],
-          [0.0012, 0.0007, 0.0007,  ..., 0.0003, 0.0005, 0.0006],
-          [0.0012, 0.0013, 0.0013,  ..., 0.0017, 0.0016, 0.0014]],
-
-         [[0.0010, 0.0008, 0.0012,  ..., 0.0013, 0.0013, 0.0006],
-          [0.0011, 0.0012, 0.0012,  ..., 0.0013, 0.0013, 0.0014],
-          [0.0008, 0.0007, 0.0012,  ..., 0.0015, 0.0014, 0.0014],
-          ...,
-          [0.0008, 0.0007, 0.0009,  ..., 0.0004, 0.0004, 0.0005],
-          [0.0012, 0.0008, 0.0008,  ..., 0.0003, 0.0003, 0.0005],
-          [0.0013, 0.0012, 0.0012,  ..., 0.0015, 0.0015, 0.0014]]],
-
-
-        [[[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],
-
-         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],
-
-         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]]],
-
-
-        ...,
-
-
-        [[[0.0001, 0.0002, 0.0018,  ..., 0.0018, 0.0005, 0.0004],
-          [0.0002, 0.0018, 0.0018,  ..., 0.0017, 0.0013, 0.0010],
-          [0.0002, 0.0017, 0.0016,  ..., 0.0015, 0.0011, 0.0009],
-          ...,
-          [0.0017, 0.0012, 0.0012,  ..., 0.0009, 0.0011, 0.0009],
-          [0.0003, 0.0007, 0.0012,  ..., 0.0006, 0.0013, 0.0012],
-          [0.0004, 0.0009, 0.0013,  ..., 0.0007, 0.0007, 0.0013]],
-
-         [[0.0002, 0.0005, 0.0015,  ..., 0.0015, 0.0008, 0.0007],
-          [0.0002, 0.0017, 0.0016,  ..., 0.0016, 0.0013, 0.0009],
-          [0.0002, 0.0017, 0.0015,  ..., 0.0016, 0.0010, 0.0008],
-          ...,
-          [0.0016, 0.0011, 0.0012,  ..., 0.0008, 0.0011, 0.0013],
-          [0.0004, 0.0012, 0.0011,  ..., 0.0006, 0.0008, 0.0011],
-          [0.0005, 0.0012, 0.0007,  ..., 0.0007, 0.0008, 0.0012]],
-
-         [[0.0002, 0.0003, 0.0016,  ..., 0.0017, 0.0006, 0.0006],
-          [0.0001, 0.0018, 0.0018,  ..., 0.0018, 0.0013, 0.0010],
-          [0.0003, 0.0017, 0.0016,  ..., 0.0015, 0.0012, 0.0010],
-          ...,
-          [0.0005, 0.0012, 0.0014,  ..., 0.0012, 0.0007, 0.0004],
-          [0.0005, 0.0008, 0.0012,  ..., 0.0009, 0.0010, 0.0009],
-          [0.0005, 0.0012, 0.0009,  ..., 0.0008, 0.0009, 0.0010]]],
-
-
-        [[[0.0011, 0.0008, 0.0007,  ..., 0.0004, 0.0008, 0.0010],
-          [0.0011, 0.0011, 0.0012,  ..., 0.0003, 0.0009, 0.0014],
-          [0.0008, 0.0006, 0.0009,  ..., 0.0005, 0.0011, 0.0015],
-          ...,
-          [0.0014, 0.0015, 0.0010,  ..., 0.0010, 0.0012, 0.0013],
-          [0.0014, 0.0012, 0.0010,  ..., 0.0012, 0.0013, 0.0014],
-          [0.0016, 0.0015, 0.0014,  ..., 0.0014, 0.0015, 0.0015]],
-
-         [[0.0011, 0.0009, 0.0009,  ..., 0.0015, 0.0010, 0.0012],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0003, 0.0010, 0.0015],
-          [0.0008, 0.0004, 0.0006,  ..., 0.0006, 0.0011, 0.0016],
-          ...,
-          [0.0014, 0.0002, 0.0008,  ..., 0.0010, 0.0013, 0.0014],
-          [0.0014, 0.0015, 0.0011,  ..., 0.0011, 0.0014, 0.0014],
-          [0.0004, 0.0016, 0.0015,  ..., 0.0007, 0.0015, 0.0015]],
-
-         [[0.0013, 0.0013, 0.0008,  ..., 0.0015, 0.0008, 0.0006],
-          [0.0014, 0.0014, 0.0013,  ..., 0.0016, 0.0010, 0.0006],
-          [0.0012, 0.0012, 0.0012,  ..., 0.0015, 0.0009, 0.0005],
-          ...,
-          [0.0011, 0.0009, 0.0013,  ..., 0.0013, 0.0009, 0.0008],
-          [0.0011, 0.0010, 0.0007,  ..., 0.0012, 0.0012, 0.0007],
-          [0.0007, 0.0007, 0.0008,  ..., 0.0010, 0.0007, 0.0007]]],
-
-
-        [[[0.0014, 0.0008, 0.0008,  ..., 0.0010, 0.0010, 0.0009],
-          [0.0013, 0.0008, 0.0009,  ..., 0.0011, 0.0009, 0.0009],
-          [0.0008, 0.0008, 0.0009,  ..., 0.0009, 0.0010, 0.0009],
-          ...,
-          [0.0011, 0.0011, 0.0011,  ..., 0.0010, 0.0010, 0.0011],
-          [0.0008, 0.0008, 0.0011,  ..., 0.0011, 0.0009, 0.0008],
-          [0.0006, 0.0013, 0.0012,  ..., 0.0008, 0.0008, 0.0007]],
-
-         [[0.0014, 0.0009, 0.0010,  ..., 0.0009, 0.0008, 0.0008],
-          [0.0007, 0.0009, 0.0010,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0007, 0.0010, 0.0010,  ..., 0.0009, 0.0009, 0.0008],
-          ...,
-          [0.0011, 0.0009, 0.0011,  ..., 0.0010, 0.0011, 0.0012],
-          [0.0008, 0.0007, 0.0012,  ..., 0.0012, 0.0012, 0.0007],
-          [0.0014, 0.0013, 0.0013,  ..., 0.0013, 0.0007, 0.0006]],
-
-         [[0.0011, 0.0011, 0.0009,  ..., 0.0012, 0.0012, 0.0012],
-          [0.0010, 0.0011, 0.0012,  ..., 0.0013, 0.0008, 0.0012],
-          [0.0010, 0.0011, 0.0012,  ..., 0.0008, 0.0008, 0.0008],
-          ...,
-          [0.0008, 0.0007, 0.0007,  ..., 0.0012, 0.0013, 0.0013],
-          [0.0006, 0.0005, 0.0014,  ..., 0.0013, 0.0014, 0.0006],
-          [0.0006, 0.0006, 0.0014,  ..., 0.0006, 0.0006, 0.0006]]]],
-       device='cuda:0', grad_fn=<AddBackward0>)
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
print(updating_magnitude_increase.rescale(updating_magnitude_increase(learn.model[0][0]), global_scs.min()).mul_(learn.model[0][0]._mask))
-
- -
-
-
- -
-
- -
- -
-
tensor([[[[0.0007, 0.0007, 0.0007,  ..., 0.0013, 0.0012, 0.0007],
-          [0.0012, 0.0012, 0.0008,  ..., 0.0007, 0.0007, 0.0012],
-          [0.0008, 0.0012, 0.0012,  ..., 0.0013, 0.0012, 0.0012],
-          ...,
-          [0.0009, 0.0011, 0.0010,  ..., 0.0007, 0.0008, 0.0008],
-          [0.0011, 0.0011, 0.0010,  ..., 0.0012, 0.0011, 0.0012],
-          [0.0010, 0.0009, 0.0009,  ..., 0.0008, 0.0010, 0.0008]],
-
-         [[0.0006, 0.0006, 0.0006,  ..., 0.0014, 0.0013, 0.0006],
-          [0.0013, 0.0012, 0.0007,  ..., 0.0006, 0.0006, 0.0007],
-          [0.0007, 0.0013, 0.0013,  ..., 0.0014, 0.0014, 0.0013],
-          ...,
-          [0.0009, 0.0009, 0.0010,  ..., 0.0006, 0.0006, 0.0007],
-          [0.0011, 0.0010, 0.0009,  ..., 0.0000, 0.0012, 0.0013],
-          [0.0011, 0.0010, 0.0009,  ..., 0.0008, 0.0009, 0.0008]],
-
-         [[0.0018, 0.0003, 0.0016,  ..., 0.0016, 0.0015, 0.0004],
-          [0.0017, 0.0004, 0.0004,  ..., 0.0004, 0.0004, 0.0004],
-          [0.0016, 0.0016, 0.0016,  ..., 0.0016, 0.0016, 0.0015],
-          ...,
-          [0.0006, 0.0014, 0.0013,  ..., 0.0003, 0.0004, 0.0005],
-          [0.0014, 0.0007, 0.0008,  ..., 0.0016, 0.0015, 0.0015],
-          [0.0013, 0.0012, 0.0007,  ..., 0.0005, 0.0006, 0.0006]]],
-
-
-        [[[0.0009, 0.0007, 0.0013,  ..., 0.0006, 0.0008, 0.0008],
-          [0.0010, 0.0012, 0.0012,  ..., 0.0012, 0.0012, 0.0011],
-          [0.0009, 0.0012, 0.0012,  ..., 0.0013, 0.0012, 0.0011],
-          ...,
-          [0.0009, 0.0008, 0.0009,  ..., 0.0007, 0.0008, 0.0008],
-          [0.0011, 0.0008, 0.0008,  ..., 0.0005, 0.0006, 0.0007],
-          [0.0011, 0.0012, 0.0012,  ..., 0.0016, 0.0014, 0.0013]],
-
-         [[0.0009, 0.0013, 0.0014,  ..., 0.0014, 0.0012, 0.0007],
-          [0.0011, 0.0012, 0.0012,  ..., 0.0013, 0.0012, 0.0012],
-          [0.0009, 0.0007, 0.0012,  ..., 0.0014, 0.0013, 0.0012],
-          ...,
-          [0.0009, 0.0007, 0.0009,  ..., 0.0005, 0.0007, 0.0007],
-          [0.0012, 0.0007, 0.0007,  ..., 0.0003, 0.0005, 0.0006],
-          [0.0012, 0.0013, 0.0013,  ..., 0.0017, 0.0016, 0.0014]],
-
-         [[0.0010, 0.0008, 0.0012,  ..., 0.0013, 0.0013, 0.0006],
-          [0.0011, 0.0012, 0.0012,  ..., 0.0013, 0.0013, 0.0014],
-          [0.0008, 0.0007, 0.0012,  ..., 0.0015, 0.0014, 0.0014],
-          ...,
-          [0.0008, 0.0007, 0.0009,  ..., 0.0004, 0.0004, 0.0005],
-          [0.0012, 0.0008, 0.0008,  ..., 0.0003, 0.0003, 0.0005],
-          [0.0013, 0.0012, 0.0012,  ..., 0.0015, 0.0015, 0.0014]]],
-
-
-        [[[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],
-
-         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],
-
-         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          ...,
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
-          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]]],
-
-
-        ...,
-
-
-        [[[0.0001, 0.0002, 0.0018,  ..., 0.0018, 0.0005, 0.0004],
-          [0.0002, 0.0018, 0.0018,  ..., 0.0017, 0.0013, 0.0010],
-          [0.0002, 0.0017, 0.0016,  ..., 0.0015, 0.0011, 0.0009],
-          ...,
-          [0.0017, 0.0012, 0.0012,  ..., 0.0009, 0.0011, 0.0009],
-          [0.0003, 0.0007, 0.0012,  ..., 0.0006, 0.0013, 0.0012],
-          [0.0004, 0.0009, 0.0013,  ..., 0.0007, 0.0007, 0.0013]],
-
-         [[0.0002, 0.0005, 0.0015,  ..., 0.0015, 0.0008, 0.0007],
-          [0.0002, 0.0017, 0.0016,  ..., 0.0016, 0.0013, 0.0009],
-          [0.0002, 0.0017, 0.0015,  ..., 0.0016, 0.0010, 0.0008],
-          ...,
-          [0.0016, 0.0011, 0.0012,  ..., 0.0008, 0.0011, 0.0013],
-          [0.0004, 0.0012, 0.0011,  ..., 0.0006, 0.0008, 0.0011],
-          [0.0005, 0.0012, 0.0007,  ..., 0.0007, 0.0008, 0.0012]],
-
-         [[0.0002, 0.0003, 0.0016,  ..., 0.0017, 0.0006, 0.0006],
-          [0.0001, 0.0018, 0.0018,  ..., 0.0018, 0.0013, 0.0010],
-          [0.0003, 0.0017, 0.0016,  ..., 0.0015, 0.0012, 0.0010],
-          ...,
-          [0.0005, 0.0012, 0.0014,  ..., 0.0012, 0.0007, 0.0004],
-          [0.0005, 0.0008, 0.0012,  ..., 0.0009, 0.0010, 0.0009],
-          [0.0005, 0.0012, 0.0009,  ..., 0.0008, 0.0009, 0.0010]]],
-
-
-        [[[0.0011, 0.0008, 0.0007,  ..., 0.0004, 0.0008, 0.0010],
-          [0.0011, 0.0011, 0.0012,  ..., 0.0003, 0.0009, 0.0014],
-          [0.0008, 0.0006, 0.0009,  ..., 0.0005, 0.0011, 0.0015],
-          ...,
-          [0.0014, 0.0015, 0.0010,  ..., 0.0010, 0.0012, 0.0013],
-          [0.0014, 0.0012, 0.0010,  ..., 0.0012, 0.0013, 0.0014],
-          [0.0016, 0.0015, 0.0014,  ..., 0.0014, 0.0015, 0.0015]],
-
-         [[0.0011, 0.0009, 0.0009,  ..., 0.0015, 0.0010, 0.0012],
-          [0.0009, 0.0009, 0.0009,  ..., 0.0003, 0.0010, 0.0015],
-          [0.0008, 0.0004, 0.0006,  ..., 0.0006, 0.0011, 0.0016],
-          ...,
-          [0.0014, 0.0002, 0.0008,  ..., 0.0010, 0.0013, 0.0014],
-          [0.0014, 0.0015, 0.0011,  ..., 0.0011, 0.0014, 0.0014],
-          [0.0004, 0.0016, 0.0015,  ..., 0.0007, 0.0015, 0.0015]],
-
-         [[0.0013, 0.0013, 0.0008,  ..., 0.0015, 0.0008, 0.0006],
-          [0.0014, 0.0014, 0.0013,  ..., 0.0016, 0.0010, 0.0006],
-          [0.0012, 0.0012, 0.0012,  ..., 0.0015, 0.0009, 0.0005],
-          ...,
-          [0.0011, 0.0009, 0.0013,  ..., 0.0013, 0.0009, 0.0008],
-          [0.0011, 0.0010, 0.0007,  ..., 0.0012, 0.0012, 0.0007],
-          [0.0007, 0.0007, 0.0008,  ..., 0.0010, 0.0007, 0.0007]]],
-
-
-        [[[0.0014, 0.0008, 0.0008,  ..., 0.0010, 0.0010, 0.0009],
-          [0.0013, 0.0008, 0.0009,  ..., 0.0011, 0.0009, 0.0009],
-          [0.0008, 0.0008, 0.0009,  ..., 0.0009, 0.0010, 0.0009],
-          ...,
-          [0.0011, 0.0011, 0.0011,  ..., 0.0010, 0.0010, 0.0011],
-          [0.0008, 0.0008, 0.0011,  ..., 0.0011, 0.0009, 0.0008],
-          [0.0006, 0.0013, 0.0012,  ..., 0.0008, 0.0008, 0.0007]],
-
-         [[0.0014, 0.0009, 0.0010,  ..., 0.0009, 0.0008, 0.0008],
-          [0.0007, 0.0009, 0.0010,  ..., 0.0008, 0.0008, 0.0008],
-          [0.0007, 0.0000, 0.0010,  ..., 0.0009, 0.0009, 0.0008],
-          ...,
-          [0.0011, 0.0009, 0.0011,  ..., 0.0010, 0.0011, 0.0012],
-          [0.0008, 0.0007, 0.0012,  ..., 0.0012, 0.0012, 0.0007],
-          [0.0014, 0.0013, 0.0013,  ..., 0.0013, 0.0007, 0.0006]],
-
-         [[0.0011, 0.0011, 0.0009,  ..., 0.0012, 0.0012, 0.0012],
-          [0.0010, 0.0011, 0.0012,  ..., 0.0013, 0.0008, 0.0012],
-          [0.0010, 0.0011, 0.0012,  ..., 0.0008, 0.0008, 0.0008],
-          ...,
-          [0.0008, 0.0007, 0.0007,  ..., 0.0012, 0.0013, 0.0013],
-          [0.0006, 0.0005, 0.0014,  ..., 0.0013, 0.0014, 0.0006],
-          [0.0006, 0.0006, 0.0014,  ..., 0.0006, 0.0006, 0.0006]]]],
-       device='cuda:0', grad_fn=<MulBackward0>)
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_magnitude_increase.rescale(updating_magnitude_increase(learn.model[0][0]), global_scs.min()).mul_(learn.model[0][0]._mask)[0]
-
- -
-
-
- -
-
- -
- - - -
-
tensor([[[0.0007, 0.0007, 0.0007, 0.0013, 0.0013, 0.0012, 0.0007],
-         [0.0012, 0.0012, 0.0008, 0.0008, 0.0007, 0.0007, 0.0012],
-         [0.0008, 0.0012, 0.0012, 0.0013, 0.0013, 0.0012, 0.0012],
-         [0.0012, 0.0009, 0.0009, 0.0008, 0.0008, 0.0008, 0.0012],
-         [0.0009, 0.0011, 0.0010, 0.0009, 0.0007, 0.0008, 0.0008],
-         [0.0011, 0.0011, 0.0010, 0.0010, 0.0012, 0.0011, 0.0012],
-         [0.0010, 0.0009, 0.0009, 0.0010, 0.0008, 0.0010, 0.0008]],
-
-        [[0.0006, 0.0006, 0.0006, 0.0014, 0.0014, 0.0013, 0.0006],
-         [0.0013, 0.0012, 0.0007, 0.0007, 0.0006, 0.0006, 0.0007],
-         [0.0007, 0.0013, 0.0013, 0.0013, 0.0014, 0.0014, 0.0013],
-         [0.0008, 0.0008, 0.0008, 0.0000, 0.0007, 0.0007, 0.0012],
-         [0.0009, 0.0009, 0.0010, 0.0008, 0.0006, 0.0006, 0.0007],
-         [0.0011, 0.0010, 0.0009, 0.0011, 0.0000, 0.0012, 0.0013],
-         [0.0011, 0.0010, 0.0009, 0.0010, 0.0008, 0.0009, 0.0008]],
-
-        [[0.0018, 0.0003, 0.0016, 0.0016, 0.0016, 0.0015, 0.0004],
-         [0.0017, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004, 0.0004],
-         [0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0015],
-         [0.0016, 0.0005, 0.0005, 0.0005, 0.0004, 0.0016, 0.0015],
-         [0.0006, 0.0014, 0.0013, 0.0014, 0.0003, 0.0004, 0.0005],
-         [0.0014, 0.0007, 0.0008, 0.0014, 0.0016, 0.0015, 0.0015],
-         [0.0013, 0.0012, 0.0007, 0.0008, 0.0005, 0.0006, 0.0006]]],
-       device='cuda:0', grad_fn=<SelectBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
global_scs.min()
-
- -
-
-
- -
-
- -
- - - -
-
tensor(-0.0010, device='cuda:0', grad_fn=<MinBackward1>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
print(updating_magnitude_increase.granularize(learn.model[0][0], updating_magnitude_increase.rescale(updating_magnitude_increase(learn.model[0][0]), global_scs.min()).mul_(learn.model[0][0]._mask), 'filter').squeeze())
-
- -
-
-
- -
-
- -
- -
-
tensor([0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0009, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0011, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0011, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0011, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0011, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0011, 0.0011,
-        0.0010], device='cuda:0', grad_fn=<SqueezeBackward0>)
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_magnitude_increase.granularize(learn.model[0][0], updating_magnitude_increase.rescale(updating_magnitude_increase(learn.model[0][0]), global_scs.min()).mul_(learn.model[0][0]._mask), 'filter').squeeze()[0]<updating_magnitude_increase.granularize(learn.model[0][0], updating_magnitude_increase.rescale(updating_magnitude_increase(learn.model[0][0]), global_scs.min()).mul_(learn.model[0][0]._mask), 'filter').squeeze()[1]
-
- -
-
-
- -
-
- -
- - - -
-
tensor(True, device='cuda:0')
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
min_value = 0
-
-for k, m in enumerate(learn.model.modules()):
-    if isinstance(m, nn.Conv2d):
-        print(k)
-        
-        #print(large_final.get_scores(m, large_final(m), 'filter', large_final.min_value).squeeze())
-        #if updating_magnitude_increase(m).view(-1)[updating_magnitude_increase(m).argmin()]<min_value: min_value=updating_magnitude_increase(m).view(-1)[updating_magnitude_increase(m).argmin()]
-        #print(updating_magnitude_increase(m))
-
- -
-
-
- -
-
- -
- -
-
2
-8
-11
-14
-17
-21
-24
-27
-30
-33
-37
-40
-43
-46
-49
-53
-56
-59
-62
-65
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
global_scores = torch.cat([self.criteria(m, self.granularity).view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)]) # Get all scores
-
- -
-
-
- -
-
- -
- - - -
-
tensor([0.0007, 0.0007, 0.0007,  ..., 0.0008, 0.0006, 0.0006], device='cuda:0',
-       requires_grad=True)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sc[800:1200]
-
- -
-
-
- -
-
- -
- - - -
-
tensor([0.0011, 0.0007, 0.0006, 0.0005, 0.0006, 0.0007, 0.0010, 0.0010, 0.0009,
-        0.0006, 0.0005, 0.0006, 0.0007, 0.0009, 0.0009, 0.0010, 0.0007, 0.0015,
-        0.0006, 0.0012, 0.0009, 0.0009, 0.0012, 0.0007, 0.0005, 0.0007, 0.0012,
-        0.0010, 0.0011, 0.0012, 0.0014, 0.0005, 0.0007, 0.0011, 0.0014, 0.0015,
-        0.0015, 0.0004, 0.0017, 0.0018, 0.0009, 0.0013, 0.0014, 0.0004, 0.0004,
-        0.0004, 0.0004, 0.0007, 0.0011, 0.0013, 0.0005, 0.0004, 0.0003, 0.0004,
-        0.0009, 0.0012, 0.0013, 0.0014, 0.0003, 0.0004, 0.0004, 0.0009, 0.0011,
-        0.0011, 0.0013, 0.0005, 0.0004, 0.0005, 0.0010, 0.0011, 0.0012, 0.0014,
-        0.0005, 0.0005, 0.0006, 0.0011, 0.0012, 0.0013, 0.0013, 0.0005, 0.0005,
-        0.0006, 0.0007, 0.0009, 0.0010, 0.0012, 0.0012, 0.0013, 0.0013, 0.0006,
-        0.0008, 0.0010, 0.0012, 0.0013, 0.0013, 0.0013, 0.0005, 0.0008, 0.0010,
-        0.0007, 0.0005, 0.0006, 0.0014, 0.0005, 0.0012, 0.0010, 0.0009, 0.0008,
-        0.0006, 0.0014, 0.0004, 0.0006, 0.0010, 0.0010, 0.0009, 0.0007, 0.0006,
-        0.0003, 0.0005, 0.0014, 0.0011, 0.0009, 0.0004, 0.0016, 0.0002, 0.0004,
-        0.0005, 0.0009, 0.0006, 0.0002, 0.0017, 0.0013, 0.0011, 0.0010, 0.0008,
-        0.0007, 0.0006, 0.0006, 0.0014, 0.0012, 0.0011, 0.0012, 0.0014, 0.0014,
-        0.0006, 0.0015, 0.0012, 0.0011, 0.0014, 0.0017, 0.0016, 0.0015, 0.0015,
-        0.0009, 0.0012, 0.0013, 0.0014, 0.0016, 0.0016, 0.0016, 0.0006, 0.0012,
-        0.0000, 0.0000, 0.0013, 0.0015, 0.0018, 0.0004, 0.0005, 0.0010, 0.0013,
-        0.0018, 0.0017, 0.0018, 0.0017, 0.0003, 0.0013, 0.0017, 0.0019, 0.0019,
-        0.0006, 0.0008, 0.0010, 0.0014, 0.0015, 0.0016, 0.0015, 0.0005, 0.0008,
-        0.0010, 0.0007, 0.0004, 0.0005, 0.0005, 0.0005, 0.0011, 0.0007, 0.0004,
-        0.0001, 0.0002, 0.0004, 0.0004, 0.0010, 0.0006, 0.0004, 0.0004, 0.0002,
-        0.0003, 0.0003, 0.0014, 0.0007, 0.0005, 0.0007, 0.0006, 0.0005, 0.0002,
-        0.0017, 0.0016, 0.0011, 0.0007, 0.0003, 0.0003, 0.0002, 0.0018, 0.0018,
-        0.0009, 0.0000, 0.0002, 0.0002, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
-        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0005, 0.0014,
-        0.0008, 0.0012, 0.0009, 0.0011, 0.0009, 0.0004, 0.0014, 0.0013, 0.0013,
-        0.0008, 0.0011, 0.0009, 0.0004, 0.0006, 0.0014, 0.0006, 0.0009, 0.0010,
-        0.0009, 0.0015, 0.0007, 0.0014], device='cuda:0',
-       grad_fn=<SliceBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb = SparsifyCallback(sparsity=50, granularity='weight', context='global', criteria=updating_magnitude_increase, schedule=cos)
-learn.fit(5, cbs=sp_cb)
-
- -
-
-
- -
-
- -
- -
-
Pruning of weight until a sparsity of [50]%
-Saving Weights at epoch 0
-
-
-
- -
- - -
- -
- -
- -
- - -
-
- - 0.00% [0/5 00:00<00:00] -
- - - - - - - - - - - - - -
epochtrain_lossvalid_lossaccuracytime

- -

- - 6.52% [6/92 00:04<01:06 1.0906] -
-
- -
- -
- -
-
tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', grad_fn=<CatBackward0>)
-tensor([ 0.0010,  0.0010,  0.0010,  ..., -0.0010, -0.0010, -0.0010],
-       device='cuda:0', grad_fn=<CatBackward0>)
-tensor([-0.0002, -0.0005, -0.0002,  ..., -0.0007, -0.0008, -0.0007],
-       device='cuda:0', grad_fn=<CatBackward0>)
-tensor([ 5.0217e-06, -4.7387e-05,  1.6059e-04,  ..., -5.1428e-04,
-        -5.0568e-04, -3.9404e-04], device='cuda:0', grad_fn=<CatBackward0>)
-tensor([-0.0005, -0.0006, -0.0005,  ..., -0.0005, -0.0006, -0.0004],
-       device='cuda:0', grad_fn=<CatBackward0>)
-tensor([-0.0002, -0.0002, -0.0001,  ..., -0.0005, -0.0006, -0.0004],
-       device='cuda:0', grad_fn=<CatBackward0>)
-tensor([-0.0003, -0.0003, -0.0003,  ..., -0.0002, -0.0004, -0.0004],
-       device='cuda:0', grad_fn=<CatBackward0>)
-
-
-
- -
- -
-
-KeyboardInterrupt
-
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sc = updating_magnitude_increase(learn.model[0][0])
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sc
-
- -
-
-
- -
-
- -
- - - -
-
tensor([[[[ 0.0000e+00,  3.1665e-08,  3.8650e-08,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  2.2352e-08,  ...,  0.0000e+00,
-            1.4901e-08,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]],
-
-         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]],
-
-         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]]],
-
-
-        [[[ 2.1420e-08,  1.9092e-08,  0.0000e+00,  ...,  1.8626e-08,
-            2.2352e-08,  2.2352e-08],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 1.4901e-08,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            1.4901e-08,  1.4901e-08],
-          [ 0.0000e+00,  0.0000e+00,  1.1176e-08,  ...,  1.3039e-08,
-            1.6764e-08,  1.8626e-08],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]],
-
-         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 7.4506e-09,  5.5879e-09,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 7.4506e-09,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            1.4901e-08,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  7.4506e-09,
-            1.4901e-08,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]],
-
-         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]]],
-
-
-        [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]],
-
-         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]],
-
-         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]]],
-
-
-        ...,
-
-
-        [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]],
-
-         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]],
-
-         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            2.6077e-08,  1.1176e-08],
-          [ 1.8626e-09,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 2.2352e-08,  1.9558e-08,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 4.6566e-08,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 5.7742e-08,  5.8673e-08,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 5.9605e-08,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]]],
-
-
-        [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]],
-
-         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]],
-
-         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 9.3132e-08,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            3.3528e-08,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]]],
-
-
-        [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]],
-
-         [[ 0.0000e+00,  3.7253e-09,  0.0000e+00,  ...,  0.0000e+00,
-            3.3760e-09,  1.7462e-09],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            7.4506e-09,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]],
-
-         [[ 0.0000e+00,  2.7940e-09,  0.0000e+00,  ...,  2.0955e-09,
-            0.0000e+00,  1.8626e-09],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00, -1.3970e-09],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            7.4506e-09,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            5.5879e-09,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  3.7253e-09,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]]]], device='cuda:0',
-       grad_fn=<SubBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
learn.model[0][0].weight
-
- -
-
-
- -
-
- -
- - - -
-
Parameter containing:
-tensor([[[[-0.0000e+00, -9.1996e-03, -4.4083e-03,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -1.1251e-01,  ..., -2.7214e-01,
-           -1.2951e-01, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          ...,
-          [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
-            0.0000e+00, -0.0000e+00]],
-
-         [[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]],
-
-         [[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00]]],
-
-
-        [[[-4.8550e-03, -4.2851e-03, -0.0000e+00,  ..., -3.7784e-02,
-           -2.6483e-02, -4.9857e-02],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-1.1106e-02, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          ...,
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -1.4430e-01,
-           -1.1385e-01, -5.1759e-02],
-          [-0.0000e+00, -0.0000e+00, -1.5829e-02,  ..., -2.3354e-02,
-           -2.6224e-02, -3.2349e-02],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00]],
-
-         [[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-6.0392e-02, -2.4560e-02, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          ...,
-          [-1.1325e-01, -0.0000e+00,  0.0000e+00,  ..., -2.8661e-01,
-           -2.1487e-01, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -9.5230e-02,
-           -7.0820e-02, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00]],
-
-         [[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          ...,
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00]]],
-
-
-        [[[-7.0793e-08, -6.4276e-08, -7.3772e-08,  ..., -9.7954e-08,
-           -1.0900e-07, -8.3382e-08],
-          [-6.1097e-09,  2.0603e-09, -8.0885e-09,  ..., -4.9817e-08,
-           -4.3815e-08, -3.0523e-09],
-          [ 7.1919e-08,  7.5581e-08,  5.9255e-08,  ..., -9.7464e-09,
-           -1.0946e-09,  4.2423e-08],
-          ...,
-          [ 9.5844e-08,  1.0034e-07,  7.9780e-08,  ..., -1.7483e-08,
-           -4.7644e-08, -1.3259e-08],
-          [ 1.2898e-07,  1.4755e-07,  1.7468e-07,  ...,  1.3226e-07,
-            1.0623e-07,  9.3272e-08],
-          [ 1.2553e-07,  1.3638e-07,  1.8422e-07,  ...,  2.1389e-07,
-            1.7701e-07,  1.7158e-07]],
-
-         [[-1.2684e-07, -9.6094e-08, -1.0367e-07,  ..., -1.1803e-07,
-           -1.3303e-07, -1.0815e-07],
-          [-5.7386e-08, -2.5043e-08, -3.0101e-08,  ..., -7.2888e-08,
-           -6.6991e-08, -2.2564e-08],
-          [ 2.1803e-08,  4.8586e-08,  3.1207e-08,  ..., -1.8685e-08,
-           -7.9554e-09,  3.9732e-08],
-          ...,
-          [ 5.5987e-08,  7.5491e-08,  4.4475e-08,  ..., -4.4107e-08,
-           -5.9902e-08, -1.8239e-08],
-          [ 7.7578e-08,  9.8302e-08,  1.0450e-07,  ...,  6.3242e-08,
-            4.1761e-08,  4.5880e-08],
-          [ 5.9806e-08,  7.0973e-08,  9.0395e-08,  ...,  1.1649e-07,
-            8.7510e-08,  9.8791e-08]],
-
-         [[-4.3790e-08,  1.3264e-08,  7.8239e-09,  ..., -5.8777e-09,
-           -2.6205e-08, -1.5642e-08],
-          [ 4.1681e-08,  1.0773e-07,  1.0941e-07,  ...,  7.6368e-08,
-            7.1417e-08,  9.7569e-08],
-          [ 1.0431e-07,  1.6578e-07,  1.5926e-07,  ...,  1.3511e-07,
-            1.3481e-07,  1.6441e-07],
-          ...,
-          [ 9.8718e-08,  1.5065e-07,  1.2541e-07,  ...,  6.8284e-08,
-            6.8350e-08,  1.1362e-07],
-          [ 9.1392e-08,  1.3570e-07,  1.3787e-07,  ...,  1.1673e-07,
-            1.1717e-07,  1.4387e-07],
-          [ 6.2154e-08,  8.8143e-08,  1.0452e-07,  ...,  1.3934e-07,
-            1.3326e-07,  1.5837e-07]]],
-
-
-        ...,
-
-
-        [[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          ...,
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00]],
-
-         [[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          ...,
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00]],
-
-         [[ 0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -4.1103e-02, -5.4732e-02],
-          [-3.1011e-02, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [-4.2861e-02, -8.1197e-03, -0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [-8.1318e-03, -0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [-3.0157e-02, -1.1246e-02, -0.0000e+00,  ..., -0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [-1.8279e-02, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-            0.0000e+00,  0.0000e+00]]],
-
-
-        [[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00, -0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00,  0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00],
-          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-           -0.0000e+00, -0.0000e+00]],
-
-         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
-            0.0000e+00,  0.0000e+00]],
-
-         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 5.9157e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            1.7233e-02,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]]],
-
-
-        [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00, -0.0000e+00],
-          [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00, -0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]],
-
-         [[ 0.0000e+00,  3.1678e-02,  0.0000e+00,  ...,  0.0000e+00,
-            1.1399e-03,  4.9293e-04],
-          [-0.0000e+00,  0.0000e+00,  1.4514e-01,  ...,  0.0000e+00,
-            0.0000e+00, -0.0000e+00],
-          [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00, -0.0000e+00],
-          ...,
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  4.0631e-01,
-            2.6102e-01,  1.3537e-01],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            8.9127e-02,  0.0000e+00],
-          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]],
-
-         [[-0.0000e+00,  1.1473e-02,  0.0000e+00,  ...,  2.0317e-03,
-            0.0000e+00,  1.3542e-02],
-          [-0.0000e+00,  4.4105e-02,  5.8867e-02,  ...,  1.2212e-02,
-            0.0000e+00,  3.9528e-03],
-          [-0.0000e+00,  1.2382e-01,  4.1229e-02,  ...,  0.0000e+00,
-            0.0000e+00, -0.0000e+00],
-          ...,
-          [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  2.2744e-01,
-            1.1968e-01,  7.2456e-02],
-          [-0.0000e+00,  0.0000e+00,  8.1858e-02,  ...,  1.5512e-01,
-            2.0697e-02,  0.0000e+00],
-          [-0.0000e+00,  0.0000e+00,  3.2302e-02,  ...,  0.0000e+00,
-            0.0000e+00,  0.0000e+00]]]], device='cuda:0', requires_grad=True)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
learn.model[0][0]._mask
-
- -
-
-
- -
-
- -
- - - -
-
tensor([[[[0., 1., 1.,  ..., 0., 0., 0.],
-          [0., 0., 1.,  ..., 1., 1., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          ...,
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.]],
-
-         [[0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          ...,
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.]],
-
-         [[0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          ...,
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.]]],
-
-
-        [[[1., 1., 0.,  ..., 1., 1., 1.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [1., 0., 0.,  ..., 0., 0., 0.],
-          ...,
-          [0., 0., 0.,  ..., 1., 1., 1.],
-          [0., 0., 1.,  ..., 1., 1., 1.],
-          [0., 0., 0.,  ..., 0., 0., 0.]],
-
-         [[0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [1., 1., 0.,  ..., 0., 0., 0.],
-          ...,
-          [1., 0., 0.,  ..., 1., 1., 0.],
-          [0., 0., 0.,  ..., 1., 1., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.]],
-
-         [[0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          ...,
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.]]],
-
-
-        [[[1., 1., 1.,  ..., 1., 1., 1.],
-          [1., 1., 1.,  ..., 1., 1., 1.],
-          [1., 1., 1.,  ..., 1., 1., 1.],
-          ...,
-          [1., 1., 1.,  ..., 1., 1., 1.],
-          [1., 1., 1.,  ..., 1., 1., 1.],
-          [1., 1., 1.,  ..., 1., 1., 1.]],
-
-         [[1., 1., 1.,  ..., 1., 1., 1.],
-          [1., 1., 1.,  ..., 1., 1., 1.],
-          [1., 1., 1.,  ..., 1., 1., 1.],
-          ...,
-          [1., 1., 1.,  ..., 1., 1., 1.],
-          [1., 1., 1.,  ..., 1., 1., 1.],
-          [1., 1., 1.,  ..., 1., 1., 1.]],
-
-         [[1., 1., 1.,  ..., 1., 1., 1.],
-          [1., 1., 1.,  ..., 1., 1., 1.],
-          [1., 1., 1.,  ..., 1., 1., 1.],
-          ...,
-          [1., 1., 1.,  ..., 1., 1., 1.],
-          [1., 1., 1.,  ..., 1., 1., 1.],
-          [1., 1., 1.,  ..., 1., 1., 1.]]],
-
-
-        ...,
-
-
-        [[[0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          ...,
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.]],
-
-         [[0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          ...,
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.]],
-
-         [[0., 0., 0.,  ..., 0., 1., 1.],
-          [1., 0., 0.,  ..., 0., 0., 0.],
-          [1., 1., 0.,  ..., 0., 0., 0.],
-          ...,
-          [1., 0., 0.,  ..., 0., 0., 0.],
-          [1., 1., 0.,  ..., 0., 0., 0.],
-          [1., 0., 0.,  ..., 0., 0., 0.]]],
-
-
-        [[[0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          ...,
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.]],
-
-         [[0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          ...,
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.]],
-
-         [[0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          ...,
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [1., 0., 0.,  ..., 0., 1., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.]]],
-
-
-        [[[0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          ...,
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.]],
-
-         [[0., 1., 0.,  ..., 0., 1., 1.],
-          [0., 0., 1.,  ..., 0., 0., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.],
-          ...,
-          [0., 0., 0.,  ..., 1., 1., 1.],
-          [0., 0., 0.,  ..., 0., 1., 0.],
-          [0., 0., 0.,  ..., 0., 0., 0.]],
-
-         [[0., 1., 0.,  ..., 1., 0., 1.],
-          [0., 1., 1.,  ..., 1., 0., 1.],
-          [0., 1., 1.,  ..., 0., 0., 0.],
-          ...,
-          [0., 0., 0.,  ..., 1., 1., 1.],
-          [0., 0., 1.,  ..., 1., 1., 0.],
-          [0., 0., 1.,  ..., 0., 0., 0.]]]], device='cuda:0')
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
%debug
-
- -
-
-
- -
-
- -
- -
-
> /tmp/ipykernel_1952159/3353689575.py(27)granularize()
-     25             dim = granularities[m.__class__.__name__][g]
-     26             #if hasattr(m, '_mask') == False: m.register_buffer("_mask", torch.ones_like(scores[None].mean(dim=dim, keepdim=True).squeeze(0))) # Put the mask into a buffer
----> 27             scores = scores[None].mean(dim=dim, keepdim=True).squeeze(0)
-     28         else: raise NameError('Invalid Granularity')
-     29         return scores
-
-ipdb> scores.shape
-torch.Size([11166912])
-ipdb> quit
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class SparsifyCallback(Callback):
-    def __init__(self, sparsity, granularity, context, criteria, schedule, lth=False, rewind_epoch=0, reset_end=False, save_tickets=False, model=None, round_to=None, layer_type=nn.Conv2d):
-        store_attr()
-        self.sparsity = listify(self.sparsity)
-
-    def before_fit(self):
-        print(f'Pruning of {self.granularity} until a sparsity of {self.sparsity}%')
-        assert self.schedule.start_pct*self.n_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'
-        model = self.model if self.model else self.learn.model
-        self.sparsifier = Sparsifier(model, self.granularity, self.context, self.criteria, self.layer_type)
-
-    def before_epoch(self):
-        if self.epoch == self.rewind_epoch:
-            print(f'Saving Weights at epoch {self.epoch}')
-            self.sparsifier._save_weights()
-
-    def before_batch(self):
-        self.current_sparsity = self.schedule(self.sparsity, round(self.pct_train,3))
-        if self.schedule.pruned and self.training:
-            if self.lth and self.save_tickets:
-                print('Saving Intermediate Ticket')
-                self.sparsifier.save_model(f'winning_ticket_{self.previous_sparsity[0]:.2f}.pth', self.learn.model)
-            self.sparsifier.prune_model(self.current_sparsity, self.round_to)
-
-    def after_step(self):
-        if self.lth and self.schedule.pruned:
-            print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
-            self.sparsifier._reset_weights(self.learn.model)
-        self.schedule.after_pruned()
-        self.sparsifier._apply_masks()
-
-    def after_epoch(self):
-        sparsity_str = [float(f"%0.2f"%sp) for sp in self.current_sparsity]
-        print(f'Sparsity at the end of epoch {self.epoch}: {sparsity_str}%')
-
-    def after_fit(self):
-        if self.save_tickets:
-            print('Saving Final Ticket')
-            self.sparsifier.save_model(f'winning_ticket_{self.previous_sparsity[0]:.2f}.pth', self.learn.model)
-        print(f'Final Sparsity: {self.schedule.current_sparsity:}%')
-        if self.reset_end: self.sparsifier._reset_weights()
-        #self.sparsifier._clean_buffers()
-        self.schedule.reset()
-        self.sparsifier.print_sparsity()
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
mmm = nn.Sequential(nn.Linear(30,30), nn.Linear(1000,1000), nn.BatchNorm2d(3))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
mmm[0].weight = nn.Parameter(torch.randn_like(mmm[0].weight))
-mmm[1].weight = nn.Parameter(torch.randn_like(mmm[1].weight))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
mmm[0].weight = nn.Parameter(torch.Tensor(
-        [[ 0.421,  0.398,  0.056],
-        [-0.164, -0.456, -0.131],
-        [ 0.344,  0.335, -0.400]]))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
mmm[1].weight = nn.Parameter(torch.Tensor(
-        [[ 0.299,  0.512, -0.397],
-        [ 0.050, -0.472,  0.561],
-        [-0.028, -0.224,  0.427]]))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb = Sparsifier(mmm, granularity='weight', context='global', criteria=updating_magnitude_increase, layer_type=nn.Linear)
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
mmm[0].weight = nn.Parameter(torch.randn_like(mmm[0].weight))
-mmm[1].weight = nn.Parameter(torch.randn_like(mmm[1].weight))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
mmm[0].weight = nn.Parameter(torch.Tensor(
-        [[ 0.431,  0.399,  0.036],
-        [-0.114, -0.426, -0.172],
-        [ 0.348,  0.305, -0.456]]))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
mmm[1].weight = nn.Parameter(torch.Tensor(
-        [[ 0.212,  0.535, -0.398],
-        [ 0.067, -0.421,  0.456],
-        [-0.045, -0.287,  0.356]]))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_magnitude_increase(mmm[0])[None].mean(dim=0, keepdim=True).squeeze(0).mean()
-
- -
-
-
- -
-
- -
- - - -
-
tensor(-0.0334, grad_fn=<MeanBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_magnitude_increase(mmm[1])[None].mean(dim=0, keepdim=True).squeeze(0).mean()
-
- -
-
-
- -
-
- -
- - - -
-
tensor(-0.0008, grad_fn=<MeanBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_magnitude_increase.__call__??
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_magnitude_increase(mmm[0]).mean()
-
- -
-
-
- -
-
- -
- - - -
-
tensor(0.0117, grad_fn=<MeanBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_magnitude_increase(mmm[1]).mean()
-
- -
-
-
- -
-
- -
- - - -
-
tensor(0.0003, grad_fn=<MeanBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
global_scores = torch.cat([updating_magnitude_increase(m).view(-1) for m in mmm.modules() if isinstance(m, nn.Linear)]) # Get all scores
-global_scores = updating_magnitude_increase.descale(updating_magnitude_increase.rescale(global_scores))
-                
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_magnitude_increase.min_value
-
- -
-
-
- -
-
- -
- - - -
-
tensor(-4.8099, grad_fn=<SelectBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_magnitude_increase.rescale(updating_magnitude_increase(mmm[0]), updating_magnitude_increase.min_value)[None].mean(dim=1, keepdim=True).squeeze(0).mean()
-
- -
-
-
- -
-
- -
- - - -
-
tensor(4.8364, grad_fn=<MeanBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_magnitude_increase.min_value
-
- -
-
-
- -
-
- -
- - - -
-
tensor(-4.8099, grad_fn=<SelectBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_magnitude_increase.rescale(updating_magnitude_increase(mmm[1]), updating_magnitude_increase.min_value)[None].mean(dim=1, keepdim=True).squeeze(0).mean()
-
- -
-
-
- -
-
- -
- - - -
-
tensor(4.8095, grad_fn=<MeanBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb.prune_model(50)
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
mmm[0].weight
-
- -
-
-
- -
-
- -
- - - -
-
Parameter containing:
-tensor([[-0.4697, -0.0000,  0.0000,  0.6964, -0.0000, -0.0000,  0.4197,  0.6107,
-         -0.3471, -1.5264, -0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
-         -0.0000,  0.0000, -0.0000,  0.0000, -0.7765,  0.0000,  0.3007,  0.0000,
-          1.3610,  0.0000, -0.0000, -1.0565,  2.1248,  0.6461],
-        [ 1.3334,  0.0000, -0.0000, -0.7524, -0.0000,  0.0000,  2.1995, -1.8947,
-          0.0000, -0.0000,  1.3457,  0.5020, -0.0000,  0.0000, -0.5550,  0.0000,
-          0.7765,  1.5074,  0.0000,  0.7520, -0.8449, -0.0000, -0.0000,  1.1098,
-         -0.0000, -1.3247,  0.3156, -0.8737,  2.6677,  0.0000],
-        [-1.1004, -0.0000,  0.0000,  0.0000, -0.9998,  0.0000, -1.7682, -0.5072,
-         -0.0000,  0.0000, -1.1983, -0.5300, -0.1178,  0.0000,  1.9009,  0.0000,
-         -1.1917,  0.1676, -2.0444,  0.0000,  1.2305,  0.0000, -0.2910, -0.0000,
-          0.0000,  0.0000,  1.6175, -1.0137, -0.0000,  0.0000],
-        [ 0.0000, -0.0000,  2.1317,  0.0000, -0.0000, -0.0000,  0.0000,  0.0000,
-          0.6983, -0.0000,  0.0000,  1.3517, -0.0000, -0.0000, -0.0000,  0.0000,
-         -1.1758,  0.0000, -0.0000,  0.0000,  0.0000,  0.0000,  1.9039,  1.8416,
-         -1.8862, -0.0000,  0.0000,  1.0628,  1.3996, -0.0000],
-        [-1.7224,  0.0000, -0.0000,  0.4172,  0.0000,  1.4383,  0.9212,  0.0000,
-         -4.0557, -0.9632, -0.0000,  0.0000,  0.1785,  1.1216,  0.0000,  0.4738,
-         -0.2778,  0.0000, -1.4880, -0.0000, -0.0000, -0.4614,  0.0000, -1.1963,
-         -0.6461, -0.9679,  0.4609,  2.0274, -0.0000,  0.5464],
-        [-0.0000, -1.0287, -0.0000,  0.0000, -1.1861, -0.0000,  0.0000, -0.0000,
-          1.3118,  1.0544, -0.3401, -0.0000, -0.4149, -0.0000, -0.0000,  0.0000,
-         -0.0000, -1.7844, -0.0000, -1.1136, -0.0000,  0.0000,  0.0000, -0.8604,
-         -2.2147, -0.8733, -0.0000,  0.0000, -1.4399, -0.6658],
-        [-1.1911,  1.1144, -1.6733,  1.8671,  2.7911, -1.0089,  0.0000, -0.0000,
-          0.6511, -0.7973,  0.3811, -0.0000, -0.3755, -0.0000, -1.3389,  0.0000,
-         -0.0000, -0.0000,  0.0000, -0.0000, -0.0000,  0.0000, -0.0000,  1.2961,
-          1.4071,  1.4808,  0.7312,  0.0000, -1.0309, -1.7718],
-        [ 0.0000, -0.0000, -0.4198, -0.0000, -0.0000, -0.0000, -0.0000,  1.6534,
-         -0.0000,  0.5711,  0.0000, -0.0000, -0.0000,  0.8588, -0.8003,  0.7455,
-          0.0000, -0.6631, -1.9123,  1.5362,  0.0000, -0.0000,  0.0000, -0.0000,
-          2.4530, -0.6909,  0.0000,  0.0000,  0.0000, -0.0000],
-        [-2.3368,  0.0000, -0.4693, -0.0000, -0.0000,  0.0000,  0.0000, -0.0000,
-         -0.5719,  0.0000, -0.0000,  0.4853,  0.0000, -2.0872, -0.0000,  0.0000,
-          0.0000, -1.1446, -1.0138,  0.0000, -0.0000, -0.0000, -1.9052,  0.3071,
-         -0.0000, -0.0000,  1.5891, -1.5744, -0.3440,  0.0000],
-        [ 2.8317,  0.0000, -1.6325,  0.0000, -0.0000, -0.0000, -0.2990,  1.6706,
-         -0.0000,  2.4660,  0.6963,  1.1858,  0.3070,  0.0000,  0.0000,  1.1636,
-         -0.0000,  0.0000,  0.0000, -0.0000,  0.0000, -1.8598, -0.0000, -0.0000,
-          0.7630, -0.7048, -0.2939, -0.0000, -0.0000,  1.0744],
-        [ 0.7770, -1.7871,  0.0000,  0.0000,  0.0000, -0.5985, -0.0000,  1.4921,
-          0.0000,  0.0000, -0.1751, -0.0000,  0.8796,  0.7243, -2.2937, -0.0000,
-          0.0000, -1.0126,  1.3886, -1.0761, -2.1231, -1.7882,  0.0000, -0.0000,
-         -0.0000, -0.0000,  1.3518, -0.0000, -0.0000, -0.0000],
-        [-1.3437,  0.5383, -0.6913, -0.0000,  0.7758,  0.0000, -0.0000,  1.0919,
-          1.2038,  0.0000, -0.0000,  1.6652,  0.0000,  0.0000, -0.5265, -1.1390,
-          0.0000, -0.0000,  1.5858, -0.5444,  0.0000,  0.0000, -0.0000,  1.0638,
-          1.3580, -0.0000, -0.0000, -0.0000, -0.5559,  0.0000],
-        [ 0.0000, -2.2599, -0.0000,  0.0000, -1.0718, -0.0000, -1.3622, -0.0000,
-         -0.0000,  1.4162,  0.0000, -0.0000,  0.0000, -1.4348, -0.0000, -1.1684,
-         -2.0178,  1.7049,  0.0000,  1.5161, -0.2162,  1.5940,  0.7751, -2.7666,
-          0.0000, -2.0388,  1.4220,  0.4973, -0.6650, -0.0000],
-        [-1.0671, -0.3545, -0.0000, -0.0000, -0.0000, -0.0000,  0.0000,  0.9151,
-         -0.0000,  0.0000, -0.4876,  2.1664, -0.7885, -0.0000,  0.0000,  0.0000,
-         -0.0000,  0.0000, -1.2401, -0.0000, -0.1766,  1.2207, -0.7934, -0.0000,
-          2.7660,  0.0000, -0.4780, -1.3265, -0.9518, -0.0000],
-        [ 1.0592,  0.0000,  0.5933,  0.0000, -0.7153,  2.5650, -0.9966, -0.0000,
-         -0.7905,  0.0000, -0.0000, -1.1730, -0.0000,  0.0000,  0.0000, -0.0000,
-          1.4408, -0.0000,  1.0600,  0.0000, -0.0000, -0.0000, -0.0000,  0.0000,
-         -0.2164, -0.0000,  0.7504, -0.0000,  0.1538, -2.4244],
-        [-1.7428,  0.7574, -0.0000,  1.4356, -0.0000, -0.0000,  0.0000,  0.0000,
-          0.0000,  0.0000,  0.0000, -1.0921,  0.6738,  0.0000, -1.1633, -0.0000,
-          0.0000, -0.0000, -1.4767, -1.1387,  0.0000, -0.0000, -0.4164, -0.7833,
-          0.0000, -0.0000, -0.3589, -0.0000,  0.6142,  0.0000],
-        [ 0.0000,  0.0000,  0.0000,  1.2298, -0.7466,  1.7485,  0.4276, -2.8522,
-         -0.0000,  0.7918,  1.2715,  0.0000, -0.0000, -0.0000,  0.0000,  0.0000,
-         -1.0455,  0.7355, -1.0363,  1.2118, -0.8099,  1.0770, -1.9903,  2.3046,
-         -1.3997, -0.2232, -1.2704,  0.0000,  0.9377,  0.2910],
-        [ 0.0000,  1.0140, -0.0000, -1.6978,  1.6673,  0.3279,  0.0000,  0.8625,
-         -1.3830, -0.0000, -0.7740, -0.7179,  0.2901,  0.0000, -0.0000, -0.0000,
-          0.0000, -0.0000,  0.0000,  0.0000,  0.0000, -1.4254,  2.4915,  2.4178,
-         -0.0000,  0.0000, -0.0000, -0.4720,  0.0000,  0.9163],
-        [ 0.8333,  0.0000,  0.0000, -0.0000,  0.5730,  1.2512, -0.0000, -0.0786,
-         -0.0000, -0.0000,  1.2939, -1.5655,  0.0000, -0.0000, -0.0000, -0.0000,
-          0.8519,  1.0708,  0.0000,  1.1720, -0.8262, -0.9850,  1.3785,  0.0000,
-          0.0000, -0.0000,  0.0000, -0.0000,  0.8973, -0.0000],
-        [-0.0000, -2.0345, -0.4941, -0.2799,  0.0000, -0.0000, -1.1654,  0.0000,
-          0.0000,  0.0000,  0.0000,  1.1161,  1.6155,  0.9361,  1.0208,  0.0000,
-          0.0000,  1.8135, -1.3932, -0.0000,  0.0000, -0.0000,  1.1717,  0.9664,
-          0.0000,  1.1022,  0.4991, -1.6445, -1.3458, -0.0000],
-        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.4807,  0.0000,  1.8799,
-         -0.0000,  1.6223, -0.5247,  0.0000, -0.0000,  0.4769, -0.3047, -0.8768,
-         -1.1120, -1.4647,  2.7974,  2.2982, -0.0000, -2.3688, -0.0000, -2.0119,
-         -2.1981, -0.8121, -0.0000,  0.0000,  0.9923,  0.3450],
-        [-0.0000,  2.7514, -0.0000,  0.9442,  0.0000,  1.5887,  0.0000, -0.0000,
-         -0.0000, -0.0000, -0.0000, -0.0000,  2.5753, -0.0000,  0.8475,  0.0000,
-          1.1907, -1.2603, -1.3330, -0.0000,  0.0000,  0.0000,  0.2910, -1.2485,
-         -0.0000,  1.7944,  0.0000, -0.4205, -0.0000, -0.0000],
-        [ 2.7456,  0.0000,  1.3690, -0.0000,  0.9224, -0.0000, -0.4663,  0.9413,
-          0.0000, -1.5284,  0.0000, -0.0000, -0.3151, -0.4256,  0.0000, -0.5944,
-          0.0000,  0.9653,  1.1356, -1.0759, -0.4597,  0.4099, -1.7619, -0.0000,
-         -0.0000,  0.9033, -0.1668,  1.6416, -0.0000, -1.0526],
-        [-0.0000,  0.0000, -0.0000,  2.3039,  0.0000,  0.0000, -0.0000, -0.0000,
-         -0.0000, -0.0000,  0.8914, -0.0000, -2.4315,  0.0000,  0.0000, -0.0000,
-         -0.9459, -2.2884,  0.5098, -0.8177, -0.0000, -0.0000, -0.9245, -0.0000,
-          1.4152,  0.2603, -0.5356, -0.0000, -4.0684, -1.9600],
-        [-0.0000, -0.0000,  0.0000,  2.2072, -2.6049, -0.0000,  0.2980,  1.8473,
-         -1.7405, -0.0000, -0.9643,  0.0000,  1.2898, -0.0925, -0.8666, -0.9252,
-         -1.5343,  1.3880, -0.4778,  0.9656, -0.0000,  0.8843,  1.2774,  0.0000,
-         -0.0000,  0.0000,  0.0000,  0.0000, -0.2223,  0.0000],
-        [ 0.0000,  0.9122,  0.0000,  0.0000, -0.0000, -0.0000, -0.6227, -0.0000,
-         -0.5466,  0.0000,  1.3575,  0.0000, -1.2804,  0.0000,  1.3495,  1.1257,
-          0.7999,  0.0000, -0.7655, -0.0000,  0.0000,  0.0000,  0.0000, -0.0000,
-          0.0000,  0.9295, -0.0000, -1.4763, -2.0517,  0.8248],
-        [-0.0000, -0.0000, -0.5591,  1.4653, -1.9468, -0.0000, -0.0000,  1.6211,
-          0.7414,  0.2067,  0.0000,  0.0000, -1.4623, -1.3966,  0.3445, -0.3798,
-          0.3531,  0.9742, -2.7739, -0.2327, -0.7980, -1.3762, -1.5635,  0.0000,
-          0.7634, -1.3027, -0.0000,  0.0000,  0.4137,  0.6250],
-        [ 1.2535,  0.0000,  0.5963,  0.7589,  0.0000,  0.0000, -0.0000, -2.1696,
-         -0.9564,  1.0425, -0.0000, -0.5696,  1.1524,  1.4433, -0.0000, -0.0000,
-         -0.0000, -1.6302,  0.0000,  0.0000, -1.3254,  0.1585, -0.2137,  0.3994,
-         -0.0000,  0.0000, -0.8294, -1.8059, -1.6328, -0.0000],
-        [-1.3658,  0.0000,  0.5649,  1.3609, -0.2278,  0.7200, -0.1217,  0.9329,
-          0.0000, -0.2575,  0.0000, -0.4812, -0.6969, -0.0000, -0.0000, -0.0000,
-          0.0000,  0.8209, -0.0000,  0.8714,  0.4689,  0.0000,  0.0000,  0.0000,
-         -0.0000, -0.6730, -0.6239,  0.0000,  0.8790, -0.6795],
-        [ 0.3636, -0.0000, -0.0000, -0.5781,  0.4282, -0.0000, -0.7778,  1.1833,
-         -0.0000,  1.5292, -2.2093,  0.0000, -0.0000, -0.0000,  1.5653,  0.3961,
-         -1.0770, -0.0000, -0.0000,  0.0000,  0.0000,  1.0421, -0.0000, -0.0000,
-          2.1614,  0.8494,  0.0000, -0.0000, -0.0000,  0.0000]],
-       requires_grad=True)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
mmm[1].weight
-
- -
-
-
- -
-
- -
- - - -
-
Parameter containing:
-tensor([[ 0.0000,  0.5350, -0.3980],
-        [ 0.0670, -0.0000,  0.0000],
-        [-0.0450, -0.2870,  0.0000]], requires_grad=True)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb.model[0]._old_weights
-
- -
-
-
- -
-
- -
- - - -
-
tensor([[ 0.4310,  0.3990,  0.0000],
-        [-0.0000, -0.0000, -0.1720],
-        [ 0.3480,  0.0000, -0.4560]])
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb.model[1]._old_weights
-
- -
-
-
- -
-
- -
- - - -
-
tensor([[ 0.0000,  0.5350, -0.3980],
-        [ 0.0670, -0.0000,  0.0000],
-        [-0.0450, -0.2870,  0.0000]])
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
mmm[0].weight = nn.Parameter(torch.Tensor(
-        [[ 0.4112,  0.3787,  0.0563],
-        [-0.1742, -0.4459, -0.2345],
-        [ 0.4584,  0.3128, -0.7657]]))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
mmm[1].weight = nn.Parameter(torch.Tensor(
-        [[ 0.1982,  0.4983, -0.2781],
-        [ 0.0819, -0.2375,  0.],
-        [-0.0975, -0.1970,  0.2969]]))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_movement(mmm[0], 'weight')
-
- -
-
-
- -
-
- -
- - - -
-
tensor([[0.0198, 0.0203, 0.0563],
-        [0.1742, 0.4459, 0.0625],
-        [0.1104, 0.3128, 0.3097]], grad_fn=<AbsBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
updating_movement(mmm[1], 'weight')
-
- -
-
-
- -
-
- -
- - - -
-
tensor([[0.1982, 0.0367, 0.1199],
-        [0.0149, 0.2375, 0.0000],
-        [0.0525, 0.0900, 0.2969]], grad_fn=<AbsBackward0>)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
mmm[0]._mask
-
- -
-
-
- -
-
- -
- - - -
-
tensor([[1., 1., 0.],
-        [0., 0., 1.],
-        [1., 0., 1.]])
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
mmm[1]._mask
-
- -
-
-
- -
-
- -
- - - -
-
tensor([[0., 1., 1.],
-        [1., 0., 0.],
-        [1., 1., 0.]])
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb.prune_model(70)
-
- -
-
-
- -
-
- -
- -
-
tensor(-0.1199, grad_fn=<SelectBackward0>)
-tensor(-0.1199, grad_fn=<SelectBackward0>)
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
mmm[0].weight
-
- -
-
-
- -
-
- -
- - - -
-
Parameter containing:
-tensor([[ 0.4112,  0.0000,  0.0000],
-        [-0.0000, -0.0000, -0.2345],
-        [ 0.4584,  0.0000, -0.7657]], requires_grad=True)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
mmm[1].weight
-
- -
-
-
- -
-
- -
- - - -
-
Parameter containing:
-tensor([[ 0.0000,  0.0000, -0.0000],
-        [ 0.0819, -0.0000,  0.0000],
-        [-0.0975, -0.0000,  0.0000]], requires_grad=True)
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
mmm[0]._mask 
-
- -
-
-
- -
-
- -
- - - -
-
tensor([[1., 0., 0.],
-        [0., 0., 1.],
-        [1., 0., 1.]])
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
mmm[1]._mask
-
- -
-
-
- -
-
- -
- - - -
-
tensor([[0., 0., 0.],
-        [1., 0., 0.],
-        [1., 0., 0.]])
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class Criteria():
-    def __init__(self, f, needs_init=False, needs_update=False, output_f=None, return_init=False):
-        store_attr()
-        assert (needs_init and needs_update)==False, "The init values will be overwritten by the updating ones."
-        self.min_value=0
-
-    def __call__(self, m, g):
-        if self.needs_update and hasattr(m, '_old_weights') == False:
-            m.register_buffer("_old_weights", m._init_weights.clone()) # If the previous value of weights is not known, take the initial value
-
-        if g in granularities[m.__class__.__name__]:
-            dim = granularities[m.__class__.__name__][g]
-            wf = self.f(m.weight)[None].mean(dim=dim, keepdim=True).squeeze(0)
-            if self.needs_init: wi = self.f(m._init_weights)[None].mean(dim=dim, keepdim=True).squeeze(0)
-            if self.needs_update: wi = self.f(m._old_weights)[None].mean(dim=dim, keepdim=True).squeeze(0)
-
-        else: raise NameError('Invalid Granularity')
-            
-        if hasattr(m, '_mask') == False: m.register_buffer("_mask", torch.ones_like(wf)) # Put the mask into a buffer
-
-        if self.output_f: output = self.output_f(wf, wi)
-        elif self.return_init: output = wi
-        else: output = wf
-        return output
-    
-    def rescale(self, w, min_value=None):
-        self.min_value = min_value if min_value else w.view(-1)[w.argmin()]
-        output =  w + self.min_value.abs() + torch.finfo(torch.float32).eps
-        return output
-    
-    def descale(self, w):
-        output = w - self.min_value.abs()
-        return output
-        
-    def update_weights(self, m):
-        if self.needs_update: 
-            m._old_weights = m.weight.data.clone() # The current value becomes the old one for the next iteration
-
-updating_magnitude_increase = Criteria(torch.abs, needs_update=True, output_f= torch.sub)
-magnitude_increase = Criteria(torch.abs, needs_init=True, output_f= torch.sub)
-
-
-class SparsifyCallback(Callback):
-    def __init__(self, sparsity, granularity, context, criteria, schedule, lth=False, rewind_epoch=0, reset_end=False, save_tickets=False, model=None, round_to=None, layer_type=nn.Conv2d):
-        store_attr()
-        self.sparsity = listify(self.sparsity)
-
-    def before_fit(self):
-        print(f'Pruning of {self.granularity} until a sparsity of {self.sparsity}%')
-        assert self.schedule.start_pct*self.n_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'
-        model = self.model if self.model else self.learn.model
-        self.sparsifier = Sparsifier(model, self.granularity, self.context, self.criteria, self.layer_type)
-
-    def before_epoch(self):
-        if self.epoch == self.rewind_epoch:
-            print(f'Saving Weights at epoch {self.epoch}')
-            self.sparsifier._save_weights()
-
-    def before_batch(self):
-        self.current_sparsity = self.schedule(self.sparsity, round(self.pct_train,3))
-        if self.schedule.pruned and self.training:
-            if self.lth and self.save_tickets:
-                print('Saving Intermediate Ticket')
-                self.sparsifier.save_model(f'winning_ticket_{self.previous_sparsity[0]:.2f}.pth', self.learn.model)
-            self.sparsifier.prune_model(self.current_sparsity, self.round_to)
-
-    def after_step(self):
-        if self.lth and self.schedule.pruned:
-            print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
-            self.sparsifier._reset_weights(self.learn.model)
-        self.schedule.after_pruned()
-        self.sparsifier._apply_masks()
-
-    def after_epoch(self):
-        sparsity_str = [float(f"%0.2f"%sp) for sp in self.current_sparsity]
-        print(f'Sparsity at the end of epoch {self.epoch}: {sparsity_str}%')
-        self.sparsifier.print_sparsity()
-
-    def after_fit(self):
-        if self.save_tickets:
-            print('Saving Final Ticket')
-            self.sparsifier.save_model(f'winning_ticket_{self.previous_sparsity[0]:.2f}.pth', self.learn.model)
-        print(f'Final Sparsity: {self.schedule.current_sparsity:}%')
-        if self.reset_end: self.sparsifier._reset_weights()
-        self.sparsifier._clean_buffers()
-        self.schedule.reset()
-        self.sparsifier.print_sparsity()
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
large_final = Criteria(torch.abs)
-movement = Criteria(noop, needs_init=True, output_f= lambda x,y: torch.abs(torch.sub(x,y)))
-updating_movement = Criteria(noop, needs_update=True, output_f= lambda x,y: torch.abs(torch.sub(x,y)))
-large_init_large_final = Criteria(torch.abs, needs_init=True, output_f=torch.min)
-large_init = Criteria(torch.abs, needs_init=True, return_init=True)
-#updating_movmag = Criteria(noop, needs_update=True, output_f=lambda x,y: torch.mul(x, torch.sub(x,y)))
-#updating_movmag = Criteria(noop, needs_update=True, output_f=lambda x,y: torch.mul(x, torch.sub(y,x)))
-updating_movmag = Criteria(noop, needs_update=True, output_f=lambda x,y: torch.abs(torch.mul(x, torch.sub(x,y))))
-squared_final = Criteria(torch.square)
-updating_movement2 = Criteria(noop, needs_update=True, output_f= lambda x,y: -torch.abs(torch.sub(x,y)))
-updating_magnitude_increase = Criteria(torch.abs, needs_update=True, output_f= lambda x,y: torch.abs(torch.sub(x,y)))
-updating_magnitude_increase = Criteria(torch.square, needs_update=True, output_f= lambda x,y: torch.abs(torch.sub(x,y)))
-updating_movement = Criteria(torch.square, needs_update=True, output_f= lambda x,y: torch.abs(torch.sub(y,x)))
-updating_movement = Criteria(torch.abs, needs_update=True, output_f= lambda x,y: torch.abs(torch.sub(-x,y)))
-updating_movement = Criteria(noop, needs_init=True, output_f= lambda x,y: torch.abs(torch.sub(x,-y)))
-updating_movmag = Criteria(noop, needs_update=True, output_f=lambda x,y: torch.abs(torch.mul(torch.square(x), torch.sub(x,y))))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
learn = cnn_learner(dls, resnet18, metrics=accuracy)
-learn.unfreeze()
-
-sp_cb = SparsifyCallback(sparsity=50, granularity='weight', context='global', criteria=updating_magnitude_increase, schedule=cos)
-learn.fit_one_cycle(10, cbs=sp_cb)
-
- -
-
-
- -
-
- -
- -
-
Pruning of weight until a sparsity of [50]%
-Saving Weights at epoch 0
-
-
-
- -
- - -
- -
- -
- -
- - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
epochtrain_lossvalid_lossaccuracytime
00.7615390.5710410.82611600:09
10.4898180.3420610.84167800:09
20.3360390.3503400.86671200:09
30.2510310.2706180.88159700:09
40.1824630.2286120.90324800:09
50.1423170.2104870.90866000:09
60.1159570.1982450.92692800:09
70.0696800.2204490.92692800:09
80.0475300.2280660.92151600:09
90.0460270.2368400.92016200:09
- -
- -
- -
-
Sparsity at the end of epoch 0: [1.22]%
-Sparsity in Conv2d 2: 13.57%
-Sparsity in Conv2d 8: 13.23%
-Sparsity in Conv2d 11: 0.87%
-Sparsity in Conv2d 14: 0.84%
-Sparsity in Conv2d 17: 0.76%
-Sparsity in Conv2d 21: 0.84%
-Sparsity in Conv2d 24: 0.81%
-Sparsity in Conv2d 27: 0.87%
-Sparsity in Conv2d 30: 0.82%
-Sparsity in Conv2d 33: 0.84%
-Sparsity in Conv2d 37: 0.90%
-Sparsity in Conv2d 40: 0.96%
-Sparsity in Conv2d 43: 0.79%
-Sparsity in Conv2d 46: 1.01%
-Sparsity in Conv2d 49: 1.07%
-Sparsity in Conv2d 53: 1.06%
-Sparsity in Conv2d 56: 1.15%
-Sparsity in Conv2d 59: 0.85%
-Sparsity in Conv2d 62: 1.11%
-Sparsity in Conv2d 65: 1.47%
-Sparsity at the end of epoch 1: [4.77]%
-Sparsity in Conv2d 2: 14.39%
-Sparsity in Conv2d 8: 14.35%
-Sparsity in Conv2d 11: 2.14%
-Sparsity in Conv2d 14: 2.13%
-Sparsity in Conv2d 17: 1.91%
-Sparsity in Conv2d 21: 2.07%
-Sparsity in Conv2d 24: 2.40%
-Sparsity in Conv2d 27: 1.94%
-Sparsity in Conv2d 30: 2.42%
-Sparsity in Conv2d 33: 2.47%
-Sparsity in Conv2d 37: 2.68%
-Sparsity in Conv2d 40: 3.05%
-Sparsity in Conv2d 43: 2.48%
-Sparsity in Conv2d 46: 3.47%
-Sparsity in Conv2d 49: 3.77%
-Sparsity in Conv2d 53: 3.85%
-Sparsity in Conv2d 56: 5.24%
-Sparsity in Conv2d 59: 2.57%
-Sparsity in Conv2d 62: 4.87%
-Sparsity in Conv2d 65: 6.36%
-Sparsity at the end of epoch 2: [10.31]%
-Sparsity in Conv2d 2: 14.77%
-Sparsity in Conv2d 8: 15.15%
-Sparsity in Conv2d 11: 3.03%
-Sparsity in Conv2d 14: 3.05%
-Sparsity in Conv2d 17: 2.92%
-Sparsity in Conv2d 21: 3.01%
-Sparsity in Conv2d 24: 3.60%
-Sparsity in Conv2d 27: 2.69%
-Sparsity in Conv2d 30: 3.66%
-Sparsity in Conv2d 33: 3.88%
-Sparsity in Conv2d 37: 4.21%
-Sparsity in Conv2d 40: 4.93%
-Sparsity in Conv2d 43: 3.99%
-Sparsity in Conv2d 46: 5.91%
-Sparsity in Conv2d 49: 6.53%
-Sparsity in Conv2d 53: 7.01%
-Sparsity in Conv2d 56: 12.08%
-Sparsity in Conv2d 59: 4.33%
-Sparsity in Conv2d 62: 11.75%
-Sparsity in Conv2d 65: 14.74%
-Sparsity at the end of epoch 3: [17.27]%
-Sparsity in Conv2d 2: 15.51%
-Sparsity in Conv2d 8: 16.17%
-Sparsity in Conv2d 11: 4.27%
-Sparsity in Conv2d 14: 4.19%
-Sparsity in Conv2d 17: 4.13%
-Sparsity in Conv2d 21: 4.23%
-Sparsity in Conv2d 24: 5.04%
-Sparsity in Conv2d 27: 3.80%
-Sparsity in Conv2d 30: 5.23%
-Sparsity in Conv2d 33: 5.60%
-Sparsity in Conv2d 37: 6.10%
-Sparsity in Conv2d 40: 7.26%
-Sparsity in Conv2d 43: 5.93%
-Sparsity in Conv2d 46: 8.78%
-Sparsity in Conv2d 49: 9.91%
-Sparsity in Conv2d 53: 10.95%
-Sparsity in Conv2d 56: 19.87%
-Sparsity in Conv2d 59: 6.63%
-Sparsity in Conv2d 62: 21.28%
-Sparsity in Conv2d 65: 25.43%
-Sparsity at the end of epoch 4: [25.0]%
-Sparsity in Conv2d 2: 16.67%
-Sparsity in Conv2d 8: 17.75%
-Sparsity in Conv2d 11: 5.93%
-Sparsity in Conv2d 14: 5.75%
-Sparsity in Conv2d 17: 5.85%
-Sparsity in Conv2d 21: 6.14%
-Sparsity in Conv2d 24: 7.22%
-Sparsity in Conv2d 27: 5.31%
-Sparsity in Conv2d 30: 7.46%
-Sparsity in Conv2d 33: 8.03%
-Sparsity in Conv2d 37: 8.63%
-Sparsity in Conv2d 40: 10.37%
-Sparsity in Conv2d 43: 8.40%
-Sparsity in Conv2d 46: 12.57%
-Sparsity in Conv2d 49: 14.23%
-Sparsity in Conv2d 53: 15.97%
-Sparsity in Conv2d 56: 28.28%
-Sparsity in Conv2d 59: 9.90%
-Sparsity in Conv2d 62: 31.62%
-Sparsity in Conv2d 65: 36.79%
-Sparsity at the end of epoch 5: [32.73]%
-Sparsity in Conv2d 2: 17.86%
-Sparsity in Conv2d 8: 19.73%
-Sparsity in Conv2d 11: 8.12%
-Sparsity in Conv2d 14: 8.07%
-Sparsity in Conv2d 17: 7.99%
-Sparsity in Conv2d 21: 8.58%
-Sparsity in Conv2d 24: 9.92%
-Sparsity in Conv2d 27: 7.19%
-Sparsity in Conv2d 30: 10.24%
-Sparsity in Conv2d 33: 11.03%
-Sparsity in Conv2d 37: 11.78%
-Sparsity in Conv2d 40: 14.03%
-Sparsity in Conv2d 43: 11.42%
-Sparsity in Conv2d 46: 17.01%
-Sparsity in Conv2d 49: 19.24%
-Sparsity in Conv2d 53: 21.69%
-Sparsity in Conv2d 56: 36.67%
-Sparsity in Conv2d 59: 13.92%
-Sparsity in Conv2d 62: 41.53%
-Sparsity in Conv2d 65: 47.52%
-Sparsity at the end of epoch 6: [39.69]%
-Sparsity in Conv2d 2: 19.31%
-Sparsity in Conv2d 8: 22.28%
-Sparsity in Conv2d 11: 11.07%
-Sparsity in Conv2d 14: 10.78%
-Sparsity in Conv2d 17: 10.89%
-Sparsity in Conv2d 21: 11.64%
-Sparsity in Conv2d 24: 13.37%
-Sparsity in Conv2d 27: 9.96%
-Sparsity in Conv2d 30: 13.68%
-Sparsity in Conv2d 33: 14.68%
-Sparsity in Conv2d 37: 15.57%
-Sparsity in Conv2d 40: 18.30%
-Sparsity in Conv2d 43: 15.03%
-Sparsity in Conv2d 46: 21.92%
-Sparsity in Conv2d 49: 24.58%
-Sparsity in Conv2d 53: 27.60%
-Sparsity in Conv2d 56: 44.14%
-Sparsity in Conv2d 59: 18.41%
-Sparsity in Conv2d 62: 49.92%
-Sparsity in Conv2d 65: 56.39%
-Sparsity at the end of epoch 7: [45.23]%
-Sparsity in Conv2d 2: 21.10%
-Sparsity in Conv2d 8: 24.27%
-Sparsity in Conv2d 11: 13.47%
-Sparsity in Conv2d 14: 13.09%
-Sparsity in Conv2d 17: 13.26%
-Sparsity in Conv2d 21: 14.07%
-Sparsity in Conv2d 24: 16.16%
-Sparsity in Conv2d 27: 12.49%
-Sparsity in Conv2d 30: 16.46%
-Sparsity in Conv2d 33: 17.55%
-Sparsity in Conv2d 37: 18.62%
-Sparsity in Conv2d 40: 21.72%
-Sparsity in Conv2d 43: 18.00%
-Sparsity in Conv2d 46: 25.96%
-Sparsity in Conv2d 49: 28.86%
-Sparsity in Conv2d 53: 32.31%
-Sparsity in Conv2d 56: 49.68%
-Sparsity in Conv2d 59: 22.11%
-Sparsity in Conv2d 62: 57.19%
-Sparsity in Conv2d 65: 63.16%
-Sparsity at the end of epoch 8: [48.78]%
-Sparsity in Conv2d 2: 23.46%
-Sparsity in Conv2d 8: 26.24%
-Sparsity in Conv2d 11: 15.71%
-Sparsity in Conv2d 14: 15.49%
-Sparsity in Conv2d 17: 15.37%
-Sparsity in Conv2d 21: 16.57%
-Sparsity in Conv2d 24: 18.71%
-Sparsity in Conv2d 27: 14.66%
-Sparsity in Conv2d 30: 19.02%
-Sparsity in Conv2d 33: 20.18%
-Sparsity in Conv2d 37: 21.35%
-Sparsity in Conv2d 40: 24.63%
-Sparsity in Conv2d 43: 20.71%
-Sparsity in Conv2d 46: 29.16%
-Sparsity in Conv2d 49: 32.13%
-Sparsity in Conv2d 53: 35.85%
-Sparsity in Conv2d 56: 53.37%
-Sparsity in Conv2d 59: 25.14%
-Sparsity in Conv2d 62: 61.19%
-Sparsity in Conv2d 65: 67.01%
-Sparsity at the end of epoch 9: [50.0]%
-Sparsity in Conv2d 2: 25.31%
-Sparsity in Conv2d 8: 27.31%
-Sparsity in Conv2d 11: 16.86%
-Sparsity in Conv2d 14: 16.77%
-Sparsity in Conv2d 17: 16.53%
-Sparsity in Conv2d 21: 17.72%
-Sparsity in Conv2d 24: 19.72%
-Sparsity in Conv2d 27: 15.84%
-Sparsity in Conv2d 30: 20.10%
-Sparsity in Conv2d 33: 21.17%
-Sparsity in Conv2d 37: 22.36%
-Sparsity in Conv2d 40: 25.59%
-Sparsity in Conv2d 43: 21.91%
-Sparsity in Conv2d 46: 30.15%
-Sparsity in Conv2d 49: 33.14%
-Sparsity in Conv2d 53: 36.91%
-Sparsity in Conv2d 56: 54.49%
-Sparsity in Conv2d 59: 26.49%
-Sparsity in Conv2d 62: 62.55%
-Sparsity in Conv2d 65: 68.04%
-Final Sparsity: [50.0]%
-Sparsity in Conv2d 2: 25.31%
-Sparsity in Conv2d 8: 27.31%
-Sparsity in Conv2d 11: 16.86%
-Sparsity in Conv2d 14: 16.77%
-Sparsity in Conv2d 17: 16.53%
-Sparsity in Conv2d 21: 17.72%
-Sparsity in Conv2d 24: 19.72%
-Sparsity in Conv2d 27: 15.84%
-Sparsity in Conv2d 30: 20.10%
-Sparsity in Conv2d 33: 21.17%
-Sparsity in Conv2d 37: 22.36%
-Sparsity in Conv2d 40: 25.59%
-Sparsity in Conv2d 43: 21.91%
-Sparsity in Conv2d 46: 30.15%
-Sparsity in Conv2d 49: 33.14%
-Sparsity in Conv2d 53: 36.91%
-Sparsity in Conv2d 56: 54.49%
-Sparsity in Conv2d 59: 26.49%
-Sparsity in Conv2d 62: 62.55%
-Sparsity in Conv2d 65: 68.04%
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
Sparsifier._compute_threshold??
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
zs, ps = 0,0
-for k,m in enumerate(learn.model.modules()):
-    if isinstance(m, nn.Conv2d):
-        zs += torch.sum(m.weight == 0)
-        ps += float(m.weight.nelement())
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
large_final weight local: 93.3
-upd mag increase weight local: 88.3
-mag increase weight local: 87.3
-    
-upd mag increase weight global: 
-mag increase weight global: 86.9
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
zs/ps
-
- -
-
-
- -
-
- -
- - - -
-
tensor(0.3000, device='cuda:0')
-
- -
- -
-
- -
- {% endraw %} - -
-
-

Last version

-
-
-
- {% raw %} - -
-
- -
-
-
class Criteria():
-    def __init__(self, f, needs_init=False, needs_update=False, output_f=None, return_init=False):
-        store_attr()
-        assert (needs_init and needs_update)==False, "The init values will be overwritten by the updating ones."
-
-    def __call__(self, m, g):
-        self.min_value=None
-        if self.needs_update and hasattr(m, '_old_weights') == False:
-            m.register_buffer("_old_weights", m._init_weights.clone()) # If the previous value of weights is not known, take the initial value
-
-        if g in granularities[m.__class__.__name__]:
-            dim = granularities[m.__class__.__name__][g]
-            wf = self.f(m.weight)[None].mean(dim=dim, keepdim=True).squeeze(0)
-            if self.needs_init: wi = self.f(m._init_weights)[None].mean(dim=dim, keepdim=True).squeeze(0)
-            if self.needs_update: wi = self.f(m._old_weights)[None].mean(dim=dim, keepdim=True).squeeze(0)
-
-        else: raise NameError('Invalid Granularity')
-            
-        if hasattr(m, '_mask') == False: m.register_buffer("_mask", torch.ones_like(wf)) # Put the mask into a buffer
-
-        if self.output_f: output = self.output_f(wf, wi)
-        elif self.return_init: output = wi
-        else: output = wf
-        return output
-    
-    def rescale(self, w, min_value=None):
-        self.min_value = min_value if min_value else w.view(-1)[w.argmin()]
-        output =  w + self.min_value.abs() + torch.finfo(torch.float32).eps
-        return output
-    
-    def descale(self, w):
-        output = w - self.min_value.abs()
-        return output
-        
-    def update_weights(self, m):
-        if self.needs_update: 
-            m._old_weights = m.weight.data.clone() # The current value becomes the old one for the next iteration
-            
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class Sparsifier():
-    def __init__(self, model, granularity, context, criteria, layer_type=nn.Conv2d):
-        store_attr()
-        self._save_weights() # Save the original weights
-
-    def prune_layer(self, m, sparsity, round_to=None, k=0):
-        scores = self.criteria(m, self.granularity)
-        setattr(m, '_mask', self._compute_mask(m, scores, sparsity, round_to))
-        self._apply(m)
-        self.criteria.update_weights(m)
-
-    def prune_model(self, sparsity, round_to=None):
-        self.threshold=None
-        sparsity_list = listify(sparsity)
-        if len(sparsity_list)>1: assert self.context=='local', f"A list of sparsities cannot be passed using: {self.context}"
-        sparsities = cycle(sparsity_list) if len(sparsity_list)==1 else iter(sparsity_list)
-        mods = list(self.model.modules())
-        for k,m in enumerate(self.model.modules()):
-            if isinstance(m, self.layer_type): 
-                sp = next(sparsities)
-                self.prune_layer(m, sp, round_to,k)
-                if isinstance(mods[k+1], nn.modules.batchnorm._BatchNorm): self.prune_batchnorm(m, mods[k+1])
-                
-    def prune_batchnorm(self, m, bn):
-        mask = getattr(m, "_mask", None)
-        if self.granularity == 'filter' and mask is not None:
-            bn.weight.data.mul_(mask.squeeze())
-            bn.bias.data.mul_(mask.squeeze())
-            
-    def _apply_masks(self):
-        for m in self.model.modules():
-            if isinstance(m, self.layer_type):
-                self._apply(m)
-        
-    def _apply(self, m):
-        mask = getattr(m, "_mask", None)
-        if mask is not None: m.weight.data.mul_(mask)
-        if self.granularity == 'filter' and m.bias is not None:
-            if mask is not None: m.bias.data.mul_(mask.squeeze()) # We want to prune the bias when pruning filters
-    
-    def _reset_weights(self, model=None):
-        if not model: model=self.model
-        for m in model.modules():
-            if hasattr(m, 'weight'):
-                init_weights = getattr(m, "_init_weights", m.weight)
-                init_biases = getattr(m, "_init_biases", m.bias)
-                with torch.no_grad():
-                    if m.weight is not None: m.weight.copy_(init_weights)
-                    if m.bias is not None: m.bias.copy_(init_biases)
-                self._apply(m)
-            if isinstance(m, nn.modules.batchnorm._BatchNorm): m.reset_parameters()
-                
-    def _save_weights(self):
-        for m in self.model.modules():
-            if hasattr(m, 'weight'):              
-                m.register_buffer("_init_weights", m.weight.clone())
-                b = getattr(m, 'bias', None)
-                if b is not None: m.register_buffer("_init_biases", b.clone())
-                    
-    def save_model(self, path, model=None):
-        if not model: model=self.model
-        tmp_model = pickle.loads(pickle.dumps(model))
-        self._reset_weights(tmp_model)
-        self._clean_buffers(tmp_model)
-        torch.save(tmp_model, path)
-
-    def _clean_buffers(self, model=None):
-        if not model: model=self.model
-        for m in model.modules():
-            if hasattr(m, 'weight'):
-                if hasattr(m, '_mask'): del m._buffers["_mask"]
-                if hasattr(m, '_init_weights'): del m._buffers["_init_weights"]
-                if hasattr(m, '_init_biases'): del m._buffers["_init_biases"]
-    
-    def _compute_threshold(self, scores, sparsity):
-        if self.context == 'global':
-            if self.threshold is None: 
-                global_scores = torch.cat([self.criteria(m, self.granularity).view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)]) # Get all scores
-                global_mask = torch.cat([m._mask.view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)]) # Get all masks
-                global_scores, self.min_value = self.criteria.rescale(global_scores).mul_(global_mask.squeeze())
-                global_scores = self.criteria.descale(global_scores.mul_(global_mask.squeeze()))
-                self.threshold = torch.quantile(global_scores, sparsity/100) # Compute the threshold globally (only once per model pruning)
-            return self.threshold
-        elif self.context == 'local':
-            return torch.quantile(scores.view(-1), sparsity/100) # Compute the threshold locally
-        else: raise NameError('Invalid Context')
-
-    def _rounded_sparsity(self, n_to_prune, round_to):
-        return max(round_to*torch.ceil(n_to_prune/round_to), round_to)
-
-    def _compute_mask(self, m, scores, sparsity, round_to):
-        if self.context=='local': scores = self.criteria.descale(self.criteria.rescale(scores)[0].mul_(m._mask)) # We don't want to scale individual layers in global
-
-        self.threshold = self._compute_threshold(scores, sparsity)
-        if self.context=='global': scores = self.criteria.descale(self.criteria.rescale(scores, self.min_value)[0].mul_(m._mask))
-        
-        if round_to:
-            n_to_keep = sum(scores.ge(self.threshold)).squeeze()
-            self.threshold = torch.topk(scores.squeeze(), int(self._rounded_sparsity(n_to_keep, round_to)))[0].min()
-        if self.threshold > scores.max(): self.threshold = scores.max() # Make sure we don't remove every weight of a given layer
-        return scores.ge(self.threshold).to(dtype=scores.dtype)
-    
-    def print_sparsity(self):
-        for k,m in enumerate(self.model.modules()):
-            if isinstance(m, self.layer_type):
-                print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class Sparsifier():
-    def __init__(self, model, granularity, context, criteria, layer_type=nn.Conv2d):
-        store_attr()
-        self._save_weights() # Save the original weights
-
-    def prune_layer(self, m, sparsity, round_to=None, k=0):
-        scores = self.criteria(m, self.granularity)
-        setattr(m, '_mask', self._compute_mask(m, scores, sparsity, round_to))
-        self._apply(m)
-        self.criteria.update_weights(m)
-
-    def prune_model(self, sparsity, round_to=None):
-        self.threshold=None
-        sparsity_list = listify(sparsity)
-        if len(sparsity_list)>1: assert self.context=='local', f"A list of sparsities cannot be passed using: {self.context}"
-        sparsities = cycle(sparsity_list) if len(sparsity_list)==1 else iter(sparsity_list)
-        mods = list(self.model.modules())
-        for k,m in enumerate(self.model.modules()):
-            if isinstance(m, self.layer_type): 
-                sp = next(sparsities)
-                self.prune_layer(m, sp, round_to,k)
-                if isinstance(mods[k+1], nn.modules.batchnorm._BatchNorm): self.prune_batchnorm(m, mods[k+1])
-                
-    def prune_batchnorm(self, m, bn):
-        mask = getattr(m, "_mask", None)
-        if self.granularity == 'filter' and mask is not None:
-            bn.weight.data.mul_(mask.squeeze())
-            bn.bias.data.mul_(mask.squeeze())
-            
-    def _apply_masks(self):
-        for m in self.model.modules():
-            if isinstance(m, self.layer_type):
-                self._apply(m)
-        
-    def _apply(self, m):
-        mask = getattr(m, "_mask", None)
-        if mask is not None: m.weight.data.mul_(mask)
-        if self.granularity == 'filter' and m.bias is not None:
-            if mask is not None: m.bias.data.mul_(mask.squeeze()) # We want to prune the bias when pruning filters
-    
-    def _reset_weights(self, model=None):
-        if not model: model=self.model
-        for m in model.modules():
-            if hasattr(m, 'weight'):
-                init_weights = getattr(m, "_init_weights", m.weight)
-                init_biases = getattr(m, "_init_biases", m.bias)
-                with torch.no_grad():
-                    if m.weight is not None: m.weight.copy_(init_weights)
-                    if m.bias is not None: m.bias.copy_(init_biases)
-                self._apply(m)
-            if isinstance(m, nn.modules.batchnorm._BatchNorm): m.reset_parameters()
-                
-    def _save_weights(self):
-        for m in self.model.modules():
-            if hasattr(m, 'weight'):              
-                m.register_buffer("_init_weights", m.weight.clone())
-                b = getattr(m, 'bias', None)
-                if b is not None: m.register_buffer("_init_biases", b.clone())
-                    
-    def save_model(self, path, model=None):
-        if not model: model=self.model
-        tmp_model = pickle.loads(pickle.dumps(model))
-        self._reset_weights(tmp_model)
-        self._clean_buffers(tmp_model)
-        torch.save(tmp_model, path)
-
-    def _clean_buffers(self, model=None):
-        if not model: model=self.model
-        for m in model.modules():
-            if hasattr(m, 'weight'):
-                if hasattr(m, '_mask'): del m._buffers["_mask"]
-                if hasattr(m, '_init_weights'): del m._buffers["_init_weights"]
-                if hasattr(m, '_init_biases'): del m._buffers["_init_biases"]
-    
-    def _compute_threshold(self, scores, sparsity):
-        if self.context == 'global':
-            if self.threshold is None: 
-                global_scores = torch.cat([self.criteria(m, self.granularity).view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)]) # Get all scores
-                global_mask = torch.cat([m._mask.view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)]) # Get all masks
-                global_scores = self.criteria.descale(self.criteria.rescale(global_scores).mul_(global_mask.squeeze()))
-                self.threshold = torch.quantile(global_scores, sparsity/100) # Compute the threshold globally (only once per model pruning)
-            return self.threshold
-        elif self.context == 'local':
-            return torch.quantile(scores.view(-1), sparsity/100) # Compute the threshold locally
-        else: raise NameError('Invalid Context')
-
-    def _rounded_sparsity(self, n_to_prune, round_to):
-        return max(round_to*torch.ceil(n_to_prune/round_to), round_to)
-    
-    
-    
-    
-
-    def _compute_mask(self, m, scores, sparsity, round_to):
-        if self.context=='local': scores = self.criteria.descale(self.criteria.rescale(scores).mul_(m._mask))
-
-        self.threshold = self._compute_threshold(scores, sparsity)
-        if self.context=='global': 
-            print(self.criteria.min_value)
-            scores = self.criteria.descale(self.criteria.rescale(scores, self.criteria.min_value).mul_(m._mask))
-        
-        if round_to:
-            n_to_keep = sum(scores.ge(self.threshold)).squeeze()
-            self.threshold = torch.topk(scores.squeeze(), int(self._rounded_sparsity(n_to_keep, round_to)))[0].min()
-        if self.threshold > scores.max(): self.threshold = scores.max() # Make sure we don't remove every weight of a given layer
-        return scores.ge(self.threshold).to(dtype=scores.dtype)
-    
-    def print_sparsity(self):
-        for k,m in enumerate(self.model.modules()):
-            if isinstance(m, self.layer_type):
-                print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class Sparsifier():
-    def __init__(self, model, granularity, context, criteria, layer_type=nn.Conv2d):
-        store_attr()
-        self._save_weights() # Save the original weights
-
-    def prune_layer(self, m, sparsity, round_to=None, k=0):
-        scores = self.criteria(m, self.granularity)
-        setattr(m, '_mask', self._compute_mask(m, scores, sparsity, round_to))
-        self._apply(m)
-        self.criteria.update_weights(m)
-
-    def prune_model(self, sparsity, round_to=None):
-        self.threshold=None
-        sparsity_list = listify(sparsity)
-        if len(sparsity_list)>1: assert self.context=='local', f"A list of sparsities cannot be passed using: {self.context}"
-        sparsities = cycle(sparsity_list) if len(sparsity_list)==1 else iter(sparsity_list)
-        mods = list(self.model.modules())
-        for k,m in enumerate(self.model.modules()):
-            if isinstance(m, self.layer_type): 
-                sp = next(sparsities)
-                self.prune_layer(m, sp, round_to,k)
-                if isinstance(mods[k+1], nn.modules.batchnorm._BatchNorm): self.prune_batchnorm(m, mods[k+1])
-                
-    def prune_batchnorm(self, m, bn):
-        mask = getattr(m, "_mask", None)
-        if self.granularity == 'filter' and mask is not None:
-            bn.weight.data.mul_(mask.squeeze())
-            bn.bias.data.mul_(mask.squeeze())
-            
-    def _apply_masks(self):
-        for m in self.model.modules():
-            if isinstance(m, self.layer_type):
-                self._apply(m)
-        
-    def _apply(self, m):
-        mask = getattr(m, "_mask", None)
-        if mask is not None: m.weight.data.mul_(mask)
-        if self.granularity == 'filter' and m.bias is not None:
-            if mask is not None: m.bias.data.mul_(mask.squeeze()) # We want to prune the bias when pruning filters
-    
-    def _reset_weights(self, model=None):
-        if not model: model=self.model
-        for m in model.modules():
-            if hasattr(m, 'weight'):
-                init_weights = getattr(m, "_init_weights", m.weight)
-                init_biases = getattr(m, "_init_biases", m.bias)
-                with torch.no_grad():
-                    if m.weight is not None: m.weight.copy_(init_weights)
-                    if m.bias is not None: m.bias.copy_(init_biases)
-                self._apply(m)
-            if isinstance(m, nn.modules.batchnorm._BatchNorm): m.reset_parameters()
-                
-    def _save_weights(self):
-        for m in self.model.modules():
-            if hasattr(m, 'weight'):              
-                m.register_buffer("_init_weights", m.weight.clone())
-                b = getattr(m, 'bias', None)
-                if b is not None: m.register_buffer("_init_biases", b.clone())
-                    
-    def save_model(self, path, model=None):
-        if not model: model=self.model
-        tmp_model = pickle.loads(pickle.dumps(model))
-        self._reset_weights(tmp_model)
-        self._clean_buffers(tmp_model)
-        torch.save(tmp_model, path)
-
-    def _clean_buffers(self, model=None):
-        if not model: model=self.model
-        for m in model.modules():
-            if hasattr(m, 'weight'):
-                if hasattr(m, '_mask'): del m._buffers["_mask"]
-                if hasattr(m, '_init_weights'): del m._buffers["_init_weights"]
-                if hasattr(m, '_init_biases'): del m._buffers["_init_biases"]
-    
-    def _compute_threshold(self, m, scores, sparsity):
-        if self.context == 'global':
-            if self.threshold is None: 
-                global_scores = torch.cat([self.criteria(m, self.granularity).view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)]) # Get all scores
-                global_mask = torch.cat([m._mask.view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)]) # Get all masks
-                global_scores = self.criteria.descale(self.criteria.rescale(global_scores).mul_(global_mask.squeeze()))
-                self.threshold = torch.quantile(global_scores, sparsity/100) # Compute the threshold globally (only once per model pruning)
-            scores = self.criteria.descale(self.criteria.rescale(scores, self.criteria.min_value).mul_(m._mask)) # min_value is computed only once per prune_model
-            return self.threshold, scores
-        elif self.context == 'local':
-            scores = self.criteria.descale(self.criteria.rescale(scores).mul_(m._mask))
-            return torch.quantile(scores.view(-1), sparsity/100), scores
-        else: raise NameError('Invalid Context')
-
-    def _rounded_sparsity(self, n_to_prune, round_to):
-        return max(round_to*torch.ceil(n_to_prune/round_to), round_to)
-    
-    def _compute_mask(self, m, scores, sparsity, round_to):
-        self.threshold, scores = self._compute_threshold(m, scores, sparsity)
-        if round_to:
-            n_to_keep = sum(scores.ge(self.threshold)).squeeze()
-            self.threshold = torch.topk(scores.squeeze(), int(self._rounded_sparsity(n_to_keep, round_to)))[0].min()
-        if self.threshold > scores.max(): self.threshold = scores.max() # Make sure we don't remove every weight of a given layer
-        return scores.ge(self.threshold).to(dtype=scores.dtype)
-    
-    def print_sparsity(self):
-        for k,m in enumerate(self.model.modules()):
-            if isinstance(m, self.layer_type):
-                print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")
-
- -
-
-
- -
- {% endraw %} - -
-
-

up to here

-
-
-
-
-
-

Surprisingly, our network that is composed of $50 \%$ of zeroes performs reasonnably well when compared to our plain and dense network.

- -
-
-
-
-
-

The SparsifyCallback also accepts a list of sparsities, corresponding to each layer of layer_type to be pruned. Below, we show how to prune only the intermediate layers of ResNet-18.

- -
-
-
- {% raw %} - -
-
- -
-
-
learn = cnn_learner(dls, resnet18, metrics=accuracy)
-learn.unfreeze()
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sparsities = [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0]
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb = SparsifyCallback(sparsity=sparsities, granularity='weight', context='local', criteria=updating_magnitude_increase, schedule=cos)
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
learn.fit_one_cycle(5, cbs=sp_cb)
-
- -
-
-
- -
-
- -
- -
-
Pruning of weight until a sparsity of [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0]%
-Saving Weights at epoch 0
-
-
-
- -
- - -
- -
- -
- -
- - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
epochtrain_lossvalid_lossaccuracytime
00.6436220.5821410.82002700:09
10.4658080.4298110.83355900:09
20.3594320.3559550.84438400:09
30.2665830.2570440.89648200:09
40.1777830.2231900.91272000:09
- -
- -
- -
-
Sparsity at the end of epoch 0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
-Sparsity at the end of epoch 1: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
-Sparsity at the end of epoch 2: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
-Sparsity at the end of epoch 3: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
-Sparsity at the end of epoch 4: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
-Final Sparsity: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
-Sparsity in Conv2d 2: 0.00%
-Sparsity in Conv2d 8: 0.00%
-Sparsity in Conv2d 11: 0.00%
-Sparsity in Conv2d 14: 0.00%
-Sparsity in Conv2d 17: 0.00%
-Sparsity in Conv2d 21: 0.00%
-Sparsity in Conv2d 24: 50.00%
-Sparsity in Conv2d 27: 50.00%
-Sparsity in Conv2d 30: 50.00%
-Sparsity in Conv2d 33: 50.00%
-Sparsity in Conv2d 37: 50.00%
-Sparsity in Conv2d 40: 50.00%
-Sparsity in Conv2d 43: 50.00%
-Sparsity in Conv2d 46: 50.00%
-Sparsity in Conv2d 49: 0.00%
-Sparsity in Conv2d 53: 0.00%
-Sparsity in Conv2d 56: 0.00%
-Sparsity in Conv2d 59: 0.00%
-Sparsity in Conv2d 62: 0.00%
-Sparsity in Conv2d 65: 0.00%
-
-
-
- -
-
- -
- {% endraw %} - -
-
-

On top of that, the SparsifyCallbackcan also take many optionnal arguments:

-
    -
  • start_sparsity: the sparsity that the schedule will use as a starting point (default to 0)
  • -
  • start_epoch: the epoch at which the schedule will start pruning (default to 0)
  • -
  • end_epoch: the epoch at which the schedule will stop pruning (default to the training epochs passed in fit)
  • -
  • lth: whether training using the Lottery Ticket Hypothesis, i.e. reset the weights to their original value at each pruning step (more information in the Lottery Ticket Hypothesis section)
  • -
  • rewind_epoch: the epoch used as a reference for the Lottery Ticket Hypothesis with Rewinding (default to 0)
  • -
  • reset_end: whether you want to reset the weights to their original values after training (pruning masks are still applied)
  • -
  • save_tickets: whether to save intermediate winning tickets.
  • -
  • model: pass a model or a part of the model if you don't want to apply pruning on the whole model trained.
  • -
  • round_to: if specified, the weights will be pruned to the closest multiple value of round_to.
  • -
  • layer_type: specify the type of layer that you want to apply pruning to (default to nn.Conv2d)`
  • -
- -
-
-
-
-
-

For example, we correctly pruned the convolution layers of our model, but we could imagine pruning the Linear Layers of even only the BatchNorm ones !

- -
-
-
-
- - diff --git a/docs/tutorial.pytorch_lightning.html b/docs/tutorial.pytorch_lightning.html deleted file mode 100644 index 0a7f077..0000000 --- a/docs/tutorial.pytorch_lightning.html +++ /dev/null @@ -1,16717 +0,0 @@ ---- - -title: Pytorch Lightning - - -keywords: fastai -sidebar: home_sidebar - -summary: "Prune models with pytorch Lightning" -description: "Prune models with pytorch Lightning" -nb_path: "nbs/04d_tutorial.pytorch_lightning.ipynb" ---- - - -
- - {% raw %} - -
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
 
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
import os
-
-
-from fastcore.all import typedispatch
-import fastai
-from fasterai.sparse.criteria import *
-from fasterai.sparse.schedule import *
-
-import pytorch_lightning
-import math
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torchvision
-from pl_bolts.datamodules import CIFAR10DataModule
-from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
-from pytorch_lightning import LightningModule, Trainer, seed_everything
-from pytorch_lightning.callbacks import LearningRateMonitor
-from pytorch_lightning.loggers import TensorBoardLogger
-from torch.optim.lr_scheduler import OneCycleLR
-from torch.optim.swa_utils import AveragedModel, update_bn
-from torchmetrics.functional import accuracy
-
-from fasterai.sparse.all import Sparsifier
-from fastcore.basics import store_attr
-
-seed_everything(7)
-
-PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
-AVAIL_GPUS = min(1, torch.cuda.device_count())
-BATCH_SIZE = 256 if AVAIL_GPUS else 64
-NUM_WORKERS = int(os.cpu_count() / 2)
-
- -
-
-
- -
-
- -
- -
-
Global seed set to 7
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
train_transforms = torchvision.transforms.Compose(
-    [
-        torchvision.transforms.RandomCrop(32, padding=4),
-        torchvision.transforms.RandomHorizontalFlip(),
-        torchvision.transforms.ToTensor(),
-        cifar10_normalization(),
-    ]
-)
-
-test_transforms = torchvision.transforms.Compose(
-    [
-        torchvision.transforms.ToTensor(),
-        cifar10_normalization(),
-    ]
-)
-
-cifar10_dm = CIFAR10DataModule(
-    data_dir=PATH_DATASETS,
-    batch_size=BATCH_SIZE,
-    num_workers=NUM_WORKERS,
-    train_transforms=train_transforms,
-    test_transforms=test_transforms,
-    val_transforms=test_transforms,
-)
-
- -
-
-
- -
-
- -
- -
-
/home/HubensN/miniconda3/envs/deep/lib/python3.7/site-packages/pytorch_lightning/core/datamodule.py:74: LightningDeprecationWarning: DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7.
-  "DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7."
-/home/HubensN/miniconda3/envs/deep/lib/python3.7/site-packages/pytorch_lightning/core/datamodule.py:78: LightningDeprecationWarning: DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7.
-  "DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7."
-/home/HubensN/miniconda3/envs/deep/lib/python3.7/site-packages/pytorch_lightning/core/datamodule.py:82: LightningDeprecationWarning: DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7.
-  "DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7."
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
def create_model():
-    model = torchvision.models.resnet18(pretrained=False, num_classes=10)
-    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
-    model.maxpool = nn.Identity()
-    return model
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class LitResnet(LightningModule):
-    def __init__(self, lr=0.05):
-        super().__init__()
-
-        self.save_hyperparameters()
-        self.model = create_model()
-
-    def forward(self, x):
-        out = self.model(x)
-        return F.log_softmax(out, dim=1)
-
-    def training_step(self, batch, batch_idx):
-        x, y = batch
-        logits = self(x)
-        loss = F.nll_loss(logits, y)
-        self.log("train_loss", loss)
-        return loss
-
-    def evaluate(self, batch, stage=None):
-        x, y = batch
-        logits = self(x)
-        loss = F.nll_loss(logits, y)
-        preds = torch.argmax(logits, dim=1)
-        acc = accuracy(preds, y)
-
-        if stage:
-            self.log(f"{stage}_loss", loss, prog_bar=True)
-            self.log(f"{stage}_acc", acc, prog_bar=True)
-
-    def validation_step(self, batch, batch_idx):
-        self.evaluate(batch, "val")
-
-    def test_step(self, batch, batch_idx):
-        self.evaluate(batch, "test")
-
-    def configure_optimizers(self):
-        optimizer = torch.optim.SGD(
-            self.parameters(),
-            lr=self.hparams.lr,
-            momentum=0.9,
-            weight_decay=5e-4,
-        )
-        steps_per_epoch = 45000 // BATCH_SIZE
-        scheduler_dict = {
-            "scheduler": OneCycleLR(
-                optimizer,
-                0.1,
-                epochs=self.trainer.max_epochs,
-                steps_per_epoch=steps_per_epoch,
-            ),
-            "interval": "step",
-        }
-        return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
import fastai
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class SparsifyCallback(fastai.callback.all.Callback):
-
-    def __init__(self, end_sparsity, granularity, method, criteria, sched_func, start_sparsity=0, start_epoch=0, end_epoch=None, lth=False, rewind_epoch=0, reset_end=False, model=None, round_to=None, layer_type=nn.Conv2d):
-        self.end_sparsity = end_sparsity
-        self.granularity, self.method, self.criteria, self.sched_func = granularity, method, criteria, sched_func
-        self.start_sparsity, self.start_epoch, self.end_epoch = start_sparsity, start_epoch, end_epoch
-        self.lth, self.rewind_epoch, self.reset_end = lth, rewind_epoch, reset_end
-        self.model = model
-        self.round_to = round_to
-        self.layer_type = layer_type
-        self.train_iter = 0
-        self.current_sparsity, self.previous_sparsity = 0, 0
-        assert self.start_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'
-        print("Starting to init trainer!")
-        
-    def before_fit(self):
-        print(f'Pruning of {self.granularity} until a sparsity of {self.end_sparsity}%')
-        self.end_epoch = self.n_epoch if self.end_epoch is None else self.end_epoch
-        assert self.end_epoch <= self.n_epoch, 'Your end_epoch must be smaller than total number of epoch'
-
-        model = self.learn.model if self.model is None else self.model # Pass a model if you don't want the whole model to be pruned
-        self.sparsifier = Sparsifier(model, self.granularity, self.method, self.criteria, self.layer_type)
-        self.n_batches = math.floor(len(self.learn.dls.dataset)/self.learn.dls.bs)
-        self.total_iters = self.end_epoch * self.n_batches
-        self.start_iter = self.start_epoch * self.n_batches
-
-    def before_epoch(self):
-        if self.epoch == self.rewind_epoch:
-            print(f'Saving Weights at epoch {self.epoch}')
-            self.sparsifier._save_weights()
-
-    def before_batch(self):
-        if self.epoch>=self.start_epoch:
-            if self.epoch < self.end_epoch: self._set_sparsity()
-            self.sparsifier.prune_model(self.current_sparsity, self.round_to)
-
-            if self.lth and self.current_sparsity!=self.previous_sparsity: # If sparsity has changed, the network has been pruned
-                    print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
-                    self.sparsifier._reset_weights()
-
-            self.previous_sparsity = self.current_sparsity
-
-    def before_step(self):
-        if self.epoch>=self.start_epoch:
-            self.sparsifier._mask_grad()
-
-    def after_epoch(self):
-        print(f'Sparsity at the end of epoch {self.epoch}: {self.current_sparsity:.2f}%')
-
-    def after_fit(self):
-        print(f'Final Sparsity: {self.current_sparsity:.2f}')
-        if self.reset_end:
-            self.sparsifier._reset_weights()
-        self.sparsifier._clean_buffers() # Remove buffers at the end of training
-        self.sparsifier.print_sparsity()
-
-    def _set_sparsity(self):
-        self.current_sparsity = self.sched_func(start=self.start_sparsity, end=self.end_sparsity, pos=(self.train_iter-self.start_iter)/(self.total_iters-self.start_iter))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb = SparsifyCallback(end_sparsity=50, granularity='filter', method='local', criteria=large_final, sched_func=sched_onecycle)
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb = SparsifyCallbackFlash(end_sparsity=50, granularity='filter', method='local', criteria=large_final, sched_func=sched_onecycle)
-
- -
-
-
- -
-
- -
- -
-
Starting to init trainer!
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
from fasterai.sparse.all import *
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
store_attr??
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
import pytorch_lightning
-import torch.nn as nn
-import math
-
-class SparsifyCallbackFlash(pytorch_lightning.callbacks.Callback):
-    def __init__(self, end_sparsity, granularity, method, criteria, sched_func, start_sparsity=0, start_epoch=0, end_epoch=None, lth=False, rewind_epoch=0, reset_end=False, model=None, round_to=None, layer_type=nn.Conv2d):
-        store_attr('end_sparsity, granularity, method, criteria, sched_func, start_sparsity, start_epoch, end_epoch, lth, rewind_epoch, reset_end, model, round_to, layer_type')
-        self.train_iter = 0
-        self.current_sparsity, self.previous_sparsity = 0, 0
-        assert self.start_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'
-        print("Starting to init trainer!")
-
-
-    def on_fit_start(self, trainer, pl_module):
-        print(f'Pruning of {self.granularity} until a sparsity of {self.end_sparsity}%')
-        self.end_epoch = trainer.max_epochs if self.end_epoch is None else self.end_epoch
-        assert self.end_epoch <= trainer.max_epochs, 'Your end_epoch must be smaller than total number of epoch'
-
-        model = trainer.model if self.model is None else self.model # Pass a model if you don't want the whole model to be pruned
-        self.sparsifier = Sparsifier(model, self.granularity, self.method, self.criteria, self.layer_type)
-        self.n_batches = math.floor(len(trainer.datamodule.dataset_train)/trainer.datamodule.batch_size)
-        self.total_iters = self.end_epoch * self.n_batches
-        self.start_iter = self.start_epoch * self.n_batches
-        
-    def on_fit_end(self, trainer, pl_module):
-        print(f'Final Sparsity: {self.current_sparsity:.2f}')
-        if self.reset_end:
-            self.sparsifier._reset_weights()
-        self.sparsifier._clean_buffers() # Remove buffers at the end of training
-        self.sparsifier.print_sparsity()
-        
-    def on_train_epoch_start(self, trainer, pl_module):
-        if trainer.current_epoch == self.rewind_epoch:
-            print(f'Saving Weights at epoch {trainer.current_epoch}')
-            self.sparsifier._save_weights()
-        
-    def on_train_epoch_end(self, trainer, pl_module):
-        print(f'Sparsity at the end of epoch {trainer.current_epoch}: {self.current_sparsity:.2f}%')
-
-    def on_batch_start(self, trainer, pl_module):
-        self.train_iter+=1
-        if trainer.current_epoch>=self.start_epoch:
-            if trainer.current_epoch < self.end_epoch: self._set_sparsity()
-            self.sparsifier.prune_model(self.current_sparsity, self.round_to)
-
-            if self.lth and self.current_sparsity!=self.previous_sparsity: # If sparsity has changed, the network has been pruned
-                    print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
-                    self.sparsifier._reset_weights()
-
-            self.previous_sparsity = self.current_sparsity
-            
-    #def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
-    #    if trainer.current_epoch>=self.start_epoch:
-    #        self.sparsifier.prune_model(self.current_sparsity, self.round_to)
-
-    #        if self.lth and self.current_sparsity!=self.previous_sparsity: # If sparsity has changed, the network has been pruned
-    #                print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
-    #                self.sparsifier._reset_weights()
-        
-    def on_after_backward(self, trainer, pl_module): #, optimizer, opt_idx
-        if trainer.current_epoch>=self.start_epoch:
-            #print('After BW', model.model.conv1.weight.grad.sum())
-            self.sparsifier._mask_grad()
-            #print('After BW, After Prune', model.model.conv1.weight.grad.sum())
-        
-    def on_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
-        print('After Batch', model.model.conv1.weight.sum(dim=(1,2,3)))
-    
-    def on_before_optimizer_step(self, trainer, pl_module, optimizer, opt_idx):
-        print('Before Step', model.model.conv1.weight.sum(dim=(1,2,3)))
-        
-    def on_before_zero_grad(self, trainer, pl_module, optimizer):
-        print('After Step', model.model.conv1.weight.sum(dim=(1,2,3)))
-        
-    
-    def _set_sparsity(self):
-        self.current_sparsity = self.sched_func(start=self.start_sparsity, end=self.end_sparsity, pos=(self.train_iter-self.start_iter)/(self.total_iters-self.start_iter))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb = SparsifyCallbackFlash(end_sparsity=50, granularity='filter', method='local', criteria=large_final, sched_func=sched_onecycle)
-
- -
-
-
- -
-
- -
- -
-
Starting to init trainer!
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
model = LitResnet(lr=0.05)
-model.datamodule = cifar10_dm
-
-trainer = Trainer(
-    progress_bar_refresh_rate=10,
-    max_epochs=3,
-    gpus=AVAIL_GPUS,
-    logger=TensorBoardLogger("lightning_logs/", name="resnet"),
-    callbacks=[LearningRateMonitor(logging_interval="step"), sp_cb],
-)
-
-trainer.fit(model, cifar10_dm)
-#trainer.test(model, datamodule=cifar10_dm)
-
- -
-
-
- -
-
- -
- -
-
GPU available: True, used: True
-TPU available: False, using: 0 TPU cores
-IPU available: False, using: 0 IPUs
-LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
-
-  | Name  | Type   | Params
----------------------------------
-0 | model | ResNet | 11.2 M
----------------------------------
-11.2 M    Trainable params
-0         Non-trainable params
-11.2 M    Total params
-44.696    Total estimated model params size (MB)
-
-
-
- -
- -
-
Pruning of filter until a sparsity of 50%
-
-
-
- -
- -
-
Global seed set to 7
-
-
-
- -
- -
-
Saving Weights at epoch 0
-After Step tensor([ 0.9081,  0.6772, -0.3952,  0.2494,  0.1716,  0.0284, -0.5729, -1.1187,
-        -0.6003, -0.5951, -0.2932, -0.2708,  0.0602,  0.3945, -0.0567, -0.4998,
-        -0.7118,  0.6990, -0.8625, -0.6608, -0.3355, -0.0327, -0.6007, -0.2280,
-        -0.0635,  0.3191,  0.1373,  0.1476,  0.1205,  1.4984, -0.9119, -0.0762,
-        -0.7782, -0.1663, -0.9118, -0.2271,  0.5274,  0.0000,  0.1409,  0.5117,
-         0.3650, -0.2874, -0.0246,  0.7153,  0.5553,  0.5945,  0.1937,  0.4063,
-         0.5245, -0.5398,  0.5487, -0.1477,  0.4366,  0.2581,  0.7108,  0.5323,
-        -0.1485,  1.1084, -0.1958, -0.6098,  1.0361, -0.1469, -0.2773, -0.2078],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9081,  0.6772, -0.3952,  0.2494,  0.1716,  0.0284, -0.5729, -1.1187,
-        -0.6003, -0.5951, -0.2932, -0.2708,  0.0602,  0.3945, -0.0567, -0.4998,
-        -0.7118,  0.6990, -0.8625, -0.6608, -0.3355, -0.0327, -0.6007, -0.2280,
-        -0.0635,  0.3191,  0.1373,  0.1476,  0.1205,  1.4984, -0.9119, -0.0762,
-        -0.7782, -0.1663, -0.9118, -0.2271,  0.5274,  0.0000,  0.1409,  0.5117,
-         0.3650, -0.2874, -0.0246,  0.7153,  0.5553,  0.5945,  0.1937,  0.4063,
-         0.5245, -0.5398,  0.5487, -0.1477,  0.4366,  0.2581,  0.7108,  0.5323,
-        -0.1485,  1.1084, -0.1958, -0.6098,  1.0361, -0.1469, -0.2773, -0.2078],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9077,  0.6766, -0.3968,  0.2510,  0.1723,  0.0253, -0.5725, -1.1187,
-        -0.5994, -0.5963, -0.2932, -0.2697,  0.0583,  0.3911, -0.0608, -0.5002,
-        -0.7128,  0.6987, -0.8623, -0.6603, -0.3353, -0.0314, -0.6016, -0.2284,
-        -0.0652,  0.3181,  0.1375,  0.1443,  0.1189,  1.4984, -0.9123, -0.0740,
-        -0.7780, -0.1686, -0.9115, -0.2282,  0.5280,  0.0000,  0.1393,  0.5112,
-         0.3673, -0.2886, -0.0255,  0.7156,  0.5556,  0.5940,  0.1920,  0.4071,
-         0.5256, -0.5393,  0.5502, -0.1461,  0.4368,  0.2590,  0.7105,  0.5324,
-        -0.1518,  1.1084, -0.1961, -0.6084,  1.0361, -0.1472, -0.2784, -0.2083],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9077,  0.6766, -0.3968,  0.2510,  0.1723,  0.0253, -0.5725, -1.1187,
-        -0.5994, -0.5963, -0.2932, -0.2697,  0.0583,  0.3911, -0.0608, -0.5002,
-        -0.7128,  0.6987, -0.8623, -0.6603, -0.3353, -0.0314, -0.6016, -0.2284,
-        -0.0652,  0.3181,  0.1375,  0.1443,  0.1189,  1.4984, -0.9123, -0.0740,
-        -0.7780, -0.1686, -0.9115, -0.2282,  0.5280,  0.0000,  0.1393,  0.5112,
-         0.3673, -0.2886, -0.0255,  0.7156,  0.5556,  0.5940,  0.1920,  0.4071,
-         0.5256, -0.5393,  0.5502, -0.1461,  0.4368,  0.2590,  0.7105,  0.5324,
-        -0.1518,  1.1084, -0.1961, -0.6084,  1.0361, -0.1472, -0.2784, -0.2083],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9077,  0.6766, -0.3968,  0.2510,  0.1723,  0.0253, -0.5725, -1.1187,
-        -0.5994, -0.5963, -0.2932, -0.2697,  0.0583,  0.3911, -0.0608, -0.5002,
-        -0.7128,  0.6987, -0.8623, -0.6603, -0.3353, -0.0314, -0.6016, -0.2284,
-        -0.0652,  0.3181,  0.1375,  0.1443,  0.1189,  1.4984, -0.9123, -0.0740,
-        -0.7780, -0.1686, -0.9115, -0.2282,  0.5280,  0.0000,  0.1393,  0.5112,
-         0.3673, -0.2886, -0.0255,  0.7156,  0.5556,  0.5940,  0.1920,  0.4071,
-         0.5256, -0.5393,  0.5502, -0.1461,  0.4368,  0.2590,  0.7105,  0.5324,
-        -0.1518,  1.1084, -0.1961, -0.6084,  1.0361, -0.1472, -0.2784, -0.2083],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9067,  0.6768, -0.3994,  0.2519,  0.1738,  0.0215, -0.5718, -1.1184,
-        -0.5988, -0.5976, -0.2888, -0.2670,  0.0592,  0.3892, -0.0651, -0.5004,
-        -0.7149,  0.6978, -0.8622, -0.6604, -0.3347, -0.0293, -0.6020, -0.2296,
-        -0.0675,  0.3181,  0.1361,  0.1391,  0.1179,  1.4982, -0.9125, -0.0722,
-        -0.7785, -0.1713, -0.9110, -0.2300,  0.5292,  0.0000,  0.1344,  0.5093,
-         0.3678, -0.2896, -0.0236,  0.7161,  0.5565,  0.5941,  0.1913,  0.4073,
-         0.5262, -0.5394,  0.5518, -0.1395,  0.4371,  0.2598,  0.7110,  0.5323,
-        -0.1547,  1.1088, -0.1971, -0.6074,  1.0365, -0.1461, -0.2777, -0.2076],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9067,  0.6768, -0.3994,  0.2519,  0.1738,  0.0215, -0.5718, -1.1184,
-        -0.5988, -0.5976, -0.2888, -0.2670,  0.0592,  0.3892, -0.0651, -0.5004,
-        -0.7149,  0.6978, -0.8622, -0.6604, -0.3347, -0.0293, -0.6020, -0.2296,
-        -0.0675,  0.3181,  0.1361,  0.1391,  0.1179,  1.4982, -0.9125, -0.0722,
-        -0.7785, -0.1713, -0.9110, -0.2300,  0.5292,  0.0000,  0.1344,  0.5093,
-         0.3678, -0.2896, -0.0236,  0.7161,  0.5565,  0.5941,  0.1913,  0.4073,
-         0.5262, -0.5394,  0.5518, -0.1395,  0.4371,  0.2598,  0.7110,  0.5323,
-        -0.1547,  1.1088, -0.1971, -0.6074,  1.0365, -0.1461, -0.2777, -0.2076],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9067,  0.6768, -0.3994,  0.2519,  0.1738,  0.0215, -0.5718, -1.1184,
-        -0.5988, -0.5976, -0.2888, -0.2670,  0.0592,  0.3892, -0.0651, -0.5004,
-        -0.7149,  0.6978, -0.8622, -0.6604, -0.3347, -0.0293, -0.6020, -0.2296,
-        -0.0675,  0.3181,  0.1361,  0.1391,  0.1179,  1.4982, -0.9125, -0.0722,
-        -0.7785, -0.1713, -0.9110, -0.2300,  0.5292,  0.0000,  0.1344,  0.5093,
-         0.3678, -0.2896, -0.0236,  0.7161,  0.5565,  0.5941,  0.1913,  0.4073,
-         0.5262, -0.5394,  0.5518, -0.1395,  0.4371,  0.2598,  0.7110,  0.5323,
-        -0.1547,  1.1088, -0.1971, -0.6074,  1.0365, -0.1461, -0.2777, -0.2076],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9054,  0.6777, -0.4027,  0.2506,  0.1768,  0.0169, -0.5707, -1.1182,
-        -0.5983, -0.5986, -0.2849, -0.2662,  0.0603,  0.3884, -0.0682, -0.5005,
-        -0.7170,  0.6972, -0.8620, -0.6612, -0.3324, -0.0273, -0.6027, -0.2344,
-        -0.0713,  0.3180,  0.1321,  0.1368,  0.1171,  1.4979, -0.9123, -0.0706,
-        -0.7794, -0.1755, -0.9104, -0.2303,  0.5306,  0.0000,  0.1321,  0.5081,
-         0.3678, -0.2889, -0.0219,  0.7161,  0.5574,  0.5945,  0.1923,  0.4087,
-         0.5273, -0.5389,  0.5536, -0.1342,  0.4376,  0.2593,  0.7122,  0.5322,
-        -0.1581,  1.1091, -0.1973, -0.6073,  1.0371, -0.1443, -0.2750, -0.2057],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9054,  0.6777, -0.4027,  0.2506,  0.1768,  0.0169, -0.5707, -1.1182,
-        -0.5983, -0.5986, -0.2849, -0.2662,  0.0603,  0.3884, -0.0682, -0.5005,
-        -0.7170,  0.6972, -0.8620, -0.6612, -0.3324, -0.0273, -0.6027, -0.2344,
-        -0.0713,  0.3180,  0.1321,  0.1368,  0.1171,  1.4979, -0.9123, -0.0706,
-        -0.7794, -0.1755, -0.9104, -0.2303,  0.5306,  0.0000,  0.1321,  0.5081,
-         0.3678, -0.2889, -0.0219,  0.7161,  0.5574,  0.5945,  0.1923,  0.4087,
-         0.5273, -0.5389,  0.5536, -0.1342,  0.4376,  0.2593,  0.7122,  0.5322,
-        -0.1581,  1.1091, -0.1973, -0.6073,  1.0371, -0.1443, -0.2750, -0.2057],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9054,  0.6777, -0.4027,  0.2506,  0.1768,  0.0169, -0.5707, -1.1182,
-        -0.5983, -0.5986, -0.2849, -0.2662,  0.0603,  0.3884, -0.0682, -0.5005,
-        -0.7170,  0.6972, -0.8620, -0.6612, -0.3324, -0.0273, -0.6027, -0.2344,
-        -0.0713,  0.3180,  0.1321,  0.1368,  0.1171,  1.4979, -0.9123, -0.0706,
-        -0.7794, -0.1755, -0.9104, -0.2303,  0.5306,  0.0000,  0.1321,  0.5081,
-         0.3678, -0.2889, -0.0219,  0.7161,  0.5574,  0.5945,  0.1923,  0.4087,
-         0.5273, -0.5389,  0.5536, -0.1342,  0.4376,  0.2593,  0.7122,  0.5322,
-        -0.1581,  1.1091, -0.1973, -0.6073,  1.0371, -0.1443, -0.2750, -0.2057],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9047,  0.6789, -0.4069,  0.2467,  0.1825,  0.0134, -0.5705, -1.1180,
-        -0.5983, -0.5991, -0.2800, -0.2661,  0.0600,  0.3891, -0.0648, -0.5013,
-        -0.7192,  0.6964, -0.8612, -0.6632, -0.3316, -0.0273, -0.6032, -0.2386,
-        -0.0747,  0.3168,  0.1275,  0.1369,  0.1147,  1.4978, -0.9122, -0.0674,
-        -0.7799, -0.1790, -0.9100, -0.2305,  0.5335,  0.0000,  0.1285,  0.5066,
-         0.3676, -0.2888, -0.0226,  0.7160,  0.5588,  0.5953,  0.1949,  0.4107,
-         0.5263, -0.5385,  0.5556, -0.1313,  0.4397,  0.2578,  0.7134,  0.5325,
-        -0.1588,  1.1093, -0.1960, -0.6065,  1.0374, -0.1432, -0.2723, -0.2028],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9047,  0.6789, -0.4069,  0.2467,  0.1825,  0.0134, -0.5705, -1.1180,
-        -0.5983, -0.5991, -0.2800, -0.2661,  0.0600,  0.3891, -0.0648, -0.5013,
-        -0.7192,  0.6964, -0.8612, -0.6632, -0.3316, -0.0273, -0.6032, -0.2386,
-        -0.0747,  0.3168,  0.1275,  0.1369,  0.1147,  1.4978, -0.9122, -0.0674,
-        -0.7799, -0.1790, -0.9100, -0.2305,  0.5335,  0.0000,  0.1285,  0.5066,
-         0.3676, -0.2888, -0.0226,  0.7160,  0.5588,  0.5953,  0.1949,  0.4107,
-         0.5263, -0.5385,  0.5556, -0.1313,  0.4397,  0.2578,  0.7134,  0.5325,
-        -0.1588,  1.1093, -0.1960, -0.6065,  1.0374, -0.1432, -0.2723, -0.2028],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9047,  0.6789, -0.4069,  0.2467,  0.1825,  0.0134, -0.5705, -1.1180,
-        -0.5983, -0.5991, -0.2800, -0.2661,  0.0600,  0.3891, -0.0648, -0.5013,
-        -0.7192,  0.6964, -0.8612, -0.6632, -0.3316, -0.0273, -0.6032, -0.2386,
-        -0.0747,  0.3168,  0.1275,  0.1369,  0.1147,  1.4978, -0.9122, -0.0674,
-        -0.7799, -0.1790, -0.9100, -0.2305,  0.5335,  0.0000,  0.1285,  0.5066,
-         0.3676, -0.2888, -0.0226,  0.7160,  0.5588,  0.5953,  0.1949,  0.4107,
-         0.5263, -0.5385,  0.5556, -0.1313,  0.4397,  0.2578,  0.7134,  0.5325,
-        -0.1588,  1.1093, -0.1960, -0.6065,  1.0374, -0.1432, -0.2723, -0.2028],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9039,  0.6800, -0.4124,  0.2452,  0.1900,  0.0067, -0.5710, -1.1178,
-        -0.5980, -0.5994, -0.2733, -0.2691,  0.0571,  0.3893, -0.0627, -0.5025,
-        -0.7218,  0.6962, -0.8608, -0.6641, -0.3306, -0.0264, -0.6032, -0.2454,
-        -0.0787,  0.3152,  0.1191,  0.1348,  0.1121,  1.4975, -0.9119, -0.0649,
-        -0.7805, -0.1829, -0.9098, -0.2318,  0.5354,  0.0000,  0.1287,  0.5045,
-         0.3671, -0.2912, -0.0235,  0.7155,  0.5597,  0.5969,  0.1956,  0.4132,
-         0.5252, -0.5378,  0.5577, -0.1288,  0.4389,  0.2536,  0.7148,  0.5332,
-        -0.1587,  1.1094, -0.1920, -0.6054,  1.0375, -0.1405, -0.2689, -0.1993],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9039,  0.6800, -0.4124,  0.2452,  0.1900,  0.0067, -0.5710, -1.1178,
-        -0.5980, -0.5994, -0.2733, -0.2691,  0.0571,  0.3893, -0.0627, -0.5025,
-        -0.7218,  0.6962, -0.8608, -0.6641, -0.3306, -0.0264, -0.6032, -0.2454,
-        -0.0787,  0.3152,  0.1191,  0.1348,  0.1121,  1.4975, -0.9119, -0.0649,
-        -0.7805, -0.1829, -0.9098, -0.2318,  0.5354,  0.0000,  0.1287,  0.5045,
-         0.3671, -0.2912, -0.0235,  0.7155,  0.5597,  0.5969,  0.1956,  0.4132,
-         0.5252, -0.5378,  0.5577, -0.1288,  0.4389,  0.2536,  0.7148,  0.5332,
-        -0.1587,  1.1094, -0.1920, -0.6054,  1.0375, -0.1405, -0.2689, -0.1993],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9039,  0.6800, -0.4124,  0.2452,  0.1900,  0.0067, -0.5710, -1.1178,
-        -0.5980, -0.5994, -0.2733, -0.2691,  0.0571,  0.3893, -0.0627, -0.5025,
-        -0.7218,  0.6962, -0.8608, -0.6641, -0.3306, -0.0264, -0.6032, -0.2454,
-        -0.0787,  0.3152,  0.1191,  0.1348,  0.1121,  1.4975, -0.9119, -0.0649,
-        -0.7805, -0.1829, -0.9098, -0.2318,  0.5354,  0.0000,  0.1287,  0.5045,
-         0.3671, -0.2912, -0.0235,  0.7155,  0.5597,  0.5969,  0.1956,  0.4132,
-         0.5252, -0.5378,  0.5577, -0.1288,  0.4389,  0.2536,  0.7148,  0.5332,
-        -0.1587,  1.1094, -0.1920, -0.6054,  1.0375, -0.1405, -0.2689, -0.1993],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9035,  0.6803, -0.4170,  0.2432,  0.1989,  0.0034, -0.5718, -1.1175,
-        -0.5966, -0.5992, -0.2676, -0.2730,  0.0548,  0.3880, -0.0651, -0.5043,
-        -0.7242,  0.6950, -0.8603, -0.6645, -0.3307, -0.0262, -0.6033, -0.2486,
-        -0.0795,  0.3143,  0.1072,  0.1324,  0.1107,  1.4973, -0.9114, -0.0671,
-        -0.7813, -0.1838, -0.9096, -0.2323,  0.5361,  0.0000,  0.1296,  0.5025,
-         0.3656, -0.2931, -0.0241,  0.7144,  0.5599,  0.5982,  0.1948,  0.4162,
-         0.5250, -0.5372,  0.5586, -0.1258,  0.4399,  0.2505,  0.7168,  0.5335,
-        -0.1593,  1.1096, -0.1864, -0.6047,  1.0376, -0.1376, -0.2682, -0.1942],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9035,  0.6803, -0.4170,  0.2432,  0.1989,  0.0034, -0.5718, -1.1175,
-        -0.5966, -0.5992, -0.2676, -0.2730,  0.0548,  0.3880, -0.0651, -0.5043,
-        -0.7242,  0.6950, -0.8603, -0.6645, -0.3307, -0.0262, -0.6033, -0.2486,
-        -0.0795,  0.3143,  0.1072,  0.1324,  0.1107,  1.4973, -0.9114, -0.0671,
-        -0.7813, -0.1838, -0.9096, -0.2323,  0.5361,  0.0000,  0.1296,  0.5025,
-         0.3656, -0.2931, -0.0241,  0.7144,  0.5599,  0.5982,  0.1948,  0.4162,
-         0.5250, -0.5372,  0.5586, -0.1258,  0.4399,  0.2505,  0.7168,  0.5335,
-        -0.1593,  1.1096, -0.1864, -0.6047,  1.0376, -0.1376, -0.2682, -0.1942],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9035,  0.6803, -0.4170,  0.2432,  0.1989,  0.0034, -0.5718, -1.1175,
-        -0.5966, -0.5992, -0.2676, -0.2730,  0.0548,  0.3880, -0.0651, -0.5043,
-        -0.7242,  0.6950, -0.8603, -0.6645, -0.3307, -0.0262, -0.6033, -0.2486,
-        -0.0795,  0.3143,  0.1072,  0.1324,  0.1107,  1.4973, -0.9114, -0.0671,
-        -0.7813, -0.1838, -0.9096, -0.2323,  0.5361,  0.0000,  0.1296,  0.5025,
-         0.3656, -0.2931, -0.0241,  0.7144,  0.5599,  0.5982,  0.1948,  0.4162,
-         0.5250, -0.5372,  0.5586, -0.1258,  0.4399,  0.2505,  0.7168,  0.5335,
-        -0.1593,  1.1096, -0.1864, -0.6047,  1.0376, -0.1376, -0.2682, -0.1942],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 9.0326e-01,  6.8058e-01, -4.2272e-01,  2.3786e-01,  2.0892e-01,
-        -6.7803e-04, -5.7346e-01, -1.1174e+00, -5.9490e-01, -5.9974e-01,
-        -2.6319e-01, -2.7499e-01,  5.0632e-02,  3.8441e-01, -6.4690e-02,
-        -5.0749e-01, -7.2754e-01,  6.9292e-01, -8.5993e-01, -6.6510e-01,
-        -3.3036e-01, -2.8271e-02, -6.0367e-01, -2.5165e-01, -8.1150e-02,
-         3.1468e-01,  9.5137e-02,  1.2634e-01,  1.1046e-01,  1.4974e+00,
-        -9.1108e-01, -6.9237e-02, -7.8212e-01, -1.8506e-01, -9.0969e-01,
-        -2.3252e-01,  5.3505e-01,  0.0000e+00,  1.3171e-01,  5.0027e-01,
-         3.6314e-01, -2.9196e-01, -2.6607e-02,  7.1387e-01,  5.6020e-01,
-         6.0096e-01,  2.0155e-01,  4.1924e-01,  5.2513e-01, -5.3732e-01,
-         5.5881e-01, -1.2220e-01,  4.4163e-01,  2.4590e-01,  7.1860e-01,
-         5.3458e-01, -1.5731e-01,  1.1098e+00, -1.8395e-01, -6.0554e-01,
-         1.0378e+00, -1.3325e-01, -2.6667e-01, -1.8969e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 9.0326e-01,  6.8058e-01, -4.2272e-01,  2.3786e-01,  2.0892e-01,
-        -6.7803e-04, -5.7346e-01, -1.1174e+00, -5.9490e-01, -5.9974e-01,
-        -2.6319e-01, -2.7499e-01,  5.0632e-02,  3.8441e-01, -6.4690e-02,
-        -5.0749e-01, -7.2754e-01,  6.9292e-01, -8.5993e-01, -6.6510e-01,
-        -3.3036e-01, -2.8271e-02, -6.0367e-01, -2.5165e-01, -8.1150e-02,
-         3.1468e-01,  9.5137e-02,  1.2634e-01,  1.1046e-01,  1.4974e+00,
-        -9.1108e-01, -6.9237e-02, -7.8212e-01, -1.8506e-01, -9.0969e-01,
-        -2.3252e-01,  5.3505e-01,  0.0000e+00,  1.3171e-01,  5.0027e-01,
-         3.6314e-01, -2.9196e-01, -2.6607e-02,  7.1387e-01,  5.6020e-01,
-         6.0096e-01,  2.0155e-01,  4.1924e-01,  5.2513e-01, -5.3732e-01,
-         5.5881e-01, -1.2220e-01,  4.4163e-01,  2.4590e-01,  7.1860e-01,
-         5.3458e-01, -1.5731e-01,  1.1098e+00, -1.8395e-01, -6.0554e-01,
-         1.0378e+00, -1.3325e-01, -2.6667e-01, -1.8969e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 9.0326e-01,  6.8058e-01, -4.2272e-01,  2.3786e-01,  2.0892e-01,
-        -6.7803e-04, -5.7346e-01, -1.1174e+00, -5.9490e-01, -5.9974e-01,
-        -2.6319e-01, -2.7499e-01,  5.0632e-02,  3.8441e-01, -6.4690e-02,
-        -5.0749e-01, -7.2754e-01,  6.9292e-01, -8.5993e-01, -6.6510e-01,
-        -3.3036e-01, -2.8271e-02, -6.0367e-01, -2.5165e-01, -8.1150e-02,
-         3.1468e-01,  9.5137e-02,  1.2634e-01,  1.1046e-01,  1.4974e+00,
-        -9.1108e-01, -6.9237e-02, -7.8212e-01, -1.8506e-01, -9.0969e-01,
-        -2.3252e-01,  5.3505e-01,  0.0000e+00,  1.3171e-01,  5.0027e-01,
-         3.6314e-01, -2.9196e-01, -2.6607e-02,  7.1387e-01,  5.6020e-01,
-         6.0096e-01,  2.0155e-01,  4.1924e-01,  5.2513e-01, -5.3732e-01,
-         5.5881e-01, -1.2220e-01,  4.4163e-01,  2.4590e-01,  7.1860e-01,
-         5.3458e-01, -1.5731e-01,  1.1098e+00, -1.8395e-01, -6.0554e-01,
-         1.0378e+00, -1.3325e-01, -2.6667e-01, -1.8969e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 9.0331e-01,  6.8092e-01, -4.2676e-01,  2.3341e-01,  2.2045e-01,
-        -1.1952e-03, -5.7586e-01, -1.1174e+00, -5.9263e-01, -5.9908e-01,
-        -2.5908e-01, -2.7426e-01,  4.6200e-02,  3.7879e-01, -6.2461e-02,
-        -5.1004e-01, -7.3088e-01,  6.9179e-01, -8.5985e-01, -6.6591e-01,
-        -3.2887e-01, -3.4696e-02, -6.0402e-01, -2.5573e-01, -8.5363e-02,
-         3.1356e-01,  8.0069e-02,  1.1617e-01,  1.1121e-01,  1.4977e+00,
-        -9.1085e-01, -7.4302e-02, -7.8335e-01, -1.8587e-01, -9.0964e-01,
-        -2.3405e-01,  5.3210e-01,  0.0000e+00,  1.3593e-01,  4.9847e-01,
-         3.6028e-01, -2.9464e-01, -2.9597e-02,  7.1263e-01,  5.5947e-01,
-         6.0425e-01,  2.0614e-01,  4.2172e-01,  5.2841e-01, -5.3748e-01,
-         5.5806e-01, -1.1630e-01,  4.4254e-01,  2.4232e-01,  7.2057e-01,
-         5.3573e-01, -1.5028e-01,  1.1101e+00, -1.8262e-01, -6.0682e-01,
-         1.0382e+00, -1.2586e-01, -2.6385e-01, -1.8694e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 9.0331e-01,  6.8092e-01, -4.2676e-01,  2.3341e-01,  2.2045e-01,
-        -1.1952e-03, -5.7586e-01, -1.1174e+00, -5.9263e-01, -5.9908e-01,
-        -2.5908e-01, -2.7426e-01,  4.6200e-02,  3.7879e-01, -6.2461e-02,
-        -5.1004e-01, -7.3088e-01,  6.9179e-01, -8.5985e-01, -6.6591e-01,
-        -3.2887e-01, -3.4696e-02, -6.0402e-01, -2.5573e-01, -8.5363e-02,
-         3.1356e-01,  8.0069e-02,  1.1617e-01,  1.1121e-01,  1.4977e+00,
-        -9.1085e-01, -7.4302e-02, -7.8335e-01, -1.8587e-01, -9.0964e-01,
-        -2.3405e-01,  5.3210e-01,  0.0000e+00,  1.3593e-01,  4.9847e-01,
-         3.6028e-01, -2.9464e-01, -2.9597e-02,  7.1263e-01,  5.5947e-01,
-         6.0425e-01,  2.0614e-01,  4.2172e-01,  5.2841e-01, -5.3748e-01,
-         5.5806e-01, -1.1630e-01,  4.4254e-01,  2.4232e-01,  7.2057e-01,
-         5.3573e-01, -1.5028e-01,  1.1101e+00, -1.8262e-01, -6.0682e-01,
-         1.0382e+00, -1.2586e-01, -2.6385e-01, -1.8694e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 9.0331e-01,  6.8092e-01, -4.2676e-01,  2.3341e-01,  2.2045e-01,
-        -1.1952e-03, -5.7586e-01, -1.1174e+00, -5.9263e-01, -5.9908e-01,
-        -2.5908e-01, -2.7426e-01,  4.6200e-02,  3.7879e-01, -6.2461e-02,
-        -5.1004e-01, -7.3088e-01,  6.9179e-01, -8.5985e-01, -6.6591e-01,
-        -3.2887e-01, -3.4696e-02, -6.0402e-01, -2.5573e-01, -8.5363e-02,
-         3.1356e-01,  8.0069e-02,  1.1617e-01,  1.1121e-01,  1.4977e+00,
-        -9.1085e-01, -7.4302e-02, -7.8335e-01, -1.8587e-01, -9.0964e-01,
-        -2.3405e-01,  5.3210e-01,  0.0000e+00,  1.3593e-01,  4.9847e-01,
-         3.6028e-01, -2.9464e-01, -2.9597e-02,  7.1263e-01,  5.5947e-01,
-         6.0425e-01,  2.0614e-01,  4.2172e-01,  5.2841e-01, -5.3748e-01,
-         5.5806e-01, -1.1630e-01,  4.4254e-01,  2.4232e-01,  7.2057e-01,
-         5.3573e-01, -1.5028e-01,  1.1101e+00, -1.8262e-01, -6.0682e-01,
-         1.0382e+00, -1.2586e-01, -2.6385e-01, -1.8694e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 9.0362e-01,  6.8231e-01, -4.3062e-01,  2.2729e-01,  2.3439e-01,
-        -4.5037e-04, -5.7814e-01, -1.1173e+00, -5.8928e-01, -5.9883e-01,
-        -2.5595e-01, -2.7521e-01,  4.2081e-02,  3.7035e-01, -6.8336e-02,
-        -5.1345e-01, -7.3419e-01,  6.9036e-01, -8.5969e-01, -6.6726e-01,
-        -3.2734e-01, -3.9906e-02, -6.0411e-01, -2.5952e-01, -8.9613e-02,
-         3.1316e-01,  6.8056e-02,  1.0623e-01,  1.1009e-01,  1.4980e+00,
-        -9.1055e-01, -7.3296e-02, -7.8475e-01, -1.8351e-01, -9.0979e-01,
-        -2.3460e-01,  5.2860e-01,  0.0000e+00,  1.4435e-01,  4.9718e-01,
-         3.5915e-01, -2.9532e-01, -2.8910e-02,  7.1120e-01,  5.5895e-01,
-         6.0758e-01,  2.1014e-01,  4.2308e-01,  5.3264e-01, -5.3784e-01,
-         5.5652e-01, -1.1029e-01,  4.4389e-01,  2.3881e-01,  7.2261e-01,
-         5.3638e-01, -1.4419e-01,  1.1106e+00, -1.8342e-01, -6.0779e-01,
-         1.0385e+00, -1.1912e-01, -2.6227e-01, -1.8456e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 9.0362e-01,  6.8231e-01, -4.3062e-01,  2.2729e-01,  2.3439e-01,
-        -4.5037e-04, -5.7814e-01, -1.1173e+00, -5.8928e-01, -5.9883e-01,
-        -2.5595e-01, -2.7521e-01,  4.2081e-02,  3.7035e-01, -6.8336e-02,
-        -5.1345e-01, -7.3419e-01,  6.9036e-01, -8.5969e-01, -6.6726e-01,
-        -3.2734e-01, -3.9906e-02, -6.0411e-01, -2.5952e-01, -8.9613e-02,
-         3.1316e-01,  6.8056e-02,  1.0623e-01,  1.1009e-01,  1.4980e+00,
-        -9.1055e-01, -7.3296e-02, -7.8475e-01, -1.8351e-01, -9.0979e-01,
-        -2.3460e-01,  5.2860e-01,  0.0000e+00,  1.4435e-01,  4.9718e-01,
-         3.5915e-01, -2.9532e-01, -2.8910e-02,  7.1120e-01,  5.5895e-01,
-         6.0758e-01,  2.1014e-01,  4.2308e-01,  5.3264e-01, -5.3784e-01,
-         5.5652e-01, -1.1029e-01,  4.4389e-01,  2.3881e-01,  7.2261e-01,
-         5.3638e-01, -1.4419e-01,  1.1106e+00, -1.8342e-01, -6.0779e-01,
-         1.0385e+00, -1.1912e-01, -2.6227e-01, -1.8456e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 9.0362e-01,  6.8231e-01, -4.3062e-01,  2.2729e-01,  2.3439e-01,
-        -4.5037e-04, -5.7814e-01, -1.1173e+00, -5.8928e-01, -5.9883e-01,
-        -2.5595e-01, -2.7521e-01,  4.2081e-02,  3.7035e-01, -6.8336e-02,
-        -5.1345e-01, -7.3419e-01,  6.9036e-01, -8.5969e-01, -6.6726e-01,
-        -3.2734e-01, -3.9906e-02, -6.0411e-01, -2.5952e-01, -8.9613e-02,
-         3.1316e-01,  6.8056e-02,  1.0623e-01,  1.1009e-01,  1.4980e+00,
-        -9.1055e-01, -7.3296e-02, -7.8475e-01, -1.8351e-01, -9.0979e-01,
-        -2.3460e-01,  5.2860e-01,  0.0000e+00,  1.4435e-01,  4.9718e-01,
-         3.5915e-01, -2.9532e-01, -2.8910e-02,  7.1120e-01,  5.5895e-01,
-         6.0758e-01,  2.1014e-01,  4.2308e-01,  5.3264e-01, -5.3784e-01,
-         5.5652e-01, -1.1029e-01,  4.4389e-01,  2.3881e-01,  7.2261e-01,
-         5.3638e-01, -1.4419e-01,  1.1106e+00, -1.8342e-01, -6.0779e-01,
-         1.0385e+00, -1.1912e-01, -2.6227e-01, -1.8456e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9046,  0.6840, -0.4340,  0.2203,  0.2470,  0.0019, -0.5806, -1.1174,
-        -0.5855, -0.5996, -0.2532, -0.2769,  0.0369,  0.3613, -0.0722, -0.5172,
-        -0.7381,  0.6895, -0.8589, -0.6684, -0.3252, -0.0450, -0.6036, -0.2670,
-        -0.0944,  0.3126,  0.0525,  0.0969,  0.1118,  1.4988, -0.9097, -0.0727,
-        -0.7858, -0.1827, -0.9101, -0.2380,  0.5255,  0.0000,  0.1522,  0.4964,
-         0.3578, -0.2947, -0.0253,  0.7100,  0.5594,  0.6102,  0.2206,  0.4255,
-         0.5360, -0.5375,  0.5551, -0.1036,  0.4456,  0.2389,  0.7249,  0.5359,
-        -0.1330,  1.1106, -0.1843, -0.6086,  1.0390, -0.1126, -0.2634, -0.1828],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9046,  0.6840, -0.4340,  0.2203,  0.2470,  0.0019, -0.5806, -1.1174,
-        -0.5855, -0.5996, -0.2532, -0.2769,  0.0369,  0.3613, -0.0722, -0.5172,
-        -0.7381,  0.6895, -0.8589, -0.6684, -0.3252, -0.0450, -0.6036, -0.2670,
-        -0.0944,  0.3126,  0.0525,  0.0969,  0.1118,  1.4988, -0.9097, -0.0727,
-        -0.7858, -0.1827, -0.9101, -0.2380,  0.5255,  0.0000,  0.1522,  0.4964,
-         0.3578, -0.2947, -0.0253,  0.7100,  0.5594,  0.6102,  0.2206,  0.4255,
-         0.5360, -0.5375,  0.5551, -0.1036,  0.4456,  0.2389,  0.7249,  0.5359,
-        -0.1330,  1.1106, -0.1843, -0.6086,  1.0390, -0.1126, -0.2634, -0.1828],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9046,  0.6840, -0.4340,  0.2203,  0.2470,  0.0019, -0.5806, -1.1174,
-        -0.5855, -0.5996, -0.2532, -0.2769,  0.0369,  0.3613, -0.0722, -0.5172,
-        -0.7381,  0.6895, -0.8589, -0.6684, -0.3252, -0.0450, -0.6036, -0.2670,
-        -0.0944,  0.3126,  0.0525,  0.0969,  0.1118,  1.4988, -0.9097, -0.0727,
-        -0.7858, -0.1827, -0.9101, -0.2380,  0.5255,  0.0000,  0.1522,  0.4964,
-         0.3578, -0.2947, -0.0253,  0.7100,  0.5594,  0.6102,  0.2206,  0.4255,
-         0.5360, -0.5375,  0.5551, -0.1036,  0.4456,  0.2389,  0.7249,  0.5359,
-        -0.1330,  1.1106, -0.1843, -0.6086,  1.0390, -0.1126, -0.2634, -0.1828],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9053,  0.6864, -0.4376,  0.2118,  0.2558,  0.0022, -0.5828, -1.1176,
-        -0.5831, -0.5999, -0.2494, -0.2795,  0.0338,  0.3495, -0.0730, -0.5217,
-        -0.7424,  0.6882, -0.8579, -0.6685, -0.3240, -0.0504, -0.6026, -0.2754,
-        -0.0995,  0.3148,  0.0358,  0.0861,  0.1081,  1.4994, -0.9089, -0.0729,
-        -0.7869, -0.1824, -0.9107, -0.2411,  0.5223,  0.0000,  0.1603,  0.4957,
-         0.3577, -0.2915, -0.0220,  0.7095,  0.5597,  0.6133,  0.2308,  0.4271,
-         0.5402, -0.5375,  0.5538, -0.0994,  0.4472,  0.2405,  0.7276,  0.5359,
-        -0.1212,  1.1109, -0.1841, -0.6085,  1.0396, -0.1053, -0.2622, -0.1839],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9053,  0.6864, -0.4376,  0.2118,  0.2558,  0.0022, -0.5828, -1.1176,
-        -0.5831, -0.5999, -0.2494, -0.2795,  0.0338,  0.3495, -0.0730, -0.5217,
-        -0.7424,  0.6882, -0.8579, -0.6685, -0.3240, -0.0504, -0.6026, -0.2754,
-        -0.0995,  0.3148,  0.0358,  0.0861,  0.1081,  1.4994, -0.9089, -0.0729,
-        -0.7869, -0.1824, -0.9107, -0.2411,  0.5223,  0.0000,  0.1603,  0.4957,
-         0.3577, -0.2915, -0.0220,  0.7095,  0.5597,  0.6133,  0.2308,  0.4271,
-         0.5402, -0.5375,  0.5538, -0.0994,  0.4472,  0.2405,  0.7276,  0.5359,
-        -0.1212,  1.1109, -0.1841, -0.6085,  1.0396, -0.1053, -0.2622, -0.1839],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9053,  0.6864, -0.4376,  0.2118,  0.2558,  0.0022, -0.5828, -1.1176,
-        -0.5831, -0.5999, -0.2494, -0.2795,  0.0338,  0.3495, -0.0730, -0.5217,
-        -0.7424,  0.6882, -0.8579, -0.6685, -0.3240, -0.0504, -0.6026, -0.2754,
-        -0.0995,  0.3148,  0.0358,  0.0861,  0.1081,  1.4994, -0.9089, -0.0729,
-        -0.7869, -0.1824, -0.9107, -0.2411,  0.5223,  0.0000,  0.1603,  0.4957,
-         0.3577, -0.2915, -0.0220,  0.7095,  0.5597,  0.6133,  0.2308,  0.4271,
-         0.5402, -0.5375,  0.5538, -0.0994,  0.4472,  0.2405,  0.7276,  0.5359,
-        -0.1212,  1.1109, -0.1841, -0.6085,  1.0396, -0.1053, -0.2622, -0.1839],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9058,  0.6879, -0.4411,  0.2004,  0.2613,  0.0032, -0.5840, -1.1177,
-        -0.5818, -0.6004, -0.2454, -0.2772,  0.0304,  0.3370, -0.0675, -0.5260,
-        -0.7459,  0.6865, -0.8573, -0.6697, -0.3234, -0.0565, -0.6018, -0.2808,
-        -0.0975,  0.3192,  0.0183,  0.0736,  0.1054,  1.5001, -0.9085, -0.0747,
-        -0.7877, -0.1845, -0.9111, -0.2469,  0.5200,  0.0000,  0.1673,  0.4956,
-         0.3596, -0.2896, -0.0168,  0.7094,  0.5603,  0.6160,  0.2407,  0.4282,
-         0.5458, -0.5378,  0.5521, -0.0961,  0.4486,  0.2419,  0.7297,  0.5338,
-        -0.1104,  1.1113, -0.1862, -0.6085,  1.0398, -0.0969, -0.2615, -0.1872],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9058,  0.6879, -0.4411,  0.2004,  0.2613,  0.0032, -0.5840, -1.1177,
-        -0.5818, -0.6004, -0.2454, -0.2772,  0.0304,  0.3370, -0.0675, -0.5260,
-        -0.7459,  0.6865, -0.8573, -0.6697, -0.3234, -0.0565, -0.6018, -0.2808,
-        -0.0975,  0.3192,  0.0183,  0.0736,  0.1054,  1.5001, -0.9085, -0.0747,
-        -0.7877, -0.1845, -0.9111, -0.2469,  0.5200,  0.0000,  0.1673,  0.4956,
-         0.3596, -0.2896, -0.0168,  0.7094,  0.5603,  0.6160,  0.2407,  0.4282,
-         0.5458, -0.5378,  0.5521, -0.0961,  0.4486,  0.2419,  0.7297,  0.5338,
-        -0.1104,  1.1113, -0.1862, -0.6085,  1.0398, -0.0969, -0.2615, -0.1872],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9058,  0.6879, -0.4411,  0.2004,  0.2613,  0.0032, -0.5840, -1.1177,
-        -0.5818, -0.6004, -0.2454, -0.2772,  0.0304,  0.3370, -0.0675, -0.5260,
-        -0.7459,  0.6865, -0.8573, -0.6697, -0.3234, -0.0565, -0.6018, -0.2808,
-        -0.0975,  0.3192,  0.0183,  0.0736,  0.1054,  1.5001, -0.9085, -0.0747,
-        -0.7877, -0.1845, -0.9111, -0.2469,  0.5200,  0.0000,  0.1673,  0.4956,
-         0.3596, -0.2896, -0.0168,  0.7094,  0.5603,  0.6160,  0.2407,  0.4282,
-         0.5458, -0.5378,  0.5521, -0.0961,  0.4486,  0.2419,  0.7297,  0.5338,
-        -0.1104,  1.1113, -0.1862, -0.6085,  1.0398, -0.0969, -0.2615, -0.1872],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9066,  0.6897, -0.4438,  0.1929,  0.2643,  0.0043, -0.5846, -1.1180,
-        -0.5805, -0.6003, -0.2405, -0.2728,  0.0349,  0.3279, -0.0627, -0.5295,
-        -0.7491,  0.6861, -0.8571, -0.6700, -0.3220, -0.0618, -0.6008, -0.2866,
-        -0.0927,  0.3199,  0.0120,  0.0559,  0.1042,  1.5007, -0.9084, -0.0724,
-        -0.7880, -0.1825, -0.9118, -0.2538,  0.5168,  0.0000,  0.1815,  0.4954,
-         0.3621, -0.2863, -0.0147,  0.7091,  0.5611,  0.6199,  0.2499,  0.4305,
-         0.5527, -0.5380,  0.5504, -0.0929,  0.4498,  0.2489,  0.7318,  0.5301,
-        -0.1047,  1.1114, -0.1914, -0.6082,  1.0401, -0.0886, -0.2607, -0.1922],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9066,  0.6897, -0.4438,  0.1929,  0.2643,  0.0043, -0.5846, -1.1180,
-        -0.5805, -0.6003, -0.2405, -0.2728,  0.0349,  0.3279, -0.0627, -0.5295,
-        -0.7491,  0.6861, -0.8571, -0.6700, -0.3220, -0.0618, -0.6008, -0.2866,
-        -0.0927,  0.3199,  0.0120,  0.0559,  0.1042,  1.5007, -0.9084, -0.0724,
-        -0.7880, -0.1825, -0.9118, -0.2538,  0.5168,  0.0000,  0.1815,  0.4954,
-         0.3621, -0.2863, -0.0147,  0.7091,  0.5611,  0.6199,  0.2499,  0.4305,
-         0.5527, -0.5380,  0.5504, -0.0929,  0.4498,  0.2489,  0.7318,  0.5301,
-        -0.1047,  1.1114, -0.1914, -0.6082,  1.0401, -0.0886, -0.2607, -0.1922],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9066,  0.6897, -0.4438,  0.1929,  0.2643,  0.0043, -0.5846, -1.1180,
-        -0.5805, -0.6003, -0.2405, -0.2728,  0.0349,  0.3279, -0.0627, -0.5295,
-        -0.7491,  0.6861, -0.8571, -0.6700, -0.3220, -0.0618, -0.6008, -0.2866,
-        -0.0927,  0.3199,  0.0120,  0.0559,  0.1042,  1.5007, -0.9084, -0.0724,
-        -0.7880, -0.1825, -0.9118, -0.2538,  0.5168,  0.0000,  0.1815,  0.4954,
-         0.3621, -0.2863, -0.0147,  0.7091,  0.5611,  0.6199,  0.2499,  0.4305,
-         0.5527, -0.5380,  0.5504, -0.0929,  0.4498,  0.2489,  0.7318,  0.5301,
-        -0.1047,  1.1114, -0.1914, -0.6082,  1.0401, -0.0886, -0.2607, -0.1922],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9067,  0.6917, -0.4460,  0.1880,  0.2689,  0.0043, -0.5849, -1.1180,
-        -0.5788, -0.6010, -0.2378, -0.2721,  0.0393,  0.3190, -0.0577, -0.5324,
-        -0.7523,  0.6853, -0.8570, -0.6697, -0.3200, -0.0677, -0.5992, -0.2895,
-        -0.0892,  0.3176,  0.0076,  0.0378,  0.1030,  1.5013, -0.9082, -0.0667,
-        -0.7890, -0.1853, -0.9124, -0.2588,  0.5127,  0.0000,  0.1970,  0.4963,
-         0.3651, -0.2819, -0.0120,  0.7088,  0.5620,  0.6233,  0.2587,  0.4317,
-         0.5611, -0.5381,  0.5482, -0.0874,  0.4539,  0.2583,  0.7338,  0.5264,
-        -0.0967,  1.1117, -0.1930, -0.6087,  1.0404, -0.0814, -0.2584, -0.1984],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9067,  0.6917, -0.4460,  0.1880,  0.2689,  0.0043, -0.5849, -1.1180,
-        -0.5788, -0.6010, -0.2378, -0.2721,  0.0393,  0.3190, -0.0577, -0.5324,
-        -0.7523,  0.6853, -0.8570, -0.6697, -0.3200, -0.0677, -0.5992, -0.2895,
-        -0.0892,  0.3176,  0.0076,  0.0378,  0.1030,  1.5013, -0.9082, -0.0667,
-        -0.7890, -0.1853, -0.9124, -0.2588,  0.5127,  0.0000,  0.1970,  0.4963,
-         0.3651, -0.2819, -0.0120,  0.7088,  0.5620,  0.6233,  0.2587,  0.4317,
-         0.5611, -0.5381,  0.5482, -0.0874,  0.4539,  0.2583,  0.7338,  0.5264,
-        -0.0967,  1.1117, -0.1930, -0.6087,  1.0404, -0.0814, -0.2584, -0.1984],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9067,  0.6917, -0.4460,  0.1880,  0.2689,  0.0043, -0.5849, -1.1180,
-        -0.5788, -0.6010, -0.2378, -0.2721,  0.0393,  0.3190, -0.0577, -0.5324,
-        -0.7523,  0.6853, -0.8570, -0.6697, -0.3200, -0.0677, -0.5992, -0.2895,
-        -0.0892,  0.3176,  0.0076,  0.0378,  0.1030,  1.5013, -0.9082, -0.0667,
-        -0.7890, -0.1853, -0.9124, -0.2588,  0.5127,  0.0000,  0.1970,  0.4963,
-         0.3651, -0.2819, -0.0120,  0.7088,  0.5620,  0.6233,  0.2587,  0.4317,
-         0.5611, -0.5381,  0.5482, -0.0874,  0.4539,  0.2583,  0.7338,  0.5264,
-        -0.0967,  1.1117, -0.1930, -0.6087,  1.0404, -0.0814, -0.2584, -0.1984],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9069,  0.6938, -0.4478,  0.1783,  0.2721,  0.0034, -0.5845, -1.1178,
-        -0.5775, -0.6023, -0.2337, -0.2710,  0.0433,  0.3107, -0.0527, -0.5356,
-        -0.7558,  0.6853, -0.8567, -0.6677, -0.3183, -0.0739, -0.5968, -0.2924,
-        -0.0841,  0.3135,  0.0068,  0.0194,  0.1056,  1.5023, -0.9080, -0.0594,
-        -0.7897, -0.1899, -0.9130, -0.2638,  0.5088,  0.0000,  0.2112,  0.4976,
-         0.3681, -0.2765, -0.0132,  0.7077,  0.5634,  0.6259,  0.2659,  0.4319,
-         0.5705, -0.5390,  0.5457, -0.0838,  0.4562,  0.2672,  0.7353,  0.5230,
-        -0.0871,  1.1115, -0.1971, -0.6104,  1.0405, -0.0757, -0.2551, -0.2043],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9069,  0.6938, -0.4478,  0.1783,  0.2721,  0.0034, -0.5845, -1.1178,
-        -0.5775, -0.6023, -0.2337, -0.2710,  0.0433,  0.3107, -0.0527, -0.5356,
-        -0.7558,  0.6853, -0.8567, -0.6677, -0.3183, -0.0739, -0.5968, -0.2924,
-        -0.0841,  0.3135,  0.0068,  0.0194,  0.1056,  1.5023, -0.9080, -0.0594,
-        -0.7897, -0.1899, -0.9130, -0.2638,  0.5088,  0.0000,  0.2112,  0.4976,
-         0.3681, -0.2765, -0.0132,  0.7077,  0.5634,  0.6259,  0.2659,  0.4319,
-         0.5705, -0.5390,  0.5457, -0.0838,  0.4562,  0.2672,  0.7353,  0.5230,
-        -0.0871,  1.1115, -0.1971, -0.6104,  1.0405, -0.0757, -0.2551, -0.2043],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9069,  0.6938, -0.4478,  0.1783,  0.2721,  0.0034, -0.5845, -1.1178,
-        -0.5775, -0.6023, -0.2337, -0.2710,  0.0433,  0.3107, -0.0527, -0.5356,
-        -0.7558,  0.6853, -0.8567, -0.6677, -0.3183, -0.0739, -0.5968, -0.2924,
-        -0.0841,  0.3135,  0.0068,  0.0194,  0.1056,  1.5023, -0.9080, -0.0594,
-        -0.7897, -0.1899, -0.9130, -0.2638,  0.5088,  0.0000,  0.2112,  0.4976,
-         0.3681, -0.2765, -0.0132,  0.7077,  0.5634,  0.6259,  0.2659,  0.4319,
-         0.5705, -0.5390,  0.5457, -0.0838,  0.4562,  0.2672,  0.7353,  0.5230,
-        -0.0871,  1.1115, -0.1971, -0.6104,  1.0405, -0.0757, -0.2551, -0.2043],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9077,  0.6957, -0.4499,  0.1688,  0.2772,  0.0024, -0.5842, -1.1173,
-        -0.5771, -0.6048, -0.2298, -0.2716,  0.0449,  0.3024, -0.0538, -0.5387,
-        -0.7594,  0.6847, -0.8564, -0.6661, -0.3166, -0.0810, -0.5943, -0.2956,
-        -0.0836,  0.3080,  0.0083, -0.0025,  0.1035,  1.5033, -0.9076, -0.0521,
-        -0.7905, -0.2004, -0.9134, -0.2695,  0.5074,  0.0000,  0.2219,  0.5001,
-         0.3719, -0.2699, -0.0144,  0.7056,  0.5650,  0.6304,  0.2773,  0.4329,
-         0.5810, -0.5393,  0.5426, -0.0798,  0.4616,  0.2751,  0.7372,  0.5189,
-        -0.0831,  1.1112, -0.1955, -0.6127,  1.0407, -0.0740, -0.2541, -0.2065],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9077,  0.6957, -0.4499,  0.1688,  0.2772,  0.0024, -0.5842, -1.1173,
-        -0.5771, -0.6048, -0.2298, -0.2716,  0.0449,  0.3024, -0.0538, -0.5387,
-        -0.7594,  0.6847, -0.8564, -0.6661, -0.3166, -0.0810, -0.5943, -0.2956,
-        -0.0836,  0.3080,  0.0083, -0.0025,  0.1035,  1.5033, -0.9076, -0.0521,
-        -0.7905, -0.2004, -0.9134, -0.2695,  0.5074,  0.0000,  0.2219,  0.5001,
-         0.3719, -0.2699, -0.0144,  0.7056,  0.5650,  0.6304,  0.2773,  0.4329,
-         0.5810, -0.5393,  0.5426, -0.0798,  0.4616,  0.2751,  0.7372,  0.5189,
-        -0.0831,  1.1112, -0.1955, -0.6127,  1.0407, -0.0740, -0.2541, -0.2065],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9077,  0.6957, -0.4499,  0.1688,  0.2772,  0.0024, -0.5842, -1.1173,
-        -0.5771, -0.6048, -0.2298, -0.2716,  0.0449,  0.3024, -0.0538, -0.5387,
-        -0.7594,  0.6847, -0.8564, -0.6661, -0.3166, -0.0810, -0.5943, -0.2956,
-        -0.0836,  0.3080,  0.0083, -0.0025,  0.1035,  1.5033, -0.9076, -0.0521,
-        -0.7905, -0.2004, -0.9134, -0.2695,  0.5074,  0.0000,  0.2219,  0.5001,
-         0.3719, -0.2699, -0.0144,  0.7056,  0.5650,  0.6304,  0.2773,  0.4329,
-         0.5810, -0.5393,  0.5426, -0.0798,  0.4616,  0.2751,  0.7372,  0.5189,
-        -0.0831,  1.1112, -0.1955, -0.6127,  1.0407, -0.0740, -0.2541, -0.2065],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9081,  0.6975, -0.4515,  0.1596,  0.2813,  0.0029, -0.5841, -1.1171,
-        -0.5775, -0.6065, -0.2239, -0.2703,  0.0394,  0.2900, -0.0527, -0.5429,
-        -0.7630,  0.6845, -0.8573, -0.6653, -0.3175, -0.0901, -0.5918, -0.2999,
-        -0.0875,  0.3023,  0.0116, -0.0226,  0.1005,  1.5046, -0.9070, -0.0471,
-        -0.7927, -0.2117, -0.9136, -0.2742,  0.5035,  0.0000,  0.2337,  0.5030,
-         0.3754, -0.2605, -0.0142,  0.7043,  0.5671,  0.6361,  0.2890,  0.4311,
-         0.5928, -0.5399,  0.5401, -0.0779,  0.4663,  0.2834,  0.7389,  0.5157,
-        -0.0760,  1.1110, -0.1953, -0.6170,  1.0409, -0.0695, -0.2545, -0.2089],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9081,  0.6975, -0.4515,  0.1596,  0.2813,  0.0029, -0.5841, -1.1171,
-        -0.5775, -0.6065, -0.2239, -0.2703,  0.0394,  0.2900, -0.0527, -0.5429,
-        -0.7630,  0.6845, -0.8573, -0.6653, -0.3175, -0.0901, -0.5918, -0.2999,
-        -0.0875,  0.3023,  0.0116, -0.0226,  0.1005,  1.5046, -0.9070, -0.0471,
-        -0.7927, -0.2117, -0.9136, -0.2742,  0.5035,  0.0000,  0.2337,  0.5030,
-         0.3754, -0.2605, -0.0142,  0.7043,  0.5671,  0.6361,  0.2890,  0.4311,
-         0.5928, -0.5399,  0.5401, -0.0779,  0.4663,  0.2834,  0.7389,  0.5157,
-        -0.0760,  1.1110, -0.1953, -0.6170,  1.0409, -0.0695, -0.2545, -0.2089],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9081,  0.6975, -0.4515,  0.1596,  0.2813,  0.0029, -0.5841, -1.1171,
-        -0.5775, -0.6065, -0.2239, -0.2703,  0.0394,  0.2900, -0.0527, -0.5429,
-        -0.7630,  0.6845, -0.8573, -0.6653, -0.3175, -0.0901, -0.5918, -0.2999,
-        -0.0875,  0.3023,  0.0116, -0.0226,  0.1005,  1.5046, -0.9070, -0.0471,
-        -0.7927, -0.2117, -0.9136, -0.2742,  0.5035,  0.0000,  0.2337,  0.5030,
-         0.3754, -0.2605, -0.0142,  0.7043,  0.5671,  0.6361,  0.2890,  0.4311,
-         0.5928, -0.5399,  0.5401, -0.0779,  0.4663,  0.2834,  0.7389,  0.5157,
-        -0.0760,  1.1110, -0.1953, -0.6170,  1.0409, -0.0695, -0.2545, -0.2089],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 9.0809e-01,  6.9898e-01, -4.5331e-01,  1.5216e-01,  2.8274e-01,
-         3.8534e-05, -5.8375e-01, -1.1168e+00, -5.7785e-01, -6.0879e-01,
-        -2.1652e-01, -2.7356e-01,  3.4214e-02,  2.7478e-01, -6.3727e-02,
-        -5.4732e-01, -7.6711e-01,  6.8492e-01, -8.5857e-01, -6.6502e-01,
-        -3.1773e-01, -9.8530e-02, -5.8796e-01, -3.0606e-01, -9.1878e-02,
-         2.9933e-01,  1.8453e-02, -3.8549e-02,  1.0132e-01,  1.5058e+00,
-        -9.0663e-01, -5.0223e-02, -7.9554e-01, -2.2241e-01, -9.1341e-01,
-        -2.8046e-01,  5.0073e-01,  0.0000e+00,  2.4553e-01,  5.0247e-01,
-         3.7934e-01, -2.4644e-01, -1.0619e-02,  7.0329e-01,  5.7052e-01,
-         6.4158e-01,  2.9682e-01,  4.2901e-01,  6.0502e-01, -5.4004e-01,
-         5.3717e-01, -7.9878e-02,  4.6994e-01,  2.9041e-01,  7.4008e-01,
-         5.1300e-01, -6.4303e-02,  1.1109e+00, -1.9746e-01, -6.1998e-01,
-         1.0418e+00, -6.0271e-02, -2.5628e-01, -2.0947e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 9.0809e-01,  6.9898e-01, -4.5331e-01,  1.5216e-01,  2.8274e-01,
-         3.8534e-05, -5.8375e-01, -1.1168e+00, -5.7785e-01, -6.0879e-01,
-        -2.1652e-01, -2.7356e-01,  3.4214e-02,  2.7478e-01, -6.3727e-02,
-        -5.4732e-01, -7.6711e-01,  6.8492e-01, -8.5857e-01, -6.6502e-01,
-        -3.1773e-01, -9.8530e-02, -5.8796e-01, -3.0606e-01, -9.1878e-02,
-         2.9933e-01,  1.8453e-02, -3.8549e-02,  1.0132e-01,  1.5058e+00,
-        -9.0663e-01, -5.0223e-02, -7.9554e-01, -2.2241e-01, -9.1341e-01,
-        -2.8046e-01,  5.0073e-01,  0.0000e+00,  2.4553e-01,  5.0247e-01,
-         3.7934e-01, -2.4644e-01, -1.0619e-02,  7.0329e-01,  5.7052e-01,
-         6.4158e-01,  2.9682e-01,  4.2901e-01,  6.0502e-01, -5.4004e-01,
-         5.3717e-01, -7.9878e-02,  4.6994e-01,  2.9041e-01,  7.4008e-01,
-         5.1300e-01, -6.4303e-02,  1.1109e+00, -1.9746e-01, -6.1998e-01,
-         1.0418e+00, -6.0271e-02, -2.5628e-01, -2.0947e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 9.0809e-01,  6.9898e-01, -4.5331e-01,  1.5216e-01,  2.8274e-01,
-         3.8534e-05, -5.8375e-01, -1.1168e+00, -5.7785e-01, -6.0879e-01,
-        -2.1652e-01, -2.7356e-01,  3.4214e-02,  2.7478e-01, -6.3727e-02,
-        -5.4732e-01, -7.6711e-01,  6.8492e-01, -8.5857e-01, -6.6502e-01,
-        -3.1773e-01, -9.8530e-02, -5.8796e-01, -3.0606e-01, -9.1878e-02,
-         2.9933e-01,  1.8453e-02, -3.8549e-02,  1.0132e-01,  1.5058e+00,
-        -9.0663e-01, -5.0223e-02, -7.9554e-01, -2.2241e-01, -9.1341e-01,
-        -2.8046e-01,  5.0073e-01,  0.0000e+00,  2.4553e-01,  5.0247e-01,
-         3.7934e-01, -2.4644e-01, -1.0619e-02,  7.0329e-01,  5.7052e-01,
-         6.4158e-01,  2.9682e-01,  4.2901e-01,  6.0502e-01, -5.4004e-01,
-         5.3717e-01, -7.9878e-02,  4.6994e-01,  2.9041e-01,  7.4008e-01,
-         5.1300e-01, -6.4303e-02,  1.1109e+00, -1.9746e-01, -6.1998e-01,
-         1.0418e+00, -6.0271e-02, -2.5628e-01, -2.0947e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9082,  0.7016, -0.4534,  0.1388,  0.2857, -0.0019, -0.5843, -1.1163,
-        -0.5804, -0.6104, -0.2084, -0.2757,  0.0238,  0.2603, -0.0733, -0.5544,
-        -0.7707,  0.6854, -0.8598, -0.6638, -0.3173, -0.1110, -0.5853, -0.3088,
-        -0.0968,  0.2967,  0.0226, -0.0596,  0.0984,  1.5074, -0.9062, -0.0448,
-        -0.7989, -0.2336, -0.9136, -0.2848,  0.4983,  0.0000,  0.2580,  0.5003,
-         0.3830, -0.2286, -0.0051,  0.7022,  0.5734,  0.6475,  0.3064,  0.4254,
-         0.6158, -0.5419,  0.5330, -0.0805,  0.4748,  0.2967,  0.7402,  0.5106,
-        -0.0460,  1.1109, -0.2011, -0.6216,  1.0421, -0.0527, -0.2584, -0.2088],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9082,  0.7016, -0.4534,  0.1388,  0.2857, -0.0019, -0.5843, -1.1163,
-        -0.5804, -0.6104, -0.2084, -0.2757,  0.0238,  0.2603, -0.0733, -0.5544,
-        -0.7707,  0.6854, -0.8598, -0.6638, -0.3173, -0.1110, -0.5853, -0.3088,
-        -0.0968,  0.2967,  0.0226, -0.0596,  0.0984,  1.5074, -0.9062, -0.0448,
-        -0.7989, -0.2336, -0.9136, -0.2848,  0.4983,  0.0000,  0.2580,  0.5003,
-         0.3830, -0.2286, -0.0051,  0.7022,  0.5734,  0.6475,  0.3064,  0.4254,
-         0.6158, -0.5419,  0.5330, -0.0805,  0.4748,  0.2967,  0.7402,  0.5106,
-        -0.0460,  1.1109, -0.2011, -0.6216,  1.0421, -0.0527, -0.2584, -0.2088],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9082,  0.7016, -0.4534,  0.1388,  0.2857, -0.0019, -0.5843, -1.1163,
-        -0.5804, -0.6104, -0.2084, -0.2757,  0.0238,  0.2603, -0.0733, -0.5544,
-        -0.7707,  0.6854, -0.8598, -0.6638, -0.3173, -0.1110, -0.5853, -0.3088,
-        -0.0968,  0.2967,  0.0226, -0.0596,  0.0984,  1.5074, -0.9062, -0.0448,
-        -0.7989, -0.2336, -0.9136, -0.2848,  0.4983,  0.0000,  0.2580,  0.5003,
-         0.3830, -0.2286, -0.0051,  0.7022,  0.5734,  0.6475,  0.3064,  0.4254,
-         0.6158, -0.5419,  0.5330, -0.0805,  0.4748,  0.2967,  0.7402,  0.5106,
-        -0.0460,  1.1109, -0.2011, -0.6216,  1.0421, -0.0527, -0.2584, -0.2088],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 9.0735e-01,  7.0306e-01, -4.5280e-01,  1.2427e-01,  2.8572e-01,
-        -4.0952e-03, -5.8513e-01, -1.1155e+00, -5.8369e-01, -6.1238e-01,
-        -2.0482e-01, -2.7752e-01,  1.4553e-02,  2.5039e-01, -8.9902e-02,
-        -5.5956e-01, -7.7304e-01,  6.8676e-01, -8.6044e-01, -6.6298e-01,
-        -3.1876e-01, -1.2073e-01, -5.8386e-01, -3.1067e-01, -1.0285e-01,
-         2.9348e-01,  2.2847e-02, -8.3241e-02,  9.3427e-02,  1.5093e+00,
-        -9.0596e-01, -3.1366e-02, -8.0256e-01, -2.4658e-01, -9.1331e-01,
-        -2.8910e-01,  4.9493e-01,  0.0000e+00,  2.6856e-01,  4.9764e-01,
-         3.8877e-01, -2.0599e-01,  6.7335e-04,  7.0219e-01,  5.7672e-01,
-         6.5394e-01,  3.1667e-01,  4.2309e-01,  6.2805e-01, -5.4364e-01,
-         5.3175e-01, -8.0631e-02,  4.7960e-01,  3.0852e-01,  7.4047e-01,
-         5.0808e-01, -3.1520e-02,  1.1107e+00, -2.0246e-01, -6.2141e-01,
-         1.0426e+00, -4.7125e-02, -2.5999e-01, -2.1075e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 9.0735e-01,  7.0306e-01, -4.5280e-01,  1.2427e-01,  2.8572e-01,
-        -4.0952e-03, -5.8513e-01, -1.1155e+00, -5.8369e-01, -6.1238e-01,
-        -2.0482e-01, -2.7752e-01,  1.4553e-02,  2.5039e-01, -8.9902e-02,
-        -5.5956e-01, -7.7304e-01,  6.8676e-01, -8.6044e-01, -6.6298e-01,
-        -3.1876e-01, -1.2073e-01, -5.8386e-01, -3.1067e-01, -1.0285e-01,
-         2.9348e-01,  2.2847e-02, -8.3241e-02,  9.3427e-02,  1.5093e+00,
-        -9.0596e-01, -3.1366e-02, -8.0256e-01, -2.4658e-01, -9.1331e-01,
-        -2.8910e-01,  4.9493e-01,  0.0000e+00,  2.6856e-01,  4.9764e-01,
-         3.8877e-01, -2.0599e-01,  6.7335e-04,  7.0219e-01,  5.7672e-01,
-         6.5394e-01,  3.1667e-01,  4.2309e-01,  6.2805e-01, -5.4364e-01,
-         5.3175e-01, -8.0631e-02,  4.7960e-01,  3.0852e-01,  7.4047e-01,
-         5.0808e-01, -3.1520e-02,  1.1107e+00, -2.0246e-01, -6.2141e-01,
-         1.0426e+00, -4.7125e-02, -2.5999e-01, -2.1075e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 9.0735e-01,  7.0306e-01, -4.5280e-01,  1.2427e-01,  2.8572e-01,
-        -4.0952e-03, -5.8513e-01, -1.1155e+00, -5.8369e-01, -6.1238e-01,
-        -2.0482e-01, -2.7752e-01,  1.4553e-02,  2.5039e-01, -8.9902e-02,
-        -5.5956e-01, -7.7304e-01,  6.8676e-01, -8.6044e-01, -6.6298e-01,
-        -3.1876e-01, -1.2073e-01, -5.8386e-01, -3.1067e-01, -1.0285e-01,
-         2.9348e-01,  2.2847e-02, -8.3241e-02,  9.3427e-02,  1.5093e+00,
-        -9.0596e-01, -3.1366e-02, -8.0256e-01, -2.4658e-01, -9.1331e-01,
-        -2.8910e-01,  4.9493e-01,  0.0000e+00,  2.6856e-01,  4.9764e-01,
-         3.8877e-01, -2.0599e-01,  6.7335e-04,  7.0219e-01,  5.7672e-01,
-         6.5394e-01,  3.1667e-01,  4.2309e-01,  6.2805e-01, -5.4364e-01,
-         5.3175e-01, -8.0631e-02,  4.7960e-01,  3.0852e-01,  7.4047e-01,
-         5.0808e-01, -3.1520e-02,  1.1107e+00, -2.0246e-01, -6.2141e-01,
-         1.0426e+00, -4.7125e-02, -2.5999e-01, -2.1075e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9071,  0.7055, -0.4502,  0.1171,  0.2856, -0.0041, -0.5851, -1.1147,
-        -0.5870, -0.6136, -0.1971, -0.2846,  0.0083,  0.2404, -0.1021, -0.5651,
-        -0.7755,  0.6884, -0.8611, -0.6613, -0.3226, -0.1312, -0.5830, -0.3122,
-        -0.1121,  0.2915,  0.0292, -0.1095,  0.0877,  1.5109, -0.9060, -0.0140,
-        -0.8059, -0.2547, -0.9129, -0.2935,  0.4930,  0.0000,  0.2806,  0.4951,
-         0.3935, -0.1778,  0.0052,  0.7016,  0.5800,  0.6624,  0.3287,  0.4221,
-         0.6403, -0.5465,  0.5309, -0.0773,  0.4824,  0.3177,  0.7390,  0.5066,
-        -0.0134,  1.1100, -0.1999, -0.6191,  1.0424, -0.0449, -0.2615, -0.2139],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9071,  0.7055, -0.4502,  0.1171,  0.2856, -0.0041, -0.5851, -1.1147,
-        -0.5870, -0.6136, -0.1971, -0.2846,  0.0083,  0.2404, -0.1021, -0.5651,
-        -0.7755,  0.6884, -0.8611, -0.6613, -0.3226, -0.1312, -0.5830, -0.3122,
-        -0.1121,  0.2915,  0.0292, -0.1095,  0.0877,  1.5109, -0.9060, -0.0140,
-        -0.8059, -0.2547, -0.9129, -0.2935,  0.4930,  0.0000,  0.2806,  0.4951,
-         0.3935, -0.1778,  0.0052,  0.7016,  0.5800,  0.6624,  0.3287,  0.4221,
-         0.6403, -0.5465,  0.5309, -0.0773,  0.4824,  0.3177,  0.7390,  0.5066,
-        -0.0134,  1.1100, -0.1999, -0.6191,  1.0424, -0.0449, -0.2615, -0.2139],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9071,  0.7055, -0.4502,  0.1171,  0.2856, -0.0041, -0.5851, -1.1147,
-        -0.5870, -0.6136, -0.1971, -0.2846,  0.0083,  0.2404, -0.1021, -0.5651,
-        -0.7755,  0.6884, -0.8611, -0.6613, -0.3226, -0.1312, -0.5830, -0.3122,
-        -0.1121,  0.2915,  0.0292, -0.1095,  0.0877,  1.5109, -0.9060, -0.0140,
-        -0.8059, -0.2547, -0.9129, -0.2935,  0.4930,  0.0000,  0.2806,  0.4951,
-         0.3935, -0.1778,  0.0052,  0.7016,  0.5800,  0.6624,  0.3287,  0.4221,
-         0.6403, -0.5465,  0.5309, -0.0773,  0.4824,  0.3177,  0.7390,  0.5066,
-        -0.0134,  1.1100, -0.1999, -0.6191,  1.0424, -0.0449, -0.2615, -0.2139],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 9.0613e-01,  7.0808e-01, -4.4644e-01,  9.6410e-02,  2.8943e-01,
-        -2.7042e-03, -5.8521e-01, -1.1138e+00, -5.9066e-01, -6.1504e-01,
-        -1.9424e-01, -2.9477e-01,  2.2738e-03,  2.3198e-01, -1.0466e-01,
-        -5.7044e-01, -7.7745e-01,  6.9018e-01, -8.6190e-01, -6.6058e-01,
-        -3.2516e-01, -1.3661e-01, -5.8217e-01, -3.1276e-01, -1.2280e-01,
-         2.8869e-01,  3.7664e-02, -1.3780e-01,  8.6353e-02,  1.5123e+00,
-        -9.0595e-01,  3.7749e-03, -8.0916e-01, -2.5962e-01, -9.1271e-01,
-        -2.9870e-01,  4.9041e-01,  0.0000e+00,  2.9293e-01,  4.9208e-01,
-         3.9696e-01, -1.5024e-01,  1.0647e-02,  6.9866e-01,  5.8363e-01,
-         6.6923e-01,  3.4186e-01,  4.2001e-01,  6.5021e-01, -5.4899e-01,
-         5.3154e-01, -7.9079e-02,  4.8539e-01,  3.2534e-01,  7.3732e-01,
-         5.0484e-01,  9.3496e-04,  1.1090e+00, -1.9319e-01, -6.1569e-01,
-         1.0422e+00, -4.6211e-02, -2.6277e-01, -2.1266e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 9.0613e-01,  7.0808e-01, -4.4644e-01,  9.6410e-02,  2.8943e-01,
-        -2.7042e-03, -5.8521e-01, -1.1138e+00, -5.9066e-01, -6.1504e-01,
-        -1.9424e-01, -2.9477e-01,  2.2738e-03,  2.3198e-01, -1.0466e-01,
-        -5.7044e-01, -7.7745e-01,  6.9018e-01, -8.6190e-01, -6.6058e-01,
-        -3.2516e-01, -1.3661e-01, -5.8217e-01, -3.1276e-01, -1.2280e-01,
-         2.8869e-01,  3.7664e-02, -1.3780e-01,  8.6353e-02,  1.5123e+00,
-        -9.0595e-01,  3.7749e-03, -8.0916e-01, -2.5962e-01, -9.1271e-01,
-        -2.9870e-01,  4.9041e-01,  0.0000e+00,  2.9293e-01,  4.9208e-01,
-         3.9696e-01, -1.5024e-01,  1.0647e-02,  6.9866e-01,  5.8363e-01,
-         6.6923e-01,  3.4186e-01,  4.2001e-01,  6.5021e-01, -5.4899e-01,
-         5.3154e-01, -7.9079e-02,  4.8539e-01,  3.2534e-01,  7.3732e-01,
-         5.0484e-01,  9.3496e-04,  1.1090e+00, -1.9319e-01, -6.1569e-01,
-         1.0422e+00, -4.6211e-02, -2.6277e-01, -2.1266e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 9.0613e-01,  7.0808e-01, -4.4644e-01,  9.6410e-02,  2.8943e-01,
-        -2.7042e-03, -5.8521e-01, -1.1138e+00, -5.9066e-01, -6.1504e-01,
-        -1.9424e-01, -2.9477e-01,  2.2738e-03,  2.3198e-01, -1.0466e-01,
-        -5.7044e-01, -7.7745e-01,  6.9018e-01, -8.6190e-01, -6.6058e-01,
-        -3.2516e-01, -1.3661e-01, -5.8217e-01, -3.1276e-01, -1.2280e-01,
-         2.8869e-01,  3.7664e-02, -1.3780e-01,  8.6353e-02,  1.5123e+00,
-        -9.0595e-01,  3.7749e-03, -8.0916e-01, -2.5962e-01, -9.1271e-01,
-        -2.9870e-01,  4.9041e-01,  0.0000e+00,  2.9293e-01,  4.9208e-01,
-         3.9696e-01, -1.5024e-01,  1.0647e-02,  6.9866e-01,  5.8363e-01,
-         6.6923e-01,  3.4186e-01,  4.2001e-01,  6.5021e-01, -5.4899e-01,
-         5.3154e-01, -7.9079e-02,  4.8539e-01,  3.2534e-01,  7.3732e-01,
-         5.0484e-01,  9.3496e-04,  1.1090e+00, -1.9319e-01, -6.1569e-01,
-         1.0422e+00, -4.6211e-02, -2.6277e-01, -2.1266e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9050,  0.7090, -0.4428,  0.0621,  0.2917, -0.0021, -0.5851, -1.1131,
-        -0.5942, -0.6155, -0.2006, -0.3052, -0.0033,  0.2221, -0.0993, -0.5760,
-        -0.7811,  0.6914, -0.8627, -0.6572, -0.3274, -0.1419, -0.5826, -0.3159,
-        -0.1268,  0.2859,  0.0527, -0.1725,  0.0850,  1.5138, -0.9058,  0.0181,
-        -0.8133, -0.2710, -0.9126, -0.3047,  0.4894,  0.0000,  0.3094,  0.4910,
-         0.3997, -0.1157,  0.0187,  0.6959,  0.5887,  0.6769,  0.3532,  0.4233,
-         0.6589, -0.5515,  0.5308, -0.0902,  0.4869,  0.3354,  0.7359,  0.5037,
-         0.0155,  1.1071, -0.1871, -0.6131,  1.0427, -0.0476, -0.2588, -0.2121],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9050,  0.7090, -0.4428,  0.0621,  0.2917, -0.0021, -0.5851, -1.1131,
-        -0.5942, -0.6155, -0.2006, -0.3052, -0.0033,  0.2221, -0.0993, -0.5760,
-        -0.7811,  0.6914, -0.8627, -0.6572, -0.3274, -0.1419, -0.5826, -0.3159,
-        -0.1268,  0.2859,  0.0527, -0.1725,  0.0850,  1.5138, -0.9058,  0.0181,
-        -0.8133, -0.2710, -0.9126, -0.3047,  0.4894,  0.0000,  0.3094,  0.4910,
-         0.3997, -0.1157,  0.0187,  0.6959,  0.5887,  0.6769,  0.3532,  0.4233,
-         0.6589, -0.5515,  0.5308, -0.0902,  0.4869,  0.3354,  0.7359,  0.5037,
-         0.0155,  1.1071, -0.1871, -0.6131,  1.0427, -0.0476, -0.2588, -0.2121],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9050,  0.7090, -0.4428,  0.0621,  0.2917, -0.0021, -0.5851, -1.1131,
-        -0.5942, -0.6155, -0.2006, -0.3052, -0.0033,  0.2221, -0.0993, -0.5760,
-        -0.7811,  0.6914, -0.8627, -0.6572, -0.3274, -0.1419, -0.5826, -0.3159,
-        -0.1268,  0.2859,  0.0527, -0.1725,  0.0850,  1.5138, -0.9058,  0.0181,
-        -0.8133, -0.2710, -0.9126, -0.3047,  0.4894,  0.0000,  0.3094,  0.4910,
-         0.3997, -0.1157,  0.0187,  0.6959,  0.5887,  0.6769,  0.3532,  0.4233,
-         0.6589, -0.5515,  0.5308, -0.0902,  0.4869,  0.3354,  0.7359,  0.5037,
-         0.0155,  1.1071, -0.1871, -0.6131,  1.0427, -0.0476, -0.2588, -0.2121],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.9024,  0.7099, -0.4382,  0.0253,  0.2915,  0.0029, -0.5854, -1.1123,
-        -0.5988, -0.6146, -0.2091, -0.3161, -0.0127,  0.2118, -0.0929, -0.5797,
-        -0.7847,  0.6928, -0.8638, -0.6533, -0.3271, -0.1491, -0.5842, -0.3145,
-        -0.1287,  0.2864,  0.0636, -0.2015,  0.0826,  1.5153, -0.9048,  0.0401,
-        -0.8174, -0.2842, -0.9127, -0.3088,  0.4886,  0.0000,  0.3236,  0.4893,
-         0.4019, -0.0817,  0.0194,  0.6937,  0.5930,  0.6844,  0.3661,  0.4270,
-         0.6675, -0.5538,  0.5277, -0.1057,  0.4880,  0.3447,  0.7345,  0.5035,
-         0.0239,  1.1051, -0.1804, -0.6096,  1.0432, -0.0457, -0.2525, -0.2147],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.9024,  0.7099, -0.4382,  0.0253,  0.2915,  0.0029, -0.5854, -1.1123,
-        -0.5988, -0.6146, -0.2091, -0.3161, -0.0127,  0.2118, -0.0929, -0.5797,
-        -0.7847,  0.6928, -0.8638, -0.6533, -0.3271, -0.1491, -0.5842, -0.3145,
-        -0.1287,  0.2864,  0.0636, -0.2015,  0.0826,  1.5153, -0.9048,  0.0401,
-        -0.8174, -0.2842, -0.9127, -0.3088,  0.4886,  0.0000,  0.3236,  0.4893,
-         0.4019, -0.0817,  0.0194,  0.6937,  0.5930,  0.6844,  0.3661,  0.4270,
-         0.6675, -0.5538,  0.5277, -0.1057,  0.4880,  0.3447,  0.7345,  0.5035,
-         0.0239,  1.1051, -0.1804, -0.6096,  1.0432, -0.0457, -0.2525, -0.2147],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.9024,  0.7099, -0.4382,  0.0253,  0.2915,  0.0029, -0.5854, -1.1123,
-        -0.5988, -0.6146, -0.2091, -0.3161, -0.0127,  0.2118, -0.0929, -0.5797,
-        -0.7847,  0.6928, -0.8638, -0.6533, -0.3271, -0.1491, -0.5842, -0.3145,
-        -0.1287,  0.2864,  0.0636, -0.2015,  0.0826,  1.5153, -0.9048,  0.0401,
-        -0.8174, -0.2842, -0.9127, -0.3088,  0.4886,  0.0000,  0.3236,  0.4893,
-         0.4019, -0.0817,  0.0194,  0.6937,  0.5930,  0.6844,  0.3661,  0.4270,
-         0.6675, -0.5538,  0.5277, -0.1057,  0.4880,  0.3447,  0.7345,  0.5035,
-         0.0239,  1.1051, -0.1804, -0.6096,  1.0432, -0.0457, -0.2525, -0.2147],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 8.9953e-01,  7.0912e-01, -4.3275e-01, -1.4428e-03,  2.9315e-01,
-         1.0506e-02, -5.8402e-01, -1.1115e+00, -6.0515e-01, -6.1241e-01,
-        -2.2131e-01, -3.2038e-01, -1.4874e-02,  1.9834e-01, -8.6796e-02,
-        -5.8208e-01, -7.8766e-01,  6.9730e-01, -8.6626e-01, -6.5007e-01,
-        -3.2225e-01, -1.5535e-01, -5.8389e-01, -3.1289e-01, -1.3179e-01,
-         2.8526e-01,  6.9646e-02, -2.2773e-01,  8.0369e-02,  1.5170e+00,
-        -9.0445e-01,  4.7837e-02, -8.2241e-01, -2.9930e-01, -9.1316e-01,
-        -3.1202e-01,  4.8882e-01,  0.0000e+00,  3.3642e-01,  4.8637e-01,
-         4.0246e-01, -4.2253e-02,  2.0668e-02,  6.9341e-01,  5.9713e-01,
-         6.9088e-01,  3.7560e-01,  4.3223e-01,  6.7582e-01, -5.5512e-01,
-         5.2483e-01, -1.2274e-01,  4.8765e-01,  3.4926e-01,  7.3391e-01,
-         5.0577e-01,  4.3959e-02,  1.1035e+00, -1.7291e-01, -6.0415e-01,
-         1.0449e+00, -4.4390e-02, -2.4873e-01, -2.1702e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.9953e-01,  7.0912e-01, -4.3275e-01, -1.4428e-03,  2.9315e-01,
-         1.0506e-02, -5.8402e-01, -1.1115e+00, -6.0515e-01, -6.1241e-01,
-        -2.2131e-01, -3.2038e-01, -1.4874e-02,  1.9834e-01, -8.6796e-02,
-        -5.8208e-01, -7.8766e-01,  6.9730e-01, -8.6626e-01, -6.5007e-01,
-        -3.2225e-01, -1.5535e-01, -5.8389e-01, -3.1289e-01, -1.3179e-01,
-         2.8526e-01,  6.9646e-02, -2.2773e-01,  8.0369e-02,  1.5170e+00,
-        -9.0445e-01,  4.7837e-02, -8.2241e-01, -2.9930e-01, -9.1316e-01,
-        -3.1202e-01,  4.8882e-01,  0.0000e+00,  3.3642e-01,  4.8637e-01,
-         4.0246e-01, -4.2253e-02,  2.0668e-02,  6.9341e-01,  5.9713e-01,
-         6.9088e-01,  3.7560e-01,  4.3223e-01,  6.7582e-01, -5.5512e-01,
-         5.2483e-01, -1.2274e-01,  4.8765e-01,  3.4926e-01,  7.3391e-01,
-         5.0577e-01,  4.3959e-02,  1.1035e+00, -1.7291e-01, -6.0415e-01,
-         1.0449e+00, -4.4390e-02, -2.4873e-01, -2.1702e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.9953e-01,  7.0912e-01, -4.3275e-01, -1.4428e-03,  2.9315e-01,
-         1.0506e-02, -5.8402e-01, -1.1115e+00, -6.0515e-01, -6.1241e-01,
-        -2.2131e-01, -3.2038e-01, -1.4874e-02,  1.9834e-01, -8.6796e-02,
-        -5.8208e-01, -7.8766e-01,  6.9730e-01, -8.6626e-01, -6.5007e-01,
-        -3.2225e-01, -1.5535e-01, -5.8389e-01, -3.1289e-01, -1.3179e-01,
-         2.8526e-01,  6.9646e-02, -2.2773e-01,  8.0369e-02,  1.5170e+00,
-        -9.0445e-01,  4.7837e-02, -8.2241e-01, -2.9930e-01, -9.1316e-01,
-        -3.1202e-01,  4.8882e-01,  0.0000e+00,  3.3642e-01,  4.8637e-01,
-         4.0246e-01, -4.2253e-02,  2.0668e-02,  6.9341e-01,  5.9713e-01,
-         6.9088e-01,  3.7560e-01,  4.3223e-01,  6.7582e-01, -5.5512e-01,
-         5.2483e-01, -1.2274e-01,  4.8765e-01,  3.4926e-01,  7.3391e-01,
-         5.0577e-01,  4.3959e-02,  1.1035e+00, -1.7291e-01, -6.0415e-01,
-         1.0449e+00, -4.4390e-02, -2.4873e-01, -2.1702e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8966,  0.7082, -0.4260, -0.0222,  0.2990,  0.0209, -0.5841, -1.1111,
-        -0.6092, -0.6098, -0.2316, -0.3238, -0.0132,  0.1833, -0.0753, -0.5849,
-        -0.7902,  0.7016, -0.8677, -0.6484, -0.3175, -0.1594, -0.5821, -0.3115,
-        -0.1319,  0.2791,  0.0780, -0.2524,  0.0773,  1.5186, -0.9038,  0.0445,
-        -0.8272, -0.3124, -0.9131, -0.3129,  0.4872,  0.0000,  0.3490,  0.4820,
-         0.4056,  0.0119,  0.0243,  0.6942,  0.6015,  0.6988,  0.3768,  0.4351,
-         0.6832, -0.5567,  0.5229, -0.1428,  0.4861,  0.3546,  0.7328,  0.5069,
-         0.0747,  1.1016, -0.1584, -0.5987,  1.0473, -0.0401, -0.2403, -0.2175],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8966,  0.7082, -0.4260, -0.0222,  0.2990,  0.0209, -0.5841, -1.1111,
-        -0.6092, -0.6098, -0.2316, -0.3238, -0.0132,  0.1833, -0.0753, -0.5849,
-        -0.7902,  0.7016, -0.8677, -0.6484, -0.3175, -0.1594, -0.5821, -0.3115,
-        -0.1319,  0.2791,  0.0780, -0.2524,  0.0773,  1.5186, -0.9038,  0.0445,
-        -0.8272, -0.3124, -0.9131, -0.3129,  0.4872,  0.0000,  0.3490,  0.4820,
-         0.4056,  0.0119,  0.0243,  0.6942,  0.6015,  0.6988,  0.3768,  0.4351,
-         0.6832, -0.5567,  0.5229, -0.1428,  0.4861,  0.3546,  0.7328,  0.5069,
-         0.0747,  1.1016, -0.1584, -0.5987,  1.0473, -0.0401, -0.2403, -0.2175],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8966,  0.7082, -0.4260, -0.0222,  0.2990,  0.0209, -0.5841, -1.1111,
-        -0.6092, -0.6098, -0.2316, -0.3238, -0.0132,  0.1833, -0.0753, -0.5849,
-        -0.7902,  0.7016, -0.8677, -0.6484, -0.3175, -0.1594, -0.5821, -0.3115,
-        -0.1319,  0.2791,  0.0780, -0.2524,  0.0773,  1.5186, -0.9038,  0.0445,
-        -0.8272, -0.3124, -0.9131, -0.3129,  0.4872,  0.0000,  0.3490,  0.4820,
-         0.4056,  0.0119,  0.0243,  0.6942,  0.6015,  0.6988,  0.3768,  0.4351,
-         0.6832, -0.5567,  0.5229, -0.1428,  0.4861,  0.3546,  0.7328,  0.5069,
-         0.0747,  1.1016, -0.1584, -0.5987,  1.0473, -0.0401, -0.2403, -0.2175],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8921,  0.7066, -0.4191, -0.0540,  0.3021,  0.0305, -0.5838, -1.1109,
-        -0.6133, -0.6066, -0.2385, -0.3286, -0.0122,  0.1706, -0.0707, -0.5865,
-        -0.7919,  0.7057, -0.8701, -0.6469, -0.3092, -0.1595, -0.5791, -0.3110,
-        -0.1364,  0.2734,  0.0856, -0.2766,  0.0720,  1.5202, -0.9028,  0.0405,
-        -0.8321, -0.3168, -0.9132, -0.3138,  0.4866,  0.0000,  0.3643,  0.4784,
-         0.4109,  0.0609,  0.0203,  0.6947,  0.6057,  0.7060,  0.3785,  0.4384,
-         0.6923, -0.5580,  0.5187, -0.1607,  0.4832,  0.3625,  0.7316,  0.5056,
-         0.1058,  1.0999, -0.1462, -0.5906,  1.0499, -0.0356, -0.2278, -0.2146],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8921,  0.7066, -0.4191, -0.0540,  0.3021,  0.0305, -0.5838, -1.1109,
-        -0.6133, -0.6066, -0.2385, -0.3286, -0.0122,  0.1706, -0.0707, -0.5865,
-        -0.7919,  0.7057, -0.8701, -0.6469, -0.3092, -0.1595, -0.5791, -0.3110,
-        -0.1364,  0.2734,  0.0856, -0.2766,  0.0720,  1.5202, -0.9028,  0.0405,
-        -0.8321, -0.3168, -0.9132, -0.3138,  0.4866,  0.0000,  0.3643,  0.4784,
-         0.4109,  0.0609,  0.0203,  0.6947,  0.6057,  0.7060,  0.3785,  0.4384,
-         0.6923, -0.5580,  0.5187, -0.1607,  0.4832,  0.3625,  0.7316,  0.5056,
-         0.1058,  1.0999, -0.1462, -0.5906,  1.0499, -0.0356, -0.2278, -0.2146],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8921,  0.7066, -0.4191, -0.0540,  0.3021,  0.0305, -0.5838, -1.1109,
-        -0.6133, -0.6066, -0.2385, -0.3286, -0.0122,  0.1706, -0.0707, -0.5865,
-        -0.7919,  0.7057, -0.8701, -0.6469, -0.3092, -0.1595, -0.5791, -0.3110,
-        -0.1364,  0.2734,  0.0856, -0.2766,  0.0720,  1.5202, -0.9028,  0.0405,
-        -0.8321, -0.3168, -0.9132, -0.3138,  0.4866,  0.0000,  0.3643,  0.4784,
-         0.4109,  0.0609,  0.0203,  0.6947,  0.6057,  0.7060,  0.3785,  0.4384,
-         0.6923, -0.5580,  0.5187, -0.1607,  0.4832,  0.3625,  0.7316,  0.5056,
-         0.1058,  1.0999, -0.1462, -0.5906,  1.0499, -0.0356, -0.2278, -0.2146],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 8.8733e-01,  7.0062e-01, -4.1233e-01, -7.7788e-02,  3.1067e-01,
-         4.2623e-02, -5.8435e-01, -1.1099e+00, -6.1735e-01, -6.0454e-01,
-        -2.4445e-01, -3.3174e-01,  1.1368e-03,  1.5474e-01, -6.9680e-02,
-        -5.9007e-01, -7.9353e-01,  7.1062e-01, -8.7218e-01, -6.4871e-01,
-        -2.9790e-01, -1.5554e-01, -5.7366e-01, -3.1225e-01, -1.3725e-01,
-         2.6970e-01,  8.2326e-02, -2.9924e-01,  5.9300e-02,  1.5219e+00,
-        -9.0248e-01,  1.8600e-02, -8.3580e-01, -3.2375e-01, -9.1369e-01,
-        -3.1447e-01,  4.8534e-01,  0.0000e+00,  3.8451e-01,  4.7385e-01,
-         4.1585e-01,  9.6695e-02,  1.9922e-02,  6.9430e-01,  6.0903e-01,
-         7.1289e-01,  3.7740e-01,  4.4350e-01,  7.0198e-01, -5.5883e-01,
-         5.1737e-01, -1.7847e-01,  4.7834e-01,  3.6765e-01,  7.2901e-01,
-         5.0385e-01,  1.3124e-01,  1.0983e+00, -1.3111e-01, -5.8349e-01,
-         1.0523e+00, -2.6574e-02, -2.0960e-01, -2.1107e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.8733e-01,  7.0062e-01, -4.1233e-01, -7.7788e-02,  3.1067e-01,
-         4.2623e-02, -5.8435e-01, -1.1099e+00, -6.1735e-01, -6.0454e-01,
-        -2.4445e-01, -3.3174e-01,  1.1368e-03,  1.5474e-01, -6.9680e-02,
-        -5.9007e-01, -7.9353e-01,  7.1062e-01, -8.7218e-01, -6.4871e-01,
-        -2.9790e-01, -1.5554e-01, -5.7366e-01, -3.1225e-01, -1.3725e-01,
-         2.6970e-01,  8.2326e-02, -2.9924e-01,  5.9300e-02,  1.5219e+00,
-        -9.0248e-01,  1.8600e-02, -8.3580e-01, -3.2375e-01, -9.1369e-01,
-        -3.1447e-01,  4.8534e-01,  0.0000e+00,  3.8451e-01,  4.7385e-01,
-         4.1585e-01,  9.6695e-02,  1.9922e-02,  6.9430e-01,  6.0903e-01,
-         7.1289e-01,  3.7740e-01,  4.4350e-01,  7.0198e-01, -5.5883e-01,
-         5.1737e-01, -1.7847e-01,  4.7834e-01,  3.6765e-01,  7.2901e-01,
-         5.0385e-01,  1.3124e-01,  1.0983e+00, -1.3111e-01, -5.8349e-01,
-         1.0523e+00, -2.6574e-02, -2.0960e-01, -2.1107e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.8733e-01,  7.0062e-01, -4.1233e-01, -7.7788e-02,  3.1067e-01,
-         4.2623e-02, -5.8435e-01, -1.1099e+00, -6.1735e-01, -6.0454e-01,
-        -2.4445e-01, -3.3174e-01,  1.1368e-03,  1.5474e-01, -6.9680e-02,
-        -5.9007e-01, -7.9353e-01,  7.1062e-01, -8.7218e-01, -6.4871e-01,
-        -2.9790e-01, -1.5554e-01, -5.7366e-01, -3.1225e-01, -1.3725e-01,
-         2.6970e-01,  8.2326e-02, -2.9924e-01,  5.9300e-02,  1.5219e+00,
-        -9.0248e-01,  1.8600e-02, -8.3580e-01, -3.2375e-01, -9.1369e-01,
-        -3.1447e-01,  4.8534e-01,  0.0000e+00,  3.8451e-01,  4.7385e-01,
-         4.1585e-01,  9.6695e-02,  1.9922e-02,  6.9430e-01,  6.0903e-01,
-         7.1289e-01,  3.7740e-01,  4.4350e-01,  7.0198e-01, -5.5883e-01,
-         5.1737e-01, -1.7847e-01,  4.7834e-01,  3.6765e-01,  7.2901e-01,
-         5.0385e-01,  1.3124e-01,  1.0983e+00, -1.3111e-01, -5.8349e-01,
-         1.0523e+00, -2.6574e-02, -2.0960e-01, -2.1107e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8829,  0.6941, -0.4042, -0.0971,  0.3114,  0.0577, -0.5844, -1.1088,
-        -0.6198, -0.5994, -0.2424, -0.3353,  0.0118,  0.1344, -0.0753, -0.5916,
-        -0.7958,  0.7153, -0.8761, -0.6528, -0.2810, -0.1473, -0.5676, -0.3071,
-        -0.1384,  0.2696,  0.0605, -0.3190,  0.0451,  1.5236, -0.9026, -0.0098,
-        -0.8394, -0.3253, -0.9138, -0.3149,  0.4830,  0.0000,  0.3991,  0.4678,
-         0.4197,  0.1129,  0.0153,  0.6925,  0.6120,  0.7189,  0.3778,  0.4490,
-         0.7110, -0.5598,  0.5173, -0.1967,  0.4752,  0.3731,  0.7269,  0.5036,
-         0.1593,  1.0955, -0.1055, -0.5749,  1.0551, -0.0143, -0.1960, -0.2098],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8829,  0.6941, -0.4042, -0.0971,  0.3114,  0.0577, -0.5844, -1.1088,
-        -0.6198, -0.5994, -0.2424, -0.3353,  0.0118,  0.1344, -0.0753, -0.5916,
-        -0.7958,  0.7153, -0.8761, -0.6528, -0.2810, -0.1473, -0.5676, -0.3071,
-        -0.1384,  0.2696,  0.0605, -0.3190,  0.0451,  1.5236, -0.9026, -0.0098,
-        -0.8394, -0.3253, -0.9138, -0.3149,  0.4830,  0.0000,  0.3991,  0.4678,
-         0.4197,  0.1129,  0.0153,  0.6925,  0.6120,  0.7189,  0.3778,  0.4490,
-         0.7110, -0.5598,  0.5173, -0.1967,  0.4752,  0.3731,  0.7269,  0.5036,
-         0.1593,  1.0955, -0.1055, -0.5749,  1.0551, -0.0143, -0.1960, -0.2098],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8829,  0.6941, -0.4042, -0.0971,  0.3114,  0.0577, -0.5844, -1.1088,
-        -0.6198, -0.5994, -0.2424, -0.3353,  0.0118,  0.1344, -0.0753, -0.5916,
-        -0.7958,  0.7153, -0.8761, -0.6528, -0.2810, -0.1473, -0.5676, -0.3071,
-        -0.1384,  0.2696,  0.0605, -0.3190,  0.0451,  1.5236, -0.9026, -0.0098,
-        -0.8394, -0.3253, -0.9138, -0.3149,  0.4830,  0.0000,  0.3991,  0.4678,
-         0.4197,  0.1129,  0.0153,  0.6925,  0.6120,  0.7189,  0.3778,  0.4490,
-         0.7110, -0.5598,  0.5173, -0.1967,  0.4752,  0.3731,  0.7269,  0.5036,
-         0.1593,  1.0955, -0.1055, -0.5749,  1.0551, -0.0143, -0.1960, -0.2098],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8779,  0.6892, -0.3961, -0.1161,  0.3117,  0.0732, -0.5847, -1.1081,
-        -0.6209, -0.5901, -0.2408, -0.3417,  0.0198,  0.1118, -0.1008, -0.5939,
-        -0.7971,  0.7210, -0.8806, -0.6565, -0.2615, -0.1348, -0.5635, -0.3020,
-        -0.1348,  0.2640,  0.0343, -0.3340,  0.0354,  1.5255, -0.9018, -0.0435,
-        -0.8422, -0.3188, -0.9144, -0.3172,  0.4829,  0.0000,  0.4122,  0.4616,
-         0.4241,  0.1365,  0.0117,  0.6883,  0.6151,  0.7234,  0.3807,  0.4526,
-         0.7203, -0.5618,  0.5181, -0.2124,  0.4740,  0.3762,  0.7257,  0.5043,
-         0.1846,  1.0937, -0.0772, -0.5641,  1.0571, -0.0026, -0.1845, -0.2138],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8779,  0.6892, -0.3961, -0.1161,  0.3117,  0.0732, -0.5847, -1.1081,
-        -0.6209, -0.5901, -0.2408, -0.3417,  0.0198,  0.1118, -0.1008, -0.5939,
-        -0.7971,  0.7210, -0.8806, -0.6565, -0.2615, -0.1348, -0.5635, -0.3020,
-        -0.1348,  0.2640,  0.0343, -0.3340,  0.0354,  1.5255, -0.9018, -0.0435,
-        -0.8422, -0.3188, -0.9144, -0.3172,  0.4829,  0.0000,  0.4122,  0.4616,
-         0.4241,  0.1365,  0.0117,  0.6883,  0.6151,  0.7234,  0.3807,  0.4526,
-         0.7203, -0.5618,  0.5181, -0.2124,  0.4740,  0.3762,  0.7257,  0.5043,
-         0.1846,  1.0937, -0.0772, -0.5641,  1.0571, -0.0026, -0.1845, -0.2138],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8779,  0.6892, -0.3961, -0.1161,  0.3117,  0.0732, -0.5847, -1.1081,
-        -0.6209, -0.5901, -0.2408, -0.3417,  0.0198,  0.1118, -0.1008, -0.5939,
-        -0.7971,  0.7210, -0.8806, -0.6565, -0.2615, -0.1348, -0.5635, -0.3020,
-        -0.1348,  0.2640,  0.0343, -0.3340,  0.0354,  1.5255, -0.9018, -0.0435,
-        -0.8422, -0.3188, -0.9144, -0.3172,  0.4829,  0.0000,  0.4122,  0.4616,
-         0.4241,  0.1365,  0.0117,  0.6883,  0.6151,  0.7234,  0.3807,  0.4526,
-         0.7203, -0.5618,  0.5181, -0.2124,  0.4740,  0.3762,  0.7257,  0.5043,
-         0.1846,  1.0937, -0.0772, -0.5641,  1.0571, -0.0026, -0.1845, -0.2138],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8739,  0.6856, -0.3867, -0.1387,  0.3066,  0.0839, -0.5852, -1.1074,
-        -0.6220, -0.5830, -0.2374, -0.3520,  0.0312,  0.0856, -0.1342, -0.5966,
-        -0.7993,  0.7242, -0.8850, -0.6611, -0.2406, -0.1247, -0.5604, -0.2984,
-        -0.1307,  0.2561,  0.0038, -0.3484,  0.0280,  1.5269, -0.8999, -0.0892,
-        -0.8443, -0.3069, -0.9151, -0.3222,  0.4856,  0.0000,  0.4274,  0.4577,
-         0.4249,  0.1660,  0.0061,  0.6843,  0.6194,  0.7249,  0.3858,  0.4515,
-         0.7284, -0.5639,  0.5183, -0.2260,  0.4725,  0.3767,  0.7249,  0.5065,
-         0.2092,  1.0912, -0.0403, -0.5509,  1.0583,  0.0183, -0.1769, -0.2185],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8739,  0.6856, -0.3867, -0.1387,  0.3066,  0.0839, -0.5852, -1.1074,
-        -0.6220, -0.5830, -0.2374, -0.3520,  0.0312,  0.0856, -0.1342, -0.5966,
-        -0.7993,  0.7242, -0.8850, -0.6611, -0.2406, -0.1247, -0.5604, -0.2984,
-        -0.1307,  0.2561,  0.0038, -0.3484,  0.0280,  1.5269, -0.8999, -0.0892,
-        -0.8443, -0.3069, -0.9151, -0.3222,  0.4856,  0.0000,  0.4274,  0.4577,
-         0.4249,  0.1660,  0.0061,  0.6843,  0.6194,  0.7249,  0.3858,  0.4515,
-         0.7284, -0.5639,  0.5183, -0.2260,  0.4725,  0.3767,  0.7249,  0.5065,
-         0.2092,  1.0912, -0.0403, -0.5509,  1.0583,  0.0183, -0.1769, -0.2185],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8739,  0.6856, -0.3867, -0.1387,  0.3066,  0.0839, -0.5852, -1.1074,
-        -0.6220, -0.5830, -0.2374, -0.3520,  0.0312,  0.0856, -0.1342, -0.5966,
-        -0.7993,  0.7242, -0.8850, -0.6611, -0.2406, -0.1247, -0.5604, -0.2984,
-        -0.1307,  0.2561,  0.0038, -0.3484,  0.0280,  1.5269, -0.8999, -0.0892,
-        -0.8443, -0.3069, -0.9151, -0.3222,  0.4856,  0.0000,  0.4274,  0.4577,
-         0.4249,  0.1660,  0.0061,  0.6843,  0.6194,  0.7249,  0.3858,  0.4515,
-         0.7284, -0.5639,  0.5183, -0.2260,  0.4725,  0.3767,  0.7249,  0.5065,
-         0.2092,  1.0912, -0.0403, -0.5509,  1.0583,  0.0183, -0.1769, -0.2185],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8680,  0.6834, -0.3778, -0.1633,  0.2926,  0.0991, -0.5862, -1.1067,
-        -0.6225, -0.5777, -0.2308, -0.3628,  0.0319,  0.0578, -0.1767, -0.6005,
-        -0.8013,  0.7279, -0.8890, -0.6664, -0.2191, -0.1067, -0.5558, -0.2917,
-        -0.1206,  0.2519, -0.0161, -0.3618,  0.0190,  1.5280, -0.8977, -0.1352,
-        -0.8464, -0.2889, -0.9153, -0.3272,  0.4932,  0.0000,  0.4445,  0.4547,
-         0.4261,  0.1892, -0.0083,  0.6790,  0.6249,  0.7243,  0.3969,  0.4589,
-         0.7368, -0.5658,  0.5168, -0.2377,  0.4727,  0.3698,  0.7230,  0.5116,
-         0.2350,  1.0883, -0.0025, -0.5301,  1.0601,  0.0364, -0.1706, -0.2170],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8680,  0.6834, -0.3778, -0.1633,  0.2926,  0.0991, -0.5862, -1.1067,
-        -0.6225, -0.5777, -0.2308, -0.3628,  0.0319,  0.0578, -0.1767, -0.6005,
-        -0.8013,  0.7279, -0.8890, -0.6664, -0.2191, -0.1067, -0.5558, -0.2917,
-        -0.1206,  0.2519, -0.0161, -0.3618,  0.0190,  1.5280, -0.8977, -0.1352,
-        -0.8464, -0.2889, -0.9153, -0.3272,  0.4932,  0.0000,  0.4445,  0.4547,
-         0.4261,  0.1892, -0.0083,  0.6790,  0.6249,  0.7243,  0.3969,  0.4589,
-         0.7368, -0.5658,  0.5168, -0.2377,  0.4727,  0.3698,  0.7230,  0.5116,
-         0.2350,  1.0883, -0.0025, -0.5301,  1.0601,  0.0364, -0.1706, -0.2170],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8680,  0.6834, -0.3778, -0.1633,  0.2926,  0.0991, -0.5862, -1.1067,
-        -0.6225, -0.5777, -0.2308, -0.3628,  0.0319,  0.0578, -0.1767, -0.6005,
-        -0.8013,  0.7279, -0.8890, -0.6664, -0.2191, -0.1067, -0.5558, -0.2917,
-        -0.1206,  0.2519, -0.0161, -0.3618,  0.0190,  1.5280, -0.8977, -0.1352,
-        -0.8464, -0.2889, -0.9153, -0.3272,  0.4932,  0.0000,  0.4445,  0.4547,
-         0.4261,  0.1892, -0.0083,  0.6790,  0.6249,  0.7243,  0.3969,  0.4589,
-         0.7368, -0.5658,  0.5168, -0.2377,  0.4727,  0.3698,  0.7230,  0.5116,
-         0.2350,  1.0883, -0.0025, -0.5301,  1.0601,  0.0364, -0.1706, -0.2170],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8616,  0.6784, -0.3682, -0.1885,  0.2735,  0.1160, -0.5898, -1.1059,
-        -0.6250, -0.5733, -0.2214, -0.3744,  0.0364,  0.0171, -0.2265, -0.6058,
-        -0.8023,  0.7294, -0.8921, -0.6732, -0.1979, -0.0874, -0.5504, -0.2777,
-        -0.1009,  0.2484, -0.0350, -0.3753, -0.0033,  1.5290, -0.8963, -0.1809,
-        -0.8487, -0.2754, -0.9149, -0.3340,  0.5018,  0.0000,  0.4606,  0.4533,
-         0.4260,  0.2003, -0.0252,  0.6753,  0.6305,  0.7240,  0.4099,  0.4680,
-         0.7452, -0.5672,  0.5151, -0.2510,  0.4737,  0.3600,  0.7202,  0.5161,
-         0.2677,  1.0844,  0.0226, -0.5053,  1.0608,  0.0481, -0.1564, -0.2171],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8616,  0.6784, -0.3682, -0.1885,  0.2735,  0.1160, -0.5898, -1.1059,
-        -0.6250, -0.5733, -0.2214, -0.3744,  0.0364,  0.0171, -0.2265, -0.6058,
-        -0.8023,  0.7294, -0.8921, -0.6732, -0.1979, -0.0874, -0.5504, -0.2777,
-        -0.1009,  0.2484, -0.0350, -0.3753, -0.0033,  1.5290, -0.8963, -0.1809,
-        -0.8487, -0.2754, -0.9149, -0.3340,  0.5018,  0.0000,  0.4606,  0.4533,
-         0.4260,  0.2003, -0.0252,  0.6753,  0.6305,  0.7240,  0.4099,  0.4680,
-         0.7452, -0.5672,  0.5151, -0.2510,  0.4737,  0.3600,  0.7202,  0.5161,
-         0.2677,  1.0844,  0.0226, -0.5053,  1.0608,  0.0481, -0.1564, -0.2171],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8616,  0.6784, -0.3682, -0.1885,  0.2735,  0.1160, -0.5898, -1.1059,
-        -0.6250, -0.5733, -0.2214, -0.3744,  0.0364,  0.0171, -0.2265, -0.6058,
-        -0.8023,  0.7294, -0.8921, -0.6732, -0.1979, -0.0874, -0.5504, -0.2777,
-        -0.1009,  0.2484, -0.0350, -0.3753, -0.0033,  1.5290, -0.8963, -0.1809,
-        -0.8487, -0.2754, -0.9149, -0.3340,  0.5018,  0.0000,  0.4606,  0.4533,
-         0.4260,  0.2003, -0.0252,  0.6753,  0.6305,  0.7240,  0.4099,  0.4680,
-         0.7452, -0.5672,  0.5151, -0.2510,  0.4737,  0.3600,  0.7202,  0.5161,
-         0.2677,  1.0844,  0.0226, -0.5053,  1.0608,  0.0481, -0.1564, -0.2171],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8549,  0.6751, -0.3571, -0.2131,  0.2501,  0.1341, -0.5920, -1.1055,
-        -0.6260, -0.5680, -0.2080, -0.3959,  0.0295, -0.0249, -0.2799, -0.6121,
-        -0.8053,  0.7285, -0.8933, -0.6800, -0.1791, -0.0660, -0.5444, -0.2574,
-        -0.0784,  0.2468, -0.0395, -0.3894, -0.0221,  1.5304, -0.8946, -0.2240,
-        -0.8512, -0.2574, -0.9139, -0.3364,  0.5116,  0.0000,  0.4743,  0.4500,
-         0.4231,  0.2076, -0.0407,  0.6681,  0.6368,  0.7255,  0.4250,  0.4769,
-         0.7513, -0.5687,  0.5142, -0.2660,  0.4720,  0.3462,  0.7168,  0.5217,
-         0.3055,  1.0805,  0.0521, -0.4805,  1.0611,  0.0531, -0.1400, -0.2094],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8549,  0.6751, -0.3571, -0.2131,  0.2501,  0.1341, -0.5920, -1.1055,
-        -0.6260, -0.5680, -0.2080, -0.3959,  0.0295, -0.0249, -0.2799, -0.6121,
-        -0.8053,  0.7285, -0.8933, -0.6800, -0.1791, -0.0660, -0.5444, -0.2574,
-        -0.0784,  0.2468, -0.0395, -0.3894, -0.0221,  1.5304, -0.8946, -0.2240,
-        -0.8512, -0.2574, -0.9139, -0.3364,  0.5116,  0.0000,  0.4743,  0.4500,
-         0.4231,  0.2076, -0.0407,  0.6681,  0.6368,  0.7255,  0.4250,  0.4769,
-         0.7513, -0.5687,  0.5142, -0.2660,  0.4720,  0.3462,  0.7168,  0.5217,
-         0.3055,  1.0805,  0.0521, -0.4805,  1.0611,  0.0531, -0.1400, -0.2094],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8549,  0.6751, -0.3571, -0.2131,  0.2501,  0.1341, -0.5920, -1.1055,
-        -0.6260, -0.5680, -0.2080, -0.3959,  0.0295, -0.0249, -0.2799, -0.6121,
-        -0.8053,  0.7285, -0.8933, -0.6800, -0.1791, -0.0660, -0.5444, -0.2574,
-        -0.0784,  0.2468, -0.0395, -0.3894, -0.0221,  1.5304, -0.8946, -0.2240,
-        -0.8512, -0.2574, -0.9139, -0.3364,  0.5116,  0.0000,  0.4743,  0.4500,
-         0.4231,  0.2076, -0.0407,  0.6681,  0.6368,  0.7255,  0.4250,  0.4769,
-         0.7513, -0.5687,  0.5142, -0.2660,  0.4720,  0.3462,  0.7168,  0.5217,
-         0.3055,  1.0805,  0.0521, -0.4805,  1.0611,  0.0531, -0.1400, -0.2094],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8488,  0.6713, -0.3420, -0.2311,  0.2361,  0.1492, -0.5959, -1.1047,
-        -0.6284, -0.5644, -0.2015, -0.4121,  0.0281, -0.0554, -0.3284, -0.6185,
-        -0.8091,  0.7275, -0.8953, -0.6835, -0.1505, -0.0481, -0.5416, -0.2526,
-        -0.0563,  0.2499, -0.0381, -0.4028, -0.0447,  1.5320, -0.8926, -0.2716,
-        -0.8546, -0.2521, -0.9124, -0.3346,  0.5216,  0.0000,  0.4868,  0.4423,
-         0.4171,  0.2135, -0.0556,  0.6617,  0.6424,  0.7270,  0.4396,  0.4858,
-         0.7557, -0.5690,  0.5154, -0.2789,  0.4696,  0.3254,  0.7104,  0.5248,
-         0.3405,  1.0751,  0.0709, -0.4601,  1.0613,  0.0571, -0.1228, -0.1892],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8488,  0.6713, -0.3420, -0.2311,  0.2361,  0.1492, -0.5959, -1.1047,
-        -0.6284, -0.5644, -0.2015, -0.4121,  0.0281, -0.0554, -0.3284, -0.6185,
-        -0.8091,  0.7275, -0.8953, -0.6835, -0.1505, -0.0481, -0.5416, -0.2526,
-        -0.0563,  0.2499, -0.0381, -0.4028, -0.0447,  1.5320, -0.8926, -0.2716,
-        -0.8546, -0.2521, -0.9124, -0.3346,  0.5216,  0.0000,  0.4868,  0.4423,
-         0.4171,  0.2135, -0.0556,  0.6617,  0.6424,  0.7270,  0.4396,  0.4858,
-         0.7557, -0.5690,  0.5154, -0.2789,  0.4696,  0.3254,  0.7104,  0.5248,
-         0.3405,  1.0751,  0.0709, -0.4601,  1.0613,  0.0571, -0.1228, -0.1892],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8488,  0.6713, -0.3420, -0.2311,  0.2361,  0.1492, -0.5959, -1.1047,
-        -0.6284, -0.5644, -0.2015, -0.4121,  0.0281, -0.0554, -0.3284, -0.6185,
-        -0.8091,  0.7275, -0.8953, -0.6835, -0.1505, -0.0481, -0.5416, -0.2526,
-        -0.0563,  0.2499, -0.0381, -0.4028, -0.0447,  1.5320, -0.8926, -0.2716,
-        -0.8546, -0.2521, -0.9124, -0.3346,  0.5216,  0.0000,  0.4868,  0.4423,
-         0.4171,  0.2135, -0.0556,  0.6617,  0.6424,  0.7270,  0.4396,  0.4858,
-         0.7557, -0.5690,  0.5154, -0.2789,  0.4696,  0.3254,  0.7104,  0.5248,
-         0.3405,  1.0751,  0.0709, -0.4601,  1.0613,  0.0571, -0.1228, -0.1892],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8435,  0.6673, -0.3222, -0.2449,  0.2230,  0.1679, -0.5980, -1.1045,
-        -0.6295, -0.5626, -0.1915, -0.4256,  0.0411, -0.0720, -0.3773, -0.6272,
-        -0.8136,  0.7240, -0.8970, -0.6882, -0.1252, -0.0254, -0.5410, -0.2429,
-        -0.0341,  0.2560, -0.0352, -0.4209, -0.0699,  1.5331, -0.8895, -0.3166,
-        -0.8590, -0.2454, -0.9102, -0.3352,  0.5339,  0.0000,  0.4966,  0.4346,
-         0.4086,  0.2417, -0.0645,  0.6585,  0.6485,  0.7280,  0.4546,  0.4933,
-         0.7576, -0.5691,  0.5169, -0.2875,  0.4734,  0.3007,  0.7026,  0.5277,
-         0.3785,  1.0689,  0.0888, -0.4440,  1.0613,  0.0757, -0.1090, -0.1641],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8435,  0.6673, -0.3222, -0.2449,  0.2230,  0.1679, -0.5980, -1.1045,
-        -0.6295, -0.5626, -0.1915, -0.4256,  0.0411, -0.0720, -0.3773, -0.6272,
-        -0.8136,  0.7240, -0.8970, -0.6882, -0.1252, -0.0254, -0.5410, -0.2429,
-        -0.0341,  0.2560, -0.0352, -0.4209, -0.0699,  1.5331, -0.8895, -0.3166,
-        -0.8590, -0.2454, -0.9102, -0.3352,  0.5339,  0.0000,  0.4966,  0.4346,
-         0.4086,  0.2417, -0.0645,  0.6585,  0.6485,  0.7280,  0.4546,  0.4933,
-         0.7576, -0.5691,  0.5169, -0.2875,  0.4734,  0.3007,  0.7026,  0.5277,
-         0.3785,  1.0689,  0.0888, -0.4440,  1.0613,  0.0757, -0.1090, -0.1641],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8435,  0.6673, -0.3222, -0.2449,  0.2230,  0.1679, -0.5980, -1.1045,
-        -0.6295, -0.5626, -0.1915, -0.4256,  0.0411, -0.0720, -0.3773, -0.6272,
-        -0.8136,  0.7240, -0.8970, -0.6882, -0.1252, -0.0254, -0.5410, -0.2429,
-        -0.0341,  0.2560, -0.0352, -0.4209, -0.0699,  1.5331, -0.8895, -0.3166,
-        -0.8590, -0.2454, -0.9102, -0.3352,  0.5339,  0.0000,  0.4966,  0.4346,
-         0.4086,  0.2417, -0.0645,  0.6585,  0.6485,  0.7280,  0.4546,  0.4933,
-         0.7576, -0.5691,  0.5169, -0.2875,  0.4734,  0.3007,  0.7026,  0.5277,
-         0.3785,  1.0689,  0.0888, -0.4440,  1.0613,  0.0757, -0.1090, -0.1641],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8391,  0.6691, -0.3004, -0.2495,  0.2097,  0.1873, -0.6013, -1.1041,
-        -0.6269, -0.5639, -0.1777, -0.4454,  0.0375, -0.0682, -0.4275, -0.6342,
-        -0.8192,  0.7200, -0.8989, -0.6911, -0.1006, -0.0115, -0.5464, -0.2337,
-        -0.0160,  0.2618, -0.0305, -0.4366, -0.0957,  1.5344, -0.8857, -0.3558,
-        -0.8624, -0.2297, -0.9083, -0.3328,  0.5464,  0.0000,  0.5053,  0.4262,
-         0.3974,  0.2753, -0.0859,  0.6565,  0.6551,  0.7265,  0.4721,  0.5000,
-         0.7602, -0.5681,  0.5192, -0.2972,  0.4786,  0.2767,  0.6934,  0.5301,
-         0.4225,  1.0631,  0.1032, -0.4325,  1.0621,  0.1015, -0.1037, -0.1397],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8391,  0.6691, -0.3004, -0.2495,  0.2097,  0.1873, -0.6013, -1.1041,
-        -0.6269, -0.5639, -0.1777, -0.4454,  0.0375, -0.0682, -0.4275, -0.6342,
-        -0.8192,  0.7200, -0.8989, -0.6911, -0.1006, -0.0115, -0.5464, -0.2337,
-        -0.0160,  0.2618, -0.0305, -0.4366, -0.0957,  1.5344, -0.8857, -0.3558,
-        -0.8624, -0.2297, -0.9083, -0.3328,  0.5464,  0.0000,  0.5053,  0.4262,
-         0.3974,  0.2753, -0.0859,  0.6565,  0.6551,  0.7265,  0.4721,  0.5000,
-         0.7602, -0.5681,  0.5192, -0.2972,  0.4786,  0.2767,  0.6934,  0.5301,
-         0.4225,  1.0631,  0.1032, -0.4325,  1.0621,  0.1015, -0.1037, -0.1397],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8391,  0.6691, -0.3004, -0.2495,  0.2097,  0.1873, -0.6013, -1.1041,
-        -0.6269, -0.5639, -0.1777, -0.4454,  0.0375, -0.0682, -0.4275, -0.6342,
-        -0.8192,  0.7200, -0.8989, -0.6911, -0.1006, -0.0115, -0.5464, -0.2337,
-        -0.0160,  0.2618, -0.0305, -0.4366, -0.0957,  1.5344, -0.8857, -0.3558,
-        -0.8624, -0.2297, -0.9083, -0.3328,  0.5464,  0.0000,  0.5053,  0.4262,
-         0.3974,  0.2753, -0.0859,  0.6565,  0.6551,  0.7265,  0.4721,  0.5000,
-         0.7602, -0.5681,  0.5192, -0.2972,  0.4786,  0.2767,  0.6934,  0.5301,
-         0.4225,  1.0631,  0.1032, -0.4325,  1.0621,  0.1015, -0.1037, -0.1397],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8361,  0.6720, -0.2792, -0.2438,  0.2012,  0.2060, -0.6036, -1.1035,
-        -0.6218, -0.5627, -0.1545, -0.4718,  0.0380, -0.0572, -0.4797, -0.6399,
-        -0.8268,  0.7146, -0.9006, -0.6946, -0.0867,  0.0023, -0.5481, -0.2410,
-         0.0043,  0.2607, -0.0221, -0.4514, -0.1151,  1.5356, -0.8815, -0.3965,
-        -0.8659, -0.2050, -0.9076, -0.3322,  0.5594,  0.0000,  0.5113,  0.4191,
-         0.3891,  0.3217, -0.1103,  0.6559,  0.6601,  0.7266,  0.4872,  0.5030,
-         0.7630, -0.5683,  0.5195, -0.3081,  0.4850,  0.2481,  0.6843,  0.5312,
-         0.4574,  1.0570,  0.1117, -0.4299,  1.0638,  0.1134, -0.1064, -0.1164],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8361,  0.6720, -0.2792, -0.2438,  0.2012,  0.2060, -0.6036, -1.1035,
-        -0.6218, -0.5627, -0.1545, -0.4718,  0.0380, -0.0572, -0.4797, -0.6399,
-        -0.8268,  0.7146, -0.9006, -0.6946, -0.0867,  0.0023, -0.5481, -0.2410,
-         0.0043,  0.2607, -0.0221, -0.4514, -0.1151,  1.5356, -0.8815, -0.3965,
-        -0.8659, -0.2050, -0.9076, -0.3322,  0.5594,  0.0000,  0.5113,  0.4191,
-         0.3891,  0.3217, -0.1103,  0.6559,  0.6601,  0.7266,  0.4872,  0.5030,
-         0.7630, -0.5683,  0.5195, -0.3081,  0.4850,  0.2481,  0.6843,  0.5312,
-         0.4574,  1.0570,  0.1117, -0.4299,  1.0638,  0.1134, -0.1064, -0.1164],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8361,  0.6720, -0.2792, -0.2438,  0.2012,  0.2060, -0.6036, -1.1035,
-        -0.6218, -0.5627, -0.1545, -0.4718,  0.0380, -0.0572, -0.4797, -0.6399,
-        -0.8268,  0.7146, -0.9006, -0.6946, -0.0867,  0.0023, -0.5481, -0.2410,
-         0.0043,  0.2607, -0.0221, -0.4514, -0.1151,  1.5356, -0.8815, -0.3965,
-        -0.8659, -0.2050, -0.9076, -0.3322,  0.5594,  0.0000,  0.5113,  0.4191,
-         0.3891,  0.3217, -0.1103,  0.6559,  0.6601,  0.7266,  0.4872,  0.5030,
-         0.7630, -0.5683,  0.5195, -0.3081,  0.4850,  0.2481,  0.6843,  0.5312,
-         0.4574,  1.0570,  0.1117, -0.4299,  1.0638,  0.1134, -0.1064, -0.1164],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8342,  0.6717, -0.2580, -0.2372,  0.2124,  0.2299, -0.6076, -1.1030,
-        -0.6166, -0.5620, -0.1467, -0.4903,  0.0508, -0.0319, -0.5286, -0.6476,
-        -0.8355,  0.7117, -0.9024, -0.6978, -0.0690,  0.0206, -0.5494, -0.2538,
-         0.0121,  0.2630, -0.0340, -0.4613, -0.1419,  1.5365, -0.8791, -0.4363,
-        -0.8703, -0.1913, -0.9066, -0.3341,  0.5699,  0.0000,  0.5182,  0.4144,
-         0.3789,  0.3720, -0.1299,  0.6548,  0.6638,  0.7298,  0.5034,  0.5029,
-         0.7638, -0.5697,  0.5211, -0.3174,  0.4900,  0.2090,  0.6759,  0.5319,
-         0.4914,  1.0517,  0.1078, -0.4367,  1.0664,  0.1249, -0.0992, -0.0911],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8342,  0.6717, -0.2580, -0.2372,  0.2124,  0.2299, -0.6076, -1.1030,
-        -0.6166, -0.5620, -0.1467, -0.4903,  0.0508, -0.0319, -0.5286, -0.6476,
-        -0.8355,  0.7117, -0.9024, -0.6978, -0.0690,  0.0206, -0.5494, -0.2538,
-         0.0121,  0.2630, -0.0340, -0.4613, -0.1419,  1.5365, -0.8791, -0.4363,
-        -0.8703, -0.1913, -0.9066, -0.3341,  0.5699,  0.0000,  0.5182,  0.4144,
-         0.3789,  0.3720, -0.1299,  0.6548,  0.6638,  0.7298,  0.5034,  0.5029,
-         0.7638, -0.5697,  0.5211, -0.3174,  0.4900,  0.2090,  0.6759,  0.5319,
-         0.4914,  1.0517,  0.1078, -0.4367,  1.0664,  0.1249, -0.0992, -0.0911],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8342,  0.6717, -0.2580, -0.2372,  0.2124,  0.2299, -0.6076, -1.1030,
-        -0.6166, -0.5620, -0.1467, -0.4903,  0.0508, -0.0319, -0.5286, -0.6476,
-        -0.8355,  0.7117, -0.9024, -0.6978, -0.0690,  0.0206, -0.5494, -0.2538,
-         0.0121,  0.2630, -0.0340, -0.4613, -0.1419,  1.5365, -0.8791, -0.4363,
-        -0.8703, -0.1913, -0.9066, -0.3341,  0.5699,  0.0000,  0.5182,  0.4144,
-         0.3789,  0.3720, -0.1299,  0.6548,  0.6638,  0.7298,  0.5034,  0.5029,
-         0.7638, -0.5697,  0.5211, -0.3174,  0.4900,  0.2090,  0.6759,  0.5319,
-         0.4914,  1.0517,  0.1078, -0.4367,  1.0664,  0.1249, -0.0992, -0.0911],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8320,  0.6782, -0.2257, -0.2280,  0.2331,  0.2634, -0.6142, -1.1017,
-        -0.6103, -0.5659, -0.1389, -0.5033,  0.0693, -0.0228, -0.5789, -0.6532,
-        -0.8451,  0.7116, -0.9060, -0.6979, -0.0542,  0.0530, -0.5582, -0.2583,
-         0.0170,  0.2651, -0.0416, -0.4759, -0.1735,  1.5370, -0.8771, -0.4708,
-        -0.8720, -0.1885, -0.9062, -0.3400,  0.5778,  0.0000,  0.5220,  0.4084,
-         0.3728,  0.4303, -0.1370,  0.6557,  0.6660,  0.7345,  0.5201,  0.4988,
-         0.7588, -0.5712,  0.5153, -0.3158,  0.5050,  0.1578,  0.6675,  0.5338,
-         0.5305,  1.0481,  0.1046, -0.4353,  1.0706,  0.1359, -0.0899, -0.0652],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8320,  0.6782, -0.2257, -0.2280,  0.2331,  0.2634, -0.6142, -1.1017,
-        -0.6103, -0.5659, -0.1389, -0.5033,  0.0693, -0.0228, -0.5789, -0.6532,
-        -0.8451,  0.7116, -0.9060, -0.6979, -0.0542,  0.0530, -0.5582, -0.2583,
-         0.0170,  0.2651, -0.0416, -0.4759, -0.1735,  1.5370, -0.8771, -0.4708,
-        -0.8720, -0.1885, -0.9062, -0.3400,  0.5778,  0.0000,  0.5220,  0.4084,
-         0.3728,  0.4303, -0.1370,  0.6557,  0.6660,  0.7345,  0.5201,  0.4988,
-         0.7588, -0.5712,  0.5153, -0.3158,  0.5050,  0.1578,  0.6675,  0.5338,
-         0.5305,  1.0481,  0.1046, -0.4353,  1.0706,  0.1359, -0.0899, -0.0652],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8320,  0.6782, -0.2257, -0.2280,  0.2331,  0.2634, -0.6142, -1.1017,
-        -0.6103, -0.5659, -0.1389, -0.5033,  0.0693, -0.0228, -0.5789, -0.6532,
-        -0.8451,  0.7116, -0.9060, -0.6979, -0.0542,  0.0530, -0.5582, -0.2583,
-         0.0170,  0.2651, -0.0416, -0.4759, -0.1735,  1.5370, -0.8771, -0.4708,
-        -0.8720, -0.1885, -0.9062, -0.3400,  0.5778,  0.0000,  0.5220,  0.4084,
-         0.3728,  0.4303, -0.1370,  0.6557,  0.6660,  0.7345,  0.5201,  0.4988,
-         0.7588, -0.5712,  0.5153, -0.3158,  0.5050,  0.1578,  0.6675,  0.5338,
-         0.5305,  1.0481,  0.1046, -0.4353,  1.0706,  0.1359, -0.0899, -0.0652],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8328,  0.6827, -0.2034, -0.2186,  0.2607,  0.3025, -0.6211, -1.1005,
-        -0.6012, -0.5698, -0.1202, -0.5168,  0.0740, -0.0193, -0.6270, -0.6569,
-        -0.8544,  0.7149, -0.9099, -0.6954, -0.0441,  0.0788, -0.5598, -0.2495,
-         0.0209,  0.2696, -0.0472, -0.4878, -0.2115,  1.5370, -0.8757, -0.5042,
-        -0.8744, -0.1809, -0.9056, -0.3410,  0.5850,  0.0000,  0.5241,  0.3953,
-         0.3655,  0.4912, -0.1603,  0.6543,  0.6681,  0.7375,  0.5376,  0.4919,
-         0.7544, -0.5737,  0.5116, -0.3163,  0.5206,  0.1029,  0.6601,  0.5383,
-         0.5704,  1.0450,  0.0972, -0.4387,  1.0750,  0.1367, -0.0733, -0.0433],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8328,  0.6827, -0.2034, -0.2186,  0.2607,  0.3025, -0.6211, -1.1005,
-        -0.6012, -0.5698, -0.1202, -0.5168,  0.0740, -0.0193, -0.6270, -0.6569,
-        -0.8544,  0.7149, -0.9099, -0.6954, -0.0441,  0.0788, -0.5598, -0.2495,
-         0.0209,  0.2696, -0.0472, -0.4878, -0.2115,  1.5370, -0.8757, -0.5042,
-        -0.8744, -0.1809, -0.9056, -0.3410,  0.5850,  0.0000,  0.5241,  0.3953,
-         0.3655,  0.4912, -0.1603,  0.6543,  0.6681,  0.7375,  0.5376,  0.4919,
-         0.7544, -0.5737,  0.5116, -0.3163,  0.5206,  0.1029,  0.6601,  0.5383,
-         0.5704,  1.0450,  0.0972, -0.4387,  1.0750,  0.1367, -0.0733, -0.0433],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8328,  0.6827, -0.2034, -0.2186,  0.2607,  0.3025, -0.6211, -1.1005,
-        -0.6012, -0.5698, -0.1202, -0.5168,  0.0740, -0.0193, -0.6270, -0.6569,
-        -0.8544,  0.7149, -0.9099, -0.6954, -0.0441,  0.0788, -0.5598, -0.2495,
-         0.0209,  0.2696, -0.0472, -0.4878, -0.2115,  1.5370, -0.8757, -0.5042,
-        -0.8744, -0.1809, -0.9056, -0.3410,  0.5850,  0.0000,  0.5241,  0.3953,
-         0.3655,  0.4912, -0.1603,  0.6543,  0.6681,  0.7375,  0.5376,  0.4919,
-         0.7544, -0.5737,  0.5116, -0.3163,  0.5206,  0.1029,  0.6601,  0.5383,
-         0.5704,  1.0450,  0.0972, -0.4387,  1.0750,  0.1367, -0.0733, -0.0433],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8351,  0.6848, -0.1897, -0.2048,  0.2838,  0.3446, -0.6252, -1.0997,
-        -0.5887, -0.5760, -0.0915, -0.5324,  0.0763, -0.0127, -0.6672, -0.6593,
-        -0.8647,  0.7176, -0.9129, -0.6936, -0.0493,  0.0977, -0.5597, -0.2150,
-         0.0413,  0.2816, -0.0432, -0.5008, -0.2447,  1.5366, -0.8754, -0.5314,
-        -0.8746, -0.1532, -0.9049, -0.3450,  0.5902,  0.0000,  0.5304,  0.3827,
-         0.3599,  0.5477, -0.1978,  0.6502,  0.6690,  0.7393,  0.5509,  0.4826,
-         0.7515, -0.5756,  0.5046, -0.3172,  0.5358,  0.0467,  0.6544,  0.5398,
-         0.6094,  1.0428,  0.0791, -0.4364,  1.0782,  0.1284, -0.0567, -0.0316],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8351,  0.6848, -0.1897, -0.2048,  0.2838,  0.3446, -0.6252, -1.0997,
-        -0.5887, -0.5760, -0.0915, -0.5324,  0.0763, -0.0127, -0.6672, -0.6593,
-        -0.8647,  0.7176, -0.9129, -0.6936, -0.0493,  0.0977, -0.5597, -0.2150,
-         0.0413,  0.2816, -0.0432, -0.5008, -0.2447,  1.5366, -0.8754, -0.5314,
-        -0.8746, -0.1532, -0.9049, -0.3450,  0.5902,  0.0000,  0.5304,  0.3827,
-         0.3599,  0.5477, -0.1978,  0.6502,  0.6690,  0.7393,  0.5509,  0.4826,
-         0.7515, -0.5756,  0.5046, -0.3172,  0.5358,  0.0467,  0.6544,  0.5398,
-         0.6094,  1.0428,  0.0791, -0.4364,  1.0782,  0.1284, -0.0567, -0.0316],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8351,  0.6848, -0.1897, -0.2048,  0.2838,  0.3446, -0.6252, -1.0997,
-        -0.5887, -0.5760, -0.0915, -0.5324,  0.0763, -0.0127, -0.6672, -0.6593,
-        -0.8647,  0.7176, -0.9129, -0.6936, -0.0493,  0.0977, -0.5597, -0.2150,
-         0.0413,  0.2816, -0.0432, -0.5008, -0.2447,  1.5366, -0.8754, -0.5314,
-        -0.8746, -0.1532, -0.9049, -0.3450,  0.5902,  0.0000,  0.5304,  0.3827,
-         0.3599,  0.5477, -0.1978,  0.6502,  0.6690,  0.7393,  0.5509,  0.4826,
-         0.7515, -0.5756,  0.5046, -0.3172,  0.5358,  0.0467,  0.6544,  0.5398,
-         0.6094,  1.0428,  0.0791, -0.4364,  1.0782,  0.1284, -0.0567, -0.0316],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8425,  0.6875, -0.1784, -0.1889,  0.3025,  0.3909, -0.6248, -1.0981,
-        -0.5758, -0.5842, -0.0779, -0.5512,  0.0906,  0.0071, -0.7033, -0.6596,
-        -0.8756,  0.7186, -0.9150, -0.6917, -0.0557,  0.0995, -0.5576, -0.1973,
-         0.0740,  0.3121, -0.0289, -0.5090, -0.2735,  1.5361, -0.8731, -0.5574,
-        -0.8751, -0.1274, -0.9027, -0.3456,  0.5930,  0.0000,  0.5420,  0.3823,
-         0.3565,  0.6008, -0.2406,  0.6478,  0.6726,  0.7348,  0.5655,  0.4840,
-         0.7448, -0.5772,  0.4974, -0.3192,  0.5496,  0.0103,  0.6487,  0.5401,
-         0.6463,  1.0404,  0.0565, -0.4403,  1.0826,  0.0867, -0.0219, -0.0071],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8425,  0.6875, -0.1784, -0.1889,  0.3025,  0.3909, -0.6248, -1.0981,
-        -0.5758, -0.5842, -0.0779, -0.5512,  0.0906,  0.0071, -0.7033, -0.6596,
-        -0.8756,  0.7186, -0.9150, -0.6917, -0.0557,  0.0995, -0.5576, -0.1973,
-         0.0740,  0.3121, -0.0289, -0.5090, -0.2735,  1.5361, -0.8731, -0.5574,
-        -0.8751, -0.1274, -0.9027, -0.3456,  0.5930,  0.0000,  0.5420,  0.3823,
-         0.3565,  0.6008, -0.2406,  0.6478,  0.6726,  0.7348,  0.5655,  0.4840,
-         0.7448, -0.5772,  0.4974, -0.3192,  0.5496,  0.0103,  0.6487,  0.5401,
-         0.6463,  1.0404,  0.0565, -0.4403,  1.0826,  0.0867, -0.0219, -0.0071],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8425,  0.6875, -0.1784, -0.1889,  0.3025,  0.3909, -0.6248, -1.0981,
-        -0.5758, -0.5842, -0.0779, -0.5512,  0.0906,  0.0071, -0.7033, -0.6596,
-        -0.8756,  0.7186, -0.9150, -0.6917, -0.0557,  0.0995, -0.5576, -0.1973,
-         0.0740,  0.3121, -0.0289, -0.5090, -0.2735,  1.5361, -0.8731, -0.5574,
-        -0.8751, -0.1274, -0.9027, -0.3456,  0.5930,  0.0000,  0.5420,  0.3823,
-         0.3565,  0.6008, -0.2406,  0.6478,  0.6726,  0.7348,  0.5655,  0.4840,
-         0.7448, -0.5772,  0.4974, -0.3192,  0.5496,  0.0103,  0.6487,  0.5401,
-         0.6463,  1.0404,  0.0565, -0.4403,  1.0826,  0.0867, -0.0219, -0.0071],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8496,  0.6888, -0.1583, -0.1660,  0.3228,  0.4375, -0.6218, -1.0961,
-        -0.5621, -0.5938, -0.0463, -0.5668,  0.1274,  0.0353, -0.7413, -0.6554,
-        -0.8861,  0.7168, -0.9164, -0.6914, -0.0501,  0.0992, -0.5501, -0.1789,
-         0.1146,  0.3535, -0.0037, -0.5176, -0.2938,  1.5358, -0.8712, -0.5825,
-        -0.8759, -0.0872, -0.9012, -0.3440,  0.5928,  0.0000,  0.5583,  0.3760,
-         0.3504,  0.6540, -0.2843,  0.6521,  0.6763,  0.7279,  0.5770,  0.4851,
-         0.7340, -0.5807,  0.4902, -0.3197,  0.5594, -0.0185,  0.6470,  0.5375,
-         0.6816,  1.0392,  0.0452, -0.4440,  1.0880,  0.0386,  0.0155,  0.0200],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8496,  0.6888, -0.1583, -0.1660,  0.3228,  0.4375, -0.6218, -1.0961,
-        -0.5621, -0.5938, -0.0463, -0.5668,  0.1274,  0.0353, -0.7413, -0.6554,
-        -0.8861,  0.7168, -0.9164, -0.6914, -0.0501,  0.0992, -0.5501, -0.1789,
-         0.1146,  0.3535, -0.0037, -0.5176, -0.2938,  1.5358, -0.8712, -0.5825,
-        -0.8759, -0.0872, -0.9012, -0.3440,  0.5928,  0.0000,  0.5583,  0.3760,
-         0.3504,  0.6540, -0.2843,  0.6521,  0.6763,  0.7279,  0.5770,  0.4851,
-         0.7340, -0.5807,  0.4902, -0.3197,  0.5594, -0.0185,  0.6470,  0.5375,
-         0.6816,  1.0392,  0.0452, -0.4440,  1.0880,  0.0386,  0.0155,  0.0200],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8496,  0.6888, -0.1583, -0.1660,  0.3228,  0.4375, -0.6218, -1.0961,
-        -0.5621, -0.5938, -0.0463, -0.5668,  0.1274,  0.0353, -0.7413, -0.6554,
-        -0.8861,  0.7168, -0.9164, -0.6914, -0.0501,  0.0992, -0.5501, -0.1789,
-         0.1146,  0.3535, -0.0037, -0.5176, -0.2938,  1.5358, -0.8712, -0.5825,
-        -0.8759, -0.0872, -0.9012, -0.3440,  0.5928,  0.0000,  0.5583,  0.3760,
-         0.3504,  0.6540, -0.2843,  0.6521,  0.6763,  0.7279,  0.5770,  0.4851,
-         0.7340, -0.5807,  0.4902, -0.3197,  0.5594, -0.0185,  0.6470,  0.5375,
-         0.6816,  1.0392,  0.0452, -0.4440,  1.0880,  0.0386,  0.0155,  0.0200],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8540,  0.6826, -0.1418, -0.1293,  0.3483,  0.4703, -0.6192, -1.0949,
-        -0.5517, -0.5965, -0.0411, -0.5794,  0.1471,  0.0631, -0.7827, -0.6545,
-        -0.8879,  0.7126, -0.9167, -0.6927, -0.0452,  0.1059, -0.5423, -0.1787,
-         0.1443,  0.3862,  0.0142, -0.5258, -0.3166,  1.5357, -0.8689, -0.6053,
-        -0.8779, -0.0517, -0.9016, -0.3421,  0.5954,  0.0000,  0.5813,  0.3591,
-         0.3447,  0.7102, -0.3187,  0.6618,  0.6797,  0.7215,  0.5902,  0.4847,
-         0.7254, -0.5833,  0.4949, -0.3267,  0.5613, -0.0523,  0.6475,  0.5382,
-         0.7158,  1.0384,  0.0236, -0.4734,  1.0927, -0.0100,  0.0488,  0.0483],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8540,  0.6826, -0.1418, -0.1293,  0.3483,  0.4703, -0.6192, -1.0949,
-        -0.5517, -0.5965, -0.0411, -0.5794,  0.1471,  0.0631, -0.7827, -0.6545,
-        -0.8879,  0.7126, -0.9167, -0.6927, -0.0452,  0.1059, -0.5423, -0.1787,
-         0.1443,  0.3862,  0.0142, -0.5258, -0.3166,  1.5357, -0.8689, -0.6053,
-        -0.8779, -0.0517, -0.9016, -0.3421,  0.5954,  0.0000,  0.5813,  0.3591,
-         0.3447,  0.7102, -0.3187,  0.6618,  0.6797,  0.7215,  0.5902,  0.4847,
-         0.7254, -0.5833,  0.4949, -0.3267,  0.5613, -0.0523,  0.6475,  0.5382,
-         0.7158,  1.0384,  0.0236, -0.4734,  1.0927, -0.0100,  0.0488,  0.0483],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8540,  0.6826, -0.1418, -0.1293,  0.3483,  0.4703, -0.6192, -1.0949,
-        -0.5517, -0.5965, -0.0411, -0.5794,  0.1471,  0.0631, -0.7827, -0.6545,
-        -0.8879,  0.7126, -0.9167, -0.6927, -0.0452,  0.1059, -0.5423, -0.1787,
-         0.1443,  0.3862,  0.0142, -0.5258, -0.3166,  1.5357, -0.8689, -0.6053,
-        -0.8779, -0.0517, -0.9016, -0.3421,  0.5954,  0.0000,  0.5813,  0.3591,
-         0.3447,  0.7102, -0.3187,  0.6618,  0.6797,  0.7215,  0.5902,  0.4847,
-         0.7254, -0.5833,  0.4949, -0.3267,  0.5613, -0.0523,  0.6475,  0.5382,
-         0.7158,  1.0384,  0.0236, -0.4734,  1.0927, -0.0100,  0.0488,  0.0483],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8557,  0.6765, -0.1242, -0.0930,  0.3557,  0.4981, -0.6172, -1.0929,
-        -0.5442, -0.5942, -0.0296, -0.5883,  0.1710,  0.0779, -0.8210, -0.6484,
-        -0.8914,  0.7124, -0.9192, -0.6930, -0.0256,  0.1264, -0.5343, -0.1782,
-         0.1817,  0.3988,  0.0091, -0.5358, -0.3322,  1.5361, -0.8679, -0.6253,
-        -0.8805, -0.0115, -0.9016, -0.3420,  0.5969,  0.0000,  0.5990,  0.3315,
-         0.3420,  0.7627, -0.3453,  0.6644,  0.6847,  0.7174,  0.6001,  0.4729,
-         0.7136, -0.5859,  0.4988, -0.3285,  0.5610, -0.0685,  0.6426,  0.5384,
-         0.7475,  1.0379, -0.0088, -0.4867,  1.0961, -0.0454,  0.0844,  0.0482],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8557,  0.6765, -0.1242, -0.0930,  0.3557,  0.4981, -0.6172, -1.0929,
-        -0.5442, -0.5942, -0.0296, -0.5883,  0.1710,  0.0779, -0.8210, -0.6484,
-        -0.8914,  0.7124, -0.9192, -0.6930, -0.0256,  0.1264, -0.5343, -0.1782,
-         0.1817,  0.3988,  0.0091, -0.5358, -0.3322,  1.5361, -0.8679, -0.6253,
-        -0.8805, -0.0115, -0.9016, -0.3420,  0.5969,  0.0000,  0.5990,  0.3315,
-         0.3420,  0.7627, -0.3453,  0.6644,  0.6847,  0.7174,  0.6001,  0.4729,
-         0.7136, -0.5859,  0.4988, -0.3285,  0.5610, -0.0685,  0.6426,  0.5384,
-         0.7475,  1.0379, -0.0088, -0.4867,  1.0961, -0.0454,  0.0844,  0.0482],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8557,  0.6765, -0.1242, -0.0930,  0.3557,  0.4981, -0.6172, -1.0929,
-        -0.5442, -0.5942, -0.0296, -0.5883,  0.1710,  0.0779, -0.8210, -0.6484,
-        -0.8914,  0.7124, -0.9192, -0.6930, -0.0256,  0.1264, -0.5343, -0.1782,
-         0.1817,  0.3988,  0.0091, -0.5358, -0.3322,  1.5361, -0.8679, -0.6253,
-        -0.8805, -0.0115, -0.9016, -0.3420,  0.5969,  0.0000,  0.5990,  0.3315,
-         0.3420,  0.7627, -0.3453,  0.6644,  0.6847,  0.7174,  0.6001,  0.4729,
-         0.7136, -0.5859,  0.4988, -0.3285,  0.5610, -0.0685,  0.6426,  0.5384,
-         0.7475,  1.0379, -0.0088, -0.4867,  1.0961, -0.0454,  0.0844,  0.0482],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 8.5266e-01,  6.7580e-01, -9.5389e-02, -5.8647e-02,  3.4499e-01,
-         5.2972e-01, -6.1499e-01, -1.0908e+00, -5.3981e-01, -6.0411e-01,
-        -1.2679e-02, -5.8782e-01,  1.8367e-01,  1.0575e-01, -8.5817e-01,
-        -6.4225e-01, -9.0060e-01,  7.1460e-01, -9.2264e-01, -6.9305e-01,
-        -5.1725e-04,  1.5005e-01, -5.3724e-01, -1.5387e-01,  2.2525e-01,
-         4.1215e-01, -2.4788e-02, -5.4276e-01, -3.3418e-01,  1.5373e+00,
-        -8.6478e-01, -6.4354e-01, -8.8071e-01,  4.0049e-02, -9.0071e-01,
-        -3.4485e-01,  5.9956e-01,  0.0000e+00,  6.2284e-01,  3.1167e-01,
-         3.4139e-01,  8.0222e-01, -3.4693e-01,  6.6650e-01,  6.8761e-01,
-         7.1025e-01,  6.0540e-01,  4.5924e-01,  7.0868e-01, -5.8703e-01,
-         4.9695e-01, -3.2038e-01,  5.6034e-01, -8.1232e-02,  6.3611e-01,
-         5.4369e-01,  7.7647e-01,  1.0382e+00, -3.0172e-02, -4.6999e-01,
-         1.1006e+00, -6.4157e-02,  9.5164e-02,  3.3943e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.5266e-01,  6.7580e-01, -9.5389e-02, -5.8647e-02,  3.4499e-01,
-         5.2972e-01, -6.1499e-01, -1.0908e+00, -5.3981e-01, -6.0411e-01,
-        -1.2679e-02, -5.8782e-01,  1.8367e-01,  1.0575e-01, -8.5817e-01,
-        -6.4225e-01, -9.0060e-01,  7.1460e-01, -9.2264e-01, -6.9305e-01,
-        -5.1725e-04,  1.5005e-01, -5.3724e-01, -1.5387e-01,  2.2525e-01,
-         4.1215e-01, -2.4788e-02, -5.4276e-01, -3.3418e-01,  1.5373e+00,
-        -8.6478e-01, -6.4354e-01, -8.8071e-01,  4.0049e-02, -9.0071e-01,
-        -3.4485e-01,  5.9956e-01,  0.0000e+00,  6.2284e-01,  3.1167e-01,
-         3.4139e-01,  8.0222e-01, -3.4693e-01,  6.6650e-01,  6.8761e-01,
-         7.1025e-01,  6.0540e-01,  4.5924e-01,  7.0868e-01, -5.8703e-01,
-         4.9695e-01, -3.2038e-01,  5.6034e-01, -8.1232e-02,  6.3611e-01,
-         5.4369e-01,  7.7647e-01,  1.0382e+00, -3.0172e-02, -4.6999e-01,
-         1.1006e+00, -6.4157e-02,  9.5164e-02,  3.3943e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.5266e-01,  6.7580e-01, -9.5389e-02, -5.8647e-02,  3.4499e-01,
-         5.2972e-01, -6.1499e-01, -1.0908e+00, -5.3981e-01, -6.0411e-01,
-        -1.2679e-02, -5.8782e-01,  1.8367e-01,  1.0575e-01, -8.5817e-01,
-        -6.4225e-01, -9.0060e-01,  7.1460e-01, -9.2264e-01, -6.9305e-01,
-        -5.1725e-04,  1.5005e-01, -5.3724e-01, -1.5387e-01,  2.2525e-01,
-         4.1215e-01, -2.4788e-02, -5.4276e-01, -3.3418e-01,  1.5373e+00,
-        -8.6478e-01, -6.4354e-01, -8.8071e-01,  4.0049e-02, -9.0071e-01,
-        -3.4485e-01,  5.9956e-01,  0.0000e+00,  6.2284e-01,  3.1167e-01,
-         3.4139e-01,  8.0222e-01, -3.4693e-01,  6.6650e-01,  6.8761e-01,
-         7.1025e-01,  6.0540e-01,  4.5924e-01,  7.0868e-01, -5.8703e-01,
-         4.9695e-01, -3.2038e-01,  5.6034e-01, -8.1232e-02,  6.3611e-01,
-         5.4369e-01,  7.7647e-01,  1.0382e+00, -3.0172e-02, -4.6999e-01,
-         1.1006e+00, -6.4157e-02,  9.5164e-02,  3.3943e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8481,  0.6777, -0.0667, -0.0233,  0.3218,  0.5614, -0.6137, -1.0897,
-        -0.5371, -0.6205,  0.0116, -0.5840,  0.1812,  0.1305, -0.8841, -0.6349,
-        -0.9096,  0.7181, -0.9248, -0.6915,  0.0243,  0.1978, -0.5450, -0.1431,
-         0.2955,  0.4249, -0.0832, -0.5469, -0.3278,  1.5390, -0.8622, -0.6591,
-        -0.8821,  0.0986, -0.8981, -0.3505,  0.5981,  0.0000,  0.6303,  0.2840,
-         0.3475,  0.8376, -0.3497,  0.6621,  0.6887,  0.7069,  0.6090,  0.4262,
-         0.7096, -0.5890,  0.4958, -0.3156,  0.5619, -0.1015,  0.6207,  0.5497,
-         0.8033,  1.0353, -0.0656, -0.4447,  1.1025, -0.0874,  0.1014, -0.0045],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8481,  0.6777, -0.0667, -0.0233,  0.3218,  0.5614, -0.6137, -1.0897,
-        -0.5371, -0.6205,  0.0116, -0.5840,  0.1812,  0.1305, -0.8841, -0.6349,
-        -0.9096,  0.7181, -0.9248, -0.6915,  0.0243,  0.1978, -0.5450, -0.1431,
-         0.2955,  0.4249, -0.0832, -0.5469, -0.3278,  1.5390, -0.8622, -0.6591,
-        -0.8821,  0.0986, -0.8981, -0.3505,  0.5981,  0.0000,  0.6303,  0.2840,
-         0.3475,  0.8376, -0.3497,  0.6621,  0.6887,  0.7069,  0.6090,  0.4262,
-         0.7096, -0.5890,  0.4958, -0.3156,  0.5619, -0.1015,  0.6207,  0.5497,
-         0.8033,  1.0353, -0.0656, -0.4447,  1.1025, -0.0874,  0.1014, -0.0045],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8481,  0.6777, -0.0667, -0.0233,  0.3218,  0.5614, -0.6137, -1.0897,
-        -0.5371, -0.6205,  0.0116, -0.5840,  0.1812,  0.1305, -0.8841, -0.6349,
-        -0.9096,  0.7181, -0.9248, -0.6915,  0.0243,  0.1978, -0.5450, -0.1431,
-         0.2955,  0.4249, -0.0832, -0.5469, -0.3278,  1.5390, -0.8622, -0.6591,
-        -0.8821,  0.0986, -0.8981, -0.3505,  0.5981,  0.0000,  0.6303,  0.2840,
-         0.3475,  0.8376, -0.3497,  0.6621,  0.6887,  0.7069,  0.6090,  0.4262,
-         0.7096, -0.5890,  0.4958, -0.3156,  0.5619, -0.1015,  0.6207,  0.5497,
-         0.8033,  1.0353, -0.0656, -0.4447,  1.1025, -0.0874,  0.1014, -0.0045],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8453,  0.6709, -0.0316,  0.0031,  0.3022,  0.5885, -0.6127, -1.0888,
-        -0.5377, -0.6351,  0.0136, -0.5838,  0.1487,  0.1609, -0.9119, -0.6284,
-        -0.9182,  0.7209, -0.9274, -0.6892,  0.0227,  0.2179, -0.5582, -0.1325,
-         0.3503,  0.4398, -0.1383, -0.5490, -0.3169,  1.5414, -0.8586, -0.6740,
-        -0.8825,  0.1398, -0.8967, -0.3510,  0.5993,  0.0000,  0.6398,  0.2602,
-         0.3466,  0.8698, -0.3477,  0.6601,  0.6898,  0.7015,  0.6142,  0.4043,
-         0.7144, -0.5898,  0.4980, -0.3139,  0.5599, -0.1175,  0.6093,  0.5526,
-         0.8292,  1.0344, -0.0941, -0.4420,  1.1051, -0.1116,  0.0880, -0.0505],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8453,  0.6709, -0.0316,  0.0031,  0.3022,  0.5885, -0.6127, -1.0888,
-        -0.5377, -0.6351,  0.0136, -0.5838,  0.1487,  0.1609, -0.9119, -0.6284,
-        -0.9182,  0.7209, -0.9274, -0.6892,  0.0227,  0.2179, -0.5582, -0.1325,
-         0.3503,  0.4398, -0.1383, -0.5490, -0.3169,  1.5414, -0.8586, -0.6740,
-        -0.8825,  0.1398, -0.8967, -0.3510,  0.5993,  0.0000,  0.6398,  0.2602,
-         0.3466,  0.8698, -0.3477,  0.6601,  0.6898,  0.7015,  0.6142,  0.4043,
-         0.7144, -0.5898,  0.4980, -0.3139,  0.5599, -0.1175,  0.6093,  0.5526,
-         0.8292,  1.0344, -0.0941, -0.4420,  1.1051, -0.1116,  0.0880, -0.0505],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8453,  0.6709, -0.0316,  0.0031,  0.3022,  0.5885, -0.6127, -1.0888,
-        -0.5377, -0.6351,  0.0136, -0.5838,  0.1487,  0.1609, -0.9119, -0.6284,
-        -0.9182,  0.7209, -0.9274, -0.6892,  0.0227,  0.2179, -0.5582, -0.1325,
-         0.3503,  0.4398, -0.1383, -0.5490, -0.3169,  1.5414, -0.8586, -0.6740,
-        -0.8825,  0.1398, -0.8967, -0.3510,  0.5993,  0.0000,  0.6398,  0.2602,
-         0.3466,  0.8698, -0.3477,  0.6601,  0.6898,  0.7015,  0.6142,  0.4043,
-         0.7144, -0.5898,  0.4980, -0.3139,  0.5599, -0.1175,  0.6093,  0.5526,
-         0.8292,  1.0344, -0.0941, -0.4420,  1.1051, -0.1116,  0.0880, -0.0505],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 8.4785e-01,  6.6183e-01,  1.8025e-03,  1.4265e-03,  3.0244e-01,
-         6.0873e-01, -6.0872e-01, -1.0883e+00, -5.4400e-01, -6.4973e-01,
-        -1.6679e-03, -5.7666e-01,  1.2513e-01,  2.0507e-01, -9.3640e-01,
-        -6.2306e-01, -9.2391e-01,  7.2304e-01, -9.2803e-01, -6.8419e-01,
-         9.0225e-03,  2.0134e-01, -5.7122e-01, -1.1008e-01,  3.9814e-01,
-         4.5700e-01, -1.6610e-01, -5.4718e-01, -3.0446e-01,  1.5438e+00,
-        -8.5262e-01, -6.8574e-01, -8.8078e-01,  1.5274e-01, -8.9266e-01,
-        -3.4508e-01,  6.0414e-01,  0.0000e+00,  6.4873e-01,  2.4528e-01,
-         3.4118e-01,  9.0364e-01, -3.5487e-01,  6.5901e-01,  6.8997e-01,
-         6.9338e-01,  6.2017e-01,  3.8932e-01,  7.1688e-01, -5.9111e-01,
-         4.9627e-01, -3.2175e-01,  5.6234e-01, -1.3454e-01,  5.9622e-01,
-         5.5684e-01,  8.5444e-01,  1.0331e+00, -1.2141e-01, -4.6627e-01,
-         1.1095e+00, -1.4739e-01,  7.7974e-02, -1.0855e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.4785e-01,  6.6183e-01,  1.8025e-03,  1.4265e-03,  3.0244e-01,
-         6.0873e-01, -6.0872e-01, -1.0883e+00, -5.4400e-01, -6.4973e-01,
-        -1.6679e-03, -5.7666e-01,  1.2513e-01,  2.0507e-01, -9.3640e-01,
-        -6.2306e-01, -9.2391e-01,  7.2304e-01, -9.2803e-01, -6.8419e-01,
-         9.0225e-03,  2.0134e-01, -5.7122e-01, -1.1008e-01,  3.9814e-01,
-         4.5700e-01, -1.6610e-01, -5.4718e-01, -3.0446e-01,  1.5438e+00,
-        -8.5262e-01, -6.8574e-01, -8.8078e-01,  1.5274e-01, -8.9266e-01,
-        -3.4508e-01,  6.0414e-01,  0.0000e+00,  6.4873e-01,  2.4528e-01,
-         3.4118e-01,  9.0364e-01, -3.5487e-01,  6.5901e-01,  6.8997e-01,
-         6.9338e-01,  6.2017e-01,  3.8932e-01,  7.1688e-01, -5.9111e-01,
-         4.9627e-01, -3.2175e-01,  5.6234e-01, -1.3454e-01,  5.9622e-01,
-         5.5684e-01,  8.5444e-01,  1.0331e+00, -1.2141e-01, -4.6627e-01,
-         1.1095e+00, -1.4739e-01,  7.7974e-02, -1.0855e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.4785e-01,  6.6183e-01,  1.8025e-03,  1.4265e-03,  3.0244e-01,
-         6.0873e-01, -6.0872e-01, -1.0883e+00, -5.4400e-01, -6.4973e-01,
-        -1.6679e-03, -5.7666e-01,  1.2513e-01,  2.0507e-01, -9.3640e-01,
-        -6.2306e-01, -9.2391e-01,  7.2304e-01, -9.2803e-01, -6.8419e-01,
-         9.0225e-03,  2.0134e-01, -5.7122e-01, -1.1008e-01,  3.9814e-01,
-         4.5700e-01, -1.6610e-01, -5.4718e-01, -3.0446e-01,  1.5438e+00,
-        -8.5262e-01, -6.8574e-01, -8.8078e-01,  1.5274e-01, -8.9266e-01,
-        -3.4508e-01,  6.0414e-01,  0.0000e+00,  6.4873e-01,  2.4528e-01,
-         3.4118e-01,  9.0364e-01, -3.5487e-01,  6.5901e-01,  6.8997e-01,
-         6.9338e-01,  6.2017e-01,  3.8932e-01,  7.1688e-01, -5.9111e-01,
-         4.9627e-01, -3.2175e-01,  5.6234e-01, -1.3454e-01,  5.9622e-01,
-         5.5684e-01,  8.5444e-01,  1.0331e+00, -1.2141e-01, -4.6627e-01,
-         1.1095e+00, -1.4739e-01,  7.7974e-02, -1.0855e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8480,  0.6467,  0.0505,  0.0129,  0.3521,  0.6176, -0.6035, -1.0859,
-        -0.5493, -0.6594, -0.0316, -0.5608,  0.1000,  0.2710, -0.9691, -0.6158,
-        -0.9280,  0.7162, -0.9277, -0.6774, -0.0052,  0.1552, -0.5815, -0.1002,
-         0.4326,  0.4630, -0.1855, -0.5458, -0.2998,  1.5459, -0.8451, -0.6956,
-        -0.8774,  0.1352, -0.8911, -0.3328,  0.6248,  0.0000,  0.6657,  0.2678,
-         0.3230,  0.9347, -0.3515,  0.6685,  0.6924,  0.6883,  0.6275,  0.3795,
-         0.7224, -0.5930,  0.4989, -0.3273,  0.5579, -0.1341,  0.5868,  0.5589,
-         0.8795,  1.0309, -0.1550, -0.5100,  1.1168, -0.1660,  0.0439, -0.1603],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8480,  0.6467,  0.0505,  0.0129,  0.3521,  0.6176, -0.6035, -1.0859,
-        -0.5493, -0.6594, -0.0316, -0.5608,  0.1000,  0.2710, -0.9691, -0.6158,
-        -0.9280,  0.7162, -0.9277, -0.6774, -0.0052,  0.1552, -0.5815, -0.1002,
-         0.4326,  0.4630, -0.1855, -0.5458, -0.2998,  1.5459, -0.8451, -0.6956,
-        -0.8774,  0.1352, -0.8911, -0.3328,  0.6248,  0.0000,  0.6657,  0.2678,
-         0.3230,  0.9347, -0.3515,  0.6685,  0.6924,  0.6883,  0.6275,  0.3795,
-         0.7224, -0.5930,  0.4989, -0.3273,  0.5579, -0.1341,  0.5868,  0.5589,
-         0.8795,  1.0309, -0.1550, -0.5100,  1.1168, -0.1660,  0.0439, -0.1603],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8480,  0.6467,  0.0505,  0.0129,  0.3521,  0.6176, -0.6035, -1.0859,
-        -0.5493, -0.6594, -0.0316, -0.5608,  0.1000,  0.2710, -0.9691, -0.6158,
-        -0.9280,  0.7162, -0.9277, -0.6774, -0.0052,  0.1552, -0.5815, -0.1002,
-         0.4326,  0.4630, -0.1855, -0.5458, -0.2998,  1.5459, -0.8451, -0.6956,
-        -0.8774,  0.1352, -0.8911, -0.3328,  0.6248,  0.0000,  0.6657,  0.2678,
-         0.3230,  0.9347, -0.3515,  0.6685,  0.6924,  0.6883,  0.6275,  0.3795,
-         0.7224, -0.5930,  0.4989, -0.3273,  0.5579, -0.1341,  0.5868,  0.5589,
-         0.8795,  1.0309, -0.1550, -0.5100,  1.1168, -0.1660,  0.0439, -0.1603],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8516,  0.6250,  0.1032,  0.0291,  0.4085,  0.6249, -0.5966, -1.0827,
-        -0.5561, -0.6681, -0.0772, -0.5355,  0.0922,  0.3283, -1.0032, -0.6110,
-        -0.9278,  0.7076, -0.9278, -0.6685, -0.0113,  0.0982, -0.5864, -0.0968,
-         0.4542,  0.4567, -0.2037, -0.5439, -0.2914,  1.5472, -0.8413, -0.7019,
-        -0.8749,  0.0927, -0.8897, -0.3183,  0.6463,  0.0000,  0.6891,  0.3010,
-         0.2956,  0.9626, -0.3477,  0.6820,  0.6967,  0.6807,  0.6386,  0.3650,
-         0.7230, -0.5933,  0.4977, -0.3283,  0.5453, -0.1070,  0.5833,  0.5609,
-         0.9040,  1.0290, -0.1999, -0.5570,  1.1253, -0.1857,  0.0091, -0.2012],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8516,  0.6250,  0.1032,  0.0291,  0.4085,  0.6249, -0.5966, -1.0827,
-        -0.5561, -0.6681, -0.0772, -0.5355,  0.0922,  0.3283, -1.0032, -0.6110,
-        -0.9278,  0.7076, -0.9278, -0.6685, -0.0113,  0.0982, -0.5864, -0.0968,
-         0.4542,  0.4567, -0.2037, -0.5439, -0.2914,  1.5472, -0.8413, -0.7019,
-        -0.8749,  0.0927, -0.8897, -0.3183,  0.6463,  0.0000,  0.6891,  0.3010,
-         0.2956,  0.9626, -0.3477,  0.6820,  0.6967,  0.6807,  0.6386,  0.3650,
-         0.7230, -0.5933,  0.4977, -0.3283,  0.5453, -0.1070,  0.5833,  0.5609,
-         0.9040,  1.0290, -0.1999, -0.5570,  1.1253, -0.1857,  0.0091, -0.2012],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8516,  0.6250,  0.1032,  0.0291,  0.4085,  0.6249, -0.5966, -1.0827,
-        -0.5561, -0.6681, -0.0772, -0.5355,  0.0922,  0.3283, -1.0032, -0.6110,
-        -0.9278,  0.7076, -0.9278, -0.6685, -0.0113,  0.0982, -0.5864, -0.0968,
-         0.4542,  0.4567, -0.2037, -0.5439, -0.2914,  1.5472, -0.8413, -0.7019,
-        -0.8749,  0.0927, -0.8897, -0.3183,  0.6463,  0.0000,  0.6891,  0.3010,
-         0.2956,  0.9626, -0.3477,  0.6820,  0.6967,  0.6807,  0.6386,  0.3650,
-         0.7230, -0.5933,  0.4977, -0.3283,  0.5453, -0.1070,  0.5833,  0.5609,
-         0.9040,  1.0290, -0.1999, -0.5570,  1.1253, -0.1857,  0.0091, -0.2012],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8531,  0.6081,  0.1566,  0.0508,  0.4595,  0.6326, -0.5882, -1.0798,
-        -0.5604, -0.6778, -0.1220, -0.5030,  0.0868,  0.3818, -1.0357, -0.6079,
-        -0.9257,  0.7018, -0.9292, -0.6589, -0.0235,  0.0500, -0.5933, -0.0724,
-         0.4748,  0.4460, -0.2289, -0.5414, -0.2876,  1.5483, -0.8386, -0.7071,
-        -0.8714,  0.0432, -0.8889, -0.2982,  0.6698,  0.0000,  0.7122,  0.3247,
-         0.2632,  0.9882, -0.3348,  0.6976,  0.7009,  0.6716,  0.6463,  0.3542,
-         0.7302, -0.5936,  0.4987, -0.3238,  0.5302, -0.0843,  0.5755,  0.5640,
-         0.9278,  1.0256, -0.2392, -0.6060,  1.1331, -0.2147, -0.0266, -0.2313],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8531,  0.6081,  0.1566,  0.0508,  0.4595,  0.6326, -0.5882, -1.0798,
-        -0.5604, -0.6778, -0.1220, -0.5030,  0.0868,  0.3818, -1.0357, -0.6079,
-        -0.9257,  0.7018, -0.9292, -0.6589, -0.0235,  0.0500, -0.5933, -0.0724,
-         0.4748,  0.4460, -0.2289, -0.5414, -0.2876,  1.5483, -0.8386, -0.7071,
-        -0.8714,  0.0432, -0.8889, -0.2982,  0.6698,  0.0000,  0.7122,  0.3247,
-         0.2632,  0.9882, -0.3348,  0.6976,  0.7009,  0.6716,  0.6463,  0.3542,
-         0.7302, -0.5936,  0.4987, -0.3238,  0.5302, -0.0843,  0.5755,  0.5640,
-         0.9278,  1.0256, -0.2392, -0.6060,  1.1331, -0.2147, -0.0266, -0.2313],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8531,  0.6081,  0.1566,  0.0508,  0.4595,  0.6326, -0.5882, -1.0798,
-        -0.5604, -0.6778, -0.1220, -0.5030,  0.0868,  0.3818, -1.0357, -0.6079,
-        -0.9257,  0.7018, -0.9292, -0.6589, -0.0235,  0.0500, -0.5933, -0.0724,
-         0.4748,  0.4460, -0.2289, -0.5414, -0.2876,  1.5483, -0.8386, -0.7071,
-        -0.8714,  0.0432, -0.8889, -0.2982,  0.6698,  0.0000,  0.7122,  0.3247,
-         0.2632,  0.9882, -0.3348,  0.6976,  0.7009,  0.6716,  0.6463,  0.3542,
-         0.7302, -0.5936,  0.4987, -0.3238,  0.5302, -0.0843,  0.5755,  0.5640,
-         0.9278,  1.0256, -0.2392, -0.6060,  1.1331, -0.2147, -0.0266, -0.2313],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8520,  0.6035,  0.1812,  0.0575,  0.4987,  0.6428, -0.5787, -1.0763,
-        -0.5627, -0.6876, -0.1697, -0.4773,  0.0595,  0.4090, -1.0658, -0.6038,
-        -0.9232,  0.6984, -0.9313, -0.6519, -0.0357,  0.0149, -0.5956, -0.0668,
-         0.4934,  0.4378, -0.2524, -0.5389, -0.2868,  1.5499, -0.8385, -0.7119,
-        -0.8683,  0.0256, -0.8870, -0.2823,  0.6978,  0.0000,  0.7296,  0.3361,
-         0.2282,  1.0128, -0.3342,  0.7094,  0.7043,  0.6589,  0.6542,  0.3241,
-         0.7465, -0.5933,  0.4978, -0.3146,  0.5213, -0.0610,  0.5602,  0.5637,
-         0.9503,  1.0209, -0.2637, -0.6393,  1.1405, -0.2535, -0.0682, -0.2747],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8520,  0.6035,  0.1812,  0.0575,  0.4987,  0.6428, -0.5787, -1.0763,
-        -0.5627, -0.6876, -0.1697, -0.4773,  0.0595,  0.4090, -1.0658, -0.6038,
-        -0.9232,  0.6984, -0.9313, -0.6519, -0.0357,  0.0149, -0.5956, -0.0668,
-         0.4934,  0.4378, -0.2524, -0.5389, -0.2868,  1.5499, -0.8385, -0.7119,
-        -0.8683,  0.0256, -0.8870, -0.2823,  0.6978,  0.0000,  0.7296,  0.3361,
-         0.2282,  1.0128, -0.3342,  0.7094,  0.7043,  0.6589,  0.6542,  0.3241,
-         0.7465, -0.5933,  0.4978, -0.3146,  0.5213, -0.0610,  0.5602,  0.5637,
-         0.9503,  1.0209, -0.2637, -0.6393,  1.1405, -0.2535, -0.0682, -0.2747],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8520,  0.6035,  0.1812,  0.0575,  0.4987,  0.6428, -0.5787, -1.0763,
-        -0.5627, -0.6876, -0.1697, -0.4773,  0.0595,  0.4090, -1.0658, -0.6038,
-        -0.9232,  0.6984, -0.9313, -0.6519, -0.0357,  0.0149, -0.5956, -0.0668,
-         0.4934,  0.4378, -0.2524, -0.5389, -0.2868,  1.5499, -0.8385, -0.7119,
-        -0.8683,  0.0256, -0.8870, -0.2823,  0.6978,  0.0000,  0.7296,  0.3361,
-         0.2282,  1.0128, -0.3342,  0.7094,  0.7043,  0.6589,  0.6542,  0.3241,
-         0.7465, -0.5933,  0.4978, -0.3146,  0.5213, -0.0610,  0.5602,  0.5637,
-         0.9503,  1.0209, -0.2637, -0.6393,  1.1405, -0.2535, -0.0682, -0.2747],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8515,  0.6070,  0.2001,  0.0773,  0.5316,  0.6571, -0.5678, -1.0729,
-        -0.5543, -0.6971, -0.1996, -0.4722,  0.0395,  0.4368, -1.0937, -0.5989,
-        -0.9209,  0.6960, -0.9357, -0.6486, -0.0348,  0.0029, -0.5945, -0.0713,
-         0.5130,  0.4299, -0.2678, -0.5333, -0.2836,  1.5520, -0.8362, -0.7143,
-        -0.8646,  0.0457, -0.8843, -0.2621,  0.7215,  0.0000,  0.7394,  0.3244,
-         0.2036,  1.0351, -0.3345,  0.7191,  0.7049,  0.6445,  0.6627,  0.3014,
-         0.7634, -0.5925,  0.4956, -0.2897,  0.5141, -0.0160,  0.5418,  0.5622,
-         0.9717,  1.0174, -0.2860, -0.6613,  1.1461, -0.2893, -0.1124, -0.2999],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8515,  0.6070,  0.2001,  0.0773,  0.5316,  0.6571, -0.5678, -1.0729,
-        -0.5543, -0.6971, -0.1996, -0.4722,  0.0395,  0.4368, -1.0937, -0.5989,
-        -0.9209,  0.6960, -0.9357, -0.6486, -0.0348,  0.0029, -0.5945, -0.0713,
-         0.5130,  0.4299, -0.2678, -0.5333, -0.2836,  1.5520, -0.8362, -0.7143,
-        -0.8646,  0.0457, -0.8843, -0.2621,  0.7215,  0.0000,  0.7394,  0.3244,
-         0.2036,  1.0351, -0.3345,  0.7191,  0.7049,  0.6445,  0.6627,  0.3014,
-         0.7634, -0.5925,  0.4956, -0.2897,  0.5141, -0.0160,  0.5418,  0.5622,
-         0.9717,  1.0174, -0.2860, -0.6613,  1.1461, -0.2893, -0.1124, -0.2999],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8515,  0.6070,  0.2001,  0.0773,  0.5316,  0.6571, -0.5678, -1.0729,
-        -0.5543, -0.6971, -0.1996, -0.4722,  0.0395,  0.4368, -1.0937, -0.5989,
-        -0.9209,  0.6960, -0.9357, -0.6486, -0.0348,  0.0029, -0.5945, -0.0713,
-         0.5130,  0.4299, -0.2678, -0.5333, -0.2836,  1.5520, -0.8362, -0.7143,
-        -0.8646,  0.0457, -0.8843, -0.2621,  0.7215,  0.0000,  0.7394,  0.3244,
-         0.2036,  1.0351, -0.3345,  0.7191,  0.7049,  0.6445,  0.6627,  0.3014,
-         0.7634, -0.5925,  0.4956, -0.2897,  0.5141, -0.0160,  0.5418,  0.5622,
-         0.9717,  1.0174, -0.2860, -0.6613,  1.1461, -0.2893, -0.1124, -0.2999],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8491,  0.6108,  0.2011,  0.0875,  0.5533,  0.6765, -0.5541, -1.0689,
-        -0.5547, -0.7052, -0.2310, -0.4666,  0.0126,  0.4657, -1.1206, -0.5940,
-        -0.9187,  0.6926, -0.9401, -0.6459, -0.0500, -0.0171, -0.5942, -0.0899,
-         0.5299,  0.4219, -0.2783, -0.5267, -0.2695,  1.5535, -0.8322, -0.7162,
-        -0.8616,  0.0367, -0.8810, -0.2461,  0.7414,  0.0000,  0.7476,  0.3008,
-         0.1767,  1.0551, -0.3252,  0.7272,  0.7003,  0.6273,  0.6711,  0.2753,
-         0.7719, -0.5910,  0.4902, -0.2761,  0.5119,  0.0158,  0.5210,  0.5580,
-         0.9936,  1.0151, -0.2927, -0.6728,  1.1517, -0.3129, -0.1469, -0.3052],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8491,  0.6108,  0.2011,  0.0875,  0.5533,  0.6765, -0.5541, -1.0689,
-        -0.5547, -0.7052, -0.2310, -0.4666,  0.0126,  0.4657, -1.1206, -0.5940,
-        -0.9187,  0.6926, -0.9401, -0.6459, -0.0500, -0.0171, -0.5942, -0.0899,
-         0.5299,  0.4219, -0.2783, -0.5267, -0.2695,  1.5535, -0.8322, -0.7162,
-        -0.8616,  0.0367, -0.8810, -0.2461,  0.7414,  0.0000,  0.7476,  0.3008,
-         0.1767,  1.0551, -0.3252,  0.7272,  0.7003,  0.6273,  0.6711,  0.2753,
-         0.7719, -0.5910,  0.4902, -0.2761,  0.5119,  0.0158,  0.5210,  0.5580,
-         0.9936,  1.0151, -0.2927, -0.6728,  1.1517, -0.3129, -0.1469, -0.3052],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8491,  0.6108,  0.2011,  0.0875,  0.5533,  0.6765, -0.5541, -1.0689,
-        -0.5547, -0.7052, -0.2310, -0.4666,  0.0126,  0.4657, -1.1206, -0.5940,
-        -0.9187,  0.6926, -0.9401, -0.6459, -0.0500, -0.0171, -0.5942, -0.0899,
-         0.5299,  0.4219, -0.2783, -0.5267, -0.2695,  1.5535, -0.8322, -0.7162,
-        -0.8616,  0.0367, -0.8810, -0.2461,  0.7414,  0.0000,  0.7476,  0.3008,
-         0.1767,  1.0551, -0.3252,  0.7272,  0.7003,  0.6273,  0.6711,  0.2753,
-         0.7719, -0.5910,  0.4902, -0.2761,  0.5119,  0.0158,  0.5210,  0.5580,
-         0.9936,  1.0151, -0.2927, -0.6728,  1.1517, -0.3129, -0.1469, -0.3052],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8476,  0.6133,  0.1960,  0.0940,  0.5631,  0.6911, -0.5442, -1.0651,
-        -0.5504, -0.7142, -0.2784, -0.4500, -0.0025,  0.4767, -1.1409, -0.5943,
-        -0.9175,  0.6890, -0.9441, -0.6340, -0.0892, -0.0555, -0.6063, -0.1418,
-         0.5391,  0.4140, -0.2930, -0.5231, -0.2627,  1.5544, -0.8277, -0.7152,
-        -0.8604,  0.0117, -0.8763, -0.2268,  0.7590,  0.0000,  0.7540,  0.2741,
-         0.1441,  1.0712, -0.3007,  0.7313,  0.6955,  0.6136,  0.6799,  0.2500,
-         0.7742, -0.5901,  0.4855, -0.2693,  0.5099,  0.0352,  0.5001,  0.5517,
-         1.0148,  1.0122, -0.3105, -0.6823,  1.1570, -0.3296, -0.1802, -0.2976],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8476,  0.6133,  0.1960,  0.0940,  0.5631,  0.6911, -0.5442, -1.0651,
-        -0.5504, -0.7142, -0.2784, -0.4500, -0.0025,  0.4767, -1.1409, -0.5943,
-        -0.9175,  0.6890, -0.9441, -0.6340, -0.0892, -0.0555, -0.6063, -0.1418,
-         0.5391,  0.4140, -0.2930, -0.5231, -0.2627,  1.5544, -0.8277, -0.7152,
-        -0.8604,  0.0117, -0.8763, -0.2268,  0.7590,  0.0000,  0.7540,  0.2741,
-         0.1441,  1.0712, -0.3007,  0.7313,  0.6955,  0.6136,  0.6799,  0.2500,
-         0.7742, -0.5901,  0.4855, -0.2693,  0.5099,  0.0352,  0.5001,  0.5517,
-         1.0148,  1.0122, -0.3105, -0.6823,  1.1570, -0.3296, -0.1802, -0.2976],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8476,  0.6133,  0.1960,  0.0940,  0.5631,  0.6911, -0.5442, -1.0651,
-        -0.5504, -0.7142, -0.2784, -0.4500, -0.0025,  0.4767, -1.1409, -0.5943,
-        -0.9175,  0.6890, -0.9441, -0.6340, -0.0892, -0.0555, -0.6063, -0.1418,
-         0.5391,  0.4140, -0.2930, -0.5231, -0.2627,  1.5544, -0.8277, -0.7152,
-        -0.8604,  0.0117, -0.8763, -0.2268,  0.7590,  0.0000,  0.7540,  0.2741,
-         0.1441,  1.0712, -0.3007,  0.7313,  0.6955,  0.6136,  0.6799,  0.2500,
-         0.7742, -0.5901,  0.4855, -0.2693,  0.5099,  0.0352,  0.5001,  0.5517,
-         1.0148,  1.0122, -0.3105, -0.6823,  1.1570, -0.3296, -0.1802, -0.2976],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8467,  0.6123,  0.1899,  0.1074,  0.5671,  0.7027, -0.5378, -1.0605,
-        -0.5415, -0.7226, -0.3128, -0.4210,  0.0053,  0.4820, -1.1586, -0.5926,
-        -0.9217,  0.6874, -0.9492, -0.6188, -0.1328, -0.0850, -0.6290, -0.2181,
-         0.5469,  0.4040, -0.3106, -0.5161, -0.2652,  1.5559, -0.8244, -0.7136,
-        -0.8584, -0.0254, -0.8724, -0.2042,  0.7771,  0.0000,  0.7605,  0.2528,
-         0.1130,  1.0846, -0.2796,  0.7376,  0.6908,  0.6007,  0.6901,  0.2177,
-         0.7747, -0.5897,  0.4794, -0.2650,  0.5076,  0.0404,  0.4811,  0.5427,
-         1.0338,  1.0093, -0.3320, -0.6788,  1.1636, -0.3396, -0.2094, -0.2607],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8467,  0.6123,  0.1899,  0.1074,  0.5671,  0.7027, -0.5378, -1.0605,
-        -0.5415, -0.7226, -0.3128, -0.4210,  0.0053,  0.4820, -1.1586, -0.5926,
-        -0.9217,  0.6874, -0.9492, -0.6188, -0.1328, -0.0850, -0.6290, -0.2181,
-         0.5469,  0.4040, -0.3106, -0.5161, -0.2652,  1.5559, -0.8244, -0.7136,
-        -0.8584, -0.0254, -0.8724, -0.2042,  0.7771,  0.0000,  0.7605,  0.2528,
-         0.1130,  1.0846, -0.2796,  0.7376,  0.6908,  0.6007,  0.6901,  0.2177,
-         0.7747, -0.5897,  0.4794, -0.2650,  0.5076,  0.0404,  0.4811,  0.5427,
-         1.0338,  1.0093, -0.3320, -0.6788,  1.1636, -0.3396, -0.2094, -0.2607],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8467,  0.6123,  0.1899,  0.1074,  0.5671,  0.7027, -0.5378, -1.0605,
-        -0.5415, -0.7226, -0.3128, -0.4210,  0.0053,  0.4820, -1.1586, -0.5926,
-        -0.9217,  0.6874, -0.9492, -0.6188, -0.1328, -0.0850, -0.6290, -0.2181,
-         0.5469,  0.4040, -0.3106, -0.5161, -0.2652,  1.5559, -0.8244, -0.7136,
-        -0.8584, -0.0254, -0.8724, -0.2042,  0.7771,  0.0000,  0.7605,  0.2528,
-         0.1130,  1.0846, -0.2796,  0.7376,  0.6908,  0.6007,  0.6901,  0.2177,
-         0.7747, -0.5897,  0.4794, -0.2650,  0.5076,  0.0404,  0.4811,  0.5427,
-         1.0338,  1.0093, -0.3320, -0.6788,  1.1636, -0.3396, -0.2094, -0.2607],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8474,  0.6132,  0.1829,  0.1168,  0.5688,  0.7116, -0.5337, -1.0566,
-        -0.5277, -0.7316, -0.3507, -0.3867,  0.0144,  0.4779, -1.1747, -0.5925,
-        -0.9294,  0.6856, -0.9538, -0.6004, -0.1819, -0.1123, -0.6513, -0.2942,
-         0.5534,  0.3851, -0.3150, -0.5152, -0.2635,  1.5576, -0.8218, -0.7147,
-        -0.8588, -0.0687, -0.8700, -0.1843,  0.7952,  0.0000,  0.7638,  0.2224,
-         0.0771,  1.0977, -0.2675,  0.7410,  0.6850,  0.5844,  0.6953,  0.1815,
-         0.7815, -0.5911,  0.4724, -0.2703,  0.5090,  0.0521,  0.4607,  0.5361,
-         1.0518,  1.0065, -0.3401, -0.6790,  1.1713, -0.3394, -0.2196, -0.1918],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8474,  0.6132,  0.1829,  0.1168,  0.5688,  0.7116, -0.5337, -1.0566,
-        -0.5277, -0.7316, -0.3507, -0.3867,  0.0144,  0.4779, -1.1747, -0.5925,
-        -0.9294,  0.6856, -0.9538, -0.6004, -0.1819, -0.1123, -0.6513, -0.2942,
-         0.5534,  0.3851, -0.3150, -0.5152, -0.2635,  1.5576, -0.8218, -0.7147,
-        -0.8588, -0.0687, -0.8700, -0.1843,  0.7952,  0.0000,  0.7638,  0.2224,
-         0.0771,  1.0977, -0.2675,  0.7410,  0.6850,  0.5844,  0.6953,  0.1815,
-         0.7815, -0.5911,  0.4724, -0.2703,  0.5090,  0.0521,  0.4607,  0.5361,
-         1.0518,  1.0065, -0.3401, -0.6790,  1.1713, -0.3394, -0.2196, -0.1918],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8474,  0.6132,  0.1829,  0.1168,  0.5688,  0.7116, -0.5337, -1.0566,
-        -0.5277, -0.7316, -0.3507, -0.3867,  0.0144,  0.4779, -1.1747, -0.5925,
-        -0.9294,  0.6856, -0.9538, -0.6004, -0.1819, -0.1123, -0.6513, -0.2942,
-         0.5534,  0.3851, -0.3150, -0.5152, -0.2635,  1.5576, -0.8218, -0.7147,
-        -0.8588, -0.0687, -0.8700, -0.1843,  0.7952,  0.0000,  0.7638,  0.2224,
-         0.0771,  1.0977, -0.2675,  0.7410,  0.6850,  0.5844,  0.6953,  0.1815,
-         0.7815, -0.5911,  0.4724, -0.2703,  0.5090,  0.0521,  0.4607,  0.5361,
-         1.0518,  1.0065, -0.3401, -0.6790,  1.1713, -0.3394, -0.2196, -0.1918],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8490,  0.6039,  0.1740,  0.1185,  0.5610,  0.7151, -0.5294, -1.0527,
-        -0.5123, -0.7409, -0.3927, -0.3595,  0.0355,  0.4668, -1.1904, -0.5921,
-        -0.9386,  0.6779, -0.9575, -0.5849, -0.2253, -0.1374, -0.6737, -0.3559,
-         0.5566,  0.3647, -0.3065, -0.5126, -0.2671,  1.5598, -0.8186, -0.7171,
-        -0.8621, -0.1111, -0.8682, -0.1585,  0.8146,  0.0000,  0.7626,  0.2079,
-         0.0339,  1.1099, -0.2296,  0.7422,  0.6785,  0.5683,  0.6977,  0.1529,
-         0.7865, -0.5919,  0.4633, -0.2607,  0.5098,  0.0791,  0.4369,  0.5255,
-         1.0672,  1.0001, -0.3243, -0.6882,  1.1785, -0.3314, -0.2240, -0.1222],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8490,  0.6039,  0.1740,  0.1185,  0.5610,  0.7151, -0.5294, -1.0527,
-        -0.5123, -0.7409, -0.3927, -0.3595,  0.0355,  0.4668, -1.1904, -0.5921,
-        -0.9386,  0.6779, -0.9575, -0.5849, -0.2253, -0.1374, -0.6737, -0.3559,
-         0.5566,  0.3647, -0.3065, -0.5126, -0.2671,  1.5598, -0.8186, -0.7171,
-        -0.8621, -0.1111, -0.8682, -0.1585,  0.8146,  0.0000,  0.7626,  0.2079,
-         0.0339,  1.1099, -0.2296,  0.7422,  0.6785,  0.5683,  0.6977,  0.1529,
-         0.7865, -0.5919,  0.4633, -0.2607,  0.5098,  0.0791,  0.4369,  0.5255,
-         1.0672,  1.0001, -0.3243, -0.6882,  1.1785, -0.3314, -0.2240, -0.1222],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8490,  0.6039,  0.1740,  0.1185,  0.5610,  0.7151, -0.5294, -1.0527,
-        -0.5123, -0.7409, -0.3927, -0.3595,  0.0355,  0.4668, -1.1904, -0.5921,
-        -0.9386,  0.6779, -0.9575, -0.5849, -0.2253, -0.1374, -0.6737, -0.3559,
-         0.5566,  0.3647, -0.3065, -0.5126, -0.2671,  1.5598, -0.8186, -0.7171,
-        -0.8621, -0.1111, -0.8682, -0.1585,  0.8146,  0.0000,  0.7626,  0.2079,
-         0.0339,  1.1099, -0.2296,  0.7422,  0.6785,  0.5683,  0.6977,  0.1529,
-         0.7865, -0.5919,  0.4633, -0.2607,  0.5098,  0.0791,  0.4369,  0.5255,
-         1.0672,  1.0001, -0.3243, -0.6882,  1.1785, -0.3314, -0.2240, -0.1222],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8501,  0.5965,  0.1676,  0.1304,  0.5592,  0.7166, -0.5240, -1.0491,
-        -0.4926, -0.7490, -0.4272, -0.3435,  0.0586,  0.4549, -1.2062, -0.5850,
-        -0.9468,  0.6708, -0.9595, -0.5716, -0.2552, -0.1329, -0.6914, -0.4000,
-         0.5622,  0.3327, -0.2902, -0.5076, -0.2614,  1.5620, -0.8155, -0.7168,
-        -0.8646, -0.1306, -0.8657, -0.1393,  0.8346,  0.0000,  0.7606,  0.2072,
-        -0.0147,  1.1181, -0.1869,  0.7424,  0.6721,  0.5478,  0.6995,  0.1294,
-         0.7959, -0.5925,  0.4493, -0.2360,  0.5167,  0.1159,  0.4207,  0.5121,
-         1.0822,  0.9943, -0.2969, -0.6928,  1.1852, -0.3301, -0.2213, -0.0560],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8501,  0.5965,  0.1676,  0.1304,  0.5592,  0.7166, -0.5240, -1.0491,
-        -0.4926, -0.7490, -0.4272, -0.3435,  0.0586,  0.4549, -1.2062, -0.5850,
-        -0.9468,  0.6708, -0.9595, -0.5716, -0.2552, -0.1329, -0.6914, -0.4000,
-         0.5622,  0.3327, -0.2902, -0.5076, -0.2614,  1.5620, -0.8155, -0.7168,
-        -0.8646, -0.1306, -0.8657, -0.1393,  0.8346,  0.0000,  0.7606,  0.2072,
-        -0.0147,  1.1181, -0.1869,  0.7424,  0.6721,  0.5478,  0.6995,  0.1294,
-         0.7959, -0.5925,  0.4493, -0.2360,  0.5167,  0.1159,  0.4207,  0.5121,
-         1.0822,  0.9943, -0.2969, -0.6928,  1.1852, -0.3301, -0.2213, -0.0560],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8501,  0.5965,  0.1676,  0.1304,  0.5592,  0.7166, -0.5240, -1.0491,
-        -0.4926, -0.7490, -0.4272, -0.3435,  0.0586,  0.4549, -1.2062, -0.5850,
-        -0.9468,  0.6708, -0.9595, -0.5716, -0.2552, -0.1329, -0.6914, -0.4000,
-         0.5622,  0.3327, -0.2902, -0.5076, -0.2614,  1.5620, -0.8155, -0.7168,
-        -0.8646, -0.1306, -0.8657, -0.1393,  0.8346,  0.0000,  0.7606,  0.2072,
-        -0.0147,  1.1181, -0.1869,  0.7424,  0.6721,  0.5478,  0.6995,  0.1294,
-         0.7959, -0.5925,  0.4493, -0.2360,  0.5167,  0.1159,  0.4207,  0.5121,
-         1.0822,  0.9943, -0.2969, -0.6928,  1.1852, -0.3301, -0.2213, -0.0560],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8535,  0.5810,  0.1600,  0.1479,  0.5540,  0.7116, -0.5180, -1.0450,
-        -0.4726, -0.7553, -0.4579, -0.3374,  0.0603,  0.4231, -1.2211, -0.5807,
-        -0.9495,  0.6631, -0.9600, -0.5568, -0.2692, -0.1153, -0.7009, -0.4492,
-         0.5679,  0.3091, -0.2787, -0.5002, -0.2573,  1.5641, -0.8098, -0.7163,
-        -0.8668, -0.1216, -0.8616, -0.1226,  0.8540,  0.0000,  0.7552,  0.2185,
-        -0.0477,  1.1243, -0.1507,  0.7398,  0.6655,  0.5236,  0.7004,  0.1113,
-         0.8041, -0.5938,  0.4341, -0.1815,  0.5201,  0.1599,  0.4045,  0.5024,
-         1.0964,  0.9882, -0.2748, -0.7010,  1.1931, -0.3249, -0.2172, -0.0146],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8535,  0.5810,  0.1600,  0.1479,  0.5540,  0.7116, -0.5180, -1.0450,
-        -0.4726, -0.7553, -0.4579, -0.3374,  0.0603,  0.4231, -1.2211, -0.5807,
-        -0.9495,  0.6631, -0.9600, -0.5568, -0.2692, -0.1153, -0.7009, -0.4492,
-         0.5679,  0.3091, -0.2787, -0.5002, -0.2573,  1.5641, -0.8098, -0.7163,
-        -0.8668, -0.1216, -0.8616, -0.1226,  0.8540,  0.0000,  0.7552,  0.2185,
-        -0.0477,  1.1243, -0.1507,  0.7398,  0.6655,  0.5236,  0.7004,  0.1113,
-         0.8041, -0.5938,  0.4341, -0.1815,  0.5201,  0.1599,  0.4045,  0.5024,
-         1.0964,  0.9882, -0.2748, -0.7010,  1.1931, -0.3249, -0.2172, -0.0146],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8535,  0.5810,  0.1600,  0.1479,  0.5540,  0.7116, -0.5180, -1.0450,
-        -0.4726, -0.7553, -0.4579, -0.3374,  0.0603,  0.4231, -1.2211, -0.5807,
-        -0.9495,  0.6631, -0.9600, -0.5568, -0.2692, -0.1153, -0.7009, -0.4492,
-         0.5679,  0.3091, -0.2787, -0.5002, -0.2573,  1.5641, -0.8098, -0.7163,
-        -0.8668, -0.1216, -0.8616, -0.1226,  0.8540,  0.0000,  0.7552,  0.2185,
-        -0.0477,  1.1243, -0.1507,  0.7398,  0.6655,  0.5236,  0.7004,  0.1113,
-         0.8041, -0.5938,  0.4341, -0.1815,  0.5201,  0.1599,  0.4045,  0.5024,
-         1.0964,  0.9882, -0.2748, -0.7010,  1.1931, -0.3249, -0.2172, -0.0146],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8580,  0.5682,  0.1519,  0.1486,  0.5568,  0.7077, -0.5037, -1.0415,
-        -0.4545, -0.7562, -0.4813, -0.3469,  0.0300,  0.3686, -1.2326, -0.5846,
-        -0.9479,  0.6553, -0.9596, -0.5484, -0.2972, -0.0854, -0.6926, -0.4905,
-         0.5666,  0.2776, -0.2707, -0.4858, -0.2472,  1.5675, -0.8091, -0.7165,
-        -0.8655, -0.0885, -0.8584, -0.1248,  0.8695,  0.0000,  0.7445,  0.1974,
-        -0.0750,  1.1339, -0.1468,  0.7349,  0.6593,  0.5000,  0.7032,  0.0968,
-         0.8104, -0.5956,  0.4206, -0.1424,  0.5201,  0.1960,  0.3719,  0.4885,
-         1.1095,  0.9795, -0.2756, -0.7006,  1.1990, -0.3140, -0.2239, -0.0107],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8580,  0.5682,  0.1519,  0.1486,  0.5568,  0.7077, -0.5037, -1.0415,
-        -0.4545, -0.7562, -0.4813, -0.3469,  0.0300,  0.3686, -1.2326, -0.5846,
-        -0.9479,  0.6553, -0.9596, -0.5484, -0.2972, -0.0854, -0.6926, -0.4905,
-         0.5666,  0.2776, -0.2707, -0.4858, -0.2472,  1.5675, -0.8091, -0.7165,
-        -0.8655, -0.0885, -0.8584, -0.1248,  0.8695,  0.0000,  0.7445,  0.1974,
-        -0.0750,  1.1339, -0.1468,  0.7349,  0.6593,  0.5000,  0.7032,  0.0968,
-         0.8104, -0.5956,  0.4206, -0.1424,  0.5201,  0.1960,  0.3719,  0.4885,
-         1.1095,  0.9795, -0.2756, -0.7006,  1.1990, -0.3140, -0.2239, -0.0107],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8580,  0.5682,  0.1519,  0.1486,  0.5568,  0.7077, -0.5037, -1.0415,
-        -0.4545, -0.7562, -0.4813, -0.3469,  0.0300,  0.3686, -1.2326, -0.5846,
-        -0.9479,  0.6553, -0.9596, -0.5484, -0.2972, -0.0854, -0.6926, -0.4905,
-         0.5666,  0.2776, -0.2707, -0.4858, -0.2472,  1.5675, -0.8091, -0.7165,
-        -0.8655, -0.0885, -0.8584, -0.1248,  0.8695,  0.0000,  0.7445,  0.1974,
-        -0.0750,  1.1339, -0.1468,  0.7349,  0.6593,  0.5000,  0.7032,  0.0968,
-         0.8104, -0.5956,  0.4206, -0.1424,  0.5201,  0.1960,  0.3719,  0.4885,
-         1.1095,  0.9795, -0.2756, -0.7006,  1.1990, -0.3140, -0.2239, -0.0107],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8615,  0.5498,  0.1426,  0.1437,  0.5609,  0.7063, -0.4853, -1.0392,
-        -0.4422, -0.7566, -0.5024, -0.3469, -0.0150,  0.3130, -1.2433, -0.5908,
-        -0.9448,  0.6501, -0.9575, -0.5335, -0.3347, -0.0239, -0.6772, -0.5361,
-         0.5668,  0.2451, -0.2727, -0.4706, -0.2271,  1.5718, -0.8091, -0.7202,
-        -0.8642, -0.0333, -0.8540, -0.1359,  0.8822,  0.0000,  0.7337,  0.1696,
-        -0.0924,  1.1415, -0.1539,  0.7312,  0.6514,  0.4783,  0.7057,  0.0821,
-         0.8094, -0.5977,  0.4134, -0.1033,  0.5177,  0.2258,  0.3406,  0.4769,
-         1.1206,  0.9733, -0.3000, -0.7001,  1.2049, -0.3125, -0.2168, -0.0075],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8615,  0.5498,  0.1426,  0.1437,  0.5609,  0.7063, -0.4853, -1.0392,
-        -0.4422, -0.7566, -0.5024, -0.3469, -0.0150,  0.3130, -1.2433, -0.5908,
-        -0.9448,  0.6501, -0.9575, -0.5335, -0.3347, -0.0239, -0.6772, -0.5361,
-         0.5668,  0.2451, -0.2727, -0.4706, -0.2271,  1.5718, -0.8091, -0.7202,
-        -0.8642, -0.0333, -0.8540, -0.1359,  0.8822,  0.0000,  0.7337,  0.1696,
-        -0.0924,  1.1415, -0.1539,  0.7312,  0.6514,  0.4783,  0.7057,  0.0821,
-         0.8094, -0.5977,  0.4134, -0.1033,  0.5177,  0.2258,  0.3406,  0.4769,
-         1.1206,  0.9733, -0.3000, -0.7001,  1.2049, -0.3125, -0.2168, -0.0075],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8615,  0.5498,  0.1426,  0.1437,  0.5609,  0.7063, -0.4853, -1.0392,
-        -0.4422, -0.7566, -0.5024, -0.3469, -0.0150,  0.3130, -1.2433, -0.5908,
-        -0.9448,  0.6501, -0.9575, -0.5335, -0.3347, -0.0239, -0.6772, -0.5361,
-         0.5668,  0.2451, -0.2727, -0.4706, -0.2271,  1.5718, -0.8091, -0.7202,
-        -0.8642, -0.0333, -0.8540, -0.1359,  0.8822,  0.0000,  0.7337,  0.1696,
-        -0.0924,  1.1415, -0.1539,  0.7312,  0.6514,  0.4783,  0.7057,  0.0821,
-         0.8094, -0.5977,  0.4134, -0.1033,  0.5177,  0.2258,  0.3406,  0.4769,
-         1.1206,  0.9733, -0.3000, -0.7001,  1.2049, -0.3125, -0.2168, -0.0075],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8640,  0.5320,  0.1126,  0.1348,  0.5654,  0.7066, -0.4646, -1.0361,
-        -0.4352, -0.7581, -0.5252, -0.3422, -0.0782,  0.2571, -1.2543, -0.5938,
-        -0.9410,  0.6455, -0.9536, -0.5214, -0.3666,  0.0523, -0.6580, -0.5772,
-         0.5655,  0.2158, -0.2773, -0.4560, -0.2098,  1.5758, -0.8074, -0.7258,
-        -0.8647,  0.0232, -0.8497, -0.1368,  0.8938,  0.0000,  0.7224,  0.1350,
-        -0.0992,  1.1476, -0.1670,  0.7284,  0.6425,  0.4515,  0.7073,  0.0663,
-         0.8090, -0.5992,  0.4038, -0.0688,  0.5143,  0.2458,  0.3057,  0.4591,
-         1.1309,  0.9683, -0.3125, -0.6936,  1.2092, -0.2963, -0.2008, -0.0061],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8640,  0.5320,  0.1126,  0.1348,  0.5654,  0.7066, -0.4646, -1.0361,
-        -0.4352, -0.7581, -0.5252, -0.3422, -0.0782,  0.2571, -1.2543, -0.5938,
-        -0.9410,  0.6455, -0.9536, -0.5214, -0.3666,  0.0523, -0.6580, -0.5772,
-         0.5655,  0.2158, -0.2773, -0.4560, -0.2098,  1.5758, -0.8074, -0.7258,
-        -0.8647,  0.0232, -0.8497, -0.1368,  0.8938,  0.0000,  0.7224,  0.1350,
-        -0.0992,  1.1476, -0.1670,  0.7284,  0.6425,  0.4515,  0.7073,  0.0663,
-         0.8090, -0.5992,  0.4038, -0.0688,  0.5143,  0.2458,  0.3057,  0.4591,
-         1.1309,  0.9683, -0.3125, -0.6936,  1.2092, -0.2963, -0.2008, -0.0061],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8640,  0.5320,  0.1126,  0.1348,  0.5654,  0.7066, -0.4646, -1.0361,
-        -0.4352, -0.7581, -0.5252, -0.3422, -0.0782,  0.2571, -1.2543, -0.5938,
-        -0.9410,  0.6455, -0.9536, -0.5214, -0.3666,  0.0523, -0.6580, -0.5772,
-         0.5655,  0.2158, -0.2773, -0.4560, -0.2098,  1.5758, -0.8074, -0.7258,
-        -0.8647,  0.0232, -0.8497, -0.1368,  0.8938,  0.0000,  0.7224,  0.1350,
-        -0.0992,  1.1476, -0.1670,  0.7284,  0.6425,  0.4515,  0.7073,  0.0663,
-         0.8090, -0.5992,  0.4038, -0.0688,  0.5143,  0.2458,  0.3057,  0.4591,
-         1.1309,  0.9683, -0.3125, -0.6936,  1.2092, -0.2963, -0.2008, -0.0061],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8656,  0.5115,  0.0939,  0.1335,  0.5718,  0.7014, -0.4412, -1.0328,
-        -0.4289, -0.7569, -0.5427, -0.3206, -0.1134,  0.2304, -1.2674, -0.5922,
-        -0.9348,  0.6377, -0.9505, -0.5183, -0.4073,  0.1053, -0.6441, -0.6184,
-         0.5584,  0.1972, -0.2746, -0.4430, -0.1966,  1.5780, -0.8044, -0.7284,
-        -0.8676,  0.0553, -0.8461, -0.1422,  0.9068,  0.0000,  0.7189,  0.1351,
-        -0.0901,  1.1528, -0.1569,  0.7295,  0.6330,  0.4215,  0.7081,  0.0564,
-         0.8018, -0.5989,  0.3900, -0.0458,  0.5043,  0.2656,  0.2728,  0.4339,
-         1.1404,  0.9667, -0.3008, -0.6888,  1.2145, -0.2633, -0.1884,  0.0172],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8656,  0.5115,  0.0939,  0.1335,  0.5718,  0.7014, -0.4412, -1.0328,
-        -0.4289, -0.7569, -0.5427, -0.3206, -0.1134,  0.2304, -1.2674, -0.5922,
-        -0.9348,  0.6377, -0.9505, -0.5183, -0.4073,  0.1053, -0.6441, -0.6184,
-         0.5584,  0.1972, -0.2746, -0.4430, -0.1966,  1.5780, -0.8044, -0.7284,
-        -0.8676,  0.0553, -0.8461, -0.1422,  0.9068,  0.0000,  0.7189,  0.1351,
-        -0.0901,  1.1528, -0.1569,  0.7295,  0.6330,  0.4215,  0.7081,  0.0564,
-         0.8018, -0.5989,  0.3900, -0.0458,  0.5043,  0.2656,  0.2728,  0.4339,
-         1.1404,  0.9667, -0.3008, -0.6888,  1.2145, -0.2633, -0.1884,  0.0172],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8656,  0.5115,  0.0939,  0.1335,  0.5718,  0.7014, -0.4412, -1.0328,
-        -0.4289, -0.7569, -0.5427, -0.3206, -0.1134,  0.2304, -1.2674, -0.5922,
-        -0.9348,  0.6377, -0.9505, -0.5183, -0.4073,  0.1053, -0.6441, -0.6184,
-         0.5584,  0.1972, -0.2746, -0.4430, -0.1966,  1.5780, -0.8044, -0.7284,
-        -0.8676,  0.0553, -0.8461, -0.1422,  0.9068,  0.0000,  0.7189,  0.1351,
-        -0.0901,  1.1528, -0.1569,  0.7295,  0.6330,  0.4215,  0.7081,  0.0564,
-         0.8018, -0.5989,  0.3900, -0.0458,  0.5043,  0.2656,  0.2728,  0.4339,
-         1.1404,  0.9667, -0.3008, -0.6888,  1.2145, -0.2633, -0.1884,  0.0172],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8669,  0.4829,  0.0972,  0.1472,  0.5625,  0.6824, -0.4260, -1.0295,
-        -0.4305, -0.7566, -0.5594, -0.2943, -0.0971,  0.2563, -1.2820, -0.5911,
-        -0.9294,  0.6229, -0.9481, -0.5163, -0.4501,  0.1345, -0.6443, -0.6559,
-         0.5519,  0.1859, -0.2537, -0.4325, -0.1762,  1.5793, -0.7995, -0.7330,
-        -0.8718,  0.0407, -0.8443, -0.1425,  0.9189,  0.0000,  0.7189,  0.1953,
-        -0.0826,  1.1557, -0.1064,  0.7344,  0.6248,  0.3809,  0.7055,  0.0529,
-         0.7995, -0.5960,  0.3688, -0.0358,  0.4932,  0.2845,  0.2478,  0.4123,
-         1.1489,  0.9678, -0.2617, -0.6868,  1.2202, -0.2200, -0.1738,  0.0540],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8669,  0.4829,  0.0972,  0.1472,  0.5625,  0.6824, -0.4260, -1.0295,
-        -0.4305, -0.7566, -0.5594, -0.2943, -0.0971,  0.2563, -1.2820, -0.5911,
-        -0.9294,  0.6229, -0.9481, -0.5163, -0.4501,  0.1345, -0.6443, -0.6559,
-         0.5519,  0.1859, -0.2537, -0.4325, -0.1762,  1.5793, -0.7995, -0.7330,
-        -0.8718,  0.0407, -0.8443, -0.1425,  0.9189,  0.0000,  0.7189,  0.1953,
-        -0.0826,  1.1557, -0.1064,  0.7344,  0.6248,  0.3809,  0.7055,  0.0529,
-         0.7995, -0.5960,  0.3688, -0.0358,  0.4932,  0.2845,  0.2478,  0.4123,
-         1.1489,  0.9678, -0.2617, -0.6868,  1.2202, -0.2200, -0.1738,  0.0540],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8669,  0.4829,  0.0972,  0.1472,  0.5625,  0.6824, -0.4260, -1.0295,
-        -0.4305, -0.7566, -0.5594, -0.2943, -0.0971,  0.2563, -1.2820, -0.5911,
-        -0.9294,  0.6229, -0.9481, -0.5163, -0.4501,  0.1345, -0.6443, -0.6559,
-         0.5519,  0.1859, -0.2537, -0.4325, -0.1762,  1.5793, -0.7995, -0.7330,
-        -0.8718,  0.0407, -0.8443, -0.1425,  0.9189,  0.0000,  0.7189,  0.1953,
-        -0.0826,  1.1557, -0.1064,  0.7344,  0.6248,  0.3809,  0.7055,  0.0529,
-         0.7995, -0.5960,  0.3688, -0.0358,  0.4932,  0.2845,  0.2478,  0.4123,
-         1.1489,  0.9678, -0.2617, -0.6868,  1.2202, -0.2200, -0.1738,  0.0540],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8688,  0.4379,  0.0994,  0.1414,  0.5560,  0.6651, -0.4047, -1.0253,
-        -0.4308, -0.7580, -0.5686, -0.2633, -0.0554,  0.2868, -1.2924, -0.5931,
-        -0.9239,  0.6103, -0.9454, -0.5114, -0.4967,  0.1284, -0.6470, -0.6966,
-         0.5330,  0.1831, -0.2245, -0.4145, -0.1705,  1.5809, -0.7938, -0.7410,
-        -0.8770, -0.0184, -0.8423, -0.1393,  0.9317,  0.0000,  0.7166,  0.2511,
-        -0.0946,  1.1601, -0.0619,  0.7387,  0.6193,  0.3414,  0.7009,  0.0547,
-         0.8046, -0.5957,  0.3517, -0.0678,  0.4722,  0.2912,  0.2280,  0.3978,
-         1.1567,  0.9689, -0.2237, -0.6884,  1.2253, -0.1650, -0.1686,  0.0921],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8688,  0.4379,  0.0994,  0.1414,  0.5560,  0.6651, -0.4047, -1.0253,
-        -0.4308, -0.7580, -0.5686, -0.2633, -0.0554,  0.2868, -1.2924, -0.5931,
-        -0.9239,  0.6103, -0.9454, -0.5114, -0.4967,  0.1284, -0.6470, -0.6966,
-         0.5330,  0.1831, -0.2245, -0.4145, -0.1705,  1.5809, -0.7938, -0.7410,
-        -0.8770, -0.0184, -0.8423, -0.1393,  0.9317,  0.0000,  0.7166,  0.2511,
-        -0.0946,  1.1601, -0.0619,  0.7387,  0.6193,  0.3414,  0.7009,  0.0547,
-         0.8046, -0.5957,  0.3517, -0.0678,  0.4722,  0.2912,  0.2280,  0.3978,
-         1.1567,  0.9689, -0.2237, -0.6884,  1.2253, -0.1650, -0.1686,  0.0921],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8688,  0.4379,  0.0994,  0.1414,  0.5560,  0.6651, -0.4047, -1.0253,
-        -0.4308, -0.7580, -0.5686, -0.2633, -0.0554,  0.2868, -1.2924, -0.5931,
-        -0.9239,  0.6103, -0.9454, -0.5114, -0.4967,  0.1284, -0.6470, -0.6966,
-         0.5330,  0.1831, -0.2245, -0.4145, -0.1705,  1.5809, -0.7938, -0.7410,
-        -0.8770, -0.0184, -0.8423, -0.1393,  0.9317,  0.0000,  0.7166,  0.2511,
-        -0.0946,  1.1601, -0.0619,  0.7387,  0.6193,  0.3414,  0.7009,  0.0547,
-         0.8046, -0.5957,  0.3517, -0.0678,  0.4722,  0.2912,  0.2280,  0.3978,
-         1.1567,  0.9689, -0.2237, -0.6884,  1.2253, -0.1650, -0.1686,  0.0921],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8696,  0.3970,  0.1143,  0.1278,  0.5514,  0.6516, -0.3824, -1.0206,
-        -0.4309, -0.7595, -0.5691, -0.2167,  0.0160,  0.3043, -1.3013, -0.5937,
-        -0.9225,  0.5987, -0.9449, -0.5065, -0.5425,  0.1188, -0.6465, -0.7227,
-         0.5062,  0.1883, -0.1891, -0.4015, -0.1601,  1.5817, -0.7894, -0.7479,
-        -0.8811, -0.0756, -0.8405, -0.1365,  0.9439,  0.0000,  0.7094,  0.3116,
-        -0.1168,  1.1625,  0.0049,  0.7410,  0.6145,  0.3042,  0.6945,  0.0510,
-         0.8084, -0.5940,  0.3379, -0.0902,  0.4497,  0.2889,  0.2170,  0.3915,
-         1.1637,  0.9719, -0.1641, -0.6805,  1.2303, -0.1053, -0.1621,  0.1267],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8696,  0.3970,  0.1143,  0.1278,  0.5514,  0.6516, -0.3824, -1.0206,
-        -0.4309, -0.7595, -0.5691, -0.2167,  0.0160,  0.3043, -1.3013, -0.5937,
-        -0.9225,  0.5987, -0.9449, -0.5065, -0.5425,  0.1188, -0.6465, -0.7227,
-         0.5062,  0.1883, -0.1891, -0.4015, -0.1601,  1.5817, -0.7894, -0.7479,
-        -0.8811, -0.0756, -0.8405, -0.1365,  0.9439,  0.0000,  0.7094,  0.3116,
-        -0.1168,  1.1625,  0.0049,  0.7410,  0.6145,  0.3042,  0.6945,  0.0510,
-         0.8084, -0.5940,  0.3379, -0.0902,  0.4497,  0.2889,  0.2170,  0.3915,
-         1.1637,  0.9719, -0.1641, -0.6805,  1.2303, -0.1053, -0.1621,  0.1267],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8696,  0.3970,  0.1143,  0.1278,  0.5514,  0.6516, -0.3824, -1.0206,
-        -0.4309, -0.7595, -0.5691, -0.2167,  0.0160,  0.3043, -1.3013, -0.5937,
-        -0.9225,  0.5987, -0.9449, -0.5065, -0.5425,  0.1188, -0.6465, -0.7227,
-         0.5062,  0.1883, -0.1891, -0.4015, -0.1601,  1.5817, -0.7894, -0.7479,
-        -0.8811, -0.0756, -0.8405, -0.1365,  0.9439,  0.0000,  0.7094,  0.3116,
-        -0.1168,  1.1625,  0.0049,  0.7410,  0.6145,  0.3042,  0.6945,  0.0510,
-         0.8084, -0.5940,  0.3379, -0.0902,  0.4497,  0.2889,  0.2170,  0.3915,
-         1.1637,  0.9719, -0.1641, -0.6805,  1.2303, -0.1053, -0.1621,  0.1267],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8698,  0.3611,  0.1099,  0.1090,  0.5275,  0.6427, -0.3652, -1.0166,
-        -0.4233, -0.7580, -0.5605, -0.1878,  0.0627,  0.3002, -1.3084, -0.5952,
-        -0.9189,  0.5894, -0.9454, -0.5083, -0.5770,  0.1144, -0.6397, -0.7255,
-         0.4741,  0.2004, -0.1680, -0.3891, -0.1354,  1.5824, -0.7827, -0.7523,
-        -0.8859, -0.0990, -0.8373, -0.1441,  0.9571,  0.0000,  0.7045,  0.3622,
-        -0.1115,  1.1636,  0.0488,  0.7433,  0.6091,  0.2751,  0.6872,  0.0442,
-         0.8123, -0.5903,  0.3236, -0.0826,  0.4268,  0.2764,  0.2027,  0.3790,
-         1.1707,  0.9743, -0.0914, -0.6543,  1.2349, -0.0513, -0.1588,  0.1596],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8698,  0.3611,  0.1099,  0.1090,  0.5275,  0.6427, -0.3652, -1.0166,
-        -0.4233, -0.7580, -0.5605, -0.1878,  0.0627,  0.3002, -1.3084, -0.5952,
-        -0.9189,  0.5894, -0.9454, -0.5083, -0.5770,  0.1144, -0.6397, -0.7255,
-         0.4741,  0.2004, -0.1680, -0.3891, -0.1354,  1.5824, -0.7827, -0.7523,
-        -0.8859, -0.0990, -0.8373, -0.1441,  0.9571,  0.0000,  0.7045,  0.3622,
-        -0.1115,  1.1636,  0.0488,  0.7433,  0.6091,  0.2751,  0.6872,  0.0442,
-         0.8123, -0.5903,  0.3236, -0.0826,  0.4268,  0.2764,  0.2027,  0.3790,
-         1.1707,  0.9743, -0.0914, -0.6543,  1.2349, -0.0513, -0.1588,  0.1596],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8698,  0.3611,  0.1099,  0.1090,  0.5275,  0.6427, -0.3652, -1.0166,
-        -0.4233, -0.7580, -0.5605, -0.1878,  0.0627,  0.3002, -1.3084, -0.5952,
-        -0.9189,  0.5894, -0.9454, -0.5083, -0.5770,  0.1144, -0.6397, -0.7255,
-         0.4741,  0.2004, -0.1680, -0.3891, -0.1354,  1.5824, -0.7827, -0.7523,
-        -0.8859, -0.0990, -0.8373, -0.1441,  0.9571,  0.0000,  0.7045,  0.3622,
-        -0.1115,  1.1636,  0.0488,  0.7433,  0.6091,  0.2751,  0.6872,  0.0442,
-         0.8123, -0.5903,  0.3236, -0.0826,  0.4268,  0.2764,  0.2027,  0.3790,
-         1.1707,  0.9743, -0.0914, -0.6543,  1.2349, -0.0513, -0.1588,  0.1596],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8704,  0.3286,  0.1077,  0.0956,  0.4878,  0.6359, -0.3478, -1.0133,
-        -0.4153, -0.7529, -0.5467, -0.1702,  0.1164,  0.3137, -1.3147, -0.5960,
-        -0.9108,  0.5797, -0.9463, -0.5160, -0.5981,  0.1140, -0.6267, -0.7357,
-         0.4419,  0.2234, -0.1568, -0.3778, -0.1140,  1.5825, -0.7775, -0.7555,
-        -0.8920, -0.0965, -0.8347, -0.1692,  0.9711,  0.0000,  0.7129,  0.4328,
-        -0.0760,  1.1650,  0.0857,  0.7479,  0.6016,  0.2381,  0.6803,  0.0424,
-         0.8087, -0.5866,  0.3149, -0.0707,  0.3964,  0.2709,  0.1915,  0.3780,
-         1.1764,  0.9833, -0.0193, -0.6249,  1.2403,  0.0029, -0.1500,  0.1741],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8704,  0.3286,  0.1077,  0.0956,  0.4878,  0.6359, -0.3478, -1.0133,
-        -0.4153, -0.7529, -0.5467, -0.1702,  0.1164,  0.3137, -1.3147, -0.5960,
-        -0.9108,  0.5797, -0.9463, -0.5160, -0.5981,  0.1140, -0.6267, -0.7357,
-         0.4419,  0.2234, -0.1568, -0.3778, -0.1140,  1.5825, -0.7775, -0.7555,
-        -0.8920, -0.0965, -0.8347, -0.1692,  0.9711,  0.0000,  0.7129,  0.4328,
-        -0.0760,  1.1650,  0.0857,  0.7479,  0.6016,  0.2381,  0.6803,  0.0424,
-         0.8087, -0.5866,  0.3149, -0.0707,  0.3964,  0.2709,  0.1915,  0.3780,
-         1.1764,  0.9833, -0.0193, -0.6249,  1.2403,  0.0029, -0.1500,  0.1741],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8704,  0.3286,  0.1077,  0.0956,  0.4878,  0.6359, -0.3478, -1.0133,
-        -0.4153, -0.7529, -0.5467, -0.1702,  0.1164,  0.3137, -1.3147, -0.5960,
-        -0.9108,  0.5797, -0.9463, -0.5160, -0.5981,  0.1140, -0.6267, -0.7357,
-         0.4419,  0.2234, -0.1568, -0.3778, -0.1140,  1.5825, -0.7775, -0.7555,
-        -0.8920, -0.0965, -0.8347, -0.1692,  0.9711,  0.0000,  0.7129,  0.4328,
-        -0.0760,  1.1650,  0.0857,  0.7479,  0.6016,  0.2381,  0.6803,  0.0424,
-         0.8087, -0.5866,  0.3149, -0.0707,  0.3964,  0.2709,  0.1915,  0.3780,
-         1.1764,  0.9833, -0.0193, -0.6249,  1.2403,  0.0029, -0.1500,  0.1741],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8679,  0.3139,  0.0929,  0.0541,  0.4245,  0.6272, -0.3308, -1.0107,
-        -0.4154, -0.7463, -0.5303, -0.1639,  0.1226,  0.3212, -1.3224, -0.5940,
-        -0.9024,  0.5757, -0.9478, -0.5246, -0.6207,  0.1331, -0.6115, -0.7372,
-         0.4204,  0.2472, -0.1458, -0.3620, -0.1023,  1.5820, -0.7730, -0.7567,
-        -0.8995, -0.1088, -0.8316, -0.1923,  0.9837,  0.0000,  0.7192,  0.4981,
-        -0.0179,  1.1649,  0.0936,  0.7506,  0.5945,  0.2056,  0.6726,  0.0391,
-         0.8051, -0.5816,  0.3122, -0.0646,  0.3859,  0.2674,  0.2138,  0.3733,
-         1.1825,  0.9923,  0.0642, -0.6004,  1.2448,  0.0722, -0.1372,  0.1711],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8679,  0.3139,  0.0929,  0.0541,  0.4245,  0.6272, -0.3308, -1.0107,
-        -0.4154, -0.7463, -0.5303, -0.1639,  0.1226,  0.3212, -1.3224, -0.5940,
-        -0.9024,  0.5757, -0.9478, -0.5246, -0.6207,  0.1331, -0.6115, -0.7372,
-         0.4204,  0.2472, -0.1458, -0.3620, -0.1023,  1.5820, -0.7730, -0.7567,
-        -0.8995, -0.1088, -0.8316, -0.1923,  0.9837,  0.0000,  0.7192,  0.4981,
-        -0.0179,  1.1649,  0.0936,  0.7506,  0.5945,  0.2056,  0.6726,  0.0391,
-         0.8051, -0.5816,  0.3122, -0.0646,  0.3859,  0.2674,  0.2138,  0.3733,
-         1.1825,  0.9923,  0.0642, -0.6004,  1.2448,  0.0722, -0.1372,  0.1711],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8679,  0.3139,  0.0929,  0.0541,  0.4245,  0.6272, -0.3308, -1.0107,
-        -0.4154, -0.7463, -0.5303, -0.1639,  0.1226,  0.3212, -1.3224, -0.5940,
-        -0.9024,  0.5757, -0.9478, -0.5246, -0.6207,  0.1331, -0.6115, -0.7372,
-         0.4204,  0.2472, -0.1458, -0.3620, -0.1023,  1.5820, -0.7730, -0.7567,
-        -0.8995, -0.1088, -0.8316, -0.1923,  0.9837,  0.0000,  0.7192,  0.4981,
-        -0.0179,  1.1649,  0.0936,  0.7506,  0.5945,  0.2056,  0.6726,  0.0391,
-         0.8051, -0.5816,  0.3122, -0.0646,  0.3859,  0.2674,  0.2138,  0.3733,
-         1.1825,  0.9923,  0.0642, -0.6004,  1.2448,  0.0722, -0.1372,  0.1711],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8672,  0.3030,  0.0648,  0.0193,  0.3642,  0.6214, -0.3156, -1.0073,
-        -0.4101, -0.7380, -0.5107, -0.1470,  0.1068,  0.3307, -1.3287, -0.5852,
-        -0.8922,  0.5744, -0.9503, -0.5357, -0.6376,  0.1431, -0.5926, -0.7413,
-         0.3965,  0.2771, -0.1334, -0.3441, -0.1018,  1.5810, -0.7673, -0.7574,
-        -0.9065, -0.1061, -0.8263, -0.2126,  0.9943,  0.0000,  0.7328,  0.5667,
-         0.0477,  1.1632,  0.0676,  0.7547,  0.5879,  0.1770,  0.6652,  0.0417,
-         0.8048, -0.5754,  0.3123, -0.0399,  0.3696,  0.2673,  0.2639,  0.3743,
-         1.1871,  1.0038,  0.1352, -0.5730,  1.2491,  0.1186, -0.1312,  0.1505],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8672,  0.3030,  0.0648,  0.0193,  0.3642,  0.6214, -0.3156, -1.0073,
-        -0.4101, -0.7380, -0.5107, -0.1470,  0.1068,  0.3307, -1.3287, -0.5852,
-        -0.8922,  0.5744, -0.9503, -0.5357, -0.6376,  0.1431, -0.5926, -0.7413,
-         0.3965,  0.2771, -0.1334, -0.3441, -0.1018,  1.5810, -0.7673, -0.7574,
-        -0.9065, -0.1061, -0.8263, -0.2126,  0.9943,  0.0000,  0.7328,  0.5667,
-         0.0477,  1.1632,  0.0676,  0.7547,  0.5879,  0.1770,  0.6652,  0.0417,
-         0.8048, -0.5754,  0.3123, -0.0399,  0.3696,  0.2673,  0.2639,  0.3743,
-         1.1871,  1.0038,  0.1352, -0.5730,  1.2491,  0.1186, -0.1312,  0.1505],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8672,  0.3030,  0.0648,  0.0193,  0.3642,  0.6214, -0.3156, -1.0073,
-        -0.4101, -0.7380, -0.5107, -0.1470,  0.1068,  0.3307, -1.3287, -0.5852,
-        -0.8922,  0.5744, -0.9503, -0.5357, -0.6376,  0.1431, -0.5926, -0.7413,
-         0.3965,  0.2771, -0.1334, -0.3441, -0.1018,  1.5810, -0.7673, -0.7574,
-        -0.9065, -0.1061, -0.8263, -0.2126,  0.9943,  0.0000,  0.7328,  0.5667,
-         0.0477,  1.1632,  0.0676,  0.7547,  0.5879,  0.1770,  0.6652,  0.0417,
-         0.8048, -0.5754,  0.3123, -0.0399,  0.3696,  0.2673,  0.2639,  0.3743,
-         1.1871,  1.0038,  0.1352, -0.5730,  1.2491,  0.1186, -0.1312,  0.1505],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8614,  0.3050,  0.0419, -0.0145,  0.3315,  0.6065, -0.3018, -1.0037,
-        -0.4066, -0.7240, -0.4917, -0.1261,  0.0862,  0.3117, -1.3317, -0.5752,
-        -0.8818,  0.5694, -0.9537, -0.5408, -0.6529,  0.1675, -0.5670, -0.7463,
-         0.3711,  0.2870, -0.1297, -0.3210, -0.1083,  1.5799, -0.7640, -0.7552,
-        -0.9134, -0.0812, -0.8187, -0.2431,  1.0043,  0.0000,  0.7388,  0.6328,
-         0.1135,  1.1619,  0.0513,  0.7566,  0.5826,  0.1513,  0.6596,  0.0444,
-         0.8066, -0.5697,  0.3179,  0.0136,  0.3618,  0.2705,  0.3170,  0.3688,
-         1.1913,  1.0124,  0.1925, -0.5293,  1.2533,  0.1649, -0.1249,  0.1178],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8614,  0.3050,  0.0419, -0.0145,  0.3315,  0.6065, -0.3018, -1.0037,
-        -0.4066, -0.7240, -0.4917, -0.1261,  0.0862,  0.3117, -1.3317, -0.5752,
-        -0.8818,  0.5694, -0.9537, -0.5408, -0.6529,  0.1675, -0.5670, -0.7463,
-         0.3711,  0.2870, -0.1297, -0.3210, -0.1083,  1.5799, -0.7640, -0.7552,
-        -0.9134, -0.0812, -0.8187, -0.2431,  1.0043,  0.0000,  0.7388,  0.6328,
-         0.1135,  1.1619,  0.0513,  0.7566,  0.5826,  0.1513,  0.6596,  0.0444,
-         0.8066, -0.5697,  0.3179,  0.0136,  0.3618,  0.2705,  0.3170,  0.3688,
-         1.1913,  1.0124,  0.1925, -0.5293,  1.2533,  0.1649, -0.1249,  0.1178],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8614,  0.3050,  0.0419, -0.0145,  0.3315,  0.6065, -0.3018, -1.0037,
-        -0.4066, -0.7240, -0.4917, -0.1261,  0.0862,  0.3117, -1.3317, -0.5752,
-        -0.8818,  0.5694, -0.9537, -0.5408, -0.6529,  0.1675, -0.5670, -0.7463,
-         0.3711,  0.2870, -0.1297, -0.3210, -0.1083,  1.5799, -0.7640, -0.7552,
-        -0.9134, -0.0812, -0.8187, -0.2431,  1.0043,  0.0000,  0.7388,  0.6328,
-         0.1135,  1.1619,  0.0513,  0.7566,  0.5826,  0.1513,  0.6596,  0.0444,
-         0.8066, -0.5697,  0.3179,  0.0136,  0.3618,  0.2705,  0.3170,  0.3688,
-         1.1913,  1.0124,  0.1925, -0.5293,  1.2533,  0.1649, -0.1249,  0.1178],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8566,  0.3113,  0.0420, -0.0425,  0.3402,  0.5804, -0.3005, -0.9985,
-        -0.4009, -0.7087, -0.4680, -0.0930,  0.0949,  0.2933, -1.3351, -0.5605,
-        -0.8733,  0.5662, -0.9572, -0.5380, -0.6701,  0.1948, -0.5460, -0.7614,
-         0.3433,  0.2954, -0.1110, -0.3044, -0.0967,  1.5793, -0.7585, -0.7508,
-        -0.9182, -0.0417, -0.8074, -0.2594,  1.0123,  0.0000,  0.7393,  0.6916,
-         0.1578,  1.1618,  0.0555,  0.7544,  0.5799,  0.1336,  0.6559,  0.0417,
-         0.8091, -0.5644,  0.3254,  0.1106,  0.3542,  0.2757,  0.3596,  0.3612,
-         1.1942,  1.0189,  0.2592, -0.4954,  1.2575,  0.2151, -0.1040,  0.1039],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8566,  0.3113,  0.0420, -0.0425,  0.3402,  0.5804, -0.3005, -0.9985,
-        -0.4009, -0.7087, -0.4680, -0.0930,  0.0949,  0.2933, -1.3351, -0.5605,
-        -0.8733,  0.5662, -0.9572, -0.5380, -0.6701,  0.1948, -0.5460, -0.7614,
-         0.3433,  0.2954, -0.1110, -0.3044, -0.0967,  1.5793, -0.7585, -0.7508,
-        -0.9182, -0.0417, -0.8074, -0.2594,  1.0123,  0.0000,  0.7393,  0.6916,
-         0.1578,  1.1618,  0.0555,  0.7544,  0.5799,  0.1336,  0.6559,  0.0417,
-         0.8091, -0.5644,  0.3254,  0.1106,  0.3542,  0.2757,  0.3596,  0.3612,
-         1.1942,  1.0189,  0.2592, -0.4954,  1.2575,  0.2151, -0.1040,  0.1039],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8566,  0.3113,  0.0420, -0.0425,  0.3402,  0.5804, -0.3005, -0.9985,
-        -0.4009, -0.7087, -0.4680, -0.0930,  0.0949,  0.2933, -1.3351, -0.5605,
-        -0.8733,  0.5662, -0.9572, -0.5380, -0.6701,  0.1948, -0.5460, -0.7614,
-         0.3433,  0.2954, -0.1110, -0.3044, -0.0967,  1.5793, -0.7585, -0.7508,
-        -0.9182, -0.0417, -0.8074, -0.2594,  1.0123,  0.0000,  0.7393,  0.6916,
-         0.1578,  1.1618,  0.0555,  0.7544,  0.5799,  0.1336,  0.6559,  0.0417,
-         0.8091, -0.5644,  0.3254,  0.1106,  0.3542,  0.2757,  0.3596,  0.3612,
-         1.1942,  1.0189,  0.2592, -0.4954,  1.2575,  0.2151, -0.1040,  0.1039],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8527,  0.3075,  0.0538, -0.0763,  0.3537,  0.5605, -0.2848, -0.9932,
-        -0.3962, -0.6895, -0.4484, -0.0556,  0.1198,  0.2800, -1.3415, -0.5452,
-        -0.8655,  0.5664, -0.9609, -0.5298, -0.6882,  0.2132, -0.5262, -0.7765,
-         0.3103,  0.2877, -0.0869, -0.2927, -0.0813,  1.5782, -0.7500, -0.7489,
-        -0.9225, -0.0154, -0.7969, -0.2784,  1.0154,  0.0000,  0.7373,  0.7465,
-         0.1914,  1.1638,  0.0705,  0.7518,  0.5802,  0.1237,  0.6503,  0.0413,
-         0.8054, -0.5587,  0.3282,  0.1911,  0.3487,  0.2673,  0.4117,  0.3498,
-         1.1965,  1.0253,  0.3289, -0.4747,  1.2622,  0.2682, -0.0792,  0.1003],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8527,  0.3075,  0.0538, -0.0763,  0.3537,  0.5605, -0.2848, -0.9932,
-        -0.3962, -0.6895, -0.4484, -0.0556,  0.1198,  0.2800, -1.3415, -0.5452,
-        -0.8655,  0.5664, -0.9609, -0.5298, -0.6882,  0.2132, -0.5262, -0.7765,
-         0.3103,  0.2877, -0.0869, -0.2927, -0.0813,  1.5782, -0.7500, -0.7489,
-        -0.9225, -0.0154, -0.7969, -0.2784,  1.0154,  0.0000,  0.7373,  0.7465,
-         0.1914,  1.1638,  0.0705,  0.7518,  0.5802,  0.1237,  0.6503,  0.0413,
-         0.8054, -0.5587,  0.3282,  0.1911,  0.3487,  0.2673,  0.4117,  0.3498,
-         1.1965,  1.0253,  0.3289, -0.4747,  1.2622,  0.2682, -0.0792,  0.1003],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8527,  0.3075,  0.0538, -0.0763,  0.3537,  0.5605, -0.2848, -0.9932,
-        -0.3962, -0.6895, -0.4484, -0.0556,  0.1198,  0.2800, -1.3415, -0.5452,
-        -0.8655,  0.5664, -0.9609, -0.5298, -0.6882,  0.2132, -0.5262, -0.7765,
-         0.3103,  0.2877, -0.0869, -0.2927, -0.0813,  1.5782, -0.7500, -0.7489,
-        -0.9225, -0.0154, -0.7969, -0.2784,  1.0154,  0.0000,  0.7373,  0.7465,
-         0.1914,  1.1638,  0.0705,  0.7518,  0.5802,  0.1237,  0.6503,  0.0413,
-         0.8054, -0.5587,  0.3282,  0.1911,  0.3487,  0.2673,  0.4117,  0.3498,
-         1.1965,  1.0253,  0.3289, -0.4747,  1.2622,  0.2682, -0.0792,  0.1003],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8489,  0.3025,  0.0578, -0.1292,  0.3443,  0.5440, -0.2672, -0.9873,
-        -0.4080, -0.6745, -0.4315,  0.0092,  0.1380,  0.2704, -1.3479, -0.5388,
-        -0.8563,  0.5678, -0.9633, -0.5123, -0.7147,  0.2302, -0.5135, -0.7945,
-         0.2699,  0.2711, -0.0704, -0.2985, -0.0735,  1.5769, -0.7365, -0.7470,
-        -0.9264, -0.0130, -0.7855, -0.2727,  1.0135,  0.0000,  0.7217,  0.7883,
-         0.2323,  1.1661,  0.0815,  0.7420,  0.5806,  0.1357,  0.6404,  0.0411,
-         0.8017, -0.5518,  0.3246,  0.2735,  0.3493,  0.2570,  0.4444,  0.3352,
-         1.1985,  1.0291,  0.3760, -0.4897,  1.2661,  0.3257, -0.0445,  0.1112],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8489,  0.3025,  0.0578, -0.1292,  0.3443,  0.5440, -0.2672, -0.9873,
-        -0.4080, -0.6745, -0.4315,  0.0092,  0.1380,  0.2704, -1.3479, -0.5388,
-        -0.8563,  0.5678, -0.9633, -0.5123, -0.7147,  0.2302, -0.5135, -0.7945,
-         0.2699,  0.2711, -0.0704, -0.2985, -0.0735,  1.5769, -0.7365, -0.7470,
-        -0.9264, -0.0130, -0.7855, -0.2727,  1.0135,  0.0000,  0.7217,  0.7883,
-         0.2323,  1.1661,  0.0815,  0.7420,  0.5806,  0.1357,  0.6404,  0.0411,
-         0.8017, -0.5518,  0.3246,  0.2735,  0.3493,  0.2570,  0.4444,  0.3352,
-         1.1985,  1.0291,  0.3760, -0.4897,  1.2661,  0.3257, -0.0445,  0.1112],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8489,  0.3025,  0.0578, -0.1292,  0.3443,  0.5440, -0.2672, -0.9873,
-        -0.4080, -0.6745, -0.4315,  0.0092,  0.1380,  0.2704, -1.3479, -0.5388,
-        -0.8563,  0.5678, -0.9633, -0.5123, -0.7147,  0.2302, -0.5135, -0.7945,
-         0.2699,  0.2711, -0.0704, -0.2985, -0.0735,  1.5769, -0.7365, -0.7470,
-        -0.9264, -0.0130, -0.7855, -0.2727,  1.0135,  0.0000,  0.7217,  0.7883,
-         0.2323,  1.1661,  0.0815,  0.7420,  0.5806,  0.1357,  0.6404,  0.0411,
-         0.8017, -0.5518,  0.3246,  0.2735,  0.3493,  0.2570,  0.4444,  0.3352,
-         1.1985,  1.0291,  0.3760, -0.4897,  1.2661,  0.3257, -0.0445,  0.1112],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8425,  0.2976,  0.0388, -0.1585,  0.3094,  0.5290, -0.2437, -0.9823,
-        -0.4177, -0.6654, -0.4195,  0.0475,  0.1327,  0.2572, -1.3518, -0.5333,
-        -0.8482,  0.5650, -0.9676, -0.5022, -0.7319,  0.2373, -0.5030, -0.8043,
-         0.2128,  0.2431, -0.0533, -0.2969, -0.0725,  1.5760, -0.7242, -0.7457,
-        -0.9287,  0.0127, -0.7745, -0.2500,  1.0096,  0.0000,  0.6997,  0.8206,
-         0.2632,  1.1646,  0.0770,  0.7333,  0.5830,  0.1496,  0.6319,  0.0279,
-         0.8071, -0.5447,  0.3186,  0.2982,  0.3509,  0.2218,  0.4740,  0.3247,
-         1.2011,  1.0300,  0.4202, -0.4719,  1.2688,  0.3675, -0.0295,  0.1025],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8425,  0.2976,  0.0388, -0.1585,  0.3094,  0.5290, -0.2437, -0.9823,
-        -0.4177, -0.6654, -0.4195,  0.0475,  0.1327,  0.2572, -1.3518, -0.5333,
-        -0.8482,  0.5650, -0.9676, -0.5022, -0.7319,  0.2373, -0.5030, -0.8043,
-         0.2128,  0.2431, -0.0533, -0.2969, -0.0725,  1.5760, -0.7242, -0.7457,
-        -0.9287,  0.0127, -0.7745, -0.2500,  1.0096,  0.0000,  0.6997,  0.8206,
-         0.2632,  1.1646,  0.0770,  0.7333,  0.5830,  0.1496,  0.6319,  0.0279,
-         0.8071, -0.5447,  0.3186,  0.2982,  0.3509,  0.2218,  0.4740,  0.3247,
-         1.2011,  1.0300,  0.4202, -0.4719,  1.2688,  0.3675, -0.0295,  0.1025],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8425,  0.2976,  0.0388, -0.1585,  0.3094,  0.5290, -0.2437, -0.9823,
-        -0.4177, -0.6654, -0.4195,  0.0475,  0.1327,  0.2572, -1.3518, -0.5333,
-        -0.8482,  0.5650, -0.9676, -0.5022, -0.7319,  0.2373, -0.5030, -0.8043,
-         0.2128,  0.2431, -0.0533, -0.2969, -0.0725,  1.5760, -0.7242, -0.7457,
-        -0.9287,  0.0127, -0.7745, -0.2500,  1.0096,  0.0000,  0.6997,  0.8206,
-         0.2632,  1.1646,  0.0770,  0.7333,  0.5830,  0.1496,  0.6319,  0.0279,
-         0.8071, -0.5447,  0.3186,  0.2982,  0.3509,  0.2218,  0.4740,  0.3247,
-         1.2011,  1.0300,  0.4202, -0.4719,  1.2688,  0.3675, -0.0295,  0.1025],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8375,  0.2833, -0.0045, -0.1650,  0.2715,  0.5186, -0.2251, -0.9785,
-        -0.4229, -0.6607, -0.4137,  0.0199,  0.0827,  0.2623, -1.3560, -0.5316,
-        -0.8404,  0.5644, -0.9704, -0.4959, -0.7453,  0.2144, -0.4864, -0.8171,
-         0.1376,  0.2319, -0.0331, -0.2929, -0.0817,  1.5764, -0.7156, -0.7394,
-        -0.9297,  0.0513, -0.7637, -0.2271,  1.0110,  0.0000,  0.6878,  0.8470,
-         0.2675,  1.1607,  0.0411,  0.7291,  0.5805,  0.1486,  0.6226,  0.0385,
-         0.8169, -0.5387,  0.3069,  0.2822,  0.3429,  0.1732,  0.4923,  0.3309,
-         1.2040,  1.0328,  0.4562, -0.4657,  1.2727,  0.3888, -0.0144,  0.0939],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8375,  0.2833, -0.0045, -0.1650,  0.2715,  0.5186, -0.2251, -0.9785,
-        -0.4229, -0.6607, -0.4137,  0.0199,  0.0827,  0.2623, -1.3560, -0.5316,
-        -0.8404,  0.5644, -0.9704, -0.4959, -0.7453,  0.2144, -0.4864, -0.8171,
-         0.1376,  0.2319, -0.0331, -0.2929, -0.0817,  1.5764, -0.7156, -0.7394,
-        -0.9297,  0.0513, -0.7637, -0.2271,  1.0110,  0.0000,  0.6878,  0.8470,
-         0.2675,  1.1607,  0.0411,  0.7291,  0.5805,  0.1486,  0.6226,  0.0385,
-         0.8169, -0.5387,  0.3069,  0.2822,  0.3429,  0.1732,  0.4923,  0.3309,
-         1.2040,  1.0328,  0.4562, -0.4657,  1.2727,  0.3888, -0.0144,  0.0939],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8375,  0.2833, -0.0045, -0.1650,  0.2715,  0.5186, -0.2251, -0.9785,
-        -0.4229, -0.6607, -0.4137,  0.0199,  0.0827,  0.2623, -1.3560, -0.5316,
-        -0.8404,  0.5644, -0.9704, -0.4959, -0.7453,  0.2144, -0.4864, -0.8171,
-         0.1376,  0.2319, -0.0331, -0.2929, -0.0817,  1.5764, -0.7156, -0.7394,
-        -0.9297,  0.0513, -0.7637, -0.2271,  1.0110,  0.0000,  0.6878,  0.8470,
-         0.2675,  1.1607,  0.0411,  0.7291,  0.5805,  0.1486,  0.6226,  0.0385,
-         0.8169, -0.5387,  0.3069,  0.2822,  0.3429,  0.1732,  0.4923,  0.3309,
-         1.2040,  1.0328,  0.4562, -0.4657,  1.2727,  0.3888, -0.0144,  0.0939],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 8.3317e-01,  2.7207e-01, -5.6172e-02, -1.6518e-01,  2.6192e-01,
-         5.1705e-01, -2.3862e-01, -9.7664e-01, -4.2490e-01, -6.5969e-01,
-        -4.0968e-01, -4.4241e-03,  4.9412e-02,  2.7385e-01, -1.3576e+00,
-        -5.3574e-01, -8.3907e-01,  5.5970e-01, -9.7616e-01, -4.9480e-01,
-        -7.5499e-01,  1.6080e-01, -4.7468e-01, -8.2654e-01,  4.3112e-02,
-         2.3072e-01, -1.5595e-02, -2.9040e-01, -8.6214e-02,  1.5775e+00,
-        -7.1273e-01, -7.3383e-01, -9.3075e-01,  5.6695e-02, -7.5314e-01,
-        -2.0307e-01,  1.0139e+00,  0.0000e+00,  6.7313e-01,  8.7261e-01,
-         2.5058e-01,  1.1540e+00,  5.1844e-03,  7.2862e-01,  5.7412e-01,
-         1.4032e-01,  6.1620e-01,  4.9613e-02,  8.2662e-01, -5.3755e-01,
-         2.8893e-01,  2.3396e-01,  3.2864e-01,  1.0818e-01,  5.0813e-01,
-         3.3818e-01,  1.2068e+00,  1.0357e+00,  4.8912e-01, -4.2889e-01,
-         1.2759e+00,  3.9636e-01, -3.2507e-04,  7.9335e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.3317e-01,  2.7207e-01, -5.6172e-02, -1.6518e-01,  2.6192e-01,
-         5.1705e-01, -2.3862e-01, -9.7664e-01, -4.2490e-01, -6.5969e-01,
-        -4.0968e-01, -4.4241e-03,  4.9412e-02,  2.7385e-01, -1.3576e+00,
-        -5.3574e-01, -8.3907e-01,  5.5970e-01, -9.7616e-01, -4.9480e-01,
-        -7.5499e-01,  1.6080e-01, -4.7468e-01, -8.2654e-01,  4.3112e-02,
-         2.3072e-01, -1.5595e-02, -2.9040e-01, -8.6214e-02,  1.5775e+00,
-        -7.1273e-01, -7.3383e-01, -9.3075e-01,  5.6695e-02, -7.5314e-01,
-        -2.0307e-01,  1.0139e+00,  0.0000e+00,  6.7313e-01,  8.7261e-01,
-         2.5058e-01,  1.1540e+00,  5.1844e-03,  7.2862e-01,  5.7412e-01,
-         1.4032e-01,  6.1620e-01,  4.9613e-02,  8.2662e-01, -5.3755e-01,
-         2.8893e-01,  2.3396e-01,  3.2864e-01,  1.0818e-01,  5.0813e-01,
-         3.3818e-01,  1.2068e+00,  1.0357e+00,  4.8912e-01, -4.2889e-01,
-         1.2759e+00,  3.9636e-01, -3.2507e-04,  7.9335e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.3317e-01,  2.7207e-01, -5.6172e-02, -1.6518e-01,  2.6192e-01,
-         5.1705e-01, -2.3862e-01, -9.7664e-01, -4.2490e-01, -6.5969e-01,
-        -4.0968e-01, -4.4241e-03,  4.9412e-02,  2.7385e-01, -1.3576e+00,
-        -5.3574e-01, -8.3907e-01,  5.5970e-01, -9.7616e-01, -4.9480e-01,
-        -7.5499e-01,  1.6080e-01, -4.7468e-01, -8.2654e-01,  4.3112e-02,
-         2.3072e-01, -1.5595e-02, -2.9040e-01, -8.6214e-02,  1.5775e+00,
-        -7.1273e-01, -7.3383e-01, -9.3075e-01,  5.6695e-02, -7.5314e-01,
-        -2.0307e-01,  1.0139e+00,  0.0000e+00,  6.7313e-01,  8.7261e-01,
-         2.5058e-01,  1.1540e+00,  5.1844e-03,  7.2862e-01,  5.7412e-01,
-         1.4032e-01,  6.1620e-01,  4.9613e-02,  8.2662e-01, -5.3755e-01,
-         2.8893e-01,  2.3396e-01,  3.2864e-01,  1.0818e-01,  5.0813e-01,
-         3.3818e-01,  1.2068e+00,  1.0357e+00,  4.8912e-01, -4.2889e-01,
-         1.2759e+00,  3.9636e-01, -3.2507e-04,  7.9335e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 8.3011e-01,  2.6457e-01, -1.0603e-01, -1.8032e-01,  2.2555e-01,
-         5.2373e-01, -2.6862e-01, -9.7533e-01, -4.2992e-01, -6.5976e-01,
-        -4.0496e-01, -2.8469e-02,  3.5085e-02,  2.7971e-01, -1.3568e+00,
-        -5.4298e-01, -8.4320e-01,  5.5562e-01, -9.8359e-01, -4.8837e-01,
-        -7.5816e-01,  6.6743e-02, -4.5767e-01, -8.2435e-01, -5.9995e-02,
-         2.2646e-01,  1.3474e-02, -2.7961e-01, -8.0743e-02,  1.5793e+00,
-        -7.1144e-01, -7.2887e-01, -9.3148e-01,  4.1963e-02, -7.4577e-01,
-        -1.7196e-01,  1.0116e+00,  0.0000e+00,  6.3948e-01,  8.9107e-01,
-         2.2907e-01,  1.1469e+00, -3.3750e-02,  7.2373e-01,  5.7125e-01,
-         1.4722e-01,  6.1618e-01,  5.0592e-02,  8.3478e-01, -5.3777e-01,
-         2.7441e-01,  1.5093e-01,  3.1038e-01,  2.9085e-02,  5.2145e-01,
-         3.4857e-01,  1.2088e+00,  1.0340e+00,  5.4875e-01, -3.8305e-01,
-         1.2771e+00,  4.0332e-01,  4.0163e-04,  6.3698e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.3011e-01,  2.6457e-01, -1.0603e-01, -1.8032e-01,  2.2555e-01,
-         5.2373e-01, -2.6862e-01, -9.7533e-01, -4.2992e-01, -6.5976e-01,
-        -4.0496e-01, -2.8469e-02,  3.5085e-02,  2.7971e-01, -1.3568e+00,
-        -5.4298e-01, -8.4320e-01,  5.5562e-01, -9.8359e-01, -4.8837e-01,
-        -7.5816e-01,  6.6743e-02, -4.5767e-01, -8.2435e-01, -5.9995e-02,
-         2.2646e-01,  1.3474e-02, -2.7961e-01, -8.0743e-02,  1.5793e+00,
-        -7.1144e-01, -7.2887e-01, -9.3148e-01,  4.1963e-02, -7.4577e-01,
-        -1.7196e-01,  1.0116e+00,  0.0000e+00,  6.3948e-01,  8.9107e-01,
-         2.2907e-01,  1.1469e+00, -3.3750e-02,  7.2373e-01,  5.7125e-01,
-         1.4722e-01,  6.1618e-01,  5.0592e-02,  8.3478e-01, -5.3777e-01,
-         2.7441e-01,  1.5093e-01,  3.1038e-01,  2.9085e-02,  5.2145e-01,
-         3.4857e-01,  1.2088e+00,  1.0340e+00,  5.4875e-01, -3.8305e-01,
-         1.2771e+00,  4.0332e-01,  4.0163e-04,  6.3698e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.3011e-01,  2.6457e-01, -1.0603e-01, -1.8032e-01,  2.2555e-01,
-         5.2373e-01, -2.6862e-01, -9.7533e-01, -4.2992e-01, -6.5976e-01,
-        -4.0496e-01, -2.8469e-02,  3.5085e-02,  2.7971e-01, -1.3568e+00,
-        -5.4298e-01, -8.4320e-01,  5.5562e-01, -9.8359e-01, -4.8837e-01,
-        -7.5816e-01,  6.6743e-02, -4.5767e-01, -8.2435e-01, -5.9995e-02,
-         2.2646e-01,  1.3474e-02, -2.7961e-01, -8.0743e-02,  1.5793e+00,
-        -7.1144e-01, -7.2887e-01, -9.3148e-01,  4.1963e-02, -7.4577e-01,
-        -1.7196e-01,  1.0116e+00,  0.0000e+00,  6.3948e-01,  8.9107e-01,
-         2.2907e-01,  1.1469e+00, -3.3750e-02,  7.2373e-01,  5.7125e-01,
-         1.4722e-01,  6.1618e-01,  5.0592e-02,  8.3478e-01, -5.3777e-01,
-         2.7441e-01,  1.5093e-01,  3.1038e-01,  2.9085e-02,  5.2145e-01,
-         3.4857e-01,  1.2088e+00,  1.0340e+00,  5.4875e-01, -3.8305e-01,
-         1.2771e+00,  4.0332e-01,  4.0163e-04,  6.3698e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8308,  0.2676, -0.1373, -0.1864,  0.1869,  0.5278, -0.2908, -0.9722,
-        -0.4321, -0.6585, -0.3927, -0.0707,  0.0173,  0.3022, -1.3598, -0.5338,
-        -0.8415,  0.5534, -0.9877, -0.4840, -0.7626,  0.0057, -0.4433, -0.8223,
-        -0.1539,  0.2197,  0.0460, -0.2677, -0.0824,  1.5808, -0.7011, -0.7204,
-        -0.9312,  0.0565, -0.7350, -0.1329,  1.0096,  0.0000,  0.6162,  0.9091,
-         0.2304,  1.1432, -0.0595,  0.7203,  0.5743,  0.1464,  0.6192,  0.0437,
-         0.8385, -0.5341,  0.2543,  0.1349,  0.2806, -0.0346,  0.5379,  0.3585,
-         1.2096,  1.0335,  0.5980, -0.3524,  1.2792,  0.4118, -0.0103,  0.0577],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8308,  0.2676, -0.1373, -0.1864,  0.1869,  0.5278, -0.2908, -0.9722,
-        -0.4321, -0.6585, -0.3927, -0.0707,  0.0173,  0.3022, -1.3598, -0.5338,
-        -0.8415,  0.5534, -0.9877, -0.4840, -0.7626,  0.0057, -0.4433, -0.8223,
-        -0.1539,  0.2197,  0.0460, -0.2677, -0.0824,  1.5808, -0.7011, -0.7204,
-        -0.9312,  0.0565, -0.7350, -0.1329,  1.0096,  0.0000,  0.6162,  0.9091,
-         0.2304,  1.1432, -0.0595,  0.7203,  0.5743,  0.1464,  0.6192,  0.0437,
-         0.8385, -0.5341,  0.2543,  0.1349,  0.2806, -0.0346,  0.5379,  0.3585,
-         1.2096,  1.0335,  0.5980, -0.3524,  1.2792,  0.4118, -0.0103,  0.0577],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8308,  0.2676, -0.1373, -0.1864,  0.1869,  0.5278, -0.2908, -0.9722,
-        -0.4321, -0.6585, -0.3927, -0.0707,  0.0173,  0.3022, -1.3598, -0.5338,
-        -0.8415,  0.5534, -0.9877, -0.4840, -0.7626,  0.0057, -0.4433, -0.8223,
-        -0.1539,  0.2197,  0.0460, -0.2677, -0.0824,  1.5808, -0.7011, -0.7204,
-        -0.9312,  0.0565, -0.7350, -0.1329,  1.0096,  0.0000,  0.6162,  0.9091,
-         0.2304,  1.1432, -0.0595,  0.7203,  0.5743,  0.1464,  0.6192,  0.0437,
-         0.8385, -0.5341,  0.2543,  0.1349,  0.2806, -0.0346,  0.5379,  0.3585,
-         1.2096,  1.0335,  0.5980, -0.3524,  1.2792,  0.4118, -0.0103,  0.0577],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8302,  0.2822, -0.1337, -0.1589,  0.1910,  0.5145, -0.3252, -0.9699,
-        -0.4267, -0.6557, -0.3788, -0.1377,  0.0518,  0.2831, -1.3629, -0.5251,
-        -0.8399,  0.5518, -0.9897, -0.4851, -0.7653, -0.0256, -0.4319, -0.8279,
-        -0.2263,  0.1962,  0.0757, -0.2480, -0.0960,  1.5799, -0.6887, -0.7079,
-        -0.9315,  0.1065, -0.7306, -0.0948,  1.0061,  0.0000,  0.5844,  0.9245,
-         0.2386,  1.1414, -0.0495,  0.7103,  0.5786,  0.1463,  0.6234,  0.0293,
-         0.8337, -0.5312,  0.2455,  0.1503,  0.2513, -0.0896,  0.5649,  0.3551,
-         1.2097,  1.0361,  0.6554, -0.3470,  1.2819,  0.4263, -0.0158,  0.0702],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8302,  0.2822, -0.1337, -0.1589,  0.1910,  0.5145, -0.3252, -0.9699,
-        -0.4267, -0.6557, -0.3788, -0.1377,  0.0518,  0.2831, -1.3629, -0.5251,
-        -0.8399,  0.5518, -0.9897, -0.4851, -0.7653, -0.0256, -0.4319, -0.8279,
-        -0.2263,  0.1962,  0.0757, -0.2480, -0.0960,  1.5799, -0.6887, -0.7079,
-        -0.9315,  0.1065, -0.7306, -0.0948,  1.0061,  0.0000,  0.5844,  0.9245,
-         0.2386,  1.1414, -0.0495,  0.7103,  0.5786,  0.1463,  0.6234,  0.0293,
-         0.8337, -0.5312,  0.2455,  0.1503,  0.2513, -0.0896,  0.5649,  0.3551,
-         1.2097,  1.0361,  0.6554, -0.3470,  1.2819,  0.4263, -0.0158,  0.0702],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8302,  0.2822, -0.1337, -0.1589,  0.1910,  0.5145, -0.3252, -0.9699,
-        -0.4267, -0.6557, -0.3788, -0.1377,  0.0518,  0.2831, -1.3629, -0.5251,
-        -0.8399,  0.5518, -0.9897, -0.4851, -0.7653, -0.0256, -0.4319, -0.8279,
-        -0.2263,  0.1962,  0.0757, -0.2480, -0.0960,  1.5799, -0.6887, -0.7079,
-        -0.9315,  0.1065, -0.7306, -0.0948,  1.0061,  0.0000,  0.5844,  0.9245,
-         0.2386,  1.1414, -0.0495,  0.7103,  0.5786,  0.1463,  0.6234,  0.0293,
-         0.8337, -0.5312,  0.2455,  0.1503,  0.2513, -0.0896,  0.5649,  0.3551,
-         1.2097,  1.0361,  0.6554, -0.3470,  1.2819,  0.4263, -0.0158,  0.0702],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8318,  0.2883, -0.1146, -0.1102,  0.1919,  0.5028, -0.3566, -0.9678,
-        -0.4070, -0.6508, -0.3568, -0.1717,  0.1102,  0.3183, -1.3649, -0.5066,
-        -0.8352,  0.5470, -0.9922, -0.4837, -0.7667, -0.0673, -0.4216, -0.8422,
-        -0.3096,  0.2279,  0.1206, -0.2180, -0.1016,  1.5786, -0.6769, -0.6946,
-        -0.9323,  0.1557, -0.7235, -0.0620,  1.0039,  0.0000,  0.5900,  0.9460,
-         0.2591,  1.1412, -0.0132,  0.7118,  0.5809,  0.1091,  0.6265,  0.0519,
-         0.8268, -0.5302,  0.2381,  0.1969,  0.1762, -0.1172,  0.6031,  0.3743,
-         1.2096,  1.0447,  0.7099, -0.3388,  1.2854,  0.4243, -0.0400,  0.0817],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8318,  0.2883, -0.1146, -0.1102,  0.1919,  0.5028, -0.3566, -0.9678,
-        -0.4070, -0.6508, -0.3568, -0.1717,  0.1102,  0.3183, -1.3649, -0.5066,
-        -0.8352,  0.5470, -0.9922, -0.4837, -0.7667, -0.0673, -0.4216, -0.8422,
-        -0.3096,  0.2279,  0.1206, -0.2180, -0.1016,  1.5786, -0.6769, -0.6946,
-        -0.9323,  0.1557, -0.7235, -0.0620,  1.0039,  0.0000,  0.5900,  0.9460,
-         0.2591,  1.1412, -0.0132,  0.7118,  0.5809,  0.1091,  0.6265,  0.0519,
-         0.8268, -0.5302,  0.2381,  0.1969,  0.1762, -0.1172,  0.6031,  0.3743,
-         1.2096,  1.0447,  0.7099, -0.3388,  1.2854,  0.4243, -0.0400,  0.0817],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8318,  0.2883, -0.1146, -0.1102,  0.1919,  0.5028, -0.3566, -0.9678,
-        -0.4070, -0.6508, -0.3568, -0.1717,  0.1102,  0.3183, -1.3649, -0.5066,
-        -0.8352,  0.5470, -0.9922, -0.4837, -0.7667, -0.0673, -0.4216, -0.8422,
-        -0.3096,  0.2279,  0.1206, -0.2180, -0.1016,  1.5786, -0.6769, -0.6946,
-        -0.9323,  0.1557, -0.7235, -0.0620,  1.0039,  0.0000,  0.5900,  0.9460,
-         0.2591,  1.1412, -0.0132,  0.7118,  0.5809,  0.1091,  0.6265,  0.0519,
-         0.8268, -0.5302,  0.2381,  0.1969,  0.1762, -0.1172,  0.6031,  0.3743,
-         1.2096,  1.0447,  0.7099, -0.3388,  1.2854,  0.4243, -0.0400,  0.0817],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8324,  0.3084, -0.1140, -0.0927,  0.2257,  0.4785, -0.3787, -0.9660,
-        -0.3880, -0.6482, -0.3367, -0.2236,  0.1311,  0.3343, -1.3668, -0.4901,
-        -0.8232,  0.5425, -0.9922, -0.4758, -0.7741, -0.0960, -0.4016, -0.8652,
-        -0.3790,  0.2669,  0.1363, -0.1703, -0.1168,  1.5772, -0.6629, -0.6820,
-        -0.9342,  0.2361, -0.7144, -0.0455,  1.0014,  0.0000,  0.5871,  0.9672,
-         0.2949,  1.1430, -0.0031,  0.7144,  0.5802,  0.0471,  0.6338,  0.0639,
-         0.8227, -0.5275,  0.2252,  0.2986,  0.0797, -0.1300,  0.6462,  0.3868,
-         1.2096,  1.0554,  0.7715, -0.3669,  1.2890,  0.4188, -0.0700,  0.0648],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8324,  0.3084, -0.1140, -0.0927,  0.2257,  0.4785, -0.3787, -0.9660,
-        -0.3880, -0.6482, -0.3367, -0.2236,  0.1311,  0.3343, -1.3668, -0.4901,
-        -0.8232,  0.5425, -0.9922, -0.4758, -0.7741, -0.0960, -0.4016, -0.8652,
-        -0.3790,  0.2669,  0.1363, -0.1703, -0.1168,  1.5772, -0.6629, -0.6820,
-        -0.9342,  0.2361, -0.7144, -0.0455,  1.0014,  0.0000,  0.5871,  0.9672,
-         0.2949,  1.1430, -0.0031,  0.7144,  0.5802,  0.0471,  0.6338,  0.0639,
-         0.8227, -0.5275,  0.2252,  0.2986,  0.0797, -0.1300,  0.6462,  0.3868,
-         1.2096,  1.0554,  0.7715, -0.3669,  1.2890,  0.4188, -0.0700,  0.0648],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8324,  0.3084, -0.1140, -0.0927,  0.2257,  0.4785, -0.3787, -0.9660,
-        -0.3880, -0.6482, -0.3367, -0.2236,  0.1311,  0.3343, -1.3668, -0.4901,
-        -0.8232,  0.5425, -0.9922, -0.4758, -0.7741, -0.0960, -0.4016, -0.8652,
-        -0.3790,  0.2669,  0.1363, -0.1703, -0.1168,  1.5772, -0.6629, -0.6820,
-        -0.9342,  0.2361, -0.7144, -0.0455,  1.0014,  0.0000,  0.5871,  0.9672,
-         0.2949,  1.1430, -0.0031,  0.7144,  0.5802,  0.0471,  0.6338,  0.0639,
-         0.8227, -0.5275,  0.2252,  0.2986,  0.0797, -0.1300,  0.6462,  0.3868,
-         1.2096,  1.0554,  0.7715, -0.3669,  1.2890,  0.4188, -0.0700,  0.0648],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8299,  0.3172, -0.1362, -0.0863,  0.2383,  0.4549, -0.3951, -0.9638,
-        -0.3753, -0.6448, -0.3211, -0.2770,  0.0905,  0.3530, -1.3695, -0.4662,
-        -0.7966,  0.5398, -0.9883, -0.4648, -0.7841, -0.1032, -0.3854, -0.8868,
-        -0.4519,  0.3328,  0.1162, -0.1212, -0.1492,  1.5746, -0.6425, -0.6668,
-        -0.9350,  0.2830, -0.7067, -0.0357,  1.0032,  0.0000,  0.6057,  0.9896,
-         0.3258,  1.1443, -0.0269,  0.7251,  0.5763, -0.0517,  0.6365,  0.0748,
-         0.8145, -0.5216,  0.2281,  0.3837, -0.0431, -0.1352,  0.6870,  0.3980,
-         1.2093,  1.0656,  0.8145, -0.4099,  1.2934,  0.3886, -0.1161, -0.0050],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8299,  0.3172, -0.1362, -0.0863,  0.2383,  0.4549, -0.3951, -0.9638,
-        -0.3753, -0.6448, -0.3211, -0.2770,  0.0905,  0.3530, -1.3695, -0.4662,
-        -0.7966,  0.5398, -0.9883, -0.4648, -0.7841, -0.1032, -0.3854, -0.8868,
-        -0.4519,  0.3328,  0.1162, -0.1212, -0.1492,  1.5746, -0.6425, -0.6668,
-        -0.9350,  0.2830, -0.7067, -0.0357,  1.0032,  0.0000,  0.6057,  0.9896,
-         0.3258,  1.1443, -0.0269,  0.7251,  0.5763, -0.0517,  0.6365,  0.0748,
-         0.8145, -0.5216,  0.2281,  0.3837, -0.0431, -0.1352,  0.6870,  0.0000,
-         1.2093,  1.0656,  0.8145, -0.4099,  1.2934,  0.3886, -0.1161, -0.0050],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8299,  0.3172, -0.1362, -0.0863,  0.2383,  0.4549, -0.3951, -0.9638,
-        -0.3753, -0.6448, -0.3211, -0.2770,  0.0905,  0.3530, -1.3695, -0.4662,
-        -0.7966,  0.5398, -0.9883, -0.4648, -0.7841, -0.1032, -0.3854, -0.8868,
-        -0.4519,  0.3328,  0.1162, -0.1212, -0.1492,  1.5746, -0.6425, -0.6668,
-        -0.9350,  0.2830, -0.7067, -0.0357,  1.0032,  0.0000,  0.6057,  0.9896,
-         0.3258,  1.1443, -0.0269,  0.7251,  0.5763, -0.0517,  0.6365,  0.0748,
-         0.8145, -0.5216,  0.2281,  0.3837, -0.0431, -0.1352,  0.6870,  0.0000,
-         1.2093,  1.0656,  0.8145, -0.4099,  1.2934,  0.3886, -0.1161, -0.0050],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8277,  0.3282, -0.1847, -0.0867,  0.1869,  0.4438, -0.4151, -0.9648,
-        -0.3726, -0.6420, -0.3063, -0.3449,  0.0154,  0.3714, -1.3723, -0.4558,
-        -0.7803,  0.5386, -0.9848, -0.4636, -0.7790, -0.0577, -0.3554, -0.8901,
-        -0.4966,  0.4008,  0.0869, -0.0636, -0.1590,  1.5714, -0.6243, -0.6510,
-        -0.9407,  0.2869, -0.6977, -0.0151,  1.0039,  0.0000,  0.6233,  1.0065,
-         0.3538,  1.1359, -0.0781,  0.7325,  0.5699, -0.1461,  0.6402,  0.0927,
-         0.8081, -0.5141,  0.2201,  0.4008, -0.1555, -0.1140,  0.7227,  0.0101,
-         1.2087,  1.0748,  0.8352, -0.4293,  1.2969,  0.3564, -0.1577, -0.0931],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8277,  0.3282, -0.1847, -0.0867,  0.1869,  0.4438, -0.4151, -0.9648,
-        -0.3726, -0.6420, -0.3063, -0.3449,  0.0154,  0.3714, -1.3723, -0.4558,
-        -0.7803,  0.5386, -0.9848, -0.4636, -0.7790, -0.0577, -0.3554, -0.8901,
-        -0.4966,  0.4008,  0.0869, -0.0636, -0.1590,  1.5714, -0.6243, -0.6510,
-        -0.9407,  0.2869, -0.6977, -0.0151,  1.0039,  0.0000,  0.6233,  1.0065,
-         0.3538,  1.1359, -0.0781,  0.7325,  0.5699, -0.1461,  0.6402,  0.0927,
-         0.8081, -0.5141,  0.2201,  0.4008, -0.1555, -0.1140,  0.7227,  0.0000,
-         1.2087,  1.0748,  0.8352, -0.4293,  1.2969,  0.3564, -0.1577, -0.0931],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8277,  0.3282, -0.1847, -0.0867,  0.1869,  0.4438, -0.4151, -0.9648,
-        -0.3726, -0.6420, -0.3063, -0.3449,  0.0154,  0.3714, -1.3723, -0.4558,
-        -0.7803,  0.5386, -0.9848, -0.4636, -0.7790, -0.0577, -0.3554, -0.8901,
-        -0.4966,  0.4008,  0.0869, -0.0636, -0.1590,  1.5714, -0.6243, -0.6510,
-        -0.9407,  0.2869, -0.6977, -0.0151,  1.0039,  0.0000,  0.6233,  1.0065,
-         0.3538,  1.1359, -0.0781,  0.7325,  0.5699, -0.1461,  0.6402,  0.0927,
-         0.8081, -0.5141,  0.2201,  0.4008, -0.1555, -0.1140,  0.7227,  0.0000,
-         1.2087,  1.0748,  0.8352, -0.4293,  1.2969,  0.3564, -0.1577, -0.0931],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8163,  0.3677, -0.2186, -0.0644,  0.1750,  0.4349, -0.4269, -0.9644,
-        -0.3529, -0.6419, -0.2813, -0.4204, -0.0730,  0.3680, -1.3768, -0.4452,
-        -0.7744,  0.5421, -0.9793, -0.4663, -0.7561,  0.0263, -0.3356, -0.8687,
-        -0.5284,  0.4499,  0.0362, -0.0018, -0.1413,  1.5714, -0.6200, -0.6351,
-        -0.9481,  0.3071, -0.6917,  0.0193,  1.0055,  0.0000,  0.6108,  1.0186,
-         0.4011,  1.1223, -0.0880,  0.7320,  0.5602, -0.2445,  0.6359,  0.0903,
-         0.8175, -0.5061,  0.2037,  0.3850, -0.2454, -0.0705,  0.7483,  0.0092,
-         1.2080,  1.0775,  0.8402, -0.4152,  1.2979,  0.3260, -0.1909, -0.1620],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8163,  0.3677, -0.2186, -0.0644,  0.1750,  0.4349, -0.4269, -0.9644,
-        -0.3529, -0.6419, -0.2813, -0.4204, -0.0730,  0.3680, -1.3768, -0.4452,
-        -0.7744,  0.5421, -0.9793, -0.4663, -0.7561,  0.0263, -0.3356, -0.8687,
-        -0.5284,  0.4499,  0.0362, -0.0018, -0.1413,  1.5714, -0.6200, -0.6351,
-        -0.9481,  0.3071, -0.6917,  0.0193,  1.0055,  0.0000,  0.6108,  1.0186,
-         0.4011,  1.1223, -0.0880,  0.7320,  0.5602, -0.2445,  0.6359,  0.0903,
-         0.8175, -0.5061,  0.2037,  0.3850, -0.2454, -0.0705,  0.7483,  0.0000,
-         1.2080,  1.0775,  0.8402, -0.4152,  1.2979,  0.3260, -0.1909, -0.1620],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8163,  0.3677, -0.2186, -0.0644,  0.1750,  0.4349, -0.4269, -0.9644,
-        -0.3529, -0.6419, -0.2813, -0.4204, -0.0730,  0.3680, -1.3768, -0.4452,
-        -0.7744,  0.5421, -0.9793, -0.4663, -0.7561,  0.0263, -0.3356, -0.8687,
-        -0.5284,  0.4499,  0.0362, -0.0018, -0.1413,  1.5714, -0.6200, -0.6351,
-        -0.9481,  0.3071, -0.6917,  0.0193,  1.0055,  0.0000,  0.6108,  1.0186,
-         0.4011,  1.1223, -0.0880,  0.7320,  0.5602, -0.2445,  0.6359,  0.0903,
-         0.8175, -0.5061,  0.2037,  0.3850, -0.2454, -0.0705,  0.7483,  0.0000,
-         1.2080,  1.0775,  0.8402, -0.4152,  1.2979,  0.3260, -0.1909, -0.1620],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.8025,  0.4118, -0.2302, -0.0094,  0.1471,  0.4366, -0.4306, -0.9641,
-        -0.3293, -0.6451, -0.2429, -0.5076, -0.1077,  0.3623, -1.3778, -0.4466,
-        -0.7796,  0.5443, -0.9769, -0.4724, -0.7279,  0.0813, -0.3262, -0.8360,
-        -0.5554,  0.4835,  0.0035,  0.0411, -0.1214,  1.5719, -0.6231, -0.6259,
-        -0.9584,  0.2911, -0.6889,  0.0633,  1.0046,  0.0000,  0.5714,  1.0237,
-         0.4378,  1.1088, -0.0708,  0.7295,  0.5524, -0.3352,  0.6294,  0.0943,
-         0.8291, -0.4971,  0.1899,  0.3341, -0.3264, -0.0299,  0.7654,  0.0083,
-         1.2076,  1.0754,  0.8448, -0.3788,  1.2982,  0.3052, -0.2177, -0.1983],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.8025,  0.4118, -0.2302, -0.0094,  0.1471,  0.4366, -0.4306, -0.9641,
-        -0.3293, -0.6451, -0.2429, -0.5076, -0.1077,  0.3623, -1.3778, -0.4466,
-        -0.7796,  0.5443, -0.9769, -0.4724, -0.7279,  0.0813, -0.3262, -0.8360,
-        -0.5554,  0.4835,  0.0035,  0.0411, -0.1214,  1.5719, -0.6231, -0.6259,
-        -0.9584,  0.2911, -0.6889,  0.0633,  1.0046,  0.0000,  0.5714,  1.0237,
-         0.4378,  1.1088, -0.0708,  0.7295,  0.5524, -0.3352,  0.6294,  0.0943,
-         0.8291, -0.4971,  0.1899,  0.3341, -0.3264, -0.0299,  0.7654,  0.0000,
-         1.2076,  1.0754,  0.8448, -0.3788,  1.2982,  0.3052, -0.2177, -0.1983],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.8025,  0.4118, -0.2302, -0.0094,  0.1471,  0.4366, -0.4306, -0.9641,
-        -0.3293, -0.6451, -0.2429, -0.5076, -0.1077,  0.3623, -1.3778, -0.4466,
-        -0.7796,  0.5443, -0.9769, -0.4724, -0.7279,  0.0813, -0.3262, -0.8360,
-        -0.5554,  0.4835,  0.0035,  0.0411, -0.1214,  1.5719, -0.6231, -0.6259,
-        -0.9584,  0.2911, -0.6889,  0.0633,  1.0046,  0.0000,  0.5714,  1.0237,
-         0.4378,  1.1088, -0.0708,  0.7295,  0.5524, -0.3352,  0.6294,  0.0943,
-         0.8291, -0.4971,  0.1899,  0.3341, -0.3264, -0.0299,  0.7654,  0.0000,
-         1.2076,  1.0754,  0.8448, -0.3788,  1.2982,  0.3052, -0.2177, -0.1983],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.7887,  0.4436, -0.2378,  0.0274,  0.1600,  0.4436, -0.4368, -0.9652,
-        -0.3115, -0.6452, -0.2043, -0.5608, -0.1474,  0.3518, -1.3752, -0.4499,
-        -0.7877,  0.5459, -0.9744, -0.4727, -0.7155,  0.0808, -0.3216, -0.8139,
-        -0.5740,  0.5045, -0.0283,  0.0667, -0.1386,  1.5723, -0.6255, -0.6150,
-        -0.9667,  0.2640, -0.6815,  0.0953,  1.0031,  0.0000,  0.5104,  1.0241,
-         0.4688,  1.1001, -0.0565,  0.7172,  0.5430, -0.4056,  0.6237,  0.1092,
-         0.8417, -0.4868,  0.2018,  0.2831, -0.3978, -0.0181,  0.7699,  0.0075,
-         1.2074,  1.0659,  0.8463, -0.3743,  1.2988,  0.2909, -0.2348, -0.2306],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.7887,  0.4436, -0.2378,  0.0274,  0.1600,  0.4436, -0.4368, -0.9652,
-        -0.3115, -0.6452, -0.2043, -0.5608, -0.1474,  0.3518, -1.3752, -0.4499,
-        -0.7877,  0.5459, -0.9744, -0.4727, -0.7155,  0.0808, -0.3216, -0.8139,
-        -0.5740,  0.5045, -0.0283,  0.0667, -0.1386,  1.5723, -0.6255, -0.6150,
-        -0.9667,  0.2640, -0.6815,  0.0953,  1.0031,  0.0000,  0.5104,  1.0241,
-         0.4688,  1.1001, -0.0565,  0.7172,  0.5430, -0.4056,  0.6237,  0.1092,
-         0.8417, -0.4868,  0.2018,  0.2831, -0.3978, -0.0181,  0.7699,  0.0000,
-         1.2074,  1.0659,  0.8463, -0.3743,  1.2988,  0.2909, -0.2348, -0.2306],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.7887,  0.4436, -0.2378,  0.0274,  0.1600,  0.4436, -0.4368, -0.9652,
-        -0.3115, -0.6452, -0.2043, -0.5608, -0.1474,  0.3518, -1.3752, -0.4499,
-        -0.7877,  0.5459, -0.9744, -0.4727, -0.7155,  0.0808, -0.3216, -0.8139,
-        -0.5740,  0.5045, -0.0283,  0.0667, -0.1386,  1.5723, -0.6255, -0.6150,
-        -0.9667,  0.2640, -0.6815,  0.0953,  1.0031,  0.0000,  0.5104,  1.0241,
-         0.4688,  1.1001, -0.0565,  0.7172,  0.5430, -0.4056,  0.6237,  0.1092,
-         0.8417, -0.4868,  0.2018,  0.2831, -0.3978, -0.0181,  0.7699,  0.0000,
-         1.2074,  1.0659,  0.8463, -0.3743,  1.2988,  0.2909, -0.2348, -0.2306],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.7767,  0.4693, -0.2234,  0.0584,  0.2462,  0.4369, -0.4370, -0.9640,
-        -0.2956, -0.6449, -0.1623, -0.5844, -0.1290,  0.3267, -1.3768, -0.4635,
-        -0.7871,  0.5390, -0.9756, -0.4746, -0.7173,  0.0412, -0.3197, -0.8180,
-        -0.5873,  0.5159, -0.0433,  0.0667, -0.1841,  1.5714, -0.6320, -0.6012,
-        -0.9700,  0.2367, -0.6762,  0.1011,  1.0000,  0.0000,  0.4544,  1.0202,
-         0.4831,  1.0983, -0.0366,  0.7058,  0.5382, -0.4648,  0.6133,  0.1355,
-         0.8423, -0.4782,  0.2149,  0.2499, -0.4574, -0.0104,  0.7702,  0.0068,
-         1.2072,  1.0570,  0.8440, -0.4041,  1.3003,  0.2828, -0.2440, -0.2283],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.7767,  0.4693, -0.2234,  0.0584,  0.2462,  0.4369, -0.4370, -0.9640,
-        -0.2956, -0.6449, -0.1623, -0.5844, -0.1290,  0.3267, -1.3768, -0.4635,
-        -0.7871,  0.5390, -0.9756, -0.4746, -0.7173,  0.0412, -0.3197, -0.8180,
-        -0.5873,  0.5159, -0.0433,  0.0667, -0.1841,  1.5714, -0.6320, -0.6012,
-        -0.9700,  0.2367, -0.6762,  0.1011,  1.0000,  0.0000,  0.4544,  1.0202,
-         0.4831,  1.0983, -0.0366,  0.7058,  0.5382, -0.4648,  0.6133,  0.1355,
-         0.8423, -0.4782,  0.2149,  0.2499, -0.4574, -0.0104,  0.7702,  0.0000,
-         1.2072,  1.0570,  0.8440, -0.4041,  1.3003,  0.2828, -0.2440, -0.2283],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.7767,  0.4693, -0.2234,  0.0584,  0.2462,  0.4369, -0.4370, -0.9640,
-        -0.2956, -0.6449, -0.1623, -0.5844, -0.1290,  0.3267, -1.3768, -0.4635,
-        -0.7871,  0.5390, -0.9756, -0.4746, -0.7173,  0.0412, -0.3197, -0.8180,
-        -0.5873,  0.5159, -0.0433,  0.0667, -0.1841,  1.5714, -0.6320, -0.6012,
-        -0.9700,  0.2367, -0.6762,  0.1011,  1.0000,  0.0000,  0.4544,  1.0202,
-         0.4831,  1.0983, -0.0366,  0.7058,  0.5382, -0.4648,  0.6133,  0.1355,
-         0.8423, -0.4782,  0.2149,  0.2499, -0.4574, -0.0104,  0.7702,  0.0000,
-         1.2072,  1.0570,  0.8440, -0.4041,  1.3003,  0.2828, -0.2440, -0.2283],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.7623,  0.4791, -0.1915,  0.1059,  0.3411,  0.4290, -0.4388, -0.9620,
-        -0.2750, -0.6369, -0.1058, -0.5792, -0.0805,  0.3223, -1.3784, -0.4670,
-        -0.7863,  0.5326, -0.9764, -0.4681, -0.7218, -0.0032, -0.3203, -0.8200,
-        -0.6056,  0.5299, -0.0640,  0.0630, -0.2125,  1.5694, -0.6414, -0.5857,
-        -0.9748,  0.2057, -0.6683,  0.1073,  0.9950,  0.0000,  0.4127,  1.0137,
-         0.4832,  1.0998, -0.0139,  0.6940,  0.5331, -0.5273,  0.6015,  0.1758,
-         0.8410, -0.4738,  0.2250,  0.2178, -0.5148,  0.0082,  0.7708,  0.0061,
-         1.2073,  1.0476,  0.8404, -0.4347,  1.3025,  0.2677, -0.2484, -0.2204],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.7623,  0.4791, -0.1915,  0.1059,  0.3411,  0.4290, -0.4388, -0.9620,
-        -0.2750, -0.6369, -0.1058, -0.5792, -0.0805,  0.3223, -1.3784, -0.4670,
-        -0.7863,  0.5326, -0.9764, -0.4681, -0.7218, -0.0032, -0.3203, -0.8200,
-        -0.6056,  0.5299, -0.0640,  0.0630, -0.2125,  1.5694, -0.6414, -0.5857,
-        -0.9748,  0.2057, -0.6683,  0.1073,  0.9950,  0.0000,  0.4127,  1.0137,
-         0.4832,  1.0998, -0.0139,  0.6940,  0.5331, -0.5273,  0.6015,  0.1758,
-         0.8410, -0.4738,  0.2250,  0.2178, -0.5148,  0.0082,  0.7708,  0.0000,
-         1.2073,  1.0476,  0.8404, -0.4347,  1.3025,  0.2677, -0.2484, -0.2204],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.7623,  0.4791, -0.1915,  0.1059,  0.3411,  0.4290, -0.4388, -0.9620,
-        -0.2750, -0.6369, -0.1058, -0.5792, -0.0805,  0.3223, -1.3784, -0.4670,
-        -0.7863,  0.5326, -0.9764, -0.4681, -0.7218, -0.0032, -0.3203, -0.8200,
-        -0.6056,  0.5299, -0.0640,  0.0630, -0.2125,  1.5694, -0.6414, -0.5857,
-        -0.9748,  0.2057, -0.6683,  0.1073,  0.9950,  0.0000,  0.4127,  1.0137,
-         0.4832,  1.0998, -0.0139,  0.6940,  0.5331, -0.5273,  0.6015,  0.1758,
-         0.8410, -0.4738,  0.2250,  0.2178, -0.5148,  0.0082,  0.7708,  0.0000,
-         1.2073,  1.0476,  0.8404, -0.4347,  1.3025,  0.2677, -0.2484, -0.2204],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.7489,  0.4920, -0.1744,  0.1343,  0.3842,  0.4058, -0.4467, -0.9578,
-        -0.2606, -0.6295, -0.0653, -0.5724, -0.0502,  0.3123, -1.3794, -0.4577,
-        -0.7817,  0.5279, -0.9776, -0.4546, -0.7167, -0.0193, -0.3095, -0.8039,
-        -0.6164,  0.5401, -0.1128,  0.0595, -0.2329,  1.5671, -0.6505, -0.5671,
-        -0.9782,  0.1780, -0.6553,  0.1061,  0.9941,  0.0000,  0.3796,  1.0085,
-         0.4856,  1.0998, -0.0187,  0.6779,  0.5283, -0.5838,  0.5920,  0.2017,
-         0.8391, -0.4696,  0.2370,  0.1961, -0.5607,  0.0315,  0.7685,  0.0055,
-         1.2078,  1.0433,  0.8305, -0.4443,  1.3048,  0.2536, -0.2400, -0.2525],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.7489,  0.4920, -0.1744,  0.1343,  0.3842,  0.4058, -0.4467, -0.9578,
-        -0.2606, -0.6295, -0.0653, -0.5724, -0.0502,  0.3123, -1.3794, -0.4577,
-        -0.7817,  0.5279, -0.9776, -0.4546, -0.7167, -0.0193, -0.3095, -0.8039,
-        -0.6164,  0.5401, -0.1128,  0.0595, -0.2329,  1.5671, -0.6505, -0.5671,
-        -0.9782,  0.1780, -0.6553,  0.1061,  0.9941,  0.0000,  0.3796,  1.0085,
-         0.4856,  1.0998, -0.0187,  0.6779,  0.5283, -0.5838,  0.5920,  0.2017,
-         0.8391, -0.4696,  0.2370,  0.1961, -0.5607,  0.0315,  0.7685,  0.0000,
-         1.2078,  1.0433,  0.8305, -0.4443,  1.3048,  0.2536, -0.2400, -0.2525],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.7489,  0.4920, -0.1744,  0.1343,  0.3842,  0.4058, -0.4467, -0.9578,
-        -0.2606, -0.6295, -0.0653, -0.5724, -0.0502,  0.3123, -1.3794, -0.4577,
-        -0.7817,  0.5279, -0.9776, -0.4546, -0.7167, -0.0193, -0.3095, -0.8039,
-        -0.6164,  0.5401, -0.1128,  0.0595, -0.2329,  1.5671, -0.6505, -0.5671,
-        -0.9782,  0.1780, -0.6553,  0.1061,  0.9941,  0.0000,  0.3796,  1.0085,
-         0.4856,  1.0998, -0.0187,  0.6779,  0.5283, -0.5838,  0.5920,  0.2017,
-         0.8391, -0.4696,  0.2370,  0.1961, -0.5607,  0.0315,  0.7685,  0.0000,
-         1.2078,  1.0433,  0.8305, -0.4443,  1.3048,  0.2536, -0.2400, -0.2525],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.7205,  0.5222, -0.1530,  0.1550,  0.4126,  0.3960, -0.4590, -0.9528,
-        -0.2547, -0.6189, -0.0316, -0.5702, -0.0458,  0.2591, -1.3793, -0.4546,
-        -0.7812,  0.5213, -0.9767, -0.4498, -0.7058,  0.0328, -0.3017, -0.7928,
-        -0.6130,  0.5416, -0.1469,  0.0830, -0.2618,  1.5645, -0.6570, -0.5486,
-        -0.9839,  0.1876, -0.6433,  0.0850,  0.9905,  0.0000,  0.2882,  0.9983,
-         0.4979,  1.0988, -0.0507,  0.6555,  0.5255, -0.6291,  0.5893,  0.1969,
-         0.8536, -0.4656,  0.2262,  0.1734, -0.5982,  0.0413,  0.7642,  0.0049,
-         1.2074,  1.0368,  0.8223, -0.4417,  1.3040,  0.2426, -0.2408, -0.2619],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.7205,  0.5222, -0.1530,  0.1550,  0.4126,  0.3960, -0.4590, -0.9528,
-        -0.2547, -0.6189, -0.0316, -0.5702, -0.0458,  0.2591, -1.3793, -0.4546,
-        -0.7812,  0.5213, -0.9767, -0.4498, -0.7058,  0.0328, -0.3017, -0.7928,
-        -0.6130,  0.5416, -0.1469,  0.0830, -0.2618,  1.5645, -0.6570, -0.5486,
-        -0.9839,  0.1876, -0.6433,  0.0850,  0.9905,  0.0000,  0.2882,  0.9983,
-         0.4979,  1.0988, -0.0507,  0.6555,  0.5255, -0.6291,  0.5893,  0.1969,
-         0.8536, -0.4656,  0.2262,  0.1734, -0.5982,  0.0413,  0.7642,  0.0000,
-         1.2074,  1.0368,  0.8223, -0.4417,  1.3040,  0.2426, -0.2408, -0.2619],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.7205,  0.5222, -0.1530,  0.1550,  0.4126,  0.3960, -0.4590, -0.9528,
-        -0.2547, -0.6189, -0.0316, -0.5702, -0.0458,  0.2591, -1.3793, -0.4546,
-        -0.7812,  0.5213, -0.9767, -0.4498, -0.7058,  0.0328, -0.3017, -0.7928,
-        -0.6130,  0.5416, -0.1469,  0.0830, -0.2618,  1.5645, -0.6570, -0.5486,
-        -0.9839,  0.1876, -0.6433,  0.0850,  0.9905,  0.0000,  0.2882,  0.9983,
-         0.4979,  1.0988, -0.0507,  0.6555,  0.5255, -0.6291,  0.5893,  0.1969,
-         0.8536, -0.4656,  0.2262,  0.1734, -0.5982,  0.0413,  0.7642,  0.0000,
-         1.2074,  1.0368,  0.8223, -0.4417,  1.3040,  0.2426, -0.2408, -0.2619],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.6848,  0.5551, -0.1236,  0.1538,  0.4349,  0.3923, -0.4691, -0.9467,
-        -0.2398, -0.6073, -0.0110, -0.5766, -0.0535,  0.1671, -1.3782, -0.4541,
-        -0.7812,  0.5111, -0.9752, -0.4510, -0.6966,  0.1244, -0.2817, -0.7816,
-        -0.5993,  0.5374, -0.1961,  0.1127, -0.2874,  1.5630, -0.6686, -0.5304,
-        -0.9887,  0.2162, -0.6334,  0.0405,  0.9871,  0.0000,  0.1592,  0.9894,
-         0.5114,  1.0993, -0.0869,  0.6345,  0.5232, -0.6698,  0.5903,  0.1972,
-         0.8739, -0.4563,  0.1931,  0.1418, -0.6319,  0.0392,  0.7599,  0.0044,
-         1.2064,  1.0363,  0.8116, -0.4307,  1.3036,  0.2288, -0.2439, -0.2580],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.6848,  0.5551, -0.1236,  0.1538,  0.4349,  0.3923, -0.4691, -0.9467,
-        -0.2398, -0.6073, -0.0110, -0.5766, -0.0535,  0.1671, -1.3782, -0.4541,
-        -0.7812,  0.5111, -0.9752, -0.4510, -0.6966,  0.1244, -0.2817, -0.7816,
-        -0.5993,  0.5374, -0.1961,  0.1127, -0.2874,  1.5630, -0.6686, -0.5304,
-        -0.9887,  0.2162, -0.6334,  0.0405,  0.9871,  0.0000,  0.1592,  0.9894,
-         0.5114,  1.0993, -0.0869,  0.6345,  0.5232, -0.6698,  0.5903,  0.1972,
-         0.8739, -0.4563,  0.1931,  0.1418, -0.6319,  0.0392,  0.7599,  0.0000,
-         1.2064,  1.0363,  0.8116, -0.4307,  1.3036,  0.2288, -0.2439, -0.2580],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6848,  0.5551, -0.1236,  0.1538,  0.4349,  0.3923, -0.4691, -0.9467,
-        -0.2398, -0.6073, -0.0110, -0.5766, -0.0535,  0.1671, -1.3782, -0.4541,
-        -0.7812,  0.5111, -0.9752, -0.4510, -0.6966,  0.1244, -0.2817, -0.7816,
-        -0.5993,  0.5374, -0.1961,  0.1127, -0.2874,  1.5630, -0.6686, -0.5304,
-        -0.9887,  0.2162, -0.6334,  0.0405,  0.9871,  0.0000,  0.1592,  0.9894,
-         0.5114,  1.0993, -0.0869,  0.6345,  0.5232, -0.6698,  0.5903,  0.1972,
-         0.8739, -0.4563,  0.1931,  0.1418, -0.6319,  0.0392,  0.7599,  0.0000,
-         1.2064,  1.0363,  0.8116, -0.4307,  1.3036,  0.2288, -0.2439, -0.2580],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.5870e-01,  5.5715e-01, -6.9219e-02,  1.7682e-01,  4.3876e-01,
-         3.8071e-01, -4.7506e-01, -9.4113e-01, -2.1942e-01, -5.9914e-01,
-         1.1191e-03, -5.6578e-01, -1.4353e-02,  1.7588e-01, -1.3771e+00,
-        -4.6068e-01, -7.7570e-01,  5.0082e-01, -9.7435e-01, -4.5474e-01,
-        -6.8725e-01,  1.3975e-01, -2.7406e-01, -7.6694e-01, -5.9270e-01,
-         5.4300e-01, -2.0805e-01,  1.2052e-01, -3.0651e-01,  1.5608e+00,
-        -6.7597e-01, -5.0654e-01, -9.9108e-01,  2.1175e-01, -6.2769e-01,
-        -1.4722e-02,  9.9377e-01,  0.0000e+00,  1.4855e-01,  9.8261e-01,
-         5.0991e-01,  1.1008e+00, -1.0448e-01,  6.2125e-01,  5.2014e-01,
-        -6.9215e-01,  5.9101e-01,  2.3236e-01,  8.8575e-01, -4.4348e-01,
-         1.5843e-01,  1.0954e-01, -6.5193e-01,  3.5356e-02,  7.6581e-01,
-         3.9493e-03,  1.2065e+00,  1.0422e+00,  7.9736e-01, -4.1187e-01,
-         1.3048e+00,  2.0540e-01, -2.5482e-01, -2.2250e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 6.5870e-01,  5.5715e-01, -6.9219e-02,  1.7682e-01,  4.3876e-01,
-         3.8071e-01, -4.7506e-01, -9.4113e-01, -2.1942e-01, -5.9914e-01,
-         1.1191e-03, -5.6578e-01, -1.4353e-02,  1.7588e-01, -1.3771e+00,
-        -4.6068e-01, -7.7570e-01,  5.0082e-01, -9.7435e-01, -4.5474e-01,
-        -6.8725e-01,  1.3975e-01, -2.7406e-01, -7.6694e-01, -5.9270e-01,
-         5.4300e-01, -2.0805e-01,  1.2052e-01, -3.0651e-01,  1.5608e+00,
-        -6.7597e-01, -5.0654e-01, -9.9108e-01,  2.1175e-01, -6.2769e-01,
-        -1.4722e-02,  9.9377e-01,  0.0000e+00,  1.4855e-01,  9.8261e-01,
-         5.0991e-01,  1.1008e+00, -1.0448e-01,  6.2125e-01,  5.2014e-01,
-        -6.9215e-01,  5.9101e-01,  2.3236e-01,  8.8575e-01, -4.4348e-01,
-         1.5843e-01,  1.0954e-01, -6.5193e-01,  3.5356e-02,  7.6581e-01,
-         0.0000e+00,  1.2065e+00,  1.0422e+00,  7.9736e-01, -4.1187e-01,
-         1.3048e+00,  2.0540e-01, -2.5482e-01, -2.2250e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 6.5870e-01,  5.5715e-01, -6.9219e-02,  1.7682e-01,  4.3876e-01,
-         3.8071e-01, -4.7506e-01, -9.4113e-01, -2.1942e-01, -5.9914e-01,
-         1.1191e-03, -5.6578e-01, -1.4353e-02,  1.7588e-01, -1.3771e+00,
-        -4.6068e-01, -7.7570e-01,  5.0082e-01, -9.7435e-01, -4.5474e-01,
-        -6.8725e-01,  1.3975e-01, -2.7406e-01, -7.6694e-01, -5.9270e-01,
-         5.4300e-01, -2.0805e-01,  1.2052e-01, -3.0651e-01,  1.5608e+00,
-        -6.7597e-01, -5.0654e-01, -9.9108e-01,  2.1175e-01, -6.2769e-01,
-        -1.4722e-02,  9.9377e-01,  0.0000e+00,  1.4855e-01,  9.8261e-01,
-         5.0991e-01,  1.1008e+00, -1.0448e-01,  6.2125e-01,  5.2014e-01,
-        -6.9215e-01,  5.9101e-01,  2.3236e-01,  8.8575e-01, -4.4348e-01,
-         1.5843e-01,  1.0954e-01, -6.5193e-01,  3.5356e-02,  7.6581e-01,
-         0.0000e+00,  1.2065e+00,  1.0422e+00,  7.9736e-01, -4.1187e-01,
-         1.3048e+00,  2.0540e-01, -2.5482e-01, -2.2250e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 0.6382,  0.5408,  0.0427,  0.2101,  0.4241,  0.3461, -0.4705, -0.9347,
-        -0.1797, -0.5914,  0.0109, -0.5453,  0.1068,  0.2282, -1.3768, -0.4758,
-        -0.7751,  0.4818, -0.9741, -0.4543, -0.6869,  0.0983, -0.2858, -0.7501,
-        -0.5831,  0.5540, -0.1615,  0.1113, -0.3177,  1.5577, -0.6703, -0.4796,
-        -0.9940,  0.1809, -0.6268, -0.0727,  1.0035,  0.0000,  0.2040,  0.9783,
-         0.4860,  1.1042, -0.0764,  0.6149,  0.5125, -0.7006,  0.5841,  0.2719,
-         0.8924, -0.4291,  0.1192,  0.0809, -0.6587,  0.0225,  0.7770,  0.0035,
-         1.2057,  1.0478,  0.7713, -0.4022,  1.3072,  0.1673, -0.2472, -0.1263],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.6382,  0.5408,  0.0427,  0.2101,  0.4241,  0.3461, -0.4705, -0.9347,
-        -0.1797, -0.5914,  0.0109, -0.5453,  0.1068,  0.2282, -1.3768, -0.4758,
-        -0.7751,  0.4818, -0.9741, -0.4543, -0.6869,  0.0983, -0.2858, -0.7501,
-        -0.5831,  0.5540, -0.1615,  0.1113, -0.3177,  1.5577, -0.6703, -0.4796,
-        -0.9940,  0.1809, -0.6268, -0.0727,  1.0035,  0.0000,  0.2040,  0.9783,
-         0.4860,  1.1042, -0.0764,  0.6149,  0.5125, -0.7006,  0.5841,  0.2719,
-         0.8924, -0.4291,  0.1192,  0.0809, -0.6587,  0.0225,  0.7770,  0.0000,
-         1.2057,  1.0478,  0.7713, -0.4022,  1.3072,  0.1673, -0.2472, -0.1263],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6382,  0.5408,  0.0427,  0.2101,  0.4241,  0.3461, -0.4705, -0.9347,
-        -0.1797, -0.5914,  0.0109, -0.5453,  0.1068,  0.2282, -1.3768, -0.4758,
-        -0.7751,  0.4818, -0.9741, -0.4543, -0.6869,  0.0983, -0.2858, -0.7501,
-        -0.5831,  0.5540, -0.1615,  0.1113, -0.3177,  1.5577, -0.6703, -0.4796,
-        -0.9940,  0.1809, -0.6268, -0.0727,  1.0035,  0.0000,  0.2040,  0.9783,
-         0.4860,  1.1042, -0.0764,  0.6149,  0.5125, -0.7006,  0.5841,  0.2719,
-         0.8924, -0.4291,  0.1192,  0.0809, -0.6587,  0.0225,  0.7770,  0.0000,
-         1.2057,  1.0478,  0.7713, -0.4022,  1.3072,  0.1673, -0.2472, -0.1263],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.6299,  0.5402,  0.1409,  0.2143,  0.4311,  0.3010, -0.4659, -0.9314,
-        -0.1441, -0.5819,  0.0179, -0.5103,  0.2177,  0.2475, -1.3752, -0.4903,
-        -0.7693,  0.4631, -0.9757, -0.4491, -0.6893,  0.0584, -0.3036, -0.7383,
-        -0.5942,  0.5652, -0.1529,  0.0918, -0.3400,  1.5559, -0.6581, -0.4661,
-        -0.9935,  0.1553, -0.6217, -0.1191,  1.0134,  0.0000,  0.2537,  0.9748,
-         0.4589,  1.1063, -0.0611,  0.5946,  0.5056, -0.7167,  0.5784,  0.2913,
-         0.9033, -0.4131,  0.1245,  0.0359, -0.6750, -0.0107,  0.7910,  0.0032,
-         1.2050,  1.0504,  0.7484, -0.3984,  1.3102,  0.1164, -0.2541, -0.0773],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.6299,  0.5402,  0.1409,  0.2143,  0.4311,  0.3010, -0.4659, -0.9314,
-        -0.1441, -0.5819,  0.0179, -0.5103,  0.2177,  0.2475, -1.3752, -0.4903,
-        -0.7693,  0.4631, -0.9757, -0.4491, -0.6893,  0.0584, -0.3036, -0.7383,
-        -0.5942,  0.5652, -0.1529,  0.0918, -0.3400,  1.5559, -0.6581, -0.4661,
-        -0.9935,  0.1553, -0.6217, -0.1191,  1.0134,  0.0000,  0.2537,  0.9748,
-         0.4589,  1.1063, -0.0611,  0.5946,  0.5056, -0.7167,  0.5784,  0.2913,
-         0.9033, -0.4131,  0.1245,  0.0359, -0.6750, -0.0107,  0.7910,  0.0000,
-         1.2050,  1.0504,  0.7484, -0.3984,  1.3102,  0.1164, -0.2541, -0.0773],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6299,  0.5402,  0.1409,  0.2143,  0.4311,  0.3010, -0.4659, -0.9314,
-        -0.1441, -0.5819,  0.0179, -0.5103,  0.2177,  0.2475, -1.3752, -0.4903,
-        -0.7693,  0.4631, -0.9757, -0.4491, -0.6893,  0.0584, -0.3036, -0.7383,
-        -0.5942,  0.5652, -0.1529,  0.0918, -0.3400,  1.5559, -0.6581, -0.4661,
-        -0.9935,  0.1553, -0.6217, -0.1191,  1.0134,  0.0000,  0.2537,  0.9748,
-         0.4589,  1.1063, -0.0611,  0.5946,  0.5056, -0.7167,  0.5784,  0.2913,
-         0.9033, -0.4131,  0.1245,  0.0359, -0.6750, -0.0107,  0.7910,  0.0000,
-         1.2050,  1.0504,  0.7484, -0.3984,  1.3102,  0.1164, -0.2541, -0.0773],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.6209,  0.5493,  0.2001,  0.2187,  0.4439,  0.2660, -0.4628, -0.9284,
-        -0.1247, -0.5698,  0.0491, -0.4859,  0.2783,  0.2530, -1.3742, -0.5011,
-        -0.7558,  0.4420, -0.9762, -0.4517, -0.6882,  0.0881, -0.2835, -0.7221,
-        -0.5781,  0.5690, -0.1657,  0.0790, -0.3528,  1.5555, -0.6453, -0.4532,
-        -0.9911,  0.1473, -0.6176, -0.1869,  1.0270,  0.0000,  0.3172,  0.9688,
-         0.4431,  1.1128, -0.0593,  0.5775,  0.4988, -0.7158,  0.5743,  0.2983,
-         0.9107, -0.3982,  0.1645,  0.0152, -0.6827, -0.0529,  0.8020,  0.0028,
-         1.2040,  1.0507,  0.7201, -0.3739,  1.3138,  0.0781, -0.2506, -0.0702],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.6209,  0.5493,  0.2001,  0.2187,  0.4439,  0.2660, -0.4628, -0.9284,
-        -0.1247, -0.5698,  0.0491, -0.4859,  0.2783,  0.2530, -1.3742, -0.5011,
-        -0.7558,  0.4420, -0.9762, -0.4517, -0.6882,  0.0881, -0.2835, -0.7221,
-        -0.5781,  0.5690, -0.1657,  0.0790, -0.3528,  1.5555, -0.6453, -0.4532,
-        -0.9911,  0.1473, -0.6176, -0.1869,  1.0270,  0.0000,  0.3172,  0.9688,
-         0.4431,  1.1128, -0.0593,  0.5775,  0.4988, -0.7158,  0.5743,  0.2983,
-         0.9107, -0.3982,  0.1645,  0.0152, -0.6827, -0.0529,  0.8020,  0.0000,
-         1.2040,  1.0507,  0.7201, -0.3739,  1.3138,  0.0781, -0.2506, -0.0702],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6209,  0.5493,  0.2001,  0.2187,  0.4439,  0.2660, -0.4628, -0.9284,
-        -0.1247, -0.5698,  0.0491, -0.4859,  0.2783,  0.2530, -1.3742, -0.5011,
-        -0.7558,  0.4420, -0.9762, -0.4517, -0.6882,  0.0881, -0.2835, -0.7221,
-        -0.5781,  0.5690, -0.1657,  0.0790, -0.3528,  1.5555, -0.6453, -0.4532,
-        -0.9911,  0.1473, -0.6176, -0.1869,  1.0270,  0.0000,  0.3172,  0.9688,
-         0.4431,  1.1128, -0.0593,  0.5775,  0.4988, -0.7158,  0.5743,  0.2983,
-         0.9107, -0.3982,  0.1645,  0.0152, -0.6827, -0.0529,  0.8020,  0.0000,
-         1.2040,  1.0507,  0.7201, -0.3739,  1.3138,  0.0781, -0.2506, -0.0702],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.5955,  0.5675,  0.2330,  0.2362,  0.4632,  0.2458, -0.4477, -0.9226,
-        -0.0987, -0.5497,  0.1231, -0.4790,  0.3334,  0.2342, -1.3730, -0.4931,
-        -0.7453,  0.4307, -0.9776, -0.4503, -0.6800,  0.1665, -0.2474, -0.7028,
-        -0.5253,  0.5699, -0.1896,  0.0797, -0.3371,  1.5565, -0.6445, -0.4397,
-        -0.9890,  0.1524, -0.6059, -0.2657,  1.0405,  0.0000,  0.3600,  0.9622,
-         0.4365,  1.1166, -0.0735,  0.5589,  0.4895, -0.7132,  0.5781,  0.3026,
-         0.9273, -0.3784,  0.2104,  0.0228, -0.6957, -0.0547,  0.8135,  0.0025,
-         1.2031,  1.0502,  0.6972, -0.3073,  1.3171,  0.0594, -0.2293, -0.0916],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.5955,  0.5675,  0.2330,  0.2362,  0.4632,  0.2458, -0.4477, -0.9226,
-        -0.0987, -0.5497,  0.1231, -0.4790,  0.3334,  0.2342, -1.3730, -0.4931,
-        -0.7453,  0.4307, -0.9776, -0.4503, -0.6800,  0.1665, -0.2474, -0.7028,
-        -0.5253,  0.5699, -0.1896,  0.0797, -0.3371,  1.5565, -0.6445, -0.4397,
-        -0.9890,  0.1524, -0.6059, -0.2657,  1.0405,  0.0000,  0.3600,  0.9622,
-         0.4365,  1.1166, -0.0735,  0.5589,  0.4895, -0.7132,  0.5781,  0.3026,
-         0.9273, -0.3784,  0.2104,  0.0228, -0.6957, -0.0547,  0.8135,  0.0000,
-         1.2031,  1.0502,  0.6972, -0.3073,  1.3171,  0.0594, -0.2293, -0.0916],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5955,  0.5675,  0.2330,  0.2362,  0.4632,  0.2458, -0.4477, -0.9226,
-        -0.0987, -0.5497,  0.1231, -0.4790,  0.3334,  0.2342, -1.3730, -0.4931,
-        -0.7453,  0.4307, -0.9776, -0.4503, -0.6800,  0.1665, -0.2474, -0.7028,
-        -0.5253,  0.5699, -0.1896,  0.0797, -0.3371,  1.5565, -0.6445, -0.4397,
-        -0.9890,  0.1524, -0.6059, -0.2657,  1.0405,  0.0000,  0.3600,  0.9622,
-         0.4365,  1.1166, -0.0735,  0.5589,  0.4895, -0.7132,  0.5781,  0.3026,
-         0.9273, -0.3784,  0.2104,  0.0228, -0.6957, -0.0547,  0.8135,  0.0000,
-         1.2031,  1.0502,  0.6972, -0.3073,  1.3171,  0.0594, -0.2293, -0.0916],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.5813,  0.5651,  0.2753,  0.2543,  0.4302,  0.2038, -0.4176, -0.9174,
-        -0.0976, -0.5289,  0.1641, -0.4786,  0.3949,  0.2580, -1.3714, -0.4751,
-        -0.7186,  0.4255, -0.9798, -0.4561, -0.6782,  0.1596, -0.2272, -0.6777,
-        -0.5192,  0.5775, -0.1882,  0.0613, -0.3343,  1.5559, -0.6304, -0.4257,
-        -0.9888,  0.1033, -0.5965, -0.3489,  1.0535,  0.0000,  0.4118,  0.9567,
-         0.4148,  1.1235, -0.0830,  0.5351,  0.4784, -0.7118,  0.5742,  0.3149,
-         0.9327, -0.3449,  0.2411, -0.0334, -0.7113, -0.0721,  0.8276,  0.0022,
-         1.2025,  1.0502,  0.6623, -0.2959,  1.3227,  0.0145, -0.2248, -0.0907],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.5813,  0.5651,  0.2753,  0.2543,  0.4302,  0.2038, -0.4176, -0.9174,
-        -0.0976, -0.5289,  0.1641, -0.4786,  0.3949,  0.2580, -1.3714, -0.4751,
-        -0.7186,  0.4255, -0.9798, -0.4561, -0.6782,  0.1596, -0.2272, -0.6777,
-        -0.5192,  0.5775, -0.1882,  0.0613, -0.3343,  1.5559, -0.6304, -0.4257,
-        -0.9888,  0.1033, -0.5965, -0.3489,  1.0535,  0.0000,  0.4118,  0.9567,
-         0.4148,  1.1235, -0.0830,  0.5351,  0.4784, -0.7118,  0.5742,  0.3149,
-         0.9327, -0.3449,  0.2411, -0.0334, -0.7113, -0.0721,  0.8276,  0.0000,
-         1.2025,  1.0502,  0.6623, -0.2959,  1.3227,  0.0145, -0.2248, -0.0907],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5813,  0.5651,  0.2753,  0.2543,  0.4302,  0.2038, -0.4176, -0.9174,
-        -0.0976, -0.5289,  0.1641, -0.4786,  0.3949,  0.2580, -1.3714, -0.4751,
-        -0.7186,  0.4255, -0.9798, -0.4561, -0.6782,  0.1596, -0.2272, -0.6777,
-        -0.5192,  0.5775, -0.1882,  0.0613, -0.3343,  1.5559, -0.6304, -0.4257,
-        -0.9888,  0.1033, -0.5965, -0.3489,  1.0535,  0.0000,  0.4118,  0.9567,
-         0.4148,  1.1235, -0.0830,  0.5351,  0.4784, -0.7118,  0.5742,  0.3149,
-         0.9327, -0.3449,  0.2411, -0.0334, -0.7113, -0.0721,  0.8276,  0.0000,
-         1.2025,  1.0502,  0.6623, -0.2959,  1.3227,  0.0145, -0.2248, -0.0907],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.5591,  0.5713,  0.3191,  0.2683,  0.4031,  0.1885, -0.4046, -0.9093,
-        -0.0924, -0.5008,  0.1842, -0.4614,  0.4461,  0.2507, -1.3643, -0.4583,
-        -0.6997,  0.4122, -0.9827, -0.4529, -0.6729,  0.1595, -0.2354, -0.6300,
-        -0.5200,  0.5701, -0.1655,  0.0521, -0.2947,  1.5550, -0.6139, -0.4070,
-        -0.9856,  0.0863, -0.5962, -0.3972,  1.0654,  0.0000,  0.4179,  0.9464,
-         0.3835,  1.1266, -0.0424,  0.5093,  0.4666, -0.6983,  0.5687,  0.3391,
-         0.9374, -0.3153,  0.2666, -0.0591, -0.7222, -0.0757,  0.8320,  0.0020,
-         1.2009,  1.0451,  0.6168, -0.2452,  1.3270, -0.0100, -0.1978, -0.0275],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.5591,  0.5713,  0.3191,  0.2683,  0.4031,  0.1885, -0.4046, -0.9093,
-        -0.0924, -0.5008,  0.1842, -0.4614,  0.4461,  0.2507, -1.3643, -0.4583,
-        -0.6997,  0.4122, -0.9827, -0.4529, -0.6729,  0.1595, -0.2354, -0.6300,
-        -0.5200,  0.5701, -0.1655,  0.0521, -0.2947,  1.5550, -0.6139, -0.4070,
-        -0.9856,  0.0863, -0.5962, -0.3972,  1.0654,  0.0000,  0.4179,  0.9464,
-         0.3835,  1.1266, -0.0424,  0.5093,  0.4666, -0.6983,  0.5687,  0.3391,
-         0.9374, -0.3153,  0.2666, -0.0591, -0.7222, -0.0757,  0.8320,  0.0000,
-         1.2009,  1.0451,  0.6168, -0.2452,  1.3270, -0.0100, -0.1978, -0.0275],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5591,  0.5713,  0.3191,  0.2683,  0.4031,  0.1885, -0.4046, -0.9093,
-        -0.0924, -0.5008,  0.1842, -0.4614,  0.4461,  0.2507, -1.3643, -0.4583,
-        -0.6997,  0.4122, -0.9827, -0.4529, -0.6729,  0.1595, -0.2354, -0.6300,
-        -0.5200,  0.5701, -0.1655,  0.0521, -0.2947,  1.5550, -0.6139, -0.4070,
-        -0.9856,  0.0863, -0.5962, -0.3972,  1.0654,  0.0000,  0.4179,  0.9464,
-         0.3835,  1.1266, -0.0424,  0.5093,  0.4666, -0.6983,  0.5687,  0.3391,
-         0.9374, -0.3153,  0.2666, -0.0591, -0.7222, -0.0757,  0.8320,  0.0000,
-         1.2009,  1.0451,  0.6168, -0.2452,  1.3270, -0.0100, -0.1978, -0.0275],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.5294,  0.5662,  0.3447,  0.2772,  0.4133,  0.1618, -0.4182, -0.9009,
-        -0.0963, -0.4656,  0.1341, -0.4548,  0.4800,  0.2360, -1.3569, -0.4797,
-        -0.7023,  0.4031, -0.9835, -0.4663, -0.6635,  0.1190, -0.2683, -0.5804,
-        -0.5047,  0.5566, -0.0453,  0.0423, -0.2403,  1.5543, -0.5938, -0.3708,
-        -0.9854,  0.0826, -0.6210, -0.4279,  1.0800,  0.0000,  0.4280,  0.9370,
-         0.3243,  1.1275,  0.0757,  0.4907,  0.4579, -0.6528,  0.5595,  0.3604,
-         0.9395, -0.2901,  0.2684, -0.0788, -0.7184, -0.0177,  0.8324,  0.0018,
-         1.1991,  1.0430,  0.5760, -0.1971,  1.3300, -0.0120, -0.1227,  0.1753],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.5294,  0.5662,  0.3447,  0.2772,  0.4133,  0.1618, -0.4182, -0.9009,
-        -0.0963, -0.4656,  0.1341, -0.4548,  0.4800,  0.2360, -1.3569, -0.4797,
-        -0.7023,  0.4031, -0.9835, -0.4663, -0.6635,  0.1190, -0.2683, -0.5804,
-        -0.5047,  0.5566, -0.0453,  0.0423, -0.2403,  1.5543, -0.5938, -0.3708,
-        -0.9854,  0.0826, -0.6210, -0.4279,  1.0800,  0.0000,  0.4280,  0.9370,
-         0.3243,  1.1275,  0.0757,  0.4907,  0.4579, -0.6528,  0.5595,  0.3604,
-         0.9395, -0.2901,  0.2684, -0.0788, -0.7184, -0.0177,  0.8324,  0.0000,
-         1.1991,  1.0430,  0.5760, -0.1971,  1.3300, -0.0120, -0.1227,  0.1753],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5294,  0.5662,  0.3447,  0.2772,  0.4133,  0.1618, -0.4182, -0.9009,
-        -0.0963, -0.4656,  0.1341, -0.4548,  0.4800,  0.2360, -1.3569, -0.4797,
-        -0.7023,  0.4031, -0.9835, -0.4663, -0.6635,  0.1190, -0.2683, -0.5804,
-        -0.5047,  0.5566, -0.0453,  0.0423, -0.2403,  1.5543, -0.5938, -0.3708,
-        -0.9854,  0.0826, -0.6210, -0.4279,  1.0800,  0.0000,  0.4280,  0.9370,
-         0.3243,  1.1275,  0.0757,  0.4907,  0.4579, -0.6528,  0.5595,  0.3604,
-         0.9395, -0.2901,  0.2684, -0.0788, -0.7184, -0.0177,  0.8324,  0.0000,
-         1.1991,  1.0430,  0.5760, -0.1971,  1.3300, -0.0120, -0.1227,  0.1753],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 0.5005,  0.5507,  0.3612,  0.2846,  0.4012,  0.1246, -0.4335, -0.8915,
-        -0.1007, -0.4370,  0.0806, -0.4378,  0.5013,  0.2206, -1.3530, -0.4994,
-        -0.6981,  0.3951, -0.9832, -0.4623, -0.6551,  0.0596, -0.2913, -0.5655,
-        -0.4720,  0.5376,  0.0360,  0.0245, -0.2119,  1.5503, -0.5772, -0.3315,
-        -0.9879,  0.0964, -0.6416, -0.4411,  1.0923,  0.0000,  0.4323,  0.9236,
-         0.2467,  1.1289,  0.1601,  0.4739,  0.4536, -0.6137,  0.5505,  0.3725,
-         0.9316, -0.2821,  0.2657, -0.0886, -0.7127,  0.0501,  0.8308,  0.0016,
-         1.1966,  1.0436,  0.5378, -0.1836,  1.3361, -0.0289, -0.0317,  0.3285],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Step tensor([ 0.5005,  0.5507,  0.3612,  0.2846,  0.4012,  0.1246, -0.4335, -0.8915,
-        -0.1007, -0.4370,  0.0806, -0.4378,  0.5013,  0.2206, -1.3530, -0.4994,
-        -0.6981,  0.3951, -0.9832, -0.4623, -0.6551,  0.0596, -0.2913, -0.5655,
-        -0.4720,  0.5376,  0.0360,  0.0245, -0.2119,  1.5503, -0.5772, -0.3315,
-        -0.9879,  0.0964, -0.6416, -0.4411,  1.0923,  0.0000,  0.4323,  0.9236,
-         0.2467,  1.1289,  0.1601,  0.4739,  0.4536, -0.6137,  0.5505,  0.3725,
-         0.9316, -0.2821,  0.2657, -0.0886, -0.7127,  0.0501,  0.8308,  0.0000,
-         1.1966,  1.0436,  0.5378, -0.1836,  1.3361, -0.0289, -0.0317,  0.3285],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5005,  0.5507,  0.3612,  0.2846,  0.4012,  0.1246, -0.4335, -0.8915,
-        -0.1007, -0.4370,  0.0806, -0.4378,  0.5013,  0.2206, -1.3530, -0.4994,
-        -0.6981,  0.3951, -0.9832, -0.4623, -0.6551,  0.0596, -0.2913, -0.5655,
-        -0.4720,  0.5376,  0.0360,  0.0245, -0.2119,  1.5503, -0.5772, -0.3315,
-        -0.9879,  0.0964, -0.6416, -0.4411,  1.0923,  0.0000,  0.4323,  0.9236,
-         0.2467,  1.1289,  0.1601,  0.4739,  0.4536, -0.6137,  0.5505,  0.3725,
-         0.9316, -0.2821,  0.2657, -0.0886, -0.7127,  0.0501,  0.8308,  0.0000,
-         1.1966,  1.0436,  0.5378, -0.1836,  1.3361, -0.0289, -0.0317,  0.3285],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.7434e-01,  5.3136e-01,  3.4447e-01,  2.6340e-01,  3.8525e-01,
-         1.7047e-02, -4.3931e-01, -8.8958e-01, -1.2484e-01, -4.2201e-01,
-         5.8387e-02, -4.2758e-01,  4.8165e-01,  2.4041e-01, -1.3549e+00,
-        -5.0273e-01, -6.8400e-01,  3.8064e-01, -9.7913e-01, -4.5901e-01,
-        -6.4912e-01, -2.2920e-02, -2.8090e-01, -5.8313e-01, -4.4491e-01,
-         5.3314e-01, -2.4882e-04, -3.0611e-02, -2.1588e-01,  1.5445e+00,
-        -5.4475e-01, -2.9967e-01, -9.9086e-01,  9.5095e-02, -6.5237e-01,
-        -4.6093e-01,  1.1036e+00,  0.0000e+00,  4.5985e-01,  9.1634e-01,
-         1.6312e-01,  1.1300e+00,  1.8601e-01,  4.7185e-01,  4.4466e-01,
-        -5.9200e-01,  5.4432e-01,  3.6106e-01,  9.1395e-01, -2.7625e-01,
-         2.7235e-01, -1.5361e-01, -7.0213e-01,  7.8308e-02,  8.3940e-01,
-         1.3793e-03,  1.1936e+00,  1.0505e+00,  5.0011e-01, -2.3361e-01,
-         1.3433e+00, -1.1271e-01,  6.2518e-02,  4.0680e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 4.7434e-01,  5.3136e-01,  3.4447e-01,  2.6340e-01,  3.8525e-01,
-         1.7047e-02, -4.3931e-01, -8.8958e-01, -1.2484e-01, -4.2201e-01,
-         5.8387e-02, -4.2758e-01,  4.8165e-01,  2.4041e-01, -1.3549e+00,
-        -5.0273e-01, -6.8400e-01,  3.8064e-01, -9.7913e-01, -4.5901e-01,
-        -6.4912e-01, -2.2920e-02, -2.8090e-01, -5.8313e-01, -4.4491e-01,
-         5.3314e-01, -2.4882e-04, -3.0611e-02, -2.1588e-01,  1.5445e+00,
-        -5.4475e-01, -2.9967e-01, -9.9086e-01,  9.5095e-02, -6.5237e-01,
-        -4.6093e-01,  1.1036e+00,  0.0000e+00,  4.5985e-01,  9.1634e-01,
-         1.6312e-01,  1.1300e+00,  1.8601e-01,  4.7185e-01,  4.4466e-01,
-        -5.9200e-01,  5.4432e-01,  3.6106e-01,  9.1395e-01, -2.7625e-01,
-         2.7235e-01, -1.5361e-01, -7.0213e-01,  7.8308e-02,  8.3940e-01,
-         0.0000e+00,  1.1936e+00,  1.0505e+00,  5.0011e-01, -2.3361e-01,
-         1.3433e+00, -1.1271e-01,  6.2518e-02,  4.0680e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 4.7434e-01,  5.3136e-01,  3.4447e-01,  2.6340e-01,  3.8525e-01,
-         1.7047e-02, -4.3931e-01, -8.8958e-01, -1.2484e-01, -4.2201e-01,
-         5.8387e-02, -4.2758e-01,  4.8165e-01,  2.4041e-01, -1.3549e+00,
-        -5.0273e-01, -6.8400e-01,  3.8064e-01, -9.7913e-01, -4.5901e-01,
-        -6.4912e-01, -2.2920e-02, -2.8090e-01, -5.8313e-01, -4.4491e-01,
-         5.3314e-01, -2.4882e-04, -3.0611e-02, -2.1588e-01,  1.5445e+00,
-        -5.4475e-01, -2.9967e-01, -9.9086e-01,  9.5095e-02, -6.5237e-01,
-        -4.6093e-01,  1.1036e+00,  0.0000e+00,  4.5985e-01,  9.1634e-01,
-         1.6312e-01,  1.1300e+00,  1.8601e-01,  4.7185e-01,  4.4466e-01,
-        -5.9200e-01,  5.4432e-01,  3.6106e-01,  9.1395e-01, -2.7625e-01,
-         2.7235e-01, -1.5361e-01, -7.0213e-01,  7.8308e-02,  8.3940e-01,
-         0.0000e+00,  1.1936e+00,  1.0505e+00,  5.0011e-01, -2.3361e-01,
-         1.3433e+00, -1.1271e-01,  6.2518e-02,  4.0680e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 4.4686e-01,  4.9925e-01,  3.1454e-01,  2.2944e-01,  2.8628e-01,
-        -1.1367e-01, -4.4661e-01, -8.8967e-01, -1.7278e-01, -4.0789e-01,
-         5.8748e-02, -4.5186e-01,  4.7216e-01,  2.7655e-01, -1.3624e+00,
-        -5.0357e-01, -6.6253e-01,  3.6791e-01, -9.7459e-01, -4.7744e-01,
-        -6.2327e-01, -5.5031e-02, -2.3543e-01, -5.7939e-01, -3.9417e-01,
-         5.3328e-01, -7.1986e-02, -1.0139e-01, -2.0998e-01,  1.5374e+00,
-        -5.0120e-01, -2.8643e-01, -9.9618e-01,  6.2610e-02, -6.5971e-01,
-        -4.8190e-01,  1.1151e+00,  0.0000e+00,  5.0631e-01,  9.0601e-01,
-         7.9952e-02,  1.1289e+00,  1.5917e-01,  4.9404e-01,  4.2534e-01,
-        -5.6014e-01,  5.3539e-01,  3.5149e-01,  8.9095e-01, -2.6897e-01,
-         2.8291e-01, -2.1534e-01, -6.8663e-01,  9.5797e-02,  8.4580e-01,
-         1.2193e-03,  1.1923e+00,  1.0567e+00,  4.4160e-01, -2.8526e-01,
-         1.3486e+00, -2.1306e-01,  1.3604e-01,  4.8053e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4469,  0.4993,  0.3145,  0.2294,  0.2863, -0.1137, -0.4466, -0.8897,
-        -0.1728, -0.4079,  0.0587, -0.4519,  0.4722,  0.2765, -1.3624, -0.5036,
-        -0.6625,  0.3679, -0.9746, -0.4774, -0.6233, -0.0550, -0.2354, -0.5794,
-        -0.3942,  0.5333, -0.0720, -0.1014, -0.2100,  1.5374, -0.5012, -0.2864,
-        -0.9962,  0.0626, -0.6597, -0.4819,  1.1151,  0.0000,  0.5063,  0.9060,
-         0.0800,  1.1289,  0.1592,  0.4940,  0.4253, -0.5601,  0.5354,  0.3515,
-         0.8909, -0.2690,  0.2829, -0.2153, -0.6866,  0.0958,  0.8458,  0.0000,
-         1.1923,  1.0567,  0.4416, -0.2853,  1.3486, -0.2131,  0.1360,  0.4805],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4469,  0.4993,  0.3145,  0.2294,  0.2863, -0.1137, -0.4466, -0.8897,
-        -0.1728, -0.4079,  0.0587, -0.4519,  0.4722,  0.2765, -1.3624, -0.5036,
-        -0.6625,  0.3679, -0.9746, -0.4774, -0.6233, -0.0550, -0.2354, -0.5794,
-        -0.3942,  0.5333, -0.0720, -0.1014, -0.2100,  1.5374, -0.5012, -0.2864,
-        -0.9962,  0.0626, -0.6597, -0.4819,  1.1151,  0.0000,  0.5063,  0.9060,
-         0.0800,  1.1289,  0.1592,  0.4940,  0.4253, -0.5601,  0.5354,  0.3515,
-         0.8909, -0.2690,  0.2829, -0.2153, -0.6866,  0.0958,  0.8458,  0.0000,
-         1.1923,  1.0567,  0.4416, -0.2853,  1.3486, -0.2131,  0.1360,  0.4805],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.1489e-01,  4.6735e-01,  2.4744e-01,  2.2695e-01,  1.4792e-01,
-        -2.3412e-01, -4.4970e-01, -8.9424e-01, -2.1878e-01, -3.8687e-01,
-        -4.1711e-03, -5.1513e-01,  4.3972e-01,  3.3494e-01, -1.3702e+00,
-        -4.9991e-01, -6.4546e-01,  3.4746e-01, -9.7031e-01, -5.1515e-01,
-        -6.0155e-01, -2.4765e-02, -1.5839e-01, -5.6952e-01, -3.2878e-01,
-         5.3044e-01, -1.3007e-01, -1.5043e-01, -2.1795e-01,  1.5329e+00,
-        -4.4278e-01, -2.7795e-01, -1.0083e+00,  4.5158e-02, -6.6397e-01,
-        -5.2204e-01,  1.1237e+00,  0.0000e+00,  5.4498e-01,  8.9738e-01,
-         1.9331e-02,  1.1237e+00,  9.0092e-02,  5.1898e-01,  4.0142e-01,
-        -5.2836e-01,  5.2314e-01,  3.3275e-01,  8.7549e-01, -2.5520e-01,
-         2.7865e-01, -2.6177e-01, -6.8787e-01,  8.1429e-02,  8.4915e-01,
-         1.0766e-03,  1.1914e+00,  1.0611e+00,  4.0930e-01, -3.1299e-01,
-         1.3517e+00, -2.8304e-01,  1.8263e-01,  5.2554e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4149,  0.4673,  0.2474,  0.2270,  0.1479, -0.2341, -0.4497, -0.8942,
-        -0.2188, -0.3869, -0.0042, -0.5151,  0.4397,  0.3349, -1.3702, -0.4999,
-        -0.6455,  0.3475, -0.9703, -0.5152, -0.6015, -0.0248, -0.1584, -0.5695,
-        -0.3288,  0.5304, -0.1301, -0.1504, -0.2179,  1.5329, -0.4428, -0.2779,
-        -1.0083,  0.0452, -0.6640, -0.5220,  1.1237,  0.0000,  0.5450,  0.8974,
-         0.0193,  1.1237,  0.0901,  0.5190,  0.4014, -0.5284,  0.5231,  0.3328,
-         0.8755, -0.2552,  0.2786, -0.2618, -0.6879,  0.0814,  0.8492,  0.0000,
-         1.1914,  1.0611,  0.4093, -0.3130,  1.3517, -0.2830,  0.1826,  0.5255],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4149,  0.4673,  0.2474,  0.2270,  0.1479, -0.2341, -0.4497, -0.8942,
-        -0.2188, -0.3869, -0.0042, -0.5151,  0.4397,  0.3349, -1.3702, -0.4999,
-        -0.6455,  0.3475, -0.9703, -0.5152, -0.6015, -0.0248, -0.1584, -0.5695,
-        -0.3288,  0.5304, -0.1301, -0.1504, -0.2179,  1.5329, -0.4428, -0.2779,
-        -1.0083,  0.0452, -0.6640, -0.5220,  1.1237,  0.0000,  0.5450,  0.8974,
-         0.0193,  1.1237,  0.0901,  0.5190,  0.4014, -0.5284,  0.5231,  0.3328,
-         0.8755, -0.2552,  0.2786, -0.2618, -0.6879,  0.0814,  0.8492,  0.0000,
-         1.1914,  1.0611,  0.4093, -0.3130,  1.3517, -0.2830,  0.1826,  0.5255],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.8658e-01,  4.3207e-01,  2.0286e-01,  2.3765e-01,  7.5984e-02,
-        -3.3305e-01, -4.5326e-01, -8.9755e-01, -2.5361e-01, -3.7180e-01,
-        -5.4420e-02, -5.6576e-01,  4.2756e-01,  3.6967e-01, -1.3699e+00,
-        -5.0082e-01, -6.3573e-01,  3.1747e-01, -9.7002e-01, -5.3799e-01,
-        -5.8383e-01,  1.8670e-02, -9.7602e-02, -5.7139e-01, -2.7084e-01,
-         5.1331e-01, -1.9138e-01, -1.9863e-01, -2.2566e-01,  1.5297e+00,
-        -4.0576e-01, -2.9125e-01, -1.0203e+00,  4.0524e-02, -6.6513e-01,
-        -5.5201e-01,  1.1272e+00,  0.0000e+00,  5.6757e-01,  8.8297e-01,
-        -2.8388e-02,  1.1182e+00,  5.9735e-02,  5.3286e-01,  3.7572e-01,
-        -5.0664e-01,  5.1590e-01,  3.2802e-01,  8.7402e-01, -2.5172e-01,
-         3.0577e-01, -2.8437e-01, -6.9558e-01,  1.0266e-01,  8.4576e-01,
-         9.4944e-04,  1.1903e+00,  1.0571e+00,  3.9179e-01, -3.5700e-01,
-         1.3521e+00, -3.3488e-01,  2.3143e-01,  5.7027e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3866,  0.4321,  0.2029,  0.2376,  0.0760, -0.3331, -0.4533, -0.8976,
-        -0.2536, -0.3718, -0.0544, -0.5658,  0.4276,  0.3697, -1.3699, -0.5008,
-        -0.6357,  0.3175, -0.9700, -0.5380, -0.5838,  0.0187, -0.0976, -0.5714,
-        -0.2708,  0.5133, -0.1914, -0.1986, -0.2257,  1.5297, -0.4058, -0.2913,
-        -1.0203,  0.0405, -0.6651, -0.5520,  1.1272,  0.0000,  0.5676,  0.8830,
-        -0.0284,  1.1182,  0.0597,  0.5329,  0.3757, -0.5066,  0.5159,  0.3280,
-         0.8740, -0.2517,  0.3058, -0.2844, -0.6956,  0.1027,  0.8458,  0.0000,
-         1.1903,  1.0571,  0.3918, -0.3570,  1.3521, -0.3349,  0.2314,  0.5703],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3866,  0.4321,  0.2029,  0.2376,  0.0760, -0.3331, -0.4533, -0.8976,
-        -0.2536, -0.3718, -0.0544, -0.5658,  0.4276,  0.3697, -1.3699, -0.5008,
-        -0.6357,  0.3175, -0.9700, -0.5380, -0.5838,  0.0187, -0.0976, -0.5714,
-        -0.2708,  0.5133, -0.1914, -0.1986, -0.2257,  1.5297, -0.4058, -0.2913,
-        -1.0203,  0.0405, -0.6651, -0.5520,  1.1272,  0.0000,  0.5676,  0.8830,
-        -0.0284,  1.1182,  0.0597,  0.5329,  0.3757, -0.5066,  0.5159,  0.3280,
-         0.8740, -0.2517,  0.3058, -0.2844, -0.6956,  0.1027,  0.8458,  0.0000,
-         1.1903,  1.0571,  0.3918, -0.3570,  1.3521, -0.3349,  0.2314,  0.5703],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5832e-01,  4.1469e-01,  1.6834e-01,  2.3787e-01,  6.2973e-02,
-        -4.0067e-01, -4.6883e-01, -8.9622e-01, -2.5055e-01, -3.6187e-01,
-        -6.0464e-02, -5.8041e-01,  4.0624e-01,  3.8544e-01, -1.3688e+00,
-        -5.1530e-01, -6.4254e-01,  2.9577e-01, -9.7209e-01, -5.3324e-01,
-        -5.5106e-01,  7.3299e-02, -6.7736e-02, -5.7136e-01, -2.1825e-01,
-         4.7969e-01, -2.5544e-01, -2.5601e-01, -1.9099e-01,  1.5274e+00,
-        -4.1321e-01, -2.9038e-01, -1.0316e+00,  7.1851e-02, -6.6025e-01,
-        -5.6114e-01,  1.1279e+00,  0.0000e+00,  5.6557e-01,  8.6631e-01,
-        -8.3516e-02,  1.1080e+00,  7.3494e-02,  5.3421e-01,  3.6870e-01,
-        -5.0147e-01,  5.1789e-01,  3.1655e-01,  8.7763e-01, -2.7093e-01,
-         3.4129e-01, -2.3053e-01, -7.0188e-01,  1.2035e-01,  8.4321e-01,
-         8.3629e-04,  1.1892e+00,  1.0512e+00,  3.9486e-01, -3.7781e-01,
-         1.3520e+00, -3.9737e-01,  2.6004e-01,  6.2106e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3583,  0.4147,  0.1683,  0.2379,  0.0630, -0.4007, -0.4688, -0.8962,
-        -0.2505, -0.3619, -0.0605, -0.5804,  0.4062,  0.3854, -1.3688, -0.5153,
-        -0.6425,  0.2958, -0.9721, -0.5332, -0.5511,  0.0733, -0.0677, -0.5714,
-        -0.2183,  0.4797, -0.2554, -0.2560, -0.1910,  1.5274, -0.4132, -0.2904,
-        -1.0316,  0.0719, -0.6603, -0.5611,  1.1279,  0.0000,  0.5656,  0.8663,
-        -0.0835,  1.1080,  0.0735,  0.5342,  0.3687, -0.5015,  0.5179,  0.3166,
-         0.8776, -0.2709,  0.3413, -0.2305, -0.7019,  0.1203,  0.8432,  0.0000,
-         1.1892,  1.0512,  0.3949, -0.3778,  1.3520, -0.3974,  0.2600,  0.6211],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3583,  0.4147,  0.1683,  0.2379,  0.0630, -0.4007, -0.4688, -0.8962,
-        -0.2505, -0.3619, -0.0605, -0.5804,  0.4062,  0.3854, -1.3688, -0.5153,
-        -0.6425,  0.2958, -0.9721, -0.5332, -0.5511,  0.0733, -0.0677, -0.5714,
-        -0.2183,  0.4797, -0.2554, -0.2560, -0.1910,  1.5274, -0.4132, -0.2904,
-        -1.0316,  0.0719, -0.6603, -0.5611,  1.1279,  0.0000,  0.5656,  0.8663,
-        -0.0835,  1.1080,  0.0735,  0.5342,  0.3687, -0.5015,  0.5179,  0.3166,
-         0.8776, -0.2709,  0.3413, -0.2305, -0.7019,  0.1203,  0.8432,  0.0000,
-         1.1892,  1.0512,  0.3949, -0.3778,  1.3520, -0.3974,  0.2600,  0.6211],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2526e-01,  4.0956e-01,  1.2180e-01,  1.8539e-01, -5.2661e-03,
-        -4.5072e-01, -4.7692e-01, -8.8702e-01, -2.0640e-01, -3.6031e-01,
-        -5.8457e-02, -5.6045e-01,  3.3965e-01,  3.4610e-01, -1.3687e+00,
-        -5.3847e-01, -6.5621e-01,  2.5606e-01, -9.7466e-01, -5.1314e-01,
-        -5.2988e-01,  9.2788e-02, -6.9188e-02, -5.6227e-01, -2.0315e-01,
-         4.3293e-01, -2.7233e-01, -3.0383e-01, -1.6673e-01,  1.5235e+00,
-        -4.3152e-01, -2.6211e-01, -1.0452e+00,  1.0224e-01, -6.5439e-01,
-        -5.4628e-01,  1.1253e+00,  0.0000e+00,  5.3577e-01,  8.4370e-01,
-        -1.7305e-01,  1.0959e+00,  9.2268e-02,  5.2217e-01,  3.6916e-01,
-        -4.9427e-01,  5.2949e-01,  3.0132e-01,  8.8039e-01, -2.8069e-01,
-         3.8209e-01, -1.7640e-01, -6.9605e-01,  1.3251e-01,  8.4039e-01,
-         7.3576e-04,  1.1878e+00,  1.0463e+00,  4.1179e-01, -4.3116e-01,
-         1.3512e+00, -4.4218e-01,  2.5608e-01,  6.5839e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3253,  0.4096,  0.1218,  0.1854, -0.0053, -0.4507, -0.4769, -0.8870,
-        -0.2064, -0.3603, -0.0585, -0.5605,  0.3397,  0.3461, -1.3687, -0.5385,
-        -0.6562,  0.2561, -0.9747, -0.5131, -0.5299,  0.0928, -0.0692, -0.5623,
-        -0.2032,  0.4329, -0.2723, -0.3038, -0.1667,  1.5235, -0.4315, -0.2621,
-         0.0000,  0.1022, -0.6544, -0.5463,  1.1253,  0.0000,  0.5358,  0.8437,
-        -0.1730,  1.0959,  0.0923,  0.5222,  0.3692, -0.4943,  0.5295,  0.3013,
-         0.8804, -0.2807,  0.3821, -0.1764, -0.6960,  0.1325,  0.8404,  0.0000,
-         1.1878,  1.0463,  0.4118, -0.4312,  1.3512, -0.4422,  0.2561,  0.6584],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3253,  0.4096,  0.1218,  0.1854, -0.0053, -0.4507, -0.4769, -0.8870,
-        -0.2064, -0.3603, -0.0585, -0.5605,  0.3397,  0.3461, -1.3687, -0.5385,
-        -0.6562,  0.2561, -0.9747, -0.5131, -0.5299,  0.0928, -0.0692, -0.5623,
-        -0.2032,  0.4329, -0.2723, -0.3038, -0.1667,  1.5235, -0.4315, -0.2621,
-         0.0000,  0.1022, -0.6544, -0.5463,  1.1253,  0.0000,  0.5358,  0.8437,
-        -0.1730,  1.0959,  0.0923,  0.5222,  0.3692, -0.4943,  0.5295,  0.3013,
-         0.8804, -0.2807,  0.3821, -0.1764, -0.6960,  0.1325,  0.8404,  0.0000,
-         1.1878,  1.0463,  0.4118, -0.4312,  1.3512, -0.4422,  0.2561,  0.6584],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8335e-01,  3.7645e-01,  7.2242e-02,  9.0517e-02, -9.5988e-02,
-        -4.8812e-01, -4.7566e-01, -8.7639e-01, -1.5875e-01, -3.6006e-01,
-        -4.5103e-02, -5.1877e-01,  2.3706e-01,  3.1634e-01, -1.3701e+00,
-        -5.6253e-01, -6.5865e-01,  1.8215e-01, -9.7692e-01, -4.7606e-01,
-        -5.1803e-01,  3.9137e-02, -7.8780e-02, -5.5568e-01, -2.0589e-01,
-         3.8830e-01, -2.7659e-01, -3.4333e-01, -1.3575e-01,  1.5201e+00,
-        -4.5991e-01, -2.4152e-01, -1.1976e-02,  8.1002e-02, -6.4323e-01,
-        -5.2810e-01,  1.1240e+00,  0.0000e+00,  5.3377e-01,  8.2152e-01,
-        -2.8555e-01,  1.0840e+00,  8.8123e-02,  5.1602e-01,  3.7535e-01,
-        -4.8682e-01,  5.4024e-01,  2.8516e-01,  8.8603e-01, -2.7155e-01,
-         3.9116e-01, -1.3488e-01, -6.8366e-01,  4.1572e-02,  8.3335e-01,
-         6.4655e-04,  1.1855e+00,  1.0445e+00,  4.4608e-01, -5.2146e-01,
-         1.3502e+00, -4.7163e-01,  2.6009e-01,  6.6427e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2833,  0.3764,  0.0722,  0.0905, -0.0960, -0.4881, -0.4757, -0.8764,
-        -0.1588, -0.3601, -0.0451, -0.5188,  0.2371,  0.3163, -1.3701, -0.5625,
-        -0.6586,  0.1821, -0.9769, -0.4761, -0.5180,  0.0391, -0.0788, -0.5557,
-        -0.2059,  0.3883, -0.2766, -0.3433, -0.1358,  1.5201, -0.4599, -0.2415,
-         0.0000,  0.0810, -0.6432, -0.5281,  1.1240,  0.0000,  0.5338,  0.8215,
-        -0.2855,  1.0840,  0.0881,  0.5160,  0.3754, -0.4868,  0.5402,  0.2852,
-         0.8860, -0.2715,  0.3912, -0.1349, -0.6837,  0.0416,  0.8334,  0.0000,
-         1.1855,  1.0445,  0.4461, -0.5215,  1.3502, -0.4716,  0.2601,  0.6643],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2833,  0.3764,  0.0722,  0.0905, -0.0960, -0.4881, -0.4757, -0.8764,
-        -0.1588, -0.3601, -0.0451, -0.5188,  0.2371,  0.3163, -1.3701, -0.5625,
-        -0.6586,  0.1821, -0.9769, -0.4761, -0.5180,  0.0391, -0.0788, -0.5557,
-        -0.2059,  0.3883, -0.2766, -0.3433, -0.1358,  1.5201, -0.4599, -0.2415,
-         0.0000,  0.0810, -0.6432, -0.5281,  1.1240,  0.0000,  0.5338,  0.8215,
-        -0.2855,  1.0840,  0.0881,  0.5160,  0.3754, -0.4868,  0.5402,  0.2852,
-         0.8860, -0.2715,  0.3912, -0.1349, -0.6837,  0.0416,  0.8334,  0.0000,
-         1.1855,  1.0445,  0.4461, -0.5215,  1.3502, -0.4716,  0.2601,  0.6643],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4234e-01,  3.3863e-01,  3.4895e-02,  2.5590e-02, -1.3559e-01,
-        -5.2509e-01, -4.6994e-01, -8.6440e-01, -1.1404e-01, -3.5915e-01,
-        -7.6912e-03, -4.9026e-01,  1.2877e-01,  3.0387e-01, -1.3708e+00,
-        -5.8762e-01, -6.5281e-01,  8.4058e-02, -9.7731e-01, -4.4821e-01,
-        -4.9260e-01, -6.7715e-02, -9.7038e-02, -5.2340e-01, -2.1808e-01,
-         3.5036e-01, -2.6479e-01, -3.6615e-01, -6.8418e-02,  1.5187e+00,
-        -4.6735e-01, -2.1986e-01, -1.0512e-02,  2.9859e-02, -6.3482e-01,
-        -5.1298e-01,  1.1219e+00,  0.0000e+00,  5.2084e-01,  7.9900e-01,
-        -3.8237e-01,  1.0663e+00,  7.0612e-02,  5.1228e-01,  3.7495e-01,
-        -4.7571e-01,  5.4339e-01,  2.8009e-01,  8.9329e-01, -2.4979e-01,
-         3.5966e-01, -1.0902e-01, -6.6832e-01, -6.5344e-02,  8.2540e-01,
-         5.6750e-04,  1.1828e+00,  1.0378e+00,  4.8432e-01, -5.6467e-01,
-         1.3480e+00, -4.7184e-01,  2.4507e-01,  6.6420e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2423,  0.3386,  0.0349,  0.0256, -0.1356, -0.5251, -0.4699, -0.8644,
-        -0.1140, -0.3591, -0.0077, -0.4903,  0.1288,  0.3039, -1.3708, -0.5876,
-        -0.6528,  0.0841, -0.9773, -0.4482, -0.4926, -0.0677, -0.0970, -0.5234,
-        -0.2181,  0.3504, -0.2648, -0.3661, -0.0684,  1.5187, -0.4673, -0.2199,
-         0.0000,  0.0299, -0.6348, -0.5130,  1.1219,  0.0000,  0.5208,  0.7990,
-        -0.3824,  1.0663,  0.0706,  0.5123,  0.3750, -0.4757,  0.5434,  0.2801,
-         0.8933, -0.2498,  0.3597, -0.1090, -0.6683, -0.0653,  0.8254,  0.0000,
-         1.1828,  1.0378,  0.4843, -0.5647,  1.3480, -0.4718,  0.2451,  0.6642],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2423,  0.3386,  0.0349,  0.0256, -0.1356, -0.5251, -0.4699, -0.8644,
-        -0.1140, -0.3591, -0.0077, -0.4903,  0.1288,  0.3039, -1.3708, -0.5876,
-        -0.6528,  0.0841, -0.9773, -0.4482, -0.4926, -0.0677, -0.0970, -0.5234,
-        -0.2181,  0.3504, -0.2648, -0.3661, -0.0684,  1.5187, -0.4673, -0.2199,
-         0.0000,  0.0299, -0.6348, -0.5130,  1.1219,  0.0000,  0.5208,  0.7990,
-        -0.3824,  1.0663,  0.0706,  0.5123,  0.3750, -0.4757,  0.5434,  0.2801,
-         0.8933, -0.2498,  0.3597, -0.1090, -0.6683, -0.0653,  0.8254,  0.0000,
-         1.1828,  1.0378,  0.4843, -0.5647,  1.3480, -0.4718,  0.2451,  0.6642],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.2510e-01,  3.0338e-01,  1.1492e-02, -1.2641e-02, -1.3784e-01,
-        -5.4989e-01, -4.5570e-01, -8.4936e-01, -8.4999e-02, -3.5397e-01,
-         5.5501e-02, -4.6702e-01,  1.1279e-02,  3.0916e-01, -1.3711e+00,
-        -6.0596e-01, -6.4047e-01, -4.1169e-02, -9.7885e-01, -4.2125e-01,
-        -4.5465e-01, -1.2476e-01, -1.0983e-01, -4.8359e-01, -1.8692e-01,
-         2.9702e-01, -2.3227e-01, -3.7419e-01,  3.2989e-02,  1.5181e+00,
-        -4.8642e-01, -2.0648e-01, -9.2159e-03, -2.4378e-02, -6.2438e-01,
-        -4.9296e-01,  1.1179e+00,  0.0000e+00,  5.2893e-01,  7.7672e-01,
-        -4.4759e-01,  1.0488e+00,  6.8724e-02,  5.2019e-01,  3.5787e-01,
-        -4.7100e-01,  5.5387e-01,  2.8443e-01,  9.1257e-01, -2.1086e-01,
-         3.3195e-01, -5.1742e-02, -6.5144e-01, -1.1947e-01,  8.0996e-01,
-         4.9755e-04,  1.1798e+00,  1.0248e+00,  5.3054e-01, -5.8665e-01,
-         1.3438e+00, -4.5631e-01,  2.5315e-01,  6.8225e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2251,  0.3034,  0.0115, -0.0126, -0.1378, -0.5499, -0.4557, -0.8494,
-        -0.0850, -0.3540,  0.0555, -0.4670,  0.0113,  0.3092, -1.3711, -0.6060,
-        -0.6405, -0.0412, -0.9789, -0.4212, -0.4547, -0.1248, -0.1098, -0.4836,
-        -0.1869,  0.2970, -0.2323, -0.3742,  0.0330,  1.5181, -0.4864, -0.2065,
-         0.0000, -0.0244, -0.6244, -0.4930,  1.1179,  0.0000,  0.5289,  0.7767,
-        -0.4476,  1.0488,  0.0687,  0.5202,  0.3579, -0.4710,  0.5539,  0.2844,
-         0.9126, -0.2109,  0.3319, -0.0517, -0.6514, -0.1195,  0.8100,  0.0000,
-         1.1798,  1.0248,  0.5305, -0.5867,  1.3438, -0.4563,  0.2532,  0.6823],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2251,  0.3034,  0.0115, -0.0126, -0.1378, -0.5499, -0.4557, -0.8494,
-        -0.0850, -0.3540,  0.0555, -0.4670,  0.0113,  0.3092, -1.3711, -0.6060,
-        -0.6405, -0.0412, -0.9789, -0.4212, -0.4547, -0.1248, -0.1098, -0.4836,
-        -0.1869,  0.2970, -0.2323, -0.3742,  0.0330,  1.5181, -0.4864, -0.2065,
-         0.0000, -0.0244, -0.6244, -0.4930,  1.1179,  0.0000,  0.5289,  0.7767,
-        -0.4476,  1.0488,  0.0687,  0.5202,  0.3579, -0.4710,  0.5539,  0.2844,
-         0.9126, -0.2109,  0.3319, -0.0517, -0.6514, -0.1195,  0.8100,  0.0000,
-         1.1798,  1.0248,  0.5305, -0.5867,  1.3438, -0.4563,  0.2532,  0.6823],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.3037e-01,  2.9201e-01, -3.1206e-02, -4.7282e-02, -6.9462e-02,
-        -5.7195e-01, -4.2314e-01, -8.3676e-01, -4.6083e-02, -3.5037e-01,
-         1.1067e-01, -4.5478e-01, -1.2709e-01,  2.9887e-01, -1.3727e+00,
-        -6.0974e-01, -6.1831e-01, -1.3539e-01, -9.7930e-01, -4.0291e-01,
-        -4.0710e-01, -1.1956e-01, -1.1495e-01, -4.2807e-01, -1.0415e-01,
-         2.6321e-01, -1.8016e-01, -3.6382e-01,  1.2106e-01,  1.5182e+00,
-        -4.9538e-01, -2.0564e-01, -8.0709e-03, -1.8849e-02, -6.1295e-01,
-        -4.7795e-01,  1.1156e+00,  0.0000e+00,  5.4083e-01,  7.5445e-01,
-        -4.8830e-01,  1.0327e+00,  8.3477e-02,  5.3046e-01,  3.3649e-01,
-        -4.4017e-01,  5.6419e-01,  2.9533e-01,  9.3467e-01, -1.5506e-01,
-         3.1528e-01,  1.9320e-02, -6.2648e-01, -1.0191e-01,  7.9118e-01,
-         4.3574e-04,  1.1760e+00,  1.0123e+00,  5.8793e-01, -5.6897e-01,
-         1.3386e+00, -4.3941e-01,  2.6283e-01,  7.1293e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2304,  0.2920, -0.0312, -0.0473, -0.0695, -0.5719, -0.4231, -0.8368,
-        -0.0461, -0.3504,  0.1107, -0.4548, -0.1271,  0.2989, -1.3727, -0.6097,
-        -0.6183, -0.1354, -0.9793, -0.4029, -0.4071, -0.1196, -0.1149, -0.4281,
-        -0.1041,  0.2632, -0.1802, -0.3638,  0.1211,  1.5182, -0.4954, -0.2056,
-         0.0000, -0.0188, -0.6130, -0.4779,  1.1156,  0.0000,  0.5408,  0.7545,
-        -0.4883,  1.0327,  0.0835,  0.5305,  0.3365, -0.4402,  0.5642,  0.2953,
-         0.9347, -0.1551,  0.3153,  0.0193, -0.6265, -0.1019,  0.7912,  0.0000,
-         1.1760,  1.0123,  0.5879, -0.5690,  1.3386, -0.4394,  0.2628,  0.7129],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2304,  0.2920, -0.0312, -0.0473, -0.0695, -0.5719, -0.4231, -0.8368,
-        -0.0461, -0.3504,  0.1107, -0.4548, -0.1271,  0.2989, -1.3727, -0.6097,
-        -0.6183, -0.1354, -0.9793, -0.4029, -0.4071, -0.1196, -0.1149, -0.4281,
-        -0.1041,  0.2632, -0.1802, -0.3638,  0.1211,  1.5182, -0.4954, -0.2056,
-         0.0000, -0.0188, -0.6130, -0.4779,  1.1156,  0.0000,  0.5408,  0.7545,
-        -0.4883,  1.0327,  0.0835,  0.5305,  0.3365, -0.4402,  0.5642,  0.2953,
-         0.9347, -0.1551,  0.3153,  0.0193, -0.6265, -0.1019,  0.7912,  0.0000,
-         1.1760,  1.0123,  0.5879, -0.5690,  1.3386, -0.4394,  0.2628,  0.7129],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.5148e-01,  2.3096e-01, -7.0476e-02, -6.4816e-02, -9.3666e-03,
-        -5.8261e-01, -3.8003e-01, -8.2577e-01, -1.6536e-03, -3.2015e-01,
-         8.9324e-02, -4.4685e-01, -2.5598e-01,  3.6009e-01, -1.3754e+00,
-        -6.1005e-01, -6.0531e-01, -2.1327e-01, -9.7892e-01, -4.0289e-01,
-        -3.4515e-01, -1.9114e-02, -8.7746e-02, -4.1921e-01,  3.0078e-02,
-         2.6468e-01, -8.9257e-02, -3.2142e-01,  1.7207e-01,  1.5187e+00,
-        -5.0466e-01, -2.0565e-01, -7.0603e-03,  3.9014e-02, -6.0936e-01,
-        -4.9118e-01,  1.1190e+00,  0.0000e+00,  6.3535e-01,  7.5230e-01,
-        -4.8854e-01,  1.0188e+00,  7.2376e-02,  5.7060e-01,  3.0097e-01,
-        -3.7380e-01,  5.6350e-01,  3.4461e-01,  9.5013e-01, -1.0100e-01,
-         2.7893e-01,  1.1064e-01, -5.8045e-01,  7.0618e-03,  7.7870e-01,
-         3.8117e-04,  1.1715e+00,  1.0237e+00,  6.3909e-01, -5.3812e-01,
-         1.3357e+00, -3.9157e-01,  2.7676e-01,  7.2371e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2515,  0.2310, -0.0705, -0.0648, -0.0094, -0.5826, -0.3800, -0.8258,
-        -0.0017, -0.3201,  0.0893, -0.4469, -0.2560,  0.3601, -1.3754, -0.6100,
-        -0.6053, -0.2133, -0.9789, -0.4029, -0.3451, -0.0191, -0.0877, -0.4192,
-         0.0301,  0.2647, -0.0893, -0.3214,  0.1721,  1.5187, -0.5047, -0.2057,
-         0.0000,  0.0390, -0.6094, -0.4912,  1.1190,  0.0000,  0.6353,  0.7523,
-        -0.4885,  1.0188,  0.0724,  0.5706,  0.3010, -0.3738,  0.5635,  0.3446,
-         0.9501, -0.1010,  0.2789,  0.1106, -0.5805,  0.0071,  0.7787,  0.0000,
-         1.1715,  1.0237,  0.6391, -0.5381,  1.3357, -0.3916,  0.2768,  0.7237],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2515,  0.2310, -0.0705, -0.0648, -0.0094, -0.5826, -0.3800, -0.8258,
-        -0.0017, -0.3201,  0.0893, -0.4469, -0.2560,  0.3601, -1.3754, -0.6100,
-        -0.6053, -0.2133, -0.9789, -0.4029, -0.3451, -0.0191, -0.0877, -0.4192,
-         0.0301,  0.2647, -0.0893, -0.3214,  0.1721,  1.5187, -0.5047, -0.2057,
-         0.0000,  0.0390, -0.6094, -0.4912,  1.1190,  0.0000,  0.6353,  0.7523,
-        -0.4885,  1.0188,  0.0724,  0.5706,  0.3010, -0.3738,  0.5635,  0.3446,
-         0.9501, -0.1010,  0.2789,  0.1106, -0.5805,  0.0071,  0.7787,  0.0000,
-         1.1715,  1.0237,  0.6391, -0.5381,  1.3357, -0.3916,  0.2768,  0.7237],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8784e-01,  1.9606e-01, -3.0975e-02, -6.8456e-02,  7.4342e-02,
-        -5.8470e-01, -3.3447e-01, -8.1293e-01,  4.6377e-02, -3.0117e-01,
-         1.0643e-01, -4.1663e-01, -2.9100e-01,  4.1550e-01, -1.3795e+00,
-        -6.0941e-01, -5.9117e-01, -2.8134e-01, -9.7999e-01, -3.9381e-01,
-        -2.8723e-01,  1.5130e-02, -9.7310e-02, -3.9246e-01,  9.1345e-02,
-         2.7114e-01,  2.5741e-02, -2.9561e-01,  2.3816e-01,  1.5207e+00,
-        -5.1283e-01, -1.8798e-01, -6.1695e-03,  4.0847e-02, -6.0354e-01,
-        -5.0087e-01,  1.1245e+00,  0.0000e+00,  7.2577e-01,  7.5191e-01,
-        -4.8876e-01,  1.0086e+00,  9.4448e-02,  5.9994e-01,  2.7256e-01,
-        -3.2912e-01,  5.6066e-01,  3.9535e-01,  9.5921e-01, -5.0561e-03,
-         2.6296e-01,  1.6482e-01, -5.5183e-01,  1.1016e-01,  7.6922e-01,
-         3.3308e-04,  1.1670e+00,  1.0388e+00,  6.9050e-01, -5.1265e-01,
-         1.3343e+00, -3.3464e-01,  2.9953e-01,  7.2662e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2878,  0.1961, -0.0310, -0.0685,  0.0743, -0.5847, -0.3345, -0.8129,
-         0.0464, -0.3012,  0.1064, -0.4166, -0.2910,  0.4155, -1.3795, -0.6094,
-        -0.5912, -0.2813, -0.9800, -0.3938, -0.2872,  0.0151, -0.0973, -0.3925,
-         0.0913,  0.2711,  0.0257, -0.2956,  0.2382,  1.5207, -0.5128, -0.1880,
-         0.0000,  0.0408, -0.6035, -0.5009,  1.1245,  0.0000,  0.7258,  0.7519,
-        -0.4888,  1.0086,  0.0944,  0.5999,  0.2726, -0.3291,  0.5607,  0.3953,
-         0.9592, -0.0051,  0.2630,  0.1648, -0.5518,  0.1102,  0.7692,  0.0000,
-         1.1670,  1.0388,  0.6905, -0.5126,  1.3343, -0.3346,  0.2995,  0.7266],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2878,  0.1961, -0.0310, -0.0685,  0.0743, -0.5847, -0.3345, -0.8129,
-         0.0464, -0.3012,  0.1064, -0.4166, -0.2910,  0.4155, -1.3795, -0.6094,
-        -0.5912, -0.2813, -0.9800, -0.3938, -0.2872,  0.0151, -0.0973, -0.3925,
-         0.0913,  0.2711,  0.0257, -0.2956,  0.2382,  1.5207, -0.5128, -0.1880,
-         0.0000,  0.0408, -0.6035, -0.5009,  1.1245,  0.0000,  0.7258,  0.7519,
-        -0.4888,  1.0086,  0.0944,  0.5999,  0.2726, -0.3291,  0.5607,  0.3953,
-         0.9592, -0.0051,  0.2630,  0.1648, -0.5518,  0.1102,  0.7692,  0.0000,
-         1.1670,  1.0388,  0.6905, -0.5126,  1.3343, -0.3346,  0.2995,  0.7266],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2016e-01,  1.6902e-01,  6.5960e-02, -5.2243e-02,  5.4671e-02,
-        -5.8142e-01, -2.9816e-01, -7.9898e-01,  9.6247e-02, -2.7137e-01,
-         1.3916e-01, -3.9343e-01, -3.1070e-01,  4.5726e-01, -1.3844e+00,
-        -6.0476e-01, -5.8012e-01, -3.5720e-01, -9.8188e-01, -3.8313e-01,
-        -2.2709e-01,  4.4916e-02, -9.1694e-02, -4.2204e-01,  1.6743e-01,
-         2.7368e-01,  1.5190e-01, -2.5995e-01,  3.1660e-01,  1.5222e+00,
-        -5.1727e-01, -1.6358e-01, -5.3853e-03,  4.5528e-02, -5.9365e-01,
-        -5.0034e-01,  1.1287e+00,  0.0000e+00,  8.0447e-01,  7.4696e-01,
-        -4.7955e-01,  1.0071e+00,  1.3333e-01,  6.2559e-01,  2.7621e-01,
-        -2.8693e-01,  5.5761e-01,  4.3043e-01,  9.5421e-01,  1.2754e-01,
-         2.5680e-01,  2.3589e-01, -5.3667e-01,  1.7479e-01,  7.6241e-01,
-         2.9074e-04,  1.1624e+00,  1.0563e+00,  7.3782e-01, -5.1307e-01,
-         1.3352e+00, -2.6328e-01,  3.2791e-01,  7.1292e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3202,  0.1690,  0.0660, -0.0522,  0.0547, -0.5814, -0.2982, -0.7990,
-         0.0962, -0.2714,  0.1392, -0.3934, -0.3107,  0.4573, -1.3844, -0.6048,
-        -0.5801, -0.3572, -0.9819, -0.3831, -0.2271,  0.0449, -0.0917, -0.4220,
-         0.1674,  0.2737,  0.1519, -0.2600,  0.3166,  1.5222, -0.5173, -0.1636,
-         0.0000,  0.0455, -0.5936, -0.5003,  1.1287,  0.0000,  0.8045,  0.7470,
-        -0.4796,  1.0071,  0.1333,  0.6256,  0.2762, -0.2869,  0.5576,  0.4304,
-         0.9542,  0.1275,  0.2568,  0.2359, -0.5367,  0.1748,  0.7624,  0.0000,
-         1.1624,  1.0563,  0.7378, -0.5131,  1.3352, -0.2633,  0.3279,  0.7129],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3202,  0.1690,  0.0660, -0.0522,  0.0547, -0.5814, -0.2982, -0.7990,
-         0.0962, -0.2714,  0.1392, -0.3934, -0.3107,  0.4573, -1.3844, -0.6048,
-        -0.5801, -0.3572, -0.9819, -0.3831, -0.2271,  0.0449, -0.0917, -0.4220,
-         0.1674,  0.2737,  0.1519, -0.2600,  0.3166,  1.5222, -0.5173, -0.1636,
-         0.0000,  0.0455, -0.5936, -0.5003,  1.1287,  0.0000,  0.8045,  0.7470,
-        -0.4796,  1.0071,  0.1333,  0.6256,  0.2762, -0.2869,  0.5576,  0.4304,
-         0.9542,  0.1275,  0.2568,  0.2359, -0.5367,  0.1748,  0.7624,  0.0000,
-         1.1624,  1.0563,  0.7378, -0.5131,  1.3352, -0.2633,  0.3279,  0.7129],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.3870e-01,  1.9632e-01,  1.2374e-01, -2.8232e-02,  1.9848e-02,
-        -5.7453e-01, -2.6936e-01, -7.8782e-01,  1.5332e-01, -2.4755e-01,
-         1.0010e-01, -3.4842e-01, -3.3096e-01,  4.5476e-01, -1.3890e+00,
-        -5.9188e-01, -5.6845e-01, -4.2543e-01, -9.8422e-01, -3.6692e-01,
-        -1.7175e-01,  6.1644e-02, -9.7776e-02, -4.2134e-01,  2.1071e-01,
-         2.5768e-01,  2.6931e-01, -2.3119e-01,  3.5724e-01,  1.5225e+00,
-        -4.9909e-01, -1.5195e-01, -4.6958e-03,  1.4643e-03, -5.7885e-01,
-        -4.9063e-01,  1.1356e+00,  0.0000e+00,  8.6379e-01,  7.3703e-01,
-        -4.7238e-01,  1.0119e+00,  1.5884e-01,  6.3738e-01,  2.8576e-01,
-        -2.4326e-01,  5.5101e-01,  4.3891e-01,  9.4219e-01,  2.1345e-01,
-         2.6593e-01,  2.6503e-01, -5.3963e-01,  2.4435e-01,  7.5653e-01,
-         2.5352e-04,  1.1580e+00,  1.0762e+00,  7.7085e-01, -5.1182e-01,
-         1.3385e+00, -1.9836e-01,  3.3044e-01,  7.1001e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.3870e-01,  1.9632e-01,  1.2374e-01, -2.8232e-02,  1.9848e-02,
-        -5.7453e-01, -2.6936e-01, -7.8782e-01,  1.5332e-01, -2.4755e-01,
-         1.0010e-01, -3.4842e-01, -3.3096e-01,  4.5476e-01, -1.3890e+00,
-        -5.9188e-01, -5.6845e-01, -4.2543e-01, -9.8422e-01, -3.6692e-01,
-        -1.7175e-01,  6.1644e-02, -9.7776e-02, -4.2134e-01,  2.1071e-01,
-         2.5768e-01,  2.6931e-01, -2.3119e-01,  3.5724e-01,  1.5225e+00,
-        -4.9909e-01, -1.5195e-01,  0.0000e+00,  1.4643e-03, -5.7885e-01,
-        -4.9063e-01,  1.1356e+00,  0.0000e+00,  8.6379e-01,  7.3703e-01,
-        -4.7238e-01,  1.0119e+00,  1.5884e-01,  6.3738e-01,  2.8576e-01,
-        -2.4326e-01,  5.5101e-01,  4.3891e-01,  9.4219e-01,  2.1345e-01,
-         2.6593e-01,  2.6503e-01, -5.3963e-01,  2.4435e-01,  7.5653e-01,
-         0.0000e+00,  1.1580e+00,  1.0762e+00,  7.7085e-01, -5.1182e-01,
-         1.3385e+00, -1.9836e-01,  3.3044e-01,  7.1001e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.3870e-01,  1.9632e-01,  1.2374e-01, -2.8232e-02,  1.9848e-02,
-        -5.7453e-01, -2.6936e-01, -7.8782e-01,  1.5332e-01, -2.4755e-01,
-         1.0010e-01, -3.4842e-01, -3.3096e-01,  4.5476e-01, -1.3890e+00,
-        -5.9188e-01, -5.6845e-01, -4.2543e-01, -9.8422e-01, -3.6692e-01,
-        -1.7175e-01,  6.1644e-02, -9.7776e-02, -4.2134e-01,  2.1071e-01,
-         2.5768e-01,  2.6931e-01, -2.3119e-01,  3.5724e-01,  1.5225e+00,
-        -4.9909e-01, -1.5195e-01,  0.0000e+00,  1.4643e-03, -5.7885e-01,
-        -4.9063e-01,  1.1356e+00,  0.0000e+00,  8.6379e-01,  7.3703e-01,
-        -4.7238e-01,  1.0119e+00,  1.5884e-01,  6.3738e-01,  2.8576e-01,
-        -2.4326e-01,  5.5101e-01,  4.3891e-01,  9.4219e-01,  2.1345e-01,
-         2.6593e-01,  2.6503e-01, -5.3963e-01,  2.4435e-01,  7.5653e-01,
-         0.0000e+00,  1.1580e+00,  1.0762e+00,  7.7085e-01, -5.1182e-01,
-         1.3385e+00, -1.9836e-01,  3.3044e-01,  7.1001e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7427e-01,  2.3862e-01,  1.1569e-01, -1.1472e-04, -4.9408e-02,
-        -5.6334e-01, -2.7061e-01, -7.7616e-01,  1.8241e-01, -2.4802e-01,
-        -9.7585e-02, -2.3705e-01, -2.7083e-01,  4.4199e-01, -1.3913e+00,
-        -5.8010e-01, -5.5689e-01, -5.0147e-01, -9.8966e-01, -3.4045e-01,
-        -1.3759e-01,  6.7824e-02, -1.2590e-01, -3.5126e-01,  2.3472e-01,
-         2.4121e-01,  2.6258e-01, -2.5120e-01,  3.2060e-01,  1.5240e+00,
-        -4.6840e-01, -1.3575e-01, -4.0903e-03, -6.3776e-02, -5.6554e-01,
-        -4.8846e-01,  1.1431e+00,  0.0000e+00,  9.1539e-01,  7.3696e-01,
-        -5.0517e-01,  1.0125e+00,  1.0421e-01,  6.3970e-01,  2.9006e-01,
-        -2.6576e-01,  5.4095e-01,  4.0371e-01,  9.2974e-01,  2.4050e-01,
-         2.3361e-01,  2.7879e-01, -5.4637e-01,  2.5070e-01,  7.5700e-01,
-         2.2083e-04,  1.1535e+00,  1.1003e+00,  7.8545e-01, -5.0040e-01,
-         1.3426e+00, -1.6117e-01,  3.2458e-01,  6.5906e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.7427e-01,  2.3862e-01,  1.1569e-01, -1.1472e-04, -4.9408e-02,
-        -5.6334e-01, -2.7061e-01, -7.7616e-01,  1.8241e-01, -2.4802e-01,
-        -9.7585e-02, -2.3705e-01, -2.7083e-01,  4.4199e-01, -1.3913e+00,
-        -5.8010e-01, -5.5689e-01, -5.0147e-01, -9.8966e-01, -3.4045e-01,
-        -1.3759e-01,  6.7824e-02, -1.2590e-01, -3.5126e-01,  2.3472e-01,
-         2.4121e-01,  2.6258e-01, -2.5120e-01,  3.2060e-01,  1.5240e+00,
-        -4.6840e-01, -1.3575e-01,  0.0000e+00, -6.3776e-02, -5.6554e-01,
-        -4.8846e-01,  1.1431e+00,  0.0000e+00,  9.1539e-01,  7.3696e-01,
-        -5.0517e-01,  1.0125e+00,  1.0421e-01,  6.3970e-01,  2.9006e-01,
-        -2.6576e-01,  5.4095e-01,  4.0371e-01,  9.2974e-01,  2.4050e-01,
-         2.3361e-01,  2.7879e-01, -5.4637e-01,  2.5070e-01,  7.5700e-01,
-         0.0000e+00,  1.1535e+00,  1.1003e+00,  7.8545e-01, -5.0040e-01,
-         1.3426e+00, -1.6117e-01,  3.2458e-01,  6.5906e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.7427e-01,  2.3862e-01,  1.1569e-01, -1.1472e-04, -4.9408e-02,
-        -5.6334e-01, -2.7061e-01, -7.7616e-01,  1.8241e-01, -2.4802e-01,
-        -9.7585e-02, -2.3705e-01, -2.7083e-01,  4.4199e-01, -1.3913e+00,
-        -5.8010e-01, -5.5689e-01, -5.0147e-01, -9.8966e-01, -3.4045e-01,
-        -1.3759e-01,  6.7824e-02, -1.2590e-01, -3.5126e-01,  2.3472e-01,
-         2.4121e-01,  2.6258e-01, -2.5120e-01,  3.2060e-01,  1.5240e+00,
-        -4.6840e-01, -1.3575e-01,  0.0000e+00, -6.3776e-02, -5.6554e-01,
-        -4.8846e-01,  1.1431e+00,  0.0000e+00,  9.1539e-01,  7.3696e-01,
-        -5.0517e-01,  1.0125e+00,  1.0421e-01,  6.3970e-01,  2.9006e-01,
-        -2.6576e-01,  5.4095e-01,  4.0371e-01,  9.2974e-01,  2.4050e-01,
-         2.3361e-01,  2.7879e-01, -5.4637e-01,  2.5070e-01,  7.5700e-01,
-         0.0000e+00,  1.1535e+00,  1.1003e+00,  7.8545e-01, -5.0040e-01,
-         1.3426e+00, -1.6117e-01,  3.2458e-01,  6.5906e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 4.0275e-01,  3.0295e-01,  8.7616e-02, -6.8473e-03, -2.3629e-03,
-        -5.5596e-01, -3.0572e-01, -7.6919e-01,  2.0723e-01, -2.5831e-01,
-        -2.8551e-01, -1.6403e-01, -2.0650e-01,  4.3972e-01, -1.3927e+00,
-        -5.8607e-01, -5.4892e-01, -5.6838e-01, -9.9597e-01, -3.0391e-01,
-        -1.1824e-01,  4.8019e-02, -1.5658e-01, -2.7805e-01,  2.4039e-01,
-         2.0225e-01,  2.0949e-01, -2.7274e-01,  2.8724e-01,  1.5262e+00,
-        -4.3072e-01, -9.9075e-02, -3.5592e-03, -1.1320e-01, -5.5375e-01,
-        -4.6086e-01,  1.1487e+00,  0.0000e+00,  9.4047e-01,  7.3973e-01,
-        -5.2195e-01,  1.0119e+00,  2.7601e-02,  6.4003e-01,  2.8693e-01,
-        -2.7280e-01,  5.3535e-01,  3.3758e-01,  9.2064e-01,  2.3349e-01,
-         1.7206e-01,  3.0498e-01, -5.5013e-01,  2.1792e-01,  7.5538e-01,
-         1.9216e-04,  1.1493e+00,  1.1191e+00,  7.8044e-01, -4.9833e-01,
-         1.3452e+00, -1.6168e-01,  3.0656e-01,  6.0137e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4027,  0.3029,  0.0876, -0.0068, -0.0024, -0.5560, -0.3057, -0.7692,
-         0.2072, -0.2583, -0.2855, -0.1640, -0.2065,  0.4397, -1.3927, -0.5861,
-        -0.5489, -0.5684, -0.9960, -0.3039, -0.1182,  0.0480, -0.1566, -0.2780,
-         0.2404,  0.2023,  0.2095, -0.2727,  0.2872,  1.5262, -0.4307, -0.0991,
-         0.0000, -0.1132, -0.5538, -0.4609,  1.1487,  0.0000,  0.9405,  0.7397,
-        -0.5219,  1.0119,  0.0276,  0.6400,  0.2869, -0.2728,  0.5353,  0.3376,
-         0.9206,  0.2335,  0.1721,  0.3050, -0.5501,  0.2179,  0.7554,  0.0000,
-         1.1493,  1.1191,  0.7804, -0.4983,  1.3452, -0.1617,  0.3066,  0.6014],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4027,  0.3029,  0.0876, -0.0068, -0.0024, -0.5560, -0.3057, -0.7692,
-         0.2072, -0.2583, -0.2855, -0.1640, -0.2065,  0.4397, -1.3927, -0.5861,
-        -0.5489, -0.5684, -0.9960, -0.3039, -0.1182,  0.0480, -0.1566, -0.2780,
-         0.2404,  0.2023,  0.2095, -0.2727,  0.2872,  1.5262, -0.4307, -0.0991,
-         0.0000, -0.1132, -0.5538, -0.4609,  1.1487,  0.0000,  0.9405,  0.7397,
-        -0.5219,  1.0119,  0.0276,  0.6400,  0.2869, -0.2728,  0.5353,  0.3376,
-         0.9206,  0.2335,  0.1721,  0.3050, -0.5501,  0.2179,  0.7554,  0.0000,
-         1.1493,  1.1191,  0.7804, -0.4983,  1.3452, -0.1617,  0.3066,  0.6014],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.3692e-01,  3.4762e-01,  7.3149e-02, -2.2991e-02, -2.9774e-02,
-        -5.4568e-01, -3.6261e-01, -7.6319e-01,  2.4172e-01, -2.5998e-01,
-        -4.0015e-01, -1.8170e-01, -1.8017e-01,  4.4415e-01, -1.3868e+00,
-        -6.0537e-01, -5.6873e-01, -6.0433e-01, -1.0051e+00, -2.8043e-01,
-        -7.5100e-02,  5.6049e-02, -1.8048e-01, -2.6297e-01,  2.4726e-01,
-         1.7628e-01,  1.7304e-01, -2.7591e-01,  2.5943e-01,  1.5298e+00,
-        -4.2249e-01, -6.5333e-02, -3.0940e-03, -1.1292e-01, -5.4387e-01,
-        -4.3574e-01,  1.1507e+00,  0.0000e+00,  9.4352e-01,  7.3697e-01,
-        -4.9464e-01,  1.0097e+00,  7.3701e-04,  6.3928e-01,  2.5555e-01,
-        -2.0775e-01,  5.3094e-01,  2.6114e-01,  9.1628e-01,  2.4760e-01,
-         1.2261e-01,  3.1895e-01, -5.3048e-01,  1.8111e-01,  7.4288e-01,
-         1.6704e-04,  1.1459e+00,  1.1251e+00,  7.8041e-01, -4.8126e-01,
-         1.3467e+00, -1.4445e-01,  2.8659e-01,  5.4464e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 4.3692e-01,  3.4762e-01,  7.3149e-02, -2.2991e-02, -2.9774e-02,
-        -5.4568e-01, -3.6261e-01, -7.6319e-01,  2.4172e-01, -2.5998e-01,
-        -4.0015e-01, -1.8170e-01, -1.8017e-01,  4.4415e-01, -1.3868e+00,
-        -6.0537e-01, -5.6873e-01, -6.0433e-01, -1.0051e+00, -2.8043e-01,
-        -7.5100e-02,  5.6049e-02, -1.8048e-01, -2.6297e-01,  2.4726e-01,
-         1.7628e-01,  1.7304e-01, -2.7591e-01,  2.5943e-01,  1.5298e+00,
-        -4.2249e-01, -6.5333e-02,  0.0000e+00, -1.1292e-01, -5.4387e-01,
-        -4.3574e-01,  1.1507e+00,  0.0000e+00,  9.4352e-01,  7.3697e-01,
-        -4.9464e-01,  1.0097e+00,  7.3701e-04,  6.3928e-01,  2.5555e-01,
-        -2.0775e-01,  5.3094e-01,  2.6114e-01,  9.1628e-01,  2.4760e-01,
-         1.2261e-01,  3.1895e-01, -5.3048e-01,  1.8111e-01,  7.4288e-01,
-         0.0000e+00,  1.1459e+00,  1.1251e+00,  7.8041e-01, -4.8126e-01,
-         1.3467e+00, -1.4445e-01,  2.8659e-01,  5.4464e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 4.3692e-01,  3.4762e-01,  7.3149e-02, -2.2991e-02, -2.9774e-02,
-        -5.4568e-01, -3.6261e-01, -7.6319e-01,  2.4172e-01, -2.5998e-01,
-        -4.0015e-01, -1.8170e-01, -1.8017e-01,  4.4415e-01, -1.3868e+00,
-        -6.0537e-01, -5.6873e-01, -6.0433e-01, -1.0051e+00, -2.8043e-01,
-        -7.5100e-02,  5.6049e-02, -1.8048e-01, -2.6297e-01,  2.4726e-01,
-         1.7628e-01,  1.7304e-01, -2.7591e-01,  2.5943e-01,  1.5298e+00,
-        -4.2249e-01, -6.5333e-02,  0.0000e+00, -1.1292e-01, -5.4387e-01,
-        -4.3574e-01,  1.1507e+00,  0.0000e+00,  9.4352e-01,  7.3697e-01,
-        -4.9464e-01,  1.0097e+00,  7.3701e-04,  6.3928e-01,  2.5555e-01,
-        -2.0775e-01,  5.3094e-01,  2.6114e-01,  9.1628e-01,  2.4760e-01,
-         1.2261e-01,  3.1895e-01, -5.3048e-01,  1.8111e-01,  7.4288e-01,
-         0.0000e+00,  1.1459e+00,  1.1251e+00,  7.8041e-01, -4.8126e-01,
-         1.3467e+00, -1.4445e-01,  2.8659e-01,  5.4464e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 4.5853e-01,  3.8525e-01,  3.5848e-02, -2.6446e-02, -9.0685e-02,
-        -5.3688e-01, -4.1273e-01, -7.5595e-01,  2.5259e-01, -2.6868e-01,
-        -5.0205e-01, -1.7252e-01, -1.9034e-01,  4.5152e-01, -1.3821e+00,
-        -6.2082e-01, -5.9211e-01, -6.3704e-01, -1.0111e+00, -2.4640e-01,
-        -5.6272e-02,  6.1906e-02, -2.0459e-01, -3.5353e-01,  2.4431e-01,
-         1.6795e-01,  1.5231e-01, -2.7926e-01,  2.7038e-01,  1.5328e+00,
-        -4.3813e-01, -5.7218e-02, -2.6870e-03, -1.3538e-01, -5.2806e-01,
-        -4.0913e-01,  1.1564e+00,  0.0000e+00,  9.3936e-01,  7.3211e-01,
-        -4.6745e-01,  1.0101e+00, -3.9769e-02,  6.3343e-01,  2.5223e-01,
-        -1.0816e-01,  5.3117e-01,  1.8499e-01,  9.1631e-01,  2.5180e-01,
-         6.4771e-02,  3.1120e-01, -5.0622e-01,  1.5747e-01,  7.2205e-01,
-         1.4506e-04,  1.1412e+00,  1.1265e+00,  7.8545e-01, -4.7025e-01,
-         1.3465e+00, -1.0869e-01,  2.5313e-01,  5.0347e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4585,  0.3853,  0.0358, -0.0264, -0.0907, -0.5369, -0.4127, -0.7559,
-         0.2526, -0.2687, -0.5021, -0.1725, -0.1903,  0.4515, -1.3821, -0.6208,
-        -0.5921, -0.6370, -1.0111, -0.2464, -0.0563,  0.0619, -0.2046, -0.3535,
-         0.2443,  0.1679,  0.1523, -0.2793,  0.2704,  1.5328, -0.4381, -0.0572,
-         0.0000, -0.1354, -0.5281, -0.4091,  1.1564,  0.0000,  0.9394,  0.7321,
-        -0.4675,  1.0101, -0.0398,  0.6334,  0.2522, -0.1082,  0.5312,  0.1850,
-         0.9163,  0.2518,  0.0648,  0.3112, -0.5062,  0.1575,  0.7220,  0.0000,
-         1.1412,  1.1265,  0.7854, -0.4703,  1.3465, -0.1087,  0.2531,  0.5035],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4585,  0.3853,  0.0358, -0.0264, -0.0907, -0.5369, -0.4127, -0.7559,
-         0.2526, -0.2687, -0.5021, -0.1725, -0.1903,  0.4515, -1.3821, -0.6208,
-        -0.5921, -0.6370, -1.0111, -0.2464, -0.0563,  0.0619, -0.2046, -0.3535,
-         0.2443,  0.1679,  0.1523, -0.2793,  0.2704,  1.5328, -0.4381, -0.0572,
-         0.0000, -0.1354, -0.5281, -0.4091,  1.1564,  0.0000,  0.9394,  0.7321,
-        -0.4675,  1.0101, -0.0398,  0.6334,  0.2522, -0.1082,  0.5312,  0.1850,
-         0.9163,  0.2518,  0.0648,  0.3112, -0.5062,  0.1575,  0.7220,  0.0000,
-         1.1412,  1.1265,  0.7854, -0.4703,  1.3465, -0.1087,  0.2531,  0.5035],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.9695e-01,  3.8668e-01,  3.3073e-02, -4.3990e-02, -1.1958e-01,
-        -5.2729e-01, -4.6085e-01, -7.5641e-01,  2.9098e-01, -2.9789e-01,
-        -6.0404e-01, -1.9196e-01, -1.8827e-01,  4.2683e-01, -1.3776e+00,
-        -6.3927e-01, -6.0448e-01, -6.7030e-01, -1.0144e+00, -2.3420e-01,
-        -5.8769e-02, -1.3312e-02, -2.3290e-01, -4.5635e-01,  1.8024e-01,
-         1.3665e-01,  1.3583e-01, -2.8286e-01,  2.7064e-01,  1.5337e+00,
-        -4.5571e-01, -5.2556e-02, -2.3312e-03, -1.8505e-01, -5.2214e-01,
-        -3.9583e-01,  1.1628e+00,  0.0000e+00,  9.2557e-01,  7.2819e-01,
-        -4.4290e-01,  1.0157e+00, -5.9706e-02,  6.3392e-01,  2.4110e-01,
-         2.9521e-02,  5.3333e-01,  1.3844e-01,  9.0546e-01,  2.6556e-01,
-         1.5511e-02,  2.7189e-01, -4.5555e-01,  7.7601e-02,  6.9984e-01,
-         1.2586e-04,  1.1370e+00,  1.1301e+00,  7.9383e-01, -4.7887e-01,
-         1.3456e+00, -7.7985e-02,  2.1311e-01,  4.9110e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4970,  0.3867,  0.0331, -0.0440, -0.1196, -0.5273, -0.4609, -0.7564,
-         0.2910, -0.2979, -0.6040, -0.1920, -0.1883,  0.4268, -1.3776, -0.6393,
-        -0.6045, -0.6703, -1.0144, -0.2342, -0.0588, -0.0133, -0.2329, -0.4564,
-         0.1802,  0.1367,  0.1358, -0.2829,  0.2706,  1.5337, -0.4557, -0.0526,
-         0.0000, -0.1851, -0.5221, -0.3958,  1.1628,  0.0000,  0.9256,  0.7282,
-        -0.4429,  1.0157, -0.0597,  0.6339,  0.2411,  0.0295,  0.5333,  0.1384,
-         0.9055,  0.2656,  0.0155,  0.2719, -0.4555,  0.0776,  0.6998,  0.0000,
-         1.1370,  1.1301,  0.7938, -0.4789,  1.3456, -0.0780,  0.2131,  0.4911],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4970,  0.3867,  0.0331, -0.0440, -0.1196, -0.5273, -0.4609, -0.7564,
-         0.2910, -0.2979, -0.6040, -0.1920, -0.1883,  0.4268, -1.3776, -0.6393,
-        -0.6045, -0.6703, -1.0144, -0.2342, -0.0588, -0.0133, -0.2329, -0.4564,
-         0.1802,  0.1367,  0.1358, -0.2829,  0.2706,  1.5337, -0.4557, -0.0526,
-         0.0000, -0.1851, -0.5221, -0.3958,  1.1628,  0.0000,  0.9256,  0.7282,
-        -0.4429,  1.0157, -0.0597,  0.6339,  0.2411,  0.0295,  0.5333,  0.1384,
-         0.9055,  0.2656,  0.0155,  0.2719, -0.4555,  0.0776,  0.6998,  0.0000,
-         1.1370,  1.1301,  0.7938, -0.4789,  1.3456, -0.0780,  0.2131,  0.4911],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.3471e-01,  3.9068e-01,  5.0698e-02, -7.6962e-02, -1.3458e-01,
-        -5.2232e-01, -4.8205e-01, -7.5373e-01,  3.5691e-01, -3.1490e-01,
-        -6.9891e-01, -1.5452e-01, -1.6866e-01,  3.8841e-01, -1.3773e+00,
-        -6.3968e-01, -6.0274e-01, -6.9783e-01, -1.0155e+00, -1.8802e-01,
-        -4.3873e-02, -1.1034e-01, -2.6079e-01, -5.5225e-01,  9.6189e-02,
-         9.8750e-02,  1.6891e-01, -3.0405e-01,  2.7231e-01,  1.5338e+00,
-        -4.6014e-01, -3.3035e-02, -2.0206e-03, -2.6875e-01, -4.9604e-01,
-        -3.6592e-01,  1.1648e+00,  0.0000e+00,  9.0979e-01,  7.2409e-01,
-        -4.3818e-01,  1.0187e+00, -6.6932e-02,  6.2557e-01,  2.1692e-01,
-         1.4740e-01,  5.4063e-01,  9.4134e-02,  8.9168e-01,  2.8511e-01,
-        -3.7572e-02,  2.1620e-01, -4.0274e-01,  9.4125e-04,  6.9243e-01,
-         1.0909e-04,  1.1328e+00,  1.1293e+00,  8.0724e-01, -4.6717e-01,
-         1.3443e+00, -4.4625e-02,  1.3921e-01,  5.0773e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 5.3471e-01,  3.9068e-01,  5.0698e-02, -7.6962e-02, -1.3458e-01,
-        -5.2232e-01, -4.8205e-01, -7.5373e-01,  3.5691e-01, -3.1490e-01,
-        -6.9891e-01, -1.5452e-01, -1.6866e-01,  3.8841e-01, -1.3773e+00,
-        -6.3968e-01, -6.0274e-01, -6.9783e-01, -1.0155e+00, -1.8802e-01,
-        -4.3873e-02, -1.1034e-01, -2.6079e-01, -5.5225e-01,  9.6189e-02,
-         9.8750e-02,  1.6891e-01, -3.0405e-01,  2.7231e-01,  1.5338e+00,
-        -4.6014e-01, -3.3035e-02,  0.0000e+00, -2.6875e-01, -4.9604e-01,
-        -3.6592e-01,  1.1648e+00,  0.0000e+00,  9.0979e-01,  7.2409e-01,
-        -4.3818e-01,  1.0187e+00, -6.6932e-02,  6.2557e-01,  2.1692e-01,
-         1.4740e-01,  5.4063e-01,  9.4134e-02,  8.9168e-01,  2.8511e-01,
-        -3.7572e-02,  2.1620e-01, -4.0274e-01,  9.4125e-04,  6.9243e-01,
-         0.0000e+00,  1.1328e+00,  1.1293e+00,  8.0724e-01, -4.6717e-01,
-         1.3443e+00, -4.4625e-02,  1.3921e-01,  5.0773e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 5.3471e-01,  3.9068e-01,  5.0698e-02, -7.6962e-02, -1.3458e-01,
-        -5.2232e-01, -4.8205e-01, -7.5373e-01,  3.5691e-01, -3.1490e-01,
-        -6.9891e-01, -1.5452e-01, -1.6866e-01,  3.8841e-01, -1.3773e+00,
-        -6.3968e-01, -6.0274e-01, -6.9783e-01, -1.0155e+00, -1.8802e-01,
-        -4.3873e-02, -1.1034e-01, -2.6079e-01, -5.5225e-01,  9.6189e-02,
-         9.8750e-02,  1.6891e-01, -3.0405e-01,  2.7231e-01,  1.5338e+00,
-        -4.6014e-01, -3.3035e-02,  0.0000e+00, -2.6875e-01, -4.9604e-01,
-        -3.6592e-01,  1.1648e+00,  0.0000e+00,  9.0979e-01,  7.2409e-01,
-        -4.3818e-01,  1.0187e+00, -6.6932e-02,  6.2557e-01,  2.1692e-01,
-         1.4740e-01,  5.4063e-01,  9.4134e-02,  8.9168e-01,  2.8511e-01,
-        -3.7572e-02,  2.1620e-01, -4.0274e-01,  9.4125e-04,  6.9243e-01,
-         0.0000e+00,  1.1328e+00,  1.1293e+00,  8.0724e-01, -4.6717e-01,
-         1.3443e+00, -4.4625e-02,  1.3921e-01,  5.0773e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 5.6860e-01,  4.2000e-01,  3.7833e-03, -9.2718e-02, -1.2073e-01,
-        -5.2624e-01, -5.1463e-01, -7.4893e-01,  4.1616e-01, -3.2364e-01,
-        -7.6966e-01, -2.1930e-01, -2.1248e-01,  3.6566e-01, -1.3799e+00,
-        -6.4627e-01, -6.1324e-01, -7.1704e-01, -1.0153e+00, -1.9522e-01,
-        -2.1602e-02, -8.9539e-02, -2.6052e-01, -6.1614e-01,  7.5086e-02,
-         1.4300e-02,  1.4274e-01, -2.8074e-01,  2.4918e-01,  1.5312e+00,
-        -4.7389e-01, -4.7057e-02, -1.7498e-03, -2.4997e-01, -4.8610e-01,
-        -3.6019e-01,  1.1694e+00,  0.0000e+00,  9.1383e-01,  7.4334e-01,
-        -3.8021e-01,  1.0194e+00, -1.1672e-01,  6.2444e-01,  1.6246e-01,
-         1.7479e-01,  5.3856e-01,  2.3581e-02,  8.7764e-01,  2.8786e-01,
-        -6.7085e-02,  1.6407e-01, -3.3969e-01, -3.0659e-02,  7.1903e-01,
-         9.4467e-05,  1.1293e+00,  1.1426e+00,  8.0919e-01, -4.4248e-01,
-         1.3436e+00, -9.8711e-02,  5.9063e-02,  5.0502e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5686,  0.4200,  0.0038, -0.0927, -0.1207, -0.5262, -0.5146, -0.7489,
-         0.4162, -0.3236, -0.7697, -0.2193, -0.2125,  0.3657, -1.3799, -0.6463,
-        -0.6132, -0.7170, -1.0153, -0.1952, -0.0216, -0.0895, -0.2605, -0.6161,
-         0.0751,  0.0143,  0.1427, -0.2807,  0.2492,  1.5312, -0.4739, -0.0471,
-         0.0000, -0.2500, -0.4861, -0.3602,  1.1694,  0.0000,  0.9138,  0.7433,
-        -0.3802,  1.0194, -0.1167,  0.6244,  0.1625,  0.1748,  0.5386,  0.0236,
-         0.8776,  0.2879, -0.0671,  0.1641, -0.3397, -0.0307,  0.7190,  0.0000,
-         1.1293,  0.0000,  0.8092, -0.4425,  1.3436, -0.0987,  0.0591,  0.5050],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5686,  0.4200,  0.0038, -0.0927, -0.1207, -0.5262, -0.5146, -0.7489,
-         0.4162, -0.3236, -0.7697, -0.2193, -0.2125,  0.3657, -1.3799, -0.6463,
-        -0.6132, -0.7170, -1.0153, -0.1952, -0.0216, -0.0895, -0.2605, -0.6161,
-         0.0751,  0.0143,  0.1427, -0.2807,  0.2492,  1.5312, -0.4739, -0.0471,
-         0.0000, -0.2500, -0.4861, -0.3602,  1.1694,  0.0000,  0.9138,  0.7433,
-        -0.3802,  1.0194, -0.1167,  0.6244,  0.1625,  0.1748,  0.5386,  0.0236,
-         0.8776,  0.2879, -0.0671,  0.1641, -0.3397, -0.0307,  0.7190,  0.0000,
-         1.1293,  0.0000,  0.8092, -0.4425,  1.3436, -0.0987,  0.0591,  0.5050],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.8695e-01,  5.1189e-01, -1.0347e-01, -6.3838e-02, -2.4159e-01,
-        -5.1535e-01, -5.4151e-01, -7.4287e-01,  4.5508e-01, -3.0577e-01,
-        -8.0039e-01, -3.0833e-01, -2.9383e-01,  3.1051e-01, -1.3850e+00,
-        -6.4033e-01, -6.4625e-01, -7.1002e-01, -1.0137e+00, -2.1913e-01,
-         4.5566e-02,  1.1794e-01, -1.9909e-01, -6.6377e-01,  1.5014e-01,
-        -5.2223e-02,  8.8079e-02, -2.2341e-01,  2.5711e-01,  1.5286e+00,
-        -5.2051e-01, -3.6771e-02, -1.5138e-03, -6.0769e-02, -4.7096e-01,
-        -3.4831e-01,  1.1756e+00,  0.0000e+00,  9.2941e-01,  7.6742e-01,
-        -2.2049e-01,  1.0226e+00, -1.7491e-01,  6.1707e-01,  1.9571e-01,
-         2.0024e-01,  5.3333e-01, -4.6268e-02,  8.6270e-01,  2.6868e-01,
-        -4.8197e-02,  1.8411e-01, -2.6200e-01,  7.5846e-02,  7.4545e-01,
-         8.1730e-05,  1.1260e+00,  1.1552e-02,  8.0656e-01, -3.8243e-01,
-         1.3437e+00, -1.3187e-01,  2.8380e-02,  4.6475e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5870,  0.5119, -0.1035, -0.0638, -0.2416, -0.5153, -0.5415, -0.7429,
-         0.4551, -0.3058, -0.8004, -0.3083, -0.2938,  0.3105, -1.3850, -0.6403,
-        -0.6463, -0.7100, -1.0137, -0.2191,  0.0456,  0.1179, -0.1991, -0.6638,
-         0.1501, -0.0522,  0.0881, -0.2234,  0.2571,  1.5286, -0.5205, -0.0368,
-         0.0000, -0.0608, -0.4710, -0.3483,  1.1756,  0.0000,  0.9294,  0.7674,
-        -0.2205,  1.0226, -0.1749,  0.6171,  0.1957,  0.2002,  0.5333, -0.0463,
-         0.8627,  0.2687, -0.0482,  0.1841, -0.2620,  0.0758,  0.7455,  0.0000,
-         1.1260,  0.0000,  0.8066, -0.3824,  1.3437, -0.1319,  0.0284,  0.4648],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5870,  0.5119, -0.1035, -0.0638, -0.2416, -0.5153, -0.5415, -0.7429,
-         0.4551, -0.3058, -0.8004, -0.3083, -0.2938,  0.3105, -1.3850, -0.6403,
-        -0.6463, -0.7100, -1.0137, -0.2191,  0.0456,  0.1179, -0.1991, -0.6638,
-         0.1501, -0.0522,  0.0881, -0.2234,  0.2571,  1.5286, -0.5205, -0.0368,
-         0.0000, -0.0608, -0.4710, -0.3483,  1.1756,  0.0000,  0.9294,  0.7674,
-        -0.2205,  1.0226, -0.1749,  0.6171,  0.1957,  0.2002,  0.5333, -0.0463,
-         0.8627,  0.2687, -0.0482,  0.1841, -0.2620,  0.0758,  0.7455,  0.0000,
-         1.1260,  0.0000,  0.8066, -0.3824,  1.3437, -0.1319,  0.0284,  0.4648],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.2043e-01,  5.7769e-01, -8.7914e-02,  8.1186e-03, -2.9066e-01,
-        -4.7513e-01, -5.7123e-01, -7.3882e-01,  4.9853e-01, -3.1102e-01,
-        -8.1509e-01, -3.3172e-01, -2.2306e-01,  2.3820e-01, -1.3906e+00,
-        -6.3904e-01, -6.8925e-01, -6.8894e-01, -1.0131e+00, -2.2469e-01,
-         6.7976e-02,  2.4498e-01, -2.1487e-01, -6.9443e-01,  1.7633e-01,
-        -8.9693e-02,  7.7979e-02, -1.8881e-01,  2.7238e-01,  1.5232e+00,
-        -5.2430e-01,  2.4418e-02, -1.3086e-03,  3.6408e-02, -4.6946e-01,
-        -3.1815e-01,  1.1821e+00,  0.0000e+00,  9.5164e-01,  7.8662e-01,
-        -1.2208e-01,  1.0336e+00, -1.5476e-01,  6.1636e-01,  1.9148e-01,
-         1.9525e-01,  5.2666e-01, -3.8751e-02,  8.3619e-01,  2.3729e-01,
-        -3.1254e-02,  2.1893e-01, -2.2052e-01,  1.3929e-01,  7.7025e-01,
-         7.0648e-05,  1.1222e+00,  9.9856e-03,  8.0597e-01, -3.9623e-01,
-         1.3461e+00, -1.1712e-01,  1.8521e-02,  3.9878e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6204,  0.5777, -0.0879,  0.0081, -0.2907, -0.4751, -0.5712, -0.7388,
-         0.4985, -0.3110, -0.8151, -0.3317, -0.2231,  0.2382, -1.3906, -0.6390,
-        -0.6892, -0.6889, -1.0131, -0.2247,  0.0680,  0.2450, -0.2149, -0.6944,
-         0.1763, -0.0897,  0.0780, -0.1888,  0.2724,  1.5232, -0.5243,  0.0244,
-         0.0000,  0.0364, -0.4695, -0.3181,  1.1821,  0.0000,  0.9516,  0.7866,
-        -0.1221,  1.0336, -0.1548,  0.6164,  0.1915,  0.1953,  0.5267, -0.0388,
-         0.8362,  0.2373, -0.0313,  0.2189, -0.2205,  0.1393,  0.7703,  0.0000,
-         1.1222,  0.0000,  0.8060, -0.3962,  1.3461, -0.1171,  0.0185,  0.3988],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6204,  0.5777, -0.0879,  0.0081, -0.2907, -0.4751, -0.5712, -0.7388,
-         0.4985, -0.3110, -0.8151, -0.3317, -0.2231,  0.2382, -1.3906, -0.6390,
-        -0.6892, -0.6889, -1.0131, -0.2247,  0.0680,  0.2450, -0.2149, -0.6944,
-         0.1763, -0.0897,  0.0780, -0.1888,  0.2724,  1.5232, -0.5243,  0.0244,
-         0.0000,  0.0364, -0.4695, -0.3181,  1.1821,  0.0000,  0.9516,  0.7866,
-        -0.1221,  1.0336, -0.1548,  0.6164,  0.1915,  0.1953,  0.5267, -0.0388,
-         0.8362,  0.2373, -0.0313,  0.2189, -0.2205,  0.1393,  0.7703,  0.0000,
-         1.1222,  0.0000,  0.8060, -0.3962,  1.3461, -0.1171,  0.0185,  0.3988],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.5761e-01,  5.8153e-01,  1.7425e-02,  1.3276e-01, -2.4271e-01,
-        -4.0866e-01, -5.9760e-01, -7.3592e-01,  5.0567e-01, -3.3411e-01,
-        -8.3008e-01, -2.9983e-01,  8.6295e-03,  2.9145e-01, -1.3984e+00,
-        -6.4478e-01, -7.2627e-01, -6.6784e-01, -1.0158e+00, -2.0468e-01,
-         1.1341e-02,  1.8765e-01, -2.9849e-01, -7.0205e-01,  7.2261e-02,
-        -7.8381e-02,  1.3910e-01, -1.8352e-01,  1.8774e-01,  1.5193e+00,
-        -4.8954e-01,  1.1920e-01, -1.1301e-03, -1.5725e-02, -4.6944e-01,
-        -3.0583e-01,  1.1898e+00,  0.0000e+00,  1.0006e+00,  8.0854e-01,
-        -8.0463e-02,  1.0350e+00, -9.9988e-02,  6.5440e-01,  1.2036e-01,
-         1.4133e-01,  5.1159e-01,  7.5100e-02,  8.1171e-01,  1.0939e-01,
-        -4.5023e-02,  1.1455e-01, -2.4772e-01,  2.3981e-02,  7.9062e-01,
-         6.1015e-05,  1.1189e+00,  8.6240e-03,  7.9341e-01, -4.5599e-01,
-         1.3490e+00, -6.9826e-02,  3.2040e-02,  3.5432e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6576,  0.5815,  0.0174,  0.1328, -0.2427, -0.4087, -0.5976, -0.7359,
-         0.5057, -0.3341, -0.8301, -0.2998,  0.0086,  0.2915, -1.3984, -0.6448,
-        -0.7263, -0.6678, -1.0158, -0.2047,  0.0113,  0.1877, -0.2985, -0.7020,
-         0.0723, -0.0784,  0.1391, -0.1835,  0.1877,  1.5193, -0.4895,  0.1192,
-         0.0000, -0.0157, -0.4694, -0.3058,  1.1898,  0.0000,  1.0006,  0.8085,
-        -0.0805,  1.0350, -0.1000,  0.6544,  0.1204,  0.1413,  0.5116,  0.0751,
-         0.8117,  0.1094, -0.0450,  0.1146, -0.2477,  0.0240,  0.7906,  0.0000,
-         1.1189,  0.0000,  0.7934, -0.4560,  1.3490, -0.0698,  0.0320,  0.3543],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6576,  0.5815,  0.0174,  0.1328, -0.2427, -0.4087, -0.5976, -0.7359,
-         0.5057, -0.3341, -0.8301, -0.2998,  0.0086,  0.2915, -1.3984, -0.6448,
-        -0.7263, -0.6678, -1.0158, -0.2047,  0.0113,  0.1877, -0.2985, -0.7020,
-         0.0723, -0.0784,  0.1391, -0.1835,  0.1877,  1.5193, -0.4895,  0.1192,
-         0.0000, -0.0157, -0.4694, -0.3058,  1.1898,  0.0000,  1.0006,  0.8085,
-        -0.0805,  1.0350, -0.1000,  0.6544,  0.1204,  0.1413,  0.5116,  0.0751,
-         0.8117,  0.1094, -0.0450,  0.1146, -0.2477,  0.0240,  0.7906,  0.0000,
-         1.1189,  0.0000,  0.7934, -0.4560,  1.3490, -0.0698,  0.0320,  0.3543],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.5791e-01,  5.9043e-01, -2.2695e-02,  2.2607e-01, -6.7192e-02,
-        -3.6700e-01, -6.1101e-01, -7.3357e-01,  4.8594e-01, -3.4695e-01,
-        -8.5088e-01, -3.1682e-01,  9.3813e-02,  3.0281e-01, -1.4076e+00,
-        -6.4638e-01, -7.4318e-01, -6.4954e-01, -1.0198e+00, -2.2862e-01,
-        -6.6806e-02,  1.6071e-01, -3.5988e-01, -6.6233e-01, -3.4655e-02,
-        -1.0185e-01,  1.6433e-01, -1.5656e-01,  3.7063e-02,  1.5152e+00,
-        -4.3781e-01,  1.6938e-01, -9.7521e-04, -2.9853e-02, -4.8813e-01,
-        -3.4032e-01,  1.1980e+00,  0.0000e+00,  1.0326e+00,  8.1171e-01,
-        -2.6099e-03,  1.0324e+00, -8.1139e-02,  6.8462e-01,  1.1865e-02,
-         1.7951e-01,  4.9846e-01,  1.3710e-01,  7.9437e-01, -1.0549e-02,
-        -6.6282e-02, -8.2287e-02, -2.8151e-01, -4.7410e-02,  7.8354e-01,
-         5.2650e-05,  1.1149e+00,  7.4418e-03,  7.7509e-01, -4.7077e-01,
-         1.3503e+00, -9.8224e-03,  5.0041e-02,  2.5901e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6579,  0.5904, -0.0227,  0.2261, -0.0672, -0.3670, -0.6110, -0.7336,
-         0.4859, -0.3470, -0.8509, -0.3168,  0.0938,  0.3028, -1.4076, -0.6464,
-        -0.7432, -0.6495, -1.0198, -0.2286, -0.0668,  0.1607, -0.3599, -0.6623,
-        -0.0347, -0.1018,  0.1643, -0.1566,  0.0371,  1.5152, -0.4378,  0.1694,
-         0.0000, -0.0299, -0.4881, -0.3403,  1.1980,  0.0000,  1.0326,  0.8117,
-        -0.0026,  1.0324, -0.0811,  0.6846,  0.0119,  0.1795,  0.4985,  0.1371,
-         0.7944, -0.0105, -0.0663, -0.0823, -0.2815, -0.0474,  0.7835,  0.0000,
-         1.1149,  0.0000,  0.7751, -0.4708,  1.3503, -0.0098,  0.0500,  0.2590],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6579,  0.5904, -0.0227,  0.2261, -0.0672, -0.3670, -0.6110, -0.7336,
-         0.4859, -0.3470, -0.8509, -0.3168,  0.0938,  0.3028, -1.4076, -0.6464,
-        -0.7432, -0.6495, -1.0198, -0.2286, -0.0668,  0.1607, -0.3599, -0.6623,
-        -0.0347, -0.1018,  0.1643, -0.1566,  0.0371,  1.5152, -0.4378,  0.1694,
-         0.0000, -0.0299, -0.4881, -0.3403,  1.1980,  0.0000,  1.0326,  0.8117,
-        -0.0026,  1.0324, -0.0811,  0.6846,  0.0119,  0.1795,  0.4985,  0.1371,
-         0.7944, -0.0105, -0.0663, -0.0823, -0.2815, -0.0474,  0.7835,  0.0000,
-         1.1149,  0.0000,  0.7751, -0.4708,  1.3503, -0.0098,  0.0500,  0.2590],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.5276e-01,  6.1241e-01, -1.1744e-01,  2.8536e-01,  7.7322e-02,
-        -3.4502e-01, -6.1922e-01, -7.2868e-01,  4.6212e-01, -3.6640e-01,
-        -8.7519e-01, -3.3555e-01,  1.2206e-01,  2.7276e-01, -1.4147e+00,
-        -6.3029e-01, -7.4412e-01, -6.4414e-01, -1.0199e+00, -2.4335e-01,
-        -1.3920e-01,  1.5055e-01, -3.9791e-01, -6.2108e-01, -1.2426e-01,
-        -1.3863e-01,  1.8760e-01, -1.3276e-01, -7.3575e-02,  1.5109e+00,
-        -3.9482e-01,  1.9435e-01, -8.4082e-04, -5.3609e-02, -4.8687e-01,
-        -3.5447e-01,  1.2015e+00,  0.0000e+00,  1.0460e+00,  8.0637e-01,
-         6.3426e-02,  1.0320e+00, -1.0275e-01,  7.0502e-01, -2.1281e-02,
-         2.2734e-01,  4.7988e-01,  1.5777e-01,  7.9043e-01, -9.4417e-02,
-        -8.0935e-02, -2.5989e-01, -2.9725e-01, -1.7429e-01,  7.7661e-01,
-         4.5395e-05,  1.1106e+00,  6.4162e-03,  7.5577e-01, -4.9298e-01,
-         1.3510e+00,  3.2247e-02,  5.6908e-02,  1.6271e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6528,  0.6124, -0.1174,  0.2854,  0.0773, -0.3450, -0.6192, -0.7287,
-         0.4621, -0.3664, -0.8752, -0.3356,  0.1221,  0.2728, -1.4147, -0.6303,
-        -0.7441, -0.6441, -1.0199, -0.2434, -0.1392,  0.1506, -0.3979, -0.6211,
-        -0.1243, -0.1386,  0.1876, -0.1328, -0.0736,  1.5109, -0.3948,  0.1944,
-         0.0000, -0.0536, -0.4869, -0.3545,  1.2015,  0.0000,  1.0460,  0.8064,
-         0.0634,  1.0320, -0.1027,  0.7050, -0.0213,  0.2273,  0.4799,  0.1578,
-         0.7904, -0.0944, -0.0809, -0.2599, -0.2972, -0.1743,  0.7766,  0.0000,
-         1.1106,  0.0000,  0.7558, -0.4930,  1.3510,  0.0322,  0.0569,  0.1627],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6528,  0.6124, -0.1174,  0.2854,  0.0773, -0.3450, -0.6192, -0.7287,
-         0.4621, -0.3664, -0.8752, -0.3356,  0.1221,  0.2728, -1.4147, -0.6303,
-        -0.7441, -0.6441, -1.0199, -0.2434, -0.1392,  0.1506, -0.3979, -0.6211,
-        -0.1243, -0.1386,  0.1876, -0.1328, -0.0736,  1.5109, -0.3948,  0.1944,
-         0.0000, -0.0536, -0.4869, -0.3545,  1.2015,  0.0000,  1.0460,  0.8064,
-         0.0634,  1.0320, -0.1027,  0.7050, -0.0213,  0.2273,  0.4799,  0.1578,
-         0.7904, -0.0944, -0.0809, -0.2599, -0.2972, -0.1743,  0.7766,  0.0000,
-         1.1106,  0.0000,  0.7558, -0.4930,  1.3510,  0.0322,  0.0569,  0.1627],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.5549e-01,  6.4026e-01, -1.5770e-01,  3.4894e-01,  1.6809e-01,
-        -3.0708e-01, -6.1959e-01, -7.1941e-01,  4.2907e-01, -3.6012e-01,
-        -8.9371e-01, -3.5282e-01,  1.2659e-01,  2.1759e-01, -1.4199e+00,
-        -6.0804e-01, -7.4914e-01, -6.3329e-01, -1.0186e+00, -2.4863e-01,
-        -1.6980e-01,  1.9601e-01, -4.1304e-01, -6.0764e-01, -1.5422e-01,
-        -1.6428e-01,  2.2579e-01, -8.1978e-02, -1.7287e-01,  1.5077e+00,
-        -3.7021e-01,  2.2123e-01, -7.2436e-04, -1.9138e-02, -4.7626e-01,
-        -3.5783e-01,  1.1999e+00,  0.0000e+00,  1.0484e+00,  7.9066e-01,
-         1.5014e-01,  1.0316e+00, -1.0207e-01,  7.2148e-01, -4.9782e-02,
-         2.8153e-01,  4.5881e-01,  1.6924e-01,  7.9796e-01, -1.5615e-01,
-        -7.0846e-02, -3.7062e-01, -3.3120e-01, -2.0414e-01,  7.5795e-01,
-         3.9107e-05,  1.1057e+00,  5.5276e-03,  7.4733e-01, -4.9619e-01,
-         1.3497e+00,  9.5207e-02,  4.9101e-02,  5.3319e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6555,  0.6403, -0.1577,  0.3489,  0.1681, -0.3071, -0.6196, -0.7194,
-         0.4291, -0.3601, -0.8937, -0.3528,  0.1266,  0.2176, -1.4199, -0.6080,
-        -0.7491, -0.6333, -1.0186, -0.2486, -0.1698,  0.1960, -0.4130, -0.6076,
-        -0.1542, -0.1643,  0.2258, -0.0820, -0.1729,  1.5077, -0.3702,  0.2212,
-         0.0000, -0.0191, -0.4763, -0.3578,  1.1999,  0.0000,  1.0484,  0.7907,
-         0.1501,  1.0316, -0.1021,  0.7215, -0.0498,  0.2815,  0.4588,  0.1692,
-         0.7980, -0.1561, -0.0708, -0.3706, -0.3312, -0.2041,  0.7579,  0.0000,
-         1.1057,  0.0000,  0.7473, -0.4962,  1.3497,  0.0952,  0.0491,  0.0533],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6555,  0.6403, -0.1577,  0.3489,  0.1681, -0.3071, -0.6196, -0.7194,
-         0.4291, -0.3601, -0.8937, -0.3528,  0.1266,  0.2176, -1.4199, -0.6080,
-        -0.7491, -0.6333, -1.0186, -0.2486, -0.1698,  0.1960, -0.4130, -0.6076,
-        -0.1542, -0.1643,  0.2258, -0.0820, -0.1729,  1.5077, -0.3702,  0.2212,
-         0.0000, -0.0191, -0.4763, -0.3578,  1.1999,  0.0000,  1.0484,  0.7907,
-         0.1501,  1.0316, -0.1021,  0.7215, -0.0498,  0.2815,  0.4588,  0.1692,
-         0.7980, -0.1561, -0.0708, -0.3706, -0.3312, -0.2041,  0.7579,  0.0000,
-         1.1057,  0.0000,  0.7473, -0.4962,  1.3497,  0.0952,  0.0491,  0.0533],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.5394e-01,  6.6149e-01, -1.6717e-01,  3.9510e-01,  3.1591e-01,
-        -2.7043e-01, -6.1437e-01, -7.0739e-01,  3.9745e-01, -3.3005e-01,
-        -8.9599e-01, -3.6284e-01,  1.5788e-01,  1.2871e-01, -1.4203e+00,
-        -5.8095e-01, -7.6507e-01, -6.0874e-01, -1.0172e+00, -2.4270e-01,
-        -1.5161e-01,  2.5262e-01, -3.9610e-01, -6.0651e-01, -1.0586e-01,
-        -1.6273e-01,  2.6542e-01, -1.1012e-02, -2.6522e-01,  1.5065e+00,
-        -4.0852e-01,  2.4853e-01, -6.2354e-04,  5.7846e-02, -4.4987e-01,
-        -3.4467e-01,  1.1922e+00,  0.0000e+00,  1.0315e+00,  7.6163e-01,
-         2.2823e-01,  1.0298e+00, -9.6764e-02,  7.2481e-01, -1.1064e-01,
-         2.9725e-01,  4.3652e-01,  1.6715e-01,  8.0285e-01, -2.5267e-01,
-        -1.2949e-02, -4.0625e-01, -3.8560e-01, -1.6056e-01,  7.2701e-01,
-         3.3664e-05,  1.0993e+00,  4.7582e-03,  7.4120e-01, -4.9245e-01,
-         1.3472e+00,  1.4747e-01,  6.7337e-02, -4.8052e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6539,  0.6615, -0.1672,  0.3951,  0.3159, -0.2704, -0.6144, -0.7074,
-         0.3974, -0.3300, -0.8960, -0.3628,  0.1579,  0.1287, -1.4203, -0.5809,
-        -0.7651, -0.6087, -1.0172, -0.2427, -0.1516,  0.2526, -0.3961, -0.6065,
-        -0.1059, -0.1627,  0.2654, -0.0110, -0.2652,  1.5065, -0.4085,  0.2485,
-         0.0000,  0.0578, -0.4499, -0.3447,  1.1922,  0.0000,  1.0315,  0.7616,
-         0.2282,  1.0298, -0.0968,  0.7248, -0.1106,  0.2972,  0.4365,  0.1671,
-         0.8029, -0.2527, -0.0129, -0.4063, -0.3856, -0.1606,  0.7270,  0.0000,
-         1.0993,  0.0000,  0.7412, -0.4925,  1.3472,  0.1475,  0.0673, -0.0481],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6539,  0.6615, -0.1672,  0.3951,  0.3159, -0.2704, -0.6144, -0.7074,
-         0.3974, -0.3300, -0.8960, -0.3628,  0.1579,  0.1287, -1.4203, -0.5809,
-        -0.7651, -0.6087, -1.0172, -0.2427, -0.1516,  0.2526, -0.3961, -0.6065,
-        -0.1059, -0.1627,  0.2654, -0.0110, -0.2652,  1.5065, -0.4085,  0.2485,
-         0.0000,  0.0578, -0.4499, -0.3447,  1.1922,  0.0000,  1.0315,  0.7616,
-         0.2282,  1.0298, -0.0968,  0.7248, -0.1106,  0.2972,  0.4365,  0.1671,
-         0.8029, -0.2527, -0.0129, -0.4063, -0.3856, -0.1606,  0.7270,  0.0000,
-         1.0993,  0.0000,  0.7412, -0.4925,  1.3472,  0.1475,  0.0673, -0.0481],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.5679e-01,  6.6606e-01, -2.0509e-01,  4.5879e-01,  4.4541e-01,
-        -2.1834e-01, -6.0882e-01, -6.9886e-01,  3.8084e-01, -3.1058e-01,
-        -8.9759e-01, -3.4203e-01,  1.9194e-01,  6.1897e-02, -1.4182e+00,
-        -5.4835e-01, -7.8447e-01, -5.8274e-01, -1.0150e+00, -2.3620e-01,
-        -1.2336e-01,  2.7482e-01, -3.7972e-01, -6.1043e-01, -5.8284e-02,
-        -1.5507e-01,  3.2312e-01,  5.6568e-02, -3.4358e-01,  1.5032e+00,
-        -4.4974e-01,  2.6548e-01, -5.3634e-04,  9.2124e-02, -4.1867e-01,
-        -3.1881e-01,  1.1841e+00,  0.0000e+00,  1.0178e+00,  7.3072e-01,
-         2.7513e-01,  1.0291e+00, -1.1836e-01,  7.2450e-01, -1.6223e-01,
-         2.9663e-01,  4.1688e-01,  1.9078e-01,  8.0279e-01, -3.3273e-01,
-         5.7784e-02, -4.1537e-01, -4.2721e-01, -1.2661e-01,  6.9766e-01,
-         2.8956e-05,  1.0930e+00,  4.0928e-03,  7.2737e-01, -4.8216e-01,
-         1.3432e+00,  1.8723e-01,  9.9279e-02, -1.4648e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6568,  0.6661, -0.2051,  0.4588,  0.4454, -0.2183, -0.6088, -0.6989,
-         0.3808, -0.3106, -0.8976, -0.3420,  0.1919,  0.0619, -1.4182, -0.5484,
-        -0.7845, -0.5827, -1.0150, -0.2362, -0.1234,  0.2748, -0.3797, -0.6104,
-        -0.0583, -0.1551,  0.3231,  0.0566, -0.3436,  1.5032, -0.4497,  0.2655,
-         0.0000,  0.0921, -0.4187, -0.3188,  1.1841,  0.0000,  1.0178,  0.7307,
-         0.2751,  1.0291, -0.1184,  0.7245, -0.1622,  0.2966,  0.4169,  0.1908,
-         0.8028, -0.3327,  0.0578, -0.4154, -0.4272, -0.1266,  0.6977,  0.0000,
-         1.0930,  0.0000,  0.7274, -0.4822,  1.3432,  0.1872,  0.0993, -0.1465],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6568,  0.6661, -0.2051,  0.4588,  0.4454, -0.2183, -0.6088, -0.6989,
-         0.3808, -0.3106, -0.8976, -0.3420,  0.1919,  0.0619, -1.4182, -0.5484,
-        -0.7845, -0.5827, -1.0150, -0.2362, -0.1234,  0.2748, -0.3797, -0.6104,
-        -0.0583, -0.1551,  0.3231,  0.0566, -0.3436,  1.5032, -0.4497,  0.2655,
-         0.0000,  0.0921, -0.4187, -0.3188,  1.1841,  0.0000,  1.0178,  0.7307,
-         0.2751,  1.0291, -0.1184,  0.7245, -0.1622,  0.2966,  0.4169,  0.1908,
-         0.8028, -0.3327,  0.0578, -0.4154, -0.4272, -0.1266,  0.6977,  0.0000,
-         1.0930,  0.0000,  0.7274, -0.4822,  1.3432,  0.1872,  0.0993, -0.1465],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.6211e-01,  6.5763e-01, -1.9453e-01,  4.5795e-01,  5.3745e-01,
-        -1.6628e-01, -6.0263e-01, -6.8882e-01,  3.4290e-01, -2.9219e-01,
-        -8.9423e-01, -2.9939e-01,  2.7980e-01,  4.7185e-02, -1.4150e+00,
-        -5.2399e-01, -7.8968e-01, -5.6064e-01, -1.0131e+00, -2.2467e-01,
-        -1.1156e-01,  1.8769e-01, -3.8474e-01, -6.1338e-01, -9.8566e-02,
-        -1.5622e-01,  3.5904e-01,  5.8445e-02, -4.3411e-01,  1.5020e+00,
-        -4.6262e-01,  2.8525e-01, -4.6099e-04,  2.7604e-02, -3.9669e-01,
-        -2.8849e-01,  1.1730e+00,  0.0000e+00,  9.9839e-01,  7.1344e-01,
-         2.5722e-01,  1.0200e+00, -1.0944e-01,  7.1377e-01, -1.5735e-01,
-         2.7550e-01,  4.0078e-01,  2.2695e-01,  8.1072e-01, -3.6981e-01,
-         1.0578e-01, -4.6131e-01, -4.4949e-01, -1.5753e-01,  6.8657e-01,
-         2.4888e-05,  1.0876e+00,  3.5178e-03,  7.2118e-01, -4.8478e-01,
-         1.3372e+00,  1.9023e-01,  1.2497e-01, -2.4141e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6621,  0.6576, -0.1945,  0.4580,  0.5375, -0.1663, -0.6026, -0.6888,
-         0.3429, -0.2922, -0.8942, -0.2994,  0.2798,  0.0472, -1.4150, -0.5240,
-        -0.7897, -0.5606, -1.0131, -0.2247, -0.1116,  0.1877, -0.3847, -0.6134,
-        -0.0986, -0.1562,  0.3590,  0.0584, -0.4341,  1.5020, -0.4626,  0.2853,
-         0.0000,  0.0276, -0.3967, -0.2885,  1.1730,  0.0000,  0.9984,  0.7134,
-         0.2572,  1.0200, -0.1094,  0.7138, -0.1574,  0.2755,  0.4008,  0.2269,
-         0.8107, -0.3698,  0.1058, -0.4613, -0.4495, -0.1575,  0.6866,  0.0000,
-         1.0876,  0.0000,  0.7212, -0.4848,  1.3372,  0.1902,  0.1250, -0.2414],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6621,  0.6576, -0.1945,  0.4580,  0.5375, -0.1663, -0.6026, -0.6888,
-         0.3429, -0.2922, -0.8942, -0.2994,  0.2798,  0.0472, -1.4150, -0.5240,
-        -0.7897, -0.5606, -1.0131, -0.2247, -0.1116,  0.1877, -0.3847, -0.6134,
-        -0.0986, -0.1562,  0.3590,  0.0584, -0.4341,  1.5020, -0.4626,  0.2853,
-         0.0000,  0.0276, -0.3967, -0.2885,  1.1730,  0.0000,  0.9984,  0.7134,
-         0.2572,  1.0200, -0.1094,  0.7138, -0.1574,  0.2755,  0.4008,  0.2269,
-         0.8107, -0.3698,  0.1058, -0.4613, -0.4495, -0.1575,  0.6866,  0.0000,
-         1.0876,  0.0000,  0.7212, -0.4848,  1.3372,  0.1902,  0.1250, -0.2414],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.6349e-01,  6.4107e-01, -2.3389e-01,  4.5054e-01,  5.6821e-01,
-        -1.3208e-01, -5.9898e-01, -6.8672e-01,  2.9663e-01, -2.9051e-01,
-        -8.8753e-01, -2.6570e-01,  3.3972e-01,  4.8943e-02, -1.4117e+00,
-        -5.0666e-01, -7.8431e-01, -5.4074e-01, -1.0111e+00, -1.9586e-01,
-        -7.4469e-02,  8.8066e-02, -3.8997e-01, -5.8413e-01, -1.1882e-01,
-        -1.6900e-01,  3.7506e-01,  4.5228e-02, -4.7427e-01,  1.5011e+00,
-        -4.8375e-01,  3.0744e-01, -3.9594e-04, -3.4891e-02, -3.6805e-01,
-        -2.4369e-01,  1.1643e+00,  0.0000e+00,  9.8065e-01,  6.9848e-01,
-         2.3319e-01,  1.0101e+00, -1.5108e-01,  7.0082e-01, -1.2771e-01,
-         2.4815e-01,  3.7265e-01,  2.8436e-01,  8.2575e-01, -3.9981e-01,
-         1.5613e-01, -4.8922e-01, -4.7689e-01, -1.9987e-01,  6.7648e-01,
-         2.1376e-05,  1.0839e+00,  3.0214e-03,  7.0114e-01, -4.5708e-01,
-         1.3324e+00,  1.7674e-01,  1.3306e-01, -3.5182e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6635,  0.6411, -0.2339,  0.4505,  0.5682, -0.1321, -0.5990, -0.6867,
-         0.2966, -0.2905, -0.8875, -0.2657,  0.3397,  0.0489, -1.4117, -0.5067,
-        -0.7843, -0.5407, -1.0111, -0.1959, -0.0745,  0.0881, -0.3900, -0.5841,
-        -0.1188, -0.1690,  0.3751,  0.0452, -0.4743,  1.5011, -0.4837,  0.3074,
-         0.0000, -0.0349, -0.3681, -0.2437,  1.1643,  0.0000,  0.9807,  0.6985,
-         0.2332,  1.0101, -0.1511,  0.7008, -0.1277,  0.2481,  0.3727,  0.2844,
-         0.8257, -0.3998,  0.1561, -0.4892, -0.4769, -0.1999,  0.6765,  0.0000,
-         1.0839,  0.0000,  0.7011, -0.4571,  1.3324,  0.1767,  0.1331, -0.3518],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6635,  0.6411, -0.2339,  0.4505,  0.5682, -0.1321, -0.5990, -0.6867,
-         0.2966, -0.2905, -0.8875, -0.2657,  0.3397,  0.0489, -1.4117, -0.5067,
-        -0.7843, -0.5407, -1.0111, -0.1959, -0.0745,  0.0881, -0.3900, -0.5841,
-        -0.1188, -0.1690,  0.3751,  0.0452, -0.4743,  1.5011, -0.4837,  0.3074,
-         0.0000, -0.0349, -0.3681, -0.2437,  1.1643,  0.0000,  0.9807,  0.6985,
-         0.2332,  1.0101, -0.1511,  0.7008, -0.1277,  0.2481,  0.3727,  0.2844,
-         0.8257, -0.3998,  0.1561, -0.4892, -0.4769, -0.1999,  0.6765,  0.0000,
-         1.0839,  0.0000,  0.7011, -0.4571,  1.3324,  0.1767,  0.1331, -0.3518],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.4914e-01,  6.4136e-01, -2.8685e-01,  4.2981e-01,  5.6004e-01,
-        -9.2208e-02, -5.9756e-01, -6.7507e-01,  2.3357e-01, -2.6646e-01,
-        -8.7867e-01, -2.8872e-01,  3.0000e-01,  3.4929e-03, -1.4068e+00,
-        -4.8711e-01, -7.6802e-01, -5.1273e-01, -1.0099e+00, -2.0459e-01,
-        -3.0585e-02,  1.0186e-01, -3.6468e-01, -5.2879e-01, -8.4328e-02,
-        -1.8924e-01,  3.2117e-01,  5.8324e-02, -4.7095e-01,  1.5046e+00,
-        -4.9192e-01,  2.7044e-01, -3.3983e-04, -6.0518e-03, -3.7204e-01,
-        -2.1853e-01,  1.1504e+00,  0.0000e+00,  9.3357e-01,  6.5178e-01,
-         2.5520e-01,  9.9338e-01, -1.3638e-01,  6.7179e-01, -1.2109e-01,
-         2.3878e-01,  3.3854e-01,  2.7951e-01,  8.4611e-01, -4.0436e-01,
-         2.2040e-01, -4.8488e-01, -5.0362e-01, -1.1984e-01,  6.4257e-01,
-         1.8347e-05,  1.0797e+00,  2.5932e-03,  6.7367e-01, -4.1221e-01,
-         1.3279e+00,  1.0727e-01,  1.3902e-01, -4.3714e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6491,  0.6414, -0.2869,  0.4298,  0.5600, -0.0922,  0.0000, -0.6751,
-         0.2336, -0.2665, -0.8787, -0.2887,  0.3000,  0.0035, -1.4068, -0.4871,
-        -0.7680, -0.5127, -1.0099, -0.2046, -0.0306,  0.1019, -0.3647, -0.5288,
-        -0.0843, -0.1892,  0.3212,  0.0583, -0.4709,  1.5046, -0.4919,  0.2704,
-         0.0000, -0.0061, -0.3720, -0.2185,  1.1504,  0.0000,  0.9336,  0.6518,
-         0.2552,  0.9934, -0.1364,  0.6718, -0.1211,  0.2388,  0.3385,  0.2795,
-         0.8461, -0.4044,  0.2204, -0.4849, -0.5036, -0.1198,  0.6426,  0.0000,
-         1.0797,  0.0000,  0.6737, -0.4122,  1.3279,  0.1073,  0.1390, -0.4371],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6491,  0.6414, -0.2869,  0.4298,  0.5600, -0.0922,  0.0000, -0.6751,
-         0.2336, -0.2665, -0.8787, -0.2887,  0.3000,  0.0035, -1.4068, -0.4871,
-        -0.7680, -0.5127, -1.0099, -0.2046, -0.0306,  0.1019, -0.3647, -0.5288,
-        -0.0843, -0.1892,  0.3212,  0.0583, -0.4709,  1.5046, -0.4919,  0.2704,
-         0.0000, -0.0061, -0.3720, -0.2185,  1.1504,  0.0000,  0.9336,  0.6518,
-         0.2552,  0.9934, -0.1364,  0.6718, -0.1211,  0.2388,  0.3385,  0.2795,
-         0.8461, -0.4044,  0.2204, -0.4849, -0.5036, -0.1198,  0.6426,  0.0000,
-         1.0797,  0.0000,  0.6737, -0.4122,  1.3279,  0.1073,  0.1390, -0.4371],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.4252e-01,  6.3402e-01, -2.8374e-01,  3.7684e-01,  5.3533e-01,
-        -3.9781e-02,  1.2158e-03, -6.6409e-01,  1.6676e-01, -2.5957e-01,
-        -8.7363e-01, -2.6339e-01,  2.4215e-01, -4.3323e-03, -1.4038e+00,
-        -4.9160e-01, -7.5233e-01, -4.9835e-01, -1.0068e+00, -2.1774e-01,
-         2.1627e-02,  7.2175e-02, -3.4702e-01, -4.5177e-01, -6.4269e-02,
-        -1.9702e-01,  2.6134e-01,  4.6130e-02, -4.5929e-01,  1.5048e+00,
-        -4.9724e-01,  2.2649e-01, -2.9147e-04, -2.3297e-03, -4.0629e-01,
-        -1.8653e-01,  1.1439e+00,  0.0000e+00,  9.1770e-01,  6.2277e-01,
-         2.6115e-01,  9.8423e-01, -1.0076e-01,  6.5142e-01, -5.1071e-02,
-         2.1387e-01,  3.2306e-01,  2.7766e-01,  8.3834e-01, -4.2490e-01,
-         2.6851e-01, -4.9343e-01, -5.3250e-01, -2.4451e-02,  6.2822e-01,
-         1.5736e-05,  1.0755e+00,  2.2242e-03,  6.2390e-01, -4.3738e-01,
-         1.3263e+00,  2.5933e-02,  1.4347e-01, -4.3168e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6425,  0.6340, -0.2837,  0.3768,  0.5353, -0.0398,  0.0000, -0.6641,
-         0.1668, -0.2596, -0.8736, -0.2634,  0.2422, -0.0043, -1.4038, -0.4916,
-        -0.7523, -0.4983, -1.0068, -0.2177,  0.0216,  0.0722, -0.3470, -0.4518,
-        -0.0643, -0.1970,  0.2613,  0.0461, -0.4593,  1.5048, -0.4972,  0.2265,
-         0.0000, -0.0023, -0.4063, -0.1865,  1.1439,  0.0000,  0.9177,  0.6228,
-         0.2612,  0.9842, -0.1008,  0.6514, -0.0511,  0.2139,  0.3231,  0.2777,
-         0.8383, -0.4249,  0.2685, -0.4934, -0.5325, -0.0245,  0.6282,  0.0000,
-         1.0755,  0.0000,  0.6239, -0.4374,  1.3263,  0.0259,  0.1435, -0.4317],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6425,  0.6340, -0.2837,  0.3768,  0.5353, -0.0398,  0.0000, -0.6641,
-         0.1668, -0.2596, -0.8736, -0.2634,  0.2422, -0.0043, -1.4038, -0.4916,
-        -0.7523, -0.4983, -1.0068, -0.2177,  0.0216,  0.0722, -0.3470, -0.4518,
-        -0.0643, -0.1970,  0.2613,  0.0461, -0.4593,  1.5048, -0.4972,  0.2265,
-         0.0000, -0.0023, -0.4063, -0.1865,  1.1439,  0.0000,  0.9177,  0.6228,
-         0.2612,  0.9842, -0.1008,  0.6514, -0.0511,  0.2139,  0.3231,  0.2777,
-         0.8383, -0.4249,  0.2685, -0.4934, -0.5325, -0.0245,  0.6282,  0.0000,
-         1.0755,  0.0000,  0.6239, -0.4374,  1.3263,  0.0259,  0.1435, -0.4317],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.4036e-01,  6.3036e-01, -2.4734e-01,  3.6162e-01,  5.1393e-01,
-         3.8762e-02,  1.0421e-03, -6.7201e-01,  8.2848e-02, -2.2102e-01,
-        -8.6657e-01, -2.5956e-01,  1.7677e-01,  1.7074e-02, -1.4025e+00,
-        -5.1630e-01, -7.4886e-01, -4.8601e-01, -1.0030e+00, -2.8233e-01,
-         5.1506e-02,  4.2051e-02, -3.2945e-01, -3.8630e-01, -3.4360e-02,
-        -1.8149e-01,  2.2100e-01,  3.6813e-02, -4.4745e-01,  1.5070e+00,
-        -5.1985e-01,  2.1595e-01, -2.4983e-04, -3.9169e-03, -4.5877e-01,
-        -2.1876e-01,  1.1396e+00,  0.0000e+00,  9.1592e-01,  5.9060e-01,
-         2.6293e-01,  9.8552e-01, -9.9852e-03,  6.3599e-01,  2.3314e-02,
-         1.8204e-01,  3.1664e-01,  2.8439e-01,  8.3383e-01, -4.5513e-01,
-         2.9897e-01, -5.2158e-01, -5.4139e-01,  1.9295e-01,  6.0206e-01,
-         1.3488e-05,  1.0722e+00,  1.9064e-03,  5.7847e-01, -4.7437e-01,
-         1.3252e+00,  1.2519e-02,  1.2625e-01, -3.5588e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6404,  0.6304, -0.2473,  0.3616,  0.5139,  0.0388,  0.0000, -0.6720,
-         0.0828, -0.2210, -0.8666, -0.2596,  0.1768,  0.0171, -1.4025, -0.5163,
-        -0.7489, -0.4860, -1.0030, -0.2823,  0.0515,  0.0421, -0.3294, -0.3863,
-        -0.0344, -0.1815,  0.2210,  0.0368, -0.4474,  1.5070, -0.5199,  0.2159,
-         0.0000, -0.0039, -0.4588, -0.2188,  1.1396,  0.0000,  0.9159,  0.5906,
-         0.2629,  0.9855, -0.0100,  0.6360,  0.0233,  0.1820,  0.3166,  0.2844,
-         0.8338, -0.4551,  0.2990, -0.5216, -0.5414,  0.1930,  0.6021,  0.0000,
-         1.0722,  0.0000,  0.5785, -0.4744,  1.3252,  0.0125,  0.1262, -0.3559],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6404,  0.6304, -0.2473,  0.3616,  0.5139,  0.0388,  0.0000, -0.6720,
-         0.0828, -0.2210, -0.8666, -0.2596,  0.1768,  0.0171, -1.4025, -0.5163,
-        -0.7489, -0.4860, -1.0030, -0.2823,  0.0515,  0.0421, -0.3294, -0.3863,
-        -0.0344, -0.1815,  0.2210,  0.0368, -0.4474,  1.5070, -0.5199,  0.2159,
-         0.0000, -0.0039, -0.4588, -0.2188,  1.1396,  0.0000,  0.9159,  0.5906,
-         0.2629,  0.9855, -0.0100,  0.6360,  0.0233,  0.1820,  0.3166,  0.2844,
-         0.8338, -0.4551,  0.2990, -0.5216, -0.5414,  0.1930,  0.6021,  0.0000,
-         1.0722,  0.0000,  0.5785, -0.4744,  1.3252,  0.0125,  0.1262, -0.3559],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.3500e-01,  6.1131e-01, -1.7671e-01,  3.4817e-01,  4.0812e-01,
-         1.5163e-01,  8.9264e-04, -6.8258e-01,  3.3915e-02, -2.3383e-01,
-        -8.4980e-01, -2.3349e-01,  1.8038e-01,  8.7006e-02, -1.4035e+00,
-        -5.4962e-01, -7.4221e-01, -4.8086e-01, -1.0001e+00, -3.4251e-01,
-         6.5892e-02, -5.0297e-02, -3.4186e-01, -3.2445e-01, -3.4199e-02,
-        -1.4665e-01,  2.1284e-01,  1.6382e-02, -3.8811e-01,  1.5050e+00,
-        -5.4709e-01,  2.2840e-01, -2.1400e-04, -5.0070e-02, -4.9590e-01,
-        -2.6100e-01,  1.1433e+00,  0.0000e+00,  9.2475e-01,  5.7539e-01,
-         2.2568e-01,  9.8637e-01,  7.6764e-02,  6.3672e-01,  7.2475e-02,
-         1.3156e-01,  3.1048e-01,  3.1114e-01,  8.3001e-01, -4.9028e-01,
-         3.0501e-01, -5.3834e-01, -5.4012e-01,  3.3862e-01,  5.8820e-01,
-         1.1553e-05,  1.0710e+00,  1.6330e-03,  5.4950e-01, -4.8809e-01,
-         1.3290e+00,  1.7673e-02,  9.6163e-02, -2.8290e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6350,  0.6113, -0.1767,  0.3482,  0.4081,  0.1516,  0.0000, -0.6826,
-         0.0339, -0.2338, -0.8498, -0.2335,  0.1804,  0.0870, -1.4035, -0.5496,
-        -0.7422, -0.4809, -1.0001, -0.3425,  0.0659, -0.0503, -0.3419, -0.3244,
-        -0.0342, -0.1467,  0.2128,  0.0164, -0.3881,  1.5050, -0.5471,  0.2284,
-         0.0000, -0.0501, -0.4959, -0.2610,  1.1433,  0.0000,  0.9248,  0.5754,
-         0.2257,  0.9864,  0.0768,  0.6367,  0.0725,  0.1316,  0.3105,  0.3111,
-         0.8300, -0.4903,  0.3050, -0.5383, -0.5401,  0.3386,  0.5882,  0.0000,
-         1.0710,  0.0000,  0.5495, -0.4881,  1.3290,  0.0177,  0.0962, -0.2829],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6350,  0.6113, -0.1767,  0.3482,  0.4081,  0.1516,  0.0000, -0.6826,
-         0.0339, -0.2338, -0.8498, -0.2335,  0.1804,  0.0870, -1.4035, -0.5496,
-        -0.7422, -0.4809, -1.0001, -0.3425,  0.0659, -0.0503, -0.3419, -0.3244,
-        -0.0342, -0.1467,  0.2128,  0.0164, -0.3881,  1.5050, -0.5471,  0.2284,
-         0.0000, -0.0501, -0.4959, -0.2610,  1.1433,  0.0000,  0.9248,  0.5754,
-         0.2257,  0.9864,  0.0768,  0.6367,  0.0725,  0.1316,  0.3105,  0.3111,
-         0.8300, -0.4903,  0.3050, -0.5383, -0.5401,  0.3386,  0.5882,  0.0000,
-         1.0710,  0.0000,  0.5495, -0.4881,  1.3290,  0.0177,  0.0962, -0.2829],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.2145e-01,  5.9023e-01, -9.7670e-02,  3.1163e-01,  3.7370e-01,
-         2.5237e-01,  7.6415e-04, -6.8939e-01,  2.7698e-03, -3.0795e-01,
-        -8.3914e-01, -1.7875e-01,  1.9994e-01,  2.1118e-01, -1.4038e+00,
-        -5.6700e-01, -7.3355e-01, -4.6738e-01, -9.9813e-01, -3.8493e-01,
-         6.0937e-02, -1.4467e-01, -3.8431e-01, -2.7310e-01, -5.7281e-02,
-        -8.0635e-02,  1.7740e-01, -2.6739e-02, -3.5758e-01,  1.5028e+00,
-        -5.5109e-01,  2.4288e-01, -1.8320e-04, -8.5688e-02, -5.2144e-01,
-        -2.5870e-01,  1.1417e+00,  0.0000e+00,  9.1919e-01,  5.4971e-01,
-         1.7333e-01,  9.8054e-01,  1.5245e-01,  6.4214e-01,  7.2957e-02,
-         1.0104e-01,  2.7743e-01,  3.3944e-01,  8.3388e-01, -5.0729e-01,
-         2.7222e-01, -5.5482e-01, -5.0834e-01,  4.2531e-01,  5.6785e-01,
-         9.8905e-06,  1.0701e+00,  1.3980e-03,  5.3191e-01, -4.9820e-01,
-         1.3314e+00,  2.5425e-02,  6.9219e-02, -2.2708e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6215,  0.5902, -0.0977,  0.3116,  0.3737,  0.2524,  0.0000, -0.6894,
-         0.0028, -0.3080, -0.8391, -0.1788,  0.1999,  0.2112, -1.4038, -0.5670,
-        -0.7336, -0.4674, -0.9981, -0.3849,  0.0609, -0.1447, -0.3843, -0.2731,
-        -0.0573, -0.0806,  0.1774, -0.0267, -0.3576,  1.5028, -0.5511,  0.2429,
-         0.0000, -0.0857, -0.5214, -0.2587,  1.1417,  0.0000,  0.9192,  0.5497,
-         0.1733,  0.9805,  0.1525,  0.6421,  0.0730,  0.1010,  0.2774,  0.3394,
-         0.8339, -0.5073,  0.2722, -0.5548, -0.5083,  0.4253,  0.5678,  0.0000,
-         1.0701,  0.0000,  0.5319, -0.4982,  1.3314,  0.0254,  0.0692, -0.2271],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6215,  0.5902, -0.0977,  0.3116,  0.3737,  0.2524,  0.0000, -0.6894,
-         0.0028, -0.3080, -0.8391, -0.1788,  0.1999,  0.2112, -1.4038, -0.5670,
-        -0.7336, -0.4674, -0.9981, -0.3849,  0.0609, -0.1447, -0.3843, -0.2731,
-        -0.0573, -0.0806,  0.1774, -0.0267, -0.3576,  1.5028, -0.5511,  0.2429,
-         0.0000, -0.0857, -0.5214, -0.2587,  1.1417,  0.0000,  0.9192,  0.5497,
-         0.1733,  0.9805,  0.1525,  0.6421,  0.0730,  0.1010,  0.2774,  0.3394,
-         0.8339, -0.5073,  0.2722, -0.5548, -0.5083,  0.4253,  0.5678,  0.0000,
-         1.0701,  0.0000,  0.5319, -0.4982,  1.3314,  0.0254,  0.0692, -0.2271],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.1668e-01,  5.6268e-01, -5.7647e-02,  2.6522e-01,  3.4746e-01,
-         2.9227e-01,  6.5377e-04, -6.9172e-01, -2.5842e-02, -3.7660e-01,
-        -8.1673e-01, -1.2168e-01,  2.0623e-01,  3.1455e-01, -1.4028e+00,
-        -5.7764e-01, -7.2087e-01, -4.4396e-01, -9.9791e-01, -4.0508e-01,
-         9.6803e-02, -1.8390e-01, -4.0050e-01, -2.0532e-01, -1.4204e-02,
-         1.8732e-02,  9.4744e-02, -4.2403e-02, -3.0900e-01,  1.5004e+00,
-        -5.5411e-01,  2.1999e-01, -1.5673e-04, -5.1963e-02, -5.2070e-01,
-        -2.4102e-01,  1.1441e+00,  0.0000e+00,  9.2264e-01,  5.6226e-01,
-         1.2789e-01,  9.7059e-01,  1.6548e-01,  6.5120e-01,  6.5211e-02,
-         6.1442e-02,  2.4294e-01,  3.7029e-01,  8.3851e-01, -5.2723e-01,
-         2.5153e-01, -5.2954e-01, -4.7323e-01,  4.3459e-01,  5.9014e-01,
-         8.4618e-06,  1.0697e+00,  1.1960e-03,  5.0578e-01, -5.0195e-01,
-         1.3360e+00, -8.6477e-02,  6.4690e-02, -2.5386e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6167,  0.5627, -0.0576,  0.2652,  0.3475,  0.2923,  0.0000, -0.6917,
-        -0.0258, -0.3766, -0.8167, -0.1217,  0.2062,  0.3145, -1.4028, -0.5776,
-        -0.7209, -0.4440, -0.9979, -0.4051,  0.0968, -0.1839, -0.4005, -0.2053,
-        -0.0142,  0.0187,  0.0947, -0.0424, -0.3090,  1.5004, -0.5541,  0.2200,
-         0.0000, -0.0520, -0.5207, -0.2410,  1.1441,  0.0000,  0.9226,  0.5623,
-         0.1279,  0.9706,  0.1655,  0.6512,  0.0652,  0.0614,  0.2429,  0.3703,
-         0.8385, -0.5272,  0.2515, -0.5295, -0.4732,  0.4346,  0.5901,  0.0000,
-         1.0697,  0.0000,  0.5058, -0.5020,  1.3360, -0.0865,  0.0647, -0.2539],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6167,  0.5627, -0.0576,  0.2652,  0.3475,  0.2923,  0.0000, -0.6917,
-        -0.0258, -0.3766, -0.8167, -0.1217,  0.2062,  0.3145, -1.4028, -0.5776,
-        -0.7209, -0.4440, -0.9979, -0.4051,  0.0968, -0.1839, -0.4005, -0.2053,
-        -0.0142,  0.0187,  0.0947, -0.0424, -0.3090,  1.5004, -0.5541,  0.2200,
-         0.0000, -0.0520, -0.5207, -0.2410,  1.1441,  0.0000,  0.9226,  0.5623,
-         0.1279,  0.9706,  0.1655,  0.6512,  0.0652,  0.0614,  0.2429,  0.3703,
-         0.8385, -0.5272,  0.2515, -0.5295, -0.4732,  0.4346,  0.5901,  0.0000,
-         1.0697,  0.0000,  0.5058, -0.5020,  1.3360, -0.0865,  0.0647, -0.2539],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.0525e-01,  5.4585e-01, -5.4705e-02,  2.1522e-01,  2.6450e-01,
-         2.8843e-01,  5.5902e-04, -6.9545e-01, -5.7872e-02, -3.8318e-01,
-        -7.8899e-01, -1.9815e-01,  1.2692e-01,  3.3820e-01, -1.3989e+00,
-        -5.7985e-01, -7.1642e-01, -3.9627e-01, -9.9945e-01, -4.3921e-01,
-         1.0305e-01, -1.2271e-01, -3.8790e-01, -1.4685e-01,  9.0309e-02,
-         8.0191e-02, -2.8660e-02, -3.3127e-02, -2.3066e-01,  1.5036e+00,
-        -5.4713e-01,  1.4020e-01, -1.3402e-04,  8.0386e-02, -5.2785e-01,
-        -2.6344e-01,  1.1405e+00,  0.0000e+00,  9.0622e-01,  5.7846e-01,
-         1.0356e-01,  9.5395e-01,  1.7825e-01,  6.4828e-01,  8.1883e-02,
-         6.3103e-02,  2.0066e-01,  3.6093e-01,  8.6294e-01, -5.4557e-01,
-         2.7658e-01, -4.7054e-01, -4.5655e-01,  3.9055e-01,  6.1684e-01,
-         7.2354e-06,  1.0687e+00,  1.0227e-03,  4.9583e-01, -4.8175e-01,
-         1.3393e+00, -1.9640e-01,  8.7251e-02, -2.8885e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.6052,  0.5458, -0.0547,  0.2152,  0.2645,  0.2884,  0.0000, -0.6955,
-        -0.0579, -0.3832, -0.7890, -0.1982,  0.1269,  0.3382, -1.3989, -0.5799,
-        -0.7164, -0.3963, -0.9995, -0.4392,  0.1030, -0.1227, -0.3879, -0.1468,
-         0.0903,  0.0802, -0.0287, -0.0331, -0.2307,  1.5036, -0.5471,  0.1402,
-         0.0000,  0.0804, -0.5278, -0.2634,  1.1405,  0.0000,  0.9062,  0.5785,
-         0.1036,  0.9540,  0.1783,  0.6483,  0.0819,  0.0631,  0.2007,  0.3609,
-         0.8629, -0.5456,  0.2766, -0.4705, -0.4566,  0.3905,  0.6168,  0.0000,
-         1.0687,  0.0000,  0.4958, -0.4817,  1.3393, -0.1964,  0.0873, -0.2889],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.6052,  0.5458, -0.0547,  0.2152,  0.2645,  0.2884,  0.0000, -0.6955,
-        -0.0579, -0.3832, -0.7890, -0.1982,  0.1269,  0.3382, -1.3989, -0.5799,
-        -0.7164, -0.3963, -0.9995, -0.4392,  0.1030, -0.1227, -0.3879, -0.1468,
-         0.0903,  0.0802, -0.0287, -0.0331, -0.2307,  1.5036, -0.5471,  0.1402,
-         0.0000,  0.0804, -0.5278, -0.2634,  1.1405,  0.0000,  0.9062,  0.5785,
-         0.1036,  0.9540,  0.1783,  0.6483,  0.0819,  0.0631,  0.2007,  0.3609,
-         0.8629, -0.5456,  0.2766, -0.4705, -0.4566,  0.3905,  0.6168,  0.0000,
-         1.0687,  0.0000,  0.4958, -0.4817,  1.3393, -0.1964,  0.0873, -0.2889],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.7844e-01,  5.2429e-01, -6.7794e-02,  1.8611e-01,  2.4306e-01,
-         3.1547e-01,  4.7773e-04, -7.0102e-01, -9.0168e-02, -3.5620e-01,
-        -7.5371e-01, -3.1503e-01,  2.5413e-02,  2.8830e-01, -1.4024e+00,
-        -5.7404e-01, -7.0919e-01, -3.2607e-01, -1.0011e+00, -5.0673e-01,
-         5.1226e-02, -8.1347e-03, -3.6101e-01, -1.1391e-01,  1.9696e-01,
-         7.4553e-02, -9.9145e-02, -3.0621e-02, -1.7908e-01,  1.5074e+00,
-        -5.4155e-01,  2.3425e-02, -1.1453e-04,  1.7663e-01, -5.2747e-01,
-        -3.2004e-01,  1.1319e+00,  0.0000e+00,  8.5451e-01,  5.5273e-01,
-         8.4971e-02,  9.3705e-01,  1.9437e-01,  6.3211e-01, -2.8614e-02,
-         9.1126e-02,  1.6426e-01,  3.2153e-01,  8.9943e-01, -5.5338e-01,
-         2.9050e-01, -3.9890e-01, -4.4031e-01,  3.4651e-01,  6.1473e-01,
-         6.1833e-06,  1.0662e+00,  8.7397e-04,  5.0349e-01, -4.3674e-01,
-         1.3410e+00, -2.2385e-01,  6.1650e-02, -2.6741e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5784,  0.5243, -0.0678,  0.1861,  0.2431,  0.3155,  0.0000, -0.7010,
-        -0.0902, -0.3562, -0.7537, -0.3150,  0.0254,  0.2883, -1.4024, -0.5740,
-        -0.7092, -0.3261, -1.0011, -0.5067,  0.0512, -0.0081, -0.3610, -0.1139,
-         0.1970,  0.0746, -0.0991, -0.0306, -0.1791,  1.5074, -0.5416,  0.0234,
-         0.0000,  0.1766, -0.5275, -0.3200,  1.1319,  0.0000,  0.8545,  0.5527,
-         0.0850,  0.9370,  0.1944,  0.6321, -0.0286,  0.0911,  0.1643,  0.3215,
-         0.8994, -0.5534,  0.2905, -0.3989, -0.4403,  0.3465,  0.6147,  0.0000,
-         1.0662,  0.0000,  0.5035, -0.4367,  1.3410, -0.2239,  0.0617, -0.2674],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5784,  0.5243, -0.0678,  0.1861,  0.2431,  0.3155,  0.0000, -0.7010,
-        -0.0902, -0.3562, -0.7537, -0.3150,  0.0254,  0.2883, -1.4024, -0.5740,
-        -0.7092, -0.3261, -1.0011, -0.5067,  0.0512, -0.0081, -0.3610, -0.1139,
-         0.1970,  0.0746, -0.0991, -0.0306, -0.1791,  1.5074, -0.5416,  0.0234,
-         0.0000,  0.1766, -0.5275, -0.3200,  1.1319,  0.0000,  0.8545,  0.5527,
-         0.0850,  0.9370,  0.1944,  0.6321, -0.0286,  0.0911,  0.1643,  0.3215,
-         0.8994, -0.5534,  0.2905, -0.3989, -0.4403,  0.3465,  0.6147,  0.0000,
-         1.0662,  0.0000,  0.5035, -0.4367,  1.3410, -0.2239,  0.0617, -0.2674],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.6626e-01,  5.0818e-01, -4.1618e-02,  1.2371e-01,  3.6499e-01,
-         3.7498e-01,  4.0805e-04, -7.0309e-01, -1.2660e-01, -3.5420e-01,
-        -7.2753e-01, -3.6262e-01, -1.6776e-02,  2.6766e-01, -1.4090e+00,
-        -5.7065e-01, -6.9769e-01, -2.9078e-01, -1.0003e+00, -5.4069e-01,
-         1.6912e-02, -2.2876e-02, -3.5790e-01, -1.0318e-01,  2.0936e-01,
-         5.1014e-02, -7.7842e-02, -4.2371e-02, -1.6923e-01,  1.5040e+00,
-        -5.1757e-01, -2.1990e-02, -9.7826e-05,  1.5210e-01, -5.2883e-01,
-        -3.3943e-01,  1.1228e+00,  0.0000e+00,  7.9887e-01,  5.0730e-01,
-        -4.9438e-03,  9.2725e-01,  2.3362e-01,  6.1538e-01, -1.4305e-01,
-         1.0326e-01,  1.5295e-01,  2.8285e-01,  9.2284e-01, -5.5858e-01,
-         2.6296e-01, -3.6377e-01, -3.9113e-01,  2.9999e-01,  5.8458e-01,
-         5.2815e-06,  1.0659e+00,  7.4650e-04,  5.1879e-01, -4.1963e-01,
-         1.3430e+00, -2.1477e-01,  2.2848e-02, -2.4526e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5663,  0.5082, -0.0416,  0.1237,  0.3650,  0.3750,  0.0000, -0.7031,
-        -0.1266, -0.3542, -0.7275, -0.3626, -0.0168,  0.2677, -1.4090, -0.5707,
-        -0.6977, -0.2908, -1.0003, -0.5407,  0.0169, -0.0229, -0.3579, -0.1032,
-         0.2094,  0.0510, -0.0778, -0.0424, -0.1692,  1.5040, -0.5176, -0.0220,
-         0.0000,  0.1521, -0.5288, -0.3394,  1.1228,  0.0000,  0.7989,  0.5073,
-        -0.0049,  0.9273,  0.2336,  0.6154, -0.1430,  0.1033,  0.1529,  0.2829,
-         0.9228, -0.5586,  0.2630, -0.3638, -0.3911,  0.3000,  0.5846,  0.0000,
-         1.0659,  0.0000,  0.5188, -0.4196,  0.0000, -0.2148,  0.0228, -0.2453],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5663,  0.5082, -0.0416,  0.1237,  0.3650,  0.3750,  0.0000, -0.7031,
-        -0.1266, -0.3542, -0.7275, -0.3626, -0.0168,  0.2677, -1.4090, -0.5707,
-        -0.6977, -0.2908, -1.0003, -0.5407,  0.0169, -0.0229, -0.3579, -0.1032,
-         0.2094,  0.0510, -0.0778, -0.0424, -0.1692,  1.5040, -0.5176, -0.0220,
-         0.0000,  0.1521, -0.5288, -0.3394,  1.1228,  0.0000,  0.7989,  0.5073,
-        -0.0049,  0.9273,  0.2336,  0.6154, -0.1430,  0.1033,  0.1529,  0.2829,
-         0.9228, -0.5586,  0.2630, -0.3638, -0.3911,  0.3000,  0.5846,  0.0000,
-         1.0659,  0.0000,  0.5188, -0.4196,  0.0000, -0.2148,  0.0228, -0.2453],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.5728e-01,  4.9672e-01, -2.2435e-03,  1.6039e-02,  5.2010e-01,
-         3.5580e-01,  3.4836e-04, -6.8731e-01, -1.3660e-01, -3.6136e-01,
-        -6.9511e-01, -3.4699e-01, -1.4176e-02,  2.0941e-01, -1.4173e+00,
-        -5.5032e-01, -6.7155e-01, -2.6646e-01, -9.9879e-01, -5.5331e-01,
-         2.6910e-03, -4.0174e-02, -3.2042e-01, -8.5846e-02,  2.2005e-01,
-         3.1165e-02, -4.2465e-02, -5.8996e-02, -1.7584e-01,  1.4961e+00,
-        -5.2029e-01, -1.0416e-01, -8.3516e-05,  1.2926e-01, -5.1236e-01,
-        -3.2901e-01,  1.1158e+00,  0.0000e+00,  7.6130e-01,  4.8167e-01,
-        -1.0942e-01,  9.2122e-01,  1.8350e-01,  5.9760e-01, -3.1080e-01,
-         9.3059e-02,  2.0638e-01,  2.4425e-01,  9.3663e-01, -5.5527e-01,
-         2.4943e-01, -2.7359e-01, -3.6909e-01,  7.4561e-02,  5.6233e-01,
-         4.5089e-06,  1.0656e+00,  6.3730e-04,  5.4558e-01, -3.8663e-01,
-         1.7024e-03, -2.0978e-01,  3.4123e-02, -1.9346e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5573,  0.4967, -0.0022,  0.0160,  0.5201,  0.3558,  0.0000, -0.6873,
-        -0.1366, -0.3614, -0.6951, -0.3470, -0.0142,  0.2094, -1.4173, -0.5503,
-        -0.6716, -0.2665, -0.9988, -0.5533,  0.0027, -0.0402, -0.3204, -0.0858,
-         0.2200,  0.0312, -0.0425, -0.0590, -0.1758,  1.4961, -0.5203, -0.1042,
-         0.0000,  0.1293, -0.5124, -0.3290,  1.1158,  0.0000,  0.7613,  0.4817,
-        -0.1094,  0.9212,  0.1835,  0.5976, -0.3108,  0.0931,  0.2064,  0.2443,
-         0.9366, -0.5553,  0.2494, -0.2736, -0.3691,  0.0746,  0.5623,  0.0000,
-         1.0656,  0.0000,  0.5456, -0.3866,  0.0000, -0.2098,  0.0341, -0.1935],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5573,  0.4967, -0.0022,  0.0160,  0.5201,  0.3558,  0.0000, -0.6873,
-        -0.1366, -0.3614, -0.6951, -0.3470, -0.0142,  0.2094, -1.4173, -0.5503,
-        -0.6716, -0.2665, -0.9988, -0.5533,  0.0027, -0.0402, -0.3204, -0.0858,
-         0.2200,  0.0312, -0.0425, -0.0590, -0.1758,  1.4961, -0.5203, -0.1042,
-         0.0000,  0.1293, -0.5124, -0.3290,  1.1158,  0.0000,  0.7613,  0.4817,
-        -0.1094,  0.9212,  0.1835,  0.5976, -0.3108,  0.0931,  0.2064,  0.2443,
-         0.9366, -0.5553,  0.2494, -0.2736, -0.3691,  0.0746,  0.5623,  0.0000,
-         1.0656,  0.0000,  0.5456, -0.3866,  0.0000, -0.2098,  0.0341, -0.1935],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.5047e-01,  4.4973e-01,  6.2542e-02, -5.7362e-02,  6.3128e-01,
-         3.3603e-01,  2.9726e-04, -6.7263e-01, -1.4934e-01, -3.8050e-01,
-        -6.6589e-01, -3.0309e-01,  4.8251e-02,  1.9018e-01, -1.4266e+00,
-        -5.3952e-01, -6.3903e-01, -2.4409e-01, -9.9123e-01, -5.6829e-01,
-        -3.6255e-02, -6.8751e-02, -2.7236e-01, -6.8772e-02,  2.1357e-01,
-        -1.3793e-04,  9.2802e-03, -6.0042e-02, -1.9401e-01,  1.4867e+00,
-        -5.1789e-01, -1.3741e-01, -7.1265e-05,  1.1398e-01, -5.1325e-01,
-        -3.3397e-01,  1.1126e+00,  0.0000e+00,  7.5114e-01,  4.5934e-01,
-        -1.9083e-01,  9.1833e-01,  1.7224e-01,  5.9478e-01, -4.4776e-01,
-         4.4938e-02,  2.2342e-01,  2.1155e-01,  9.4658e-01, -5.5168e-01,
-         2.2345e-01, -1.6398e-01, -3.0004e-01, -8.0197e-02,  5.3234e-01,
-         3.8475e-06,  1.0655e+00,  5.4382e-04,  5.7665e-01, -3.4530e-01,
-         1.4527e-03, -1.6444e-01,  4.3380e-02, -1.3788e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 5.5047e-01,  4.4973e-01,  6.2542e-02, -5.7362e-02,  6.3128e-01,
-         3.3603e-01,  0.0000e+00, -6.7263e-01, -1.4934e-01, -3.8050e-01,
-        -6.6589e-01, -3.0309e-01,  4.8251e-02,  1.9018e-01, -1.4266e+00,
-        -5.3952e-01, -6.3903e-01, -2.4409e-01, -9.9123e-01, -5.6829e-01,
-        -3.6255e-02, -6.8751e-02, -2.7236e-01, -6.8772e-02,  2.1357e-01,
-        -1.3793e-04,  9.2802e-03, -6.0042e-02, -1.9401e-01,  1.4867e+00,
-        -5.1789e-01, -1.3741e-01,  0.0000e+00,  1.1398e-01, -5.1325e-01,
-        -3.3397e-01,  1.1126e+00,  0.0000e+00,  7.5114e-01,  4.5934e-01,
-        -1.9083e-01,  9.1833e-01,  1.7224e-01,  5.9478e-01, -4.4776e-01,
-         4.4938e-02,  2.2342e-01,  2.1155e-01,  9.4658e-01, -5.5168e-01,
-         2.2345e-01, -1.6398e-01, -3.0004e-01, -8.0197e-02,  5.3234e-01,
-         0.0000e+00,  1.0655e+00,  0.0000e+00,  5.7665e-01, -3.4530e-01,
-         0.0000e+00, -1.6444e-01,  4.3380e-02, -1.3788e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 5.5047e-01,  4.4973e-01,  6.2542e-02, -5.7362e-02,  6.3128e-01,
-         3.3603e-01,  0.0000e+00, -6.7263e-01, -1.4934e-01, -3.8050e-01,
-        -6.6589e-01, -3.0309e-01,  4.8251e-02,  1.9018e-01, -1.4266e+00,
-        -5.3952e-01, -6.3903e-01, -2.4409e-01, -9.9123e-01, -5.6829e-01,
-        -3.6255e-02, -6.8751e-02, -2.7236e-01, -6.8772e-02,  2.1357e-01,
-        -1.3793e-04,  9.2802e-03, -6.0042e-02, -1.9401e-01,  1.4867e+00,
-        -5.1789e-01, -1.3741e-01,  0.0000e+00,  1.1398e-01, -5.1325e-01,
-        -3.3397e-01,  1.1126e+00,  0.0000e+00,  7.5114e-01,  4.5934e-01,
-        -1.9083e-01,  9.1833e-01,  1.7224e-01,  5.9478e-01, -4.4776e-01,
-         4.4938e-02,  2.2342e-01,  2.1155e-01,  9.4658e-01, -5.5168e-01,
-         2.2345e-01, -1.6398e-01, -3.0004e-01, -8.0197e-02,  5.3234e-01,
-         0.0000e+00,  1.0655e+00,  0.0000e+00,  5.7665e-01, -3.4530e-01,
-         0.0000e+00, -1.6444e-01,  4.3380e-02, -1.3788e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 5.5140e-01,  3.8380e-01,  1.0333e-01, -1.0257e-01,  6.3202e-01,
-         3.5682e-01,  2.5354e-04, -6.6595e-01, -1.5545e-01, -4.0760e-01,
-        -6.3484e-01, -2.8992e-01,  9.1408e-02,  1.5846e-01, -1.4380e+00,
-        -5.1524e-01, -6.0738e-01, -2.2113e-01, -9.8586e-01, -5.9243e-01,
-        -5.8980e-02, -5.1422e-02, -2.0942e-01, -8.4843e-02,  2.4475e-01,
-        -2.8406e-02,  6.3014e-02, -2.6862e-02, -1.8045e-01,  1.4787e+00,
-        -5.3108e-01, -1.6195e-01, -6.0784e-05,  8.3846e-02, -5.1943e-01,
-        -3.4078e-01,  1.1146e+00,  0.0000e+00,  7.5436e-01,  4.5692e-01,
-        -2.0929e-01,  9.1523e-01,  1.5030e-01,  6.0322e-01, -5.6487e-01,
-         3.0204e-03,  2.3032e-01,  1.7768e-01,  9.4351e-01, -5.6330e-01,
-         2.2243e-01, -5.4021e-02, -2.2558e-01, -1.2689e-01,  5.0926e-01,
-         3.2816e-06,  1.0652e+00,  4.6384e-04,  5.7319e-01, -2.9187e-01,
-         1.2390e-03, -1.3754e-01,  4.7827e-02, -9.4046e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5514,  0.3838,  0.1033, -0.1026,  0.6320,  0.3568,  0.0000, -0.6660,
-        -0.1554, -0.4076, -0.6348, -0.2899,  0.0914,  0.1585, -1.4380, -0.5152,
-        -0.6074, -0.2211, -0.9859, -0.5924, -0.0590, -0.0514, -0.2094, -0.0848,
-         0.2448, -0.0284,  0.0630, -0.0269, -0.1805,  1.4787, -0.5311, -0.1619,
-         0.0000,  0.0838, -0.5194, -0.3408,  1.1146,  0.0000,  0.7544,  0.4569,
-        -0.2093,  0.9152,  0.1503,  0.6032, -0.5649,  0.0030,  0.2303,  0.1777,
-         0.9435, -0.5633,  0.2224, -0.0540, -0.2256, -0.1269,  0.5093,  0.0000,
-         1.0652,  0.0000,  0.5732, -0.2919,  0.0000, -0.1375,  0.0478, -0.0940],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5514,  0.3838,  0.1033, -0.1026,  0.6320,  0.3568,  0.0000, -0.6660,
-        -0.1554, -0.4076, -0.6348, -0.2899,  0.0914,  0.1585, -1.4380, -0.5152,
-        -0.6074, -0.2211, -0.9859, -0.5924, -0.0590, -0.0514, -0.2094, -0.0848,
-         0.2448, -0.0284,  0.0630, -0.0269, -0.1805,  1.4787, -0.5311, -0.1619,
-         0.0000,  0.0838, -0.5194, -0.3408,  1.1146,  0.0000,  0.7544,  0.4569,
-        -0.2093,  0.9152,  0.1503,  0.6032, -0.5649,  0.0030,  0.2303,  0.1777,
-         0.9435, -0.5633,  0.2224, -0.0540, -0.2256, -0.1269,  0.5093,  0.0000,
-         1.0652,  0.0000,  0.5732, -0.2919,  0.0000, -0.1375,  0.0478, -0.0940],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.3905e-01,  3.4065e-01,  1.1721e-01, -1.1936e-01,  6.3459e-01,
-         3.9768e-01,  2.1616e-04, -6.5812e-01, -1.9629e-01, -4.0594e-01,
-        -6.1425e-01, -2.8467e-01,  9.6524e-02,  1.0301e-01, -1.4489e+00,
-        -4.9337e-01, -5.8914e-01, -2.0104e-01, -9.8298e-01, -6.1570e-01,
-        -1.0453e-01, -1.7080e-02, -1.6030e-01, -1.6458e-01,  2.7094e-01,
-        -9.7508e-02,  1.4582e-01,  1.0226e-02, -1.9795e-01,  1.4738e+00,
-        -5.2536e-01, -1.7267e-01, -5.1822e-05,  8.6860e-02, -5.3917e-01,
-        -3.6035e-01,  1.1141e+00,  0.0000e+00,  7.4927e-01,  4.2505e-01,
-        -2.0808e-01,  9.1462e-01,  1.2448e-01,  6.0344e-01, -6.7064e-01,
-         1.2047e-02,  2.3145e-01,  1.4974e-01,  9.4994e-01, -5.6551e-01,
-         2.3764e-01,  1.4328e-02, -1.6489e-01, -1.0922e-01,  4.7046e-01,
-         2.7978e-06,  1.0651e+00,  3.9545e-04,  5.7799e-01, -2.6864e-01,
-         1.0563e-03, -6.4826e-02,  3.6531e-02, -1.0345e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5391,  0.3406,  0.1172, -0.1194,  0.6346,  0.3977,  0.0000, -0.6581,
-        -0.1963, -0.4059, -0.6143, -0.2847,  0.0965,  0.1030, -1.4489, -0.4934,
-        -0.5891, -0.2010, -0.9830, -0.6157, -0.1045, -0.0171, -0.1603, -0.1646,
-         0.2709, -0.0975,  0.1458,  0.0102, -0.1980,  1.4738, -0.5254, -0.1727,
-         0.0000,  0.0869, -0.5392, -0.3603,  1.1141,  0.0000,  0.7493,  0.4250,
-        -0.2081,  0.9146,  0.1245,  0.6034, -0.6706,  0.0120,  0.2314,  0.1497,
-         0.9499, -0.5655,  0.2376,  0.0143, -0.1649, -0.1092,  0.4705,  0.0000,
-         1.0651,  0.0000,  0.5780, -0.2686,  0.0000, -0.0648,  0.0365, -0.0103],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5391,  0.3406,  0.1172, -0.1194,  0.6346,  0.3977,  0.0000, -0.6581,
-        -0.1963, -0.4059, -0.6143, -0.2847,  0.0965,  0.1030, -1.4489, -0.4934,
-        -0.5891, -0.2010, -0.9830, -0.6157, -0.1045, -0.0171, -0.1603, -0.1646,
-         0.2709, -0.0975,  0.1458,  0.0102, -0.1980,  1.4738, -0.5254, -0.1727,
-         0.0000,  0.0869, -0.5392, -0.3603,  1.1141,  0.0000,  0.7493,  0.4250,
-        -0.2081,  0.9146,  0.1245,  0.6034, -0.6706,  0.0120,  0.2314,  0.1497,
-         0.9499, -0.5655,  0.2376,  0.0143, -0.1649, -0.1092,  0.4705,  0.0000,
-         1.0651,  0.0000,  0.5780, -0.2686,  0.0000, -0.0648,  0.0365, -0.0103],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.2527e-01,  3.3718e-01,  9.2360e-02, -1.1930e-01,  6.4721e-01,
-         4.2618e-01,  1.8421e-04, -6.4675e-01, -1.9705e-01, -3.8463e-01,
-        -5.8859e-01, -3.3711e-01,  7.3762e-02, -3.5781e-02, -1.4572e+00,
-        -4.6714e-01, -5.5992e-01, -1.7509e-01, -9.8324e-01, -6.3746e-01,
-        -1.3227e-01,  5.4825e-02, -1.2348e-01, -2.2187e-01,  2.8305e-01,
-        -1.5737e-01,  2.2050e-01,  3.8638e-02, -2.0444e-01,  1.4683e+00,
-        -5.1914e-01, -1.7522e-01, -4.4163e-05,  9.8886e-02, -5.4426e-01,
-        -3.9320e-01,  1.1082e+00,  0.0000e+00,  7.2997e-01,  3.8877e-01,
-        -1.9213e-01,  9.0984e-01,  7.4562e-02,  5.9224e-01, -7.6301e-01,
-         2.6643e-02,  2.6192e-01,  9.4339e-02,  9.6233e-01, -5.5871e-01,
-         2.3592e-01,  3.9413e-02, -1.2800e-01, -5.8657e-02,  4.3572e-01,
-         2.3843e-06,  1.0661e+00,  3.3701e-04,  5.8931e-01, -2.4155e-01,
-         9.0022e-04, -2.3956e-02,  1.1601e-02,  3.6838e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5253,  0.3372,  0.0924, -0.1193,  0.6472,  0.4262,  0.0000, -0.6468,
-        -0.1971, -0.3846, -0.5886, -0.3371,  0.0738, -0.0358, -1.4572, -0.4671,
-        -0.5599, -0.1751, -0.9832, -0.6375, -0.1323,  0.0548, -0.1235, -0.2219,
-         0.2830, -0.1574,  0.2205,  0.0386, -0.2044,  1.4683, -0.5191, -0.1752,
-         0.0000,  0.0989, -0.5443, -0.3932,  1.1082,  0.0000,  0.7300,  0.3888,
-        -0.1921,  0.9098,  0.0746,  0.5922, -0.7630,  0.0266,  0.2619,  0.0943,
-         0.9623, -0.5587,  0.2359,  0.0394, -0.1280, -0.0587,  0.4357,  0.0000,
-         1.0661,  0.0000,  0.5893, -0.2415,  0.0000, -0.0240,  0.0116,  0.0368],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5253,  0.3372,  0.0924, -0.1193,  0.6472,  0.4262,  0.0000, -0.6468,
-        -0.1971, -0.3846, -0.5886, -0.3371,  0.0738, -0.0358, -1.4572, -0.4671,
-        -0.5599, -0.1751, -0.9832, -0.6375, -0.1323,  0.0548, -0.1235, -0.2219,
-         0.2830, -0.1574,  0.2205,  0.0386, -0.2044,  1.4683, -0.5191, -0.1752,
-         0.0000,  0.0989, -0.5443, -0.3932,  1.1082,  0.0000,  0.7300,  0.3888,
-        -0.1921,  0.9098,  0.0746,  0.5922, -0.7630,  0.0266,  0.2619,  0.0943,
-         0.9623, -0.5587,  0.2359,  0.0394, -0.1280, -0.0587,  0.4357,  0.0000,
-         1.0661,  0.0000,  0.5893, -0.2415,  0.0000, -0.0240,  0.0116,  0.0368],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.1107e-01,  3.6013e-01,  1.0424e-01, -8.3970e-02,  6.8822e-01,
-         4.3232e-01,  1.5693e-04, -6.2801e-01, -1.9123e-01, -3.6385e-01,
-        -5.6183e-01, -3.6256e-01,  1.0312e-01, -1.5182e-01, -1.4616e+00,
-        -4.3924e-01, -5.4203e-01, -1.5586e-01, -9.8493e-01, -6.3676e-01,
-        -1.4500e-01,  1.0214e-01, -1.1260e-01, -2.2932e-01,  2.4679e-01,
-        -1.9054e-01,  2.7658e-01,  6.7748e-02, -2.2385e-01,  1.4634e+00,
-        -5.1848e-01, -1.5862e-01, -3.7622e-05,  1.1791e-01, -5.3438e-01,
-        -3.9995e-01,  1.0952e+00,  0.0000e+00,  7.0000e-01,  3.6749e-01,
-        -2.1312e-01,  9.0526e-01,  6.7397e-02,  5.6471e-01, -8.4994e-01,
-         5.0062e-02,  3.0251e-01,  2.6352e-02,  9.6509e-01, -5.3954e-01,
-         2.1624e-01,  7.4626e-02, -1.3976e-01, -2.8365e-03,  4.3100e-01,
-         2.0312e-06,  1.0666e+00,  2.8709e-04,  6.1937e-01, -2.5057e-01,
-         7.6688e-04, -9.4253e-03,  9.5602e-03,  6.4166e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5111,  0.3601,  0.1042, -0.0840,  0.6882,  0.4323,  0.0000, -0.6280,
-        -0.1912, -0.3638, -0.5618, -0.3626,  0.1031, -0.1518, -1.4616, -0.4392,
-        -0.5420, -0.1559, -0.9849, -0.6368, -0.1450,  0.1021, -0.1126, -0.2293,
-         0.2468, -0.1905,  0.2766,  0.0677, -0.2239,  1.4634, -0.5185, -0.1586,
-         0.0000,  0.1179, -0.5344, -0.3999,  1.0952,  0.0000,  0.7000,  0.3675,
-        -0.2131,  0.9053,  0.0674,  0.5647, -0.8499,  0.0501,  0.3025,  0.0264,
-         0.9651, -0.5395,  0.2162,  0.0746, -0.1398, -0.0028,  0.4310,  0.0000,
-         1.0666,  0.0000,  0.6194, -0.2506,  0.0000, -0.0094,  0.0096,  0.0642],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5111,  0.3601,  0.1042, -0.0840,  0.6882,  0.4323,  0.0000, -0.6280,
-        -0.1912, -0.3638, -0.5618, -0.3626,  0.1031, -0.1518, -1.4616, -0.4392,
-        -0.5420, -0.1559, -0.9849, -0.6368, -0.1450,  0.1021, -0.1126, -0.2293,
-         0.2468, -0.1905,  0.2766,  0.0677, -0.2239,  1.4634, -0.5185, -0.1586,
-         0.0000,  0.1179, -0.5344, -0.3999,  1.0952,  0.0000,  0.7000,  0.3675,
-        -0.2131,  0.9053,  0.0674,  0.5647, -0.8499,  0.0501,  0.3025,  0.0264,
-         0.9651, -0.5395,  0.2162,  0.0746, -0.1398, -0.0028,  0.4310,  0.0000,
-         1.0666,  0.0000,  0.6194, -0.2506,  0.0000, -0.0094,  0.0096,  0.0642],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.9059e-01,  3.3983e-01,  1.3530e-01, -6.8357e-02,  6.9423e-01,
-         4.3732e-01,  1.3364e-04, -6.0342e-01, -1.7385e-01, -3.5521e-01,
-        -5.3148e-01, -3.7287e-01,  1.3630e-01, -2.9825e-01, -1.4656e+00,
-        -4.1021e-01, -5.1826e-01, -1.3938e-01, -9.8922e-01, -6.3435e-01,
-        -1.4256e-01,  6.8182e-02, -1.4576e-01, -2.5938e-01,  1.5539e-01,
-        -2.0452e-01,  3.4185e-01,  4.7995e-02, -2.2470e-01,  1.4552e+00,
-        -5.1370e-01, -1.6977e-01, -3.2038e-05,  4.6933e-02, -5.1883e-01,
-        -3.9791e-01,  1.0887e+00,  0.0000e+00,  7.0005e-01,  3.8581e-01,
-        -2.5009e-01,  8.9414e-01,  6.2919e-02,  5.6291e-01, -9.2316e-01,
-        -1.3646e-02,  3.2716e-01, -7.6990e-04,  9.7335e-01, -5.2262e-01,
-         1.3059e-01,  5.4989e-02, -1.3847e-01, -1.1867e-02,  4.2832e-01,
-         1.7297e-06,  1.0667e+00,  2.4448e-04,  6.4024e-01, -2.2205e-01,
-         6.5306e-04, -1.2438e-02, -1.1463e-02,  1.3325e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 4.9059e-01,  3.3983e-01,  1.3530e-01, -6.8357e-02,  6.9423e-01,
-         4.3732e-01,  0.0000e+00, -6.0342e-01, -1.7385e-01, -3.5521e-01,
-        -5.3148e-01, -3.7287e-01,  1.3630e-01, -2.9825e-01, -1.4656e+00,
-        -4.1021e-01, -5.1826e-01, -1.3938e-01, -9.8922e-01, -6.3435e-01,
-        -1.4256e-01,  6.8182e-02, -1.4576e-01, -2.5938e-01,  1.5539e-01,
-        -2.0452e-01,  3.4185e-01,  4.7995e-02, -2.2470e-01,  1.4552e+00,
-        -5.1370e-01, -1.6977e-01,  0.0000e+00,  4.6933e-02, -5.1883e-01,
-        -3.9791e-01,  1.0887e+00,  0.0000e+00,  7.0005e-01,  3.8581e-01,
-        -2.5009e-01,  8.9414e-01,  6.2919e-02,  5.6291e-01, -9.2316e-01,
-        -1.3646e-02,  3.2716e-01, -7.6990e-04,  9.7335e-01, -5.2262e-01,
-         1.3059e-01,  5.4989e-02, -1.3847e-01, -1.1867e-02,  4.2832e-01,
-         0.0000e+00,  1.0667e+00,  0.0000e+00,  6.4024e-01, -2.2205e-01,
-         0.0000e+00, -1.2438e-02, -1.1463e-02,  1.3325e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 4.9059e-01,  3.3983e-01,  1.3530e-01, -6.8357e-02,  6.9423e-01,
-         4.3732e-01,  0.0000e+00, -6.0342e-01, -1.7385e-01, -3.5521e-01,
-        -5.3148e-01, -3.7287e-01,  1.3630e-01, -2.9825e-01, -1.4656e+00,
-        -4.1021e-01, -5.1826e-01, -1.3938e-01, -9.8922e-01, -6.3435e-01,
-        -1.4256e-01,  6.8182e-02, -1.4576e-01, -2.5938e-01,  1.5539e-01,
-        -2.0452e-01,  3.4185e-01,  4.7995e-02, -2.2470e-01,  1.4552e+00,
-        -5.1370e-01, -1.6977e-01,  0.0000e+00,  4.6933e-02, -5.1883e-01,
-        -3.9791e-01,  1.0887e+00,  0.0000e+00,  7.0005e-01,  3.8581e-01,
-        -2.5009e-01,  8.9414e-01,  6.2919e-02,  5.6291e-01, -9.2316e-01,
-        -1.3646e-02,  3.2716e-01, -7.6990e-04,  9.7335e-01, -5.2262e-01,
-         1.3059e-01,  5.4989e-02, -1.3847e-01, -1.1867e-02,  4.2832e-01,
-         0.0000e+00,  1.0667e+00,  0.0000e+00,  6.4024e-01, -2.2205e-01,
-         0.0000e+00, -1.2438e-02, -1.1463e-02,  1.3325e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 4.7081e-01,  2.3883e-01,  1.8326e-01, -7.8665e-02,  6.9662e-01,
-         4.2747e-01,  1.1377e-04, -5.8964e-01, -1.7300e-01, -3.7769e-01,
-        -5.2078e-01, -3.4602e-01,  1.6947e-01, -4.6670e-01, -1.4683e+00,
-        -3.9910e-01, -4.8490e-01, -1.3589e-01, -9.9218e-01, -6.2765e-01,
-        -1.5998e-01, -7.2776e-02, -2.1902e-01, -3.5279e-01,  4.6658e-03,
-        -2.0815e-01,  3.9863e-01, -4.5269e-02, -2.0966e-01,  1.4476e+00,
-        -4.9207e-01, -1.8722e-01, -2.7274e-05, -1.0340e-01, -5.1482e-01,
-        -3.7614e-01,  1.0923e+00,  0.0000e+00,  7.7747e-01,  4.9859e-01,
-        -3.3901e-01,  8.8331e-01,  5.0554e-02,  6.0221e-01, -9.9149e-01,
-        -1.2255e-01,  3.2519e-01,  2.3015e-02,  9.8171e-01, -5.0829e-01,
-         2.4330e-03, -1.2207e-02, -1.3796e-01, -1.2129e-01,  4.5706e-01,
-         1.4725e-06,  1.0683e+00,  2.0812e-04,  6.4925e-01, -2.5422e-01,
-         5.5595e-04, -1.3995e-02, -3.8960e-02,  1.9881e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4708,  0.2388,  0.1833, -0.0787,  0.6966,  0.4275,  0.0000, -0.5896,
-        -0.1730, -0.3777, -0.5208, -0.3460,  0.1695, -0.4667, -1.4683, -0.3991,
-        -0.4849, -0.1359, -0.9922, -0.6277, -0.1600, -0.0728,  0.0000, -0.3528,
-         0.0047, -0.2082,  0.3986, -0.0453, -0.2097,  1.4476, -0.4921, -0.1872,
-         0.0000, -0.1034, -0.5148, -0.3761,  1.0923,  0.0000,  0.7775,  0.4986,
-        -0.3390,  0.8833,  0.0506,  0.6022, -0.9915, -0.1226,  0.3252,  0.0230,
-         0.9817, -0.5083,  0.0024, -0.0122, -0.1380, -0.1213,  0.4571,  0.0000,
-         1.0683,  0.0000,  0.6492, -0.2542,  0.0000, -0.0140, -0.0390,  0.1988],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4708,  0.2388,  0.1833, -0.0787,  0.6966,  0.4275,  0.0000, -0.5896,
-        -0.1730, -0.3777, -0.5208, -0.3460,  0.1695, -0.4667, -1.4683, -0.3991,
-        -0.4849, -0.1359, -0.9922, -0.6277, -0.1600, -0.0728,  0.0000, -0.3528,
-         0.0047, -0.2082,  0.3986, -0.0453, -0.2097,  1.4476, -0.4921, -0.1872,
-         0.0000, -0.1034, -0.5148, -0.3761,  1.0923,  0.0000,  0.7775,  0.4986,
-        -0.3390,  0.8833,  0.0506,  0.6022, -0.9915, -0.1226,  0.3252,  0.0230,
-         0.9817, -0.5083,  0.0024, -0.0122, -0.1380, -0.1213,  0.4571,  0.0000,
-         1.0683,  0.0000,  0.6492, -0.2542,  0.0000, -0.0140, -0.0390,  0.1988],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.4753e-01,  1.7046e-01,  2.1648e-01, -1.0346e-01,  7.2184e-01,
-         4.1942e-01,  9.6817e-05, -5.6937e-01, -1.5476e-01, -3.8554e-01,
-        -4.9998e-01, -2.9748e-01,  1.6451e-01, -5.6919e-01, -1.4679e+00,
-        -3.8453e-01, -4.6549e-01, -1.0699e-01, -9.9176e-01, -6.0293e-01,
-        -1.5204e-01, -1.5665e-01, -6.2346e-02, -4.6023e-01, -1.0676e-01,
-        -2.0659e-01,  4.3018e-01, -8.0905e-02, -1.9533e-01,  1.4393e+00,
-        -4.6111e-01, -1.9065e-01, -2.3211e-05, -2.0491e-01, -5.0935e-01,
-        -3.4284e-01,  1.0918e+00,  0.0000e+00,  8.1896e-01,  5.8240e-01,
-        -3.9636e-01,  8.7870e-01,  1.4896e-02,  6.1872e-01, -1.0375e+00,
-        -1.6153e-01,  3.5192e-01,  1.2937e-02,  9.8651e-01, -5.0111e-01,
-        -9.9456e-02, -2.7428e-02, -9.6373e-02, -1.7430e-01,  4.7531e-01,
-         1.2531e-06,  1.0710e+00,  1.7712e-04,  6.6802e-01, -2.6326e-01,
-         4.7313e-04, -1.5476e-02, -4.5682e-02,  2.3322e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4475,  0.1705,  0.2165, -0.1035,  0.7218,  0.4194,  0.0000, -0.5694,
-        -0.1548, -0.3855, -0.5000, -0.2975,  0.1645, -0.5692, -1.4679, -0.3845,
-        -0.4655, -0.1070, -0.9918, -0.6029, -0.1520, -0.1566,  0.0000, -0.4602,
-        -0.1068, -0.2066,  0.4302, -0.0809, -0.1953,  1.4393, -0.4611, -0.1907,
-         0.0000, -0.2049, -0.5094, -0.3428,  1.0918,  0.0000,  0.8190,  0.5824,
-        -0.3964,  0.8787,  0.0149,  0.6187, -1.0375, -0.1615,  0.3519,  0.0129,
-         0.9865, -0.5011, -0.0995, -0.0274, -0.0964, -0.1743,  0.4753,  0.0000,
-         1.0710,  0.0000,  0.6680, -0.2633,  0.0000, -0.0155, -0.0457,  0.2332],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4475,  0.1705,  0.2165, -0.1035,  0.7218,  0.4194,  0.0000, -0.5694,
-        -0.1548, -0.3855, -0.5000, -0.2975,  0.1645, -0.5692, -1.4679, -0.3845,
-        -0.4655, -0.1070, -0.9918, -0.6029, -0.1520, -0.1566,  0.0000, -0.4602,
-        -0.1068, -0.2066,  0.4302, -0.0809, -0.1953,  1.4393, -0.4611, -0.1907,
-         0.0000, -0.2049, -0.5094, -0.3428,  1.0918,  0.0000,  0.8190,  0.5824,
-        -0.3964,  0.8787,  0.0149,  0.6187, -1.0375, -0.1615,  0.3519,  0.0129,
-         0.9865, -0.5011, -0.0995, -0.0274, -0.0964, -0.1743,  0.4753,  0.0000,
-         1.0710,  0.0000,  0.6680, -0.2633,  0.0000, -0.0155, -0.0457,  0.2332],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.3259e-01,  1.9163e-01,  1.8969e-01, -1.6569e-01,  7.1753e-01,
-         3.8239e-01,  8.2370e-05, -5.6074e-01, -1.0451e-01, -3.8877e-01,
-        -4.4906e-01, -2.9815e-01,  9.2326e-02, -6.3893e-01, -1.4654e+00,
-        -3.5690e-01, -4.9529e-01, -2.1740e-02, -9.8927e-01, -5.9025e-01,
-        -1.5437e-01, -1.1209e-01, -5.3043e-02, -5.5274e-01, -1.1057e-01,
-        -2.0467e-01,  3.9509e-01, -7.6677e-02, -1.3375e-01,  1.4351e+00,
-        -4.3399e-01, -1.8138e-01, -1.9747e-05, -2.2238e-01, -5.0814e-01,
-        -3.1810e-01,  1.0938e+00,  0.0000e+00,  8.2395e-01,  6.1429e-01,
-        -3.5502e-01,  8.7460e-01, -9.1705e-02,  6.1877e-01, -1.0736e+00,
-        -1.6666e-01,  3.7361e-01, -4.3820e-02,  9.8897e-01, -5.2061e-01,
-        -1.2917e-01, -2.1119e-02, -3.5087e-02, -1.1197e-01,  4.6225e-01,
-         1.0661e-06,  1.0726e+00,  1.5069e-04,  6.7779e-01, -2.4294e-01,
-         4.0253e-04, -4.4036e-02, -3.6369e-02,  1.4107e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4326,  0.1916,  0.1897, -0.1657,  0.7175,  0.3824,  0.0000, -0.5607,
-        -0.1045, -0.3888, -0.4491, -0.2981,  0.0923, -0.6389, -1.4654, -0.3569,
-        -0.4953, -0.0217, -0.9893, -0.5903, -0.1544, -0.1121,  0.0000, -0.5527,
-        -0.1106, -0.2047,  0.3951, -0.0767, -0.1337,  1.4351, -0.4340, -0.1814,
-         0.0000, -0.2224, -0.5081, -0.3181,  1.0938,  0.0000,  0.8239,  0.6143,
-        -0.3550,  0.8746, -0.0917,  0.6188, -1.0736, -0.1667,  0.3736, -0.0438,
-         0.9890, -0.5206, -0.1292, -0.0211, -0.0351, -0.1120,  0.4622,  0.0000,
-         1.0726,  0.0000,  0.6778, -0.2429,  0.0000, -0.0440, -0.0364,  0.1411],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4326,  0.1916,  0.1897, -0.1657,  0.7175,  0.3824,  0.0000, -0.5607,
-        -0.1045, -0.3888, -0.4491, -0.2981,  0.0923, -0.6389, -1.4654, -0.3569,
-        -0.4953, -0.0217, -0.9893, -0.5903, -0.1544, -0.1121,  0.0000, -0.5527,
-        -0.1106, -0.2047,  0.3951, -0.0767, -0.1337,  1.4351, -0.4340, -0.1814,
-         0.0000, -0.2224, -0.5081, -0.3181,  1.0938,  0.0000,  0.8239,  0.6143,
-        -0.3550,  0.8746, -0.0917,  0.6188, -1.0736, -0.1667,  0.3736, -0.0438,
-         0.9890, -0.5206, -0.1292, -0.0211, -0.0351, -0.1120,  0.4622,  0.0000,
-         1.0726,  0.0000,  0.6778, -0.2429,  0.0000, -0.0440, -0.0364,  0.1411],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.4532e-01,  2.3895e-01,  1.7110e-01, -2.3145e-01,  7.2178e-01,
-         3.8681e-01,  7.0061e-05, -5.6339e-01, -3.9849e-02, -3.8142e-01,
-        -3.6340e-01, -3.1950e-01,  6.3203e-02, -7.0342e-01, -1.4591e+00,
-        -3.1612e-01, -5.6265e-01,  1.1765e-01, -9.8531e-01, -5.8519e-01,
-        -1.8281e-01,  2.5840e-02, -4.5116e-02, -6.3061e-01, -5.1264e-02,
-        -2.2680e-01,  3.9858e-01, -4.5774e-02, -7.3069e-02,  1.4343e+00,
-        -4.2128e-01, -1.3238e-01, -1.6796e-05, -1.8789e-01, -5.2675e-01,
-        -3.0426e-01,  1.0835e+00,  0.0000e+00,  8.0481e-01,  6.0208e-01,
-        -2.5056e-01,  8.7287e-01, -1.5560e-01,  5.9806e-01, -1.0983e+00,
-        -7.5374e-02,  4.0696e-01, -1.3451e-01,  9.8887e-01, -5.4693e-01,
-        -8.9774e-02, -3.8470e-02,  1.2505e-01,  7.1373e-02,  4.1805e-01,
-         9.0680e-07,  1.0734e+00,  1.2817e-04,  6.8414e-01, -1.8182e-01,
-         3.4237e-04, -2.1057e-02,  1.8580e-03,  8.6332e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4453,  0.2389,  0.1711, -0.2314,  0.7218,  0.3868,  0.0000, -0.5634,
-        -0.0398, -0.3814, -0.3634, -0.3195,  0.0632, -0.7034, -1.4591, -0.3161,
-        -0.5627,  0.1176, -0.9853, -0.5852, -0.1828,  0.0258,  0.0000, -0.6306,
-        -0.0513, -0.2268,  0.3986, -0.0458, -0.0731,  1.4343, -0.4213, -0.1324,
-         0.0000, -0.1879, -0.5268, -0.3043,  1.0835,  0.0000,  0.8048,  0.6021,
-        -0.2506,  0.8729, -0.1556,  0.5981, -1.0983, -0.0754,  0.4070, -0.1345,
-         0.9889, -0.5469, -0.0898, -0.0385,  0.1250,  0.0714,  0.4181,  0.0000,
-         1.0734,  0.0000,  0.6841, -0.1818,  0.0000, -0.0211,  0.0019,  0.0863],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4453,  0.2389,  0.1711, -0.2314,  0.7218,  0.3868,  0.0000, -0.5634,
-        -0.0398, -0.3814, -0.3634, -0.3195,  0.0632, -0.7034, -1.4591, -0.3161,
-        -0.5627,  0.1176, -0.9853, -0.5852, -0.1828,  0.0258,  0.0000, -0.6306,
-        -0.0513, -0.2268,  0.3986, -0.0458, -0.0731,  1.4343, -0.4213, -0.1324,
-         0.0000, -0.1879, -0.5268, -0.3043,  1.0835,  0.0000,  0.8048,  0.6021,
-        -0.2506,  0.8729, -0.1556,  0.5981, -1.0983, -0.0754,  0.4070, -0.1345,
-         0.9889, -0.5469, -0.0898, -0.0385,  0.1250,  0.0714,  0.4181,  0.0000,
-         1.0734,  0.0000,  0.6841, -0.1818,  0.0000, -0.0211,  0.0019,  0.0863],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.6785e-01,  2.4444e-01,  1.9509e-01, -2.8371e-01,  7.2764e-01,
-         4.1348e-01,  5.9576e-05, -5.7969e-01, -1.1963e-02, -3.9287e-01,
-        -3.1136e-01, -3.0052e-01,  6.9259e-02, -7.7039e-01, -1.4514e+00,
-        -3.0913e-01, -6.2357e-01,  2.5411e-01, -9.8252e-01, -5.8364e-01,
-        -2.1796e-01,  8.1617e-02, -3.8365e-02, -7.0534e-01, -5.0387e-02,
-        -2.4585e-01,  4.5036e-01, -4.2769e-02, -5.1513e-02,  1.4313e+00,
-        -4.0993e-01, -1.0435e-01, -1.4283e-05, -2.1790e-01, -5.4341e-01,
-        -2.7610e-01,  1.0780e+00,  0.0000e+00,  7.8962e-01,  5.9004e-01,
-        -1.9473e-01,  8.7561e-01, -2.5912e-01,  5.7900e-01, -1.1183e+00,
-        -1.9432e-02,  4.2397e-01, -1.7701e-01,  9.8784e-01, -5.7921e-01,
-        -5.3255e-02, -9.5059e-02,  2.2315e-01,  3.0519e-01,  3.7591e-01,
-         7.7110e-07,  1.0727e+00,  1.0899e-04,  6.8030e-01, -1.5751e-01,
-         2.9114e-04,  5.7634e-02,  1.5826e-02,  9.1401e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4679,  0.2444,  0.1951, -0.2837,  0.7276,  0.4135,  0.0000, -0.5797,
-        -0.0120, -0.3929, -0.3114, -0.3005,  0.0693, -0.7704, -1.4514, -0.3091,
-        -0.6236,  0.2541, -0.9825, -0.5836, -0.2180,  0.0816,  0.0000, -0.7053,
-        -0.0504, -0.2458,  0.4504, -0.0428, -0.0515,  1.4313, -0.4099, -0.1044,
-         0.0000, -0.2179, -0.5434, -0.2761,  1.0780,  0.0000,  0.7896,  0.5900,
-        -0.1947,  0.8756, -0.2591,  0.5790, -1.1183, -0.0194,  0.4240, -0.1770,
-         0.9878, -0.5792, -0.0533, -0.0951,  0.2231,  0.3052,  0.3759,  0.0000,
-         1.0727,  0.0000,  0.6803, -0.1575,  0.0000,  0.0576,  0.0158,  0.0914],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4679,  0.2444,  0.1951, -0.2837,  0.7276,  0.4135,  0.0000, -0.5797,
-        -0.0120, -0.3929, -0.3114, -0.3005,  0.0693, -0.7704, -1.4514, -0.3091,
-        -0.6236,  0.2541, -0.9825, -0.5836, -0.2180,  0.0816,  0.0000, -0.7053,
-        -0.0504, -0.2458,  0.4504, -0.0428, -0.0515,  1.4313, -0.4099, -0.1044,
-         0.0000, -0.2179, -0.5434, -0.2761,  1.0780,  0.0000,  0.7896,  0.5900,
-        -0.1947,  0.8756, -0.2591,  0.5790, -1.1183, -0.0194,  0.4240, -0.1770,
-         0.9878, -0.5792, -0.0533, -0.0951,  0.2231,  0.3052,  0.3759,  0.0000,
-         1.0727,  0.0000,  0.6803, -0.1575,  0.0000,  0.0576,  0.0158,  0.0914],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.5840e-01,  2.9057e-01,  1.7091e-01, -3.8891e-01,  6.6411e-01,
-         3.9543e-01,  5.0650e-05, -5.9420e-01,  2.8812e-02, -4.0198e-01,
-        -2.8684e-01, -2.7969e-01,  5.0558e-02, -8.2597e-01, -1.4440e+00,
-        -2.5839e-01, -6.3633e-01,  3.8163e-01, -9.8341e-01, -5.8471e-01,
-        -2.2842e-01,  1.0183e-01, -3.2616e-02, -7.3844e-01, -1.1496e-01,
-        -2.5931e-01,  4.7596e-01, -9.4712e-02, -1.3716e-02,  1.4321e+00,
-        -3.8447e-01, -9.2772e-02, -1.2143e-05, -2.8732e-01, -5.3446e-01,
-        -2.4303e-01,  1.0679e+00,  0.0000e+00,  7.6617e-01,  5.9346e-01,
-        -2.1270e-01,  8.5549e-01, -3.9992e-01,  5.5886e-01, -1.1256e+00,
-         3.6758e-04,  4.3037e-01, -2.1781e-01,  9.9584e-01, -5.9193e-01,
-        -3.9790e-02, -1.7109e-01,  3.3329e-01,  5.6285e-01,  2.9734e-01,
-         6.5556e-07,  1.0739e+00,  9.2659e-05,  6.5510e-01, -5.8074e-03,
-         2.4751e-04,  6.2377e-02,  9.9500e-03,  2.4163e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Sparsity at the end of epoch 0: 10.68%
-After Step tensor([ 4.5840e-01,  2.9057e-01,  1.7091e-01, -3.8891e-01,  6.6411e-01,
-         3.9543e-01,  0.0000e+00, -5.9420e-01,  2.8812e-02, -4.0198e-01,
-        -2.8684e-01, -2.7969e-01,  5.0558e-02, -8.2597e-01, -1.4440e+00,
-        -2.5839e-01, -6.3633e-01,  3.8163e-01, -9.8341e-01, -5.8471e-01,
-        -2.2842e-01,  1.0183e-01,  0.0000e+00, -7.3844e-01, -1.1496e-01,
-        -2.5931e-01,  4.7596e-01, -9.4712e-02, -1.3716e-02,  1.4321e+00,
-        -3.8447e-01, -9.2772e-02,  0.0000e+00, -2.8732e-01, -5.3446e-01,
-        -2.4303e-01,  1.0679e+00,  0.0000e+00,  7.6617e-01,  5.9346e-01,
-        -2.1270e-01,  8.5549e-01, -3.9992e-01,  5.5886e-01, -1.1256e+00,
-         3.6758e-04,  4.3037e-01, -2.1781e-01,  9.9584e-01, -5.9193e-01,
-        -3.9790e-02, -1.7109e-01,  3.3329e-01,  5.6285e-01,  2.9734e-01,
-         0.0000e+00,  1.0739e+00,  0.0000e+00,  6.5510e-01, -5.8074e-03,
-         0.0000e+00,  6.2377e-02,  9.9500e-03,  2.4163e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 4.5840e-01,  2.9057e-01,  1.7091e-01, -3.8891e-01,  6.6411e-01,
-         3.9543e-01,  0.0000e+00, -5.9420e-01,  2.8812e-02, -4.0198e-01,
-        -2.8684e-01, -2.7969e-01,  5.0558e-02, -8.2597e-01, -1.4440e+00,
-        -2.5839e-01, -6.3633e-01,  3.8163e-01, -9.8341e-01, -5.8471e-01,
-        -2.2842e-01,  1.0183e-01,  0.0000e+00, -7.3844e-01, -1.1496e-01,
-        -2.5931e-01,  4.7596e-01, -9.4712e-02, -1.3716e-02,  1.4321e+00,
-        -3.8447e-01, -9.2772e-02,  0.0000e+00, -2.8732e-01, -5.3446e-01,
-        -2.4303e-01,  1.0679e+00,  0.0000e+00,  7.6617e-01,  5.9346e-01,
-        -2.1270e-01,  8.5549e-01, -3.9992e-01,  5.5886e-01, -1.1256e+00,
-         3.6758e-04,  4.3037e-01, -2.1781e-01,  9.9584e-01, -5.9193e-01,
-        -3.9790e-02, -1.7109e-01,  3.3329e-01,  5.6285e-01,  2.9734e-01,
-         0.0000e+00,  1.0739e+00,  0.0000e+00,  6.5510e-01, -5.8074e-03,
-         0.0000e+00,  6.2377e-02,  9.9500e-03,  2.4163e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 4.3286e-01,  2.8000e-01,  1.6923e-01, -4.8403e-01,  5.7640e-01,
-         3.2788e-01,  4.3053e-05, -5.9901e-01,  3.5814e-02, -4.1091e-01,
-        -2.9351e-01, -2.3261e-01,  4.8610e-02, -8.9256e-01, -1.4412e+00,
-        -2.1577e-01, -6.2868e-01,  4.5008e-01, -9.8364e-01, -5.7130e-01,
-        -2.5340e-01,  3.5375e-02, -2.7724e-02, -7.5156e-01, -2.0321e-01,
-        -2.5749e-01,  4.3597e-01, -1.6754e-01,  1.3792e-02,  1.4283e+00,
-        -3.4021e-01, -8.7801e-02, -1.0321e-05, -3.5330e-01, -5.1770e-01,
-        -1.7677e-01,  1.0708e+00,  0.0000e+00,  7.8555e-01,  6.6360e-01,
-        -2.6467e-01,  8.3659e-01, -5.0640e-01,  5.7120e-01, -1.1257e+00,
-        -4.0641e-02,  4.0581e-01, -2.2135e-01,  9.9107e-01, -5.9461e-01,
-        -6.3748e-02, -2.2799e-01,  3.6500e-01,  7.3063e-01,  3.0955e-01,
-         5.5724e-07,  1.0745e+00,  7.8762e-05,  5.7942e-01,  1.0340e-01,
-         2.1039e-04, -4.0106e-02,  6.1098e-02, -5.6981e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4329,  0.2800,  0.1692, -0.4840,  0.5764,  0.3279,  0.0000, -0.5990,
-         0.0358, -0.4109, -0.2935, -0.2326,  0.0486, -0.8926, -1.4412, -0.2158,
-        -0.6287,  0.4501, -0.9836, -0.5713, -0.2534,  0.0354,  0.0000, -0.7516,
-        -0.2032, -0.2575,  0.4360, -0.1675,  0.0138,  1.4283, -0.3402, -0.0878,
-         0.0000, -0.3533, -0.5177, -0.1768,  1.0708,  0.0000,  0.7856,  0.6636,
-        -0.2647,  0.8366, -0.5064,  0.5712, -1.1257, -0.0406,  0.4058, -0.2214,
-         0.9911,  0.0000, -0.0637, -0.2280,  0.3650,  0.7306,  0.3096,  0.0000,
-         1.0745,  0.0000,  0.5794,  0.1034,  0.0000, -0.0401,  0.0611, -0.0570],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4329,  0.2800,  0.1692, -0.4840,  0.5764,  0.3279,  0.0000, -0.5990,
-         0.0358, -0.4109, -0.2935, -0.2326,  0.0486, -0.8926, -1.4412, -0.2158,
-        -0.6287,  0.4501, -0.9836, -0.5713, -0.2534,  0.0354,  0.0000, -0.7516,
-        -0.2032, -0.2575,  0.4360, -0.1675,  0.0138,  1.4283, -0.3402, -0.0878,
-         0.0000, -0.3533, -0.5177, -0.1768,  1.0708,  0.0000,  0.7856,  0.6636,
-        -0.2647,  0.8366, -0.5064,  0.5712, -1.1257, -0.0406,  0.4058, -0.2214,
-         0.9911,  0.0000, -0.0637, -0.2280,  0.3650,  0.7306,  0.3096,  0.0000,
-         1.0745,  0.0000,  0.5794,  0.1034,  0.0000, -0.0401,  0.0611, -0.0570],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6872e-01,  2.3193e-01,  1.3710e-01, -5.2275e-01,  5.0012e-01,
-         2.4770e-01,  3.6594e-05, -5.9221e-01,  7.0436e-02, -4.1994e-01,
-        -2.3946e-01, -1.5612e-01,  3.1458e-02, -9.5569e-01, -1.4419e+00,
-        -1.9207e-01, -6.3548e-01,  4.9663e-01, -9.8374e-01, -5.5298e-01,
-        -2.7169e-01, -7.0868e-03, -2.3565e-02, -8.0094e-01, -2.5023e-01,
-        -2.4441e-01,  3.4917e-01, -1.9131e-01, -5.3381e-03,  1.4293e+00,
-        -3.2000e-01, -6.2051e-02, -8.7730e-06, -4.4616e-01, -4.9482e-01,
-        -9.1986e-02,  1.0754e+00,  0.0000e+00,  8.0484e-01,  7.2812e-01,
-        -2.9912e-01,  8.3145e-01, -5.6387e-01,  5.9547e-01, -1.1286e+00,
-        -7.2772e-02,  3.7698e-01, -2.1293e-01,  9.9072e-01, -2.2758e-03,
-        -8.7015e-02, -2.7108e-01,  3.9484e-01,  8.3613e-01,  3.5962e-01,
-         4.7364e-07,  1.0747e+00,  6.6946e-05,  5.0648e-01,  1.8811e-01,
-         1.7883e-04, -1.3825e-01,  1.3570e-01, -4.0340e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3687,  0.2319,  0.1371, -0.5227,  0.5001,  0.2477,  0.0000, -0.5922,
-         0.0704, -0.4199, -0.2395, -0.1561,  0.0315, -0.9557, -1.4419, -0.1921,
-        -0.6355,  0.4966, -0.9837, -0.5530, -0.2717, -0.0071,  0.0000, -0.8009,
-        -0.2502, -0.2444,  0.3492, -0.1913, -0.0053,  1.4293, -0.3200, -0.0621,
-         0.0000, -0.4462, -0.4948, -0.0920,  1.0754,  0.0000,  0.8048,  0.7281,
-        -0.2991,  0.8315, -0.5639,  0.5955, -1.1286, -0.0728,  0.3770, -0.2129,
-         0.9907,  0.0000, -0.0870, -0.2711,  0.3948,  0.8361,  0.3596,  0.0000,
-         1.0747,  0.0000,  0.5065,  0.1881,  0.0000, -0.1382,  0.1357, -0.0403],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3687,  0.2319,  0.1371, -0.5227,  0.5001,  0.2477,  0.0000, -0.5922,
-         0.0704, -0.4199, -0.2395, -0.1561,  0.0315, -0.9557, -1.4419, -0.1921,
-        -0.6355,  0.4966, -0.9837, -0.5530, -0.2717, -0.0071,  0.0000, -0.8009,
-        -0.2502, -0.2444,  0.3492, -0.1913, -0.0053,  1.4293, -0.3200, -0.0621,
-         0.0000, -0.4462, -0.4948, -0.0920,  1.0754,  0.0000,  0.8048,  0.7281,
-        -0.2991,  0.8315, -0.5639,  0.5955, -1.1286, -0.0728,  0.3770, -0.2129,
-         0.9907,  0.0000, -0.0870, -0.2711,  0.3948,  0.8361,  0.3596,  0.0000,
-         1.0747,  0.0000,  0.5065,  0.1881,  0.0000, -0.1382,  0.1357, -0.0403],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.0481e-01,  1.6211e-01,  1.7190e-01, -4.8729e-01,  4.4929e-01,
-         2.1156e-01,  3.1103e-05, -5.7031e-01,  1.0767e-01, -4.2435e-01,
-        -1.3468e-01, -1.0302e-01,  6.8509e-02, -9.9568e-01, -1.4393e+00,
-        -1.6964e-01, -6.6665e-01,  5.1785e-01, -9.8483e-01, -5.3094e-01,
-        -2.4917e-01, -6.3964e-02, -2.0029e-02, -8.5677e-01, -2.5904e-01,
-        -1.9273e-01,  3.1915e-01, -1.4614e-01, -3.0751e-02,  1.4261e+00,
-        -2.9671e-01,  4.8122e-02, -7.4566e-06, -4.9071e-01, -4.5982e-01,
-        -9.4314e-03,  1.0702e+00,  0.0000e+00,  8.1290e-01,  7.7479e-01,
-        -2.9592e-01,  8.2922e-01, -5.8070e-01,  6.3264e-01, -1.1387e+00,
-        -1.5261e-01,  3.2977e-01, -1.7383e-01,  9.7892e-01, -1.9343e-03,
-        -1.0327e-01, -3.3896e-01,  4.1536e-01,  9.2507e-01,  4.2543e-01,
-         4.0257e-07,  1.0728e+00,  5.6900e-05,  4.2439e-01,  3.4125e-01,
-         1.5199e-04, -1.4543e-01,  1.9537e-01,  7.0624e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3048,  0.1621,  0.1719, -0.4873,  0.4493,  0.2116,  0.0000, -0.5703,
-         0.1077, -0.4243, -0.1347, -0.1030,  0.0685, -0.9957, -1.4393, -0.1696,
-        -0.6667,  0.5178, -0.9848, -0.5309, -0.2492, -0.0640,  0.0000, -0.8568,
-        -0.2590, -0.1927,  0.3192, -0.1461, -0.0308,  1.4261, -0.2967,  0.0481,
-         0.0000, -0.4907, -0.4598, -0.0094,  1.0702,  0.0000,  0.8129,  0.7748,
-        -0.2959,  0.8292, -0.5807,  0.6326, -1.1387, -0.1526,  0.3298, -0.1738,
-         0.9789,  0.0000, -0.1033, -0.3390,  0.4154,  0.9251,  0.4254,  0.0000,
-         1.0728,  0.0000,  0.4244,  0.3413,  0.0000, -0.1454,  0.1954,  0.0706],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3048,  0.1621,  0.1719, -0.4873,  0.4493,  0.2116,  0.0000, -0.5703,
-         0.1077, -0.4243, -0.1347, -0.1030,  0.0685, -0.9957, -1.4393, -0.1696,
-        -0.6667,  0.5178, -0.9848, -0.5309, -0.2492, -0.0640,  0.0000, -0.8568,
-        -0.2590, -0.1927,  0.3192, -0.1461, -0.0308,  1.4261, -0.2967,  0.0481,
-         0.0000, -0.4907, -0.4598, -0.0094,  1.0702,  0.0000,  0.8129,  0.7748,
-        -0.2959,  0.8292, -0.5807,  0.6326, -1.1387, -0.1526,  0.3298, -0.1738,
-         0.9789,  0.0000, -0.1033, -0.3390,  0.4154,  0.9251,  0.4254,  0.0000,
-         1.0728,  0.0000,  0.4244,  0.3413,  0.0000, -0.1454,  0.1954,  0.0706],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6365e-01,  1.3955e-01,  1.8071e-01, -3.9785e-01,  3.7089e-01,
-         2.1298e-01,  2.6435e-05, -5.4954e-01,  1.8309e-01, -3.7393e-01,
-         1.3167e-02, -1.6227e-01,  6.0437e-02, -1.0358e+00, -1.4446e+00,
-        -1.1734e-01, -6.8033e-01,  5.3122e-01, -9.8540e-01, -5.4216e-01,
-        -2.2457e-01,  4.6959e-02, -1.7023e-02, -8.9690e-01, -1.1204e-01,
-        -1.8204e-01,  3.1812e-01, -2.0414e-02, -3.7023e-02,  1.4214e+00,
-        -2.4547e-01,  1.4408e-01, -6.3376e-06, -4.2121e-01, -4.6593e-01,
-         2.2624e-02,  1.0569e+00,  0.0000e+00,  7.9468e-01,  8.0103e-01,
-        -1.3462e-01,  8.3528e-01, -6.0019e-01,  6.6424e-01, -1.1410e+00,
-        -2.1555e-01,  3.0108e-01, -1.2877e-01,  9.5973e-01, -1.6441e-03,
-        -7.8448e-02, -3.3698e-01,  4.3697e-01,  9.9099e-01,  4.3861e-01,
-         3.4215e-07,  1.0697e+00,  4.8361e-05,  3.3012e-01,  5.1944e-01,
-         1.2918e-04, -1.2110e-01,  2.3327e-01,  1.6198e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2637,  0.1396,  0.1807, -0.3979,  0.3709,  0.2130,  0.0000, -0.5495,
-         0.1831, -0.3739,  0.0132, -0.1623,  0.0604, -1.0358, -1.4446, -0.1173,
-        -0.6803,  0.5312, -0.9854, -0.5422, -0.2246,  0.0470,  0.0000, -0.8969,
-        -0.1120, -0.1820,  0.3181, -0.0204, -0.0370,  1.4214, -0.2455,  0.1441,
-         0.0000, -0.4212, -0.4659,  0.0226,  1.0569,  0.0000,  0.7947,  0.8010,
-        -0.1346,  0.8353, -0.6002,  0.6642, -1.1410, -0.2155,  0.3011, -0.1288,
-         0.9597,  0.0000, -0.0784, -0.3370,  0.4370,  0.9910,  0.4386,  0.0000,
-         1.0697,  0.0000,  0.3301,  0.5194,  0.0000, -0.1211,  0.2333,  0.1620],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2637,  0.1396,  0.1807, -0.3979,  0.3709,  0.2130,  0.0000, -0.5495,
-         0.1831, -0.3739,  0.0132, -0.1623,  0.0604, -1.0358, -1.4446, -0.1173,
-        -0.6803,  0.5312, -0.9854, -0.5422, -0.2246,  0.0470,  0.0000, -0.8969,
-        -0.1120, -0.1820,  0.3181, -0.0204, -0.0370,  1.4214, -0.2455,  0.1441,
-         0.0000, -0.4212, -0.4659,  0.0226,  1.0569,  0.0000,  0.7947,  0.8010,
-        -0.1346,  0.8353, -0.6002,  0.6642, -1.1410, -0.2155,  0.3011, -0.1288,
-         0.9597,  0.0000, -0.0784, -0.3370,  0.4370,  0.9910,  0.4386,  0.0000,
-         1.0697,  0.0000,  0.3301,  0.5194,  0.0000, -0.1211,  0.2333,  0.1620],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.0751e-01,  1.5776e-01,  1.4082e-01, -3.2038e-01,  3.0326e-01,
-         1.8234e-01,  2.2468e-05, -5.2783e-01,  2.5449e-01, -3.0431e-01,
-         1.5517e-01, -2.5459e-01,  1.3105e-02, -1.0532e+00, -1.4452e+00,
-        -4.1518e-02, -6.7700e-01,  5.2327e-01, -9.8569e-01, -5.5120e-01,
-        -1.8370e-01,  1.5838e-01, -1.4468e-02, -9.2182e-01,  6.7250e-02,
-        -1.3455e-01,  2.8411e-01,  6.4854e-02, -6.5996e-03,  1.4225e+00,
-        -2.0421e-01,  1.8590e-01, -5.3864e-06, -3.2187e-01, -4.4783e-01,
-         2.3858e-02,  1.0340e+00,  0.0000e+00,  7.3643e-01,  8.0524e-01,
-         4.3743e-02,  8.3487e-01, -6.4782e-01,  6.6160e-01, -1.1440e+00,
-        -2.9977e-01,  3.1974e-01, -6.1461e-02,  9.5476e-01, -1.3973e-03,
-        -4.2827e-02, -3.0855e-01,  4.7512e-01,  1.0292e+00,  4.3427e-01,
-         2.9080e-07,  1.0655e+00,  4.1103e-05,  3.1277e-01,  6.6393e-01,
-         1.0980e-04, -1.4363e-01,  2.3713e-01,  1.6046e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2075,  0.1578,  0.1408, -0.3204,  0.3033,  0.1823,  0.0000, -0.5278,
-         0.2545, -0.3043,  0.1552, -0.2546,  0.0131, -1.0532, -1.4452, -0.0415,
-        -0.6770,  0.5233, -0.9857, -0.5512, -0.1837,  0.1584,  0.0000, -0.9218,
-         0.0673, -0.1345,  0.2841,  0.0649, -0.0066,  1.4225, -0.2042,  0.1859,
-         0.0000, -0.3219, -0.4478,  0.0239,  1.0340,  0.0000,  0.7364,  0.8052,
-         0.0437,  0.8349, -0.6478,  0.6616, -1.1440, -0.2998,  0.3197, -0.0615,
-         0.9548,  0.0000, -0.0428, -0.3085,  0.4751,  1.0292,  0.4343,  0.0000,
-         1.0655,  0.0000,  0.3128,  0.6639,  0.0000, -0.1436,  0.2371,  0.1605],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2075,  0.1578,  0.1408, -0.3204,  0.3033,  0.1823,  0.0000, -0.5278,
-         0.2545, -0.3043,  0.1552, -0.2546,  0.0131, -1.0532, -1.4452, -0.0415,
-        -0.6770,  0.5233, -0.9857, -0.5512, -0.1837,  0.1584,  0.0000, -0.9218,
-         0.0673, -0.1345,  0.2841,  0.0649, -0.0066,  1.4225, -0.2042,  0.1859,
-         0.0000, -0.3219, -0.4478,  0.0239,  1.0340,  0.0000,  0.7364,  0.8052,
-         0.0437,  0.8349, -0.6478,  0.6616, -1.1440, -0.2998,  0.3197, -0.0615,
-         0.9548,  0.0000, -0.0428, -0.3085,  0.4751,  1.0292,  0.4343,  0.0000,
-         1.0655,  0.0000,  0.3128,  0.6639,  0.0000, -0.1436,  0.2371,  0.1605],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.4978e-01,  1.8740e-01,  7.5682e-03, -3.0770e-01,  3.4717e-01,
-         1.9704e-01,  1.9095e-05, -5.2046e-01,  3.0162e-01, -2.3308e-01,
-         1.7117e-01, -3.4409e-01, -1.5220e-01, -1.0243e+00, -1.4433e+00,
-        -1.1345e-01, -6.7411e-01,  4.9056e-01, -9.8131e-01, -5.8536e-01,
-        -2.0477e-01,  2.4996e-01, -1.2297e-02, -9.4856e-01,  1.8053e-01,
-        -3.3959e-02,  2.5085e-01,  1.2361e-01,  3.2509e-02,  1.4306e+00,
-        -1.1625e-01,  1.6701e-01, -4.5779e-06, -2.6260e-01, -4.5310e-01,
-        -4.5735e-02,  1.0152e+00,  0.0000e+00,  6.5909e-01,  7.7738e-01,
-         1.3558e-01,  8.3990e-01, -6.9321e-01,  6.5666e-01, -1.1399e+00,
-        -3.3061e-01,  3.5541e-01,  5.2146e-03,  9.6493e-01, -1.1876e-03,
-        -5.0998e-02, -3.0815e-01,  4.9805e-01,  1.0572e+00,  3.6031e-01,
-         2.4715e-07,  1.0605e+00,  3.4933e-05,  3.0198e-01,  7.9845e-01,
-         9.3315e-05, -1.8349e-01,  2.0132e-01,  7.9328e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1498,  0.1874,  0.0076, -0.3077,  0.3472,  0.1970,  0.0000, -0.5205,
-         0.3016, -0.2331,  0.1712, -0.3441, -0.1522, -1.0243, -1.4433, -0.1134,
-        -0.6741,  0.4906, -0.9813, -0.5854, -0.2048,  0.2500,  0.0000, -0.9486,
-         0.1805, -0.0340,  0.2509,  0.1236,  0.0325,  1.4306, -0.1162,  0.1670,
-         0.0000, -0.2626, -0.4531, -0.0457,  1.0152,  0.0000,  0.6591,  0.7774,
-         0.1356,  0.8399, -0.6932,  0.6567, -1.1399, -0.3306,  0.3554,  0.0052,
-         0.9649,  0.0000, -0.0510, -0.3082,  0.4980,  1.0572,  0.3603,  0.0000,
-         1.0605,  0.0000,  0.3020,  0.7984,  0.0000, -0.1835,  0.2013,  0.0793],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1498,  0.1874,  0.0076, -0.3077,  0.3472,  0.1970,  0.0000, -0.5205,
-         0.3016, -0.2331,  0.1712, -0.3441, -0.1522, -1.0243, -1.4433, -0.1134,
-        -0.6741,  0.4906, -0.9813, -0.5854, -0.2048,  0.2500,  0.0000, -0.9486,
-         0.1805, -0.0340,  0.2509,  0.1236,  0.0325,  1.4306, -0.1162,  0.1670,
-         0.0000, -0.2626, -0.4531, -0.0457,  1.0152,  0.0000,  0.6591,  0.7774,
-         0.1356,  0.8399, -0.6932,  0.6567, -1.1399, -0.3306,  0.3554,  0.0052,
-         0.9649,  0.0000, -0.0510, -0.3082,  0.4980,  1.0572,  0.3603,  0.0000,
-         1.0605,  0.0000,  0.3020,  0.7984,  0.0000, -0.1835,  0.2013,  0.0793],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.5810e-01,  1.7891e-01, -4.3123e-02, -2.5934e-01,  3.1817e-01,
-         2.1187e-01,  1.6229e-05, -5.2316e-01,  3.1411e-01, -1.9375e-01,
-         5.0494e-02, -3.7392e-01, -1.9656e-01, -1.0103e+00, -1.4390e+00,
-        -3.2594e-01, -6.9573e-01,  4.4236e-01, -9.7929e-01, -6.1097e-01,
-        -2.6155e-01,  1.9261e-01, -1.0451e-02, -9.8562e-01,  1.6377e-01,
-         1.8883e-01,  2.4317e-01,  9.5977e-02, -3.6470e-02,  1.4446e+00,
-        -4.3615e-02,  8.9426e-02, -3.8907e-06, -2.9347e-01, -4.8162e-01,
-        -1.1294e-01,  1.0058e+00,  0.0000e+00,  6.2938e-01,  7.7551e-01,
-         1.1438e-01,  8.2889e-01, -6.9170e-01,  6.7252e-01, -1.1390e+00,
-        -3.1564e-01,  3.7512e-01,  9.2685e-02,  9.7982e-01, -1.0093e-03,
-        -8.5232e-02, -4.0931e-01,  5.1520e-01,  1.0725e+00,  3.5285e-01,
-         2.1005e-07,  1.0562e+00,  2.9690e-05,  3.0220e-01,  8.9759e-01,
-         7.9308e-05, -1.7013e-01,  1.0625e-01,  7.2991e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1581,  0.1789, -0.0431, -0.2593,  0.3182,  0.2119,  0.0000, -0.5232,
-         0.3141, -0.1937,  0.0505, -0.3739, -0.1966, -1.0103, -1.4390, -0.3259,
-        -0.6957,  0.4424, -0.9793, -0.6110, -0.2615,  0.1926,  0.0000, -0.9856,
-         0.1638,  0.1888,  0.2432,  0.0960, -0.0365,  1.4446, -0.0436,  0.0894,
-         0.0000, -0.2935, -0.4816, -0.1129,  1.0058,  0.0000,  0.6294,  0.7755,
-         0.1144,  0.8289, -0.6917,  0.0000, -1.1390, -0.3156,  0.3751,  0.0927,
-         0.9798,  0.0000, -0.0852, -0.4093,  0.5152,  1.0725,  0.3528,  0.0000,
-         1.0562,  0.0000,  0.3022,  0.8976,  0.0000, -0.1701,  0.1062,  0.0730],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1581,  0.1789, -0.0431, -0.2593,  0.3182,  0.2119,  0.0000, -0.5232,
-         0.3141, -0.1937,  0.0505, -0.3739, -0.1966, -1.0103, -1.4390, -0.3259,
-        -0.6957,  0.4424, -0.9793, -0.6110, -0.2615,  0.1926,  0.0000, -0.9856,
-         0.1638,  0.1888,  0.2432,  0.0960, -0.0365,  1.4446, -0.0436,  0.0894,
-         0.0000, -0.2935, -0.4816, -0.1129,  1.0058,  0.0000,  0.6294,  0.7755,
-         0.1144,  0.8289, -0.6917,  0.0000, -1.1390, -0.3156,  0.3751,  0.0927,
-         0.9798,  0.0000, -0.0852, -0.4093,  0.5152,  1.0725,  0.3528,  0.0000,
-         1.0562,  0.0000,  0.3022,  0.8976,  0.0000, -0.1701,  0.1062,  0.0730],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.3907e-01,  1.0897e-01, -2.4078e-03, -2.1051e-01,  2.7389e-01,
-         2.8915e-01,  1.3793e-05, -5.0355e-01,  3.0349e-01, -1.8780e-01,
-        -4.9766e-02, -3.2657e-01, -1.3447e-01, -9.9413e-01, -1.4344e+00,
-        -5.1470e-01, -7.2613e-01,  4.0641e-01, -9.7538e-01, -6.1477e-01,
-        -2.9949e-01,  8.1025e-02, -8.8820e-03, -1.0141e+00,  9.6718e-02,
-         3.8095e-01,  2.7685e-01,  1.1766e-01, -1.2483e-01,  1.4497e+00,
-        -3.9679e-02,  1.2039e-01, -3.3066e-06, -3.3132e-01, -4.9620e-01,
-        -1.2674e-01,  9.9396e-01,  0.0000e+00,  5.9157e-01,  7.7126e-01,
-         7.7502e-02,  8.0705e-01, -6.6338e-01,  1.3474e-02, -1.1341e+00,
-        -2.8312e-01,  3.4949e-01,  1.9550e-01,  9.7938e-01, -8.5779e-04,
-        -1.1370e-01, -4.8413e-01,  4.8924e-01,  1.0884e+00,  3.3630e-01,
-         1.7852e-07,  1.0544e+00,  2.5233e-05,  2.8714e-01,  9.7866e-01,
-         6.7402e-05, -1.6686e-01,  1.9690e-02,  1.5694e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1391,  0.1090, -0.0024, -0.2105,  0.2739,  0.2892,  0.0000, -0.5035,
-         0.3035, -0.1878, -0.0498, -0.3266, -0.1345, -0.9941, -1.4344, -0.5147,
-        -0.7261,  0.4064, -0.9754, -0.6148, -0.2995,  0.0810,  0.0000, -1.0141,
-         0.0967,  0.3810,  0.2769,  0.1177, -0.1248,  1.4497, -0.0397,  0.1204,
-         0.0000, -0.3313, -0.4962, -0.1267,  0.9940,  0.0000,  0.5916,  0.7713,
-         0.0775,  0.8071, -0.6634,  0.0000, -1.1341, -0.2831,  0.3495,  0.1955,
-         0.9794,  0.0000, -0.1137, -0.4841,  0.4892,  1.0884,  0.3363,  0.0000,
-         1.0544,  0.0000,  0.2871,  0.9787,  0.0000, -0.1669,  0.0197,  0.1569],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1391,  0.1090, -0.0024, -0.2105,  0.2739,  0.2892,  0.0000, -0.5035,
-         0.3035, -0.1878, -0.0498, -0.3266, -0.1345, -0.9941, -1.4344, -0.5147,
-        -0.7261,  0.4064, -0.9754, -0.6148, -0.2995,  0.0810,  0.0000, -1.0141,
-         0.0967,  0.3810,  0.2769,  0.1177, -0.1248,  1.4497, -0.0397,  0.1204,
-         0.0000, -0.3313, -0.4962, -0.1267,  0.9940,  0.0000,  0.5916,  0.7713,
-         0.0775,  0.8071, -0.6634,  0.0000, -1.1341, -0.2831,  0.3495,  0.1955,
-         0.9794,  0.0000, -0.1137, -0.4841,  0.4892,  1.0884,  0.3363,  0.0000,
-         1.0544,  0.0000,  0.2871,  0.9787,  0.0000, -0.1669,  0.0197,  0.1569],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.4294e-01,  5.7035e-02,  3.6000e-02, -1.6203e-01,  2.0263e-01,
-         3.6383e-01,  1.1722e-05, -4.8752e-01,  2.8077e-01, -1.7213e-01,
-        -1.4671e-01, -2.8663e-01, -6.4207e-02, -9.8960e-01, -1.4323e+00,
-        -6.7548e-01, -7.4609e-01,  3.5597e-01, -9.7366e-01, -6.0121e-01,
-        -3.2768e-01, -2.4364e-02, -7.5487e-03, -1.0263e+00,  3.5703e-02,
-         5.4627e-01,  2.5879e-01,  1.0493e-01, -1.7433e-01,  1.4490e+00,
-        -2.1982e-02,  6.5968e-02, -2.8103e-06, -3.2079e-01, -4.8514e-01,
-        -1.5897e-01,  9.8545e-01,  0.0000e+00,  5.6668e-01,  7.7991e-01,
-         2.6700e-02,  7.9484e-01, -6.5106e-01,  1.1451e-02, -1.1302e+00,
-        -3.0412e-01,  3.3777e-01,  2.8504e-01,  9.7216e-01, -7.2902e-04,
-        -1.4906e-01, -4.6675e-01,  4.9053e-01,  1.0953e+00,  3.3842e-01,
-         1.5172e-07,  1.0527e+00,  2.1445e-05,  2.8793e-01,  1.0501e+00,
-         5.7284e-05, -1.8349e-01, -6.0638e-02,  1.8071e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1429,  0.0570,  0.0360, -0.1620,  0.2026,  0.3638,  0.0000, -0.4875,
-         0.2808, -0.1721, -0.1467, -0.2866, -0.0642, -0.9896, -1.4323, -0.6755,
-        -0.7461,  0.3560, -0.9737, -0.6012, -0.3277, -0.0244,  0.0000, -1.0263,
-         0.0357,  0.5463,  0.2588,  0.1049, -0.1743,  1.4490, -0.0220,  0.0660,
-         0.0000, -0.3208, -0.4851, -0.1590,  0.9854,  0.0000,  0.5667,  0.7799,
-         0.0267,  0.7948, -0.6511,  0.0000, -1.1302, -0.3041,  0.3378,  0.2850,
-         0.9722,  0.0000, -0.1491, -0.4668,  0.4905,  1.0953,  0.3384,  0.0000,
-         1.0527,  0.0000,  0.2879,  1.0501,  0.0000, -0.1835, -0.0606,  0.1807],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1429,  0.0570,  0.0360, -0.1620,  0.2026,  0.3638,  0.0000, -0.4875,
-         0.2808, -0.1721, -0.1467, -0.2866, -0.0642, -0.9896, -1.4323, -0.6755,
-        -0.7461,  0.3560, -0.9737, -0.6012, -0.3277, -0.0244,  0.0000, -1.0263,
-         0.0357,  0.5463,  0.2588,  0.1049, -0.1743,  1.4490, -0.0220,  0.0660,
-         0.0000, -0.3208, -0.4851, -0.1590,  0.9854,  0.0000,  0.5667,  0.7799,
-         0.0267,  0.7948, -0.6511,  0.0000, -1.1302, -0.3041,  0.3378,  0.2850,
-         0.9722,  0.0000, -0.1491, -0.4668,  0.4905,  1.0953,  0.3384,  0.0000,
-         1.0527,  0.0000,  0.2879,  1.0501,  0.0000, -0.1835, -0.0606,  0.1807],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.3240e-01,  8.7308e-02, -2.7938e-02, -1.3434e-01,  5.9964e-02,
-         2.6087e-01,  9.9626e-06, -4.7517e-01,  1.6981e-01, -1.3111e-01,
-        -1.8134e-01, -2.7601e-01, -5.3990e-02, -9.9595e-01, -1.4359e+00,
-        -7.9622e-01, -7.5244e-01,  3.1461e-01, -9.7651e-01, -5.9997e-01,
-        -3.4420e-01, -6.1874e-02, -6.4155e-03, -1.0214e+00,  4.3662e-03,
-         6.8517e-01,  4.5901e-02,  2.7926e-02, -1.3818e-01,  1.4562e+00,
-         3.8682e-02, -2.3016e-01, -2.3884e-06, -2.7575e-01, -4.7148e-01,
-        -2.0629e-01,  1.0038e+00,  0.0000e+00,  5.4027e-01,  8.0768e-01,
-         1.4057e-02,  7.7089e-01, -6.2107e-01,  9.7324e-03, -1.1329e+00,
-        -3.2476e-01,  3.6568e-01,  3.0654e-01,  9.7098e-01, -6.1959e-04,
-        -1.5225e-01, -3.8552e-01,  5.1552e-01,  1.0618e+00,  4.4740e-01,
-         1.2895e-07,  1.0508e+00,  1.8226e-05,  3.3540e-01,  1.0993e+00,
-         4.8685e-05, -1.9179e-01, -7.7134e-02, -8.3810e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2324,  0.0873, -0.0279, -0.1343,  0.0600,  0.2609,  0.0000, -0.4752,
-         0.1698, -0.1311, -0.1813, -0.2760, -0.0540, -0.9959, -1.4359, -0.7962,
-        -0.7524,  0.3146, -0.9765, -0.6000, -0.3442, -0.0619,  0.0000, -1.0214,
-         0.0044,  0.6852,  0.0459,  0.0279, -0.1382,  1.4562,  0.0387, -0.2302,
-         0.0000, -0.2758, -0.4715, -0.2063,  1.0038,  0.0000,  0.5403,  0.8077,
-         0.0141,  0.7709, -0.6211,  0.0000, -1.1329, -0.3248,  0.3657,  0.3065,
-         0.9710,  0.0000, -0.1522, -0.3855,  0.5155,  1.0618,  0.4474,  0.0000,
-         1.0508,  0.0000,  0.3354,  1.0993,  0.0000, -0.1918, -0.0771, -0.0838],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2324,  0.0873, -0.0279, -0.1343,  0.0600,  0.2609,  0.0000, -0.4752,
-         0.1698, -0.1311, -0.1813, -0.2760, -0.0540, -0.9959, -1.4359, -0.7962,
-        -0.7524,  0.3146, -0.9765, -0.6000, -0.3442, -0.0619,  0.0000, -1.0214,
-         0.0044,  0.6852,  0.0459,  0.0279, -0.1382,  1.4562,  0.0387, -0.2302,
-         0.0000, -0.2758, -0.4715, -0.2063,  1.0038,  0.0000,  0.5403,  0.8077,
-         0.0141,  0.7709, -0.6211,  0.0000, -1.1329, -0.3248,  0.3657,  0.3065,
-         0.9710,  0.0000, -0.1522, -0.3855,  0.5155,  1.0618,  0.4474,  0.0000,
-         1.0508,  0.0000,  0.3354,  1.0993,  0.0000, -0.1918, -0.0771, -0.0838],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7245e-01,  1.4616e-01, -4.8475e-02, -5.8708e-02,  3.2515e-02,
-         2.3619e-01,  8.4671e-06, -4.5635e-01,  8.2299e-02, -1.6468e-01,
-        -1.3310e-01, -2.5594e-01, -2.0201e-02, -9.9734e-01, -1.4334e+00,
-        -8.8470e-01, -7.6120e-01,  2.3122e-01, -9.8036e-01, -5.8429e-01,
-        -3.1775e-01, -1.0592e-01, -5.4525e-03, -1.0157e+00, -5.7045e-03,
-         7.7844e-01, -1.2500e-01, -2.5594e-02, -9.9459e-02,  1.4618e+00,
-         1.1155e-01, -4.6705e-01, -2.0299e-06, -2.5208e-01, -4.4062e-01,
-        -2.3959e-01,  1.0081e+00,  0.0000e+00,  4.7566e-01,  8.2243e-01,
-        -1.2147e-04,  7.4697e-01, -5.9488e-01,  8.2715e-03, -1.1340e+00,
-        -3.6533e-01,  4.0094e-01,  3.3818e-01,  9.8065e-01, -5.2658e-04,
-        -1.2866e-01, -2.8538e-01,  5.3223e-01,  1.0291e+00,  5.2272e-01,
-         1.0959e-07,  1.0486e+00,  1.5490e-05,  4.0599e-01,  1.1393e+00,
-         4.1377e-05, -2.1226e-01, -6.8684e-02, -2.7417e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.7245e-01,  1.4616e-01, -4.8475e-02, -5.8708e-02,  3.2515e-02,
-         2.3619e-01,  0.0000e+00, -4.5635e-01,  8.2299e-02, -1.6468e-01,
-        -1.3310e-01, -2.5594e-01, -2.0201e-02, -9.9734e-01, -1.4334e+00,
-        -8.8470e-01, -7.6120e-01,  2.3122e-01, -9.8036e-01, -5.8429e-01,
-        -3.1775e-01, -1.0592e-01,  0.0000e+00, -1.0157e+00, -5.7045e-03,
-         7.7844e-01, -1.2500e-01, -2.5594e-02, -9.9459e-02,  1.4618e+00,
-         1.1155e-01, -4.6705e-01,  0.0000e+00, -2.5208e-01, -4.4062e-01,
-        -2.3959e-01,  1.0081e+00,  0.0000e+00,  4.7566e-01,  8.2243e-01,
-        -1.2147e-04,  7.4697e-01, -5.9488e-01,  0.0000e+00, -1.1340e+00,
-        -3.6533e-01,  4.0094e-01,  3.3818e-01,  9.8065e-01,  0.0000e+00,
-        -1.2866e-01, -2.8538e-01,  5.3223e-01,  1.0291e+00,  5.2272e-01,
-         0.0000e+00,  1.0486e+00,  0.0000e+00,  4.0599e-01,  1.1393e+00,
-         0.0000e+00, -2.1226e-01, -6.8684e-02, -2.7417e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.7245e-01,  1.4616e-01, -4.8475e-02, -5.8708e-02,  3.2515e-02,
-         2.3619e-01,  0.0000e+00, -4.5635e-01,  8.2299e-02, -1.6468e-01,
-        -1.3310e-01, -2.5594e-01, -2.0201e-02, -9.9734e-01, -1.4334e+00,
-        -8.8470e-01, -7.6120e-01,  2.3122e-01, -9.8036e-01, -5.8429e-01,
-        -3.1775e-01, -1.0592e-01,  0.0000e+00, -1.0157e+00, -5.7045e-03,
-         7.7844e-01, -1.2500e-01, -2.5594e-02, -9.9459e-02,  1.4618e+00,
-         1.1155e-01, -4.6705e-01,  0.0000e+00, -2.5208e-01, -4.4062e-01,
-        -2.3959e-01,  1.0081e+00,  0.0000e+00,  4.7566e-01,  8.2243e-01,
-        -1.2147e-04,  7.4697e-01, -5.9488e-01,  0.0000e+00, -1.1340e+00,
-        -3.6533e-01,  4.0094e-01,  3.3818e-01,  9.8065e-01,  0.0000e+00,
-        -1.2866e-01, -2.8538e-01,  5.3223e-01,  1.0291e+00,  5.2272e-01,
-         0.0000e+00,  1.0486e+00,  0.0000e+00,  4.0599e-01,  1.1393e+00,
-         0.0000e+00, -2.1226e-01, -6.8684e-02, -2.7417e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1723e-01,  1.8953e-01, -4.1519e-02,  6.6128e-02, -3.8059e-02,
-         2.2362e-01,  7.1962e-06, -4.2616e-01,  2.0236e-02, -1.5731e-01,
-        -5.0043e-02, -2.2810e-01,  6.0120e-02, -9.9804e-01, -1.4341e+00,
-        -9.4834e-01, -7.6613e-01,  1.8811e-01, -9.8254e-01, -5.7600e-01,
-        -2.7366e-01, -4.6454e-02, -4.6340e-03, -1.0066e+00,  7.4466e-02,
-         8.4078e-01, -2.1329e-01, -7.5612e-03, -8.7748e-02,  1.4672e+00,
-         1.9695e-01, -6.6891e-01, -1.7252e-06, -1.8003e-01, -4.1023e-01,
-        -2.6113e-01,  1.0044e+00,  0.0000e+00,  3.7794e-01,  8.1377e-01,
-         2.3878e-02,  7.3331e-01, -5.5837e-01,  7.0299e-03, -1.1330e+00,
-        -3.9198e-01,  4.1058e-01,  3.6122e-01,  9.8599e-01, -4.4754e-04,
-        -7.5417e-02, -1.5858e-01,  5.2204e-01,  1.0050e+00,  5.7755e-01,
-         9.3140e-08,  1.0461e+00,  1.3165e-05,  4.6594e-01,  1.1767e+00,
-         3.5166e-05, -1.5269e-01, -2.7973e-02, -4.0337e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3172,  0.1895, -0.0415,  0.0661, -0.0381,  0.2236,  0.0000, -0.4262,
-         0.0202, -0.1573, -0.0500, -0.2281,  0.0601, -0.9980, -1.4341, -0.9483,
-        -0.7661,  0.1881,  0.0000, -0.5760, -0.2737, -0.0465,  0.0000, -1.0066,
-         0.0745,  0.8408, -0.2133, -0.0076, -0.0877,  1.4672,  0.1970, -0.6689,
-         0.0000, -0.1800, -0.4102, -0.2611,  1.0044,  0.0000,  0.3779,  0.8138,
-         0.0239,  0.7333, -0.5584,  0.0000, -1.1330, -0.3920,  0.4106,  0.3612,
-         0.9860,  0.0000, -0.0754, -0.1586,  0.5220,  1.0050,  0.5775,  0.0000,
-         1.0461,  0.0000,  0.4659,  1.1767,  0.0000, -0.1527, -0.0280, -0.4034],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3172,  0.1895, -0.0415,  0.0661, -0.0381,  0.2236,  0.0000, -0.4262,
-         0.0202, -0.1573, -0.0500, -0.2281,  0.0601, -0.9980, -1.4341, -0.9483,
-        -0.7661,  0.1881,  0.0000, -0.5760, -0.2737, -0.0465,  0.0000, -1.0066,
-         0.0745,  0.8408, -0.2133, -0.0076, -0.0877,  1.4672,  0.1970, -0.6689,
-         0.0000, -0.1800, -0.4102, -0.2611,  1.0044,  0.0000,  0.3779,  0.8138,
-         0.0239,  0.7333, -0.5584,  0.0000, -1.1330, -0.3920,  0.4106,  0.3612,
-         0.9860,  0.0000, -0.0754, -0.1586,  0.5220,  1.0050,  0.5775,  0.0000,
-         1.0461,  0.0000,  0.4659,  1.1767,  0.0000, -0.1527, -0.0280, -0.4034],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7397e-01,  1.8845e-01, -6.0694e-03,  2.2073e-01, -9.1157e-02,
-         2.3191e-01,  6.1161e-06, -4.1566e-01, -2.2269e-02, -1.6865e-01,
-         2.3742e-02, -1.9176e-01,  1.7879e-01, -9.9931e-01, -1.4416e+00,
-        -1.0074e+00, -7.7152e-01,  1.8252e-01, -1.8546e-03, -5.6476e-01,
-        -2.3775e-01, -1.1553e-03, -3.9385e-03, -9.9364e-01,  1.3416e-01,
-         8.8688e-01, -2.8534e-01,  3.3260e-02, -6.8379e-02,  1.4686e+00,
-         2.3371e-01, -8.3692e-01, -1.4663e-06, -1.1921e-01, -3.8358e-01,
-        -2.8047e-01,  1.0040e+00,  0.0000e+00,  2.7674e-01,  8.0635e-01,
-         6.2674e-02,  7.1552e-01, -5.3071e-01,  5.9748e-03, -1.1290e+00,
-        -3.8029e-01,  4.0189e-01,  3.7826e-01,  9.6939e-01, -3.8037e-04,
-        -3.0626e-02, -4.4263e-02,  4.8054e-01,  9.8598e-01,  6.3348e-01,
-         7.9161e-08,  1.0437e+00,  1.1189e-05,  4.8936e-01,  1.2055e+00,
-         2.9888e-05, -8.3759e-02,  1.8840e-02, -5.0062e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.7397e-01,  1.8845e-01, -6.0694e-03,  2.2073e-01, -9.1157e-02,
-         2.3191e-01,  0.0000e+00, -4.1566e-01, -2.2269e-02, -1.6865e-01,
-         2.3742e-02, -1.9176e-01,  1.7879e-01, -9.9931e-01, -1.4416e+00,
-        -1.0074e+00, -7.7152e-01,  1.8252e-01,  0.0000e+00, -5.6476e-01,
-        -2.3775e-01, -1.1553e-03,  0.0000e+00, -9.9364e-01,  1.3416e-01,
-         8.8688e-01, -2.8534e-01,  3.3260e-02, -6.8379e-02,  1.4686e+00,
-         2.3371e-01, -8.3692e-01,  0.0000e+00, -1.1921e-01, -3.8358e-01,
-        -2.8047e-01,  1.0040e+00,  0.0000e+00,  2.7674e-01,  8.0635e-01,
-         6.2674e-02,  7.1552e-01, -5.3071e-01,  0.0000e+00, -1.1290e+00,
-        -3.8029e-01,  4.0189e-01,  3.7826e-01,  9.6939e-01,  0.0000e+00,
-        -3.0626e-02, -4.4263e-02,  4.8054e-01,  9.8598e-01,  6.3348e-01,
-         0.0000e+00,  1.0437e+00,  0.0000e+00,  4.8936e-01,  1.2055e+00,
-         0.0000e+00, -8.3759e-02,  1.8840e-02, -5.0062e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.7397e-01,  1.8845e-01, -6.0694e-03,  2.2073e-01, -9.1157e-02,
-         2.3191e-01,  0.0000e+00, -4.1566e-01, -2.2269e-02, -1.6865e-01,
-         2.3742e-02, -1.9176e-01,  1.7879e-01, -9.9931e-01, -1.4416e+00,
-        -1.0074e+00, -7.7152e-01,  1.8252e-01,  0.0000e+00, -5.6476e-01,
-        -2.3775e-01, -1.1553e-03,  0.0000e+00, -9.9364e-01,  1.3416e-01,
-         8.8688e-01, -2.8534e-01,  3.3260e-02, -6.8379e-02,  1.4686e+00,
-         2.3371e-01, -8.3692e-01,  0.0000e+00, -1.1921e-01, -3.8358e-01,
-        -2.8047e-01,  1.0040e+00,  0.0000e+00,  2.7674e-01,  8.0635e-01,
-         6.2674e-02,  7.1552e-01, -5.3071e-01,  0.0000e+00, -1.1290e+00,
-        -3.8029e-01,  4.0189e-01,  3.7826e-01,  9.6939e-01,  0.0000e+00,
-        -3.0626e-02, -4.4263e-02,  4.8054e-01,  9.8598e-01,  6.3348e-01,
-         0.0000e+00,  1.0437e+00,  0.0000e+00,  4.8936e-01,  1.2055e+00,
-         0.0000e+00, -8.3759e-02,  1.8840e-02, -5.0062e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 4.0766e-01,  1.8012e-01,  4.8314e-03,  3.2518e-01, -3.5104e-02,
-         2.2786e-01,  5.1983e-06, -4.2209e-01, -8.0147e-02, -2.2458e-01,
-         7.4818e-02, -1.7145e-01,  2.3669e-01, -1.0024e+00, -1.4472e+00,
-        -1.0546e+00, -7.6293e-01,  1.3986e-01, -1.5763e-03, -5.4820e-01,
-        -1.8418e-01,  3.4674e-02, -3.3475e-03, -9.7653e-01,  1.8591e-01,
-         9.3103e-01, -3.5723e-01,  5.3907e-02,  3.0640e-03,  1.4679e+00,
-         2.2266e-01, -9.7920e-01, -1.2462e-06, -6.4136e-02, -3.3288e-01,
-        -2.9960e-01,  1.0235e+00,  0.0000e+00,  2.0724e-01,  8.2444e-01,
-         6.9693e-02,  7.0790e-01, -5.2680e-01,  5.0782e-03, -1.1269e+00,
-        -3.7162e-01,  4.0907e-01,  3.9645e-01,  9.4605e-01, -3.2329e-04,
-        -8.4121e-03,  5.1231e-02,  4.3888e-01,  9.5349e-01,  6.9666e-01,
-         6.7281e-08,  1.0416e+00,  9.5098e-06,  5.2024e-01,  1.2266e+00,
-         2.5403e-05, -8.5101e-02,  5.7277e-02, -5.6608e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4077,  0.1801,  0.0048,  0.3252, -0.0351,  0.2279,  0.0000, -0.4221,
-        -0.0801, -0.2246,  0.0748, -0.1715,  0.2367, -1.0024, -1.4472, -1.0546,
-        -0.7629,  0.1399,  0.0000, -0.5482, -0.1842,  0.0347,  0.0000, -0.9765,
-         0.1859,  0.9310, -0.3572,  0.0539,  0.0031,  1.4679,  0.2227, -0.9792,
-         0.0000, -0.0641, -0.3329, -0.2996,  1.0235,  0.0000,  0.2072,  0.8244,
-         0.0697,  0.7079, -0.5268,  0.0000, -1.1269, -0.3716,  0.4091,  0.3965,
-         0.9460,  0.0000, -0.0084,  0.0512,  0.4389,  0.9535,  0.6967,  0.0000,
-         1.0416,  0.0000,  0.5202,  1.2266,  0.0000, -0.0851,  0.0573, -0.5661],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4077,  0.1801,  0.0048,  0.3252, -0.0351,  0.2279,  0.0000, -0.4221,
-        -0.0801, -0.2246,  0.0748, -0.1715,  0.2367, -1.0024, -1.4472, -1.0546,
-        -0.7629,  0.1399,  0.0000, -0.5482, -0.1842,  0.0347,  0.0000, -0.9765,
-         0.1859,  0.9310, -0.3572,  0.0539,  0.0031,  1.4679,  0.2227, -0.9792,
-         0.0000, -0.0641, -0.3329, -0.2996,  1.0235,  0.0000,  0.2072,  0.8244,
-         0.0697,  0.7079, -0.5268,  0.0000, -1.1269, -0.3716,  0.4091,  0.3965,
-         0.9460,  0.0000, -0.0084,  0.0512,  0.4389,  0.9535,  0.6967,  0.0000,
-         1.0416,  0.0000,  0.5202,  1.2266,  0.0000, -0.0851,  0.0573, -0.5661],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.2423e-01,  1.5502e-01,  5.2634e-02,  4.5070e-01, -6.2584e-02,
-         2.2690e-01,  4.4183e-06, -4.1950e-01, -1.1429e-01, -2.7776e-01,
-         6.6411e-02, -1.4908e-01,  3.3049e-01, -9.9083e-01, -1.4555e+00,
-        -1.0897e+00, -7.5201e-01,  1.0624e-01, -1.3398e-03, -5.3843e-01,
-        -1.3009e-01,  5.5861e-03, -2.8452e-03, -9.4983e-01,  1.8258e-01,
-         9.7111e-01, -3.9392e-01,  8.0380e-02,  5.9078e-03,  1.4665e+00,
-         2.0823e-01, -1.0944e+00, -1.0592e-06, -2.5834e-02, -2.9611e-01,
-        -3.1231e-01,  1.0370e+00,  0.0000e+00,  1.4405e-01,  8.1867e-01,
-         6.9479e-02,  6.9087e-01, -5.2009e-01,  4.3162e-03, -1.1221e+00,
-        -3.9937e-01,  3.9411e-01,  3.9697e-01,  9.2194e-01, -2.7478e-04,
-        -1.5683e-02,  1.2875e-01,  4.0314e-01,  9.2555e-01,  7.3061e-01,
-         5.7186e-08,  1.0391e+00,  8.0828e-06,  5.3779e-01,  1.2377e+00,
-         2.1591e-05, -3.3102e-02,  7.2290e-02, -5.9700e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4242,  0.1550,  0.0526,  0.4507, -0.0626,  0.2269,  0.0000, -0.4195,
-        -0.1143, -0.2778,  0.0664, -0.1491,  0.3305, -0.9908, -1.4555, -1.0897,
-        -0.7520,  0.1062,  0.0000, -0.5384, -0.1301,  0.0056,  0.0000, -0.9498,
-         0.1826,  0.9711, -0.3939,  0.0804,  0.0059,  1.4665,  0.2082, -1.0944,
-         0.0000, -0.0258, -0.2961, -0.3123,  1.0370,  0.0000,  0.1440,  0.8187,
-         0.0695,  0.6909, -0.5201,  0.0000, -1.1221, -0.3994,  0.3941,  0.3970,
-         0.9219,  0.0000, -0.0157,  0.1288,  0.4031,  0.9256,  0.7306,  0.0000,
-         1.0391,  0.0000,  0.5378,  1.2377,  0.0000, -0.0331,  0.0723, -0.5970],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4242,  0.1550,  0.0526,  0.4507, -0.0626,  0.2269,  0.0000, -0.4195,
-        -0.1143, -0.2778,  0.0664, -0.1491,  0.3305, -0.9908, -1.4555, -1.0897,
-        -0.7520,  0.1062,  0.0000, -0.5384, -0.1301,  0.0056,  0.0000, -0.9498,
-         0.1826,  0.9711, -0.3939,  0.0804,  0.0059,  1.4665,  0.2082, -1.0944,
-         0.0000, -0.0258, -0.2961, -0.3123,  1.0370,  0.0000,  0.1440,  0.8187,
-         0.0695,  0.6909, -0.5201,  0.0000, -1.1221, -0.3994,  0.3941,  0.3970,
-         0.9219,  0.0000, -0.0157,  0.1288,  0.4031,  0.9256,  0.7306,  0.0000,
-         1.0391,  0.0000,  0.5378,  1.2377,  0.0000, -0.0331,  0.0723, -0.5970],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7248e-01,  1.7891e-01, -8.4729e-02,  4.3486e-01, -1.1811e-01,
-         2.6720e-01,  3.7554e-06, -4.3366e-01, -1.3607e-01, -2.8671e-01,
-        -2.9530e-02, -1.8005e-01,  2.4557e-01, -9.8832e-01, -1.4608e+00,
-        -1.1149e+00, -7.2996e-01, -7.9751e-03, -1.1388e-03, -5.5468e-01,
-        -1.1021e-01, -4.9983e-02, -2.4183e-03, -9.2346e-01,  1.1904e-01,
-         1.0111e+00, -4.2528e-01, -5.5788e-02,  4.3783e-02,  1.4679e+00,
-         1.7003e-01, -1.1849e+00, -9.0031e-07,  1.0205e-02, -2.8241e-01,
-        -3.4664e-01,  1.0456e+00,  0.0000e+00,  2.3582e-01,  8.1166e-01,
-        -6.0303e-03,  6.5958e-01, -4.7602e-01,  3.6686e-03, -1.1173e+00,
-        -3.8422e-01,  3.7250e-01,  4.9312e-01,  9.3944e-01, -2.3355e-04,
-        -2.6169e-02,  7.4678e-02,  3.8696e-01,  8.6395e-01,  7.3319e-01,
-         4.8606e-08,  1.0360e+00,  6.8702e-06,  5.7537e-01,  1.2453e+00,
-         1.8352e-05, -4.8375e-02,  1.1372e-01, -6.2985e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3725,  0.1789, -0.0847,  0.4349, -0.1181,  0.2672,  0.0000, -0.4337,
-        -0.1361, -0.2867, -0.0295, -0.1801,  0.2456, -0.9883, -1.4608, -1.1149,
-        -0.7300, -0.0080,  0.0000, -0.5547, -0.1102, -0.0500,  0.0000, -0.9235,
-         0.1190,  1.0111, -0.4253, -0.0558,  0.0438,  1.4679,  0.1700, -1.1849,
-         0.0000,  0.0102, -0.2824, -0.3466,  1.0456,  0.0000,  0.2358,  0.8117,
-        -0.0060,  0.6596, -0.4760,  0.0000, -1.1173, -0.3842,  0.3725,  0.4931,
-         0.9394,  0.0000, -0.0262,  0.0747,  0.3870,  0.8639,  0.7332,  0.0000,
-         1.0360,  0.0000,  0.5754,  1.2453,  0.0000, -0.0484,  0.1137, -0.6298],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3725,  0.1789, -0.0847,  0.4349, -0.1181,  0.2672,  0.0000, -0.4337,
-        -0.1361, -0.2867, -0.0295, -0.1801,  0.2456, -0.9883, -1.4608, -1.1149,
-        -0.7300, -0.0080,  0.0000, -0.5547, -0.1102, -0.0500,  0.0000, -0.9235,
-         0.1190,  1.0111, -0.4253, -0.0558,  0.0438,  1.4679,  0.1700, -1.1849,
-         0.0000,  0.0102, -0.2824, -0.3466,  1.0456,  0.0000,  0.2358,  0.8117,
-        -0.0060,  0.6596, -0.4760,  0.0000, -1.1173, -0.3842,  0.3725,  0.4931,
-         0.9394,  0.0000, -0.0262,  0.0747,  0.3870,  0.8639,  0.7332,  0.0000,
-         1.0360,  0.0000,  0.5754,  1.2453,  0.0000, -0.0484,  0.1137, -0.6298],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1747e-01,  2.1618e-01, -2.0078e-01,  4.2589e-01, -1.0982e-01,
-         3.0594e-01,  3.1921e-06, -4.3070e-01, -1.3890e-01, -3.0253e-01,
-        -8.5740e-02, -2.0321e-01,  1.6348e-01, -9.9061e-01, -1.4694e+00,
-        -1.1350e+00, -7.0274e-01, -9.8771e-02, -9.6795e-04, -5.6233e-01,
-        -8.6700e-02, -8.8027e-02, -2.0556e-03, -8.9777e-01,  6.3903e-02,
-         1.0498e+00, -4.2763e-01, -1.2969e-01,  7.6184e-02,  1.4682e+00,
-         1.7568e-01, -1.2637e+00, -7.6526e-07,  3.5928e-02, -2.6151e-01,
-        -3.8249e-01,  1.0583e+00,  0.0000e+00,  2.6328e-01,  8.0901e-01,
-        -5.8050e-02,  6.3151e-01, -4.4432e-01,  3.1183e-03, -1.1119e+00,
-        -3.9846e-01,  3.5773e-01,  5.5995e-01,  9.4448e-01, -1.9852e-04,
-        -6.4543e-02,  2.4372e-02,  3.7686e-01,  8.0046e-01,  7.4493e-01,
-         4.1315e-08,  1.0331e+00,  5.8396e-06,  5.9819e-01,  1.2520e+00,
-         1.5599e-05, -6.4074e-02,  1.5937e-01, -6.3664e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3175,  0.2162, -0.2008,  0.4259, -0.1098,  0.3059,  0.0000, -0.4307,
-        -0.1389, -0.3025, -0.0857, -0.2032,  0.1635, -0.9906, -1.4694, -1.1350,
-        -0.7027, -0.0988,  0.0000, -0.5623, -0.0867, -0.0880,  0.0000, -0.8978,
-         0.0639,  1.0498, -0.4276, -0.1297,  0.0762,  1.4682,  0.1757, -1.2637,
-         0.0000,  0.0359, -0.2615, -0.3825,  0.0000,  0.0000,  0.2633,  0.8090,
-        -0.0580,  0.6315, -0.4443,  0.0000, -1.1119, -0.3985,  0.3577,  0.5599,
-         0.9445,  0.0000, -0.0645,  0.0244,  0.3769,  0.8005,  0.7449,  0.0000,
-         1.0331,  0.0000,  0.5982,  1.2520,  0.0000, -0.0641,  0.1594, -0.6366],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3175,  0.2162, -0.2008,  0.4259, -0.1098,  0.3059,  0.0000, -0.4307,
-        -0.1389, -0.3025, -0.0857, -0.2032,  0.1635, -0.9906, -1.4694, -1.1350,
-        -0.7027, -0.0988,  0.0000, -0.5623, -0.0867, -0.0880,  0.0000, -0.8978,
-         0.0639,  1.0498, -0.4276, -0.1297,  0.0762,  1.4682,  0.1757, -1.2637,
-         0.0000,  0.0359, -0.2615, -0.3825,  0.0000,  0.0000,  0.2633,  0.8090,
-        -0.0580,  0.6315, -0.4443,  0.0000, -1.1119, -0.3985,  0.3577,  0.5599,
-         0.9445,  0.0000, -0.0645,  0.0244,  0.3769,  0.8005,  0.7449,  0.0000,
-         1.0331,  0.0000,  0.5982,  1.2520,  0.0000, -0.0641,  0.1594, -0.6366],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8663e-01,  2.4004e-01, -3.1423e-01,  4.0739e-01, -1.0538e-01,
-         2.8262e-01,  2.7134e-06, -4.1694e-01, -1.4115e-01, -3.2163e-01,
-        -1.3257e-01, -2.3433e-01,  6.1114e-02, -9.8192e-01, -1.4783e+00,
-        -1.1509e+00, -6.8160e-01, -1.7221e-01, -8.2279e-04, -5.7397e-01,
-        -6.2393e-02, -9.8204e-02, -1.7473e-03, -8.7714e-01,  3.5847e-02,
-         1.0862e+00, -3.9834e-01, -1.6417e-01,  1.4894e-01,  1.4694e+00,
-         1.6205e-01, -1.3323e+00, -6.5049e-07,  1.0482e-01, -2.6235e-01,
-        -4.0769e-01,  1.0770e-02,  0.0000e+00,  3.3322e-01,  8.1219e-01,
-        -7.2262e-02,  6.0511e-01, -3.9093e-01,  2.6507e-03, -1.1008e+00,
-        -3.9997e-01,  3.3870e-01,  6.1496e-01,  9.5528e-01, -1.6875e-04,
-        -1.0452e-01,  1.6829e-02,  3.4087e-01,  7.2939e-01,  7.7597e-01,
-         3.5119e-08,  1.0296e+00,  4.9639e-06,  6.1914e-01,  1.2555e+00,
-         1.3260e-05, -7.1478e-02,  2.2837e-01, -6.2838e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2866,  0.2400, -0.3142,  0.4074, -0.1054,  0.2826,  0.0000, -0.4169,
-        -0.1412, -0.3216, -0.1326, -0.2343,  0.0611, -0.9819, -1.4783, -1.1509,
-        -0.6816, -0.1722,  0.0000, -0.5740, -0.0624, -0.0982,  0.0000, -0.8771,
-         0.0358,  1.0862, -0.3983, -0.1642,  0.1489,  1.4694,  0.1620, -1.3323,
-         0.0000,  0.1048, -0.2623, -0.4077,  0.0000,  0.0000,  0.3332,  0.8122,
-        -0.0723,  0.6051, -0.3909,  0.0000, -1.1008, -0.4000,  0.3387,  0.6150,
-         0.9553,  0.0000, -0.1045,  0.0168,  0.3409,  0.7294,  0.7760,  0.0000,
-         1.0296,  0.0000,  0.6191,  1.2555,  0.0000, -0.0715,  0.2284, -0.6284],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2866,  0.2400, -0.3142,  0.4074, -0.1054,  0.2826,  0.0000, -0.4169,
-        -0.1412, -0.3216, -0.1326, -0.2343,  0.0611, -0.9819, -1.4783, -1.1509,
-        -0.6816, -0.1722,  0.0000, -0.5740, -0.0624, -0.0982,  0.0000, -0.8771,
-         0.0358,  1.0862, -0.3983, -0.1642,  0.1489,  1.4694,  0.1620, -1.3323,
-         0.0000,  0.1048, -0.2623, -0.4077,  0.0000,  0.0000,  0.3332,  0.8122,
-        -0.0723,  0.6051, -0.3909,  0.0000, -1.1008, -0.4000,  0.3387,  0.6150,
-         0.9553,  0.0000, -0.1045,  0.0168,  0.3409,  0.7294,  0.7760,  0.0000,
-         1.0296,  0.0000,  0.6191,  1.2555,  0.0000, -0.0715,  0.2284, -0.6284],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7710e-01,  2.5357e-01, -3.0476e-01,  4.0096e-01, -1.5157e-01,
-         2.5870e-01,  2.3065e-06, -4.0984e-01, -8.6176e-02, -2.9764e-01,
-        -1.2434e-01, -2.5842e-01,  4.9551e-02, -9.8766e-01, -1.4860e+00,
-        -1.1623e+00, -6.8572e-01, -1.1095e-01, -6.9942e-04, -6.0045e-01,
-        -4.9358e-02, -7.9095e-02, -1.4853e-03, -8.7888e-01, -1.0741e-02,
-         1.1150e+00, -2.8817e-01, -1.7277e-01,  1.9748e-01,  1.4679e+00,
-         1.3157e-01, -1.3900e+00, -5.5296e-07,  1.0230e-01, -3.0096e-01,
-        -4.2655e-01,  9.1555e-03,  0.0000e+00,  5.1114e-01,  8.1287e-01,
-        -5.2508e-02,  5.6405e-01, -2.6095e-01,  2.2532e-03, -1.0956e+00,
-        -3.8752e-01,  2.6601e-01,  6.8077e-01,  9.5167e-01, -1.4345e-04,
-        -1.0040e-01, -9.3248e-02,  2.8176e-01,  6.5283e-01,  7.9745e-01,
-         2.9854e-08,  1.0262e+00,  4.2196e-06,  6.0573e-01,  1.2600e+00,
-         1.1272e-05,  2.7964e-02,  2.5955e-01, -5.5223e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2771,  0.2536, -0.3048,  0.4010, -0.1516,  0.2587,  0.0000, -0.4098,
-        -0.0862, -0.2976, -0.1243, -0.2584,  0.0496, -0.9877, -1.4860, -1.1623,
-        -0.6857, -0.1110,  0.0000, -0.6004, -0.0494, -0.0791,  0.0000, -0.8789,
-        -0.0107,  1.1150, -0.2882, -0.1728,  0.1975,  1.4679,  0.1316, -1.3900,
-         0.0000,  0.1023, -0.3010, -0.4265,  0.0000,  0.0000,  0.5111,  0.8129,
-        -0.0525,  0.5641, -0.2610,  0.0000, -1.0956, -0.3875,  0.2660,  0.6808,
-         0.9517,  0.0000, -0.1004, -0.0932,  0.2818,  0.6528,  0.7974,  0.0000,
-         1.0262,  0.0000,  0.6057,  1.2600,  0.0000,  0.0280,  0.2595, -0.5522],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2771,  0.2536, -0.3048,  0.4010, -0.1516,  0.2587,  0.0000, -0.4098,
-        -0.0862, -0.2976, -0.1243, -0.2584,  0.0496, -0.9877, -1.4860, -1.1623,
-        -0.6857, -0.1110,  0.0000, -0.6004, -0.0494, -0.0791,  0.0000, -0.8789,
-        -0.0107,  1.1150, -0.2882, -0.1728,  0.1975,  1.4679,  0.1316, -1.3900,
-         0.0000,  0.1023, -0.3010, -0.4265,  0.0000,  0.0000,  0.5111,  0.8129,
-        -0.0525,  0.5641, -0.2610,  0.0000, -1.0956, -0.3875,  0.2660,  0.6808,
-         0.9517,  0.0000, -0.1004, -0.0932,  0.2818,  0.6528,  0.7974,  0.0000,
-         1.0262,  0.0000,  0.6057,  1.2600,  0.0000,  0.0280,  0.2595, -0.5522],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6906e-01,  2.7432e-01, -3.3735e-01,  3.6916e-01, -2.1648e-01,
-         2.0848e-01,  1.9608e-06, -4.1097e-01, -4.6847e-02, -2.5810e-01,
-        -1.1246e-01, -2.7069e-01,  2.5278e-02, -9.8426e-01, -1.4895e+00,
-        -1.1739e+00, -6.7094e-01, -7.6075e-02, -5.9458e-04, -6.1241e-01,
-        -4.8307e-02, -3.9564e-02, -1.2627e-03, -8.8377e-01, -2.6163e-02,
-         1.1347e+00, -1.9886e-01, -1.7442e-01,  2.1113e-01,  1.4682e+00,
-         5.7084e-02, -1.4393e+00, -4.7007e-07,  6.4219e-02, -3.0255e-01,
-        -4.4013e-01,  7.7831e-03,  0.0000e+00,  6.6712e-01,  8.1175e-01,
-         5.8324e-03,  5.2554e-01, -2.2149e-01,  1.9155e-03, -1.0879e+00,
-        -3.7975e-01,  1.9139e-01,  7.2331e-01,  9.4877e-01, -1.2194e-04,
-        -7.1009e-02, -1.5646e-01,  2.5408e-01,  5.7617e-01,  8.0616e-01,
-         2.5379e-08,  1.0227e+00,  3.5871e-06,  5.7252e-01,  1.2629e+00,
-         9.5819e-06,  7.6519e-02,  2.9040e-01, -5.2407e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2691,  0.2743, -0.3373,  0.3692, -0.2165,  0.2085,  0.0000, -0.4110,
-        -0.0468, -0.2581, -0.1125, -0.2707,  0.0253, -0.9843, -1.4895, -1.1739,
-        -0.6709, -0.0761,  0.0000, -0.6124, -0.0483, -0.0396,  0.0000, -0.8838,
-        -0.0262,  1.1347, -0.1989, -0.1744,  0.2111,  1.4682,  0.0571, -1.4393,
-         0.0000,  0.0642, -0.3026, -0.4401,  0.0000,  0.0000,  0.6671,  0.8117,
-         0.0058,  0.5255, -0.2215,  0.0000, -1.0879, -0.3797,  0.1914,  0.7233,
-         0.9488,  0.0000, -0.0710, -0.1565,  0.2541,  0.5762,  0.8062,  0.0000,
-         1.0227,  0.0000,  0.5725,  1.2629,  0.0000,  0.0765,  0.2904, -0.5241],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2691,  0.2743, -0.3373,  0.3692, -0.2165,  0.2085,  0.0000, -0.4110,
-        -0.0468, -0.2581, -0.1125, -0.2707,  0.0253, -0.9843, -1.4895, -1.1739,
-        -0.6709, -0.0761,  0.0000, -0.6124, -0.0483, -0.0396,  0.0000, -0.8838,
-        -0.0262,  1.1347, -0.1989, -0.1744,  0.2111,  1.4682,  0.0571, -1.4393,
-         0.0000,  0.0642, -0.3026, -0.4401,  0.0000,  0.0000,  0.6671,  0.8117,
-         0.0058,  0.5255, -0.2215,  0.0000, -1.0879, -0.3797,  0.1914,  0.7233,
-         0.9488,  0.0000, -0.0710, -0.1565,  0.2541,  0.5762,  0.8062,  0.0000,
-         1.0227,  0.0000,  0.5725,  1.2629,  0.0000,  0.0765,  0.2904, -0.5241],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4671e-01,  3.2366e-01, -3.8324e-01,  3.3630e-01, -2.8084e-01,
-         2.0885e-01,  1.6669e-06, -4.1262e-01, -1.8214e-02, -2.2662e-01,
-        -8.0455e-02, -2.6811e-01, -2.6612e-02, -9.5699e-01, -1.4952e+00,
-        -1.1825e+00, -6.6125e-01, -7.6611e-03, -5.0548e-04, -6.1966e-01,
-        -3.1876e-02,  2.6328e-02, -1.0734e-03, -8.8899e-01, -1.6099e-02,
-         1.1416e+00, -1.2119e-01, -1.6389e-01,  2.7078e-01,  1.4686e+00,
-         3.4764e-02, -1.4803e+00, -3.9963e-07,  1.0847e-01, -2.8863e-01,
-        -4.4095e-01,  6.6167e-03,  0.0000e+00,  7.5865e-01,  7.8764e-01,
-         6.7472e-02,  5.1305e-01, -2.0686e-01,  1.6284e-03, -1.0724e+00,
-        -3.4105e-01,  1.6787e-01,  7.3608e-01,  9.6035e-01, -1.0367e-04,
-        -4.1138e-02, -1.7857e-01,  2.2107e-01,  5.3210e-01,  7.9217e-01,
-         2.1575e-08,  1.0181e+00,  3.0495e-06,  5.4714e-01,  1.2688e+00,
-         8.1460e-06,  1.5342e-01,  3.3696e-01, -5.1437e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2467,  0.3237, -0.3832,  0.3363, -0.2808,  0.2088,  0.0000, -0.4126,
-        -0.0182, -0.2266, -0.0805, -0.2681, -0.0266, -0.9570, -1.4952, -1.1825,
-        -0.6612, -0.0077,  0.0000, -0.6197, -0.0319,  0.0263,  0.0000, -0.8890,
-        -0.0161,  1.1416, -0.1212, -0.1639,  0.2708,  1.4686,  0.0348, -1.4803,
-         0.0000,  0.1085, -0.2886, -0.4410,  0.0000,  0.0000,  0.7586,  0.7876,
-         0.0675,  0.5131, -0.2069,  0.0000, -1.0724, -0.3411,  0.1679,  0.7361,
-         0.9603,  0.0000, -0.0411, -0.1786,  0.2211,  0.5321,  0.7922,  0.0000,
-         1.0181,  0.0000,  0.5471,  1.2688,  0.0000,  0.1534,  0.3370, -0.5144],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2467,  0.3237, -0.3832,  0.3363, -0.2808,  0.2088,  0.0000, -0.4126,
-        -0.0182, -0.2266, -0.0805, -0.2681, -0.0266, -0.9570, -1.4952, -1.1825,
-        -0.6612, -0.0077,  0.0000, -0.6197, -0.0319,  0.0263,  0.0000, -0.8890,
-        -0.0161,  1.1416, -0.1212, -0.1639,  0.2708,  1.4686,  0.0348, -1.4803,
-         0.0000,  0.1085, -0.2886, -0.4410,  0.0000,  0.0000,  0.7586,  0.7876,
-         0.0675,  0.5131, -0.2069,  0.0000, -1.0724, -0.3411,  0.1679,  0.7361,
-         0.9603,  0.0000, -0.0411, -0.1786,  0.2211,  0.5321,  0.7922,  0.0000,
-         1.0181,  0.0000,  0.5471,  1.2688,  0.0000,  0.1534,  0.3370, -0.5144],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.0450e-01,  3.8653e-01, -4.3324e-01,  2.9120e-01, -3.0803e-01,
-         2.3172e-01,  1.4172e-06, -4.0774e-01, -1.5615e-02, -2.2557e-01,
-        -4.5453e-02, -2.6491e-01, -1.1066e-01, -9.2156e-01, -1.4978e+00,
-        -1.1907e+00, -6.4491e-01,  3.9629e-02, -4.2975e-04, -6.1957e-01,
-        -1.2950e-02,  6.7810e-02, -9.1263e-04, -8.6309e-01, -1.9376e-02,
-         1.1352e+00, -9.2835e-02, -1.6902e-01,  3.0166e-01,  1.4740e+00,
-         1.3071e-01, -1.5147e+00, -3.3976e-07,  1.5398e-01, -2.5772e-01,
-        -4.4679e-01,  5.6254e-03,  0.0000e+00,  8.3091e-01,  7.3524e-01,
-         9.6237e-02,  4.9842e-01, -2.1212e-01,  1.3845e-03, -1.0591e+00,
-        -2.6375e-01,  1.2135e-01,  7.3000e-01,  9.9782e-01, -8.8139e-05,
-         6.3435e-03, -1.9331e-01,  2.1788e-01,  5.0094e-01,  7.3956e-01,
-         1.8343e-08,  1.0124e+00,  2.5927e-06,  5.0806e-01,  1.2772e+00,
-         6.9256e-06,  2.5321e-01,  3.8438e-01, -5.6880e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2045,  0.3865, -0.4332,  0.2912, -0.3080,  0.2317,  0.0000, -0.4077,
-        -0.0156, -0.2256, -0.0455, -0.2649, -0.1107, -0.9216,  0.0000, -1.1907,
-        -0.6449,  0.0396,  0.0000, -0.6196, -0.0129,  0.0678,  0.0000, -0.8631,
-        -0.0194,  1.1352, -0.0928, -0.1690,  0.3017,  1.4740,  0.1307, -1.5147,
-         0.0000,  0.1540, -0.2577, -0.4468,  0.0000,  0.0000,  0.8309,  0.7352,
-         0.0962,  0.4984, -0.2121,  0.0000, -1.0591, -0.2638,  0.1214,  0.7300,
-         0.9978,  0.0000,  0.0063, -0.1933,  0.2179,  0.5009,  0.7396,  0.0000,
-         1.0124,  0.0000,  0.5081,  1.2772,  0.0000,  0.2532,  0.3844, -0.5688],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2045,  0.3865, -0.4332,  0.2912, -0.3080,  0.2317,  0.0000, -0.4077,
-        -0.0156, -0.2256, -0.0455, -0.2649, -0.1107, -0.9216,  0.0000, -1.1907,
-        -0.6449,  0.0396,  0.0000, -0.6196, -0.0129,  0.0678,  0.0000, -0.8631,
-        -0.0194,  1.1352, -0.0928, -0.1690,  0.3017,  1.4740,  0.1307, -1.5147,
-         0.0000,  0.1540, -0.2577, -0.4468,  0.0000,  0.0000,  0.8309,  0.7352,
-         0.0962,  0.4984, -0.2121,  0.0000, -1.0591, -0.2638,  0.1214,  0.7300,
-         0.9978,  0.0000,  0.0063, -0.1933,  0.2179,  0.5009,  0.7396,  0.0000,
-         1.0124,  0.0000,  0.5081,  1.2772,  0.0000,  0.2532,  0.3844, -0.5688],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.9911e-01,  4.1270e-01, -4.1082e-01,  2.4686e-01, -4.0191e-01,
-         2.6366e-01,  1.2050e-06, -4.0798e-01, -4.8034e-02, -1.8587e-01,
-         4.0765e-03, -2.8117e-01, -1.0909e-01, -8.8879e-01, -2.1608e-03,
-        -1.1955e+00, -6.3248e-01,  1.2646e-01, -3.6539e-04, -6.1041e-01,
-        -2.2560e-02,  5.8135e-02, -7.7595e-04, -8.4232e-01, -6.5855e-02,
-         1.1247e+00, -3.8617e-03, -1.8366e-01,  2.7254e-01,  1.4743e+00,
-         1.8549e-01, -1.5396e+00, -2.8888e-07,  1.0823e-01, -2.8904e-01,
-        -4.3374e-01,  4.7830e-03,  0.0000e+00,  8.8842e-01,  6.7009e-01,
-         1.1701e-01,  4.8218e-01, -1.7180e-01,  1.1771e-03, -1.0522e+00,
-        -1.7936e-01,  6.5480e-02,  7.0441e-01,  1.0151e+00, -7.4939e-05,
-         5.1340e-02, -3.5034e-01,  2.2090e-01,  4.6109e-01,  6.8294e-01,
-         1.5596e-08,  1.0067e+00,  2.2044e-06,  4.5685e-01,  1.2870e+00,
-         5.8884e-06,  3.3556e-01,  4.0601e-01, -5.4193e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1991,  0.4127, -0.4108,  0.2469, -0.4019,  0.2637,  0.0000, -0.4080,
-        -0.0480, -0.1859,  0.0041, -0.2812, -0.1091, -0.8888,  0.0000, -1.1955,
-        -0.6325,  0.1265,  0.0000, -0.6104, -0.0226,  0.0581,  0.0000, -0.8423,
-        -0.0659,  1.1247, -0.0039, -0.1837,  0.2725,  1.4743,  0.1855, -1.5396,
-         0.0000,  0.1082, -0.2890, -0.4337,  0.0000,  0.0000,  0.8884,  0.6701,
-         0.1170,  0.4822, -0.1718,  0.0000, -1.0522, -0.1794,  0.0655,  0.7044,
-         1.0151,  0.0000,  0.0513, -0.3503,  0.2209,  0.4611,  0.6829,  0.0000,
-         1.0067,  0.0000,  0.4569,  1.2870,  0.0000,  0.3356,  0.4060, -0.5419],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1991,  0.4127, -0.4108,  0.2469, -0.4019,  0.2637,  0.0000, -0.4080,
-        -0.0480, -0.1859,  0.0041, -0.2812, -0.1091, -0.8888,  0.0000, -1.1955,
-        -0.6325,  0.1265,  0.0000, -0.6104, -0.0226,  0.0581,  0.0000, -0.8423,
-        -0.0659,  1.1247, -0.0039, -0.1837,  0.2725,  1.4743,  0.1855, -1.5396,
-         0.0000,  0.1082, -0.2890, -0.4337,  0.0000,  0.0000,  0.8884,  0.6701,
-         0.1170,  0.4822, -0.1718,  0.0000, -1.0522, -0.1794,  0.0655,  0.7044,
-         1.0151,  0.0000,  0.0513, -0.3503,  0.2209,  0.4611,  0.6829,  0.0000,
-         1.0067,  0.0000,  0.4569,  1.2870,  0.0000,  0.3356,  0.4060, -0.5419],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.1676e-01,  4.2598e-01, -4.2784e-01,  1.9675e-01, -4.1641e-01,
-         2.4639e-01,  1.0246e-06, -4.0349e-01, -5.3918e-02, -1.6192e-01,
-         5.5418e-02, -2.8205e-01, -1.3553e-01, -8.6746e-01, -1.8373e-03,
-        -1.1961e+00, -5.8981e-01,  2.1806e-01, -3.1069e-04, -5.9949e-01,
-        -7.7951e-03,  2.3842e-02, -6.5978e-04, -8.2589e-01, -7.6146e-02,
-         1.1198e+00,  6.4716e-02, -1.8078e-01,  2.3203e-01,  1.4768e+00,
-         1.9749e-01, -1.5601e+00, -2.4563e-07,  6.5391e-02, -2.8827e-01,
-        -4.2160e-01,  4.0669e-03,  0.0000e+00,  9.4164e-01,  6.1129e-01,
-         1.4747e-01,  4.8560e-01, -1.6879e-01,  1.0009e-03, -1.0415e+00,
-        -1.2621e-01,  3.3697e-02,  6.7923e-01,  1.0233e+00, -6.3719e-05,
-         9.2279e-02, -4.5275e-01,  2.0929e-01,  4.0579e-01,  6.5067e-01,
-         1.3261e-08,  1.0026e+00,  1.8744e-06,  4.2463e-01,  1.2920e+00,
-         5.0069e-06,  3.6161e-01,  4.2049e-01, -5.1519e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2168,  0.4260, -0.4278,  0.1967, -0.4164,  0.2464,  0.0000, -0.4035,
-        -0.0539, -0.1619,  0.0554, -0.2820, -0.1355, -0.8675,  0.0000, -1.1961,
-        -0.5898,  0.2181,  0.0000, -0.5995, -0.0078,  0.0238,  0.0000, -0.8259,
-        -0.0761,  1.1198,  0.0647, -0.1808,  0.2320,  1.4768,  0.1975, -1.5601,
-         0.0000,  0.0654, -0.2883, -0.4216,  0.0000,  0.0000,  0.9416,  0.6113,
-         0.1475,  0.4856, -0.1688,  0.0000, -1.0415, -0.1262,  0.0337,  0.6792,
-         1.0233,  0.0000,  0.0923, -0.4528,  0.2093,  0.4058,  0.6507,  0.0000,
-         1.0026,  0.0000,  0.4246,  1.2920,  0.0000,  0.3616,  0.4205, -0.5152],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2168,  0.4260, -0.4278,  0.1967, -0.4164,  0.2464,  0.0000, -0.4035,
-        -0.0539, -0.1619,  0.0554, -0.2820, -0.1355, -0.8675,  0.0000, -1.1961,
-        -0.5898,  0.2181,  0.0000, -0.5995, -0.0078,  0.0238,  0.0000, -0.8259,
-        -0.0761,  1.1198,  0.0647, -0.1808,  0.2320,  1.4768,  0.1975, -1.5601,
-         0.0000,  0.0654, -0.2883, -0.4216,  0.0000,  0.0000,  0.9416,  0.6113,
-         0.1475,  0.4856, -0.1688,  0.0000, -1.0415, -0.1262,  0.0337,  0.6792,
-         1.0233,  0.0000,  0.0923, -0.4528,  0.2093,  0.4058,  0.6507,  0.0000,
-         1.0026,  0.0000,  0.4246,  1.2920,  0.0000,  0.3616,  0.4205, -0.5152],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.2528e-01,  3.9173e-01, -4.6102e-01,  1.2023e-01, -3.8303e-01,
-         1.8642e-01,  8.7124e-07, -4.1169e-01, -2.9308e-02, -1.3733e-01,
-         1.2244e-01, -2.6160e-01, -2.0499e-01, -8.4228e-01, -1.5624e-03,
-        -1.1955e+00, -5.4102e-01,  3.0136e-01, -2.6419e-04, -5.7509e-01,
-         5.4192e-03, -2.1356e-02, -5.6104e-04, -8.1247e-01, -6.3046e-02,
-         1.1198e+00,  6.1938e-02, -1.7873e-01,  1.6628e-01,  1.4791e+00,
-         1.5640e-01, -1.5784e+00, -2.0887e-07,  2.0082e-02, -2.2989e-01,
-        -3.8808e-01,  3.4583e-03,  0.0000e+00,  9.7142e-01,  5.5755e-01,
-         1.7094e-01,  4.9515e-01, -2.1285e-01,  8.5111e-04, -1.0287e+00,
-        -9.0370e-02,  5.7504e-02,  6.6488e-01,  1.0310e+00, -5.4183e-05,
-         1.2804e-01, -5.0345e-01,  1.9570e-01,  2.8804e-01,  6.3300e-01,
-         1.1276e-08,  9.9878e-01,  1.5939e-06,  3.6431e-01,  1.2910e+00,
-         4.2576e-06,  3.1806e-01,  4.0787e-01, -5.0625e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2253,  0.3917, -0.4610,  0.1202, -0.3830,  0.1864,  0.0000, -0.4117,
-        -0.0293, -0.1373,  0.1224, -0.2616, -0.2050, -0.8423,  0.0000, -1.1955,
-        -0.5410,  0.3014,  0.0000, -0.5751,  0.0054, -0.0214,  0.0000, -0.8125,
-        -0.0630,  1.1198,  0.0619, -0.1787,  0.1663,  1.4791,  0.1564, -1.5784,
-         0.0000,  0.0201, -0.2299, -0.3881,  0.0000,  0.0000,  0.9714,  0.5576,
-         0.1709,  0.4952, -0.2128,  0.0000, -1.0287, -0.0904,  0.0575,  0.6649,
-         1.0310,  0.0000,  0.1280, -0.5035,  0.1957,  0.2880,  0.6330,  0.0000,
-         0.9988,  0.0000,  0.3643,  1.2910,  0.0000,  0.3181,  0.4079, -0.5062],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2253,  0.3917, -0.4610,  0.1202, -0.3830,  0.1864,  0.0000, -0.4117,
-        -0.0293, -0.1373,  0.1224, -0.2616, -0.2050, -0.8423,  0.0000, -1.1955,
-        -0.5410,  0.3014,  0.0000, -0.5751,  0.0054, -0.0214,  0.0000, -0.8125,
-        -0.0630,  1.1198,  0.0619, -0.1787,  0.1663,  1.4791,  0.1564, -1.5784,
-         0.0000,  0.0201, -0.2299, -0.3881,  0.0000,  0.0000,  0.9714,  0.5576,
-         0.1709,  0.4952, -0.2128,  0.0000, -1.0287, -0.0904,  0.0575,  0.6649,
-         1.0310,  0.0000,  0.1280, -0.5035,  0.1957,  0.2880,  0.6330,  0.0000,
-         0.9988,  0.0000,  0.3643,  1.2910,  0.0000,  0.3181,  0.4079, -0.5062],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.1895e-01,  3.3101e-01, -4.8818e-01,  7.9390e-02, -1.9813e-01,
-         1.2785e-01,  7.4090e-07, -4.0593e-01,  1.9324e-02, -9.4112e-02,
-         1.5895e-01, -2.5351e-01, -2.7666e-01, -8.2803e-01, -1.3286e-03,
-        -1.1934e+00, -4.7233e-01,  3.6835e-01, -2.2467e-04, -5.3621e-01,
-         6.1393e-02, -2.0156e-02, -4.7711e-04, -7.6934e-01, -1.6347e-02,
-         1.1220e+00,  3.6155e-02, -1.6542e-01,  5.1708e-02,  1.4793e+00,
-         9.1605e-02, -1.5940e+00, -1.7762e-07, -2.0850e-03, -1.2948e-01,
-        -3.5155e-01,  2.9409e-03,  0.0000e+00,  1.0038e+00,  5.3806e-01,
-         1.9905e-01,  4.9157e-01, -2.3445e-01,  7.2379e-04, -1.0194e+00,
-        -6.7632e-02,  8.4455e-02,  6.4223e-01,  1.0221e+00, -4.6078e-05,
-         1.5364e-01, -5.0494e-01,  1.9634e-01,  1.3452e-01,  6.3572e-01,
-         9.5895e-09,  9.9585e-01,  1.3554e-06,  2.9400e-01,  1.2841e+00,
-         3.6206e-06,  2.4998e-01,  3.6583e-01, -4.9202e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2189,  0.3310, -0.4882,  0.0794, -0.1981,  0.1279,  0.0000, -0.4059,
-         0.0193, -0.0941,  0.1590, -0.2535, -0.2767, -0.8280,  0.0000, -1.1934,
-        -0.4723,  0.3684,  0.0000, -0.5362,  0.0614, -0.0202,  0.0000, -0.7693,
-        -0.0163,  1.1220,  0.0362, -0.1654,  0.0517,  1.4793,  0.0916, -1.5940,
-         0.0000, -0.0021, -0.1295, -0.3516,  0.0000,  0.0000,  1.0038,  0.5381,
-         0.1991,  0.4916, -0.2345,  0.0000, -1.0194, -0.0676,  0.0845,  0.6422,
-         1.0221,  0.0000,  0.1536, -0.5049,  0.1963,  0.1345,  0.6357,  0.0000,
-         0.9959,  0.0000,  0.2940,  1.2841,  0.0000,  0.2500,  0.3658, -0.4920],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2189,  0.3310, -0.4882,  0.0794, -0.1981,  0.1279,  0.0000, -0.4059,
-         0.0193, -0.0941,  0.1590, -0.2535, -0.2767, -0.8280,  0.0000, -1.1934,
-        -0.4723,  0.3684,  0.0000, -0.5362,  0.0614, -0.0202,  0.0000, -0.7693,
-        -0.0163,  1.1220,  0.0362, -0.1654,  0.0517,  1.4793,  0.0916, -1.5940,
-         0.0000, -0.0021, -0.1295, -0.3516,  0.0000,  0.0000,  1.0038,  0.5381,
-         0.1991,  0.4916, -0.2345,  0.0000, -1.0194, -0.0676,  0.0845,  0.6422,
-         1.0221,  0.0000,  0.1536, -0.5049,  0.1963,  0.1345,  0.6357,  0.0000,
-         0.9959,  0.0000,  0.2940,  1.2841,  0.0000,  0.2500,  0.3658, -0.4920],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.6823e-01,  2.4932e-01, -5.0622e-01,  4.8536e-02,  2.7289e-02,
-         8.2699e-02,  6.3011e-07, -4.1112e-01,  2.3412e-02, -4.8491e-02,
-         2.1688e-01, -2.1784e-01, -3.1499e-01, -8.1232e-01, -1.1300e-03,
-        -1.1921e+00, -4.1126e-01,  4.2602e-01, -1.9107e-04, -4.9716e-01,
-         1.0292e-01, -4.4837e-02, -4.0577e-04, -7.3375e-01, -6.3962e-03,
-         1.1193e+00,  6.3727e-03, -1.7109e-01, -4.2706e-02,  1.4783e+00,
-         3.8724e-02, -1.6065e+00, -1.5106e-07, -1.5706e-02, -5.4500e-02,
-        -2.9778e-01,  2.5011e-03,  0.0000e+00,  1.0069e+00,  4.8246e-01,
-         2.1752e-01,  4.8798e-01, -2.1480e-01,  6.1555e-04, -1.0053e+00,
-        -5.4059e-02,  1.0883e-01,  6.0521e-01,  1.0072e+00, -3.9188e-05,
-         1.8768e-01, -5.0143e-01,  1.8498e-01, -3.9155e-03,  6.1525e-01,
-         8.1556e-09,  9.9224e-01,  1.1527e-06,  2.3943e-01,  1.2766e+00,
-         3.0792e-06,  2.2606e-01,  3.3816e-01, -4.9383e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1682,  0.2493, -0.5062,  0.0485,  0.0273,  0.0827,  0.0000, -0.4111,
-         0.0234, -0.0485,  0.2169, -0.2178, -0.3150, -0.8123,  0.0000, -1.1921,
-        -0.4113,  0.4260,  0.0000, -0.4972,  0.1029, -0.0448,  0.0000, -0.7338,
-        -0.0064,  1.1193,  0.0064, -0.1711, -0.0427,  1.4783,  0.0387, -1.6065,
-         0.0000, -0.0157,  0.0000, -0.2978,  0.0000,  0.0000,  1.0069,  0.4825,
-         0.2175,  0.4880, -0.2148,  0.0000, -1.0053, -0.0541,  0.1088,  0.6052,
-         1.0072,  0.0000,  0.1877, -0.5014,  0.1850, -0.0039,  0.6153,  0.0000,
-         0.9922,  0.0000,  0.2394,  1.2766,  0.0000,  0.2261,  0.3382, -0.4938],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1682,  0.2493, -0.5062,  0.0485,  0.0273,  0.0827,  0.0000, -0.4111,
-         0.0234, -0.0485,  0.2169, -0.2178, -0.3150, -0.8123,  0.0000, -1.1921,
-        -0.4113,  0.4260,  0.0000, -0.4972,  0.1029, -0.0448,  0.0000, -0.7338,
-        -0.0064,  1.1193,  0.0064, -0.1711, -0.0427,  1.4783,  0.0387, -1.6065,
-         0.0000, -0.0157,  0.0000, -0.2978,  0.0000,  0.0000,  1.0069,  0.4825,
-         0.2175,  0.4880, -0.2148,  0.0000, -1.0053, -0.0541,  0.1088,  0.6052,
-         1.0072,  0.0000,  0.1877, -0.5014,  0.1850, -0.0039,  0.6153,  0.0000,
-         0.9922,  0.0000,  0.2394,  1.2766,  0.0000,  0.2261,  0.3382, -0.4938],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.3446e-01,  1.5256e-01, -4.5825e-01,  7.0556e-02,  2.4293e-01,
-         6.2367e-02,  5.3593e-07, -4.4126e-01, -2.5409e-02, -7.6537e-03,
-         2.5336e-01, -1.4038e-01, -2.4451e-01, -7.9405e-01, -9.6108e-04,
-        -1.1958e+00, -4.3450e-01,  4.7246e-01, -1.6251e-04, -4.6013e-01,
-         1.3634e-01, -8.8417e-02, -3.4512e-04, -7.3336e-01, -7.0718e-02,
-         1.1171e+00,  2.8329e-02, -2.3609e-01, -1.3459e-01,  1.4740e+00,
-        -3.7122e-02, -1.6164e+00, -1.2848e-07, -5.7969e-02,  6.3773e-02,
-        -2.1439e-01,  2.1273e-03,  0.0000e+00,  1.0164e+00,  4.4847e-01,
-         2.2650e-01,  4.7697e-01, -1.1380e-01,  5.2355e-04, -1.0026e+00,
-        -1.0419e-01,  1.5495e-01,  5.8041e-01,  9.9253e-01, -3.3330e-05,
-         1.9590e-01, -5.3405e-01,  1.3513e-01, -5.9950e-02,  5.9235e-01,
-         6.9366e-09,  9.8961e-01,  9.8044e-07,  2.1472e-01,  1.2694e+00,
-         2.6190e-06,  2.4879e-01,  2.9015e-01, -4.5651e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1345,  0.1526, -0.4582,  0.0706,  0.2429,  0.0624,  0.0000, -0.4413,
-        -0.0254, -0.0077,  0.2534, -0.1404, -0.2445, -0.7940,  0.0000, -1.1958,
-        -0.4345,  0.4725,  0.0000, -0.4601,  0.1363, -0.0884,  0.0000, -0.7334,
-        -0.0707,  1.1171,  0.0283, -0.2361, -0.1346,  1.4740, -0.0371, -1.6164,
-         0.0000, -0.0580,  0.0000, -0.2144,  0.0000,  0.0000,  1.0164,  0.4485,
-         0.2265,  0.4770, -0.1138,  0.0000, -1.0026, -0.1042,  0.1550,  0.5804,
-         0.9925,  0.0000,  0.1959, -0.5340,  0.1351, -0.0600,  0.5924,  0.0000,
-         0.9896,  0.0000,  0.2147,  1.2694,  0.0000,  0.2488,  0.2902, -0.4565],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1345,  0.1526, -0.4582,  0.0706,  0.2429,  0.0624,  0.0000, -0.4413,
-        -0.0254, -0.0077,  0.2534, -0.1404, -0.2445, -0.7940,  0.0000, -1.1958,
-        -0.4345,  0.4725,  0.0000, -0.4601,  0.1363, -0.0884,  0.0000, -0.7334,
-        -0.0707,  1.1171,  0.0283, -0.2361, -0.1346,  1.4740, -0.0371, -1.6164,
-         0.0000, -0.0580,  0.0000, -0.2144,  0.0000,  0.0000,  1.0164,  0.4485,
-         0.2265,  0.4770, -0.1138,  0.0000, -1.0026, -0.1042,  0.1550,  0.5804,
-         0.9925,  0.0000,  0.1959, -0.5340,  0.1351, -0.0600,  0.5924,  0.0000,
-         0.9896,  0.0000,  0.2147,  1.2694,  0.0000,  0.2488,  0.2902, -0.4565],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.3631e-01,  1.3478e-01, -4.4397e-01,  1.0368e-01,  7.2166e-01,
-         6.8913e-02,  4.5586e-07, -4.4755e-01,  3.5794e-02, -1.1683e-03,
-         2.4836e-01, -1.5215e-01, -2.2407e-01, -7.6111e-01, -8.1749e-04,
-        -1.1898e+00, -4.2365e-01,  5.0827e-01, -1.3823e-04, -4.4726e-01,
-         1.8396e-01, -4.5938e-02, -2.9356e-04, -7.1478e-01, -1.8016e-02,
-         1.1059e+00,  5.7891e-02, -1.9684e-01, -2.2225e-01,  1.4731e+00,
-        -1.0174e-01, -1.6264e+00, -1.0929e-07, -1.2148e-02,  5.4245e-02,
-        -1.9544e-01,  1.8095e-03,  0.0000e+00,  9.9276e-01,  3.8274e-01,
-         2.5883e-01,  4.8607e-01, -4.4119e-02,  4.4533e-04, -9.9732e-01,
-        -1.4942e-01,  2.3184e-01,  5.1739e-01,  9.8125e-01, -2.8351e-05,
-         1.8413e-01, -5.3049e-01,  4.4841e-02,  3.6133e-02,  5.5851e-01,
-         5.9003e-09,  9.8711e-01,  8.3396e-07,  2.0635e-01,  1.2667e+00,
-         2.2277e-06,  2.6415e-01,  2.1800e-01, -4.2162e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 1.3631e-01,  1.3478e-01, -4.4397e-01,  1.0368e-01,  7.2166e-01,
-         6.8913e-02,  0.0000e+00, -4.4755e-01,  3.5794e-02, -1.1683e-03,
-         2.4836e-01, -1.5215e-01, -2.2407e-01, -7.6111e-01,  0.0000e+00,
-        -1.1898e+00, -4.2365e-01,  5.0827e-01,  0.0000e+00, -4.4726e-01,
-         1.8396e-01, -4.5938e-02,  0.0000e+00, -7.1478e-01, -1.8016e-02,
-         1.1059e+00,  5.7891e-02, -1.9684e-01, -2.2225e-01,  1.4731e+00,
-        -1.0174e-01, -1.6264e+00,  0.0000e+00, -1.2148e-02,  0.0000e+00,
-        -1.9544e-01,  0.0000e+00,  0.0000e+00,  9.9276e-01,  3.8274e-01,
-         2.5883e-01,  4.8607e-01, -4.4119e-02,  0.0000e+00, -9.9732e-01,
-        -1.4942e-01,  2.3184e-01,  5.1739e-01,  9.8125e-01,  0.0000e+00,
-         1.8413e-01, -5.3049e-01,  4.4841e-02,  3.6133e-02,  5.5851e-01,
-         0.0000e+00,  9.8711e-01,  0.0000e+00,  2.0635e-01,  1.2667e+00,
-         0.0000e+00,  2.6415e-01,  2.1800e-01, -4.2162e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 1.3631e-01,  1.3478e-01, -4.4397e-01,  1.0368e-01,  7.2166e-01,
-         6.8913e-02,  0.0000e+00, -4.4755e-01,  3.5794e-02, -1.1683e-03,
-         2.4836e-01, -1.5215e-01, -2.2407e-01, -7.6111e-01,  0.0000e+00,
-        -1.1898e+00, -4.2365e-01,  5.0827e-01,  0.0000e+00, -4.4726e-01,
-         1.8396e-01, -4.5938e-02,  0.0000e+00, -7.1478e-01, -1.8016e-02,
-         1.1059e+00,  5.7891e-02, -1.9684e-01, -2.2225e-01,  1.4731e+00,
-        -1.0174e-01, -1.6264e+00,  0.0000e+00, -1.2148e-02,  0.0000e+00,
-        -1.9544e-01,  0.0000e+00,  0.0000e+00,  9.9276e-01,  3.8274e-01,
-         2.5883e-01,  4.8607e-01, -4.4119e-02,  0.0000e+00, -9.9732e-01,
-        -1.4942e-01,  2.3184e-01,  5.1739e-01,  9.8125e-01,  0.0000e+00,
-         1.8413e-01, -5.3049e-01,  4.4841e-02,  3.6133e-02,  5.5851e-01,
-         0.0000e+00,  9.8711e-01,  0.0000e+00,  2.0635e-01,  1.2667e+00,
-         0.0000e+00,  2.6415e-01,  2.1800e-01, -4.2162e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.6583e-01,  1.2397e-01, -4.4260e-01,  1.3279e-01,  1.0911e+00,
-         2.5085e-02,  3.8779e-07, -4.4176e-01,  1.1504e-01, -1.5695e-02,
-         2.0457e-01, -1.8718e-01, -2.2734e-01, -7.4127e-01, -6.9542e-04,
-        -1.1871e+00, -4.1626e-01,  5.2450e-01, -1.1759e-04, -4.1688e-01,
-         2.0424e-01,  7.2009e-02, -2.4972e-04, -7.1322e-01,  1.0383e-01,
-         1.1043e+00,  4.1592e-02, -8.3131e-02, -2.7766e-01,  1.4741e+00,
-        -1.5271e-01, -1.6341e+00, -9.2968e-08,  5.4633e-02,  4.6145e-02,
-        -1.9134e-01,  1.5393e-03,  0.0000e+00,  9.9243e-01,  3.7926e-01,
-         2.9846e-01,  4.7983e-01, -1.9601e-02,  3.7883e-04, -1.0101e+00,
-        -2.2535e-01,  2.7962e-01,  4.5148e-01,  9.7706e-01, -2.4117e-05,
-         1.9067e-01, -4.7775e-01, -6.9137e-02,  1.4932e-01,  5.3364e-01,
-         5.0192e-09,  9.8377e-01,  7.0943e-07,  2.0043e-01,  1.2634e+00,
-         1.8951e-06,  2.0444e-01,  1.6510e-01, -3.9053e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1658,  0.1240, -0.4426,  0.1328,  1.0911,  0.0251,  0.0000, -0.4418,
-         0.1150, -0.0157,  0.2046, -0.1872, -0.2273, -0.7413,  0.0000, -1.1871,
-        -0.4163,  0.5245,  0.0000, -0.4169,  0.2042,  0.0720,  0.0000, -0.7132,
-         0.1038,  1.1043,  0.0416, -0.0831, -0.2777,  1.4741, -0.1527, -1.6341,
-         0.0000,  0.0546,  0.0000, -0.1913,  0.0000,  0.0000,  0.9924,  0.3793,
-         0.2985,  0.4798, -0.0196,  0.0000, -1.0101, -0.2254,  0.2796,  0.4515,
-         0.9771,  0.0000,  0.1907, -0.4777, -0.0691,  0.1493,  0.5336,  0.0000,
-         0.9838,  0.0000,  0.2004,  1.2634,  0.0000,  0.2044,  0.1651, -0.3905],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1658,  0.1240, -0.4426,  0.1328,  1.0911,  0.0251,  0.0000, -0.4418,
-         0.1150, -0.0157,  0.2046, -0.1872, -0.2273, -0.7413,  0.0000, -1.1871,
-        -0.4163,  0.5245,  0.0000, -0.4169,  0.2042,  0.0720,  0.0000, -0.7132,
-         0.1038,  1.1043,  0.0416, -0.0831, -0.2777,  1.4741, -0.1527, -1.6341,
-         0.0000,  0.0546,  0.0000, -0.1913,  0.0000,  0.0000,  0.9924,  0.3793,
-         0.2985,  0.4798, -0.0196,  0.0000, -1.0101, -0.2254,  0.2796,  0.4515,
-         0.9771,  0.0000,  0.1907, -0.4777, -0.0691,  0.1493,  0.5336,  0.0000,
-         0.9838,  0.0000,  0.2004,  1.2634,  0.0000,  0.2044,  0.1651, -0.3905],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.1900e-01,  5.2798e-02, -3.5580e-01,  1.8259e-01,  1.3894e+00,
-         2.4168e-02,  3.2991e-07, -4.7006e-01,  1.6102e-01, -4.2962e-02,
-         1.6505e-01, -2.4828e-01, -1.5037e-01, -7.3758e-01, -5.9163e-04,
-        -1.1924e+00, -4.8332e-01,  5.2812e-01, -1.0004e-04, -4.2177e-01,
-         1.8140e-01,  1.5832e-01, -2.1245e-04, -7.5679e-01,  1.8709e-01,
-         1.1161e+00,  4.6190e-02, -2.8562e-02, -2.9361e-01,  1.4744e+00,
-        -2.7806e-01, -1.6374e+00, -7.9093e-08,  3.5944e-02,  3.9258e-02,
-        -2.0798e-01,  1.3096e-03,  0.0000e+00,  1.0250e+00,  4.5422e-01,
-         3.1980e-01,  4.5691e-01,  1.3818e-02,  3.2229e-04, -1.0282e+00,
-        -3.2373e-01,  3.1616e-01,  4.0764e-01,  9.7834e-01, -2.0518e-05,
-         2.0336e-01, -4.7354e-01, -1.4257e-01,  1.9826e-01,  5.3389e-01,
-         4.2701e-09,  9.8012e-01,  6.0355e-07,  1.2736e-01,  1.2593e+00,
-         1.6122e-06,  1.1002e-01,  1.2529e-01, -3.3120e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2190,  0.0528, -0.3558,  0.1826,  1.3894,  0.0242,  0.0000, -0.4701,
-         0.1610, -0.0430,  0.1651, -0.2483, -0.1504, -0.7376,  0.0000, -1.1924,
-        -0.4833,  0.5281,  0.0000, -0.4218,  0.1814,  0.1583,  0.0000, -0.7568,
-         0.1871,  1.1161,  0.0462, -0.0286, -0.2936,  0.0000, -0.2781, -1.6374,
-         0.0000,  0.0359,  0.0000, -0.2080,  0.0000,  0.0000,  1.0250,  0.4542,
-         0.3198,  0.4569,  0.0138,  0.0000, -1.0282, -0.3237,  0.3162,  0.4076,
-         0.9783,  0.0000,  0.2034, -0.4735, -0.1426,  0.1983,  0.5339,  0.0000,
-         0.9801,  0.0000,  0.1274,  1.2593,  0.0000,  0.1100,  0.1253, -0.3312],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2190,  0.0528, -0.3558,  0.1826,  1.3894,  0.0242,  0.0000, -0.4701,
-         0.1610, -0.0430,  0.1651, -0.2483, -0.1504, -0.7376,  0.0000, -1.1924,
-        -0.4833,  0.5281,  0.0000, -0.4218,  0.1814,  0.1583,  0.0000, -0.7568,
-         0.1871,  1.1161,  0.0462, -0.0286, -0.2936,  0.0000, -0.2781, -1.6374,
-         0.0000,  0.0359,  0.0000, -0.2080,  0.0000,  0.0000,  1.0250,  0.4542,
-         0.3198,  0.4569,  0.0138,  0.0000, -1.0282, -0.3237,  0.3162,  0.4076,
-         0.9783,  0.0000,  0.2034, -0.4735, -0.1426,  0.1983,  0.5339,  0.0000,
-         0.9801,  0.0000,  0.1274,  1.2593,  0.0000,  0.1100,  0.1253, -0.3312],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.0415e-01,  7.4420e-03, -3.2425e-01,  1.9624e-01,  1.6281e+00,
-         8.8227e-02,  2.8070e-07, -5.0675e-01,  1.1959e-01, -1.1284e-01,
-         8.7176e-02, -3.0875e-01, -1.4085e-01, -7.3765e-01, -5.0338e-04,
-        -1.1975e+00, -4.9939e-01,  5.1279e-01, -8.5119e-05, -4.7277e-01,
-         8.1324e-02,  1.1157e-01, -1.8076e-04, -7.4715e-01,  1.5967e-01,
-         1.1260e+00,  8.6875e-02, -5.5487e-02, -2.8028e-01,  2.6289e-04,
-        -3.9827e-01, -1.6367e+00, -6.7295e-08, -4.9916e-02,  3.3402e-02,
-        -2.5813e-01,  1.1142e-03,  0.0000e+00,  1.0368e+00,  4.6072e-01,
-         2.8860e-01,  4.1323e-01,  5.3528e-02,  2.7422e-04, -1.0534e+00,
-        -3.4397e-01,  3.0453e-01,  3.6023e-01,  9.5915e-01, -1.7457e-05,
-         1.9155e-01, -5.3709e-01, -1.4238e-01,  2.1263e-01,  5.0316e-01,
-         3.6331e-09,  9.7900e-01,  5.1352e-07,  6.9916e-02,  1.2494e+00,
-         1.3717e-06,  1.2472e-01,  6.7086e-02, -2.7898e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3042,  0.0074, -0.3242,  0.1962,  1.6281,  0.0882,  0.0000, -0.5067,
-         0.1196, -0.1128,  0.0872, -0.3088, -0.1409, -0.7376,  0.0000, -1.1975,
-        -0.4994,  0.5128,  0.0000, -0.4728,  0.0813,  0.1116,  0.0000, -0.7472,
-         0.1597,  1.1260,  0.0869, -0.0555, -0.2803,  0.0000, -0.3983, -1.6367,
-         0.0000, -0.0499,  0.0000, -0.2581,  0.0000,  0.0000,  1.0368,  0.4607,
-         0.2886,  0.4132,  0.0535,  0.0000, -1.0534, -0.3440,  0.3045,  0.3602,
-         0.9591,  0.0000,  0.1916, -0.5371, -0.1424,  0.2126,  0.5032,  0.0000,
-         0.9790,  0.0000,  0.0699,  1.2494,  0.0000,  0.1247,  0.0671, -0.2790],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3042,  0.0074, -0.3242,  0.1962,  1.6281,  0.0882,  0.0000, -0.5067,
-         0.1196, -0.1128,  0.0872, -0.3088, -0.1409, -0.7376,  0.0000, -1.1975,
-        -0.4994,  0.5128,  0.0000, -0.4728,  0.0813,  0.1116,  0.0000, -0.7472,
-         0.1597,  1.1260,  0.0869, -0.0555, -0.2803,  0.0000, -0.3983, -1.6367,
-         0.0000, -0.0499,  0.0000, -0.2581,  0.0000,  0.0000,  1.0368,  0.4607,
-         0.2886,  0.4132,  0.0535,  0.0000, -1.0534, -0.3440,  0.3045,  0.3602,
-         0.9591,  0.0000,  0.1916, -0.5371, -0.1424,  0.2126,  0.5032,  0.0000,
-         0.9790,  0.0000,  0.0699,  1.2494,  0.0000,  0.1247,  0.0671, -0.2790],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4086e-01,  5.4015e-03, -3.6387e-01,  1.0157e-01,  1.8300e+00,
-         1.2201e-01,  2.3885e-07, -5.3257e-01,  4.2014e-02, -2.0836e-01,
-         5.4233e-04, -3.5044e-01, -1.9528e-01, -7.3592e-01, -4.2833e-04,
-        -1.1988e+00, -4.8245e-01,  4.9522e-01, -7.2429e-05, -5.4065e-01,
-        -3.7614e-02,  1.1575e-03, -1.5381e-04, -7.3150e-01,  9.9349e-02,
-         1.1403e+00,  8.0624e-02, -1.1415e-01, -2.2133e-01,  2.2370e-04,
-        -4.7816e-01, -1.6330e+00, -5.7262e-08, -1.4431e-01,  2.8422e-02,
-        -3.0511e-01,  9.4810e-04,  0.0000e+00,  1.0416e+00,  4.8529e-01,
-         2.3109e-01,  3.5306e-01,  4.6353e-02,  2.3334e-04, -1.0707e+00,
-        -3.2892e-01,  3.0438e-01,  2.8920e-01,  9.4983e-01, -1.4855e-05,
-         1.5028e-01, -6.2934e-01, -7.0323e-02,  1.3862e-01,  4.9088e-01,
-         3.0915e-09,  9.7782e-01,  4.3696e-07,  7.4061e-02,  1.2442e+00,
-         1.1672e-06,  1.1241e-01, -3.9282e-02, -2.6930e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.4086e-01,  5.4015e-03, -3.6387e-01,  1.0157e-01,  1.8300e+00,
-         1.2201e-01,  0.0000e+00, -5.3257e-01,  4.2014e-02, -2.0836e-01,
-         5.4233e-04, -3.5044e-01, -1.9528e-01, -7.3592e-01,  0.0000e+00,
-        -1.1988e+00, -4.8245e-01,  4.9522e-01,  0.0000e+00, -5.4065e-01,
-        -3.7614e-02,  1.1575e-03,  0.0000e+00, -7.3150e-01,  9.9349e-02,
-         1.1403e+00,  8.0624e-02, -1.1415e-01, -2.2133e-01,  0.0000e+00,
-        -4.7816e-01, -1.6330e+00,  0.0000e+00, -1.4431e-01,  0.0000e+00,
-        -3.0511e-01,  0.0000e+00,  0.0000e+00,  1.0416e+00,  4.8529e-01,
-         2.3109e-01,  3.5306e-01,  4.6353e-02,  0.0000e+00, -1.0707e+00,
-        -3.2892e-01,  3.0438e-01,  2.8920e-01,  9.4983e-01,  0.0000e+00,
-         1.5028e-01, -6.2934e-01, -7.0323e-02,  1.3862e-01,  4.9088e-01,
-         0.0000e+00,  9.7782e-01,  0.0000e+00,  7.4061e-02,  1.2442e+00,
-         0.0000e+00,  1.1241e-01, -3.9282e-02, -2.6930e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.4086e-01,  5.4015e-03, -3.6387e-01,  1.0157e-01,  1.8300e+00,
-         1.2201e-01,  0.0000e+00, -5.3257e-01,  4.2014e-02, -2.0836e-01,
-         5.4233e-04, -3.5044e-01, -1.9528e-01, -7.3592e-01,  0.0000e+00,
-        -1.1988e+00, -4.8245e-01,  4.9522e-01,  0.0000e+00, -5.4065e-01,
-        -3.7614e-02,  1.1575e-03,  0.0000e+00, -7.3150e-01,  9.9349e-02,
-         1.1403e+00,  8.0624e-02, -1.1415e-01, -2.2133e-01,  0.0000e+00,
-        -4.7816e-01, -1.6330e+00,  0.0000e+00, -1.4431e-01,  0.0000e+00,
-        -3.0511e-01,  0.0000e+00,  0.0000e+00,  1.0416e+00,  4.8529e-01,
-         2.3109e-01,  3.5306e-01,  4.6353e-02,  0.0000e+00, -1.0707e+00,
-        -3.2892e-01,  3.0438e-01,  2.8920e-01,  9.4983e-01,  0.0000e+00,
-         1.5028e-01, -6.2934e-01, -7.0323e-02,  1.3862e-01,  4.9088e-01,
-         0.0000e+00,  9.7782e-01,  0.0000e+00,  7.4061e-02,  1.2442e+00,
-         0.0000e+00,  1.1241e-01, -3.9282e-02, -2.6930e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.3909e-01,  4.2286e-02, -4.2047e-01,  1.1162e-02,  1.9997e+00,
-         9.6435e-02,  2.0327e-07, -5.4634e-01,  2.5081e-02, -2.7292e-01,
-        -6.4461e-02, -3.5200e-01, -2.7293e-01, -7.2464e-01, -3.6451e-04,
-        -1.1986e+00, -4.8132e-01,  5.0541e-01, -6.1637e-05, -5.8542e-01,
-        -9.9705e-02, -8.8148e-03, -1.3089e-04, -7.2482e-01,  1.3656e-01,
-         1.1437e+00,  1.3523e-02, -1.1533e-01, -1.7721e-01,  1.9037e-04,
-        -5.5571e-01, -1.6329e+00, -4.8730e-08, -1.5270e-01,  2.4187e-02,
-        -3.4839e-01,  8.0683e-04,  0.0000e+00,  1.0440e+00,  5.0701e-01,
-         2.2123e-01,  3.0777e-01,  7.9101e-03,  1.9857e-04, -1.0720e+00,
-        -2.6580e-01,  3.5074e-01,  2.0285e-01,  9.7042e-01, -1.2641e-05,
-         1.3243e-01, -6.7739e-01, -2.4416e-02,  2.0713e-02,  4.9084e-01,
-         2.6309e-09,  9.7478e-01,  3.7186e-07,  7.1613e-02,  1.2405e+00,
-         9.9331e-07, -7.8198e-04, -8.1368e-02, -3.3318e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.3909e-01,  4.2286e-02, -4.2047e-01,  1.1162e-02,  1.9997e+00,
-         9.6435e-02,  0.0000e+00, -5.4634e-01,  2.5081e-02, -2.7292e-01,
-        -6.4461e-02, -3.5200e-01, -2.7293e-01, -7.2464e-01,  0.0000e+00,
-        -1.1986e+00, -4.8132e-01,  5.0541e-01,  0.0000e+00, -5.8542e-01,
-        -9.9705e-02, -8.8148e-03,  0.0000e+00, -7.2482e-01,  1.3656e-01,
-         1.1437e+00,  1.3523e-02, -1.1533e-01, -1.7721e-01,  0.0000e+00,
-        -5.5571e-01, -1.6329e+00,  0.0000e+00, -1.5270e-01,  0.0000e+00,
-        -3.4839e-01,  0.0000e+00,  0.0000e+00,  1.0440e+00,  5.0701e-01,
-         2.2123e-01,  3.0777e-01,  7.9101e-03,  0.0000e+00, -1.0720e+00,
-        -2.6580e-01,  3.5074e-01,  2.0285e-01,  9.7042e-01,  0.0000e+00,
-         1.3243e-01, -6.7739e-01, -2.4416e-02,  2.0713e-02,  4.9084e-01,
-         0.0000e+00,  9.7478e-01,  0.0000e+00,  7.1613e-02,  1.2405e+00,
-         0.0000e+00, -7.8198e-04, -8.1368e-02, -3.3318e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.3909e-01,  4.2286e-02, -4.2047e-01,  1.1162e-02,  1.9997e+00,
-         9.6435e-02,  0.0000e+00, -5.4634e-01,  2.5081e-02, -2.7292e-01,
-        -6.4461e-02, -3.5200e-01, -2.7293e-01, -7.2464e-01,  0.0000e+00,
-        -1.1986e+00, -4.8132e-01,  5.0541e-01,  0.0000e+00, -5.8542e-01,
-        -9.9705e-02, -8.8148e-03,  0.0000e+00, -7.2482e-01,  1.3656e-01,
-         1.1437e+00,  1.3523e-02, -1.1533e-01, -1.7721e-01,  0.0000e+00,
-        -5.5571e-01, -1.6329e+00,  0.0000e+00, -1.5270e-01,  0.0000e+00,
-        -3.4839e-01,  0.0000e+00,  0.0000e+00,  1.0440e+00,  5.0701e-01,
-         2.2123e-01,  3.0777e-01,  7.9101e-03,  0.0000e+00, -1.0720e+00,
-        -2.6580e-01,  3.5074e-01,  2.0285e-01,  9.7042e-01,  0.0000e+00,
-         1.3243e-01, -6.7739e-01, -2.4416e-02,  2.0713e-02,  4.9084e-01,
-         0.0000e+00,  9.7478e-01,  0.0000e+00,  7.1613e-02,  1.2405e+00,
-         0.0000e+00, -7.8198e-04, -8.1368e-02, -3.3318e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1839e-01,  4.9494e-02, -4.3702e-01, -9.2403e-02,  2.1408e+00,
-         5.2306e-02,  1.7300e-07, -5.4616e-01,  1.7769e-02, -2.9559e-01,
-        -1.0003e-01, -3.3665e-01, -2.9175e-01, -7.0869e-01, -3.1023e-04,
-        -1.1973e+00, -4.7154e-01,  4.9720e-01, -5.2459e-05, -6.1184e-01,
-        -1.5541e-01, -2.9958e-02, -1.1140e-04, -7.0071e-01,  1.2534e-01,
-         1.1459e+00, -4.0897e-02, -1.3729e-01, -1.3579e-01,  1.6202e-04,
-        -6.1646e-01, -1.6324e+00, -4.1474e-08, -1.4651e-01,  2.0586e-02,
-        -3.7422e-01,  6.8669e-04,  0.0000e+00,  1.0397e+00,  5.0325e-01,
-         1.9470e-01,  2.6690e-01, -1.7530e-02,  1.6900e-04, -1.0752e+00,
-        -2.1830e-01,  3.9894e-01,  1.3599e-01,  9.8972e-01, -1.0759e-05,
-         1.1556e-01, -6.9798e-01,  1.6282e-02, -1.3618e-01,  4.7704e-01,
-         2.2391e-09,  9.7017e-01,  3.1648e-07,  1.2185e-01,  1.2381e+00,
-         8.4540e-07, -9.9388e-02, -1.2542e-01, -3.8339e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3184,  0.0495, -0.4370, -0.0924,  2.1408,  0.0523,  0.0000, -0.5462,
-         0.0178, -0.2956, -0.1000, -0.3367, -0.2917, -0.7087,  0.0000, -1.1973,
-        -0.4715,  0.4972,  0.0000, -0.6118, -0.1554, -0.0300,  0.0000, -0.7007,
-         0.1253,  1.1459, -0.0409, -0.1373, -0.1358,  0.0000, -0.6165, -1.6324,
-         0.0000, -0.1465,  0.0000, -0.3742,  0.0000,  0.0000,  1.0397,  0.5032,
-         0.1947,  0.2669, -0.0175,  0.0000, -1.0752, -0.2183,  0.3989,  0.1360,
-         0.9897,  0.0000,  0.1156, -0.6980,  0.0163, -0.1362,  0.4770,  0.0000,
-         0.9702,  0.0000,  0.1219,  1.2381,  0.0000, -0.0994, -0.1254, -0.3834],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3184,  0.0495, -0.4370, -0.0924,  2.1408,  0.0523,  0.0000, -0.5462,
-         0.0178, -0.2956, -0.1000, -0.3367, -0.2917, -0.7087,  0.0000, -1.1973,
-        -0.4715,  0.4972,  0.0000, -0.6118, -0.1554, -0.0300,  0.0000, -0.7007,
-         0.1253,  1.1459, -0.0409, -0.1373, -0.1358,  0.0000, -0.6165, -1.6324,
-         0.0000, -0.1465,  0.0000, -0.3742,  0.0000,  0.0000,  1.0397,  0.5032,
-         0.1947,  0.2669, -0.0175,  0.0000, -1.0752, -0.2183,  0.3989,  0.1360,
-         0.9897,  0.0000,  0.1156, -0.6980,  0.0163, -0.1362,  0.4770,  0.0000,
-         0.9702,  0.0000,  0.1219,  1.2381,  0.0000, -0.0994, -0.1254, -0.3834],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.3846e-01,  8.5701e-02, -3.8076e-01, -1.5493e-01,  2.2583e+00,
-         2.9116e-02,  1.4725e-07, -4.9830e-01,  1.2159e-02, -2.5460e-01,
-        -6.4146e-02, -3.4491e-01, -2.0445e-01, -6.9929e-01, -2.6407e-04,
-        -1.2032e+00, -4.9085e-01,  4.5801e-01, -4.4652e-05, -6.4056e-01,
-        -1.9054e-01,  3.3600e-02, -9.4825e-05, -6.8600e-01,  1.3122e-01,
-         1.1501e+00, -3.9226e-02, -1.5371e-01, -1.2543e-01,  1.3791e-04,
-        -6.6067e-01, -1.6262e+00, -3.5302e-08, -1.4764e-01,  1.7522e-02,
-        -3.6687e-01,  5.8450e-04,  0.0000e+00,  1.0379e+00,  5.1756e-01,
-         1.9051e-01,  2.5289e-01,  1.0016e-02,  1.4385e-04, -1.0831e+00,
-        -1.6463e-01,  4.5604e-01,  1.2719e-01,  1.0012e+00, -9.1579e-06,
-         1.3146e-01, -7.1119e-01, -1.3704e-02, -9.0878e-02,  5.1530e-01,
-         1.9059e-09,  9.6498e-01,  2.6939e-07,  1.6945e-01,  1.2336e+00,
-         7.1960e-07, -1.0596e-01, -9.9868e-02, -3.4194e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3385,  0.0857, -0.3808, -0.1549,  2.2583,  0.0291,  0.0000, -0.4983,
-         0.0122, -0.2546, -0.0641, -0.3449, -0.2044, -0.6993,  0.0000, -1.2032,
-         0.0000,  0.4580,  0.0000, -0.6406, -0.1905,  0.0336,  0.0000, -0.6860,
-         0.1312,  1.1501, -0.0392, -0.1537, -0.1254,  0.0000, -0.6607, -1.6262,
-         0.0000, -0.1476,  0.0000, -0.3669,  0.0000,  0.0000,  1.0379,  0.5176,
-         0.1905,  0.2529,  0.0100,  0.0000, -1.0831, -0.1646,  0.4560,  0.1272,
-         1.0012,  0.0000,  0.1315, -0.7112, -0.0137, -0.0909,  0.5153,  0.0000,
-         0.9650,  0.0000,  0.1694,  1.2336,  0.0000, -0.1060, -0.0999, -0.3419],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3385,  0.0857, -0.3808, -0.1549,  2.2583,  0.0291,  0.0000, -0.4983,
-         0.0122, -0.2546, -0.0641, -0.3449, -0.2044, -0.6993,  0.0000, -1.2032,
-         0.0000,  0.4580,  0.0000, -0.6406, -0.1905,  0.0336,  0.0000, -0.6860,
-         0.1312,  1.1501, -0.0392, -0.1537, -0.1254,  0.0000, -0.6607, -1.6262,
-         0.0000, -0.1476,  0.0000, -0.3669,  0.0000,  0.0000,  1.0379,  0.5176,
-         0.1905,  0.2529,  0.0100,  0.0000, -1.0831, -0.1646,  0.4560,  0.1272,
-         1.0012,  0.0000,  0.1315, -0.7112, -0.0137, -0.0909,  0.5153,  0.0000,
-         0.9650,  0.0000,  0.1694,  1.2336,  0.0000, -0.1060, -0.0999, -0.3419],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6194e-01,  1.1277e-01, -3.0877e-01, -1.9019e-01,  2.3539e+00,
-         5.3199e-02,  1.2535e-07, -4.6421e-01, -2.6619e-02, -2.0506e-01,
-        -3.1760e-03, -3.1940e-01, -8.3008e-02, -7.0362e-01, -2.2480e-04,
-        -1.2106e+00, -1.6435e-02,  4.1444e-01, -3.8012e-05, -6.6979e-01,
-        -2.2993e-01,  1.0003e-01, -8.0723e-05, -7.1550e-01,  1.3287e-01,
-         1.1521e+00, -3.6087e-03, -1.7563e-01, -1.9636e-01,  1.1740e-04,
-        -6.6187e-01, -1.6196e+00, -3.0052e-08, -1.9940e-01,  1.4916e-02,
-        -3.3522e-01,  4.9758e-04,  0.0000e+00,  1.0336e+00,  5.2806e-01,
-         1.7454e-01,  2.4990e-01,  6.4877e-02,  1.2246e-04, -1.0895e+00,
-        -8.0136e-02,  5.0763e-01,  1.4117e-01,  9.9709e-01, -7.7960e-06,
-         1.5494e-01, -7.5460e-01, -2.3992e-02,  1.9358e-02,  5.4570e-01,
-         1.6225e-09,  9.6025e-01,  2.2933e-07,  2.1642e-01,  1.2384e+00,
-         6.1258e-07, -3.7178e-02, -3.9137e-02, -2.5710e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3619,  0.1128, -0.3088, -0.1902,  2.3539,  0.0532,  0.0000, -0.4642,
-        -0.0266, -0.2051, -0.0032, -0.3194, -0.0830, -0.7036,  0.0000, -1.2106,
-         0.0000,  0.4144,  0.0000, -0.6698, -0.2299,  0.1000,  0.0000, -0.7155,
-         0.1329,  1.1521, -0.0036, -0.1756, -0.1964,  0.0000, -0.6619, -1.6196,
-         0.0000, -0.1994,  0.0000, -0.3352,  0.0000,  0.0000,  1.0336,  0.5281,
-         0.1745,  0.2499,  0.0649,  0.0000, -1.0895, -0.0801,  0.5076,  0.1412,
-         0.9971,  0.0000,  0.1549, -0.7546, -0.0240,  0.0194,  0.5457,  0.0000,
-         0.9602,  0.0000,  0.2164,  1.2384,  0.0000, -0.0372, -0.0391, -0.2571],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3619,  0.1128, -0.3088, -0.1902,  2.3539,  0.0532,  0.0000, -0.4642,
-        -0.0266, -0.2051, -0.0032, -0.3194, -0.0830, -0.7036,  0.0000, -1.2106,
-         0.0000,  0.4144,  0.0000, -0.6698, -0.2299,  0.1000,  0.0000, -0.7155,
-         0.1329,  1.1521, -0.0036, -0.1756, -0.1964,  0.0000, -0.6619, -1.6196,
-         0.0000, -0.1994,  0.0000, -0.3352,  0.0000,  0.0000,  1.0336,  0.5281,
-         0.1745,  0.2499,  0.0649,  0.0000, -1.0895, -0.0801,  0.5076,  0.1412,
-         0.9971,  0.0000,  0.1549, -0.7546, -0.0240,  0.0194,  0.5457,  0.0000,
-         0.9602,  0.0000,  0.2164,  1.2384,  0.0000, -0.0372, -0.0391, -0.2571],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6533e-01,  1.6307e-01, -2.5847e-01, -2.2187e-01,  2.4372e+00,
-         4.2116e-02,  1.0673e-07, -4.0954e-01, -3.9911e-02, -2.0979e-01,
-         2.1614e-02, -2.6262e-01,  1.3342e-02, -6.7814e-01, -1.9139e-04,
-        -1.2060e+00, -1.3992e-02,  3.9265e-01, -3.2363e-05, -6.8025e-01,
-        -2.3821e-01,  1.1741e-01, -6.8727e-05, -7.3586e-01,  1.1249e-01,
-         1.1511e+00,  6.4921e-03, -1.7071e-01, -2.6219e-01,  9.9953e-05,
-        -6.6007e-01, -1.6137e+00, -2.5586e-08, -1.9484e-01,  1.2700e-02,
-        -2.9337e-01,  4.2363e-04,  0.0000e+00,  1.0189e+00,  5.2192e-01,
-         1.5995e-01,  2.4062e-01,  1.1596e-01,  1.0426e-04, -1.0960e+00,
-         9.5587e-03,  5.5293e-01,  1.2229e-01,  9.9635e-01, -6.6374e-06,
-         1.7653e-01, -7.8136e-01, -2.5412e-02,  1.0333e-01,  5.6480e-01,
-         1.3813e-09,  9.5769e-01,  1.9524e-07,  2.4322e-01,  1.2378e+00,
-         5.2154e-07,  1.2371e-02,  3.7374e-03, -1.9089e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3653,  0.1631, -0.2585, -0.2219,  2.4372,  0.0421,  0.0000, -0.4095,
-        -0.0399, -0.2098,  0.0216, -0.2626,  0.0133, -0.6781,  0.0000, -1.2060,
-         0.0000,  0.3926,  0.0000, -0.6803, -0.2382,  0.1174,  0.0000, -0.7359,
-         0.1125,  1.1511,  0.0065, -0.1707, -0.2622,  0.0000, -0.6601, -1.6137,
-         0.0000, -0.1948,  0.0000, -0.2934,  0.0000,  0.0000,  1.0189,  0.5219,
-         0.1600,  0.2406,  0.1160,  0.0000, -1.0960,  0.0096,  0.5529,  0.1223,
-         0.9963,  0.0000,  0.1765, -0.7814, -0.0254,  0.1033,  0.5648,  0.0000,
-         0.9577,  0.0000,  0.2432,  1.2378,  0.0000,  0.0124,  0.0037, -0.1909],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3653,  0.1631, -0.2585, -0.2219,  2.4372,  0.0421,  0.0000, -0.4095,
-        -0.0399, -0.2098,  0.0216, -0.2626,  0.0133, -0.6781,  0.0000, -1.2060,
-         0.0000,  0.3926,  0.0000, -0.6803, -0.2382,  0.1174,  0.0000, -0.7359,
-         0.1125,  1.1511,  0.0065, -0.1707, -0.2622,  0.0000, -0.6601, -1.6137,
-         0.0000, -0.1948,  0.0000, -0.2934,  0.0000,  0.0000,  1.0189,  0.5219,
-         0.1600,  0.2406,  0.1160,  0.0000, -1.0960,  0.0096,  0.5529,  0.1223,
-         0.9963,  0.0000,  0.1765, -0.7814, -0.0254,  0.1033,  0.5648,  0.0000,
-         0.9577,  0.0000,  0.2432,  1.2378,  0.0000,  0.0124,  0.0037, -0.1909],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2938e-01,  2.4265e-01, -2.8657e-01, -2.7689e-01,  2.5049e+00,
-        -4.0681e-03,  9.0875e-08, -3.5372e-01, -3.7643e-02, -2.5501e-01,
-         1.1189e-01, -2.3593e-01,  4.9567e-02, -6.4470e-01, -1.6297e-04,
-        -1.2033e+00, -1.1914e-02,  3.6899e-01, -2.7557e-05, -6.8378e-01,
-        -2.5067e-01,  1.6948e-01, -5.8520e-05, -7.2651e-01,  1.4675e-01,
-         1.1527e+00, -4.5062e-02, -1.0511e-01, -2.9978e-01,  8.5109e-05,
-        -6.8064e-01, -1.6080e+00, -2.1786e-08, -1.4734e-01,  1.0814e-02,
-        -2.7368e-01,  3.6072e-04,  0.0000e+00,  1.0006e+00,  4.9126e-01,
-         1.8493e-01,  2.5698e-01,  1.0678e-01,  8.8776e-05, -1.1010e+00,
-         7.0966e-02,  5.9990e-01,  7.6303e-02,  1.0003e+00, -5.6517e-06,
-         2.3084e-01, -7.5335e-01, -5.8585e-03,  1.1081e-01,  5.5174e-01,
-         1.1762e-09,  9.5664e-01,  1.6625e-07,  2.5351e-01,  1.2324e+00,
-         4.4409e-07, -5.3832e-02,  6.6364e-02, -1.8481e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3294,  0.2426, -0.2866, -0.2769,  2.5049, -0.0041,  0.0000, -0.3537,
-        -0.0376, -0.2550,  0.1119, -0.2359,  0.0496, -0.6447,  0.0000, -1.2033,
-         0.0000,  0.3690,  0.0000, -0.6838, -0.2507,  0.1695,  0.0000, -0.7265,
-         0.1467,  1.1527, -0.0451, -0.1051, -0.2998,  0.0000, -0.6806, -1.6080,
-         0.0000, -0.1473,  0.0000, -0.2737,  0.0000,  0.0000,  1.0006,  0.4913,
-         0.1849,  0.2570,  0.1068,  0.0000, -1.1010,  0.0710,  0.5999,  0.0763,
-         1.0003,  0.0000,  0.2308, -0.7533, -0.0059,  0.1108,  0.5517,  0.0000,
-         0.9566,  0.0000,  0.2535,  1.2324,  0.0000, -0.0538,  0.0664, -0.1848],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3294,  0.2426, -0.2866, -0.2769,  2.5049, -0.0041,  0.0000, -0.3537,
-        -0.0376, -0.2550,  0.1119, -0.2359,  0.0496, -0.6447,  0.0000, -1.2033,
-         0.0000,  0.3690,  0.0000, -0.6838, -0.2507,  0.1695,  0.0000, -0.7265,
-         0.1467,  1.1527, -0.0451, -0.1051, -0.2998,  0.0000, -0.6806, -1.6080,
-         0.0000, -0.1473,  0.0000, -0.2737,  0.0000,  0.0000,  1.0006,  0.4913,
-         0.1849,  0.2570,  0.1068,  0.0000, -1.1010,  0.0710,  0.5999,  0.0763,
-         1.0003,  0.0000,  0.2308, -0.7533, -0.0059,  0.1108,  0.5517,  0.0000,
-         0.9566,  0.0000,  0.2535,  1.2324,  0.0000, -0.0538,  0.0664, -0.1848],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8029e-01,  2.9624e-01, -3.3060e-01, -3.3703e-01,  2.5575e+00,
-        -1.9141e-02,  7.7389e-08, -3.5035e-01, -3.0959e-02, -2.8140e-01,
-         1.3617e-01, -2.2096e-01,  4.9927e-02, -6.2027e-01, -1.3878e-04,
-        -1.2006e+00, -1.0146e-02,  3.1638e-01, -2.3467e-05, -6.9843e-01,
-        -2.6589e-01,  1.4600e-01, -4.9836e-05, -6.9897e-01,  1.3408e-01,
-         1.1594e+00, -6.9759e-02, -7.6027e-02, -3.5340e-01,  7.2479e-05,
-        -6.9241e-01, -1.6017e+00, -1.8553e-08, -1.0877e-01,  9.2089e-03,
-        -2.6793e-01,  3.0719e-04,  0.0000e+00,  9.8183e-01,  4.5764e-01,
-         1.4231e-01,  2.5142e-01,  9.7194e-02,  7.5601e-05, -1.1013e+00,
-         1.9355e-01,  6.2879e-01,  3.6543e-02,  9.9136e-01, -4.8129e-06,
-         2.3380e-01, -7.3635e-01,  4.1772e-02,  1.0736e-01,  4.9048e-01,
-         1.0017e-09,  9.5624e-01,  1.4158e-07,  2.4830e-01,  1.2273e+00,
-         3.7819e-07, -8.2222e-02,  5.4294e-02, -2.0486e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2803,  0.2962, -0.3306, -0.3370,  2.5575, -0.0191,  0.0000, -0.3504,
-        -0.0310, -0.2814,  0.1362, -0.2210,  0.0499, -0.6203,  0.0000, -1.2006,
-         0.0000,  0.3164,  0.0000, -0.6984, -0.2659,  0.1460,  0.0000, -0.6990,
-         0.1341,  1.1594, -0.0698, -0.0760, -0.3534,  0.0000, -0.6924, -1.6017,
-         0.0000, -0.1088,  0.0000, -0.2679,  0.0000,  0.0000,  0.9818,  0.4576,
-         0.1423,  0.2514,  0.0972,  0.0000, -1.1013,  0.1936,  0.6288,  0.0365,
-         0.9914,  0.0000,  0.2338, -0.7364,  0.0418,  0.1074,  0.4905,  0.0000,
-         0.0000,  0.0000,  0.2483,  1.2273,  0.0000, -0.0822,  0.0543, -0.2049],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2803,  0.2962, -0.3306, -0.3370,  2.5575, -0.0191,  0.0000, -0.3504,
-        -0.0310, -0.2814,  0.1362, -0.2210,  0.0499, -0.6203,  0.0000, -1.2006,
-         0.0000,  0.3164,  0.0000, -0.6984, -0.2659,  0.1460,  0.0000, -0.6990,
-         0.1341,  1.1594, -0.0698, -0.0760, -0.3534,  0.0000, -0.6924, -1.6017,
-         0.0000, -0.1088,  0.0000, -0.2679,  0.0000,  0.0000,  0.9818,  0.4576,
-         0.1423,  0.2514,  0.0972,  0.0000, -1.1013,  0.1936,  0.6288,  0.0365,
-         0.9914,  0.0000,  0.2338, -0.7364,  0.0418,  0.1074,  0.4905,  0.0000,
-         0.0000,  0.0000,  0.2483,  1.2273,  0.0000, -0.0822,  0.0543, -0.2049],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6082e-01,  3.2753e-01, -3.5603e-01, -3.8333e-01,  2.5988e+00,
-        -9.7076e-03,  6.5913e-08, -3.5006e-01,  2.5504e-02, -2.6638e-01,
-         1.0657e-01, -2.2174e-01, -3.9076e-03, -6.0264e-01, -1.1820e-04,
-        -1.1991e+00, -8.6415e-03,  2.8947e-01, -1.9987e-05, -7.2047e-01,
-        -2.6836e-01,  1.1269e-01, -4.2445e-05, -6.6228e-01,  1.5729e-01,
-         1.1725e+00, -9.2198e-02, -3.4408e-02, -3.6270e-01,  6.1731e-05,
-        -6.8494e-01, -1.5955e+00, -1.5802e-08, -6.1525e-02,  7.8433e-03,
-        -2.8633e-01,  2.6163e-04,  0.0000e+00,  9.5653e-01,  3.9588e-01,
-         7.0901e-02,  2.2389e-01,  8.1896e-02,  6.4390e-05, -1.1007e+00,
-         3.1306e-01,  6.4397e-01, -1.6756e-02,  9.9268e-01, -4.0992e-06,
-         2.0297e-01, -7.0532e-01,  6.1119e-02,  1.6094e-01,  4.0214e-01,
-         8.5312e-10, -3.3529e-04,  1.2058e-07,  2.0217e-01,  1.2169e+00,
-         3.2210e-07, -1.1626e-01,  3.8073e-03, -2.1409e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2608,  0.3275, -0.3560, -0.3833,  2.5988, -0.0097,  0.0000, -0.3501,
-         0.0255, -0.2664,  0.1066, -0.2217, -0.0039, -0.6026,  0.0000, -1.1991,
-         0.0000,  0.2895,  0.0000, -0.7205, -0.2684,  0.1127,  0.0000, -0.6623,
-         0.1573,  1.1725, -0.0922, -0.0344, -0.3627,  0.0000, -0.6849, -1.5955,
-         0.0000, -0.0615,  0.0000, -0.2863,  0.0000,  0.0000,  0.9565,  0.3959,
-         0.0709,  0.2239,  0.0819,  0.0000, -1.1007,  0.3131,  0.6440, -0.0168,
-         0.9927,  0.0000,  0.2030, -0.7053,  0.0611,  0.1609,  0.4021,  0.0000,
-         0.0000,  0.0000,  0.2022,  1.2169,  0.0000, -0.1163,  0.0038, -0.2141],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2608,  0.3275, -0.3560, -0.3833,  2.5988, -0.0097,  0.0000, -0.3501,
-         0.0255, -0.2664,  0.1066, -0.2217, -0.0039, -0.6026,  0.0000, -1.1991,
-         0.0000,  0.2895,  0.0000, -0.7205, -0.2684,  0.1127,  0.0000, -0.6623,
-         0.1573,  1.1725, -0.0922, -0.0344, -0.3627,  0.0000, -0.6849, -1.5955,
-         0.0000, -0.0615,  0.0000, -0.2863,  0.0000,  0.0000,  0.9565,  0.3959,
-         0.0709,  0.2239,  0.0819,  0.0000, -1.1007,  0.3131,  0.6440, -0.0168,
-         0.9927,  0.0000,  0.2030, -0.7053,  0.0611,  0.1609,  0.4021,  0.0000,
-         0.0000,  0.0000,  0.2022,  1.2169,  0.0000, -0.1163,  0.0038, -0.2141],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8554e-01,  3.2674e-01, -3.5165e-01, -4.3247e-01,  2.6344e+00,
-         1.6420e-02,  5.6146e-08, -3.5790e-01,  1.2655e-03, -2.2194e-01,
-         8.4136e-02, -1.8196e-01,  1.9811e-02, -6.0217e-01, -1.0069e-04,
-        -1.2045e+00, -7.3610e-03,  2.5757e-01, -1.7026e-05, -7.2488e-01,
-        -2.6226e-01,  6.7459e-02, -3.6156e-05, -6.5102e-01,  1.6992e-01,
-         1.1790e+00, -8.0194e-02,  5.8209e-03, -3.8123e-01,  5.2584e-05,
-        -6.8024e-01, -1.5894e+00, -1.3460e-08, -9.2689e-02,  6.6811e-03,
-        -2.7350e-01,  2.2287e-04,  0.0000e+00,  9.5817e-01,  4.2480e-01,
-         2.3028e-03,  1.9750e-01,  3.5224e-02,  5.4849e-05, -1.0994e+00,
-         3.6369e-01,  6.5611e-01, -1.1833e-03,  9.9217e-01, -3.4918e-06,
-         1.5225e-01, -6.7281e-01,  3.8428e-02,  1.1926e-01,  3.8971e-01,
-         7.2670e-10, -2.8561e-04,  1.0271e-07,  1.6975e-01,  1.2041e+00,
-         2.7438e-07, -1.6735e-01, -1.6385e-02, -1.9701e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.8554e-01,  3.2674e-01, -3.5165e-01, -4.3247e-01,  2.6344e+00,
-         1.6420e-02,  0.0000e+00, -3.5790e-01,  1.2655e-03, -2.2194e-01,
-         8.4136e-02, -1.8196e-01,  1.9811e-02, -6.0217e-01,  0.0000e+00,
-        -1.2045e+00,  0.0000e+00,  2.5757e-01,  0.0000e+00, -7.2488e-01,
-        -2.6226e-01,  6.7459e-02,  0.0000e+00, -6.5102e-01,  1.6992e-01,
-         1.1790e+00, -8.0194e-02,  5.8209e-03, -3.8123e-01,  0.0000e+00,
-        -6.8024e-01, -1.5894e+00,  0.0000e+00, -9.2689e-02,  0.0000e+00,
-        -2.7350e-01,  0.0000e+00,  0.0000e+00,  9.5817e-01,  4.2480e-01,
-         2.3028e-03,  1.9750e-01,  3.5224e-02,  0.0000e+00, -1.0994e+00,
-         3.6369e-01,  6.5611e-01, -1.1833e-03,  9.9217e-01,  0.0000e+00,
-         1.5225e-01, -6.7281e-01,  3.8428e-02,  1.1926e-01,  3.8971e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.6975e-01,  1.2041e+00,
-         0.0000e+00, -1.6735e-01, -1.6385e-02, -1.9701e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.8554e-01,  3.2674e-01, -3.5165e-01, -4.3247e-01,  2.6344e+00,
-         1.6420e-02,  0.0000e+00, -3.5790e-01,  1.2655e-03, -2.2194e-01,
-         8.4136e-02, -1.8196e-01,  1.9811e-02, -6.0217e-01,  0.0000e+00,
-        -1.2045e+00,  0.0000e+00,  2.5757e-01,  0.0000e+00, -7.2488e-01,
-        -2.6226e-01,  6.7459e-02,  0.0000e+00, -6.5102e-01,  1.6992e-01,
-         1.1790e+00, -8.0194e-02,  5.8209e-03, -3.8123e-01,  0.0000e+00,
-        -6.8024e-01, -1.5894e+00,  0.0000e+00, -9.2689e-02,  0.0000e+00,
-        -2.7350e-01,  0.0000e+00,  0.0000e+00,  9.5817e-01,  4.2480e-01,
-         2.3028e-03,  1.9750e-01,  3.5224e-02,  0.0000e+00, -1.0994e+00,
-         3.6369e-01,  6.5611e-01, -1.1833e-03,  9.9217e-01,  0.0000e+00,
-         1.5225e-01, -6.7281e-01,  3.8428e-02,  1.1926e-01,  3.8971e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.6975e-01,  1.2041e+00,
-         0.0000e+00, -1.6735e-01, -1.6385e-02, -1.9701e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.3586e-01,  2.8995e-01, -3.1029e-01, -4.6964e-01,  2.6653e+00,
-         6.6093e-02,  4.7833e-08, -3.4545e-01, -3.5952e-02, -1.4463e-01,
-         1.1457e-01, -8.3715e-02,  1.3084e-01, -5.9214e-01, -8.5779e-05,
-        -1.2053e+00, -6.2712e-03,  2.5915e-01, -1.4505e-05, -7.2159e-01,
-        -2.3406e-01, -8.8243e-03, -3.0803e-05, -6.4234e-01,  1.3680e-01,
-         1.1795e+00, -5.6621e-02,  1.6773e-02, -3.7962e-01,  4.4798e-05,
-        -6.8693e-01, -1.5827e+00, -1.1467e-08, -1.7331e-01,  5.6919e-03,
-        -2.3338e-01,  1.8987e-04,  0.0000e+00,  9.5467e-01,  4.6410e-01,
-        -5.6642e-02,  1.7913e-01,  1.0881e-02,  4.6728e-05, -1.0976e+00,
-         3.8881e-01,  6.7180e-01,  3.9993e-02,  9.9133e-01, -2.9748e-06,
-         1.2660e-01, -6.2717e-01,  2.4263e-02,  5.3831e-02,  3.9157e-01,
-         6.1911e-10, -2.4332e-04,  8.7507e-08,  1.0545e-01,  1.1941e+00,
-         2.3375e-07, -1.8611e-01, -1.6598e-02, -1.4129e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3359,  0.2900, -0.3103, -0.4696,  2.6653,  0.0661,  0.0000, -0.3455,
-        -0.0360, -0.1446,  0.1146, -0.0837,  0.1308, -0.5921,  0.0000, -1.2053,
-         0.0000,  0.2592,  0.0000, -0.7216, -0.2341, -0.0088,  0.0000, -0.6423,
-         0.1368,  1.1795, -0.0566,  0.0168, -0.3796,  0.0000, -0.6869, -1.5827,
-         0.0000, -0.1733,  0.0000, -0.2334,  0.0000,  0.0000,  0.9547,  0.4641,
-        -0.0566,  0.1791,  0.0109,  0.0000, -1.0976,  0.3888,  0.6718,  0.0400,
-         0.9913,  0.0000,  0.1266, -0.6272,  0.0243,  0.0538,  0.3916,  0.0000,
-         0.0000,  0.0000,  0.1055,  1.1941,  0.0000, -0.1861, -0.0166, -0.1413],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3359,  0.2900, -0.3103, -0.4696,  2.6653,  0.0661,  0.0000, -0.3455,
-        -0.0360, -0.1446,  0.1146, -0.0837,  0.1308, -0.5921,  0.0000, -1.2053,
-         0.0000,  0.2592,  0.0000, -0.7216, -0.2341, -0.0088,  0.0000, -0.6423,
-         0.1368,  1.1795, -0.0566,  0.0168, -0.3796,  0.0000, -0.6869, -1.5827,
-         0.0000, -0.1733,  0.0000, -0.2334,  0.0000,  0.0000,  0.9547,  0.4641,
-        -0.0566,  0.1791,  0.0109,  0.0000, -1.0976,  0.3888,  0.6718,  0.0400,
-         0.9913,  0.0000,  0.1266, -0.6272,  0.0243,  0.0538,  0.3916,  0.0000,
-         0.0000,  0.0000,  0.1055,  1.1941,  0.0000, -0.1861, -0.0166, -0.1413],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7629e-01,  2.5589e-01, -2.6619e-01, -5.1612e-01,  2.6927e+00,
-         1.1796e-01,  4.0757e-08, -3.0627e-01, -9.3462e-02, -5.6244e-02,
-         2.3380e-01,  3.7126e-02,  2.4282e-01, -5.9189e-01, -7.3089e-05,
-        -1.2102e+00, -5.3434e-03,  2.3564e-01, -1.2359e-05, -7.1261e-01,
-        -2.0341e-01, -9.1648e-02, -2.6246e-05, -6.2226e-01,  8.4062e-02,
-         1.1810e+00, -3.8947e-02, -7.0219e-03, -3.7447e-01,  3.8171e-05,
-        -6.9077e-01, -1.5769e+00, -9.7710e-09, -2.7326e-01,  4.8498e-03,
-        -1.6254e-01,  1.6178e-04,  0.0000e+00,  9.6262e-01,  5.4382e-01,
-        -1.0668e-01,  1.5753e-01, -1.6687e-02,  3.9815e-05, -1.0889e+00,
-         4.0298e-01,  6.8975e-01,  8.8506e-02,  9.8066e-01, -2.5347e-06,
-         1.4641e-01, -5.9084e-01,  4.2961e-02, -7.3256e-02,  4.1980e-01,
-         5.2752e-10, -2.0732e-04,  7.4561e-08,  1.0493e-02,  1.1820e+00,
-         1.9917e-07, -2.1601e-01,  1.3343e-03, -7.4563e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.7629e-01,  2.5589e-01, -2.6619e-01, -5.1612e-01,  2.6927e+00,
-         1.1796e-01,  0.0000e+00, -3.0627e-01, -9.3462e-02, -5.6244e-02,
-         2.3380e-01,  3.7126e-02,  2.4282e-01, -5.9189e-01,  0.0000e+00,
-        -1.2102e+00,  0.0000e+00,  2.3564e-01,  0.0000e+00, -7.1261e-01,
-        -2.0341e-01, -9.1648e-02,  0.0000e+00, -6.2226e-01,  8.4062e-02,
-         1.1810e+00, -3.8947e-02, -7.0219e-03, -3.7447e-01,  0.0000e+00,
-        -6.9077e-01, -1.5769e+00,  0.0000e+00, -2.7326e-01,  0.0000e+00,
-        -1.6254e-01,  0.0000e+00,  0.0000e+00,  9.6262e-01,  5.4382e-01,
-        -1.0668e-01,  0.0000e+00, -1.6687e-02,  0.0000e+00, -1.0889e+00,
-         4.0298e-01,  6.8975e-01,  8.8506e-02,  9.8066e-01,  0.0000e+00,
-         1.4641e-01, -5.9084e-01,  4.2961e-02, -7.3256e-02,  4.1980e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0493e-02,  1.1820e+00,
-         0.0000e+00, -2.1601e-01,  1.3343e-03, -7.4563e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.7629e-01,  2.5589e-01, -2.6619e-01, -5.1612e-01,  2.6927e+00,
-         1.1796e-01,  0.0000e+00, -3.0627e-01, -9.3462e-02, -5.6244e-02,
-         2.3380e-01,  3.7126e-02,  2.4282e-01, -5.9189e-01,  0.0000e+00,
-        -1.2102e+00,  0.0000e+00,  2.3564e-01,  0.0000e+00, -7.1261e-01,
-        -2.0341e-01, -9.1648e-02,  0.0000e+00, -6.2226e-01,  8.4062e-02,
-         1.1810e+00, -3.8947e-02, -7.0219e-03, -3.7447e-01,  0.0000e+00,
-        -6.9077e-01, -1.5769e+00,  0.0000e+00, -2.7326e-01,  0.0000e+00,
-        -1.6254e-01,  0.0000e+00,  0.0000e+00,  9.6262e-01,  5.4382e-01,
-        -1.0668e-01,  0.0000e+00, -1.6687e-02,  0.0000e+00, -1.0889e+00,
-         4.0298e-01,  6.8975e-01,  8.8506e-02,  9.8066e-01,  0.0000e+00,
-         1.4641e-01, -5.9084e-01,  4.2961e-02, -7.3256e-02,  4.1980e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0493e-02,  1.1820e+00,
-         0.0000e+00, -2.1601e-01,  1.3343e-03, -7.4563e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.9597e-01,  2.9150e-01, -2.0785e-01, -5.1453e-01,  2.7153e+00,
-         2.6733e-01,  3.4733e-08, -2.4367e-01, -8.0091e-02,  5.8067e-02,
-         3.2736e-01,  8.9504e-02,  2.8826e-01, -5.8848e-01, -6.2285e-05,
-        -1.2186e+00, -4.5536e-03,  2.2786e-01, -1.0532e-05, -7.1105e-01,
-        -1.4868e-01,  2.3451e-02, -2.2366e-05, -5.9915e-01,  1.8202e-01,
-         1.1816e+00,  3.2583e-03,  5.1164e-02, -3.4627e-01,  3.2529e-05,
-        -7.0374e-01, -1.5713e+00, -8.3267e-09, -2.3613e-01,  4.1330e-03,
-        -1.5218e-01,  1.3787e-04,  0.0000e+00,  9.5472e-01,  5.7578e-01,
-        -2.2663e-02, -1.8414e-02,  4.4136e-03,  3.3930e-05, -1.0809e+00,
-         3.1290e-01,  7.0048e-01,  9.6657e-02,  9.8773e-01, -2.1601e-06,
-         1.9494e-01, -4.8882e-01,  2.4142e-02,  5.7301e-02,  4.1918e-01,
-         4.4955e-10, -1.7668e-04,  6.3540e-08, -5.7581e-04,  1.1687e+00,
-         1.6973e-07, -1.9104e-01,  5.2621e-02,  3.9439e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.9597e-01,  2.9150e-01, -2.0785e-01, -5.1453e-01,  2.7153e+00,
-         2.6733e-01,  0.0000e+00, -2.4367e-01, -8.0091e-02,  5.8067e-02,
-         3.2736e-01,  8.9504e-02,  2.8826e-01, -5.8848e-01,  0.0000e+00,
-        -1.2186e+00,  0.0000e+00,  2.2786e-01,  0.0000e+00, -7.1105e-01,
-        -1.4868e-01,  2.3451e-02,  0.0000e+00, -5.9915e-01,  1.8202e-01,
-         1.1816e+00,  3.2583e-03,  5.1164e-02, -3.4627e-01,  0.0000e+00,
-        -7.0374e-01, -1.5713e+00,  0.0000e+00, -2.3613e-01,  0.0000e+00,
-        -1.5218e-01,  0.0000e+00,  0.0000e+00,  9.5472e-01,  5.7578e-01,
-        -2.2663e-02,  0.0000e+00,  4.4136e-03,  0.0000e+00, -1.0809e+00,
-         3.1290e-01,  7.0048e-01,  9.6657e-02,  9.8773e-01,  0.0000e+00,
-         1.9494e-01, -4.8882e-01,  2.4142e-02,  5.7301e-02,  4.1918e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -5.7581e-04,  1.1687e+00,
-         0.0000e+00, -1.9104e-01,  5.2621e-02,  3.9439e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.9597e-01,  2.9150e-01, -2.0785e-01, -5.1453e-01,  2.7153e+00,
-         2.6733e-01,  0.0000e+00, -2.4367e-01, -8.0091e-02,  5.8067e-02,
-         3.2736e-01,  8.9504e-02,  2.8826e-01, -5.8848e-01,  0.0000e+00,
-        -1.2186e+00,  0.0000e+00,  2.2786e-01,  0.0000e+00, -7.1105e-01,
-        -1.4868e-01,  2.3451e-02,  0.0000e+00, -5.9915e-01,  1.8202e-01,
-         1.1816e+00,  3.2583e-03,  5.1164e-02, -3.4627e-01,  0.0000e+00,
-        -7.0374e-01, -1.5713e+00,  0.0000e+00, -2.3613e-01,  0.0000e+00,
-        -1.5218e-01,  0.0000e+00,  0.0000e+00,  9.5472e-01,  5.7578e-01,
-        -2.2663e-02,  0.0000e+00,  4.4136e-03,  0.0000e+00, -1.0809e+00,
-         3.1290e-01,  7.0048e-01,  9.6657e-02,  9.8773e-01,  0.0000e+00,
-         1.9494e-01, -4.8882e-01,  2.4142e-02,  5.7301e-02,  4.1918e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -5.7581e-04,  1.1687e+00,
-         0.0000e+00, -1.9104e-01,  5.2621e-02,  3.9439e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7466e-01,  3.6420e-01, -1.3762e-01, -4.8241e-01,  2.7323e+00,
-         4.3250e-01,  2.9603e-08, -1.8445e-01, -2.2609e-02,  1.4945e-01,
-         4.4482e-01,  9.2314e-02,  2.7332e-01, -5.6656e-01, -5.3087e-05,
-        -1.2279e+00, -3.8811e-03,  2.7367e-01, -8.9767e-06, -7.0644e-01,
-        -9.7683e-02,  2.0208e-01, -1.9063e-05, -5.6923e-01,  3.5015e-01,
-         1.1801e+00,  3.8817e-02,  1.3571e-01, -3.1424e-01,  2.7725e-05,
-        -6.9715e-01, -1.5675e+00, -7.0970e-09, -1.6069e-01,  3.5226e-03,
-        -1.4147e-01,  1.1751e-04,  0.0000e+00,  9.1792e-01,  5.3769e-01,
-         6.4764e-02, -1.5695e-02,  5.3009e-02,  2.8919e-05, -1.0720e+00,
-         1.7646e-01,  7.0780e-01,  5.5209e-02,  1.0160e+00, -1.8411e-06,
-         2.3260e-01, -3.8116e-01,  5.2245e-03,  2.8750e-01,  3.4297e-01,
-         3.8315e-10, -1.5059e-04,  5.4156e-08,  3.4867e-02,  1.1618e+00,
-         1.4466e-07, -8.3979e-02,  1.0529e-01,  1.3709e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3747,  0.3642, -0.1376, -0.4824,  2.7323,  0.4325,  0.0000, -0.1845,
-        -0.0226,  0.1494,  0.4448,  0.0923,  0.2733, -0.5666,  0.0000, -1.2279,
-         0.0000,  0.2737,  0.0000, -0.7064, -0.0977,  0.2021,  0.0000, -0.5692,
-         0.3501,  1.1801,  0.0388,  0.1357, -0.3142,  0.0000, -0.6972, -1.5675,
-         0.0000, -0.1607,  0.0000, -0.1415,  0.0000,  0.0000,  0.9179,  0.5377,
-         0.0648,  0.0000,  0.0530,  0.0000, -1.0720,  0.1765,  0.7078,  0.0552,
-         1.0160,  0.0000,  0.2326, -0.3812,  0.0052,  0.2875,  0.3430,  0.0000,
-         0.0000,  0.0000,  0.0349,  1.1618,  0.0000, -0.0840,  0.1053,  0.1371],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3747,  0.3642, -0.1376, -0.4824,  2.7323,  0.4325,  0.0000, -0.1845,
-        -0.0226,  0.1494,  0.4448,  0.0923,  0.2733, -0.5666,  0.0000, -1.2279,
-         0.0000,  0.2737,  0.0000, -0.7064, -0.0977,  0.2021,  0.0000, -0.5692,
-         0.3501,  1.1801,  0.0388,  0.1357, -0.3142,  0.0000, -0.6972, -1.5675,
-         0.0000, -0.1607,  0.0000, -0.1415,  0.0000,  0.0000,  0.9179,  0.5377,
-         0.0648,  0.0000,  0.0530,  0.0000, -1.0720,  0.1765,  0.7078,  0.0552,
-         1.0160,  0.0000,  0.2326, -0.3812,  0.0052,  0.2875,  0.3430,  0.0000,
-         0.0000,  0.0000,  0.0349,  1.1618,  0.0000, -0.0840,  0.1053,  0.1371],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4041e-01,  4.2409e-01, -1.1027e-01, -4.6686e-01,  2.7436e+00,
-         4.4869e-01,  2.5235e-08, -1.3507e-01,  2.9859e-02,  1.2966e-01,
-         3.9204e-01,  9.2712e-02,  1.2342e-01, -5.3730e-01, -4.5254e-05,
-        -1.2375e+00, -3.3084e-03,  1.3596e-01, -7.6521e-06, -7.0240e-01,
-        -3.1623e-02,  2.7886e-01, -1.6250e-05, -5.2852e-01,  4.6081e-01,
-         1.1898e+00, -1.1048e-01,  1.6364e-01, -2.1910e-01,  2.3634e-05,
-        -6.7570e-01, -1.5635e+00, -6.0498e-09, -1.0816e-01,  3.0028e-03,
-        -9.9207e-02,  1.0017e-04,  0.0000e+00,  9.2648e-01,  5.9719e-01,
-         1.6802e-01, -1.3379e-02, -7.2615e-02,  2.4652e-05, -1.0622e+00,
-        -6.6422e-02,  7.1615e-01,  3.3434e-02,  1.0383e+00, -1.5694e-06,
-         3.2607e-01, -2.6873e-01,  3.5969e-02,  2.0436e-01,  4.6561e-01,
-         3.2662e-10, -1.2837e-04,  4.6165e-08,  1.1141e-01,  1.1541e+00,
-         1.2332e-07, -1.3430e-01,  9.7005e-02, -6.4754e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3404,  0.4241, -0.1103, -0.4669,  2.7436,  0.4487,  0.0000, -0.1351,
-         0.0299,  0.1297,  0.3920,  0.0927,  0.1234, -0.5373,  0.0000, -1.2375,
-         0.0000,  0.1360,  0.0000, -0.7024, -0.0316,  0.2789,  0.0000, -0.5285,
-         0.4608,  1.1898, -0.1105,  0.1636, -0.2191,  0.0000, -0.6757, -1.5635,
-         0.0000, -0.1082,  0.0000, -0.0992,  0.0000,  0.0000,  0.9265,  0.5972,
-         0.1680,  0.0000, -0.0726,  0.0000, -1.0622, -0.0664,  0.7162,  0.0334,
-         1.0383,  0.0000,  0.3261, -0.2687,  0.0360,  0.2044,  0.4656,  0.0000,
-         0.0000,  0.0000,  0.1114,  1.1541,  0.0000, -0.1343,  0.0970, -0.0648],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3404,  0.4241, -0.1103, -0.4669,  2.7436,  0.4487,  0.0000, -0.1351,
-         0.0299,  0.1297,  0.3920,  0.0927,  0.1234, -0.5373,  0.0000, -1.2375,
-         0.0000,  0.1360,  0.0000, -0.7024, -0.0316,  0.2789,  0.0000, -0.5285,
-         0.4608,  1.1898, -0.1105,  0.1636, -0.2191,  0.0000, -0.6757, -1.5635,
-         0.0000, -0.1082,  0.0000, -0.0992,  0.0000,  0.0000,  0.9265,  0.5972,
-         0.1680,  0.0000, -0.0726,  0.0000, -1.0622, -0.0664,  0.7162,  0.0334,
-         1.0383,  0.0000,  0.3261, -0.2687,  0.0360,  0.2044,  0.4656,  0.0000,
-         0.0000,  0.0000,  0.1114,  1.1541,  0.0000, -0.1343,  0.0970, -0.0648],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6143e-01,  4.1953e-01, -6.0128e-02, -4.6700e-01,  2.7505e+00,
-         4.3393e-01,  2.1515e-08, -1.7290e-01,  1.7515e-02,  8.8390e-02,
-         2.6861e-01,  1.1855e-01,  3.5438e-02, -5.3719e-01, -3.8582e-05,
-        -1.2416e+00, -2.8207e-03, -4.3583e-02, -6.5241e-06, -6.9885e-01,
-         1.4810e-02,  1.7807e-01, -1.3855e-05, -5.3262e-01,  4.4661e-01,
-         1.1984e+00, -2.3884e-01,  8.4848e-02, -1.8175e-01,  2.0150e-05,
-        -6.5850e-01, -1.5584e+00, -5.1579e-09, -1.1465e-01,  2.5601e-03,
-        -3.2203e-02,  8.5400e-05,  0.0000e+00,  9.5603e-01,  7.1225e-01,
-         1.1748e-01, -1.1407e-02, -1.9632e-01,  2.1018e-05, -1.0570e+00,
-        -2.3788e-01,  7.1901e-01,  3.3129e-02,  1.0428e+00, -1.3380e-06,
-         3.5566e-01, -2.4345e-01,  2.0594e-02,  1.3216e-02,  6.0806e-01,
-         2.7847e-10, -1.0944e-04,  3.9360e-08,  1.6219e-01,  1.1468e+00,
-         1.0514e-07, -1.5184e-01,  3.2521e-02, -2.2179e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3614,  0.4195, -0.0601, -0.4670,  2.7505,  0.4339,  0.0000, -0.1729,
-         0.0175,  0.0884,  0.2686,  0.1185,  0.0354, -0.5372,  0.0000, -1.2416,
-         0.0000, -0.0436,  0.0000, -0.6989,  0.0148,  0.1781,  0.0000, -0.5326,
-         0.4466,  1.1984, -0.2388,  0.0848, -0.1818,  0.0000, -0.6585, -1.5584,
-         0.0000, -0.1147,  0.0000, -0.0322,  0.0000,  0.0000,  0.9560,  0.7123,
-         0.1175,  0.0000, -0.1963,  0.0000,  0.0000, -0.2379,  0.7190,  0.0331,
-         1.0428,  0.0000,  0.3557, -0.2435,  0.0206,  0.0132,  0.6081,  0.0000,
-         0.0000,  0.0000,  0.1622,  1.1468,  0.0000, -0.1518,  0.0325, -0.2218],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3614,  0.4195, -0.0601, -0.4670,  2.7505,  0.4339,  0.0000, -0.1729,
-         0.0175,  0.0884,  0.2686,  0.1185,  0.0354, -0.5372,  0.0000, -1.2416,
-         0.0000, -0.0436,  0.0000, -0.6989,  0.0148,  0.1781,  0.0000, -0.5326,
-         0.4466,  1.1984, -0.2388,  0.0848, -0.1818,  0.0000, -0.6585, -1.5584,
-         0.0000, -0.1147,  0.0000, -0.0322,  0.0000,  0.0000,  0.9560,  0.7123,
-         0.1175,  0.0000, -0.1963,  0.0000,  0.0000, -0.2379,  0.7190,  0.0331,
-         1.0428,  0.0000,  0.3557, -0.2435,  0.0206,  0.0132,  0.6081,  0.0000,
-         0.0000,  0.0000,  0.1622,  1.1468,  0.0000, -0.1518,  0.0325, -0.2218],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.1297e-01,  3.9638e-01,  3.2786e-02, -4.9246e-01,  2.7602e+00,
-         4.0556e-01,  1.8346e-08, -1.9282e-01, -1.6123e-02,  6.9374e-02,
-         1.2359e-01,  1.4276e-01,  3.6331e-02, -5.2735e-01, -3.2900e-05,
-        -1.2450e+00, -2.4053e-03, -2.1134e-01, -5.5632e-06, -6.9071e-01,
-         1.7863e-02, -4.3029e-04, -1.1814e-05, -5.6250e-01,  3.5425e-01,
-         1.2000e+00, -3.2394e-01, -1.9351e-02, -2.3215e-01,  1.7182e-05,
-        -6.4611e-01, -1.5547e+00, -4.3982e-09, -1.6431e-01,  2.1831e-03,
-         2.1611e-02,  7.2822e-05,  0.0000e+00,  9.6652e-01,  8.0444e-01,
-        -2.8414e-03, -9.7267e-03, -2.9018e-01,  1.7922e-05,  4.3836e-03,
-        -3.9751e-01,  7.1886e-01,  4.5415e-02,  1.0505e+00, -1.1410e-06,
-         3.4028e-01, -2.6347e-01, -3.1480e-02, -1.6453e-01,  7.0453e-01,
-         2.3745e-10, -9.3324e-05,  3.3563e-08,  2.4443e-01,  1.1372e+00,
-         8.9654e-08, -1.2922e-01, -4.2801e-02, -3.5938e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 4.1297e-01,  3.9638e-01,  3.2786e-02, -4.9246e-01,  2.7602e+00,
-         4.0556e-01,  0.0000e+00, -1.9282e-01, -1.6123e-02,  6.9374e-02,
-         1.2359e-01,  1.4276e-01,  3.6331e-02, -5.2735e-01,  0.0000e+00,
-        -1.2450e+00,  0.0000e+00, -2.1134e-01,  0.0000e+00, -6.9071e-01,
-         1.7863e-02, -4.3029e-04,  0.0000e+00, -5.6250e-01,  3.5425e-01,
-         1.2000e+00, -3.2394e-01, -1.9351e-02, -2.3215e-01,  0.0000e+00,
-        -6.4611e-01, -1.5547e+00,  0.0000e+00, -1.6431e-01,  0.0000e+00,
-         2.1611e-02,  0.0000e+00,  0.0000e+00,  9.6652e-01,  8.0444e-01,
-        -2.8414e-03,  0.0000e+00, -2.9018e-01,  0.0000e+00,  0.0000e+00,
-        -3.9751e-01,  7.1886e-01,  4.5415e-02,  1.0505e+00,  0.0000e+00,
-         3.4028e-01, -2.6347e-01, -3.1480e-02, -1.6453e-01,  7.0453e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  2.4443e-01,  1.1372e+00,
-         0.0000e+00, -1.2922e-01, -4.2801e-02, -3.5938e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 4.1297e-01,  3.9638e-01,  3.2786e-02, -4.9246e-01,  2.7602e+00,
-         4.0556e-01,  0.0000e+00, -1.9282e-01, -1.6123e-02,  6.9374e-02,
-         1.2359e-01,  1.4276e-01,  3.6331e-02, -5.2735e-01,  0.0000e+00,
-        -1.2450e+00,  0.0000e+00, -2.1134e-01,  0.0000e+00, -6.9071e-01,
-         1.7863e-02, -4.3029e-04,  0.0000e+00, -5.6250e-01,  3.5425e-01,
-         1.2000e+00, -3.2394e-01, -1.9351e-02, -2.3215e-01,  0.0000e+00,
-        -6.4611e-01, -1.5547e+00,  0.0000e+00, -1.6431e-01,  0.0000e+00,
-         2.1611e-02,  0.0000e+00,  0.0000e+00,  9.6652e-01,  8.0444e-01,
-        -2.8414e-03,  0.0000e+00, -2.9018e-01,  0.0000e+00,  0.0000e+00,
-        -3.9751e-01,  7.1886e-01,  4.5415e-02,  1.0505e+00,  0.0000e+00,
-         3.4028e-01, -2.6347e-01, -3.1480e-02, -1.6453e-01,  7.0453e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  2.4443e-01,  1.1372e+00,
-         0.0000e+00, -1.2922e-01, -4.2801e-02, -3.5938e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 4.6937e-01,  3.6499e-01,  1.3808e-01, -4.8514e-01,  2.7673e+00,
-         3.7000e-01,  1.5647e-08, -2.8914e-01, -1.2178e-01,  8.0286e-03,
-         8.1254e-03,  1.4301e-01,  8.6609e-02, -5.2454e-01, -2.8059e-05,
-        -1.2568e+00, -2.0513e-03, -2.4853e-01, -4.7446e-06, -7.1146e-01,
-        -4.2529e-02, -1.6778e-01, -1.0076e-05, -5.8546e-01,  2.4426e-01,
-         1.1884e+00, -3.4728e-01, -1.3313e-01, -3.1291e-01,  1.4654e-05,
-        -6.2628e-01, -1.5486e+00, -3.7511e-09, -1.7319e-01,  1.8619e-03,
-         2.0965e-02,  6.2107e-05,  0.0000e+00,  9.5427e-01,  8.6896e-01,
-        -8.2772e-02, -8.2955e-03, -2.7866e-01,  1.5285e-05,  3.7386e-03,
-        -4.4846e-01,  7.1678e-01,  2.7530e-02,  1.0587e+00, -9.7309e-07,
-         3.4633e-01, -3.3524e-01, -1.0180e-01, -6.8606e-02,  7.7251e-01,
-         2.0252e-10, -7.9592e-05,  2.8624e-08,  3.1954e-01,  1.1377e+00,
-         7.6462e-08,  2.3156e-02, -6.1328e-02, -4.5453e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4694,  0.3650,  0.1381, -0.4851,  2.7673,  0.3700,  0.0000, -0.2891,
-        -0.1218,  0.0080,  0.0081,  0.1430,  0.0866, -0.5245,  0.0000, -1.2568,
-         0.0000, -0.2485,  0.0000, -0.7115, -0.0425, -0.1678,  0.0000, -0.5855,
-         0.2443,  1.1884, -0.3473, -0.1331, -0.3129,  0.0000, -0.6263, -1.5486,
-         0.0000, -0.1732,  0.0000,  0.0210,  0.0000,  0.0000,  0.9543,  0.8690,
-        -0.0828,  0.0000, -0.2787,  0.0000,  0.0000, -0.4485,  0.7168,  0.0275,
-         1.0587,  0.0000,  0.3463, -0.3352, -0.1018, -0.0686,  0.7725,  0.0000,
-         0.0000,  0.0000,  0.3195,  1.1377,  0.0000,  0.0232, -0.0613, -0.4545],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4694,  0.3650,  0.1381, -0.4851,  2.7673,  0.3700,  0.0000, -0.2891,
-        -0.1218,  0.0080,  0.0081,  0.1430,  0.0866, -0.5245,  0.0000, -1.2568,
-         0.0000, -0.2485,  0.0000, -0.7115, -0.0425, -0.1678,  0.0000, -0.5855,
-         0.2443,  1.1884, -0.3473, -0.1331, -0.3129,  0.0000, -0.6263, -1.5486,
-         0.0000, -0.1732,  0.0000,  0.0210,  0.0000,  0.0000,  0.9543,  0.8690,
-        -0.0828,  0.0000, -0.2787,  0.0000,  0.0000, -0.4485,  0.7168,  0.0275,
-         1.0587,  0.0000,  0.3463, -0.3352, -0.1018, -0.0686,  0.7725,  0.0000,
-         0.0000,  0.0000,  0.3195,  1.1377,  0.0000,  0.0232, -0.0613, -0.4545],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.9422e-01,  3.6734e-01,  2.0589e-01, -4.7460e-01,  2.7732e+00,
-         3.6803e-01,  1.3347e-08, -3.7806e-01, -1.8199e-01, -6.1560e-02,
-        -5.9970e-02,  8.5337e-02,  6.7126e-02, -4.9847e-01, -2.3934e-05,
-        -1.2665e+00, -1.7498e-03, -2.6879e-01, -4.0472e-06, -7.1967e-01,
-        -1.1370e-01, -2.6580e-01, -8.5947e-06, -5.9533e-01,  1.4969e-01,
-         1.1714e+00, -3.7923e-01, -1.8964e-01, -4.1811e-01,  1.2500e-05,
-        -5.8865e-01, -1.5401e+00, -3.1997e-09, -1.3318e-01,  1.5882e-03,
-        -2.8527e-02,  5.2978e-05,  0.0000e+00,  9.3028e-01,  9.1600e-01,
-        -1.0179e-01, -7.0761e-03, -2.9570e-01,  1.3038e-05,  3.1890e-03,
-        -4.5386e-01,  7.1190e-01,  2.3013e-02,  1.0699e+00, -8.3005e-07,
-         3.7022e-01, -3.3616e-01, -1.7351e-01,  4.5539e-02,  8.2437e-01,
-         1.7275e-10, -6.7892e-05,  2.4417e-08,  4.0194e-01,  1.1362e+00,
-         6.5222e-08,  1.4399e-01, -8.3551e-02, -5.4565e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4942,  0.3673,  0.2059, -0.4746,  2.7732,  0.3680,  0.0000, -0.3781,
-        -0.1820, -0.0616, -0.0600,  0.0853,  0.0671, -0.4985,  0.0000, -1.2665,
-         0.0000, -0.2688,  0.0000, -0.7197, -0.1137, -0.2658,  0.0000, -0.5953,
-         0.1497,  1.1714, -0.3792, -0.1896, -0.4181,  0.0000, -0.5887, -1.5401,
-         0.0000, -0.1332,  0.0000, -0.0285,  0.0000,  0.0000,  0.9303,  0.9160,
-        -0.1018,  0.0000, -0.2957,  0.0000,  0.0000, -0.4539,  0.7119,  0.0230,
-         1.0699,  0.0000,  0.3702, -0.3362, -0.1735,  0.0455,  0.8244,  0.0000,
-         0.0000,  0.0000,  0.4019,  1.1362,  0.0000,  0.1440, -0.0836, -0.5457],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4942,  0.3673,  0.2059, -0.4746,  2.7732,  0.3680,  0.0000, -0.3781,
-        -0.1820, -0.0616, -0.0600,  0.0853,  0.0671, -0.4985,  0.0000, -1.2665,
-         0.0000, -0.2688,  0.0000, -0.7197, -0.1137, -0.2658,  0.0000, -0.5953,
-         0.1497,  1.1714, -0.3792, -0.1896, -0.4181,  0.0000, -0.5887, -1.5401,
-         0.0000, -0.1332,  0.0000, -0.0285,  0.0000,  0.0000,  0.9303,  0.9160,
-        -0.1018,  0.0000, -0.2957,  0.0000,  0.0000, -0.4539,  0.7119,  0.0230,
-         1.0699,  0.0000,  0.3702, -0.3362, -0.1735,  0.0455,  0.8244,  0.0000,
-         0.0000,  0.0000,  0.4019,  1.1362,  0.0000,  0.1440, -0.0836, -0.5457],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.9972e-01,  3.8914e-01,  2.3318e-01, -4.5085e-01,  2.7777e+00,
-         3.6829e-01,  1.1387e-08, -4.7032e-01, -1.6243e-01, -1.4417e-01,
-        -2.1745e-02,  8.2538e-05, -3.9310e-02, -4.8265e-01, -2.0420e-05,
-        -1.2738e+00, -1.4928e-03, -2.2417e-01, -3.4528e-06, -7.2172e-01,
-        -1.2450e-01, -2.4052e-01, -7.3326e-06, -6.0175e-01,  1.2311e-01,
-         1.1580e+00, -4.3335e-01, -1.5594e-01, -5.0574e-01,  1.0664e-05,
-        -5.5115e-01, -1.5326e+00, -2.7298e-09, -5.3620e-02,  1.3550e-03,
-        -7.3203e-02,  4.5198e-05,  0.0000e+00,  8.8960e-01,  9.3610e-01,
-        -8.5673e-03, -6.0369e-03, -3.5212e-01,  1.1124e-05,  2.7207e-03,
-        -4.9676e-01,  7.0752e-01,  1.7112e-02,  1.0574e+00, -7.0815e-07,
-         4.1044e-01, -2.5439e-01, -2.4381e-01,  9.3941e-02,  8.5840e-01,
-         1.4738e-10, -5.7922e-05,  2.0831e-08,  4.4358e-01,  1.1380e+00,
-         5.5644e-08,  1.9653e-01, -5.8764e-02, -6.1996e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 4.9972e-01,  3.8914e-01,  2.3318e-01, -4.5085e-01,  2.7777e+00,
-         3.6829e-01,  0.0000e+00, -4.7032e-01, -1.6243e-01, -1.4417e-01,
-        -2.1745e-02,  8.2538e-05, -3.9310e-02, -4.8265e-01,  0.0000e+00,
-        -1.2738e+00,  0.0000e+00, -2.2417e-01,  0.0000e+00, -7.2172e-01,
-        -1.2450e-01, -2.4052e-01,  0.0000e+00, -6.0175e-01,  1.2311e-01,
-         1.1580e+00, -4.3335e-01, -1.5594e-01, -5.0574e-01,  0.0000e+00,
-        -5.5115e-01, -1.5326e+00,  0.0000e+00, -5.3620e-02,  0.0000e+00,
-        -7.3203e-02,  0.0000e+00,  0.0000e+00,  8.8960e-01,  9.3610e-01,
-        -8.5673e-03,  0.0000e+00, -3.5212e-01,  0.0000e+00,  0.0000e+00,
-        -4.9676e-01,  7.0752e-01,  1.7112e-02,  1.0574e+00,  0.0000e+00,
-         4.1044e-01, -2.5439e-01, -2.4381e-01,  9.3941e-02,  8.5840e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.4358e-01,  1.1380e+00,
-         0.0000e+00,  1.9653e-01, -5.8764e-02, -6.1996e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 4.9972e-01,  3.8914e-01,  2.3318e-01, -4.5085e-01,  2.7777e+00,
-         3.6829e-01,  0.0000e+00, -4.7032e-01, -1.6243e-01, -1.4417e-01,
-        -2.1745e-02,  8.2538e-05, -3.9310e-02, -4.8265e-01,  0.0000e+00,
-        -1.2738e+00,  0.0000e+00, -2.2417e-01,  0.0000e+00, -7.2172e-01,
-        -1.2450e-01, -2.4052e-01,  0.0000e+00, -6.0175e-01,  1.2311e-01,
-         1.1580e+00, -4.3335e-01, -1.5594e-01, -5.0574e-01,  0.0000e+00,
-        -5.5115e-01, -1.5326e+00,  0.0000e+00, -5.3620e-02,  0.0000e+00,
-        -7.3203e-02,  0.0000e+00,  0.0000e+00,  8.8960e-01,  9.3610e-01,
-        -8.5673e-03,  0.0000e+00, -3.5212e-01,  0.0000e+00,  0.0000e+00,
-        -4.9676e-01,  7.0752e-01,  1.7112e-02,  1.0574e+00,  0.0000e+00,
-         4.1044e-01, -2.5439e-01, -2.4381e-01,  9.3941e-02,  8.5840e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.4358e-01,  1.1380e+00,
-         0.0000e+00,  1.9653e-01, -5.8764e-02, -6.1996e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 5.0827e-01,  4.3162e-01,  2.6909e-01, -4.1004e-01,  2.7808e+00,
-         3.7558e-01,  9.7163e-09, -5.4704e-01, -1.1955e-01, -2.1143e-01,
-         5.1767e-02, -7.6659e-02, -1.0658e-01, -4.5141e-01, -1.7424e-05,
-        -1.2792e+00, -1.2738e-03, -1.7721e-01, -2.9463e-06, -7.2132e-01,
-        -1.2764e-01, -1.9195e-01, -6.2569e-06, -6.0222e-01,  1.1984e-01,
-         1.1428e+00, -4.5046e-01, -1.0375e-01, -5.7382e-01,  9.0998e-06,
-        -5.3920e-01, -1.5266e+00, -2.3294e-09,  7.6146e-03,  1.1562e-03,
-        -1.2388e-01,  3.8568e-05,  0.0000e+00,  8.2192e-01,  9.4845e-01,
-         1.1968e-01, -5.1513e-03, -4.1203e-01,  9.4918e-06,  2.3216e-03,
-        -5.4156e-01,  7.0101e-01,  2.3639e-02,  1.0465e+00, -6.0427e-07,
-         4.5336e-01, -1.5822e-01, -3.1862e-01,  2.0222e-01,  8.9239e-01,
-         1.2576e-10, -4.9425e-05,  1.7775e-08,  4.6229e-01,  1.1411e+00,
-         4.7482e-08,  2.8186e-01,  1.2398e-03, -6.8419e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 5.0827e-01,  4.3162e-01,  2.6909e-01, -4.1004e-01,  2.7808e+00,
-         3.7558e-01,  0.0000e+00,  0.0000e+00, -1.1955e-01, -2.1143e-01,
-         5.1767e-02, -7.6659e-02, -1.0658e-01, -4.5141e-01,  0.0000e+00,
-        -1.2792e+00,  0.0000e+00, -1.7721e-01,  0.0000e+00, -7.2132e-01,
-        -1.2764e-01, -1.9195e-01,  0.0000e+00, -6.0222e-01,  1.1984e-01,
-         1.1428e+00, -4.5046e-01, -1.0375e-01, -5.7382e-01,  0.0000e+00,
-        -5.3920e-01, -1.5266e+00,  0.0000e+00,  7.6146e-03,  0.0000e+00,
-        -1.2388e-01,  0.0000e+00,  0.0000e+00,  8.2192e-01,  9.4845e-01,
-         1.1968e-01,  0.0000e+00, -4.1203e-01,  0.0000e+00,  0.0000e+00,
-        -5.4156e-01,  7.0101e-01,  2.3639e-02,  1.0465e+00,  0.0000e+00,
-         4.5336e-01, -1.5822e-01, -3.1862e-01,  2.0222e-01,  8.9239e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6229e-01,  1.1411e+00,
-         0.0000e+00,  2.8186e-01,  1.2398e-03, -6.8419e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 5.0827e-01,  4.3162e-01,  2.6909e-01, -4.1004e-01,  2.7808e+00,
-         3.7558e-01,  0.0000e+00,  0.0000e+00, -1.1955e-01, -2.1143e-01,
-         5.1767e-02, -7.6659e-02, -1.0658e-01, -4.5141e-01,  0.0000e+00,
-        -1.2792e+00,  0.0000e+00, -1.7721e-01,  0.0000e+00, -7.2132e-01,
-        -1.2764e-01, -1.9195e-01,  0.0000e+00, -6.0222e-01,  1.1984e-01,
-         1.1428e+00, -4.5046e-01, -1.0375e-01, -5.7382e-01,  0.0000e+00,
-        -5.3920e-01, -1.5266e+00,  0.0000e+00,  7.6146e-03,  0.0000e+00,
-        -1.2388e-01,  0.0000e+00,  0.0000e+00,  8.2192e-01,  9.4845e-01,
-         1.1968e-01,  0.0000e+00, -4.1203e-01,  0.0000e+00,  0.0000e+00,
-        -5.4156e-01,  7.0101e-01,  2.3639e-02,  1.0465e+00,  0.0000e+00,
-         4.5336e-01, -1.5822e-01, -3.1862e-01,  2.0222e-01,  8.9239e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6229e-01,  1.1411e+00,
-         0.0000e+00,  2.8186e-01,  1.2398e-03, -6.8419e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 5.1439e-01,  4.5427e-01,  3.0043e-01, -3.5376e-01,  2.7791e+00,
-         3.9175e-01,  8.2924e-09, -6.5478e-02, -6.6721e-02, -2.6445e-01,
-         8.2392e-02, -1.7072e-01, -1.6402e-01, -4.1670e-01, -1.4871e-05,
-        -1.2814e+00, -1.0872e-03, -1.2216e-01, -2.5146e-06, -7.1646e-01,
-        -1.2343e-01, -1.6177e-01, -5.3400e-06, -6.1710e-01,  1.0442e-01,
-         1.1263e+00, -4.4463e-01, -3.9453e-02, -6.0450e-01,  7.7662e-06,
-        -5.4570e-01, -1.5231e+00, -1.9880e-09,  3.4398e-02,  9.8675e-04,
-        -1.9172e-01,  3.2916e-05,  0.0000e+00,  7.5341e-01,  9.6419e-01,
-         2.4059e-01, -4.3964e-03, -4.6245e-01,  8.1008e-06,  1.9814e-03,
-        -5.9481e-01,  6.9491e-01,  2.9752e-02,  1.0399e+00, -5.1572e-07,
-         4.7976e-01, -7.3564e-02, -3.7715e-01,  2.7514e-01,  9.2753e-01,
-         1.0733e-10, -4.2182e-05,  1.5170e-08,  4.7595e-01,  1.1448e+00,
-         4.0523e-08,  3.5228e-01,  2.4484e-02, -7.1817e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5144,  0.4543,  0.3004, -0.3538,  2.7791,  0.3917,  0.0000,  0.0000,
-        -0.0667, -0.2645,  0.0824, -0.1707, -0.1640, -0.4167,  0.0000, -1.2814,
-         0.0000, -0.1222,  0.0000, -0.7165, -0.1234, -0.1618,  0.0000, -0.6171,
-         0.1044,  1.1263, -0.4446, -0.0395, -0.6045,  0.0000, -0.5457, -1.5231,
-         0.0000,  0.0344,  0.0000, -0.1917,  0.0000,  0.0000,  0.7534,  0.9642,
-         0.2406,  0.0000, -0.4624,  0.0000,  0.0000, -0.5948,  0.6949,  0.0298,
-         1.0399,  0.0000,  0.4798, -0.0736, -0.3772,  0.2751,  0.9275,  0.0000,
-         0.0000,  0.0000,  0.4760,  1.1448,  0.0000,  0.3523,  0.0245, -0.7182],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5144,  0.4543,  0.3004, -0.3538,  2.7791,  0.3917,  0.0000,  0.0000,
-        -0.0667, -0.2645,  0.0824, -0.1707, -0.1640, -0.4167,  0.0000, -1.2814,
-         0.0000, -0.1222,  0.0000, -0.7165, -0.1234, -0.1618,  0.0000, -0.6171,
-         0.1044,  1.1263, -0.4446, -0.0395, -0.6045,  0.0000, -0.5457, -1.5231,
-         0.0000,  0.0344,  0.0000, -0.1917,  0.0000,  0.0000,  0.7534,  0.9642,
-         0.2406,  0.0000, -0.4624,  0.0000,  0.0000, -0.5948,  0.6949,  0.0298,
-         1.0399,  0.0000,  0.4798, -0.0736, -0.3772,  0.2751,  0.9275,  0.0000,
-         0.0000,  0.0000,  0.4760,  1.1448,  0.0000,  0.3523,  0.0245, -0.7182],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.2155e-01,  4.5685e-01,  3.6684e-01, -2.7621e-01,  2.7747e+00,
-         4.3421e-01,  7.0785e-09, -5.5893e-02, -1.1054e-02, -2.7233e-01,
-         6.5258e-02, -2.3059e-01, -1.6990e-01, -3.9763e-01, -1.2694e-05,
-        -1.2837e+00, -9.2803e-04, -1.3970e-02, -2.1465e-06, -7.0910e-01,
-        -1.2577e-01, -1.0847e-01, -4.5583e-06, -6.1807e-01,  6.8812e-02,
-         1.1185e+00, -3.9099e-01,  2.9979e-02, -5.8527e-01,  6.6294e-06,
-        -5.4355e-01, -1.5167e+00, -1.6970e-09,  6.6920e-02,  8.4230e-04,
-        -2.4991e-01,  2.8097e-05,  0.0000e+00,  6.9759e-01,  9.8444e-01,
-         3.2262e-01, -3.7528e-03, -4.8418e-01,  6.9150e-06,  1.6913e-03,
-        -6.5493e-01,  6.7905e-01,  3.9363e-02,  1.0226e+00, -4.4022e-07,
-         4.7303e-01, -2.8497e-02, -4.6316e-01,  3.5240e-01,  9.5883e-01,
-         9.1618e-11, -3.6007e-05,  1.2950e-08,  4.7477e-01,  1.1391e+00,
-         3.4591e-08,  4.2892e-01,  5.0684e-02, -7.1826e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5215,  0.4569,  0.3668, -0.2762,  2.7747,  0.4342,  0.0000,  0.0000,
-        -0.0111, -0.2723,  0.0653, -0.2306, -0.1699, -0.3976,  0.0000, -1.2837,
-         0.0000, -0.0140,  0.0000, -0.7091, -0.1258, -0.1085,  0.0000, -0.6181,
-         0.0688,  1.1185, -0.3910,  0.0300, -0.5853,  0.0000, -0.5435, -1.5167,
-         0.0000,  0.0669,  0.0000, -0.2499,  0.0000,  0.0000,  0.6976,  0.9844,
-         0.3226,  0.0000, -0.4842,  0.0000,  0.0000, -0.6549,  0.6791,  0.0394,
-         1.0226,  0.0000,  0.4730, -0.0285, -0.4632,  0.3524,  0.9588,  0.0000,
-         0.0000,  0.0000,  0.4748,  1.1391,  0.0000,  0.4289,  0.0507, -0.7183],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5215,  0.4569,  0.3668, -0.2762,  2.7747,  0.4342,  0.0000,  0.0000,
-        -0.0111, -0.2723,  0.0653, -0.2306, -0.1699, -0.3976,  0.0000, -1.2837,
-         0.0000, -0.0140,  0.0000, -0.7091, -0.1258, -0.1085,  0.0000, -0.6181,
-         0.0688,  1.1185, -0.3910,  0.0300, -0.5853,  0.0000, -0.5435, -1.5167,
-         0.0000,  0.0669,  0.0000, -0.2499,  0.0000,  0.0000,  0.6976,  0.9844,
-         0.3226,  0.0000, -0.4842,  0.0000,  0.0000, -0.6549,  0.6791,  0.0394,
-         1.0226,  0.0000,  0.4730, -0.0285, -0.4632,  0.3524,  0.9588,  0.0000,
-         0.0000,  0.0000,  0.4748,  1.1391,  0.0000,  0.4289,  0.0507, -0.7183],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.2336e-01,  4.7795e-01,  3.9967e-01, -2.0265e-01,  2.7708e+00,
-         4.6473e-01,  6.0434e-09, -4.7720e-02,  5.4610e-02, -2.8867e-01,
-         6.8803e-02, -2.6840e-01, -2.1967e-01, -3.7408e-01, -1.0838e-05,
-        -1.2791e+00, -7.9232e-04,  7.5216e-02, -1.8326e-06, -6.9101e-01,
-        -1.0784e-01, -5.8703e-02, -3.8917e-06, -5.9960e-01,  4.5075e-02,
-         1.1187e+00, -3.5392e-01,  8.4930e-02, -5.2633e-01,  5.6600e-06,
-        -5.2566e-01, -1.5091e+00, -1.4488e-09,  6.7229e-02,  7.1913e-04,
-        -2.8939e-01,  2.3989e-05,  0.0000e+00,  6.8981e-01,  1.0045e+00,
-         3.8778e-01, -3.2041e-03, -5.2108e-01,  5.9038e-06,  1.4440e-03,
-        -6.8890e-01,  6.5725e-01,  6.4291e-02,  1.0101e+00, -3.7585e-07,
-         4.6655e-01,  1.2703e-02, -5.2192e-01,  3.9178e-01,  9.9423e-01,
-         7.8221e-11, -3.0742e-05,  1.1056e-08,  4.6088e-01,  1.1249e+00,
-         2.9533e-08,  4.3932e-01,  1.1424e-01, -7.1390e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5234,  0.4779,  0.3997, -0.2027,  2.7708,  0.4647,  0.0000,  0.0000,
-         0.0546, -0.2887,  0.0688, -0.2684, -0.2197, -0.3741,  0.0000, -1.2791,
-         0.0000,  0.0752,  0.0000, -0.6910, -0.1078, -0.0587,  0.0000, -0.5996,
-         0.0451,  1.1187, -0.3539,  0.0849, -0.5263,  0.0000, -0.5257, -1.5091,
-         0.0000,  0.0672,  0.0000, -0.2894,  0.0000,  0.0000,  0.6898,  1.0045,
-         0.3878,  0.0000, -0.5211,  0.0000,  0.0000, -0.6889,  0.6573,  0.0643,
-         1.0101,  0.0000,  0.4666,  0.0127, -0.5219,  0.3918,  0.9942,  0.0000,
-         0.0000,  0.0000,  0.4609,  1.1249,  0.0000,  0.4393,  0.1142, -0.7139],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5234,  0.4779,  0.3997, -0.2027,  2.7708,  0.4647,  0.0000,  0.0000,
-         0.0546, -0.2887,  0.0688, -0.2684, -0.2197, -0.3741,  0.0000, -1.2791,
-         0.0000,  0.0752,  0.0000, -0.6910, -0.1078, -0.0587,  0.0000, -0.5996,
-         0.0451,  1.1187, -0.3539,  0.0849, -0.5263,  0.0000, -0.5257, -1.5091,
-         0.0000,  0.0672,  0.0000, -0.2894,  0.0000,  0.0000,  0.6898,  1.0045,
-         0.3878,  0.0000, -0.5211,  0.0000,  0.0000, -0.6889,  0.6573,  0.0643,
-         1.0101,  0.0000,  0.4666,  0.0127, -0.5219,  0.3918,  0.9942,  0.0000,
-         0.0000,  0.0000,  0.4609,  1.1249,  0.0000,  0.4393,  0.1142, -0.7139],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.9066e-01,  4.6157e-01,  3.7517e-01, -1.6273e-01,  2.7631e+00,
-         4.4793e-01,  5.1607e-09, -4.0750e-02,  1.1650e-01, -3.2733e-01,
-         1.1794e-01, -2.9736e-01, -3.1262e-01, -3.4747e-01, -9.2546e-06,
-        -1.2700e+00, -6.7659e-04,  1.2106e-01, -1.5649e-06, -6.7220e-01,
-        -6.6253e-02, -2.5054e-02, -3.3233e-06, -5.5994e-01,  3.0166e-02,
-         1.1239e+00, -3.4683e-01,  1.0869e-01, -4.4463e-01,  4.8332e-06,
-        -5.0706e-01, -1.5024e+00, -1.2372e-09,  6.0613e-02,  6.1409e-04,
-        -3.0794e-01,  2.0485e-05,  0.0000e+00,  6.8260e-01,  1.0259e+00,
-         4.4690e-01, -2.7361e-03, -5.5278e-01,  5.0415e-06,  1.2331e-03,
-        -6.8503e-01,  6.3657e-01,  1.1944e-01,  9.9027e-01, -3.2095e-07,
-         4.9176e-01,  5.2283e-02, -5.4418e-01,  3.1771e-01,  1.0237e+00,
-         6.6795e-11, -2.6252e-05,  9.4411e-09,  4.4419e-01,  1.1176e+00,
-         2.5219e-08,  3.7193e-01,  2.0000e-01, -7.0621e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4907,  0.4616,  0.3752, -0.1627,  2.7631,  0.4479,  0.0000,  0.0000,
-         0.1165, -0.3273,  0.1179, -0.2974, -0.3126, -0.3475,  0.0000, -1.2700,
-         0.0000,  0.1211,  0.0000,  0.0000, -0.0663, -0.0251,  0.0000, -0.5599,
-         0.0302,  1.1239, -0.3468,  0.1087, -0.4446,  0.0000, -0.5071, -1.5024,
-         0.0000,  0.0606,  0.0000, -0.3079,  0.0000,  0.0000,  0.6826,  1.0259,
-         0.4469,  0.0000, -0.5528,  0.0000,  0.0000, -0.6850,  0.6366,  0.1194,
-         0.9903,  0.0000,  0.4918,  0.0523, -0.5442,  0.3177,  1.0237,  0.0000,
-         0.0000,  0.0000,  0.4442,  1.1176,  0.0000,  0.3719,  0.2000, -0.7062],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4907,  0.4616,  0.3752, -0.1627,  2.7631,  0.4479,  0.0000,  0.0000,
-         0.1165, -0.3273,  0.1179, -0.2974, -0.3126, -0.3475,  0.0000, -1.2700,
-         0.0000,  0.1211,  0.0000,  0.0000, -0.0663, -0.0251,  0.0000, -0.5599,
-         0.0302,  1.1239, -0.3468,  0.1087, -0.4446,  0.0000, -0.5071, -1.5024,
-         0.0000,  0.0606,  0.0000, -0.3079,  0.0000,  0.0000,  0.6826,  1.0259,
-         0.4469,  0.0000, -0.5528,  0.0000,  0.0000, -0.6850,  0.6366,  0.1194,
-         0.9903,  0.0000,  0.4918,  0.0523, -0.5442,  0.3177,  1.0237,  0.0000,
-         0.0000,  0.0000,  0.4442,  1.1176,  0.0000,  0.3719,  0.2000, -0.7062],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.4564e-01,  4.3686e-01,  3.5586e-01, -1.0115e-01,  2.7514e+00,
-         4.0052e-01,  4.4078e-09, -3.4804e-02,  1.9066e-01, -3.4893e-01,
-         1.2744e-01, -3.3050e-01, -4.0061e-01, -2.8619e-01, -7.9044e-06,
-        -1.2569e+00, -5.7788e-04,  1.9167e-01, -1.3366e-06,  1.6070e-02,
-        -2.4684e-02, -4.7605e-02, -2.8384e-06, -4.9052e-01, -7.8256e-03,
-         1.1309e+00, -3.5191e-01,  1.5242e-01, -3.5952e-01,  4.1281e-06,
-        -4.4362e-01, -1.4992e+00, -1.0567e-09,  4.0074e-02,  5.2450e-04,
-        -3.2096e-01,  1.7496e-05,  0.0000e+00,  6.5954e-01,  1.0469e+00,
-         5.0606e-01, -2.3369e-03, -5.8621e-01,  4.3059e-06,  1.0532e-03,
-        -6.4734e-01,  6.2154e-01,  1.5981e-01,  9.6530e-01, -2.7413e-07,
-         5.0901e-01,  9.2816e-02, -5.4733e-01,  2.3990e-01,  1.0482e+00,
-         5.7050e-11, -2.2422e-05,  8.0637e-09,  4.0656e-01,  1.1102e+00,
-         2.1540e-08,  3.0995e-01,  2.0298e-01, -7.0515e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4456,  0.4369,  0.3559, -0.1012,  2.7514,  0.4005,  0.0000,  0.0000,
-         0.1907, -0.3489,  0.1274, -0.3305, -0.4006, -0.2862,  0.0000, -1.2569,
-         0.0000,  0.1917,  0.0000,  0.0000, -0.0247, -0.0476,  0.0000, -0.4905,
-        -0.0078,  1.1309, -0.3519,  0.1524, -0.3595,  0.0000, -0.4436, -1.4992,
-         0.0000,  0.0401,  0.0000, -0.3210,  0.0000,  0.0000,  0.6595,  1.0469,
-         0.5061,  0.0000, -0.5862,  0.0000,  0.0000, -0.6473,  0.6215,  0.1598,
-         0.9653,  0.0000,  0.5090,  0.0928, -0.5473,  0.2399,  1.0482,  0.0000,
-         0.0000,  0.0000,  0.4066,  1.1102,  0.0000,  0.3100,  0.2030, -0.7052],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4456,  0.4369,  0.3559, -0.1012,  2.7514,  0.4005,  0.0000,  0.0000,
-         0.1907, -0.3489,  0.1274, -0.3305, -0.4006, -0.2862,  0.0000, -1.2569,
-         0.0000,  0.1917,  0.0000,  0.0000, -0.0247, -0.0476,  0.0000, -0.4905,
-        -0.0078,  1.1309, -0.3519,  0.1524, -0.3595,  0.0000, -0.4436, -1.4992,
-         0.0000,  0.0401,  0.0000, -0.3210,  0.0000,  0.0000,  0.6595,  1.0469,
-         0.5061,  0.0000, -0.5862,  0.0000,  0.0000, -0.6473,  0.6215,  0.1598,
-         0.9653,  0.0000,  0.5090,  0.0928, -0.5473,  0.2399,  1.0482,  0.0000,
-         0.0000,  0.0000,  0.4066,  1.1102,  0.0000,  0.3100,  0.2030, -0.7052],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.0675e-01,  4.3495e-01,  3.6935e-01, -4.9371e-02,  2.7408e+00,
-         4.2849e-01,  3.7654e-09, -2.9732e-02,  2.6867e-01, -3.7072e-01,
-         1.1136e-01, -3.1791e-01, -4.0417e-01, -2.2272e-01, -6.7525e-06,
-        -1.2451e+00, -4.9367e-04,  2.9456e-01, -1.1418e-06,  1.3728e-02,
-         8.9688e-03, -7.6323e-02, -2.4248e-06, -4.6161e-01, -6.2462e-02,
-         1.1251e+00, -3.1951e-01,  1.8377e-01, -3.0984e-01,  3.5265e-06,
-        -3.5314e-01, -1.4933e+00, -9.0272e-10,  1.3515e-04,  4.4807e-04,
-        -2.9705e-01,  1.4946e-05,  0.0000e+00,  6.0204e-01,  1.0398e+00,
-         5.3325e-01, -1.9963e-03, -5.4726e-01,  3.6785e-06,  8.9971e-04,
-        -5.8262e-01,  5.8942e-01,  1.8665e-01,  9.3911e-01, -2.3418e-07,
-         5.0487e-01,  2.1886e-02, -5.1786e-01,  2.2792e-01,  1.0523e+00,
-         4.8736e-11, -1.9154e-05,  6.8886e-09,  3.7068e-01,  1.1044e+00,
-         1.8401e-08,  3.3121e-01,  1.6405e-01, -6.8330e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 4.0675e-01,  4.3495e-01,  3.6935e-01, -4.9371e-02,  2.7408e+00,
-         4.2849e-01,  0.0000e+00,  0.0000e+00,  2.6867e-01, -3.7072e-01,
-         1.1136e-01, -3.1791e-01, -4.0417e-01, -2.2272e-01,  0.0000e+00,
-        -1.2451e+00,  0.0000e+00,  2.9456e-01,  0.0000e+00,  0.0000e+00,
-         8.9688e-03, -7.6323e-02,  0.0000e+00, -4.6161e-01, -6.2462e-02,
-         1.1251e+00, -3.1951e-01,  1.8377e-01, -3.0984e-01,  0.0000e+00,
-        -3.5314e-01, -1.4933e+00,  0.0000e+00,  1.3515e-04,  0.0000e+00,
-        -2.9705e-01,  0.0000e+00,  0.0000e+00,  6.0204e-01,  1.0398e+00,
-         5.3325e-01,  0.0000e+00, -5.4726e-01,  0.0000e+00,  0.0000e+00,
-        -5.8262e-01,  5.8942e-01,  1.8665e-01,  9.3911e-01,  0.0000e+00,
-         5.0487e-01,  2.1886e-02, -5.1786e-01,  2.2792e-01,  1.0523e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.7068e-01,  1.1044e+00,
-         0.0000e+00,  3.3121e-01,  1.6405e-01, -6.8330e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 4.0675e-01,  4.3495e-01,  3.6935e-01, -4.9371e-02,  2.7408e+00,
-         4.2849e-01,  0.0000e+00,  0.0000e+00,  2.6867e-01, -3.7072e-01,
-         1.1136e-01, -3.1791e-01, -4.0417e-01, -2.2272e-01,  0.0000e+00,
-        -1.2451e+00,  0.0000e+00,  2.9456e-01,  0.0000e+00,  0.0000e+00,
-         8.9688e-03, -7.6323e-02,  0.0000e+00, -4.6161e-01, -6.2462e-02,
-         1.1251e+00, -3.1951e-01,  1.8377e-01, -3.0984e-01,  0.0000e+00,
-        -3.5314e-01, -1.4933e+00,  0.0000e+00,  1.3515e-04,  0.0000e+00,
-        -2.9705e-01,  0.0000e+00,  0.0000e+00,  6.0204e-01,  1.0398e+00,
-         5.3325e-01,  0.0000e+00, -5.4726e-01,  0.0000e+00,  0.0000e+00,
-        -5.8262e-01,  5.8942e-01,  1.8665e-01,  9.3911e-01,  0.0000e+00,
-         5.0487e-01,  2.1886e-02, -5.1786e-01,  2.2792e-01,  1.0523e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.7068e-01,  1.1044e+00,
-         0.0000e+00,  3.3121e-01,  1.6405e-01, -6.8330e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.3334e-01,  4.5853e-01,  3.5917e-01, -5.7213e-02,  2.7320e+00,
-         4.8615e-01,  3.2174e-09, -2.5405e-02,  3.4827e-01, -3.9466e-01,
-         9.9675e-02, -3.3736e-01, -4.1812e-01, -1.8233e-01, -5.7697e-06,
-        -1.2296e+00, -4.2181e-04,  3.2746e-01, -9.7562e-07,  1.1730e-02,
-         6.4807e-02, -1.1550e-01, -2.0719e-06, -4.4305e-01, -1.1371e-01,
-         1.1172e+00, -3.0572e-01,  2.2500e-01, -2.5798e-01,  3.0132e-06,
-        -2.4887e-01, -1.4862e+00, -7.7132e-10, -5.9355e-02,  3.8285e-04,
-        -2.7560e-01,  1.2771e-05,  0.0000e+00,  5.3156e-01,  1.0299e+00,
-         5.4878e-01, -1.7058e-03, -5.1723e-01,  3.1430e-06,  7.6875e-04,
-        -5.2316e-01,  5.6412e-01,  1.6917e-01,  9.1337e-01, -2.0009e-07,
-         4.9189e-01, -1.1376e-01, -4.6885e-01,  8.6124e-02,  1.0471e+00,
-         4.1643e-11, -1.6366e-05,  5.8859e-09,  3.4401e-01,  1.1021e+00,
-         1.5723e-08,  3.5308e-01,  9.0090e-02, -6.6843e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3333,  0.4585,  0.3592, -0.0572,  2.7320,  0.4862,  0.0000,  0.0000,
-         0.3483, -0.3947,  0.0997, -0.3374, -0.4181, -0.1823,  0.0000, -1.2296,
-         0.0000,  0.3275,  0.0000,  0.0000,  0.0648, -0.1155,  0.0000, -0.4430,
-        -0.1137,  1.1172, -0.3057,  0.2250, -0.2580,  0.0000, -0.2489, -1.4862,
-         0.0000, -0.0594,  0.0000, -0.2756,  0.0000,  0.0000,  0.5316,  1.0299,
-         0.5488,  0.0000, -0.5172,  0.0000,  0.0000, -0.5232,  0.5641,  0.1692,
-         0.9134,  0.0000,  0.4919, -0.1138, -0.4688,  0.0861,  1.0471,  0.0000,
-         0.0000,  0.0000,  0.3440,  1.1021,  0.0000,  0.3531,  0.0901, -0.6684],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3333,  0.4585,  0.3592, -0.0572,  2.7320,  0.4862,  0.0000,  0.0000,
-         0.3483, -0.3947,  0.0997, -0.3374, -0.4181, -0.1823,  0.0000, -1.2296,
-         0.0000,  0.3275,  0.0000,  0.0000,  0.0648, -0.1155,  0.0000, -0.4430,
-        -0.1137,  1.1172, -0.3057,  0.2250, -0.2580,  0.0000, -0.2489, -1.4862,
-         0.0000, -0.0594,  0.0000, -0.2756,  0.0000,  0.0000,  0.5316,  1.0299,
-         0.5488,  0.0000, -0.5172,  0.0000,  0.0000, -0.5232,  0.5641,  0.1692,
-         0.9134,  0.0000,  0.4919, -0.1138, -0.4688,  0.0861,  1.0471,  0.0000,
-         0.0000,  0.0000,  0.3440,  1.1021,  0.0000,  0.3531,  0.0901, -0.6684],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9627e-01,  4.5092e-01,  3.4652e-01, -4.6569e-02,  2.7260e+00,
-         5.2229e-01,  2.7496e-09, -2.1711e-02,  3.9071e-01, -4.2237e-01,
-         1.2604e-02, -3.3097e-01, -3.9669e-01, -1.1215e-01, -4.9309e-06,
-        -1.2185e+00, -3.6049e-04,  3.4274e-01, -8.3379e-07,  1.0025e-02,
-         1.0582e-01, -1.4972e-01, -1.7707e-06, -4.0933e-01, -1.6046e-01,
-         1.1069e+00, -2.8464e-01,  2.6046e-01, -2.0897e-01,  2.5752e-06,
-        -1.3702e-01, -1.4749e+00, -6.5919e-10, -7.0743e-02,  3.2719e-04,
-        -2.6231e-01,  1.0914e-05,  0.0000e+00,  4.5097e-01,  1.0114e+00,
-         5.6462e-01, -1.4578e-03, -4.4888e-01,  2.6861e-06,  6.5699e-04,
-        -4.4035e-01,  5.3199e-01,  1.3148e-01,  8.8902e-01, -1.7100e-07,
-         4.7687e-01, -2.0492e-01, -4.3648e-01, -5.7009e-02,  1.0389e+00,
-         3.5589e-11, -1.3987e-05,  5.0302e-09,  3.0275e-01,  1.1000e+00,
-         1.3437e-08,  3.6486e-01,  1.0215e-02, -6.5347e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2963,  0.4509,  0.3465, -0.0466,  2.7260,  0.5223,  0.0000,  0.0000,
-         0.3907, -0.4224,  0.0126, -0.3310, -0.3967, -0.1122,  0.0000, -1.2185,
-         0.0000,  0.3427,  0.0000,  0.0000,  0.1058, -0.1497,  0.0000, -0.4093,
-        -0.1605,  1.1069, -0.2846,  0.2605, -0.2090,  0.0000, -0.1370, -1.4749,
-         0.0000, -0.0707,  0.0000, -0.2623,  0.0000,  0.0000,  0.4510,  1.0114,
-         0.5646,  0.0000, -0.4489,  0.0000,  0.0000, -0.4404,  0.5320,  0.1315,
-         0.8890,  0.0000,  0.4769, -0.2049, -0.4365, -0.0570,  1.0389,  0.0000,
-         0.0000,  0.0000,  0.3028,  1.1000,  0.0000,  0.3649,  0.0102, -0.6535],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2963,  0.4509,  0.3465, -0.0466,  2.7260,  0.5223,  0.0000,  0.0000,
-         0.3907, -0.4224,  0.0126, -0.3310, -0.3967, -0.1122,  0.0000, -1.2185,
-         0.0000,  0.3427,  0.0000,  0.0000,  0.1058, -0.1497,  0.0000, -0.4093,
-        -0.1605,  1.1069, -0.2846,  0.2605, -0.2090,  0.0000, -0.1370, -1.4749,
-         0.0000, -0.0707,  0.0000, -0.2623,  0.0000,  0.0000,  0.4510,  1.0114,
-         0.5646,  0.0000, -0.4489,  0.0000,  0.0000, -0.4404,  0.5320,  0.1315,
-         0.8890,  0.0000,  0.4769, -0.2049, -0.4365, -0.0570,  1.0389,  0.0000,
-         0.0000,  0.0000,  0.3028,  1.1000,  0.0000,  0.3649,  0.0102, -0.6535],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4912e-01,  4.2760e-01,  3.2308e-01, -5.6127e-02,  2.7220e+00,
-         5.2645e-01,  2.3504e-09, -1.8559e-02,  4.1562e-01, -4.5729e-01,
-        -2.4381e-02, -3.1222e-01, -3.9895e-01, -8.5055e-02, -4.2149e-06,
-        -1.2031e+00, -3.0815e-04,  3.3275e-01, -7.1272e-07,  8.5693e-03,
-         1.1523e-01, -1.6045e-01, -1.5136e-06, -3.8545e-01, -1.6394e-01,
-         1.0957e+00, -2.8261e-01,  2.6544e-01, -1.3976e-01,  2.2012e-06,
-        -8.6239e-02, -1.4628e+00, -5.6348e-10, -7.9504e-02,  2.7968e-04,
-        -2.4722e-01,  9.3295e-06,  0.0000e+00,  3.9389e-01,  9.8736e-01,
-         5.6027e-01, -1.2461e-03, -4.0912e-01,  2.2961e-06,  5.6160e-04,
-        -3.6911e-01,  5.0858e-01,  1.2223e-01,  8.8733e-01, -1.4617e-07,
-         4.7767e-01, -2.7861e-01, -3.8287e-01, -2.7437e-01,  1.0284e+00,
-         3.0421e-11, -1.1956e-05,  4.2998e-09,  3.2058e-01,  1.0982e+00,
-         1.1486e-08,  3.3467e-01, -2.0896e-02, -6.3487e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2491,  0.4276,  0.3231, -0.0561,  2.7220,  0.5264,  0.0000,  0.0000,
-         0.4156, -0.4573, -0.0244, -0.3122, -0.3990, -0.0851,  0.0000, -1.2031,
-         0.0000,  0.3327,  0.0000,  0.0000,  0.1152, -0.1605,  0.0000, -0.3854,
-        -0.1639,  1.0957, -0.2826,  0.2654, -0.1398,  0.0000, -0.0862, -1.4628,
-         0.0000, -0.0795,  0.0000,  0.0000,  0.0000,  0.0000,  0.3939,  0.9874,
-         0.5603,  0.0000, -0.4091,  0.0000,  0.0000, -0.3691,  0.5086,  0.1222,
-         0.8873,  0.0000,  0.4777, -0.2786, -0.3829, -0.2744,  1.0284,  0.0000,
-         0.0000,  0.0000,  0.3206,  1.0982,  0.0000,  0.3347, -0.0209, -0.6349],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2491,  0.4276,  0.3231, -0.0561,  2.7220,  0.5264,  0.0000,  0.0000,
-         0.4156, -0.4573, -0.0244, -0.3122, -0.3990, -0.0851,  0.0000, -1.2031,
-         0.0000,  0.3327,  0.0000,  0.0000,  0.1152, -0.1605,  0.0000, -0.3854,
-        -0.1639,  1.0957, -0.2826,  0.2654, -0.1398,  0.0000, -0.0862, -1.4628,
-         0.0000, -0.0795,  0.0000,  0.0000,  0.0000,  0.0000,  0.3939,  0.9874,
-         0.5603,  0.0000, -0.4091,  0.0000,  0.0000, -0.3691,  0.5086,  0.1222,
-         0.8873,  0.0000,  0.4777, -0.2786, -0.3829, -0.2744,  1.0284,  0.0000,
-         0.0000,  0.0000,  0.3206,  1.0982,  0.0000,  0.3347, -0.0209, -0.6349],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.8975e-01,  3.8360e-01,  3.0019e-01, -7.0464e-03,  2.7199e+00,
-         4.8507e-01,  2.0095e-09, -1.5868e-02,  4.2737e-01, -4.2401e-01,
-        -8.3435e-02, -3.3731e-01, -3.5443e-01, -1.7457e-01, -3.6037e-06,
-        -1.1933e+00, -2.6346e-04,  3.1275e-01, -6.0936e-07,  7.3266e-03,
-         2.7356e-02,  3.3620e-02, -1.2941e-06, -4.1268e-01, -1.0802e-01,
-         1.0780e+00, -2.1754e-01,  2.5102e-01, -8.1246e-02,  1.8820e-06,
-        -5.7902e-03, -1.4553e+00, -4.8176e-10,  4.7904e-02,  2.3912e-04,
-         1.2902e-02,  7.9766e-06,  0.0000e+00,  3.6391e-01,  9.6330e-01,
-         5.2022e-01, -1.0654e-03, -2.8936e-01,  1.9631e-06,  4.8016e-04,
-        -3.0912e-01,  5.0856e-01,  9.7507e-02,  8.9250e-01, -1.2498e-07,
-         4.1062e-01, -3.9516e-01, -4.0363e-01,  5.3886e-02,  1.0152e+00,
-         2.6010e-11, -1.0222e-05,  3.6763e-09,  3.6290e-01,  1.1037e+00,
-         9.8202e-09,  2.3932e-01, -6.0232e-02, -5.8555e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1898,  0.3836,  0.3002, -0.0070,  2.7199,  0.4851,  0.0000,  0.0000,
-         0.4274, -0.4240, -0.0834, -0.3373, -0.3544, -0.1746,  0.0000, -1.1933,
-         0.0000,  0.3127,  0.0000,  0.0000,  0.0274,  0.0336,  0.0000, -0.4127,
-        -0.1080,  1.0780, -0.2175,  0.2510, -0.0812,  0.0000, -0.0058, -1.4553,
-         0.0000,  0.0479,  0.0000,  0.0000,  0.0000,  0.0000,  0.3639,  0.9633,
-         0.5202,  0.0000, -0.2894,  0.0000,  0.0000, -0.3091,  0.5086,  0.0975,
-         0.8925,  0.0000,  0.4106, -0.3952, -0.4036,  0.0539,  1.0152,  0.0000,
-         0.0000,  0.0000,  0.3629,  1.1037,  0.0000,  0.2393, -0.0602, -0.5855],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1898,  0.3836,  0.3002, -0.0070,  2.7199,  0.4851,  0.0000,  0.0000,
-         0.4274, -0.4240, -0.0834, -0.3373, -0.3544, -0.1746,  0.0000, -1.1933,
-         0.0000,  0.3127,  0.0000,  0.0000,  0.0274,  0.0336,  0.0000, -0.4127,
-        -0.1080,  1.0780, -0.2175,  0.2510, -0.0812,  0.0000, -0.0058, -1.4553,
-         0.0000,  0.0479,  0.0000,  0.0000,  0.0000,  0.0000,  0.3639,  0.9633,
-         0.5202,  0.0000, -0.2894,  0.0000,  0.0000, -0.3091,  0.5086,  0.0975,
-         0.8925,  0.0000,  0.4106, -0.3952, -0.4036,  0.0539,  1.0152,  0.0000,
-         0.0000,  0.0000,  0.3629,  1.1037,  0.0000,  0.2393, -0.0602, -0.5855],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.5906e-01,  3.3782e-01,  3.0480e-01, -6.2721e-02,  2.7197e+00,
-         4.9805e-01,  1.7185e-09, -1.3569e-02,  4.2031e-01, -3.6243e-01,
-         4.8423e-02, -3.3672e-01, -3.4525e-01, -2.7846e-01, -3.0817e-06,
-        -1.1947e+00, -2.2530e-04,  2.1786e-01, -5.2111e-07,  6.2654e-03,
-        -7.4090e-02,  2.2070e-01, -1.1066e-06, -4.2115e-01,  4.7580e-02,
-         1.0761e+00, -2.1757e-01,  2.1229e-01,  5.7261e-02,  1.6094e-06,
-        -2.0782e-03, -1.4484e+00, -4.1199e-10,  2.2591e-01,  2.0449e-04,
-         1.1034e-02,  6.8213e-06,  0.0000e+00,  4.2499e-01,  9.7126e-01,
-         4.7087e-01, -9.1110e-04, -2.4300e-01,  1.6788e-06,  4.1061e-04,
-        -3.2012e-01,  4.9045e-01,  1.4869e-01,  8.9514e-01, -1.0688e-07,
-         3.6222e-01, -4.3391e-01, -4.2441e-01,  1.8858e-01,  1.0138e+00,
-         2.2243e-11, -8.7417e-06,  3.1438e-09,  3.6721e-01,  1.1092e+00,
-         8.3979e-09,  4.4375e-02, -8.7367e-03, -5.7737e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 1.5906e-01,  3.3782e-01,  3.0480e-01, -6.2721e-02,  2.7197e+00,
-         4.9805e-01,  0.0000e+00,  0.0000e+00,  4.2031e-01, -3.6243e-01,
-         4.8423e-02, -3.3672e-01, -3.4525e-01, -2.7846e-01,  0.0000e+00,
-        -1.1947e+00,  0.0000e+00,  2.1786e-01,  0.0000e+00,  0.0000e+00,
-        -7.4090e-02,  2.2070e-01,  0.0000e+00, -4.2115e-01,  4.7580e-02,
-         1.0761e+00, -2.1757e-01,  2.1229e-01,  5.7261e-02,  0.0000e+00,
-        -2.0782e-03, -1.4484e+00,  0.0000e+00,  2.2591e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2499e-01,  9.7126e-01,
-         4.7087e-01,  0.0000e+00, -2.4300e-01,  0.0000e+00,  0.0000e+00,
-        -3.2012e-01,  4.9045e-01,  1.4869e-01,  8.9514e-01,  0.0000e+00,
-         3.6222e-01, -4.3391e-01, -4.2441e-01,  1.8858e-01,  1.0138e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.6721e-01,  1.1092e+00,
-         0.0000e+00,  4.4375e-02, -8.7367e-03, -5.7737e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 1.5906e-01,  3.3782e-01,  3.0480e-01, -6.2721e-02,  2.7197e+00,
-         4.9805e-01,  0.0000e+00,  0.0000e+00,  4.2031e-01, -3.6243e-01,
-         4.8423e-02, -3.3672e-01, -3.4525e-01, -2.7846e-01,  0.0000e+00,
-        -1.1947e+00,  0.0000e+00,  2.1786e-01,  0.0000e+00,  0.0000e+00,
-        -7.4090e-02,  2.2070e-01,  0.0000e+00, -4.2115e-01,  4.7580e-02,
-         1.0761e+00, -2.1757e-01,  2.1229e-01,  5.7261e-02,  0.0000e+00,
-        -2.0782e-03, -1.4484e+00,  0.0000e+00,  2.2591e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2499e-01,  9.7126e-01,
-         4.7087e-01,  0.0000e+00, -2.4300e-01,  0.0000e+00,  0.0000e+00,
-        -3.2012e-01,  4.9045e-01,  1.4869e-01,  8.9514e-01,  0.0000e+00,
-         3.6222e-01, -4.3391e-01, -4.2441e-01,  1.8858e-01,  1.0138e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.6721e-01,  1.1092e+00,
-         0.0000e+00,  4.4375e-02, -8.7367e-03, -5.7737e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.4747e-01,  2.5880e-01,  2.7776e-01, -1.2920e-01,  2.7231e+00,
-         4.7706e-01,  1.4699e-09, -1.1607e-02,  4.0159e-01, -3.1927e-01,
-         6.1222e-02, -2.8541e-01, -3.1573e-01, -2.4676e-01, -2.6360e-06,
-        -1.1963e+00, -1.9271e-04,  1.1561e-01, -4.4573e-07,  5.3592e-03,
-        -2.5868e-01,  3.2565e-01, -9.4657e-07, -4.4370e-01,  1.2411e-01,
-         1.0794e+00, -2.5003e-01,  1.8258e-01,  4.7644e-02,  1.3767e-06,
-         2.3525e-02, -1.4454e+00, -3.5240e-10,  3.0549e-01,  1.7491e-04,
-         9.4376e-03,  5.8347e-06,  0.0000e+00,  4.8075e-01,  9.7776e-01,
-         4.1098e-01, -7.7932e-04, -2.2798e-01,  1.4360e-06,  3.5122e-04,
-        -2.7862e-01,  4.9180e-01,  1.8175e-01,  9.0644e-01, -9.1416e-08,
-         3.0758e-01, -4.2692e-01, -4.1507e-01,  2.2718e-01,  1.0155e+00,
-         1.9025e-11, -7.4773e-06,  2.6891e-09,  3.3442e-01,  1.1188e+00,
-         7.1832e-09, -1.6147e-01,  5.4861e-02, -5.6404e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1475,  0.2588,  0.2778, -0.1292,  2.7231,  0.4771,  0.0000,  0.0000,
-         0.4016, -0.3193,  0.0612, -0.2854, -0.3157, -0.2468,  0.0000, -1.1963,
-         0.0000,  0.1156,  0.0000,  0.0000, -0.2587,  0.3257,  0.0000, -0.4437,
-         0.1241,  1.0794, -0.2500,  0.1826,  0.0476,  0.0000,  0.0235, -1.4454,
-         0.0000,  0.3055,  0.0000,  0.0000,  0.0000,  0.0000,  0.4808,  0.9778,
-         0.4110,  0.0000, -0.2280,  0.0000,  0.0000, -0.2786,  0.4918,  0.1818,
-         0.9064,  0.0000,  0.3076, -0.4269, -0.4151,  0.2272,  1.0155,  0.0000,
-         0.0000,  0.0000,  0.3344,  1.1188,  0.0000, -0.1615,  0.0549, -0.5640],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1475,  0.2588,  0.2778, -0.1292,  2.7231,  0.4771,  0.0000,  0.0000,
-         0.4016, -0.3193,  0.0612, -0.2854, -0.3157, -0.2468,  0.0000, -1.1963,
-         0.0000,  0.1156,  0.0000,  0.0000, -0.2587,  0.3257,  0.0000, -0.4437,
-         0.1241,  1.0794, -0.2500,  0.1826,  0.0476,  0.0000,  0.0235, -1.4454,
-         0.0000,  0.3055,  0.0000,  0.0000,  0.0000,  0.0000,  0.4808,  0.9778,
-         0.4110,  0.0000, -0.2280,  0.0000,  0.0000, -0.2786,  0.4918,  0.1818,
-         0.9064,  0.0000,  0.3076, -0.4269, -0.4151,  0.2272,  1.0155,  0.0000,
-         0.0000,  0.0000,  0.3344,  1.1188,  0.0000, -0.1615,  0.0549, -0.5640],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.3775e-01,  1.7042e-01,  2.4988e-01, -1.5140e-01,  2.7253e+00,
-         4.3696e-01,  1.2576e-09, -9.9301e-03,  3.6596e-01, -2.9710e-01,
-         7.8235e-02, -2.0984e-01, -2.5707e-01, -1.6060e-01, -2.2552e-06,
-        -1.1980e+00, -1.6488e-04, -2.4578e-02, -3.8134e-07,  4.5850e-03,
-        -3.9738e-01,  4.0200e-01, -8.0984e-07, -4.7869e-01,  1.6679e-01,
-         1.0756e+00, -2.4501e-01,  1.3527e-01,  1.1726e-02,  1.1778e-06,
-        -1.4915e-02, -1.4448e+00, -3.0149e-10,  3.4816e-01,  1.4965e-04,
-         8.0743e-03,  4.9918e-06,  0.0000e+00,  5.1952e-01,  9.8283e-01,
-         3.7718e-01, -6.6674e-04, -1.7695e-01,  1.2285e-06,  3.0049e-04,
-        -2.7716e-01,  4.9637e-01,  2.6612e-01,  9.0918e-01, -7.8211e-08,
-         2.9908e-01, -4.1023e-01, -3.9323e-01,  1.7808e-01,  1.0273e+00,
-         1.6277e-11, -6.3971e-06,  2.3006e-09,  3.0368e-01,  1.1245e+00,
-         6.1456e-09, -3.2575e-01,  1.8428e-01, -5.1446e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1378,  0.1704,  0.2499, -0.1514,  2.7253,  0.4370,  0.0000,  0.0000,
-         0.3660, -0.2971,  0.0782, -0.2098, -0.2571, -0.1606,  0.0000,  0.0000,
-         0.0000, -0.0246,  0.0000,  0.0000, -0.3974,  0.4020,  0.0000, -0.4787,
-         0.1668,  1.0756, -0.2450,  0.1353,  0.0117,  0.0000, -0.0149, -1.4448,
-         0.0000,  0.3482,  0.0000,  0.0000,  0.0000,  0.0000,  0.5195,  0.9828,
-         0.3772,  0.0000, -0.1769,  0.0000,  0.0000, -0.2772,  0.4964,  0.2661,
-         0.9092,  0.0000,  0.2991, -0.4102, -0.3932,  0.1781,  1.0273,  0.0000,
-         0.0000,  0.0000,  0.3037,  1.1245,  0.0000, -0.3257,  0.1843, -0.5145],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1378,  0.1704,  0.2499, -0.1514,  2.7253,  0.4370,  0.0000,  0.0000,
-         0.3660, -0.2971,  0.0782, -0.2098, -0.2571, -0.1606,  0.0000,  0.0000,
-         0.0000, -0.0246,  0.0000,  0.0000, -0.3974,  0.4020,  0.0000, -0.4787,
-         0.1668,  1.0756, -0.2450,  0.1353,  0.0117,  0.0000, -0.0149, -1.4448,
-         0.0000,  0.3482,  0.0000,  0.0000,  0.0000,  0.0000,  0.5195,  0.9828,
-         0.3772,  0.0000, -0.1769,  0.0000,  0.0000, -0.2772,  0.4964,  0.2661,
-         0.9092,  0.0000,  0.2991, -0.4102, -0.3932,  0.1781,  1.0273,  0.0000,
-         0.0000,  0.0000,  0.3037,  1.1245,  0.0000, -0.3257,  0.1843, -0.5145],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.4441e-01,  5.3463e-02,  2.4305e-01, -9.3927e-02,  2.7255e+00,
-         4.2809e-01,  1.0762e-09, -8.4976e-03,  3.1551e-01, -2.9110e-01,
-        -4.5661e-02, -1.6388e-01, -2.2305e-01, -3.0147e-02, -1.9299e-06,
-        -1.4344e-03, -1.4109e-04, -1.5792e-01, -3.2633e-07,  3.9236e-03,
-        -5.1015e-01,  4.6752e-01, -6.9301e-07, -5.4081e-01,  1.7424e-01,
-         1.0796e+00, -1.9255e-01,  1.1037e-01, -9.3539e-02,  1.0079e-06,
-        -2.8732e-02, -1.4460e+00, -2.5800e-10,  3.4710e-01,  1.2806e-04,
-         6.9095e-03,  4.2717e-06,  0.0000e+00,  5.3135e-01,  9.8822e-01,
-         2.9936e-01, -5.7056e-04, -9.4871e-02,  1.0513e-06,  2.5714e-04,
-        -3.2398e-01,  5.0008e-01,  3.2340e-01,  9.1593e-01, -6.6928e-08,
-         2.4303e-01, -4.6766e-01, -4.1479e-01,  2.2508e-01,  1.0406e+00,
-         1.3929e-11, -5.4743e-06,  1.9688e-09,  2.6128e-01,  1.1341e+00,
-         5.2590e-09, -4.5035e-01,  2.7184e-01, -4.5637e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1444,  0.0535,  0.2431, -0.0939,  2.7255,  0.4281,  0.0000,  0.0000,
-         0.3155, -0.2911, -0.0457, -0.1639, -0.2230, -0.0301,  0.0000,  0.0000,
-         0.0000, -0.1579,  0.0000,  0.0000, -0.5102,  0.4675,  0.0000, -0.5408,
-         0.1742,  1.0796, -0.1926,  0.1104, -0.0935,  0.0000, -0.0287, -1.4460,
-         0.0000,  0.3471,  0.0000,  0.0000,  0.0000,  0.0000,  0.5314,  0.9882,
-         0.2994,  0.0000, -0.0949,  0.0000,  0.0000, -0.3240,  0.5001,  0.3234,
-         0.9159,  0.0000,  0.2430, -0.4677, -0.4148,  0.2251,  1.0406,  0.0000,
-         0.0000,  0.0000,  0.2613,  1.1341,  0.0000, -0.4503,  0.2718, -0.4564],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1444,  0.0535,  0.2431, -0.0939,  2.7255,  0.4281,  0.0000,  0.0000,
-         0.3155, -0.2911, -0.0457, -0.1639, -0.2230, -0.0301,  0.0000,  0.0000,
-         0.0000, -0.1579,  0.0000,  0.0000, -0.5102,  0.4675,  0.0000, -0.5408,
-         0.1742,  1.0796, -0.1926,  0.1104, -0.0935,  0.0000, -0.0287, -1.4460,
-         0.0000,  0.3471,  0.0000,  0.0000,  0.0000,  0.0000,  0.5314,  0.9882,
-         0.2994,  0.0000, -0.0949,  0.0000,  0.0000, -0.3240,  0.5001,  0.3234,
-         0.9159,  0.0000,  0.2430, -0.4677, -0.4148,  0.2251,  1.0406,  0.0000,
-         0.0000,  0.0000,  0.2613,  1.1341,  0.0000, -0.4503,  0.2718, -0.4564],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.0892e-01, -5.0623e-02,  1.9014e-01, -6.4007e-02,  2.7256e+00,
-         4.0215e-01,  9.2113e-10, -7.2734e-03,  2.5854e-01, -2.2957e-01,
-        -2.0593e-01, -2.1027e-01, -2.2740e-01,  2.1232e-02, -1.6519e-06,
-        -1.2277e-03, -1.2076e-04, -2.4978e-01, -2.7932e-07,  3.3584e-03,
-        -6.0895e-01,  5.6245e-01, -5.9317e-07, -6.0964e-01,  2.0902e-01,
-         1.0860e+00, -1.6120e-01,  7.7727e-02, -2.2549e-01,  8.6269e-07,
-        -1.6265e-02, -1.4469e+00, -2.2083e-10,  3.4150e-01,  1.0961e-04,
-         5.9141e-03,  3.6563e-06,  0.0000e+00,  5.0018e-01,  9.8313e-01,
-         2.5468e-01, -4.8836e-04, -4.7183e-02,  8.9985e-07,  2.2009e-04,
-        -3.3648e-01,  5.0457e-01,  3.3175e-01,  9.2041e-01, -5.7287e-08,
-         2.0069e-01, -5.0893e-01, -3.9296e-01,  2.6548e-01,  1.0424e+00,
-         1.1922e-11, -4.6857e-06,  1.6851e-09,  2.3007e-01,  1.1444e+00,
-         4.5014e-09, -4.9671e-01,  3.2431e-01, -4.4017e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1089, -0.0506,  0.1901, -0.0640,  2.7256,  0.4022,  0.0000,  0.0000,
-         0.2585, -0.2296, -0.2059, -0.2103, -0.2274,  0.0212,  0.0000,  0.0000,
-         0.0000, -0.2498,  0.0000,  0.0000, -0.6090,  0.5625,  0.0000, -0.6096,
-         0.2090,  1.0860, -0.1612,  0.0777, -0.2255,  0.0000, -0.0163, -1.4469,
-         0.0000,  0.3415,  0.0000,  0.0000,  0.0000,  0.0000,  0.5002,  0.9831,
-         0.2547,  0.0000, -0.0472,  0.0000,  0.0000, -0.3365,  0.5046,  0.3318,
-         0.9204,  0.0000,  0.2007, -0.5089, -0.3930,  0.2655,  1.0424,  0.0000,
-         0.0000,  0.0000,  0.2301,  1.1444,  0.0000, -0.4967,  0.3243, -0.4402],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1089, -0.0506,  0.1901, -0.0640,  2.7256,  0.4022,  0.0000,  0.0000,
-         0.2585, -0.2296, -0.2059, -0.2103, -0.2274,  0.0212,  0.0000,  0.0000,
-         0.0000, -0.2498,  0.0000,  0.0000, -0.6090,  0.5625,  0.0000, -0.6096,
-         0.2090,  1.0860, -0.1612,  0.0777, -0.2255,  0.0000, -0.0163, -1.4469,
-         0.0000,  0.3415,  0.0000,  0.0000,  0.0000,  0.0000,  0.5002,  0.9831,
-         0.2547,  0.0000, -0.0472,  0.0000,  0.0000, -0.3365,  0.5046,  0.3318,
-         0.9204,  0.0000,  0.2007, -0.5089, -0.3930,  0.2655,  1.0424,  0.0000,
-         0.0000,  0.0000,  0.2301,  1.1444,  0.0000, -0.4967,  0.3243, -0.4402],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.0077e-01, -9.7602e-02,  1.6039e-01, -5.8889e-02,  2.7230e+00,
-         3.7489e-01,  7.8862e-10, -6.2270e-03,  2.3456e-01, -2.4798e-01,
-        -5.4535e-02, -2.3648e-01, -2.1779e-01, -3.6407e-02, -1.4142e-06,
-        -1.0511e-03, -1.0339e-04, -1.7114e-01, -2.3914e-07,  2.8752e-03,
-        -6.8931e-01,  6.5400e-01, -5.0784e-07, -6.1865e-01,  2.6618e-01,
-         1.0904e+00, -1.9180e-01,  5.2599e-02, -2.5756e-01,  7.3858e-07,
-        -1.0909e-01, -1.4444e+00, -1.8906e-10,  3.1947e-01,  9.3841e-05,
-         5.0633e-03,  3.1303e-06,  0.0000e+00,  4.4377e-01,  9.6500e-01,
-         2.2279e-01, -4.1811e-04,  2.1944e-02,  7.7040e-07,  1.8843e-04,
-        -3.0929e-01,  5.0590e-01,  3.0948e-01,  9.2250e-01, -4.9045e-08,
-         2.2910e-01, -4.6234e-01, -3.3872e-01,  2.0925e-01,  1.0285e+00,
-         1.0207e-11, -4.0116e-06,  1.4427e-09,  2.0616e-01,  1.1426e+00,
-         3.8538e-09, -5.4856e-01,  3.7626e-01, -4.1290e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1008, -0.0976,  0.1604, -0.0589,  2.7230,  0.3749,  0.0000,  0.0000,
-         0.2346, -0.2480, -0.0545, -0.2365, -0.2178, -0.0364,  0.0000,  0.0000,
-         0.0000, -0.1711,  0.0000,  0.0000, -0.6893,  0.6540,  0.0000, -0.6187,
-         0.2662,  1.0904, -0.1918,  0.0526, -0.2576,  0.0000, -0.1091, -1.4444,
-         0.0000,  0.3195,  0.0000,  0.0000,  0.0000,  0.0000,  0.4438,  0.9650,
-         0.2228,  0.0000,  0.0219,  0.0000,  0.0000, -0.3093,  0.5059,  0.3095,
-         0.9225,  0.0000,  0.2291, -0.4623, -0.3387,  0.2092,  1.0285,  0.0000,
-         0.0000,  0.0000,  0.2062,  1.1426,  0.0000, -0.5486,  0.3763, -0.4129],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1008, -0.0976,  0.1604, -0.0589,  2.7230,  0.3749,  0.0000,  0.0000,
-         0.2346, -0.2480, -0.0545, -0.2365, -0.2178, -0.0364,  0.0000,  0.0000,
-         0.0000, -0.1711,  0.0000,  0.0000, -0.6893,  0.6540,  0.0000, -0.6187,
-         0.2662,  1.0904, -0.1918,  0.0526, -0.2576,  0.0000, -0.1091, -1.4444,
-         0.0000,  0.3195,  0.0000,  0.0000,  0.0000,  0.0000,  0.4438,  0.9650,
-         0.2228,  0.0000,  0.0219,  0.0000,  0.0000, -0.3093,  0.5059,  0.3095,
-         0.9225,  0.0000,  0.2291, -0.4623, -0.3387,  0.2092,  1.0285,  0.0000,
-         0.0000,  0.0000,  0.2062,  1.1426,  0.0000, -0.5486,  0.3763, -0.4129],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.0843e-01, -6.9421e-02,  1.4900e-01, -1.0442e-01,  2.7227e+00,
-         3.5106e-01,  6.7533e-10, -5.3325e-03,  1.7863e-01, -2.2125e-01,
-         2.0709e-01, -2.9004e-01, -2.3644e-01, -1.9440e-01, -1.2111e-06,
-        -9.0010e-04, -8.8538e-05, -9.1089e-02, -2.0478e-07,  2.4622e-03,
-        -7.7383e-01,  7.2631e-01, -4.3488e-07, -6.4622e-01,  3.0499e-01,
-         1.0867e+00, -2.3092e-01, -1.9291e-02, -3.1787e-01,  6.3248e-07,
-        -2.5053e-01, -1.4455e+00, -1.6190e-10,  2.6691e-01,  8.0360e-05,
-         4.3359e-03,  2.6806e-06,  0.0000e+00,  3.5932e-01,  9.3889e-01,
-         1.8620e-01, -3.5804e-04,  9.0788e-02,  6.5972e-07,  1.6136e-04,
-        -2.7528e-01,  4.8699e-01,  2.9125e-01,  9.2230e-01, -4.1999e-08,
-         2.9150e-01, -4.2251e-01, -2.6801e-01,  2.2634e-01,  1.0095e+00,
-         8.7408e-12, -3.4353e-06,  1.2355e-09,  1.8177e-01,  1.1431e+00,
-         3.3002e-09, -5.6880e-01,  4.3173e-01, -4.1387e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1084, -0.0694,  0.1490, -0.1044,  2.7227,  0.3511,  0.0000,  0.0000,
-         0.1786, -0.2213,  0.2071, -0.2900, -0.2364, -0.1944,  0.0000,  0.0000,
-         0.0000, -0.0911,  0.0000,  0.0000, -0.7738,  0.7263,  0.0000, -0.6462,
-         0.3050,  1.0867, -0.2309, -0.0193, -0.3179,  0.0000, -0.2505, -1.4455,
-         0.0000,  0.2669,  0.0000,  0.0000,  0.0000,  0.0000,  0.3593,  0.9389,
-         0.1862,  0.0000,  0.0908,  0.0000,  0.0000, -0.2753,  0.4870,  0.2913,
-         0.9223,  0.0000,  0.2915, -0.4225, -0.2680,  0.2263,  1.0095,  0.0000,
-         0.0000,  0.0000,  0.1818,  1.1431,  0.0000, -0.5688,  0.4317, -0.4139],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1084, -0.0694,  0.1490, -0.1044,  2.7227,  0.3511,  0.0000,  0.0000,
-         0.1786, -0.2213,  0.2071, -0.2900, -0.2364, -0.1944,  0.0000,  0.0000,
-         0.0000, -0.0911,  0.0000,  0.0000, -0.7738,  0.7263,  0.0000, -0.6462,
-         0.3050,  1.0867, -0.2309, -0.0193, -0.3179,  0.0000, -0.2505, -1.4455,
-         0.0000,  0.2669,  0.0000,  0.0000,  0.0000,  0.0000,  0.3593,  0.9389,
-         0.1862,  0.0000,  0.0908,  0.0000,  0.0000, -0.2753,  0.4870,  0.2913,
-         0.9223,  0.0000,  0.2915, -0.4225, -0.2680,  0.2263,  1.0095,  0.0000,
-         0.0000,  0.0000,  0.1818,  1.1431,  0.0000, -0.5688,  0.4317, -0.4139],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.7547e-01, -4.1972e-02,  1.4288e-01, -1.0000e-01,  2.7193e+00,
-         3.3817e-01,  5.7845e-10, -4.5675e-03,  9.2131e-02, -9.5394e-02,
-         2.9351e-01, -3.9305e-01, -2.4236e-01, -3.2972e-01, -1.0373e-06,
-        -7.7098e-04, -7.5837e-05, -1.1179e-02, -1.7541e-07,  2.1090e-03,
-        -8.2351e-01,  7.9083e-01, -3.7250e-07, -6.2581e-01,  3.1280e-01,
-         1.0822e+00, -2.0750e-01, -7.4404e-02, -3.1729e-01,  5.4174e-07,
-        -2.2476e-01, -1.4466e+00, -1.3868e-10,  1.9755e-01,  6.8832e-05,
-         3.7139e-03,  2.2961e-06,  0.0000e+00,  2.6670e-01,  8.9971e-01,
-         1.0476e-01, -3.0668e-04,  1.5639e-01,  5.6509e-07,  1.3821e-04,
-        -3.3379e-01,  4.6749e-01,  2.2909e-01,  9.4535e-01, -3.5975e-08,
-         2.8409e-01, -3.6624e-01, -2.4897e-01,  3.5014e-01,  9.8552e-01,
-         7.4869e-12, -2.9425e-06,  1.0582e-09,  1.3642e-01,  1.1371e+00,
-         2.8268e-09, -5.3266e-01,  4.4751e-01, -4.2522e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1755, -0.0420,  0.1429, -0.1000,  2.7193,  0.3382,  0.0000,  0.0000,
-         0.0921, -0.0954,  0.2935, -0.3931, -0.2424, -0.3297,  0.0000,  0.0000,
-         0.0000, -0.0112,  0.0000,  0.0000, -0.8235,  0.7908,  0.0000, -0.6258,
-         0.3128,  1.0822, -0.2075, -0.0744, -0.3173,  0.0000, -0.2248, -1.4466,
-         0.0000,  0.1975,  0.0000,  0.0000,  0.0000,  0.0000,  0.2667,  0.8997,
-         0.1048,  0.0000,  0.1564,  0.0000,  0.0000, -0.3338,  0.0000,  0.2291,
-         0.9453,  0.0000,  0.2841, -0.3662, -0.2490,  0.3501,  0.9855,  0.0000,
-         0.0000,  0.0000,  0.1364,  1.1371,  0.0000, -0.5327,  0.4475, -0.4252],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1755, -0.0420,  0.1429, -0.1000,  2.7193,  0.3382,  0.0000,  0.0000,
-         0.0921, -0.0954,  0.2935, -0.3931, -0.2424, -0.3297,  0.0000,  0.0000,
-         0.0000, -0.0112,  0.0000,  0.0000, -0.8235,  0.7908,  0.0000, -0.6258,
-         0.3128,  1.0822, -0.2075, -0.0744, -0.3173,  0.0000, -0.2248, -1.4466,
-         0.0000,  0.1975,  0.0000,  0.0000,  0.0000,  0.0000,  0.2667,  0.8997,
-         0.1048,  0.0000,  0.1564,  0.0000,  0.0000, -0.3338,  0.0000,  0.2291,
-         0.9453,  0.0000,  0.2841, -0.3662, -0.2490,  0.3501,  0.9855,  0.0000,
-         0.0000,  0.0000,  0.1364,  1.1371,  0.0000, -0.5327,  0.4475, -0.4252],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.1371e-01, -2.6869e-02,  7.7750e-02, -1.0547e-01,  2.7176e+00,
-         3.2486e-01,  4.9559e-10, -3.9132e-03, -4.2963e-02, -2.9258e-02,
-         2.9280e-01, -4.4309e-01, -2.6921e-01, -3.9869e-01, -8.8873e-07,
-        -6.6054e-04, -6.4974e-05,  1.3126e-02, -1.5028e-07,  1.8069e-03,
-        -8.5762e-01,  7.3957e-01, -3.1914e-07, -6.6854e-01,  1.8483e-01,
-         1.0770e+00, -1.8161e-01, -1.4929e-01, -3.3667e-01,  4.6414e-07,
-        -2.3617e-01, -1.4476e+00, -1.1881e-10, -8.7138e-03,  5.8972e-05,
-         3.1819e-03,  1.9672e-06,  0.0000e+00,  2.7619e-01,  8.6511e-01,
-        -1.6681e-02, -2.6275e-04,  1.9125e-01,  4.8414e-07,  1.1842e-04,
-        -2.6737e-01, -1.6708e-02,  3.0226e-01,  9.7715e-01, -3.0821e-08,
-         2.5838e-01, -3.5294e-01, -1.9338e-01,  2.2639e-01,  9.6249e-01,
-         6.4144e-12, -2.5210e-06,  9.0664e-10,  1.6967e-01,  1.1374e+00,
-         2.4218e-09, -4.7681e-01,  4.3146e-01, -4.2991e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2137, -0.0269,  0.0778, -0.1055,  2.7176,  0.3249,  0.0000,  0.0000,
-        -0.0430, -0.0293,  0.2928, -0.4431, -0.2692, -0.3987,  0.0000,  0.0000,
-         0.0000,  0.0131,  0.0000,  0.0000, -0.8576,  0.7396,  0.0000, -0.6685,
-         0.1848,  1.0770, -0.1816, -0.1493, -0.3367,  0.0000, -0.2362, -1.4476,
-         0.0000, -0.0087,  0.0000,  0.0000,  0.0000,  0.0000,  0.2762,  0.8651,
-        -0.0167,  0.0000,  0.1912,  0.0000,  0.0000, -0.2674,  0.0000,  0.3023,
-         0.9772,  0.0000,  0.2584, -0.3529, -0.1934,  0.2264,  0.9625,  0.0000,
-         0.0000,  0.0000,  0.1697,  1.1374,  0.0000, -0.4768,  0.4315, -0.4299],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2137, -0.0269,  0.0778, -0.1055,  2.7176,  0.3249,  0.0000,  0.0000,
-        -0.0430, -0.0293,  0.2928, -0.4431, -0.2692, -0.3987,  0.0000,  0.0000,
-         0.0000,  0.0131,  0.0000,  0.0000, -0.8576,  0.7396,  0.0000, -0.6685,
-         0.1848,  1.0770, -0.1816, -0.1493, -0.3367,  0.0000, -0.2362, -1.4476,
-         0.0000, -0.0087,  0.0000,  0.0000,  0.0000,  0.0000,  0.2762,  0.8651,
-        -0.0167,  0.0000,  0.1912,  0.0000,  0.0000, -0.2674,  0.0000,  0.3023,
-         0.9772,  0.0000,  0.2584, -0.3529, -0.1934,  0.2264,  0.9625,  0.0000,
-         0.0000,  0.0000,  0.1697,  1.1374,  0.0000, -0.4768,  0.4315, -0.4299],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4871e-01, -4.7986e-02,  6.1407e-02, -4.7337e-02,  2.7156e+00,
-         3.0784e-01,  4.2470e-10, -3.3535e-03, -1.2093e-01, -1.3794e-02,
-         2.0347e-01, -4.5561e-01, -2.6295e-01, -4.3799e-01, -7.6162e-07,
-        -5.6606e-04, -5.5681e-05,  1.0214e-02, -1.2879e-07,  1.5484e-03,
-        -8.7576e-01,  6.6823e-01, -2.7349e-07, -7.0658e-01,  4.2610e-02,
-         1.0705e+00, -1.3507e-01, -2.1437e-01, -3.4740e-01,  3.9776e-07,
-        -2.2285e-01, -1.4453e+00, -1.0182e-10, -1.5795e-01,  5.0537e-05,
-         2.7268e-03,  1.6858e-06,  0.0000e+00,  2.8467e-01,  8.4643e-01,
-        -1.7527e-01, -2.2517e-04,  2.2805e-01,  4.1489e-07,  1.0148e-04,
-        -1.7095e-01, -1.4318e-02,  3.6932e-01,  9.9764e-01, -2.6413e-08,
-         1.7752e-01, -3.1618e-01, -1.4959e-01,  1.0415e-01,  9.6082e-01,
-         5.4970e-12, -2.1604e-06,  7.7696e-10,  1.9751e-01,  1.1420e+00,
-         2.0754e-09, -3.6488e-01,  3.8071e-01, -3.8700e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2487, -0.0480,  0.0614, -0.0473,  2.7156,  0.3078,  0.0000,  0.0000,
-        -0.1209, -0.0138,  0.2035, -0.4556, -0.2629, -0.4380,  0.0000,  0.0000,
-         0.0000,  0.0102,  0.0000,  0.0000, -0.8758,  0.6682,  0.0000, -0.7066,
-         0.0426,  1.0705, -0.1351, -0.2144, -0.3474,  0.0000, -0.2229, -1.4453,
-         0.0000, -0.1579,  0.0000,  0.0000,  0.0000,  0.0000,  0.2847,  0.8464,
-        -0.1753,  0.0000,  0.2281,  0.0000,  0.0000, -0.1709,  0.0000,  0.3693,
-         0.9976,  0.0000,  0.1775, -0.3162, -0.1496,  0.1042,  0.9608,  0.0000,
-         0.0000,  0.0000,  0.1975,  1.1420,  0.0000, -0.3649,  0.3807, -0.3870],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2487, -0.0480,  0.0614, -0.0473,  2.7156,  0.3078,  0.0000,  0.0000,
-        -0.1209, -0.0138,  0.2035, -0.4556, -0.2629, -0.4380,  0.0000,  0.0000,
-         0.0000,  0.0102,  0.0000,  0.0000, -0.8758,  0.6682,  0.0000, -0.7066,
-         0.0426,  1.0705, -0.1351, -0.2144, -0.3474,  0.0000, -0.2229, -1.4453,
-         0.0000, -0.1579,  0.0000,  0.0000,  0.0000,  0.0000,  0.2847,  0.8464,
-        -0.1753,  0.0000,  0.2281,  0.0000,  0.0000, -0.1709,  0.0000,  0.3693,
-         0.9976,  0.0000,  0.1775, -0.3162, -0.1496,  0.1042,  0.9608,  0.0000,
-         0.0000,  0.0000,  0.1975,  1.1420,  0.0000, -0.3649,  0.3807, -0.3870],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.2867e-01, -1.5250e-02, -3.7444e-02, -4.9020e-02,  2.7123e+00,
-         2.4129e-01,  3.6405e-10, -2.8746e-03, -2.2854e-01,  1.5173e-02,
-         1.0161e-01, -4.9490e-01, -3.0574e-01, -4.6745e-01, -6.5284e-07,
-        -4.8522e-04, -4.7729e-05, -4.2990e-02, -1.1039e-07,  1.3273e-03,
-        -8.7459e-01,  6.4497e-01, -2.3443e-07, -7.5085e-01,  7.6758e-03,
-         1.0710e+00, -8.2192e-02, -2.4681e-01, -3.0131e-01,  3.4095e-07,
-        -1.3223e-01, -1.4415e+00, -8.7276e-11, -2.8222e-01,  4.3320e-05,
-         2.3374e-03,  1.4450e-06,  0.0000e+00,  3.0921e-01,  8.1886e-01,
-        -2.2082e-01, -1.9301e-04,  2.3073e-01,  3.5564e-07,  8.6985e-05,
-        -1.0220e-01, -1.2273e-02,  3.9944e-01,  1.0259e+00, -2.2641e-08,
-         9.8165e-02, -2.4832e-01, -9.6035e-02, -7.4285e-02,  9.5476e-01,
-         4.7119e-12, -1.8519e-06,  6.6600e-10,  2.0447e-01,  1.1440e+00,
-         1.7790e-09, -3.5145e-01,  3.6884e-01, -3.9327e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2287, -0.0153, -0.0374, -0.0490,  2.7123,  0.2413,  0.0000,  0.0000,
-        -0.2285,  0.0152,  0.1016, -0.4949, -0.3057, -0.4675,  0.0000,  0.0000,
-         0.0000, -0.0430,  0.0000,  0.0000, -0.8746,  0.6450,  0.0000, -0.7509,
-         0.0077,  1.0710, -0.0822, -0.2468, -0.3013,  0.0000, -0.1322, -1.4415,
-         0.0000, -0.2822,  0.0000,  0.0000,  0.0000,  0.0000,  0.3092,  0.8189,
-        -0.2208,  0.0000,  0.2307,  0.0000,  0.0000, -0.1022,  0.0000,  0.3994,
-         1.0259,  0.0000,  0.0982, -0.2483, -0.0960, -0.0743,  0.9548,  0.0000,
-         0.0000,  0.0000,  0.2045,  1.1440,  0.0000, -0.3514,  0.3688, -0.3933],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2287, -0.0153, -0.0374, -0.0490,  2.7123,  0.2413,  0.0000,  0.0000,
-        -0.2285,  0.0152,  0.1016, -0.4949, -0.3057, -0.4675,  0.0000,  0.0000,
-         0.0000, -0.0430,  0.0000,  0.0000, -0.8746,  0.6450,  0.0000, -0.7509,
-         0.0077,  1.0710, -0.0822, -0.2468, -0.3013,  0.0000, -0.1322, -1.4415,
-         0.0000, -0.2822,  0.0000,  0.0000,  0.0000,  0.0000,  0.3092,  0.8189,
-        -0.2208,  0.0000,  0.2307,  0.0000,  0.0000, -0.1022,  0.0000,  0.3994,
-         1.0259,  0.0000,  0.0982, -0.2483, -0.0960, -0.0743,  0.9548,  0.0000,
-         0.0000,  0.0000,  0.2045,  1.1440,  0.0000, -0.3514,  0.3688, -0.3933],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4716e-01, -1.0358e-03, -9.9892e-02, -1.9029e-05,  2.7086e+00,
-         2.1204e-01,  3.1214e-10, -2.4647e-03, -3.0539e-01,  3.6697e-02,
-         3.3641e-02, -5.1736e-01, -3.1680e-01, -4.8411e-01, -5.5975e-07,
-        -4.1602e-04, -4.0922e-05, -6.5540e-02, -9.4651e-08,  1.1380e-03,
-        -8.6322e-01,  6.2180e-01, -2.0100e-07, -7.7742e-01, -2.8044e-02,
-         1.0749e+00, -1.6849e-02, -2.5378e-01, -2.0055e-01,  2.9233e-07,
-        -5.5885e-02, -1.4341e+00, -7.4831e-11, -3.9203e-01,  3.7142e-05,
-         2.0041e-03,  1.2390e-06,  0.0000e+00,  3.3545e-01,  8.0387e-01,
-        -2.5088e-01, -1.6549e-04,  2.4577e-01,  3.0492e-07,  7.4581e-05,
-        -1.2534e-01, -1.0523e-02,  4.1276e-01,  1.0505e+00, -1.9412e-08,
-        -4.5639e-03, -1.5887e-01, -1.2574e-01, -2.1929e-01,  9.5060e-01,
-         4.0400e-12, -1.5878e-06,  5.7103e-10,  1.9028e-01,  1.1417e+00,
-         1.5253e-09, -3.4495e-01,  3.4884e-01, -4.0268e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.4716e-01, -1.0358e-03, -9.9892e-02, -1.9029e-05,  2.7086e+00,
-         2.1204e-01,  0.0000e+00,  0.0000e+00, -3.0539e-01,  3.6697e-02,
-         3.3641e-02, -5.1736e-01, -3.1680e-01, -4.8411e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -6.5540e-02,  0.0000e+00,  0.0000e+00,
-        -8.6322e-01,  6.2180e-01,  0.0000e+00, -7.7742e-01, -2.8044e-02,
-         1.0749e+00, -1.6849e-02, -2.5378e-01, -2.0055e-01,  0.0000e+00,
-        -5.5885e-02, -1.4341e+00,  0.0000e+00, -3.9203e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.3545e-01,  8.0387e-01,
-        -2.5088e-01,  0.0000e+00,  2.4577e-01,  0.0000e+00,  0.0000e+00,
-        -1.2534e-01,  0.0000e+00,  4.1276e-01,  1.0505e+00,  0.0000e+00,
-        -4.5639e-03, -1.5887e-01, -1.2574e-01, -2.1929e-01,  9.5060e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.9028e-01,  1.1417e+00,
-         0.0000e+00, -3.4495e-01,  3.4884e-01, -4.0268e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.4716e-01, -1.0358e-03, -9.9892e-02, -1.9029e-05,  2.7086e+00,
-         2.1204e-01,  0.0000e+00,  0.0000e+00, -3.0539e-01,  3.6697e-02,
-         3.3641e-02, -5.1736e-01, -3.1680e-01, -4.8411e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -6.5540e-02,  0.0000e+00,  0.0000e+00,
-        -8.6322e-01,  6.2180e-01,  0.0000e+00, -7.7742e-01, -2.8044e-02,
-         1.0749e+00, -1.6849e-02, -2.5378e-01, -2.0055e-01,  0.0000e+00,
-        -5.5885e-02, -1.4341e+00,  0.0000e+00, -3.9203e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.3545e-01,  8.0387e-01,
-        -2.5088e-01,  0.0000e+00,  2.4577e-01,  0.0000e+00,  0.0000e+00,
-        -1.2534e-01,  0.0000e+00,  4.1276e-01,  1.0505e+00,  0.0000e+00,
-        -4.5639e-03, -1.5887e-01, -1.2574e-01, -2.1929e-01,  9.5060e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.9028e-01,  1.1417e+00,
-         0.0000e+00, -3.4495e-01,  3.4884e-01, -4.0268e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9035e-01,  3.3389e-03, -8.9543e-02,  1.4478e-01,  2.7035e+00,
-         1.9249e-01,  2.6769e-10, -2.1137e-03, -3.4407e-01,  2.2397e-02,
-         2.5902e-02, -5.1962e-01, -2.5383e-01, -4.8357e-01, -4.8005e-07,
-        -3.5679e-04, -3.5096e-05, -1.2338e-02, -8.1174e-08,  9.7598e-04,
-        -8.4771e-01,  6.2627e-01, -1.7238e-07, -8.2172e-01, -2.3617e-02,
-         1.0763e+00,  4.3004e-02, -2.1612e-01, -1.2854e-01,  2.5071e-07,
-        -1.0376e-01, -1.4263e+00, -6.4176e-11, -4.5838e-01,  3.1854e-05,
-         1.7187e-03,  1.0626e-06,  0.0000e+00,  3.3162e-01,  7.8801e-01,
-        -2.3809e-01, -1.4192e-04,  2.5768e-01,  2.6151e-07,  6.3962e-05,
-        -1.5032e-01, -9.0249e-03,  3.9810e-01,  1.0645e+00, -1.6648e-08,
-        -7.8823e-02, -5.0953e-02, -2.3022e-01, -2.2067e-01,  9.4399e-01,
-         3.4648e-12, -1.3617e-06,  4.8972e-10,  1.0936e-01,  1.1370e+00,
-         1.3082e-09, -3.2614e-01,  3.0548e-01, -3.8484e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2903,  0.0033, -0.0895,  0.1448,  2.7035,  0.1925,  0.0000,  0.0000,
-        -0.3441,  0.0224,  0.0259, -0.5196, -0.2538, -0.4836,  0.0000,  0.0000,
-         0.0000, -0.0123,  0.0000,  0.0000, -0.8477,  0.6263,  0.0000, -0.8217,
-        -0.0236,  0.0000,  0.0430, -0.2161, -0.1285,  0.0000, -0.1038, -1.4263,
-         0.0000, -0.4584,  0.0000,  0.0000,  0.0000,  0.0000,  0.3316,  0.7880,
-        -0.2381,  0.0000,  0.2577,  0.0000,  0.0000, -0.1503,  0.0000,  0.3981,
-         1.0645,  0.0000, -0.0788, -0.0510, -0.2302, -0.2207,  0.9440,  0.0000,
-         0.0000,  0.0000,  0.1094,  1.1370,  0.0000, -0.3261,  0.3055, -0.3848],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2903,  0.0033, -0.0895,  0.1448,  2.7035,  0.1925,  0.0000,  0.0000,
-        -0.3441,  0.0224,  0.0259, -0.5196, -0.2538, -0.4836,  0.0000,  0.0000,
-         0.0000, -0.0123,  0.0000,  0.0000, -0.8477,  0.6263,  0.0000, -0.8217,
-        -0.0236,  0.0000,  0.0430, -0.2161, -0.1285,  0.0000, -0.1038, -1.4263,
-         0.0000, -0.4584,  0.0000,  0.0000,  0.0000,  0.0000,  0.3316,  0.7880,
-        -0.2381,  0.0000,  0.2577,  0.0000,  0.0000, -0.1503,  0.0000,  0.3981,
-         1.0645,  0.0000, -0.0788, -0.0510, -0.2302, -0.2207,  0.9440,  0.0000,
-         0.0000,  0.0000,  0.1094,  1.1370,  0.0000, -0.3261,  0.3055, -0.3848],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9569e-01,  3.4081e-03, -1.0599e-01,  2.0427e-01,  2.6993e+00,
-         1.4784e-01,  2.2964e-10, -1.8133e-03, -3.6219e-01, -2.4117e-02,
-         1.1007e-01, -4.9826e-01, -2.1290e-01, -4.9086e-01, -4.1181e-07,
-        -3.0607e-04, -3.0107e-05,  6.6141e-02, -6.9635e-08,  8.3724e-04,
-        -8.2476e-01,  6.1421e-01, -1.4788e-07, -8.5635e-01,  4.1950e-03,
-         1.2002e-03,  7.2389e-02, -1.5932e-01, -6.7514e-02,  2.1507e-07,
-        -1.7342e-01, -1.4171e+00, -5.5053e-11, -4.9000e-01,  2.7326e-05,
-         1.4744e-03,  9.1152e-07,  0.0000e+00,  3.0484e-01,  7.9335e-01,
-        -1.7200e-01, -1.2175e-04,  2.3440e-01,  2.2433e-07,  5.4870e-05,
-        -1.4847e-01, -7.7420e-03,  3.9458e-01,  1.0804e+00, -1.4282e-08,
-        -1.2079e-01,  7.9625e-02, -3.0224e-01, -2.2134e-01,  9.4535e-01,
-         2.9722e-12, -1.1681e-06,  4.2011e-10,  1.7522e-02,  1.1286e+00,
-         1.1222e-09, -3.1601e-01,  3.0646e-01, -3.8071e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2957,  0.0034, -0.1060,  0.2043,  2.6993,  0.1478,  0.0000,  0.0000,
-        -0.3622, -0.0241,  0.1101, -0.4983, -0.2129, -0.4909,  0.0000,  0.0000,
-         0.0000,  0.0661,  0.0000,  0.0000, -0.8248,  0.6142,  0.0000, -0.8563,
-         0.0042,  0.0000,  0.0724, -0.1593, -0.0675,  0.0000, -0.1734, -1.4171,
-         0.0000, -0.4900,  0.0000,  0.0000,  0.0000,  0.0000,  0.3048,  0.7933,
-        -0.1720,  0.0000,  0.2344,  0.0000,  0.0000, -0.1485,  0.0000,  0.3946,
-         1.0804,  0.0000, -0.1208,  0.0796, -0.3022, -0.2213,  0.9454,  0.0000,
-         0.0000,  0.0000,  0.0175,  1.1286,  0.0000, -0.3160,  0.3065, -0.3807],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2957,  0.0034, -0.1060,  0.2043,  2.6993,  0.1478,  0.0000,  0.0000,
-        -0.3622, -0.0241,  0.1101, -0.4983, -0.2129, -0.4909,  0.0000,  0.0000,
-         0.0000,  0.0661,  0.0000,  0.0000, -0.8248,  0.6142,  0.0000, -0.8563,
-         0.0042,  0.0000,  0.0724, -0.1593, -0.0675,  0.0000, -0.1734, -1.4171,
-         0.0000, -0.4900,  0.0000,  0.0000,  0.0000,  0.0000,  0.3048,  0.7933,
-        -0.1720,  0.0000,  0.2344,  0.0000,  0.0000, -0.1485,  0.0000,  0.3946,
-         1.0804,  0.0000, -0.1208,  0.0796, -0.3022, -0.2213,  0.9454,  0.0000,
-         0.0000,  0.0000,  0.0175,  1.1286,  0.0000, -0.3160,  0.3065, -0.3807],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8308e-01, -8.8828e-03, -1.1465e-01,  2.1035e-01,  2.6929e+00,
-         1.5552e-01,  1.9705e-10, -1.5559e-03, -3.7568e-01, -1.0373e-01,
-         2.8768e-01, -4.4622e-01, -1.7766e-01, -4.6545e-01, -3.5336e-07,
-        -2.6263e-04, -2.5834e-05,  8.4867e-02, -5.9751e-08,  7.1841e-04,
-        -7.8696e-01,  5.6132e-01, -1.2689e-07, -8.8006e-01, -4.8216e-04,
-         1.0298e-03,  1.0271e-01, -8.6071e-02, -5.2635e-03,  1.8454e-07,
-        -3.2834e-01, -1.4089e+00, -4.7239e-11, -5.2398e-01,  2.3447e-05,
-         1.2651e-03,  7.8215e-07,  0.0000e+00,  3.0027e-01,  7.9244e-01,
-        -6.1583e-02, -1.0447e-04,  2.0611e-01,  1.9249e-07,  4.7082e-05,
-        -1.8831e-01, -6.6431e-03,  4.0418e-01,  1.0928e+00, -1.2255e-08,
-        -1.1987e-01,  2.1263e-01, -3.4670e-01, -2.9571e-01,  9.3391e-01,
-         2.5504e-12, -1.0023e-06,  3.6048e-10, -1.0035e-01,  1.1228e+00,
-         9.6292e-10, -3.5352e-01,  2.8800e-01, -3.9390e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.8308e-01, -8.8828e-03, -1.1465e-01,  2.1035e-01,  2.6929e+00,
-         1.5552e-01,  0.0000e+00,  0.0000e+00, -3.7568e-01, -1.0373e-01,
-         2.8768e-01, -4.4622e-01, -1.7766e-01, -4.6545e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  8.4867e-02,  0.0000e+00,  0.0000e+00,
-        -7.8696e-01,  5.6132e-01,  0.0000e+00, -8.8006e-01, -4.8216e-04,
-         0.0000e+00,  1.0271e-01, -8.6071e-02, -5.2635e-03,  0.0000e+00,
-        -3.2834e-01, -1.4089e+00,  0.0000e+00, -5.2398e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.0027e-01,  7.9244e-01,
-        -6.1583e-02,  0.0000e+00,  2.0611e-01,  0.0000e+00,  0.0000e+00,
-        -1.8831e-01,  0.0000e+00,  4.0418e-01,  1.0928e+00,  0.0000e+00,
-        -1.1987e-01,  2.1263e-01, -3.4670e-01, -2.9571e-01,  9.3391e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0035e-01,  1.1228e+00,
-         0.0000e+00, -3.5352e-01,  2.8800e-01, -3.9390e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.8308e-01, -8.8828e-03, -1.1465e-01,  2.1035e-01,  2.6929e+00,
-         1.5552e-01,  0.0000e+00,  0.0000e+00, -3.7568e-01, -1.0373e-01,
-         2.8768e-01, -4.4622e-01, -1.7766e-01, -4.6545e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  8.4867e-02,  0.0000e+00,  0.0000e+00,
-        -7.8696e-01,  5.6132e-01,  0.0000e+00, -8.8006e-01, -4.8216e-04,
-         0.0000e+00,  1.0271e-01, -8.6071e-02, -5.2635e-03,  0.0000e+00,
-        -3.2834e-01, -1.4089e+00,  0.0000e+00, -5.2398e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.0027e-01,  7.9244e-01,
-        -6.1583e-02,  0.0000e+00,  2.0611e-01,  0.0000e+00,  0.0000e+00,
-        -1.8831e-01,  0.0000e+00,  4.0418e-01,  1.0928e+00,  0.0000e+00,
-        -1.1987e-01,  2.1263e-01, -3.4670e-01, -2.9571e-01,  9.3391e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0035e-01,  1.1228e+00,
-         0.0000e+00, -3.5352e-01,  2.8800e-01, -3.9390e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7054e-01, -9.1906e-03, -1.0666e-01,  1.5802e-01,  2.6860e+00,
-         1.7085e-01,  1.6912e-10, -1.3354e-03, -3.7799e-01, -1.6671e-01,
-         4.6542e-01, -4.1445e-01, -1.6099e-01, -4.4486e-01, -3.0329e-07,
-        -2.2541e-04, -2.2173e-05,  4.0656e-02, -5.1284e-08,  6.1661e-04,
-        -7.5333e-01,  4.7912e-01, -1.0891e-07, -9.1430e-01, -1.6345e-02,
-         8.8392e-04,  1.6487e-01, -2.5692e-02,  1.7473e-02,  1.5839e-07,
-        -4.6980e-01, -1.4026e+00, -4.0545e-11, -5.5277e-01,  2.0125e-05,
-         1.0859e-03,  6.7132e-07,  0.0000e+00,  2.9724e-01,  7.8694e-01,
-         5.2851e-02, -8.9665e-05,  1.6731e-01,  1.6522e-07,  4.0410e-05,
-        -1.5864e-01, -5.7018e-03,  4.3923e-01,  1.1110e+00, -1.0518e-08,
-        -8.3497e-02,  2.7116e-01, -3.4538e-01, -3.0322e-01,  9.1715e-01,
-         2.1890e-12, -8.6031e-07,  3.0940e-10, -1.9250e-01,  1.1194e+00,
-         8.2648e-10, -3.4191e-01,  2.7031e-01, -3.8746e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2705, -0.0092, -0.1067,  0.1580,  2.6860,  0.1709,  0.0000,  0.0000,
-        -0.3780, -0.1667,  0.4654, -0.4144, -0.1610, -0.4449,  0.0000,  0.0000,
-         0.0000,  0.0407,  0.0000,  0.0000, -0.7533,  0.4791,  0.0000, -0.9143,
-        -0.0163,  0.0000,  0.1649, -0.0257,  0.0175,  0.0000, -0.4698, -1.4026,
-         0.0000, -0.5528,  0.0000,  0.0000,  0.0000,  0.0000,  0.2972,  0.7869,
-         0.0529,  0.0000,  0.1673,  0.0000,  0.0000, -0.1586,  0.0000,  0.4392,
-         1.1110,  0.0000, -0.0835,  0.2712, -0.3454, -0.3032,  0.9172,  0.0000,
-         0.0000,  0.0000, -0.1925,  1.1194,  0.0000, -0.3419,  0.2703, -0.3875],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2705, -0.0092, -0.1067,  0.1580,  2.6860,  0.1709,  0.0000,  0.0000,
-        -0.3780, -0.1667,  0.4654, -0.4144, -0.1610, -0.4449,  0.0000,  0.0000,
-         0.0000,  0.0407,  0.0000,  0.0000, -0.7533,  0.4791,  0.0000, -0.9143,
-        -0.0163,  0.0000,  0.1649, -0.0257,  0.0175,  0.0000, -0.4698, -1.4026,
-         0.0000, -0.5528,  0.0000,  0.0000,  0.0000,  0.0000,  0.2972,  0.7869,
-         0.0529,  0.0000,  0.1673,  0.0000,  0.0000, -0.1586,  0.0000,  0.4392,
-         1.1110,  0.0000, -0.0835,  0.2712, -0.3454, -0.3032,  0.9172,  0.0000,
-         0.0000,  0.0000, -0.1925,  1.1194,  0.0000, -0.3419,  0.2703, -0.3875],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8211e-01, -1.8207e-02, -5.0423e-02,  1.8016e-01,  2.6772e+00,
-         2.7024e-01,  1.4520e-10, -1.1465e-03, -3.8419e-01, -1.7157e-01,
-         5.9033e-01, -4.1022e-01, -1.3199e-01, -4.5059e-01, -2.6038e-07,
-        -1.9352e-04, -1.9036e-05,  1.7195e-02, -4.4029e-08,  5.2938e-04,
-        -7.3625e-01,  4.1476e-01, -9.3502e-08, -9.3738e-01, -3.7001e-02,
-         7.5887e-04,  2.6493e-01,  4.7494e-02, -6.4919e-02,  1.3598e-07,
-        -5.5556e-01, -1.4015e+00, -3.4809e-11, -5.7963e-01,  1.7278e-05,
-         9.3224e-04,  5.7634e-07,  0.0000e+00,  2.3386e-01,  7.6617e-01,
-         9.7468e-02, -7.6980e-05,  1.4542e-01,  1.4184e-07,  3.4693e-05,
-        -6.4191e-02, -4.8952e-03,  4.4443e-01,  1.1285e+00, -9.0301e-09,
-        -1.1233e-01,  2.4889e-01, -3.8594e-01, -1.1542e-01,  8.9278e-01,
-         1.8793e-12, -7.3860e-07,  2.6563e-10, -2.1836e-01,  1.1199e+00,
-         7.0955e-10, -2.2768e-01,  1.9248e-01, -3.4609e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2821, -0.0182, -0.0504,  0.1802,  2.6772,  0.2702,  0.0000,  0.0000,
-        -0.3842, -0.1716,  0.5903, -0.4102, -0.1320, -0.4506,  0.0000,  0.0000,
-         0.0000,  0.0172,  0.0000,  0.0000, -0.7363,  0.4148,  0.0000, -0.9374,
-        -0.0370,  0.0000,  0.2649,  0.0475, -0.0649,  0.0000, -0.5556, -1.4015,
-         0.0000, -0.5796,  0.0000,  0.0000,  0.0000,  0.0000,  0.2339,  0.7662,
-         0.0975,  0.0000,  0.1454,  0.0000,  0.0000, -0.0642,  0.0000,  0.4444,
-         1.1285,  0.0000, -0.1123,  0.2489, -0.3859, -0.1154,  0.8928,  0.0000,
-         0.0000,  0.0000, -0.2184,  1.1199,  0.0000, -0.2277,  0.1925, -0.3461],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2821, -0.0182, -0.0504,  0.1802,  2.6772,  0.2702,  0.0000,  0.0000,
-        -0.3842, -0.1716,  0.5903, -0.4102, -0.1320, -0.4506,  0.0000,  0.0000,
-         0.0000,  0.0172,  0.0000,  0.0000, -0.7363,  0.4148,  0.0000, -0.9374,
-        -0.0370,  0.0000,  0.2649,  0.0475, -0.0649,  0.0000, -0.5556, -1.4015,
-         0.0000, -0.5796,  0.0000,  0.0000,  0.0000,  0.0000,  0.2339,  0.7662,
-         0.0975,  0.0000,  0.1454,  0.0000,  0.0000, -0.0642,  0.0000,  0.4444,
-         1.1285,  0.0000, -0.1123,  0.2489, -0.3859, -0.1154,  0.8928,  0.0000,
-         0.0000,  0.0000, -0.2184,  1.1199,  0.0000, -0.2277,  0.1925, -0.3461],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1033e-01, -1.5015e-02,  6.6583e-03,  2.7459e-01,  2.6686e+00,
-         3.8367e-01,  1.2469e-10, -9.8457e-04, -3.6747e-01, -1.6769e-01,
-         6.4531e-01, -3.8695e-01, -1.0737e-01, -4.2403e-01, -2.2361e-07,
-        -1.6619e-04, -1.6347e-05,  4.2570e-02, -3.7811e-08,  4.5461e-04,
-        -7.2399e-01,  4.0067e-01, -8.0296e-08, -9.1474e-01,  2.0353e-03,
-         6.5169e-04,  3.3559e-01,  1.8437e-01, -1.4613e-01,  1.1678e-07,
-        -6.0815e-01, -1.4009e+00, -2.9893e-11, -5.7324e-01,  1.4837e-05,
-         8.0057e-04,  4.9494e-07,  0.0000e+00,  1.7548e-01,  7.3800e-01,
-         1.4706e-01, -6.6108e-05,  1.3747e-01,  1.2181e-07,  2.9793e-05,
-         1.1668e-02, -4.2038e-03,  4.1253e-01,  1.1332e+00, -7.7547e-09,
-        -1.8378e-01,  1.9610e-01, -4.3716e-01,  1.1799e-01,  8.7546e-01,
-         1.6139e-12, -6.3428e-07,  2.2811e-10, -2.2326e-01,  1.1162e+00,
-         6.0934e-10, -1.3604e-01,  8.3063e-02, -3.2780e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.1033e-01, -1.5015e-02,  6.6583e-03,  2.7459e-01,  2.6686e+00,
-         3.8367e-01,  0.0000e+00,  0.0000e+00, -3.6747e-01, -1.6769e-01,
-         6.4531e-01, -3.8695e-01, -1.0737e-01, -4.2403e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  4.2570e-02,  0.0000e+00,  0.0000e+00,
-        -7.2399e-01,  4.0067e-01,  0.0000e+00, -9.1474e-01,  2.0353e-03,
-         0.0000e+00,  3.3559e-01,  1.8437e-01, -1.4613e-01,  0.0000e+00,
-        -6.0815e-01, -1.4009e+00,  0.0000e+00, -5.7324e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.7548e-01,  7.3800e-01,
-         1.4706e-01,  0.0000e+00,  1.3747e-01,  0.0000e+00,  0.0000e+00,
-         1.1668e-02,  0.0000e+00,  4.1253e-01,  1.1332e+00,  0.0000e+00,
-        -1.8378e-01,  1.9610e-01, -4.3716e-01,  1.1799e-01,  8.7546e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -2.2326e-01,  1.1162e+00,
-         0.0000e+00, -1.3604e-01,  8.3063e-02, -3.2780e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.1033e-01, -1.5015e-02,  6.6583e-03,  2.7459e-01,  2.6686e+00,
-         3.8367e-01,  0.0000e+00,  0.0000e+00, -3.6747e-01, -1.6769e-01,
-         6.4531e-01, -3.8695e-01, -1.0737e-01, -4.2403e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  4.2570e-02,  0.0000e+00,  0.0000e+00,
-        -7.2399e-01,  4.0067e-01,  0.0000e+00, -9.1474e-01,  2.0353e-03,
-         0.0000e+00,  3.3559e-01,  1.8437e-01, -1.4613e-01,  0.0000e+00,
-        -6.0815e-01, -1.4009e+00,  0.0000e+00, -5.7324e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.7548e-01,  7.3800e-01,
-         1.4706e-01,  0.0000e+00,  1.3747e-01,  0.0000e+00,  0.0000e+00,
-         1.1668e-02,  0.0000e+00,  4.1253e-01,  1.1332e+00,  0.0000e+00,
-        -1.8378e-01,  1.9610e-01, -4.3716e-01,  1.1799e-01,  8.7546e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -2.2326e-01,  1.1162e+00,
-         0.0000e+00, -1.3604e-01,  8.3063e-02, -3.2780e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1219e-01,  2.9300e-02, -2.0959e-02,  2.6107e-01,  2.6620e+00,
-         4.6770e-01,  1.0711e-10, -8.4575e-04, -3.4298e-01, -1.6129e-01,
-         6.5810e-01, -3.6769e-01, -1.3176e-01, -4.1456e-01, -1.9208e-07,
-        -1.4276e-04, -1.4042e-05, -4.3844e-02, -3.2479e-08,  3.9051e-04,
-        -7.1437e-01,  3.7055e-01, -6.8974e-08, -8.8042e-01,  7.7549e-02,
-         5.5980e-04,  3.7640e-01,  2.5000e-01, -1.9239e-01,  1.0031e-07,
-        -6.5497e-01, -1.3965e+00, -2.5678e-11, -5.4421e-01,  1.2745e-05,
-         6.8769e-04,  4.2515e-07,  0.0000e+00,  2.1125e-01,  7.0329e-01,
-         2.0211e-01, -5.6787e-05,  6.2976e-02,  1.0463e-07,  2.5592e-05,
-        -2.6796e-02, -3.6110e-03,  4.1877e-01,  1.1257e+00, -6.6612e-09,
-        -1.7620e-01,  1.8543e-01, -4.1950e-01,  5.5684e-02,  8.5993e-01,
-         1.3863e-12, -5.4485e-07,  1.9595e-10, -1.7314e-01,  1.1132e+00,
-         5.2342e-10, -1.5250e-01,  5.6999e-02, -3.3169e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3122,  0.0293, -0.0210,  0.2611,  2.6620,  0.4677,  0.0000,  0.0000,
-        -0.3430, -0.1613,  0.6581,  0.0000, -0.1318, -0.4146,  0.0000,  0.0000,
-         0.0000, -0.0438,  0.0000,  0.0000, -0.7144,  0.3706,  0.0000, -0.8804,
-         0.0775,  0.0000,  0.3764,  0.2500, -0.1924,  0.0000, -0.6550, -1.3965,
-         0.0000, -0.5442,  0.0000,  0.0000,  0.0000,  0.0000,  0.2113,  0.7033,
-         0.2021,  0.0000,  0.0630,  0.0000,  0.0000, -0.0268,  0.0000,  0.4188,
-         1.1257,  0.0000, -0.1762,  0.1854, -0.4195,  0.0557,  0.8599,  0.0000,
-         0.0000,  0.0000, -0.1731,  1.1132,  0.0000, -0.1525,  0.0570, -0.3317],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3122,  0.0293, -0.0210,  0.2611,  2.6620,  0.4677,  0.0000,  0.0000,
-        -0.3430, -0.1613,  0.6581,  0.0000, -0.1318, -0.4146,  0.0000,  0.0000,
-         0.0000, -0.0438,  0.0000,  0.0000, -0.7144,  0.3706,  0.0000, -0.8804,
-         0.0775,  0.0000,  0.3764,  0.2500, -0.1924,  0.0000, -0.6550, -1.3965,
-         0.0000, -0.5442,  0.0000,  0.0000,  0.0000,  0.0000,  0.2113,  0.7033,
-         0.2021,  0.0000,  0.0630,  0.0000,  0.0000, -0.0268,  0.0000,  0.4188,
-         1.1257,  0.0000, -0.1762,  0.1854, -0.4195,  0.0557,  0.8599,  0.0000,
-         0.0000,  0.0000, -0.1731,  1.1132,  0.0000, -0.1525,  0.0570, -0.3317],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7915e-01,  9.1450e-02, -1.2974e-01,  8.5408e-02,  2.6555e+00,
-         4.3824e-01,  9.2032e-11, -7.2670e-04, -3.2821e-01, -1.3109e-01,
-         5.9990e-01,  1.6545e-02, -1.9055e-01, -4.1613e-01, -1.6504e-07,
-        -1.2266e-04, -1.2066e-05, -1.3184e-01, -2.7907e-08,  3.3554e-04,
-        -7.0439e-01,  3.3906e-01, -5.9265e-08, -8.2825e-01,  1.8413e-01,
-         4.8100e-04,  3.0777e-01,  2.2547e-01, -1.3378e-01,  8.6192e-08,
-        -6.7021e-01, -1.3889e+00, -2.2064e-11, -4.9412e-01,  1.0951e-05,
-         5.9089e-04,  3.6531e-07,  0.0000e+00,  3.2341e-01,  7.2898e-01,
-         2.6847e-01, -4.8793e-05, -8.9372e-02,  8.9906e-08,  2.1990e-05,
-        -2.1449e-01, -3.1027e-03,  4.1155e-01,  1.1156e+00, -5.7236e-09,
-        -1.1582e-01,  2.2942e-01, -4.0623e-01, -3.1000e-01,  8.7647e-01,
-         1.1912e-12, -4.6815e-07,  1.6836e-10, -1.2026e-01,  1.1004e+00,
-         4.4974e-10, -2.5656e-01,  1.2526e-01, -3.8036e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2791,  0.0915, -0.1297,  0.0854,  2.6555,  0.4382,  0.0000,  0.0000,
-        -0.3282, -0.1311,  0.5999,  0.0000, -0.1905, -0.4161,  0.0000,  0.0000,
-         0.0000, -0.1318,  0.0000,  0.0000, -0.7044,  0.3391,  0.0000, -0.8283,
-         0.1841,  0.0000,  0.3078,  0.2255, -0.1338,  0.0000, -0.6702, -1.3889,
-         0.0000, -0.4941,  0.0000,  0.0000,  0.0000,  0.0000,  0.3234,  0.7290,
-         0.2685,  0.0000, -0.0894,  0.0000,  0.0000, -0.2145,  0.0000,  0.4115,
-         1.1156,  0.0000, -0.1158,  0.2294, -0.4062, -0.3100,  0.8765,  0.0000,
-         0.0000,  0.0000, -0.1203,  1.1004,  0.0000, -0.2566,  0.1253, -0.3804],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2791,  0.0915, -0.1297,  0.0854,  2.6555,  0.4382,  0.0000,  0.0000,
-        -0.3282, -0.1311,  0.5999,  0.0000, -0.1905, -0.4161,  0.0000,  0.0000,
-         0.0000, -0.1318,  0.0000,  0.0000, -0.7044,  0.3391,  0.0000, -0.8283,
-         0.1841,  0.0000,  0.3078,  0.2255, -0.1338,  0.0000, -0.6702, -1.3889,
-         0.0000, -0.4941,  0.0000,  0.0000,  0.0000,  0.0000,  0.3234,  0.7290,
-         0.2685,  0.0000, -0.0894,  0.0000,  0.0000, -0.2145,  0.0000,  0.4115,
-         1.1156,  0.0000, -0.1158,  0.2294, -0.4062, -0.3100,  0.8765,  0.0000,
-         0.0000,  0.0000, -0.1203,  1.1004,  0.0000, -0.2566,  0.1253, -0.3804],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9986e-01,  1.7076e-01, -1.4639e-01, -1.7220e-02,  2.6504e+00,
-         4.0107e-01,  7.9100e-11, -6.2458e-04, -2.9674e-01, -2.0105e-02,
-         5.2917e-01,  1.4220e-02, -2.1059e-01, -4.1158e-01, -1.4185e-07,
-        -1.0543e-04, -1.0370e-05, -2.3769e-01, -2.3986e-08,  2.8839e-04,
-        -7.1288e-01,  2.9436e-01, -5.0937e-08, -8.0589e-01,  2.7268e-01,
-         4.1341e-04,  2.7770e-01,  1.8201e-01, -9.5636e-02,  7.4081e-08,
-        -6.5317e-01, -1.3796e+00, -1.8963e-11, -4.5579e-01,  9.4124e-06,
-         5.0786e-04,  3.1398e-07,  0.0000e+00,  5.1522e-01,  7.7031e-01,
-         3.0255e-01, -4.1937e-05, -2.0089e-01,  7.7272e-08,  1.8900e-05,
-        -3.7161e-01, -2.6667e-03,  4.5057e-01,  1.1019e+00, -4.9193e-09,
-        -8.0208e-02,  2.0031e-01, -4.0204e-01, -4.6599e-01,  8.9759e-01,
-         1.0238e-12, -4.0237e-07,  1.4471e-10, -9.4082e-02,  1.0884e+00,
-         3.8654e-10, -2.5072e-01,  1.3785e-01, -3.8612e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2999,  0.1708, -0.1464, -0.0172,  2.6504,  0.4011,  0.0000,  0.0000,
-        -0.2967, -0.0201,  0.5292,  0.0000, -0.2106, -0.4116,  0.0000,  0.0000,
-         0.0000, -0.2377,  0.0000,  0.0000, -0.7129,  0.2944,  0.0000, -0.8059,
-         0.2727,  0.0000,  0.2777,  0.1820, -0.0956,  0.0000, -0.6532, -1.3796,
-         0.0000, -0.4558,  0.0000,  0.0000,  0.0000,  0.0000,  0.5152,  0.7703,
-         0.3026,  0.0000, -0.2009,  0.0000,  0.0000, -0.3716,  0.0000,  0.4506,
-         1.1019,  0.0000, -0.0802,  0.2003, -0.4020, -0.4660,  0.8976,  0.0000,
-         0.0000,  0.0000, -0.0941,  1.0884,  0.0000, -0.2507,  0.1379, -0.3861],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2999,  0.1708, -0.1464, -0.0172,  2.6504,  0.4011,  0.0000,  0.0000,
-        -0.2967, -0.0201,  0.5292,  0.0000, -0.2106, -0.4116,  0.0000,  0.0000,
-         0.0000, -0.2377,  0.0000,  0.0000, -0.7129,  0.2944,  0.0000, -0.8059,
-         0.2727,  0.0000,  0.2777,  0.1820, -0.0956,  0.0000, -0.6532, -1.3796,
-         0.0000, -0.4558,  0.0000,  0.0000,  0.0000,  0.0000,  0.5152,  0.7703,
-         0.3026,  0.0000, -0.2009,  0.0000,  0.0000, -0.3716,  0.0000,  0.4506,
-         1.1019,  0.0000, -0.0802,  0.2003, -0.4020, -0.4660,  0.8976,  0.0000,
-         0.0000,  0.0000, -0.0941,  1.0884,  0.0000, -0.2507,  0.1379, -0.3861],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1589e-01,  2.4695e-01, -1.3026e-01, -4.1884e-02,  2.6430e+00,
-         3.5168e-01,  6.8004e-11, -5.3697e-04, -2.5217e-01,  1.0416e-01,
-         4.6424e-01,  1.2225e-02, -1.9891e-01, -4.3790e-01, -1.2195e-07,
-        -9.0638e-05, -8.9156e-06, -2.6892e-01, -2.0621e-08,  2.4794e-04,
-        -7.2984e-01,  2.6256e-01, -4.3792e-08, -8.0923e-01,  3.5193e-01,
-         3.5542e-04,  2.8743e-01,  1.5142e-01, -6.4233e-02,  6.3689e-08,
-        -6.3132e-01, -1.3727e+00, -1.6303e-11, -4.4096e-01,  8.0921e-06,
-         4.3662e-04,  2.6993e-07,  0.0000e+00,  5.9528e-01,  7.5479e-01,
-         3.3337e-01, -3.6054e-05, -2.7858e-01,  6.6433e-08,  1.6249e-05,
-        -4.3562e-01, -2.2927e-03,  4.5986e-01,  1.0795e+00, -4.2293e-09,
-        -4.6698e-02,  1.5005e-01, -4.2041e-01, -4.3854e-01,  8.9134e-01,
-         8.8018e-13, -3.4593e-07,  1.2441e-10,  4.6207e-02,  1.0784e+00,
-         3.3232e-10, -1.7998e-01,  7.0286e-02, -3.5704e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3159,  0.2469, -0.1303, -0.0419,  2.6430,  0.3517,  0.0000,  0.0000,
-        -0.2522,  0.1042,  0.4642,  0.0000, -0.1989, -0.4379,  0.0000,  0.0000,
-         0.0000, -0.2689,  0.0000,  0.0000, -0.7298,  0.2626,  0.0000, -0.8092,
-         0.3519,  0.0000,  0.2874,  0.1514, -0.0642,  0.0000, -0.6313, -1.3727,
-         0.0000, -0.4410,  0.0000,  0.0000,  0.0000,  0.0000,  0.5953,  0.7548,
-         0.3334,  0.0000, -0.2786,  0.0000,  0.0000, -0.4356,  0.0000,  0.4599,
-         1.0795,  0.0000, -0.0467,  0.1500, -0.4204, -0.4385,  0.8913,  0.0000,
-         0.0000,  0.0000,  0.0462,  1.0784,  0.0000, -0.1800,  0.0703, -0.3570],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3159,  0.2469, -0.1303, -0.0419,  2.6430,  0.3517,  0.0000,  0.0000,
-        -0.2522,  0.1042,  0.4642,  0.0000, -0.1989, -0.4379,  0.0000,  0.0000,
-         0.0000, -0.2689,  0.0000,  0.0000, -0.7298,  0.2626,  0.0000, -0.8092,
-         0.3519,  0.0000,  0.2874,  0.1514, -0.0642,  0.0000, -0.6313, -1.3727,
-         0.0000, -0.4410,  0.0000,  0.0000,  0.0000,  0.0000,  0.5953,  0.7548,
-         0.3334,  0.0000, -0.2786,  0.0000,  0.0000, -0.4356,  0.0000,  0.4599,
-         1.0795,  0.0000, -0.0467,  0.1500, -0.4204, -0.4385,  0.8913,  0.0000,
-         0.0000,  0.0000,  0.0462,  1.0784,  0.0000, -0.1800,  0.0703, -0.3570],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.3446e-01,  3.0970e-01, -7.9786e-02, -3.2658e-02,  2.6354e+00,
-         3.0823e-01,  5.8482e-11, -4.6178e-04, -2.1346e-01,  1.9379e-01,
-         3.8426e-01,  1.0513e-02, -1.6644e-01, -4.6282e-01, -1.0487e-07,
-        -7.7946e-05, -7.6672e-06, -2.5127e-01, -1.7734e-08,  2.1322e-04,
-        -7.3610e-01,  2.1041e-01, -3.7660e-08, -8.2115e-01,  3.9415e-01,
-         3.0565e-04,  3.0349e-01,  1.0883e-01, -4.6496e-02,  5.4771e-08,
-        -6.1660e-01, -1.3634e+00, -1.4020e-11, -4.3904e-01,  6.9590e-06,
-         3.7548e-04,  2.3213e-07,  0.0000e+00,  6.6014e-01,  7.4616e-01,
-         3.5960e-01, -3.1005e-05, -3.1356e-01,  5.7130e-08,  1.3974e-05,
-        -4.0975e-01, -1.9716e-03,  4.5151e-01,  1.0613e+00, -3.6370e-09,
-        -3.5793e-03,  7.0003e-02, -4.1092e-01, -4.1323e-01,  8.8169e-01,
-         7.5693e-13, -2.9749e-07,  1.0699e-10,  2.1963e-01,  1.0669e+00,
-         2.8579e-10, -7.5800e-02,  3.7141e-02, -2.9723e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3345,  0.3097, -0.0798, -0.0327,  2.6354,  0.3082,  0.0000,  0.0000,
-        -0.2135,  0.1938,  0.3843,  0.0000, -0.1664, -0.4628,  0.0000,  0.0000,
-         0.0000, -0.2513,  0.0000,  0.0000, -0.7361,  0.2104,  0.0000, -0.8212,
-         0.3941,  0.0000,  0.3035,  0.1088, -0.0465,  0.0000, -0.6166, -1.3634,
-         0.0000, -0.4390,  0.0000,  0.0000,  0.0000,  0.0000,  0.6601,  0.7462,
-         0.3596,  0.0000, -0.3136,  0.0000,  0.0000, -0.4098,  0.0000,  0.4515,
-         1.0613,  0.0000, -0.0036,  0.0700, -0.4109, -0.4132,  0.8817,  0.0000,
-         0.0000,  0.0000,  0.2196,  1.0669,  0.0000, -0.0758,  0.0371, -0.2972],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3345,  0.3097, -0.0798, -0.0327,  2.6354,  0.3082,  0.0000,  0.0000,
-        -0.2135,  0.1938,  0.3843,  0.0000, -0.1664, -0.4628,  0.0000,  0.0000,
-         0.0000, -0.2513,  0.0000,  0.0000, -0.7361,  0.2104,  0.0000, -0.8212,
-         0.3941,  0.0000,  0.3035,  0.1088, -0.0465,  0.0000, -0.6166, -1.3634,
-         0.0000, -0.4390,  0.0000,  0.0000,  0.0000,  0.0000,  0.6601,  0.7462,
-         0.3596,  0.0000, -0.3136,  0.0000,  0.0000, -0.4098,  0.0000,  0.4515,
-         1.0613,  0.0000, -0.0036,  0.0700, -0.4109, -0.4132,  0.8817,  0.0000,
-         0.0000,  0.0000,  0.2196,  1.0669,  0.0000, -0.0758,  0.0371, -0.2972],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4276e-01,  3.1536e-01, -3.5144e-03,  6.4016e-03,  2.6285e+00,
-         2.7322e-01,  5.0307e-11, -3.9723e-04, -1.4724e-01,  1.9810e-01,
-         2.9762e-01,  9.0438e-03, -7.5566e-02, -4.9886e-01, -9.0215e-08,
-        -6.7051e-05, -6.5955e-06, -1.7927e-01, -1.5255e-08,  1.8341e-04,
-        -7.3683e-01,  1.3379e-01, -3.2396e-08, -8.3544e-01,  4.0887e-01,
-         2.6293e-04,  3.0234e-01,  7.3373e-02, -5.8009e-02,  4.7115e-08,
-        -6.0160e-01, -1.3567e+00, -1.2060e-11, -4.5221e-01,  5.9863e-06,
-         3.2300e-04,  1.9969e-07,  0.0000e+00,  6.9583e-01,  7.1754e-01,
-         3.5014e-01, -2.6672e-05, -3.1348e-01,  4.9145e-08,  1.2020e-05,
-        -3.0537e-01, -1.6960e-03,  4.0880e-01,  1.0504e+00, -3.1287e-09,
-        -2.2029e-03, -1.0204e-02, -3.6491e-01, -3.9096e-01,  8.6109e-01,
-         6.5113e-13, -2.5590e-07,  9.2033e-11,  3.9798e-01,  1.0620e+00,
-         2.4584e-10,  8.0830e-02, -9.2811e-03, -2.2349e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.4276e-01,  3.1536e-01, -3.5144e-03,  6.4016e-03,  2.6285e+00,
-         2.7322e-01,  0.0000e+00,  0.0000e+00, -1.4724e-01,  1.9810e-01,
-         2.9762e-01,  0.0000e+00, -7.5566e-02, -4.9886e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.7927e-01,  0.0000e+00,  0.0000e+00,
-        -7.3683e-01,  1.3379e-01,  0.0000e+00, -8.3544e-01,  4.0887e-01,
-         0.0000e+00,  3.0234e-01,  7.3373e-02, -5.8009e-02,  0.0000e+00,
-        -6.0160e-01, -1.3567e+00,  0.0000e+00, -4.5221e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.9583e-01,  7.1754e-01,
-         3.5014e-01,  0.0000e+00, -3.1348e-01,  0.0000e+00,  0.0000e+00,
-        -3.0537e-01,  0.0000e+00,  4.0880e-01,  1.0504e+00,  0.0000e+00,
-        -2.2029e-03, -1.0204e-02, -3.6491e-01, -3.9096e-01,  8.6109e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.9798e-01,  1.0620e+00,
-         0.0000e+00,  8.0830e-02, -9.2811e-03, -2.2349e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.4276e-01,  3.1536e-01, -3.5144e-03,  6.4016e-03,  2.6285e+00,
-         2.7322e-01,  0.0000e+00,  0.0000e+00, -1.4724e-01,  1.9810e-01,
-         2.9762e-01,  0.0000e+00, -7.5566e-02, -4.9886e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.7927e-01,  0.0000e+00,  0.0000e+00,
-        -7.3683e-01,  1.3379e-01,  0.0000e+00, -8.3544e-01,  4.0887e-01,
-         0.0000e+00,  3.0234e-01,  7.3373e-02, -5.8009e-02,  0.0000e+00,
-        -6.0160e-01, -1.3567e+00,  0.0000e+00, -4.5221e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.9583e-01,  7.1754e-01,
-         3.5014e-01,  0.0000e+00, -3.1348e-01,  0.0000e+00,  0.0000e+00,
-        -3.0537e-01,  0.0000e+00,  4.0880e-01,  1.0504e+00,  0.0000e+00,
-        -2.2029e-03, -1.0204e-02, -3.6491e-01, -3.9096e-01,  8.6109e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.9798e-01,  1.0620e+00,
-         0.0000e+00,  8.0830e-02, -9.2811e-03, -2.2349e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2637e-01,  3.0079e-01,  6.8790e-02,  7.7466e-02,  2.6213e+00,
-         1.9715e-01,  4.3288e-11, -3.4181e-04, -7.1124e-02,  2.1444e-01,
-         2.6414e-01,  7.7820e-03, -1.2519e-02, -5.3058e-01, -7.7628e-08,
-        -5.7696e-05, -5.6752e-06, -9.7288e-02, -1.3126e-08,  1.5782e-04,
-        -7.2624e-01,  8.9622e-02, -2.7876e-08, -8.3681e-01,  4.3087e-01,
-         2.2624e-04,  3.1630e-01,  2.8983e-02, -3.3566e-02,  4.0541e-08,
-        -5.8602e-01, -1.3515e+00, -1.0378e-11, -4.1797e-01,  5.1510e-06,
-         2.7793e-04,  1.7183e-07,  0.0000e+00,  7.0137e-01,  6.9736e-01,
-         3.4796e-01, -2.2950e-05, -2.8276e-01,  4.2288e-08,  1.0343e-05,
-        -2.2141e-01, -1.4594e-03,  3.5084e-01,  1.0358e+00, -2.6921e-09,
-         1.8981e-03, -2.4629e-02, -3.0311e-01, -3.9143e-01,  8.4325e-01,
-         5.6028e-13, -2.2020e-07,  7.9192e-11,  5.5549e-01,  1.0506e+00,
-         2.1154e-10,  2.1586e-01, -3.3342e-02, -1.3558e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.2637e-01,  3.0079e-01,  6.8790e-02,  7.7466e-02,  2.6213e+00,
-         1.9715e-01,  0.0000e+00,  0.0000e+00, -7.1124e-02,  2.1444e-01,
-         2.6414e-01,  0.0000e+00, -1.2519e-02, -5.3058e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -9.7288e-02,  0.0000e+00,  0.0000e+00,
-        -7.2624e-01,  8.9622e-02,  0.0000e+00, -8.3681e-01,  4.3087e-01,
-         0.0000e+00,  3.1630e-01,  2.8983e-02, -3.3566e-02,  0.0000e+00,
-        -5.8602e-01, -1.3515e+00,  0.0000e+00, -4.1797e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  7.0137e-01,  6.9736e-01,
-         3.4796e-01,  0.0000e+00, -2.8276e-01,  0.0000e+00,  0.0000e+00,
-        -2.2141e-01,  0.0000e+00,  3.5084e-01,  1.0358e+00,  0.0000e+00,
-         1.8981e-03, -2.4629e-02, -3.0311e-01, -3.9143e-01,  8.4325e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.5549e-01,  0.0000e+00,
-         0.0000e+00,  2.1586e-01, -3.3342e-02, -1.3558e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.2637e-01,  3.0079e-01,  6.8790e-02,  7.7466e-02,  2.6213e+00,
-         1.9715e-01,  0.0000e+00,  0.0000e+00, -7.1124e-02,  2.1444e-01,
-         2.6414e-01,  0.0000e+00, -1.2519e-02, -5.3058e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -9.7288e-02,  0.0000e+00,  0.0000e+00,
-        -7.2624e-01,  8.9622e-02,  0.0000e+00, -8.3681e-01,  4.3087e-01,
-         0.0000e+00,  3.1630e-01,  2.8983e-02, -3.3566e-02,  0.0000e+00,
-        -5.8602e-01, -1.3515e+00,  0.0000e+00, -4.1797e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  7.0137e-01,  6.9736e-01,
-         3.4796e-01,  0.0000e+00, -2.8276e-01,  0.0000e+00,  0.0000e+00,
-        -2.2141e-01,  0.0000e+00,  3.5084e-01,  1.0358e+00,  0.0000e+00,
-         1.8981e-03, -2.4629e-02, -3.0311e-01, -3.9143e-01,  8.4325e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.5549e-01,  0.0000e+00,
-         0.0000e+00,  2.1586e-01, -3.3342e-02, -1.3558e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7827e-01,  2.7671e-01,  7.0787e-02,  1.3648e-01,  2.6155e+00,
-         1.0596e-01,  3.7259e-11, -2.9420e-04,  1.4822e-02,  1.9266e-01,
-         2.4061e-01,  6.6982e-03,  8.6958e-03, -5.6034e-01, -6.6816e-08,
-        -4.9660e-05, -4.8848e-06, -2.3332e-02, -1.1298e-08,  1.3584e-04,
-        -7.0285e-01,  8.9440e-02, -2.3993e-08, -8.2644e-01,  4.6786e-01,
-         1.9473e-04,  2.8796e-01,  2.7942e-03,  3.0316e-03,  3.4895e-08,
-        -5.6220e-01, -1.3471e+00, -8.9324e-12, -4.0848e-01,  4.4336e-06,
-         2.3922e-04,  1.4790e-07,  0.0000e+00,  6.7348e-01,  6.6133e-01,
-         3.5311e-01, -1.9754e-05, -2.7728e-01,  3.6398e-08,  8.9027e-06,
-        -1.6202e-01, -1.2561e-03,  2.9813e-01,  1.0379e+00, -2.3172e-09,
-         1.4623e-02, -3.3228e-02, -2.5407e-01, -3.8555e-01,  8.3178e-01,
-         4.8225e-13, -1.8953e-07,  6.8162e-11,  6.6804e-01, -9.8441e-03,
-         1.8208e-10,  2.7831e-01, -1.8815e-02, -6.2175e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2783,  0.2767,  0.0708,  0.1365,  2.6155,  0.1060,  0.0000,  0.0000,
-         0.0148,  0.1927,  0.2406,  0.0000,  0.0087, -0.5603,  0.0000,  0.0000,
-         0.0000, -0.0233,  0.0000,  0.0000, -0.7028,  0.0894,  0.0000, -0.8264,
-         0.4679,  0.0000,  0.2880,  0.0028,  0.0030,  0.0000, -0.5622, -1.3471,
-         0.0000, -0.4085,  0.0000,  0.0000,  0.0000,  0.0000,  0.6735,  0.6613,
-         0.3531,  0.0000, -0.2773,  0.0000,  0.0000, -0.1620,  0.0000,  0.2981,
-         1.0379,  0.0000,  0.0146, -0.0332, -0.2541, -0.3856,  0.8318,  0.0000,
-         0.0000,  0.0000,  0.6680,  0.0000,  0.0000,  0.2783, -0.0188, -0.0622],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2783,  0.2767,  0.0708,  0.1365,  2.6155,  0.1060,  0.0000,  0.0000,
-         0.0148,  0.1927,  0.2406,  0.0000,  0.0087, -0.5603,  0.0000,  0.0000,
-         0.0000, -0.0233,  0.0000,  0.0000, -0.7028,  0.0894,  0.0000, -0.8264,
-         0.4679,  0.0000,  0.2880,  0.0028,  0.0030,  0.0000, -0.5622, -1.3471,
-         0.0000, -0.4085,  0.0000,  0.0000,  0.0000,  0.0000,  0.6735,  0.6613,
-         0.3531,  0.0000, -0.2773,  0.0000,  0.0000, -0.1620,  0.0000,  0.2981,
-         1.0379,  0.0000,  0.0146, -0.0332, -0.2541, -0.3856,  0.8318,  0.0000,
-         0.0000,  0.0000,  0.6680,  0.0000,  0.0000,  0.2783, -0.0188, -0.0622],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.7936e-01,  3.0453e-01,  1.5097e-02,  1.7338e-01,  2.6121e+00,
-         1.3216e-02,  3.2080e-11, -2.5330e-04,  1.8587e-02,  1.7673e-01,
-         2.1569e-01,  5.7670e-03, -6.0712e-02, -5.8702e-01, -5.7528e-08,
-        -4.2757e-05, -4.2058e-06, -4.6727e-02, -9.7277e-09,  1.1696e-04,
-        -6.8272e-01,  1.1866e-01, -2.0658e-08, -8.0133e-01,  5.6137e-01,
-         1.6766e-04,  2.3355e-01, -1.3338e-03,  9.5496e-02,  3.0044e-08,
-        -5.4001e-01, -1.3467e+00, -7.6907e-12, -3.7930e-01,  3.8173e-06,
-         2.0597e-04,  1.2734e-07,  0.0000e+00,  6.0747e-01,  5.9366e-01,
-         4.1607e-01, -1.7008e-05, -2.9327e-01,  3.1338e-08,  7.6651e-06,
-        -1.0513e-01, -1.0815e-03,  2.2642e-01,  1.0519e+00, -1.9951e-09,
-         8.3533e-02, -4.9148e-03, -2.6531e-01, -3.1990e-01,  8.0637e-01,
-         4.1521e-13, -1.6318e-07,  5.8687e-11,  7.4716e-01, -8.4757e-03,
-         1.5677e-10,  2.4216e-01,  4.6870e-04, -3.9941e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 1.7936e-01,  3.0453e-01,  1.5097e-02,  1.7338e-01,  2.6121e+00,
-         1.3216e-02,  0.0000e+00,  0.0000e+00,  1.8587e-02,  1.7673e-01,
-         2.1569e-01,  0.0000e+00, -6.0712e-02, -5.8702e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.6727e-02,  0.0000e+00,  0.0000e+00,
-        -6.8272e-01,  1.1866e-01,  0.0000e+00, -8.0133e-01,  5.6137e-01,
-         0.0000e+00,  2.3355e-01, -1.3338e-03,  9.5496e-02,  0.0000e+00,
-        -5.4001e-01, -1.3467e+00,  0.0000e+00, -3.7930e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.0747e-01,  5.9366e-01,
-         4.1607e-01,  0.0000e+00, -2.9327e-01,  0.0000e+00,  0.0000e+00,
-        -1.0513e-01,  0.0000e+00,  2.2642e-01,  1.0519e+00,  0.0000e+00,
-         8.3533e-02, -4.9148e-03, -2.6531e-01, -3.1990e-01,  8.0637e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  7.4716e-01,  0.0000e+00,
-         0.0000e+00,  2.4216e-01,  4.6870e-04, -3.9941e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 1.7936e-01,  3.0453e-01,  1.5097e-02,  1.7338e-01,  2.6121e+00,
-         1.3216e-02,  0.0000e+00,  0.0000e+00,  1.8587e-02,  1.7673e-01,
-         2.1569e-01,  0.0000e+00, -6.0712e-02, -5.8702e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.6727e-02,  0.0000e+00,  0.0000e+00,
-        -6.8272e-01,  1.1866e-01,  0.0000e+00, -8.0133e-01,  5.6137e-01,
-         0.0000e+00,  2.3355e-01, -1.3338e-03,  9.5496e-02,  0.0000e+00,
-        -5.4001e-01, -1.3467e+00,  0.0000e+00, -3.7930e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.0747e-01,  5.9366e-01,
-         4.1607e-01,  0.0000e+00, -2.9327e-01,  0.0000e+00,  0.0000e+00,
-        -1.0513e-01,  0.0000e+00,  2.2642e-01,  1.0519e+00,  0.0000e+00,
-         8.3533e-02, -4.9148e-03, -2.6531e-01, -3.1990e-01,  8.0637e-01,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  7.4716e-01,  0.0000e+00,
-         0.0000e+00,  2.4216e-01,  4.6870e-04, -3.9941e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 7.0065e-02,  3.1371e-01, -1.9506e-02,  2.0379e-01,  2.6102e+00,
-        -1.9865e-02,  2.7628e-11, -2.1816e-04,  5.2595e-02,  1.5022e-01,
-         1.5863e-01,  4.9668e-03, -1.3202e-01, -5.9518e-01, -4.9546e-08,
-        -3.6824e-05, -3.6222e-06, -1.2782e-01, -8.3779e-09,  1.0073e-04,
-        -6.4142e-01,  8.8004e-02, -1.7792e-08, -7.9065e-01,  6.0451e-01,
-         1.4440e-04,  1.7498e-01, -1.3778e-02,  2.0686e-01,  2.5875e-08,
-        -5.4047e-01, -1.3431e+00, -6.6236e-12, -3.4356e-01,  3.2876e-06,
-         1.7739e-04,  1.0967e-07,  0.0000e+00,  5.7402e-01,  5.6533e-01,
-         4.3342e-01, -1.4648e-05, -3.1958e-01,  2.6990e-08,  6.6015e-06,
-        -8.7567e-02, -9.3146e-04,  1.6705e-01,  1.0717e+00, -1.7182e-09,
-         9.8817e-02,  5.5946e-02, -2.5770e-01, -3.1858e-01,  7.8982e-01,
-         3.5760e-13, -1.4054e-07,  5.0544e-11,  7.7452e-01, -7.2996e-03,
-         1.3501e-10,  2.0570e-01, -3.0742e-03, -2.8856e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0701,  0.3137, -0.0195,  0.2038,  2.6102, -0.0199,  0.0000,  0.0000,
-         0.0526,  0.1502,  0.1586,  0.0000, -0.1320, -0.5952,  0.0000,  0.0000,
-         0.0000, -0.1278,  0.0000,  0.0000, -0.6414,  0.0880,  0.0000, -0.7907,
-         0.6045,  0.0000,  0.1750, -0.0138,  0.2069,  0.0000, -0.5405, -1.3431,
-         0.0000, -0.3436,  0.0000,  0.0000,  0.0000,  0.0000,  0.5740,  0.5653,
-         0.4334,  0.0000, -0.3196,  0.0000,  0.0000, -0.0876,  0.0000,  0.1670,
-         1.0717,  0.0000,  0.0988,  0.0559, -0.2577, -0.3186,  0.7898,  0.0000,
-         0.0000,  0.0000,  0.7745,  0.0000,  0.0000,  0.2057, -0.0031, -0.0289],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0701,  0.3137, -0.0195,  0.2038,  2.6102, -0.0199,  0.0000,  0.0000,
-         0.0526,  0.1502,  0.1586,  0.0000, -0.1320, -0.5952,  0.0000,  0.0000,
-         0.0000, -0.1278,  0.0000,  0.0000, -0.6414,  0.0880,  0.0000, -0.7907,
-         0.6045,  0.0000,  0.1750, -0.0138,  0.2069,  0.0000, -0.5405, -1.3431,
-         0.0000, -0.3436,  0.0000,  0.0000,  0.0000,  0.0000,  0.5740,  0.5653,
-         0.4334,  0.0000, -0.3196,  0.0000,  0.0000, -0.0876,  0.0000,  0.1670,
-         1.0717,  0.0000,  0.0988,  0.0559, -0.2577, -0.3186,  0.7898,  0.0000,
-         0.0000,  0.0000,  0.7745,  0.0000,  0.0000,  0.2057, -0.0031, -0.0289],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 9.3838e-02,  2.8430e-01,  4.1292e-02,  2.1230e-01,  2.6069e+00,
-         9.0751e-02,  2.3802e-11, -1.8794e-04,  1.3358e-01,  6.4146e-02,
-        -2.6545e-02,  4.2790e-03, -1.1373e-01, -5.9493e-01, -4.2684e-08,
-        -3.1724e-05, -3.1206e-06, -1.0975e-01, -7.2176e-09,  8.6780e-05,
-        -6.1596e-01, -4.0593e-02, -1.5328e-08, -8.0359e-01,  5.1597e-01,
-         1.2440e-04,  1.2774e-01, -6.3832e-03,  3.6516e-01,  2.2292e-08,
-        -5.1830e-01, -1.3407e+00, -5.7063e-12, -3.0675e-01,  2.8323e-06,
-         1.5282e-04,  9.4479e-08,  0.0000e+00,  6.0493e-01,  5.7349e-01,
-         3.4717e-01, -1.2619e-05, -3.0151e-01,  2.3252e-08,  5.6872e-06,
-        -1.8385e-01, -8.0246e-04,  1.9560e-01,  1.0823e+00, -1.4803e-09,
-         1.8315e-02, -4.5105e-02, -3.5821e-01, -1.9945e-01,  7.8836e-01,
-         3.0807e-13, -1.2108e-07,  4.3544e-11,  7.3948e-01, -6.2887e-03,
-         1.1632e-10,  2.3044e-01, -6.8403e-02, -2.8809e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0938,  0.2843,  0.0413,  0.2123,  2.6069,  0.0908,  0.0000,  0.0000,
-         0.1336,  0.0641, -0.0265,  0.0000, -0.1137, -0.5949,  0.0000,  0.0000,
-         0.0000, -0.1098,  0.0000,  0.0000, -0.6160, -0.0406,  0.0000, -0.8036,
-         0.5160,  0.0000,  0.1277, -0.0064,  0.3652,  0.0000, -0.5183, -1.3407,
-         0.0000, -0.3067,  0.0000,  0.0000,  0.0000,  0.0000,  0.6049,  0.5735,
-         0.3472,  0.0000, -0.3015,  0.0000,  0.0000, -0.1839,  0.0000,  0.1956,
-         1.0823,  0.0000,  0.0183, -0.0451, -0.3582, -0.1994,  0.7884,  0.0000,
-         0.0000,  0.0000,  0.7395,  0.0000,  0.0000,  0.2304, -0.0684, -0.0288],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0938,  0.2843,  0.0413,  0.2123,  2.6069,  0.0908,  0.0000,  0.0000,
-         0.1336,  0.0641, -0.0265,  0.0000, -0.1137, -0.5949,  0.0000,  0.0000,
-         0.0000, -0.1098,  0.0000,  0.0000, -0.6160, -0.0406,  0.0000, -0.8036,
-         0.5160,  0.0000,  0.1277, -0.0064,  0.3652,  0.0000, -0.5183, -1.3407,
-         0.0000, -0.3067,  0.0000,  0.0000,  0.0000,  0.0000,  0.6049,  0.5735,
-         0.3472,  0.0000, -0.3015,  0.0000,  0.0000, -0.1839,  0.0000,  0.1956,
-         1.0823,  0.0000,  0.0183, -0.0451, -0.3582, -0.1994,  0.7884,  0.0000,
-         0.0000,  0.0000,  0.7395,  0.0000,  0.0000,  0.2304, -0.0684, -0.0288],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 8.3765e-02,  2.4966e-01,  2.6514e-02,  1.8825e-01,  2.6063e+00,
-         1.6574e-01,  2.0512e-11, -1.6197e-04,  2.1275e-01, -2.5420e-02,
-        -1.2117e-01,  3.6875e-03, -1.1914e-01, -5.6019e-01, -3.6784e-08,
-        -2.7339e-05, -2.6892e-06, -2.0505e-01, -6.2200e-09,  7.4785e-05,
-        -5.5346e-01, -1.6983e-01, -1.3209e-08, -8.2769e-01,  4.2376e-01,
-         1.0720e-04,  5.3785e-02, -4.1806e-03,  4.9748e-01,  1.9210e-08,
-        -4.7906e-01, -1.3382e+00, -4.9175e-12, -3.2176e-01,  2.4408e-06,
-         1.3170e-04,  8.1420e-08,  0.0000e+00,  6.0367e-01,  5.7964e-01,
-         2.7926e-01, -1.0875e-05, -3.5318e-01,  2.0038e-08,  4.9011e-06,
-        -2.1787e-01, -6.9153e-04,  2.0333e-01,  1.0902e+00, -1.2757e-09,
-        -7.4045e-02, -8.8392e-02, -3.5711e-01, -2.9035e-01,  7.7763e-01,
-         2.6549e-13, -1.0434e-07,  3.7525e-11,  6.9522e-01, -5.4194e-03,
-         1.0024e-10,  2.2369e-01, -1.1674e-01, -5.7528e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0838,  0.2497,  0.0265,  0.1883,  2.6063,  0.1657,  0.0000,  0.0000,
-         0.2127, -0.0254, -0.1212,  0.0000, -0.1191, -0.5602,  0.0000,  0.0000,
-         0.0000, -0.2051,  0.0000,  0.0000, -0.5535, -0.1698,  0.0000, -0.8277,
-         0.4238,  0.0000,  0.0538, -0.0042,  0.4975,  0.0000, -0.4791, -1.3382,
-         0.0000, -0.3218,  0.0000,  0.0000,  0.0000,  0.0000,  0.6037,  0.5796,
-         0.2793,  0.0000, -0.3532,  0.0000,  0.0000, -0.2179,  0.0000,  0.2033,
-         1.0902,  0.0000, -0.0740, -0.0884, -0.3571, -0.2904,  0.7776,  0.0000,
-         0.0000,  0.0000,  0.6952,  0.0000,  0.0000,  0.2237, -0.1167, -0.0575],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0838,  0.2497,  0.0265,  0.1883,  2.6063,  0.1657,  0.0000,  0.0000,
-         0.2127, -0.0254, -0.1212,  0.0000, -0.1191, -0.5602,  0.0000,  0.0000,
-         0.0000, -0.2051,  0.0000,  0.0000, -0.5535, -0.1698,  0.0000, -0.8277,
-         0.4238,  0.0000,  0.0538, -0.0042,  0.4975,  0.0000, -0.4791, -1.3382,
-         0.0000, -0.3218,  0.0000,  0.0000,  0.0000,  0.0000,  0.6037,  0.5796,
-         0.2793,  0.0000, -0.3532,  0.0000,  0.0000, -0.2179,  0.0000,  0.2033,
-         1.0902,  0.0000, -0.0740, -0.0884, -0.3571, -0.2904,  0.7776,  0.0000,
-         0.0000,  0.0000,  0.6952,  0.0000,  0.0000,  0.2237, -0.1167, -0.0575],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.1670e-02,  2.7265e-01, -4.7193e-02,  7.5304e-02,  2.6037e+00,
-         1.6776e-01,  1.7682e-11, -1.3962e-04,  2.5282e-01, -9.9736e-02,
-        -1.2384e-02,  3.1788e-03, -1.9863e-01, -5.6802e-01, -3.1709e-08,
-        -2.3567e-05, -2.3182e-06, -2.3874e-01, -5.3619e-09,  6.4467e-05,
-        -4.5752e-01, -1.7830e-01, -1.1387e-08, -7.8851e-01,  3.3238e-01,
-         9.2415e-05, -3.0496e-02, -3.2712e-02,  6.4108e-01,  1.6560e-08,
-        -4.6195e-01, -1.3296e+00, -4.2391e-12, -2.2793e-01,  2.1041e-06,
-         1.1353e-04,  7.0187e-08,  0.0000e+00,  5.6857e-01,  5.5359e-01,
-         2.3883e-01, -9.3746e-06, -4.3676e-01,  1.7274e-08,  4.2250e-06,
-        -2.7387e-01, -5.9613e-04,  2.2122e-01,  1.0894e+00, -1.0997e-09,
-        -2.8585e-02, -1.7422e-02, -3.4544e-01, -3.3349e-01,  7.4652e-01,
-         2.2886e-13, -8.9946e-08,  3.2348e-11,  6.3989e-01, -4.6717e-03,
-         8.6409e-11,  1.7249e-01, -1.4442e-01, -1.0016e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0117,  0.2726, -0.0472,  0.0753,  2.6037,  0.1678,  0.0000,  0.0000,
-         0.2528, -0.0997, -0.0124,  0.0000, -0.1986, -0.5680,  0.0000,  0.0000,
-         0.0000, -0.2387,  0.0000,  0.0000, -0.4575, -0.1783,  0.0000, -0.7885,
-         0.3324,  0.0000, -0.0305, -0.0327,  0.6411,  0.0000, -0.4619, -1.3296,
-         0.0000, -0.2279,  0.0000,  0.0000,  0.0000,  0.0000,  0.5686,  0.5536,
-         0.2388,  0.0000, -0.4368,  0.0000,  0.0000, -0.2739,  0.0000,  0.2212,
-         1.0894,  0.0000, -0.0286, -0.0174, -0.3454, -0.3335,  0.7465,  0.0000,
-         0.0000,  0.0000,  0.6399,  0.0000,  0.0000,  0.1725, -0.1444, -0.1002],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0117,  0.2726, -0.0472,  0.0753,  2.6037,  0.1678,  0.0000,  0.0000,
-         0.2528, -0.0997, -0.0124,  0.0000, -0.1986, -0.5680,  0.0000,  0.0000,
-         0.0000, -0.2387,  0.0000,  0.0000, -0.4575, -0.1783,  0.0000, -0.7885,
-         0.3324,  0.0000, -0.0305, -0.0327,  0.6411,  0.0000, -0.4619, -1.3296,
-         0.0000, -0.2279,  0.0000,  0.0000,  0.0000,  0.0000,  0.5686,  0.5536,
-         0.2388,  0.0000, -0.4368,  0.0000,  0.0000, -0.2739,  0.0000,  0.2212,
-         1.0894,  0.0000, -0.0286, -0.0174, -0.3454, -0.3335,  0.7465,  0.0000,
-         0.0000,  0.0000,  0.6399,  0.0000,  0.0000,  0.1725, -0.1444, -0.1002],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.7204e-01,  3.2178e-01, -1.4703e-01, -2.8551e-02,  2.5997e+00,
-         1.2472e-01,  1.5247e-11, -1.2040e-04,  2.3038e-01, -1.3700e-01,
-         9.3311e-02,  2.7411e-03, -3.2469e-01, -5.8426e-01, -2.7343e-08,
-        -2.0322e-05, -1.9990e-06, -2.1220e-01, -4.6236e-09,  5.5591e-05,
-        -3.7621e-01, -1.3197e-01, -9.8188e-09, -7.6244e-01,  2.8318e-01,
-         7.9690e-05, -8.8889e-02, -6.1810e-02,  7.5561e-01,  1.4280e-08,
-        -4.2168e-01, -1.3220e+00, -3.6554e-12, -1.5515e-01,  1.8144e-06,
-         9.7896e-05,  6.0523e-08,  0.0000e+00,  4.9392e-01,  5.0235e-01,
-         2.3983e-01, -8.0838e-06, -5.0152e-01,  1.4895e-08,  3.6432e-06,
-        -2.9534e-01, -5.1405e-04,  2.2077e-01,  1.0912e+00, -9.4826e-10,
-         5.0534e-02,  2.8403e-02, -3.5502e-01, -2.7906e-01,  7.0893e-01,
-         1.9735e-13, -7.7562e-08,  2.7894e-11,  5.7576e-01, -4.0285e-03,
-         7.4511e-11,  1.1739e-01, -1.6277e-01, -1.0950e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.1720,  0.3218, -0.1470, -0.0286,  2.5997,  0.1247,  0.0000,  0.0000,
-         0.2304, -0.1370,  0.0933,  0.0000, -0.3247, -0.5843,  0.0000,  0.0000,
-         0.0000, -0.2122,  0.0000,  0.0000, -0.3762, -0.1320,  0.0000, -0.7624,
-         0.2832,  0.0000, -0.0889, -0.0618,  0.7556,  0.0000, -0.4217, -1.3220,
-         0.0000, -0.1552,  0.0000,  0.0000,  0.0000,  0.0000,  0.4939,  0.5023,
-         0.2398,  0.0000, -0.5015,  0.0000,  0.0000, -0.2953,  0.0000,  0.2208,
-         0.0000,  0.0000,  0.0505,  0.0284, -0.3550, -0.2791,  0.7089,  0.0000,
-         0.0000,  0.0000,  0.5758,  0.0000,  0.0000,  0.1174, -0.1628, -0.1095],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.1720,  0.3218, -0.1470, -0.0286,  2.5997,  0.1247,  0.0000,  0.0000,
-         0.2304, -0.1370,  0.0933,  0.0000, -0.3247, -0.5843,  0.0000,  0.0000,
-         0.0000, -0.2122,  0.0000,  0.0000, -0.3762, -0.1320,  0.0000, -0.7624,
-         0.2832,  0.0000, -0.0889, -0.0618,  0.7556,  0.0000, -0.4217, -1.3220,
-         0.0000, -0.1552,  0.0000,  0.0000,  0.0000,  0.0000,  0.4939,  0.5023,
-         0.2398,  0.0000, -0.5015,  0.0000,  0.0000, -0.2953,  0.0000,  0.2208,
-         0.0000,  0.0000,  0.0505,  0.0284, -0.3550, -0.2791,  0.7089,  0.0000,
-         0.0000,  0.0000,  0.5758,  0.0000,  0.0000,  0.1174, -0.1628, -0.1095],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.5002e-01,  3.3807e-01, -1.8290e-01, -4.8839e-02,  2.5935e+00,
-         6.1193e-02,  1.3152e-11, -1.0385e-04,  2.1490e-01, -1.4038e-01,
-         6.9259e-02,  2.3644e-03, -3.9389e-01, -6.0241e-01, -2.3586e-08,
-        -1.7530e-05, -1.7243e-06, -8.1251e-02, -3.9882e-09,  4.7952e-05,
-        -3.8821e-01, -1.5784e-02, -8.4695e-09, -7.7198e-01,  2.7099e-01,
-         6.8739e-05, -1.1067e-01, -8.7536e-02,  8.3280e-01,  1.2318e-08,
-        -3.5070e-01, -1.3209e+00, -3.1531e-12, -9.8398e-02,  1.5650e-06,
-         8.4443e-05,  5.2206e-08,  0.0000e+00,  4.1829e-01,  4.3972e-01,
-         2.1830e-01, -6.9730e-06, -5.0779e-01,  1.2848e-08,  3.1426e-06,
-        -2.7465e-01, -4.4341e-04,  2.1188e-01,  1.4916e-03, -8.1795e-10,
-         6.9243e-02,  3.1586e-02, -4.2104e-01, -5.9834e-02,  6.8751e-01,
-         1.7023e-13, -6.6903e-08,  2.4061e-11,  5.3913e-01, -3.4749e-03,
-         6.4272e-11,  8.7096e-02, -1.6642e-01, -1.3174e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.2500,  0.3381, -0.1829, -0.0488,  2.5935,  0.0612,  0.0000,  0.0000,
-         0.2149, -0.1404,  0.0693,  0.0000, -0.3939, -0.6024,  0.0000,  0.0000,
-         0.0000, -0.0813,  0.0000,  0.0000, -0.3882, -0.0158,  0.0000, -0.7720,
-         0.2710,  0.0000, -0.1107, -0.0875,  0.8328,  0.0000, -0.3507, -1.3209,
-         0.0000, -0.0984,  0.0000,  0.0000,  0.0000,  0.0000,  0.4183,  0.4397,
-         0.2183,  0.0000, -0.5078,  0.0000,  0.0000, -0.2747,  0.0000,  0.2119,
-         0.0000,  0.0000,  0.0692,  0.0316, -0.4210, -0.0598,  0.6875,  0.0000,
-         0.0000,  0.0000,  0.5391,  0.0000,  0.0000,  0.0871, -0.1664, -0.1317],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.2500,  0.3381, -0.1829, -0.0488,  2.5935,  0.0612,  0.0000,  0.0000,
-         0.2149, -0.1404,  0.0693,  0.0000, -0.3939, -0.6024,  0.0000,  0.0000,
-         0.0000, -0.0813,  0.0000,  0.0000, -0.3882, -0.0158,  0.0000, -0.7720,
-         0.2710,  0.0000, -0.1107, -0.0875,  0.8328,  0.0000, -0.3507, -1.3209,
-         0.0000, -0.0984,  0.0000,  0.0000,  0.0000,  0.0000,  0.4183,  0.4397,
-         0.2183,  0.0000, -0.5078,  0.0000,  0.0000, -0.2747,  0.0000,  0.2119,
-         0.0000,  0.0000,  0.0692,  0.0316, -0.4210, -0.0598,  0.6875,  0.0000,
-         0.0000,  0.0000,  0.5391,  0.0000,  0.0000,  0.0871, -0.1664, -0.1317],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.4532e-01,  3.0665e-01, -1.3084e-01, -7.1319e-02,  2.5867e+00,
-         5.4689e-02,  1.1348e-11, -8.9609e-05,  2.1722e-01, -1.8086e-01,
-        -1.6117e-02,  2.0401e-03, -3.7858e-01, -6.1900e-01, -2.0351e-08,
-        -1.5126e-05, -1.4878e-06, -1.1687e-02, -3.4412e-09,  4.1375e-05,
-        -4.2142e-01,  5.0488e-02, -7.3079e-09, -7.9645e-01,  2.1084e-01,
-         5.9312e-05, -1.1058e-01, -1.1382e-01,  8.6231e-01,  1.0628e-08,
-        -2.9587e-01, -1.3244e+00, -2.7206e-12, -8.1925e-02,  1.3504e-06,
-         7.2862e-05,  4.5046e-08,  0.0000e+00,  4.5492e-01,  4.5656e-01,
-         1.5089e-01, -6.0167e-06, -4.6471e-01,  1.1086e-08,  2.7116e-06,
-        -2.0358e-01, -3.8260e-04,  2.3888e-01,  1.2871e-03, -7.0577e-10,
-         4.9486e-02, -5.5265e-03, -4.2209e-01,  1.4293e-01,  7.0266e-01,
-         1.4688e-13, -5.7728e-08,  2.0761e-11,  5.1543e-01, -2.9983e-03,
-         5.5457e-11,  9.3411e-02, -1.0523e-01, -1.6950e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.2453,  0.3066, -0.1308, -0.0713,  2.5867,  0.0547,  0.0000,  0.0000,
-         0.2172, -0.1809, -0.0161,  0.0000, -0.3786, -0.6190,  0.0000,  0.0000,
-         0.0000, -0.0117,  0.0000,  0.0000, -0.4214,  0.0505,  0.0000, -0.7964,
-         0.2108,  0.0000, -0.1106, -0.1138,  0.8623,  0.0000, -0.2959, -1.3244,
-         0.0000, -0.0819,  0.0000,  0.0000,  0.0000,  0.0000,  0.4549,  0.4566,
-         0.1509,  0.0000, -0.4647,  0.0000,  0.0000, -0.2036,  0.0000,  0.2389,
-         0.0000,  0.0000,  0.0495, -0.0055, -0.4221,  0.1429,  0.7027,  0.0000,
-         0.0000,  0.0000,  0.5154,  0.0000,  0.0000,  0.0934, -0.1052, -0.1695],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.2453,  0.3066, -0.1308, -0.0713,  2.5867,  0.0547,  0.0000,  0.0000,
-         0.2172, -0.1809, -0.0161,  0.0000, -0.3786, -0.6190,  0.0000,  0.0000,
-         0.0000, -0.0117,  0.0000,  0.0000, -0.4214,  0.0505,  0.0000, -0.7964,
-         0.2108,  0.0000, -0.1106, -0.1138,  0.8623,  0.0000, -0.2959, -1.3244,
-         0.0000, -0.0819,  0.0000,  0.0000,  0.0000,  0.0000,  0.4549,  0.4566,
-         0.1509,  0.0000, -0.4647,  0.0000,  0.0000, -0.2036,  0.0000,  0.2389,
-         0.0000,  0.0000,  0.0495, -0.0055, -0.4221,  0.1429,  0.7027,  0.0000,
-         0.0000,  0.0000,  0.5154,  0.0000,  0.0000,  0.0934, -0.1052, -0.1695],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.3053e-01,  2.7186e-01, -7.1899e-02, -1.1368e-01,  2.5834e+00,
-         8.6707e-02,  9.7952e-12, -7.7344e-05,  2.3148e-01, -2.2406e-01,
-        -7.3512e-02,  1.7609e-03, -3.6414e-01, -6.3759e-01, -1.7566e-08,
-        -1.3055e-05, -1.2842e-06, -2.2711e-02, -2.9702e-09,  3.5712e-05,
-        -4.7398e-01,  8.8438e-02, -6.3077e-09, -8.3519e-01,  1.5209e-01,
-         5.1194e-05, -1.0593e-01, -1.6086e-01,  8.9108e-01,  9.1737e-09,
-        -2.6953e-01, -1.3296e+00, -2.3483e-12, -7.8036e-02,  1.1656e-06,
-         6.2890e-05,  3.8881e-08,  0.0000e+00,  4.6880e-01,  4.5075e-01,
-         5.4863e-02, -5.1932e-06, -4.2603e-01,  9.5689e-09,  2.3405e-06,
-        -1.1213e-01, -3.3023e-04,  2.5037e-01,  1.1109e-03, -6.0918e-10,
-         1.7240e-02, -7.7790e-02, -3.7834e-01,  2.1872e-01,  6.9589e-01,
-         1.2678e-13, -4.9827e-08,  1.7919e-11,  5.1353e-01, -2.5880e-03,
-         4.7867e-11,  9.9574e-02, -3.6641e-02, -1.8956e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.2305,  0.2719, -0.0719, -0.1137,  2.5834,  0.0867,  0.0000,  0.0000,
-         0.2315, -0.2241, -0.0735,  0.0000, -0.3641, -0.6376,  0.0000,  0.0000,
-         0.0000, -0.0227,  0.0000,  0.0000, -0.4740,  0.0884,  0.0000, -0.8352,
-         0.1521,  0.0000, -0.1059, -0.1609,  0.8911,  0.0000, -0.2695, -1.3296,
-         0.0000, -0.0780,  0.0000,  0.0000,  0.0000,  0.0000,  0.4688,  0.4508,
-         0.0549,  0.0000, -0.4260,  0.0000,  0.0000, -0.1121,  0.0000,  0.2504,
-         0.0000,  0.0000,  0.0172, -0.0778, -0.3783,  0.2187,  0.6959,  0.0000,
-         0.0000,  0.0000,  0.5135,  0.0000,  0.0000,  0.0996, -0.0366, -0.1896],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.2305,  0.2719, -0.0719, -0.1137,  2.5834,  0.0867,  0.0000,  0.0000,
-         0.2315, -0.2241, -0.0735,  0.0000, -0.3641, -0.6376,  0.0000,  0.0000,
-         0.0000, -0.0227,  0.0000,  0.0000, -0.4740,  0.0884,  0.0000, -0.8352,
-         0.1521,  0.0000, -0.1059, -0.1609,  0.8911,  0.0000, -0.2695, -1.3296,
-         0.0000, -0.0780,  0.0000,  0.0000,  0.0000,  0.0000,  0.4688,  0.4508,
-         0.0549,  0.0000, -0.4260,  0.0000,  0.0000, -0.1121,  0.0000,  0.2504,
-         0.0000,  0.0000,  0.0172, -0.0778, -0.3783,  0.2187,  0.6959,  0.0000,
-         0.0000,  0.0000,  0.5135,  0.0000,  0.0000,  0.0996, -0.0366, -0.1896],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.3780e-01,  2.1329e-01,  1.9273e-02, -1.3147e-01,  2.5747e+00,
-         1.3677e-01,  8.4573e-12, -6.6780e-05,  2.6610e-01, -2.4826e-01,
-        -4.2789e-02,  1.5204e-03, -3.3520e-01, -6.1838e-01, -1.5166e-08,
-        -1.1272e-05, -1.1088e-06, -5.3885e-02, -2.5645e-09,  3.0834e-05,
-        -4.7221e-01,  1.0094e-01, -5.4461e-09, -8.4031e-01,  8.0027e-02,
-         4.4201e-05, -9.4712e-02, -1.8224e-01,  8.9433e-01,  7.9206e-09,
-        -2.9764e-01, -1.3322e+00, -2.0275e-12, -8.2277e-02,  1.0064e-06,
-         5.4300e-05,  3.3570e-08,  0.0000e+00,  4.9957e-01,  4.6972e-01,
-        -8.2387e-03, -4.4838e-06, -4.0756e-01,  8.2619e-09,  2.0208e-06,
-        -5.2432e-02, -2.8513e-04,  2.3684e-01,  9.5916e-04, -5.2597e-10,
-         3.2168e-02, -9.4077e-02, -3.3381e-01,  2.1378e-01,  6.9074e-01,
-         1.0946e-13, -4.3021e-08,  1.5472e-11,  5.2364e-01, -2.2345e-03,
-         4.1329e-11,  1.1983e-01,  3.7000e-02, -1.6981e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.2378,  0.2133,  0.0193, -0.1315,  2.5747,  0.1368,  0.0000,  0.0000,
-         0.2661, -0.2483, -0.0428,  0.0000, -0.3352, -0.6184,  0.0000,  0.0000,
-         0.0000, -0.0539,  0.0000,  0.0000, -0.4722,  0.1009,  0.0000, -0.8403,
-         0.0800,  0.0000, -0.0947, -0.1822,  0.8943,  0.0000, -0.2976, -1.3322,
-         0.0000, -0.0823,  0.0000,  0.0000,  0.0000,  0.0000,  0.4996,  0.4697,
-        -0.0082,  0.0000, -0.4076,  0.0000,  0.0000, -0.0524,  0.0000,  0.2368,
-         0.0000,  0.0000,  0.0322, -0.0941, -0.3338,  0.2138,  0.6907,  0.0000,
-         0.0000,  0.0000,  0.5236,  0.0000,  0.0000,  0.1198,  0.0370, -0.1698],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.2378,  0.2133,  0.0193, -0.1315,  2.5747,  0.1368,  0.0000,  0.0000,
-         0.2661, -0.2483, -0.0428,  0.0000, -0.3352, -0.6184,  0.0000,  0.0000,
-         0.0000, -0.0539,  0.0000,  0.0000, -0.4722,  0.1009,  0.0000, -0.8403,
-         0.0800,  0.0000, -0.0947, -0.1822,  0.8943,  0.0000, -0.2976, -1.3322,
-         0.0000, -0.0823,  0.0000,  0.0000,  0.0000,  0.0000,  0.4996,  0.4697,
-        -0.0082,  0.0000, -0.4076,  0.0000,  0.0000, -0.0524,  0.0000,  0.2368,
-         0.0000,  0.0000,  0.0322, -0.0941, -0.3338,  0.2138,  0.6907,  0.0000,
-         0.0000,  0.0000,  0.5236,  0.0000,  0.0000,  0.1198,  0.0370, -0.1698],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.4799e-01,  1.9348e-01,  7.2727e-02, -1.3215e-01,  2.5653e+00,
-         1.7513e-01,  7.3045e-12, -5.7677e-05,  2.9346e-01, -2.6757e-01,
-         5.5927e-02,  1.3131e-03, -3.2097e-01, -5.9302e-01, -1.3099e-08,
-        -9.7356e-06, -9.5765e-07, -4.7899e-02, -2.2150e-09,  2.6631e-05,
-        -4.2798e-01,  1.0752e-01, -4.7038e-09, -8.2876e-01,  2.9705e-02,
-         3.8176e-05, -8.1453e-02, -1.7743e-01,  8.8498e-01,  6.8410e-09,
-        -3.7272e-01, -1.3324e+00, -1.7512e-12, -9.7805e-02,  8.6919e-07,
-         4.6898e-05,  2.8994e-08,  0.0000e+00,  4.8838e-01,  4.7618e-01,
-        -5.6818e-03, -3.8726e-06, -3.9282e-01,  7.1357e-09,  1.7453e-06,
-        -4.8389e-02, -2.4626e-04,  2.4342e-01,  8.2842e-04, -4.5427e-10,
-         7.9226e-02, -5.3006e-02, -3.3095e-01,  1.4087e-01,  6.7124e-01,
-         9.4542e-14, -3.7157e-08,  1.3363e-11,  5.2888e-01, -1.9299e-03,
-         3.5695e-11,  1.1663e-01,  8.2148e-02, -1.1581e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.2480,  0.1935,  0.0727, -0.1321,  2.5653,  0.1751,  0.0000,  0.0000,
-         0.2935, -0.2676,  0.0559,  0.0000, -0.3210, -0.5930,  0.0000,  0.0000,
-         0.0000, -0.0479,  0.0000,  0.0000, -0.4280,  0.1075,  0.0000, -0.8288,
-         0.0297,  0.0000, -0.0815, -0.1774,  0.8850,  0.0000, -0.3727, -1.3324,
-         0.0000, -0.0978,  0.0000,  0.0000,  0.0000,  0.0000,  0.4884,  0.4762,
-        -0.0057,  0.0000, -0.3928,  0.0000,  0.0000, -0.0484,  0.0000,  0.2434,
-         0.0000,  0.0000,  0.0792, -0.0530, -0.3310,  0.1409,  0.6712,  0.0000,
-         0.0000,  0.0000,  0.5289,  0.0000,  0.0000,  0.1166,  0.0821, -0.1158],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.2480,  0.1935,  0.0727, -0.1321,  2.5653,  0.1751,  0.0000,  0.0000,
-         0.2935, -0.2676,  0.0559,  0.0000, -0.3210, -0.5930,  0.0000,  0.0000,
-         0.0000, -0.0479,  0.0000,  0.0000, -0.4280,  0.1075,  0.0000, -0.8288,
-         0.0297,  0.0000, -0.0815, -0.1774,  0.8850,  0.0000, -0.3727, -1.3324,
-         0.0000, -0.0978,  0.0000,  0.0000,  0.0000,  0.0000,  0.4884,  0.4762,
-        -0.0057,  0.0000, -0.3928,  0.0000,  0.0000, -0.0484,  0.0000,  0.2434,
-         0.0000,  0.0000,  0.0792, -0.0530, -0.3310,  0.1409,  0.6712,  0.0000,
-         0.0000,  0.0000,  0.5289,  0.0000,  0.0000,  0.1166,  0.0821, -0.1158],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.5398e-01,  1.8674e-01,  1.0285e-01, -1.1424e-01,  2.5552e+00,
-         1.5971e-01,  6.3109e-12, -4.9831e-05,  3.3553e-01, -3.2314e-01,
-         1.7727e-01,  1.1345e-03, -3.0748e-01, -5.6069e-01, -1.1317e-08,
-        -8.4113e-06, -8.2738e-07, -5.1191e-03, -1.9137e-09,  2.3009e-05,
-        -3.4238e-01,  1.3296e-01, -4.0639e-09, -7.9836e-01,  3.7046e-02,
-         3.2983e-05, -1.1039e-01, -1.3809e-01,  8.7258e-01,  5.9104e-09,
-        -4.6906e-01, -1.3287e+00, -1.5130e-12, -1.0270e-01,  7.5096e-07,
-         4.0519e-05,  2.5050e-08,  0.0000e+00,  4.6951e-01,  4.7802e-01,
-         3.4243e-02, -3.3459e-06, -3.9607e-01,  6.1651e-09,  1.5079e-06,
-        -6.8945e-02, -2.1276e-04,  2.3163e-01,  7.1573e-04, -3.9248e-10,
-         1.3234e-01,  7.2952e-02, -3.3770e-01,  2.7725e-02,  6.4730e-01,
-         8.1682e-14, -3.2102e-08,  1.1545e-11,  5.4423e-01, -1.6674e-03,
-         3.0840e-11,  8.0261e-02,  1.4813e-01, -8.1414e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.2540,  0.1867,  0.1029, -0.1142,  2.5552,  0.1597,  0.0000,  0.0000,
-         0.3355, -0.3231,  0.1773,  0.0000, -0.3075, -0.5607,  0.0000,  0.0000,
-         0.0000, -0.0051,  0.0000,  0.0000, -0.3424,  0.1330,  0.0000, -0.7984,
-         0.0370,  0.0000, -0.1104, -0.1381,  0.8726,  0.0000, -0.4691, -1.3287,
-         0.0000, -0.1027,  0.0000,  0.0000,  0.0000,  0.0000,  0.4695,  0.4780,
-         0.0342,  0.0000, -0.3961,  0.0000,  0.0000, -0.0689,  0.0000,  0.2316,
-         0.0000,  0.0000,  0.1323,  0.0730, -0.3377,  0.0277,  0.6473,  0.0000,
-         0.0000,  0.0000,  0.5442,  0.0000,  0.0000,  0.0803,  0.1481, -0.0814],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.2540,  0.1867,  0.1029, -0.1142,  2.5552,  0.1597,  0.0000,  0.0000,
-         0.3355, -0.3231,  0.1773,  0.0000, -0.3075, -0.5607,  0.0000,  0.0000,
-         0.0000, -0.0051,  0.0000,  0.0000, -0.3424,  0.1330,  0.0000, -0.7984,
-         0.0370,  0.0000, -0.1104, -0.1381,  0.8726,  0.0000, -0.4691, -1.3287,
-         0.0000, -0.1027,  0.0000,  0.0000,  0.0000,  0.0000,  0.4695,  0.4780,
-         0.0342,  0.0000, -0.3961,  0.0000,  0.0000, -0.0689,  0.0000,  0.2316,
-         0.0000,  0.0000,  0.1323,  0.0730, -0.3377,  0.0277,  0.6473,  0.0000,
-         0.0000,  0.0000,  0.5442,  0.0000,  0.0000,  0.0803,  0.1481, -0.0814],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.6107e-01,  1.7947e-01,  1.0251e-01, -8.1097e-02,  2.5502e+00,
-         1.6180e-01,  5.4542e-12, -4.3067e-05,  3.5964e-01, -3.7686e-01,
-         2.3703e-01,  9.8052e-04, -3.1621e-01, -4.9321e-01, -9.7810e-09,
-        -7.2696e-06, -7.1507e-07,  3.1041e-02, -1.6539e-09,  1.9886e-05,
-        -2.0964e-01,  1.6771e-01, -3.5123e-09, -7.3877e-01,  7.8275e-02,
-         2.8506e-05, -1.6951e-01, -6.0264e-02,  8.6436e-01,  5.1081e-09,
-        -4.9292e-01, -1.3182e+00, -1.3076e-12, -1.0996e-01,  6.4902e-07,
-         3.5019e-05,  2.1650e-08,  0.0000e+00,  4.0392e-01,  4.6579e-01,
-         5.1996e-02, -2.8917e-06, -4.2489e-01,  5.3282e-09,  1.3032e-06,
-        -1.6157e-01, -1.8388e-04,  1.9535e-01,  6.1858e-04, -3.3921e-10,
-         1.3066e-01,  2.2634e-01, -3.6806e-01, -1.1509e-01,  6.2738e-01,
-         7.0594e-14, -2.7745e-08,  9.9781e-12,  5.7131e-01, -1.4410e-03,
-         2.6654e-11, -2.5570e-02,  1.3253e-01, -4.7074e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.2611,  0.1795,  0.1025, -0.0811,  2.5502,  0.1618,  0.0000,  0.0000,
-         0.3596, -0.3769,  0.2370,  0.0000, -0.3162, -0.4932,  0.0000,  0.0000,
-         0.0000,  0.0310,  0.0000,  0.0000, -0.2096,  0.1677,  0.0000, -0.7388,
-         0.0783,  0.0000, -0.1695, -0.0603,  0.8644,  0.0000, -0.4929, -1.3182,
-         0.0000, -0.1100,  0.0000,  0.0000,  0.0000,  0.0000,  0.4039,  0.4658,
-         0.0520,  0.0000, -0.4249,  0.0000,  0.0000, -0.1616,  0.0000,  0.1954,
-         0.0000,  0.0000,  0.1307,  0.2263, -0.3681, -0.1151,  0.6274,  0.0000,
-         0.0000,  0.0000,  0.5713,  0.0000,  0.0000, -0.0256,  0.1325, -0.0471],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.2611,  0.1795,  0.1025, -0.0811,  2.5502,  0.1618,  0.0000,  0.0000,
-         0.3596, -0.3769,  0.2370,  0.0000, -0.3162, -0.4932,  0.0000,  0.0000,
-         0.0000,  0.0310,  0.0000,  0.0000, -0.2096,  0.1677,  0.0000, -0.7388,
-         0.0783,  0.0000, -0.1695, -0.0603,  0.8644,  0.0000, -0.4929, -1.3182,
-         0.0000, -0.1100,  0.0000,  0.0000,  0.0000,  0.0000,  0.4039,  0.4658,
-         0.0520,  0.0000, -0.4249,  0.0000,  0.0000, -0.1616,  0.0000,  0.1954,
-         0.0000,  0.0000,  0.1307,  0.2263, -0.3681, -0.1151,  0.6274,  0.0000,
-         0.0000,  0.0000,  0.5713,  0.0000,  0.0000, -0.0256,  0.1325, -0.0471],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.2958e-01,  1.7168e-01,  1.5807e-01, -4.9534e-02,  2.5436e+00,
-         2.0280e-01,  4.7154e-12, -3.7234e-05,  3.1880e-01, -3.4245e-01,
-         1.8493e-01,  8.4771e-04, -2.7036e-01, -4.5175e-01, -8.4561e-09,
-        -6.2849e-06, -6.1822e-07,  1.0559e-01, -1.4299e-09,  1.7192e-05,
-        -1.6724e-01,  1.9071e-01, -3.0366e-09, -7.2094e-01,  5.4915e-02,
-         2.4645e-05, -1.4277e-01, -4.0272e-02,  8.4157e-01,  4.4162e-09,
-        -4.2830e-01, -1.3113e+00, -1.1305e-12, -1.1241e-01,  5.6111e-07,
-         3.0275e-05,  1.8717e-08,  0.0000e+00,  3.0168e-01,  4.2977e-01,
-         1.9050e-02, -2.5000e-06, -3.9633e-01,  4.6065e-09,  1.1267e-06,
-        -1.6378e-01, -1.5897e-04,  1.8566e-01,  5.3479e-04, -2.9326e-10,
-         1.3747e-01,  2.4789e-01, -3.7385e-01, -6.7709e-02,  5.9899e-01,
-         6.1032e-14, -2.3987e-08,  8.6265e-12,  6.0711e-01, -1.2459e-03,
-         2.3043e-11, -4.4203e-02,  6.8054e-02, -2.9286e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.2296,  0.0000,  0.1581, -0.0495,  2.5436,  0.2028,  0.0000,  0.0000,
-         0.3188, -0.3425,  0.1849,  0.0000, -0.2704, -0.4518,  0.0000,  0.0000,
-         0.0000,  0.1056,  0.0000,  0.0000, -0.1672,  0.1907,  0.0000, -0.7209,
-         0.0549,  0.0000, -0.1428, -0.0403,  0.8416,  0.0000, -0.4283, -1.3113,
-         0.0000, -0.1124,  0.0000,  0.0000,  0.0000,  0.0000,  0.3017,  0.4298,
-         0.0191,  0.0000, -0.3963,  0.0000,  0.0000, -0.1638,  0.0000,  0.1857,
-         0.0000,  0.0000,  0.1375,  0.2479, -0.3739, -0.0677,  0.5990,  0.0000,
-         0.0000,  0.0000,  0.6071,  0.0000,  0.0000, -0.0442,  0.0681, -0.0293],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.2296,  0.0000,  0.1581, -0.0495,  2.5436,  0.2028,  0.0000,  0.0000,
-         0.3188, -0.3425,  0.1849,  0.0000, -0.2704, -0.4518,  0.0000,  0.0000,
-         0.0000,  0.1056,  0.0000,  0.0000, -0.1672,  0.1907,  0.0000, -0.7209,
-         0.0549,  0.0000, -0.1428, -0.0403,  0.8416,  0.0000, -0.4283, -1.3113,
-         0.0000, -0.1124,  0.0000,  0.0000,  0.0000,  0.0000,  0.3017,  0.4298,
-         0.0191,  0.0000, -0.3963,  0.0000,  0.0000, -0.1638,  0.0000,  0.1857,
-         0.0000,  0.0000,  0.1375,  0.2479, -0.3739, -0.0677,  0.5990,  0.0000,
-         0.0000,  0.0000,  0.6071,  0.0000,  0.0000, -0.0442,  0.0681, -0.0293],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.8218e-01, -6.7371e-03,  2.5157e-01,  2.5279e-02,  2.5382e+00,
-         2.7428e-01,  4.0781e-12, -3.2201e-05,  2.4645e-01, -3.0264e-01,
-         3.0325e-02,  7.3313e-04, -2.0910e-01, -4.0136e-01, -7.3132e-09,
-        -5.4354e-06, -5.3466e-07,  2.4469e-01, -1.2366e-09,  1.4868e-05,
-        -1.2418e-01,  2.1882e-01, -2.6261e-09, -6.8648e-01,  3.0166e-02,
-         2.1314e-05, -1.0788e-01, -4.2918e-03,  8.1985e-01,  3.8193e-09,
-        -3.4597e-01, -1.3038e+00, -9.7767e-13, -1.0764e-01,  4.8527e-07,
-         2.6183e-05,  1.6187e-08,  0.0000e+00,  1.8625e-01,  3.9266e-01,
-        -3.0853e-02, -2.1621e-06, -3.7359e-01,  3.9839e-09,  9.7441e-07,
-        -9.9827e-02, -1.3749e-04,  1.7511e-01,  4.6251e-04, -2.5362e-10,
-         1.6434e-01,  2.2562e-01, -4.0487e-01,  1.0930e-01,  5.8738e-01,
-         5.2783e-14, -2.0745e-08,  7.4605e-12,  6.4430e-01, -1.0775e-03,
-         1.9929e-11, -3.3465e-02, -4.2269e-03, -6.1928e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.1822,  0.0000,  0.2516,  0.0253,  2.5382,  0.2743,  0.0000,  0.0000,
-         0.2465, -0.3026,  0.0303,  0.0000, -0.2091, -0.4014,  0.0000,  0.0000,
-         0.0000,  0.2447,  0.0000,  0.0000, -0.1242,  0.2188,  0.0000, -0.6865,
-         0.0302,  0.0000, -0.1079, -0.0043,  0.8198,  0.0000, -0.3460, -1.3038,
-         0.0000, -0.1076,  0.0000,  0.0000,  0.0000,  0.0000,  0.1863,  0.3927,
-        -0.0309,  0.0000, -0.3736,  0.0000,  0.0000, -0.0998,  0.0000,  0.1751,
-         0.0000,  0.0000,  0.1643,  0.2256, -0.4049,  0.1093,  0.5874,  0.0000,
-         0.0000,  0.0000,  0.6443,  0.0000,  0.0000, -0.0335, -0.0042, -0.0619],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.1822,  0.0000,  0.2516,  0.0253,  2.5382,  0.2743,  0.0000,  0.0000,
-         0.2465, -0.3026,  0.0303,  0.0000, -0.2091, -0.4014,  0.0000,  0.0000,
-         0.0000,  0.2447,  0.0000,  0.0000, -0.1242,  0.2188,  0.0000, -0.6865,
-         0.0302,  0.0000, -0.1079, -0.0043,  0.8198,  0.0000, -0.3460, -1.3038,
-         0.0000, -0.1076,  0.0000,  0.0000,  0.0000,  0.0000,  0.1863,  0.3927,
-        -0.0309,  0.0000, -0.3736,  0.0000,  0.0000, -0.0998,  0.0000,  0.1751,
-         0.0000,  0.0000,  0.1643,  0.2256, -0.4049,  0.1093,  0.5874,  0.0000,
-         0.0000,  0.0000,  0.6443,  0.0000,  0.0000, -0.0335, -0.0042, -0.0619],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.6913e-01, -5.8285e-03,  2.5380e-01, -6.8480e-03,  2.5342e+00,
-         3.3634e-01,  3.5281e-12, -2.7858e-05,  1.6857e-01, -2.3391e-01,
-        -5.0993e-02,  6.3425e-04, -2.3246e-01, -4.0379e-01, -6.3268e-09,
-        -4.7023e-06, -4.6255e-07,  2.0262e-01, -1.0698e-09,  1.2863e-05,
-        -1.6629e-01,  1.9073e-01, -2.2719e-09, -6.7686e-01, -1.4812e-02,
-         1.8439e-05, -8.9147e-02, -7.6753e-02,  7.7041e-01,  3.3042e-09,
-        -1.8817e-01, -1.2990e+00, -8.4581e-13, -1.1002e-01,  4.1982e-07,
-         2.2652e-05,  1.4004e-08,  0.0000e+00,  6.6865e-02,  3.8757e-01,
-        -1.0797e-01, -1.8705e-06, -3.6869e-01,  3.4466e-09,  8.4299e-07,
-        -5.4075e-02, -1.1894e-04,  2.2282e-01,  4.0013e-04, -2.1942e-10,
-         1.9317e-01,  1.6503e-01, -4.0102e-01,  3.7400e-03,  5.6429e-01,
-         4.5664e-14, -1.7947e-08,  6.4543e-12,  6.8280e-01, -9.3214e-04,
-         1.7241e-11, -1.0368e-01, -5.5675e-02, -1.1798e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.1691,  0.0000,  0.2538, -0.0068,  2.5342,  0.3363,  0.0000,  0.0000,
-         0.1686, -0.2339, -0.0510,  0.0000, -0.2325, -0.4038,  0.0000,  0.0000,
-         0.0000,  0.2026,  0.0000,  0.0000, -0.1663,  0.1907,  0.0000, -0.6769,
-        -0.0148,  0.0000, -0.0891, -0.0768,  0.7704,  0.0000, -0.1882, -1.2990,
-         0.0000, -0.1100,  0.0000,  0.0000,  0.0000,  0.0000,  0.0669,  0.3876,
-        -0.1080,  0.0000, -0.3687,  0.0000,  0.0000, -0.0541,  0.0000,  0.2228,
-         0.0000,  0.0000,  0.1932,  0.1650, -0.4010,  0.0037,  0.5643,  0.0000,
-         0.0000,  0.0000,  0.6828,  0.0000,  0.0000, -0.1037, -0.0557, -0.1180],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.1691,  0.0000,  0.2538, -0.0068,  2.5342,  0.3363,  0.0000,  0.0000,
-         0.1686, -0.2339, -0.0510,  0.0000, -0.2325, -0.4038,  0.0000,  0.0000,
-         0.0000,  0.2026,  0.0000,  0.0000, -0.1663,  0.1907,  0.0000, -0.6769,
-        -0.0148,  0.0000, -0.0891, -0.0768,  0.7704,  0.0000, -0.1882, -1.2990,
-         0.0000, -0.1100,  0.0000,  0.0000,  0.0000,  0.0000,  0.0669,  0.3876,
-        -0.1080,  0.0000, -0.3687,  0.0000,  0.0000, -0.0541,  0.0000,  0.2228,
-         0.0000,  0.0000,  0.1932,  0.1650, -0.4010,  0.0037,  0.5643,  0.0000,
-         0.0000,  0.0000,  0.6828,  0.0000,  0.0000, -0.1037, -0.0557, -0.1180],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.4038e-01, -5.0441e-03,  3.0227e-01, -3.6768e-02,  2.5311e+00,
-         3.6597e-01,  3.0533e-12, -2.4109e-05,  8.3160e-02, -1.4387e-01,
-        -7.1530e-02,  5.4889e-04, -2.2538e-01, -4.4711e-01, -5.4754e-09,
-        -4.0695e-06, -4.0030e-07,  1.0543e-01, -9.2586e-10,  1.1132e-05,
-        -2.0940e-01,  1.6196e-01, -1.9662e-09, -6.8606e-01, -7.8738e-02,
-         1.5958e-05, -1.9378e-02, -1.4695e-01,  7.2591e-01,  2.8595e-09,
-        -6.1849e-02, -1.2948e+00, -7.3198e-13, -1.2903e-01,  3.6332e-07,
-         1.9603e-05,  1.2120e-08,  0.0000e+00,  1.8895e-02,  3.7956e-01,
-        -1.9649e-01, -1.6188e-06, -3.4145e-01,  2.9827e-09,  7.2955e-07,
-        -6.5218e-02, -1.0294e-04,  2.7597e-01,  3.4628e-04, -1.8989e-10,
-         2.4721e-01,  5.9790e-02, -3.6061e-01, -1.0040e-01,  5.3084e-01,
-         3.9519e-14, -1.5532e-08,  5.5857e-12,  7.1089e-01, -8.0670e-04,
-         1.4921e-11, -1.1839e-01, -6.9219e-02, -1.2872e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.1404,  0.0000,  0.3023, -0.0368,  2.5311,  0.3660,  0.0000,  0.0000,
-         0.0832, -0.1439, -0.0715,  0.0000, -0.2254, -0.4471,  0.0000,  0.0000,
-         0.0000,  0.1054,  0.0000,  0.0000, -0.2094,  0.1620,  0.0000, -0.6861,
-        -0.0787,  0.0000, -0.0194, -0.1470,  0.7259,  0.0000, -0.0618, -1.2948,
-         0.0000, -0.1290,  0.0000,  0.0000,  0.0000,  0.0000,  0.0189,  0.3796,
-        -0.1965,  0.0000, -0.3414,  0.0000,  0.0000, -0.0652,  0.0000,  0.2760,
-         0.0000,  0.0000,  0.2472,  0.0598, -0.3606, -0.1004,  0.5308,  0.0000,
-         0.0000,  0.0000,  0.7109,  0.0000,  0.0000, -0.1184, -0.0692, -0.1287],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.1404,  0.0000,  0.3023, -0.0368,  2.5311,  0.3660,  0.0000,  0.0000,
-         0.0832, -0.1439, -0.0715,  0.0000, -0.2254, -0.4471,  0.0000,  0.0000,
-         0.0000,  0.1054,  0.0000,  0.0000, -0.2094,  0.1620,  0.0000, -0.6861,
-        -0.0787,  0.0000, -0.0194, -0.1470,  0.7259,  0.0000, -0.0618, -1.2948,
-         0.0000, -0.1290,  0.0000,  0.0000,  0.0000,  0.0000,  0.0189,  0.3796,
-        -0.1965,  0.0000, -0.3414,  0.0000,  0.0000, -0.0652,  0.0000,  0.2760,
-         0.0000,  0.0000,  0.2472,  0.0598, -0.3606, -0.1004,  0.5308,  0.0000,
-         0.0000,  0.0000,  0.7109,  0.0000,  0.0000, -0.1184, -0.0692, -0.1287],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.4393e-01, -4.3668e-03,  2.3422e-01, -3.1640e-02,  2.5276e+00,
-         3.1684e-01,  2.6433e-12, -2.0872e-05,  9.3822e-02, -1.0375e-01,
-        -1.0227e-01,  4.7519e-04, -2.2226e-01, -4.4907e-01, -4.7401e-09,
-        -3.5230e-06, -3.4655e-07, -6.3109e-02, -8.0153e-10,  9.6371e-06,
-        -1.7706e-01,  1.3193e-01, -1.7022e-09, -6.7220e-01, -1.1599e-01,
-         1.3815e-05,  1.1126e-02, -1.6739e-01,  6.1117e-01,  2.4756e-09,
-        -3.8017e-02, -1.2894e+00, -6.3369e-13, -1.2322e-01,  3.1453e-07,
-         1.6971e-05,  1.0492e-08,  0.0000e+00,  8.2770e-02,  4.3232e-01,
-        -2.4077e-01, -1.4014e-06, -3.4612e-01,  2.5822e-09,  6.3158e-07,
-        -1.2113e-01, -8.9115e-05,  3.1910e-01,  2.9978e-04, -1.6439e-10,
-         2.7838e-01,  1.2167e-01, -3.1262e-01, -2.9200e-01,  5.2683e-01,
-         3.4212e-14, -1.3446e-08,  4.8357e-12,  7.5005e-01, -6.9837e-04,
-         1.2917e-11, -1.4300e-01, -7.6767e-02, -1.4968e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.1439,  0.0000,  0.2342, -0.0316,  2.5276,  0.3168,  0.0000,  0.0000,
-         0.0938, -0.1037, -0.1023,  0.0000, -0.2223, -0.4491,  0.0000,  0.0000,
-         0.0000, -0.0631,  0.0000,  0.0000, -0.1771,  0.1319,  0.0000, -0.6722,
-        -0.1160,  0.0000,  0.0111, -0.1674,  0.6112,  0.0000, -0.0380, -1.2894,
-         0.0000, -0.1232,  0.0000,  0.0000,  0.0000,  0.0000,  0.0828,  0.4323,
-        -0.2408,  0.0000, -0.3461,  0.0000,  0.0000, -0.1211,  0.0000,  0.3191,
-         0.0000,  0.0000,  0.2784,  0.1217, -0.3126, -0.2920,  0.5268,  0.0000,
-         0.0000,  0.0000,  0.7501,  0.0000,  0.0000, -0.1430, -0.0768, -0.1497],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.1439,  0.0000,  0.2342, -0.0316,  2.5276,  0.3168,  0.0000,  0.0000,
-         0.0938, -0.1037, -0.1023,  0.0000, -0.2223, -0.4491,  0.0000,  0.0000,
-         0.0000, -0.0631,  0.0000,  0.0000, -0.1771,  0.1319,  0.0000, -0.6722,
-        -0.1160,  0.0000,  0.0111, -0.1674,  0.6112,  0.0000, -0.0380, -1.2894,
-         0.0000, -0.1232,  0.0000,  0.0000,  0.0000,  0.0000,  0.0828,  0.4323,
-        -0.2408,  0.0000, -0.3461,  0.0000,  0.0000, -0.1211,  0.0000,  0.3191,
-         0.0000,  0.0000,  0.2784,  0.1217, -0.3126, -0.2920,  0.5268,  0.0000,
-         0.0000,  0.0000,  0.7501,  0.0000,  0.0000, -0.1430, -0.0768, -0.1497],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.1126e-01, -3.7817e-03,  1.4912e-01,  1.8317e-02,  2.5252e+00,
-         2.1958e-01,  2.2891e-12, -1.8075e-05,  1.1120e-01, -7.3115e-02,
-        -8.0167e-02,  4.1152e-04, -2.1007e-01, -4.3419e-01, -4.1050e-09,
-        -3.0510e-06, -3.0011e-07, -1.0469e-01, -6.9414e-10,  8.3459e-06,
-        -1.3957e-01,  1.3047e-01, -1.4741e-09, -6.5851e-01, -8.8649e-02,
-         1.1964e-05,  4.2691e-02, -1.4890e-01,  5.0792e-01,  2.1439e-09,
-        -3.7300e-02, -1.2826e+00, -5.4879e-13, -1.1072e-01,  2.7239e-07,
-         1.4697e-05,  9.0864e-09,  0.0000e+00,  2.1582e-01,  4.9127e-01,
-        -2.1746e-01, -1.2136e-06, -3.5114e-01,  2.2362e-09,  5.4696e-07,
-        -2.7280e-01, -7.7175e-05,  3.1763e-01,  2.5962e-04, -1.4236e-10,
-         3.1748e-01,  2.0836e-01, -3.1745e-01, -2.6720e-01,  5.4042e-01,
-         2.9628e-14, -1.1644e-08,  4.1878e-12,  7.9562e-01, -6.0480e-04,
-         1.1186e-11, -1.5003e-01, -7.8529e-02, -1.8572e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.1113,  0.0000,  0.1491,  0.0183,  2.5252,  0.2196,  0.0000,  0.0000,
-         0.1112, -0.0731, -0.0802,  0.0000, -0.2101, -0.4342,  0.0000,  0.0000,
-         0.0000, -0.1047,  0.0000,  0.0000, -0.1396,  0.1305,  0.0000, -0.6585,
-        -0.0886,  0.0000,  0.0427, -0.1489,  0.5079,  0.0000, -0.0373, -1.2826,
-         0.0000, -0.1107,  0.0000,  0.0000,  0.0000,  0.0000,  0.2158,  0.4913,
-        -0.2175,  0.0000, -0.3511,  0.0000,  0.0000, -0.2728,  0.0000,  0.3176,
-         0.0000,  0.0000,  0.3175,  0.2084, -0.3175, -0.2672,  0.5404,  0.0000,
-         0.0000,  0.0000,  0.7956,  0.0000,  0.0000, -0.1500, -0.0785, -0.1857],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.1113,  0.0000,  0.1491,  0.0183,  2.5252,  0.2196,  0.0000,  0.0000,
-         0.1112, -0.0731, -0.0802,  0.0000, -0.2101, -0.4342,  0.0000,  0.0000,
-         0.0000, -0.1047,  0.0000,  0.0000, -0.1396,  0.1305,  0.0000, -0.6585,
-        -0.0886,  0.0000,  0.0427, -0.1489,  0.5079,  0.0000, -0.0373, -1.2826,
-         0.0000, -0.1107,  0.0000,  0.0000,  0.0000,  0.0000,  0.2158,  0.4913,
-        -0.2175,  0.0000, -0.3511,  0.0000,  0.0000, -0.2728,  0.0000,  0.3176,
-         0.0000,  0.0000,  0.3175,  0.2084, -0.3175, -0.2672,  0.5404,  0.0000,
-         0.0000,  0.0000,  0.7956,  0.0000,  0.0000, -0.1500, -0.0785, -0.1857],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-7.4348e-02, -3.2761e-03,  8.9666e-02,  6.3554e-02,  2.5239e+00,
-         9.6962e-02,  1.9831e-12, -1.5659e-05,  1.3405e-01, -7.3333e-02,
-        -2.7665e-02,  3.5651e-04, -1.8784e-01, -3.9717e-01, -3.5563e-09,
-        -2.6431e-06, -2.5999e-07, -1.0759e-01, -6.0135e-10,  7.2302e-06,
-        -1.2357e-01,  1.4251e-01, -1.2770e-09, -6.6153e-01, -3.9639e-02,
-         1.0365e-05,  6.6478e-02, -1.1190e-01,  4.1908e-01,  1.8573e-09,
-        -8.0067e-02, -1.2772e+00, -4.7542e-13, -1.1524e-01,  2.3598e-07,
-         1.2732e-05,  7.8717e-09,  0.0000e+00,  3.4971e-01,  5.4227e-01,
-        -1.6532e-01, -1.0514e-06, -3.6391e-01,  1.9373e-09,  4.7384e-07,
-        -3.3040e-01, -6.6858e-05,  2.8158e-01,  2.2491e-04, -1.2333e-10,
-         3.3232e-01,  2.6588e-01, -2.6032e-01, -2.2444e-01,  5.5369e-01,
-         2.5667e-14, -1.0088e-08,  3.6279e-12,  8.1652e-01, -5.2395e-04,
-         9.6910e-12, -1.5454e-01, -3.2972e-02, -2.1313e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0743,  0.0000,  0.0897,  0.0636,  2.5239,  0.0970,  0.0000,  0.0000,
-         0.1341, -0.0733, -0.0277,  0.0000, -0.1878, -0.3972,  0.0000,  0.0000,
-         0.0000, -0.1076,  0.0000,  0.0000, -0.1236,  0.1425,  0.0000, -0.6615,
-        -0.0396,  0.0000,  0.0665, -0.1119,  0.4191,  0.0000, -0.0801, -1.2772,
-         0.0000, -0.1152,  0.0000,  0.0000,  0.0000,  0.0000,  0.3497,  0.5423,
-        -0.1653,  0.0000, -0.3639,  0.0000,  0.0000, -0.3304,  0.0000,  0.2816,
-         0.0000,  0.0000,  0.3323,  0.2659, -0.2603, -0.2244,  0.5537,  0.0000,
-         0.0000,  0.0000,  0.8165,  0.0000,  0.0000, -0.1545, -0.0330, -0.2131],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0743,  0.0000,  0.0897,  0.0636,  2.5239,  0.0970,  0.0000,  0.0000,
-         0.1341, -0.0733, -0.0277,  0.0000, -0.1878, -0.3972,  0.0000,  0.0000,
-         0.0000, -0.1076,  0.0000,  0.0000, -0.1236,  0.1425,  0.0000, -0.6615,
-        -0.0396,  0.0000,  0.0665, -0.1119,  0.4191,  0.0000, -0.0801, -1.2772,
-         0.0000, -0.1152,  0.0000,  0.0000,  0.0000,  0.0000,  0.3497,  0.5423,
-        -0.1653,  0.0000, -0.3639,  0.0000,  0.0000, -0.3304,  0.0000,  0.2816,
-         0.0000,  0.0000,  0.3323,  0.2659, -0.2603, -0.2244,  0.5537,  0.0000,
-         0.0000,  0.0000,  0.8165,  0.0000,  0.0000, -0.1545, -0.0330, -0.2131],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-4.1126e-02, -2.8392e-03,  6.3950e-02,  1.3227e-01,  2.5214e+00,
-        -2.4282e-02,  1.7186e-12, -1.3570e-05,  1.4902e-01, -6.1021e-02,
-         4.0137e-02,  3.0896e-04, -1.3395e-01, -3.7437e-01, -3.0819e-09,
-        -2.2906e-06, -2.2532e-07, -5.1089e-02, -5.2114e-10,  6.2658e-06,
-        -1.4539e-01,  1.8223e-01, -1.1067e-09, -6.5880e-01,  2.5406e-02,
-         8.9821e-06,  8.3458e-02, -6.7024e-02,  3.4539e-01,  1.6095e-09,
-        -1.2759e-01, -1.2689e+00, -4.1201e-13, -1.0228e-01,  2.0450e-07,
-         1.1034e-05,  6.8217e-09,  0.0000e+00,  4.3872e-01,  5.8801e-01,
-        -9.6998e-02, -9.1116e-07, -3.4837e-01,  1.6789e-09,  4.1064e-07,
-        -3.6726e-01, -5.7940e-05,  2.2368e-01,  1.9491e-04, -1.0688e-10,
-         3.4796e-01,  3.0719e-01, -2.0768e-01, -1.0475e-01,  5.5695e-01,
-         2.2244e-14, -8.7422e-09,  3.1440e-12,  8.2090e-01, -4.5406e-04,
-         8.3984e-12, -1.3146e-01,  8.6258e-03, -2.4756e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0411,  0.0000,  0.0639,  0.1323,  2.5214, -0.0243,  0.0000,  0.0000,
-         0.1490, -0.0610,  0.0401,  0.0000, -0.1339, -0.3744,  0.0000,  0.0000,
-         0.0000, -0.0511,  0.0000,  0.0000, -0.1454,  0.1822,  0.0000, -0.6588,
-         0.0254,  0.0000,  0.0835, -0.0670,  0.3454,  0.0000, -0.1276, -1.2689,
-         0.0000, -0.1023,  0.0000,  0.0000,  0.0000,  0.0000,  0.4387,  0.5880,
-        -0.0970,  0.0000, -0.3484,  0.0000,  0.0000, -0.3673,  0.0000,  0.2237,
-         0.0000,  0.0000,  0.3480,  0.3072, -0.2077, -0.1047,  0.5569,  0.0000,
-         0.0000,  0.0000,  0.8209,  0.0000,  0.0000, -0.1315,  0.0086, -0.2476],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0411,  0.0000,  0.0639,  0.1323,  2.5214, -0.0243,  0.0000,  0.0000,
-         0.1490, -0.0610,  0.0401,  0.0000, -0.1339, -0.3744,  0.0000,  0.0000,
-         0.0000, -0.0511,  0.0000,  0.0000, -0.1454,  0.1822,  0.0000, -0.6588,
-         0.0254,  0.0000,  0.0835, -0.0670,  0.3454,  0.0000, -0.1276, -1.2689,
-         0.0000, -0.1023,  0.0000,  0.0000,  0.0000,  0.0000,  0.4387,  0.5880,
-        -0.0970,  0.0000, -0.3484,  0.0000,  0.0000, -0.3673,  0.0000,  0.2237,
-         0.0000,  0.0000,  0.3480,  0.3072, -0.2077, -0.1047,  0.5569,  0.0000,
-         0.0000,  0.0000,  0.8209,  0.0000,  0.0000, -0.1315,  0.0086, -0.2476],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-9.2761e-03, -2.4613e-03,  6.6971e-02,  1.7295e-01,  2.5158e+00,
-        -1.1136e-01,  1.4899e-12, -1.1764e-05,  1.6696e-01, -2.7409e-02,
-         1.1097e-01,  2.6784e-04, -6.8867e-02, -3.3438e-01, -2.6718e-09,
-        -1.9858e-06, -1.9533e-07,  3.6382e-02, -4.5179e-10,  5.4320e-06,
-        -1.4220e-01,  1.9196e-01, -9.5943e-10, -6.2879e-01,  4.9738e-02,
-         7.7868e-06,  1.0693e-01, -2.1481e-02,  2.7979e-01,  1.3954e-09,
-        -1.3504e-01, -1.2557e+00, -3.5718e-13, -9.2548e-02,  1.7729e-07,
-         9.5658e-06,  5.9139e-09,  0.0000e+00,  4.6606e-01,  5.9581e-01,
-        -4.5729e-02, -7.8990e-07, -3.2793e-01,  1.4555e-09,  3.5599e-07,
-        -3.7052e-01, -5.0230e-05,  1.4700e-01,  1.6897e-04, -9.2658e-11,
-         3.6108e-01,  3.0518e-01, -1.5397e-01,  1.2649e-02,  5.3693e-01,
-         1.9284e-14, -7.5788e-09,  2.7256e-12,  8.2298e-01, -3.9364e-04,
-         7.2808e-12, -7.6354e-02,  4.0111e-02, -2.4973e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0093,  0.0000,  0.0670,  0.1729,  2.5158, -0.1114,  0.0000,  0.0000,
-         0.1670, -0.0274,  0.1110,  0.0000, -0.0689, -0.3344,  0.0000,  0.0000,
-         0.0000,  0.0364,  0.0000,  0.0000, -0.1422,  0.1920,  0.0000, -0.6288,
-         0.0497,  0.0000,  0.1069, -0.0215,  0.2798,  0.0000, -0.1350, -1.2557,
-         0.0000, -0.0925,  0.0000,  0.0000,  0.0000,  0.0000,  0.4661,  0.5958,
-        -0.0457,  0.0000, -0.3279,  0.0000,  0.0000, -0.3705,  0.0000,  0.1470,
-         0.0000,  0.0000,  0.3611,  0.3052, -0.1540,  0.0126,  0.5369,  0.0000,
-         0.0000,  0.0000,  0.8230,  0.0000,  0.0000, -0.0764,  0.0401, -0.2497],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0093,  0.0000,  0.0670,  0.1729,  2.5158, -0.1114,  0.0000,  0.0000,
-         0.1670, -0.0274,  0.1110,  0.0000, -0.0689, -0.3344,  0.0000,  0.0000,
-         0.0000,  0.0364,  0.0000,  0.0000, -0.1422,  0.1920,  0.0000, -0.6288,
-         0.0497,  0.0000,  0.1069, -0.0215,  0.2798,  0.0000, -0.1350, -1.2557,
-         0.0000, -0.0925,  0.0000,  0.0000,  0.0000,  0.0000,  0.4661,  0.5958,
-        -0.0457,  0.0000, -0.3279,  0.0000,  0.0000, -0.3705,  0.0000,  0.1470,
-         0.0000,  0.0000,  0.3611,  0.3052, -0.1540,  0.0126,  0.5369,  0.0000,
-         0.0000,  0.0000,  0.8230,  0.0000,  0.0000, -0.0764,  0.0401, -0.2497],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.0663e-02, -2.1345e-03,  8.7744e-02,  2.1527e-01,  2.5096e+00,
-        -1.5477e-01,  1.2921e-12, -1.0202e-05,  1.4808e-01,  2.9020e-02,
-         1.4692e-01,  2.3228e-04, -1.9470e-02, -3.1970e-01, -2.3171e-09,
-        -1.7221e-06, -1.6940e-07,  9.0340e-02, -3.9180e-10,  4.7108e-06,
-        -1.8600e-01,  1.7457e-01, -8.3205e-10, -6.1681e-01,  6.9971e-02,
-         6.7529e-06,  1.6285e-01, -6.9178e-03,  2.0175e-01,  1.2101e-09,
-        -1.4846e-01, -1.2470e+00, -3.0976e-13, -7.2128e-02,  1.5375e-07,
-         8.2957e-06,  5.1287e-09,  0.0000e+00,  4.4875e-01,  5.9658e-01,
-        -1.9150e-02, -6.8503e-07, -3.0062e-01,  1.2622e-09,  3.0873e-07,
-        -3.8446e-01, -4.3561e-05,  8.5129e-02,  1.4654e-04, -8.0356e-11,
-         3.7102e-01,  2.5898e-01, -1.4415e-01,  1.7834e-01,  5.0748e-01,
-         1.6723e-14, -6.5726e-09,  2.3637e-12,  8.2629e-01, -3.4138e-04,
-         6.3141e-12, -2.4695e-02,  6.2514e-02, -2.4754e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0207,  0.0000,  0.0877,  0.2153,  2.5096, -0.1548,  0.0000,  0.0000,
-         0.1481,  0.0290,  0.1469,  0.0000, -0.0195, -0.3197,  0.0000,  0.0000,
-         0.0000,  0.0903,  0.0000,  0.0000, -0.1860,  0.1746,  0.0000, -0.6168,
-         0.0700,  0.0000,  0.1629, -0.0069,  0.2018,  0.0000, -0.1485, -1.2470,
-         0.0000, -0.0721,  0.0000,  0.0000,  0.0000,  0.0000,  0.4488,  0.5966,
-        -0.0192,  0.0000, -0.3006,  0.0000,  0.0000, -0.3845,  0.0000,  0.0851,
-         0.0000,  0.0000,  0.3710,  0.2590, -0.1441,  0.1783,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.8263,  0.0000,  0.0000, -0.0247,  0.0625, -0.2475],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0207,  0.0000,  0.0877,  0.2153,  2.5096, -0.1548,  0.0000,  0.0000,
-         0.1481,  0.0290,  0.1469,  0.0000, -0.0195, -0.3197,  0.0000,  0.0000,
-         0.0000,  0.0903,  0.0000,  0.0000, -0.1860,  0.1746,  0.0000, -0.6168,
-         0.0700,  0.0000,  0.1629, -0.0069,  0.2018,  0.0000, -0.1485, -1.2470,
-         0.0000, -0.0721,  0.0000,  0.0000,  0.0000,  0.0000,  0.4488,  0.5966,
-        -0.0192,  0.0000, -0.3006,  0.0000,  0.0000, -0.3845,  0.0000,  0.0851,
-         0.0000,  0.0000,  0.3710,  0.2590, -0.1441,  0.1783,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.8263,  0.0000,  0.0000, -0.0247,  0.0625, -0.2475],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4064e-02, -1.8518e-03,  6.3442e-02,  2.4105e-01,  2.5040e+00,
-        -1.4085e-01,  1.1209e-12, -8.8510e-06,  9.5242e-02,  7.0510e-02,
-         1.4029e-01,  2.0151e-04, -1.3558e-02, -3.0734e-01, -2.0101e-09,
-        -1.4940e-06, -1.4696e-07,  9.3523e-03, -3.3990e-10,  4.0868e-06,
-        -2.4902e-01,  8.7881e-02, -7.2183e-10, -6.3134e-01,  4.5070e-02,
-         5.8584e-06,  2.3216e-01, -4.8558e-02,  1.0426e-01,  1.0498e-09,
-        -1.9902e-01, -1.2431e+00, -2.6873e-13, -6.1991e-02,  1.3338e-07,
-         7.1969e-06,  4.4494e-09,  0.0000e+00,  4.1948e-01,  5.7870e-01,
-        -8.8463e-02, -5.9429e-07, -2.6605e-01,  1.0950e-09,  2.6783e-07,
-        -3.3064e-01, -3.7790e-05,  8.3697e-02,  1.2713e-04, -6.9712e-11,
-         3.5882e-01,  1.8360e-01, -1.4795e-01,  1.2895e-01, -2.5546e-02,
-         1.4508e-14, -5.7020e-09,  2.0506e-12,  8.6247e-01, -2.9616e-04,
-         5.4777e-12, -3.2043e-03,  2.9837e-02, -2.2936e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0341,  0.0000,  0.0634,  0.2411,  2.5040, -0.1408,  0.0000,  0.0000,
-         0.0952,  0.0705,  0.1403,  0.0000, -0.0136, -0.3073,  0.0000,  0.0000,
-         0.0000,  0.0094,  0.0000,  0.0000, -0.2490,  0.0879,  0.0000, -0.6313,
-         0.0451,  0.0000,  0.2322, -0.0486,  0.1043,  0.0000, -0.1990, -1.2431,
-         0.0000, -0.0620,  0.0000,  0.0000,  0.0000,  0.0000,  0.4195,  0.5787,
-        -0.0885,  0.0000, -0.2660,  0.0000,  0.0000, -0.3306,  0.0000,  0.0837,
-         0.0000,  0.0000,  0.3588,  0.1836, -0.1480,  0.1290,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.8625,  0.0000,  0.0000, -0.0032,  0.0298, -0.2294],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0341,  0.0000,  0.0634,  0.2411,  2.5040, -0.1408,  0.0000,  0.0000,
-         0.0952,  0.0705,  0.1403,  0.0000, -0.0136, -0.3073,  0.0000,  0.0000,
-         0.0000,  0.0094,  0.0000,  0.0000, -0.2490,  0.0879,  0.0000, -0.6313,
-         0.0451,  0.0000,  0.2322, -0.0486,  0.1043,  0.0000, -0.1990, -1.2431,
-         0.0000, -0.0620,  0.0000,  0.0000,  0.0000,  0.0000,  0.4195,  0.5787,
-        -0.0885,  0.0000, -0.2660,  0.0000,  0.0000, -0.3306,  0.0000,  0.0837,
-         0.0000,  0.0000,  0.3588,  0.1836, -0.1480,  0.1290,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.8625,  0.0000,  0.0000, -0.0032,  0.0298, -0.2294],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7907e-02, -1.6071e-03,  1.5507e-02,  2.3172e-01,  2.5004e+00,
-        -8.6071e-02,  9.7279e-13, -7.6813e-06,  4.0515e-02,  6.5664e-02,
-         1.1892e-01,  1.7488e-04, -3.9326e-02, -2.5496e-01, -1.7445e-09,
-        -1.2966e-06, -1.2754e-07, -1.5872e-01, -2.9498e-10,  3.5467e-06,
-        -2.8134e-01, -3.5351e-02, -6.2644e-10, -6.2779e-01, -2.8667e-02,
-         5.0842e-06,  2.7250e-01, -8.4589e-02, -8.3811e-03,  9.1106e-10,
-        -2.4681e-01, -1.2432e+00, -2.3321e-13, -8.6443e-02,  1.1576e-07,
-         6.2458e-06,  3.8614e-09,  0.0000e+00,  3.3571e-01,  5.5891e-01,
-        -1.5798e-01, -5.1575e-07, -2.5925e-01,  9.5031e-10,  2.3244e-07,
-        -2.3895e-01, -3.2796e-05,  7.8146e-02,  1.1033e-04, -6.0499e-11,
-         3.3324e-01,  1.1871e-01, -1.2166e-01, -1.0950e-01, -2.2170e-02,
-         1.2591e-14, -4.9484e-09,  1.7796e-12,  8.8481e-01, -2.5702e-04,
-         4.7538e-12, -4.0737e-02,  2.1610e-03, -1.9896e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.7907e-02,  0.0000e+00,  1.5507e-02,  2.3172e-01,  2.5004e+00,
-        -8.6071e-02,  0.0000e+00,  0.0000e+00,  4.0515e-02,  6.5664e-02,
-         1.1892e-01,  0.0000e+00, -3.9326e-02, -2.5496e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.5872e-01,  0.0000e+00,  0.0000e+00,
-        -2.8134e-01, -3.5351e-02,  0.0000e+00, -6.2779e-01, -2.8667e-02,
-         0.0000e+00,  2.7250e-01, -8.4589e-02, -8.3811e-03,  0.0000e+00,
-        -2.4681e-01, -1.2432e+00,  0.0000e+00, -8.6443e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.3571e-01,  5.5891e-01,
-        -1.5798e-01,  0.0000e+00, -2.5925e-01,  0.0000e+00,  0.0000e+00,
-        -2.3895e-01,  0.0000e+00,  7.8146e-02,  0.0000e+00,  0.0000e+00,
-         3.3324e-01,  1.1871e-01, -1.2166e-01, -1.0950e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  8.8481e-01,  0.0000e+00,
-         0.0000e+00, -4.0737e-02,  2.1610e-03, -1.9896e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.7907e-02,  0.0000e+00,  1.5507e-02,  2.3172e-01,  2.5004e+00,
-        -8.6071e-02,  0.0000e+00,  0.0000e+00,  4.0515e-02,  6.5664e-02,
-         1.1892e-01,  0.0000e+00, -3.9326e-02, -2.5496e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.5872e-01,  0.0000e+00,  0.0000e+00,
-        -2.8134e-01, -3.5351e-02,  0.0000e+00, -6.2779e-01, -2.8667e-02,
-         0.0000e+00,  2.7250e-01, -8.4589e-02, -8.3811e-03,  0.0000e+00,
-        -2.4681e-01, -1.2432e+00,  0.0000e+00, -8.6443e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.3571e-01,  5.5891e-01,
-        -1.5798e-01,  0.0000e+00, -2.5925e-01,  0.0000e+00,  0.0000e+00,
-        -2.3895e-01,  0.0000e+00,  7.8146e-02,  0.0000e+00,  0.0000e+00,
-         3.3324e-01,  1.1871e-01, -1.2166e-01, -1.0950e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  8.8481e-01,  0.0000e+00,
-         0.0000e+00, -4.0737e-02,  2.1610e-03, -1.9896e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 8.0087e-02, -1.3952e-03,  5.3678e-03,  2.3880e-01,  2.4956e+00,
-        -4.2729e-02,  8.4453e-13, -6.6686e-06, -2.4407e-02,  5.9434e-02,
-         9.4507e-02,  1.5182e-04, -3.4909e-02, -1.9836e-01, -1.5145e-09,
-        -1.1256e-06, -1.1072e-07, -2.6628e-01, -2.5609e-10,  3.0791e-06,
-        -3.1666e-01, -8.9567e-02, -5.4385e-10, -6.3856e-01, -7.1245e-02,
-         4.4139e-06,  3.1759e-01, -8.4985e-02, -7.6956e-02,  7.9095e-10,
-        -3.0108e-01, -1.2420e+00, -2.0247e-13, -9.3832e-02,  1.0049e-07,
-         5.4223e-06,  3.3523e-09,  0.0000e+00,  3.0745e-01,  5.7250e-01,
-        -1.8593e-01, -4.4775e-07, -2.5160e-01,  8.2502e-10,  2.0179e-07,
-        -1.8650e-01, -2.8472e-05,  9.5961e-02,  9.5781e-05, -5.2523e-11,
-         3.3479e-01,  4.7218e-02, -1.4074e-01, -2.1112e-01, -1.9247e-02,
-         1.0931e-14, -4.2960e-09,  1.5450e-12,  8.9971e-01, -2.2313e-04,
-         4.1271e-12, -9.7421e-02, -2.3452e-03, -1.8672e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.0087e-02,  0.0000e+00,  5.3678e-03,  2.3880e-01,  2.4956e+00,
-        -4.2729e-02,  0.0000e+00,  0.0000e+00, -2.4407e-02,  5.9434e-02,
-         9.4507e-02,  0.0000e+00, -3.4909e-02, -1.9836e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.6628e-01,  0.0000e+00,  0.0000e+00,
-        -3.1666e-01, -8.9567e-02,  0.0000e+00, -6.3856e-01, -7.1245e-02,
-         0.0000e+00,  3.1759e-01, -8.4985e-02, -7.6956e-02,  0.0000e+00,
-        -3.0108e-01, -1.2420e+00,  0.0000e+00, -9.3832e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.0745e-01,  5.7250e-01,
-        -1.8593e-01,  0.0000e+00, -2.5160e-01,  0.0000e+00,  0.0000e+00,
-        -1.8650e-01,  0.0000e+00,  9.5961e-02,  0.0000e+00,  0.0000e+00,
-         3.3479e-01,  4.7218e-02, -1.4074e-01, -2.1112e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  8.9971e-01,  0.0000e+00,
-         0.0000e+00, -9.7421e-02, -2.3452e-03, -1.8672e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.0087e-02,  0.0000e+00,  5.3678e-03,  2.3880e-01,  2.4956e+00,
-        -4.2729e-02,  0.0000e+00,  0.0000e+00, -2.4407e-02,  5.9434e-02,
-         9.4507e-02,  0.0000e+00, -3.4909e-02, -1.9836e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.6628e-01,  0.0000e+00,  0.0000e+00,
-        -3.1666e-01, -8.9567e-02,  0.0000e+00, -6.3856e-01, -7.1245e-02,
-         0.0000e+00,  3.1759e-01, -8.4985e-02, -7.6956e-02,  0.0000e+00,
-        -3.0108e-01, -1.2420e+00,  0.0000e+00, -9.3832e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.0745e-01,  5.7250e-01,
-        -1.8593e-01,  0.0000e+00, -2.5160e-01,  0.0000e+00,  0.0000e+00,
-        -1.8650e-01,  0.0000e+00,  9.5961e-02,  0.0000e+00,  0.0000e+00,
-         3.3479e-01,  4.7218e-02, -1.4074e-01, -2.1112e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  8.9971e-01,  0.0000e+00,
-         0.0000e+00, -9.7421e-02, -2.3452e-03, -1.8672e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.6361e-01, -1.2117e-03,  4.4280e-03,  2.6936e-01,  2.4925e+00,
-        -4.0941e-02,  7.3345e-13, -5.7915e-06, -1.2273e-01,  9.7050e-02,
-         7.5004e-02,  1.3185e-04, -5.2414e-02, -2.0891e-01, -1.3153e-09,
-        -9.7757e-07, -9.6159e-08, -1.8499e-01, -2.2241e-10,  2.6741e-06,
-        -3.5158e-01,  1.7810e-02, -4.7232e-10, -7.0881e-01,  6.8123e-03,
-         3.8334e-06,  3.4221e-01, -1.6662e-02, -7.0062e-02,  6.8692e-10,
-        -3.3963e-01, -1.2434e+00, -1.7584e-13, -5.0052e-02,  8.7277e-08,
-         4.7091e-06,  2.9114e-09,  0.0000e+00,  3.0802e-01,  6.0847e-01,
-        -7.5855e-02, -3.8886e-07, -2.8159e-01,  7.1651e-10,  1.7525e-07,
-        -2.3231e-01, -2.4727e-05,  7.3876e-02,  8.3183e-05, -4.5615e-11,
-         3.8240e-01,  2.3471e-02, -2.4097e-01,  1.4952e-02, -1.6716e-02,
-         9.4931e-15, -3.7310e-09,  1.3418e-12,  8.8883e-01, -1.9378e-04,
-         3.5842e-12, -1.5778e-01,  1.8561e-03, -2.1321e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 1.6361e-01,  0.0000e+00,  4.4280e-03,  2.6936e-01,  2.4925e+00,
-        -4.0941e-02,  0.0000e+00,  0.0000e+00, -1.2273e-01,  9.7050e-02,
-         7.5004e-02,  0.0000e+00, -5.2414e-02, -2.0891e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.8499e-01,  0.0000e+00,  0.0000e+00,
-        -3.5158e-01,  1.7810e-02,  0.0000e+00, -7.0881e-01,  6.8123e-03,
-         0.0000e+00,  3.4221e-01, -1.6662e-02, -7.0062e-02,  0.0000e+00,
-        -3.3963e-01, -1.2434e+00,  0.0000e+00, -5.0052e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.0802e-01,  6.0847e-01,
-        -7.5855e-02,  0.0000e+00, -2.8159e-01,  0.0000e+00,  0.0000e+00,
-        -2.3231e-01,  0.0000e+00,  7.3876e-02,  0.0000e+00,  0.0000e+00,
-         3.8240e-01,  2.3471e-02, -2.4097e-01,  1.4952e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  8.8883e-01,  0.0000e+00,
-         0.0000e+00, -1.5778e-01,  1.8561e-03, -2.1321e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 1.6361e-01,  0.0000e+00,  4.4280e-03,  2.6936e-01,  2.4925e+00,
-        -4.0941e-02,  0.0000e+00,  0.0000e+00, -1.2273e-01,  9.7050e-02,
-         7.5004e-02,  0.0000e+00, -5.2414e-02, -2.0891e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.8499e-01,  0.0000e+00,  0.0000e+00,
-        -3.5158e-01,  1.7810e-02,  0.0000e+00, -7.0881e-01,  6.8123e-03,
-         0.0000e+00,  3.4221e-01, -1.6662e-02, -7.0062e-02,  0.0000e+00,
-        -3.3963e-01, -1.2434e+00,  0.0000e+00, -5.0052e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.0802e-01,  6.0847e-01,
-        -7.5855e-02,  0.0000e+00, -2.8159e-01,  0.0000e+00,  0.0000e+00,
-        -2.3231e-01,  0.0000e+00,  7.3876e-02,  0.0000e+00,  0.0000e+00,
-         3.8240e-01,  2.3471e-02, -2.4097e-01,  1.4952e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  8.8883e-01,  0.0000e+00,
-         0.0000e+00, -1.5778e-01,  1.8561e-03, -2.1321e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4381e-01, -1.0527e-03, -2.8344e-02,  3.0493e-01,  2.4914e+00,
-        -7.1854e-02,  6.3721e-13, -5.0315e-06, -1.9979e-01,  1.6784e-01,
-         7.2574e-02,  1.1455e-04, -1.0102e-01, -1.8049e-01, -1.1427e-09,
-        -8.4930e-07, -8.3542e-08, -8.1820e-02, -1.9323e-10,  2.3232e-06,
-        -3.6335e-01,  1.4870e-01, -4.1034e-10, -7.4969e-01,  1.2738e-01,
-         3.3304e-06,  3.6242e-01,  7.8344e-02, -5.1309e-02,  5.9678e-10,
-        -4.0609e-01, -1.2366e+00, -1.5276e-13,  1.3363e-02,  7.5825e-08,
-         4.0912e-06,  2.5293e-09,  0.0000e+00,  2.9638e-01,  6.3573e-01,
-         6.0021e-02, -3.3784e-07, -3.2784e-01,  6.2249e-10,  1.5226e-07,
-        -2.7983e-01, -2.1483e-05,  2.5683e-02,  7.2268e-05, -3.9629e-11,
-         4.2440e-01,  5.5823e-02, -3.3879e-01,  1.5079e-01, -1.4522e-02,
-         8.2475e-15, -3.2414e-09,  1.1657e-12,  8.7495e-01, -1.6836e-04,
-         3.1139e-12, -2.6324e-01,  6.6831e-03, -2.2706e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2438,  0.0000, -0.0283,  0.3049,  2.4914, -0.0719,  0.0000,  0.0000,
-        -0.1998,  0.1678,  0.0726,  0.0000, -0.1010, -0.1805,  0.0000,  0.0000,
-         0.0000, -0.0818,  0.0000,  0.0000, -0.3633,  0.1487,  0.0000, -0.7497,
-         0.1274,  0.0000,  0.3624,  0.0783, -0.0513,  0.0000, -0.4061, -1.2366,
-         0.0000,  0.0134,  0.0000,  0.0000,  0.0000,  0.0000,  0.2964,  0.6357,
-         0.0600,  0.0000, -0.3278,  0.0000,  0.0000, -0.2798,  0.0000,  0.0257,
-         0.0000,  0.0000,  0.4244,  0.0558, -0.3388,  0.1508,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.8749,  0.0000,  0.0000, -0.2632,  0.0067, -0.2271],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2438,  0.0000, -0.0283,  0.3049,  2.4914, -0.0719,  0.0000,  0.0000,
-        -0.1998,  0.1678,  0.0726,  0.0000, -0.1010, -0.1805,  0.0000,  0.0000,
-         0.0000, -0.0818,  0.0000,  0.0000, -0.3633,  0.1487,  0.0000, -0.7497,
-         0.1274,  0.0000,  0.3624,  0.0783, -0.0513,  0.0000, -0.4061, -1.2366,
-         0.0000,  0.0134,  0.0000,  0.0000,  0.0000,  0.0000,  0.2964,  0.6357,
-         0.0600,  0.0000, -0.3278,  0.0000,  0.0000, -0.2798,  0.0000,  0.0257,
-         0.0000,  0.0000,  0.4244,  0.0558, -0.3388,  0.1508,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.8749,  0.0000,  0.0000, -0.2632,  0.0067, -0.2271],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.0691e-01, -9.1490e-04, -1.1710e-02,  3.3378e-01,  2.4903e+00,
-        -8.3662e-02,  5.5381e-13, -4.3729e-06, -2.4844e-01,  2.1124e-01,
-         1.5990e-02,  9.9559e-05, -8.3334e-02, -1.2504e-01, -9.9313e-10,
-        -7.3813e-07, -7.2607e-08,  1.6695e-02, -1.6793e-10,  2.0191e-06,
-        -3.4264e-01,  1.9348e-01, -3.5663e-10, -7.3740e-01,  1.8205e-01,
-         2.8944e-06,  3.6632e-01,  1.5813e-01, -9.3960e-03,  5.1867e-10,
-        -5.0606e-01, -1.2223e+00, -1.3277e-13,  6.5983e-02,  6.5900e-08,
-         3.5557e-06,  2.1983e-09,  0.0000e+00,  3.6673e-01,  6.8220e-01,
-         7.8350e-02, -2.9361e-07, -3.3107e-01,  5.4101e-10,  1.3233e-07,
-        -2.9119e-01, -1.8671e-05,  3.8104e-02,  6.2809e-05, -3.4442e-11,
-         4.5595e-01,  1.1895e-01, -4.4283e-01,  2.2395e-01, -1.2621e-02,
-         7.1679e-15, -2.8171e-09,  1.0131e-12,  8.4282e-01, -1.4632e-04,
-         2.7063e-12, -3.7391e-01, -2.9510e-02, -2.2758e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3069,  0.0000, -0.0117,  0.3338,  2.4903, -0.0837,  0.0000,  0.0000,
-        -0.2484,  0.2112,  0.0160,  0.0000, -0.0833, -0.1250,  0.0000,  0.0000,
-         0.0000,  0.0167,  0.0000,  0.0000, -0.3426,  0.1935,  0.0000, -0.7374,
-         0.1820,  0.0000,  0.3663,  0.1581, -0.0094,  0.0000, -0.5061, -1.2223,
-         0.0000,  0.0660,  0.0000,  0.0000,  0.0000,  0.0000,  0.3667,  0.6822,
-         0.0784,  0.0000, -0.3311,  0.0000,  0.0000, -0.2912,  0.0000,  0.0381,
-         0.0000,  0.0000,  0.4559,  0.1190, -0.4428,  0.2239,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.8428,  0.0000,  0.0000, -0.3739, -0.0295, -0.2276],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3069,  0.0000, -0.0117,  0.3338,  2.4903, -0.0837,  0.0000,  0.0000,
-        -0.2484,  0.2112,  0.0160,  0.0000, -0.0833, -0.1250,  0.0000,  0.0000,
-         0.0000,  0.0167,  0.0000,  0.0000, -0.3426,  0.1935,  0.0000, -0.7374,
-         0.1820,  0.0000,  0.3663,  0.1581, -0.0094,  0.0000, -0.5061, -1.2223,
-         0.0000,  0.0660,  0.0000,  0.0000,  0.0000,  0.0000,  0.3667,  0.6822,
-         0.0784,  0.0000, -0.3311,  0.0000,  0.0000, -0.2912,  0.0000,  0.0381,
-         0.0000,  0.0000,  0.4559,  0.1190, -0.4428,  0.2239,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.8428,  0.0000,  0.0000, -0.3739, -0.0295, -0.2276],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6450e-01, -7.9544e-04,  2.3147e-02,  3.4645e-01,  2.4844e+00,
-        -7.6422e-02,  4.8149e-13, -3.8019e-06, -2.5304e-01,  2.7517e-01,
-        -3.6314e-02,  8.6559e-05, -2.4694e-02, -8.3301e-02, -8.6345e-10,
-        -6.4175e-07, -6.3126e-08,  1.1796e-01, -1.4601e-10,  1.7555e-06,
-        -3.0096e-01,  1.8460e-01, -3.1006e-10, -6.5703e-01,  1.6038e-01,
-         2.5165e-06,  3.6726e-01,  2.3145e-01,  1.0745e-02,  4.5094e-10,
-        -6.5267e-01, -1.2056e+00, -1.1543e-13,  8.6846e-02,  5.7295e-08,
-         3.0914e-06,  1.9112e-09,  0.0000e+00,  4.1689e-01,  7.3506e-01,
-         5.2292e-02, -2.5527e-07, -3.1394e-01,  4.7037e-10,  1.1505e-07,
-        -2.6739e-01, -1.6233e-05,  6.3986e-02,  5.4607e-05, -2.9945e-11,
-         4.7691e-01,  1.7959e-01, -5.0076e-01,  2.1365e-01, -1.0973e-02,
-         6.2320e-15, -2.4493e-09,  8.8085e-13,  7.9530e-01, -1.2721e-04,
-         2.3529e-12, -4.6626e-01, -8.2155e-02, -2.2701e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3645,  0.0000,  0.0231,  0.3465,  2.4844, -0.0764,  0.0000,  0.0000,
-        -0.2530,  0.2752, -0.0363,  0.0000, -0.0247, -0.0833,  0.0000,  0.0000,
-         0.0000,  0.1180,  0.0000,  0.0000, -0.3010,  0.1846,  0.0000, -0.6570,
-         0.1604,  0.0000,  0.3673,  0.2315,  0.0107,  0.0000, -0.6527, -1.2056,
-         0.0000,  0.0868,  0.0000,  0.0000,  0.0000,  0.0000,  0.4169,  0.7351,
-         0.0523,  0.0000, -0.3139,  0.0000,  0.0000, -0.2674,  0.0000,  0.0640,
-         0.0000,  0.0000,  0.4769,  0.1796, -0.5008,  0.2137,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7953,  0.0000,  0.0000, -0.4663, -0.0822, -0.2270],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3645,  0.0000,  0.0231,  0.3465,  2.4844, -0.0764,  0.0000,  0.0000,
-        -0.2530,  0.2752, -0.0363,  0.0000, -0.0247, -0.0833,  0.0000,  0.0000,
-         0.0000,  0.1180,  0.0000,  0.0000, -0.3010,  0.1846,  0.0000, -0.6570,
-         0.1604,  0.0000,  0.3673,  0.2315,  0.0107,  0.0000, -0.6527, -1.2056,
-         0.0000,  0.0868,  0.0000,  0.0000,  0.0000,  0.0000,  0.4169,  0.7351,
-         0.0523,  0.0000, -0.3139,  0.0000,  0.0000, -0.2674,  0.0000,  0.0640,
-         0.0000,  0.0000,  0.4769,  0.1796, -0.5008,  0.2137,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7953,  0.0000,  0.0000, -0.4663, -0.0822, -0.2270],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7323e-01, -6.9182e-04,  1.8806e-02,  3.3616e-01,  2.4777e+00,
-        -4.9277e-02,  4.1877e-13, -3.3067e-06, -2.5814e-01,  3.6309e-01,
-        -1.2003e-02,  7.5284e-05, -1.8831e-02, -6.4794e-02, -7.5098e-10,
-        -5.5815e-07, -5.4903e-08,  1.8021e-01, -1.2699e-10,  1.5268e-06,
-        -2.5448e-01,  1.6603e-01, -2.6967e-10, -5.9410e-01,  1.3207e-01,
-         2.1887e-06,  3.7357e-01,  2.8971e-01,  1.7733e-02,  3.9220e-10,
-        -7.7564e-01, -1.1958e+00, -1.0040e-13,  1.0530e-01,  4.9832e-08,
-         2.6887e-06,  1.6623e-09,  0.0000e+00,  4.4953e-01,  7.5786e-01,
-         4.1089e-02, -2.2202e-07, -3.1670e-01,  4.0910e-10,  1.0006e-07,
-        -2.0051e-01, -1.4118e-05,  9.4722e-02,  4.7494e-05, -2.6044e-11,
-         4.9095e-01,  1.9697e-01, -5.0386e-01,  1.3463e-01, -9.5439e-03,
-         5.4202e-15, -2.1302e-09,  7.6611e-13,  7.3903e-01, -1.1064e-04,
-         2.0465e-12, -5.3985e-01, -1.0053e-01, -2.0777e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3732,  0.0000,  0.0188,  0.3362,  2.4777, -0.0493,  0.0000,  0.0000,
-        -0.2581,  0.3631, -0.0120,  0.0000, -0.0188, -0.0648,  0.0000,  0.0000,
-         0.0000,  0.1802,  0.0000,  0.0000, -0.2545,  0.1660,  0.0000, -0.5941,
-         0.1321,  0.0000,  0.3736,  0.2897,  0.0177,  0.0000, -0.7756, -1.1958,
-         0.0000,  0.1053,  0.0000,  0.0000,  0.0000,  0.0000,  0.4495,  0.7579,
-         0.0411,  0.0000, -0.3167,  0.0000,  0.0000, -0.2005,  0.0000,  0.0947,
-         0.0000,  0.0000,  0.4910,  0.1970, -0.5039,  0.1346,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7390,  0.0000,  0.0000, -0.5398, -0.1005, -0.2078],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3732,  0.0000,  0.0188,  0.3362,  2.4777, -0.0493,  0.0000,  0.0000,
-        -0.2581,  0.3631, -0.0120,  0.0000, -0.0188, -0.0648,  0.0000,  0.0000,
-         0.0000,  0.1802,  0.0000,  0.0000, -0.2545,  0.1660,  0.0000, -0.5941,
-         0.1321,  0.0000,  0.3736,  0.2897,  0.0177,  0.0000, -0.7756, -1.1958,
-         0.0000,  0.1053,  0.0000,  0.0000,  0.0000,  0.0000,  0.4495,  0.7579,
-         0.0411,  0.0000, -0.3167,  0.0000,  0.0000, -0.2005,  0.0000,  0.0947,
-         0.0000,  0.0000,  0.4910,  0.1970, -0.5039,  0.1346,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7390,  0.0000,  0.0000, -0.5398, -0.1005, -0.2078],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7560e-01, -6.0193e-04,  6.4413e-02,  3.1489e-01,  2.4705e+00,
-         1.2096e-02,  3.6436e-13, -2.8770e-06, -2.7910e-01,  4.3189e-01,
-         5.3175e-02,  6.5502e-05,  1.1170e-02, -8.9337e-02, -6.5340e-10,
-        -4.8563e-07, -4.7769e-08,  1.9462e-01, -1.1049e-10,  1.3284e-06,
-        -2.4614e-01,  1.0148e-01, -2.3463e-10, -5.6785e-01,  5.8958e-02,
-         1.9043e-06,  3.8619e-01,  2.9406e-01, -1.0002e-02,  3.4124e-10,
-        -8.4260e-01, -1.1884e+00, -8.7351e-14,  1.0299e-01,  4.3357e-08,
-         2.3394e-06,  1.4463e-09,  0.0000e+00,  4.8192e-01,  7.7157e-01,
-        -3.3535e-02, -1.9317e-07, -2.9051e-01,  3.5594e-10,  8.7060e-08,
-        -8.0814e-02, -1.2284e-05,  1.5022e-01,  4.1323e-05, -2.2660e-11,
-         4.9823e-01,  1.1855e-01, -4.8418e-01,  2.3343e-02, -8.3038e-03,
-         4.7159e-15, -1.8534e-09,  6.6657e-13,  6.8350e-01, -9.6266e-05,
-         1.7806e-12, -5.7469e-01, -1.0293e-01, -1.5016e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3756,  0.0000,  0.0644,  0.3149,  2.4705,  0.0121,  0.0000,  0.0000,
-        -0.2791,  0.4319,  0.0532,  0.0000,  0.0112, -0.0893,  0.0000,  0.0000,
-         0.0000,  0.1946,  0.0000,  0.0000, -0.2461,  0.1015,  0.0000, -0.5679,
-         0.0590,  0.0000,  0.3862,  0.2941, -0.0100,  0.0000, -0.8426, -1.1884,
-         0.0000,  0.1030,  0.0000,  0.0000,  0.0000,  0.0000,  0.4819,  0.7716,
-        -0.0335,  0.0000, -0.2905,  0.0000,  0.0000, -0.0808,  0.0000,  0.1502,
-         0.0000,  0.0000,  0.4982,  0.1185, -0.4842,  0.0233,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6835,  0.0000,  0.0000, -0.5747, -0.1029, -0.1502],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3756,  0.0000,  0.0644,  0.3149,  2.4705,  0.0121,  0.0000,  0.0000,
-        -0.2791,  0.4319,  0.0532,  0.0000,  0.0112, -0.0893,  0.0000,  0.0000,
-         0.0000,  0.1946,  0.0000,  0.0000, -0.2461,  0.1015,  0.0000, -0.5679,
-         0.0590,  0.0000,  0.3862,  0.2941, -0.0100,  0.0000, -0.8426, -1.1884,
-         0.0000,  0.1030,  0.0000,  0.0000,  0.0000,  0.0000,  0.4819,  0.7716,
-        -0.0335,  0.0000, -0.2905,  0.0000,  0.0000, -0.0808,  0.0000,  0.1502,
-         0.0000,  0.0000,  0.4982,  0.1185, -0.4842,  0.0233,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6835,  0.0000,  0.0000, -0.5747, -0.1029, -0.1502],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5162e-01, -5.2391e-04,  7.1525e-02,  3.3173e-01,  2.4652e+00,
-         3.3652e-02,  3.1713e-13, -2.5041e-06, -3.1245e-01,  4.3749e-01,
-         6.1354e-02,  5.7012e-05, -3.0386e-02, -1.5073e-01, -5.6871e-10,
-        -4.2269e-07, -4.1578e-08,  1.7622e-01, -9.6166e-11,  1.1562e-06,
-        -2.3195e-01,  5.4875e-02, -2.0422e-10, -5.7149e-01,  1.5336e-02,
-         1.6575e-06,  3.7587e-01,  2.8823e-01, -4.3335e-02,  2.9701e-10,
-        -8.7417e-01, -1.1849e+00, -7.6029e-14,  1.2475e-01,  3.7737e-08,
-         2.0362e-06,  1.2588e-09,  0.0000e+00,  4.6752e-01,  7.6607e-01,
-        -9.5252e-02, -1.6814e-07, -2.6439e-01,  3.0981e-10,  7.5776e-08,
-         4.6638e-02, -1.0692e-05,  1.8988e-01,  3.5967e-05, -1.9723e-11,
-         4.5973e-01,  2.6329e-02, -4.6245e-01, -1.2783e-02, -7.2275e-03,
-         4.1047e-15, -1.6132e-09,  5.8017e-13,  6.5549e-01, -8.3789e-05,
-         1.5498e-12, -5.8915e-01, -1.2263e-01, -5.0584e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3516,  0.0000,  0.0715,  0.3317,  2.4652,  0.0337,  0.0000,  0.0000,
-        -0.3124,  0.4375,  0.0614,  0.0000, -0.0304, -0.1507,  0.0000,  0.0000,
-         0.0000,  0.1762,  0.0000,  0.0000, -0.2319,  0.0549,  0.0000, -0.5715,
-         0.0153,  0.0000,  0.3759,  0.2882, -0.0433,  0.0000, -0.8742, -1.1849,
-         0.0000,  0.1247,  0.0000,  0.0000,  0.0000,  0.0000,  0.4675,  0.7661,
-        -0.0953,  0.0000, -0.2644,  0.0000,  0.0000,  0.0466,  0.0000,  0.1899,
-         0.0000,  0.0000,  0.4597,  0.0263, -0.4624, -0.0128,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6555,  0.0000,  0.0000, -0.5891, -0.1226, -0.0506],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3516,  0.0000,  0.0715,  0.3317,  2.4652,  0.0337,  0.0000,  0.0000,
-        -0.3124,  0.4375,  0.0614,  0.0000, -0.0304, -0.1507,  0.0000,  0.0000,
-         0.0000,  0.1762,  0.0000,  0.0000, -0.2319,  0.0549,  0.0000, -0.5715,
-         0.0153,  0.0000,  0.3759,  0.2882, -0.0433,  0.0000, -0.8742, -1.1849,
-         0.0000,  0.1247,  0.0000,  0.0000,  0.0000,  0.0000,  0.4675,  0.7661,
-        -0.0953,  0.0000, -0.2644,  0.0000,  0.0000,  0.0466,  0.0000,  0.1899,
-         0.0000,  0.0000,  0.4597,  0.0263, -0.4624, -0.0128,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6555,  0.0000,  0.0000, -0.5891, -0.1226, -0.0506],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2621e-01, -4.5618e-04,  4.4235e-02,  3.6739e-01,  2.4605e+00,
-         4.0816e-02,  2.7613e-13, -2.1804e-06, -3.3749e-01,  4.0230e-01,
-        -4.3735e-02,  4.9641e-05, -1.0444e-01, -1.7251e-01, -4.9519e-10,
-        -3.6804e-07, -3.6202e-08,  1.7472e-01, -8.3733e-11,  1.0068e-06,
-        -1.9741e-01,  4.6291e-02, -1.7782e-10, -5.8698e-01,  1.6987e-02,
-         1.4432e-06,  3.4957e-01,  2.7991e-01, -6.7149e-02,  2.5861e-10,
-        -8.6671e-01, -1.1834e+00, -6.6199e-14,  1.4602e-01,  3.2858e-08,
-         1.7729e-06,  1.0961e-09,  0.0000e+00,  4.6088e-01,  7.6480e-01,
-        -1.4547e-01, -1.4640e-07, -2.6508e-01,  2.6975e-10,  6.5979e-08,
-         1.1161e-01, -9.3095e-06,  2.2269e-01,  3.1317e-05, -1.7173e-11,
-         4.0807e-01,  3.5387e-03, -4.4618e-01,  4.6432e-03, -6.2931e-03,
-         3.5740e-15, -1.4046e-09,  5.0516e-13,  6.2427e-01, -7.2956e-05,
-         1.3494e-12, -5.9692e-01, -1.1404e-01,  2.8934e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3262,  0.0000,  0.0442,  0.3674,  2.4605,  0.0408,  0.0000,  0.0000,
-        -0.3375,  0.4023, -0.0437,  0.0000, -0.1044, -0.1725,  0.0000,  0.0000,
-         0.0000,  0.1747,  0.0000,  0.0000, -0.1974,  0.0463,  0.0000, -0.5870,
-         0.0170,  0.0000,  0.3496,  0.2799, -0.0671,  0.0000, -0.8667, -1.1834,
-         0.0000,  0.1460,  0.0000,  0.0000,  0.0000,  0.0000,  0.4609,  0.7648,
-        -0.1455,  0.0000, -0.2651,  0.0000,  0.0000,  0.1116,  0.0000,  0.2227,
-         0.0000,  0.0000,  0.4081,  0.0035, -0.4462,  0.0046,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6243,  0.0000,  0.0000, -0.5969, -0.1140,  0.0289],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3262,  0.0000,  0.0442,  0.3674,  2.4605,  0.0408,  0.0000,  0.0000,
-        -0.3375,  0.4023, -0.0437,  0.0000, -0.1044, -0.1725,  0.0000,  0.0000,
-         0.0000,  0.1747,  0.0000,  0.0000, -0.1974,  0.0463,  0.0000, -0.5870,
-         0.0170,  0.0000,  0.3496,  0.2799, -0.0671,  0.0000, -0.8667, -1.1834,
-         0.0000,  0.1460,  0.0000,  0.0000,  0.0000,  0.0000,  0.4609,  0.7648,
-        -0.1455,  0.0000, -0.2651,  0.0000,  0.0000,  0.1116,  0.0000,  0.2227,
-         0.0000,  0.0000,  0.4081,  0.0035, -0.4462,  0.0046,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6243,  0.0000,  0.0000, -0.5969, -0.1140,  0.0289],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1206e-01, -3.9735e-04,  1.7211e-02,  4.2150e-01,  2.4545e+00,
-         2.7120e-02,  2.4052e-13, -1.8992e-06, -3.6543e-01,  3.5268e-01,
-        -1.3115e-01,  4.3239e-05, -1.4794e-01, -1.8585e-01, -4.3133e-10,
-        -3.2058e-07, -3.1534e-08,  1.9539e-01, -7.2935e-11,  8.7692e-07,
-        -1.4195e-01,  5.3537e-02, -1.5489e-10, -5.6023e-01,  7.0294e-02,
-         1.2571e-06,  2.9595e-01,  2.8069e-01, -7.8406e-02,  2.2526e-10,
-        -8.4769e-01, -1.1783e+00, -5.7662e-14,  1.6793e-01,  2.8621e-08,
-         1.5443e-06,  9.5472e-10,  0.0000e+00,  4.3764e-01,  7.5494e-01,
-        -1.7561e-01, -1.2752e-07, -2.6375e-01,  2.3497e-10,  5.7470e-08,
-         1.0425e-01, -8.1089e-06,  2.0564e-01,  2.7278e-05, -1.4958e-11,
-         3.2348e-01,  5.8941e-02, -4.5283e-01,  5.7191e-03, -5.4815e-03,
-         3.1131e-15, -1.2235e-09,  4.4002e-13,  5.8036e-01, -6.3548e-05,
-         1.1754e-12, -6.0392e-01, -1.5750e-01,  1.0564e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3121,  0.0000,  0.0172,  0.4215,  2.4545,  0.0271,  0.0000,  0.0000,
-        -0.3654,  0.3527, -0.1312,  0.0000, -0.1479, -0.1858,  0.0000,  0.0000,
-         0.0000,  0.1954,  0.0000,  0.0000, -0.1419,  0.0535,  0.0000, -0.5602,
-         0.0703,  0.0000,  0.2959,  0.2807, -0.0784,  0.0000, -0.8477, -1.1783,
-         0.0000,  0.1679,  0.0000,  0.0000,  0.0000,  0.0000,  0.4376,  0.7549,
-        -0.1756,  0.0000, -0.2637,  0.0000,  0.0000,  0.1042,  0.0000,  0.2056,
-         0.0000,  0.0000,  0.3235,  0.0589, -0.4528,  0.0057,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5804,  0.0000,  0.0000, -0.6039, -0.1575,  0.1056],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3121,  0.0000,  0.0172,  0.4215,  2.4545,  0.0271,  0.0000,  0.0000,
-        -0.3654,  0.3527, -0.1312,  0.0000, -0.1479, -0.1858,  0.0000,  0.0000,
-         0.0000,  0.1954,  0.0000,  0.0000, -0.1419,  0.0535,  0.0000, -0.5602,
-         0.0703,  0.0000,  0.2959,  0.2807, -0.0784,  0.0000, -0.8477, -1.1783,
-         0.0000,  0.1679,  0.0000,  0.0000,  0.0000,  0.0000,  0.4376,  0.7549,
-        -0.1756,  0.0000, -0.2637,  0.0000,  0.0000,  0.1042,  0.0000,  0.2056,
-         0.0000,  0.0000,  0.3235,  0.0589, -0.4528,  0.0057,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5804,  0.0000,  0.0000, -0.6039, -0.1575,  0.1056],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8278e-01, -3.4624e-04,  1.7099e-02,  4.3274e-01,  2.4493e+00,
-         1.3791e-02,  2.0958e-13, -1.6549e-06, -3.8574e-01,  3.0898e-01,
-        -1.3869e-01,  3.7677e-05, -1.6834e-01, -2.1448e-01, -3.7584e-10,
-        -2.7934e-07, -2.7477e-08,  1.7465e-01, -6.3553e-11,  7.6412e-07,
-        -6.4627e-02,  5.6005e-02, -1.3496e-10, -5.4024e-01,  9.7093e-02,
-         1.0954e-06,  2.5339e-01,  2.7647e-01, -6.6673e-02,  1.9629e-10,
-        -8.4090e-01, -1.1693e+00, -5.0245e-14,  1.6247e-01,  2.4939e-08,
-         1.3456e-06,  8.3192e-10,  0.0000e+00,  4.5215e-01,  7.6420e-01,
-        -1.6367e-01, -1.1112e-07, -2.5733e-01,  2.0474e-10,  5.0078e-08,
-         8.7597e-02, -7.0658e-06,  2.0308e-01,  2.3769e-05, -1.3034e-11,
-         2.8131e-01,  8.3393e-02, -4.2288e-01, -2.4764e-02, -4.7764e-03,
-         2.7127e-15, -1.0661e-09,  3.8342e-13,  5.3834e-01, -5.5374e-05,
-         1.0242e-12, -6.0539e-01, -1.6750e-01,  2.0464e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2828,  0.0000,  0.0171,  0.4327,  2.4493,  0.0138,  0.0000,  0.0000,
-        -0.3857,  0.3090, -0.1387,  0.0000, -0.1683, -0.2145,  0.0000,  0.0000,
-         0.0000,  0.1746,  0.0000,  0.0000, -0.0646,  0.0560,  0.0000, -0.5402,
-         0.0971,  0.0000,  0.2534,  0.2765, -0.0667,  0.0000, -0.8409, -1.1693,
-         0.0000,  0.1625,  0.0000,  0.0000,  0.0000,  0.0000,  0.4521,  0.7642,
-        -0.1637,  0.0000, -0.2573,  0.0000,  0.0000,  0.0876,  0.0000,  0.2031,
-         0.0000,  0.0000,  0.0000,  0.0834, -0.4229, -0.0248,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5383,  0.0000,  0.0000, -0.6054, -0.1675,  0.2046],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2828,  0.0000,  0.0171,  0.4327,  2.4493,  0.0138,  0.0000,  0.0000,
-        -0.3857,  0.3090, -0.1387,  0.0000, -0.1683, -0.2145,  0.0000,  0.0000,
-         0.0000,  0.1746,  0.0000,  0.0000, -0.0646,  0.0560,  0.0000, -0.5402,
-         0.0971,  0.0000,  0.2534,  0.2765, -0.0667,  0.0000, -0.8409, -1.1693,
-         0.0000,  0.1625,  0.0000,  0.0000,  0.0000,  0.0000,  0.4521,  0.7642,
-        -0.1637,  0.0000, -0.2573,  0.0000,  0.0000,  0.0876,  0.0000,  0.2031,
-         0.0000,  0.0000,  0.0000,  0.0834, -0.4229, -0.0248,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5383,  0.0000,  0.0000, -0.6054, -0.1675,  0.2046],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.5784e-01, -3.0182e-04, -5.0099e-02,  4.1206e-01,  2.4444e+00,
-        -1.0291e-03,  1.8269e-13, -1.4426e-06, -4.0500e-01,  2.5124e-01,
-        -1.0460e-01,  3.2843e-05, -2.2089e-01, -2.0591e-01, -3.2762e-10,
-        -2.4350e-07, -2.3952e-08,  1.4153e-01, -5.5399e-11,  6.6608e-07,
-         6.7749e-02,  5.6480e-02, -1.1765e-10, -4.7403e-01,  1.6743e-01,
-         9.5484e-07,  1.6749e-01,  3.1232e-01, -1.4279e-02,  1.7110e-10,
-        -8.3694e-01, -1.1571e+00, -4.3799e-14,  1.6326e-01,  2.1740e-08,
-         1.1730e-06,  7.2518e-10,  0.0000e+00,  4.2368e-01,  7.7885e-01,
-        -7.8586e-02, -9.6860e-08, -2.7378e-01,  1.7847e-10,  4.3653e-08,
-         5.4021e-02, -6.1593e-06,  1.5157e-01,  2.0720e-05, -1.1362e-11,
-        -3.6762e-02,  1.4779e-01, -3.6994e-01, -1.5105e-01, -4.1636e-03,
-         2.3646e-15, -9.2934e-10,  3.3422e-13,  5.1101e-01, -4.8269e-05,
-         8.9279e-13, -6.4367e-01, -1.7654e-01,  2.7615e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.5784e-01,  0.0000e+00, -5.0099e-02,  4.1206e-01,  2.4444e+00,
-        -1.0291e-03,  0.0000e+00,  0.0000e+00, -4.0500e-01,  2.5124e-01,
-        -1.0460e-01,  0.0000e+00, -2.2089e-01, -2.0591e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  1.4153e-01,  0.0000e+00,  0.0000e+00,
-         6.7749e-02,  5.6480e-02,  0.0000e+00, -4.7403e-01,  1.6743e-01,
-         0.0000e+00,  1.6749e-01,  3.1232e-01, -1.4279e-02,  0.0000e+00,
-        -8.3694e-01, -1.1571e+00,  0.0000e+00,  1.6326e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2368e-01,  7.7885e-01,
-        -7.8586e-02,  0.0000e+00, -2.7378e-01,  0.0000e+00,  0.0000e+00,
-         5.4021e-02,  0.0000e+00,  1.5157e-01,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.4779e-01, -3.6994e-01, -1.5105e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.1101e-01,  0.0000e+00,
-         0.0000e+00, -6.4367e-01, -1.7654e-01,  2.7615e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.5784e-01,  0.0000e+00, -5.0099e-02,  4.1206e-01,  2.4444e+00,
-        -1.0291e-03,  0.0000e+00,  0.0000e+00, -4.0500e-01,  2.5124e-01,
-        -1.0460e-01,  0.0000e+00, -2.2089e-01, -2.0591e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  1.4153e-01,  0.0000e+00,  0.0000e+00,
-         6.7749e-02,  5.6480e-02,  0.0000e+00, -4.7403e-01,  1.6743e-01,
-         0.0000e+00,  1.6749e-01,  3.1232e-01, -1.4279e-02,  0.0000e+00,
-        -8.3694e-01, -1.1571e+00,  0.0000e+00,  1.6326e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2368e-01,  7.7885e-01,
-        -7.8586e-02,  0.0000e+00, -2.7378e-01,  0.0000e+00,  0.0000e+00,
-         5.4021e-02,  0.0000e+00,  1.5157e-01,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.4779e-01, -3.6994e-01, -1.5105e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.1101e-01,  0.0000e+00,
-         0.0000e+00, -6.4367e-01, -1.7654e-01,  2.7615e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7038e-01, -2.6319e-04, -6.6353e-02,  3.7236e-01,  2.4423e+00,
-         1.6916e-02,  1.5931e-13, -1.2580e-06, -3.9687e-01,  1.7376e-01,
-        -9.1570e-02,  2.8640e-05, -1.8875e-01, -2.0240e-01, -2.8570e-10,
-        -2.1234e-07, -2.0887e-08,  9.8325e-02, -4.8310e-11,  5.8085e-07,
-         1.3039e-01,  6.1700e-02, -1.0259e-10, -5.2420e-01,  2.1224e-01,
-         8.3265e-07,  7.1878e-02,  3.0130e-01,  4.8567e-02,  1.4921e-10,
-        -8.4034e-01, -1.1401e+00, -3.8194e-14,  1.5669e-01,  1.8958e-08,
-         1.0229e-06,  6.3238e-10,  0.0000e+00,  4.3200e-01,  8.1780e-01,
-         1.8279e-02, -8.4465e-08, -2.5922e-01,  1.5563e-10,  3.8066e-08,
-        -4.3714e-02, -5.3711e-06,  6.8732e-02,  1.8068e-05, -9.9080e-12,
-        -3.2058e-02,  1.4941e-01, -3.6091e-01, -1.7313e-01, -3.6308e-03,
-         2.0620e-15, -8.1041e-10,  2.9145e-13,  4.8550e-01, -4.2092e-05,
-         7.7854e-13, -6.4208e-01, -1.6581e-01,  2.6847e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2704,  0.0000, -0.0664,  0.3724,  2.4423,  0.0169,  0.0000,  0.0000,
-        -0.3969,  0.1738, -0.0916,  0.0000, -0.1887, -0.2024,  0.0000,  0.0000,
-         0.0000,  0.0983,  0.0000,  0.0000,  0.1304,  0.0617,  0.0000, -0.5242,
-         0.2122,  0.0000,  0.0719,  0.3013,  0.0486,  0.0000, -0.8403, -1.1401,
-         0.0000,  0.1567,  0.0000,  0.0000,  0.0000,  0.0000,  0.4320,  0.8178,
-         0.0183,  0.0000, -0.2592,  0.0000,  0.0000, -0.0437,  0.0000,  0.0687,
-         0.0000,  0.0000,  0.0000,  0.1494, -0.3609, -0.1731,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4855,  0.0000,  0.0000, -0.6421, -0.1658,  0.2685],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2704,  0.0000, -0.0664,  0.3724,  2.4423,  0.0169,  0.0000,  0.0000,
-        -0.3969,  0.1738, -0.0916,  0.0000, -0.1887, -0.2024,  0.0000,  0.0000,
-         0.0000,  0.0983,  0.0000,  0.0000,  0.1304,  0.0617,  0.0000, -0.5242,
-         0.2122,  0.0000,  0.0719,  0.3013,  0.0486,  0.0000, -0.8403, -1.1401,
-         0.0000,  0.1567,  0.0000,  0.0000,  0.0000,  0.0000,  0.4320,  0.8178,
-         0.0183,  0.0000, -0.2592,  0.0000,  0.0000, -0.0437,  0.0000,  0.0687,
-         0.0000,  0.0000,  0.0000,  0.1494, -0.3609, -0.1731,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4855,  0.0000,  0.0000, -0.6421, -0.1658,  0.2685],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8168e-01, -2.2960e-04, -2.7975e-02,  3.4407e-01,  2.4416e+00,
-         8.7118e-02,  1.3898e-13, -1.0974e-06, -3.5528e-01,  7.1206e-02,
-        -1.5023e-01,  2.4985e-05, -1.2449e-01, -2.3496e-01, -2.4923e-10,
-        -1.8524e-07, -1.8221e-08,  4.2841e-02, -4.2144e-11,  5.0671e-07,
-         1.2747e-01,  3.0238e-02, -8.9498e-11, -5.9550e-01,  1.9769e-01,
-         7.2637e-07,  1.9658e-02,  2.5702e-01,  1.1870e-01,  1.3016e-10,
-        -8.3229e-01, -1.1321e+00, -3.3319e-14,  1.2431e-01,  1.6538e-08,
-         8.9232e-07,  5.5166e-10,  0.0000e+00,  4.5291e-01,  8.6585e-01,
-         4.8801e-02, -7.3684e-08, -2.1440e-01,  1.3577e-10,  3.3208e-08,
-        -8.8934e-02, -4.6855e-06,  1.2802e-02,  1.5762e-05, -8.6434e-12,
-        -2.7966e-02,  8.2602e-02, -3.8278e-01, -4.6579e-02, -3.1674e-03,
-         1.7988e-15, -7.0697e-10,  2.5425e-13,  4.8864e-01, -3.6720e-05,
-         6.7917e-13, -5.7040e-01, -1.9821e-01,  2.0057e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2817,  0.0000, -0.0280,  0.3441,  2.4416,  0.0871,  0.0000,  0.0000,
-        -0.3553,  0.0712, -0.1502,  0.0000, -0.1245, -0.2350,  0.0000,  0.0000,
-         0.0000,  0.0428,  0.0000,  0.0000,  0.1275,  0.0302,  0.0000, -0.5955,
-         0.1977,  0.0000,  0.0197,  0.2570,  0.1187,  0.0000, -0.8323, -1.1321,
-         0.0000,  0.1243,  0.0000,  0.0000,  0.0000,  0.0000,  0.4529,  0.8658,
-         0.0488,  0.0000, -0.2144,  0.0000,  0.0000, -0.0889,  0.0000,  0.0128,
-         0.0000,  0.0000,  0.0000,  0.0826, -0.3828, -0.0466,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4886,  0.0000,  0.0000, -0.5704, -0.1982,  0.2006],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2817,  0.0000, -0.0280,  0.3441,  2.4416,  0.0871,  0.0000,  0.0000,
-        -0.3553,  0.0712, -0.1502,  0.0000, -0.1245, -0.2350,  0.0000,  0.0000,
-         0.0000,  0.0428,  0.0000,  0.0000,  0.1275,  0.0302,  0.0000, -0.5955,
-         0.1977,  0.0000,  0.0197,  0.2570,  0.1187,  0.0000, -0.8323, -1.1321,
-         0.0000,  0.1243,  0.0000,  0.0000,  0.0000,  0.0000,  0.4529,  0.8658,
-         0.0488,  0.0000, -0.2144,  0.0000,  0.0000, -0.0889,  0.0000,  0.0128,
-         0.0000,  0.0000,  0.0000,  0.0826, -0.3828, -0.0466,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4886,  0.0000,  0.0000, -0.5704, -0.1982,  0.2006],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8532e-01, -2.0037e-04,  4.2665e-02,  3.2617e-01,  2.4423e+00,
-         2.0871e-01,  1.2129e-13, -9.5770e-07, -3.0621e-01,  4.2552e-03,
-        -1.7522e-01,  2.1804e-05, -2.4579e-02, -2.6876e-01, -2.1750e-10,
-        -1.6166e-07, -1.5901e-08, -3.0743e-02, -3.6779e-11,  4.4220e-07,
-         7.4929e-02, -1.2829e-02, -7.8104e-11, -7.1196e-01,  1.7774e-01,
-         6.3390e-07,  7.0560e-04,  1.9939e-01,  1.9148e-01,  1.1359e-10,
-        -8.2765e-01, -1.1334e+00, -2.9077e-14,  6.0190e-02,  1.4433e-08,
-         7.7872e-07,  4.8144e-10,  0.0000e+00,  4.6752e-01,  8.9653e-01,
-         4.6748e-02, -6.4304e-08, -1.6300e-01,  1.1849e-10,  2.8980e-08,
-        -1.4117e-01, -4.0890e-06,  2.3222e-03,  1.3756e-05, -7.5430e-12,
-        -2.4406e-02, -6.7295e-03, -4.0895e-01,  1.1349e-01, -2.7642e-03,
-         1.5698e-15, -6.1697e-10,  2.2189e-13,  5.0067e-01, -3.2045e-05,
-         5.9271e-13, -5.0443e-01, -2.3119e-01,  1.0409e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.8532e-01,  0.0000e+00,  4.2665e-02,  3.2617e-01,  2.4423e+00,
-         2.0871e-01,  0.0000e+00,  0.0000e+00, -3.0621e-01,  4.2552e-03,
-        -1.7522e-01,  0.0000e+00, -2.4579e-02, -2.6876e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.0743e-02,  0.0000e+00,  0.0000e+00,
-         7.4929e-02, -1.2829e-02,  0.0000e+00, -7.1196e-01,  1.7774e-01,
-         0.0000e+00,  7.0560e-04,  1.9939e-01,  1.9148e-01,  0.0000e+00,
-        -8.2765e-01, -1.1334e+00,  0.0000e+00,  6.0190e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6752e-01,  8.9653e-01,
-         4.6748e-02,  0.0000e+00, -1.6300e-01,  0.0000e+00,  0.0000e+00,
-        -1.4117e-01,  0.0000e+00,  2.3222e-03,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -6.7295e-03, -4.0895e-01,  1.1349e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.0067e-01,  0.0000e+00,
-         0.0000e+00, -5.0443e-01, -2.3119e-01,  1.0409e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.8532e-01,  0.0000e+00,  4.2665e-02,  3.2617e-01,  2.4423e+00,
-         2.0871e-01,  0.0000e+00,  0.0000e+00, -3.0621e-01,  4.2552e-03,
-        -1.7522e-01,  0.0000e+00, -2.4579e-02, -2.6876e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.0743e-02,  0.0000e+00,  0.0000e+00,
-         7.4929e-02, -1.2829e-02,  0.0000e+00, -7.1196e-01,  1.7774e-01,
-         0.0000e+00,  7.0560e-04,  1.9939e-01,  1.9148e-01,  0.0000e+00,
-        -8.2765e-01, -1.1334e+00,  0.0000e+00,  6.0190e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6752e-01,  8.9653e-01,
-         4.6748e-02,  0.0000e+00, -1.6300e-01,  0.0000e+00,  0.0000e+00,
-        -1.4117e-01,  0.0000e+00,  2.3222e-03,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -6.7295e-03, -4.0895e-01,  1.1349e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.0067e-01,  0.0000e+00,
-         0.0000e+00, -5.0443e-01, -2.3119e-01,  1.0409e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6792e-01, -1.7493e-04,  6.4759e-02,  2.7333e-01,  2.4425e+00,
-         3.5665e-01,  1.0589e-13, -8.3611e-07, -2.5456e-01, -7.8516e-03,
-        -1.3802e-01,  1.9036e-05,  1.8049e-02, -3.1434e-01, -1.8989e-10,
-        -1.4113e-07, -1.3882e-08, -1.0430e-01, -3.2109e-11,  3.8606e-07,
-        -2.8989e-02, -1.2378e-02, -6.8188e-11, -7.8751e-01,  2.0740e-01,
-         5.5342e-07, -2.1021e-02,  1.2559e-01,  3.0989e-01,  9.9169e-11,
-        -8.3519e-01, -1.1378e+00, -2.5385e-14,  1.7936e-02,  1.2600e-08,
-         6.7985e-07,  4.2031e-10,  0.0000e+00,  4.3104e-01,  8.9997e-01,
-         1.2016e-01, -5.6139e-08, -1.3737e-01,  1.0344e-10,  2.5301e-08,
-        -1.6142e-01, -3.5699e-06, -3.2400e-02,  1.2009e-05, -6.5853e-12,
-        -2.1307e-02, -8.1226e-02, -4.0657e-01,  2.3315e-01, -2.4132e-03,
-         1.3705e-15, -5.3864e-10,  1.9371e-13,  5.0346e-01, -2.7976e-05,
-         5.1745e-13, -4.6304e-01, -2.6264e-01,  5.2330e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2679,  0.0000,  0.0648,  0.2733,  2.4425,  0.3567,  0.0000,  0.0000,
-        -0.2546, -0.0079, -0.1380,  0.0000,  0.0180, -0.3143,  0.0000,  0.0000,
-         0.0000, -0.1043,  0.0000,  0.0000, -0.0290, -0.0124,  0.0000, -0.7875,
-         0.2074,  0.0000, -0.0210,  0.1256,  0.3099,  0.0000, -0.8352, -1.1378,
-         0.0000,  0.0179,  0.0000,  0.0000,  0.0000,  0.0000,  0.4310,  0.9000,
-         0.1202,  0.0000, -0.1374,  0.0000,  0.0000, -0.1614,  0.0000, -0.0324,
-         0.0000,  0.0000,  0.0000, -0.0812, -0.4066,  0.2331,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5035,  0.0000,  0.0000, -0.4630, -0.2626,  0.0523],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2679,  0.0000,  0.0648,  0.2733,  2.4425,  0.3567,  0.0000,  0.0000,
-        -0.2546, -0.0079, -0.1380,  0.0000,  0.0180, -0.3143,  0.0000,  0.0000,
-         0.0000, -0.1043,  0.0000,  0.0000, -0.0290, -0.0124,  0.0000, -0.7875,
-         0.2074,  0.0000, -0.0210,  0.1256,  0.3099,  0.0000, -0.8352, -1.1378,
-         0.0000,  0.0179,  0.0000,  0.0000,  0.0000,  0.0000,  0.4310,  0.9000,
-         0.1202,  0.0000, -0.1374,  0.0000,  0.0000, -0.1614,  0.0000, -0.0324,
-         0.0000,  0.0000,  0.0000, -0.0812, -0.4066,  0.2331,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5035,  0.0000,  0.0000, -0.4630, -0.2626,  0.0523],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.2234e-01, -1.5278e-04,  1.1193e-02,  1.9859e-01,  2.4409e+00,
-         4.7460e-01,  9.2479e-14, -7.3023e-07, -1.9530e-01, -2.6922e-02,
-        -7.2796e-02,  1.6625e-05, -5.9594e-03, -3.4675e-01, -1.6584e-10,
-        -1.2326e-07, -1.2124e-08, -1.3939e-01, -2.8043e-11,  3.3717e-07,
-        -1.4151e-01,  2.9938e-02, -5.9553e-11, -8.3891e-01,  2.4572e-01,
-         4.8334e-07, -7.2800e-02,  4.7188e-02,  3.9825e-01,  8.6611e-11,
-        -8.4847e-01, -1.1413e+00, -2.2171e-14, -1.5381e-03,  1.1005e-08,
-         5.9376e-07,  3.6708e-10,  0.0000e+00,  3.8435e-01,  8.9743e-01,
-         2.3247e-01, -4.9030e-08, -1.4596e-01,  9.0343e-11,  2.2097e-08,
-        -1.7051e-01, -3.1178e-06, -4.4113e-02,  1.0488e-05, -5.7514e-12,
-        -1.8609e-02, -1.0802e-01, -3.7511e-01,  2.6346e-01, -2.1076e-03,
-         1.1970e-15, -4.7043e-10,  1.6918e-13,  5.2875e-01, -2.4434e-05,
-         4.5193e-13, -4.6779e-01, -2.4735e-01, -4.7853e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.2234e-01,  0.0000e+00,  1.1193e-02,  1.9859e-01,  2.4409e+00,
-         4.7460e-01,  0.0000e+00,  0.0000e+00, -1.9530e-01, -2.6922e-02,
-        -7.2796e-02,  0.0000e+00, -5.9594e-03, -3.4675e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.3939e-01,  0.0000e+00,  0.0000e+00,
-        -1.4151e-01,  2.9938e-02,  0.0000e+00, -8.3891e-01,  2.4572e-01,
-         0.0000e+00, -7.2800e-02,  4.7188e-02,  3.9825e-01,  0.0000e+00,
-        -8.4847e-01, -1.1413e+00,  0.0000e+00, -1.5381e-03,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.8435e-01,  8.9743e-01,
-         2.3247e-01,  0.0000e+00, -1.4596e-01,  0.0000e+00,  0.0000e+00,
-        -1.7051e-01,  0.0000e+00, -4.4113e-02,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -1.0802e-01, -3.7511e-01,  2.6346e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.2875e-01,  0.0000e+00,
-         0.0000e+00, -4.6779e-01, -2.4735e-01, -4.7853e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.2234e-01,  0.0000e+00,  1.1193e-02,  1.9859e-01,  2.4409e+00,
-         4.7460e-01,  0.0000e+00,  0.0000e+00, -1.9530e-01, -2.6922e-02,
-        -7.2796e-02,  0.0000e+00, -5.9594e-03, -3.4675e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.3939e-01,  0.0000e+00,  0.0000e+00,
-        -1.4151e-01,  2.9938e-02,  0.0000e+00, -8.3891e-01,  2.4572e-01,
-         0.0000e+00, -7.2800e-02,  4.7188e-02,  3.9825e-01,  0.0000e+00,
-        -8.4847e-01, -1.1413e+00,  0.0000e+00, -1.5381e-03,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.8435e-01,  8.9743e-01,
-         2.3247e-01,  0.0000e+00, -1.4596e-01,  0.0000e+00,  0.0000e+00,
-        -1.7051e-01,  0.0000e+00, -4.4113e-02,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -1.0802e-01, -3.7511e-01,  2.6346e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.2875e-01,  0.0000e+00,
-         0.0000e+00, -4.6779e-01, -2.4735e-01, -4.7853e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.7595e-01, -1.3348e-04, -6.4304e-02,  1.0458e-01,  2.4381e+00,
-         5.6884e-01,  8.0800e-14, -6.3801e-07, -1.2509e-01, -6.2898e-02,
-        -7.8560e-04,  1.4526e-05, -5.7334e-02, -3.6389e-01, -1.4490e-10,
-        -1.0769e-07, -1.0593e-08, -1.6299e-01, -2.4501e-11,  2.9459e-07,
-        -2.4164e-01,  8.5751e-02, -5.2032e-11, -8.4927e-01,  2.7630e-01,
-         4.2230e-07, -1.4648e-01, -2.7657e-02,  4.6871e-01,  7.5673e-11,
-        -8.6652e-01, -1.1462e+00, -1.9371e-14, -1.0880e-02,  9.6147e-09,
-         5.1877e-07,  3.2072e-10,  0.0000e+00,  3.1876e-01,  8.9231e-01,
-         3.2730e-01, -4.2838e-08, -1.6573e-01,  7.8933e-11,  1.9306e-08,
-        -1.7528e-01, -2.7241e-06, -6.5885e-02,  9.1637e-06, -5.0251e-12,
-        -1.6259e-02, -8.6654e-02, -3.2876e-01,  2.3614e-01, -1.8414e-03,
-         1.0458e-15, -4.1102e-10,  1.4782e-13,  5.4292e-01, -2.1348e-05,
-         3.9485e-13, -4.8493e-01, -2.1685e-01, -8.1347e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 1.7595e-01,  0.0000e+00, -6.4304e-02,  1.0458e-01,  2.4381e+00,
-         5.6884e-01,  0.0000e+00,  0.0000e+00, -1.2509e-01, -6.2898e-02,
-        -7.8560e-04,  0.0000e+00, -5.7334e-02, -3.6389e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.6299e-01,  0.0000e+00,  0.0000e+00,
-        -2.4164e-01,  8.5751e-02,  0.0000e+00, -8.4927e-01,  2.7630e-01,
-         0.0000e+00, -1.4648e-01, -2.7657e-02,  4.6871e-01,  0.0000e+00,
-        -8.6652e-01, -1.1462e+00,  0.0000e+00, -1.0880e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.1876e-01,  8.9231e-01,
-         3.2730e-01,  0.0000e+00, -1.6573e-01,  0.0000e+00,  0.0000e+00,
-        -1.7528e-01,  0.0000e+00, -6.5885e-02,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -8.6654e-02, -3.2876e-01,  2.3614e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.4292e-01,  0.0000e+00,
-         0.0000e+00, -4.8493e-01, -2.1685e-01, -8.1347e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 1.7595e-01,  0.0000e+00, -6.4304e-02,  1.0458e-01,  2.4381e+00,
-         5.6884e-01,  0.0000e+00,  0.0000e+00, -1.2509e-01, -6.2898e-02,
-        -7.8560e-04,  0.0000e+00, -5.7334e-02, -3.6389e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.6299e-01,  0.0000e+00,  0.0000e+00,
-        -2.4164e-01,  8.5751e-02,  0.0000e+00, -8.4927e-01,  2.7630e-01,
-         0.0000e+00, -1.4648e-01, -2.7657e-02,  4.6871e-01,  0.0000e+00,
-        -8.6652e-01, -1.1462e+00,  0.0000e+00, -1.0880e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.1876e-01,  8.9231e-01,
-         3.2730e-01,  0.0000e+00, -1.6573e-01,  0.0000e+00,  0.0000e+00,
-        -1.7528e-01,  0.0000e+00, -6.5885e-02,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -8.6654e-02, -3.2876e-01,  2.3614e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.4292e-01,  0.0000e+00,
-         0.0000e+00, -4.8493e-01, -2.1685e-01, -8.1347e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.4915e-01, -1.1667e-04, -1.1199e-01,  6.5085e-02,  2.4350e+00,
-         6.7004e-01,  7.0623e-14, -5.5765e-07, -1.4394e-02, -7.0757e-02,
-        -1.8389e-02,  1.2696e-05, -9.0452e-02, -3.5275e-01, -1.2665e-10,
-        -9.4128e-08, -9.2590e-09, -1.0583e-01, -2.1415e-11,  2.5748e-07,
-        -2.7518e-01,  1.5880e-01, -4.5478e-11, -8.5793e-01,  3.0170e-01,
-         3.6911e-07, -1.9879e-01, -4.8126e-02,  5.0637e-01,  6.6142e-11,
-        -8.8356e-01, -1.1490e+00, -1.6931e-14, -2.4941e-02,  8.4037e-09,
-         4.5343e-07,  2.8033e-10,  0.0000e+00,  2.7823e-01,  8.9262e-01,
-         3.9162e-01, -3.7443e-08, -1.8320e-01,  6.8991e-11,  1.6875e-08,
-        -2.0704e-01, -2.3810e-06, -6.5140e-02,  8.0095e-06, -4.3921e-12,
-        -1.4211e-02, -2.5641e-02, -2.9163e-01,  2.5251e-01, -1.6095e-03,
-         9.1408e-16, -3.5925e-10,  1.2920e-13,  5.4782e-01, -1.8659e-05,
-         3.4512e-13, -4.6805e-01, -1.9369e-01, -1.4413e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1491,  0.0000, -0.1120,  0.0651,  2.4350,  0.6700,  0.0000,  0.0000,
-        -0.0144, -0.0708, -0.0184,  0.0000, -0.0905, -0.3528,  0.0000,  0.0000,
-         0.0000, -0.1058,  0.0000,  0.0000, -0.2752,  0.1588,  0.0000, -0.8579,
-         0.3017,  0.0000, -0.1988, -0.0481,  0.5064,  0.0000, -0.8836, -1.1490,
-         0.0000, -0.0249,  0.0000,  0.0000,  0.0000,  0.0000,  0.2782,  0.8926,
-         0.3916,  0.0000, -0.1832,  0.0000,  0.0000, -0.2070,  0.0000, -0.0651,
-         0.0000,  0.0000,  0.0000, -0.0256, -0.2916,  0.2525,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5478,  0.0000,  0.0000, -0.4681, -0.1937, -0.1441],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1491,  0.0000, -0.1120,  0.0651,  2.4350,  0.6700,  0.0000,  0.0000,
-        -0.0144, -0.0708, -0.0184,  0.0000, -0.0905, -0.3528,  0.0000,  0.0000,
-         0.0000, -0.1058,  0.0000,  0.0000, -0.2752,  0.1588,  0.0000, -0.8579,
-         0.3017,  0.0000, -0.1988, -0.0481,  0.5064,  0.0000, -0.8836, -1.1490,
-         0.0000, -0.0249,  0.0000,  0.0000,  0.0000,  0.0000,  0.2782,  0.8926,
-         0.3916,  0.0000, -0.1832,  0.0000,  0.0000, -0.2070,  0.0000, -0.0651,
-         0.0000,  0.0000,  0.0000, -0.0256, -0.2916,  0.2525,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5478,  0.0000,  0.0000, -0.4681, -0.1937, -0.1441],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.1592e-01, -1.0202e-04, -1.5458e-01,  5.5854e-02,  2.4321e+00,
-         7.7636e-01,  6.1752e-14, -4.8760e-07,  9.4582e-02, -7.3868e-02,
-        -7.9312e-02,  1.1101e-05, -1.0773e-01, -3.2577e-01, -1.1074e-10,
-        -8.2305e-08, -8.0960e-09, -3.4433e-02, -1.8725e-11,  2.2514e-07,
-        -2.7779e-01,  2.1047e-01, -3.9766e-11, -8.7433e-01,  2.8746e-01,
-         3.2274e-07, -2.3442e-01, -4.7876e-02,  5.2118e-01,  5.7834e-11,
-        -8.7955e-01, -1.1453e+00, -1.4804e-14, -5.4417e-02,  7.3481e-09,
-         3.9648e-07,  2.4512e-10,  0.0000e+00,  2.2905e-01,  8.8025e-01,
-         4.0828e-01, -3.2739e-08, -2.0316e-01,  6.0325e-11,  1.4755e-08,
-        -2.4311e-01, -2.0819e-06, -2.4420e-02,  7.0034e-06, -3.8404e-12,
-        -1.2426e-02,  1.9836e-02, -2.7734e-01,  2.6641e-01, -1.4073e-03,
-         7.9926e-16, -3.1412e-10,  1.1297e-13,  5.6321e-01, -1.6315e-05,
-         3.0177e-13, -4.2650e-01, -1.7179e-01, -1.8152e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1159,  0.0000, -0.1546,  0.0559,  2.4321,  0.7764,  0.0000,  0.0000,
-         0.0946, -0.0739, -0.0793,  0.0000, -0.1077, -0.3258,  0.0000,  0.0000,
-         0.0000, -0.0344,  0.0000,  0.0000, -0.2778,  0.2105,  0.0000, -0.8743,
-         0.2875,  0.0000, -0.2344, -0.0479,  0.5212,  0.0000, -0.8796, -1.1453,
-         0.0000, -0.0544,  0.0000,  0.0000,  0.0000,  0.0000,  0.2290,  0.8802,
-         0.4083,  0.0000, -0.2032,  0.0000,  0.0000, -0.2431,  0.0000, -0.0244,
-         0.0000,  0.0000,  0.0000,  0.0198, -0.2773,  0.2664,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5632,  0.0000,  0.0000, -0.4265, -0.1718, -0.1815],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1159,  0.0000, -0.1546,  0.0559,  2.4321,  0.7764,  0.0000,  0.0000,
-         0.0946, -0.0739, -0.0793,  0.0000, -0.1077, -0.3258,  0.0000,  0.0000,
-         0.0000, -0.0344,  0.0000,  0.0000, -0.2778,  0.2105,  0.0000, -0.8743,
-         0.2875,  0.0000, -0.2344, -0.0479,  0.5212,  0.0000, -0.8796, -1.1453,
-         0.0000, -0.0544,  0.0000,  0.0000,  0.0000,  0.0000,  0.2290,  0.8802,
-         0.4083,  0.0000, -0.2032,  0.0000,  0.0000, -0.2431,  0.0000, -0.0244,
-         0.0000,  0.0000,  0.0000,  0.0198, -0.2773,  0.2664,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5632,  0.0000,  0.0000, -0.4265, -0.1718, -0.1815],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.3958e-02, -8.9236e-05, -2.0953e-01,  1.7173e-02,  2.4273e+00,
-         8.7973e-01,  5.4016e-14, -4.2652e-07,  1.8594e-01, -1.3627e-01,
-        -1.3098e-01,  9.7106e-06, -1.3139e-01, -2.9550e-01, -9.6866e-11,
-        -7.1995e-08, -7.0818e-09, -6.1339e-02, -1.6380e-11,  1.9694e-07,
-        -2.4098e-01,  2.2425e-01, -3.4784e-11, -8.7828e-01,  2.4837e-01,
-         2.8231e-07, -2.6941e-01, -6.0304e-02,  5.5472e-01,  5.0589e-11,
-        -9.0584e-01, -1.1381e+00, -1.2950e-14, -8.2552e-02,  6.4276e-09,
-         3.4681e-07,  2.1441e-10,  0.0000e+00,  1.3433e-01,  8.7020e-01,
-         4.1816e-01, -2.8638e-08, -2.2235e-01,  5.2768e-11,  1.2907e-08,
-        -2.6979e-01, -1.8211e-06,  4.4270e-02,  6.1261e-06, -3.3593e-12,
-        -1.0869e-02,  7.4551e-02, -2.4394e-01,  1.9431e-01, -1.2310e-03,
-         6.9914e-16, -2.7477e-10,  9.8818e-14,  5.7626e-01, -1.4271e-05,
-         2.6397e-13, -4.1195e-01, -1.5126e-01, -2.1864e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0440,  0.0000, -0.2095,  0.0172,  2.4273,  0.8797,  0.0000,  0.0000,
-         0.1859, -0.1363, -0.1310,  0.0000, -0.1314, -0.2955,  0.0000,  0.0000,
-         0.0000, -0.0613,  0.0000,  0.0000, -0.2410,  0.2242,  0.0000, -0.8783,
-         0.2484,  0.0000, -0.2694, -0.0603,  0.5547,  0.0000, -0.9058, -1.1381,
-         0.0000, -0.0826,  0.0000,  0.0000,  0.0000,  0.0000,  0.1343,  0.8702,
-         0.4182,  0.0000, -0.2224,  0.0000,  0.0000, -0.2698,  0.0000,  0.0443,
-         0.0000,  0.0000,  0.0000,  0.0746, -0.2439,  0.1943,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5763,  0.0000,  0.0000, -0.4119, -0.1513, -0.2186],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0440,  0.0000, -0.2095,  0.0172,  2.4273,  0.8797,  0.0000,  0.0000,
-         0.1859, -0.1363, -0.1310,  0.0000, -0.1314, -0.2955,  0.0000,  0.0000,
-         0.0000, -0.0613,  0.0000,  0.0000, -0.2410,  0.2242,  0.0000, -0.8783,
-         0.2484,  0.0000, -0.2694, -0.0603,  0.5547,  0.0000, -0.9058, -1.1381,
-         0.0000, -0.0826,  0.0000,  0.0000,  0.0000,  0.0000,  0.1343,  0.8702,
-         0.4182,  0.0000, -0.2224,  0.0000,  0.0000, -0.2698,  0.0000,  0.0443,
-         0.0000,  0.0000,  0.0000,  0.0746, -0.2439,  0.1943,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5763,  0.0000,  0.0000, -0.4119, -0.1513, -0.2186],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 9.0467e-03, -7.8088e-05, -2.5551e-01, -3.5744e-02,  2.4193e+00,
-         9.4598e-01,  4.7268e-14, -3.7324e-07,  2.5536e-01, -1.7691e-01,
-        -1.5319e-01,  8.4975e-06, -1.4986e-01, -2.8131e-01, -8.4765e-11,
-        -6.3001e-08, -6.1971e-09, -9.3451e-02, -1.4333e-11,  1.7233e-07,
-        -2.1692e-01,  1.9491e-01, -3.0439e-11, -8.7789e-01,  1.6699e-01,
-         2.4704e-07, -2.5154e-01, -1.0055e-01,  5.5471e-01,  4.4269e-11,
-        -9.1178e-01, -1.1384e+00, -1.1332e-14, -1.3928e-01,  5.6246e-09,
-         3.0348e-07,  1.8762e-10,  0.0000e+00,  2.4316e-02,  8.5069e-01,
-         4.1049e-01, -2.5060e-08, -2.2885e-01,  4.6176e-11,  1.1294e-08,
-        -2.8354e-01, -1.5936e-06,  1.0864e-01,  5.3608e-06, -2.9397e-12,
-        -9.5114e-03,  6.5526e-02, -1.8425e-01,  1.0941e-01, -1.0772e-03,
-         6.1179e-16, -2.4045e-10,  8.6473e-14,  5.7382e-01, -1.2489e-05,
-         2.3099e-13, -3.6636e-01, -1.1059e-01, -2.3845e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0090,  0.0000, -0.2555, -0.0357,  2.4193,  0.9460,  0.0000,  0.0000,
-         0.2554, -0.1769, -0.1532,  0.0000, -0.1499, -0.2813,  0.0000,  0.0000,
-         0.0000, -0.0935,  0.0000,  0.0000, -0.2169,  0.1949,  0.0000, -0.8779,
-         0.1670,  0.0000, -0.2515, -0.1005,  0.5547,  0.0000, -0.9118, -1.1384,
-         0.0000, -0.1393,  0.0000,  0.0000,  0.0000,  0.0000,  0.0243,  0.8507,
-         0.4105,  0.0000, -0.2288,  0.0000,  0.0000, -0.2835,  0.0000,  0.1086,
-         0.0000,  0.0000,  0.0000,  0.0655, -0.1842,  0.1094,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5738,  0.0000,  0.0000, -0.3664, -0.1106, -0.2385],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0090,  0.0000, -0.2555, -0.0357,  2.4193,  0.9460,  0.0000,  0.0000,
-         0.2554, -0.1769, -0.1532,  0.0000, -0.1499, -0.2813,  0.0000,  0.0000,
-         0.0000, -0.0935,  0.0000,  0.0000, -0.2169,  0.1949,  0.0000, -0.8779,
-         0.1670,  0.0000, -0.2515, -0.1005,  0.5547,  0.0000, -0.9118, -1.1384,
-         0.0000, -0.1393,  0.0000,  0.0000,  0.0000,  0.0000,  0.0243,  0.8507,
-         0.4105,  0.0000, -0.2288,  0.0000,  0.0000, -0.2835,  0.0000,  0.1086,
-         0.0000,  0.0000,  0.0000,  0.0655, -0.1842,  0.1094,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5738,  0.0000,  0.0000, -0.3664, -0.1106, -0.2385],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.2516e-03, -6.8360e-05, -2.7619e-01, -7.0575e-02,  2.4119e+00,
-         9.9005e-01,  4.1379e-14, -3.2674e-07,  3.1606e-01, -1.7276e-01,
-        -1.5195e-01,  7.4389e-06, -1.7682e-01, -2.8361e-01, -7.4205e-11,
-        -5.5152e-08, -5.4250e-09, -1.0549e-01, -1.2548e-11,  1.5086e-07,
-        -2.1127e-01,  1.7789e-01, -2.6647e-11, -8.8727e-01,  8.8601e-02,
-         2.1627e-07, -2.1889e-01, -1.4150e-01,  5.4091e-01,  3.8754e-11,
-        -9.0785e-01, -1.1489e+00, -9.9202e-15, -1.7601e-01,  4.9239e-09,
-         2.6568e-07,  1.6425e-10,  0.0000e+00, -9.0658e-02,  8.2874e-01,
-         3.9359e-01, -2.1938e-08, -2.2207e-01,  4.0423e-11,  9.8871e-09,
-        -2.6641e-01, -1.3951e-06,  1.6908e-01,  4.6929e-06, -2.5734e-12,
-        -8.3265e-03,  4.4850e-02, -1.3156e-01,  8.6804e-02, -9.4304e-04,
-         5.3558e-16, -2.1049e-10,  7.5700e-14,  5.5913e-01, -1.0933e-05,
-         2.0221e-13, -3.0562e-01, -8.5440e-02, -2.6519e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0063,  0.0000, -0.2762, -0.0706,  2.4119,  0.9901,  0.0000,  0.0000,
-         0.3161, -0.1728, -0.1520,  0.0000, -0.1768, -0.2836,  0.0000,  0.0000,
-         0.0000, -0.1055,  0.0000,  0.0000, -0.2113,  0.1779,  0.0000, -0.8873,
-         0.0886,  0.0000, -0.2189, -0.1415,  0.5409,  0.0000, -0.9079, -1.1489,
-         0.0000, -0.1760,  0.0000,  0.0000,  0.0000,  0.0000, -0.0907,  0.8287,
-         0.3936,  0.0000, -0.2221,  0.0000,  0.0000, -0.2664,  0.0000,  0.1691,
-         0.0000,  0.0000,  0.0000,  0.0448, -0.1316,  0.0868,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5591,  0.0000,  0.0000, -0.3056, -0.0854, -0.2652],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0063,  0.0000, -0.2762, -0.0706,  2.4119,  0.9901,  0.0000,  0.0000,
-         0.3161, -0.1728, -0.1520,  0.0000, -0.1768, -0.2836,  0.0000,  0.0000,
-         0.0000, -0.1055,  0.0000,  0.0000, -0.2113,  0.1779,  0.0000, -0.8873,
-         0.0886,  0.0000, -0.2189, -0.1415,  0.5409,  0.0000, -0.9079, -1.1489,
-         0.0000, -0.1760,  0.0000,  0.0000,  0.0000,  0.0000, -0.0907,  0.8287,
-         0.3936,  0.0000, -0.2221,  0.0000,  0.0000, -0.2664,  0.0000,  0.1691,
-         0.0000,  0.0000,  0.0000,  0.0448, -0.1316,  0.0868,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5591,  0.0000,  0.0000, -0.3056, -0.0854, -0.2652],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.3257e-02, -5.9867e-05, -2.3028e-01, -8.6201e-02,  2.4038e+00,
-         1.0251e+00,  3.6239e-14, -2.8614e-07,  3.5762e-01, -1.1913e-01,
-        -1.1953e-01,  6.5147e-06, -1.6636e-01, -3.0213e-01, -6.4986e-11,
-        -4.8300e-08, -4.7510e-09, -9.3260e-02, -1.0989e-11,  1.3212e-07,
-        -2.2403e-01,  1.6458e-01, -2.3336e-11, -9.0503e-01,  1.1615e-02,
-         1.8940e-07, -1.2897e-01, -2.0925e-01,  5.1959e-01,  3.3939e-11,
-        -8.9600e-01, -1.1643e+00, -8.6877e-15, -2.0290e-01,  4.3122e-09,
-         2.3267e-07,  1.4384e-10,  0.0000e+00, -1.6583e-01,  8.0938e-01,
-         3.6108e-01, -1.9213e-08, -1.7094e-01,  3.5401e-11,  8.6588e-09,
-        -2.4792e-01, -1.2217e-06,  2.4315e-01,  4.1099e-06, -2.2537e-12,
-        -7.2920e-03, -5.4695e-02, -1.2421e-01,  2.0123e-01, -8.2588e-04,
-         4.6904e-16, -1.8434e-10,  6.6295e-14,  5.3295e-01, -9.5745e-06,
-         1.7709e-13, -2.1142e-01, -5.8128e-02, -2.8156e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0233,  0.0000, -0.2303, -0.0862,  2.4038,  1.0251,  0.0000,  0.0000,
-         0.3576, -0.1191, -0.1195,  0.0000, -0.1664, -0.3021,  0.0000,  0.0000,
-         0.0000, -0.0933,  0.0000,  0.0000, -0.2240,  0.1646,  0.0000, -0.9050,
-         0.0116,  0.0000, -0.1290, -0.2093,  0.5196,  0.0000, -0.8960, -1.1643,
-         0.0000, -0.2029,  0.0000,  0.0000,  0.0000,  0.0000, -0.1658,  0.8094,
-         0.3611,  0.0000, -0.1709,  0.0000,  0.0000, -0.2479,  0.0000,  0.2432,
-         0.0000,  0.0000,  0.0000, -0.0547, -0.1242,  0.2012,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5330,  0.0000,  0.0000, -0.2114, -0.0581, -0.2816],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0233,  0.0000, -0.2303, -0.0862,  2.4038,  1.0251,  0.0000,  0.0000,
-         0.3576, -0.1191, -0.1195,  0.0000, -0.1664, -0.3021,  0.0000,  0.0000,
-         0.0000, -0.0933,  0.0000,  0.0000, -0.2240,  0.1646,  0.0000, -0.9050,
-         0.0116,  0.0000, -0.1290, -0.2093,  0.5196,  0.0000, -0.8960, -1.1643,
-         0.0000, -0.2029,  0.0000,  0.0000,  0.0000,  0.0000, -0.1658,  0.8094,
-         0.3611,  0.0000, -0.1709,  0.0000,  0.0000, -0.2479,  0.0000,  0.2432,
-         0.0000,  0.0000,  0.0000, -0.0547, -0.1242,  0.2012,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5330,  0.0000,  0.0000, -0.2114, -0.0581, -0.2816],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.6425e-02, -5.2450e-05, -2.3463e-01, -9.8281e-02,  2.3968e+00,
-         1.0455e+00,  3.1749e-14, -2.5069e-07,  3.8281e-01, -1.4376e-02,
-        -4.7228e-02,  5.7076e-06, -1.9548e-01, -3.2274e-01, -5.6935e-11,
-        -4.2316e-08, -4.1624e-09,  1.3500e-02, -9.6274e-12,  1.1575e-07,
-        -2.1498e-01,  2.2398e-01, -2.0445e-11, -8.6986e-01,  4.0855e-02,
-         1.6593e-07, -1.0179e-01, -2.1226e-01,  5.0298e-01,  2.9734e-11,
-        -8.6452e-01, -1.1736e+00, -7.6114e-15, -1.7156e-01,  3.7779e-09,
-         2.0384e-07,  1.2602e-10,  0.0000e+00, -2.8165e-01,  7.9452e-01,
-         3.6620e-01, -1.6833e-08, -1.6410e-01,  3.1015e-11,  7.5861e-09,
-        -2.7531e-01, -1.0704e-06,  2.6693e-01,  3.6007e-06, -1.9745e-12,
-        -6.3886e-03, -2.4725e-02, -1.8576e-01,  3.4751e-01, -7.2356e-04,
-         4.1093e-16, -1.6150e-10,  5.8082e-14,  5.0384e-01, -8.3883e-06,
-         1.5515e-13, -1.2888e-01, -5.0114e-02, -3.0642e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0164,  0.0000, -0.2346, -0.0983,  2.3968,  1.0455,  0.0000,  0.0000,
-         0.3828, -0.0144, -0.0472,  0.0000, -0.1955, -0.3227,  0.0000,  0.0000,
-         0.0000,  0.0135,  0.0000,  0.0000, -0.2150,  0.2240,  0.0000, -0.8699,
-         0.0409,  0.0000, -0.1018, -0.2123,  0.5030,  0.0000, -0.8645, -1.1736,
-         0.0000, -0.1716,  0.0000,  0.0000,  0.0000,  0.0000, -0.2817,  0.7945,
-         0.3662,  0.0000, -0.1641,  0.0000,  0.0000, -0.2753,  0.0000,  0.2669,
-         0.0000,  0.0000,  0.0000, -0.0247, -0.1858,  0.3475,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5038,  0.0000,  0.0000, -0.1289, -0.0501, -0.3064],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0164,  0.0000, -0.2346, -0.0983,  2.3968,  1.0455,  0.0000,  0.0000,
-         0.3828, -0.0144, -0.0472,  0.0000, -0.1955, -0.3227,  0.0000,  0.0000,
-         0.0000,  0.0135,  0.0000,  0.0000, -0.2150,  0.2240,  0.0000, -0.8699,
-         0.0409,  0.0000, -0.1018, -0.2123,  0.5030,  0.0000, -0.8645, -1.1736,
-         0.0000, -0.1716,  0.0000,  0.0000,  0.0000,  0.0000, -0.2817,  0.7945,
-         0.3662,  0.0000, -0.1641,  0.0000,  0.0000, -0.2753,  0.0000,  0.2669,
-         0.0000,  0.0000,  0.0000, -0.0247, -0.1858,  0.3475,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5038,  0.0000,  0.0000, -0.1289, -0.0501, -0.3064],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.9391e-02, -4.5970e-05, -2.1353e-01, -9.6219e-02,  2.3897e+00,
-         1.0626e+00,  2.7827e-14, -2.1972e-07,  4.0798e-01,  4.8122e-02,
-        -2.8404e-02,  5.0025e-06, -1.9660e-01, -3.2711e-01, -4.9901e-11,
-        -3.7088e-08, -3.6482e-09,  1.3583e-01, -8.4380e-12,  1.0145e-07,
-        -1.7190e-01,  2.6453e-01, -1.7919e-11, -8.2562e-01,  5.3005e-02,
-         1.4543e-07, -6.0023e-02, -1.9627e-01,  4.7956e-01,  2.6061e-11,
-        -8.3877e-01, -1.1797e+00, -6.6711e-15, -1.3835e-01,  3.3112e-09,
-         1.7866e-07,  1.1045e-10,  0.0000e+00, -3.9511e-01,  7.8790e-01,
-         3.5450e-01, -1.4753e-08, -1.3269e-01,  2.7184e-11,  6.6489e-09,
-        -2.6405e-01, -9.3814e-07,  3.0655e-01,  3.1559e-06, -1.7306e-12,
-        -5.5993e-03,  7.5969e-03, -2.5795e-01,  4.7813e-01, -6.3417e-04,
-         3.6016e-16, -1.4155e-10,  5.0907e-14,  4.6094e-01, -7.3520e-06,
-         1.3598e-13, -5.3259e-02, -4.3792e-02, -3.4981e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0194,  0.0000, -0.2135, -0.0962,  2.3897,  1.0626,  0.0000,  0.0000,
-         0.4080,  0.0481, -0.0284,  0.0000, -0.1966, -0.3271,  0.0000,  0.0000,
-         0.0000,  0.1358,  0.0000,  0.0000, -0.1719,  0.2645,  0.0000, -0.8256,
-         0.0530,  0.0000, -0.0600, -0.1963,  0.4796,  0.0000, -0.8388, -1.1797,
-         0.0000, -0.1383,  0.0000,  0.0000,  0.0000,  0.0000, -0.3951,  0.7879,
-         0.3545,  0.0000, -0.1327,  0.0000,  0.0000, -0.2640,  0.0000,  0.3066,
-         0.0000,  0.0000,  0.0000,  0.0076, -0.2579,  0.4781,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4609,  0.0000,  0.0000, -0.0533, -0.0438, -0.3498],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0194,  0.0000, -0.2135, -0.0962,  2.3897,  1.0626,  0.0000,  0.0000,
-         0.4080,  0.0481, -0.0284,  0.0000, -0.1966, -0.3271,  0.0000,  0.0000,
-         0.0000,  0.1358,  0.0000,  0.0000, -0.1719,  0.2645,  0.0000, -0.8256,
-         0.0530,  0.0000, -0.0600, -0.1963,  0.4796,  0.0000, -0.8388, -1.1797,
-         0.0000, -0.1383,  0.0000,  0.0000,  0.0000,  0.0000, -0.3951,  0.7879,
-         0.3545,  0.0000, -0.1327,  0.0000,  0.0000, -0.2640,  0.0000,  0.3066,
-         0.0000,  0.0000,  0.0000,  0.0076, -0.2579,  0.4781,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4609,  0.0000,  0.0000, -0.0533, -0.0438, -0.3498],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.8239e-02, -4.0307e-05, -2.2231e-01, -1.4754e-01,  2.3806e+00,
-         1.0545e+00,  2.4399e-14, -1.9265e-07,  4.2140e-01,  2.1091e-02,
-         2.2928e-02,  4.3862e-06, -1.9737e-01, -3.3502e-01, -4.3754e-11,
-        -3.2519e-08, -3.1988e-09,  1.0856e-01, -7.3985e-12,  8.8955e-08,
-        -1.4906e-01,  2.6437e-01, -1.5712e-11, -7.7408e-01,  4.1049e-02,
-         1.2752e-07, -6.1224e-02, -1.9710e-01,  4.3032e-01,  2.2850e-11,
-        -8.1056e-01, -1.1840e+00, -5.8493e-15, -1.1118e-01,  2.9033e-09,
-         1.5665e-07,  9.6847e-11,  0.0000e+00, -4.5921e-01,  7.8824e-01,
-         3.1632e-01, -1.2936e-08, -1.0376e-01,  2.3835e-11,  5.8298e-09,
-        -2.4681e-01, -8.2257e-07,  3.6023e-01,  2.7671e-06, -1.5174e-12,
-        -4.9095e-03,  6.5767e-02, -2.7375e-01,  3.8492e-01, -5.5605e-04,
-         3.1579e-16, -1.2411e-10,  4.4635e-14,  4.7279e-01, -6.4463e-06,
-         1.1923e-13, -6.5143e-02,  2.6754e-02, -4.0376e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0382,  0.0000, -0.2223, -0.1475,  2.3806,  1.0545,  0.0000,  0.0000,
-         0.4214,  0.0211,  0.0229,  0.0000, -0.1974, -0.3350,  0.0000,  0.0000,
-         0.0000,  0.1086,  0.0000,  0.0000, -0.1491,  0.2644,  0.0000, -0.7741,
-         0.0410,  0.0000, -0.0612, -0.1971,  0.4303,  0.0000, -0.8106, -1.1840,
-         0.0000, -0.1112,  0.0000,  0.0000,  0.0000,  0.0000, -0.4592,  0.7882,
-         0.3163,  0.0000, -0.1038,  0.0000,  0.0000, -0.2468,  0.0000,  0.3602,
-         0.0000,  0.0000,  0.0000,  0.0658, -0.2737,  0.3849,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4728,  0.0000,  0.0000, -0.0651,  0.0268, -0.4038],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0382,  0.0000, -0.2223, -0.1475,  2.3806,  1.0545,  0.0000,  0.0000,
-         0.4214,  0.0211,  0.0229,  0.0000, -0.1974, -0.3350,  0.0000,  0.0000,
-         0.0000,  0.1086,  0.0000,  0.0000, -0.1491,  0.2644,  0.0000, -0.7741,
-         0.0410,  0.0000, -0.0612, -0.1971,  0.4303,  0.0000, -0.8106, -1.1840,
-         0.0000, -0.1112,  0.0000,  0.0000,  0.0000,  0.0000, -0.4592,  0.7882,
-         0.3163,  0.0000, -0.1038,  0.0000,  0.0000, -0.2468,  0.0000,  0.3602,
-         0.0000,  0.0000,  0.0000,  0.0658, -0.2737,  0.3849,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4728,  0.0000,  0.0000, -0.0651,  0.0268, -0.4038],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.4047e-02, -3.5356e-05, -2.2597e-01, -2.3628e-01,  2.3731e+00,
-         1.0383e+00,  2.1401e-14, -1.6899e-07,  4.3383e-01, -1.6966e-02,
-         1.0988e-01,  3.8474e-06, -1.7462e-01, -3.3765e-01, -3.8379e-11,
-        -2.8524e-08, -2.8058e-09, -1.5402e-02, -6.4897e-12,  7.8027e-08,
-        -1.5063e-01,  2.2487e-01, -1.3782e-11, -7.6497e-01,  2.8476e-03,
-         1.1185e-07, -4.4339e-02, -2.0760e-01,  3.7381e-01,  2.0043e-11,
-        -7.8712e-01, -1.1889e+00, -5.1307e-15, -1.2125e-01,  2.5466e-09,
-         1.3741e-07,  8.4950e-11,  0.0000e+00, -4.6232e-01,  7.7459e-01,
-         2.9288e-01, -1.1346e-08, -8.4920e-02,  2.0907e-11,  5.1136e-09,
-        -2.1839e-01, -7.2152e-07,  4.3129e-01,  2.4272e-06, -1.3310e-12,
-        -4.3064e-03,  8.7193e-02, -1.9976e-01,  2.0103e-01, -4.8774e-04,
-         2.7700e-16, -1.0887e-10,  3.9152e-14,  5.2590e-01, -5.6544e-06,
-         1.0458e-13, -7.6494e-02,  1.5081e-01, -4.2434e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0540,  0.0000, -0.2260, -0.2363,  2.3731,  1.0383,  0.0000,  0.0000,
-         0.4338, -0.0170,  0.1099,  0.0000, -0.1746, -0.3377,  0.0000,  0.0000,
-         0.0000, -0.0154,  0.0000,  0.0000, -0.1506,  0.2249,  0.0000, -0.7650,
-         0.0028,  0.0000, -0.0443, -0.2076,  0.3738,  0.0000, -0.7871, -1.1889,
-         0.0000, -0.1213,  0.0000,  0.0000,  0.0000,  0.0000, -0.4623,  0.7746,
-         0.2929,  0.0000, -0.0849,  0.0000,  0.0000, -0.2184,  0.0000,  0.4313,
-         0.0000,  0.0000,  0.0000,  0.0872, -0.1998,  0.2010,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5259,  0.0000,  0.0000, -0.0765,  0.1508, -0.4243],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0540,  0.0000, -0.2260, -0.2363,  2.3731,  1.0383,  0.0000,  0.0000,
-         0.4338, -0.0170,  0.1099,  0.0000, -0.1746, -0.3377,  0.0000,  0.0000,
-         0.0000, -0.0154,  0.0000,  0.0000, -0.1506,  0.2249,  0.0000, -0.7650,
-         0.0028,  0.0000, -0.0443, -0.2076,  0.3738,  0.0000, -0.7871, -1.1889,
-         0.0000, -0.1213,  0.0000,  0.0000,  0.0000,  0.0000, -0.4623,  0.7746,
-         0.2929,  0.0000, -0.0849,  0.0000,  0.0000, -0.2184,  0.0000,  0.4313,
-         0.0000,  0.0000,  0.0000,  0.0872, -0.1998,  0.2010,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5259,  0.0000,  0.0000, -0.0765,  0.1508, -0.4243],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 7.0772e-02, -3.1025e-05, -2.3516e-01, -3.0062e-01,  2.3682e+00,
-         1.0211e+00,  1.8780e-14, -1.4829e-07,  4.5154e-01,  1.8601e-02,
-         2.7509e-01,  3.3761e-06, -1.8352e-01, -3.6824e-01, -3.3678e-11,
-        -2.5030e-08, -2.4621e-09, -1.0720e-01, -5.6947e-12,  6.8469e-08,
-        -1.8629e-01,  2.1725e-01, -1.2093e-11, -7.6849e-01,  8.1073e-03,
-         9.8152e-08,  3.6126e-03, -2.1353e-01,  3.2327e-01,  1.7588e-11,
-        -7.6642e-01, -1.2001e+00, -4.5022e-15, -1.2712e-01,  2.2347e-09,
-         1.2058e-07,  7.4544e-11,  0.0000e+00, -4.2872e-01,  7.4901e-01,
-         2.8230e-01, -9.9566e-09, -1.0769e-01,  1.8346e-11,  4.4872e-09,
-        -2.0814e-01, -6.3314e-07,  4.8023e-01,  2.1299e-06, -1.1679e-12,
-        -3.7789e-03,  9.4982e-02, -1.3134e-01,  6.3441e-02, -4.2800e-04,
-         2.4307e-16, -9.5530e-11,  3.4356e-14,  5.6405e-01, -4.9618e-06,
-         9.1773e-14, -6.9067e-02,  2.2792e-01, -4.2764e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0708,  0.0000, -0.2352, -0.3006,  2.3682,  1.0211,  0.0000,  0.0000,
-         0.4515,  0.0186,  0.2751,  0.0000, -0.1835, -0.3682,  0.0000,  0.0000,
-         0.0000, -0.1072,  0.0000,  0.0000, -0.1863,  0.2172,  0.0000, -0.7685,
-         0.0081,  0.0000,  0.0036, -0.2135,  0.3233,  0.0000, -0.7664, -1.2001,
-         0.0000, -0.1271,  0.0000,  0.0000,  0.0000,  0.0000, -0.4287,  0.7490,
-         0.2823,  0.0000, -0.1077,  0.0000,  0.0000, -0.2081,  0.0000,  0.4802,
-         0.0000,  0.0000,  0.0000,  0.0950, -0.1313,  0.0634,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5641,  0.0000,  0.0000, -0.0691,  0.0000, -0.4276],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0708,  0.0000, -0.2352, -0.3006,  2.3682,  1.0211,  0.0000,  0.0000,
-         0.4515,  0.0186,  0.2751,  0.0000, -0.1835, -0.3682,  0.0000,  0.0000,
-         0.0000, -0.1072,  0.0000,  0.0000, -0.1863,  0.2172,  0.0000, -0.7685,
-         0.0081,  0.0000,  0.0036, -0.2135,  0.3233,  0.0000, -0.7664, -1.2001,
-         0.0000, -0.1271,  0.0000,  0.0000,  0.0000,  0.0000, -0.4287,  0.7490,
-         0.2823,  0.0000, -0.1077,  0.0000,  0.0000, -0.2081,  0.0000,  0.4802,
-         0.0000,  0.0000,  0.0000,  0.0950, -0.1313,  0.0634,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5641,  0.0000,  0.0000, -0.0691,  0.0000, -0.4276],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.1413e-01, -2.7235e-05, -2.4560e-01, -3.1330e-01,  2.3641e+00,
-         1.0200e+00,  1.6486e-14, -1.3018e-07,  4.8288e-01,  4.9718e-02,
-         3.3757e-01,  2.9637e-06, -1.6411e-01, -4.1985e-01, -2.9564e-11,
-        -2.1973e-08, -2.1614e-09, -1.3354e-01, -4.9992e-12,  6.0107e-08,
-        -1.8563e-01,  2.1448e-01, -1.0616e-11, -7.3568e-01,  2.5259e-02,
-         8.6163e-08,  5.7455e-02, -1.9026e-01,  2.9704e-01,  1.5440e-11,
-        -7.7001e-01, -1.2091e+00, -3.9523e-15, -1.2450e-01,  1.9617e-09,
-         1.0585e-07,  6.5439e-11,  0.0000e+00, -3.8022e-01,  7.1287e-01,
-         2.6441e-01, -8.7405e-09, -1.3062e-01,  1.6105e-11,  3.9392e-09,
-        -2.1744e-01, -5.5581e-07,  5.1216e-01,  1.8697e-06, -1.0253e-12,
-        -3.3174e-03,  1.7549e-01, -9.3439e-02, -2.9112e-02, -3.7572e-04,
-         2.1338e-16, -8.3862e-11,  3.0160e-14,  6.0475e-01, -4.3557e-06,
-         8.0564e-14,  1.7507e-03,  6.7689e-02, -4.2621e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 1.1413e-01,  0.0000e+00, -2.4560e-01, -3.1330e-01,  2.3641e+00,
-         1.0200e+00,  0.0000e+00,  0.0000e+00,  4.8288e-01,  4.9718e-02,
-         3.3757e-01,  0.0000e+00, -1.6411e-01, -4.1985e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.3354e-01,  0.0000e+00,  0.0000e+00,
-        -1.8563e-01,  2.1448e-01,  0.0000e+00, -7.3568e-01,  2.5259e-02,
-         0.0000e+00,  5.7455e-02, -1.9026e-01,  2.9704e-01,  0.0000e+00,
-        -7.7001e-01, -1.2091e+00,  0.0000e+00, -1.2450e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -3.8022e-01,  7.1287e-01,
-         2.6441e-01,  0.0000e+00, -1.3062e-01,  0.0000e+00,  0.0000e+00,
-        -2.1744e-01,  0.0000e+00,  5.1216e-01,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.7549e-01, -9.3439e-02, -2.9112e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.0475e-01,  0.0000e+00,
-         0.0000e+00,  1.7507e-03,  0.0000e+00, -4.2621e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 1.1413e-01,  0.0000e+00, -2.4560e-01, -3.1330e-01,  2.3641e+00,
-         1.0200e+00,  0.0000e+00,  0.0000e+00,  4.8288e-01,  4.9718e-02,
-         3.3757e-01,  0.0000e+00, -1.6411e-01, -4.1985e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.3354e-01,  0.0000e+00,  0.0000e+00,
-        -1.8563e-01,  2.1448e-01,  0.0000e+00, -7.3568e-01,  2.5259e-02,
-         0.0000e+00,  5.7455e-02, -1.9026e-01,  2.9704e-01,  0.0000e+00,
-        -7.7001e-01, -1.2091e+00,  0.0000e+00, -1.2450e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -3.8022e-01,  7.1287e-01,
-         2.6441e-01,  0.0000e+00, -1.3062e-01,  0.0000e+00,  0.0000e+00,
-        -2.1744e-01,  0.0000e+00,  5.1216e-01,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.7549e-01, -9.3439e-02, -2.9112e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.0475e-01,  0.0000e+00,
-         0.0000e+00,  1.7507e-03,  0.0000e+00, -4.2621e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.8272e-01, -2.3918e-05, -2.3074e-01, -2.8774e-01,  2.3592e+00,
-         1.0047e+00,  1.4478e-14, -1.1432e-07,  5.3675e-01,  2.0737e-02,
-         3.1615e-01,  2.6028e-06, -8.9359e-02, -4.3104e-01, -2.5964e-11,
-        -1.9297e-08, -1.8982e-09, -1.0937e-01, -4.3903e-12,  5.2786e-08,
-        -1.1761e-01,  2.2608e-01, -9.3234e-12, -7.0709e-01,  4.4547e-02,
-         7.5670e-08,  9.0018e-02, -1.2772e-01,  3.0415e-01,  1.3560e-11,
-        -7.6986e-01, -1.2192e+00, -3.4710e-15, -1.0290e-01,  1.7228e-09,
-         9.2957e-08,  5.7470e-11,  0.0000e+00, -3.2945e-01,  6.9400e-01,
-         2.3770e-01, -7.6760e-09, -1.2133e-01,  1.4144e-11,  3.4594e-09,
-        -1.9025e-01, -4.8812e-07,  5.1451e-01,  1.6420e-06, -9.0042e-13,
-        -2.9134e-03,  3.2089e-01, -1.0767e-01, -9.7170e-02, -3.2996e-04,
-         1.8739e-16, -7.3649e-11,  2.6487e-14,  6.4373e-01, -3.8253e-06,
-         7.0752e-14,  1.1892e-01,  5.9446e-02, -4.2944e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1827,  0.0000, -0.2307, -0.2877,  2.3592,  1.0047,  0.0000,  0.0000,
-         0.5367,  0.0207,  0.3161,  0.0000, -0.0894, -0.4310,  0.0000,  0.0000,
-         0.0000, -0.1094,  0.0000,  0.0000, -0.1176,  0.2261,  0.0000, -0.7071,
-         0.0445,  0.0000,  0.0900, -0.1277,  0.3041,  0.0000, -0.7699, -1.2192,
-         0.0000, -0.1029,  0.0000,  0.0000,  0.0000,  0.0000, -0.3295,  0.6940,
-         0.2377,  0.0000, -0.1213,  0.0000,  0.0000, -0.1902,  0.0000,  0.5145,
-         0.0000,  0.0000,  0.0000,  0.3209, -0.1077, -0.0972,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6437,  0.0000,  0.0000,  0.1189,  0.0000, -0.4294],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1827,  0.0000, -0.2307, -0.2877,  2.3592,  1.0047,  0.0000,  0.0000,
-         0.5367,  0.0207,  0.3161,  0.0000, -0.0894, -0.4310,  0.0000,  0.0000,
-         0.0000, -0.1094,  0.0000,  0.0000, -0.1176,  0.2261,  0.0000, -0.7071,
-         0.0445,  0.0000,  0.0900, -0.1277,  0.3041,  0.0000, -0.7699, -1.2192,
-         0.0000, -0.1029,  0.0000,  0.0000,  0.0000,  0.0000, -0.3295,  0.6940,
-         0.2377,  0.0000, -0.1213,  0.0000,  0.0000, -0.1902,  0.0000,  0.5145,
-         0.0000,  0.0000,  0.0000,  0.3209, -0.1077, -0.0972,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6437,  0.0000,  0.0000,  0.1189,  0.0000, -0.4294],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.5192e-01, -2.1014e-05, -2.0999e-01, -2.1881e-01,  2.3565e+00,
-         1.0079e+00,  1.2720e-14, -1.0044e-07,  5.8410e-01, -2.8010e-03,
-         2.5244e-01,  2.2867e-06, -1.4856e-02, -4.3181e-01, -2.2811e-11,
-        -1.6954e-08, -1.6677e-09, -2.4564e-02, -3.8572e-12,  4.6376e-08,
-        -6.0285e-02,  2.2609e-01, -8.1912e-12, -7.4030e-01,  3.8425e-02,
-         6.6481e-08,  1.4288e-01, -4.3570e-02,  3.1212e-01,  1.1913e-11,
-        -7.5299e-01, -1.2319e+00, -3.0495e-15, -1.1795e-01,  1.5136e-09,
-         8.1669e-08,  5.0491e-11,  0.0000e+00, -2.6142e-01,  6.7584e-01,
-         1.8574e-01, -6.7439e-09, -1.1088e-01,  1.2426e-11,  3.0393e-09,
-        -1.5506e-01, -4.2884e-07,  5.3022e-01,  1.4426e-06, -7.9108e-13,
-        -2.5596e-03,  4.0743e-01, -1.1330e-01, -9.4854e-02, -2.8989e-04,
-         1.6464e-16, -6.4705e-11,  2.3270e-14,  6.7547e-01, -3.3607e-06,
-         6.2161e-14,  2.4534e-01,  5.2227e-02, -4.0746e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2519,  0.0000, -0.2100, -0.2188,  2.3565,  1.0079,  0.0000,  0.0000,
-         0.5841, -0.0028,  0.2524,  0.0000, -0.0149, -0.4318,  0.0000,  0.0000,
-         0.0000, -0.0246,  0.0000,  0.0000, -0.0603,  0.2261,  0.0000, -0.7403,
-         0.0384,  0.0000,  0.1429, -0.0436,  0.3121,  0.0000, -0.7530, -1.2319,
-         0.0000, -0.1179,  0.0000,  0.0000,  0.0000,  0.0000, -0.2614,  0.6758,
-         0.1857,  0.0000, -0.1109,  0.0000,  0.0000, -0.1551,  0.0000,  0.5302,
-         0.0000,  0.0000,  0.0000,  0.4074, -0.1133, -0.0949,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6755,  0.0000,  0.0000,  0.2453,  0.0000, -0.4075],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2519,  0.0000, -0.2100, -0.2188,  2.3565,  1.0079,  0.0000,  0.0000,
-         0.5841, -0.0028,  0.2524,  0.0000, -0.0149, -0.4318,  0.0000,  0.0000,
-         0.0000, -0.0246,  0.0000,  0.0000, -0.0603,  0.2261,  0.0000, -0.7403,
-         0.0384,  0.0000,  0.1429, -0.0436,  0.3121,  0.0000, -0.7530, -1.2319,
-         0.0000, -0.1179,  0.0000,  0.0000,  0.0000,  0.0000, -0.2614,  0.6758,
-         0.1857,  0.0000, -0.1109,  0.0000,  0.0000, -0.1551,  0.0000,  0.5302,
-         0.0000,  0.0000,  0.0000,  0.4074, -0.1133, -0.0949,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6755,  0.0000,  0.0000,  0.2453,  0.0000, -0.4075],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.0708e-01, -1.8470e-05, -1.9715e-01, -1.6019e-01,  2.3552e+00,
-         1.0071e+00,  1.1180e-14, -8.8278e-08,  6.0517e-01,  1.5929e-02,
-         2.2271e-01,  2.0098e-06,  2.5622e-02, -4.4087e-01, -2.0049e-11,
-        -1.4901e-08, -1.4657e-09,  7.4419e-02, -3.3902e-12,  4.0761e-08,
-        -3.8309e-02,  2.4293e-01, -7.1994e-12, -7.6226e-01,  4.4067e-02,
-         5.8431e-08,  1.7541e-01,  1.9587e-02,  3.0306e-01,  1.0471e-11,
-        -7.1693e-01, -1.2452e+00, -2.6803e-15, -1.2406e-01,  1.3303e-09,
-         7.1781e-08,  4.4377e-11,  0.0000e+00, -2.1211e-01,  6.7650e-01,
-         1.6891e-01, -5.9273e-09, -1.0333e-01,  1.0922e-11,  2.6713e-09,
-        -1.3797e-01, -3.7692e-07,  5.3680e-01,  1.2679e-06, -6.9530e-13,
-        -2.2497e-03,  4.3727e-01, -1.1832e-01, -5.9126e-02, -2.5479e-04,
-         1.4470e-16, -5.6871e-11,  2.0453e-14,  6.8783e-01, -2.9538e-06,
-         5.4634e-14,  3.2991e-01,  4.5903e-02, -3.6408e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3071,  0.0000, -0.1972, -0.1602,  2.3552,  1.0071,  0.0000,  0.0000,
-         0.6052,  0.0159,  0.2227,  0.0000,  0.0256, -0.4409,  0.0000,  0.0000,
-         0.0000,  0.0744,  0.0000,  0.0000, -0.0383,  0.2429,  0.0000, -0.7623,
-         0.0441,  0.0000,  0.1754,  0.0196,  0.3031,  0.0000, -0.7169, -1.2452,
-         0.0000, -0.1241,  0.0000,  0.0000,  0.0000,  0.0000, -0.2121,  0.6765,
-         0.1689,  0.0000, -0.1033,  0.0000,  0.0000, -0.1380,  0.0000,  0.5368,
-         0.0000,  0.0000,  0.0000,  0.4373, -0.1183, -0.0591,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6878,  0.0000,  0.0000,  0.3299,  0.0000, -0.3641],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3071,  0.0000, -0.1972, -0.1602,  2.3552,  1.0071,  0.0000,  0.0000,
-         0.6052,  0.0159,  0.2227,  0.0000,  0.0256, -0.4409,  0.0000,  0.0000,
-         0.0000,  0.0744,  0.0000,  0.0000, -0.0383,  0.2429,  0.0000, -0.7623,
-         0.0441,  0.0000,  0.1754,  0.0196,  0.3031,  0.0000, -0.7169, -1.2452,
-         0.0000, -0.1241,  0.0000,  0.0000,  0.0000,  0.0000, -0.2121,  0.6765,
-         0.1689,  0.0000, -0.1033,  0.0000,  0.0000, -0.1380,  0.0000,  0.5368,
-         0.0000,  0.0000,  0.0000,  0.4373, -0.1183, -0.0591,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6878,  0.0000,  0.0000,  0.3299,  0.0000, -0.3641],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7227e-01, -1.6240e-05, -1.7951e-01, -1.0737e-01,  2.3520e+00,
-         1.0009e+00,  9.8302e-15, -7.7621e-08,  5.9401e-01,  8.7087e-02,
-         2.5269e-01,  1.7672e-06,  1.7841e-02, -4.5200e-01, -1.7628e-11,
-        -1.3102e-08, -1.2888e-09,  1.4896e-01, -2.9809e-12,  3.5840e-08,
-        -5.7395e-02,  2.7750e-01, -6.3303e-12, -7.8236e-01,  6.1801e-02,
-         5.1377e-08,  2.0872e-01,  5.3016e-02,  2.6825e-01,  9.2065e-12,
-        -6.8231e-01, -1.2583e+00, -2.3567e-15, -1.4287e-01,  1.1697e-09,
-         6.3115e-08,  3.9020e-11,  0.0000e+00, -1.3304e-01,  6.7933e-01,
-         1.6668e-01, -5.2118e-09, -9.0407e-02,  9.6031e-12,  2.3488e-09,
-        -1.2710e-01, -3.3141e-07,  5.2418e-01,  1.1149e-06, -6.1136e-13,
-        -1.9781e-03,  4.2775e-01, -9.6875e-02,  1.1704e-02, -2.2403e-04,
-         1.2723e-16, -5.0005e-11,  1.7984e-14,  6.9138e-01, -2.5972e-06,
-         4.8038e-14,  3.8778e-01,  4.0362e-02, -3.0653e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3723,  0.0000, -0.1795, -0.1074,  2.3520,  1.0009,  0.0000,  0.0000,
-         0.5940,  0.0871,  0.2527,  0.0000,  0.0178, -0.4520,  0.0000,  0.0000,
-         0.0000,  0.1490,  0.0000,  0.0000, -0.0574,  0.2775,  0.0000, -0.7824,
-         0.0618,  0.0000,  0.2087,  0.0530,  0.2682,  0.0000, -0.6823, -1.2583,
-         0.0000, -0.1429,  0.0000,  0.0000,  0.0000,  0.0000, -0.1330,  0.6793,
-         0.1667,  0.0000, -0.0904,  0.0000,  0.0000, -0.1271,  0.0000,  0.5242,
-         0.0000,  0.0000,  0.0000,  0.4278, -0.0969,  0.0117,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6914,  0.0000,  0.0000,  0.3878,  0.0000, -0.3065],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3723,  0.0000, -0.1795, -0.1074,  2.3520,  1.0009,  0.0000,  0.0000,
-         0.5940,  0.0871,  0.2527,  0.0000,  0.0178, -0.4520,  0.0000,  0.0000,
-         0.0000,  0.1490,  0.0000,  0.0000, -0.0574,  0.2775,  0.0000, -0.7824,
-         0.0618,  0.0000,  0.2087,  0.0530,  0.2682,  0.0000, -0.6823, -1.2583,
-         0.0000, -0.1429,  0.0000,  0.0000,  0.0000,  0.0000, -0.1330,  0.6793,
-         0.1667,  0.0000, -0.0904,  0.0000,  0.0000, -0.1271,  0.0000,  0.5242,
-         0.0000,  0.0000,  0.0000,  0.4278, -0.0969,  0.0117,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6914,  0.0000,  0.0000,  0.3878,  0.0000, -0.3065],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.1533e-01, -1.4285e-05, -1.7200e-01, -1.0260e-01,  2.3492e+00,
-         9.8657e-01,  8.6470e-15, -6.8278e-08,  5.4818e-01,  1.5997e-01,
-         3.3378e-01,  1.5545e-06, -7.0387e-02, -4.7939e-01, -1.5506e-11,
-        -1.1525e-08, -1.1337e-09,  1.3386e-01, -2.6221e-12,  3.1526e-08,
-        -1.2474e-01,  2.8858e-01, -5.5683e-12, -7.8400e-01,  6.0521e-02,
-         4.5193e-08,  2.2652e-01,  1.2839e-02,  1.9591e-01,  8.0983e-12,
-        -6.3596e-01, -1.2680e+00, -2.0730e-15, -1.4949e-01,  1.0289e-09,
-         5.5518e-08,  3.4323e-11,  0.0000e+00, -7.3279e-02,  6.7040e-01,
-         1.7007e-01, -4.5844e-09, -7.4071e-02,  8.4472e-12,  2.0661e-09,
-        -1.4813e-01, -2.9152e-07,  4.9151e-01,  9.8068e-07, -5.3777e-13,
-        -1.7400e-03,  3.6991e-01, -6.5993e-02, -8.8378e-03, -1.9707e-04,
-         1.1192e-16, -4.3986e-11,  1.5819e-14,  6.9313e-01, -2.2846e-06,
-         4.2256e-14,  4.0053e-01,  3.5503e-02, -2.5000e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4153,  0.0000, -0.1720, -0.1026,  2.3492,  0.9866,  0.0000,  0.0000,
-         0.5482,  0.1600,  0.3338,  0.0000, -0.0704, -0.4794,  0.0000,  0.0000,
-         0.0000,  0.1339,  0.0000,  0.0000, -0.1247,  0.2886,  0.0000, -0.7840,
-         0.0605,  0.0000,  0.2265,  0.0128,  0.1959,  0.0000, -0.6360, -1.2680,
-         0.0000, -0.1495,  0.0000,  0.0000,  0.0000,  0.0000, -0.0733,  0.6704,
-         0.1701,  0.0000, -0.0741,  0.0000,  0.0000, -0.1481,  0.0000,  0.4915,
-         0.0000,  0.0000,  0.0000,  0.3699, -0.0660, -0.0088,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6931,  0.0000,  0.0000,  0.4005,  0.0000, -0.2500],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4153,  0.0000, -0.1720, -0.1026,  2.3492,  0.9866,  0.0000,  0.0000,
-         0.5482,  0.1600,  0.3338,  0.0000, -0.0704, -0.4794,  0.0000,  0.0000,
-         0.0000,  0.1339,  0.0000,  0.0000, -0.1247,  0.2886,  0.0000, -0.7840,
-         0.0605,  0.0000,  0.2265,  0.0128,  0.1959,  0.0000, -0.6360, -1.2680,
-         0.0000, -0.1495,  0.0000,  0.0000,  0.0000,  0.0000, -0.0733,  0.6704,
-         0.1701,  0.0000, -0.0741,  0.0000,  0.0000, -0.1481,  0.0000,  0.4915,
-         0.0000,  0.0000,  0.0000,  0.3699, -0.0660, -0.0088,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6931,  0.0000,  0.0000,  0.4005,  0.0000, -0.2500],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.6677e-01, -1.2571e-05, -1.5594e-01, -1.0024e-01,  2.3504e+00,
-         9.7484e-01,  7.6092e-15, -6.0084e-08,  4.9806e-01,  1.6105e-01,
-         2.4589e-01,  1.3679e-06, -1.5337e-01, -4.7843e-01, -1.3646e-11,
-        -1.0142e-08, -9.9761e-10,  1.3810e-01, -2.3074e-12,  2.7742e-08,
-        -1.0715e-01,  2.5661e-01, -4.9000e-12, -7.8881e-01,  5.4050e-02,
-         3.9769e-08,  2.0259e-01, -2.7546e-02,  1.4280e-01,  7.1264e-12,
-        -5.8620e-01, -1.2694e+00, -1.8242e-15, -1.3681e-01,  9.0545e-10,
-         4.8855e-08,  3.0204e-11,  0.0000e+00, -1.5987e-02,  7.0753e-01,
-         1.6494e-01, -4.0342e-09, -5.1281e-02,  7.4334e-12,  1.8181e-09,
-        -1.7663e-01, -2.5653e-07,  4.9033e-01,  8.6298e-07, -4.7323e-13,
-        -1.5311e-03,  3.1824e-01, -1.3724e-01, -2.9192e-02, -1.7342e-04,
-         9.8487e-17, -3.8707e-11,  1.3920e-14,  6.9884e-01, -2.0104e-06,
-         3.7185e-14,  3.8833e-01,  3.1242e-02, -2.5436e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4668,  0.0000, -0.1559, -0.1002,  2.3504,  0.9748,  0.0000,  0.0000,
-         0.4981,  0.1610,  0.2459,  0.0000, -0.1534, -0.4784,  0.0000,  0.0000,
-         0.0000,  0.1381,  0.0000,  0.0000, -0.1071,  0.2566,  0.0000, -0.7888,
-         0.0541,  0.0000,  0.2026, -0.0275,  0.1428,  0.0000, -0.5862, -1.2694,
-         0.0000, -0.1368,  0.0000,  0.0000,  0.0000,  0.0000, -0.0160,  0.7075,
-         0.1649,  0.0000, -0.0513,  0.0000,  0.0000, -0.1766,  0.0000,  0.4903,
-         0.0000,  0.0000,  0.0000,  0.3182, -0.1372, -0.0292,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6988,  0.0000,  0.0000,  0.3883,  0.0000, -0.2544],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4668,  0.0000, -0.1559, -0.1002,  2.3504,  0.9748,  0.0000,  0.0000,
-         0.4981,  0.1610,  0.2459,  0.0000, -0.1534, -0.4784,  0.0000,  0.0000,
-         0.0000,  0.1381,  0.0000,  0.0000, -0.1071,  0.2566,  0.0000, -0.7888,
-         0.0541,  0.0000,  0.2026, -0.0275,  0.1428,  0.0000, -0.5862, -1.2694,
-         0.0000, -0.1368,  0.0000,  0.0000,  0.0000,  0.0000, -0.0160,  0.7075,
-         0.1649,  0.0000, -0.0513,  0.0000,  0.0000, -0.1766,  0.0000,  0.4903,
-         0.0000,  0.0000,  0.0000,  0.3182, -0.1372, -0.0292,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6988,  0.0000,  0.0000,  0.3883,  0.0000, -0.2544],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.0352e-01, -1.1066e-05, -1.3806e-01, -1.2082e-01,  2.3510e+00,
-         9.6690e-01,  6.6987e-15, -5.2894e-08,  4.1665e-01,  1.8644e-01,
-         2.5002e-01,  1.2042e-06, -2.2726e-01, -4.8709e-01, -1.2013e-11,
-        -8.9283e-09, -8.7824e-10,  9.8670e-02, -2.0313e-12,  2.4423e-08,
-        -1.2151e-01,  2.0525e-01, -4.3137e-12, -7.7584e-01,  2.8900e-02,
-         3.5011e-08,  1.9878e-01, -9.7896e-02,  1.0834e-01,  6.2737e-12,
-        -5.5776e-01, -1.2660e+00, -1.6059e-15, -1.3382e-01,  7.9711e-10,
-         4.3009e-08,  2.6590e-11,  0.0000e+00,  3.2679e-02,  7.2355e-01,
-         1.5632e-01, -3.5515e-09, -2.1112e-02,  6.5440e-12,  1.6006e-09,
-        -2.0694e-01, -2.2584e-07,  4.7992e-01,  7.5972e-07, -4.1660e-13,
-        -1.3479e-03,  2.1605e-01, -1.7241e-01, -9.3092e-02, -1.5267e-04,
-         8.6702e-17, -3.4075e-11,  1.2255e-14,  7.0635e-01, -1.7699e-06,
-         3.2735e-14,  3.5260e-01,  2.7504e-02, -2.3832e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5035,  0.0000, -0.1381, -0.1208,  2.3510,  0.9669,  0.0000,  0.0000,
-         0.4167,  0.1864,  0.2500,  0.0000, -0.2273, -0.4871,  0.0000,  0.0000,
-         0.0000,  0.0987,  0.0000,  0.0000, -0.1215,  0.2052,  0.0000, -0.7758,
-         0.0289,  0.0000,  0.1988, -0.0979,  0.1083,  0.0000, -0.5578, -1.2660,
-         0.0000, -0.1338,  0.0000,  0.0000,  0.0000,  0.0000,  0.0327,  0.7235,
-         0.1563,  0.0000, -0.0211,  0.0000,  0.0000, -0.2069,  0.0000,  0.4799,
-         0.0000,  0.0000,  0.0000,  0.2161, -0.1724, -0.0931,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7063,  0.0000,  0.0000,  0.3526,  0.0000, -0.2383],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5035,  0.0000, -0.1381, -0.1208,  2.3510,  0.9669,  0.0000,  0.0000,
-         0.4167,  0.1864,  0.2500,  0.0000, -0.2273, -0.4871,  0.0000,  0.0000,
-         0.0000,  0.0987,  0.0000,  0.0000, -0.1215,  0.2052,  0.0000, -0.7758,
-         0.0289,  0.0000,  0.1988, -0.0979,  0.1083,  0.0000, -0.5578, -1.2660,
-         0.0000, -0.1338,  0.0000,  0.0000,  0.0000,  0.0000,  0.0327,  0.7235,
-         0.1563,  0.0000, -0.0211,  0.0000,  0.0000, -0.2069,  0.0000,  0.4799,
-         0.0000,  0.0000,  0.0000,  0.2161, -0.1724, -0.0931,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7063,  0.0000,  0.0000,  0.3526,  0.0000, -0.2383],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.1706e-01, -9.7463e-06, -1.2696e-01, -1.4366e-01,  2.3517e+00,
-         9.5595e-01,  5.8996e-15, -4.6584e-08,  3.4555e-01,  2.1688e-01,
-         1.9804e-01,  1.0606e-06, -2.8907e-01, -4.8232e-01, -1.0580e-11,
-        -7.8632e-09, -7.7346e-10,  5.6382e-02, -1.7890e-12,  2.1509e-08,
-        -1.1559e-01,  1.5015e-01, -3.7991e-12, -7.5118e-01,  1.3157e-02,
-         3.0834e-08,  2.0103e-01, -1.5387e-01,  9.3221e-02,  5.5253e-12,
-        -5.3771e-01, -1.2635e+00, -1.4144e-15, -1.2594e-01,  7.0202e-10,
-         3.7878e-08,  2.3418e-11,  0.0000e+00,  1.3018e-01,  7.2430e-01,
-         1.4567e-01, -3.1278e-09,  8.7943e-03,  5.7633e-12,  1.4096e-09,
-        -2.2328e-01, -1.9890e-07,  4.9053e-01,  6.6909e-07, -3.6690e-13,
-        -1.1871e-03,  1.1435e-01, -2.0203e-01, -1.2635e-01, -1.3445e-04,
-         7.6359e-17, -3.0010e-11,  1.0793e-14,  6.9683e-01, -1.5587e-06,
-         2.8830e-14,  3.3215e-01,  2.4223e-02, -2.2867e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5171,  0.0000, -0.1270, -0.1437,  2.3517,  0.9559,  0.0000,  0.0000,
-         0.3455,  0.2169,  0.1980,  0.0000, -0.2891, -0.4823,  0.0000,  0.0000,
-         0.0000,  0.0564,  0.0000,  0.0000, -0.1156,  0.1502,  0.0000, -0.7512,
-         0.0132,  0.0000,  0.2010, -0.1539,  0.0932,  0.0000, -0.5377, -1.2635,
-         0.0000, -0.1259,  0.0000,  0.0000,  0.0000,  0.0000,  0.1302,  0.7243,
-         0.1457,  0.0000,  0.0088,  0.0000,  0.0000, -0.2233,  0.0000,  0.4905,
-         0.0000,  0.0000,  0.0000,  0.1143, -0.2020, -0.1264,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6968,  0.0000,  0.0000,  0.3322,  0.0000, -0.2287],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5171,  0.0000, -0.1270, -0.1437,  2.3517,  0.9559,  0.0000,  0.0000,
-         0.3455,  0.2169,  0.1980,  0.0000, -0.2891, -0.4823,  0.0000,  0.0000,
-         0.0000,  0.0564,  0.0000,  0.0000, -0.1156,  0.1502,  0.0000, -0.7512,
-         0.0132,  0.0000,  0.2010, -0.1539,  0.0932,  0.0000, -0.5377, -1.2635,
-         0.0000, -0.1259,  0.0000,  0.0000,  0.0000,  0.0000,  0.1302,  0.7243,
-         0.1457,  0.0000,  0.0088,  0.0000,  0.0000, -0.2233,  0.0000,  0.4905,
-         0.0000,  0.0000,  0.0000,  0.1143, -0.2020, -0.1264,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6968,  0.0000,  0.0000,  0.3322,  0.0000, -0.2287],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.2252e-01, -8.5871e-06, -9.9585e-02, -1.2198e-01,  2.3513e+00,
-         9.4473e-01,  5.1979e-15, -4.1043e-08,  2.9694e-01,  2.5381e-01,
-         1.0537e-01,  9.3444e-07, -3.2595e-01, -4.5994e-01, -9.3213e-12,
-        -6.9279e-09, -6.8147e-10,  5.1780e-02, -1.5762e-12,  1.8951e-08,
-        -5.6011e-02,  1.3185e-01, -3.3472e-12, -7.4366e-01,  2.8747e-02,
-         2.7166e-08,  2.1261e-01, -1.4032e-01,  8.2189e-02,  4.8681e-12,
-        -5.4267e-01, -1.2573e+00, -1.2461e-15, -8.5196e-02,  6.1852e-10,
-         3.3373e-08,  2.0632e-11,  0.0000e+00,  2.4013e-01,  6.9613e-01,
-         1.3280e-01, -2.7558e-09,  3.7309e-02,  5.0778e-12,  1.2420e-09,
-        -2.5367e-01, -1.7524e-07,  4.9054e-01,  5.8951e-07, -3.2326e-13,
-        -1.0459e-03,  1.1721e-01, -2.5591e-01, -7.8590e-02, -1.1846e-04,
-         6.7277e-17, -2.6441e-11,  9.5091e-15,  7.0517e-01, -1.3733e-06,
-         2.5401e-14,  3.3370e-01,  2.1342e-02, -2.1430e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5225,  0.0000, -0.0996, -0.1220,  2.3513,  0.9447,  0.0000,  0.0000,
-         0.2969,  0.2538,  0.1054,  0.0000, -0.3259, -0.4599,  0.0000,  0.0000,
-         0.0000,  0.0518,  0.0000,  0.0000, -0.0560,  0.1319,  0.0000, -0.7437,
-         0.0287,  0.0000,  0.2126, -0.1403,  0.0822,  0.0000, -0.5427, -1.2573,
-         0.0000, -0.0852,  0.0000,  0.0000,  0.0000,  0.0000,  0.2401,  0.6961,
-         0.1328,  0.0000,  0.0373,  0.0000,  0.0000, -0.2537,  0.0000,  0.4905,
-         0.0000,  0.0000,  0.0000,  0.1172, -0.2559, -0.0786,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7052,  0.0000,  0.0000,  0.3337,  0.0000, -0.2143],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5225,  0.0000, -0.0996, -0.1220,  2.3513,  0.9447,  0.0000,  0.0000,
-         0.2969,  0.2538,  0.1054,  0.0000, -0.3259, -0.4599,  0.0000,  0.0000,
-         0.0000,  0.0518,  0.0000,  0.0000, -0.0560,  0.1319,  0.0000, -0.7437,
-         0.0287,  0.0000,  0.2126, -0.1403,  0.0822,  0.0000, -0.5427, -1.2573,
-         0.0000, -0.0852,  0.0000,  0.0000,  0.0000,  0.0000,  0.2401,  0.6961,
-         0.1328,  0.0000,  0.0373,  0.0000,  0.0000, -0.2537,  0.0000,  0.4905,
-         0.0000,  0.0000,  0.0000,  0.1172, -0.2559, -0.0786,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7052,  0.0000,  0.0000,  0.3337,  0.0000, -0.2143],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.1172e-01, -7.5688e-06, -6.6259e-02, -8.1769e-02,  2.3482e+00,
-         9.2472e-01,  4.5815e-15, -3.6176e-08,  2.6120e-01,  2.6335e-01,
-         1.2724e-02,  8.2363e-07, -3.3312e-01, -4.4191e-01, -8.2160e-12,
-        -6.1064e-09, -6.0066e-10,  6.2712e-02, -1.3893e-12,  1.6704e-08,
-         6.1756e-02,  1.3277e-01, -2.9503e-12, -7.1028e-01,  4.9281e-02,
-         2.3945e-08,  1.9987e-01, -6.0068e-02,  1.0444e-01,  4.2908e-12,
-        -5.6810e-01, -1.2497e+00, -1.0984e-15, -5.9465e-02,  5.4518e-10,
-         2.9416e-08,  1.8186e-11,  0.0000e+00,  2.8695e-01,  6.5724e-01,
-         1.3701e-01, -2.4290e-09,  8.2626e-02,  4.4757e-12,  1.0947e-09,
-        -2.9378e-01, -1.5446e-07,  4.6593e-01,  5.1960e-07, -2.8493e-13,
-        -9.2191e-04,  1.8246e-01, -3.2119e-01, -7.2704e-02, -1.0441e-04,
-         5.9299e-17, -2.3306e-11,  8.3815e-15,  7.2027e-01, -1.2105e-06,
-         2.2389e-14,  3.3204e-01,  1.8811e-02, -1.9969e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5117,  0.0000, -0.0663, -0.0818,  2.3482,  0.9247,  0.0000,  0.0000,
-         0.2612,  0.2633,  0.0127,  0.0000, -0.3331, -0.4419,  0.0000,  0.0000,
-         0.0000,  0.0627,  0.0000,  0.0000,  0.0618,  0.1328,  0.0000, -0.7103,
-         0.0493,  0.0000,  0.1999, -0.0601,  0.1044,  0.0000, -0.5681, -1.2497,
-         0.0000, -0.0595,  0.0000,  0.0000,  0.0000,  0.0000,  0.2870,  0.6572,
-         0.1370,  0.0000,  0.0826,  0.0000,  0.0000, -0.2938,  0.0000,  0.4659,
-         0.0000,  0.0000,  0.0000,  0.1825, -0.3212, -0.0727,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7203,  0.0000,  0.0000,  0.3320,  0.0000, -0.1997],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5117,  0.0000, -0.0663, -0.0818,  2.3482,  0.9247,  0.0000,  0.0000,
-         0.2612,  0.2633,  0.0127,  0.0000, -0.3331, -0.4419,  0.0000,  0.0000,
-         0.0000,  0.0627,  0.0000,  0.0000,  0.0618,  0.1328,  0.0000, -0.7103,
-         0.0493,  0.0000,  0.1999, -0.0601,  0.1044,  0.0000, -0.5681, -1.2497,
-         0.0000, -0.0595,  0.0000,  0.0000,  0.0000,  0.0000,  0.2870,  0.6572,
-         0.1370,  0.0000,  0.0826,  0.0000,  0.0000, -0.2938,  0.0000,  0.4659,
-         0.0000,  0.0000,  0.0000,  0.1825, -0.3212, -0.0727,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7203,  0.0000,  0.0000,  0.3320,  0.0000, -0.1997],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.0905e-01, -6.6740e-06, -4.1537e-02, -4.2190e-02,  2.3451e+00,
-         9.0445e-01,  4.0399e-15, -3.1900e-08,  2.3241e-01,  2.8265e-01,
-         4.6823e-03,  7.2626e-07, -3.3368e-01, -4.1274e-01, -7.2447e-12,
-        -5.3845e-09, -5.2965e-10,  7.0325e-02, -1.2250e-12,  1.4729e-08,
-         1.6560e-01,  1.3888e-01, -2.6015e-12, -6.5201e-01,  8.0529e-02,
-         2.1114e-08,  1.8959e-01,  3.2678e-02,  1.3194e-01,  3.7836e-12,
-        -6.1452e-01, -1.2450e+00, -9.6852e-16, -4.2310e-02,  4.8072e-10,
-         2.5938e-08,  1.6036e-11,  0.0000e+00,  3.1728e-01,  6.2498e-01,
-         1.6149e-01, -2.1419e-09,  1.1774e-01,  3.9466e-12,  9.6529e-10,
-        -3.2236e-01, -1.3620e-07,  4.4197e-01,  4.5818e-07, -2.5125e-13,
-        -8.1292e-04,  2.3801e-01, -3.5430e-01, -8.3684e-02, -9.2070e-05,
-         5.2289e-17, -2.0550e-11,  7.3907e-15,  7.3011e-01, -1.0674e-06,
-         1.9742e-14,  3.4774e-01,  1.6587e-02, -1.7754e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5090,  0.0000, -0.0415, -0.0422,  2.3451,  0.9045,  0.0000,  0.0000,
-         0.2324,  0.2826,  0.0047,  0.0000, -0.3337, -0.4127,  0.0000,  0.0000,
-         0.0000,  0.0703,  0.0000,  0.0000,  0.1656,  0.1389,  0.0000, -0.6520,
-         0.0805,  0.0000,  0.1896,  0.0327,  0.1319,  0.0000, -0.6145, -1.2450,
-         0.0000, -0.0423,  0.0000,  0.0000,  0.0000,  0.0000,  0.3173,  0.6250,
-         0.1615,  0.0000,  0.1177,  0.0000,  0.0000, -0.3224,  0.0000,  0.4420,
-         0.0000,  0.0000,  0.0000,  0.2380, -0.3543, -0.0837,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7301,  0.0000,  0.0000,  0.3477,  0.0000, -0.1775],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5090,  0.0000, -0.0415, -0.0422,  2.3451,  0.9045,  0.0000,  0.0000,
-         0.2324,  0.2826,  0.0047,  0.0000, -0.3337, -0.4127,  0.0000,  0.0000,
-         0.0000,  0.0703,  0.0000,  0.0000,  0.1656,  0.1389,  0.0000, -0.6520,
-         0.0805,  0.0000,  0.1896,  0.0327,  0.1319,  0.0000, -0.6145, -1.2450,
-         0.0000, -0.0423,  0.0000,  0.0000,  0.0000,  0.0000,  0.3173,  0.6250,
-         0.1615,  0.0000,  0.1177,  0.0000,  0.0000, -0.3224,  0.0000,  0.4420,
-         0.0000,  0.0000,  0.0000,  0.2380, -0.3543, -0.0837,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7301,  0.0000,  0.0000,  0.3477,  0.0000, -0.1775],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.0374e-01, -5.8874e-06, -2.9388e-02, -9.9043e-03,  2.3427e+00,
-         8.8646e-01,  3.5638e-15, -2.8140e-08,  1.7854e-01,  2.9530e-01,
-         3.3002e-02,  6.4066e-07, -3.5830e-01, -3.8373e-01, -6.3908e-12,
-        -4.7499e-09, -4.6722e-10,  7.2009e-02, -1.0807e-12,  1.2993e-08,
-         2.2293e-01,  1.5603e-01, -2.2949e-12, -6.2750e-01,  1.1240e-01,
-         1.8626e-08,  2.0123e-01,  1.0313e-01,  1.4659e-01,  3.3376e-12,
-        -6.3825e-01, -1.2496e+00, -8.5437e-16, -2.9356e-02,  4.2407e-10,
-         2.2881e-08,  1.4146e-11,  0.0000e+00,  3.2662e-01,  5.7728e-01,
-         1.7883e-01, -1.8894e-09,  1.3003e-01,  3.4814e-12,  8.5152e-10,
-        -3.1574e-01, -1.2015e-07,  4.1138e-01,  4.0417e-07, -2.2163e-13,
-        -7.1711e-04,  2.5832e-01, -3.6877e-01, -5.9825e-02, -8.1218e-05,
-         4.6126e-17, -1.8128e-11,  6.5196e-15,  7.4360e-01, -9.4157e-07,
-         1.7415e-14,  3.5281e-01,  1.4632e-02, -1.3226e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5037,  0.0000, -0.0294, -0.0099,  2.3427,  0.8865,  0.0000,  0.0000,
-         0.1785,  0.2953,  0.0330,  0.0000, -0.3583, -0.3837,  0.0000,  0.0000,
-         0.0000,  0.0720,  0.0000,  0.0000,  0.2229,  0.1560,  0.0000, -0.6275,
-         0.1124,  0.0000,  0.2012,  0.1031,  0.1466,  0.0000, -0.6382, -1.2496,
-         0.0000, -0.0294,  0.0000,  0.0000,  0.0000,  0.0000,  0.3266,  0.5773,
-         0.1788,  0.0000,  0.1300,  0.0000,  0.0000, -0.3157,  0.0000,  0.4114,
-         0.0000,  0.0000,  0.0000,  0.2583, -0.3688, -0.0598,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7436,  0.0000,  0.0000,  0.3528,  0.0000, -0.1323],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5037,  0.0000, -0.0294, -0.0099,  2.3427,  0.8865,  0.0000,  0.0000,
-         0.1785,  0.2953,  0.0330,  0.0000, -0.3583, -0.3837,  0.0000,  0.0000,
-         0.0000,  0.0720,  0.0000,  0.0000,  0.2229,  0.1560,  0.0000, -0.6275,
-         0.1124,  0.0000,  0.2012,  0.1031,  0.1466,  0.0000, -0.6382, -1.2496,
-         0.0000, -0.0294,  0.0000,  0.0000,  0.0000,  0.0000,  0.3266,  0.5773,
-         0.1788,  0.0000,  0.1300,  0.0000,  0.0000, -0.3157,  0.0000,  0.4114,
-         0.0000,  0.0000,  0.0000,  0.2583, -0.3688, -0.0598,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7436,  0.0000,  0.0000,  0.3528,  0.0000, -0.1323],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.0636e-01, -5.1956e-06,  2.3213e-02,  3.7369e-02,  2.3416e+00,
-         8.5972e-01,  3.1450e-15, -2.4833e-08,  1.2882e-01,  3.1971e-01,
-         1.6869e-01,  5.6539e-07, -3.7553e-01, -3.5003e-01, -5.6399e-12,
-        -4.1918e-09, -4.1233e-10,  4.0336e-02, -9.5368e-13,  1.1466e-08,
-         2.5255e-01,  1.3169e-01, -2.0253e-12, -6.5628e-01,  7.7795e-02,
-         1.6437e-08,  2.2244e-01,  1.5080e-01,  1.3718e-01,  2.9455e-12,
-        -6.9508e-01, -1.2479e+00, -7.5398e-16, -2.7531e-02,  3.7424e-10,
-         2.0192e-08,  1.2484e-11,  0.0000e+00,  3.4812e-01,  5.2528e-01,
-         1.7922e-01, -1.6674e-09,  1.7927e-01,  3.0724e-12,  7.5147e-10,
-        -3.6138e-01, -1.0603e-07,  3.6561e-01,  3.5668e-07, -1.9559e-13,
-        -6.3285e-04,  2.7168e-01, -3.5850e-01, -4.4886e-02, -7.1675e-05,
-         4.0706e-17, -1.5998e-11,  5.7535e-15,  7.3065e-01, -8.3093e-07,
-         1.5369e-14,  3.7831e-01,  1.2913e-02, -5.1602e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Sparsity at the end of epoch 1: 48.39%
-After Step tensor([ 0.5064,  0.0000,  0.0232,  0.0374,  2.3416,  0.8597,  0.0000,  0.0000,
-         0.1288,  0.3197,  0.1687,  0.0000, -0.3755, -0.3500,  0.0000,  0.0000,
-         0.0000,  0.0403,  0.0000,  0.0000,  0.2525,  0.1317,  0.0000, -0.6563,
-         0.0778,  0.0000,  0.2224,  0.1508,  0.1372,  0.0000, -0.6951, -1.2479,
-         0.0000, -0.0275,  0.0000,  0.0000,  0.0000,  0.0000,  0.3481,  0.5253,
-         0.1792,  0.0000,  0.1793,  0.0000,  0.0000, -0.3614,  0.0000,  0.3656,
-         0.0000,  0.0000,  0.0000,  0.2717, -0.3585, -0.0449,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7307,  0.0000,  0.0000,  0.3783,  0.0000, -0.0516],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5064,  0.0000,  0.0232,  0.0374,  2.3416,  0.8597,  0.0000,  0.0000,
-         0.1288,  0.3197,  0.1687,  0.0000, -0.3755, -0.3500,  0.0000,  0.0000,
-         0.0000,  0.0403,  0.0000,  0.0000,  0.2525,  0.1317,  0.0000, -0.6563,
-         0.0778,  0.0000,  0.2224,  0.1508,  0.1372,  0.0000, -0.6951, -1.2479,
-         0.0000, -0.0275,  0.0000,  0.0000,  0.0000,  0.0000,  0.3481,  0.5253,
-         0.1792,  0.0000,  0.1793,  0.0000,  0.0000, -0.3614,  0.0000,  0.3656,
-         0.0000,  0.0000,  0.0000,  0.2717, -0.3585, -0.0449,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7307,  0.0000,  0.0000,  0.3783,  0.0000, -0.0516],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.1017e-01, -4.5870e-06,  5.5963e-02,  5.8613e-02,  2.3396e+00,
-         8.2662e-01,  2.7766e-15, -2.1925e-08,  8.4063e-02,  3.3392e-01,
-         2.4563e-01,  4.9916e-07, -4.0126e-01, -3.2466e-01, -4.9793e-12,
-        -3.7008e-09, -3.6403e-10, -3.2571e-02, -8.4197e-13,  1.0123e-08,
-         2.3642e-01,  7.5095e-02, -1.7880e-12, -6.9629e-01,  2.1137e-02,
-         1.4512e-08,  2.2716e-01,  1.6111e-01,  9.4385e-02,  2.6004e-12,
-        -7.3967e-01, -1.2517e+00, -6.6566e-16, -4.8594e-02,  3.3040e-10,
-         1.7827e-08,  1.1021e-11,  0.0000e+00,  3.6436e-01,  4.8578e-01,
-         1.7299e-01, -1.4721e-09,  1.9960e-01,  2.7125e-12,  6.6344e-10,
-        -3.6394e-01, -9.3610e-08,  3.3499e-01,  3.1490e-07, -1.7268e-13,
-        -5.5872e-04,  2.4021e-01, -3.1761e-01, -5.5102e-02, -6.3279e-05,
-         3.5938e-17, -1.4124e-11,  5.0796e-15,  7.2253e-01, -7.3360e-07,
-         1.3569e-14,  3.8404e-01,  1.1400e-02,  4.8220e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5102,  0.0000,  0.0560,  0.0586,  2.3396,  0.8266,  0.0000,  0.0000,
-         0.0841,  0.3339,  0.2456,  0.0000, -0.4013, -0.3247,  0.0000,  0.0000,
-         0.0000, -0.0326,  0.0000,  0.0000,  0.2364,  0.0751,  0.0000, -0.6963,
-         0.0211,  0.0000,  0.2272,  0.1611,  0.0944,  0.0000, -0.7397, -1.2517,
-         0.0000, -0.0486,  0.0000,  0.0000,  0.0000,  0.0000,  0.3644,  0.4858,
-         0.1730,  0.0000,  0.1996,  0.0000,  0.0000, -0.3639,  0.0000,  0.3350,
-         0.0000,  0.0000,  0.0000,  0.2402, -0.3176, -0.0551,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7225,  0.0000,  0.0000,  0.3840,  0.0000,  0.0048],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5102,  0.0000,  0.0560,  0.0586,  2.3396,  0.8266,  0.0000,  0.0000,
-         0.0841,  0.3339,  0.2456,  0.0000, -0.4013, -0.3247,  0.0000,  0.0000,
-         0.0000, -0.0326,  0.0000,  0.0000,  0.2364,  0.0751,  0.0000, -0.6963,
-         0.0211,  0.0000,  0.2272,  0.1611,  0.0944,  0.0000, -0.7397, -1.2517,
-         0.0000, -0.0486,  0.0000,  0.0000,  0.0000,  0.0000,  0.3644,  0.4858,
-         0.1730,  0.0000,  0.1996,  0.0000,  0.0000, -0.3639,  0.0000,  0.3350,
-         0.0000,  0.0000,  0.0000,  0.2402, -0.3176, -0.0551,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.7225,  0.0000,  0.0000,  0.3840,  0.0000,  0.0048],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.2471e-01, -4.0514e-06,  1.0786e-01,  1.0295e-01,  2.3384e+00,
-         7.9501e-01,  2.4524e-15, -1.9364e-08,  6.1956e-02,  3.4676e-01,
-         2.3148e-01,  4.4087e-07, -4.1695e-01, -3.2150e-01, -4.3978e-12,
-        -3.2686e-09, -3.2152e-10, -9.1195e-02, -7.4364e-13,  8.9411e-09,
-         2.8330e-01,  2.8586e-02, -1.5792e-12, -7.1171e-01, -4.3213e-02,
-         1.2817e-08,  2.3696e-01,  2.1385e-01,  8.7035e-02,  2.2968e-12,
-        -7.7860e-01, -1.2473e+00, -5.8792e-16, -5.4487e-02,  2.9182e-10,
-         1.5745e-08,  9.7343e-12,  0.0000e+00,  4.0454e-01,  4.6507e-01,
-         1.4993e-01, -1.3002e-09,  2.3597e-01,  2.3957e-12,  5.8597e-10,
-        -3.8866e-01, -8.2678e-08,  3.0897e-01,  2.7813e-07, -1.5252e-13,
-        -4.9347e-04,  2.3813e-01, -3.0145e-01, -3.9031e-02, -5.5890e-05,
-         3.1741e-17, -1.2475e-11,  4.4864e-15,  6.9425e-01, -6.4793e-07,
-         1.1984e-14,  4.1229e-01,  1.0069e-02,  2.1657e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5247,  0.0000,  0.1079,  0.1030,  2.3384,  0.7950,  0.0000,  0.0000,
-         0.0620,  0.3468,  0.2315,  0.0000, -0.4169, -0.3215,  0.0000,  0.0000,
-         0.0000, -0.0912,  0.0000,  0.0000,  0.2833,  0.0286,  0.0000, -0.7117,
-        -0.0432,  0.0000,  0.2370,  0.2138,  0.0870,  0.0000, -0.7786, -1.2473,
-         0.0000, -0.0545,  0.0000,  0.0000,  0.0000,  0.0000,  0.4045,  0.4651,
-         0.1499,  0.0000,  0.2360,  0.0000,  0.0000, -0.3887,  0.0000,  0.3090,
-         0.0000,  0.0000,  0.0000,  0.2381, -0.3014, -0.0390,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6943,  0.0000,  0.0000,  0.4123,  0.0000,  0.0217],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5247,  0.0000,  0.1079,  0.1030,  2.3384,  0.7950,  0.0000,  0.0000,
-         0.0620,  0.3468,  0.2315,  0.0000, -0.4169, -0.3215,  0.0000,  0.0000,
-         0.0000, -0.0912,  0.0000,  0.0000,  0.2833,  0.0286,  0.0000, -0.7117,
-        -0.0432,  0.0000,  0.2370,  0.2138,  0.0870,  0.0000, -0.7786, -1.2473,
-         0.0000, -0.0545,  0.0000,  0.0000,  0.0000,  0.0000,  0.4045,  0.4651,
-         0.1499,  0.0000,  0.2360,  0.0000,  0.0000, -0.3887,  0.0000,  0.3090,
-         0.0000,  0.0000,  0.0000,  0.2381, -0.3014, -0.0390,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6943,  0.0000,  0.0000,  0.4123,  0.0000,  0.0217],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.1716e-01, -3.5797e-06,  1.4272e-01,  1.1016e-01,  2.3386e+00,
-         7.8149e-01,  2.1669e-15, -1.7110e-08,  4.2773e-02,  3.2777e-01,
-         1.4361e-01,  3.8954e-07, -4.3792e-01, -3.2725e-01, -3.8858e-12,
-        -2.8881e-09, -2.8409e-10, -1.7170e-01, -6.5707e-13,  7.9002e-09,
-         3.0059e-01, -4.3050e-02, -1.3954e-12, -7.3913e-01, -1.1889e-01,
-         1.1325e-08,  2.3617e-01,  2.2207e-01,  4.8044e-02,  2.0294e-12,
-        -8.0718e-01, -1.2483e+00, -5.1948e-16, -7.7130e-02,  2.5785e-10,
-         1.3912e-08,  8.6011e-12,  0.0000e+00,  4.1691e-01,  4.4858e-01,
-         1.0628e-01, -1.1488e-09,  2.8512e-01,  2.1168e-12,  5.1775e-10,
-        -3.9264e-01, -7.3053e-08,  3.0944e-01,  2.4575e-07, -1.3476e-13,
-        -4.3602e-04,  1.7708e-01, -2.8940e-01, -1.6737e-02, -4.9383e-05,
-         2.8046e-17, -1.1023e-11,  3.9641e-15,  6.7380e-01, -5.7250e-07,
-         1.0589e-14,  4.1227e-01,  8.8969e-03,  4.2595e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5172,  0.0000,  0.1427,  0.1102,  2.3386,  0.7815,  0.0000,  0.0000,
-         0.0428,  0.3278,  0.1436,  0.0000, -0.4379, -0.3273,  0.0000,  0.0000,
-         0.0000, -0.1717,  0.0000,  0.0000,  0.3006, -0.0431,  0.0000, -0.7391,
-        -0.1189,  0.0000,  0.2362,  0.2221,  0.0480,  0.0000, -0.8072, -1.2483,
-         0.0000, -0.0771,  0.0000,  0.0000,  0.0000,  0.0000,  0.4169,  0.4486,
-         0.1063,  0.0000,  0.2851,  0.0000,  0.0000, -0.3926,  0.0000,  0.3094,
-         0.0000,  0.0000,  0.0000,  0.1771, -0.2894, -0.0167,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6738,  0.0000,  0.0000,  0.4123,  0.0000,  0.0426],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5172,  0.0000,  0.1427,  0.1102,  2.3386,  0.7815,  0.0000,  0.0000,
-         0.0428,  0.3278,  0.1436,  0.0000, -0.4379, -0.3273,  0.0000,  0.0000,
-         0.0000, -0.1717,  0.0000,  0.0000,  0.3006, -0.0431,  0.0000, -0.7391,
-        -0.1189,  0.0000,  0.2362,  0.2221,  0.0480,  0.0000, -0.8072, -1.2483,
-         0.0000, -0.0771,  0.0000,  0.0000,  0.0000,  0.0000,  0.4169,  0.4486,
-         0.1063,  0.0000,  0.2851,  0.0000,  0.0000, -0.3926,  0.0000,  0.3094,
-         0.0000,  0.0000,  0.0000,  0.1771, -0.2894, -0.0167,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6738,  0.0000,  0.0000,  0.4123,  0.0000,  0.0426],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.1651e-01, -3.1643e-06,  1.0399e-01,  1.3593e-01,  2.3378e+00,
-         7.5344e-01,  1.9154e-15, -1.5124e-08,  2.6139e-02,  3.2669e-01,
-         8.5215e-02,  3.4434e-07, -4.7656e-01, -3.2884e-01, -3.4349e-12,
-        -2.5529e-09, -2.5112e-10, -2.4629e-01, -5.8082e-13,  6.9834e-09,
-         2.6978e-01, -5.8811e-02, -1.2334e-12, -7.6683e-01, -1.2705e-01,
-         1.0011e-08,  1.8226e-01,  2.3235e-01,  1.9107e-02,  1.7939e-12,
-        -8.3220e-01, -1.2491e+00, -4.5919e-16, -1.0012e-01,  2.2792e-10,
-         1.2298e-08,  7.6029e-12,  0.0000e+00,  4.1873e-01,  4.4239e-01,
-         1.3370e-01, -1.0155e-09,  2.6697e-01,  1.8711e-12,  4.5766e-10,
-        -3.8335e-01, -6.4575e-08,  2.5043e-01,  2.1723e-07, -1.1912e-13,
-        -3.8542e-04,  1.6244e-01, -2.9563e-01,  2.2056e-02, -4.3652e-05,
-         2.4791e-17, -9.7434e-12,  3.5041e-15,  6.6253e-01, -5.0606e-07,
-         9.3602e-15,  3.7359e-01,  7.8644e-03,  6.1951e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5165,  0.0000,  0.1040,  0.1359,  2.3378,  0.7534,  0.0000,  0.0000,
-         0.0261,  0.3267,  0.0852,  0.0000, -0.4766, -0.3288,  0.0000,  0.0000,
-         0.0000, -0.2463,  0.0000,  0.0000,  0.2698, -0.0588,  0.0000, -0.7668,
-        -0.1270,  0.0000,  0.1823,  0.2324,  0.0191,  0.0000, -0.8322, -1.2491,
-         0.0000, -0.1001,  0.0000,  0.0000,  0.0000,  0.0000,  0.4187,  0.4424,
-         0.1337,  0.0000,  0.2670,  0.0000,  0.0000, -0.3833,  0.0000,  0.2504,
-         0.0000,  0.0000,  0.0000,  0.1624, -0.2956,  0.0221,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6625,  0.0000,  0.0000,  0.3736,  0.0000,  0.0620],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5165,  0.0000,  0.1040,  0.1359,  2.3378,  0.7534,  0.0000,  0.0000,
-         0.0261,  0.3267,  0.0852,  0.0000, -0.4766, -0.3288,  0.0000,  0.0000,
-         0.0000, -0.2463,  0.0000,  0.0000,  0.2698, -0.0588,  0.0000, -0.7668,
-        -0.1270,  0.0000,  0.1823,  0.2324,  0.0191,  0.0000, -0.8322, -1.2491,
-         0.0000, -0.1001,  0.0000,  0.0000,  0.0000,  0.0000,  0.4187,  0.4424,
-         0.1337,  0.0000,  0.2670,  0.0000,  0.0000, -0.3833,  0.0000,  0.2504,
-         0.0000,  0.0000,  0.0000,  0.1624, -0.2956,  0.0221,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6625,  0.0000,  0.0000,  0.3736,  0.0000,  0.0620],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.2290e-01, -2.7982e-06,  1.6282e-02,  8.2067e-02,  2.3345e+00,
-         7.0914e-01,  1.6938e-15, -1.3375e-08,  2.5124e-02,  2.9947e-01,
-         1.0676e-03,  3.0450e-07, -5.2485e-01, -3.0306e-01, -3.0375e-12,
-        -2.2576e-09, -2.2207e-10, -3.2260e-01, -5.1362e-13,  6.1754e-09,
-         2.0268e-01, -2.1891e-02, -1.0907e-12, -7.7422e-01, -8.4374e-02,
-         8.8526e-09,  7.0355e-02,  2.0147e-01, -2.0748e-02,  1.5863e-12,
-        -8.3997e-01, -1.2487e+00, -4.0607e-16, -1.3191e-01,  2.0155e-10,
-         1.0875e-08,  6.7233e-12,  0.0000e+00,  4.0574e-01,  4.2783e-01,
-         1.7422e-01, -8.9801e-10,  2.1918e-01,  1.6547e-12,  4.0472e-10,
-        -3.6775e-01, -5.7104e-08,  1.9214e-01,  1.9210e-07, -1.0534e-13,
-        -3.4083e-04,  1.8853e-01, -2.9962e-01, -2.0618e-02, -3.8602e-05,
-         2.1923e-17, -8.6161e-12,  3.0987e-15,  6.5934e-01, -4.4752e-07,
-         8.2773e-15,  2.7873e-01,  6.9545e-03,  2.4204e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 5.2290e-01,  0.0000e+00,  1.6282e-02,  8.2067e-02,  2.3345e+00,
-         7.0914e-01,  0.0000e+00,  0.0000e+00,  2.5124e-02,  2.9947e-01,
-         1.0676e-03,  0.0000e+00, -5.2485e-01, -3.0306e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.2260e-01,  0.0000e+00,  0.0000e+00,
-         2.0268e-01, -2.1891e-02,  0.0000e+00, -7.7422e-01, -8.4374e-02,
-         0.0000e+00,  7.0355e-02,  2.0147e-01, -2.0748e-02,  0.0000e+00,
-        -8.3997e-01, -1.2487e+00,  0.0000e+00, -1.3191e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.0574e-01,  4.2783e-01,
-         1.7422e-01,  0.0000e+00,  2.1918e-01,  0.0000e+00,  0.0000e+00,
-        -3.6775e-01,  0.0000e+00,  1.9214e-01,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.8853e-01, -2.9962e-01, -2.0618e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.5934e-01,  0.0000e+00,
-         0.0000e+00,  2.7873e-01,  0.0000e+00,  2.4204e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 5.2290e-01,  0.0000e+00,  1.6282e-02,  8.2067e-02,  2.3345e+00,
-         7.0914e-01,  0.0000e+00,  0.0000e+00,  2.5124e-02,  2.9947e-01,
-         1.0676e-03,  0.0000e+00, -5.2485e-01, -3.0306e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.2260e-01,  0.0000e+00,  0.0000e+00,
-         2.0268e-01, -2.1891e-02,  0.0000e+00, -7.7422e-01, -8.4374e-02,
-         0.0000e+00,  7.0355e-02,  2.0147e-01, -2.0748e-02,  0.0000e+00,
-        -8.3997e-01, -1.2487e+00,  0.0000e+00, -1.3191e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.0574e-01,  4.2783e-01,
-         1.7422e-01,  0.0000e+00,  2.1918e-01,  0.0000e+00,  0.0000e+00,
-        -3.6775e-01,  0.0000e+00,  1.9214e-01,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.8853e-01, -2.9962e-01, -2.0618e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.5934e-01,  0.0000e+00,
-         0.0000e+00,  2.7873e-01,  0.0000e+00,  2.4204e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 5.3343e-01, -2.4755e-06, -6.2621e-02,  3.4592e-02,  2.3306e+00,
-         6.7341e-01,  1.4985e-15, -1.1832e-08,  1.4102e-02,  2.8403e-01,
-        -2.6111e-02,  2.6938e-07, -5.7926e-01, -2.7802e-01, -2.6872e-12,
-        -1.9972e-09, -1.9646e-10, -3.5141e-01, -4.5439e-13,  5.4632e-09,
-         1.3933e-01,  2.8377e-02, -9.6495e-13, -7.6704e-01, -3.4198e-02,
-         7.8316e-09, -3.4219e-02,  1.5725e-01, -6.1825e-02,  1.4034e-12,
-        -8.4573e-01, -1.2488e+00, -3.5924e-16, -1.5102e-01,  1.7831e-10,
-         9.6208e-09,  5.9480e-12,  0.0000e+00,  3.9646e-01,  4.0924e-01,
-         2.2378e-01, -7.9445e-10,  1.7596e-01,  1.4638e-12,  3.5804e-10,
-        -3.5254e-01, -5.0519e-08,  1.2542e-01,  1.6994e-07, -9.3191e-14,
-        -3.0152e-04,  1.8364e-01, -2.8855e-01, -2.5956e-02, -3.4150e-05,
-         1.9395e-17, -7.6225e-12,  2.7413e-15,  6.6134e-01, -3.9590e-07,
-         7.3227e-15,  1.8435e-01,  6.1525e-03,  3.0309e-05], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 5.3343e-01,  0.0000e+00, -6.2621e-02,  3.4592e-02,  2.3306e+00,
-         6.7341e-01,  0.0000e+00,  0.0000e+00,  1.4102e-02,  2.8403e-01,
-        -2.6111e-02,  0.0000e+00, -5.7926e-01, -2.7802e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.5141e-01,  0.0000e+00,  0.0000e+00,
-         1.3933e-01,  2.8377e-02,  0.0000e+00, -7.6704e-01, -3.4198e-02,
-         0.0000e+00, -3.4219e-02,  1.5725e-01, -6.1825e-02,  0.0000e+00,
-        -8.4573e-01, -1.2488e+00,  0.0000e+00, -1.5102e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.9646e-01,  4.0924e-01,
-         2.2378e-01,  0.0000e+00,  1.7596e-01,  0.0000e+00,  0.0000e+00,
-        -3.5254e-01,  0.0000e+00,  1.2542e-01,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.8364e-01, -2.8855e-01, -2.5956e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.6134e-01,  0.0000e+00,
-         0.0000e+00,  1.8435e-01,  0.0000e+00,  3.0309e-05], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 5.3343e-01,  0.0000e+00, -6.2621e-02,  3.4592e-02,  2.3306e+00,
-         6.7341e-01,  0.0000e+00,  0.0000e+00,  1.4102e-02,  2.8403e-01,
-        -2.6111e-02,  0.0000e+00, -5.7926e-01, -2.7802e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.5141e-01,  0.0000e+00,  0.0000e+00,
-         1.3933e-01,  2.8377e-02,  0.0000e+00, -7.6704e-01, -3.4198e-02,
-         0.0000e+00, -3.4219e-02,  1.5725e-01, -6.1825e-02,  0.0000e+00,
-        -8.4573e-01, -1.2488e+00,  0.0000e+00, -1.5102e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.9646e-01,  4.0924e-01,
-         2.2378e-01,  0.0000e+00,  1.7596e-01,  0.0000e+00,  0.0000e+00,
-        -3.5254e-01,  0.0000e+00,  1.2542e-01,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.8364e-01, -2.8855e-01, -2.5956e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.6134e-01,  0.0000e+00,
-         0.0000e+00,  1.8435e-01,  0.0000e+00,  3.0309e-05], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 5.3367e-01, -2.1909e-06, -1.2739e-01, -7.4849e-03,  2.3277e+00,
-         6.2178e-01,  1.3262e-15, -1.0472e-08,  1.0277e-02,  2.6290e-01,
-        -1.3892e-02,  2.3841e-07, -6.2328e-01, -2.5429e-01, -2.3782e-12,
-        -1.7676e-09, -1.7387e-10, -3.8543e-01, -4.0215e-13,  4.8352e-09,
-         9.2043e-02,  4.9267e-02, -8.5401e-13, -7.6019e-01, -1.3747e-02,
-         6.9313e-09, -1.3801e-01,  1.0381e-01, -7.9089e-02,  1.2420e-12,
-        -8.5951e-01, -1.2467e+00, -3.1794e-16, -1.2381e-01,  1.5781e-10,
-         8.5148e-09,  5.2641e-12,  0.0000e+00,  4.0329e-01,  3.8639e-01,
-         2.7855e-01, -7.0311e-10,  1.5345e-01,  1.2956e-12,  3.1688e-10,
-        -3.3909e-01, -4.4711e-08,  7.9086e-02,  1.5041e-07, -8.2478e-14,
-        -2.6686e-04,  1.9910e-01, -2.5912e-01, -8.0970e-02, -3.0224e-05,
-         1.7165e-17, -6.7461e-12,  2.4262e-15,  6.7391e-01, -3.5039e-07,
-         6.4808e-15,  8.6988e-02,  5.4452e-03, -2.6834e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5337,  0.0000, -0.1274, -0.0075,  2.3277,  0.6218,  0.0000,  0.0000,
-         0.0103,  0.2629, -0.0139,  0.0000, -0.6233, -0.2543,  0.0000,  0.0000,
-         0.0000, -0.3854,  0.0000,  0.0000,  0.0920,  0.0493,  0.0000, -0.7602,
-        -0.0137,  0.0000, -0.1380,  0.1038, -0.0791,  0.0000, -0.8595, -1.2467,
-         0.0000, -0.1238,  0.0000,  0.0000,  0.0000,  0.0000,  0.4033,  0.3864,
-         0.2785,  0.0000,  0.1534,  0.0000,  0.0000, -0.3391,  0.0000,  0.0791,
-         0.0000,  0.0000,  0.0000,  0.1991, -0.2591, -0.0810,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6739,  0.0000,  0.0000,  0.0870,  0.0000, -0.0268],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5337,  0.0000, -0.1274, -0.0075,  2.3277,  0.6218,  0.0000,  0.0000,
-         0.0103,  0.2629, -0.0139,  0.0000, -0.6233, -0.2543,  0.0000,  0.0000,
-         0.0000, -0.3854,  0.0000,  0.0000,  0.0920,  0.0493,  0.0000, -0.7602,
-        -0.0137,  0.0000, -0.1380,  0.1038, -0.0791,  0.0000, -0.8595, -1.2467,
-         0.0000, -0.1238,  0.0000,  0.0000,  0.0000,  0.0000,  0.4033,  0.3864,
-         0.2785,  0.0000,  0.1534,  0.0000,  0.0000, -0.3391,  0.0000,  0.0791,
-         0.0000,  0.0000,  0.0000,  0.1991, -0.2591, -0.0810,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6739,  0.0000,  0.0000,  0.0870,  0.0000, -0.0268],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.2074e-01, -1.9398e-06, -1.6239e-01, -1.1877e-02,  2.3231e+00,
-         5.6153e-01,  1.1742e-15, -9.2717e-09, -2.0190e-02,  2.3701e-01,
-        -3.9129e-02,  2.1109e-07, -6.5988e-01, -2.5931e-01, -2.1057e-12,
-        -1.5650e-09, -1.5394e-10, -3.8703e-01, -3.5606e-13,  4.2810e-09,
-         6.3678e-02,  3.8120e-02, -7.5614e-13, -7.3322e-01, -2.6244e-02,
-         6.1369e-09, -1.8675e-01,  5.8830e-02, -9.3253e-02,  1.0997e-12,
-        -8.5180e-01, -1.2417e+00, -2.8150e-16, -7.6939e-02,  1.3972e-10,
-         7.5390e-09,  4.6609e-12,  0.0000e+00,  4.0118e-01,  3.6808e-01,
-         2.8113e-01, -6.2254e-10,  1.7496e-01,  1.1471e-12,  2.8056e-10,
-        -3.4745e-01, -3.9587e-08,  2.9692e-02,  1.3317e-07, -7.3026e-14,
-        -2.3628e-04,  1.7388e-01, -2.5522e-01, -1.2145e-02, -2.6760e-05,
-         1.5198e-17, -5.9730e-12,  2.1481e-15,  6.8095e-01, -3.1023e-07,
-         5.7381e-15,  4.5986e-02,  4.8211e-03, -5.1947e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5207,  0.0000, -0.1624, -0.0119,  2.3231,  0.5615,  0.0000,  0.0000,
-        -0.0202,  0.2370, -0.0391,  0.0000, -0.6599, -0.2593,  0.0000,  0.0000,
-         0.0000, -0.3870,  0.0000,  0.0000,  0.0637,  0.0381,  0.0000, -0.7332,
-        -0.0262,  0.0000, -0.1868,  0.0588, -0.0933,  0.0000, -0.8518, -1.2417,
-         0.0000, -0.0769,  0.0000,  0.0000,  0.0000,  0.0000,  0.4012,  0.3681,
-         0.2811,  0.0000,  0.1750,  0.0000,  0.0000, -0.3475,  0.0000,  0.0297,
-         0.0000,  0.0000,  0.0000,  0.1739, -0.2552, -0.0121,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6810,  0.0000,  0.0000,  0.0460,  0.0000, -0.0519],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5207,  0.0000, -0.1624, -0.0119,  2.3231,  0.5615,  0.0000,  0.0000,
-        -0.0202,  0.2370, -0.0391,  0.0000, -0.6599, -0.2593,  0.0000,  0.0000,
-         0.0000, -0.3870,  0.0000,  0.0000,  0.0637,  0.0381,  0.0000, -0.7332,
-        -0.0262,  0.0000, -0.1868,  0.0588, -0.0933,  0.0000, -0.8518, -1.2417,
-         0.0000, -0.0769,  0.0000,  0.0000,  0.0000,  0.0000,  0.4012,  0.3681,
-         0.2811,  0.0000,  0.1750,  0.0000,  0.0000, -0.3475,  0.0000,  0.0297,
-         0.0000,  0.0000,  0.0000,  0.1739, -0.2552, -0.0121,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6810,  0.0000,  0.0000,  0.0460,  0.0000, -0.0519],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.1310e-01, -1.7182e-06, -1.6530e-01, -2.0789e-02,  2.3205e+00,
-         5.0595e-01,  1.0401e-15, -8.2125e-09, -5.3121e-02,  1.8237e-01,
-        -4.9279e-02,  1.8698e-07, -6.6630e-01, -2.7404e-01, -1.8651e-12,
-        -1.3862e-09, -1.3636e-10, -3.5952e-01, -3.1539e-13,  3.7920e-09,
-         5.1737e-02, -4.3761e-03, -6.6976e-13, -7.2262e-01, -7.2184e-02,
-         5.4358e-09, -2.1712e-01,  2.3085e-02, -1.0737e-01,  9.7407e-13,
-        -8.4424e-01, -1.2354e+00, -2.4934e-16, -4.2077e-02,  1.2376e-10,
-         6.6777e-09,  4.1284e-12,  0.0000e+00,  4.0212e-01,  3.4037e-01,
-         2.5400e-01, -5.5142e-10,  2.2478e-01,  1.0160e-12,  2.4851e-10,
-        -3.3536e-01, -3.5065e-08,  5.9840e-03,  1.1796e-07, -6.4683e-14,
-        -2.0929e-04,  1.1539e-01, -2.4890e-01,  1.1372e-01, -2.3703e-05,
-         1.3462e-17, -5.2907e-12,  1.9027e-15,  6.8002e-01, -2.7479e-07,
-         5.0826e-15,  3.4305e-02,  4.2704e-03, -6.5036e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.5131,  0.0000, -0.1653, -0.0208,  2.3205,  0.5060,  0.0000,  0.0000,
-        -0.0531,  0.1824, -0.0493,  0.0000, -0.6663, -0.2740,  0.0000,  0.0000,
-         0.0000, -0.3595,  0.0000,  0.0000,  0.0517, -0.0044,  0.0000, -0.7226,
-        -0.0722,  0.0000, -0.2171,  0.0231, -0.1074,  0.0000, -0.8442, -1.2354,
-         0.0000, -0.0421,  0.0000,  0.0000,  0.0000,  0.0000,  0.4021,  0.3404,
-         0.2540,  0.0000,  0.2248,  0.0000,  0.0000, -0.3354,  0.0000,  0.0060,
-         0.0000,  0.0000,  0.0000,  0.1154, -0.2489,  0.1137,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6800,  0.0000,  0.0000,  0.0343,  0.0000, -0.0650],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.5131,  0.0000, -0.1653, -0.0208,  2.3205,  0.5060,  0.0000,  0.0000,
-        -0.0531,  0.1824, -0.0493,  0.0000, -0.6663, -0.2740,  0.0000,  0.0000,
-         0.0000, -0.3595,  0.0000,  0.0000,  0.0517, -0.0044,  0.0000, -0.7226,
-        -0.0722,  0.0000, -0.2171,  0.0231, -0.1074,  0.0000, -0.8442, -1.2354,
-         0.0000, -0.0421,  0.0000,  0.0000,  0.0000,  0.0000,  0.4021,  0.3404,
-         0.2540,  0.0000,  0.2248,  0.0000,  0.0000, -0.3354,  0.0000,  0.0060,
-         0.0000,  0.0000,  0.0000,  0.1154, -0.2489,  0.1137,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6800,  0.0000,  0.0000,  0.0343,  0.0000, -0.0650],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.9016e-01, -1.5226e-06, -1.2168e-01, -3.1924e-02,  2.3188e+00,
-         4.4942e-01,  9.2163e-16, -7.2773e-09, -1.0751e-01,  1.1453e-01,
-        -5.6487e-02,  1.6568e-07, -6.4972e-01, -3.1001e-01, -1.6527e-12,
-        -1.2284e-09, -1.2083e-10, -3.1265e-01, -2.7947e-13,  3.3602e-09,
-         3.0972e-02, -5.6796e-02, -5.9349e-13, -7.1835e-01, -1.4696e-01,
-         4.8168e-09, -2.3618e-01, -3.0939e-02, -1.1312e-01,  8.6315e-13,
-        -8.3655e-01, -1.2304e+00, -2.2095e-16,  5.8476e-03,  1.0967e-10,
-         5.9173e-09,  3.6583e-12,  0.0000e+00,  4.3381e-01,  3.2868e-01,
-         1.9088e-01, -4.8863e-10,  3.0511e-01,  9.0034e-13,  2.2021e-10,
-        -3.2036e-01, -3.1072e-08,  4.0201e-03,  1.0452e-07, -5.7317e-14,
-        -1.8545e-04,  7.1840e-02, -2.4482e-01,  2.4166e-01, -2.1004e-05,
-         1.1929e-17, -4.6882e-12,  1.6860e-15,  6.6414e-01, -2.4350e-07,
-         4.5038e-15,  4.6850e-02,  3.7841e-03, -8.8436e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4902,  0.0000, -0.1217, -0.0319,  2.3188,  0.4494,  0.0000,  0.0000,
-        -0.1075,  0.1145, -0.0565,  0.0000, -0.6497, -0.3100,  0.0000,  0.0000,
-         0.0000, -0.3126,  0.0000,  0.0000,  0.0310, -0.0568,  0.0000, -0.7183,
-        -0.1470,  0.0000, -0.2362, -0.0309, -0.1131,  0.0000, -0.8365, -1.2304,
-         0.0000,  0.0058,  0.0000,  0.0000,  0.0000,  0.0000,  0.4338,  0.3287,
-         0.1909,  0.0000,  0.3051,  0.0000,  0.0000, -0.3204,  0.0000,  0.0040,
-         0.0000,  0.0000,  0.0000,  0.0718, -0.2448,  0.2417,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6641,  0.0000,  0.0000,  0.0469,  0.0000, -0.0884],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4902,  0.0000, -0.1217, -0.0319,  2.3188,  0.4494,  0.0000,  0.0000,
-        -0.1075,  0.1145, -0.0565,  0.0000, -0.6497, -0.3100,  0.0000,  0.0000,
-         0.0000, -0.3126,  0.0000,  0.0000,  0.0310, -0.0568,  0.0000, -0.7183,
-        -0.1470,  0.0000, -0.2362, -0.0309, -0.1131,  0.0000, -0.8365, -1.2304,
-         0.0000,  0.0058,  0.0000,  0.0000,  0.0000,  0.0000,  0.4338,  0.3287,
-         0.1909,  0.0000,  0.3051,  0.0000,  0.0000, -0.3204,  0.0000,  0.0040,
-         0.0000,  0.0000,  0.0000,  0.0718, -0.2448,  0.2417,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6641,  0.0000,  0.0000,  0.0469,  0.0000, -0.0884],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.6475e-01, -1.3497e-06, -8.7065e-02, -6.2907e-02,  2.3167e+00,
-         3.8527e-01,  8.1701e-16, -6.4513e-09, -1.2696e-01,  5.7173e-02,
-        -1.0061e-01,  1.4688e-07, -6.4049e-01, -3.2327e-01, -1.4651e-12,
-        -1.0889e-09, -1.0711e-10, -2.7764e-01, -2.4775e-13,  2.9788e-09,
-         2.5421e-02, -9.6904e-02, -5.2612e-13, -7.0980e-01, -2.1056e-01,
-         4.2701e-09, -2.5221e-01, -7.9663e-02, -1.0382e-01,  7.6517e-13,
-        -8.3130e-01, -1.2260e+00, -1.9587e-16,  5.4681e-02,  9.7220e-11,
-         5.2456e-09,  3.2430e-12,  0.0000e+00,  4.7579e-01,  3.5424e-01,
-         1.0865e-01, -4.3316e-10,  3.9423e-01,  7.9814e-13,  1.9522e-10,
-        -3.1116e-01, -2.7545e-08,  2.2827e-02,  9.2660e-08, -5.0811e-14,
-        -1.6440e-04,  5.9195e-02, -2.2801e-01,  3.1154e-01, -1.8620e-05,
-         1.0575e-17, -4.1560e-12,  1.4947e-15,  6.3417e-01, -2.1586e-07,
-         3.9926e-15,  5.8835e-02,  3.3545e-03, -1.0345e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4648,  0.0000, -0.0871, -0.0629,  2.3167,  0.3853,  0.0000,  0.0000,
-        -0.1270,  0.0572, -0.1006,  0.0000, -0.6405, -0.3233,  0.0000,  0.0000,
-         0.0000, -0.2776,  0.0000,  0.0000,  0.0254, -0.0969,  0.0000, -0.7098,
-        -0.2106,  0.0000, -0.2522, -0.0797, -0.1038,  0.0000, -0.8313, -1.2260,
-         0.0000,  0.0547,  0.0000,  0.0000,  0.0000,  0.0000,  0.4758,  0.3542,
-         0.1087,  0.0000,  0.3942,  0.0000,  0.0000, -0.3112,  0.0000,  0.0228,
-         0.0000,  0.0000,  0.0000,  0.0592, -0.2280,  0.3115,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6342,  0.0000,  0.0000,  0.0588,  0.0000, -0.1034],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4648,  0.0000, -0.0871, -0.0629,  2.3167,  0.3853,  0.0000,  0.0000,
-        -0.1270,  0.0572, -0.1006,  0.0000, -0.6405, -0.3233,  0.0000,  0.0000,
-         0.0000, -0.2776,  0.0000,  0.0000,  0.0254, -0.0969,  0.0000, -0.7098,
-        -0.2106,  0.0000, -0.2522, -0.0797, -0.1038,  0.0000, -0.8313, -1.2260,
-         0.0000,  0.0547,  0.0000,  0.0000,  0.0000,  0.0000,  0.4758,  0.3542,
-         0.1087,  0.0000,  0.3942,  0.0000,  0.0000, -0.3112,  0.0000,  0.0228,
-         0.0000,  0.0000,  0.0000,  0.0592, -0.2280,  0.3115,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6342,  0.0000,  0.0000,  0.0588,  0.0000, -0.1034],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.3809e-01, -1.1970e-06, -9.1794e-02, -1.1416e-01,  2.3139e+00,
-         3.4374e-01,  7.2457e-16, -5.7213e-09, -1.5650e-01,  1.2470e-02,
-        -9.4811e-02,  1.3026e-07, -6.5185e-01, -3.4543e-01, -1.2994e-12,
-        -9.6573e-10, -9.4995e-11, -2.3618e-01, -2.1972e-13,  2.6417e-09,
-         3.3661e-02, -1.0059e-01, -4.6659e-13, -6.7618e-01, -2.3017e-01,
-         3.7869e-09, -2.8014e-01, -9.4572e-02, -6.9279e-02,  6.7860e-13,
-        -8.1366e-01, -1.2193e+00, -1.7371e-16,  1.1151e-01,  8.6220e-11,
-         4.6521e-09,  2.8761e-12,  0.0000e+00,  4.7927e-01,  3.5028e-01,
-         9.1960e-02, -3.8415e-10,  4.4188e-01,  7.0783e-13,  1.7313e-10,
-        -3.2384e-01, -2.4428e-08,  7.2347e-03,  8.2176e-08, -4.5062e-14,
-        -1.4580e-04,  6.6175e-02, -2.1190e-01,  3.2904e-01, -1.6513e-05,
-         9.3782e-18, -3.6858e-12,  1.3255e-15,  6.1267e-01, -1.9144e-07,
-         3.5408e-15,  1.5464e-02,  2.9750e-03, -1.1476e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4381,  0.0000, -0.0918, -0.1142,  2.3139,  0.3437,  0.0000,  0.0000,
-        -0.1565,  0.0125, -0.0948,  0.0000, -0.6519, -0.3454,  0.0000,  0.0000,
-         0.0000, -0.2362,  0.0000,  0.0000,  0.0337, -0.1006,  0.0000, -0.6762,
-        -0.2302,  0.0000, -0.2801, -0.0946, -0.0693,  0.0000, -0.8137, -1.2193,
-         0.0000,  0.1115,  0.0000,  0.0000,  0.0000,  0.0000,  0.4793,  0.3503,
-         0.0920,  0.0000,  0.4419,  0.0000,  0.0000, -0.3238,  0.0000,  0.0072,
-         0.0000,  0.0000,  0.0000,  0.0662, -0.2119,  0.3290,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6127,  0.0000,  0.0000,  0.0155,  0.0000, -0.1148],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4381,  0.0000, -0.0918, -0.1142,  2.3139,  0.3437,  0.0000,  0.0000,
-        -0.1565,  0.0125, -0.0948,  0.0000, -0.6519, -0.3454,  0.0000,  0.0000,
-         0.0000, -0.2362,  0.0000,  0.0000,  0.0337, -0.1006,  0.0000, -0.6762,
-        -0.2302,  0.0000, -0.2801, -0.0946, -0.0693,  0.0000, -0.8137, -1.2193,
-         0.0000,  0.1115,  0.0000,  0.0000,  0.0000,  0.0000,  0.4793,  0.3503,
-         0.0920,  0.0000,  0.4419,  0.0000,  0.0000, -0.3238,  0.0000,  0.0072,
-         0.0000,  0.0000,  0.0000,  0.0662, -0.2119,  0.3290,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6127,  0.0000,  0.0000,  0.0155,  0.0000, -0.1148],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.9665e-01, -1.0620e-06, -1.1883e-01, -1.7697e-01,  2.3128e+00,
-         3.0663e-01,  6.4285e-16, -5.0760e-09, -1.8207e-01, -6.9289e-03,
-        -1.7013e-02,  1.1557e-07, -6.5844e-01, -3.5072e-01, -1.1528e-12,
-        -8.5681e-10, -8.4281e-11, -1.9958e-01, -1.9494e-13,  2.3438e-09,
-         5.7166e-03, -8.7546e-02, -4.1397e-13, -6.7957e-01, -2.3569e-01,
-         3.3598e-09, -3.2114e-01, -1.0877e-01, -4.7134e-02,  6.0206e-13,
-        -7.9370e-01, -1.2125e+00, -1.5412e-16,  1.1716e-01,  7.6496e-11,
-         4.1274e-09,  2.5517e-12,  0.0000e+00,  4.4195e-01,  2.9210e-01,
-         1.0382e-01, -3.4082e-10,  4.6339e-01,  6.2800e-13,  1.5360e-10,
-        -3.1956e-01, -2.1673e-08, -6.8540e-03,  7.2907e-08, -3.9980e-14,
-        -1.2936e-04,  4.3288e-02, -1.2778e-01,  2.9932e-01, -1.4651e-05,
-         8.3205e-18, -3.2701e-12,  1.1760e-15,  6.0421e-01, -1.6985e-07,
-         3.1415e-15, -5.1588e-02,  2.6394e-03, -9.5423e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3966,  0.0000, -0.1188, -0.1770,  2.3128,  0.3066,  0.0000,  0.0000,
-        -0.1821, -0.0069, -0.0170,  0.0000, -0.6584, -0.3507,  0.0000,  0.0000,
-         0.0000, -0.1996,  0.0000,  0.0000,  0.0057, -0.0875,  0.0000, -0.6796,
-        -0.2357,  0.0000, -0.3211, -0.1088, -0.0471,  0.0000, -0.7937, -1.2125,
-         0.0000,  0.1172,  0.0000,  0.0000,  0.0000,  0.0000,  0.4419,  0.2921,
-         0.1038,  0.0000,  0.4634,  0.0000,  0.0000, -0.3196,  0.0000, -0.0069,
-         0.0000,  0.0000,  0.0000,  0.0433, -0.1278,  0.2993,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6042,  0.0000,  0.0000, -0.0516,  0.0000, -0.0954],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3966,  0.0000, -0.1188, -0.1770,  2.3128,  0.3066,  0.0000,  0.0000,
-        -0.1821, -0.0069, -0.0170,  0.0000, -0.6584, -0.3507,  0.0000,  0.0000,
-         0.0000, -0.1996,  0.0000,  0.0000,  0.0057, -0.0875,  0.0000, -0.6796,
-        -0.2357,  0.0000, -0.3211, -0.1088, -0.0471,  0.0000, -0.7937, -1.2125,
-         0.0000,  0.1172,  0.0000,  0.0000,  0.0000,  0.0000,  0.4419,  0.2921,
-         0.1038,  0.0000,  0.4634,  0.0000,  0.0000, -0.3196,  0.0000, -0.0069,
-         0.0000,  0.0000,  0.0000,  0.0433, -0.1278,  0.2993,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6042,  0.0000,  0.0000, -0.0516,  0.0000, -0.0954],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6626e-01, -9.4261e-07, -1.3988e-01, -2.4234e-01,  2.3118e+00,
-         2.5761e-01,  5.7058e-16, -4.5054e-09, -1.9098e-01,  1.4487e-02,
-         9.5630e-02,  1.0257e-07, -6.5812e-01, -3.4965e-01, -1.0232e-12,
-        -7.6049e-10, -7.4806e-11, -1.6146e-01, -1.7302e-13,  2.0803e-09,
-        -5.9953e-02, -4.2160e-02, -3.6743e-13, -7.4778e-01, -1.9503e-01,
-         2.9821e-09, -3.5435e-01, -1.3060e-01, -6.9464e-02,  5.3438e-13,
-        -7.6912e-01, -1.2134e+00, -1.3679e-16,  9.4074e-02,  6.7896e-11,
-         3.6634e-09,  2.2648e-12,  0.0000e+00,  4.4734e-01,  2.8128e-01,
-         1.2618e-01, -3.0251e-10,  4.7720e-01,  5.5740e-13,  1.3633e-10,
-        -2.6979e-01, -1.9236e-08,  2.8863e-03,  6.4711e-08, -3.5485e-14,
-        -1.1481e-04,  7.1209e-03,  2.5226e-03,  3.3270e-01, -1.3004e-05,
-         7.3851e-18, -2.9025e-12,  1.0438e-15,  5.9260e-01, -1.5075e-07,
-         2.7883e-15, -8.8787e-02,  2.3427e-03, -5.9164e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3663,  0.0000, -0.1399, -0.2423,  2.3118,  0.2576,  0.0000,  0.0000,
-        -0.1910,  0.0145,  0.0956,  0.0000, -0.6581, -0.3496,  0.0000,  0.0000,
-         0.0000, -0.1615,  0.0000,  0.0000, -0.0600, -0.0422,  0.0000, -0.7478,
-        -0.1950,  0.0000, -0.3543, -0.1306, -0.0695,  0.0000, -0.7691, -1.2134,
-         0.0000,  0.0941,  0.0000,  0.0000,  0.0000,  0.0000,  0.4473,  0.2813,
-         0.1262,  0.0000,  0.4772,  0.0000,  0.0000, -0.2698,  0.0000,  0.0029,
-         0.0000,  0.0000,  0.0000,  0.0071,  0.0025,  0.3327,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5926,  0.0000,  0.0000, -0.0888,  0.0000, -0.0592],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3663,  0.0000, -0.1399, -0.2423,  2.3118,  0.2576,  0.0000,  0.0000,
-        -0.1910,  0.0145,  0.0956,  0.0000, -0.6581, -0.3496,  0.0000,  0.0000,
-         0.0000, -0.1615,  0.0000,  0.0000, -0.0600, -0.0422,  0.0000, -0.7478,
-        -0.1950,  0.0000, -0.3543, -0.1306, -0.0695,  0.0000, -0.7691, -1.2134,
-         0.0000,  0.0941,  0.0000,  0.0000,  0.0000,  0.0000,  0.4473,  0.2813,
-         0.1262,  0.0000,  0.4772,  0.0000,  0.0000, -0.2698,  0.0000,  0.0029,
-         0.0000,  0.0000,  0.0000,  0.0071,  0.0025,  0.3327,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5926,  0.0000,  0.0000, -0.0888,  0.0000, -0.0592],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5163e-01, -8.3699e-07, -1.4101e-01, -3.0766e-01,  2.3092e+00,
-         2.0814e-01,  5.0664e-16, -4.0005e-09, -1.7778e-01,  4.4130e-02,
-         1.4313e-01,  9.1080e-08, -6.4981e-01, -3.3032e-01, -9.0855e-13,
-        -6.7527e-10, -6.6423e-11, -1.2170e-01, -1.5363e-13,  1.8472e-09,
-        -1.4326e-01,  2.1316e-02, -3.2626e-13, -8.2848e-01, -1.4314e-01,
-         2.6479e-09, -3.8219e-01, -1.4869e-01, -1.0369e-01,  4.7450e-13,
-        -7.3434e-01, -1.2157e+00, -1.2146e-16,  1.4744e-02,  6.0288e-11,
-         3.2529e-09,  2.0111e-12,  0.0000e+00,  4.8548e-01,  3.0567e-01,
-         1.5153e-01, -2.6861e-10,  4.9691e-01,  4.9494e-13,  1.2106e-10,
-        -2.1026e-01, -1.7081e-08,  2.0250e-02,  5.7460e-08, -3.1509e-14,
-        -1.0195e-04, -1.9277e-02,  1.1692e-01,  3.3699e-01, -1.1546e-05,
-         6.5575e-18, -2.5772e-12,  9.2686e-16,  5.5769e-01, -1.3386e-07,
-         2.4759e-15, -1.1481e-01,  2.0802e-03, -2.4919e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3516,  0.0000, -0.1410, -0.3077,  2.3092,  0.2081,  0.0000,  0.0000,
-        -0.1778,  0.0441,  0.1431,  0.0000, -0.6498, -0.3303,  0.0000,  0.0000,
-         0.0000, -0.1217,  0.0000,  0.0000, -0.1433,  0.0213,  0.0000, -0.8285,
-        -0.1431,  0.0000, -0.3822, -0.1487, -0.1037,  0.0000, -0.7343, -1.2157,
-         0.0000,  0.0147,  0.0000,  0.0000,  0.0000,  0.0000,  0.4855,  0.3057,
-         0.1515,  0.0000,  0.4969,  0.0000,  0.0000, -0.2103,  0.0000,  0.0202,
-         0.0000,  0.0000,  0.0000, -0.0193,  0.1169,  0.3370,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5577,  0.0000,  0.0000, -0.1148,  0.0000, -0.0249],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3516,  0.0000, -0.1410, -0.3077,  2.3092,  0.2081,  0.0000,  0.0000,
-        -0.1778,  0.0441,  0.1431,  0.0000, -0.6498, -0.3303,  0.0000,  0.0000,
-         0.0000, -0.1217,  0.0000,  0.0000, -0.1433,  0.0213,  0.0000, -0.8285,
-        -0.1431,  0.0000, -0.3822, -0.1487, -0.1037,  0.0000, -0.7343, -1.2157,
-         0.0000,  0.0147,  0.0000,  0.0000,  0.0000,  0.0000,  0.4855,  0.3057,
-         0.1515,  0.0000,  0.4969,  0.0000,  0.0000, -0.2103,  0.0000,  0.0202,
-         0.0000,  0.0000,  0.0000, -0.0193,  0.1169,  0.3370,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5577,  0.0000,  0.0000, -0.1148,  0.0000, -0.0249],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2026e-01, -7.4350e-07, -1.2260e-01, -3.6153e-01,  2.3048e+00,
-         1.7831e-01,  4.5005e-16, -3.5537e-09, -1.4908e-01,  5.6622e-02,
-         1.7111e-02,  8.0907e-08, -6.1887e-01, -2.9864e-01, -8.0707e-13,
-        -5.9985e-10, -5.9004e-11, -8.5926e-02, -1.3647e-13,  1.6408e-09,
-        -2.0223e-01,  8.6843e-02, -2.8982e-13, -8.9501e-01, -1.1746e-01,
-         2.3522e-09, -4.0143e-01, -1.5125e-01, -1.3226e-01,  4.2150e-13,
-        -6.9859e-01, -1.2175e+00, -1.0789e-16, -6.5371e-02,  5.3554e-11,
-         2.8896e-09,  1.7864e-12,  0.0000e+00,  5.4746e-01,  3.3139e-01,
-         1.3536e-01, -2.3861e-10,  5.2889e-01,  4.3966e-13,  1.0754e-10,
-        -1.7312e-01, -1.5173e-08,  4.7647e-02,  5.1042e-08, -2.7989e-14,
-        -9.0561e-05, -2.2379e-02,  1.3828e-01,  3.0602e-01, -1.0257e-05,
-         5.8251e-18, -2.2894e-12,  8.2333e-16,  5.2799e-01, -1.1891e-07,
-         2.1993e-15, -1.2794e-01,  1.8479e-03, -4.7571e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3203,  0.0000, -0.1226, -0.3615,  2.3048,  0.1783,  0.0000,  0.0000,
-        -0.1491,  0.0566,  0.0171,  0.0000, -0.6189, -0.2986,  0.0000,  0.0000,
-         0.0000, -0.0859,  0.0000,  0.0000, -0.2022,  0.0868,  0.0000, -0.8950,
-        -0.1175,  0.0000, -0.4014, -0.1512, -0.1323,  0.0000, -0.6986, -1.2175,
-         0.0000, -0.0654,  0.0000,  0.0000,  0.0000,  0.0000,  0.5475,  0.3314,
-         0.1354,  0.0000,  0.5289,  0.0000,  0.0000, -0.1731,  0.0000,  0.0476,
-         0.0000,  0.0000,  0.0000, -0.0224,  0.1383,  0.3060,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5280,  0.0000,  0.0000, -0.1279,  0.0000, -0.0476],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3203,  0.0000, -0.1226, -0.3615,  2.3048,  0.1783,  0.0000,  0.0000,
-        -0.1491,  0.0566,  0.0171,  0.0000, -0.6189, -0.2986,  0.0000,  0.0000,
-         0.0000, -0.0859,  0.0000,  0.0000, -0.2022,  0.0868,  0.0000, -0.8950,
-        -0.1175,  0.0000, -0.4014, -0.1512, -0.1323,  0.0000, -0.6986, -1.2175,
-         0.0000, -0.0654,  0.0000,  0.0000,  0.0000,  0.0000,  0.5475,  0.3314,
-         0.1354,  0.0000,  0.5289,  0.0000,  0.0000, -0.1731,  0.0000,  0.0476,
-         0.0000,  0.0000,  0.0000, -0.0224,  0.1383,  0.3060,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5280,  0.0000,  0.0000, -0.1279,  0.0000, -0.0476],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9664e-01, -6.6072e-07, -1.0824e-01, -4.1468e-01,  2.3013e+00,
-         1.6635e-01,  3.9995e-16, -3.1580e-09, -1.0513e-01,  7.6155e-02,
-        -1.2161e-01,  7.1900e-08, -5.8379e-01, -2.5193e-01, -7.1722e-13,
-        -5.3306e-10, -5.2435e-11, -6.7573e-02, -1.2128e-13,  1.4582e-09,
-        -2.5201e-01,  1.2034e-01, -2.5755e-13, -9.5043e-01, -1.0458e-01,
-         2.0903e-09, -4.0914e-01, -1.6788e-01, -1.7086e-01,  3.7457e-13,
-        -6.5554e-01, -1.2178e+00, -9.5883e-17, -1.3219e-01,  4.7592e-11,
-         2.5679e-09,  1.5875e-12,  0.0000e+00,  5.9793e-01,  3.3468e-01,
-         8.9060e-02, -2.1204e-10,  5.5515e-01,  3.9071e-13,  9.5563e-11,
-        -1.1487e-01, -1.3484e-08,  7.8198e-02,  4.5359e-08, -2.4873e-14,
-        -8.0479e-05, -2.6247e-02,  1.6040e-01,  2.5762e-01, -9.1149e-06,
-         5.1766e-18, -2.0345e-12,  7.3167e-16,  5.0643e-01, -1.0567e-07,
-         1.9545e-15, -1.2414e-01,  1.6421e-03, -5.9310e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2966,  0.0000, -0.1082, -0.4147,  2.3013,  0.1664,  0.0000,  0.0000,
-        -0.1051,  0.0762, -0.1216,  0.0000, -0.5838, -0.2519,  0.0000,  0.0000,
-         0.0000, -0.0676,  0.0000,  0.0000, -0.2520,  0.1203,  0.0000, -0.9504,
-        -0.1046,  0.0000, -0.4091, -0.1679, -0.1709,  0.0000, -0.6555, -1.2178,
-         0.0000, -0.1322,  0.0000,  0.0000,  0.0000,  0.0000,  0.5979,  0.3347,
-         0.0891,  0.0000,  0.5552,  0.0000,  0.0000, -0.1149,  0.0000,  0.0782,
-         0.0000,  0.0000,  0.0000, -0.0262,  0.1604,  0.2576,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5064,  0.0000,  0.0000, -0.1241,  0.0000, -0.0593],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2966,  0.0000, -0.1082, -0.4147,  2.3013,  0.1664,  0.0000,  0.0000,
-        -0.1051,  0.0762, -0.1216,  0.0000, -0.5838, -0.2519,  0.0000,  0.0000,
-         0.0000, -0.0676,  0.0000,  0.0000, -0.2520,  0.1203,  0.0000, -0.9504,
-        -0.1046,  0.0000, -0.4091, -0.1679, -0.1709,  0.0000, -0.6555, -1.2178,
-         0.0000, -0.1322,  0.0000,  0.0000,  0.0000,  0.0000,  0.5979,  0.3347,
-         0.0891,  0.0000,  0.5552,  0.0000,  0.0000, -0.1149,  0.0000,  0.0782,
-         0.0000,  0.0000,  0.0000, -0.0262,  0.1604,  0.2576,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5064,  0.0000,  0.0000, -0.1241,  0.0000, -0.0593],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8280e-01, -5.8740e-07, -1.0122e-01, -4.4841e-01,  2.3000e+00,
-         1.7765e-01,  3.5557e-16, -2.8076e-09, -7.4145e-02,  8.3770e-02,
-        -2.5788e-01,  6.3921e-08, -5.4600e-01, -2.1871e-01, -6.3763e-13,
-        -4.7391e-10, -4.6616e-11, -4.8803e-02, -1.0782e-13,  1.2964e-09,
-        -2.6685e-01,  1.2018e-01, -2.2897e-13, -9.9958e-01, -1.1855e-01,
-         1.8583e-09, -4.0038e-01, -1.4934e-01, -1.9834e-01,  3.3301e-13,
-        -6.1313e-01, -1.2155e+00, -8.5243e-17, -1.8024e-01,  4.2310e-11,
-         2.2829e-09,  1.4114e-12,  0.0000e+00,  6.3831e-01,  3.5807e-01,
-         5.0984e-02, -1.8851e-10,  5.7329e-01,  3.4735e-13,  8.4959e-11,
-        -8.3609e-02, -1.1987e-08,  1.0715e-01,  4.0326e-08, -2.2113e-14,
-        -7.1548e-05, -6.2678e-02,  1.6188e-01,  1.6080e-01, -8.1034e-06,
-         4.6021e-18, -1.8087e-12,  6.5048e-16,  5.0180e-01, -9.3943e-08,
-         1.7376e-15, -1.1097e-01,  1.4599e-03, -5.3954e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2828,  0.0000, -0.1012, -0.4484,  2.3000,  0.1776,  0.0000,  0.0000,
-        -0.0741,  0.0838, -0.2579,  0.0000, -0.5460, -0.2187,  0.0000,  0.0000,
-         0.0000, -0.0488,  0.0000,  0.0000, -0.2668,  0.1202,  0.0000, -0.9996,
-        -0.1186,  0.0000, -0.4004, -0.1493, -0.1983,  0.0000, -0.6131, -1.2155,
-         0.0000, -0.1802,  0.0000,  0.0000,  0.0000,  0.0000,  0.6383,  0.3581,
-         0.0510,  0.0000,  0.5733,  0.0000,  0.0000, -0.0836,  0.0000,  0.1072,
-         0.0000,  0.0000,  0.0000, -0.0627,  0.1619,  0.1608,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5018,  0.0000,  0.0000, -0.1110,  0.0000, -0.0540],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2828,  0.0000, -0.1012, -0.4484,  2.3000,  0.1776,  0.0000,  0.0000,
-        -0.0741,  0.0838, -0.2579,  0.0000, -0.5460, -0.2187,  0.0000,  0.0000,
-         0.0000, -0.0488,  0.0000,  0.0000, -0.2668,  0.1202,  0.0000, -0.9996,
-        -0.1186,  0.0000, -0.4004, -0.1493, -0.1983,  0.0000, -0.6131, -1.2155,
-         0.0000, -0.1802,  0.0000,  0.0000,  0.0000,  0.0000,  0.6383,  0.3581,
-         0.0510,  0.0000,  0.5733,  0.0000,  0.0000, -0.0836,  0.0000,  0.1072,
-         0.0000,  0.0000,  0.0000, -0.0627,  0.1619,  0.1608,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5018,  0.0000,  0.0000, -0.1110,  0.0000, -0.0540],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8112e-01, -5.2243e-07, -1.2803e-01, -4.6797e-01,  2.2969e+00,
-         1.9022e-01,  3.1624e-16, -2.4971e-09, -5.8704e-02,  1.3999e-01,
-        -2.7864e-01,  5.6851e-08, -5.3901e-01, -2.0743e-01, -5.6711e-13,
-        -4.2149e-10, -4.1460e-11, -1.5949e-02, -9.5895e-14,  1.1530e-09,
-        -2.7831e-01,  1.5445e-01, -2.0364e-13, -1.0245e+00, -7.3648e-02,
-         1.6528e-09, -4.1163e-01, -9.1099e-02, -2.4096e-01,  2.9617e-13,
-        -5.7702e-01, -1.2132e+00, -7.5814e-17, -2.3411e-01,  3.7631e-11,
-         2.0304e-09,  1.2553e-12,  0.0000e+00,  6.5304e-01,  3.8866e-01,
-         8.6875e-02, -1.6766e-10,  5.7269e-01,  3.0893e-13,  7.5562e-11,
-        -7.6500e-02, -1.0662e-08,  1.1793e-01,  3.5865e-08, -1.9667e-14,
-        -6.3634e-05, -7.7651e-02,  1.7544e-01,  4.7074e-02, -7.2071e-06,
-         4.0931e-18, -1.6087e-12,  5.7853e-16,  4.7874e-01, -8.3552e-08,
-         1.5454e-15, -1.1302e-01,  1.2984e-03, -1.3100e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2811,  0.0000, -0.1280, -0.4680,  2.2969,  0.1902,  0.0000,  0.0000,
-        -0.0587,  0.1400, -0.2786,  0.0000, -0.5390, -0.2074,  0.0000,  0.0000,
-         0.0000, -0.0159,  0.0000,  0.0000, -0.2783,  0.1544,  0.0000, -1.0245,
-        -0.0736,  0.0000, -0.4116, -0.0911, -0.2410,  0.0000, -0.5770, -1.2132,
-         0.0000, -0.2341,  0.0000,  0.0000,  0.0000,  0.0000,  0.6530,  0.3887,
-         0.0869,  0.0000,  0.5727,  0.0000,  0.0000, -0.0765,  0.0000,  0.1179,
-         0.0000,  0.0000,  0.0000, -0.0777,  0.1754,  0.0471,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4787,  0.0000,  0.0000, -0.1130,  0.0000, -0.0131],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2811,  0.0000, -0.1280, -0.4680,  2.2969,  0.1902,  0.0000,  0.0000,
-        -0.0587,  0.1400, -0.2786,  0.0000, -0.5390, -0.2074,  0.0000,  0.0000,
-         0.0000, -0.0159,  0.0000,  0.0000, -0.2783,  0.1544,  0.0000, -1.0245,
-        -0.0736,  0.0000, -0.4116, -0.0911, -0.2410,  0.0000, -0.5770, -1.2132,
-         0.0000, -0.2341,  0.0000,  0.0000,  0.0000,  0.0000,  0.6530,  0.3887,
-         0.0869,  0.0000,  0.5727,  0.0000,  0.0000, -0.0765,  0.0000,  0.1179,
-         0.0000,  0.0000,  0.0000, -0.0777,  0.1754,  0.0471,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4787,  0.0000,  0.0000, -0.1130,  0.0000, -0.0131],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8642e-01, -4.6484e-07, -1.4974e-01, -4.7644e-01,  2.2911e+00,
-         2.0491e-01,  2.8138e-16, -2.2218e-09, -2.3243e-02,  1.8925e-01,
-        -2.9361e-01,  5.0583e-08, -5.3654e-01, -1.9141e-01, -5.0459e-13,
-        -3.7503e-10, -3.6890e-11,  3.6324e-02, -8.5323e-14,  1.0259e-09,
-        -2.7546e-01,  1.9098e-01, -1.8119e-13, -1.0431e+00, -9.8616e-03,
-         1.4706e-09, -4.1867e-01, -1.8085e-02, -2.7881e-01,  2.6352e-13,
-        -5.4406e-01, -1.2111e+00, -6.7456e-17, -2.5393e-01,  3.3482e-11,
-         1.8066e-09,  1.1169e-12,  0.0000e+00,  6.5451e-01,  4.0938e-01,
-         1.2548e-01, -1.4918e-10,  5.7783e-01,  2.7487e-13,  6.7232e-11,
-        -9.8518e-02, -9.4862e-09,  9.9108e-02,  3.1912e-08, -1.7499e-14,
-        -5.6619e-05, -6.0057e-02,  1.6498e-01, -4.3738e-03, -6.4126e-06,
-         3.6419e-18, -1.4313e-12,  5.1475e-16,  4.7400e-01, -7.4341e-08,
-         1.3750e-15, -9.8273e-02,  1.1553e-03,  1.3533e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2864,  0.0000, -0.1497, -0.4764,  2.2911,  0.2049,  0.0000,  0.0000,
-        -0.0232,  0.1893, -0.2936,  0.0000, -0.5365, -0.1914,  0.0000,  0.0000,
-         0.0000,  0.0363,  0.0000,  0.0000, -0.2755,  0.1910,  0.0000, -1.0431,
-        -0.0099,  0.0000, -0.4187, -0.0181, -0.2788,  0.0000, -0.5441, -1.2111,
-         0.0000, -0.2539,  0.0000,  0.0000,  0.0000,  0.0000,  0.6545,  0.4094,
-         0.1255,  0.0000,  0.5778,  0.0000,  0.0000, -0.0985,  0.0000,  0.0991,
-         0.0000,  0.0000,  0.0000, -0.0601,  0.1650, -0.0044,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4740,  0.0000,  0.0000, -0.0983,  0.0000,  0.0135],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2864,  0.0000, -0.1497, -0.4764,  2.2911,  0.2049,  0.0000,  0.0000,
-        -0.0232,  0.1893, -0.2936,  0.0000, -0.5365, -0.1914,  0.0000,  0.0000,
-         0.0000,  0.0363,  0.0000,  0.0000, -0.2755,  0.1910,  0.0000, -1.0431,
-        -0.0099,  0.0000, -0.4187, -0.0181, -0.2788,  0.0000, -0.5441, -1.2111,
-         0.0000, -0.2539,  0.0000,  0.0000,  0.0000,  0.0000,  0.6545,  0.4094,
-         0.1255,  0.0000,  0.5778,  0.0000,  0.0000, -0.0985,  0.0000,  0.0991,
-         0.0000,  0.0000,  0.0000, -0.0601,  0.1650, -0.0044,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4740,  0.0000,  0.0000, -0.0983,  0.0000,  0.0135],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7586e-01, -4.1376e-07, -1.7958e-01, -4.6902e-01,  2.2849e+00,
-         2.1220e-01,  2.5046e-16, -1.9776e-09, -9.0246e-03,  2.3860e-01,
-        -2.7725e-01,  4.5025e-08, -5.4477e-01, -1.7788e-01, -4.4914e-13,
-        -3.3382e-10, -3.2836e-11,  1.0564e-01, -7.5948e-14,  9.1314e-10,
-        -2.5442e-01,  2.3527e-01, -1.6128e-13, -1.0504e+00,  5.9368e-02,
-         1.3090e-09, -4.2495e-01,  6.1479e-02, -2.8450e-01,  2.3457e-13,
-        -5.2342e-01, -1.2064e+00, -6.0044e-17, -2.6093e-01,  2.9803e-11,
-         1.6081e-09,  9.9416e-13,  0.0000e+00,  6.2987e-01,  4.0831e-01,
-         1.8343e-01, -1.3279e-10,  5.7149e-01,  2.4467e-13,  5.9844e-11,
-        -1.2375e-01, -8.4438e-09,  7.4320e-02,  2.8405e-08, -1.5576e-14,
-        -5.0398e-05, -3.3927e-02,  1.8479e-01, -3.3131e-02, -5.7080e-06,
-         3.2417e-18, -1.2740e-12,  4.5819e-16,  4.7286e-01, -6.6173e-08,
-         1.2239e-15, -7.7419e-02,  1.0283e-03,  5.6428e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2759,  0.0000, -0.1796, -0.4690,  2.2849,  0.2122,  0.0000,  0.0000,
-        -0.0090,  0.2386, -0.2773,  0.0000, -0.5448, -0.1779,  0.0000,  0.0000,
-         0.0000,  0.1056,  0.0000,  0.0000, -0.2544,  0.2353,  0.0000, -1.0504,
-         0.0594,  0.0000, -0.4249,  0.0615, -0.2845,  0.0000, -0.5234, -1.2064,
-         0.0000, -0.2609,  0.0000,  0.0000,  0.0000,  0.0000,  0.6299,  0.4083,
-         0.1834,  0.0000,  0.5715,  0.0000,  0.0000, -0.1237,  0.0000,  0.0743,
-         0.0000,  0.0000,  0.0000, -0.0339,  0.1848, -0.0331,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4729,  0.0000,  0.0000, -0.0774,  0.0000,  0.0564],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2759,  0.0000, -0.1796, -0.4690,  2.2849,  0.2122,  0.0000,  0.0000,
-        -0.0090,  0.2386, -0.2773,  0.0000, -0.5448, -0.1779,  0.0000,  0.0000,
-         0.0000,  0.1056,  0.0000,  0.0000, -0.2544,  0.2353,  0.0000, -1.0504,
-         0.0594,  0.0000, -0.4249,  0.0615, -0.2845,  0.0000, -0.5234, -1.2064,
-         0.0000, -0.2609,  0.0000,  0.0000,  0.0000,  0.0000,  0.6299,  0.4083,
-         0.1834,  0.0000,  0.5715,  0.0000,  0.0000, -0.1237,  0.0000,  0.0743,
-         0.0000,  0.0000,  0.0000, -0.0339,  0.1848, -0.0331,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4729,  0.0000,  0.0000, -0.0774,  0.0000,  0.0564],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6224e-01, -3.6845e-07, -1.9650e-01, -4.5142e-01,  2.2790e+00,
-         2.2052e-01,  2.2303e-16, -1.7611e-09,  1.6175e-02,  2.6875e-01,
-        -2.6246e-01,  4.0094e-08, -5.4218e-01, -1.5701e-01, -3.9995e-13,
-        -2.9726e-10, -2.9240e-11,  1.2846e-01, -6.7630e-14,  8.1313e-10,
-        -2.3278e-01,  2.6426e-01, -1.4362e-13, -1.0495e+00,  1.1155e-01,
-         1.1656e-09, -4.2862e-01,  1.1538e-01, -2.9444e-01,  2.0888e-13,
-        -5.1047e-01, -1.1998e+00, -5.3468e-17, -2.5402e-01,  2.6539e-11,
-         1.4319e-09,  8.8528e-13,  0.0000e+00,  6.0419e-01,  4.1438e-01,
-         2.3250e-01, -1.1824e-10,  5.5972e-01,  2.1787e-13,  5.3290e-11,
-        -1.5970e-01, -7.5191e-09,  7.7168e-02,  2.5294e-08, -1.3870e-14,
-        -4.4878e-05,  2.6006e-02,  1.3175e-01, -3.6325e-02, -5.0828e-06,
-         2.8867e-18, -1.1345e-12,  4.0801e-16,  4.8681e-01, -5.8925e-08,
-         1.0899e-15, -6.1548e-02,  9.1572e-04,  4.8569e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2622,  0.0000, -0.1965, -0.4514,  2.2790,  0.2205,  0.0000,  0.0000,
-         0.0162,  0.2687, -0.2625,  0.0000, -0.5422, -0.1570,  0.0000,  0.0000,
-         0.0000,  0.1285,  0.0000,  0.0000, -0.2328,  0.2643,  0.0000, -1.0495,
-         0.1115,  0.0000, -0.4286,  0.1154, -0.2944,  0.0000, -0.5105, -1.1998,
-         0.0000, -0.2540,  0.0000,  0.0000,  0.0000,  0.0000,  0.6042,  0.4144,
-         0.2325,  0.0000,  0.5597,  0.0000,  0.0000, -0.1597,  0.0000,  0.0772,
-         0.0000,  0.0000,  0.0000,  0.0260,  0.1317, -0.0363,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4868,  0.0000,  0.0000, -0.0615,  0.0000,  0.0486],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2622,  0.0000, -0.1965, -0.4514,  2.2790,  0.2205,  0.0000,  0.0000,
-         0.0162,  0.2687, -0.2625,  0.0000, -0.5422, -0.1570,  0.0000,  0.0000,
-         0.0000,  0.1285,  0.0000,  0.0000, -0.2328,  0.2643,  0.0000, -1.0495,
-         0.1115,  0.0000, -0.4286,  0.1154, -0.2944,  0.0000, -0.5105, -1.1998,
-         0.0000, -0.2540,  0.0000,  0.0000,  0.0000,  0.0000,  0.6042,  0.4144,
-         0.2325,  0.0000,  0.5597,  0.0000,  0.0000, -0.1597,  0.0000,  0.0772,
-         0.0000,  0.0000,  0.0000,  0.0260,  0.1317, -0.0363,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4868,  0.0000,  0.0000, -0.0615,  0.0000,  0.0486],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4729e-01, -3.2823e-07, -1.7901e-01, -4.0970e-01,  2.2740e+00,
-         2.4470e-01,  1.9868e-16, -1.5688e-09,  3.2867e-02,  2.8244e-01,
-        -2.6277e-01,  3.5718e-08, -5.1528e-01, -1.2659e-01, -3.5629e-13,
-        -2.6481e-10, -2.6048e-11,  1.4875e-01, -6.0247e-14,  7.2437e-10,
-        -1.9626e-01,  2.4790e-01, -1.2794e-13, -1.0417e+00,  1.2096e-01,
-         1.0384e-09, -4.1846e-01,  1.5017e-01, -2.7403e-01,  1.8608e-13,
-        -5.2147e-01, -1.1907e+00, -4.7632e-17, -2.3139e-01,  2.3642e-11,
-         1.2756e-09,  7.8864e-13,  0.0000e+00,  5.6864e-01,  4.1168e-01,
-         2.1829e-01, -1.0534e-10,  5.5237e-01,  1.9409e-13,  4.7473e-11,
-        -2.0708e-01, -6.6983e-09,  9.8782e-02,  2.2533e-08, -1.2356e-14,
-        -3.9979e-05,  7.1602e-02,  5.5880e-02, -1.7080e-02, -4.5280e-06,
-         2.5716e-18, -1.0107e-12,  3.6347e-16,  5.0786e-01, -5.2493e-08,
-         9.7092e-16, -2.6973e-02,  8.1576e-04,  3.9182e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2473,  0.0000, -0.1790, -0.4097,  2.2740,  0.2447,  0.0000,  0.0000,
-         0.0329,  0.2824, -0.2628,  0.0000, -0.5153, -0.1266,  0.0000,  0.0000,
-         0.0000,  0.1488,  0.0000,  0.0000, -0.1963,  0.2479,  0.0000, -1.0417,
-         0.1210,  0.0000, -0.4185,  0.1502, -0.2740,  0.0000, -0.5215, -1.1907,
-         0.0000, -0.2314,  0.0000,  0.0000,  0.0000,  0.0000,  0.5686,  0.4117,
-         0.2183,  0.0000,  0.5524,  0.0000,  0.0000, -0.2071,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0716,  0.0559, -0.0171,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5079,  0.0000,  0.0000, -0.0270,  0.0000,  0.0392],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2473,  0.0000, -0.1790, -0.4097,  2.2740,  0.2447,  0.0000,  0.0000,
-         0.0329,  0.2824, -0.2628,  0.0000, -0.5153, -0.1266,  0.0000,  0.0000,
-         0.0000,  0.1488,  0.0000,  0.0000, -0.1963,  0.2479,  0.0000, -1.0417,
-         0.1210,  0.0000, -0.4185,  0.1502, -0.2740,  0.0000, -0.5215, -1.1907,
-         0.0000, -0.2314,  0.0000,  0.0000,  0.0000,  0.0000,  0.5686,  0.4117,
-         0.2183,  0.0000,  0.5524,  0.0000,  0.0000, -0.2071,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0716,  0.0559, -0.0171,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5079,  0.0000,  0.0000, -0.0270,  0.0000,  0.0392],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4479e-01, -2.9252e-07, -1.1882e-01, -3.8066e-01,  2.2670e+00,
-         2.6597e-01,  1.7707e-16, -1.3981e-09,  4.7311e-02,  2.6543e-01,
-        -2.8268e-01,  3.1832e-08, -4.7664e-01, -1.1953e-01, -3.1753e-13,
-        -2.3600e-10, -2.3214e-11,  1.1542e-01, -5.3693e-14,  6.4556e-10,
-        -1.6699e-01,  1.9691e-01, -1.1402e-13, -1.0153e+00,  9.2774e-02,
-         9.2542e-10, -3.9104e-01,  1.4197e-01, -2.2725e-01,  1.6583e-13,
-        -5.4870e-01, -1.1838e+00, -4.2449e-17, -1.7033e-01,  2.1070e-11,
-         1.1368e-09,  7.0284e-13,  0.0000e+00,  5.0972e-01,  4.0626e-01,
-         1.5430e-01, -9.3876e-11,  5.5455e-01,  1.7298e-13,  4.2308e-11,
-        -2.3094e-01, -5.9695e-09,  1.9262e-02,  2.0082e-08, -1.1012e-14,
-        -3.5630e-05,  1.1576e-01,  1.3987e-02, -3.6454e-02, -4.0354e-06,
-         2.2918e-18, -9.0071e-13,  3.2393e-16,  5.1283e-01, -4.6782e-08,
-         8.6529e-16,  2.2108e-02,  7.2701e-04,  3.3254e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2448,  0.0000, -0.1188, -0.3807,  2.2670,  0.2660,  0.0000,  0.0000,
-         0.0473,  0.2654, -0.2827,  0.0000, -0.4766, -0.1195,  0.0000,  0.0000,
-         0.0000,  0.1154,  0.0000,  0.0000, -0.1670,  0.1969,  0.0000, -1.0153,
-         0.0928,  0.0000, -0.3910,  0.1420, -0.2272,  0.0000, -0.5487, -1.1838,
-         0.0000, -0.1703,  0.0000,  0.0000,  0.0000,  0.0000,  0.5097,  0.4063,
-         0.1543,  0.0000,  0.5546,  0.0000,  0.0000, -0.2309,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1158,  0.0140, -0.0365,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5128,  0.0000,  0.0000,  0.0221,  0.0000,  0.0333],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2448,  0.0000, -0.1188, -0.3807,  2.2670,  0.2660,  0.0000,  0.0000,
-         0.0473,  0.2654, -0.2827,  0.0000, -0.4766, -0.1195,  0.0000,  0.0000,
-         0.0000,  0.1154,  0.0000,  0.0000, -0.1670,  0.1969,  0.0000, -1.0153,
-         0.0928,  0.0000, -0.3910,  0.1420, -0.2272,  0.0000, -0.5487, -1.1838,
-         0.0000, -0.1703,  0.0000,  0.0000,  0.0000,  0.0000,  0.5097,  0.4063,
-         0.1543,  0.0000,  0.5546,  0.0000,  0.0000, -0.2309,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1158,  0.0140, -0.0365,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5128,  0.0000,  0.0000,  0.0221,  0.0000,  0.0333],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.4992e-01, -2.6080e-07, -3.5197e-02, -3.6707e-01,  2.2632e+00,
-         3.0179e-01,  1.5787e-16, -1.2465e-09,  1.8097e-02,  2.5864e-01,
-        -2.7572e-01,  2.8380e-08, -4.4054e-01, -1.3543e-01, -2.8310e-13,
-        -2.1041e-10, -2.0697e-11,  3.7034e-02, -4.7870e-14,  5.7556e-10,
-        -1.4987e-01,  1.2770e-01, -1.0166e-13, -9.7505e-01,  4.7262e-02,
-         8.2507e-10, -3.6630e-01,  1.1017e-01, -1.7136e-01,  1.4785e-13,
-        -5.7439e-01, -1.1744e+00, -3.7846e-17, -8.6629e-02,  1.8785e-11,
-         1.0136e-09,  6.2663e-13,  0.0000e+00,  4.3167e-01,  3.9812e-01,
-         6.1886e-02, -8.3696e-11,  5.6700e-01,  1.5422e-13,  3.7720e-11,
-        -2.7326e-01, -5.3222e-09,  1.7173e-02,  1.7904e-08, -9.8179e-15,
-        -3.1766e-05,  1.4870e-01, -3.1268e-02, -1.0266e-01, -3.5978e-06,
-         2.0433e-18, -8.0304e-13,  2.8880e-16,  5.0206e-01, -4.1709e-08,
-         7.7146e-16,  4.1251e-02,  6.4817e-04,  3.6656e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2499,  0.0000, -0.0352, -0.3671,  2.2632,  0.3018,  0.0000,  0.0000,
-         0.0181,  0.2586, -0.2757,  0.0000, -0.4405, -0.1354,  0.0000,  0.0000,
-         0.0000,  0.0370,  0.0000,  0.0000, -0.1499,  0.1277,  0.0000, -0.9751,
-         0.0473,  0.0000, -0.3663,  0.1102, -0.1714,  0.0000, -0.5744, -1.1744,
-         0.0000, -0.0866,  0.0000,  0.0000,  0.0000,  0.0000,  0.4317,  0.3981,
-         0.0619,  0.0000,  0.5670,  0.0000,  0.0000, -0.2733,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1487, -0.0313, -0.1027,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5021,  0.0000,  0.0000,  0.0413,  0.0000,  0.0367],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2499,  0.0000, -0.0352, -0.3671,  2.2632,  0.3018,  0.0000,  0.0000,
-         0.0181,  0.2586, -0.2757,  0.0000, -0.4405, -0.1354,  0.0000,  0.0000,
-         0.0000,  0.0370,  0.0000,  0.0000, -0.1499,  0.1277,  0.0000, -0.9751,
-         0.0473,  0.0000, -0.3663,  0.1102, -0.1714,  0.0000, -0.5744, -1.1744,
-         0.0000, -0.0866,  0.0000,  0.0000,  0.0000,  0.0000,  0.4317,  0.3981,
-         0.0619,  0.0000,  0.5670,  0.0000,  0.0000, -0.2733,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1487, -0.0313, -0.1027,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5021,  0.0000,  0.0000,  0.0413,  0.0000,  0.0367],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6803e-01, -2.3261e-07,  4.8139e-02, -3.5141e-01,  2.2595e+00,
-         3.1105e-01,  1.4080e-16, -1.1118e-09, -2.4896e-02,  2.4683e-01,
-        -2.2702e-01,  2.5313e-08, -3.9527e-01, -1.7536e-01, -2.5250e-13,
-        -1.8767e-10, -1.8460e-11, -3.8987e-02, -4.2697e-14,  5.1336e-10,
-        -1.3224e-01,  6.4320e-02, -9.0672e-14, -9.1843e-01,  1.0646e-02,
-         7.3590e-10, -3.4159e-01,  7.1687e-02, -1.1190e-01,  1.3187e-13,
-        -6.1279e-01, -1.1612e+00, -3.3756e-17,  2.5557e-03,  1.6755e-11,
-         9.0403e-10,  5.5890e-13,  0.0000e+00,  3.6115e-01,  4.0053e-01,
-        -4.1032e-04, -7.4651e-11,  5.7768e-01,  1.3755e-13,  3.3643e-11,
-        -3.3263e-01, -4.7470e-09,  1.5317e-02,  1.5969e-08, -8.7568e-15,
-        -2.8333e-05,  2.0060e-01, -5.1673e-02, -1.4239e-01, -3.2089e-06,
-         1.8224e-18, -7.1625e-13,  2.5759e-16,  4.5233e-01, -3.7201e-08,
-         6.8808e-16,  5.7055e-02,  5.7812e-04,  6.7388e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.6803e-01,  0.0000e+00,  4.8139e-02, -3.5141e-01,  2.2595e+00,
-         3.1105e-01,  0.0000e+00,  0.0000e+00, -2.4896e-02,  2.4683e-01,
-        -2.2702e-01,  0.0000e+00, -3.9527e-01, -1.7536e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.8987e-02,  0.0000e+00,  0.0000e+00,
-        -1.3224e-01,  6.4320e-02,  0.0000e+00, -9.1843e-01,  1.0646e-02,
-         0.0000e+00, -3.4159e-01,  7.1687e-02, -1.1190e-01,  0.0000e+00,
-        -6.1279e-01, -1.1612e+00,  0.0000e+00,  2.5557e-03,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.6115e-01,  4.0053e-01,
-        -4.1032e-04,  0.0000e+00,  5.7768e-01,  0.0000e+00,  0.0000e+00,
-        -3.3263e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  2.0060e-01, -5.1673e-02, -1.4239e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.5233e-01,  0.0000e+00,
-         0.0000e+00,  5.7055e-02,  0.0000e+00,  6.7388e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.6803e-01,  0.0000e+00,  4.8139e-02, -3.5141e-01,  2.2595e+00,
-         3.1105e-01,  0.0000e+00,  0.0000e+00, -2.4896e-02,  2.4683e-01,
-        -2.2702e-01,  0.0000e+00, -3.9527e-01, -1.7536e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.8987e-02,  0.0000e+00,  0.0000e+00,
-        -1.3224e-01,  6.4320e-02,  0.0000e+00, -9.1843e-01,  1.0646e-02,
-         0.0000e+00, -3.4159e-01,  7.1687e-02, -1.1190e-01,  0.0000e+00,
-        -6.1279e-01, -1.1612e+00,  0.0000e+00,  2.5557e-03,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.6115e-01,  4.0053e-01,
-        -4.1032e-04,  0.0000e+00,  5.7768e-01,  0.0000e+00,  0.0000e+00,
-        -3.3263e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  2.0060e-01, -5.1673e-02, -1.4239e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.5233e-01,  0.0000e+00,
-         0.0000e+00,  5.7055e-02,  0.0000e+00,  6.7388e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7356e-01, -2.0755e-07,  9.9293e-02, -3.2322e-01,  2.2570e+00,
-         3.3381e-01,  1.2564e-16, -9.9204e-10, -9.8128e-02,  2.4154e-01,
-        -1.8237e-01,  2.2586e-08, -4.0294e-01, -2.5021e-01, -2.2530e-13,
-        -1.6745e-10, -1.6472e-11, -5.8983e-02, -3.8097e-14,  4.5806e-10,
-        -1.3613e-01,  1.1942e-02, -8.0905e-14, -8.6811e-01, -5.4647e-03,
-         6.5663e-10, -2.9492e-01,  2.5949e-02, -4.6723e-02,  1.1766e-13,
-        -5.9608e-01, -1.1521e+00, -3.0120e-17,  8.9923e-02,  1.4950e-11,
-         8.0665e-10,  4.9870e-13,  0.0000e+00,  2.8769e-01,  4.0196e-01,
-        -5.7089e-02, -6.6609e-11,  5.7148e-01,  1.2273e-13,  3.0019e-11,
-        -3.9640e-01, -4.2357e-09,  1.3667e-02,  1.4249e-08, -7.8135e-15,
-        -2.5281e-05,  1.9112e-01, -9.4021e-02, -1.4426e-02, -2.8633e-06,
-         1.6261e-18, -6.3909e-13,  2.2984e-16,  4.0475e-01, -3.3194e-08,
-         6.1396e-16,  7.3316e-02,  5.1585e-04,  9.2102e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2736,  0.0000,  0.0993, -0.3232,  2.2570,  0.3338,  0.0000,  0.0000,
-        -0.0981,  0.2415, -0.1824,  0.0000, -0.4029, -0.2502,  0.0000,  0.0000,
-         0.0000, -0.0590,  0.0000,  0.0000, -0.1361,  0.0119,  0.0000, -0.8681,
-        -0.0055,  0.0000, -0.2949,  0.0259, -0.0467,  0.0000, -0.5961, -1.1521,
-         0.0000,  0.0899,  0.0000,  0.0000,  0.0000,  0.0000,  0.2877,  0.4020,
-        -0.0571,  0.0000,  0.5715,  0.0000,  0.0000, -0.3964,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1911, -0.0940, -0.0144,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4048,  0.0000,  0.0000,  0.0733,  0.0000,  0.0921],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2736,  0.0000,  0.0993, -0.3232,  2.2570,  0.3338,  0.0000,  0.0000,
-        -0.0981,  0.2415, -0.1824,  0.0000, -0.4029, -0.2502,  0.0000,  0.0000,
-         0.0000, -0.0590,  0.0000,  0.0000, -0.1361,  0.0119,  0.0000, -0.8681,
-        -0.0055,  0.0000, -0.2949,  0.0259, -0.0467,  0.0000, -0.5961, -1.1521,
-         0.0000,  0.0899,  0.0000,  0.0000,  0.0000,  0.0000,  0.2877,  0.4020,
-        -0.0571,  0.0000,  0.5715,  0.0000,  0.0000, -0.3964,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1911, -0.0940, -0.0144,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4048,  0.0000,  0.0000,  0.0733,  0.0000,  0.0921],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6572e-01, -1.8527e-07,  9.8344e-02, -2.9523e-01,  2.2545e+00,
-         3.1931e-01,  1.1215e-16, -8.8554e-10, -1.8339e-01,  2.3818e-01,
-        -1.2468e-01,  2.0161e-08, -4.5262e-01, -3.1846e-01, -2.0111e-13,
-        -1.4947e-10, -1.4703e-11, -5.5978e-02, -3.4007e-14,  4.0888e-10,
-        -1.2456e-01, -4.6265e-02, -7.2219e-14, -8.1916e-01,  2.3146e-02,
-         5.8614e-10, -2.5415e-01,  2.6548e-02,  4.6869e-02,  1.0503e-13,
-        -5.6240e-01, -1.1474e+00, -2.6886e-17,  1.5230e-01,  1.3345e-11,
-         7.2005e-10,  4.4516e-13,  0.0000e+00,  2.3427e-01,  4.1774e-01,
-        -4.9406e-02, -5.9458e-11,  5.4426e-01,  1.0956e-13,  2.6797e-11,
-        -4.4156e-01, -3.7809e-09,  1.2200e-02,  1.2719e-08, -6.9747e-15,
-        -2.2567e-05,  1.8730e-01, -1.2091e-01,  8.0518e-02, -2.5559e-06,
-         1.4515e-18, -5.7048e-13,  2.0517e-16,  3.5350e-01, -2.9630e-08,
-         5.4805e-16,  4.2613e-02,  4.6046e-04,  5.9167e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2657,  0.0000,  0.0983, -0.2952,  2.2545,  0.3193,  0.0000,  0.0000,
-        -0.1834,  0.2382, -0.1247,  0.0000, -0.4526, -0.3185,  0.0000,  0.0000,
-         0.0000, -0.0560,  0.0000,  0.0000, -0.1246, -0.0463,  0.0000, -0.8192,
-         0.0231,  0.0000, -0.2541,  0.0265,  0.0469,  0.0000, -0.5624, -1.1474,
-         0.0000,  0.1523,  0.0000,  0.0000,  0.0000,  0.0000,  0.2343,  0.4177,
-        -0.0494,  0.0000,  0.5443,  0.0000,  0.0000, -0.4416,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1873, -0.1209,  0.0805,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3535,  0.0000,  0.0000,  0.0426,  0.0000,  0.0592],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2657,  0.0000,  0.0983, -0.2952,  2.2545,  0.3193,  0.0000,  0.0000,
-        -0.1834,  0.2382, -0.1247,  0.0000, -0.4526, -0.3185,  0.0000,  0.0000,
-         0.0000, -0.0560,  0.0000,  0.0000, -0.1246, -0.0463,  0.0000, -0.8192,
-         0.0231,  0.0000, -0.2541,  0.0265,  0.0469,  0.0000, -0.5624, -1.1474,
-         0.0000,  0.1523,  0.0000,  0.0000,  0.0000,  0.0000,  0.2343,  0.4177,
-        -0.0494,  0.0000,  0.5443,  0.0000,  0.0000, -0.4416,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1873, -0.1209,  0.0805,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3535,  0.0000,  0.0000,  0.0426,  0.0000,  0.0592],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7571e-01, -1.6545e-07,  8.0222e-02, -2.6634e-01,  2.2515e+00,
-         2.8470e-01,  1.0015e-16, -7.9078e-10, -2.5568e-01,  2.0939e-01,
-        -6.5277e-02,  1.8004e-08, -4.8120e-01, -3.7952e-01, -1.7959e-13,
-        -1.3348e-10, -1.3130e-11, -8.5133e-02, -3.0369e-14,  3.6513e-10,
-        -1.0325e-01, -1.1270e-01, -6.4491e-14, -7.7804e-01,  4.6319e-02,
-         5.2342e-10, -2.1807e-01,  4.7850e-02,  1.4303e-01,  9.3794e-14,
-        -5.4595e-01, -1.1440e+00, -2.4009e-17,  1.9695e-01,  1.1917e-11,
-         6.4300e-10,  3.9753e-13,  0.0000e+00,  1.7851e-01,  4.4968e-01,
-        -1.4693e-02, -5.3096e-11,  5.0782e-01,  9.7835e-14,  2.3929e-11,
-        -4.7855e-01, -3.3764e-09,  1.0895e-02,  1.1358e-08, -6.2284e-15,
-        -2.0152e-05,  2.0034e-01, -1.3090e-01,  8.2256e-02, -2.2824e-06,
-         1.2962e-18, -5.0944e-13,  1.8321e-16,  2.9210e-01, -2.6460e-08,
-         4.8940e-16, -9.6377e-03,  4.1119e-04,  3.0739e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2757,  0.0000,  0.0802, -0.2663,  2.2515,  0.2847,  0.0000,  0.0000,
-        -0.2557,  0.2094, -0.0653,  0.0000, -0.4812, -0.3795,  0.0000,  0.0000,
-         0.0000, -0.0851,  0.0000,  0.0000, -0.1033, -0.1127,  0.0000, -0.7780,
-         0.0463,  0.0000, -0.2181,  0.0479,  0.1430,  0.0000, -0.5459, -1.1440,
-         0.0000,  0.1970,  0.0000,  0.0000,  0.0000,  0.0000,  0.1785,  0.4497,
-        -0.0147,  0.0000,  0.5078,  0.0000,  0.0000, -0.4785,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.2003, -0.1309,  0.0823,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2921,  0.0000,  0.0000, -0.0096,  0.0000,  0.0031],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2757,  0.0000,  0.0802, -0.2663,  2.2515,  0.2847,  0.0000,  0.0000,
-        -0.2557,  0.2094, -0.0653,  0.0000, -0.4812, -0.3795,  0.0000,  0.0000,
-         0.0000, -0.0851,  0.0000,  0.0000, -0.1033, -0.1127,  0.0000, -0.7780,
-         0.0463,  0.0000, -0.2181,  0.0479,  0.1430,  0.0000, -0.5459, -1.1440,
-         0.0000,  0.1970,  0.0000,  0.0000,  0.0000,  0.0000,  0.1785,  0.4497,
-        -0.0147,  0.0000,  0.5078,  0.0000,  0.0000, -0.4785,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.2003, -0.1309,  0.0823,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2921,  0.0000,  0.0000, -0.0096,  0.0000,  0.0031],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7288e-01, -1.4780e-07,  4.7549e-02, -2.5731e-01,  2.2471e+00,
-         2.3969e-01,  8.9468e-17, -7.0645e-10, -3.0470e-01,  1.5107e-01,
-         8.2497e-03,  1.6084e-08, -4.9956e-01, -4.3502e-01, -1.6044e-13,
-        -1.1925e-10, -1.1730e-11, -1.2998e-01, -2.7130e-14,  3.2619e-10,
-        -6.7593e-02, -1.6429e-01, -5.7614e-14, -7.3653e-01,  8.5602e-02,
-         4.6760e-10, -2.1261e-01,  7.5673e-02,  2.1109e-01,  8.3791e-14,
-        -5.4078e-01, -1.1329e+00, -2.1449e-17,  2.1160e-01,  1.0646e-11,
-         5.7443e-10,  3.5513e-13,  0.0000e+00,  1.1443e-01,  4.8148e-01,
-         3.8393e-02, -4.7434e-11,  4.6827e-01,  8.7401e-14,  2.1377e-11,
-        -5.0722e-01, -3.0163e-09,  9.7328e-03,  1.0147e-08, -5.5641e-15,
-        -1.8003e-05,  2.3194e-01, -1.3057e-01,  3.5615e-02, -2.0390e-06,
-         1.1580e-18, -4.5511e-13,  1.6367e-16,  2.3573e-01, -2.3638e-08,
-         4.3721e-16, -7.5281e-02,  3.6734e-04, -5.9831e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2729,  0.0000,  0.0475, -0.2573,  2.2471,  0.2397,  0.0000,  0.0000,
-        -0.3047,  0.1511,  0.0082,  0.0000, -0.4996, -0.4350,  0.0000,  0.0000,
-         0.0000, -0.1300,  0.0000,  0.0000, -0.0676, -0.1643,  0.0000, -0.7365,
-         0.0856,  0.0000, -0.2126,  0.0757,  0.2111,  0.0000, -0.5408, -1.1329,
-         0.0000,  0.2116,  0.0000,  0.0000,  0.0000,  0.0000,  0.1144,  0.4815,
-         0.0384,  0.0000,  0.4683,  0.0000,  0.0000, -0.5072,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.2319, -0.1306,  0.0356,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2357,  0.0000,  0.0000, -0.0753,  0.0000, -0.0598],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2729,  0.0000,  0.0475, -0.2573,  2.2471,  0.2397,  0.0000,  0.0000,
-        -0.3047,  0.1511,  0.0082,  0.0000, -0.4996, -0.4350,  0.0000,  0.0000,
-         0.0000, -0.1300,  0.0000,  0.0000, -0.0676, -0.1643,  0.0000, -0.7365,
-         0.0856,  0.0000, -0.2126,  0.0757,  0.2111,  0.0000, -0.5408, -1.1329,
-         0.0000,  0.2116,  0.0000,  0.0000,  0.0000,  0.0000,  0.1144,  0.4815,
-         0.0384,  0.0000,  0.4683,  0.0000,  0.0000, -0.5072,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.2319, -0.1306,  0.0356,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2357,  0.0000,  0.0000, -0.0753,  0.0000, -0.0598],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.5661e-01, -1.3209e-07,  3.8956e-02, -2.2120e-01,  2.2405e+00,
-         2.3968e-01,  7.9959e-17, -6.3137e-10, -3.4922e-01,  7.2618e-02,
-         2.0195e-02,  1.4374e-08, -5.1851e-01, -4.7075e-01, -1.4339e-13,
-        -1.0657e-10, -1.0483e-11, -1.9829e-01, -2.4246e-14,  2.9152e-10,
-        -3.9259e-02, -2.3194e-01, -5.1490e-14, -7.1924e-01,  6.7476e-02,
-         4.1790e-10, -1.7889e-01,  7.3577e-02,  2.3312e-01,  7.4885e-14,
-        -5.3612e-01, -1.1253e+00, -1.9169e-17,  1.7331e-01,  9.5146e-12,
-         5.1337e-10,  3.1739e-13,  0.0000e+00,  8.0527e-02,  5.0698e-01,
-         2.7682e-02, -4.2392e-11,  4.3361e-01,  7.8112e-14,  1.9105e-11,
-        -5.1550e-01, -2.6957e-09,  8.6983e-03,  9.0683e-09, -4.9727e-15,
-        -1.6090e-05,  2.0031e-01, -1.3471e-01,  1.5051e-02, -1.8223e-06,
-         1.0349e-18, -4.0674e-13,  1.4628e-16,  1.7542e-01, -2.1126e-08,
-         3.9074e-16, -1.2333e-01,  3.2830e-04, -9.3982e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2566,  0.0000,  0.0390, -0.2212,  2.2405,  0.2397,  0.0000,  0.0000,
-        -0.3492,  0.0726,  0.0202,  0.0000, -0.5185, -0.4708,  0.0000,  0.0000,
-         0.0000, -0.1983,  0.0000,  0.0000, -0.0393, -0.2319,  0.0000, -0.7192,
-         0.0675,  0.0000, -0.1789,  0.0736,  0.2331,  0.0000, -0.5361, -1.1253,
-         0.0000,  0.1733,  0.0000,  0.0000,  0.0000,  0.0000,  0.0805,  0.5070,
-         0.0277,  0.0000,  0.4336,  0.0000,  0.0000, -0.5155,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.2003, -0.1347,  0.0151,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.1754,  0.0000,  0.0000, -0.1233,  0.0000, -0.0940],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2566,  0.0000,  0.0390, -0.2212,  2.2405,  0.2397,  0.0000,  0.0000,
-        -0.3492,  0.0726,  0.0202,  0.0000, -0.5185, -0.4708,  0.0000,  0.0000,
-         0.0000, -0.1983,  0.0000,  0.0000, -0.0393, -0.2319,  0.0000, -0.7192,
-         0.0675,  0.0000, -0.1789,  0.0736,  0.2331,  0.0000, -0.5361, -1.1253,
-         0.0000,  0.1733,  0.0000,  0.0000,  0.0000,  0.0000,  0.0805,  0.5070,
-         0.0277,  0.0000,  0.4336,  0.0000,  0.0000, -0.5155,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.2003, -0.1347,  0.0151,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.1754,  0.0000,  0.0000, -0.1233,  0.0000, -0.0940],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.1939e-01, -1.1810e-07,  2.7222e-02, -1.8902e-01,  2.2314e+00,
-         2.6210e-01,  7.1489e-17, -5.6448e-10, -3.6548e-01, -3.9029e-03,
-         2.2403e-02,  1.2852e-08, -5.2135e-01, -5.0087e-01, -1.2820e-13,
-        -9.5282e-11, -9.3725e-12, -2.3963e-01, -2.1678e-14,  2.6064e-10,
-        -1.1604e-02, -2.7514e-01, -4.6036e-14, -7.1530e-01,  4.8151e-02,
-         3.7363e-10, -1.4721e-01,  6.8678e-02,  2.1815e-01,  6.6953e-14,
-        -5.2438e-01, -1.1221e+00, -1.7139e-17,  9.7202e-02,  8.5067e-12,
-         4.5899e-10,  2.8376e-13,  0.0000e+00,  5.6664e-02,  5.0921e-01,
-        -4.7499e-04, -3.7902e-11,  3.8851e-01,  6.9837e-14,  1.7081e-11,
-        -4.7823e-01, -2.4101e-09,  7.7769e-03,  8.1077e-09, -4.4460e-15,
-        -1.4385e-05,  1.4134e-01, -1.3079e-01,  3.9752e-02, -1.6292e-06,
-         9.2528e-19, -3.6365e-13,  1.3078e-16,  1.3548e-01, -1.8888e-08,
-         3.4935e-16, -1.4356e-01,  2.9352e-04, -1.1890e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.1939e-01,  0.0000e+00,  2.7222e-02, -1.8902e-01,  2.2314e+00,
-         2.6210e-01,  0.0000e+00,  0.0000e+00, -3.6548e-01, -3.9029e-03,
-         2.2403e-02,  0.0000e+00, -5.2135e-01, -5.0087e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.3963e-01,  0.0000e+00,  0.0000e+00,
-        -1.1604e-02, -2.7514e-01,  0.0000e+00, -7.1530e-01,  4.8151e-02,
-         0.0000e+00, -1.4721e-01,  6.8678e-02,  2.1815e-01,  0.0000e+00,
-        -5.2438e-01, -1.1221e+00,  0.0000e+00,  9.7202e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.6664e-02,  5.0921e-01,
-        -4.7499e-04,  0.0000e+00,  3.8851e-01,  0.0000e+00,  0.0000e+00,
-        -4.7823e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.4134e-01, -1.3079e-01,  3.9752e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.3548e-01,  0.0000e+00,
-         0.0000e+00, -1.4356e-01,  0.0000e+00, -1.1890e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.1939e-01,  0.0000e+00,  2.7222e-02, -1.8902e-01,  2.2314e+00,
-         2.6210e-01,  0.0000e+00,  0.0000e+00, -3.6548e-01, -3.9029e-03,
-         2.2403e-02,  0.0000e+00, -5.2135e-01, -5.0087e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.3963e-01,  0.0000e+00,  0.0000e+00,
-        -1.1604e-02, -2.7514e-01,  0.0000e+00, -7.1530e-01,  4.8151e-02,
-         0.0000e+00, -1.4721e-01,  6.8678e-02,  2.1815e-01,  0.0000e+00,
-        -5.2438e-01, -1.1221e+00,  0.0000e+00,  9.7202e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.6664e-02,  5.0921e-01,
-        -4.7499e-04,  0.0000e+00,  3.8851e-01,  0.0000e+00,  0.0000e+00,
-        -4.7823e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.4134e-01, -1.3079e-01,  3.9752e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.3548e-01,  0.0000e+00,
-         0.0000e+00, -1.4356e-01,  0.0000e+00, -1.1890e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.8682e-01, -1.0563e-07,  3.0226e-02, -1.5479e-01,  2.2231e+00,
-         2.9966e-01,  6.3941e-17, -5.0489e-10, -3.6898e-01, -7.6030e-02,
-         2.2792e-02,  1.1495e-08, -4.9795e-01, -5.3146e-01, -1.1466e-13,
-        -8.5223e-11, -8.3830e-12, -2.4933e-01, -1.9389e-14,  2.3312e-10,
-         2.7749e-03, -3.1650e-01, -4.1176e-14, -7.3483e-01,  3.4095e-03,
-         3.3418e-10, -9.9256e-02,  4.8092e-02,  1.7809e-01,  5.9884e-14,
-        -5.2319e-01, -1.1200e+00, -1.5329e-17, -2.6731e-02,  7.6086e-12,
-         4.1053e-10,  2.5381e-13,  0.0000e+00,  2.8062e-02,  4.9504e-01,
-        -5.0031e-02, -3.3900e-11,  3.4239e-01,  6.2464e-14,  1.5278e-11,
-        -4.3122e-01, -2.1557e-09,  6.9558e-03,  7.2517e-09, -3.9766e-15,
-        -1.2866e-05,  7.0491e-02, -1.1964e-01,  1.6164e-01, -1.4572e-06,
-         8.2759e-19, -3.2526e-13,  1.1697e-16,  1.2076e-01, -1.6894e-08,
-         3.1247e-16, -1.4281e-01,  2.6253e-04, -1.2533e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1868,  0.0000,  0.0302, -0.1548,  2.2231,  0.2997,  0.0000,  0.0000,
-        -0.3690, -0.0760,  0.0228,  0.0000, -0.4979, -0.5315,  0.0000,  0.0000,
-         0.0000, -0.2493,  0.0000,  0.0000,  0.0028, -0.3165,  0.0000, -0.7348,
-         0.0034,  0.0000, -0.0993,  0.0481,  0.1781,  0.0000, -0.5232, -1.1200,
-         0.0000, -0.0267,  0.0000,  0.0000,  0.0000,  0.0000,  0.0281,  0.4950,
-        -0.0500,  0.0000,  0.3424,  0.0000,  0.0000, -0.4312,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0705, -0.1196,  0.1616,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.1208,  0.0000,  0.0000, -0.1428,  0.0000, -0.1253],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1868,  0.0000,  0.0302, -0.1548,  2.2231,  0.2997,  0.0000,  0.0000,
-        -0.3690, -0.0760,  0.0228,  0.0000, -0.4979, -0.5315,  0.0000,  0.0000,
-         0.0000, -0.2493,  0.0000,  0.0000,  0.0028, -0.3165,  0.0000, -0.7348,
-         0.0034,  0.0000, -0.0993,  0.0481,  0.1781,  0.0000, -0.5232, -1.1200,
-         0.0000, -0.0267,  0.0000,  0.0000,  0.0000,  0.0000,  0.0281,  0.4950,
-        -0.0500,  0.0000,  0.3424,  0.0000,  0.0000, -0.4312,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0705, -0.1196,  0.1616,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.1208,  0.0000,  0.0000, -0.1428,  0.0000, -0.1253],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.6420e-01, -9.4518e-08,  5.2378e-03, -1.3065e-01,  2.2171e+00,
-         3.1887e-01,  5.7213e-17, -4.5176e-10, -3.4461e-01, -1.3023e-01,
-         9.9437e-02,  1.0285e-08, -4.7275e-01, -5.5808e-01, -1.0260e-13,
-        -7.6256e-11, -7.5009e-12, -2.3745e-01, -1.7349e-14,  2.0859e-10,
-         3.2309e-02, -3.1658e-01, -3.6843e-14, -7.3189e-01,  2.5932e-03,
-         2.9902e-10, -7.4248e-02,  5.3871e-02,  1.5104e-01,  5.3583e-14,
-        -5.3727e-01, -1.1164e+00, -1.3716e-17, -1.4623e-01,  6.8080e-12,
-         3.6734e-10,  2.2710e-13,  0.0000e+00, -2.0704e-02,  4.7388e-01,
-        -5.2288e-02, -3.0333e-11,  2.8249e-01,  5.5891e-14,  1.3670e-11,
-        -3.6492e-01, -1.9289e-09,  6.2239e-03,  6.4887e-09, -3.5582e-15,
-        -1.1513e-05,  4.7448e-02, -8.0751e-02,  2.6581e-01, -1.3039e-06,
-         7.4051e-19, -2.9103e-13,  1.0467e-16,  1.0764e-01, -1.5116e-08,
-         2.7959e-16, -1.4402e-01,  2.3491e-04, -1.0754e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1642,  0.0000,  0.0052, -0.1307,  2.2171,  0.3189,  0.0000,  0.0000,
-        -0.3446, -0.1302,  0.0994,  0.0000, -0.4728, -0.5581,  0.0000,  0.0000,
-         0.0000, -0.2375,  0.0000,  0.0000,  0.0323, -0.3166,  0.0000, -0.7319,
-         0.0026,  0.0000, -0.0742,  0.0539,  0.1510,  0.0000, -0.5373, -1.1164,
-         0.0000, -0.1462,  0.0000,  0.0000,  0.0000,  0.0000, -0.0207,  0.4739,
-        -0.0523,  0.0000,  0.2825,  0.0000,  0.0000, -0.3649,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0474, -0.0808,  0.2658,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.1076,  0.0000,  0.0000, -0.1440,  0.0000, -0.1075],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1642,  0.0000,  0.0052, -0.1307,  2.2171,  0.3189,  0.0000,  0.0000,
-        -0.3446, -0.1302,  0.0994,  0.0000, -0.4728, -0.5581,  0.0000,  0.0000,
-         0.0000, -0.2375,  0.0000,  0.0000,  0.0323, -0.3166,  0.0000, -0.7319,
-         0.0026,  0.0000, -0.0742,  0.0539,  0.1510,  0.0000, -0.5373, -1.1164,
-         0.0000, -0.1462,  0.0000,  0.0000,  0.0000,  0.0000, -0.0207,  0.4739,
-        -0.0523,  0.0000,  0.2825,  0.0000,  0.0000, -0.3649,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0474, -0.0808,  0.2658,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.1076,  0.0000,  0.0000, -0.1440,  0.0000, -0.1075],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.3048e-01, -8.4606e-08, -4.5497e-02, -1.1563e-01,  2.2113e+00,
-         3.3436e-01,  5.1213e-17, -4.0439e-10, -3.2504e-01, -1.7729e-01,
-         1.9512e-01,  9.2067e-09, -4.5527e-01, -5.7912e-01, -9.1840e-14,
-        -6.8259e-11, -6.7143e-12, -1.9822e-01, -1.5530e-14,  1.8672e-10,
-         1.0140e-01, -2.8702e-01, -3.2979e-14, -7.0837e-01,  2.9384e-02,
-         2.6766e-10, -7.6125e-02,  9.8569e-02,  1.6135e-01,  4.7964e-14,
-        -5.5478e-01, -1.1122e+00, -1.2278e-17, -2.4555e-01,  6.0941e-12,
-         3.2881e-10,  2.0328e-13,  0.0000e+00, -7.3913e-02,  4.6084e-01,
-         1.6038e-03, -2.7152e-11,  2.0858e-01,  5.0030e-14,  1.2237e-11,
-        -3.2910e-01, -1.7266e-09,  5.5712e-03,  5.8082e-09, -3.1850e-15,
-        -1.0305e-05,  7.0279e-02, -6.0363e-02,  3.5066e-01, -1.1672e-06,
-         6.6286e-19, -2.6051e-13,  9.3690e-17,  1.0545e-01, -1.3531e-08,
-         2.5027e-16, -1.6202e-01,  2.1027e-04, -9.9008e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 1.3048e-01,  0.0000e+00, -4.5497e-02, -1.1563e-01,  2.2113e+00,
-         3.3436e-01,  0.0000e+00,  0.0000e+00, -3.2504e-01, -1.7729e-01,
-         1.9512e-01,  0.0000e+00, -4.5527e-01, -5.7912e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.9822e-01,  0.0000e+00,  0.0000e+00,
-         1.0140e-01, -2.8702e-01,  0.0000e+00, -7.0837e-01,  2.9384e-02,
-         0.0000e+00, -7.6125e-02,  9.8569e-02,  1.6135e-01,  0.0000e+00,
-        -5.5478e-01, -1.1122e+00,  0.0000e+00, -2.4555e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -7.3913e-02,  4.6084e-01,
-         1.6038e-03,  0.0000e+00,  2.0858e-01,  0.0000e+00,  0.0000e+00,
-        -3.2910e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  7.0279e-02, -6.0363e-02,  3.5066e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0545e-01,  0.0000e+00,
-         0.0000e+00, -1.6202e-01,  0.0000e+00, -9.9008e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 1.3048e-01,  0.0000e+00, -4.5497e-02, -1.1563e-01,  2.2113e+00,
-         3.3436e-01,  0.0000e+00,  0.0000e+00, -3.2504e-01, -1.7729e-01,
-         1.9512e-01,  0.0000e+00, -4.5527e-01, -5.7912e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.9822e-01,  0.0000e+00,  0.0000e+00,
-         1.0140e-01, -2.8702e-01,  0.0000e+00, -7.0837e-01,  2.9384e-02,
-         0.0000e+00, -7.6125e-02,  9.8569e-02,  1.6135e-01,  0.0000e+00,
-        -5.5478e-01, -1.1122e+00,  0.0000e+00, -2.4555e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -7.3913e-02,  4.6084e-01,
-         1.6038e-03,  0.0000e+00,  2.0858e-01,  0.0000e+00,  0.0000e+00,
-        -3.2910e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  7.0279e-02, -6.0363e-02,  3.5066e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0545e-01,  0.0000e+00,
-         0.0000e+00, -1.6202e-01,  0.0000e+00, -9.9008e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.1721e-01, -7.5763e-08, -1.1815e-01, -1.0550e-01,  2.2069e+00,
-         3.5825e-01,  4.5861e-17, -3.6212e-10, -3.0533e-01, -2.0965e-01,
-         2.6315e-01,  8.2445e-09, -4.4596e-01, -5.9796e-01, -8.2241e-14,
-        -6.1125e-11, -6.0125e-12, -1.6885e-01, -1.3907e-14,  1.6720e-10,
-         1.6714e-01, -2.3609e-01, -2.9532e-14, -6.9449e-01,  8.7718e-02,
-         2.3969e-10, -8.6926e-02,  1.6074e-01,  1.6395e-01,  4.2951e-14,
-        -5.8677e-01, -1.1115e+00, -1.0995e-17, -3.5199e-01,  5.4572e-12,
-         2.9445e-10,  1.8204e-13,  0.0000e+00, -1.1096e-01,  4.3641e-01,
-         8.0470e-02, -2.4314e-11,  1.0848e-01,  4.4801e-14,  1.0958e-11,
-        -2.9341e-01, -1.5461e-09,  4.9889e-03,  5.2012e-09, -2.8521e-15,
-        -9.2282e-06,  9.5768e-02, -3.4995e-02,  4.4252e-01, -1.0452e-06,
-         5.9358e-19, -2.3329e-13,  8.3898e-17,  1.1127e-01, -1.2117e-08,
-         2.2411e-16, -1.8644e-01,  1.8830e-04, -8.7255e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1172,  0.0000, -0.1181, -0.1055,  2.2069,  0.3583,  0.0000,  0.0000,
-        -0.3053, -0.2097,  0.2631,  0.0000, -0.4460, -0.5980,  0.0000,  0.0000,
-         0.0000, -0.1689,  0.0000,  0.0000,  0.1671, -0.2361,  0.0000, -0.6945,
-         0.0877,  0.0000, -0.0869,  0.1607,  0.1639,  0.0000, -0.5868, -1.1115,
-         0.0000, -0.3520,  0.0000,  0.0000,  0.0000,  0.0000, -0.1110,  0.4364,
-         0.0805,  0.0000,  0.1085,  0.0000,  0.0000, -0.2934,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0958, -0.0350,  0.4425,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.1113,  0.0000,  0.0000, -0.1864,  0.0000, -0.0873],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1172,  0.0000, -0.1181, -0.1055,  2.2069,  0.3583,  0.0000,  0.0000,
-        -0.3053, -0.2097,  0.2631,  0.0000, -0.4460, -0.5980,  0.0000,  0.0000,
-         0.0000, -0.1689,  0.0000,  0.0000,  0.1671, -0.2361,  0.0000, -0.6945,
-         0.0877,  0.0000, -0.0869,  0.1607,  0.1639,  0.0000, -0.5868, -1.1115,
-         0.0000, -0.3520,  0.0000,  0.0000,  0.0000,  0.0000, -0.1110,  0.4364,
-         0.0805,  0.0000,  0.1085,  0.0000,  0.0000, -0.2934,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0958, -0.0350,  0.4425,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.1113,  0.0000,  0.0000, -0.1864,  0.0000, -0.0873],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.0865e-01, -6.7871e-08, -1.8208e-01, -1.1466e-01,  2.2063e+00,
-         3.8849e-01,  4.1084e-17, -3.2440e-10, -2.9375e-01, -2.3723e-01,
-         2.4072e-01,  7.3857e-09, -4.3286e-01, -5.9550e-01, -7.3675e-14,
-        -5.4758e-11, -5.3863e-12, -1.6084e-01, -1.2458e-14,  1.4979e-10,
-         2.3365e-01, -2.0234e-01, -2.6456e-14, -6.8876e-01,  1.3362e-01,
-         2.1472e-10, -8.8727e-02,  2.0327e-01,  1.7660e-01,  3.8477e-14,
-        -6.2113e-01, -1.1059e+00, -9.8493e-18, -4.2517e-01,  4.8887e-12,
-         2.6378e-10,  1.6308e-13,  0.0000e+00, -7.7511e-02,  4.4660e-01,
-         1.2541e-01, -2.1782e-11,  1.4356e-03,  4.0134e-14,  9.8165e-12,
-        -2.4675e-01, -1.3851e-09,  4.4693e-03,  4.6594e-09, -2.5550e-15,
-        -8.2669e-06,  8.9244e-02, -2.6857e-02,  4.1353e-01, -9.3630e-07,
-         5.3175e-19, -2.0899e-13,  7.5159e-17,  1.5275e-01, -1.0855e-08,
-         2.0077e-16, -2.2142e-01,  1.6868e-04, -6.9252e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 1.0865e-01,  0.0000e+00, -1.8208e-01, -1.1466e-01,  2.2063e+00,
-         3.8849e-01,  0.0000e+00,  0.0000e+00, -2.9375e-01, -2.3723e-01,
-         2.4072e-01,  0.0000e+00, -4.3286e-01, -5.9550e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.6084e-01,  0.0000e+00,  0.0000e+00,
-         2.3365e-01, -2.0234e-01,  0.0000e+00, -6.8876e-01,  1.3362e-01,
-         0.0000e+00, -8.8727e-02,  2.0327e-01,  1.7660e-01,  0.0000e+00,
-        -6.2113e-01, -1.1059e+00,  0.0000e+00, -4.2517e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -7.7511e-02,  4.4660e-01,
-         1.2541e-01,  0.0000e+00,  1.4356e-03,  0.0000e+00,  0.0000e+00,
-        -2.4675e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  8.9244e-02, -2.6857e-02,  4.1353e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.5275e-01,  0.0000e+00,
-         0.0000e+00, -2.2142e-01,  0.0000e+00, -6.9252e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 1.0865e-01,  0.0000e+00, -1.8208e-01, -1.1466e-01,  2.2063e+00,
-         3.8849e-01,  0.0000e+00,  0.0000e+00, -2.9375e-01, -2.3723e-01,
-         2.4072e-01,  0.0000e+00, -4.3286e-01, -5.9550e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.6084e-01,  0.0000e+00,  0.0000e+00,
-         2.3365e-01, -2.0234e-01,  0.0000e+00, -6.8876e-01,  1.3362e-01,
-         0.0000e+00, -8.8727e-02,  2.0327e-01,  1.7660e-01,  0.0000e+00,
-        -6.2113e-01, -1.1059e+00,  0.0000e+00, -4.2517e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -7.7511e-02,  4.4660e-01,
-         1.2541e-01,  0.0000e+00,  1.4356e-03,  0.0000e+00,  0.0000e+00,
-        -2.4675e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  8.9244e-02, -2.6857e-02,  4.1353e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.5275e-01,  0.0000e+00,
-         0.0000e+00, -2.2142e-01,  0.0000e+00, -6.9252e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 8.7277e-02, -6.0825e-08, -2.3035e-01, -1.4992e-01,  2.2046e+00,
-         4.1104e-01,  3.6818e-17, -2.9072e-10, -2.9391e-01, -2.5233e-01,
-         1.6785e-01,  6.6189e-09, -4.3153e-01, -5.8040e-01, -6.6026e-14,
-        -4.9073e-11, -4.8271e-12, -1.8719e-01, -1.1165e-14,  1.3424e-10,
-         2.6702e-01, -1.4827e-01, -2.3710e-14, -6.6153e-01,  1.8758e-01,
-         1.9243e-10, -7.7243e-02,  1.9427e-01,  1.5553e-01,  3.4482e-14,
-        -6.3454e-01, -1.1021e+00, -8.8268e-18, -4.5784e-01,  4.3812e-12,
-         2.3639e-10,  1.4615e-13,  0.0000e+00, -6.2604e-02,  4.7636e-01,
-         1.1090e-01, -1.9520e-11, -9.6509e-02,  3.5968e-14,  8.7974e-12,
-        -2.0219e-01, -1.2413e-09,  4.0053e-03,  4.1757e-09, -2.2898e-15,
-        -7.4087e-06,  8.5606e-02, -5.7004e-02,  3.0213e-01, -8.3910e-07,
-         4.7654e-19, -1.8729e-13,  6.7356e-17,  2.0525e-01, -9.7277e-09,
-         1.7992e-16, -2.6373e-01,  1.5117e-04, -6.5413e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0873,  0.0000, -0.2304, -0.1499,  2.2046,  0.4110,  0.0000,  0.0000,
-        -0.2939, -0.2523,  0.1679,  0.0000, -0.4315, -0.5804,  0.0000,  0.0000,
-         0.0000, -0.1872,  0.0000,  0.0000,  0.2670, -0.1483,  0.0000, -0.6615,
-         0.1876,  0.0000, -0.0772,  0.1943,  0.1555,  0.0000, -0.6345, -1.1021,
-         0.0000, -0.4578,  0.0000,  0.0000,  0.0000,  0.0000, -0.0626,  0.4764,
-         0.1109,  0.0000, -0.0965,  0.0000,  0.0000, -0.2022,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0856, -0.0570,  0.3021,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2052,  0.0000,  0.0000, -0.2637,  0.0000, -0.0654],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0873,  0.0000, -0.2304, -0.1499,  2.2046,  0.4110,  0.0000,  0.0000,
-        -0.2939, -0.2523,  0.1679,  0.0000, -0.4315, -0.5804,  0.0000,  0.0000,
-         0.0000, -0.1872,  0.0000,  0.0000,  0.2670, -0.1483,  0.0000, -0.6615,
-         0.1876,  0.0000, -0.0772,  0.1943,  0.1555,  0.0000, -0.6345, -1.1021,
-         0.0000, -0.4578,  0.0000,  0.0000,  0.0000,  0.0000, -0.0626,  0.4764,
-         0.1109,  0.0000, -0.0965,  0.0000,  0.0000, -0.2022,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0856, -0.0570,  0.3021,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2052,  0.0000,  0.0000, -0.2637,  0.0000, -0.0654],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 7.7901e-02, -5.4532e-08, -2.3545e-01, -1.7069e-01,  2.2001e+00,
-         4.3889e-01,  3.3009e-17, -2.6064e-10, -3.1143e-01, -2.6091e-01,
-        -1.1294e-02,  5.9341e-09, -4.0632e-01, -5.6937e-01, -5.9195e-14,
-        -4.3996e-11, -4.3276e-12, -2.0525e-01, -1.0010e-14,  1.2035e-10,
-         3.0423e-01, -1.3081e-01, -2.1257e-14, -6.1010e-01,  1.5073e-01,
-         1.7252e-10, -2.9142e-02,  1.7123e-01,  1.5352e-01,  3.0915e-14,
-        -6.6092e-01, -1.0961e+00, -7.9135e-18, -4.7325e-01,  3.9279e-12,
-         2.1193e-10,  1.3103e-13,  0.0000e+00, -4.1359e-02,  4.8984e-01,
-         2.7288e-02, -1.7501e-11, -1.5565e-01,  3.2246e-14,  7.8872e-12,
-        -2.0647e-01, -1.1129e-09,  3.5909e-03,  3.7436e-09, -2.0529e-15,
-        -6.6422e-06,  5.4724e-02, -1.3199e-01,  2.1904e-01, -7.5228e-07,
-         4.2724e-19, -1.6791e-13,  6.0387e-17,  2.6036e-01, -8.7212e-09,
-         1.6131e-16, -3.0616e-01,  1.3553e-04, -4.8266e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0779,  0.0000, -0.2355, -0.1707,  2.2001,  0.4389,  0.0000,  0.0000,
-        -0.3114, -0.2609, -0.0113,  0.0000, -0.4063, -0.5694,  0.0000,  0.0000,
-         0.0000, -0.2052,  0.0000,  0.0000,  0.3042, -0.1308,  0.0000, -0.6101,
-         0.1507,  0.0000, -0.0291,  0.1712,  0.1535,  0.0000, -0.6609, -1.0961,
-         0.0000, -0.4733,  0.0000,  0.0000,  0.0000,  0.0000, -0.0414,  0.4898,
-         0.0273,  0.0000, -0.1556,  0.0000,  0.0000, -0.2065,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0547, -0.1320,  0.2190,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2604,  0.0000,  0.0000, -0.3062,  0.0000, -0.0483],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0779,  0.0000, -0.2355, -0.1707,  2.2001,  0.4389,  0.0000,  0.0000,
-        -0.3114, -0.2609, -0.0113,  0.0000, -0.4063, -0.5694,  0.0000,  0.0000,
-         0.0000, -0.2052,  0.0000,  0.0000,  0.3042, -0.1308,  0.0000, -0.6101,
-         0.1507,  0.0000, -0.0291,  0.1712,  0.1535,  0.0000, -0.6609, -1.0961,
-         0.0000, -0.4733,  0.0000,  0.0000,  0.0000,  0.0000, -0.0414,  0.4898,
-         0.0273,  0.0000, -0.1556,  0.0000,  0.0000, -0.2065,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0547, -0.1320,  0.2190,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2604,  0.0000,  0.0000, -0.3062,  0.0000, -0.0483],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 8.0849e-02, -4.8909e-08, -2.2952e-01, -1.9234e-01,  2.1929e+00,
-         4.6850e-01,  2.9605e-17, -2.3377e-10, -3.2204e-01, -2.6868e-01,
-        -1.4700e-01,  5.3222e-09, -3.8681e-01, -5.6844e-01, -5.3091e-14,
-        -3.9459e-11, -3.8814e-12, -1.9376e-01, -8.9774e-15,  1.0794e-10,
-         3.4703e-01, -1.1528e-01, -1.9065e-14, -5.4849e-01,  1.0807e-01,
-         1.5473e-10,  1.9534e-02,  1.4167e-01,  1.5303e-01,  2.7727e-14,
-        -6.7682e-01, -1.0872e+00, -7.0975e-18, -4.7177e-01,  3.5229e-12,
-         1.9008e-10,  1.1751e-13,  0.0000e+00, -2.2039e-02,  4.8278e-01,
-        -4.6581e-02, -1.5696e-11, -2.0071e-01,  2.8921e-14,  7.0739e-12,
-        -2.1677e-01, -9.9810e-10,  3.2206e-03,  3.3576e-09, -1.8412e-15,
-        -5.9573e-06,  2.7717e-02, -2.1946e-01,  1.5266e-01, -6.7471e-07,
-         3.8318e-19, -1.5060e-13,  5.4160e-17,  2.8602e-01, -7.8219e-09,
-         1.4468e-16, -3.4249e-01,  1.2156e-04, -2.9879e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0808,  0.0000, -0.2295, -0.1923,  2.1929,  0.4685,  0.0000,  0.0000,
-        -0.3220, -0.2687, -0.1470,  0.0000, -0.3868, -0.5684,  0.0000,  0.0000,
-         0.0000, -0.1938,  0.0000,  0.0000,  0.3470, -0.1153,  0.0000, -0.5485,
-         0.1081,  0.0000,  0.0195,  0.1417,  0.1530,  0.0000, -0.6768, -1.0872,
-         0.0000, -0.4718,  0.0000,  0.0000,  0.0000,  0.0000, -0.0220,  0.4828,
-        -0.0466,  0.0000, -0.2007,  0.0000,  0.0000, -0.2168,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0277, -0.2195,  0.1527,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2860,  0.0000,  0.0000, -0.3425,  0.0000, -0.0299],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0808,  0.0000, -0.2295, -0.1923,  2.1929,  0.4685,  0.0000,  0.0000,
-        -0.3220, -0.2687, -0.1470,  0.0000, -0.3868, -0.5684,  0.0000,  0.0000,
-         0.0000, -0.1938,  0.0000,  0.0000,  0.3470, -0.1153,  0.0000, -0.5485,
-         0.1081,  0.0000,  0.0195,  0.1417,  0.1530,  0.0000, -0.6768, -1.0872,
-         0.0000, -0.4718,  0.0000,  0.0000,  0.0000,  0.0000, -0.0220,  0.4828,
-        -0.0466,  0.0000, -0.2007,  0.0000,  0.0000, -0.2168,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0277, -0.2195,  0.1527,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.2860,  0.0000,  0.0000, -0.3425,  0.0000, -0.0299],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 8.8158e-02, -4.3882e-08, -2.0568e-01, -2.0216e-01,  2.1856e+00,
-         4.9726e-01,  2.6563e-17, -2.0974e-10, -3.2508e-01, -2.5369e-01,
-        -2.3415e-01,  4.7753e-09, -3.6092e-01, -5.7014e-01, -4.7635e-14,
-        -3.5404e-11, -3.4825e-12, -1.5378e-01, -8.0548e-15,  9.6845e-11,
-         3.6530e-01, -9.9443e-02, -1.7105e-14, -5.2247e-01,  8.6062e-02,
-         1.3883e-10,  6.9003e-02,  1.1497e-01,  1.5472e-01,  2.4877e-14,
-        -6.7629e-01, -1.0806e+00, -6.3681e-18, -4.6750e-01,  3.1608e-12,
-         1.7055e-10,  1.0544e-13,  0.0000e+00, -2.5522e-02,  4.5958e-01,
-        -1.2076e-01, -1.4083e-11, -2.2024e-01,  2.5949e-14,  6.3469e-12,
-        -2.0388e-01, -8.9553e-10,  2.8896e-03,  3.0126e-09, -1.6520e-15,
-        -5.3450e-06, -1.1358e-02, -2.8737e-01,  1.5933e-01, -6.0537e-07,
-         3.4380e-19, -1.3512e-13,  4.8594e-17,  3.1114e-01, -7.0181e-09,
-         1.2981e-16, -3.6035e-01,  1.0906e-04, -4.3638e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0882,  0.0000, -0.2057, -0.2022,  2.1856,  0.4973,  0.0000,  0.0000,
-        -0.3251, -0.2537, -0.2342,  0.0000, -0.3609, -0.5701,  0.0000,  0.0000,
-         0.0000, -0.1538,  0.0000,  0.0000,  0.3653, -0.0994,  0.0000, -0.5225,
-         0.0861,  0.0000,  0.0690,  0.1150,  0.1547,  0.0000, -0.6763, -1.0806,
-         0.0000, -0.4675,  0.0000,  0.0000,  0.0000,  0.0000, -0.0255,  0.4596,
-        -0.1208,  0.0000, -0.2202,  0.0000,  0.0000, -0.2039,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0114, -0.2874,  0.1593,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3111,  0.0000,  0.0000, -0.3603,  0.0000, -0.0044],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0882,  0.0000, -0.2057, -0.2022,  2.1856,  0.4973,  0.0000,  0.0000,
-        -0.3251, -0.2537, -0.2342,  0.0000, -0.3609, -0.5701,  0.0000,  0.0000,
-         0.0000, -0.1538,  0.0000,  0.0000,  0.3653, -0.0994,  0.0000, -0.5225,
-         0.0861,  0.0000,  0.0690,  0.1150,  0.1547,  0.0000, -0.6763, -1.0806,
-         0.0000, -0.4675,  0.0000,  0.0000,  0.0000,  0.0000, -0.0255,  0.4596,
-        -0.1208,  0.0000, -0.2202,  0.0000,  0.0000, -0.2039,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0114, -0.2874,  0.1593,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3111,  0.0000,  0.0000, -0.3603,  0.0000, -0.0044],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.1860e-01, -3.9388e-08, -1.9417e-01, -2.1915e-01,  2.1841e+00,
-         5.0431e-01,  2.3842e-17, -1.8826e-10, -3.3885e-01, -2.2301e-01,
-        -2.6224e-01,  4.2862e-09, -3.4094e-01, -5.6932e-01, -4.2756e-14,
-        -3.1778e-11, -3.1258e-12, -1.3496e-01, -7.2298e-15,  8.6926e-11,
-         4.0009e-01, -7.4254e-02, -1.5353e-14, -4.9488e-01,  7.4346e-02,
-         1.2461e-10,  8.6249e-02,  1.0824e-01,  1.8810e-01,  2.2329e-14,
-        -6.7527e-01, -1.0768e+00, -5.7158e-18, -4.6688e-01,  2.8371e-12,
-         1.5308e-10,  9.4638e-14,  0.0000e+00,  2.2788e-03,  4.6578e-01,
-        -1.4241e-01, -1.2641e-11, -2.3910e-01,  2.3291e-14,  5.6968e-12,
-        -1.8526e-01, -8.0381e-10,  2.5937e-03,  2.7040e-09, -1.4828e-15,
-        -4.7976e-06, -2.1790e-02, -3.3516e-01,  1.6476e-01, -5.4336e-07,
-         3.0859e-19, -1.2128e-13,  4.3617e-17,  3.2496e-01, -6.2993e-09,
-         1.1651e-16, -4.0126e-01,  9.7892e-05, -1.4540e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1186,  0.0000, -0.1942, -0.2192,  2.1841,  0.5043,  0.0000,  0.0000,
-        -0.3388, -0.2230, -0.2622,  0.0000, -0.3409, -0.5693,  0.0000,  0.0000,
-         0.0000, -0.1350,  0.0000,  0.0000,  0.4001, -0.0743,  0.0000, -0.4949,
-         0.0743,  0.0000,  0.0862,  0.1082,  0.1881,  0.0000, -0.6753, -1.0768,
-         0.0000, -0.4669,  0.0000,  0.0000,  0.0000,  0.0000,  0.0023,  0.4658,
-        -0.1424,  0.0000, -0.2391,  0.0000,  0.0000, -0.1853,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0218, -0.3352,  0.1648,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3250,  0.0000,  0.0000, -0.4013,  0.0000, -0.0145],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1186,  0.0000, -0.1942, -0.2192,  2.1841,  0.5043,  0.0000,  0.0000,
-        -0.3388, -0.2230, -0.2622,  0.0000, -0.3409, -0.5693,  0.0000,  0.0000,
-         0.0000, -0.1350,  0.0000,  0.0000,  0.4001, -0.0743,  0.0000, -0.4949,
-         0.0743,  0.0000,  0.0862,  0.1082,  0.1881,  0.0000, -0.6753, -1.0768,
-         0.0000, -0.4669,  0.0000,  0.0000,  0.0000,  0.0000,  0.0023,  0.4658,
-        -0.1424,  0.0000, -0.2391,  0.0000,  0.0000, -0.1853,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0218, -0.3352,  0.1648,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3250,  0.0000,  0.0000, -0.4013,  0.0000, -0.0145],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.5159e-01, -3.5367e-08, -1.7159e-01, -2.0020e-01,  2.1837e+00,
-         4.9485e-01,  2.1408e-17, -1.6904e-10, -3.5114e-01, -1.8128e-01,
-        -2.7092e-01,  3.8486e-09, -3.2484e-01, -5.6438e-01, -3.8391e-14,
-        -2.8534e-11, -2.8067e-12, -1.0087e-01, -6.4918e-15,  7.8053e-11,
-         4.1468e-01, -3.3873e-02, -1.3786e-14, -4.9376e-01,  8.2251e-02,
-         1.1189e-10,  9.2754e-02,  8.9762e-02,  2.1908e-01,  2.0050e-14,
-        -6.7266e-01, -1.0762e+00, -5.1324e-18, -4.5892e-01,  2.5475e-12,
-         1.3745e-10,  8.4978e-14,  0.0000e+00,  4.8161e-02,  4.8525e-01,
-        -1.6192e-01, -1.1350e-11, -2.4874e-01,  2.0914e-14,  5.1153e-12,
-        -1.5843e-01, -7.2175e-10,  2.3289e-03,  2.4280e-09, -1.3314e-15,
-        -4.3078e-06, -2.5467e-02, -3.8119e-01,  2.1734e-01, -4.8790e-07,
-         2.7709e-19, -1.0890e-13,  3.9165e-17,  3.4215e-01, -5.6562e-09,
-         1.0462e-16, -4.2493e-01,  8.7899e-05, -3.8984e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1516,  0.0000, -0.1716, -0.2002,  2.1837,  0.4948,  0.0000,  0.0000,
-        -0.3511, -0.1813, -0.2709,  0.0000, -0.3248, -0.5644,  0.0000,  0.0000,
-         0.0000, -0.1009,  0.0000,  0.0000,  0.4147, -0.0339,  0.0000, -0.4938,
-         0.0823,  0.0000,  0.0928,  0.0898,  0.2191,  0.0000, -0.6727, -1.0762,
-         0.0000, -0.4589,  0.0000,  0.0000,  0.0000,  0.0000,  0.0482,  0.4853,
-        -0.1619,  0.0000, -0.2487,  0.0000,  0.0000, -0.1584,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0255, -0.3812,  0.2173,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3421,  0.0000,  0.0000, -0.4249,  0.0000, -0.0390],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1516,  0.0000, -0.1716, -0.2002,  2.1837,  0.4948,  0.0000,  0.0000,
-        -0.3511, -0.1813, -0.2709,  0.0000, -0.3248, -0.5644,  0.0000,  0.0000,
-         0.0000, -0.1009,  0.0000,  0.0000,  0.4147, -0.0339,  0.0000, -0.4938,
-         0.0823,  0.0000,  0.0928,  0.0898,  0.2191,  0.0000, -0.6727, -1.0762,
-         0.0000, -0.4589,  0.0000,  0.0000,  0.0000,  0.0000,  0.0482,  0.4853,
-        -0.1619,  0.0000, -0.2487,  0.0000,  0.0000, -0.1584,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0255, -0.3812,  0.2173,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3421,  0.0000,  0.0000, -0.4249,  0.0000, -0.0390],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.0056e-01, -3.1769e-08, -1.5123e-01, -1.8306e-01,  2.1831e+00,
-         4.7450e-01,  1.9230e-17, -1.5185e-10, -3.6834e-01, -1.1915e-01,
-        -2.7160e-01,  3.4571e-09, -3.1276e-01, -5.6668e-01, -3.4485e-14,
-        -2.5631e-11, -2.5212e-12, -5.3643e-02, -5.8313e-15,  7.0112e-11,
-         4.1631e-01,  2.0690e-02, -1.2384e-14, -5.1648e-01,  8.8515e-02,
-         1.0051e-10,  9.6805e-02,  6.2519e-02,  2.3528e-01,  1.8010e-14,
-        -6.6496e-01, -1.0797e+00, -4.6102e-18, -4.6288e-01,  2.2883e-12,
-         1.2347e-10,  7.6332e-14,  0.0000e+00,  1.3235e-01,  5.0308e-01,
-        -1.5819e-01, -1.0195e-11, -2.5954e-01,  1.8786e-14,  4.5949e-12,
-        -1.4248e-01, -6.4833e-10,  2.0920e-03,  2.1810e-09, -1.1960e-15,
-        -3.8696e-06, -2.8477e-02, -4.0378e-01,  3.0796e-01, -4.3826e-07,
-         2.4890e-19, -9.7822e-14,  3.5180e-17,  3.5346e-01, -5.0808e-09,
-         9.3975e-17, -4.4582e-01,  7.8957e-05, -7.5409e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2006,  0.0000, -0.1512, -0.1831,  2.1831,  0.4745,  0.0000,  0.0000,
-        -0.3683, -0.1191, -0.2716,  0.0000, -0.3128, -0.5667,  0.0000,  0.0000,
-         0.0000, -0.0536,  0.0000,  0.0000,  0.4163,  0.0207,  0.0000, -0.5165,
-         0.0885,  0.0000,  0.0968,  0.0625,  0.2353,  0.0000, -0.6650, -1.0797,
-         0.0000, -0.4629,  0.0000,  0.0000,  0.0000,  0.0000,  0.1323,  0.5031,
-        -0.1582,  0.0000, -0.2595,  0.0000,  0.0000, -0.1425,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0285, -0.4038,  0.3080,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3535,  0.0000,  0.0000, -0.4458,  0.0000, -0.0754],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2006,  0.0000, -0.1512, -0.1831,  2.1831,  0.4745,  0.0000,  0.0000,
-        -0.3683, -0.1191, -0.2716,  0.0000, -0.3128, -0.5667,  0.0000,  0.0000,
-         0.0000, -0.0536,  0.0000,  0.0000,  0.4163,  0.0207,  0.0000, -0.5165,
-         0.0885,  0.0000,  0.0968,  0.0625,  0.2353,  0.0000, -0.6650, -1.0797,
-         0.0000, -0.4629,  0.0000,  0.0000,  0.0000,  0.0000,  0.1323,  0.5031,
-        -0.1582,  0.0000, -0.2595,  0.0000,  0.0000, -0.1425,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0285, -0.4038,  0.3080,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3535,  0.0000,  0.0000, -0.4458,  0.0000, -0.0754],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.3413e-01, -2.8548e-08, -1.4251e-01, -1.7883e-01,  2.1824e+00,
-         4.4955e-01,  1.7280e-17, -1.3645e-10, -3.8090e-01, -6.3280e-02,
-        -2.6438e-01,  3.1066e-09, -3.1060e-01, -5.5767e-01, -3.0989e-14,
-        -2.3032e-11, -2.2656e-12, -9.7334e-03, -5.2401e-15,  6.3003e-11,
-         4.0821e-01,  9.2159e-02, -1.1128e-14, -5.2195e-01,  1.1423e-01,
-         9.0316e-11,  9.0726e-02,  3.6295e-02,  2.5391e-01,  1.6184e-14,
-        -6.5224e-01, -1.0842e+00, -4.1428e-18, -4.6633e-01,  2.0563e-12,
-         1.1095e-10,  6.8593e-14,  0.0000e+00,  1.6752e-01,  5.1606e-01,
-        -1.3779e-01, -9.1617e-12, -2.7304e-01,  1.6881e-14,  4.1290e-12,
-        -1.1879e-01, -5.8259e-10,  1.8799e-03,  1.9598e-09, -1.0747e-15,
-        -3.4772e-06, -1.5189e-02, -4.1812e-01,  3.8017e-01, -3.9383e-07,
-         2.2366e-19, -8.7903e-14,  3.1613e-17,  3.7162e-01, -4.5656e-09,
-         8.4446e-17, -4.7260e-01,  7.0951e-05, -1.1006e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2341,  0.0000, -0.1425, -0.1788,  2.1824,  0.4496,  0.0000,  0.0000,
-        -0.3809, -0.0633, -0.2644,  0.0000, -0.3106, -0.5577,  0.0000,  0.0000,
-         0.0000, -0.0097,  0.0000,  0.0000,  0.4082,  0.0922,  0.0000, -0.5220,
-         0.1142,  0.0000,  0.0907,  0.0363,  0.2539,  0.0000, -0.6522, -1.0842,
-         0.0000, -0.4663,  0.0000,  0.0000,  0.0000,  0.0000,  0.1675,  0.5161,
-        -0.1378,  0.0000, -0.2730,  0.0000,  0.0000, -0.1188,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0152, -0.4181,  0.3802,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3716,  0.0000,  0.0000, -0.4726,  0.0000, -0.1101],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2341,  0.0000, -0.1425, -0.1788,  2.1824,  0.4496,  0.0000,  0.0000,
-        -0.3809, -0.0633, -0.2644,  0.0000, -0.3106, -0.5577,  0.0000,  0.0000,
-         0.0000, -0.0097,  0.0000,  0.0000,  0.4082,  0.0922,  0.0000, -0.5220,
-         0.1142,  0.0000,  0.0907,  0.0363,  0.2539,  0.0000, -0.6522, -1.0842,
-         0.0000, -0.4663,  0.0000,  0.0000,  0.0000,  0.0000,  0.1675,  0.5161,
-        -0.1378,  0.0000, -0.2730,  0.0000,  0.0000, -0.1188,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0152, -0.4181,  0.3802,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3716,  0.0000,  0.0000, -0.4726,  0.0000, -0.1101],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.7548e-01, -2.5663e-08, -1.3338e-01, -1.6024e-01,  2.1807e+00,
-         4.0528e-01,  1.5534e-17, -1.2266e-10, -3.7627e-01, -1.2850e-03,
-        -2.6194e-01,  2.7926e-09, -2.8646e-01, -5.5185e-01, -2.7857e-14,
-        -2.0705e-11, -2.0366e-12,  2.7113e-02, -4.7105e-15,  5.6636e-11,
-         4.0526e-01,  1.4718e-01, -1.0003e-14, -5.2873e-01,  1.2618e-01,
-         8.1189e-11,  8.1500e-02,  1.7100e-02,  2.6130e-01,  1.4549e-14,
-        -6.4595e-01, -1.0916e+00, -3.7241e-18, -4.8084e-01,  1.8485e-12,
-         9.9737e-11,  6.1661e-14,  0.0000e+00,  2.0802e-01,  5.3673e-01,
-        -1.1928e-01, -8.2359e-12, -2.7776e-01,  1.5175e-14,  3.7117e-12,
-        -8.9848e-02, -5.2372e-10,  1.6899e-03,  1.7618e-09, -9.6610e-16,
-        -3.1258e-06,  2.4375e-02, -4.1497e-01,  4.6433e-01, -3.5403e-07,
-         2.0106e-19, -7.9020e-14,  2.8419e-17,  3.8264e-01, -4.1043e-09,
-         7.5913e-17, -4.8093e-01,  6.3781e-05, -1.3466e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.7548e-01,  0.0000e+00, -1.3338e-01, -1.6024e-01,  2.1807e+00,
-         4.0528e-01,  0.0000e+00,  0.0000e+00, -3.7627e-01, -1.2850e-03,
-        -2.6194e-01,  0.0000e+00, -2.8646e-01, -5.5185e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  2.7113e-02,  0.0000e+00,  0.0000e+00,
-         4.0526e-01,  1.4718e-01,  0.0000e+00, -5.2873e-01,  1.2618e-01,
-         0.0000e+00,  8.1500e-02,  1.7100e-02,  2.6130e-01,  0.0000e+00,
-        -6.4595e-01, -1.0916e+00,  0.0000e+00, -4.8084e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  2.0802e-01,  5.3673e-01,
-        -1.1928e-01,  0.0000e+00, -2.7776e-01,  0.0000e+00,  0.0000e+00,
-        -8.9848e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  2.4375e-02, -4.1497e-01,  4.6433e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.8264e-01,  0.0000e+00,
-         0.0000e+00, -4.8093e-01,  0.0000e+00, -1.3466e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.7548e-01,  0.0000e+00, -1.3338e-01, -1.6024e-01,  2.1807e+00,
-         4.0528e-01,  0.0000e+00,  0.0000e+00, -3.7627e-01, -1.2850e-03,
-        -2.6194e-01,  0.0000e+00, -2.8646e-01, -5.5185e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  2.7113e-02,  0.0000e+00,  0.0000e+00,
-         4.0526e-01,  1.4718e-01,  0.0000e+00, -5.2873e-01,  1.2618e-01,
-         0.0000e+00,  8.1500e-02,  1.7100e-02,  2.6130e-01,  0.0000e+00,
-        -6.4595e-01, -1.0916e+00,  0.0000e+00, -4.8084e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  2.0802e-01,  5.3673e-01,
-        -1.1928e-01,  0.0000e+00, -2.7776e-01,  0.0000e+00,  0.0000e+00,
-        -8.9848e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  2.4375e-02, -4.1497e-01,  4.6433e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  3.8264e-01,  0.0000e+00,
-         0.0000e+00, -4.8093e-01,  0.0000e+00, -1.3466e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9162e-01, -2.3078e-08, -1.2654e-01, -1.5506e-01,  2.1791e+00,
-         3.6061e-01,  1.3970e-17, -1.1031e-10, -3.6315e-01,  4.2981e-02,
-        -2.4182e-01,  2.5114e-09, -2.6364e-01, -5.5068e-01, -2.5052e-14,
-        -1.8619e-11, -1.8315e-12,  5.2867e-02, -4.2361e-15,  5.0932e-11,
-         3.9400e-01,  2.0614e-01, -8.9960e-15, -5.3159e-01,  1.4412e-01,
-         7.3012e-11,  7.2858e-02, -1.1051e-02,  2.4810e-01,  1.3083e-14,
-        -6.4119e-01, -1.0992e+00, -3.3491e-18, -4.9478e-01,  1.6623e-12,
-         8.9692e-11,  5.5451e-14,  0.0000e+00,  2.2818e-01,  5.3236e-01,
-        -1.0247e-01, -7.4064e-12, -2.8088e-01,  1.3647e-14,  3.3379e-12,
-        -5.1483e-02, -4.7097e-10,  1.5197e-03,  1.5843e-09, -8.6880e-16,
-        -2.8110e-06,  5.6270e-02, -3.8638e-01,  4.9464e-01, -3.1837e-07,
-         1.8081e-19, -7.1062e-14,  2.5556e-17,  4.0585e-01, -3.6909e-09,
-         6.8267e-17, -4.8845e-01,  5.7358e-05, -1.4390e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2916,  0.0000, -0.1265, -0.1551,  2.1791,  0.3606,  0.0000,  0.0000,
-        -0.3631,  0.0430, -0.2418,  0.0000, -0.2636, -0.5507,  0.0000,  0.0000,
-         0.0000,  0.0529,  0.0000,  0.0000,  0.3940,  0.2061,  0.0000, -0.5316,
-         0.1441,  0.0000,  0.0729, -0.0111,  0.2481,  0.0000, -0.6412, -1.0992,
-         0.0000, -0.4948,  0.0000,  0.0000,  0.0000,  0.0000,  0.2282,  0.5324,
-        -0.1025,  0.0000, -0.2809,  0.0000,  0.0000, -0.0515,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0563, -0.3864,  0.4946,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4058,  0.0000,  0.0000, -0.4884,  0.0000, -0.1439],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2916,  0.0000, -0.1265, -0.1551,  2.1791,  0.3606,  0.0000,  0.0000,
-        -0.3631,  0.0430, -0.2418,  0.0000, -0.2636, -0.5507,  0.0000,  0.0000,
-         0.0000,  0.0529,  0.0000,  0.0000,  0.3940,  0.2061,  0.0000, -0.5316,
-         0.1441,  0.0000,  0.0729, -0.0111,  0.2481,  0.0000, -0.6412, -1.0992,
-         0.0000, -0.4948,  0.0000,  0.0000,  0.0000,  0.0000,  0.2282,  0.5324,
-        -0.1025,  0.0000, -0.2809,  0.0000,  0.0000, -0.0515,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0563, -0.3864,  0.4946,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4058,  0.0000,  0.0000, -0.4884,  0.0000, -0.1439],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1972e-01, -2.0762e-08, -1.1903e-01, -1.4196e-01,  2.1776e+00,
-         3.1615e-01,  1.2568e-17, -9.9235e-11, -3.3834e-01,  7.3959e-02,
-        -1.9676e-01,  2.2593e-09, -2.3950e-01, -5.3189e-01, -2.2537e-14,
-        -1.6750e-11, -1.6477e-12,  3.3702e-02, -3.8109e-15,  4.5820e-11,
-         3.8152e-01,  2.5769e-01, -8.0930e-15, -5.5177e-01,  1.6626e-01,
-         6.5683e-11,  4.9992e-02, -3.4816e-02,  2.2942e-01,  1.1770e-14,
-        -6.4460e-01, -1.1089e+00, -3.0129e-18, -5.0445e-01,  1.4955e-12,
-         8.0689e-11,  4.9885e-14,  0.0000e+00,  2.3938e-01,  5.2122e-01,
-        -7.7574e-02, -6.6630e-12, -2.7592e-01,  1.2277e-14,  3.0029e-12,
-        -2.8329e-02, -4.2370e-10,  1.3672e-03,  1.4253e-09, -7.8159e-16,
-        -2.5289e-06,  8.6942e-02, -3.3848e-01,  4.5782e-01, -2.8642e-07,
-         1.6266e-19, -6.3929e-14,  2.2991e-17,  4.2619e-01, -3.3204e-09,
-         6.1415e-17, -4.9603e-01,  5.1600e-05, -1.2836e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3197,  0.0000, -0.1190, -0.1420,  2.1776,  0.3162,  0.0000,  0.0000,
-        -0.3383,  0.0740, -0.1968,  0.0000, -0.2395, -0.5319,  0.0000,  0.0000,
-         0.0000,  0.0337,  0.0000,  0.0000,  0.3815,  0.2577,  0.0000, -0.5518,
-         0.1663,  0.0000,  0.0500, -0.0348,  0.2294,  0.0000, -0.6446, -1.1089,
-         0.0000, -0.5044,  0.0000,  0.0000,  0.0000,  0.0000,  0.2394,  0.5212,
-        -0.0776,  0.0000, -0.2759,  0.0000,  0.0000, -0.0283,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0869, -0.3385,  0.4578,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4262,  0.0000,  0.0000, -0.4960,  0.0000, -0.1284],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3197,  0.0000, -0.1190, -0.1420,  2.1776,  0.3162,  0.0000,  0.0000,
-        -0.3383,  0.0740, -0.1968,  0.0000, -0.2395, -0.5319,  0.0000,  0.0000,
-         0.0000,  0.0337,  0.0000,  0.0000,  0.3815,  0.2577,  0.0000, -0.5518,
-         0.1663,  0.0000,  0.0500, -0.0348,  0.2294,  0.0000, -0.6446, -1.1089,
-         0.0000, -0.5044,  0.0000,  0.0000,  0.0000,  0.0000,  0.2394,  0.5212,
-        -0.0776,  0.0000, -0.2759,  0.0000,  0.0000, -0.0283,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0869, -0.3385,  0.4578,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4262,  0.0000,  0.0000, -0.4960,  0.0000, -0.1284],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.3476e-01, -1.8685e-08, -1.2629e-01, -1.4151e-01,  2.1754e+00,
-         3.0189e-01,  1.1310e-17, -8.9308e-11, -2.9716e-01,  7.7018e-02,
-        -1.4756e-01,  2.0333e-09, -2.1819e-01, -5.1534e-01, -2.0283e-14,
-        -1.5075e-11, -1.4828e-12,  3.8296e-03, -3.4297e-15,  4.1236e-11,
-         3.7151e-01,  2.8307e-01, -7.2834e-15, -5.7182e-01,  1.6477e-01,
-         5.9112e-11,  4.0789e-02, -6.1798e-02,  2.1454e-01,  1.0593e-14,
-        -6.5109e-01, -1.1187e+00, -2.7115e-18, -5.0840e-01,  1.3459e-12,
-         7.2617e-11,  4.4895e-14,  0.0000e+00,  2.2067e-01,  4.9120e-01,
-        -6.0785e-02, -5.9964e-12, -2.8235e-01,  1.1049e-14,  2.7025e-12,
-         3.2307e-03, -3.8131e-10,  1.2304e-03,  1.2827e-09, -7.0340e-16,
-        -2.2759e-06,  9.6949e-02, -2.6468e-01,  3.9122e-01, -2.5776e-07,
-         1.4639e-19, -5.7534e-14,  2.0691e-17,  4.5740e-01, -2.9883e-09,
-         5.5271e-17, -4.9029e-01,  4.6438e-05, -9.0530e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3348,  0.0000, -0.1263, -0.1415,  2.1754,  0.3019,  0.0000,  0.0000,
-        -0.2972,  0.0770, -0.1476,  0.0000, -0.2182, -0.5153,  0.0000,  0.0000,
-         0.0000,  0.0038,  0.0000,  0.0000,  0.3715,  0.2831,  0.0000, -0.5718,
-         0.1648,  0.0000,  0.0408, -0.0618,  0.2145,  0.0000, -0.6511, -1.1187,
-         0.0000, -0.5084,  0.0000,  0.0000,  0.0000,  0.0000,  0.2207,  0.4912,
-        -0.0608,  0.0000, -0.2824,  0.0000,  0.0000,  0.0032,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0969, -0.2647,  0.3912,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4574,  0.0000,  0.0000, -0.4903,  0.0000, -0.0905],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3348,  0.0000, -0.1263, -0.1415,  2.1754,  0.3019,  0.0000,  0.0000,
-        -0.2972,  0.0770, -0.1476,  0.0000, -0.2182, -0.5153,  0.0000,  0.0000,
-         0.0000,  0.0038,  0.0000,  0.0000,  0.3715,  0.2831,  0.0000, -0.5718,
-         0.1648,  0.0000,  0.0408, -0.0618,  0.2145,  0.0000, -0.6511, -1.1187,
-         0.0000, -0.5084,  0.0000,  0.0000,  0.0000,  0.0000,  0.2207,  0.4912,
-        -0.0608,  0.0000, -0.2824,  0.0000,  0.0000,  0.0032,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0969, -0.2647,  0.3912,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4574,  0.0000,  0.0000, -0.4903,  0.0000, -0.0905],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4350e-01, -1.6822e-08, -1.2135e-01, -1.3644e-01,  2.1728e+00,
-         2.8648e-01,  1.0183e-17, -8.0403e-11, -2.5700e-01,  8.6198e-02,
-        -8.1634e-02,  1.8305e-09, -1.8948e-01, -5.0418e-01, -1.8260e-14,
-        -1.3572e-11, -1.3350e-12, -2.0093e-02, -3.0877e-15,  3.7125e-11,
-         3.5317e-01,  2.9839e-01, -6.5572e-15, -6.2921e-01,  1.6288e-01,
-         5.3219e-11,  5.0730e-02, -8.6931e-02,  1.9487e-01,  9.5365e-15,
-        -6.5378e-01, -1.1307e+00, -2.4411e-18, -5.2250e-01,  1.2117e-12,
-         6.5377e-11,  4.0418e-14,  0.0000e+00,  1.9503e-01,  4.5270e-01,
-        -5.7417e-02, -5.3986e-12, -2.7732e-01,  9.9473e-15,  2.4330e-12,
-         4.7138e-02, -3.4329e-10,  1.1077e-03,  1.1548e-09, -6.3327e-16,
-        -2.0490e-06,  9.9572e-02, -1.9544e-01,  2.9516e-01, -2.3206e-07,
-         1.3179e-19, -5.1797e-14,  1.8628e-17,  4.9191e-01, -2.6903e-09,
-         4.9760e-17, -4.7307e-01,  4.1808e-05, -2.5427e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3435,  0.0000, -0.1213, -0.1364,  2.1728,  0.2865,  0.0000,  0.0000,
-        -0.2570,  0.0862, -0.0816,  0.0000, -0.1895, -0.5042,  0.0000,  0.0000,
-         0.0000, -0.0201,  0.0000,  0.0000,  0.3532,  0.2984,  0.0000, -0.6292,
-         0.1629,  0.0000,  0.0507, -0.0869,  0.1949,  0.0000, -0.6538, -1.1307,
-         0.0000, -0.5225,  0.0000,  0.0000,  0.0000,  0.0000,  0.1950,  0.4527,
-        -0.0574,  0.0000, -0.2773,  0.0000,  0.0000,  0.0471,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0996, -0.1954,  0.2952,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4919,  0.0000,  0.0000, -0.4731,  0.0000, -0.0254],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3435,  0.0000, -0.1213, -0.1364,  2.1728,  0.2865,  0.0000,  0.0000,
-        -0.2570,  0.0862, -0.0816,  0.0000, -0.1895, -0.5042,  0.0000,  0.0000,
-         0.0000, -0.0201,  0.0000,  0.0000,  0.3532,  0.2984,  0.0000, -0.6292,
-         0.1629,  0.0000,  0.0507, -0.0869,  0.1949,  0.0000, -0.6538, -1.1307,
-         0.0000, -0.5225,  0.0000,  0.0000,  0.0000,  0.0000,  0.1950,  0.4527,
-        -0.0574,  0.0000, -0.2773,  0.0000,  0.0000,  0.0471,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0996, -0.1954,  0.2952,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4919,  0.0000,  0.0000, -0.4731,  0.0000, -0.0254],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4566e-01, -1.5150e-08, -1.2021e-01, -1.3811e-01,  2.1698e+00,
-         2.7488e-01,  9.1707e-18, -7.2413e-11, -2.1275e-01,  8.5270e-02,
-        -1.6017e-02,  1.6486e-09, -1.5622e-01, -4.8785e-01, -1.6446e-14,
-        -1.2223e-11, -1.2023e-12, -4.8906e-02, -2.7809e-15,  3.3436e-11,
-         3.3580e-01,  2.9030e-01, -5.9056e-15, -6.8274e-01,  1.4093e-01,
-         4.7930e-11,  6.8274e-02, -1.0490e-01,  1.7695e-01,  8.5888e-15,
-        -6.5756e-01, -1.1418e+00, -2.1986e-18, -5.3391e-01,  1.0913e-12,
-         5.8880e-11,  3.6402e-14,  0.0000e+00,  1.7121e-01,  4.1787e-01,
-        -6.2913e-02, -4.8621e-12, -2.8114e-01,  8.9589e-15,  2.1912e-12,
-         1.1269e-01, -3.0918e-10,  9.9764e-04,  1.0401e-09, -5.7034e-16,
-        -1.8454e-06,  9.1831e-02, -9.4393e-02,  1.9276e-01, -2.0900e-07,
-         1.1870e-19, -4.6650e-14,  1.6777e-17,  5.3321e-01, -2.4230e-09,
-         4.4815e-17, -4.4573e-01,  3.7654e-05,  4.7404e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3457,  0.0000, -0.1202, -0.1381,  2.1698,  0.2749,  0.0000,  0.0000,
-        -0.2127,  0.0853, -0.0160,  0.0000, -0.1562, -0.4879,  0.0000,  0.0000,
-         0.0000, -0.0489,  0.0000,  0.0000,  0.3358,  0.2903,  0.0000, -0.6827,
-         0.1409,  0.0000,  0.0683, -0.1049,  0.1770,  0.0000, -0.6576, -1.1418,
-         0.0000, -0.5339,  0.0000,  0.0000,  0.0000,  0.0000,  0.1712,  0.4179,
-        -0.0629,  0.0000, -0.2811,  0.0000,  0.0000,  0.1127,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0918, -0.0944,  0.1928,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5332,  0.0000,  0.0000, -0.4457,  0.0000,  0.0474],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3457,  0.0000, -0.1202, -0.1381,  2.1698,  0.2749,  0.0000,  0.0000,
-        -0.2127,  0.0853, -0.0160,  0.0000, -0.1562, -0.4879,  0.0000,  0.0000,
-         0.0000, -0.0489,  0.0000,  0.0000,  0.3358,  0.2903,  0.0000, -0.6827,
-         0.1409,  0.0000,  0.0683, -0.1049,  0.1770,  0.0000, -0.6576, -1.1418,
-         0.0000, -0.5339,  0.0000,  0.0000,  0.0000,  0.0000,  0.1712,  0.4179,
-        -0.0629,  0.0000, -0.2811,  0.0000,  0.0000,  0.1127,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0918, -0.0944,  0.1928,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5332,  0.0000,  0.0000, -0.4457,  0.0000,  0.0474],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6494e-01, -1.3650e-08, -1.1858e-01, -1.2496e-01,  2.1669e+00,
-         2.6179e-01,  8.2624e-18, -6.5241e-11, -1.7382e-01,  8.7583e-02,
-         1.1389e-02,  1.4854e-09, -1.2473e-01, -4.8490e-01, -1.4817e-14,
-        -1.1012e-11, -1.0832e-12, -3.3844e-02, -2.5055e-15,  3.0124e-11,
-         3.3300e-01,  2.7813e-01, -5.3207e-15, -7.3835e-01,  9.9127e-02,
-         4.3183e-11,  8.5187e-02, -9.3523e-02,  1.6559e-01,  7.7382e-15,
-        -6.4384e-01, -1.1541e+00, -1.9808e-18, -5.2862e-01,  9.8318e-13,
-         5.3049e-11,  3.2797e-14,  0.0000e+00,  1.8341e-01,  4.2557e-01,
-        -6.4252e-02, -4.3805e-12, -2.9425e-01,  8.0716e-15,  1.9742e-12,
-         1.3845e-01, -2.7856e-10,  8.9883e-04,  9.3707e-10, -5.1385e-16,
-        -1.6626e-06,  9.9805e-02, -3.5748e-02,  1.3988e-01, -1.8830e-07,
-         1.0694e-19, -4.2030e-14,  1.5115e-17,  5.6053e-01, -2.1830e-09,
-         4.0377e-17, -4.0925e-01,  3.3924e-05,  7.4484e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3649,  0.0000, -0.1186, -0.1250,  2.1669,  0.2618,  0.0000,  0.0000,
-        -0.1738,  0.0876,  0.0114,  0.0000, -0.1247, -0.4849,  0.0000,  0.0000,
-         0.0000, -0.0338,  0.0000,  0.0000,  0.3330,  0.2781,  0.0000, -0.7384,
-         0.0991,  0.0000,  0.0852, -0.0935,  0.1656,  0.0000, -0.6438, -1.1541,
-         0.0000, -0.5286,  0.0000,  0.0000,  0.0000,  0.0000,  0.1834,  0.4256,
-        -0.0643,  0.0000, -0.2943,  0.0000,  0.0000,  0.1385,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0998, -0.0357,  0.1399,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5605,  0.0000,  0.0000, -0.4092,  0.0000,  0.0745],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3649,  0.0000, -0.1186, -0.1250,  2.1669,  0.2618,  0.0000,  0.0000,
-        -0.1738,  0.0876,  0.0114,  0.0000, -0.1247, -0.4849,  0.0000,  0.0000,
-         0.0000, -0.0338,  0.0000,  0.0000,  0.3330,  0.2781,  0.0000, -0.7384,
-         0.0991,  0.0000,  0.0852, -0.0935,  0.1656,  0.0000, -0.6438, -1.1541,
-         0.0000, -0.5286,  0.0000,  0.0000,  0.0000,  0.0000,  0.1834,  0.4256,
-        -0.0643,  0.0000, -0.2943,  0.0000,  0.0000,  0.1385,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0998, -0.0357,  0.1399,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5605,  0.0000,  0.0000, -0.4092,  0.0000,  0.0745],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.8378e-01, -1.2302e-08, -1.1654e-01, -9.4522e-02,  2.1641e+00,
-         2.4571e-01,  7.4469e-18, -5.8801e-11, -1.1570e-01,  7.7782e-02,
-         2.7506e-02,  1.3387e-09, -9.3391e-02, -4.7191e-01, -1.3354e-14,
-        -9.9254e-12, -9.7632e-13, -3.5504e-04, -2.2582e-15,  2.7151e-11,
-         3.3659e-01,  2.8729e-01, -4.7955e-15, -7.7964e-01,  6.9043e-02,
-         3.8921e-11,  1.0306e-01, -5.6698e-02,  1.7068e-01,  6.9744e-15,
-        -6.4054e-01, -1.1627e+00, -1.7853e-18, -5.1431e-01,  8.8613e-13,
-         4.7812e-11,  2.9559e-14,  0.0000e+00,  1.9784e-01,  4.3911e-01,
-        -4.6658e-02, -3.9481e-12, -3.1377e-01,  7.2748e-15,  1.7793e-12,
-         1.1974e-01, -2.5106e-10,  8.1011e-04,  8.4457e-10, -4.6313e-16,
-        -1.4985e-06,  1.1641e-01, -4.5995e-03,  1.2354e-01, -1.6971e-07,
-         9.6385e-20, -3.7881e-14,  1.3623e-17,  5.8265e-01, -1.9675e-09,
-         3.6391e-17, -3.7103e-01,  3.0576e-05,  9.5392e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.8378e-01,  0.0000e+00, -1.1654e-01, -9.4522e-02,  2.1641e+00,
-         2.4571e-01,  0.0000e+00,  0.0000e+00, -1.1570e-01,  7.7782e-02,
-         2.7506e-02,  0.0000e+00, -9.3391e-02, -4.7191e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.5504e-04,  0.0000e+00,  0.0000e+00,
-         3.3659e-01,  2.8729e-01,  0.0000e+00, -7.7964e-01,  6.9043e-02,
-         0.0000e+00,  1.0306e-01, -5.6698e-02,  1.7068e-01,  0.0000e+00,
-        -6.4054e-01, -1.1627e+00,  0.0000e+00, -5.1431e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.9784e-01,  4.3911e-01,
-        -4.6658e-02,  0.0000e+00, -3.1377e-01,  0.0000e+00,  0.0000e+00,
-         1.1974e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.1641e-01, -4.5995e-03,  1.2354e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.8265e-01,  0.0000e+00,
-         0.0000e+00, -3.7103e-01,  0.0000e+00,  9.5392e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.8378e-01,  0.0000e+00, -1.1654e-01, -9.4522e-02,  2.1641e+00,
-         2.4571e-01,  0.0000e+00,  0.0000e+00, -1.1570e-01,  7.7782e-02,
-         2.7506e-02,  0.0000e+00, -9.3391e-02, -4.7191e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -3.5504e-04,  0.0000e+00,  0.0000e+00,
-         3.3659e-01,  2.8729e-01,  0.0000e+00, -7.7964e-01,  6.9043e-02,
-         0.0000e+00,  1.0306e-01, -5.6698e-02,  1.7068e-01,  0.0000e+00,
-        -6.4054e-01, -1.1627e+00,  0.0000e+00, -5.1431e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.9784e-01,  4.3911e-01,
-        -4.6658e-02,  0.0000e+00, -3.1377e-01,  0.0000e+00,  0.0000e+00,
-         1.1974e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.1641e-01, -4.5995e-03,  1.2354e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.8265e-01,  0.0000e+00,
-         0.0000e+00, -3.7103e-01,  0.0000e+00,  9.5392e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 4.0115e-01, -1.1092e-08, -1.1136e-01, -5.3675e-02,  2.1611e+00,
-         2.2618e-01,  6.7142e-18, -5.3016e-11, -5.0749e-02,  7.4129e-02,
-        -1.1032e-02,  1.2070e-09, -6.2400e-02, -4.7595e-01, -1.2041e-14,
-        -8.9489e-12, -8.8027e-13,  5.8627e-02, -2.0360e-15,  2.4479e-11,
-         3.4226e-01,  2.9376e-01, -4.3237e-15, -8.1436e-01,  3.9984e-02,
-         3.5091e-11,  1.2339e-01, -1.7668e-02,  1.7923e-01,  6.2882e-15,
-        -6.1292e-01, -1.1682e+00, -1.6097e-18, -4.8654e-01,  7.9895e-13,
-         4.3108e-11,  2.6651e-14,  0.0000e+00,  2.0491e-01,  4.4606e-01,
-        -4.9989e-02, -3.5597e-12, -3.3376e-01,  6.5591e-15,  1.6043e-12,
-         6.8220e-02, -2.2636e-10,  7.3041e-04,  7.6148e-10, -4.1757e-16,
-        -1.3511e-06,  1.4861e-01, -1.7972e-02,  1.3884e-01, -1.5302e-07,
-         8.6903e-20, -3.4154e-14,  1.2283e-17,  5.9899e-01, -1.7739e-09,
-         3.2811e-17, -3.3345e-01,  2.7568e-05,  1.0280e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4011,  0.0000, -0.1114, -0.0537,  2.1611,  0.2262,  0.0000,  0.0000,
-        -0.0507,  0.0741, -0.0110,  0.0000, -0.0624, -0.4760,  0.0000,  0.0000,
-         0.0000,  0.0586,  0.0000,  0.0000,  0.3423,  0.2938,  0.0000, -0.8144,
-         0.0400,  0.0000,  0.1234, -0.0177,  0.1792,  0.0000, -0.6129, -1.1682,
-         0.0000, -0.4865,  0.0000,  0.0000,  0.0000,  0.0000,  0.2049,  0.4461,
-        -0.0500,  0.0000, -0.3338,  0.0000,  0.0000,  0.0682,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1486, -0.0180,  0.1388,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5990,  0.0000,  0.0000, -0.3334,  0.0000,  0.1028],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4011,  0.0000, -0.1114, -0.0537,  2.1611,  0.2262,  0.0000,  0.0000,
-        -0.0507,  0.0741, -0.0110,  0.0000, -0.0624, -0.4760,  0.0000,  0.0000,
-         0.0000,  0.0586,  0.0000,  0.0000,  0.3423,  0.2938,  0.0000, -0.8144,
-         0.0400,  0.0000,  0.1234, -0.0177,  0.1792,  0.0000, -0.6129, -1.1682,
-         0.0000, -0.4865,  0.0000,  0.0000,  0.0000,  0.0000,  0.2049,  0.4461,
-        -0.0500,  0.0000, -0.3338,  0.0000,  0.0000,  0.0682,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1486, -0.0180,  0.1388,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5990,  0.0000,  0.0000, -0.3334,  0.0000,  0.1028],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.0911e-01, -1.0004e-08, -9.8797e-02, -7.8877e-03,  2.1569e+00,
-         2.0037e-01,  6.0558e-18, -4.7818e-11, -2.2492e-03,  6.3084e-02,
-        -7.9037e-02,  1.0887e-09, -4.0200e-02, -4.8248e-01, -1.0860e-14,
-        -8.0714e-12, -7.9395e-13,  9.6703e-02, -1.8363e-15,  2.2079e-11,
-         3.4783e-01,  2.9475e-01, -3.8997e-15, -8.4592e-01,  1.3096e-02,
-         3.1651e-11,  1.5022e-01,  1.6158e-02,  1.7999e-01,  5.6716e-15,
-        -5.7247e-01, -1.1728e+00, -1.4518e-18, -4.5991e-01,  7.2061e-13,
-         3.8881e-11,  2.4038e-14,  0.0000e+00,  2.1585e-01,  4.5347e-01,
-        -5.9902e-02, -3.2107e-12, -3.4228e-01,  5.9159e-15,  1.4470e-12,
-        -4.5267e-03, -2.0416e-10,  6.5878e-04,  6.8681e-10, -3.7662e-16,
-        -1.2186e-06,  1.7241e-01, -5.9256e-02,  1.6836e-01, -1.3801e-07,
-         7.8381e-20, -3.0805e-14,  1.1079e-17,  6.1069e-01, -1.6000e-09,
-         2.9594e-17, -2.9408e-01,  2.4864e-05,  8.2620e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4091,  0.0000, -0.0988, -0.0079,  2.1569,  0.2004,  0.0000,  0.0000,
-        -0.0022,  0.0631, -0.0790,  0.0000, -0.0402, -0.4825,  0.0000,  0.0000,
-         0.0000,  0.0967,  0.0000,  0.0000,  0.3478,  0.2947,  0.0000, -0.8459,
-         0.0131,  0.0000,  0.1502,  0.0162,  0.1800,  0.0000, -0.5725, -1.1728,
-         0.0000, -0.4599,  0.0000,  0.0000,  0.0000,  0.0000,  0.2159,  0.4535,
-        -0.0599,  0.0000, -0.3423,  0.0000,  0.0000, -0.0045,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1724, -0.0593,  0.1684,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6107,  0.0000,  0.0000, -0.2941,  0.0000,  0.0826],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4091,  0.0000, -0.0988, -0.0079,  2.1569,  0.2004,  0.0000,  0.0000,
-        -0.0022,  0.0631, -0.0790,  0.0000, -0.0402, -0.4825,  0.0000,  0.0000,
-         0.0000,  0.0967,  0.0000,  0.0000,  0.3478,  0.2947,  0.0000, -0.8459,
-         0.0131,  0.0000,  0.1502,  0.0162,  0.1800,  0.0000, -0.5725, -1.1728,
-         0.0000, -0.4599,  0.0000,  0.0000,  0.0000,  0.0000,  0.2159,  0.4535,
-        -0.0599,  0.0000, -0.3423,  0.0000,  0.0000, -0.0045,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1724, -0.0593,  0.1684,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6107,  0.0000,  0.0000, -0.2941,  0.0000,  0.0826],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.0110e-01, -9.0267e-09, -8.1565e-02,  3.0888e-02,  2.1523e+00,
-         1.9345e-01,  5.4640e-18, -4.3145e-11,  4.2721e-02,  5.9515e-02,
-        -1.8127e-01,  9.8228e-10, -3.0601e-02, -4.9964e-01, -9.7985e-15,
-        -7.2826e-12, -7.1636e-13,  9.9762e-02, -1.6569e-15,  1.9921e-11,
-         3.4717e-01,  2.7908e-01, -3.5186e-15, -8.5526e-01, -2.2547e-02,
-         2.8557e-11,  1.7925e-01,  3.8190e-02,  1.8060e-01,  5.1173e-15,
-        -5.1324e-01, -1.1779e+00, -1.3099e-18, -4.1472e-01,  6.5018e-13,
-         3.5081e-11,  2.1689e-14,  0.0000e+00,  2.1105e-01,  4.5648e-01,
-        -8.6101e-02, -2.8969e-12, -3.6095e-01,  5.3378e-15,  1.3056e-12,
-        -8.3358e-02, -1.8421e-10,  5.9440e-04,  6.1969e-10, -3.3981e-16,
-        -1.0995e-06,  1.5729e-01, -1.1390e-01,  2.0171e-01, -1.2453e-07,
-         7.0721e-20, -2.7795e-14,  9.9959e-18,  6.0594e-01, -1.4436e-09,
-         2.6701e-17, -2.5595e-01,  2.2434e-05,  6.2650e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4011,  0.0000, -0.0816,  0.0309,  2.1523,  0.1935,  0.0000,  0.0000,
-         0.0427,  0.0595, -0.1813,  0.0000, -0.0306, -0.4996,  0.0000,  0.0000,
-         0.0000,  0.0998,  0.0000,  0.0000,  0.3472,  0.2791,  0.0000, -0.8553,
-        -0.0225,  0.0000,  0.1793,  0.0382,  0.1806,  0.0000, -0.5132, -1.1779,
-         0.0000, -0.4147,  0.0000,  0.0000,  0.0000,  0.0000,  0.2111,  0.4565,
-        -0.0861,  0.0000, -0.3610,  0.0000,  0.0000, -0.0834,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1573, -0.1139,  0.2017,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6059,  0.0000,  0.0000, -0.2559,  0.0000,  0.0627],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4011,  0.0000, -0.0816,  0.0309,  2.1523,  0.1935,  0.0000,  0.0000,
-         0.0427,  0.0595, -0.1813,  0.0000, -0.0306, -0.4996,  0.0000,  0.0000,
-         0.0000,  0.0998,  0.0000,  0.0000,  0.3472,  0.2791,  0.0000, -0.8553,
-        -0.0225,  0.0000,  0.1793,  0.0382,  0.1806,  0.0000, -0.5132, -1.1779,
-         0.0000, -0.4147,  0.0000,  0.0000,  0.0000,  0.0000,  0.2111,  0.4565,
-        -0.0861,  0.0000, -0.3610,  0.0000,  0.0000, -0.0834,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1573, -0.1139,  0.2017,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6059,  0.0000,  0.0000, -0.2559,  0.0000,  0.0627],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.9704e-01, -8.1474e-09, -8.1265e-02,  5.1529e-02,  2.1474e+00,
-         1.8000e-01,  4.9318e-18, -3.8942e-11,  7.4672e-02,  6.8723e-02,
-        -2.3235e-01,  8.8659e-10, -5.4527e-02, -5.2342e-01, -8.8440e-15,
-        -6.5732e-12, -6.4658e-13,  9.6594e-02, -1.4955e-15,  1.7981e-11,
-         3.3653e-01,  2.6537e-01, -3.1759e-15, -8.7700e-01, -3.8003e-02,
-         2.5776e-11,  2.0347e-01,  4.9575e-02,  1.7915e-01,  4.6188e-15,
-        -4.5949e-01, -1.1865e+00, -1.1823e-18, -3.7282e-01,  5.8685e-13,
-         3.1664e-11,  1.9576e-14,  0.0000e+00,  1.8918e-01,  4.4122e-01,
-        -1.0750e-01, -2.6147e-12, -3.8015e-01,  4.8178e-15,  1.1784e-12,
-        -1.4649e-01, -1.6627e-10,  5.3650e-04,  5.5932e-10, -3.0671e-16,
-        -9.9238e-07,  1.1776e-01, -1.4347e-01,  2.1562e-01, -1.1240e-07,
-         6.3832e-20, -2.5087e-14,  9.0222e-18,  6.0578e-01, -1.3030e-09,
-         2.4100e-17, -2.2909e-01,  2.0249e-05,  6.2085e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3970,  0.0000, -0.0813,  0.0515,  2.1474,  0.1800,  0.0000,  0.0000,
-         0.0747,  0.0687, -0.2324,  0.0000, -0.0545, -0.5234,  0.0000,  0.0000,
-         0.0000,  0.0966,  0.0000,  0.0000,  0.3365,  0.2654,  0.0000, -0.8770,
-        -0.0380,  0.0000,  0.2035,  0.0496,  0.1791,  0.0000, -0.4595, -1.1865,
-         0.0000, -0.3728,  0.0000,  0.0000,  0.0000,  0.0000,  0.1892,  0.4412,
-        -0.1075,  0.0000, -0.3801,  0.0000,  0.0000, -0.1465,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1178, -0.1435,  0.2156,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6058,  0.0000,  0.0000, -0.2291,  0.0000,  0.0621],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3970,  0.0000, -0.0813,  0.0515,  2.1474,  0.1800,  0.0000,  0.0000,
-         0.0747,  0.0687, -0.2324,  0.0000, -0.0545, -0.5234,  0.0000,  0.0000,
-         0.0000,  0.0966,  0.0000,  0.0000,  0.3365,  0.2654,  0.0000, -0.8770,
-        -0.0380,  0.0000,  0.2035,  0.0496,  0.1791,  0.0000, -0.4595, -1.1865,
-         0.0000, -0.3728,  0.0000,  0.0000,  0.0000,  0.0000,  0.1892,  0.4412,
-        -0.1075,  0.0000, -0.3801,  0.0000,  0.0000, -0.1465,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1178, -0.1435,  0.2156,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6058,  0.0000,  0.0000, -0.2291,  0.0000,  0.0621],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.0410e-01, -7.3564e-09, -7.1895e-02,  6.1716e-02,  2.1431e+00,
-         1.7677e-01,  4.4529e-18, -3.5161e-11,  9.2339e-02,  6.4859e-02,
-        -2.8683e-01,  8.0051e-10, -7.4801e-02, -5.3148e-01, -7.9854e-15,
-        -5.9350e-12, -5.8380e-13,  4.4310e-02, -1.3503e-15,  1.6235e-11,
-         3.1639e-01,  2.4145e-01, -2.8675e-15, -8.9891e-01, -4.9640e-02,
-         2.3273e-11,  2.1889e-01,  5.5079e-02,  1.7035e-01,  4.1704e-15,
-        -4.2391e-01, -1.1949e+00, -1.0675e-18, -3.2751e-01,  5.2987e-13,
-         2.8590e-11,  1.7675e-14,  0.0000e+00,  1.6976e-01,  4.4156e-01,
-        -1.2947e-01, -2.3608e-12, -3.9670e-01,  4.3501e-15,  1.0640e-12,
-        -2.0468e-01, -1.5012e-10,  4.8441e-04,  5.0502e-10, -2.7693e-16,
-        -8.9603e-07,  9.5233e-02, -1.9198e-01,  1.6677e-01, -1.0148e-07,
-         5.7635e-20, -2.2651e-14,  8.1463e-18,  6.0210e-01, -1.1765e-09,
-         2.1761e-17, -2.2801e-01,  1.8283e-05,  3.8462e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4041,  0.0000, -0.0719,  0.0617,  2.1431,  0.1768,  0.0000,  0.0000,
-         0.0923,  0.0649, -0.2868,  0.0000, -0.0748, -0.5315,  0.0000,  0.0000,
-         0.0000,  0.0443,  0.0000,  0.0000,  0.3164,  0.2414,  0.0000, -0.8989,
-        -0.0496,  0.0000,  0.2189,  0.0551,  0.1704,  0.0000, -0.4239, -1.1949,
-         0.0000, -0.3275,  0.0000,  0.0000,  0.0000,  0.0000,  0.1698,  0.4416,
-        -0.1295,  0.0000, -0.3967,  0.0000,  0.0000, -0.2047,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0952, -0.1920,  0.1668,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6021,  0.0000,  0.0000, -0.2280,  0.0000,  0.0385],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4041,  0.0000, -0.0719,  0.0617,  2.1431,  0.1768,  0.0000,  0.0000,
-         0.0923,  0.0649, -0.2868,  0.0000, -0.0748, -0.5315,  0.0000,  0.0000,
-         0.0000,  0.0443,  0.0000,  0.0000,  0.3164,  0.2414,  0.0000, -0.8989,
-        -0.0496,  0.0000,  0.2189,  0.0551,  0.1704,  0.0000, -0.4239, -1.1949,
-         0.0000, -0.3275,  0.0000,  0.0000,  0.0000,  0.0000,  0.1698,  0.4416,
-        -0.1295,  0.0000, -0.3967,  0.0000,  0.0000, -0.2047,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0952, -0.1920,  0.1668,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6021,  0.0000,  0.0000, -0.2280,  0.0000,  0.0385],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.9583e-01, -6.6445e-09, -7.8200e-02,  5.3416e-02,  2.1392e+00,
-         1.7083e-01,  4.0220e-18, -3.1758e-11,  1.0095e-01,  5.8373e-02,
-        -3.2230e-01,  7.2305e-10, -1.1711e-01, -5.3943e-01, -7.2126e-15,
-        -5.3607e-12, -5.2731e-13, -1.1187e-02, -1.2196e-15,  1.4664e-11,
-         2.8215e-01,  2.1569e-01, -2.5900e-15, -9.1368e-01, -5.3962e-02,
-         2.1021e-11,  2.3058e-01,  5.4624e-02,  1.5303e-01,  3.7668e-15,
-        -3.9872e-01, -1.2010e+00, -9.6423e-19, -2.6792e-01,  4.7860e-13,
-         2.5823e-11,  1.5965e-14,  0.0000e+00,  1.3462e-01,  4.2412e-01,
-        -1.4048e-01, -2.1324e-12, -4.1408e-01,  3.9291e-15,  9.6102e-13,
-        -2.4794e-01, -1.3560e-10,  4.3753e-04,  4.5615e-10, -2.5013e-16,
-        -8.0932e-07,  7.1522e-02, -2.2935e-01,  1.0977e-01, -9.1662e-08,
-         5.2057e-20, -2.0459e-14,  7.3580e-18,  5.9550e-01, -1.0626e-09,
-         1.9655e-17, -2.3231e-01,  1.6514e-05,  2.0656e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3958,  0.0000, -0.0782,  0.0534,  2.1392,  0.1708,  0.0000,  0.0000,
-         0.1009,  0.0584, -0.3223,  0.0000, -0.1171, -0.5394,  0.0000,  0.0000,
-         0.0000, -0.0112,  0.0000,  0.0000,  0.2821,  0.2157,  0.0000, -0.9137,
-        -0.0540,  0.0000,  0.2306,  0.0546,  0.1530,  0.0000, -0.3987, -1.2010,
-         0.0000, -0.2679,  0.0000,  0.0000,  0.0000,  0.0000,  0.1346,  0.4241,
-        -0.1405,  0.0000, -0.4141,  0.0000,  0.0000, -0.2479,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0715, -0.2294,  0.1098,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5955,  0.0000,  0.0000, -0.2323,  0.0000,  0.0207],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3958,  0.0000, -0.0782,  0.0534,  2.1392,  0.1708,  0.0000,  0.0000,
-         0.1009,  0.0584, -0.3223,  0.0000, -0.1171, -0.5394,  0.0000,  0.0000,
-         0.0000, -0.0112,  0.0000,  0.0000,  0.2821,  0.2157,  0.0000, -0.9137,
-        -0.0540,  0.0000,  0.2306,  0.0546,  0.1530,  0.0000, -0.3987, -1.2010,
-         0.0000, -0.2679,  0.0000,  0.0000,  0.0000,  0.0000,  0.1346,  0.4241,
-        -0.1405,  0.0000, -0.4141,  0.0000,  0.0000, -0.2479,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0715, -0.2294,  0.1098,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5955,  0.0000,  0.0000, -0.2323,  0.0000,  0.0207],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.9069e-01, -6.0036e-09, -8.5107e-02,  5.0641e-02,  2.1372e+00,
-         1.5501e-01,  3.6341e-18, -2.8695e-11,  9.3670e-02,  5.3434e-02,
-        -3.2931e-01,  6.5331e-10, -1.5960e-01, -5.5094e-01, -6.5169e-15,
-        -4.8436e-12, -4.7645e-13, -2.5864e-02, -1.1020e-15,  1.3249e-11,
-         2.3870e-01,  1.8193e-01, -2.3402e-15, -9.2260e-01, -5.7603e-02,
-         1.8993e-11,  2.3431e-01,  4.1574e-02,  1.4704e-01,  3.4035e-15,
-        -3.7446e-01, -1.2043e+00, -8.7123e-19, -2.1458e-01,  4.3243e-13,
-         2.3333e-11,  1.4425e-14,  0.0000e+00,  1.1699e-01,  4.2401e-01,
-        -1.3944e-01, -1.9267e-12, -4.1944e-01,  3.5501e-15,  8.6832e-13,
-        -2.7847e-01, -1.2252e-10,  3.9533e-04,  4.1215e-10, -2.2601e-16,
-        -7.3126e-07,  3.8147e-02, -2.3733e-01,  8.3300e-02, -8.2821e-08,
-         4.7036e-20, -1.8486e-14,  6.6482e-18,  5.8993e-01, -9.6015e-10,
-         1.7759e-17, -2.2976e-01,  1.4921e-05, -1.1387e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3907,  0.0000, -0.0851,  0.0506,  2.1372,  0.1550,  0.0000,  0.0000,
-         0.0937,  0.0534, -0.3293,  0.0000, -0.1596, -0.5509,  0.0000,  0.0000,
-         0.0000, -0.0259,  0.0000,  0.0000,  0.2387,  0.1819,  0.0000, -0.9226,
-        -0.0576,  0.0000,  0.2343,  0.0416,  0.1470,  0.0000, -0.3745, -1.2043,
-         0.0000, -0.2146,  0.0000,  0.0000,  0.0000,  0.0000,  0.1170,  0.4240,
-        -0.1394,  0.0000, -0.4194,  0.0000,  0.0000, -0.2785,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0381, -0.2373,  0.0833,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5899,  0.0000,  0.0000, -0.2298,  0.0000, -0.0114],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3907,  0.0000, -0.0851,  0.0506,  2.1372,  0.1550,  0.0000,  0.0000,
-         0.0937,  0.0534, -0.3293,  0.0000, -0.1596, -0.5509,  0.0000,  0.0000,
-         0.0000, -0.0259,  0.0000,  0.0000,  0.2387,  0.1819,  0.0000, -0.9226,
-        -0.0576,  0.0000,  0.2343,  0.0416,  0.1470,  0.0000, -0.3745, -1.2043,
-         0.0000, -0.2146,  0.0000,  0.0000,  0.0000,  0.0000,  0.1170,  0.4240,
-        -0.1394,  0.0000, -0.4194,  0.0000,  0.0000, -0.2785,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0381, -0.2373,  0.0833,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5899,  0.0000,  0.0000, -0.2298,  0.0000, -0.0114],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.8641e-01, -5.4264e-09, -8.5104e-02,  5.5911e-02,  2.1365e+00,
-         1.4341e-01,  3.2847e-18, -2.5936e-11,  7.8114e-02,  4.0593e-02,
-        -3.3944e-01,  5.9050e-10, -1.8963e-01, -5.5894e-01, -5.8904e-15,
-        -4.3780e-12, -4.3064e-13, -4.1990e-02, -9.9604e-16,  1.1976e-11,
-         2.0089e-01,  1.5649e-01, -2.1152e-15, -9.2898e-01, -4.3973e-02,
-         1.7167e-11,  2.5516e-01,  3.9154e-02,  1.5059e-01,  3.0763e-15,
-        -3.4980e-01, -1.2072e+00, -7.8747e-19, -1.7054e-01,  3.9086e-13,
-         2.1089e-11,  1.3038e-14,  0.0000e+00,  1.0565e-01,  4.1317e-01,
-        -1.2028e-01, -1.7415e-12, -4.2372e-01,  3.2088e-15,  7.8484e-13,
-        -3.1389e-01, -1.1074e-10,  3.5733e-04,  3.7253e-10, -2.0428e-16,
-        -6.6096e-07,  1.0251e-03, -2.5526e-01,  8.4742e-02, -7.4859e-08,
-         4.2514e-20, -1.6709e-14,  6.0091e-18,  5.7808e-01, -8.6784e-10,
-         1.6052e-17, -2.3319e-01,  1.3487e-05, -4.5115e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.8641e-01,  0.0000e+00, -8.5104e-02,  5.5911e-02,  2.1365e+00,
-         1.4341e-01,  0.0000e+00,  0.0000e+00,  7.8114e-02,  4.0593e-02,
-        -3.3944e-01,  0.0000e+00, -1.8963e-01, -5.5894e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.1990e-02,  0.0000e+00,  0.0000e+00,
-         2.0089e-01,  1.5649e-01,  0.0000e+00, -9.2898e-01, -4.3973e-02,
-         0.0000e+00,  2.5516e-01,  3.9154e-02,  1.5059e-01,  0.0000e+00,
-        -3.4980e-01, -1.2072e+00,  0.0000e+00, -1.7054e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0565e-01,  4.1317e-01,
-        -1.2028e-01,  0.0000e+00, -4.2372e-01,  0.0000e+00,  0.0000e+00,
-        -3.1389e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.0251e-03, -2.5526e-01,  8.4742e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.7808e-01,  0.0000e+00,
-         0.0000e+00, -2.3319e-01,  0.0000e+00, -4.5115e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.8641e-01,  0.0000e+00, -8.5104e-02,  5.5911e-02,  2.1365e+00,
-         1.4341e-01,  0.0000e+00,  0.0000e+00,  7.8114e-02,  4.0593e-02,
-        -3.3944e-01,  0.0000e+00, -1.8963e-01, -5.5894e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.1990e-02,  0.0000e+00,  0.0000e+00,
-         2.0089e-01,  1.5649e-01,  0.0000e+00, -9.2898e-01, -4.3973e-02,
-         0.0000e+00,  2.5516e-01,  3.9154e-02,  1.5059e-01,  0.0000e+00,
-        -3.4980e-01, -1.2072e+00,  0.0000e+00, -1.7054e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0565e-01,  4.1317e-01,
-        -1.2028e-01,  0.0000e+00, -4.2372e-01,  0.0000e+00,  0.0000e+00,
-        -3.1389e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  1.0251e-03, -2.5526e-01,  8.4742e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.7808e-01,  0.0000e+00,
-         0.0000e+00, -2.3319e-01,  0.0000e+00, -4.5115e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.9714e-01, -4.9064e-09, -7.8195e-02,  6.7568e-02,  2.1378e+00,
-         1.3090e-01,  2.9699e-18, -2.3451e-11,  6.5038e-02,  1.5987e-02,
-        -3.2268e-01,  5.3391e-10, -2.0468e-01, -5.7340e-01, -5.3259e-15,
-        -3.9584e-12, -3.8937e-13, -3.0201e-02, -9.0059e-16,  1.0828e-11,
-         1.6396e-01,  1.2288e-01, -1.9125e-15, -9.2925e-01, -6.3366e-02,
-         1.5522e-11,  2.8130e-01,  4.2477e-02,  1.6243e-01,  2.7815e-15,
-        -3.4495e-01, -1.2086e+00, -7.1201e-19, -1.3689e-01,  3.5341e-13,
-         1.9068e-11,  1.1789e-14,  0.0000e+00,  1.0668e-01,  4.0531e-01,
-        -1.0487e-01, -1.5746e-12, -4.1730e-01,  2.9013e-15,  7.0963e-13,
-        -3.5162e-01, -1.0013e-10,  3.2308e-04,  3.3683e-10, -1.8470e-16,
-        -5.9762e-07, -5.3570e-02, -2.5745e-01,  1.0981e-01, -6.7685e-08,
-         3.8440e-20, -1.5108e-14,  5.4333e-18,  5.6280e-01, -7.8468e-10,
-         1.4513e-17, -2.2947e-01,  1.2194e-05, -6.2740e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3971,  0.0000, -0.0782,  0.0676,  2.1378,  0.1309,  0.0000,  0.0000,
-         0.0650,  0.0160, -0.3227,  0.0000, -0.2047, -0.5734,  0.0000,  0.0000,
-         0.0000, -0.0302,  0.0000,  0.0000,  0.1640,  0.1229,  0.0000, -0.9293,
-        -0.0634,  0.0000,  0.2813,  0.0425,  0.1624,  0.0000, -0.3449, -1.2086,
-         0.0000, -0.1369,  0.0000,  0.0000,  0.0000,  0.0000,  0.1067,  0.4053,
-        -0.1049,  0.0000, -0.4173,  0.0000,  0.0000, -0.3516,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0536, -0.2574,  0.1098,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5628,  0.0000,  0.0000, -0.2295,  0.0000, -0.0627],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3971,  0.0000, -0.0782,  0.0676,  2.1378,  0.1309,  0.0000,  0.0000,
-         0.0650,  0.0160, -0.3227,  0.0000, -0.2047, -0.5734,  0.0000,  0.0000,
-         0.0000, -0.0302,  0.0000,  0.0000,  0.1640,  0.1229,  0.0000, -0.9293,
-        -0.0634,  0.0000,  0.2813,  0.0425,  0.1624,  0.0000, -0.3449, -1.2086,
-         0.0000, -0.1369,  0.0000,  0.0000,  0.0000,  0.0000,  0.1067,  0.4053,
-        -0.1049,  0.0000, -0.4173,  0.0000,  0.0000, -0.3516,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0536, -0.2574,  0.1098,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5628,  0.0000,  0.0000, -0.2295,  0.0000, -0.0627],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.0914e-01, -4.4378e-09, -7.3298e-02,  8.6441e-02,  2.1403e+00,
-         1.2633e-01,  2.6863e-18, -2.1211e-11,  6.3865e-02, -1.4888e-02,
-        -2.8912e-01,  4.8291e-10, -2.1821e-01, -5.8506e-01, -4.8172e-15,
-        -3.5803e-12, -3.5218e-13, -4.6212e-03, -8.1457e-16,  9.7938e-12,
-         1.3957e-01,  9.3924e-02, -1.7298e-15, -9.2136e-01, -8.5674e-02,
-         1.4040e-11,  3.0935e-01,  6.4437e-02,  1.7740e-01,  2.5158e-15,
-        -3.6428e-01, -1.2063e+00, -6.4400e-19, -9.1872e-02,  3.1965e-13,
-         1.7247e-11,  1.0663e-14,  0.0000e+00,  1.0485e-01,  3.9285e-01,
-        -7.8289e-02, -1.4242e-12, -4.1797e-01,  2.6242e-15,  6.4185e-13,
-        -4.1902e-01, -9.0564e-11,  2.9222e-04,  3.0466e-10, -1.6706e-16,
-        -5.4054e-07, -7.4744e-02, -2.6964e-01,  1.1383e-01, -6.1220e-08,
-         3.4768e-20, -1.3665e-14,  4.9143e-18,  5.4747e-01, -7.0973e-10,
-         1.3127e-17, -2.4801e-01,  1.1029e-05, -8.5391e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4091,  0.0000, -0.0733,  0.0864,  2.1403,  0.1263,  0.0000,  0.0000,
-         0.0639, -0.0149, -0.2891,  0.0000, -0.2182, -0.5851,  0.0000,  0.0000,
-         0.0000, -0.0046,  0.0000,  0.0000,  0.1396,  0.0939,  0.0000, -0.9214,
-        -0.0857,  0.0000,  0.3093,  0.0644,  0.1774,  0.0000, -0.3643, -1.2063,
-         0.0000, -0.0919,  0.0000,  0.0000,  0.0000,  0.0000,  0.1048,  0.3928,
-        -0.0783,  0.0000, -0.4180,  0.0000,  0.0000, -0.4190,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0747, -0.2696,  0.1138,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5475,  0.0000,  0.0000, -0.2480,  0.0000, -0.0854],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4091,  0.0000, -0.0733,  0.0864,  2.1403,  0.1263,  0.0000,  0.0000,
-         0.0639, -0.0149, -0.2891,  0.0000, -0.2182, -0.5851,  0.0000,  0.0000,
-         0.0000, -0.0046,  0.0000,  0.0000,  0.1396,  0.0939,  0.0000, -0.9214,
-        -0.0857,  0.0000,  0.3093,  0.0644,  0.1774,  0.0000, -0.3643, -1.2063,
-         0.0000, -0.0919,  0.0000,  0.0000,  0.0000,  0.0000,  0.1048,  0.3928,
-        -0.0783,  0.0000, -0.4180,  0.0000,  0.0000, -0.4190,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0747, -0.2696,  0.1138,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5475,  0.0000,  0.0000, -0.2480,  0.0000, -0.0854],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.1951e-01, -4.0153e-09, -4.8519e-02,  1.2564e-01,  2.1434e+00,
-         1.4247e-01,  2.4305e-18, -1.9192e-11,  5.5981e-02, -1.8101e-02,
-        -2.4667e-01,  4.3694e-10, -2.2248e-01, -5.8814e-01, -4.3586e-15,
-        -3.2394e-12, -3.1865e-13, -5.4427e-02, -7.3701e-16,  8.8614e-12,
-         1.1645e-01,  6.5525e-02, -1.5651e-15, -9.1879e-01, -9.3407e-02,
-         1.2703e-11,  3.4264e-01,  8.6330e-02,  1.8878e-01,  2.2763e-15,
-        -3.7527e-01, -1.2056e+00, -5.8268e-19, -5.3023e-02,  2.8922e-13,
-         1.5605e-11,  9.6476e-15,  0.0000e+00,  1.4604e-01,  4.2025e-01,
-        -6.6499e-02, -1.2886e-12, -4.0602e-01,  2.3743e-15,  5.8074e-13,
-        -4.7286e-01, -8.1941e-11,  2.6440e-04,  2.7565e-10, -1.5116e-16,
-        -4.8907e-07, -8.0989e-02, -2.8446e-01,  1.3526e-01, -5.5391e-08,
-         3.1458e-20, -1.2364e-14,  4.4464e-18,  5.2404e-01, -6.4215e-10,
-         1.1877e-17, -2.6457e-01,  9.9793e-06, -1.2422e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4195,  0.0000, -0.0485,  0.1256,  2.1434,  0.1425,  0.0000,  0.0000,
-         0.0560, -0.0181, -0.2467,  0.0000, -0.2225, -0.5881,  0.0000,  0.0000,
-         0.0000, -0.0544,  0.0000,  0.0000,  0.1164,  0.0655,  0.0000, -0.9188,
-        -0.0934,  0.0000,  0.3426,  0.0863,  0.1888,  0.0000, -0.3753, -1.2056,
-         0.0000, -0.0530,  0.0000,  0.0000,  0.0000,  0.0000,  0.1460,  0.4202,
-        -0.0665,  0.0000, -0.4060,  0.0000,  0.0000, -0.4729,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0810, -0.2845,  0.1353,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5240,  0.0000,  0.0000, -0.2646,  0.0000, -0.1242],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4195,  0.0000, -0.0485,  0.1256,  2.1434,  0.1425,  0.0000,  0.0000,
-         0.0560, -0.0181, -0.2467,  0.0000, -0.2225, -0.5881,  0.0000,  0.0000,
-         0.0000, -0.0544,  0.0000,  0.0000,  0.1164,  0.0655,  0.0000, -0.9188,
-        -0.0934,  0.0000,  0.3426,  0.0863,  0.1888,  0.0000, -0.3753, -1.2056,
-         0.0000, -0.0530,  0.0000,  0.0000,  0.0000,  0.0000,  0.1460,  0.4202,
-        -0.0665,  0.0000, -0.4060,  0.0000,  0.0000, -0.4729,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0810, -0.2845,  0.1353,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5240,  0.0000,  0.0000, -0.2646,  0.0000, -0.1242],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.3152e-01, -3.6342e-09, -3.6905e-02,  1.4011e-01,  2.1469e+00,
-         1.3631e-01,  2.1998e-18, -1.7370e-11,  5.4049e-02, -3.2134e-02,
-        -1.8243e-01,  3.9547e-10, -2.2764e-01, -5.8756e-01, -3.9449e-15,
-        -2.9320e-12, -2.8841e-13, -9.1807e-02, -6.6707e-16,  8.0204e-12,
-         9.0834e-02,  4.1555e-02, -1.4166e-15, -9.1402e-01, -9.7152e-02,
-         1.1497e-11,  3.4283e-01,  1.0495e-01,  2.0176e-01,  2.0603e-15,
-        -3.9133e-01, -1.2041e+00, -5.2738e-19, -1.2702e-02,  2.6177e-13,
-         1.4124e-11,  8.7320e-15,  0.0000e+00,  1.8983e-01,  4.5268e-01,
-        -1.7059e-02, -1.1663e-12, -3.9586e-01,  2.1490e-15,  5.2563e-13,
-        -5.1225e-01, -7.4165e-11,  2.3931e-04,  2.4949e-10, -1.3681e-16,
-        -4.4266e-07, -7.3561e-02, -2.7951e-01,  1.0603e-01, -5.0135e-08,
-         2.8473e-20, -1.1190e-14,  4.0244e-18,  4.9906e-01, -5.8121e-10,
-         1.0750e-17, -2.7100e-01,  9.0322e-06, -1.6562e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4315,  0.0000, -0.0369,  0.1401,  2.1469,  0.1363,  0.0000,  0.0000,
-         0.0540, -0.0321, -0.1824,  0.0000, -0.2276, -0.5876,  0.0000,  0.0000,
-         0.0000, -0.0918,  0.0000,  0.0000,  0.0908,  0.0416,  0.0000, -0.9140,
-        -0.0972,  0.0000,  0.3428,  0.1050,  0.2018,  0.0000, -0.3913, -1.2041,
-         0.0000, -0.0127,  0.0000,  0.0000,  0.0000,  0.0000,  0.1898,  0.4527,
-        -0.0171,  0.0000, -0.3959,  0.0000,  0.0000, -0.5122,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0736, -0.2795,  0.1060,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4991,  0.0000,  0.0000, -0.2710,  0.0000, -0.1656],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4315,  0.0000, -0.0369,  0.1401,  2.1469,  0.1363,  0.0000,  0.0000,
-         0.0540, -0.0321, -0.1824,  0.0000, -0.2276, -0.5876,  0.0000,  0.0000,
-         0.0000, -0.0918,  0.0000,  0.0000,  0.0908,  0.0416,  0.0000, -0.9140,
-        -0.0972,  0.0000,  0.3428,  0.1050,  0.2018,  0.0000, -0.3913, -1.2041,
-         0.0000, -0.0127,  0.0000,  0.0000,  0.0000,  0.0000,  0.1898,  0.4527,
-        -0.0171,  0.0000, -0.3959,  0.0000,  0.0000, -0.5122,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0736, -0.2795,  0.1060,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4991,  0.0000,  0.0000, -0.2710,  0.0000, -0.1656],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.4302e-01, -3.2904e-09, -3.8418e-02,  1.3441e-01,  2.1490e+00,
-         1.1444e-01,  1.9917e-18, -1.5727e-11,  4.7702e-02, -4.8435e-02,
-        -1.1545e-01,  3.5806e-10, -2.4063e-01, -5.7813e-01, -3.5717e-15,
-        -2.6546e-12, -2.6113e-13, -1.1651e-01, -6.0396e-16,  7.2616e-12,
-         7.4061e-02,  2.7562e-02, -1.2826e-15, -9.0458e-01, -9.5855e-02,
-         1.0410e-11,  3.2987e-01,  1.1572e-01,  2.0123e-01,  1.8654e-15,
-        -3.7424e-01, -1.2018e+00, -4.7749e-19,  1.9392e-02,  2.3700e-13,
-         1.2788e-11,  7.9059e-15,  0.0000e+00,  2.1621e-01,  4.7830e-01,
-         3.9080e-02, -1.0560e-12, -3.8689e-01,  1.9457e-15,  4.7590e-13,
-        -5.3262e-01, -6.7149e-11,  2.1667e-04,  2.2589e-10, -1.2387e-16,
-        -4.0078e-07, -7.1431e-02, -2.6423e-01,  4.8083e-02, -4.5392e-08,
-         2.5779e-20, -1.0132e-14,  3.6437e-18,  4.7320e-01, -5.2623e-10,
-         9.7332e-18, -2.6505e-01,  8.1778e-06, -2.0608e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4430,  0.0000, -0.0384,  0.1344,  2.1490,  0.1144,  0.0000,  0.0000,
-         0.0477, -0.0484, -0.1155,  0.0000, -0.2406, -0.5781,  0.0000,  0.0000,
-         0.0000, -0.1165,  0.0000,  0.0000,  0.0741,  0.0276,  0.0000, -0.9046,
-        -0.0959,  0.0000,  0.3299,  0.1157,  0.2012,  0.0000, -0.3742, -1.2018,
-         0.0000,  0.0194,  0.0000,  0.0000,  0.0000,  0.0000,  0.2162,  0.4783,
-         0.0391,  0.0000, -0.3869,  0.0000,  0.0000, -0.5326,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0714, -0.2642,  0.0481,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4732,  0.0000,  0.0000, -0.2650,  0.0000, -0.2061],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4430,  0.0000, -0.0384,  0.1344,  2.1490,  0.1144,  0.0000,  0.0000,
-         0.0477, -0.0484, -0.1155,  0.0000, -0.2406, -0.5781,  0.0000,  0.0000,
-         0.0000, -0.1165,  0.0000,  0.0000,  0.0741,  0.0276,  0.0000, -0.9046,
-        -0.0959,  0.0000,  0.3299,  0.1157,  0.2012,  0.0000, -0.3742, -1.2018,
-         0.0000,  0.0194,  0.0000,  0.0000,  0.0000,  0.0000,  0.2162,  0.4783,
-         0.0391,  0.0000, -0.3869,  0.0000,  0.0000, -0.5326,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0714, -0.2642,  0.0481,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4732,  0.0000,  0.0000, -0.2650,  0.0000, -0.2061],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.5069e-01, -2.9801e-09, -4.7820e-02,  1.1929e-01,  2.1510e+00,
-         8.7137e-02,  1.8039e-18, -1.4244e-11,  3.7938e-02, -5.9939e-02,
-        -6.7333e-02,  3.2429e-10, -2.7175e-01, -5.6222e-01, -3.2349e-15,
-        -2.4043e-12, -2.3650e-13, -1.2331e-01, -5.4701e-16,  6.5769e-12,
-         5.9394e-02,  2.5223e-02, -1.1616e-15, -8.9001e-01, -9.8985e-02,
-         9.4280e-12,  3.0495e-01,  1.2966e-01,  1.9925e-01,  1.6895e-15,
-        -3.3493e-01, -1.2005e+00, -4.3246e-19,  5.8012e-02,  2.1465e-13,
-         1.1582e-11,  7.1604e-15,  0.0000e+00,  2.1573e-01,  4.9980e-01,
-         1.0437e-01, -9.5639e-13, -3.8438e-01,  1.7622e-15,  4.3102e-13,
-        -5.5048e-01, -6.0816e-11,  1.9624e-04,  2.0459e-10, -1.1219e-16,
-        -3.6299e-07, -7.5312e-02, -2.5040e-01,  3.0210e-02, -4.1111e-08,
-         2.3348e-20, -9.1762e-15,  3.3001e-18,  4.5615e-01, -4.7661e-10,
-         8.8153e-18, -2.5110e-01,  7.4066e-06, -2.5420e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4507,  0.0000, -0.0478,  0.1193,  2.1510,  0.0871,  0.0000,  0.0000,
-         0.0379, -0.0599, -0.0673,  0.0000, -0.2718, -0.5622,  0.0000,  0.0000,
-         0.0000, -0.1233,  0.0000,  0.0000,  0.0594,  0.0252,  0.0000, -0.8900,
-        -0.0990,  0.0000,  0.3049,  0.1297,  0.1992,  0.0000, -0.3349, -1.2005,
-         0.0000,  0.0580,  0.0000,  0.0000,  0.0000,  0.0000,  0.2157,  0.4998,
-         0.1044,  0.0000, -0.3844,  0.0000,  0.0000, -0.5505,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0753, -0.2504,  0.0302,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4561,  0.0000,  0.0000, -0.2511,  0.0000, -0.2542],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4507,  0.0000, -0.0478,  0.1193,  2.1510,  0.0871,  0.0000,  0.0000,
-         0.0379, -0.0599, -0.0673,  0.0000, -0.2718, -0.5622,  0.0000,  0.0000,
-         0.0000, -0.1233,  0.0000,  0.0000,  0.0594,  0.0252,  0.0000, -0.8900,
-        -0.0990,  0.0000,  0.3049,  0.1297,  0.1992,  0.0000, -0.3349, -1.2005,
-         0.0000,  0.0580,  0.0000,  0.0000,  0.0000,  0.0000,  0.2157,  0.4998,
-         0.1044,  0.0000, -0.3844,  0.0000,  0.0000, -0.5505,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0753, -0.2504,  0.0302,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4561,  0.0000,  0.0000, -0.2511,  0.0000, -0.2542],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.4885e-01, -2.7000e-09, -6.3930e-02,  1.0563e-01,  2.1523e+00,
-         7.1791e-02,  1.6343e-18, -1.2905e-11,  2.4376e-02, -5.7642e-02,
-        -4.2179e-02,  2.9381e-10, -3.2329e-01, -5.4201e-01, -2.9308e-15,
-        -2.1783e-12, -2.1427e-13, -9.7448e-02, -4.9559e-16,  5.9586e-12,
-         5.5199e-02,  3.6746e-02, -1.0524e-15, -8.6701e-01, -8.7867e-02,
-         8.5418e-12,  2.9679e-01,  1.5468e-01,  2.1523e-01,  1.5306e-15,
-        -2.6481e-01, -1.1969e+00, -3.9181e-19,  1.2046e-01,  1.9448e-13,
-         1.0493e-11,  6.4873e-15,  0.0000e+00,  1.7330e-01,  4.8980e-01,
-         1.4973e-01, -8.6649e-13, -3.9479e-01,  1.5966e-15,  3.9051e-13,
-        -5.7992e-01, -5.5100e-11,  1.7779e-04,  1.8535e-10, -1.0164e-16,
-        -3.2887e-07, -6.8382e-02, -2.5192e-01,  3.1946e-02, -3.7247e-08,
-         2.1153e-20, -8.3136e-15,  2.9899e-18,  4.4105e-01, -4.3180e-10,
-         7.9867e-18, -2.4006e-01,  6.7104e-06, -3.0129e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4489,  0.0000, -0.0639,  0.1056,  2.1523,  0.0718,  0.0000,  0.0000,
-         0.0244, -0.0576, -0.0422,  0.0000, -0.3233, -0.5420,  0.0000,  0.0000,
-         0.0000, -0.0974,  0.0000,  0.0000,  0.0552,  0.0367,  0.0000, -0.8670,
-        -0.0879,  0.0000,  0.2968,  0.1547,  0.2152,  0.0000, -0.2648, -1.1969,
-         0.0000,  0.1205,  0.0000,  0.0000,  0.0000,  0.0000,  0.1733,  0.4898,
-         0.1497,  0.0000, -0.3948,  0.0000,  0.0000, -0.5799,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0684, -0.2519,  0.0319,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4410,  0.0000,  0.0000, -0.2401,  0.0000, -0.3013],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4489,  0.0000, -0.0639,  0.1056,  2.1523,  0.0718,  0.0000,  0.0000,
-         0.0244, -0.0576, -0.0422,  0.0000, -0.3233, -0.5420,  0.0000,  0.0000,
-         0.0000, -0.0974,  0.0000,  0.0000,  0.0552,  0.0367,  0.0000, -0.8670,
-        -0.0879,  0.0000,  0.2968,  0.1547,  0.2152,  0.0000, -0.2648, -1.1969,
-         0.0000,  0.1205,  0.0000,  0.0000,  0.0000,  0.0000,  0.1733,  0.4898,
-         0.1497,  0.0000, -0.3948,  0.0000,  0.0000, -0.5799,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0684, -0.2519,  0.0319,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4410,  0.0000,  0.0000, -0.2401,  0.0000, -0.3013],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.5226e-01, -2.4470e-09, -7.0564e-02,  9.7122e-02,  2.1540e+00,
-         6.0743e-02,  1.4812e-18, -1.1696e-11,  1.5772e-02, -4.9603e-02,
-        -3.8418e-03,  2.6628e-10, -3.5757e-01, -5.1838e-01, -2.6562e-15,
-        -1.9742e-12, -1.9419e-13, -8.3801e-02, -4.4915e-16,  5.4003e-12,
-         5.6428e-02,  4.6823e-02, -9.5383e-16, -8.5112e-01, -7.0659e-02,
-         7.7414e-12,  2.8190e-01,  1.7924e-01,  2.2288e-01,  1.3872e-15,
-        -1.9139e-01, -1.1950e+00, -3.5510e-19,  1.5786e-01,  1.7625e-13,
-         9.5099e-12,  5.8794e-15,  0.0000e+00,  1.2869e-01,  4.9997e-01,
-         1.8625e-01, -7.8529e-13, -3.9704e-01,  1.4470e-15,  3.5391e-13,
-        -5.7498e-01, -4.9936e-11,  1.6113e-04,  1.6799e-10, -9.2117e-17,
-        -2.9805e-07, -5.8489e-02, -2.5481e-01,  2.2397e-02, -3.3757e-08,
-         1.9171e-20, -7.5346e-15,  2.7097e-18,  4.2627e-01, -3.9134e-10,
-         7.2383e-18, -2.2028e-01,  6.0816e-06, -3.3484e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4523,  0.0000, -0.0706,  0.0971,  2.1540,  0.0607,  0.0000,  0.0000,
-         0.0158, -0.0496, -0.0038,  0.0000, -0.3576, -0.5184,  0.0000,  0.0000,
-         0.0000, -0.0838,  0.0000,  0.0000,  0.0564,  0.0468,  0.0000, -0.8511,
-        -0.0707,  0.0000,  0.2819,  0.1792,  0.2229,  0.0000, -0.1914, -1.1950,
-         0.0000,  0.1579,  0.0000,  0.0000,  0.0000,  0.0000,  0.1287,  0.5000,
-         0.1862,  0.0000, -0.3970,  0.0000,  0.0000, -0.5750,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0585, -0.2548,  0.0224,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4263,  0.0000,  0.0000, -0.2203,  0.0000, -0.3348],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4523,  0.0000, -0.0706,  0.0971,  2.1540,  0.0607,  0.0000,  0.0000,
-         0.0158, -0.0496, -0.0038,  0.0000, -0.3576, -0.5184,  0.0000,  0.0000,
-         0.0000, -0.0838,  0.0000,  0.0000,  0.0564,  0.0468,  0.0000, -0.8511,
-        -0.0707,  0.0000,  0.2819,  0.1792,  0.2229,  0.0000, -0.1914, -1.1950,
-         0.0000,  0.1579,  0.0000,  0.0000,  0.0000,  0.0000,  0.1287,  0.5000,
-         0.1862,  0.0000, -0.3970,  0.0000,  0.0000, -0.5750,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0585, -0.2548,  0.0224,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4263,  0.0000,  0.0000, -0.2203,  0.0000, -0.3348],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.5117e-01, -2.2184e-09, -6.8056e-02,  9.7886e-02,  2.1546e+00,
-         6.9671e-02,  1.3428e-18, -1.0603e-11, -4.4980e-03, -3.1720e-02,
-         1.4815e-02,  2.4140e-10, -3.8746e-01, -5.0271e-01, -2.4081e-15,
-        -1.7898e-12, -1.7605e-13, -7.9299e-02, -4.0719e-16,  4.8958e-12,
-         7.3800e-02,  3.7792e-02, -8.6473e-16, -8.4892e-01, -5.4301e-02,
-         7.0182e-12,  2.8454e-01,  2.1942e-01,  2.3870e-01,  1.2576e-15,
-        -1.2650e-01, -1.1929e+00, -3.2193e-19,  1.9131e-01,  1.5979e-13,
-         8.6216e-12,  5.3302e-15,  0.0000e+00,  8.4757e-02,  5.1084e-01,
-         2.0495e-01, -7.1193e-13, -3.9683e-01,  1.3118e-15,  3.2085e-13,
-        -5.6878e-01, -4.5272e-11,  1.4608e-04,  1.5229e-10, -8.3512e-17,
-        -2.7021e-07, -4.3175e-02, -2.5561e-01,  1.0178e-02, -3.0603e-08,
-         1.7380e-20, -6.8308e-15,  2.4566e-18,  4.2125e-01, -3.5478e-10,
-         6.5621e-18, -1.9989e-01,  5.5135e-06, -3.5772e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4512,  0.0000, -0.0681,  0.0979,  2.1546,  0.0697,  0.0000,  0.0000,
-        -0.0045, -0.0317,  0.0148,  0.0000, -0.3875, -0.5027,  0.0000,  0.0000,
-         0.0000, -0.0793,  0.0000,  0.0000,  0.0738,  0.0378,  0.0000, -0.8489,
-        -0.0543,  0.0000,  0.2845,  0.2194,  0.2387,  0.0000, -0.1265, -1.1929,
-         0.0000,  0.1913,  0.0000,  0.0000,  0.0000,  0.0000,  0.0848,  0.5108,
-         0.2050,  0.0000, -0.3968,  0.0000,  0.0000, -0.5688,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0432, -0.2556,  0.0102,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4213,  0.0000,  0.0000, -0.1999,  0.0000, -0.3577],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4512,  0.0000, -0.0681,  0.0979,  2.1546,  0.0697,  0.0000,  0.0000,
-        -0.0045, -0.0317,  0.0148,  0.0000, -0.3875, -0.5027,  0.0000,  0.0000,
-         0.0000, -0.0793,  0.0000,  0.0000,  0.0738,  0.0378,  0.0000, -0.8489,
-        -0.0543,  0.0000,  0.2845,  0.2194,  0.2387,  0.0000, -0.1265, -1.1929,
-         0.0000,  0.1913,  0.0000,  0.0000,  0.0000,  0.0000,  0.0848,  0.5108,
-         0.2050,  0.0000, -0.3968,  0.0000,  0.0000, -0.5688,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0432, -0.2556,  0.0102,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4213,  0.0000,  0.0000, -0.1999,  0.0000, -0.3577],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.5410e-01, -2.0118e-09, -5.3254e-02,  9.1695e-02,  2.1544e+00,
-         5.7660e-02,  1.2178e-18, -9.6158e-12, -1.7433e-02, -8.7435e-03,
-         3.6701e-02,  2.1892e-10, -3.9375e-01, -4.9025e-01, -2.1838e-15,
-        -1.6231e-12, -1.5966e-13, -7.8029e-02, -3.6927e-16,  4.4399e-12,
-         6.7282e-02,  2.6136e-02, -7.8420e-16, -8.6148e-01, -3.3450e-02,
-         6.3647e-12,  2.7109e-01,  2.2843e-01,  2.3383e-01,  1.1405e-15,
-        -9.2931e-02, -1.1907e+00, -2.9195e-19,  2.0636e-01,  1.4491e-13,
-         7.8187e-12,  4.8338e-15,  0.0000e+00,  7.1849e-02,  5.3892e-01,
-         2.1767e-01, -6.4564e-13, -3.8712e-01,  1.1896e-15,  2.9098e-13,
-        -5.4051e-01, -4.1056e-11,  1.3248e-04,  1.3811e-10, -7.5735e-17,
-        -2.4504e-07, -1.0451e-02, -2.5879e-01, -2.8148e-03, -2.7753e-08,
-         1.5762e-20, -6.1947e-15,  2.2278e-18,  4.0812e-01, -3.2175e-10,
-         5.9510e-18, -1.7739e-01,  5.0000e-06, -3.8542e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4541,  0.0000, -0.0533,  0.0917,  2.1544,  0.0577,  0.0000,  0.0000,
-        -0.0174, -0.0087,  0.0367,  0.0000, -0.3937, -0.4903,  0.0000,  0.0000,
-         0.0000, -0.0780,  0.0000,  0.0000,  0.0673,  0.0261,  0.0000, -0.8615,
-        -0.0335,  0.0000,  0.2711,  0.2284,  0.2338,  0.0000, -0.0929, -1.1907,
-         0.0000,  0.2064,  0.0000,  0.0000,  0.0000,  0.0000,  0.0718,  0.5389,
-         0.2177,  0.0000, -0.3871,  0.0000,  0.0000, -0.5405,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0105, -0.2588, -0.0028,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4081,  0.0000,  0.0000, -0.1774,  0.0000, -0.3854],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4541,  0.0000, -0.0533,  0.0917,  2.1544,  0.0577,  0.0000,  0.0000,
-        -0.0174, -0.0087,  0.0367,  0.0000, -0.3937, -0.4903,  0.0000,  0.0000,
-         0.0000, -0.0780,  0.0000,  0.0000,  0.0673,  0.0261,  0.0000, -0.8615,
-        -0.0335,  0.0000,  0.2711,  0.2284,  0.2338,  0.0000, -0.0929, -1.1907,
-         0.0000,  0.2064,  0.0000,  0.0000,  0.0000,  0.0000,  0.0718,  0.5389,
-         0.2177,  0.0000, -0.3871,  0.0000,  0.0000, -0.5405,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0105, -0.2588, -0.0028,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4081,  0.0000,  0.0000, -0.1774,  0.0000, -0.3854],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.3936e-01, -1.8250e-09, -3.3250e-02,  6.8244e-02,  2.1550e+00,
-         3.5034e-02,  1.1047e-18, -8.7231e-12, -2.6099e-02, -1.1488e-02,
-         8.1688e-02,  1.9860e-10, -3.8277e-01, -4.8096e-01, -1.9811e-15,
-        -1.4724e-12, -1.4484e-13, -9.5814e-02, -3.3499e-16,  4.0277e-12,
-         4.6840e-02, -1.3753e-02, -7.1140e-16, -8.7648e-01, -2.8910e-02,
-         5.7738e-12,  2.5531e-01,  2.1836e-01,  2.2211e-01,  1.0346e-15,
-        -8.4111e-02, -1.1892e+00, -2.6484e-19,  2.1351e-01,  1.3146e-13,
-         7.0929e-12,  4.3851e-15,  0.0000e+00,  8.6293e-02,  5.7013e-01,
-         2.2374e-01, -5.8570e-13, -3.7380e-01,  1.0792e-15,  2.6396e-13,
-        -5.0942e-01, -3.7244e-11,  1.2018e-04,  1.2529e-10, -6.8705e-17,
-        -2.2230e-07,  2.0181e-02, -2.4948e-01, -6.8828e-02, -2.5177e-08,
-         1.4299e-20, -5.6196e-15,  2.0210e-18,  4.0100e-01, -2.9188e-10,
-         5.3986e-18, -1.6735e-01,  4.5359e-06, -4.0563e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4394,  0.0000, -0.0333,  0.0682,  2.1550,  0.0350,  0.0000,  0.0000,
-        -0.0261, -0.0115,  0.0817,  0.0000, -0.3828, -0.4810,  0.0000,  0.0000,
-         0.0000, -0.0958,  0.0000,  0.0000,  0.0468, -0.0138,  0.0000, -0.8765,
-        -0.0289,  0.0000,  0.2553,  0.2184,  0.2221,  0.0000, -0.0841, -1.1892,
-         0.0000,  0.2135,  0.0000,  0.0000,  0.0000,  0.0000,  0.0863,  0.5701,
-         0.2237,  0.0000, -0.3738,  0.0000,  0.0000, -0.5094,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0202, -0.2495, -0.0688,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4010,  0.0000,  0.0000, -0.1673,  0.0000, -0.4056],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4394,  0.0000, -0.0333,  0.0682,  2.1550,  0.0350,  0.0000,  0.0000,
-        -0.0261, -0.0115,  0.0817,  0.0000, -0.3828, -0.4810,  0.0000,  0.0000,
-         0.0000, -0.0958,  0.0000,  0.0000,  0.0468, -0.0138,  0.0000, -0.8765,
-        -0.0289,  0.0000,  0.2553,  0.2184,  0.2221,  0.0000, -0.0841, -1.1892,
-         0.0000,  0.2135,  0.0000,  0.0000,  0.0000,  0.0000,  0.0863,  0.5701,
-         0.2237,  0.0000, -0.3738,  0.0000,  0.0000, -0.5094,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0202, -0.2495, -0.0688,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4010,  0.0000,  0.0000, -0.1673,  0.0000, -0.4056],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.1626e-01, -1.6561e-09, -5.6330e-03,  5.2609e-02,  2.1543e+00,
-         1.0344e-02,  1.0025e-18, -7.9158e-12, -2.8198e-02, -1.2641e-02,
-         1.3531e-01,  1.8022e-10, -3.6197e-01, -4.8003e-01, -1.7977e-15,
-        -1.3361e-12, -1.3143e-13, -1.1626e-01, -3.0399e-16,  3.6550e-12,
-         2.6060e-02, -7.2710e-02, -6.4556e-16, -8.9298e-01, -2.3719e-02,
-         5.2394e-12,  2.5809e-01,  2.1408e-01,  1.9543e-01,  9.3888e-16,
-        -5.8618e-02, -1.1883e+00, -2.4033e-19,  2.0416e-01,  1.1929e-13,
-         6.4364e-12,  3.9792e-15,  0.0000e+00,  8.9009e-02,  5.9202e-01,
-         2.1928e-01, -5.3149e-13, -3.5344e-01,  9.7933e-16,  2.3953e-13,
-        -4.5633e-01, -3.3797e-11,  1.0906e-04,  1.1369e-10, -6.2346e-17,
-        -2.0172e-07,  5.1264e-02, -2.3042e-01, -1.2935e-01, -2.2847e-08,
-         1.2975e-20, -5.0995e-15,  1.8340e-18,  4.0559e-01, -2.6486e-10,
-         4.8989e-18, -1.3574e-01,  4.1161e-06, -4.1011e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.4163,  0.0000, -0.0056,  0.0526,  2.1543,  0.0103,  0.0000,  0.0000,
-        -0.0282, -0.0126,  0.1353,  0.0000, -0.3620, -0.4800,  0.0000,  0.0000,
-         0.0000, -0.1163,  0.0000,  0.0000,  0.0261, -0.0727,  0.0000, -0.8930,
-        -0.0237,  0.0000,  0.2581,  0.2141,  0.1954,  0.0000, -0.0586, -1.1883,
-         0.0000,  0.2042,  0.0000,  0.0000,  0.0000,  0.0000,  0.0890,  0.5920,
-         0.2193,  0.0000, -0.3534,  0.0000,  0.0000, -0.4563,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0513, -0.2304, -0.1293,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4056,  0.0000,  0.0000, -0.1357,  0.0000, -0.4101],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.4163,  0.0000, -0.0056,  0.0526,  2.1543,  0.0103,  0.0000,  0.0000,
-        -0.0282, -0.0126,  0.1353,  0.0000, -0.3620, -0.4800,  0.0000,  0.0000,
-         0.0000, -0.1163,  0.0000,  0.0000,  0.0261, -0.0727,  0.0000, -0.8930,
-        -0.0237,  0.0000,  0.2581,  0.2141,  0.1954,  0.0000, -0.0586, -1.1883,
-         0.0000,  0.2042,  0.0000,  0.0000,  0.0000,  0.0000,  0.0890,  0.5920,
-         0.2193,  0.0000, -0.3534,  0.0000,  0.0000, -0.4563,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0513, -0.2304, -0.1293,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4056,  0.0000,  0.0000, -0.1357,  0.0000, -0.4101],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.9457e-01, -1.5033e-09,  1.2561e-02,  3.3995e-02,  2.1531e+00,
-        -2.3256e-02,  9.0999e-19, -7.1854e-12, -1.0905e-02, -1.6835e-02,
-         1.9659e-01,  1.6359e-10, -3.4158e-01, -4.7913e-01, -1.6319e-15,
-        -1.2129e-12, -1.1930e-13, -1.4442e-01, -2.7594e-16,  3.3177e-12,
-        -2.4849e-03, -1.2110e-01, -5.8600e-16, -9.0865e-01, -1.7995e-02,
-         4.7560e-12,  2.6374e-01,  1.9614e-01,  1.4722e-01,  8.5225e-16,
-        -3.6158e-02, -1.1871e+00, -2.1816e-19,  1.8914e-01,  1.0828e-13,
-         5.8426e-12,  3.6121e-15,  0.0000e+00,  1.0360e-01,  6.1012e-01,
-         2.1354e-01, -4.8245e-13, -3.3471e-01,  8.8897e-16,  2.1743e-13,
-        -3.8156e-01, -3.0679e-11,  9.8993e-05,  1.0320e-10, -5.6593e-17,
-        -1.8311e-07,  7.4471e-02, -2.0557e-01, -1.7670e-01, -2.0739e-08,
-         1.1778e-20, -4.6290e-15,  1.6647e-18,  4.0100e-01, -2.4043e-10,
-         4.4469e-18, -9.1986e-02,  3.7363e-06, -4.1120e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3946,  0.0000,  0.0126,  0.0340,  2.1531, -0.0233,  0.0000,  0.0000,
-        -0.0109, -0.0168,  0.1966,  0.0000, -0.3416, -0.4791,  0.0000,  0.0000,
-         0.0000, -0.1444,  0.0000,  0.0000, -0.0025, -0.1211,  0.0000, -0.9087,
-        -0.0180,  0.0000,  0.2637,  0.1961,  0.1472,  0.0000, -0.0362, -1.1871,
-         0.0000,  0.1891,  0.0000,  0.0000,  0.0000,  0.0000,  0.1036,  0.6101,
-         0.2135,  0.0000, -0.3347,  0.0000,  0.0000, -0.3816,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0745, -0.2056, -0.1767,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4010,  0.0000,  0.0000, -0.0920,  0.0000, -0.4112],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3946,  0.0000,  0.0126,  0.0340,  2.1531, -0.0233,  0.0000,  0.0000,
-        -0.0109, -0.0168,  0.1966,  0.0000, -0.3416, -0.4791,  0.0000,  0.0000,
-         0.0000, -0.1444,  0.0000,  0.0000, -0.0025, -0.1211,  0.0000, -0.9087,
-        -0.0180,  0.0000,  0.2637,  0.1961,  0.1472,  0.0000, -0.0362, -1.1871,
-         0.0000,  0.1891,  0.0000,  0.0000,  0.0000,  0.0000,  0.1036,  0.6101,
-         0.2135,  0.0000, -0.3347,  0.0000,  0.0000, -0.3816,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0745, -0.2056, -0.1767,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4010,  0.0000,  0.0000, -0.0920,  0.0000, -0.4112],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7359e-01, -1.3650e-09,  3.3985e-02,  2.3786e-02,  2.1518e+00,
-        -6.9313e-02,  8.2628e-19, -6.5244e-12,  6.9137e-03, -2.1519e-02,
-         2.3598e-01,  1.4854e-10, -3.2082e-01, -4.8588e-01, -1.4818e-15,
-        -1.1013e-12, -1.0833e-13, -1.4051e-01, -2.5056e-16,  3.0125e-12,
-        -2.7577e-02, -1.4907e-01, -5.3209e-16, -9.1802e-01, -2.5576e-02,
-         4.3185e-12,  2.7453e-01,  1.8112e-01,  1.0130e-01,  7.7385e-16,
-        -1.4312e-02, -1.1850e+00, -1.9809e-19,  1.7486e-01,  9.8323e-14,
-         5.3051e-12,  3.2798e-15,  0.0000e+00,  1.3591e-01,  6.3475e-01,
-         1.8978e-01, -4.3807e-13, -3.2194e-01,  8.0719e-16,  1.9743e-13,
-        -3.2082e-01, -2.7857e-11,  8.9887e-05,  9.3711e-11, -5.1388e-17,
-        -1.6627e-07,  8.6494e-02, -1.7494e-01, -1.6283e-01, -1.8831e-08,
-         1.0695e-20, -4.2032e-15,  1.5116e-18,  3.9472e-01, -2.1831e-10,
-         4.0379e-18, -3.9465e-02,  3.3926e-06, -4.0800e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3736,  0.0000,  0.0340,  0.0238,  2.1518, -0.0693,  0.0000,  0.0000,
-         0.0069, -0.0215,  0.2360,  0.0000, -0.3208, -0.4859,  0.0000,  0.0000,
-         0.0000, -0.1405,  0.0000,  0.0000, -0.0276, -0.1491,  0.0000, -0.9180,
-        -0.0256,  0.0000,  0.2745,  0.1811,  0.1013,  0.0000, -0.0143, -1.1850,
-         0.0000,  0.1749,  0.0000,  0.0000,  0.0000,  0.0000,  0.1359,  0.6347,
-         0.1898,  0.0000, -0.3219,  0.0000,  0.0000, -0.3208,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0865, -0.1749, -0.1628,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3947,  0.0000,  0.0000, -0.0395,  0.0000, -0.4080],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3736,  0.0000,  0.0340,  0.0238,  2.1518, -0.0693,  0.0000,  0.0000,
-         0.0069, -0.0215,  0.2360,  0.0000, -0.3208, -0.4859,  0.0000,  0.0000,
-         0.0000, -0.1405,  0.0000,  0.0000, -0.0276, -0.1491,  0.0000, -0.9180,
-        -0.0256,  0.0000,  0.2745,  0.1811,  0.1013,  0.0000, -0.0143, -1.1850,
-         0.0000,  0.1749,  0.0000,  0.0000,  0.0000,  0.0000,  0.1359,  0.6347,
-         0.1898,  0.0000, -0.3219,  0.0000,  0.0000, -0.3208,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0865, -0.1749, -0.1628,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3947,  0.0000,  0.0000, -0.0395,  0.0000, -0.4080],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4993e-01, -1.2398e-09,  3.9284e-02,  1.5900e-02,  2.1497e+00,
-        -1.1369e-01,  7.5050e-19, -5.9261e-12,  3.3193e-02, -2.6449e-02,
-         2.6991e-01,  1.3492e-10, -3.1109e-01, -4.8825e-01, -1.3459e-15,
-        -1.0003e-12, -9.8394e-14, -1.3512e-01, -2.2758e-16,  2.7363e-12,
-        -4.3216e-02, -1.5421e-01, -4.8329e-16, -9.1541e-01, -2.5324e-02,
-         3.9225e-12,  2.7860e-01,  1.6416e-01,  5.7777e-02,  7.0288e-16,
-        -9.3030e-03, -1.1830e+00, -1.7992e-19,  1.6200e-01,  8.9305e-14,
-         4.8186e-12,  2.9790e-15,  0.0000e+00,  1.3628e-01,  6.4147e-01,
-         1.6949e-01, -3.9790e-13, -3.2301e-01,  7.3316e-16,  1.7932e-13,
-        -2.6976e-01, -2.5302e-11,  8.1643e-05,  8.5116e-11, -4.6675e-17,
-        -1.5102e-07,  9.0189e-02, -1.4838e-01, -1.6012e-01, -1.7104e-08,
-         9.7138e-21, -3.8177e-15,  1.3730e-18,  3.9614e-01, -1.9829e-10,
-         3.6676e-18,  3.5202e-03,  3.0814e-06, -4.0020e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3499,  0.0000,  0.0393,  0.0159,  2.1497, -0.1137,  0.0000,  0.0000,
-         0.0332, -0.0264,  0.2699,  0.0000, -0.3111, -0.4882,  0.0000,  0.0000,
-         0.0000, -0.1351,  0.0000,  0.0000, -0.0432, -0.1542,  0.0000, -0.9154,
-        -0.0253,  0.0000,  0.2786,  0.1642,  0.0578,  0.0000, -0.0093, -1.1830,
-         0.0000,  0.1620,  0.0000,  0.0000,  0.0000,  0.0000,  0.1363,  0.6415,
-         0.1695,  0.0000, -0.3230,  0.0000,  0.0000, -0.2698,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0902, -0.1484, -0.1601,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3961,  0.0000,  0.0000,  0.0035,  0.0000, -0.4002],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3499,  0.0000,  0.0393,  0.0159,  2.1497, -0.1137,  0.0000,  0.0000,
-         0.0332, -0.0264,  0.2699,  0.0000, -0.3111, -0.4882,  0.0000,  0.0000,
-         0.0000, -0.1351,  0.0000,  0.0000, -0.0432, -0.1542,  0.0000, -0.9154,
-        -0.0253,  0.0000,  0.2786,  0.1642,  0.0578,  0.0000, -0.0093, -1.1830,
-         0.0000,  0.1620,  0.0000,  0.0000,  0.0000,  0.0000,  0.1363,  0.6415,
-         0.1695,  0.0000, -0.3230,  0.0000,  0.0000, -0.2698,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0902, -0.1484, -0.1601,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3961,  0.0000,  0.0000,  0.0035,  0.0000, -0.4002],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2802e-01, -1.1265e-09,  4.1982e-02, -1.0714e-02,  2.1484e+00,
-        -1.6262e-01,  6.8188e-19, -5.3842e-12,  6.7824e-02, -3.8583e-02,
-         2.9242e-01,  1.2258e-10, -3.0277e-01, -4.9186e-01, -1.2228e-15,
-        -9.0883e-13, -8.9397e-14, -1.3027e-01, -2.0677e-16,  2.4861e-12,
-        -4.7100e-02, -1.5671e-01, -4.3910e-16, -9.1194e-01, -3.2335e-02,
-         3.5638e-12,  2.7844e-01,  1.4389e-01,  2.7577e-02,  6.3861e-16,
-        -1.6871e-02, -1.1807e+00, -1.6347e-19,  1.6246e-01,  8.1139e-14,
-         4.3780e-12,  2.7066e-15,  0.0000e+00,  1.3042e-01,  6.4096e-01,
-         1.5434e-01, -3.6151e-13, -3.2892e-01,  6.6612e-16,  1.6293e-13,
-        -2.4004e-01, -2.2989e-11,  7.4178e-05,  7.7334e-11, -4.2407e-17,
-        -1.3721e-07,  1.0037e-01, -1.0142e-01, -1.6728e-01, -1.5540e-08,
-         8.8256e-21, -3.4686e-15,  1.2474e-18,  4.0574e-01, -1.8016e-10,
-         3.3322e-18,  3.5225e-02,  2.7997e-06, -3.8769e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3280,  0.0000,  0.0420, -0.0107,  2.1484, -0.1626,  0.0000,  0.0000,
-         0.0678, -0.0386,  0.2924,  0.0000, -0.3028, -0.4919,  0.0000,  0.0000,
-         0.0000, -0.1303,  0.0000,  0.0000, -0.0471, -0.1567,  0.0000, -0.9119,
-        -0.0323,  0.0000,  0.2784,  0.1439,  0.0276,  0.0000, -0.0169, -1.1807,
-         0.0000,  0.1625,  0.0000,  0.0000,  0.0000,  0.0000,  0.1304,  0.6410,
-         0.1543,  0.0000, -0.3289,  0.0000,  0.0000, -0.2400,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1004, -0.1014, -0.1673,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4057,  0.0000,  0.0000,  0.0352,  0.0000, -0.3877],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3280,  0.0000,  0.0420, -0.0107,  2.1484, -0.1626,  0.0000,  0.0000,
-         0.0678, -0.0386,  0.2924,  0.0000, -0.3028, -0.4919,  0.0000,  0.0000,
-         0.0000, -0.1303,  0.0000,  0.0000, -0.0471, -0.1567,  0.0000, -0.9119,
-        -0.0323,  0.0000,  0.2784,  0.1439,  0.0276,  0.0000, -0.0169, -1.1807,
-         0.0000,  0.1625,  0.0000,  0.0000,  0.0000,  0.0000,  0.1304,  0.6410,
-         0.1543,  0.0000, -0.3289,  0.0000,  0.0000, -0.2400,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1004, -0.1014, -0.1673,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4057,  0.0000,  0.0000,  0.0352,  0.0000, -0.3877],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9778e-01, -1.0238e-09,  3.5561e-02, -4.6637e-02,  2.1467e+00,
-        -2.2620e-01,  6.1971e-19, -4.8933e-12,  8.7436e-02, -4.8246e-02,
-         3.1137e-01,  1.1141e-10, -3.0412e-01, -4.9416e-01, -1.1113e-15,
-        -8.2597e-13, -8.1247e-14, -1.0412e-01, -1.8792e-16,  2.2594e-12,
-        -3.9505e-02, -1.3160e-01, -3.9907e-16, -8.9415e-01, -2.0927e-02,
-         3.2389e-12,  2.6271e-01,  1.2991e-01,  4.2825e-03,  5.8039e-16,
-        -3.5290e-02, -1.1795e+00, -1.4857e-19,  1.5355e-01,  7.3742e-14,
-         3.9788e-12,  2.4599e-15,  0.0000e+00,  1.0704e-01,  6.4569e-01,
-         1.6329e-01, -3.2856e-13, -3.4696e-01,  6.0540e-16,  1.4807e-13,
-        -2.3996e-01, -2.0893e-11,  6.7415e-05,  7.0283e-11, -3.8541e-17,
-        -1.2470e-07,  1.2206e-01, -6.7998e-02, -1.3638e-01, -1.4123e-08,
-         8.0210e-21, -3.1524e-15,  1.1337e-18,  4.0806e-01, -1.6373e-10,
-         3.0284e-18,  5.0050e-02,  2.5444e-06, -3.8032e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2978,  0.0000,  0.0356, -0.0466,  2.1467, -0.2262,  0.0000,  0.0000,
-         0.0874, -0.0482,  0.3114,  0.0000, -0.3041, -0.4942,  0.0000,  0.0000,
-         0.0000, -0.1041,  0.0000,  0.0000, -0.0395, -0.1316,  0.0000, -0.8941,
-        -0.0209,  0.0000,  0.2627,  0.1299,  0.0043,  0.0000, -0.0353, -1.1795,
-         0.0000,  0.1535,  0.0000,  0.0000,  0.0000,  0.0000,  0.1070,  0.6457,
-         0.1633,  0.0000, -0.3470,  0.0000,  0.0000, -0.2400,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1221, -0.0680, -0.1364,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4081,  0.0000,  0.0000,  0.0501,  0.0000, -0.3803],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2978,  0.0000,  0.0356, -0.0466,  2.1467, -0.2262,  0.0000,  0.0000,
-         0.0874, -0.0482,  0.3114,  0.0000, -0.3041, -0.4942,  0.0000,  0.0000,
-         0.0000, -0.1041,  0.0000,  0.0000, -0.0395, -0.1316,  0.0000, -0.8941,
-        -0.0209,  0.0000,  0.2627,  0.1299,  0.0043,  0.0000, -0.0353, -1.1795,
-         0.0000,  0.1535,  0.0000,  0.0000,  0.0000,  0.0000,  0.1070,  0.6457,
-         0.1633,  0.0000, -0.3470,  0.0000,  0.0000, -0.2400,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1221, -0.0680, -0.1364,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4081,  0.0000,  0.0000,  0.0501,  0.0000, -0.3803],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6461e-01, -9.3072e-10,  3.4359e-02, -5.7816e-02,  2.1451e+00,
-        -2.8821e-01,  5.6338e-19, -4.4485e-12,  1.0462e-01, -5.7450e-02,
-         3.2412e-01,  1.0128e-10, -3.0550e-01, -4.9935e-01, -1.0103e-15,
-        -7.5089e-13, -7.3862e-14, -5.2230e-02, -1.7084e-16,  2.0540e-12,
-        -3.4066e-02, -1.0074e-01, -3.6279e-16, -8.8205e-01, -1.1467e-02,
-         2.9445e-12,  2.5590e-01,  1.0575e-01, -2.7657e-02,  5.2763e-16,
-        -7.1769e-02, -1.1788e+00, -1.3506e-19,  1.3755e-01,  6.7039e-14,
-         3.6172e-12,  2.2363e-15,  0.0000e+00,  9.8881e-02,  6.5189e-01,
-         1.6798e-01, -2.9869e-13, -3.6732e-01,  5.5036e-16,  1.3461e-13,
-        -2.4572e-01, -1.8994e-11,  6.1287e-05,  6.3894e-11, -3.5037e-17,
-        -1.1336e-07,  1.2151e-01, -3.0953e-02, -5.6835e-02, -1.2840e-08,
-         7.2919e-21, -2.8658e-15,  1.0307e-18,  4.1274e-01, -1.4885e-10,
-         2.7531e-18,  5.8276e-02,  2.3132e-06, -3.7071e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2646,  0.0000,  0.0344, -0.0578,  2.1451, -0.2882,  0.0000,  0.0000,
-         0.1046, -0.0574,  0.3241,  0.0000, -0.3055, -0.4993,  0.0000,  0.0000,
-         0.0000, -0.0522,  0.0000,  0.0000, -0.0341, -0.1007,  0.0000, -0.8821,
-        -0.0115,  0.0000,  0.2559,  0.1058, -0.0277,  0.0000, -0.0718, -1.1788,
-         0.0000,  0.1375,  0.0000,  0.0000,  0.0000,  0.0000,  0.0989,  0.6519,
-         0.1680,  0.0000, -0.3673,  0.0000,  0.0000, -0.2457,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1215, -0.0310, -0.0568,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4127,  0.0000,  0.0000,  0.0583,  0.0000, -0.3707],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2646,  0.0000,  0.0344, -0.0578,  2.1451, -0.2882,  0.0000,  0.0000,
-         0.1046, -0.0574,  0.3241,  0.0000, -0.3055, -0.4993,  0.0000,  0.0000,
-         0.0000, -0.0522,  0.0000,  0.0000, -0.0341, -0.1007,  0.0000, -0.8821,
-        -0.0115,  0.0000,  0.2559,  0.1058, -0.0277,  0.0000, -0.0718, -1.1788,
-         0.0000,  0.1375,  0.0000,  0.0000,  0.0000,  0.0000,  0.0989,  0.6519,
-         0.1680,  0.0000, -0.3673,  0.0000,  0.0000, -0.2457,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1215, -0.0310, -0.0568,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4127,  0.0000,  0.0000,  0.0583,  0.0000, -0.3707],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.3627e-01, -8.4636e-10,  3.9656e-02, -4.8489e-02,  2.1446e+00,
-        -3.2983e-01,  5.1232e-19, -4.0453e-12,  1.0507e-01, -6.4026e-02,
-         3.1913e-01,  9.2100e-11, -2.9565e-01, -5.0587e-01, -9.1873e-16,
-        -6.8283e-13, -6.7167e-14, -1.4495e-02, -1.5535e-16,  1.8679e-12,
-        -2.6238e-02, -5.9370e-02, -3.2991e-16, -8.6904e-01,  2.5548e-02,
-         2.6776e-12,  2.6142e-01,  9.6673e-02, -4.9599e-02,  4.7981e-16,
-        -9.6917e-02, -1.1805e+00, -1.2282e-19,  1.1171e-01,  6.0963e-14,
-         3.2893e-12,  2.0336e-15,  0.0000e+00,  9.2156e-02,  6.5867e-01,
-         1.6305e-01, -2.7162e-13, -3.8061e-01,  5.0048e-16,  1.2241e-13,
-        -2.6370e-01, -1.7272e-11,  5.5732e-05,  5.8103e-11, -3.1862e-17,
-        -1.0309e-07,  1.2435e-01, -3.1969e-02,  5.8059e-02, -1.1676e-08,
-         6.6310e-21, -2.6061e-15,  9.3724e-19,  4.1830e-01, -1.3536e-10,
-         2.5036e-18,  6.9783e-02,  2.1035e-06, -3.7293e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2363,  0.0000,  0.0397, -0.0485,  2.1446, -0.3298,  0.0000,  0.0000,
-         0.1051, -0.0640,  0.3191,  0.0000, -0.2957, -0.5059,  0.0000,  0.0000,
-         0.0000, -0.0145,  0.0000,  0.0000, -0.0262, -0.0594,  0.0000, -0.8690,
-         0.0255,  0.0000,  0.2614,  0.0967, -0.0496,  0.0000, -0.0969, -1.1805,
-         0.0000,  0.1117,  0.0000,  0.0000,  0.0000,  0.0000,  0.0922,  0.6587,
-         0.1631,  0.0000, -0.3806,  0.0000,  0.0000, -0.2637,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1244, -0.0320,  0.0581,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4183,  0.0000,  0.0000,  0.0698,  0.0000, -0.3729],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2363,  0.0000,  0.0397, -0.0485,  2.1446, -0.3298,  0.0000,  0.0000,
-         0.1051, -0.0640,  0.3191,  0.0000, -0.2957, -0.5059,  0.0000,  0.0000,
-         0.0000, -0.0145,  0.0000,  0.0000, -0.0262, -0.0594,  0.0000, -0.8690,
-         0.0255,  0.0000,  0.2614,  0.0967, -0.0496,  0.0000, -0.0969, -1.1805,
-         0.0000,  0.1117,  0.0000,  0.0000,  0.0000,  0.0000,  0.0922,  0.6587,
-         0.1631,  0.0000, -0.3806,  0.0000,  0.0000, -0.2637,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1244, -0.0320,  0.0581,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4183,  0.0000,  0.0000,  0.0698,  0.0000, -0.3729],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.1179e-01, -7.6987e-10,  4.2731e-02, -2.6399e-02,  2.1445e+00,
-        -3.6167e-01,  4.6602e-19, -3.6797e-12,  1.0061e-01, -4.0530e-02,
-         2.9783e-01,  8.3777e-11, -2.8906e-01, -4.9924e-01, -8.3570e-16,
-        -6.2112e-13, -6.1097e-14, -2.6944e-02, -1.4131e-16,  1.6990e-12,
-        -3.4883e-02, -3.6277e-03, -3.0010e-16, -8.6364e-01,  9.1054e-02,
-         2.4356e-12,  2.6050e-01,  7.6615e-02, -7.4913e-02,  4.3645e-16,
-        -1.2978e-01, -1.1846e+00, -1.1172e-19,  1.0069e-01,  5.5453e-14,
-         2.9920e-12,  1.8498e-15,  0.0000e+00,  8.1028e-02,  6.7074e-01,
-         1.6098e-01, -2.4707e-13, -3.8723e-01,  4.5525e-16,  1.1135e-13,
-        -2.8661e-01, -1.5711e-11,  5.0696e-05,  5.2852e-11, -2.8982e-17,
-        -9.3773e-08,  1.4992e-01, -5.8000e-02,  1.5812e-01, -1.0621e-08,
-         6.0317e-21, -2.3706e-15,  8.5254e-19,  4.1808e-01, -1.2312e-10,
-         2.2773e-18,  6.0177e-02,  1.9134e-06, -3.7393e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2118,  0.0000,  0.0427, -0.0264,  2.1445, -0.3617,  0.0000,  0.0000,
-         0.1006, -0.0405,  0.2978,  0.0000, -0.2891, -0.4992,  0.0000,  0.0000,
-         0.0000, -0.0269,  0.0000,  0.0000, -0.0349, -0.0036,  0.0000, -0.8636,
-         0.0911,  0.0000,  0.2605,  0.0766, -0.0749,  0.0000, -0.1298, -1.1846,
-         0.0000,  0.1007,  0.0000,  0.0000,  0.0000,  0.0000,  0.0810,  0.6707,
-         0.1610,  0.0000, -0.3872,  0.0000,  0.0000, -0.2866,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1499, -0.0580,  0.1581,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4181,  0.0000,  0.0000,  0.0602,  0.0000, -0.3739],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2118,  0.0000,  0.0427, -0.0264,  2.1445, -0.3617,  0.0000,  0.0000,
-         0.1006, -0.0405,  0.2978,  0.0000, -0.2891, -0.4992,  0.0000,  0.0000,
-         0.0000, -0.0269,  0.0000,  0.0000, -0.0349, -0.0036,  0.0000, -0.8636,
-         0.0911,  0.0000,  0.2605,  0.0766, -0.0749,  0.0000, -0.1298, -1.1846,
-         0.0000,  0.1007,  0.0000,  0.0000,  0.0000,  0.0000,  0.0810,  0.6707,
-         0.1610,  0.0000, -0.3872,  0.0000,  0.0000, -0.2866,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1499, -0.0580,  0.1581,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4181,  0.0000,  0.0000,  0.0602,  0.0000, -0.3739],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.9244e-01, -7.0049e-10,  5.1404e-02, -7.5829e-03,  2.1454e+00,
-        -3.8066e-01,  4.2402e-19, -3.3481e-12,  8.0606e-02, -7.8519e-03,
-         2.7277e-01,  7.6227e-11, -2.9122e-01, -4.8479e-01, -7.6039e-16,
-        -5.6515e-13, -5.5591e-14, -8.1403e-02, -1.2858e-16,  1.5459e-12,
-        -4.2187e-02,  5.4476e-02, -2.7305e-16, -8.6340e-01,  1.5457e-01,
-         2.2161e-12,  2.6065e-01,  5.7025e-02, -9.0545e-02,  3.9712e-16,
-        -1.6179e-01, -1.1868e+00, -1.0165e-19,  9.3628e-02,  5.0456e-14,
-         2.7224e-12,  1.6831e-15,  0.0000e+00,  9.3812e-02,  6.8901e-01,
-         1.5859e-01, -2.2481e-13, -3.8839e-01,  4.1422e-16,  1.0132e-13,
-        -3.0363e-01, -1.4295e-11,  4.6127e-05,  4.8089e-11, -2.6370e-17,
-        -8.5322e-08,  1.7372e-01, -8.5100e-02,  2.1206e-01, -9.6635e-09,
-         5.4881e-21, -2.1569e-15,  7.7571e-19,  4.0676e-01, -1.1203e-10,
-         2.0721e-18,  2.9324e-02,  1.7410e-06, -3.7920e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1924,  0.0000,  0.0514, -0.0076,  2.1454, -0.3807,  0.0000,  0.0000,
-         0.0806, -0.0079,  0.2728,  0.0000, -0.2912, -0.4848,  0.0000,  0.0000,
-         0.0000, -0.0814,  0.0000,  0.0000, -0.0422,  0.0545,  0.0000, -0.8634,
-         0.1546,  0.0000,  0.2606,  0.0570, -0.0905,  0.0000, -0.1618, -1.1868,
-         0.0000,  0.0936,  0.0000,  0.0000,  0.0000,  0.0000,  0.0938,  0.6890,
-         0.1586,  0.0000, -0.3884,  0.0000,  0.0000, -0.3036,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1737, -0.0851,  0.2121,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4068,  0.0000,  0.0000,  0.0293,  0.0000, -0.3792],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1924,  0.0000,  0.0514, -0.0076,  2.1454, -0.3807,  0.0000,  0.0000,
-         0.0806, -0.0079,  0.2728,  0.0000, -0.2912, -0.4848,  0.0000,  0.0000,
-         0.0000, -0.0814,  0.0000,  0.0000, -0.0422,  0.0545,  0.0000, -0.8634,
-         0.1546,  0.0000,  0.2606,  0.0570, -0.0905,  0.0000, -0.1618, -1.1868,
-         0.0000,  0.0936,  0.0000,  0.0000,  0.0000,  0.0000,  0.0938,  0.6890,
-         0.1586,  0.0000, -0.3884,  0.0000,  0.0000, -0.3036,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1737, -0.0851,  0.2121,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4068,  0.0000,  0.0000,  0.0293,  0.0000, -0.3792],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.6704e-01, -6.3754e-10,  6.5814e-02,  6.8782e-03,  2.1463e+00,
-        -3.8971e-01,  3.8592e-19, -3.0472e-12,  4.6401e-02,  4.5575e-02,
-         2.5054e-01,  6.9377e-11, -2.9432e-01, -4.6655e-01, -6.9206e-16,
-        -5.1436e-13, -5.0595e-14, -1.5744e-01, -1.1702e-16,  1.4070e-12,
-        -4.7998e-02,  8.9611e-02, -2.4851e-16, -8.6931e-01,  1.9614e-01,
-         2.0170e-12,  2.6614e-01,  3.9340e-02, -1.0889e-01,  3.6143e-16,
-        -1.9513e-01, -1.1891e+00, -9.2519e-20,  9.8201e-02,  4.5922e-14,
-         2.4778e-12,  1.5318e-15,  0.0000e+00,  1.2043e-01,  7.1315e-01,
-         1.5618e-01, -2.0460e-13, -3.9147e-01,  3.7700e-16,  9.2210e-14,
-        -3.0804e-01, -1.3011e-11,  4.1982e-05,  4.3768e-11, -2.4001e-17,
-        -7.7655e-08,  1.9388e-01, -9.1519e-02,  2.3886e-01, -8.7951e-09,
-         4.9949e-21, -1.9631e-15,  7.0600e-19,  3.9639e-01, -1.0196e-10,
-         1.8859e-18, -2.8760e-03,  1.5845e-06, -3.7148e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1670,  0.0000,  0.0658,  0.0069,  2.1463, -0.3897,  0.0000,  0.0000,
-         0.0464,  0.0456,  0.2505,  0.0000, -0.2943, -0.4665,  0.0000,  0.0000,
-         0.0000, -0.1574,  0.0000,  0.0000, -0.0480,  0.0896,  0.0000, -0.8693,
-         0.1961,  0.0000,  0.2661,  0.0393, -0.1089,  0.0000, -0.1951, -1.1891,
-         0.0000,  0.0982,  0.0000,  0.0000,  0.0000,  0.0000,  0.1204,  0.7132,
-         0.1562,  0.0000, -0.3915,  0.0000,  0.0000, -0.3080,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1939, -0.0915,  0.2389,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3964,  0.0000,  0.0000, -0.0029,  0.0000, -0.3715],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1670,  0.0000,  0.0658,  0.0069,  2.1463, -0.3897,  0.0000,  0.0000,
-         0.0464,  0.0456,  0.2505,  0.0000, -0.2943, -0.4665,  0.0000,  0.0000,
-         0.0000, -0.1574,  0.0000,  0.0000, -0.0480,  0.0896,  0.0000, -0.8693,
-         0.1961,  0.0000,  0.2661,  0.0393, -0.1089,  0.0000, -0.1951, -1.1891,
-         0.0000,  0.0982,  0.0000,  0.0000,  0.0000,  0.0000,  0.1204,  0.7132,
-         0.1562,  0.0000, -0.3915,  0.0000,  0.0000, -0.3080,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1939, -0.0915,  0.2389,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3964,  0.0000,  0.0000, -0.0029,  0.0000, -0.3715],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.4078e-01, -5.8041e-10,  8.7758e-02,  1.2731e-02,  2.1467e+00,
-        -3.8858e-01,  3.5133e-19, -2.7742e-12,  8.9883e-03,  1.0028e-01,
-         2.1505e-01,  6.3160e-11, -2.9750e-01, -4.4950e-01, -6.3004e-16,
-        -4.6827e-13, -4.6061e-14, -2.3712e-01, -1.0654e-16,  1.2809e-12,
-        -5.5962e-02,  1.0640e-01, -2.2624e-16, -8.6686e-01,  2.2099e-01,
-         1.8362e-12,  2.8258e-01,  2.6917e-02, -1.1922e-01,  3.2904e-16,
-        -2.1498e-01, -1.1913e+00, -8.4228e-20,  1.0322e-01,  4.1807e-14,
-         2.2557e-12,  1.3946e-15,  0.0000e+00,  1.4477e-01,  7.3679e-01,
-         1.3892e-01, -1.8627e-13, -3.9335e-01,  3.4322e-16,  8.3947e-14,
-        -3.0817e-01, -1.1845e-11,  3.8220e-05,  3.9846e-11, -2.1850e-17,
-        -7.0696e-08,  1.8989e-01, -9.0739e-02,  2.3116e-01, -8.0069e-09,
-         4.5473e-21, -1.7872e-15,  6.4274e-19,  3.8628e-01, -9.2825e-11,
-         1.7169e-18, -3.4931e-02,  1.4425e-06, -3.5624e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1408,  0.0000,  0.0878,  0.0127,  2.1467, -0.3886,  0.0000,  0.0000,
-         0.0090,  0.1003,  0.2150,  0.0000, -0.2975, -0.4495,  0.0000,  0.0000,
-         0.0000, -0.2371,  0.0000,  0.0000, -0.0560,  0.1064,  0.0000, -0.8669,
-         0.2210,  0.0000,  0.2826,  0.0269, -0.1192,  0.0000, -0.2150, -1.1913,
-         0.0000,  0.1032,  0.0000,  0.0000,  0.0000,  0.0000,  0.1448,  0.7368,
-         0.1389,  0.0000, -0.3934,  0.0000,  0.0000, -0.3082,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1899, -0.0907,  0.2312,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3863,  0.0000,  0.0000, -0.0349,  0.0000, -0.3562],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1408,  0.0000,  0.0878,  0.0127,  2.1467, -0.3886,  0.0000,  0.0000,
-         0.0090,  0.1003,  0.2150,  0.0000, -0.2975, -0.4495,  0.0000,  0.0000,
-         0.0000, -0.2371,  0.0000,  0.0000, -0.0560,  0.1064,  0.0000, -0.8669,
-         0.2210,  0.0000,  0.2826,  0.0269, -0.1192,  0.0000, -0.2150, -1.1913,
-         0.0000,  0.1032,  0.0000,  0.0000,  0.0000,  0.0000,  0.1448,  0.7368,
-         0.1389,  0.0000, -0.3934,  0.0000,  0.0000, -0.3082,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1899, -0.0907,  0.2312,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3863,  0.0000,  0.0000, -0.0349,  0.0000, -0.3562],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 9.7483e-02, -5.2854e-10,  1.0256e-01, -2.7314e-03,  2.1470e+00,
-        -3.8713e-01,  3.1994e-19, -2.5263e-12, -2.3947e-02,  1.3782e-01,
-         1.9597e-01,  5.7516e-11, -3.0850e-01, -4.3681e-01, -5.7374e-16,
-        -4.2642e-13, -4.1945e-14, -3.1641e-01, -9.7016e-17,  1.1665e-12,
-        -7.1537e-02,  9.1505e-02, -2.0603e-16, -8.6555e-01,  2.2780e-01,
-         1.6721e-12,  3.0204e-01,  3.0161e-03, -1.3962e-01,  2.9964e-16,
-        -2.3189e-01, -1.1920e+00, -7.6701e-20,  1.0858e-01,  3.8071e-14,
-         2.0541e-12,  1.2699e-15,  0.0000e+00,  1.5873e-01,  7.5553e-01,
-         9.9521e-02, -1.6962e-13, -3.9809e-01,  3.1254e-16,  7.6445e-14,
-        -2.9705e-01, -1.0786e-11,  3.4804e-05,  3.6285e-11, -1.9897e-17,
-        -6.4378e-08,  1.5534e-01, -7.6028e-02,  1.7240e-01, -7.2914e-09,
-         4.1410e-21, -1.6275e-15,  5.8530e-19,  3.8610e-01, -8.4529e-11,
-         1.5635e-18, -6.6524e-02,  1.3136e-06, -3.3394e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0975,  0.0000,  0.1026, -0.0027,  2.1470, -0.3871,  0.0000,  0.0000,
-        -0.0239,  0.1378,  0.1960,  0.0000, -0.3085, -0.4368,  0.0000,  0.0000,
-         0.0000, -0.3164,  0.0000,  0.0000, -0.0715,  0.0915,  0.0000, -0.8656,
-         0.2278,  0.0000,  0.3020,  0.0030, -0.1396,  0.0000, -0.2319, -1.1920,
-         0.0000,  0.1086,  0.0000,  0.0000,  0.0000,  0.0000,  0.1587,  0.7555,
-         0.0995,  0.0000, -0.3981,  0.0000,  0.0000, -0.2971,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1553, -0.0760,  0.1724,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3861,  0.0000,  0.0000, -0.0665,  0.0000, -0.3339],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0975,  0.0000,  0.1026, -0.0027,  2.1470, -0.3871,  0.0000,  0.0000,
-        -0.0239,  0.1378,  0.1960,  0.0000, -0.3085, -0.4368,  0.0000,  0.0000,
-         0.0000, -0.3164,  0.0000,  0.0000, -0.0715,  0.0915,  0.0000, -0.8656,
-         0.2278,  0.0000,  0.3020,  0.0030, -0.1396,  0.0000, -0.2319, -1.1920,
-         0.0000,  0.1086,  0.0000,  0.0000,  0.0000,  0.0000,  0.1587,  0.7555,
-         0.0995,  0.0000, -0.3981,  0.0000,  0.0000, -0.2971,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1553, -0.0760,  0.1724,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3861,  0.0000,  0.0000, -0.0665,  0.0000, -0.3339],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 5.2685e-02, -4.8144e-10,  1.1574e-01, -9.1928e-03,  2.1490e+00,
-        -3.7819e-01,  2.9142e-19, -2.3011e-12, -5.4455e-02,  1.6112e-01,
-         1.5590e-01,  5.2390e-11, -3.2171e-01, -4.2674e-01, -5.2260e-16,
-        -3.8842e-13, -3.8207e-14, -3.7707e-01, -8.8370e-17,  1.0625e-12,
-        -6.9257e-02,  6.7403e-02, -1.8766e-16, -8.6868e-01,  2.1879e-01,
-         1.5231e-12,  3.2590e-01, -9.2153e-03, -1.5168e-01,  2.7293e-16,
-        -2.4828e-01, -1.1916e+00, -6.9865e-20,  1.2722e-01,  3.4678e-14,
-         1.8711e-12,  1.1568e-15,  0.0000e+00,  1.5851e-01,  7.6587e-01,
-         6.4815e-02, -1.5451e-13, -4.0097e-01,  2.8469e-16,  6.9632e-14,
-        -2.9165e-01, -9.8249e-12,  3.1702e-05,  3.3051e-11, -1.8124e-17,
-        -5.8641e-08,  1.1995e-01, -6.7949e-02,  1.1452e-01, -6.6416e-09,
-         3.7719e-21, -1.4824e-15,  5.3313e-19,  3.9978e-01, -7.6996e-11,
-         1.4241e-18, -9.2934e-02,  1.1965e-06, -3.0418e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0527,  0.0000,  0.1157, -0.0092,  2.1490, -0.3782,  0.0000,  0.0000,
-        -0.0545,  0.1611,  0.1559,  0.0000, -0.3217, -0.4267,  0.0000,  0.0000,
-         0.0000, -0.3771,  0.0000,  0.0000, -0.0693,  0.0674,  0.0000, -0.8687,
-         0.2188,  0.0000,  0.3259, -0.0092, -0.1517,  0.0000, -0.2483, -1.1916,
-         0.0000,  0.1272,  0.0000,  0.0000,  0.0000,  0.0000,  0.1585,  0.7659,
-         0.0648,  0.0000, -0.4010,  0.0000,  0.0000, -0.2917,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1200, -0.0679,  0.1145,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3998,  0.0000,  0.0000, -0.0929,  0.0000, -0.3042],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0527,  0.0000,  0.1157, -0.0092,  2.1490, -0.3782,  0.0000,  0.0000,
-        -0.0545,  0.1611,  0.1559,  0.0000, -0.3217, -0.4267,  0.0000,  0.0000,
-         0.0000, -0.3771,  0.0000,  0.0000, -0.0693,  0.0674,  0.0000, -0.8687,
-         0.2188,  0.0000,  0.3259, -0.0092, -0.1517,  0.0000, -0.2483, -1.1916,
-         0.0000,  0.1272,  0.0000,  0.0000,  0.0000,  0.0000,  0.1585,  0.7659,
-         0.0648,  0.0000, -0.4010,  0.0000,  0.0000, -0.2917,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1200, -0.0679,  0.1145,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.3998,  0.0000,  0.0000, -0.0929,  0.0000, -0.3042],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 9.4764e-03, -4.3865e-10,  1.3154e-01, -5.4274e-03,  2.1502e+00,
-        -3.6951e-01,  2.6552e-19, -2.0966e-12, -9.1067e-02,  1.6959e-01,
-         1.0208e-01,  4.7733e-11, -3.3124e-01, -4.1762e-01, -4.7615e-16,
-        -3.5389e-13, -3.4811e-14, -4.2202e-01, -8.0515e-17,  9.6806e-13,
-        -6.3539e-02,  2.4248e-02, -1.7098e-16, -8.7085e-01,  1.9370e-01,
-         1.3877e-12,  3.5402e-01, -2.4667e-02, -1.5775e-01,  2.4867e-16,
-        -2.5737e-01, -1.1900e+00, -6.3655e-20,  1.3929e-01,  3.1595e-14,
-         1.7048e-12,  1.0540e-15,  0.0000e+00,  1.5117e-01,  7.6563e-01,
-         3.0670e-02, -1.4077e-13, -3.9329e-01,  2.5939e-16,  6.3443e-14,
-        -2.8141e-01, -8.9517e-12,  2.8885e-05,  3.0113e-11, -1.6513e-17,
-        -5.3429e-08,  8.7375e-02, -5.2453e-02,  6.2866e-02, -6.0513e-09,
-         3.4367e-21, -1.3507e-15,  4.8575e-19,  4.1630e-01, -7.0152e-11,
-         1.2975e-18, -1.0201e-01,  1.0902e-06, -2.7193e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0095,  0.0000,  0.1315, -0.0054,  2.1502, -0.3695,  0.0000,  0.0000,
-        -0.0911,  0.1696,  0.1021,  0.0000, -0.3312, -0.4176,  0.0000,  0.0000,
-         0.0000, -0.4220,  0.0000,  0.0000, -0.0635,  0.0242,  0.0000, -0.8709,
-         0.1937,  0.0000,  0.3540, -0.0247, -0.1577,  0.0000, -0.2574, -1.1900,
-         0.0000,  0.1393,  0.0000,  0.0000,  0.0000,  0.0000,  0.1512,  0.7656,
-         0.0307,  0.0000, -0.3933,  0.0000,  0.0000, -0.2814,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0874, -0.0525,  0.0629,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4163,  0.0000,  0.0000, -0.1020,  0.0000, -0.2719],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0095,  0.0000,  0.1315, -0.0054,  2.1502, -0.3695,  0.0000,  0.0000,
-        -0.0911,  0.1696,  0.1021,  0.0000, -0.3312, -0.4176,  0.0000,  0.0000,
-         0.0000, -0.4220,  0.0000,  0.0000, -0.0635,  0.0242,  0.0000, -0.8709,
-         0.1937,  0.0000,  0.3540, -0.0247, -0.1577,  0.0000, -0.2574, -1.1900,
-         0.0000,  0.1393,  0.0000,  0.0000,  0.0000,  0.0000,  0.1512,  0.7656,
-         0.0307,  0.0000, -0.3933,  0.0000,  0.0000, -0.2814,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0874, -0.0525,  0.0629,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4163,  0.0000,  0.0000, -0.1020,  0.0000, -0.2719],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.4154e-02, -3.9976e-10,  1.5097e-01,  3.7738e-03,  2.1515e+00,
-        -3.5779e-01,  2.4198e-19, -1.9107e-12, -1.1135e-01,  1.7970e-01,
-         6.1167e-02,  4.3502e-11, -3.3361e-01, -4.1352e-01, -4.3395e-16,
-        -3.2252e-13, -3.1725e-14, -4.4052e-01, -7.3378e-17,  8.8225e-13,
-        -5.3080e-02, -1.0377e-02, -1.5583e-16, -8.7003e-01,  1.6351e-01,
-         1.2647e-12,  3.8086e-01, -3.5121e-02, -1.7059e-01,  2.2663e-16,
-        -2.6254e-01, -1.1882e+00, -5.8013e-20,  1.5469e-01,  2.8795e-14,
-         1.5537e-12,  9.6052e-16,  0.0000e+00,  1.4929e-01,  7.6028e-01,
-         2.0515e-03, -1.2829e-13, -3.8470e-01,  2.3639e-16,  5.7819e-14,
-        -2.7684e-01, -8.1582e-12,  2.6324e-05,  2.7444e-11, -1.5049e-17,
-        -4.8693e-08,  5.1521e-02, -2.5694e-02,  1.6081e-02, -5.5148e-09,
-         3.1320e-21, -1.2309e-15,  4.4269e-19,  4.3138e-01, -6.3934e-11,
-         1.1825e-18, -1.0056e-01,  9.9355e-07, -2.3883e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-2.4154e-02,  0.0000e+00,  1.5097e-01,  3.7738e-03,  2.1515e+00,
-        -3.5779e-01,  0.0000e+00,  0.0000e+00, -1.1135e-01,  1.7970e-01,
-         6.1167e-02,  0.0000e+00, -3.3361e-01, -4.1352e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.4052e-01,  0.0000e+00,  0.0000e+00,
-        -5.3080e-02, -1.0377e-02,  0.0000e+00, -8.7003e-01,  1.6351e-01,
-         0.0000e+00,  3.8086e-01, -3.5121e-02, -1.7059e-01,  0.0000e+00,
-        -2.6254e-01, -1.1882e+00,  0.0000e+00,  1.5469e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.4929e-01,  7.6028e-01,
-         2.0515e-03,  0.0000e+00, -3.8470e-01,  0.0000e+00,  0.0000e+00,
-        -2.7684e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  5.1521e-02, -2.5694e-02,  1.6081e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.3138e-01,  0.0000e+00,
-         0.0000e+00, -1.0056e-01,  0.0000e+00, -2.3883e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([-2.4154e-02,  0.0000e+00,  1.5097e-01,  3.7738e-03,  2.1515e+00,
-        -3.5779e-01,  0.0000e+00,  0.0000e+00, -1.1135e-01,  1.7970e-01,
-         6.1167e-02,  0.0000e+00, -3.3361e-01, -4.1352e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.4052e-01,  0.0000e+00,  0.0000e+00,
-        -5.3080e-02, -1.0377e-02,  0.0000e+00, -8.7003e-01,  1.6351e-01,
-         0.0000e+00,  3.8086e-01, -3.5121e-02, -1.7059e-01,  0.0000e+00,
-        -2.6254e-01, -1.1882e+00,  0.0000e+00,  1.5469e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.4929e-01,  7.6028e-01,
-         2.0515e-03,  0.0000e+00, -3.8470e-01,  0.0000e+00,  0.0000e+00,
-        -2.7684e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  5.1521e-02, -2.5694e-02,  1.6081e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.3138e-01,  0.0000e+00,
-         0.0000e+00, -1.0056e-01,  0.0000e+00, -2.3883e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([-3.6557e-02, -3.6442e-10,  1.6995e-01,  1.9801e-02,  2.1527e+00,
-        -3.6694e-01,  2.2059e-19, -1.7418e-12, -1.2546e-01,  1.7737e-01,
-         8.2475e-03,  3.9656e-11, -3.1799e-01, -4.1102e-01, -3.9558e-16,
-        -2.9401e-13, -2.8920e-14, -4.3985e-01, -6.6890e-17,  8.0425e-13,
-        -3.0839e-02, -5.0506e-02, -1.4205e-16, -8.6265e-01,  1.0964e-01,
-         1.1529e-12,  4.0652e-01, -3.2921e-02, -1.7634e-01,  2.0659e-16,
-        -2.7274e-01, -1.1827e+00, -5.2883e-20,  1.5588e-01,  2.6249e-14,
-         1.4163e-12,  8.7560e-16,  0.0000e+00,  1.5386e-01,  7.5340e-01,
-        -1.6659e-03, -1.1695e-13, -3.7068e-01,  2.1549e-16,  5.2707e-14,
-        -2.7665e-01, -7.4369e-12,  2.3997e-05,  2.5018e-11, -1.3719e-17,
-        -4.4388e-08,  2.7079e-02, -1.0155e-02, -1.4541e-02, -5.0273e-09,
-         2.8551e-21, -1.1221e-15,  4.0355e-19,  4.4536e-01, -5.8281e-11,
-         1.0780e-18, -8.2575e-02,  9.0571e-07, -2.2572e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-3.6557e-02,  0.0000e+00,  1.6995e-01,  1.9801e-02,  2.1527e+00,
-        -3.6694e-01,  0.0000e+00,  0.0000e+00, -1.2546e-01,  1.7737e-01,
-         8.2475e-03,  0.0000e+00, -3.1799e-01, -4.1102e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.3985e-01,  0.0000e+00,  0.0000e+00,
-        -3.0839e-02, -5.0506e-02,  0.0000e+00, -8.6265e-01,  1.0964e-01,
-         0.0000e+00,  4.0652e-01, -3.2921e-02, -1.7634e-01,  0.0000e+00,
-        -2.7274e-01, -1.1827e+00,  0.0000e+00,  1.5588e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.5386e-01,  7.5340e-01,
-        -1.6659e-03,  0.0000e+00, -3.7068e-01,  0.0000e+00,  0.0000e+00,
-        -2.7665e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  2.7079e-02, -1.0155e-02, -1.4541e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.4536e-01,  0.0000e+00,
-         0.0000e+00, -8.2575e-02,  0.0000e+00, -2.2572e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([-3.6557e-02,  0.0000e+00,  1.6995e-01,  1.9801e-02,  2.1527e+00,
-        -3.6694e-01,  0.0000e+00,  0.0000e+00, -1.2546e-01,  1.7737e-01,
-         8.2475e-03,  0.0000e+00, -3.1799e-01, -4.1102e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.3985e-01,  0.0000e+00,  0.0000e+00,
-        -3.0839e-02, -5.0506e-02,  0.0000e+00, -8.6265e-01,  1.0964e-01,
-         0.0000e+00,  4.0652e-01, -3.2921e-02, -1.7634e-01,  0.0000e+00,
-        -2.7274e-01, -1.1827e+00,  0.0000e+00,  1.5588e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.5386e-01,  7.5340e-01,
-        -1.6659e-03,  0.0000e+00, -3.7068e-01,  0.0000e+00,  0.0000e+00,
-        -2.7665e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  2.7079e-02, -1.0155e-02, -1.4541e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.4536e-01,  0.0000e+00,
-         0.0000e+00, -8.2575e-02,  0.0000e+00, -2.2572e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([-4.5450e-02, -3.3228e-10,  1.8053e-01,  3.9076e-02,  2.1534e+00,
-        -3.7920e-01,  2.0114e-19, -1.5882e-12, -1.2794e-01,  1.6626e-01,
-        -3.6370e-02,  3.6159e-11, -3.0345e-01, -4.0418e-01, -3.6070e-16,
-        -2.6808e-13, -2.6370e-14, -4.2671e-01, -6.0992e-17,  7.3332e-13,
-        -3.7073e-03, -7.7802e-02, -1.2952e-16, -8.5346e-01,  5.9516e-02,
-         1.0512e-12,  4.1956e-01, -2.4087e-02, -1.7476e-01,  1.8837e-16,
-        -2.9059e-01, -1.1764e+00, -4.8220e-20,  1.5453e-01,  2.3934e-14,
-         1.2914e-12,  7.9839e-16,  0.0000e+00,  1.6487e-01,  7.4646e-01,
-         5.0171e-03, -1.0664e-13, -3.6133e-01,  1.9649e-16,  4.8059e-14,
-        -2.8539e-01, -6.7811e-12,  2.1881e-05,  2.2811e-11, -1.2509e-17,
-        -4.0473e-08,  4.0412e-04,  9.6613e-04, -2.6334e-02, -4.5839e-09,
-         2.6033e-21, -1.0232e-15,  3.6796e-19,  4.6256e-01, -5.3142e-11,
-         9.8291e-19, -6.0307e-02,  8.2584e-07, -2.1923e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-4.5450e-02,  0.0000e+00,  1.8053e-01,  3.9076e-02,  2.1534e+00,
-        -3.7920e-01,  0.0000e+00,  0.0000e+00, -1.2794e-01,  1.6626e-01,
-        -3.6370e-02,  0.0000e+00, -3.0345e-01, -4.0418e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.2671e-01,  0.0000e+00,  0.0000e+00,
-        -3.7073e-03, -7.7802e-02,  0.0000e+00, -8.5346e-01,  5.9516e-02,
-         0.0000e+00,  4.1956e-01, -2.4087e-02, -1.7476e-01,  0.0000e+00,
-        -2.9059e-01, -1.1764e+00,  0.0000e+00,  1.5453e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.6487e-01,  7.4646e-01,
-         5.0171e-03,  0.0000e+00, -3.6133e-01,  0.0000e+00,  0.0000e+00,
-        -2.8539e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  4.0412e-04,  9.6613e-04, -2.6334e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6256e-01,  0.0000e+00,
-         0.0000e+00, -6.0307e-02,  0.0000e+00, -2.1923e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([-4.5450e-02,  0.0000e+00,  1.8053e-01,  3.9076e-02,  2.1534e+00,
-        -3.7920e-01,  0.0000e+00,  0.0000e+00, -1.2794e-01,  1.6626e-01,
-        -3.6370e-02,  0.0000e+00, -3.0345e-01, -4.0418e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -4.2671e-01,  0.0000e+00,  0.0000e+00,
-        -3.7073e-03, -7.7802e-02,  0.0000e+00, -8.5346e-01,  5.9516e-02,
-         0.0000e+00,  4.1956e-01, -2.4087e-02, -1.7476e-01,  0.0000e+00,
-        -2.9059e-01, -1.1764e+00,  0.0000e+00,  1.5453e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.6487e-01,  7.4646e-01,
-         5.0171e-03,  0.0000e+00, -3.6133e-01,  0.0000e+00,  0.0000e+00,
-        -2.8539e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  4.0412e-04,  9.6613e-04, -2.6334e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6256e-01,  0.0000e+00,
-         0.0000e+00, -6.0307e-02,  0.0000e+00, -2.1923e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([-5.0422e-02, -3.0306e-10,  2.0123e-01,  6.9000e-02,  2.1540e+00,
-        -3.8677e-01,  1.8344e-19, -1.4485e-12, -1.2315e-01,  1.6198e-01,
-        -7.4725e-02,  3.2978e-11, -2.7860e-01, -3.9351e-01, -3.2897e-16,
-        -2.4450e-13, -2.4050e-14, -4.1267e-01, -5.5627e-17,  6.6882e-13,
-         2.9638e-02, -1.0240e-01, -1.1813e-16, -8.4261e-01,  6.1633e-03,
-         9.5876e-13,  4.3187e-01, -7.6450e-03, -1.6694e-01,  1.7181e-16,
-        -3.0906e-01, -1.1715e+00, -4.3979e-20,  1.4965e-01,  2.1829e-14,
-         1.1778e-12,  7.2816e-16,  0.0000e+00,  1.7446e-01,  7.3607e-01,
-         1.9023e-02, -9.7258e-14, -3.4858e-01,  1.7921e-16,  4.3832e-14,
-        -2.9203e-01, -6.1846e-12,  1.9956e-05,  2.0805e-11, -1.1409e-17,
-        -3.6913e-08, -5.5604e-03,  6.2544e-03, -1.6613e-02, -4.1807e-09,
-         2.3743e-21, -9.3316e-16,  3.3560e-19,  4.7682e-01, -4.8467e-11,
-         8.9646e-19, -3.5430e-02,  7.5320e-07, -2.1709e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0504,  0.0000,  0.2012,  0.0690,  2.1540, -0.3868,  0.0000,  0.0000,
-        -0.1232,  0.1620, -0.0747,  0.0000, -0.2786, -0.3935,  0.0000,  0.0000,
-         0.0000, -0.4127,  0.0000,  0.0000,  0.0296, -0.1024,  0.0000, -0.8426,
-         0.0062,  0.0000,  0.4319, -0.0076, -0.1669,  0.0000, -0.3091, -1.1715,
-         0.0000,  0.1497,  0.0000,  0.0000,  0.0000,  0.0000,  0.1745,  0.7361,
-         0.0190,  0.0000, -0.3486,  0.0000,  0.0000, -0.2920,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0056,  0.0063, -0.0166,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4768,  0.0000,  0.0000, -0.0354,  0.0000, -0.2171],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0504,  0.0000,  0.2012,  0.0690,  2.1540, -0.3868,  0.0000,  0.0000,
-        -0.1232,  0.1620, -0.0747,  0.0000, -0.2786, -0.3935,  0.0000,  0.0000,
-         0.0000, -0.4127,  0.0000,  0.0000,  0.0296, -0.1024,  0.0000, -0.8426,
-         0.0062,  0.0000,  0.4319, -0.0076, -0.1669,  0.0000, -0.3091, -1.1715,
-         0.0000,  0.1497,  0.0000,  0.0000,  0.0000,  0.0000,  0.1745,  0.7361,
-         0.0190,  0.0000, -0.3486,  0.0000,  0.0000, -0.2920,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0056,  0.0063, -0.0166,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4768,  0.0000,  0.0000, -0.0354,  0.0000, -0.2171],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-5.2054e-02, -2.7647e-10,  2.0843e-01,  9.5849e-02,  2.1529e+00,
-        -4.0205e-01,  1.6735e-19, -1.3214e-12, -9.2790e-02,  1.4358e-01,
-        -9.1880e-02,  3.0085e-11, -2.5558e-01, -3.8210e-01, -3.0011e-16,
-        -2.2305e-13, -2.1940e-14, -3.8091e-01, -5.0746e-17,  6.1014e-13,
-         7.7906e-02, -1.0692e-01, -1.0777e-16, -8.2601e-01, -3.9649e-02,
-         8.7464e-13,  4.3742e-01,  2.6251e-02, -1.5530e-01,  1.5673e-16,
-        -3.4090e-01, -1.1654e+00, -4.0120e-20,  1.4218e-01,  1.9914e-14,
-         1.0745e-12,  6.6427e-16,  0.0000e+00,  1.8967e-01,  7.2518e-01,
-         5.2489e-02, -8.8725e-14, -3.4090e-01,  1.6348e-16,  3.9986e-14,
-        -3.0935e-01, -5.6420e-12,  1.8205e-05,  1.8980e-11, -1.0408e-17,
-        -3.3675e-08,  2.8685e-03,  2.3738e-02,  3.0304e-03, -3.8139e-09,
-         2.1660e-21, -8.5128e-16,  3.0615e-19,  4.9381e-01, -4.4215e-11,
-         8.1780e-19, -1.1091e-02,  6.8711e-07, -2.1446e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0521,  0.0000,  0.2084,  0.0958,  2.1529, -0.4021,  0.0000,  0.0000,
-        -0.0928,  0.1436, -0.0919,  0.0000, -0.2556, -0.3821,  0.0000,  0.0000,
-         0.0000, -0.3809,  0.0000,  0.0000,  0.0779, -0.1069,  0.0000, -0.8260,
-        -0.0396,  0.0000,  0.4374,  0.0263, -0.1553,  0.0000, -0.3409, -1.1654,
-         0.0000,  0.1422,  0.0000,  0.0000,  0.0000,  0.0000,  0.1897,  0.7252,
-         0.0525,  0.0000, -0.3409,  0.0000,  0.0000, -0.3093,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0029,  0.0237,  0.0030,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4938,  0.0000,  0.0000, -0.0111,  0.0000, -0.2145],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0521,  0.0000,  0.2084,  0.0958,  2.1529, -0.4021,  0.0000,  0.0000,
-        -0.0928,  0.1436, -0.0919,  0.0000, -0.2556, -0.3821,  0.0000,  0.0000,
-         0.0000, -0.3809,  0.0000,  0.0000,  0.0779, -0.1069,  0.0000, -0.8260,
-        -0.0396,  0.0000,  0.4374,  0.0263, -0.1553,  0.0000, -0.3409, -1.1654,
-         0.0000,  0.1422,  0.0000,  0.0000,  0.0000,  0.0000,  0.1897,  0.7252,
-         0.0525,  0.0000, -0.3409,  0.0000,  0.0000, -0.3093,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0029,  0.0237,  0.0030,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4938,  0.0000,  0.0000, -0.0111,  0.0000, -0.2145],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-4.3099e-02, -2.5227e-10,  2.1312e-01,  1.3581e-01,  2.1521e+00,
-        -4.2299e-01,  1.5270e-19, -1.2058e-12, -5.6188e-02,  1.1032e-01,
-        -1.0366e-01,  2.7452e-11, -2.2529e-01, -3.7640e-01, -2.7384e-16,
-        -2.0353e-13, -2.0020e-14, -3.4255e-01, -4.6305e-17,  5.5674e-13,
-         1.3157e-01, -9.4695e-02, -9.8334e-17, -8.1510e-01, -7.8293e-02,
-         7.9809e-13,  4.3895e-01,  6.2635e-02, -1.5406e-01,  1.4301e-16,
-        -3.8517e-01, -1.1584e+00, -3.6609e-20,  1.2690e-01,  1.8171e-14,
-         9.8042e-13,  6.0613e-16,  0.0000e+00,  2.0814e-01,  7.1367e-01,
-         9.0432e-02, -8.0959e-14, -3.3412e-01,  1.4917e-16,  3.6487e-14,
-        -3.2648e-01, -5.1482e-12,  1.6612e-05,  1.7318e-11, -9.4968e-18,
-        -3.0727e-08,  3.3837e-02,  3.3222e-02,  6.5868e-02, -3.4801e-09,
-         1.9764e-21, -7.7678e-16,  2.7936e-19,  5.0747e-01, -4.0345e-11,
-         7.4623e-19,  1.8098e-02,  6.2698e-07, -2.2340e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0431,  0.0000,  0.2131,  0.1358,  2.1521, -0.4230,  0.0000,  0.0000,
-        -0.0562,  0.1103, -0.1037,  0.0000, -0.2253, -0.3764,  0.0000,  0.0000,
-         0.0000, -0.3425,  0.0000,  0.0000,  0.1316, -0.0947,  0.0000, -0.8151,
-        -0.0783,  0.0000,  0.4390,  0.0626, -0.1541,  0.0000, -0.3852, -1.1584,
-         0.0000,  0.1269,  0.0000,  0.0000,  0.0000,  0.0000,  0.2081,  0.7137,
-         0.0904,  0.0000, -0.3341,  0.0000,  0.0000, -0.3265,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0338,  0.0332,  0.0659,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5075,  0.0000,  0.0000,  0.0181,  0.0000, -0.2234],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0431,  0.0000,  0.2131,  0.1358,  2.1521, -0.4230,  0.0000,  0.0000,
-        -0.0562,  0.1103, -0.1037,  0.0000, -0.2253, -0.3764,  0.0000,  0.0000,
-         0.0000, -0.3425,  0.0000,  0.0000,  0.1316, -0.0947,  0.0000, -0.8151,
-        -0.0783,  0.0000,  0.4390,  0.0626, -0.1541,  0.0000, -0.3852, -1.1584,
-         0.0000,  0.1269,  0.0000,  0.0000,  0.0000,  0.0000,  0.2081,  0.7137,
-         0.0904,  0.0000, -0.3341,  0.0000,  0.0000, -0.3265,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0338,  0.0332,  0.0659,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5075,  0.0000,  0.0000,  0.0181,  0.0000, -0.2234],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-4.3115e-02, -2.3024e-10,  2.0856e-01,  1.6451e-01,  2.1503e+00,
-        -4.4330e-01,  1.3937e-19, -1.1005e-12, -2.6483e-02,  6.7923e-02,
-        -9.8113e-02,  2.5055e-11, -1.9889e-01, -3.7533e-01, -2.4993e-16,
-        -1.8576e-13, -1.8272e-14, -3.1217e-01, -4.2262e-17,  5.0813e-13,
-         1.7155e-01, -8.6881e-02, -8.9749e-17, -8.0749e-01, -1.0977e-01,
-         7.2841e-13,  4.3450e-01,  8.7299e-02, -1.5848e-01,  1.3053e-16,
-        -4.2430e-01, -1.1526e+00, -3.3412e-20,  1.1429e-01,  1.6584e-14,
-         8.9482e-13,  5.5321e-16,  0.0000e+00,  2.2349e-01,  7.0390e-01,
-         1.2516e-01, -7.3891e-14, -3.2363e-01,  1.3615e-16,  3.3301e-14,
-        -3.2902e-01, -4.6987e-12,  1.5161e-05,  1.5806e-11, -8.6676e-18,
-        -2.8044e-08,  4.1987e-02,  4.4683e-02,  9.9377e-02, -3.1763e-09,
-         1.8039e-21, -7.0896e-16,  2.5497e-19,  5.2572e-01, -3.6823e-11,
-         6.8107e-19,  4.5752e-02,  5.7223e-07, -2.1962e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0431,  0.0000,  0.2086,  0.1645,  2.1503, -0.4433,  0.0000,  0.0000,
-        -0.0265,  0.0679, -0.0981,  0.0000, -0.1989, -0.3753,  0.0000,  0.0000,
-         0.0000, -0.3122,  0.0000,  0.0000,  0.1715, -0.0869,  0.0000, -0.8075,
-        -0.1098,  0.0000,  0.4345,  0.0873, -0.1585,  0.0000, -0.4243, -1.1526,
-         0.0000,  0.1143,  0.0000,  0.0000,  0.0000,  0.0000,  0.2235,  0.7039,
-         0.1252,  0.0000, -0.3236,  0.0000,  0.0000, -0.3290,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0420,  0.0447,  0.0994,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5257,  0.0000,  0.0000,  0.0458,  0.0000, -0.2196],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0431,  0.0000,  0.2086,  0.1645,  2.1503, -0.4433,  0.0000,  0.0000,
-        -0.0265,  0.0679, -0.0981,  0.0000, -0.1989, -0.3753,  0.0000,  0.0000,
-         0.0000, -0.3122,  0.0000,  0.0000,  0.1715, -0.0869,  0.0000, -0.8075,
-        -0.1098,  0.0000,  0.4345,  0.0873, -0.1585,  0.0000, -0.4243, -1.1526,
-         0.0000,  0.1143,  0.0000,  0.0000,  0.0000,  0.0000,  0.2235,  0.7039,
-         0.1252,  0.0000, -0.3236,  0.0000,  0.0000, -0.3290,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0420,  0.0447,  0.0994,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5257,  0.0000,  0.0000,  0.0458,  0.0000, -0.2196],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-3.9428e-02, -2.1019e-10,  2.0668e-01,  1.6689e-01,  2.1490e+00,
-        -4.5697e-01,  1.2723e-19, -1.0046e-12, -8.9598e-03,  1.8142e-02,
-        -8.1625e-02,  2.2873e-11, -1.7724e-01, -3.7258e-01, -2.2816e-16,
-        -1.6958e-13, -1.6681e-14, -3.0432e-01, -3.8581e-17,  4.6387e-13,
-         1.9262e-01, -8.5359e-02, -8.1932e-17, -8.0709e-01, -1.3269e-01,
-         6.6496e-13,  4.1849e-01,  9.6670e-02, -1.6290e-01,  1.1916e-16,
-        -4.5090e-01, -1.1488e+00, -3.0502e-20,  1.0589e-01,  1.5140e-14,
-         8.1688e-13,  5.0503e-16,  0.0000e+00,  2.4152e-01,  6.9473e-01,
-         1.4278e-01, -6.7455e-14, -3.1083e-01,  1.2429e-16,  3.0400e-14,
-        -3.1487e-01, -4.2894e-12,  1.3841e-05,  1.4430e-11, -7.9127e-18,
-        -2.5602e-08,  3.5147e-02,  4.7582e-02,  1.1530e-01, -2.8996e-09,
-         1.6468e-21, -6.4720e-16,  2.3276e-19,  5.4160e-01, -3.3615e-11,
-         6.2175e-19,  6.3851e-02,  5.2239e-07, -2.1885e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0394,  0.0000,  0.2067,  0.1669,  2.1490, -0.4570,  0.0000,  0.0000,
-        -0.0090,  0.0181, -0.0816,  0.0000, -0.1772, -0.3726,  0.0000,  0.0000,
-         0.0000, -0.3043,  0.0000,  0.0000,  0.1926, -0.0854,  0.0000, -0.8071,
-        -0.1327,  0.0000,  0.4185,  0.0967, -0.1629,  0.0000, -0.4509, -1.1488,
-         0.0000,  0.1059,  0.0000,  0.0000,  0.0000,  0.0000,  0.2415,  0.6947,
-         0.1428,  0.0000, -0.3108,  0.0000,  0.0000, -0.3149,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0351,  0.0476,  0.1153,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5416,  0.0000,  0.0000,  0.0639,  0.0000, -0.2189],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0394,  0.0000,  0.2067,  0.1669,  2.1490, -0.4570,  0.0000,  0.0000,
-        -0.0090,  0.0181, -0.0816,  0.0000, -0.1772, -0.3726,  0.0000,  0.0000,
-         0.0000, -0.3043,  0.0000,  0.0000,  0.1926, -0.0854,  0.0000, -0.8071,
-        -0.1327,  0.0000,  0.4185,  0.0967, -0.1629,  0.0000, -0.4509, -1.1488,
-         0.0000,  0.1059,  0.0000,  0.0000,  0.0000,  0.0000,  0.2415,  0.6947,
-         0.1428,  0.0000, -0.3108,  0.0000,  0.0000, -0.3149,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0351,  0.0476,  0.1153,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5416,  0.0000,  0.0000,  0.0639,  0.0000, -0.2189],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-3.0696e-02, -1.9192e-10,  1.9706e-01,  1.5856e-01,  2.1472e+00,
-        -4.7323e-01,  1.1617e-19, -9.1733e-13,  1.0055e-02, -2.4654e-02,
-        -3.7781e-02,  2.0885e-11, -1.5986e-01, -3.6647e-01, -2.0833e-16,
-        -1.5484e-13, -1.5231e-14, -3.0346e-01, -3.5228e-17,  4.2356e-13,
-         1.9455e-01, -9.0991e-02, -7.4812e-17, -8.1423e-01, -1.4478e-01,
-         6.0718e-13,  3.8979e-01,  8.6172e-02, -1.7861e-01,  1.0880e-16,
-        -4.7636e-01, -1.1478e+00, -2.7851e-20,  9.9336e-02,  1.3824e-14,
-         7.4590e-13,  4.6114e-16,  0.0000e+00,  2.5616e-01,  6.8105e-01,
-         1.4747e-01, -6.1593e-14, -2.9800e-01,  1.1349e-16,  2.7759e-14,
-        -3.0813e-01, -3.9167e-12,  1.2638e-05,  1.3176e-11, -7.2251e-18,
-        -2.3377e-08,  2.9594e-02,  5.7604e-02,  9.0989e-02, -2.6476e-09,
-         1.5037e-21, -5.9096e-16,  2.1253e-19,  5.5719e-01, -3.0694e-11,
-         5.6772e-19,  6.9361e-02,  4.7700e-07, -2.0085e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0307,  0.0000,  0.1971,  0.1586,  2.1472, -0.4732,  0.0000,  0.0000,
-         0.0101, -0.0247, -0.0378,  0.0000, -0.1599, -0.3665,  0.0000,  0.0000,
-         0.0000, -0.3035,  0.0000,  0.0000,  0.1946, -0.0910,  0.0000, -0.8142,
-        -0.1448,  0.0000,  0.3898,  0.0862, -0.1786,  0.0000, -0.4764, -1.1478,
-         0.0000,  0.0993,  0.0000,  0.0000,  0.0000,  0.0000,  0.2562,  0.6811,
-         0.1475,  0.0000, -0.2980,  0.0000,  0.0000, -0.3081,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0296,  0.0576,  0.0910,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5572,  0.0000,  0.0000,  0.0694,  0.0000, -0.2008],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0307,  0.0000,  0.1971,  0.1586,  2.1472, -0.4732,  0.0000,  0.0000,
-         0.0101, -0.0247, -0.0378,  0.0000, -0.1599, -0.3665,  0.0000,  0.0000,
-         0.0000, -0.3035,  0.0000,  0.0000,  0.1946, -0.0910,  0.0000, -0.8142,
-        -0.1448,  0.0000,  0.3898,  0.0862, -0.1786,  0.0000, -0.4764, -1.1478,
-         0.0000,  0.0993,  0.0000,  0.0000,  0.0000,  0.0000,  0.2562,  0.6811,
-         0.1475,  0.0000, -0.2980,  0.0000,  0.0000, -0.3081,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0296,  0.0576,  0.0910,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5572,  0.0000,  0.0000,  0.0694,  0.0000, -0.2008],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.5844e-02, -1.7528e-10,  1.8161e-01,  1.4227e-01,  2.1463e+00,
-        -4.7706e-01,  1.0610e-19, -8.3780e-13,  1.7089e-02, -6.5402e-02,
-         4.1537e-03,  1.9074e-11, -1.5951e-01, -3.6642e-01, -1.9027e-16,
-        -1.4142e-13, -1.3911e-14, -3.0134e-01, -3.2174e-17,  3.8684e-13,
-         1.8607e-01, -9.6346e-02, -6.8326e-17, -8.2146e-01, -1.4576e-01,
-         5.5454e-13,  3.6105e-01,  7.6856e-02, -1.9387e-01,  9.9370e-17,
-        -4.8753e-01, -1.1474e+00, -2.5437e-20,  1.0290e-01,  1.2626e-14,
-         6.8123e-13,  4.2116e-16,  0.0000e+00,  2.4056e-01,  6.6699e-01,
-         1.3606e-01, -5.6253e-14, -2.9439e-01,  1.0365e-16,  2.5352e-14,
-        -2.9696e-01, -3.5771e-12,  1.1542e-05,  1.2033e-11, -6.5987e-18,
-        -2.1350e-08,  5.7720e-03,  7.0826e-02,  5.3869e-02, -2.4181e-09,
-         1.3733e-21, -5.3973e-16,  1.9411e-19,  5.7409e-01, -2.8033e-11,
-         5.1850e-19,  7.3435e-02,  4.3564e-07, -1.8535e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0258,  0.0000,  0.1816,  0.1423,  2.1463, -0.4771,  0.0000,  0.0000,
-         0.0171, -0.0654,  0.0042,  0.0000, -0.1595, -0.3664,  0.0000,  0.0000,
-         0.0000, -0.3013,  0.0000,  0.0000,  0.1861, -0.0963,  0.0000, -0.8215,
-        -0.1458,  0.0000,  0.3611,  0.0769, -0.1939,  0.0000, -0.4875, -1.1474,
-         0.0000,  0.1029,  0.0000,  0.0000,  0.0000,  0.0000,  0.2406,  0.6670,
-         0.1361,  0.0000, -0.2944,  0.0000,  0.0000, -0.2970,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0058,  0.0708,  0.0539,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5741,  0.0000,  0.0000,  0.0734,  0.0000, -0.1854],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0258,  0.0000,  0.1816,  0.1423,  2.1463, -0.4771,  0.0000,  0.0000,
-         0.0171, -0.0654,  0.0042,  0.0000, -0.1595, -0.3664,  0.0000,  0.0000,
-         0.0000, -0.3013,  0.0000,  0.0000,  0.1861, -0.0963,  0.0000, -0.8215,
-        -0.1458,  0.0000,  0.3611,  0.0769, -0.1939,  0.0000, -0.4875, -1.1474,
-         0.0000,  0.1029,  0.0000,  0.0000,  0.0000,  0.0000,  0.2406,  0.6670,
-         0.1361,  0.0000, -0.2944,  0.0000,  0.0000, -0.2970,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0058,  0.0708,  0.0539,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5741,  0.0000,  0.0000,  0.0734,  0.0000, -0.1854],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-2.2700e-02, -1.6012e-10,  1.5520e-01,  1.2577e-01,  2.1454e+00,
-        -4.7525e-01,  9.6924e-20, -7.6533e-13,  2.3327e-02, -9.6334e-02,
-         4.3837e-02,  1.7424e-11, -1.6909e-01, -3.6731e-01, -1.7381e-16,
-        -1.2918e-13, -1.2707e-14, -3.0029e-01, -2.9391e-17,  3.5338e-13,
-         1.6497e-01, -9.8499e-02, -6.2415e-17, -8.3118e-01, -1.2879e-01,
-         5.0657e-13,  3.2574e-01,  5.9599e-02, -2.0626e-01,  9.0774e-17,
-        -4.9161e-01, -1.1465e+00, -2.3236e-20,  9.5888e-02,  1.1533e-14,
-         6.2230e-13,  3.8473e-16,  0.0000e+00,  2.0944e-01,  6.5696e-01,
-         1.1950e-01, -5.1387e-14, -2.9184e-01,  9.4685e-17,  2.3159e-14,
-        -2.8075e-01, -3.2677e-12,  1.0544e-05,  1.0992e-11, -6.0278e-18,
-        -1.9503e-08, -8.5395e-03,  7.0098e-02,  7.4599e-03, -2.2089e-09,
-         1.2545e-21, -4.9304e-16,  1.7731e-19,  5.9330e-01, -2.5608e-11,
-         4.7365e-19,  6.8798e-02,  3.9796e-07, -1.7288e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0227,  0.0000,  0.1552,  0.1258,  2.1454, -0.4753,  0.0000,  0.0000,
-         0.0233, -0.0963,  0.0438,  0.0000, -0.1691, -0.3673,  0.0000,  0.0000,
-         0.0000, -0.3003,  0.0000,  0.0000,  0.1650, -0.0985,  0.0000, -0.8312,
-        -0.1288,  0.0000,  0.3257,  0.0596, -0.2063,  0.0000, -0.4916, -1.1465,
-         0.0000,  0.0959,  0.0000,  0.0000,  0.0000,  0.0000,  0.2094,  0.6570,
-         0.1195,  0.0000, -0.2918,  0.0000,  0.0000, -0.2807,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0085,  0.0701,  0.0075,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5933,  0.0000,  0.0000,  0.0688,  0.0000, -0.1729],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0227,  0.0000,  0.1552,  0.1258,  2.1454, -0.4753,  0.0000,  0.0000,
-         0.0233, -0.0963,  0.0438,  0.0000, -0.1691, -0.3673,  0.0000,  0.0000,
-         0.0000, -0.3003,  0.0000,  0.0000,  0.1650, -0.0985,  0.0000, -0.8312,
-        -0.1288,  0.0000,  0.3257,  0.0596, -0.2063,  0.0000, -0.4916, -1.1465,
-         0.0000,  0.0959,  0.0000,  0.0000,  0.0000,  0.0000,  0.2094,  0.6570,
-         0.1195,  0.0000, -0.2918,  0.0000,  0.0000, -0.2807,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0085,  0.0701,  0.0075,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5933,  0.0000,  0.0000,  0.0688,  0.0000, -0.1729],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.9143e-02, -1.4630e-10,  1.3067e-01,  9.7382e-02,  2.1452e+00,
-        -4.7232e-01,  8.8558e-20, -6.9927e-13,  2.6456e-02, -1.2613e-01,
-         7.3280e-02,  1.5920e-11, -1.7443e-01, -3.7455e-01, -1.5881e-16,
-        -1.1803e-13, -1.1610e-14, -3.0510e-01, -2.6854e-17,  3.2287e-13,
-         1.4248e-01, -1.0646e-01, -5.7028e-17, -8.3584e-01, -1.1888e-01,
-         4.6284e-13,  2.8563e-01,  3.2295e-02, -2.1407e-01,  8.2939e-17,
-        -4.9498e-01, -1.1447e+00, -2.1231e-20,  8.3201e-02,  1.0538e-14,
-         5.6859e-13,  3.5152e-16,  0.0000e+00,  1.7002e-01,  6.5270e-01,
-         9.0548e-02, -4.6951e-14, -2.8586e-01,  8.6512e-17,  2.1160e-14,
-        -2.6376e-01, -2.9856e-12,  9.6338e-06,  1.0044e-11, -5.5076e-18,
-        -1.7820e-08, -2.9287e-02,  4.3958e-02, -5.0029e-02, -2.0183e-09,
-         1.1462e-21, -4.5048e-16,  1.6201e-19,  6.0566e-01, -2.3398e-11,
-         4.3277e-19,  4.8068e-02,  3.6361e-07, -1.6446e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0191,  0.0000,  0.1307,  0.0974,  2.1452, -0.4723,  0.0000,  0.0000,
-         0.0265, -0.1261,  0.0733,  0.0000, -0.1744, -0.3745,  0.0000,  0.0000,
-         0.0000, -0.3051,  0.0000,  0.0000,  0.1425, -0.1065,  0.0000, -0.8358,
-        -0.1189,  0.0000,  0.2856,  0.0323, -0.2141,  0.0000, -0.4950, -1.1447,
-         0.0000,  0.0832,  0.0000,  0.0000,  0.0000,  0.0000,  0.1700,  0.6527,
-         0.0905,  0.0000, -0.2859,  0.0000,  0.0000, -0.2638,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0293,  0.0440, -0.0500,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6057,  0.0000,  0.0000,  0.0481,  0.0000, -0.1645],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0191,  0.0000,  0.1307,  0.0974,  2.1452, -0.4723,  0.0000,  0.0000,
-         0.0265, -0.1261,  0.0733,  0.0000, -0.1744, -0.3745,  0.0000,  0.0000,
-         0.0000, -0.3051,  0.0000,  0.0000,  0.1425, -0.1065,  0.0000, -0.8358,
-        -0.1189,  0.0000,  0.2856,  0.0323, -0.2141,  0.0000, -0.4950, -1.1447,
-         0.0000,  0.0832,  0.0000,  0.0000,  0.0000,  0.0000,  0.1700,  0.6527,
-         0.0905,  0.0000, -0.2859,  0.0000,  0.0000, -0.2638,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0293,  0.0440, -0.0500,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6057,  0.0000,  0.0000,  0.0481,  0.0000, -0.1645],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-1.7603e-02, -1.3370e-10,  1.1024e-01,  8.2304e-02,  2.1439e+00,
-        -4.6162e-01,  8.0931e-20, -6.3904e-13,  3.0839e-02, -1.3842e-01,
-         9.1792e-02,  1.4549e-11, -1.7731e-01, -3.8494e-01, -1.4513e-16,
-        -1.0787e-13, -1.0610e-14, -3.0447e-01, -2.4541e-17,  2.9507e-13,
-         1.1242e-01, -1.0512e-01, -5.2116e-17, -8.4015e-01, -1.0444e-01,
-         4.2298e-13,  2.5211e-01,  2.5508e-03, -2.2386e-01,  7.5796e-17,
-        -4.8415e-01, -1.1439e+00, -1.9402e-20,  6.3344e-02,  9.6303e-15,
-         5.1961e-13,  3.2124e-16,  0.0000e+00,  1.2410e-01,  6.4254e-01,
-         6.2077e-02, -4.2908e-14, -2.7537e-01,  7.9061e-17,  1.9338e-14,
-        -2.3863e-01, -2.7285e-12,  8.8041e-06,  9.1786e-12, -5.0332e-18,
-        -1.6285e-08, -3.6714e-02,  4.1163e-03, -8.7170e-02, -1.8444e-09,
-         1.0475e-21, -4.1168e-16,  1.4806e-19,  6.1333e-01, -2.1383e-11,
-         3.9549e-19,  3.0658e-02,  3.3229e-07, -1.6494e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0176,  0.0000,  0.1102,  0.0823,  2.1439, -0.4616,  0.0000,  0.0000,
-         0.0308, -0.1384,  0.0918,  0.0000, -0.1773, -0.3849,  0.0000,  0.0000,
-         0.0000, -0.3045,  0.0000,  0.0000,  0.1124, -0.1051,  0.0000, -0.8401,
-        -0.1044,  0.0000,  0.2521,  0.0026, -0.2239,  0.0000, -0.4841, -1.1439,
-         0.0000,  0.0633,  0.0000,  0.0000,  0.0000,  0.0000,  0.1241,  0.6425,
-         0.0621,  0.0000, -0.2754,  0.0000,  0.0000, -0.2386,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0367,  0.0041, -0.0872,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6133,  0.0000,  0.0000,  0.0307,  0.0000, -0.1649],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0176,  0.0000,  0.1102,  0.0823,  2.1439, -0.4616,  0.0000,  0.0000,
-         0.0308, -0.1384,  0.0918,  0.0000, -0.1773, -0.3849,  0.0000,  0.0000,
-         0.0000, -0.3045,  0.0000,  0.0000,  0.1124, -0.1051,  0.0000, -0.8401,
-        -0.1044,  0.0000,  0.2521,  0.0026, -0.2239,  0.0000, -0.4841, -1.1439,
-         0.0000,  0.0633,  0.0000,  0.0000,  0.0000,  0.0000,  0.1241,  0.6425,
-         0.0621,  0.0000, -0.2754,  0.0000,  0.0000, -0.2386,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0367,  0.0041, -0.0872,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6133,  0.0000,  0.0000,  0.0307,  0.0000, -0.1649],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([-5.9668e-03, -1.2221e-10,  9.9170e-02,  8.2513e-02,  2.1428e+00,
-        -4.5169e-01,  7.3975e-20, -5.8412e-13,  3.5229e-02, -1.3577e-01,
-         1.0008e-01,  1.3299e-11, -1.7365e-01, -4.0070e-01, -1.3266e-16,
-        -9.8596e-14, -9.6985e-15, -2.9294e-01, -2.2432e-17,  2.6971e-13,
-         8.1958e-02, -1.0664e-01, -4.7637e-17, -8.4698e-01, -8.9948e-02,
-         3.8663e-13,  2.3582e-01, -1.7131e-02, -2.3758e-01,  6.9281e-17,
-        -4.7091e-01, -1.1449e+00, -1.7735e-20,  4.4544e-02,  8.8026e-15,
-         4.7496e-13,  2.9363e-16,  0.0000e+00,  8.0630e-02,  6.3308e-01,
-         2.8752e-02, -3.9220e-14, -2.6035e-01,  7.2266e-17,  1.7676e-14,
-        -2.0988e-01, -2.4940e-12,  8.0474e-06,  8.3897e-12, -4.6006e-18,
-        -1.4885e-08, -4.9874e-02, -4.0534e-02, -6.7422e-02, -1.6859e-09,
-         9.5746e-22, -3.7630e-16,  1.3533e-19,  6.1971e-01, -1.9545e-11,
-         3.6150e-19,  2.7653e-02,  3.0373e-07, -1.5613e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([-0.0060,  0.0000,  0.0992,  0.0825,  2.1428, -0.4517,  0.0000,  0.0000,
-         0.0352, -0.1358,  0.1001,  0.0000, -0.1737, -0.4007,  0.0000,  0.0000,
-         0.0000, -0.2929,  0.0000,  0.0000,  0.0820, -0.1066,  0.0000, -0.8470,
-        -0.0899,  0.0000,  0.2358, -0.0171, -0.2376,  0.0000, -0.4709, -1.1449,
-         0.0000,  0.0445,  0.0000,  0.0000,  0.0000,  0.0000,  0.0806,  0.6331,
-         0.0288,  0.0000, -0.2604,  0.0000,  0.0000, -0.2099,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0499, -0.0405, -0.0674,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6197,  0.0000,  0.0000,  0.0277,  0.0000, -0.1561],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([-0.0060,  0.0000,  0.0992,  0.0825,  2.1428, -0.4517,  0.0000,  0.0000,
-         0.0352, -0.1358,  0.1001,  0.0000, -0.1737, -0.4007,  0.0000,  0.0000,
-         0.0000, -0.2929,  0.0000,  0.0000,  0.0820, -0.1066,  0.0000, -0.8470,
-        -0.0899,  0.0000,  0.2358, -0.0171, -0.2376,  0.0000, -0.4709, -1.1449,
-         0.0000,  0.0445,  0.0000,  0.0000,  0.0000,  0.0000,  0.0806,  0.6331,
-         0.0288,  0.0000, -0.2604,  0.0000,  0.0000, -0.2099,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0499, -0.0405, -0.0674,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6197,  0.0000,  0.0000,  0.0277,  0.0000, -0.1561],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.1360e-02, -1.1173e-10,  9.8507e-02,  8.5128e-02,  2.1421e+00,
-        -4.4148e-01,  6.7630e-20, -5.3402e-13,  3.4002e-02, -1.3562e-01,
-         9.6247e-02,  1.2158e-11, -1.6195e-01, -4.2726e-01, -1.2128e-16,
-        -9.0140e-14, -8.8666e-15, -2.7467e-01, -2.0508e-17,  2.4657e-13,
-         5.9268e-02, -1.1204e-01, -4.3551e-17, -8.5830e-01, -8.4626e-02,
-         3.5346e-13,  2.3062e-01, -3.2361e-02, -2.5209e-01,  6.3339e-17,
-        -4.5675e-01, -1.1456e+00, -1.6213e-20,  1.2895e-02,  8.0476e-15,
-         4.3422e-13,  2.6845e-16,  0.0000e+00,  4.4299e-02,  6.2817e-01,
-        -4.4018e-03, -3.5856e-14, -2.3954e-01,  6.6068e-17,  1.6159e-14,
-        -1.8098e-01, -2.2801e-12,  7.3571e-06,  7.6701e-12, -4.2060e-18,
-        -1.3609e-08, -6.4432e-02, -9.8332e-02, -2.6827e-02, -1.5413e-09,
-         8.7534e-22, -3.4402e-16,  1.2372e-19,  6.2409e-01, -1.7868e-11,
-         3.3049e-19,  3.1434e-02,  2.7768e-07, -1.4609e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0114,  0.0000,  0.0985,  0.0851,  2.1421, -0.4415,  0.0000,  0.0000,
-         0.0340, -0.1356,  0.0962,  0.0000, -0.1619, -0.4273,  0.0000,  0.0000,
-         0.0000, -0.2747,  0.0000,  0.0000,  0.0593, -0.1120,  0.0000, -0.8583,
-        -0.0846,  0.0000,  0.2306, -0.0324, -0.2521,  0.0000, -0.4568, -1.1456,
-         0.0000,  0.0129,  0.0000,  0.0000,  0.0000,  0.0000,  0.0443,  0.6282,
-        -0.0044,  0.0000, -0.2395,  0.0000,  0.0000, -0.1810,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0644, -0.0983, -0.0268,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6241,  0.0000,  0.0000,  0.0314,  0.0000, -0.1461],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0114,  0.0000,  0.0985,  0.0851,  2.1421, -0.4415,  0.0000,  0.0000,
-         0.0340, -0.1356,  0.0962,  0.0000, -0.1619, -0.4273,  0.0000,  0.0000,
-         0.0000, -0.2747,  0.0000,  0.0000,  0.0593, -0.1120,  0.0000, -0.8583,
-        -0.0846,  0.0000,  0.2306, -0.0324, -0.2521,  0.0000, -0.4568, -1.1456,
-         0.0000,  0.0129,  0.0000,  0.0000,  0.0000,  0.0000,  0.0443,  0.6282,
-        -0.0044,  0.0000, -0.2395,  0.0000,  0.0000, -0.1810,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0644, -0.0983, -0.0268,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6241,  0.0000,  0.0000,  0.0314,  0.0000, -0.1461],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9978e-02, -1.0216e-10,  9.5246e-02,  9.1153e-02,  2.1413e+00,
-        -4.1727e-01,  6.1841e-20, -4.8830e-13,  2.5556e-02, -1.2243e-01,
-         7.5153e-02,  1.1117e-11, -1.6467e-01, -4.5455e-01, -1.1090e-16,
-        -8.2424e-14, -8.1076e-15, -2.5944e-01, -1.8752e-17,  2.2547e-13,
-         3.3026e-02, -1.1284e-01, -3.9823e-17, -8.7466e-01, -6.9704e-02,
-         3.2321e-13,  2.3654e-01, -3.9488e-02, -2.6784e-01,  5.7917e-17,
-        -4.3159e-01, -1.1475e+00, -1.4826e-20, -5.3266e-03,  7.3587e-15,
-         3.9705e-13,  2.4547e-16,  0.0000e+00,  1.8045e-02,  6.2669e-01,
-        -5.1119e-02, -3.2786e-14, -2.3077e-01,  6.0412e-17,  1.4776e-14,
-        -1.3519e-01, -2.0849e-12,  6.7273e-06,  7.0135e-12, -3.8460e-18,
-        -1.2444e-08, -8.6351e-02, -1.4530e-01,  1.8949e-02, -1.4094e-09,
-         8.0041e-22, -3.1457e-16,  1.1313e-19,  6.2870e-01, -1.6339e-11,
-         3.0220e-19,  4.2261e-02,  2.5391e-07, -1.2688e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0300,  0.0000,  0.0952,  0.0912,  2.1413, -0.4173,  0.0000,  0.0000,
-         0.0256, -0.1224,  0.0752,  0.0000, -0.1647, -0.4545,  0.0000,  0.0000,
-         0.0000, -0.2594,  0.0000,  0.0000,  0.0330, -0.1128,  0.0000, -0.8747,
-        -0.0697,  0.0000,  0.2365, -0.0395, -0.2678,  0.0000, -0.4316, -1.1475,
-         0.0000, -0.0053,  0.0000,  0.0000,  0.0000,  0.0000,  0.0180,  0.6267,
-        -0.0511,  0.0000, -0.2308,  0.0000,  0.0000, -0.1352,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0864, -0.1453,  0.0189,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6287,  0.0000,  0.0000,  0.0423,  0.0000, -0.1269],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0300,  0.0000,  0.0952,  0.0912,  2.1413, -0.4173,  0.0000,  0.0000,
-         0.0256, -0.1224,  0.0752,  0.0000, -0.1647, -0.4545,  0.0000,  0.0000,
-         0.0000, -0.2594,  0.0000,  0.0000,  0.0330, -0.1128,  0.0000, -0.8747,
-        -0.0697,  0.0000,  0.2365, -0.0395, -0.2678,  0.0000, -0.4316, -1.1475,
-         0.0000, -0.0053,  0.0000,  0.0000,  0.0000,  0.0000,  0.0180,  0.6267,
-        -0.0511,  0.0000, -0.2308,  0.0000,  0.0000, -0.1352,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0864, -0.1453,  0.0189,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6287,  0.0000,  0.0000,  0.0423,  0.0000, -0.1269],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 4.6584e-02, -9.3434e-11,  9.2716e-02,  1.0304e-01,  2.1416e+00,
-        -3.9071e-01,  5.6557e-20, -4.4659e-13,  1.2456e-02, -1.0636e-01,
-         4.3122e-02,  1.0167e-11, -1.6943e-01, -4.6888e-01, -1.0142e-16,
-        -7.5382e-14, -7.4149e-15, -2.4774e-01, -1.7150e-17,  2.0620e-13,
-         2.0526e-02, -1.0985e-01, -3.6421e-17, -8.9091e-01, -4.6599e-02,
-         2.9559e-13,  2.3927e-01, -3.6055e-02, -2.7976e-01,  5.2969e-17,
-        -4.0888e-01, -1.1492e+00, -1.3559e-20, -2.5773e-02,  6.7300e-15,
-         3.6313e-13,  2.2450e-16,  0.0000e+00, -9.7608e-03,  6.3074e-01,
-        -8.9807e-02, -2.9985e-14, -2.2922e-01,  5.5251e-17,  1.3514e-14,
-        -9.4096e-02, -1.9068e-12,  6.1526e-06,  6.4143e-12, -3.5174e-18,
-        -1.1381e-08, -9.6414e-02, -2.0167e-01,  7.4660e-02, -1.2890e-09,
-         7.3203e-22, -2.8770e-16,  1.0347e-19,  6.3316e-01, -1.4943e-11,
-         2.7638e-19,  4.5313e-02,  2.3222e-07, -1.2279e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0466,  0.0000,  0.0927,  0.1030,  2.1416, -0.3907,  0.0000,  0.0000,
-         0.0125, -0.1064,  0.0431,  0.0000, -0.1694, -0.4689,  0.0000,  0.0000,
-         0.0000, -0.2477,  0.0000,  0.0000,  0.0205, -0.1099,  0.0000, -0.8909,
-        -0.0466,  0.0000,  0.2393, -0.0361, -0.2798,  0.0000, -0.4089, -1.1492,
-         0.0000, -0.0258,  0.0000,  0.0000,  0.0000,  0.0000, -0.0098,  0.6307,
-        -0.0898,  0.0000, -0.2292,  0.0000,  0.0000, -0.0941,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0964, -0.2017,  0.0747,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6332,  0.0000,  0.0000,  0.0453,  0.0000, -0.1228],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0466,  0.0000,  0.0927,  0.1030,  2.1416, -0.3907,  0.0000,  0.0000,
-         0.0125, -0.1064,  0.0431,  0.0000, -0.1694, -0.4689,  0.0000,  0.0000,
-         0.0000, -0.2477,  0.0000,  0.0000,  0.0205, -0.1099,  0.0000, -0.8909,
-        -0.0466,  0.0000,  0.2393, -0.0361, -0.2798,  0.0000, -0.4089, -1.1492,
-         0.0000, -0.0258,  0.0000,  0.0000,  0.0000,  0.0000, -0.0098,  0.6307,
-        -0.0898,  0.0000, -0.2292,  0.0000,  0.0000, -0.0941,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0964, -0.2017,  0.0747,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6332,  0.0000,  0.0000,  0.0453,  0.0000, -0.1228],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.0399e-02, -8.5467e-11,  9.7680e-02,  1.1211e-01,  2.1408e+00,
-        -3.6795e-01,  5.1734e-20, -4.0850e-13,  3.6946e-03, -8.6520e-02,
-         3.3709e-02,  9.3004e-12, -1.7439e-01, -4.7614e-01, -9.2775e-17,
-        -6.8953e-14, -6.7826e-15, -2.3493e-01, -1.5688e-17,  1.8862e-13,
-         7.8452e-03, -1.0925e-01, -3.3315e-17, -9.0147e-01, -2.6068e-02,
-         2.7039e-13,  2.4601e-01, -3.5070e-02, -2.8701e-01,  4.8452e-17,
-        -3.8935e-01, -1.1503e+00, -1.2403e-20, -4.9814e-02,  6.1561e-15,
-         3.3216e-13,  2.0535e-16,  0.0000e+00, -3.0692e-02,  6.3261e-01,
-        -1.2431e-01, -2.7428e-14, -2.2725e-01,  5.0539e-17,  1.2361e-14,
-        -6.2387e-02, -1.7442e-12,  5.6279e-06,  5.8673e-12, -3.2174e-18,
-        -1.0410e-08, -9.9996e-02, -2.4509e-01,  1.1927e-01, -1.1790e-09,
-         6.6960e-22, -2.6317e-16,  9.4644e-20,  6.3055e-01, -1.3669e-11,
-         2.5282e-19,  4.0573e-02,  2.1241e-07, -1.1328e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0604,  0.0000,  0.0977,  0.1121,  2.1408, -0.3679,  0.0000,  0.0000,
-         0.0037, -0.0865,  0.0337,  0.0000, -0.1744, -0.4761,  0.0000,  0.0000,
-         0.0000, -0.2349,  0.0000,  0.0000,  0.0078, -0.1092,  0.0000, -0.9015,
-        -0.0261,  0.0000,  0.2460, -0.0351, -0.2870,  0.0000, -0.3893, -1.1503,
-         0.0000, -0.0498,  0.0000,  0.0000,  0.0000,  0.0000, -0.0307,  0.6326,
-        -0.1243,  0.0000, -0.2273,  0.0000,  0.0000, -0.0624,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.1000, -0.2451,  0.1193,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6306,  0.0000,  0.0000,  0.0406,  0.0000, -0.1133],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0604,  0.0000,  0.0977,  0.1121,  2.1408, -0.3679,  0.0000,  0.0000,
-         0.0037, -0.0865,  0.0337,  0.0000, -0.1744, -0.4761,  0.0000,  0.0000,
-         0.0000, -0.2349,  0.0000,  0.0000,  0.0078, -0.1092,  0.0000, -0.9015,
-        -0.0261,  0.0000,  0.2460, -0.0351, -0.2870,  0.0000, -0.3893, -1.1503,
-         0.0000, -0.0498,  0.0000,  0.0000,  0.0000,  0.0000, -0.0307,  0.6326,
-        -0.1243,  0.0000, -0.2273,  0.0000,  0.0000, -0.0624,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.1000, -0.2451,  0.1193,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6306,  0.0000,  0.0000,  0.0406,  0.0000, -0.1133],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 7.2794e-02, -7.8192e-11,  1.0193e-01,  1.1351e-01,  2.1397e+00,
-        -3.4792e-01,  4.7331e-20, -3.7373e-13, -8.1407e-04, -7.4476e-02,
-         4.5637e-02,  8.5088e-12, -1.7901e-01, -4.7733e-01, -8.4878e-17,
-        -6.3084e-14, -6.2053e-15, -2.2577e-01, -1.4352e-17,  1.7256e-13,
-         1.5301e-03, -1.1389e-01, -3.0479e-17, -9.0575e-01, -9.7372e-03,
-         2.4737e-13,  2.3762e-01, -3.1597e-02, -2.8705e-01,  4.4328e-17,
-        -3.7925e-01, -1.1502e+00, -1.1347e-20, -7.3869e-02,  5.6321e-15,
-         3.0389e-13,  1.8787e-16,  0.0000e+00, -3.8378e-02,  6.3319e-01,
-        -1.5407e-01, -2.5094e-14, -2.2702e-01,  4.6237e-17,  1.1309e-14,
-        -3.5466e-02, -1.5957e-12,  5.1489e-06,  5.3679e-12, -2.9436e-18,
-        -9.5240e-09, -9.1298e-02, -2.7637e-01,  1.4938e-01, -1.0787e-09,
-         6.1261e-22, -2.4076e-16,  8.6588e-20,  6.2359e-01, -1.2505e-11,
-         2.3130e-19,  2.8711e-02,  1.9433e-07, -1.0001e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 7.2794e-02,  0.0000e+00,  1.0193e-01,  1.1351e-01,  2.1397e+00,
-        -3.4792e-01,  0.0000e+00,  0.0000e+00, -8.1407e-04, -7.4476e-02,
-         4.5637e-02,  0.0000e+00, -1.7901e-01, -4.7733e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.2577e-01,  0.0000e+00,  0.0000e+00,
-         1.5301e-03, -1.1389e-01,  0.0000e+00, -9.0575e-01, -9.7372e-03,
-         0.0000e+00,  2.3762e-01, -3.1597e-02, -2.8705e-01,  0.0000e+00,
-        -3.7925e-01, -1.1502e+00,  0.0000e+00, -7.3869e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -3.8378e-02,  6.3319e-01,
-        -1.5407e-01,  0.0000e+00, -2.2702e-01,  0.0000e+00,  0.0000e+00,
-        -3.5466e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -9.1298e-02, -2.7637e-01,  1.4938e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.2359e-01,  0.0000e+00,
-         0.0000e+00,  2.8711e-02,  0.0000e+00, -1.0001e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 7.2794e-02,  0.0000e+00,  1.0193e-01,  1.1351e-01,  2.1397e+00,
-        -3.4792e-01,  0.0000e+00,  0.0000e+00, -8.1407e-04, -7.4476e-02,
-         4.5637e-02,  0.0000e+00, -1.7901e-01, -4.7733e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.2577e-01,  0.0000e+00,  0.0000e+00,
-         1.5301e-03, -1.1389e-01,  0.0000e+00, -9.0575e-01, -9.7372e-03,
-         0.0000e+00,  2.3762e-01, -3.1597e-02, -2.8705e-01,  0.0000e+00,
-        -3.7925e-01, -1.1502e+00,  0.0000e+00, -7.3869e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -3.8378e-02,  6.3319e-01,
-        -1.5407e-01,  0.0000e+00, -2.2702e-01,  0.0000e+00,  0.0000e+00,
-        -3.5466e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -9.1298e-02, -2.7637e-01,  1.4938e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.2359e-01,  0.0000e+00,
-         0.0000e+00,  2.8711e-02,  0.0000e+00, -1.0001e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 8.1795e-02, -7.1548e-11,  9.3781e-02,  1.0702e-01,  2.1384e+00,
-        -3.3166e-01,  4.3309e-20, -3.4197e-13, -1.8115e-03, -5.1305e-02,
-         5.6134e-02,  7.7858e-12, -1.8903e-01, -4.7747e-01, -7.7665e-17,
-        -5.7724e-14, -5.6780e-15, -2.1564e-01, -1.3133e-17,  1.5790e-13,
-        -9.9353e-03, -1.0878e-01, -2.7889e-17, -9.0229e-01,  1.3225e-02,
-         2.2635e-13,  2.1973e-01, -3.1266e-02, -2.8252e-01,  4.0561e-17,
-        -3.6999e-01, -1.1511e+00, -1.0383e-20, -8.8463e-02,  5.1535e-15,
-         2.7806e-13,  1.7191e-16,  0.0000e+00, -4.2173e-02,  6.3065e-01,
-        -1.6614e-01, -2.2961e-14, -2.3213e-01,  4.2308e-17,  1.0348e-14,
-        -2.6745e-02, -1.4601e-12,  4.7114e-06,  4.9118e-12, -2.6934e-18,
-        -8.7148e-09, -8.4654e-02, -2.9846e-01,  1.5209e-01, -9.8702e-10,
-         5.6055e-22, -2.2031e-16,  7.9230e-20,  6.1289e-01, -1.1443e-11,
-         2.1164e-19,  4.0104e-03,  1.7782e-07, -8.6307e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.1795e-02,  0.0000e+00,  9.3781e-02,  1.0702e-01,  2.1384e+00,
-        -3.3166e-01,  0.0000e+00,  0.0000e+00, -1.8115e-03, -5.1305e-02,
-         5.6134e-02,  0.0000e+00, -1.8903e-01, -4.7747e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.1564e-01,  0.0000e+00,  0.0000e+00,
-        -9.9353e-03, -1.0878e-01,  0.0000e+00, -9.0229e-01,  1.3225e-02,
-         0.0000e+00,  2.1973e-01, -3.1266e-02, -2.8252e-01,  0.0000e+00,
-        -3.6999e-01, -1.1511e+00,  0.0000e+00, -8.8463e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -4.2173e-02,  6.3065e-01,
-        -1.6614e-01,  0.0000e+00, -2.3213e-01,  0.0000e+00,  0.0000e+00,
-        -2.6745e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -8.4654e-02, -2.9846e-01,  1.5209e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.1289e-01,  0.0000e+00,
-         0.0000e+00,  4.0104e-03,  0.0000e+00, -8.6307e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.1795e-02,  0.0000e+00,  9.3781e-02,  1.0702e-01,  2.1384e+00,
-        -3.3166e-01,  0.0000e+00,  0.0000e+00, -1.8115e-03, -5.1305e-02,
-         5.6134e-02,  0.0000e+00, -1.8903e-01, -4.7747e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.1564e-01,  0.0000e+00,  0.0000e+00,
-        -9.9353e-03, -1.0878e-01,  0.0000e+00, -9.0229e-01,  1.3225e-02,
-         0.0000e+00,  2.1973e-01, -3.1266e-02, -2.8252e-01,  0.0000e+00,
-        -3.6999e-01, -1.1511e+00,  0.0000e+00, -8.8463e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -4.2173e-02,  6.3065e-01,
-        -1.6614e-01,  0.0000e+00, -2.3213e-01,  0.0000e+00,  0.0000e+00,
-        -2.6745e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -8.4654e-02, -2.9846e-01,  1.5209e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  6.1289e-01,  0.0000e+00,
-         0.0000e+00,  4.0104e-03,  0.0000e+00, -8.6307e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 7.6617e-02, -6.5478e-11,  8.1871e-02,  7.8570e-02,  2.1369e+00,
-        -3.1530e-01,  3.9635e-20, -3.1297e-13, -4.6049e-03, -4.0404e-02,
-         8.0625e-02,  7.1253e-12, -2.0550e-01, -4.7474e-01, -7.1077e-17,
-        -5.2827e-14, -5.1964e-15, -2.2200e-01, -1.2019e-17,  1.4451e-13,
-        -2.5884e-02, -1.0994e-01, -2.5524e-17, -8.9622e-01,  3.2426e-02,
-         2.0715e-13,  1.8764e-01, -3.6660e-02, -2.6669e-01,  3.7120e-17,
-        -3.6852e-01, -1.1529e+00, -9.5021e-21, -1.0279e-01,  4.7164e-15,
-         2.5448e-13,  1.5733e-16,  0.0000e+00, -4.3549e-02,  6.2225e-01,
-        -1.7442e-01, -2.1014e-14, -2.3692e-01,  3.8720e-17,  9.4704e-15,
-        -2.2162e-02, -1.3363e-12,  4.3117e-06,  4.4951e-12, -2.4650e-18,
-        -7.9755e-09, -7.7465e-02, -3.0937e-01,  9.6399e-02, -9.0329e-10,
-         5.1300e-22, -2.0162e-16,  7.2509e-20,  6.0223e-01, -1.0472e-11,
-         1.9369e-19, -3.0259e-02,  1.6274e-07, -7.3742e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0766,  0.0000,  0.0819,  0.0786,  2.1369, -0.3153,  0.0000,  0.0000,
-        -0.0046, -0.0404,  0.0806,  0.0000, -0.2055, -0.4747,  0.0000,  0.0000,
-         0.0000, -0.2220,  0.0000,  0.0000, -0.0259, -0.1099,  0.0000, -0.8962,
-         0.0324,  0.0000,  0.1876, -0.0367, -0.2667,  0.0000, -0.3685, -1.1529,
-         0.0000, -0.1028,  0.0000,  0.0000,  0.0000,  0.0000, -0.0435,  0.6223,
-        -0.1744,  0.0000, -0.2369,  0.0000,  0.0000, -0.0222,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0775, -0.3094,  0.0964,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6022,  0.0000,  0.0000, -0.0303,  0.0000, -0.0737],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0766,  0.0000,  0.0819,  0.0786,  2.1369, -0.3153,  0.0000,  0.0000,
-        -0.0046, -0.0404,  0.0806,  0.0000, -0.2055, -0.4747,  0.0000,  0.0000,
-         0.0000, -0.2220,  0.0000,  0.0000, -0.0259, -0.1099,  0.0000, -0.8962,
-         0.0324,  0.0000,  0.1876, -0.0367, -0.2667,  0.0000, -0.3685, -1.1529,
-         0.0000, -0.1028,  0.0000,  0.0000,  0.0000,  0.0000, -0.0435,  0.6223,
-        -0.1744,  0.0000, -0.2369,  0.0000,  0.0000, -0.0222,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0775, -0.3094,  0.0964,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.6022,  0.0000,  0.0000, -0.0303,  0.0000, -0.0737],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.6938e-02, -5.9933e-11,  6.1463e-02,  4.6548e-02,  2.1357e+00,
-        -3.0125e-01,  3.6279e-20, -2.8646e-13, -9.6954e-03, -3.8910e-02,
-         8.9126e-02,  6.5219e-12, -2.2150e-01, -4.6246e-01, -6.5058e-17,
-        -4.8353e-14, -4.7563e-15, -2.3334e-01, -1.1001e-17,  1.3227e-13,
-        -3.8196e-02, -9.9417e-02, -2.3362e-17, -8.8880e-01,  4.9475e-02,
-         1.8961e-13,  1.4863e-01, -3.7596e-02, -2.5014e-01,  3.3977e-17,
-        -3.6574e-01, -1.1539e+00, -8.6973e-21, -1.2601e-01,  4.3169e-15,
-         2.3293e-13,  1.4400e-16,  0.0000e+00, -3.4922e-02,  6.1636e-01,
-        -1.6892e-01, -1.9234e-14, -2.4364e-01,  3.5440e-17,  8.6684e-15,
-        -2.1936e-02, -1.2231e-12,  3.9466e-06,  4.1145e-12, -2.2562e-18,
-        -7.3001e-09, -7.0178e-02, -3.1832e-01,  1.7714e-02, -8.2679e-10,
-         4.6956e-22, -1.8454e-16,  6.6369e-20,  5.9207e-01, -9.5851e-12,
-         1.7729e-19, -7.1038e-02,  1.4895e-07, -6.7960e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0669,  0.0000,  0.0615,  0.0465,  2.1357, -0.3013,  0.0000,  0.0000,
-        -0.0097, -0.0389,  0.0891,  0.0000, -0.2215, -0.4625,  0.0000,  0.0000,
-         0.0000, -0.2333,  0.0000,  0.0000, -0.0382, -0.0994,  0.0000, -0.8888,
-         0.0495,  0.0000,  0.1486, -0.0376, -0.2501,  0.0000, -0.3657, -1.1539,
-         0.0000, -0.1260,  0.0000,  0.0000,  0.0000,  0.0000, -0.0349,  0.6164,
-        -0.1689,  0.0000, -0.2436,  0.0000,  0.0000, -0.0219,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0702, -0.3183,  0.0177,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5921,  0.0000,  0.0000, -0.0710,  0.0000, -0.0680],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0669,  0.0000,  0.0615,  0.0465,  2.1357, -0.3013,  0.0000,  0.0000,
-        -0.0097, -0.0389,  0.0891,  0.0000, -0.2215, -0.4625,  0.0000,  0.0000,
-         0.0000, -0.2333,  0.0000,  0.0000, -0.0382, -0.0994,  0.0000, -0.8888,
-         0.0495,  0.0000,  0.1486, -0.0376, -0.2501,  0.0000, -0.3657, -1.1539,
-         0.0000, -0.1260,  0.0000,  0.0000,  0.0000,  0.0000, -0.0349,  0.6164,
-        -0.1689,  0.0000, -0.2436,  0.0000,  0.0000, -0.0219,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0702, -0.3183,  0.0177,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5921,  0.0000,  0.0000, -0.0710,  0.0000, -0.0680],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.3507e-02, -5.4865e-11,  4.3173e-02,  5.5495e-03,  2.1346e+00,
-        -2.9161e-01,  3.3211e-20, -2.6224e-13, -1.3052e-02, -4.9939e-02,
-         9.6492e-02,  5.9704e-12, -2.2917e-01, -4.4722e-01, -5.9557e-17,
-        -4.4265e-14, -4.3541e-15, -2.5368e-01, -1.0071e-17,  1.2108e-13,
-        -4.0178e-02, -8.7360e-02, -2.1387e-17, -8.7857e-01,  5.9436e-02,
-         1.7358e-13,  9.6969e-02, -2.9862e-02, -2.2555e-01,  3.1104e-17,
-        -3.7213e-01, -1.1565e+00, -7.9619e-21, -1.5184e-01,  3.9519e-15,
-         2.1323e-13,  1.3183e-16,  0.0000e+00, -2.0449e-02,  6.1691e-01,
-        -1.4920e-01, -1.7608e-14, -2.4758e-01,  3.2444e-17,  7.9354e-15,
-        -3.9224e-02, -1.1197e-12,  3.6129e-06,  3.7666e-12, -2.0654e-18,
-        -6.6828e-09, -5.0729e-02, -3.1867e-01, -7.4174e-02, -7.5688e-10,
-         4.2985e-22, -1.6894e-16,  6.0757e-20,  5.7622e-01, -8.7746e-12,
-         1.6230e-19, -1.1159e-01,  1.3636e-07, -6.9165e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0635,  0.0000,  0.0432,  0.0055,  2.1346, -0.2916,  0.0000,  0.0000,
-        -0.0131, -0.0499,  0.0965,  0.0000, -0.2292, -0.4472,  0.0000,  0.0000,
-         0.0000, -0.2537,  0.0000,  0.0000, -0.0402, -0.0874,  0.0000, -0.8786,
-         0.0594,  0.0000,  0.0970, -0.0299, -0.2256,  0.0000, -0.3721, -1.1565,
-         0.0000, -0.1518,  0.0000,  0.0000,  0.0000,  0.0000, -0.0204,  0.6169,
-        -0.1492,  0.0000, -0.2476,  0.0000,  0.0000, -0.0392,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0507, -0.3187, -0.0742,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5762,  0.0000,  0.0000, -0.1116,  0.0000, -0.0692],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0635,  0.0000,  0.0432,  0.0055,  2.1346, -0.2916,  0.0000,  0.0000,
-        -0.0131, -0.0499,  0.0965,  0.0000, -0.2292, -0.4472,  0.0000,  0.0000,
-         0.0000, -0.2537,  0.0000,  0.0000, -0.0402, -0.0874,  0.0000, -0.8786,
-         0.0594,  0.0000,  0.0970, -0.0299, -0.2256,  0.0000, -0.3721, -1.1565,
-         0.0000, -0.1518,  0.0000,  0.0000,  0.0000,  0.0000, -0.0204,  0.6169,
-        -0.1492,  0.0000, -0.2476,  0.0000,  0.0000, -0.0392,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0507, -0.3187, -0.0742,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5762,  0.0000,  0.0000, -0.1116,  0.0000, -0.0692],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 6.2678e-02, -5.0233e-11,  2.6497e-02, -2.3576e-02,  2.1333e+00,
-        -2.8006e-01,  3.0407e-20, -2.4010e-13, -1.9638e-02, -4.4359e-02,
-         1.0088e-01,  5.4663e-12, -2.3938e-01, -4.3522e-01, -5.4528e-17,
-        -4.0528e-14, -3.9865e-15, -2.6727e-01, -9.2205e-18,  1.1086e-13,
-        -4.2476e-02, -7.5340e-02, -1.9581e-17, -8.7837e-01,  6.6038e-02,
-         1.5892e-13,  5.5152e-02, -2.0502e-02, -1.9720e-01,  2.8478e-17,
-        -3.7086e-01, -1.1595e+00, -7.2897e-21, -1.8150e-01,  3.6183e-15,
-         1.9523e-13,  1.2070e-16,  0.0000e+00,  2.6202e-03,  6.2035e-01,
-        -1.2951e-01, -1.6121e-14, -2.4932e-01,  2.9705e-17,  7.2654e-15,
-        -5.1791e-02, -1.0251e-12,  3.3078e-06,  3.4485e-12, -1.8911e-18,
-        -6.1186e-09, -4.7629e-02, -3.1592e-01, -1.3980e-01, -6.9298e-10,
-         3.9356e-22, -1.5468e-16,  5.5627e-20,  5.6041e-01, -8.0338e-12,
-         1.4859e-19, -1.4523e-01,  1.2485e-07, -7.1836e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.0627,  0.0000,  0.0265, -0.0236,  2.1333, -0.2801,  0.0000,  0.0000,
-        -0.0196, -0.0444,  0.1009,  0.0000, -0.2394, -0.4352,  0.0000,  0.0000,
-         0.0000, -0.2673,  0.0000,  0.0000, -0.0425, -0.0753,  0.0000, -0.8784,
-         0.0660,  0.0000,  0.0552, -0.0205, -0.1972,  0.0000, -0.3709, -1.1595,
-         0.0000, -0.1815,  0.0000,  0.0000,  0.0000,  0.0000,  0.0026,  0.6204,
-        -0.1295,  0.0000, -0.2493,  0.0000,  0.0000, -0.0518,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0476, -0.3159, -0.1398,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5604,  0.0000,  0.0000, -0.1452,  0.0000, -0.0718],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.0627,  0.0000,  0.0265, -0.0236,  2.1333, -0.2801,  0.0000,  0.0000,
-        -0.0196, -0.0444,  0.1009,  0.0000, -0.2394, -0.4352,  0.0000,  0.0000,
-         0.0000, -0.2673,  0.0000,  0.0000, -0.0425, -0.0753,  0.0000, -0.8784,
-         0.0660,  0.0000,  0.0552, -0.0205, -0.1972,  0.0000, -0.3709, -1.1595,
-         0.0000, -0.1815,  0.0000,  0.0000,  0.0000,  0.0000,  0.0026,  0.6204,
-        -0.1295,  0.0000, -0.2493,  0.0000,  0.0000, -0.0518,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0476, -0.3159, -0.1398,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5604,  0.0000,  0.0000, -0.1452,  0.0000, -0.0718],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 7.2268e-02, -4.5998e-11,  1.0641e-02, -3.4315e-02,  2.1319e+00,
-        -2.6252e-01,  2.7843e-20, -2.1986e-13, -1.8884e-02, -3.3902e-02,
-         8.7996e-02,  5.0055e-12, -2.4638e-01, -4.2388e-01, -4.9931e-17,
-        -3.7111e-14, -3.6504e-15, -2.6899e-01, -8.4431e-18,  1.0151e-13,
-        -4.1280e-02, -5.0595e-02, -1.7930e-17, -8.7522e-01,  8.1586e-02,
-         1.4552e-13,  2.3555e-02, -1.3291e-03, -1.6956e-01,  2.6077e-17,
-        -3.6784e-01, -1.1636e+00, -6.6751e-21, -2.0497e-01,  3.3132e-15,
-         1.7877e-13,  1.1052e-16,  0.0000e+00,  2.0155e-02,  6.2339e-01,
-        -1.1199e-01, -1.4762e-14, -2.5498e-01,  2.7200e-17,  6.6529e-15,
-        -6.0775e-02, -9.3871e-13,  3.0289e-06,  3.1578e-12, -1.7316e-18,
-        -5.6027e-09, -3.4230e-02, -3.2406e-01, -1.6817e-01, -6.3456e-10,
-         3.6038e-22, -1.4164e-16,  5.0937e-20,  5.4433e-01, -7.3564e-12,
-         1.3607e-19, -1.6799e-01,  1.1432e-07, -7.8988e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 7.2268e-02,  0.0000e+00,  1.0641e-02, -3.4315e-02,  2.1319e+00,
-        -2.6252e-01,  0.0000e+00,  0.0000e+00, -1.8884e-02, -3.3902e-02,
-         8.7996e-02,  0.0000e+00, -2.4638e-01, -4.2388e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.6899e-01,  0.0000e+00,  0.0000e+00,
-        -4.1280e-02, -5.0595e-02,  0.0000e+00, -8.7522e-01,  8.1586e-02,
-         0.0000e+00,  2.3555e-02, -1.3291e-03, -1.6956e-01,  0.0000e+00,
-        -3.6784e-01, -1.1636e+00,  0.0000e+00, -2.0497e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  2.0155e-02,  6.2339e-01,
-        -1.1199e-01,  0.0000e+00, -2.5498e-01,  0.0000e+00,  0.0000e+00,
-        -6.0775e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -3.4230e-02, -3.2406e-01, -1.6817e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.4433e-01,  0.0000e+00,
-         0.0000e+00, -1.6799e-01,  0.0000e+00, -7.8988e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 7.2268e-02,  0.0000e+00,  1.0641e-02, -3.4315e-02,  2.1319e+00,
-        -2.6252e-01,  0.0000e+00,  0.0000e+00, -1.8884e-02, -3.3902e-02,
-         8.7996e-02,  0.0000e+00, -2.4638e-01, -4.2388e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.6899e-01,  0.0000e+00,  0.0000e+00,
-        -4.1280e-02, -5.0595e-02,  0.0000e+00, -8.7522e-01,  8.1586e-02,
-         0.0000e+00,  2.3555e-02, -1.3291e-03, -1.6956e-01,  0.0000e+00,
-        -3.6784e-01, -1.1636e+00,  0.0000e+00, -2.0497e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  2.0155e-02,  6.2339e-01,
-        -1.1199e-01,  0.0000e+00, -2.5498e-01,  0.0000e+00,  0.0000e+00,
-        -6.0775e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -3.4230e-02, -3.2406e-01, -1.6817e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.4433e-01,  0.0000e+00,
-         0.0000e+00, -1.6799e-01,  0.0000e+00, -7.8988e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 8.5388e-02, -4.2125e-11,  9.5886e-04, -3.9037e-02,  2.1305e+00,
-        -2.4920e-01,  2.5499e-20, -2.0135e-13, -1.6963e-02, -5.2321e-03,
-         8.6718e-02,  4.5841e-12, -2.5206e-01, -4.1977e-01, -4.5727e-17,
-        -3.3986e-14, -3.3431e-15, -2.6092e-01, -7.7323e-18,  9.2968e-14,
-        -4.5885e-02, -2.1350e-02, -1.6420e-17, -8.7672e-01,  9.1119e-02,
-         1.3327e-13,  3.8002e-03,  7.9093e-03, -1.5021e-01,  2.3881e-17,
-        -3.6437e-01, -1.1689e+00, -6.1131e-21, -2.2313e-01,  3.0343e-15,
-         1.6372e-13,  1.0122e-16,  0.0000e+00,  5.7679e-02,  6.2209e-01,
-        -1.0363e-01, -1.3519e-14, -2.5795e-01,  2.4910e-17,  6.0928e-15,
-        -7.2189e-02, -8.5967e-13,  2.7739e-06,  2.8919e-12, -1.5858e-18,
-        -5.1310e-09, -2.7932e-02, -3.2513e-01, -1.7361e-01, -5.8113e-10,
-         3.3004e-22, -1.2971e-16,  4.6649e-20,  5.2857e-01, -6.7371e-12,
-         1.2461e-19, -1.8695e-01,  1.0470e-07, -7.2964e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 8.5388e-02,  0.0000e+00,  9.5886e-04, -3.9037e-02,  2.1305e+00,
-        -2.4920e-01,  0.0000e+00,  0.0000e+00, -1.6963e-02, -5.2321e-03,
-         8.6718e-02,  0.0000e+00, -2.5206e-01, -4.1977e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.6092e-01,  0.0000e+00,  0.0000e+00,
-        -4.5885e-02, -2.1350e-02,  0.0000e+00, -8.7672e-01,  9.1119e-02,
-         0.0000e+00,  3.8002e-03,  7.9093e-03, -1.5021e-01,  0.0000e+00,
-        -3.6437e-01, -1.1689e+00,  0.0000e+00, -2.2313e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.7679e-02,  6.2209e-01,
-        -1.0363e-01,  0.0000e+00, -2.5795e-01,  0.0000e+00,  0.0000e+00,
-        -7.2189e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -2.7932e-02, -3.2513e-01, -1.7361e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.2857e-01,  0.0000e+00,
-         0.0000e+00, -1.8695e-01,  0.0000e+00, -7.2964e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 8.5388e-02,  0.0000e+00,  9.5886e-04, -3.9037e-02,  2.1305e+00,
-        -2.4920e-01,  0.0000e+00,  0.0000e+00, -1.6963e-02, -5.2321e-03,
-         8.6718e-02,  0.0000e+00, -2.5206e-01, -4.1977e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.6092e-01,  0.0000e+00,  0.0000e+00,
-        -4.5885e-02, -2.1350e-02,  0.0000e+00, -8.7672e-01,  9.1119e-02,
-         0.0000e+00,  3.8002e-03,  7.9093e-03, -1.5021e-01,  0.0000e+00,
-        -3.6437e-01, -1.1689e+00,  0.0000e+00, -2.2313e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.7679e-02,  6.2209e-01,
-        -1.0363e-01,  0.0000e+00, -2.5795e-01,  0.0000e+00,  0.0000e+00,
-        -7.2189e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -2.7932e-02, -3.2513e-01, -1.7361e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.2857e-01,  0.0000e+00,
-         0.0000e+00, -1.8695e-01,  0.0000e+00, -7.2964e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 1.0531e-01, -3.8583e-11, -5.4389e-03, -3.6414e-02,  2.1289e+00,
-        -2.3607e-01,  2.3355e-20, -1.8441e-13, -1.2856e-02,  1.9750e-02,
-         1.0177e-01,  4.1986e-12, -2.5550e-01, -4.1275e-01, -4.1882e-17,
-        -3.1128e-14, -3.0620e-15, -2.5575e-01, -7.0821e-18,  8.5150e-14,
-        -3.9874e-02,  3.1199e-03, -1.5040e-17, -8.7837e-01,  9.9017e-02,
-         1.2206e-13, -1.7523e-02,  2.3170e-02, -1.2691e-01,  2.1873e-17,
-        -3.7357e-01, -1.1745e+00, -5.5991e-21, -2.3933e-01,  2.7791e-15,
-         1.4995e-13,  9.2705e-17,  0.0000e+00,  1.0245e-01,  6.2703e-01,
-        -8.1577e-02, -1.2382e-14, -2.6015e-01,  2.2816e-17,  5.5804e-15,
-        -8.9105e-02, -7.8739e-13,  2.5407e-06,  2.6488e-12, -1.4525e-18,
-        -4.6996e-09, -1.4716e-02, -3.2407e-01, -1.8132e-01, -5.3227e-10,
-         3.0229e-22, -1.1880e-16,  4.2726e-20,  5.0799e-01, -6.1706e-12,
-         1.1413e-19, -1.9706e-01,  9.5893e-08, -6.1009e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1053,  0.0000, -0.0054, -0.0364,  2.1289, -0.2361,  0.0000,  0.0000,
-        -0.0129,  0.0198,  0.1018,  0.0000, -0.2555, -0.4127,  0.0000,  0.0000,
-         0.0000, -0.2558,  0.0000,  0.0000, -0.0399,  0.0031,  0.0000, -0.8784,
-         0.0990,  0.0000, -0.0175,  0.0232, -0.1269,  0.0000, -0.3736, -1.1745,
-         0.0000, -0.2393,  0.0000,  0.0000,  0.0000,  0.0000,  0.1025,  0.6270,
-        -0.0816,  0.0000, -0.2601,  0.0000,  0.0000, -0.0891,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0147, -0.3241, -0.1813,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5080,  0.0000,  0.0000, -0.1971,  0.0000, -0.0610],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1053,  0.0000, -0.0054, -0.0364,  2.1289, -0.2361,  0.0000,  0.0000,
-        -0.0129,  0.0198,  0.1018,  0.0000, -0.2555, -0.4127,  0.0000,  0.0000,
-         0.0000, -0.2558,  0.0000,  0.0000, -0.0399,  0.0031,  0.0000, -0.8784,
-         0.0990,  0.0000, -0.0175,  0.0232, -0.1269,  0.0000, -0.3736, -1.1745,
-         0.0000, -0.2393,  0.0000,  0.0000,  0.0000,  0.0000,  0.1025,  0.6270,
-        -0.0816,  0.0000, -0.2601,  0.0000,  0.0000, -0.0891,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0147, -0.3241, -0.1813,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.5080,  0.0000,  0.0000, -0.1971,  0.0000, -0.0610],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.2991e-01, -3.5343e-11, -1.3292e-02, -1.6000e-02,  2.1276e+00,
-        -2.2168e-01,  2.1394e-20, -1.6893e-13,  2.9234e-03,  6.2336e-02,
-         1.1895e-01,  3.8460e-12, -2.6316e-01, -4.1309e-01, -3.8365e-17,
-        -2.8514e-14, -2.8048e-15, -2.3212e-01, -6.4873e-18,  7.7999e-14,
-        -2.5214e-02,  2.9430e-02, -1.3777e-17, -8.7864e-01,  1.1061e-01,
-         1.1181e-13, -3.4542e-02,  5.3287e-02, -1.0590e-01,  2.0036e-17,
-        -3.8517e-01, -1.1800e+00, -5.1288e-21, -2.4183e-01,  2.5457e-15,
-         1.3736e-13,  8.4919e-17,  0.0000e+00,  1.3589e-01,  6.2837e-01,
-        -6.4070e-02, -1.1342e-14, -2.6718e-01,  2.0899e-17,  5.1118e-15,
-        -1.2432e-01, -7.2126e-13,  2.3273e-06,  2.4263e-12, -1.3305e-18,
-        -4.3049e-09,  1.3193e-02, -3.3216e-01, -1.4112e-01, -4.8756e-10,
-         2.7690e-22, -1.0883e-16,  3.9138e-20,  4.9321e-01, -5.6523e-12,
-         1.0455e-19, -1.9811e-01,  8.7839e-08, -5.0117e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1299,  0.0000, -0.0133, -0.0160,  2.1276, -0.2217,  0.0000,  0.0000,
-         0.0029,  0.0623,  0.1189,  0.0000, -0.2632, -0.4131,  0.0000,  0.0000,
-         0.0000, -0.2321,  0.0000,  0.0000, -0.0252,  0.0294,  0.0000, -0.8786,
-         0.1106,  0.0000, -0.0345,  0.0533, -0.1059,  0.0000, -0.3852, -1.1800,
-         0.0000, -0.2418,  0.0000,  0.0000,  0.0000,  0.0000,  0.1359,  0.6284,
-        -0.0641,  0.0000, -0.2672,  0.0000,  0.0000, -0.1243,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0132, -0.3322, -0.1411,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4932,  0.0000,  0.0000, -0.1981,  0.0000, -0.0501],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1299,  0.0000, -0.0133, -0.0160,  2.1276, -0.2217,  0.0000,  0.0000,
-         0.0029,  0.0623,  0.1189,  0.0000, -0.2632, -0.4131,  0.0000,  0.0000,
-         0.0000, -0.2321,  0.0000,  0.0000, -0.0252,  0.0294,  0.0000, -0.8786,
-         0.1106,  0.0000, -0.0345,  0.0533, -0.1059,  0.0000, -0.3852, -1.1800,
-         0.0000, -0.2418,  0.0000,  0.0000,  0.0000,  0.0000,  0.1359,  0.6284,
-        -0.0641,  0.0000, -0.2672,  0.0000,  0.0000, -0.1243,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0132, -0.3322, -0.1411,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4932,  0.0000,  0.0000, -0.1981,  0.0000, -0.0501],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.6119e-01, -3.2378e-11, -1.3247e-02,  2.1350e-02,  2.1265e+00,
-        -2.0430e-01,  1.9599e-20, -1.5476e-13,  2.0546e-02,  1.1199e-01,
-         1.4149e-01,  3.5233e-12, -2.6821e-01, -4.1558e-01, -3.5146e-17,
-        -2.6122e-14, -2.5695e-15, -2.0677e-01, -5.9431e-18,  7.1456e-14,
-        -4.1493e-03,  5.0063e-02, -1.2621e-17, -8.8213e-01,  1.1772e-01,
-         1.0243e-13, -4.2162e-02,  8.7606e-02, -8.4986e-02,  1.8355e-17,
-        -3.9485e-01, -1.1855e+00, -4.6986e-21, -2.4327e-01,  2.3322e-15,
-         1.2583e-13,  7.7795e-17,  0.0000e+00,  1.5602e-01,  6.2846e-01,
-        -4.5897e-02, -1.0391e-14, -2.6596e-01,  1.9146e-17,  4.6829e-15,
-        -1.5828e-01, -6.6075e-13,  2.1321e-06,  2.2228e-12, -1.2189e-18,
-        -3.9437e-09,  3.0760e-02, -3.4051e-01, -8.3875e-02, -4.4666e-10,
-         2.5367e-22, -9.9697e-17,  3.5855e-20,  4.8226e-01, -5.1782e-12,
-         9.5776e-20, -1.9304e-01,  8.0470e-08, -3.1624e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1612,  0.0000, -0.0132,  0.0214,  2.1265, -0.2043,  0.0000,  0.0000,
-         0.0205,  0.1120,  0.1415,  0.0000, -0.2682, -0.4156,  0.0000,  0.0000,
-         0.0000, -0.2068,  0.0000,  0.0000, -0.0041,  0.0501,  0.0000, -0.8821,
-         0.1177,  0.0000, -0.0422,  0.0876, -0.0850,  0.0000, -0.3948, -1.1855,
-         0.0000, -0.2433,  0.0000,  0.0000,  0.0000,  0.0000,  0.1560,  0.6285,
-        -0.0459,  0.0000, -0.2660,  0.0000,  0.0000, -0.1583,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0308, -0.3405, -0.0839,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4823,  0.0000,  0.0000, -0.1930,  0.0000, -0.0316],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1612,  0.0000, -0.0132,  0.0214,  2.1265, -0.2043,  0.0000,  0.0000,
-         0.0205,  0.1120,  0.1415,  0.0000, -0.2682, -0.4156,  0.0000,  0.0000,
-         0.0000, -0.2068,  0.0000,  0.0000, -0.0041,  0.0501,  0.0000, -0.8821,
-         0.1177,  0.0000, -0.0422,  0.0876, -0.0850,  0.0000, -0.3948, -1.1855,
-         0.0000, -0.2433,  0.0000,  0.0000,  0.0000,  0.0000,  0.1560,  0.6285,
-        -0.0459,  0.0000, -0.2660,  0.0000,  0.0000, -0.1583,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0308, -0.3405, -0.0839,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4823,  0.0000,  0.0000, -0.1930,  0.0000, -0.0316],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 1.8729e-01, -2.9665e-11, -1.4310e-02,  4.7926e-02,  2.1254e+00,
-        -1.8388e-01,  1.7956e-20, -1.4179e-13,  3.2952e-02,  1.5043e-01,
-         1.6229e-01,  3.2281e-12, -2.7787e-01, -4.2311e-01, -3.2201e-17,
-        -2.3933e-14, -2.3542e-15, -1.8170e-01, -5.4450e-18,  6.5468e-14,
-         1.6194e-02,  6.7957e-02, -1.1563e-17, -8.8113e-01,  1.2352e-01,
-         9.3849e-14, -4.9049e-02,  1.1706e-01, -6.4289e-02,  1.6817e-17,
-        -4.0237e-01, -1.1901e+00, -4.3048e-21, -2.3653e-01,  2.1367e-15,
-         1.1529e-13,  7.1276e-17,  0.0000e+00,  1.6296e-01,  6.2234e-01,
-        -3.7369e-02, -9.5201e-15, -2.6667e-01,  1.7542e-17,  4.2905e-15,
-        -1.8154e-01, -6.0538e-13,  1.9534e-06,  2.0365e-12, -1.1167e-18,
-        -3.6133e-09,  4.5636e-02, -3.4729e-01, -3.3345e-02, -4.0923e-10,
-         2.3241e-22, -9.1342e-17,  3.2850e-20,  4.7249e-01, -4.7442e-12,
-         8.7750e-20, -1.8797e-01,  7.3727e-08, -1.3423e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.1873,  0.0000, -0.0143,  0.0479,  2.1254, -0.1839,  0.0000,  0.0000,
-         0.0330,  0.1504,  0.1623,  0.0000, -0.2779, -0.4231,  0.0000,  0.0000,
-         0.0000, -0.1817,  0.0000,  0.0000,  0.0162,  0.0680,  0.0000, -0.8811,
-         0.1235,  0.0000, -0.0490,  0.1171, -0.0643,  0.0000, -0.4024, -1.1901,
-         0.0000, -0.2365,  0.0000,  0.0000,  0.0000,  0.0000,  0.1630,  0.6223,
-        -0.0374,  0.0000, -0.2667,  0.0000,  0.0000, -0.1815,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0456, -0.3473, -0.0333,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4725,  0.0000,  0.0000, -0.1880,  0.0000, -0.0134],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.1873,  0.0000, -0.0143,  0.0479,  2.1254, -0.1839,  0.0000,  0.0000,
-         0.0330,  0.1504,  0.1623,  0.0000, -0.2779, -0.4231,  0.0000,  0.0000,
-         0.0000, -0.1817,  0.0000,  0.0000,  0.0162,  0.0680,  0.0000, -0.8811,
-         0.1235,  0.0000, -0.0490,  0.1171, -0.0643,  0.0000, -0.4024, -1.1901,
-         0.0000, -0.2365,  0.0000,  0.0000,  0.0000,  0.0000,  0.1630,  0.6223,
-        -0.0374,  0.0000, -0.2667,  0.0000,  0.0000, -0.1815,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0456, -0.3473, -0.0333,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4725,  0.0000,  0.0000, -0.1880,  0.0000, -0.0134],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.1064e-01, -2.7181e-11, -1.7090e-02,  6.1898e-02,  2.1245e+00,
-        -1.6672e-01,  1.6453e-20, -1.2992e-13,  4.2495e-02,  1.7704e-01,
-         1.7838e-01,  2.9578e-12, -2.8816e-01, -4.2982e-01, -2.9505e-17,
-        -2.1929e-14, -2.1571e-15, -1.6110e-01, -4.9892e-18,  5.9987e-14,
-         3.4422e-02,  7.5510e-02, -1.0595e-17, -8.8204e-01,  1.1236e-01,
-         8.5992e-14, -4.7655e-02,  1.3437e-01, -5.2759e-02,  1.5409e-17,
-        -4.1108e-01, -1.1933e+00, -3.9444e-21, -2.2899e-01,  1.9578e-15,
-         1.0564e-13,  6.5309e-17,  0.0000e+00,  1.6950e-01,  6.1554e-01,
-        -3.5682e-02, -8.7231e-15, -2.6982e-01,  1.6073e-17,  3.9313e-15,
-        -2.0172e-01, -5.5470e-13,  1.7899e-06,  1.8660e-12, -1.0232e-18,
-        -3.3108e-09,  5.1494e-02, -3.4543e-01,  8.3509e-04, -3.7497e-10,
-         2.1295e-22, -8.3695e-17,  3.0100e-20,  4.6548e-01, -4.3470e-12,
-         8.0403e-20, -1.7846e-01,  6.7554e-08,  1.7027e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.1064e-01,  0.0000e+00, -1.7090e-02,  6.1898e-02,  2.1245e+00,
-        -1.6672e-01,  0.0000e+00,  0.0000e+00,  4.2495e-02,  1.7704e-01,
-         1.7838e-01,  0.0000e+00, -2.8816e-01, -4.2982e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.6110e-01,  0.0000e+00,  0.0000e+00,
-         3.4422e-02,  7.5510e-02,  0.0000e+00, -8.8204e-01,  1.1236e-01,
-         0.0000e+00, -4.7655e-02,  1.3437e-01, -5.2759e-02,  0.0000e+00,
-        -4.1108e-01, -1.1933e+00,  0.0000e+00, -2.2899e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.6950e-01,  6.1554e-01,
-        -3.5682e-02,  0.0000e+00, -2.6982e-01,  0.0000e+00,  0.0000e+00,
-        -2.0172e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  5.1494e-02, -3.4543e-01,  8.3509e-04,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6548e-01,  0.0000e+00,
-         0.0000e+00, -1.7846e-01,  0.0000e+00,  1.7027e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.1064e-01,  0.0000e+00, -1.7090e-02,  6.1898e-02,  2.1245e+00,
-        -1.6672e-01,  0.0000e+00,  0.0000e+00,  4.2495e-02,  1.7704e-01,
-         1.7838e-01,  0.0000e+00, -2.8816e-01, -4.2982e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.6110e-01,  0.0000e+00,  0.0000e+00,
-         3.4422e-02,  7.5510e-02,  0.0000e+00, -8.8204e-01,  1.1236e-01,
-         0.0000e+00, -4.7655e-02,  1.3437e-01, -5.2759e-02,  0.0000e+00,
-        -4.1108e-01, -1.1933e+00,  0.0000e+00, -2.2899e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.6950e-01,  6.1554e-01,
-        -3.5682e-02,  0.0000e+00, -2.6982e-01,  0.0000e+00,  0.0000e+00,
-        -2.0172e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  5.1494e-02, -3.4543e-01,  8.3509e-04,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6548e-01,  0.0000e+00,
-         0.0000e+00, -1.7846e-01,  0.0000e+00,  1.7027e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 2.3743e-01, -2.4908e-11, -1.5570e-02,  6.4393e-02,  2.1239e+00,
-        -1.4270e-01,  1.5077e-20, -1.1905e-13,  5.1032e-02,  1.8997e-01,
-         1.9111e-01,  2.7104e-12, -2.9630e-01, -4.3265e-01, -2.7037e-17,
-        -2.0095e-14, -1.9767e-15, -1.5406e-01, -4.5719e-18,  5.4969e-14,
-         5.6851e-02,  7.8984e-02, -9.7090e-18, -8.8359e-01,  1.0068e-01,
-         7.8799e-14, -4.5478e-02,  1.5726e-01, -3.5843e-02,  1.4120e-17,
-        -4.2117e-01, -1.1961e+00, -3.6145e-21, -2.1787e-01,  1.7941e-15,
-         9.6801e-14,  5.9846e-17,  0.0000e+00,  1.6234e-01,  6.0937e-01,
-        -3.0819e-02, -7.9934e-15, -2.6953e-01,  1.4729e-17,  3.6025e-15,
-        -2.1821e-01, -5.0830e-13,  1.6401e-06,  1.7099e-12, -9.3766e-19,
-        -3.0338e-09,  6.5184e-02, -3.4123e-01,  2.1036e-02, -3.4361e-10,
-         1.9514e-22, -7.6694e-17,  2.7582e-20,  4.6362e-01, -3.9834e-12,
-         7.3678e-20, -1.7018e-01,  6.1904e-08,  3.1652e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2374,  0.0000, -0.0156,  0.0644,  2.1239, -0.1427,  0.0000,  0.0000,
-         0.0510,  0.1900,  0.1911,  0.0000, -0.2963, -0.4326,  0.0000,  0.0000,
-         0.0000, -0.1541,  0.0000,  0.0000,  0.0569,  0.0790,  0.0000, -0.8836,
-         0.1007,  0.0000, -0.0455,  0.1573, -0.0358,  0.0000, -0.4212, -1.1961,
-         0.0000, -0.2179,  0.0000,  0.0000,  0.0000,  0.0000,  0.1623,  0.6094,
-        -0.0308,  0.0000, -0.2695,  0.0000,  0.0000, -0.2182,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0652, -0.3412,  0.0210,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4636,  0.0000,  0.0000, -0.1702,  0.0000,  0.0317],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2374,  0.0000, -0.0156,  0.0644,  2.1239, -0.1427,  0.0000,  0.0000,
-         0.0510,  0.1900,  0.1911,  0.0000, -0.2963, -0.4326,  0.0000,  0.0000,
-         0.0000, -0.1541,  0.0000,  0.0000,  0.0569,  0.0790,  0.0000, -0.8836,
-         0.1007,  0.0000, -0.0455,  0.1573, -0.0358,  0.0000, -0.4212, -1.1961,
-         0.0000, -0.2179,  0.0000,  0.0000,  0.0000,  0.0000,  0.1623,  0.6094,
-        -0.0308,  0.0000, -0.2695,  0.0000,  0.0000, -0.2182,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0652, -0.3412,  0.0210,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4636,  0.0000,  0.0000, -0.1702,  0.0000,  0.0317],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.6488e-01, -2.2826e-11, -1.0590e-02,  6.9818e-02,  2.1233e+00,
-        -1.1707e-01,  1.3817e-20, -1.0910e-13,  4.9659e-02,  2.1036e-01,
-         2.1109e-01,  2.4839e-12, -2.9897e-01, -4.3478e-01, -2.4778e-17,
-        -1.8416e-14, -1.8115e-15, -1.4978e-01, -4.1898e-18,  5.0375e-14,
-         7.6163e-02,  7.8269e-02, -8.8975e-18, -8.9004e-01,  8.7098e-02,
-         7.2213e-14, -3.4419e-02,  1.8296e-01, -2.5840e-02,  1.2940e-17,
-        -4.3015e-01, -1.1987e+00, -3.3124e-21, -2.1217e-01,  1.6441e-15,
-         8.8711e-14,  5.4844e-17,  0.0000e+00,  1.4588e-01,  6.0002e-01,
-        -2.3997e-02, -7.3253e-15, -2.6465e-01,  1.3498e-17,  3.3014e-15,
-        -2.1552e-01, -4.6582e-13,  1.5031e-06,  1.5670e-12, -8.5929e-19,
-        -2.7803e-09,  6.3127e-02, -3.2429e-01,  3.7251e-02, -3.1489e-10,
-         1.7883e-22, -7.0284e-17,  2.5277e-20,  4.6588e-01, -3.6505e-12,
-         6.7520e-20, -1.5533e-01,  5.6730e-08,  5.5068e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2649,  0.0000, -0.0106,  0.0698,  2.1233, -0.1171,  0.0000,  0.0000,
-         0.0497,  0.2104,  0.2111,  0.0000, -0.2990, -0.4348,  0.0000,  0.0000,
-         0.0000, -0.1498,  0.0000,  0.0000,  0.0762,  0.0783,  0.0000, -0.8900,
-         0.0871,  0.0000, -0.0344,  0.1830, -0.0258,  0.0000, -0.4302, -1.1987,
-         0.0000, -0.2122,  0.0000,  0.0000,  0.0000,  0.0000,  0.1459,  0.6000,
-        -0.0240,  0.0000, -0.2646,  0.0000,  0.0000, -0.2155,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0631, -0.3243,  0.0373,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4659,  0.0000,  0.0000, -0.1553,  0.0000,  0.0551],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2649,  0.0000, -0.0106,  0.0698,  2.1233, -0.1171,  0.0000,  0.0000,
-         0.0497,  0.2104,  0.2111,  0.0000, -0.2990, -0.4348,  0.0000,  0.0000,
-         0.0000, -0.1498,  0.0000,  0.0000,  0.0762,  0.0783,  0.0000, -0.8900,
-         0.0871,  0.0000, -0.0344,  0.1830, -0.0258,  0.0000, -0.4302, -1.1987,
-         0.0000, -0.2122,  0.0000,  0.0000,  0.0000,  0.0000,  0.1459,  0.6000,
-        -0.0240,  0.0000, -0.2646,  0.0000,  0.0000, -0.2155,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0631, -0.3243,  0.0373,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4659,  0.0000,  0.0000, -0.1553,  0.0000,  0.0551],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8760e-01, -2.0919e-11, -1.8564e-03,  8.1031e-02,  2.1229e+00,
-        -8.9586e-02,  1.2663e-20, -9.9988e-14,  4.0205e-02,  2.2851e-01,
-         2.1765e-01,  2.2764e-12, -2.9845e-01, -4.3250e-01, -2.2708e-17,
-        -1.6878e-14, -1.6602e-15, -1.5500e-01, -3.8398e-18,  4.6168e-14,
-         9.8736e-02,  7.3316e-02, -8.1544e-18, -8.9838e-01,  7.2188e-02,
-         6.6182e-14, -1.6430e-02,  2.0759e-01, -2.0161e-02,  1.1859e-17,
-        -4.3503e-01, -1.2022e+00, -3.0358e-21, -2.0397e-01,  1.5068e-15,
-         8.1302e-14,  5.0264e-17,  0.0000e+00,  1.3713e-01,  5.9293e-01,
-        -2.5686e-02, -6.7135e-15, -2.6156e-01,  1.2370e-17,  3.0257e-15,
-        -2.0058e-01, -4.2691e-13,  1.3775e-06,  1.4361e-12, -7.8752e-19,
-        -2.5481e-09,  5.6708e-02, -3.0576e-01,  6.0895e-02, -2.8859e-10,
-         1.6390e-22, -6.4414e-17,  2.3166e-20,  4.6811e-01, -3.3456e-12,
-         6.1881e-20, -1.3988e-01,  5.1992e-08,  7.3246e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 2.8760e-01,  0.0000e+00, -1.8564e-03,  8.1031e-02,  2.1229e+00,
-        -8.9586e-02,  0.0000e+00,  0.0000e+00,  4.0205e-02,  2.2851e-01,
-         2.1765e-01,  0.0000e+00, -2.9845e-01, -4.3250e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.5500e-01,  0.0000e+00,  0.0000e+00,
-         9.8736e-02,  7.3316e-02,  0.0000e+00, -8.9838e-01,  7.2188e-02,
-         0.0000e+00, -1.6430e-02,  2.0759e-01, -2.0161e-02,  0.0000e+00,
-        -4.3503e-01, -1.2022e+00,  0.0000e+00, -2.0397e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.3713e-01,  5.9293e-01,
-        -2.5686e-02,  0.0000e+00, -2.6156e-01,  0.0000e+00,  0.0000e+00,
-        -2.0058e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  5.6708e-02, -3.0576e-01,  6.0895e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6811e-01,  0.0000e+00,
-         0.0000e+00, -1.3988e-01,  0.0000e+00,  7.3246e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 2.8760e-01,  0.0000e+00, -1.8564e-03,  8.1031e-02,  2.1229e+00,
-        -8.9586e-02,  0.0000e+00,  0.0000e+00,  4.0205e-02,  2.2851e-01,
-         2.1765e-01,  0.0000e+00, -2.9845e-01, -4.3250e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -1.5500e-01,  0.0000e+00,  0.0000e+00,
-         9.8736e-02,  7.3316e-02,  0.0000e+00, -8.9838e-01,  7.2188e-02,
-         0.0000e+00, -1.6430e-02,  2.0759e-01, -2.0161e-02,  0.0000e+00,
-        -4.3503e-01, -1.2022e+00,  0.0000e+00, -2.0397e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.3713e-01,  5.9293e-01,
-        -2.5686e-02,  0.0000e+00, -2.6156e-01,  0.0000e+00,  0.0000e+00,
-        -2.0058e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  5.6708e-02, -3.0576e-01,  6.0895e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6811e-01,  0.0000e+00,
-         0.0000e+00, -1.3988e-01,  0.0000e+00,  7.3246e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.0861e-01, -1.9173e-11,  3.6097e-03,  9.1203e-02,  2.1222e+00,
-        -5.9552e-02,  1.1606e-20, -9.1642e-14,  2.4046e-02,  2.4126e-01,
-         2.2376e-01,  2.0864e-12, -3.0076e-01, -4.2984e-01, -2.0813e-17,
-        -1.5469e-14, -1.5216e-15, -1.6453e-01, -3.5193e-18,  4.2314e-14,
-         1.2224e-01,  7.2768e-02, -7.4738e-18, -9.0374e-01,  6.0192e-02,
-         6.0658e-14,  5.1806e-03,  2.3297e-01, -1.1880e-02,  1.0870e-17,
-        -4.4051e-01, -1.2048e+00, -2.7824e-21, -1.9254e-01,  1.3810e-15,
-         7.4516e-14,  4.6068e-17,  0.0000e+00,  1.1477e-01,  5.8233e-01,
-        -2.7829e-02, -6.1532e-15, -2.6212e-01,  1.1338e-17,  2.7731e-15,
-        -1.7758e-01, -3.9128e-13,  1.2626e-06,  1.3163e-12, -7.2179e-19,
-        -2.3354e-09,  5.1271e-02, -2.9037e-01,  8.1541e-02, -2.6450e-10,
-         1.5022e-22, -5.9038e-17,  2.1232e-20,  4.7216e-01, -3.0664e-12,
-         5.6716e-20, -1.1784e-01,  4.7652e-08,  9.3388e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3086,  0.0000,  0.0036,  0.0912,  2.1222, -0.0596,  0.0000,  0.0000,
-         0.0240,  0.2413,  0.2238,  0.0000, -0.3008, -0.4298,  0.0000,  0.0000,
-         0.0000, -0.1645,  0.0000,  0.0000,  0.1222,  0.0728,  0.0000, -0.9037,
-         0.0602,  0.0000,  0.0052,  0.2330, -0.0119,  0.0000, -0.4405, -1.2048,
-         0.0000, -0.1925,  0.0000,  0.0000,  0.0000,  0.0000,  0.1148,  0.5823,
-        -0.0278,  0.0000, -0.2621,  0.0000,  0.0000, -0.1776,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0513, -0.2904,  0.0815,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4722,  0.0000,  0.0000, -0.1178,  0.0000,  0.0934],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3086,  0.0000,  0.0036,  0.0912,  2.1222, -0.0596,  0.0000,  0.0000,
-         0.0240,  0.2413,  0.2238,  0.0000, -0.3008, -0.4298,  0.0000,  0.0000,
-         0.0000, -0.1645,  0.0000,  0.0000,  0.1222,  0.0728,  0.0000, -0.9037,
-         0.0602,  0.0000,  0.0052,  0.2330, -0.0119,  0.0000, -0.4405, -1.2048,
-         0.0000, -0.1925,  0.0000,  0.0000,  0.0000,  0.0000,  0.1148,  0.5823,
-        -0.0278,  0.0000, -0.2621,  0.0000,  0.0000, -0.1776,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0513, -0.2904,  0.0815,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4722,  0.0000,  0.0000, -0.1178,  0.0000,  0.0934],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2832e-01, -1.7574e-11,  1.4626e-02,  8.9968e-02,  2.1216e+00,
-        -3.4520e-02,  1.0638e-20, -8.3997e-14,  4.3057e-03,  2.5007e-01,
-         2.3039e-01,  1.9124e-12, -2.9691e-01, -4.3178e-01, -1.9077e-17,
-        -1.4178e-14, -1.3947e-15, -1.7841e-01, -3.2257e-18,  3.8784e-14,
-         1.3500e-01,  6.8073e-02, -6.8503e-18, -9.0655e-01,  4.2481e-02,
-         5.5598e-14,  3.0253e-02,  2.4845e-01, -1.3431e-02,  9.9628e-18,
-        -4.3919e-01, -1.2070e+00, -2.5503e-21, -1.7949e-01,  1.2658e-15,
-         6.8299e-14,  4.2225e-17,  0.0000e+00,  9.7051e-02,  5.7282e-01,
-        -4.1986e-02, -5.6399e-15, -2.5741e-01,  1.0392e-17,  2.5418e-15,
-        -1.5957e-01, -3.5864e-13,  1.1572e-06,  1.2065e-12, -6.6158e-19,
-        -2.1406e-09,  3.9283e-02, -2.7184e-01,  9.2667e-02, -2.4244e-10,
-         1.3769e-22, -5.4113e-17,  1.9461e-20,  4.7425e-01, -2.8106e-12,
-         5.1985e-20, -9.9447e-02,  4.3677e-08,  1.1818e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3283,  0.0000,  0.0146,  0.0900,  2.1216, -0.0345,  0.0000,  0.0000,
-         0.0043,  0.2501,  0.2304,  0.0000, -0.2969, -0.4318,  0.0000,  0.0000,
-         0.0000, -0.1784,  0.0000,  0.0000,  0.1350,  0.0681,  0.0000, -0.9065,
-         0.0425,  0.0000,  0.0303,  0.2485, -0.0134,  0.0000, -0.4392, -1.2070,
-         0.0000, -0.1795,  0.0000,  0.0000,  0.0000,  0.0000,  0.0971,  0.5728,
-        -0.0420,  0.0000, -0.2574,  0.0000,  0.0000, -0.1596,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0393, -0.2718,  0.0927,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4743,  0.0000,  0.0000, -0.0994,  0.0000,  0.1182],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3283,  0.0000,  0.0146,  0.0900,  2.1216, -0.0345,  0.0000,  0.0000,
-         0.0043,  0.2501,  0.2304,  0.0000, -0.2969, -0.4318,  0.0000,  0.0000,
-         0.0000, -0.1784,  0.0000,  0.0000,  0.1350,  0.0681,  0.0000, -0.9065,
-         0.0425,  0.0000,  0.0303,  0.2485, -0.0134,  0.0000, -0.4392, -1.2070,
-         0.0000, -0.1795,  0.0000,  0.0000,  0.0000,  0.0000,  0.0971,  0.5728,
-        -0.0420,  0.0000, -0.2574,  0.0000,  0.0000, -0.1596,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0393, -0.2718,  0.0927,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4743,  0.0000,  0.0000, -0.0994,  0.0000,  0.1182],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4188e-01, -1.6108e-11,  2.4281e-02,  7.7558e-02,  2.1210e+00,
-        -1.4899e-02,  9.7507e-21, -7.6993e-14, -1.5258e-02,  2.5409e-01,
-         2.4262e-01,  1.7529e-12, -2.9676e-01, -4.3655e-01, -1.7486e-17,
-        -1.2996e-14, -1.2784e-15, -1.8861e-01, -2.9568e-18,  3.5550e-14,
-         1.3922e-01,  5.9578e-02, -6.2791e-18, -9.0797e-01,  1.7404e-02,
-         5.0961e-14,  4.9174e-02,  2.5185e-01, -2.0437e-02,  9.1320e-18,
-        -4.3302e-01, -1.2096e+00, -2.3376e-21, -1.5695e-01,  1.1603e-15,
-         6.2604e-14,  3.8704e-17,  0.0000e+00,  8.2697e-02,  5.6413e-01,
-        -6.9713e-02, -5.1696e-15, -2.5345e-01,  9.5254e-18,  2.3298e-15,
-        -1.4961e-01, -3.2873e-13,  1.0607e-06,  1.1059e-12, -6.0641e-19,
-        -1.9621e-09,  1.8041e-02, -2.4882e-01,  8.3043e-02, -2.2222e-10,
-         1.2620e-22, -4.9600e-17,  1.7838e-20,  4.7757e-01, -2.5762e-12,
-         4.7650e-20, -8.3220e-02,  4.0035e-08,  1.3650e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3419,  0.0000,  0.0243,  0.0776,  2.1210, -0.0149,  0.0000,  0.0000,
-        -0.0153,  0.2541,  0.2426,  0.0000, -0.2968, -0.4365,  0.0000,  0.0000,
-         0.0000, -0.1886,  0.0000,  0.0000,  0.1392,  0.0596,  0.0000, -0.9080,
-         0.0174,  0.0000,  0.0492,  0.2519, -0.0204,  0.0000, -0.4330, -1.2096,
-         0.0000, -0.1570,  0.0000,  0.0000,  0.0000,  0.0000,  0.0827,  0.5641,
-        -0.0697,  0.0000, -0.2534,  0.0000,  0.0000, -0.1496,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0180, -0.2488,  0.0830,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4776,  0.0000,  0.0000, -0.0832,  0.0000,  0.1365],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3419,  0.0000,  0.0243,  0.0776,  2.1210, -0.0149,  0.0000,  0.0000,
-        -0.0153,  0.2541,  0.2426,  0.0000, -0.2968, -0.4365,  0.0000,  0.0000,
-         0.0000, -0.1886,  0.0000,  0.0000,  0.1392,  0.0596,  0.0000, -0.9080,
-         0.0174,  0.0000,  0.0492,  0.2519, -0.0204,  0.0000, -0.4330, -1.2096,
-         0.0000, -0.1570,  0.0000,  0.0000,  0.0000,  0.0000,  0.0827,  0.5641,
-        -0.0697,  0.0000, -0.2534,  0.0000,  0.0000, -0.1496,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0180, -0.2488,  0.0830,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4776,  0.0000,  0.0000, -0.0832,  0.0000,  0.1365],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5279e-01, -1.4766e-11,  3.1721e-02,  5.7192e-02,  2.1201e+00,
-         4.9814e-03,  8.9379e-21, -7.0575e-14, -3.3182e-02,  2.5181e-01,
-         2.4820e-01,  1.6068e-12, -2.9578e-01, -4.4083e-01, -1.6028e-17,
-        -1.1913e-14, -1.1718e-15, -2.0381e-01, -2.7103e-18,  3.2587e-14,
-         1.4161e-01,  4.9020e-02, -5.7557e-18, -9.0792e-01, -8.4028e-03,
-         4.6713e-14,  6.8413e-02,  2.5013e-01, -2.9449e-02,  8.3708e-18,
-        -4.2867e-01, -1.2125e+00, -2.1428e-21, -1.3591e-01,  1.0636e-15,
-         5.7386e-14,  3.5478e-17,  0.0000e+00,  7.4793e-02,  5.5590e-01,
-        -1.0408e-01, -4.7387e-15, -2.4997e-01,  8.7314e-18,  2.1356e-15,
-        -1.4318e-01, -3.0133e-13,  9.7231e-07,  1.0137e-12, -5.5586e-19,
-        -1.7985e-09, -1.6715e-03, -2.2677e-01,  5.7274e-02, -2.0370e-10,
-         1.1568e-22, -4.5466e-17,  1.6351e-20,  4.7850e-01, -2.3615e-12,
-         4.3678e-20, -7.2717e-02,  3.6698e-08,  1.4979e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.5279e-01,  0.0000e+00,  3.1721e-02,  5.7192e-02,  2.1201e+00,
-         4.9814e-03,  0.0000e+00,  0.0000e+00, -3.3182e-02,  2.5181e-01,
-         2.4820e-01,  0.0000e+00, -2.9578e-01, -4.4083e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.0381e-01,  0.0000e+00,  0.0000e+00,
-         1.4161e-01,  4.9020e-02,  0.0000e+00, -9.0792e-01, -8.4028e-03,
-         0.0000e+00,  6.8413e-02,  2.5013e-01, -2.9449e-02,  0.0000e+00,
-        -4.2867e-01, -1.2125e+00,  0.0000e+00, -1.3591e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  7.4793e-02,  5.5590e-01,
-        -1.0408e-01,  0.0000e+00, -2.4997e-01,  0.0000e+00,  0.0000e+00,
-        -1.4318e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -1.6715e-03, -2.2677e-01,  5.7274e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.7850e-01,  0.0000e+00,
-         0.0000e+00, -7.2717e-02,  0.0000e+00,  1.4979e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.5279e-01,  0.0000e+00,  3.1721e-02,  5.7192e-02,  2.1201e+00,
-         4.9814e-03,  0.0000e+00,  0.0000e+00, -3.3182e-02,  2.5181e-01,
-         2.4820e-01,  0.0000e+00, -2.9578e-01, -4.4083e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.0381e-01,  0.0000e+00,  0.0000e+00,
-         1.4161e-01,  4.9020e-02,  0.0000e+00, -9.0792e-01, -8.4028e-03,
-         0.0000e+00,  6.8413e-02,  2.5013e-01, -2.9449e-02,  0.0000e+00,
-        -4.2867e-01, -1.2125e+00,  0.0000e+00, -1.3591e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  7.4793e-02,  5.5590e-01,
-        -1.0408e-01,  0.0000e+00, -2.4997e-01,  0.0000e+00,  0.0000e+00,
-        -1.4318e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00, -1.6715e-03, -2.2677e-01,  5.7274e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.7850e-01,  0.0000e+00,
-         0.0000e+00, -7.2717e-02,  0.0000e+00,  1.4979e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5778e-01, -1.3535e-11,  4.0296e-02,  3.9735e-02,  2.1188e+00,
-         1.9505e-02,  8.1930e-21, -6.4693e-14, -4.5695e-02,  2.4841e-01,
-         2.5081e-01,  1.4729e-12, -2.9121e-01, -4.4567e-01, -1.4692e-17,
-        -1.0920e-14, -1.0741e-15, -2.1407e-01, -2.4844e-18,  2.9871e-14,
-         1.3973e-01,  4.4073e-02, -5.2760e-18, -9.0854e-01, -2.8704e-02,
-         4.2820e-14,  8.8075e-02,  2.4534e-01, -4.0658e-02,  7.6732e-18,
-        -4.2580e-01, -1.2150e+00, -1.9642e-21, -1.2181e-01,  9.7493e-16,
-         5.2603e-14,  3.2521e-17,  0.0000e+00,  5.9261e-02,  5.4629e-01,
-        -1.3305e-01, -4.3437e-15, -2.4330e-01,  8.0038e-18,  1.9576e-15,
-        -1.3959e-01, -2.7622e-13,  8.9128e-07,  9.2920e-13, -5.0954e-19,
-        -1.6486e-09, -5.8048e-03, -2.1683e-01,  3.3231e-02, -1.8672e-10,
-         1.0604e-22, -4.1677e-17,  1.4988e-20,  4.8154e-01, -2.1647e-12,
-         4.0038e-20, -6.1493e-02,  3.3639e-08,  1.6257e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3578,  0.0000,  0.0403,  0.0397,  2.1188,  0.0195,  0.0000,  0.0000,
-        -0.0457,  0.2484,  0.2508,  0.0000, -0.2912, -0.4457,  0.0000,  0.0000,
-         0.0000, -0.2141,  0.0000,  0.0000,  0.1397,  0.0441,  0.0000, -0.9085,
-        -0.0287,  0.0000,  0.0881,  0.2453, -0.0407,  0.0000, -0.4258, -1.2150,
-         0.0000, -0.1218,  0.0000,  0.0000,  0.0000,  0.0000,  0.0593,  0.5463,
-        -0.1330,  0.0000, -0.2433,  0.0000,  0.0000, -0.1396,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0058, -0.2168,  0.0332,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4815,  0.0000,  0.0000, -0.0615,  0.0000,  0.1626],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3578,  0.0000,  0.0403,  0.0397,  2.1188,  0.0195,  0.0000,  0.0000,
-        -0.0457,  0.2484,  0.2508,  0.0000, -0.2912, -0.4457,  0.0000,  0.0000,
-         0.0000, -0.2141,  0.0000,  0.0000,  0.1397,  0.0441,  0.0000, -0.9085,
-        -0.0287,  0.0000,  0.0881,  0.2453, -0.0407,  0.0000, -0.4258, -1.2150,
-         0.0000, -0.1218,  0.0000,  0.0000,  0.0000,  0.0000,  0.0593,  0.5463,
-        -0.1330,  0.0000, -0.2433,  0.0000,  0.0000, -0.1396,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0058, -0.2168,  0.0332,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4815,  0.0000,  0.0000, -0.0615,  0.0000,  0.1626],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5650e-01, -1.2407e-11,  4.1074e-02,  2.1614e-02,  2.1173e+00,
-         3.3108e-02,  7.5103e-21, -5.9303e-14, -6.0062e-02,  2.3582e-01,
-         2.4454e-01,  1.3502e-12, -2.9251e-01, -4.4646e-01, -1.3468e-17,
-        -1.0010e-14, -9.8464e-16, -2.2177e-01, -2.2774e-18,  2.7382e-14,
-         1.3496e-01,  4.8117e-02, -4.8364e-18, -9.0648e-01, -3.9553e-02,
-         3.9252e-14,  9.8255e-02,  2.3805e-01, -4.8752e-02,  7.0338e-18,
-        -4.2091e-01, -1.2175e+00, -1.8005e-21, -1.0962e-01,  8.9369e-16,
-         4.8220e-14,  2.9811e-17,  0.0000e+00,  3.7229e-02,  5.3753e-01,
-        -1.5245e-01, -3.9818e-15, -2.4309e-01,  7.3368e-18,  1.7945e-15,
-        -1.4374e-01, -2.5320e-13,  8.1701e-07,  8.5177e-13, -4.6708e-19,
-        -1.5112e-09, -4.5652e-03, -2.1610e-01,  6.6716e-03, -1.7116e-10,
-         9.7207e-23, -3.8204e-17,  1.3740e-20,  4.8207e-01, -1.9843e-12,
-         3.6702e-20, -5.8882e-02,  3.0836e-08,  1.6471e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3565,  0.0000,  0.0411,  0.0216,  2.1173,  0.0331,  0.0000,  0.0000,
-        -0.0601,  0.2358,  0.2445,  0.0000, -0.2925, -0.4465,  0.0000,  0.0000,
-         0.0000, -0.2218,  0.0000,  0.0000,  0.1350,  0.0481,  0.0000, -0.9065,
-        -0.0396,  0.0000,  0.0983,  0.2380, -0.0488,  0.0000, -0.4209, -1.2175,
-         0.0000, -0.1096,  0.0000,  0.0000,  0.0000,  0.0000,  0.0372,  0.5375,
-        -0.1525,  0.0000, -0.2431,  0.0000,  0.0000, -0.1437,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0046, -0.2161,  0.0067,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4821,  0.0000,  0.0000, -0.0589,  0.0000,  0.1647],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3565,  0.0000,  0.0411,  0.0216,  2.1173,  0.0331,  0.0000,  0.0000,
-        -0.0601,  0.2358,  0.2445,  0.0000, -0.2925, -0.4465,  0.0000,  0.0000,
-         0.0000, -0.2218,  0.0000,  0.0000,  0.1350,  0.0481,  0.0000, -0.9065,
-        -0.0396,  0.0000,  0.0983,  0.2380, -0.0488,  0.0000, -0.4209, -1.2175,
-         0.0000, -0.1096,  0.0000,  0.0000,  0.0000,  0.0000,  0.0372,  0.5375,
-        -0.1525,  0.0000, -0.2431,  0.0000,  0.0000, -0.1437,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000, -0.0046, -0.2161,  0.0067,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4821,  0.0000,  0.0000, -0.0589,  0.0000,  0.1647],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5444e-01, -1.1373e-11,  4.2188e-02,  7.7516e-05,  2.1158e+00,
-         4.6635e-02,  6.8845e-21, -5.4361e-14, -6.8509e-02,  2.2330e-01,
-         2.3022e-01,  1.2376e-12, -2.9157e-01, -4.4475e-01, -1.2346e-17,
-        -9.1760e-15, -9.0260e-16, -2.3360e-01, -2.0876e-18,  2.5100e-14,
-         1.3170e-01,  5.9434e-02, -4.4334e-18, -9.0369e-01, -4.6607e-02,
-         3.5982e-14,  1.0392e-01,  2.3432e-01, -5.6825e-02,  6.4477e-18,
-        -4.1316e-01, -1.2193e+00, -1.6505e-21, -9.5199e-02,  8.1922e-16,
-         4.4202e-14,  2.7327e-17,  0.0000e+00,  1.8286e-02,  5.3136e-01,
-        -1.6535e-01, -3.6500e-15, -2.4317e-01,  6.7255e-18,  1.6450e-15,
-        -1.5168e-01, -2.3210e-13,  7.4893e-07,  7.8079e-13, -4.2816e-19,
-        -1.3853e-09,  5.2105e-03, -2.2137e-01, -1.2249e-02, -1.5690e-10,
-         8.9107e-23, -3.5021e-17,  1.2595e-20,  4.8326e-01, -1.8189e-12,
-         3.3643e-20, -5.7022e-02,  2.8267e-08,  1.5922e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.5444e-01,  0.0000e+00,  4.2188e-02,  7.7516e-05,  2.1158e+00,
-         4.6635e-02,  0.0000e+00,  0.0000e+00, -6.8509e-02,  2.2330e-01,
-         2.3022e-01,  0.0000e+00, -2.9157e-01, -4.4475e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.3360e-01,  0.0000e+00,  0.0000e+00,
-         1.3170e-01,  5.9434e-02,  0.0000e+00, -9.0369e-01, -4.6607e-02,
-         0.0000e+00,  1.0392e-01,  2.3432e-01, -5.6825e-02,  0.0000e+00,
-        -4.1316e-01, -1.2193e+00,  0.0000e+00, -9.5199e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.8286e-02,  5.3136e-01,
-        -1.6535e-01,  0.0000e+00, -2.4317e-01,  0.0000e+00,  0.0000e+00,
-        -1.5168e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  5.2105e-03, -2.2137e-01, -1.2249e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.8326e-01,  0.0000e+00,
-         0.0000e+00, -5.7022e-02,  0.0000e+00,  1.5922e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.5444e-01,  0.0000e+00,  4.2188e-02,  7.7516e-05,  2.1158e+00,
-         4.6635e-02,  0.0000e+00,  0.0000e+00, -6.8509e-02,  2.2330e-01,
-         2.3022e-01,  0.0000e+00, -2.9157e-01, -4.4475e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.3360e-01,  0.0000e+00,  0.0000e+00,
-         1.3170e-01,  5.9434e-02,  0.0000e+00, -9.0369e-01, -4.6607e-02,
-         0.0000e+00,  1.0392e-01,  2.3432e-01, -5.6825e-02,  0.0000e+00,
-        -4.1316e-01, -1.2193e+00,  0.0000e+00, -9.5199e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.8286e-02,  5.3136e-01,
-        -1.6535e-01,  0.0000e+00, -2.4317e-01,  0.0000e+00,  0.0000e+00,
-        -1.5168e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  5.2105e-03, -2.2137e-01, -1.2249e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.8326e-01,  0.0000e+00,
-         0.0000e+00, -5.7022e-02,  0.0000e+00,  1.5922e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5013e-01, -1.0426e-11,  4.1648e-02, -2.3349e-02,  2.1152e+00,
-         5.8595e-02,  6.3108e-21, -4.9831e-14, -7.0808e-02,  2.0968e-01,
-         2.0366e-01,  1.1345e-12, -2.9171e-01, -4.4061e-01, -1.1317e-17,
-        -8.4113e-15, -8.2738e-16, -2.2751e-01, -1.9137e-18,  2.3009e-14,
-         1.2688e-01,  7.6131e-02, -4.0639e-18, -8.9639e-01, -4.7831e-02,
-         3.2983e-14,  1.0725e-01,  2.2228e-01, -6.1396e-02,  5.9104e-18,
-        -4.0580e-01, -1.2207e+00, -1.5129e-21, -7.3318e-02,  7.5095e-16,
-         4.0519e-14,  2.5050e-17,  0.0000e+00, -8.4725e-04,  5.2204e-01,
-        -1.7427e-01, -3.3459e-15, -2.4837e-01,  6.1650e-18,  1.5079e-15,
-        -1.6617e-01, -2.1276e-13,  6.8652e-07,  7.1573e-13, -3.9248e-19,
-        -1.2699e-09,  3.1572e-02, -2.3856e-01, -2.2193e-02, -1.4382e-10,
-         8.1682e-23, -3.2102e-17,  1.1545e-20,  4.8518e-01, -1.6674e-12,
-         3.0840e-20, -5.6478e-02,  2.5911e-08,  1.4634e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.5013e-01,  0.0000e+00,  4.1648e-02, -2.3349e-02,  2.1152e+00,
-         5.8595e-02,  0.0000e+00,  0.0000e+00, -7.0808e-02,  2.0968e-01,
-         2.0366e-01,  0.0000e+00, -2.9171e-01, -4.4061e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.2751e-01,  0.0000e+00,  0.0000e+00,
-         1.2688e-01,  7.6131e-02,  0.0000e+00, -8.9639e-01, -4.7831e-02,
-         0.0000e+00,  1.0725e-01,  2.2228e-01, -6.1396e-02,  0.0000e+00,
-        -4.0580e-01, -1.2207e+00,  0.0000e+00, -7.3318e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -8.4725e-04,  5.2204e-01,
-        -1.7427e-01,  0.0000e+00, -2.4837e-01,  0.0000e+00,  0.0000e+00,
-        -1.6617e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  3.1572e-02, -2.3856e-01, -2.2193e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.8518e-01,  0.0000e+00,
-         0.0000e+00, -5.6478e-02,  0.0000e+00,  1.4634e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.5013e-01,  0.0000e+00,  4.1648e-02, -2.3349e-02,  2.1152e+00,
-         5.8595e-02,  0.0000e+00,  0.0000e+00, -7.0808e-02,  2.0968e-01,
-         2.0366e-01,  0.0000e+00, -2.9171e-01, -4.4061e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -2.2751e-01,  0.0000e+00,  0.0000e+00,
-         1.2688e-01,  7.6131e-02,  0.0000e+00, -8.9639e-01, -4.7831e-02,
-         0.0000e+00,  1.0725e-01,  2.2228e-01, -6.1396e-02,  0.0000e+00,
-        -4.0580e-01, -1.2207e+00,  0.0000e+00, -7.3318e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00, -8.4725e-04,  5.2204e-01,
-        -1.7427e-01,  0.0000e+00, -2.4837e-01,  0.0000e+00,  0.0000e+00,
-        -1.6617e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  3.1572e-02, -2.3856e-01, -2.2193e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.8518e-01,  0.0000e+00,
-         0.0000e+00, -5.6478e-02,  0.0000e+00,  1.4634e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4584e-01, -9.5567e-12,  3.8066e-02, -4.5897e-02,  2.1146e+00,
-         6.5201e-02,  5.7848e-21, -4.5678e-14, -7.0340e-02,  1.9432e-01,
-         1.7341e-01,  1.0400e-12, -2.9069e-01, -4.3911e-01, -1.0374e-17,
-        -7.7102e-15, -7.5842e-16, -2.1360e-01, -1.7542e-18,  2.1091e-14,
-         1.2304e-01,  9.3750e-02, -3.7252e-18, -8.8891e-01, -4.8145e-02,
-         3.0234e-14,  1.0685e-01,  2.1180e-01, -6.0212e-02,  5.4178e-18,
-        -4.0235e-01, -1.2231e+00, -1.3868e-21, -5.6660e-02,  6.8836e-16,
-         3.7141e-14,  2.2962e-17,  0.0000e+00, -1.9719e-02,  5.1301e-01,
-        -1.7461e-01, -3.0670e-15, -2.5298e-01,  5.6512e-18,  1.3822e-15,
-        -1.8097e-01, -1.9503e-13,  6.2930e-07,  6.5608e-13, -3.5977e-19,
-        -1.1640e-09,  6.1424e-02, -2.5794e-01, -2.1567e-02, -1.3184e-10,
-         7.4874e-23, -2.9427e-17,  1.0583e-20,  4.8630e-01, -1.5284e-12,
-         2.8269e-20, -5.8850e-02,  2.3752e-08,  1.2417e-01], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3458,  0.0000,  0.0381, -0.0459,  2.1146,  0.0652,  0.0000,  0.0000,
-        -0.0703,  0.1943,  0.1734,  0.0000, -0.2907, -0.4391,  0.0000,  0.0000,
-         0.0000, -0.2136,  0.0000,  0.0000,  0.1230,  0.0938,  0.0000, -0.8889,
-        -0.0481,  0.0000,  0.1068,  0.2118, -0.0602,  0.0000, -0.4024, -1.2231,
-         0.0000, -0.0567,  0.0000,  0.0000,  0.0000,  0.0000, -0.0197,  0.5130,
-        -0.1746,  0.0000, -0.2530,  0.0000,  0.0000, -0.1810,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0614, -0.2579, -0.0216,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4863,  0.0000,  0.0000, -0.0589,  0.0000,  0.1242],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3458,  0.0000,  0.0381, -0.0459,  2.1146,  0.0652,  0.0000,  0.0000,
-        -0.0703,  0.1943,  0.1734,  0.0000, -0.2907, -0.4391,  0.0000,  0.0000,
-         0.0000, -0.2136,  0.0000,  0.0000,  0.1230,  0.0938,  0.0000, -0.8889,
-        -0.0481,  0.0000,  0.1068,  0.2118, -0.0602,  0.0000, -0.4024, -1.2231,
-         0.0000, -0.0567,  0.0000,  0.0000,  0.0000,  0.0000, -0.0197,  0.5130,
-        -0.1746,  0.0000, -0.2530,  0.0000,  0.0000, -0.1810,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0614, -0.2579, -0.0216,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4863,  0.0000,  0.0000, -0.0589,  0.0000,  0.1242],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4584e-01, -8.7599e-12,  3.2733e-02, -5.7945e-02,  2.1143e+00,
-         6.8495e-02,  5.3025e-21, -4.1870e-14, -6.7369e-02,  1.8307e-01,
-         1.5097e-01,  9.5325e-13, -2.8736e-01, -4.3510e-01, -9.5090e-18,
-        -7.0674e-15, -6.9519e-16, -1.9362e-01, -1.6079e-18,  1.9333e-14,
-         1.2384e-01,  1.2006e-01, -3.4146e-18, -8.8110e-01, -3.9032e-02,
-         2.7713e-14,  1.0022e-01,  2.1465e-01, -5.5457e-02,  4.9661e-18,
-        -4.0243e-01, -1.2249e+00, -1.2712e-21, -4.8196e-02,  6.3097e-16,
-         3.4045e-14,  2.1048e-17,  0.0000e+00, -3.2097e-02,  5.0603e-01,
-        -1.6133e-01, -2.8113e-15, -2.5909e-01,  5.1800e-18,  1.2670e-15,
-        -1.9287e-01, -1.7877e-13,  5.7684e-07,  6.0138e-13, -3.2977e-19,
-        -1.0670e-09,  9.3388e-02, -2.7698e-01, -6.6370e-03, -1.2085e-10,
-         6.8631e-23, -2.6973e-17,  9.7006e-21,  4.8277e-01, -1.4010e-12,
-         2.5912e-20, -5.8011e-02,  2.1771e-08,  9.8686e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3458,  0.0000,  0.0327, -0.0579,  2.1143,  0.0685,  0.0000,  0.0000,
-        -0.0674,  0.1831,  0.1510,  0.0000, -0.2874, -0.4351,  0.0000,  0.0000,
-         0.0000, -0.1936,  0.0000,  0.0000,  0.1238,  0.1201,  0.0000, -0.8811,
-        -0.0390,  0.0000,  0.1002,  0.2146, -0.0555,  0.0000, -0.4024, -1.2249,
-         0.0000, -0.0482,  0.0000,  0.0000,  0.0000,  0.0000, -0.0321,  0.5060,
-        -0.1613,  0.0000, -0.2591,  0.0000,  0.0000, -0.1929,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0934, -0.2770, -0.0066,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4828,  0.0000,  0.0000, -0.0580,  0.0000,  0.0987],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3458,  0.0000,  0.0327, -0.0579,  2.1143,  0.0685,  0.0000,  0.0000,
-        -0.0674,  0.1831,  0.1510,  0.0000, -0.2874, -0.4351,  0.0000,  0.0000,
-         0.0000, -0.1936,  0.0000,  0.0000,  0.1238,  0.1201,  0.0000, -0.8811,
-        -0.0390,  0.0000,  0.1002,  0.2146, -0.0555,  0.0000, -0.4024, -1.2249,
-         0.0000, -0.0482,  0.0000,  0.0000,  0.0000,  0.0000, -0.0321,  0.5060,
-        -0.1613,  0.0000, -0.2591,  0.0000,  0.0000, -0.1929,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0934, -0.2770, -0.0066,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4828,  0.0000,  0.0000, -0.0580,  0.0000,  0.0987],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4533e-01, -8.0293e-12,  2.4431e-02, -6.6858e-02,  2.1141e+00,
-         6.5681e-02,  4.8603e-21, -3.8377e-14, -6.3023e-02,  1.6374e-01,
-         1.3649e-01,  8.7374e-13, -2.8086e-01, -4.3218e-01, -8.7158e-18,
-        -6.4779e-15, -6.3720e-16, -1.7501e-01, -1.4738e-18,  1.7720e-14,
-         1.2946e-01,  1.4350e-01, -3.1298e-18, -8.7252e-01, -3.2404e-02,
-         2.5402e-14,  8.9816e-02,  2.2108e-01, -4.7282e-02,  4.5519e-18,
-        -4.0793e-01, -1.2261e+00, -1.1652e-21, -4.5358e-02,  5.7834e-16,
-         3.1205e-14,  1.9292e-17,  0.0000e+00, -3.4429e-02,  5.0218e-01,
-        -1.4215e-01, -2.5768e-15, -2.6487e-01,  4.7480e-18,  1.1613e-15,
-        -2.0645e-01, -1.6386e-13,  5.2872e-07,  5.5122e-13, -3.0227e-19,
-        -9.7799e-10,  1.2429e-01, -2.9319e-01,  6.0327e-03, -1.1077e-10,
-         6.2907e-23, -2.4723e-17,  8.8914e-21,  4.7794e-01, -1.2841e-12,
-         2.3751e-20, -5.7497e-02,  1.9956e-08,  7.1081e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3453,  0.0000,  0.0244, -0.0669,  2.1141,  0.0657,  0.0000,  0.0000,
-        -0.0630,  0.1637,  0.1365,  0.0000, -0.2809, -0.4322,  0.0000,  0.0000,
-         0.0000, -0.1750,  0.0000,  0.0000,  0.1295,  0.1435,  0.0000, -0.8725,
-        -0.0324,  0.0000,  0.0898,  0.2211, -0.0473,  0.0000, -0.4079, -1.2261,
-         0.0000, -0.0454,  0.0000,  0.0000,  0.0000,  0.0000, -0.0344,  0.5022,
-        -0.1421,  0.0000, -0.2649,  0.0000,  0.0000, -0.2064,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1243, -0.2932,  0.0060,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4779,  0.0000,  0.0000, -0.0575,  0.0000,  0.0711],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3453,  0.0000,  0.0244, -0.0669,  2.1141,  0.0657,  0.0000,  0.0000,
-        -0.0630,  0.1637,  0.1365,  0.0000, -0.2809, -0.4322,  0.0000,  0.0000,
-         0.0000, -0.1750,  0.0000,  0.0000,  0.1295,  0.1435,  0.0000, -0.8725,
-        -0.0324,  0.0000,  0.0898,  0.2211, -0.0473,  0.0000, -0.4079, -1.2261,
-         0.0000, -0.0454,  0.0000,  0.0000,  0.0000,  0.0000, -0.0344,  0.5022,
-        -0.1421,  0.0000, -0.2649,  0.0000,  0.0000, -0.2064,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1243, -0.2932,  0.0060,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4779,  0.0000,  0.0000, -0.0575,  0.0000,  0.0711],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4585e-01, -7.3592e-12,  1.8012e-02, -7.1499e-02,  2.1139e+00,
-         6.5921e-02,  4.4546e-21, -3.5175e-14, -6.5567e-02,  1.4121e-01,
-         1.1891e-01,  8.0082e-13, -2.7479e-01, -4.2717e-01, -7.9884e-18,
-        -5.9373e-15, -5.8403e-16, -1.6319e-01, -1.3508e-18,  1.6241e-14,
-         1.3760e-01,  1.6643e-01, -2.8686e-18, -8.6228e-01, -2.7125e-02,
-         2.3282e-14,  8.1310e-02,  2.2910e-01, -3.6544e-02,  4.1720e-18,
-        -4.1313e-01, -1.2270e+00, -1.0679e-21, -4.7359e-02,  5.3008e-16,
-         2.8601e-14,  1.7682e-17,  0.0000e+00, -3.4913e-02,  4.9876e-01,
-        -1.2283e-01, -2.3617e-15, -2.6991e-01,  4.3517e-18,  1.0644e-15,
-        -2.1964e-01, -1.5018e-13,  4.8460e-07,  5.0521e-13, -2.7704e-19,
-        -8.9638e-10,  1.5108e-01, -3.1118e-01,  2.2342e-02, -1.0152e-10,
-         5.7657e-23, -2.2660e-17,  8.1494e-21,  4.7314e-01, -1.1769e-12,
-         2.1769e-20, -5.5402e-02,  1.8290e-08,  4.3383e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3458,  0.0000,  0.0180, -0.0715,  2.1139,  0.0659,  0.0000,  0.0000,
-        -0.0656,  0.1412,  0.1189,  0.0000, -0.2748, -0.4272,  0.0000,  0.0000,
-         0.0000, -0.1632,  0.0000,  0.0000,  0.1376,  0.1664,  0.0000, -0.8623,
-        -0.0271,  0.0000,  0.0813,  0.2291, -0.0365,  0.0000, -0.4131, -1.2270,
-         0.0000, -0.0474,  0.0000,  0.0000,  0.0000,  0.0000, -0.0349,  0.4988,
-        -0.1228,  0.0000, -0.2699,  0.0000,  0.0000, -0.2196,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1511, -0.3112,  0.0223,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4731,  0.0000,  0.0000, -0.0554,  0.0000,  0.0434],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3458,  0.0000,  0.0180, -0.0715,  2.1139,  0.0659,  0.0000,  0.0000,
-        -0.0656,  0.1412,  0.1189,  0.0000, -0.2748, -0.4272,  0.0000,  0.0000,
-         0.0000, -0.1632,  0.0000,  0.0000,  0.1376,  0.1664,  0.0000, -0.8623,
-        -0.0271,  0.0000,  0.0813,  0.2291, -0.0365,  0.0000, -0.4131, -1.2270,
-         0.0000, -0.0474,  0.0000,  0.0000,  0.0000,  0.0000, -0.0349,  0.4988,
-        -0.1228,  0.0000, -0.2699,  0.0000,  0.0000, -0.2196,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1511, -0.3112,  0.0223,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4731,  0.0000,  0.0000, -0.0554,  0.0000,  0.0434],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4646e-01, -6.7446e-12,  1.2027e-02, -7.5426e-02,  2.1135e+00,
-         6.2781e-02,  4.0826e-21, -3.2237e-14, -6.6253e-02,  1.2671e-01,
-         1.0003e-01,  7.3394e-13, -2.6916e-01, -4.2358e-01, -7.3213e-18,
-        -5.4415e-15, -5.3525e-16, -1.4710e-01, -1.2380e-18,  1.4885e-14,
-         1.4662e-01,  1.8836e-01, -2.6291e-18, -8.5356e-01, -2.2917e-02,
-         2.1338e-14,  7.6154e-02,  2.3988e-01, -2.3864e-02,  3.8236e-18,
-        -4.1967e-01, -1.2276e+00, -9.7876e-22, -5.2811e-02,  4.8581e-16,
-         2.6212e-14,  1.6205e-17,  0.0000e+00, -4.0792e-02,  4.9390e-01,
-        -1.0414e-01, -2.1645e-15, -2.7435e-01,  3.9883e-18,  9.7550e-16,
-        -2.3270e-01, -1.3764e-13,  4.4413e-07,  4.6302e-13, -2.5390e-19,
-        -8.2152e-10,  1.7887e-01, -3.2484e-01,  4.0133e-02, -9.3044e-11,
-         5.2842e-23, -2.0768e-17,  7.4688e-21,  4.6924e-01, -1.0787e-12,
-         1.9951e-20, -5.1011e-02,  1.6763e-08,  2.3442e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3465,  0.0000,  0.0120, -0.0754,  2.1135,  0.0628,  0.0000,  0.0000,
-        -0.0663,  0.1267,  0.1000,  0.0000, -0.2692, -0.4236,  0.0000,  0.0000,
-         0.0000, -0.1471,  0.0000,  0.0000,  0.1466,  0.1884,  0.0000, -0.8536,
-        -0.0229,  0.0000,  0.0762,  0.2399, -0.0239,  0.0000, -0.4197, -1.2276,
-         0.0000, -0.0528,  0.0000,  0.0000,  0.0000,  0.0000, -0.0408,  0.4939,
-        -0.1041,  0.0000, -0.2744,  0.0000,  0.0000, -0.2327,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1789, -0.3248,  0.0401,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4692,  0.0000,  0.0000, -0.0510,  0.0000,  0.0234],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3465,  0.0000,  0.0120, -0.0754,  2.1135,  0.0628,  0.0000,  0.0000,
-        -0.0663,  0.1267,  0.1000,  0.0000, -0.2692, -0.4236,  0.0000,  0.0000,
-         0.0000, -0.1471,  0.0000,  0.0000,  0.1466,  0.1884,  0.0000, -0.8536,
-        -0.0229,  0.0000,  0.0762,  0.2399, -0.0239,  0.0000, -0.4197, -1.2276,
-         0.0000, -0.0528,  0.0000,  0.0000,  0.0000,  0.0000, -0.0408,  0.4939,
-        -0.1041,  0.0000, -0.2744,  0.0000,  0.0000, -0.2327,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1789, -0.3248,  0.0401,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4692,  0.0000,  0.0000, -0.0510,  0.0000,  0.0234],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4915e-01, -6.1809e-12,  1.1792e-02, -7.4549e-02,  2.1128e+00,
-         6.3420e-02,  3.7414e-21, -2.9543e-14, -7.0485e-02,  1.1816e-01,
-         9.3475e-02,  6.7260e-13, -2.5964e-01, -4.2373e-01, -6.7094e-18,
-        -4.9867e-15, -4.9052e-16, -1.3414e-01, -1.1345e-18,  1.3641e-14,
-         1.5264e-01,  2.0565e-01, -2.4093e-18, -8.4949e-01, -2.2858e-02,
-         1.9554e-14,  7.7852e-02,  2.5303e-01, -1.6422e-02,  3.5040e-18,
-        -4.2444e-01, -1.2284e+00, -8.9696e-22, -6.1953e-02,  4.4521e-16,
-         2.4022e-14,  1.4851e-17,  0.0000e+00, -3.5356e-02,  4.9051e-01,
-        -9.0057e-02, -1.9836e-15, -2.7249e-01,  3.6550e-18,  8.9397e-16,
-        -2.3519e-01, -1.2614e-13,  4.0701e-07,  4.2432e-13, -2.3268e-19,
-        -7.5286e-10,  1.9331e-01, -3.2496e-01,  6.2684e-02, -8.5267e-11,
-         4.8425e-23, -1.9032e-17,  6.8446e-21,  4.6431e-01, -9.8851e-13,
-         1.8284e-20, -4.2452e-02,  1.5362e-08,  1.7480e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3491,  0.0000,  0.0118, -0.0745,  2.1128,  0.0634,  0.0000,  0.0000,
-        -0.0705,  0.1182,  0.0935,  0.0000, -0.2596, -0.4237,  0.0000,  0.0000,
-         0.0000, -0.1341,  0.0000,  0.0000,  0.1526,  0.2056,  0.0000, -0.8495,
-        -0.0229,  0.0000,  0.0779,  0.2530, -0.0164,  0.0000, -0.4244, -1.2284,
-         0.0000, -0.0620,  0.0000,  0.0000,  0.0000,  0.0000, -0.0354,  0.4905,
-        -0.0901,  0.0000, -0.2725,  0.0000,  0.0000, -0.2352,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1933, -0.3250,  0.0627,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4643,  0.0000,  0.0000, -0.0425,  0.0000,  0.0175],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3491,  0.0000,  0.0118, -0.0745,  2.1128,  0.0634,  0.0000,  0.0000,
-        -0.0705,  0.1182,  0.0935,  0.0000, -0.2596, -0.4237,  0.0000,  0.0000,
-         0.0000, -0.1341,  0.0000,  0.0000,  0.1526,  0.2056,  0.0000, -0.8495,
-        -0.0229,  0.0000,  0.0779,  0.2530, -0.0164,  0.0000, -0.4244, -1.2284,
-         0.0000, -0.0620,  0.0000,  0.0000,  0.0000,  0.0000, -0.0354,  0.4905,
-        -0.0901,  0.0000, -0.2725,  0.0000,  0.0000, -0.2352,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1933, -0.3250,  0.0627,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4643,  0.0000,  0.0000, -0.0425,  0.0000,  0.0175],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5177e-01, -5.6638e-12,  1.2570e-02, -7.0952e-02,  2.1121e+00,
-         6.4830e-02,  3.4284e-21, -2.7071e-14, -7.5429e-02,  1.1411e-01,
-         9.8885e-02,  6.1634e-13, -2.5321e-01, -4.2497e-01, -6.1481e-18,
-        -4.5695e-15, -4.4948e-16, -1.2743e-01, -1.0396e-18,  1.2500e-14,
-         1.5322e-01,  2.1473e-01, -2.2078e-18, -8.5050e-01, -2.6477e-02,
-         1.7918e-14,  8.2865e-02,  2.6152e-01, -1.1179e-02,  3.2109e-18,
-        -4.3363e-01, -1.2302e+00, -8.2192e-22, -7.3463e-02,  4.0796e-16,
-         2.2012e-14,  1.3609e-17,  0.0000e+00, -3.2633e-02,  4.8754e-01,
-        -7.6060e-02, -1.8177e-15, -2.7172e-01,  3.3492e-18,  8.1918e-16,
-        -2.3226e-01, -1.1558e-13,  3.7296e-07,  3.8883e-13, -2.1322e-19,
-        -6.8988e-10,  1.9504e-01, -3.1217e-01,  8.2218e-02, -7.8134e-11,
-         4.4374e-23, -1.7440e-17,  6.2720e-21,  4.5946e-01, -9.0581e-13,
-         1.6754e-20, -3.5049e-02,  1.4077e-08,  2.1967e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3518,  0.0000,  0.0126, -0.0710,  2.1121,  0.0648,  0.0000,  0.0000,
-        -0.0754,  0.1141,  0.0989,  0.0000, -0.2532, -0.4250,  0.0000,  0.0000,
-         0.0000, -0.1274,  0.0000,  0.0000,  0.1532,  0.2147,  0.0000, -0.8505,
-        -0.0265,  0.0000,  0.0829,  0.2615, -0.0112,  0.0000, -0.4336, -1.2302,
-         0.0000, -0.0735,  0.0000,  0.0000,  0.0000,  0.0000, -0.0326,  0.4875,
-        -0.0761,  0.0000, -0.2717,  0.0000,  0.0000, -0.2323,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1950, -0.3122,  0.0822,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4595,  0.0000,  0.0000, -0.0350,  0.0000,  0.0220],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3518,  0.0000,  0.0126, -0.0710,  2.1121,  0.0648,  0.0000,  0.0000,
-        -0.0754,  0.1141,  0.0989,  0.0000, -0.2532, -0.4250,  0.0000,  0.0000,
-         0.0000, -0.1274,  0.0000,  0.0000,  0.1532,  0.2147,  0.0000, -0.8505,
-        -0.0265,  0.0000,  0.0829,  0.2615, -0.0112,  0.0000, -0.4336, -1.2302,
-         0.0000, -0.0735,  0.0000,  0.0000,  0.0000,  0.0000, -0.0326,  0.4875,
-        -0.0761,  0.0000, -0.2717,  0.0000,  0.0000, -0.2323,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1950, -0.3122,  0.0822,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4595,  0.0000,  0.0000, -0.0350,  0.0000,  0.0220],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5822e-01, -5.1895e-12,  1.6191e-02, -6.9945e-02,  2.1113e+00,
-         6.2055e-02,  3.1413e-21, -2.4804e-14, -7.9788e-02,  1.1033e-01,
-         1.1460e-01,  5.6472e-13, -2.4385e-01, -4.2552e-01, -5.6333e-18,
-        -4.1868e-15, -4.1184e-16, -1.2471e-01, -9.5256e-19,  1.1453e-14,
-         1.4724e-01,  2.1991e-01, -2.0229e-18, -8.5195e-01, -3.3335e-02,
-         1.6418e-14,  8.6836e-02,  2.6604e-01, -8.2903e-03,  2.9420e-18,
-        -4.4341e-01, -1.2319e+00, -7.5309e-22, -8.8913e-02,  3.7380e-16,
-         2.0169e-14,  1.2469e-17,  0.0000e+00, -1.5010e-02,  4.8820e-01,
-        -6.0024e-02, -1.6654e-15, -2.6687e-01,  3.0687e-18,  7.5058e-16,
-        -2.3308e-01, -1.0591e-13,  3.4173e-07,  3.5626e-13, -1.9536e-19,
-        -6.3210e-10,  1.8943e-01, -2.9808e-01,  9.5377e-02, -7.1591e-11,
-         4.0658e-23, -1.5979e-17,  5.7468e-21,  4.5093e-01, -8.2996e-13,
-         1.5351e-20, -3.2771e-02,  1.2898e-08,  3.0725e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3582,  0.0000,  0.0162, -0.0699,  2.1113,  0.0621,  0.0000,  0.0000,
-        -0.0798,  0.1103,  0.1146,  0.0000, -0.2438, -0.4255,  0.0000,  0.0000,
-         0.0000, -0.1247,  0.0000,  0.0000,  0.1472,  0.2199,  0.0000, -0.8519,
-        -0.0333,  0.0000,  0.0868,  0.2660, -0.0083,  0.0000, -0.4434, -1.2319,
-         0.0000, -0.0889,  0.0000,  0.0000,  0.0000,  0.0000, -0.0150,  0.4882,
-        -0.0600,  0.0000, -0.2669,  0.0000,  0.0000, -0.2331,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1894, -0.2981,  0.0954,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4509,  0.0000,  0.0000, -0.0328,  0.0000,  0.0307],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3582,  0.0000,  0.0162, -0.0699,  2.1113,  0.0621,  0.0000,  0.0000,
-        -0.0798,  0.1103,  0.1146,  0.0000, -0.2438, -0.4255,  0.0000,  0.0000,
-         0.0000, -0.1247,  0.0000,  0.0000,  0.1472,  0.2199,  0.0000, -0.8519,
-        -0.0333,  0.0000,  0.0868,  0.2660, -0.0083,  0.0000, -0.4434, -1.2319,
-         0.0000, -0.0889,  0.0000,  0.0000,  0.0000,  0.0000, -0.0150,  0.4882,
-        -0.0600,  0.0000, -0.2669,  0.0000,  0.0000, -0.2331,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1894, -0.2981,  0.0954,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4509,  0.0000,  0.0000, -0.0328,  0.0000,  0.0307],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6520e-01, -4.7544e-12,  2.1500e-02, -6.4923e-02,  2.1108e+00,
-         6.1163e-02,  2.8779e-21, -2.2724e-14, -8.4447e-02,  1.0694e-01,
-         1.3714e-01,  5.1737e-13, -2.3451e-01, -4.2682e-01, -5.1609e-18,
-        -3.8358e-15, -3.7731e-16, -1.1894e-01, -8.7269e-19,  1.0493e-14,
-         1.3941e-01,  2.1873e-01, -1.8533e-18, -8.5362e-01, -4.1685e-02,
-         1.5041e-14,  9.3821e-02,  2.6947e-01, -8.8483e-03,  2.6953e-18,
-        -4.5522e-01, -1.2337e+00, -6.8994e-22, -1.0540e-01,  3.4246e-16,
-         1.8478e-14,  1.1424e-17,  0.0000e+00,  7.3369e-03,  4.8980e-01,
-        -4.4973e-02, -1.5258e-15, -2.6181e-01,  2.8114e-18,  6.8765e-16,
-        -2.3373e-01, -9.7025e-14,  3.1307e-07,  3.2639e-13, -1.7898e-19,
-        -5.7910e-10,  1.7738e-01, -2.7877e-01,  1.1228e-01, -6.5588e-11,
-         3.7249e-23, -1.4640e-17,  5.2649e-21,  4.4313e-01, -7.6037e-13,
-         1.4064e-20, -2.8914e-02,  1.1816e-08,  4.3356e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3652,  0.0000,  0.0215, -0.0649,  2.1108,  0.0612,  0.0000,  0.0000,
-        -0.0844,  0.1069,  0.1371,  0.0000, -0.2345, -0.4268,  0.0000,  0.0000,
-         0.0000, -0.1189,  0.0000,  0.0000,  0.1394,  0.2187,  0.0000, -0.8536,
-        -0.0417,  0.0000,  0.0938,  0.2695, -0.0088,  0.0000, -0.4552, -1.2337,
-         0.0000, -0.1054,  0.0000,  0.0000,  0.0000,  0.0000,  0.0073,  0.4898,
-        -0.0450,  0.0000, -0.2618,  0.0000,  0.0000, -0.2337,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1774, -0.2788,  0.1123,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4431,  0.0000,  0.0000, -0.0289,  0.0000,  0.0434],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3652,  0.0000,  0.0215, -0.0649,  2.1108,  0.0612,  0.0000,  0.0000,
-        -0.0844,  0.1069,  0.1371,  0.0000, -0.2345, -0.4268,  0.0000,  0.0000,
-         0.0000, -0.1189,  0.0000,  0.0000,  0.1394,  0.2187,  0.0000, -0.8536,
-        -0.0417,  0.0000,  0.0938,  0.2695, -0.0088,  0.0000, -0.4552, -1.2337,
-         0.0000, -0.1054,  0.0000,  0.0000,  0.0000,  0.0000,  0.0073,  0.4898,
-        -0.0450,  0.0000, -0.2618,  0.0000,  0.0000, -0.2337,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1774, -0.2788,  0.1123,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4431,  0.0000,  0.0000, -0.0289,  0.0000,  0.0434],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6910e-01, -4.3552e-12,  2.1715e-02, -5.9685e-02,  2.1101e+00,
-         5.6170e-02,  2.6363e-21, -2.0816e-14, -9.2591e-02,  1.0353e-01,
-         1.6875e-01,  4.7393e-13, -2.3019e-01, -4.2299e-01, -4.7276e-18,
-        -3.5137e-15, -3.4563e-16, -1.2259e-01, -7.9942e-19,  9.6116e-15,
-         1.2646e-01,  2.2059e-01, -1.6977e-18, -8.5743e-01, -4.4427e-02,
-         1.3778e-14,  9.2554e-02,  2.6941e-01, -1.2206e-02,  2.4690e-18,
-        -4.6901e-01, -1.2362e+00, -6.3202e-22, -1.2213e-01,  3.1370e-16,
-         1.6926e-14,  1.0464e-17,  0.0000e+00,  3.4122e-02,  4.8848e-01,
-        -2.2763e-02, -1.3977e-15, -2.5805e-01,  2.5754e-18,  6.2991e-16,
-        -2.3063e-01, -8.8879e-14,  2.8679e-07,  2.9899e-13, -1.6395e-19,
-        -5.3048e-10,  1.6694e-01, -2.4919e-01,  1.1693e-01, -6.0081e-11,
-         3.4122e-23, -1.3410e-17,  4.8229e-21,  4.3704e-01, -6.9653e-13,
-         1.2883e-20, -3.0326e-02,  1.0824e-08,  5.5178e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3691,  0.0000,  0.0217, -0.0597,  2.1101,  0.0562,  0.0000,  0.0000,
-        -0.0926,  0.1035,  0.1687,  0.0000, -0.2302, -0.4230,  0.0000,  0.0000,
-         0.0000, -0.1226,  0.0000,  0.0000,  0.1265,  0.2206,  0.0000, -0.8574,
-        -0.0444,  0.0000,  0.0926,  0.2694, -0.0122,  0.0000, -0.4690, -1.2362,
-         0.0000, -0.1221,  0.0000,  0.0000,  0.0000,  0.0000,  0.0341,  0.4885,
-        -0.0228,  0.0000, -0.2580,  0.0000,  0.0000, -0.2306,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1669, -0.2492,  0.1169,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4370,  0.0000,  0.0000, -0.0303,  0.0000,  0.0552],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3691,  0.0000,  0.0217, -0.0597,  2.1101,  0.0562,  0.0000,  0.0000,
-        -0.0926,  0.1035,  0.1687,  0.0000, -0.2302, -0.4230,  0.0000,  0.0000,
-         0.0000, -0.1226,  0.0000,  0.0000,  0.1265,  0.2206,  0.0000, -0.8574,
-        -0.0444,  0.0000,  0.0926,  0.2694, -0.0122,  0.0000, -0.4690, -1.2362,
-         0.0000, -0.1221,  0.0000,  0.0000,  0.0000,  0.0000,  0.0341,  0.4885,
-        -0.0228,  0.0000, -0.2580,  0.0000,  0.0000, -0.2306,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1669, -0.2492,  0.1169,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4370,  0.0000,  0.0000, -0.0303,  0.0000,  0.0552],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7135e-01, -3.9890e-12,  2.0851e-02, -5.3417e-02,  2.1096e+00,
-         4.9808e-02,  2.4146e-21, -1.9066e-14, -1.0088e-01,  1.0150e-01,
-         2.0725e-01,  4.3408e-13, -2.2730e-01, -4.1900e-01, -4.3301e-18,
-        -3.2183e-15, -3.1657e-16, -1.2873e-01, -7.3220e-19,  8.8034e-15,
-         1.0767e-01,  2.1667e-01, -1.5549e-18, -8.6378e-01, -4.4602e-02,
-         1.2620e-14,  8.9562e-02,  2.6228e-01, -1.8857e-02,  2.2614e-18,
-        -4.8378e-01, -1.2392e+00, -5.7887e-22, -1.3914e-01,  2.8733e-16,
-         1.5503e-14,  9.5845e-18,  0.0000e+00,  6.2128e-02,  4.8801e-01,
-        -3.9472e-03, -1.2802e-15, -2.5432e-01,  2.3588e-18,  5.7695e-16,
-        -2.2498e-01, -8.1405e-14,  2.6267e-07,  2.7385e-13, -1.5017e-19,
-        -4.8587e-10,  1.4925e-01, -2.1477e-01,  1.0855e-01, -5.5029e-11,
-         3.1252e-23, -1.2283e-17,  4.4173e-21,  4.3140e-01, -6.3796e-13,
-         1.1800e-20, -3.5841e-02,  9.9140e-09,  6.8748e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3714,  0.0000,  0.0209, -0.0534,  2.1096,  0.0498,  0.0000,  0.0000,
-        -0.1009,  0.1015,  0.2072,  0.0000, -0.2273, -0.4190,  0.0000,  0.0000,
-         0.0000, -0.1287,  0.0000,  0.0000,  0.1077,  0.2167,  0.0000, -0.8638,
-        -0.0446,  0.0000,  0.0896,  0.2623, -0.0189,  0.0000, -0.4838, -1.2392,
-         0.0000, -0.1391,  0.0000,  0.0000,  0.0000,  0.0000,  0.0621,  0.4880,
-        -0.0039,  0.0000, -0.2543,  0.0000,  0.0000, -0.2250,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1493, -0.2148,  0.1086,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4314,  0.0000,  0.0000, -0.0358,  0.0000,  0.0687],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3714,  0.0000,  0.0209, -0.0534,  2.1096,  0.0498,  0.0000,  0.0000,
-        -0.1009,  0.1015,  0.2072,  0.0000, -0.2273, -0.4190,  0.0000,  0.0000,
-         0.0000, -0.1287,  0.0000,  0.0000,  0.1077,  0.2167,  0.0000, -0.8638,
-        -0.0446,  0.0000,  0.0896,  0.2623, -0.0189,  0.0000, -0.4838, -1.2392,
-         0.0000, -0.1391,  0.0000,  0.0000,  0.0000,  0.0000,  0.0621,  0.4880,
-        -0.0039,  0.0000, -0.2543,  0.0000,  0.0000, -0.2250,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1493, -0.2148,  0.1086,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4314,  0.0000,  0.0000, -0.0358,  0.0000,  0.0687],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7358e-01, -3.6530e-12,  1.8609e-02, -4.6791e-02,  2.1089e+00,
-         4.0635e-02,  2.2112e-21, -1.7460e-14, -1.0792e-01,  1.0170e-01,
-         2.4187e-01,  3.9752e-13, -2.2301e-01, -4.1660e-01, -3.9654e-18,
-        -2.9472e-15, -2.8990e-16, -1.3082e-01, -6.7053e-19,  8.0620e-15,
-         8.7363e-02,  2.1270e-01, -1.4240e-18, -8.6986e-01, -4.4197e-02,
-         1.1557e-14,  8.3915e-02,  2.5128e-01, -2.5102e-02,  2.0709e-18,
-        -4.9764e-01, -1.2423e+00, -5.3012e-22, -1.5288e-01,  2.6313e-16,
-         1.4197e-14,  8.7772e-18,  0.0000e+00,  8.4237e-02,  4.8693e-01,
-         1.6294e-02, -1.1723e-15, -2.5130e-01,  2.1602e-18,  5.2835e-16,
-        -2.1971e-01, -7.4549e-14,  2.4055e-07,  2.5078e-13, -1.3752e-19,
-        -4.4495e-10,  1.3297e-01, -1.8000e-01,  9.9593e-02, -5.0394e-11,
-         2.8620e-23, -1.1248e-17,  4.0453e-21,  4.2654e-01, -5.8423e-13,
-         1.0806e-20, -4.3850e-02,  9.0790e-09,  7.9838e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3736,  0.0000,  0.0186, -0.0468,  2.1089,  0.0406,  0.0000,  0.0000,
-        -0.1079,  0.1017,  0.2419,  0.0000, -0.2230, -0.4166,  0.0000,  0.0000,
-         0.0000, -0.1308,  0.0000,  0.0000,  0.0874,  0.2127,  0.0000, -0.8699,
-        -0.0442,  0.0000,  0.0839,  0.2513, -0.0251,  0.0000, -0.4976, -1.2423,
-         0.0000, -0.1529,  0.0000,  0.0000,  0.0000,  0.0000,  0.0842,  0.4869,
-         0.0163,  0.0000, -0.2513,  0.0000,  0.0000, -0.2197,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1330, -0.1800,  0.0996,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4265,  0.0000,  0.0000, -0.0439,  0.0000,  0.0798],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3736,  0.0000,  0.0186, -0.0468,  2.1089,  0.0406,  0.0000,  0.0000,
-        -0.1079,  0.1017,  0.2419,  0.0000, -0.2230, -0.4166,  0.0000,  0.0000,
-         0.0000, -0.1308,  0.0000,  0.0000,  0.0874,  0.2127,  0.0000, -0.8699,
-        -0.0442,  0.0000,  0.0839,  0.2513, -0.0251,  0.0000, -0.4976, -1.2423,
-         0.0000, -0.1529,  0.0000,  0.0000,  0.0000,  0.0000,  0.0842,  0.4869,
-         0.0163,  0.0000, -0.2513,  0.0000,  0.0000, -0.2197,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1330, -0.1800,  0.0996,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4265,  0.0000,  0.0000, -0.0439,  0.0000,  0.0798],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7573e-01, -3.3448e-12,  1.7836e-02, -4.1880e-02,  2.1080e+00,
-         3.2902e-02,  2.0247e-21, -1.5987e-14, -1.1348e-01,  9.5617e-02,
-         2.7301e-01,  3.6398e-13, -2.1792e-01, -4.1456e-01, -3.6308e-18,
-        -2.6985e-15, -2.6544e-16, -1.2751e-01, -6.1395e-19,  7.3817e-15,
-         6.8378e-02,  2.1118e-01, -1.3038e-18, -8.7556e-01, -4.1638e-02,
-         1.0582e-14,  7.7475e-02,  2.3927e-01, -3.2292e-02,  1.8962e-18,
-        -5.1222e-01, -1.2455e+00, -4.8539e-22, -1.6532e-01,  2.4092e-16,
-         1.2999e-14,  8.0366e-18,  0.0000e+00,  9.7730e-02,  4.8370e-01,
-         3.3704e-02, -1.0734e-15, -2.4930e-01,  1.9779e-18,  4.8377e-16,
-        -2.2012e-01, -6.8259e-14,  2.2025e-07,  2.2962e-13, -1.2592e-19,
-        -4.0741e-10,  1.2212e-01, -1.5391e-01,  9.4415e-02, -4.6142e-11,
-         2.6205e-23, -1.0299e-17,  3.7039e-21,  4.2129e-01, -5.3493e-13,
-         9.8941e-21, -5.4038e-02,  8.3130e-09,  8.7099e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3757,  0.0000,  0.0178, -0.0419,  2.1080,  0.0329,  0.0000,  0.0000,
-        -0.1135,  0.0956,  0.2730,  0.0000, -0.2179, -0.4146,  0.0000,  0.0000,
-         0.0000, -0.1275,  0.0000,  0.0000,  0.0684,  0.2112,  0.0000, -0.8756,
-        -0.0416,  0.0000,  0.0775,  0.2393, -0.0323,  0.0000, -0.5122, -1.2455,
-         0.0000, -0.1653,  0.0000,  0.0000,  0.0000,  0.0000,  0.0977,  0.4837,
-         0.0337,  0.0000, -0.2493,  0.0000,  0.0000, -0.2201,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1221, -0.1539,  0.0944,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4213,  0.0000,  0.0000, -0.0540,  0.0000,  0.0871],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3757,  0.0000,  0.0178, -0.0419,  2.1080,  0.0329,  0.0000,  0.0000,
-        -0.1135,  0.0956,  0.2730,  0.0000, -0.2179, -0.4146,  0.0000,  0.0000,
-         0.0000, -0.1275,  0.0000,  0.0000,  0.0684,  0.2112,  0.0000, -0.8756,
-        -0.0416,  0.0000,  0.0775,  0.2393, -0.0323,  0.0000, -0.5122, -1.2455,
-         0.0000, -0.1653,  0.0000,  0.0000,  0.0000,  0.0000,  0.0977,  0.4837,
-         0.0337,  0.0000, -0.2493,  0.0000,  0.0000, -0.2201,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1221, -0.1539,  0.0944,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4213,  0.0000,  0.0000, -0.0540,  0.0000,  0.0871],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7569e-01, -3.0620e-12,  1.6142e-02, -3.8454e-02,  2.1068e+00,
-         2.7333e-02,  1.8535e-21, -1.4635e-14, -1.2173e-01,  8.6506e-02,
-         2.9220e-01,  3.3321e-13, -2.1504e-01, -4.1113e-01, -3.3238e-18,
-        -2.4704e-15, -2.4300e-16, -1.2720e-01, -5.6204e-19,  6.7576e-15,
-         5.0116e-02,  2.0853e-01, -1.1936e-18, -8.7850e-01, -3.9427e-02,
-         9.6871e-15,  7.1011e-02,  2.2695e-01, -3.7943e-02,  1.7359e-18,
-        -5.2447e-01, -1.2479e+00, -4.4435e-22, -1.7194e-01,  2.2055e-16,
-         1.1900e-14,  7.3572e-18,  0.0000e+00,  1.0602e-01,  4.7742e-01,
-         4.4115e-02, -9.8267e-16, -2.4756e-01,  1.8107e-18,  4.4287e-16,
-        -2.2511e-01, -6.2488e-14,  2.0163e-07,  2.1021e-13, -1.1527e-19,
-        -3.7296e-10,  1.1441e-01, -1.3893e-01,  8.4395e-02, -4.2241e-11,
-         2.3990e-23, -9.4284e-18,  3.3908e-21,  4.1846e-01, -4.8970e-13,
-         9.0576e-21, -6.3752e-02,  7.6102e-09,  9.0301e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3757,  0.0000,  0.0161, -0.0385,  2.1068,  0.0273,  0.0000,  0.0000,
-        -0.1217,  0.0865,  0.2922,  0.0000, -0.2150, -0.4111,  0.0000,  0.0000,
-         0.0000, -0.1272,  0.0000,  0.0000,  0.0501,  0.2085,  0.0000, -0.8785,
-        -0.0394,  0.0000,  0.0710,  0.2270, -0.0379,  0.0000, -0.5245, -1.2479,
-         0.0000, -0.1719,  0.0000,  0.0000,  0.0000,  0.0000,  0.1060,  0.4774,
-         0.0441,  0.0000, -0.2476,  0.0000,  0.0000, -0.2251,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1144, -0.1389,  0.0844,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4185,  0.0000,  0.0000, -0.0638,  0.0000,  0.0903],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3757,  0.0000,  0.0161, -0.0385,  2.1068,  0.0273,  0.0000,  0.0000,
-        -0.1217,  0.0865,  0.2922,  0.0000, -0.2150, -0.4111,  0.0000,  0.0000,
-         0.0000, -0.1272,  0.0000,  0.0000,  0.0501,  0.2085,  0.0000, -0.8785,
-        -0.0394,  0.0000,  0.0710,  0.2270, -0.0379,  0.0000, -0.5245, -1.2479,
-         0.0000, -0.1719,  0.0000,  0.0000,  0.0000,  0.0000,  0.1060,  0.4774,
-         0.0441,  0.0000, -0.2476,  0.0000,  0.0000, -0.2251,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1144, -0.1389,  0.0844,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4185,  0.0000,  0.0000, -0.0638,  0.0000,  0.0903],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7524e-01, -2.8026e-12,  1.4143e-02, -3.5058e-02,  2.1059e+00,
-         2.3657e-02,  1.6965e-21, -1.3395e-14, -1.3056e-01,  8.3285e-02,
-         2.9900e-01,  3.0498e-13, -2.1473e-01, -4.0783e-01, -3.0422e-18,
-        -2.2611e-15, -2.2241e-16, -1.2500e-01, -5.1442e-19,  6.1851e-15,
-         3.4761e-02,  2.0312e-01, -1.0924e-18, -8.7960e-01, -4.1105e-02,
-         8.8664e-15,  6.6773e-02,  2.1438e-01, -4.0827e-02,  1.5888e-18,
-        -5.3355e-01, -1.2501e+00, -4.0670e-22, -1.7092e-01,  2.0187e-16,
-         1.0892e-14,  6.7339e-18,  0.0000e+00,  1.1793e-01,  4.7404e-01,
-         4.6601e-02, -8.9942e-16, -2.4722e-01,  1.6573e-18,  4.0535e-16,
-        -2.2978e-01, -5.7194e-14,  1.8455e-07,  1.9240e-13, -1.0550e-19,
-        -3.4136e-10,  1.0309e-01, -1.2520e-01,  7.3557e-02, -3.8662e-11,
-         2.1957e-23, -8.6296e-18,  3.1035e-21,  4.1526e-01, -4.4821e-13,
-         8.2902e-21, -7.1553e-02,  6.9654e-09,  9.1182e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3752,  0.0000,  0.0141, -0.0351,  2.1059,  0.0237,  0.0000,  0.0000,
-        -0.1306,  0.0833,  0.2990,  0.0000, -0.2147, -0.4078,  0.0000,  0.0000,
-         0.0000, -0.1250,  0.0000,  0.0000,  0.0348,  0.2031,  0.0000, -0.8796,
-        -0.0411,  0.0000,  0.0668,  0.2144, -0.0408,  0.0000, -0.5335, -1.2501,
-         0.0000, -0.1709,  0.0000,  0.0000,  0.0000,  0.0000,  0.1179,  0.4740,
-         0.0466,  0.0000, -0.2472,  0.0000,  0.0000, -0.2298,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1031, -0.1252,  0.0736,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4153,  0.0000,  0.0000, -0.0716,  0.0000,  0.0912],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3752,  0.0000,  0.0141, -0.0351,  2.1059,  0.0237,  0.0000,  0.0000,
-        -0.1306,  0.0833,  0.2990,  0.0000, -0.2147, -0.4078,  0.0000,  0.0000,
-         0.0000, -0.1250,  0.0000,  0.0000,  0.0348,  0.2031,  0.0000, -0.8796,
-        -0.0411,  0.0000,  0.0668,  0.2144, -0.0408,  0.0000, -0.5335, -1.2501,
-         0.0000, -0.1709,  0.0000,  0.0000,  0.0000,  0.0000,  0.1179,  0.4740,
-         0.0466,  0.0000, -0.2472,  0.0000,  0.0000, -0.2298,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.1031, -0.1252,  0.0736,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4153,  0.0000,  0.0000, -0.0716,  0.0000,  0.0912],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.7258e-01, -2.5646e-12,  9.7755e-03, -3.1829e-02,  2.1048e+00,
-         2.1110e-02,  1.5524e-21, -1.2258e-14, -1.3855e-01,  8.2203e-02,
-         2.9882e-01,  2.7908e-13, -2.1846e-01, -4.0462e-01, -2.7839e-18,
-        -2.0691e-15, -2.0353e-16, -1.1714e-01, -4.7074e-19,  5.6599e-15,
-         2.2466e-02,  1.9989e-01, -9.9968e-19, -8.8077e-01, -4.0946e-02,
-         8.1135e-15,  6.2548e-02,  2.0300e-01, -4.1858e-02,  1.4539e-18,
-        -5.4371e-01, -1.2525e+00, -3.7217e-22, -1.6575e-01,  1.8473e-16,
-         9.9671e-15,  6.1620e-18,  0.0000e+00,  1.1894e-01,  4.6526e-01,
-         4.6520e-02, -8.2304e-16, -2.4955e-01,  1.5165e-18,  3.7093e-16,
-        -2.3612e-01, -5.2337e-14,  1.6888e-07,  1.7606e-13, -9.6545e-20,
-        -3.1238e-10,  9.1299e-02, -1.1017e-01,  6.5578e-02, -3.5379e-11,
-         2.0093e-23, -7.8968e-18,  2.8400e-21,  4.1369e-01, -4.1015e-13,
-         7.5862e-21, -7.9332e-02,  6.3739e-09,  9.1857e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3726,  0.0000,  0.0098, -0.0318,  2.1048,  0.0211,  0.0000,  0.0000,
-        -0.1385,  0.0822,  0.2988,  0.0000, -0.2185, -0.4046,  0.0000,  0.0000,
-         0.0000, -0.1171,  0.0000,  0.0000,  0.0225,  0.1999,  0.0000, -0.8808,
-        -0.0409,  0.0000,  0.0625,  0.2030, -0.0419,  0.0000, -0.5437, -1.2525,
-         0.0000, -0.1658,  0.0000,  0.0000,  0.0000,  0.0000,  0.1189,  0.4653,
-         0.0465,  0.0000, -0.2495,  0.0000,  0.0000, -0.2361,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0913, -0.1102,  0.0656,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4137,  0.0000,  0.0000, -0.0793,  0.0000,  0.0919],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3726,  0.0000,  0.0098, -0.0318,  2.1048,  0.0211,  0.0000,  0.0000,
-        -0.1385,  0.0822,  0.2988,  0.0000, -0.2185, -0.4046,  0.0000,  0.0000,
-         0.0000, -0.1171,  0.0000,  0.0000,  0.0225,  0.1999,  0.0000, -0.8808,
-        -0.0409,  0.0000,  0.0625,  0.2030, -0.0419,  0.0000, -0.5437, -1.2525,
-         0.0000, -0.1658,  0.0000,  0.0000,  0.0000,  0.0000,  0.1189,  0.4653,
-         0.0465,  0.0000, -0.2495,  0.0000,  0.0000, -0.2361,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0913, -0.1102,  0.0656,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4137,  0.0000,  0.0000, -0.0793,  0.0000,  0.0919],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6953e-01, -2.3463e-12,  5.5697e-03, -3.0639e-02,  2.1037e+00,
-         1.7444e-02,  1.4202e-21, -1.1214e-14, -1.4510e-01,  7.6353e-02,
-         2.8507e-01,  2.5532e-13, -2.2070e-01, -4.0172e-01, -2.5469e-18,
-        -1.8929e-15, -1.8620e-16, -1.0801e-01, -4.3067e-19,  5.1780e-15,
-         1.2969e-02,  1.9565e-01, -9.1457e-19, -8.7833e-01, -4.0853e-02,
-         7.4228e-15,  5.5642e-02,  1.9201e-01, -3.9804e-02,  1.3301e-18,
-        -5.5064e-01, -1.2545e+00, -3.4048e-22, -1.5999e-01,  1.6900e-16,
-         9.1186e-15,  5.6374e-18,  0.0000e+00,  1.2044e-01,  4.5878e-01,
-         4.1836e-02, -7.5297e-16, -2.5341e-01,  1.3874e-18,  3.3935e-16,
-        -2.4242e-01, -4.7881e-14,  1.5450e-07,  1.6107e-13, -8.8326e-20,
-        -2.8578e-10,  8.2796e-02, -1.0401e-01,  5.9458e-02, -3.2367e-11,
-         1.8382e-23, -7.2245e-18,  2.5982e-21,  4.1391e-01, -3.7524e-13,
-         6.9404e-21, -8.6162e-02,  5.8313e-09,  8.6706e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3695,  0.0000,  0.0056, -0.0306,  2.1037,  0.0174,  0.0000,  0.0000,
-        -0.1451,  0.0764,  0.2851,  0.0000, -0.2207, -0.4017,  0.0000,  0.0000,
-         0.0000, -0.1080,  0.0000,  0.0000,  0.0130,  0.1956,  0.0000, -0.8783,
-        -0.0409,  0.0000,  0.0556,  0.1920, -0.0398,  0.0000, -0.5506, -1.2545,
-         0.0000, -0.1600,  0.0000,  0.0000,  0.0000,  0.0000,  0.1204,  0.4588,
-         0.0418,  0.0000, -0.2534,  0.0000,  0.0000, -0.2424,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0828, -0.1040,  0.0595,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4139,  0.0000,  0.0000, -0.0862,  0.0000,  0.0867],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3695,  0.0000,  0.0056, -0.0306,  2.1037,  0.0174,  0.0000,  0.0000,
-        -0.1451,  0.0764,  0.2851,  0.0000, -0.2207, -0.4017,  0.0000,  0.0000,
-         0.0000, -0.1080,  0.0000,  0.0000,  0.0130,  0.1956,  0.0000, -0.8783,
-        -0.0409,  0.0000,  0.0556,  0.1920, -0.0398,  0.0000, -0.5506, -1.2545,
-         0.0000, -0.1600,  0.0000,  0.0000,  0.0000,  0.0000,  0.1204,  0.4588,
-         0.0418,  0.0000, -0.2534,  0.0000,  0.0000, -0.2424,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0828, -0.1040,  0.0595,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4139,  0.0000,  0.0000, -0.0862,  0.0000,  0.0867],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.6481e-01, -2.1460e-12,  2.4234e-03, -2.4726e-02,  2.1025e+00,
-         1.7278e-02,  1.2990e-21, -1.0257e-14, -1.4834e-01,  6.8507e-02,
-         2.6214e-01,  2.3353e-13, -2.2672e-01, -3.9673e-01, -2.3295e-18,
-        -1.7314e-15, -1.7031e-16, -1.0428e-01, -3.9390e-19,  4.7360e-15,
-         2.5478e-03,  1.9256e-01, -8.3651e-19, -8.7571e-01, -3.8005e-02,
-         6.7892e-15,  5.2835e-02,  1.7717e-01, -4.0167e-02,  1.2166e-18,
-        -5.5514e-01, -1.2565e+00, -3.1142e-22, -1.5315e-01,  1.5457e-16,
-         8.3402e-15,  5.1562e-18,  0.0000e+00,  1.1818e-01,  4.5012e-01,
-         2.9829e-02, -6.8870e-16, -2.5780e-01,  1.2690e-18,  3.1038e-16,
-        -2.4680e-01, -4.3794e-14,  1.4131e-07,  1.4732e-13, -8.0787e-20,
-        -2.6139e-10,  7.8275e-02, -1.0511e-01,  5.2340e-02, -2.9605e-11,
-         1.6813e-23, -6.6078e-18,  2.3764e-21,  4.1634e-01, -3.4321e-13,
-         6.3480e-21, -9.3164e-02,  5.3335e-09,  8.0490e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3648,  0.0000,  0.0024, -0.0247,  2.1025,  0.0173,  0.0000,  0.0000,
-        -0.1483,  0.0685,  0.2621,  0.0000, -0.2267, -0.3967,  0.0000,  0.0000,
-         0.0000, -0.1043,  0.0000,  0.0000,  0.0025,  0.1926,  0.0000, -0.8757,
-        -0.0380,  0.0000,  0.0528,  0.1772, -0.0402,  0.0000, -0.5551, -1.2565,
-         0.0000, -0.1531,  0.0000,  0.0000,  0.0000,  0.0000,  0.1182,  0.4501,
-         0.0298,  0.0000, -0.2578,  0.0000,  0.0000, -0.2468,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0783, -0.1051,  0.0523,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4163,  0.0000,  0.0000, -0.0932,  0.0000,  0.0805],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3648,  0.0000,  0.0024, -0.0247,  2.1025,  0.0173,  0.0000,  0.0000,
-        -0.1483,  0.0685,  0.2621,  0.0000, -0.2267, -0.3967,  0.0000,  0.0000,
-         0.0000, -0.1043,  0.0000,  0.0000,  0.0025,  0.1926,  0.0000, -0.8757,
-        -0.0380,  0.0000,  0.0528,  0.1772, -0.0402,  0.0000, -0.5551, -1.2565,
-         0.0000, -0.1531,  0.0000,  0.0000,  0.0000,  0.0000,  0.1182,  0.4501,
-         0.0298,  0.0000, -0.2578,  0.0000,  0.0000, -0.2468,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0783, -0.1051,  0.0523,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4163,  0.0000,  0.0000, -0.0932,  0.0000,  0.0805],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5820e-01, -1.9623e-12,  3.2151e-04, -1.9326e-02,  2.1013e+00,
-         2.0961e-02,  1.1878e-21, -9.3791e-15, -1.5247e-01,  5.6304e-02,
-         2.3008e-01,  2.1353e-13, -2.3446e-01, -3.9320e-01, -2.1301e-18,
-        -1.5832e-15, -1.5573e-16, -9.6629e-02, -3.6019e-19,  4.3306e-15,
-        -3.7235e-03,  1.8604e-01, -7.6490e-19, -8.7223e-01, -3.9777e-02,
-         6.2080e-15,  5.0239e-02,  1.6444e-01, -3.5931e-02,  1.1124e-18,
-        -5.5711e-01, -1.2577e+00, -2.8476e-22, -1.4508e-01,  1.4134e-16,
-         7.6263e-15,  4.7149e-18,  0.0000e+00,  1.1444e-01,  4.4116e-01,
-         1.8849e-02, -6.2975e-16, -2.6078e-01,  1.1604e-18,  2.8381e-16,
-        -2.5063e-01, -4.0045e-14,  1.2922e-07,  1.3471e-13, -7.3871e-20,
-        -2.3901e-10,  7.1437e-02, -1.1015e-01,  4.4962e-02, -2.7070e-11,
-         1.5374e-23, -6.0422e-18,  2.1730e-21,  4.2008e-01, -3.1383e-13,
-         5.8046e-21, -9.9873e-02,  4.8770e-09,  6.8962e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.5820e-01,  0.0000e+00,  3.2151e-04, -1.9326e-02,  2.1013e+00,
-         2.0961e-02,  0.0000e+00,  0.0000e+00, -1.5247e-01,  5.6304e-02,
-         2.3008e-01,  0.0000e+00, -2.3446e-01, -3.9320e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -9.6629e-02,  0.0000e+00,  0.0000e+00,
-        -3.7235e-03,  1.8604e-01,  0.0000e+00, -8.7223e-01, -3.9777e-02,
-         0.0000e+00,  5.0239e-02,  1.6444e-01, -3.5931e-02,  0.0000e+00,
-        -5.5711e-01, -1.2577e+00,  0.0000e+00, -1.4508e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1444e-01,  4.4116e-01,
-         1.8849e-02,  0.0000e+00, -2.6078e-01,  0.0000e+00,  0.0000e+00,
-        -2.5063e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  7.1437e-02, -1.1015e-01,  4.4962e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2008e-01,  0.0000e+00,
-         0.0000e+00, -9.9873e-02,  0.0000e+00,  6.8962e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.5820e-01,  0.0000e+00,  3.2151e-04, -1.9326e-02,  2.1013e+00,
-         2.0961e-02,  0.0000e+00,  0.0000e+00, -1.5247e-01,  5.6304e-02,
-         2.3008e-01,  0.0000e+00, -2.3446e-01, -3.9320e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -9.6629e-02,  0.0000e+00,  0.0000e+00,
-        -3.7235e-03,  1.8604e-01,  0.0000e+00, -8.7223e-01, -3.9777e-02,
-         0.0000e+00,  5.0239e-02,  1.6444e-01, -3.5931e-02,  0.0000e+00,
-        -5.5711e-01, -1.2577e+00,  0.0000e+00, -1.4508e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1444e-01,  4.4116e-01,
-         1.8849e-02,  0.0000e+00, -2.6078e-01,  0.0000e+00,  0.0000e+00,
-        -2.5063e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  7.1437e-02, -1.1015e-01,  4.4962e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2008e-01,  0.0000e+00,
-         0.0000e+00, -9.9873e-02,  0.0000e+00,  6.8962e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.5197e-01, -1.7938e-12, -7.5829e-04, -1.5311e-02,  2.1003e+00,
-         2.5982e-02,  1.0858e-21, -8.5738e-15, -1.5766e-01,  4.6765e-02,
-         1.9186e-01,  1.9520e-13, -2.4244e-01, -3.9172e-01, -1.9472e-18,
-        -1.4472e-15, -1.4236e-16, -8.4226e-02, -3.2926e-19,  3.9588e-15,
-        -9.0963e-03,  1.8219e-01, -6.9922e-19, -8.6863e-01, -4.0227e-02,
-         5.6750e-15,  4.9138e-02,  1.5414e-01, -3.2122e-02,  1.0169e-18,
-        -5.5699e-01, -1.2586e+00, -2.6031e-22, -1.3206e-01,  1.2921e-16,
-         6.9715e-15,  4.3100e-18,  0.0000e+00,  1.1167e-01,  4.3562e-01,
-         8.8359e-03, -5.7567e-16, -2.6368e-01,  1.0607e-18,  2.5944e-16,
-        -2.5368e-01, -3.6607e-14,  1.1812e-07,  1.2315e-13, -6.7529e-20,
-        -2.1849e-10,  6.6593e-02, -1.1932e-01,  4.4862e-02, -2.4746e-11,
-         1.4054e-23, -5.5234e-18,  1.9864e-21,  4.2456e-01, -2.8688e-13,
-         5.3062e-21, -1.0455e-01,  4.4582e-09,  5.4885e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.5197e-01,  0.0000e+00, -7.5829e-04, -1.5311e-02,  2.1003e+00,
-         2.5982e-02,  0.0000e+00,  0.0000e+00, -1.5766e-01,  4.6765e-02,
-         1.9186e-01,  0.0000e+00, -2.4244e-01, -3.9172e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -8.4226e-02,  0.0000e+00,  0.0000e+00,
-        -9.0963e-03,  1.8219e-01,  0.0000e+00, -8.6863e-01, -4.0227e-02,
-         0.0000e+00,  4.9138e-02,  1.5414e-01, -3.2122e-02,  0.0000e+00,
-        -5.5699e-01, -1.2586e+00,  0.0000e+00, -1.3206e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1167e-01,  4.3562e-01,
-         8.8359e-03,  0.0000e+00, -2.6368e-01,  0.0000e+00,  0.0000e+00,
-        -2.5368e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  6.6593e-02, -1.1932e-01,  4.4862e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2456e-01,  0.0000e+00,
-         0.0000e+00, -1.0455e-01,  0.0000e+00,  5.4885e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.5197e-01,  0.0000e+00, -7.5829e-04, -1.5311e-02,  2.1003e+00,
-         2.5982e-02,  0.0000e+00,  0.0000e+00, -1.5766e-01,  4.6765e-02,
-         1.9186e-01,  0.0000e+00, -2.4244e-01, -3.9172e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -8.4226e-02,  0.0000e+00,  0.0000e+00,
-        -9.0963e-03,  1.8219e-01,  0.0000e+00, -8.6863e-01, -4.0227e-02,
-         0.0000e+00,  4.9138e-02,  1.5414e-01, -3.2122e-02,  0.0000e+00,
-        -5.5699e-01, -1.2586e+00,  0.0000e+00, -1.3206e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1167e-01,  4.3562e-01,
-         8.8359e-03,  0.0000e+00, -2.6368e-01,  0.0000e+00,  0.0000e+00,
-        -2.5368e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  6.6593e-02, -1.1932e-01,  4.4862e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2456e-01,  0.0000e+00,
-         0.0000e+00, -1.0455e-01,  0.0000e+00,  5.4885e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.4443e-01, -1.6393e-12, -2.1729e-03, -1.2307e-02,  2.0995e+00,
-         3.2250e-02,  9.9229e-22, -7.8352e-15, -1.6298e-01,  3.7598e-02,
-         1.5230e-01,  1.7839e-13, -2.5082e-01, -3.9040e-01, -1.7795e-18,
-        -1.3226e-15, -1.3009e-16, -7.3678e-02, -3.0090e-19,  3.6178e-15,
-        -1.5462e-02,  1.7988e-01, -6.3899e-19, -8.6480e-01, -3.9734e-02,
-         5.1861e-15,  4.7885e-02,  1.4311e-01, -2.7588e-02,  9.2933e-19,
-        -5.5598e-01, -1.2597e+00, -2.3789e-22, -1.1899e-01,  1.1808e-16,
-         6.3709e-15,  3.9387e-18,  0.0000e+00,  1.0430e-01,  4.2841e-01,
-         9.4905e-04, -5.2609e-16, -2.6628e-01,  9.6936e-19,  2.3710e-16,
-        -2.5953e-01, -3.3454e-14,  1.0795e-07,  1.1254e-13, -6.1712e-20,
-        -1.9967e-10,  6.3458e-02, -1.3199e-01,  4.3065e-02, -2.2614e-11,
-         1.2843e-23, -5.0476e-18,  1.8153e-21,  4.2981e-01, -2.6217e-13,
-         4.8491e-21, -1.0960e-01,  4.0742e-09,  3.9359e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.4443e-01,  0.0000e+00, -2.1729e-03, -1.2307e-02,  2.0995e+00,
-         3.2250e-02,  0.0000e+00,  0.0000e+00, -1.6298e-01,  3.7598e-02,
-         1.5230e-01,  0.0000e+00, -2.5082e-01, -3.9040e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -7.3678e-02,  0.0000e+00,  0.0000e+00,
-        -1.5462e-02,  1.7988e-01,  0.0000e+00, -8.6480e-01, -3.9734e-02,
-         0.0000e+00,  4.7885e-02,  1.4311e-01, -2.7588e-02,  0.0000e+00,
-        -5.5598e-01, -1.2597e+00,  0.0000e+00, -1.1899e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0430e-01,  4.2841e-01,
-         9.4905e-04,  0.0000e+00, -2.6628e-01,  0.0000e+00,  0.0000e+00,
-        -2.5953e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  6.3458e-02, -1.3199e-01,  4.3065e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2981e-01,  0.0000e+00,
-         0.0000e+00, -1.0960e-01,  0.0000e+00,  3.9359e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.4443e-01,  0.0000e+00, -2.1729e-03, -1.2307e-02,  2.0995e+00,
-         3.2250e-02,  0.0000e+00,  0.0000e+00, -1.6298e-01,  3.7598e-02,
-         1.5230e-01,  0.0000e+00, -2.5082e-01, -3.9040e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -7.3678e-02,  0.0000e+00,  0.0000e+00,
-        -1.5462e-02,  1.7988e-01,  0.0000e+00, -8.6480e-01, -3.9734e-02,
-         0.0000e+00,  4.7885e-02,  1.4311e-01, -2.7588e-02,  0.0000e+00,
-        -5.5598e-01, -1.2597e+00,  0.0000e+00, -1.1899e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0430e-01,  4.2841e-01,
-         9.4905e-04,  0.0000e+00, -2.6628e-01,  0.0000e+00,  0.0000e+00,
-        -2.5953e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  6.3458e-02, -1.3199e-01,  4.3065e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.2981e-01,  0.0000e+00,
-         0.0000e+00, -1.0960e-01,  0.0000e+00,  3.9359e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.3598e-01, -1.4976e-12, -1.5898e-03, -7.6163e-03,  2.0984e+00,
-         4.0337e-02,  9.0651e-22, -7.1580e-15, -1.6624e-01,  2.4136e-02,
-         1.1244e-01,  1.6297e-13, -2.5515e-01, -3.8654e-01, -1.6256e-18,
-        -1.2082e-15, -1.1885e-16, -6.7193e-02, -2.7489e-19,  3.3051e-15,
-        -1.9767e-02,  1.7957e-01, -5.8376e-19, -8.5995e-01, -3.6186e-02,
-         4.7378e-15,  4.4966e-02,  1.3375e-01, -2.4861e-02,  8.4900e-19,
-        -5.5384e-01, -1.2603e+00, -2.1733e-22, -1.0840e-01,  1.0787e-16,
-         5.8202e-15,  3.5983e-18,  0.0000e+00,  9.8809e-02,  4.2213e-01,
-        -5.4346e-03, -4.8061e-16, -2.6675e-01,  8.8557e-19,  2.1660e-16,
-        -2.6435e-01, -3.0562e-14,  9.8615e-08,  1.0281e-13, -5.6377e-20,
-        -1.8241e-10,  6.7144e-02, -1.4860e-01,  4.1649e-02, -2.0660e-11,
-         1.1733e-23, -4.6113e-18,  1.6584e-21,  4.3496e-01, -2.3951e-13,
-         4.4299e-21, -1.1482e-01,  3.7220e-09,  2.1484e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.3598e-01,  0.0000e+00, -1.5898e-03, -7.6163e-03,  2.0984e+00,
-         4.0337e-02,  0.0000e+00,  0.0000e+00, -1.6624e-01,  2.4136e-02,
-         1.1244e-01,  0.0000e+00, -2.5515e-01, -3.8654e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -6.7193e-02,  0.0000e+00,  0.0000e+00,
-        -1.9767e-02,  1.7957e-01,  0.0000e+00, -8.5995e-01, -3.6186e-02,
-         0.0000e+00,  4.4966e-02,  1.3375e-01, -2.4861e-02,  0.0000e+00,
-        -5.5384e-01, -1.2603e+00,  0.0000e+00, -1.0840e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  9.8809e-02,  4.2213e-01,
-        -5.4346e-03,  0.0000e+00, -2.6675e-01,  0.0000e+00,  0.0000e+00,
-        -2.6435e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  6.7144e-02, -1.4860e-01,  4.1649e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.3496e-01,  0.0000e+00,
-         0.0000e+00, -1.1482e-01,  0.0000e+00,  2.1484e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.3598e-01,  0.0000e+00, -1.5898e-03, -7.6163e-03,  2.0984e+00,
-         4.0337e-02,  0.0000e+00,  0.0000e+00, -1.6624e-01,  2.4136e-02,
-         1.1244e-01,  0.0000e+00, -2.5515e-01, -3.8654e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -6.7193e-02,  0.0000e+00,  0.0000e+00,
-        -1.9767e-02,  1.7957e-01,  0.0000e+00, -8.5995e-01, -3.6186e-02,
-         0.0000e+00,  4.4966e-02,  1.3375e-01, -2.4861e-02,  0.0000e+00,
-        -5.5384e-01, -1.2603e+00,  0.0000e+00, -1.0840e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  9.8809e-02,  4.2213e-01,
-        -5.4346e-03,  0.0000e+00, -2.6675e-01,  0.0000e+00,  0.0000e+00,
-        -2.6435e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  6.7144e-02, -1.4860e-01,  4.1649e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.3496e-01,  0.0000e+00,
-         0.0000e+00, -1.1482e-01,  0.0000e+00,  2.1484e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.2789e-01, -1.3677e-12,  1.2657e-03, -2.4517e-03,  2.0976e+00,
-         4.9379e-02,  8.2787e-22, -6.5370e-15, -1.7040e-01,  1.2490e-02,
-         7.4740e-02,  1.4883e-13, -2.5726e-01, -3.8518e-01, -1.4846e-18,
-        -1.1034e-15, -1.0854e-16, -6.4675e-02, -2.5104e-19,  3.0183e-15,
-        -2.4024e-02,  1.7828e-01, -5.3311e-19, -8.5657e-01, -3.1500e-02,
-         4.3268e-15,  4.5143e-02,  1.2314e-01, -2.2946e-02,  7.7534e-19,
-        -5.4934e-01, -1.2608e+00, -1.9847e-22, -1.0246e-01,  9.8511e-17,
-         5.3153e-15,  3.2861e-18,  0.0000e+00,  9.4034e-02,  4.1692e-01,
-        -1.2646e-02, -4.3891e-16, -2.6501e-01,  8.0874e-19,  1.9781e-16,
-        -2.6421e-01, -2.7910e-14,  9.0059e-08,  9.3891e-14, -5.1486e-20,
-        -1.6659e-10,  7.1764e-02, -1.6459e-01,  3.8303e-02, -1.8867e-11,
-         1.0715e-23, -4.2112e-18,  1.5145e-21,  4.3869e-01, -2.1873e-13,
-         4.0456e-21, -1.2081e-01,  3.3991e-09,  4.8951e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 3.2789e-01,  0.0000e+00,  1.2657e-03, -2.4517e-03,  2.0976e+00,
-         4.9379e-02,  0.0000e+00,  0.0000e+00, -1.7040e-01,  1.2490e-02,
-         7.4740e-02,  0.0000e+00, -2.5726e-01, -3.8518e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -6.4675e-02,  0.0000e+00,  0.0000e+00,
-        -2.4024e-02,  1.7828e-01,  0.0000e+00, -8.5657e-01, -3.1500e-02,
-         0.0000e+00,  4.5143e-02,  1.2314e-01, -2.2946e-02,  0.0000e+00,
-        -5.4934e-01, -1.2608e+00,  0.0000e+00, -1.0246e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  9.4034e-02,  4.1692e-01,
-        -1.2646e-02,  0.0000e+00, -2.6501e-01,  0.0000e+00,  0.0000e+00,
-        -2.6421e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  7.1764e-02, -1.6459e-01,  3.8303e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.3869e-01,  0.0000e+00,
-         0.0000e+00, -1.2081e-01,  0.0000e+00,  4.8951e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Before Step tensor([ 3.2789e-01,  0.0000e+00,  1.2657e-03, -2.4517e-03,  2.0976e+00,
-         4.9379e-02,  0.0000e+00,  0.0000e+00, -1.7040e-01,  1.2490e-02,
-         7.4740e-02,  0.0000e+00, -2.5726e-01, -3.8518e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00, -6.4675e-02,  0.0000e+00,  0.0000e+00,
-        -2.4024e-02,  1.7828e-01,  0.0000e+00, -8.5657e-01, -3.1500e-02,
-         0.0000e+00,  4.5143e-02,  1.2314e-01, -2.2946e-02,  0.0000e+00,
-        -5.4934e-01, -1.2608e+00,  0.0000e+00, -1.0246e-01,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  9.4034e-02,  4.1692e-01,
-        -1.2646e-02,  0.0000e+00, -2.6501e-01,  0.0000e+00,  0.0000e+00,
-        -2.6421e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         0.0000e+00,  7.1764e-02, -1.6459e-01,  3.8303e-02,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  4.3869e-01,  0.0000e+00,
-         0.0000e+00, -1.2081e-01,  0.0000e+00,  4.8951e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Batch tensor([ 3.1720e-01, -1.2485e-12,  3.7605e-03,  2.4614e-03,  2.0967e+00,
-         5.4613e-02,  7.5576e-22, -5.9676e-15, -1.7414e-01,  3.9396e-03,
-         5.1275e-02,  1.3587e-13, -2.6045e-01, -3.8550e-01, -1.3553e-18,
-        -1.0073e-15, -9.9084e-17, -6.2618e-02, -2.2917e-19,  2.7554e-15,
-        -2.9798e-02,  1.7863e-01, -4.8668e-19, -8.5379e-01, -2.6081e-02,
-         3.9500e-15,  4.6201e-02,  1.1365e-01, -2.2987e-02,  7.0781e-19,
-        -5.4629e-01, -1.2616e+00, -1.8119e-22, -9.7161e-02,  8.9932e-17,
-         4.8524e-15,  2.9999e-18,  0.0000e+00,  8.8029e-02,  4.1006e-01,
-        -2.1547e-02, -4.0069e-16, -2.6291e-01,  7.3830e-19,  1.8058e-16,
-        -2.6403e-01, -2.5480e-14,  8.2216e-08,  8.5713e-14, -4.7002e-20,
-        -1.5208e-10,  7.4839e-02, -1.7510e-01,  3.3851e-02, -1.7224e-11,
-         9.7819e-24, -3.8445e-18,  1.3826e-21,  4.4234e-01, -1.9968e-13,
-         3.6933e-21, -1.2717e-01,  3.1031e-09, -6.2694e-03], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3172,  0.0000,  0.0038,  0.0025,  2.0967,  0.0546,  0.0000,  0.0000,
-        -0.1741,  0.0039,  0.0513,  0.0000, -0.2604, -0.3855,  0.0000,  0.0000,
-         0.0000, -0.0626,  0.0000,  0.0000, -0.0298,  0.1786,  0.0000, -0.8538,
-        -0.0261,  0.0000,  0.0462,  0.1137, -0.0230,  0.0000, -0.5463, -1.2616,
-         0.0000, -0.0972,  0.0000,  0.0000,  0.0000,  0.0000,  0.0880,  0.4101,
-        -0.0215,  0.0000, -0.2629,  0.0000,  0.0000, -0.2640,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0748, -0.1751,  0.0339,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4423,  0.0000,  0.0000, -0.1272,  0.0000, -0.0063],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3172,  0.0000,  0.0038,  0.0025,  2.0967,  0.0546,  0.0000,  0.0000,
-        -0.1741,  0.0039,  0.0513,  0.0000, -0.2604, -0.3855,  0.0000,  0.0000,
-         0.0000, -0.0626,  0.0000,  0.0000, -0.0298,  0.1786,  0.0000, -0.8538,
-        -0.0261,  0.0000,  0.0462,  0.1137, -0.0230,  0.0000, -0.5463, -1.2616,
-         0.0000, -0.0972,  0.0000,  0.0000,  0.0000,  0.0000,  0.0880,  0.4101,
-        -0.0215,  0.0000, -0.2629,  0.0000,  0.0000, -0.2640,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0748, -0.1751,  0.0339,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4423,  0.0000,  0.0000, -0.1272,  0.0000, -0.0063],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.0896e-01, -1.1394e-12,  8.9681e-03,  4.0967e-03,  2.0957e+00,
-         5.9799e-02,  6.8967e-22, -5.4457e-15, -1.7904e-01, -5.8454e-03,
-         3.3818e-02,  1.2398e-13, -2.6278e-01, -3.8730e-01, -1.2368e-18,
-        -9.1922e-16, -9.0419e-17, -6.1595e-02, -2.0913e-19,  2.5145e-15,
-        -3.6718e-02,  1.7530e-01, -4.4412e-19, -8.5079e-01, -2.4585e-02,
-         3.6045e-15,  4.9239e-02,  1.0103e-01, -2.3903e-02,  6.4591e-19,
-        -5.4337e-01, -1.2622e+00, -1.6534e-22, -9.2914e-02,  8.2067e-17,
-         4.4280e-15,  2.7376e-18,  0.0000e+00,  8.4866e-02,  4.0620e-01,
-        -3.2085e-02, -3.6565e-16, -2.5837e-01,  6.7374e-19,  1.6479e-16,
-        -2.6700e-01, -2.3251e-14,  7.5026e-08,  7.8217e-14, -4.2892e-20,
-        -1.3878e-10,  7.1584e-02, -1.9056e-01,  2.5945e-02, -1.5718e-11,
-         8.9265e-24, -3.5082e-18,  1.2617e-21,  4.4414e-01, -1.8222e-13,
-         3.3703e-21, -1.3364e-01,  2.8317e-09, -1.4528e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3090,  0.0000,  0.0090,  0.0041,  2.0957,  0.0598,  0.0000,  0.0000,
-        -0.1790, -0.0058,  0.0338,  0.0000, -0.2628, -0.3873,  0.0000,  0.0000,
-         0.0000, -0.0616,  0.0000,  0.0000, -0.0367,  0.1753,  0.0000, -0.8508,
-        -0.0246,  0.0000,  0.0492,  0.1010, -0.0239,  0.0000, -0.5434, -1.2622,
-         0.0000, -0.0929,  0.0000,  0.0000,  0.0000,  0.0000,  0.0849,  0.4062,
-        -0.0321,  0.0000, -0.2584,  0.0000,  0.0000, -0.2670,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0716, -0.1906,  0.0259,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4441,  0.0000,  0.0000, -0.1336,  0.0000, -0.0145],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3090,  0.0000,  0.0090,  0.0041,  2.0957,  0.0598,  0.0000,  0.0000,
-        -0.1790, -0.0058,  0.0338,  0.0000, -0.2628, -0.3873,  0.0000,  0.0000,
-         0.0000, -0.0616,  0.0000,  0.0000, -0.0367,  0.1753,  0.0000, -0.8508,
-        -0.0246,  0.0000,  0.0492,  0.1010, -0.0239,  0.0000, -0.5434, -1.2622,
-         0.0000, -0.0929,  0.0000,  0.0000,  0.0000,  0.0000,  0.0849,  0.4062,
-        -0.0321,  0.0000, -0.2584,  0.0000,  0.0000, -0.2670,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0716, -0.1906,  0.0259,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4441,  0.0000,  0.0000, -0.1336,  0.0000, -0.0145],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 3.0148e-01, -1.0393e-12,  1.4269e-02,  7.6697e-03,  2.0947e+00,
-         6.3072e-02,  6.2909e-22, -4.9674e-15, -1.8241e-01, -1.2662e-02,
-         2.2120e-02,  1.1309e-13, -2.6403e-01, -3.8822e-01, -1.1281e-18,
-        -8.3848e-16, -8.2477e-17, -6.0704e-02, -1.9076e-19,  2.2936e-15,
-        -4.3446e-02,  1.7471e-01, -4.0511e-19, -8.4693e-01, -2.0560e-02,
-         3.2879e-15,  5.1424e-02,  8.8464e-02, -2.5438e-02,  5.8918e-19,
-        -5.4152e-01, -1.2621e+00, -1.5082e-22, -8.9325e-02,  7.4859e-17,
-         4.0391e-15,  2.4971e-18,  0.0000e+00,  8.4010e-02,  4.0333e-01,
-        -4.2808e-02, -3.3353e-16, -2.5354e-01,  6.1456e-19,  1.5032e-16,
-        -2.7064e-01, -2.1209e-14,  6.8436e-08,  7.1347e-14, -3.9124e-20,
-        -1.2659e-10,  7.1597e-02, -2.0416e-01,  1.8862e-02, -1.4337e-11,
-         8.1424e-24, -3.2001e-18,  1.1509e-21,  4.4617e-01, -1.6621e-13,
-         3.0743e-21, -1.3942e-01,  2.5830e-09, -1.9292e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.3015,  0.0000,  0.0143,  0.0077,  2.0947,  0.0631,  0.0000,  0.0000,
-        -0.1824, -0.0127,  0.0221,  0.0000, -0.2640, -0.3882,  0.0000,  0.0000,
-         0.0000, -0.0607,  0.0000,  0.0000, -0.0434,  0.1747,  0.0000, -0.8469,
-        -0.0206,  0.0000,  0.0514,  0.0885, -0.0254,  0.0000, -0.5415, -1.2621,
-         0.0000, -0.0893,  0.0000,  0.0000,  0.0000,  0.0000,  0.0840,  0.4033,
-        -0.0428,  0.0000, -0.2535,  0.0000,  0.0000, -0.2706,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0716, -0.2042,  0.0189,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4462,  0.0000,  0.0000, -0.1394,  0.0000, -0.0193],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.3015,  0.0000,  0.0143,  0.0077,  2.0947,  0.0631,  0.0000,  0.0000,
-        -0.1824, -0.0127,  0.0221,  0.0000, -0.2640, -0.3882,  0.0000,  0.0000,
-         0.0000, -0.0607,  0.0000,  0.0000, -0.0434,  0.1747,  0.0000, -0.8469,
-        -0.0206,  0.0000,  0.0514,  0.0885, -0.0254,  0.0000, -0.5415, -1.2621,
-         0.0000, -0.0893,  0.0000,  0.0000,  0.0000,  0.0000,  0.0840,  0.4033,
-        -0.0428,  0.0000, -0.2535,  0.0000,  0.0000, -0.2706,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0716, -0.2042,  0.0189,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4462,  0.0000,  0.0000, -0.1394,  0.0000, -0.0193],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.9396e-01, -9.4758e-13,  1.8215e-02,  9.7023e-03,  2.0938e+00,
-         6.4198e-02,  5.7358e-22, -4.5291e-15, -1.8744e-01, -1.7959e-02,
-         1.4392e-02,  1.0311e-13, -2.6694e-01, -3.8918e-01, -1.0286e-18,
-        -7.6449e-16, -7.5200e-17, -5.9900e-02, -1.7393e-19,  2.0912e-15,
-        -5.1187e-02,  1.7087e-01, -3.6937e-19, -8.4397e-01, -1.9693e-02,
-         2.9978e-15,  5.5533e-02,  7.5367e-02, -2.7458e-02,  5.3719e-19,
-        -5.4007e-01, -1.2621e+00, -1.3751e-22, -8.5249e-02,  6.8253e-17,
-         3.6827e-15,  2.2768e-18,  0.0000e+00,  8.2562e-02,  4.0136e-01,
-        -5.4306e-02, -3.0410e-16, -2.4833e-01,  5.6033e-19,  1.3705e-16,
-        -2.7231e-01, -1.9338e-14,  6.2397e-08,  6.5052e-14, -3.5672e-20,
-        -1.1542e-10,  6.7050e-02, -2.1400e-01,  1.1176e-02, -1.3072e-11,
-         7.4239e-24, -2.9177e-18,  1.0493e-21,  4.4850e-01, -1.5154e-13,
-         2.8030e-21, -1.4395e-01,  2.3551e-09, -1.9831e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-After Step tensor([ 0.2940,  0.0000,  0.0182,  0.0097,  2.0938,  0.0642,  0.0000,  0.0000,
-        -0.1874, -0.0180,  0.0144,  0.0000, -0.2669, -0.3892,  0.0000,  0.0000,
-         0.0000, -0.0599,  0.0000,  0.0000, -0.0512,  0.1709,  0.0000, -0.8440,
-        -0.0197,  0.0000,  0.0555,  0.0754, -0.0275,  0.0000, -0.5401, -1.2621,
-         0.0000, -0.0852,  0.0000,  0.0000,  0.0000,  0.0000,  0.0826,  0.4014,
-        -0.0543,  0.0000, -0.2483,  0.0000,  0.0000, -0.2723,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0671, -0.2140,  0.0112,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4485,  0.0000,  0.0000, -0.1440,  0.0000, -0.0198],
-       device='cuda:0', grad_fn=<SumBackward1>)
-Before Step tensor([ 0.2940,  0.0000,  0.0182,  0.0097,  2.0938,  0.0642,  0.0000,  0.0000,
-        -0.1874, -0.0180,  0.0144,  0.0000, -0.2669, -0.3892,  0.0000,  0.0000,
-         0.0000, -0.0599,  0.0000,  0.0000, -0.0512,  0.1709,  0.0000, -0.8440,
-        -0.0197,  0.0000,  0.0555,  0.0754, -0.0275,  0.0000, -0.5401, -1.2621,
-         0.0000, -0.0852,  0.0000,  0.0000,  0.0000,  0.0000,  0.0826,  0.4014,
-        -0.0543,  0.0000, -0.2483,  0.0000,  0.0000, -0.2723,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.0000,  0.0671, -0.2140,  0.0112,  0.0000,  0.0000,
-         0.0000,  0.0000,  0.4485,  0.0000,  0.0000, -0.1440,  0.0000, -0.0198],
-       device='cuda:0', grad_fn=<SumBackward1>)
-After Batch tensor([ 2.8364e-01, -8.6356e-13,  2.3995e-02,  1.3488e-02,  2.0936e+00,
-         6.5073e-02,  5.2273e-22, -4.1275e-15, -1.9212e-01, -2.2407e-02,
-         1.6205e-02,  9.3972e-14, -2.6933e-01, -3.8867e-01, -9.3740e-19,
-        -6.9671e-16, -6.8532e-17, -5.2110e-02, -1.5851e-19,  1.9058e-15,
-        -6.0010e-02,  1.6931e-01, -3.3661e-19, -8.3944e-01, -1.8648e-02,
-         2.7320e-15,  5.8916e-02,  6.1481e-02, -3.2871e-02,  4.8956e-19,
-        -5.3908e-01, -1.2617e+00, -1.2532e-22, -8.2074e-02,  6.2201e-17,
-         3.3562e-15,  2.0749e-18,  0.0000e+00,  8.2516e-02,  3.9997e-01,
-        -6.5787e-02, -2.7714e-16, -2.4107e-01,  5.1065e-19,  1.2490e-16,
-        -2.7515e-01, -1.7623e-14,  5.6865e-08,  5.9284e-14, -3.2509e-20,
-        -1.0518e-10,  6.5321e-02, -2.2398e-01,  6.6516e-03, -1.1913e-11,
-         6.7657e-24, -2.6590e-18,  9.5628e-22,  4.5156e-01, -1.3811e-13,
-         2.5545e-21, -1.4627e-01,  2.1462e-09, -2.4427e-02], device='cuda:0',
-       grad_fn=<SumBackward1>)
-Sparsity at the end of epoch 2: 50.00%
-Final Sparsity: 50.00
-Sparsity in Conv2d 2: 1.56%
-Sparsity in Conv2d 8: 1.56%
-Sparsity in Conv2d 11: 1.56%
-Sparsity in Conv2d 14: 1.56%
-Sparsity in Conv2d 17: 1.56%
-Sparsity in Conv2d 21: 0.78%
-Sparsity in Conv2d 24: 0.78%
-Sparsity in Conv2d 27: 0.78%
-Sparsity in Conv2d 30: 0.78%
-Sparsity in Conv2d 33: 0.78%
-Sparsity in Conv2d 37: 0.39%
-Sparsity in Conv2d 40: 0.39%
-Sparsity in Conv2d 43: 0.39%
-Sparsity in Conv2d 46: 0.39%
-Sparsity in Conv2d 49: 0.39%
-Sparsity in Conv2d 53: 0.20%
-Sparsity in Conv2d 56: 0.20%
-Sparsity in Conv2d 59: 0.20%
-Sparsity in Conv2d 62: 0.20%
-Sparsity in Conv2d 65: 0.20%
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
model.model.conv1.weight.sum(dim=(1,2,3))
-
- -
-
-
- -
-
- -
- - - -
-
tensor([ 2.9347e-01, -2.7638e-15, -4.3429e-01,  1.5531e-20, -1.2244e-01,
-         6.7792e-02,  3.4213e-24,  9.2662e-15, -2.5555e-01, -6.6723e-11,
-        -1.1368e-02, -5.4554e-18,  2.7437e-02, -3.6576e-12,  1.6695e-18,
-        -8.0519e-02,  6.7549e-18,  6.4657e-02,  5.7248e-18,  3.1335e-17,
-        -6.9838e-14, -1.6188e-02,  1.2506e-20,  5.0455e-01,  1.3777e-13,
-        -6.4526e-19, -3.7569e-02, -1.2282e-14,  6.2495e-02, -1.4700e-18,
-        -2.6848e-01,  9.4839e-02,  9.6079e-22,  1.5481e-01, -4.7590e-19,
-         2.1518e-14, -7.0799e-16,  0.0000e+00,  1.6172e+00,  5.7085e-01,
-        -6.2181e-02, -3.7426e-01,  1.1096e-01, -6.0660e-16, -5.0897e-22,
-        -1.4613e-01, -2.6145e-12, -1.7860e-08,  3.6786e-10, -3.4189e-17,
-         5.0733e-13,  1.2981e-01, -9.3539e-01, -1.3682e-01, -5.1219e-01,
-        -2.5171e-02, -9.8362e-02, -3.2823e-23, -1.1528e-15, -1.0429e+00,
-        -1.0777e-19, -1.6025e-01,  1.1684e-02,  8.0589e-02],
-       grad_fn=<SumBackward1>)
-
- -
- -
-
- -
- {% endraw %} - -
-
-

Library Agnostic Callback

-
-
-
- {% raw %} - -
-
- -
-
-
learn.__class__.__name__
-
- -
-
-
- -
-
- -
- - - -
-
'Learner'
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
trainer.__class__.__name__
-
- -
-
-
- -
-
- -
- - - -
-
'Trainer'
-
- -
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class SparsifyCB():
-    def __init__(self, end_sparsity, granularity, method, criteria, sched_func, start_sparsity=0, start_epoch=0, end_epoch=None, lth=False, rewind_epoch=0, reset_end=False, model=None, round_to=None, layer_type=nn.Conv2d):
-        self.end_sparsity = end_sparsity
-        self.granularity, self.method, self.criteria, self.sched_func = granularity, method, criteria, sched_func
-        self.start_sparsity, self.start_epoch, self.end_epoch = start_sparsity, start_epoch, end_epoch
-        self.lth, self.rewind_epoch, self.reset_end = lth, rewind_epoch, reset_end
-        self.model = model
-        self.round_to = round_to
-        self.layer_type = layer_type
-        self.train_iter = 0
-        self.current_sparsity, self.previous_sparsity = 0, 0
-        assert self.start_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'
-        print("Starting to init trainer!")
-    
-    
-    def setup(self, n_epoch, learn, n_batches):
-        print(f'Pruning of {self.granularity} until a sparsity of {self.end_sparsity}%')
-        self.end_epoch = n_epoch if self.end_epoch is None else self.end_epoch
-        assert self.end_epoch <= n_epoch, 'Your end_epoch must be smaller than total number of epoch'
-
-        self.model = learn.model if self.model is None else self.model # Pass a model if you don't want the whole model to be pruned
-        self.sparsifier = Sparsifier(self.model, self.granularity, self.method, self.criteria, self.layer_type)
-            
-        self.total_iters = self.end_epoch * n_batches
-        self.start_iter = self.start_epoch * n_batches
-
-    def save_weigths(self, epoch):
-        if epoch == self.rewind_epoch:
-            print(f'Saving Weights at epoch {epoch}')
-            self.sparsifier._save_weights()
-
-    def prune_weights(self, epoch, train_iter):
-        if epoch>=self.start_epoch:
-            if epoch < self.end_epoch: self._set_sparsity(train_iter)
-            self.sparsifier.prune_model(self.current_sparsity, self.round_to)
-
-            if self.lth and self.current_sparsity!=self.previous_sparsity: # If sparsity has changed, the network has been pruned
-                    print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
-                    self.sparsifier._reset_weights()
-
-            self.previous_sparsity = self.current_sparsity
-
-    def prune_gradients(self, epoch):
-        if epoch>=self.start_epoch:
-            self.sparsifier._mask_grad()
-
-    def print_sparsity(self, epoch):
-        print(f'Sparsity at the end of epoch {epoch}: {self.current_sparsity:.2f}%')
-
-    def clean(self):
-        print(f'Final Sparsity: {self.current_sparsity:.2f}')
-        if self.reset_end:
-            self.sparsifier._reset_weights()
-        self.sparsifier._clean_buffers() # Remove buffers at the end of training
-        self.sparsifier.print_sparsity()
-
-    def _set_sparsity(self, train_iter):
-        self.current_sparsity = self.sched_func(start=self.start_sparsity, end=self.end_sparsity, pos=(train_iter-self.start_iter)/(self.total_iters-self.start_iter))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class SparsifyCallback(fastai.callback.all.Callback):
-    def __init__(self, end_sparsity, granularity, method, criteria, sched_func, start_sparsity=0, start_epoch=0, end_epoch=None, lth=False, rewind_epoch=0, reset_end=False, model=None, round_to=None, layer_type=nn.Conv2d):
-        store_attr()
-        self.cb = SparsifyCB(self.end_sparsity, self.granularity, self.method, self.criteria, self.sched_func, self.start_sparsity, self.start_epoch, self.end_epoch, self.lth, self.rewind_epoch, self.reset_end, self.model, self.round_to, self.layer_type) 
-
-    def before_fit(self):
-        n_batches = math.floor(len(self.learn.dls.dataset)/self.learn.dls.bs)
-        self.cb.setup(self.n_epoch, self.learn, n_batches)
-
-    def before_epoch(self):
-        self.cb.save_weigths(self.epoch)
-
-    def before_batch(self):
-        self.cb.prune_weights(self.epoch, self.train_iter)
-
-    def before_step(self):
-        self.cb.prune_gradients(self.epoch)
-
-    def after_epoch(self):
-        self.cb.print_sparsity(self.epoch)
-
-    def after_fit(self):
-        self.cb.clean()
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
class SparsifyCallback(pytorch_lightning.callbacks.Callback):
-    def __init__(self, end_sparsity, granularity, method, criteria, sched_func, start_sparsity=0, start_epoch=0, end_epoch=None, lth=False, rewind_epoch=0, reset_end=False, model=None, round_to=None, layer_type=nn.Conv2d):
-        store_attr('end_sparsity, granularity, method, criteria, sched_func, start_sparsity, start_epoch, end_epoch, lth, rewind_epoch, reset_end, model, round_to, layer_type')
-        self.cb = SparsifyCB(self.end_sparsity, self.granularity, self.method, self.criteria, self.sched_func, self.start_sparsity, self.start_epoch, self.end_epoch, self.lth, self.rewind_epoch, self.reset_end, self.model, self.round_to, self.layer_type) 
-        self.train_iter=0
-        
-    def on_fit_start(self, trainer, pl_module):
-        n_batches = math.floor(len(trainer.datamodule.dataset_train)/trainer.datamodule.batch_size)
-        self.cb.setup(trainer.max_epochs, trainer, n_batches)
-        
-    def on_fit_end(self, trainer, pl_module):
-        self.cb.clean()
-        
-    def on_train_epoch_start(self, trainer, pl_module):
-        self.cb.save_weigths(trainer.current_epoch)
-        
-    def on_train_epoch_end(self, trainer, pl_module):
-        self.cb.print_sparsity(trainer.current_epoch)
-
-    def on_batch_start(self, trainer, pl_module):
-        self.train_iter+=1
-        self.cb.prune_weights(trainer.current_epoch, self.train_iter)
-
-    def on_after_backward(self, trainer, pl_module):
-        self.cb.prune_gradients(trainer.current_epoch)
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb = SparsifyCallback(end_sparsity=50, granularity='filter', method='local', criteria=large_final, sched_func=sched_onecycle)
-
- -
-
-
- -
-
- -
- -
-
Starting to init trainer!
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
from torchmetrics.functional.classification import accuracy
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
model = LitResnet(lr=0.05)
-model.datamodule = cifar10_dm
-
-trainer = Trainer(
-    progress_bar_refresh_rate=10,
-    max_epochs=2,
-    gpus=AVAIL_GPUS,
-    callbacks=[sp_cb],
-)
-
-trainer.fit(model, cifar10_dm)
-#trainer.test(model, datamodule=cifar10_dm)
-
- -
-
-
- -
-
- -
- -
-
/home/HubensN/miniconda3/envs/deep/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:91: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=10)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer.
-  f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"
-GPU available: True, used: True
-TPU available: False, using: 0 TPU cores
-IPU available: False, using: 0 IPUs
-
-
-
- -
- -
-
Files already downloaded and verified
-Files already downloaded and verified
-
-
-
- -
- -
-
/home/HubensN/miniconda3/envs/deep/lib/python3.7/site-packages/pytorch_lightning/core/datamodule.py:115: LightningDeprecationWarning: DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7.
-  "DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7."
-/home/HubensN/miniconda3/envs/deep/lib/python3.7/site-packages/pytorch_lightning/core/datamodule.py:134: LightningDeprecationWarning: DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7.
-  "DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7."
-LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
-
-  | Name  | Type   | Params
----------------------------------
-0 | model | ResNet | 11.2 M
----------------------------------
-11.2 M    Trainable params
-0         Non-trainable params
-11.2 M    Total params
-44.696    Total estimated model params size (MB)
-
-
-
- -
- -
-
Pruning of filter until a sparsity of 50%
-
-
-
- -
- -
-
Global seed set to 7
-
-
-
- -
- -
-
Saving Weights at epoch 0
-Sparsity at the end of epoch 0: 37.00%
-Sparsity at the end of epoch 1: 50.00%
-Final Sparsity: 50.00
-Sparsity in Conv2d 2: 1.56%
-Sparsity in Conv2d 8: 1.56%
-Sparsity in Conv2d 11: 1.56%
-Sparsity in Conv2d 14: 1.56%
-Sparsity in Conv2d 17: 1.56%
-Sparsity in Conv2d 21: 0.78%
-Sparsity in Conv2d 24: 0.78%
-Sparsity in Conv2d 27: 0.78%
-Sparsity in Conv2d 30: 0.78%
-Sparsity in Conv2d 33: 0.78%
-Sparsity in Conv2d 37: 0.39%
-Sparsity in Conv2d 40: 0.39%
-Sparsity in Conv2d 43: 0.39%
-Sparsity in Conv2d 46: 0.39%
-Sparsity in Conv2d 49: 0.39%
-Sparsity in Conv2d 53: 0.20%
-Sparsity in Conv2d 56: 0.20%
-Sparsity in Conv2d 59: 0.20%
-Sparsity in Conv2d 62: 0.20%
-Sparsity in Conv2d 65: 0.20%
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
from fastai.vision.all import *
-from fastai.callback.all import *
-from fasterai.sparse.sparsifier import *
-from fasterai.sparse.criteria import *
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
path = untar_data(URLs.PETS)
-files = get_image_files(path/"images")
-
-def label_func(f): return f[0].isupper()
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
learn = cnn_learner(dls, resnet18, metrics=accuracy)
-learn.unfreeze()
-
- -
-
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
sp_cb = SparsifyCallback(end_sparsity=50, granularity='filter', method='local', criteria=large_final, sched_func=sched_cos)
-
- -
-
-
- -
-
- -
- -
-
Starting to init trainer!
-
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
learn.fit_one_cycle(2, cbs=sp_cb)
-
- -
-
-
- -
-
- -
- -
-
----------------------------------------------------------------------------
-AttributeError                            Traceback (most recent call last)
-/tmp/ipykernel_423329/1783024704.py in <module>
-----> 1 learn.fit_one_cycle(2, cbs=sp_cb)
-
-~/miniconda3/envs/deep/lib/python3.7/site-packages/fastai/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt)
-    114     scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
-    115               'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
---> 116     self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
-    117 
-    118 # Cell
-
-~/miniconda3/envs/deep/lib/python3.7/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
-    213 
-    214     def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):
---> 215         with self.added_cbs(cbs):
-    216             if reset_opt or not self.opt: self.create_opt()
-    217             if wd is None: wd = self.wd
-
-~/miniconda3/envs/deep/lib/python3.7/contextlib.py in __enter__(self)
-    110         del self.args, self.kwds, self.func
-    111         try:
---> 112             return next(self.gen)
-    113         except StopIteration:
-    114             raise RuntimeError("generator didn't yield") from None
-
-~/miniconda3/envs/deep/lib/python3.7/site-packages/fastai/learner.py in added_cbs(self, cbs)
-    128     @contextmanager
-    129     def added_cbs(self, cbs):
---> 130         self.add_cbs(cbs)
-    131         try: yield
-    132         finally: self.remove_cbs(cbs)
-
-~/miniconda3/envs/deep/lib/python3.7/site-packages/fastai/learner.py in add_cbs(self, cbs)
-    104 
-    105     def add_cbs(self, cbs):
---> 106         L(cbs).map(self.add_cb)
-    107         return self
-    108 
-
-~/miniconda3/envs/deep/lib/python3.7/site-packages/fastcore/foundation.py in map(self, f, gen, *args, **kwargs)
-    152     def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step))
-    153 
---> 154     def map(self, f, *args, gen=False, **kwargs): return self._new(map_ex(self, f, *args, gen=gen, **kwargs))
-    155     def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, negate, **kwargs))
-    156     def argfirst(self, f, negate=False): return first(i for i,o in self.enumerate() if f(o))
-
-~/miniconda3/envs/deep/lib/python3.7/site-packages/fastcore/basics.py in map_ex(iterable, f, gen, *args, **kwargs)
-    664     res = map(g, iterable)
-    665     if gen: return res
---> 666     return list(res)
-    667 
-    668 # Cell
-
-~/miniconda3/envs/deep/lib/python3.7/site-packages/fastcore/basics.py in __call__(self, *args, **kwargs)
-    649             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
-    650         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
---> 651         return self.func(*fargs, **kwargs)
-    652 
-    653 # Cell
-
-~/miniconda3/envs/deep/lib/python3.7/site-packages/fastai/learner.py in add_cb(self, cb)
-    114         if isinstance(cb, type): cb = cb()
-    115         cb.learn = self
---> 116         setattr(self, cb.name, cb)
-    117         self.cbs.append(cb)
-    118         return self
-
-AttributeError: 'SparsifyCallback' object has no attribute 'name'
-
-
- -
-
- -
- {% endraw %} - - {% raw %} - -
-
- -
-
-
learn.model[0][0].weight.mean(dim=(1,2,3))
-
- -
-
-
- -
-
- -
- - - -
-
tensor([-5.5175e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-         4.7389e-03, -2.2442e-03,  0.0000e+00, -7.6861e-04,  0.0000e+00,
-         7.8788e-03,  3.2037e-03,  0.0000e+00,  0.0000e+00, -4.2978e-04,
-         3.6307e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
-        -4.3683e-03,  0.0000e+00,  0.0000e+00, -5.1489e-03, -1.7260e-03,
-         3.9044e-03,  2.3079e-03,  0.0000e+00, -1.3950e-03,  0.0000e+00,
-         1.9790e-03,  6.6928e-03,  0.0000e+00,  6.8169e-03,  0.0000e+00,
-         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.2592e-03,
-        -1.2275e-04,  2.2800e-03, -4.0261e-03,  3.7364e-03,  3.0080e-03,
-         0.0000e+00,  0.0000e+00,  3.4525e-03,  0.0000e+00, -3.2564e-03,
-         1.9701e-03, -6.7768e-03,  0.0000e+00, -2.9157e-04,  0.0000e+00,
-        -4.0615e-03,  0.0000e+00,  0.0000e+00,  0.0000e+00, -9.0066e-05,
-        -7.1053e-03,  0.0000e+00,  0.0000e+00, -1.7526e-03], device='cuda:0',
-       grad_fn=<MeanBackward1>)
-
- -
- -
-
- -
- {% endraw %} - -
- - diff --git a/fasterai/__init__.py b/fasterai/__init__.py index 788da1f..fe404ae 100644 --- a/fasterai/__init__.py +++ b/fasterai/__init__.py @@ -1 +1 @@ -__version__ = "0.2.4" +__version__ = "0.2.5" diff --git a/fasterai/_modidx.py b/fasterai/_modidx.py index 94d613c..0657709 100644 --- a/fasterai/_modidx.py +++ b/fasterai/_modidx.py @@ -2,8 +2,8 @@ d = { 'settings': { 'branch': 'master', 'doc_baseurl': '/fasterai/', - 'doc_host': 'https://nathanhubens.github.io', - 'git_url': 'https://github.com/nathanhubens/fasterai/tree/master/', + 'doc_host': 'https://FasterAI-Labs.github.io', + 'git_url': 'https://github.com/FasterAI-Labs/fasterai/tree/master/', 'lib_path': 'fasterai'}, 'syms': { 'fasterai.core.all': {}, 'fasterai.core.criteria': { 'fasterai.core.criteria.Criteria': ('core.criteria.html#criteria', 'fasterai/core/criteria.py'), diff --git a/nbs/nbdev.yml b/nbs/nbdev.yml index 502df65..82d8225 100644 --- a/nbs/nbdev.yml +++ b/nbs/nbdev.yml @@ -3,7 +3,7 @@ project: website: title: "fasterai" - site-url: "https://nathanhubens.github.io/fasterai/" + site-url: "https://FasterAI-Labs.github.io/fasterai/" description: "A library to make neural networks lighter and faster with fastai" repo-branch: master - repo-url: "https://github.com/nathanhubens/fasterai/tree/master/" + repo-url: "https://github.com/FasterAI-Labs/fasterai/tree/master/"