Align autoregressive text using start and stop tokens

This commit is contained in:
James Betker 2022-04-08 09:41:59 -06:00
parent 628569af7b
commit 2fb9ffb0aa

View File

@ -86,7 +86,8 @@ class InferenceModel(GPT2PreTrainedModel):
assert labels is None # Training not supported by this inference model.
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
out = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True, past_key_values=past_key_values, use_cache=use_cache)
out = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True, past_key_values=past_key_values,
use_cache=use_cache, expected_seq_len=150)
if use_cache:
hidden_states, present_key_values = out
else:
@ -168,6 +169,8 @@ class AutoregressiveCodegen(nn.Module):
self.START_TOKEN=8192
self.STOP_TOKEN=8193
self.START_TEXT_TOKEN = 255
self.STOP_TEXT_TOKEN = 0
self.max_text_token_id = num_text_tokens
self.max_mel_token_id = num_mel_tokens
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
@ -231,6 +234,9 @@ class AutoregressiveCodegen(nn.Module):
for i in range(conditioning_signal.shape[1]):
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
# Since all positional embeddings are relative, it is (probably) important to "fix" the text with some permanent embeddings.
text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN)
text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN)
_, enc_text = self.encoder(text_codes, return_hiddens=True)
# Interleave cond_emb into the first few contexts.
full_context = enc_text
@ -255,6 +261,8 @@ class AutoregressiveCodegen(nn.Module):
for i in range(conditioning_signal.shape[1]):
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN)
text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN)
_, enc_text = self.encoder(text_codes, return_hiddens=True)
# Interleave cond_emb into the first few contexts.
full_context = enc_text