diff --git a/codes/models/gpt_voice/lucidrains_gpt.py b/codes/models/gpt_voice/lucidrains_gpt.py index 3aea6f6d..f71a4c88 100644 --- a/codes/models/gpt_voice/lucidrains_gpt.py +++ b/codes/models/gpt_voice/lucidrains_gpt.py @@ -150,7 +150,7 @@ class Attention(nn.Module): out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') - out = self.to_out(out) + out = self.to_out(out) return out diff --git a/codes/train.py b/codes/train.py index a6a799ec..d582904f 100644 --- a/codes/train.py +++ b/codes/train.py @@ -138,6 +138,11 @@ class Trainer: ### Evaluators self.evaluators = [] if 'eval' in opt.keys() and 'evaluators' in opt['eval'].keys(): + # In "pure" mode, we propagate through the normal training steps, but use validation data instead and average + # the total loss. A validation dataloader is required. + if opt_get(opt, ['eval', 'pure'], False): + assert hasattr(self, 'val_loader') + for ev_key, ev_opt in opt['eval']['evaluators'].items(): self.evaluators.append(create_evaluator(self.model.networks[ev_opt['for']], ev_opt, self.model.env)) @@ -213,50 +218,27 @@ class Trainer: shutil.copytree(self.tb_logger_path, alt_tblogger) #### validation - if opt['datasets'].get('val', None) and self.current_step % opt['train']['val_freq'] == 0: - if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', - 'extensibletrainer'] and self.rank <= 0: # image restoration validation - avg_psnr = 0. - avg_fea_loss = 0. - idx = 0 - val_tqdm = tqdm(self.val_loader) - for val_data in val_tqdm: - idx += 1 - for b in range(len(val_data['HQ_path'])): - img_name = os.path.splitext(os.path.basename(val_data['HQ_path'][b]))[0] - img_dir = os.path.join(opt['path']['val_images'], img_name) - - util.mkdir(img_dir) - - self.model.feed_data(val_data, self.current_step) - self.model.test() - - visuals = self.model.get_current_visuals() - if visuals is None: - continue - - sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 - # calculate PSNR - if self.val_compute_psnr: - gt_img = util.tensor2img(visuals['hq'][b]) # uint8 - sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) - avg_psnr += util.calculate_psnr(sr_img, gt_img) - - # Save SR images for reference - img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step) - save_img_path = os.path.join(img_dir, img_base_name) - util.save_img(sr_img, save_img_path) - - avg_psnr = avg_psnr / idx - avg_fea_loss = avg_fea_loss / idx - - # log - self.logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss)) - - # tensorboard logger - if opt['use_tb_logger'] and 'debug' not in opt['name'] and self.rank <= 0: - self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step) - self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step) + if opt_get(opt, ['eval', 'pure'], False) and self.current_step % opt['train']['val_freq'] == 0: + metrics = [] + for val_data in tqdm(self.val_loader): + self.model.feed_data(val_data, self.current_step) + metrics.append(self.model.test()) + reduced_metrics = {} + for metric in metrics: + for k, v in metric.as_dict().items(): + if isinstance(v, torch.Tensor) and len(v.shape) == 0: + if k in reduced_metrics.keys(): + reduced_metrics[k].append(v) + else: + reduced_metrics[k] = [v] + if self.rank <= 0: + for k, v in reduced_metrics.items(): + val = torch.stack(v).mean().item() + self.tb_logger.add_scalar(k, val, self.current_step) + print(f">>Eval {k}: {val}") + if opt['wandb']: + import wandb + wandb.log({k: torch.stack(v).mean().item() for k,v in reduced_metrics.items()}) if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0: eval_dict = {} @@ -300,7 +282,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_lj.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index d88a447e..1b08b281 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -14,6 +14,7 @@ from trainer.steps import ConfigurableStep from trainer.experiments.experiments import get_experiment_for_name import torchvision.utils as utils +from utils.loss_accumulator import LossAccumulator, InfStorageLossAccumulator from utils.util import opt_get, denormalize logger = logging.getLogger('base') @@ -312,6 +313,7 @@ class ExtensibleTrainer(BaseModel): for net in self.netsG.values(): net.eval() + accum_metrics = InfStorageLossAccumulator() with torch.no_grad(): # This can happen one of two ways: Either a 'validation injector' is provided, in which case we run that. # Or, we run the entire chain of steps in "train" mode and use eval.output_state. @@ -327,7 +329,7 @@ class ExtensibleTrainer(BaseModel): # Iterate through the steps, performing them one at a time. state = self.dstate for step_num, s in enumerate(self.steps): - ns = s.do_forward_backward(state, 0, step_num, train=False) + ns = s.do_forward_backward(state, 0, step_num, train=False, loss_accumulator=accum_metrics) for k, v in ns.items(): state[k] = [v] @@ -340,6 +342,7 @@ class ExtensibleTrainer(BaseModel): for net in self.netsG.values(): net.train() + return accum_metrics # Fetches a summary of the log. def get_current_log(self, step): diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 90f5a063..2052f7f8 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -165,12 +165,13 @@ class ConfigurableStep(Module): # Performs all forward and backward passes for this step given an input state. All input states are lists of # chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later # steps might use. These tensors are automatically detached and accumulated into chunks. - def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True, no_ddp_sync=False): + def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True, no_ddp_sync=False, loss_accumulator=None): local_state = {} # <-- Will store the entire local state to be passed to injectors & losses. new_state = {} # <-- Will store state values created by this step for returning to ExtensibleTrainer. for k, v in state.items(): local_state[k] = v[grad_accum_step] local_state['train_nets'] = str(self.get_networks_trained()) + loss_accumulator = self.loss_accumulator if loss_accumulator is None else loss_accumulator # Some losses compute backward() internally. Accommodate this by stashing the amp_loss_id in env. self.env['amp_loss_id'] = amp_loss_id @@ -204,7 +205,7 @@ class ConfigurableStep(Module): local_state.update(injected) new_state.update(injected) - if train and len(self.losses) > 0: + if len(self.losses) > 0: # Finally, compute the losses. total_loss = 0 for loss_name, loss in self.losses.items(): @@ -223,14 +224,14 @@ class ConfigurableStep(Module): total_loss += l * self.weights[loss_name] # Record metrics. if isinstance(l, torch.Tensor): - self.loss_accumulator.add_loss(loss_name, l) + loss_accumulator.add_loss(loss_name, l) for n, v in loss.extra_metrics(): - self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v) + loss_accumulator.add_loss("%s_%s" % (loss_name, n), v) loss.clear_metrics() # In some cases, the loss could not be set (e.g. all losses have 'after') - if isinstance(total_loss, torch.Tensor): - self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss) + if train and isinstance(total_loss, torch.Tensor): + loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss) reset_required = total_loss < self.min_total_loss # Scale the loss down by the accumulation factor. @@ -245,7 +246,7 @@ class ConfigurableStep(Module): # way to simply bypass backward. If you want a more efficient way to specify a min_loss, use or # implement it at the loss level. self.get_network_for_name(self.step_opt['training']).zero_grad() - self.loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),)) + loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),)) self.grads_generated = True diff --git a/codes/utils/loss_accumulator.py b/codes/utils/loss_accumulator.py index 896138e9..1e78d07d 100644 --- a/codes/utils/loss_accumulator.py +++ b/codes/utils/loss_accumulator.py @@ -43,4 +43,38 @@ class LossAccumulator: result["loss_" + k] = torch.mean(buf[:i]) for k, v in self.counters.items(): result[k] = v + return result + + +# Stores losses in an infinitely-sized list. +class InfStorageLossAccumulator: + def __init__(self): + self.buffers = {} + + def add_loss(self, name, tensor): + if name not in self.buffers.keys(): + if "_histogram" in name: + tensor = torch.flatten(tensor.detach().cpu()) + self.buffers[name] = [] + else: + self.buffers[name] = [] + buf = self.buffers[name] + # Can take tensors or just plain python numbers. + if '_histogram' in name: + buf.append(torch.flatten(tensor.detach().cpu())) + elif isinstance(tensor, torch.Tensor): + buf.append(tensor.detach().cpu()) + else: + buf.append(tensor) + + def increment_metric(self, name): + pass + + def as_dict(self): + result = {} + for k, buf in self.buffers.items(): + if '_histogram' in k: + result["loss_" + k] = torch.flatten(buf) + else: + result["loss_" + k] = torch.mean(torch.stack(buf)) return result \ No newline at end of file