nocaptions -> nospeech to match the paper figure

This commit is contained in:
Jong Wook Kim
2022-09-23 15:45:32 +09:00
parent 61989529b7
commit 15ab548263
3 changed files with 27 additions and 39 deletions
+10 -22
View File
@@ -23,7 +23,7 @@ def transcribe(
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
no_captions_threshold: Optional[float] = 0.6,
no_speech_threshold: Optional[float] = 0.6,
**decode_options,
):
"""
@@ -50,8 +50,8 @@ def transcribe(
logprob_threshold: float
If the average log probability over sampled tokens is below this value, treat as failed
no_captions_threshold: float
If the no_captions probability is higher than this value AND the average log probability
no_speech_threshold: float
If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below `logprob_threshold`, consider the segment as silent
decode_options: dict
@@ -148,7 +148,7 @@ def transcribe(
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_caption_prob": result.no_caption_prob,
"no_speech_prob": result.no_speech_prob,
}
)
if verbose:
@@ -163,11 +163,11 @@ def transcribe(
result = decode_with_fallback(segment)[0]
tokens = torch.tensor(result.tokens)
if no_captions_threshold is not None:
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_caption_prob > no_captions_threshold
should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
# don't skip if the logprob is high enough, despite the no_captions_prob
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
@@ -249,7 +249,7 @@ def cli():
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_caption_threshold", type=optional_float, default=0.6, help="if the probability of the <|nocaptions|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
args = parser.parse_args().__dict__
model_name: str = args.pop("model")
@@ -261,12 +261,8 @@ def cli():
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
args["language"] = "en"
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
compression_ratio_threshold = args.pop("compression_ratio_threshold")
logprob_threshold = args.pop("logprob_threshold")
no_caption_threshold = args.pop("no_caption_threshold")
temperature = args.pop("temperature")
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
if temperature_increment_on_fallback is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
else:
@@ -276,15 +272,7 @@ def cli():
model = load_model(model_name, device=device)
for audio_path in args.pop("audio"):
result = transcribe(
model,
audio_path,
temperature=temperature,
compression_ratio_threshold=compression_ratio_threshold,
logprob_threshold=logprob_threshold,
no_captions_threshold=no_caption_threshold,
**args,
)
result = transcribe(model, audio_path, temperature=temperature, **args)
audio_basename = os.path.basename(audio_path)