fix circular deps

This commit is contained in:
James Betker 2022-05-27 11:44:27 -06:00
parent 34ee1d0bc3
commit b4269af61b
2 changed files with 8 additions and 5 deletions

View File

@ -4,7 +4,6 @@ import torch
import torch.nn.functional as F
import torchaudio
from models.audio.tts.unet_diffusion_tts_flat import DiffusionTtsFlat
from trainer.inject import Injector
from utils.music_utils import get_music_codegen
from utils.util import opt_get, load_model_from_config, pad_or_truncate
@ -250,6 +249,7 @@ class ConditioningLatentDistributionDivergenceInjector(Injector):
also_load_savepoint=False, load_path=pretrained_path).eval()
self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': '../experiments/clips_mel_norms.pth'},{})
else:
from models.audio.tts.unet_diffusion_tts_flat import DiffusionTtsFlat
self.latent_producer = DiffusionTtsFlat(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False,
num_heads=16, layer_drop=0, unconditioned_percentage=0).eval()
@ -369,4 +369,8 @@ class NormalizeMelInjector(Injector):
def forward(self, state):
mel = state[self.input]
with torch.no_grad():
return {self.output: normalize_mel(mel)}
return {self.output: normalize_mel(mel)}
if __name__ == '__main__':
print('hi')

View File

@ -1,10 +1,8 @@
import torch
from models.audio.mel2vec import ContrastiveTrainingWrapper
from models.audio.music.unet_diffusion_waveform_gen_simple import DiffusionWaveformGen
def get_mel2wav_model():
from models.audio.music.unet_diffusion_waveform_gen_simple import DiffusionWaveformGen
model = DiffusionWaveformGen(model_channels=256, in_channels=16, in_mel_channels=256, out_channels=32, channel_mult=[1,2,3,4,4],
num_res_blocks=[3,3,2,2,1], token_conditioning_resolutions=[1,4,16], dropout=0, kernel_size=3, scale_factor=2,
time_embed_dim_multiplier=4, unconditioned_percentage=0)
@ -13,6 +11,7 @@ def get_mel2wav_model():
return model
def get_music_codegen():
from models.audio.mel2vec import ContrastiveTrainingWrapper
model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0,
mask_time_prob=0,
mask_time_length=6, num_negatives=100, codebook_size=16, codebook_groups=4,