forked from camenduru/ai-voice-cloning
cleanups and fixes, fix DLAS throwing errors from '''too short of sound files''' by just culling them during transcription
This commit is contained in:
parent
7f2da0f5fb
commit
2feb6da0c0
|
@ -7,8 +7,6 @@ checkpointing_enabled: true
|
|||
fp16: ${half_p}
|
||||
bitsandbytes: ${bitsandbytes}
|
||||
gpus: ${gpus}
|
||||
wandb: false
|
||||
use_tb_logger: true
|
||||
|
||||
datasets:
|
||||
train:
|
||||
|
@ -135,8 +133,7 @@ eval:
|
|||
output_state: gen
|
||||
|
||||
logger:
|
||||
print_freq: ${print_rate}
|
||||
save_checkpoint_freq: ${save_rate}
|
||||
visuals: [gen, mel]
|
||||
visual_debug_rate: ${print_rate}
|
||||
visual_debug_rate: ${save_rate}
|
||||
is_mel_spectrogram: true
|
|
@ -1 +1 @@
|
|||
Subproject commit bf94744514e0628c6e6ba21eda76d1fd71fb1252
|
||||
Subproject commit 802c162ce816ac9e824bd82f64f6282019ae15d5
|
197
src/utils.py
197
src/utils.py
|
@ -609,20 +609,10 @@ class TrainingState():
|
|||
self.open_state = False
|
||||
self.training_started = False
|
||||
|
||||
self.info = {}
|
||||
|
||||
self.epoch_rate = ""
|
||||
self.epoch_time_start = 0
|
||||
self.epoch_time_end = 0
|
||||
self.epoch_time_deltas = 0
|
||||
self.epoch_taken = 0
|
||||
self.info = {}
|
||||
|
||||
self.it_rate = ""
|
||||
self.it_time_start = 0
|
||||
self.it_time_end = 0
|
||||
self.it_time_deltas = 0
|
||||
self.it_taken = 0
|
||||
self.last_step = 0
|
||||
self.it_rates = 0
|
||||
|
||||
self.eta = "?"
|
||||
self.eta_hhmmss = "?"
|
||||
|
@ -655,11 +645,63 @@ class TrainingState():
|
|||
print("Spawning process: ", " ".join(self.cmd))
|
||||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||
|
||||
def parse_metrics(self, data):
|
||||
if isinstance(data, str):
|
||||
if line.find('INFO: Training Metrics:') >= 0:
|
||||
data = json.loads(line.split("INFO: Training Metrics:")[-1])
|
||||
data['mode'] = "training"
|
||||
elif line.find('INFO: Validation Metrics:') >= 0:
|
||||
data = json.loads(line.split("INFO: Validation Metrics:")[-1])
|
||||
data['mode'] = "validation"
|
||||
else:
|
||||
return
|
||||
|
||||
self.info = data
|
||||
if 'epoch' in self.info:
|
||||
self.epoch = int(self.info['epoch'])
|
||||
if 'it' in self.info:
|
||||
self.it = int(self.info['it'])
|
||||
if 'step' in self.info:
|
||||
self.step = int(self.info['step'])
|
||||
if 'steps' in self.info:
|
||||
self.steps = int(self.info['steps'])
|
||||
|
||||
if 'iteration_rate' in self.info:
|
||||
it_rate = self.info['iteration_rate']
|
||||
self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it'
|
||||
self.it_rates += it_rate
|
||||
|
||||
self.eta = (self.its - self.it) * (self.it_rates / self.its)
|
||||
try:
|
||||
eta = str(timedelta(seconds=int(self.eta)))
|
||||
self.eta_hhmmss = eta
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
|
||||
if self.epochs != self.its:
|
||||
self.metrics['step'].append(f"{self.it}/{self.its}")
|
||||
if self.steps > 1:
|
||||
self.metrics['step'].append(f"{self.step}/{self.steps}")
|
||||
self.metrics['step'] = ", ".join(self.metrics['step'])
|
||||
|
||||
if 'lr' in self.info:
|
||||
self.statistics['lr'].append({'step': self.it, 'value': self.info['lr'], 'type': 'learning_rate'})
|
||||
|
||||
for k in ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']:
|
||||
if k not in self.info:
|
||||
continue
|
||||
|
||||
self.statistics['loss'].append({'step': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
||||
if k == "loss_gpt_total":
|
||||
self.losses.append( self.statistics['loss'][-1] )
|
||||
|
||||
return data
|
||||
|
||||
def load_statistics(self, update=False):
|
||||
if not os.path.isdir(f'{self.dataset_dir}/'):
|
||||
return
|
||||
|
||||
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce', 'learning_rate_gpt_0']
|
||||
infos = {}
|
||||
highest_step = self.last_info_check_at
|
||||
|
||||
|
@ -677,34 +719,23 @@ class TrainingState():
|
|||
|
||||
for line in lines:
|
||||
if line.find('INFO: Training Metrics:') >= 0:
|
||||
data = line.split("INFO: Training Metrics:")[-1]
|
||||
info = json.loads(data)
|
||||
|
||||
step = info['it']
|
||||
if update and step <= self.last_info_check_at:
|
||||
continue
|
||||
|
||||
if 'lr' in info:
|
||||
self.statistics['lr'].append({'step': step, 'value': info['lr'], 'type': 'learning_rate_gpt_0'})
|
||||
if 'loss_text_ce' in info:
|
||||
self.statistics['loss'].append({'step': step, 'value': info['loss_text_ce'], 'type': 'loss_text_ce'})
|
||||
if 'loss_mel_ce' in info:
|
||||
self.statistics['loss'].append({'step': step, 'value': info['loss_mel_ce'], 'type': 'loss_mel_ce'})
|
||||
if 'loss_gpt_total' in info:
|
||||
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'loss_gpt_total'})
|
||||
self.losses.append( self.statistics['loss'][-1] )
|
||||
|
||||
data = json.loads(line.split("INFO: Training Metrics:")[-1])
|
||||
data['mode'] = "training"
|
||||
elif line.find('INFO: Validation Metrics:') >= 0:
|
||||
data = line.split("INFO: Validation Metrics:")[-1]
|
||||
data = json.loads(line.split("INFO: Validation Metrics:")[-1])
|
||||
data['mode'] = "validation"
|
||||
else:
|
||||
continue
|
||||
|
||||
step = info['it']
|
||||
if update and step <= self.last_info_check_at:
|
||||
continue
|
||||
if "it" not in data:
|
||||
continue
|
||||
|
||||
if 'loss_text_ce' in info:
|
||||
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_text_ce'})
|
||||
if 'loss_mel_ce' in info:
|
||||
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_mel_ce'})
|
||||
step = data['it']
|
||||
|
||||
if update and step <= self.last_info_check_at:
|
||||
continue
|
||||
|
||||
self.parse_metrics(data)
|
||||
|
||||
self.last_info_check_at = highest_step
|
||||
|
||||
|
@ -741,10 +772,7 @@ class TrainingState():
|
|||
# rip out iteration info
|
||||
elif not self.training_started:
|
||||
if line.find('Start training from epoch') >= 0:
|
||||
self.it_time_start = time.time()
|
||||
self.epoch_time_start = time.time()
|
||||
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
|
||||
should_return = True
|
||||
|
||||
match = re.findall(r'epoch: ([\d,]+)', line)
|
||||
if match and len(match) > 0:
|
||||
|
@ -754,90 +782,51 @@ class TrainingState():
|
|||
self.it = int(match[0].replace(",", ""))
|
||||
|
||||
self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq'])
|
||||
|
||||
should_return = True
|
||||
else:
|
||||
lapsed = False
|
||||
message = None
|
||||
data = None
|
||||
|
||||
# INFO: Training Metrics: {"loss_text_ce": 4.308311939239502, "loss_mel_ce": 2.1610655784606934, "loss_gpt_total": 2.204148769378662, "lr": 0.0001, "it": 2, "step": 1, "steps": 1, "epoch": 1, "iteration_rate": 0.10700102965037028}
|
||||
if line.find('INFO: Training Metrics:') >= 0:
|
||||
data = line.split("INFO: Training Metrics:")[-1]
|
||||
self.info = json.loads(data)
|
||||
|
||||
if 'epoch' in self.info:
|
||||
self.epoch = int(self.info['epoch'])
|
||||
if 'it' in self.info:
|
||||
self.it = int(self.info['it'])
|
||||
if 'step' in self.info:
|
||||
self.step = int(self.info['step'])
|
||||
if 'steps' in self.info:
|
||||
self.steps = int(self.info['steps'])
|
||||
|
||||
if self.step == self.steps:
|
||||
lapsed = True
|
||||
|
||||
if 'lr' in self.info:
|
||||
self.statistics['lr'].append({'step': self.it, 'value': self.info['lr'], 'type': 'learning_rate_gpt_0'})
|
||||
if 'loss_text_ce' in self.info:
|
||||
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_text_ce'], 'type': 'loss_text_ce'})
|
||||
if 'loss_mel_ce' in self.info:
|
||||
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_mel_ce'], 'type': 'loss_mel_ce'})
|
||||
if 'loss_gpt_total' in self.info:
|
||||
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'loss_gpt_total'})
|
||||
self.losses.append( self.statistics['loss'][-1] )
|
||||
|
||||
if 'iteration_rate' in self.info:
|
||||
it_rate = self.info['iteration_rate']
|
||||
self.it_rate = f'{"{:.3f}".format(it_rate)}s/it' if it_rate >= 1 or it_rate == 0 else f'{"{:.3f}".format(1/it_rate)}it/s'
|
||||
|
||||
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
|
||||
if self.epochs != self.its:
|
||||
self.metrics['step'].append(f"{self.it}/{self.its}")
|
||||
if self.steps > 1:
|
||||
self.metrics['step'].append(f"{self.step}/{self.steps}")
|
||||
self.metrics['step'] = ", ".join(self.metrics['step'])
|
||||
|
||||
should_return = True
|
||||
data = json.loads(line.split("INFO: Training Metrics:")[-1])
|
||||
data['mode'] = "training"
|
||||
elif line.find('INFO: Validation Metrics:') >= 0:
|
||||
data = line.split("INFO: Validation Metrics:")[-1]
|
||||
data = json.loads(line.split("INFO: Validation Metrics:")[-1])
|
||||
data['mode'] = "validation"
|
||||
|
||||
if 'loss_text_ce' in self.info:
|
||||
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_text_ce'})
|
||||
if 'loss_mel_ce' in self.info:
|
||||
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_mel_ce'})
|
||||
if data is not None:
|
||||
self.parse_metrics( data )
|
||||
should_return = True
|
||||
|
||||
if lapsed:
|
||||
if ': nan' in line and not self.nan_detected:
|
||||
self.nan_detected = self.it
|
||||
|
||||
"""
|
||||
if self.step == self.steps and self.steps > 0:
|
||||
self.epoch_time_end = time.time()
|
||||
self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start
|
||||
self.epoch_time_start = time.time()
|
||||
try:
|
||||
self.epoch_rate = f'{"{:.3f}".format(self.epoch_time_delta)}s/epoch' if self.epoch_time_delta >= 1 or self.epoch_time_delta == 0 else f'{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s' # I doubt anyone will have it/s rates, but its here
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
#self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
|
||||
self.epoch_time_deltas = self.epoch_time_deltas + self.epoch_time_delta
|
||||
self.epoch_taken = self.epoch_taken + 1
|
||||
self.eta = (self.epochs - self.epoch) * (self.epoch_time_deltas / self.epoch_taken)
|
||||
try:
|
||||
eta = str(timedelta(seconds=int(self.eta)))
|
||||
self.eta_hhmmss = eta
|
||||
self.epoch_rate = f'{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s' if 0 < self.epoch_time_delta and self.epoch_time_delta < 1 else f'{"{:.3f}".format(self.epoch_time_delta)}s/epoch'
|
||||
except Exception as e:
|
||||
pass
|
||||
"""
|
||||
|
||||
self.metrics['rate'] = []
|
||||
"""
|
||||
if self.epoch_rate:
|
||||
self.metrics['rate'].append(self.epoch_rate)
|
||||
if self.it_rate and self.epoch_rate != self.it_rate:
|
||||
"""
|
||||
if self.it_rate:
|
||||
self.metrics['rate'].append(self.it_rate)
|
||||
self.metrics['rate'] = ", ".join(self.metrics['rate'])
|
||||
|
||||
eta_hhmmss = "?"
|
||||
if self.eta_hhmmss:
|
||||
eta_hhmmss = self.eta_hhmmss
|
||||
|
||||
self.metrics['loss'] = []
|
||||
eta_hhmmss = self.eta_hhmmss if self.eta_hhmmss else "?"
|
||||
|
||||
self.metrics['loss'] = []
|
||||
if 'lr' in self.info:
|
||||
self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["lr"])}')
|
||||
|
||||
|
@ -1126,7 +1115,11 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
|
|||
sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
|
||||
|
||||
if not torch.any(sliced_waveform < 0):
|
||||
print(f"Error with {sliced_name}, skipping...")
|
||||
print(f"Sound file is silent: {sliced_name}, skipping...")
|
||||
continue
|
||||
|
||||
if sliced_waveform.shape[-1] < (.6 * sampling_rate):
|
||||
print(f"Sound file is too short: {sliced_name}, skipping...")
|
||||
continue
|
||||
|
||||
torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate)
|
||||
|
|
Loading…
Reference in New Issue
Block a user