More adjustments to support multi-modal training

Specifically - looks like at least MSE loss cannot handle autocasted tensors
This commit is contained in:
James Betker 2020-10-22 16:49:34 -06:00
parent 76789a456f
commit e9c0b9f0fd
5 changed files with 19 additions and 16 deletions

View File

@ -54,7 +54,7 @@ class GANLoss(nn.Module):
target_label = target_is_real target_label = target_is_real
else: else:
target_label = self.get_target_label(input, target_is_real) target_label = self.get_target_label(input, target_is_real)
loss = self.loss(input, target_label) loss = self.loss(input.float(), target_label.float())
return loss return loss

View File

@ -91,7 +91,7 @@ class PixLoss(ConfigurableLoss):
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device']) self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
def forward(self, _, state): def forward(self, _, state):
return self.criterion(state[self.opt['fake']], state[self.opt['real']]) return self.criterion(state[self.opt['fake']].float(), state[self.opt['real']].float())
class FeatureLoss(ConfigurableLoss): class FeatureLoss(ConfigurableLoss):
@ -105,13 +105,14 @@ class FeatureLoss(ConfigurableLoss):
self.netF = torch.nn.parallel.DataParallel(self.netF) self.netF = torch.nn.parallel.DataParallel(self.netF)
def forward(self, _, state): def forward(self, _, state):
with torch.no_grad(): with autocast(enabled=self.env['opt']['fp16']):
logits_real = self.netF(state[self.opt['real']]) with torch.no_grad():
logits_fake = self.netF(state[self.opt['fake']]) logits_real = self.netF(state[self.opt['real']])
logits_fake = self.netF(state[self.opt['fake']])
if self.opt['criterion'] == 'cosine': if self.opt['criterion'] == 'cosine':
return self.criterion(logits_fake, logits_real, torch.ones(1, device=logits_fake.device)) return self.criterion(logits_fake.float(), logits_real.float(), torch.ones(1, device=logits_fake.device))
else: else:
return self.criterion(logits_fake, logits_real) return self.criterion(logits_fake.float(), logits_real.float())
# Special form of feature loss which first computes the feature embedding for the truth space, then uses a second # Special form of feature loss which first computes the feature embedding for the truth space, then uses a second
@ -132,7 +133,7 @@ class InterpretedFeatureLoss(ConfigurableLoss):
def forward(self, _, state): def forward(self, _, state):
logits_real = self.netF_real(state[self.opt['real']]) logits_real = self.netF_real(state[self.opt['real']])
logits_fake = self.netF_gen(state[self.opt['fake']]) logits_fake = self.netF_gen(state[self.opt['fake']])
return self.criterion(logits_fake, logits_real) return self.criterion(logits_fake.float(), logits_real.float())
class GeneratorGanLoss(ConfigurableLoss): class GeneratorGanLoss(ConfigurableLoss):
@ -300,7 +301,7 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
if self.opt['criterion'] == 'cosine': if self.opt['criterion'] == 'cosine':
return self.criterion(state[self.opt['real']], upsampled_altered, torch.ones(1, device=upsampled_altered.device)) return self.criterion(state[self.opt['real']], upsampled_altered, torch.ones(1, device=upsampled_altered.device))
else: else:
return self.criterion(state[self.opt['real']], upsampled_altered) return self.criterion(state[self.opt['real']].float(), upsampled_altered.float())
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an # Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
@ -353,7 +354,7 @@ class TranslationInvarianceLoss(ConfigurableLoss):
if self.opt['criterion'] == 'cosine': if self.opt['criterion'] == 'cosine':
return self.criterion(fake_shared_output, real_shared_output, torch.ones(1, device=real_shared_output.device)) return self.criterion(fake_shared_output, real_shared_output, torch.ones(1, device=real_shared_output.device))
else: else:
return self.criterion(fake_shared_output, real_shared_output) return self.criterion(fake_shared_output.float(), real_shared_output.float())
# Computes a loss repeatedly feeding the generator downsampled inputs created from its outputs. The expectation is # Computes a loss repeatedly feeding the generator downsampled inputs created from its outputs. The expectation is
@ -392,7 +393,7 @@ class RecursiveInvarianceLoss(ConfigurableLoss):
if self.opt['criterion'] == 'cosine': if self.opt['criterion'] == 'cosine':
return self.criterion(compare_real, compare_fake, torch.ones(1, device=compare_real.device)) return self.criterion(compare_real, compare_fake, torch.ones(1, device=compare_real.device))
else: else:
return self.criterion(compare_real, compare_fake) return self.criterion(compare_real.float(), compare_fake.float())
# Loss that pulls tensors from dim 1 of the input and repeatedly feeds them into the # Loss that pulls tensors from dim 1 of the input and repeatedly feeds them into the

