forked from mrq/DL-Art-School
Fix inference in new autoregressive_codegen
This commit is contained in:
parent
3f8d7955ef
commit
7c578eb59b
|
@ -5,7 +5,7 @@ from transformers import GPT2PreTrainedModel, GPT2Config
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||||
|
|
||||||
from models.arch_util import AttentionBlock
|
from models.arch_util import AttentionBlock
|
||||||
from models.lucidrains.x_transformers import TransformerWrapper, Encoder, Decoder
|
from models.lucidrains.x_transformers import TransformerWrapper, Decoder, Encoder
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
|
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ class InferenceModel(GPT2PreTrainedModel):
|
||||||
assert labels is None # Training not supported by this inference model.
|
assert labels is None # Training not supported by this inference model.
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
hidden_states = self.transformer.decoder(input_ids, context=self.context, return_embeddings=True)
|
hidden_states = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True)
|
||||||
logits = self.transformer.decoder.to_logits(hidden_states)
|
logits = self.transformer.decoder.to_logits(hidden_states)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
|
@ -241,21 +241,24 @@ class AutoregressiveCodegen(nn.Module):
|
||||||
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
|
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
|
||||||
return loss_mel
|
return loss_mel
|
||||||
|
|
||||||
def generate(self, conditioning_signal, text_codes, max_tokens=1024, **hf_generate_kwargs):
|
def generate(self, conditioning_signal, text_codes, max_tokens=256, **hf_generate_kwargs):
|
||||||
if not hasattr(self, 'inference_model'):
|
inference_model = InferenceModel(self)
|
||||||
self.inference_model = InferenceModel(self)
|
# Build the context
|
||||||
|
|
||||||
if len(conditioning_signal.shape) != 4:
|
if len(conditioning_signal.shape) != 4:
|
||||||
conditioning_signal = conditioning_signal.unsqueeze(1)
|
conditioning_signal = conditioning_signal.unsqueeze(1)
|
||||||
cond_embs = []
|
cond_embs = []
|
||||||
for i in range(conditioning_signal.shape[1]):
|
for i in range(conditioning_signal.shape[1]):
|
||||||
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
|
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
|
||||||
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
|
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
|
||||||
enc_text = self.encoder(text_codes, return_embeddings=True)
|
_, enc_text = self.encoder(text_codes, return_hiddens=True)
|
||||||
context = torch.cat([cond_emb, enc_text], dim=1)
|
# Interleave cond_emb into the first few contexts.
|
||||||
self.inference_model.store_context(context)
|
full_context = enc_text
|
||||||
|
full_context[1] = cond_emb
|
||||||
|
full_context[3] = cond_emb
|
||||||
|
full_context[6] = cond_emb
|
||||||
|
inference_model.store_context(full_context)
|
||||||
|
|
||||||
gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
|
gen = inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
|
||||||
max_length=max_tokens, output_attentions=False, return_dict_in_generate=True,
|
max_length=max_tokens, output_attentions=False, return_dict_in_generate=True,
|
||||||
**hf_generate_kwargs)
|
**hf_generate_kwargs)
|
||||||
return gen.sequences
|
return gen.sequences
|
||||||
|
|
Loading…
Reference in New Issue
Block a user