10 Commits

Author SHA1 Message Date
Jong Wook Kim 8180fde939 Release 20230306 2023-03-06 18:53:04 -08:00
Local State c6e4e5efb3 remove auxiliary audio extension (#1021)
Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-03-06 17:48:14 -08:00
Jong Wook Kim b80bcf610d apply formatting with black (#1038)
* applying black (with the default 88-column limit)

* add flake8

* add isort

* fix isort
2023-03-06 15:50:37 -08:00
Jong Wook Kim 500d0fe966 word-level timestamps in transcribe() (#869)
* word-level timestamps in `transcribe()`

* moving to `timing.py`

* numba implementation for dtw, replacing dtw-python

* triton implementation for dtw

* add test for dtw implementations

* triton implementation of median_filter

* a simple word-level timestamps test

* add scipy as dev dependency

* installs an older version of Triton if CUDA < 11.4

* fix broken merge

* loosen nvcc version match regex

* find_alignment() function

* miscellaneous improvements

* skip median filtering when the input is too small

* Expose punctuation options in cli and transcribe() (#973)

* fix merge error

* fix merge error 2

* annotating that word_timestamps is experimental

---------

Co-authored-by: ryanheise <ryan@ryanheise.com>
2023-03-06 14:00:49 -08:00
Jong Wook Kim eab8d920ed Decoding improvements (#1033)
* suppress task tokens (transcribe/translate)

* not ignoring the last segment ending with one timestamp
2023-03-06 11:32:32 -08:00
Roman Vasilenko 3e1780fd37 Update README.md (#894)
Fixed a few typos and made general improvements for clarity.

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-03-03 16:41:59 -08:00
Andrey Chernykh 7858aa9c08 Fix infinite loop caused by incorrect timestamp tokens prediction (#914)
* Fix infinite loop caused by incorrect timestamp tokens prediction

https://github.com/openai/whisper/discussions/810

* Update decoding.py

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-02-01 15:46:51 -08:00
Jong Wook Kim 5c1a8c10e7 clarify that 3.11 is not supported 2023-01-27 00:01:49 -08:00
Jong Wook Kim 4e635c6644 Update README.md about Python 3.8+ requirement 2023-01-24 14:45:56 -08:00
Jong Wook Kim a6b36ede1f drop python 3.7 support (#889) 2023-01-24 14:05:57 -08:00
26 changed files with 1307 additions and 279 deletions
+4
View File
@@ -0,0 +1,4 @@
[flake8]
per-file-ignores =
*/__init__.py: F401
+5 -3
View File
@@ -21,6 +21,8 @@ jobs:
- run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
- uses: actions/checkout@v2
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
- run: pip install pytest
- run: pip install .
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]'
- run: pip install .["dev"]
- run: black --check --diff -t py38 --include '(\.pyi?)$' .
- run: isort --check --diff .
- run: flake8 --ignore E203,W503,W504,E501,E731,E741 .
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
+26
View File
@@ -0,0 +1,26 @@
# CHANGELOG
## [v20230306](https://github.com/openai/whisper/releases/tag/v20230306)
* #1021: remove auxiliary audio extension
* #1038: apply formatting with `black`, `isort`, and `flake8`
* #869: word-level timestamps in `transcribe()`
* #1033: Decoding improvements
* #894: Update README.md
* #914: Fix infinite loop caused by incorrect timestamp tokens prediction
* #889: drop python 3.7 support
## [v20230124](https://github.com/openai/whisper/releases/tag/v20230124)
* #887: handle printing even if sys.stdout.buffer is not available
* #228: Add TSV formatted output in transcript, using integer start/end time in milliseconds
* #333: Added `--output_format` option
* #864: Handle `XDG_CACHE_HOME` properly for `download_root`
* #867: use stdout for printing transcription progress
* #659: Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm
* #859: print '?' if a letter can't be encoded using the system default encoding
## [v20230117](https://github.com/openai/whisper/releases/tag/v20230117)
The first versioned release available on [PyPI](https://pypi.org/project/openai-whisper/)
+6 -6
View File
@@ -5,19 +5,19 @@
[[Model card]](https://github.com/openai/whisper/blob/main/model-card.md)
[[Colab example]](https://colab.research.google.com/github/openai/whisper/blob/master/notebooks/LibriSpeech.ipynb)
Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multi-task model that can perform multilingual speech recognition as well as speech translation and language identification.
Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitasking model that can perform multilingual speech recognition, speech translation, and language identification.
## Approach
![Approach](https://raw.githubusercontent.com/openai/whisper/main/approach.png)
A Transformer sequence-to-sequence model is trained on various speech processing tasks, including multilingual speech recognition, speech translation, spoken language identification, and voice activity detection. All of these tasks are jointly represented as a sequence of tokens to be predicted by the decoder, allowing for a single model to replace many different stages of a traditional speech processing pipeline. The multitask training format uses a set of special tokens that serve as task specifiers or classification targets.
A Transformer sequence-to-sequence model is trained on various speech processing tasks, including multilingual speech recognition, speech translation, spoken language identification, and voice activity detection. These tasks are jointly represented as a sequence of tokens to be predicted by the decoder, allowing a single model to replace many stages of a traditional speech-processing pipeline. The multitask training format uses a set of special tokens that serve as task specifiers or classification targets.
## Setup
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.7 or later and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. You can download and install (or update to) the latest release of Whisper with the following command:
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.8-3.10 and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. You can download and install (or update to) the latest release of Whisper with the following command:
pip install -U openai-whisper
@@ -68,9 +68,9 @@ There are five model sizes, four with English-only versions, offering speed and
| medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x |
| large | 1550 M | N/A | `large` | ~10 GB | 1x |
For English-only applications, the `.en` models tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
Whisper's performance varies widely depending on the language. The figure below shows a WER (Word Error Rate) breakdown by languages of Fleurs dataset, using the `large-v2` model. More WER and BLEU scores corresponding to the other models and datasets can be found in Appendix D in [the paper](https://arxiv.org/abs/2212.04356). The smaller is better.
Whisper's performance varies widely depending on the language. The figure below shows a WER (Word Error Rate) breakdown by languages of the Fleurs dataset using the `large-v2` model. More WER and BLEU scores corresponding to the other models and datasets can be found in Appendix D in [the paper](https://arxiv.org/abs/2212.04356). The smaller, the better.
![WER breakdown by language](https://raw.githubusercontent.com/openai/whisper/main/language-breakdown.svg)
@@ -144,4 +144,4 @@ Please use the [🙌 Show and tell](https://github.com/openai/whisper/discussion
## License
The code and the model weights of Whisper are released under the MIT License. See [LICENSE](https://github.com/openai/whisper/blob/main/LICENSE) for further details.
Whisper's code and model weights are released under the MIT License. See [LICENSE](https://github.com/openai/whisper/blob/main/LICENSE) for further details.
+8
View File
@@ -0,0 +1,8 @@
[tool.black]
[tool.isort]
profile = "black"
include_trailing_comma = true
line_length = 88
multi_line_output = 3
+1
View File
@@ -1,3 +1,4 @@
numba
numpy
torch
tqdm
+23 -3
View File
@@ -1,7 +1,8 @@
import os
import sys
import pkg_resources
from setuptools import setup, find_packages
from setuptools import find_packages, setup
def read_version(fname="whisper/version.py"):
@@ -9,6 +10,24 @@ def read_version(fname="whisper/version.py"):
return locals()["__version__"]
requirements = []
if sys.platform.startswith("linux"):
triton_requirement = "triton>=2.0.0.dev20221202"
try:
import re
import subprocess
version_line = (
subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
)
major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
if (int(major), int(minor)) < (11, 4):
# the last version supporting CUDA < 11.4
triton_requirement = "triton==2.0.0.dev20221011"
except (IndexError, OSError, subprocess.SubprocessError):
pass
requirements.append(triton_requirement)
setup(
name="openai-whisper",
py_modules=["whisper"],
@@ -22,7 +41,8 @@ setup(
url="https://github.com/openai/whisper",
license="MIT",
packages=find_packages(exclude=["tests*"]),
install_requires=[
install_requires=requirements
+ [
str(r)
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
@@ -32,5 +52,5 @@ setup(
"console_scripts": ["whisper=whisper.transcribe:cli"],
},
include_package_data=True,
extras_require={"dev": ["pytest"]},
extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
)
+14
View File
@@ -0,0 +1,14 @@
import random as rand
import numpy
import pytest
def pytest_configure(config):
config.addinivalue_line("markers", "requires_cuda")
@pytest.fixture
def random():
rand.seed(42)
numpy.random.seed(42)
+1 -1
View File
@@ -2,7 +2,7 @@ import os.path
import numpy as np
from whisper.audio import load_audio, log_mel_spectrogram, SAMPLE_RATE
from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
def test_audio():
+4 -1
View File
@@ -1,7 +1,10 @@
import pytest
from whisper.normalizers import EnglishTextNormalizer
from whisper.normalizers.english import EnglishNumberNormalizer, EnglishSpellingNormalizer
from whisper.normalizers.english import (
EnglishNumberNormalizer,
EnglishSpellingNormalizer,
)
@pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()])
+96
View File
@@ -0,0 +1,96 @@
import numpy as np
import pytest
import scipy.ndimage
import torch
from whisper.timing import dtw_cpu, dtw_cuda, median_filter
sizes = [
(10, 20),
(32, 16),
(123, 1500),
(234, 189),
]
shapes = [
(10,),
(1, 15),
(4, 5, 345),
(6, 12, 240, 512),
]
@pytest.mark.parametrize("N, M", sizes)
def test_dtw(N: int, M: int):
steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)])
np.random.shuffle(steps)
x = np.random.random((N, M)).astype(np.float32)
i, j, k = 0, 0, 0
trace = []
while True:
x[i, j] -= 1
trace.append((i, j))
if k == len(steps):
break
if k + 1 < len(steps) and steps[k] != steps[k + 1]:
i += 1
j += 1
k += 2
continue
if steps[k] == 0:
i += 1
if steps[k] == 1:
j += 1
k += 1
trace = np.array(trace).T
dtw_trace = dtw_cpu(x)
assert np.allclose(trace, dtw_trace)
@pytest.mark.requires_cuda
@pytest.mark.parametrize("N, M", sizes)
def test_dtw_cuda_equivalence(N: int, M: int):
x_numpy = np.random.randn(N, M).astype(np.float32)
x_cuda = torch.from_numpy(x_numpy).cuda()
trace_cpu = dtw_cpu(x_numpy)
trace_cuda = dtw_cuda(x_cuda)
assert np.allclose(trace_cpu, trace_cuda)
@pytest.mark.parametrize("shape", shapes)
def test_median_filter(shape):
x = torch.randn(*shape)
for filter_width in [3, 5, 7, 13]:
filtered = median_filter(x, filter_width)
# using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
pad_width = filter_width // 2
padded_x = np.pad(
x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect"
)
scipy_filtered = scipy.ndimage.median_filter(
padded_x, [1] * (x.ndim - 1) + [filter_width]
)
scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
assert np.allclose(filtered, scipy_filtered)
@pytest.mark.requires_cuda
@pytest.mark.parametrize("shape", shapes)
def test_median_filter_equivalence(shape):
x = torch.randn(*shape)
for filter_width in [3, 5, 7, 13]:
filtered_cpu = median_filter(x, filter_width)
filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
assert np.allclose(filtered_cpu, filtered_gpu)
+15 -1
View File
@@ -13,10 +13,24 @@ def test_transcribe(model_name: str):
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
language = "en" if model_name.endswith(".en") else None
result = model.transcribe(audio_path, language=language, temperature=0.0)
result = model.transcribe(
audio_path, language=language, temperature=0.0, word_timestamps=True
)
assert result["language"] == "en"
transcription = result["text"].lower()
assert "my fellow americans" in transcription
assert "your country" in transcription
assert "do for you" in transcription
timing_checked = False
for segment in result["segments"]:
for timing in segment["words"]:
assert timing["start"] < timing["end"]
if timing["word"].strip(" ,") == "Americans":
assert timing["start"] <= 1.8
assert timing["end"] >= 1.8
print(timing)
timing_checked = True
assert timing_checked
+49 -17
View File
@@ -10,11 +10,10 @@ from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import Whisper, ModelDimensions
from .model import ModelDimensions, Whisper
from .transcribe import transcribe
from .version import __version__
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
@@ -29,6 +28,22 @@ _MODELS = {
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
}
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
@@ -45,10 +60,18 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
@@ -59,7 +82,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return model_bytes if in_memory else download_target
@@ -69,7 +94,12 @@ def available_models() -> List[str]:
return list(_MODELS.keys())
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
) -> Whisper:
"""
Load a Whisper ASR model
@@ -94,24 +124,23 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
download_root = os.path.join(
os.getenv(
"XDG_CACHE_HOME",
os.path.join(
os.path.expanduser("~"), ".cache"
)
),
"whisper"
)
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
@@ -119,4 +148,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model.to(device)
-1
View File
@@ -1,4 +1,3 @@
from .transcribe import cli
cli()
+16 -4
View File
@@ -16,7 +16,13 @@ N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
N_FRAMES = exact_div(
N_SAMPLES, HOP_LENGTH
) # 3000: number of frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE):
@@ -55,7 +61,9 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
array = array.index_select(
dim=axis, index=torch.arange(length, device=array.device)
)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
@@ -85,11 +93,15 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
with np.load(
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS
):
"""
Compute the log-Mel spectrogram of
+163 -66
View File
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
@@ -16,7 +16,9 @@ if TYPE_CHECKING:
@torch.no_grad()
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
def detect_language(
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
) -> Tuple[Tensor, List[dict]]:
"""
Detect the spoken language in the audio, and return them as list of strings, along with the ids
of the most probable language tokens and the probability distribution over all language tokens.
@@ -31,8 +33,13 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
"""
if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual)
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
if (
tokenizer.language is None
or tokenizer.language_token not in tokenizer.sot_sequence
):
raise ValueError(
"This model doesn't have language tokens so it can't perform lang id"
)
single = mel.ndim == 2
if single:
@@ -70,31 +77,36 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
@dataclass(frozen=True)
class DecodingOptions:
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
language: Optional[str] = None # language that the audio is in; uses detected language if None
# whether to perform X->X "transcribe" or X->English "translate"
task: str = "transcribe"
# language that the audio is in; uses detected language if None
language: Optional[str] = None
# sampling-related options
temperature: float = 0.0
sample_len: Optional[int] = None # maximum number of tokens to sample
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
# options for ranking generations (either beams or best-of-N samples)
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
# "alpha" in Google NMT, or None for length norm, when ranking generations
# to select which to return among the beams or best-of-N samples
length_penalty: Optional[float] = None
# prompt, prefix, and token suppression
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
suppress_blank: bool = True # this will suppress blank outputs
# text or tokens to feed as the prompt or the prefix; for more info:
# https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
prompt: Optional[Union[str, List[int]]] = None # for the previous context
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
# list of tokens ids (or comma-separated token ids) to suppress
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
suppress_blank: bool = True # this will suppress blank outputs
# timestamp sampling options
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 1.0
# implementation details
fp16: bool = True # use fp16 for most of the calculation
@@ -158,7 +170,9 @@ class PyTorchInference(Inference):
class SequenceRanker:
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
def rank(
self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
) -> List[int]:
"""
Given a list of groups of samples and their cumulative log probabilities,
return the indices of the samples in each group to select as the final result
@@ -196,7 +210,9 @@ class TokenDecoder:
def reset(self):
"""Initialize any stateful variables for decoding a new sequence"""
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
"""Specify how to select the next token, based on the current trace and logits
Parameters
@@ -251,12 +267,13 @@ class GreedyDecoder(TokenDecoder):
self.temperature = temperature
self.eot = eot
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
temperature = self.temperature
if temperature == 0:
def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
if self.temperature == 0:
next_tokens = logits.argmax(dim=-1)
else:
next_tokens = Categorical(logits=logits / temperature).sample()
next_tokens = Categorical(logits=logits / self.temperature).sample()
logprobs = F.log_softmax(logits.float(), dim=-1)
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
@@ -275,7 +292,13 @@ class GreedyDecoder(TokenDecoder):
class BeamSearchDecoder(TokenDecoder):
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
def __init__(
self,
beam_size: int,
eot: int,
inference: Inference,
patience: Optional[float] = None,
):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
@@ -283,12 +306,16 @@ class BeamSearchDecoder(TokenDecoder):
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences = None
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
assert (
self.max_candidates > 0
), f"Invalid beam size ({beam_size}) or patience ({patience})"
def reset(self):
self.finished_sequences = None
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
@@ -332,7 +359,9 @@ class BeamSearchDecoder(TokenDecoder):
# add newly finished sequences to self.finished_sequences
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
for previously_finished, newly_finished in zip(
self.finished_sequences, finished_sequences
):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates:
break # the candidate list is full
@@ -340,7 +369,8 @@ class BeamSearchDecoder(TokenDecoder):
# mark as completed if all audio has enough number of samples
completed = all(
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
len(sequences) >= self.max_candidates
for sequences in self.finished_sequences
)
return tokens, completed
@@ -348,7 +378,9 @@ class BeamSearchDecoder(TokenDecoder):
# collect all finished sequences, including patience, and add unfinished ones if not enough
sum_logprobs = sum_logprobs.cpu()
for i, sequences in enumerate(self.finished_sequences):
if len(sequences) < self.beam_size: # when not enough sequences are finished
if (
len(sequences) < self.beam_size
): # when not enough sequences are finished
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
sequence = preceding_tokens[i, j].tolist() + [self.eot]
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
@@ -356,7 +388,8 @@ class BeamSearchDecoder(TokenDecoder):
break
tokens: List[List[Tensor]] = [
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
[torch.tensor(seq) for seq in sequences.keys()]
for sequences in self.finished_sequences
]
sum_logprobs: List[List[float]] = [
list(sequences.values()) for sequences in self.finished_sequences
@@ -400,7 +433,10 @@ class SuppressTokens(LogitFilter):
class ApplyTimestampRules(LogitFilter):
def __init__(
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
self,
tokenizer: Tokenizer,
sample_begin: int,
max_initial_timestamp_index: Optional[int],
):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
@@ -413,9 +449,14 @@ class ApplyTimestampRules(LogitFilter):
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]):
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
sampled_tokens = tokens[k, self.sample_begin :]
seq = [t for t in sampled_tokens.tolist()]
last_was_timestamp = (
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
)
penultimate_was_timestamp = (
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
)
if last_was_timestamp:
if penultimate_was_timestamp: # has to be non-timestamp
@@ -423,19 +464,30 @@ class ApplyTimestampRules(LogitFilter):
else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf
timestamps = sampled_tokens[
sampled_tokens.ge(self.tokenizer.timestamp_begin)
]
if timestamps.numel() > 0:
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
if tokens.shape[1] == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
# apply the `max_initial_timestamp` option
if self.max_initial_timestamp_index is not None:
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
last_allowed = (
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
)
logits[:, last_allowed + 1 :] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = F.log_softmax(logits.float(), dim=-1)
for k in range(tokens.shape[0]):
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
dim=-1
)
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
@@ -451,7 +503,9 @@ class DecodingTask:
self.model = model
language = options.language or "en"
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
tokenizer = get_tokenizer(
model.is_multilingual, language=language, task=options.task
)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
@@ -491,9 +545,13 @@ class DecodingTask:
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
max_initial_timestamp_index = None
if options.max_initial_timestamp:
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
max_initial_timestamp_index = round(
self.options.max_initial_timestamp / precision
)
self.logit_filters.append(
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
ApplyTimestampRules(
tokenizer, self.sample_begin, max_initial_timestamp_index
)
)
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
@@ -504,30 +562,38 @@ class DecodingTask:
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
if options.patience is not None and options.beam_size is None:
raise ValueError("patience requires beam_size to be given")
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
if options.length_penalty is not None and not (
0 <= options.length_penalty <= 1
):
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
return options
def _get_initial_tokens(self) -> Tuple[int]:
tokens = list(self.sot_sequence)
prefix = self.options.prefix
prompt = self.options.prompt
if prefix:
if prefix := self.options.prefix:
prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
self.tokenizer.encode(" " + prefix.strip())
if isinstance(prefix, str)
else prefix
)
if self.sample_len is not None:
max_prefix_len = self.n_ctx // 2 - self.sample_len
prefix_tokens = prefix_tokens[-max_prefix_len:]
tokens = tokens + prefix_tokens
if prompt:
if prompt := self.options.prompt:
prompt_tokens = (
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
self.tokenizer.encode(" " + prompt.strip())
if isinstance(prompt, str)
else prompt
)
tokens = (
[self.tokenizer.sot_prev]
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
+ tokens
)
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
return tuple(tokens)
@@ -546,7 +612,13 @@ class DecodingTask:
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
suppress_tokens.extend(
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
[
self.tokenizer.transcribe,
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm,
]
)
if self.tokenizer.no_speech is not None:
# no-speech probability is collected separately
@@ -558,14 +630,21 @@ class DecodingTask:
if self.options.fp16:
mel = mel.half()
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
if mel.shape[-2:] == (
self.model.dims.n_audio_ctx,
self.model.dims.n_audio_state,
):
# encoded audio features are given; skip audio encoding
audio_features = mel
else:
audio_features = self.model.encoder(mel)
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
if audio_features.dtype != (
torch.float16 if self.options.fp16 else torch.float32
):
return TypeError(
f"audio_features has an incorrect dtype: {audio_features.dtype}"
)
return audio_features
@@ -574,7 +653,9 @@ class DecodingTask:
lang_probs = None
if self.options.language is None or self.options.task == "lang_id":
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
lang_tokens, lang_probs = self.model.detect_language(
audio_features, self.tokenizer
)
languages = [max(probs, key=probs.get) for probs in lang_probs]
if self.options.language is None:
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
@@ -591,7 +672,9 @@ class DecodingTask:
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
if (
i == 0 and self.tokenizer.no_speech is not None
): # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
@@ -625,8 +708,12 @@ class DecodingTask:
languages, language_probs = self._detect_language(audio_features, tokens)
if self.options.task == "lang_id":
return [
DecodingResult(audio_features=features, language=language, language_probs=probs)
for features, language, probs in zip(audio_features, languages, language_probs)
DecodingResult(
audio_features=features, language=language, language_probs=probs
)
for features, language, probs in zip(
audio_features, languages, language_probs
)
]
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
@@ -647,7 +734,8 @@ class DecodingTask:
# get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens: List[List[Tensor]] = [
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
for s in tokens
]
# select the top-ranked sample in each group
@@ -656,9 +744,18 @@ class DecodingTask:
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
avg_logprobs: List[float] = [
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
]
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
fields = (
texts,
languages,
tokens,
audio_features,
avg_logprobs,
no_speech_probs,
)
if len(set(map(len, fields))) != 1:
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
@@ -673,12 +770,16 @@ class DecodingTask:
temperature=self.options.temperature,
compression_ratio=compression_ratio(text),
)
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
*fields
)
]
@torch.no_grad()
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
def decode(
model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()
) -> Union[DecodingResult, List[DecodingResult]]:
"""
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
@@ -698,13 +799,9 @@ def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOpt
result: Union[DecodingResult, List[DecodingResult]]
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
"""
single = mel.ndim == 2
if single:
if single := mel.ndim == 2:
mel = mel.unsqueeze(0)
result = DecodingTask(model, options).run(mel)
if single:
result = result[0]
return result
return result[0] if single else result
+59 -18
View File
@@ -1,15 +1,16 @@
import base64
import gzip
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional
from typing import Dict, Iterable, Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from torch import Tensor, nn
from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function
from .decoding import detect_language as detect_language_function, decode as decode_function
@dataclass
@@ -34,12 +35,16 @@ class LayerNorm(nn.LayerNorm):
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
x,
self.weight.to(x.dtype),
None if self.bias is None else self.bias.to(x.dtype),
)
class Conv1d(nn.Conv1d):
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
@@ -85,7 +90,9 @@ class MultiHeadAttention(nn.Module):
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
):
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
@@ -108,11 +115,15 @@ class ResidualAttentionBlock(nn.Module):
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
self.cross_attn = (
MultiHeadAttention(n_state, n_head) if cross_attention else None
)
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
self.mlp = nn.Sequential(
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
)
self.mlp_ln = LayerNorm(n_state)
def forward(
@@ -130,7 +141,9 @@ class ResidualAttentionBlock(nn.Module):
class AudioEncoder(nn.Module):
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
@@ -161,14 +174,19 @@ class AudioEncoder(nn.Module):
class TextDecoder(nn.Module):
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
[
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
]
)
self.ln = LayerNorm(n_state)
@@ -183,14 +201,19 @@ class TextDecoder(nn.Module):
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = x.to(xa.dtype)
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.ln(x)
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits
@@ -213,6 +236,21 @@ class Whisper(nn.Module):
self.dims.n_text_head,
self.dims.n_text_layer,
)
# use the last half layers for alignment by default; see `set_alignment_heads()` below
all_heads = torch.zeros(
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
)
all_heads[self.dims.n_text_layer // 2 :] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(
gzip.decompress(base64.b85decode(dump)), dtype=bool
).copy()
mask = torch.from_numpy(array).reshape(
self.dims.n_text_layer, self.dims.n_text_head
)
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel)
@@ -220,7 +258,9 @@ class Whisper(nn.Module):
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder(tokens, audio_features)
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
def forward(
self, mel: torch.Tensor, tokens: torch.Tensor
) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
@property
@@ -249,8 +289,9 @@ class Whisper(nn.Module):
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
cache[module] = output # save as-is, for the first token or cross attention
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
# save as-is, for the first token or cross attention
cache[module] = output
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
+2 -2
View File
@@ -1,2 +1,2 @@
from .basic import BasicTextNormalizer
from .english import EnglishTextNormalizer
from .basic import BasicTextNormalizer as BasicTextNormalizer
from .english import EnglishTextNormalizer as EnglishTextNormalizer
+8 -3
View File
@@ -48,13 +48,16 @@ def remove_symbols(s: str):
Replace any other markers, symbols, punctuations with a space, keeping diacritics
"""
return "".join(
" " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
" " if unicodedata.category(c)[0] in "MSP" else c
for c in unicodedata.normalize("NFKC", s)
)
class BasicTextNormalizer:
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
self.clean = (
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
)
self.split_letters = split_letters
def __call__(self, s: str):
@@ -66,6 +69,8 @@ class BasicTextNormalizer:
if self.split_letters:
s = " ".join(regex.findall(r"\X", s, regex.U))
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
s = re.sub(
r"\s+", " ", s
) # replace any successive whitespace characters with a space
return s
+14 -7
View File
@@ -84,7 +84,8 @@ class EnglishNumberNormalizer:
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
}
self.tens_ordinal = {
name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()
name.replace("y", "ieth"): (value, "th")
for name, value in self.tens.items()
}
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
@@ -108,7 +109,10 @@ class EnglishNumberNormalizer:
self.multipliers_ordinal = {
name + "th": (value, "th") for name, value in self.multipliers.items()
}
self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal}
self.multipliers_suffixed = {
**self.multipliers_plural,
**self.multipliers_ordinal,
}
self.decimals = {*self.ones, *self.tens, *self.zeros}
self.preceding_prefixers = {
@@ -128,7 +132,8 @@ class EnglishNumberNormalizer:
"cents": "¢",
}
self.prefixes = set(
list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())
list(self.preceding_prefixers.values())
+ list(self.following_prefixers.values())
)
self.suffixers = {
"per": {"cent": "%"},
@@ -218,7 +223,9 @@ class EnglishNumberNormalizer:
if value is None:
value = ones
elif isinstance(value, str) or prev in self.ones:
if prev in self.tens and ones < 10: # replace the last zero with the digit
if (
prev in self.tens and ones < 10
): # replace the last zero with the digit
assert value[-1] == "0"
value = value[:-1] + str(ones)
else:
@@ -522,14 +529,14 @@ class EnglishTextNormalizer:
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
s = re.sub(self.ignore_patterns, "", s)
s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe
s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
for pattern, replacement in self.replacers.items():
s = re.sub(pattern, replacement, s)
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
s = self.standardize_numbers(s)
s = self.standardize_spellings(s)
@@ -538,6 +545,6 @@ class EnglishTextNormalizer:
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
s = re.sub(r"([^0-9])%", r"\1 ", s)
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
return s
+323
View File
@@ -0,0 +1,323 @@
import subprocess
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, List
import numba
import numpy as np
import torch
import torch.nn.functional as F
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
from .tokenizer import Tokenizer
if TYPE_CHECKING:
from .model import Whisper
def median_filter(x: torch.Tensor, filter_width: int):
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
pad_width = filter_width // 2
if x.shape[-1] <= pad_width:
# F.pad requires the padding width to be smaller than the input dimension
return x
if (ndim := x.ndim) <= 2:
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
x = x[None, None, :]
assert (
filter_width > 0 and filter_width % 2 == 1
), "`filter_width` should be an odd number"
result = None
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
if x.is_cuda:
try:
from .triton_ops import median_filter_cuda
result = median_filter_cuda(x, filter_width)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower median kernel implementation..."
)
if result is None:
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
if ndim <= 2:
result = result[0, 0]
return result
@numba.jit
def backtrace(trace: np.ndarray):
i = trace.shape[0] - 1
j = trace.shape[1] - 1
trace[0, :] = 2
trace[:, 0] = 1
result = []
while i > 0 or j > 0:
result.append((i - 1, j - 1))
if trace[i, j] == 0:
i -= 1
j -= 1
elif trace[i, j] == 1:
i -= 1
elif trace[i, j] == 2:
j -= 1
else:
raise ValueError("Unexpected trace[i, j]")
result = np.array(result)
return result[::-1, :].T
@numba.jit(nopython=True, parallel=True)
def dtw_cpu(x: np.ndarray):
N, M = x.shape
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
cost[0, 0] = 0
for j in range(1, M + 1):
for i in range(1, N + 1):
c0 = cost[i - 1, j - 1]
c1 = cost[i - 1, j]
c2 = cost[i, j - 1]
if c0 < c1 and c0 < c2:
c, t = c0, 0
elif c1 < c0 and c1 < c2:
c, t = c1, 1
else:
c, t = c2, 2
cost[i, j] = x[i - 1, j - 1] + c
trace[i, j] = t
return backtrace(trace)
def dtw_cuda(x, BLOCK_SIZE=1024):
from .triton_ops import dtw_kernel
M, N = x.shape
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
x_skew = (
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
)
x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0
cost = cost.cuda()
trace = torch.zeros_like(cost, dtype=torch.int32)
dtw_kernel[(1,)](
cost,
trace,
x_skew,
x_skew.stride(0),
cost.stride(0),
trace.stride(0),
N,
M,
BLOCK_SIZE=BLOCK_SIZE,
)
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
:, : N + 1
]
return backtrace(trace.cpu().numpy())
def dtw(x: torch.Tensor) -> np.ndarray:
if x.is_cuda:
try:
return dtw_cuda(x)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower DTW implementation..."
)
return dtw_cpu(x.double().cpu().numpy())
@dataclass
class WordTiming:
word: str
tokens: List[int]
start: float
end: float
probability: float
def find_alignment(
model: "Whisper",
tokenizer: Tokenizer,
text_tokens: List[int],
mel: torch.Tensor,
num_frames: int,
*,
medfilt_width: int = 7,
qk_scale: float = 1.0,
) -> List[WordTiming]:
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*text_tokens,
tokenizer.eot,
]
).to(model.device)
# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist()
for hook in hooks:
hook.remove()
# heads * tokens * frames
weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T])
weights = weights[:, :, : num_frames // 2]
weights = (weights * qk_scale).softmax(dim=-1)
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = (weights - mean) / std
weights = median_filter(weights, medfilt_width)
matrix = weights.mean(axis=0)
matrix = matrix[len(tokenizer.sot_sequence) : -1]
text_indices, time_indices = dtw(-matrix)
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
start_times = jump_times[word_boundaries[:-1]]
end_times = jump_times[word_boundaries[1:]]
word_probabilities = [
np.mean(text_token_probs[i:j])
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
]
# hack: ensure the first and second word is not longer than twice the median word duration.
# a better segmentation algorithm based on VAD should be able to replace this.
word_durations = end_times - start_times
word_durations = word_durations[word_durations.nonzero()]
if len(word_durations) > 0:
median_duration = np.median(word_durations)
max_duration = median_duration * 2
if len(word_durations) >= 2 and word_durations[1] > max_duration:
boundary = max(end_times[2] / 2, end_times[2] - max_duration)
end_times[0] = start_times[1] = boundary
if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
start_times[0] = max(0, end_times[0] - max_duration)
return [
WordTiming(word, tokens, start, end, probability)
for word, tokens, start, end, probability in zip(
words, word_tokens, start_times, end_times, word_probabilities
)
]
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
# merge prepended punctuations
i = len(alignment) - 2
j = len(alignment) - 1
while i >= 0:
previous = alignment[i]
following = alignment[j]
if previous.word.startswith(" ") and previous.word.strip() in prepended:
# prepend it to the following word
following.word = previous.word + following.word
following.tokens = previous.tokens + following.tokens
previous.word = ""
previous.tokens = []
else:
j = i
i -= 1
# merge appended punctuations
i = 0
j = 1
while j < len(alignment):
previous = alignment[i]
following = alignment[j]
if not previous.word.endswith(" ") and following.word in appended:
# append it to the previous word
previous.word = previous.word + following.word
previous.tokens = previous.tokens + following.tokens
following.word = ""
following.tokens = []
else:
i = j
j += 1
def add_word_timestamps(
*,
segments: List[dict],
model: "Whisper",
tokenizer: Tokenizer,
mel: torch.Tensor,
num_frames: int,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
**kwargs,
):
if len(segments) == 0:
return
text_tokens = [t for segment in segments for t in segment["tokens"]]
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
segment_lengths = [len(s["tokens"]) for s in segments]
token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
for segment in segments:
segment["words"] = []
word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0))
for i, timing in enumerate(alignment):
if timing.word:
segment = segments[token_sources[word_boundaries[i]]]
start = round(time_offset + timing.start, 2)
end = round(time_offset + timing.end, 2)
segment["words"].append(
dict(
word=timing.word,
start=start,
end=end,
probability=timing.probability,
)
)
for segment in segments:
if len(words := segment["words"]) > 0:
# adjust the segment-level timestamps based on the word-level timestamps
segment["start"] = words[0]["start"]
segment["end"] = words[-1]["end"]
+82 -33
View File
@@ -1,6 +1,7 @@
import os
import string
from dataclasses import dataclass
from functools import lru_cache
from functools import cached_property, lru_cache
from typing import List, Optional, Tuple, Union
import numpy as np
@@ -137,7 +138,9 @@ class Tokenizer:
def encode(self, text, **kwargs):
return self.tokenizer.encode(text, **kwargs)
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
def decode(
self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs
):
return self.tokenizer.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str:
@@ -153,50 +156,51 @@ class Tokenizer:
outputs.append([])
else:
outputs[-1].append(token)
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
return "".join(outputs)
return "".join(
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
)
@property
@lru_cache()
@cached_property
def eot(self) -> int:
return self.tokenizer.eos_token_id
@property
@lru_cache()
@cached_property
def transcribe(self) -> int:
return self._get_single_token_id("<|transcribe|>")
@cached_property
def translate(self) -> int:
return self._get_single_token_id("<|translate|>")
@cached_property
def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>")
@property
@lru_cache()
@cached_property
def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>")
@property
@lru_cache()
@cached_property
def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>")
@property
@lru_cache()
@cached_property
def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>")
@property
@lru_cache()
@cached_property
def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>")
@property
@lru_cache()
@cached_property
def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1
@property
@lru_cache()
@cached_property
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
raise ValueError(f"This tokenizer does not have language token configured")
raise ValueError("This tokenizer does not have language token configured")
additional_tokens = dict(
zip(
@@ -210,8 +214,7 @@ class Tokenizer:
raise KeyError(f"Language {self.language} not found in tokenizer.")
@property
@lru_cache()
@cached_property
def all_language_tokens(self) -> Tuple[int]:
result = []
for token, token_id in zip(
@@ -222,18 +225,15 @@ class Tokenizer:
result.append(token_id)
return tuple(result)
@property
@lru_cache()
@cached_property
def all_language_codes(self) -> Tuple[str]:
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
@property
@lru_cache()
@cached_property
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps])
@property
@lru_cache()
@cached_property
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
@@ -245,8 +245,10 @@ class Tokenizer:
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += (
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
)
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
@@ -258,7 +260,10 @@ class Tokenizer:
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
for tokens in [
self.tokenizer.encode(symbol),
self.tokenizer.encode(" " + symbol),
]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
@@ -269,6 +274,48 @@ class Tokenizer:
assert len(tokens) == 1, f"{text} is not encoded as a single token"
return tokens[0]
def split_to_word_tokens(self, tokens: List[int]):
if self.language in {"zh", "ja", "th", "lo", "my"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
return self.split_tokens_on_unicode(tokens)
return self.split_tokens_on_spaces(tokens)
def split_tokens_on_unicode(self, tokens: List[int]):
words = []
word_tokens = []
current_tokens = []
for token in tokens:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if "\ufffd" not in decoded:
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []
return words, word_tokens
def split_tokens_on_spaces(self, tokens: List[int]):
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
words = []
word_tokens = []
for subword, subword_tokens in zip(subwords, subword_tokens_list):
special = subword_tokens[0] >= self.eot
with_space = subword.startswith(" ")
punctuation = subword.strip() in string.punctuation
if special or with_space or punctuation or len(words) == 0:
words.append(subword)
word_tokens.append(subword_tokens)
else:
words[-1] = words[-1] + subword
word_tokens[-1].extend(subword_tokens)
return words, word_tokens
@lru_cache(maxsize=None)
def build_tokenizer(name: str = "gpt2"):
@@ -328,4 +375,6 @@ def get_tokenizer(
if task is not None:
sot_sequence.append(transcribe if task == "transcribe" else translate)
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
return Tokenizer(
tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
)
+209 -85
View File
@@ -1,16 +1,32 @@
import argparse
import os
import warnings
from typing import Optional, Tuple, Union, TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Tuple, Union
import numpy as np
import torch
import tqdm
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
from .audio import (
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
SAMPLE_RATE,
log_mel_spectrogram,
pad_or_trim,
)
from .decoding import DecodingOptions, DecodingResult
from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
from .utils import (
exact_div,
format_timestamp,
get_writer,
make_safe,
optional_float,
optional_int,
str2bool,
)
if TYPE_CHECKING:
from .model import Whisper
@@ -26,6 +42,10 @@ def transcribe(
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
**decode_options,
):
"""
@@ -62,6 +82,21 @@ def transcribe(
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
word_timestamps: bool
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
and include the timestamps for each word in each segment.
prepend_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the next word
append_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the previous word
initial_prompt: Optional[str]
Optional text to provide as a prompt for the first window. This can be used to provide, or
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
@@ -88,19 +123,28 @@ def transcribe(
decode_options["language"] = "en"
else:
if verbose:
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(segment)
print(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
)
language = decode_options["language"]
task = decode_options.get("task", "transcribe")
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
)
decode_result = None
for t in temperatures:
@@ -117,9 +161,15 @@ def transcribe(
decode_result = model.decode(segment, options)
needs_fallback = False
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
if (
compression_ratio_threshold is not None
and decode_result.compression_ratio > compression_ratio_threshold
):
needs_fallback = True # too repetitive
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
if (
logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = True # average log probability is too low
if not needs_fallback:
@@ -138,117 +188,187 @@ def transcribe(
all_segments = []
prompt_reset_since = 0
initial_prompt = decode_options.pop("initial_prompt", None) or []
if initial_prompt:
initial_prompt = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt)
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
else:
initial_prompt_tokens = []
def add_segment(
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
def new_segment(
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
):
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
if len(text.strip()) == 0: # skip empty text output
return
text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot]
return {
"id": len(all_segments),
"seek": seek,
"start": start,
"end": end,
"text": tokenizer.decode(text_tokens),
"tokens": text_tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
all_segments.append(
{
"id": len(all_segments),
"seek": seek,
"start": start,
"end": end,
"text": text,
"tokens": text_tokens.tolist(),
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
)
if verbose:
print(make_safe(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"))
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
# show the progress bar when verbose is False (if True, transcribed text will be printed)
num_frames = mel.shape[-1]
previous_seek_value = seek
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
with tqdm.tqdm(
total=num_frames, unit="frames", disable=verbose is not False
) as pbar:
while seek < num_frames:
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek:]
segment_size = min(mel_segment.shape[-1], N_FRAMES)
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(segment)
result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens)
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
if (
logprob_threshold is not None
and result.avg_logprob > logprob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
seek += segment.shape[-1] # fast-forward to the next segment boundary
seek += segment_size # fast-forward to the next segment boundary
continue
previous_seek = seek
current_segments = []
current_tokens = []
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[
0
].add_(1)
if (
len(consecutive) > 0
): # if the output contains two consecutive timestamp tokens
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [
False,
True,
]:
consecutive = consecutive.tolist() + [len(tokens)]
last_slice = 0
for current_slice in consecutive:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_position = (
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_position = (
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
add_segment(
start=timestamp_offset + start_timestamp_position * time_precision,
end=timestamp_offset + end_timestamp_position * time_precision,
text_tokens=sliced_tokens[1:-1],
result=result,
current_segments.append(
new_segment(
start=time_offset + start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens,
result=result,
)
)
current_tokens.append(sliced_tokens.tolist())
last_slice = current_slice
last_timestamp_position = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_position * input_stride
if ended_with_single_timestamp:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
if (
len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
duration = last_timestamp_position * time_precision
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
duration = last_timestamp_pos * time_precision
add_segment(
start=timestamp_offset,
end=timestamp_offset + duration,
text_tokens=tokens,
result=result,
current_segments.append(
new_segment(
start=time_offset,
end=time_offset + duration,
tokens=tokens,
result=result,
)
)
seek += segment.shape[-1]
all_tokens.extend(tokens.tolist())
current_tokens.append(tokens.tolist())
seek += segment_size
if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)
# update progress bar
pbar.update(min(num_frames, seek) - previous_seek_value)
previous_seek_value = seek
if word_timestamps:
add_word_timestamps(
segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
)
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if len(consecutive) > 0 and len(word_end_timestamps) > 0:
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
)
if seek_shift > 0:
seek = previous_seek + seek_shift
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
if verbose:
for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"]
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
print(make_safe(line))
# if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or segment["text"].strip() == "":
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
current_tokens[i] = []
all_segments.extend(current_segments)
all_tokens.extend(
[token for segment in current_tokens for token in segment]
)
# update progress bar
pbar.update(min(num_frames, seek) - previous_seek)
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
segments=all_segments,
language=language,
)
def cli():
from . import available_models
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
@@ -276,7 +396,11 @@ def cli():
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
# fmt: on
args = parser.parse_args().__dict__
model_name: str = args.pop("model")
@@ -288,29 +412,29 @@ def cli():
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
warnings.warn(
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
)
args["language"] = "en"
temperature = args.pop("temperature")
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
if temperature_increment_on_fallback is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
else:
temperature = [temperature]
threads = args.pop("threads")
if threads > 0:
if (threads := args.pop("threads")) > 0:
torch.set_num_threads(threads)
from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir)
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path)
if __name__ == '__main__':
if __name__ == "__main__":
cli()
+109
View File
@@ -0,0 +1,109 @@
from functools import lru_cache
import numpy as np
import torch
try:
import triton
import triton.language as tl
except ImportError:
raise RuntimeError("triton import failed; try `pip install --pre triton`")
@triton.jit
def dtw_kernel(
cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
):
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < M
for k in range(1, N + M + 1): # k = i + j
tl.debug_barrier()
p0 = cost + (k - 1) * cost_stride
p1 = cost + k * cost_stride
p2 = cost + k * cost_stride + 1
c0 = tl.load(p0 + offsets, mask=mask)
c1 = tl.load(p1 + offsets, mask=mask)
c2 = tl.load(p2 + offsets, mask=mask)
x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
cost_ptr = cost + (k + 1) * cost_stride + 1
tl.store(cost_ptr + offsets, cost_row, mask=mask)
trace_ptr = trace + (k + 1) * trace_stride + 1
tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
@lru_cache(maxsize=None)
def median_kernel(filter_width: int):
@triton.jit
def kernel(
y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
): # x.shape[-1] == filter_width
row_idx = tl.program_id(0)
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < y_stride
x_ptr = x + row_idx * x_stride # noqa: F841
y_ptr = y + row_idx * y_stride
LOAD_ALL_ROWS_HERE # noqa: F821
BUBBLESORT_HERE # noqa: F821
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
kernel = triton.JITFunction(kernel.fn)
kernel.src = kernel.src.replace(
" LOAD_ALL_ROWS_HERE",
"\n".join(
[
f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
for i in range(filter_width)
]
),
)
kernel.src = kernel.src.replace(
" BUBBLESORT_HERE",
"\n\n".join(
[
"\n\n".join(
[
"\n".join(
[
f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
f" row{j} = smaller",
f" row{j + 1} = larger",
]
)
for j in range(filter_width - i - 1)
]
)
for i in range(filter_width // 2 + 1)
]
),
)
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
return kernel
def median_filter_cuda(x: torch.Tensor, filter_width: int):
"""Apply a median filter of given width along the last dimension of x"""
slices = x.contiguous().unfold(-1, filter_width, 1)
grid = np.prod(slices.shape[:-2])
kernel = median_kernel(filter_width)
y = torch.empty_like(slices[..., 0])
BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
return y
+69 -27
View File
@@ -7,11 +7,14 @@ from typing import Callable, TextIO
system_encoding = sys.getdefaultencoding()
if system_encoding != "utf-8":
def make_safe(string):
# replaces any character not representable using the system default encoding with an '?',
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
return string.encode(system_encoding, errors="replace").decode(system_encoding)
else:
def make_safe(string):
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
return string
@@ -43,7 +46,9 @@ def compression_ratio(text) -> float:
return len(text_bytes) / len(zlib.compress(text_bytes))
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
def format_timestamp(
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
@@ -57,7 +62,9 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
class ResultWriter:
@@ -68,7 +75,10 @@ class ResultWriter:
def __call__(self, result: dict, audio_path: str):
audio_basename = os.path.basename(audio_path)
output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
)
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f)
@@ -82,37 +92,69 @@ class WriteTXT(ResultWriter):
def write_result(self, result: dict, file: TextIO):
for segment in result["segments"]:
print(segment['text'].strip(), file=file, flush=True)
print(segment["text"].strip(), file=file, flush=True)
class WriteVTT(ResultWriter):
class SubtitlesWriter(ResultWriter):
always_include_hours: bool
decimal_marker: str
def iterate_result(self, result: dict):
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")
if word_timings := segment.get("words", None):
all_words = [timing["word"] for timing in word_timings]
all_words[0] = all_words[0].strip() # remove the leading space, if any
last = segment_start
for i, this_word in enumerate(word_timings):
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, segment_text
yield start, end, "".join(
[
f"<u>{word}</u>" if j == i else word
for j, word in enumerate(all_words)
]
)
last = end
if last != segment_end:
yield last, segment_end, segment_text
else:
yield segment_start, segment_end, segment_text
def format_timestamp(self, seconds: float):
return format_timestamp(
seconds=seconds,
always_include_hours=self.always_include_hours,
decimal_marker=self.decimal_marker,
)
class WriteVTT(SubtitlesWriter):
extension: str = "vtt"
always_include_hours: bool = False
decimal_marker: str = "."
def write_result(self, result: dict, file: TextIO):
print("WEBVTT\n", file=file)
for segment in result["segments"]:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
for start, end, text in self.iterate_result(result):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteSRT(ResultWriter):
class WriteSRT(SubtitlesWriter):
extension: str = "srt"
always_include_hours: bool = True
decimal_marker: str = ","
def write_result(self, result: dict, file: TextIO):
for i, segment in enumerate(result["segments"], start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteTSV(ResultWriter):
@@ -124,14 +166,15 @@ class WriteTSV(ResultWriter):
an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
"""
extension: str = "tsv"
def write_result(self, result: dict, file: TextIO):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment['start']), file=file, end="\t")
print(round(1000 * segment['end']), file=file, end="\t")
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
print(round(1000 * segment["start"]), file=file, end="\t")
print(round(1000 * segment["end"]), file=file, end="\t")
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter):
@@ -160,4 +203,3 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO],
return write_all
return writers[output_format](output_dir)
+1 -1
View File
@@ -1 +1 @@
__version__ = "20230124"
__version__ = "20230306"