forked from mrq/DL-Art-School
add distributued logic for loss
This commit is contained in:
parent
efe12cb816
commit
82aad335ba
|
@ -3,7 +3,8 @@ from random import random
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import einsum
|
from torch import einsum, distributed
|
||||||
|
from torch.distributed import get_world_size
|
||||||
|
|
||||||
from models.arch_util import AttentionBlock
|
from models.arch_util import AttentionBlock
|
||||||
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
|
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
|
||||||
|
@ -85,6 +86,7 @@ class CLVP(nn.Module):
|
||||||
speech_enc_depth=6,
|
speech_enc_depth=6,
|
||||||
speech_mask_percentage=0,
|
speech_mask_percentage=0,
|
||||||
latent_multiplier=4,
|
latent_multiplier=4,
|
||||||
|
is_distributed=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
latent_dim = latent_multiplier*model_dim
|
latent_dim = latent_multiplier*model_dim
|
||||||
|
@ -100,6 +102,8 @@ class CLVP(nn.Module):
|
||||||
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
|
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
|
||||||
self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||||
|
|
||||||
|
self.distributed = is_distributed
|
||||||
|
|
||||||
if mel_codes is None:
|
if mel_codes is None:
|
||||||
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
|
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
|
||||||
else:
|
else:
|
||||||
|
@ -139,9 +143,16 @@ class CLVP(nn.Module):
|
||||||
|
|
||||||
text_latents = self.to_text_latent(enc_text)
|
text_latents = self.to_text_latent(enc_text)
|
||||||
speech_latents = self.to_speech_latent(enc_speech)
|
speech_latents = self.to_speech_latent(enc_speech)
|
||||||
|
if self.distributed:
|
||||||
|
ws = get_world_size()
|
||||||
|
text_gather_cells = [torch.zeros_like(text_latents) for _ in range(ws)]
|
||||||
|
speech_gather_cells = [torch.zeros_like(speech_latents) for _ in range(ws)]
|
||||||
|
distributed.all_gather(text_gather_cells, text_latents)
|
||||||
|
text_latents = torch.cat(text_gather_cells, dim=0)
|
||||||
|
distributed.all_gather(speech_gather_cells, speech_latents)
|
||||||
|
speech_latents = torch.cat(speech_gather_cells, dim=0)
|
||||||
|
|
||||||
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
||||||
|
|
||||||
temp = self.temperature.exp()
|
temp = self.temperature.exp()
|
||||||
|
|
||||||
if not return_loss:
|
if not return_loss:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user