forked from mrq/DL-Art-School
potentially average conditioning inputs
This commit is contained in:
parent
e6a95f7c11
commit
1e87b934db
|
@ -242,7 +242,7 @@ class UnifiedVoice(nn.Module):
|
||||||
mel_length_compression=1024, number_text_tokens=256,
|
mel_length_compression=1024, number_text_tokens=256,
|
||||||
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
||||||
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
||||||
checkpointing=True):
|
checkpointing=True, average_conditioning_embeddings=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
layers: Number of layers in transformer stack.
|
layers: Number of layers in transformer stack.
|
||||||
|
@ -261,6 +261,7 @@ class UnifiedVoice(nn.Module):
|
||||||
train_solo_embeddings:
|
train_solo_embeddings:
|
||||||
use_mel_codes_as_input:
|
use_mel_codes_as_input:
|
||||||
checkpointing:
|
checkpointing:
|
||||||
|
average_conditioning_embeddings: Whether or not conditioning embeddings should be averaged, instead of fed piecewise into the model.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -278,6 +279,7 @@ class UnifiedVoice(nn.Module):
|
||||||
self.max_conditioning_inputs = max_conditioning_inputs
|
self.max_conditioning_inputs = max_conditioning_inputs
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
|
self.average_conditioning_embeddings = average_conditioning_embeddings
|
||||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||||
if use_mel_codes_as_input:
|
if use_mel_codes_as_input:
|
||||||
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
||||||
|
@ -390,6 +392,8 @@ class UnifiedVoice(nn.Module):
|
||||||
for j in range(speech_conditioning_input.shape[1]):
|
for j in range(speech_conditioning_input.shape[1]):
|
||||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||||
conds = torch.stack(conds, dim=1)
|
conds = torch.stack(conds, dim=1)
|
||||||
|
if self.average_conditioning_embeddings:
|
||||||
|
conds = conds.mean(dim=1).unsqueeze(1)
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user