Add rotating log buffer to trainer

Should stabilize stats output.
This commit is contained in:
James Betker 2020-05-12 10:09:45 -06:00
parent f217216c81
commit 9210a62f58

View File

@ -288,15 +288,29 @@ class SRGANModel(BaseModel):
# set log TODO(handle mega-batches?)
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if self.cri_pix:
self.log_dict['l_g_pix'] = l_g_pix.item()
self.add_log_entry('l_g_pix', l_g_pix.item())
if self.cri_fea:
self.log_dict['feature_weight'] = self.l_fea_w
self.log_dict['l_g_fea'] = l_g_fea.item()
self.log_dict['l_g_gan'] = l_g_gan.item()
self.log_dict['l_g_total'] = l_g_total.item() * self.mega_batch_factor
self.log_dict['l_d_real'] = l_d_real.item() * self.mega_batch_factor
self.log_dict['l_d_fake'] = l_d_fake.item() * self.mega_batch_factor
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
self.add_log_entry('feature_weight', self.l_fea_w)
self.add_log_entry('l_g_fea', l_g_fea.item())
self.add_log_entry('l_g_gan', l_g_gan.item())
self.add_log_entry('l_g_total', l_g_total.item() * self.mega_batch_factor)
self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor)
self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor)
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
# Allows the log to serve as an easy-to-use rotating buffer.
def add_log_entry(self, key, value):
key_it = "%s_it" % (key,)
log_rotating_buffer_size = 50
if key not in self.log_dict.keys():
self.log_dict[key] = []
self.log_dict[key_it] = 0
if len(self.log_dict[key]) < log_rotating_buffer_size:
self.log_dict[key].append(value)
else:
self.log_dict[key][self.log_dict[key_it] % log_rotating_buffer_size] = value
self.log_dict[key_it] += 1
def create_artificial_skips(self, truth_img):
med_skip = F.interpolate(truth_img, scale_factor=.5)
@ -309,8 +323,14 @@ class SRGANModel(BaseModel):
self.fake_GenOut = [self.netG(self.var_L[0])]
self.netG.train()
# Fetches a summary of the log.
def get_current_log(self):
return self.log_dict
return_log = {}
for k in self.log_dict.keys():
if not isinstance(self.log_dict[k], list):
continue
return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k])
return return_log
def get_current_visuals(self, need_GT=True):
out_dict = OrderedDict()