Fix inference in new autoregressive_codegen

This commit is contained in:
James Betker 2022-04-07 21:22:46 -06:00
parent 3f8d7955ef
commit 7c578eb59b

View File

@ -5,7 +5,7 @@ from transformers import GPT2PreTrainedModel, GPT2Config
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from models.arch_util import AttentionBlock
from models.lucidrains.x_transformers import TransformerWrapper, Encoder, Decoder
from models.lucidrains.x_transformers import TransformerWrapper, Decoder, Encoder
from trainer.networks import register_model
@ -86,7 +86,7 @@ class InferenceModel(GPT2PreTrainedModel):
assert labels is None # Training not supported by this inference model.
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.transformer.decoder(input_ids, context=self.context, return_embeddings=True)
hidden_states = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True)
logits = self.transformer.decoder.to_logits(hidden_states)
if not return_dict:
@ -241,21 +241,24 @@ class AutoregressiveCodegen(nn.Module):
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
return loss_mel
def generate(self, conditioning_signal, text_codes, max_tokens=1024, **hf_generate_kwargs):
if not hasattr(self, 'inference_model'):
self.inference_model = InferenceModel(self)
def generate(self, conditioning_signal, text_codes, max_tokens=256, **hf_generate_kwargs):
inference_model = InferenceModel(self)
# Build the context
if len(conditioning_signal.shape) != 4:
conditioning_signal = conditioning_signal.unsqueeze(1)
cond_embs = []
for i in range(conditioning_signal.shape[1]):
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
enc_text = self.encoder(text_codes, return_embeddings=True)
context = torch.cat([cond_emb, enc_text], dim=1)
self.inference_model.store_context(context)
_, enc_text = self.encoder(text_codes, return_hiddens=True)
# Interleave cond_emb into the first few contexts.
full_context = enc_text
full_context[1] = cond_emb
full_context[3] = cond_emb
full_context[6] = cond_emb
inference_model.store_context(full_context)
gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
gen = inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
max_length=max_tokens, output_attentions=False, return_dict_in_generate=True,
**hf_generate_kwargs)
return gen.sequences