forked from mrq/ai-voice-cloning
Fixed Keep X Previous States
This commit is contained in:
parent
9e320a34c8
commit
29b3d1ae1d
|
@ -752,8 +752,8 @@ class TrainingState():
|
|||
|
||||
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ])
|
||||
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ])
|
||||
remove_models = models[:-2]
|
||||
remove_states = states[:-2]
|
||||
remove_models = models[:-keep]
|
||||
remove_states = states[:-keep]
|
||||
|
||||
for d in remove_models:
|
||||
path = f'{self.dataset_dir}/models/{d}_gpt.pth'
|
||||
|
@ -898,6 +898,9 @@ class TrainingState():
|
|||
if should_return:
|
||||
result = "".join(self.buffer) if not self.training_started else message
|
||||
|
||||
if keep_x_past_checkpoints > 0:
|
||||
self.cleanup_old(keep=keep_x_past_checkpoints)
|
||||
|
||||
return (
|
||||
result,
|
||||
percent,
|
||||
|
|
Loading…
Reference in New Issue
Block a user