From 1f11a517fa63a60b885c7cf35d4d84e83b7e9f4e Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Sat, 23 Sep 2023 14:43:52 +0800 Subject: [PATCH] Allow partial for intersect_dense_pruned (#1218) * allow_partial for intersect_dense_pruned * fix comment * set default value and move to the last * Fix crash and backward * Fix allow partial for online decoding * Add python demo * Minor fixes * Add is_final * Minor fixes * Minor fixes * Fix style * Fix online intersecter test * Fix test * Fix cpp style * Support decoding with a wav scp * quick fix for online decoding * Add GetFinalFrame; Fix the online decoding issue * Fix style * Fix ci * Fix ci * fix ci * fix ci --------- Co-authored-by: Guo Liyong --- .github/workflows/run-tests-cpu.yml | 3 + ...test-k2-as-third-party-lib-cuda-ubuntu.yml | 3 + k2/csrc/fsa_algo.h | 14 +- k2/csrc/intersect_dense_pruned.cu | 331 ++++++++++----- k2/csrc/intersect_dense_pruned.h | 6 +- k2/csrc/intersect_test.cu | 22 +- k2/python/csrc/torch/fsa_algo.cu | 31 +- k2/python/k2/autograd.py | 33 +- k2/python/k2/dense_fsa_vec.py | 2 +- k2/python/k2/online_dense_intersecter.py | 7 + .../tests/online_dense_intersecter_test.py | 2 +- k2/torch/bin/hlg_decode.py | 305 ++++++++++++++ k2/torch/bin/online_decode.cu | 18 +- k2/torch/bin/online_decode.py | 383 ++++++++++++++++++ k2/torch/csrc/CMakeLists.txt | 2 +- k2/torch/csrc/fsa_algo.cu | 7 +- k2/torch/csrc/fsa_algo.h | 8 +- .../github_actions/generate_build_matrix.py | 2 +- scripts/github_actions/install_cuda.sh | 3 + scripts/github_actions/install_cudnn.sh | 3 + scripts/github_actions/install_torch.sh | 7 +- 21 files changed, 1040 insertions(+), 152 deletions(-) create mode 100644 k2/torch/bin/hlg_decode.py create mode 100644 k2/torch/bin/online_decode.py diff --git a/.github/workflows/run-tests-cpu.yml b/.github/workflows/run-tests-cpu.yml index e70cd9966..e55f1fbbd 100644 --- a/.github/workflows/run-tests-cpu.yml +++ b/.github/workflows/run-tests-cpu.yml @@ -54,6 +54,9 @@ jobs: torch: ["1.13.1"] python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] build_type: ["Release", "Debug"] + exclude: + - os: macos-latest + python-version: "3.11" steps: # refer to https://github.com/actions/checkout diff --git a/.github/workflows/test-k2-as-third-party-lib-cuda-ubuntu.yml b/.github/workflows/test-k2-as-third-party-lib-cuda-ubuntu.yml index 439de7d53..b79f4ac70 100644 --- a/.github/workflows/test-k2-as-third-party-lib-cuda-ubuntu.yml +++ b/.github/workflows/test-k2-as-third-party-lib-cuda-ubuntu.yml @@ -104,6 +104,9 @@ jobs: - name: Install GCC 7 run: | + sudo apt update + sudo apt install software-properties-common + sudo add-apt-repository "deb [arch=amd64] http://archive.ubuntu.com/ubuntu focal main universe" sudo apt-get install -y gcc-7 g++-7 echo "CC=/usr/bin/gcc-7" >> $GITHUB_ENV echo "CXX=/usr/bin/g++-7" >> $GITHUB_ENV diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index cecf6940a..dc51e1f14 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -161,10 +161,10 @@ void AddEpsilonSelfLoops(FsaOrVec &src, FsaOrVec *dest, @param[in] b_fsas Input FSAs that correspond to neural network outputs (see documentation in fsa.h). @param[in] search_beam Beam for frame-synchronous beam pruning, - e.g. 20. Smaller is faster, larger is more exact - (less pruning). This is the default value; it may be - modified by {min,max}_active which dictate the minimum - or maximum allowed number of active states per frame. + e.g. 20. Smaller is faster, larger is more exact + (less pruning). This is the default value; it may be + modified by {min,max}_active which dictate the minimum + or maximum allowed number of active states per frame. @param[in] output_beam Beam with which we prune the output (analogous to lattice-beam in Kaldi), e.g. 8. We discard arcs in the output that are not on a path that's within @@ -178,6 +178,11 @@ void AddEpsilonSelfLoops(FsaOrVec &src, FsaOrVec *dest, of states are active. The hash size used per FSA is 4 times (this rounded up to a power of 2), so this affects memory consumption. + @param [in] allow_partial If true and there was no final state active, + we will treat all the states on the last frame + to be final state. If false, we only + care about the real final state in the decoding + graph on the last frame when generating lattice. @param[out] out Output vector of composed, pruned FSAs, with same Dim0() as b_fsas. Elements of it may be empty if the composition was empty, either intrinsically or due to @@ -196,6 +201,7 @@ void AddEpsilonSelfLoops(FsaOrVec &src, FsaOrVec *dest, void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas, float search_beam, float output_beam, int32_t min_active_states, int32_t max_active_states, + bool allow_partial, FsaVec *out, Array1 *arc_map_a, Array1 *arc_map_b); diff --git a/k2/csrc/intersect_dense_pruned.cu b/k2/csrc/intersect_dense_pruned.cu index 6831fc074..e612921ac 100644 --- a/k2/csrc/intersect_dense_pruned.cu +++ b/k2/csrc/intersect_dense_pruned.cu @@ -68,6 +68,12 @@ class MultiGraphDenseIntersectPruned { intersection/composition task. This is advisory, in that it will try not to exceed that but may not always succeed. This determines the hash size. + @param [in] allow_partial If true and there was no final state active, + we will treat all the states on the last frame + to be final state. If false, we only + care about the real final state in the decoding + graph on the last frame when generating lattice. + @param [in] online_decoding True for online decoding (i.e. chunk by chunk decoding), false for running in batch mode. @@ -75,6 +81,7 @@ class MultiGraphDenseIntersectPruned { MultiGraphDenseIntersectPruned(FsaVec &a_fsas, int32_t num_seqs, float search_beam, float output_beam, int32_t min_active, int32_t max_active, + bool allow_partial, bool online_decoding) : a_fsas_(a_fsas), num_seqs_(num_seqs), @@ -82,13 +89,14 @@ class MultiGraphDenseIntersectPruned { output_beam_(output_beam), min_active_(min_active), max_active_(max_active), + allow_partial_(allow_partial), online_decoding_(online_decoding), dynamic_beams_(a_fsas.Context(), num_seqs, search_beam), forward_semaphore_(1), final_t_(a_fsas.Context(), num_seqs, 0) { NVTX_RANGE(K2_FUNC); - c_ = GetContext(a_fsas.shape); T_ = 0; + c_ = GetContext(a_fsas.shape); K2_CHECK_GT(search_beam, 0); K2_CHECK_GT(output_beam, 0); K2_CHECK_GE(min_active, 0); @@ -140,7 +148,7 @@ class MultiGraphDenseIntersectPruned { log-likes of each phone. A series of sequences of (in general) different length. */ - void Intersect(std::shared_ptr &b_fsas) { + void Intersect(DenseFsaVec *b_fsas) { /* T is the largest number of (frames+1) of neural net output, or the largest number of frames of log-likelihoods we count the final frame with (0, @@ -243,22 +251,13 @@ class MultiGraphDenseIntersectPruned { @param [in] frames The frames generated for previously decoded chunks. @param [in] beams Current search beams for each of the sequences, it has `beams.Dim() == num_seqs_`. - @return A pointer to current `frames_`, which would be usefull to + @return A pointer to current `frames_`, which would be useful to generate `DecodeStateInfo` for each sequences. */ const std::vector>* OnlineIntersect( - std::shared_ptr &b_fsas, + DenseFsaVec *b_fsas, std::vector> &frames, Array1 &beams) { - /* - T is the largest number of (frames+1) of neural net output currently - received, or the largest number of frames of log-likelihoods we count the - final frame with (0, -inf, -inf..) that is used for the final-arc. - The largest number of states in the fsas represented by b_fsas equals - T+1 (e.g. 1 frame would require 2 states, because that 1 frame is the arc - from state 0 to state 1). So the #states is 2 greater than the actual - number of frames in the neural-net output. - */ K2_CHECK(online_decoding_); K2_CHECK(c_->IsCompatible(*b_fsas->Context())); K2_CHECK_EQ(a_fsas_.shape.Dim0(), 1); @@ -268,18 +267,21 @@ class MultiGraphDenseIntersectPruned { b_fsas_ = b_fsas; frames_.swap(frames); dynamic_beams_ = beams.To(c_); - T_ = frames_.size(); - // -1 here because we already put the initial frame info to frames_ - int32_t T = T_ + b_fsas_->shape.MaxSize(1) - 1; + // T_ is the actual number of frames we have already processed in previous + // chunks, -1 here because frames_ includes the initial frame. + T_ = frames_.size() - 1; + // -1 here because we add extra frame to b_fsas_ (to handle -1 arc) + // see dense_fsa_vec.py for more details of converting nnet_outputs to fsas. + int32_t chunk_size = b_fsas_->shape.MaxSize(1) - 1; + int32_t T = T_ + chunk_size; - // we'll initially populate frames_[0.. T+1], but discard the one at T+1, - // which has no arcs or states, the ones we use are from 0 to T. - frames_.reserve(T + 2); + // plus initial frame, we actually have T + 1 frames. + frames_.reserve(T + 1); - if (T_ == 0) frames_.push_back(InitialFrameInfo()); - - for (int32_t t = 0; t <= b_fsas_->shape.MaxSize(1); t++) { + // we only do PropagateForward for real frames(i.e. not including the extra + // frame we added to b_fsas_. + for (int32_t t = 0; t < chunk_size; t++) { if (state_map_.NumKeyBits() == 32) { frames_.push_back(PropagateForward<32>(t, frames_.back().get())); } else if (state_map_.NumKeyBits() == 36) { @@ -288,23 +290,12 @@ class MultiGraphDenseIntersectPruned { K2_CHECK_EQ(state_map_.NumKeyBits(), 40); frames_.push_back(PropagateForward<40>(t, frames_.back().get())); } - if (t == b_fsas_->shape.MaxSize(1)) { - PruneTimeRange(T_ - 1, T_ + t); + if (t == chunk_size - 1) { + int32_t start = std::max(0, T_ - 2); + PruneTimeRange(start, T_ + t + 1); } } - // The FrameInfo for time T+1 will have no states. We did that - // last PropagateForward so that the 'arcs' member of frames_[T] - // is set up (it has no arcs but we need the shape). - frames_.pop_back(); - - int32_t history_t = T_ - 1; - - T_ = T - 1; - // partial_final_frame_ is the last frame to generate partial result, - // but it should not be the start frame of next chunk decoding. - partial_final_frame_ = std::move(frames_.back()); - frames_.pop_back(); - + int32_t history_t = T_; const int32_t *b_fsas_row_splits1 = b_fsas_->shape.RowSplits(1).Data(); int32_t *final_t_data = final_t_.Data(); @@ -313,11 +304,106 @@ class MultiGraphDenseIntersectPruned { c_, num_seqs_, lambda_set_final_and_final_t, (int32_t i)->void { int32_t b_chunk_size = b_fsas_row_splits1[i + 1] - b_fsas_row_splits1[i]; - final_t_data[i] = history_t + b_chunk_size - 1; + final_t_data[i] = history_t + b_chunk_size; }); + + // T_ will be used in FormatOutput, plus 1 here because we need an extra + // frame for final arcs (i.e. the partial_final_frame return by + // GetFinalFrame()) to construct the lattice. + T_ = T + 1; return &frames_; } + /* Propagate the last frame in b_fsas_(i.e. the extra frame containing only 0 + and -infs). See dense_fsa_vec.py to get more details of b_fsas_. + + The purpose of this function is to get the final states to construct + partial results for online decoding. It suppose to be invoked in + FormatOutput when online_decoding_ is True. + + This function returns the final FrameInfo needed by the FormatOutput. The + final_frame->states contains the final state for each sequence (if it has), + the final_frame->arcs actually contains no arc at all, but we need its + shape. + + This function also adds the arcs to frames_.back(), normally the arcs of + frames_.back() will be populated in next ForwardPass, we populate it here + so that we can get valid fsas in FormatOutput. It will not affect the + ForwardPass because the ForwardPass only need the states in frames_.back(). + Actually we will re-expand the arcs in frames_.back() in the next + ForwardPass. + */ + std::unique_ptr GetFinalFrame() { + K2_CHECK(online_decoding_); + + // chunk_size is the index of the added extra frame. + int32_t chunk_size = b_fsas_->shape.MaxSize(1) - 1; + FrameInfo *cur_frame = frames_.back().get(); + + // These are all of the expanded arcs, actually we only need the arcs + // pointing to the final states. + auto arcs = GetArcs(chunk_size, cur_frame); + + int32_t num_fsas = NumFsas(); + + // Number of final states for each sequence, should be 0 or 1. + Array1 num_final_states(c_, num_fsas + 1, 0); + // Keep the arcs pointing to final states. + Renumbering renumber_arcs(c_, arcs.NumElements()); + char *keep_this_arc_data = renumber_arcs.Keep().Data(); + const int32_t *arcs_row_ids1_data = arcs.RowIds(1).Data(), + *arcs_row_ids2_data = arcs.RowIds(2).Data(), + *fsa_row_split1_data = a_fsas_.RowSplits(1).Data(); + int32_t *num_final_states_data = num_final_states.Data(); + ArcInfo *arcs_data = arcs.values.Data(); + + K2_EVAL( + c_, arcs.NumElements(), lambda_renumber_arc, (int32_t idx012) -> void { + int32_t idx01 = arcs_row_ids2_data[idx012], + idx0 = arcs_row_ids1_data[idx01]; + ArcInfo ai = arcs_data[idx012]; + // Arcs pointing to final states have non infinity scores + if (ai.arc_loglike - ai.arc_loglike == 0) { + num_final_states_data[idx0] = 1; + keep_this_arc_data[idx012] = 1; + } else { + keep_this_arc_data[idx012] = 0; + } + }); + + int32_t num_arcs = renumber_arcs.NumNewElems(); + const int32_t *new2old_data = renumber_arcs.New2Old().Data(); + Array1 new_arcs(c_, num_arcs); + ArcInfo *new_arcs_data = new_arcs.Data(); + + K2_EVAL(c_, num_arcs, lambda_set_new_arcs, (int32_t new_idx012) -> void { + int32_t old_idx012 = new2old_data[new_idx012]; + ArcInfo old_ai = arcs_data[old_idx012]; + // Only 1 state (the final state) in next frame, so idx1 is always 0. + old_ai.u.dest_info_state_idx1 = 0; + new_arcs_data[new_idx012] = old_ai; + }); + + auto old2new_rowsplits = renumber_arcs.Old2New(true); + auto old2new_shape = RaggedShape2(&old2new_rowsplits, nullptr, num_arcs); + auto total_shape = ComposeRaggedShapes(arcs.shape, old2new_shape); + auto new_arcs_shape = RemoveAxis(total_shape, 2); + cur_frame->arcs = Ragged(new_arcs_shape, new_arcs); + + std::unique_ptr ans = std::make_unique(); + ExclusiveSum(num_final_states, &num_final_states); + auto final_state_shape = RaggedShape2( + &num_final_states, nullptr, -1); + // No arcs for final frame, but we need its shape in FormatOutput. + auto state_to_arc_shape = RegularRaggedShape( + c_, final_state_shape.NumElements(), 0); + auto final_arc_shape = ComposeRaggedShapes( + final_state_shape, state_to_arc_shape); + ans->arcs = Ragged(final_arc_shape, Array1(c_, 0)); + return ans; + } + + void BackwardPass() { int32_t num_fsas = b_fsas_->shape.Dim0(), num_work_items = max_active_ * num_fsas * T_; @@ -386,19 +472,21 @@ class MultiGraphDenseIntersectPruned { } void FormatOutput(FsaVec *ofsa, Array1 *arc_map_a, - Array1 *arc_map_b, bool is_final) { + Array1 *arc_map_b) { NVTX_RANGE("FormatOutput"); bool online_decoding = online_decoding_; + bool allow_partial = allow_partial_; + std::unique_ptr partial_final_frame; if (online_decoding) { + partial_final_frame = std::move(GetFinalFrame()); K2_CHECK(arc_map_a); K2_CHECK_EQ(arc_map_b, nullptr); } else { - K2_CHECK(is_final); K2_CHECK(arc_map_a && arc_map_b); } - int32_t T = is_final ? T_ : T_ + 1; + int32_t T = T_; ContextPtr c_cpu = GetCpuContext(); Array1 arcs_data_ptrs(c_cpu, T + 1); Array1 arcs_row_splits1_ptrs(c_cpu, T + 1); @@ -406,12 +494,12 @@ class MultiGraphDenseIntersectPruned { arcs_data_ptrs.Data()[t] = frames_[t]->arcs.values.Data(); arcs_row_splits1_ptrs.Data()[t] = frames_[t]->arcs.RowSplits(1).Data(); } - arcs_data_ptrs.Data()[T] = is_final - ? frames_[T]->arcs.values.Data() - : partial_final_frame_->arcs.values.Data(); + arcs_data_ptrs.Data()[T] = online_decoding + ? partial_final_frame->arcs.values.Data() + : frames_[T]->arcs.values.Data(); arcs_row_splits1_ptrs.Data()[T] = - is_final ? frames_[T]->arcs.RowSplits(1).Data() - : partial_final_frame_->arcs.RowSplits(1).Data(); + online_decoding ? partial_final_frame->arcs.RowSplits(1).Data() + : frames_[T]->arcs.RowSplits(1).Data(); // transfer to GPU if we're using a GPU arcs_data_ptrs = arcs_data_ptrs.To(c_); @@ -438,11 +526,8 @@ class MultiGraphDenseIntersectPruned { Array1 num_extra_states(c_, num_fsas + 1); int32_t *num_extra_states_data = num_extra_states.Data(); K2_EVAL(c_, num_fsas, lambda_set_num_extra_states, (int32_t i) -> void { - int32_t final_t; - if (online_decoding) - final_t = is_final ? final_t_data[i] : final_t_data[i] + 1; - else - final_t = b_fsas_row_splits1[i+1] - b_fsas_row_splits1[i]; + int32_t final_t = online_decoding ? final_t_data[i] + : b_fsas_row_splits1[i+1] - b_fsas_row_splits1[i]; int32_t *arcs_row_splits1_data = arcs_row_splits1_ptrs_data[final_t]; int32_t num_states_final_t = arcs_row_splits1_data[i + 1] - @@ -477,8 +562,8 @@ class MultiGraphDenseIntersectPruned { for (int32_t t = 0; t < T; t++) arcs_shapes[t] = &(frames_[t]->arcs.shape); - arcs_shapes[T] = is_final ? &(frames_[T]->arcs.shape) - : &(partial_final_frame_->arcs.shape); + arcs_shapes[T] = online_decoding ? &(partial_final_frame->arcs.shape) + : &(frames_[T]->arcs.shape); arcs_shapes[T + 1] = &final_arcs_shape; @@ -537,7 +622,21 @@ class MultiGraphDenseIntersectPruned { int32_t dest_state_idx012 = oarc_idx01x_next + arc_info.u.dest_info_state_idx1; arc.dest_state = dest_state_idx012 - oarc_idx0xx; - arc.label = a_fsas_arcs[arc_info.a_fsas_arc_idx012].label; + int32_t arc_label = a_fsas_arcs[arc_info.a_fsas_arc_idx012].label; + arc.label = arc_label; + + int32_t final_t = online_decoding ? final_t_data[oarc_idx0] + :b_fsas_row_splits1[oarc_idx0+1] - b_fsas_row_splits1[oarc_idx0]; + if (t == final_t - 1 && arc_label != -1) { + if (allow_partial) { + arc.label = -1; + } else { + // Unreachable code. + K2_LOG(FATAL) << + "arc.labe != -1 on final_arc when allow_partial==false."; + } + } + arc.score = arc_info.arc_loglike; arcs_out_data[oarc_idx0123] = arc; @@ -547,7 +646,11 @@ class MultiGraphDenseIntersectPruned { int32_t fsa_id = oarc_idx0, b_fsas_idx0x = b_fsas_row_splits1[fsa_id], b_fsas_idx01 = b_fsas_idx0x + t, - b_fsas_idx2 = (arc.label + 1), + // Use arc_label instead of arc.label to keep track of + // the origial arc index in b_fsas when allow_partial == true. + // Then arc_map_b storages the "correct" arc index instead of + // the non-exist manually added arc pointing to super-final state. + b_fsas_idx2 = (arc_label + 1), b_fsas_arc_idx012 = b_fsas_idx01 * b_fsas_num_cols + b_fsas_idx2; arc_map_b_data[oarc_idx0123] = b_fsas_arc_idx012; } @@ -610,6 +713,7 @@ class MultiGraphDenseIntersectPruned { Array1 cutoffs(c_, num_fsas); float *cutoffs_data = cutoffs.Data(); + bool online_decoding = online_decoding_; K2_EVAL( c_, num_fsas, lambda_set_beam_and_cutoffs, (int32_t i)->void { float best_loglike = max_per_fsa_data[i], @@ -620,7 +724,7 @@ class MultiGraphDenseIntersectPruned { float current_min_active = min_active; // Do less pruning on the few final frames, to ensure we don't prune // away final states. - if (t + 5 >= final_t) { + if (!online_decoding && t + 5 >= final_t) { current_min_active = max(min_active, max_active / 2); } if (active_states <= max_active) { @@ -641,7 +745,7 @@ class MultiGraphDenseIntersectPruned { } else { // We modify dynamic_beam when max_active violated only if it's not // last few frames, in order to avoid final states pruning. - if (t + 5 < final_t) { + if (online_decoding || t + 5 < final_t) { // We violated the max_active constraint -> decrease beam if (dynamic_beam > default_beam) dynamic_beam = default_beam; @@ -652,7 +756,7 @@ class MultiGraphDenseIntersectPruned { } // no pruning on last frame; we want all final-arcs. // -1 because t starts from 0. - if (t == final_t - 1) dynamic_beam = 1.0e+10; + if (!online_decoding && t == final_t - 1) dynamic_beam = 1.0e+10; dynamic_beams_data[i] = dynamic_beam; cutoffs_data[i] = best_loglike - dynamic_beam; @@ -674,7 +778,6 @@ class MultiGraphDenseIntersectPruned { NVTX_RANGE(K2_FUNC); Ragged &states = cur_frame->states; const StateInfo *state_values = states.values.Data(); - float minus_inf = -std::numeric_limits::infinity(); // in a_fsas_ (the decoding graphs), maps from state_idx01 to arc_idx01x. const int32_t *fsa_arc_splits = a_fsas_.shape.RowSplits(2).Data(); @@ -707,6 +810,9 @@ class MultiGraphDenseIntersectPruned { const int32_t *ai_row_ids2 = ai_shape.RowIds(2).Data(); // from state_idx01 to arc_idx01x const int32_t *ai_row_splits2 = ai_shape.RowSplits(2).Data(); + + const int32_t *a_fsas_row_splits1 = a_fsas_.shape.RowSplits(1).Data(); + const int32_t *a_fsas_row_ids1 = a_fsas_.shape.RowIds(1).Data(); // from state_idx01 (into a_fsas_) to arc_idx01x (into a_fsas_) const int32_t *a_fsas_row_splits2 = a_fsas_.shape.RowSplits(2).Data(); @@ -722,6 +828,30 @@ class MultiGraphDenseIntersectPruned { Ragged ai(ai_shape); ArcInfo *ai_data = ai.values.Data(); // uninitialized + // A valid final arc means its label == -1. + auto has_valid_final_arc = Array1(c_, NumFsas(), false); + bool *has_valid_final_arc_data = has_valid_final_arc.Data(); + bool allow_partial = allow_partial_; + + if (allow_partial_) { + K2_EVAL( + c_, ai.values.Dim(), set_has_non_inf_arc, (int32_t ai_arc_idx012)->void { + int32_t ai_state_idx01 = ai_row_ids2[ai_arc_idx012], + ai_fsa_idx0 = ai_row_ids1[ai_state_idx01], + ai_arc_idx01x = ai_row_splits2[ai_state_idx01], + ai_arc_idx2 = ai_arc_idx012 - ai_arc_idx01x; + StateInfo sinfo = state_values[ai_state_idx01]; + int32_t a_fsas_arc_idx01x = + a_fsas_row_splits2[sinfo.a_fsas_state_idx01], + a_fsas_arc_idx012 = a_fsas_arc_idx01x + ai_arc_idx2; + Arc arc = arcs[a_fsas_arc_idx012]; + auto final_t = b_fsas_row_splits1[ai_fsa_idx0+1] - b_fsas_row_splits1[ai_fsa_idx0]; + if (final_t - 1 == t && -1 == arc.label) { + has_valid_final_arc_data[ai_fsa_idx0] = true; + } + }); + } + K2_EVAL( c_, ai.values.Dim(), ai_lambda, (int32_t ai_arc_idx012)->void { int32_t ai_state_idx01 = ai_row_ids2[ai_arc_idx012], @@ -738,13 +868,21 @@ class MultiGraphDenseIntersectPruned { scores_idx01 = scores_idx0x + t, // t == idx1 into 'scores' scores_idx2 = arc.label + 1; // the +1 is so that -1 can be handled - - // Assign negative infinity score to arc which label is out-of-range. - float acoustic_score; - if (scores_idx2 <= scores_num_cols) { - acoustic_score = scores_acc(scores_idx01, scores_idx2); - } else { - acoustic_score = minus_inf; + K2_DCHECK_LT(static_cast(scores_idx2), + static_cast(scores_num_cols)); + float acoustic_score = scores_acc(scores_idx01, scores_idx2); + auto dest_state = arc.dest_state; + auto final_t = b_fsas_row_splits1[ai_fsa_idx0+1] - b_fsas_row_splits1[ai_fsa_idx0]; + + if (final_t - 1 == t && + (allow_partial && !has_valid_final_arc_data[ai_fsa_idx0])) { + int32_t a_fsas_idx0 = a_fsas_row_ids1[sinfo.a_fsas_state_idx01]; + // state_idx1 is 0-based. + // So "-1" is used when calculating a_fsas_final_state_idx1. + int32_t a_fsas_final_state_idx1 + = a_fsas_row_splits1[a_fsas_idx0 + 1] - 1 - a_fsas_row_splits1[a_fsas_idx0]; + dest_state = a_fsas_final_state_idx1; + acoustic_score = 0.0; } ArcInfo ai; ai.a_fsas_arc_idx012 = a_fsas_arc_idx012; @@ -757,7 +895,7 @@ class MultiGraphDenseIntersectPruned { // convert to an idx01; this relies on the fact that // sinfo.abs_state_id == arc.src_state + a_fsas_fsa_idx0x. ai.u.dest_a_fsas_state_idx01 = - sinfo.a_fsas_state_idx01 + arc.dest_state - arc.src_state; + sinfo.a_fsas_state_idx01 + dest_state - arc.src_state; ai_data[ai_arc_idx012] = ai; }); return ai; @@ -949,7 +1087,6 @@ class MultiGraphDenseIntersectPruned { int32_t dest_a_fsas_state_idx01 = info.u.dest_a_fsas_state_idx01; - uint64_t state_map_idx = dest_a_fsas_state_idx01 + fsa_id * state_map_fsa_stride; uint64_t state_idx01; @@ -1504,7 +1641,7 @@ class MultiGraphDenseIntersectPruned { int32_t a_fsas_stride_; // 1 if we use a different FSA per sequence // (a_fsas_.Dim0() > 1), 0 if the decoding graph is // shared (a_fsas_.Dim0() == 1). - std::shared_ptr b_fsas_; // nnet_output to be decoded. + DenseFsaVec *b_fsas_; // nnet_output to be decoded. int32_t num_seqs_; // the number of sequences to decode at a time, // i.e. batch size for decoding. int32_t T_; // equals to b_fsas_->shape.MaxSize(1), for @@ -1514,20 +1651,19 @@ class MultiGraphDenseIntersectPruned { float output_beam_; int32_t min_active_; int32_t max_active_; + bool allow_partial_; Array1 dynamic_beams_; // dynamic beams (initially just search_beam_ // but change due to max_active/min_active // constraints). bool online_decoding_; // true for online decoding. Array1 final_t_; // record the final frame id of each DenseFsa. - std::unique_ptr partial_final_frame_; // store the final frame for // partial results int32_t state_map_fsa_stride_; // state_map_fsa_stride_ is a_fsas_.TotSize(1) // if a_fsas_.Dim0() == 1, else 0. - Hash state_map_; // state_map_ maps from: // key == (state_map_fsa_stride_*n) + a_fsas_state_idx01, // where n is the fsa_idx, i.e. the index into b_fsas_ @@ -1582,6 +1718,7 @@ class MultiGraphDenseIntersectPruned { void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas, float search_beam, float output_beam, int32_t min_active_states, int32_t max_active_states, + bool allow_partial, FsaVec *out, Array1 *arc_map_a, Array1 *arc_map_b) { NVTX_RANGE("IntersectDensePruned"); @@ -1591,22 +1728,22 @@ void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas, search_beam, output_beam, min_active_states, max_active_states, + allow_partial, online_decoding); - - auto b_fsas_p = std::make_shared(b_fsas); - intersector.Intersect(b_fsas_p); - intersector.FormatOutput(out, arc_map_a, arc_map_b, true); + intersector.Intersect(&b_fsas); + intersector.FormatOutput(out, arc_map_a, arc_map_b); } OnlineDenseIntersecter::OnlineDenseIntersecter(FsaVec &a_fsas, int32_t num_seqs, float search_beam, float output_beam, - int32_t min_active_states, int32_t max_active_states) { + int32_t min_active_states, int32_t max_active_states, bool allow_partial) { bool online_decoding = true; K2_CHECK_EQ(a_fsas.NumAxes(), 3); c_ = a_fsas.Context(); search_beam_ = search_beam; impl_ = new MultiGraphDenseIntersectPruned(a_fsas, num_seqs, search_beam, - output_beam, min_active_states, max_active_states, online_decoding); + output_beam, min_active_states, max_active_states, allow_partial, + online_decoding); } OnlineDenseIntersecter::~OnlineDenseIntersecter(){ @@ -1625,10 +1762,9 @@ OnlineDenseIntersecter::~OnlineDenseIntersecter(){ } void OnlineDenseIntersecter::Decode(DenseFsaVec &b_fsas, - std::vector> *decode_states, + std::vector *decode_states, FsaVec *ofsa, Array1 *arc_map_a) { - auto b_fsas_p = std::make_shared(b_fsas); - int32_t num_seqs = b_fsas_p->shape.Dim0(); + int32_t num_seqs = b_fsas.shape.Dim0(); K2_CHECK_EQ(num_seqs, static_cast(decode_states->size())); @@ -1638,26 +1774,26 @@ void OnlineDenseIntersecter::Decode(DenseFsaVec &b_fsas, Array1 beams(GetCpuContext(), num_seqs); float *beams_data = beams.Data(); for (int32_t i = 0; i < num_seqs; ++i) { - // initialization - if (!decode_states->at(i)) { - DecodeStateInfo info; + DecodeStateInfo *decode_state_ptr = decode_states->at(i); + K2_CHECK(decode_state_ptr); + // initialization; NumAxes == 1 means this is an uninitialized Ragged + if (decode_state_ptr->states.NumAxes() == 1) { StateInfo sinfo; // start state of decoding graph sinfo.a_fsas_state_idx01 = 0; sinfo.forward_loglike = FloatToOrderedInt(0.0); - info.states = Ragged( + decode_state_ptr->states = Ragged( RegularRaggedShape(c_, 1, 1), Array1(c_, std::vector{sinfo})); - info.arcs = Ragged(RaggedShape(c_, "[ [ [ x ] ] ]"), + decode_state_ptr->arcs = Ragged(RaggedShape(c_, "[ [ [ x ] ] ]"), Array1(c_, std::vector{ArcInfo()})); - info.beam = search_beam_; - decode_states->at(i) = std::make_shared(info); + decode_state_ptr->beam = search_beam_; } - seq_states_ptr_vec[i] = &(decode_states->at(i)->states); - seq_arcs_ptr_vec[i] = &(decode_states->at(i)->arcs); - beams_data[i] = decode_states->at(i)->beam; + seq_states_ptr_vec[i] = &(decode_state_ptr->states); + seq_arcs_ptr_vec[i] = &(decode_state_ptr->arcs); + beams_data[i] = decode_state_ptr->beam; } auto stack_states = Stack(0, num_seqs, seq_states_ptr_vec.data()); @@ -1685,8 +1821,10 @@ void OnlineDenseIntersecter::Decode(DenseFsaVec &b_fsas, frames[i] = std::make_unique(info); } - const auto new_frames = impl_->OnlineIntersect(b_fsas_p, frames, beams); - impl_->FormatOutput(ofsa, arc_map_a, nullptr/*arc_map_b*/, false); + const auto new_frames = impl_->OnlineIntersect( + &b_fsas, frames, beams); + + impl_->FormatOutput(ofsa, arc_map_a, nullptr/*arc_map_b*/); int32_t frames_num = new_frames->size(); std::vector *> frame_states_ptr_vec(frames_num); @@ -1707,11 +1845,10 @@ void OnlineDenseIntersecter::Decode(DenseFsaVec &b_fsas, beams = impl_->GetBeams().To(GetCpuContext()); beams_data = beams.Data(); for (int32_t i = 0; i < num_seqs; ++i) { - DecodeStateInfo info; - info.states = seq_states_vec[i]; - info.arcs = seq_arcs_vec[i]; - info.beam = beams_data[i]; - decode_states->at(i) = std::make_shared(info); + DecodeStateInfo* decode_state_ptr = decode_states->at(i); + decode_state_ptr->states = seq_states_vec[i]; + decode_state_ptr->arcs = seq_arcs_vec[i]; + decode_state_ptr->beam = beams_data[i]; } } diff --git a/k2/csrc/intersect_dense_pruned.h b/k2/csrc/intersect_dense_pruned.h index 7f19405f1..2d0860724 100644 --- a/k2/csrc/intersect_dense_pruned.h +++ b/k2/csrc/intersect_dense_pruned.h @@ -117,7 +117,7 @@ class MultiGraphDenseIntersectPruned; // DecodeStateInfo contains the history decoding states for each sequence, this // is normally constructed from `frames_` in MultiGraphDenseIntersectPruned -// bu using `Stack` and `Unstack`. +// by using `Stack` and `Unstack`. struct DecodeStateInfo { // States that survived for the previously decoded frames. Indexed // [frame_idx][state_idx], state_idx just enumerates the active states @@ -170,7 +170,7 @@ class OnlineDenseIntersecter { public: OnlineDenseIntersecter(FsaVec &a_fsas, int32_t num_seqs, float search_beam, float output_beam, int32_t min_states, - int32_t max_states); + int32_t max_states, bool allow_partial = true); /* Does intersection/composition for current chunk of nnet_output(given by a DenseFsaVec), sequences in every chunk may come from different @@ -194,7 +194,7 @@ class OnlineDenseIntersecter { will have been assigned to this location. */ void Decode(DenseFsaVec &b_fsas, - std::vector> *decode_states, + std::vector *decode_states, FsaVec *ofsa, Array1 *arc_map_a); ContextPtr &Context() { return c_;} diff --git a/k2/csrc/intersect_test.cu b/k2/csrc/intersect_test.cu index 391503eed..41ee7d24e 100644 --- a/k2/csrc/intersect_test.cu +++ b/k2/csrc/intersect_test.cu @@ -243,7 +243,7 @@ TEST(Intersect, RandomSingle) { K2_LOG(INFO) << "fsas_b = " << fsas_b; FsaVec out_fsas2; Array1 arc_map_a2, arc_map_b2; - // IntersectDensePruned() treats epsilons as normal symbols, so we need to + // IntersectDense() treats epsilons as normal symbols, so we need to // as well. ArcSort(&fsa); // CAUTION if you later test the arc_maps: we arc-sort here, @@ -339,7 +339,7 @@ TEST(Intersect, RandomFsaVec) { K2_LOG(INFO) << "fsas_b = " << fsas_b; FsaVec out_fsas2; Array1 arc_map_a2, arc_map_b2; - // IntersectDensePruned() treats epsilons as normal symbols, so we need to + // IntersectDense() treats epsilons as normal symbols, so we need to // as well. ArcSort(&fsavec); // CAUTION if you later test the arc_maps: we arc-sort @@ -485,11 +485,12 @@ TEST(IntersectPruned, Simple) { float beam = 100000; int32_t max_active = 10000, min_active = 0; + bool allow_partial = false; FsaVec out_fsas; Array1 arc_map_a, arc_map_b; IntersectDensePruned(fsa, dfsavec, beam, beam, min_active, max_active, - &out_fsas, &arc_map_a, &arc_map_b); + allow_partial, &out_fsas, &arc_map_a, &arc_map_b); K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_a = " << arc_map_a << ", arc_map_b = " << arc_map_b; @@ -542,11 +543,12 @@ TEST(IntersectPruned, TwoDense) { float beam = 100000; int32_t max_active = 10000, min_active = 0; + bool allow_partial = false; FsaVec out_fsas; Array1 arc_map_a, arc_map_b; IntersectDensePruned(fsa, dfsavec, beam, beam, min_active, max_active, - &out_fsas, &arc_map_a, &arc_map_b); + allow_partial, &out_fsas, &arc_map_a, &arc_map_b); K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_a = " << arc_map_a << ", arc_map_b = " << arc_map_b; @@ -591,11 +593,12 @@ TEST(IntersectPruned, TwoFsas) { float beam = 100000; int32_t max_active = 10000, min_active = 0; + bool allow_partial = false; FsaVec out_fsas; Array1 arc_map_a, arc_map_b; IntersectDensePruned(fsa_vec, dfsavec, beam, beam, min_active, max_active, - &out_fsas, &arc_map_a, &arc_map_b); + allow_partial, &out_fsas, &arc_map_a, &arc_map_b); K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_a = " << arc_map_a << ", arc_map_b = " << arc_map_b; @@ -659,8 +662,10 @@ TEST(IntersectPruned, RandomSingle) { FsaVec out_fsas; float beam = 1000.0; int32_t max_active = 10000, min_active = 0; + bool allow_partial = false; + IntersectDensePruned(fsa, dfsavec, beam, beam, min_active, max_active, - &out_fsas, &arc_map_a, &arc_map_b); + allow_partial, &out_fsas, &arc_map_a, &arc_map_b); K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_b = " << arc_map_b; FsaVec fsas_b = ConvertDenseToFsaVec(dfsavec); @@ -763,8 +768,11 @@ TEST(IntersectPruned, RandomFsaVec) { FsaVec out_fsas; float search_beam = 1000.0, output_beam = 1000.0; int32_t min_active = 0, max_active = 10; + bool allow_partial = false; + IntersectDensePruned(fsavec, dfsavec, search_beam, output_beam, min_active, - max_active, &out_fsas, &arc_map_a, &arc_map_b); + max_active, allow_partial, + &out_fsas, &arc_map_a, &arc_map_b); K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_a = " << arc_map_a << ", arc_map_b = " << arc_map_b; diff --git a/k2/python/csrc/torch/fsa_algo.cu b/k2/python/csrc/torch/fsa_algo.cu index e8f4c01a9..f875b3212 100644 --- a/k2/python/csrc/torch/fsa_algo.cu +++ b/k2/python/csrc/torch/fsa_algo.cu @@ -202,7 +202,7 @@ static void PybindIntersectDensePruned(py::module &m) { "intersect_dense_pruned", [](FsaVec &a_fsas, DenseFsaVec &b_fsas, float search_beam, float output_beam, int32_t min_active_states, - int32_t max_active_states) + int32_t max_active_states, bool allow_partial) -> std::tuple { DeviceGuard guard(a_fsas.Context()); Array1 arc_map_a; @@ -210,13 +210,15 @@ static void PybindIntersectDensePruned(py::module &m) { FsaVec out; IntersectDensePruned(a_fsas, b_fsas, search_beam, output_beam, - min_active_states, max_active_states, &out, + min_active_states, max_active_states, + allow_partial, &out, &arc_map_a, &arc_map_b); return std::make_tuple(out, ToTorch(arc_map_a), ToTorch(arc_map_b)); }, py::arg("a_fsas"), py::arg("b_fsas"), py::arg("search_beam"), py::arg("output_beam"), py::arg("min_active_states"), - py::arg("max_active_states")); + py::arg("max_active_states"), + py::arg("allow_partial") = false); } static void PybindIntersectDense(py::module &m) { @@ -751,8 +753,9 @@ static void PybindLevenshteinGraph(py::module &m) { static void PybindDecodeStateInfo(py::module &m) { using PyClass = DecodeStateInfo; - py::class_> state_info(m, - "DecodeStateInfo"); + py::class_ state_info( + m, "DecodeStateInfo"); + state_info.def(py::init<>()); } static void PybindOnlineDenseIntersecter(py::module &m) { @@ -763,26 +766,32 @@ static void PybindOnlineDenseIntersecter(py::module &m) { py::init([](FsaVec &decoding_graph, int32_t num_streams, float search_beam, float output_beam, int32_t min_active_states, - int32_t max_active_states) -> std::unique_ptr { + int32_t max_active_states, + bool allow_partial) -> std::unique_ptr { DeviceGuard guard(decoding_graph.Context()); return std::make_unique(decoding_graph, num_streams, search_beam, output_beam, - min_active_states, max_active_states); + min_active_states, max_active_states, + allow_partial); }), py::arg("decoding_graph"), py::arg("num_streams"), py::arg("search_beam"), py::arg("output_beam"), py::arg("min_active_states"), - py::arg("max_active_states")); + py::arg("max_active_states"), py::arg("allow_partial") = true); intersecter.def( "decode", [](PyClass &self, DenseFsaVec &dense_fsa_vec, - std::vector> &decode_states) + std::vector &decode_states) -> std::tuple>> { + std::vector> { DeviceGuard guard(self.Context()); FsaVec ofsa; Array1 arc_map; - self.Decode(dense_fsa_vec, &decode_states, &ofsa, &arc_map); + std::vector decode_states_ptr(decode_states.size()); + for (size_t i = 0; i < decode_states.size(); ++i) { + decode_states_ptr[i] = &decode_states[i]; + } + self.Decode(dense_fsa_vec, &decode_states_ptr, &ofsa, &arc_map); torch::Tensor arc_map_tensor = ToTorch(arc_map); return std::make_tuple(ofsa, arc_map_tensor, decode_states); }, diff --git a/k2/python/k2/autograd.py b/k2/python/k2/autograd.py index c82b2f322..5d62b472b 100644 --- a/k2/python/k2/autograd.py +++ b/k2/python/k2/autograd.py @@ -358,6 +358,7 @@ def forward(ctx, output_beam: float, min_active_states: int, max_active_states: int, + allow_partial: bool, unused_scores_a: torch.Tensor, unused_scores_b: torch.Tensor, seqframe_idx_name: Optional[str] = None, @@ -383,16 +384,21 @@ def forward(ctx, output_beam: Pruning beam for the output of intersection (vs. best path); equivalent to kaldi's lattice-beam. E.g. 8. - max_active_states: - Maximum number of FSA states that are allowed to be active on any - given frame for any given intersection/composition task. This is - advisory, in that it will try not to exceed that but may not always - succeed. You can use a very large number if no constraint is needed. min_active_states: Minimum number of FSA states that are allowed to be active on any given frame for any given intersection/composition task. This is advisory, in that it will try not to have fewer than this number active. Set it to zero if there is no constraint. + max_active_states: + Maximum number of FSA states that are allowed to be active on any + given frame for any given intersection/composition task. This is + advisory, in that it will try not to exceed that but may not always + succeed. You can use a very large number if no constraint is needed. + allow_partial If true and there was no final state active, + we will treat all the states on the + last frame to be final state. If false, we only + care about the real final state in the decoding + graph on the last frame when generating lattice. unused_scores_a: It equals to `a_fsas.scores` and its sole purpose is for back propagation. @@ -418,7 +424,8 @@ def forward(ctx, search_beam=search_beam, output_beam=output_beam, min_active_states=min_active_states, - max_active_states=max_active_states) + max_active_states=max_active_states, + allow_partial=allow_partial) out_fsa[0] = Fsa(ragged_arc) @@ -466,7 +473,7 @@ def forward(ctx, @staticmethod def backward(ctx, out_fsa_grad: torch.Tensor) \ - -> Tuple[None, None, None, None, None, None, None, torch.Tensor, torch.Tensor]: # noqa + -> Tuple[None, None, None, None, None, None, None, None, torch.Tensor, torch.Tensor, None, None]: # noqa a_scores, b_scores = ctx.saved_tensors arc_map_a = ctx.arc_map_a arc_map_b = ctx.arc_map_b @@ -493,6 +500,7 @@ def backward(ctx, out_fsa_grad: torch.Tensor) \ None, # output_beam None, # min_active_states None, # max_active_states + None, # allow_partial grad_a, # unused_scores_a grad_b, # unused_scores_b None, # seqframe_idx_name @@ -663,7 +671,8 @@ def intersect_dense_pruned(a_fsas: Fsa, min_active_states: int, max_active_states: int, seqframe_idx_name: Optional[str] = None, - frame_idx_name: Optional[str] = None) -> Fsa: + frame_idx_name: Optional[str] = None, + allow_partial: bool = False) -> Fsa: '''Intersect array of FSAs on CPU/GPU. Caution: @@ -694,6 +703,11 @@ def intersect_dense_pruned(a_fsas: Fsa, frame for any given intersection/composition task. This is advisory, in that it will try not to exceed that but may not always succeed. You can use a very large number if no constraint is needed. + allow_partial If true and there was no final state active, + we will treat all the states on the + last frame to be final state. If false, we only + care about the real final state in the decoding + graph on the last frame when generating lattice. seqframe_idx_name: If set (e.g. to 'seqframe'), an attribute in the output will be created that encodes the sequence-index and the frame-index within that @@ -727,7 +741,8 @@ def intersect_dense_pruned(a_fsas: Fsa, # in `out_fsa[0].scores` _IntersectDensePrunedFunction.apply(a_fsas, b_fsas, out_fsa, search_beam, output_beam, min_active_states, - max_active_states, a_fsas.scores, + max_active_states, allow_partial, + a_fsas.scores, b_fsas.scores, seqframe_idx_name, frame_idx_name) return out_fsa[0] diff --git a/k2/python/k2/dense_fsa_vec.py b/k2/python/k2/dense_fsa_vec.py index 37c7e040c..580fc21ce 100644 --- a/k2/python/k2/dense_fsa_vec.py +++ b/k2/python/k2/dense_fsa_vec.py @@ -102,7 +102,7 @@ def __init__(self, segment_index, start_frame, duration = segment assert 0 <= segment_index < N assert 0 <= start_frame < T - assert duration > 0 + assert duration >= 0 assert start_frame + duration <= T + allow_truncate offset = segment_index * T end_frame = min(start_frame + duration, T) # exclusive diff --git a/k2/python/k2/online_dense_intersecter.py b/k2/python/k2/online_dense_intersecter.py index 4faebd5b0..efb751047 100644 --- a/k2/python/k2/online_dense_intersecter.py +++ b/k2/python/k2/online_dense_intersecter.py @@ -35,6 +35,7 @@ def __init__( output_beam: float, min_active_states: int, max_active_states: int, + allow_partial: bool = True, ) -> None: """Create a new online intersecter object. Args: @@ -91,6 +92,7 @@ def __init__( decode_states[1] = new_decode_states[1] ... """ + self.num_streams_ = num_streams self.decoding_graph = decoding_graph self.device = decoding_graph.device self.intersecter = _k2.OnlineDenseIntersecter( @@ -100,8 +102,13 @@ def __init__( output_beam, min_active_states, max_active_states, + allow_partial=allow_partial, ) + @property + def num_streams(self) -> int: + return self.num_streams_ + def decode( self, dense_fsas: DenseFsaVec, decode_states: List[DecodeStateInfo] ) -> Tuple[Fsa, List[DecodeStateInfo]]: diff --git a/k2/python/tests/online_dense_intersecter_test.py b/k2/python/tests/online_dense_intersecter_test.py index 2da700bb1..b657264c4 100644 --- a/k2/python/tests/online_dense_intersecter_test.py +++ b/k2/python/tests/online_dense_intersecter_test.py @@ -59,7 +59,7 @@ def test(self): num_chunks = 3 chunk_size = 5 - decode_states = [None] * num_streams + decode_states = [k2.DecodeStateInfo()] * num_streams for i in range(num_chunks): logits = torch.randn( diff --git a/k2/torch/bin/hlg_decode.py b/k2/torch/bin/hlg_decode.py new file mode 100644 index 000000000..d95a6014c --- /dev/null +++ b/k2/torch/bin/hlg_decode.py @@ -0,0 +1,305 @@ +import argparse +import logging +import math +import os +from typing import Any, Dict, List, Optional, Tuple + +import k2 +import kaldifeat +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + +from k2 import ( + get_lattice, + one_best_decoding, + get_aux_labels, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the jit script model.", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + """, + ) + + parser.add_argument( + "--wav-scp", + type=str, + help="""The audio lists to transcribe in wav.scp format""", + ) + + parser.add_argument( + "--output-file", + type=str, + help=""" + The file to write out results to, only used when giving --wav-scp + """, + ) + + parser.add_argument( + "--batch-size", + type=int, + default=5, + help="The number of wavs in a batch.", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="*", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def decode_one_batch( + params: object, + batch: List[Tuple[str, str]], + model: torch.nn.Module, + feature_extractor: kaldifeat.Fbank, + decoding_graph: k2.Fsa, + token_sym_table: Optional[k2.SymbolTable] = None, + word_sym_table: Optional[k2.SymbolTable] = None, +) -> Dict[str, str]: + device = params.device + filenames = [x[1] for x in batch] + waves = read_sound_files( + filenames=filenames, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + features = feature_extractor(waves) + + feature_len = [] + for f in features: + feature_len.append(f.shape[0]) + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + # Note: We don't use key padding mask for attention during decoding + nnet_output, _, _ = model(features) + + log_prob = torch.nn.functional.log_softmax(nnet_output, dim=-1) + log_prob_len = torch.tensor(feature_len) // params.subsampling_factor + log_prob_len = log_prob_len.to(device) + + lattice = get_lattice( + log_prob=log_prob, + log_prob_len=log_prob_len, + decoding_graph=decoding_graph, + subsampling_factor=params.subsampling_factor, + ) + best_path = one_best_decoding(lattice=lattice, use_double_scores=True) + + hyps = get_aux_labels(best_path) + + if params.method == "ctc-decoding": + hyps = ["".join([token_sym_table[i] for i in ids]) for ids in hyps] + else: + assert params.method == "1best", params.method + hyps = [" ".join([word_sym_table[i] for i in ids]) for ids in hyps] + + results = {} + for i, hyp in enumerate(hyps): + results[batch[i][0]] = hyp.replace("▁", " ").strip() + return results + + +def main(): + parser = get_parser() + args = parser.parse_args() + + args.sample_rate = 16000 + args.subsampling_factor = 4 + args.feature_dim = 80 + args.num_classes = 500 + + wave_list: List[Tuple[str, str]] = [] + if args.wav_scp is not None: + assert os.path.isfile( + args.wav_scp + ), f"wav_scp not exists : {args.wav_scp}" + assert ( + args.output_file is not None + ), "You should provide output_file when using wav_scp" + with open(args.wav_scp, "r") as f: + for line in f: + toks = line.strip().split() + assert len(toks) == 2, toks + if not os.path.isfile(toks[1]): + logging.warning(f"File {toks[1]} not exists, skipping.") + continue + wave_list.append(toks) + else: + assert len(args.sound_files) > 0, "No wav_scp or waves provided." + for i, f in enumerate(args.sound_files): + if not os.path.isfile(f): + logging.warning(f"File {f} not exists, skipping.") + continue + wave_list.append((i, f)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + args.device = device + + logging.info(f"params : {args}") + + logging.info("Creating model") + model = torch.jit.load(args.nn_model) + model = model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = args.feature_dim + + fbank = kaldifeat.Fbank(opts) + + token_sym_table = None + word_sym_table = None + if args.method == "ctc-decoding": + logging.info("Use CTC decoding") + max_token_id = args.num_classes - 1 + decoding_graph = k2.ctc_topo( + max_token=max_token_id, + device=device, + ) + token_sym_table = k2.SymbolTable.from_file(args.tokens) + else: + assert args.method == "1best", args.method + logging.info(f"Loading HLG from {args.HLG}") + decoding_graph = k2.Fsa.from_dict( + torch.load(args.HLG, map_location="cpu") + ) + decoding_graph = decoding_graph.to(device) + word_sym_table = k2.SymbolTable.from_file(args.words_file) + decoding_graph = k2.Fsa.from_fsas([decoding_graph]) + + results = {} + start = 0 + while start + args.batch_size <= len(wave_list): + + if start % 100 == 0: + logging.info(f"Decoding progress: {start}/{len(wave_list)}.") + + res = decode_one_batch( + params=args, + batch=wave_list[start: start + args.batch_size], + model=model, + feature_extractor=fbank, + decoding_graph=decoding_graph, + token_sym_table=token_sym_table, + word_sym_table=word_sym_table, + ) + start += args.batch_size + + results.update(res) + + logging.info(f"results : {results}") + + if args.wav_scp is not None: + output_dir = os.path.dirname(args.output_file) + if output_dir != "": + os.makedirs(output_dir, exist_ok=True) + with open(args.output_file, "w", encoding="utf-8") as f: + for x in wave_list: + f.write(x[0] + "\t" + results[x[0]] + "\n") + logging.info(f"Decoding results are written to {args.output_file}") + else: + s = "\n" + logging.info(f"results : {results}") + for x in wave_list: + s += f"{x[1]}:\n{results[x[0]]}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/k2/torch/bin/online_decode.cu b/k2/torch/bin/online_decode.cu index 362b1002a..f1c75fbd9 100644 --- a/k2/torch/bin/online_decode.cu +++ b/k2/torch/bin/online_decode.cu @@ -219,7 +219,7 @@ int main(int argc, char *argv[]) { FLAGS_min_activate_states, FLAGS_max_activate_states); // store decode states for each waves - std::vector> states_info(num_waves); + std::vector states_info(num_waves); // decocding results for each waves std::vector texts(num_waves, ""); @@ -227,12 +227,12 @@ int main(int argc, char *argv[]) { std::vector positions(num_waves, 0); int32_t T = nnet_output.size(1); - int32_t chunk_size = 10; // 20 frames per chunk + int32_t chunk_size = 10; // 10 frames per chunk // simulate asynchronous decoding while (true) { - std::vector> current_states_info( - FLAGS_num_streams); + k2::DecodeStateInfo dummy_state_info; + std::vector current_states_info; std::vector num_frame; std::vector current_nnet_output; // which waves we are decoding now @@ -242,10 +242,10 @@ int main(int argc, char *argv[]) { // this wave is done if (num_frames[i] == 0) continue; - current_states_info[current_wave_ids.size()] = states_info[i]; + current_states_info.push_back(&states_info[i]); current_wave_ids.push_back(i); - if (num_frames[i] < chunk_size * subsampling_factor) { + if (num_frames[i] <= chunk_size * subsampling_factor) { num_frame.push_back(num_frames[i]); num_frames[i] = 0; } else { @@ -280,6 +280,7 @@ int main(int argc, char *argv[]) { .device(nnet_output.device()); current_nnet_output.push_back( torch::zeros({chunk_size, nnet_output.size(2)}, opts)); + current_states_info.push_back(&dummy_state_info); } auto sub_nnet_output = torch::stack(current_nnet_output); @@ -303,11 +304,6 @@ int main(int argc, char *argv[]) { decoder.Decode(dense_fsa_vec, ¤t_states_info, &fsa, &graph_arc_map); - // update decoding states - for (size_t i = 0; i < current_wave_ids.size(); ++i) { - states_info[current_wave_ids[i]] = current_states_info[i]; - } - k2::FsaClass lattice(fsa); lattice.CopyAttrs(decoding_graph, k2::Array1ToTorch(graph_arc_map)); diff --git a/k2/torch/bin/online_decode.py b/k2/torch/bin/online_decode.py new file mode 100644 index 000000000..1d0bba3f6 --- /dev/null +++ b/k2/torch/bin/online_decode.py @@ -0,0 +1,383 @@ +import argparse +import logging +import math +import os +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import k2 +import kaldifeat +import torch +import torchaudio +from k2 import ( + DecodeStateInfo, + OnlineDenseIntersecter, + one_best_decoding, + get_aux_labels, +) +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the jit script model.", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + """, + ) + + parser.add_argument( + "--num-streams", + type=int, + default=2, + help="""The number of streams that can be run in parallel.""", + ) + + parser.add_argument( + "--wav-scp", + type=str, + help="""The audio lists to transcribe in wav.scp format""", + ) + + parser.add_argument( + "--output-file", + type=str, + help=""" + The file to write out results to, only used when giving --wav-scp + """, + ) + + parser.add_argument( + "--print-partial", + dest="print_partial", + action="store_true", + help="Whether print partial results.", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="*", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +@dataclass +class DecodeStream: + # The identifier of wavs. + utt_id: str + # The total number of frames for current nnet_output. + num_frames: int + # The output of encoder. + nnet_output: torch.Tensor + # Current position, index in to feature. + position: int + # Decode state for intersect_dense_pruned. + state_info: DecodeStateInfo + # Current decoding result. + result: str + + +def decode_one_chunk( + params: object, + intersector: k2.OnlineDenseIntersecter, + streams: List[DecodeStream], + token_sym_table: Optional[k2.SymbolTable] = None, + word_sym_table: Optional[k2.SymbolTable] = None, +) -> List[int]: + assert params.num_streams == intersector.num_streams, ( + params.num_streams, + intersector.num_streams, + ) + current_state_infos = [] + current_nnet_outputs = [] + current_num_frames = [] + finised_streams = [] + for i, stream in enumerate(streams): + start = stream.position + if (stream.num_frames - stream.position) <= params.chunk_size: + current_num_frames.append(stream.num_frames - stream.position) + end = stream.num_frames + stream.position = stream.num_frames + finised_streams.append(i) + else: + current_num_frames.append(params.chunk_size) + end = stream.position + params.chunk_size + stream.position += params.chunk_size + current_state_infos.append(stream.state_info) + current_nnet_outputs.append(stream.nnet_output[start:end, :]) + + while len(current_num_frames) < params.num_streams: + current_num_frames.append(0) + current_nnet_outputs.append( + torch.zeros( + (params.chunk_size, params.num_classes), + device=params.device, + ) + ) + current_state_infos.append(DecodeStateInfo()) + + current_nnet_outputs = pad_sequence(current_nnet_outputs, batch_first=True) + supervision_segments = torch.tensor( + # seq_index, start_time, duration + [[i, 0, current_num_frames[i]] for i in range(params.num_streams)], + dtype=torch.int32, + ) + dense_fsa_vec = k2.DenseFsaVec(current_nnet_outputs, supervision_segments) + lattice, current_state_infos = intersector.decode( + dense_fsa_vec, current_state_infos + ) + + best_path = one_best_decoding(lattice=lattice, use_double_scores=True) + symbol_ids = get_aux_labels(best_path) + + if params.method == "ctc-decoding": + assert token_sym_table is not None + hyps = [ + "".join([token_sym_table[i] for i in ids]) for ids in symbol_ids + ] + else: + assert word_sym_table is not None + assert params.method == "1best", params.method + hyps = [ + " ".join([word_sym_table[i] for i in ids]) for ids in symbol_ids + ] + for i, stream in enumerate(streams): + stream.state_info = current_state_infos[i] + stream.result = hyps[i].replace("▁", " ").strip() + return finised_streams + + +def decode_dataset( + params: object, + waves: List[Tuple[str, str]], + model: torch.nn.Module, + feature_extractor: kaldifeat.Fbank, + intersector: k2.OnlineDenseIntersecter, + token_sym_table: Optional[k2.SymbolTable] = None, + word_sym_table: Optional[k2.SymbolTable] = None, +) -> Dict[str, str]: + results = {} + decode_streams = [] + wave_index = 0 + while True: + if wave_index < len(waves) and len(decode_streams) < params.num_streams: + data, sample_rate = torchaudio.load(waves[wave_index][1]) + assert ( + sample_rate == params.sample_rate + ), f"expected sample rate: {params.sample_rate}. " + f"Given: {sample_rate}" + data = data[0].to(params.device) + feature = feature_extractor(data) + nnet_output, _, _ = model(feature.unsqueeze(0)) + decode_streams.append( + DecodeStream( + utt_id=waves[wave_index][0], + num_frames=nnet_output.shape[1], + nnet_output=nnet_output[0], + position=0, + state_info=DecodeStateInfo(), + result="", + ) + ) + wave_index += 1 + if wave_index % 100 == 0: + logging.info(f"Decoding progress: {wave_index}/{len(waves)}.") + continue + + if len(decode_streams) == 0: + break + + finised_streams = decode_one_chunk( + params=params, + intersector=intersector, + streams=decode_streams, + token_sym_table=token_sym_table, + word_sym_table=word_sym_table, + ) + + if params.print_partial: + s = "\n" + for stream in decode_streams: + s += f"{stream.utt_id}:\t{stream.result}\n\n" + logging.info(s) + + if finised_streams: + finised_streams = sorted(finised_streams, reverse=True) + for j in finised_streams: + results[decode_streams[j].utt_id] = decode_streams[j].result + del decode_streams[j] + return results + + +def main(): + parser = get_parser() + args = parser.parse_args() + + args.sample_rate = 16000 + args.subsampling_factor = 4 + args.feature_dim = 80 + args.num_classes = 500 + args.chunk_size = 16 + + wave_list: List[Tuple[str, str]] = [] + if args.wav_scp is not None: + assert os.path.isfile( + args.wav_scp + ), f"wav_scp not exists : {args.wav_scp}" + assert ( + args.output_file is not None + ), "You should provide output_file when using wav_scp" + with open(args.wav_scp, "r") as f: + for line in f: + toks = line.strip().split() + assert len(toks) == 2, toks + if not os.path.isfile(toks[1]): + logging.warning(f"File {toks[1]} not exists, skipping.") + continue + wave_list.append(toks) + else: + assert len(args.sound_files) > 0, "No wav_scp or waves provided." + for i, f in enumerate(args.sound_files): + if not os.path.isfile(f): + logging.warning(f"File {f} not exists, skipping.") + continue + wave_list.append((i, f)) + + # logging.info(f"wave_list : {wave_list}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + args.device = device + + logging.info(f"params : {args}") + + logging.info("Creating model") + model = torch.jit.load(args.nn_model) + model = model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = args.feature_dim + + fbank = kaldifeat.Fbank(opts) + + token_sym_table = None + word_sym_table = None + if args.method == "ctc-decoding": + logging.info("Use CTC decoding") + max_token_id = args.num_classes - 1 + decoding_graph = k2.ctc_topo( + max_token=max_token_id, + device=device, + ) + token_sym_table = k2.SymbolTable.from_file(args.tokens) + else: + assert args.method == "1best", args.method + logging.info(f"Loading HLG from {args.HLG}") + decoding_graph = k2.Fsa.from_dict( + torch.load(args.HLG, map_location="cpu") + ) + decoding_graph = decoding_graph.to(device) + word_sym_table = k2.SymbolTable.from_file(args.words_file) + decoding_graph = k2.Fsa.from_fsas([decoding_graph]) + + intersector = k2.OnlineDenseIntersecter( + decoding_graph=decoding_graph, + num_streams=args.num_streams, + search_beam=20, + output_beam=8, + min_active_states=30, + max_active_states=10000, + ) + + results = decode_dataset( + params=args, + waves=wave_list, + model=model, + feature_extractor=fbank, + intersector=intersector, + token_sym_table=token_sym_table, + word_sym_table=word_sym_table, + ) + + if args.wav_scp is not None: + output_dir = os.path.dirname(args.output_file) + if output_dir != "": + os.makedirs(output_dir, exist_ok=True) + with open(args.output_file, "w", encoding="utf-8") as f: + for x in wave_list: + f.write(x[0] + "\t" + results[x[0]] + "\n") + logging.info(f"Decoding results are written to {args.output_file}") + else: + s = "\n" + logging.info(f"results : {results}") + for x in wave_list: + s += f"{x[1]}:\n{results[x[0]]}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/k2/torch/csrc/CMakeLists.txt b/k2/torch/csrc/CMakeLists.txt index e4831bca6..39775a627 100644 --- a/k2/torch/csrc/CMakeLists.txt +++ b/k2/torch/csrc/CMakeLists.txt @@ -75,7 +75,7 @@ target_link_libraries(k2_torch_api PUBLIC k2_torch) if(K2_ENABLE_TESTS) add_executable(torch_api_test torch_api_test.cc) - target_link_libraries(torch_api_test PRIVATE k2_torch_api gtest gtest_main) + target_link_libraries(torch_api_test k2_torch_api gtest gtest_main) # NOTE: We set the working directory here so that # it works also on windows. The reason is that diff --git a/k2/torch/csrc/fsa_algo.cu b/k2/torch/csrc/fsa_algo.cu index f6583c291..adcb75946 100644 --- a/k2/torch/csrc/fsa_algo.cu +++ b/k2/torch/csrc/fsa_algo.cu @@ -47,13 +47,14 @@ FsaClass TrivialGraph(int32_t max_token, FsaClass IntersectDensePruned(FsaClass &graph, DenseFsaVec &dense, float search_beam, float output_beam, int32_t min_activate_states, - int32_t max_activate_states) { + int32_t max_activate_states, + bool allow_partial) { Array1 graph_arc_map; Array1 dense_arc_map; FsaVec fsa; IntersectDensePruned(graph.fsa, dense, search_beam, output_beam, - min_activate_states, max_activate_states, &fsa, - &graph_arc_map, &dense_arc_map); + min_activate_states, max_activate_states, allow_partial, + &fsa, &graph_arc_map, &dense_arc_map); FsaClass dest(fsa); dest.CopyAttrs(graph, Array1ToTorch(graph_arc_map)); return dest; diff --git a/k2/torch/csrc/fsa_algo.h b/k2/torch/csrc/fsa_algo.h index 04087e880..223f5051d 100644 --- a/k2/torch/csrc/fsa_algo.h +++ b/k2/torch/csrc/fsa_algo.h @@ -85,13 +85,19 @@ FsaClass TrivialGraph(int32_t max_token, torch::Device device = torch::kCPU); in that it will try not to have fewer than this number active. Set it to zero if there is no constraint. + @param [in] allow_partial If true and there was no final state active, + we will treat all the states on the last frame + to be final state. If false, we only + care about the real final state in the decoding + graph on the last frame when generating lattice. @return Returns an FsaClass containing the intersection of DenseFsaVec and decoding graphs with the attributes propagated. */ FsaClass IntersectDensePruned(FsaClass &graphs, DenseFsaVec &dense, float search_beam, float output_beam, int32_t min_activate_states, - int32_t max_activate_states); + int32_t max_activate_states, + bool allow_partial = false); /* Return the shortest paths as linear FSAs from the start state to the final state in the tropical semiring. diff --git a/scripts/github_actions/generate_build_matrix.py b/scripts/github_actions/generate_build_matrix.py index f576aab0b..fc04f6572 100755 --- a/scripts/github_actions/generate_build_matrix.py +++ b/scripts/github_actions/generate_build_matrix.py @@ -213,7 +213,7 @@ def generate_build_matrix( "torch": torch, "python-version": p, "cuda": c, - "image": f"pytorch/manylinux-builder:cuda{c}", + "image": "pytorch/manylinux-builder:cuda" + c, } ) else: diff --git a/scripts/github_actions/install_cuda.sh b/scripts/github_actions/install_cuda.sh index f7a669a45..f94e7d869 100755 --- a/scripts/github_actions/install_cuda.sh +++ b/scripts/github_actions/install_cuda.sh @@ -49,6 +49,9 @@ case "$cuda" in 11.7) url=https://developer.download.nvidia.com/compute/cuda/11.7.1/local_installers/cuda_11.7.1_515.65.01_linux.run ;; + 11.8) + url=https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run + ;; *) echo "Unknown cuda version: $cuda" exit 1 diff --git a/scripts/github_actions/install_cudnn.sh b/scripts/github_actions/install_cudnn.sh index d57018ce0..7bfe681e4 100755 --- a/scripts/github_actions/install_cudnn.sh +++ b/scripts/github_actions/install_cudnn.sh @@ -42,6 +42,9 @@ case $cuda in 11.7) filename=cudnn-11.3-linux-x64-v8.2.0.53.tgz ;; + 11.8) + filename=cudnn-11.3-linux-x64-v8.2.0.53.tgz + ;; *) echo "Unsupported cuda version: $cuda" exit 1 diff --git a/scripts/github_actions/install_torch.sh b/scripts/github_actions/install_torch.sh index 7ba74857a..84eef395f 100755 --- a/scripts/github_actions/install_torch.sh +++ b/scripts/github_actions/install_torch.sh @@ -14,8 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -torch=$TORCH_VERSION -cuda=$CUDA_VERSION +if [ $TORCH_VERSION != "" ] && [ $CUDA_VERSION != ""]; then + torch=$TORCH_VERSION + cuda=$CUDA_VERSION +fi + case ${torch} in 1.5.*) case ${cuda} in