Add profiling to SRGAN for testing timings
This commit is contained in:
parent
45a900fafe
commit
d2d5e097d5
|
@ -167,6 +167,11 @@ class SRGANModel(BaseModel):
|
|||
self.load_random_corruptor()
|
||||
|
||||
def feed_data(self, data, need_GT=True):
|
||||
_profile = True
|
||||
if _profile:
|
||||
from time import time
|
||||
_t = time()
|
||||
|
||||
# Corrupt the data with the given corruptor, if specified.
|
||||
self.fed_LQ = data['LQ'].to(self.device)
|
||||
if self.netC and random.random() < self.corruptor_usage_prob:
|
||||
|
@ -183,6 +188,11 @@ class SRGANModel(BaseModel):
|
|||
self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)]
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
_profile = False
|
||||
if _profile:
|
||||
from time import time
|
||||
_t = time()
|
||||
|
||||
# Some generators have variants depending on the current step.
|
||||
if hasattr(self.netG.module, "update_for_step"):
|
||||
self.netG.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
||||
|
@ -211,10 +221,13 @@ class SRGANModel(BaseModel):
|
|||
else:
|
||||
noise_theta = (self.D_noise_theta - self.D_noise_theta_floor) * (self.D_noise_final - min(step, self.D_noise_final)) / self.D_noise_final + self.D_noise_theta_floor
|
||||
|
||||
if _profile:
|
||||
print("Misc setup %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
self.fake_GenOut = []
|
||||
var_ref_skips = []
|
||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
||||
|
||||
#from utils import gpu_mem_track
|
||||
#import inspect
|
||||
#gpu_tracker = gpu_mem_track.MemTracker(inspect.currentframe())
|
||||
|
@ -222,6 +235,10 @@ class SRGANModel(BaseModel):
|
|||
fake_GenOut = self.netG(var_L)
|
||||
#gpu_tracker.track()
|
||||
|
||||
if _profile:
|
||||
print("Gen forward %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
# Extract the image output. For generators that output skip-through connections, the master output is always
|
||||
# the first element of the tuple.
|
||||
if isinstance(fake_GenOut, tuple):
|
||||
|
@ -245,6 +262,10 @@ class SRGANModel(BaseModel):
|
|||
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
||||
l_g_total += l_g_fea
|
||||
|
||||
if _profile:
|
||||
print("Fea forward %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
# Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role
|
||||
# in the resultant image.
|
||||
if step % self.l_fea_w_decay_steps == 0:
|
||||
|
@ -267,8 +288,17 @@ class SRGANModel(BaseModel):
|
|||
|
||||
with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled:
|
||||
l_g_total_scaled.backward()
|
||||
|
||||
if _profile:
|
||||
print("Gen backward %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
self.optimizer_G.step()
|
||||
|
||||
if _profile:
|
||||
print("Gen step %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
# D
|
||||
if self.l_gan_w > 0:
|
||||
for p in self.netD.parameters():
|
||||
|
@ -284,6 +314,10 @@ class SRGANModel(BaseModel):
|
|||
# The following line detaches all generator outputs that are not None.
|
||||
fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)])
|
||||
|
||||
if _profile:
|
||||
print("Gen forward for disc %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
# Apply noise to the inputs to slow discriminator convergence.
|
||||
var_ref = (var_ref[0] + noise,) + var_ref[1:]
|
||||
fake_H = (fake_H[0] + noise,) + fake_H[1:]
|
||||
|
@ -308,15 +342,33 @@ class SRGANModel(BaseModel):
|
|||
# l_d_total.backward()
|
||||
pred_d_fake = self.netD(fake_H).detach()
|
||||
pred_d_real = self.netD(var_ref)
|
||||
|
||||
if _profile:
|
||||
print("Double disc forward (RAGAN) %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 / self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
|
||||
if _profile:
|
||||
print("Disc backward 1 (RAGAN) %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
pred_d_fake = self.netD(fake_H)
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 / self.mega_batch_factor
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
|
||||
if _profile:
|
||||
print("Disc forward/backward 2 (RAGAN) %f" % (time() - _t,))
|
||||
_t = time()
|
||||
self.optimizer_D.step()
|
||||
|
||||
if _profile:
|
||||
print("Disc step %f" % (time() - _t,))
|
||||
_t = time()
|
||||
|
||||
# Log sample images from first microbatch.
|
||||
if step % 50 == 0:
|
||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
|
||||
|
|
Loading…
Reference in New Issue
Block a user