diff --git a/README.md b/README.md
index 03c17c3a0..10671a60f 100644
--- a/README.md
+++ b/README.md
@@ -116,6 +116,7 @@ Want to jump straight in? Get started with one of our sample applications/templa
| Semantic Image Search (server-side) | Search for images with text (Supabase) | [code](./examples/semantic-image-search/), [demo](https://huggingface.co/spaces/Xenova/semantic-image-search) |
| Vanilla JavaScript | In-browser object detection | [video](https://scrimba.com/scrim/cKm9bDAg), [code](./examples/vanilla-js/), [demo](https://huggingface.co/spaces/Scrimba/vanilla-js-object-detector) |
| React | Multilingual translation website | [code](./examples/react-translator/), [demo](https://huggingface.co/spaces/Xenova/react-translator) |
+| Text to speech (client-side) | In-browser speech synthesis | [code](./examples/text-to-speech-client/), [demo](https://huggingface.co/spaces/Xenova/text-to-speech-client) |
| Browser extension | Text classification extension | [code](./examples/extension/) |
| Electron | Text classification application | [code](./examples/electron/) |
| Next.js (client-side) | Sentiment analysis (in-browser inference) | [code](./examples/next-client/), [demo](https://huggingface.co/spaces/Xenova/next-example-app) |
@@ -222,7 +223,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
| [Audio Classification](https://huggingface.co/tasks/audio-classification) | `audio-classification` | Assigning a label or class to a given audio. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.AudioClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=audio-classification&library=transformers.js) |
| [Audio-to-Audio](https://huggingface.co/tasks/audio-to-audio) | n/a | Generating audio from an input audio source. | ❌ |
| [Automatic Speech Recognition](https://huggingface.co/tasks/automatic-speech-recognition) | `automatic-speech-recognition` | Transcribing a given audio into text. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.AutomaticSpeechRecognitionPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&library=transformers.js) |
-| [Text-to-Speech](https://huggingface.co/tasks/text-to-speech) | n/a | Generating natural-sounding speech given text input. | ❌ |
+| [Text-to-Speech](https://huggingface.co/tasks/text-to-speech) | `text-to-speech` or `text-to-audio` | | Generating natural-sounding speech given text input. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.TextToAudioPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=text-to-audio&library=transformers.js) |
#### Tabular
diff --git a/docs/snippets/3_examples.snippet b/docs/snippets/3_examples.snippet
index 009494342..6af3da7ec 100644
--- a/docs/snippets/3_examples.snippet
+++ b/docs/snippets/3_examples.snippet
@@ -9,6 +9,7 @@ Want to jump straight in? Get started with one of our sample applications/templa
| Semantic Image Search (server-side) | Search for images with text (Supabase) | [code](./examples/semantic-image-search/), [demo](https://huggingface.co/spaces/Xenova/semantic-image-search) |
| Vanilla JavaScript | In-browser object detection | [video](https://scrimba.com/scrim/cKm9bDAg), [code](./examples/vanilla-js/), [demo](https://huggingface.co/spaces/Scrimba/vanilla-js-object-detector) |
| React | Multilingual translation website | [code](./examples/react-translator/), [demo](https://huggingface.co/spaces/Xenova/react-translator) |
+| Text to speech (client-side) | In-browser speech synthesis | [code](./examples/text-to-speech-client/), [demo](https://huggingface.co/spaces/Xenova/text-to-speech-client) |
| Browser extension | Text classification extension | [code](./examples/extension/) |
| Electron | Text classification application | [code](./examples/electron/) |
| Next.js (client-side) | Sentiment analysis (in-browser inference) | [code](./examples/next-client/), [demo](https://huggingface.co/spaces/Xenova/next-example-app) |
diff --git a/docs/snippets/5_supported-tasks.snippet b/docs/snippets/5_supported-tasks.snippet
index e2da1636a..002634869 100644
--- a/docs/snippets/5_supported-tasks.snippet
+++ b/docs/snippets/5_supported-tasks.snippet
@@ -38,7 +38,7 @@
| [Audio Classification](https://huggingface.co/tasks/audio-classification) | `audio-classification` | Assigning a label or class to a given audio. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.AudioClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=audio-classification&library=transformers.js) |
| [Audio-to-Audio](https://huggingface.co/tasks/audio-to-audio) | n/a | Generating audio from an input audio source. | ❌ |
| [Automatic Speech Recognition](https://huggingface.co/tasks/automatic-speech-recognition) | `automatic-speech-recognition` | Transcribing a given audio into text. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.AutomaticSpeechRecognitionPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&library=transformers.js) |
-| [Text-to-Speech](https://huggingface.co/tasks/text-to-speech) | n/a | Generating natural-sounding speech given text input. | ❌ |
+| [Text-to-Speech](https://huggingface.co/tasks/text-to-speech) | `text-to-speech` or `text-to-audio` | | Generating natural-sounding speech given text input. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.TextToAudioPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=text-to-audio&library=transformers.js) |
#### Tabular
diff --git a/examples/text-to-speech-client/.eslintrc.cjs b/examples/text-to-speech-client/.eslintrc.cjs
new file mode 100644
index 000000000..4dcb43901
--- /dev/null
+++ b/examples/text-to-speech-client/.eslintrc.cjs
@@ -0,0 +1,20 @@
+module.exports = {
+ root: true,
+ env: { browser: true, es2020: true },
+ extends: [
+ 'eslint:recommended',
+ 'plugin:react/recommended',
+ 'plugin:react/jsx-runtime',
+ 'plugin:react-hooks/recommended',
+ ],
+ ignorePatterns: ['dist', '.eslintrc.cjs'],
+ parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
+ settings: { react: { version: '18.2' } },
+ plugins: ['react-refresh'],
+ rules: {
+ 'react-refresh/only-export-components': [
+ 'warn',
+ { allowConstantExport: true },
+ ],
+ },
+}
diff --git a/examples/text-to-speech-client/.gitignore b/examples/text-to-speech-client/.gitignore
new file mode 100644
index 000000000..a547bf36d
--- /dev/null
+++ b/examples/text-to-speech-client/.gitignore
@@ -0,0 +1,24 @@
+# Logs
+logs
+*.log
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+pnpm-debug.log*
+lerna-debug.log*
+
+node_modules
+dist
+dist-ssr
+*.local
+
+# Editor directories and files
+.vscode/*
+!.vscode/extensions.json
+.idea
+.DS_Store
+*.suo
+*.ntvs*
+*.njsproj
+*.sln
+*.sw?
diff --git a/examples/text-to-speech-client/README.md b/examples/text-to-speech-client/README.md
new file mode 100644
index 000000000..f768e33fc
--- /dev/null
+++ b/examples/text-to-speech-client/README.md
@@ -0,0 +1,8 @@
+# React + Vite
+
+This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
+
+Currently, two official plugins are available:
+
+- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react/README.md) uses [Babel](https://babeljs.io/) for Fast Refresh
+- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
diff --git a/examples/text-to-speech-client/index.html b/examples/text-to-speech-client/index.html
new file mode 100644
index 000000000..26748c68f
--- /dev/null
+++ b/examples/text-to-speech-client/index.html
@@ -0,0 +1,12 @@
+
+
+
+
+
+ Transformers.js - Text-to-speech demo
+
+
+
+
+
+
diff --git a/examples/text-to-speech-client/package.json b/examples/text-to-speech-client/package.json
new file mode 100644
index 000000000..49949efa2
--- /dev/null
+++ b/examples/text-to-speech-client/package.json
@@ -0,0 +1,30 @@
+{
+ "name": "text-to-speech-client",
+ "private": true,
+ "version": "0.0.0",
+ "type": "module",
+ "scripts": {
+ "dev": "vite",
+ "build": "vite build",
+ "lint": "eslint . --ext js,jsx --report-unused-disable-directives --max-warnings 0",
+ "preview": "vite preview"
+ },
+ "dependencies": {
+ "@xenova/transformers": "^2.7.0",
+ "react": "^18.2.0",
+ "react-dom": "^18.2.0"
+ },
+ "devDependencies": {
+ "@types/react": "^18.2.15",
+ "@types/react-dom": "^18.2.7",
+ "@vitejs/plugin-react": "^4.0.3",
+ "autoprefixer": "^10.4.16",
+ "eslint": "^8.45.0",
+ "eslint-plugin-react": "^7.32.2",
+ "eslint-plugin-react-hooks": "^4.6.0",
+ "eslint-plugin-react-refresh": "^0.4.3",
+ "postcss": "^8.4.31",
+ "tailwindcss": "^3.3.3",
+ "vite": "^4.4.5"
+ }
+}
diff --git a/examples/text-to-speech-client/postcss.config.js b/examples/text-to-speech-client/postcss.config.js
new file mode 100644
index 000000000..2e7af2b7f
--- /dev/null
+++ b/examples/text-to-speech-client/postcss.config.js
@@ -0,0 +1,6 @@
+export default {
+ plugins: {
+ tailwindcss: {},
+ autoprefixer: {},
+ },
+}
diff --git a/examples/text-to-speech-client/src/App.jsx b/examples/text-to-speech-client/src/App.jsx
new file mode 100644
index 000000000..cfa53fb9a
--- /dev/null
+++ b/examples/text-to-speech-client/src/App.jsx
@@ -0,0 +1,162 @@
+import React, { useState, useEffect, useRef } from 'react';
+
+import AudioPlayer from './components/AudioPlayer';
+import Progress from './components/Progress';
+import { SPEAKERS, DEFAULT_SPEAKER } from './constants';
+
+const App = () => {
+
+ // Model loading
+ const [ready, setReady] = useState(null);
+ const [disabled, setDisabled] = useState(false);
+ const [progressItems, setProgressItems] = useState([]);
+
+ // Inputs and outputs
+ const [text, setText] = useState('I love Hugging Face!');
+ const [selectedSpeaker, setSelectedSpeaker] = useState(DEFAULT_SPEAKER);
+ const [output, setOutput] = useState(null);
+
+ // Create a reference to the worker object.
+ const worker = useRef(null);
+
+ // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
+ useEffect(() => {
+ if (!worker.current) {
+ // Create the worker if it does not yet exist.
+ worker.current = new Worker(new URL('./worker.js', import.meta.url), {
+ type: 'module'
+ });
+ }
+
+ // Create a callback function for messages from the worker thread.
+ const onMessageReceived = (e) => {
+ switch (e.data.status) {
+ case 'initiate':
+ // Model file start load: add a new progress item to the list.
+ setReady(false);
+ setProgressItems(prev => [...prev, e.data]);
+ break;
+
+ case 'progress':
+ // Model file progress: update one of the progress items.
+ setProgressItems(
+ prev => prev.map(item => {
+ if (item.file === e.data.file) {
+ return { ...item, progress: e.data.progress }
+ }
+ return item;
+ })
+ );
+ break;
+
+ case 'done':
+ // Model file loaded: remove the progress item from the list.
+ setProgressItems(
+ prev => prev.filter(item => item.file !== e.data.file)
+ );
+ break;
+
+ case 'ready':
+ // Pipeline ready: the worker is ready to accept messages.
+ setReady(true);
+ break;
+
+ case 'complete':
+ // Generation complete: re-enable the "Translate" button
+ setDisabled(false);
+
+ const blobUrl = URL.createObjectURL(e.data.output);
+ setOutput(blobUrl);
+ break;
+ }
+ };
+
+ // Attach the callback function as an event listener.
+ worker.current.addEventListener('message', onMessageReceived);
+
+ // Define a cleanup function for when the component is unmounted.
+ return () => worker.current.removeEventListener('message', onMessageReceived);
+ });
+
+
+ const handleGenerateSpeech = () => {
+ setDisabled(true);
+ worker.current.postMessage({
+ text,
+ speaker_id: selectedSpeaker,
+ });
+ };
+
+ const isLoading = ready === false;
+ return (
+
+
+ {isLoading && (
+
+ )}
+ {progressItems.map(data => (
+
+ ))}
+
+
+
In-browser Text to Speech
+
+
+
+
+
+
+
+
+
+
+
+
+ {output &&
}
+
+
+ );
+};
+
+export default App;
diff --git a/examples/text-to-speech-client/src/components/AudioPlayer.jsx b/examples/text-to-speech-client/src/components/AudioPlayer.jsx
new file mode 100644
index 000000000..a6d2daf1f
--- /dev/null
+++ b/examples/text-to-speech-client/src/components/AudioPlayer.jsx
@@ -0,0 +1,26 @@
+import { useEffect, useRef } from "react";
+
+export default function AudioPlayer({ audioUrl, mimeType }) {
+ const audioPlayer = useRef(null);
+ const audioSource = useRef(null);
+
+ // Updates src when url changes
+ useEffect(() => {
+ if (audioPlayer.current && audioSource.current) {
+ audioSource.current.src = audioUrl;
+ audioPlayer.current.load();
+ }
+ }, [audioUrl]);
+
+ return (
+
+ );
+}
\ No newline at end of file
diff --git a/examples/text-to-speech-client/src/components/Progress.jsx b/examples/text-to-speech-client/src/components/Progress.jsx
new file mode 100644
index 000000000..efaaf0a9a
--- /dev/null
+++ b/examples/text-to-speech-client/src/components/Progress.jsx
@@ -0,0 +1,12 @@
+
+export default function Progress({ text, percentage }) {
+ percentage ??= 0;
+ return (
+
+
+ {text} ({`${percentage.toFixed(2)}%`})
+
+
+ );
+}
+
diff --git a/examples/text-to-speech-client/src/constants.js b/examples/text-to-speech-client/src/constants.js
new file mode 100644
index 000000000..ef6d848af
--- /dev/null
+++ b/examples/text-to-speech-client/src/constants.js
@@ -0,0 +1,11 @@
+export const SPEAKERS = {
+ "US female 1": "cmu_us_slt_arctic-wav-arctic_a0001",
+ "US female 2": "cmu_us_clb_arctic-wav-arctic_a0001",
+ "US male 1": "cmu_us_bdl_arctic-wav-arctic_a0003",
+ "US male 2": "cmu_us_rms_arctic-wav-arctic_a0003",
+ "Canadian male": "cmu_us_jmk_arctic-wav-arctic_a0002",
+ "Scottish male": "cmu_us_awb_arctic-wav-arctic_b0002",
+ "Indian male": "cmu_us_ksp_arctic-wav-arctic_a0007",
+}
+
+export const DEFAULT_SPEAKER = "cmu_us_slt_arctic-wav-arctic_a0001";
diff --git a/examples/text-to-speech-client/src/index.css b/examples/text-to-speech-client/src/index.css
new file mode 100644
index 000000000..2ea01764b
--- /dev/null
+++ b/examples/text-to-speech-client/src/index.css
@@ -0,0 +1,21 @@
+@tailwind base;
+@tailwind components;
+@tailwind utilities;
+
+:root {
+ font-family: Inter, system-ui, Avenir, Helvetica, Arial, sans-serif;
+ line-height: 1.5;
+ font-weight: 400;
+ color: #213547;
+ background-color: #ffffff;
+
+ font-synthesis: none;
+ text-rendering: optimizeLegibility;
+ -webkit-font-smoothing: antialiased;
+ -moz-osx-font-smoothing: grayscale;
+ -webkit-text-size-adjust: 100%;
+}
+
+audio::-webkit-media-controls-panel {
+ background-color: white;
+}
\ No newline at end of file
diff --git a/examples/text-to-speech-client/src/main.jsx b/examples/text-to-speech-client/src/main.jsx
new file mode 100644
index 000000000..54b39dd1d
--- /dev/null
+++ b/examples/text-to-speech-client/src/main.jsx
@@ -0,0 +1,10 @@
+import React from 'react'
+import ReactDOM from 'react-dom/client'
+import App from './App.jsx'
+import './index.css'
+
+ReactDOM.createRoot(document.getElementById('root')).render(
+
+
+ ,
+)
diff --git a/examples/text-to-speech-client/src/utils.js b/examples/text-to-speech-client/src/utils.js
new file mode 100644
index 000000000..b23f88bba
--- /dev/null
+++ b/examples/text-to-speech-client/src/utils.js
@@ -0,0 +1,47 @@
+// Adapted from https://www.npmjs.com/package/audiobuffer-to-wav
+
+export function encodeWAV(samples) {
+ let offset = 44;
+ const buffer = new ArrayBuffer(offset + samples.length * 4);
+ const view = new DataView(buffer);
+ const sampleRate = 16000;
+
+ /* RIFF identifier */
+ writeString(view, 0, 'RIFF')
+ /* RIFF chunk length */
+ view.setUint32(4, 36 + samples.length * 4, true)
+ /* RIFF type */
+ writeString(view, 8, 'WAVE')
+ /* format chunk identifier */
+ writeString(view, 12, 'fmt ')
+ /* format chunk length */
+ view.setUint32(16, 16, true)
+ /* sample format (raw) */
+ view.setUint16(20, 3, true)
+ /* channel count */
+ view.setUint16(22, 1, true)
+ /* sample rate */
+ view.setUint32(24, sampleRate, true)
+ /* byte rate (sample rate * block align) */
+ view.setUint32(28, sampleRate * 4, true)
+ /* block align (channel count * bytes per sample) */
+ view.setUint16(32, 4, true)
+ /* bits per sample */
+ view.setUint16(34, 32, true)
+ /* data chunk identifier */
+ writeString(view, 36, 'data')
+ /* data chunk length */
+ view.setUint32(40, samples.length * 4, true)
+
+ for (let i = 0; i < samples.length; ++i, offset += 4) {
+ view.setFloat32(offset, samples[i], true)
+ }
+
+ return buffer
+}
+
+function writeString(view, offset, string) {
+ for (let i = 0; i < string.length; ++i) {
+ view.setUint8(offset + i, string.charCodeAt(i))
+ }
+}
diff --git a/examples/text-to-speech-client/src/worker.js b/examples/text-to-speech-client/src/worker.js
new file mode 100644
index 000000000..76b8f76ef
--- /dev/null
+++ b/examples/text-to-speech-client/src/worker.js
@@ -0,0 +1,97 @@
+
+import { env, Tensor, AutoTokenizer, SpeechT5ForTextToSpeech, SpeechT5HifiGan } from '@xenova/transformers';
+import { encodeWAV } from './utils';
+
+// Disable local model checks
+env.allowLocalModels = false;
+
+
+// Use the Singleton pattern to enable lazy construction of the pipeline.
+class MyTextToSpeechPipeline {
+
+ static BASE_URL = 'https://huggingface.co/datasets/Xenova/cmu-arctic-xvectors-extracted/resolve/main/';
+
+ static model_id = 'Xenova/speecht5_tts';
+ static vocoder_id = 'Xenova/speecht5_hifigan';
+
+ static tokenizer_instance = null;
+ static model_instance = null;
+ static vocoder_instance = null;
+
+ static async getInstance(progress_callback = null) {
+ if (this.tokenizer_instance === null) {
+ this.tokenizer = AutoTokenizer.from_pretrained(this.model_id, { progress_callback });
+ }
+
+ if (this.model_instance === null) {
+ this.model_instance = SpeechT5ForTextToSpeech.from_pretrained(this.model_id, {
+ quantized: false,
+ progress_callback,
+ });
+ }
+
+ if (this.vocoder_instance === null) {
+ this.vocoder_instance = SpeechT5HifiGan.from_pretrained(this.vocoder_id, {
+ quantized: false,
+ progress_callback,
+ });
+ }
+
+ return new Promise(async (resolve, reject) => {
+ const result = await Promise.all([
+ this.tokenizer,
+ this.model_instance,
+ this.vocoder_instance,
+ ]);
+ self.postMessage({
+ status: 'ready',
+ });
+ resolve(result);
+ });
+ }
+
+ static async getSpeakerEmbeddings(speaker_id) {
+ // e.g., `cmu_us_awb_arctic-wav-arctic_a0001`
+ const speaker_embeddings_url = `${this.BASE_URL}${speaker_id}.bin`;
+ const speaker_embeddings = new Tensor(
+ 'float32',
+ new Float32Array(await (await fetch(speaker_embeddings_url)).arrayBuffer()),
+ [1, 512]
+ )
+ return speaker_embeddings;
+ }
+}
+
+// Mapping of cached speaker embeddings
+const speaker_embeddings_cache = new Map();
+
+// Listen for messages from the main thread
+self.addEventListener('message', async (event) => {
+ // Load the pipeline
+ const [tokenizer, model, vocoder] = await MyTextToSpeechPipeline.getInstance(x => {
+ // We also add a progress callback so that we can track model loading.
+ self.postMessage(x);
+ });
+
+ // Tokenize the input
+ const { input_ids } = tokenizer(event.data.text);
+
+ // Load the speaker embeddings
+ let speaker_embeddings = speaker_embeddings_cache.get(event.data.speaker_id);
+ if (speaker_embeddings === undefined) {
+ speaker_embeddings = await MyTextToSpeechPipeline.getSpeakerEmbeddings(event.data.speaker_id);
+ speaker_embeddings_cache.set(event.data.speaker_id, speaker_embeddings);
+ }
+
+ // Generate the waveform
+ const { waveform } = await model.generate_speech(input_ids, speaker_embeddings, { vocoder });
+
+ // Encode the waveform as a WAV file
+ const wav = encodeWAV(waveform.data);
+
+ // Send the output back to the main thread
+ self.postMessage({
+ status: 'complete',
+ output: new Blob([wav], { type: 'audio/wav' }),
+ });
+});
diff --git a/examples/text-to-speech-client/tailwind.config.js b/examples/text-to-speech-client/tailwind.config.js
new file mode 100644
index 000000000..d37737fc0
--- /dev/null
+++ b/examples/text-to-speech-client/tailwind.config.js
@@ -0,0 +1,12 @@
+/** @type {import('tailwindcss').Config} */
+export default {
+ content: [
+ "./index.html",
+ "./src/**/*.{js,ts,jsx,tsx}",
+ ],
+ theme: {
+ extend: {},
+ },
+ plugins: [],
+}
+
diff --git a/examples/text-to-speech-client/vite.config.js b/examples/text-to-speech-client/vite.config.js
new file mode 100644
index 000000000..5a33944a9
--- /dev/null
+++ b/examples/text-to-speech-client/vite.config.js
@@ -0,0 +1,7 @@
+import { defineConfig } from 'vite'
+import react from '@vitejs/plugin-react'
+
+// https://vitejs.dev/config/
+export default defineConfig({
+ plugins: [react()],
+})
diff --git a/scripts/convert.py b/scripts/convert.py
index b86abb9b6..0a68cd5ca 100644
--- a/scripts/convert.py
+++ b/scripts/convert.py
@@ -307,6 +307,17 @@ def main():
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
json.dump(tokenizer_json, fp, indent=4)
+ elif config.model_type == 'speecht5':
+ # TODO allow user to specify vocoder path
+ export_kwargs["model_kwargs"] = {"vocoder": "microsoft/speecht5_hifigan"}
+
+ if tokenizer is not None:
+ from .extra.speecht5 import generate_tokenizer_json
+ tokenizer_json = generate_tokenizer_json(tokenizer)
+
+ with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
+ json.dump(tokenizer_json, fp, indent=4)
+
else:
pass # TODO
diff --git a/scripts/extra/speecht5.py b/scripts/extra/speecht5.py
new file mode 100644
index 000000000..e01992c5c
--- /dev/null
+++ b/scripts/extra/speecht5.py
@@ -0,0 +1,116 @@
+import json
+
+
+def generate_tokenizer_json(tokenizer):
+ vocab = tokenizer.get_vocab()
+
+ tokenizer_json = {
+ "version": "1.0",
+ "truncation": None,
+ "padding": None,
+ "added_tokens": [
+ {
+ "id": vocab[token],
+ "content": token,
+ "single_word": False,
+ "lstrip": False,
+ "rstrip": False,
+ "normalized": False,
+ "special": True
+ }
+ for token in vocab
+ if token.startswith('<') and token.endswith('>')
+ ],
+
+ "normalizer": {
+ "type": "Precompiled",
+ "precompiled_charsmap": None
+ },
+ "pre_tokenizer": {
+ "type": "Sequence",
+ "pretokenizers": [
+ {
+ "type": "WhitespaceSplit"
+ },
+ {
+ "type": "Metaspace",
+ "replacement": "▁",
+ "add_prefix_space": True
+ },
+ {
+ "type": "Split",
+ "pattern": {
+ "Regex": ""
+ },
+ "behavior": "Isolated",
+ "invert": False
+ }
+ ]
+ },
+ "post_processor": {
+ "type": "TemplateProcessing",
+ "single": [
+ {
+ "Sequence": {
+ "id": "A",
+ "type_id": 0
+ }
+ },
+ {
+ "SpecialToken": {
+ "id": "",
+ "type_id": 0
+ }
+ }
+ ],
+ "pair": [
+ {
+ "Sequence": {
+ "id": "A",
+ "type_id": 0
+ }
+ },
+ {
+ "SpecialToken": {
+ "id": "",
+ "type_id": 0
+ }
+ },
+ {
+ "Sequence": {
+ "id": "B",
+ "type_id": 0
+ }
+ },
+ {
+ "SpecialToken": {
+ "id": "",
+ "type_id": 0
+ }
+ }
+ ],
+ "special_tokens": {
+ "": {
+ "id": "",
+ "ids": [
+ 2
+ ],
+ "tokens": [
+ ""
+ ]
+ }
+ }
+ },
+ "decoder": {
+ "type": "Metaspace",
+ "replacement": "▁",
+ "add_prefix_space": True
+ },
+ 'model': {
+ # 'type': 'Char',
+ 'unk_id': 2,
+ "vocab": vocab
+ }
+ }
+
+ return tokenizer_json
diff --git a/scripts/supported_models.py b/scripts/supported_models.py
index 56531ea8d..03711b806 100644
--- a/scripts/supported_models.py
+++ b/scripts/supported_models.py
@@ -375,6 +375,10 @@
# 'facebook/sam-vit-large',
# 'facebook/sam-vit-huge',
# ],
+ 'speecht5': [
+ # Text-to-speech
+ 'microsoft/speecht5_tts',
+ ],
'squeezebert': [
# Feature extraction
'squeezebert/squeezebert-uncased',
diff --git a/src/models.js b/src/models.js
index 8e5085f62..7c3b55964 100644
--- a/src/models.js
+++ b/src/models.js
@@ -74,6 +74,7 @@ import {
cat,
dynamicTimeWarping,
mean,
+ ones_like,
stack,
std_mean,
Tensor,
@@ -278,11 +279,7 @@ function prepareAttentionMask(self, tokens) {
)
return new Tensor('int64', data, tokens.dims)
} else {
- return new Tensor(
- 'int64',
- new BigInt64Array(tokens.data.length).fill(1n),
- tokens.dims
- )
+ return ones_like(tokens);
}
}
@@ -928,7 +925,9 @@ export class PreTrainedModel extends Callable {
const modelType = this.config.model_type;
const possibleInfo =
MODEL_WITH_LM_HEAD_MAPPING_NAMES.get(modelType)
- ?? MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES.get(modelType)
+ ?? MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES.get(modelType)
+ ?? MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.get(modelType)
+ // ?? MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES.get(modelType) // TODO
?? MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.get(modelType);
if (possibleInfo) {
@@ -3563,6 +3562,151 @@ export class WavLMForSequenceClassification extends WavLMPreTrainedModel {
}
}
+//////////////////////////////////////////////////
+// SpeechT5 models
+/**
+ * An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
+ */
+export class SpeechT5PreTrainedModel extends PreTrainedModel { };
+
+/**
+ * The bare SpeechT5 Encoder-Decoder Model outputting raw hidden-states without any specific pre- or post-nets.
+ */
+export class SpeechT5Model extends SpeechT5PreTrainedModel { };
+
+/**
+ * SpeechT5 Model with a speech encoder and a text decoder.
+ */
+export class SpeechT5ForSpeechToText extends SpeechT5PreTrainedModel { }
+
+/**
+ * SpeechT5 Model with a text encoder and a speech decoder.
+ */
+export class SpeechT5ForTextToSpeech extends SpeechT5PreTrainedModel {
+
+ /**
+ * Creates a new instance of the `SpeechT5ForTextToSpeech` class.
+ * @param {Object} config The model configuration.
+ * @param {any} session session for the model.
+ * @param {any} decoder_merged_session session for the decoder.
+ * @param {GenerationConfig} generation_config The generation configuration.
+ */
+ constructor(config, session, decoder_merged_session, generation_config) {
+ super(config, session);
+ this.decoder_merged_session = decoder_merged_session;
+ this.generation_config = generation_config;
+
+ this.num_decoder_layers = this.config.decoder_layers;
+ this.num_decoder_heads = this.config.decoder_attention_heads;
+ this.decoder_dim_kv = this.config.hidden_size / this.num_decoder_heads;
+
+ this.num_encoder_layers = this.config.encoder_layers;
+ this.num_encoder_heads = this.config.encoder_attention_heads;
+ this.encoder_dim_kv = this.config.hidden_size / this.num_encoder_heads;
+ }
+
+ /**
+ * @typedef {Object} SpeechOutput
+ * @property {Tensor} [spectrogram] The predicted log-mel spectrogram of shape
+ * `(output_sequence_length, config.num_mel_bins)`. Returned when no `vocoder` is provided
+ * @property {Tensor} [waveform] The predicted waveform of shape `(num_frames,)`. Returned when a `vocoder` is provided.
+ * @property {Tensor} [cross_attentions] The outputs of the decoder's cross-attention layers of shape
+ * `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length, input_sequence_length)`. returned when `output_cross_attentions` is `true`.
+ */
+
+ /**
+ * Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a speech waveform using a vocoder.
+ * @param {Tensor} input_values Indices of input sequence tokens in the vocabulary.
+ * @param {Tensor} speaker_embeddings Tensor containing the speaker embeddings.
+ * @param {Object} options Optional parameters for generating speech.
+ * @param {number} [options.threshold=0.5] The generated sequence ends when the predicted stop token probability exceeds this value.
+ * @param {number} [options.minlenratio=0.0] Used to calculate the minimum required length for the output sequence.
+ * @param {number} [options.maxlenratio=20.0] Used to calculate the maximum allowed length for the output sequence.
+ * @param {Object} [options.vocoder=null] The vocoder that converts the mel spectrogram into a speech waveform. If `null`, the output is the mel spectrogram.
+ * @param {boolean} [options.output_cross_attentions=false] Whether or not to return the attentions tensors of the decoder's cross-attention layers.
+ * @returns {Promise} A promise which resolves to an object containing the spectrogram, waveform, and cross-attention tensors.
+ */
+ async generate_speech(input_values, speaker_embeddings, {
+ threshold = 0.5,
+ minlenratio = 0.0,
+ maxlenratio = 20.0,
+ vocoder = null,
+ // output_cross_attentions = false, // TODO add
+ } = {}) {
+
+ const model_inputs = {
+ input_ids: input_values
+ }
+
+ const { encoder_outputs, encoder_attention_mask } = await encoderForward(this, model_inputs);
+
+ const r = encoder_outputs.dims[1] / this.config.reduction_factor;
+ const maxlen = Math.floor(r * maxlenratio);
+ const minlen = Math.floor(r * minlenratio);
+
+ const num_mel_bins = this.config.num_mel_bins;
+
+ let spectrogramParts = [];
+ let past_key_values = null;
+ let decoder_outputs = null;
+ let idx = 0;
+
+ while (true) {
+ ++idx;
+
+ const use_cache_branch = boolTensor(!!decoder_outputs);
+ let output_sequence;
+ if (decoder_outputs) {
+ output_sequence = decoder_outputs.output_sequence_out;
+ } else {
+ output_sequence = new Tensor(
+ 'float32',
+ new Float32Array(num_mel_bins),
+ [1, 1, num_mel_bins],
+ )
+ }
+ let decoderFeeds = {
+ use_cache_branch,
+ output_sequence,
+ encoder_attention_mask: encoder_attention_mask,
+ speaker_embeddings: speaker_embeddings,
+ encoder_hidden_states: encoder_outputs,
+ };
+
+ this.addPastKeyValues(decoderFeeds, past_key_values);
+ decoder_outputs = await sessionRun(this.decoder_merged_session, decoderFeeds);
+ past_key_values = this.getPastKeyValues(decoder_outputs, past_key_values);
+
+ const { prob, spectrum } = decoder_outputs;
+ spectrogramParts.push(spectrum);
+
+ if (idx >= minlen && (
+ // Finished when stop token or maximum length is reached.
+ Array.from(prob.data).filter(p => p >= threshold).length > 0 || idx >= maxlen
+ )) {
+ break;
+ }
+ }
+
+ const spectrogram = cat(spectrogramParts);
+ const { waveform } = await sessionRun(vocoder.session, { spectrogram });
+
+ return {
+ spectrogram,
+ waveform,
+ // cross_attentions: null, // TODO add
+ }
+ }
+}
+
+/**
+ * HiFi-GAN vocoder.
+ */
+export class SpeechT5HifiGan extends PreTrainedModel {
+ main_input_name = 'spectrogram';
+}
+//////////////////////////////////////////////////
+
//////////////////////////////////////////////////
// AutoModels, used to simplify construction of PreTrainedModels
// (uses config to instantiate correct class)
@@ -3659,6 +3803,8 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
['donut-swin', ['DonutSwinModel', DonutSwinModel]],
['yolos', ['YolosModel', YolosModel]],
+ ['hifigan', ['SpeechT5HifiGan', SpeechT5HifiGan]],
+
['sam', ['SamModel', SamModel]], // TODO change to encoder-decoder when model is split correctly
]);
@@ -3689,6 +3835,15 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
['opt', ['OPTModel', OPTModel]],
]);
+const MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = new Map([
+ ['speecht5', ['SpeechT5ForSpeechToText', SpeechT5ForSpeechToText]],
+ ['whisper', ['WhisperForConditionalGeneration', WhisperForConditionalGeneration]],
+])
+
+const MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = new Map([
+ ['speecht5', ['SpeechT5ForTextToSpeech', SpeechT5ForTextToSpeech]],
+])
+
const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([
['bert', ['BertForSequenceClassification', BertForSequenceClassification]],
['camembert', ['CamembertForSequenceClassification', CamembertForSequenceClassification]],
@@ -3718,13 +3873,12 @@ const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([
['xlm-roberta', ['XLMRobertaForTokenClassification', XLMRobertaForTokenClassification]],
]);
-const MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES = new Map([
+const MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = new Map([
['t5', ['T5ForConditionalGeneration', T5ForConditionalGeneration]],
['longt5', ['LongT5ForConditionalGeneration', LongT5ForConditionalGeneration]],
['mt5', ['MT5ForConditionalGeneration', MT5ForConditionalGeneration]],
['bart', ['BartForConditionalGeneration', BartForConditionalGeneration]],
['mbart', ['MBartForConditionalGeneration', MBartForConditionalGeneration]],
- ['whisper', ['WhisperForConditionalGeneration', WhisperForConditionalGeneration]],
['marian', ['MarianMTModel', MarianMTModel]],
['m2m_100', ['M2M100ForConditionalGeneration', M2M100ForConditionalGeneration]],
['blenderbot', ['BlenderbotForConditionalGeneration', BlenderbotForConditionalGeneration]],
@@ -3822,7 +3976,8 @@ const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_MAPPING_NAMES_DECODER_ONLY, MODEL_TYPES.DecoderOnly],
[MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
- [MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
+ [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
+ [MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
[MODEL_WITH_LM_HEAD_MAPPING_NAMES, MODEL_TYPES.DecoderOnly],
[MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
@@ -3833,6 +3988,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
+ [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
];
for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) {
@@ -3897,7 +4053,29 @@ export class AutoModelForTokenClassification extends PretrainedMixin {
* let model = await AutoModelForSeq2SeqLM.from_pretrained('t5-small');
*/
export class AutoModelForSeq2SeqLM extends PretrainedMixin {
- static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES];
+ static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES];
+}
+
+/**
+ * Helper class which is used to instantiate pretrained sequence-to-sequence speech-to-text models with the `from_pretrained` function.
+ * The chosen model class is determined by the type specified in the model config.
+ *
+ * @example
+ * let model = await AutoModelForSpeechSeq2Seq.from_pretrained('openai/whisper-tiny.en');
+ */
+export class AutoModelForSpeechSeq2Seq extends PretrainedMixin {
+ static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES];
+}
+
+/**
+ * Helper class which is used to instantiate pretrained sequence-to-sequence text-to-spectrogram models with the `from_pretrained` function.
+ * The chosen model class is determined by the type specified in the model config.
+ *
+ * @example
+ * let model = await AutoModelForTextToSpectrogram.from_pretrained('microsoft/speecht5_tts');
+ */
+export class AutoModelForTextToSpectrogram extends PretrainedMixin {
+ static MODEL_CLASS_MAPPINGS = [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES];
}
/**
diff --git a/src/pipelines.js b/src/pipelines.js
index 0f6ddccee..4873f35a9 100644
--- a/src/pipelines.js
+++ b/src/pipelines.js
@@ -25,6 +25,8 @@ import {
AutoModelForQuestionAnswering,
AutoModelForMaskedLM,
AutoModelForSeq2SeqLM,
+ AutoModelForSpeechSeq2Seq,
+ AutoModelForTextToSpectrogram,
AutoModelForCTC,
AutoModelForCausalLM,
AutoModelForVision2Seq,
@@ -32,6 +34,7 @@ import {
AutoModelForImageSegmentation,
AutoModelForObjectDetection,
AutoModelForDocumentQuestionAnswering,
+ // AutoModelForTextToWaveform,
PreTrainedModel,
} from './models.js';
import {
@@ -57,6 +60,7 @@ import {
read_audio
} from './utils/audio.js';
import {
+ Tensor,
mean_pooling,
} from './utils/tensor.js';
import { RawImage } from './utils/image.js';
@@ -1127,8 +1131,7 @@ export class AutomaticSpeechRecognitionPipeline extends Pipeline {
}
/**
- * @typedef {import('./utils/tensor.js').Tensor} Tensor
- * @typedef {{stride: number[], input_features: Tensor, is_last: boolean, tokens?: number[], token_timestamps?: number[]}} Chunk
+ * @typedef {{stride: number[], input_features: import('./utils/tensor.js').Tensor, is_last: boolean, tokens?: number[], token_timestamps?: number[]}} Chunk
*
* @callback ChunkCallback
* @param {Chunk} chunk The chunk to process.
@@ -1835,6 +1838,103 @@ export class DocumentQuestionAnsweringPipeline extends Pipeline {
}
}
+/**
+ * Text-to-audio generation pipeline using any `AutoModelForTextToWaveform` or `AutoModelForTextToSpectrogram`.
+ * This pipeline generates an audio file from an input text and optional other conditional inputs.
+ *
+ * **Example:** Generate audio from text with `Xenova/speecht5_tts`.
+ * ```js
+ * let speaker_embeddings = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/speaker_embeddings.bin';
+ * let synthesizer = await pipeline('text-to-speech', 'Xenova/speecht5_tts', { quantized: false });
+ * let out = await synthesizer('Hello, my dog is cute', { speaker_embeddings });
+ * // {
+ * // audio: Float32Array(26112) [-0.00005657337896991521, 0.00020583874720614403, ...],
+ * // sampling_rate: 16000
+ * // }
+ * ```
+ *
+ * You can then save the audio to a .wav file with the `wavefile` package:
+ * ```js
+ * import wavefile from 'wavefile';
+ * import fs from 'fs';
+ *
+ * let wav = new wavefile.WaveFile();
+ * wav.fromScratch(1, out.sampling_rate, '32f', out.audio);
+ * fs.writeFileSync('out.wav', wav.toBuffer());
+ * ```
+ */
+export class TextToAudioPipeline extends Pipeline {
+ DEFAULT_VOCODER_ID = "Xenova/speecht5_hifigan"
+
+ /**
+ * Create a new TextToAudioPipeline.
+ * @param {Object} options An object containing the following properties:
+ * @param {string} [options.task] The task of the pipeline. Useful for specifying subtasks.
+ * @param {PreTrainedModel} [options.model] The model to use.
+ * @param {PreTrainedTokenizer} [options.tokenizer] The tokenizer to use.
+ * @param {Processor} [options.processor] The processor to use.
+ * @param {PreTrainedModel} [options.vocoder] The vocoder to use.
+ */
+ constructor(options) {
+ super(options);
+
+ // TODO: Find a better way for `pipeline` to set the default vocoder
+ this.vocoder = options.vocoder ?? null;
+ }
+
+ /**
+ * Generates speech/audio from the inputs.
+ * @param {string|string[]} text_inputs The text(s) to generate.
+ * @param {Object} options Parameters passed to the model generation/forward method.
+ * @param {PreTrainedModel} [options.vocoder=null] The vocoder to use (if the model uses one). If not provided, use the default HifiGan vocoder.
+ * @param {Tensor|Float32Array|string|URL} [options.speaker_embeddings=null]
+ * @returns {Promise