Add "dataset_debugger" support
This allows the datasets themselves compile statistics and report them via tensorboard and wandb.
This commit is contained in:
parent
f3cab45658
commit
f4484fd155
|
@ -71,13 +71,10 @@ def create_dataset(dataset_opt, return_collate=False):
|
||||||
collate = C()
|
collate = C()
|
||||||
elif mode == 'paired_voice_audio':
|
elif mode == 'paired_voice_audio':
|
||||||
from data.audio.paired_voice_audio_dataset import TextWavLoader as D
|
from data.audio.paired_voice_audio_dataset import TextWavLoader as D
|
||||||
from data.audio.paired_voice_audio_dataset import TextMelCollate as C
|
|
||||||
from models.tacotron2.hparams import create_hparams
|
from models.tacotron2.hparams import create_hparams
|
||||||
default_params = create_hparams()
|
default_params = create_hparams()
|
||||||
default_params.update(dataset_opt)
|
default_params.update(dataset_opt)
|
||||||
dataset_opt = munchify(default_params)
|
dataset_opt = munchify(default_params)
|
||||||
if opt_get(dataset_opt, ['needs_collate'], True):
|
|
||||||
collate = C()
|
|
||||||
elif mode == 'gpt_tts':
|
elif mode == 'gpt_tts':
|
||||||
from data.audio.gpt_tts_dataset import GptTtsDataset as D
|
from data.audio.gpt_tts_dataset import GptTtsDataset as D
|
||||||
from data.audio.gpt_tts_dataset import GptTtsCollater as C
|
from data.audio.gpt_tts_dataset import GptTtsCollater as C
|
||||||
|
@ -99,3 +96,11 @@ def create_dataset(dataset_opt, return_collate=False):
|
||||||
return dataset, collate
|
return dataset, collate
|
||||||
else:
|
else:
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_debugger(dataset_opt):
|
||||||
|
mode = dataset_opt['mode']
|
||||||
|
if mode == 'paired_voice_audio':
|
||||||
|
from data.audio.paired_voice_audio_dataset import PairedVoiceDebugger
|
||||||
|
return PairedVoiceDebugger()
|
||||||
|
return None
|
|
@ -120,7 +120,7 @@ class TextWavLoader(torch.utils.data.Dataset):
|
||||||
try:
|
try:
|
||||||
tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index])
|
tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index])
|
||||||
cond, cond_is_self = load_similar_clips(self.audiopaths_and_text[index][0], self.conditioning_length, self.sample_rate,
|
cond, cond_is_self = load_similar_clips(self.audiopaths_and_text[index][0], self.conditioning_length, self.sample_rate,
|
||||||
n=self.conditioning_candidates) if self.load_conditioning else None
|
n=self.conditioning_candidates) if self.load_conditioning else None, False
|
||||||
except:
|
except:
|
||||||
if self.skipped_items > 100:
|
if self.skipped_items > 100:
|
||||||
raise # Rethrow if we have nested too far.
|
raise # Rethrow if we have nested too far.
|
||||||
|
@ -162,6 +162,37 @@ class TextWavLoader(torch.utils.data.Dataset):
|
||||||
return len(self.audiopaths_and_text)
|
return len(self.audiopaths_and_text)
|
||||||
|
|
||||||
|
|
||||||
|
class PairedVoiceDebugger:
|
||||||
|
def __init__(self):
|
||||||
|
self.total_items = 0
|
||||||
|
self.loaded_items = 0
|
||||||
|
self.self_conditioning_items = 0
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return {'total_items': self.total_items,
|
||||||
|
'loaded_items': self.loaded_items,
|
||||||
|
'self_conditioning_items': self.self_conditioning_items}
|
||||||
|
|
||||||
|
def load_state(self, state):
|
||||||
|
if isinstance(state, dict):
|
||||||
|
self.total_items = opt_get(state, ['total_items'], 0)
|
||||||
|
self.loaded_items = opt_get(state, ['loaded_items'], 0)
|
||||||
|
self.self_conditioning_items = opt_get(state, ['self_conditioning_items'], 0)
|
||||||
|
|
||||||
|
def update(self, batch):
|
||||||
|
self.total_items += batch['wav'].shape[0]
|
||||||
|
self.loaded_items += batch['skipped_items'].sum().item()
|
||||||
|
if 'conditioning' in batch.keys():
|
||||||
|
self.self_conditioning_items += batch['conditioning_contains_self'].sum().item()
|
||||||
|
|
||||||
|
def get_debugging_map(self):
|
||||||
|
return {
|
||||||
|
'total_samples_loaded': self.total_items,
|
||||||
|
'percent_skipped_samples': (self.loaded_items - self.total_items) / self.loaded_items,
|
||||||
|
'percent_conditioning_is_self': self.self_conditioning_items / self.loaded_items,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
batch_sz = 8
|
batch_sz = 8
|
||||||
params = {
|
params = {
|
||||||
|
|
|
@ -12,7 +12,7 @@ from data.data_sampler import DistIterSampler
|
||||||
from trainer.eval.evaluator import create_evaluator
|
from trainer.eval.evaluator import create_evaluator
|
||||||
|
|
||||||
from utils import util, options as option
|
from utils import util, options as option
|
||||||
from data import create_dataloader, create_dataset
|
from data import create_dataloader, create_dataset, get_dataset_debugger
|
||||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||||
from time import time
|
from time import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -110,6 +110,9 @@ class Trainer:
|
||||||
for phase, dataset_opt in opt['datasets'].items():
|
for phase, dataset_opt in opt['datasets'].items():
|
||||||
if phase == 'train':
|
if phase == 'train':
|
||||||
self.train_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
|
self.train_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
|
||||||
|
self.dataset_debugger = get_dataset_debugger(dataset_opt)
|
||||||
|
if self.dataset_debugger is not None and resume_state is not None:
|
||||||
|
self.dataset_debugger.load_state(opt_get(resume_state, ['dataset_debugger_state'], {}))
|
||||||
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
|
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
|
||||||
total_iters = int(opt['train']['niter'])
|
total_iters = int(opt['train']['niter'])
|
||||||
self.total_epochs = int(math.ceil(total_iters / train_size))
|
self.total_epochs = int(math.ceil(total_iters / train_size))
|
||||||
|
@ -187,8 +190,12 @@ class Trainer:
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
#### log
|
#### log
|
||||||
|
if self.dataset_debugger is not None:
|
||||||
|
self.dataset_debugger.update(train_data)
|
||||||
if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0:
|
if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0:
|
||||||
logs = self.model.get_current_log(self.current_step)
|
logs = self.model.get_current_log(self.current_step)
|
||||||
|
if self.dataset_debugger is not None:
|
||||||
|
logs.update(self.dataset_debugger.get_debugging_map())
|
||||||
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step)
|
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step)
|
||||||
for v in self.model.get_current_learning_rate():
|
for v in self.model.get_current_learning_rate():
|
||||||
message += '{:.3e},'.format(v)
|
message += '{:.3e},'.format(v)
|
||||||
|
@ -213,7 +220,10 @@ class Trainer:
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
self.logger.info('Saving models and training states.')
|
self.logger.info('Saving models and training states.')
|
||||||
self.model.save(self.current_step)
|
self.model.save(self.current_step)
|
||||||
self.model.save_training_state(self.epoch, self.current_step)
|
state = {'epoch': self.epoch, 'iter': self.current_step}
|
||||||
|
if self.dataset_debugger is not None:
|
||||||
|
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
||||||
|
self.model.save_training_state(state)
|
||||||
if 'alt_path' in opt['path'].keys():
|
if 'alt_path' in opt['path'].keys():
|
||||||
import shutil
|
import shutil
|
||||||
print("Synchronizing tb_logger to alt_path..")
|
print("Synchronizing tb_logger to alt_path..")
|
||||||
|
|
|
@ -270,7 +270,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
if self.auto_recover is None:
|
if self.auto_recover is None:
|
||||||
print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.")
|
print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.")
|
||||||
self.save(step)
|
self.save(step)
|
||||||
self.save_training_state(0, step)
|
self.save_training_state({'iter': step})
|
||||||
raise ArithmeticError
|
raise ArithmeticError
|
||||||
else:
|
else:
|
||||||
print(f"!!!!!!!!Detected NaN grads more than 10 steps in a row. Restoring to a state {self.auto_recover} saves ago.")
|
print(f"!!!!!!!!Detected NaN grads more than 10 steps in a row. Restoring to a state {self.auto_recover} saves ago.")
|
||||||
|
|
|
@ -127,16 +127,16 @@ class BaseModel():
|
||||||
load_net_clean[k] = v
|
load_net_clean[k] = v
|
||||||
network.load_state_dict(load_net_clean, strict=strict)
|
network.load_state_dict(load_net_clean, strict=strict)
|
||||||
|
|
||||||
def save_training_state(self, epoch, iter_step):
|
def save_training_state(self, state):
|
||||||
"""Save training state during training, which will be used for resuming"""
|
"""Save training state during training, which will be used for resuming"""
|
||||||
state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
|
state.update({'schedulers': [], 'optimizers': []})
|
||||||
for s in self.schedulers:
|
for s in self.schedulers:
|
||||||
state['schedulers'].append(s.state_dict())
|
state['schedulers'].append(s.state_dict())
|
||||||
for o in self.optimizers:
|
for o in self.optimizers:
|
||||||
state['optimizers'].append(o.state_dict())
|
state['optimizers'].append(o.state_dict())
|
||||||
if 'amp_opt_level' in self.opt.keys():
|
if 'amp_opt_level' in self.opt.keys():
|
||||||
state['amp'] = amp.state_dict()
|
state['amp'] = amp.state_dict()
|
||||||
save_filename = '{}.state'.format(iter_step)
|
save_filename = '{}.state'.format(utils.util.opt_get(state, ['iter'], 'no_step_provided'))
|
||||||
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
|
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
|
||||||
torch.save(state, save_path)
|
torch.save(state, save_path)
|
||||||
if '__state__' not in self.save_history.keys():
|
if '__state__' not in self.save_history.keys():
|
||||||
|
|
|
@ -61,7 +61,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# Also convert the state.
|
# Also convert the state.
|
||||||
resume_state_from = torch.load(opt_from['path']['resume_state'])
|
resume_state_from = torch.load(opt_from['path']['resume_state'])
|
||||||
resume_state_to = model_to.save_training_state(0, 0, return_state=True)
|
resume_state_to = model_to.save_training_state({}, return_state=True)
|
||||||
resume_state_from['optimizers'][0]['param_groups'].append(resume_state_to['optimizers'][0]['param_groups'][-1])
|
resume_state_from['optimizers'][0]['param_groups'].append(resume_state_to['optimizers'][0]['param_groups'][-1])
|
||||||
torch.save(resume_state_from, "converted_state.pth")
|
torch.save(resume_state_from, "converted_state.pth")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user