Adjustments to pixpro & resnet-unet

I'm not really satisfied with what I got out of these networks on round 1.
Lets try again..
This commit is contained in:
James Betker 2021-01-06 15:00:46 -07:00
parent 9680294430
commit 01a589e712
2 changed files with 7 additions and 5 deletions

View File

@ -279,9 +279,11 @@ class PixelCL(nn.Module):
super().__init__() super().__init__()
DEFAULT_AUG = nn.Sequential( DEFAULT_AUG = nn.Sequential(
RandomApply(augs.ColorJitter(0.3, 0.3, 0.3, 0.2), p=0.8), RandomApply(augs.ColorJitter(0.6, 0.6, 0.6, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2), augs.RandomGrayscale(p=0.2),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1) RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
augs.RandomSolarize(p=0.5),
# Normalize left out because it should be done at the model level.
) )
self.augment1 = default(augment_fn, DEFAULT_AUG) self.augment1 = default(augment_fn, DEFAULT_AUG)

View File

@ -81,7 +81,7 @@ class UResNet50(torchvision.models.resnet.ResNet):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None, groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None): norm_layer=None, out_dim=128):
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group,
replace_stride_with_dilation, norm_layer) replace_stride_with_dilation, norm_layer)
if norm_layer is None: if norm_layer is None:
@ -110,7 +110,7 @@ class UResNet50(torchvision.models.resnet.ResNet):
conv3x3(512, 512), conv3x3(512, 512),
norm_layer(512), norm_layer(512),
nn.ReLU(), nn.ReLU(),
conv1x1(512, 128)) conv1x1(512, out_dim))
del self.fc # Not used in this implementation and just consumes a ton of GPU memory. del self.fc # Not used in this implementation and just consumes a ton of GPU memory.
@ -142,7 +142,7 @@ class UResNet50(torchvision.models.resnet.ResNet):
@register_model @register_model
def register_u_resnet50(opt_net, opt): def register_u_resnet50(opt_net, opt):
model = UResNet50(Bottleneck, [3, 4, 6, 3]) model = UResNet50(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim'])
return model return model