diff --git a/fluent_pose_synthesis/.style.yapf b/fluent_pose_synthesis/.style.yapf new file mode 100644 index 0000000..4029262 --- /dev/null +++ b/fluent_pose_synthesis/.style.yapf @@ -0,0 +1,9 @@ +[style] +based_on_style = pep8 +column_limit = 120 +split_before_named_assigns = false +coalesce_brackets = true +split_before_expression_after_opening_paren = false +split_arguments_when_comma_terminated = false +each_dict_entry_on_separate_line = false +indent_dictionary_value = false \ No newline at end of file diff --git a/fluent_pose_synthesis/config/default.json b/fluent_pose_synthesis/config/default.json index baa2c39..79d0bd5 100644 --- a/fluent_pose_synthesis/config/default.json +++ b/fluent_pose_synthesis/config/default.json @@ -11,21 +11,31 @@ "dropout": 0.2, "activation": "gelu", "ablation": null, - "legacy": false + "legacy": false, + "history_len": 5 }, "diff": { "noise_schedule": "cosine", - "diffusion_steps": 32, + "diffusion_steps": 8, "sigma_small": true }, "trainer": { - "epoch": 300, + "epoch": 500, "lr": 1e-4, "batch_size": 1024, - "cond_mask_prob": 0, + "cond_mask_prob": 0.15, "use_loss_mse": true, "use_loss_vel": true, + "use_loss_accel": false, + "lambda_vel": 1.0, + "lambda_accel": 1.0, + "guidance_scale": 2.0, "workers": 4, - "load_num": 200 + "load_num": -1, + "validation_max_len": 160, + "validation_chunk_size": 40, + "validation_stop_threshold": 1e-4, + "eval_freq": 1, + "use_amp": false } } \ No newline at end of file diff --git a/fluent_pose_synthesis/config/option.py b/fluent_pose_synthesis/config/option.py index 200b782..7e4f710 100644 --- a/fluent_pose_synthesis/config/option.py +++ b/fluent_pose_synthesis/config/option.py @@ -10,22 +10,32 @@ def add_model_args(parser): parser.add_argument('--num_heads', type=int, default=4, help='Number of attention heads.') parser.add_argument('--num_layers', type=int, default=4, help='Number of model layers.') + def add_diffusion_args(parser): parser.add_argument('--noise_schedule', type=str, default='cosine', help='Noise schedule: "cosine", "linear", etc.') - parser.add_argument('--diffusion_steps', type=int, default=4, help='Number of diffusion steps.') + parser.add_argument('--diffusion_steps', type=int, default=8, help='Number of diffusion steps.') parser.add_argument('--sigma_small', action='store_true', help='Use small sigma values.') + def add_train_args(parser): - parser.add_argument('--epoch', type=int, default=300, help='Number of training epochs.') + parser.add_argument('--epoch', type=int, default=500, help='Number of training epochs.') parser.add_argument('--lr', type=float, default=0.00005, help='Learning rate.') parser.add_argument('--lr_anneal_steps', type=int, default=0, help='Annealing steps.') - parser.add_argument('--weight_decay', type=float, default=0.00, help='Weight decay.') + parser.add_argument('--weight_decay', type=float, default=0.001, help='Weight decay.') parser.add_argument('--batch_size', type=int, default=1024, help='Batch size.') - parser.add_argument('--cond_mask_prob', type=float, default=0, help='Conditioning mask probability.') + parser.add_argument('--cond_mask_prob', type=float, default=0.15, help='Conditioning mask probability.') parser.add_argument('--workers', type=int, default=4, help='Data loader workers.') - parser.add_argument('--ema', default=False, type=bool, help='Use Exponential Moving Average (EMA) for model parameters.') + parser.add_argument('--ema', default=False, type=bool, + help='Use Exponential Moving Average (EMA) for model parameters.') parser.add_argument('--lambda_vel', type=float, default=1.0, help='Weight factor for the velocity loss term.') + parser.add_argument('--use_loss_vel', action='store_true', default=True, help='Enable velocity loss term.') + parser.add_argument('--use_loss_accel', action='store_true', default=False, help='Enable acceleration loss term.') + parser.add_argument('--lambda_accel', type=float, default=1.0, help='Weight factor for the acceleration loss term.') + parser.add_argument('--guidance_scale', type=float, default=2.0, + help='Classifier-free guidance scale for inference.') parser.add_argument('--load_num', type=int, default=-1, help='Number of models to load.') + parser.add_argument('--use_amp', action='store_true', default=False, help='Use mixed precision training (AMP).') + parser.add_argument('--eval_freq', type=int, default=1, help='Frequency of evaluation during training.') def config_parse(args): @@ -49,14 +59,17 @@ def config_parse(args): config.trainer.lr_anneal_steps = args.lr_anneal_steps config.trainer.weight_decay = args.weight_decay config.trainer.batch_size = args.batch_size - config.trainer.ema = True #if args.ema else config.trainer.ema + config.trainer.ema = True #if args.ema else config.trainer.ema config.trainer.cond_mask_prob = args.cond_mask_prob config.trainer.workers = args.workers config.trainer.save_freq = int(config.trainer.epoch // 5) config.trainer.lambda_vel = args.lambda_vel + config.trainer.use_loss_vel = args.use_loss_vel + config.trainer.use_loss_accel = args.use_loss_accel + config.trainer.lambda_accel = args.lambda_accel + config.trainer.guidance_scale = args.guidance_scale config.trainer.load_num = args.load_num - # Save directory data_prefix = args.data.split('/')[-1].split('.')[0] config.save = f'{args.save}/{args.name}_{data_prefix}' if 'debug' not in config.name else f'{args.save}/{args.name}' diff --git a/fluent_pose_synthesis/core/models.py b/fluent_pose_synthesis/core/models.py index 9858ec8..5cfd305 100644 --- a/fluent_pose_synthesis/core/models.py +++ b/fluent_pose_synthesis/core/models.py @@ -9,22 +9,19 @@ class OutputProcessMLP(nn.Module): """ Output process for the Sign Language Pose Diffusion model. """ - def __init__(self, input_feats, latent_dim, njoints, nfeats, hidden_dim=512): # add hidden_dim as parameter + + def __init__(self, input_feats, latent_dim, njoints, nfeats, hidden_dim=512): # add hidden_dim as parameter super().__init__() self.input_feats = input_feats self.latent_dim = latent_dim self.njoints = njoints self.nfeats = nfeats - self.hidden_dim = hidden_dim # store hidden dimension + self.hidden_dim = hidden_dim # store hidden dimension # MLP layers - self.mlp = nn.Sequential( - nn.Linear(self.latent_dim, self.hidden_dim), - nn.SiLU(), - nn.Linear(self.hidden_dim, self.hidden_dim // 2), - nn.SiLU(), - nn.Linear(self.hidden_dim // 2, self.input_feats) - ) + self.mlp = nn.Sequential(nn.Linear(self.latent_dim, self.hidden_dim), nn.SiLU(), + nn.Linear(self.hidden_dim, self.hidden_dim // 2), nn.SiLU(), + nn.Linear(self.hidden_dim // 2, self.input_feats)) def forward(self, output): nframes, bs, d = output.shape @@ -39,25 +36,11 @@ class SignLanguagePoseDiffusion(nn.Module): Sign Language Pose Diffusion model. """ - def __init__( - self, - input_feats: int, - chunk_len: int, - keypoints: int, - dims: int, - latent_dim: int = 256, - ff_size: int = 1024, - num_layers: int = 8, - num_heads: int = 4, - dropout: float = 0.2, - ablation: Optional[str] = None, - activation: str = "gelu", - legacy: bool = False, - arch: str = "trans_enc", - cond_mask_prob: float = 0, - device: Optional[torch.device] = None, - batch_first: bool = True - ): + def __init__(self, input_feats: int, chunk_len: int, keypoints: int, dims: int, latent_dim: int = 256, + ff_size: int = 1024, num_layers: int = 8, num_heads: int = 4, dropout: float = 0.2, + ablation: Optional[str] = None, activation: str = "gelu", legacy: bool = False, + arch: str = "trans_enc", cond_mask_prob: float = 0, device: Optional[torch.device] = None, + batch_first: bool = True): """ Args: input_feats (int): Number of input features (keypoints * dimensions). @@ -105,37 +88,33 @@ def __init__( # Define sequence encoder based on chosen architecture if self.arch == "trans_enc": print(f"Initializing Transformer Encoder (batch_first={self.batch_first})") - encoder_layer = nn.TransformerEncoderLayer( - d_model=latent_dim, nhead=num_heads, dim_feedforward=ff_size, - dropout=dropout, activation=activation, batch_first=self.batch_first - ) + encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim, nhead=num_heads, dim_feedforward=ff_size, + dropout=dropout, activation=activation, + batch_first=self.batch_first) self.sequence_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) elif self.arch == "trans_dec": print(f"Initializing Transformer Decoder (batch_first={self.batch_first})") - decoder_layer = nn.TransformerDecoderLayer( - d_model=latent_dim, nhead=num_heads, dim_feedforward=ff_size, - dropout=dropout, activation=activation, batch_first=self.batch_first - ) + decoder_layer = nn.TransformerDecoderLayer(d_model=latent_dim, nhead=num_heads, dim_feedforward=ff_size, + dropout=dropout, activation=activation, + batch_first=self.batch_first) self.sequence_encoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers) elif self.arch == "gru": - print("Initializing GRU Encoder (batch_first=True)") - self.sequence_encoder = nn.GRU( - latent_dim, latent_dim, num_layers=num_layers, batch_first=True - ) + print("Initializing GRU Encoder (batch_first=True)") + self.sequence_encoder = nn.GRU(latent_dim, latent_dim, num_layers=num_layers, batch_first=True) else: raise ValueError("Please choose correct architecture [trans_enc, trans_dec, gru]") # Pose projection: projects latent representation back to pose space. # The OutputProcess returns (B, keypoints, dims, T); apply a post_transform to get (B, T, keypoints, dims) - self.pose_projection = OutputProcessMLP(input_feats, latent_dim, keypoints, dims, hidden_dim=512) + self.pose_projection = OutputProcessMLP(input_feats, latent_dim, keypoints, dims, hidden_dim=1024) self.to(self.device) def forward( - self, - fluent_clip: torch.Tensor, # (B, K, D, T_chunk) - disfluent_seq: torch.Tensor, # (B, K, D, T_disfl) - t: torch.Tensor, # (B,) - previous_output: Optional[torch.Tensor] = None # (B, K, D, T_hist) + self, + fluent_clip: torch.Tensor, # (B, K, D, T_chunk) + disfluent_seq: torch.Tensor, # (B, K, D, T_disfl) + t: torch.Tensor, # (B,) + previous_output: Optional[torch.Tensor] = None # (B, K, D, T_hist) ) -> torch.Tensor: # # --- DEBUG: Print Initial Input Shapes --- @@ -158,15 +137,15 @@ def forward( T_chunk = fluent_clip.shape[-1] # 1. Embed Timestep - _t_emb_raw = self.embed_timestep(t) # Expected (B, D) + _t_emb_raw = self.embed_timestep(t) # Expected (B, D) # print(f"[DEBUG FWD 1a] Raw t_emb shape: {_t_emb_raw.shape}") - t_emb = _t_emb_raw.permute(1, 0, 2) + t_emb = _t_emb_raw.permute(1, 0, 2).contiguous() # print(f"[DEBUG FWD 1b] Final t_emb shape: {t_emb.shape}") # 2. Embed Disfluent Sequence (Condition) - _disfluent_emb_raw = self.disfluent_encoder(disfluent_seq) # Expected (T_disfl, B, D) + _disfluent_emb_raw = self.disfluent_encoder(disfluent_seq) # Expected (T_disfl, B, D) # print(f"[DEBUG FWD 2a] Raw disfluent_emb shape: {_disfluent_emb_raw.shape}") - disfluent_emb = _disfluent_emb_raw.permute(1, 0, 2) # Expected (B, T_disfl, D) + disfluent_emb = _disfluent_emb_raw.permute(1, 0, 2).contiguous() # Expected (B, T_disfl, D) # print(f"[DEBUG FWD 2b] Final disfluent_emb shape: {disfluent_emb.shape}") # 3. Embed Previous Output (History), if available @@ -174,9 +153,9 @@ def forward( # print("[DEBUG FWD 3a] Processing previous_output...") if previous_output is not None and previous_output.shape[-1] > 0: # print(f"[DEBUG FWD 3b] History Input shape: {previous_output.shape}") - _prev_out_emb_raw = self.fluent_encoder(previous_output) # Expected (T_hist, B, D) + _prev_out_emb_raw = self.fluent_encoder(previous_output) # Expected (T_hist, B, D) # print(f"[DEBUG FWD 3c] Raw prev_out_emb shape: {_prev_out_emb_raw.shape}") - prev_out_emb = _prev_out_emb_raw.permute(1, 0, 2) # Expected (B, T_hist, D) + prev_out_emb = _prev_out_emb_raw.permute(1, 0, 2).contiguous() # Expected (B, T_hist, D) # print(f"[DEBUG FWD 3d] Final prev_out_emb shape: {prev_out_emb.shape}") embeddings_to_concat.append(prev_out_emb) else: @@ -184,9 +163,9 @@ def forward( pass # 4. Embed Current Fluent Clip (Noisy Target 'x') - _fluent_emb_raw = self.fluent_encoder(fluent_clip) # Expected (T_chunk, B, D) + _fluent_emb_raw = self.fluent_encoder(fluent_clip) # Expected (T_chunk, B, D) # print(f"[DEBUG FWD 4a] Raw fluent_emb shape: {_fluent_emb_raw.shape}") - fluent_emb = _fluent_emb_raw.permute(1, 0, 2) # Expected (B, T_chunk, D) + fluent_emb = _fluent_emb_raw.permute(1, 0, 2).contiguous() # Expected (B, T_chunk, D) # print(f"[DEBUG FWD 4b] Final fluent_emb shape: {fluent_emb.shape}") embeddings_to_concat.append(fluent_emb) @@ -198,33 +177,33 @@ def forward( # print(f"[DEBUG FWD 6a] xseq shape before PositionalEncoding: {xseq.shape}") # Adapt based on PositionalEncoding expectation (T, B, D) vs batch_first if self.batch_first: - xseq_permuted = xseq.permute(1, 0, 2) # (T_total, B, D) + xseq_permuted = xseq.permute(1, 0, 2).contiguous() # (T_total, B, D) # print(f"[DEBUG FWD 6b] xseq permuted for PosEnc: {xseq_permuted.shape}") xseq_encoded = self.sequence_pos_encoder(xseq_permuted) # print(f"[DEBUG FWD 6c] xseq after PosEnc: {xseq_encoded.shape}") - xseq = xseq_encoded.permute(1, 0, 2) # Back to (B, T_total, D) + xseq = xseq_encoded.permute(1, 0, 2) # Back to (B, T_total, D) # print(f"[DEBUG FWD 6d] xseq permuted back: {xseq.shape}") else: - # If not batch_first, assume xseq should be (T, B, D) already - # Need to adjust concatenation and permutations above if batch_first=False - xseq = xseq.permute(1, 0, 2) # Assume needs (T, B, D) + # If not batch_first, assume xseq should be (T, B, D) already + # Need to adjust concatenation and permutations above if batch_first=False + xseq = xseq.permute(1, 0, 2) # Assume needs (T, B, D) # print(f"[DEBUG FWD 6b] xseq permuted for PosEnc (batch_first=False): {xseq.shape}") - xseq = self.sequence_pos_encoder(xseq) - # print(f"[DEBUG FWD 6c] xseq after PosEnc (batch_first=False): {xseq.shape}") - # Keep as (T, B, D) if encoder needs it + xseq = self.sequence_pos_encoder(xseq) + # print(f"[DEBUG FWD 6c] xseq after PosEnc (batch_first=False): {xseq.shape}") + # Keep as (T, B, D) if encoder needs it # 7. Process through sequence encoder # print(f"[DEBUG FWD 7a] Input to sequence_encoder ({self.arch}) shape: {xseq.shape}") if self.arch == "trans_enc": - x_encoded = self.sequence_encoder(xseq) + x_encoded = self.sequence_encoder(xseq) elif self.arch == "gru": - x_encoded, _ = self.sequence_encoder(xseq) + x_encoded, _ = self.sequence_encoder(xseq) elif self.arch == "trans_dec": - memory = xseq - tgt = xseq - x_encoded = self.sequence_encoder(tgt=tgt, memory=memory) + memory = xseq + tgt = xseq + x_encoded = self.sequence_encoder(tgt=tgt, memory=memory) else: - raise ValueError("Unsupported architecture") + raise ValueError("Unsupported architecture") # print(f"[DEBUG FWD 7b] Output from sequence_encoder shape: {x_encoded.shape}") # 8. Extract the output corresponding to the target fluent_clip @@ -249,30 +228,11 @@ def forward( return output - - # def mingyi_forward( - # self, fluent_clip: torch.Tensor, disfluent_seq: torch.Tensor, t: torch.Tensor - # ) -> torch.Tensor: - - # batch_size, keypoints, dims, time = fluent_clip.shape - - # t_emb = self.embed_timestep(t) - # disfluent_emb = self.disfluent_encoder(disfluent_seq) - # fluent_emb = self.fluent_encoder(fluent_clip) - - # xseq = torch.cat((t_emb, disfluent_emb, fluent_emb), axis=0) - # xseq = self.sequence_pos_encoder(xseq) - - # x_out = self.sequence_encoder(xseq)[:time] - # output = self.pose_projection(x_out) - - # return output - def interface( - self, - fluent_clip: torch.Tensor, # (B, K, D, T_chunk) - t: torch.Tensor, # (B,) - y: dict[str, torch.Tensor] # Conditions dict + self, + fluent_clip: torch.Tensor, # (B, K, D, T_chunk) + t: torch.Tensor, # (B,) + y: dict[str, torch.Tensor] # Conditions dict ) -> torch.Tensor: """ Interface for Classifier-Free Guidance (CFG). Handles previous_output. @@ -283,13 +243,8 @@ def interface( previous_output = y.get("previous_output", None) # Apply CFG: randomly drop the condition with probability cond_mask_prob - keep_batch_idx = torch.rand(batch_size, device=disfluent_seq.device) < (1-self.cond_mask_prob) + keep_batch_idx = torch.rand(batch_size, device=disfluent_seq.device) < (1 - self.cond_mask_prob) disfluent_seq = disfluent_seq * keep_batch_idx.view((batch_size, 1, 1, 1)) # Call the forward function - return self.forward( - fluent_clip=fluent_clip, - disfluent_seq=disfluent_seq, - t=t, - previous_output=previous_output - ) \ No newline at end of file + return self.forward(fluent_clip=fluent_clip, disfluent_seq=disfluent_seq, t=t, previous_output=previous_output) diff --git a/fluent_pose_synthesis/core/training.py b/fluent_pose_synthesis/core/training.py index 13f0469..cdde02b 100644 --- a/fluent_pose_synthesis/core/training.py +++ b/fluent_pose_synthesis/core/training.py @@ -1,21 +1,53 @@ # pylint: disable=protected-access, arguments-renamed -from typing import Optional, Tuple, Dict, Any +from typing import Optional, Tuple, Dict, Any, List +from pathlib import Path +import itertools +import time import numpy as np import torch from torch import Tensor from torch.utils.data import DataLoader +from torch.utils.data import Subset, DataLoader +from tqdm import tqdm +from torch.amp import GradScaler, autocast +import torch.nn as nn from pose_format import Pose from pose_format.torch.masked.collator import zero_pad_collator from pose_format.numpy.pose_body import NumPyPoseBody from pose_format.utils.generic import normalize_pose_size from pose_anonymization.data.normalization import unnormalize_mean_std +from pose_evaluation.metrics.distance_metric import DistanceMetric +from pose_evaluation.metrics.dtw_metric import DTWDTAIImplementationDistanceMeasure +from pose_evaluation.metrics.pose_processors import NormalizePosesProcessor from CAMDM.diffusion.gaussian_diffusion import GaussianDiffusion from CAMDM.network.training import BaseTrainingPortal from CAMDM.utils.common import mkdir +class _ConditionalWrapper(nn.Module): + """Wraps a base model and a fixed conditioning dict, forwarding only (x, t).""" + + def __init__(self, base_model: nn.Module, cond: dict): + super().__init__() + self.base_model = base_model + self.cond = cond + + def forward(self, x, t, **kwargs): + # Ignore incoming kwargs, use fixed cond + return self.base_model.interface(x, t, self.cond) + + +def move_to_device(val, device): + if torch.is_tensor(val): + return val.to(device) + elif isinstance(val, dict): + return {k: move_to_device(v, device) for k, v in val.items()} + else: + return val + + def masked_l2_per_sample(x: Tensor, y: Tensor, mask: Optional[Tensor] = None, reduce: bool = True) -> Tensor: """ Compute masked L2 loss per sample. Correctly handles division by zero for fully masked samples. @@ -26,28 +58,20 @@ def masked_l2_per_sample(x: Tensor, y: Tensor, mask: Optional[Tensor] = None, re True = masked (invalid), False = valid. reduce: Whether to average the per-sample loss over the batch dimension. """ - diff_sq = (x - y) ** 2 # (B, K, D, T) + diff_sq = (x - y)**2 # (B, K, D, T) if mask is not None: mask = mask.bool() # Ensure boolean type # Invert mask: False (valid) -> 1.0, True (masked) -> 0.0 valid_mask_elements = (~mask).float() - # Apply mask to zero out loss contribution from invalid elements diff_sq = diff_sq * valid_mask_elements else: - # If no mask is provided, all elements are considered valid valid_mask_elements = torch.ones_like(diff_sq) - # Sum squared errors over all dimensions except batch for each sample - per_sample_loss_sum = diff_sq.flatten(start_dim=1).sum(dim=1) # Shape: (B,) + per_sample_loss_sum = diff_sq.flatten(start_dim=1).sum(dim=1) # Shape: (B,) - # Count the number of valid elements for each sample - valid_elements_count = valid_mask_elements.flatten(start_dim=1).sum(dim=1) # Shape: (B,) - - # Compute mean squared error per sample - # Clamp denominator to avoid division by zero (0/0 = NaN). - # If valid_elements_count is 0, per_sample_loss_sum is also 0, resulting in 0 loss. - per_sample_loss = per_sample_loss_sum / valid_elements_count.clamp(min=1) # Shape: (B,) + valid_elements_count = valid_mask_elements.flatten(start_dim=1).sum(dim=1) # Shape: (B,) + per_sample_loss = per_sample_loss_sum / valid_elements_count.clamp(min=1) # Shape: (B,) if reduce: # Return the average loss across the batch @@ -58,15 +82,17 @@ def masked_l2_per_sample(x: Tensor, y: Tensor, mask: Optional[Tensor] = None, re class PoseTrainingPortal(BaseTrainingPortal): + def __init__( self, config: Any, model: torch.nn.Module, diffusion: GaussianDiffusion, - dataloader: DataLoader, + dataloader: DataLoader, # Training dataloader logger: Optional[Any], tb_writer: Optional[Any], - finetune_loader: Optional[DataLoader] = None, + validation_dataloader: Optional[DataLoader] = None, + prior_loader: Optional[DataLoader] = None, ): """ Training portal specialized for pose diffusion tasks. @@ -77,19 +103,44 @@ def __init__( dataloader: The main training dataloader. logger: Logger instance (optional). tb_writer: TensorBoard writer (optional). - finetune_loader: Optional finetuning dataloader. + validation_dataloader: Optional validation dataloader. + prior_loader: Optional prior dataloader. """ - super().__init__( - config, model, diffusion, dataloader, logger, tb_writer, finetune_loader - ) + super().__init__(config, model, diffusion, dataloader, logger, tb_writer, prior_loader) self.pose_header = None self.device = config.device + self.validation_dataloader = validation_dataloader + self.best_validation_metric = float("inf") + + # Initialize DTW metric calculator + default_dtw_dist_val = 0.0 + self.validation_metric_calculator = DistanceMetric( + name="Validation DTW", + distance_measure=DTWDTAIImplementationDistanceMeasure( + name="dtaiDTW", + use_fast=True, + default_distance=default_dtw_dist_val, + ), + pose_preprocessors=[NormalizePosesProcessor()], + ) + self.logger.info(f"Initialized DTW metric with default_distance: {default_dtw_dist_val}") + + # Store normalization statistics from the training dataset for unnormalization + self.data_input_mean = torch.tensor(dataloader.dataset.input_mean, device=self.device, + dtype=torch.float32).squeeze() + self.data_input_std = torch.tensor(dataloader.dataset.input_std, device=self.device, + dtype=torch.float32).squeeze() + + # Store pose_header from validation dataset (for saving poses) + self.val_pose_header = self.validation_dataloader.dataset.pose_header + if self.val_pose_header: + self.logger.info("Pose header loaded from validation dataset.") def diffuse( self, - x_start: Tensor, # Target fluent chunk (from dataloader['data']) + x_start: Tensor, # Target fluent chunk (from dataloader['data']) t: Tensor, # Diffusion timesteps - cond: Dict[str, Tensor], # Conditioning inputs (from dataloader['conditions']) + cond: Dict[str, Tensor], # Conditioning inputs (from dataloader['conditions']) noise: Optional[Tensor] = None, return_loss: bool = False, ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: @@ -106,29 +157,28 @@ def diffuse( return_loss: Whether to compute and return training losses. """ # 1. Permute x_start from (B, T_chunk, K, D) to (B, K, D, T_chunk) - x_start = x_start.permute(0, 2, 3, 1).to(self.device) # (B, K, D, T_chunk) + x_start = x_start.permute(0, 2, 3, 1).to(self.device) # (B, K, D, T_chunk) if noise is None: noise = torch.randn_like(x_start) # 2. Apply forward diffusion process: q_sample(x_start, t) -> x_t - x_t = self.diffusion.q_sample(x_start, t.to(self.device), noise=noise) # (B, K, D, T_chunk) + x_t = self.diffusion.q_sample(x_start, t.to(self.device), noise=noise) # (B, K, D, T_chunk) # 3. Prepare conditions for the model processed_cond = {} for key, val in cond.items(): - processed_cond[key] = val.to(self.device) + processed_cond[key] = move_to_device(val, self.device) # Permute sequence conditions to (B, K, D, T) expected by MotionProcess encoders - processed_cond["input_sequence"] = processed_cond["input_sequence"].permute(0, 2, 3, 1) # (B, K, D, T_disfl) - processed_cond["previous_output"] = processed_cond["previous_output"].permute(0, 2, 3, 1) # (B, K, D, T_hist) + processed_cond["input_sequence"] = processed_cond["input_sequence"].permute(0, 2, 3, 1) # (B, K, D, T_disfl) + processed_cond["previous_output"] = processed_cond["previous_output"].permute(0, 2, 3, 1) # (B, K, D, T_hist) # 4. Call the model's interface model_output = self.model.interface(x_t, self.diffusion._scale_timesteps(t.to(self.device)), processed_cond) # Permute output back to (B, T_chunk, K, D) for consistency if not calculating loss model_output_original_shape = model_output.permute(0, 3, 1, 2) - # 5. Compute Loss (if requested) if return_loss: loss_terms = {} @@ -138,32 +188,46 @@ def diffuse( if mmt.name == "PREVIOUS_X": target = self.diffusion.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0] elif mmt.name == "START_X": - target = x_start # Target is the original clean chunk + target = x_start # Target is the original clean chunk elif mmt.name == "EPSILON": - target = noise # Target is the noise added + target = noise # Target is the noise added else: raise ValueError(f"Unsupported model_mean_type: {mmt}") - assert (model_output.shape == target.shape == x_start.shape), "Shape mismatch between model output, target, and x_start" - - # Process the target_mask - mask_from_loader = processed_cond["target_mask"] - # print(f"[DEBUG diffuse] Received target_mask shape: {mask_from_loader.shape}") - - # Adapt mask shape based on loader output (assuming B, T, K, D) - mask = mask_from_loader.permute(0, 2, 3, 1) # -> (B, K, D, T_chunk) - # print(f"[DEBUG diffuse] Final mask shape for loss: {mask.shape}") + assert (model_output.shape == target.shape == + x_start.shape), "Shape mismatch between model output, target, and x_start" - # Calculate loss only for samples that have at least one valid frame/point - # Sum mask over K, D, T dimensions. Check if sum > 0 for each batch item. - batch_has_valid = (mask.float().sum(dim=(1, 2, 3)) < mask.shape[1]*mask.shape[2]*mask.shape[3]) # Check if not all masked + # Retrieve the optional target_mask which flags padded or invalid frames (True=masked) + mask_from_loader = processed_cond.get("target_mask", None) + if mask_from_loader is not None: + # Permute mask shape from (B, T_chunk, K, D) to (B, K, D, T_chunk) + mask = mask_from_loader.permute(0, 2, 3, 1) + else: + # Create mask based on original fluent lengths + original_lengths = cond.get("metadata", {}).get("fluent_pose_length", None) + if original_lengths is not None: + original_lengths = original_lengths.to(x_start.device) # Shape: (B,) + B, K, D, T_padded = x_start.shape + time_idx = torch.arange(T_padded, device=x_start.device).unsqueeze(0).expand(B, -1) # (B, T_padded) + mask_for_time = time_idx >= original_lengths.unsqueeze(1) # (B, T_padded) + mask = mask_for_time.unsqueeze(1).unsqueeze(1).expand(B, K, D, T_padded) + else: + mask = torch.zeros_like(x_start) + print("[WARNING] No target_mask provided. Using zero mask (no frames masked).") + + batch_has_valid = (mask.float().sum(dim=(1, 2, 3)) + < mask.shape[1] * mask.shape[2] * mask.shape[3]) # Check if not all masked valid_batch_indices = batch_has_valid.nonzero().squeeze() if valid_batch_indices.numel() == 0: print("[WARNING] All samples in this batch are fully masked. Skipping loss computation.") # Returning zero loss zero_loss = torch.tensor(0.0, device=self.device, requires_grad=False) - loss_terms = {"loss": zero_loss, "loss_data": zero_loss, "loss_data_vel": zero_loss} + loss_terms = { + "loss": zero_loss, + "loss_data": zero_loss, + "loss_data_vel": zero_loss, + } # Need to return model_output still return model_output_original_shape, loss_terms @@ -171,27 +235,46 @@ def diffuse( # Use the already computed `mask` (shape B, K, D, T) where True=masked if self.config.trainer.use_loss_mse: loss_data = masked_l2_per_sample(target, model_output, mask, reduce=True) + # loss_data = torch.nn.MSELoss()(model_output, target) loss_terms["loss_data"] = loss_data + # --- Velocity Loss --- + lambda_vel = getattr(self.config.trainer, "lambda_vel", 1.0) if self.config.trainer.use_loss_vel: - # Calculate velocity on time axis (last dimension) + # Compute first-order difference (velocity) target_vel = target[..., 1:] - target[..., :-1] model_output_vel = model_output[..., 1:] - model_output[..., :-1] - # Create mask for velocity (same shape as velocity) mask_vel = mask[..., 1:] if mask is not None else None - loss_data_vel = masked_l2_per_sample(target_vel, model_output_vel, mask_vel, reduce=True) loss_terms["loss_data_vel"] = loss_data_vel - if hasattr(self.config.trainer, "lambda_vel"): - lambda_vel = self.config.trainer.lambda_vel - - # Calulate Total Loss - total_loss = 0.0 - if self.config.trainer.use_loss_mse: - total_loss += loss_terms.get("loss_data", 0.0) - if self.config.trainer.use_loss_vel: - total_loss += lambda_vel * loss_terms.get("loss_data_vel", 0.0) + # --- Optional: Acceleration Loss --- + if getattr(self.config.trainer, "use_loss_accel", False): + lambda_accel = getattr(self.config.trainer, "lambda_accel", 1.0) + # Compute second-order difference (acceleration) + target_accel = target_vel[..., 1:] - target_vel[..., :-1] + model_output_accel = model_output_vel[..., 1:] - model_output_vel[..., :-1] + mask_accel = mask_vel[..., 1:] if mask_vel is not None else None + loss_data_accel = masked_l2_per_sample(target_accel, model_output_accel, mask_accel, reduce=True) + loss_terms["loss_data_accel"] = loss_data_accel + + # --- Compute Weighted Total Loss (Bullet-proof version) --- + # Initialize a new zero tensor for total_loss to avoid in-place side effects + total_loss = torch.tensor(0.0, device=self.device) + + # Accumulate each component without using in-place operations + if "loss_data" in loss_terms: + total_loss = total_loss + loss_terms["loss_data"] + + lambda_vel = getattr(self.config.trainer, "lambda_vel", 1.0) + if "loss_data_vel" in loss_terms: + total_loss = total_loss + (lambda_vel * loss_terms["loss_data_vel"]) + + lambda_accel = getattr(self.config.trainer, "lambda_accel", 1.0) + if "loss_data_accel" in loss_terms: + total_loss = total_loss + (lambda_accel * loss_terms["loss_data_accel"]) + + # Store the combined result loss_terms["loss"] = total_loss return model_output_original_shape, loss_terms @@ -199,9 +282,385 @@ def diffuse( # If return_loss is False, just return the model output return model_output_original_shape, None - def evaluate_sampling( - self, dataloader: DataLoader, save_folder_name: str = "init_samples" - ): + def _run_validation_epoch(self) -> Optional[float]: + """ + Runs validation using helper methods. + """ + if self.validation_dataloader is None: + if self.logger: + self.logger.info("Validation dataloader not provided. Skipping validation.") + return None + + self.model.eval() + with torch.no_grad(): + references, predictions = [], [] + for batch_idx, batch_data in enumerate(self.validation_dataloader): + batch_refs, batch_preds = self._process_validation_batch(batch_data, batch_idx) + references.extend(batch_refs) + predictions.extend(batch_preds) + + if not references: + if self.logger: + self.logger.warning("No poses collected during validation for DTW calculation.") + self.model.train() + return float("inf") + + if self.logger: + self.logger.info(f"Calculating DTW for {len(references)} validation samples...") + dtw_score = self._compute_dtw_score(predictions, references) + self.model.train() + return dtw_score + + def _process_validation_batch(self, batch_data: Dict[str, Any], batch_idx: int) -> Tuple[List[Pose], List[Pose]]: + """ + Process a single validation batch: generate sequences, unnormalize, and build Pose objects. + """ + with torch.no_grad(): + # 1. Autoregressive inference (extracted from original implementation) + gt_fluent_full_loader = batch_data["full_fluent_reference"].to(self.device) + disfluent_cond_seq_loader = batch_data["conditions"]["input_sequence"].to(self.device) + initial_history_loader = batch_data["conditions"]["previous_output"].to(self.device) + # Permute formats and prepare history + disfluent_cond_seq = disfluent_cond_seq_loader.permute(0, 2, 3, 1) + current_history = initial_history_loader.permute(0, 2, 3, 1) + history_len = getattr(self.config.arch, "history_len", 5) + # Trim or pad history + if current_history.shape[3] > history_len: + current_history = current_history[:, :, :, -history_len:] + elif current_history.shape[3] < history_len: + padding = torch.zeros(current_history.shape[0], current_history.shape[1], current_history.shape[2], + history_len - current_history.shape[3], device=self.device) + current_history = torch.cat([padding, current_history], dim=3) + # Prepare autoregressive generation + K = self.config.arch.keypoints + D_feat = self.config.arch.dims + max_len = getattr(self.config.trainer, "validation_max_len", 160) + chunk_size = getattr(self.config.trainer, "validation_chunk_size", self.config.arch.chunk_len) + stop_thresh = getattr(self.config.trainer, "validation_stop_threshold", 1e-4) + generated = torch.empty(current_history.shape[0], K, D_feat, 0, device=self.device) + active = torch.ones(current_history.shape[0], dtype=torch.bool, device=self.device) + num_steps = (max_len + chunk_size - 1) // chunk_size + for _ in range(num_steps): + if not active.any(): + break + n_frames = min(chunk_size, max_len - generated.shape[3]) + target_shape = (current_history.shape[0], K, D_feat, n_frames) + # --- Use classifier-free guidance --- + cond_dict = {"input_sequence": disfluent_cond_seq, "previous_output": current_history} + guidance_scale = getattr(self.config.trainer, "guidance_scale", 2.0) + # Unconditional input: zero out the disfluent sequence + uncond_disfluent_seq = torch.zeros_like(disfluent_cond_seq) + uncond_dict = {"input_sequence": uncond_disfluent_seq, "previous_output": current_history} + + # Conditional sampling + wrapped_model_cond = _ConditionalWrapper(self.model, cond_dict) + cond_chunk = self.diffusion.p_sample_loop( + model=wrapped_model_cond, shape=target_shape, + clip_denoised=getattr(self.config.diff, "clip_denoised", + False), model_kwargs={"y": cond_dict}, progress=False) + # Unconditional sampling + wrapped_model_uncond = _ConditionalWrapper(self.model, uncond_dict) + uncond_chunk = self.diffusion.p_sample_loop( + model=wrapped_model_uncond, shape=target_shape, + clip_denoised=getattr(self.config.diff, "clip_denoised", + False), model_kwargs={"y": uncond_dict}, progress=False) + # Combine with guidance scale + chunk = uncond_chunk + guidance_scale * (cond_chunk - uncond_chunk) + # Stop condition + if chunk.numel() > 0: + mean_abs = chunk.abs().mean(dim=(1, 2, 3)) + stopped = (mean_abs < stop_thresh) & active + chunk[stopped] = 0 + active = active & (~stopped) + generated = torch.cat([generated, chunk], dim=3) + # Update history + if generated.shape[3] >= history_len: + current_history = generated[:, :, :, -history_len:] + else: + pad = torch.zeros(generated.shape[0], K, D_feat, history_len - generated.shape[3], + device=self.device) + current_history = torch.cat([pad, generated], dim=3) + # Permute back + pred_normed = generated.permute(0, 3, 1, 2) + + # 2. Unnormalize sequences + val_ds = self.validation_dataloader.dataset + val_mean = torch.tensor(val_ds.input_mean, device=self.device).view(1, 1, K, D_feat) + val_std = torch.tensor(val_ds.input_std, device=self.device).view(1, 1, K, D_feat) + train_mean = self.data_input_mean.view(1, 1, K, D_feat) + train_std = self.data_input_std.view(1, 1, K, D_feat) + gt_unnorm = gt_fluent_full_loader * val_std + val_mean + pred_unnorm = pred_normed * train_std + train_mean + + # 3. Build Pose lists + refs, preds = [], [] + for i in range(gt_unnorm.shape[0]): + # Retrieve original reference length to truncate padded frames + original_lengths = batch_data["metadata"]["fluent_pose_length"] + current_original_length = int(original_lengths[i]) + fps = getattr(self.val_pose_header, "fps", 25.0) + # Truncate to original length before reshaping + ref_truncated_data = gt_unnorm[i, :current_original_length, :, :] + ref_np = ref_truncated_data.cpu().numpy().reshape(current_original_length, 1, K, + D_feat).astype(np.float64) + ref_body = NumPyPoseBody(fps=fps, data=ref_np, confidence=np.ones((current_original_length, 1, K))) + # Use full prediction length without truncation + pred_length = pred_unnorm.shape[1] + pred_full = pred_unnorm[i, :pred_length, :, :] # shape (pred_length, K, D_feat) + pred_np = pred_full.cpu().numpy().reshape(pred_length, 1, K, D_feat).astype(np.float64) + preds_body = NumPyPoseBody(fps=fps, data=pred_np, confidence=np.ones((pred_length, 1, K))) + refs.append(Pose(self.val_pose_header, ref_body)) + preds.append(Pose(self.val_pose_header, preds_body)) + return refs, preds + + def _compute_dtw_score(self, predictions: List[Pose], references: List[Pose]) -> float: + """ + Compute mean DTW distance between predictions and references using corpus_score, with timing. + """ + # Time the corpus_score computation + start_time = time.time() + # Wrap the entire references list as a single reference corpus + wrapped_refs = [references] + mean_score = float(self.validation_metric_calculator.corpus_score(predictions, wrapped_refs)) + elapsed = time.time() - start_time + + # Log timing separately for readability + if self.logger: + elapsed_msg = f"Validation DTW corpus_score time: {elapsed:.4f}s" + self.logger.info(elapsed_msg) + + # Log the actual score + if self.logger: + score_msg = f"=== Validation DTW (corpus_score): {mean_score:.4f} ===" + self.logger.info(score_msg) + + if self.tb_writer: + self.tb_writer.add_scalar("validation/DTW_distance", mean_score, self.epoch) + return mean_score + + # Override the run_loop method to include validation + def run_loop(self, enable_profiler=False, profiler_directory="./logs/tb_profiler"): + print(">>> TRAINING CODE LOADED FROM:", __file__) + use_amp = getattr(self.config.trainer, "use_amp", False) + # Initialize gradient scaler for mixed precision + if use_amp: + scaler = GradScaler("cuda") + else: + scaler = None + if enable_profiler: + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + on_trace_ready=torch.profiler.tensorboard_trace_handler(profiler_directory), + ) + profiler.start() + else: + profiler = None + + sampling_num = 16 + sampling_idx = np.random.randint(0, len(self.dataloader.dataset), sampling_num) + sampling_subset = DataLoader(Subset(self.dataloader.dataset, sampling_idx), batch_size=sampling_num) + self.evaluate_sampling(sampling_subset, save_folder_name="init_samples") + # Sample fixed validation indices for saving predictions across epochs + num_to_save = getattr(self.config.trainer, "validation_save_num", 30) + val_dataset = self.validation_dataloader.dataset + if len(val_dataset) > num_to_save: + self.validation_sample_indices = np.random.choice(len(val_dataset), num_to_save, replace=False).tolist() + else: + self.validation_sample_indices = list(range(len(val_dataset))) + + epoch_process_bar = tqdm(range(self.epoch, self.num_epochs), desc=f"Epoch {self.epoch}") + for epoch_idx in epoch_process_bar: + self.model.train() + self.model.training = True + self.epoch = epoch_idx + epoch_losses = {} + + data_len = len(self.dataloader) + + for datas in self.dataloader: + datas = {key: val.to(self.device) if torch.is_tensor(val) else val for key, val in datas.items()} + cond = { + key: val.to(self.device) if torch.is_tensor(val) else val + for key, val in datas["conditions"].items() + } + x_start = datas["data"] + + self.opt.zero_grad() + t, weights = self.schedule_sampler.sample(x_start.shape[0], self.device) + + if use_amp: + with autocast("cuda"): + _, losses = self.diffuse(x_start, t, cond, noise=None, return_loss=True) + total_loss = (losses["loss"] * weights).mean() + scaler.scale(total_loss).backward() + scaler.step(self.opt) + scaler.update() + else: + _, losses = self.diffuse(x_start, t, cond, noise=None, return_loss=True) + total_loss = (losses["loss"] * weights).mean() + total_loss.backward() + self.opt.step() + + if profiler: + profiler.step() + + if self.config.trainer.ema: + self.ema.update() + + for key_name in losses.keys(): + if "loss" in key_name: + if key_name not in epoch_losses.keys(): + epoch_losses[key_name] = [] + epoch_losses[key_name].append(losses[key_name].mean().item()) + + # Stop profiling after one epoch + if profiler: + profiler.stop() + profiler = None + + if self.prior_loader is not None: + for prior_datas in itertools.islice(self.prior_loader, data_len): + prior_datas = { + key: val.to(self.device) if torch.is_tensor(val) else val + for key, val in prior_datas.items() + } + prior_cond = { + key: val.to(self.device) if torch.is_tensor(val) else val + for key, val in prior_datas["conditions"].items() + } + prior_x_start = prior_datas["data"] + + self.opt.zero_grad() + t, weights = self.schedule_sampler.sample(prior_x_start.shape[0], self.device) + + if use_amp: + with autocast("cuda"): + _, prior_losses = self.diffuse(prior_x_start, t, prior_cond, noise=None, return_loss=True) + total_loss = (prior_losses["loss"] * weights).mean() + scaler.scale(total_loss).backward() + scaler.step(self.opt) + scaler.update() + else: + _, prior_losses = self.diffuse(prior_x_start, t, prior_cond, noise=None, return_loss=True) + total_loss = (prior_losses["loss"] * weights).mean() + total_loss.backward() + self.opt.step() + + for key_name in prior_losses.keys(): + if "loss" in key_name: + if key_name not in epoch_losses.keys(): + epoch_losses[key_name] = [] + epoch_losses[key_name].append(prior_losses[key_name].mean().item()) + + loss_str = (f"loss_data: {np.mean(epoch_losses['loss_data']):.6f}, " + f"loss_data_vel: {np.mean(epoch_losses['loss_data_vel']):.6f}, " + f"loss: {np.mean(epoch_losses['loss']):.6f}") + + epoch_avg_loss = np.mean(epoch_losses["loss"]) + + if self.epoch > 10 and epoch_avg_loss < self.best_loss: + self.save_checkpoint(filename="best") + + if epoch_avg_loss < self.best_loss: + self.best_loss = epoch_avg_loss + + epoch_process_bar.set_description( + f"Epoch {epoch_idx}/{self.config.trainer.epoch} | loss: {epoch_avg_loss:.6f} | best_loss: {self.best_loss:.6f}" + ) + self.logger.info( + f"Epoch {epoch_idx}/{self.config.trainer.epoch} | {loss_str} | best_loss: {self.best_loss:.6f}") + + save_freq = max(1, int(getattr(self.config.trainer, "save_freq", 1))) + if epoch_idx > 0 and epoch_idx % save_freq == 0: + self.save_checkpoint(filename=f"weights_{epoch_idx}") + self.evaluate_sampling(sampling_subset, save_folder_name="train_samples") + + for key_name in epoch_losses.keys(): + if "loss" in key_name: + self.tb_writer.add_scalar(f"train/{key_name}", np.mean(epoch_losses[key_name]), epoch_idx) + + self.scheduler.step() + + # Validation Phase + eval_freq = getattr(self.config.trainer, "eval_freq", 1) + if self.validation_dataloader is not None and (self.epoch % eval_freq == 0 + or self.epoch == self.config.trainer.epoch - 1): + current_validation_metric = self._run_validation_epoch() + # Log the validation metric to TensorBoard + if self.tb_writer and current_validation_metric is not None: + self.tb_writer.add_scalar("validation/DTW_distance", current_validation_metric, self.epoch) + # If the metric is better, save the best model + if (current_validation_metric is not None and current_validation_metric < self.best_validation_metric): + self.best_validation_metric = current_validation_metric + self.logger.info( + f"*** New best validation metric: {self.best_validation_metric:.4f} at epoch {self.epoch}. Saving best model. ***" + ) + self.save_checkpoint(filename="best_model_validation") + + # Compute validation loss on chunks + val_losses = [] + for val_batch in self.validation_dataloader: + # Recursively move to device and handle dicts + val_batch_device = {} + for key, v_item in val_batch.items(): + if torch.is_tensor(v_item): + val_batch_device[key] = v_item.to(self.device) + elif isinstance(v_item, dict): + val_batch_device[key] = { + sk: sv.to(self.device) if torch.is_tensor(sv) else sv + for sk, sv in v_item.items() + } + else: + val_batch_device[key] = v_item + + current_cond_for_diffuse = {k: v for k, v in val_batch_device.get("conditions", {}).items()} + if "metadata" in val_batch_device and "fluent_pose_length" in val_batch_device["metadata"]: + if "metadata" not in current_cond_for_diffuse: + current_cond_for_diffuse["metadata"] = {} + current_cond_for_diffuse["metadata"]["fluent_pose_length"] = val_batch_device["metadata"][ + "fluent_pose_length"] + + x_start = val_batch_device["data"] + t, weights = self.schedule_sampler.sample(x_start.shape[0], self.device) + with torch.no_grad(): + _, losses = self.diffuse(x_start, t, current_cond_for_diffuse, noise=None, return_loss=True) + batch_loss = (losses["loss"] * weights).mean().item() + val_losses.append(batch_loss) + if self.tb_writer and val_losses: + avg_val_loss = np.mean(val_losses) + self.tb_writer.add_scalar("validation/loss", avg_val_loss, self.epoch) + # --- Save the fixed 30 (or as defined in validation_save_num) validation predictions and corresponding GT results --- + save_dir = Path(self.config.save) / "validation_samples" / f"epoch_{self.epoch}" + mkdir(save_dir) + sample_indices = self.validation_sample_indices + val_save_loader = DataLoader( + Subset(val_dataset, sample_indices), + batch_size=len(sample_indices), + shuffle=False, + num_workers=self.config.trainer.workers, + pin_memory=True, + collate_fn=zero_pad_collator, + ) + for batch_idx, batch_data in enumerate(val_save_loader): + refs, preds = self._process_validation_batch(batch_data, batch_idx) + for i, (ref, pred) in enumerate(zip(refs, preds)): + idx = sample_indices[batch_idx * len(preds) + i] + ref_path = save_dir / f"ref_epoch{self.epoch}_idx{idx}.pose" + with open(ref_path, "wb") as f: + ref.write(f) + pred_path = save_dir / f"pred_epoch{self.epoch}_idx{idx}.pose" + with open(pred_path, "wb") as f: + pred.write(f) + self.logger.info(f"Saved {len(sample_indices)} validation GT and predictions to {save_dir}") + + best_path = "%s/best.pt" % (self.config.save) + self.load_checkpoint(best_path) + self.evaluate_sampling(sampling_subset, save_folder_name="best") + + def evaluate_sampling(self, dataloader: DataLoader, save_folder_name: str = "init_samples"): """ Perform inference and save generated samples from the model. This currently evaluates the model in a NON-AUTOREGRESSIVE way, predicting only the first chunk based on conditions. @@ -235,30 +694,32 @@ def get_original_dataset(dataset): fluent_clip = datas["data"].to(self.device) - cond = { - key: (val.to(self.device) if torch.is_tensor(val) else val) - for key, val in datas["conditions"].items() - } + cond = {key: (val.to(self.device) if torch.is_tensor(val) else val) for key, val in datas["conditions"].items()} - time, _ = self.schedule_sampler.sample( - patched_dataloader.batch_size, self.device - ) + time, _ = self.schedule_sampler.sample(patched_dataloader.batch_size, self.device) with torch.no_grad(): - pred_output_tensor, _ = self.diffuse( - fluent_clip, time, cond, noise=None, return_loss=False - ) + pred_output_tensor, _ = self.diffuse(fluent_clip, time, cond, noise=None, return_loss=False) fluent_clip_array = fluent_clip.cpu().numpy() pred_output_array = pred_output_tensor.cpu().numpy() - unnormed_gt_list = self.export_samples(fluent_clip_array, f"{self.save_dir}/{save_folder_name}", "gt") - unnormed_pred_list = self.export_samples(pred_output_array, f"{self.save_dir}/{save_folder_name}", "pred") + unnormed_fluent_clip = (fluent_clip_array * dataset.input_std + dataset.input_mean) + unnormed_pred_output = (pred_output_array * dataset.input_std + dataset.input_mean) + + self.export_samples(unnormed_fluent_clip, f"{self.save_dir}/{save_folder_name}", "gt") + self.export_samples(unnormed_pred_output, f"{self.save_dir}/{save_folder_name}", "pred") # Save the normalized fluent clip and predicted output as numpy arrays - np.save(f"{self.save_dir}/{save_folder_name}/gt_output_normed.npy", fluent_clip_array) - np.save(f"{self.save_dir}/{save_folder_name}/pred_output_normed.npy", pred_output_array) + np.save( + f"{self.save_dir}/{save_folder_name}/gt_output_normed.npy", + fluent_clip_array, + ) + np.save( + f"{self.save_dir}/{save_folder_name}/pred_output_normed.npy", + pred_output_array, + ) # Save the unnormalized fluent clip and predicted output as numpy arrays - unormed_gt_batch = np.stack(unnormed_gt_list, axis=0) # (B, T, K, D) - unormed_pred_batch = np.stack(unnormed_pred_list, axis=0) # (B, T, K, D) + unormed_gt_batch = np.stack(unnormed_fluent_clip, axis=0) # (B, T, K, D) + unormed_pred_batch = np.stack(unnormed_pred_output, axis=0) # (B, T, K, D) np.save(f"{self.save_dir}/{save_folder_name}/gt_output.npy", unormed_gt_batch) np.save(f"{self.save_dir}/{save_folder_name}/pred_output.npy", unormed_pred_batch) @@ -267,7 +728,7 @@ def get_original_dataset(dataset): else: print(f"Evaluate sampling {save_folder_name} at epoch {self.epoch}") - def export_samples(self, pose_output_normalized_np: np.ndarray, save_path: str, prefix: str) -> list: + def export_samples(self, pose_output_np: np.ndarray, save_path: str, prefix: str) -> list: """ Unnormalizes pose data using unnormalize_mean_std, exports to .pose format, and returns the unnormalized numpy data. Args: @@ -275,11 +736,9 @@ def export_samples(self, pose_output_normalized_np: np.ndarray, save_path: str, save_path: Path (string) where files will be saved. prefix: Prefix for file names, e.g., "gt" or "pred". """ - unnormalized_arrays = [] # Store unnormalized arrays here - - for i in range(pose_output_normalized_np.shape[0]): + for i in range(pose_output_np.shape[0]): - pose_array = pose_output_normalized_np[i] # (time, keypoints, 3) + pose_array = pose_output_np[i] # (time, keypoints, 3) time, keypoints, dim = pose_array.shape pose_array = pose_array.reshape(time, 1, keypoints, dim) @@ -291,24 +750,22 @@ def export_samples(self, pose_output_normalized_np: np.ndarray, save_path: str, pose_obj = Pose(self.pose_header, pose_body) # Unnormalize the pose data and normalize its size for export - unnorm_pose = unnormalize_mean_std(pose_obj) + # unnorm_pose = unnormalize_mean_std(pose_obj) # Scale the pose back for visualization - normalize_pose_size(unnorm_pose) + # normalize_pose_size(unnorm_pose) file_path = f"{save_path}/pose_{i}.{prefix}.pose" with open(file_path, "wb") as f: - unnorm_pose.write(f) + pose_obj.write(f) # self.logger.info(f"Saved pose file: {file_path}") - # Verify the file by reading it back with open(file_path, "rb") as f_check: Pose.read(f_check.read()) # Extract and store the unnormalized numpy data - unnorm_data_np = np.array(unnorm_pose.body.data.data.astype(pose_output_normalized_np.dtype)).squeeze(1) # (T, K, D) - unnormalized_arrays.append(unnorm_data_np) + # unnorm_data_np = np.array(unnorm_pose.body.data.data.astype(pose_output_normalized_np.dtype)).squeeze(1) # (T, K, D) + # unnormalized_arrays.append(unnorm_data_np) # If error occurs, the file was not written correctly # self.logger.info(f"Pose file {file_path} read successfully.") - return unnormalized_arrays diff --git a/fluent_pose_synthesis/data/load_data.py b/fluent_pose_synthesis/data/load_data.py index b779430..1bb2281 100644 --- a/fluent_pose_synthesis/data/load_data.py +++ b/fluent_pose_synthesis/data/load_data.py @@ -6,20 +6,28 @@ import torch import numpy as np +from tqdm import tqdm from torch.utils.data import Dataset from pose_format import Pose from pose_format.torch.masked.collator import zero_pad_collator from pose_anonymization.data.normalization import normalize_mean_std +import pickle +import hashlib class SignLanguagePoseDataset(Dataset): + def __init__( self, data_dir: Path, split: str, chunk_len: int, dtype=np.float32, + history_len: int = 5, limited_num: int = -1, + use_cache: bool = True, + cache_dir: Path = None, + force_reload: bool = False, ): """ Args: @@ -32,8 +40,45 @@ def __init__( self.data_dir = data_dir self.split = split self.chunk_len = chunk_len + self.history_len = history_len + self.window_len = chunk_len + history_len self.dtype = dtype + # Cache controls + self.use_cache = use_cache + self.force_reload = force_reload + # Determine cache directory + if cache_dir is None: + self.cache_dir = self.data_dir / "cache" + else: + self.cache_dir = cache_dir + self.cache_dir.mkdir(exist_ok=True) + # Collect data file mtimes for invalidation + split_dir = self.data_dir / self.split + all_files = list(split_dir.glob(f"{split}_*_original.pose")) + \ + list(split_dir.glob(f"{split}_*_updated.pose")) + \ + list(split_dir.glob(f"{split}_*_metadata.json")) + mtimes = [f.stat().st_mtime for f in all_files if f.exists()] + data_mtime = max(mtimes) if mtimes else 0 + # Build cache key + cache_params = { + 'data_dir': str(data_dir), + 'split': split, + 'chunk_len': chunk_len, + 'history_len': history_len, + 'dtype': str(dtype), + 'limited_num': limited_num, + 'data_mtime': data_mtime, + } + cache_key = hashlib.md5(str(cache_params).encode()).hexdigest() + self.cache_file = self.cache_dir / f"dataset_cache_{split}_{cache_key}.pkl" + # Try loading from cache + if self.use_cache and not self.force_reload and self.cache_file.exists(): + print(f"Loading dataset from cache: {self.cache_file}") + self._load_from_cache() + print(f"Dataset loaded from cache: {len(self.examples)} samples, split={split}") + return + # Store only file paths for now, load data on-the-fly # Each sample should have fluent (original), disfluent (updated), and metadata files self.examples = [] @@ -42,19 +87,15 @@ def __init__( if limited_num > 0: fluent_files = fluent_files[:limited_num] # Limit the number of samples to load - for fluent_file in fluent_files: + for fluent_file in tqdm(fluent_files, desc=f"Loading {split} examples"): # Construct corresponding disfluent and metadata file paths based on the file name disfluent_file = fluent_file.with_name(fluent_file.name.replace("_original.pose", "_updated.pose")) metadata_file = fluent_file.with_name(fluent_file.name.replace("_original.pose", "_metadata.json")) - self.examples.append( - { - "fluent_path": fluent_file, - "disfluent_path": disfluent_file, - "metadata_path": metadata_file, - } - ) - - print(f"Dataset initialized with {len(self.examples)} samples. Split: {split}") + self.examples.append({ + "fluent_path": fluent_file, + "disfluent_path": disfluent_file, + "metadata_path": metadata_file, + }) # Initialize pose_header from the first fluent .pose file if self.examples: @@ -69,11 +110,139 @@ def __init__( else: self.pose_header = None + self.fluent_clip_list = [] + self.fluent_mask_list = [] + self.disfluent_clip_list = [] + + self.train_indices = [] + + for example_idx, example in enumerate( + tqdm(self.examples, desc=f"Processing pose files for {split}", total=len(self.examples))): + with open(example["fluent_path"], "rb") as f: + fluent_pose = Pose.read(f.read()) + with open(example["disfluent_path"], "rb") as f: + disfluent_pose = Pose.read(f.read()) + + fluent_data = np.array(fluent_pose.body.data.astype(self.dtype)) + fluent_mask = fluent_pose.body.data.mask + disfluent_data = np.array(disfluent_pose.body.data.astype(self.dtype)) + fluent_length = fluent_data.shape[0] + + self.fluent_clip_list.append(fluent_data[:, 0]) + self.fluent_mask_list.append(fluent_mask[:, 0]) + self.disfluent_clip_list.append(disfluent_data[:, 0]) + + if self.split == "validation": + self.train_indices = np.arange(len(self.examples)).reshape(-1, 1) + else: + for example_idx, example in enumerate( + tqdm(self.examples, desc=f"Building indices for {split}", total=len(self.examples))): + fluent_data = self.fluent_clip_list[example_idx] + fluent_length = fluent_data.shape[0] + + if fluent_length >= self.chunk_len: + zero_indices = np.array([-1] * self.history_len + list(range(self.chunk_len))).reshape(1, -1) + clip_indices = np.arange(0, fluent_length - self.window_len + 1, 1)[:, None] + np.arange( + self.window_len) + clip_indices = np.concatenate((zero_indices, clip_indices), axis=0) + clip_indices_with_idx = np.hstack(( + np.full( + (len(clip_indices), 1), + example_idx, + dtype=clip_indices.dtype, + ), + clip_indices, + )) + else: + zero_indices = np.array([-1] * self.history_len + list(range(fluent_length)) + [-2] * + (self.chunk_len - fluent_length)).reshape(1, -1) + clip_indices_list = [] + for i in range(self.window_len): + is_history_part = i < self.history_len + if i < fluent_length: + clip_indices_list.append(i) + else: + if is_history_part: + clip_indices_list.append(-1) + else: + clip_indices_list.append(-2) + clip_indices = np.array(clip_indices_list).reshape(1, -1) + clip_indices = np.concatenate((zero_indices, clip_indices), axis=0) + clip_indices_with_idx = np.hstack(( + np.full( + (len(clip_indices), 1), + example_idx, + dtype=clip_indices.dtype, + ), + clip_indices, + )) + + self.train_indices.append(clip_indices_with_idx) + + self.train_indices = np.concatenate(self.train_indices, axis=0) + + concatenated_fluent_clips = np.concatenate(self.fluent_clip_list, axis=0) + self.input_mean = concatenated_fluent_clips.mean(axis=0, keepdims=True) # axis=0 + self.input_std = concatenated_fluent_clips.std(axis=0, keepdims=True) + + concatenated_disfluent_clips = np.concatenate(self.disfluent_clip_list, axis=0) + self.condition_mean = concatenated_disfluent_clips.mean(axis=0, keepdims=True) + self.condition_std = concatenated_disfluent_clips.std(axis=0, keepdims=True) + + for i in range(len(self.examples)): + self.fluent_clip_list[i] = (self.fluent_clip_list[i] - self.input_mean) / self.input_std + self.disfluent_clip_list[i] = (self.disfluent_clip_list[i] - self.condition_mean) / self.condition_std + + # Save cache for future runs + if self.use_cache: + print(f"Saving dataset to cache: {self.cache_file}") + self._save_to_cache() + print("Dataset initialized with {} samples. Split: {}".format(len(self.examples), split)) + + def _save_to_cache(self): + """Serialize dataset to cache file.""" + data = { + 'examples': self.examples, + 'pose_header': self.pose_header, + 'fluent_clip_list': self.fluent_clip_list, + 'fluent_mask_list': self.fluent_mask_list, + 'disfluent_clip_list': self.disfluent_clip_list, + 'train_indices': self.train_indices, + 'input_mean': self.input_mean, + 'input_std': self.input_std, + 'condition_mean': self.condition_mean, + 'condition_std': self.condition_std, + } + try: + with open(self.cache_file, 'wb') as f: + pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) + except Exception as e: + print(f"[WARNING] Failed to save cache: {e}") + + def _load_from_cache(self): + """Load dataset from cache file.""" + try: + with open(self.cache_file, 'rb') as f: + data = pickle.load(f) + self.examples = data['examples'] + self.pose_header = data['pose_header'] + self.fluent_clip_list = data['fluent_clip_list'] + self.fluent_mask_list = data['fluent_mask_list'] + self.disfluent_clip_list = data['disfluent_clip_list'] + self.train_indices = data['train_indices'] + self.input_mean = data['input_mean'] + self.input_std = data['input_std'] + self.condition_mean = data['condition_mean'] + self.condition_std = data['condition_std'] + except Exception as e: + print(f"[WARNING] Failed to load cache, rebuilding: {e}") + # Fall back to fresh build + def __len__(self) -> int: """ Returns the number of samples in the dataset. """ - return len(self.examples) + return len(self.train_indices) def __getitem__(self, idx: int) -> Dict[str, Any]: """ @@ -82,101 +251,104 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: Args: idx (int): Index of the sample to retrieve. """ - - sample = self.examples[idx] - - # Load pose sequences and metadata from disk - with open(sample["fluent_path"], "rb") as f: - fluent_pose = Pose.read(f.read()) - with open(sample["disfluent_path"], "rb") as f: - disfluent_pose = Pose.read(f.read()) - with open(sample["metadata_path"], "r", encoding="utf-8") as f: - metadata = json.load(f) - - # print(f"[DEBUG][Before Norm] Fluent raw data mean: {fluent_pose.body.data.mean(axis=(0, 1, 2))} std {fluent_pose.body.data.std(axis=(0, 1, 2))}") - # print(f"[DEBUG][Before Norm] Disfluent raw data mean: {disfluent_pose.body.data.mean(axis=(0, 1, 2))} std {disfluent_pose.body.data.std(axis=(0, 1, 2))}") - - # Normalize the pose data - fluent_pose = normalize_mean_std(fluent_pose) - disfluent_pose = normalize_mean_std(disfluent_pose) - - # print(f"DEBUG][After Norm] Fluent normalized data mean:: {fluent_pose.body.data.mean(axis=(0, 1, 2))} std {fluent_pose.body.data.std(axis=(0, 1, 2))}") - # print(f"[DEBUG][After Norm] Disfluent normalized data mean: {disfluent_pose.body.data.mean(axis=(0, 1, 2))} std {disfluent_pose.body.data.std(axis=(0, 1, 2))}") - - fluent_data = np.array(fluent_pose.body.data.astype(self.dtype)) - fluent_mask = fluent_pose.body.data.mask - disfluent_data = np.array(disfluent_pose.body.data.astype(self.dtype)) - - fluent_length = len(fluent_data) - - # 1. Randomly sample the start index for the fluent (target) chunk - if fluent_length <= self.chunk_len: - start_idx = 0 - target_len = fluent_length - history_len = 0 - else: - start_idx = random.randint(0, fluent_length - self.chunk_len) - target_len = self.chunk_len - history_len = start_idx - - # 2. Extract target chunk (y_k) and history chunk (y_1, ..., y_{k-1}) - target_chunk = fluent_data[start_idx : start_idx + target_len] - target_mask = fluent_mask[start_idx : start_idx + target_len] - - if history_len > 0: - history_chunk = fluent_data[:history_len] - else: - # MODIFICATION: Force minimum length of 1 for previous_output if empty - history_chunk = np.zeros((1,) + fluent_data.shape[1:], dtype=self.dtype) # create a single empty frame - # The purpose of this is to ensure the current collate_fn works - # else: - # # No history chunk available, create an empty array with time dimension 0 - # history_chunk = np.empty((0,) + fluent_data.shape[1:], dtype=self.dtype) - - # 3. Prepare the entire disfluent sequence as condition - disfluent_seq = disfluent_data - - # 4. Pad target chunk if its actual length is less than chunk_len - if target_chunk.shape[0] < self.chunk_len: - pad_len = self.chunk_len - target_chunk.shape[0] - # Padding 0s for target chunk - padding_shape_data = (pad_len,) + target_chunk.shape[1:] - target_padding = np.zeros(padding_shape_data, dtype=self.dtype) - target_chunk = np.concatenate([target_chunk, target_padding], axis=0) - # Padding for mask (True for masked) - mask_padding = np.ones((pad_len,) + target_mask.shape[1:], dtype=bool) - target_mask = np.concatenate([target_mask, mask_padding], axis=0) - - # 5. Convert numpy arrays to torch tensors + # Load metadata JSON for the current example to access original lengths + motion_idx = self.train_indices[idx][0] + metadata_path = self.examples[motion_idx]["metadata_path"] + with open(metadata_path, "r", encoding="utf-8") as mf: + meta_json = json.load(mf) + orig_fluent_len = meta_json.get("fluent_pose_length", None) + orig_disfluent_len = meta_json.get("disfluent_pose_length", None) + + if self.split == "validation": + motion_idx = self.train_indices[idx][0] + full_seq = torch.from_numpy(self.fluent_clip_list[motion_idx].astype(np.float32)) + disfluent_seq = torch.from_numpy(self.disfluent_clip_list[motion_idx].astype(np.float32)) + + # Construct previous_output for validation + history_len = self.history_len + num_keypoints = self.fluent_clip_list[motion_idx].shape[1] # K + num_dims = self.fluent_clip_list[motion_idx].shape[2] # D + + # Create a zero tensor for previous_output + # Its shape should be (history_len, K, D) + previous_output = torch.zeros((history_len, num_keypoints, num_dims), dtype=full_seq.dtype) + + metadata = { + "original_example_index": int(motion_idx), "fluent_pose_length": orig_fluent_len, + "disfluent_pose_length": orig_disfluent_len + } + result = { + "data": full_seq, # Full sequence, mainly used as reference for validation metrics + "conditions": { + "input_sequence": disfluent_seq, # Disfluent sequence as condition + "previous_output": previous_output, # Now the initial history is all zeros + }, + "full_fluent_reference": full_seq, # Full reference for DTW evaluation + "metadata": metadata, + } + return result + + item_frame_indice = self.train_indices[idx] + motion_idx, frame_indices = item_frame_indice[0], item_frame_indice[1:] + history_indices = frame_indices[:self.history_len] + target_indices = frame_indices[self.history_len:] + + history_chunk = self.fluent_clip_list[motion_idx][history_indices] + disfluent_seq = self.disfluent_clip_list[motion_idx] + # Process target_chunk and target_mask, set frame at -2 to all-zero frame with mask True, others remain unchanged + target_chunk_frames = [] + target_mask_frames = [] + single_frame_shape = self.fluent_clip_list[motion_idx][0].shape + single_mask_shape = self.fluent_mask_list[motion_idx][0].shape + for t_idx in target_indices: + if t_idx == -2: + target_chunk_frames.append(np.zeros(single_frame_shape, dtype=np.float32)) + target_mask_frames.append(np.ones(single_mask_shape, dtype=bool)) + else: + target_chunk_frames.append(self.fluent_clip_list[motion_idx][t_idx]) + target_mask_frames.append(self.fluent_mask_list[motion_idx][t_idx]) + target_chunk = np.stack(target_chunk_frames, axis=0) + target_mask = np.stack(target_mask_frames, axis=0) + + history_chunk[history_indices == -1].fill(0) + + # Convert numpy arrays to torch tensors target_chunk = torch.from_numpy(target_chunk.astype(np.float32)) history_chunk = torch.from_numpy(history_chunk.astype(np.float32)) disfluent_seq = torch.from_numpy(disfluent_seq.astype(np.float32)) - target_mask = torch.from_numpy(target_mask) # Bool tensor + target_mask = torch.from_numpy(target_mask) # Bool tensor - # 6. Squeeze person dimension - target_chunk = target_chunk.squeeze(1) # (T_chunk, K, D) - history_chunk = history_chunk.squeeze(1) # (T_hist, K, D) - disfluent_seq = disfluent_seq.squeeze(1) # (T_disfl, K, D) - target_mask = target_mask.squeeze(1) # (T_chunk, K, D) - - # 7. Create conditions dictionary - # Later, zero_pad_collator will handle padding T_disfl and T_hist across the batch + # Create conditions dictionary conditions = { - "input_sequence": disfluent_seq, # (T_disfl, K, D) - "previous_output": history_chunk, # (T_hist, K, D) - "target_mask": target_mask # (T_chunk, K, D) + "input_sequence": disfluent_seq, # (T_disfl, K, D) + "previous_output": history_chunk, # (T_hist, K, D) + "target_mask": target_mask, # (T_chunk, K, D) } - # print(f"DEBUG Dataset idx {idx}:") - # print(f" target_chunk shape: {target_chunk.shape}") - # print(f" input_sequence shape: {disfluent_seq.shape}") - # print(f" previous_output shape: {history_chunk.shape}") - # print(f" target_mask shape: {target_mask.shape}") + # Get motion_idx, which points to the original sample index in self.examples + item_frame_indice = self.train_indices[idx] # Assume split is not 'test', or test also uses train_indices + motion_idx = item_frame_indice[0] - return { - "data": target_chunk, # (T_chunk, K, D) + # Create metadata dictionary + metadata = { + "original_example_index": int(motion_idx), # Ensure it is Python int type + "original_disfluent_filepath": str(self.examples[motion_idx]["disfluent_path"]), + "fluent_pose_length": orig_fluent_len, + "disfluent_pose_length": orig_disfluent_len + } + + # Build base return dictionary + result = { + "data": target_chunk, # (T_chunk, K, D) "conditions": conditions, + "metadata": metadata, } + # If validation set, append full fluent reference sequence + if self.split == "validation": + # Take the full sequence from pre-normalized fluent_clip_list and convert to tensor + full_seq = torch.from_numpy(self.fluent_clip_list[motion_idx].astype(np.float32)) + result["full_fluent_reference"] = full_seq # (T_full, K, D) + return result def example_dataset(): @@ -206,7 +378,7 @@ def example_dataset(): batch = next(iter(dataloader)) print("Batch Keys:", batch.keys()) - print("Conditions Keys:", batch['conditions'].keys()) + print("Conditions Keys:", batch["conditions"].keys()) print("\nShapes:") print(f" data (Target Chunk): {batch['data'].shape}") @@ -233,27 +405,3 @@ def example_dataset(): # if __name__ == '__main__': # example_dataset() - - -# Example Output: -# Dataset initialized with 128 samples. Split: train -# Batch Keys: dict_keys(['data', 'conditions']) -# Conditions Keys: dict_keys(['input_sequence', 'previous_output', 'target_mask']) - -# Shapes: -# data (Target Chunk): torch.Size([32, 40, 178, 3]) -# conditions['input_sequence'] (Disfluent): torch.Size([32, 359, 178, 3]) -# conditions['previous_output'] (History): torch.Size([32, 110, 178, 3]) -# conditions['target_mask']: torch.Size([32, 40, 178, 3]) - -# Normalization Stats (Shapes): -# Fluent Mean: torch.Size([1, 178, 3]) -# Fluent Std: torch.Size([1, 178, 3]) -# Disfluent Mean: torch.Size([1, 178, 3]) -# Disfluent Std: torch.Size([1, 178, 3]) - -# Sample Values (first element of first sequence): -# Target Chunk (first 5 flattened): tensor([ 9.0694e-02, 7.7781e-01, -7.0343e+02, 1.9091e-01, -8.7535e-01]) -# History Chunk (first 5 flattened): tensor([0., 0., 0., 0., 0.]) -# Disfluent Seq (first 5 flattened): tensor([ 0.1327, 1.0505, -1.7174, 0.2764, -0.7866]) -# Target Mask (first 5 flattened): tensor([False, False, False, False, False]) diff --git a/fluent_pose_synthesis/infer.py b/fluent_pose_synthesis/infer.py new file mode 100644 index 0000000..5927cdf --- /dev/null +++ b/fluent_pose_synthesis/infer.py @@ -0,0 +1,435 @@ +# Example usage: +# python -m fluent_pose_synthesis.infer \ +# -i assets/sample_dataset \ +# -c save/debug_run/4th_train_whole_dataset_continued_output/4th_train_whole_dataset_continued_output/config.json \ +# -r save/debug_run/4th_train_whole_dataset_continued_output/4th_train_whole_dataset_continued_output/best_model_validation.pt \ +# -o save/debug_run/4th_train_whole_dataset_continued_output/4th_train_whole_dataset_continued_output/infer_results_validation_progressive \ +# --batch_size 1 \ +# --chunk_size 40 \ +# --max_len 160 \ +# --stop_threshold 1e-4 \ +# --seed 1234 \ +# --progressive + +import argparse +import json +from pathlib import Path, PosixPath +from types import SimpleNamespace +import torch +from torch.utils.data import DataLoader +from pose_format.torch.masked.collator import zero_pad_collator +import numpy as np +from numpy.core.multiarray import scalar +from numpy import dtype +import torch.serialization +import logging +from tqdm import tqdm + +# CAMDM and project imports +from CAMDM.diffusion.create_diffusion import create_gaussian_diffusion +from CAMDM.utils.common import fixseed +from fluent_pose_synthesis.core.models import SignLanguagePoseDiffusion +from fluent_pose_synthesis.data.load_data import SignLanguagePoseDataset +from fluent_pose_synthesis.core.training import _ConditionalWrapper +from pose_format import Pose +from pose_format.numpy.pose_body import NumPyPoseBody + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +torch.serialization.add_safe_globals([ + SimpleNamespace, + Path, + PosixPath, + scalar, + dtype, + np.int64, + np.int32, + np.float64, + np.float32, + np.bool_, +]) + + +def dict_to_namespace(d): + if isinstance(d, dict): + return SimpleNamespace(**{k: dict_to_namespace(v) for k, v in d.items()}) + elif isinstance(d, list): + return [dict_to_namespace(item) for item in d] + else: + return d + + +def convert_namespace_to_dict(namespace_obj): + if isinstance(namespace_obj, SimpleNamespace): + return {k: convert_namespace_to_dict(v) for k, v in vars(namespace_obj).items()} + elif isinstance(namespace_obj, dict): + return {k: convert_namespace_to_dict(v) for k, v in namespace_obj.items()} + elif isinstance(namespace_obj, (Path, PosixPath)): + return str(namespace_obj) + elif isinstance(namespace_obj, torch.device): + return str(namespace_obj) + else: + return namespace_obj + + +def load_checkpoint_and_config(model, checkpoint_path, device): + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) + model.load_state_dict(checkpoint["state_dict"]) + print(f"Loaded model checkpoint from {checkpoint_path}") + loaded_config_dict = checkpoint.get("config", None) + if loaded_config_dict: + if isinstance(loaded_config_dict, SimpleNamespace): + loaded_config_dict = convert_namespace_to_dict(loaded_config_dict) + print("[INFO] Config found in checkpoint.") + return model, loaded_config_dict + return model, None + + +def main(): + parser = argparse.ArgumentParser(description="Autoregressive inference for fluent pose synthesis") + parser.add_argument("-i", "--input", required=True, type=str, + help="Path to disfluent input data directory (should contain a 'test' split)") + parser.add_argument("-c", "--config", required=True, type=str, + help="Path to model config JSON file (from training)") + parser.add_argument("-r", "--resume", required=True, type=str, help="Path to model checkpoint (.pt)") + parser.add_argument("-o", "--output", default="output/infer_results", type=str, + help="Directory to save generated outputs") + parser.add_argument("--batch_size", default=4, type=int, help="Batch size for inference.") + parser.add_argument("--chunk_size", default=10, type=int, + help="Number of frames to attempt to generate per autoregressive step.") + parser.add_argument("--max_len", default=100, type=int, help="Maximum total frames to generate.") + parser.add_argument("--stop_threshold", default=1e-5, type=float, + help="Threshold for mean absolute value of generated chunk to detect stop condition.") + parser.add_argument("--seed", default=42, type=int, help="Random seed.") + parser.add_argument("--progressive", action="store_true", + help="Use progressive sampling (p_sample_loop_progressive) instead of p_sample_loop") + args = parser.parse_args() + + fixseed(args.seed) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Using device: {device}") + + # Load Configuration + try: + with open(args.config, "r", encoding="utf-8") as f: + config_dict = json.load(f) + print(f"[INFO] Successfully loaded config from {args.config} as JSON.") + except json.JSONDecodeError: + print(f"[WARNING] Could not parse {args.config} as JSON. Attempting to eval as SimpleNamespace string...") + with open(args.config, "r", encoding="utf-8") as f: + config_namespace = eval( + f.read(), { + "SimpleNamespace": SimpleNamespace, "namespace": SimpleNamespace, "PosixPath": PosixPath, "device": + torch.device, "Path": Path + }) # type: ignore + config_dict = convert_namespace_to_dict(config_namespace) + standard_json_path = Path(args.config).with_suffix('.fixed.json') + with open(standard_json_path, "w", encoding="utf-8") as f_json: + json.dump(config_dict, f_json, indent=4) + print(f"[INFO] Saved successfully parsed config to {standard_json_path}") + + config = dict_to_namespace(config_dict) + config.inference = SimpleNamespace(chunk_size=args.chunk_size, max_len=args.max_len, + stop_threshold=args.stop_threshold) + config.device = device + config.inference.progressive = args.progressive + + print("Loading dataset...") + dataset = SignLanguagePoseDataset( + data_dir=Path(args.input), + split="validation", + # split="train", # use trainset for now during overfitting stage + chunk_len=config.arch.chunk_len, # From training config + history_len=getattr(config.arch, "history_len", 5), # Use default from load_data.py or from config if added + dtype=np.float32, + limited_num=-1 # Load all samples from the test set + ) + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + collate_fn=zero_pad_collator, + num_workers=0 # num_workers=0 for easier debugging + ) + print(f"Dataset loaded with {len(dataset)} samples.") + + data_input_mean_np = np.array(dataset.input_mean).squeeze() + data_input_std_np = np.array(dataset.input_std).squeeze() + data_cond_mean_np = np.array(dataset.condition_mean).squeeze() + data_cond_std_np = np.array(dataset.condition_std).squeeze() + + logger.debug(f"Squeezed data_input_mean_np shape: {data_input_mean_np.shape}") + logger.debug(f"Squeezed data_input_std_np shape: {data_input_std_np.shape}") + logger.debug(f"Squeezed data_input_mean_np value (first few elements):\n{data_input_mean_np.ravel()[:5]}") + + expected_stat_shape_suffix = (config.arch.keypoints, config.arch.dims) + logger.debug(f"Expected stat shape suffix: {expected_stat_shape_suffix}") + + assert data_input_mean_np.shape[-len(expected_stat_shape_suffix):] == expected_stat_shape_suffix, \ + f"Mean shape mismatch. Actual: {data_input_mean_np.shape}, Expected suffix: {expected_stat_shape_suffix}" + assert data_input_std_np.shape[-len(expected_stat_shape_suffix):] == expected_stat_shape_suffix, \ + f"Std shape mismatch. Actual: {data_input_std_np.shape}, Expected suffix: {expected_stat_shape_suffix}" + + pose_header = dataset.pose_header + logger.info(f"Pose header loaded from dataset.") + + logger.info(f"Initializing model...") + input_feats = config.arch.keypoints * config.arch.dims + model = SignLanguagePoseDiffusion(input_feats=input_feats, chunk_len=config.arch.chunk_len, + keypoints=config.arch.keypoints, dims=config.arch.dims, + latent_dim=config.arch.latent_dim, ff_size=config.arch.ff_size, + num_layers=config.arch.num_layers, num_heads=config.arch.num_heads, + dropout=getattr(config.arch, "dropout", + 0.2), activation=getattr(config.arch, "activation", "gelu"), + arch=config.arch.decoder, cond_mask_prob=0, device=config.device, + batch_first=getattr(config.arch, "batch_first", True)).to(config.device) + + model, _ = load_checkpoint_and_config(model, args.resume, device) + model.eval() + + diffusion = create_gaussian_diffusion(config) + print(f"ACTUAL diffusion.num_timesteps = {diffusion.num_timesteps}") + print(f"Betas: {diffusion.betas}") + print(f"Alphas Cumprod: {diffusion.alphas_cumprod}") + print(f"Sqrt Alphas Cumprod: {diffusion.sqrt_alphas_cumprod}") + print(f"Sqrt One Minus Alphas Cumprod: {diffusion.sqrt_one_minus_alphas_cumprod}") + + save_dir = Path(args.output) + save_dir.mkdir(parents=True, exist_ok=True) + print(f"Output will be saved to: {save_dir}") + + processed_originals = {} + unique_outputs_generated = 0 # Track how many unique fluent sequences are actually generated + + print( + f"Starting autoregressive inference: target_chunk_size_per_step={args.chunk_size}, max_len={args.max_len}, stop_threshold={args.stop_threshold}" + ) + total_batches = len(dataloader) + + history_len_for_inference = getattr(config.arch, "history_len", 5) + print(f"[INFO] Using history_len_for_inference: {history_len_for_inference}") + + for batch_idx, batch_data in enumerate(tqdm(dataloader, desc="Batches")): + logger.info(f"Processing batch {batch_idx + 1}/{total_batches}") + + disfluent_cond_seq_batch_loader = batch_data["conditions"]["input_sequence"].to(device) + metadata_from_collate = batch_data.get("metadata") + + tasks_for_this_batch = [] + + current_actual_batch_size = disfluent_cond_seq_batch_loader.shape[0] + + for i in range(current_actual_batch_size): # Iterate over each logical sample in the batch + current_original_id = None + + # Extract metadata of the i-th sample from the already collated metadata_from_collate + # Prefer using file path as ID + if "original_disfluent_filepath" in metadata_from_collate and i < len( + metadata_from_collate["original_disfluent_filepath"]): + current_original_id = metadata_from_collate["original_disfluent_filepath"][i] + # If file path is missing or index is out of range, try using example_index + elif "original_example_index" in metadata_from_collate and i < len( + metadata_from_collate["original_example_index"]): + current_original_id = f"example_idx_{metadata_from_collate['original_example_index'][i]}" + + if current_original_id not in processed_originals: + individual_disfluent_seq_tensor = disfluent_cond_seq_batch_loader[i:i + 1].permute(0, 2, 3, 1) + tasks_for_this_batch.append((current_original_id, i, individual_disfluent_seq_tensor)) + processed_originals[current_original_id] = "PROCESSING" + else: + print(f" Skipping already processed or queued original sequence ID: {current_original_id}") + + if not tasks_for_this_batch: + print(" No new sequences to process in this batch.") + continue # Skip to the next batch + + # Combine all unique tasks collected in this batch into a new "mini-batch" for generation + # The shape of disfluent_cond_seq_for_generation will be (num_unique_tasks, K, D, T_in) + disfluent_cond_seq_for_generation = torch.cat([task[2] for task in tasks_for_this_batch], dim=0) + + current_batch_size_for_generation = disfluent_cond_seq_for_generation.shape[0] + K, D, T_in = disfluent_cond_seq_for_generation.shape[1:] # Get dimensions from the merged tensor + + print(f" Combined {current_batch_size_for_generation} unique sequences for generation from this batch.") + print( + f" Effective disfluent condition shape for model: (B_eff, K, D, T) = ({current_batch_size_for_generation}, {K}, {D}, {T_in})" + ) + + # Start autoregressive generation for this mini-batch + generated_fluent_seq = torch.empty(current_batch_size_for_generation, K, D, 0, + device=device) # (B_eff, K, D, T_gen) + active_sequences = torch.ones(current_batch_size_for_generation, dtype=torch.bool, device=device) + num_autoregressive_steps = (args.max_len + args.chunk_size - 1) // args.chunk_size + + for step in range(num_autoregressive_steps): + if not active_sequences.any(): + print(f" Step {step + 1}: All sequences stopped for this effective batch.") + break + current_generated_len = generated_fluent_seq.shape[3] + if current_generated_len >= args.max_len: + print(f" Step {step + 1}: Reached max_len ({args.max_len}).") + break + + n_frames_to_generate_this_step = min(args.chunk_size, args.max_len - current_generated_len) + target_chunk_shape = (current_batch_size_for_generation, K, D, n_frames_to_generate_this_step) + print( + f" Step {step + 1}: Gen {n_frames_to_generate_this_step} frames. Total gen: {current_generated_len}.") + + # Construct fixed-length history tensor of size history_len_for_inference + if current_generated_len < history_len_for_inference: + # Left-pad with zeros if not enough frames generated + pad_frames = history_len_for_inference - current_generated_len + padding = torch.zeros(current_batch_size_for_generation, K, D, pad_frames, device=device) + effective_previous_output = torch.cat([padding, generated_fluent_seq], dim=3) + else: + # Slice the last history_len_for_inference frames + effective_previous_output = generated_fluent_seq[:, :, :, -history_len_for_inference:] + + print(f" Effective previous_output shape for model: {effective_previous_output.shape}") + + model_kwargs_y = { + "input_sequence": disfluent_cond_seq_for_generation, "previous_output": effective_previous_output + } + # Use the same ConditionalWrapper as in training + wrapped_model = _ConditionalWrapper(model, model_kwargs_y) + + # --- Progressive or non-progressive sampling and per-step saving --- + with torch.no_grad(): + if config.inference.progressive: + # Prepare to collect all pred_xstart for each step in the progressive sampling + all_steps_pred_xstart = [] + sampler = diffusion.p_sample_loop_progressive( + model=wrapped_model, + shape=target_chunk_shape, + clip_denoised=False, + model_kwargs={"y": model_kwargs_y}, + progress=False, + ) + for prog_step, sample in enumerate(tqdm(sampler, desc=f"Prog steps batch {batch_idx+1}")): + pred_xstart = sample["pred_xstart"].cpu().numpy() # shape (B_eff, K, D, chunk) + for task_idx, (original_id, _, _) in enumerate(tasks_for_this_batch): + if "/" in original_id or "\\" in original_id: + filename_base_from_id = Path(original_id).stem + else: + filename_base_from_id = original_id + np.save(save_dir / f"pose_pred_fluent_{filename_base_from_id}_step{prog_step}.npy", + pred_xstart[task_idx]) + single_pred = pred_xstart[task_idx] # (K, D, chunk) + # transpose to (chunk, K, D) + single_pred = np.transpose(single_pred, (2, 0, 1)) # (chunk, K, D) + # reshape to (chunk, 1, K, D) + single_pred = single_pred.reshape(single_pred.shape[0], 1, single_pred.shape[1], + single_pred.shape[2]) # (chunk, 1, K, D) + # unnormalize + unnorm_pred = single_pred * data_input_std_np + data_input_mean_np + # confidence + confidence = np.ones((unnorm_pred.shape[0], 1, unnorm_pred.shape[2]), dtype=np.float32) + fps_to_use = pose_header.fps if hasattr(pose_header, + 'fps') and pose_header.fps > 0 else 25.0 + pose_body = NumPyPoseBody(fps=fps_to_use, data=unnorm_pred, confidence=confidence) + pose_obj = Pose(pose_header, pose_body) + with open(save_dir / f"pose_pred_fluent_{filename_base_from_id}_step{prog_step}.pose", + "wb") as f: + pose_obj.write(f) + all_steps_pred_xstart.append(pred_xstart) + # Use the last step as the generated chunk (convert to tensor, move to device) + generated_chunk = torch.tensor(all_steps_pred_xstart[-1], device=device) + else: + logger.info(f"Using non-progressive sampling for batch {batch_idx+1}, step {step+1}") + # p_sample_loop returns a tensor of shape (B, K, D, chunk) + generated_chunk = diffusion.p_sample_loop( + model=wrapped_model, + shape=target_chunk_shape, + clip_denoised=False, + model_kwargs={"y": model_kwargs_y}, + progress=False, + ) + + mean_abs_for_chunk = torch.zeros(current_batch_size_for_generation, device=device) + if generated_chunk.numel() > 0: + mean_abs_for_chunk[active_sequences] = generated_chunk[active_sequences].abs().mean(dim=(1, 2, 3)) + + newly_stopped_sequences = (mean_abs_for_chunk < args.stop_threshold) & active_sequences + if newly_stopped_sequences.any(): + print( + f" Sequences (effective batch indices) {newly_stopped_sequences.nonzero(as_tuple=True)[0].tolist()} stopped this step." + ) + + active_sequences = active_sequences & (~newly_stopped_sequences) + generated_chunk[newly_stopped_sequences] = torch.zeros_like(generated_chunk[newly_stopped_sequences]) + generated_fluent_seq = torch.cat([generated_fluent_seq, generated_chunk], dim=3) + + print( + f" Finished generation for {current_batch_size_for_generation} unique sequences. Final length: {generated_fluent_seq.shape[3]}" + ) + + # --- Save results --- + pred_fluent_normed_np_bEff_T_K_D = generated_fluent_seq.permute(0, 3, 1, + 2).cpu().numpy() # (B_eff, T_final, K, D) + + input_disfluent_normed_np_bEff_T_K_D = disfluent_cond_seq_for_generation.permute(0, 3, 1, 2).cpu().numpy() + + for task_idx, (original_id, original_batch_idx, _) in enumerate(tasks_for_this_batch): + # Extract current task data from the merged batch + current_pred_fluent_normed_np = pred_fluent_normed_np_bEff_T_K_D[task_idx] # (T_final, K, D) + current_input_disfluent_normed_np = input_disfluent_normed_np_bEff_T_K_D[task_idx] # (T_in, K, D) + + # --- Unnormalize and save logic --- + # Unnormalize predicted results using data_input_mean_np, data_input_std_np (from dataset) + unnorm_pred_fluent_np_tdk = current_pred_fluent_normed_np * data_input_std_np + data_input_mean_np + + # Unnormalize input using data_cond_mean_np, data_cond_std_np (from dataset) + # (Assume these stats have already been obtained from the dataset earlier in infer.py) + unnorm_input_disfluent_np_tdk = current_input_disfluent_normed_np * data_cond_std_np + data_cond_mean_np + + # Use original_id as basis for filename to ensure uniqueness + # original_id may be a file path, extract filename part + if "/" in original_id or "\\" in original_id: # If it's a file path + filename_base_from_id = Path(original_id).stem + else: # If it's "example_idx_X" or "dataloader_item_X" + filename_base_from_id = original_id + + save_prefix_pred = f"pose_pred_fluent_{filename_base_from_id}" + save_prefix_input = f"pose_input_disfluent_{filename_base_from_id}" + + # Save predicted fluent pose (.pose and .npy) + T_final_sample, K_sample, D_sample = unnorm_pred_fluent_np_tdk.shape + if T_final_sample > 0 and pose_header is not None: + unnorm_pose_data_tpkd = unnorm_pred_fluent_np_tdk.reshape(T_final_sample, 1, K_sample, D_sample) + unnorm_confidence = np.ones((T_final_sample, 1, K_sample), dtype=np.float32) + fps_to_use = pose_header.fps if hasattr(pose_header, 'fps') and pose_header.fps > 0 else 25.0 + unnorm_pose_body = NumPyPoseBody(fps=fps_to_use, data=unnorm_pose_data_tpkd, + confidence=unnorm_confidence) + unnorm_pose_obj = Pose(pose_header, unnorm_pose_body) + + with open(save_dir / f"{save_prefix_pred}.pose", "wb") as f: + unnorm_pose_obj.write(f) + np.save(save_dir / f"{save_prefix_pred}_unnormed.npy", unnorm_pred_fluent_np_tdk) + np.save(save_dir / f"{save_prefix_pred}_normed.npy", current_pred_fluent_normed_np) + + # Save original disfluent pose (.pose and .npy) + T_in_sample, K_in_sample, D_in_sample = unnorm_input_disfluent_np_tdk.shape + if T_in_sample > 0 and pose_header is not None: + unnorm_input_data_tpkd = unnorm_input_disfluent_np_tdk.reshape(T_in_sample, 1, K_in_sample, D_in_sample) + input_confidence = np.ones((T_in_sample, 1, K_in_sample), dtype=np.float32) + fps_to_use = pose_header.fps if hasattr(pose_header, 'fps') and pose_header.fps > 0 else 25.0 + unnorm_input_body = NumPyPoseBody(fps=fps_to_use, data=unnorm_input_data_tpkd, + confidence=input_confidence) + unnorm_input_obj = Pose(pose_header, unnorm_input_body) + + with open(save_dir / f"{save_prefix_input}.pose", "wb") as f: + unnorm_input_obj.write(f) + np.save(save_dir / f"{save_prefix_input}_unnormed.npy", unnorm_input_disfluent_np_tdk) + np.save(save_dir / f"{save_prefix_input}_normed.npy", current_input_disfluent_normed_np) + + # Mark this original_id as successfully processed and saved + processed_originals[original_id] = f"{save_prefix_pred}.pose" + unique_outputs_generated += 1 + + print( + f"\nAutoregressive inference finished. Generated and saved {unique_outputs_generated} unique fluent sequences.") + + +if __name__ == "__main__": + main() diff --git a/fluent_pose_synthesis/infer_autoregressive.py b/fluent_pose_synthesis/infer_autoregressive.py deleted file mode 100644 index b5ab56c..0000000 --- a/fluent_pose_synthesis/infer_autoregressive.py +++ /dev/null @@ -1,315 +0,0 @@ -# Example usage: -# python -m fluent_pose_synthesis.infer_autoregressive \ -# -i assets/sample_dataset \ -# -c save/debug_run/weighted_auto_step32_with_200data_output/config.json \ -# -r save/debug_run/weighted_auto_step32_with_200data_output/best.pt \ -# -o save/debug_run/weighted_auto_step32_with_200data_output/infer_output \ -# --batch_size 4 \ -# --chunk_size 30 \ -# --max_len 120 \ -# --stop_threshold 1e-4 \ -# --seed 1234 - - -import argparse -import json -import time -from pathlib import Path, PosixPath -import torch -from torch.utils.data import DataLoader -from pose_format.torch.masked.collator import zero_pad_collator -import numpy as np -from types import SimpleNamespace -from numpy.core.multiarray import scalar -from numpy import dtype -import torch.serialization - -# CAMDM and project imports -from CAMDM.diffusion.create_diffusion import create_gaussian_diffusion -from CAMDM.utils.common import fixseed -from fluent_pose_synthesis.core.models import SignLanguagePoseDiffusion -from fluent_pose_synthesis.data.load_data import SignLanguagePoseDataset -from pose_format import Pose -from pose_format.numpy.pose_body import NumPyPoseBody -from pose_anonymization.data.normalization import unnormalize_mean_std -from pose_format.utils.generic import normalize_pose_size - - -torch.serialization.add_safe_globals([ - SimpleNamespace, Path, PosixPath, scalar, dtype, - np.int64, np.int32, np.float64, np.float32, np.bool_, - ]) - - -def dict_to_namespace(d): - if isinstance(d, dict): - return SimpleNamespace(**{k: dict_to_namespace(v) for k, v in d.items()}) - elif isinstance(d, list): - return [dict_to_namespace(item) for item in d] - else: - return d - - -def convert_namespace_to_dict(namespace_obj): - if isinstance(namespace_obj, SimpleNamespace): - return {k: convert_namespace_to_dict(v) for k, v in vars(namespace_obj).items()} - elif isinstance(namespace_obj, dict): - return {k: convert_namespace_to_dict(v) for k, v in namespace_obj.items()} - elif isinstance(namespace_obj, (Path, PosixPath)): - return str(namespace_obj) - elif isinstance(namespace_obj, torch.device): - return str(namespace_obj) - else: - return namespace_obj - - -def load_checkpoint_and_config(model, checkpoint_path, device): - checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) - model.load_state_dict(checkpoint["state_dict"]) - print(f"Loaded model checkpoint from {checkpoint_path}") - loaded_config_dict = checkpoint.get("config", None) - if loaded_config_dict: - if isinstance(loaded_config_dict, SimpleNamespace): - loaded_config_dict = convert_namespace_to_dict(loaded_config_dict) - print("[INFO] Config found in checkpoint.") - return model, loaded_config_dict - return model, None - - -def main(): - parser = argparse.ArgumentParser(description="Autoregressive inference for fluent pose synthesis") - parser.add_argument("-i", "--input", required=True, type=str, help="Path to disfluent input data directory") - parser.add_argument("-c", "--config", required=True, type=str, help="Path to model config JSON file") - parser.add_argument("-r", "--resume", required=True, type=str, help="Path to model checkpoint (.pt)") - parser.add_argument("-o", "--output", default="output/infer_autoregressive_results", type=str, help="Directory to save generated outputs") - parser.add_argument("--batch_size", default=1, type=int, help="Batch size for inference.") - parser.add_argument("--chunk_size", default=10, type=int, help="Number of frames to attempt to generate per autoregressive step.") - parser.add_argument("--max_len", default=100, type=int, help="Maximum total frames to generate.") - parser.add_argument("--stop_threshold", default=1e-5, type=float, help="Threshold for mean absolute value of generated chunk to detect stop condition.") - parser.add_argument("--seed", default=42, type=int, help="Random seed.") - args = parser.parse_args() - - fixseed(args.seed) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") - - # Load Configuration - try: - with open(args.config, "r", encoding="utf-8") as f: - config_dict = json.load(f) - print(f"[INFO] Successfully loaded config from {args.config} as JSON.") - except json.JSONDecodeError: - print(f"[WARNING] Could not parse {args.config} as JSON. Attempting to eval as SimpleNamespace string...") - with open(args.config, "r", encoding="utf-8") as f: - namespace_content = f.read() - config_namespace = eval(namespace_content, {"SimpleNamespace": SimpleNamespace, "namespace": SimpleNamespace, "PosixPath": PosixPath, "device": torch.device, "Path": Path}) - config_dict = convert_namespace_to_dict(config_namespace) - standard_json_path = Path(args.config).with_suffix('.fixed.json') - with open(standard_json_path, "w", encoding="utf-8") as f_json: - json.dump(config_dict, f_json, indent=4) - print(f"[INFO] Saved successfully parsed config to {standard_json_path}") - - config = dict_to_namespace(config_dict) - config.inference = SimpleNamespace( - chunk_size=args.chunk_size, max_len=args.max_len, stop_threshold=args.stop_threshold - ) - config.device = device - - print("Loading dataset...") - dataset = SignLanguagePoseDataset( - data_dir=Path(args.input), - split="test", - chunk_len=config.arch.chunk_len, - dtype=np.float32, - limited_num=-1 - ) - dataloader = DataLoader( - dataset, batch_size=args.batch_size, shuffle=False, - collate_fn=zero_pad_collator, num_workers=0 - ) - print(f"Dataset loaded with {len(dataset)} samples.") - - if hasattr(dataset, 'pose_header') and dataset.pose_header is not None: - pose_header = dataset.pose_header - print("[INFO] Pose header loaded from dataset.") - - print("Initializing model...") - input_feats = config.arch.keypoints * config.arch.dims - model = SignLanguagePoseDiffusion( - input_feats=input_feats, chunk_len=config.arch.chunk_len, - keypoints=config.arch.keypoints, dims=config.arch.dims, - latent_dim=config.arch.latent_dim, ff_size=config.arch.ff_size, - num_layers=config.arch.num_layers, num_heads=config.arch.num_heads, - dropout=getattr(config.arch, "dropout", 0.2), - activation=getattr(config.arch, "activation", "gelu"), - arch=config.arch.decoder, cond_mask_prob=0, device=config.device - ).to(config.device) - - model, _ = load_checkpoint_and_config(model, args.resume, device) - model.eval() - - diffusion = create_gaussian_diffusion(config) - - class WrappedDiffusionModel(torch.nn.Module): - def __init__(self, model_to_wrap): - super().__init__() - self.model_to_wrap = model_to_wrap - def forward(self, x_noisy_chunk, t, **kwargs): - return self.model_to_wrap.interface(x_noisy_chunk, t, y=kwargs["y"]) - wrapped_model = WrappedDiffusionModel(model) - - save_dir = Path(args.output) - save_dir.mkdir(parents=True, exist_ok=True) - print(f"Output will be saved to: {save_dir}") - - print(f"Starting autoregressive inference: target_chunk_size_per_step={args.chunk_size}, max_len={args.max_len}, stop_threshold={args.stop_threshold}") - total_batches = len(dataloader) - # for batch_idx, batch_data in enumerate(dataloader): - # print(f"\nProcessing batch {batch_idx + 1}/{total_batches}") - for batch_idx, batch_data in enumerate(dataloader): - print(f"\n--- Processing batch {batch_idx + 1}/{total_batches} ---") - print(f"Batch data type: {type(batch_data)}") - if isinstance(batch_data, dict): - for key, value in batch_data.items(): - print(f" Key: '{key}'") - if isinstance(value, torch.Tensor): - print(f" Tensor shape: {value.shape}, dtype: {value.dtype}") - elif isinstance(value, dict): # 比如 'conditions' - print(f" Value is a dict:") - for sub_key, sub_value in value.items(): - print(f" Sub-key: '{sub_key}'") - if isinstance(sub_value, torch.Tensor): - print(f" Tensor shape: {sub_value.shape}, dtype: {sub_value.dtype}") - else: - print(f" Value type: {type(sub_value)}") - else: - print(f" Value type: {type(value)}") - - disfluent_cond_seq_loader = batch_data["conditions"]["input_sequence"].to(device) - disfluent_cond_seq = disfluent_cond_seq_loader.permute(0, 2, 3, 1) # (B, K, D, T_loader) - - current_batch_size, K, D, T_in = disfluent_cond_seq.shape - print(f" Initial disfluent condition shape: (B,K,D,T) = ({current_batch_size}, {K}, {D}, {T_in})") - - generated_fluent_seq = torch.empty(current_batch_size, K, D, 0, device=device) - active_sequences = torch.ones(current_batch_size, dtype=torch.bool, device=device) - num_autoregressive_steps = (args.max_len + args.chunk_size - 1) // args.chunk_size - - for step in range(num_autoregressive_steps): - if not active_sequences.any(): - print(f" Step {step + 1}: All sequences stopped. Ending batch.") - break - current_generated_len = generated_fluent_seq.shape[3] - if current_generated_len >= args.max_len: - print(f" Step {step + 1}: Reached max_len ({args.max_len}). Ending batch.") - break - n_frames_to_generate_this_step = min(args.chunk_size, args.max_len - current_generated_len) - target_chunk_shape = (current_batch_size, K, D, n_frames_to_generate_this_step) - print(f" Step {step + 1}: Gen {n_frames_to_generate_this_step} frames. Total gen: {current_generated_len}.") - - model_kwargs_y = { - "input_sequence": disfluent_cond_seq, - "previous_output": generated_fluent_seq - } - model_kwargs_for_sampler = {"y": model_kwargs_y} - - with torch.no_grad(): - generated_chunk = diffusion.p_sample_loop( - model=wrapped_model, shape=target_chunk_shape, - clip_denoised=False, model_kwargs=model_kwargs_for_sampler, - progress=False - ) - - mean_abs_for_chunk = torch.zeros(current_batch_size, device=device) - if generated_chunk.numel() > 0: - mean_abs_for_chunk[active_sequences] = generated_chunk[active_sequences].abs().mean(dim=(1, 2, 3)) - newly_stopped_sequences = (mean_abs_for_chunk < args.stop_threshold) & active_sequences - if newly_stopped_sequences.any(): - print(f" Sequences at indices {newly_stopped_sequences.nonzero(as_tuple=True)[0].tolist()} stopped.") - active_sequences = active_sequences & (~newly_stopped_sequences) - generated_chunk[newly_stopped_sequences] = torch.zeros_like(generated_chunk[newly_stopped_sequences]) - generated_fluent_seq = torch.cat([generated_fluent_seq, generated_chunk], dim=3) - - print(f" Finished generation for batch {batch_idx + 1}. Final length: {generated_fluent_seq.shape[3]}") - - pred_fluent_normed_np_btdk = generated_fluent_seq.permute(0, 3, 1, 2).cpu().numpy() # (B, T, D, K) -> (B, T, K, D) for NumPyPoseBody - - unnorm_pred_fluent_list_np = [] # To store unnormalized numpy arrays for combined .npy file - - for i in range(current_batch_size): - sample_idx_global = batch_idx * args.batch_size + i - normed_pose_data_tdk = pred_fluent_normed_np_btdk[i] # (T, K, D) - T_final, K_sample, D_sample = normed_pose_data_tdk.shape - - if T_final == 0: # Skip if no frames were generated for this sample - unnorm_pred_fluent_list_np.append(np.empty((0, K_sample, D_sample), dtype=normed_pose_data_tdk.dtype)) # Add empty for consistent list length - continue - - # Create Pose object from normalized data - normed_pose_data_tpkd = normed_pose_data_tdk.reshape(T_final, 1, K_sample, D_sample) # Add person dim - normed_confidence = np.ones((T_final, 1, K_sample), dtype=np.float32) - - fps_to_use = 25 - normed_pose_body = NumPyPoseBody(fps=fps_to_use, data=normed_pose_data_tpkd, confidence=normed_confidence) - normed_pose_obj = Pose(pose_header, normed_pose_body) - - # Unnormalize using the new method - unnorm_pose_obj = unnormalize_mean_std(normed_pose_obj.copy()) # Use .copy() if unnormalize_mean_std modifies in-place - normalize_pose_size(unnorm_pose_obj) # Scale for visualization consistency - - # Save unnormalized .pose file - with open(save_dir / f"pose_pred_fluent_{sample_idx_global}.pose", "wb") as f: - unnorm_pose_obj.write(f) - - # Get unnormalized numpy data for combined .npy file - unnorm_data_np_tdk = np.array(unnorm_pose_obj.body.data.data.astype(normed_pose_data_tdk.dtype)).squeeze(1) # (T, K, D) - unnorm_pred_fluent_list_np.append(unnorm_data_np_tdk) - - # Save combined .npy files for the batch (one normalized, one unnormalized) - np.save(save_dir / f"pred_fluent_normed_batch{batch_idx}.npy", pred_fluent_normed_np_btdk) - if unnorm_pred_fluent_list_np: # Only save if list is not empty - for i_un, unnorm_arr in enumerate(unnorm_pred_fluent_list_np): - sample_idx_global_un = batch_idx * args.batch_size + i_un - if unnorm_arr.shape[0] > 0 : # only save if there is data - np.save(save_dir / f"pred_fluent_unnormed_sample{sample_idx_global_un}.npy", unnorm_arr) - - print(f" Saved predicted fluent poses for batch {batch_idx + 1} to {save_dir}") - - # Save original disfluent input for reference (using similar unnormalization) - disfluent_input_normed_np_btdk = disfluent_cond_seq_loader.cpu().numpy() # (B, T_in_loader, K, D) - - # Save normalized disfluent input - np.save(save_dir / f"input_disfluent_normed_batch{batch_idx}.npy", disfluent_input_normed_np_btdk) - - # Save unnormalized disfluent input - for i in range(current_batch_size): - sample_idx_global = batch_idx * args.batch_size + i - normed_disfluent_data_tdk = disfluent_input_normed_np_btdk[i] - T_disfl, K_disfl, D_disfl = normed_disfluent_data_tdk.shape - - if T_disfl == 0: continue - - normed_disfluent_data_tpkd = normed_disfluent_data_tdk.reshape(T_disfl, 1, K_disfl, D_disfl) - disfl_confidence = np.ones((T_disfl, 1, K_disfl), dtype=np.float32) - - fps_to_use = pose_header.fps if hasattr(pose_header, 'fps') and pose_header.fps > 0 else 25 - normed_disfl_body = NumPyPoseBody(fps=fps_to_use, data=normed_disfluent_data_tpkd, confidence=disfl_confidence) - normed_disfl_obj = Pose(pose_header, normed_disfl_body) - - unnorm_disfl_obj = unnormalize_mean_std(normed_disfl_obj.copy()) - normalize_pose_size(unnorm_disfl_obj) # Optional - - # Save unnormalized disfluent .pose file - with open(save_dir / f"pose_input_disfluent_{sample_idx_global}.pose", "wb") as f: - unnorm_disfl_obj.write(f) - - # Save individual unnormalized disfluent .npy - unnorm_disfl_np_tdk = np.array(unnorm_disfl_obj.body.data.data.astype(normed_disfluent_data_tdk.dtype)).squeeze(1) - np.save(save_dir / f"input_disfluent_unnormed_sample{sample_idx_global}.npy", unnorm_disfl_np_tdk) - - print(f" Saved input disfluent poses for batch {batch_idx + 1} to {save_dir}") - - print("\nAutoregressive inference finished for all batches.") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/fluent_pose_synthesis/tests/test_overfit.py b/fluent_pose_synthesis/tests/test_overfit.py index b110736..a88c36d 100644 --- a/fluent_pose_synthesis/tests/test_overfit.py +++ b/fluent_pose_synthesis/tests/test_overfit.py @@ -13,6 +13,7 @@ class DummyDataset(Dataset): + def __len__(self): return 1 @@ -23,12 +24,8 @@ def __getitem__(self, idx): def get_toy_batch(batch_size=2, seq_len=40, keypoints=178): assert batch_size == 2, "get_toy_batch currently only supports batch_size=2" - base_linear = torch.linspace(0, 1, seq_len * keypoints * 3).reshape( - seq_len, 1, keypoints, 3 - ) - base_sine = ( - torch.sin(torch.linspace(0, 4 * np.pi, seq_len)).unsqueeze(1).unsqueeze(2) - ) # [T, 1, 1] + base_linear = torch.linspace(0, 1, seq_len * keypoints * 3).reshape(seq_len, 1, keypoints, 3) + base_sine = (torch.sin(torch.linspace(0, 4 * np.pi, seq_len)).unsqueeze(1).unsqueeze(2)) # [T, 1, 1] base_sine = base_sine.expand(seq_len, 1, keypoints).unsqueeze(-1) # [T, 1, K, 1] base_sine = base_sine.repeat(1, 1, 1, 3) # [T, 1, K, 3] @@ -125,11 +122,8 @@ def test_overfit_toy_batch(): # Move batch to device batch = { - k: ( - v.to(config.device) - if isinstance(v, torch.Tensor) - else {kk: vv.to(config.device) for kk, vv in v.items()} - ) + k: (v.to(config.device) if isinstance(v, torch.Tensor) else {kk: vv.to(config.device) + for kk, vv in v.items()}) for k, v in batch.items() } @@ -166,12 +160,8 @@ def test_overfit_toy_batch(): losses = [] for step in range(config.trainer.epoch): - t, weights = trainer.schedule_sampler.sample( - config.trainer.batch_size, config.device - ) - _, loss_dict = trainer.diffuse( - batch["data"], t, batch["conditions"], return_loss=True - ) + t, weights = trainer.schedule_sampler.sample(config.trainer.batch_size, config.device) + _, loss_dict = trainer.diffuse(batch["data"], t, batch["conditions"], return_loss=True) loss = (loss_dict["loss"] * weights).mean() losses.append(loss.item()) print(f"[Step {step}] Loss: {loss.item():.6f}") @@ -180,9 +170,7 @@ def test_overfit_toy_batch(): loss.backward() optimizer.step() - assert ( - losses[-1] < 1e-3 - ), "Final loss is too high. Model failed to overfit the toy batch." + assert (losses[-1] < 1e-3), "Final loss is too high. Model failed to overfit the toy batch." plot_loss_curve(losses, save_path="overfit_loss_curve.png") # Check model output differences @@ -209,9 +197,7 @@ def test_overfit_toy_batch(): config.arch.keypoints, config.arch.dims, ) - assert ( - out1.shape == out2.shape == expected_shape - ), f"Unexpected output shape, expected {expected_shape}" + assert (out1.shape == out2.shape == expected_shape), f"Unexpected output shape, expected {expected_shape}" # Compute multiple metrics to assess output difference l2_diff = torch.norm(out1 - out2).item() @@ -223,9 +209,7 @@ def test_overfit_toy_batch(): print(f"Cosine distance: {cosine_dist:.6f}") # Assert based on multiple metrics - assert ( - avg_kpt_error > 0.01 or cosine_dist > 0.01 - ), "Outputs are too similar. Possible collapse." + assert (avg_kpt_error > 0.01 or cosine_dist > 0.01), "Outputs are too similar. Possible collapse." print("Overfitting test passed.") @@ -252,6 +236,7 @@ def compute_average_keypoint_error(pose1, pose2): diff = torch.norm(pose1 - pose2, dim=-1) # [B, T, K] return diff.mean().item() # scalar + def compute_cosine_distance(pose1, pose2): """ Computes cosine distance between flattened pose vectors. @@ -259,7 +244,7 @@ def compute_cosine_distance(pose1, pose2): """ v1 = pose1.flatten() v2 = pose2.flatten() - cos = F.cosine_similarity(v1, v2, dim=0) # pylint: disable=not-callable + cos = F.cosine_similarity(v1, v2, dim=0) # pylint: disable=not-callable return 1 - cos.item() diff --git a/fluent_pose_synthesis/train.py b/fluent_pose_synthesis/train.py index 5f1ab98..1920044 100644 --- a/fluent_pose_synthesis/train.py +++ b/fluent_pose_synthesis/train.py @@ -1,7 +1,10 @@ import sys +import os + import time import shutil import argparse +import json from pathlib import Path, PosixPath from types import SimpleNamespace @@ -17,6 +20,7 @@ from fluent_pose_synthesis.core.models import SignLanguagePoseDiffusion from fluent_pose_synthesis.core.training import PoseTrainingPortal from fluent_pose_synthesis.data.load_data import SignLanguagePoseDataset + from fluent_pose_synthesis.config.option import ( add_model_args, add_train_args, @@ -25,23 +29,25 @@ ) # Add custom globals to torch.serialization -torch.serialization.add_safe_globals( - [ - SimpleNamespace, - PosixPath, - np.int64, - np.int32, - np.float64, - np.float32, - np.bool_, - ] -) +torch.serialization.add_safe_globals([ + SimpleNamespace, + PosixPath, + np.int64, + np.int32, + np.float64, + np.float32, + np.bool_, +]) # Patch torch.load to avoid loading weights only # This is a workaround for the issue where torch.load tries to load weights only _original_torch_load = torch.load + + def patched_torch_load(*args, **kwargs): kwargs.setdefault("weights_only", False) return _original_torch_load(*args, **kwargs) + + torch.load = patched_torch_load @@ -55,15 +61,16 @@ def train( fixseed(1024) np_dtype = select_platform(32) - logger.info("Loading dataset...") + # Training Dataset and Dataloader + logger.info("Loading training dataset...") train_dataset = SignLanguagePoseDataset( data_dir=config.data, split="train", chunk_len=config.arch.chunk_len, + history_len=getattr(config.arch, "history_len", 5), dtype=np_dtype, limited_num=config.trainer.load_num, ) - train_dataloader = DataLoader( train_dataset, batch_size=config.trainer.batch_size, @@ -73,12 +80,27 @@ def train( pin_memory=True, collate_fn=zero_pad_collator, ) - - logger.info( - f"Training Dataset includes {len(train_dataset)} samples, " - f"with {config.arch.chunk_len} fluent frames per sample." + logger.info(f"Training Dataset includes {len(train_dataset)} samples, " + f"with {config.arch.chunk_len} fluent frames per sample.") + + # Validation Dataset and Dataloader + logger.info("Loading validation dataset...") + validation_dataset = SignLanguagePoseDataset(data_dir=config.data, split="validation", + chunk_len=config.arch.chunk_len, + history_len=getattr(config.arch, "history_len", 5), dtype=np_dtype, + limited_num=config.trainer.load_num) + validation_dataloader = DataLoader( + validation_dataset, + batch_size=config.trainer.batch_size, + shuffle=False, # No need to shuffle validation data + num_workers=config.trainer.workers, + drop_last=False, + pin_memory=True, + collate_fn=zero_pad_collator, ) + logger.info(f"Validation Dataset includes {len(validation_dataset)} samples.") + # Model and Diffusion Initialization diffusion = create_gaussian_diffusion(config) input_feats = config.arch.keypoints * config.arch.dims @@ -101,29 +123,34 @@ def train( ).to(config.device) logger.info(f"Model: {model}") - trainer = PoseTrainingPortal( - config, model, diffusion, train_dataloader, logger, tb_writer - ) + + # Training Portal Initialization + trainer = PoseTrainingPortal(config, model, diffusion, train_dataloader, logger, tb_writer, + validation_dataloader=validation_dataloader) if resume_path is not None: try: trainer.load_checkpoint(str(resume_path)) + logger.info(f"[DEBUG] After load_checkpoint, trainer.epoch={trainer.epoch}") except FileNotFoundError: print(f"No checkpoint found at {resume_path}") sys.exit(1) - trainer.run_loop() + custom_profiler_directory = config.save / "profiler_logs" + custom_profiler_directory.mkdir(parents=True, exist_ok=True) + + logger.info(f"Profiler output will be directed to: {custom_profiler_directory}") + + logger.info(f"[DEBUG] About to start run_loop with trainer.epoch={trainer.epoch}") + trainer.run_loop(enable_profiler=True, profiler_directory=str(custom_profiler_directory)) + # trainer.run_loop() def main(): start_time = time.time() - parser = argparse.ArgumentParser( - description="### Fluent Sign Language Pose Synthesis Training ###" - ) - parser.add_argument( - "-n", "--name", default="debug", type=str, help="The name of this training run" - ) + parser = argparse.ArgumentParser(description="### Fluent Sign Language Pose Synthesis Training ###") + parser.add_argument("-n", "--name", default="debug", type=str, help="The name of this training run") parser.add_argument( "-c", "--config", @@ -134,14 +161,11 @@ def main(): parser.add_argument( "-i", "--data", - default="assets/sample_dataset", - # default="/pose_data/output", + default="/pose_data/output", type=str, help="Path to dataset folder", ) - parser.add_argument( - "-r", "--resume", default=None, type=str, help="Path to latest checkpoint" - ) + parser.add_argument("-r", "--resume", default=None, type=str, help="Path to latest checkpoint") parser.add_argument( "-s", "--save", @@ -156,7 +180,6 @@ def main(): add_train_args(parser) args = parser.parse_args() - config = config_parse(args) # Convert key paths to Path objects @@ -175,20 +198,15 @@ def main(): config.trainer.epoch = 2000 # Handle existing folder - if ( - not args.cluster - and config.save.exists() - and "debug" not in args.name - and args.resume is None - ): + if (not args.cluster and config.save.exists() and "debug" not in args.name and args.resume is None): allow_cover = input("Model folder exists. Overwrite? (Y/N): ").lower() if allow_cover == "n": sys.exit(0) shutil.rmtree(config.save, ignore_errors=True) - # Auto-resume for cluster - resume_path = None - if config.save.exists() and args.resume is None: + # Use the resume path from command line argument if provided + resume_path = Path(args.resume) if args.resume else None + if resume_path is None and config.save.exists(): best_ckpt = config.save / "best.pt" if best_ckpt.exists(): resume_path = best_ckpt @@ -207,12 +225,28 @@ def main(): # Save config with open(config.save / "config.json", "w", encoding="utf-8") as f: - f.write(str(config)) + # Convert SimpleNamespace to dict for JSON serialization + json.dump(config_to_dict(config), f, indent=4) + logger.info(f"Saved final configuration to {config.save / 'config.json'}") logger.info(f"\nLaunching training with config:\n{config}") train(config, resume_path, logger, tb_writer) logger.info(f"\nTotal training time: {(time.time() - start_time) / 60:.2f} mins") +def config_to_dict(config_namespace): + """Helper to convert SimpleNamespace (recursively) to dict for JSON.""" + if isinstance(config_namespace, SimpleNamespace): + return {k: config_to_dict(v) for k, v in vars(config_namespace).items()} + elif isinstance(config_namespace, Path): + return str(config_namespace) + elif isinstance(config_namespace, (list, tuple)): + return [config_to_dict(i) for i in config_namespace] + elif isinstance(config_namespace, torch.device): + return str(config_namespace) + else: + return config_namespace + + if __name__ == "__main__": - main() \ No newline at end of file + main()