apply formatting with black (#1038)

* applying black (with the default 88-column limit)

* add flake8

* add isort

* fix isort
This commit is contained in:
Jong Wook Kim
2023-03-06 18:50:37 -05:00
committed by GitHub
parent 500d0fe966
commit b80bcf610d
21 changed files with 533 additions and 227 deletions
+104 -39
View File
@@ -1,17 +1,32 @@
import argparse
import os
import warnings
from typing import Optional, Tuple, Union, TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Tuple, Union
import numpy as np
import torch
import tqdm
from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, log_mel_spectrogram, pad_or_trim
from .audio import (
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
SAMPLE_RATE,
log_mel_spectrogram,
pad_or_trim,
)
from .decoding import DecodingOptions, DecodingResult
from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
from .utils import (
exact_div,
format_timestamp,
get_writer,
make_safe,
optional_float,
optional_int,
str2bool,
)
if TYPE_CHECKING:
from .model import Whisper
@@ -29,8 +44,8 @@ def transcribe(
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
word_timestamps: bool = False,
prepend_punctuations: str = "\"\'“¿([{-",
append_punctuations: str = "\"\'.。,!?::”)]}、",
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
**decode_options,
):
"""
@@ -108,12 +123,16 @@ def transcribe(
decode_options["language"] = "en"
else:
if verbose:
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
print(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
)
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
@@ -123,7 +142,9 @@ def transcribe(
warnings.warn("Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
)
decode_result = None
for t in temperatures:
@@ -140,9 +161,15 @@ def transcribe(
decode_result = model.decode(segment, options)
needs_fallback = False
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
if (
compression_ratio_threshold is not None
and decode_result.compression_ratio > compression_ratio_threshold
):
needs_fallback = True # too repetitive
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
if (
logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = True # average log probability is too low
if not needs_fallback:
@@ -186,7 +213,9 @@ def transcribe(
# show the progress bar when verbose is False (if True, transcribed text will be printed)
num_frames = mel.shape[-1]
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
with tqdm.tqdm(
total=num_frames, unit="frames", disable=verbose is not False
) as pbar:
while seek < num_frames:
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek:]
@@ -201,7 +230,10 @@ def transcribe(
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
if (
logprob_threshold is not None
and result.avg_logprob > logprob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
@@ -214,22 +246,35 @@ def transcribe(
current_tokens = []
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]:
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[
0
].add_(1)
if (
len(consecutive) > 0
): # if the output contains two consecutive timestamp tokens
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [
False,
True,
]:
consecutive = consecutive.tolist() + [len(tokens)]
last_slice = 0
for current_slice in consecutive:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
current_segments.append(new_segment(
start=time_offset + start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens,
result=result,
))
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
current_segments.append(
new_segment(
start=time_offset + start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens,
result=result,
)
)
current_tokens.append(sliced_tokens.tolist())
last_slice = current_slice
@@ -238,23 +283,32 @@ def transcribe(
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
last_timestamp_pos = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
if (
len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
duration = last_timestamp_pos * time_precision
current_segments.append(new_segment(
start=time_offset,
end=time_offset + duration,
tokens=tokens,
result=result,
))
current_segments.append(
new_segment(
start=time_offset,
end=time_offset + duration,
tokens=tokens,
result=result,
)
)
current_tokens.append(tokens.tolist())
seek += segment_size
@@ -272,9 +326,13 @@ def transcribe(
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
)
word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]]
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if len(consecutive) > 0 and len(word_end_timestamps) > 0:
seek_shift = round((word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND)
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
)
if seek_shift > 0:
seek = previous_seek + seek_shift
@@ -293,21 +351,24 @@ def transcribe(
current_tokens[i] = []
all_segments.extend(current_segments)
all_tokens.extend([token for segment in current_tokens for token in segment])
all_tokens.extend(
[token for segment in current_tokens for token in segment]
)
# update progress bar
pbar.update(min(num_frames, seek) - previous_seek)
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
segments=all_segments,
language=language
language=language,
)
def cli():
from . import available_models
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
@@ -339,6 +400,7 @@ def cli():
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
# fmt: on
args = parser.parse_args().__dict__
model_name: str = args.pop("model")
@@ -350,7 +412,9 @@ def cli():
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
warnings.warn(
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
)
args["language"] = "en"
temperature = args.pop("temperature")
@@ -363,6 +427,7 @@ def cli():
torch.set_num_threads(threads)
from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir)
@@ -371,5 +436,5 @@ def cli():
writer(result, audio_path)
if __name__ == '__main__':
if __name__ == "__main__":
cli()