More mods
This commit is contained in:
parent
dffbfd2ec4
commit
c04f244802
|
@ -543,9 +543,9 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False)
|
||||
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
|
||||
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
||||
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
|
||||
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)])
|
||||
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
|
||||
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=True)
|
||||
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
|
||||
self.attentions = None
|
||||
|
|
|
@ -95,10 +95,15 @@ class ConfigurableStep(Module):
|
|||
local_state.update(injected)
|
||||
new_state.update(injected)
|
||||
|
||||
if train:
|
||||
if train and len(self.losses) > 0:
|
||||
# Finally, compute the losses.
|
||||
total_loss = 0
|
||||
for loss_name, loss in self.losses.items():
|
||||
# Some losses only activate after a set number of steps. For example, proto-discriminator losses can
|
||||
# be very disruptive to a generator.
|
||||
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']:
|
||||
continue
|
||||
|
||||
l = loss(self.training_net, local_state)
|
||||
total_loss += l * self.weights[loss_name]
|
||||
# Record metrics.
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr3_gan.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_spsr_switched2_psnr.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user