forked from mrq/DL-Art-School
More mods
This commit is contained in:
parent
dffbfd2ec4
commit
c04f244802
|
@ -543,9 +543,9 @@ class SwitchedSpsrWithRef2(nn.Module):
|
||||||
attention_norm=True,
|
attention_norm=True,
|
||||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||||
add_scalable_noise_to_transforms=False)
|
add_scalable_noise_to_transforms=False)
|
||||||
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
|
||||||
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
|
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)])
|
||||||
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False)
|
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
|
||||||
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=True)
|
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=True)
|
||||||
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
|
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
|
||||||
self.attentions = None
|
self.attentions = None
|
||||||
|
|
|
@ -95,10 +95,15 @@ class ConfigurableStep(Module):
|
||||||
local_state.update(injected)
|
local_state.update(injected)
|
||||||
new_state.update(injected)
|
new_state.update(injected)
|
||||||
|
|
||||||
if train:
|
if train and len(self.losses) > 0:
|
||||||
# Finally, compute the losses.
|
# Finally, compute the losses.
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
for loss_name, loss in self.losses.items():
|
for loss_name, loss in self.losses.items():
|
||||||
|
# Some losses only activate after a set number of steps. For example, proto-discriminator losses can
|
||||||
|
# be very disruptive to a generator.
|
||||||
|
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']:
|
||||||
|
continue
|
||||||
|
|
||||||
l = loss(self.training_net, local_state)
|
l = loss(self.training_net, local_state)
|
||||||
total_loss += l * self.weights[loss_name]
|
total_loss += l * self.weights[loss_name]
|
||||||
# Record metrics.
|
# Record metrics.
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr3_gan.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/pretrain_spsr_switched2_psnr.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user