forked from mrq/DL-Art-School
hmm..
This commit is contained in:
parent
2e72fddaeb
commit
eab1162d2b
|
@ -6,7 +6,7 @@ from models.diffusion.nn import timestep_embedding, normalization, zero_module,
|
|||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
|
||||
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
from utils.util import checkpoint, print_network
|
||||
|
||||
|
||||
def is_latent(t):
|
||||
|
@ -248,7 +248,8 @@ if __name__ == '__main__':
|
|||
ts = torch.LongTensor([600, 600])
|
||||
clvp = torch.randn(2,768)
|
||||
type = torch.LongTensor([0,1])
|
||||
model = TransformerDiffusionTTS(model_channels=768, unconditioned_percentage=.5, in_groups=8, prenet_channels=512, block_channels=384)
|
||||
model = TransformerDiffusionTTS(model_channels=3072, unconditioned_percentage=.5, in_groups=8, prenet_channels=1024, block_channels=1024)
|
||||
print_network(model)
|
||||
o = model(clip, ts, aligned_sequence, cond, clvp_input=clvp, type=type)
|
||||
#o = model(clip, ts, aligned_sequence, cond, aligned_latent)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user