diff --git a/codes/models/loss.py b/codes/models/loss.py index ef61b84c..43f6ae7d 100644 --- a/codes/models/loss.py +++ b/codes/models/loss.py @@ -54,7 +54,7 @@ class GANLoss(nn.Module): target_label = target_is_real else: target_label = self.get_target_label(input, target_is_real) - loss = self.loss(input, target_label) + loss = self.loss(input.float(), target_label.float()) return loss diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 4641ac05..e31d5a64 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -91,7 +91,7 @@ class PixLoss(ConfigurableLoss): self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device']) def forward(self, _, state): - return self.criterion(state[self.opt['fake']], state[self.opt['real']]) + return self.criterion(state[self.opt['fake']].float(), state[self.opt['real']].float()) class FeatureLoss(ConfigurableLoss): @@ -105,13 +105,14 @@ class FeatureLoss(ConfigurableLoss): self.netF = torch.nn.parallel.DataParallel(self.netF) def forward(self, _, state): - with torch.no_grad(): - logits_real = self.netF(state[self.opt['real']]) - logits_fake = self.netF(state[self.opt['fake']]) + with autocast(enabled=self.env['opt']['fp16']): + with torch.no_grad(): + logits_real = self.netF(state[self.opt['real']]) + logits_fake = self.netF(state[self.opt['fake']]) if self.opt['criterion'] == 'cosine': - return self.criterion(logits_fake, logits_real, torch.ones(1, device=logits_fake.device)) + return self.criterion(logits_fake.float(), logits_real.float(), torch.ones(1, device=logits_fake.device)) else: - return self.criterion(logits_fake, logits_real) + return self.criterion(logits_fake.float(), logits_real.float()) # Special form of feature loss which first computes the feature embedding for the truth space, then uses a second @@ -132,7 +133,7 @@ class InterpretedFeatureLoss(ConfigurableLoss): def forward(self, _, state): logits_real = self.netF_real(state[self.opt['real']]) logits_fake = self.netF_gen(state[self.opt['fake']]) - return self.criterion(logits_fake, logits_real) + return self.criterion(logits_fake.float(), logits_real.float()) class GeneratorGanLoss(ConfigurableLoss): @@ -300,7 +301,7 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss): if self.opt['criterion'] == 'cosine': return self.criterion(state[self.opt['real']], upsampled_altered, torch.ones(1, device=upsampled_altered.device)) else: - return self.criterion(state[self.opt['real']], upsampled_altered) + return self.criterion(state[self.opt['real']].float(), upsampled_altered.float()) # Computes a loss created by comparing the output of a generator to the output from the same generator when fed an @@ -353,7 +354,7 @@ class TranslationInvarianceLoss(ConfigurableLoss): if self.opt['criterion'] == 'cosine': return self.criterion(fake_shared_output, real_shared_output, torch.ones(1, device=real_shared_output.device)) else: - return self.criterion(fake_shared_output, real_shared_output) + return self.criterion(fake_shared_output.float(), real_shared_output.float()) # Computes a loss repeatedly feeding the generator downsampled inputs created from its outputs. The expectation is @@ -392,7 +393,7 @@ class RecursiveInvarianceLoss(ConfigurableLoss): if self.opt['criterion'] == 'cosine': return self.criterion(compare_real, compare_fake, torch.ones(1, device=compare_real.device)) else: - return self.criterion(compare_real, compare_fake) + return self.criterion(compare_real.float(), compare_fake.float()) # Loss that pulls tensors from dim 1 of the input and repeatedly feeds them into the diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 9afa97d9..db301c7a 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -207,8 +207,8 @@ class TecoGanLoss(ConfigurableLoss): lr = state[self.opt['lr_inputs']] l_total = 0 for i in range(sequence_len - 2): - real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16) - fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16) + real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin) + fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin) with autocast(enabled=fp16): d_fake = net(fake_sext) d_real = net(real_sext) diff --git a/codes/multi_modal_train.py b/codes/multi_modal_train.py index 604061cd..fd32ba17 100644 --- a/codes/multi_modal_train.py +++ b/codes/multi_modal_train.py @@ -25,7 +25,7 @@ def main(master_opt, launcher): for k, v in model.networks.items(): if k in all_networks.keys() and k not in shared_networks: shared_networks.append(k) - all_networks[k] = v + all_networks[k] = v.module trainers.append(train_gen) print("Networks being shared by trainers: ", shared_networks) diff --git a/codes/train.py b/codes/train.py index 02d6173f..45b6da28 100644 --- a/codes/train.py +++ b/codes/train.py @@ -175,7 +175,7 @@ class Trainer: #### log if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0: logs = self.model.get_current_log(self.current_step) - message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, self.current_step) + message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step) for v in self.model.get_current_learning_rate(): message += '{:.3e},'.format(v) message += ')] ' @@ -196,7 +196,7 @@ class Trainer: if self.rank <= 0: self.logger.info('Saving models and training states.') self.model.save(self.current_step) - self.model.save_training_state(epoch, self.current_step) + self.model.save_training_state(self.epoch, self.current_step) if 'alt_path' in opt['path'].keys(): import shutil print("Synchronizing tb_logger to alt_path..") @@ -253,6 +253,7 @@ class Trainer: def do_training(self): self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step)) for epoch in range(self.start_epoch, self.total_epochs + 1): + self.epoch = epoch if opt['dist']: self.train_sampler.set_epoch(epoch) tq_ldr = tqdm(self.train_loader) @@ -264,6 +265,7 @@ class Trainer: def create_training_generator(self, index): self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step)) for epoch in range(self.start_epoch, self.total_epochs + 1): + self.epoch = epoch if self.opt['dist']: self.train_sampler.set_epoch(epoch) tq_ldr = tqdm(self.train_loader, position=index)