diff --git a/codes/models/gpt_voice/gpt_asr.py b/codes/models/gpt_voice/gpt_asr.py index b0acd4f8..76a99c34 100644 --- a/codes/models/gpt_voice/gpt_asr.py +++ b/codes/models/gpt_voice/gpt_asr.py @@ -69,7 +69,7 @@ class GptAsr(nn.Module): self.final_norm = nn.LayerNorm(model_dim) self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) - def forward(self, mel_inputs, text_targets): + def get_logits(self, mel_inputs, text_targets): # Pad front and back. Pad at front is the "START" token. text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS) text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1])) @@ -80,16 +80,17 @@ class GptAsr(nn.Module): mel_emb = mel_emb.permute(0,2,1).contiguous() mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) emb = torch.cat([mel_emb, text_emb], dim=1) - enc = self.gpt(emb) - # Compute loss text_logits = self.final_norm(enc[:, self.max_mel_frames:]) text_logits = self.text_head(text_logits) text_logits = text_logits.permute(0,2,1) - loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long()) + return text_logits - return loss_text.mean() + def forward(self, mel_inputs, text_targets): + text_logits = self.get_logits(mel_inputs, text_targets) + loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long()) + return loss_text.mean(), text_logits def inference_beam_topk(self, mel): def topk_sampler(distribution, k): @@ -155,11 +156,26 @@ def register_gpt_asr(opt_net, opt): return GptAsr(**opt_get(opt_net, ['kwargs'], {})) +# Halves the number of layers in the provided model. +def distill(model): + rc = 0 + i = 0 + while i < len(model.gpt.layers.layers): + if rc % 2 != 0: + del model.gpt.layers.layers[i] + else: + i += 1 + rc += 1 + return model + + if __name__ == '__main__': - gpt = GptAsr() - l = gpt(torch.randn(2,80,800), - torch.randint(high=len(symbols), size=(2,180))) - print(l.shape) + gpt = GptAsr(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12) + gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth')) + student = distill(gpt) + torch.save(student.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth') + #l = gpt(torch.randn(2,80,800), + # torch.randint(high=len(symbols), size=(2,180))) #o = gpt.infer(torch.randint(high=24, size=(2,60))) #print(o.shape) diff --git a/codes/models/gpt_voice/lucidrains_gpt.py b/codes/models/gpt_voice/lucidrains_gpt.py index ed766384..07b1c90b 100644 --- a/codes/models/gpt_voice/lucidrains_gpt.py +++ b/codes/models/gpt_voice/lucidrains_gpt.py @@ -191,7 +191,7 @@ class Transformer(nn.Module): route_attn = ((True, False),) * depth attn_route_map = {'mask': route_attn} - self.layers = execute_type(layers, args_route = attn_route_map) + self.layers = execute_type(layers, args_route = attn_route_map, checkpoint=True) def forward(self, x): return self.layers(x) diff --git a/codes/models/gpt_voice/reversible.py b/codes/models/gpt_voice/reversible.py index 97a3dd64..481a2927 100644 --- a/codes/models/gpt_voice/reversible.py +++ b/codes/models/gpt_voice/reversible.py @@ -4,6 +4,9 @@ from torch.autograd.function import Function from torch.utils.checkpoint import get_device_states, set_device_states # for routing arguments into the functions of the reversible layer +from utils.util import checkpoint + + def route_args(router, args, depth): routed_args = [(dict(), dict()) for _ in range(depth)] matched_keys = [key for key in args.keys() if key in router] @@ -123,20 +126,25 @@ class _ReversibleFunction(Function): return dy, None, None class SequentialSequence(nn.Module): - def __init__(self, layers, args_route = {}, layer_dropout = 0.): + def __init__(self, layers, args_route = {}, layer_dropout = 0., checkpoint=False): super().__init__() assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' self.layers = layers self.args_route = args_route self.layer_dropout = layer_dropout + self.checkpoint = checkpoint def forward(self, x, **kwargs): args = route_args(self.args_route, kwargs, len(self.layers)) layers_and_args = list(zip(self.layers, args)) for (f, g), (f_args, g_args) in layers_and_args: - x = x + f(x, **f_args) - x = x + g(x, **g_args) + if self.checkpoint: + x = x + f(x, **f_args) + x = x + g(x, **g_args) + else: + x = x + checkpoint(f, x, **f_args) + x = x + checkpoint(g, x, **g_args) return x class ReversibleSequence(nn.Module):