Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Running training while all weights are frozen leads to loss.backward failing. #29

Open
infophysics opened this issue Oct 9, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@infophysics
Copy link

When all weights are frozen, no gradients are produced during the model evaluation and the loss has no mechanism for doing backprop.

To Reproduce
A clear set of commands and resources needed to reproduce the issue:
Run any training config with weights_frozen: true set for all pieces of the chain

# Base configuration
base:
  world_size: 1
  iterations: 10000 # 200k/256 -> ~1000/epoch -> ~ 10 epochs
  seed: 0
  unwrap: false
  log_dir: /sdf/data/neutrino/2x2/spine/train/mpvmpr_v2/logs/full_chain/default
  log_step: 1
  train:
    weight_prefix: /sdf/data/neutrino/2x2/spine/train/mpvmpr_v2/weights/full_chain/default/snapshot
    save_step: 1000 # ~1 epoch
    optimizer:
      name: Adam
      lr: 0.001

# IO configuration
io:
  loader:
    batch_size: 256
    shuffle: false
    num_workers: 8
    collate_fn: all
    sampler: random_sequence
    dataset:
      name: larcv
      file_keys: /sdf/data/neutrino/2x2/sim/mpvmpr_v2/train_file_list.txt
      schema:
        data:
          parser: sparse3d
          sparse_event: sparse3d_pcluster
        seg_label:
          parser: sparse3d
          sparse_event: sparse3d_pcluster_semantics
        ppn_label:
          parser: particle_points
          sparse_event: sparse3d_pcluster
          particle_event: particle_pcluster
          include_point_tagging: false
        clust_label:
          parser: cluster3d
          cluster_event: cluster3d_pcluster
          particle_event: particle_pcluster
          sparse_semantics_event: sparse3d_pcluster_semantics
          add_particle_info: true
          clean_data: true
          type_include_secondary: false
          type_include_mpr: false
          primary_include_mpr: false
        coord_label:
          parser: particle_coords
          particle_event: particle_pcluster
          cluster_event: cluster3d_pcluster

