6 Commits

Author SHA1 Message Date
Jong Wook Kim 6dea21fd7f Release 20230314 2023-03-15 00:39:19 -07:00
Jong Wook Kim 79c43e4859 abort find_alignment on empty input (#1090) 2023-03-14 12:47:58 -07:00
Guillaume Klein 5f9ac653b7 Fix truncated words list when the replacement character is decoded (#1089) 2023-03-14 09:32:41 -07:00
Akash Mahajan ba88b8e1b3 fix github language stats getting dominated by jupyter notebook (#1076)
Co-authored-by: Akash Mahajan <akash.mahajan@microsoft.com>
Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-03-14 00:07:09 -07:00
Guillaume Klein 671ac5a4ce Fix alignment between the segments and the list of words (#1087)
* Fix alignment between the segments and the list of words

* Ensure the word index does not overflow
2023-03-13 16:34:09 -07:00
Jong Wook Kim 839639a223 Use tiktoken (#1044)
* use tiktoken==0.3.0

* formatting

* tuple should be safer

* Update whisper/tokenizer.py

Co-authored-by: Ruhollah Majdoddin <r.majdodin@gmail.com>

* use tiktoken 0.3.1

* reflecting suggestions

* cleanup

* bypassing load_tiktoken_bpe to avoid blobfile dep

---------

Co-authored-by: Ruhollah Majdoddin <r.majdodin@gmail.com>
2023-03-13 02:34:16 -07:00
20 changed files with 100665 additions and 100118 deletions
+3
View File
@@ -0,0 +1,3 @@
# Override jupyter in Github language stats for more accurate estimate of repo code languages
# reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code
*.ipynb linguist-generated
+8
View File
@@ -1,5 +1,13 @@
# CHANGELOG
## [v20230314](https://github.com/openai/whisper/releases/tag/v20230314)
* abort find_alignment on empty input ([#1090](https://github.com/openai/whisper/pull/1090))
* Fix truncated words list when the replacement character is decoded ([#1089](https://github.com/openai/whisper/pull/1089))
* fix github language stats getting dominated by jupyter notebook ([#1076](https://github.com/openai/whisper/pull/1076))
* Fix alignment between the segments and the list of words ([#1087](https://github.com/openai/whisper/pull/1087))
* Use tiktoken ([#1044](https://github.com/openai/whisper/pull/1044))
## [v20230308](https://github.com/openai/whisper/releases/tag/v20230308)
* kwargs in decode() for convenience ([#1061](https://github.com/openai/whisper/pull/1061))
-2
View File
@@ -2,6 +2,4 @@ include requirements.txt
include README.md
include LICENSE
include whisper/assets/*
include whisper/assets/gpt2/*
include whisper/assets/multilingual/*
include whisper/normalizers/english.json
+1 -1
View File
@@ -3,5 +3,5 @@ numpy
torch
tqdm
more-itertools
transformers>=4.19.0
tiktoken==0.3.1
ffmpeg-python==0.2.0
+10
View File
@@ -12,3 +12,13 @@ def test_tokenizer():
assert gpt2_tokenizer.decode(gpt2_tokens) == text
assert multilingual_tokenizer.decode(multilingual_tokens) == text
assert len(gpt2_tokens) > len(multilingual_tokens)
def test_split_on_unicode():
multilingual_tokenizer = get_tokenizer(multilingual=True)
tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378]
words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens)
assert words == [" elle", " est", " l", "'", "", "é", "rit", "oire"]
assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]]
+6 -1
View File
@@ -4,6 +4,7 @@ import pytest
import torch
import whisper
from whisper.tokenizer import get_tokenizer
@pytest.mark.parametrize("model_name", whisper.available_models())
@@ -24,6 +25,11 @@ def test_transcribe(model_name: str):
assert "your country" in transcription
assert "do for you" in transcription
tokenizer = get_tokenizer(model.is_multilingual)
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
assert tokenizer.decode(all_tokens) == result["text"]
assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")
timing_checked = False
for segment in result["segments"]:
for timing in segment["words"]:
@@ -31,7 +37,6 @@ def test_transcribe(model_name: str):
if timing["word"].strip(" ,") == "Americans":
assert timing["start"] <= 1.8
assert timing["end"] >= 1.8
print(timing)
timing_checked = True
assert timing_checked
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1 +0,0 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
@@ -1 +0,0 @@
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large Load Diff
@@ -1 +0,0 @@
{"<|endoftext|>": 50257}
File diff suppressed because it is too large Load Diff
@@ -1 +0,0 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
@@ -1 +0,0 @@
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}
File diff suppressed because one or more lines are too long
+31 -20
View File
@@ -1,3 +1,4 @@
import itertools
import subprocess
import warnings
from dataclasses import dataclass
@@ -169,6 +170,9 @@ def find_alignment(
medfilt_width: int = 7,
qk_scale: float = 1.0,
) -> List[WordTiming]:
if len(text_tokens) == 0:
return []
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
@@ -290,34 +294,41 @@ def add_word_timestamps(
if len(segments) == 0:
return
text_tokens = [t for s in segments for t in s["tokens"] if t < tokenizer.eot]
text_tokens_per_segment = [
[token for token in segment["tokens"] if token < tokenizer.eot]
for segment in segments
]
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
segment_lengths = [len(s["tokens"]) for s in segments]
token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
word_index = 0
for segment in segments:
segment["words"] = []
for segment, text_tokens in zip(segments, text_tokens_per_segment):
saved_tokens = 0
words = []
word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0))
for i, timing in enumerate(alignment):
if timing.word:
segment = segments[token_sources[word_boundaries[i]]]
start = round(time_offset + timing.start, 2)
end = round(time_offset + timing.end, 2)
segment["words"].append(
dict(
word=timing.word,
start=start,
end=end,
probability=timing.probability,
while word_index < len(alignment) and saved_tokens < len(text_tokens):
timing = alignment[word_index]
if timing.word:
words.append(
dict(
word=timing.word,
start=round(time_offset + timing.start, 2),
end=round(time_offset + timing.end, 2),
probability=timing.probability,
)
)
)
for segment in segments:
if len(words := segment["words"]) > 0:
saved_tokens += len(timing.tokens)
word_index += 1
if len(words) > 0:
# adjust the segment-level timestamps based on the word-level timestamps
segment["start"] = words[0]["start"]
segment["end"] = words[-1]["end"]
segment["words"] = words
+92 -85
View File
@@ -1,12 +1,12 @@
import base64
import os
import string
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import cached_property, lru_cache
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from transformers import GPT2TokenizerFast
import tiktoken
from tiktoken_ext.openai_public import gpt2
LANGUAGES = {
"en": "english",
@@ -127,74 +127,84 @@ TO_LANGUAGE_CODE = {
}
@dataclass(frozen=True)
@dataclass
class Tokenizer:
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
tokenizer: "GPT2TokenizerFast"
language: Optional[str]
sot_sequence: Tuple[int]
encoding: tiktoken.Encoding
language: Optional[str] = None
task: Optional[str] = None
sot_sequence: Tuple[int] = ()
special_tokens: Dict[str, int] = field(default_factory=dict)
def __post_init__(self):
for special in self.encoding.special_tokens_set:
special_token = self.encoding.encode_single_token(special)
self.special_tokens[special] = special_token
sot: int = self.special_tokens["<|startoftranscript|>"]
translate: int = self.special_tokens["<|translate|>"]
transcribe: int = self.special_tokens["<|transcribe|>"]
langs = tuple(LANGUAGES.keys())
sot_sequence = [sot]
if self.language is not None:
sot_sequence.append(sot + 1 + langs.index(self.language))
if self.task is not None:
task_token: int = transcribe if self.task == "transcribe" else translate
sot_sequence.append(task_token)
self.sot_sequence = tuple(sot_sequence)
def encode(self, text, **kwargs):
return self.tokenizer.encode(text, **kwargs)
return self.encoding.encode(text, **kwargs)
def decode(
self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs
):
return self.tokenizer.decode(token_ids, **kwargs)
def decode(self, token_ids: List[int], **kwargs) -> str:
token_ids = [t for t in token_ids if t < self.timestamp_begin]
return self.encoding.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str:
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
outputs = [[]]
for token in tokens:
if token >= self.timestamp_begin:
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
outputs.append(timestamp)
outputs.append([])
else:
outputs[-1].append(token)
return "".join(
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
)
return self.encoding.decode(token_ids, **kwargs)
@cached_property
def eot(self) -> int:
return self.tokenizer.eos_token_id
return self.encoding.eot_token
@cached_property
def transcribe(self) -> int:
return self._get_single_token_id("<|transcribe|>")
return self.special_tokens["<|transcribe|>"]
@cached_property
def translate(self) -> int:
return self._get_single_token_id("<|translate|>")
return self.special_tokens["<|translate|>"]
@cached_property
def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>")
return self.special_tokens["<|startoftranscript|>"]
@cached_property
def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>")
return self.special_tokens["<|startoflm|>"]
@cached_property
def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>")
return self.special_tokens["<|startofprev|>"]
@cached_property
def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>")
return self.special_tokens["<|nospeech|>"]
@cached_property
def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>")
return self.special_tokens["<|notimestamps|>"]
@cached_property
def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1
return self.special_tokens["<|0.00|>"]
@cached_property
def language_token(self) -> int:
@@ -202,25 +212,15 @@ class Tokenizer:
if self.language is None:
raise ValueError("This tokenizer does not have language token configured")
additional_tokens = dict(
zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
)
)
candidate = f"<|{self.language}|>"
if candidate in additional_tokens:
return additional_tokens[candidate]
if token := self.special_tokens.get(f"<|{self.language}|>", None):
return token
raise KeyError(f"Language {self.language} not found in tokenizer.")
@cached_property
def all_language_tokens(self) -> Tuple[int]:
result = []
for token, token_id in zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
):
for token, token_id in self.special_tokens.items():
if token.strip("<|>") in LANGUAGES:
result.append(token_id)
return tuple(result)
@@ -258,22 +258,17 @@ class Tokenizer:
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [
self.tokenizer.encode(symbol),
self.tokenizer.encode(" " + symbol),
self.encoding.encode(symbol),
self.encoding.encode(" " + symbol),
]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
return tuple(sorted(result))
def _get_single_token_id(self, text) -> int:
tokens = self.tokenizer.encode(text)
assert len(tokens) == 1, f"{text} is not encoded as a single token"
return tokens[0]
def split_to_word_tokens(self, tokens: List[int]):
if self.language in {"zh", "ja", "th", "lo", "my"}:
# These languages don't typically use spaces, so it is difficult to split words
@@ -284,17 +279,27 @@ class Tokenizer:
return self.split_tokens_on_spaces(tokens)
def split_tokens_on_unicode(self, tokens: List[int]):
decoded_full = self.decode_with_timestamps(tokens)
replacement_char = "\ufffd"
words = []
word_tokens = []
current_tokens = []
unicode_offset = 0
for token in tokens:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if "\ufffd" not in decoded:
if (
replacement_char not in decoded
or decoded_full[unicode_offset + decoded.index(replacement_char)]
== replacement_char
):
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []
unicode_offset += len(decoded)
return words, word_tokens
@@ -318,12 +323,17 @@ class Tokenizer:
@lru_cache(maxsize=None)
def build_tokenizer(name: str = "gpt2"):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
path = os.path.join(os.path.dirname(__file__), "assets", name)
tokenizer = GPT2TokenizerFast.from_pretrained(path)
def get_encoding(name: str = "gpt2"):
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
ranks = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in open(vocab_path) if line)
}
n_vocab = len(ranks)
special_tokens = {}
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
"<|translate|>",
@@ -332,18 +342,28 @@ def build_tokenizer(name: str = "gpt2"):
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
return tokenizer
for token in specials:
special_tokens[token] = n_vocab
n_vocab += 1
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=gpt2()["pat_str"],
mergeable_ranks=ranks,
special_tokens=special_tokens,
)
@lru_cache(maxsize=None)
def get_tokenizer(
multilingual: bool,
*,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
language: Optional[str] = None,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
) -> Tokenizer:
if language is not None:
language = language.lower()
@@ -354,27 +374,14 @@ def get_tokenizer(
raise ValueError(f"Unsupported language: {language}")
if multilingual:
tokenizer_name = "multilingual"
task = task or "transcribe"
encoding_name = "multilingual"
language = language or "en"
task = task or "transcribe"
else:
tokenizer_name = "gpt2"
task = None
encoding_name = "gpt2"
language = None
task = None
tokenizer = build_tokenizer(name=tokenizer_name)
all_special_ids: List[int] = tokenizer.all_special_ids
sot: int = all_special_ids[1]
translate: int = all_special_ids[-6]
transcribe: int = all_special_ids[-5]
encoding = get_encoding(name=encoding_name)
langs = tuple(LANGUAGES.keys())
sot_sequence = [sot]
if language is not None:
sot_sequence.append(sot + 1 + langs.index(language))
if task is not None:
sot_sequence.append(transcribe if task == "transcribe" else translate)
return Tokenizer(
tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
)
return Tokenizer(encoding=encoding, language=language, task=task)
+1 -1
View File
@@ -1 +1 @@
__version__ = "20230308"
__version__ = "20230314"