99 lines
3.1 KiB
Python
99 lines
3.1 KiB
Python
"""
|
|
Text-to-speech synthesis (TTS)
|
|
|
|
Sources:
|
|
— https://huggingface.co/facebook/fastspeech2-en-ljspeech
|
|
— https://github.com/AI-Guru/arxiv-reader
|
|
"""
|
|
import argparse
|
|
from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
|
|
from fairseq.models.text_to_speech.hub_interface import TTSHubInterface
|
|
import scipy
|
|
import numpy as np
|
|
|
|
def read_input_file(name: str):
|
|
"""
|
|
Read lines from input file
|
|
"""
|
|
with open(name, "r", encoding="utf-8") as file:
|
|
lines = file.readlines()
|
|
return lines
|
|
|
|
def main():
|
|
"""
|
|
Defined starting point of source code.
|
|
"""
|
|
rate = 44100
|
|
full_wave_file = []
|
|
sentences = []
|
|
|
|
# Parsing command line arguments
|
|
parser = argparse.ArgumentParser(description='Convert teext to speech.')
|
|
parser.add_argument('-i','--input', help='Input filename', required=True)
|
|
parser.add_argument('-o','--output', help='Output filename', required=True)
|
|
args = vars(parser.parse_args())
|
|
|
|
models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
|
|
"facebook/fastspeech2-en-ljspeech",
|
|
arg_overrides={"vocoder": "hifigan", "fp16": False}
|
|
)
|
|
|
|
TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg)
|
|
generator = task.build_generator(models, cfg)
|
|
|
|
# Read input file
|
|
lines = read_input_file(args['input'])
|
|
|
|
# Convert to sentences
|
|
for line in lines:
|
|
# abbreviations (in alphabetical order)
|
|
line = line.replace("%", "per cent")
|
|
line = line.replace("5G", "5 G")
|
|
line = line.replace("CO2", "C O 2")
|
|
line = line.replace("EUR", "Euro")
|
|
line = line.replace("II", "2")
|
|
line = line.replace("IBM", "I B M")
|
|
line = line.replace("IMF", "I M F")
|
|
line = line.replace("OECD", "O E C D")
|
|
line = line.replace("UN", "U N")
|
|
line = line.replace("USB", "U S B")
|
|
line = line.replace("WHO", "W H O")
|
|
line = line.replace("WTO", "W T O")
|
|
# compound words
|
|
line = line.replace("biotechnology", "bio technology")
|
|
line = line.replace("Coronavirus", "Corona virus")
|
|
line = line.replace("immunocompetence", "immuno competence")
|
|
# punctuation marks
|
|
line = line.replace("-", " - ")
|
|
line = line.replace("/", ", ")
|
|
line = line.replace("—", ". ")
|
|
line = line.replace(":", ". ")
|
|
line = line.replace(";", ". ")
|
|
line = line.replace("?", "?. ")
|
|
line = line.replace("(", ". ")
|
|
|
|
for sentence in line.split(". "):
|
|
sentences.append(sentence.strip())
|
|
sentences.append("<PAUSE>")
|
|
|
|
# Synthesis text
|
|
for text in sentences:
|
|
if text == "":
|
|
continue
|
|
|
|
if text == "<PAUSE>":
|
|
full_wave_file.extend(np.zeros(rate))
|
|
continue
|
|
|
|
sample = TTSHubInterface.get_model_input(task, text)
|
|
wav, rate = TTSHubInterface.get_prediction(task, models[0], generator, sample)
|
|
|
|
wav = wav.numpy()
|
|
full_wave_file.extend(wav)
|
|
|
|
full_wave_file = np.array(full_wave_file, dtype=np.float32)
|
|
scipy.io.wavfile.write(args['output'], rate, full_wave_file)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|