# Model configuration
model:
  name: full_chain
  weight_path: null

  network_input:
    data: data
    seg_label: seg_label
    clust_label: clust_label

  loss_input:
    seg_label: seg_label
    ppn_label: ppn_label
    clust_label: clust_label
    coord_label: coord_label

  modules:
    # General chain configuration
    chain:
      deghosting: null
      charge_rescaling: null
      segmentation: uresnet
      point_proposal: ppn
      fragmentation: graph_spice
      shower_aggregation: grappa
      shower_primary: grappa
      track_aggregation: grappa
      particle_aggregation: null
      inter_aggregation: grappa
      particle_identification: grappa
      primary_identification: grappa
      orientation_identification: grappa
      calibration: null

    # Semantic segmentation + point proposal
    uresnet_ppn:
      weight_path: /sdf/data/neutrino/2x2/spine/train/mpvmpr_v2/weights/full_chain/graph_spice/snapshot-999.ckpt
      freeze_weights: true
      uresnet:
        num_input: 1
        num_classes: 5
        filters: 32
        depth: 5
        reps: 2
        allow_bias: false
        activation:
          name: lrelu
          negative_slope: 0.33
        norm_layer:
          name: batch_norm
          eps: 0.0001
          momentum: 0.01
  
      ppn:
        classify_endpoints: false
  
    uresnet_ppn_loss:
      uresnet_loss:
        balance_loss: false
  
      ppn_loss:
        mask_loss: CE
        resolution: 5.0

    # Dense clustering
    graph_spice:
      weight_path: /sdf/data/neutrino/2x2/spine/train/mpvmpr_v2/weights/full_chain/graph_spice/snapshot-999.ckpt
      freeze_weights: true
      shapes: [shower, track, michel, delta]
      use_raw_features: true
      invert: true
      make_clusters: true
      embedder:
        spatial_embedding_dim: 3
        feature_embedding_dim: 16
        occupancy_mode: softplus
        covariance_mode: softplus
        uresnet:
          num_input: 4 # 1 feature + 3 normalized coords
          filters: 32
          input_kernel: 5
          depth: 5
          reps: 2
          spatial_size: 320
          allow_bias: false
          activation:
            name: lrelu
            negative_slope: 0.33
          norm_layer:
            name: batch_norm
            eps: 0.0001
            momentum: 0.01
      kernel:
        name: bilinear
        num_features: 32
      constructor:
        edge_threshold: 0.1
        min_size: 3
        label_edges: true
        graph:
          name: radius
          r: 1.9
        orphan:
          mode: radius
          radius: 1.9
          iterate: true
          assign_all: true

    graph_spice_loss:
      name: edge
      loss: binary_log_dice_ce

    # Shower fragment aggregation + shower primary identification
    grappa_shower:
      weight_path: /sdf/data/neutrino/2x2/spine/train/mpvmpr_v2/weights/full_chain/grappa_shower/snapshot-2999.ckpt
      freeze_weights: true
      nodes:
        source: cluster
        shapes: [shower, michel, delta]
        min_size: -1
        make_groups: true
        grouping_method: score
      graph:
        name: complete
        max_length: [500, 0, 500, 500, 0, 0, 0, 25, 0, 25]
        dist_algorithm: recursive
      node_encoder:
        name: geo
        use_numpy: true
        add_value: true
        add_shape: true
        add_points: true
        add_local_dirs: true
        dir_max_dist: 5
        add_local_dedxs: true
        dedx_max_dist: 5
      edge_encoder:
        name: geo
        use_numpy: true
      gnn_model:
        name: meta
        node_feats: 33 # 16 (geo) + 3 (extra) + 6 (points) + 6 (directions) + 2 (local dedxs)
        edge_feats: 19
        node_pred: 2
        edge_pred: 2
        edge_layer:
          name: mlp
          mlp:
            depth: 3
            width: 64
            activation:
              name: lrelu
              negative_slope: 0.1
            normalization: batch_norm
        node_layer:
          name: mlp
          reduction: max
          attention: false
          message_mlp:
            depth: 3
            width: 64
            activation:
              name: lrelu
              negative_slope: 0.1
            normalization: batch_norm
          aggr_mlp:
            depth: 3
            width: 64
            activation:
              name: lrelu
              negative_slope: 0.1
            normalization: batch_norm

    grappa_shower_loss:
      node_loss:
        name: shower_primary
        high_purity: true
        use_group_pred: true
      edge_loss:
        name: channel
        target: group
        high_purity: true

    # Track aggregation
    grappa_track:
      weight_path: /sdf/data/neutrino/2x2/spine/train/mpvmpr_v2/weights/full_chain/grappa_track/snapshot-2999.ckpt
      freeze_weights: true
      nodes:
        source: cluster
        shapes: [track]
        min_size: -1
        make_groups: true
        grouping_method: score
      graph:
        name: complete
        max_length: 100
        dist_algorithm: recursive
      node_encoder:
        name: geo
        use_numpy: true
        add_value: true
        add_shape: false
        add_points: true
        add_local_dirs: true
        dir_max_dist: 5
        add_local_dedxs: true
        dedx_max_dist: 5
      edge_encoder:
        name: geo
        use_numpy: true
      gnn_model:
        name: meta
        node_feats: 32 # 16 (geo) + 2 (extra) + 6 (points) + 6 (directions) + 2 (local dedxs)
        edge_feats: 19
        node_pred: 2
        edge_pred: 2
        edge_layer:
          name: mlp
          mlp:
            depth: 3
            width: 64
            activation:
              name: lrelu
              negative_slope: 0.1
            normalization: batch_norm
        node_layer:
          name: mlp
          reduction: max
          attention: false
          message_mlp:
            depth: 3
            width: 64
            activation:
              name: lrelu
              negative_slope: 0.1
            normalization: batch_norm
          aggr_mlp:
            depth: 3
            width: 64
            activation:
              name: lrelu
              negative_slope: 0.1
            normalization: batch_norm

    grappa_track_loss:
      edge_loss:
        name: channel
        target: group

    # Interaction aggregation, PID, primary, orientation
    grappa_inter:
      weight_path: /sdf/data/neutrino/2x2/spine/train/mpvmpr_v2/weights/grappa_inter/balanced/snapshot-19999.ckpt
      freeze_weights: true
      model_name: ''
      nodes:
        source: group
        shapes: [shower, track, michel, delta]
        min_size: -1
        make_groups: true
      graph:
        name: complete
        max_length: [500, 500, 0, 0, 25, 25, 25, 0, 0, 0]
        dist_algorithm: recursive
      node_encoder:
        name: geo
        use_numpy: true
        add_value: true
        add_shape: true
        add_points: true
        add_local_dirs: true
        dir_max_dist: 5
        add_local_dedxs: true
        dedx_max_dist: 5
      edge_encoder:
        name: geo
        use_numpy: true
      gnn_model:
        name: meta
        node_feats: 33 # 16 (geo) + 3 (extra) + 6 (points) + 6 (directions) + 2 (local dedxs)
        edge_feats: 19
        node_pred:
          type: 6
          primary: 2
          orient: 2
          #momentum: 1
          #vertex: 5
        edge_pred: 2
        edge_layer:
          name: mlp
          mlp:
            depth: 3
            width: 128
            activation:
              name: lrelu
              negative_slope: 0.1
            normalization: batch_norm
        node_layer:
          name: mlp
          reduction: max
          attention: false
          message_mlp:
            depth: 3
            width: 128
            activation:
              name: lrelu
              negative_slope: 0.1
            normalization: batch_norm
          aggr_mlp:
            depth: 3
            width: 128
            activation:
              name: lrelu
              negative_slope: 0.1
            normalization: batch_norm

    grappa_inter_loss:
      node_loss:
        type:
          name: class
          target: pid
          loss: ce
          balance_loss: true
        primary:
          name: class
          target: inter_primary
          loss: ce
          balance_loss: true
        orient:
          name: orient
          loss: ce
        #momentum:
        #  name: reg
        #  target: 16
        #  loss: berhu
        #vertex:
        #  name: vertex
        #  primary_loss: ce
        #  balance_primary_loss: true
        #  regression_loss: mse
        #  only_contained: true
        #  normalize_positions: true
        #  use_anchor_points: true
        #  return_vertex_labels: true
        #  detector: icarus
      edge_loss:
        name: channel
        target: interaction

Expected behavior
An error such as the following will appear:

image

Code base
Provide the following:

Possible solution (optional)
Obviously we can remove freeze_weights from our config, but it would be ideal for SPINE to check that at least some weights are being trained during a training script, so as not to produce the error above (which may be difficult to parse).

@infophysics infophysics added the bug Something isn't working label Oct 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant