Compare commits
37 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 55f690af79 | |||
| 7f1ef223ab | |||
| f5bfe004ec | |||
| da600abd2b | |||
| 9f7aba6099 | |||
| 12e1089462 | |||
| ea1c266709 | |||
| 8135a7c31c | |||
| 9d646db9d8 | |||
| 37a4f1be6d | |||
| b9f9b433ae | |||
| f0083e7eb2 | |||
| a84191faae | |||
| b1d213c0c7 | |||
| 493dfffa37 | |||
| 0f39c89d92 | |||
| 6df3ea1fb5 | |||
| 70861c7ce3 | |||
| f82bc59f5e | |||
| 28769fcfe5 | |||
| 53807677fe | |||
| 9323b2526c | |||
| 68e44bd83c | |||
| 0b5dcfdef7 | |||
| b9265e5796 | |||
| fd8f80c8b8 | |||
| 4179ed2475 | |||
| ec1b34bb90 | |||
| eff383b27b | |||
| 02aa851a49 | |||
| 76148a56c5 | |||
| 9f70a352f9 | |||
| 7f3e408e09 | |||
| f680570016 | |||
| d18e9ea5dd | |||
| 82725cea9c | |||
| 35713c66e0 |
@@ -0,0 +1,37 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions-ecosystem/action-regex-match@v2
|
||||
id: regex-match
|
||||
with:
|
||||
text: ${{ github.event.head_commit.message }}
|
||||
regex: '^Release ([^ ]+)'
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.8'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install setuptools wheel twine
|
||||
- name: Release
|
||||
if: ${{ steps.regex-match.outputs.match != '' }}
|
||||
uses: softprops/action-gh-release@v1
|
||||
with:
|
||||
tag_name: v${{ steps.regex-match.outputs.group1 }}
|
||||
- name: Build and publish
|
||||
if: ${{ steps.regex-match.outputs.match != '' }}
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||
run: |
|
||||
python setup.py sdist
|
||||
twine upload dist/*
|
||||
@@ -0,0 +1,26 @@
|
||||
name: test
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
whisper-test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.8', '3.9', '3.10']
|
||||
pytorch-version: [1.10.2, 1.13.1]
|
||||
exclude:
|
||||
- python-version: '3.10'
|
||||
pytorch-version: 1.10.2
|
||||
steps:
|
||||
- uses: conda-incubator/setup-miniconda@v2
|
||||
- run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
|
||||
- uses: actions/checkout@v2
|
||||
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
|
||||
- run: pip install pytest
|
||||
- run: pip install .
|
||||
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]'
|
||||
@@ -1,3 +1,6 @@
|
||||
include requirements.txt
|
||||
include README.md
|
||||
include LICENSE
|
||||
include whisper/assets/*
|
||||
include whisper/assets/gpt2/*
|
||||
include whisper/assets/multilingual/*
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Whisper
|
||||
|
||||
[[Blog]](https://openai.com/blog/whisper)
|
||||
[[Paper]](https://cdn.openai.com/papers/whisper.pdf)
|
||||
[[Model card]](model-card.md)
|
||||
[[Paper]](https://arxiv.org/abs/2212.04356)
|
||||
[[Model card]](https://github.com/openai/whisper/blob/main/model-card.md)
|
||||
[[Colab example]](https://colab.research.google.com/github/openai/whisper/blob/master/notebooks/LibriSpeech.ipynb)
|
||||
|
||||
Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multi-task model that can perform multilingual speech recognition as well as speech translation and language identification.
|
||||
@@ -10,17 +10,25 @@ Whisper is a general-purpose speech recognition model. It is trained on a large
|
||||
|
||||
## Approach
|
||||
|
||||

|
||||

|
||||
|
||||
A Transformer sequence-to-sequence model is trained on various speech processing tasks, including multilingual speech recognition, speech translation, spoken language identification, and voice activity detection. All of these tasks are jointly represented as a sequence of tokens to be predicted by the decoder, allowing for a single model to replace many different stages of a traditional speech processing pipeline. The multitask training format uses a set of special tokens that serve as task specifiers or classification targets.
|
||||
|
||||
|
||||
## Setup
|
||||
|
||||
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.7 or later and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. The following command will pull and install the latest commit from this repository, along with its Python dependencies
|
||||
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.7 or later and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. You can download and install (or update to) the latest release of Whisper with the following command:
|
||||
|
||||
pip install -U openai-whisper
|
||||
|
||||
Alternatively, the following command will pull and install the latest commit from this repository, along with its Python dependencies:
|
||||
|
||||
pip install git+https://github.com/openai/whisper.git
|
||||
|
||||
To update the package to the latest version of this repository, please run:
|
||||
|
||||
pip install --upgrade --no-deps --force-reinstall git+https://github.com/openai/whisper.git
|
||||
|
||||
It also requires the command-line tool [`ffmpeg`](https://ffmpeg.org/) to be installed on your system, which is available from most package managers:
|
||||
|
||||
```bash
|
||||
@@ -62,9 +70,9 @@ There are five model sizes, four with English-only versions, offering speed and
|
||||
|
||||
For English-only applications, the `.en` models tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
|
||||
|
||||
Whisper's performance varies widely depending on the language. The figure below shows a WER breakdown by languages of Fleurs dataset, using the `large` model. More WER and BLEU scores corresponding to the other models and datasets can be found in Appendix D in [the paper](https://cdn.openai.com/papers/whisper.pdf).
|
||||
Whisper's performance varies widely depending on the language. The figure below shows a WER (Word Error Rate) breakdown by languages of Fleurs dataset, using the `large-v2` model. More WER and BLEU scores corresponding to the other models and datasets can be found in Appendix D in [the paper](https://arxiv.org/abs/2212.04356). The smaller is better.
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
|
||||
@@ -86,7 +94,7 @@ Run the following to view all available options:
|
||||
|
||||
whisper --help
|
||||
|
||||
See [tokenizer.py](whisper/tokenizer.py) for the list of all available languages.
|
||||
See [tokenizer.py](https://github.com/openai/whisper/blob/main/whisper/tokenizer.py) for the list of all available languages.
|
||||
|
||||
|
||||
## Python usage
|
||||
@@ -136,4 +144,4 @@ Please use the [🙌 Show and tell](https://github.com/openai/whisper/discussion
|
||||
|
||||
## License
|
||||
|
||||
The code and the model weights of Whisper are released under the MIT License. See [LICENSE](LICENSE) for further details.
|
||||
The code and the model weights of Whisper are released under the MIT License. See [LICENSE](https://github.com/openai/whisper/blob/main/LICENSE) for further details.
|
||||
|
||||
+1501
-2532
File diff suppressed because it is too large
Load Diff
|
Before Width: | Height: | Size: 134 KiB After Width: | Height: | Size: 100 KiB |
+8
-6
@@ -2,7 +2,7 @@
|
||||
|
||||
This is the official codebase for running the automatic speech recognition (ASR) models (Whisper models) trained and released by OpenAI.
|
||||
|
||||
Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about the automatic speech recognition model. More information on how these models were trained and evaluated can be found [in the paper](https://cdn.openai.com/papers/whisper.pdf).
|
||||
Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about the automatic speech recognition model. More information on how these models were trained and evaluated can be found [in the paper](https://arxiv.org/abs/2212.04356).
|
||||
|
||||
|
||||
## Model Details
|
||||
@@ -17,10 +17,12 @@ The Whisper models are trained for speech recognition and translation tasks, cap
|
||||
| medium | 769 M | ✓ | ✓ |
|
||||
| large | 1550 M | | ✓ |
|
||||
|
||||
In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661).
|
||||
|
||||
|
||||
### Release date
|
||||
|
||||
September 2022
|
||||
September 2022 (original series) and December 2022 (`large-v2`)
|
||||
|
||||
### Model type
|
||||
|
||||
@@ -28,7 +30,7 @@ Sequence-to-sequence ASR (automatic speech recognition) and speech translation m
|
||||
|
||||
### Paper & samples
|
||||
|
||||
[Paper](https://cdn.openai.com/papers/whisper.pdf) / [Blog](https://openai.com/blog/whisper)
|
||||
[Paper](https://arxiv.org/abs/2212.04356) / [Blog](https://openai.com/blog/whisper)
|
||||
|
||||
|
||||
## Model Use
|
||||
@@ -46,7 +48,7 @@ In particular, we caution against using Whisper models to transcribe recordings
|
||||
|
||||
The models are trained on 680,000 hours of audio and the corresponding transcripts collected from the internet. 65% of this data (or 438,000 hours) represents English-language audio and matched English transcripts, roughly 18% (or 126,000 hours) represents non-English audio and English transcripts, while the final 17% (or 117,000 hours) represents non-English audio and the corresponding transcript. This non-English data represents 98 different languages.
|
||||
|
||||
As discussed in [the accompanying paper](https://cdn.openai.com/papers/whisper.pdf), we see that performance on transcription in a given language is directly correlated with the amount of training data we employ in that language.
|
||||
As discussed in [the accompanying paper](https://arxiv.org/abs/2212.04356), we see that performance on transcription in a given language is directly correlated with the amount of training data we employ in that language.
|
||||
|
||||
|
||||
## Performance and Limitations
|
||||
@@ -55,9 +57,9 @@ Our studies show that, over many existing ASR systems, the models exhibit improv
|
||||
|
||||
However, because the models are trained in a weakly supervised manner using large-scale noisy data, the predictions may include texts that are not actually spoken in the audio input (i.e. hallucination). We hypothesize that this happens because, given their general knowledge of language, the models combine trying to predict the next word in audio with trying to transcribe the audio itself.
|
||||
|
||||
Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://cdn.openai.com/papers/whisper.pdf).
|
||||
Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://arxiv.org/abs/2212.04356).
|
||||
|
||||
In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis on these limitations are provided in [the paper](https://cdn.openai.com/papers/whisper.pdf). It is likely that this behavior and hallucinations may be worse on lower-resource and/or lower-discoverability languages.
|
||||
In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis on these limitations are provided in [the paper](https://arxiv.org/abs/2212.04356). It is likely that this behavior and hallucinations may be worse on lower-resource and/or lower-discoverability languages.
|
||||
|
||||
|
||||
## Broader Implications
|
||||
|
||||
+2989
-116
File diff suppressed because one or more lines are too long
@@ -3,12 +3,24 @@ import os
|
||||
import pkg_resources
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
|
||||
def read_version(fname="whisper/version.py"):
|
||||
exec(compile(open(fname, encoding="utf-8").read(), fname, "exec"))
|
||||
return locals()["__version__"]
|
||||
|
||||
|
||||
setup(
|
||||
name="whisper",
|
||||
name="openai-whisper",
|
||||
py_modules=["whisper"],
|
||||
version="1.0",
|
||||
description="",
|
||||
version=read_version(),
|
||||
description="Robust Speech Recognition via Large-Scale Weak Supervision",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
readme="README.md",
|
||||
python_requires=">=3.7",
|
||||
author="OpenAI",
|
||||
url="https://github.com/openai/whisper",
|
||||
license="MIT",
|
||||
packages=find_packages(exclude=["tests*"]),
|
||||
install_requires=[
|
||||
str(r)
|
||||
@@ -16,9 +28,9 @@ setup(
|
||||
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||
)
|
||||
],
|
||||
entry_points = {
|
||||
'console_scripts': ['whisper=whisper.transcribe:cli'],
|
||||
entry_points={
|
||||
"console_scripts": ["whisper=whisper.transcribe:cli"],
|
||||
},
|
||||
include_package_data=True,
|
||||
extras_require={'dev': ['pytest']},
|
||||
extras_require={"dev": ["pytest"]},
|
||||
)
|
||||
|
||||
@@ -84,6 +84,7 @@ def test_text_normalizer():
|
||||
assert std("he's like") == "he is like"
|
||||
assert std("she's been like") == "she has been like"
|
||||
assert std("10km") == "10 km"
|
||||
assert std("10mm") == "10 mm"
|
||||
assert std("RC232") == "rc 232"
|
||||
|
||||
assert (
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import whisper
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', whisper.available_models())
|
||||
@pytest.mark.parametrize("model_name", whisper.available_models())
|
||||
def test_transcribe(model_name: str):
|
||||
model = whisper.load_model(model_name).cuda()
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = whisper.load_model(model_name).to(device)
|
||||
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
||||
|
||||
language = "en" if model_name.endswith(".en") else None
|
||||
|
||||
+15
-3
@@ -12,6 +12,7 @@ from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
from .model import Whisper, ModelDimensions
|
||||
from .transcribe import transcribe
|
||||
from .version import __version__
|
||||
|
||||
|
||||
_MODELS = {
|
||||
@@ -23,7 +24,9 @@ _MODELS = {
|
||||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt",
|
||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
}
|
||||
|
||||
|
||||
@@ -37,7 +40,8 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
with open(download_target, "rb") as f:
|
||||
model_bytes = f.read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return model_bytes if in_memory else download_target
|
||||
else:
|
||||
@@ -90,7 +94,15 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if download_root is None:
|
||||
download_root = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
||||
download_root = os.path.join(
|
||||
os.getenv(
|
||||
"XDG_CACHE_HOME",
|
||||
os.path.join(
|
||||
os.path.expanduser("~"), ".cache"
|
||||
)
|
||||
),
|
||||
"whisper"
|
||||
)
|
||||
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
|
||||
+2
-2
@@ -55,7 +55,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||
"""
|
||||
if torch.is_tensor(array):
|
||||
if array.shape[axis] > length:
|
||||
array = array.index_select(dim=axis, index=torch.arange(length))
|
||||
array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
|
||||
|
||||
if array.shape[axis] < length:
|
||||
pad_widths = [(0, 0)] * array.ndim
|
||||
@@ -113,7 +113,7 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
|
||||
|
||||
window = torch.hann_window(N_FFT).to(audio.device)
|
||||
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||
magnitudes = stft[:, :-1].abs() ** 2
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
|
||||
filters = mel_filters(audio.device, n_mels)
|
||||
mel_spec = filters @ magnitudes
|
||||
|
||||
+8
-4
@@ -423,10 +423,14 @@ class ApplyTimestampRules(LogitFilter):
|
||||
else: # cannot be normal text tokens
|
||||
logits[k, : self.tokenizer.eot] = -np.inf
|
||||
|
||||
# apply the `max_initial_timestamp` option
|
||||
if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
|
||||
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
||||
logits[:, last_allowed + 1 :] = -np.inf
|
||||
if tokens.shape[1] == self.sample_begin:
|
||||
# suppress generating non-timestamp tokens at the beginning
|
||||
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|
||||
|
||||
# apply the `max_initial_timestamp` option
|
||||
if self.max_initial_timestamp_index is not None:
|
||||
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
||||
logits[:, last_allowed + 1 :] = -np.inf
|
||||
|
||||
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
|
||||
+12
-11
@@ -72,18 +72,18 @@ class MultiHeadAttention(nn.Module):
|
||||
):
|
||||
q = self.query(x)
|
||||
|
||||
if kv_cache is None or xa is None:
|
||||
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
||||
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
||||
k = self.key(x if xa is None else xa)
|
||||
v = self.value(x if xa is None else xa)
|
||||
else:
|
||||
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||
k = kv_cache.get(self.key, self.key(xa))
|
||||
v = kv_cache.get(self.value, self.value(xa))
|
||||
k = kv_cache[self.key]
|
||||
v = kv_cache[self.value]
|
||||
|
||||
wv = self.qkv_attention(q, k, v, mask)
|
||||
return self.out(wv)
|
||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||
return self.out(wv), qk
|
||||
|
||||
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
@@ -95,9 +95,10 @@ class MultiHeadAttention(nn.Module):
|
||||
qk = q @ k
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
qk = qk.float()
|
||||
|
||||
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
|
||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
@@ -121,9 +122,9 @@ class ResidualAttentionBlock(nn.Module):
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
):
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
||||
x = x + self.mlp(self.mlp_ln(x))
|
||||
return x
|
||||
|
||||
@@ -214,10 +215,10 @@ class Whisper(nn.Module):
|
||||
)
|
||||
|
||||
def embed_audio(self, mel: torch.Tensor):
|
||||
return self.encoder.forward(mel)
|
||||
return self.encoder(mel)
|
||||
|
||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||
return self.decoder.forward(tokens, audio_features)
|
||||
return self.decoder(tokens, audio_features)
|
||||
|
||||
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
return self.decoder(tokens, self.encoder(mel))
|
||||
|
||||
@@ -1737,6 +1737,5 @@
|
||||
"yoghurt": "yogurt",
|
||||
"yoghurts": "yogurts",
|
||||
"mhm": "hmm",
|
||||
"mm": "hmm",
|
||||
"mmm": "hmm"
|
||||
}
|
||||
@@ -28,7 +28,7 @@ LANGUAGES = {
|
||||
"hi": "hindi",
|
||||
"fi": "finnish",
|
||||
"vi": "vietnamese",
|
||||
"iw": "hebrew",
|
||||
"he": "hebrew",
|
||||
"uk": "ukrainian",
|
||||
"el": "greek",
|
||||
"ms": "malay",
|
||||
|
||||
+28
-28
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
|
||||
from typing import Optional, Tuple, Union, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -10,7 +10,7 @@ import tqdm
|
||||
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
|
||||
from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
@@ -44,7 +44,7 @@ def transcribe(
|
||||
If False, displays minimal details. If None, does not display anything
|
||||
|
||||
temperature: Union[float, Tuple[float, ...]]
|
||||
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
|
||||
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
|
||||
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
||||
|
||||
compression_ratio_threshold: float
|
||||
@@ -84,13 +84,16 @@ def transcribe(
|
||||
mel = log_mel_spectrogram(audio)
|
||||
|
||||
if decode_options.get("language", None) is None:
|
||||
if verbose:
|
||||
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
|
||||
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
_, probs = model.detect_language(segment)
|
||||
decode_options["language"] = max(probs, key=probs.get)
|
||||
if verbose is not None:
|
||||
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
|
||||
if not model.is_multilingual:
|
||||
decode_options["language"] = "en"
|
||||
else:
|
||||
if verbose:
|
||||
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
|
||||
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
_, probs = model.detect_language(segment)
|
||||
decode_options["language"] = max(probs, key=probs.get)
|
||||
if verbose is not None:
|
||||
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
|
||||
|
||||
language = decode_options["language"]
|
||||
task = decode_options.get("task", "transcribe")
|
||||
@@ -154,7 +157,7 @@ def transcribe(
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": text,
|
||||
"tokens": result.tokens,
|
||||
"tokens": text_tokens.tolist(),
|
||||
"temperature": result.temperature,
|
||||
"avg_logprob": result.avg_logprob,
|
||||
"compression_ratio": result.compression_ratio,
|
||||
@@ -162,7 +165,7 @@ def transcribe(
|
||||
}
|
||||
)
|
||||
if verbose:
|
||||
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
|
||||
print(make_safe(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"))
|
||||
|
||||
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
||||
num_frames = mel.shape[-1]
|
||||
@@ -252,6 +255,7 @@ def cli():
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||
@@ -261,7 +265,7 @@ def cli():
|
||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
|
||||
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||
|
||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||
@@ -272,16 +276,19 @@ def cli():
|
||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
model_name: str = args.pop("model")
|
||||
model_dir: str = args.pop("model_dir")
|
||||
output_dir: str = args.pop("output_dir")
|
||||
output_format: str = args.pop("output_format")
|
||||
device: str = args.pop("device")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
||||
if args["language"] is not None:
|
||||
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
||||
args["language"] = "en"
|
||||
|
||||
temperature = args.pop("temperature")
|
||||
@@ -291,25 +298,18 @@ def cli():
|
||||
else:
|
||||
temperature = [temperature]
|
||||
|
||||
threads = args.pop("threads")
|
||||
if threads > 0:
|
||||
torch.set_num_threads(threads)
|
||||
|
||||
from . import load_model
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
|
||||
for audio_path in args.pop("audio"):
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
|
||||
# save TXT
|
||||
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
|
||||
write_txt(result["segments"], file=txt)
|
||||
|
||||
# save VTT
|
||||
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
|
||||
write_vtt(result["segments"], file=vtt)
|
||||
|
||||
# save SRT
|
||||
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
||||
write_srt(result["segments"], file=srt)
|
||||
writer(result, audio_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
+112
-36
@@ -1,5 +1,20 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import zlib
|
||||
from typing import Iterator, TextIO
|
||||
from typing import Callable, TextIO
|
||||
|
||||
system_encoding = sys.getdefaultencoding()
|
||||
|
||||
if system_encoding != "utf-8":
|
||||
def make_safe(string):
|
||||
# replaces any character not representable using the system default encoding with an '?',
|
||||
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
|
||||
return string.encode(system_encoding, errors="replace").decode(system_encoding)
|
||||
else:
|
||||
def make_safe(string):
|
||||
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
|
||||
return string
|
||||
|
||||
|
||||
def exact_div(x, y):
|
||||
@@ -24,7 +39,8 @@ def optional_float(string):
|
||||
|
||||
|
||||
def compression_ratio(text) -> float:
|
||||
return len(text) / len(zlib.compress(text.encode("utf-8")))
|
||||
text_bytes = text.encode("utf-8")
|
||||
return len(text_bytes) / len(zlib.compress(text_bytes))
|
||||
|
||||
|
||||
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
|
||||
@@ -44,44 +60,104 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
|
||||
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||
|
||||
|
||||
def write_txt(transcript: Iterator[dict], file: TextIO):
|
||||
for segment in transcript:
|
||||
print(segment['text'].strip(), file=file, flush=True)
|
||||
class ResultWriter:
|
||||
extension: str
|
||||
|
||||
def __init__(self, output_dir: str):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def __call__(self, result: dict, audio_path: str):
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
self.write_result(result, file=f)
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def write_vtt(transcript: Iterator[dict], file: TextIO):
|
||||
print("WEBVTT\n", file=file)
|
||||
for segment in transcript:
|
||||
print(
|
||||
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||
file=file,
|
||||
flush=True,
|
||||
)
|
||||
class WriteTXT(ResultWriter):
|
||||
extension: str = "txt"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
for segment in result["segments"]:
|
||||
print(segment['text'].strip(), file=file, flush=True)
|
||||
|
||||
|
||||
def write_srt(transcript: Iterator[dict], file: TextIO):
|
||||
class WriteVTT(ResultWriter):
|
||||
extension: str = "vtt"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
print("WEBVTT\n", file=file)
|
||||
for segment in result["segments"]:
|
||||
print(
|
||||
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||
file=file,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
class WriteSRT(ResultWriter):
|
||||
extension: str = "srt"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
for i, segment in enumerate(result["segments"], start=1):
|
||||
# write srt lines
|
||||
print(
|
||||
f"{i}\n"
|
||||
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
||||
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||
file=file,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
class WriteTSV(ResultWriter):
|
||||
"""
|
||||
Write a transcript to a file in SRT format.
|
||||
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
|
||||
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
|
||||
|
||||
Example usage:
|
||||
from pathlib import Path
|
||||
from whisper.utils import write_srt
|
||||
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
|
||||
# save SRT
|
||||
audio_basename = Path(audio_path).stem
|
||||
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
||||
write_srt(result["segments"], file=srt)
|
||||
Using integer milliseconds as start and end times means there's no chance of interference from
|
||||
an environment setting a language encoding that causes the decimal in a floating point number
|
||||
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
|
||||
"""
|
||||
for i, segment in enumerate(transcript, start=1):
|
||||
# write srt lines
|
||||
print(
|
||||
f"{i}\n"
|
||||
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
||||
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||
file=file,
|
||||
flush=True,
|
||||
)
|
||||
extension: str = "tsv"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
print("start", "end", "text", sep="\t", file=file)
|
||||
for segment in result["segments"]:
|
||||
print(round(1000 * segment['start']), file=file, end="\t")
|
||||
print(round(1000 * segment['end']), file=file, end="\t")
|
||||
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
|
||||
|
||||
|
||||
class WriteJSON(ResultWriter):
|
||||
extension: str = "json"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
json.dump(result, file)
|
||||
|
||||
|
||||
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
|
||||
writers = {
|
||||
"txt": WriteTXT,
|
||||
"vtt": WriteVTT,
|
||||
"srt": WriteSRT,
|
||||
"tsv": WriteTSV,
|
||||
"json": WriteJSON,
|
||||
}
|
||||
|
||||
if output_format == "all":
|
||||
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||
|
||||
def write_all(result: dict, file: TextIO):
|
||||
for writer in all_writers:
|
||||
writer(result, file)
|
||||
|
||||
return write_all
|
||||
|
||||
return writers[output_format](output_dir)
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = "20230124"
|
||||
Reference in New Issue
Block a user