From 0ea93a7f40ffc5d7f520e620c236269835422c5c Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 23 Mar 2023 01:52:26 +0000 Subject: [PATCH] more cleanup, use 24KHz for preparing for VALL-E (encodec will resample to 24Khz anyways, makes audio a little nicer), some other things --- src/utils.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/utils.py b/src/utils.py index 80822e8..0f0cfe9 100755 --- a/src/utils.py +++ b/src/utils.py @@ -1269,6 +1269,12 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non os.makedirs(f'{indir}/audio/', exist_ok=True) + TARGET_SAMPLE_RATE = 22050 + if args.tts_backend == "vall-e": + TARGET_SAMPLE_RATE = 24000 + if tts: + TARGET_SAMPLE_RATE = tts.input_sample_rate + if os.path.exists(infile): results = json.load(open(infile, 'r', encoding="utf-8")) @@ -1300,7 +1306,7 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non waveform, sample_rate = torchaudio.load(file) # resample to the input rate, since it'll get resampled for training anyways # this should also "help" increase throughput a bit when filling the dataloaders - waveform, sample_rate = resample(waveform, sample_rate, tts.input_sample_rate if tts is not None else 22050) + waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE) if waveform.shape[0] == 2: waveform = waveform[:1] torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16) @@ -1341,6 +1347,12 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul if results is None: results = json.load(open(infile, 'r', encoding="utf-8")) + TARGET_SAMPLE_RATE = 22050 + if args.tts_backend == "vall-e": + TARGET_SAMPLE_RATE = 24000 + if tts: + TARGET_SAMPLE_RATE = tts.input_sample_rate + files = 0 segments = 0 for filename in results: @@ -1369,12 +1381,12 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul print(message) messages.append(message) continue - sliced, _ = resample( sliced, sample_rate, 22050 ) + sliced, _ = resample( sliced, sample_rate, TARGET_SAMPLE_RATE ) if waveform.shape[0] == 2: waveform = waveform[:1] - torchaudio.save(f"{indir}/audio/{file}", sliced, 22050, encoding="PCM_S", bits_per_sample=16) + torchaudio.save(f"{indir}/audio/{file}", sliced, TARGET_SAMPLE_RATE, encoding="PCM_S", bits_per_sample=16) segments +=1 @@ -1466,7 +1478,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p lines = { 'training': [], 'validation': [] } segments = {} - for filename in results: + for filename in enumerate_progress(results, desc="Parsing results", progress=progress): use_segment = use_segments result = results[filename] @@ -1636,7 +1648,6 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p open(phn_file, 'w', encoding='utf-8').write(" ".join(phonemized)) print("Phonemized:", file) - training_joined = "\n".join(lines['training']) validation_joined = "\n".join(lines['validation']) @@ -2713,7 +2724,7 @@ def load_whisper_model(language=None, model_name=None, progress=None): use_auth_token=args.hf_token, device=torch.device(device), ) - whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token) + # whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token) except Exception as e: pass