Added main function as entry point

This commit is contained in:
Heiko J Schick
2022-08-30 16:28:44 +02:00
parent 5978cb7bcf
commit 494b15fd91
+44 -40
View File
@@ -10,53 +10,57 @@ import scipy
import numpy as np
import IPython.display as ipd
def main():
"""
Defined starting point of source code.
"""
models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
"facebook/fastspeech2-en-ljspeech",
arg_overrides={"vocoder": "hifigan", "fp16": False}
)
models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
"facebook/fastspeech2-en-ljspeech",
arg_overrides={"vocoder": "hifigan", "fp16": False}
)
TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg)
generator = task.build_generator(models, cfg)
TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg)
generator = task.build_generator(models, cfg)
full_wave_file = []
rate = 44100
sentences = []
full_wave_file = []
rate = 44100
sentences = []
# Read input file
with open(f"input.txt", "r") as f:
lines = f.readlines()
# Read input file
with open(f"input.txt", "r") as f:
lines = f.readlines()
# Convert to sentences
for line in lines:
line = line.replace("-", " - ")
line = line.replace("/", ", ")
line = line.replace("", ". ")
line = line.replace(":", ". ")
line = line.replace(";", ". ")
line = line.replace("(", ". ")
line = line.replace(")", ". ")
for x in line.split(". "):
# print(x)
sentences.append(x.strip())
# print(sentences)
sentences.append("<PAUSE>")
# print(sentences)
# Convert to sentences
for line in lines:
line = line.replace("-", " - ")
line = line.replace("/", ", ")
line = line.replace("", ". ")
line = line.replace(":", ". ")
line = line.replace(";", ". ")
line = line.replace("(", ". ")
line = line.replace(")", ". ")
for x in line.split(". "):
sentences.append(x.strip())
sentences.append("<PAUSE>")
# Synthesis text
for text in sentences:
if text == "":
continue
# Synthesis text
for text in sentences:
if text == "":
continue
if text == "<PAUSE>":
full_wave_file.extend(np.zeros(rate))
continue
if text == "<PAUSE>":
full_wave_file.extend(np.zeros(rate))
continue
sample = TTSHubInterface.get_model_input(task, text)
wav, rate = TTSHubInterface.get_prediction(task, models[0], generator, sample)
sample = TTSHubInterface.get_model_input(task, text)
wav, rate = TTSHubInterface.get_prediction(task, models[0], generator, sample)
wav = wav.numpy()
full_wave_file.extend(wav)
wav = wav.numpy()
full_wave_file.extend(wav)
full_wave_file = np.array(full_wave_file, dtype=np.float32)
scipy.io.wavfile.write("test.wav", rate, full_wave_file)
full_wave_file = np.array(full_wave_file, dtype=np.float32)
scipy.io.wavfile.write("test.wav", rate, full_wave_file)
if __name__ == "__main__":
main()