From d2d5e097d5cc996c0a4983988993ba6497e8d233 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 18 Jun 2020 11:29:10 -0600 Subject: [PATCH] Add profiling to SRGAN for testing timings --- codes/models/SRGAN_model.py | 54 ++++++++++++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 9a9daff8..4eb6582b 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -167,6 +167,11 @@ class SRGANModel(BaseModel): self.load_random_corruptor() def feed_data(self, data, need_GT=True): + _profile = True + if _profile: + from time import time + _t = time() + # Corrupt the data with the given corruptor, if specified. self.fed_LQ = data['LQ'].to(self.device) if self.netC and random.random() < self.corruptor_usage_prob: @@ -183,6 +188,11 @@ class SRGANModel(BaseModel): self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)] def optimize_parameters(self, step): + _profile = False + if _profile: + from time import time + _t = time() + # Some generators have variants depending on the current step. if hasattr(self.netG.module, "update_for_step"): self.netG.module.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) @@ -211,10 +221,13 @@ class SRGANModel(BaseModel): else: noise_theta = (self.D_noise_theta - self.D_noise_theta_floor) * (self.D_noise_final - min(step, self.D_noise_final)) / self.D_noise_final + self.D_noise_theta_floor + if _profile: + print("Misc setup %f" % (time() - _t,)) + _t = time() + self.fake_GenOut = [] var_ref_skips = [] for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix): - #from utils import gpu_mem_track #import inspect #gpu_tracker = gpu_mem_track.MemTracker(inspect.currentframe()) @@ -222,6 +235,10 @@ class SRGANModel(BaseModel): fake_GenOut = self.netG(var_L) #gpu_tracker.track() + if _profile: + print("Gen forward %f" % (time() - _t,)) + _t = time() + # Extract the image output. For generators that output skip-through connections, the master output is always # the first element of the tuple. if isinstance(fake_GenOut, tuple): @@ -245,6 +262,10 @@ class SRGANModel(BaseModel): l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea + if _profile: + print("Fea forward %f" % (time() - _t,)) + _t = time() + # Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role # in the resultant image. if step % self.l_fea_w_decay_steps == 0: @@ -267,8 +288,17 @@ class SRGANModel(BaseModel): with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled: l_g_total_scaled.backward() + + if _profile: + print("Gen backward %f" % (time() - _t,)) + _t = time() + self.optimizer_G.step() + if _profile: + print("Gen step %f" % (time() - _t,)) + _t = time() + # D if self.l_gan_w > 0: for p in self.netD.parameters(): @@ -284,6 +314,10 @@ class SRGANModel(BaseModel): # The following line detaches all generator outputs that are not None. fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)]) + if _profile: + print("Gen forward for disc %f" % (time() - _t,)) + _t = time() + # Apply noise to the inputs to slow discriminator convergence. var_ref = (var_ref[0] + noise,) + var_ref[1:] fake_H = (fake_H[0] + noise,) + fake_H[1:] @@ -308,15 +342,33 @@ class SRGANModel(BaseModel): # l_d_total.backward() pred_d_fake = self.netD(fake_H).detach() pred_d_real = self.netD(var_ref) + + if _profile: + print("Double disc forward (RAGAN) %f" % (time() - _t,)) + _t = time() + l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: l_d_real_scaled.backward() + + if _profile: + print("Disc backward 1 (RAGAN) %f" % (time() - _t,)) + _t = time() + pred_d_fake = self.netD(fake_H) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() + + if _profile: + print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,)) + _t = time() self.optimizer_D.step() + if _profile: + print("Disc step %f" % (time() - _t,)) + _t = time() + # Log sample images from first microbatch. if step % 50 == 0: sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")