From efa737b685370a376c3830c96485bf57b964450f Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 11 May 2022 21:14:18 -0600 Subject: [PATCH] re-add distributed collect to clvp --- codes/models/classifiers/twin_cifar_resnet.py | 166 ++++++++++++++++++ codes/models/clip/clvp.py | 10 ++ 2 files changed, 176 insertions(+) create mode 100644 codes/models/classifiers/twin_cifar_resnet.py diff --git a/codes/models/classifiers/twin_cifar_resnet.py b/codes/models/classifiers/twin_cifar_resnet.py new file mode 100644 index 00000000..ceb78064 --- /dev/null +++ b/codes/models/classifiers/twin_cifar_resnet.py @@ -0,0 +1,166 @@ +"""resnet in pytorch + + + +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. + + Deep Residual Learning for Image Recognition + https://arxiv.org/abs/1512.03385v1 +""" + +import torch +import torch.nn as nn + +from trainer.networks import register_model + + +class BasicBlock(nn.Module): + """Basic Block for resnet 18 and resnet 34 + + """ + + #BasicBlock and BottleNeck block + #have different output size + #we use class attribute expansion + #to distinct + expansion = 1 + + def __init__(self, in_channels, out_channels, stride=1): + super().__init__() + + #residual function + self.residual_function = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels * BasicBlock.expansion) + ) + + #shortcut + self.shortcut = nn.Sequential() + + #the shortcut output dimension is not the same with residual function + #use 1*1 convolution to match the dimension + if stride != 1 or in_channels != BasicBlock.expansion * out_channels: + self.shortcut = nn.Sequential( + nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_channels * BasicBlock.expansion) + ) + + def forward(self, x): + return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) + +class BottleNeck(nn.Module): + """Residual block for resnet over 50 layers + + """ + expansion = 4 + def __init__(self, in_channels, out_channels, stride=1): + super().__init__() + self.residual_function = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels * BottleNeck.expansion), + ) + + self.shortcut = nn.Sequential() + + if stride != 1 or in_channels != out_channels * BottleNeck.expansion: + self.shortcut = nn.Sequential( + nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels * BottleNeck.expansion) + ) + + def forward(self, x): + return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) + +class ResNet(nn.Module): + + def __init__(self, block, num_block, num_classes=100): + super().__init__() + + self.in_channels = 32 + + self.conv1 = nn.Sequential( + nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True)) + #we use a different inputsize than the original paper + #so conv2_x's stride is 1 + self.conv2_x = self._make_layer(block, 32, num_block[0], 1) + self.conv3_x = self._make_layer(block, 64, num_block[1], 2) + self.conv4_x = self._make_layer(block, 128, num_block[2], 2) + self.conv5_x = self._make_layer(block, 256, num_block[3], 2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(256 * block.expansion, num_classes) + + def _make_layer(self, block, out_channels, num_blocks, stride): + """make resnet layers(by layer i didnt mean this 'layer' was the + same as a neuron netowork layer, ex. conv layer), one layer may + contain more than one residual block + + Args: + block: block type, basic block or bottle neck block + out_channels: output depth channel number of this layer + num_blocks: how many blocks per layer + stride: the stride of the first block of this layer + + Return: + return a resnet layer + """ + + # we have num_block blocks per layer, the first block + # could be 1 or 2, other blocks would always be 1 + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_channels, out_channels, stride)) + self.in_channels = out_channels * block.expansion + + return nn.Sequential(*layers) + + def forward(self, x): + output = self.conv1(x) + output = self.conv2_x(output) + output = self.conv3_x(output) + output = self.conv4_x(output) + output = self.conv5_x(output) + output = self.avgpool(output) + output = output.view(output.size(0), -1) + output = self.fc(output) + + return output + +@register_model +def register_cifar_resnet18(opt_net, opt): + """ return a ResNet 18 object + """ + return ResNet(BasicBlock, [2, 2, 2, 2]) + +def resnet34(): + """ return a ResNet 34 object + """ + return ResNet(BasicBlock, [3, 4, 6, 3]) + +def resnet50(): + """ return a ResNet 50 object + """ + return ResNet(BottleNeck, [3, 4, 6, 3]) + +def resnet101(): + """ return a ResNet 101 object + """ + return ResNet(BottleNeck, [3, 4, 23, 3]) + +def resnet152(): + """ return a ResNet 152 object + """ + return ResNet(BottleNeck, [3, 8, 36, 3]) + + diff --git a/codes/models/clip/clvp.py b/codes/models/clip/clvp.py index 7626c7ca..7f4a3461 100644 --- a/codes/models/clip/clvp.py +++ b/codes/models/clip/clvp.py @@ -86,6 +86,7 @@ class CLVP(nn.Module): speech_enc_depth=6, speech_mask_percentage=0, latent_multiplier=4, + distributed_collect=False, ): super().__init__() latent_dim = latent_multiplier*model_dim @@ -100,6 +101,7 @@ class CLVP(nn.Module): self.text_emb = nn.Embedding(num_text_tokens, model_dim) self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True) self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False) + self.distributed_collect = distributed_collect if mel_codes is None: self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2) @@ -144,6 +146,14 @@ class CLVP(nn.Module): text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) temp = self.temperature.exp() + if self.distributed_collect: + collective = [torch.zeros_like(text_latents) for _ in range(torch.distributed.get_world_size())] + torch.all_gather(collective, text_latents) + text_latents = torch.cat(collective, dim=0) + collective = [torch.zeros_like(speech_latents) for _ in range(torch.distributed.get_world_size())] + torch.all_gather(collective, speech_latents) + speech_latents = torch.cat(collective, dim=0) + if not return_loss: sim = einsum('n d, n d -> n', text_latents, speech_latents) * temp return sim