More adjustments to support multi-modal training
Specifically - looks like at least MSE loss cannot handle autocasted tensors
This commit is contained in:
parent
76789a456f
commit
e9c0b9f0fd
|
@ -54,7 +54,7 @@ class GANLoss(nn.Module):
|
|||
target_label = target_is_real
|
||||
else:
|
||||
target_label = self.get_target_label(input, target_is_real)
|
||||
loss = self.loss(input, target_label)
|
||||
loss = self.loss(input.float(), target_label.float())
|
||||
return loss
|
||||
|
||||
|
||||
|
|
|
@ -91,7 +91,7 @@ class PixLoss(ConfigurableLoss):
|
|||
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||
|
||||
def forward(self, _, state):
|
||||
return self.criterion(state[self.opt['fake']], state[self.opt['real']])
|
||||
return self.criterion(state[self.opt['fake']].float(), state[self.opt['real']].float())
|
||||
|
||||
|
||||
class FeatureLoss(ConfigurableLoss):
|
||||
|
@ -105,13 +105,14 @@ class FeatureLoss(ConfigurableLoss):
|
|||
self.netF = torch.nn.parallel.DataParallel(self.netF)
|
||||
|
||||
def forward(self, _, state):
|
||||
with torch.no_grad():
|
||||
logits_real = self.netF(state[self.opt['real']])
|
||||
logits_fake = self.netF(state[self.opt['fake']])
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
with torch.no_grad():
|
||||
logits_real = self.netF(state[self.opt['real']])
|
||||
logits_fake = self.netF(state[self.opt['fake']])
|
||||
if self.opt['criterion'] == 'cosine':
|
||||
return self.criterion(logits_fake, logits_real, torch.ones(1, device=logits_fake.device))
|
||||
return self.criterion(logits_fake.float(), logits_real.float(), torch.ones(1, device=logits_fake.device))
|
||||
else:
|
||||
return self.criterion(logits_fake, logits_real)
|
||||
return self.criterion(logits_fake.float(), logits_real.float())
|
||||
|
||||
|
||||
# Special form of feature loss which first computes the feature embedding for the truth space, then uses a second
|
||||
|
@ -132,7 +133,7 @@ class InterpretedFeatureLoss(ConfigurableLoss):
|
|||
def forward(self, _, state):
|
||||
logits_real = self.netF_real(state[self.opt['real']])
|
||||
logits_fake = self.netF_gen(state[self.opt['fake']])
|
||||
return self.criterion(logits_fake, logits_real)
|
||||
return self.criterion(logits_fake.float(), logits_real.float())
|
||||
|
||||
|
||||
class GeneratorGanLoss(ConfigurableLoss):
|
||||
|
@ -300,7 +301,7 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
|
|||
if self.opt['criterion'] == 'cosine':
|
||||
return self.criterion(state[self.opt['real']], upsampled_altered, torch.ones(1, device=upsampled_altered.device))
|
||||
else:
|
||||
return self.criterion(state[self.opt['real']], upsampled_altered)
|
||||
return self.criterion(state[self.opt['real']].float(), upsampled_altered.float())
|
||||
|
||||
|
||||
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
||||
|
@ -353,7 +354,7 @@ class TranslationInvarianceLoss(ConfigurableLoss):
|
|||
if self.opt['criterion'] == 'cosine':
|
||||
return self.criterion(fake_shared_output, real_shared_output, torch.ones(1, device=real_shared_output.device))
|
||||
else:
|
||||
return self.criterion(fake_shared_output, real_shared_output)
|
||||
return self.criterion(fake_shared_output.float(), real_shared_output.float())
|
||||
|
||||
|
||||
# Computes a loss repeatedly feeding the generator downsampled inputs created from its outputs. The expectation is
|
||||
|
@ -392,7 +393,7 @@ class RecursiveInvarianceLoss(ConfigurableLoss):
|
|||
if self.opt['criterion'] == 'cosine':
|
||||
return self.criterion(compare_real, compare_fake, torch.ones(1, device=compare_real.device))
|
||||
else:
|
||||
return self.criterion(compare_real, compare_fake)
|
||||
return self.criterion(compare_real.float(), compare_fake.float())
|
||||
|
||||
|
||||
# Loss that pulls tensors from dim 1 of the input and repeatedly feeds them into the
|
||||
|
|
|
@ -207,8 +207,8 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
lr = state[self.opt['lr_inputs']]
|
||||
l_total = 0
|
||||
for i in range(sequence_len - 2):
|
||||
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16)
|
||||
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16)
|
||||
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
||||
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
||||
with autocast(enabled=fp16):
|
||||
d_fake = net(fake_sext)
|
||||
d_real = net(real_sext)
|
||||
|
|
|
@ -25,7 +25,7 @@ def main(master_opt, launcher):
|
|||
for k, v in model.networks.items():
|
||||
if k in all_networks.keys() and k not in shared_networks:
|
||||
shared_networks.append(k)
|
||||
all_networks[k] = v
|
||||
all_networks[k] = v.module
|
||||
trainers.append(train_gen)
|
||||
print("Networks being shared by trainers: ", shared_networks)
|
||||
|
||||
|
|
|
@ -175,7 +175,7 @@ class Trainer:
|
|||
#### log
|
||||
if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0:
|
||||
logs = self.model.get_current_log(self.current_step)
|
||||
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, self.current_step)
|
||||
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step)
|
||||
for v in self.model.get_current_learning_rate():
|
||||
message += '{:.3e},'.format(v)
|
||||
message += ')] '
|
||||
|
@ -196,7 +196,7 @@ class Trainer:
|
|||
if self.rank <= 0:
|
||||
self.logger.info('Saving models and training states.')
|
||||
self.model.save(self.current_step)
|
||||
self.model.save_training_state(epoch, self.current_step)
|
||||
self.model.save_training_state(self.epoch, self.current_step)
|
||||
if 'alt_path' in opt['path'].keys():
|
||||
import shutil
|
||||
print("Synchronizing tb_logger to alt_path..")
|
||||
|
@ -253,6 +253,7 @@ class Trainer:
|
|||
def do_training(self):
|
||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
||||
for epoch in range(self.start_epoch, self.total_epochs + 1):
|
||||
self.epoch = epoch
|
||||
if opt['dist']:
|
||||
self.train_sampler.set_epoch(epoch)
|
||||
tq_ldr = tqdm(self.train_loader)
|
||||
|
@ -264,6 +265,7 @@ class Trainer:
|
|||
def create_training_generator(self, index):
|
||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
||||
for epoch in range(self.start_epoch, self.total_epochs + 1):
|
||||
self.epoch = epoch
|
||||
if self.opt['dist']:
|
||||
self.train_sampler.set_epoch(epoch)
|
||||
tq_ldr = tqdm(self.train_loader, position=index)
|
||||
|
|
Loading…
Reference in New Issue
Block a user