This commit is contained in:
James Betker 2021-12-11 08:17:26 -07:00
parent d610540ce5
commit 5a664aa56e
7 changed files with 7 additions and 53 deletions

View File

@ -646,22 +646,6 @@ class UNetModel(nn.Module):
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps, y=None):
"""
Apply the model to an input batch.
@ -891,20 +875,6 @@ class EncoderUNetModel(nn.Module):
else:
raise NotImplementedError(f"Unexpected {pool} pooling")
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
def forward(self, x, timesteps):
"""
Apply the model to an input batch.

View File

@ -281,22 +281,6 @@ class DiffusionVocoderWithRef(nn.Module):
del p.DO_NOT_TRAIN
p.requires_grad = True
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps, spectrogram, conditioning_input=None):
"""
Apply the model to an input batch.

View File

@ -15,7 +15,7 @@ def wav_to_mel(wav):
"""
Converts an audio clip into a MEL tensor that the vocoder, DVAE and GptTts models use whenever a MEL is called for.
"""
return TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'normalize': True},{})({'wav': wav})['mel']
return TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel'},{})({'wav': wav})['mel']
def convert_mel_to_codes(dvae_model, mel):

View File

@ -26,7 +26,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model', default='X:\\dlas\\experiments\\train_diffusion_vocoder_with_cond_new_dvae.yml')
parser.add_argument('-diffusion_model_name', type=str, help='Name of the diffusion model in opt.', default='generator')
parser.add_argument('-diffusion_model_path', type=str, help='Name of the diffusion model in opt.', default='X:\\dlas\\experiments\\train_diffusion_vocoder_with_cond_new_dvae\\models\\6200_generator_ema.pth')
parser.add_argument('-diffusion_model_path', type=str, help='Name of the diffusion model in opt.', default='X:\\dlas\\experiments\\train_diffusion_vocoder_with_cond_new_dvae_full\\models\\200_generator_ema.pth')
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
parser.add_argument('-input_file', type=str, help='Path to the input audio file.', default='Z:\\clips\\books1\\3_dchha04 Romancing The Tribes\\00036.wav')
parser.add_argument('-cond', type=str, help='Path to the conditioning input audio file.', default=None)

View File

@ -49,10 +49,7 @@ def forward_pass(model, data, output_dir, spacing, audio_mode):
def load_image(path, audio_mode):
# Load test image
if audio_mode:
im = load_audio(path, 22050)
padding_needed = ((im.shape[1]//8192)+1)*8192-im.shape[1]
im = torch.nn.functional.pad(im, (0, padding_needed))
im = im[:, :(im.shape[1]//8192)*8192].unsqueeze(0)
im = load_audio(path, 22050).unsqueeze(0)
else:
im = ToTensor()(Image.open(path)) * 2 - 1
_, h, w = im.shape

View File

@ -286,7 +286,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/validate_lrdvae_proper.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_bench.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -627,6 +627,9 @@ def test_torch_mel_injector():
inj = TorchMelSpectrogramInjector({'in': 'in', 'out': 'out'}, {})
f = inj({'in': a.unsqueeze(0)})['out']
plot_spectrogram(f[0])
inj = MelSpectrogramInjector({'in': 'in', 'out': 'out'}, {})
t = inj({'in': a.unsqueeze(0)})['out']
plot_spectrogram(t[0])
print('Pause')