From bfd778edf878199baa24194bdfbc99efbb874185 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 30 Sep 2023 21:53:26 +0800 Subject: [PATCH] First working version for ctc decoding with H/HL/HLG (#1) --- .clang-format | 9 + .github/workflows/build-wheels-aarch64.yaml | 64 ++ .github/workflows/build-wheels-linux.yaml | 71 ++ .github/workflows/build-wheels-macos.yaml | 62 ++ .github/workflows/build-wheels-win32.yaml | 56 ++ .github/workflows/build-wheels-win64.yaml | 55 ++ .github/workflows/linux.yaml | 60 ++ .github/workflows/macos.yaml | 60 ++ .github/workflows/windows.yaml | 70 ++ .gitignore | 4 + CMakeLists.txt | 75 ++ cmake/__init__.py | 0 cmake/cmake_extension.py | 124 +++ cmake/eigen.cmake | 48 ++ cmake/googletest.cmake | 75 ++ cmake/kaldifst.cmake | 69 ++ cmake/pybind11.cmake | 43 ++ kaldi-decoder/CMakeLists.txt | 5 + kaldi-decoder/csrc/CMakeLists.txt | 47 ++ kaldi-decoder/csrc/decodable-ctc.cc | 31 + kaldi-decoder/csrc/decodable-ctc.h | 33 + kaldi-decoder/csrc/decodable-itf.h | 106 +++ kaldi-decoder/csrc/eigen-test.cc | 714 ++++++++++++++++++ kaldi-decoder/csrc/eigen.cc | 71 ++ kaldi-decoder/csrc/eigen.h | 39 + kaldi-decoder/csrc/faster-decoder.cc | 426 +++++++++++ kaldi-decoder/csrc/faster-decoder.h | 204 +++++ kaldi-decoder/csrc/hash-list-inl.h | 207 +++++ kaldi-decoder/csrc/hash-list-test.cc | 103 +++ kaldi-decoder/csrc/hash-list.h | 133 ++++ kaldi-decoder/csrc/log.h | 100 +++ kaldi-decoder/csrc/stl-utils.h | 209 +++++ kaldi-decoder/python/CMakeLists.txt | 1 + kaldi-decoder/python/csrc/CMakeLists.txt | 29 + kaldi-decoder/python/csrc/decodable-ctc.cc | 17 + kaldi-decoder/python/csrc/decodable-ctc.h | 16 + kaldi-decoder/python/csrc/decodable-itf.cc | 55 ++ kaldi-decoder/python/csrc/decodable-itf.h | 16 + kaldi-decoder/python/csrc/faster-decoder.cc | 55 ++ kaldi-decoder/python/csrc/faster-decoder.h | 16 + kaldi-decoder/python/csrc/kaldi-decoder.cc | 20 + kaldi-decoder/python/csrc/kaldi-decoder.h | 14 + .../python/kaldi_decoder/__init__.py | 6 + scripts/check_style_cpplint.sh | 112 +++ scripts/utils.sh | 19 + setup.py | 70 ++ 46 files changed, 3819 insertions(+) create mode 100644 .clang-format create mode 100644 .github/workflows/build-wheels-aarch64.yaml create mode 100644 .github/workflows/build-wheels-linux.yaml create mode 100644 .github/workflows/build-wheels-macos.yaml create mode 100644 .github/workflows/build-wheels-win32.yaml create mode 100644 .github/workflows/build-wheels-win64.yaml create mode 100644 .github/workflows/linux.yaml create mode 100644 .github/workflows/macos.yaml create mode 100644 .github/workflows/windows.yaml create mode 100644 .gitignore create mode 100644 CMakeLists.txt create mode 100644 cmake/__init__.py create mode 100644 cmake/cmake_extension.py create mode 100644 cmake/eigen.cmake create mode 100644 cmake/googletest.cmake create mode 100644 cmake/kaldifst.cmake create mode 100644 cmake/pybind11.cmake create mode 100644 kaldi-decoder/CMakeLists.txt create mode 100644 kaldi-decoder/csrc/CMakeLists.txt create mode 100644 kaldi-decoder/csrc/decodable-ctc.cc create mode 100644 kaldi-decoder/csrc/decodable-ctc.h create mode 100644 kaldi-decoder/csrc/decodable-itf.h create mode 100644 kaldi-decoder/csrc/eigen-test.cc create mode 100644 kaldi-decoder/csrc/eigen.cc create mode 100644 kaldi-decoder/csrc/eigen.h create mode 100644 kaldi-decoder/csrc/faster-decoder.cc create mode 100644 kaldi-decoder/csrc/faster-decoder.h create mode 100644 kaldi-decoder/csrc/hash-list-inl.h create mode 100644 kaldi-decoder/csrc/hash-list-test.cc create mode 100644 kaldi-decoder/csrc/hash-list.h create mode 100644 kaldi-decoder/csrc/log.h create mode 100644 kaldi-decoder/csrc/stl-utils.h create mode 100644 kaldi-decoder/python/CMakeLists.txt create mode 100644 kaldi-decoder/python/csrc/CMakeLists.txt create mode 100644 kaldi-decoder/python/csrc/decodable-ctc.cc create mode 100644 kaldi-decoder/python/csrc/decodable-ctc.h create mode 100644 kaldi-decoder/python/csrc/decodable-itf.cc create mode 100644 kaldi-decoder/python/csrc/decodable-itf.h create mode 100644 kaldi-decoder/python/csrc/faster-decoder.cc create mode 100644 kaldi-decoder/python/csrc/faster-decoder.h create mode 100644 kaldi-decoder/python/csrc/kaldi-decoder.cc create mode 100644 kaldi-decoder/python/csrc/kaldi-decoder.h create mode 100644 kaldi-decoder/python/kaldi_decoder/__init__.py create mode 100755 scripts/check_style_cpplint.sh create mode 100644 scripts/utils.sh create mode 100644 setup.py diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..c65e772 --- /dev/null +++ b/.clang-format @@ -0,0 +1,9 @@ +--- +BasedOnStyle: Google +--- +Language: Cpp +Cpp11BracedListStyle: true +Standard: Cpp11 +DerivePointerAlignment: false +PointerAlignment: Right +--- diff --git a/.github/workflows/build-wheels-aarch64.yaml b/.github/workflows/build-wheels-aarch64.yaml new file mode 100644 index 0000000..997f583 --- /dev/null +++ b/.github/workflows/build-wheels-aarch64.yaml @@ -0,0 +1,64 @@ +name: build-wheels-aarch64 + +on: + push: + branches: + - wheel + tags: + - '*' + + workflow_dispatch: + +env: + KALDI_DECODER_IS_FOR_PYPI: 1 + +concurrency: + group: build-wheels-aarch64-${{ github.ref }} + cancel-in-progress: true + +jobs: + build_wheels_aarch64: + name: ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: ["cp37", "cp38", "cp39", "cp310", "cp311", "cp312"] + + steps: + - uses: actions/checkout@v2 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v2 + with: + platforms: all + + # see https://cibuildwheel.readthedocs.io/en/stable/changelog/ + # for a list of versions + - name: Build wheels + uses: pypa/cibuildwheel@v2.15.0 + env: + CIBW_BUILD: "${{ matrix.python-version}}-* " + CIBW_SKIP: "cp27-* cp35-* cp36-* *-win32 pp* *-musllinux* *-manylinux_i686" + CIBW_BUILD_VERBOSITY: 3 + CIBW_ARCHS_LINUX: aarch64 + CIBW_ENVIRONMENT_LINUX: LD_LIBRARY_PATH='/project/build/lib.linux-aarch64-cpython-37/kaldi_decoder/lib:/project/build/lib.linux-aarch64-cpython-38/kaldi_decoder/lib:/project/build/lib.linux-aarch64-cpython-39/kaldi_decoder/lib:/project/build/lib.linux-aarch64-cpython-310/kaldi_decoder/lib:/project/build/lib.linux-aarch64-cpython-311/kaldi_decoder/lib:/project/build/lib.linux-aarch64-cpython-312/kaldi_decoder/lib' + + - name: Display wheels + shell: bash + run: | + ls -lh ./wheelhouse/ + + - uses: actions/upload-artifact@v2 + with: + path: ./wheelhouse/*.whl + + - name: Publish wheels to PyPI + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + python3 -m pip install --upgrade pip + python3 -m pip install wheel twine setuptools + + twine upload ./wheelhouse/*.whl diff --git a/.github/workflows/build-wheels-linux.yaml b/.github/workflows/build-wheels-linux.yaml new file mode 100644 index 0000000..1833396 --- /dev/null +++ b/.github/workflows/build-wheels-linux.yaml @@ -0,0 +1,71 @@ +name: build-wheels-linux + +on: + push: + branches: + - wheel + tags: + - '*' + + workflow_dispatch: + +concurrency: + group: build-wheels-linux-${{ github.ref }} + cancel-in-progress: true + +jobs: + build_wheels_linux: + name: ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: ["cp37", "cp38", "cp39", "cp310", "cp311", "cp312"] + + steps: + - uses: actions/checkout@v2 + + # see https://cibuildwheel.readthedocs.io/en/stable/changelog/ + # for a list of versions + - name: Build wheels + uses: pypa/cibuildwheel@v2.15.0 + env: + CIBW_BUILD: "${{ matrix.python-version}}-* " + CIBW_SKIP: "cp27-* cp35-* cp36-* pp* *-musllinux* *-win32 " + CIBW_BUILD_VERBOSITY: 3 + CIBW_ENVIRONMENT_LINUX: LD_LIBRARY_PATH='/project/build/lib.linux-x86_64-cpython-37/kaldi_decoder/lib:/project/build/lib.linux-i686-cpython-37/kaldi_decoder/lib:/project/build/lib.linux-x86_64-cpython-38/kaldi_decoder/lib:/project/build/lib.linux-i686-cpython-38/kaldi_decoder/lib:/project/build/lib.linux-x86_64-cpython-39/kaldi_decoder/lib:/project/build/lib.linux-i686-cpython-39/kaldi_decoder/lib:/project/build/lib.linux-x86_64-cpython-310/kaldi_decoder/lib:/project/build/lib.linux-i686-cpython-310/kaldi_decoder/lib:/project/build/lib.linux-x86_64-cpython-311/kaldi_decoder/lib:/project/build/lib.linux-i686-cpython-311/kaldi_decoder/lib:/project/build/lib.linux-x86_64-cpython-312/kaldi_decoder/lib:/project/build/lib.linux-i686-cpython-312/kaldi_decoder/lib' + + - name: Display wheels + shell: bash + run: | + ls -lh ./wheelhouse/ + + - uses: actions/upload-artifact@v2 + with: + path: ./wheelhouse/*.whl + + - name: Publish wheels to PyPI + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + python3 -m pip install --upgrade pip + python3 -m pip install wheel twine setuptools + + twine upload ./wheelhouse/*.whl + + - name: Build sdist + if: matrix.python-version == 'cp38' + shell: bash + run: | + python3 setup.py sdist + ls -lh dist/* + + - name: Publish sdist to PyPI + if: matrix.python-version == 'cp38' + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + twine upload dist/kaldi-decoder-*.tar.gz diff --git a/.github/workflows/build-wheels-macos.yaml b/.github/workflows/build-wheels-macos.yaml new file mode 100644 index 0000000..59b79b8 --- /dev/null +++ b/.github/workflows/build-wheels-macos.yaml @@ -0,0 +1,62 @@ +name: build-wheels-macos + +on: + push: + branches: + - wheel + tags: + - '*' + + workflow_dispatch: + +env: + KALDI_DECODER_IS_FOR_PYPI: 1 + +concurrency: + group: build-wheels-macos-${{ github.ref }} + cancel-in-progress: true + +jobs: + build_wheels: + name: ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [macos-latest] + python-version: ["cp38", "cp39", "cp310", "cp311", "cp312"] + + steps: + - uses: actions/checkout@v2 + + # see https://cibuildwheel.readthedocs.io/en/stable/changelog/ + # for a list of versions + - name: Build wheels + uses: pypa/cibuildwheel@v2.15.0 + env: + CIBW_BUILD: "${{ matrix.python-version}}-* " + CIBW_ENVIRONMENT: KALDI_DECODER_CMAKE_ARGS="-DCMAKE_OSX_ARCHITECTURES='arm64;x86_64'" + CIBW_ARCHS: "universal2" + CIBW_BUILD_VERBOSITY: 3 + + # Don't repair macOS wheels + CIBW_REPAIR_WHEEL_COMMAND_MACOS: "" + + - name: Display wheels + shell: bash + run: | + ls -lh ./wheelhouse/ + + - uses: actions/upload-artifact@v2 + with: + path: ./wheelhouse/*.whl + + - name: Publish wheels to PyPI + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + python3 -m pip install --upgrade pip + python3 -m pip install wheel twine setuptools + + twine upload ./wheelhouse/*.whl diff --git a/.github/workflows/build-wheels-win32.yaml b/.github/workflows/build-wheels-win32.yaml new file mode 100644 index 0000000..2c8e9c1 --- /dev/null +++ b/.github/workflows/build-wheels-win32.yaml @@ -0,0 +1,56 @@ +name: build-wheels-win32 + +on: + push: + branches: + - wheel + tags: + - '*' + + workflow_dispatch: + +concurrency: + group: build-wheels-win32-${{ github.ref }} + cancel-in-progress: true + +jobs: + build_wheels_win32: + name: ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [windows-latest] + python-version: ["cp37", "cp38", "cp39", "cp310", "cp311", "cp312"] + + steps: + - uses: actions/checkout@v2 + + # see https://cibuildwheel.readthedocs.io/en/stable/changelog/ + # for a list of versions + - name: Build wheels + uses: pypa/cibuildwheel@v2.15.0 + env: + CIBW_ENVIRONMENT: KALDI_DECODER_CMAKE_ARGS="-A Win32" + CIBW_BUILD: "${{ matrix.python-version}}-* " + CIBW_SKIP: "*-win_amd64" + CIBW_BUILD_VERBOSITY: 3 + + - name: Display wheels + shell: bash + run: | + ls -lh ./wheelhouse/ + + - uses: actions/upload-artifact@v2 + with: + path: ./wheelhouse/*.whl + + - name: Publish wheels to PyPI + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + python3 -m pip install --upgrade pip + python3 -m pip install wheel twine setuptools + + twine upload ./wheelhouse/*.whl diff --git a/.github/workflows/build-wheels-win64.yaml b/.github/workflows/build-wheels-win64.yaml new file mode 100644 index 0000000..e376fbf --- /dev/null +++ b/.github/workflows/build-wheels-win64.yaml @@ -0,0 +1,55 @@ +name: build-wheels-win64 + +on: + push: + branches: + - wheel + tags: + - '*' + + workflow_dispatch: + +concurrency: + group: build-wheels-win64-${{ github.ref }} + cancel-in-progress: true + +jobs: + build_wheels_win64: + name: ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [windows-latest] + python-version: ["cp37", "cp38", "cp39", "cp310", "cp311", "cp312"] + + steps: + - uses: actions/checkout@v2 + + # see https://cibuildwheel.readthedocs.io/en/stable/changelog/ + # for a list of versions + - name: Build wheels + uses: pypa/cibuildwheel@v2.15.0 + env: + CIBW_BUILD: "${{ matrix.python-version}}-* " + CIBW_SKIP: "cp27-* cp35-* cp36-* pp* *-musllinux* *-win32 " + CIBW_BUILD_VERBOSITY: 3 + + - name: Display wheels + shell: bash + run: | + ls -lh ./wheelhouse/ + + - uses: actions/upload-artifact@v2 + with: + path: ./wheelhouse/*.whl + + - name: Publish wheels to PyPI + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + python3 -m pip install --upgrade pip + python3 -m pip install wheel twine setuptools + + twine upload ./wheelhouse/*.whl diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml new file mode 100644 index 0000000..3347b34 --- /dev/null +++ b/.github/workflows/linux.yaml @@ -0,0 +1,60 @@ +name: linux + +on: + push: + branches: + - master + pull_request: + branches: + - master + + workflow_dispatch: + +concurrency: + group: linux-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + linux: + name: ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: ["3.7", "3.8", "3.9", "3.10"] + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Display Python version + run: python -c "import sys; print(sys.version)" + + - name: Configure CMake + shell: bash + run: | + mkdir build + cd build + cmake -D CMAKE_BUILD_TYPE=Release -DKALDI_DECODER_ENABLE_TESTS=ON .. + + - name: Build + run: | + cd build + make -j2 + + ls -lh lib + + - name: Test + run: | + cd build + ctest --output-on-failure diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml new file mode 100644 index 0000000..0246477 --- /dev/null +++ b/.github/workflows/macos.yaml @@ -0,0 +1,60 @@ +name: macos + +on: + push: + branches: + - master + pull_request: + branches: + - master + + workflow_dispatch: + +concurrency: + group: macos-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + macos: + name: ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [macos-latest] + python-version: ["3.7", "3.8", "3.9", "3.10"] + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Display Python version + run: python -c "import sys; print(sys.version)" + + - name: Configure CMake + shell: bash + run: | + mkdir build + cd build + cmake -D CMAKE_BUILD_TYPE=Release -DKALDI_DECODER_ENABLE_TESTS=ON .. + + - name: Build + run: | + cd build + make -j2 + + ls -lh lib + + - name: Test + run: | + cd build + ctest --output-on-failure diff --git a/.github/workflows/windows.yaml b/.github/workflows/windows.yaml new file mode 100644 index 0000000..59b5d27 --- /dev/null +++ b/.github/workflows/windows.yaml @@ -0,0 +1,70 @@ +name: windows + +on: + push: + branches: + - master + pull_request: + branches: + - master + + workflow_dispatch: + +concurrency: + group: windows-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + windows: + if: false + name: ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [windows-latest] + python-version: ["3.7", "3.8", "3.9", "3.10"] + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + # see https://github.com/microsoft/setup-msbuild + - name: Add msbuild to PATH + uses: microsoft/setup-msbuild@v1.0.2 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Display Python version + run: python -c "import sys; print(sys.version)" + + - name: Configure CMake + shell: bash + run: | + mkdir build + cd build + cmake -D CMAKE_BUILD_TYPE=Release -DKALDI_DECODER_ENABLE_TESTS=ON .. + + - name: Build + shell: bash + run: | + cd build + cmake --build . --target ALL_BUILD --config Release + ls -lh ./lib/Release/* + ls -lh ./bin/Release/* + + - name: Test + shell: bash + run: | + cd build + export PYTHONPATH=$PWD/lib/Release:$PYTHONPATH + export PYTHONPATH=$PWD/../kaldi-decoder/python:$PYTHONPATH + + ctest -C Release --output-on-failure diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ca0e221 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +build +__pycache__ +dist +*egg-info diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..6eb85ac --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,75 @@ +cmake_minimum_required(VERSION 3.13 FATAL_ERROR) +project(kaldi-decoder) + +# Disable warning about +# +# "The DOWNLOAD_EXTRACT_TIMESTAMP option was not given and policy CMP0135 is +# not set. +if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") + cmake_policy(SET CMP0135 NEW) +endif() + +set(KALDI_DECODER_VERSION "0.2.1") + +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin") + +set(CMAKE_SKIP_BUILD_RPATH FALSE) +set(BUILD_RPATH_USE_ORIGIN TRUE) +set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) + +if(NOT APPLE) + set(KALDI_DECODER_RPATH_ORIGIN "$ORIGIN") +else() + set(CMAKE_MACOSX_RPATH ON) + set(KALDI_DECODER_RPATH_ORIGIN "@loader_path") +endif() + +set(CMAKE_INSTALL_RPATH ${KALDI_DECODER_RPATH_ORIGIN}) +set(CMAKE_BUILD_RPATH ${KALDI_DECODER_RPATH_ORIGIN}) + +if(NOT CMAKE_BUILD_TYPE) + message(STATUS "No CMAKE_BUILD_TYPE given, default to Release") + set(CMAKE_BUILD_TYPE Release) +endif() +message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") + +set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") +set(CMAKE_CXX_EXTENSIONS OFF) + +list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) + +option(BUILD_SHARED_LIBS "Whether to build shared libraries" ON) +option(KALDI_DECODER_ENABLE_TESTS "Whether to build tests" ON) +option(KALDI_DECODER_BUILD_PYTHON "Whether to build Python" ON) + +if(BUILD_SHARED_LIBS AND MSVC) + set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) +endif() + +if(NOT BUILD_SHARED_LIBS AND MSVC) + # see https://cmake.org/cmake/help/latest/prop_tgt/MSVC_RUNTIME_LIBRARY.html + # https://stackoverflow.com/questions/14172856/compile-with-mt-instead-of-md-using-cmake + if(MSVC) + add_compile_options( + $<$:/MT> #---------| + $<$:/MTd> #---|-- Statically link the runtime libraries + $<$:/MT> #--| + ) + endif() +endif() + +if(KALDI_DECODER_BUILD_PYTHON) + include(pybind11) +endif() + +include(kaldifst) +include(eigen) + +if(KALDI_DECODER_ENABLE_TESTS) + enable_testing() + include(googletest) +endif() + +add_subdirectory(kaldi-decoder) diff --git a/cmake/__init__.py b/cmake/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py new file mode 100644 index 0000000..649369d --- /dev/null +++ b/cmake/cmake_extension.py @@ -0,0 +1,124 @@ +# Copyright (c) 2021-2023 Xiaomi Corporation (author: Fangjun Kuang) +# flake8: noqa + +import os +import platform +import sys +from pathlib import Path + +import setuptools +from setuptools.command.build_ext import build_ext + + +def is_for_pypi(): + ans = os.environ.get("KALDI_DECODER_IS_FOR_PYPI", None) + return ans is not None + + +def is_macos(): + return platform.system() == "Darwin" + + +def is_windows(): + return platform.system() == "Windows" + + +try: + from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + + class bdist_wheel(_bdist_wheel): + def finalize_options(self): + _bdist_wheel.finalize_options(self) + # In this case, the generated wheel has a name in the form + # kaldi-decoder-xxx-pyxx-none-any.whl + if is_for_pypi() and not is_macos(): + self.root_is_pure = True + else: + # The generated wheel has a name ending with + # -linux_x86_64.whl + self.root_is_pure = False + +except ImportError: + bdist_wheel = None + + +def cmake_extension(name, *args, **kwargs) -> setuptools.Extension: + kwargs["language"] = "c++" + sources = [] + return setuptools.Extension(name, sources, *args, **kwargs) + + +class BuildExtension(build_ext): + def build_extension(self, ext: setuptools.extension.Extension): + # build/temp.linux-x86_64-3.8 + os.makedirs(self.build_temp, exist_ok=True) + + # build/lib.linux-x86_64-3.8 + os.makedirs(self.build_lib, exist_ok=True) + + install_dir = Path(self.build_lib).resolve() / "kaldi_decoder" + + kaldi_decoder_dir = Path(__file__).parent.parent.resolve() + + cmake_args = os.environ.get("KALDI_DECODER_CMAKE_ARGS", "") + make_args = os.environ.get("KALDI_DECODER_MAKE_ARGS", "") + system_make_args = os.environ.get("MAKEFLAGS", "") + + if cmake_args == "": + cmake_args = "-DCMAKE_BUILD_TYPE=Release" + + extra_cmake_args = f" -DCMAKE_INSTALL_PREFIX={install_dir} " + if is_windows(): + extra_cmake_args += f" -DBUILD_SHARED_LIBS=OFF " + else: + extra_cmake_args += f" -DBUILD_SHARED_LIBS=ON " + + extra_cmake_args += f" -DKALDI_DECODER_BUILD_PYTHON=ON " + extra_cmake_args += f" -DKALDI_DECODER_ENABLE_TESTS=OFF " + + if "PYTHON_EXECUTABLE" not in cmake_args: + print(f"Setting PYTHON_EXECUTABLE to {sys.executable}") + cmake_args += f" -DPYTHON_EXECUTABLE={sys.executable}" + + cmake_args += extra_cmake_args + + if is_windows(): + build_cmd = f""" + cmake {cmake_args} -B {self.build_temp} -S {kaldi_decoder_dir} + cmake --build {self.build_temp} --target install --config Release -- -m + """ + print(f"build command is:\n{build_cmd}") + ret = os.system( + f"cmake {cmake_args} -B {self.build_temp} -S {kaldi_decoder_dir}" + ) + if ret != 0: + raise Exception("Failed to configure kaldi-decoder") + + ret = os.system( + f"cmake --build {self.build_temp} --target install --config Release -- -m" # noqa + ) + if ret != 0: + raise Exception("Failed to build and install kaldi-decoder") + else: + if make_args == "" and system_make_args == "": + print("for fast compilation, run:") + print('export KALDI_DECODER_MAKE_ARGS="-j"; python setup.py install') + print('Setting make_args to "-j4"') + make_args = "-j4" + + build_cmd = f""" + cd {self.build_temp} + + cmake {cmake_args} {kaldi_decoder_dir} + + make {make_args} install/strip + """ + print(f"build command is:\n{build_cmd}") + + ret = os.system(build_cmd) + if ret != 0: + raise Exception( + "\nBuild kaldi-decoder failed. Please check the error message.\n" + "You can ask for help by creating an issue on GitHub.\n" + "\nClick:\n\thttps://github.com/k2-fsa/kaldi-decoder/issues/new\n" # noqa + ) diff --git a/cmake/eigen.cmake b/cmake/eigen.cmake new file mode 100644 index 0000000..e519b79 --- /dev/null +++ b/cmake/eigen.cmake @@ -0,0 +1,48 @@ +function(download_eigen) + include(FetchContent) + + set(eigen_URL "https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.gz") + set(eigen_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/eigen-3.4.0.tar.gz") + set(eigen_HASH "SHA256=8586084f71f9bde545ee7fa6d00288b264a2b7ac3607b974e54d13e7162c1c72") + + # If you don't have access to the Internet, + # please pre-download eigen + set(possible_file_locations + $ENV{HOME}/Downloads/eigen-3.4.0.tar.gz + ${PROJECT_SOURCE_DIR}/eigen-3.4.0.tar.gz + ${PROJECT_BINARY_DIR}/eigen-3.4.0.tar.gz + /tmp/eigen-3.4.0.tar.gz + /star-fj/fangjun/download/github/eigen-3.4.0.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(eigen_URL "${f}") + file(TO_CMAKE_PATH "${eigen_URL}" eigen_URL) + set(eigen_URL2) + break() + endif() + endforeach() + + set(BUILD_TESTING OFF CACHE BOOL "" FORCE) + set(EIGEN_BUILD_DOC OFF CACHE BOOL "" FORCE) + + FetchContent_Declare(eigen + URL ${eigen_URL} + URL_HASH ${eigen_HASH} + ) + + FetchContent_GetProperties(eigen) + if(NOT eigen_POPULATED) + message(STATUS "Downloading eigen ${eigen_URL}") + FetchContent_Populate(eigen) + endif() + message(STATUS "eigen is downloaded to ${eigen_SOURCE_DIR}") + message(STATUS "eigen's binary dir is ${eigen_BINARY_DIR}") + + + add_subdirectory(${eigen_SOURCE_DIR} ${eigen_BINARY_DIR} EXCLUDE_FROM_ALL) +endfunction() + +download_eigen() + diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake new file mode 100644 index 0000000..d5ff6c6 --- /dev/null +++ b/cmake/googletest.cmake @@ -0,0 +1,75 @@ +function(download_googltest) + include(FetchContent) + + set(googletest_URL "https://github.com/google/googletest/archive/refs/tags/v1.13.0.tar.gz") + set(googletest_URL2 "https://huggingface.co/csukuangfj/sherpa-cmake-deps/resolve/main/googletest-1.13.0.tar.gz") + set(googletest_HASH "SHA256=ad7fdba11ea011c1d925b3289cf4af2c66a352e18d4c7264392fead75e919363") + + # If you don't have access to the Internet, + # please pre-download googletest + set(possible_file_locations + $ENV{HOME}/Downloads/googletest-1.13.0.tar.gz + ${PROJECT_SOURCE_DIR}/googletest-1.13.0.tar.gz + ${PROJECT_BINARY_DIR}/googletest-1.13.0.tar.gz + /tmp/googletest-1.13.0.tar.gz + /star-fj/fangjun/download/github/googletest-1.13.0.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(googletest_URL "${f}") + file(TO_CMAKE_PATH "${googletest_URL}" googletest_URL) + set(googletest_URL2) + break() + endif() + endforeach() + + set(BUILD_GMOCK ON CACHE BOOL "" FORCE) + set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) + set(gtest_disable_pthreads ON CACHE BOOL "" FORCE) + set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + + FetchContent_Declare(googletest + URL + ${googletest_URL} + ${googletest_URL2} + URL_HASH ${googletest_HASH} + ) + + FetchContent_GetProperties(googletest) + if(NOT googletest_POPULATED) + message(STATUS "Downloading googletest from ${googletest_URL}") + FetchContent_Populate(googletest) + endif() + message(STATUS "googletest is downloaded to ${googletest_SOURCE_DIR}") + message(STATUS "googletest's binary dir is ${googletest_BINARY_DIR}") + + if(APPLE) + set(CMAKE_MACOSX_RPATH ON) # to solve the following warning on macOS + endif() + #[==[ + -- Generating done + Policy CMP0042 is not set: MACOSX_RPATH is enabled by default. Run "cmake + --help-policy CMP0042" for policy details. Use the cmake_policy command to + set the policy and suppress this warning. + + MACOSX_RPATH is not specified for the following targets: + + gmock + gmock_main + gtest + gtest_main + + This warning is for project developers. Use -Wno-dev to suppress it. + ]==] + + add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL) + + target_include_directories(gtest + INTERFACE + ${googletest_SOURCE_DIR}/googletest/include + ${googletest_SOURCE_DIR}/googlemock/include + ) +endfunction() + +download_googltest() diff --git a/cmake/kaldifst.cmake b/cmake/kaldifst.cmake new file mode 100644 index 0000000..81b35e8 --- /dev/null +++ b/cmake/kaldifst.cmake @@ -0,0 +1,69 @@ +function(download_kaldifst) + include(FetchContent) + + set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.6.tar.gz") + set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.6.tar.gz") + set(kaldifst_HASH "SHA256=79280c0bb08b5ed1a2ab7c21320a2b071f1f0eb10d2f047e8d6f027f0d32b4d2") + + # If you don't have access to the Internet, + # please pre-download kaldifst + set(possible_file_locations + $ENV{HOME}/Downloads/kaldifst-1.7.6.tar.gz + ${PROJECT_SOURCE_DIR}/kaldifst-1.7.6.tar.gz + ${PROJECT_BINARY_DIR}/kaldifst-1.7.6.tar.gz + /tmp/kaldifst-1.7.6.tar.gz + /star-fj/fangjun/download/github/kaldifst-1.7.6.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(kaldifst_URL "${f}") + file(TO_CMAKE_PATH "${kaldifst_URL}" kaldifst_URL) + set(kaldifst_URL2) + break() + endif() + endforeach() + + set(KALDIFST_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE) + + FetchContent_Declare(kaldifst + URL ${kaldifst_URL} + URL_HASH ${kaldifst_HASH} + ) + + FetchContent_GetProperties(kaldifst) + if(NOT kaldifst_POPULATED) + message(STATUS "Downloading kaldifst ${kaldifst_URL}") + FetchContent_Populate(kaldifst) + endif() + message(STATUS "kaldifst is downloaded to ${kaldifst_SOURCE_DIR}") + message(STATUS "kaldifst's binary dir is ${kaldifst_BINARY_DIR}") + + list(APPEND CMAKE_MODULE_PATH ${kaldifst_SOURCE_DIR}/cmake) + + include_directories(${kaldifst_SOURCE_DIR}) + add_subdirectory(${kaldifst_SOURCE_DIR} ${kaldifst_BINARY_DIR}) + + target_include_directories(kaldifst_core + PUBLIC + ${kaldifst_SOURCE_DIR} + ) + + target_include_directories(fst + PUBLIC + ${openfst_SOURCE_DIR}/src/include + ) + + set_target_properties(kaldifst_core PROPERTIES OUTPUT_NAME "kaldi-decoder-kaldi-fst-core") + set_target_properties(fst PROPERTIES OUTPUT_NAME "kaldi-decoder-fst") + + if(KALDI_DECODER_BUILD_PYTHON AND WIN32) + install(TARGETS kaldifst_core fst DESTINATION ..) + else() + install(TARGETS kaldifst_core fst DESTINATION lib) + endif() + +endfunction() + +download_kaldifst() diff --git a/cmake/pybind11.cmake b/cmake/pybind11.cmake new file mode 100644 index 0000000..ce08948 --- /dev/null +++ b/cmake/pybind11.cmake @@ -0,0 +1,43 @@ +function(download_pybind11) + include(FetchContent) + + set(pybind11_URL "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.2.tar.gz") + set(pybind11_URL2 "https://huggingface.co/csukuangfj/sherpa-cmake-deps/resolve/main/pybind11-2.10.2.tar.gz") + set(pybind11_HASH "SHA256=93bd1e625e43e03028a3ea7389bba5d3f9f2596abc074b068e70f4ef9b1314ae") + + # If you don't have access to the Internet, + # please pre-download pybind11 + set(possible_file_locations + $ENV{HOME}/Downloads/pybind11-2.10.2.tar.gz + ${PROJECT_SOURCE_DIR}/pybind11-2.10.2.tar.gz + ${PROJECT_BINARY_DIR}/pybind11-2.10.2.tar.gz + /tmp/pybind11-2.10.2.tar.gz + /star-fj/fangjun/download/github/pybind11-2.10.2.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(pybind11_URL "${f}") + file(TO_CMAKE_PATH "${pybind11_URL}" pybind11_URL) + set(pybind11_URL2) + break() + endif() + endforeach() + + FetchContent_Declare(pybind11 + URL + ${pybind11_URL} + ${pybind11_URL2} + URL_HASH ${pybind11_HASH} + ) + + FetchContent_GetProperties(pybind11) + if(NOT pybind11_POPULATED) + message(STATUS "Downloading pybind11 from ${pybind11_URL}") + FetchContent_Populate(pybind11) + endif() + message(STATUS "pybind11 is downloaded to ${pybind11_SOURCE_DIR}") + add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR} EXCLUDE_FROM_ALL) +endfunction() + +download_pybind11() diff --git a/kaldi-decoder/CMakeLists.txt b/kaldi-decoder/CMakeLists.txt new file mode 100644 index 0000000..c464f6b --- /dev/null +++ b/kaldi-decoder/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(csrc) + +if(KALDI_DECODER_BUILD_PYTHON) + add_subdirectory(python) +endif() diff --git a/kaldi-decoder/csrc/CMakeLists.txt b/kaldi-decoder/csrc/CMakeLists.txt new file mode 100644 index 0000000..03cbbf5 --- /dev/null +++ b/kaldi-decoder/csrc/CMakeLists.txt @@ -0,0 +1,47 @@ +include_directories(${CMAKE_SOURCE_DIR}) + +# Please keep the source files alphabetically sorted +set(srcs + decodable-ctc.cc + eigen.cc + faster-decoder.cc +) + + +add_library(kaldi-decoder-core ${srcs}) + +target_link_libraries(kaldi-decoder-core PUBLIC kaldifst_core) +target_link_libraries(kaldi-decoder-core PUBLIC Eigen3::Eigen) + +if(KALDI_DECODER_ENABLE_TESTS) + set(test_srcs + eigen-test.cc + hash-list-test.cc + ) + + function(kaldi_decoder_add_test source) + get_filename_component(name ${source} NAME_WE) + set(target_name "${name}") + add_executable(${target_name} ${source}) + target_link_libraries(${target_name} + PRIVATE + gtest + gtest_main + kaldi-decoder-core + ) + add_test(NAME "Test.${target_name}" + COMMAND + $ + ) + endfunction() + + foreach(source IN LISTS test_srcs) + kaldi_decoder_add_test(${source}) + endforeach() +endif() + +if(KALDI_DECODER_BUILD_PYTHON AND WIN32) + install(TARGETS kaldi-decoder-core DESTINATION ..) +else() + install(TARGETS kaldi-decoder-core DESTINATION lib) +endif() diff --git a/kaldi-decoder/csrc/decodable-ctc.cc b/kaldi-decoder/csrc/decodable-ctc.cc new file mode 100644 index 0000000..c7ba25a --- /dev/null +++ b/kaldi-decoder/csrc/decodable-ctc.cc @@ -0,0 +1,31 @@ +// kaldi-decoder/csrc/decodable-ctc.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "kaldi-decoder/csrc/decodable-ctc.h" + +#include + +namespace kaldi_decoder { + +DecodableCtc::DecodableCtc(const FloatMatrix &feats) : feature_matrix_(feats) {} + +float DecodableCtc::LogLikelihood(int32_t frame, int32_t index) { + // Note: We need to use index - 1 here since + // all the input labels of the H are incremented during graph + // construction + assert(index >= 1); + + return feature_matrix_(frame, index - 1); +} + +int32_t DecodableCtc::NumFramesReady() const { return feature_matrix_.rows(); } + +int32_t DecodableCtc::NumIndices() const { return feature_matrix_.cols(); } + +bool DecodableCtc::IsLastFrame(int32_t frame) const { + assert(frame < NumFramesReady()); + return (frame == NumFramesReady() - 1); +} + +} // namespace kaldi_decoder diff --git a/kaldi-decoder/csrc/decodable-ctc.h b/kaldi-decoder/csrc/decodable-ctc.h new file mode 100644 index 0000000..9cf044d --- /dev/null +++ b/kaldi-decoder/csrc/decodable-ctc.h @@ -0,0 +1,33 @@ +// kaldi-decoder/csrc/decodable-ctc.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef KALDI_DECODER_CSRC_DECODABLE_CTC_H_ +#define KALDI_DECODER_CSRC_DECODABLE_CTC_H_ + +#include "kaldi-decoder/csrc/decodable-itf.h" +#include "kaldi-decoder/csrc/eigen.h" + +namespace kaldi_decoder { + +class DecodableCtc : public DecodableInterface { + public: + explicit DecodableCtc(const FloatMatrix &feats); + + float LogLikelihood(int32_t frame, int32_t index) override; + + int32_t NumFramesReady() const override; + + // Indices are one-based! This is for compatibility with OpenFst. + int32_t NumIndices() const override; + + bool IsLastFrame(int32_t frame) const override; + + private: + // it saves log_softmax output + FloatMatrix feature_matrix_; +}; + +} // namespace kaldi_decoder + +#endif // KALDI_DECODER_CSRC_DECODABLE_CTC_H_ diff --git a/kaldi-decoder/csrc/decodable-itf.h b/kaldi-decoder/csrc/decodable-itf.h new file mode 100644 index 0000000..9f58ec9 --- /dev/null +++ b/kaldi-decoder/csrc/decodable-itf.h @@ -0,0 +1,106 @@ +// kaldi-decoder/csrc/decodable-itf.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University; +// Mirko Hannemann; Go Vivace Inc.; +// 2013 Johns Hopkins University (author: Daniel Povey) +// Copyright (c) 2023 Xiaomi Corporation +// this file is copied and modified from +// kaldi/src/itf/decodable-itf.h +#ifndef KALDI_DECODER_CSRC_DECODABLE_ITF_H_ +#define KALDI_DECODER_CSRC_DECODABLE_ITF_H_ +#include "kaldi-decoder/csrc/log.h" + +namespace kaldi_decoder { + +/** + DecodableInterface provides a link between the (acoustic-modeling and + feature-processing) code and the decoder. The idea is to make this + interface as small as possible, and to make it as agnostic as possible about + the form of the acoustic model (e.g. don't assume the probabilities are a + function of just a vector of floats), and about the decoder (e.g. don't + assume it accesses frames in strict left-to-right order). For normal + models, without on-line operation, the "decodable" sub-class will just be a + wrapper around a matrix of features and an acoustic model, and it will + answer the question 'what is the acoustic likelihood for this index and this + frame?'. + + For online decoding, where the features are coming in in real time, it is + important to understand the IsLastFrame() and NumFramesReady() functions. + There are two ways these are used: the old online-decoding code, in + ../online/, and the new online-decoding code, in ../online2/. In the old + online-decoding code, the decoder would do: \code{.cc} for (int frame = 0; + !decodable.IsLastFrame(frame); frame++) { + // Process this frame + } + \endcode + and the call to IsLastFrame would block if the features had not arrived yet. + The decodable object would have to know when to terminate the decoding. This + online-decoding mode is still supported, it is what happens when you call, + for example, LatticeFasterDecoder::Decode(). + + We realized that this "blocking" mode of decoding is not very convenient + because it forces the program to be multi-threaded and makes it complex to + control endpointing. In the "new" decoding code, you don't call (for + example) LatticeFasterDecoder::Decode(), you call + LatticeFasterDecoder::InitDecoding(), and then each time you get more + features, you provide them to the decodable object, and you call + LatticeFasterDecoder::AdvanceDecoding(), which does something like this: + \code{.cc} + while (num_frames_decoded_ < decodable.NumFramesReady()) { + // Decode one more frame [increments num_frames_decoded_] + } + \endcode + So the decodable object never has IsLastFrame() called. For decoding where + you are starting with a matrix of features, the NumFramesReady() function + will always just return the number of frames in the file, and IsLastFrame() + will return true for the last frame. + + For truly online decoding, the "old" online decodable objects in ../online/ + have a "blocking" IsLastFrame() and will crash if you call NumFramesReady(). + The "new" online decodable objects in ../online2/ return the number of frames + currently accessible if you call NumFramesReady(). You will likely not need + to call IsLastFrame(), but we implement it to only return true for the last + frame of the file once we've decided to terminate decoding. +*/ +class DecodableInterface { + public: + virtual ~DecodableInterface() = default; + + /// Returns the log likelihood, which will be negated in the decoder. + /// The "frame" starts from zero. You should verify that NumFramesReady() > + /// frame before calling this. + virtual float LogLikelihood(int32_t frame, int32_t index) = 0; + + /// Returns true if this is the last frame. Frames are zero-based, so the + /// first frame is zero. IsLastFrame(-1) will return false, unless the file + /// is empty (which is a case that I'm not sure all the code will handle, so + /// be careful). Caution: the behavior of this function in an online setting + /// is being changed somewhat. In future it may return false in cases where + /// we haven't yet decided to terminate decoding, but later true if we decide + /// to terminate decoding. The plan in future is to rely more on + /// NumFramesReady(), and in future, IsLastFrame() would always return false + /// in an online-decoding setting, and would only return true in a + /// decoding-from-matrix setting where we want to allow the last delta or LDA + /// features to be flushed out for compatibility with the baseline setup. + virtual bool IsLastFrame(int32_t frame) const = 0; + + /// The call NumFramesReady() will return the number of frames currently + /// available for this decodable object. This is for use in setups where you + /// don't want the decoder to block while waiting for input. This is newly + /// added as of Jan 2014, and I hope, going forward, to rely on this mechanism + /// more than IsLastFrame to know when to stop decoding. + virtual int32_t NumFramesReady() const { + KALDI_DECODER_ERR + << "NumFramesReady() not implemented for this decodable type."; + return -1; + } + + /// Returns the number of states in the acoustic model + /// (they will be indexed one-based, i.e. from 1 to NumIndices(); + /// this is for compatibility with OpenFst). + virtual int32_t NumIndices() const = 0; +}; + +} // namespace kaldi_decoder + +#endif // KALDI_DECODER_CSRC_DECODABLE_ITF_H_ diff --git a/kaldi-decoder/csrc/eigen-test.cc b/kaldi-decoder/csrc/eigen-test.cc new file mode 100644 index 0000000..c80e898 --- /dev/null +++ b/kaldi-decoder/csrc/eigen-test.cc @@ -0,0 +1,714 @@ +// kaldi-decoder/csrc/eigen-test.cc + +// Copyright (c) 2023 Xiaomi Corporation + +#include "kaldi-decoder/csrc/eigen.h" + +#include +#include + +#include "gtest/gtest.h" + +// See +// +// Quick reference guide +// https://eigen.tuxfamily.org/dox/group__QuickRefPage.html +// +// Preprocessor directives +// https://eigen.tuxfamily.org/dox/TopicPreprocessorDirectives.html +// +// Understanding Eigen +// https://eigen.tuxfamily.org/dox/UserManual_UnderstandingEigen.html +// +// Using Eigen in CMake Projects +// https://eigen.tuxfamily.org/dox/TopicCMakeGuide.html + +namespace kaldi_decoder { + +TEST(Eigen, Hello) { + Eigen::MatrixXd m(2, 2); // uninitialized; contains garbage data + EXPECT_EQ(m.size(), 2 * 2); + EXPECT_EQ(m.rows(), 2); + EXPECT_EQ(m.cols(), 2); + + m(0, 0) = 3; + m(1, 0) = 2.5; + m(0, 1) = -1; + m(1, 1) = m(1, 0) + m(0, 1); + + auto m2 = m; // value semantics; create a copy + m2(0, 0) = 10; + EXPECT_EQ(m(0, 0), 3); + + Eigen::MatrixXd m3 = std::move(m2); + // now m2 is empty + EXPECT_EQ(m2.size(), 0); + EXPECT_EQ(m3(0, 0), 10); + + double *d = &m(0, 0); + d[0] = 11; + d[1] = 20; + d[2] = 30; + d[3] = 40; + EXPECT_EQ(m(0, 0), 11); // column major by default + EXPECT_EQ(m(1, 0), 20); + EXPECT_EQ(m(0, 1), 30); // it is contiguous in memory + EXPECT_EQ(m(1, 1), 40); + + // column major + EXPECT_EQ(m(0), 11); + EXPECT_EQ(m(1), 20); + EXPECT_EQ(m(2), 30); + EXPECT_EQ(m(3), 40); + + Eigen::MatrixXf a; + EXPECT_EQ(a.size(), 0); + + Eigen::Matrix3f b; // uninitialized + EXPECT_EQ(b.size(), 3 * 3); + + Eigen::MatrixXf c(2, 5); // uninitialized + EXPECT_EQ(c.size(), 2 * 5); + + EXPECT_EQ(c.rows(), 2); + EXPECT_EQ(c.cols(), 5); + + { + Eigen::Matrix f{ + {1, 2}, + {3, 4}, + }; + + // row major + EXPECT_EQ(f(0), 1); + EXPECT_EQ(f(1), 2); + EXPECT_EQ(f(2), 3); + EXPECT_EQ(f(3), 4); + + // Note: f[0] causes compilation errors + } +} + +TEST(Eigen, Identity) { + auto m = Eigen::Matrix3f::Identity(); // 3x3 identity matrix + EXPECT_EQ(m.sum(), 3); + + auto n = Eigen::MatrixXf::Identity(2, 3); // 2x3 identity matrix +#if 0 + 1 0 0 + 0 1 0 +#endif +} + +// https://eigen.tuxfamily.org/dox/classEigen_1_1DenseBase.html#ae814abb451b48ed872819192dc188c19 +TEST(Eigen, Random) { + // Random: Uniform distribution in the range [-1, 1] + auto m = Eigen::MatrixXd::Random(2, 3); +#if 0 + -0.999984 0.511211 0.0655345 + -0.736924 -0.0826997 -0.562082 +#endif + + // Note: We don't need to specify the shape for Random() in this case + auto m2 = Eigen::Matrix3d::Random(); +#if 0 + -0.999984 -0.0826997 -0.905911 + -0.736924 0.0655345 0.357729 + 0.511211 -0.562082 0.358593 +#endif +} + +TEST(Eigen, Vector) { + Eigen::VectorXd v(3); + + // comma initializer + v << 1, 2, 3; + EXPECT_EQ(v(0), 1); + EXPECT_EQ(v(1), 2); + EXPECT_EQ(v(2), 3); + + // vector also support operator[] + EXPECT_EQ(v[0], 1); + EXPECT_EQ(v[1], 2); + EXPECT_EQ(v[2], 3); + + double *p = &v[0]; + p[0] = 10; + p[1] = 20; + p[2] = 30; + + EXPECT_EQ(v[0], 10); + EXPECT_EQ(v[1], 20); + EXPECT_EQ(v[2], 30); + + // fixed size + Eigen::Vector3d a(10, 20, 30); + EXPECT_EQ(a[0], 10); + EXPECT_EQ(a[1], 20); + EXPECT_EQ(a[2], 30); +} + +TEST(Eigen, CommaInitializer) { + // comma initializer does not depend on the storage major + { + Eigen::Matrix m(2, + 2); + m << 1, 2, 3, 4; + EXPECT_EQ(m(0, 0), 1); + EXPECT_EQ(m(0, 1), 2); + EXPECT_EQ(m(1, 0), 3); + EXPECT_EQ(m(1, 1), 4); + } + + { + Eigen::Matrix m(2, + 2); + m << 1, 2, 3, 4; + EXPECT_EQ(m(0, 0), 1); + EXPECT_EQ(m(0, 1), 2); + EXPECT_EQ(m(1, 0), 3); + EXPECT_EQ(m(1, 1), 4); + } +} + +TEST(Eigen, Resize) { + // a resize operation is a destructive operation if it changes the size. + // The original content is not copied to the resized area + Eigen::MatrixXf a(2, 3); + EXPECT_EQ(a.rows(), 2); + EXPECT_EQ(a.cols(), 3); + EXPECT_EQ(a.size(), a.rows() * a.cols()); + + a.resize(5, 6); + EXPECT_EQ(a.rows(), 5); + EXPECT_EQ(a.cols(), 6); + EXPECT_EQ(a.size(), a.rows() * a.cols()); + + Eigen::MatrixXf b; + EXPECT_EQ(b.size(), 0); + + b = a; // copy by value + EXPECT_EQ(b.rows(), 5); + EXPECT_EQ(b.cols(), 6); +} + +TEST(Eigen, MatMul) { + Eigen::MatrixXf a(2, 2); + a << 1, 2, 3, 4; + + Eigen::MatrixXf b(2, 2); + b << 3, 0, 0, 2; + + Eigen::MatrixXf c = a * b; // matrix multiplication + EXPECT_EQ(c(0, 0), a(0, 0) * b(0, 0)); + EXPECT_EQ(c(0, 1), a(0, 1) * b(1, 1)); + + EXPECT_EQ(c(1, 0), a(1, 0) * b(0, 0)); + EXPECT_EQ(c(1, 1), a(1, 1) * b(1, 1)); + + Eigen::MatrixXf d; + d.noalias() = a * b; // explicitly specify that there is no alias + + EXPECT_EQ(d(0, 0), a(0, 0) * b(0, 0)); + EXPECT_EQ(d(0, 1), a(0, 1) * b(1, 1)); + + EXPECT_EQ(d(1, 0), a(1, 0) * b(0, 0)); + EXPECT_EQ(d(1, 1), a(1, 1) * b(1, 1)); +} + +TEST(Eigen, Transpose) { + Eigen::MatrixXf a(2, 2); + a << 1, 2, 3, 4; + + // a = a.transpose(); // wrong due to alias +#if 0 + 1 2 + 2 4 +#endif + + Eigen::MatrixXf b(2, 2); + b << 1, 2, 3, 4; + b.transposeInPlace(); // correct +#if 0 + 1 3 + 2 4 +#endif +} + +TEST(Eigen, Reduction) { + Eigen::MatrixXf m(2, 2); + m << 1, 2, 3, -5; +#if 0 + 1 2 + 3 -5 +#endif + + EXPECT_EQ(m.sum(), 1); + EXPECT_EQ(m.prod(), -30); + EXPECT_EQ(m.mean(), m.sum() / m.size()); + EXPECT_EQ(m.minCoeff(), -5); + EXPECT_EQ(m.maxCoeff(), 3); + EXPECT_EQ(m.trace(), 1 + (-5)); + EXPECT_EQ(m.trace(), m.diagonal().sum()); + + std::ptrdiff_t row_id, col_id; + + float a = m.minCoeff(&row_id, &col_id); + EXPECT_EQ(a, -5); + EXPECT_EQ(row_id, 1); + EXPECT_EQ(col_id, 1); + + float b = m.maxCoeff(&row_id, &col_id); + EXPECT_EQ(b, 3); + EXPECT_EQ(row_id, 1); + EXPECT_EQ(col_id, 0); +} + +TEST(Eigen, Array) { + // Note: It is XX for a 2-D array + Eigen::ArrayXXf a(2, 3); + a << 1, 2, 3, 4, 5, 6; +#if 0 + 1 2 3 + 4 5 6 +#endif + + EXPECT_EQ(a(0, 0), 1); + EXPECT_EQ(a(0, 1), 2); + EXPECT_EQ(a(0, 2), 3); + + EXPECT_EQ(a(1, 0), 4); + EXPECT_EQ(a(1, 1), 5); + EXPECT_EQ(a(1, 2), 6); + + EXPECT_EQ(a.rows(), 2); + EXPECT_EQ(a.cols(), 3); + + Eigen::Array b; + EXPECT_EQ(b.rows(), 5); + EXPECT_EQ(b.cols(), 2); + + // 1-d array + Eigen::ArrayXf c(10); + + EXPECT_EQ(c.rows(), 10); + EXPECT_EQ(c.cols(), 1); + EXPECT_EQ(c.size(), 10); + static_assert( + std::is_same>::value, + ""); + + static_assert(std::is_same>::value, + ""); + + static_assert( + std::is_same>::value, + ""); + + static_assert(std::is_same>::value, + ""); + + static_assert(std::is_same>::value, + ""); +} + +TEST(Eigen, ArrayMultiplication) { + Eigen::ArrayXXf a(2, 2); + a << 1, 2, 3, 4; + Eigen::ArrayXXf b = a * a; + + EXPECT_EQ(b(0, 0), a(0, 0) * a(0, 0)); + EXPECT_EQ(b(0, 1), a(0, 1) * a(0, 1)); + EXPECT_EQ(b(1, 0), a(1, 0) * a(1, 0)); + EXPECT_EQ(b(1, 1), a(1, 1) * a(1, 1)); + + // column-wise product + Eigen::ArrayXXf c = a.matrix().cwiseProduct(a.matrix()); + + EXPECT_EQ(c(0, 0), a(0, 0) * a(0, 0)); + EXPECT_EQ(c(0, 1), a(0, 1) * a(0, 1)); + EXPECT_EQ(c(1, 0), a(1, 0) * a(1, 0)); + EXPECT_EQ(c(1, 1), a(1, 1) * a(1, 1)); +} + +TEST(Eigen, CoefficientWise) { + Eigen::ArrayXXf a(2, 2); + a << 1, 2, 3, -4; + + EXPECT_EQ(a.abs()(1, 1), 4); + EXPECT_EQ(a.abs().sum(), 10); + + EXPECT_EQ(a.abs().sqrt()(1, 1), 2); +} + +TEST(Eigen, Row) { + Eigen::MatrixXf m(2, 3); + m << 1, 2, 3, 4, 5, 6; + + Eigen::MatrixXf a = m.row(0); // copied to a + EXPECT_EQ(a.rows(), 1); + EXPECT_EQ(a.cols(), 3); + + a(0) = 10; + EXPECT_EQ(m(0, 0), 1); + + Eigen::MatrixXf b = m.col(1); // copied to b + EXPECT_EQ(b.rows(), 2); + EXPECT_EQ(b.cols(), 1); + b(0) = 10; + EXPECT_EQ(m(0, 1), 2); + + auto c = m.row(0); // c is a proxy object; no copy is created + c(0) = 10; // also change m + EXPECT_EQ(m(0, 0), 10); + EXPECT_EQ(c.rows(), 1); + EXPECT_EQ(c.cols(), 3); + + auto d = c; // d is also a proxy + d(0) = 100; + EXPECT_EQ(c(0), 100); + EXPECT_EQ(m(0), 100); + + // N5Eigen6MatrixIfLin1ELin1ELi0ELin1ELin1EEE + // std::cout << typeid(m).name() << "\n"; + + // N5Eigen6MatrixIfLin1ELin1ELi0ELin1ELin1EEE + // std::cout << typeid(b).name() << "\n"; + + // N5Eigen5BlockINS_6MatrixIfLin1ELin1ELi0ELin1ELin1EEELi1ELin1ELb0EEE + // std::cout << typeid(c).name() << "\n"; +} + +TEST(Eigen, Sequence) { + // (start, end) + // Note that 5 is included here + auto seq = Eigen::seq(2, 5); // [2, 3, 4, 5] + EXPECT_EQ(seq.size(), 4); + for (int32_t i = 0; i != seq.size(); ++i) { + EXPECT_EQ(seq[i], i + 2); + } + + // start 2, end 5, increment 2, + // (start, end, increment), note that 5 is not included here + auto seq2 = Eigen::seq(2, 5, 2); // [2, 4] + EXPECT_EQ(seq2.size(), 2); + EXPECT_EQ(seq2[0], 2); + EXPECT_EQ(seq2[1], 4); + + // (start, sequence_length) + auto seq3 = Eigen::seqN(2, 5); // [2, 3, 4, 5, 6] + EXPECT_EQ(seq3.size(), 5); + for (int32_t i = 0; i != seq3.size(); ++i) { + EXPECT_EQ(seq3[i], i + 2); + } + + Eigen::VectorXf v(5); + v << 0, 1, 2, 3, 4; + + Eigen::VectorXf a = v(Eigen::seq(2, Eigen::last)); + EXPECT_EQ(a.size(), 3); + EXPECT_EQ(a[0], 2); + EXPECT_EQ(a[1], 3); + EXPECT_EQ(a[2], 4); + + a = v(Eigen::seq(2, Eigen::last - 1)); + EXPECT_EQ(a.size(), 2); + EXPECT_EQ(a[0], 2); + EXPECT_EQ(a[1], 3); +} + +TEST(Eigen, CopyRow) { + Eigen::MatrixXf a = Eigen::MatrixXf::Random(2, 3); + Eigen::MatrixXf b(2, 3); + b.row(0) = a.row(0); + b.row(1) = a.row(1); + for (int32_t i = 0; i != a.size(); ++i) { + EXPECT_EQ(a(i), b(i)); + } + + a = Eigen::MatrixXf::Random(5, 3); + b.resize(5, 3); + + b(Eigen::seqN(0, 3), Eigen::all) = a(Eigen::seqN(0, 3), Eigen::all); + b(Eigen::seqN(3, 2), Eigen::all) = a(Eigen::seqN(3, 2), Eigen::all); + for (int32_t i = 0; i != a.size(); ++i) { + EXPECT_EQ(a(i), b(i)); + } + + Eigen::MatrixXf c(5, 3); + c(Eigen::seqN(0, 5), Eigen::all) = a; + for (int32_t i = 0; i != a.size(); ++i) { + EXPECT_EQ(a(i), c(i)); + } +} + +TEST(Eigen, SpecialFunctions) { + Eigen::MatrixXf a(2, 3); + a.setOnes(); + for (int32_t i = 0; i != a.size(); ++i) { + EXPECT_EQ(a(i), 1); + } + + a.setZero(); + for (int32_t i = 0; i != a.size(); ++i) { + EXPECT_EQ(a(i), 0); + } +} + +TEST(Eigen, LogSumExp) { + Eigen::VectorXf v(5); + v << 0.1, 0.3, 0.2, 0.15, 0.25; + auto f = LogSumExp(v); + EXPECT_NEAR(f, 1.8119, 1e-4); + + v.resize(10); + v << -0.028933119028806686, -0.8265501260757446, 0.31104734539985657, + 0.25977903604507446, 0.18070533871650696, 0.02222185768187046, + -1.4124598503112793, -0.5896500945091248, -0.17299121618270874, + -0.6516317129135132; + + f = LogSumExp(v); + EXPECT_NEAR(f, 2.1343, 1e-4); +} + +TEST(Eigen, Addmm) { + // means_invvars_: (nmix, dim) + // data: (dim,) + // loglikes: (nmix,) + // loglikes += means * inv(vars) * data. + // loglikes->AddMatVec(1.0, means_invvars_, kNoTrans, data, 1.0); + // loglikes = loglikes.unsqueeze(1); // (nmix, 1) + // loglikes.addmm_(means_invvars_, data.unsqueeze(1)); + + int32_t nmix = 3; + int32_t dim = 5; + + Eigen::MatrixXf means_invvars = Eigen::MatrixXf::Random(nmix, dim); + Eigen::VectorXf data = Eigen::VectorXf::Random(dim); + + Eigen::VectorXf loglikes = Eigen::VectorXf::Random(nmix); + + loglikes += means_invvars * data; +} + +TEST(Eigen, VectorOp) { + Eigen::VectorXf a(2); + a << 1, 2; + Eigen::RowVectorXf b(2); + b << 10, 20; + + { + Eigen::VectorXf c = a.array() + b.transpose().array(); + EXPECT_EQ(c.size(), 2); + EXPECT_EQ(c[0], a[0] + b[0]); + EXPECT_EQ(c[1], a[1] + b[1]); + } + +#if 0 + { + // Don't do this! + Eigen::VectorXf c = a.array() + b.array(); + EXPECT_EQ(c.size(), 1); + EXPECT_EQ(c[0], a[0] + b[0]); + } + + { + // Don't do this! + Eigen::RowVectorXf c(2); + c << 100, 200; + c.row(0) = a.array() + b.array(); + EXPECT_EQ(c.size(), 2); + EXPECT_EQ(c[0], a[0] + b[0]); + EXPECT_EQ(c[1], a[1] + b[0]); + } + + { + // Don't do this! + Eigen::RowVectorXf c(2); + c << 100, 200; + c.row(0) = b.array() + a.array(); + EXPECT_EQ(c.size(), 2); + EXPECT_EQ(c[0], b[0] + a[0]); + EXPECT_EQ(c[1], b[1] + a[0]); + } +#endif +} + +TEST(Eigen, VectorOp2) { + Eigen::MatrixXf m(2, 3); + m << 1, 4, 8, 16, 9, 25; + + Eigen::VectorXf v(3); + v << 10, 20, 30; + + v = v.transpose().array() * m.row(1).array().sqrt(); + EXPECT_EQ(v.size(), 3); + + EXPECT_EQ(v[0], 10 * std::sqrt(16)); + EXPECT_EQ(v[1], 20 * std::sqrt(9)); + EXPECT_EQ(v[2], 30 * std::sqrt(25)); +} + +TEST(Eigen, RowwiseSum) { + Eigen::MatrixXf m(2, 3); + m << 1, 2, 3, 4, 5, 6; + + Eigen::MatrixXf a = m.rowwise().sum(); + EXPECT_EQ(a.rows(), m.rows()); + EXPECT_EQ(a.cols(), 1); + + EXPECT_EQ(a(0), 1 + 2 + 3); + EXPECT_EQ(a(1), 4 + 5 + 6); + + Eigen::MatrixXf b = m.colwise().sum(); + EXPECT_EQ(b.rows(), 1); + EXPECT_EQ(b.cols(), m.cols()); + + EXPECT_EQ(b(0), 1 + 4); + EXPECT_EQ(b(1), 2 + 5); + EXPECT_EQ(b(2), 3 + 6); + + // assign a row vector to a vector + Eigen::VectorXf c = m.colwise().sum(); + EXPECT_EQ(c.rows(), b.cols()); + EXPECT_EQ(c.cols(), 1); + + EXPECT_EQ(c(0), 1 + 4); + EXPECT_EQ(c(1), 2 + 5); + EXPECT_EQ(c(2), 3 + 6); + + // assign a vector to a row vector + Eigen::RowVectorXf d = m.rowwise().sum(); + EXPECT_EQ(d(0), 1 + 2 + 3); + EXPECT_EQ(d(1), 4 + 5 + 6); + + // now for array + Eigen::MatrixXf a2 = m.array().rowwise().sum(); + EXPECT_EQ(a2.rows(), m.rows()); + EXPECT_EQ(a2.cols(), 1); +} + +TEST(Eigen, Replicate) { + Eigen::VectorXf v(2); + v << 1, 2; + Eigen::VectorXf a = v.replicate(3, 1); + EXPECT_EQ(a.size(), v.size() * 3); + EXPECT_EQ(a[0], v[0]); + EXPECT_EQ(a[1], v[1]); + EXPECT_EQ(a[2], v[0]); + EXPECT_EQ(a[3], v[1]); + EXPECT_EQ(a[4], v[0]); + EXPECT_EQ(a[5], v[1]); + + Eigen::MatrixXf m = v.transpose().replicate(3, 1); + // repeat the rows 3 times + EXPECT_EQ(m.rows(), 3); + EXPECT_EQ(m.cols(), v.size()); + + Eigen::MatrixXf expected_m(3, 2); + expected_m << 1, 2, 1, 2, 1, 2; + for (int32_t i = 0; i != m.size(); ++i) { + EXPECT_EQ(m(i), expected_m(i)); + } +} + +TEST(Eigen, Indexes) { + Eigen::VectorXf v(5); + v << 0, 10, 20, 30, 40; + + std::vector indexes = {1, 4, 0, 2, 1}; + Eigen::VectorXf a = v(indexes); + EXPECT_EQ(a.size(), indexes.size()); + for (int32_t i = 0; i != a.size(); ++i) { + EXPECT_EQ(a[i], v[indexes[i]]); + } + + Eigen::MatrixXf m(3, 2); + m << 0, 1, 2, 3, 4, 5; + + indexes = {1, 0, 2, 1}; + Eigen::MatrixXf b = m(indexes, Eigen::all); +#if 0 + 2 3 + 0 1 + 4 5 + 2 3 +#endif +} + +TEST(Eigen, TestSoftmax) { + Eigen::VectorXf v(5); + v << 0.46589261293411255, 0.5329158902168274, 0.45468050241470337, + 0.509181022644043, 0.4529399275779724; + + Eigen::VectorXf expected(5); + expected << 0.1964813768863678, 0.21010152995586395, 0.19429071247577667, + 0.205173522233963, 0.19395282864570618; + + Eigen::VectorXf actual = Softmax(v); + for (int32_t i = 0; i != 5; ++i) { + EXPECT_NEAR(expected[i], actual[i], 1e-4); + } +} + +TEST(Eigen, Op1) { + Eigen::VectorXf a(2); + Eigen::VectorXf b(3); + + a << 10, 20; + b << 3, 5, 8; + + Eigen::MatrixXf c; + c = a * b.transpose(); + + Eigen::MatrixXf expected(2, 3); + expected << 30, 50, 80, 60, 100, 160; + for (int32_t i = 0; i != 6; ++i) { + EXPECT_EQ(c(i), expected(i)); + } +} + +TEST(Eigen, Op2) { + Eigen::MatrixXf a(2, 3); + a << 1, 2, 3, 4, 5, 6; + + Eigen::VectorXf b(2); + b << 10, 20; + + a = a.array() * b.replicate(1, a.cols()).array(); + + Eigen::MatrixXf expected(2, 3); + expected << 10, 20, 30, 80, 100, 120; + for (int32_t i = 0; i != 6; ++i) { + EXPECT_EQ(a(i), expected(i)); + } + + std::cout << a << "\n"; +} + +TEST(Eigen, Op3) { + Eigen::MatrixXf a(2, 3); + a << 1, 2, 3, 4, 5, 6; + Eigen::VectorXf b = a.row(1); + // b contains [4, 5, 6] and is a column vector + b[0] = 100; + + // it is OK to assign a column vector to a row vector + a.row(1) = b; + // a is + // 1 2 3 + // 100 5 6 +} + +TEST(Eigen, Dot) { + Eigen::VectorXf a(3); + Eigen::VectorXf b(3); + a << 1, 2, 3; + b << 4, 5, 6; + float c = a.dot(b); + EXPECT_EQ(c, 1 * 4 + 2 * 5 + 3 * 6); +} + +} // namespace kaldi_decoder diff --git a/kaldi-decoder/csrc/eigen.cc b/kaldi-decoder/csrc/eigen.cc new file mode 100644 index 0000000..2956241 --- /dev/null +++ b/kaldi-decoder/csrc/eigen.cc @@ -0,0 +1,71 @@ +// kaldi-decoder/csrc/eigen.cc + +// Copyright (c) 2023 Xiaomi Corporation + +#include "kaldi-decoder/csrc/eigen.h" + +#include +#include + +namespace kaldi_decoder { + +// see https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp +// log(sum(exp(x))) == log(sum(exp(x - max(x)))) + max(x) +float LogSumExp(const FloatVector &v) { + float max_v = v.maxCoeff(); + + return std::log((v.array() - max_v).exp().sum()) + max_v; +} + +FloatVector Softmax(const FloatVector &v, float *log_sum_exp /*= nullptr*/) { + float max_v = v.maxCoeff(); + + FloatVector ans = (v.array() - max_v).exp(); + + float ans_sum = ans.sum(); + + if (log_sum_exp) { + *log_sum_exp = std::log(ans_sum) + max_v; + } + + return ans / ans_sum; +} + +FloatVector RandnVector(int32_t n, float mean /*= 0*/, float stddev /*= 1*/) { + std::random_device rd{}; + std::mt19937 gen{rd()}; + std::normal_distribution d{mean, stddev}; + + FloatVector ans(n); + + for (int32_t i = 0; i != n; ++i) { + ans[i] = d(gen); + } + + return ans; +} + +FloatMatrix RandnMatrix(int32_t rows, int32_t cols, float mean /*= 0*/, + float stddev /*= 1*/) { + std::random_device rd{}; + std::mt19937 gen{rd()}; + std::normal_distribution d{mean, stddev}; + + FloatMatrix ans(rows, cols); + + for (int32_t i = 0; i != ans.size(); ++i) { + ans(i) = d(gen); + } + + return ans; +} + +float Randn(float mean /*= 0*/, float stddev /*= 1*/) { + std::random_device rd{}; + std::mt19937 gen{rd()}; + std::normal_distribution d{mean, stddev}; + + return d(gen); +} + +} // namespace kaldi_decoder diff --git a/kaldi-decoder/csrc/eigen.h b/kaldi-decoder/csrc/eigen.h new file mode 100644 index 0000000..444ac9d --- /dev/null +++ b/kaldi-decoder/csrc/eigen.h @@ -0,0 +1,39 @@ +// kaldi-decoder/csrc/eigen.h + +// Copyright (c) 2023 Xiaomi Corporation +#ifndef KALDI_DECODER_CSRC_EIGEN_H_ +#define KALDI_DECODER_CSRC_EIGEN_H_ + +#include "Eigen/Dense" + +namespace kaldi_decoder { + +using FloatMatrix = + Eigen::Matrix; + +using DoubleMatrix = + Eigen::Matrix; + +using FloatVector = Eigen::Matrix; + +using DoubleVector = Eigen::Matrix; + +using FloatRowVector = Eigen::Matrix; + +using DoubleRowVector = Eigen::Matrix; + +float LogSumExp(const FloatVector &v); + +FloatVector Softmax(const FloatVector &v, float *log_sum_exp = nullptr); + +// A vector of normal distribution +FloatVector RandnVector(int32_t n, float mean = 0, float stddev = 1); + +FloatMatrix RandnMatrix(int32_t rows, int32_t cols, float mean = 0, + float stddev = 1); + +float Randn(float mean = 0, float stddev = 1); + +} // namespace kaldi_decoder + +#endif // KALDI_DECODER_CSRC_EIGEN_H_ diff --git a/kaldi-decoder/csrc/faster-decoder.cc b/kaldi-decoder/csrc/faster-decoder.cc new file mode 100644 index 0000000..2c816d4 --- /dev/null +++ b/kaldi-decoder/csrc/faster-decoder.cc @@ -0,0 +1,426 @@ +// kaldi-decoder/csrc/faster-decoder.cc + +// Copyright 2009-2011 Microsoft Corporation +// 2012-2013 Johns Hopkins University (author: Daniel Povey) +// Copyright (c) 2023 Xiaomi Corporation + +// this file is copied and modified from +// kaldi/src/decoder/faster-decoder.cc + +#include "kaldi-decoder/csrc/faster-decoder.h" + +#include +#include +#include + +#include "kaldi-decoder/csrc/log.h" +#include "kaldifst/csrc/remove-eps-local.h" + +namespace kaldi_decoder { + +FasterDecoder::FasterDecoder(const fst::Fst &fst, + const FasterDecoderOptions &opts) + : fst_(fst), config_(opts), num_frames_decoded_(-1) { + KALDI_DECODER_ASSERT(config_.hash_ratio >= + 1.0); // less doesn't make much sense. + KALDI_DECODER_ASSERT(config_.max_active > 1); + KALDI_DECODER_ASSERT(config_.min_active >= 0 && + config_.min_active < config_.max_active); + + // just so on the first frame we do something reasonable. + toks_.SetSize(1000); +} + +void FasterDecoder::ClearToks(Elem *list) { + for (Elem *e = list, *e_tail; e != nullptr; e = e_tail) { + Token::TokenDelete(e->val); + e_tail = e->tail; + toks_.Delete(e); + } +} + +void FasterDecoder::InitDecoding() { + // clean up from last time: + ClearToks(toks_.Clear()); + StateId start_state = fst_.Start(); + + KALDI_DECODER_ASSERT(start_state != fst::kNoStateId); + + Arc dummy_arc(0, 0, Weight::One(), start_state); + + toks_.Insert(start_state, new Token(dummy_arc, nullptr)); + + ProcessNonemitting(std::numeric_limits::max()); + + num_frames_decoded_ = 0; +} + +// TODO(dan): first time we go through this, could avoid using the queue. +void FasterDecoder::ProcessNonemitting(double cutoff) { + // Processes nonemitting arcs for one frame. + KALDI_DECODER_ASSERT(queue_.empty()); + + for (const Elem *e = toks_.GetList(); e != nullptr; e = e->tail) { + queue_.push_back(e); + } + + while (!queue_.empty()) { + const Elem *e = queue_.back(); + queue_.pop_back(); + + StateId state = e->key; + Token *tok = e->val; // would segfault if state not + // in toks_ but this can't happen. + if (tok->cost_ > cutoff) { // Don't bother processing successors. + continue; + } + + KALDI_DECODER_ASSERT(tok != nullptr && state == tok->arc_.nextstate); + + for (fst::ArcIterator> aiter(fst_, state); !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + + if (arc.ilabel != 0) { + continue; + } + + // propagate nonemitting only... + + Token *new_tok = new Token(arc, tok); + + if (new_tok->cost_ > cutoff) { // prune + Token::TokenDelete(new_tok); + continue; + } + + Elem *e_found = toks_.Insert(arc.nextstate, new_tok); + if (e_found->val == new_tok) { + // if inserted successfully + queue_.push_back(e_found); + continue; + } + + // there is another token at this state, we need + // to compare their costs and keep the one with a lower cost + + if (*(e_found->val) < *new_tok) { + // i.e., if the cost of e_found is larger than new_tok + // we keep the token with a lower cost + Token::TokenDelete(e_found->val); + e_found->val = new_tok; + queue_.push_back(e_found); + } else { + // the new token has a higher cost, remove it + Token::TokenDelete(new_tok); + } + } + } +} + +void FasterDecoder::Decode(DecodableInterface *decodable) { + InitDecoding(); + AdvanceDecoding(decodable); +} + +void FasterDecoder::AdvanceDecoding(DecodableInterface *decodable, + int32_t max_num_frames /*=-1*/) { + KALDI_DECODER_ASSERT(num_frames_decoded_ >= 0 && + "You must call InitDecoding() before AdvanceDecoding()"); + + int32_t num_frames_ready = decodable->NumFramesReady(); + + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_DECODER_ASSERT(num_frames_ready >= num_frames_decoded_); + + int32_t target_frames_decoded = num_frames_ready; + + if (max_num_frames >= 0) { + target_frames_decoded = + std::min(target_frames_decoded, num_frames_decoded_ + max_num_frames); + } + + while (num_frames_decoded_ < target_frames_decoded) { + // note: ProcessEmitting() increments num_frames_decoded_ + double weight_cutoff = ProcessEmitting(decodable); + + ProcessNonemitting(weight_cutoff); + } +} + +// ProcessEmitting returns the likelihood cutoff used. +double FasterDecoder::ProcessEmitting(DecodableInterface *decodable) { + int32_t frame = num_frames_decoded_; + Elem *last_toks = toks_.Clear(); + size_t tok_cnt; + float adaptive_beam; + Elem *best_elem = nullptr; + double weight_cutoff = + GetCutoff(last_toks, &tok_cnt, &adaptive_beam, &best_elem); + + // KALDI_DECODER_LOG << tok_cnt << " tokens active."; + + // This makes sure the hash is always big enough. + PossiblyResizeHash(tok_cnt); + + // This is the cutoff we use after adding in the log-likes (i.e. + // for the next frame). This is a bound on the cutoff we will use + // on the next frame. + double next_weight_cutoff = std::numeric_limits::infinity(); + + // First process the best token to get a hopefully + // reasonably tight bound on the next cutoff. + if (best_elem) { + StateId state = best_elem->key; + Token *tok = best_elem->val; + for (fst::ArcIterator> aiter(fst_, state); !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // we'd propagate.. + float ac_cost = -1 * decodable->LogLikelihood(frame, arc.ilabel); + double new_weight = arc.weight.Value() + tok->cost_ + ac_cost; + if (new_weight + adaptive_beam < next_weight_cutoff) + next_weight_cutoff = new_weight + adaptive_beam; + } + } + } + + // int32_t n = 0, np = 0; + + // the tokens are now owned here, in last_toks, and the hash is empty. + // 'owned' is a complex thing here; the point is we need to call TokenDelete + // on each elem 'e' to let toks_ know we're done with them. + for (Elem *e = last_toks, *e_tail; e != nullptr; + e = e_tail) { // loop this way + // n++; + // because we delete "e" as we go. + StateId state = e->key; + Token *tok = e->val; + if (tok->cost_ < weight_cutoff) { // not pruned. + // np++; + KALDI_DECODER_ASSERT(state == tok->arc_.nextstate); + for (fst::ArcIterator> aiter(fst_, state); !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + float ac_cost = -1 * decodable->LogLikelihood(frame, arc.ilabel); + double new_weight = arc.weight.Value() + tok->cost_ + ac_cost; + if (new_weight < next_weight_cutoff) { // not pruned.. + Token *new_tok = new Token(arc, ac_cost, tok); + Elem *e_found = toks_.Insert(arc.nextstate, new_tok); + + if (new_weight + adaptive_beam < next_weight_cutoff) { + next_weight_cutoff = new_weight + adaptive_beam; + } + + if (e_found->val != new_tok) { + if (*(e_found->val) < *new_tok) { + // e_found has a higher cost + Token::TokenDelete(e_found->val); + e_found->val = new_tok; + } else { + // new_tok has a higher cost + Token::TokenDelete(new_tok); + } + } + } + } + } + } + + e_tail = e->tail; + Token::TokenDelete(e->val); + toks_.Delete(e); + } + + num_frames_decoded_++; + return next_weight_cutoff; +} + +// Gets the weight cutoff. Also counts the active tokens. +double FasterDecoder::GetCutoff(Elem *list_head, size_t *tok_count, + float *adaptive_beam, Elem **best_elem) { + double best_cost = std::numeric_limits::infinity(); + + size_t count = 0; + if (config_.max_active == std::numeric_limits::max() && + config_.min_active == 0) { + // no constraints + for (Elem *e = list_head; e != nullptr; e = e->tail, ++count) { + double w = e->val->cost_; + if (w < best_cost) { + best_cost = w; + + if (best_elem) { + *best_elem = e; + } + } + } + + if (tok_count != nullptr) { + *tok_count = count; + } + + if (adaptive_beam != nullptr) { + *adaptive_beam = config_.beam; + } + + return best_cost + config_.beam; + } + + tmp_array_.clear(); + + for (Elem *e = list_head; e != nullptr; e = e->tail, ++count) { + double w = e->val->cost_; + tmp_array_.push_back(w); + + if (w < best_cost) { + best_cost = w; + + if (best_elem) { + *best_elem = e; + } + } + } + + if (tok_count != nullptr) { + *tok_count = count; + } + + double beam_cutoff = best_cost + config_.beam; + double min_active_cutoff = std::numeric_limits::infinity(); + double max_active_cutoff = std::numeric_limits::infinity(); + + if (tmp_array_.size() > static_cast(config_.max_active)) { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.max_active, tmp_array_.end()); + max_active_cutoff = tmp_array_[config_.max_active]; + } + + if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. + if (adaptive_beam) { + *adaptive_beam = max_active_cutoff - best_cost + config_.beam_delta; + } + + return max_active_cutoff; + } + + if (tmp_array_.size() > static_cast(config_.min_active)) { + if (config_.min_active == 0) { + min_active_cutoff = best_cost; + } else { + std::nth_element( + tmp_array_.begin(), tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) + ? tmp_array_.begin() + config_.max_active + : tmp_array_.end()); + + min_active_cutoff = tmp_array_[config_.min_active]; + } + } + + if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. + if (adaptive_beam) { + *adaptive_beam = min_active_cutoff - best_cost + config_.beam_delta; + } + + return min_active_cutoff; + } else { + *adaptive_beam = config_.beam; + + return beam_cutoff; + } +} + +void FasterDecoder::PossiblyResizeHash(size_t num_toks) { + auto new_sz = + static_cast(static_cast(num_toks) * config_.hash_ratio); + + if (new_sz > toks_.Size()) { + toks_.SetSize(new_sz); + } +} + +bool FasterDecoder::ReachedFinal() const { + for (const Elem *e = toks_.GetList(); e != nullptr; e = e->tail) { + if (e->val->cost_ != std::numeric_limits::infinity() && + fst_.Final(e->key) != Weight::Zero()) + return true; + } + return false; +} + +bool FasterDecoder::GetBestPath(fst::MutableFst *fst_out, + bool use_final_probs) { + // GetBestPath gets the decoding output. If "use_final_probs" is true + // AND we reached a final state, it limits itself to final states; + // otherwise it gets the most likely token not taking into + // account final-probs. fst_out will be empty (Start() == kNoStateId) if + // nothing was available. It returns true if it got output (thus, fst_out + // will be nonempty). + fst_out->DeleteStates(); + Token *best_tok = nullptr; + bool is_final = ReachedFinal(); + if (!is_final) { + for (const Elem *e = toks_.GetList(); e != nullptr; e = e->tail) { + if (best_tok == nullptr || *best_tok < *(e->val)) { + best_tok = e->val; + } + } + } else { + double infinity = std::numeric_limits::infinity(); + double best_cost = infinity; + + for (const Elem *e = toks_.GetList(); e != nullptr; e = e->tail) { + double this_cost = e->val->cost_ + fst_.Final(e->key).Value(); + if (this_cost < best_cost && this_cost != infinity) { + best_cost = this_cost; + best_tok = e->val; + } + } + } + + if (best_tok == nullptr) { + // No output. + return false; + } + + std::vector arcs_reverse; // arcs in reverse order. + + for (Token *tok = best_tok; tok != nullptr; tok = tok->prev_) { + float tot_cost = tok->cost_ - (tok->prev_ ? tok->prev_->cost_ : 0.0); + float graph_cost = tok->arc_.weight.Value(); + float ac_cost = tot_cost - graph_cost; + + fst::LatticeArc l_arc(tok->arc_.ilabel, tok->arc_.olabel, + fst::LatticeWeight(graph_cost, ac_cost), + tok->arc_.nextstate); + arcs_reverse.push_back(l_arc); + } + + KALDI_DECODER_ASSERT(arcs_reverse.back().nextstate == fst_.Start()); + + arcs_reverse.pop_back(); // that was a "fake" token... gives no info. + + StateId cur_state = fst_out->AddState(); + fst_out->SetStart(cur_state); + for (ssize_t i = static_cast(arcs_reverse.size()) - 1; i >= 0; --i) { + fst::LatticeArc arc = arcs_reverse[i]; + arc.nextstate = fst_out->AddState(); + fst_out->AddArc(cur_state, arc); + cur_state = arc.nextstate; + } + if (is_final && use_final_probs) { + Weight final_weight = fst_.Final(best_tok->arc_.nextstate); + fst_out->SetFinal(cur_state, fst::LatticeWeight(final_weight.Value(), 0.0)); + } else { + fst_out->SetFinal(cur_state, fst::LatticeWeight::One()); + } + fst::RemoveEpsLocal(fst_out); + return true; +} + +} // namespace kaldi_decoder diff --git a/kaldi-decoder/csrc/faster-decoder.h b/kaldi-decoder/csrc/faster-decoder.h new file mode 100644 index 0000000..ba41dee --- /dev/null +++ b/kaldi-decoder/csrc/faster-decoder.h @@ -0,0 +1,204 @@ +// kaldi-decoder/csrc/faster-decoder.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) +// Copyright (c) 2023 Xiaomi Corporation + +// this file is copied and modified from +// kaldi/src/decoder/faster-decoder.h +#ifndef KALDI_DECODER_CSRC_FASTER_DECODER_H_ +#define KALDI_DECODER_CSRC_FASTER_DECODER_H_ + +#include +#include +#include + +#include "fst/fst.h" +#include "fst/fstlib.h" +#include "kaldi-decoder/csrc/decodable-itf.h" +#include "kaldi-decoder/csrc/hash-list.h" +#include "kaldifst/csrc/lattice-weight.h" + +namespace kaldi_decoder { + +struct FasterDecoderOptions { + // Decoding beam. Larger->slower, more accurate. + float beam; + + // Decoder max active states. Larger->slower; more accurate + int32_t max_active; + + // Decoder min active states (don't prune if #active less than this). + int32_t min_active; + + // Increment used in decoder [obscure setting] + float beam_delta; + + // Setting used in decoder to control hash behavior + float hash_ratio; + + /*implicit*/ FasterDecoderOptions( + float beam = 16.0, + int32_t max_active = std::numeric_limits::max(), + int32_t min_active = 20, float beam_delta = 0.5, float hash_ratio = 2.0) + : beam(beam), + max_active(max_active), + min_active(min_active), // This decoder mostly used for + // alignment, use small default. + beam_delta(beam_delta), + hash_ratio(hash_ratio) {} + + std::string ToString() const { + std::ostringstream os; + + os << "FasterDecoderOptions("; + os << "beam=" << beam << ", "; + os << "max_active=" << max_active << ", "; + os << "min_active=" << min_active << ", "; + os << "beam_delta=" << beam_delta << ", "; + os << "hash_ratio=" << hash_ratio << ")"; + + return os.str(); + } +}; + +class FasterDecoder { + public: + typedef fst::StdArc Arc; + typedef Arc::Label Label; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + + FasterDecoder(const fst::Fst &fst, + const FasterDecoderOptions &config); + + FasterDecoder(const FasterDecoder &) = delete; + FasterDecoder &operator=(const FasterDecoder &) = delete; + + void SetOptions(const FasterDecoderOptions &config) { config_ = config; } + + ~FasterDecoder() { ClearToks(toks_.Clear()); } + + void Decode(DecodableInterface *decodable); + + /// Returns true if a final state was active on the last frame. + bool ReachedFinal() const; + + /// GetBestPath gets the decoding traceback. If "use_final_probs" is true + /// AND we reached a final state, it limits itself to final states; + /// otherwise it gets the most likely token not taking into account + /// final-probs. Returns true if the output best path was not the empty + /// FST (will only return false in unusual circumstances where + /// no tokens survived). + bool GetBestPath(fst::MutableFst *fst_out, + bool use_final_probs = true); + + /// As a new alternative to Decode(), you can call InitDecoding + /// and then (possibly multiple times) AdvanceDecoding(). + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object, but if max_num_frames is >= 0 it will decode no more than + /// that many frames. + void AdvanceDecoding(DecodableInterface *decodable, + int32_t max_num_frames = -1); + + /// Returns the number of frames already decoded. + int32_t NumFramesDecoded() const { return num_frames_decoded_; } + + protected: + class Token { + public: + Arc arc_; // contains only the graph part of the cost; + // we can work out the acoustic part from difference between + // "cost_" and prev->cost_. + Token *prev_; + int32_t ref_count_; + // if you are looking for weight_ here, it was removed and now we just have + // cost_, which corresponds to ConvertToCost(weight_). + double cost_; + inline Token(const Arc &arc, float ac_cost, Token *prev) + : arc_(arc), prev_(prev), ref_count_(1) { + if (prev) { + prev->ref_count_++; + cost_ = prev->cost_ + arc.weight.Value() + ac_cost; + } else { + cost_ = arc.weight.Value() + ac_cost; + } + } + + // for epsilon transitions, i.e., non-emitting arcs + inline Token(const Arc &arc, Token *prev) + : arc_(arc), prev_(prev), ref_count_(1) { + if (prev) { + prev->ref_count_++; + cost_ = prev->cost_ + arc.weight.Value(); + } else { + cost_ = arc.weight.Value(); + } + } + + inline bool operator<(const Token &other) const { + return cost_ > other.cost_; + } + + inline static void TokenDelete(Token *tok) { + while (--tok->ref_count_ == 0) { + Token *prev = tok->prev_; + delete tok; + if (prev == nullptr) { + return; + } else { + tok = prev; + } + } + } + }; + + using Elem = HashList::Elem; + + /// Gets the weight cutoff. Also counts the active tokens. + double GetCutoff(Elem *list_head, size_t *tok_count, float *adaptive_beam, + Elem **best_elem); + + void PossiblyResizeHash(size_t num_toks); + + // ProcessEmitting returns the likelihood cutoff used. + // It decodes the frame num_frames_decoded_ of the decodable object + // and then increments num_frames_decoded_ + double ProcessEmitting(DecodableInterface *decodable); + + // TODO(dan): first time we go through this, could avoid using the queue. + void ProcessNonemitting(double cutoff); + + // HashList defined in ../hash-list.h. It actually allows us to maintain + // more than one list (e.g. for current and previous frames), but only one of + // them at a time can be indexed by StateId. + HashList toks_; + + const fst::Fst &fst_; + + FasterDecoderOptions config_; + + // temp variable used in ProcessNonemitting, + std::vector queue_; + + std::vector tmp_array_; // used in GetCutoff. + // make it class member to avoid internal new/delete. + + // Keep track of the number of frames decoded in the current file. + int32_t num_frames_decoded_; + + // It might seem unclear why we call ClearToks(toks_.Clear()). + // There are two separate cleanup tasks we need to do at when we start a new + // file. one is to delete the Token objects in the list; the other is to + // delete the Elem objects. toks_.Clear() just clears them from the hash and + // gives ownership to the caller, who then has to call toks_.Delete(e) for + // each one. It was designed this way for convenience in propagating tokens + // from one frame to the next. + void ClearToks(Elem *list); +}; + +} // namespace kaldi_decoder + +#endif // KALDI_DECODER_CSRC_FASTER_DECODER_H_ diff --git a/kaldi-decoder/csrc/hash-list-inl.h b/kaldi-decoder/csrc/hash-list-inl.h new file mode 100644 index 0000000..a58ae22 --- /dev/null +++ b/kaldi-decoder/csrc/hash-list-inl.h @@ -0,0 +1,207 @@ +// kaldi-decoder/csrc/hash-list-inl.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) +// Copyright (c) 2023 Xiaomi Corporation + +// this file is copied and modified from +// kaldi/src/utils/hash-list-inl.h + +#ifndef KALDI_DECODER_CSRC_HASH_LIST_INL_H_ +#define KALDI_DECODER_CSRC_HASH_LIST_INL_H_ + +// Do not include this file directly. It is included by fast-hash.h + +namespace kaldi_decoder { + +template +HashList::HashList() { + list_head_ = nullptr; + bucket_list_tail_ = static_cast(-1); // invalid. + hash_size_ = 0; + freed_head_ = nullptr; +} + +template +void HashList::SetSize(size_t size) { + hash_size_ = size; + KALDI_DECODER_ASSERT(list_head_ == nullptr && + bucket_list_tail_ == + static_cast(-1)); // make sure empty. + + if (size > buckets_.size()) { + buckets_.resize(size, HashBucket(0, nullptr)); + } +} + +template +typename HashList::Elem *HashList::Clear() { + // Clears the hashtable and gives ownership of the currently contained list + // to the user. + for (size_t cur_bucket = bucket_list_tail_; + cur_bucket != static_cast(-1); + cur_bucket = buckets_[cur_bucket].prev_bucket) { + buckets_[cur_bucket].last_elem = + nullptr; // this is how we indicate "empty". + } + bucket_list_tail_ = static_cast(-1); + Elem *ans = list_head_; + list_head_ = nullptr; + return ans; +} + +template +const typename HashList::Elem *HashList::GetList() const { + return list_head_; +} + +template +inline void HashList::Delete(Elem *e) { + e->tail = freed_head_; + freed_head_ = e; +} + +template +inline typename HashList::Elem *HashList::Find(I key) { + size_t index = (static_cast(key) % hash_size_); + HashBucket &bucket = buckets_[index]; + if (bucket.last_elem == nullptr) { + return nullptr; // empty bucket. + } else { + Elem *head = (bucket.prev_bucket == static_cast(-1) + ? list_head_ + : buckets_[bucket.prev_bucket].last_elem->tail), + *tail = bucket.last_elem->tail; + for (Elem *e = head; e != tail; e = e->tail) { + if (e->key == key) { + return e; + } + } + + return nullptr; // Not found. + } +} + +template +inline typename HashList::Elem *HashList::New() { + if (freed_head_) { + Elem *ans = freed_head_; + freed_head_ = freed_head_->tail; + return ans; + } else { + Elem *tmp = new Elem[allocate_block_size_]; + for (size_t i = 0; i + 1 < allocate_block_size_; i++) { + tmp[i].tail = tmp + i + 1; + } + + tmp[allocate_block_size_ - 1].tail = nullptr; + freed_head_ = tmp; + allocated_.push_back(tmp); + return this->New(); + } +} + +template +HashList::~HashList() { + // First test whether we had any memory leak within the + // HashList, i.e. things for which the user did not call Delete(). + size_t num_in_list = 0, num_allocated = 0; + + for (Elem *e = freed_head_; e != nullptr; e = e->tail) { + num_in_list++; + } + + for (size_t i = 0; i < allocated_.size(); i++) { + num_allocated += allocate_block_size_; + delete[] allocated_[i]; + } + + if (num_in_list != num_allocated) { + KALDI_DECODER_WARN << "Possible memory leak: " << num_in_list + << " != " << num_allocated + << ": you might have forgotten to call Delete on " + << "some Elems"; + } +} + +template +inline typename HashList::Elem *HashList::Insert(I key, T val) { + size_t index = (static_cast(key) % hash_size_); + HashBucket &bucket = buckets_[index]; + // Check the element is existing or not. + if (bucket.last_elem != nullptr) { + Elem *head = (bucket.prev_bucket == static_cast(-1) + ? list_head_ + : buckets_[bucket.prev_bucket].last_elem->tail), + *tail = bucket.last_elem->tail; + + for (Elem *e = head; e != tail; e = e->tail) { + if (e->key == key) { + return e; + } + } + } + + // This is a new element. Insert it. + Elem *elem = New(); + elem->key = key; + elem->val = val; + + if (bucket.last_elem == nullptr) { // Unoccupied bucket. Insert at + // head of bucket list (which is tail of regular list, they go in + // opposite directions). + if (bucket_list_tail_ == static_cast(-1)) { + // list was empty so this is the first elem. + KALDI_DECODER_ASSERT(list_head_ == nullptr); + list_head_ = elem; + } else { + // link in to the chain of Elems + buckets_[bucket_list_tail_].last_elem->tail = elem; + } + elem->tail = nullptr; + bucket.last_elem = elem; + bucket.prev_bucket = bucket_list_tail_; + bucket_list_tail_ = index; + } else { + // Already-occupied bucket. Insert at tail of list of elements within + // the bucket. + elem->tail = bucket.last_elem->tail; + bucket.last_elem->tail = elem; + bucket.last_elem = elem; + } + return elem; +} + +template +void HashList::InsertMore(I key, T val) { + size_t index = (static_cast(key) % hash_size_); + HashBucket &bucket = buckets_[index]; + Elem *elem = New(); + elem->key = key; + elem->val = val; + + // assume one element is already here + KALDI_DECODER_ASSERT(bucket.last_elem != nullptr); + + if (bucket.last_elem->key == key) { // standard behavior: add as last element + elem->tail = bucket.last_elem->tail; + bucket.last_elem->tail = elem; + bucket.last_elem = elem; + return; + } + Elem *e = (bucket.prev_bucket == static_cast(-1) + ? list_head_ + : buckets_[bucket.prev_bucket].last_elem->tail); + // find place to insert in linked list + while (e != bucket.last_elem->tail && e->key != key) { + e = e->tail; + } + + KALDI_DECODER_ASSERT(e->key == key); // not found? - should not happen + elem->tail = e->tail; + e->tail = elem; +} + +} // namespace kaldi_decoder + +#endif // KALDI_DECODER_CSRC_HASH_LIST_INL_H_ diff --git a/kaldi-decoder/csrc/hash-list-test.cc b/kaldi-decoder/csrc/hash-list-test.cc new file mode 100644 index 0000000..c9949c3 --- /dev/null +++ b/kaldi-decoder/csrc/hash-list-test.cc @@ -0,0 +1,103 @@ +// kaldi-decoder/csrc/hash-list-test.cc + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) +// Copyright (c) 2023 Xiaomi Corporation + +// this file is copied and modified from +// kaldi/src/utils/hash-list-test.cc + +#include "kaldi-decoder/csrc/hash-list.h" + +#include + +#include // for baseline. + +// #include "kaldi-decoder/csrc/kaldi-math.h" +#include "gtest/gtest.h" + +namespace kaldi_decoder { + +template +void TestHashList() { + typedef typename HashList::Elem Elem; + + HashList hash; + hash.SetSize(200); // must be called before use. + std::map m1; + + for (size_t j = 0; j < 50; j++) { + Int key = rand() % 200; // NOLINT + T val = rand() % 50; // NOLINT + m1[key] = val; + Elem *e = hash.Find(key); + if (e) { + e->val = val; + } else { + hash.Insert(key, val); + } + } + + std::map m2; + + for (int i = 0; i < 100; i++) { + m2.clear(); + for (auto iter = m1.begin(); iter != m1.end(); iter++) { + m2[iter->first + 1] = iter->second; + } + std::swap(m1, m2); + + Elem *h = hash.Clear(), *tmp; + + // note, SetSize is relatively cheap + hash.SetSize(100 + rand() % 100); // NOLINT + // operation as long as we are not increasing the size more than it's ever + // previously been increased to. + + for (; h != nullptr; h = tmp) { + hash.Insert(h->key + 1, h->val); + tmp = h->tail; + hash.Delete(h); // think of this like calling delete. + } + + // Now make sure h and m2 are the same. + const Elem *list = hash.GetList(); + size_t count = 0; + for (; list != nullptr; list = list->tail, count++) { + KALDI_DECODER_ASSERT(m1[list->key] == list->val); + } + + for (size_t j = 0; j < 10; j++) { + Int key = rand() % 200; // NOLINT + bool found_m1 = (m1.find(key) != m1.end()); + + if (found_m1) m1[key]; + + Elem *e = hash.Find(key); + KALDI_DECODER_ASSERT((e != nullptr) == found_m1); + + if (found_m1) KALDI_DECODER_ASSERT(m1[key] == e->val); + } + + KALDI_DECODER_ASSERT(m1.size() == count); + } + + Elem *h = hash.Clear(), *tmp; + for (; h != nullptr; h = tmp) { + tmp = h->tail; + hash.Delete(h); // think of this like calling delete. + } +} + +TEST(HashList, Test) { + for (size_t i = 0; i < 3; i++) { + TestHashList(); + TestHashList(); + TestHashList(); + TestHashList(); + TestHashList(); + TestHashList(); + } +} + +} // namespace kaldi_decoder diff --git a/kaldi-decoder/csrc/hash-list.h b/kaldi-decoder/csrc/hash-list.h new file mode 100644 index 0000000..d006ef9 --- /dev/null +++ b/kaldi-decoder/csrc/hash-list.h @@ -0,0 +1,133 @@ +// kaldi-decoder/csrc/hash-list.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) +// Copyright (c) 2023 Xiaomi Corporation + +// this file is copied and modified from +// kaldi/src/utils/hash-list.h + +#ifndef KALDI_DECODER_CSRC_HASH_LIST_H_ +#define KALDI_DECODER_CSRC_HASH_LIST_H_ + +#include +#include +#include +#include +#include + +#include "kaldi-decoder/csrc/stl-utils.h" + +/* This header provides utilities for a structure that's used in a decoder (but + is quite generic in nature so we implement and test it separately). + Basically it's a singly-linked list, but implemented in such a way that we + can quickly search for elements in the list. We give it a slightly richer + interface than just a hash and a list. The idea is that we want to separate + the hash part and the list part: basically, in the decoder, we want to have a + single hash for the current frame and the next frame, because by the time we + need to access the hash for the next frame we no longer need the hash for the + previous frame. So we have an operation that clears the hash but leaves the + list structure intact. We also control memory management inside this object, + to avoid repeated new's/deletes. +*/ + +namespace kaldi_decoder { + +template +class HashList { + public: + struct Elem { + I key; + T val; + Elem *tail; + }; + + /// Constructor takes no arguments. + /// Call SetSize to inform it of the likely size. + HashList(); + + /// Clears the hash and gives the head of the current list to the user; + /// ownership is transferred to the user (the user must call Delete() + /// for each element in the list, at his/her leisure). + Elem *Clear(); + + /// Gives the head of the current list to the user. Ownership retained in the + /// class. Caution: in December 2013 the return type was changed to const + /// Elem* and this function was made const. You may need to change some types + /// of local Elem* variables to const if this produces compilation errors. + const Elem *GetList() const; + + /// Think of this like delete(). It is to be called for each Elem in turn + /// after you "obtained ownership" by doing Clear(). This is not the opposite + /// of Insert, it is the opposite of New. It's really a memory operation. + inline void Delete(Elem *e); + + /// This should probably not be needed to be called directly by the user. + /// Think of it as opposite + /// to Delete(); + inline Elem *New(); + + /// Find tries to find this element in the current list using the hashtable. + /// It returns NULL if not present. The Elem it returns is not owned by the + /// user, it is part of the internal list owned by this object, but the user + /// is free to modify the "val" element. + inline Elem *Find(I key); + + /// Insert inserts a new element into the hashtable/stored list. + /// Because element keys in a hashtable are unique, this operation checks + /// whether each inserted element has a key equivalent to the one of an + /// element already in the hashtable. If so, the element is not inserted, + /// returning an pointer to this existing element. + inline Elem *Insert(I key, T val); + + /// Insert inserts another element with same key into the hashtable/ + /// stored list. + /// By calling this, the user asserts that one element with that key is + /// already present. + /// We insert it that way, that all elements with the same key + /// follow each other. + /// Find() will return the first one of the elements with the same key. + inline void InsertMore(I key, T val); + + /// SetSize tells the object how many hash buckets to allocate (should + /// typically be at least twice the number of objects we expect to go in the + /// structure, for fastest performance). It must be called while the hash + /// is empty (e.g. after Clear() or after initializing the object, but before + /// adding anything to the hash. + void SetSize(size_t sz); + + /// Returns current number of hash buckets. + inline size_t Size() const { return hash_size_; } + + ~HashList(); + + private: + struct HashBucket { + size_t prev_bucket; // index to next bucket (-1 if list tail). Note: + // list of buckets goes in opposite direction to list of Elems. + Elem *last_elem; // pointer to last element in this bucket (NULL if empty) + inline HashBucket(size_t i, Elem *e) : prev_bucket(i), last_elem(e) {} + }; + + Elem *list_head_; // head of currently stored list. + size_t bucket_list_tail_; // tail of list of active hash buckets. + + size_t hash_size_; // number of hash buckets. + + std::vector buckets_; + + Elem *freed_head_; // head of list of currently freed elements. [ready for + // allocation] + + std::vector allocated_; // list of allocated blocks. + + static const size_t allocate_block_size_ = 1024; // Number of Elements to + // allocate in one block. Must be largish so storing allocated_ doesn't + // become a problem. +}; + +} // namespace kaldi_decoder + +#include "kaldi-decoder/csrc/hash-list-inl.h" + +#endif // KALDI_DECODER_CSRC_HASH_LIST_H_ diff --git a/kaldi-decoder/csrc/log.h b/kaldi-decoder/csrc/log.h new file mode 100644 index 0000000..f3ad806 --- /dev/null +++ b/kaldi-decoder/csrc/log.h @@ -0,0 +1,100 @@ +// kaldi-decoder/csrc/log.h +// +// Copyright (c) 2022 Xiaomi Corporation + +#ifndef KALDI_DECODER_CSRC_LOG_H_ +#define KALDI_DECODER_CSRC_LOG_H_ + +#include +#include +#include +#include + +namespace kaldi_decoder { + +enum class LogLevel { + kInfo = 0, + kWarn = 1, + kError = 2, // abort the program +}; + +class Logger { + public: + Logger(const char *filename, const char *func_name, uint32_t line_num, + LogLevel level) + : level_(level) { + os_ << filename << ":" << func_name << ":" << line_num << "\n"; + switch (level_) { + case LogLevel::kInfo: + os_ << "[I] "; + break; + case LogLevel::kWarn: + os_ << "[W] "; + break; + case LogLevel::kError: + os_ << "[E] "; + break; + } + } + + template + Logger &operator<<(const T &val) { + os_ << val; + return *this; + } + + ~Logger() noexcept(false) { + if (level_ == LogLevel::kError) { + // throw std::runtime_error(os_.str()); + // abort(); + throw std::runtime_error(os_.str()); + } + // fprintf(stderr, "%s\n", os_.str().c_str()); + } + + private: + std::ostringstream os_; + LogLevel level_; +}; + +class Voidifier { + public: + void operator&(const Logger &) const {} +}; + +#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) || \ + defined(__PRETTY_FUNCTION__) +// for clang and GCC +#define KALDI_DECODER_FUNC __PRETTY_FUNCTION__ +#else +// for other compilers +#define KALDI_DECODER_FUNC __func__ +#endif + +#define KALDI_DECODER_LOG \ + kaldi_decoder::Logger(__FILE__, KALDI_DECODER_FUNC, __LINE__, \ + kaldi_decoder::LogLevel::kInfo) + +#define KALDI_DECODER_WARN \ + kaldi_decoder::Logger(__FILE__, KALDI_DECODER_FUNC, __LINE__, \ + kaldi_decoder::LogLevel::kWarn) + +#define KALDI_DECODER_ERR \ + kaldi_decoder::Logger(__FILE__, KALDI_DECODER_FUNC, __LINE__, \ + kaldi_decoder::LogLevel::kError) + +#define KALDI_DECODER_ASSERT(x) \ + (x) ? (void)0 \ + : kaldi_decoder::Voidifier() & KALDI_DECODER_ERR << "Check failed!\n" \ + << "x: " << #x + +#define KALDI_DECODER_PARANOID_ASSERT KALDI_DECODER_ASSERT + +#define KALDI_DECODER_DISALLOW_COPY_AND_ASSIGN(Class) \ + public: \ + Class(const Class &) = delete; \ + Class &operator=(const Class &) = delete; + +} // namespace kaldi_decoder + +#endif // KALDI_DECODER_CSRC_LOG_H_ diff --git a/kaldi-decoder/csrc/stl-utils.h b/kaldi-decoder/csrc/stl-utils.h new file mode 100644 index 0000000..b06ddf8 --- /dev/null +++ b/kaldi-decoder/csrc/stl-utils.h @@ -0,0 +1,209 @@ +// kaldi-decoder/csrc/stl-utils.h +// +// Copyright (c) 2022 Xiaomi Corporation + +// this file is copied and modified from +// kaldi/src/util/stl-utils.h + +#ifndef KALDI_DECODER_CSRC_STL_UTILS_H_ +#define KALDI_DECODER_CSRC_STL_UTILS_H_ +#include +#include +#include +#include +#include + +#include "kaldi-decoder/csrc/log.h" + +namespace kaldi_decoder { + +/// Returns true if the vector is sorted. +template +inline bool IsSorted(const std::vector &vec) { + typename std::vector::const_iterator iter = vec.begin(), end = vec.end(); + if (iter == end) return true; + while (1) { + typename std::vector::const_iterator next_iter = iter; + ++next_iter; + if (next_iter == end) return true; // end of loop and nothing out of order + if (*next_iter < *iter) return false; + iter = next_iter; + } +} + +/// Sorts and uniq's (removes duplicates) from a vector. +template +inline void SortAndUniq(std::vector *vec) { + std::sort(vec->begin(), vec->end()); + vec->erase(std::unique(vec->begin(), vec->end()), vec->end()); +} + +/// Returns true if the vector is sorted and contains each element +/// only once. +template +inline bool IsSortedAndUniq(const std::vector &vec) { + typename std::vector::const_iterator iter = vec.begin(), end = vec.end(); + if (iter == end) return true; + while (1) { + typename std::vector::const_iterator next_iter = iter; + ++next_iter; + if (next_iter == end) return true; // end of loop and nothing out of order + if (*next_iter <= *iter) return false; + iter = next_iter; + } +} + +template +inline void WriteIntegerVector(std::ostream &os, bool binary, + const std::vector &v) { + // Compile time assertion that this is not called with a wrong type. + static_assert(std::is_integral::value, ""); + if (binary) { + char sz = sizeof(T); // this is currently just a check. + os.write(&sz, 1); + int32_t vecsz = static_cast(v.size()); + KALDI_DECODER_ASSERT((size_t)vecsz == v.size()); + + os.write(reinterpret_cast(&vecsz), sizeof(vecsz)); + if (vecsz != 0) { + os.write(reinterpret_cast(&(v[0])), sizeof(T) * vecsz); + } + } else { + // focus here is on prettiness of text form rather than + // efficiency of reading-in. + // reading-in is dominated by low-level operations anyway: + // for efficiency use binary. + os << "[ "; + typename std::vector::const_iterator iter = v.begin(), end = v.end(); + for (; iter != end; ++iter) { + if (sizeof(T) == 1) + os << static_cast(*iter) << " "; + else + os << *iter << " "; + } + os << "]\n"; + } + if (os.fail()) { + KALDI_DECODER_ERR << "Write failure in WriteIntegerVector."; + } +} + +template +inline void ReadIntegerVector(std::istream &is, bool binary, + std::vector *v) { + static_assert(std::is_integral::value, ""); + KALDI_DECODER_ASSERT(v != nullptr); + if (binary) { + int sz = is.peek(); + if (sz == sizeof(T)) { + is.get(); + } else { // this is currently just a check. + KALDI_DECODER_ERR << "ReadIntegerVector: expected to see type of size " + << sizeof(T) << ", saw instead " << sz + << ", at file position " << is.tellg(); + } + int32_t vecsz; + is.read(reinterpret_cast(&vecsz), sizeof(vecsz)); + if (is.fail() || vecsz < 0) goto bad; + + v->resize(vecsz); + + if (vecsz > 0) { + is.read(reinterpret_cast(&((*v)[0])), sizeof(T) * vecsz); + } + } else { + std::vector tmp_v; // use temporary so v doesn't use extra memory + // due to resizing. + is >> std::ws; + if (is.peek() != static_cast('[')) { + KALDI_DECODER_ERR << "ReadIntegerVector: expected to see [, saw " + << is.peek() << ", at file position " << is.tellg(); + } + is.get(); // consume the '['. + is >> std::ws; // consume whitespace. + while (is.peek() != static_cast(']')) { + if (sizeof(T) == 1) { // read/write chars as numbers. + int16_t next_t; + is >> next_t >> std::ws; + if (is.fail()) + goto bad; + else + tmp_v.push_back((T)next_t); + } else { + T next_t; + is >> next_t >> std::ws; + if (is.fail()) + goto bad; + else + tmp_v.push_back(next_t); + } + } + is.get(); // get the final ']'. + *v = tmp_v; // could use std::swap to use less temporary memory, but this + // uses less permanent memory. + } + if (!is.fail()) return; +bad: + KALDI_DECODER_ERR << "ReadIntegerVector: read failure at file position " + << is.tellg(); +} + +/// Deletes any non-NULL pointers in the vector v, and sets +/// the corresponding entries of v to NULL +template +void DeletePointers(std::vector *v) { + KALDI_DECODER_ASSERT(v != nullptr); + typename std::vector::iterator iter = v->begin(), end = v->end(); + for (; iter != end; ++iter) { + if (*iter != nullptr) { + delete *iter; + *iter = nullptr; // set to NULL for extra safety. + } + } +} + +/// A hashing function-object for pairs of ints +template +struct PairHasher { // hashing function for pair + size_t operator()(const std::pair &x) const noexcept { + // 7853 was chosen at random from a list of primes. + return x.first + x.second * 7853; + } + PairHasher() { // Check we're instantiated with an integer type. + static_assert(std::is_integral::value, ""); + static_assert(std::is_integral::value, ""); + } +}; + +/// Returns true if the vector of pointers contains NULL pointers. +template +bool ContainsNullPointers(const std::vector &v) { + typename std::vector::const_iterator iter = v.begin(), end = v.end(); + for (; iter != end; ++iter) + if (*iter == static_cast(nullptr)) return true; + return false; +} + +/// A hashing function-object for vectors. +template +struct VectorHasher { // hashing function for vector. + size_t operator()(const std::vector &x) const noexcept { + size_t ans = 0; + auto iter = x.begin(), end = x.end(); + for (; iter != end; ++iter) { + ans *= kPrime; + ans += *iter; + } + return ans; + } + VectorHasher() { // Check we're instantiated with an integer type. + static_assert(std::is_integral::value, ""); + } + + private: + static const int kPrime = 7853; +}; + +} // namespace kaldi_decoder + +#endif // KALDI_DECODER_CSRC_STL_UTILS_H_ diff --git a/kaldi-decoder/python/CMakeLists.txt b/kaldi-decoder/python/CMakeLists.txt new file mode 100644 index 0000000..86735ca --- /dev/null +++ b/kaldi-decoder/python/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(csrc) diff --git a/kaldi-decoder/python/csrc/CMakeLists.txt b/kaldi-decoder/python/csrc/CMakeLists.txt new file mode 100644 index 0000000..9470aa0 --- /dev/null +++ b/kaldi-decoder/python/csrc/CMakeLists.txt @@ -0,0 +1,29 @@ +include_directories(${PROJECT_SOURCE_DIR}) + +set(srcs + decodable-ctc.cc + decodable-itf.cc + faster-decoder.cc + kaldi-decoder.cc +) + +pybind11_add_module(_kaldi_decoder ${srcs}) +target_link_libraries(_kaldi_decoder PRIVATE kaldi-decoder-core) + +if(APPLE) + execute_process( + COMMAND "${PYTHON_EXECUTABLE}" -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE PYTHON_SITE_PACKAGE_DIR + ) + message(STATUS "PYTHON_SITE_PACKAGE_DIR: ${PYTHON_SITE_PACKAGE_DIR}") + target_link_libraries(_kaldi_decoder PRIVATE "-Wl,-rpath,${PYTHON_SITE_PACKAGE_DIR}") +endif() + +if(NOT WIN32) + target_link_libraries(_kaldi_decoder PRIVATE "-Wl,-rpath,${KALDI_DECODER_RPATH_ORIGIN}/kaldi_decoder/lib") +endif() + +install(TARGETS _kaldi_decoder + DESTINATION ../ +) diff --git a/kaldi-decoder/python/csrc/decodable-ctc.cc b/kaldi-decoder/python/csrc/decodable-ctc.cc new file mode 100644 index 0000000..7920421 --- /dev/null +++ b/kaldi-decoder/python/csrc/decodable-ctc.cc @@ -0,0 +1,17 @@ +// kaldi-decoder/python/csrc/decodable-ctc.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "kaldi-decoder/python/csrc/decodable-ctc.h" + +#include "kaldi-decoder/csrc/decodable-ctc.h" + +namespace kaldi_decoder { + +void PybindDecodableCtc(py::module *m) { + using PyClass = DecodableCtc; + py::class_(*m, "DecodableCtc") + .def(py::init(), py::arg("feats")); +} + +} // namespace kaldi_decoder diff --git a/kaldi-decoder/python/csrc/decodable-ctc.h b/kaldi-decoder/python/csrc/decodable-ctc.h new file mode 100644 index 0000000..5ca43d6 --- /dev/null +++ b/kaldi-decoder/python/csrc/decodable-ctc.h @@ -0,0 +1,16 @@ +// kaldi-decoder/python/csrc/decodable-ctc.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef KALDI_DECODER_PYTHON_CSRC_DECODABLE_CTC_H_ +#define KALDI_DECODER_PYTHON_CSRC_DECODABLE_CTC_H_ + +#include "kaldi-decoder/python/csrc/kaldi-decoder.h" + +namespace kaldi_decoder { + +void PybindDecodableCtc(py::module *m); + +} + +#endif // KALDI_DECODER_PYTHON_CSRC_DECODABLE_CTC_H_ diff --git a/kaldi-decoder/python/csrc/decodable-itf.cc b/kaldi-decoder/python/csrc/decodable-itf.cc new file mode 100644 index 0000000..914143d --- /dev/null +++ b/kaldi-decoder/python/csrc/decodable-itf.cc @@ -0,0 +1,55 @@ +// kaldi-decoder/python/csrc/decodable-itf.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "kaldi-decoder/python/csrc/decodable-itf.h" + +#include "kaldi-decoder/csrc/decodable-itf.h" + +namespace kaldi_decoder { + +namespace { + +// +// https://pybind11.readthedocs.io/en/stable/advanced/classes.html#overriding-virtual-functions-in-python +// https://pybind11.readthedocs.io/en/stable/reference.html#inheritance +class PyDecodableInterface : public DecodableInterface { + public: + using DecodableInterface::DecodableInterface; + + float LogLikelihood(int32_t frame, int32_t index) override { + PYBIND11_OVERRIDE_PURE_NAME(float, DecodableInterface, "log_likelihood", + LogLikelihood, frame, index); + } + + bool IsLastFrame(int32_t frame) const override { + PYBIND11_OVERRIDE_PURE_NAME(bool, DecodableInterface, "is_last_frame", + IsLastFrame, frame); + } + + int32_t NumFramesReady() const override { + PYBIND11_OVERRIDE_NAME(int32_t, DecodableInterface, "num_frames_ready", + NumFramesReady); + } + + int32_t NumIndices() const override { + PYBIND11_OVERRIDE_PURE_NAME(int32_t, DecodableInterface, "num_indices", + NumIndices); + } +}; + +} // namespace + +void PybindDecodableItf(py::module *m) { + using PyClass = DecodableInterface; + + py::class_(*m, "DecodableInterface") + .def(py::init<>()) + .def("log_likelihood", &PyClass::LogLikelihood, py::arg("frame"), + py::arg("index")) + .def("is_last_frame", &PyClass::IsLastFrame, py::arg("frame")) + .def("num_frames_ready", &PyClass::NumFramesReady) + .def("num_indices", &PyClass::NumIndices); +} + +} // namespace kaldi_decoder diff --git a/kaldi-decoder/python/csrc/decodable-itf.h b/kaldi-decoder/python/csrc/decodable-itf.h new file mode 100644 index 0000000..f9f8b35 --- /dev/null +++ b/kaldi-decoder/python/csrc/decodable-itf.h @@ -0,0 +1,16 @@ +// kaldi-decoder/python/csrc/decodable-itf.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef KALDI_DECODER_PYTHON_CSRC_DECODABLE_ITF_H_ +#define KALDI_DECODER_PYTHON_CSRC_DECODABLE_ITF_H_ + +#include "kaldi-decoder/python/csrc/kaldi-decoder.h" + +namespace kaldi_decoder { + +void PybindDecodableItf(py::module *m); + +} + +#endif // KALDI_DECODER_PYTHON_CSRC_DECODABLE_ITF_H_ diff --git a/kaldi-decoder/python/csrc/faster-decoder.cc b/kaldi-decoder/python/csrc/faster-decoder.cc new file mode 100644 index 0000000..2f7fb80 --- /dev/null +++ b/kaldi-decoder/python/csrc/faster-decoder.cc @@ -0,0 +1,55 @@ +// kaldi-decoder/python/csrc/faster-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "kaldi-decoder/python/csrc/faster-decoder.h" + +#include +#include + +#include "kaldi-decoder/csrc/faster-decoder.h" + +namespace kaldi_decoder { + +static void PybindFasterDecoderOptions(py::module *m) { + using PyClass = FasterDecoderOptions; + py::class_(*m, "FasterDecoderOptions") + .def(py::init(), + py::arg("beam") = 16.0, + py::arg("max_active") = std::numeric_limits::max(), + py::arg("min_active") = 20, py::arg("beam_delta") = 0.5, + py::arg("hash_ratio") = 2.0) + .def_readwrite("beam", &PyClass::beam) + .def_readwrite("max_active", &PyClass::max_active) + .def_readwrite("min_active", &PyClass::min_active) + .def_readwrite("beam_delta", &PyClass::beam_delta) + .def_readwrite("hash_ratio", &PyClass::hash_ratio) + .def("__str__", &PyClass::ToString); +} + +void PybindFasterDecoder(py::module *m) { + PybindFasterDecoderOptions(m); + using PyClass = FasterDecoder; + py::class_(*m, "FasterDecoder") + .def(py::init &, + const FasterDecoderOptions &>(), + py::arg("fst"), py::arg("config")) + .def("set_options", &PyClass::SetOptions, py::arg("config")) + .def("decode", &PyClass::Decode, py::arg("decodable")) + .def("reached_final", &PyClass::ReachedFinal) + .def( + "get_best_path", + [](PyClass &self, bool use_final_probs) + -> std::pair> { + fst::VectorFst fst; + bool ok = self.GetBestPath(&fst, use_final_probs); + return std::make_pair(ok, fst); + }, + py::arg("use_final_probs") = true) + .def("init_decoding", &PyClass::InitDecoding) + .def("advanced_decoding", &PyClass::AdvanceDecoding, py::arg("decodable"), + py::arg("max_num_frames") = -1) + .def("num_frames_decoded", &PyClass::NumFramesDecoded); +} + +} // namespace kaldi_decoder diff --git a/kaldi-decoder/python/csrc/faster-decoder.h b/kaldi-decoder/python/csrc/faster-decoder.h new file mode 100644 index 0000000..dbafed9 --- /dev/null +++ b/kaldi-decoder/python/csrc/faster-decoder.h @@ -0,0 +1,16 @@ +// kaldi-decoder/python/csrc/faster-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef KALDI_DECODER_PYTHON_CSRC_FASTER_DECODER_H_ +#define KALDI_DECODER_PYTHON_CSRC_FASTER_DECODER_H_ + +#include "kaldi-decoder/python/csrc/kaldi-decoder.h" + +namespace kaldi_decoder { + +void PybindFasterDecoder(py::module *m); + +} + +#endif // KALDI_DECODER_PYTHON_CSRC_FASTER_DECODER_H_ diff --git a/kaldi-decoder/python/csrc/kaldi-decoder.cc b/kaldi-decoder/python/csrc/kaldi-decoder.cc new file mode 100644 index 0000000..1b00418 --- /dev/null +++ b/kaldi-decoder/python/csrc/kaldi-decoder.cc @@ -0,0 +1,20 @@ +// kaldi-decoder/python/csrc/kaldi-decoder.cc +// +// Copyright (c) 2022 Xiaomi Corporation + +#include "kaldi-decoder/python/csrc/kaldi-decoder.h" + +#include "kaldi-decoder/python/csrc/decodable-ctc.h" +#include "kaldi-decoder/python/csrc/decodable-itf.h" +#include "kaldi-decoder/python/csrc/faster-decoder.h" + +namespace kaldi_decoder { + +PYBIND11_MODULE(_kaldi_decoder, m) { + m.doc() = "pybind11 binding of kaldi-decoder"; + PybindDecodableItf(&m); + PybindFasterDecoder(&m); + PybindDecodableCtc(&m); +} + +} // namespace kaldi_decoder diff --git a/kaldi-decoder/python/csrc/kaldi-decoder.h b/kaldi-decoder/python/csrc/kaldi-decoder.h new file mode 100644 index 0000000..dd2ec29 --- /dev/null +++ b/kaldi-decoder/python/csrc/kaldi-decoder.h @@ -0,0 +1,14 @@ +// kaldi-decoder/python/csrc/kaldi-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef KALDI_DECODER_PYTHON_CSRC_KALDI_DECODER_H_ +#define KALDI_DECODER_PYTHON_CSRC_KALDI_DECODER_H_ + +#include "pybind11/eigen.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +#endif // KALDI_DECODER_PYTHON_CSRC_KALDI_DECODER_H_ diff --git a/kaldi-decoder/python/kaldi_decoder/__init__.py b/kaldi-decoder/python/kaldi_decoder/__init__.py new file mode 100644 index 0000000..b70339a --- /dev/null +++ b/kaldi-decoder/python/kaldi_decoder/__init__.py @@ -0,0 +1,6 @@ +from _kaldi_decoder import ( + FasterDecoderOptions, + FasterDecoder, + DecodableInterface, + DecodableCtc, +) diff --git a/scripts/check_style_cpplint.sh b/scripts/check_style_cpplint.sh new file mode 100755 index 0000000..c14a80e --- /dev/null +++ b/scripts/check_style_cpplint.sh @@ -0,0 +1,112 @@ +#!/bin/bash +# +# Usage: +# +# (1) To check files of the last commit +# ./scripts/check_style_cpplint.sh +# +# (2) To check changed files not committed yet +# ./scripts/check_style_cpplint.sh 1 +# +# (3) To check all files in the project +# ./scripts/check_style_cpplint.sh 2 + + +cpplint_version="1.5.4" +cur_dir=$(cd $(dirname $BASH_SOURCE) && pwd) +kaldi_decoder_dir=$(cd $cur_dir/.. && pwd) + +build_dir=$kaldi_decoder_dir/build +mkdir -p $build_dir + +cpplint_src=$build_dir/cpplint-${cpplint_version}/cpplint.py + +if [ ! -d "$build_dir/cpplint-${cpplint_version}" ]; then + pushd $build_dir + if command -v wget &> /dev/null; then + wget https://github.com/cpplint/cpplint/archive/${cpplint_version}.tar.gz + elif command -v curl &> /dev/null; then + curl -O -SL https://github.com/cpplint/cpplint/archive/${cpplint_version}.tar.gz + else + echo "Please install wget or curl to download cpplint" + exit 1 + fi + tar xf ${cpplint_version}.tar.gz + rm ${cpplint_version}.tar.gz + + # cpplint will report the following error for: __host__ __device__ ( + # + # Extra space before ( in function call [whitespace/parens] [4] + # + # the following patch disables the above error + sed -i "3490i\ not Search(r'__host__ __device__\\\s+\\\(', fncall) and" $cpplint_src + popd +fi + +source $kaldi_decoder_dir/scripts/utils.sh + +# return true if the given file is a c++ source file +# return false otherwise +function is_source_code_file() { + case "$1" in + *.cc|*.h|*.cu) + echo true;; + *) + echo false;; + esac +} + +function check_style() { + python3 $cpplint_src $1 || abort $1 +} + +function check_last_commit() { + files=$(git diff HEAD^1 --name-only --diff-filter=ACDMRUXB) + echo $files +} + +function check_current_dir() { + files=$(git status -s -uno --porcelain | awk '{ + if (NF == 4) { + # a file has been renamed + print $NF + } else { + print $2 + }}') + + echo $files +} + +function do_check() { + case "$1" in + 1) + echo "Check changed files" + files=$(check_current_dir) + ;; + 2) + echo "Check all files" + files=$(find $kaldi_decoder_dir/kaldi-decoder -name "*.h" -o -name "*.cc" -o -name "*.cu") + ;; + *) + echo "Check last commit" + files=$(check_last_commit) + ;; + esac + + for f in $files; do + need_check=$(is_source_code_file $f) + if $need_check; then + [[ -f $f ]] && check_style $f + fi + done +} + +function main() { + do_check $1 + + ok "Great! Style check passed!" +} + +cd $kaldi_decoder_dir + +main $1 diff --git a/scripts/utils.sh b/scripts/utils.sh new file mode 100644 index 0000000..fb424a7 --- /dev/null +++ b/scripts/utils.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +default='\033[0m' +bold='\033[1m' +red='\033[31m' +green='\033[32m' + +function ok() { + printf "${bold}${green}[OK]${default} $1\n" +} + +function error() { + printf "${bold}${red}[FAILED]${default} $1\n" +} + +function abort() { + printf "${bold}${red}[FAILED]${default} $1\n" + exit 1 +} diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..98b5c62 --- /dev/null +++ b/setup.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 + +import os +import re +import sys +from pathlib import Path + +import setuptools + +from cmake.cmake_extension import ( + BuildExtension, + bdist_wheel, + cmake_extension, + is_windows, +) + + +def read_long_description(): + with open("README.md", encoding="utf8") as f: + readme = f.read() + return readme + + +def get_package_version(): + with open("CMakeLists.txt") as f: + content = f.read() + + match = re.search(r"set\(KALDI_DECODER_VERSION (.*)\)", content) + latest_version = match.group(1).strip('"') + return latest_version + + +package_name = "kaldi-decoder" + +with open("kaldi-decoder/python/kaldi_decoder/__init__.py", "a") as f: + f.write(f"__version__ = '{get_package_version()}'\n") + +setuptools.setup( + name=package_name, + python_requires=">=3.6", + version=get_package_version(), + author="The next-gen Kaldi development team", + author_email="csukuangfj@gmail.com", + package_dir={ + "kaldi_decoder": "kaldi-decoder/python/kaldi_decoder", + }, + packages=["kaldi_decoder"], + url="https://github.com/k2-fsa/kaldi-decoder", + long_description=read_long_description(), + long_description_content_type="text/markdown", + ext_modules=[cmake_extension("_kaldi_decoder")], + cmdclass={"build_ext": BuildExtension, "bdist_wheel": bdist_wheel}, + zip_safe=False, + classifiers=[ + "Programming Language :: C++", + "Programming Language :: Python", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + license="Apache licensed, as found in the LICENSE file", +) + +with open("kaldi-decoder/python/kaldi_decoder/__init__.py", "r") as f: + lines = f.readlines() + +with open("kaldi-decoder/python/kaldi_decoder/__init__.py", "w") as f: + for line in lines: + if "__version__" in line: + # skip __version__ = "x.x.x" + continue + f.write(line)