From 0a714e8451e1c168b689ade39eb11fe9d59747a2 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Mon, 15 Jun 2020 21:32:03 -0600
Subject: [PATCH] Fix initialization in mhead switched rrdb

---
 codes/models/archs/RRDBNet_arch.py | 1 +
 codes/models/archs/arch_util.py    | 2 +-
 codes/models/networks.py           | 9 +--------
 3 files changed, 3 insertions(+), 9 deletions(-)

diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py
index fc7725ce..af0139ac 100644
--- a/codes/models/archs/RRDBNet_arch.py
+++ b/codes/models/archs/RRDBNet_arch.py
@@ -85,6 +85,7 @@ class ResidualDenseBlock_5C_WithMheadConverter(ResidualDenseBlock_5C):
                                                                        late_stage_padding=0)
         self.heads = heads
         self.converter = nn.Conv3d(nf, nf, kernel_size=(heads, 1, 1), stride=(heads, 1, 1))
+        arch_util.initialize_weights(self.converter)
 
     # Accepts input of shape (b, heads, f, w, h)
     def forward(self, x):
diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py
index ecb7be76..f79d4532 100644
--- a/codes/models/archs/arch_util.py
+++ b/codes/models/archs/arch_util.py
@@ -13,7 +13,7 @@ def initialize_weights(net_l, scale=1):
         net_l = [net_l]
     for net in net_l:
         for m in net.modules():
-            if isinstance(m, nn.Conv2d):
+            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d):
                 init.kaiming_normal_(m.weight, a=0, mode='fan_in')
                 m.weight.data *= scale  # for residual block
                 if m.bias is not None:
diff --git a/codes/models/networks.py b/codes/models/networks.py
index cc1fb6e0..1da351dc 100644
--- a/codes/models/networks.py
+++ b/codes/models/networks.py
@@ -80,14 +80,7 @@ def define_G(opt, net_key='network_G'):
         '''netG = FlatProcessorNet_arch.FlatProcessorNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                 nf=opt_net['nf'], downscale=opt_net['scale'], reduce_anneal_blocks=opt_net['ra_blocks'],
                                 assembler_blocks=opt_net['assembler_blocks'])'''
-        netG = FlatProcessorNetNew_arch.fixup_resnet34(num_filters=opt_net['nf'])
-    # video restoration
-    elif which_model == 'EDVR':
-        netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
-                              groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
-                              back_RBs=opt_net['back_RBs'], center=opt_net['center'],
-                              predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
-                              w_TSA=opt_net['w_TSA'])
+        netG = FlatProcessorNetNew_arch.fixup_resnet34(num_filters=opt_net['nf'])\
 
     else:
         raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))