more stuff

This commit is contained in:
James Betker 2022-03-25 00:03:18 -06:00
parent d4218d8443
commit 45804177b8
3 changed files with 12 additions and 15 deletions

View File

@ -260,7 +260,7 @@ class FastPairedVoiceDebugger:
if __name__ == '__main__':
batch_sz = 256
batch_sz = 16
params = {
'mode': 'fast_paired_voice_audio',
'path': ['y:/libritts/train-other-500/transcribed-oco.tsv',
@ -268,20 +268,19 @@ if __name__ == '__main__':
'y:/libritts/train-clean-360/transcribed-oco.tsv',
'y:/clips/books1/transcribed-w2v.tsv',
'y:/clips/books2/transcribed-w2v.tsv',
'y:/bigasr_dataset/hifi_tts/transcribed-w2v.tsv'],
'y:/bigasr_dataset/hifi_tts/transcribed-w2v.tsv',
'y:/clips/podcasts-1/transcribed-oco.tsv',],
'phase': 'train',
'n_workers': 0,
'batch_size': batch_sz,
'max_wav_length': 163840,
'max_text_length': 200,
'max_wav_length': 220500,
'max_text_length': 500,
'sample_rate': 22050,
'load_conditioning': True,
'num_conditioning_candidates': 1,
'conditioning_length': 66000,
'use_bpe_tokenizer': False,
'num_conditioning_candidates': 2,
'conditioning_length': 102400,
'use_bpe_tokenizer': True,
'load_aligned_codes': False,
'needs_collate': False,
'produce_ctc_metadata': False,
}
from data import create_dataset, create_dataloader
@ -302,6 +301,8 @@ if __name__ == '__main__':
#max_repeats = max(max_repeats, b['ctc_repeats'].max())
print(f'{i} {ib} {b["real_text"][ib]}')
save(b, i, ib, 'wav')
save(b, i, ib, 'conditioning', 0)
save(b, i, ib, 'conditioning', 1)
pass
if i > 15:
break

View File

@ -4,12 +4,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
from x_transformers import Encoder
from x_transformers.x_transformers import RelativePositionBias
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, TimestepBlock
from models.audio.tts.mini_encoder import AudioMiniEncoder
from trainer.networks import register_model
from utils.util import checkpoint
@ -189,7 +186,7 @@ class DiffusionTtsFlat(nn.Module):
}
return groups
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False):
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, conditioning_free=False):
"""
Apply the model to an input batch.
@ -197,7 +194,6 @@ class DiffusionTtsFlat(nn.Module):
:param timesteps: a 1-D batch of timesteps.
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
:param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
:param lr_input: for super-sampling models, a guidance audio clip at a lower sampling rate.
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
:return: an [N x C x ...] Tensor of outputs.
"""

View File

@ -132,7 +132,7 @@ class DiscreteTokenInjector(Injector):
super().__init__(opt, env)
cfg = opt_get(opt, ['dvae_config'], "../experiments/train_diffusion_vocoder_22k_level.yml")
dvae_name = opt_get(opt, ['dvae_name'], 'dvae')
self.dvae = load_model_from_config(cfg, dvae_name, device=env['device']).eval()
self.dvae = load_model_from_config(cfg, dvae_name, device=f'cuda:{env["device"]}').eval()
def forward(self, state):
inp = state[self.input]