Compare commits
25 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6dea21fd7f | |||
| 79c43e4859 | |||
| 5f9ac653b7 | |||
| ba88b8e1b3 | |||
| 671ac5a4ce | |||
| 839639a223 | |||
| ad3250a846 | |||
| c4b50c0824 | |||
| 38f2f4d99d | |||
| aac47c9834 | |||
| 26807ec6d3 | |||
| 919a713499 | |||
| 38e990d853 | |||
| 924e1f8e06 | |||
| 4b0d5e58d0 | |||
| 8180fde939 | |||
| c6e4e5efb3 | |||
| b80bcf610d | |||
| 500d0fe966 | |||
| eab8d920ed | |||
| 3e1780fd37 | |||
| 7858aa9c08 | |||
| 5c1a8c10e7 | |||
| 4e635c6644 | |||
| a6b36ede1f |
@@ -0,0 +1,3 @@
|
||||
# Override jupyter in Github language stats for more accurate estimate of repo code languages
|
||||
# reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code
|
||||
*.ipynb linguist-generated
|
||||
@@ -21,6 +21,8 @@ jobs:
|
||||
- run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
|
||||
- uses: actions/checkout@v2
|
||||
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
|
||||
- run: pip install pytest
|
||||
- run: pip install .
|
||||
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]'
|
||||
- run: pip install .["dev"]
|
||||
- run: black --check --diff -t py38 --include '(\.pyi?)$' .
|
||||
- run: isort --check --diff .
|
||||
- run: flake8 --ignore E203,W503,W504,E501,E731,E741 .
|
||||
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
# CHANGELOG
|
||||
|
||||
## [v20230314](https://github.com/openai/whisper/releases/tag/v20230314)
|
||||
|
||||
* abort find_alignment on empty input ([#1090](https://github.com/openai/whisper/pull/1090))
|
||||
* Fix truncated words list when the replacement character is decoded ([#1089](https://github.com/openai/whisper/pull/1089))
|
||||
* fix github language stats getting dominated by jupyter notebook ([#1076](https://github.com/openai/whisper/pull/1076))
|
||||
* Fix alignment between the segments and the list of words ([#1087](https://github.com/openai/whisper/pull/1087))
|
||||
* Use tiktoken ([#1044](https://github.com/openai/whisper/pull/1044))
|
||||
|
||||
## [v20230308](https://github.com/openai/whisper/releases/tag/v20230308)
|
||||
|
||||
* kwargs in decode() for convenience ([#1061](https://github.com/openai/whisper/pull/1061))
|
||||
* fix all_tokens handling that caused more repetitions and discrepancy in JSON ([#1060](https://github.com/openai/whisper/pull/1060))
|
||||
* fix typo in CHANGELOG.md
|
||||
|
||||
## [v20230307](https://github.com/openai/whisper/releases/tag/v20230307)
|
||||
|
||||
* Fix the repetition/hallucination issue identified in #1046 ([#1052](https://github.com/openai/whisper/pull/1052))
|
||||
* Use triton==2.0.0 ([#1053](https://github.com/openai/whisper/pull/1053))
|
||||
* Install triton in x86_64 linux only ([#1051](https://github.com/openai/whisper/pull/1051))
|
||||
* update setup.py to specify python >= 3.8 requirement
|
||||
|
||||
## [v20230306](https://github.com/openai/whisper/releases/tag/v20230306)
|
||||
|
||||
* remove auxiliary audio extension ([#1021](https://github.com/openai/whisper/pull/1021))
|
||||
* apply formatting with `black`, `isort`, and `flake8` ([#1038](https://github.com/openai/whisper/pull/1038))
|
||||
* word-level timestamps in `transcribe()` ([#869](https://github.com/openai/whisper/pull/869))
|
||||
* Decoding improvements ([#1033](https://github.com/openai/whisper/pull/1033))
|
||||
* Update README.md ([#894](https://github.com/openai/whisper/pull/894))
|
||||
* Fix infinite loop caused by incorrect timestamp tokens prediction ([#914](https://github.com/openai/whisper/pull/914))
|
||||
* drop python 3.7 support ([#889](https://github.com/openai/whisper/pull/889))
|
||||
|
||||
## [v20230124](https://github.com/openai/whisper/releases/tag/v20230124)
|
||||
|
||||
* handle printing even if sys.stdout.buffer is not available ([#887](https://github.com/openai/whisper/pull/887))
|
||||
* Add TSV formatted output in transcript, using integer start/end time in milliseconds ([#228](https://github.com/openai/whisper/pull/228))
|
||||
* Added `--output_format` option ([#333](https://github.com/openai/whisper/pull/333))
|
||||
* Handle `XDG_CACHE_HOME` properly for `download_root` ([#864](https://github.com/openai/whisper/pull/864))
|
||||
* use stdout for printing transcription progress ([#867](https://github.com/openai/whisper/pull/867))
|
||||
* Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm ([#659](https://github.com/openai/whisper/pull/659))
|
||||
* print '?' if a letter can't be encoded using the system default encoding ([#859](https://github.com/openai/whisper/pull/859))
|
||||
|
||||
## [v20230117](https://github.com/openai/whisper/releases/tag/v20230117)
|
||||
|
||||
The first versioned release available on [PyPI](https://pypi.org/project/openai-whisper/)
|
||||
@@ -2,6 +2,4 @@ include requirements.txt
|
||||
include README.md
|
||||
include LICENSE
|
||||
include whisper/assets/*
|
||||
include whisper/assets/gpt2/*
|
||||
include whisper/assets/multilingual/*
|
||||
include whisper/normalizers/english.json
|
||||
|
||||
@@ -5,19 +5,19 @@
|
||||
[[Model card]](https://github.com/openai/whisper/blob/main/model-card.md)
|
||||
[[Colab example]](https://colab.research.google.com/github/openai/whisper/blob/master/notebooks/LibriSpeech.ipynb)
|
||||
|
||||
Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multi-task model that can perform multilingual speech recognition as well as speech translation and language identification.
|
||||
Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitasking model that can perform multilingual speech recognition, speech translation, and language identification.
|
||||
|
||||
|
||||
## Approach
|
||||
|
||||

|
||||
|
||||
A Transformer sequence-to-sequence model is trained on various speech processing tasks, including multilingual speech recognition, speech translation, spoken language identification, and voice activity detection. All of these tasks are jointly represented as a sequence of tokens to be predicted by the decoder, allowing for a single model to replace many different stages of a traditional speech processing pipeline. The multitask training format uses a set of special tokens that serve as task specifiers or classification targets.
|
||||
A Transformer sequence-to-sequence model is trained on various speech processing tasks, including multilingual speech recognition, speech translation, spoken language identification, and voice activity detection. These tasks are jointly represented as a sequence of tokens to be predicted by the decoder, allowing a single model to replace many stages of a traditional speech-processing pipeline. The multitask training format uses a set of special tokens that serve as task specifiers or classification targets.
|
||||
|
||||
|
||||
## Setup
|
||||
|
||||
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.7 or later and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. You can download and install (or update to) the latest release of Whisper with the following command:
|
||||
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.8-3.10 and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. You can download and install (or update to) the latest release of Whisper with the following command:
|
||||
|
||||
pip install -U openai-whisper
|
||||
|
||||
@@ -68,9 +68,9 @@ There are five model sizes, four with English-only versions, offering speed and
|
||||
| medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x |
|
||||
| large | 1550 M | N/A | `large` | ~10 GB | 1x |
|
||||
|
||||
For English-only applications, the `.en` models tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
|
||||
The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
|
||||
|
||||
Whisper's performance varies widely depending on the language. The figure below shows a WER (Word Error Rate) breakdown by languages of Fleurs dataset, using the `large-v2` model. More WER and BLEU scores corresponding to the other models and datasets can be found in Appendix D in [the paper](https://arxiv.org/abs/2212.04356). The smaller is better.
|
||||
Whisper's performance varies widely depending on the language. The figure below shows a WER (Word Error Rate) breakdown by languages of the Fleurs dataset using the `large-v2` model. More WER and BLEU scores corresponding to the other models and datasets can be found in Appendix D in [the paper](https://arxiv.org/abs/2212.04356). The smaller, the better.
|
||||
|
||||

|
||||
|
||||
@@ -144,4 +144,4 @@ Please use the [🙌 Show and tell](https://github.com/openai/whisper/discussion
|
||||
|
||||
## License
|
||||
|
||||
The code and the model weights of Whisper are released under the MIT License. See [LICENSE](https://github.com/openai/whisper/blob/main/LICENSE) for further details.
|
||||
Whisper's code and model weights are released under the MIT License. See [LICENSE](https://github.com/openai/whisper/blob/main/LICENSE) for further details.
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
[tool.black]
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
include_trailing_comma = true
|
||||
line_length = 88
|
||||
multi_line_output = 3
|
||||
|
||||
+2
-1
@@ -1,6 +1,7 @@
|
||||
numba
|
||||
numpy
|
||||
torch
|
||||
tqdm
|
||||
more-itertools
|
||||
transformers>=4.19.0
|
||||
tiktoken==0.3.1
|
||||
ffmpeg-python==0.2.0
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
|
||||
import pkg_resources
|
||||
from setuptools import setup, find_packages
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def read_version(fname="whisper/version.py"):
|
||||
@@ -9,6 +11,10 @@ def read_version(fname="whisper/version.py"):
|
||||
return locals()["__version__"]
|
||||
|
||||
|
||||
requirements = []
|
||||
if sys.platform.startswith("linux") and platform.machine() == "x86_64":
|
||||
requirements.append("triton==2.0.0")
|
||||
|
||||
setup(
|
||||
name="openai-whisper",
|
||||
py_modules=["whisper"],
|
||||
@@ -17,12 +23,13 @@ setup(
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
readme="README.md",
|
||||
python_requires=">=3.7",
|
||||
python_requires=">=3.8",
|
||||
author="OpenAI",
|
||||
url="https://github.com/openai/whisper",
|
||||
license="MIT",
|
||||
packages=find_packages(exclude=["tests*"]),
|
||||
install_requires=[
|
||||
install_requires=requirements
|
||||
+ [
|
||||
str(r)
|
||||
for r in pkg_resources.parse_requirements(
|
||||
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||
@@ -32,5 +39,5 @@ setup(
|
||||
"console_scripts": ["whisper=whisper.transcribe:cli"],
|
||||
},
|
||||
include_package_data=True,
|
||||
extras_require={"dev": ["pytest"]},
|
||||
extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
import random as rand
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "requires_cuda")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random():
|
||||
rand.seed(42)
|
||||
numpy.random.seed(42)
|
||||
+1
-1
@@ -2,7 +2,7 @@ import os.path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from whisper.audio import load_audio, log_mel_spectrogram, SAMPLE_RATE
|
||||
from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
|
||||
|
||||
def test_audio():
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import pytest
|
||||
|
||||
from whisper.normalizers import EnglishTextNormalizer
|
||||
from whisper.normalizers.english import EnglishNumberNormalizer, EnglishSpellingNormalizer
|
||||
from whisper.normalizers.english import (
|
||||
EnglishNumberNormalizer,
|
||||
EnglishSpellingNormalizer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()])
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import scipy.ndimage
|
||||
import torch
|
||||
|
||||
from whisper.timing import dtw_cpu, dtw_cuda, median_filter
|
||||
|
||||
sizes = [
|
||||
(10, 20),
|
||||
(32, 16),
|
||||
(123, 1500),
|
||||
(234, 189),
|
||||
]
|
||||
shapes = [
|
||||
(10,),
|
||||
(1, 15),
|
||||
(4, 5, 345),
|
||||
(6, 12, 240, 512),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("N, M", sizes)
|
||||
def test_dtw(N: int, M: int):
|
||||
steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)])
|
||||
np.random.shuffle(steps)
|
||||
x = np.random.random((N, M)).astype(np.float32)
|
||||
|
||||
i, j, k = 0, 0, 0
|
||||
trace = []
|
||||
while True:
|
||||
x[i, j] -= 1
|
||||
trace.append((i, j))
|
||||
|
||||
if k == len(steps):
|
||||
break
|
||||
|
||||
if k + 1 < len(steps) and steps[k] != steps[k + 1]:
|
||||
i += 1
|
||||
j += 1
|
||||
k += 2
|
||||
continue
|
||||
|
||||
if steps[k] == 0:
|
||||
i += 1
|
||||
if steps[k] == 1:
|
||||
j += 1
|
||||
k += 1
|
||||
|
||||
trace = np.array(trace).T
|
||||
dtw_trace = dtw_cpu(x)
|
||||
|
||||
assert np.allclose(trace, dtw_trace)
|
||||
|
||||
|
||||
@pytest.mark.requires_cuda
|
||||
@pytest.mark.parametrize("N, M", sizes)
|
||||
def test_dtw_cuda_equivalence(N: int, M: int):
|
||||
x_numpy = np.random.randn(N, M).astype(np.float32)
|
||||
x_cuda = torch.from_numpy(x_numpy).cuda()
|
||||
|
||||
trace_cpu = dtw_cpu(x_numpy)
|
||||
trace_cuda = dtw_cuda(x_cuda)
|
||||
|
||||
assert np.allclose(trace_cpu, trace_cuda)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", shapes)
|
||||
def test_median_filter(shape):
|
||||
x = torch.randn(*shape)
|
||||
|
||||
for filter_width in [3, 5, 7, 13]:
|
||||
filtered = median_filter(x, filter_width)
|
||||
|
||||
# using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
|
||||
pad_width = filter_width // 2
|
||||
padded_x = np.pad(
|
||||
x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect"
|
||||
)
|
||||
scipy_filtered = scipy.ndimage.median_filter(
|
||||
padded_x, [1] * (x.ndim - 1) + [filter_width]
|
||||
)
|
||||
scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
|
||||
|
||||
assert np.allclose(filtered, scipy_filtered)
|
||||
|
||||
|
||||
@pytest.mark.requires_cuda
|
||||
@pytest.mark.parametrize("shape", shapes)
|
||||
def test_median_filter_equivalence(shape):
|
||||
x = torch.randn(*shape)
|
||||
|
||||
for filter_width in [3, 5, 7, 13]:
|
||||
filtered_cpu = median_filter(x, filter_width)
|
||||
filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
|
||||
|
||||
assert np.allclose(filtered_cpu, filtered_gpu)
|
||||
@@ -12,3 +12,13 @@ def test_tokenizer():
|
||||
assert gpt2_tokenizer.decode(gpt2_tokens) == text
|
||||
assert multilingual_tokenizer.decode(multilingual_tokens) == text
|
||||
assert len(gpt2_tokens) > len(multilingual_tokens)
|
||||
|
||||
|
||||
def test_split_on_unicode():
|
||||
multilingual_tokenizer = get_tokenizer(multilingual=True)
|
||||
|
||||
tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378]
|
||||
words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens)
|
||||
|
||||
assert words == [" elle", " est", " l", "'", "�", "é", "rit", "oire"]
|
||||
assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]]
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import whisper
|
||||
from whisper.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", whisper.available_models())
|
||||
@@ -13,10 +14,29 @@ def test_transcribe(model_name: str):
|
||||
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
||||
|
||||
language = "en" if model_name.endswith(".en") else None
|
||||
result = model.transcribe(audio_path, language=language, temperature=0.0)
|
||||
result = model.transcribe(
|
||||
audio_path, language=language, temperature=0.0, word_timestamps=True
|
||||
)
|
||||
assert result["language"] == "en"
|
||||
assert result["text"] == "".join([s["text"] for s in result["segments"]])
|
||||
|
||||
transcription = result["text"].lower()
|
||||
assert "my fellow americans" in transcription
|
||||
assert "your country" in transcription
|
||||
assert "do for you" in transcription
|
||||
|
||||
tokenizer = get_tokenizer(model.is_multilingual)
|
||||
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
|
||||
assert tokenizer.decode(all_tokens) == result["text"]
|
||||
assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")
|
||||
|
||||
timing_checked = False
|
||||
for segment in result["segments"]:
|
||||
for timing in segment["words"]:
|
||||
assert timing["start"] < timing["end"]
|
||||
if timing["word"].strip(" ,") == "Americans":
|
||||
assert timing["start"] <= 1.8
|
||||
assert timing["end"] >= 1.8
|
||||
timing_checked = True
|
||||
|
||||
assert timing_checked
|
||||
|
||||
+49
-17
@@ -10,11 +10,10 @@ from tqdm import tqdm
|
||||
|
||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
from .model import Whisper, ModelDimensions
|
||||
from .model import ModelDimensions, Whisper
|
||||
from .transcribe import transcribe
|
||||
from .version import __version__
|
||||
|
||||
|
||||
_MODELS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||
@@ -29,6 +28,22 @@ _MODELS = {
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
}
|
||||
|
||||
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
||||
_ALIGNMENT_HEADS = {
|
||||
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
||||
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
||||
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
||||
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
||||
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
||||
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
||||
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
||||
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
||||
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||
"large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
os.makedirs(root, exist_ok=True)
|
||||
@@ -45,10 +60,18 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return model_bytes if in_memory else download_target
|
||||
else:
|
||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
warnings.warn(
|
||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
||||
with tqdm(
|
||||
total=int(source.info().get("Content-Length")),
|
||||
ncols=80,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
@@ -59,7 +82,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
return model_bytes if in_memory else download_target
|
||||
|
||||
@@ -69,7 +94,12 @@ def available_models() -> List[str]:
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
|
||||
def load_model(
|
||||
name: str,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
download_root: str = None,
|
||||
in_memory: bool = False,
|
||||
) -> Whisper:
|
||||
"""
|
||||
Load a Whisper ASR model
|
||||
|
||||
@@ -94,24 +124,23 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if download_root is None:
|
||||
download_root = os.path.join(
|
||||
os.getenv(
|
||||
"XDG_CACHE_HOME",
|
||||
os.path.join(
|
||||
os.path.expanduser("~"), ".cache"
|
||||
)
|
||||
),
|
||||
"whisper"
|
||||
)
|
||||
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
alignment_heads = _ALIGNMENT_HEADS[name]
|
||||
elif os.path.isfile(name):
|
||||
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||
alignment_heads = None
|
||||
else:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
raise RuntimeError(
|
||||
f"Model {name} not found; available models = {available_models()}"
|
||||
)
|
||||
|
||||
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
|
||||
with (
|
||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||
) as fp:
|
||||
checkpoint = torch.load(fp, map_location=device)
|
||||
del checkpoint_file
|
||||
|
||||
@@ -119,4 +148,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
|
||||
model = Whisper(dims)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
if alignment_heads is not None:
|
||||
model.set_alignment_heads(alignment_heads)
|
||||
|
||||
return model.to(device)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from .transcribe import cli
|
||||
|
||||
|
||||
cli()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1 +0,0 @@
|
||||
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
||||
@@ -1 +0,0 @@
|
||||
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -1 +0,0 @@
|
||||
{"<|endoftext|>": 50257}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1 +0,0 @@
|
||||
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
||||
@@ -1 +0,0 @@
|
||||
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}
|
||||
File diff suppressed because one or more lines are too long
+29
-6
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
@@ -15,8 +15,12 @@ N_FFT = 400
|
||||
N_MELS = 80
|
||||
HOP_LENGTH = 160
|
||||
CHUNK_LENGTH = 30
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
||||
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
||||
|
||||
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
||||
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
||||
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
||||
|
||||
|
||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
@@ -55,7 +59,9 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||
"""
|
||||
if torch.is_tensor(array):
|
||||
if array.shape[axis] > length:
|
||||
array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
|
||||
array = array.index_select(
|
||||
dim=axis, index=torch.arange(length, device=array.device)
|
||||
)
|
||||
|
||||
if array.shape[axis] < length:
|
||||
pad_widths = [(0, 0)] * array.ndim
|
||||
@@ -85,11 +91,18 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
||||
)
|
||||
"""
|
||||
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
||||
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
|
||||
with np.load(
|
||||
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||
) as f:
|
||||
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
||||
|
||||
|
||||
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
|
||||
def log_mel_spectrogram(
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
n_mels: int = N_MELS,
|
||||
padding: int = 0,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
"""
|
||||
Compute the log-Mel spectrogram of
|
||||
|
||||
@@ -101,6 +114,12 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
|
||||
n_mels: int
|
||||
The number of Mel-frequency filters, only 80 is supported
|
||||
|
||||
padding: int
|
||||
Number of zero samples to pad to the right
|
||||
|
||||
device: Optional[Union[str, torch.device]]
|
||||
If given, the audio tensor is moved to this device before STFT
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor, shape = (80, n_frames)
|
||||
@@ -111,6 +130,10 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
|
||||
audio = load_audio(audio)
|
||||
audio = torch.from_numpy(audio)
|
||||
|
||||
if device is not None:
|
||||
audio = audio.to(device)
|
||||
if padding > 0:
|
||||
audio = F.pad(audio, (0, padding))
|
||||
window = torch.hann_window(N_FFT).to(audio.device)
|
||||
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
|
||||
+171
-68
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -16,7 +16,9 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
|
||||
def detect_language(
|
||||
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
|
||||
) -> Tuple[Tensor, List[dict]]:
|
||||
"""
|
||||
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
||||
of the most probable language tokens and the probability distribution over all language tokens.
|
||||
@@ -31,8 +33,13 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
|
||||
"""
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer(model.is_multilingual)
|
||||
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
|
||||
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
|
||||
if (
|
||||
tokenizer.language is None
|
||||
or tokenizer.language_token not in tokenizer.sot_sequence
|
||||
):
|
||||
raise ValueError(
|
||||
"This model doesn't have language tokens so it can't perform lang id"
|
||||
)
|
||||
|
||||
single = mel.ndim == 2
|
||||
if single:
|
||||
@@ -70,31 +77,36 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodingOptions:
|
||||
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
|
||||
language: Optional[str] = None # language that the audio is in; uses detected language if None
|
||||
# whether to perform X->X "transcribe" or X->English "translate"
|
||||
task: str = "transcribe"
|
||||
|
||||
# language that the audio is in; uses detected language if None
|
||||
language: Optional[str] = None
|
||||
|
||||
# sampling-related options
|
||||
temperature: float = 0.0
|
||||
sample_len: Optional[int] = None # maximum number of tokens to sample
|
||||
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
|
||||
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
|
||||
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
|
||||
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
|
||||
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
|
||||
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
|
||||
|
||||
# options for ranking generations (either beams or best-of-N samples)
|
||||
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
|
||||
# "alpha" in Google NMT, or None for length norm, when ranking generations
|
||||
# to select which to return among the beams or best-of-N samples
|
||||
length_penalty: Optional[float] = None
|
||||
|
||||
# prompt, prefix, and token suppression
|
||||
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
|
||||
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
|
||||
suppress_blank: bool = True # this will suppress blank outputs
|
||||
# text or tokens to feed as the prompt or the prefix; for more info:
|
||||
# https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||
prompt: Optional[Union[str, List[int]]] = None # for the previous context
|
||||
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
|
||||
|
||||
# list of tokens ids (or comma-separated token ids) to suppress
|
||||
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
||||
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
||||
suppress_blank: bool = True # this will suppress blank outputs
|
||||
|
||||
# timestamp sampling options
|
||||
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
||||
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
|
||||
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
||||
max_initial_timestamp: Optional[float] = 1.0
|
||||
|
||||
# implementation details
|
||||
fp16: bool = True # use fp16 for most of the calculation
|
||||
@@ -158,7 +170,9 @@ class PyTorchInference(Inference):
|
||||
|
||||
|
||||
class SequenceRanker:
|
||||
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
|
||||
def rank(
|
||||
self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
|
||||
) -> List[int]:
|
||||
"""
|
||||
Given a list of groups of samples and their cumulative log probabilities,
|
||||
return the indices of the samples in each group to select as the final result
|
||||
@@ -196,7 +210,9 @@ class TokenDecoder:
|
||||
def reset(self):
|
||||
"""Initialize any stateful variables for decoding a new sequence"""
|
||||
|
||||
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||
def update(
|
||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||
) -> Tuple[Tensor, bool]:
|
||||
"""Specify how to select the next token, based on the current trace and logits
|
||||
|
||||
Parameters
|
||||
@@ -251,12 +267,13 @@ class GreedyDecoder(TokenDecoder):
|
||||
self.temperature = temperature
|
||||
self.eot = eot
|
||||
|
||||
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||
temperature = self.temperature
|
||||
if temperature == 0:
|
||||
def update(
|
||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||
) -> Tuple[Tensor, bool]:
|
||||
if self.temperature == 0:
|
||||
next_tokens = logits.argmax(dim=-1)
|
||||
else:
|
||||
next_tokens = Categorical(logits=logits / temperature).sample()
|
||||
next_tokens = Categorical(logits=logits / self.temperature).sample()
|
||||
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
||||
@@ -275,7 +292,13 @@ class GreedyDecoder(TokenDecoder):
|
||||
|
||||
|
||||
class BeamSearchDecoder(TokenDecoder):
|
||||
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
|
||||
def __init__(
|
||||
self,
|
||||
beam_size: int,
|
||||
eot: int,
|
||||
inference: Inference,
|
||||
patience: Optional[float] = None,
|
||||
):
|
||||
self.beam_size = beam_size
|
||||
self.eot = eot
|
||||
self.inference = inference
|
||||
@@ -283,12 +306,16 @@ class BeamSearchDecoder(TokenDecoder):
|
||||
self.max_candidates: int = round(beam_size * self.patience)
|
||||
self.finished_sequences = None
|
||||
|
||||
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
|
||||
assert (
|
||||
self.max_candidates > 0
|
||||
), f"Invalid beam size ({beam_size}) or patience ({patience})"
|
||||
|
||||
def reset(self):
|
||||
self.finished_sequences = None
|
||||
|
||||
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||
def update(
|
||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||
) -> Tuple[Tensor, bool]:
|
||||
if tokens.shape[0] % self.beam_size != 0:
|
||||
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
||||
|
||||
@@ -332,7 +359,9 @@ class BeamSearchDecoder(TokenDecoder):
|
||||
|
||||
# add newly finished sequences to self.finished_sequences
|
||||
assert len(self.finished_sequences) == len(finished_sequences)
|
||||
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
|
||||
for previously_finished, newly_finished in zip(
|
||||
self.finished_sequences, finished_sequences
|
||||
):
|
||||
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
||||
if len(previously_finished) >= self.max_candidates:
|
||||
break # the candidate list is full
|
||||
@@ -340,7 +369,8 @@ class BeamSearchDecoder(TokenDecoder):
|
||||
|
||||
# mark as completed if all audio has enough number of samples
|
||||
completed = all(
|
||||
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
|
||||
len(sequences) >= self.max_candidates
|
||||
for sequences in self.finished_sequences
|
||||
)
|
||||
return tokens, completed
|
||||
|
||||
@@ -348,7 +378,9 @@ class BeamSearchDecoder(TokenDecoder):
|
||||
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
||||
sum_logprobs = sum_logprobs.cpu()
|
||||
for i, sequences in enumerate(self.finished_sequences):
|
||||
if len(sequences) < self.beam_size: # when not enough sequences are finished
|
||||
if (
|
||||
len(sequences) < self.beam_size
|
||||
): # when not enough sequences are finished
|
||||
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
||||
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
||||
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
||||
@@ -356,7 +388,8 @@ class BeamSearchDecoder(TokenDecoder):
|
||||
break
|
||||
|
||||
tokens: List[List[Tensor]] = [
|
||||
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
|
||||
[torch.tensor(seq) for seq in sequences.keys()]
|
||||
for sequences in self.finished_sequences
|
||||
]
|
||||
sum_logprobs: List[List[float]] = [
|
||||
list(sequences.values()) for sequences in self.finished_sequences
|
||||
@@ -400,7 +433,10 @@ class SuppressTokens(LogitFilter):
|
||||
|
||||
class ApplyTimestampRules(LogitFilter):
|
||||
def __init__(
|
||||
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
|
||||
self,
|
||||
tokenizer: Tokenizer,
|
||||
sample_begin: int,
|
||||
max_initial_timestamp_index: Optional[int],
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.sample_begin = sample_begin
|
||||
@@ -413,9 +449,14 @@ class ApplyTimestampRules(LogitFilter):
|
||||
|
||||
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||
for k in range(tokens.shape[0]):
|
||||
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
|
||||
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
||||
sampled_tokens = tokens[k, self.sample_begin :]
|
||||
seq = [t for t in sampled_tokens.tolist()]
|
||||
last_was_timestamp = (
|
||||
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||
)
|
||||
penultimate_was_timestamp = (
|
||||
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
||||
)
|
||||
|
||||
if last_was_timestamp:
|
||||
if penultimate_was_timestamp: # has to be non-timestamp
|
||||
@@ -423,19 +464,30 @@ class ApplyTimestampRules(LogitFilter):
|
||||
else: # cannot be normal text tokens
|
||||
logits[k, : self.tokenizer.eot] = -np.inf
|
||||
|
||||
timestamps = sampled_tokens[
|
||||
sampled_tokens.ge(self.tokenizer.timestamp_begin)
|
||||
]
|
||||
if timestamps.numel() > 0:
|
||||
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
|
||||
logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
|
||||
|
||||
if tokens.shape[1] == self.sample_begin:
|
||||
# suppress generating non-timestamp tokens at the beginning
|
||||
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|
||||
|
||||
# apply the `max_initial_timestamp` option
|
||||
if self.max_initial_timestamp_index is not None:
|
||||
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
||||
last_allowed = (
|
||||
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
||||
)
|
||||
logits[:, last_allowed + 1 :] = -np.inf
|
||||
|
||||
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
for k in range(tokens.shape[0]):
|
||||
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
|
||||
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
|
||||
dim=-1
|
||||
)
|
||||
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
||||
if timestamp_logprob > max_text_token_logprob:
|
||||
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
||||
@@ -451,7 +503,9 @@ class DecodingTask:
|
||||
self.model = model
|
||||
|
||||
language = options.language or "en"
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
|
||||
tokenizer = get_tokenizer(
|
||||
model.is_multilingual, language=language, task=options.task
|
||||
)
|
||||
self.tokenizer: Tokenizer = tokenizer
|
||||
self.options: DecodingOptions = self._verify_options(options)
|
||||
|
||||
@@ -491,9 +545,13 @@ class DecodingTask:
|
||||
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
||||
max_initial_timestamp_index = None
|
||||
if options.max_initial_timestamp:
|
||||
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
|
||||
max_initial_timestamp_index = round(
|
||||
self.options.max_initial_timestamp / precision
|
||||
)
|
||||
self.logit_filters.append(
|
||||
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
|
||||
ApplyTimestampRules(
|
||||
tokenizer, self.sample_begin, max_initial_timestamp_index
|
||||
)
|
||||
)
|
||||
|
||||
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
||||
@@ -504,30 +562,38 @@ class DecodingTask:
|
||||
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
||||
if options.patience is not None and options.beam_size is None:
|
||||
raise ValueError("patience requires beam_size to be given")
|
||||
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
|
||||
if options.length_penalty is not None and not (
|
||||
0 <= options.length_penalty <= 1
|
||||
):
|
||||
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
||||
|
||||
return options
|
||||
|
||||
def _get_initial_tokens(self) -> Tuple[int]:
|
||||
tokens = list(self.sot_sequence)
|
||||
prefix = self.options.prefix
|
||||
prompt = self.options.prompt
|
||||
|
||||
if prefix:
|
||||
if prefix := self.options.prefix:
|
||||
prefix_tokens = (
|
||||
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
|
||||
self.tokenizer.encode(" " + prefix.strip())
|
||||
if isinstance(prefix, str)
|
||||
else prefix
|
||||
)
|
||||
if self.sample_len is not None:
|
||||
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
||||
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
||||
tokens = tokens + prefix_tokens
|
||||
|
||||
if prompt:
|
||||
if prompt := self.options.prompt:
|
||||
prompt_tokens = (
|
||||
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
|
||||
self.tokenizer.encode(" " + prompt.strip())
|
||||
if isinstance(prompt, str)
|
||||
else prompt
|
||||
)
|
||||
tokens = (
|
||||
[self.tokenizer.sot_prev]
|
||||
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
||||
+ tokens
|
||||
)
|
||||
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
|
||||
|
||||
return tuple(tokens)
|
||||
|
||||
@@ -546,7 +612,13 @@ class DecodingTask:
|
||||
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
||||
|
||||
suppress_tokens.extend(
|
||||
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
|
||||
[
|
||||
self.tokenizer.transcribe,
|
||||
self.tokenizer.translate,
|
||||
self.tokenizer.sot,
|
||||
self.tokenizer.sot_prev,
|
||||
self.tokenizer.sot_lm,
|
||||
]
|
||||
)
|
||||
if self.tokenizer.no_speech is not None:
|
||||
# no-speech probability is collected separately
|
||||
@@ -558,14 +630,21 @@ class DecodingTask:
|
||||
if self.options.fp16:
|
||||
mel = mel.half()
|
||||
|
||||
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
||||
if mel.shape[-2:] == (
|
||||
self.model.dims.n_audio_ctx,
|
||||
self.model.dims.n_audio_state,
|
||||
):
|
||||
# encoded audio features are given; skip audio encoding
|
||||
audio_features = mel
|
||||
else:
|
||||
audio_features = self.model.encoder(mel)
|
||||
|
||||
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
|
||||
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
|
||||
if audio_features.dtype != (
|
||||
torch.float16 if self.options.fp16 else torch.float32
|
||||
):
|
||||
return TypeError(
|
||||
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
||||
)
|
||||
|
||||
return audio_features
|
||||
|
||||
@@ -574,7 +653,9 @@ class DecodingTask:
|
||||
lang_probs = None
|
||||
|
||||
if self.options.language is None or self.options.task == "lang_id":
|
||||
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
|
||||
lang_tokens, lang_probs = self.model.detect_language(
|
||||
audio_features, self.tokenizer
|
||||
)
|
||||
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
||||
if self.options.language is None:
|
||||
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
||||
@@ -591,7 +672,9 @@ class DecodingTask:
|
||||
for i in range(self.sample_len):
|
||||
logits = self.inference.logits(tokens, audio_features)
|
||||
|
||||
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
|
||||
if (
|
||||
i == 0 and self.tokenizer.no_speech is not None
|
||||
): # save no_speech_probs
|
||||
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||
|
||||
@@ -625,8 +708,12 @@ class DecodingTask:
|
||||
languages, language_probs = self._detect_language(audio_features, tokens)
|
||||
if self.options.task == "lang_id":
|
||||
return [
|
||||
DecodingResult(audio_features=features, language=language, language_probs=probs)
|
||||
for features, language, probs in zip(audio_features, languages, language_probs)
|
||||
DecodingResult(
|
||||
audio_features=features, language=language, language_probs=probs
|
||||
)
|
||||
for features, language, probs in zip(
|
||||
audio_features, languages, language_probs
|
||||
)
|
||||
]
|
||||
|
||||
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
|
||||
@@ -647,7 +734,8 @@ class DecodingTask:
|
||||
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
||||
tokens: List[List[Tensor]] = [
|
||||
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
|
||||
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
||||
for s in tokens
|
||||
]
|
||||
|
||||
# select the top-ranked sample in each group
|
||||
@@ -656,9 +744,18 @@ class DecodingTask:
|
||||
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
||||
|
||||
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
||||
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
|
||||
avg_logprobs: List[float] = [
|
||||
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
|
||||
]
|
||||
|
||||
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
|
||||
fields = (
|
||||
texts,
|
||||
languages,
|
||||
tokens,
|
||||
audio_features,
|
||||
avg_logprobs,
|
||||
no_speech_probs,
|
||||
)
|
||||
if len(set(map(len, fields))) != 1:
|
||||
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
||||
|
||||
@@ -673,12 +770,19 @@ class DecodingTask:
|
||||
temperature=self.options.temperature,
|
||||
compression_ratio=compression_ratio(text),
|
||||
)
|
||||
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
|
||||
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
|
||||
*fields
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
|
||||
def decode(
|
||||
model: "Whisper",
|
||||
mel: Tensor,
|
||||
options: DecodingOptions = DecodingOptions(),
|
||||
**kwargs,
|
||||
) -> Union[DecodingResult, List[DecodingResult]]:
|
||||
"""
|
||||
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
||||
|
||||
@@ -698,13 +802,12 @@ def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOpt
|
||||
result: Union[DecodingResult, List[DecodingResult]]
|
||||
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
||||
"""
|
||||
single = mel.ndim == 2
|
||||
if single:
|
||||
if single := mel.ndim == 2:
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
result = DecodingTask(model, options).run(mel)
|
||||
|
||||
if single:
|
||||
result = result[0]
|
||||
if kwargs:
|
||||
options = replace(options, **kwargs)
|
||||
|
||||
return result
|
||||
result = DecodingTask(model, options).run(mel)
|
||||
|
||||
return result[0] if single else result
|
||||
|
||||
+59
-18
@@ -1,15 +1,16 @@
|
||||
import base64
|
||||
import gzip
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
from typing import Iterable, Optional
|
||||
from typing import Dict, Iterable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .decoding import decode as decode_function
|
||||
from .decoding import detect_language as detect_language_function
|
||||
from .transcribe import transcribe as transcribe_function
|
||||
from .decoding import detect_language as detect_language_function, decode as decode_function
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,12 +35,16 @@ class LayerNorm(nn.LayerNorm):
|
||||
class Linear(nn.Linear):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(
|
||||
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
|
||||
x,
|
||||
self.weight.to(x.dtype),
|
||||
None if self.bias is None else self.bias.to(x.dtype),
|
||||
)
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
|
||||
def _conv_forward(
|
||||
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
||||
) -> Tensor:
|
||||
return super()._conv_forward(
|
||||
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||
)
|
||||
@@ -85,7 +90,9 @@ class MultiHeadAttention(nn.Module):
|
||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||
return self.out(wv), qk
|
||||
|
||||
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
||||
def qkv_attention(
|
||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||
):
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.25
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||
@@ -108,11 +115,15 @@ class ResidualAttentionBlock(nn.Module):
|
||||
self.attn = MultiHeadAttention(n_state, n_head)
|
||||
self.attn_ln = LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||
self.cross_attn = (
|
||||
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||
)
|
||||
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||
|
||||
n_mlp = n_state * 4
|
||||
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
|
||||
self.mlp = nn.Sequential(
|
||||
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
||||
)
|
||||
self.mlp_ln = LayerNorm(n_state)
|
||||
|
||||
def forward(
|
||||
@@ -130,7 +141,9 @@ class ResidualAttentionBlock(nn.Module):
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||
def __init__(
|
||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||
@@ -161,14 +174,19 @@ class AudioEncoder(nn.Module):
|
||||
|
||||
|
||||
class TextDecoder(nn.Module):
|
||||
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||
def __init__(
|
||||
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
|
||||
[
|
||||
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
)
|
||||
self.ln = LayerNorm(n_state)
|
||||
|
||||
@@ -183,14 +201,19 @@ class TextDecoder(nn.Module):
|
||||
the encoded audio features to be attended on
|
||||
"""
|
||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
x = (
|
||||
self.token_embedding(x)
|
||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
)
|
||||
x = x.to(xa.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||
|
||||
x = self.ln(x)
|
||||
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
||||
logits = (
|
||||
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
||||
).float()
|
||||
|
||||
return logits
|
||||
|
||||
@@ -213,6 +236,21 @@ class Whisper(nn.Module):
|
||||
self.dims.n_text_head,
|
||||
self.dims.n_text_layer,
|
||||
)
|
||||
# use the last half layers for alignment by default; see `set_alignment_heads()` below
|
||||
all_heads = torch.zeros(
|
||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||
)
|
||||
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||
|
||||
def set_alignment_heads(self, dump: bytes):
|
||||
array = np.frombuffer(
|
||||
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
||||
).copy()
|
||||
mask = torch.from_numpy(array).reshape(
|
||||
self.dims.n_text_layer, self.dims.n_text_head
|
||||
)
|
||||
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
||||
|
||||
def embed_audio(self, mel: torch.Tensor):
|
||||
return self.encoder(mel)
|
||||
@@ -220,7 +258,9 @@ class Whisper(nn.Module):
|
||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||
return self.decoder(tokens, audio_features)
|
||||
|
||||
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
def forward(
|
||||
self, mel: torch.Tensor, tokens: torch.Tensor
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
return self.decoder(tokens, self.encoder(mel))
|
||||
|
||||
@property
|
||||
@@ -249,8 +289,9 @@ class Whisper(nn.Module):
|
||||
hooks = []
|
||||
|
||||
def save_to_cache(module, _, output):
|
||||
if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
|
||||
cache[module] = output # save as-is, for the first token or cross attention
|
||||
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
||||
# save as-is, for the first token or cross attention
|
||||
cache[module] = output
|
||||
else:
|
||||
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
||||
return cache[module]
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from .basic import BasicTextNormalizer
|
||||
from .english import EnglishTextNormalizer
|
||||
from .basic import BasicTextNormalizer as BasicTextNormalizer
|
||||
from .english import EnglishTextNormalizer as EnglishTextNormalizer
|
||||
|
||||
@@ -48,13 +48,16 @@ def remove_symbols(s: str):
|
||||
Replace any other markers, symbols, punctuations with a space, keeping diacritics
|
||||
"""
|
||||
return "".join(
|
||||
" " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
|
||||
" " if unicodedata.category(c)[0] in "MSP" else c
|
||||
for c in unicodedata.normalize("NFKC", s)
|
||||
)
|
||||
|
||||
|
||||
class BasicTextNormalizer:
|
||||
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
|
||||
self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
||||
self.clean = (
|
||||
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
||||
)
|
||||
self.split_letters = split_letters
|
||||
|
||||
def __call__(self, s: str):
|
||||
@@ -66,6 +69,8 @@ class BasicTextNormalizer:
|
||||
if self.split_letters:
|
||||
s = " ".join(regex.findall(r"\X", s, regex.U))
|
||||
|
||||
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
|
||||
s = re.sub(
|
||||
r"\s+", " ", s
|
||||
) # replace any successive whitespace characters with a space
|
||||
|
||||
return s
|
||||
|
||||
@@ -84,7 +84,8 @@ class EnglishNumberNormalizer:
|
||||
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_ordinal = {
|
||||
name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()
|
||||
name.replace("y", "ieth"): (value, "th")
|
||||
for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
|
||||
|
||||
@@ -108,7 +109,10 @@ class EnglishNumberNormalizer:
|
||||
self.multipliers_ordinal = {
|
||||
name + "th": (value, "th") for name, value in self.multipliers.items()
|
||||
}
|
||||
self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal}
|
||||
self.multipliers_suffixed = {
|
||||
**self.multipliers_plural,
|
||||
**self.multipliers_ordinal,
|
||||
}
|
||||
self.decimals = {*self.ones, *self.tens, *self.zeros}
|
||||
|
||||
self.preceding_prefixers = {
|
||||
@@ -128,7 +132,8 @@ class EnglishNumberNormalizer:
|
||||
"cents": "¢",
|
||||
}
|
||||
self.prefixes = set(
|
||||
list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())
|
||||
list(self.preceding_prefixers.values())
|
||||
+ list(self.following_prefixers.values())
|
||||
)
|
||||
self.suffixers = {
|
||||
"per": {"cent": "%"},
|
||||
@@ -218,7 +223,9 @@ class EnglishNumberNormalizer:
|
||||
if value is None:
|
||||
value = ones
|
||||
elif isinstance(value, str) or prev in self.ones:
|
||||
if prev in self.tens and ones < 10: # replace the last zero with the digit
|
||||
if (
|
||||
prev in self.tens and ones < 10
|
||||
): # replace the last zero with the digit
|
||||
assert value[-1] == "0"
|
||||
value = value[:-1] + str(ones)
|
||||
else:
|
||||
@@ -522,14 +529,14 @@ class EnglishTextNormalizer:
|
||||
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||
s = re.sub(self.ignore_patterns, "", s)
|
||||
s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe
|
||||
s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
|
||||
|
||||
for pattern, replacement in self.replacers.items():
|
||||
s = re.sub(pattern, replacement, s)
|
||||
|
||||
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
|
||||
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
|
||||
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics
|
||||
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
|
||||
|
||||
s = self.standardize_numbers(s)
|
||||
s = self.standardize_spellings(s)
|
||||
@@ -538,6 +545,6 @@ class EnglishTextNormalizer:
|
||||
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
|
||||
s = re.sub(r"([^0-9])%", r"\1 ", s)
|
||||
|
||||
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
|
||||
s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
|
||||
|
||||
return s
|
||||
|
||||
@@ -0,0 +1,334 @@
|
||||
import itertools
|
||||
import subprocess
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
|
||||
def median_filter(x: torch.Tensor, filter_width: int):
|
||||
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
|
||||
pad_width = filter_width // 2
|
||||
if x.shape[-1] <= pad_width:
|
||||
# F.pad requires the padding width to be smaller than the input dimension
|
||||
return x
|
||||
|
||||
if (ndim := x.ndim) <= 2:
|
||||
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
|
||||
x = x[None, None, :]
|
||||
|
||||
assert (
|
||||
filter_width > 0 and filter_width % 2 == 1
|
||||
), "`filter_width` should be an odd number"
|
||||
|
||||
result = None
|
||||
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
||||
if x.is_cuda:
|
||||
try:
|
||||
from .triton_ops import median_filter_cuda
|
||||
|
||||
result = median_filter_cuda(x, filter_width)
|
||||
except (RuntimeError, subprocess.CalledProcessError):
|
||||
warnings.warn(
|
||||
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
||||
"falling back to a slower median kernel implementation..."
|
||||
)
|
||||
|
||||
if result is None:
|
||||
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
|
||||
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
||||
|
||||
if ndim <= 2:
|
||||
result = result[0, 0]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@numba.jit
|
||||
def backtrace(trace: np.ndarray):
|
||||
i = trace.shape[0] - 1
|
||||
j = trace.shape[1] - 1
|
||||
trace[0, :] = 2
|
||||
trace[:, 0] = 1
|
||||
|
||||
result = []
|
||||
while i > 0 or j > 0:
|
||||
result.append((i - 1, j - 1))
|
||||
|
||||
if trace[i, j] == 0:
|
||||
i -= 1
|
||||
j -= 1
|
||||
elif trace[i, j] == 1:
|
||||
i -= 1
|
||||
elif trace[i, j] == 2:
|
||||
j -= 1
|
||||
else:
|
||||
raise ValueError("Unexpected trace[i, j]")
|
||||
|
||||
result = np.array(result)
|
||||
return result[::-1, :].T
|
||||
|
||||
|
||||
@numba.jit(nopython=True, parallel=True)
|
||||
def dtw_cpu(x: np.ndarray):
|
||||
N, M = x.shape
|
||||
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
||||
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
||||
|
||||
cost[0, 0] = 0
|
||||
for j in range(1, M + 1):
|
||||
for i in range(1, N + 1):
|
||||
c0 = cost[i - 1, j - 1]
|
||||
c1 = cost[i - 1, j]
|
||||
c2 = cost[i, j - 1]
|
||||
|
||||
if c0 < c1 and c0 < c2:
|
||||
c, t = c0, 0
|
||||
elif c1 < c0 and c1 < c2:
|
||||
c, t = c1, 1
|
||||
else:
|
||||
c, t = c2, 2
|
||||
|
||||
cost[i, j] = x[i - 1, j - 1] + c
|
||||
trace[i, j] = t
|
||||
|
||||
return backtrace(trace)
|
||||
|
||||
|
||||
def dtw_cuda(x, BLOCK_SIZE=1024):
|
||||
from .triton_ops import dtw_kernel
|
||||
|
||||
M, N = x.shape
|
||||
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
|
||||
|
||||
x_skew = (
|
||||
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
|
||||
)
|
||||
x_skew = x_skew.T.contiguous()
|
||||
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
||||
cost[0, 0] = 0
|
||||
cost = cost.cuda()
|
||||
trace = torch.zeros_like(cost, dtype=torch.int32)
|
||||
|
||||
dtw_kernel[(1,)](
|
||||
cost,
|
||||
trace,
|
||||
x_skew,
|
||||
x_skew.stride(0),
|
||||
cost.stride(0),
|
||||
trace.stride(0),
|
||||
N,
|
||||
M,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
|
||||
:, : N + 1
|
||||
]
|
||||
return backtrace(trace.cpu().numpy())
|
||||
|
||||
|
||||
def dtw(x: torch.Tensor) -> np.ndarray:
|
||||
if x.is_cuda:
|
||||
try:
|
||||
return dtw_cuda(x)
|
||||
except (RuntimeError, subprocess.CalledProcessError):
|
||||
warnings.warn(
|
||||
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
||||
"falling back to a slower DTW implementation..."
|
||||
)
|
||||
|
||||
return dtw_cpu(x.double().cpu().numpy())
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordTiming:
|
||||
word: str
|
||||
tokens: List[int]
|
||||
start: float
|
||||
end: float
|
||||
probability: float
|
||||
|
||||
|
||||
def find_alignment(
|
||||
model: "Whisper",
|
||||
tokenizer: Tokenizer,
|
||||
text_tokens: List[int],
|
||||
mel: torch.Tensor,
|
||||
num_frames: int,
|
||||
*,
|
||||
medfilt_width: int = 7,
|
||||
qk_scale: float = 1.0,
|
||||
) -> List[WordTiming]:
|
||||
if len(text_tokens) == 0:
|
||||
return []
|
||||
|
||||
tokens = torch.tensor(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
tokenizer.no_timestamps,
|
||||
*text_tokens,
|
||||
tokenizer.eot,
|
||||
]
|
||||
).to(model.device)
|
||||
|
||||
# install hooks on the cross attention layers to retrieve the attention weights
|
||||
QKs = [None] * model.dims.n_text_layer
|
||||
hooks = [
|
||||
block.cross_attn.register_forward_hook(
|
||||
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
|
||||
)
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||
token_probs = sampled_logits.softmax(dim=-1)
|
||||
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
||||
text_token_probs = text_token_probs.tolist()
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
# heads * tokens * frames
|
||||
weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T])
|
||||
weights = weights[:, :, : num_frames // 2]
|
||||
weights = (weights * qk_scale).softmax(dim=-1)
|
||||
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||
weights = (weights - mean) / std
|
||||
weights = median_filter(weights, medfilt_width)
|
||||
|
||||
matrix = weights.mean(axis=0)
|
||||
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
||||
text_indices, time_indices = dtw(-matrix)
|
||||
|
||||
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
||||
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
||||
|
||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
||||
start_times = jump_times[word_boundaries[:-1]]
|
||||
end_times = jump_times[word_boundaries[1:]]
|
||||
word_probabilities = [
|
||||
np.mean(text_token_probs[i:j])
|
||||
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
||||
]
|
||||
|
||||
# hack: ensure the first and second word is not longer than twice the median word duration.
|
||||
# a better segmentation algorithm based on VAD should be able to replace this.
|
||||
word_durations = end_times - start_times
|
||||
word_durations = word_durations[word_durations.nonzero()]
|
||||
if len(word_durations) > 0:
|
||||
median_duration = np.median(word_durations)
|
||||
max_duration = median_duration * 2
|
||||
if len(word_durations) >= 2 and word_durations[1] > max_duration:
|
||||
boundary = max(end_times[2] / 2, end_times[2] - max_duration)
|
||||
end_times[0] = start_times[1] = boundary
|
||||
if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
|
||||
start_times[0] = max(0, end_times[0] - max_duration)
|
||||
|
||||
return [
|
||||
WordTiming(word, tokens, start, end, probability)
|
||||
for word, tokens, start, end, probability in zip(
|
||||
words, word_tokens, start_times, end_times, word_probabilities
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
|
||||
# merge prepended punctuations
|
||||
i = len(alignment) - 2
|
||||
j = len(alignment) - 1
|
||||
while i >= 0:
|
||||
previous = alignment[i]
|
||||
following = alignment[j]
|
||||
if previous.word.startswith(" ") and previous.word.strip() in prepended:
|
||||
# prepend it to the following word
|
||||
following.word = previous.word + following.word
|
||||
following.tokens = previous.tokens + following.tokens
|
||||
previous.word = ""
|
||||
previous.tokens = []
|
||||
else:
|
||||
j = i
|
||||
i -= 1
|
||||
|
||||
# merge appended punctuations
|
||||
i = 0
|
||||
j = 1
|
||||
while j < len(alignment):
|
||||
previous = alignment[i]
|
||||
following = alignment[j]
|
||||
if not previous.word.endswith(" ") and following.word in appended:
|
||||
# append it to the previous word
|
||||
previous.word = previous.word + following.word
|
||||
previous.tokens = previous.tokens + following.tokens
|
||||
following.word = ""
|
||||
following.tokens = []
|
||||
else:
|
||||
i = j
|
||||
j += 1
|
||||
|
||||
|
||||
def add_word_timestamps(
|
||||
*,
|
||||
segments: List[dict],
|
||||
model: "Whisper",
|
||||
tokenizer: Tokenizer,
|
||||
mel: torch.Tensor,
|
||||
num_frames: int,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
**kwargs,
|
||||
):
|
||||
if len(segments) == 0:
|
||||
return
|
||||
|
||||
text_tokens_per_segment = [
|
||||
[token for token in segment["tokens"] if token < tokenizer.eot]
|
||||
for segment in segments
|
||||
]
|
||||
|
||||
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
||||
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
||||
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
||||
|
||||
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
||||
word_index = 0
|
||||
|
||||
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
||||
saved_tokens = 0
|
||||
words = []
|
||||
|
||||
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
||||
timing = alignment[word_index]
|
||||
|
||||
if timing.word:
|
||||
words.append(
|
||||
dict(
|
||||
word=timing.word,
|
||||
start=round(time_offset + timing.start, 2),
|
||||
end=round(time_offset + timing.end, 2),
|
||||
probability=timing.probability,
|
||||
)
|
||||
)
|
||||
|
||||
saved_tokens += len(timing.tokens)
|
||||
word_index += 1
|
||||
|
||||
if len(words) > 0:
|
||||
# adjust the segment-level timestamps based on the word-level timestamps
|
||||
segment["start"] = words[0]["start"]
|
||||
segment["end"] = words[-1]["end"]
|
||||
|
||||
segment["words"] = words
|
||||
+159
-103
@@ -1,11 +1,12 @@
|
||||
import base64
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import string
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import GPT2TokenizerFast
|
||||
import tiktoken
|
||||
from tiktoken_ext.openai_public import gpt2
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english",
|
||||
@@ -126,114 +127,113 @@ TO_LANGUAGE_CODE = {
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass
|
||||
class Tokenizer:
|
||||
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
|
||||
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
||||
|
||||
tokenizer: "GPT2TokenizerFast"
|
||||
language: Optional[str]
|
||||
sot_sequence: Tuple[int]
|
||||
encoding: tiktoken.Encoding
|
||||
language: Optional[str] = None
|
||||
task: Optional[str] = None
|
||||
sot_sequence: Tuple[int] = ()
|
||||
special_tokens: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
for special in self.encoding.special_tokens_set:
|
||||
special_token = self.encoding.encode_single_token(special)
|
||||
self.special_tokens[special] = special_token
|
||||
|
||||
sot: int = self.special_tokens["<|startoftranscript|>"]
|
||||
translate: int = self.special_tokens["<|translate|>"]
|
||||
transcribe: int = self.special_tokens["<|transcribe|>"]
|
||||
|
||||
langs = tuple(LANGUAGES.keys())
|
||||
sot_sequence = [sot]
|
||||
if self.language is not None:
|
||||
sot_sequence.append(sot + 1 + langs.index(self.language))
|
||||
if self.task is not None:
|
||||
task_token: int = transcribe if self.task == "transcribe" else translate
|
||||
sot_sequence.append(task_token)
|
||||
|
||||
self.sot_sequence = tuple(sot_sequence)
|
||||
|
||||
def encode(self, text, **kwargs):
|
||||
return self.tokenizer.encode(text, **kwargs)
|
||||
return self.encoding.encode(text, **kwargs)
|
||||
|
||||
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
|
||||
return self.tokenizer.decode(token_ids, **kwargs)
|
||||
def decode(self, token_ids: List[int], **kwargs) -> str:
|
||||
token_ids = [t for t in token_ids if t < self.timestamp_begin]
|
||||
return self.encoding.decode(token_ids, **kwargs)
|
||||
|
||||
def decode_with_timestamps(self, tokens) -> str:
|
||||
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
|
||||
"""
|
||||
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
|
||||
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
|
||||
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||
"""
|
||||
outputs = [[]]
|
||||
for token in tokens:
|
||||
if token >= self.timestamp_begin:
|
||||
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
|
||||
outputs.append(timestamp)
|
||||
outputs.append([])
|
||||
else:
|
||||
outputs[-1].append(token)
|
||||
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
||||
return "".join(outputs)
|
||||
return self.encoding.decode(token_ids, **kwargs)
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def eot(self) -> int:
|
||||
return self.tokenizer.eos_token_id
|
||||
return self.encoding.eot_token
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def transcribe(self) -> int:
|
||||
return self.special_tokens["<|transcribe|>"]
|
||||
|
||||
@cached_property
|
||||
def translate(self) -> int:
|
||||
return self.special_tokens["<|translate|>"]
|
||||
|
||||
@cached_property
|
||||
def sot(self) -> int:
|
||||
return self._get_single_token_id("<|startoftranscript|>")
|
||||
return self.special_tokens["<|startoftranscript|>"]
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def sot_lm(self) -> int:
|
||||
return self._get_single_token_id("<|startoflm|>")
|
||||
return self.special_tokens["<|startoflm|>"]
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def sot_prev(self) -> int:
|
||||
return self._get_single_token_id("<|startofprev|>")
|
||||
return self.special_tokens["<|startofprev|>"]
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def no_speech(self) -> int:
|
||||
return self._get_single_token_id("<|nospeech|>")
|
||||
return self.special_tokens["<|nospeech|>"]
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def no_timestamps(self) -> int:
|
||||
return self._get_single_token_id("<|notimestamps|>")
|
||||
return self.special_tokens["<|notimestamps|>"]
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def timestamp_begin(self) -> int:
|
||||
return self.tokenizer.all_special_ids[-1] + 1
|
||||
return self.special_tokens["<|0.00|>"]
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def language_token(self) -> int:
|
||||
"""Returns the token id corresponding to the value of the `language` field"""
|
||||
if self.language is None:
|
||||
raise ValueError(f"This tokenizer does not have language token configured")
|
||||
raise ValueError("This tokenizer does not have language token configured")
|
||||
|
||||
additional_tokens = dict(
|
||||
zip(
|
||||
self.tokenizer.additional_special_tokens,
|
||||
self.tokenizer.additional_special_tokens_ids,
|
||||
)
|
||||
)
|
||||
candidate = f"<|{self.language}|>"
|
||||
if candidate in additional_tokens:
|
||||
return additional_tokens[candidate]
|
||||
if token := self.special_tokens.get(f"<|{self.language}|>", None):
|
||||
return token
|
||||
|
||||
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def all_language_tokens(self) -> Tuple[int]:
|
||||
result = []
|
||||
for token, token_id in zip(
|
||||
self.tokenizer.additional_special_tokens,
|
||||
self.tokenizer.additional_special_tokens_ids,
|
||||
):
|
||||
for token, token_id in self.special_tokens.items():
|
||||
if token.strip("<|>") in LANGUAGES:
|
||||
result.append(token_id)
|
||||
return tuple(result)
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def all_language_codes(self) -> Tuple[str]:
|
||||
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
||||
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def non_speech_tokens(self) -> Tuple[int]:
|
||||
"""
|
||||
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
||||
@@ -245,8 +245,10 @@ class Tokenizer:
|
||||
|
||||
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
||||
"""
|
||||
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
|
||||
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
||||
symbols += (
|
||||
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||
)
|
||||
|
||||
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
||||
# In case they're multiple tokens, suppress the first token, which is safe because:
|
||||
@@ -256,27 +258,82 @@ class Tokenizer:
|
||||
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
||||
|
||||
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
|
||||
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
||||
for symbol in symbols + list(miscellaneous):
|
||||
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
|
||||
for tokens in [
|
||||
self.encoding.encode(symbol),
|
||||
self.encoding.encode(" " + symbol),
|
||||
]:
|
||||
if len(tokens) == 1 or symbol in miscellaneous:
|
||||
result.add(tokens[0])
|
||||
|
||||
return tuple(sorted(result))
|
||||
|
||||
def _get_single_token_id(self, text) -> int:
|
||||
tokens = self.tokenizer.encode(text)
|
||||
assert len(tokens) == 1, f"{text} is not encoded as a single token"
|
||||
return tokens[0]
|
||||
def split_to_word_tokens(self, tokens: List[int]):
|
||||
if self.language in {"zh", "ja", "th", "lo", "my"}:
|
||||
# These languages don't typically use spaces, so it is difficult to split words
|
||||
# without morpheme analysis. Here, we instead split words at any
|
||||
# position where the tokens are decoded as valid unicode points
|
||||
return self.split_tokens_on_unicode(tokens)
|
||||
|
||||
return self.split_tokens_on_spaces(tokens)
|
||||
|
||||
def split_tokens_on_unicode(self, tokens: List[int]):
|
||||
decoded_full = self.decode_with_timestamps(tokens)
|
||||
replacement_char = "\ufffd"
|
||||
|
||||
words = []
|
||||
word_tokens = []
|
||||
current_tokens = []
|
||||
unicode_offset = 0
|
||||
|
||||
for token in tokens:
|
||||
current_tokens.append(token)
|
||||
decoded = self.decode_with_timestamps(current_tokens)
|
||||
|
||||
if (
|
||||
replacement_char not in decoded
|
||||
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
||||
== replacement_char
|
||||
):
|
||||
words.append(decoded)
|
||||
word_tokens.append(current_tokens)
|
||||
current_tokens = []
|
||||
unicode_offset += len(decoded)
|
||||
|
||||
return words, word_tokens
|
||||
|
||||
def split_tokens_on_spaces(self, tokens: List[int]):
|
||||
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
||||
words = []
|
||||
word_tokens = []
|
||||
|
||||
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
||||
special = subword_tokens[0] >= self.eot
|
||||
with_space = subword.startswith(" ")
|
||||
punctuation = subword.strip() in string.punctuation
|
||||
if special or with_space or punctuation or len(words) == 0:
|
||||
words.append(subword)
|
||||
word_tokens.append(subword_tokens)
|
||||
else:
|
||||
words[-1] = words[-1] + subword
|
||||
word_tokens[-1].extend(subword_tokens)
|
||||
|
||||
return words, word_tokens
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def build_tokenizer(name: str = "gpt2"):
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
path = os.path.join(os.path.dirname(__file__), "assets", name)
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(path)
|
||||
def get_encoding(name: str = "gpt2"):
|
||||
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
||||
ranks = {
|
||||
base64.b64decode(token): int(rank)
|
||||
for token, rank in (line.split() for line in open(vocab_path) if line)
|
||||
}
|
||||
n_vocab = len(ranks)
|
||||
special_tokens = {}
|
||||
|
||||
specials = [
|
||||
"<|endoftext|>",
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
||||
"<|translate|>",
|
||||
@@ -285,18 +342,28 @@ def build_tokenizer(name: str = "gpt2"):
|
||||
"<|startofprev|>",
|
||||
"<|nospeech|>",
|
||||
"<|notimestamps|>",
|
||||
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
||||
]
|
||||
|
||||
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
|
||||
return tokenizer
|
||||
for token in specials:
|
||||
special_tokens[token] = n_vocab
|
||||
n_vocab += 1
|
||||
|
||||
return tiktoken.Encoding(
|
||||
name=os.path.basename(vocab_path),
|
||||
explicit_n_vocab=n_vocab,
|
||||
pat_str=gpt2()["pat_str"],
|
||||
mergeable_ranks=ranks,
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_tokenizer(
|
||||
multilingual: bool,
|
||||
*,
|
||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||
language: Optional[str] = None,
|
||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||
) -> Tokenizer:
|
||||
if language is not None:
|
||||
language = language.lower()
|
||||
@@ -307,25 +374,14 @@ def get_tokenizer(
|
||||
raise ValueError(f"Unsupported language: {language}")
|
||||
|
||||
if multilingual:
|
||||
tokenizer_name = "multilingual"
|
||||
task = task or "transcribe"
|
||||
encoding_name = "multilingual"
|
||||
language = language or "en"
|
||||
task = task or "transcribe"
|
||||
else:
|
||||
tokenizer_name = "gpt2"
|
||||
task = None
|
||||
encoding_name = "gpt2"
|
||||
language = None
|
||||
task = None
|
||||
|
||||
tokenizer = build_tokenizer(name=tokenizer_name)
|
||||
all_special_ids: List[int] = tokenizer.all_special_ids
|
||||
sot: int = all_special_ids[1]
|
||||
translate: int = all_special_ids[-6]
|
||||
transcribe: int = all_special_ids[-5]
|
||||
encoding = get_encoding(name=encoding_name)
|
||||
|
||||
langs = tuple(LANGUAGES.keys())
|
||||
sot_sequence = [sot]
|
||||
if language is not None:
|
||||
sot_sequence.append(sot + 1 + langs.index(language))
|
||||
if task is not None:
|
||||
sot_sequence.append(transcribe if task == "transcribe" else translate)
|
||||
|
||||
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
|
||||
return Tokenizer(encoding=encoding, language=language, task=task)
|
||||
|
||||
+216
-90
@@ -1,16 +1,33 @@
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
|
||||
from .audio import (
|
||||
FRAMES_PER_SECOND,
|
||||
HOP_LENGTH,
|
||||
N_FRAMES,
|
||||
N_SAMPLES,
|
||||
SAMPLE_RATE,
|
||||
log_mel_spectrogram,
|
||||
pad_or_trim,
|
||||
)
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .timing import add_word_timestamps
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
|
||||
from .utils import (
|
||||
exact_div,
|
||||
format_timestamp,
|
||||
get_writer,
|
||||
make_safe,
|
||||
optional_float,
|
||||
optional_int,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
@@ -26,6 +43,10 @@ def transcribe(
|
||||
logprob_threshold: Optional[float] = -1.0,
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
word_timestamps: bool = False,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
@@ -62,6 +83,21 @@ def transcribe(
|
||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
||||
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
||||
|
||||
word_timestamps: bool
|
||||
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
||||
and include the timestamps for each word in each segment.
|
||||
|
||||
prepend_punctuations: str
|
||||
If word_timestamps is True, merge these punctuation symbols with the next word
|
||||
|
||||
append_punctuations: str
|
||||
If word_timestamps is True, merge these punctuation symbols with the previous word
|
||||
|
||||
initial_prompt: Optional[str]
|
||||
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||
to make it more likely to predict those word correctly.
|
||||
|
||||
decode_options: dict
|
||||
Keyword arguments to construct `DecodingOptions` instances
|
||||
|
||||
@@ -81,26 +117,37 @@ def transcribe(
|
||||
if dtype == torch.float32:
|
||||
decode_options["fp16"] = False
|
||||
|
||||
mel = log_mel_spectrogram(audio)
|
||||
# Pad 30-seconds of silence to the input audio, for slicing
|
||||
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
||||
content_frames = mel.shape[-1] - N_FRAMES
|
||||
|
||||
if decode_options.get("language", None) is None:
|
||||
if not model.is_multilingual:
|
||||
decode_options["language"] = "en"
|
||||
else:
|
||||
if verbose:
|
||||
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
|
||||
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
_, probs = model.detect_language(segment)
|
||||
print(
|
||||
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
||||
)
|
||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
_, probs = model.detect_language(mel_segment)
|
||||
decode_options["language"] = max(probs, key=probs.get)
|
||||
if verbose is not None:
|
||||
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
|
||||
print(
|
||||
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
||||
)
|
||||
|
||||
language = decode_options["language"]
|
||||
task = decode_options.get("task", "transcribe")
|
||||
language: str = decode_options["language"]
|
||||
task: str = decode_options.get("task", "transcribe")
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
||||
|
||||
if word_timestamps and task == "translate":
|
||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||
|
||||
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
||||
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
|
||||
temperatures = (
|
||||
[temperature] if isinstance(temperature, (int, float)) else temperature
|
||||
)
|
||||
decode_result = None
|
||||
|
||||
for t in temperatures:
|
||||
@@ -117,9 +164,15 @@ def transcribe(
|
||||
decode_result = model.decode(segment, options)
|
||||
|
||||
needs_fallback = False
|
||||
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
|
||||
if (
|
||||
compression_ratio_threshold is not None
|
||||
and decode_result.compression_ratio > compression_ratio_threshold
|
||||
):
|
||||
needs_fallback = True # too repetitive
|
||||
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
|
||||
if (
|
||||
logprob_threshold is not None
|
||||
and decode_result.avg_logprob < logprob_threshold
|
||||
):
|
||||
needs_fallback = True # average log probability is too low
|
||||
|
||||
if not needs_fallback:
|
||||
@@ -138,117 +191,186 @@ def transcribe(
|
||||
all_segments = []
|
||||
prompt_reset_since = 0
|
||||
|
||||
initial_prompt = decode_options.pop("initial_prompt", None) or []
|
||||
if initial_prompt:
|
||||
initial_prompt = tokenizer.encode(" " + initial_prompt.strip())
|
||||
all_tokens.extend(initial_prompt)
|
||||
if initial_prompt is not None:
|
||||
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
else:
|
||||
initial_prompt_tokens = []
|
||||
|
||||
def add_segment(
|
||||
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
|
||||
def new_segment(
|
||||
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
||||
):
|
||||
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
|
||||
if len(text.strip()) == 0: # skip empty text output
|
||||
return
|
||||
tokens = tokens.tolist()
|
||||
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
||||
return {
|
||||
"seek": seek,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": tokenizer.decode(text_tokens),
|
||||
"tokens": tokens,
|
||||
"temperature": result.temperature,
|
||||
"avg_logprob": result.avg_logprob,
|
||||
"compression_ratio": result.compression_ratio,
|
||||
"no_speech_prob": result.no_speech_prob,
|
||||
}
|
||||
|
||||
all_segments.append(
|
||||
{
|
||||
"id": len(all_segments),
|
||||
"seek": seek,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": text,
|
||||
"tokens": text_tokens.tolist(),
|
||||
"temperature": result.temperature,
|
||||
"avg_logprob": result.avg_logprob,
|
||||
"compression_ratio": result.compression_ratio,
|
||||
"no_speech_prob": result.no_speech_prob,
|
||||
}
|
||||
)
|
||||
if verbose:
|
||||
print(make_safe(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"))
|
||||
|
||||
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
||||
num_frames = mel.shape[-1]
|
||||
previous_seek_value = seek
|
||||
|
||||
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
|
||||
while seek < num_frames:
|
||||
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
|
||||
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
||||
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
||||
with tqdm.tqdm(
|
||||
total=content_frames, unit="frames", disable=verbose is not False
|
||||
) as pbar:
|
||||
while seek < content_frames:
|
||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||
mel_segment = mel[:, seek : seek + N_FRAMES]
|
||||
segment_size = min(N_FRAMES, content_frames - seek)
|
||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||
|
||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||
result: DecodingResult = decode_with_fallback(segment)
|
||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||
tokens = torch.tensor(result.tokens)
|
||||
|
||||
if no_speech_threshold is not None:
|
||||
# no voice activity check
|
||||
should_skip = result.no_speech_prob > no_speech_threshold
|
||||
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
|
||||
if (
|
||||
logprob_threshold is not None
|
||||
and result.avg_logprob > logprob_threshold
|
||||
):
|
||||
# don't skip if the logprob is high enough, despite the no_speech_prob
|
||||
should_skip = False
|
||||
|
||||
if should_skip:
|
||||
seek += segment.shape[-1] # fast-forward to the next segment boundary
|
||||
seek += segment_size # fast-forward to the next segment boundary
|
||||
continue
|
||||
|
||||
previous_seek = seek
|
||||
current_segments = []
|
||||
|
||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
|
||||
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
|
||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||
|
||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||
consecutive.add_(1)
|
||||
if len(consecutive) > 0:
|
||||
# if the output contains two consecutive timestamp tokens
|
||||
slices = consecutive.tolist()
|
||||
if single_timestamp_ending:
|
||||
slices.append(len(tokens))
|
||||
|
||||
last_slice = 0
|
||||
for current_slice in consecutive:
|
||||
for current_slice in slices:
|
||||
sliced_tokens = tokens[last_slice:current_slice]
|
||||
start_timestamp_position = (
|
||||
start_timestamp_pos = (
|
||||
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
end_timestamp_position = (
|
||||
end_timestamp_pos = (
|
||||
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
add_segment(
|
||||
start=timestamp_offset + start_timestamp_position * time_precision,
|
||||
end=timestamp_offset + end_timestamp_position * time_precision,
|
||||
text_tokens=sliced_tokens[1:-1],
|
||||
result=result,
|
||||
current_segments.append(
|
||||
new_segment(
|
||||
start=time_offset + start_timestamp_pos * time_precision,
|
||||
end=time_offset + end_timestamp_pos * time_precision,
|
||||
tokens=sliced_tokens,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
last_slice = current_slice
|
||||
last_timestamp_position = (
|
||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
seek += last_timestamp_position * input_stride
|
||||
all_tokens.extend(tokens[: last_slice + 1].tolist())
|
||||
|
||||
if single_timestamp_ending:
|
||||
# single timestamp at the end means no speech after the last timestamp.
|
||||
seek += segment_size
|
||||
else:
|
||||
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
||||
last_timestamp_pos = (
|
||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
seek += last_timestamp_pos * input_stride
|
||||
else:
|
||||
duration = segment_duration
|
||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
|
||||
if (
|
||||
len(timestamps) > 0
|
||||
and timestamps[-1].item() != tokenizer.timestamp_begin
|
||||
):
|
||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||
# single timestamp at the end means no speech after the last timestamp.
|
||||
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
|
||||
duration = last_timestamp_position * time_precision
|
||||
last_timestamp_pos = (
|
||||
timestamps[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
duration = last_timestamp_pos * time_precision
|
||||
|
||||
add_segment(
|
||||
start=timestamp_offset,
|
||||
end=timestamp_offset + duration,
|
||||
text_tokens=tokens,
|
||||
result=result,
|
||||
current_segments.append(
|
||||
new_segment(
|
||||
start=time_offset,
|
||||
end=time_offset + duration,
|
||||
tokens=tokens,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
|
||||
seek += segment.shape[-1]
|
||||
all_tokens.extend(tokens.tolist())
|
||||
seek += segment_size
|
||||
|
||||
if not condition_on_previous_text or result.temperature > 0.5:
|
||||
# do not feed the prompt tokens if a high temperature was used
|
||||
prompt_reset_since = len(all_tokens)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(min(num_frames, seek) - previous_seek_value)
|
||||
previous_seek_value = seek
|
||||
if word_timestamps:
|
||||
add_word_timestamps(
|
||||
segments=current_segments,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
mel=mel_segment,
|
||||
num_frames=segment_size,
|
||||
prepend_punctuations=prepend_punctuations,
|
||||
append_punctuations=append_punctuations,
|
||||
)
|
||||
word_end_timestamps = [
|
||||
w["end"] for s in current_segments for w in s["words"]
|
||||
]
|
||||
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
||||
seek_shift = round(
|
||||
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
||||
)
|
||||
if seek_shift > 0:
|
||||
seek = previous_seek + seek_shift
|
||||
|
||||
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
|
||||
if verbose:
|
||||
for segment in current_segments:
|
||||
start, end, text = segment["start"], segment["end"], segment["text"]
|
||||
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
||||
print(make_safe(line))
|
||||
|
||||
# if a segment is instantaneous or does not contain text, clear it
|
||||
for i, segment in enumerate(current_segments):
|
||||
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
||||
segment["text"] = ""
|
||||
segment["tokens"] = []
|
||||
segment["words"] = []
|
||||
|
||||
all_segments.extend(
|
||||
[
|
||||
{"id": i, **segment}
|
||||
for i, segment in enumerate(
|
||||
current_segments, start=len(all_segments)
|
||||
)
|
||||
]
|
||||
)
|
||||
all_tokens.extend(
|
||||
[token for segment in current_segments for token in segment["tokens"]]
|
||||
)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(min(content_frames, seek) - previous_seek)
|
||||
|
||||
return dict(
|
||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
||||
segments=all_segments,
|
||||
language=language,
|
||||
)
|
||||
|
||||
|
||||
def cli():
|
||||
from . import available_models
|
||||
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
||||
@@ -276,7 +398,11 @@ def cli():
|
||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
||||
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
||||
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
model_name: str = args.pop("model")
|
||||
@@ -288,29 +414,29 @@ def cli():
|
||||
|
||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||
if args["language"] is not None:
|
||||
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
||||
warnings.warn(
|
||||
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
||||
)
|
||||
args["language"] = "en"
|
||||
|
||||
temperature = args.pop("temperature")
|
||||
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
|
||||
if temperature_increment_on_fallback is not None:
|
||||
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
||||
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
||||
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
||||
else:
|
||||
temperature = [temperature]
|
||||
|
||||
threads = args.pop("threads")
|
||||
if threads > 0:
|
||||
if (threads := args.pop("threads")) > 0:
|
||||
torch.set_num_threads(threads)
|
||||
|
||||
from . import load_model
|
||||
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
|
||||
for audio_path in args.pop("audio"):
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
writer(result, audio_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError:
|
||||
raise RuntimeError("triton import failed; try `pip install --pre triton`")
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dtw_kernel(
|
||||
cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < M
|
||||
|
||||
for k in range(1, N + M + 1): # k = i + j
|
||||
tl.debug_barrier()
|
||||
|
||||
p0 = cost + (k - 1) * cost_stride
|
||||
p1 = cost + k * cost_stride
|
||||
p2 = cost + k * cost_stride + 1
|
||||
|
||||
c0 = tl.load(p0 + offsets, mask=mask)
|
||||
c1 = tl.load(p1 + offsets, mask=mask)
|
||||
c2 = tl.load(p2 + offsets, mask=mask)
|
||||
|
||||
x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
|
||||
cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
|
||||
|
||||
cost_ptr = cost + (k + 1) * cost_stride + 1
|
||||
tl.store(cost_ptr + offsets, cost_row, mask=mask)
|
||||
|
||||
trace_ptr = trace + (k + 1) * trace_stride + 1
|
||||
tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
|
||||
tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
|
||||
tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def median_kernel(filter_width: int):
|
||||
@triton.jit
|
||||
def kernel(
|
||||
y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
|
||||
): # x.shape[-1] == filter_width
|
||||
row_idx = tl.program_id(0)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < y_stride
|
||||
|
||||
x_ptr = x + row_idx * x_stride # noqa: F841
|
||||
y_ptr = y + row_idx * y_stride
|
||||
|
||||
LOAD_ALL_ROWS_HERE # noqa: F821
|
||||
|
||||
BUBBLESORT_HERE # noqa: F821
|
||||
|
||||
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
||||
|
||||
kernel = triton.JITFunction(kernel.fn)
|
||||
kernel.src = kernel.src.replace(
|
||||
" LOAD_ALL_ROWS_HERE",
|
||||
"\n".join(
|
||||
[
|
||||
f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
|
||||
for i in range(filter_width)
|
||||
]
|
||||
),
|
||||
)
|
||||
kernel.src = kernel.src.replace(
|
||||
" BUBBLESORT_HERE",
|
||||
"\n\n".join(
|
||||
[
|
||||
"\n\n".join(
|
||||
[
|
||||
"\n".join(
|
||||
[
|
||||
f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
|
||||
f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
|
||||
f" row{j} = smaller",
|
||||
f" row{j + 1} = larger",
|
||||
]
|
||||
)
|
||||
for j in range(filter_width - i - 1)
|
||||
]
|
||||
)
|
||||
for i in range(filter_width // 2 + 1)
|
||||
]
|
||||
),
|
||||
)
|
||||
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
def median_filter_cuda(x: torch.Tensor, filter_width: int):
|
||||
"""Apply a median filter of given width along the last dimension of x"""
|
||||
slices = x.contiguous().unfold(-1, filter_width, 1)
|
||||
grid = np.prod(slices.shape[:-2])
|
||||
|
||||
kernel = median_kernel(filter_width)
|
||||
y = torch.empty_like(slices[..., 0])
|
||||
|
||||
BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
|
||||
kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
|
||||
|
||||
return y
|
||||
+69
-27
@@ -7,11 +7,14 @@ from typing import Callable, TextIO
|
||||
system_encoding = sys.getdefaultencoding()
|
||||
|
||||
if system_encoding != "utf-8":
|
||||
|
||||
def make_safe(string):
|
||||
# replaces any character not representable using the system default encoding with an '?',
|
||||
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
|
||||
return string.encode(system_encoding, errors="replace").decode(system_encoding)
|
||||
|
||||
else:
|
||||
|
||||
def make_safe(string):
|
||||
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
|
||||
return string
|
||||
@@ -43,7 +46,9 @@ def compression_ratio(text) -> float:
|
||||
return len(text_bytes) / len(zlib.compress(text_bytes))
|
||||
|
||||
|
||||
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
|
||||
def format_timestamp(
|
||||
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
|
||||
):
|
||||
assert seconds >= 0, "non-negative timestamp expected"
|
||||
milliseconds = round(seconds * 1000.0)
|
||||
|
||||
@@ -57,7 +62,9 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
|
||||
milliseconds -= seconds * 1_000
|
||||
|
||||
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
||||
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||
return (
|
||||
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||
)
|
||||
|
||||
|
||||
class ResultWriter:
|
||||
@@ -68,7 +75,10 @@ class ResultWriter:
|
||||
|
||||
def __call__(self, result: dict, audio_path: str):
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
|
||||
audio_basename = os.path.splitext(audio_basename)[0]
|
||||
output_path = os.path.join(
|
||||
self.output_dir, audio_basename + "." + self.extension
|
||||
)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
self.write_result(result, file=f)
|
||||
@@ -82,37 +92,69 @@ class WriteTXT(ResultWriter):
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
for segment in result["segments"]:
|
||||
print(segment['text'].strip(), file=file, flush=True)
|
||||
print(segment["text"].strip(), file=file, flush=True)
|
||||
|
||||
|
||||
class WriteVTT(ResultWriter):
|
||||
class SubtitlesWriter(ResultWriter):
|
||||
always_include_hours: bool
|
||||
decimal_marker: str
|
||||
|
||||
def iterate_result(self, result: dict):
|
||||
for segment in result["segments"]:
|
||||
segment_start = self.format_timestamp(segment["start"])
|
||||
segment_end = self.format_timestamp(segment["end"])
|
||||
segment_text = segment["text"].strip().replace("-->", "->")
|
||||
|
||||
if word_timings := segment.get("words", None):
|
||||
all_words = [timing["word"] for timing in word_timings]
|
||||
all_words[0] = all_words[0].strip() # remove the leading space, if any
|
||||
last = segment_start
|
||||
for i, this_word in enumerate(word_timings):
|
||||
start = self.format_timestamp(this_word["start"])
|
||||
end = self.format_timestamp(this_word["end"])
|
||||
if last != start:
|
||||
yield last, start, segment_text
|
||||
|
||||
yield start, end, "".join(
|
||||
[
|
||||
f"<u>{word}</u>" if j == i else word
|
||||
for j, word in enumerate(all_words)
|
||||
]
|
||||
)
|
||||
last = end
|
||||
|
||||
if last != segment_end:
|
||||
yield last, segment_end, segment_text
|
||||
else:
|
||||
yield segment_start, segment_end, segment_text
|
||||
|
||||
def format_timestamp(self, seconds: float):
|
||||
return format_timestamp(
|
||||
seconds=seconds,
|
||||
always_include_hours=self.always_include_hours,
|
||||
decimal_marker=self.decimal_marker,
|
||||
)
|
||||
|
||||
|
||||
class WriteVTT(SubtitlesWriter):
|
||||
extension: str = "vtt"
|
||||
always_include_hours: bool = False
|
||||
decimal_marker: str = "."
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
print("WEBVTT\n", file=file)
|
||||
for segment in result["segments"]:
|
||||
print(
|
||||
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||
file=file,
|
||||
flush=True,
|
||||
)
|
||||
for start, end, text in self.iterate_result(result):
|
||||
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
|
||||
class WriteSRT(ResultWriter):
|
||||
class WriteSRT(SubtitlesWriter):
|
||||
extension: str = "srt"
|
||||
always_include_hours: bool = True
|
||||
decimal_marker: str = ","
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
for i, segment in enumerate(result["segments"], start=1):
|
||||
# write srt lines
|
||||
print(
|
||||
f"{i}\n"
|
||||
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
||||
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||
file=file,
|
||||
flush=True,
|
||||
)
|
||||
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
|
||||
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
|
||||
class WriteTSV(ResultWriter):
|
||||
@@ -124,14 +166,15 @@ class WriteTSV(ResultWriter):
|
||||
an environment setting a language encoding that causes the decimal in a floating point number
|
||||
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
|
||||
"""
|
||||
|
||||
extension: str = "tsv"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
print("start", "end", "text", sep="\t", file=file)
|
||||
for segment in result["segments"]:
|
||||
print(round(1000 * segment['start']), file=file, end="\t")
|
||||
print(round(1000 * segment['end']), file=file, end="\t")
|
||||
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
|
||||
print(round(1000 * segment["start"]), file=file, end="\t")
|
||||
print(round(1000 * segment["end"]), file=file, end="\t")
|
||||
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
|
||||
|
||||
|
||||
class WriteJSON(ResultWriter):
|
||||
@@ -160,4 +203,3 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO],
|
||||
return write_all
|
||||
|
||||
return writers[output_format](output_dir)
|
||||
|
||||
|
||||
+1
-1
@@ -1 +1 @@
|
||||
__version__ = "20230124"
|
||||
__version__ = "20230314"
|
||||
|
||||
Reference in New Issue
Block a user