forked from mrq/DL-Art-School
more simplifications
This commit is contained in:
parent
f3f391b372
commit
57da6d0ddf
|
@ -91,6 +91,7 @@ class DiffusionTtsFlat(nn.Module):
|
|||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
ff_mult=2,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
))
|
||||
|
@ -104,12 +105,13 @@ class DiffusionTtsFlat(nn.Module):
|
|||
attn_layers=TimestepEmbeddingAttentionLayers(
|
||||
dim=model_channels,
|
||||
timestep_dim=time_embed_dim,
|
||||
depth=3,
|
||||
depth=2,
|
||||
heads=num_heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
ff_mult=2,
|
||||
rotary_pos_emb=True,
|
||||
layerdrop_percent=0,
|
||||
)
|
||||
|
@ -130,6 +132,7 @@ class DiffusionTtsFlat(nn.Module):
|
|||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
ff_mult=2,
|
||||
rotary_pos_emb=True,
|
||||
layerdrop_percent=layer_drop,
|
||||
zero_init_branch_output=True,
|
||||
|
|
|
@ -318,7 +318,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_tts9_mel_flat.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_gpt_tts_unified.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user