This commit is contained in:
mrq 2023-03-08 00:51:51 +00:00
parent e862169e7f
commit a7e0dc9127

View File

@ -802,7 +802,7 @@ class TrainingState():
if line.find('INFO: [epoch:') >= 0: if line.find('INFO: [epoch:') >= 0:
info_line = line.split("INFO:")[-1] info_line = line.split("INFO:")[-1]
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point # to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
if ': nan' in info_line and not self.self.nan_detected: if ': nan' in info_line and not self.nan_detected:
self.nan_detected = self.it self.nan_detected = self.it
# easily rip out our stats... # easily rip out our stats...
@ -986,11 +986,6 @@ class TrainingState():
message, message,
) )
def run_training(config_path, verbose=False, gpus=1, keep_x_past_checkpoints=0, progress=gr.Progress(track_tqdm=True)):
global training_state
if training_state and training_state.process:
return "Training already in progress"
try: try:
import altair as alt import altair as alt
alt.data_transformers.enable('default', max_rows=None) alt.data_transformers.enable('default', max_rows=None)
@ -998,6 +993,12 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_checkpoints=0,
print(e) print(e)
pass pass
def run_training(config_path, verbose=False, gpus=1, keep_x_past_checkpoints=0, progress=gr.Progress(track_tqdm=True)):
global training_state
if training_state and training_state.process:
return "Training already in progress"
# ensure we have the dvae.pth # ensure we have the dvae.pth
get_model_path('dvae.pth') get_model_path('dvae.pth')