diff --git a/codes/models/gpt_voice/gpt_asr_hf.py b/codes/models/gpt_voice/gpt_asr_hf.py deleted file mode 100644 index a120824a..00000000 --- a/codes/models/gpt_voice/gpt_asr_hf.py +++ /dev/null @@ -1,330 +0,0 @@ -from time import time - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import GPT2Model, GPT2Config, GPT2LMHeadModel, GPT2PreTrainedModel -from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions -from transformers.utils.model_parallel_utils import get_device_map, assert_device_map - -from models.tacotron2.text import symbols -from trainer.networks import register_model -from utils.util import opt_get - - -class ResBlock(nn.Module): - def __init__(self, chan): - super().__init__() - self.net = nn.Sequential( - nn.Conv1d(chan, chan, kernel_size=3, padding=1), - nn.GroupNorm(chan//8, chan), - nn.ReLU(), - nn.Conv1d(chan, chan, kernel_size=3, padding=1), - nn.GroupNorm(chan//8, chan) - ) - - def forward(self, x): - return F.relu(self.net(x) + x) - - -class MelEncoder(nn.Module): - def __init__(self, channels, mel_channels=80): - super().__init__() - self.channels = channels - self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=5, padding=2), - ResBlock(channels//4), - ResBlock(channels//4), - nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1), - nn.GroupNorm(channels//16, channels//2), - nn.ReLU(), - ResBlock(channels//2), - ResBlock(channels//2), - nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1), - nn.GroupNorm(channels//8, channels), - nn.ReLU(), - ResBlock(channels), - ResBlock(channels) - ) - - def forward(self, x): - return self.encoder(x) - - -class GPT2InferenceModel(GPT2PreTrainedModel): - def __init__(self, config, gpt, text_pos_emb, norm, linear): - super().__init__(config) - self.transformer = gpt - self.text_pos_embedding = text_pos_emb - self.lm_head = nn.Sequential(norm, linear) - - # Model parallel - self.model_parallel = False - self.device_map = None - self.cached_mel_emb = None - - def parallelize(self, device_map=None): - self.device_map = ( - get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.transformer.h)) - self.transformer.parallelize(self.device_map) - self.lm_head = self.lm_head.to(self.transformer.first_device) - self.model_parallel = True - - def deparallelize(self): - self.transformer.deparallelize() - self.transformer = self.transformer.to("cpu") - self.lm_head = self.lm_head.to("cpu") - self.model_parallel = False - torch.cuda.empty_cache() - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def store_mel_emb(self, mel_emb): - self.cached_mel_emb = mel_emb - - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): - - token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs - if past: - input_ids = input_ids[:, -1].unsqueeze(-1) - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past: - position_ids = position_ids[:, -1].unsqueeze(-1) - else: - position_ids = None - return { - "input_ids": input_ids, - "past_key_values": past, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - } - - def forward( - self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - assert self.cached_mel_emb is not None - assert inputs_embeds is None # Not supported by this inference model. - assert labels is None # Training not supported by this inference model. - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # Create embedding - mel_len = self.cached_mel_emb.shape[1] - if input_ids.shape[1] != 1: - text_inputs = input_ids[:, mel_len:] - text_emb = self.transformer.get_input_embeddings()(text_inputs) - if self.text_pos_embedding is not None: - text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_emb.device)) - if self.cached_mel_emb.shape[0] != text_emb.shape[0]: - mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0) - else: - mel_emb = self.cached_mel_emb - emb = torch.cat([mel_emb, text_emb], dim=1) - else: - emb = self.transformer.get_input_embeddings()(input_ids) - if self.text_pos_embedding is not None: - emb = emb + self.text_pos_embedding(torch.tensor(attention_mask.shape[1]-mel_len, device=attention_mask.device)).unsqueeze(0).unsqueeze(0) - - transformer_outputs = self.transformer( - inputs_embeds=emb, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.transformer.first_device) - hidden_states = hidden_states.to(self.lm_head.weight.device) - - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - return (lm_logits,) + transformer_outputs[1:] - - return CausalLMOutputWithCrossAttentions( - loss=None, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - cross_attentions=transformer_outputs.cross_attentions, - ) - - @staticmethod - def _reorder_cache(past, beam_idx): - """ - This function is used to re-order the :obj:`past_key_values` cache if - :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is - called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past - ) - - -class GptAsrHf(nn.Module): - NUMBER_SYMBOLS = len(symbols) - NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+1 - - def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_frames=1000, checkpointing=True): - super().__init__() - self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding. - self.max_symbols_per_phrase = max_symbols_per_phrase - - self.model_dim = model_dim - self.max_mel_frames = self.max_mel_frames - self.mel_encoder = MelEncoder(model_dim) - self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) - self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim) - seq_length = 2+self.max_symbols_per_phrase+self.max_mel_frames - self.gpt_config = GPT2Config(vocab_size=self.NUMBER_TEXT_TOKENS, - n_positions=seq_length, - n_ctx=seq_length, - n_embd=model_dim, - n_layer=layers, - n_head=heads, - gradient_checkpointing=checkpointing, - use_cache=not checkpointing) - self.gpt = GPT2Model(self.gpt_config) - self.final_norm = nn.LayerNorm(model_dim) - self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) - - - def get_logits(self, mel_inputs, text_targets, get_attns=False): - # Pad front and back. Pad at front is the "START" token. - text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS) - text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1])) - text_emb = self.gpt.get_input_embeddings()(text_targets) - text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device)) - mel_emb = self.mel_encoder(mel_inputs) - mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1])) - mel_emb = mel_emb.permute(0,2,1).contiguous() - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) - emb = torch.cat([mel_emb, text_emb], dim=1) - gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) - if get_attns: - return gpt_out.attentions - enc = gpt_out.last_hidden_state - text_logits = self.final_norm(enc[:, self.max_mel_frames:]) - text_logits = self.text_head(text_logits) - text_logits = text_logits.permute(0,2,1) - return text_logits - - def forward(self, mel_inputs, text_targets, return_attentions=False): - text_logits = self.get_logits(mel_inputs, text_targets, get_attns=return_attentions) - if return_attentions: - return text_logits # These weren't really the logits. - loss_text = F.cross_entropy(text_logits, text_targets.long()) - return loss_text.mean(), text_logits - - def inference(self, mel_inputs, cond_text=None, do_sample=False, temperature=1.0, num_beams=8): - if not hasattr(self, 'inference_model'): - self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head) - - mel_emb = self.mel_encoder(mel_inputs) - assert mel_emb.shape[-1] <= self.max_mel_frames - mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1])) - mel_emb = mel_emb.permute(0,2,1).contiguous() - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) - self.inference_model.store_mel_emb(mel_emb) - - # "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above. - if cond_text is None: - fake_inputs = torch.full((mel_inputs.shape[0],self.max_mel_frames+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device) - fake_inputs[:,-1] = self.NUMBER_SYMBOLS - else: - cond_used = 10 - fake_inputs = torch.full((mel_inputs.shape[0],self.max_mel_frames+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device) - fake_inputs[:,-1-cond_used] = self.NUMBER_SYMBOLS - fake_inputs[:, -cond_used:] = cond_text[:, :cond_used] - gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.NUMBER_SYMBOLS, pad_token_id=0, eos_token_id=0, - max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True) - return gen[:, self.max_mel_frames:] - - -@register_model -def register_gpt_asr_hf(opt_net, opt): - return GptAsrHf(**opt_get(opt_net, ['kwargs'], {})) - - -# Quick script that loads a model and halves the number of layers, then saves that model. -def distill(): - gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=512, heads=8) - gpt.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_ema.pth')) - rc = 0 - i = 0 - while i < len(gpt.gpt.h): - if rc % 2 != 0: - del gpt.gpt.h[i] - else: - i += 1 - rc += 1 - torch.save(gpt.state_dict(), 'X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_distilled.pth') - - -if __name__ == '__main__': - distill() - - ''' - gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8) - #l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100))) - start = time() - gpt.inference(torch.randn(1,80,350), num_beams=1) - print(f"Elapsed: {time()-start}") - ''' - - ''' - with torch.no_grad(): - t = torch.randn(1,80,800).cuda() - start = time() - s = gpt.inference_beam_topk(t) - print(time()-start) - - start = time() - o = gpt.inference_beam_topk(t, fn='inference_beam_opt') - print(time()-start) - ''' - diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py deleted file mode 100644 index 9acad597..00000000 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ /dev/null @@ -1,396 +0,0 @@ -import functools - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import GPT2Model, GPT2Config, GPT2PreTrainedModel -from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions -from transformers.utils.model_parallel_utils import get_device_map, assert_device_map - -from trainer.networks import register_model -from utils.util import opt_get - - -class ResBlock(nn.Module): - """ - Basic residual convolutional block that uses GroupNorm. - """ - def __init__(self, chan): - super().__init__() - self.net = nn.Sequential( - nn.Conv1d(chan, chan, kernel_size=3, padding=1), - nn.GroupNorm(chan//8, chan), - nn.ReLU(), - nn.Conv1d(chan, chan, kernel_size=3, padding=1), - nn.GroupNorm(chan//8, chan) - ) - - def forward(self, x): - return F.relu(self.net(x) + x) - - -class LeanMelEncoder(nn.Module): - """ - Encodes a BxCxS MEL tensor into a latent space suitable for use with a transformer. - """ - def __init__(self, channels, mel_channels=80, resblocks_per_reduction=1): - super().__init__() - self.channels = channels - self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//2, kernel_size=5, stride=2, padding=1), - nn.GroupNorm(channels//16, channels//2), - nn.ReLU(), - nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]), - nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1), - nn.GroupNorm(channels//8, channels), - nn.ReLU(), - nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), - nn.Conv1d(channels, channels, kernel_size=3, stride=2, padding=1), - nn.GroupNorm(channels//8, channels), - nn.ReLU(), - nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), - ) - self.reduction = 8 - - def forward(self, x): - for e in self.encoder: - x = e(x) - return x - - -def null_position_embeddings(range, dim): - """ - Helper method which simply returns a range-shaped tensor filled with zeros. Useful for emulating a no-effect - embedding. - """ - return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) - - -class GPT2InferenceModel(GPT2PreTrainedModel): - def __init__(self, config, gpt, text_pos_emb, norm, linear): - super().__init__(config) - self.transformer = gpt - self.text_pos_embedding = text_pos_emb - self.lm_head = nn.Sequential(norm, linear) - - # Model parallel - self.model_parallel = False - self.device_map = None - self.cached_mel_emb = None - - def parallelize(self, device_map=None): - self.device_map = ( - get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.transformer.h)) - self.transformer.parallelize(self.device_map) - self.lm_head = self.lm_head.to(self.transformer.first_device) - self.model_parallel = True - - def deparallelize(self): - self.transformer.deparallelize() - self.transformer = self.transformer.to("cpu") - self.lm_head = self.lm_head.to("cpu") - self.model_parallel = False - torch.cuda.empty_cache() - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def store_mel_emb(self, mel_emb): - self.cached_mel_emb = mel_emb - - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): - - token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs - if past: - input_ids = input_ids[:, -1].unsqueeze(-1) - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past: - position_ids = position_ids[:, -1].unsqueeze(-1) - else: - position_ids = None - return { - "input_ids": input_ids, - "past_key_values": past, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - } - - def forward( - self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - assert self.cached_mel_emb is not None - assert inputs_embeds is None # Not supported by this inference model. - assert labels is None # Training not supported by this inference model. - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # Create embedding - mel_len = self.cached_mel_emb.shape[1] - if input_ids.shape[1] != 1: - text_inputs = input_ids[:, mel_len:] - text_emb = self.transformer.get_input_embeddings()(text_inputs) - text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_emb.device)) - if self.cached_mel_emb.shape[0] != text_emb.shape[0]: - mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0) - else: - mel_emb = self.cached_mel_emb - emb = torch.cat([mel_emb, text_emb], dim=1) - else: - emb = self.transformer.get_input_embeddings()(input_ids) + \ - self.text_pos_embedding(torch.tensor(attention_mask.shape[1]-mel_len, device=attention_mask.device)).unsqueeze(0).unsqueeze(0) - - transformer_outputs = self.transformer( - inputs_embeds=emb, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.transformer.first_device) - hidden_states = hidden_states.to(self.lm_head.weight.device) - - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - return (lm_logits,) + transformer_outputs[1:] - - return CausalLMOutputWithCrossAttentions( - loss=None, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - cross_attentions=transformer_outputs.cross_attentions, - ) - - @staticmethod - def _reorder_cache(past, beam_idx): - """ - This function is used to re-order the :obj:`past_key_values` cache if - :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is - called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past - ) - - -class GptAsrHf2(nn.Module): - """ - Core module that encapsulates a set of embeddings, a MEL encoder, a GPT-style transformer and the head needed to - make its output useful. - """ - def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=800, max_mel_frames=3000, - checkpointing=True, number_text_tokens=512, start_token=511, stop_token=0, mel_compression=256): - super().__init__() - self.number_text_tokens = number_text_tokens - self.start_token = start_token - self.stop_token = stop_token - self.max_symbols_per_phrase = max_symbols_per_phrase - self.model_dim = model_dim - self.mel_encoder = LeanMelEncoder(model_dim) - self.max_mel_frames = max_mel_frames // self.mel_encoder.reduction - self.mel_compression = mel_compression - seq_length = 2+self.max_symbols_per_phrase+self.max_mel_frames - self.gpt_config = GPT2Config(vocab_size=self.number_text_tokens, - n_positions=seq_length, - n_ctx=seq_length, - n_embd=model_dim, - n_layer=layers, - n_head=heads, - gradient_checkpointing=checkpointing, - use_cache=not checkpointing) - self.gpt = GPT2Model(self.gpt_config) - # Override the built in positional embeddings - del self.gpt.wpe - self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) - - # This model uses its own positional embeddings, which helps discriminate between text and audio MELs. - self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) - self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim) - self.text_solo_embedding = nn.Parameter(torch.randn(1,1,model_dim) * self.gpt.config.initializer_range, requires_grad=True) - - # Head layers - self.final_norm = nn.LayerNorm(model_dim) - self.text_head = nn.Linear(model_dim, self.number_text_tokens) - - # Initialize the embeddings per the GPT-2 scheme - for module in [self.text_pos_embedding, self.mel_pos_embedding]: - module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def build_aligned_inputs_and_targets(self, input, start_token, stop_token): - """ - Helper function for producing inputs and outputs for the GPT model. - """ - inp = F.pad(input, (1,0), value=start_token) - tar = F.pad(input, (0,1), value=stop_token) - return inp, tar - - def get_logits(self, mel_inputs, text_emb, get_attns=False): - """ - Helper function for producing text logits. - """ - if mel_inputs is None: - emb = text_emb - mel_len = 0 - else: - mel_emb = self.mel_encoder(mel_inputs) - assert mel_emb.shape[-1] <= self.max_mel_frames, f'{mel_emb.shape[-1]} > {self.max_mel_frames}' - mel_emb = mel_emb.permute(0,2,1).contiguous() - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) - emb = torch.cat([mel_emb, text_emb], dim=1) - mel_len = mel_emb.shape[1] - gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) - if get_attns: - return gpt_out.attentions - enc = gpt_out.last_hidden_state - text_logits = self.final_norm(enc[:, mel_len:]) - text_logits = self.text_head(text_logits) - text_logits = text_logits.permute(0,2,1) - return text_logits - - def forward(self, mel_inputs, wav_lengths, text_inputs, text_lengths, return_attentions=False): - """ - "Normal" forward pass which produces a text loss when given a MEL-encoded audio clip and transcribed text - targets. - """ - assert text_inputs.shape[1] <= self.max_symbols_per_phrase, str(text_inputs.shape[1]) - assert text_inputs.max() <= self.number_text_tokens, str(text_inputs.max()) - - # Trim off excessive inputs to speed training. This might seem odd, but consider that this model is fed microbatches - # which are padded at the macro-batch level. - max_text_len = text_lengths.max() - text_inputs = text_inputs[:, :max_text_len] - max_mel_len = wav_lengths.max() // self.mel_compression - mel_inputs = mel_inputs[:, :, :max_mel_len] - - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token) - text_emb = self.gpt.get_input_embeddings()(text_inputs) + \ - self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) - text_logits = self.get_logits(mel_inputs, text_emb, get_attns=return_attentions) - - if return_attentions: - return text_logits # These weren't really the logits. - loss_text = F.cross_entropy(text_logits, text_targets.long()) - return loss_text.mean(), text_logits - - def text_only(self, text_inputs, text_lengths): - """ - Used to train on only text inputs. - """ - assert text_inputs.shape[1] <= self.max_symbols_per_phrase, str(text_inputs.shape[1]) - assert text_inputs.max() <= self.number_text_tokens, str(text_inputs.max()) - - # Trim off excessive inputs to speed training. This might seem odd, but consider that this model is fed microbatches - # which are padded at the macro-batch level. - max_text_len = text_lengths.max() - text_inputs = text_inputs[:, :max_text_len] - - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token) - text_emb = self.gpt.get_input_embeddings()(text_inputs) + \ - self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + \ - self.text_solo_embedding - text_logits = self.get_logits(None, text_emb) - loss_text = F.cross_entropy(text_logits, text_targets.long()) - return loss_text.mean(), text_logits - - def inference(self, mel_inputs, wav_lengths, do_sample=False, temperature=1.0, num_beams=8): - """ - Performs inference by transcribing mel_inputs into text. Returns the text tokens. - """ - if not hasattr(self, 'inference_model'): - self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head) - - # TODO: get rid of this.. - max_mel_len = wav_lengths.max() // self.mel_compression - mel_inputs = mel_inputs[:, :, :max_mel_len] - - mel_emb = self.mel_encoder(mel_inputs) - assert mel_emb.shape[-1] <= self.max_mel_frames - mel_emb = mel_emb.permute(0,2,1).contiguous() - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) - self.inference_model.store_mel_emb(mel_emb) - - # "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above. - fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device) - fake_inputs[:,-1] = self.start_token - gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.start_token, pad_token_id=0, eos_token_id=0, - max_length=self.max_symbols_per_phrase+mel_emb.shape[1], temperature=temperature, num_beams=num_beams, use_cache=True) - return gen[:, mel_emb.shape[1]+1:] - - -@register_model -def register_gpt_asr_hf2(opt_net, opt): - return GptAsrHf2(**opt_get(opt_net, ['kwargs'], {})) - - -# Quick script that loads a model and halves the number of layers, then saves that model. -def distill(): - gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=512, heads=8) - gpt.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_ema.pth')) - rc = 0 - i = 0 - while i < len(gpt.gpt.h): - if rc % 2 != 0: - del gpt.gpt.h[i] - else: - i += 1 - rc += 1 - torch.save(gpt.state_dict(), 'X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_distilled.pth') - - -if __name__ == '__main__': - #distill() - - gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8) - l = gpt(torch.randn(2,80,640), torch.tensor([100*256,20*256]), torch.randint(high=100, size=(2,80)), torch.tensor([15,60])) - gpt.text_only(torch.randint(high=100, size=(2,120)), torch.tensor([30,33])) - - #start = time() - #gpt.inference(torch.randn(1,80,350), num_beams=1) - #print(f"Elapsed: {time()-start}") diff --git a/codes/models/gpt_voice/gpt_tts_hf.py b/codes/models/gpt_voice/gpt_tts_hf.py deleted file mode 100644 index d59543b5..00000000 --- a/codes/models/gpt_voice/gpt_tts_hf.py +++ /dev/null @@ -1,159 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import GPT2Model, GPT2Config - -from models.arch_util import AttentionBlock -from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel -from models.tacotron2.text import symbols -from trainer.networks import register_model -from utils.util import opt_get - - -class ConditioningEncoder(nn.Module): - def __init__(self, - spec_dim, - embedding_dim, - attn_blocks=6, - num_attn_heads=4, - do_checkpointing=False): - super().__init__() - attn = [] - self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) - for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing)) - self.attn = nn.Sequential(*attn) - self.dim = embedding_dim - self.do_checkpointing = do_checkpointing - - def forward(self, x): - h = self.init(x) - h = self.attn(h) - return h[:, :, 0] - - -class GptTtsHf(nn.Module): - NUMBER_TEXT_TOKENS = 256 # The number of tokens produced by our bespoke BPE tokenizer. - START_TEXT_TOKEN = 255 - STOP_TEXT_TOKEN = 0 - NUMBER_MEL_CODES = 8194 - START_MEL_TOKEN = 8192 - STOP_MEL_TOKEN = 8193 - - def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=80, max_mel_tokens=250, max_conditioning_inputs=3, - checkpointing=True, mel_length_compression=1024, max_conditioning_length=60): - super().__init__() - - - self.max_mel_tokens = max_mel_tokens - self.max_symbols_per_phrase = max_symbols_per_phrase - self.model_dim = model_dim - self.max_conditioning_inputs = max_conditioning_inputs - self.mel_length_compression = mel_length_compression - self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) - self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim) - seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens - self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES, - n_positions=seq_length, - n_ctx=seq_length, - n_embd=model_dim, - n_layer=layers, - n_head=heads, - gradient_checkpointing=checkpointing, - use_cache=not checkpointing) - self.gpt = GPT2Model(self.gpt_config) - self.final_norm = nn.LayerNorm(model_dim) - self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) - self.mel_head = nn.Linear(model_dim, self.NUMBER_MEL_CODES) - self.max_conditioning_length = max_conditioning_length - - - def build_aligned_inputs_and_targets(self, input, start_token, stop_token): - inp = F.pad(input, (1,0), value=start_token) - tar = F.pad(input, (0,1), value=stop_token) - return inp, tar - - def get_logits(self, text_inputs, cond_input, mel_inputs, get_attns=False): - text_emb = self.text_embedding(text_inputs) - cond = self.conditioning_encoder(cond_input).unsqueeze(1) - mel_emb = self.gpt.get_input_embeddings()(mel_inputs) - - emb = torch.cat([text_emb, cond, mel_emb], dim=1) - gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) - if get_attns: - return gpt_out.attentions - enc = gpt_out.last_hidden_state - - text_logits = self.final_norm(enc[:, :text_emb.shape[1]]) - text_logits = self.text_head(text_logits) - text_logits = text_logits.permute(0,2,1) - mel_logits = self.final_norm(enc[:, -mel_emb.shape[1]:]) - mel_logits = self.mel_head(mel_logits) - mel_logits = mel_logits.permute(0,2,1) - - return text_logits, mel_logits - - def forward(self, text_inputs, cond_input, mel_targets, wav_lengths, return_attentions=False): - """ - Forward pass - text_inputs: long tensor, (b,t) - cond_inputs: MEL float tensor, (b,c,80,s) - mel_targets: long tensor, (b,m) - mel_lengths: long tensor, (b,) - """ - # Set padding areas within MEL (currently it is coded with the MEL code for ). - mel_lengths = wav_lengths // self.mel_length_compression - for b in range(len(mel_lengths)): - if mel_lengths[b] < mel_targets.shape[-1]: - mel_targets[b, mel_lengths[b]:] = self.STOP_MEL_TOKEN - - # Randomly permute the conditioning spectrogram, to destroy any structure present. - cond_input = cond_input[:,:,torch.randperm(cond_input.shape[-1])] - if cond_input.shape[-1] > self.max_conditioning_length: - cond_input = cond_input[:,:,:self.max_conditioning_length] - - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN) - mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_targets, self.START_MEL_TOKEN, self.STOP_MEL_TOKEN) - text_logits, mel_logits = self.get_logits(text_inputs, cond_input, mel_inputs, get_attns=return_attentions) - if return_attentions: - return mel_logits - loss_text = F.cross_entropy(text_logits, text_targets.long()) - loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) - return loss_text.mean(), loss_mel.mean(), mel_logits - - def inference(self, text_inputs, cond_input, **hf_generate_kwargs): - if not hasattr(self, 'inference_model'): - self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, None, self.final_norm, self.mel_head) - - text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN) - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN) - text_emb = self.text_embedding(text_inputs) - - # Randomly permute the conditioning spectrogram, to destroy any structure present. - cond_input = cond_input[:,:,torch.randperm(cond_input.shape[-1])] - if cond_input.shape[-1] > self.max_conditioning_length: - cond_input = cond_input[:,:,:self.max_conditioning_length] - cond = self.conditioning_encoder(cond_input).unsqueeze(1) - - emb = torch.cat([text_emb, cond], dim=1) - self.inference_model.store_mel_emb(emb) - - fake_inputs = torch.full((emb.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device) - fake_inputs[:,-1] = self.START_MEL_TOKEN - - gen = self.inference_model.generate(fake_inputs, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN, - max_length=emb.shape[1]+self.max_mel_tokens, **hf_generate_kwargs) - return gen[:, fake_inputs.shape[1]:] - - -@register_model -def register_gpt_tts_hf(opt_net, opt): - return GptTtsHf(**opt_get(opt_net, ['kwargs'], {})) - - -if __name__ == '__main__': - gpt = GptTtsHf(model_dim=1024, heads=16) - l = gpt(torch.randint(high=len(symbols), size=(2,200)), - torch.arange(0, 80, 1, dtype=torch.float).view(1,80,1).repeat(2,1,800), - torch.randint(high=8192, size=(2,250)), - torch.tensor([150*256,195*256])) diff --git a/codes/models/gpt_voice/pixelshuffle_1d.py b/codes/models/gpt_voice/pixelshuffle_1d.py deleted file mode 100644 index 4ff48904..00000000 --- a/codes/models/gpt_voice/pixelshuffle_1d.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch - -# "long" and "short" denote longer and shorter samples -class PixelShuffle1D(torch.nn.Module): - """ - 1D pixel shuffler. https://arxiv.org/pdf/1609.05158.pdf - Upscales sample length, downscales channel length - "short" is input, "long" is output - """ - def __init__(self, upscale_factor): - super(PixelShuffle1D, self).__init__() - self.upscale_factor = upscale_factor - - def forward(self, x): - batch_size = x.shape[0] - short_channel_len = x.shape[1] - short_width = x.shape[2] - - long_channel_len = short_channel_len // self.upscale_factor - long_width = self.upscale_factor * short_width - - x = x.contiguous().view([batch_size, self.upscale_factor, long_channel_len, short_width]) - x = x.permute(0, 2, 3, 1).contiguous() - x = x.view(batch_size, long_channel_len, long_width) - - return x - -class PixelUnshuffle1D(torch.nn.Module): - """ - Inverse of 1D pixel shuffler - Upscales channel length, downscales sample length - "long" is input, "short" is output - """ - def __init__(self, downscale_factor): - super(PixelUnshuffle1D, self).__init__() - self.downscale_factor = downscale_factor - - def forward(self, x): - batch_size = x.shape[0] - long_channel_len = x.shape[1] - long_width = x.shape[2] - - short_channel_len = long_channel_len * self.downscale_factor - short_width = long_width // self.downscale_factor - - x = x.contiguous().view([batch_size, long_channel_len, short_width, self.downscale_factor]) - x = x.permute(0, 3, 1, 2).contiguous() - x = x.view([batch_size, short_channel_len, short_width]) - return x \ No newline at end of file diff --git a/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref_trunc_top.py b/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref_trunc_top.py deleted file mode 100644 index 8b1997e5..00000000 --- a/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref_trunc_top.py +++ /dev/null @@ -1,394 +0,0 @@ -import random - -from models.diffusion.fp16_util import convert_module_to_f32, convert_module_to_f16 -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, ResBlock, TimestepEmbedSequential, \ - Downsample, Upsample -import torch -import torch.nn as nn -import torch.nn.functional as F - -from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner -from trainer.networks import register_model -from utils.util import get_mask_from_lengths - - -class DiscreteSpectrogramConditioningBlock(nn.Module): - def __init__(self, dvae_channels, channels): - super().__init__() - self.intg = nn.Sequential(nn.Conv1d(dvae_channels, channels, kernel_size=1), - normalization(channels), - nn.SiLU(), - nn.Conv1d(channels, channels, kernel_size=3)) - - """ - Embeds the given codes and concatenates them onto x. Return shape is the same as x.shape. - - :param x: bxcxS waveform latent - :param codes: bxN discrete codes, N <= S - """ - def forward(self, x, dvae_in): - b, c, S = x.shape - _, q, N = dvae_in.shape - emb = self.intg(dvae_in) - emb = nn.functional.interpolate(emb, size=(S,), mode='nearest') - return torch.cat([x, emb], dim=1) - - -class DiffusionVocoderWithRefTruncatedTop(nn.Module): - """ - The full UNet model with attention and timestep embedding. - - Customized to be conditioned on a spectrogram prior. - - :param in_channels: channels in the input Tensor. - :param spectrogram_channels: channels in the conditioning spectrogram. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - """ - - def __init__( - self, - model_channels, - in_channels=1, - out_channels=2, # mean and variance - discrete_codes=512, - dropout=0, - # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K - channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48), - num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2), - # spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0) - # attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 - spectrogram_conditioning_resolutions=(512,), - attention_resolutions=(512,1024,2048), - conv_resample=True, - dims=1, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - kernel_size=3, - scale_factor=2, - conditioning_inputs_provided=True, - conditioning_input_dim=80, - time_embed_dim_multiplier=4, - only_train_dvae_connection_layers=False, - ): - super().__init__() - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.dtype = torch.float16 if use_fp16 else torch.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - self.dims = dims - - padding = 1 if kernel_size == 3 else 2 - - time_embed_dim = model_channels * time_embed_dim_multiplier - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - self.conditioning_enabled = conditioning_inputs_provided - if conditioning_inputs_provided: - self.contextual_embedder = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1, - attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) - - self.cheater_input_block = TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels//2, kernel_size, padding=padding, stride=2)) - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, model_channels//2, model_channels, kernel_size, padding=padding) - ) - ] - ) - spectrogram_blocks = [] - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - - for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)): - if ds in spectrogram_conditioning_resolutions: - spec_cond_block = DiscreteSpectrogramConditioningBlock(discrete_codes, ch) - self.input_blocks.append(spec_cond_block) - spectrogram_blocks.append(spec_cond_block) - ch *= 2 - - for _ in range(num_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=int(mult * model_channels), - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - kernel_size=kernel_size, - ) - ] - ch = int(mult * model_channels) - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - kernel_size=kernel_size, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - kernel_size=kernel_size, - ), - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - kernel_size=kernel_size, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, (mult, num_blocks) in list(enumerate(zip(channel_mult, num_res_blocks)))[::-1]: - for i in range(num_blocks + 1): - ich = input_block_chans.pop() - layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, - out_channels=int(model_channels * mult), - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - kernel_size=kernel_size, - ) - ] - ch = int(model_channels * mult) - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads_upsample, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ) - ) - if level and i == num_blocks: - out_ch = ch - layers.append( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - up=True, - kernel_size=kernel_size, - ) - if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor) - ) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - # These are the special input and output blocks that are pseudo-disconnected from the rest of the graph, - # allowing them to be trained on a smaller subset of input. - self.top_inp_raw = TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding) - ) - self.top_inp_blocks = nn.ModuleList([TimestepEmbedSequential(ResBlock( - model_channels, - time_embed_dim, - dropout, - out_channels=model_channels, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - kernel_size=kernel_size, - )) for _ in range(num_blocks)]) - self.top_out_upsample = TimestepEmbedSequential(ResBlock( - model_channels, - time_embed_dim, - dropout, - out_channels=model_channels, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - up=True, - kernel_size=kernel_size, - ) if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=model_channels, factor=scale_factor)) - self.top_out_blocks = nn.ModuleList([TimestepEmbedSequential(ResBlock( - 2 * model_channels, - time_embed_dim, - dropout, - out_channels=model_channels, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - kernel_size=kernel_size, - )) for _ in range(num_blocks) - ]) - self.top_out_final = nn.Sequential( - normalization(model_channels), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), - ) - - if only_train_dvae_connection_layers: - for p in self.parameters(): - p.DO_NOT_TRAIN = True - p.requires_grad = False - for sb in spectrogram_blocks: - for p in sb.parameters(): - del p.DO_NOT_TRAIN - p.requires_grad = True - - def forward(self, x, timesteps, spectrogram, conditioning_input=None): - """ - Apply the model to an input batch. - - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param y: an [N] Tensor of labels, if class-conditional. - :return: an [N x C x ...] Tensor of outputs, halved in size and the bounds of the original input that was halved. - """ - assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement. - if self.conditioning_enabled: - assert conditioning_input is not None - - emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - if self.conditioning_enabled: - emb2 = self.contextual_embedder(conditioning_input) - emb = emb1 + emb2 - else: - emb = emb1 - - # Handle the top blocks first, independently of the rest of the unet. These only process half of x. - if self.training: - rand_start = (random.randint(0, x.shape[-1] // 2) // 2) * 2 # Must be a multiple of 2, to align with the next lower layer. - rand_stop = rand_start + x.shape[-1] // 2 - else: - rand_start = 0 # When in eval, rand_start:rand_stop spans the entire input. - rand_stop = x.shape[-1] - top_blocks = [] - ht = self.top_inp_raw(x.type(self.dtype)[:, :, rand_start:rand_stop], emb) - for block in self.top_inp_blocks: - ht = block(ht, emb) - top_blocks.append(ht) - - # Now the standard unet (notice how it doesn't use ht at all, and uses a bare x fed through a strided conv. - h = self.cheater_input_block(x.type(self.dtype), emb) - hs = [] - for k, module in enumerate(self.input_blocks): - if isinstance(module, DiscreteSpectrogramConditioningBlock): - h = module(h, spectrogram) - else: - h = module(h, emb) - hs.append(h) - h = self.middle_block(h, emb) - for module in self.output_blocks: - h = torch.cat([h, hs.pop()], dim=1) - h = module(h, emb) - - # And finally the top output blocks, which do consume the unet's outputs as well as the cross-input blocks. First we'll need to only take a subset of the unets output. - hb = h[:, :, rand_start//2:rand_stop//2] - hb = self.top_out_upsample(hb, emb) - for block in self.top_out_blocks: - hb = torch.cat([hb, top_blocks.pop()], dim=1) - hb = block(hb, emb) - - hb = hb.type(x.dtype) - return self.top_out_final(hb), rand_start, rand_stop - - -@register_model -def register_unet_diffusion_vocoder_with_ref_trunc_top(opt_net, opt): - return DiffusionVocoderWithRefTruncatedTop(**opt_net['kwargs']) - - -# Test for ~4 second audio clip at 22050Hz -if __name__ == '__main__': - clip = torch.randn(2, 1, 40960) - #spec = torch.randint(8192, (2, 40,)) - spec = torch.randn(2, 512, 160) - cond = torch.randn(2, 1, 40960) - ts = torch.LongTensor([555, 556]) - model = DiffusionVocoderWithRefTruncatedTop(32, conditioning_inputs_provided=True, time_embed_dim_multiplier=8) - print(model(clip, ts, spec, cond)) diff --git a/codes/models/gpt_voice/unified_voice.py b/codes/models/gpt_voice/unified_voice.py deleted file mode 100644 index 04ea15aa..00000000 --- a/codes/models/gpt_voice/unified_voice.py +++ /dev/null @@ -1,344 +0,0 @@ -import functools - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import GPT2Model, GPT2Config - -from models.arch_util import AttentionBlock -from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel -from models.gpt_voice.gpt_asr_hf2 import ResBlock -from models.tacotron2.text import symbols -from trainer.networks import register_model -from utils.util import opt_get - - -class ConditioningEncoder(nn.Module): - def __init__(self, - spec_dim, - embedding_dim, - attn_blocks=6, - num_attn_heads=4, - do_checkpointing=False): - super().__init__() - attn = [] - self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) - for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing)) - self.attn = nn.Sequential(*attn) - self.dim = embedding_dim - self.do_checkpointing = do_checkpointing - - def forward(self, x): - h = self.init(x) - h = self.attn(h) - return h[:, :, 0] - - -class MelEncoder(nn.Module): - def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2): - super().__init__() - self.channels = channels - self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1), - nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]), - nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1), - nn.GroupNorm(channels//16, channels//2), - nn.ReLU(), - nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]), - nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1), - nn.GroupNorm(channels//8, channels), - nn.ReLU(), - nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), - ) - self.reduction = 4 - - - def forward(self, x): - for e in self.encoder: - x = e(x) - return x.permute(0,2,1) - - -def null_position_embeddings(range, dim): - return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) - - -class UnifiedGptVoice(nn.Module): - """ - Derived from GptTtsHf, but offers multiple modes of autoregressive operation: - - Text only - - Voice only - - Text conditioned on voice - - Voice conditioned on text - """ - - def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, - max_conditioning_length=60, shuffle_conditioning=True, mel_length_compression=1024, number_text_tokens=256, - start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, - stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, - checkpointing=True): - """ - Args: - layers: Number of layers in transformer stack. - model_dim: Operating dimensions of the transformer - heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64 - max_text_tokens: Maximum number of text tokens that will be encountered by model. - max_mel_tokens: Maximum number of MEL tokens that will be encountered by model. - max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). - max_conditioning_length: Maximum length of conditioning input. Only needed if shuffle_conditioning=True - shuffle_conditioning: Whether or not the conditioning inputs will be shuffled across the sequence dimension. Useful if you want to provide the same input as conditioning and mel_codes. - mel_length_compression: The factor between and . Used to compute MEL code padding given wav input length. - number_text_tokens: - start_text_token: - stop_text_token: - number_mel_codes: - start_mel_token: - stop_mel_token: - train_solo_embeddings: - use_mel_codes_as_input: - checkpointing: - """ - super().__init__() - - self.number_text_tokens = number_text_tokens - self.start_text_token = start_text_token - self.stop_text_token = stop_text_token - self.number_mel_codes = number_mel_codes - self.start_mel_token = start_mel_token - self.stop_mel_token = stop_mel_token - self.shuffle_conditioning = shuffle_conditioning - - self.max_mel_tokens = max_mel_tokens - self.max_text_tokens = max_text_tokens - self.model_dim = model_dim - self.max_conditioning_inputs = max_conditioning_inputs - self.mel_length_compression = mel_length_compression - self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) - self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) - self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 2, model_dim) - self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim) - seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs - self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes, - n_positions=seq_length, - n_ctx=seq_length, - n_embd=model_dim, - n_layer=layers, - n_head=heads, - gradient_checkpointing=checkpointing, - use_cache=not checkpointing) - self.gpt = GPT2Model(self.gpt_config) - if train_solo_embeddings: - self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True) - self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True) - else: - self.mel_solo_embedding = 0 - self.text_solo_embedding = 0 - # Override the built in positional embeddings - del self.gpt.wpe - self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) - - if not use_mel_codes_as_input: - self.gpt.wte = MelEncoder(model_dim, resblocks_per_reduction=1) - - self.final_norm = nn.LayerNorm(model_dim) - self.text_head = nn.Linear(model_dim, self.number_text_tokens) - self.mel_head = nn.Linear(model_dim, self.number_mel_codes) - self.max_conditioning_length = max_conditioning_length - - # Initialize the embeddings per the GPT-2 scheme - for module in [self.text_embedding, self.text_pos_embedding, self.mel_pos_embedding]: - module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def build_aligned_inputs_and_targets(self, input, start_token, stop_token): - inp = F.pad(input, (1,0), value=start_token) - tar = F.pad(input, (0,1), value=stop_token) - return inp, tar - - def set_mel_padding(self, mel_input_tokens, wav_lengths): - """ - Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in - that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required - preformatting to create a working TTS model. - """ - # Set padding areas within MEL (currently it is coded with the MEL code for ). - mel_lengths = wav_lengths // self.mel_length_compression - for b in range(len(mel_lengths)): - actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token. - if actual_end < mel_input_tokens.shape[-1]: - mel_input_tokens[b, actual_end:] = self.stop_mel_token - return mel_input_tokens - - def randomly_permute_conditioning_input(self, speech_conditioning_input): - """ - Randomly permute the conditioning spectrogram, to destroy any structure present. Note that since the - conditioning input is derived from a discrete spectrogram, it does actually retain structure, but only a little - bit (actually: exactly how much we want; enough to discriminate different vocal qualities, but nothing about - what is being said). - """ - cond_input = speech_conditioning_input[:,:,torch.randperm(speech_conditioning_input.shape[-1])] - if cond_input.shape[-1] > self.max_conditioning_length: - cond_input = cond_input[:,:,:self.max_conditioning_length] - return cond_input - - def get_logits(self, speech_conditioning_input, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False): - if second_inputs is not None: - emb = torch.cat([speech_conditioning_input, first_inputs, second_inputs], dim=1) - else: - emb = torch.cat([speech_conditioning_input, first_inputs], dim=1) - - gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) - if get_attns: - return gpt_out.attentions - - enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input - enc = self.final_norm(enc) - first_logits = enc[:, :first_inputs.shape[1]] - first_logits = first_head(first_logits) - first_logits = first_logits.permute(0,2,1) - if second_inputs is not None: - second_logits = enc[:, -second_inputs.shape[1]:] - second_logits = second_head(second_logits) - second_logits = second_logits.permute(0,2,1) - return first_logits, second_logits - else: - return first_logits - - def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False): - """ - Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode - (actuated by `text_first`). - - speech_conditioning_input: MEL float tensor, (b,80,s) - text_inputs: long tensor, (b,t) - text_lengths: long tensor, (b,) - mel_inputs: long tensor, (b,m) - wav_lengths: long tensor, (b,) - raw_mels: MEL float tensor (b,80,s) - """ - assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}' - assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}' - - # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by - # chopping the inputs by the maximum actual length. - max_text_len = text_lengths.max() - text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token) - max_mel_len = wav_lengths.max() // self.mel_length_compression - mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token) - if raw_mels is not None: - raw_mels = raw_mels[:, :, :max_mel_len*4] - mel_codes = self.set_mel_padding(mel_codes, wav_lengths) - - if self.shuffle_conditioning: - speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) - speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) - - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) - mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) - if raw_mels is not None: - mel_inp = F.pad(raw_mels, (0, 8)) - else: - mel_inp = mel_codes - mel_emb = self.gpt.get_input_embeddings()(mel_inp) - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) - if text_first: - text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions) - else: - mel_logits, text_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions) - - if return_attentions: - return mel_logits - loss_text = F.cross_entropy(text_logits, text_targets.long()) - loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) - return loss_text.mean(), loss_mel.mean(), mel_logits - - def text_forward(self, speech_conditioning_input, text_inputs, text_lengths): - """ - Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the - model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided). - """ - assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}' - - # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by - # chopping the inputs by the maximum actual length. - max_text_len = text_lengths.max() - text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token) - - if self.shuffle_conditioning: - speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) - speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) - - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + self.text_solo_embedding - text_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head) - loss_text = F.cross_entropy(text_logits, text_targets.long()) - return loss_text.mean() - - def speech_forward(self, speech_conditioning_input, mel_codes, wav_lengths, raw_mels=None): - """ - Performs autoregressive modeling on only speech data. - """ - assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}' - - # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by - # chopping the inputs by the maximum actual length. - max_mel_len = wav_lengths.max() // self.mel_length_compression - mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token) - mel_codes = self.set_mel_padding(mel_codes, wav_lengths) - if raw_mels is not None: - raw_mels = raw_mels[:, :, :max_mel_len*4] - - if self.shuffle_conditioning: - speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) - speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) - - mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) - if raw_mels is not None: - mel_inp = F.pad(raw_mels, (0, 4)) - else: - mel_inp = mel_codes - mel_emb = self.gpt.get_input_embeddings()(mel_inp) - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + self.mel_solo_embedding - mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head) - loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) - return loss_mel.mean() - - def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs): - if not hasattr(self, 'inference_model'): - self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head) - - text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) - - if self.shuffle_conditioning: - # Randomly permute the conditioning spectrogram, to destroy any structure present. - speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) - cond = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) - - emb = torch.cat([cond, text_emb], dim=1) - self.inference_model.store_mel_emb(emb) - - fake_inputs = torch.full((emb.shape[0], emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device) - fake_inputs[:,-1] = self.start_mel_token - - gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, - max_length=self.gpt_config.n_positions, **hf_generate_kwargs) - return gen[:, fake_inputs.shape[1]:] - - -@register_model -def register_unified_gpt_voice(opt_net, opt): - return UnifiedGptVoice(**opt_get(opt_net, ['kwargs'], {})) - - -if __name__ == '__main__': - gpt = UnifiedGptVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True) - l = gpt(torch.randn(2, 80, 800), - torch.randint(high=len(symbols), size=(2,80)), - torch.tensor([32, 80]), - torch.randint(high=8192, size=(2,250)), - torch.tensor([150*256,195*256])) - gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80])) diff --git a/codes/models/gpt_voice/unified_voice2.py b/codes/models/gpt_voice/unified_voice2.py index 69b61fb7..edbf52e9 100644 --- a/codes/models/gpt_voice/unified_voice2.py +++ b/codes/models/gpt_voice/unified_voice2.py @@ -8,13 +8,30 @@ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.utils.model_parallel_utils import get_device_map, assert_device_map from models.arch_util import AttentionBlock -from models.gpt_voice.gpt_asr_hf2 import ResBlock from models.gpt_voice.transformer_builders import build_hf_gpt_transformer from models.tacotron2.text import symbols from trainer.networks import register_model from utils.util import opt_get +class ResBlock(nn.Module): + """ + Basic residual convolutional block that uses GroupNorm. + """ + def __init__(self, chan): + super().__init__() + self.net = nn.Sequential( + nn.Conv1d(chan, chan, kernel_size=3, padding=1), + nn.GroupNorm(chan//8, chan), + nn.ReLU(), + nn.Conv1d(chan, chan, kernel_size=3, padding=1), + nn.GroupNorm(chan//8, chan) + ) + + def forward(self, x): + return F.relu(self.net(x) + x) + + class GPT2InferenceModel(GPT2PreTrainedModel): def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear): super().__init__(config) diff --git a/codes/models/gpt_voice/unified_voice_bilevel.py b/codes/models/gpt_voice/unified_voice_bilevel.py deleted file mode 100644 index 35a811ad..00000000 --- a/codes/models/gpt_voice/unified_voice_bilevel.py +++ /dev/null @@ -1,313 +0,0 @@ -import functools -from math import log - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import GPT2Model, GPT2Config - -from models.arch_util import AttentionBlock -from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel -from models.tacotron2.text import symbols -from trainer.networks import register_model -from utils.util import opt_get - - -def null_position_embeddings(range, dim): - return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) - - -class ConditioningEncoder(nn.Module): - def __init__(self, - spec_dim, - embedding_dim, - attn_blocks=6, - num_attn_heads=4, - do_checkpointing=False): - super().__init__() - attn = [] - self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) - for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing)) - self.attn = nn.Sequential(*attn) - self.dim = embedding_dim - self.do_checkpointing = do_checkpointing - - def forward(self, x): - h = self.init(x) - h = self.attn(h) - return h[:, :, 0] - - -class TopEncoder(nn.Module): - def __init__(self, layers, dim, heads, do_checkpointing=False, dim_reduction=16): - self.init = nn.Conv1d(dim, dim, kernel_size=1) - reduction_layers = [] - for j in range(int(log(dim_reduction, 2))): - reduction_layers.append(AttentionBlock(dim, heads, do_checkpoint=do_checkpointing)) - reduction_layers.append(nn.Conv1d(dim, dim, kernel_size=3, padding=1, stride=2)) - self.reduction_layers = nn.Sequential(*reduction_layers) - actual_layers = [AttentionBlock(dim, heads, do_checkpoint=do_checkpointing) for _ in range(layers)] - self.actual_layers = nn.Sequential(*actual_layers) - - def forward(self, x): - h = self.init(x) - h = self.reduction_layers(h) - h = self.actual_layers(h) - return h - - -class UnifiedGptVoice(nn.Module): - """ - Derived from GptTtsHf, but offers multiple modes of autoregressive operation: - - Text only - - Voice only - - Text conditioned on voice - - Voice conditioned on text - """ - - def __init__(self, top_encoder_layers=4, top_layers=8, bottom_layers=8, top_dim_reduction=16, model_dim=512, heads=8, - max_symbols_per_phrase=120, max_mel_tokens=250, max_total_tokens=370, max_conditioning_inputs=3, - checkpointing=True, mel_length_compression=1024, max_conditioning_length=60, number_text_tokens=256, - start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, - stop_mel_token=8193): - super().__init__() - - self.number_text_tokens = number_text_tokens - self.start_text_token = start_text_token - self.stop_text_token = stop_text_token - self.number_mel_codes = number_mel_codes - self.start_mel_token = start_mel_token - self.stop_mel_token = stop_mel_token - - self.max_mel_tokens = max_mel_tokens - self.max_symbols_per_phrase = max_symbols_per_phrase - self.max_total_tokens = max_total_tokens - self.model_dim = model_dim - self.max_conditioning_inputs = max_conditioning_inputs - self.mel_length_compression = mel_length_compression - self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) - self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) - self.text_pos_solo_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) - self.text_pos_paired_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) - self.mel_pos_solo_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim) - self.mel_pos_paired_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim) - seq_length = 2+self.max_total_tokens+self.max_conditioning_inputs - - self.top_encoder = TopEncoder(top_encoder_layers, model_dim, heads, do_checkpointing=checkpointing, - dim_reduction=top_dim_reduction) - self.top_gpt_config = GPT2Config(vocab_size=1, - n_positions=seq_length // top_dim_reduction, - n_ctx=seq_length // top_dim_reduction, - n_embd=model_dim, - n_layer=top_layers, - n_head=heads, - gradient_checkpointing=checkpointing, - use_cache=not checkpointing) - self.top_gpt = GPT2Model(self.top_gpt_config) - del self.top_gpt.wte - self.top_gpt_start_embedding = nn.Parameter(torch.randn(1,1,model_dim)*self.top_gpt_config.initializer_range, - requires_grad=True) - self.top_dim_reduction = top_dim_reduction - - self.bottom_gpt_config = GPT2Config(vocab_size=self.number_mel_codes, - n_positions=seq_length, - n_ctx=seq_length, - n_embd=model_dim, - n_layer=bottom_layers, - n_head=heads, - gradient_checkpointing=checkpointing, - use_cache=not checkpointing) - self.bottom_gpt = GPT2Model(self.bottom_gpt_config) - # Override the built in positional embeddings - del self.bottom_gpt.wpe - self.bottom_gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) - - self.final_norm = nn.LayerNorm(model_dim) - self.text_head = nn.Linear(model_dim, self.number_text_tokens) - self.mel_head = nn.Linear(model_dim, self.number_mel_codes) - self.max_conditioning_length = max_conditioning_length - - # Initialize the embeddings per the GPT-2 scheme - for module in [self.text_embedding, self.text_pos_solo_embedding, self.text_pos_paired_embedding, - self.mel_pos_solo_embedding, self.mel_pos_paired_embedding]: - module.weight.data.normal_(mean=0.0, std=self.bottom_gpt.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def build_aligned_inputs_and_targets(self, input, start_token, stop_token): - inp = F.pad(input, (1,0), value=start_token) - tar = F.pad(input, (0,1), value=stop_token) - return inp, tar - - def set_mel_padding(self, mel_input_tokens, wav_lengths): - """ - Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in - that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required - preformatting to create a working TTS model. - """ - # Set padding areas within MEL (currently it is coded with the MEL code for ). - mel_lengths = wav_lengths // self.mel_length_compression - for b in range(len(mel_lengths)): - actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token. - if actual_end < mel_input_tokens.shape[-1]: - mel_input_tokens[b, actual_end:] = self.stop_mel_token - return mel_input_tokens - - def randomly_permute_conditioning_input(self, speech_conditioning_input): - """ - Randomly permute the conditioning spectrogram, to destroy any structure present. Note that since the - conditioning input is derived from a discrete spectrogram, it does actually retain structure, but only a little - bit (actually: exactly how much we want; enough to discriminate different vocal qualities, but nothing about - what is being said). - """ - cond_input = speech_conditioning_input[:,:,torch.randperm(speech_conditioning_input.shape[-1])] - if cond_input.shape[-1] > self.max_conditioning_length: - cond_input = cond_input[:,:,:self.max_conditioning_length] - return cond_input - - - def get_top_embeddings(self, embedded_input): - true_embeddings = self.top_encoder(embedded_input) - inputs = torch.cat([self.top_gpt_start_embedding, true_embeddings[:,:-1]], dim=1) - top_pred = self.top_gpt(inputs_embeds=inputs, return_dict=True) - return top_pred.last_hidden_state, true_embeddings - - - def inject_top_embeddings(self, embedded_input, probability_of_true_top_embedding=.5): - pred, true = self.get_top_embeddings(embedded_input) - rand = torch.bernoulli(torch.full((1,embedded_input.shape[1]), - fill_value=probability_of_true_top_embedding)).to(embedded_input.device) - mix = pred * rand + true * (not rand) - embs = torch.chunk(embedded_input, self.top_dim_reduction, dim=1) - assert len(embs) == mix.shape[1] - rejoin = [] - for i, emb in enumerate(embs): - rejoin.append(torch.cat([mix[i], emb]), dim=1) - return torch.cat(rejoin, dim=1) - - - def get_logits(self, speech_conditioning_input, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False): - if second_inputs is not None: - emb = torch.cat([speech_conditioning_input, first_inputs, second_inputs], dim=1) - else: - emb = torch.cat([speech_conditioning_input, first_inputs], dim=1) - - gpt_out = self.bottom_gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) - if get_attns: - return gpt_out.attentions - - enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input - enc = self.final_norm(enc) - first_logits = enc[:, :first_inputs.shape[1]] - first_logits = first_head(first_logits) - first_logits = first_logits.permute(0,2,1) - if second_inputs is not None: - second_logits = enc[:, -second_inputs.shape[1]:] - second_logits = second_head(second_logits) - second_logits = second_logits.permute(0,2,1) - return first_logits, second_logits - else: - return first_logits - - def forward(self, speech_conditioning_input, text_inputs, mel_inputs, wav_lengths, text_first=True, return_attentions=False): - """ - Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode - (actuated by `text_first`). - - speech_conditioning_input: MEL float tensor, (b,80,s) - text_inputs: long tensor, (b,t) - mel_inputs: long tensor, (b,m) - wav_lengths: long tensor, (b,) - """ - assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}' - assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}' - assert self.max_total_tokens >= mel_inputs.shape[1] + text_inputs.shape[1], f'{mel_inputs.shape[1]}, {text_inputs.shape[1]}' - - mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths) - speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) - speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) - - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_paired_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) - mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token) - mel_emb = self.bottom_gpt.get_input_embeddings()(mel_inputs) - mel_emb = mel_emb + self.mel_pos_paired_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) - - if text_first: - text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions) - else: - mel_logits, text_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions) - - if return_attentions: - return mel_logits - loss_text = F.cross_entropy(text_logits, text_targets.long()) - loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) - return loss_text.mean(), loss_mel.mean(), mel_logits - - def text_forward(self, speech_conditioning_input, text_inputs): - """ - Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the - model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided). - """ - assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}' - - speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) - speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) - - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_solo_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) - text_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head) - loss_text = F.cross_entropy(text_logits, text_targets.long()) - return loss_text.mean() - - def speech_forward(self, speech_conditioning_input, mel_inputs, wav_lengths): - """ - Performs autoregressive modeling on only speech data. - """ - assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}' - - mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths) - speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) - speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) - - mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token) - mel_emb = self.bottom_gpt.get_input_embeddings()(mel_inputs) - mel_emb = mel_emb + self.mel_pos_solo_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) - mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head) - loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) - return loss_mel.mean() - - def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs): - if not hasattr(self, 'inference_model'): - self.inference_model = GPT2InferenceModel(self.bottom_gpt_config, self.bottom_gpt, self.mel_pos_paired_embedding, self.final_norm, self.mel_head) - - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_paired_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) - - # Randomly permute the conditioning spectrogram, to destroy any structure present. - speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) - cond = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) - - emb = torch.cat([cond, text_emb], dim=1) - self.inference_model.store_mel_emb(emb) - - fake_inputs = torch.full((emb.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device) - fake_inputs[:,-1] = self.start_mel_token - - gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, - max_length=self.bottom_gpt_config.n_positions, **hf_generate_kwargs) - return gen[:, fake_inputs.shape[1]:] - - -@register_model -def register_unified_gpt_voice_bilevel(opt_net, opt): - return UnifiedGptVoice(**opt_get(opt_net, ['kwargs'], {})) - - -if __name__ == '__main__': - gpt = UnifiedGptVoice(model_dim=256, heads=4) - l = gpt(torch.randn(2, 80, 800), - torch.randint(high=len(symbols), size=(2,80)), - torch.randint(high=8192, size=(2,250)), - torch.tensor([150*256,195*256])) diff --git a/codes/scripts/audio/asr_eval.py b/codes/scripts/audio/asr_eval.py deleted file mode 100644 index a16045a7..00000000 --- a/codes/scripts/audio/asr_eval.py +++ /dev/null @@ -1,88 +0,0 @@ -import os -import os.path as osp -import logging -import random -import argparse - -import torchvision - -import utils -import utils.options as option -import utils.util as util -from models.tacotron2.text import sequence_to_text -from trainer.ExtensibleTrainer import ExtensibleTrainer -from data import create_dataset, create_dataloader -from tqdm import tqdm -import torch -import numpy as np -from scipy.io import wavfile - - -def forward_pass(model, data, output_dir, opt, macro_b, dataset): - with torch.no_grad(): - model.feed_data(data, 0) - model.test() - - gt_key = opt['eval']['gen_text'] - txts = [] - for b in range(model.eval_state[gt_key][0].shape[0]): - if 'real_text' in opt['eval'].keys(): - real = data[opt['eval']['real_text']][b] - print(f'{macro_b} {b} Real text: "{real}"') - - codes = model.eval_state[opt['eval']['gen_text']][0][b].cpu() - if hasattr(dataset, 'tokenizer'): - text = dataset.tokenizer.decode(codes.numpy()) - text = text.replace(' $$$', '') - txts.append(text) - else: - txts.append(sequence_to_text(codes)) - return txts - - -if __name__ == "__main__": - # Set seeds - torch.manual_seed(5555) - random.seed(5555) - np.random.seed(5555) - - #### options - torch.backends.cudnn.benchmark = True - want_metrics = False - parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_asr_hf2.yml') - opt = option.parse(parser.parse_args().opt, is_train=False) - opt = option.dict_to_nonedict(opt) - utils.util.loaded_options = opt - - util.mkdirs( - (path for key, path in opt['path'].items() - if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) - util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, - screen=True, tofile=True) - logger = logging.getLogger('base') - logger.info(option.dict2str(opt)) - - dataset_opt = opt['datasets']['val'] - test_set, collate_fn = create_dataset(dataset_opt, return_collate=True) - test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn) - logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) - - model = ExtensibleTrainer(opt) - - batch = 0 - output = open('results.tsv', 'w') - dataset_dir = opt['path']['results_root'] - util.mkdir(dataset_dir) - - for data in tqdm(test_loader): - #if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255: - # continue - preds = forward_pass(model, data, dataset_dir, opt, batch, test_set) - for b, pred in enumerate(preds): - pred = pred.replace('_', '') - output.write(f'{pred}\t{os.path.basename(data["filenames"][b])}\n') - print(pred) - batch += 1 - output.flush() - diff --git a/codes/scripts/audio/compute_gpt_attention.py b/codes/scripts/audio/compute_gpt_attention.py deleted file mode 100644 index 520a6998..00000000 --- a/codes/scripts/audio/compute_gpt_attention.py +++ /dev/null @@ -1,40 +0,0 @@ -import os - -import numpy -import torch -import torch.nn as nn -from matplotlib import pyplot -from torch.utils.tensorboard import SummaryWriter - -from data.audio.unsupervised_audio_dataset import load_audio -from models.gpt_voice.gpt_asr_hf import GptAsrHf -from models.tacotron2.text import text_to_sequence -from trainer.injectors.base_injectors import MelSpectrogramInjector - -if __name__ == '__main__': - audio_data = load_audio('Z:\\split\\classified\\fine\\books1\\2_dchha03 The Organization of Peace\\00010.wav', 22050).unsqueeze(0) - audio_data = torch.nn.functional.pad(audio_data, (0, 358395-audio_data.shape[-1])) - mel_inj = MelSpectrogramInjector({'in': 'in', 'out': 'out'}, {}) - mel = mel_inj({'in': audio_data})['out'].cuda() - actual_text = 'and it doesn\'t take very long.' - labels = torch.IntTensor(text_to_sequence(actual_text, ['english_cleaners'])).unsqueeze(0).cuda() - - model = GptAsrHf(layers=12, model_dim=512, max_mel_frames=1400, max_symbols_per_phrase=250, heads=8) - model.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\31000_gpt_ema.pth')) - model = model.cuda() - - with torch.no_grad(): - attentions = model(mel, labels, return_attentions=True) - attentions = torch.stack(attentions, dim=0).permute(0,1,2,4,3)[:, :, :, -model.max_symbols_per_phrase:, :model.max_mel_frames] - attentions = attentions.sum(0).sum(1).squeeze() - - xs = [str(i) for i in range(1, model.max_mel_frames+1, 1)] - os.makedirs('results', exist_ok=True) - logger = SummaryWriter('results') - for e, character_attn in enumerate(attentions): - if e >= len(actual_text): - break - fig = pyplot.figure() - ax = fig.add_axes([0,0,1,1]) - ax.bar(xs, character_attn.cpu().numpy()) - logger.add_figure(f'{e}_{actual_text[e]}', fig) diff --git a/codes/scripts/audio/gen/use_diffuse_tts.py b/codes/scripts/audio/gen/use_diffuse_tts.py index 419b5553..a60b1b8f 100644 --- a/codes/scripts/audio/gen/use_diffuse_tts.py +++ b/codes/scripts/audio/gen/use_diffuse_tts.py @@ -114,44 +114,56 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model', default='X:\\dlas\\experiments\\train_diffusion_tts5_medium.yml') parser.add_argument('-diffusion_model_name', type=str, help='Name of the diffusion model in opt.', default='generator') - parser.add_argument('-diffusion_model_path', type=str, help='Path to saved model weights', default='X:\\dlas\\experiments\\train_diffusion_tts5_medium\\models\\68500_generator_ema.pth') - # -cond "Y:\libritts/train-clean-100/103/1241/103_1241_000017_000001.wav" - parser.add_argument('-cond', type=str, help='Type of conditioning voice', default='simmons') + parser.add_argument('-diffusion_model_path', type=str, help='Path to saved model weights', default='X:\\dlas\\experiments\\train_diffusion_tts5_medium\\models\\73000_generator_ema.pth') + parser.add_argument('-sr_opt', type=str, help='Path to options YAML file used to train the SR diffusion model', default='X:\\dlas\\experiments\\train_diffusion_tts6_upsample.yml') + parser.add_argument('-sr_diffusion_model_name', type=str, help='Name of the SR diffusion model in opt.', default='generator') + parser.add_argument('-sr_diffusion_model_path', type=str, help='Path to saved model weights for the SR diffuser', default='X:\\dlas\\experiments\\train_diffusion_tts6_upsample\\models\\7000_generator_ema.pth') + parser.add_argument('-cond', type=str, help='Type of conditioning voice', default='carlin') parser.add_argument('-diffusion_steps', type=int, help='Number of diffusion steps to perform to create the generate. Lower steps reduces quality, but >40 is generally pretty good.', default=100) - parser.add_argument('-diffusion_schedule', type=str, help='Type of diffusion schedule that was used', default='cosine') parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='../results/use_diffuse_tts') - parser.add_argument('-sample_rate', type=int, help='Model sample rate', default=5500) - parser.add_argument('-cond_sample_rate', type=int, help='Conditioning sample rate', default=5500) parser.add_argument('-device', type=str, help='Device to run on', default='cuda') args = parser.parse_args() os.makedirs(args.output_path, exist_ok=True) - print("Loading Diffusion Model..") + # Fixed parameters. + base_sample_rate = 5500 + sr_sample_rate = 22050 + + print("Loading Diffusion Models..") diffusion = load_model_from_config(args.opt, args.diffusion_model_name, also_load_savepoint=False, - load_path=args.diffusion_model_path, device=args.device) - diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=args.diffusion_steps, schedule=args.diffusion_schedule) - aligned_codes_compression_factor = args.sample_rate * 221 // 11025 - cond = load_audio(conditioning_clips[args.cond], args.cond_sample_rate).to(args.device) - if cond.shape[-1] > 88000: - cond = cond[:,:88000] - torchaudio.save(os.path.join(args.output_path, 'cond.wav'), cond.cpu(), args.sample_rate) + load_path=args.diffusion_model_path, device='cpu').eval() + diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=args.diffusion_steps, schedule='cosine') + aligned_codes_compression_factor = base_sample_rate * 221 // 11025 + sr_diffusion = load_model_from_config(args.sr_opt, args.sr_diffusion_model_name, also_load_savepoint=False, + load_path=args.sr_diffusion_model_path, device='cpu').eval() + sr_diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=args.diffusion_steps, schedule='linear') + sr_cond = load_audio(conditioning_clips[args.cond], sr_sample_rate).to(args.device) + if sr_cond.shape[-1] > 88000: + sr_cond = sr_cond[:,:88000] + cond = audio = torchaudio.functional.resample(sr_cond, sr_sample_rate, base_sample_rate) + torchaudio.save(os.path.join(args.output_path, 'cond_base.wav'), cond.cpu(), base_sample_rate) + torchaudio.save(os.path.join(args.output_path, 'cond_sr.wav'), sr_cond.cpu(), sr_sample_rate) - for p, code in enumerate(provided_codes): - print("Loading data..") - aligned_codes = torch.tensor(code).to(args.device) + with torch.no_grad(): + for p, code in enumerate(provided_codes): + print("Loading data..") + aligned_codes = torch.tensor(code).to(args.device) - with torch.no_grad(): - print("Performing inference..") - diffusion.eval() + print("Performing initial diffusion..") output_shape = (1, 1, ceil_multiple(aligned_codes.shape[-1]*aligned_codes_compression_factor, 2048)) - - output = diffuser.p_sample_loop(diffusion, output_shape, noise=torch.zeros(output_shape, device=args.device), + diffusion = diffusion.cuda() + output_base = diffuser.p_sample_loop(diffusion, output_shape, noise=torch.zeros(output_shape, device=args.device), model_kwargs={'tokens': aligned_codes.unsqueeze(0), 'conditioning_input': cond.unsqueeze(0)}) - torchaudio.save(os.path.join(args.output_path, f'{p}_output_mean.wav'), output.cpu().squeeze(0), args.sample_rate) + diffusion = diffusion.cpu() + torchaudio.save(os.path.join(args.output_path, f'{p}_output_mean_base.wav'), output_base.cpu().squeeze(0), base_sample_rate) - for k in range(2): - output = diffuser.p_sample_loop(diffusion, output_shape, model_kwargs={'tokens': aligned_codes.unsqueeze(0), - 'conditioning_input': cond.unsqueeze(0)}) - - torchaudio.save(os.path.join(args.output_path, f'{p}_output_{k}.wav'), output.cpu().squeeze(0), args.sample_rate) + print("Performing SR diffusion..") + output_shape = (1, 1, output_base.shape[-1] * (sr_sample_rate // base_sample_rate)) + sr_diffusion = sr_diffusion.cuda() + output = diffuser.p_sample_loop(sr_diffusion, output_shape, noise=torch.zeros(output_shape, device=args.device), + model_kwargs={'tokens': aligned_codes.unsqueeze(0), + 'conditioning_input': sr_cond.unsqueeze(0), + 'lr_input': output_base}) + sr_diffusion = sr_diffusion.cpu() + torchaudio.save(os.path.join(args.output_path, f'{p}_output_mean_sr.wav'), output.cpu().squeeze(0), sr_sample_rate) diff --git a/codes/scripts/audio/generate_quantized_mels.py b/codes/scripts/audio/generate_quantized_mels.py deleted file mode 100644 index b331ef12..00000000 --- a/codes/scripts/audio/generate_quantized_mels.py +++ /dev/null @@ -1,68 +0,0 @@ -import os -import os.path as osp -import logging -import random -import argparse - -import torchvision - -import utils -import utils.options as option -import utils.util as util -from models.waveglow.denoiser import Denoiser -from trainer.ExtensibleTrainer import ExtensibleTrainer -from data import create_dataset, create_dataloader -from tqdm import tqdm -import torch -import numpy as np -from scipy.io import wavfile - -if __name__ == "__main__": - # Set seeds - torch.manual_seed(5555) - random.seed(5555) - np.random.seed(5555) - - #### options - torch.backends.cudnn.benchmark = True - want_metrics = False - parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/generate_quantized_mels.yml') - opt = option.parse(parser.parse_args().opt, is_train=False) - opt = option.dict_to_nonedict(opt) - utils.util.loaded_options = opt - - util.mkdirs( - (path for key, path in opt['path'].items() - if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) - util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, - screen=True, tofile=True) - logger = logging.getLogger('base') - logger.info(option.dict2str(opt)) - - test_loaders = [] - for phase, dataset_opt in sorted(opt['datasets'].items()): - test_set, collate_fn = create_dataset(dataset_opt, return_collate=True) - test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn) - logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) - test_loaders.append(test_loader) - - model = ExtensibleTrainer(opt) - - outpath = opt['path']['results_root'] - os.makedirs(os.path.join(outpath, 'quantized_mels'), exist_ok=True) - for test_loader in test_loaders: - dataset_dir = opt['path']['results_root'] - util.mkdir(dataset_dir) - - tq = tqdm(test_loader) - for data in tq: - with torch.no_grad(): - model.feed_data(data, 0) - model.test() - - wavfiles = data['filenames'] - quantized = model.eval_state[opt['eval']['quantized_mels']][0] - for i, filename in enumerate(wavfiles): - qmelfile = filename.replace('wavs/', 'quantized_mels/') + '.pth' - torch.save(quantized[i], os.path.join(outpath, qmelfile)) diff --git a/codes/scripts/audio/librivox/preprocess_libritts.py b/codes/scripts/audio/librivox/preprocess_libritts.py deleted file mode 100644 index acfb2ec4..00000000 --- a/codes/scripts/audio/librivox/preprocess_libritts.py +++ /dev/null @@ -1,32 +0,0 @@ -# Combines all libriTTS WAV->text mappings into a single file -import os - -from tqdm import tqdm - -if __name__ == '__main__': - libri_root = 'E:\\audio\\LibriTTS' - basis = 'train-clean-360' - - readers = os.listdir(os.path.join(libri_root, basis)) - ofile = open(os.path.join(libri_root, f'{basis}_list.txt'), 'w', encoding='utf-8') - for reader_dir in tqdm(readers): - reader = os.path.join(libri_root, basis, reader_dir) - if not os.path.isdir(reader): - continue - for chapter_dir in os.listdir(reader): - chapter = os.path.join(reader, chapter_dir) - if not os.path.isdir(chapter): - continue - id = f'{os.path.basename(reader)}_{os.path.basename(chapter)}' - trans_file = f'{id}.trans.tsv' - with open(os.path.join(chapter, trans_file), encoding='utf-8') as f: - trans_lines = [line.strip().split('\t') for line in f] - for line in trans_lines: - wav_file, raw_text, normalized_text = line - wav_file = '/'.join([basis, reader_dir, chapter_dir, f'{wav_file}.wav']) - if not os.path.exists(os.path.join(libri_root, wav_file)): - print(f'!WARNING could not open {wav_file}') - else: - ofile.write(f'{wav_file}|{normalized_text}\n') - ofile.flush() - ofile.close() diff --git a/codes/scripts/audio/librivox/produce_libri_stretched_dataset.py b/codes/scripts/audio/librivox/produce_libri_stretched_dataset.py deleted file mode 100644 index cb56ba76..00000000 --- a/codes/scripts/audio/librivox/produce_libri_stretched_dataset.py +++ /dev/null @@ -1,99 +0,0 @@ -# Combines all libriTTS WAV->text mappings into a single file -import os -import random - -import audio2numpy -import torch -from scipy.io import wavfile -from tqdm import tqdm - -from utils.audio_resampler import AudioResampler - - -def secs_to_frames(secs, sr): - return int(secs*sr) - - -def get_audio_clip(audio, sr, start, end): - start = secs_to_frames(start, sr) - end = secs_to_frames(end, sr) - assert end > start - if end >= audio.shape[0]: - return None - return audio[start:end] - - -# Produces an audio clip that would produce a MEL spectrogram of length mel_length by parsing parsed_sentences starting -# at starting_index and moving forwards until the full length is finished. -# Returns: -# On failure, returns tuple: (end_index, None, [], []) -# On success: returns tuple: (end_index, clip, start_points, end_points) -# clip.shape = (,) -# start_points = list(ints) where each sentence in the clip starts -# end_points = list(ints) where each sentence in the clip ends -def gather_clip(audio, parsed_sentences, starting_index, sr, mel_length): - audio_length = (mel_length * 256) / sr # This is technically a hyperparameter, but I have no intent of changing the MEL hop length. - starts = [] - ends = [] - start, end = parsed_sentences[starting_index][4:6] - start = float(start) - end = float(end) - clipstart = max(start - random.random() * 2, 0) # Offset start backwards by up to 2 seconds - clipend = start + audio_length - clip = get_audio_clip(audio, sr, clipstart, clipend) - if clip is not None: - # Fetch the start and endpoints that go along with this clip. - starts.append(secs_to_frames(start-clipstart, sr)) - while end < clipend: - ends.append(secs_to_frames(end-clipstart, sr)) - starting_index += 1 - if starting_index >= len(parsed_sentences): - break - start, end = parsed_sentences[starting_index][4:6] - start = float(start) - end = float(end) - if start < clipend: - starts.append(secs_to_frames(start-clipstart, sr)) - - return starting_index+1, clip, starts, ends - - -if __name__ == '__main__': - full_book_root = 'D:\\data\\audio\\libritts\\full_books\\mp3' - libri_root = 'D:\\data\\audio\\libritts\\test-clean' - desired_mel_length = 2000 - desired_audio_sample_rate = 22050 - output_dir = 'D:\\data\\audio\\libritts\\stop_dataset_eval' - - os.makedirs(output_dir, exist_ok=True) - j = 0 - readers = os.listdir(libri_root) - for it, reader_dir in enumerate(tqdm(readers)): - #if it <= 145: # Hey idiot! If you change this, change j too! - # continue - reader = os.path.join(libri_root, reader_dir) - if not os.path.isdir(reader): - continue - for chapter_dir in os.listdir(reader): - chapter = os.path.join(reader, chapter_dir) - if not os.path.isdir(chapter): - continue - id = f'{os.path.basename(reader)}_{os.path.basename(chapter)}' - book_file = os.path.join(chapter, f'{id}.book.tsv') - if not os.path.exists(book_file): - continue - with open(book_file, encoding='utf-8') as f: - full_chapter, sr = audio2numpy.open_audio(os.path.join(full_book_root, reader_dir, chapter_dir, f'{chapter_dir}.mp3')) - full_chapter = torch.tensor(full_chapter) - if len(full_chapter.shape) > 1: - full_chapter = full_chapter[:, 0] # Only use mono-audio. - resampler = AudioResampler(sr, desired_audio_sample_rate, dtype=torch.float) - full_chapter = resampler(full_chapter.unsqueeze(0)).squeeze(0) - parsed_sentences = [line.strip().split('\t') for line in f] - i = 0 - while i < len(parsed_sentences): - i, clip, ns, ne = gather_clip(full_chapter, parsed_sentences, i, desired_audio_sample_rate, desired_mel_length) - if clip is not None: - wavfile.write(os.path.join(output_dir, f'{j}.wav'), desired_audio_sample_rate, clip.cpu().numpy()) - torch.save((ns,ne), os.path.join(output_dir, f'{j}_se.pth')) - j += 1