Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 55f690af79 | |||
| 7f1ef223ab | |||
| f5bfe004ec | |||
| da600abd2b | |||
| 9f7aba6099 | |||
| 12e1089462 | |||
| ea1c266709 | |||
| 8135a7c31c | |||
| 9d646db9d8 |
@@ -23,4 +23,4 @@ jobs:
|
|||||||
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
|
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
|
||||||
- run: pip install pytest
|
- run: pip install pytest
|
||||||
- run: pip install .
|
- run: pip install .
|
||||||
- run: pytest -k 'not test_transcribe or test_transcribe[tiny]'
|
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]'
|
||||||
|
|||||||
@@ -84,6 +84,7 @@ def test_text_normalizer():
|
|||||||
assert std("he's like") == "he is like"
|
assert std("he's like") == "he is like"
|
||||||
assert std("she's been like") == "she has been like"
|
assert std("she's been like") == "she has been like"
|
||||||
assert std("10km") == "10 km"
|
assert std("10km") == "10 km"
|
||||||
|
assert std("10mm") == "10 mm"
|
||||||
assert std("RC232") == "rc 232"
|
assert std("RC232") == "rc 232"
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
|
|||||||
+8
-3
@@ -94,9 +94,14 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
if download_root is None:
|
if download_root is None:
|
||||||
download_root = os.getenv(
|
download_root = os.path.join(
|
||||||
"XDG_CACHE_HOME",
|
os.getenv(
|
||||||
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
"XDG_CACHE_HOME",
|
||||||
|
os.path.join(
|
||||||
|
os.path.expanduser("~"), ".cache"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
"whisper"
|
||||||
)
|
)
|
||||||
|
|
||||||
if name in _MODELS:
|
if name in _MODELS:
|
||||||
|
|||||||
@@ -1737,6 +1737,5 @@
|
|||||||
"yoghurt": "yogurt",
|
"yoghurt": "yogurt",
|
||||||
"yoghurts": "yogurts",
|
"yoghurts": "yogurts",
|
||||||
"mhm": "hmm",
|
"mhm": "hmm",
|
||||||
"mm": "hmm",
|
|
||||||
"mmm": "hmm"
|
"mmm": "hmm"
|
||||||
}
|
}
|
||||||
+8
-17
@@ -1,7 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
|
from typing import Optional, Tuple, Union, TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -10,7 +10,7 @@ import tqdm
|
|||||||
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
|
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
|
||||||
from .decoding import DecodingOptions, DecodingResult
|
from .decoding import DecodingOptions, DecodingResult
|
||||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||||
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
|
from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .model import Whisper
|
from .model import Whisper
|
||||||
@@ -165,7 +165,7 @@ def transcribe(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
|
print(make_safe(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"))
|
||||||
|
|
||||||
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
||||||
num_frames = mel.shape[-1]
|
num_frames = mel.shape[-1]
|
||||||
@@ -255,6 +255,7 @@ def cli():
|
|||||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||||
|
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
||||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||||
|
|
||||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||||
@@ -281,6 +282,7 @@ def cli():
|
|||||||
model_name: str = args.pop("model")
|
model_name: str = args.pop("model")
|
||||||
model_dir: str = args.pop("model_dir")
|
model_dir: str = args.pop("model_dir")
|
||||||
output_dir: str = args.pop("output_dir")
|
output_dir: str = args.pop("output_dir")
|
||||||
|
output_format: str = args.pop("output_format")
|
||||||
device: str = args.pop("device")
|
device: str = args.pop("device")
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
@@ -303,22 +305,11 @@ def cli():
|
|||||||
from . import load_model
|
from . import load_model
|
||||||
model = load_model(model_name, device=device, download_root=model_dir)
|
model = load_model(model_name, device=device, download_root=model_dir)
|
||||||
|
|
||||||
|
writer = get_writer(output_format, output_dir)
|
||||||
|
|
||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||||
|
writer(result, audio_path)
|
||||||
audio_basename = os.path.basename(audio_path)
|
|
||||||
|
|
||||||
# save TXT
|
|
||||||
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
|
|
||||||
write_txt(result["segments"], file=txt)
|
|
||||||
|
|
||||||
# save VTT
|
|
||||||
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
|
|
||||||
write_vtt(result["segments"], file=vtt)
|
|
||||||
|
|
||||||
# save SRT
|
|
||||||
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
|
||||||
write_srt(result["segments"], file=srt)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
+110
-35
@@ -1,5 +1,20 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import zlib
|
import zlib
|
||||||
from typing import Iterator, TextIO
|
from typing import Callable, TextIO
|
||||||
|
|
||||||
|
system_encoding = sys.getdefaultencoding()
|
||||||
|
|
||||||
|
if system_encoding != "utf-8":
|
||||||
|
def make_safe(string):
|
||||||
|
# replaces any character not representable using the system default encoding with an '?',
|
||||||
|
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
|
||||||
|
return string.encode(system_encoding, errors="replace").decode(system_encoding)
|
||||||
|
else:
|
||||||
|
def make_safe(string):
|
||||||
|
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
def exact_div(x, y):
|
def exact_div(x, y):
|
||||||
@@ -45,44 +60,104 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
|
|||||||
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||||
|
|
||||||
|
|
||||||
def write_txt(transcript: Iterator[dict], file: TextIO):
|
class ResultWriter:
|
||||||
for segment in transcript:
|
extension: str
|
||||||
print(segment['text'].strip(), file=file, flush=True)
|
|
||||||
|
def __init__(self, output_dir: str):
|
||||||
|
self.output_dir = output_dir
|
||||||
|
|
||||||
|
def __call__(self, result: dict, audio_path: str):
|
||||||
|
audio_basename = os.path.basename(audio_path)
|
||||||
|
output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
|
||||||
|
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
|
self.write_result(result, file=f)
|
||||||
|
|
||||||
|
def write_result(self, result: dict, file: TextIO):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def write_vtt(transcript: Iterator[dict], file: TextIO):
|
class WriteTXT(ResultWriter):
|
||||||
print("WEBVTT\n", file=file)
|
extension: str = "txt"
|
||||||
for segment in transcript:
|
|
||||||
print(
|
def write_result(self, result: dict, file: TextIO):
|
||||||
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
for segment in result["segments"]:
|
||||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
print(segment['text'].strip(), file=file, flush=True)
|
||||||
file=file,
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def write_srt(transcript: Iterator[dict], file: TextIO):
|
class WriteVTT(ResultWriter):
|
||||||
|
extension: str = "vtt"
|
||||||
|
|
||||||
|
def write_result(self, result: dict, file: TextIO):
|
||||||
|
print("WEBVTT\n", file=file)
|
||||||
|
for segment in result["segments"]:
|
||||||
|
print(
|
||||||
|
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
||||||
|
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||||
|
file=file,
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteSRT(ResultWriter):
|
||||||
|
extension: str = "srt"
|
||||||
|
|
||||||
|
def write_result(self, result: dict, file: TextIO):
|
||||||
|
for i, segment in enumerate(result["segments"], start=1):
|
||||||
|
# write srt lines
|
||||||
|
print(
|
||||||
|
f"{i}\n"
|
||||||
|
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
||||||
|
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
||||||
|
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||||
|
file=file,
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteTSV(ResultWriter):
|
||||||
"""
|
"""
|
||||||
Write a transcript to a file in SRT format.
|
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
|
||||||
|
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
|
||||||
|
|
||||||
Example usage:
|
Using integer milliseconds as start and end times means there's no chance of interference from
|
||||||
from pathlib import Path
|
an environment setting a language encoding that causes the decimal in a floating point number
|
||||||
from whisper.utils import write_srt
|
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
|
||||||
|
|
||||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
|
||||||
|
|
||||||
# save SRT
|
|
||||||
audio_basename = Path(audio_path).stem
|
|
||||||
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
|
||||||
write_srt(result["segments"], file=srt)
|
|
||||||
"""
|
"""
|
||||||
for i, segment in enumerate(transcript, start=1):
|
extension: str = "tsv"
|
||||||
# write srt lines
|
|
||||||
print(
|
def write_result(self, result: dict, file: TextIO):
|
||||||
f"{i}\n"
|
print("start", "end", "text", sep="\t", file=file)
|
||||||
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
for segment in result["segments"]:
|
||||||
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
print(round(1000 * segment['start']), file=file, end="\t")
|
||||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
print(round(1000 * segment['end']), file=file, end="\t")
|
||||||
file=file,
|
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
|
class WriteJSON(ResultWriter):
|
||||||
|
extension: str = "json"
|
||||||
|
|
||||||
|
def write_result(self, result: dict, file: TextIO):
|
||||||
|
json.dump(result, file)
|
||||||
|
|
||||||
|
|
||||||
|
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
|
||||||
|
writers = {
|
||||||
|
"txt": WriteTXT,
|
||||||
|
"vtt": WriteVTT,
|
||||||
|
"srt": WriteSRT,
|
||||||
|
"tsv": WriteTSV,
|
||||||
|
"json": WriteJSON,
|
||||||
|
}
|
||||||
|
|
||||||
|
if output_format == "all":
|
||||||
|
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||||
|
|
||||||
|
def write_all(result: dict, file: TextIO):
|
||||||
|
for writer in all_writers:
|
||||||
|
writer(result, file)
|
||||||
|
|
||||||
|
return write_all
|
||||||
|
|
||||||
|
return writers[output_format](output_dir)
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -1 +1 @@
|
|||||||
__version__ = "20230117"
|
__version__ = "20230124"
|
||||||
|
|||||||
Reference in New Issue
Block a user