Add "pure" evaluator

Which simply computes the training loss against an eval dataset
This commit is contained in:
James Betker 2021-08-09 14:58:35 -06:00
parent 080bea2f19
commit 82fc69abfa
5 changed files with 74 additions and 54 deletions

View File

@ -150,7 +150,7 @@ class Attention(nn.Module):
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
out = self.to_out(out)
return out

View File

@ -138,6 +138,11 @@ class Trainer:
### Evaluators
self.evaluators = []
if 'eval' in opt.keys() and 'evaluators' in opt['eval'].keys():
# In "pure" mode, we propagate through the normal training steps, but use validation data instead and average
# the total loss. A validation dataloader is required.
if opt_get(opt, ['eval', 'pure'], False):
assert hasattr(self, 'val_loader')
for ev_key, ev_opt in opt['eval']['evaluators'].items():
self.evaluators.append(create_evaluator(self.model.networks[ev_opt['for']],
ev_opt, self.model.env))
@ -213,50 +218,27 @@ class Trainer:
shutil.copytree(self.tb_logger_path, alt_tblogger)
#### validation
if opt['datasets'].get('val', None) and self.current_step % opt['train']['val_freq'] == 0:
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan',
'extensibletrainer'] and self.rank <= 0: # image restoration validation
avg_psnr = 0.
avg_fea_loss = 0.
idx = 0
val_tqdm = tqdm(self.val_loader)
for val_data in val_tqdm:
idx += 1
for b in range(len(val_data['HQ_path'])):
img_name = os.path.splitext(os.path.basename(val_data['HQ_path'][b]))[0]
img_dir = os.path.join(opt['path']['val_images'], img_name)
util.mkdir(img_dir)
self.model.feed_data(val_data, self.current_step)
self.model.test()
visuals = self.model.get_current_visuals()
if visuals is None:
continue
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
# calculate PSNR
if self.val_compute_psnr:
gt_img = util.tensor2img(visuals['hq'][b]) # uint8
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img)
# Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
save_img_path = os.path.join(img_dir, img_base_name)
util.save_img(sr_img, save_img_path)
avg_psnr = avg_psnr / idx
avg_fea_loss = avg_fea_loss / idx
# log
self.logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name'] and self.rank <= 0:
self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step)
self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step)
if opt_get(opt, ['eval', 'pure'], False) and self.current_step % opt['train']['val_freq'] == 0:
metrics = []
for val_data in tqdm(self.val_loader):
self.model.feed_data(val_data, self.current_step)
metrics.append(self.model.test())
reduced_metrics = {}
for metric in metrics:
for k, v in metric.as_dict().items():
if isinstance(v, torch.Tensor) and len(v.shape) == 0:
if k in reduced_metrics.keys():
reduced_metrics[k].append(v)
else:
reduced_metrics[k] = [v]
if self.rank <= 0:
for k, v in reduced_metrics.items():
val = torch.stack(v).mean().item()
self.tb_logger.add_scalar(k, val, self.current_step)
print(f">>Eval {k}: {val}")
if opt['wandb']:
import wandb
wandb.log({k: torch.stack(v).mean().item() for k,v in reduced_metrics.items()})
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0:
eval_dict = {}
@ -300,7 +282,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_lj.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -14,6 +14,7 @@ from trainer.steps import ConfigurableStep
from trainer.experiments.experiments import get_experiment_for_name
import torchvision.utils as utils
from utils.loss_accumulator import LossAccumulator, InfStorageLossAccumulator
from utils.util import opt_get, denormalize
logger = logging.getLogger('base')
@ -312,6 +313,7 @@ class ExtensibleTrainer(BaseModel):
for net in self.netsG.values():
net.eval()
accum_metrics = InfStorageLossAccumulator()
with torch.no_grad():
# This can happen one of two ways: Either a 'validation injector' is provided, in which case we run that.
# Or, we run the entire chain of steps in "train" mode and use eval.output_state.
@ -327,7 +329,7 @@ class ExtensibleTrainer(BaseModel):
# Iterate through the steps, performing them one at a time.
state = self.dstate
for step_num, s in enumerate(self.steps):
ns = s.do_forward_backward(state, 0, step_num, train=False)
ns = s.do_forward_backward(state, 0, step_num, train=False, loss_accumulator=accum_metrics)
for k, v in ns.items():
state[k] = [v]
@ -340,6 +342,7 @@ class ExtensibleTrainer(BaseModel):
for net in self.netsG.values():
net.train()
return accum_metrics
# Fetches a summary of the log.
def get_current_log(self, step):

