forked from mrq/DL-Art-School
more simplifications
This commit is contained in:
parent
f3f391b372
commit
57da6d0ddf
|
@ -91,6 +91,7 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
attn_dropout=dropout,
|
attn_dropout=dropout,
|
||||||
use_rmsnorm=True,
|
use_rmsnorm=True,
|
||||||
ff_glu=True,
|
ff_glu=True,
|
||||||
|
ff_mult=2,
|
||||||
rotary_pos_emb=True,
|
rotary_pos_emb=True,
|
||||||
)
|
)
|
||||||
))
|
))
|
||||||
|
@ -104,12 +105,13 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
attn_layers=TimestepEmbeddingAttentionLayers(
|
attn_layers=TimestepEmbeddingAttentionLayers(
|
||||||
dim=model_channels,
|
dim=model_channels,
|
||||||
timestep_dim=time_embed_dim,
|
timestep_dim=time_embed_dim,
|
||||||
depth=3,
|
depth=2,
|
||||||
heads=num_heads,
|
heads=num_heads,
|
||||||
ff_dropout=dropout,
|
ff_dropout=dropout,
|
||||||
attn_dropout=dropout,
|
attn_dropout=dropout,
|
||||||
use_rmsnorm=True,
|
use_rmsnorm=True,
|
||||||
ff_glu=True,
|
ff_glu=True,
|
||||||
|
ff_mult=2,
|
||||||
rotary_pos_emb=True,
|
rotary_pos_emb=True,
|
||||||
layerdrop_percent=0,
|
layerdrop_percent=0,
|
||||||
)
|
)
|
||||||
|
@ -130,6 +132,7 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
attn_dropout=dropout,
|
attn_dropout=dropout,
|
||||||
use_rmsnorm=True,
|
use_rmsnorm=True,
|
||||||
ff_glu=True,
|
ff_glu=True,
|
||||||
|
ff_mult=2,
|
||||||
rotary_pos_emb=True,
|
rotary_pos_emb=True,
|
||||||
layerdrop_percent=layer_drop,
|
layerdrop_percent=layer_drop,
|
||||||
zero_init_branch_output=True,
|
zero_init_branch_output=True,
|
||||||
|
|
|
@ -318,7 +318,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_tts9_mel_flat.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_gpt_tts_unified.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user