more cleanup, use 24KHz for preparing for VALL-E (encodec will resample to 24Khz anyways, makes audio a little nicer), some other things

This commit is contained in:
mrq 2023-03-23 01:52:26 +00:00
parent d2a9ab9e41
commit 0ea93a7f40

View File

@ -1269,6 +1269,12 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
os.makedirs(f'{indir}/audio/', exist_ok=True) os.makedirs(f'{indir}/audio/', exist_ok=True)
TARGET_SAMPLE_RATE = 22050
if args.tts_backend == "vall-e":
TARGET_SAMPLE_RATE = 24000
if tts:
TARGET_SAMPLE_RATE = tts.input_sample_rate
if os.path.exists(infile): if os.path.exists(infile):
results = json.load(open(infile, 'r', encoding="utf-8")) results = json.load(open(infile, 'r', encoding="utf-8"))
@ -1300,7 +1306,7 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
waveform, sample_rate = torchaudio.load(file) waveform, sample_rate = torchaudio.load(file)
# resample to the input rate, since it'll get resampled for training anyways # resample to the input rate, since it'll get resampled for training anyways
# this should also "help" increase throughput a bit when filling the dataloaders # this should also "help" increase throughput a bit when filling the dataloaders
waveform, sample_rate = resample(waveform, sample_rate, tts.input_sample_rate if tts is not None else 22050) waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE)
if waveform.shape[0] == 2: if waveform.shape[0] == 2:
waveform = waveform[:1] waveform = waveform[:1]
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16) torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16)
@ -1341,6 +1347,12 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
if results is None: if results is None:
results = json.load(open(infile, 'r', encoding="utf-8")) results = json.load(open(infile, 'r', encoding="utf-8"))
TARGET_SAMPLE_RATE = 22050
if args.tts_backend == "vall-e":
TARGET_SAMPLE_RATE = 24000
if tts:
TARGET_SAMPLE_RATE = tts.input_sample_rate
files = 0 files = 0
segments = 0 segments = 0
for filename in results: for filename in results:
@ -1369,12 +1381,12 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
print(message) print(message)
messages.append(message) messages.append(message)
continue continue
sliced, _ = resample( sliced, sample_rate, 22050 ) sliced, _ = resample( sliced, sample_rate, TARGET_SAMPLE_RATE )
if waveform.shape[0] == 2: if waveform.shape[0] == 2:
waveform = waveform[:1] waveform = waveform[:1]
torchaudio.save(f"{indir}/audio/{file}", sliced, 22050, encoding="PCM_S", bits_per_sample=16) torchaudio.save(f"{indir}/audio/{file}", sliced, TARGET_SAMPLE_RATE, encoding="PCM_S", bits_per_sample=16)
segments +=1 segments +=1
@ -1466,7 +1478,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
lines = { 'training': [], 'validation': [] } lines = { 'training': [], 'validation': [] }
segments = {} segments = {}
for filename in results: for filename in enumerate_progress(results, desc="Parsing results", progress=progress):
use_segment = use_segments use_segment = use_segments
result = results[filename] result = results[filename]
@ -1636,7 +1648,6 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
open(phn_file, 'w', encoding='utf-8').write(" ".join(phonemized)) open(phn_file, 'w', encoding='utf-8').write(" ".join(phonemized))
print("Phonemized:", file) print("Phonemized:", file)
training_joined = "\n".join(lines['training']) training_joined = "\n".join(lines['training'])
validation_joined = "\n".join(lines['validation']) validation_joined = "\n".join(lines['validation'])
@ -2713,7 +2724,7 @@ def load_whisper_model(language=None, model_name=None, progress=None):
use_auth_token=args.hf_token, use_auth_token=args.hf_token,
device=torch.device(device), device=torch.device(device),
) )
whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token) # whisper_diarize = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=args.hf_token)
except Exception as e: except Exception as e:
pass pass