9 Commits

Author SHA1 Message Date
Jong Wook Kim ad3250a846 Release 20230308 2023-03-08 15:48:57 -08:00
Jong Wook Kim c4b50c0824 kwargs in decode() for convenience (#1061)
* kwargs in decode() for convenience

* formatting fix
2023-03-08 15:46:38 -08:00
Jong Wook Kim 38f2f4d99d fix all_tokens handling that caused more repetitions and discrepancy in JSON (#1060) 2023-03-08 15:34:07 -08:00
Jong Wook Kim aac47c9834 fix typo 2023-03-07 20:43:49 -08:00
Jong Wook Kim 26807ec6d3 Release 20230307 2023-03-07 20:36:29 -08:00
Jong Wook Kim 919a713499 attempt to fix the repetition/hallucination issue identified in #1046 (#1052)
* attempt to fix the repetition/hallucination issue identified in #1046

* zero-pad the audio instead of spectrogram

* formatting fix

* delete debug print
2023-03-07 20:08:45 -08:00
Jong Wook Kim 38e990d853 Use triton==2.0.0 (#1053) 2023-03-07 16:56:31 -08:00
Jong Wook Kim 924e1f8e06 Try installing triton only if linux & x86_64 (#1051) 2023-03-07 11:31:40 -08:00
Jong Wook Kim 4b0d5e58d0 Update setup.py 2023-03-07 04:47:46 -08:00
8 changed files with 91 additions and 72 deletions
+26 -14
View File
@@ -1,25 +1,37 @@
# CHANGELOG
## [v20230308](https://github.com/openai/whisper/releases/tag/v20230308)
* kwargs in decode() for convenience ([#1061](https://github.com/openai/whisper/pull/1061))
* fix all_tokens handling that caused more repetitions and discrepancy in JSON ([#1060](https://github.com/openai/whisper/pull/1060))
* fix typo in CHANGELOG.md
## [v20230307](https://github.com/openai/whisper/releases/tag/v20230307)
* Fix the repetition/hallucination issue identified in #1046 ([#1052](https://github.com/openai/whisper/pull/1052))
* Use triton==2.0.0 ([#1053](https://github.com/openai/whisper/pull/1053))
* Install triton in x86_64 linux only ([#1051](https://github.com/openai/whisper/pull/1051))
* update setup.py to specify python >= 3.8 requirement
## [v20230306](https://github.com/openai/whisper/releases/tag/v20230306)
* #1021: remove auxiliary audio extension
* #1038: apply formatting with `black`, `isort`, and `flake8`
* #869: word-level timestamps in `transcribe()`
* #1033: Decoding improvements
* #894: Update README.md
* #914: Fix infinite loop caused by incorrect timestamp tokens prediction
* #889: drop python 3.7 support
* remove auxiliary audio extension ([#1021](https://github.com/openai/whisper/pull/1021))
* apply formatting with `black`, `isort`, and `flake8` ([#1038](https://github.com/openai/whisper/pull/1038))
* word-level timestamps in `transcribe()` ([#869](https://github.com/openai/whisper/pull/869))
* Decoding improvements ([#1033](https://github.com/openai/whisper/pull/1033))
* Update README.md ([#894](https://github.com/openai/whisper/pull/894))
* Fix infinite loop caused by incorrect timestamp tokens prediction ([#914](https://github.com/openai/whisper/pull/914))
* drop python 3.7 support ([#889](https://github.com/openai/whisper/pull/889))
## [v20230124](https://github.com/openai/whisper/releases/tag/v20230124)
* #887: handle printing even if sys.stdout.buffer is not available
* #228: Add TSV formatted output in transcript, using integer start/end time in milliseconds
* #333: Added `--output_format` option
* #864: Handle `XDG_CACHE_HOME` properly for `download_root`
* #867: use stdout for printing transcription progress
* #659: Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm
* #859: print '?' if a letter can't be encoded using the system default encoding
* handle printing even if sys.stdout.buffer is not available ([#887](https://github.com/openai/whisper/pull/887))
* Add TSV formatted output in transcript, using integer start/end time in milliseconds ([#228](https://github.com/openai/whisper/pull/228))
* Added `--output_format` option ([#333](https://github.com/openai/whisper/pull/333))
* Handle `XDG_CACHE_HOME` properly for `download_root` ([#864](https://github.com/openai/whisper/pull/864))
* use stdout for printing transcription progress ([#867](https://github.com/openai/whisper/pull/867))
* Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm ([#659](https://github.com/openai/whisper/pull/659))
* print '?' if a letter can't be encoded using the system default encoding ([#859](https://github.com/openai/whisper/pull/859))
## [v20230117](https://github.com/openai/whisper/releases/tag/v20230117)
+4 -17
View File
@@ -1,4 +1,5 @@
import os
import platform
import sys
import pkg_resources
@@ -11,22 +12,8 @@ def read_version(fname="whisper/version.py"):
requirements = []
if sys.platform.startswith("linux"):
triton_requirement = "triton>=2.0.0.dev20221202"
try:
import re
import subprocess
version_line = (
subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
)
major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
if (int(major), int(minor)) < (11, 4):
# the last version supporting CUDA < 11.4
triton_requirement = "triton==2.0.0.dev20221011"
except (IndexError, OSError, subprocess.SubprocessError):
pass
requirements.append(triton_requirement)
if sys.platform.startswith("linux") and platform.machine() == "x86_64":
requirements.append("triton==2.0.0")
setup(
name="openai-whisper",
@@ -36,7 +23,7 @@ setup(
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
readme="README.md",
python_requires=">=3.7",
python_requires=">=3.8",
author="OpenAI",
url="https://github.com/openai/whisper",
license="MIT",
+1
View File
@@ -17,6 +17,7 @@ def test_transcribe(model_name: str):
audio_path, language=language, temperature=0.0, word_timestamps=True
)
assert result["language"] == "en"
assert result["text"] == "".join([s["text"] for s in result["segments"]])
transcription = result["text"].lower()
assert "my fellow americans" in transcription
+17 -6
View File
@@ -1,6 +1,6 @@
import os
from functools import lru_cache
from typing import Union
from typing import Optional, Union
import ffmpeg
import numpy as np
@@ -15,10 +15,8 @@ N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(
N_SAMPLES, HOP_LENGTH
) # 3000: number of frames in a mel spectrogram input
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
@@ -100,7 +98,10 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = N_MELS,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of
@@ -113,6 +114,12 @@ def log_mel_spectrogram(
n_mels: int
The number of Mel-frequency filters, only 80 is supported
padding: int
Number of zero samples to pad to the right
device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns
-------
torch.Tensor, shape = (80, n_frames)
@@ -123,6 +130,10 @@ def log_mel_spectrogram(
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
+8 -2
View File
@@ -1,4 +1,4 @@
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
@@ -778,7 +778,10 @@ class DecodingTask:
@torch.no_grad()
def decode(
model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()
model: "Whisper",
mel: Tensor,
options: DecodingOptions = DecodingOptions(),
**kwargs,
) -> Union[DecodingResult, List[DecodingResult]]:
"""
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
@@ -802,6 +805,9 @@ def decode(
if single := mel.ndim == 2:
mel = mel.unsqueeze(0)
if kwargs:
options = replace(options, **kwargs)
result = DecodingTask(model, options).run(mel)
return result[0] if single else result
+1 -1
View File
@@ -290,7 +290,7 @@ def add_word_timestamps(
if len(segments) == 0:
return
text_tokens = [t for segment in segments for t in segment["tokens"]]
text_tokens = [t for s in segments for t in s["tokens"] if t < tokenizer.eot]
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
+33 -31
View File
@@ -11,6 +11,7 @@ from .audio import (
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
N_SAMPLES,
SAMPLE_RATE,
log_mel_spectrogram,
pad_or_trim,
@@ -116,7 +117,9 @@ def transcribe(
if dtype == torch.float32:
decode_options["fp16"] = False
mel = log_mel_spectrogram(audio)
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES
if decode_options.get("language", None) is None:
if not model.is_multilingual:
@@ -197,14 +200,14 @@ def transcribe(
def new_segment(
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
):
text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot]
tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return {
"id": len(all_segments),
"seek": seek,
"start": start,
"end": end,
"text": tokenizer.decode(text_tokens),
"tokens": text_tokens,
"tokens": tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
@@ -212,14 +215,13 @@ def transcribe(
}
# show the progress bar when verbose is False (if True, transcribed text will be printed)
num_frames = mel.shape[-1]
with tqdm.tqdm(
total=num_frames, unit="frames", disable=verbose is not False
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
while seek < num_frames:
while seek < content_frames:
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek:]
segment_size = min(mel_segment.shape[-1], N_FRAMES)
mel_segment = mel[:, seek : seek + N_FRAMES]
segment_size = min(N_FRAMES, content_frames - seek)
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
@@ -243,23 +245,20 @@ def transcribe(
previous_seek = seek
current_segments = []
current_tokens = []
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[
0
].add_(1)
if (
len(consecutive) > 0
): # if the output contains two consecutive timestamp tokens
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [
False,
True,
]:
consecutive = consecutive.tolist() + [len(tokens)]
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
consecutive.add_(1)
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
slices = consecutive.tolist()
if single_timestamp_ending:
slices.append(len(tokens))
last_slice = 0
for current_slice in consecutive:
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
@@ -275,10 +274,9 @@ def transcribe(
result=result,
)
)
current_tokens.append(sliced_tokens.tolist())
last_slice = current_slice
if ended_with_single_timestamp:
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
else:
@@ -287,7 +285,6 @@ def transcribe(
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
@@ -309,7 +306,6 @@ def transcribe(
result=result,
)
)
current_tokens.append(tokens.tolist())
seek += segment_size
if not condition_on_previous_text or result.temperature > 0.5:
@@ -329,7 +325,7 @@ def transcribe(
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if len(consecutive) > 0 and len(word_end_timestamps) > 0:
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
)
@@ -348,15 +344,21 @@ def transcribe(
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
current_tokens[i] = []
all_segments.extend(current_segments)
all_segments.extend(
[
{"id": i, **segment}
for i, segment in enumerate(
current_segments, start=len(all_segments)
)
]
)
all_tokens.extend(
[token for segment in current_tokens for token in segment]
[token for segment in current_segments for token in segment["tokens"]]
)
# update progress bar
pbar.update(min(num_frames, seek) - previous_seek)
pbar.update(min(content_frames, seek) - previous_seek)
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
+1 -1
View File
@@ -1 +1 @@
__version__ = "20230306"
__version__ = "20230308"