More cleanup

This commit is contained in:
James Betker 2022-02-04 11:06:17 -07:00
parent 5cc342de66
commit bb3d1ab03d
14 changed files with 58 additions and 2341 deletions

View File

@ -1,330 +0,0 @@
from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Model, GPT2Config, GPT2LMHeadModel, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
from models.tacotron2.text import symbols
from trainer.networks import register_model
from utils.util import opt_get
class ResBlock(nn.Module):
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan//8, chan),
nn.ReLU(),
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan//8, chan)
)
def forward(self, x):
return F.relu(self.net(x) + x)
class MelEncoder(nn.Module):
def __init__(self, channels, mel_channels=80):
super().__init__()
self.channels = channels
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=5, padding=2),
ResBlock(channels//4),
ResBlock(channels//4),
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//16, channels//2),
nn.ReLU(),
ResBlock(channels//2),
ResBlock(channels//2),
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//8, channels),
nn.ReLU(),
ResBlock(channels),
ResBlock(channels)
)
def forward(self, x):
return self.encoder(x)
class GPT2InferenceModel(GPT2PreTrainedModel):
def __init__(self, config, gpt, text_pos_emb, norm, linear):
super().__init__(config)
self.transformer = gpt
self.text_pos_embedding = text_pos_emb
self.lm_head = nn.Sequential(norm, linear)
# Model parallel
self.model_parallel = False
self.device_map = None
self.cached_mel_emb = None
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.transformer.h))
self.transformer.parallelize(self.device_map)
self.lm_head = self.lm_head.to(self.transformer.first_device)
self.model_parallel = True
def deparallelize(self):
self.transformer.deparallelize()
self.transformer = self.transformer.to("cpu")
self.lm_head = self.lm_head.to("cpu")
self.model_parallel = False
torch.cuda.empty_cache()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def store_mel_emb(self, mel_emb):
self.cached_mel_emb = mel_emb
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
assert self.cached_mel_emb is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Create embedding
mel_len = self.cached_mel_emb.shape[1]
if input_ids.shape[1] != 1:
text_inputs = input_ids[:, mel_len:]
text_emb = self.transformer.get_input_embeddings()(text_inputs)
if self.text_pos_embedding is not None:
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_emb.device))
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
else:
mel_emb = self.cached_mel_emb
emb = torch.cat([mel_emb, text_emb], dim=1)
else:
emb = self.transformer.get_input_embeddings()(input_ids)
if self.text_pos_embedding is not None:
emb = emb + self.text_pos_embedding(torch.tensor(attention_mask.shape[1]-mel_len, device=attention_mask.device)).unsqueeze(0).unsqueeze(0)
transformer_outputs = self.transformer(
inputs_embeds=emb,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device)
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return (lm_logits,) + transformer_outputs[1:]
return CausalLMOutputWithCrossAttentions(
loss=None,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
@staticmethod
def _reorder_cache(past, beam_idx):
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
class GptAsrHf(nn.Module):
NUMBER_SYMBOLS = len(symbols)
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+1
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_frames=1000, checkpointing=True):
super().__init__()
self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding.
self.max_symbols_per_phrase = max_symbols_per_phrase
self.model_dim = model_dim
self.max_mel_frames = self.max_mel_frames
self.mel_encoder = MelEncoder(model_dim)
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim)
seq_length = 2+self.max_symbols_per_phrase+self.max_mel_frames
self.gpt_config = GPT2Config(vocab_size=self.NUMBER_TEXT_TOKENS,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing)
self.gpt = GPT2Model(self.gpt_config)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
def get_logits(self, mel_inputs, text_targets, get_attns=False):
# Pad front and back. Pad at front is the "START" token.
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1]))
text_emb = self.gpt.get_input_embeddings()(text_targets)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
mel_emb = self.mel_encoder(mel_inputs)
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
emb = torch.cat([mel_emb, text_emb], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
if get_attns:
return gpt_out.attentions
enc = gpt_out.last_hidden_state
text_logits = self.final_norm(enc[:, self.max_mel_frames:])
text_logits = self.text_head(text_logits)
text_logits = text_logits.permute(0,2,1)
return text_logits
def forward(self, mel_inputs, text_targets, return_attentions=False):
text_logits = self.get_logits(mel_inputs, text_targets, get_attns=return_attentions)
if return_attentions:
return text_logits # These weren't really the logits.
loss_text = F.cross_entropy(text_logits, text_targets.long())
return loss_text.mean(), text_logits
def inference(self, mel_inputs, cond_text=None, do_sample=False, temperature=1.0, num_beams=8):
if not hasattr(self, 'inference_model'):
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head)
mel_emb = self.mel_encoder(mel_inputs)
assert mel_emb.shape[-1] <= self.max_mel_frames
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
self.inference_model.store_mel_emb(mel_emb)
# "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above.
if cond_text is None:
fake_inputs = torch.full((mel_inputs.shape[0],self.max_mel_frames+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
fake_inputs[:,-1] = self.NUMBER_SYMBOLS
else:
cond_used = 10
fake_inputs = torch.full((mel_inputs.shape[0],self.max_mel_frames+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
fake_inputs[:,-1-cond_used] = self.NUMBER_SYMBOLS
fake_inputs[:, -cond_used:] = cond_text[:, :cond_used]
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.NUMBER_SYMBOLS, pad_token_id=0, eos_token_id=0,
max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True)
return gen[:, self.max_mel_frames:]
@register_model
def register_gpt_asr_hf(opt_net, opt):
return GptAsrHf(**opt_get(opt_net, ['kwargs'], {}))
# Quick script that loads a model and halves the number of layers, then saves that model.
def distill():
gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=512, heads=8)
gpt.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_ema.pth'))
rc = 0
i = 0
while i < len(gpt.gpt.h):
if rc % 2 != 0:
del gpt.gpt.h[i]
else:
i += 1
rc += 1
torch.save(gpt.state_dict(), 'X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_distilled.pth')
if __name__ == '__main__':
distill()
'''
gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
#l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100)))
start = time()
gpt.inference(torch.randn(1,80,350), num_beams=1)
print(f"Elapsed: {time()-start}")
'''
'''
with torch.no_grad():
t = torch.randn(1,80,800).cuda()
start = time()
s = gpt.inference_beam_topk(t)
print(time()-start)
start = time()
o = gpt.inference_beam_topk(t, fn='inference_beam_opt')
print(time()-start)
'''

View File

@ -1,396 +0,0 @@
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Model, GPT2Config, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
from trainer.networks import register_model
from utils.util import opt_get
class ResBlock(nn.Module):
"""
Basic residual convolutional block that uses GroupNorm.
"""
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan//8, chan),
nn.ReLU(),
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan//8, chan)
)
def forward(self, x):
return F.relu(self.net(x) + x)
class LeanMelEncoder(nn.Module):
"""
Encodes a BxCxS MEL tensor into a latent space suitable for use with a transformer.
"""
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=1):
super().__init__()
self.channels = channels
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//2, kernel_size=5, stride=2, padding=1),
nn.GroupNorm(channels//16, channels//2),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//8, channels),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels, channels, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//8, channels),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
)
self.reduction = 8
def forward(self, x):
for e in self.encoder:
x = e(x)
return x
def null_position_embeddings(range, dim):
"""
Helper method which simply returns a range-shaped tensor filled with zeros. Useful for emulating a no-effect
embedding.
"""
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
class GPT2InferenceModel(GPT2PreTrainedModel):
def __init__(self, config, gpt, text_pos_emb, norm, linear):
super().__init__(config)
self.transformer = gpt
self.text_pos_embedding = text_pos_emb
self.lm_head = nn.Sequential(norm, linear)
# Model parallel
self.model_parallel = False
self.device_map = None
self.cached_mel_emb = None
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.transformer.h))
self.transformer.parallelize(self.device_map)
self.lm_head = self.lm_head.to(self.transformer.first_device)
self.model_parallel = True
def deparallelize(self):
self.transformer.deparallelize()
self.transformer = self.transformer.to("cpu")
self.lm_head = self.lm_head.to("cpu")
self.model_parallel = False
torch.cuda.empty_cache()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def store_mel_emb(self, mel_emb):
self.cached_mel_emb = mel_emb
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
assert self.cached_mel_emb is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Create embedding
mel_len = self.cached_mel_emb.shape[1]
if input_ids.shape[1] != 1:
text_inputs = input_ids[:, mel_len:]
text_emb = self.transformer.get_input_embeddings()(text_inputs)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_emb.device))
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
else:
mel_emb = self.cached_mel_emb
emb = torch.cat([mel_emb, text_emb], dim=1)
else:
emb = self.transformer.get_input_embeddings()(input_ids) + \
self.text_pos_embedding(torch.tensor(attention_mask.shape[1]-mel_len, device=attention_mask.device)).unsqueeze(0).unsqueeze(0)
transformer_outputs = self.transformer(
inputs_embeds=emb,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device)
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return (lm_logits,) + transformer_outputs[1:]
return CausalLMOutputWithCrossAttentions(
loss=None,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
@staticmethod
def _reorder_cache(past, beam_idx):
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
class GptAsrHf2(nn.Module):
"""
Core module that encapsulates a set of embeddings, a MEL encoder, a GPT-style transformer and the head needed to
make its output useful.
"""
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=800, max_mel_frames=3000,
checkpointing=True, number_text_tokens=512, start_token=511, stop_token=0, mel_compression=256):
super().__init__()
self.number_text_tokens = number_text_tokens
self.start_token = start_token
self.stop_token = stop_token
self.max_symbols_per_phrase = max_symbols_per_phrase
self.model_dim = model_dim
self.mel_encoder = LeanMelEncoder(model_dim)
self.max_mel_frames = max_mel_frames // self.mel_encoder.reduction
self.mel_compression = mel_compression
seq_length = 2+self.max_symbols_per_phrase+self.max_mel_frames
self.gpt_config = GPT2Config(vocab_size=self.number_text_tokens,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing)
self.gpt = GPT2Model(self.gpt_config)
# Override the built in positional embeddings
del self.gpt.wpe
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
# This model uses its own positional embeddings, which helps discriminate between text and audio MELs.
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim)
self.text_solo_embedding = nn.Parameter(torch.randn(1,1,model_dim) * self.gpt.config.initializer_range, requires_grad=True)
# Head layers
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
# Initialize the embeddings per the GPT-2 scheme
for module in [self.text_pos_embedding, self.mel_pos_embedding]:
module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
"""
Helper function for producing inputs and outputs for the GPT model.
"""
inp = F.pad(input, (1,0), value=start_token)
tar = F.pad(input, (0,1), value=stop_token)
return inp, tar
def get_logits(self, mel_inputs, text_emb, get_attns=False):
"""
Helper function for producing text logits.
"""
if mel_inputs is None:
emb = text_emb
mel_len = 0
else:
mel_emb = self.mel_encoder(mel_inputs)
assert mel_emb.shape[-1] <= self.max_mel_frames, f'{mel_emb.shape[-1]} > {self.max_mel_frames}'
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
emb = torch.cat([mel_emb, text_emb], dim=1)
mel_len = mel_emb.shape[1]
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
if get_attns:
return gpt_out.attentions
enc = gpt_out.last_hidden_state
text_logits = self.final_norm(enc[:, mel_len:])
text_logits = self.text_head(text_logits)
text_logits = text_logits.permute(0,2,1)
return text_logits
def forward(self, mel_inputs, wav_lengths, text_inputs, text_lengths, return_attentions=False):
"""
"Normal" forward pass which produces a text loss when given a MEL-encoded audio clip and transcribed text
targets.
"""
assert text_inputs.shape[1] <= self.max_symbols_per_phrase, str(text_inputs.shape[1])
assert text_inputs.max() <= self.number_text_tokens, str(text_inputs.max())
# Trim off excessive inputs to speed training. This might seem odd, but consider that this model is fed microbatches
# which are padded at the macro-batch level.
max_text_len = text_lengths.max()
text_inputs = text_inputs[:, :max_text_len]
max_mel_len = wav_lengths.max() // self.mel_compression
mel_inputs = mel_inputs[:, :, :max_mel_len]
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
text_logits = self.get_logits(mel_inputs, text_emb, get_attns=return_attentions)
if return_attentions:
return text_logits # These weren't really the logits.
loss_text = F.cross_entropy(text_logits, text_targets.long())
return loss_text.mean(), text_logits
def text_only(self, text_inputs, text_lengths):
"""
Used to train on only text inputs.
"""
assert text_inputs.shape[1] <= self.max_symbols_per_phrase, str(text_inputs.shape[1])
assert text_inputs.max() <= self.number_text_tokens, str(text_inputs.max())
# Trim off excessive inputs to speed training. This might seem odd, but consider that this model is fed microbatches
# which are padded at the macro-batch level.
max_text_len = text_lengths.max()
text_inputs = text_inputs[:, :max_text_len]
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + \
self.text_solo_embedding
text_logits = self.get_logits(None, text_emb)
loss_text = F.cross_entropy(text_logits, text_targets.long())
return loss_text.mean(), text_logits
def inference(self, mel_inputs, wav_lengths, do_sample=False, temperature=1.0, num_beams=8):
"""
Performs inference by transcribing mel_inputs into text. Returns the text tokens.
"""
if not hasattr(self, 'inference_model'):
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head)
# TODO: get rid of this..
max_mel_len = wav_lengths.max() // self.mel_compression
mel_inputs = mel_inputs[:, :, :max_mel_len]
mel_emb = self.mel_encoder(mel_inputs)
assert mel_emb.shape[-1] <= self.max_mel_frames
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
self.inference_model.store_mel_emb(mel_emb)
# "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above.
fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
fake_inputs[:,-1] = self.start_token
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.start_token, pad_token_id=0, eos_token_id=0,
max_length=self.max_symbols_per_phrase+mel_emb.shape[1], temperature=temperature, num_beams=num_beams, use_cache=True)
return gen[:, mel_emb.shape[1]+1:]
@register_model
def register_gpt_asr_hf2(opt_net, opt):
return GptAsrHf2(**opt_get(opt_net, ['kwargs'], {}))
# Quick script that loads a model and halves the number of layers, then saves that model.
def distill():
gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=512, heads=8)
gpt.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_ema.pth'))
rc = 0
i = 0
while i < len(gpt.gpt.h):
if rc % 2 != 0:
del gpt.gpt.h[i]
else:
i += 1
rc += 1
torch.save(gpt.state_dict(), 'X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_distilled.pth')
if __name__ == '__main__':
#distill()
gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
l = gpt(torch.randn(2,80,640), torch.tensor([100*256,20*256]), torch.randint(high=100, size=(2,80)), torch.tensor([15,60]))
gpt.text_only(torch.randint(high=100, size=(2,120)), torch.tensor([30,33]))
#start = time()
#gpt.inference(torch.randn(1,80,350), num_beams=1)
#print(f"Elapsed: {time()-start}")

