Initialize our embeddings the same way GPT-2 initializes theirs.
This commit is contained in:
parent
8d01f7685c
commit
cd89e6b42e
|
@ -89,6 +89,12 @@ class UnifiedGptVoice(nn.Module):
|
||||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
||||||
self.max_conditioning_length = max_conditioning_length
|
self.max_conditioning_length = max_conditioning_length
|
||||||
|
|
||||||
|
# Initialize the embeddings per the GPT-2 scheme
|
||||||
|
for module in [self.text_embedding, self.text_pos_embedding, self.mel_pos_embedding]:
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||||
inp = F.pad(input, (1,0), value=start_token)
|
inp = F.pad(input, (1,0), value=start_token)
|
||||||
tar = F.pad(input, (0,1), value=stop_token)
|
tar = F.pad(input, (0,1), value=stop_token)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user