Compare commits
56 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ad3250a846 | |||
| c4b50c0824 | |||
| 38f2f4d99d | |||
| aac47c9834 | |||
| 26807ec6d3 | |||
| 919a713499 | |||
| 38e990d853 | |||
| 924e1f8e06 | |||
| 4b0d5e58d0 | |||
| 8180fde939 | |||
| c6e4e5efb3 | |||
| b80bcf610d | |||
| 500d0fe966 | |||
| eab8d920ed | |||
| 3e1780fd37 | |||
| 7858aa9c08 | |||
| 5c1a8c10e7 | |||
| 4e635c6644 | |||
| a6b36ede1f | |||
| 55f690af79 | |||
| 7f1ef223ab | |||
| f5bfe004ec | |||
| da600abd2b | |||
| 9f7aba6099 | |||
| 12e1089462 | |||
| ea1c266709 | |||
| 8135a7c31c | |||
| 9d646db9d8 | |||
| 37a4f1be6d | |||
| b9f9b433ae | |||
| f0083e7eb2 | |||
| a84191faae | |||
| b1d213c0c7 | |||
| 493dfffa37 | |||
| 0f39c89d92 | |||
| 6df3ea1fb5 | |||
| 70861c7ce3 | |||
| f82bc59f5e | |||
| 28769fcfe5 | |||
| 53807677fe | |||
| 9323b2526c | |||
| 68e44bd83c | |||
| 0b5dcfdef7 | |||
| b9265e5796 | |||
| fd8f80c8b8 | |||
| 4179ed2475 | |||
| ec1b34bb90 | |||
| eff383b27b | |||
| 02aa851a49 | |||
| 76148a56c5 | |||
| 9f70a352f9 | |||
| 7f3e408e09 | |||
| f680570016 | |||
| d18e9ea5dd | |||
| 82725cea9c | |||
| 35713c66e0 |
@@ -0,0 +1,37 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions-ecosystem/action-regex-match@v2
|
||||
id: regex-match
|
||||
with:
|
||||
text: ${{ github.event.head_commit.message }}
|
||||
regex: '^Release ([^ ]+)'
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.8'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install setuptools wheel twine
|
||||
- name: Release
|
||||
if: ${{ steps.regex-match.outputs.match != '' }}
|
||||
uses: softprops/action-gh-release@v1
|
||||
with:
|
||||
tag_name: v${{ steps.regex-match.outputs.group1 }}
|
||||
- name: Build and publish
|
||||
if: ${{ steps.regex-match.outputs.match != '' }}
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||
run: |
|
||||
python setup.py sdist
|
||||
twine upload dist/*
|
||||
@@ -0,0 +1,28 @@
|
||||
name: test
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
whisper-test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.8', '3.9', '3.10']
|
||||
pytorch-version: [1.10.2, 1.13.1]
|
||||
exclude:
|
||||
- python-version: '3.10'
|
||||
pytorch-version: 1.10.2
|
||||
steps:
|
||||
- uses: conda-incubator/setup-miniconda@v2
|
||||
- run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
|
||||
- uses: actions/checkout@v2
|
||||
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
|
||||
- run: pip install .["dev"]
|
||||
- run: black --check --diff -t py38 --include '(\.pyi?)$' .
|
||||
- run: isort --check --diff .
|
||||
- run: flake8 --ignore E203,W503,W504,E501,E731,E741 .
|
||||
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
|
||||
@@ -0,0 +1,38 @@
|
||||
# CHANGELOG
|
||||
|
||||
## [v20230308](https://github.com/openai/whisper/releases/tag/v20230308)
|
||||
|
||||
* kwargs in decode() for convenience ([#1061](https://github.com/openai/whisper/pull/1061))
|
||||
* fix all_tokens handling that caused more repetitions and discrepancy in JSON ([#1060](https://github.com/openai/whisper/pull/1060))
|
||||
* fix typo in CHANGELOG.md
|
||||
|
||||
## [v20230307](https://github.com/openai/whisper/releases/tag/v20230307)
|
||||
|
||||
* Fix the repetition/hallucination issue identified in #1046 ([#1052](https://github.com/openai/whisper/pull/1052))
|
||||
* Use triton==2.0.0 ([#1053](https://github.com/openai/whisper/pull/1053))
|
||||
* Install triton in x86_64 linux only ([#1051](https://github.com/openai/whisper/pull/1051))
|
||||
* update setup.py to specify python >= 3.8 requirement
|
||||
|
||||
## [v20230306](https://github.com/openai/whisper/releases/tag/v20230306)
|
||||
|
||||
* remove auxiliary audio extension ([#1021](https://github.com/openai/whisper/pull/1021))
|
||||
* apply formatting with `black`, `isort`, and `flake8` ([#1038](https://github.com/openai/whisper/pull/1038))
|
||||
* word-level timestamps in `transcribe()` ([#869](https://github.com/openai/whisper/pull/869))
|
||||
* Decoding improvements ([#1033](https://github.com/openai/whisper/pull/1033))
|
||||
* Update README.md ([#894](https://github.com/openai/whisper/pull/894))
|
||||
* Fix infinite loop caused by incorrect timestamp tokens prediction ([#914](https://github.com/openai/whisper/pull/914))
|
||||
* drop python 3.7 support ([#889](https://github.com/openai/whisper/pull/889))
|
||||
|
||||
## [v20230124](https://github.com/openai/whisper/releases/tag/v20230124)
|
||||
|
||||
* handle printing even if sys.stdout.buffer is not available ([#887](https://github.com/openai/whisper/pull/887))
|
||||
* Add TSV formatted output in transcript, using integer start/end time in milliseconds ([#228](https://github.com/openai/whisper/pull/228))
|
||||
* Added `--output_format` option ([#333](https://github.com/openai/whisper/pull/333))
|
||||
* Handle `XDG_CACHE_HOME` properly for `download_root` ([#864](https://github.com/openai/whisper/pull/864))
|
||||
* use stdout for printing transcription progress ([#867](https://github.com/openai/whisper/pull/867))
|
||||
* Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm ([#659](https://github.com/openai/whisper/pull/659))
|
||||
* print '?' if a letter can't be encoded using the system default encoding ([#859](https://github.com/openai/whisper/pull/859))
|
||||
|
||||
## [v20230117](https://github.com/openai/whisper/releases/tag/v20230117)
|
||||
|
||||
The first versioned release available on [PyPI](https://pypi.org/project/openai-whisper/)
|
||||
@@ -1,3 +1,6 @@
|
||||
include requirements.txt
|
||||
include README.md
|
||||
include LICENSE
|
||||
include whisper/assets/*
|
||||
include whisper/assets/gpt2/*
|
||||
include whisper/assets/multilingual/*
|
||||
|
||||
@@ -1,26 +1,34 @@
|
||||
# Whisper
|
||||
|
||||
[[Blog]](https://openai.com/blog/whisper)
|
||||
[[Paper]](https://cdn.openai.com/papers/whisper.pdf)
|
||||
[[Model card]](model-card.md)
|
||||
[[Paper]](https://arxiv.org/abs/2212.04356)
|
||||
[[Model card]](https://github.com/openai/whisper/blob/main/model-card.md)
|
||||
[[Colab example]](https://colab.research.google.com/github/openai/whisper/blob/master/notebooks/LibriSpeech.ipynb)
|
||||
|
||||
Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multi-task model that can perform multilingual speech recognition as well as speech translation and language identification.
|
||||
Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitasking model that can perform multilingual speech recognition, speech translation, and language identification.
|
||||
|
||||
|
||||
## Approach
|
||||
|
||||

|
||||

|
||||
|
||||
A Transformer sequence-to-sequence model is trained on various speech processing tasks, including multilingual speech recognition, speech translation, spoken language identification, and voice activity detection. All of these tasks are jointly represented as a sequence of tokens to be predicted by the decoder, allowing for a single model to replace many different stages of a traditional speech processing pipeline. The multitask training format uses a set of special tokens that serve as task specifiers or classification targets.
|
||||
A Transformer sequence-to-sequence model is trained on various speech processing tasks, including multilingual speech recognition, speech translation, spoken language identification, and voice activity detection. These tasks are jointly represented as a sequence of tokens to be predicted by the decoder, allowing a single model to replace many stages of a traditional speech-processing pipeline. The multitask training format uses a set of special tokens that serve as task specifiers or classification targets.
|
||||
|
||||
|
||||
## Setup
|
||||
|
||||
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.7 or later and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. The following command will pull and install the latest commit from this repository, along with its Python dependencies
|
||||
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.8-3.10 and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. You can download and install (or update to) the latest release of Whisper with the following command:
|
||||
|
||||
pip install -U openai-whisper
|
||||
|
||||
Alternatively, the following command will pull and install the latest commit from this repository, along with its Python dependencies:
|
||||
|
||||
pip install git+https://github.com/openai/whisper.git
|
||||
|
||||
To update the package to the latest version of this repository, please run:
|
||||
|
||||
pip install --upgrade --no-deps --force-reinstall git+https://github.com/openai/whisper.git
|
||||
|
||||
It also requires the command-line tool [`ffmpeg`](https://ffmpeg.org/) to be installed on your system, which is available from most package managers:
|
||||
|
||||
```bash
|
||||
@@ -60,11 +68,11 @@ There are five model sizes, four with English-only versions, offering speed and
|
||||
| medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x |
|
||||
| large | 1550 M | N/A | `large` | ~10 GB | 1x |
|
||||
|
||||
For English-only applications, the `.en` models tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
|
||||
The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
|
||||
|
||||
Whisper's performance varies widely depending on the language. The figure below shows a WER breakdown by languages of Fleurs dataset, using the `large` model. More WER and BLEU scores corresponding to the other models and datasets can be found in Appendix D in [the paper](https://cdn.openai.com/papers/whisper.pdf).
|
||||
Whisper's performance varies widely depending on the language. The figure below shows a WER (Word Error Rate) breakdown by languages of the Fleurs dataset using the `large-v2` model. More WER and BLEU scores corresponding to the other models and datasets can be found in Appendix D in [the paper](https://arxiv.org/abs/2212.04356). The smaller, the better.
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
|
||||
@@ -86,7 +94,7 @@ Run the following to view all available options:
|
||||
|
||||
whisper --help
|
||||
|
||||
See [tokenizer.py](whisper/tokenizer.py) for the list of all available languages.
|
||||
See [tokenizer.py](https://github.com/openai/whisper/blob/main/whisper/tokenizer.py) for the list of all available languages.
|
||||
|
||||
|
||||
## Python usage
|
||||
@@ -136,4 +144,4 @@ Please use the [🙌 Show and tell](https://github.com/openai/whisper/discussion
|
||||
|
||||
## License
|
||||
|
||||
The code and the model weights of Whisper are released under the MIT License. See [LICENSE](LICENSE) for further details.
|
||||
Whisper's code and model weights are released under the MIT License. See [LICENSE](https://github.com/openai/whisper/blob/main/LICENSE) for further details.
|
||||
|
||||
+1501
-2532
File diff suppressed because it is too large
Load Diff
|
Before Width: | Height: | Size: 134 KiB After Width: | Height: | Size: 100 KiB |
+8
-6
@@ -2,7 +2,7 @@
|
||||
|
||||
This is the official codebase for running the automatic speech recognition (ASR) models (Whisper models) trained and released by OpenAI.
|
||||
|
||||
Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about the automatic speech recognition model. More information on how these models were trained and evaluated can be found [in the paper](https://cdn.openai.com/papers/whisper.pdf).
|
||||
Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about the automatic speech recognition model. More information on how these models were trained and evaluated can be found [in the paper](https://arxiv.org/abs/2212.04356).
|
||||
|
||||
|
||||
## Model Details
|
||||
@@ -17,10 +17,12 @@ The Whisper models are trained for speech recognition and translation tasks, cap
|
||||
| medium | 769 M | ✓ | ✓ |
|
||||
| large | 1550 M | | ✓ |
|
||||
|
||||
In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661).
|
||||
|
||||
|
||||
### Release date
|
||||
|
||||
September 2022
|
||||
September 2022 (original series) and December 2022 (`large-v2`)
|
||||
|
||||
### Model type
|
||||
|
||||
@@ -28,7 +30,7 @@ Sequence-to-sequence ASR (automatic speech recognition) and speech translation m
|
||||
|
||||
### Paper & samples
|
||||
|
||||
[Paper](https://cdn.openai.com/papers/whisper.pdf) / [Blog](https://openai.com/blog/whisper)
|
||||
[Paper](https://arxiv.org/abs/2212.04356) / [Blog](https://openai.com/blog/whisper)
|
||||
|
||||
|
||||
## Model Use
|
||||
@@ -46,7 +48,7 @@ In particular, we caution against using Whisper models to transcribe recordings
|
||||
|
||||
The models are trained on 680,000 hours of audio and the corresponding transcripts collected from the internet. 65% of this data (or 438,000 hours) represents English-language audio and matched English transcripts, roughly 18% (or 126,000 hours) represents non-English audio and English transcripts, while the final 17% (or 117,000 hours) represents non-English audio and the corresponding transcript. This non-English data represents 98 different languages.
|
||||
|
||||
As discussed in [the accompanying paper](https://cdn.openai.com/papers/whisper.pdf), we see that performance on transcription in a given language is directly correlated with the amount of training data we employ in that language.
|
||||
As discussed in [the accompanying paper](https://arxiv.org/abs/2212.04356), we see that performance on transcription in a given language is directly correlated with the amount of training data we employ in that language.
|
||||
|
||||
|
||||
## Performance and Limitations
|
||||
@@ -55,9 +57,9 @@ Our studies show that, over many existing ASR systems, the models exhibit improv
|
||||
|
||||
However, because the models are trained in a weakly supervised manner using large-scale noisy data, the predictions may include texts that are not actually spoken in the audio input (i.e. hallucination). We hypothesize that this happens because, given their general knowledge of language, the models combine trying to predict the next word in audio with trying to transcribe the audio itself.
|
||||
|
||||
Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://cdn.openai.com/papers/whisper.pdf).
|
||||
Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://arxiv.org/abs/2212.04356).
|
||||
|
||||
In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis on these limitations are provided in [the paper](https://cdn.openai.com/papers/whisper.pdf). It is likely that this behavior and hallucinations may be worse on lower-resource and/or lower-discoverability languages.
|
||||
In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis on these limitations are provided in [the paper](https://arxiv.org/abs/2212.04356). It is likely that this behavior and hallucinations may be worse on lower-resource and/or lower-discoverability languages.
|
||||
|
||||
|
||||
## Broader Implications
|
||||
|
||||
+2989
-116
File diff suppressed because one or more lines are too long
@@ -0,0 +1,8 @@
|
||||
[tool.black]
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
include_trailing_comma = true
|
||||
line_length = 88
|
||||
multi_line_output = 3
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
numba
|
||||
numpy
|
||||
torch
|
||||
tqdm
|
||||
|
||||
@@ -1,24 +1,43 @@
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
|
||||
import pkg_resources
|
||||
from setuptools import setup, find_packages
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def read_version(fname="whisper/version.py"):
|
||||
exec(compile(open(fname, encoding="utf-8").read(), fname, "exec"))
|
||||
return locals()["__version__"]
|
||||
|
||||
|
||||
requirements = []
|
||||
if sys.platform.startswith("linux") and platform.machine() == "x86_64":
|
||||
requirements.append("triton==2.0.0")
|
||||
|
||||
setup(
|
||||
name="whisper",
|
||||
name="openai-whisper",
|
||||
py_modules=["whisper"],
|
||||
version="1.0",
|
||||
description="",
|
||||
version=read_version(),
|
||||
description="Robust Speech Recognition via Large-Scale Weak Supervision",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
readme="README.md",
|
||||
python_requires=">=3.8",
|
||||
author="OpenAI",
|
||||
url="https://github.com/openai/whisper",
|
||||
license="MIT",
|
||||
packages=find_packages(exclude=["tests*"]),
|
||||
install_requires=[
|
||||
install_requires=requirements
|
||||
+ [
|
||||
str(r)
|
||||
for r in pkg_resources.parse_requirements(
|
||||
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||
)
|
||||
],
|
||||
entry_points = {
|
||||
'console_scripts': ['whisper=whisper.transcribe:cli'],
|
||||
entry_points={
|
||||
"console_scripts": ["whisper=whisper.transcribe:cli"],
|
||||
},
|
||||
include_package_data=True,
|
||||
extras_require={'dev': ['pytest']},
|
||||
extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
import random as rand
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "requires_cuda")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random():
|
||||
rand.seed(42)
|
||||
numpy.random.seed(42)
|
||||
+1
-1
@@ -2,7 +2,7 @@ import os.path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from whisper.audio import load_audio, log_mel_spectrogram, SAMPLE_RATE
|
||||
from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
|
||||
|
||||
def test_audio():
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import pytest
|
||||
|
||||
from whisper.normalizers import EnglishTextNormalizer
|
||||
from whisper.normalizers.english import EnglishNumberNormalizer, EnglishSpellingNormalizer
|
||||
from whisper.normalizers.english import (
|
||||
EnglishNumberNormalizer,
|
||||
EnglishSpellingNormalizer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()])
|
||||
@@ -84,6 +87,7 @@ def test_text_normalizer():
|
||||
assert std("he's like") == "he is like"
|
||||
assert std("she's been like") == "she has been like"
|
||||
assert std("10km") == "10 km"
|
||||
assert std("10mm") == "10 mm"
|
||||
assert std("RC232") == "rc 232"
|
||||
|
||||
assert (
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import scipy.ndimage
|
||||
import torch
|
||||
|
||||
from whisper.timing import dtw_cpu, dtw_cuda, median_filter
|
||||
|
||||
sizes = [
|
||||
(10, 20),
|
||||
(32, 16),
|
||||
(123, 1500),
|
||||
(234, 189),
|
||||
]
|
||||
shapes = [
|
||||
(10,),
|
||||
(1, 15),
|
||||
(4, 5, 345),
|
||||
(6, 12, 240, 512),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("N, M", sizes)
|
||||
def test_dtw(N: int, M: int):
|
||||
steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)])
|
||||
np.random.shuffle(steps)
|
||||
x = np.random.random((N, M)).astype(np.float32)
|
||||
|
||||
i, j, k = 0, 0, 0
|
||||
trace = []
|
||||
while True:
|
||||
x[i, j] -= 1
|
||||
trace.append((i, j))
|
||||
|
||||
if k == len(steps):
|
||||
break
|
||||
|
||||
if k + 1 < len(steps) and steps[k] != steps[k + 1]:
|
||||
i += 1
|
||||
j += 1
|
||||
k += 2
|
||||
continue
|
||||
|
||||
if steps[k] == 0:
|
||||
i += 1
|
||||
if steps[k] == 1:
|
||||
j += 1
|
||||
k += 1
|
||||
|
||||
trace = np.array(trace).T
|
||||
dtw_trace = dtw_cpu(x)
|
||||
|
||||
assert np.allclose(trace, dtw_trace)
|
||||
|
||||
|
||||
@pytest.mark.requires_cuda
|
||||
@pytest.mark.parametrize("N, M", sizes)
|
||||
def test_dtw_cuda_equivalence(N: int, M: int):
|
||||
x_numpy = np.random.randn(N, M).astype(np.float32)
|
||||
x_cuda = torch.from_numpy(x_numpy).cuda()
|
||||
|
||||
trace_cpu = dtw_cpu(x_numpy)
|
||||
trace_cuda = dtw_cuda(x_cuda)
|
||||
|
||||
assert np.allclose(trace_cpu, trace_cuda)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", shapes)
|
||||
def test_median_filter(shape):
|
||||
x = torch.randn(*shape)
|
||||
|
||||
for filter_width in [3, 5, 7, 13]:
|
||||
filtered = median_filter(x, filter_width)
|
||||
|
||||
# using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
|
||||
pad_width = filter_width // 2
|
||||
padded_x = np.pad(
|
||||
x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect"
|
||||
)
|
||||
scipy_filtered = scipy.ndimage.median_filter(
|
||||
padded_x, [1] * (x.ndim - 1) + [filter_width]
|
||||
)
|
||||
scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
|
||||
|
||||
assert np.allclose(filtered, scipy_filtered)
|
||||
|
||||
|
||||
@pytest.mark.requires_cuda
|
||||
@pytest.mark.parametrize("shape", shapes)
|
||||
def test_median_filter_equivalence(shape):
|
||||
x = torch.randn(*shape)
|
||||
|
||||
for filter_width in [3, 5, 7, 13]:
|
||||
filtered_cpu = median_filter(x, filter_width)
|
||||
filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
|
||||
|
||||
assert np.allclose(filtered_cpu, filtered_gpu)
|
||||
@@ -1,20 +1,37 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import whisper
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', whisper.available_models())
|
||||
@pytest.mark.parametrize("model_name", whisper.available_models())
|
||||
def test_transcribe(model_name: str):
|
||||
model = whisper.load_model(model_name).cuda()
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = whisper.load_model(model_name).to(device)
|
||||
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
||||
|
||||
language = "en" if model_name.endswith(".en") else None
|
||||
result = model.transcribe(audio_path, language=language, temperature=0.0)
|
||||
result = model.transcribe(
|
||||
audio_path, language=language, temperature=0.0, word_timestamps=True
|
||||
)
|
||||
assert result["language"] == "en"
|
||||
assert result["text"] == "".join([s["text"] for s in result["segments"]])
|
||||
|
||||
transcription = result["text"].lower()
|
||||
assert "my fellow americans" in transcription
|
||||
assert "your country" in transcription
|
||||
assert "do for you" in transcription
|
||||
|
||||
timing_checked = False
|
||||
for segment in result["segments"]:
|
||||
for timing in segment["words"]:
|
||||
assert timing["start"] < timing["end"]
|
||||
if timing["word"].strip(" ,") == "Americans":
|
||||
assert timing["start"] <= 1.8
|
||||
assert timing["end"] >= 1.8
|
||||
print(timing)
|
||||
timing_checked = True
|
||||
|
||||
assert timing_checked
|
||||
|
||||
+55
-11
@@ -10,9 +10,9 @@ from tqdm import tqdm
|
||||
|
||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
from .model import Whisper, ModelDimensions
|
||||
from .model import ModelDimensions, Whisper
|
||||
from .transcribe import transcribe
|
||||
|
||||
from .version import __version__
|
||||
|
||||
_MODELS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
@@ -23,7 +23,25 @@ _MODELS = {
|
||||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt",
|
||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
}
|
||||
|
||||
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
||||
_ALIGNMENT_HEADS = {
|
||||
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
||||
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
||||
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
||||
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
||||
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
||||
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
||||
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
||||
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
||||
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||
"large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||
}
|
||||
|
||||
|
||||
@@ -37,14 +55,23 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
with open(download_target, "rb") as f:
|
||||
model_bytes = f.read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return model_bytes if in_memory else download_target
|
||||
else:
|
||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
warnings.warn(
|
||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
||||
with tqdm(
|
||||
total=int(source.info().get("Content-Length")),
|
||||
ncols=80,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
@@ -55,7 +82,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
return model_bytes if in_memory else download_target
|
||||
|
||||
@@ -65,7 +94,12 @@ def available_models() -> List[str]:
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
|
||||
def load_model(
|
||||
name: str,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
download_root: str = None,
|
||||
in_memory: bool = False,
|
||||
) -> Whisper:
|
||||
"""
|
||||
Load a Whisper ASR model
|
||||
|
||||
@@ -90,16 +124,23 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if download_root is None:
|
||||
download_root = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
||||
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
alignment_heads = _ALIGNMENT_HEADS[name]
|
||||
elif os.path.isfile(name):
|
||||
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||
alignment_heads = None
|
||||
else:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
raise RuntimeError(
|
||||
f"Model {name} not found; available models = {available_models()}"
|
||||
)
|
||||
|
||||
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
|
||||
with (
|
||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||
) as fp:
|
||||
checkpoint = torch.load(fp, map_location=device)
|
||||
del checkpoint_file
|
||||
|
||||
@@ -107,4 +148,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
|
||||
model = Whisper(dims)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
if alignment_heads is not None:
|
||||
model.set_alignment_heads(alignment_heads)
|
||||
|
||||
return model.to(device)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from .transcribe import cli
|
||||
|
||||
|
||||
cli()
|
||||
|
||||
+30
-7
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
@@ -15,8 +15,12 @@ N_FFT = 400
|
||||
N_MELS = 80
|
||||
HOP_LENGTH = 160
|
||||
CHUNK_LENGTH = 30
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
||||
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
||||
|
||||
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
||||
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
||||
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
||||
|
||||
|
||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
@@ -55,7 +59,9 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||
"""
|
||||
if torch.is_tensor(array):
|
||||
if array.shape[axis] > length:
|
||||
array = array.index_select(dim=axis, index=torch.arange(length))
|
||||
array = array.index_select(
|
||||
dim=axis, index=torch.arange(length, device=array.device)
|
||||
)
|
||||
|
||||
if array.shape[axis] < length:
|
||||
pad_widths = [(0, 0)] * array.ndim
|
||||
@@ -85,11 +91,18 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
||||
)
|
||||
"""
|
||||
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
||||
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
|
||||
with np.load(
|
||||
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||
) as f:
|
||||
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
||||
|
||||
|
||||
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
|
||||
def log_mel_spectrogram(
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
n_mels: int = N_MELS,
|
||||
padding: int = 0,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
"""
|
||||
Compute the log-Mel spectrogram of
|
||||
|
||||
@@ -101,6 +114,12 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
|
||||
n_mels: int
|
||||
The number of Mel-frequency filters, only 80 is supported
|
||||
|
||||
padding: int
|
||||
Number of zero samples to pad to the right
|
||||
|
||||
device: Optional[Union[str, torch.device]]
|
||||
If given, the audio tensor is moved to this device before STFT
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor, shape = (80, n_frames)
|
||||
@@ -111,9 +130,13 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
|
||||
audio = load_audio(audio)
|
||||
audio = torch.from_numpy(audio)
|
||||
|
||||
if device is not None:
|
||||
audio = audio.to(device)
|
||||
if padding > 0:
|
||||
audio = F.pad(audio, (0, padding))
|
||||
window = torch.hann_window(N_FFT).to(audio.device)
|
||||
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||
magnitudes = stft[:, :-1].abs() ** 2
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
|
||||
filters = mel_filters(audio.device, n_mels)
|
||||
mel_spec = filters @ magnitudes
|
||||
|
||||
+178
-71
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -16,7 +16,9 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
|
||||
def detect_language(
|
||||
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
|
||||
) -> Tuple[Tensor, List[dict]]:
|
||||
"""
|
||||
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
||||
of the most probable language tokens and the probability distribution over all language tokens.
|
||||
@@ -31,8 +33,13 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
|
||||
"""
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer(model.is_multilingual)
|
||||
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
|
||||
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
|
||||
if (
|
||||
tokenizer.language is None
|
||||
or tokenizer.language_token not in tokenizer.sot_sequence
|
||||
):
|
||||
raise ValueError(
|
||||
"This model doesn't have language tokens so it can't perform lang id"
|
||||
)
|
||||
|
||||
single = mel.ndim == 2
|
||||
if single:
|
||||
@@ -70,31 +77,36 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodingOptions:
|
||||
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
|
||||
language: Optional[str] = None # language that the audio is in; uses detected language if None
|
||||
# whether to perform X->X "transcribe" or X->English "translate"
|
||||
task: str = "transcribe"
|
||||
|
||||
# language that the audio is in; uses detected language if None
|
||||
language: Optional[str] = None
|
||||
|
||||
# sampling-related options
|
||||
temperature: float = 0.0
|
||||
sample_len: Optional[int] = None # maximum number of tokens to sample
|
||||
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
|
||||
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
|
||||
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
|
||||
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
|
||||
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
|
||||
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
|
||||
|
||||
# options for ranking generations (either beams or best-of-N samples)
|
||||
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
|
||||
# "alpha" in Google NMT, or None for length norm, when ranking generations
|
||||
# to select which to return among the beams or best-of-N samples
|
||||
length_penalty: Optional[float] = None
|
||||
|
||||
# prompt, prefix, and token suppression
|
||||
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
|
||||
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
|
||||
suppress_blank: bool = True # this will suppress blank outputs
|
||||
# text or tokens to feed as the prompt or the prefix; for more info:
|
||||
# https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||
prompt: Optional[Union[str, List[int]]] = None # for the previous context
|
||||
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
|
||||
|
||||
# list of tokens ids (or comma-separated token ids) to suppress
|
||||
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
||||
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
||||
suppress_blank: bool = True # this will suppress blank outputs
|
||||
|
||||
# timestamp sampling options
|
||||
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
||||
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
|
||||
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
||||
max_initial_timestamp: Optional[float] = 1.0
|
||||
|
||||
# implementation details
|
||||
fp16: bool = True # use fp16 for most of the calculation
|
||||
@@ -158,7 +170,9 @@ class PyTorchInference(Inference):
|
||||
|
||||
|
||||
class SequenceRanker:
|
||||
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
|
||||
def rank(
|
||||
self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
|
||||
) -> List[int]:
|
||||
"""
|
||||
Given a list of groups of samples and their cumulative log probabilities,
|
||||
return the indices of the samples in each group to select as the final result
|
||||
@@ -196,7 +210,9 @@ class TokenDecoder:
|
||||
def reset(self):
|
||||
"""Initialize any stateful variables for decoding a new sequence"""
|
||||
|
||||
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||
def update(
|
||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||
) -> Tuple[Tensor, bool]:
|
||||
"""Specify how to select the next token, based on the current trace and logits
|
||||
|
||||
Parameters
|
||||
@@ -251,12 +267,13 @@ class GreedyDecoder(TokenDecoder):
|
||||
self.temperature = temperature
|
||||
self.eot = eot
|
||||
|
||||
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||
temperature = self.temperature
|
||||
if temperature == 0:
|
||||
def update(
|
||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||
) -> Tuple[Tensor, bool]:
|
||||
if self.temperature == 0:
|
||||
next_tokens = logits.argmax(dim=-1)
|
||||
else:
|
||||
next_tokens = Categorical(logits=logits / temperature).sample()
|
||||
next_tokens = Categorical(logits=logits / self.temperature).sample()
|
||||
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
||||
@@ -275,7 +292,13 @@ class GreedyDecoder(TokenDecoder):
|
||||
|
||||
|
||||
class BeamSearchDecoder(TokenDecoder):
|
||||
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
|
||||
def __init__(
|
||||
self,
|
||||
beam_size: int,
|
||||
eot: int,
|
||||
inference: Inference,
|
||||
patience: Optional[float] = None,
|
||||
):
|
||||
self.beam_size = beam_size
|
||||
self.eot = eot
|
||||
self.inference = inference
|
||||
@@ -283,12 +306,16 @@ class BeamSearchDecoder(TokenDecoder):
|
||||
self.max_candidates: int = round(beam_size * self.patience)
|
||||
self.finished_sequences = None
|
||||
|
||||
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
|
||||
assert (
|
||||
self.max_candidates > 0
|
||||
), f"Invalid beam size ({beam_size}) or patience ({patience})"
|
||||
|
||||
def reset(self):
|
||||
self.finished_sequences = None
|
||||
|
||||
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||
def update(
|
||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||
) -> Tuple[Tensor, bool]:
|
||||
if tokens.shape[0] % self.beam_size != 0:
|
||||
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
||||
|
||||
@@ -332,7 +359,9 @@ class BeamSearchDecoder(TokenDecoder):
|
||||
|
||||
# add newly finished sequences to self.finished_sequences
|
||||
assert len(self.finished_sequences) == len(finished_sequences)
|
||||
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
|
||||
for previously_finished, newly_finished in zip(
|
||||
self.finished_sequences, finished_sequences
|
||||
):
|
||||
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
||||
if len(previously_finished) >= self.max_candidates:
|
||||
break # the candidate list is full
|
||||
@@ -340,7 +369,8 @@ class BeamSearchDecoder(TokenDecoder):
|
||||
|
||||
# mark as completed if all audio has enough number of samples
|
||||
completed = all(
|
||||
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
|
||||
len(sequences) >= self.max_candidates
|
||||
for sequences in self.finished_sequences
|
||||
)
|
||||
return tokens, completed
|
||||
|
||||
@@ -348,7 +378,9 @@ class BeamSearchDecoder(TokenDecoder):
|
||||
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
||||
sum_logprobs = sum_logprobs.cpu()
|
||||
for i, sequences in enumerate(self.finished_sequences):
|
||||
if len(sequences) < self.beam_size: # when not enough sequences are finished
|
||||
if (
|
||||
len(sequences) < self.beam_size
|
||||
): # when not enough sequences are finished
|
||||
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
||||
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
||||
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
||||
@@ -356,7 +388,8 @@ class BeamSearchDecoder(TokenDecoder):
|
||||
break
|
||||
|
||||
tokens: List[List[Tensor]] = [
|
||||
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
|
||||
[torch.tensor(seq) for seq in sequences.keys()]
|
||||
for sequences in self.finished_sequences
|
||||
]
|
||||
sum_logprobs: List[List[float]] = [
|
||||
list(sequences.values()) for sequences in self.finished_sequences
|
||||
@@ -400,7 +433,10 @@ class SuppressTokens(LogitFilter):
|
||||
|
||||
class ApplyTimestampRules(LogitFilter):
|
||||
def __init__(
|
||||
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
|
||||
self,
|
||||
tokenizer: Tokenizer,
|
||||
sample_begin: int,
|
||||
max_initial_timestamp_index: Optional[int],
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.sample_begin = sample_begin
|
||||
@@ -413,9 +449,14 @@ class ApplyTimestampRules(LogitFilter):
|
||||
|
||||
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||
for k in range(tokens.shape[0]):
|
||||
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
|
||||
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
||||
sampled_tokens = tokens[k, self.sample_begin :]
|
||||
seq = [t for t in sampled_tokens.tolist()]
|
||||
last_was_timestamp = (
|
||||
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||
)
|
||||
penultimate_was_timestamp = (
|
||||
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
||||
)
|
||||
|
||||
if last_was_timestamp:
|
||||
if penultimate_was_timestamp: # has to be non-timestamp
|
||||
@@ -423,15 +464,30 @@ class ApplyTimestampRules(LogitFilter):
|
||||
else: # cannot be normal text tokens
|
||||
logits[k, : self.tokenizer.eot] = -np.inf
|
||||
|
||||
# apply the `max_initial_timestamp` option
|
||||
if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
|
||||
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
||||
logits[:, last_allowed + 1 :] = -np.inf
|
||||
timestamps = sampled_tokens[
|
||||
sampled_tokens.ge(self.tokenizer.timestamp_begin)
|
||||
]
|
||||
if timestamps.numel() > 0:
|
||||
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
|
||||
logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
|
||||
|
||||
if tokens.shape[1] == self.sample_begin:
|
||||
# suppress generating non-timestamp tokens at the beginning
|
||||
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|
||||
|
||||
# apply the `max_initial_timestamp` option
|
||||
if self.max_initial_timestamp_index is not None:
|
||||
last_allowed = (
|
||||
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
||||
)
|
||||
logits[:, last_allowed + 1 :] = -np.inf
|
||||
|
||||
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
for k in range(tokens.shape[0]):
|
||||
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
|
||||
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
|
||||
dim=-1
|
||||
)
|
||||
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
||||
if timestamp_logprob > max_text_token_logprob:
|
||||
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
||||
@@ -447,7 +503,9 @@ class DecodingTask:
|
||||
self.model = model
|
||||
|
||||
language = options.language or "en"
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
|
||||
tokenizer = get_tokenizer(
|
||||
model.is_multilingual, language=language, task=options.task
|
||||
)
|
||||
self.tokenizer: Tokenizer = tokenizer
|
||||
self.options: DecodingOptions = self._verify_options(options)
|
||||
|
||||
@@ -487,9 +545,13 @@ class DecodingTask:
|
||||
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
||||
max_initial_timestamp_index = None
|
||||
if options.max_initial_timestamp:
|
||||
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
|
||||
max_initial_timestamp_index = round(
|
||||
self.options.max_initial_timestamp / precision
|
||||
)
|
||||
self.logit_filters.append(
|
||||
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
|
||||
ApplyTimestampRules(
|
||||
tokenizer, self.sample_begin, max_initial_timestamp_index
|
||||
)
|
||||
)
|
||||
|
||||
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
||||
@@ -500,30 +562,38 @@ class DecodingTask:
|
||||
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
||||
if options.patience is not None and options.beam_size is None:
|
||||
raise ValueError("patience requires beam_size to be given")
|
||||
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
|
||||
if options.length_penalty is not None and not (
|
||||
0 <= options.length_penalty <= 1
|
||||
):
|
||||
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
||||
|
||||
return options
|
||||
|
||||
def _get_initial_tokens(self) -> Tuple[int]:
|
||||
tokens = list(self.sot_sequence)
|
||||
prefix = self.options.prefix
|
||||
prompt = self.options.prompt
|
||||
|
||||
if prefix:
|
||||
if prefix := self.options.prefix:
|
||||
prefix_tokens = (
|
||||
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
|
||||
self.tokenizer.encode(" " + prefix.strip())
|
||||
if isinstance(prefix, str)
|
||||
else prefix
|
||||
)
|
||||
if self.sample_len is not None:
|
||||
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
||||
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
||||
tokens = tokens + prefix_tokens
|
||||
|
||||
if prompt:
|
||||
if prompt := self.options.prompt:
|
||||
prompt_tokens = (
|
||||
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
|
||||
self.tokenizer.encode(" " + prompt.strip())
|
||||
if isinstance(prompt, str)
|
||||
else prompt
|
||||
)
|
||||
tokens = (
|
||||
[self.tokenizer.sot_prev]
|
||||
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
||||
+ tokens
|
||||
)
|
||||
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
|
||||
|
||||
return tuple(tokens)
|
||||
|
||||
@@ -542,7 +612,13 @@ class DecodingTask:
|
||||
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
||||
|
||||
suppress_tokens.extend(
|
||||
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
|
||||
[
|
||||
self.tokenizer.transcribe,
|
||||
self.tokenizer.translate,
|
||||
self.tokenizer.sot,
|
||||
self.tokenizer.sot_prev,
|
||||
self.tokenizer.sot_lm,
|
||||
]
|
||||
)
|
||||
if self.tokenizer.no_speech is not None:
|
||||
# no-speech probability is collected separately
|
||||
@@ -554,14 +630,21 @@ class DecodingTask:
|
||||
if self.options.fp16:
|
||||
mel = mel.half()
|
||||
|
||||
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
||||
if mel.shape[-2:] == (
|
||||
self.model.dims.n_audio_ctx,
|
||||
self.model.dims.n_audio_state,
|
||||
):
|
||||
# encoded audio features are given; skip audio encoding
|
||||
audio_features = mel
|
||||
else:
|
||||
audio_features = self.model.encoder(mel)
|
||||
|
||||
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
|
||||
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
|
||||
if audio_features.dtype != (
|
||||
torch.float16 if self.options.fp16 else torch.float32
|
||||
):
|
||||
return TypeError(
|
||||
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
||||
)
|
||||
|
||||
return audio_features
|
||||
|
||||
@@ -570,7 +653,9 @@ class DecodingTask:
|
||||
lang_probs = None
|
||||
|
||||
if self.options.language is None or self.options.task == "lang_id":
|
||||
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
|
||||
lang_tokens, lang_probs = self.model.detect_language(
|
||||
audio_features, self.tokenizer
|
||||
)
|
||||
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
||||
if self.options.language is None:
|
||||
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
||||
@@ -587,7 +672,9 @@ class DecodingTask:
|
||||
for i in range(self.sample_len):
|
||||
logits = self.inference.logits(tokens, audio_features)
|
||||
|
||||
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
|
||||
if (
|
||||
i == 0 and self.tokenizer.no_speech is not None
|
||||
): # save no_speech_probs
|
||||
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||
|
||||
@@ -621,8 +708,12 @@ class DecodingTask:
|
||||
languages, language_probs = self._detect_language(audio_features, tokens)
|
||||
if self.options.task == "lang_id":
|
||||
return [
|
||||
DecodingResult(audio_features=features, language=language, language_probs=probs)
|
||||
for features, language, probs in zip(audio_features, languages, language_probs)
|
||||
DecodingResult(
|
||||
audio_features=features, language=language, language_probs=probs
|
||||
)
|
||||
for features, language, probs in zip(
|
||||
audio_features, languages, language_probs
|
||||
)
|
||||
]
|
||||
|
||||
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
|
||||
@@ -643,7 +734,8 @@ class DecodingTask:
|
||||
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
||||
tokens: List[List[Tensor]] = [
|
||||
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
|
||||
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
||||
for s in tokens
|
||||
]
|
||||
|
||||
# select the top-ranked sample in each group
|
||||
@@ -652,9 +744,18 @@ class DecodingTask:
|
||||
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
||||
|
||||
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
||||
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
|
||||
avg_logprobs: List[float] = [
|
||||
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
|
||||
]
|
||||
|
||||
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
|
||||
fields = (
|
||||
texts,
|
||||
languages,
|
||||
tokens,
|
||||
audio_features,
|
||||
avg_logprobs,
|
||||
no_speech_probs,
|
||||
)
|
||||
if len(set(map(len, fields))) != 1:
|
||||
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
||||
|
||||
@@ -669,12 +770,19 @@ class DecodingTask:
|
||||
temperature=self.options.temperature,
|
||||
compression_ratio=compression_ratio(text),
|
||||
)
|
||||
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
|
||||
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
|
||||
*fields
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
|
||||
def decode(
|
||||
model: "Whisper",
|
||||
mel: Tensor,
|
||||
options: DecodingOptions = DecodingOptions(),
|
||||
**kwargs,
|
||||
) -> Union[DecodingResult, List[DecodingResult]]:
|
||||
"""
|
||||
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
||||
|
||||
@@ -694,13 +802,12 @@ def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOpt
|
||||
result: Union[DecodingResult, List[DecodingResult]]
|
||||
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
||||
"""
|
||||
single = mel.ndim == 2
|
||||
if single:
|
||||
if single := mel.ndim == 2:
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
result = DecodingTask(model, options).run(mel)
|
||||
|
||||
if single:
|
||||
result = result[0]
|
||||
if kwargs:
|
||||
options = replace(options, **kwargs)
|
||||
|
||||
return result
|
||||
result = DecodingTask(model, options).run(mel)
|
||||
|
||||
return result[0] if single else result
|
||||
|
||||
+71
-29
@@ -1,15 +1,16 @@
|
||||
import base64
|
||||
import gzip
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
from typing import Iterable, Optional
|
||||
from typing import Dict, Iterable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .decoding import decode as decode_function
|
||||
from .decoding import detect_language as detect_language_function
|
||||
from .transcribe import transcribe as transcribe_function
|
||||
from .decoding import detect_language as detect_language_function, decode as decode_function
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,12 +35,16 @@ class LayerNorm(nn.LayerNorm):
|
||||
class Linear(nn.Linear):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(
|
||||
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
|
||||
x,
|
||||
self.weight.to(x.dtype),
|
||||
None if self.bias is None else self.bias.to(x.dtype),
|
||||
)
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
|
||||
def _conv_forward(
|
||||
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
||||
) -> Tensor:
|
||||
return super()._conv_forward(
|
||||
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||
)
|
||||
@@ -72,20 +77,22 @@ class MultiHeadAttention(nn.Module):
|
||||
):
|
||||
q = self.query(x)
|
||||
|
||||
if kv_cache is None or xa is None:
|
||||
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
||||
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
||||
k = self.key(x if xa is None else xa)
|
||||
v = self.value(x if xa is None else xa)
|
||||
else:
|
||||
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||
k = kv_cache.get(self.key, self.key(xa))
|
||||
v = kv_cache.get(self.value, self.value(xa))
|
||||
k = kv_cache[self.key]
|
||||
v = kv_cache[self.value]
|
||||
|
||||
wv = self.qkv_attention(q, k, v, mask)
|
||||
return self.out(wv)
|
||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||
return self.out(wv), qk
|
||||
|
||||
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
||||
def qkv_attention(
|
||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||
):
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.25
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||
@@ -95,9 +102,10 @@ class MultiHeadAttention(nn.Module):
|
||||
qk = q @ k
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
qk = qk.float()
|
||||
|
||||
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
|
||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
@@ -107,11 +115,15 @@ class ResidualAttentionBlock(nn.Module):
|
||||
self.attn = MultiHeadAttention(n_state, n_head)
|
||||
self.attn_ln = LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||
self.cross_attn = (
|
||||
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||
)
|
||||
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||
|
||||
n_mlp = n_state * 4
|
||||
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
|
||||
self.mlp = nn.Sequential(
|
||||
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
||||
)
|
||||
self.mlp_ln = LayerNorm(n_state)
|
||||
|
||||
def forward(
|
||||
@@ -121,15 +133,17 @@ class ResidualAttentionBlock(nn.Module):
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
):
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
||||
x = x + self.mlp(self.mlp_ln(x))
|
||||
return x
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||
def __init__(
|
||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||
@@ -160,14 +174,19 @@ class AudioEncoder(nn.Module):
|
||||
|
||||
|
||||
class TextDecoder(nn.Module):
|
||||
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||
def __init__(
|
||||
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
|
||||
[
|
||||
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
)
|
||||
self.ln = LayerNorm(n_state)
|
||||
|
||||
@@ -182,14 +201,19 @@ class TextDecoder(nn.Module):
|
||||
the encoded audio features to be attended on
|
||||
"""
|
||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
x = (
|
||||
self.token_embedding(x)
|
||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
)
|
||||
x = x.to(xa.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||
|
||||
x = self.ln(x)
|
||||
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
||||
logits = (
|
||||
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
||||
).float()
|
||||
|
||||
return logits
|
||||
|
||||
@@ -212,14 +236,31 @@ class Whisper(nn.Module):
|
||||
self.dims.n_text_head,
|
||||
self.dims.n_text_layer,
|
||||
)
|
||||
# use the last half layers for alignment by default; see `set_alignment_heads()` below
|
||||
all_heads = torch.zeros(
|
||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||
)
|
||||
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||
|
||||
def set_alignment_heads(self, dump: bytes):
|
||||
array = np.frombuffer(
|
||||
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
||||
).copy()
|
||||
mask = torch.from_numpy(array).reshape(
|
||||
self.dims.n_text_layer, self.dims.n_text_head
|
||||
)
|
||||
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
||||
|
||||
def embed_audio(self, mel: torch.Tensor):
|
||||
return self.encoder.forward(mel)
|
||||
return self.encoder(mel)
|
||||
|
||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||
return self.decoder.forward(tokens, audio_features)
|
||||
return self.decoder(tokens, audio_features)
|
||||
|
||||
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
def forward(
|
||||
self, mel: torch.Tensor, tokens: torch.Tensor
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
return self.decoder(tokens, self.encoder(mel))
|
||||
|
||||
@property
|
||||
@@ -248,8 +289,9 @@ class Whisper(nn.Module):
|
||||
hooks = []
|
||||
|
||||
def save_to_cache(module, _, output):
|
||||
if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
|
||||
cache[module] = output # save as-is, for the first token or cross attention
|
||||
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
||||
# save as-is, for the first token or cross attention
|
||||
cache[module] = output
|
||||
else:
|
||||
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
||||
return cache[module]
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from .basic import BasicTextNormalizer
|
||||
from .english import EnglishTextNormalizer
|
||||
from .basic import BasicTextNormalizer as BasicTextNormalizer
|
||||
from .english import EnglishTextNormalizer as EnglishTextNormalizer
|
||||
|
||||
@@ -48,13 +48,16 @@ def remove_symbols(s: str):
|
||||
Replace any other markers, symbols, punctuations with a space, keeping diacritics
|
||||
"""
|
||||
return "".join(
|
||||
" " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
|
||||
" " if unicodedata.category(c)[0] in "MSP" else c
|
||||
for c in unicodedata.normalize("NFKC", s)
|
||||
)
|
||||
|
||||
|
||||
class BasicTextNormalizer:
|
||||
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
|
||||
self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
||||
self.clean = (
|
||||
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
||||
)
|
||||
self.split_letters = split_letters
|
||||
|
||||
def __call__(self, s: str):
|
||||
@@ -66,6 +69,8 @@ class BasicTextNormalizer:
|
||||
if self.split_letters:
|
||||
s = " ".join(regex.findall(r"\X", s, regex.U))
|
||||
|
||||
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
|
||||
s = re.sub(
|
||||
r"\s+", " ", s
|
||||
) # replace any successive whitespace characters with a space
|
||||
|
||||
return s
|
||||
|
||||
@@ -1737,6 +1737,5 @@
|
||||
"yoghurt": "yogurt",
|
||||
"yoghurts": "yogurts",
|
||||
"mhm": "hmm",
|
||||
"mm": "hmm",
|
||||
"mmm": "hmm"
|
||||
}
|
||||
@@ -84,7 +84,8 @@ class EnglishNumberNormalizer:
|
||||
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_ordinal = {
|
||||
name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()
|
||||
name.replace("y", "ieth"): (value, "th")
|
||||
for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
|
||||
|
||||
@@ -108,7 +109,10 @@ class EnglishNumberNormalizer:
|
||||
self.multipliers_ordinal = {
|
||||
name + "th": (value, "th") for name, value in self.multipliers.items()
|
||||
}
|
||||
self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal}
|
||||
self.multipliers_suffixed = {
|
||||
**self.multipliers_plural,
|
||||
**self.multipliers_ordinal,
|
||||
}
|
||||
self.decimals = {*self.ones, *self.tens, *self.zeros}
|
||||
|
||||
self.preceding_prefixers = {
|
||||
@@ -128,7 +132,8 @@ class EnglishNumberNormalizer:
|
||||
"cents": "¢",
|
||||
}
|
||||
self.prefixes = set(
|
||||
list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())
|
||||
list(self.preceding_prefixers.values())
|
||||
+ list(self.following_prefixers.values())
|
||||
)
|
||||
self.suffixers = {
|
||||
"per": {"cent": "%"},
|
||||
@@ -218,7 +223,9 @@ class EnglishNumberNormalizer:
|
||||
if value is None:
|
||||
value = ones
|
||||
elif isinstance(value, str) or prev in self.ones:
|
||||
if prev in self.tens and ones < 10: # replace the last zero with the digit
|
||||
if (
|
||||
prev in self.tens and ones < 10
|
||||
): # replace the last zero with the digit
|
||||
assert value[-1] == "0"
|
||||
value = value[:-1] + str(ones)
|
||||
else:
|
||||
@@ -522,14 +529,14 @@ class EnglishTextNormalizer:
|
||||
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||
s = re.sub(self.ignore_patterns, "", s)
|
||||
s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe
|
||||
s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
|
||||
|
||||
for pattern, replacement in self.replacers.items():
|
||||
s = re.sub(pattern, replacement, s)
|
||||
|
||||
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
|
||||
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
|
||||
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics
|
||||
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
|
||||
|
||||
s = self.standardize_numbers(s)
|
||||
s = self.standardize_spellings(s)
|
||||
@@ -538,6 +545,6 @@ class EnglishTextNormalizer:
|
||||
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
|
||||
s = re.sub(r"([^0-9])%", r"\1 ", s)
|
||||
|
||||
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
|
||||
s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
|
||||
|
||||
return s
|
||||
|
||||
@@ -0,0 +1,323 @@
|
||||
import subprocess
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
|
||||
def median_filter(x: torch.Tensor, filter_width: int):
|
||||
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
|
||||
pad_width = filter_width // 2
|
||||
if x.shape[-1] <= pad_width:
|
||||
# F.pad requires the padding width to be smaller than the input dimension
|
||||
return x
|
||||
|
||||
if (ndim := x.ndim) <= 2:
|
||||
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
|
||||
x = x[None, None, :]
|
||||
|
||||
assert (
|
||||
filter_width > 0 and filter_width % 2 == 1
|
||||
), "`filter_width` should be an odd number"
|
||||
|
||||
result = None
|
||||
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
||||
if x.is_cuda:
|
||||
try:
|
||||
from .triton_ops import median_filter_cuda
|
||||
|
||||
result = median_filter_cuda(x, filter_width)
|
||||
except (RuntimeError, subprocess.CalledProcessError):
|
||||
warnings.warn(
|
||||
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
||||
"falling back to a slower median kernel implementation..."
|
||||
)
|
||||
|
||||
if result is None:
|
||||
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
|
||||
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
||||
|
||||
if ndim <= 2:
|
||||
result = result[0, 0]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@numba.jit
|
||||
def backtrace(trace: np.ndarray):
|
||||
i = trace.shape[0] - 1
|
||||
j = trace.shape[1] - 1
|
||||
trace[0, :] = 2
|
||||
trace[:, 0] = 1
|
||||
|
||||
result = []
|
||||
while i > 0 or j > 0:
|
||||
result.append((i - 1, j - 1))
|
||||
|
||||
if trace[i, j] == 0:
|
||||
i -= 1
|
||||
j -= 1
|
||||
elif trace[i, j] == 1:
|
||||
i -= 1
|
||||
elif trace[i, j] == 2:
|
||||
j -= 1
|
||||
else:
|
||||
raise ValueError("Unexpected trace[i, j]")
|
||||
|
||||
result = np.array(result)
|
||||
return result[::-1, :].T
|
||||
|
||||
|
||||
@numba.jit(nopython=True, parallel=True)
|
||||
def dtw_cpu(x: np.ndarray):
|
||||
N, M = x.shape
|
||||
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
||||
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
||||
|
||||
cost[0, 0] = 0
|
||||
for j in range(1, M + 1):
|
||||
for i in range(1, N + 1):
|
||||
c0 = cost[i - 1, j - 1]
|
||||
c1 = cost[i - 1, j]
|
||||
c2 = cost[i, j - 1]
|
||||
|
||||
if c0 < c1 and c0 < c2:
|
||||
c, t = c0, 0
|
||||
elif c1 < c0 and c1 < c2:
|
||||
c, t = c1, 1
|
||||
else:
|
||||
c, t = c2, 2
|
||||
|
||||
cost[i, j] = x[i - 1, j - 1] + c
|
||||
trace[i, j] = t
|
||||
|
||||
return backtrace(trace)
|
||||
|
||||
|
||||
def dtw_cuda(x, BLOCK_SIZE=1024):
|
||||
from .triton_ops import dtw_kernel
|
||||
|
||||
M, N = x.shape
|
||||
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
|
||||
|
||||
x_skew = (
|
||||
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
|
||||
)
|
||||
x_skew = x_skew.T.contiguous()
|
||||
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
||||
cost[0, 0] = 0
|
||||
cost = cost.cuda()
|
||||
trace = torch.zeros_like(cost, dtype=torch.int32)
|
||||
|
||||
dtw_kernel[(1,)](
|
||||
cost,
|
||||
trace,
|
||||
x_skew,
|
||||
x_skew.stride(0),
|
||||
cost.stride(0),
|
||||
trace.stride(0),
|
||||
N,
|
||||
M,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
|
||||
:, : N + 1
|
||||
]
|
||||
return backtrace(trace.cpu().numpy())
|
||||
|
||||
|
||||
def dtw(x: torch.Tensor) -> np.ndarray:
|
||||
if x.is_cuda:
|
||||
try:
|
||||
return dtw_cuda(x)
|
||||
except (RuntimeError, subprocess.CalledProcessError):
|
||||
warnings.warn(
|
||||
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
||||
"falling back to a slower DTW implementation..."
|
||||
)
|
||||
|
||||
return dtw_cpu(x.double().cpu().numpy())
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordTiming:
|
||||
word: str
|
||||
tokens: List[int]
|
||||
start: float
|
||||
end: float
|
||||
probability: float
|
||||
|
||||
|
||||
def find_alignment(
|
||||
model: "Whisper",
|
||||
tokenizer: Tokenizer,
|
||||
text_tokens: List[int],
|
||||
mel: torch.Tensor,
|
||||
num_frames: int,
|
||||
*,
|
||||
medfilt_width: int = 7,
|
||||
qk_scale: float = 1.0,
|
||||
) -> List[WordTiming]:
|
||||
tokens = torch.tensor(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
tokenizer.no_timestamps,
|
||||
*text_tokens,
|
||||
tokenizer.eot,
|
||||
]
|
||||
).to(model.device)
|
||||
|
||||
# install hooks on the cross attention layers to retrieve the attention weights
|
||||
QKs = [None] * model.dims.n_text_layer
|
||||
hooks = [
|
||||
block.cross_attn.register_forward_hook(
|
||||
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
|
||||
)
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||
token_probs = sampled_logits.softmax(dim=-1)
|
||||
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
||||
text_token_probs = text_token_probs.tolist()
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
# heads * tokens * frames
|
||||
weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T])
|
||||
weights = weights[:, :, : num_frames // 2]
|
||||
weights = (weights * qk_scale).softmax(dim=-1)
|
||||
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||
weights = (weights - mean) / std
|
||||
weights = median_filter(weights, medfilt_width)
|
||||
|
||||
matrix = weights.mean(axis=0)
|
||||
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
||||
text_indices, time_indices = dtw(-matrix)
|
||||
|
||||
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
||||
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
||||
|
||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
||||
start_times = jump_times[word_boundaries[:-1]]
|
||||
end_times = jump_times[word_boundaries[1:]]
|
||||
word_probabilities = [
|
||||
np.mean(text_token_probs[i:j])
|
||||
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
||||
]
|
||||
|
||||
# hack: ensure the first and second word is not longer than twice the median word duration.
|
||||
# a better segmentation algorithm based on VAD should be able to replace this.
|
||||
word_durations = end_times - start_times
|
||||
word_durations = word_durations[word_durations.nonzero()]
|
||||
if len(word_durations) > 0:
|
||||
median_duration = np.median(word_durations)
|
||||
max_duration = median_duration * 2
|
||||
if len(word_durations) >= 2 and word_durations[1] > max_duration:
|
||||
boundary = max(end_times[2] / 2, end_times[2] - max_duration)
|
||||
end_times[0] = start_times[1] = boundary
|
||||
if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
|
||||
start_times[0] = max(0, end_times[0] - max_duration)
|
||||
|
||||
return [
|
||||
WordTiming(word, tokens, start, end, probability)
|
||||
for word, tokens, start, end, probability in zip(
|
||||
words, word_tokens, start_times, end_times, word_probabilities
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
|
||||
# merge prepended punctuations
|
||||
i = len(alignment) - 2
|
||||
j = len(alignment) - 1
|
||||
while i >= 0:
|
||||
previous = alignment[i]
|
||||
following = alignment[j]
|
||||
if previous.word.startswith(" ") and previous.word.strip() in prepended:
|
||||
# prepend it to the following word
|
||||
following.word = previous.word + following.word
|
||||
following.tokens = previous.tokens + following.tokens
|
||||
previous.word = ""
|
||||
previous.tokens = []
|
||||
else:
|
||||
j = i
|
||||
i -= 1
|
||||
|
||||
# merge appended punctuations
|
||||
i = 0
|
||||
j = 1
|
||||
while j < len(alignment):
|
||||
previous = alignment[i]
|
||||
following = alignment[j]
|
||||
if not previous.word.endswith(" ") and following.word in appended:
|
||||
# append it to the previous word
|
||||
previous.word = previous.word + following.word
|
||||
previous.tokens = previous.tokens + following.tokens
|
||||
following.word = ""
|
||||
following.tokens = []
|
||||
else:
|
||||
i = j
|
||||
j += 1
|
||||
|
||||
|
||||
def add_word_timestamps(
|
||||
*,
|
||||
segments: List[dict],
|
||||
model: "Whisper",
|
||||
tokenizer: Tokenizer,
|
||||
mel: torch.Tensor,
|
||||
num_frames: int,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
**kwargs,
|
||||
):
|
||||
if len(segments) == 0:
|
||||
return
|
||||
|
||||
text_tokens = [t for s in segments for t in s["tokens"] if t < tokenizer.eot]
|
||||
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
||||
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
||||
|
||||
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
||||
segment_lengths = [len(s["tokens"]) for s in segments]
|
||||
token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
|
||||
|
||||
for segment in segments:
|
||||
segment["words"] = []
|
||||
|
||||
word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0))
|
||||
for i, timing in enumerate(alignment):
|
||||
if timing.word:
|
||||
segment = segments[token_sources[word_boundaries[i]]]
|
||||
start = round(time_offset + timing.start, 2)
|
||||
end = round(time_offset + timing.end, 2)
|
||||
segment["words"].append(
|
||||
dict(
|
||||
word=timing.word,
|
||||
start=start,
|
||||
end=end,
|
||||
probability=timing.probability,
|
||||
)
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
if len(words := segment["words"]) > 0:
|
||||
# adjust the segment-level timestamps based on the word-level timestamps
|
||||
segment["start"] = words[0]["start"]
|
||||
segment["end"] = words[-1]["end"]
|
||||
+83
-34
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import string
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -28,7 +29,7 @@ LANGUAGES = {
|
||||
"hi": "hindi",
|
||||
"fi": "finnish",
|
||||
"vi": "vietnamese",
|
||||
"iw": "hebrew",
|
||||
"he": "hebrew",
|
||||
"uk": "ukrainian",
|
||||
"el": "greek",
|
||||
"ms": "malay",
|
||||
@@ -137,7 +138,9 @@ class Tokenizer:
|
||||
def encode(self, text, **kwargs):
|
||||
return self.tokenizer.encode(text, **kwargs)
|
||||
|
||||
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
|
||||
def decode(
|
||||
self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs
|
||||
):
|
||||
return self.tokenizer.decode(token_ids, **kwargs)
|
||||
|
||||
def decode_with_timestamps(self, tokens) -> str:
|
||||
@@ -153,50 +156,51 @@ class Tokenizer:
|
||||
outputs.append([])
|
||||
else:
|
||||
outputs[-1].append(token)
|
||||
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
||||
return "".join(outputs)
|
||||
return "".join(
|
||||
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
||||
)
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def eot(self) -> int:
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def transcribe(self) -> int:
|
||||
return self._get_single_token_id("<|transcribe|>")
|
||||
|
||||
@cached_property
|
||||
def translate(self) -> int:
|
||||
return self._get_single_token_id("<|translate|>")
|
||||
|
||||
@cached_property
|
||||
def sot(self) -> int:
|
||||
return self._get_single_token_id("<|startoftranscript|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def sot_lm(self) -> int:
|
||||
return self._get_single_token_id("<|startoflm|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def sot_prev(self) -> int:
|
||||
return self._get_single_token_id("<|startofprev|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def no_speech(self) -> int:
|
||||
return self._get_single_token_id("<|nospeech|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def no_timestamps(self) -> int:
|
||||
return self._get_single_token_id("<|notimestamps|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def timestamp_begin(self) -> int:
|
||||
return self.tokenizer.all_special_ids[-1] + 1
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def language_token(self) -> int:
|
||||
"""Returns the token id corresponding to the value of the `language` field"""
|
||||
if self.language is None:
|
||||
raise ValueError(f"This tokenizer does not have language token configured")
|
||||
raise ValueError("This tokenizer does not have language token configured")
|
||||
|
||||
additional_tokens = dict(
|
||||
zip(
|
||||
@@ -210,8 +214,7 @@ class Tokenizer:
|
||||
|
||||
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def all_language_tokens(self) -> Tuple[int]:
|
||||
result = []
|
||||
for token, token_id in zip(
|
||||
@@ -222,18 +225,15 @@ class Tokenizer:
|
||||
result.append(token_id)
|
||||
return tuple(result)
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def all_language_codes(self) -> Tuple[str]:
|
||||
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
||||
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached_property
|
||||
def non_speech_tokens(self) -> Tuple[int]:
|
||||
"""
|
||||
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
||||
@@ -245,8 +245,10 @@ class Tokenizer:
|
||||
|
||||
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
||||
"""
|
||||
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
|
||||
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
||||
symbols += (
|
||||
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||
)
|
||||
|
||||
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
||||
# In case they're multiple tokens, suppress the first token, which is safe because:
|
||||
@@ -258,7 +260,10 @@ class Tokenizer:
|
||||
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
|
||||
for symbol in symbols + list(miscellaneous):
|
||||
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
|
||||
for tokens in [
|
||||
self.tokenizer.encode(symbol),
|
||||
self.tokenizer.encode(" " + symbol),
|
||||
]:
|
||||
if len(tokens) == 1 or symbol in miscellaneous:
|
||||
result.add(tokens[0])
|
||||
|
||||
@@ -269,6 +274,48 @@ class Tokenizer:
|
||||
assert len(tokens) == 1, f"{text} is not encoded as a single token"
|
||||
return tokens[0]
|
||||
|
||||
def split_to_word_tokens(self, tokens: List[int]):
|
||||
if self.language in {"zh", "ja", "th", "lo", "my"}:
|
||||
# These languages don't typically use spaces, so it is difficult to split words
|
||||
# without morpheme analysis. Here, we instead split words at any
|
||||
# position where the tokens are decoded as valid unicode points
|
||||
return self.split_tokens_on_unicode(tokens)
|
||||
|
||||
return self.split_tokens_on_spaces(tokens)
|
||||
|
||||
def split_tokens_on_unicode(self, tokens: List[int]):
|
||||
words = []
|
||||
word_tokens = []
|
||||
current_tokens = []
|
||||
|
||||
for token in tokens:
|
||||
current_tokens.append(token)
|
||||
decoded = self.decode_with_timestamps(current_tokens)
|
||||
if "\ufffd" not in decoded:
|
||||
words.append(decoded)
|
||||
word_tokens.append(current_tokens)
|
||||
current_tokens = []
|
||||
|
||||
return words, word_tokens
|
||||
|
||||
def split_tokens_on_spaces(self, tokens: List[int]):
|
||||
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
||||
words = []
|
||||
word_tokens = []
|
||||
|
||||
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
||||
special = subword_tokens[0] >= self.eot
|
||||
with_space = subword.startswith(" ")
|
||||
punctuation = subword.strip() in string.punctuation
|
||||
if special or with_space or punctuation or len(words) == 0:
|
||||
words.append(subword)
|
||||
word_tokens.append(subword_tokens)
|
||||
else:
|
||||
words[-1] = words[-1] + subword
|
||||
word_tokens[-1].extend(subword_tokens)
|
||||
|
||||
return words, word_tokens
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def build_tokenizer(name: str = "gpt2"):
|
||||
@@ -328,4 +375,6 @@ def get_tokenizer(
|
||||
if task is not None:
|
||||
sot_sequence.append(transcribe if task == "transcribe" else translate)
|
||||
|
||||
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
|
||||
return Tokenizer(
|
||||
tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
|
||||
)
|
||||
|
||||
+232
-106
@@ -1,16 +1,33 @@
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
|
||||
from .audio import (
|
||||
FRAMES_PER_SECOND,
|
||||
HOP_LENGTH,
|
||||
N_FRAMES,
|
||||
N_SAMPLES,
|
||||
SAMPLE_RATE,
|
||||
log_mel_spectrogram,
|
||||
pad_or_trim,
|
||||
)
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .timing import add_word_timestamps
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
|
||||
from .utils import (
|
||||
exact_div,
|
||||
format_timestamp,
|
||||
get_writer,
|
||||
make_safe,
|
||||
optional_float,
|
||||
optional_int,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
@@ -26,6 +43,10 @@ def transcribe(
|
||||
logprob_threshold: Optional[float] = -1.0,
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
word_timestamps: bool = False,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
@@ -44,7 +65,7 @@ def transcribe(
|
||||
If False, displays minimal details. If None, does not display anything
|
||||
|
||||
temperature: Union[float, Tuple[float, ...]]
|
||||
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
|
||||
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
|
||||
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
||||
|
||||
compression_ratio_threshold: float
|
||||
@@ -62,6 +83,21 @@ def transcribe(
|
||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
||||
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
||||
|
||||
word_timestamps: bool
|
||||
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
||||
and include the timestamps for each word in each segment.
|
||||
|
||||
prepend_punctuations: str
|
||||
If word_timestamps is True, merge these punctuation symbols with the next word
|
||||
|
||||
append_punctuations: str
|
||||
If word_timestamps is True, merge these punctuation symbols with the previous word
|
||||
|
||||
initial_prompt: Optional[str]
|
||||
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||
to make it more likely to predict those word correctly.
|
||||
|
||||
decode_options: dict
|
||||
Keyword arguments to construct `DecodingOptions` instances
|
||||
|
||||
@@ -81,23 +117,37 @@ def transcribe(
|
||||
if dtype == torch.float32:
|
||||
decode_options["fp16"] = False
|
||||
|
||||
mel = log_mel_spectrogram(audio)
|
||||
# Pad 30-seconds of silence to the input audio, for slicing
|
||||
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
||||
content_frames = mel.shape[-1] - N_FRAMES
|
||||
|
||||
if decode_options.get("language", None) is None:
|
||||
if verbose:
|
||||
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
|
||||
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
_, probs = model.detect_language(segment)
|
||||
decode_options["language"] = max(probs, key=probs.get)
|
||||
if verbose is not None:
|
||||
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
|
||||
if not model.is_multilingual:
|
||||
decode_options["language"] = "en"
|
||||
else:
|
||||
if verbose:
|
||||
print(
|
||||
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
||||
)
|
||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
_, probs = model.detect_language(mel_segment)
|
||||
decode_options["language"] = max(probs, key=probs.get)
|
||||
if verbose is not None:
|
||||
print(
|
||||
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
||||
)
|
||||
|
||||
language = decode_options["language"]
|
||||
task = decode_options.get("task", "transcribe")
|
||||
language: str = decode_options["language"]
|
||||
task: str = decode_options.get("task", "transcribe")
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
||||
|
||||
if word_timestamps and task == "translate":
|
||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||
|
||||
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
||||
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
|
||||
temperatures = (
|
||||
[temperature] if isinstance(temperature, (int, float)) else temperature
|
||||
)
|
||||
decode_result = None
|
||||
|
||||
for t in temperatures:
|
||||
@@ -114,9 +164,15 @@ def transcribe(
|
||||
decode_result = model.decode(segment, options)
|
||||
|
||||
needs_fallback = False
|
||||
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
|
||||
if (
|
||||
compression_ratio_threshold is not None
|
||||
and decode_result.compression_ratio > compression_ratio_threshold
|
||||
):
|
||||
needs_fallback = True # too repetitive
|
||||
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
|
||||
if (
|
||||
logprob_threshold is not None
|
||||
and decode_result.avg_logprob < logprob_threshold
|
||||
):
|
||||
needs_fallback = True # average log probability is too low
|
||||
|
||||
if not needs_fallback:
|
||||
@@ -135,123 +191,193 @@ def transcribe(
|
||||
all_segments = []
|
||||
prompt_reset_since = 0
|
||||
|
||||
initial_prompt = decode_options.pop("initial_prompt", None) or []
|
||||
if initial_prompt:
|
||||
initial_prompt = tokenizer.encode(" " + initial_prompt.strip())
|
||||
all_tokens.extend(initial_prompt)
|
||||
if initial_prompt is not None:
|
||||
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
else:
|
||||
initial_prompt_tokens = []
|
||||
|
||||
def add_segment(
|
||||
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
|
||||
def new_segment(
|
||||
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
||||
):
|
||||
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
|
||||
if len(text.strip()) == 0: # skip empty text output
|
||||
return
|
||||
tokens = tokens.tolist()
|
||||
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
||||
return {
|
||||
"seek": seek,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": tokenizer.decode(text_tokens),
|
||||
"tokens": tokens,
|
||||
"temperature": result.temperature,
|
||||
"avg_logprob": result.avg_logprob,
|
||||
"compression_ratio": result.compression_ratio,
|
||||
"no_speech_prob": result.no_speech_prob,
|
||||
}
|
||||
|
||||
all_segments.append(
|
||||
{
|
||||
"id": len(all_segments),
|
||||
"seek": seek,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": text,
|
||||
"tokens": result.tokens,
|
||||
"temperature": result.temperature,
|
||||
"avg_logprob": result.avg_logprob,
|
||||
"compression_ratio": result.compression_ratio,
|
||||
"no_speech_prob": result.no_speech_prob,
|
||||
}
|
||||
)
|
||||
if verbose:
|
||||
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
|
||||
|
||||
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
||||
num_frames = mel.shape[-1]
|
||||
previous_seek_value = seek
|
||||
|
||||
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
|
||||
while seek < num_frames:
|
||||
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
|
||||
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
||||
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
||||
with tqdm.tqdm(
|
||||
total=content_frames, unit="frames", disable=verbose is not False
|
||||
) as pbar:
|
||||
while seek < content_frames:
|
||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||
mel_segment = mel[:, seek : seek + N_FRAMES]
|
||||
segment_size = min(N_FRAMES, content_frames - seek)
|
||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||
|
||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||
result: DecodingResult = decode_with_fallback(segment)
|
||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||
tokens = torch.tensor(result.tokens)
|
||||
|
||||
if no_speech_threshold is not None:
|
||||
# no voice activity check
|
||||
should_skip = result.no_speech_prob > no_speech_threshold
|
||||
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
|
||||
if (
|
||||
logprob_threshold is not None
|
||||
and result.avg_logprob > logprob_threshold
|
||||
):
|
||||
# don't skip if the logprob is high enough, despite the no_speech_prob
|
||||
should_skip = False
|
||||
|
||||
if should_skip:
|
||||
seek += segment.shape[-1] # fast-forward to the next segment boundary
|
||||
seek += segment_size # fast-forward to the next segment boundary
|
||||
continue
|
||||
|
||||
previous_seek = seek
|
||||
current_segments = []
|
||||
|
||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
|
||||
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
|
||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||
|
||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||
consecutive.add_(1)
|
||||
if len(consecutive) > 0:
|
||||
# if the output contains two consecutive timestamp tokens
|
||||
slices = consecutive.tolist()
|
||||
if single_timestamp_ending:
|
||||
slices.append(len(tokens))
|
||||
|
||||
last_slice = 0
|
||||
for current_slice in consecutive:
|
||||
for current_slice in slices:
|
||||
sliced_tokens = tokens[last_slice:current_slice]
|
||||
start_timestamp_position = (
|
||||
start_timestamp_pos = (
|
||||
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
end_timestamp_position = (
|
||||
end_timestamp_pos = (
|
||||
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
add_segment(
|
||||
start=timestamp_offset + start_timestamp_position * time_precision,
|
||||
end=timestamp_offset + end_timestamp_position * time_precision,
|
||||
text_tokens=sliced_tokens[1:-1],
|
||||
result=result,
|
||||
current_segments.append(
|
||||
new_segment(
|
||||
start=time_offset + start_timestamp_pos * time_precision,
|
||||
end=time_offset + end_timestamp_pos * time_precision,
|
||||
tokens=sliced_tokens,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
last_slice = current_slice
|
||||
last_timestamp_position = (
|
||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
seek += last_timestamp_position * input_stride
|
||||
all_tokens.extend(tokens[: last_slice + 1].tolist())
|
||||
|
||||
if single_timestamp_ending:
|
||||
# single timestamp at the end means no speech after the last timestamp.
|
||||
seek += segment_size
|
||||
else:
|
||||
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
||||
last_timestamp_pos = (
|
||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
seek += last_timestamp_pos * input_stride
|
||||
else:
|
||||
duration = segment_duration
|
||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
|
||||
if (
|
||||
len(timestamps) > 0
|
||||
and timestamps[-1].item() != tokenizer.timestamp_begin
|
||||
):
|
||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||
# single timestamp at the end means no speech after the last timestamp.
|
||||
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
|
||||
duration = last_timestamp_position * time_precision
|
||||
last_timestamp_pos = (
|
||||
timestamps[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
duration = last_timestamp_pos * time_precision
|
||||
|
||||
add_segment(
|
||||
start=timestamp_offset,
|
||||
end=timestamp_offset + duration,
|
||||
text_tokens=tokens,
|
||||
result=result,
|
||||
current_segments.append(
|
||||
new_segment(
|
||||
start=time_offset,
|
||||
end=time_offset + duration,
|
||||
tokens=tokens,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
|
||||
seek += segment.shape[-1]
|
||||
all_tokens.extend(tokens.tolist())
|
||||
seek += segment_size
|
||||
|
||||
if not condition_on_previous_text or result.temperature > 0.5:
|
||||
# do not feed the prompt tokens if a high temperature was used
|
||||
prompt_reset_since = len(all_tokens)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(min(num_frames, seek) - previous_seek_value)
|
||||
previous_seek_value = seek
|
||||
if word_timestamps:
|
||||
add_word_timestamps(
|
||||
segments=current_segments,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
mel=mel_segment,
|
||||
num_frames=segment_size,
|
||||
prepend_punctuations=prepend_punctuations,
|
||||
append_punctuations=append_punctuations,
|
||||
)
|
||||
word_end_timestamps = [
|
||||
w["end"] for s in current_segments for w in s["words"]
|
||||
]
|
||||
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
||||
seek_shift = round(
|
||||
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
||||
)
|
||||
if seek_shift > 0:
|
||||
seek = previous_seek + seek_shift
|
||||
|
||||
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
|
||||
if verbose:
|
||||
for segment in current_segments:
|
||||
start, end, text = segment["start"], segment["end"], segment["text"]
|
||||
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
||||
print(make_safe(line))
|
||||
|
||||
# if a segment is instantaneous or does not contain text, clear it
|
||||
for i, segment in enumerate(current_segments):
|
||||
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
||||
segment["text"] = ""
|
||||
segment["tokens"] = []
|
||||
segment["words"] = []
|
||||
|
||||
all_segments.extend(
|
||||
[
|
||||
{"id": i, **segment}
|
||||
for i, segment in enumerate(
|
||||
current_segments, start=len(all_segments)
|
||||
)
|
||||
]
|
||||
)
|
||||
all_tokens.extend(
|
||||
[token for segment in current_segments for token in segment["tokens"]]
|
||||
)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(min(content_frames, seek) - previous_seek)
|
||||
|
||||
return dict(
|
||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
||||
segments=all_segments,
|
||||
language=language,
|
||||
)
|
||||
|
||||
|
||||
def cli():
|
||||
from . import available_models
|
||||
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||
@@ -261,7 +387,7 @@ def cli():
|
||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
|
||||
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||
|
||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||
@@ -272,45 +398,45 @@ def cli():
|
||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
||||
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
||||
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
model_name: str = args.pop("model")
|
||||
model_dir: str = args.pop("model_dir")
|
||||
output_dir: str = args.pop("output_dir")
|
||||
output_format: str = args.pop("output_format")
|
||||
device: str = args.pop("device")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
||||
if args["language"] is not None:
|
||||
warnings.warn(
|
||||
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
||||
)
|
||||
args["language"] = "en"
|
||||
|
||||
temperature = args.pop("temperature")
|
||||
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
|
||||
if temperature_increment_on_fallback is not None:
|
||||
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
||||
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
||||
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
||||
else:
|
||||
temperature = [temperature]
|
||||
|
||||
if (threads := args.pop("threads")) > 0:
|
||||
torch.set_num_threads(threads)
|
||||
|
||||
from . import load_model
|
||||
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
for audio_path in args.pop("audio"):
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
|
||||
# save TXT
|
||||
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
|
||||
write_txt(result["segments"], file=txt)
|
||||
|
||||
# save VTT
|
||||
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
|
||||
write_vtt(result["segments"], file=vtt)
|
||||
|
||||
# save SRT
|
||||
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
||||
write_srt(result["segments"], file=srt)
|
||||
writer(result, audio_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError:
|
||||
raise RuntimeError("triton import failed; try `pip install --pre triton`")
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dtw_kernel(
|
||||
cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < M
|
||||
|
||||
for k in range(1, N + M + 1): # k = i + j
|
||||
tl.debug_barrier()
|
||||
|
||||
p0 = cost + (k - 1) * cost_stride
|
||||
p1 = cost + k * cost_stride
|
||||
p2 = cost + k * cost_stride + 1
|
||||
|
||||
c0 = tl.load(p0 + offsets, mask=mask)
|
||||
c1 = tl.load(p1 + offsets, mask=mask)
|
||||
c2 = tl.load(p2 + offsets, mask=mask)
|
||||
|
||||
x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
|
||||
cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
|
||||
|
||||
cost_ptr = cost + (k + 1) * cost_stride + 1
|
||||
tl.store(cost_ptr + offsets, cost_row, mask=mask)
|
||||
|
||||
trace_ptr = trace + (k + 1) * trace_stride + 1
|
||||
tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
|
||||
tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
|
||||
tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def median_kernel(filter_width: int):
|
||||
@triton.jit
|
||||
def kernel(
|
||||
y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
|
||||
): # x.shape[-1] == filter_width
|
||||
row_idx = tl.program_id(0)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < y_stride
|
||||
|
||||
x_ptr = x + row_idx * x_stride # noqa: F841
|
||||
y_ptr = y + row_idx * y_stride
|
||||
|
||||
LOAD_ALL_ROWS_HERE # noqa: F821
|
||||
|
||||
BUBBLESORT_HERE # noqa: F821
|
||||
|
||||
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
||||
|
||||
kernel = triton.JITFunction(kernel.fn)
|
||||
kernel.src = kernel.src.replace(
|
||||
" LOAD_ALL_ROWS_HERE",
|
||||
"\n".join(
|
||||
[
|
||||
f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
|
||||
for i in range(filter_width)
|
||||
]
|
||||
),
|
||||
)
|
||||
kernel.src = kernel.src.replace(
|
||||
" BUBBLESORT_HERE",
|
||||
"\n\n".join(
|
||||
[
|
||||
"\n\n".join(
|
||||
[
|
||||
"\n".join(
|
||||
[
|
||||
f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
|
||||
f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
|
||||
f" row{j} = smaller",
|
||||
f" row{j + 1} = larger",
|
||||
]
|
||||
)
|
||||
for j in range(filter_width - i - 1)
|
||||
]
|
||||
)
|
||||
for i in range(filter_width // 2 + 1)
|
||||
]
|
||||
),
|
||||
)
|
||||
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
def median_filter_cuda(x: torch.Tensor, filter_width: int):
|
||||
"""Apply a median filter of given width along the last dimension of x"""
|
||||
slices = x.contiguous().unfold(-1, filter_width, 1)
|
||||
grid = np.prod(slices.shape[:-2])
|
||||
|
||||
kernel = median_kernel(filter_width)
|
||||
y = torch.empty_like(slices[..., 0])
|
||||
|
||||
BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
|
||||
kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
|
||||
|
||||
return y
|
||||
+155
-37
@@ -1,5 +1,23 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import zlib
|
||||
from typing import Iterator, TextIO
|
||||
from typing import Callable, TextIO
|
||||
|
||||
system_encoding = sys.getdefaultencoding()
|
||||
|
||||
if system_encoding != "utf-8":
|
||||
|
||||
def make_safe(string):
|
||||
# replaces any character not representable using the system default encoding with an '?',
|
||||
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
|
||||
return string.encode(system_encoding, errors="replace").decode(system_encoding)
|
||||
|
||||
else:
|
||||
|
||||
def make_safe(string):
|
||||
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
|
||||
return string
|
||||
|
||||
|
||||
def exact_div(x, y):
|
||||
@@ -24,10 +42,13 @@ def optional_float(string):
|
||||
|
||||
|
||||
def compression_ratio(text) -> float:
|
||||
return len(text) / len(zlib.compress(text.encode("utf-8")))
|
||||
text_bytes = text.encode("utf-8")
|
||||
return len(text_bytes) / len(zlib.compress(text_bytes))
|
||||
|
||||
|
||||
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
|
||||
def format_timestamp(
|
||||
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
|
||||
):
|
||||
assert seconds >= 0, "non-negative timestamp expected"
|
||||
milliseconds = round(seconds * 1000.0)
|
||||
|
||||
@@ -41,47 +62,144 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
|
||||
milliseconds -= seconds * 1_000
|
||||
|
||||
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
||||
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||
return (
|
||||
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||
)
|
||||
|
||||
|
||||
def write_txt(transcript: Iterator[dict], file: TextIO):
|
||||
for segment in transcript:
|
||||
print(segment['text'].strip(), file=file, flush=True)
|
||||
class ResultWriter:
|
||||
extension: str
|
||||
|
||||
def __init__(self, output_dir: str):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def __call__(self, result: dict, audio_path: str):
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
audio_basename = os.path.splitext(audio_basename)[0]
|
||||
output_path = os.path.join(
|
||||
self.output_dir, audio_basename + "." + self.extension
|
||||
)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
self.write_result(result, file=f)
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def write_vtt(transcript: Iterator[dict], file: TextIO):
|
||||
print("WEBVTT\n", file=file)
|
||||
for segment in transcript:
|
||||
print(
|
||||
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||
file=file,
|
||||
flush=True,
|
||||
class WriteTXT(ResultWriter):
|
||||
extension: str = "txt"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
for segment in result["segments"]:
|
||||
print(segment["text"].strip(), file=file, flush=True)
|
||||
|
||||
|
||||
class SubtitlesWriter(ResultWriter):
|
||||
always_include_hours: bool
|
||||
decimal_marker: str
|
||||
|
||||
def iterate_result(self, result: dict):
|
||||
for segment in result["segments"]:
|
||||
segment_start = self.format_timestamp(segment["start"])
|
||||
segment_end = self.format_timestamp(segment["end"])
|
||||
segment_text = segment["text"].strip().replace("-->", "->")
|
||||
|
||||
if word_timings := segment.get("words", None):
|
||||
all_words = [timing["word"] for timing in word_timings]
|
||||
all_words[0] = all_words[0].strip() # remove the leading space, if any
|
||||
last = segment_start
|
||||
for i, this_word in enumerate(word_timings):
|
||||
start = self.format_timestamp(this_word["start"])
|
||||
end = self.format_timestamp(this_word["end"])
|
||||
if last != start:
|
||||
yield last, start, segment_text
|
||||
|
||||
yield start, end, "".join(
|
||||
[
|
||||
f"<u>{word}</u>" if j == i else word
|
||||
for j, word in enumerate(all_words)
|
||||
]
|
||||
)
|
||||
last = end
|
||||
|
||||
if last != segment_end:
|
||||
yield last, segment_end, segment_text
|
||||
else:
|
||||
yield segment_start, segment_end, segment_text
|
||||
|
||||
def format_timestamp(self, seconds: float):
|
||||
return format_timestamp(
|
||||
seconds=seconds,
|
||||
always_include_hours=self.always_include_hours,
|
||||
decimal_marker=self.decimal_marker,
|
||||
)
|
||||
|
||||
|
||||
def write_srt(transcript: Iterator[dict], file: TextIO):
|
||||
class WriteVTT(SubtitlesWriter):
|
||||
extension: str = "vtt"
|
||||
always_include_hours: bool = False
|
||||
decimal_marker: str = "."
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
print("WEBVTT\n", file=file)
|
||||
for start, end, text in self.iterate_result(result):
|
||||
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
|
||||
class WriteSRT(SubtitlesWriter):
|
||||
extension: str = "srt"
|
||||
always_include_hours: bool = True
|
||||
decimal_marker: str = ","
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
|
||||
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
|
||||
class WriteTSV(ResultWriter):
|
||||
"""
|
||||
Write a transcript to a file in SRT format.
|
||||
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
|
||||
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
|
||||
|
||||
Example usage:
|
||||
from pathlib import Path
|
||||
from whisper.utils import write_srt
|
||||
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
|
||||
# save SRT
|
||||
audio_basename = Path(audio_path).stem
|
||||
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
||||
write_srt(result["segments"], file=srt)
|
||||
Using integer milliseconds as start and end times means there's no chance of interference from
|
||||
an environment setting a language encoding that causes the decimal in a floating point number
|
||||
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
|
||||
"""
|
||||
for i, segment in enumerate(transcript, start=1):
|
||||
# write srt lines
|
||||
print(
|
||||
f"{i}\n"
|
||||
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
||||
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||
file=file,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
extension: str = "tsv"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
print("start", "end", "text", sep="\t", file=file)
|
||||
for segment in result["segments"]:
|
||||
print(round(1000 * segment["start"]), file=file, end="\t")
|
||||
print(round(1000 * segment["end"]), file=file, end="\t")
|
||||
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
|
||||
|
||||
|
||||
class WriteJSON(ResultWriter):
|
||||
extension: str = "json"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
json.dump(result, file)
|
||||
|
||||
|
||||
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
|
||||
writers = {
|
||||
"txt": WriteTXT,
|
||||
"vtt": WriteVTT,
|
||||
"srt": WriteSRT,
|
||||
"tsv": WriteTSV,
|
||||
"json": WriteJSON,
|
||||
}
|
||||
|
||||
if output_format == "all":
|
||||
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||
|
||||
def write_all(result: dict, file: TextIO):
|
||||
for writer in all_writers:
|
||||
writer(result, file)
|
||||
|
||||
return write_all
|
||||
|
||||
return writers[output_format](output_dir)
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = "20230308"
|
||||
Reference in New Issue
Block a user