From 9210a62f58620fa9cec284a8c5680bc5166bb865 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 12 May 2020 10:09:45 -0600 Subject: [PATCH] Add rotating log buffer to trainer Should stabilize stats output. --- codes/models/SRGAN_model.py | 38 ++++++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 117cedeb..cd5387f1 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -288,15 +288,29 @@ class SRGANModel(BaseModel): # set log TODO(handle mega-batches?) if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: - self.log_dict['l_g_pix'] = l_g_pix.item() + self.add_log_entry('l_g_pix', l_g_pix.item()) if self.cri_fea: - self.log_dict['feature_weight'] = self.l_fea_w - self.log_dict['l_g_fea'] = l_g_fea.item() - self.log_dict['l_g_gan'] = l_g_gan.item() - self.log_dict['l_g_total'] = l_g_total.item() * self.mega_batch_factor - self.log_dict['l_d_real'] = l_d_real.item() * self.mega_batch_factor - self.log_dict['l_d_fake'] = l_d_fake.item() * self.mega_batch_factor - self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) + self.add_log_entry('feature_weight', self.l_fea_w) + self.add_log_entry('l_g_fea', l_g_fea.item()) + self.add_log_entry('l_g_gan', l_g_gan.item()) + self.add_log_entry('l_g_total', l_g_total.item() * self.mega_batch_factor) + self.add_log_entry('l_d_real', l_d_real.item() * self.mega_batch_factor) + self.add_log_entry('l_d_fake', l_d_fake.item() * self.mega_batch_factor) + self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) + + # Allows the log to serve as an easy-to-use rotating buffer. + def add_log_entry(self, key, value): + key_it = "%s_it" % (key,) + log_rotating_buffer_size = 50 + if key not in self.log_dict.keys(): + self.log_dict[key] = [] + self.log_dict[key_it] = 0 + if len(self.log_dict[key]) < log_rotating_buffer_size: + self.log_dict[key].append(value) + else: + self.log_dict[key][self.log_dict[key_it] % log_rotating_buffer_size] = value + self.log_dict[key_it] += 1 + def create_artificial_skips(self, truth_img): med_skip = F.interpolate(truth_img, scale_factor=.5) @@ -309,8 +323,14 @@ class SRGANModel(BaseModel): self.fake_GenOut = [self.netG(self.var_L[0])] self.netG.train() + # Fetches a summary of the log. def get_current_log(self): - return self.log_dict + return_log = {} + for k in self.log_dict.keys(): + if not isinstance(self.log_dict[k], list): + continue + return_log[k] = sum(self.log_dict[k]) / len(self.log_dict[k]) + return return_log def get_current_visuals(self, need_GT=True): out_dict = OrderedDict()