Implement guidance-free diffusion in eval

And a few other fixes
This commit is contained in:
James Betker 2022-03-01 11:49:36 -07:00
parent 2134f06516
commit db0c3340ac
6 changed files with 61 additions and 39 deletions

View File

@ -263,19 +263,25 @@ if __name__ == '__main__':
batch_sz = 256 batch_sz = 256
params = { params = {
'mode': 'fast_paired_voice_audio', 'mode': 'fast_paired_voice_audio',
'path': ['Y:\\libritts\\train-clean-360\\transcribed-w2v.tsv', 'Y:\\clips\\books1\\transcribed-w2v.tsv'], 'path': ['y:/libritts/train-other-500/transcribed-oco.tsv',
'y:/libritts/train-clean-100/transcribed-oco.tsv',
'y:/libritts/train-clean-360/transcribed-oco.tsv',
'y:/clips/books1/transcribed-w2v.tsv',
'y:/clips/books2/transcribed-w2v.tsv',
'y:/bigasr_dataset/hifi_tts/transcribed-w2v.tsv'],
'phase': 'train', 'phase': 'train',
'n_workers': 0, 'n_workers': 0,
'batch_size': batch_sz, 'batch_size': batch_sz,
'max_wav_length': 255995, 'max_wav_length': 163840,
'max_text_length': 200, 'max_text_length': 200,
'sample_rate': 22050, 'sample_rate': 22050,
'load_conditioning': True, 'load_conditioning': True,
'num_conditioning_candidates': 1, 'num_conditioning_candidates': 1,
'conditioning_length': 66000, 'conditioning_length': 66000,
'use_bpe_tokenizer': True, 'use_bpe_tokenizer': False,
'load_aligned_codes': True, 'load_aligned_codes': False,
'produce_ctc_metadata': True, 'needs_collate': False,
'produce_ctc_metadata': False,
} }
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader
@ -294,10 +300,10 @@ if __name__ == '__main__':
for ib in range(batch_sz): for ib in range(batch_sz):
#max_pads = max(max_pads, b['ctc_pads'].max()) #max_pads = max(max_pads, b['ctc_pads'].max())
#max_repeats = max(max_repeats, b['ctc_repeats'].max()) #max_repeats = max(max_repeats, b['ctc_repeats'].max())
#print(f'{i} {ib} {b["real_text"][ib]}') print(f'{i} {ib} {b["real_text"][ib]}')
#save(b, i, ib, 'wav') save(b, i, ib, 'wav')
pass pass
#if i > 5: if i > 15:
# break break
print(max_pads, max_repeats) print(max_pads, max_repeats)

View File

@ -125,6 +125,7 @@ class GaussianDiffusion:
rescale_timesteps=False, rescale_timesteps=False,
conditioning_free=False, conditioning_free=False,
conditioning_free_k=1, conditioning_free_k=1,
ramp_conditioning_free=True,
): ):
self.model_mean_type = ModelMeanType(model_mean_type) self.model_mean_type = ModelMeanType(model_mean_type)
self.model_var_type = ModelVarType(model_var_type) self.model_var_type = ModelVarType(model_var_type)
@ -132,6 +133,7 @@ class GaussianDiffusion:
self.rescale_timesteps = rescale_timesteps self.rescale_timesteps = rescale_timesteps
self.conditioning_free = conditioning_free self.conditioning_free = conditioning_free
self.conditioning_free_k = conditioning_free_k self.conditioning_free_k = conditioning_free_k
self.ramp_conditioning_free = ramp_conditioning_free
# Use float64 for accuracy. # Use float64 for accuracy.
betas = np.array(betas, dtype=np.float64) betas = np.array(betas, dtype=np.float64)
@ -299,7 +301,12 @@ class GaussianDiffusion:
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
if self.conditioning_free: if self.conditioning_free:
model_output = (1 + self.conditioning_free_k) * model_output - self.conditioning_free_k * model_output_no_conditioning if self.ramp_conditioning_free:
assert t.shape[0] == 1 # This should only be used in inference.
cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps)
else:
cfk = self.conditioning_free_k
model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning
def process_xstart(x): def process_xstart(x):
if denoised_fn is not None: if denoised_fn is not None:

View File

