forked from mrq/DL-Art-School
Align autoregressive text using start and stop tokens
This commit is contained in:
parent
628569af7b
commit
2fb9ffb0aa
|
@ -86,7 +86,8 @@ class InferenceModel(GPT2PreTrainedModel):
|
||||||
assert labels is None # Training not supported by this inference model.
|
assert labels is None # Training not supported by this inference model.
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
out = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True, past_key_values=past_key_values, use_cache=use_cache)
|
out = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True, past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache, expected_seq_len=150)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
hidden_states, present_key_values = out
|
hidden_states, present_key_values = out
|
||||||
else:
|
else:
|
||||||
|
@ -168,6 +169,8 @@ class AutoregressiveCodegen(nn.Module):
|
||||||
|
|
||||||
self.START_TOKEN=8192
|
self.START_TOKEN=8192
|
||||||
self.STOP_TOKEN=8193
|
self.STOP_TOKEN=8193
|
||||||
|
self.START_TEXT_TOKEN = 255
|
||||||
|
self.STOP_TEXT_TOKEN = 0
|
||||||
self.max_text_token_id = num_text_tokens
|
self.max_text_token_id = num_text_tokens
|
||||||
self.max_mel_token_id = num_mel_tokens
|
self.max_mel_token_id = num_mel_tokens
|
||||||
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
|
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
|
||||||
|
@ -231,6 +234,9 @@ class AutoregressiveCodegen(nn.Module):
|
||||||
for i in range(conditioning_signal.shape[1]):
|
for i in range(conditioning_signal.shape[1]):
|
||||||
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
|
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
|
||||||
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
|
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
|
||||||
|
# Since all positional embeddings are relative, it is (probably) important to "fix" the text with some permanent embeddings.
|
||||||
|
text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN)
|
||||||
|
text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN)
|
||||||
_, enc_text = self.encoder(text_codes, return_hiddens=True)
|
_, enc_text = self.encoder(text_codes, return_hiddens=True)
|
||||||
# Interleave cond_emb into the first few contexts.
|
# Interleave cond_emb into the first few contexts.
|
||||||
full_context = enc_text
|
full_context = enc_text
|
||||||
|
@ -255,6 +261,8 @@ class AutoregressiveCodegen(nn.Module):
|
||||||
for i in range(conditioning_signal.shape[1]):
|
for i in range(conditioning_signal.shape[1]):
|
||||||
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
|
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
|
||||||
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
|
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
|
||||||
|
text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN)
|
||||||
|
text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN)
|
||||||
_, enc_text = self.encoder(text_codes, return_hiddens=True)
|
_, enc_text = self.encoder(text_codes, return_hiddens=True)
|
||||||
# Interleave cond_emb into the first few contexts.
|
# Interleave cond_emb into the first few contexts.
|
||||||
full_context = enc_text
|
full_context = enc_text
|
||||||
|
|
Loading…
Reference in New Issue
Block a user