clean up unified voice

- remove unused code
- fix inference model to use the terms "prior" and "posterior" to properly define the modeling order (they were inverted before)
- default some settings I never intend to change in the future
This commit is contained in:
James Betker 2022-05-09 14:45:49 -06:00
parent 9118f58849
commit b42b4e18de

View File

@ -32,17 +32,17 @@ class ResBlock(nn.Module):
class GPT2InferenceModel(GPT2PreTrainedModel): class GPT2InferenceModel(GPT2PreTrainedModel):
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear): def __init__(self, config, gpt, posterior_pos_emb, embeddings, norm, linear):
super().__init__(config) super().__init__(config)
self.transformer = gpt self.transformer = gpt
self.text_pos_embedding = text_pos_emb self.posterior_pos_embedding = posterior_pos_emb
self.embeddings = embeddings self.embeddings = embeddings
self.lm_head = nn.Sequential(norm, linear) self.head = nn.Sequential(norm, linear)
# Model parallel # Model parallel
self.model_parallel = False self.model_parallel = False
self.device_map = None self.device_map = None
self.cached_mel_emb = None self.cached_prior_emb = None
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
self.device_map = ( self.device_map = (
@ -52,27 +52,26 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
) )
assert_device_map(self.device_map, len(self.transformer.h)) assert_device_map(self.device_map, len(self.transformer.h))
self.transformer.parallelize(self.device_map) self.transformer.parallelize(self.device_map)
self.lm_head = self.lm_head.to(self.transformer.first_device) self.head = self.head.to(self.transformer.first_device)
self.model_parallel = True self.model_parallel = True
def deparallelize(self): def deparallelize(self):
self.transformer.deparallelize() self.transformer.deparallelize()
self.transformer = self.transformer.to("cpu") self.transformer = self.transformer.to("cpu")
self.lm_head = self.lm_head.to("cpu") self.head = self.head.to("cpu")
self.model_parallel = False self.model_parallel = False
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.head
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings self.head = new_embeddings
def store_mel_emb(self, mel_emb): def store_prior_emb(self, mel_emb):
self.cached_mel_emb = mel_emb self.cached_prior_emb = mel_emb
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past: if past:
@ -117,25 +116,25 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
): ):
assert self.cached_mel_emb is not None assert self.cached_prior_emb is not None
assert inputs_embeds is None # Not supported by this inference model. assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model. assert labels is None # Training not supported by this inference model.
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Create embedding # Create embedding
mel_len = self.cached_mel_emb.shape[1] prior_len = self.cached_prior_emb.shape[1]
if input_ids.shape[1] != 1: if input_ids.shape[1] != 1:
text_inputs = input_ids[:, mel_len:] posterior_inputs = input_ids[:, prior_len:]
text_emb = self.embeddings(text_inputs) posterior_emb = self.embeddings(posterior_inputs)
text_emb = text_emb + self.text_pos_embedding(text_emb) posterior_emb = posterior_emb + self.posterior_pos_embedding(posterior_emb)
if self.cached_mel_emb.shape[0] != text_emb.shape[0]: if self.cached_prior_emb.shape[0] != posterior_emb.shape[0]:
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0) prior_emb = self.cached_prior_emb.repeat_interleave(posterior_emb.shape[0] // self.cached_prior_emb.shape[0], 0)
else: else:
mel_emb = self.cached_mel_emb prior_emb = self.cached_prior_emb
emb = torch.cat([mel_emb, text_emb], dim=1) emb = torch.cat([prior_emb, posterior_emb], dim=1)
else: else:
emb = self.embeddings(input_ids) emb = self.embeddings(input_ids)
emb = emb + self.text_pos_embedding.get_fixed_embedding(attention_mask.shape[1]-mel_len, attention_mask.device) emb = emb + self.posterior_pos_embedding.get_fixed_embedding(attention_mask.shape[1] - prior_len, attention_mask.device)
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
inputs_embeds=emb, inputs_embeds=emb,
@ -156,16 +155,16 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
# Set device for model parallelism # Set device for model parallelism
if self.model_parallel: if self.model_parallel:
torch.cuda.set_device(self.transformer.first_device) torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device) hidden_states = hidden_states.to(self.head.weight.device)
lm_logits = self.lm_head(hidden_states) logits = self.head(hidden_states)
if not return_dict: if not return_dict:
return (lm_logits,) + transformer_outputs[1:] return (logits,) + transformer_outputs[1:]
return CausalLMOutputWithCrossAttentions( return CausalLMOutputWithCrossAttentions(
loss=None, loss=None,
logits=lm_logits, logits=logits,
past_key_values=transformer_outputs.past_key_values, past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
@ -239,9 +238,7 @@ class MelEncoder(nn.Module):
class UnifiedVoice(nn.Module): class UnifiedVoice(nn.Module):
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192, mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, start_text_token=None, stop_mel_token=8193, start_text_token=None, checkpointing=True, types=1):
checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False,
types=1):
""" """
Args: Args:
layers: Number of layers in transformer stack. layers: Number of layers in transformer stack.
@ -252,14 +249,10 @@ class UnifiedVoice(nn.Module):
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length. mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
number_text_tokens: number_text_tokens:
stop_text_token:
number_mel_codes: number_mel_codes:
start_mel_token: start_mel_token:
stop_mel_token: stop_mel_token:
train_solo_embeddings:
use_mel_codes_as_input:
checkpointing: checkpointing:
average_conditioning_embeddings: Whether or not conditioning embeddings should be averaged, instead of fed piecewise into the model.
""" """
super().__init__() super().__init__()
@ -277,41 +270,20 @@ class UnifiedVoice(nn.Module):
self.model_dim = model_dim self.model_dim = model_dim
self.mel_length_compression = mel_length_compression self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.average_conditioning_embeddings = average_conditioning_embeddings
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim) self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
if use_mel_codes_as_input: self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
else:
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing) build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
else:
self.mel_solo_embedding = 0
self.text_solo_embedding = 0
self.final_norm = nn.LayerNorm(model_dim) self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1) self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes) self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
# Initialize the embeddings per the GPT-2 scheme # Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding] embeddings = [self.text_embedding, self.mel_embedding]
if use_mel_codes_as_input:
embeddings.append(self.mel_embedding)
for module in embeddings: for module in embeddings:
module.weight.data.normal_(mean=0.0, std=.02) module.weight.data.normal_(mean=0.0, std=.02)
if freeze_everything_but_position_embeddings:
for p in self.parameters():
p.requires_grad = False
p.DO_NOT_TRAIN = True
for m in [self.mel_pos_embedding, self.text_pos_embedding]:
for p in m.parameters():
del p.DO_NOT_TRAIN
p.requires_grad = True
def get_grad_norm_parameter_groups(self): def get_grad_norm_parameter_groups(self):
return { return {
'conditioning_encoder': list(self.conditioning_encoder.parameters()), 'conditioning_encoder': list(self.conditioning_encoder.parameters()),
@ -338,15 +310,13 @@ class UnifiedVoice(nn.Module):
mel_input_tokens[b, actual_end:] = self.stop_mel_token mel_input_tokens[b, actual_end:] = self.stop_mel_token
return mel_input_tokens return mel_input_tokens
def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False): def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, return_latent=False):
if second_inputs is not None: if second_inputs is not None:
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
else: else:
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) gpt_out = self.gpt(inputs_embeds=emb, return_dict=True)
if get_attns:
return gpt_out.attentions
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
enc = self.final_norm(enc) enc = self.final_norm(enc)
@ -372,13 +342,11 @@ class UnifiedVoice(nn.Module):
for j in range(speech_conditioning_input.shape[1]): for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1) conds = torch.stack(conds, dim=1)
if self.average_conditioning_embeddings: conds = conds.mean(dim=1).unsqueeze(1)
conds = conds.mean(dim=1).unsqueeze(1)
return conds return conds
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False, def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, return_latent=False):
return_latent=False, clip_inputs=True):
""" """
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`). (actuated by `text_first`).
@ -388,25 +356,13 @@ class UnifiedVoice(nn.Module):
text_lengths: long tensor, (b,) text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m) mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,) wav_lengths: long tensor, (b,)
raw_mels: MEL float tensor (b,80,s)
If return_attentions is specified, only logits are returned.
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
""" """
# Types are expressed by expanding the text embedding space. # Types are expressed by expanding the text embedding space.
if types is not None: if types is not None:
text_inputs = text_inputs * (1+types).unsqueeze(-1) text_inputs = text_inputs * (1+types).unsqueeze(-1)
if clip_inputs:
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
text_inputs = text_inputs[:, :max_text_len]
max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = mel_codes[:, :max_mel_len]
if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4]
mel_codes = self.set_mel_padding(mel_codes, wav_lengths) mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token) text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token) mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
@ -416,86 +372,24 @@ class UnifiedVoice(nn.Module):
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
if raw_mels is not None: mel_inp = mel_codes
mel_inp = F.pad(raw_mels, (0, 8))
else:
mel_inp = mel_codes
mel_emb = self.mel_embedding(mel_inp) mel_emb = self.mel_embedding(mel_inp)
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
if text_first: if text_first:
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent) text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, return_latent=return_latent)
if return_latent: if return_latent:
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
else: else:
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent) mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, return_latent=return_latent)
if return_latent: if return_latent:
return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
if return_attentions:
return mel_logits
loss_text = F.cross_entropy(text_logits, text_targets.long()) loss_text = F.cross_entropy(text_logits, text_targets.long())
loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_text.mean(), loss_mel.mean(), mel_logits return loss_text.mean(), loss_mel.mean(), mel_logits
def text_forward(self, speech_conditioning_input, text_inputs, text_lengths): def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
"""
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
"""
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
if self.average_conditioning_embeddings:
conds = conds.mean(dim=1).unsqueeze(1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
text_logits = self.get_logits(conds, text_emb, self.text_head)
loss_text = F.cross_entropy(text_logits, text_targets.long())
return loss_text.mean()
def speech_forward(self, speech_conditioning_input, mel_codes, wav_lengths, raw_mels=None):
"""
Performs autoregressive modeling on only speech data.
"""
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4]
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
if self.average_conditioning_embeddings:
conds = conds.mean(dim=1).unsqueeze(1)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
if raw_mels is not None:
mel_inp = F.pad(raw_mels, (0, 4))
else:
mel_inp = mel_codes
mel_emb = self.mel_embedding(mel_inp)
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding
mel_logits = self.get_logits(conds, mel_emb, self.mel_head)
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_mel.mean()
def inference_speech(self, speech_conditioning_input, text_inputs, return_attentions=False, **hf_generate_kwargs):
if self.max_mel_tokens == -1: # Assume if this is the case, max_mel_tokens=-1 also if self.max_mel_tokens == -1: # Assume if this is the case, max_mel_tokens=-1 also
seq_length = 2002 # Arbitrary default. seq_length = 2002 # Arbitrary default.
else: else:
@ -522,64 +416,17 @@ class UnifiedVoice(nn.Module):
for j in range(speech_conditioning_input.shape[1]): for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1) conds = torch.stack(conds, dim=1)
if self.average_conditioning_embeddings: conds = conds.mean(dim=1).unsqueeze(1)
conds = conds.mean(dim=1).unsqueeze(1)
emb = torch.cat([conds, text_emb], dim=1) emb = torch.cat([conds, text_emb], dim=1)
self.inference_model.store_mel_emb(emb) self.inference_model.store_prior_emb(emb)
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device) fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:,-1] = self.start_mel_token fake_inputs[:,-1] = self.start_mel_token
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
max_length=seq_length, output_attentions=return_attentions, return_dict_in_generate=True, **hf_generate_kwargs) max_length=seq_length, return_dict_in_generate=True, **hf_generate_kwargs)
if return_attentions: return gen.sequences[:, fake_inputs.shape[1]:]
return gen.sequences[:, fake_inputs.shape[1]:], gen.attentions
else:
return gen.sequences[:, fake_inputs.shape[1]:]
# Turns the (utterly insane) output of HF.generate() into a far more sane output:
# [tensors(B,H,S,S)]. Outer=layers, B=batch,H=head,S=sequence
def make_hf_generate_attentions_sane(self, attentions):
layers = [[] for _ in range(len(attentions[0]))]
full_attention_size = attentions[-1][0].shape[-1]
for i, gen in enumerate(attentions):
for j, lyr in enumerate(gen):
layers[j].append(F.pad(lyr, (0, full_attention_size - lyr.shape[-1])))
catted = []
for lyr in layers:
catted.append(torch.cat(lyr, dim=2))
return catted
def convert_attentions_to_aligned_codes(self, text, attentions, codes, num_conds):
"""
This was an attempt to make some sense out of the attention matrix retrieved from the unified_voice model. Unfortunately, I can't use it for aligning text & voice.
"""
text_padding = num_conds+2
num_text = text.shape[-1]
num_context = num_text + text_padding
assert num_context + 1 == attentions[0][0].shape[-1]
attentions = self.make_hf_generate_attentions_sane(attentions)
results = [torch.empty_like(codes) for _ in range(len(attentions))]
for l, layer in enumerate(attentions):
dec_context = layer[:, :, num_context:, :]
# Mask out everything that isn't text (including the start token, which gets a LOT of attention)
dec_context[:,:,:,:text_padding+1] = 0
dec_context[:,:,:,num_context:] = 0
for h in range(dec_context.shape[1]):
dec_context_indices = torch.argmax(dec_context[0,h], dim=-1)
print(f'layer_{l};head_{h}: ' + str(dec_context_indices))
for t, att_tok in enumerate(attentions):
combined_attention_weights = torch.zeros((codes.shape[0], num_text), device=codes.device)
for lyr in att_tok:
token_to_text_attentions = lyr[:, :, -1, text_padding:(text_padding + num_text)].sum(dim=1)
combined_attention_weights = combined_attention_weights + token_to_text_attentions
break
most_attended_text_token = combined_attention_weights.argmax(dim=-1)
results[:, t] = most_attended_text_token
eos_token_mask = (codes != self.stop_mel_token)
return results * eos_token_mask
@register_model @register_model
@ -588,11 +435,10 @@ def register_unified_voice2(opt_net, opt):
if __name__ == '__main__': if __name__ == '__main__':
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True, types=2) gpt = UnifiedVoice(model_dim=256, heads=4, max_conditioning_inputs=4, types=2)
l = gpt(torch.randn(2, 3, 80, 800), l = gpt(torch.randn(2, 3, 80, 800),
torch.randint(high=256, size=(2,120)), torch.randint(high=256, size=(2,120)),
torch.tensor([32, 120]), torch.tensor([32, 120]),
torch.randint(high=8192, size=(2,250)), torch.randint(high=8192, size=(2,250)),
torch.tensor([250*256,195*256]), torch.tensor([250*256,195*256]),
types=torch.tensor([0, 1])) types=torch.tensor([0, 1]))
#gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))