This commit is contained in:
James Betker 2022-05-29 22:32:25 -06:00
parent 2e72fddaeb
commit eab1162d2b

View File

@ -6,7 +6,7 @@ from models.diffusion.nn import timestep_embedding, normalization, zero_module,
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
from trainer.networks import register_model
from utils.util import checkpoint
from utils.util import checkpoint, print_network
def is_latent(t):
@ -248,7 +248,8 @@ if __name__ == '__main__':
ts = torch.LongTensor([600, 600])
clvp = torch.randn(2,768)
type = torch.LongTensor([0,1])
model = TransformerDiffusionTTS(model_channels=768, unconditioned_percentage=.5, in_groups=8, prenet_channels=512, block_channels=384)
model = TransformerDiffusionTTS(model_channels=3072, unconditioned_percentage=.5, in_groups=8, prenet_channels=1024, block_channels=1024)
print_network(model)
o = model(clip, ts, aligned_sequence, cond, clvp_input=clvp, type=type)
#o = model(clip, ts, aligned_sequence, cond, aligned_latent)