add distributued logic for loss

This commit is contained in:
James Betker 2022-04-15 09:31:48 -06:00
parent efe12cb816
commit 82aad335ba

View File

@ -3,7 +3,8 @@ from random import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from torch import einsum, distributed
from torch.distributed import get_world_size
from models.arch_util import AttentionBlock
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
@ -85,6 +86,7 @@ class CLVP(nn.Module):
speech_enc_depth=6,
speech_mask_percentage=0,
latent_multiplier=4,
is_distributed=False,
):
super().__init__()
latent_dim = latent_multiplier*model_dim
@ -100,6 +102,8 @@ class CLVP(nn.Module):
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False)
self.distributed = is_distributed
if mel_codes is None:
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
else:
@ -139,9 +143,16 @@ class CLVP(nn.Module):
text_latents = self.to_text_latent(enc_text)
speech_latents = self.to_speech_latent(enc_speech)
if self.distributed:
ws = get_world_size()
text_gather_cells = [torch.zeros_like(text_latents) for _ in range(ws)]
speech_gather_cells = [torch.zeros_like(speech_latents) for _ in range(ws)]
distributed.all_gather(text_gather_cells, text_latents)
text_latents = torch.cat(text_gather_cells, dim=0)
distributed.all_gather(speech_gather_cells, speech_latents)
speech_latents = torch.cat(speech_gather_cells, dim=0)
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
temp = self.temperature.exp()
if not return_loss: