Skip to content

Commit

Permalink
Use InferenceMode to replace NoGradMode.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Jan 25, 2024
1 parent bd0047c commit a552531
Show file tree
Hide file tree
Showing 28 changed files with 123 additions and 73 deletions.
25 changes: 5 additions & 20 deletions .github/workflows/run-cpp-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
torch: ["1.13.1"]
torch: ["2.1.2"]
python-version: ["3.8"]
build_type: ["Release", "Debug"]

Expand Down Expand Up @@ -94,39 +94,24 @@ jobs:
python3 -m pip install wheel twine typing_extensions
python3 -m pip install torch==${{ matrix.torch }} numpy -f https://download.pytorch.org/whl/cpu/torch_stable.html
python3 -m pip install k2==1.23.4.dev20230319+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/index.html
python3 -m pip install k2==1.24.4.dev20231220+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/index.html
python3 -m pip install kaldifeat==1.25.3.dev20231221+cpu.torch${{ matrix.torch }} -f https://csukuangfj.github.io/kaldifeat/cpu.html
python3 -m torch.utils.collect_env
- name: Cache kaldifeat
id: my-cache-2
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.torch }}-kaldifeat-v1.22
- name: Install kaldifeat
if: steps.my-cache-2.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/install-kaldifeat.sh
python3 -m torch.utils.collect_env
- name: Build sherpa
shell: bash
env:
BUILD_TYPE: ${{ matrix.build_type }}
run: |
export KALDIFEAT_INSTALL_PREFIX=$HOME/tmp/kaldifeat/build
echo $KALDIFEAT_INSTALL_PREFIX
ls -lh $KALDIFEAT_INSTALL_PREFIX
echo "Build type: $BUILD_TYPE"
mkdir build
cd build
cmake \
-DCMAKE_CXX_STANDARD=17 \
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
-DCMAKE_BUILD_TYPE=$BUILD_TYPE \
Expand Down
6 changes: 6 additions & 0 deletions docs/source/sherpa/install/from_source.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ in a single environment.
cd sherpa
mkdir build
cd build
# For torch >= 2.0, please use
#
# cmake -DCMAKE_CXX_STANDARD=17 ..
#
cmake ..
make -j
Expand Down
2 changes: 1 addition & 1 deletion sherpa/cpp_api/bin/offline-recognizer-microphone.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ for more details.
// https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html
torch::set_num_threads(1);
torch::set_num_interop_threads(1);
torch::NoGradGuard no_grad;
sherpa::InferenceMode no_grad;

