asr_hf2: add independent position embedders
This commit is contained in:
parent
5b5cbc057c
commit
6996dfd9d5
|
@ -223,6 +223,7 @@ class GptAsrHf2(nn.Module):
|
|||
self.max_mel_frames = self.max_mel_frames
|
||||
self.mel_encoder = MelEncoder(model_dim)
|
||||
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
||||
self.text_solo_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim)
|
||||
seq_length = 2+self.max_symbols_per_phrase+self.max_mel_frames
|
||||
self.gpt_config = GPT2Config(vocab_size=self.number_text_tokens,
|
||||
|
@ -248,11 +249,11 @@ class GptAsrHf2(nn.Module):
|
|||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
def get_logits(self, mel_inputs, text_targets, get_attns=False):
|
||||
def get_logits(self, mel_inputs, text_targets, pos_emb, get_attns=False):
|
||||
# Pad front and remove last element to set up next token prediction. Pad at front is the "START" token.
|
||||
text_inputs = F.pad(text_targets, (1,0), value=self.start_token)[:, :-1]
|
||||
text_emb = self.gpt.get_input_embeddings()(text_inputs)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))
|
||||
text_emb = text_emb + pos_emb(torch.arange(text_emb.shape[1], device=text_inputs.device))
|
||||
if mel_inputs is None:
|
||||
emb = text_emb
|
||||
mel_len = 0
|
||||
|
@ -273,7 +274,7 @@ class GptAsrHf2(nn.Module):
|
|||
|
||||
def forward(self, mel_inputs, text_targets, return_attentions=False):
|
||||
text_targets = F.pad(text_targets, (0,1)) # Pad the targets with a <0> so that all have a "stop" token.
|
||||
text_logits = self.get_logits(mel_inputs, text_targets, get_attns=return_attentions)
|
||||
text_logits = self.get_logits(mel_inputs, text_targets, self.text_pos_embedding, get_attns=return_attentions)
|
||||
if return_attentions:
|
||||
return text_logits # These weren't really the logits.
|
||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||
|
@ -281,7 +282,7 @@ class GptAsrHf2(nn.Module):
|
|||
|
||||
def text_only(self, text_targets):
|
||||
text_targets = F.pad(text_targets, (0,1)) # Pad the targets with a <0> so that all have a "stop" token.
|
||||
text_logits = self.get_logits(None, text_targets)
|
||||
text_logits = self.get_logits(None, text_targets, self.text_solo_pos_embedding)
|
||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||
return loss_text.mean(), text_logits
|
||||
|
||||
|
|
|
@ -286,7 +286,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_unified_voice.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_hf2.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user