Fix inference in new autoregressive_codegen
This commit is contained in:
parent
3f8d7955ef
commit
7c578eb59b
|
@ -5,7 +5,7 @@ from transformers import GPT2PreTrainedModel, GPT2Config
|
|||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
|
||||
from models.arch_util import AttentionBlock
|
||||
from models.lucidrains.x_transformers import TransformerWrapper, Encoder, Decoder
|
||||
from models.lucidrains.x_transformers import TransformerWrapper, Decoder, Encoder
|
||||
from trainer.networks import register_model
|
||||
|
||||
|
||||
|
@ -86,7 +86,7 @@ class InferenceModel(GPT2PreTrainedModel):
|
|||
assert labels is None # Training not supported by this inference model.
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
hidden_states = self.transformer.decoder(input_ids, context=self.context, return_embeddings=True)
|
||||
hidden_states = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True)
|
||||
logits = self.transformer.decoder.to_logits(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
|
@ -241,21 +241,24 @@ class AutoregressiveCodegen(nn.Module):
|
|||
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
|
||||
return loss_mel
|
||||
|
||||
def generate(self, conditioning_signal, text_codes, max_tokens=1024, **hf_generate_kwargs):
|
||||
if not hasattr(self, 'inference_model'):
|
||||
self.inference_model = InferenceModel(self)
|
||||
|
||||
def generate(self, conditioning_signal, text_codes, max_tokens=256, **hf_generate_kwargs):
|
||||
inference_model = InferenceModel(self)
|
||||
# Build the context
|
||||
if len(conditioning_signal.shape) != 4:
|
||||
conditioning_signal = conditioning_signal.unsqueeze(1)
|
||||
cond_embs = []
|
||||
for i in range(conditioning_signal.shape[1]):
|
||||
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
|
||||
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
|
||||
enc_text = self.encoder(text_codes, return_embeddings=True)
|
||||
context = torch.cat([cond_emb, enc_text], dim=1)
|
||||
self.inference_model.store_context(context)
|
||||
_, enc_text = self.encoder(text_codes, return_hiddens=True)
|
||||
# Interleave cond_emb into the first few contexts.
|
||||
full_context = enc_text
|
||||
full_context[1] = cond_emb
|
||||
full_context[3] = cond_emb
|
||||
full_context[6] = cond_emb
|
||||
inference_model.store_context(full_context)
|
||||
|
||||
gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
|
||||
gen = inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
|
||||
max_length=max_tokens, output_attentions=False, return_dict_in_generate=True,
|
||||
**hf_generate_kwargs)
|
||||
return gen.sequences
|
||||
|
|
Loading…
Reference in New Issue
Block a user