This commit is contained in:
mrq 2023-02-28 22:13:21 +00:00
parent 787b44807a
commit 5037752059
2 changed files with 17 additions and 7 deletions

View File

@ -600,8 +600,10 @@ class TrainingState():
self.it_rate = rate self.it_rate = rate
except Exception as e: except Exception as e:
pass pass
last_loss = ""
message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] [Loss at it {self.losses[-1]["iteration"]}: {self.losses[-1]["loss"]}] [ETA: {self.eta_hhmmss}]' if len(self.losses) > 0:
last_loss = f'[Loss @ it. {self.losses[-1]["iteration"]}: {self.losses[-1]["loss"]}]'
message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] {last_loss} [ETA: {self.eta_hhmmss}]'
if lapsed: if lapsed:
self.epoch = self.epoch + 1 self.epoch = self.epoch + 1
@ -1180,9 +1182,13 @@ def setup_args():
if os.path.isfile('./config/exec.json'): if os.path.isfile('./config/exec.json'):
with open(f'./config/exec.json', 'r', encoding="utf-8") as f: with open(f'./config/exec.json', 'r', encoding="utf-8") as f:
overrides = json.load(f) try:
for k in overrides: overrides = json.load(f)
default_arguments[k] = overrides[k] for k in overrides:
default_arguments[k] = overrides[k]
except Exception as e:
print(e)
pass
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--share", action='store_true', default=default_arguments['share'], help="Lets Gradio return a public URL to use anywhere") parser.add_argument("--share", action='store_true', default=default_arguments['share'], help="Lets Gradio return a public URL to use anywhere")

View File

@ -206,6 +206,7 @@ def import_training_settings_proxy( voice ):
outdir = f'./training/{voice}-finetune/' outdir = f'./training/{voice}-finetune/'
in_config_path = f"{indir}/train.yaml" in_config_path = f"{indir}/train.yaml"
out_config_path = None
out_configs = [] out_configs = []
if os.path.isdir(outdir): if os.path.isdir(outdir):
out_configs = sorted([d[:-5] for d in os.listdir(outdir) if d[-5:] == ".yaml" ]) out_configs = sorted([d[:-5] for d in os.listdir(outdir) if d[-5:] == ".yaml" ])
@ -240,7 +241,10 @@ def import_training_settings_proxy( voice ):
save_rate = int(config['logger']['save_checkpoint_freq'] / steps_per_iteration) save_rate = int(config['logger']['save_checkpoint_freq'] / steps_per_iteration)
statedir = f'{outdir}/training_state/' # NOOO STOP MIXING YOUR CASES statedir = f'{outdir}/training_state/' # NOOO STOP MIXING YOUR CASES
resumes = sorted([int(d[:-6]) for d in os.listdir(statedir) if d[-6:] == ".state" ]) resumes = []
resume_path = None
if os.path.isdir(statedir):
resumes = sorted([int(d[:-6]) for d in os.listdir(statedir) if d[-6:] == ".state" ])
if len(resumes) > 0: if len(resumes) > 0:
resume_path = f'{statedir}/{resumes[-1]}.state' resume_path = f'{statedir}/{resumes[-1]}.state'
@ -490,7 +494,7 @@ def setup_gradio():
with gr.Row(): with gr.Row():
refresh_dataset_list = gr.Button(value="Refresh Dataset List") refresh_dataset_list = gr.Button(value="Refresh Dataset List")
import_dataset_button = gr.Button(value="Import Dataset") import_dataset_button = gr.Button(value="Reuse/Import Dataset")
with gr.Column(): with gr.Column():
save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
with gr.Row(): with gr.Row():