forked from mrq/DL-Art-School
clean up unified voice
- remove unused code - fix inference model to use the terms "prior" and "posterior" to properly define the modeling order (they were inverted before) - default some settings I never intend to change in the future
This commit is contained in:
parent
9118f58849
commit
b42b4e18de
|
@ -32,17 +32,17 @@ class ResBlock(nn.Module):
|
|||
|
||||
|
||||
class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear):
|
||||
def __init__(self, config, gpt, posterior_pos_emb, embeddings, norm, linear):
|
||||
super().__init__(config)
|
||||
self.transformer = gpt
|
||||
self.text_pos_embedding = text_pos_emb
|
||||
self.posterior_pos_embedding = posterior_pos_emb
|
||||
self.embeddings = embeddings
|
||||
self.lm_head = nn.Sequential(norm, linear)
|
||||
self.head = nn.Sequential(norm, linear)
|
||||
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
self.cached_mel_emb = None
|
||||
self.cached_prior_emb = None
|
||||
|
||||
def parallelize(self, device_map=None):
|
||||
self.device_map = (
|
||||
|
@ -52,27 +52,26 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
)
|
||||
assert_device_map(self.device_map, len(self.transformer.h))
|
||||
self.transformer.parallelize(self.device_map)
|
||||
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
||||
self.head = self.head.to(self.transformer.first_device)
|
||||
self.model_parallel = True
|
||||
|
||||
def deparallelize(self):
|
||||
self.transformer.deparallelize()
|
||||
self.transformer = self.transformer.to("cpu")
|
||||
self.lm_head = self.lm_head.to("cpu")
|
||||
self.head = self.head.to("cpu")
|
||||
self.model_parallel = False
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
return self.head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
self.head = new_embeddings
|
||||
|
||||
def store_mel_emb(self, mel_emb):
|
||||
self.cached_mel_emb = mel_emb
|
||||
def store_prior_emb(self, mel_emb):
|
||||
self.cached_prior_emb = mel_emb
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
|
@ -117,25 +116,25 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
assert self.cached_mel_emb is not None
|
||||
assert self.cached_prior_emb is not None
|
||||
assert inputs_embeds is None # Not supported by this inference model.
|
||||
assert labels is None # Training not supported by this inference model.
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Create embedding
|
||||
mel_len = self.cached_mel_emb.shape[1]
|
||||
prior_len = self.cached_prior_emb.shape[1]
|
||||
if input_ids.shape[1] != 1:
|
||||
text_inputs = input_ids[:, mel_len:]
|
||||
text_emb = self.embeddings(text_inputs)
|
||||
text_emb = text_emb + self.text_pos_embedding(text_emb)
|
||||
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
|
||||
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
|
||||
posterior_inputs = input_ids[:, prior_len:]
|
||||
posterior_emb = self.embeddings(posterior_inputs)
|
||||
posterior_emb = posterior_emb + self.posterior_pos_embedding(posterior_emb)
|
||||
if self.cached_prior_emb.shape[0] != posterior_emb.shape[0]:
|
||||
prior_emb = self.cached_prior_emb.repeat_interleave(posterior_emb.shape[0] // self.cached_prior_emb.shape[0], 0)
|
||||
else:
|
||||
mel_emb = self.cached_mel_emb
|
||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||
prior_emb = self.cached_prior_emb
|
||||
emb = torch.cat([prior_emb, posterior_emb], dim=1)
|
||||
else:
|
||||
emb = self.embeddings(input_ids)
|
||||
emb = emb + self.text_pos_embedding.get_fixed_embedding(attention_mask.shape[1]-mel_len, attention_mask.device)
|
||||
emb = emb + self.posterior_pos_embedding.get_fixed_embedding(attention_mask.shape[1] - prior_len, attention_mask.device)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
inputs_embeds=emb,
|
||||
|
@ -156,16 +155,16 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
# Set device for model parallelism
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(self.transformer.first_device)
|
||||
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
||||
hidden_states = hidden_states.to(self.head.weight.device)
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
logits = self.head(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (lm_logits,) + transformer_outputs[1:]
|
||||
return (logits,) + transformer_outputs[1:]
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=None,
|
||||
logits=lm_logits,
|
||||
logits=logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
|
@ -239,9 +238,7 @@ class MelEncoder(nn.Module):
|
|||
class UnifiedVoice(nn.Module):
|
||||
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
||||
mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192,
|
||||
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, start_text_token=None,
|
||||
checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False,
|
||||
types=1):
|
||||
stop_mel_token=8193, start_text_token=None, checkpointing=True, types=1):
|
||||
"""
|
||||
Args:
|
||||
layers: Number of layers in transformer stack.
|
||||
|
@ -252,14 +249,10 @@ class UnifiedVoice(nn.Module):
|
|||
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
|
||||
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
|
||||
number_text_tokens:
|
||||
stop_text_token:
|
||||
number_mel_codes:
|
||||
start_mel_token:
|
||||
stop_mel_token:
|
||||
train_solo_embeddings:
|
||||
use_mel_codes_as_input:
|
||||
checkpointing:
|
||||
average_conditioning_embeddings: Whether or not conditioning embeddings should be averaged, instead of fed piecewise into the model.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
@ -277,41 +270,20 @@ class UnifiedVoice(nn.Module):
|
|||
self.model_dim = model_dim
|
||||
self.mel_length_compression = mel_length_compression
|
||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||
self.average_conditioning_embeddings = average_conditioning_embeddings
|
||||
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
|
||||
if use_mel_codes_as_input:
|
||||
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
||||
else:
|
||||
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
||||
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
||||
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
|
||||
if train_solo_embeddings:
|
||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||
else:
|
||||
self.mel_solo_embedding = 0
|
||||
self.text_solo_embedding = 0
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
|
||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
||||
|
||||
# Initialize the embeddings per the GPT-2 scheme
|
||||
embeddings = [self.text_embedding]
|
||||
if use_mel_codes_as_input:
|
||||
embeddings.append(self.mel_embedding)
|
||||
embeddings = [self.text_embedding, self.mel_embedding]
|
||||
for module in embeddings:
|
||||
module.weight.data.normal_(mean=0.0, std=.02)
|
||||
|
||||
if freeze_everything_but_position_embeddings:
|
||||
for p in self.parameters():
|
||||
p.requires_grad = False
|
||||
p.DO_NOT_TRAIN = True
|
||||
for m in [self.mel_pos_embedding, self.text_pos_embedding]:
|
||||
for p in m.parameters():
|
||||
del p.DO_NOT_TRAIN
|
||||
p.requires_grad = True
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
return {
|
||||
'conditioning_encoder': list(self.conditioning_encoder.parameters()),
|
||||
|
@ -338,15 +310,13 @@ class UnifiedVoice(nn.Module):
|
|||
mel_input_tokens[b, actual_end:] = self.stop_mel_token
|
||||
return mel_input_tokens
|
||||
|
||||
def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False):
|
||||
def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, return_latent=False):
|
||||
if second_inputs is not None:
|
||||
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
|
||||
else:
|
||||
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
|
||||
|
||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
||||
if get_attns:
|
||||
return gpt_out.attentions
|
||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True)
|
||||
|
||||
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
|
||||
enc = self.final_norm(enc)
|
||||
|
@ -372,13 +342,11 @@ class UnifiedVoice(nn.Module):
|
|||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
if self.average_conditioning_embeddings:
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
return conds
|
||||
|
||||
|
||||
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
|
||||
return_latent=False, clip_inputs=True):
|
||||
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, return_latent=False):
|
||||
"""
|
||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||
(actuated by `text_first`).
|
||||
|
@ -388,25 +356,13 @@ class UnifiedVoice(nn.Module):
|
|||
text_lengths: long tensor, (b,)
|
||||
mel_inputs: long tensor, (b,m)
|
||||
wav_lengths: long tensor, (b,)
|
||||
raw_mels: MEL float tensor (b,80,s)
|
||||
|
||||
If return_attentions is specified, only logits are returned.
|
||||
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
||||
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
|
||||
"""
|
||||
# Types are expressed by expanding the text embedding space.
|
||||
if types is not None:
|
||||
text_inputs = text_inputs * (1+types).unsqueeze(-1)
|
||||
|
||||
if clip_inputs:
|
||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||
# chopping the inputs by the maximum actual length.
|
||||
max_text_len = text_lengths.max()
|
||||
text_inputs = text_inputs[:, :max_text_len]
|
||||
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
||||
mel_codes = mel_codes[:, :max_mel_len]
|
||||
if raw_mels is not None:
|
||||
raw_mels = raw_mels[:, :, :max_mel_len*4]
|
||||
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
|
||||
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
|
||||
|
@ -416,86 +372,24 @@ class UnifiedVoice(nn.Module):
|
|||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||
if raw_mels is not None:
|
||||
mel_inp = F.pad(raw_mels, (0, 8))
|
||||
else:
|
||||
mel_inp = mel_codes
|
||||
mel_inp = mel_codes
|
||||
mel_emb = self.mel_embedding(mel_inp)
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
|
||||
|
||||
if text_first:
|
||||
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
|
||||
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, return_latent=return_latent)
|
||||
if return_latent:
|
||||
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
||||
else:
|
||||
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
|
||||
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, return_latent=return_latent)
|
||||
if return_latent:
|
||||
return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
||||
|
||||
if return_attentions:
|
||||
return mel_logits
|
||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||
return loss_text.mean(), loss_mel.mean(), mel_logits
|
||||
|
||||
def text_forward(self, speech_conditioning_input, text_inputs, text_lengths):
|
||||
"""
|
||||
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
|
||||
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
|
||||
"""
|
||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||
# chopping the inputs by the maximum actual length.
|
||||
max_text_len = text_lengths.max()
|
||||
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
|
||||
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
if self.average_conditioning_embeddings:
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
|
||||
text_logits = self.get_logits(conds, text_emb, self.text_head)
|
||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||
return loss_text.mean()
|
||||
|
||||
def speech_forward(self, speech_conditioning_input, mel_codes, wav_lengths, raw_mels=None):
|
||||
"""
|
||||
Performs autoregressive modeling on only speech data.
|
||||
"""
|
||||
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
|
||||
|
||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||
# chopping the inputs by the maximum actual length.
|
||||
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
||||
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
|
||||
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||
if raw_mels is not None:
|
||||
raw_mels = raw_mels[:, :, :max_mel_len*4]
|
||||
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
if self.average_conditioning_embeddings:
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
|
||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||
if raw_mels is not None:
|
||||
mel_inp = F.pad(raw_mels, (0, 4))
|
||||
else:
|
||||
mel_inp = mel_codes
|
||||
mel_emb = self.mel_embedding(mel_inp)
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding
|
||||
mel_logits = self.get_logits(conds, mel_emb, self.mel_head)
|
||||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||
return loss_mel.mean()
|
||||
|
||||
def inference_speech(self, speech_conditioning_input, text_inputs, return_attentions=False, **hf_generate_kwargs):
|
||||
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
||||
if self.max_mel_tokens == -1: # Assume if this is the case, max_mel_tokens=-1 also
|
||||
seq_length = 2002 # Arbitrary default.
|
||||
else:
|
||||
|
@ -522,64 +416,17 @@ class UnifiedVoice(nn.Module):
|
|||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
if self.average_conditioning_embeddings:
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
|
||||
emb = torch.cat([conds, text_emb], dim=1)
|
||||
self.inference_model.store_mel_emb(emb)
|
||||
self.inference_model.store_prior_emb(emb)
|
||||
|
||||
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||
fake_inputs[:,-1] = self.start_mel_token
|
||||
|
||||
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
||||
max_length=seq_length, output_attentions=return_attentions, return_dict_in_generate=True, **hf_generate_kwargs)
|
||||
if return_attentions:
|
||||
return gen.sequences[:, fake_inputs.shape[1]:], gen.attentions
|
||||
else:
|
||||
return gen.sequences[:, fake_inputs.shape[1]:]
|
||||
|
||||
|
||||
# Turns the (utterly insane) output of HF.generate() into a far more sane output:
|
||||
# [tensors(B,H,S,S)]. Outer=layers, B=batch,H=head,S=sequence
|
||||
def make_hf_generate_attentions_sane(self, attentions):
|
||||
layers = [[] for _ in range(len(attentions[0]))]
|
||||
full_attention_size = attentions[-1][0].shape[-1]
|
||||
for i, gen in enumerate(attentions):
|
||||
for j, lyr in enumerate(gen):
|
||||
layers[j].append(F.pad(lyr, (0, full_attention_size - lyr.shape[-1])))
|
||||
catted = []
|
||||
for lyr in layers:
|
||||
catted.append(torch.cat(lyr, dim=2))
|
||||
return catted
|
||||
|
||||
def convert_attentions_to_aligned_codes(self, text, attentions, codes, num_conds):
|
||||
"""
|
||||
This was an attempt to make some sense out of the attention matrix retrieved from the unified_voice model. Unfortunately, I can't use it for aligning text & voice.
|
||||
"""
|
||||
text_padding = num_conds+2
|
||||
num_text = text.shape[-1]
|
||||
num_context = num_text + text_padding
|
||||
assert num_context + 1 == attentions[0][0].shape[-1]
|
||||
attentions = self.make_hf_generate_attentions_sane(attentions)
|
||||
results = [torch.empty_like(codes) for _ in range(len(attentions))]
|
||||
for l, layer in enumerate(attentions):
|
||||
dec_context = layer[:, :, num_context:, :]
|
||||
# Mask out everything that isn't text (including the start token, which gets a LOT of attention)
|
||||
dec_context[:,:,:,:text_padding+1] = 0
|
||||
dec_context[:,:,:,num_context:] = 0
|
||||
for h in range(dec_context.shape[1]):
|
||||
dec_context_indices = torch.argmax(dec_context[0,h], dim=-1)
|
||||
print(f'layer_{l};head_{h}: ' + str(dec_context_indices))
|
||||
for t, att_tok in enumerate(attentions):
|
||||
combined_attention_weights = torch.zeros((codes.shape[0], num_text), device=codes.device)
|
||||
for lyr in att_tok:
|
||||
token_to_text_attentions = lyr[:, :, -1, text_padding:(text_padding + num_text)].sum(dim=1)
|
||||
combined_attention_weights = combined_attention_weights + token_to_text_attentions
|
||||
break
|
||||
most_attended_text_token = combined_attention_weights.argmax(dim=-1)
|
||||
results[:, t] = most_attended_text_token
|
||||
eos_token_mask = (codes != self.stop_mel_token)
|
||||
return results * eos_token_mask
|
||||
max_length=seq_length, return_dict_in_generate=True, **hf_generate_kwargs)
|
||||
return gen.sequences[:, fake_inputs.shape[1]:]
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -588,11 +435,10 @@ def register_unified_voice2(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True, types=2)
|
||||
gpt = UnifiedVoice(model_dim=256, heads=4, max_conditioning_inputs=4, types=2)
|
||||
l = gpt(torch.randn(2, 3, 80, 800),
|
||||
torch.randint(high=256, size=(2,120)),
|
||||
torch.tensor([32, 120]),
|
||||
torch.randint(high=8192, size=(2,250)),
|
||||
torch.tensor([250*256,195*256]),
|
||||
types=torch.tensor([0, 1]))
|
||||
#gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
|
||||
|
|
Loading…
Reference in New Issue
Block a user