15 Commits

Author SHA1 Message Date
Jong Wook Kim 6dea21fd7f Release 20230314 2023-03-15 00:39:19 -07:00
Jong Wook Kim 79c43e4859 abort find_alignment on empty input (#1090) 2023-03-14 12:47:58 -07:00
Guillaume Klein 5f9ac653b7 Fix truncated words list when the replacement character is decoded (#1089) 2023-03-14 09:32:41 -07:00
Akash Mahajan ba88b8e1b3 fix github language stats getting dominated by jupyter notebook (#1076)
Co-authored-by: Akash Mahajan <akash.mahajan@microsoft.com>
Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-03-14 00:07:09 -07:00
Guillaume Klein 671ac5a4ce Fix alignment between the segments and the list of words (#1087)
* Fix alignment between the segments and the list of words

* Ensure the word index does not overflow
2023-03-13 16:34:09 -07:00
Jong Wook Kim 839639a223 Use tiktoken (#1044)
* use tiktoken==0.3.0

* formatting

* tuple should be safer

* Update whisper/tokenizer.py

Co-authored-by: Ruhollah Majdoddin <r.majdodin@gmail.com>

* use tiktoken 0.3.1

* reflecting suggestions

* cleanup

* bypassing load_tiktoken_bpe to avoid blobfile dep

---------

Co-authored-by: Ruhollah Majdoddin <r.majdodin@gmail.com>
2023-03-13 02:34:16 -07:00
Jong Wook Kim ad3250a846 Release 20230308 2023-03-08 15:48:57 -08:00
Jong Wook Kim c4b50c0824 kwargs in decode() for convenience (#1061)
* kwargs in decode() for convenience

* formatting fix
2023-03-08 15:46:38 -08:00
Jong Wook Kim 38f2f4d99d fix all_tokens handling that caused more repetitions and discrepancy in JSON (#1060) 2023-03-08 15:34:07 -08:00
Jong Wook Kim aac47c9834 fix typo 2023-03-07 20:43:49 -08:00
Jong Wook Kim 26807ec6d3 Release 20230307 2023-03-07 20:36:29 -08:00
Jong Wook Kim 919a713499 attempt to fix the repetition/hallucination issue identified in #1046 (#1052)
* attempt to fix the repetition/hallucination issue identified in #1046

* zero-pad the audio instead of spectrogram

* formatting fix

* delete debug print
2023-03-07 20:08:45 -08:00
Jong Wook Kim 38e990d853 Use triton==2.0.0 (#1053) 2023-03-07 16:56:31 -08:00
Jong Wook Kim 924e1f8e06 Try installing triton only if linux & x86_64 (#1051) 2023-03-07 11:31:40 -08:00
Jong Wook Kim 4b0d5e58d0 Update setup.py 2023-03-07 04:47:46 -08:00
24 changed files with 100754 additions and 100188 deletions
+3
View File
@@ -0,0 +1,3 @@
# Override jupyter in Github language stats for more accurate estimate of repo code languages
# reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code
*.ipynb linguist-generated
+34 -14
View File
@@ -1,25 +1,45 @@
# CHANGELOG # CHANGELOG
## [v20230314](https://github.com/openai/whisper/releases/tag/v20230314)
* abort find_alignment on empty input ([#1090](https://github.com/openai/whisper/pull/1090))
* Fix truncated words list when the replacement character is decoded ([#1089](https://github.com/openai/whisper/pull/1089))
* fix github language stats getting dominated by jupyter notebook ([#1076](https://github.com/openai/whisper/pull/1076))
* Fix alignment between the segments and the list of words ([#1087](https://github.com/openai/whisper/pull/1087))
* Use tiktoken ([#1044](https://github.com/openai/whisper/pull/1044))
## [v20230308](https://github.com/openai/whisper/releases/tag/v20230308)
* kwargs in decode() for convenience ([#1061](https://github.com/openai/whisper/pull/1061))
* fix all_tokens handling that caused more repetitions and discrepancy in JSON ([#1060](https://github.com/openai/whisper/pull/1060))
* fix typo in CHANGELOG.md
## [v20230307](https://github.com/openai/whisper/releases/tag/v20230307)
* Fix the repetition/hallucination issue identified in #1046 ([#1052](https://github.com/openai/whisper/pull/1052))
* Use triton==2.0.0 ([#1053](https://github.com/openai/whisper/pull/1053))
* Install triton in x86_64 linux only ([#1051](https://github.com/openai/whisper/pull/1051))
* update setup.py to specify python >= 3.8 requirement
## [v20230306](https://github.com/openai/whisper/releases/tag/v20230306) ## [v20230306](https://github.com/openai/whisper/releases/tag/v20230306)
* #1021: remove auxiliary audio extension * remove auxiliary audio extension ([#1021](https://github.com/openai/whisper/pull/1021))
* #1038: apply formatting with `black`, `isort`, and `flake8` * apply formatting with `black`, `isort`, and `flake8` ([#1038](https://github.com/openai/whisper/pull/1038))
* #869: word-level timestamps in `transcribe()` * word-level timestamps in `transcribe()` ([#869](https://github.com/openai/whisper/pull/869))
* #1033: Decoding improvements * Decoding improvements ([#1033](https://github.com/openai/whisper/pull/1033))
* #894: Update README.md * Update README.md ([#894](https://github.com/openai/whisper/pull/894))
* #914: Fix infinite loop caused by incorrect timestamp tokens prediction * Fix infinite loop caused by incorrect timestamp tokens prediction ([#914](https://github.com/openai/whisper/pull/914))
* #889: drop python 3.7 support * drop python 3.7 support ([#889](https://github.com/openai/whisper/pull/889))
## [v20230124](https://github.com/openai/whisper/releases/tag/v20230124) ## [v20230124](https://github.com/openai/whisper/releases/tag/v20230124)
* #887: handle printing even if sys.stdout.buffer is not available * handle printing even if sys.stdout.buffer is not available ([#887](https://github.com/openai/whisper/pull/887))
* #228: Add TSV formatted output in transcript, using integer start/end time in milliseconds * Add TSV formatted output in transcript, using integer start/end time in milliseconds ([#228](https://github.com/openai/whisper/pull/228))
* #333: Added `--output_format` option * Added `--output_format` option ([#333](https://github.com/openai/whisper/pull/333))
* #864: Handle `XDG_CACHE_HOME` properly for `download_root` * Handle `XDG_CACHE_HOME` properly for `download_root` ([#864](https://github.com/openai/whisper/pull/864))
* #867: use stdout for printing transcription progress * use stdout for printing transcription progress ([#867](https://github.com/openai/whisper/pull/867))
* #659: Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm * Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm ([#659](https://github.com/openai/whisper/pull/659))
* #859: print '?' if a letter can't be encoded using the system default encoding * print '?' if a letter can't be encoded using the system default encoding ([#859](https://github.com/openai/whisper/pull/859))
## [v20230117](https://github.com/openai/whisper/releases/tag/v20230117) ## [v20230117](https://github.com/openai/whisper/releases/tag/v20230117)
-2
View File
@@ -2,6 +2,4 @@ include requirements.txt
include README.md include README.md
include LICENSE include LICENSE
include whisper/assets/* include whisper/assets/*
include whisper/assets/gpt2/*
include whisper/assets/multilingual/*
include whisper/normalizers/english.json include whisper/normalizers/english.json
+1 -1
View File
@@ -3,5 +3,5 @@ numpy
torch torch
tqdm tqdm
more-itertools more-itertools
transformers>=4.19.0 tiktoken==0.3.1
ffmpeg-python==0.2.0 ffmpeg-python==0.2.0
+4 -17
View File
@@ -1,4 +1,5 @@
import os import os
import platform
import sys import sys
import pkg_resources import pkg_resources
@@ -11,22 +12,8 @@ def read_version(fname="whisper/version.py"):
requirements = [] requirements = []
if sys.platform.startswith("linux"): if sys.platform.startswith("linux") and platform.machine() == "x86_64":
triton_requirement = "triton>=2.0.0.dev20221202" requirements.append("triton==2.0.0")
try:
import re
import subprocess
version_line = (
subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
)
major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
if (int(major), int(minor)) < (11, 4):
# the last version supporting CUDA < 11.4
triton_requirement = "triton==2.0.0.dev20221011"
except (IndexError, OSError, subprocess.SubprocessError):
pass
requirements.append(triton_requirement)
setup( setup(
name="openai-whisper", name="openai-whisper",
@@ -36,7 +23,7 @@ setup(
long_description=open("README.md", encoding="utf-8").read(), long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
readme="README.md", readme="README.md",
python_requires=">=3.7", python_requires=">=3.8",
author="OpenAI", author="OpenAI",
url="https://github.com/openai/whisper", url="https://github.com/openai/whisper",
license="MIT", license="MIT",
+10
View File
@@ -12,3 +12,13 @@ def test_tokenizer():
assert gpt2_tokenizer.decode(gpt2_tokens) == text assert gpt2_tokenizer.decode(gpt2_tokens) == text
assert multilingual_tokenizer.decode(multilingual_tokens) == text assert multilingual_tokenizer.decode(multilingual_tokens) == text
assert len(gpt2_tokens) > len(multilingual_tokens) assert len(gpt2_tokens) > len(multilingual_tokens)
def test_split_on_unicode():
multilingual_tokenizer = get_tokenizer(multilingual=True)
tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378]
words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens)
assert words == [" elle", " est", " l", "'", "", "é", "rit", "oire"]
assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]]
+7 -1
View File
@@ -4,6 +4,7 @@ import pytest
import torch import torch
import whisper import whisper
from whisper.tokenizer import get_tokenizer
@pytest.mark.parametrize("model_name", whisper.available_models()) @pytest.mark.parametrize("model_name", whisper.available_models())
@@ -17,12 +18,18 @@ def test_transcribe(model_name: str):
audio_path, language=language, temperature=0.0, word_timestamps=True audio_path, language=language, temperature=0.0, word_timestamps=True
) )
assert result["language"] == "en" assert result["language"] == "en"
assert result["text"] == "".join([s["text"] for s in result["segments"]])
transcription = result["text"].lower() transcription = result["text"].lower()
assert "my fellow americans" in transcription assert "my fellow americans" in transcription
assert "your country" in transcription assert "your country" in transcription
assert "do for you" in transcription assert "do for you" in transcription
tokenizer = get_tokenizer(model.is_multilingual)
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
assert tokenizer.decode(all_tokens) == result["text"]
assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")
timing_checked = False timing_checked = False
for segment in result["segments"]: for segment in result["segments"]:
for timing in segment["words"]: for timing in segment["words"]:
@@ -30,7 +37,6 @@ def test_transcribe(model_name: str):
if timing["word"].strip(" ,") == "Americans": if timing["word"].strip(" ,") == "Americans":
assert timing["start"] <= 1.8 assert timing["start"] <= 1.8
assert timing["end"] >= 1.8 assert timing["end"] >= 1.8
print(timing)
timing_checked = True timing_checked = True
assert timing_checked assert timing_checked
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1 +0,0 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
@@ -1 +0,0 @@
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large Load Diff
@@ -1 +0,0 @@
{"<|endoftext|>": 50257}
File diff suppressed because it is too large Load Diff
@@ -1 +0,0 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
@@ -1 +0,0 @@
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}
File diff suppressed because one or more lines are too long
+17 -6
View File
@@ -1,6 +1,6 @@
import os import os
from functools import lru_cache from functools import lru_cache
from typing import Union from typing import Optional, Union
import ffmpeg import ffmpeg
import numpy as np import numpy as np
@@ -15,10 +15,8 @@ N_FFT = 400
N_MELS = 80 N_MELS = 80
HOP_LENGTH = 160 HOP_LENGTH = 160
CHUNK_LENGTH = 30 CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = exact_div( N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
N_SAMPLES, HOP_LENGTH
) # 3000: number of frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
@@ -100,7 +98,10 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
def log_mel_spectrogram( def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = N_MELS,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
): ):
""" """
Compute the log-Mel spectrogram of Compute the log-Mel spectrogram of
@@ -113,6 +114,12 @@ def log_mel_spectrogram(
n_mels: int n_mels: int
The number of Mel-frequency filters, only 80 is supported The number of Mel-frequency filters, only 80 is supported
padding: int
Number of zero samples to pad to the right
device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns Returns
------- -------
torch.Tensor, shape = (80, n_frames) torch.Tensor, shape = (80, n_frames)
@@ -123,6 +130,10 @@ def log_mel_spectrogram(
audio = load_audio(audio) audio = load_audio(audio)
audio = torch.from_numpy(audio) audio = torch.from_numpy(audio)
if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device) window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2 magnitudes = stft[..., :-1].abs() ** 2
+8 -2
View File
@@ -1,4 +1,4 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
@@ -778,7 +778,10 @@ class DecodingTask:
@torch.no_grad() @torch.no_grad()
def decode( def decode(
model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions() model: "Whisper",
mel: Tensor,
options: DecodingOptions = DecodingOptions(),
**kwargs,
) -> Union[DecodingResult, List[DecodingResult]]: ) -> Union[DecodingResult, List[DecodingResult]]:
""" """
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
@@ -802,6 +805,9 @@ def decode(
if single := mel.ndim == 2: if single := mel.ndim == 2:
mel = mel.unsqueeze(0) mel = mel.unsqueeze(0)
if kwargs:
options = replace(options, **kwargs)
result = DecodingTask(model, options).run(mel) result = DecodingTask(model, options).run(mel)
return result[0] if single else result return result[0] if single else result
+26 -15
View File
@@ -1,3 +1,4 @@
import itertools
import subprocess import subprocess
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
@@ -169,6 +170,9 @@ def find_alignment(
medfilt_width: int = 7, medfilt_width: int = 7,
qk_scale: float = 1.0, qk_scale: float = 1.0,
) -> List[WordTiming]: ) -> List[WordTiming]:
if len(text_tokens) == 0:
return []
tokens = torch.tensor( tokens = torch.tensor(
[ [
*tokenizer.sot_sequence, *tokenizer.sot_sequence,
@@ -290,34 +294,41 @@ def add_word_timestamps(
if len(segments) == 0: if len(segments) == 0:
return return
text_tokens = [t for segment in segments for t in segment["tokens"]] text_tokens_per_segment = [
[token for token in segment["tokens"] if token < tokenizer.eot]
for segment in segments
]
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
merge_punctuations(alignment, prepend_punctuations, append_punctuations) merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
segment_lengths = [len(s["tokens"]) for s in segments] word_index = 0
token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
for segment in segments: for segment, text_tokens in zip(segments, text_tokens_per_segment):
segment["words"] = [] saved_tokens = 0
words = []
while word_index < len(alignment) and saved_tokens < len(text_tokens):
timing = alignment[word_index]
word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0))
for i, timing in enumerate(alignment):
if timing.word: if timing.word:
segment = segments[token_sources[word_boundaries[i]]] words.append(
start = round(time_offset + timing.start, 2)
end = round(time_offset + timing.end, 2)
segment["words"].append(
dict( dict(
word=timing.word, word=timing.word,
start=start, start=round(time_offset + timing.start, 2),
end=end, end=round(time_offset + timing.end, 2),
probability=timing.probability, probability=timing.probability,
) )
) )
for segment in segments: saved_tokens += len(timing.tokens)
if len(words := segment["words"]) > 0: word_index += 1
if len(words) > 0:
# adjust the segment-level timestamps based on the word-level timestamps # adjust the segment-level timestamps based on the word-level timestamps
segment["start"] = words[0]["start"] segment["start"] = words[0]["start"]
segment["end"] = words[-1]["end"] segment["end"] = words[-1]["end"]
segment["words"] = words
+92 -85
View File
@@ -1,12 +1,12 @@
import base64
import os import os
import string import string
from dataclasses import dataclass from dataclasses import dataclass, field
from functools import cached_property, lru_cache from functools import cached_property, lru_cache
from typing import List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple
import numpy as np import tiktoken
import torch from tiktoken_ext.openai_public import gpt2
from transformers import GPT2TokenizerFast
LANGUAGES = { LANGUAGES = {
"en": "english", "en": "english",
@@ -127,74 +127,84 @@ TO_LANGUAGE_CODE = {
} }
@dataclass(frozen=True) @dataclass
class Tokenizer: class Tokenizer:
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens""" """A thin wrapper around `tiktoken` providing quick access to special tokens"""
tokenizer: "GPT2TokenizerFast" encoding: tiktoken.Encoding
language: Optional[str] language: Optional[str] = None
sot_sequence: Tuple[int] task: Optional[str] = None
sot_sequence: Tuple[int] = ()
special_tokens: Dict[str, int] = field(default_factory=dict)
def __post_init__(self):
for special in self.encoding.special_tokens_set:
special_token = self.encoding.encode_single_token(special)
self.special_tokens[special] = special_token
sot: int = self.special_tokens["<|startoftranscript|>"]
translate: int = self.special_tokens["<|translate|>"]
transcribe: int = self.special_tokens["<|transcribe|>"]
langs = tuple(LANGUAGES.keys())
sot_sequence = [sot]
if self.language is not None:
sot_sequence.append(sot + 1 + langs.index(self.language))
if self.task is not None:
task_token: int = transcribe if self.task == "transcribe" else translate
sot_sequence.append(task_token)
self.sot_sequence = tuple(sot_sequence)
def encode(self, text, **kwargs): def encode(self, text, **kwargs):
return self.tokenizer.encode(text, **kwargs) return self.encoding.encode(text, **kwargs)
def decode( def decode(self, token_ids: List[int], **kwargs) -> str:
self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs token_ids = [t for t in token_ids if t < self.timestamp_begin]
): return self.encoding.decode(token_ids, **kwargs)
return self.tokenizer.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str: def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
""" """
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
""" """
outputs = [[]] return self.encoding.decode(token_ids, **kwargs)
for token in tokens:
if token >= self.timestamp_begin:
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
outputs.append(timestamp)
outputs.append([])
else:
outputs[-1].append(token)
return "".join(
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
)
@cached_property @cached_property
def eot(self) -> int: def eot(self) -> int:
return self.tokenizer.eos_token_id return self.encoding.eot_token
@cached_property @cached_property
def transcribe(self) -> int: def transcribe(self) -> int:
return self._get_single_token_id("<|transcribe|>") return self.special_tokens["<|transcribe|>"]
@cached_property @cached_property
def translate(self) -> int: def translate(self) -> int:
return self._get_single_token_id("<|translate|>") return self.special_tokens["<|translate|>"]
@cached_property @cached_property
def sot(self) -> int: def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>") return self.special_tokens["<|startoftranscript|>"]
@cached_property @cached_property
def sot_lm(self) -> int: def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>") return self.special_tokens["<|startoflm|>"]
@cached_property @cached_property
def sot_prev(self) -> int: def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>") return self.special_tokens["<|startofprev|>"]
@cached_property @cached_property
def no_speech(self) -> int: def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>") return self.special_tokens["<|nospeech|>"]
@cached_property @cached_property
def no_timestamps(self) -> int: def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>") return self.special_tokens["<|notimestamps|>"]
@cached_property @cached_property
def timestamp_begin(self) -> int: def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1 return self.special_tokens["<|0.00|>"]
@cached_property @cached_property
def language_token(self) -> int: def language_token(self) -> int:
@@ -202,25 +212,15 @@ class Tokenizer:
if self.language is None: if self.language is None:
raise ValueError("This tokenizer does not have language token configured") raise ValueError("This tokenizer does not have language token configured")
additional_tokens = dict( if token := self.special_tokens.get(f"<|{self.language}|>", None):
zip( return token
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
)
)
candidate = f"<|{self.language}|>"
if candidate in additional_tokens:
return additional_tokens[candidate]
raise KeyError(f"Language {self.language} not found in tokenizer.") raise KeyError(f"Language {self.language} not found in tokenizer.")
@cached_property @cached_property
def all_language_tokens(self) -> Tuple[int]: def all_language_tokens(self) -> Tuple[int]:
result = [] result = []
for token, token_id in zip( for token, token_id in self.special_tokens.items():
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
):
if token.strip("<|>") in LANGUAGES: if token.strip("<|>") in LANGUAGES:
result.append(token_id) result.append(token_id)
return tuple(result) return tuple(result)
@@ -258,22 +258,17 @@ class Tokenizer:
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]} result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
for symbol in symbols + list(miscellaneous): for symbol in symbols + list(miscellaneous):
for tokens in [ for tokens in [
self.tokenizer.encode(symbol), self.encoding.encode(symbol),
self.tokenizer.encode(" " + symbol), self.encoding.encode(" " + symbol),
]: ]:
if len(tokens) == 1 or symbol in miscellaneous: if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0]) result.add(tokens[0])
return tuple(sorted(result)) return tuple(sorted(result))
def _get_single_token_id(self, text) -> int:
tokens = self.tokenizer.encode(text)
assert len(tokens) == 1, f"{text} is not encoded as a single token"
return tokens[0]
def split_to_word_tokens(self, tokens: List[int]): def split_to_word_tokens(self, tokens: List[int]):
if self.language in {"zh", "ja", "th", "lo", "my"}: if self.language in {"zh", "ja", "th", "lo", "my"}:
# These languages don't typically use spaces, so it is difficult to split words # These languages don't typically use spaces, so it is difficult to split words
@@ -284,17 +279,27 @@ class Tokenizer:
return self.split_tokens_on_spaces(tokens) return self.split_tokens_on_spaces(tokens)
def split_tokens_on_unicode(self, tokens: List[int]): def split_tokens_on_unicode(self, tokens: List[int]):
decoded_full = self.decode_with_timestamps(tokens)
replacement_char = "\ufffd"
words = [] words = []
word_tokens = [] word_tokens = []
current_tokens = [] current_tokens = []
unicode_offset = 0
for token in tokens: for token in tokens:
current_tokens.append(token) current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens) decoded = self.decode_with_timestamps(current_tokens)
if "\ufffd" not in decoded:
if (
replacement_char not in decoded
or decoded_full[unicode_offset + decoded.index(replacement_char)]
== replacement_char
):
words.append(decoded) words.append(decoded)
word_tokens.append(current_tokens) word_tokens.append(current_tokens)
current_tokens = [] current_tokens = []
unicode_offset += len(decoded)
return words, word_tokens return words, word_tokens
@@ -318,12 +323,17 @@ class Tokenizer:
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def build_tokenizer(name: str = "gpt2"): def get_encoding(name: str = "gpt2"):
os.environ["TOKENIZERS_PARALLELISM"] = "false" vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
path = os.path.join(os.path.dirname(__file__), "assets", name) ranks = {
tokenizer = GPT2TokenizerFast.from_pretrained(path) base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in open(vocab_path) if line)
}
n_vocab = len(ranks)
special_tokens = {}
specials = [ specials = [
"<|endoftext|>",
"<|startoftranscript|>", "<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()], *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
"<|translate|>", "<|translate|>",
@@ -332,18 +342,28 @@ def build_tokenizer(name: str = "gpt2"):
"<|startofprev|>", "<|startofprev|>",
"<|nospeech|>", "<|nospeech|>",
"<|notimestamps|>", "<|notimestamps|>",
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
] ]
tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) for token in specials:
return tokenizer special_tokens[token] = n_vocab
n_vocab += 1
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=gpt2()["pat_str"],
mergeable_ranks=ranks,
special_tokens=special_tokens,
)
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_tokenizer( def get_tokenizer(
multilingual: bool, multilingual: bool,
*, *,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
language: Optional[str] = None, language: Optional[str] = None,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
) -> Tokenizer: ) -> Tokenizer:
if language is not None: if language is not None:
language = language.lower() language = language.lower()
@@ -354,27 +374,14 @@ def get_tokenizer(
raise ValueError(f"Unsupported language: {language}") raise ValueError(f"Unsupported language: {language}")
if multilingual: if multilingual:
tokenizer_name = "multilingual" encoding_name = "multilingual"
task = task or "transcribe"
language = language or "en" language = language or "en"
task = task or "transcribe"
else: else:
tokenizer_name = "gpt2" encoding_name = "gpt2"
task = None
language = None language = None
task = None
tokenizer = build_tokenizer(name=tokenizer_name) encoding = get_encoding(name=encoding_name)
all_special_ids: List[int] = tokenizer.all_special_ids
sot: int = all_special_ids[1]
translate: int = all_special_ids[-6]
transcribe: int = all_special_ids[-5]
langs = tuple(LANGUAGES.keys()) return Tokenizer(encoding=encoding, language=language, task=task)
sot_sequence = [sot]
if language is not None:
sot_sequence.append(sot + 1 + langs.index(language))
if task is not None:
sot_sequence.append(transcribe if task == "transcribe" else translate)
return Tokenizer(
tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
)
+33 -31
View File
@@ -11,6 +11,7 @@ from .audio import (
FRAMES_PER_SECOND, FRAMES_PER_SECOND,
HOP_LENGTH, HOP_LENGTH,
N_FRAMES, N_FRAMES,
N_SAMPLES,
SAMPLE_RATE, SAMPLE_RATE,
log_mel_spectrogram, log_mel_spectrogram,
pad_or_trim, pad_or_trim,
@@ -116,7 +117,9 @@ def transcribe(
if dtype == torch.float32: if dtype == torch.float32:
decode_options["fp16"] = False decode_options["fp16"] = False
mel = log_mel_spectrogram(audio) # Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES
if decode_options.get("language", None) is None: if decode_options.get("language", None) is None:
if not model.is_multilingual: if not model.is_multilingual:
@@ -197,14 +200,14 @@ def transcribe(
def new_segment( def new_segment(
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
): ):
text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot] tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return { return {
"id": len(all_segments),
"seek": seek, "seek": seek,
"start": start, "start": start,
"end": end, "end": end,
"text": tokenizer.decode(text_tokens), "text": tokenizer.decode(text_tokens),
"tokens": text_tokens, "tokens": tokens,
"temperature": result.temperature, "temperature": result.temperature,
"avg_logprob": result.avg_logprob, "avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio, "compression_ratio": result.compression_ratio,
@@ -212,14 +215,13 @@ def transcribe(
} }
# show the progress bar when verbose is False (if True, transcribed text will be printed) # show the progress bar when verbose is False (if True, transcribed text will be printed)
num_frames = mel.shape[-1]
with tqdm.tqdm( with tqdm.tqdm(
total=num_frames, unit="frames", disable=verbose is not False total=content_frames, unit="frames", disable=verbose is not False
) as pbar: ) as pbar:
while seek < num_frames: while seek < content_frames:
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek:] mel_segment = mel[:, seek : seek + N_FRAMES]
segment_size = min(mel_segment.shape[-1], N_FRAMES) segment_size = min(N_FRAMES, content_frames - seek)
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
@@ -243,23 +245,20 @@ def transcribe(
previous_seek = seek previous_seek = seek
current_segments = [] current_segments = []
current_tokens = []
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
0
].add_(1) consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
if ( consecutive.add_(1)
len(consecutive) > 0 if len(consecutive) > 0:
): # if the output contains two consecutive timestamp tokens # if the output contains two consecutive timestamp tokens
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [ slices = consecutive.tolist()
False, if single_timestamp_ending:
True, slices.append(len(tokens))
]:
consecutive = consecutive.tolist() + [len(tokens)]
last_slice = 0 last_slice = 0
for current_slice in consecutive: for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice] sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = ( start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin sliced_tokens[0].item() - tokenizer.timestamp_begin
@@ -275,10 +274,9 @@ def transcribe(
result=result, result=result,
) )
) )
current_tokens.append(sliced_tokens.tolist())
last_slice = current_slice last_slice = current_slice
if ended_with_single_timestamp: if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp. # single timestamp at the end means no speech after the last timestamp.
seek += segment_size seek += segment_size
else: else:
@@ -287,7 +285,6 @@ def transcribe(
tokens[last_slice - 1].item() - tokenizer.timestamp_begin tokens[last_slice - 1].item() - tokenizer.timestamp_begin
) )
seek += last_timestamp_pos * input_stride seek += last_timestamp_pos * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
else: else:
duration = segment_duration duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()] timestamps = tokens[timestamp_tokens.nonzero().flatten()]
@@ -309,7 +306,6 @@ def transcribe(
result=result, result=result,
) )
) )
current_tokens.append(tokens.tolist())
seek += segment_size seek += segment_size
if not condition_on_previous_text or result.temperature > 0.5: if not condition_on_previous_text or result.temperature > 0.5:
@@ -329,7 +325,7 @@ def transcribe(
word_end_timestamps = [ word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"] w["end"] for s in current_segments for w in s["words"]
] ]
if len(consecutive) > 0 and len(word_end_timestamps) > 0: if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round( seek_shift = round(
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
) )
@@ -348,15 +344,21 @@ def transcribe(
segment["text"] = "" segment["text"] = ""
segment["tokens"] = [] segment["tokens"] = []
segment["words"] = [] segment["words"] = []
current_tokens[i] = []
all_segments.extend(current_segments) all_segments.extend(
[
{"id": i, **segment}
for i, segment in enumerate(
current_segments, start=len(all_segments)
)
]
)
all_tokens.extend( all_tokens.extend(
[token for segment in current_tokens for token in segment] [token for segment in current_segments for token in segment["tokens"]]
) )
# update progress bar # update progress bar
pbar.update(min(num_frames, seek) - previous_seek) pbar.update(min(content_frames, seek) - previous_seek)
return dict( return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
+1 -1
View File
@@ -1 +1 @@
__version__ = "20230306" __version__ = "20230314"