Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 26807ec6d3 | |||
| 919a713499 | |||
| 38e990d853 | |||
| 924e1f8e06 | |||
| 4b0d5e58d0 |
+20
-14
@@ -1,25 +1,31 @@
|
||||
# CHANGELOG
|
||||
|
||||
## [v20230307](https://github.com/openai/whisper/releases/tag/v202303067)
|
||||
|
||||
* Fix the repetition/hallucination issue identified in #1046 ([#1052](https://github.com/openai/whisper/pull/1052))
|
||||
* Use triton==2.0.0 ([#1053](https://github.com/openai/whisper/pull/1053))
|
||||
* Install triton in x86_64 linux only ([#1051](https://github.com/openai/whisper/pull/1051))
|
||||
* update setup.py to specify python >= 3.8 requirement
|
||||
|
||||
## [v20230306](https://github.com/openai/whisper/releases/tag/v20230306)
|
||||
|
||||
* #1021: remove auxiliary audio extension
|
||||
* #1038: apply formatting with `black`, `isort`, and `flake8`
|
||||
* #869: word-level timestamps in `transcribe()`
|
||||
* #1033: Decoding improvements
|
||||
* #894: Update README.md
|
||||
* #914: Fix infinite loop caused by incorrect timestamp tokens prediction
|
||||
* #889: drop python 3.7 support
|
||||
* remove auxiliary audio extension ([#1021](https://github.com/openai/whisper/pull/1021))
|
||||
* apply formatting with `black`, `isort`, and `flake8` ([#1038](https://github.com/openai/whisper/pull/1038))
|
||||
* word-level timestamps in `transcribe()` ([#869](https://github.com/openai/whisper/pull/869))
|
||||
* Decoding improvements ([#1033](https://github.com/openai/whisper/pull/1033))
|
||||
* Update README.md ([#894](https://github.com/openai/whisper/pull/894))
|
||||
* Fix infinite loop caused by incorrect timestamp tokens prediction ([#914](https://github.com/openai/whisper/pull/914))
|
||||
* drop python 3.7 support ([#889](https://github.com/openai/whisper/pull/889))
|
||||
|
||||
## [v20230124](https://github.com/openai/whisper/releases/tag/v20230124)
|
||||
|
||||
* #887: handle printing even if sys.stdout.buffer is not available
|
||||
* #228: Add TSV formatted output in transcript, using integer start/end time in milliseconds
|
||||
* #333: Added `--output_format` option
|
||||
* #864: Handle `XDG_CACHE_HOME` properly for `download_root`
|
||||
* #867: use stdout for printing transcription progress
|
||||
* #659: Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm
|
||||
* #859: print '?' if a letter can't be encoded using the system default encoding
|
||||
* handle printing even if sys.stdout.buffer is not available ([#887](https://github.com/openai/whisper/pull/887))
|
||||
* Add TSV formatted output in transcript, using integer start/end time in milliseconds ([#228](https://github.com/openai/whisper/pull/228))
|
||||
* Added `--output_format` option ([#333](https://github.com/openai/whisper/pull/333))
|
||||
* Handle `XDG_CACHE_HOME` properly for `download_root` ([#864](https://github.com/openai/whisper/pull/864))
|
||||
* use stdout for printing transcription progress ([#867](https://github.com/openai/whisper/pull/867))
|
||||
* Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm ([#659](https://github.com/openai/whisper/pull/659))
|
||||
* print '?' if a letter can't be encoded using the system default encoding ([#859](https://github.com/openai/whisper/pull/859))
|
||||
|
||||
## [v20230117](https://github.com/openai/whisper/releases/tag/v20230117)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
|
||||
import pkg_resources
|
||||
@@ -11,22 +12,8 @@ def read_version(fname="whisper/version.py"):
|
||||
|
||||
|
||||
requirements = []
|
||||
if sys.platform.startswith("linux"):
|
||||
triton_requirement = "triton>=2.0.0.dev20221202"
|
||||
try:
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
version_line = (
|
||||
subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
|
||||
)
|
||||
major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
|
||||
if (int(major), int(minor)) < (11, 4):
|
||||
# the last version supporting CUDA < 11.4
|
||||
triton_requirement = "triton==2.0.0.dev20221011"
|
||||
except (IndexError, OSError, subprocess.SubprocessError):
|
||||
pass
|
||||
requirements.append(triton_requirement)
|
||||
if sys.platform.startswith("linux") and platform.machine() == "x86_64":
|
||||
requirements.append("triton==2.0.0")
|
||||
|
||||
setup(
|
||||
name="openai-whisper",
|
||||
@@ -36,7 +23,7 @@ setup(
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
readme="README.md",
|
||||
python_requires=">=3.7",
|
||||
python_requires=">=3.8",
|
||||
author="OpenAI",
|
||||
url="https://github.com/openai/whisper",
|
||||
license="MIT",
|
||||
|
||||
+17
-6
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
@@ -15,10 +15,8 @@ N_FFT = 400
|
||||
N_MELS = 80
|
||||
HOP_LENGTH = 160
|
||||
CHUNK_LENGTH = 30
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
||||
N_FRAMES = exact_div(
|
||||
N_SAMPLES, HOP_LENGTH
|
||||
) # 3000: number of frames in a mel spectrogram input
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
||||
|
||||
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
||||
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
||||
@@ -100,7 +98,10 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
||||
|
||||
|
||||
def log_mel_spectrogram(
|
||||
audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
n_mels: int = N_MELS,
|
||||
padding: int = 0,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
"""
|
||||
Compute the log-Mel spectrogram of
|
||||
@@ -113,6 +114,12 @@ def log_mel_spectrogram(
|
||||
n_mels: int
|
||||
The number of Mel-frequency filters, only 80 is supported
|
||||
|
||||
padding: int
|
||||
Number of zero samples to pad to the right
|
||||
|
||||
device: Optional[Union[str, torch.device]]
|
||||
If given, the audio tensor is moved to this device before STFT
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor, shape = (80, n_frames)
|
||||
@@ -123,6 +130,10 @@ def log_mel_spectrogram(
|
||||
audio = load_audio(audio)
|
||||
audio = torch.from_numpy(audio)
|
||||
|
||||
if device is not None:
|
||||
audio = audio.to(device)
|
||||
if padding > 0:
|
||||
audio = F.pad(audio, (0, padding))
|
||||
window = torch.hann_window(N_FFT).to(audio.device)
|
||||
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
|
||||
+21
-21
@@ -11,6 +11,7 @@ from .audio import (
|
||||
FRAMES_PER_SECOND,
|
||||
HOP_LENGTH,
|
||||
N_FRAMES,
|
||||
N_SAMPLES,
|
||||
SAMPLE_RATE,
|
||||
log_mel_spectrogram,
|
||||
pad_or_trim,
|
||||
@@ -116,7 +117,9 @@ def transcribe(
|
||||
if dtype == torch.float32:
|
||||
decode_options["fp16"] = False
|
||||
|
||||
mel = log_mel_spectrogram(audio)
|
||||
# Pad 30-seconds of silence to the input audio, for slicing
|
||||
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
||||
content_frames = mel.shape[-1] - N_FRAMES
|
||||
|
||||
if decode_options.get("language", None) is None:
|
||||
if not model.is_multilingual:
|
||||
@@ -212,14 +215,13 @@ def transcribe(
|
||||
}
|
||||
|
||||
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
||||
num_frames = mel.shape[-1]
|
||||
with tqdm.tqdm(
|
||||
total=num_frames, unit="frames", disable=verbose is not False
|
||||
total=content_frames, unit="frames", disable=verbose is not False
|
||||
) as pbar:
|
||||
while seek < num_frames:
|
||||
while seek < content_frames:
|
||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||
mel_segment = mel[:, seek:]
|
||||
segment_size = min(mel_segment.shape[-1], N_FRAMES)
|
||||
mel_segment = mel[:, seek : seek + N_FRAMES]
|
||||
segment_size = min(N_FRAMES, content_frames - seek)
|
||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||
|
||||
@@ -246,20 +248,18 @@ def transcribe(
|
||||
current_tokens = []
|
||||
|
||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[
|
||||
0
|
||||
].add_(1)
|
||||
if (
|
||||
len(consecutive) > 0
|
||||
): # if the output contains two consecutive timestamp tokens
|
||||
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [
|
||||
False,
|
||||
True,
|
||||
]:
|
||||
consecutive = consecutive.tolist() + [len(tokens)]
|
||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||
|
||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||
consecutive.add_(1)
|
||||
if len(consecutive) > 0:
|
||||
# if the output contains two consecutive timestamp tokens
|
||||
slices = consecutive.tolist()
|
||||
if single_timestamp_ending:
|
||||
slices.append(len(tokens))
|
||||
|
||||
last_slice = 0
|
||||
for current_slice in consecutive:
|
||||
for current_slice in slices:
|
||||
sliced_tokens = tokens[last_slice:current_slice]
|
||||
start_timestamp_pos = (
|
||||
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
||||
@@ -278,7 +278,7 @@ def transcribe(
|
||||
current_tokens.append(sliced_tokens.tolist())
|
||||
last_slice = current_slice
|
||||
|
||||
if ended_with_single_timestamp:
|
||||
if single_timestamp_ending:
|
||||
# single timestamp at the end means no speech after the last timestamp.
|
||||
seek += segment_size
|
||||
else:
|
||||
@@ -329,7 +329,7 @@ def transcribe(
|
||||
word_end_timestamps = [
|
||||
w["end"] for s in current_segments for w in s["words"]
|
||||
]
|
||||
if len(consecutive) > 0 and len(word_end_timestamps) > 0:
|
||||
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
||||
seek_shift = round(
|
||||
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
||||
)
|
||||
@@ -356,7 +356,7 @@ def transcribe(
|
||||
)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(min(num_frames, seek) - previous_seek)
|
||||
pbar.update(min(content_frames, seek) - previous_seek)
|
||||
|
||||
return dict(
|
||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
||||
|
||||
+1
-1
@@ -1 +1 @@
|
||||
__version__ = "20230306"
|
||||
__version__ = "20230307"
|
||||
|
||||
Reference in New Issue
Block a user