4 Commits

Author SHA1 Message Date
Jong Wook Kim ad3250a846 Release 20230308 2023-03-08 15:48:57 -08:00
Jong Wook Kim c4b50c0824 kwargs in decode() for convenience (#1061)
* kwargs in decode() for convenience

* formatting fix
2023-03-08 15:46:38 -08:00
Jong Wook Kim 38f2f4d99d fix all_tokens handling that caused more repetitions and discrepancy in JSON (#1060) 2023-03-08 15:34:07 -08:00
Jong Wook Kim aac47c9834 fix typo 2023-03-07 20:43:49 -08:00
6 changed files with 30 additions and 15 deletions
+7 -1
View File
@@ -1,6 +1,12 @@
# CHANGELOG
## [v20230307](https://github.com/openai/whisper/releases/tag/v202303067)
## [v20230308](https://github.com/openai/whisper/releases/tag/v20230308)
* kwargs in decode() for convenience ([#1061](https://github.com/openai/whisper/pull/1061))
* fix all_tokens handling that caused more repetitions and discrepancy in JSON ([#1060](https://github.com/openai/whisper/pull/1060))
* fix typo in CHANGELOG.md
## [v20230307](https://github.com/openai/whisper/releases/tag/v20230307)
* Fix the repetition/hallucination issue identified in #1046 ([#1052](https://github.com/openai/whisper/pull/1052))
* Use triton==2.0.0 ([#1053](https://github.com/openai/whisper/pull/1053))
+1
View File
@@ -17,6 +17,7 @@ def test_transcribe(model_name: str):
audio_path, language=language, temperature=0.0, word_timestamps=True
)
assert result["language"] == "en"
assert result["text"] == "".join([s["text"] for s in result["segments"]])
transcription = result["text"].lower()
assert "my fellow americans" in transcription
+8 -2
View File
@@ -1,4 +1,4 @@
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
@@ -778,7 +778,10 @@ class DecodingTask:
@torch.no_grad()
def decode(
model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()
model: "Whisper",
mel: Tensor,
options: DecodingOptions = DecodingOptions(),
**kwargs,
) -> Union[DecodingResult, List[DecodingResult]]:
"""
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
@@ -802,6 +805,9 @@ def decode(
if single := mel.ndim == 2:
mel = mel.unsqueeze(0)
if kwargs:
options = replace(options, **kwargs)
result = DecodingTask(model, options).run(mel)
return result[0] if single else result
+1 -1
View File
@@ -290,7 +290,7 @@ def add_word_timestamps(
if len(segments) == 0:
return
text_tokens = [t for segment in segments for t in segment["tokens"]]
text_tokens = [t for s in segments for t in s["tokens"] if t < tokenizer.eot]
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
+12 -10
View File
@@ -200,14 +200,14 @@ def transcribe(
def new_segment(
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
):
text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot]
tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return {
"id": len(all_segments),
"seek": seek,
"start": start,
"end": end,
"text": tokenizer.decode(text_tokens),
"tokens": text_tokens,
"tokens": tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
@@ -245,7 +245,6 @@ def transcribe(
previous_seek = seek
current_segments = []
current_tokens = []
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
@@ -275,7 +274,6 @@ def transcribe(
result=result,
)
)
current_tokens.append(sliced_tokens.tolist())
last_slice = current_slice
if single_timestamp_ending:
@@ -287,7 +285,6 @@ def transcribe(
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
@@ -309,7 +306,6 @@ def transcribe(
result=result,
)
)
current_tokens.append(tokens.tolist())
seek += segment_size
if not condition_on_previous_text or result.temperature > 0.5:
@@ -348,11 +344,17 @@ def transcribe(
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
current_tokens[i] = []
all_segments.extend(current_segments)
all_segments.extend(
[
{"id": i, **segment}
for i, segment in enumerate(
current_segments, start=len(all_segments)
)
]
)
all_tokens.extend(
[token for segment in current_tokens for token in segment]
[token for segment in current_segments for token in segment["tokens"]]
)
# update progress bar
+1 -1
View File
@@ -1 +1 @@
__version__ = "20230307"
__version__ = "20230308"