diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 3a139dc7..fdc7ddaf 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -18,7 +18,7 @@ class MultiConvBlock(nn.Module): self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor)] + [ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor) for i in range(depth - 2)] + [ConvBnLelu(filters_mid, filters_out, kernel_size, activation=False, norm=False, bias=False, weight_init_factor=weight_init_factor)]) - self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init)) + self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init, dtype=torch.float)) self.bias = nn.Parameter(torch.zeros(1)) def forward(self, x, noise=None):