Add rotating log buffer to trainer
Should stabilize stats output.
This commit is contained in:
parent
f217216c81
commit
9210a62f58
|
@ -288,15 +288,29 @@ class SRGANModel(BaseModel):
|
|||
# set log TODO(handle mega-batches?)
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
if self.cri_pix:
|
||||
self.log_dict['l_g_pix'] = l_g_pix.item()
|
||||
self.add_log_entry('l_g_pix', l_g_pix.item())
|
||||
if self.cri_fea:
|
||||
self.log_dict['feature_weight'] = self.l_fea_w
|
||||
self.log_dict['l_g_fea'] = l_g_fea.item()
|
||||
self.log_dict['l_g_gan'] = l_g_gan.item()
|
||||
self.log_dict['l_g_total'] = l_g_total.item() * self.mega_batch_factor
|
||||
self.log_dict['l_d_real'] = l_d_real.item() * self.mega_batch_factor
|
||||
self.log_dict['l_d_fake'] = l_d_fake.item() * self.mega_batch_factor
|
||||
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
|
||||
self.add_log_entry('feature_weight', self.l_fea_w)
|
||||
self.add_log_entry('l_g_fea', l_g_fea.item())
|
||||
self.add_log_entry('l_g_gan', l_g_gan.item())
|
||||
self.add_log_entry('l_g_total', l_g_total.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor)
|
||||
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
||||
|
||||
# Allows the log to serve as an easy-to-use rotating buffer.
|
||||
def add_log_entry(self, key, value):
|
||||
key_it = "%s_it" % (key,)
|
||||
log_rotating_buffer_size = 50
|
||||
if key not in self.log_dict.keys():
|
||||
self.log_dict[key] = []
|
||||
self.log_dict[key_it] = 0
|
||||
if len(self.log_dict[key]) < log_rotating_buffer_size:
|
||||
self.log_dict[key].append(value)
|
||||
else:
|
||||
self.log_dict[key][self.log_dict[key_it] % log_rotating_buffer_size] = value
|
||||
self.log_dict[key_it] += 1
|
||||
|
||||
|
||||
def create_artificial_skips(self, truth_img):
|
||||
med_skip = F.interpolate(truth_img, scale_factor=.5)
|
||||
|
@ -309,8 +323,14 @@ class SRGANModel(BaseModel):
|
|||
self.fake_GenOut = [self.netG(self.var_L[0])]
|
||||
self.netG.train()
|
||||
|
||||
# Fetches a summary of the log.
|
||||
def get_current_log(self):
|
||||
return self.log_dict
|
||||
return_log = {}
|
||||
for k in self.log_dict.keys():
|
||||
if not isinstance(self.log_dict[k], list):
|
||||
continue
|
||||
return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k])
|
||||
return return_log
|
||||
|
||||
def get_current_visuals(self, need_GT=True):
|
||||
out_dict = OrderedDict()
|
||||
|
|
Loading…
Reference in New Issue
Block a user