@ -667,7 +667,7 @@ class TrainingState():
self . steps = int ( self . info [ ' steps ' ] )
if ' iteration_rate ' in self . info :
it_rate = self . info [ ' iteration_rate ' ]
it_rate = self . info [ ' iteration_rate ' ] / self . batch_size # why
self . it_rate = f ' { " {:.3f} " . format ( 1 / it_rate ) } it/s ' if 0 < it_rate and it_rate < 1 else f ' { " {:.3f} " . format ( it_rate ) } s/it '
self . it_rates + = it_rate
@ -676,6 +676,7 @@ class TrainingState():
eta = str ( timedelta ( seconds = int ( self . eta ) ) )
self . eta_hhmmss = eta
except Exception as e :
self . eta_hhmmss = " ? "
pass
self . metrics [ ' step ' ] = [ f " { self . epoch } / { self . epochs } " ]
@ -1064,13 +1065,16 @@ def whisper_transcribe( file, language=None ):
return result
def prepare_dataset ( files , outdir , language = None , skip_existings = False , progress= None ) :
def prepare_dataset ( files , outdir , language = None , skip_existings = False , slice_audio= False , progress= None ) :
unload_tts ( )
global whisper_model
if whisper_model is None :
load_whisper_model ( language = language )
if args . whisper_backend == " m-bain/whisperx " and slice_audio :
print ( " ! CAUTION ! Slicing audio with whisperx is terrible. Please consider using a different whisper backend if you want to slice audio. " )
os . makedirs ( f ' { outdir } /audio/ ' , exist_ok = True )
results = { }
@ -1092,6 +1096,14 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
if match [ 0 ] not in previous_list :
previous_list . append ( f ' { match [ 0 ] . split ( " / " ) [ - 1 ] } .wav ' )
def validate_waveform ( waveform , sample_rate , name ) :
if not torch . any ( waveform < 0 ) :
return False
if waveform . shape [ - 1 ] < ( .6 * sampling_rate ) :
return False
return True
for file in enumerate_progress ( files , desc = " Iterating through voice files " , progress = progress ) :
basename = os . path . basename ( file )
@ -1106,29 +1118,36 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
waveform , sampling_rate = torchaudio . load ( file )
num_channels , num_frames = waveform . shape
idx = 0
for segment in result [ ' segments ' ] : # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
start = int ( segment [ ' start ' ] * sampling_rate )
end = int ( segment [ ' end ' ] * sampling_rate )
sliced_waveform = waveform [ : , start : end ]
sliced_name = basename . replace ( " .wav " , f " _ { pad ( idx , 4 ) } .wav " )
if not torch . any ( sliced_waveform < 0 ) :
print ( f " Sound file is silent: { sliced_name } , skipping... " )
continue
if sliced_waveform . shape [ - 1 ] < ( .6 * sampling_rate ) :
print ( f " Sound file is too short: { sliced_name } , skipping... " )
if not slice_audio :
if not validate_waveform ( waveform , sampling_rate , name ) :
print ( f " Segment invalid: { name } , skipping... " )
continue
torchaudio . save ( f " { outdir } /audio/ { sliced_name } " , sliced_waveform , sampling_rate )
idx = idx + 1
line = f " audio/ { sliced_name } | { segment [ ' text ' ] . strip ( ) } "
torchaudio . save ( f " { outdir } /audio/ { basename } " , waveform , sampling_rate )
line = f " audio/ { basename } | { result [ ' text ' ] . strip ( ) } "
transcription . append ( line )
with open ( f ' { outdir } /train.txt ' , ' a ' , encoding = " utf-8 " ) as f :
f . write ( f ' \n { line } ' )
else :
idx = 0
for segment in result [ ' segments ' ] : # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
start = int ( segment [ ' start ' ] * sampling_rate )
end = int ( segment [ ' end ' ] * sampling_rate )
sliced_waveform = waveform [ : , start : end ]
sliced_name = basename . replace ( " .wav " , f " _ { pad ( idx , 4 ) } .wav " )
if not validate_waveform ( sliced_waveform , sampling_rate , sliced_name ) :
print ( f " Trimmed segment invalid: { sliced_name } , skipping... " )
continue
torchaudio . save ( f " { outdir } /audio/ { sliced_name } " , sliced_waveform , sampling_rate )
idx = idx + 1
line = f " audio/ { sliced_name } | { segment [ ' text ' ] . strip ( ) } "
transcription . append ( line )
with open ( f ' { outdir } /train.txt ' , ' a ' , encoding = " utf-8 " ) as f :
f . write ( f ' \n { line } ' )
do_gc ( )
@ -1144,7 +1163,7 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
return f " Processed dataset to: { outdir } \n { joined } "
def prepare_validation_dataset ( voice , text_length ) :
def prepare_validation_dataset ( voice , text_length , audio_length ) :
indir = f ' ./training/ { voice } / '
infile = f ' { indir } /dataset.txt '
if not os . path . exists ( infile ) :
@ -1166,8 +1185,14 @@ def prepare_validation_dataset( voice, text_length ):
split = line . split ( " | " )
filename = split [ 0 ]
text = split [ 1 ]
culled = len ( text ) < text_length
if len ( text ) < text_length :
if not culled and audio_length > 0 :
metadata = torchaudio . info ( f ' { indir } / { filename } ' )
duration = metadata . num_channels * metadata . num_frames / metadata . sample_rate
culled = duration < audio_length
if culled :
validation . append ( line . strip ( ) )
else :
training . append ( line . strip ( ) )
@ -1178,7 +1203,7 @@ def prepare_validation_dataset( voice, text_length ):
with open ( f ' { indir } /validation.txt ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( " \n " . join ( validation ) )
msg = f " Culled { len ( validation ) } lines"
msg = f " Culled { len ( validation ) } /{ len ( lines ) } lines. "
print ( msg )
return msg
@ -1896,6 +1921,9 @@ def load_tts( restart=False, autoregressive_model=None ):
print ( f " Loading TorToiSe... (AR: { autoregressive_model } , vocoder: { args . vocoder_model } ) " )
if get_device_name ( ) == " cpu " :
print ( " !!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch. " )
tts_loading = True
try :
tts = TextToSpeech ( minor_optimizations = not args . low_vram , autoregressive_model_path = autoregressive_model , vocoder_model = args . vocoder_model )