From 82aad335ba9d43df44d9a665b41112c0aaeb2288 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 15 Apr 2022 09:31:48 -0600 Subject: [PATCH] add distributued logic for loss --- codes/models/clip/clvp.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/codes/models/clip/clvp.py b/codes/models/clip/clvp.py index 2518d7b1..11a16af8 100644 --- a/codes/models/clip/clvp.py +++ b/codes/models/clip/clvp.py @@ -3,7 +3,8 @@ from random import random import torch import torch.nn as nn import torch.nn.functional as F -from torch import einsum +from torch import einsum, distributed +from torch.distributed import get_world_size from models.arch_util import AttentionBlock from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder @@ -85,6 +86,7 @@ class CLVP(nn.Module): speech_enc_depth=6, speech_mask_percentage=0, latent_multiplier=4, + is_distributed=False, ): super().__init__() latent_dim = latent_multiplier*model_dim @@ -100,6 +102,8 @@ class CLVP(nn.Module): self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True) self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False) + self.distributed = is_distributed + if mel_codes is None: self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2) else: @@ -139,9 +143,16 @@ class CLVP(nn.Module): text_latents = self.to_text_latent(enc_text) speech_latents = self.to_speech_latent(enc_speech) + if self.distributed: + ws = get_world_size() + text_gather_cells = [torch.zeros_like(text_latents) for _ in range(ws)] + speech_gather_cells = [torch.zeros_like(speech_latents) for _ in range(ws)] + distributed.all_gather(text_gather_cells, text_latents) + text_latents = torch.cat(text_gather_cells, dim=0) + distributed.all_gather(speech_gather_cells, speech_latents) + speech_latents = torch.cat(speech_gather_cells, dim=0) text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) - temp = self.temperature.exp() if not return_loss: