Add profiling to SRGAN for testing timings

This commit is contained in:
James Betker 2020-06-18 11:29:10 -06:00
parent 45a900fafe
commit d2d5e097d5

View File

@ -167,6 +167,11 @@ class SRGANModel(BaseModel):
self.load_random_corruptor()
def feed_data(self, data, need_GT=True):
_profile = True
if _profile:
from time import time
_t = time()
# Corrupt the data with the given corruptor, if specified.
self.fed_LQ = data['LQ'].to(self.device)
if self.netC and random.random() < self.corruptor_usage_prob:
@ -183,6 +188,11 @@ class SRGANModel(BaseModel):
self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)]
def optimize_parameters(self, step):
_profile = False
if _profile:
from time import time
_t = time()
# Some generators have variants depending on the current step.
if hasattr(self.netG.module, "update_for_step"):
self.netG.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
@ -211,10 +221,13 @@ class SRGANModel(BaseModel):
else:
noise_theta = (self.D_noise_theta - self.D_noise_theta_floor) * (self.D_noise_final - min(step, self.D_noise_final)) / self.D_noise_final + self.D_noise_theta_floor
if _profile:
print("Misc setup %f" % (time() - _t,))
_t = time()
self.fake_GenOut = []
var_ref_skips = []
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
#from utils import gpu_mem_track
#import inspect
#gpu_tracker = gpu_mem_track.MemTracker(inspect.currentframe())
@ -222,6 +235,10 @@ class SRGANModel(BaseModel):
fake_GenOut = self.netG(var_L)
#gpu_tracker.track()
if _profile:
print("Gen forward %f" % (time() - _t,))
_t = time()
# Extract the image output. For generators that output skip-through connections, the master output is always
# the first element of the tuple.
if isinstance(fake_GenOut, tuple):
@ -245,6 +262,10 @@ class SRGANModel(BaseModel):
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
l_g_total += l_g_fea
if _profile:
print("Fea forward %f" % (time() - _t,))
_t = time()
# Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role
# in the resultant image.
if step % self.l_fea_w_decay_steps == 0:
@ -267,8 +288,17 @@ class SRGANModel(BaseModel):
with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled:
l_g_total_scaled.backward()
if _profile:
print("Gen backward %f" % (time() - _t,))
_t = time()
self.optimizer_G.step()
if _profile:
print("Gen step %f" % (time() - _t,))
_t = time()
# D
if self.l_gan_w > 0:
for p in self.netD.parameters():
@ -284,6 +314,10 @@ class SRGANModel(BaseModel):
# The following line detaches all generator outputs that are not None.
fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)])
if _profile:
print("Gen forward for disc %f" % (time() - _t,))
_t = time()
# Apply noise to the inputs to slow discriminator convergence.
var_ref = (var_ref[0] + noise,) + var_ref[1:]
fake_H = (fake_H[0] + noise,) + fake_H[1:]
@ -308,15 +342,33 @@ class SRGANModel(BaseModel):
# l_d_total.backward()
pred_d_fake = self.netD(fake_H).detach()
pred_d_real = self.netD(var_ref)
if _profile:
print("Double disc forward (RAGAN) %f" % (time() - _t,))
_t = time()
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward()
if _profile:
print("Disc backward 1 (RAGAN) %f" % (time() - _t,))
_t = time()
pred_d_fake = self.netD(fake_H)
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward()
if _profile:
print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,))
_t = time()
self.optimizer_D.step()
if _profile:
print("Disc step %f" % (time() - _t,))
_t = time()
# Log sample images from first microbatch.
if step % 50 == 0:
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")