diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index acfd3ce9..8b763b09 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -149,7 +149,10 @@ class RRDBNet(nn.Module): self.num_blocks = num_blocks self.blocks_per_checkpoint = blocks_per_checkpoint self.scale = scale - self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, 1, 1) + self.in_channels = in_channels + first_conv_stride = 1 if in_channels <= 4 else scale + first_conv_ksize = 3 if first_conv_stride == 1 else 7 + self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, 1) self.body = make_layer( body_block, num_blocks, @@ -170,7 +173,7 @@ class RRDBNet(nn.Module): ]: default_init_weights(m, 0.1) - def forward(self, x): + def forward(self, x, ref=None): """Forward function. Args: @@ -179,8 +182,12 @@ class RRDBNet(nn.Module): Returns: Tensor: Forward results. """ - - feat = self.conv_first(x) + if self.in_channels > 4: + x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic") + if ref is None: + ref = torch.zeros_like(x_lg) + x_lg = torch.cat([x_lg, ref]) + feat = self.conv_first(x_lg) body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)) feat = feat + body_feat # upsample