From b4269af61b9cf9b040dacbfbde8113289374a391 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 27 May 2022 11:44:27 -0600 Subject: [PATCH] fix circular deps --- codes/trainer/injectors/audio_injectors.py | 8 ++++++-- codes/utils/music_utils.py | 5 ++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index 578c37e3..fb0a734d 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -4,7 +4,6 @@ import torch import torch.nn.functional as F import torchaudio -from models.audio.tts.unet_diffusion_tts_flat import DiffusionTtsFlat from trainer.inject import Injector from utils.music_utils import get_music_codegen from utils.util import opt_get, load_model_from_config, pad_or_truncate @@ -250,6 +249,7 @@ class ConditioningLatentDistributionDivergenceInjector(Injector): also_load_savepoint=False, load_path=pretrained_path).eval() self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': '../experiments/clips_mel_norms.pth'},{}) else: + from models.audio.tts.unet_diffusion_tts_flat import DiffusionTtsFlat self.latent_producer = DiffusionTtsFlat(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, layer_drop=0, unconditioned_percentage=0).eval() @@ -369,4 +369,8 @@ class NormalizeMelInjector(Injector): def forward(self, state): mel = state[self.input] with torch.no_grad(): - return {self.output: normalize_mel(mel)} \ No newline at end of file + return {self.output: normalize_mel(mel)} + + +if __name__ == '__main__': + print('hi') \ No newline at end of file diff --git a/codes/utils/music_utils.py b/codes/utils/music_utils.py index 24d34e84..6b4ca673 100644 --- a/codes/utils/music_utils.py +++ b/codes/utils/music_utils.py @@ -1,10 +1,8 @@ import torch -from models.audio.mel2vec import ContrastiveTrainingWrapper -from models.audio.music.unet_diffusion_waveform_gen_simple import DiffusionWaveformGen - def get_mel2wav_model(): + from models.audio.music.unet_diffusion_waveform_gen_simple import DiffusionWaveformGen model = DiffusionWaveformGen(model_channels=256, in_channels=16, in_mel_channels=256, out_channels=32, channel_mult=[1,2,3,4,4], num_res_blocks=[3,3,2,2,1], token_conditioning_resolutions=[1,4,16], dropout=0, kernel_size=3, scale_factor=2, time_embed_dim_multiplier=4, unconditioned_percentage=0) @@ -13,6 +11,7 @@ def get_mel2wav_model(): return model def get_music_codegen(): + from models.audio.mel2vec import ContrastiveTrainingWrapper model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, mask_time_prob=0, mask_time_length=6, num_negatives=100, codebook_size=16, codebook_groups=4,