Enable RRDB to take in reference inputs
This commit is contained in:
parent
7d38381d46
commit
3791f95ad0
|
@ -149,7 +149,10 @@ class RRDBNet(nn.Module):
|
|||
self.num_blocks = num_blocks
|
||||
self.blocks_per_checkpoint = blocks_per_checkpoint
|
||||
self.scale = scale
|
||||
self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
|
||||
self.in_channels = in_channels
|
||||
first_conv_stride = 1 if in_channels <= 4 else scale
|
||||
first_conv_ksize = 3 if first_conv_stride == 1 else 7
|
||||
self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, 1)
|
||||
self.body = make_layer(
|
||||
body_block,
|
||||
num_blocks,
|
||||
|
@ -170,7 +173,7 @@ class RRDBNet(nn.Module):
|
|||
]:
|
||||
default_init_weights(m, 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, ref=None):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
|
@ -179,8 +182,12 @@ class RRDBNet(nn.Module):
|
|||
Returns:
|
||||
Tensor: Forward results.
|
||||
"""
|
||||
|
||||
feat = self.conv_first(x)
|
||||
if self.in_channels > 4:
|
||||
x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic")
|
||||
if ref is None:
|
||||
ref = torch.zeros_like(x_lg)
|
||||
x_lg = torch.cat([x_lg, ref])
|
||||
feat = self.conv_first(x_lg)
|
||||
body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat))
|
||||
feat = feat + body_feat
|
||||
# upsample
|
||||
|
|
Loading…
Reference in New Issue
Block a user