View File

@ -207,8 +207,8 @@ class TecoGanLoss(ConfigurableLoss):
lr = state[self.opt['lr_inputs']] lr = state[self.opt['lr_inputs']]
l_total = 0 l_total = 0
for i in range(sequence_len - 2): for i in range(sequence_len - 2):
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16) real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin)
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16) fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin)
with autocast(enabled=fp16): with autocast(enabled=fp16):
d_fake = net(fake_sext) d_fake = net(fake_sext)
d_real = net(real_sext) d_real = net(real_sext)

View File

@ -25,7 +25,7 @@ def main(master_opt, launcher):
for k, v in model.networks.items(): for k, v in model.networks.items():
if k in all_networks.keys() and k not in shared_networks: if k in all_networks.keys() and k not in shared_networks:
shared_networks.append(k) shared_networks.append(k)
all_networks[k] = v all_networks[k] = v.module
trainers.append(train_gen) trainers.append(train_gen)
print("Networks being shared by trainers: ", shared_networks) print("Networks being shared by trainers: ", shared_networks)

View File

@ -175,7 +175,7 @@ class Trainer:
#### log #### log
if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0: if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0:
logs = self.model.get_current_log(self.current_step) logs = self.model.get_current_log(self.current_step)
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, self.current_step) message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step)
for v in self.model.get_current_learning_rate(): for v in self.model.get_current_learning_rate():
message += '{:.3e},'.format(v) message += '{:.3e},'.format(v)
message += ')] ' message += ')] '
@ -196,7 +196,7 @@ class Trainer:
if self.rank <= 0: if self.rank <= 0:
self.logger.info('Saving models and training states.') self.logger.info('Saving models and training states.')
self.model.save(self.current_step) self.model.save(self.current_step)
self.model.save_training_state(epoch, self.current_step) self.model.save_training_state(self.epoch, self.current_step)
if 'alt_path' in opt['path'].keys(): if 'alt_path' in opt['path'].keys():
import shutil import shutil
print("Synchronizing tb_logger to alt_path..") print("Synchronizing tb_logger to alt_path..")
@ -253,6 +253,7 @@ class Trainer:
def do_training(self): def do_training(self):
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step)) self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
for epoch in range(self.start_epoch, self.total_epochs + 1): for epoch in range(self.start_epoch, self.total_epochs + 1):
self.epoch = epoch
if opt['dist']: if opt['dist']:
self.train_sampler.set_epoch(epoch) self.train_sampler.set_epoch(epoch)
tq_ldr = tqdm(self.train_loader) tq_ldr = tqdm(self.train_loader)
@ -264,6 +265,7 @@ class Trainer:
def create_training_generator(self, index): def create_training_generator(self, index):
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step)) self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
for epoch in range(self.start_epoch, self.total_epochs + 1): for epoch in range(self.start_epoch, self.total_epochs + 1):
self.epoch = epoch
if self.opt['dist']: if self.opt['dist']:
self.train_sampler.set_epoch(epoch) self.train_sampler.set_epoch(epoch)
tq_ldr = tqdm(self.train_loader, position=index) tq_ldr = tqdm(self.train_loader, position=index)