From 5037752059c54d6c46c32a5f1f65b7e857bbe0dd Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 28 Feb 2023 22:13:21 +0000 Subject: [PATCH] oops --- src/utils.py | 16 +++++++++++----- src/webui.py | 8 ++++++-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/utils.py b/src/utils.py index 7896f1d..33dca73 100755 --- a/src/utils.py +++ b/src/utils.py @@ -600,8 +600,10 @@ class TrainingState(): self.it_rate = rate except Exception as e: pass - - message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] [Loss at it {self.losses[-1]["iteration"]}: {self.losses[-1]["loss"]}] [ETA: {self.eta_hhmmss}]' + last_loss = "" + if len(self.losses) > 0: + last_loss = f'[Loss @ it. {self.losses[-1]["iteration"]}: {self.losses[-1]["loss"]}]' + message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] {last_loss} [ETA: {self.eta_hhmmss}]' if lapsed: self.epoch = self.epoch + 1 @@ -1180,9 +1182,13 @@ def setup_args(): if os.path.isfile('./config/exec.json'): with open(f'./config/exec.json', 'r', encoding="utf-8") as f: - overrides = json.load(f) - for k in overrides: - default_arguments[k] = overrides[k] + try: + overrides = json.load(f) + for k in overrides: + default_arguments[k] = overrides[k] + except Exception as e: + print(e) + pass parser = argparse.ArgumentParser() parser.add_argument("--share", action='store_true', default=default_arguments['share'], help="Lets Gradio return a public URL to use anywhere") diff --git a/src/webui.py b/src/webui.py index aed451d..573050d 100755 --- a/src/webui.py +++ b/src/webui.py @@ -206,6 +206,7 @@ def import_training_settings_proxy( voice ): outdir = f'./training/{voice}-finetune/' in_config_path = f"{indir}/train.yaml" + out_config_path = None out_configs = [] if os.path.isdir(outdir): out_configs = sorted([d[:-5] for d in os.listdir(outdir) if d[-5:] == ".yaml" ]) @@ -240,7 +241,10 @@ def import_training_settings_proxy( voice ): save_rate = int(config['logger']['save_checkpoint_freq'] / steps_per_iteration) statedir = f'{outdir}/training_state/' # NOOO STOP MIXING YOUR CASES - resumes = sorted([int(d[:-6]) for d in os.listdir(statedir) if d[-6:] == ".state" ]) + resumes = [] + resume_path = None + if os.path.isdir(statedir): + resumes = sorted([int(d[:-6]) for d in os.listdir(statedir) if d[-6:] == ".state" ]) if len(resumes) > 0: resume_path = f'{statedir}/{resumes[-1]}.state' @@ -490,7 +494,7 @@ def setup_gradio(): with gr.Row(): refresh_dataset_list = gr.Button(value="Refresh Dataset List") - import_dataset_button = gr.Button(value="Import Dataset") + import_dataset_button = gr.Button(value="Reuse/Import Dataset") with gr.Column(): save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) with gr.Row():