w2v_matcher mods
This commit is contained in:
parent
9029e4f20c
commit
998c53ad4f
|
@ -6,6 +6,7 @@ import torch.nn.functional as F
|
||||||
from x_transformers import Encoder, Decoder, ContinuousTransformerWrapper
|
from x_transformers import Encoder, Decoder, ContinuousTransformerWrapper
|
||||||
|
|
||||||
from models.gpt_voice.mini_encoder import AudioMiniEncoder
|
from models.gpt_voice.mini_encoder import AudioMiniEncoder
|
||||||
|
from trainer.networks import register_model
|
||||||
|
|
||||||
|
|
||||||
class CheckpointedLayer(nn.Module):
|
class CheckpointedLayer(nn.Module):
|
||||||
|
@ -42,6 +43,8 @@ class CheckpointedXTransformer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Wav2VecMatcher(nn.Module):
|
class Wav2VecMatcher(nn.Module):
|
||||||
|
W2V_COMPRESSION=320
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_dim,
|
model_dim,
|
||||||
encoder_depth,
|
encoder_depth,
|
||||||
|
@ -88,7 +91,15 @@ class Wav2VecMatcher(nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, text_tokens, conditioning_clip, w2v_logits, token_lengths, w2v_lengths):
|
def get_grad_norm_parameter_groups(self):
|
||||||
|
return {
|
||||||
|
'encoder': list(self.encoder.parameters()),
|
||||||
|
'decoder': list(self.decoder.parameters()),
|
||||||
|
'heads': list(self.w2v_query_encoder.parameters()) + list(self.w2v_value_encoder.parameters()),
|
||||||
|
'minicoder': list(self.conditioning_encoder.parameters()),
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self, text_tokens, conditioning_clip, w2v_logits, token_lengths, clip_lengths):
|
||||||
# Clip off text_lengths where possible to save compute.
|
# Clip off text_lengths where possible to save compute.
|
||||||
max_text_len = token_lengths.max()
|
max_text_len = token_lengths.max()
|
||||||
text_tokens = text_tokens[:, :max_text_len]
|
text_tokens = text_tokens[:, :max_text_len]
|
||||||
|
@ -102,31 +113,34 @@ class Wav2VecMatcher(nn.Module):
|
||||||
dec_out = self.decoder(dec_inputs, context=dec_context)[:, :-1]
|
dec_out = self.decoder(dec_inputs, context=dec_context)[:, :-1]
|
||||||
w2v_queries = self.w2v_query_encoder(w2v_logits)
|
w2v_queries = self.w2v_query_encoder(w2v_logits)
|
||||||
|
|
||||||
# Compute loss
|
# Compute losses, A CLIP-like dot product matcher and a mechanism to force pad prediction.
|
||||||
b,l,c = dec_out.shape
|
b,l,c = dec_out.shape
|
||||||
keys_uncompressed = dec_out.reshape(b*l, c)
|
keys_uncompressed = dec_out.reshape(b*l, c)
|
||||||
queries_uncompressed = w2v_queries.reshape(b*l, c)
|
queries_uncompressed = w2v_queries.reshape(b*l, c)
|
||||||
dot = torch.einsum("i c, j c -> i j", keys_uncompressed, queries_uncompressed)
|
dot = torch.einsum("i c, j c -> i j", keys_uncompressed, queries_uncompressed)
|
||||||
labels = torch.arange(0, b*l, 1, device=dot.device)
|
labels = torch.arange(0, b*l, 1, device=dot.device)
|
||||||
# TODO: weight the cross entropy: logits from the same clip should be weighted as possible "matches" (say, share ~10% of the probability mass). Logits near
|
|
||||||
# the w2v logits should also get a bump in probability mass. Cross entropy is probably not the right avenue for this. This is important to enable
|
|
||||||
# "searching" for w2v matches from a large pool.
|
|
||||||
ce_loss1 = F.cross_entropy(dot, labels, reduction="none")
|
ce_loss1 = F.cross_entropy(dot, labels, reduction="none")
|
||||||
ce_loss2 = F.cross_entropy(dot.t(), labels, reduction="none")
|
ce_loss2 = F.cross_entropy(dot.t(), labels, reduction="none")
|
||||||
mse_pad_loss = F.mse_loss(keys_uncompressed, self.decoder_stop_embedding.repeat(b*l,1), reduction="none").sum(dim=-1)
|
mse_pad_loss = F.mse_loss(keys_uncompressed, self.decoder_stop_embedding.repeat(b*l,1), reduction="none").sum(dim=-1)
|
||||||
|
|
||||||
# Create a mask based on w2v_lengths that will be used to ensure the encodings of padding tokens are not considered in the cross entropy loss
|
# Create a mask based on w2v_lengths that will be used to ensure the encodings of padding tokens are not considered in the cross entropy loss
|
||||||
loss_mask = torch.ones((b,l), device=ce_loss1.device)
|
loss_mask = torch.ones((b,l), device=ce_loss1.device)
|
||||||
|
w2v_lengths = clip_lengths // self.W2V_COMPRESSION
|
||||||
for i in range(b):
|
for i in range(b):
|
||||||
loss_mask[i, w2v_lengths[i]:] = 0
|
loss_mask[i, w2v_lengths[i]:] = 0
|
||||||
loss_mask = loss_mask.reshape(b*l)
|
loss_mask_collapsed = loss_mask.reshape(b*l)
|
||||||
|
|
||||||
ce_loss = (ce_loss1 * loss_mask + ce_loss2 * loss_mask).mean()
|
ce_loss = (ce_loss1 * loss_mask_collapsed + ce_loss2 * loss_mask_collapsed).mean()
|
||||||
mse_loss = (mse_pad_loss * (loss_mask == 0)).mean()
|
mse_loss = (mse_pad_loss * (loss_mask_collapsed == 0)).mean()
|
||||||
|
|
||||||
return ce_loss, mse_loss
|
return ce_loss, mse_loss
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_w2v_matcher(opt_net, opt):
|
||||||
|
return Wav2VecMatcher(**opt_net['kwargs'])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = Wav2VecMatcher(512, 8, 8)
|
model = Wav2VecMatcher(512, 8, 8)
|
||||||
toks = torch.randint(0, 100, (4,100))
|
toks = torch.randint(0, 100, (4,100))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user