apply formatting with black (#1038)
* applying black (with the default 88-column limit) * add flake8 * add isort * fix isort
This commit is contained in:
+104
-39
@@ -1,17 +1,32 @@
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, log_mel_spectrogram, pad_or_trim
|
||||
from .audio import (
|
||||
FRAMES_PER_SECOND,
|
||||
HOP_LENGTH,
|
||||
N_FRAMES,
|
||||
SAMPLE_RATE,
|
||||
log_mel_spectrogram,
|
||||
pad_or_trim,
|
||||
)
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .timing import add_word_timestamps
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
|
||||
from .utils import (
|
||||
exact_div,
|
||||
format_timestamp,
|
||||
get_writer,
|
||||
make_safe,
|
||||
optional_float,
|
||||
optional_int,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
@@ -29,8 +44,8 @@ def transcribe(
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
word_timestamps: bool = False,
|
||||
prepend_punctuations: str = "\"\'“¿([{-",
|
||||
append_punctuations: str = "\"\'.。,,!!??::”)]}、",
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
@@ -108,12 +123,16 @@ def transcribe(
|
||||
decode_options["language"] = "en"
|
||||
else:
|
||||
if verbose:
|
||||
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
|
||||
print(
|
||||
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
||||
)
|
||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
_, probs = model.detect_language(mel_segment)
|
||||
decode_options["language"] = max(probs, key=probs.get)
|
||||
if verbose is not None:
|
||||
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
|
||||
print(
|
||||
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
||||
)
|
||||
|
||||
language: str = decode_options["language"]
|
||||
task: str = decode_options.get("task", "transcribe")
|
||||
@@ -123,7 +142,9 @@ def transcribe(
|
||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||
|
||||
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
||||
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
|
||||
temperatures = (
|
||||
[temperature] if isinstance(temperature, (int, float)) else temperature
|
||||
)
|
||||
decode_result = None
|
||||
|
||||
for t in temperatures:
|
||||
@@ -140,9 +161,15 @@ def transcribe(
|
||||
decode_result = model.decode(segment, options)
|
||||
|
||||
needs_fallback = False
|
||||
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
|
||||
if (
|
||||
compression_ratio_threshold is not None
|
||||
and decode_result.compression_ratio > compression_ratio_threshold
|
||||
):
|
||||
needs_fallback = True # too repetitive
|
||||
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
|
||||
if (
|
||||
logprob_threshold is not None
|
||||
and decode_result.avg_logprob < logprob_threshold
|
||||
):
|
||||
needs_fallback = True # average log probability is too low
|
||||
|
||||
if not needs_fallback:
|
||||
@@ -186,7 +213,9 @@ def transcribe(
|
||||
|
||||
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
||||
num_frames = mel.shape[-1]
|
||||
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
|
||||
with tqdm.tqdm(
|
||||
total=num_frames, unit="frames", disable=verbose is not False
|
||||
) as pbar:
|
||||
while seek < num_frames:
|
||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||
mel_segment = mel[:, seek:]
|
||||
@@ -201,7 +230,10 @@ def transcribe(
|
||||
if no_speech_threshold is not None:
|
||||
# no voice activity check
|
||||
should_skip = result.no_speech_prob > no_speech_threshold
|
||||
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
|
||||
if (
|
||||
logprob_threshold is not None
|
||||
and result.avg_logprob > logprob_threshold
|
||||
):
|
||||
# don't skip if the logprob is high enough, despite the no_speech_prob
|
||||
should_skip = False
|
||||
|
||||
@@ -214,22 +246,35 @@ def transcribe(
|
||||
current_tokens = []
|
||||
|
||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
|
||||
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
|
||||
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]:
|
||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[
|
||||
0
|
||||
].add_(1)
|
||||
if (
|
||||
len(consecutive) > 0
|
||||
): # if the output contains two consecutive timestamp tokens
|
||||
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [
|
||||
False,
|
||||
True,
|
||||
]:
|
||||
consecutive = consecutive.tolist() + [len(tokens)]
|
||||
|
||||
last_slice = 0
|
||||
for current_slice in consecutive:
|
||||
sliced_tokens = tokens[last_slice:current_slice]
|
||||
start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
|
||||
end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
||||
current_segments.append(new_segment(
|
||||
start=time_offset + start_timestamp_pos * time_precision,
|
||||
end=time_offset + end_timestamp_pos * time_precision,
|
||||
tokens=sliced_tokens,
|
||||
result=result,
|
||||
))
|
||||
start_timestamp_pos = (
|
||||
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
end_timestamp_pos = (
|
||||
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
current_segments.append(
|
||||
new_segment(
|
||||
start=time_offset + start_timestamp_pos * time_precision,
|
||||
end=time_offset + end_timestamp_pos * time_precision,
|
||||
tokens=sliced_tokens,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
current_tokens.append(sliced_tokens.tolist())
|
||||
last_slice = current_slice
|
||||
|
||||
@@ -238,23 +283,32 @@ def transcribe(
|
||||
seek += segment_size
|
||||
else:
|
||||
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
||||
last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||
last_timestamp_pos = (
|
||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
seek += last_timestamp_pos * input_stride
|
||||
all_tokens.extend(tokens[: last_slice + 1].tolist())
|
||||
else:
|
||||
duration = segment_duration
|
||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
|
||||
if (
|
||||
len(timestamps) > 0
|
||||
and timestamps[-1].item() != tokenizer.timestamp_begin
|
||||
):
|
||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||
last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
|
||||
last_timestamp_pos = (
|
||||
timestamps[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
duration = last_timestamp_pos * time_precision
|
||||
|
||||
current_segments.append(new_segment(
|
||||
start=time_offset,
|
||||
end=time_offset + duration,
|
||||
tokens=tokens,
|
||||
result=result,
|
||||
))
|
||||
current_segments.append(
|
||||
new_segment(
|
||||
start=time_offset,
|
||||
end=time_offset + duration,
|
||||
tokens=tokens,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
current_tokens.append(tokens.tolist())
|
||||
seek += segment_size
|
||||
|
||||
@@ -272,9 +326,13 @@ def transcribe(
|
||||
prepend_punctuations=prepend_punctuations,
|
||||
append_punctuations=append_punctuations,
|
||||
)
|
||||
word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]]
|
||||
word_end_timestamps = [
|
||||
w["end"] for s in current_segments for w in s["words"]
|
||||
]
|
||||
if len(consecutive) > 0 and len(word_end_timestamps) > 0:
|
||||
seek_shift = round((word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND)
|
||||
seek_shift = round(
|
||||
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
||||
)
|
||||
if seek_shift > 0:
|
||||
seek = previous_seek + seek_shift
|
||||
|
||||
@@ -293,21 +351,24 @@ def transcribe(
|
||||
current_tokens[i] = []
|
||||
|
||||
all_segments.extend(current_segments)
|
||||
all_tokens.extend([token for segment in current_tokens for token in segment])
|
||||
all_tokens.extend(
|
||||
[token for segment in current_tokens for token in segment]
|
||||
)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(min(num_frames, seek) - previous_seek)
|
||||
|
||||
return dict(
|
||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
|
||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
||||
segments=all_segments,
|
||||
language=language
|
||||
language=language,
|
||||
)
|
||||
|
||||
|
||||
def cli():
|
||||
from . import available_models
|
||||
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
||||
@@ -339,6 +400,7 @@ def cli():
|
||||
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
||||
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
model_name: str = args.pop("model")
|
||||
@@ -350,7 +412,9 @@ def cli():
|
||||
|
||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||
if args["language"] is not None:
|
||||
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
||||
warnings.warn(
|
||||
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
||||
)
|
||||
args["language"] = "en"
|
||||
|
||||
temperature = args.pop("temperature")
|
||||
@@ -363,6 +427,7 @@ def cli():
|
||||
torch.set_num_threads(threads)
|
||||
|
||||
from . import load_model
|
||||
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
@@ -371,5 +436,5 @@ def cli():
|
||||
writer(result, audio_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
||||
Reference in New Issue
Block a user