From fefc7aba0388d061a2db37cd74ebef7a9172439b Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 21 Feb 2023 22:13:30 +0000 Subject: [PATCH] oops --- src/utils.py | 12 ++++++++++-- src/webui.py | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/utils.py b/src/utils.py index eafe425..d7efdaa 100755 --- a/src/utils.py +++ b/src/utils.py @@ -153,7 +153,10 @@ def generate( # clamp it down for the insane users who want this # it would be wiser to enforce the sample size to the batch size, but this is what the user wants - if num_autoregressive_samples < args.sample_batch_size: + sample_batch_size = args.sample_batch_size + if not sample_batch_size: + sample_batch_size = tts.autoregressive_batch_size + if num_autoregressive_samples < sample_batch_size: settings['sample_batch_size'] = num_autoregressive_samples if delimiter is None: @@ -493,7 +496,12 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress progress(it / float(its), f'[{it}/{its}] {it_rate} Training... {status}') - if line.find('INFO: [epoch:') >= 0: + # it also says Start training from epoch, so it might be better to do that + if line.find('INFO: Resuming training from epoch:') >= 0: + match = re.findall(r'iter: ([\d,]+)', line) + if match and len(match) > 0: + it = int(match[0].replace(",", "")) + elif line.find('INFO: [epoch:') >= 0: # easily rip out our stats... match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line) if match and len(match) > 0: diff --git a/src/webui.py b/src/webui.py index a6ab00c..e4edc55 100755 --- a/src/webui.py +++ b/src/webui.py @@ -288,7 +288,7 @@ def setup_gradio(): seed = gr.Number(value=0, precision=0, label="Seed") preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value" ) - num_autoregressive_samples = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Samples") + num_autoregressive_samples = gr.Slider(value=128, minimum=2, maximum=512, step=1, label="Samples") diffusion_iterations = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Iterations") temperature = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature")