10 Commits

Author SHA1 Message Date
Jong Wook Kim 6dea21fd7f Release 20230314 2023-03-15 00:39:19 -07:00
Jong Wook Kim 79c43e4859 abort find_alignment on empty input (#1090) 2023-03-14 12:47:58 -07:00
Guillaume Klein 5f9ac653b7 Fix truncated words list when the replacement character is decoded (#1089) 2023-03-14 09:32:41 -07:00
Akash Mahajan ba88b8e1b3 fix github language stats getting dominated by jupyter notebook (#1076)
Co-authored-by: Akash Mahajan <akash.mahajan@microsoft.com>
Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-03-14 00:07:09 -07:00
Guillaume Klein 671ac5a4ce Fix alignment between the segments and the list of words (#1087)
* Fix alignment between the segments and the list of words

* Ensure the word index does not overflow
2023-03-13 16:34:09 -07:00
Jong Wook Kim 839639a223 Use tiktoken (#1044)
* use tiktoken==0.3.0

* formatting

* tuple should be safer

* Update whisper/tokenizer.py

Co-authored-by: Ruhollah Majdoddin <r.majdodin@gmail.com>

* use tiktoken 0.3.1

* reflecting suggestions

* cleanup

* bypassing load_tiktoken_bpe to avoid blobfile dep

---------

Co-authored-by: Ruhollah Majdoddin <r.majdodin@gmail.com>
2023-03-13 02:34:16 -07:00
Jong Wook Kim ad3250a846 Release 20230308 2023-03-08 15:48:57 -08:00
Jong Wook Kim c4b50c0824 kwargs in decode() for convenience (#1061)
* kwargs in decode() for convenience

* formatting fix
2023-03-08 15:46:38 -08:00
Jong Wook Kim 38f2f4d99d fix all_tokens handling that caused more repetitions and discrepancy in JSON (#1060) 2023-03-08 15:34:07 -08:00
Jong Wook Kim aac47c9834 fix typo 2023-03-07 20:43:49 -08:00
22 changed files with 100693 additions and 100131 deletions
+3
View File
@@ -0,0 +1,3 @@
# Override jupyter in Github language stats for more accurate estimate of repo code languages
# reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code
*.ipynb linguist-generated
+15 -1
View File
@@ -1,6 +1,20 @@
# CHANGELOG # CHANGELOG
## [v20230307](https://github.com/openai/whisper/releases/tag/v202303067) ## [v20230314](https://github.com/openai/whisper/releases/tag/v20230314)
* abort find_alignment on empty input ([#1090](https://github.com/openai/whisper/pull/1090))
* Fix truncated words list when the replacement character is decoded ([#1089](https://github.com/openai/whisper/pull/1089))
* fix github language stats getting dominated by jupyter notebook ([#1076](https://github.com/openai/whisper/pull/1076))
* Fix alignment between the segments and the list of words ([#1087](https://github.com/openai/whisper/pull/1087))
* Use tiktoken ([#1044](https://github.com/openai/whisper/pull/1044))
## [v20230308](https://github.com/openai/whisper/releases/tag/v20230308)
* kwargs in decode() for convenience ([#1061](https://github.com/openai/whisper/pull/1061))
* fix all_tokens handling that caused more repetitions and discrepancy in JSON ([#1060](https://github.com/openai/whisper/pull/1060))
* fix typo in CHANGELOG.md
## [v20230307](https://github.com/openai/whisper/releases/tag/v20230307)
* Fix the repetition/hallucination issue identified in #1046 ([#1052](https://github.com/openai/whisper/pull/1052)) * Fix the repetition/hallucination issue identified in #1046 ([#1052](https://github.com/openai/whisper/pull/1052))
* Use triton==2.0.0 ([#1053](https://github.com/openai/whisper/pull/1053)) * Use triton==2.0.0 ([#1053](https://github.com/openai/whisper/pull/1053))
-2
View File
@@ -2,6 +2,4 @@ include requirements.txt
include README.md include README.md
include LICENSE include LICENSE
include whisper/assets/* include whisper/assets/*
include whisper/assets/gpt2/*
include whisper/assets/multilingual/*
include whisper/normalizers/english.json include whisper/normalizers/english.json
+1 -1
View File
@@ -3,5 +3,5 @@ numpy
torch torch
tqdm tqdm
more-itertools more-itertools
transformers>=4.19.0 tiktoken==0.3.1
ffmpeg-python==0.2.0 ffmpeg-python==0.2.0
+10
View File
@@ -12,3 +12,13 @@ def test_tokenizer():
assert gpt2_tokenizer.decode(gpt2_tokens) == text assert gpt2_tokenizer.decode(gpt2_tokens) == text
assert multilingual_tokenizer.decode(multilingual_tokens) == text assert multilingual_tokenizer.decode(multilingual_tokens) == text
assert len(gpt2_tokens) > len(multilingual_tokens) assert len(gpt2_tokens) > len(multilingual_tokens)
def test_split_on_unicode():
multilingual_tokenizer = get_tokenizer(multilingual=True)
tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378]
words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens)
assert words == [" elle", " est", " l", "'", "", "é", "rit", "oire"]
assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]]
+7 -1
View File
@@ -4,6 +4,7 @@ import pytest
import torch import torch
import whisper import whisper
from whisper.tokenizer import get_tokenizer
@pytest.mark.parametrize("model_name", whisper.available_models()) @pytest.mark.parametrize("model_name", whisper.available_models())
@@ -17,12 +18,18 @@ def test_transcribe(model_name: str):
audio_path, language=language, temperature=0.0, word_timestamps=True audio_path, language=language, temperature=0.0, word_timestamps=True
) )
assert result["language"] == "en" assert result["language"] == "en"
assert result["text"] == "".join([s["text"] for s in result["segments"]])
transcription = result["text"].lower() transcription = result["text"].lower()
assert "my fellow americans" in transcription assert "my fellow americans" in transcription
assert "your country" in transcription assert "your country" in transcription
assert "do for you" in transcription assert "do for you" in transcription
tokenizer = get_tokenizer(model.is_multilingual)
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
assert tokenizer.decode(all_tokens) == result["text"]
assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")
timing_checked = False timing_checked = False
for segment in result["segments"]: for segment in result["segments"]:
for timing in segment["words"]: for timing in segment["words"]:
@@ -30,7 +37,6 @@ def test_transcribe(model_name: str):
if timing["word"].strip(" ,") == "Americans": if timing["word"].strip(" ,") == "Americans":
assert timing["start"] <= 1.8 assert timing["start"] <= 1.8
assert timing["end"] >= 1.8 assert timing["end"] >= 1.8
print(timing)
timing_checked = True timing_checked = True
assert timing_checked assert timing_checked
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1 +0,0 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
@@ -1 +0,0 @@
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large Load Diff
@@ -1 +0,0 @@
{"<|endoftext|>": 50257}
File diff suppressed because it is too large Load Diff
@@ -1 +0,0 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
@@ -1 +0,0 @@
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}
File diff suppressed because one or more lines are too long
+8 -2
View File
@@ -1,4 +1,4 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
@@ -778,7 +778,10 @@ class DecodingTask:
@torch.no_grad() @torch.no_grad()
def decode( def decode(
model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions() model: "Whisper",
mel: Tensor,
options: DecodingOptions = DecodingOptions(),
**kwargs,
) -> Union[DecodingResult, List[DecodingResult]]: ) -> Union[DecodingResult, List[DecodingResult]]:
""" """
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
@@ -802,6 +805,9 @@ def decode(
if single := mel.ndim == 2: if single := mel.ndim == 2:
mel = mel.unsqueeze(0) mel = mel.unsqueeze(0)
if kwargs:
options = replace(options, **kwargs)
result = DecodingTask(model, options).run(mel) result = DecodingTask(model, options).run(mel)
return result[0] if single else result return result[0] if single else result
+31 -20
View File
@@ -1,3 +1,4 @@
import itertools
import subprocess import subprocess
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
@@ -169,6 +170,9 @@ def find_alignment(
medfilt_width: int = 7, medfilt_width: int = 7,
qk_scale: float = 1.0, qk_scale: float = 1.0,
) -> List[WordTiming]: ) -> List[WordTiming]:
if len(text_tokens) == 0:
return []
tokens = torch.tensor( tokens = torch.tensor(
[ [
*tokenizer.sot_sequence, *tokenizer.sot_sequence,
@@ -290,34 +294,41 @@ def add_word_timestamps(
if len(segments) == 0: if len(segments) == 0:
return return
text_tokens = [t for segment in segments for t in segment["tokens"]] text_tokens_per_segment = [
[token for token in segment["tokens"] if token < tokenizer.eot]
for segment in segments
]
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
merge_punctuations(alignment, prepend_punctuations, append_punctuations) merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
segment_lengths = [len(s["tokens"]) for s in segments] word_index = 0
token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
for segment in segments: for segment, text_tokens in zip(segments, text_tokens_per_segment):
segment["words"] = [] saved_tokens = 0
words = []
word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0)) while word_index < len(alignment) and saved_tokens < len(text_tokens):
for i, timing in enumerate(alignment): timing = alignment[word_index]
if timing.word:
segment = segments[token_sources[word_boundaries[i]]] if timing.word:
start = round(time_offset + timing.start, 2) words.append(
end = round(time_offset + timing.end, 2) dict(
segment["words"].append( word=timing.word,
dict( start=round(time_offset + timing.start, 2),
word=timing.word, end=round(time_offset + timing.end, 2),
start=start, probability=timing.probability,
end=end, )
probability=timing.probability,
) )
)
for segment in segments: saved_tokens += len(timing.tokens)
if len(words := segment["words"]) > 0: word_index += 1
if len(words) > 0:
# adjust the segment-level timestamps based on the word-level timestamps # adjust the segment-level timestamps based on the word-level timestamps
segment["start"] = words[0]["start"] segment["start"] = words[0]["start"]
segment["end"] = words[-1]["end"] segment["end"] = words[-1]["end"]
segment["words"] = words
+92 -85
View File
@@ -1,12 +1,12 @@
import base64
import os import os
import string import string
from dataclasses import dataclass from dataclasses import dataclass, field
from functools import cached_property, lru_cache from functools import cached_property, lru_cache
from typing import List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple
import numpy as np import tiktoken
import torch from tiktoken_ext.openai_public import gpt2
from transformers import GPT2TokenizerFast
LANGUAGES = { LANGUAGES = {
"en": "english", "en": "english",
@@ -127,74 +127,84 @@ TO_LANGUAGE_CODE = {
} }
@dataclass(frozen=True) @dataclass
class Tokenizer: class Tokenizer:
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens""" """A thin wrapper around `tiktoken` providing quick access to special tokens"""
tokenizer: "GPT2TokenizerFast" encoding: tiktoken.Encoding
language: Optional[str] language: Optional[str] = None
sot_sequence: Tuple[int] task: Optional[str] = None
sot_sequence: Tuple[int] = ()
special_tokens: Dict[str, int] = field(default_factory=dict)
def __post_init__(self):
for special in self.encoding.special_tokens_set:
special_token = self.encoding.encode_single_token(special)
self.special_tokens[special] = special_token
sot: int = self.special_tokens["<|startoftranscript|>"]
translate: int = self.special_tokens["<|translate|>"]
transcribe: int = self.special_tokens["<|transcribe|>"]
langs = tuple(LANGUAGES.keys())
sot_sequence = [sot]
if self.language is not None:
sot_sequence.append(sot + 1 + langs.index(self.language))
if self.task is not None:
task_token: int = transcribe if self.task == "transcribe" else translate
sot_sequence.append(task_token)
self.sot_sequence = tuple(sot_sequence)
def encode(self, text, **kwargs): def encode(self, text, **kwargs):
return self.tokenizer.encode(text, **kwargs) return self.encoding.encode(text, **kwargs)
def decode( def decode(self, token_ids: List[int], **kwargs) -> str:
self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs token_ids = [t for t in token_ids if t < self.timestamp_begin]
): return self.encoding.decode(token_ids, **kwargs)
return self.tokenizer.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str: def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
""" """
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
""" """
outputs = [[]] return self.encoding.decode(token_ids, **kwargs)
for token in tokens:
if token >= self.timestamp_begin:
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
outputs.append(timestamp)
outputs.append([])
else:
outputs[-1].append(token)
return "".join(
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
)
@cached_property @cached_property
def eot(self) -> int: def eot(self) -> int:
return self.tokenizer.eos_token_id return self.encoding.eot_token
@cached_property @cached_property
def transcribe(self) -> int: def transcribe(self) -> int:
return self._get_single_token_id("<|transcribe|>") return self.special_tokens["<|transcribe|>"]
@cached_property @cached_property
def translate(self) -> int: def translate(self) -> int:
return self._get_single_token_id("<|translate|>") return self.special_tokens["<|translate|>"]
@cached_property @cached_property
def sot(self) -> int: def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>") return self.special_tokens["<|startoftranscript|>"]
@cached_property @cached_property
def sot_lm(self) -> int: def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>") return self.special_tokens["<|startoflm|>"]
@cached_property @cached_property
def sot_prev(self) -> int: def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>") return self.special_tokens["<|startofprev|>"]
@cached_property @cached_property
def no_speech(self) -> int: def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>") return self.special_tokens["<|nospeech|>"]
@cached_property @cached_property
def no_timestamps(self) -> int: def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>") return self.special_tokens["<|notimestamps|>"]
@cached_property @cached_property
def timestamp_begin(self) -> int: def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1 return self.special_tokens["<|0.00|>"]
@cached_property @cached_property
def language_token(self) -> int: def language_token(self) -> int:
@@ -202,25 +212,15 @@ class Tokenizer:
if self.language is None: if self.language is None:
raise ValueError("This tokenizer does not have language token configured") raise ValueError("This tokenizer does not have language token configured")
additional_tokens = dict( if token := self.special_tokens.get(f"<|{self.language}|>", None):
zip( return token
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
)
)
candidate = f"<|{self.language}|>"
if candidate in additional_tokens:
return additional_tokens[candidate]
raise KeyError(f"Language {self.language} not found in tokenizer.") raise KeyError(f"Language {self.language} not found in tokenizer.")
@cached_property @cached_property
def all_language_tokens(self) -> Tuple[int]: def all_language_tokens(self) -> Tuple[int]:
result = [] result = []
for token, token_id in zip( for token, token_id in self.special_tokens.items():
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
):
if token.strip("<|>") in LANGUAGES: if token.strip("<|>") in LANGUAGES:
result.append(token_id) result.append(token_id)
return tuple(result) return tuple(result)
@@ -258,22 +258,17 @@ class Tokenizer:
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]} result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
for symbol in symbols + list(miscellaneous): for symbol in symbols + list(miscellaneous):
for tokens in [ for tokens in [
self.tokenizer.encode(symbol), self.encoding.encode(symbol),
self.tokenizer.encode(" " + symbol), self.encoding.encode(" " + symbol),
]: ]:
if len(tokens) == 1 or symbol in miscellaneous: if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0]) result.add(tokens[0])
return tuple(sorted(result)) return tuple(sorted(result))
def _get_single_token_id(self, text) -> int:
tokens = self.tokenizer.encode(text)
assert len(tokens) == 1, f"{text} is not encoded as a single token"
return tokens[0]
def split_to_word_tokens(self, tokens: List[int]): def split_to_word_tokens(self, tokens: List[int]):
if self.language in {"zh", "ja", "th", "lo", "my"}: if self.language in {"zh", "ja", "th", "lo", "my"}:
# These languages don't typically use spaces, so it is difficult to split words # These languages don't typically use spaces, so it is difficult to split words
@@ -284,17 +279,27 @@ class Tokenizer:
return self.split_tokens_on_spaces(tokens) return self.split_tokens_on_spaces(tokens)
def split_tokens_on_unicode(self, tokens: List[int]): def split_tokens_on_unicode(self, tokens: List[int]):
decoded_full = self.decode_with_timestamps(tokens)
replacement_char = "\ufffd"
words = [] words = []
word_tokens = [] word_tokens = []
current_tokens = [] current_tokens = []
unicode_offset = 0
for token in tokens: for token in tokens:
current_tokens.append(token) current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens) decoded = self.decode_with_timestamps(current_tokens)
if "\ufffd" not in decoded:
if (
replacement_char not in decoded
or decoded_full[unicode_offset + decoded.index(replacement_char)]
== replacement_char
):
words.append(decoded) words.append(decoded)
word_tokens.append(current_tokens) word_tokens.append(current_tokens)
current_tokens = [] current_tokens = []
unicode_offset += len(decoded)
return words, word_tokens return words, word_tokens
@@ -318,12 +323,17 @@ class Tokenizer:
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def build_tokenizer(name: str = "gpt2"): def get_encoding(name: str = "gpt2"):
os.environ["TOKENIZERS_PARALLELISM"] = "false" vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
path = os.path.join(os.path.dirname(__file__), "assets", name) ranks = {
tokenizer = GPT2TokenizerFast.from_pretrained(path) base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in open(vocab_path) if line)
}
n_vocab = len(ranks)
special_tokens = {}
specials = [ specials = [
"<|endoftext|>",
"<|startoftranscript|>", "<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()], *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
"<|translate|>", "<|translate|>",
@@ -332,18 +342,28 @@ def build_tokenizer(name: str = "gpt2"):
"<|startofprev|>", "<|startofprev|>",
"<|nospeech|>", "<|nospeech|>",
"<|notimestamps|>", "<|notimestamps|>",
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
] ]
tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) for token in specials:
return tokenizer special_tokens[token] = n_vocab
n_vocab += 1
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=gpt2()["pat_str"],
mergeable_ranks=ranks,
special_tokens=special_tokens,
)
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_tokenizer( def get_tokenizer(
multilingual: bool, multilingual: bool,
*, *,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
language: Optional[str] = None, language: Optional[str] = None,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
) -> Tokenizer: ) -> Tokenizer:
if language is not None: if language is not None:
language = language.lower() language = language.lower()
@@ -354,27 +374,14 @@ def get_tokenizer(
raise ValueError(f"Unsupported language: {language}") raise ValueError(f"Unsupported language: {language}")
if multilingual: if multilingual:
tokenizer_name = "multilingual" encoding_name = "multilingual"
task = task or "transcribe"
language = language or "en" language = language or "en"
task = task or "transcribe"
else: else:
tokenizer_name = "gpt2" encoding_name = "gpt2"
task = None
language = None language = None
task = None
tokenizer = build_tokenizer(name=tokenizer_name) encoding = get_encoding(name=encoding_name)
all_special_ids: List[int] = tokenizer.all_special_ids
sot: int = all_special_ids[1]
translate: int = all_special_ids[-6]
transcribe: int = all_special_ids[-5]
langs = tuple(LANGUAGES.keys()) return Tokenizer(encoding=encoding, language=language, task=task)
sot_sequence = [sot]
if language is not None:
sot_sequence.append(sot + 1 + langs.index(language))
if task is not None:
sot_sequence.append(transcribe if task == "transcribe" else translate)
return Tokenizer(
tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
)
+12 -10
View File
@@ -200,14 +200,14 @@ def transcribe(
def new_segment( def new_segment(
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
): ):
text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot] tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return { return {
"id": len(all_segments),
"seek": seek, "seek": seek,
"start": start, "start": start,
"end": end, "end": end,
"text": tokenizer.decode(text_tokens), "text": tokenizer.decode(text_tokens),
"tokens": text_tokens, "tokens": tokens,
"temperature": result.temperature, "temperature": result.temperature,
"avg_logprob": result.avg_logprob, "avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio, "compression_ratio": result.compression_ratio,
@@ -245,7 +245,6 @@ def transcribe(
previous_seek = seek previous_seek = seek
current_segments = [] current_segments = []
current_tokens = []
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
@@ -275,7 +274,6 @@ def transcribe(
result=result, result=result,
) )
) )
current_tokens.append(sliced_tokens.tolist())
last_slice = current_slice last_slice = current_slice
if single_timestamp_ending: if single_timestamp_ending:
@@ -287,7 +285,6 @@ def transcribe(
tokens[last_slice - 1].item() - tokenizer.timestamp_begin tokens[last_slice - 1].item() - tokenizer.timestamp_begin
) )
seek += last_timestamp_pos * input_stride seek += last_timestamp_pos * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
else: else:
duration = segment_duration duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()] timestamps = tokens[timestamp_tokens.nonzero().flatten()]
@@ -309,7 +306,6 @@ def transcribe(
result=result, result=result,
) )
) )
current_tokens.append(tokens.tolist())
seek += segment_size seek += segment_size
if not condition_on_previous_text or result.temperature > 0.5: if not condition_on_previous_text or result.temperature > 0.5:
@@ -348,11 +344,17 @@ def transcribe(
segment["text"] = "" segment["text"] = ""
segment["tokens"] = [] segment["tokens"] = []
segment["words"] = [] segment["words"] = []
current_tokens[i] = []
all_segments.extend(current_segments) all_segments.extend(
[
{"id": i, **segment}
for i, segment in enumerate(
current_segments, start=len(all_segments)
)
]
)
all_tokens.extend( all_tokens.extend(
[token for segment in current_tokens for token in segment] [token for segment in current_segments for token in segment["tokens"]]
) )
# update progress bar # update progress bar
+1 -1
View File
@@ -1 +1 @@
__version__ = "20230307" __version__ = "20230314"