View File

@ -165,12 +165,13 @@ class ConfigurableStep(Module):
# Performs all forward and backward passes for this step given an input state. All input states are lists of
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
# steps might use. These tensors are automatically detached and accumulated into chunks.
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True, no_ddp_sync=False):
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True, no_ddp_sync=False, loss_accumulator=None):
local_state = {} # <-- Will store the entire local state to be passed to injectors & losses.
new_state = {} # <-- Will store state values created by this step for returning to ExtensibleTrainer.
for k, v in state.items():
local_state[k] = v[grad_accum_step]
local_state['train_nets'] = str(self.get_networks_trained())
loss_accumulator = self.loss_accumulator if loss_accumulator is None else loss_accumulator
# Some losses compute backward() internally. Accommodate this by stashing the amp_loss_id in env.
self.env['amp_loss_id'] = amp_loss_id
@ -204,7 +205,7 @@ class ConfigurableStep(Module):
local_state.update(injected)
new_state.update(injected)
if train and len(self.losses) > 0:
if len(self.losses) > 0:
# Finally, compute the losses.
total_loss = 0
for loss_name, loss in self.losses.items():
@ -223,14 +224,14 @@ class ConfigurableStep(Module):
total_loss += l * self.weights[loss_name]
# Record metrics.
if isinstance(l, torch.Tensor):
self.loss_accumulator.add_loss(loss_name, l)
loss_accumulator.add_loss(loss_name, l)
for n, v in loss.extra_metrics():
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
loss.clear_metrics()
# In some cases, the loss could not be set (e.g. all losses have 'after')
if isinstance(total_loss, torch.Tensor):
self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
if train and isinstance(total_loss, torch.Tensor):
loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
reset_required = total_loss < self.min_total_loss
# Scale the loss down by the accumulation factor.
@ -245,7 +246,7 @@ class ConfigurableStep(Module):
# way to simply bypass backward. If you want a more efficient way to specify a min_loss, use or
# implement it at the loss level.
self.get_network_for_name(self.step_opt['training']).zero_grad()
self.loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),))
loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),))
self.grads_generated = True

View File

@ -43,4 +43,38 @@ class LossAccumulator:
result["loss_" + k] = torch.mean(buf[:i])
for k, v in self.counters.items():
result[k] = v
return result
# Stores losses in an infinitely-sized list.
class InfStorageLossAccumulator:
def __init__(self):
self.buffers = {}
def add_loss(self, name, tensor):
if name not in self.buffers.keys():
if "_histogram" in name:
tensor = torch.flatten(tensor.detach().cpu())
self.buffers[name] = []
else:
self.buffers[name] = []
buf = self.buffers[name]
# Can take tensors or just plain python numbers.
if '_histogram' in name:
buf.append(torch.flatten(tensor.detach().cpu()))
elif isinstance(tensor, torch.Tensor):
buf.append(tensor.detach().cpu())
else:
buf.append(tensor)
def increment_metric(self, name):
pass
def as_dict(self):
result = {}
for k, buf in self.buffers.items():
if '_histogram' in k:
result["loss_" + k] = torch.flatten(buf)
else:
result["loss_" + k] = torch.mean(torch.stack(buf))
return result