forked from mrq/DL-Art-School
Fix initialization in mhead switched rrdb
This commit is contained in:
parent
be7982b9ae
commit
0a714e8451
|
@ -85,6 +85,7 @@ class ResidualDenseBlock_5C_WithMheadConverter(ResidualDenseBlock_5C):
|
|||
late_stage_padding=0)
|
||||
self.heads = heads
|
||||
self.converter = nn.Conv3d(nf, nf, kernel_size=(heads, 1, 1), stride=(heads, 1, 1))
|
||||
arch_util.initialize_weights(self.converter)
|
||||
|
||||
# Accepts input of shape (b, heads, f, w, h)
|
||||
def forward(self, x):
|
||||
|
|
|
@ -13,7 +13,7 @@ def initialize_weights(net_l, scale=1):
|
|||
net_l = [net_l]
|
||||
for net in net_l:
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale # for residual block
|
||||
if m.bias is not None:
|
||||
|
|
|
@ -80,14 +80,7 @@ def define_G(opt, net_key='network_G'):
|
|||
'''netG = FlatProcessorNet_arch.FlatProcessorNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], downscale=opt_net['scale'], reduce_anneal_blocks=opt_net['ra_blocks'],
|
||||
assembler_blocks=opt_net['assembler_blocks'])'''
|
||||
netG = FlatProcessorNetNew_arch.fixup_resnet34(num_filters=opt_net['nf'])
|
||||
# video restoration
|
||||
elif which_model == 'EDVR':
|
||||
netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
|
||||
groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
|
||||
back_RBs=opt_net['back_RBs'], center=opt_net['center'],
|
||||
predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
|
||||
w_TSA=opt_net['w_TSA'])
|
||||
netG = FlatProcessorNetNew_arch.fixup_resnet34(num_filters=opt_net['nf'])\
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||
|
|
Loading…
Reference in New Issue
Block a user