diff --git a/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py b/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py index 9d13da31..de5e11e8 100644 --- a/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py +++ b/codes/models/pixel_level_contrastive_learning/pixpro_lucidrains.py @@ -279,9 +279,11 @@ class PixelCL(nn.Module): super().__init__() DEFAULT_AUG = nn.Sequential( - RandomApply(augs.ColorJitter(0.3, 0.3, 0.3, 0.2), p=0.8), + RandomApply(augs.ColorJitter(0.6, 0.6, 0.6, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), - RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1) + RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), + augs.RandomSolarize(p=0.5), + # Normalize left out because it should be done at the model level. ) self.augment1 = default(augment_fn, DEFAULT_AUG) diff --git a/codes/models/pixel_level_contrastive_learning/resnet_unet.py b/codes/models/pixel_level_contrastive_learning/resnet_unet.py index e9fdbfaa..38d0227c 100644 --- a/codes/models/pixel_level_contrastive_learning/resnet_unet.py +++ b/codes/models/pixel_level_contrastive_learning/resnet_unet.py @@ -81,7 +81,7 @@ class UResNet50(torchvision.models.resnet.ResNet): def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, - norm_layer=None): + norm_layer=None, out_dim=128): super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, replace_stride_with_dilation, norm_layer) if norm_layer is None: @@ -110,7 +110,7 @@ class UResNet50(torchvision.models.resnet.ResNet): conv3x3(512, 512), norm_layer(512), nn.ReLU(), - conv1x1(512, 128)) + conv1x1(512, out_dim)) del self.fc # Not used in this implementation and just consumes a ton of GPU memory. @@ -142,7 +142,7 @@ class UResNet50(torchvision.models.resnet.ResNet): @register_model def register_u_resnet50(opt_net, opt): - model = UResNet50(Bottleneck, [3, 4, 6, 3]) + model = UResNet50(Bottleneck, [3, 4, 6, 3], out_dim=opt_net['odim']) return model