forked from mrq/DL-Art-School
Align autoregressive text using start and stop tokens
This commit is contained in:
parent
628569af7b
commit
2fb9ffb0aa
|
@ -86,7 +86,8 @@ class InferenceModel(GPT2PreTrainedModel):
|
|||
assert labels is None # Training not supported by this inference model.
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
out = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True, past_key_values=past_key_values, use_cache=use_cache)
|
||||
out = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True, past_key_values=past_key_values,
|
||||
use_cache=use_cache, expected_seq_len=150)
|
||||
if use_cache:
|
||||
hidden_states, present_key_values = out
|
||||
else:
|
||||
|
@ -168,6 +169,8 @@ class AutoregressiveCodegen(nn.Module):
|
|||
|
||||
self.START_TOKEN=8192
|
||||
self.STOP_TOKEN=8193
|
||||
self.START_TEXT_TOKEN = 255
|
||||
self.STOP_TEXT_TOKEN = 0
|
||||
self.max_text_token_id = num_text_tokens
|
||||
self.max_mel_token_id = num_mel_tokens
|
||||
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
|
||||
|
@ -231,6 +234,9 @@ class AutoregressiveCodegen(nn.Module):
|
|||
for i in range(conditioning_signal.shape[1]):
|
||||
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
|
||||
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
|
||||
# Since all positional embeddings are relative, it is (probably) important to "fix" the text with some permanent embeddings.
|
||||
text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN)
|
||||
text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN)
|
||||
_, enc_text = self.encoder(text_codes, return_hiddens=True)
|
||||
# Interleave cond_emb into the first few contexts.
|
||||
full_context = enc_text
|
||||
|
@ -255,6 +261,8 @@ class AutoregressiveCodegen(nn.Module):
|
|||
for i in range(conditioning_signal.shape[1]):
|
||||
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
|
||||
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
|
||||
text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN)
|
||||
text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN)
|
||||
_, enc_text = self.encoder(text_codes, return_hiddens=True)
|
||||
# Interleave cond_emb into the first few contexts.
|
||||
full_context = enc_text
|
||||
|
|
Loading…
Reference in New Issue
Block a user