View File

@ -1,159 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Model, GPT2Config
from models.arch_util import AttentionBlock
from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel
from models.tacotron2.text import symbols
from trainer.networks import register_model
from utils.util import opt_get
class ConditioningEncoder(nn.Module):
def __init__(self,
spec_dim,
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
do_checkpointing=False):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
self.do_checkpointing = do_checkpointing
def forward(self, x):
h = self.init(x)
h = self.attn(h)
return h[:, :, 0]
class GptTtsHf(nn.Module):
NUMBER_TEXT_TOKENS = 256 # The number of tokens produced by our bespoke BPE tokenizer.
START_TEXT_TOKEN = 255
STOP_TEXT_TOKEN = 0
NUMBER_MEL_CODES = 8194
START_MEL_TOKEN = 8192
STOP_MEL_TOKEN = 8193
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=80, max_mel_tokens=250, max_conditioning_inputs=3,
checkpointing=True, mel_length_compression=1024, max_conditioning_length=60):
super().__init__()
self.max_mel_tokens = max_mel_tokens
self.max_symbols_per_phrase = max_symbols_per_phrase
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing)
self.gpt = GPT2Model(self.gpt_config)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
self.mel_head = nn.Linear(model_dim, self.NUMBER_MEL_CODES)
self.max_conditioning_length = max_conditioning_length
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1,0), value=start_token)
tar = F.pad(input, (0,1), value=stop_token)
return inp, tar
def get_logits(self, text_inputs, cond_input, mel_inputs, get_attns=False):
text_emb = self.text_embedding(text_inputs)
cond = self.conditioning_encoder(cond_input).unsqueeze(1)
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
emb = torch.cat([text_emb, cond, mel_emb], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
if get_attns:
return gpt_out.attentions
enc = gpt_out.last_hidden_state
text_logits = self.final_norm(enc[:, :text_emb.shape[1]])
text_logits = self.text_head(text_logits)
text_logits = text_logits.permute(0,2,1)
mel_logits = self.final_norm(enc[:, -mel_emb.shape[1]:])
mel_logits = self.mel_head(mel_logits)
mel_logits = mel_logits.permute(0,2,1)
return text_logits, mel_logits
def forward(self, text_inputs, cond_input, mel_targets, wav_lengths, return_attentions=False):
"""
Forward pass
text_inputs: long tensor, (b,t)
cond_inputs: MEL float tensor, (b,c,80,s)
mel_targets: long tensor, (b,m)
mel_lengths: long tensor, (b,)
"""
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
mel_lengths = wav_lengths // self.mel_length_compression
for b in range(len(mel_lengths)):
if mel_lengths[b] < mel_targets.shape[-1]:
mel_targets[b, mel_lengths[b]:] = self.STOP_MEL_TOKEN
# Randomly permute the conditioning spectrogram, to destroy any structure present.
cond_input = cond_input[:,:,torch.randperm(cond_input.shape[-1])]
if cond_input.shape[-1] > self.max_conditioning_length:
cond_input = cond_input[:,:,:self.max_conditioning_length]
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_targets, self.START_MEL_TOKEN, self.STOP_MEL_TOKEN)
text_logits, mel_logits = self.get_logits(text_inputs, cond_input, mel_inputs, get_attns=return_attentions)
if return_attentions:
return mel_logits
loss_text = F.cross_entropy(text_logits, text_targets.long())
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_text.mean(), loss_mel.mean(), mel_logits
def inference(self, text_inputs, cond_input, **hf_generate_kwargs):
if not hasattr(self, 'inference_model'):
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, None, self.final_norm, self.mel_head)
text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
text_emb = self.text_embedding(text_inputs)
# Randomly permute the conditioning spectrogram, to destroy any structure present.
cond_input = cond_input[:,:,torch.randperm(cond_input.shape[-1])]
if cond_input.shape[-1] > self.max_conditioning_length:
cond_input = cond_input[:,:,:self.max_conditioning_length]
cond = self.conditioning_encoder(cond_input).unsqueeze(1)
emb = torch.cat([text_emb, cond], dim=1)
self.inference_model.store_mel_emb(emb)
fake_inputs = torch.full((emb.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:,-1] = self.START_MEL_TOKEN
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
max_length=emb.shape[1]+self.max_mel_tokens, **hf_generate_kwargs)
return gen[:, fake_inputs.shape[1]:]
@register_model
def register_gpt_tts_hf(opt_net, opt):
return GptTtsHf(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
gpt = GptTtsHf(model_dim=1024, heads=16)
l = gpt(torch.randint(high=len(symbols), size=(2,200)),
torch.arange(0, 80, 1, dtype=torch.float).view(1,80,1).repeat(2,1,800),
torch.randint(high=8192, size=(2,250)),
torch.tensor([150*256,195*256]))

View File

@ -1,49 +0,0 @@
import torch
# "long" and "short" denote longer and shorter samples
class PixelShuffle1D(torch.nn.Module):
"""
1D pixel shuffler. https://arxiv.org/pdf/1609.05158.pdf
Upscales sample length, downscales channel length
"short" is input, "long" is output
"""
def __init__(self, upscale_factor):
super(PixelShuffle1D, self).__init__()
self.upscale_factor = upscale_factor
def forward(self, x):
batch_size = x.shape[0]
short_channel_len = x.shape[1]
short_width = x.shape[2]
long_channel_len = short_channel_len // self.upscale_factor
long_width = self.upscale_factor * short_width
x = x.contiguous().view([batch_size, self.upscale_factor, long_channel_len, short_width])
x = x.permute(0, 2, 3, 1).contiguous()
x = x.view(batch_size, long_channel_len, long_width)
return x
class PixelUnshuffle1D(torch.nn.Module):
"""
Inverse of 1D pixel shuffler
Upscales channel length, downscales sample length
"long" is input, "short" is output
"""
def __init__(self, downscale_factor):
super(PixelUnshuffle1D, self).__init__()
self.downscale_factor = downscale_factor
def forward(self, x):
batch_size = x.shape[0]
long_channel_len = x.shape[1]
long_width = x.shape[2]
short_channel_len = long_channel_len * self.downscale_factor
short_width = long_width // self.downscale_factor
x = x.contiguous().view([batch_size, long_channel_len, short_width, self.downscale_factor])
x = x.permute(0, 3, 1, 2).contiguous()
x = x.view([batch_size, short_channel_len, short_width])
return x

View File

@ -1,394 +0,0 @@
import random
from models.diffusion.fp16_util import convert_module_to_f32, convert_module_to_f16
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, ResBlock, TimestepEmbedSequential, \
Downsample, Upsample
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner
from trainer.networks import register_model
from utils.util import get_mask_from_lengths
class DiscreteSpectrogramConditioningBlock(nn.Module):
def __init__(self, dvae_channels, channels):
super().__init__()
self.intg = nn.Sequential(nn.Conv1d(dvae_channels, channels, kernel_size=1),
normalization(channels),
nn.SiLU(),
nn.Conv1d(channels, channels, kernel_size=3))
"""
Embeds the given codes and concatenates them onto x. Return shape is the same as x.shape.
:param x: bxcxS waveform latent
:param codes: bxN discrete codes, N <= S
"""
def forward(self, x, dvae_in):
b, c, S = x.shape
_, q, N = dvae_in.shape
emb = self.intg(dvae_in)
emb = nn.functional.interpolate(emb, size=(S,), mode='nearest')
return torch.cat([x, emb], dim=1)
class DiffusionVocoderWithRefTruncatedTop(nn.Module):
"""
The full UNet model with attention and timestep embedding.
Customized to be conditioned on a spectrogram prior.
:param in_channels: channels in the input Tensor.
:param spectrogram_channels: channels in the conditioning spectrogram.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def __init__(
self,
model_channels,
in_channels=1,
out_channels=2, # mean and variance
discrete_codes=512,
dropout=0,
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
spectrogram_conditioning_resolutions=(512,),
attention_resolutions=(512,1024,2048),
conv_resample=True,
dims=1,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
kernel_size=3,
scale_factor=2,
conditioning_inputs_provided=True,
conditioning_input_dim=80,
time_embed_dim_multiplier=4,
only_train_dvae_connection_layers=False,
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.dtype = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.dims = dims
padding = 1 if kernel_size == 3 else 2
time_embed_dim = model_channels * time_embed_dim_multiplier
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.conditioning_enabled = conditioning_inputs_provided
if conditioning_inputs_provided:
self.contextual_embedder = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
self.cheater_input_block = TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels//2, kernel_size, padding=padding, stride=2))
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, model_channels//2, model_channels, kernel_size, padding=padding)
)
]
)
spectrogram_blocks = []
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
if ds in spectrogram_conditioning_resolutions:
spec_cond_block = DiscreteSpectrogramConditioningBlock(discrete_codes, ch)
self.input_blocks.append(spec_cond_block)
spectrogram_blocks.append(spec_cond_block)
ch *= 2
for _ in range(num_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=int(mult * model_channels),
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
)
]
ch = int(mult * model_channels)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
kernel_size=kernel_size,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
),
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, (mult, num_blocks) in list(enumerate(zip(channel_mult, num_res_blocks)))[::-1]:
for i in range(num_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=int(model_channels * mult),
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
)
]
ch = int(model_channels * mult)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads_upsample,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
)
)
if level and i == num_blocks:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
kernel_size=kernel_size,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
# These are the special input and output blocks that are pseudo-disconnected from the rest of the graph,
# allowing them to be trained on a smaller subset of input.
self.top_inp_raw = TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
)
self.top_inp_blocks = nn.ModuleList([TimestepEmbedSequential(ResBlock(
model_channels,
time_embed_dim,
dropout,
out_channels=model_channels,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
)) for _ in range(num_blocks)])
self.top_out_upsample = TimestepEmbedSequential(ResBlock(
model_channels,
time_embed_dim,
dropout,
out_channels=model_channels,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
kernel_size=kernel_size,
) if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=model_channels, factor=scale_factor))
self.top_out_blocks = nn.ModuleList([TimestepEmbedSequential(ResBlock(
2 * model_channels,
time_embed_dim,
dropout,
out_channels=model_channels,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
)) for _ in range(num_blocks)
])
self.top_out_final = nn.Sequential(
normalization(model_channels),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
)
if only_train_dvae_connection_layers:
for p in self.parameters():
p.DO_NOT_TRAIN = True
p.requires_grad = False
for sb in spectrogram_blocks:
for p in sb.parameters():
del p.DO_NOT_TRAIN
p.requires_grad = True
def forward(self, x, timesteps, spectrogram, conditioning_input=None):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs, halved in size and the bounds of the original input that was halved.
"""
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
if self.conditioning_enabled:
assert conditioning_input is not None
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.conditioning_enabled:
emb2 = self.contextual_embedder(conditioning_input)
emb = emb1 + emb2
else:
emb = emb1
# Handle the top blocks first, independently of the rest of the unet. These only process half of x.
if self.training:
rand_start = (random.randint(0, x.shape[-1] // 2) // 2) * 2 # Must be a multiple of 2, to align with the next lower layer.
rand_stop = rand_start + x.shape[-1] // 2
else:
rand_start = 0 # When in eval, rand_start:rand_stop spans the entire input.
rand_stop = x.shape[-1]
top_blocks = []
ht = self.top_inp_raw(x.type(self.dtype)[:, :, rand_start:rand_stop], emb)
for block in self.top_inp_blocks:
ht = block(ht, emb)
top_blocks.append(ht)
# Now the standard unet (notice how it doesn't use ht at all, and uses a bare x fed through a strided conv.
h = self.cheater_input_block(x.type(self.dtype), emb)
hs = []
for k, module in enumerate(self.input_blocks):
if isinstance(module, DiscreteSpectrogramConditioningBlock):
h = module(h, spectrogram)
else:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
# And finally the top output blocks, which do consume the unet's outputs as well as the cross-input blocks. First we'll need to only take a subset of the unets output.
hb = h[:, :, rand_start//2:rand_stop//2]
hb = self.top_out_upsample(hb, emb)
for block in self.top_out_blocks:
hb = torch.cat([hb, top_blocks.pop()], dim=1)
hb = block(hb, emb)
hb = hb.type(x.dtype)
return self.top_out_final(hb), rand_start, rand_stop
@register_model
def register_unet_diffusion_vocoder_with_ref_trunc_top(opt_net, opt):
return DiffusionVocoderWithRefTruncatedTop(**opt_net['kwargs'])
# Test for ~4 second audio clip at 22050Hz
if __name__ == '__main__':
clip = torch.randn(2, 1, 40960)
#spec = torch.randint(8192, (2, 40,))
spec = torch.randn(2, 512, 160)
cond = torch.randn(2, 1, 40960)
ts = torch.LongTensor([555, 556])
model = DiffusionVocoderWithRefTruncatedTop(32, conditioning_inputs_provided=True, time_embed_dim_multiplier=8)
print(model(clip, ts, spec, cond))

View File

@ -1,344 +0,0 @@
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Model, GPT2Config
from models.arch_util import AttentionBlock
from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel
from models.gpt_voice.gpt_asr_hf2 import ResBlock
from models.tacotron2.text import symbols
from trainer.networks import register_model
from utils.util import opt_get
class ConditioningEncoder(nn.Module):
def __init__(self,
spec_dim,
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
do_checkpointing=False):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
self.do_checkpointing = do_checkpointing
def forward(self, x):
h = self.init(x)
h = self.attn(h)
return h[:, :, 0]
class MelEncoder(nn.Module):
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
super().__init__()
self.channels = channels
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//16, channels//2),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//8, channels),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
)
self.reduction = 4
def forward(self, x):
for e in self.encoder:
x = e(x)
return x.permute(0,2,1)
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
class UnifiedGptVoice(nn.Module):
"""
Derived from GptTtsHf, but offers multiple modes of autoregressive operation:
- Text only
- Voice only
- Text conditioned on voice
- Voice conditioned on text
"""
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
max_conditioning_length=60, shuffle_conditioning=True, mel_length_compression=1024, number_text_tokens=256,
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
checkpointing=True):
"""
Args:
layers: Number of layers in transformer stack.
model_dim: Operating dimensions of the transformer
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
max_text_tokens: Maximum number of text tokens that will be encountered by model.
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
max_conditioning_length: Maximum length of conditioning input. Only needed if shuffle_conditioning=True
shuffle_conditioning: Whether or not the conditioning inputs will be shuffled across the sequence dimension. Useful if you want to provide the same input as conditioning and mel_codes.
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
number_text_tokens:
start_text_token:
stop_text_token:
number_mel_codes:
start_mel_token:
stop_mel_token:
train_solo_embeddings:
use_mel_codes_as_input:
checkpointing:
"""
super().__init__()
self.number_text_tokens = number_text_tokens
self.start_text_token = start_text_token
self.stop_text_token = stop_text_token
self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token
self.shuffle_conditioning = shuffle_conditioning
self.max_mel_tokens = max_mel_tokens
self.max_text_tokens = max_text_tokens
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 2, model_dim)
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim)
seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing)
self.gpt = GPT2Model(self.gpt_config)
if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True)
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True)
else:
self.mel_solo_embedding = 0
self.text_solo_embedding = 0
# Override the built in positional embeddings
del self.gpt.wpe
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
if not use_mel_codes_as_input:
self.gpt.wte = MelEncoder(model_dim, resblocks_per_reduction=1)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
self.max_conditioning_length = max_conditioning_length
# Initialize the embeddings per the GPT-2 scheme
for module in [self.text_embedding, self.text_pos_embedding, self.mel_pos_embedding]:
module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1,0), value=start_token)
tar = F.pad(input, (0,1), value=stop_token)
return inp, tar
def set_mel_padding(self, mel_input_tokens, wav_lengths):
"""
Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
preformatting to create a working TTS model.
"""
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
mel_lengths = wav_lengths // self.mel_length_compression
for b in range(len(mel_lengths)):
actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
if actual_end < mel_input_tokens.shape[-1]:
mel_input_tokens[b, actual_end:] = self.stop_mel_token
return mel_input_tokens
def randomly_permute_conditioning_input(self, speech_conditioning_input):
"""
Randomly permute the conditioning spectrogram, to destroy any structure present. Note that since the
conditioning input is derived from a discrete spectrogram, it does actually retain structure, but only a little
bit (actually: exactly how much we want; enough to discriminate different vocal qualities, but nothing about
what is being said).
"""
cond_input = speech_conditioning_input[:,:,torch.randperm(speech_conditioning_input.shape[-1])]
if cond_input.shape[-1] > self.max_conditioning_length:
cond_input = cond_input[:,:,:self.max_conditioning_length]
return cond_input
def get_logits(self, speech_conditioning_input, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False):
if second_inputs is not None:
emb = torch.cat([speech_conditioning_input, first_inputs, second_inputs], dim=1)
else:
emb = torch.cat([speech_conditioning_input, first_inputs], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
if get_attns:
return gpt_out.attentions
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
enc = self.final_norm(enc)
first_logits = enc[:, :first_inputs.shape[1]]
first_logits = first_head(first_logits)
first_logits = first_logits.permute(0,2,1)
if second_inputs is not None:
second_logits = enc[:, -second_inputs.shape[1]:]
second_logits = second_head(second_logits)
second_logits = second_logits.permute(0,2,1)
return first_logits, second_logits
else:
return first_logits
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False):
"""
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`).
speech_conditioning_input: MEL float tensor, (b,80,s)
text_inputs: long tensor, (b,t)
text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,)
raw_mels: MEL float tensor (b,80,s)
"""
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4]
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
if self.shuffle_conditioning:
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
if raw_mels is not None:
mel_inp = F.pad(raw_mels, (0, 8))
else:
mel_inp = mel_codes
mel_emb = self.gpt.get_input_embeddings()(mel_inp)
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
if text_first:
text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
else:
mel_logits, text_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions)
if return_attentions:
return mel_logits
loss_text = F.cross_entropy(text_logits, text_targets.long())
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_text.mean(), loss_mel.mean(), mel_logits
def text_forward(self, speech_conditioning_input, text_inputs, text_lengths):
"""
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
"""
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
if self.shuffle_conditioning:
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + self.text_solo_embedding
text_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head)
loss_text = F.cross_entropy(text_logits, text_targets.long())
return loss_text.mean()
def speech_forward(self, speech_conditioning_input, mel_codes, wav_lengths, raw_mels=None):
"""
Performs autoregressive modeling on only speech data.
"""
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4]
if self.shuffle_conditioning:
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
if raw_mels is not None:
mel_inp = F.pad(raw_mels, (0, 4))
else:
mel_inp = mel_codes
mel_emb = self.gpt.get_input_embeddings()(mel_inp)
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + self.mel_solo_embedding
mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head)
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_mel.mean()
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
if not hasattr(self, 'inference_model'):
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
if self.shuffle_conditioning:
# Randomly permute the conditioning spectrogram, to destroy any structure present.
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
cond = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
emb = torch.cat([cond, text_emb], dim=1)
self.inference_model.store_mel_emb(emb)
fake_inputs = torch.full((emb.shape[0], emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:,-1] = self.start_mel_token
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
max_length=self.gpt_config.n_positions, **hf_generate_kwargs)
return gen[:, fake_inputs.shape[1]:]
@register_model
def register_unified_gpt_voice(opt_net, opt):
return UnifiedGptVoice(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
gpt = UnifiedGptVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True)
l = gpt(torch.randn(2, 80, 800),
torch.randint(high=len(symbols), size=(2,80)),
torch.tensor([32, 80]),
torch.randint(high=8192, size=(2,250)),
torch.tensor([150*256,195*256]))
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))

