diff --git a/codes/models/gpt_voice/unified_voice.py b/codes/models/gpt_voice/unified_voice.py index b526f5dc..689f22aa 100644 --- a/codes/models/gpt_voice/unified_voice.py +++ b/codes/models/gpt_voice/unified_voice.py @@ -89,6 +89,12 @@ class UnifiedGptVoice(nn.Module): self.mel_head = nn.Linear(model_dim, self.number_mel_codes) self.max_conditioning_length = max_conditioning_length + # Initialize the embeddings per the GPT-2 scheme + for module in [self.text_embedding, self.text_pos_embedding, self.mel_pos_embedding]: + module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + def build_aligned_inputs_and_targets(self, input, start_token, stop_token): inp = F.pad(input, (1,0), value=start_token) tar = F.pad(input, (0,1), value=stop_token)