# This is a wrapper around train.py which allows you to train a set of models using a variety of different training
# paradigms. This works by using the yielding mechanism built into train.py to iterate one step at a time and
# synchronize the underlying models.
#
# Note that this wrapper is **EXTREMELY** simple and doesn't attempt to do many things. Some issues you should plan for:
# 1) Each trainer will have its own optimizer for the underlying model - even when the model is shared.
# 2) Each trainer will run validation and save model states according to its own schedule. Likewise:
# 3) Each trainer will load state params for the models it controls independently, regardless of whether or not those
#    models are shared. Your best bet is to have all models save state at the same time so that they all load ~ the same
#    state when re-started.
import argparse

import yaml

import train
import utils.options as option
from utils.util import OrderedYaml
import torch


def main(master_opt, launcher):
    trainers = []
    all_networks = {}
    shared_networks = []
    if launcher != 'none':
        train.init_dist('nccl')
    for i, sub_opt in enumerate(master_opt['trainer_options']):
        sub_opt_parsed = option.parse(sub_opt, is_train=True)
        trainer = train.Trainer()

        #### distributed training settings
        if launcher == 'none':  # disabled distributed training
            sub_opt_parsed['dist'] = False
            trainer.rank = -1
            print('Disabled distributed training.')
        else:
            sub_opt_parsed['dist'] = True
            trainer.world_size = torch.distributed.get_world_size()
            trainer.rank = torch.distributed.get_rank()

        trainer.init(sub_opt_parsed, launcher, all_networks)
        train_gen = trainer.create_training_generator(i)
        model = next(train_gen)
        for k, v in model.networks.items():
            if k in all_networks.keys() and k not in shared_networks:
                shared_networks.append(k)
            all_networks[k] = v.module
        trainers.append(train_gen)
    print("Networks being shared by trainers: ", shared_networks)

    # Now, simply "iterate" through the trainers to accomplish training.
    while True:
        for trainer in trainers:
            next(trainer)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structured_trans_invariance.yml')
    parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()

    Loader, Dumper = OrderedYaml()
    with open(args.opt, mode='r') as f:
        opt = yaml.load(f, Loader=Loader)
        main(opt, args.launcher)