torch::jit::getExecutorMode() = false;
torch::jit::getProfilingMode() = false;
Expand Down
2 changes: 1 addition & 1 deletion sherpa/cpp_api/bin/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ int main(int argc, char *argv[]) {
// https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html
torch::set_num_threads(1);
torch::set_num_interop_threads(1);
torch::NoGradGuard no_grad;
sherpa::InferenceMode no_grad;

torch::jit::getExecutorMode() = false;
torch::jit::getProfilingMode() = false;
Expand Down
2 changes: 1 addition & 1 deletion sherpa/cpp_api/bin/online-recognizer-microphone.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ int main(int argc, char *argv[]) {
// https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html
torch::set_num_threads(1);
torch::set_num_interop_threads(1);
torch::NoGradGuard no_grad;
sherpa::InferenceMode no_grad;

torch::jit::getExecutorMode() = false;
torch::jit::getProfilingMode() = false;
Expand Down
2 changes: 1 addition & 1 deletion sherpa/cpp_api/bin/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ int32_t main(int32_t argc, char *argv[]) {
// https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html
torch::set_num_threads(1);
torch::set_num_interop_threads(1);
torch::NoGradGuard no_grad;
sherpa::InferenceMode no_grad;

torch::jit::getExecutorMode() = false;
torch::jit::getProfilingMode() = false;
Expand Down
7 changes: 3 additions & 4 deletions sherpa/cpp_api/grpc/online-grpc-server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
// Copyright (c) 2022 Xiaomi Corporation
// 2023 y00281951

#include "asio.hpp"
#include "grpcpp/ext/proto_server_reflection_plugin.h"
#include "grpcpp/grpcpp.h"
#include "grpcpp/health_check_service_interface.h"

#include "asio.hpp"
#include "sherpa/cpp_api/grpc/online-grpc-server-impl.h"
#include "sherpa/csrc/log.h"
#include "torch/all.h"
#include "sherpa/cpp_api/grpc/online-grpc-server-impl.h"

using grpc::Server;
using grpc::ServerBuilder;
Expand All @@ -34,7 +33,7 @@ sherpa-online-grpc-server \
int32_t main(int32_t argc, char *argv[]) {
torch::set_num_threads(1);
torch::set_num_interop_threads(1);
torch::NoGradGuard no_grad;
sherpa::InferenceMode no_grad;

torch::jit::getExecutorMode() = false;
torch::jit::getProfilingMode() = false;
Expand Down
33 changes: 33 additions & 0 deletions sherpa/cpp_api/macros.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/**
* Copyright (c) 2024 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef SHERPA_CPP_API_MACROS_H_
#define SHERPA_CPP_API_MACROS_H_

namespace sherpa {

#if SHERPA_TORCH_VERSION_MAJOR > 1 || \
(SHERPA_TORCH_VERSION_MAJOR == 1 && SHERPA_TORCH_VERSION_MINOR >= 9)
using InferenceMode = torch::InferenceMode;
#else
using InferenceMode = torch::NoGradGuard;
#endif

} // namespace sherpa

#endif // SHERPA_CPP_API_MACROS_H_
2 changes: 1 addition & 1 deletion sherpa/cpp_api/offline-recognizer-ctc-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}

void DecodeStreams(OfflineStream **ss, int32_t n) override {
torch::NoGradGuard no_grad;
InferenceMode no_grad;

std::vector<torch::Tensor> features_vec(n);
std::vector<int64_t> features_length_vec(n);
Expand Down
2 changes: 1 addition & 1 deletion sherpa/cpp_api/offline-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}

void DecodeStreams(OfflineStream **ss, int32_t n) override {
torch::NoGradGuard no_grad;
InferenceMode no_grad;

bool has_context_graph = false;
std::vector<torch::Tensor> features_vec(n);
Expand Down
1 change: 1 addition & 0 deletions sherpa/cpp_api/offline-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "sherpa/cpp_api/fast-beam-search-config.h"
#include "sherpa/cpp_api/feature-config.h"
#include "sherpa/cpp_api/macros.h"
#include "sherpa/cpp_api/offline-stream.h"

namespace sherpa {
Expand Down
4 changes: 2 additions & 2 deletions sherpa/cpp_api/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ class OnlineRecognizer::OnlineRecognizerImpl {
}

void DecodeStreams(OnlineStream **ss, int32_t n) {
torch::NoGradGuard no_grad;
InferenceMode no_grad;

SHERPA_CHECK_GT(n, 0);

Expand Down Expand Up @@ -551,7 +551,7 @@ bool OnlineRecognizer::IsEndpoint(OnlineStream *s) {
}

void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) {
torch::NoGradGuard no_grad;
InferenceMode no_grad;
impl_->DecodeStreams(ss, n);
}

Expand Down
1 change: 1 addition & 0 deletions sherpa/cpp_api/online-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "sherpa/cpp_api/endpoint.h"
#include "sherpa/cpp_api/fast-beam-search-config.h"
#include "sherpa/cpp_api/feature-config.h"
#include "sherpa/cpp_api/macros.h"
#include "sherpa/cpp_api/online-stream.h"

namespace sherpa {
Expand Down
2 changes: 1 addition & 1 deletion sherpa/cpp_api/websocket/offline-websocket-server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ sherpa-offline-websocket-server \
int32_t main(int32_t argc, char *argv[]) {
torch::set_num_threads(1);
torch::set_num_interop_threads(1);
torch::NoGradGuard no_grad;
sherpa::InferenceMode no_grad;

torch::jit::getExecutorMode() = false;
torch::jit::getProfilingMode() = false;
Expand Down
2 changes: 1 addition & 1 deletion sherpa/cpp_api/websocket/online-websocket-server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ sherpa-online-websocket-server \
int32_t main(int32_t argc, char *argv[]) {
torch::set_num_threads(1);
torch::set_num_interop_threads(1);
torch::NoGradGuard no_grad;
sherpa::InferenceMode no_grad;

torch::jit::getExecutorMode() = false;
torch::jit::getProfilingMode() = false;
Expand Down
6 changes: 4 additions & 2 deletions sherpa/csrc/offline-conformer-ctc-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <string>
#include <vector>

#include "sherpa/cpp_api/macros.h"

namespace sherpa {

OfflineConformerCtcModel::OfflineConformerCtcModel(
Expand All @@ -17,7 +19,7 @@ OfflineConformerCtcModel::OfflineConformerCtcModel(

torch::IValue OfflineConformerCtcModel::Forward(torch::Tensor features,
torch::Tensor features_length) {
torch::NoGradGuard no_grad;
InferenceMode no_grad;

int32_t batch_size = features.size(0);

Expand All @@ -38,7 +40,7 @@ torch::Tensor OfflineConformerCtcModel::GetLogSoftmaxOut(

torch::Tensor OfflineConformerCtcModel::GetLogSoftmaxOutLength(
torch::IValue forward_out) const {
torch::NoGradGuard no_grad;
InferenceMode no_grad;

auto mask = forward_out.toTuple()->elements()[2].toTensor();
return (~mask).sum(1);
Expand Down
8 changes: 5 additions & 3 deletions sherpa/csrc/offline-conformer-transducer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <string>
#include <utility>

#include "sherpa/cpp_api/macros.h"

namespace sherpa {

OfflineConformerTransducerModel::OfflineConformerTransducerModel(
Expand Down Expand Up @@ -36,7 +38,7 @@ OfflineConformerTransducerModel::OfflineConformerTransducerModel(
std::pair<torch::Tensor, torch::Tensor>
OfflineConformerTransducerModel::RunEncoder(
const torch::Tensor &features, const torch::Tensor &features_length) {
torch::NoGradGuard no_grad;
InferenceMode no_grad;

auto outputs =
encoder_.run_method("forward", features, features_length).toTuple();
Expand All @@ -52,7 +54,7 @@ OfflineConformerTransducerModel::RunEncoder(

torch::Tensor OfflineConformerTransducerModel::RunDecoder(
const torch::Tensor &decoder_input) {
torch::NoGradGuard no_grad;
InferenceMode no_grad;
auto decoder_out =
decoder_.run_method("forward", decoder_input, /*need_pad*/ false);

Expand All @@ -61,7 +63,7 @@ torch::Tensor OfflineConformerTransducerModel::RunDecoder(

torch::Tensor OfflineConformerTransducerModel::RunJoiner(
const torch::Tensor &encoder_out, const torch::Tensor &decoder_out) {
torch::NoGradGuard no_grad;
InferenceMode no_grad;
return joiner_
.run_method("forward", encoder_out, decoder_out,
/*project_input*/ false)
Expand Down
3 changes: 2 additions & 1 deletion sherpa/csrc/offline-ctc-one-best-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <utility>

#include "sherpa/cpp_api/macros.h"
#include "sherpa/csrc/log.h"

namespace sherpa {
Expand Down Expand Up @@ -33,7 +34,7 @@ std::vector<OfflineCtcDecoderResult> OfflineCtcOneBestDecoder::Decode(
SHERPA_CHECK_EQ(log_prob.size(2), vocab_size_);
}

torch::NoGradGuard no_grad;
InferenceMode no_grad;

auto lattice = k2::GetLattice(log_prob, log_prob_len.cpu(), decoding_graph_,
config_.search_beam, config_.output_beam,
Expand Down
4 changes: 3 additions & 1 deletion sherpa/csrc/offline-nemo-enc-dec-ctc-model-bpe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "sherpa/csrc/offline-nemo-enc-dec-ctc-model-bpe.h"

#include "sherpa/cpp_api/macros.h"

namespace sherpa {

OfflineNeMoEncDecCTCModelBPE::OfflineNeMoEncDecCTCModelBPE(
Expand All @@ -15,7 +17,7 @@ OfflineNeMoEncDecCTCModelBPE::OfflineNeMoEncDecCTCModelBPE(

torch::IValue OfflineNeMoEncDecCTCModelBPE::Forward(
torch::Tensor features, torch::Tensor features_length) {
torch::NoGradGuard no_grad;
InferenceMode no_grad;

// Change (N, T, C) to (N, C, T)
features = features.permute({0, 2, 1});
Expand Down
3 changes: 2 additions & 1 deletion sherpa/csrc/offline-transducer-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <utility>
#include <vector>

#include "sherpa/cpp_api/macros.h"
#include "torch/all.h"

namespace sherpa {
Expand Down Expand Up @@ -38,7 +39,7 @@ OfflineTransducerGreedySearchDecoder::Decode(torch::Tensor encoder_out,
torch::Tensor encoder_out_length,
OfflineStream **ss /*= nullptr*/,
int32_t n /*= 0*/) {
torch::NoGradGuard no_grad;
InferenceMode no_grad;

TORCH_CHECK(encoder_out.dim() == 3, "encoder_out.dim() is ",
encoder_out.dim(), "Expected value is 3");
Expand Down
9 changes: 6 additions & 3 deletions sherpa/csrc/offline-wav2vec2-ctc-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
// Copyright (c) 2022 Xiaomi Corporation

#include "sherpa/csrc/offline-wav2vec2-ctc-model.h"

#include "sherpa/cpp_api/macros.h"

namespace sherpa {

OfflineWav2Vec2CtcModel::OfflineWav2Vec2CtcModel(
Expand All @@ -14,23 +17,23 @@ OfflineWav2Vec2CtcModel::OfflineWav2Vec2CtcModel(

torch::IValue OfflineWav2Vec2CtcModel::Forward(torch::Tensor waveforms,
torch::Tensor lengths) {
torch::NoGradGuard no_grad;
InferenceMode no_grad;

return model_.run_method("forward", waveforms.to(device_),
lengths.to(device_));
}

torch::Tensor OfflineWav2Vec2CtcModel::GetLogSoftmaxOut(
torch::IValue forward_out) const {
torch::NoGradGuard no_grad;
InferenceMode no_grad;

auto logit = forward_out.toTuple()->elements()[0].toTensor();
return logit.log_softmax(-1);
}

torch::Tensor OfflineWav2Vec2CtcModel::GetLogSoftmaxOutLength(
torch::IValue forward_out) const {
torch::NoGradGuard no_grad;
InferenceMode no_grad;

return forward_out.toTuple()->elements()[1].toTensor();
}
Expand Down
Loading

0 comments on commit a552531

Please sign in to comment.