forked from mrq/DL-Art-School
new rev of ctc_code_gen with surrogate LM loss
This commit is contained in:
parent
d1d1ae32a1
commit
618a20412a
|
@ -12,6 +12,21 @@ from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
def clustered_mask(probability, shape, dev, lateral_expansion_radius_max=3):
|
||||||
|
"""
|
||||||
|
Produces a masking vector of the specified shape where each element has probability to be zero.
|
||||||
|
lateral_expansion_radius_max neighbors of any element that is zero also have a 50% chance to be zero.
|
||||||
|
Effectively, this produces clusters of masks tending to be lateral_expansion_radius_max wide.
|
||||||
|
|
||||||
|
Note: This means the algorithm has a far higher output probability for zeros then <probability>.
|
||||||
|
"""
|
||||||
|
mask = torch.rand(shape, device=dev)
|
||||||
|
mask = (mask < probability).float()
|
||||||
|
kernel = torch.tensor([.5 for _ in range(lateral_expansion_radius_max)] + [1] + [.5 for _ in range(lateral_expansion_radius_max)], device=dev)
|
||||||
|
mask = F.conv1d(mask.unsqueeze(1), kernel.view(1,1,2*lateral_expansion_radius_max+1), padding=lateral_expansion_radius_max).squeeze(1)
|
||||||
|
return torch.bernoulli(torch.clamp(mask, 0, 1)) == 0 # ==0 logically inverts the mask.
|
||||||
|
|
||||||
|
|
||||||
class CheckpointedTransformerWrapper(nn.Module):
|
class CheckpointedTransformerWrapper(nn.Module):
|
||||||
"""
|
"""
|
||||||
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
||||||
|
@ -30,12 +45,14 @@ class CheckpointedTransformerWrapper(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class CtcCodeGenerator(nn.Module):
|
class CtcCodeGenerator(nn.Module):
|
||||||
def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=121, max_repeat=30):
|
def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=121, max_repeat=30, mask_prob=.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_pad = max_pad
|
self.max_pad = max_pad
|
||||||
self.max_repeat = max_repeat
|
self.max_repeat = max_repeat
|
||||||
|
self.mask_probability = mask_prob
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True)
|
||||||
self.initial_embedding = nn.Embedding(ctc_codes, model_dim)
|
self.initial_embedding = nn.Embedding(ctc_codes, model_dim)
|
||||||
|
self.combiner = nn.Linear(model_dim*2, model_dim)
|
||||||
self.transformer = TransformerWrapper(
|
self.transformer = TransformerWrapper(
|
||||||
num_tokens=max_pad*max_repeat+1,
|
num_tokens=max_pad*max_repeat+1,
|
||||||
max_seq_len=-1, # Unneeded for rotary embeddings.
|
max_seq_len=-1, # Unneeded for rotary embeddings.
|
||||||
|
@ -51,6 +68,9 @@ class CtcCodeGenerator(nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.transformer.token_emb = nn.Identity() # This class handles the initial embeddings.
|
self.transformer.token_emb = nn.Identity() # This class handles the initial embeddings.
|
||||||
|
self.transformer.to_logits = nn.Identity()
|
||||||
|
self.ctc_head = nn.Linear(model_dim, max_pad*max_repeat+1)
|
||||||
|
self.inp_head = nn.Linear(model_dim, ctc_codes)
|
||||||
|
|
||||||
def forward(self, conditioning_input, codes, separators, repeats, unpadded_lengths):
|
def forward(self, conditioning_input, codes, separators, repeats, unpadded_lengths):
|
||||||
max_len = unpadded_lengths.max()
|
max_len = unpadded_lengths.max()
|
||||||
|
@ -58,6 +78,7 @@ class CtcCodeGenerator(nn.Module):
|
||||||
loss_mask = torch.ones_like(codes)
|
loss_mask = torch.ones_like(codes)
|
||||||
for i, l in enumerate(unpadded_lengths):
|
for i, l in enumerate(unpadded_lengths):
|
||||||
loss_mask[i, l:] = 0
|
loss_mask[i, l:] = 0
|
||||||
|
codes = clustered_mask(self.mask_probability, codes.shape, codes.device) * codes
|
||||||
|
|
||||||
if separators.max() > self.max_pad:
|
if separators.max() > self.max_pad:
|
||||||
print(f"Got unexpectedly long separators. Max: {separators.max()}, {separators}")
|
print(f"Got unexpectedly long separators. Max: {separators.max()}, {separators}")
|
||||||
|
@ -71,22 +92,19 @@ class CtcCodeGenerator(nn.Module):
|
||||||
labels = separators + repeats * self.max_pad
|
labels = separators + repeats * self.max_pad
|
||||||
|
|
||||||
# Perform conditioning encoder in FP32, with the transformer in FP16
|
# Perform conditioning encoder in FP32, with the transformer in FP16
|
||||||
conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input
|
cond = self.conditioning_encoder(conditioning_input).unsqueeze(1).repeat(1,codes.shape[1],1)
|
||||||
conds = []
|
h = torch.cat([cond, self.initial_embedding(codes)], dim=-1)
|
||||||
for j in range(conditioning_input.shape[1]):
|
h = self.combiner(h)
|
||||||
conds.append(self.conditioning_encoder(conditioning_input[:, j]))
|
|
||||||
conds = torch.stack(conds, dim=1)
|
|
||||||
|
|
||||||
with torch.autocast(codes.device.type):
|
with torch.autocast(codes.device.type):
|
||||||
h = self.initial_embedding(codes)
|
|
||||||
h = torch.cat([conds, h], dim=1)
|
|
||||||
logits = self.transformer(h)
|
logits = self.transformer(h)
|
||||||
# Ignore the cond outputs
|
ctc_pred = self.ctc_head(logits)
|
||||||
logits = logits[:, conds.shape[1]:, :]
|
code_pred = self.inp_head(logits)
|
||||||
|
|
||||||
loss = F.cross_entropy(logits.float().permute(0,2,1), labels, reduction='none')
|
ctcloss = F.cross_entropy(ctc_pred.float().permute(0,2,1), labels, reduction='none')
|
||||||
loss = torch.mean(loss * loss_mask)
|
ctcloss = torch.mean(ctcloss * loss_mask)
|
||||||
return loss
|
codeloss = F.cross_entropy(code_pred.float().permute(0,2,1), codes, reduction='none')
|
||||||
|
codeloss = torch.mean(codeloss * loss_mask)
|
||||||
|
return ctcloss, codeloss
|
||||||
|
|
||||||
def generate(self, speech_conditioning_input, texts):
|
def generate(self, speech_conditioning_input, texts):
|
||||||
codes = []
|
codes = []
|
||||||
|
@ -158,10 +176,13 @@ def inf():
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
#inf()
|
#inf()
|
||||||
|
|
||||||
|
mask = clustered_mask(.1, (4,100), 'cpu')
|
||||||
|
|
||||||
model = CtcCodeGenerator()
|
model = CtcCodeGenerator()
|
||||||
inps = torch.randint(0,36, (4, 300))
|
inps = torch.randint(0,36, (4, 300))
|
||||||
pads = torch.randint(0,100, (4,300))
|
pads = torch.randint(0,100, (4,300))
|
||||||
repeats = torch.randint(1,20, (4,300))
|
repeats = torch.randint(1,20, (4,300))
|
||||||
conds = torch.randn(4,3,80,600)
|
conds = torch.randn(4,80,600)
|
||||||
loss = model(conds, inps, pads, repeats, torch.tensor([250, 300, 280, 30]))
|
loss1, loss2 = model(conds, inps, pads, repeats, torch.tensor([250, 300, 280, 30]))
|
||||||
print(loss.shape)
|
print(loss1.shape, loss2.shape)
|
Loading…
Reference in New Issue
Block a user