cleanups and fixes, fix DLAS throwing errors from '''too short of sound files''' by just culling them during transcription

This commit is contained in:
mrq 2023-03-11 01:19:49 +00:00
parent 7f2da0f5fb
commit 2feb6da0c0
3 changed files with 97 additions and 107 deletions

View File

@ -7,8 +7,6 @@ checkpointing_enabled: true
fp16: ${half_p} fp16: ${half_p}
bitsandbytes: ${bitsandbytes} bitsandbytes: ${bitsandbytes}
gpus: ${gpus} gpus: ${gpus}
wandb: false
use_tb_logger: true
datasets: datasets:
train: train:
@ -135,8 +133,7 @@ eval:
output_state: gen output_state: gen
logger: logger:
print_freq: ${print_rate}
save_checkpoint_freq: ${save_rate} save_checkpoint_freq: ${save_rate}
visuals: [gen, mel] visuals: [gen, mel]
visual_debug_rate: ${print_rate} visual_debug_rate: ${save_rate}
is_mel_spectrogram: true is_mel_spectrogram: true

@ -1 +1 @@
Subproject commit bf94744514e0628c6e6ba21eda76d1fd71fb1252 Subproject commit 802c162ce816ac9e824bd82f64f6282019ae15d5

View File

@ -609,20 +609,10 @@ class TrainingState():
self.open_state = False self.open_state = False
self.training_started = False self.training_started = False
self.info = {} self.info = {}
self.epoch_rate = ""
self.epoch_time_start = 0
self.epoch_time_end = 0
self.epoch_time_deltas = 0
self.epoch_taken = 0
self.it_rate = "" self.it_rate = ""
self.it_time_start = 0 self.it_rates = 0
self.it_time_end = 0
self.it_time_deltas = 0
self.it_taken = 0
self.last_step = 0
self.eta = "?" self.eta = "?"
self.eta_hhmmss = "?" self.eta_hhmmss = "?"
@ -655,11 +645,63 @@ class TrainingState():
print("Spawning process: ", " ".join(self.cmd)) print("Spawning process: ", " ".join(self.cmd))
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
def parse_metrics(self, data):
if isinstance(data, str):
if line.find('INFO: Training Metrics:') >= 0:
data = json.loads(line.split("INFO: Training Metrics:")[-1])
data['mode'] = "training"
elif line.find('INFO: Validation Metrics:') >= 0:
data = json.loads(line.split("INFO: Validation Metrics:")[-1])
data['mode'] = "validation"
else:
return
self.info = data
if 'epoch' in self.info:
self.epoch = int(self.info['epoch'])
if 'it' in self.info:
self.it = int(self.info['it'])
if 'step' in self.info:
self.step = int(self.info['step'])
if 'steps' in self.info:
self.steps = int(self.info['steps'])
if 'iteration_rate' in self.info:
it_rate = self.info['iteration_rate']
self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it'
self.it_rates += it_rate
self.eta = (self.its - self.it) * (self.it_rates / self.its)
try:
eta = str(timedelta(seconds=int(self.eta)))
self.eta_hhmmss = eta
except Exception as e:
pass
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
if self.epochs != self.its:
self.metrics['step'].append(f"{self.it}/{self.its}")
if self.steps > 1:
self.metrics['step'].append(f"{self.step}/{self.steps}")
self.metrics['step'] = ", ".join(self.metrics['step'])
if 'lr' in self.info:
self.statistics['lr'].append({'step': self.it, 'value': self.info['lr'], 'type': 'learning_rate'})
for k in ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']:
if k not in self.info:
continue
self.statistics['loss'].append({'step': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
if k == "loss_gpt_total":
self.losses.append( self.statistics['loss'][-1] )
return data
def load_statistics(self, update=False): def load_statistics(self, update=False):
if not os.path.isdir(f'{self.dataset_dir}/'): if not os.path.isdir(f'{self.dataset_dir}/'):
return return
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce', 'learning_rate_gpt_0']
infos = {} infos = {}
highest_step = self.last_info_check_at highest_step = self.last_info_check_at
@ -677,34 +719,23 @@ class TrainingState():
for line in lines: for line in lines:
if line.find('INFO: Training Metrics:') >= 0: if line.find('INFO: Training Metrics:') >= 0:
data = line.split("INFO: Training Metrics:")[-1] data = json.loads(line.split("INFO: Training Metrics:")[-1])
info = json.loads(data) data['mode'] = "training"
step = info['it']
if update and step <= self.last_info_check_at:
continue
if 'lr' in info:
self.statistics['lr'].append({'step': step, 'value': info['lr'], 'type': 'learning_rate_gpt_0'})
if 'loss_text_ce' in info:
self.statistics['loss'].append({'step': step, 'value': info['loss_text_ce'], 'type': 'loss_text_ce'})
if 'loss_mel_ce' in info:
self.statistics['loss'].append({'step': step, 'value': info['loss_mel_ce'], 'type': 'loss_mel_ce'})
if 'loss_gpt_total' in info:
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'loss_gpt_total'})
self.losses.append( self.statistics['loss'][-1] )
elif line.find('INFO: Validation Metrics:') >= 0: elif line.find('INFO: Validation Metrics:') >= 0:
data = line.split("INFO: Validation Metrics:")[-1] data = json.loads(line.split("INFO: Validation Metrics:")[-1])
data['mode'] = "validation"
else:
continue
step = info['it'] if "it" not in data:
if update and step <= self.last_info_check_at: continue
continue
if 'loss_text_ce' in info: step = data['it']
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_text_ce'})
if 'loss_mel_ce' in info: if update and step <= self.last_info_check_at:
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_mel_ce'}) continue
self.parse_metrics(data)
self.last_info_check_at = highest_step self.last_info_check_at = highest_step
@ -741,10 +772,7 @@ class TrainingState():
# rip out iteration info # rip out iteration info
elif not self.training_started: elif not self.training_started:
if line.find('Start training from epoch') >= 0: if line.find('Start training from epoch') >= 0:
self.it_time_start = time.time()
self.epoch_time_start = time.time()
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
should_return = True
match = re.findall(r'epoch: ([\d,]+)', line) match = re.findall(r'epoch: ([\d,]+)', line)
if match and len(match) > 0: if match and len(match) > 0:
@ -754,90 +782,51 @@ class TrainingState():
self.it = int(match[0].replace(",", "")) self.it = int(match[0].replace(",", ""))
self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq']) self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq'])
should_return = True
else: else:
lapsed = False
message = None message = None
data = None
# INFO: Training Metrics: {"loss_text_ce": 4.308311939239502, "loss_mel_ce": 2.1610655784606934, "loss_gpt_total": 2.204148769378662, "lr": 0.0001, "it": 2, "step": 1, "steps": 1, "epoch": 1, "iteration_rate": 0.10700102965037028} # INFO: Training Metrics: {"loss_text_ce": 4.308311939239502, "loss_mel_ce": 2.1610655784606934, "loss_gpt_total": 2.204148769378662, "lr": 0.0001, "it": 2, "step": 1, "steps": 1, "epoch": 1, "iteration_rate": 0.10700102965037028}
if line.find('INFO: Training Metrics:') >= 0: if line.find('INFO: Training Metrics:') >= 0:
data = line.split("INFO: Training Metrics:")[-1] data = json.loads(line.split("INFO: Training Metrics:")[-1])
self.info = json.loads(data) data['mode'] = "training"
if 'epoch' in self.info:
self.epoch = int(self.info['epoch'])
if 'it' in self.info:
self.it = int(self.info['it'])
if 'step' in self.info:
self.step = int(self.info['step'])
if 'steps' in self.info:
self.steps = int(self.info['steps'])
if self.step == self.steps:
lapsed = True
if 'lr' in self.info:
self.statistics['lr'].append({'step': self.it, 'value': self.info['lr'], 'type': 'learning_rate_gpt_0'})
if 'loss_text_ce' in self.info:
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_text_ce'], 'type': 'loss_text_ce'})
if 'loss_mel_ce' in self.info:
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_mel_ce'], 'type': 'loss_mel_ce'})
if 'loss_gpt_total' in self.info:
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'loss_gpt_total'})
self.losses.append( self.statistics['loss'][-1] )
if 'iteration_rate' in self.info:
it_rate = self.info['iteration_rate']
self.it_rate = f'{"{:.3f}".format(it_rate)}s/it' if it_rate >= 1 or it_rate == 0 else f'{"{:.3f}".format(1/it_rate)}it/s'
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
if self.epochs != self.its:
self.metrics['step'].append(f"{self.it}/{self.its}")
if self.steps > 1:
self.metrics['step'].append(f"{self.step}/{self.steps}")
self.metrics['step'] = ", ".join(self.metrics['step'])
should_return = True
elif line.find('INFO: Validation Metrics:') >= 0: elif line.find('INFO: Validation Metrics:') >= 0:
data = line.split("INFO: Validation Metrics:")[-1] data = json.loads(line.split("INFO: Validation Metrics:")[-1])
data['mode'] = "validation"
if 'loss_text_ce' in self.info: if data is not None:
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_text_ce'}) self.parse_metrics( data )
if 'loss_mel_ce' in self.info:
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_mel_ce'})
should_return = True should_return = True
if lapsed: if ': nan' in line and not self.nan_detected:
self.nan_detected = self.it
"""
if self.step == self.steps and self.steps > 0:
self.epoch_time_end = time.time() self.epoch_time_end = time.time()
self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start
self.epoch_time_start = time.time() self.epoch_time_start = time.time()
try: try:
self.epoch_rate = f'{"{:.3f}".format(self.epoch_time_delta)}s/epoch' if self.epoch_time_delta >= 1 or self.epoch_time_delta == 0 else f'{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s' # I doubt anyone will have it/s rates, but its here self.epoch_rate = f'{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s' if 0 < self.epoch_time_delta and self.epoch_time_delta < 1 else f'{"{:.3f}".format(self.epoch_time_delta)}s/epoch'
except Exception as e:
pass
#self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
self.epoch_time_deltas = self.epoch_time_deltas + self.epoch_time_delta
self.epoch_taken = self.epoch_taken + 1
self.eta = (self.epochs - self.epoch) * (self.epoch_time_deltas / self.epoch_taken)
try:
eta = str(timedelta(seconds=int(self.eta)))
self.eta_hhmmss = eta
except Exception as e: except Exception as e:
pass pass
"""
self.metrics['rate'] = [] self.metrics['rate'] = []
"""
if self.epoch_rate: if self.epoch_rate:
self.metrics['rate'].append(self.epoch_rate) self.metrics['rate'].append(self.epoch_rate)
if self.it_rate and self.epoch_rate != self.it_rate: if self.it_rate and self.epoch_rate != self.it_rate:
"""
if self.it_rate:
self.metrics['rate'].append(self.it_rate) self.metrics['rate'].append(self.it_rate)
self.metrics['rate'] = ", ".join(self.metrics['rate']) self.metrics['rate'] = ", ".join(self.metrics['rate'])
eta_hhmmss = "?" eta_hhmmss = self.eta_hhmmss if self.eta_hhmmss else "?"
if self.eta_hhmmss:
eta_hhmmss = self.eta_hhmmss
self.metrics['loss'] = []
self.metrics['loss'] = []
if 'lr' in self.info: if 'lr' in self.info:
self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["lr"])}') self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["lr"])}')
@ -1126,7 +1115,11 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav") sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
if not torch.any(sliced_waveform < 0): if not torch.any(sliced_waveform < 0):
print(f"Error with {sliced_name}, skipping...") print(f"Sound file is silent: {sliced_name}, skipping...")
continue
if sliced_waveform.shape[-1] < (.6 * sampling_rate):
print(f"Sound file is too short: {sliced_name}, skipping...")
continue continue
torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate) torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate)