from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state from models.layers.resample2d_package.resample2d import Resample2d from models.steps.recurrent import RecurrentController from models.steps.injectors import Injector import torch def create_teco_loss(opt, env): type = opt['type'] if type == 'teco_generator_gan': return TecoGanGeneratorLoss(opt, env) elif type == 'teco_discriminator_gan': return TecoGanDiscriminatorLoss(opt, env) elif type == "teco_pingpong": return PingPongLoss(opt, env) return None def create_teco_discriminator_sextuplet(input_list, index, flow_gen, resampler): triplet = input_list[index:index+3] first_flow = flow_gen(triplet[0], triplet[1]) last_flow = flow_gen(triplet[2], triplet[1]) flow_triplet = [resampler(triplet[0], first_flow), triplet[1], resampler(triplet[2], last_flow)] return torch.cat(triplet + flow_triplet, dim=1) # Uses a generator to synthesize a sequence of images from [in] and injects the results into a list [out] # Images are fed in sequentially forward and back, resulting in len([out])=2*len([in])-1 (last element is not repeated). # All computation is done with torch.no_grad(). class RecurrentImageGeneratorSequenceInjector(Injector): def __init__(self, opt, env): super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env) self.flow = opt['flow_network'] self.resample = Resample2d() def forward(self, state): gen = self.env['generators'][self.opt['generator']] flow = self.env['generators'][self.flow] results = [] recurrent_input = torch.zeros_like(state[self.input][0]) # Go forward in the sequence first. first_step = True for input in state[self.input]: if first_step: first_step = False else: flowfield = flow(recurrent_input, input) recurrent_input = self.resample(recurrent_input, flowfield) recurrent_input = gen(input, recurrent_input) results.append(recurrent_input) recurrent_input = self.flow() # Now go backwards, skipping the last element (it's already stored in recurrent_input) it = reversed(range(len(results) - 1)) for i in it: flowfield = flow(recurrent_input, results[i]) recurrent_input = self.resample(recurrent_input, flowfield) recurrent_input = gen(results[i], recurrent_input) results.append(recurrent_input) return {self.output: results} # This is the temporal discriminator loss from TecoGAN. # # It has a strict contact for 'real' and 'fake' inputs: # 'real' - Must be a list of arbitrary images (len>3) drawn from the dataset # 'fake' - The output of the RecurrentImageGeneratorSequenceInjector for the same set of images. # # This loss does the following: # 1) Picks an image triplet, starting with the first '3' elements in 'real' and 'fake'. # 2) Uses the image flow generator (specified with 'image_flow_generator') to create detached flow fields for the first and last images in the above sequence. # 3) Warps the first and last images according to the flow field. # 4) Composes the three base image and the 2 warped images and middle image into a tensor concatenated at the filter dimension for both real and fake, resulting in a bx18xhxw shape tensor. # 5) Feeds the catted real and fake image sets into the discriminator, computes a loss, and backward(). # 6) Repeat from (1) until all triplets from the real sequence have been exhausted. class TecoGanDiscriminatorLoss(ConfigurableLoss): def __init__(self, opt, env): super(TecoGanDiscriminatorLoss, self).__init__(opt, env) self.opt = opt self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) self.noise = None if 'noise' not in opt.keys() else opt['noise'] self.image_flow_generator = opt['image_flow_generator'] self.resampler = Resample2d() def forward(self, net, state): self.metrics = [] flow_gen = self.env['generators'][self.image_flow_generator] real = state[self.opt['real']] fake = state[self.opt['fake']] l_total = 0 for i in range(len(real) - 2): real_sext = create_teco_discriminator_sextuplet(real, i, flow_gen, self.resampler) fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler) d_real = net(real_sext) d_fake = net(fake_sext) if self.opt['gan_type'] in ['gan', 'pixgan']: self.metrics.append(("d_fake", torch.mean(d_fake))) self.metrics.append(("d_real", torch.mean(d_real))) l_real = self.criterion(d_real, True) l_fake = self.criterion(d_fake, False) l_total += l_real + l_fake elif self.opt['gan_type'] == 'ragan': d_fake_diff = d_fake - torch.mean(d_real) self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) l_total += (self.criterion(d_real - torch.mean(d_fake), True) + self.criterion(d_fake_diff, False)) else: raise NotImplementedError return l_total class TecoGanGeneratorLoss(ConfigurableLoss): def __init__(self, opt, env): super(TecoGanGeneratorLoss, self).__init__(opt, env) self.opt = opt self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) # TecoGAN parameters self.image_flow_generator = opt['image_flow_generator'] self.resampler = Resample2d() def forward(self, _, state): flow_gen = self.env['generators'][self.image_flow_generator] real = state[self.opt['real']] fake = state[self.opt['fake']] l_total = 0 for i in range(len(real) - 2): real_sext = create_teco_discriminator_sextuplet(real, i, flow_gen, self.resampler) fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler) d_fake = net(fake_sext) if self.opt['gan_type'] in ['gan', 'pixgan']: self.metrics.append(("d_fake", torch.mean(d_fake))) l_fake = self.criterion(d_fake, True) l_total += l_fake elif self.opt['gan_type'] == 'ragan': d_real = net(real_sext) d_fake_diff = d_fake - torch.mean(d_real) self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) l_total += (self.criterion(d_real - torch.mean(d_fake), False) + self.criterion(d_fake_diff, True)) else: raise NotImplementedError return l_total # This loss doesn't have a real entry - only fakes are used. class PingPongLoss(ConfigurableLoss): def __init__(self, opt, env): super(PingPongLoss, self).__init__(opt, env) self.opt = opt self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) def forward(self, _, state): fake = state[self.opt['fake']] l_total = 0 for i in range((len(fake) - 1) / 2): early = fake[i] late = fake[-i] l_total += self.criterion(early, late) return l_total