forked from mrq/ai-voice-cloning
fixes
This commit is contained in:
parent
bbc2d26289
commit
b6f7aa6264
39
src/utils.py
39
src/utils.py
|
@ -415,7 +415,7 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
|
||||||
its = config['train']['niter']
|
its = config['train']['niter']
|
||||||
|
|
||||||
checkpoint = 0
|
checkpoint = 0
|
||||||
checkpoints = config['logger']['save_checkpoint_freq'] / its
|
checkpoints = its / config['logger']['save_checkpoint_freq']
|
||||||
|
|
||||||
buffer_size = 8
|
buffer_size = 8
|
||||||
open_state = False
|
open_state = False
|
||||||
|
@ -443,40 +443,35 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
|
||||||
elif progress is not None:
|
elif progress is not None:
|
||||||
if line.find(' 0%|') == 0:
|
if line.find(' 0%|') == 0:
|
||||||
open_state = True
|
open_state = True
|
||||||
it_time_start = time.time()
|
|
||||||
elif line.find('100%|') == 0 and open_state:
|
elif line.find('100%|') == 0 and open_state:
|
||||||
it_time_end = time.time()
|
|
||||||
open_state = False
|
open_state = False
|
||||||
it = it + 1
|
it = it + 1
|
||||||
|
|
||||||
|
it_time_end = time.time()
|
||||||
it_time_delta = it_time_end-it_time_start
|
it_time_delta = it_time_end-it_time_start
|
||||||
it_rate = f'[{"{:.3f}".format(it_time_delta)}s/it]' if it_time_delta >= 1 and it_time_delta != 0 else f'[{"{:.3f}".format(1/it_time_delta)}it/s]' # I doubt anyone will have it/s rates, but its here
|
it_time_start = time.time()
|
||||||
|
it_rate = f'[{"{:.3f}".format(it_time_delta)}s/it]' if it_time_delta >= 1 else f'[{"{:.3f}".format(1/it_time_delta)}it/s]' # I doubt anyone will have it/s rates, but its here
|
||||||
|
|
||||||
progress(it / float(its), f'[{it}/{its}] {it_rate} Training... {status}')
|
progress(it / float(its), f'[{it}/{its}] {it_rate} Training... {status}')
|
||||||
|
|
||||||
# try because I haven't tested this yet
|
|
||||||
try:
|
|
||||||
if line.find('INFO: [epoch:') >= 0:
|
|
||||||
# easily rip out our stats...
|
|
||||||
match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line)
|
|
||||||
if match and len(match) > 0:
|
|
||||||
for k, v in match:
|
|
||||||
info[k] = float(v)
|
|
||||||
|
|
||||||
# ...and returns our loss rate
|
|
||||||
# it would be nice for losses to be shown at every step
|
|
||||||
if 'loss_gpt_total' in info:
|
|
||||||
status = f"Total loss at step {int(info['step'])}: {info['loss_gpt_total']}"
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if line.find('Saving models and training states') >= 0:
|
if line.find('INFO: [epoch:') >= 0:
|
||||||
|
# easily rip out our stats...
|
||||||
|
match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line)
|
||||||
|
if match and len(match) > 0:
|
||||||
|
for k, v in match:
|
||||||
|
info[k] = float(v)
|
||||||
|
|
||||||
|
# ...and returns our loss rate
|
||||||
|
# it would be nice for losses to be shown at every step
|
||||||
|
if 'loss_gpt_total' in info:
|
||||||
|
status = f"Total loss at step {int(info['step'])}: {info['loss_gpt_total']}"
|
||||||
|
elif line.find('Saving models and training states') >= 0:
|
||||||
checkpoint = checkpoint + 1
|
checkpoint = checkpoint + 1
|
||||||
progress(checkpoint / float(checkpoints), f'[{checkpoint}/{checkpoints}] Saving checkpoint...')
|
progress(checkpoint / float(checkpoints), f'[{checkpoint}/{checkpoints}] Saving checkpoint...')
|
||||||
|
|
||||||
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
||||||
|
|
||||||
if verbose:
|
if verbose or not training_started:
|
||||||
yield "".join(buffer[-buffer_size:])
|
yield "".join(buffer[-buffer_size:])
|
||||||
|
|
||||||
training_process.stdout.close()
|
training_process.stdout.close()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user