From f4484fd15588f5e8c5b896f075a75b372290630d Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 6 Jan 2022 12:38:20 -0700 Subject: [PATCH] Add "dataset_debugger" support This allows the datasets themselves compile statistics and report them via tensorboard and wandb. --- codes/data/__init__.py | 11 +++++-- .../data/audio/paired_voice_audio_dataset.py | 33 ++++++++++++++++++- codes/train.py | 14 ++++++-- codes/trainer/ExtensibleTrainer.py | 2 +- codes/trainer/base_model.py | 6 ++-- codes/utils/convert_model.py | 2 +- 6 files changed, 57 insertions(+), 11 deletions(-) diff --git a/codes/data/__init__.py b/codes/data/__init__.py index 1ad2762e..a12768da 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -71,13 +71,10 @@ def create_dataset(dataset_opt, return_collate=False): collate = C() elif mode == 'paired_voice_audio': from data.audio.paired_voice_audio_dataset import TextWavLoader as D - from data.audio.paired_voice_audio_dataset import TextMelCollate as C from models.tacotron2.hparams import create_hparams default_params = create_hparams() default_params.update(dataset_opt) dataset_opt = munchify(default_params) - if opt_get(dataset_opt, ['needs_collate'], True): - collate = C() elif mode == 'gpt_tts': from data.audio.gpt_tts_dataset import GptTtsDataset as D from data.audio.gpt_tts_dataset import GptTtsCollater as C @@ -99,3 +96,11 @@ def create_dataset(dataset_opt, return_collate=False): return dataset, collate else: return dataset + + +def get_dataset_debugger(dataset_opt): + mode = dataset_opt['mode'] + if mode == 'paired_voice_audio': + from data.audio.paired_voice_audio_dataset import PairedVoiceDebugger + return PairedVoiceDebugger() + return None \ No newline at end of file diff --git a/codes/data/audio/paired_voice_audio_dataset.py b/codes/data/audio/paired_voice_audio_dataset.py index 5990cc55..9d162c48 100644 --- a/codes/data/audio/paired_voice_audio_dataset.py +++ b/codes/data/audio/paired_voice_audio_dataset.py @@ -120,7 +120,7 @@ class TextWavLoader(torch.utils.data.Dataset): try: tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index]) cond, cond_is_self = load_similar_clips(self.audiopaths_and_text[index][0], self.conditioning_length, self.sample_rate, - n=self.conditioning_candidates) if self.load_conditioning else None + n=self.conditioning_candidates) if self.load_conditioning else None, False except: if self.skipped_items > 100: raise # Rethrow if we have nested too far. @@ -162,6 +162,37 @@ class TextWavLoader(torch.utils.data.Dataset): return len(self.audiopaths_and_text) +class PairedVoiceDebugger: + def __init__(self): + self.total_items = 0 + self.loaded_items = 0 + self.self_conditioning_items = 0 + + def get_state(self): + return {'total_items': self.total_items, + 'loaded_items': self.loaded_items, + 'self_conditioning_items': self.self_conditioning_items} + + def load_state(self, state): + if isinstance(state, dict): + self.total_items = opt_get(state, ['total_items'], 0) + self.loaded_items = opt_get(state, ['loaded_items'], 0) + self.self_conditioning_items = opt_get(state, ['self_conditioning_items'], 0) + + def update(self, batch): + self.total_items += batch['wav'].shape[0] + self.loaded_items += batch['skipped_items'].sum().item() + if 'conditioning' in batch.keys(): + self.self_conditioning_items += batch['conditioning_contains_self'].sum().item() + + def get_debugging_map(self): + return { + 'total_samples_loaded': self.total_items, + 'percent_skipped_samples': (self.loaded_items - self.total_items) / self.loaded_items, + 'percent_conditioning_is_self': self.self_conditioning_items / self.loaded_items, + } + + if __name__ == '__main__': batch_sz = 8 params = { diff --git a/codes/train.py b/codes/train.py index 9f7349ce..a3bd9b35 100644 --- a/codes/train.py +++ b/codes/train.py @@ -12,7 +12,7 @@ from data.data_sampler import DistIterSampler from trainer.eval.evaluator import create_evaluator from utils import util, options as option -from data import create_dataloader, create_dataset +from data import create_dataloader, create_dataset, get_dataset_debugger from trainer.ExtensibleTrainer import ExtensibleTrainer from time import time from datetime import datetime @@ -110,6 +110,9 @@ class Trainer: for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': self.train_set, collate_fn = create_dataset(dataset_opt, return_collate=True) + self.dataset_debugger = get_dataset_debugger(dataset_opt) + if self.dataset_debugger is not None and resume_state is not None: + self.dataset_debugger.load_state(opt_get(resume_state, ['dataset_debugger_state'], {})) train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size'])) total_iters = int(opt['train']['niter']) self.total_epochs = int(math.ceil(total_iters / train_size)) @@ -187,8 +190,12 @@ class Trainer: _t = time() #### log + if self.dataset_debugger is not None: + self.dataset_debugger.update(train_data) if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0: logs = self.model.get_current_log(self.current_step) + if self.dataset_debugger is not None: + logs.update(self.dataset_debugger.get_debugging_map()) message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step) for v in self.model.get_current_learning_rate(): message += '{:.3e},'.format(v) @@ -213,7 +220,10 @@ class Trainer: if self.rank <= 0: self.logger.info('Saving models and training states.') self.model.save(self.current_step) - self.model.save_training_state(self.epoch, self.current_step) + state = {'epoch': self.epoch, 'iter': self.current_step} + if self.dataset_debugger is not None: + state['dataset_debugger_state'] = self.dataset_debugger.get_state() + self.model.save_training_state(state) if 'alt_path' in opt['path'].keys(): import shutil print("Synchronizing tb_logger to alt_path..") diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index af131394..1df4b507 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -270,7 +270,7 @@ class ExtensibleTrainer(BaseModel): if self.auto_recover is None: print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.") self.save(step) - self.save_training_state(0, step) + self.save_training_state({'iter': step}) raise ArithmeticError else: print(f"!!!!!!!!Detected NaN grads more than 10 steps in a row. Restoring to a state {self.auto_recover} saves ago.") diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index e66edcb8..b5563341 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -127,16 +127,16 @@ class BaseModel(): load_net_clean[k] = v network.load_state_dict(load_net_clean, strict=strict) - def save_training_state(self, epoch, iter_step): + def save_training_state(self, state): """Save training state during training, which will be used for resuming""" - state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []} + state.update({'schedulers': [], 'optimizers': []}) for s in self.schedulers: state['schedulers'].append(s.state_dict()) for o in self.optimizers: state['optimizers'].append(o.state_dict()) if 'amp_opt_level' in self.opt.keys(): state['amp'] = amp.state_dict() - save_filename = '{}.state'.format(iter_step) + save_filename = '{}.state'.format(utils.util.opt_get(state, ['iter'], 'no_step_provided')) save_path = os.path.join(self.opt['path']['training_state'], save_filename) torch.save(state, save_path) if '__state__' not in self.save_history.keys(): diff --git a/codes/utils/convert_model.py b/codes/utils/convert_model.py index fa0c0caf..10ada832 100644 --- a/codes/utils/convert_model.py +++ b/codes/utils/convert_model.py @@ -61,7 +61,7 @@ if __name__ == "__main__": # Also convert the state. resume_state_from = torch.load(opt_from['path']['resume_state']) - resume_state_to = model_to.save_training_state(0, 0, return_state=True) + resume_state_to = model_to.save_training_state({}, return_state=True) resume_state_from['optimizers'][0]['param_groups'].append(resume_state_to['optimizers'][0]['param_groups'][-1]) torch.save(resume_state_from, "converted_state.pth")