Add "pure" evaluator
Which simply computes the training loss against an eval dataset
This commit is contained in:
parent
080bea2f19
commit
82fc69abfa
|
@ -150,7 +150,7 @@ class Attention(nn.Module):
|
|||
|
||||
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
|
||||
|
|
|
@ -138,6 +138,11 @@ class Trainer:
|
|||
### Evaluators
|
||||
self.evaluators = []
|
||||
if 'eval' in opt.keys() and 'evaluators' in opt['eval'].keys():
|
||||
# In "pure" mode, we propagate through the normal training steps, but use validation data instead and average
|
||||
# the total loss. A validation dataloader is required.
|
||||
if opt_get(opt, ['eval', 'pure'], False):
|
||||
assert hasattr(self, 'val_loader')
|
||||
|
||||
for ev_key, ev_opt in opt['eval']['evaluators'].items():
|
||||
self.evaluators.append(create_evaluator(self.model.networks[ev_opt['for']],
|
||||
ev_opt, self.model.env))
|
||||
|
@ -213,50 +218,27 @@ class Trainer:
|
|||
shutil.copytree(self.tb_logger_path, alt_tblogger)
|
||||
|
||||
#### validation
|
||||
if opt['datasets'].get('val', None) and self.current_step % opt['train']['val_freq'] == 0:
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan',
|
||||
'extensibletrainer'] and self.rank <= 0: # image restoration validation
|
||||
avg_psnr = 0.
|
||||
avg_fea_loss = 0.
|
||||
idx = 0
|
||||
val_tqdm = tqdm(self.val_loader)
|
||||
for val_data in val_tqdm:
|
||||
idx += 1
|
||||
for b in range(len(val_data['HQ_path'])):
|
||||
img_name = os.path.splitext(os.path.basename(val_data['HQ_path'][b]))[0]
|
||||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||
|
||||
util.mkdir(img_dir)
|
||||
|
||||
self.model.feed_data(val_data, self.current_step)
|
||||
self.model.test()
|
||||
|
||||
visuals = self.model.get_current_visuals()
|
||||
if visuals is None:
|
||||
continue
|
||||
|
||||
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
||||
# calculate PSNR
|
||||
if self.val_compute_psnr:
|
||||
gt_img = util.tensor2img(visuals['hq'][b]) # uint8
|
||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||
|
||||
# Save SR images for reference
|
||||
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
|
||||
save_img_path = os.path.join(img_dir, img_base_name)
|
||||
util.save_img(sr_img, save_img_path)
|
||||
|
||||
avg_psnr = avg_psnr / idx
|
||||
avg_fea_loss = avg_fea_loss / idx
|
||||
|
||||
# log
|
||||
self.logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
|
||||
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name'] and self.rank <= 0:
|
||||
self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step)
|
||||
self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step)
|
||||
if opt_get(opt, ['eval', 'pure'], False) and self.current_step % opt['train']['val_freq'] == 0:
|
||||
metrics = []
|
||||
for val_data in tqdm(self.val_loader):
|
||||
self.model.feed_data(val_data, self.current_step)
|
||||
metrics.append(self.model.test())
|
||||
reduced_metrics = {}
|
||||
for metric in metrics:
|
||||
for k, v in metric.as_dict().items():
|
||||
if isinstance(v, torch.Tensor) and len(v.shape) == 0:
|
||||
if k in reduced_metrics.keys():
|
||||
reduced_metrics[k].append(v)
|
||||
else:
|
||||
reduced_metrics[k] = [v]
|
||||
if self.rank <= 0:
|
||||
for k, v in reduced_metrics.items():
|
||||
val = torch.stack(v).mean().item()
|
||||
self.tb_logger.add_scalar(k, val, self.current_step)
|
||||
print(f">>Eval {k}: {val}")
|
||||
if opt['wandb']:
|
||||
import wandb
|
||||
wandb.log({k: torch.stack(v).mean().item() for k,v in reduced_metrics.items()})
|
||||
|
||||
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0:
|
||||
eval_dict = {}
|
||||
|
@ -300,7 +282,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_lj.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -14,6 +14,7 @@ from trainer.steps import ConfigurableStep
|
|||
from trainer.experiments.experiments import get_experiment_for_name
|
||||
import torchvision.utils as utils
|
||||
|
||||
from utils.loss_accumulator import LossAccumulator, InfStorageLossAccumulator
|
||||
from utils.util import opt_get, denormalize
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
@ -312,6 +313,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
for net in self.netsG.values():
|
||||
net.eval()
|
||||
|
||||
accum_metrics = InfStorageLossAccumulator()
|
||||
with torch.no_grad():
|
||||
# This can happen one of two ways: Either a 'validation injector' is provided, in which case we run that.
|
||||
# Or, we run the entire chain of steps in "train" mode and use eval.output_state.
|
||||
|
@ -327,7 +329,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
# Iterate through the steps, performing them one at a time.
|
||||
state = self.dstate
|
||||
for step_num, s in enumerate(self.steps):
|
||||
ns = s.do_forward_backward(state, 0, step_num, train=False)
|
||||
ns = s.do_forward_backward(state, 0, step_num, train=False, loss_accumulator=accum_metrics)
|
||||
for k, v in ns.items():
|
||||
state[k] = [v]
|
||||
|
||||
|
@ -340,6 +342,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
for net in self.netsG.values():
|
||||
net.train()
|
||||
return accum_metrics
|
||||
|
||||
# Fetches a summary of the log.
|
||||
def get_current_log(self, step):
|
||||
|
|
|
@ -165,12 +165,13 @@ class ConfigurableStep(Module):
|
|||
# Performs all forward and backward passes for this step given an input state. All input states are lists of
|
||||
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
|
||||
# steps might use. These tensors are automatically detached and accumulated into chunks.
|
||||
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True, no_ddp_sync=False):
|
||||
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True, no_ddp_sync=False, loss_accumulator=None):
|
||||
local_state = {} # <-- Will store the entire local state to be passed to injectors & losses.
|
||||
new_state = {} # <-- Will store state values created by this step for returning to ExtensibleTrainer.
|
||||
for k, v in state.items():
|
||||
local_state[k] = v[grad_accum_step]
|
||||
local_state['train_nets'] = str(self.get_networks_trained())
|
||||
loss_accumulator = self.loss_accumulator if loss_accumulator is None else loss_accumulator
|
||||
|
||||
# Some losses compute backward() internally. Accommodate this by stashing the amp_loss_id in env.
|
||||
self.env['amp_loss_id'] = amp_loss_id
|
||||
|
@ -204,7 +205,7 @@ class ConfigurableStep(Module):
|
|||
local_state.update(injected)
|
||||
new_state.update(injected)
|
||||
|
||||
if train and len(self.losses) > 0:
|
||||
if len(self.losses) > 0:
|
||||
# Finally, compute the losses.
|
||||
total_loss = 0
|
||||
for loss_name, loss in self.losses.items():
|
||||
|
@ -223,14 +224,14 @@ class ConfigurableStep(Module):
|
|||
total_loss += l * self.weights[loss_name]
|
||||
# Record metrics.
|
||||
if isinstance(l, torch.Tensor):
|
||||
self.loss_accumulator.add_loss(loss_name, l)
|
||||
loss_accumulator.add_loss(loss_name, l)
|
||||
for n, v in loss.extra_metrics():
|
||||
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
|
||||
loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
|
||||
loss.clear_metrics()
|
||||
|
||||
# In some cases, the loss could not be set (e.g. all losses have 'after')
|
||||
if isinstance(total_loss, torch.Tensor):
|
||||
self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
|
||||
if train and isinstance(total_loss, torch.Tensor):
|
||||
loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
|
||||
reset_required = total_loss < self.min_total_loss
|
||||
|
||||
# Scale the loss down by the accumulation factor.
|
||||
|
@ -245,7 +246,7 @@ class ConfigurableStep(Module):
|
|||
# way to simply bypass backward. If you want a more efficient way to specify a min_loss, use or
|
||||
# implement it at the loss level.
|
||||
self.get_network_for_name(self.step_opt['training']).zero_grad()
|
||||
self.loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),))
|
||||
loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),))
|
||||
|
||||
self.grads_generated = True
|
||||
|
||||
|
|
|
@ -44,3 +44,37 @@ class LossAccumulator:
|
|||
for k, v in self.counters.items():
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
|
||||
# Stores losses in an infinitely-sized list.
|
||||
class InfStorageLossAccumulator:
|
||||
def __init__(self):
|
||||
self.buffers = {}
|
||||
|
||||
def add_loss(self, name, tensor):
|
||||
if name not in self.buffers.keys():
|
||||
if "_histogram" in name:
|
||||
tensor = torch.flatten(tensor.detach().cpu())
|
||||
self.buffers[name] = []
|
||||
else:
|
||||
self.buffers[name] = []
|
||||
buf = self.buffers[name]
|
||||
# Can take tensors or just plain python numbers.
|
||||
if '_histogram' in name:
|
||||
buf.append(torch.flatten(tensor.detach().cpu()))
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
buf.append(tensor.detach().cpu())
|
||||
else:
|
||||
buf.append(tensor)
|
||||
|
||||
def increment_metric(self, name):
|
||||
pass
|
||||
|
||||
def as_dict(self):
|
||||
result = {}
|
||||
for k, buf in self.buffers.items():
|
||||
if '_histogram' in k:
|
||||
result["loss_" + k] = torch.flatten(buf)
|
||||
else:
|
||||
result["loss_" + k] = torch.mean(torch.stack(buf))
|
||||
return result
|
Loading…
Reference in New Issue
Block a user