#### general settings
name: train_tacotron2_lj
use_tb_logger: true
gpu_ids: [0]
start_step: -1
fp16: false
checkpointing_enabled: true
wandb: false

datasets:
  train:
    name: lj
    n_workers: 1
    batch_size: 72
    mode: nv_tacotron
    path: E:\4k6k\datasets\audio\LJSpeech-1.1\ljs_audio_text_train_filelist.txt

networks:
  mel_gen:
    type: generator
    which_model_G: nv_tacotron2
    args:
      encoder_kernel_size: 5
      encoder_n_convolutions: 3
      encoder_embedding_dim: 512
      decoder_rnn_dim: 1024
      prenet_dim: 256
      max_decoder_steps: 1000
      attention_rnn_dim: 1024
      attention_dim: 128
      attention_location_n_filters: 32
      attention_location_kernel_size: 31
      postnet_embedding_dim: 512
      postnet_kernel_size: 5
      postnet_n_convolutions: 5

#### path
path:
  #pretrain_model_generator: ../experiments/diffusion_unet_128_imageset_22000.pt
  strict_load: true
  #resume_state: ../experiments/train_imgset_unet_diffusion/training_state/54000.state

steps:        
  generator:
    training: mel_gen

    optimizer: adamw
    optimizer_params:
      lr: !!float 1.2e-3
      weight_decay: !!float 1e-6
      beta1: 0.9
      beta2: 0.9999
    clip_grad_eps: 1.0

    injectors:
      mel:
        type: generator
        generator: mel_gen
        in: [padded_text, input_lengths, padded_mel, output_lengths]
        out: [mel_outputs, mel_outputs_postnet, gate_outputs, alignments]
    losses:
      tacotron_loss:
        type: nv_tacotron2_loss
        weight: 1
        mel_target_key: padded_mel
        mel_output_key: mel_outputs
        mel_output_postnet_key: mel_outputs_postnet
        gate_target_key: padded_gate
        gate_output_key: gate_outputs

train:
  niter: 500000
  warmup_iter: -1
  mega_batch_factor: 3
  ema_rate: .999
  val_freq: 500

  default_lr_scheme: MultiStepLR
  gen_lr_steps: [ 50000, 100000, 150000 ]
  lr_gamma: 0.5

eval:
  evaluators:
    val:
      type: mel
      for: mel_gen
      batch_size: 16
      dataset:
        mode: nv_tacotron
        path: E:\4k6k\datasets\audio\LJSpeech-1.1\ljs_audio_text_val_filelist.txt


logger:
  print_freq: 30
  save_checkpoint_freq: 500
  visuals: [mel_outputs, padded_mel]
  is_mel_spectrogram: true
  visual_debug_rate: 100