Go back to vanilla flavor of diffusion

This commit is contained in:
James Betker 2021-10-17 17:32:46 -06:00
parent 23da073037
commit d016a2fbad
4 changed files with 25 additions and 32 deletions

View File

@ -141,12 +141,12 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
if __name__ == '__main__':
params = {
'mode': 'unsupervised_audio',
'path': ['Z:\\split\\cleaned\\books0', 'Z:\\split\\cleaned\\books2'],
'cache_path': 'E:\\audio\\remote-cache.pth',
'path': ['\\\\192.168.5.3\\rtx3080_audio_y\\split\\books2', '\\\\192.168.5.3\\rtx3080_audio\\split\\books1', '\\\\192.168.5.3\\rtx3080_audio\\split\\cleaned-2'],
'cache_path': 'E:\\audio\\remote-cache2.pth',
'sampling_rate': 22050,
'pad_to_seconds': 5,
'pad_to_samples': 40960,
'phase': 'train',
'n_workers': 4,
'n_workers': 1,
'batch_size': 16,
'extra_samples': 4,
}
@ -156,8 +156,7 @@ if __name__ == '__main__':
dl = create_dataloader(ds, params)
i = 0
for b in tqdm(dl):
for b_ in range(16):
pass
#torchaudio.save(f'{i}_clip1_{b_}.wav', b['clip1'][b_], ds.sampling_rate)
#torchaudio.save(f'{i}_clip2_{b_}.wav', b['clip2'][b_], ds.sampling_rate)
#i += 1
for b_ in range(b['clip'].shape[0]):
#pass
torchaudio.save(f'{i}_clip_{b_}.wav', b['clip'][b_], ds.sampling_rate)
i += 1

View File

@ -13,17 +13,10 @@ from utils.util import get_mask_from_lengths
class DiscreteSpectrogramConditioningBlock(nn.Module):
def __init__(self, dvae_channels, channels):
super().__init__()
self.emb = nn.Conv1d(dvae_channels, channels, kernel_size=1)
self.intg = nn.Sequential(
normalization(channels*2),
nn.SiLU(),
nn.Conv1d(channels*2, channels*2, kernel_size=1),
normalization(channels*2),
nn.SiLU(),
nn.Conv1d(channels*2, channels, kernel_size=3, padding=1),
self.intg = nn.Sequential(nn.Conv1d(dvae_channels, channels, kernel_size=1),
normalization(channels),
nn.SiLU(),
zero_module(nn.Conv1d(channels, channels, kernel_size=1)))
nn.Conv1d(channels, channels, kernel_size=3))
"""
Embeds the given codes and concatenates them onto x. Return shape is the same as x.shape.
@ -34,11 +27,9 @@ class DiscreteSpectrogramConditioningBlock(nn.Module):
def forward(self, x, dvae_in):
b, c, S = x.shape
_, q, N = dvae_in.shape
emb = self.emb(dvae_in)
emb = self.intg(dvae_in)
emb = nn.functional.interpolate(emb, size=(S,), mode='nearest')
together = torch.cat([x, emb], dim=1)
together = self.intg(together)
return together + x
return torch.cat([x, emb], dim=1)
class DiffusionVocoderWithRef(nn.Module):
@ -81,10 +72,10 @@ class DiffusionVocoderWithRef(nn.Module):
dropout=0,
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
num_res_blocks=(1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2),
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
spectrogram_conditioning_resolutions=(1,8,64,512),
spectrogram_conditioning_resolutions=(512,),
attention_resolutions=(512,1024,2048),
conv_resample=True,
dims=1,
@ -149,6 +140,7 @@ class DiffusionVocoderWithRef(nn.Module):
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
if ds in spectrogram_conditioning_resolutions:
self.input_blocks.append(DiscreteSpectrogramConditioningBlock(discrete_codes, ch))
ch *= 2
for _ in range(num_blocks):
layers = [
@ -340,7 +332,7 @@ def register_unet_diffusion_vocoder_with_ref(opt_net, opt):
if __name__ == '__main__':
clip = torch.randn(2, 1, 40960)
#spec = torch.randint(8192, (2, 40,))
spec = torch.randn(8,512,160)
spec = torch.randn(2,512,160)
cond = torch.randn(2, 3, 80, 173)
ts = torch.LongTensor([555, 556])
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False)

View File

@ -14,7 +14,7 @@ from torchvision.transforms import ToTensor
import utils
import utils.options as option
import utils.util as util
from data.audio.wavfile_dataset import load_audio_from_wav
from models.tacotron2.taco_utils import load_wav_to_torch
from trainer.ExtensibleTrainer import ExtensibleTrainer
from data import create_dataset, create_dataloader
from tqdm import tqdm
@ -55,7 +55,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_vocoder.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_vocoder_10-17.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
@ -70,7 +70,9 @@ if __name__ == "__main__":
# Load test image
if audio_mode:
im = load_audio_from_wav(opt['image'], opt['sample_rate'])
im, sr = load_wav_to_torch(opt['image'])
assert sr == 22050
im = im.unsqueeze(0)
im = im[:, :(im.shape[1]//4096)*4096]
else:
im = ToTensor()(Image.open(opt['image'])) * 2 - 1

View File

@ -96,9 +96,6 @@ class GaussianDiffusionInferenceInjector(Injector):
self.use_ema_model = opt_get(opt, ['use_ema'], False)
self.noise_style = opt_get(opt, ['noise_type'], 'random') # 'zero', 'fixed' or 'random'
self.model_fn = opt_get(opt, ['model_function'], None)
self.model_fn = None if self.model_fn is None else getattr(self.generator, self.model_fn)
def forward(self, state):
if self.use_ema_model:
gen = self.env['emas'][self.opt['generator']]
@ -114,6 +111,9 @@ class GaussianDiffusionInferenceInjector(Injector):
elif 'spectrogram' in model_inputs.keys():
output_shape = (self.output_batch_size, 1, model_inputs['spectrogram'].shape[-1]*256)
dev = model_inputs['spectrogram'].device
elif 'discrete_spectrogram' in model_inputs.keys():
output_shape = (self.output_batch_size, 1, model_inputs['discrete_spectrogram'].shape[-1]*1024)
dev = model_inputs['discrete_spectrogram'].device
else:
raise NotImplementedError
noise = None
@ -123,7 +123,7 @@ class GaussianDiffusionInferenceInjector(Injector):
if not hasattr(self, 'fixed_noise') or self.fixed_noise.shape != output_shape:
self.fixed_noise = torch.randn(output_shape, device=dev)
noise = self.fixed_noise
gen = self.sampling_fn(self.model_fn, output_shape, noise=noise, model_kwargs=model_inputs, progress=True, device=dev)
gen = self.sampling_fn(gen, output_shape, noise=noise, model_kwargs=model_inputs, progress=True, device=dev)
if self.undo_n1_to_1:
gen = (gen + 1) / 2
return {self.output: gen}