Implement guidance-free diffusion in eval
And a few other fixes
This commit is contained in:
parent
2134f06516
commit
db0c3340ac
|
@ -263,19 +263,25 @@ if __name__ == '__main__':
|
|||
batch_sz = 256
|
||||
params = {
|
||||
'mode': 'fast_paired_voice_audio',
|
||||
'path': ['Y:\\libritts\\train-clean-360\\transcribed-w2v.tsv', 'Y:\\clips\\books1\\transcribed-w2v.tsv'],
|
||||
'path': ['y:/libritts/train-other-500/transcribed-oco.tsv',
|
||||
'y:/libritts/train-clean-100/transcribed-oco.tsv',
|
||||
'y:/libritts/train-clean-360/transcribed-oco.tsv',
|
||||
'y:/clips/books1/transcribed-w2v.tsv',
|
||||
'y:/clips/books2/transcribed-w2v.tsv',
|
||||
'y:/bigasr_dataset/hifi_tts/transcribed-w2v.tsv'],
|
||||
'phase': 'train',
|
||||
'n_workers': 0,
|
||||
'batch_size': batch_sz,
|
||||
'max_wav_length': 255995,
|
||||
'max_wav_length': 163840,
|
||||
'max_text_length': 200,
|
||||
'sample_rate': 22050,
|
||||
'load_conditioning': True,
|
||||
'num_conditioning_candidates': 1,
|
||||
'conditioning_length': 66000,
|
||||
'use_bpe_tokenizer': True,
|
||||
'load_aligned_codes': True,
|
||||
'produce_ctc_metadata': True,
|
||||
'use_bpe_tokenizer': False,
|
||||
'load_aligned_codes': False,
|
||||
'needs_collate': False,
|
||||
'produce_ctc_metadata': False,
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
|
||||
|
@ -294,10 +300,10 @@ if __name__ == '__main__':
|
|||
for ib in range(batch_sz):
|
||||
#max_pads = max(max_pads, b['ctc_pads'].max())
|
||||
#max_repeats = max(max_repeats, b['ctc_repeats'].max())
|
||||
#print(f'{i} {ib} {b["real_text"][ib]}')
|
||||
#save(b, i, ib, 'wav')
|
||||
print(f'{i} {ib} {b["real_text"][ib]}')
|
||||
save(b, i, ib, 'wav')
|
||||
pass
|
||||
#if i > 5:
|
||||
# break
|
||||
if i > 15:
|
||||
break
|
||||
print(max_pads, max_repeats)
|
||||
|
||||
|
|
|
@ -125,6 +125,7 @@ class GaussianDiffusion:
|
|||
rescale_timesteps=False,
|
||||
conditioning_free=False,
|
||||
conditioning_free_k=1,
|
||||
ramp_conditioning_free=True,
|
||||
):
|
||||
self.model_mean_type = ModelMeanType(model_mean_type)
|
||||
self.model_var_type = ModelVarType(model_var_type)
|
||||
|
@ -132,6 +133,7 @@ class GaussianDiffusion:
|
|||
self.rescale_timesteps = rescale_timesteps
|
||||
self.conditioning_free = conditioning_free
|
||||
self.conditioning_free_k = conditioning_free_k
|
||||
self.ramp_conditioning_free = ramp_conditioning_free
|
||||
|
||||
# Use float64 for accuracy.
|
||||
betas = np.array(betas, dtype=np.float64)
|
||||
|
@ -299,7 +301,12 @@ class GaussianDiffusion:
|
|||
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
||||
|
||||
if self.conditioning_free:
|
||||
model_output = (1 + self.conditioning_free_k) * model_output - self.conditioning_free_k * model_output_no_conditioning
|
||||
if self.ramp_conditioning_free:
|
||||
assert t.shape[0] == 1 # This should only be used in inference.
|
||||
cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps)
|
||||
else:
|
||||
cfk = self.conditioning_free_k
|
||||
model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning
|
||||
|
||||
def process_xstart(x):
|
||||
if denoised_fn is not None:
|
||||
|
|
|
@ -408,7 +408,7 @@ class DiffusionTts(nn.Module):
|
|||
)
|
||||
|
||||
|
||||
def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None, unaligned_input=None):
|
||||
def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None, unaligned_input=None, conditioning_free=False):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
|
@ -419,6 +419,7 @@ class DiffusionTts(nn.Module):
|
|||
:param lr_input: for super-sampling models, a guidance audio clip at a lower sampling rate.
|
||||
:param unaligned_input: A structural input that is not properly aligned with the output of the diffusion model.
|
||||
Can be combined with a conditioning input to produce more robust conditioning.
|
||||
:param conditioning_free: When set, all conditioning inputs (including tokens, conditioning_input and unaligned_input) will not be considered.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert conditioning_input is not None
|
||||
|
@ -430,11 +431,6 @@ class DiffusionTts(nn.Module):
|
|||
lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
|
||||
x = torch.cat([x, lr_input], dim=1)
|
||||
|
||||
if self.enable_unaligned_inputs:
|
||||
assert unaligned_input is not None
|
||||
unaligned_h = self.unaligned_embedder(unaligned_input).permute(0,2,1)
|
||||
unaligned_h = self.unaligned_encoder(unaligned_h).permute(0,2,1)
|
||||
|
||||
with autocast(x.device.type):
|
||||
orig_x_shape = x.shape[-1]
|
||||
cm = ceil_multiple(x.shape[-1], 2048)
|
||||
|
@ -447,28 +443,36 @@ class DiffusionTts(nn.Module):
|
|||
hs = []
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
cond_emb = self.contextual_embedder(conditioning_input)
|
||||
if tokens is not None:
|
||||
# Mask out guidance tokens for un-guided diffusion.
|
||||
if self.training and self.nil_guidance_fwd_proportion > 0:
|
||||
token_mask = clustered_mask(self.nil_guidance_fwd_proportion, tokens.shape, tokens.device, inverted=True)
|
||||
tokens = torch.where(token_mask, self.mask_token_id, tokens)
|
||||
code_emb = self.code_embedding(tokens).permute(0,2,1)
|
||||
cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
|
||||
cond_time_emb = timestep_embedding(torch.zeros_like(timesteps), code_emb.shape[1]) # This was something I was doing (adding timesteps into this computation), but removed on second thought. TODO: completely remove.
|
||||
cond_time_emb = cond_time_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
|
||||
code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb, cond_time_emb], dim=1))
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
|
||||
else:
|
||||
code_emb = cond_emb.unsqueeze(-1)
|
||||
if self.enable_unaligned_inputs:
|
||||
code_emb = self.conditioning_encoder(code_emb, context=unaligned_h)
|
||||
else:
|
||||
code_emb = self.conditioning_encoder(code_emb)
|
||||
if self.enable_unaligned_inputs:
|
||||
assert unaligned_input is not None
|
||||
unaligned_h = self.unaligned_embedder(unaligned_input).permute(0,2,1)
|
||||
unaligned_h = self.unaligned_encoder(unaligned_h).permute(0,2,1)
|
||||
|
||||
cond_emb = self.contextual_embedder(conditioning_input)
|
||||
if tokens is not None:
|
||||
# Mask out guidance tokens for un-guided diffusion.
|
||||
if self.training and self.nil_guidance_fwd_proportion > 0:
|
||||
token_mask = clustered_mask(self.nil_guidance_fwd_proportion, tokens.shape, tokens.device, inverted=True)
|
||||
tokens = torch.where(token_mask, self.mask_token_id, tokens)
|
||||
code_emb = self.code_embedding(tokens).permute(0,2,1)
|
||||
cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
|
||||
cond_time_emb = timestep_embedding(torch.zeros_like(timesteps), code_emb.shape[1]) # This was something I was doing (adding timesteps into this computation), but removed on second thought. TODO: completely remove.
|
||||
cond_time_emb = cond_time_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
|
||||
code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb, cond_time_emb], dim=1))
|
||||
else:
|
||||
code_emb = cond_emb.unsqueeze(-1)
|
||||
if self.enable_unaligned_inputs:
|
||||
code_emb = self.conditioning_encoder(code_emb, context=unaligned_h)
|
||||
else:
|
||||
code_emb = self.conditioning_encoder(code_emb)
|
||||
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.unconditioned_percentage > 0:
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((code_emb.shape[0],1,1), device=code_emb.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1, code_emb.shape[2]), code_emb)
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1), code_emb)
|
||||
|
||||
first = True
|
||||
time_emb = time_emb.float()
|
||||
|
|
|
@ -317,7 +317,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/debug_diffusion_tts7.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_matcher.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -41,7 +41,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
conditioning_free_diffusion_enabled = opt_get(opt_eval, ['conditioning_free'], False)
|
||||
conditioning_free_k = opt_get(opt_eval, ['conditioning_free_k'], 1)
|
||||
self.diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_steps, schedule=diffusion_schedule,
|
||||
conditioning_free_diffusion_enabled=conditioning_free_diffusion_enabled,
|
||||
enable_conditioning_free_guidance=conditioning_free_diffusion_enabled,
|
||||
conditioning_free_k=conditioning_free_k)
|
||||
self.dev = self.env['device']
|
||||
mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
|
||||
|
@ -162,9 +162,10 @@ if __name__ == '__main__':
|
|||
from utils.util import load_model_from_config
|
||||
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text.yml', 'generator',
|
||||
also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text\\models\\5500_generator_ema.pth').cuda()
|
||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 50,
|
||||
also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_tts7_dvae_thin_with_text\\models\\39500_generator_ema.pth').cuda()
|
||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
||||
'conditioning_free': True, 'conditioning_free_k': 2,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'vocoder'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 500, 'device': 'cuda', 'opt': {}}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 202, 'device': 'cuda', 'opt': {}}
|
||||
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
||||
print(eval.perform_eval())
|
|
@ -92,8 +92,12 @@ class GeneratorInjector(Injector):
|
|||
if self.grad:
|
||||
results = method(*params, **self.args)
|
||||
else:
|
||||
was_training = gen.training
|
||||
gen.eval()
|
||||
with torch.no_grad():
|
||||
results = method(*params, **self.args)
|
||||
if was_training:
|
||||
gen.train()
|
||||
new_state = {}
|
||||
if isinstance(self.output, list):
|
||||
# Only dereference tuples or lists, not tensors. IF YOU REACH THIS ERROR, REMOVE THE BRACES AROUND YOUR OUTPUTS IN THE YAML CONFIG
|
||||
|
|
Loading…
Reference in New Issue
Block a user