Add support for distilling gpt_asr
This commit is contained in:
parent
5d714bc566
commit
58494b0888
|
@ -69,7 +69,7 @@ class GptAsr(nn.Module):
|
||||||
self.final_norm = nn.LayerNorm(model_dim)
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
||||||
|
|
||||||
def forward(self, mel_inputs, text_targets):
|
def get_logits(self, mel_inputs, text_targets):
|
||||||
# Pad front and back. Pad at front is the "START" token.
|
# Pad front and back. Pad at front is the "START" token.
|
||||||
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
|
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
|
||||||
text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1]))
|
text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1]))
|
||||||
|
@ -80,16 +80,17 @@ class GptAsr(nn.Module):
|
||||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||||
|
|
||||||
enc = self.gpt(emb)
|
enc = self.gpt(emb)
|
||||||
|
|
||||||
# Compute loss
|
|
||||||
text_logits = self.final_norm(enc[:, self.max_mel_frames:])
|
text_logits = self.final_norm(enc[:, self.max_mel_frames:])
|
||||||
text_logits = self.text_head(text_logits)
|
text_logits = self.text_head(text_logits)
|
||||||
text_logits = text_logits.permute(0,2,1)
|
text_logits = text_logits.permute(0,2,1)
|
||||||
loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long())
|
return text_logits
|
||||||
|
|
||||||
return loss_text.mean()
|
def forward(self, mel_inputs, text_targets):
|
||||||
|
text_logits = self.get_logits(mel_inputs, text_targets)
|
||||||
|
loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long())
|
||||||
|
return loss_text.mean(), text_logits
|
||||||
|
|
||||||
def inference_beam_topk(self, mel):
|
def inference_beam_topk(self, mel):
|
||||||
def topk_sampler(distribution, k):
|
def topk_sampler(distribution, k):
|
||||||
|
@ -155,11 +156,26 @@ def register_gpt_asr(opt_net, opt):
|
||||||
return GptAsr(**opt_get(opt_net, ['kwargs'], {}))
|
return GptAsr(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
|
# Halves the number of layers in the provided model.
|
||||||
|
def distill(model):
|
||||||
|
rc = 0
|
||||||
|
i = 0
|
||||||
|
while i < len(model.gpt.layers.layers):
|
||||||
|
if rc % 2 != 0:
|
||||||
|
del model.gpt.layers.layers[i]
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
rc += 1
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
gpt = GptAsr()
|
gpt = GptAsr(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12)
|
||||||
l = gpt(torch.randn(2,80,800),
|
gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth'))
|
||||||
torch.randint(high=len(symbols), size=(2,180)))
|
student = distill(gpt)
|
||||||
print(l.shape)
|
torch.save(student.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth')
|
||||||
|
#l = gpt(torch.randn(2,80,800),
|
||||||
|
# torch.randint(high=len(symbols), size=(2,180)))
|
||||||
|
|
||||||
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
|
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
|
||||||
#print(o.shape)
|
#print(o.shape)
|
||||||
|
|
|
@ -191,7 +191,7 @@ class Transformer(nn.Module):
|
||||||
route_attn = ((True, False),) * depth
|
route_attn = ((True, False),) * depth
|
||||||
attn_route_map = {'mask': route_attn}
|
attn_route_map = {'mask': route_attn}
|
||||||
|
|
||||||
self.layers = execute_type(layers, args_route = attn_route_map)
|
self.layers = execute_type(layers, args_route = attn_route_map, checkpoint=True)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.layers(x)
|
return self.layers(x)
|
||||||
|
|
|
@ -4,6 +4,9 @@ from torch.autograd.function import Function
|
||||||
from torch.utils.checkpoint import get_device_states, set_device_states
|
from torch.utils.checkpoint import get_device_states, set_device_states
|
||||||
|
|
||||||
# for routing arguments into the functions of the reversible layer
|
# for routing arguments into the functions of the reversible layer
|
||||||
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
||||||
def route_args(router, args, depth):
|
def route_args(router, args, depth):
|
||||||
routed_args = [(dict(), dict()) for _ in range(depth)]
|
routed_args = [(dict(), dict()) for _ in range(depth)]
|
||||||
matched_keys = [key for key in args.keys() if key in router]
|
matched_keys = [key for key in args.keys() if key in router]
|
||||||
|
@ -123,20 +126,25 @@ class _ReversibleFunction(Function):
|
||||||
return dy, None, None
|
return dy, None, None
|
||||||
|
|
||||||
class SequentialSequence(nn.Module):
|
class SequentialSequence(nn.Module):
|
||||||
def __init__(self, layers, args_route = {}, layer_dropout = 0.):
|
def __init__(self, layers, args_route = {}, layer_dropout = 0., checkpoint=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
|
assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
|
||||||
self.layers = layers
|
self.layers = layers
|
||||||
self.args_route = args_route
|
self.args_route = args_route
|
||||||
self.layer_dropout = layer_dropout
|
self.layer_dropout = layer_dropout
|
||||||
|
self.checkpoint = checkpoint
|
||||||
|
|
||||||
def forward(self, x, **kwargs):
|
def forward(self, x, **kwargs):
|
||||||
args = route_args(self.args_route, kwargs, len(self.layers))
|
args = route_args(self.args_route, kwargs, len(self.layers))
|
||||||
layers_and_args = list(zip(self.layers, args))
|
layers_and_args = list(zip(self.layers, args))
|
||||||
|
|
||||||
for (f, g), (f_args, g_args) in layers_and_args:
|
for (f, g), (f_args, g_args) in layers_and_args:
|
||||||
x = x + f(x, **f_args)
|
if self.checkpoint:
|
||||||
x = x + g(x, **g_args)
|
x = x + f(x, **f_args)
|
||||||
|
x = x + g(x, **g_args)
|
||||||
|
else:
|
||||||
|
x = x + checkpoint(f, x, **f_args)
|
||||||
|
x = x + checkpoint(g, x, **g_args)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class ReversibleSequence(nn.Module):
|
class ReversibleSequence(nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user