@ -408,7 +408,7 @@ class DiffusionTts(nn.Module):
) )
def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None, unaligned_input=None): def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None, unaligned_input=None, conditioning_free=False):
""" """
Apply the model to an input batch. Apply the model to an input batch.
@ -419,6 +419,7 @@ class DiffusionTts(nn.Module):
:param lr_input: for super-sampling models, a guidance audio clip at a lower sampling rate. :param lr_input: for super-sampling models, a guidance audio clip at a lower sampling rate.
:param unaligned_input: A structural input that is not properly aligned with the output of the diffusion model. :param unaligned_input: A structural input that is not properly aligned with the output of the diffusion model.
Can be combined with a conditioning input to produce more robust conditioning. Can be combined with a conditioning input to produce more robust conditioning.
:param conditioning_free: When set, all conditioning inputs (including tokens, conditioning_input and unaligned_input) will not be considered.
:return: an [N x C x ...] Tensor of outputs. :return: an [N x C x ...] Tensor of outputs.
""" """
assert conditioning_input is not None assert conditioning_input is not None
@ -430,11 +431,6 @@ class DiffusionTts(nn.Module):
lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest') lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
x = torch.cat([x, lr_input], dim=1) x = torch.cat([x, lr_input], dim=1)
if self.enable_unaligned_inputs:
assert unaligned_input is not None
unaligned_h = self.unaligned_embedder(unaligned_input).permute(0,2,1)
unaligned_h = self.unaligned_encoder(unaligned_h).permute(0,2,1)
with autocast(x.device.type): with autocast(x.device.type):
orig_x_shape = x.shape[-1] orig_x_shape = x.shape[-1]
cm = ceil_multiple(x.shape[-1], 2048) cm = ceil_multiple(x.shape[-1], 2048)
@ -447,6 +443,14 @@ class DiffusionTts(nn.Module):
hs = [] hs = []
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
else:
if self.enable_unaligned_inputs:
assert unaligned_input is not None
unaligned_h = self.unaligned_embedder(unaligned_input).permute(0,2,1)
unaligned_h = self.unaligned_encoder(unaligned_h).permute(0,2,1)
cond_emb = self.contextual_embedder(conditioning_input) cond_emb = self.contextual_embedder(conditioning_input)
if tokens is not None: if tokens is not None:
# Mask out guidance tokens for un-guided diffusion. # Mask out guidance tokens for un-guided diffusion.
@ -466,9 +470,9 @@ class DiffusionTts(nn.Module):
code_emb = self.conditioning_encoder(code_emb) code_emb = self.conditioning_encoder(code_emb)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.unconditioned_percentage > 0: if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((code_emb.shape[0],1,1), device=code_emb.device) < self.unconditioned_percentage unconditioned_batches = torch.rand((code_emb.shape[0],1,1), device=code_emb.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1, code_emb.shape[2]), code_emb) code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1), code_emb)
first = True first = True
time_emb = time_emb.float() time_emb = time_emb.float()

View File

@ -317,7 +317,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/debug_diffusion_tts7.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_matcher.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -41,7 +41,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
conditioning_free_diffusion_enabled = opt_get(opt_eval, ['conditioning_free'], False) conditioning_free_diffusion_enabled = opt_get(opt_eval, ['conditioning_free'], False)
conditioning_free_k = opt_get(opt_eval, ['conditioning_free_k'], 1) conditioning_free_k = opt_get(opt_eval, ['conditioning_free_k'], 1)
self.diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_steps, schedule=diffusion_schedule, self.diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_steps, schedule=diffusion_schedule,
conditioning_free_diffusion_enabled=conditioning_free_diffusion_enabled, enable_conditioning_free_guidance=conditioning_free_diffusion_enabled,
conditioning_free_k=conditioning_free_k) conditioning_free_k=conditioning_free_k)
self.dev = self.env['device'] self.dev = self.env['device']
mode = opt_get(opt_eval, ['diffusion_type'], 'tts') mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
@ -162,9 +162,10 @@ if __name__ == '__main__':
from utils.util import load_model_from_config from utils.util import load_model_from_config
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text.yml', 'generator', diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text.yml', 'generator',
also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text\\models\\5500_generator_ema.pth').cuda() also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text\\models\\39500_generator_ema.pth').cuda()
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 50, opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
'conditioning_free': True, 'conditioning_free_k': 2,
'diffusion_schedule': 'linear', 'diffusion_type': 'vocoder'} 'diffusion_schedule': 'linear', 'diffusion_type': 'vocoder'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 500, 'device': 'cuda', 'opt': {}} env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 202, 'device': 'cuda', 'opt': {}}
eval = AudioDiffusionFid(diffusion, opt_eval, env) eval = AudioDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval()) print(eval.perform_eval())

View File

@ -92,8 +92,12 @@ class GeneratorInjector(Injector):
if self.grad: if self.grad:
results = method(*params, **self.args) results = method(*params, **self.args)
else: else:
was_training = gen.training
gen.eval()
with torch.no_grad(): with torch.no_grad():
results = method(*params, **self.args) results = method(*params, **self.args)
if was_training:
gen.train()
new_state = {} new_state = {}
if isinstance(self.output, list): if isinstance(self.output, list):
# Only dereference tuples or lists, not tensors. IF YOU REACH THIS ERROR, REMOVE THE BRACES AROUND YOUR OUTPUTS IN THE YAML CONFIG # Only dereference tuples or lists, not tensors. IF YOU REACH THIS ERROR, REMOVE THE BRACES AROUND YOUR OUTPUTS IN THE YAML CONFIG