Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6dea21fd7f | |||
| 79c43e4859 | |||
| 5f9ac653b7 | |||
| ba88b8e1b3 | |||
| 671ac5a4ce | |||
| 839639a223 | |||
| ad3250a846 | |||
| c4b50c0824 | |||
| 38f2f4d99d | |||
| aac47c9834 | |||
| 26807ec6d3 | |||
| 919a713499 | |||
| 38e990d853 | |||
| 924e1f8e06 | |||
| 4b0d5e58d0 |
@@ -0,0 +1,3 @@
|
|||||||
|
# Override jupyter in Github language stats for more accurate estimate of repo code languages
|
||||||
|
# reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code
|
||||||
|
*.ipynb linguist-generated
|
||||||
+34
-14
@@ -1,25 +1,45 @@
|
|||||||
# CHANGELOG
|
# CHANGELOG
|
||||||
|
|
||||||
|
## [v20230314](https://github.com/openai/whisper/releases/tag/v20230314)
|
||||||
|
|
||||||
|
* abort find_alignment on empty input ([#1090](https://github.com/openai/whisper/pull/1090))
|
||||||
|
* Fix truncated words list when the replacement character is decoded ([#1089](https://github.com/openai/whisper/pull/1089))
|
||||||
|
* fix github language stats getting dominated by jupyter notebook ([#1076](https://github.com/openai/whisper/pull/1076))
|
||||||
|
* Fix alignment between the segments and the list of words ([#1087](https://github.com/openai/whisper/pull/1087))
|
||||||
|
* Use tiktoken ([#1044](https://github.com/openai/whisper/pull/1044))
|
||||||
|
|
||||||
|
## [v20230308](https://github.com/openai/whisper/releases/tag/v20230308)
|
||||||
|
|
||||||
|
* kwargs in decode() for convenience ([#1061](https://github.com/openai/whisper/pull/1061))
|
||||||
|
* fix all_tokens handling that caused more repetitions and discrepancy in JSON ([#1060](https://github.com/openai/whisper/pull/1060))
|
||||||
|
* fix typo in CHANGELOG.md
|
||||||
|
|
||||||
|
## [v20230307](https://github.com/openai/whisper/releases/tag/v20230307)
|
||||||
|
|
||||||
|
* Fix the repetition/hallucination issue identified in #1046 ([#1052](https://github.com/openai/whisper/pull/1052))
|
||||||
|
* Use triton==2.0.0 ([#1053](https://github.com/openai/whisper/pull/1053))
|
||||||
|
* Install triton in x86_64 linux only ([#1051](https://github.com/openai/whisper/pull/1051))
|
||||||
|
* update setup.py to specify python >= 3.8 requirement
|
||||||
|
|
||||||
## [v20230306](https://github.com/openai/whisper/releases/tag/v20230306)
|
## [v20230306](https://github.com/openai/whisper/releases/tag/v20230306)
|
||||||
|
|
||||||
* #1021: remove auxiliary audio extension
|
* remove auxiliary audio extension ([#1021](https://github.com/openai/whisper/pull/1021))
|
||||||
* #1038: apply formatting with `black`, `isort`, and `flake8`
|
* apply formatting with `black`, `isort`, and `flake8` ([#1038](https://github.com/openai/whisper/pull/1038))
|
||||||
* #869: word-level timestamps in `transcribe()`
|
* word-level timestamps in `transcribe()` ([#869](https://github.com/openai/whisper/pull/869))
|
||||||
* #1033: Decoding improvements
|
* Decoding improvements ([#1033](https://github.com/openai/whisper/pull/1033))
|
||||||
* #894: Update README.md
|
* Update README.md ([#894](https://github.com/openai/whisper/pull/894))
|
||||||
* #914: Fix infinite loop caused by incorrect timestamp tokens prediction
|
* Fix infinite loop caused by incorrect timestamp tokens prediction ([#914](https://github.com/openai/whisper/pull/914))
|
||||||
* #889: drop python 3.7 support
|
* drop python 3.7 support ([#889](https://github.com/openai/whisper/pull/889))
|
||||||
|
|
||||||
## [v20230124](https://github.com/openai/whisper/releases/tag/v20230124)
|
## [v20230124](https://github.com/openai/whisper/releases/tag/v20230124)
|
||||||
|
|
||||||
* #887: handle printing even if sys.stdout.buffer is not available
|
* handle printing even if sys.stdout.buffer is not available ([#887](https://github.com/openai/whisper/pull/887))
|
||||||
* #228: Add TSV formatted output in transcript, using integer start/end time in milliseconds
|
* Add TSV formatted output in transcript, using integer start/end time in milliseconds ([#228](https://github.com/openai/whisper/pull/228))
|
||||||
* #333: Added `--output_format` option
|
* Added `--output_format` option ([#333](https://github.com/openai/whisper/pull/333))
|
||||||
* #864: Handle `XDG_CACHE_HOME` properly for `download_root`
|
* Handle `XDG_CACHE_HOME` properly for `download_root` ([#864](https://github.com/openai/whisper/pull/864))
|
||||||
* #867: use stdout for printing transcription progress
|
* use stdout for printing transcription progress ([#867](https://github.com/openai/whisper/pull/867))
|
||||||
* #659: Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm
|
* Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm ([#659](https://github.com/openai/whisper/pull/659))
|
||||||
* #859: print '?' if a letter can't be encoded using the system default encoding
|
* print '?' if a letter can't be encoded using the system default encoding ([#859](https://github.com/openai/whisper/pull/859))
|
||||||
|
|
||||||
## [v20230117](https://github.com/openai/whisper/releases/tag/v20230117)
|
## [v20230117](https://github.com/openai/whisper/releases/tag/v20230117)
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,4 @@ include requirements.txt
|
|||||||
include README.md
|
include README.md
|
||||||
include LICENSE
|
include LICENSE
|
||||||
include whisper/assets/*
|
include whisper/assets/*
|
||||||
include whisper/assets/gpt2/*
|
|
||||||
include whisper/assets/multilingual/*
|
|
||||||
include whisper/normalizers/english.json
|
include whisper/normalizers/english.json
|
||||||
|
|||||||
+1
-1
@@ -3,5 +3,5 @@ numpy
|
|||||||
torch
|
torch
|
||||||
tqdm
|
tqdm
|
||||||
more-itertools
|
more-itertools
|
||||||
transformers>=4.19.0
|
tiktoken==0.3.1
|
||||||
ffmpeg-python==0.2.0
|
ffmpeg-python==0.2.0
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
@@ -11,22 +12,8 @@ def read_version(fname="whisper/version.py"):
|
|||||||
|
|
||||||
|
|
||||||
requirements = []
|
requirements = []
|
||||||
if sys.platform.startswith("linux"):
|
if sys.platform.startswith("linux") and platform.machine() == "x86_64":
|
||||||
triton_requirement = "triton>=2.0.0.dev20221202"
|
requirements.append("triton==2.0.0")
|
||||||
try:
|
|
||||||
import re
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
version_line = (
|
|
||||||
subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
|
|
||||||
)
|
|
||||||
major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
|
|
||||||
if (int(major), int(minor)) < (11, 4):
|
|
||||||
# the last version supporting CUDA < 11.4
|
|
||||||
triton_requirement = "triton==2.0.0.dev20221011"
|
|
||||||
except (IndexError, OSError, subprocess.SubprocessError):
|
|
||||||
pass
|
|
||||||
requirements.append(triton_requirement)
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="openai-whisper",
|
name="openai-whisper",
|
||||||
@@ -36,7 +23,7 @@ setup(
|
|||||||
long_description=open("README.md", encoding="utf-8").read(),
|
long_description=open("README.md", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
readme="README.md",
|
readme="README.md",
|
||||||
python_requires=">=3.7",
|
python_requires=">=3.8",
|
||||||
author="OpenAI",
|
author="OpenAI",
|
||||||
url="https://github.com/openai/whisper",
|
url="https://github.com/openai/whisper",
|
||||||
license="MIT",
|
license="MIT",
|
||||||
|
|||||||
@@ -12,3 +12,13 @@ def test_tokenizer():
|
|||||||
assert gpt2_tokenizer.decode(gpt2_tokens) == text
|
assert gpt2_tokenizer.decode(gpt2_tokens) == text
|
||||||
assert multilingual_tokenizer.decode(multilingual_tokens) == text
|
assert multilingual_tokenizer.decode(multilingual_tokens) == text
|
||||||
assert len(gpt2_tokens) > len(multilingual_tokens)
|
assert len(gpt2_tokens) > len(multilingual_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_on_unicode():
|
||||||
|
multilingual_tokenizer = get_tokenizer(multilingual=True)
|
||||||
|
|
||||||
|
tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378]
|
||||||
|
words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens)
|
||||||
|
|
||||||
|
assert words == [" elle", " est", " l", "'", "�", "é", "rit", "oire"]
|
||||||
|
assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]]
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import whisper
|
import whisper
|
||||||
|
from whisper.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_name", whisper.available_models())
|
@pytest.mark.parametrize("model_name", whisper.available_models())
|
||||||
@@ -17,12 +18,18 @@ def test_transcribe(model_name: str):
|
|||||||
audio_path, language=language, temperature=0.0, word_timestamps=True
|
audio_path, language=language, temperature=0.0, word_timestamps=True
|
||||||
)
|
)
|
||||||
assert result["language"] == "en"
|
assert result["language"] == "en"
|
||||||
|
assert result["text"] == "".join([s["text"] for s in result["segments"]])
|
||||||
|
|
||||||
transcription = result["text"].lower()
|
transcription = result["text"].lower()
|
||||||
assert "my fellow americans" in transcription
|
assert "my fellow americans" in transcription
|
||||||
assert "your country" in transcription
|
assert "your country" in transcription
|
||||||
assert "do for you" in transcription
|
assert "do for you" in transcription
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(model.is_multilingual)
|
||||||
|
all_tokens = [t for s in result["segments"] for t in s["tokens"]]
|
||||||
|
assert tokenizer.decode(all_tokens) == result["text"]
|
||||||
|
assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")
|
||||||
|
|
||||||
timing_checked = False
|
timing_checked = False
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
for timing in segment["words"]:
|
for timing in segment["words"]:
|
||||||
@@ -30,7 +37,6 @@ def test_transcribe(model_name: str):
|
|||||||
if timing["word"].strip(" ,") == "Americans":
|
if timing["word"].strip(" ,") == "Americans":
|
||||||
assert timing["start"] <= 1.8
|
assert timing["start"] <= 1.8
|
||||||
assert timing["end"] >= 1.8
|
assert timing["end"] >= 1.8
|
||||||
print(timing)
|
|
||||||
timing_checked = True
|
timing_checked = True
|
||||||
|
|
||||||
assert timing_checked
|
assert timing_checked
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1 +0,0 @@
|
|||||||
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}
|
|
||||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -1 +0,0 @@
|
|||||||
{"<|endoftext|>": 50257}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1 +0,0 @@
|
|||||||
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}
|
|
||||||
File diff suppressed because one or more lines are too long
+17
-6
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import ffmpeg
|
import ffmpeg
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -15,10 +15,8 @@ N_FFT = 400
|
|||||||
N_MELS = 80
|
N_MELS = 80
|
||||||
HOP_LENGTH = 160
|
HOP_LENGTH = 160
|
||||||
CHUNK_LENGTH = 30
|
CHUNK_LENGTH = 30
|
||||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||||
N_FRAMES = exact_div(
|
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
||||||
N_SAMPLES, HOP_LENGTH
|
|
||||||
) # 3000: number of frames in a mel spectrogram input
|
|
||||||
|
|
||||||
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
||||||
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
||||||
@@ -100,7 +98,10 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
def log_mel_spectrogram(
|
def log_mel_spectrogram(
|
||||||
audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
|
n_mels: int = N_MELS,
|
||||||
|
padding: int = 0,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compute the log-Mel spectrogram of
|
Compute the log-Mel spectrogram of
|
||||||
@@ -113,6 +114,12 @@ def log_mel_spectrogram(
|
|||||||
n_mels: int
|
n_mels: int
|
||||||
The number of Mel-frequency filters, only 80 is supported
|
The number of Mel-frequency filters, only 80 is supported
|
||||||
|
|
||||||
|
padding: int
|
||||||
|
Number of zero samples to pad to the right
|
||||||
|
|
||||||
|
device: Optional[Union[str, torch.device]]
|
||||||
|
If given, the audio tensor is moved to this device before STFT
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor, shape = (80, n_frames)
|
torch.Tensor, shape = (80, n_frames)
|
||||||
@@ -123,6 +130,10 @@ def log_mel_spectrogram(
|
|||||||
audio = load_audio(audio)
|
audio = load_audio(audio)
|
||||||
audio = torch.from_numpy(audio)
|
audio = torch.from_numpy(audio)
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
audio = audio.to(device)
|
||||||
|
if padding > 0:
|
||||||
|
audio = F.pad(audio, (0, padding))
|
||||||
window = torch.hann_window(N_FFT).to(audio.device)
|
window = torch.hann_window(N_FFT).to(audio.device)
|
||||||
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||||
magnitudes = stft[..., :-1].abs() ** 2
|
magnitudes = stft[..., :-1].abs() ** 2
|
||||||
|
|||||||
+8
-2
@@ -1,4 +1,4 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field, replace
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -778,7 +778,10 @@ class DecodingTask:
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decode(
|
def decode(
|
||||||
model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()
|
model: "Whisper",
|
||||||
|
mel: Tensor,
|
||||||
|
options: DecodingOptions = DecodingOptions(),
|
||||||
|
**kwargs,
|
||||||
) -> Union[DecodingResult, List[DecodingResult]]:
|
) -> Union[DecodingResult, List[DecodingResult]]:
|
||||||
"""
|
"""
|
||||||
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
||||||
@@ -802,6 +805,9 @@ def decode(
|
|||||||
if single := mel.ndim == 2:
|
if single := mel.ndim == 2:
|
||||||
mel = mel.unsqueeze(0)
|
mel = mel.unsqueeze(0)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
options = replace(options, **kwargs)
|
||||||
|
|
||||||
result = DecodingTask(model, options).run(mel)
|
result = DecodingTask(model, options).run(mel)
|
||||||
|
|
||||||
return result[0] if single else result
|
return result[0] if single else result
|
||||||
|
|||||||
+26
-15
@@ -1,3 +1,4 @@
|
|||||||
|
import itertools
|
||||||
import subprocess
|
import subprocess
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -169,6 +170,9 @@ def find_alignment(
|
|||||||
medfilt_width: int = 7,
|
medfilt_width: int = 7,
|
||||||
qk_scale: float = 1.0,
|
qk_scale: float = 1.0,
|
||||||
) -> List[WordTiming]:
|
) -> List[WordTiming]:
|
||||||
|
if len(text_tokens) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
tokens = torch.tensor(
|
tokens = torch.tensor(
|
||||||
[
|
[
|
||||||
*tokenizer.sot_sequence,
|
*tokenizer.sot_sequence,
|
||||||
@@ -290,34 +294,41 @@ def add_word_timestamps(
|
|||||||
if len(segments) == 0:
|
if len(segments) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
text_tokens = [t for segment in segments for t in segment["tokens"]]
|
text_tokens_per_segment = [
|
||||||
|
[token for token in segment["tokens"] if token < tokenizer.eot]
|
||||||
|
for segment in segments
|
||||||
|
]
|
||||||
|
|
||||||
|
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
||||||
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
||||||
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
||||||
|
|
||||||
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
||||||
segment_lengths = [len(s["tokens"]) for s in segments]
|
word_index = 0
|
||||||
token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
|
|
||||||
|
|
||||||
for segment in segments:
|
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
||||||
segment["words"] = []
|
saved_tokens = 0
|
||||||
|
words = []
|
||||||
|
|
||||||
|
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
||||||
|
timing = alignment[word_index]
|
||||||
|
|
||||||
word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0))
|
|
||||||
for i, timing in enumerate(alignment):
|
|
||||||
if timing.word:
|
if timing.word:
|
||||||
segment = segments[token_sources[word_boundaries[i]]]
|
words.append(
|
||||||
start = round(time_offset + timing.start, 2)
|
|
||||||
end = round(time_offset + timing.end, 2)
|
|
||||||
segment["words"].append(
|
|
||||||
dict(
|
dict(
|
||||||
word=timing.word,
|
word=timing.word,
|
||||||
start=start,
|
start=round(time_offset + timing.start, 2),
|
||||||
end=end,
|
end=round(time_offset + timing.end, 2),
|
||||||
probability=timing.probability,
|
probability=timing.probability,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
for segment in segments:
|
saved_tokens += len(timing.tokens)
|
||||||
if len(words := segment["words"]) > 0:
|
word_index += 1
|
||||||
|
|
||||||
|
if len(words) > 0:
|
||||||
# adjust the segment-level timestamps based on the word-level timestamps
|
# adjust the segment-level timestamps based on the word-level timestamps
|
||||||
segment["start"] = words[0]["start"]
|
segment["start"] = words[0]["start"]
|
||||||
segment["end"] = words[-1]["end"]
|
segment["end"] = words[-1]["end"]
|
||||||
|
|
||||||
|
segment["words"] = words
|
||||||
|
|||||||
+92
-85
@@ -1,12 +1,12 @@
|
|||||||
|
import base64
|
||||||
import os
|
import os
|
||||||
import string
|
import string
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from functools import cached_property, lru_cache
|
from functools import cached_property, lru_cache
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import tiktoken
|
||||||
import torch
|
from tiktoken_ext.openai_public import gpt2
|
||||||
from transformers import GPT2TokenizerFast
|
|
||||||
|
|
||||||
LANGUAGES = {
|
LANGUAGES = {
|
||||||
"en": "english",
|
"en": "english",
|
||||||
@@ -127,74 +127,84 @@ TO_LANGUAGE_CODE = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
|
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
||||||
|
|
||||||
tokenizer: "GPT2TokenizerFast"
|
encoding: tiktoken.Encoding
|
||||||
language: Optional[str]
|
language: Optional[str] = None
|
||||||
sot_sequence: Tuple[int]
|
task: Optional[str] = None
|
||||||
|
sot_sequence: Tuple[int] = ()
|
||||||
|
special_tokens: Dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
for special in self.encoding.special_tokens_set:
|
||||||
|
special_token = self.encoding.encode_single_token(special)
|
||||||
|
self.special_tokens[special] = special_token
|
||||||
|
|
||||||
|
sot: int = self.special_tokens["<|startoftranscript|>"]
|
||||||
|
translate: int = self.special_tokens["<|translate|>"]
|
||||||
|
transcribe: int = self.special_tokens["<|transcribe|>"]
|
||||||
|
|
||||||
|
langs = tuple(LANGUAGES.keys())
|
||||||
|
sot_sequence = [sot]
|
||||||
|
if self.language is not None:
|
||||||
|
sot_sequence.append(sot + 1 + langs.index(self.language))
|
||||||
|
if self.task is not None:
|
||||||
|
task_token: int = transcribe if self.task == "transcribe" else translate
|
||||||
|
sot_sequence.append(task_token)
|
||||||
|
|
||||||
|
self.sot_sequence = tuple(sot_sequence)
|
||||||
|
|
||||||
def encode(self, text, **kwargs):
|
def encode(self, text, **kwargs):
|
||||||
return self.tokenizer.encode(text, **kwargs)
|
return self.encoding.encode(text, **kwargs)
|
||||||
|
|
||||||
def decode(
|
def decode(self, token_ids: List[int], **kwargs) -> str:
|
||||||
self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs
|
token_ids = [t for t in token_ids if t < self.timestamp_begin]
|
||||||
):
|
return self.encoding.decode(token_ids, **kwargs)
|
||||||
return self.tokenizer.decode(token_ids, **kwargs)
|
|
||||||
|
|
||||||
def decode_with_timestamps(self, tokens) -> str:
|
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
|
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
|
||||||
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||||
"""
|
"""
|
||||||
outputs = [[]]
|
return self.encoding.decode(token_ids, **kwargs)
|
||||||
for token in tokens:
|
|
||||||
if token >= self.timestamp_begin:
|
|
||||||
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
|
|
||||||
outputs.append(timestamp)
|
|
||||||
outputs.append([])
|
|
||||||
else:
|
|
||||||
outputs[-1].append(token)
|
|
||||||
return "".join(
|
|
||||||
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
|
||||||
)
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def eot(self) -> int:
|
def eot(self) -> int:
|
||||||
return self.tokenizer.eos_token_id
|
return self.encoding.eot_token
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def transcribe(self) -> int:
|
def transcribe(self) -> int:
|
||||||
return self._get_single_token_id("<|transcribe|>")
|
return self.special_tokens["<|transcribe|>"]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def translate(self) -> int:
|
def translate(self) -> int:
|
||||||
return self._get_single_token_id("<|translate|>")
|
return self.special_tokens["<|translate|>"]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def sot(self) -> int:
|
def sot(self) -> int:
|
||||||
return self._get_single_token_id("<|startoftranscript|>")
|
return self.special_tokens["<|startoftranscript|>"]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def sot_lm(self) -> int:
|
def sot_lm(self) -> int:
|
||||||
return self._get_single_token_id("<|startoflm|>")
|
return self.special_tokens["<|startoflm|>"]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def sot_prev(self) -> int:
|
def sot_prev(self) -> int:
|
||||||
return self._get_single_token_id("<|startofprev|>")
|
return self.special_tokens["<|startofprev|>"]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def no_speech(self) -> int:
|
def no_speech(self) -> int:
|
||||||
return self._get_single_token_id("<|nospeech|>")
|
return self.special_tokens["<|nospeech|>"]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def no_timestamps(self) -> int:
|
def no_timestamps(self) -> int:
|
||||||
return self._get_single_token_id("<|notimestamps|>")
|
return self.special_tokens["<|notimestamps|>"]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def timestamp_begin(self) -> int:
|
def timestamp_begin(self) -> int:
|
||||||
return self.tokenizer.all_special_ids[-1] + 1
|
return self.special_tokens["<|0.00|>"]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def language_token(self) -> int:
|
def language_token(self) -> int:
|
||||||
@@ -202,25 +212,15 @@ class Tokenizer:
|
|||||||
if self.language is None:
|
if self.language is None:
|
||||||
raise ValueError("This tokenizer does not have language token configured")
|
raise ValueError("This tokenizer does not have language token configured")
|
||||||
|
|
||||||
additional_tokens = dict(
|
if token := self.special_tokens.get(f"<|{self.language}|>", None):
|
||||||
zip(
|
return token
|
||||||
self.tokenizer.additional_special_tokens,
|
|
||||||
self.tokenizer.additional_special_tokens_ids,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
candidate = f"<|{self.language}|>"
|
|
||||||
if candidate in additional_tokens:
|
|
||||||
return additional_tokens[candidate]
|
|
||||||
|
|
||||||
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def all_language_tokens(self) -> Tuple[int]:
|
def all_language_tokens(self) -> Tuple[int]:
|
||||||
result = []
|
result = []
|
||||||
for token, token_id in zip(
|
for token, token_id in self.special_tokens.items():
|
||||||
self.tokenizer.additional_special_tokens,
|
|
||||||
self.tokenizer.additional_special_tokens_ids,
|
|
||||||
):
|
|
||||||
if token.strip("<|>") in LANGUAGES:
|
if token.strip("<|>") in LANGUAGES:
|
||||||
result.append(token_id)
|
result.append(token_id)
|
||||||
return tuple(result)
|
return tuple(result)
|
||||||
@@ -258,22 +258,17 @@ class Tokenizer:
|
|||||||
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
||||||
|
|
||||||
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||||
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
|
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
||||||
for symbol in symbols + list(miscellaneous):
|
for symbol in symbols + list(miscellaneous):
|
||||||
for tokens in [
|
for tokens in [
|
||||||
self.tokenizer.encode(symbol),
|
self.encoding.encode(symbol),
|
||||||
self.tokenizer.encode(" " + symbol),
|
self.encoding.encode(" " + symbol),
|
||||||
]:
|
]:
|
||||||
if len(tokens) == 1 or symbol in miscellaneous:
|
if len(tokens) == 1 or symbol in miscellaneous:
|
||||||
result.add(tokens[0])
|
result.add(tokens[0])
|
||||||
|
|
||||||
return tuple(sorted(result))
|
return tuple(sorted(result))
|
||||||
|
|
||||||
def _get_single_token_id(self, text) -> int:
|
|
||||||
tokens = self.tokenizer.encode(text)
|
|
||||||
assert len(tokens) == 1, f"{text} is not encoded as a single token"
|
|
||||||
return tokens[0]
|
|
||||||
|
|
||||||
def split_to_word_tokens(self, tokens: List[int]):
|
def split_to_word_tokens(self, tokens: List[int]):
|
||||||
if self.language in {"zh", "ja", "th", "lo", "my"}:
|
if self.language in {"zh", "ja", "th", "lo", "my"}:
|
||||||
# These languages don't typically use spaces, so it is difficult to split words
|
# These languages don't typically use spaces, so it is difficult to split words
|
||||||
@@ -284,17 +279,27 @@ class Tokenizer:
|
|||||||
return self.split_tokens_on_spaces(tokens)
|
return self.split_tokens_on_spaces(tokens)
|
||||||
|
|
||||||
def split_tokens_on_unicode(self, tokens: List[int]):
|
def split_tokens_on_unicode(self, tokens: List[int]):
|
||||||
|
decoded_full = self.decode_with_timestamps(tokens)
|
||||||
|
replacement_char = "\ufffd"
|
||||||
|
|
||||||
words = []
|
words = []
|
||||||
word_tokens = []
|
word_tokens = []
|
||||||
current_tokens = []
|
current_tokens = []
|
||||||
|
unicode_offset = 0
|
||||||
|
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
decoded = self.decode_with_timestamps(current_tokens)
|
decoded = self.decode_with_timestamps(current_tokens)
|
||||||
if "\ufffd" not in decoded:
|
|
||||||
|
if (
|
||||||
|
replacement_char not in decoded
|
||||||
|
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
||||||
|
== replacement_char
|
||||||
|
):
|
||||||
words.append(decoded)
|
words.append(decoded)
|
||||||
word_tokens.append(current_tokens)
|
word_tokens.append(current_tokens)
|
||||||
current_tokens = []
|
current_tokens = []
|
||||||
|
unicode_offset += len(decoded)
|
||||||
|
|
||||||
return words, word_tokens
|
return words, word_tokens
|
||||||
|
|
||||||
@@ -318,12 +323,17 @@ class Tokenizer:
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def build_tokenizer(name: str = "gpt2"):
|
def get_encoding(name: str = "gpt2"):
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
||||||
path = os.path.join(os.path.dirname(__file__), "assets", name)
|
ranks = {
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(path)
|
base64.b64decode(token): int(rank)
|
||||||
|
for token, rank in (line.split() for line in open(vocab_path) if line)
|
||||||
|
}
|
||||||
|
n_vocab = len(ranks)
|
||||||
|
special_tokens = {}
|
||||||
|
|
||||||
specials = [
|
specials = [
|
||||||
|
"<|endoftext|>",
|
||||||
"<|startoftranscript|>",
|
"<|startoftranscript|>",
|
||||||
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
||||||
"<|translate|>",
|
"<|translate|>",
|
||||||
@@ -332,18 +342,28 @@ def build_tokenizer(name: str = "gpt2"):
|
|||||||
"<|startofprev|>",
|
"<|startofprev|>",
|
||||||
"<|nospeech|>",
|
"<|nospeech|>",
|
||||||
"<|notimestamps|>",
|
"<|notimestamps|>",
|
||||||
|
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
||||||
]
|
]
|
||||||
|
|
||||||
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
|
for token in specials:
|
||||||
return tokenizer
|
special_tokens[token] = n_vocab
|
||||||
|
n_vocab += 1
|
||||||
|
|
||||||
|
return tiktoken.Encoding(
|
||||||
|
name=os.path.basename(vocab_path),
|
||||||
|
explicit_n_vocab=n_vocab,
|
||||||
|
pat_str=gpt2()["pat_str"],
|
||||||
|
mergeable_ranks=ranks,
|
||||||
|
special_tokens=special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_tokenizer(
|
def get_tokenizer(
|
||||||
multilingual: bool,
|
multilingual: bool,
|
||||||
*,
|
*,
|
||||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
|
||||||
language: Optional[str] = None,
|
language: Optional[str] = None,
|
||||||
|
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||||
) -> Tokenizer:
|
) -> Tokenizer:
|
||||||
if language is not None:
|
if language is not None:
|
||||||
language = language.lower()
|
language = language.lower()
|
||||||
@@ -354,27 +374,14 @@ def get_tokenizer(
|
|||||||
raise ValueError(f"Unsupported language: {language}")
|
raise ValueError(f"Unsupported language: {language}")
|
||||||
|
|
||||||
if multilingual:
|
if multilingual:
|
||||||
tokenizer_name = "multilingual"
|
encoding_name = "multilingual"
|
||||||
task = task or "transcribe"
|
|
||||||
language = language or "en"
|
language = language or "en"
|
||||||
|
task = task or "transcribe"
|
||||||
else:
|
else:
|
||||||
tokenizer_name = "gpt2"
|
encoding_name = "gpt2"
|
||||||
task = None
|
|
||||||
language = None
|
language = None
|
||||||
|
task = None
|
||||||
|
|
||||||
tokenizer = build_tokenizer(name=tokenizer_name)
|
encoding = get_encoding(name=encoding_name)
|
||||||
all_special_ids: List[int] = tokenizer.all_special_ids
|
|
||||||
sot: int = all_special_ids[1]
|
|
||||||
translate: int = all_special_ids[-6]
|
|
||||||
transcribe: int = all_special_ids[-5]
|
|
||||||
|
|
||||||
langs = tuple(LANGUAGES.keys())
|
return Tokenizer(encoding=encoding, language=language, task=task)
|
||||||
sot_sequence = [sot]
|
|
||||||
if language is not None:
|
|
||||||
sot_sequence.append(sot + 1 + langs.index(language))
|
|
||||||
if task is not None:
|
|
||||||
sot_sequence.append(transcribe if task == "transcribe" else translate)
|
|
||||||
|
|
||||||
return Tokenizer(
|
|
||||||
tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
|
|
||||||
)
|
|
||||||
|
|||||||
+33
-31
@@ -11,6 +11,7 @@ from .audio import (
|
|||||||
FRAMES_PER_SECOND,
|
FRAMES_PER_SECOND,
|
||||||
HOP_LENGTH,
|
HOP_LENGTH,
|
||||||
N_FRAMES,
|
N_FRAMES,
|
||||||
|
N_SAMPLES,
|
||||||
SAMPLE_RATE,
|
SAMPLE_RATE,
|
||||||
log_mel_spectrogram,
|
log_mel_spectrogram,
|
||||||
pad_or_trim,
|
pad_or_trim,
|
||||||
@@ -116,7 +117,9 @@ def transcribe(
|
|||||||
if dtype == torch.float32:
|
if dtype == torch.float32:
|
||||||
decode_options["fp16"] = False
|
decode_options["fp16"] = False
|
||||||
|
|
||||||
mel = log_mel_spectrogram(audio)
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
|
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
||||||
|
content_frames = mel.shape[-1] - N_FRAMES
|
||||||
|
|
||||||
if decode_options.get("language", None) is None:
|
if decode_options.get("language", None) is None:
|
||||||
if not model.is_multilingual:
|
if not model.is_multilingual:
|
||||||
@@ -197,14 +200,14 @@ def transcribe(
|
|||||||
def new_segment(
|
def new_segment(
|
||||||
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
||||||
):
|
):
|
||||||
text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot]
|
tokens = tokens.tolist()
|
||||||
|
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
||||||
return {
|
return {
|
||||||
"id": len(all_segments),
|
|
||||||
"seek": seek,
|
"seek": seek,
|
||||||
"start": start,
|
"start": start,
|
||||||
"end": end,
|
"end": end,
|
||||||
"text": tokenizer.decode(text_tokens),
|
"text": tokenizer.decode(text_tokens),
|
||||||
"tokens": text_tokens,
|
"tokens": tokens,
|
||||||
"temperature": result.temperature,
|
"temperature": result.temperature,
|
||||||
"avg_logprob": result.avg_logprob,
|
"avg_logprob": result.avg_logprob,
|
||||||
"compression_ratio": result.compression_ratio,
|
"compression_ratio": result.compression_ratio,
|
||||||
@@ -212,14 +215,13 @@ def transcribe(
|
|||||||
}
|
}
|
||||||
|
|
||||||
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
||||||
num_frames = mel.shape[-1]
|
|
||||||
with tqdm.tqdm(
|
with tqdm.tqdm(
|
||||||
total=num_frames, unit="frames", disable=verbose is not False
|
total=content_frames, unit="frames", disable=verbose is not False
|
||||||
) as pbar:
|
) as pbar:
|
||||||
while seek < num_frames:
|
while seek < content_frames:
|
||||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||||
mel_segment = mel[:, seek:]
|
mel_segment = mel[:, seek : seek + N_FRAMES]
|
||||||
segment_size = min(mel_segment.shape[-1], N_FRAMES)
|
segment_size = min(N_FRAMES, content_frames - seek)
|
||||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||||
|
|
||||||
@@ -243,23 +245,20 @@ def transcribe(
|
|||||||
|
|
||||||
previous_seek = seek
|
previous_seek = seek
|
||||||
current_segments = []
|
current_segments = []
|
||||||
current_tokens = []
|
|
||||||
|
|
||||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[
|
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||||
0
|
|
||||||
].add_(1)
|
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||||
if (
|
consecutive.add_(1)
|
||||||
len(consecutive) > 0
|
if len(consecutive) > 0:
|
||||||
): # if the output contains two consecutive timestamp tokens
|
# if the output contains two consecutive timestamp tokens
|
||||||
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [
|
slices = consecutive.tolist()
|
||||||
False,
|
if single_timestamp_ending:
|
||||||
True,
|
slices.append(len(tokens))
|
||||||
]:
|
|
||||||
consecutive = consecutive.tolist() + [len(tokens)]
|
|
||||||
|
|
||||||
last_slice = 0
|
last_slice = 0
|
||||||
for current_slice in consecutive:
|
for current_slice in slices:
|
||||||
sliced_tokens = tokens[last_slice:current_slice]
|
sliced_tokens = tokens[last_slice:current_slice]
|
||||||
start_timestamp_pos = (
|
start_timestamp_pos = (
|
||||||
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
||||||
@@ -275,10 +274,9 @@ def transcribe(
|
|||||||
result=result,
|
result=result,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
current_tokens.append(sliced_tokens.tolist())
|
|
||||||
last_slice = current_slice
|
last_slice = current_slice
|
||||||
|
|
||||||
if ended_with_single_timestamp:
|
if single_timestamp_ending:
|
||||||
# single timestamp at the end means no speech after the last timestamp.
|
# single timestamp at the end means no speech after the last timestamp.
|
||||||
seek += segment_size
|
seek += segment_size
|
||||||
else:
|
else:
|
||||||
@@ -287,7 +285,6 @@ def transcribe(
|
|||||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||||
)
|
)
|
||||||
seek += last_timestamp_pos * input_stride
|
seek += last_timestamp_pos * input_stride
|
||||||
all_tokens.extend(tokens[: last_slice + 1].tolist())
|
|
||||||
else:
|
else:
|
||||||
duration = segment_duration
|
duration = segment_duration
|
||||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||||
@@ -309,7 +306,6 @@ def transcribe(
|
|||||||
result=result,
|
result=result,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
current_tokens.append(tokens.tolist())
|
|
||||||
seek += segment_size
|
seek += segment_size
|
||||||
|
|
||||||
if not condition_on_previous_text or result.temperature > 0.5:
|
if not condition_on_previous_text or result.temperature > 0.5:
|
||||||
@@ -329,7 +325,7 @@ def transcribe(
|
|||||||
word_end_timestamps = [
|
word_end_timestamps = [
|
||||||
w["end"] for s in current_segments for w in s["words"]
|
w["end"] for s in current_segments for w in s["words"]
|
||||||
]
|
]
|
||||||
if len(consecutive) > 0 and len(word_end_timestamps) > 0:
|
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
||||||
seek_shift = round(
|
seek_shift = round(
|
||||||
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
||||||
)
|
)
|
||||||
@@ -348,15 +344,21 @@ def transcribe(
|
|||||||
segment["text"] = ""
|
segment["text"] = ""
|
||||||
segment["tokens"] = []
|
segment["tokens"] = []
|
||||||
segment["words"] = []
|
segment["words"] = []
|
||||||
current_tokens[i] = []
|
|
||||||
|
|
||||||
all_segments.extend(current_segments)
|
all_segments.extend(
|
||||||
|
[
|
||||||
|
{"id": i, **segment}
|
||||||
|
for i, segment in enumerate(
|
||||||
|
current_segments, start=len(all_segments)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
all_tokens.extend(
|
all_tokens.extend(
|
||||||
[token for segment in current_tokens for token in segment]
|
[token for segment in current_segments for token in segment["tokens"]]
|
||||||
)
|
)
|
||||||
|
|
||||||
# update progress bar
|
# update progress bar
|
||||||
pbar.update(min(num_frames, seek) - previous_seek)
|
pbar.update(min(content_frames, seek) - previous_seek)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
||||||
|
|||||||
+1
-1
@@ -1 +1 @@
|
|||||||
__version__ = "20230306"
|
__version__ = "20230314"
|
||||||
|
|||||||
Reference in New Issue
Block a user