dataset:
  training: [
    "./training/${voice}/valle/",
  ]
  noise: [
    "./training/valle/data/Other/noise/",
  ]
  
  speaker_name_getter: "lambda p: p.parts[-3]" # "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
  
  use_hdf5: False
  hdf5_name: data.h5
  hdf5_flag: r
  validate: True

  workers: 4
  cache: False

  phones_range: [4, 64]
  duration_range: [1.0, 8.0]

  random_utterance: 1.0
  max_prompts: 3
  prompt_duration: 3.0

  sample_type: path

  tasks_list: ["tts"] # ["tts", "ns", "sr", "tse", "cse", "nse", "tts"]

models:
  _max_levels: 8
  _models:
  - name: "ar"
    size: "full"
    resp_levels: 1
    prom_levels: 2
    tasks: 8
    arch_type: "retnet"

  - name: "nar"
    size: "full"
    resp_levels: 3
    prom_levels: 4
    tasks: 8
    arch_type: "retnet"


hyperparameters:
  batch_size: ${batch_size}
  gradient_accumulation_steps: ${gradient_accumulation_size}
  gradient_clipping: 100
  
  optimizer: AdamW
  learning_rate: 1.0e-4
  
  scheduler_type: ""

evaluation:
  batch_size: ${batch_size}
  frequency: ${validation_rate}
  size: 16
  
  steps: 300
  ar_temperature: 0.95
  nar_temperature: 0.25

trainer:
  iterations: ${iterations}
  
  save_tag: step
  save_on_oom: True
  save_on_quit: True
  export_on_save: True
  export_on_quit: True
  save_frequency: ${save_rate}

  keep_last_checkpoints: 4

  aggressive_optimizations: False

  load_state_dict: True
  #strict_loading: False
  #load_tag: "9500"
  #load_states: False
  #restart_step_count: True
  
  gc_mode: None # "global_step"

  weight_dtype: bfloat16

  backend: deepspeed
  deepspeed:
    zero_optimization_level: 2
    use_compression_training: True

inference:
  use_vocos: True
  normalize: False

  weight_dtype: float32

bitsandbytes:
  enabled: False
  injects: True
  linear: True
  embedding: True