forked from mrq/DL-Art-School
fix circular deps
This commit is contained in:
parent
34ee1d0bc3
commit
b4269af61b
|
@ -4,7 +4,6 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
from models.audio.tts.unet_diffusion_tts_flat import DiffusionTtsFlat
|
|
||||||
from trainer.inject import Injector
|
from trainer.inject import Injector
|
||||||
from utils.music_utils import get_music_codegen
|
from utils.music_utils import get_music_codegen
|
||||||
from utils.util import opt_get, load_model_from_config, pad_or_truncate
|
from utils.util import opt_get, load_model_from_config, pad_or_truncate
|
||||||
|
@ -250,6 +249,7 @@ class ConditioningLatentDistributionDivergenceInjector(Injector):
|
||||||
also_load_savepoint=False, load_path=pretrained_path).eval()
|
also_load_savepoint=False, load_path=pretrained_path).eval()
|
||||||
self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': '../experiments/clips_mel_norms.pth'},{})
|
self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': '../experiments/clips_mel_norms.pth'},{})
|
||||||
else:
|
else:
|
||||||
|
from models.audio.tts.unet_diffusion_tts_flat import DiffusionTtsFlat
|
||||||
self.latent_producer = DiffusionTtsFlat(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
self.latent_producer = DiffusionTtsFlat(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
||||||
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False,
|
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False,
|
||||||
num_heads=16, layer_drop=0, unconditioned_percentage=0).eval()
|
num_heads=16, layer_drop=0, unconditioned_percentage=0).eval()
|
||||||
|
@ -369,4 +369,8 @@ class NormalizeMelInjector(Injector):
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
mel = state[self.input]
|
mel = state[self.input]
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
return {self.output: normalize_mel(mel)}
|
return {self.output: normalize_mel(mel)}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print('hi')
|
|
@ -1,10 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from models.audio.mel2vec import ContrastiveTrainingWrapper
|
|
||||||
from models.audio.music.unet_diffusion_waveform_gen_simple import DiffusionWaveformGen
|
|
||||||
|
|
||||||
|
|
||||||
def get_mel2wav_model():
|
def get_mel2wav_model():
|
||||||
|
from models.audio.music.unet_diffusion_waveform_gen_simple import DiffusionWaveformGen
|
||||||
model = DiffusionWaveformGen(model_channels=256, in_channels=16, in_mel_channels=256, out_channels=32, channel_mult=[1,2,3,4,4],
|
model = DiffusionWaveformGen(model_channels=256, in_channels=16, in_mel_channels=256, out_channels=32, channel_mult=[1,2,3,4,4],
|
||||||
num_res_blocks=[3,3,2,2,1], token_conditioning_resolutions=[1,4,16], dropout=0, kernel_size=3, scale_factor=2,
|
num_res_blocks=[3,3,2,2,1], token_conditioning_resolutions=[1,4,16], dropout=0, kernel_size=3, scale_factor=2,
|
||||||
time_embed_dim_multiplier=4, unconditioned_percentage=0)
|
time_embed_dim_multiplier=4, unconditioned_percentage=0)
|
||||||
|
@ -13,6 +11,7 @@ def get_mel2wav_model():
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_music_codegen():
|
def get_music_codegen():
|
||||||
|
from models.audio.mel2vec import ContrastiveTrainingWrapper
|
||||||
model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0,
|
model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0,
|
||||||
mask_time_prob=0,
|
mask_time_prob=0,
|
||||||
mask_time_length=6, num_negatives=100, codebook_size=16, codebook_groups=4,
|
mask_time_length=6, num_negatives=100, codebook_size=16, codebook_groups=4,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user