37 Commits

Author SHA1 Message Date
Jong Wook Kim 55f690af79 Release 20230124 2023-01-24 11:11:08 -08:00
Jong Wook Kim 7f1ef223ab handle printing even if sys.stdout.buffer is not available (#887) 2023-01-24 10:12:04 -08:00
Niels Mayer f5bfe004ec Add TSV formatted output in transcript, using integer start/end times in milliseconds. (#228)
* Add CSV format output in transcript, containing lines of characters formatted like: <startTime-in-integer-milliseconds>, <endTime-in-integer-milliseconds>, <transcript-including-commas>

* for easier reading by spreadsheets importing CSV, the third

column of the CSV file is delimited by quotes, and any quote
characters that might be in the transcript (which would interfere with
parsing the third column as a string) are converted to "''".

* fix syntax error

* docstring edit

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
2023-01-22 00:27:17 -08:00
Aaryan YVS da600abd2b Added --output_format option (#333)
* Added --output option

--output option will help select the output files that will be generated.

Corrected the logic, which wrongly shows progress bar when verbose is set to False

* Changed output_files variable

* Changed back the tqdm verbose

* refactor output format handling

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
2023-01-21 23:58:38 -08:00
zer0-x 9f7aba6099 Handle XDG_CACHE_HOME properly for download_root (#864)
Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-01-21 01:09:39 -08:00
Jong Wook Kim 12e1089462 use stdout for printing transcription progress (#867) 2023-01-20 00:54:05 -08:00
Markus Hennerbichler ea1c266709 Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm (#659)
Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
2023-01-18 10:41:11 -08:00
Jong Wook Kim 8135a7c31c verbose outputs from pytest 2023-01-18 10:30:18 -08:00
Jong Wook Kim 9d646db9d8 print '?' if a letter can't be encoded using the system default encoding (#859) 2023-01-17 23:28:36 -08:00
Jong Wook Kim 37a4f1be6d Release 20230117 2023-01-17 16:08:28 -08:00
Romain Beaumont b9f9b433ae Add github action to automatically push to pypi on Release x.y.z commit (#681)
* Add github action to automatically push to pypi on Release x.y.z commit

* some housekeeping for pypi upload

* add version.py

Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
2023-01-17 15:50:26 -08:00
Umar Farooqi f0083e7eb2 Use ndimage.median_filter instead of signal.medfilter (#812)
For a 30s long audio file which didn't have any silence, ndimage.median_filter took 7s where signa.medfilter took 30s.

Co-authored-by: Umar Farooqi <umar@paystash.com>
Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
2023-01-17 14:43:05 -08:00
Jong Wook Kim a84191faae rename GitHub workflow 2023-01-17 13:54:40 -08:00
Jong Wook Kim b1d213c0c7 allow test_transcribe to run on CPU when CUDA is not available 2023-01-17 13:43:36 -08:00
Jong Wook Kim 493dfffa37 add github action to run pytest 2023-01-17 13:38:33 -08:00
Mikko Vedru 0f39c89d92 Update README.md (#804) 2023-01-16 23:46:42 -08:00
Markus Hennerbichler 6df3ea1fb5 Support batch-dimension in log_mel_spectogram (#839) 2023-01-16 23:46:15 -08:00
adamreis 70861c7ce3 Fix tiny transcribe() docstring typo (#857)
s/successfully/successively, which I believe was the intent.
2023-01-16 22:42:01 -08:00
Jong Wook Kim f82bc59f5e torch.concatenate -> torch.cat for compatibility 2023-01-10 10:53:18 -08:00
Jong Wook Kim 28769fcfe5 word-level timestamps in Multilingual_ASR notebook 2022-12-31 10:03:42 -07:00
Jong Wook Kim 53807677fe MultiHeadAttention to return qk as well 2022-12-30 01:53:57 -07:00
Jong Wook Kim 9323b2526c Revert "saving the qk matrix in the attention module for convenience"
This reverts commit 68e44bd83c.
2022-12-29 23:53:31 -07:00
Jong Wook Kim 68e44bd83c saving the qk matrix in the attention module for convenience 2022-12-29 23:02:52 -07:00
Jong Wook Kim 0b5dcfdef7 large-v2 figure and arxiv url update 2022-12-09 00:12:39 -05:00
altryne b9265e5796 Update Hebrew language code to he per IANA registry (#401)
* Update Hebrew language code to he per IANA registry

Per [IANA registry](https://www.iana.org/assignments/language-subtag-registry/language-subtag-registry), `iw` was deprecated as the code for Hebrew in 1989 and the preferred code is `he`

The correct subtag: 
```
%%
Type: language
Subtag: he
Description: Hebrew
Added: 2005-10-16
Suppress-Script: Hebr
%%
``` 
And the deprecation
```
%%
Type: language
Subtag: iw
Description: Hebrew
Added: 2005-10-16
Deprecated: 1989-01-01
Preferred-Value: he
Suppress-Script: Hebr
%%
```

* Update hebrew ISO code to he

Per discussion, it's ok to make this change without backwards compatibility
2022-12-07 13:45:31 -05:00
Paul Harter fd8f80c8b8 Explicitly closing model file after reading it (#630) 2022-12-06 12:07:19 -05:00
Jong Wook Kim 4179ed2475 add large-v2 model
- The "large-v2" model is trained for more epochs with regularization and shows improved performance compared to the previous large.
- It has the same architecture as the original large model.
- When `load_model("large")` is called, the "large-v2" model will be loaded.
- We will soon update the paper regarding this new model.
2022-12-05 11:07:14 -05:00
jumon ec1b34bb90 fix compression ratio function (#561) 2022-12-04 17:27:42 -06:00
Jong Wook Kim eff383b27b invoking __call__ instead of forward() 2022-11-16 04:18:50 -08:00
Jong Wook Kim 02aa851a49 fix to return only the text token ids 2022-11-15 16:25:11 -08:00
jumon 76148a56c5 suppress generating non-timestamp tokens at the beginning (#532) 2022-11-15 11:44:36 -08:00
Vicki Anand 9f70a352f9 Fix attention caching to make it actually work (#370) 2022-10-19 16:44:03 -07:00
Sumana Harihareswara 7f3e408e09 Add package metadata to setup.py (#315)
Add project summary, license, etc. for display with
"pip show" and similar Python package distribution tools.
2022-10-17 13:51:16 -07:00
Michael Monashev f680570016 Fix bug (#305)
Fix bug: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)
2022-10-17 11:38:20 -07:00
Jong Wook Kim d18e9ea5dd transcribe() on English-only model won't complain when language="en" is not given 2022-10-09 02:40:12 -07:00
David Marx 82725cea9c infer download_root from XDG_CACHE_HOME if avail (#257) 2022-10-09 02:14:03 -07:00
eudoxos 35713c66e0 Add --threads option to transcribe (#278)
* Add --threads option to transcribe

Torch on CPU uses by default number_of_cores/2. This option allows to
override this default.

* Update transcribe.py

Co-authored-by: Jong Wook Kim <ilikekjw@gmail.com>
2022-10-09 02:11:15 -07:00
19 changed files with 4782 additions and 2756 deletions
+37
View File
@@ -0,0 +1,37 @@
name: Release
on:
push:
branches:
- main
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions-ecosystem/action-regex-match@v2
id: regex-match
with:
text: ${{ github.event.head_commit.message }}
regex: '^Release ([^ ]+)'
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.8'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Release
if: ${{ steps.regex-match.outputs.match != '' }}
uses: softprops/action-gh-release@v1
with:
tag_name: v${{ steps.regex-match.outputs.group1 }}
- name: Build and publish
if: ${{ steps.regex-match.outputs.match != '' }}
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
python setup.py sdist
twine upload dist/*
+26
View File
@@ -0,0 +1,26 @@
name: test
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
whisper-test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10']
pytorch-version: [1.10.2, 1.13.1]
exclude:
- python-version: '3.10'
pytorch-version: 1.10.2
steps:
- uses: conda-incubator/setup-miniconda@v2
- run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
- uses: actions/checkout@v2
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
- run: pip install pytest
- run: pip install .
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]'
+3
View File
@@ -1,3 +1,6 @@
include requirements.txt
include README.md
include LICENSE
include whisper/assets/*
include whisper/assets/gpt2/*
include whisper/assets/multilingual/*
+16 -8
View File
@@ -1,8 +1,8 @@
# Whisper
[[Blog]](https://openai.com/blog/whisper)
[[Paper]](https://cdn.openai.com/papers/whisper.pdf)
[[Model card]](model-card.md)
[[Paper]](https://arxiv.org/abs/2212.04356)
[[Model card]](https://github.com/openai/whisper/blob/main/model-card.md)
[[Colab example]](https://colab.research.google.com/github/openai/whisper/blob/master/notebooks/LibriSpeech.ipynb)
Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multi-task model that can perform multilingual speech recognition as well as speech translation and language identification.
@@ -10,17 +10,25 @@ Whisper is a general-purpose speech recognition model. It is trained on a large
## Approach
![Approach](approach.png)
![Approach](https://raw.githubusercontent.com/openai/whisper/main/approach.png)
A Transformer sequence-to-sequence model is trained on various speech processing tasks, including multilingual speech recognition, speech translation, spoken language identification, and voice activity detection. All of these tasks are jointly represented as a sequence of tokens to be predicted by the decoder, allowing for a single model to replace many different stages of a traditional speech processing pipeline. The multitask training format uses a set of special tokens that serve as task specifiers or classification targets.
## Setup
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.7 or later and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. The following command will pull and install the latest commit from this repository, along with its Python dependencies
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.7 or later and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. You can download and install (or update to) the latest release of Whisper with the following command:
pip install -U openai-whisper
Alternatively, the following command will pull and install the latest commit from this repository, along with its Python dependencies:
pip install git+https://github.com/openai/whisper.git
To update the package to the latest version of this repository, please run:
pip install --upgrade --no-deps --force-reinstall git+https://github.com/openai/whisper.git
It also requires the command-line tool [`ffmpeg`](https://ffmpeg.org/) to be installed on your system, which is available from most package managers:
```bash
@@ -62,9 +70,9 @@ There are five model sizes, four with English-only versions, offering speed and
For English-only applications, the `.en` models tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
Whisper's performance varies widely depending on the language. The figure below shows a WER breakdown by languages of Fleurs dataset, using the `large` model. More WER and BLEU scores corresponding to the other models and datasets can be found in Appendix D in [the paper](https://cdn.openai.com/papers/whisper.pdf).
Whisper's performance varies widely depending on the language. The figure below shows a WER (Word Error Rate) breakdown by languages of Fleurs dataset, using the `large-v2` model. More WER and BLEU scores corresponding to the other models and datasets can be found in Appendix D in [the paper](https://arxiv.org/abs/2212.04356). The smaller is better.
![WER breakdown by language](language-breakdown.svg)
![WER breakdown by language](https://raw.githubusercontent.com/openai/whisper/main/language-breakdown.svg)
@@ -86,7 +94,7 @@ Run the following to view all available options:
whisper --help
See [tokenizer.py](whisper/tokenizer.py) for the list of all available languages.
See [tokenizer.py](https://github.com/openai/whisper/blob/main/whisper/tokenizer.py) for the list of all available languages.
## Python usage
@@ -136,4 +144,4 @@ Please use the [🙌 Show and tell](https://github.com/openai/whisper/discussion
## License
The code and the model weights of Whisper are released under the MIT License. See [LICENSE](LICENSE) for further details.
The code and the model weights of Whisper are released under the MIT License. See [LICENSE](https://github.com/openai/whisper/blob/main/LICENSE) for further details.
+1501 -2532
View File
File diff suppressed because it is too large Load Diff

Before

Width:  |  Height:  |  Size: 134 KiB

After

Width:  |  Height:  |  Size: 100 KiB

+8 -6
View File
@@ -2,7 +2,7 @@
This is the official codebase for running the automatic speech recognition (ASR) models (Whisper models) trained and released by OpenAI.
Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about the automatic speech recognition model. More information on how these models were trained and evaluated can be found [in the paper](https://cdn.openai.com/papers/whisper.pdf).
Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about the automatic speech recognition model. More information on how these models were trained and evaluated can be found [in the paper](https://arxiv.org/abs/2212.04356).
## Model Details
@@ -17,10 +17,12 @@ The Whisper models are trained for speech recognition and translation tasks, cap
| medium | 769 M | ✓ | ✓ |
| large | 1550 M | | ✓ |
In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661).
### Release date
September 2022
September 2022 (original series) and December 2022 (`large-v2`)
### Model type
@@ -28,7 +30,7 @@ Sequence-to-sequence ASR (automatic speech recognition) and speech translation m
### Paper & samples
[Paper](https://cdn.openai.com/papers/whisper.pdf) / [Blog](https://openai.com/blog/whisper)
[Paper](https://arxiv.org/abs/2212.04356) / [Blog](https://openai.com/blog/whisper)
## Model Use
@@ -46,7 +48,7 @@ In particular, we caution against using Whisper models to transcribe recordings
The models are trained on 680,000 hours of audio and the corresponding transcripts collected from the internet. 65% of this data (or 438,000 hours) represents English-language audio and matched English transcripts, roughly 18% (or 126,000 hours) represents non-English audio and English transcripts, while the final 17% (or 117,000 hours) represents non-English audio and the corresponding transcript. This non-English data represents 98 different languages.
As discussed in [the accompanying paper](https://cdn.openai.com/papers/whisper.pdf), we see that performance on transcription in a given language is directly correlated with the amount of training data we employ in that language.
As discussed in [the accompanying paper](https://arxiv.org/abs/2212.04356), we see that performance on transcription in a given language is directly correlated with the amount of training data we employ in that language.
## Performance and Limitations
@@ -55,9 +57,9 @@ Our studies show that, over many existing ASR systems, the models exhibit improv
However, because the models are trained in a weakly supervised manner using large-scale noisy data, the predictions may include texts that are not actually spoken in the audio input (i.e. hallucination). We hypothesize that this happens because, given their general knowledge of language, the models combine trying to predict the next word in audio with trying to transcribe the audio itself.
Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://cdn.openai.com/papers/whisper.pdf).
Our models perform unevenly across languages, and we observe lower accuracy on low-resource and/or low-discoverability languages or languages where we have less training data. The models also exhibit disparate performance on different accents and dialects of particular languages, which may include higher word error rate across speakers of different genders, races, ages, or other demographic criteria. Our full evaluation results are presented in [the paper accompanying this release](https://arxiv.org/abs/2212.04356).
In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis on these limitations are provided in [the paper](https://cdn.openai.com/papers/whisper.pdf). It is likely that this behavior and hallucinations may be worse on lower-resource and/or lower-discoverability languages.
In addition, the sequence-to-sequence architecture of the model makes it prone to generating repetitive texts, which can be mitigated to some degree by beam search and temperature scheduling but not perfectly. Further analysis on these limitations are provided in [the paper](https://arxiv.org/abs/2212.04356). It is likely that this behavior and hallucinations may be worse on lower-resource and/or lower-discoverability languages.
## Broader Implications
File diff suppressed because one or more lines are too long
+18 -6
View File
@@ -3,12 +3,24 @@ import os
import pkg_resources
from setuptools import setup, find_packages
def read_version(fname="whisper/version.py"):
exec(compile(open(fname, encoding="utf-8").read(), fname, "exec"))
return locals()["__version__"]
setup(
name="whisper",
name="openai-whisper",
py_modules=["whisper"],
version="1.0",
description="",
version=read_version(),
description="Robust Speech Recognition via Large-Scale Weak Supervision",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
readme="README.md",
python_requires=">=3.7",
author="OpenAI",
url="https://github.com/openai/whisper",
license="MIT",
packages=find_packages(exclude=["tests*"]),
install_requires=[
str(r)
@@ -16,9 +28,9 @@ setup(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
],
entry_points = {
'console_scripts': ['whisper=whisper.transcribe:cli'],
entry_points={
"console_scripts": ["whisper=whisper.transcribe:cli"],
},
include_package_data=True,
extras_require={'dev': ['pytest']},
extras_require={"dev": ["pytest"]},
)
+1
View File
@@ -84,6 +84,7 @@ def test_text_normalizer():
assert std("he's like") == "he is like"
assert std("she's been like") == "she has been like"
assert std("10km") == "10 km"
assert std("10mm") == "10 mm"
assert std("RC232") == "rc 232"
assert (
+4 -2
View File
@@ -1,13 +1,15 @@
import os
import pytest
import torch
import whisper
@pytest.mark.parametrize('model_name', whisper.available_models())
@pytest.mark.parametrize("model_name", whisper.available_models())
def test_transcribe(model_name: str):
model = whisper.load_model(model_name).cuda()
device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper.load_model(model_name).to(device)
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
language = "en" if model_name.endswith(".en") else None
+15 -3
View File
@@ -12,6 +12,7 @@ from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import Whisper, ModelDimensions
from .transcribe import transcribe
from .version import __version__
_MODELS = {
@@ -23,7 +24,9 @@ _MODELS = {
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}
@@ -37,7 +40,8 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
model_bytes = open(download_target, "rb").read()
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
@@ -90,7 +94,15 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
download_root = os.path.join(
os.getenv(
"XDG_CACHE_HOME",
os.path.join(
os.path.expanduser("~"), ".cache"
)
),
"whisper"
)
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
+2 -2
View File
@@ -55,7 +55,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(dim=axis, index=torch.arange(length))
array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
@@ -113,7 +113,7 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[:, :-1].abs() ** 2
magnitudes = stft[..., :-1].abs() ** 2
filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes
+8 -4
View File
@@ -423,10 +423,14 @@ class ApplyTimestampRules(LogitFilter):
else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf
# apply the `max_initial_timestamp` option
if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
logits[:, last_allowed + 1 :] = -np.inf
if tokens.shape[1] == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
# apply the `max_initial_timestamp` option
if self.max_initial_timestamp_index is not None:
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
logits[:, last_allowed + 1 :] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = F.log_softmax(logits.float(), dim=-1)
+12 -11
View File
@@ -72,18 +72,18 @@ class MultiHeadAttention(nn.Module):
):
q = self.query(x)
if kv_cache is None or xa is None:
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache.get(self.key, self.key(xa))
v = kv_cache.get(self.value, self.value(xa))
k = kv_cache[self.key]
v = kv_cache[self.value]
wv = self.qkv_attention(q, k, v, mask)
return self.out(wv)
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
n_batch, n_ctx, n_state = q.shape
@@ -95,9 +95,10 @@ class MultiHeadAttention(nn.Module):
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float()
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
w = F.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
class ResidualAttentionBlock(nn.Module):
@@ -121,9 +122,9 @@ class ResidualAttentionBlock(nn.Module):
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
x = x + self.mlp(self.mlp_ln(x))
return x
@@ -214,10 +215,10 @@ class Whisper(nn.Module):
)
def embed_audio(self, mel: torch.Tensor):
return self.encoder.forward(mel)
return self.encoder(mel)
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder.forward(tokens, audio_features)
return self.decoder(tokens, audio_features)
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
-1
View File
@@ -1737,6 +1737,5 @@
"yoghurt": "yogurt",
"yoghurts": "yogurts",
"mhm": "hmm",
"mm": "hmm",
"mmm": "hmm"
}
+1 -1
View File
@@ -28,7 +28,7 @@ LANGUAGES = {
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"iw": "hebrew",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
+28 -28
View File
@@ -1,7 +1,7 @@
import argparse
import os
import warnings
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
from typing import Optional, Tuple, Union, TYPE_CHECKING
import numpy as np
import torch
@@ -10,7 +10,7 @@ import tqdm
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
from .decoding import DecodingOptions, DecodingResult
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
if TYPE_CHECKING:
from .model import Whisper
@@ -44,7 +44,7 @@ def transcribe(
If False, displays minimal details. If None, does not display anything
temperature: Union[float, Tuple[float, ...]]
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
compression_ratio_threshold: float
@@ -84,13 +84,16 @@ def transcribe(
mel = log_mel_spectrogram(audio)
if decode_options.get("language", None) is None:
if verbose:
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
if not model.is_multilingual:
decode_options["language"] = "en"
else:
if verbose:
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
language = decode_options["language"]
task = decode_options.get("task", "transcribe")
@@ -154,7 +157,7 @@ def transcribe(
"start": start,
"end": end,
"text": text,
"tokens": result.tokens,
"tokens": text_tokens.tolist(),
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
@@ -162,7 +165,7 @@ def transcribe(
}
)
if verbose:
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
print(make_safe(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"))
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
num_frames = mel.shape[-1]
@@ -252,6 +255,7 @@ def cli():
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
@@ -261,7 +265,7 @@ def cli():
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
@@ -272,16 +276,19 @@ def cli():
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
args = parser.parse_args().__dict__
model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
os.makedirs(output_dir, exist_ok=True)
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
if args["language"] is not None:
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
args["language"] = "en"
temperature = args.pop("temperature")
@@ -291,25 +298,18 @@ def cli():
else:
temperature = [temperature]
threads = args.pop("threads")
if threads > 0:
torch.set_num_threads(threads)
from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir)
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
audio_basename = os.path.basename(audio_path)
# save TXT
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
write_txt(result["segments"], file=txt)
# save VTT
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
write_vtt(result["segments"], file=vtt)
# save SRT
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
writer(result, audio_path)
if __name__ == '__main__':
+112 -36
View File
@@ -1,5 +1,20 @@
import json
import os
import sys
import zlib
from typing import Iterator, TextIO
from typing import Callable, TextIO
system_encoding = sys.getdefaultencoding()
if system_encoding != "utf-8":
def make_safe(string):
# replaces any character not representable using the system default encoding with an '?',
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
return string.encode(system_encoding, errors="replace").decode(system_encoding)
else:
def make_safe(string):
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
return string
def exact_div(x, y):
@@ -24,7 +39,8 @@ def optional_float(string):
def compression_ratio(text) -> float:
return len(text) / len(zlib.compress(text.encode("utf-8")))
text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes))
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
@@ -44,44 +60,104 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
def write_txt(transcript: Iterator[dict], file: TextIO):
for segment in transcript:
print(segment['text'].strip(), file=file, flush=True)
class ResultWriter:
extension: str
def __init__(self, output_dir: str):
self.output_dir = output_dir
def __call__(self, result: dict, audio_path: str):
audio_basename = os.path.basename(audio_path)
output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f)
def write_result(self, result: dict, file: TextIO):
raise NotImplementedError
def write_vtt(transcript: Iterator[dict], file: TextIO):
print("WEBVTT\n", file=file)
for segment in transcript:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
class WriteTXT(ResultWriter):
extension: str = "txt"
def write_result(self, result: dict, file: TextIO):
for segment in result["segments"]:
print(segment['text'].strip(), file=file, flush=True)
def write_srt(transcript: Iterator[dict], file: TextIO):
class WriteVTT(ResultWriter):
extension: str = "vtt"
def write_result(self, result: dict, file: TextIO):
print("WEBVTT\n", file=file)
for segment in result["segments"]:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
class WriteSRT(ResultWriter):
extension: str = "srt"
def write_result(self, result: dict, file: TextIO):
for i, segment in enumerate(result["segments"], start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
class WriteTSV(ResultWriter):
"""
Write a transcript to a file in SRT format.
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
Example usage:
from pathlib import Path
from whisper.utils import write_srt
result = transcribe(model, audio_path, temperature=temperature, **args)
# save SRT
audio_basename = Path(audio_path).stem
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
Using integer milliseconds as start and end times means there's no chance of interference from
an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
"""
for i, segment in enumerate(transcript, start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
extension: str = "tsv"
def write_result(self, result: dict, file: TextIO):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment['start']), file=file, end="\t")
print(round(1000 * segment['end']), file=file, end="\t")
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter):
extension: str = "json"
def write_result(self, result: dict, file: TextIO):
json.dump(result, file)
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
"srt": WriteSRT,
"tsv": WriteTSV,
"json": WriteJSON,
}
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result: dict, file: TextIO):
for writer in all_writers:
writer(result, file)
return write_all
return writers[output_format](output_dir)
+1
View File
@@ -0,0 +1 @@
__version__ = "20230124"