Add support for distilling gpt_asr

This commit is contained in:
James Betker 2021-10-27 13:10:07 -06:00
parent 5d714bc566
commit 58494b0888
3 changed files with 37 additions and 13 deletions

View File

@ -69,7 +69,7 @@ class GptAsr(nn.Module):
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
def forward(self, mel_inputs, text_targets):
def get_logits(self, mel_inputs, text_targets):
# Pad front and back. Pad at front is the "START" token.
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1]))
@ -80,16 +80,17 @@ class GptAsr(nn.Module):
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
emb = torch.cat([mel_emb, text_emb], dim=1)
enc = self.gpt(emb)
# Compute loss
text_logits = self.final_norm(enc[:, self.max_mel_frames:])
text_logits = self.text_head(text_logits)
text_logits = text_logits.permute(0,2,1)
loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long())
return text_logits
return loss_text.mean()
def forward(self, mel_inputs, text_targets):
text_logits = self.get_logits(mel_inputs, text_targets)
loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long())
return loss_text.mean(), text_logits
def inference_beam_topk(self, mel):
def topk_sampler(distribution, k):
@ -155,11 +156,26 @@ def register_gpt_asr(opt_net, opt):
return GptAsr(**opt_get(opt_net, ['kwargs'], {}))
# Halves the number of layers in the provided model.
def distill(model):
rc = 0
i = 0
while i < len(model.gpt.layers.layers):
if rc % 2 != 0:
del model.gpt.layers.layers[i]
else:
i += 1
rc += 1
return model
if __name__ == '__main__':
gpt = GptAsr()
l = gpt(torch.randn(2,80,800),
torch.randint(high=len(symbols), size=(2,180)))
print(l.shape)
gpt = GptAsr(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12)
gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth'))
student = distill(gpt)
torch.save(student.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth')
#l = gpt(torch.randn(2,80,800),
# torch.randint(high=len(symbols), size=(2,180)))
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
#print(o.shape)

View File

@ -191,7 +191,7 @@ class Transformer(nn.Module):
route_attn = ((True, False),) * depth
attn_route_map = {'mask': route_attn}
self.layers = execute_type(layers, args_route = attn_route_map)
self.layers = execute_type(layers, args_route = attn_route_map, checkpoint=True)
def forward(self, x):
return self.layers(x)

View File

@ -4,6 +4,9 @@ from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states
# for routing arguments into the functions of the reversible layer
from utils.util import checkpoint
def route_args(router, args, depth):
routed_args = [(dict(), dict()) for _ in range(depth)]
matched_keys = [key for key in args.keys() if key in router]
@ -123,20 +126,25 @@ class _ReversibleFunction(Function):
return dy, None, None
class SequentialSequence(nn.Module):
def __init__(self, layers, args_route = {}, layer_dropout = 0.):
def __init__(self, layers, args_route = {}, layer_dropout = 0., checkpoint=False):
super().__init__()
assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
self.layers = layers
self.args_route = args_route
self.layer_dropout = layer_dropout
self.checkpoint = checkpoint
def forward(self, x, **kwargs):
args = route_args(self.args_route, kwargs, len(self.layers))
layers_and_args = list(zip(self.layers, args))
for (f, g), (f_args, g_args) in layers_and_args:
x = x + f(x, **f_args)
x = x + g(x, **g_args)
if self.checkpoint:
x = x + f(x, **f_args)
x = x + g(x, **g_args)
else:
x = x + checkpoint(f, x, **f_args)
x = x + checkpoint(g, x, **g_args)
return x
class ReversibleSequence(nn.Module):