Add support for distilling gpt_asr
This commit is contained in:
parent
5d714bc566
commit
58494b0888
|
@ -69,7 +69,7 @@ class GptAsr(nn.Module):
|
|||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
||||
|
||||
def forward(self, mel_inputs, text_targets):
|
||||
def get_logits(self, mel_inputs, text_targets):
|
||||
# Pad front and back. Pad at front is the "START" token.
|
||||
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
|
||||
text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1]))
|
||||
|
@ -80,16 +80,17 @@ class GptAsr(nn.Module):
|
|||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||
|
||||
enc = self.gpt(emb)
|
||||
|
||||
# Compute loss
|
||||
text_logits = self.final_norm(enc[:, self.max_mel_frames:])
|
||||
text_logits = self.text_head(text_logits)
|
||||
text_logits = text_logits.permute(0,2,1)
|
||||
loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long())
|
||||
return text_logits
|
||||
|
||||
return loss_text.mean()
|
||||
def forward(self, mel_inputs, text_targets):
|
||||
text_logits = self.get_logits(mel_inputs, text_targets)
|
||||
loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long())
|
||||
return loss_text.mean(), text_logits
|
||||
|
||||
def inference_beam_topk(self, mel):
|
||||
def topk_sampler(distribution, k):
|
||||
|
@ -155,11 +156,26 @@ def register_gpt_asr(opt_net, opt):
|
|||
return GptAsr(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
# Halves the number of layers in the provided model.
|
||||
def distill(model):
|
||||
rc = 0
|
||||
i = 0
|
||||
while i < len(model.gpt.layers.layers):
|
||||
if rc % 2 != 0:
|
||||
del model.gpt.layers.layers[i]
|
||||
else:
|
||||
i += 1
|
||||
rc += 1
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpt = GptAsr()
|
||||
l = gpt(torch.randn(2,80,800),
|
||||
torch.randint(high=len(symbols), size=(2,180)))
|
||||
print(l.shape)
|
||||
gpt = GptAsr(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12)
|
||||
gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth'))
|
||||
student = distill(gpt)
|
||||
torch.save(student.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth')
|
||||
#l = gpt(torch.randn(2,80,800),
|
||||
# torch.randint(high=len(symbols), size=(2,180)))
|
||||
|
||||
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
|
||||
#print(o.shape)
|
||||
|
|
|
@ -191,7 +191,7 @@ class Transformer(nn.Module):
|
|||
route_attn = ((True, False),) * depth
|
||||
attn_route_map = {'mask': route_attn}
|
||||
|
||||
self.layers = execute_type(layers, args_route = attn_route_map)
|
||||
self.layers = execute_type(layers, args_route = attn_route_map, checkpoint=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
|
|
@ -4,6 +4,9 @@ from torch.autograd.function import Function
|
|||
from torch.utils.checkpoint import get_device_states, set_device_states
|
||||
|
||||
# for routing arguments into the functions of the reversible layer
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
def route_args(router, args, depth):
|
||||
routed_args = [(dict(), dict()) for _ in range(depth)]
|
||||
matched_keys = [key for key in args.keys() if key in router]
|
||||
|
@ -123,20 +126,25 @@ class _ReversibleFunction(Function):
|
|||
return dy, None, None
|
||||
|
||||
class SequentialSequence(nn.Module):
|
||||
def __init__(self, layers, args_route = {}, layer_dropout = 0.):
|
||||
def __init__(self, layers, args_route = {}, layer_dropout = 0., checkpoint=False):
|
||||
super().__init__()
|
||||
assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
|
||||
self.layers = layers
|
||||
self.args_route = args_route
|
||||
self.layer_dropout = layer_dropout
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
args = route_args(self.args_route, kwargs, len(self.layers))
|
||||
layers_and_args = list(zip(self.layers, args))
|
||||
|
||||
for (f, g), (f_args, g_args) in layers_and_args:
|
||||
x = x + f(x, **f_args)
|
||||
x = x + g(x, **g_args)
|
||||
if self.checkpoint:
|
||||
x = x + f(x, **f_args)
|
||||
x = x + g(x, **g_args)
|
||||
else:
|
||||
x = x + checkpoint(f, x, **f_args)
|
||||
x = x + checkpoint(g, x, **g_args)
|
||||
return x
|
||||
|
||||
class ReversibleSequence(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue
Block a user