diff --git a/models/autoregressive.py b/models/autoregressive.py index 8d5e462..932e508 100644 --- a/models/autoregressive.py +++ b/models/autoregressive.py @@ -278,9 +278,10 @@ class MelEncoder(nn.Module): class UnifiedVoice(nn.Module): def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, mel_length_compression=1024, number_text_tokens=256, - start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, + start_text_token=None, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, - checkpointing=True, average_conditioning_embeddings=False): + checkpointing=True, average_conditioning_embeddings=False, + types=1): """ Args: layers: Number of layers in transformer stack. @@ -304,8 +305,8 @@ class UnifiedVoice(nn.Module): super().__init__() self.number_text_tokens = number_text_tokens - self.start_text_token = start_text_token - self.stop_text_token = stop_text_token + self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token + self.stop_text_token = 0 self.number_mel_codes = number_mel_codes self.start_mel_token = start_mel_token self.stop_mel_token = stop_mel_token @@ -318,7 +319,7 @@ class UnifiedVoice(nn.Module): self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.average_conditioning_embeddings = average_conditioning_embeddings - self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) + self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim) if use_mel_codes_as_input: self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) else: @@ -333,7 +334,7 @@ class UnifiedVoice(nn.Module): self.text_solo_embedding = 0 self.final_norm = nn.LayerNorm(model_dim) - self.text_head = nn.Linear(model_dim, self.number_text_tokens) + self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1) self.mel_head = nn.Linear(model_dim, self.number_mel_codes) # Initialize the embeddings per the GPT-2 scheme @@ -389,7 +390,7 @@ class UnifiedVoice(nn.Module): else: return first_logits - def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False, + def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False, return_latent=False, clip_inputs=True): """ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode @@ -406,6 +407,10 @@ class UnifiedVoice(nn.Module): If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality. """ + # Types are expressed by expanding the text embedding space. + if types is not None: + text_inputs = text_inputs * (1+types).unsqueeze(-1) + if clip_inputs: # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by # chopping the inputs by the maximum actual length.