9 Commits

Author SHA1 Message Date
Jong Wook Kim 55f690af79 Release 20230124 2023-01-24 11:11:08 -08:00
Jong Wook Kim 7f1ef223ab handle printing even if sys.stdout.buffer is not available (#887) 2023-01-24 10:12:04 -08:00
Niels Mayer f5bfe004ec Add TSV formatted output in transcript, using integer start/end times in milliseconds. (#228)
* Add CSV format output in transcript, containing lines of characters formatted like: <startTime-in-integer-milliseconds>, <endTime-in-integer-milliseconds>, <transcript-including-commas>

* for easier reading by spreadsheets importing CSV, the third

column of the CSV file is delimited by quotes, and any quote
characters that might be in the transcript (which would interfere with
parsing the third column as a string) are converted to "''".

* fix syntax error

* docstring edit

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
2023-01-22 00:27:17 -08:00
Aaryan YVS da600abd2b Added --output_format option (#333)
* Added --output option

--output option will help select the output files that will be generated.

Corrected the logic, which wrongly shows progress bar when verbose is set to False

* Changed output_files variable

* Changed back the tqdm verbose

* refactor output format handling

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
2023-01-21 23:58:38 -08:00
zer0-x 9f7aba6099 Handle XDG_CACHE_HOME properly for download_root (#864)
Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-01-21 01:09:39 -08:00
Jong Wook Kim 12e1089462 use stdout for printing transcription progress (#867) 2023-01-20 00:54:05 -08:00
Markus Hennerbichler ea1c266709 Fix bug where mm is mistakenly replaced with hmm in e.g. 20mm (#659)
Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
2023-01-18 10:41:11 -08:00
Jong Wook Kim 8135a7c31c verbose outputs from pytest 2023-01-18 10:30:18 -08:00
Jong Wook Kim 9d646db9d8 print '?' if a letter can't be encoded using the system default encoding (#859) 2023-01-17 23:28:36 -08:00
7 changed files with 129 additions and 58 deletions
+1 -1
View File
@@ -23,4 +23,4 @@ jobs:
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
- run: pip install pytest - run: pip install pytest
- run: pip install . - run: pip install .
- run: pytest -k 'not test_transcribe or test_transcribe[tiny]' - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]'
+1
View File
@@ -84,6 +84,7 @@ def test_text_normalizer():
assert std("he's like") == "he is like" assert std("he's like") == "he is like"
assert std("she's been like") == "she has been like" assert std("she's been like") == "she has been like"
assert std("10km") == "10 km" assert std("10km") == "10 km"
assert std("10mm") == "10 mm"
assert std("RC232") == "rc 232" assert std("RC232") == "rc 232"
assert ( assert (
+7 -2
View File
@@ -94,9 +94,14 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
if device is None: if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None: if download_root is None:
download_root = os.getenv( download_root = os.path.join(
os.getenv(
"XDG_CACHE_HOME", "XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache", "whisper") os.path.join(
os.path.expanduser("~"), ".cache"
)
),
"whisper"
) )
if name in _MODELS: if name in _MODELS:
-1
View File
@@ -1737,6 +1737,5 @@
"yoghurt": "yogurt", "yoghurt": "yogurt",
"yoghurts": "yogurts", "yoghurts": "yogurts",
"mhm": "hmm", "mhm": "hmm",
"mm": "hmm",
"mmm": "hmm" "mmm": "hmm"
} }
+8 -17
View File
@@ -1,7 +1,7 @@
import argparse import argparse
import os import os
import warnings import warnings
from typing import List, Optional, Tuple, Union, TYPE_CHECKING from typing import Optional, Tuple, Union, TYPE_CHECKING
import numpy as np import numpy as np
import torch import torch
@@ -10,7 +10,7 @@ import tqdm
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
from .decoding import DecodingOptions, DecodingResult from .decoding import DecodingOptions, DecodingResult
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
if TYPE_CHECKING: if TYPE_CHECKING:
from .model import Whisper from .model import Whisper
@@ -165,7 +165,7 @@ def transcribe(
} }
) )
if verbose: if verbose:
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}") print(make_safe(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"))
# show the progress bar when verbose is False (otherwise the transcribed text will be printed) # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
num_frames = mel.shape[-1] num_frames = mel.shape[-1]
@@ -255,6 +255,7 @@ def cli():
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
@@ -281,6 +282,7 @@ def cli():
model_name: str = args.pop("model") model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir") model_dir: str = args.pop("model_dir")
output_dir: str = args.pop("output_dir") output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device") device: str = args.pop("device")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
@@ -303,22 +305,11 @@ def cli():
from . import load_model from . import load_model
model = load_model(model_name, device=device, download_root=model_dir) model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir)
for audio_path in args.pop("audio"): for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args) result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path)
audio_basename = os.path.basename(audio_path)
# save TXT
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
write_txt(result["segments"], file=txt)
# save VTT
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
write_vtt(result["segments"], file=vtt)
# save SRT
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
if __name__ == '__main__': if __name__ == '__main__':
+95 -20
View File
@@ -1,5 +1,20 @@
import json
import os
import sys
import zlib import zlib
from typing import Iterator, TextIO from typing import Callable, TextIO
system_encoding = sys.getdefaultencoding()
if system_encoding != "utf-8":
def make_safe(string):
# replaces any character not representable using the system default encoding with an '?',
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
return string.encode(system_encoding, errors="replace").decode(system_encoding)
else:
def make_safe(string):
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
return string
def exact_div(x, y): def exact_div(x, y):
@@ -45,14 +60,37 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
def write_txt(transcript: Iterator[dict], file: TextIO): class ResultWriter:
for segment in transcript: extension: str
def __init__(self, output_dir: str):
self.output_dir = output_dir
def __call__(self, result: dict, audio_path: str):
audio_basename = os.path.basename(audio_path)
output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f)
def write_result(self, result: dict, file: TextIO):
raise NotImplementedError
class WriteTXT(ResultWriter):
extension: str = "txt"
def write_result(self, result: dict, file: TextIO):
for segment in result["segments"]:
print(segment['text'].strip(), file=file, flush=True) print(segment['text'].strip(), file=file, flush=True)
def write_vtt(transcript: Iterator[dict], file: TextIO): class WriteVTT(ResultWriter):
extension: str = "vtt"
def write_result(self, result: dict, file: TextIO):
print("WEBVTT\n", file=file) print("WEBVTT\n", file=file)
for segment in transcript: for segment in result["segments"]:
print( print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n", f"{segment['text'].strip().replace('-->', '->')}\n",
@@ -61,22 +99,11 @@ def write_vtt(transcript: Iterator[dict], file: TextIO):
) )
def write_srt(transcript: Iterator[dict], file: TextIO): class WriteSRT(ResultWriter):
""" extension: str = "srt"
Write a transcript to a file in SRT format.
Example usage: def write_result(self, result: dict, file: TextIO):
from pathlib import Path for i, segment in enumerate(result["segments"], start=1):
from whisper.utils import write_srt
result = transcribe(model, audio_path, temperature=temperature, **args)
# save SRT
audio_basename = Path(audio_path).stem
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
"""
for i, segment in enumerate(transcript, start=1):
# write srt lines # write srt lines
print( print(
f"{i}\n" f"{i}\n"
@@ -86,3 +113,51 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
file=file, file=file,
flush=True, flush=True,
) )
class WriteTSV(ResultWriter):
"""
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
Using integer milliseconds as start and end times means there's no chance of interference from
an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
"""
extension: str = "tsv"
def write_result(self, result: dict, file: TextIO):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment['start']), file=file, end="\t")
print(round(1000 * segment['end']), file=file, end="\t")
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter):
extension: str = "json"
def write_result(self, result: dict, file: TextIO):
json.dump(result, file)
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
"srt": WriteSRT,
"tsv": WriteTSV,
"json": WriteJSON,
}
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result: dict, file: TextIO):
for writer in all_writers:
writer(result, file)
return write_all
return writers[output_format](output_dir)
+1 -1
View File
@@ -1 +1 @@
__version__ = "20230117" __version__ = "20230124"