Skip to content

Commit

Permalink
Add option to carry initial_prompt with the sliding window (#2343)
Browse files Browse the repository at this point in the history
* Add option to carry initial_prompt with the sliding window

Add an option `carry_initial_prompt = False` to `whisper.transcribe()`.
When set to `True`, `initial_prompt` is prepended to each internal `decode()` call's `prompt`.
If there is not enough context space at the start of the prompt, the prompt is left-sliced to make space.

* Prevent redundant initial_prompt_tokens

* Revert unnecessary .gitignore change

---------

Co-authored-by: Kittsil <[email protected]>
Co-authored-by: Jong Wook Kim <[email protected]>
  • Loading branch information
3 people authored Oct 26, 2024
1 parent cdb8147 commit 5979f03
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def transcribe(
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
carry_initial_prompt: bool = False,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
Expand Down Expand Up @@ -102,6 +103,11 @@ def transcribe(
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.
carry_initial_prompt: bool
If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal
`decode()` call. If there is not enough context space at the start of the prompt, it is
left-sliced to make space.
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
Expand Down Expand Up @@ -227,9 +233,11 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
all_segments = []
prompt_reset_since = 0

remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
remaining_prompt_length -= len(initial_prompt_tokens)
else:
initial_prompt_tokens = []

Expand Down Expand Up @@ -275,7 +283,13 @@ def new_segment(
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)

decode_options["prompt"] = all_tokens[prompt_reset_since:]
if carry_initial_prompt:
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:]
decode_options["prompt"] = initial_prompt_tokens + remaining_prompt
else:
decode_options["prompt"] = all_tokens[prompt_reset_since:]

result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens)

Expand Down Expand Up @@ -529,6 +543,8 @@ def valid_model_name(name):

parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text")

parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")

Expand Down

0 comments on commit 5979f03

Please sign in to comment.