forked from mrq/DL-Art-School
hmm..
This commit is contained in:
parent
2e72fddaeb
commit
eab1162d2b
|
@ -6,7 +6,7 @@ from models.diffusion.nn import timestep_embedding, normalization, zero_module,
|
||||||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
|
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
|
||||||
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
|
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint, print_network
|
||||||
|
|
||||||
|
|
||||||
def is_latent(t):
|
def is_latent(t):
|
||||||
|
@ -248,7 +248,8 @@ if __name__ == '__main__':
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
clvp = torch.randn(2,768)
|
clvp = torch.randn(2,768)
|
||||||
type = torch.LongTensor([0,1])
|
type = torch.LongTensor([0,1])
|
||||||
model = TransformerDiffusionTTS(model_channels=768, unconditioned_percentage=.5, in_groups=8, prenet_channels=512, block_channels=384)
|
model = TransformerDiffusionTTS(model_channels=3072, unconditioned_percentage=.5, in_groups=8, prenet_channels=1024, block_channels=1024)
|
||||||
|
print_network(model)
|
||||||
o = model(clip, ts, aligned_sequence, cond, clvp_input=clvp, type=type)
|
o = model(clip, ts, aligned_sequence, cond, clvp_input=clvp, type=type)
|
||||||
#o = model(clip, ts, aligned_sequence, cond, aligned_latent)
|
#o = model(clip, ts, aligned_sequence, cond, aligned_latent)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user