diff --git a/models/.template.yaml b/models/.template.yaml index 79ea8a4..4758591 100755 --- a/models/.template.yaml +++ b/models/.template.yaml @@ -7,8 +7,6 @@ checkpointing_enabled: true fp16: ${half_p} bitsandbytes: ${bitsandbytes} gpus: ${gpus} -wandb: false -use_tb_logger: true datasets: train: @@ -135,8 +133,7 @@ eval: output_state: gen logger: - print_freq: ${print_rate} save_checkpoint_freq: ${save_rate} visuals: [gen, mel] - visual_debug_rate: ${print_rate} + visual_debug_rate: ${save_rate} is_mel_spectrogram: true \ No newline at end of file diff --git a/modules/dlas b/modules/dlas index bf94744..802c162 160000 --- a/modules/dlas +++ b/modules/dlas @@ -1 +1 @@ -Subproject commit bf94744514e0628c6e6ba21eda76d1fd71fb1252 +Subproject commit 802c162ce816ac9e824bd82f64f6282019ae15d5 diff --git a/src/utils.py b/src/utils.py index 35fa6eb..d171571 100755 --- a/src/utils.py +++ b/src/utils.py @@ -609,20 +609,10 @@ class TrainingState(): self.open_state = False self.training_started = False - self.info = {} - - self.epoch_rate = "" - self.epoch_time_start = 0 - self.epoch_time_end = 0 - self.epoch_time_deltas = 0 - self.epoch_taken = 0 + self.info = {} self.it_rate = "" - self.it_time_start = 0 - self.it_time_end = 0 - self.it_time_deltas = 0 - self.it_taken = 0 - self.last_step = 0 + self.it_rates = 0 self.eta = "?" self.eta_hhmmss = "?" @@ -655,11 +645,63 @@ class TrainingState(): print("Spawning process: ", " ".join(self.cmd)) self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) + def parse_metrics(self, data): + if isinstance(data, str): + if line.find('INFO: Training Metrics:') >= 0: + data = json.loads(line.split("INFO: Training Metrics:")[-1]) + data['mode'] = "training" + elif line.find('INFO: Validation Metrics:') >= 0: + data = json.loads(line.split("INFO: Validation Metrics:")[-1]) + data['mode'] = "validation" + else: + return + + self.info = data + if 'epoch' in self.info: + self.epoch = int(self.info['epoch']) + if 'it' in self.info: + self.it = int(self.info['it']) + if 'step' in self.info: + self.step = int(self.info['step']) + if 'steps' in self.info: + self.steps = int(self.info['steps']) + + if 'iteration_rate' in self.info: + it_rate = self.info['iteration_rate'] + self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it' + self.it_rates += it_rate + + self.eta = (self.its - self.it) * (self.it_rates / self.its) + try: + eta = str(timedelta(seconds=int(self.eta))) + self.eta_hhmmss = eta + except Exception as e: + pass + + self.metrics['step'] = [f"{self.epoch}/{self.epochs}"] + if self.epochs != self.its: + self.metrics['step'].append(f"{self.it}/{self.its}") + if self.steps > 1: + self.metrics['step'].append(f"{self.step}/{self.steps}") + self.metrics['step'] = ", ".join(self.metrics['step']) + + if 'lr' in self.info: + self.statistics['lr'].append({'step': self.it, 'value': self.info['lr'], 'type': 'learning_rate'}) + + for k in ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']: + if k not in self.info: + continue + + self.statistics['loss'].append({'step': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' }) + if k == "loss_gpt_total": + self.losses.append( self.statistics['loss'][-1] ) + + return data + def load_statistics(self, update=False): if not os.path.isdir(f'{self.dataset_dir}/'): return - keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce', 'learning_rate_gpt_0'] infos = {} highest_step = self.last_info_check_at @@ -677,34 +719,23 @@ class TrainingState(): for line in lines: if line.find('INFO: Training Metrics:') >= 0: - data = line.split("INFO: Training Metrics:")[-1] - info = json.loads(data) - - step = info['it'] - if update and step <= self.last_info_check_at: - continue - - if 'lr' in info: - self.statistics['lr'].append({'step': step, 'value': info['lr'], 'type': 'learning_rate_gpt_0'}) - if 'loss_text_ce' in info: - self.statistics['loss'].append({'step': step, 'value': info['loss_text_ce'], 'type': 'loss_text_ce'}) - if 'loss_mel_ce' in info: - self.statistics['loss'].append({'step': step, 'value': info['loss_mel_ce'], 'type': 'loss_mel_ce'}) - if 'loss_gpt_total' in info: - self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'loss_gpt_total'}) - self.losses.append( self.statistics['loss'][-1] ) - + data = json.loads(line.split("INFO: Training Metrics:")[-1]) + data['mode'] = "training" elif line.find('INFO: Validation Metrics:') >= 0: - data = line.split("INFO: Validation Metrics:")[-1] + data = json.loads(line.split("INFO: Validation Metrics:")[-1]) + data['mode'] = "validation" + else: + continue - step = info['it'] - if update and step <= self.last_info_check_at: - continue + if "it" not in data: + continue - if 'loss_text_ce' in info: - self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_text_ce'}) - if 'loss_mel_ce' in info: - self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_mel_ce'}) + step = data['it'] + + if update and step <= self.last_info_check_at: + continue + + self.parse_metrics(data) self.last_info_check_at = highest_step @@ -741,10 +772,7 @@ class TrainingState(): # rip out iteration info elif not self.training_started: if line.find('Start training from epoch') >= 0: - self.it_time_start = time.time() - self.epoch_time_start = time.time() self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations - should_return = True match = re.findall(r'epoch: ([\d,]+)', line) if match and len(match) > 0: @@ -754,90 +782,51 @@ class TrainingState(): self.it = int(match[0].replace(",", "")) self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq']) + + should_return = True else: - lapsed = False message = None + data = None # INFO: Training Metrics: {"loss_text_ce": 4.308311939239502, "loss_mel_ce": 2.1610655784606934, "loss_gpt_total": 2.204148769378662, "lr": 0.0001, "it": 2, "step": 1, "steps": 1, "epoch": 1, "iteration_rate": 0.10700102965037028} if line.find('INFO: Training Metrics:') >= 0: - data = line.split("INFO: Training Metrics:")[-1] - self.info = json.loads(data) - - if 'epoch' in self.info: - self.epoch = int(self.info['epoch']) - if 'it' in self.info: - self.it = int(self.info['it']) - if 'step' in self.info: - self.step = int(self.info['step']) - if 'steps' in self.info: - self.steps = int(self.info['steps']) - - if self.step == self.steps: - lapsed = True - - if 'lr' in self.info: - self.statistics['lr'].append({'step': self.it, 'value': self.info['lr'], 'type': 'learning_rate_gpt_0'}) - if 'loss_text_ce' in self.info: - self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_text_ce'], 'type': 'loss_text_ce'}) - if 'loss_mel_ce' in self.info: - self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_mel_ce'], 'type': 'loss_mel_ce'}) - if 'loss_gpt_total' in self.info: - self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'loss_gpt_total'}) - self.losses.append( self.statistics['loss'][-1] ) - - if 'iteration_rate' in self.info: - it_rate = self.info['iteration_rate'] - self.it_rate = f'{"{:.3f}".format(it_rate)}s/it' if it_rate >= 1 or it_rate == 0 else f'{"{:.3f}".format(1/it_rate)}it/s' - - self.metrics['step'] = [f"{self.epoch}/{self.epochs}"] - if self.epochs != self.its: - self.metrics['step'].append(f"{self.it}/{self.its}") - if self.steps > 1: - self.metrics['step'].append(f"{self.step}/{self.steps}") - self.metrics['step'] = ", ".join(self.metrics['step']) - - should_return = True + data = json.loads(line.split("INFO: Training Metrics:")[-1]) + data['mode'] = "training" elif line.find('INFO: Validation Metrics:') >= 0: - data = line.split("INFO: Validation Metrics:")[-1] + data = json.loads(line.split("INFO: Validation Metrics:")[-1]) + data['mode'] = "validation" - if 'loss_text_ce' in self.info: - self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_text_ce'}) - if 'loss_mel_ce' in self.info: - self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_mel_ce'}) + if data is not None: + self.parse_metrics( data ) should_return = True - if lapsed: + if ': nan' in line and not self.nan_detected: + self.nan_detected = self.it + + """ + if self.step == self.steps and self.steps > 0: self.epoch_time_end = time.time() self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start self.epoch_time_start = time.time() try: - self.epoch_rate = f'{"{:.3f}".format(self.epoch_time_delta)}s/epoch' if self.epoch_time_delta >= 1 or self.epoch_time_delta == 0 else f'{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s' # I doubt anyone will have it/s rates, but its here - except Exception as e: - pass - - #self.eta = (self.epochs - self.epoch) * self.epoch_time_delta - self.epoch_time_deltas = self.epoch_time_deltas + self.epoch_time_delta - self.epoch_taken = self.epoch_taken + 1 - self.eta = (self.epochs - self.epoch) * (self.epoch_time_deltas / self.epoch_taken) - try: - eta = str(timedelta(seconds=int(self.eta))) - self.eta_hhmmss = eta + self.epoch_rate = f'{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s' if 0 < self.epoch_time_delta and self.epoch_time_delta < 1 else f'{"{:.3f}".format(self.epoch_time_delta)}s/epoch' except Exception as e: pass + """ self.metrics['rate'] = [] + """ if self.epoch_rate: self.metrics['rate'].append(self.epoch_rate) if self.it_rate and self.epoch_rate != self.it_rate: + """ + if self.it_rate: self.metrics['rate'].append(self.it_rate) self.metrics['rate'] = ", ".join(self.metrics['rate']) - eta_hhmmss = "?" - if self.eta_hhmmss: - eta_hhmmss = self.eta_hhmmss - - self.metrics['loss'] = [] + eta_hhmmss = self.eta_hhmmss if self.eta_hhmmss else "?" + self.metrics['loss'] = [] if 'lr' in self.info: self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["lr"])}') @@ -1126,7 +1115,11 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav") if not torch.any(sliced_waveform < 0): - print(f"Error with {sliced_name}, skipping...") + print(f"Sound file is silent: {sliced_name}, skipping...") + continue + + if sliced_waveform.shape[-1] < (.6 * sampling_rate): + print(f"Sound file is too short: {sliced_name}, skipping...") continue torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate)