Go back to vanilla flavor of diffusion

This commit is contained in:
James Betker 2021-10-17 17:32:46 -06:00
parent 23da073037
commit d016a2fbad
4 changed files with 25 additions and 32 deletions

View File

@ -141,12 +141,12 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
if __name__ == '__main__': if __name__ == '__main__':
params = { params = {
'mode': 'unsupervised_audio', 'mode': 'unsupervised_audio',
'path': ['Z:\\split\\cleaned\\books0', 'Z:\\split\\cleaned\\books2'], 'path': ['\\\\192.168.5.3\\rtx3080_audio_y\\split\\books2', '\\\\192.168.5.3\\rtx3080_audio\\split\\books1', '\\\\192.168.5.3\\rtx3080_audio\\split\\cleaned-2'],
'cache_path': 'E:\\audio\\remote-cache.pth', 'cache_path': 'E:\\audio\\remote-cache2.pth',
'sampling_rate': 22050, 'sampling_rate': 22050,
'pad_to_seconds': 5, 'pad_to_samples': 40960,
'phase': 'train', 'phase': 'train',
'n_workers': 4, 'n_workers': 1,
'batch_size': 16, 'batch_size': 16,
'extra_samples': 4, 'extra_samples': 4,
} }
@ -156,8 +156,7 @@ if __name__ == '__main__':
dl = create_dataloader(ds, params) dl = create_dataloader(ds, params)
i = 0 i = 0
for b in tqdm(dl): for b in tqdm(dl):
for b_ in range(16): for b_ in range(b['clip'].shape[0]):
pass #pass
#torchaudio.save(f'{i}_clip1_{b_}.wav', b['clip1'][b_], ds.sampling_rate) torchaudio.save(f'{i}_clip_{b_}.wav', b['clip'][b_], ds.sampling_rate)
#torchaudio.save(f'{i}_clip2_{b_}.wav', b['clip2'][b_], ds.sampling_rate) i += 1
#i += 1

View File

@ -13,17 +13,10 @@ from utils.util import get_mask_from_lengths
class DiscreteSpectrogramConditioningBlock(nn.Module): class DiscreteSpectrogramConditioningBlock(nn.Module):
def __init__(self, dvae_channels, channels): def __init__(self, dvae_channels, channels):
super().__init__() super().__init__()
self.emb = nn.Conv1d(dvae_channels, channels, kernel_size=1) self.intg = nn.Sequential(nn.Conv1d(dvae_channels, channels, kernel_size=1),
self.intg = nn.Sequential(
normalization(channels*2),
nn.SiLU(),
nn.Conv1d(channels*2, channels*2, kernel_size=1),
normalization(channels*2),
nn.SiLU(),
nn.Conv1d(channels*2, channels, kernel_size=3, padding=1),
normalization(channels), normalization(channels),
nn.SiLU(), nn.SiLU(),
zero_module(nn.Conv1d(channels, channels, kernel_size=1))) nn.Conv1d(channels, channels, kernel_size=3))
""" """
Embeds the given codes and concatenates them onto x. Return shape is the same as x.shape. Embeds the given codes and concatenates them onto x. Return shape is the same as x.shape.
@ -34,11 +27,9 @@ class DiscreteSpectrogramConditioningBlock(nn.Module):
def forward(self, x, dvae_in): def forward(self, x, dvae_in):
b, c, S = x.shape b, c, S = x.shape
_, q, N = dvae_in.shape _, q, N = dvae_in.shape
emb = self.emb(dvae_in) emb = self.intg(dvae_in)
emb = nn.functional.interpolate(emb, size=(S,), mode='nearest') emb = nn.functional.interpolate(emb, size=(S,), mode='nearest')
together = torch.cat([x, emb], dim=1) return torch.cat([x, emb], dim=1)
together = self.intg(together)
return together + x
class DiffusionVocoderWithRef(nn.Module): class DiffusionVocoderWithRef(nn.Module):
@ -81,10 +72,10 @@ class DiffusionVocoderWithRef(nn.Module):
dropout=0, dropout=0,
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48), channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
num_res_blocks=(1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2), num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0) # spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 # attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
spectrogram_conditioning_resolutions=(1,8,64,512), spectrogram_conditioning_resolutions=(512,),
attention_resolutions=(512,1024,2048), attention_resolutions=(512,1024,2048),
conv_resample=True, conv_resample=True,
dims=1, dims=1,
@ -149,6 +140,7 @@ class DiffusionVocoderWithRef(nn.Module):
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)): for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
if ds in spectrogram_conditioning_resolutions: if ds in spectrogram_conditioning_resolutions:
self.input_blocks.append(DiscreteSpectrogramConditioningBlock(discrete_codes, ch)) self.input_blocks.append(DiscreteSpectrogramConditioningBlock(discrete_codes, ch))
ch *= 2
for _ in range(num_blocks): for _ in range(num_blocks):
layers = [ layers = [
@ -340,7 +332,7 @@ def register_unet_diffusion_vocoder_with_ref(opt_net, opt):
if __name__ == '__main__': if __name__ == '__main__':
clip = torch.randn(2, 1, 40960) clip = torch.randn(2, 1, 40960)
#spec = torch.randint(8192, (2, 40,)) #spec = torch.randint(8192, (2, 40,))
spec = torch.randn(8,512,160) spec = torch.randn(2,512,160)
cond = torch.randn(2, 3, 80, 173) cond = torch.randn(2, 3, 80, 173)
ts = torch.LongTensor([555, 556]) ts = torch.LongTensor([555, 556])
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False) model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False)

View File

@ -14,7 +14,7 @@ from torchvision.transforms import ToTensor
import utils import utils
import utils.options as option import utils.options as option
import utils.util as util import utils.util as util
from data.audio.wavfile_dataset import load_audio_from_wav from models.tacotron2.taco_utils import load_wav_to_torch
from trainer.ExtensibleTrainer import ExtensibleTrainer from trainer.ExtensibleTrainer import ExtensibleTrainer
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader
from tqdm import tqdm from tqdm import tqdm
@ -55,7 +55,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
want_metrics = False want_metrics = False
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_vocoder.yml') parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_vocoder_10-17.yml')
opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt utils.util.loaded_options = opt
@ -70,7 +70,9 @@ if __name__ == "__main__":
# Load test image # Load test image
if audio_mode: if audio_mode:
im = load_audio_from_wav(opt['image'], opt['sample_rate']) im, sr = load_wav_to_torch(opt['image'])
assert sr == 22050
im = im.unsqueeze(0)
im = im[:, :(im.shape[1]//4096)*4096] im = im[:, :(im.shape[1]//4096)*4096]
else: else:
im = ToTensor()(Image.open(opt['image'])) * 2 - 1 im = ToTensor()(Image.open(opt['image'])) * 2 - 1

View File

@ -96,9 +96,6 @@ class GaussianDiffusionInferenceInjector(Injector):
self.use_ema_model = opt_get(opt, ['use_ema'], False) self.use_ema_model = opt_get(opt, ['use_ema'], False)
self.noise_style = opt_get(opt, ['noise_type'], 'random') # 'zero', 'fixed' or 'random' self.noise_style = opt_get(opt, ['noise_type'], 'random') # 'zero', 'fixed' or 'random'
self.model_fn = opt_get(opt, ['model_function'], None)
self.model_fn = None if self.model_fn is None else getattr(self.generator, self.model_fn)
def forward(self, state): def forward(self, state):
if self.use_ema_model: if self.use_ema_model:
gen = self.env['emas'][self.opt['generator']] gen = self.env['emas'][self.opt['generator']]
@ -114,6 +111,9 @@ class GaussianDiffusionInferenceInjector(Injector):
elif 'spectrogram' in model_inputs.keys(): elif 'spectrogram' in model_inputs.keys():
output_shape = (self.output_batch_size, 1, model_inputs['spectrogram'].shape[-1]*256) output_shape = (self.output_batch_size, 1, model_inputs['spectrogram'].shape[-1]*256)
dev = model_inputs['spectrogram'].device dev = model_inputs['spectrogram'].device
elif 'discrete_spectrogram' in model_inputs.keys():
output_shape = (self.output_batch_size, 1, model_inputs['discrete_spectrogram'].shape[-1]*1024)
dev = model_inputs['discrete_spectrogram'].device
else: else:
raise NotImplementedError raise NotImplementedError
noise = None noise = None
@ -123,7 +123,7 @@ class GaussianDiffusionInferenceInjector(Injector):
if not hasattr(self, 'fixed_noise') or self.fixed_noise.shape != output_shape: if not hasattr(self, 'fixed_noise') or self.fixed_noise.shape != output_shape:
self.fixed_noise = torch.randn(output_shape, device=dev) self.fixed_noise = torch.randn(output_shape, device=dev)
noise = self.fixed_noise noise = self.fixed_noise
gen = self.sampling_fn(self.model_fn, output_shape, noise=noise, model_kwargs=model_inputs, progress=True, device=dev) gen = self.sampling_fn(gen, output_shape, noise=noise, model_kwargs=model_inputs, progress=True, device=dev)
if self.undo_n1_to_1: if self.undo_n1_to_1:
gen = (gen + 1) / 2 gen = (gen + 1) / 2
return {self.output: gen} return {self.output: gen}