fix circular deps
This commit is contained in:
parent
34ee1d0bc3
commit
b4269af61b
|
@ -4,7 +4,6 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
|
||||
from models.audio.tts.unet_diffusion_tts_flat import DiffusionTtsFlat
|
||||
from trainer.inject import Injector
|
||||
from utils.music_utils import get_music_codegen
|
||||
from utils.util import opt_get, load_model_from_config, pad_or_truncate
|
||||
|
@ -250,6 +249,7 @@ class ConditioningLatentDistributionDivergenceInjector(Injector):
|
|||
also_load_savepoint=False, load_path=pretrained_path).eval()
|
||||
self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': '../experiments/clips_mel_norms.pth'},{})
|
||||
else:
|
||||
from models.audio.tts.unet_diffusion_tts_flat import DiffusionTtsFlat
|
||||
self.latent_producer = DiffusionTtsFlat(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
||||
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False,
|
||||
num_heads=16, layer_drop=0, unconditioned_percentage=0).eval()
|
||||
|
@ -369,4 +369,8 @@ class NormalizeMelInjector(Injector):
|
|||
def forward(self, state):
|
||||
mel = state[self.input]
|
||||
with torch.no_grad():
|
||||
return {self.output: normalize_mel(mel)}
|
||||
return {self.output: normalize_mel(mel)}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('hi')
|
|
@ -1,10 +1,8 @@
|
|||
import torch
|
||||
|
||||
from models.audio.mel2vec import ContrastiveTrainingWrapper
|
||||
from models.audio.music.unet_diffusion_waveform_gen_simple import DiffusionWaveformGen
|
||||
|
||||
|
||||
def get_mel2wav_model():
|
||||
from models.audio.music.unet_diffusion_waveform_gen_simple import DiffusionWaveformGen
|
||||
model = DiffusionWaveformGen(model_channels=256, in_channels=16, in_mel_channels=256, out_channels=32, channel_mult=[1,2,3,4,4],
|
||||
num_res_blocks=[3,3,2,2,1], token_conditioning_resolutions=[1,4,16], dropout=0, kernel_size=3, scale_factor=2,
|
||||
time_embed_dim_multiplier=4, unconditioned_percentage=0)
|
||||
|
@ -13,6 +11,7 @@ def get_mel2wav_model():
|
|||
return model
|
||||
|
||||
def get_music_codegen():
|
||||
from models.audio.mel2vec import ContrastiveTrainingWrapper
|
||||
model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0,
|
||||
mask_time_prob=0,
|
||||
mask_time_length=6, num_negatives=100, codebook_size=16, codebook_groups=4,
|
||||
|
|
Loading…
Reference in New Issue
Block a user