diff --git a/sherpa/csrc/online-lstm-transducer-model.cc b/sherpa/csrc/online-lstm-transducer-model.cc index b1ef4df8a..5266117dd 100644 --- a/sherpa/csrc/online-lstm-transducer-model.cc +++ b/sherpa/csrc/online-lstm-transducer-model.cc @@ -26,10 +26,10 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel( joiner_ = torch::jit::load(joiner_filename, device); joiner_.eval(); + auto conv = decoder_.attr("conv").toModule(); + context_size_ = - decoder_.hasattr("conv") - ? decoder_.attr("conv").toModule().attr("weight").toTensor().size(2) - : 1; + conv.hasattr("weight") ? conv.attr("weight").toTensor().size(2) : 1; // Use 5 here since the subsampling is ((len - 3) // 2 - 1) // 2. int32_t pad_length = 5; diff --git a/sherpa/csrc/online-zipformer-transducer-model.cc b/sherpa/csrc/online-zipformer-transducer-model.cc index e564a8477..2cbf971bd 100644 --- a/sherpa/csrc/online-zipformer-transducer-model.cc +++ b/sherpa/csrc/online-zipformer-transducer-model.cc @@ -26,10 +26,10 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( joiner_ = torch::jit::load(joiner_filename, device); joiner_.eval(); + auto conv = decoder_.attr("conv").toModule(); + context_size_ = - decoder_.hasattr("conv") - ? decoder_.attr("conv").toModule().attr("weight").toTensor().size(2) - : 1; + conv.hasattr("weight") ? conv.attr("weight").toTensor().size(2) : 1; // Use 7 here since the subsampling is ((len - 7) // 2 + 1) // 2. int32_t pad_length = 7; @@ -49,10 +49,10 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( decoder_ = model_.attr("decoder").toModule(); joiner_ = model_.attr("joiner").toModule(); + auto conv = decoder_.attr("conv").toModule(); + context_size_ = - decoder_.hasattr("conv") - ? decoder_.attr("conv").toModule().attr("weight").toTensor().size(2) - : 1; + conv.hasattr("weight") ? conv.attr("weight").toTensor().size(2) : 1; // Use 7 here since the subsampling is ((len - 7) // 2 + 1) // 2. int32_t pad_length = 7; diff --git a/sherpa/csrc/online-zipformer2-transducer-model.cc b/sherpa/csrc/online-zipformer2-transducer-model.cc index 43937b1a2..72e04c93a 100644 --- a/sherpa/csrc/online-zipformer2-transducer-model.cc +++ b/sherpa/csrc/online-zipformer2-transducer-model.cc @@ -23,10 +23,10 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( decoder_ = model_.attr("decoder").toModule(); joiner_ = model_.attr("joiner").toModule(); + auto conv = decoder_.attr("conv").toModule(); + context_size_ = - decoder_.hasattr("conv") - ? decoder_.attr("conv").toModule().attr("weight").toTensor().size(2) - : 1; + conv.hasattr("weight") ? conv.attr("weight").toTensor().size(2) : 1; int32_t pad_length = encoder_.attr("pad_length").toInt();