View File

@ -8,13 +8,30 @@ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
from models.arch_util import AttentionBlock
from models.gpt_voice.gpt_asr_hf2 import ResBlock
from models.gpt_voice.transformer_builders import build_hf_gpt_transformer
from models.tacotron2.text import symbols
from trainer.networks import register_model
from utils.util import opt_get
class ResBlock(nn.Module):
"""
Basic residual convolutional block that uses GroupNorm.
"""
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan//8, chan),
nn.ReLU(),
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan//8, chan)
)
def forward(self, x):
return F.relu(self.net(x) + x)
class GPT2InferenceModel(GPT2PreTrainedModel):
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear):
super().__init__(config)

View File

@ -1,313 +0,0 @@
import functools
from math import log
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Model, GPT2Config
from models.arch_util import AttentionBlock
from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel
from models.tacotron2.text import symbols
from trainer.networks import register_model
from utils.util import opt_get
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
class ConditioningEncoder(nn.Module):
def __init__(self,
spec_dim,
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
do_checkpointing=False):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
self.do_checkpointing = do_checkpointing
def forward(self, x):
h = self.init(x)
h = self.attn(h)
return h[:, :, 0]
class TopEncoder(nn.Module):
def __init__(self, layers, dim, heads, do_checkpointing=False, dim_reduction=16):
self.init = nn.Conv1d(dim, dim, kernel_size=1)
reduction_layers = []
for j in range(int(log(dim_reduction, 2))):
reduction_layers.append(AttentionBlock(dim, heads, do_checkpoint=do_checkpointing))
reduction_layers.append(nn.Conv1d(dim, dim, kernel_size=3, padding=1, stride=2))
self.reduction_layers = nn.Sequential(*reduction_layers)
actual_layers = [AttentionBlock(dim, heads, do_checkpoint=do_checkpointing) for _ in range(layers)]
self.actual_layers = nn.Sequential(*actual_layers)
def forward(self, x):
h = self.init(x)
h = self.reduction_layers(h)
h = self.actual_layers(h)
return h
class UnifiedGptVoice(nn.Module):
"""
Derived from GptTtsHf, but offers multiple modes of autoregressive operation:
- Text only
- Voice only
- Text conditioned on voice
- Voice conditioned on text
"""
def __init__(self, top_encoder_layers=4, top_layers=8, bottom_layers=8, top_dim_reduction=16, model_dim=512, heads=8,
max_symbols_per_phrase=120, max_mel_tokens=250, max_total_tokens=370, max_conditioning_inputs=3,
checkpointing=True, mel_length_compression=1024, max_conditioning_length=60, number_text_tokens=256,
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193):
super().__init__()
self.number_text_tokens = number_text_tokens
self.start_text_token = start_text_token
self.stop_text_token = stop_text_token
self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token
self.max_mel_tokens = max_mel_tokens
self.max_symbols_per_phrase = max_symbols_per_phrase
self.max_total_tokens = max_total_tokens
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
self.text_pos_solo_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
self.text_pos_paired_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
self.mel_pos_solo_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
self.mel_pos_paired_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
seq_length = 2+self.max_total_tokens+self.max_conditioning_inputs
self.top_encoder = TopEncoder(top_encoder_layers, model_dim, heads, do_checkpointing=checkpointing,
dim_reduction=top_dim_reduction)
self.top_gpt_config = GPT2Config(vocab_size=1,
n_positions=seq_length // top_dim_reduction,
n_ctx=seq_length // top_dim_reduction,
n_embd=model_dim,
n_layer=top_layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing)
self.top_gpt = GPT2Model(self.top_gpt_config)
del self.top_gpt.wte
self.top_gpt_start_embedding = nn.Parameter(torch.randn(1,1,model_dim)*self.top_gpt_config.initializer_range,
requires_grad=True)
self.top_dim_reduction = top_dim_reduction
self.bottom_gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=model_dim,
n_layer=bottom_layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing)
self.bottom_gpt = GPT2Model(self.bottom_gpt_config)
# Override the built in positional embeddings
del self.bottom_gpt.wpe
self.bottom_gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
self.max_conditioning_length = max_conditioning_length
# Initialize the embeddings per the GPT-2 scheme
for module in [self.text_embedding, self.text_pos_solo_embedding, self.text_pos_paired_embedding,
self.mel_pos_solo_embedding, self.mel_pos_paired_embedding]:
module.weight.data.normal_(mean=0.0, std=self.bottom_gpt.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1,0), value=start_token)
tar = F.pad(input, (0,1), value=stop_token)
return inp, tar
def set_mel_padding(self, mel_input_tokens, wav_lengths):
"""
Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
preformatting to create a working TTS model.
"""
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
mel_lengths = wav_lengths // self.mel_length_compression
for b in range(len(mel_lengths)):
actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
if actual_end < mel_input_tokens.shape[-1]:
mel_input_tokens[b, actual_end:] = self.stop_mel_token
return mel_input_tokens
def randomly_permute_conditioning_input(self, speech_conditioning_input):
"""
Randomly permute the conditioning spectrogram, to destroy any structure present. Note that since the
conditioning input is derived from a discrete spectrogram, it does actually retain structure, but only a little
bit (actually: exactly how much we want; enough to discriminate different vocal qualities, but nothing about
what is being said).
"""
cond_input = speech_conditioning_input[:,:,torch.randperm(speech_conditioning_input.shape[-1])]
if cond_input.shape[-1] > self.max_conditioning_length:
cond_input = cond_input[:,:,:self.max_conditioning_length]
return cond_input
def get_top_embeddings(self, embedded_input):
true_embeddings = self.top_encoder(embedded_input)
inputs = torch.cat([self.top_gpt_start_embedding, true_embeddings[:,:-1]], dim=1)
top_pred = self.top_gpt(inputs_embeds=inputs, return_dict=True)
return top_pred.last_hidden_state, true_embeddings
def inject_top_embeddings(self, embedded_input, probability_of_true_top_embedding=.5):
pred, true = self.get_top_embeddings(embedded_input)
rand = torch.bernoulli(torch.full((1,embedded_input.shape[1]),
fill_value=probability_of_true_top_embedding)).to(embedded_input.device)
mix = pred * rand + true * (not rand)
embs = torch.chunk(embedded_input, self.top_dim_reduction, dim=1)
assert len(embs) == mix.shape[1]
rejoin = []
for i, emb in enumerate(embs):
rejoin.append(torch.cat([mix[i], emb]), dim=1)
return torch.cat(rejoin, dim=1)
def get_logits(self, speech_conditioning_input, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False):
if second_inputs is not None:
emb = torch.cat([speech_conditioning_input, first_inputs, second_inputs], dim=1)
else:
emb = torch.cat([speech_conditioning_input, first_inputs], dim=1)
gpt_out = self.bottom_gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
if get_attns:
return gpt_out.attentions
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
enc = self.final_norm(enc)
first_logits = enc[:, :first_inputs.shape[1]]
first_logits = first_head(first_logits)
first_logits = first_logits.permute(0,2,1)
if second_inputs is not None:
second_logits = enc[:, -second_inputs.shape[1]:]
second_logits = second_head(second_logits)
second_logits = second_logits.permute(0,2,1)
return first_logits, second_logits
else:
return first_logits
def forward(self, speech_conditioning_input, text_inputs, mel_inputs, wav_lengths, text_first=True, return_attentions=False):
"""
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`).
speech_conditioning_input: MEL float tensor, (b,80,s)
text_inputs: long tensor, (b,t)
mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,)
"""
assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
assert self.max_total_tokens >= mel_inputs.shape[1] + text_inputs.shape[1], f'{mel_inputs.shape[1]}, {text_inputs.shape[1]}'
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths)
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_paired_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
mel_emb = self.bottom_gpt.get_input_embeddings()(mel_inputs)
mel_emb = mel_emb + self.mel_pos_paired_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
if text_first:
text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
else:
mel_logits, text_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions)
if return_attentions:
return mel_logits
loss_text = F.cross_entropy(text_logits, text_targets.long())
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_text.mean(), loss_mel.mean(), mel_logits
def text_forward(self, speech_conditioning_input, text_inputs):
"""
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
"""
assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_solo_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
text_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head)
loss_text = F.cross_entropy(text_logits, text_targets.long())
return loss_text.mean()
def speech_forward(self, speech_conditioning_input, mel_inputs, wav_lengths):
"""
Performs autoregressive modeling on only speech data.
"""
assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths)
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
mel_emb = self.bottom_gpt.get_input_embeddings()(mel_inputs)
mel_emb = mel_emb + self.mel_pos_solo_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head)
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_mel.mean()
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
if not hasattr(self, 'inference_model'):
self.inference_model = GPT2InferenceModel(self.bottom_gpt_config, self.bottom_gpt, self.mel_pos_paired_embedding, self.final_norm, self.mel_head)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_paired_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
# Randomly permute the conditioning spectrogram, to destroy any structure present.
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
cond = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
emb = torch.cat([cond, text_emb], dim=1)
self.inference_model.store_mel_emb(emb)
fake_inputs = torch.full((emb.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:,-1] = self.start_mel_token
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
max_length=self.bottom_gpt_config.n_positions, **hf_generate_kwargs)
return gen[:, fake_inputs.shape[1]:]
@register_model
def register_unified_gpt_voice_bilevel(opt_net, opt):
return UnifiedGptVoice(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
gpt = UnifiedGptVoice(model_dim=256, heads=4)
l = gpt(torch.randn(2, 80, 800),
torch.randint(high=len(symbols), size=(2,80)),
torch.randint(high=8192, size=(2,250)),
torch.tensor([150*256,195*256]))

View File

@ -1,88 +0,0 @@
import os
import os.path as osp
import logging
import random
import argparse
import torchvision
import utils
import utils.options as option
import utils.util as util
from models.tacotron2.text import sequence_to_text
from trainer.ExtensibleTrainer import ExtensibleTrainer
from data import create_dataset, create_dataloader
from tqdm import tqdm
import torch
import numpy as np
from scipy.io import wavfile
def forward_pass(model, data, output_dir, opt, macro_b, dataset):
with torch.no_grad():
model.feed_data(data, 0)
model.test()
gt_key = opt['eval']['gen_text']
txts = []
for b in range(model.eval_state[gt_key][0].shape[0]):
if 'real_text' in opt['eval'].keys():
real = data[opt['eval']['real_text']][b]
print(f'{macro_b} {b} Real text: "{real}"')
codes = model.eval_state[opt['eval']['gen_text']][0][b].cpu()
if hasattr(dataset, 'tokenizer'):
text = dataset.tokenizer.decode(codes.numpy())
text = text.replace(' $$$', '')
txts.append(text)
else:
txts.append(sequence_to_text(codes))
return txts
if __name__ == "__main__":
# Set seeds
torch.manual_seed(5555)
random.seed(5555)
np.random.seed(5555)
#### options
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_asr_hf2.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
util.mkdirs(
(path for key, path in opt['path'].items()
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
logger.info(option.dict2str(opt))
dataset_opt = opt['datasets']['val']
test_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn)
logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
model = ExtensibleTrainer(opt)
batch = 0
output = open('results.tsv', 'w')
dataset_dir = opt['path']['results_root']
util.mkdir(dataset_dir)
for data in tqdm(test_loader):
#if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
# continue
preds = forward_pass(model, data, dataset_dir, opt, batch, test_set)
for b, pred in enumerate(preds):
pred = pred.replace('_', '')
output.write(f'{pred}\t{os.path.basename(data["filenames"][b])}\n')
print(pred)
batch += 1
output.flush()

View File

@ -1,40 +0,0 @@
import os
import numpy
import torch
import torch.nn as nn
from matplotlib import pyplot
from torch.utils.tensorboard import SummaryWriter
from data.audio.unsupervised_audio_dataset import load_audio
from models.gpt_voice.gpt_asr_hf import GptAsrHf
from models.tacotron2.text import text_to_sequence
from trainer.injectors.base_injectors import MelSpectrogramInjector
if __name__ == '__main__':
audio_data = load_audio('Z:\\split\\classified\\fine\\books1\\2_dchha03 The Organization of Peace\\00010.wav', 22050).unsqueeze(0)
audio_data = torch.nn.functional.pad(audio_data, (0, 358395-audio_data.shape[-1]))
mel_inj = MelSpectrogramInjector({'in': 'in', 'out': 'out'}, {})
mel = mel_inj({'in': audio_data})['out'].cuda()
actual_text = 'and it doesn\'t take very long.'
labels = torch.IntTensor(text_to_sequence(actual_text, ['english_cleaners'])).unsqueeze(0).cuda()
model = GptAsrHf(layers=12, model_dim=512, max_mel_frames=1400, max_symbols_per_phrase=250, heads=8)
model.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\31000_gpt_ema.pth'))
model = model.cuda()
with torch.no_grad():
attentions = model(mel, labels, return_attentions=True)
attentions = torch.stack(attentions, dim=0).permute(0,1,2,4,3)[:, :, :, -model.max_symbols_per_phrase:, :model.max_mel_frames]
attentions = attentions.sum(0).sum(1).squeeze()
xs = [str(i) for i in range(1, model.max_mel_frames+1, 1)]
os.makedirs('results', exist_ok=True)
logger = SummaryWriter('results')
for e, character_attn in enumerate(attentions):
if e >= len(actual_text):
break
fig = pyplot.figure()
ax = fig.add_axes([0,0,1,1])
ax.bar(xs, character_attn.cpu().numpy())
logger.add_figure(f'{e}_{actual_text[e]}', fig)

View File

@ -114,44 +114,56 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model', default='X:\\dlas\\experiments\\train_diffusion_tts5_medium.yml')
parser.add_argument('-diffusion_model_name', type=str, help='Name of the diffusion model in opt.', default='generator')
parser.add_argument('-diffusion_model_path', type=str, help='Path to saved model weights', default='X:\\dlas\\experiments\\train_diffusion_tts5_medium\\models\\68500_generator_ema.pth')
# -cond "Y:\libritts/train-clean-100/103/1241/103_1241_000017_000001.wav"
parser.add_argument('-cond', type=str, help='Type of conditioning voice', default='simmons')
parser.add_argument('-diffusion_model_path', type=str, help='Path to saved model weights', default='X:\\dlas\\experiments\\train_diffusion_tts5_medium\\models\\73000_generator_ema.pth')
parser.add_argument('-sr_opt', type=str, help='Path to options YAML file used to train the SR diffusion model', default='X:\\dlas\\experiments\\train_diffusion_tts6_upsample.yml')
parser.add_argument('-sr_diffusion_model_name', type=str, help='Name of the SR diffusion model in opt.', default='generator')
parser.add_argument('-sr_diffusion_model_path', type=str, help='Path to saved model weights for the SR diffuser', default='X:\\dlas\\experiments\\train_diffusion_tts6_upsample\\models\\7000_generator_ema.pth')
parser.add_argument('-cond', type=str, help='Type of conditioning voice', default='carlin')
parser.add_argument('-diffusion_steps', type=int, help='Number of diffusion steps to perform to create the generate. Lower steps reduces quality, but >40 is generally pretty good.', default=100)
parser.add_argument('-diffusion_schedule', type=str, help='Type of diffusion schedule that was used', default='cosine')
parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='../results/use_diffuse_tts')
parser.add_argument('-sample_rate', type=int, help='Model sample rate', default=5500)
parser.add_argument('-cond_sample_rate', type=int, help='Conditioning sample rate', default=5500)
parser.add_argument('-device', type=str, help='Device to run on', default='cuda')
args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True)
print("Loading Diffusion Model..")
# Fixed parameters.
base_sample_rate = 5500
sr_sample_rate = 22050
print("Loading Diffusion Models..")
diffusion = load_model_from_config(args.opt, args.diffusion_model_name, also_load_savepoint=False,
load_path=args.diffusion_model_path, device=args.device)
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=args.diffusion_steps, schedule=args.diffusion_schedule)
aligned_codes_compression_factor = args.sample_rate * 221 // 11025
cond = load_audio(conditioning_clips[args.cond], args.cond_sample_rate).to(args.device)
if cond.shape[-1] > 88000:
cond = cond[:,:88000]
torchaudio.save(os.path.join(args.output_path, 'cond.wav'), cond.cpu(), args.sample_rate)
load_path=args.diffusion_model_path, device='cpu').eval()
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=args.diffusion_steps, schedule='cosine')
aligned_codes_compression_factor = base_sample_rate * 221 // 11025
sr_diffusion = load_model_from_config(args.sr_opt, args.sr_diffusion_model_name, also_load_savepoint=False,
load_path=args.sr_diffusion_model_path, device='cpu').eval()
sr_diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=args.diffusion_steps, schedule='linear')
sr_cond = load_audio(conditioning_clips[args.cond], sr_sample_rate).to(args.device)
if sr_cond.shape[-1] > 88000:
sr_cond = sr_cond[:,:88000]
cond = audio = torchaudio.functional.resample(sr_cond, sr_sample_rate, base_sample_rate)
torchaudio.save(os.path.join(args.output_path, 'cond_base.wav'), cond.cpu(), base_sample_rate)
torchaudio.save(os.path.join(args.output_path, 'cond_sr.wav'), sr_cond.cpu(), sr_sample_rate)
for p, code in enumerate(provided_codes):
print("Loading data..")
aligned_codes = torch.tensor(code).to(args.device)
with torch.no_grad():
for p, code in enumerate(provided_codes):
print("Loading data..")
aligned_codes = torch.tensor(code).to(args.device)
with torch.no_grad():
print("Performing inference..")
diffusion.eval()
print("Performing initial diffusion..")
output_shape = (1, 1, ceil_multiple(aligned_codes.shape[-1]*aligned_codes_compression_factor, 2048))
output = diffuser.p_sample_loop(diffusion, output_shape, noise=torch.zeros(output_shape, device=args.device),
diffusion = diffusion.cuda()
output_base = diffuser.p_sample_loop(diffusion, output_shape, noise=torch.zeros(output_shape, device=args.device),
model_kwargs={'tokens': aligned_codes.unsqueeze(0),
'conditioning_input': cond.unsqueeze(0)})
torchaudio.save(os.path.join(args.output_path, f'{p}_output_mean.wav'), output.cpu().squeeze(0), args.sample_rate)
diffusion = diffusion.cpu()
torchaudio.save(os.path.join(args.output_path, f'{p}_output_mean_base.wav'), output_base.cpu().squeeze(0), base_sample_rate)
for k in range(2):
output = diffuser.p_sample_loop(diffusion, output_shape, model_kwargs={'tokens': aligned_codes.unsqueeze(0),
'conditioning_input': cond.unsqueeze(0)})
torchaudio.save(os.path.join(args.output_path, f'{p}_output_{k}.wav'), output.cpu().squeeze(0), args.sample_rate)
print("Performing SR diffusion..")
output_shape = (1, 1, output_base.shape[-1] * (sr_sample_rate // base_sample_rate))
sr_diffusion = sr_diffusion.cuda()
output = diffuser.p_sample_loop(sr_diffusion, output_shape, noise=torch.zeros(output_shape, device=args.device),
model_kwargs={'tokens': aligned_codes.unsqueeze(0),
'conditioning_input': sr_cond.unsqueeze(0),
'lr_input': output_base})
sr_diffusion = sr_diffusion.cpu()
torchaudio.save(os.path.join(args.output_path, f'{p}_output_mean_sr.wav'), output.cpu().squeeze(0), sr_sample_rate)

View File

@ -1,68 +0,0 @@
import os
import os.path as osp
import logging
import random
import argparse
import torchvision
import utils
import utils.options as option
import utils.util as util
from models.waveglow.denoiser import Denoiser
from trainer.ExtensibleTrainer import ExtensibleTrainer
from data import create_dataset, create_dataloader
from tqdm import tqdm
import torch
import numpy as np
from scipy.io import wavfile
if __name__ == "__main__":
# Set seeds
torch.manual_seed(5555)
random.seed(5555)
np.random.seed(5555)
#### options
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/generate_quantized_mels.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
util.mkdirs(
(path for key, path in opt['path'].items()
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
logger.info(option.dict2str(opt))
test_loaders = []
for phase, dataset_opt in sorted(opt['datasets'].items()):
test_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn)
logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
test_loaders.append(test_loader)
model = ExtensibleTrainer(opt)
outpath = opt['path']['results_root']
os.makedirs(os.path.join(outpath, 'quantized_mels'), exist_ok=True)
for test_loader in test_loaders:
dataset_dir = opt['path']['results_root']
util.mkdir(dataset_dir)
tq = tqdm(test_loader)
for data in tq:
with torch.no_grad():
model.feed_data(data, 0)
model.test()
wavfiles = data['filenames']
quantized = model.eval_state[opt['eval']['quantized_mels']][0]
for i, filename in enumerate(wavfiles):
qmelfile = filename.replace('wavs/', 'quantized_mels/') + '.pth'
torch.save(quantized[i], os.path.join(outpath, qmelfile))

View File

@ -1,32 +0,0 @@
# Combines all libriTTS WAV->text mappings into a single file
import os
from tqdm import tqdm
if __name__ == '__main__':
libri_root = 'E:\\audio\\LibriTTS'
basis = 'train-clean-360'
readers = os.listdir(os.path.join(libri_root, basis))
ofile = open(os.path.join(libri_root, f'{basis}_list.txt'), 'w', encoding='utf-8')
for reader_dir in tqdm(readers):
reader = os.path.join(libri_root, basis, reader_dir)
if not os.path.isdir(reader):
continue
for chapter_dir in os.listdir(reader):
chapter = os.path.join(reader, chapter_dir)
if not os.path.isdir(chapter):
continue
id = f'{os.path.basename(reader)}_{os.path.basename(chapter)}'
trans_file = f'{id}.trans.tsv'
with open(os.path.join(chapter, trans_file), encoding='utf-8') as f:
trans_lines = [line.strip().split('\t') for line in f]
for line in trans_lines:
wav_file, raw_text, normalized_text = line
wav_file = '/'.join([basis, reader_dir, chapter_dir, f'{wav_file}.wav'])
if not os.path.exists(os.path.join(libri_root, wav_file)):
print(f'!WARNING could not open {wav_file}')
else:
ofile.write(f'{wav_file}|{normalized_text}\n')
ofile.flush()
ofile.close()

View File

@ -1,99 +0,0 @@
# Combines all libriTTS WAV->text mappings into a single file
import os
import random
import audio2numpy
import torch
from scipy.io import wavfile
from tqdm import tqdm
from utils.audio_resampler import AudioResampler
def secs_to_frames(secs, sr):
return int(secs*sr)
def get_audio_clip(audio, sr, start, end):
start = secs_to_frames(start, sr)
end = secs_to_frames(end, sr)
assert end > start
if end >= audio.shape[0]:
return None
return audio[start:end]
# Produces an audio clip that would produce a MEL spectrogram of length mel_length by parsing parsed_sentences starting
# at starting_index and moving forwards until the full length is finished.
# Returns:
# On failure, returns tuple: (end_index, None, [], [])
# On success: returns tuple: (end_index, clip, start_points, end_points)
# clip.shape = (<mel_length*256>,)
# start_points = list(ints) where each sentence in the clip starts
# end_points = list(ints) where each sentence in the clip ends
def gather_clip(audio, parsed_sentences, starting_index, sr, mel_length):
audio_length = (mel_length * 256) / sr # This is technically a hyperparameter, but I have no intent of changing the MEL hop length.
starts = []
ends = []
start, end = parsed_sentences[starting_index][4:6]
start = float(start)
end = float(end)
clipstart = max(start - random.random() * 2, 0) # Offset start backwards by up to 2 seconds
clipend = start + audio_length
clip = get_audio_clip(audio, sr, clipstart, clipend)
if clip is not None:
# Fetch the start and endpoints that go along with this clip.
starts.append(secs_to_frames(start-clipstart, sr))
while end < clipend:
ends.append(secs_to_frames(end-clipstart, sr))
starting_index += 1
if starting_index >= len(parsed_sentences):
break
start, end = parsed_sentences[starting_index][4:6]
start = float(start)
end = float(end)
if start < clipend:
starts.append(secs_to_frames(start-clipstart, sr))
return starting_index+1, clip, starts, ends
if __name__ == '__main__':
full_book_root = 'D:\\data\\audio\\libritts\\full_books\\mp3'
libri_root = 'D:\\data\\audio\\libritts\\test-clean'
desired_mel_length = 2000
desired_audio_sample_rate = 22050
output_dir = 'D:\\data\\audio\\libritts\\stop_dataset_eval'
os.makedirs(output_dir, exist_ok=True)
j = 0
readers = os.listdir(libri_root)
for it, reader_dir in enumerate(tqdm(readers)):
#if it <= 145: # Hey idiot! If you change this, change j too!
# continue
reader = os.path.join(libri_root, reader_dir)
if not os.path.isdir(reader):
continue
for chapter_dir in os.listdir(reader):
chapter = os.path.join(reader, chapter_dir)
if not os.path.isdir(chapter):
continue
id = f'{os.path.basename(reader)}_{os.path.basename(chapter)}'
book_file = os.path.join(chapter, f'{id}.book.tsv')
if not os.path.exists(book_file):
continue
with open(book_file, encoding='utf-8') as f:
full_chapter, sr = audio2numpy.open_audio(os.path.join(full_book_root, reader_dir, chapter_dir, f'{chapter_dir}.mp3'))
full_chapter = torch.tensor(full_chapter)
if len(full_chapter.shape) > 1:
full_chapter = full_chapter[:, 0] # Only use mono-audio.
resampler = AudioResampler(sr, desired_audio_sample_rate, dtype=torch.float)
full_chapter = resampler(full_chapter.unsqueeze(0)).squeeze(0)
parsed_sentences = [line.strip().split('\t') for line in f]
i = 0
while i < len(parsed_sentences):
i, clip, ns, ne = gather_clip(full_chapter, parsed_sentences, i, desired_audio_sample_rate, desired_mel_length)
if clip is not None:
wavfile.write(os.path.join(output_dir, f'{j}.wav'), desired_audio_sample_rate, clip.cpu().numpy())
torch.save((ns,ne), os.path.join(output_dir, f'{j}_se.pth'))
j += 1