1
0
Fork 0

rocm5.2 works for me desu so I bumped it back up

master
mrq 2023-03-11 17:02:56 +07:00
parent e680d84a13
commit e3fdb79b49
2 changed files with 6 additions and 6 deletions

@ -7,7 +7,7 @@ python3 -m venv venv
source ./venv/bin/activate
python3 -m pip install --upgrade pip # just to be safe
# ROCM
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.1.1 # 5.2 does not work for me desu
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2
# install requirements
python3 -m pip install -r ./modules/tortoise-tts/requirements.txt # install TorToiSe requirements
python3 -m pip install -e ./modules/tortoise-tts/ # install TorToiSe

@ -1079,7 +1079,7 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, slice_a
if match[0] not in previous_list:
previous_list.append(f'{match[0].split("/")[-1]}.wav')
def validate_waveform( waveform, sample_rate, name ):
def validate_waveform( waveform, sample_rate ):
if not torch.any(waveform < 0):
return False
@ -1102,8 +1102,8 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, slice_a
num_channels, num_frames = waveform.shape
if not slice_audio:
if not validate_waveform( waveform, sampling_rate, name ):
print(f"Segment invalid: {name}, skipping...")
if not validate_waveform( waveform, sampling_rate ):
print(f"Invalid waveform: {basename}, skipping...")
continue
torchaudio.save(f"{outdir}/audio/{basename}", waveform, sampling_rate)
@ -1120,8 +1120,8 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, slice_a
sliced_waveform = waveform[:, start:end]
sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
if not validate_waveform( sliced_waveform, sampling_rate, sliced_name ):
print(f"Trimmed segment invalid: {sliced_name}, skipping...")
if not validate_waveform( sliced_waveform, sampling_rate ):
print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {sliced_name}, skipping...")
continue
torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate)