From 8dd44182e63d8f45fcc0b1c5e1f7cfbc113bf13f Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 31 Jul 2020 16:56:04 -0600 Subject: [PATCH] Fix scale torch warning --- codes/models/archs/SwitchedResidualGenerator_arch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 3a139dc7..fdc7ddaf 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -18,7 +18,7 @@ class MultiConvBlock(nn.Module): self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor)] + [ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor) for i in range(depth - 2)] + [ConvBnLelu(filters_mid, filters_out, kernel_size, activation=False, norm=False, bias=False, weight_init_factor=weight_init_factor)]) - self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init)) + self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init, dtype=torch.float)) self.bias = nn.Parameter(torch.zeros(1)) def forward(self, x, noise=None):