More mods

This commit is contained in:
James Betker 2020-09-08 20:36:27 -06:00
parent dffbfd2ec4
commit c04f244802
3 changed files with 10 additions and 5 deletions

View File

@ -543,9 +543,9 @@ class SwitchedSpsrWithRef2(nn.Module):
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False)
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)])
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=True)
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
self.attentions = None

View File

@ -95,10 +95,15 @@ class ConfigurableStep(Module):
local_state.update(injected)
new_state.update(injected)
if train:
if train and len(self.losses) > 0:
# Finally, compute the losses.
total_loss = 0
for loss_name, loss in self.losses.items():
# Some losses only activate after a set number of steps. For example, proto-discriminator losses can
# be very disruptive to a generator.
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']:
continue
l = loss(self.training_net, local_state)
total_loss += l * self.weights[loss_name]
# Record metrics.

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr3_gan.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_spsr_switched2_psnr.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)