forgot to add 'bs / gradient accum < 2 clamp validation logic

This commit is contained in:
mrq 2023-03-04 17:37:08 +00:00
parent df24827b9a
commit 1a9d159b2a
2 changed files with 65 additions and 29 deletions

View File

@ -506,6 +506,8 @@ class TrainingState():
with open(config_path, 'r') as file:
self.config = yaml.safe_load(file)
self.killed = False
self.dataset_dir = f"./training/{self.config['name']}/"
self.batch_size = self.config['datasets']['train']['batch_size']
self.dataset_path = self.config['datasets']['train']['path']
@ -527,7 +529,6 @@ class TrainingState():
self.training_started = False
self.info = {}
self.status = "..."
self.epoch_rate = ""
self.epoch_time_start = 0
@ -651,10 +652,12 @@ class TrainingState():
print("Removing", path)
os.remove(path)
def parse(self, line, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=None ):
def parse(self, line, verbose=False, keep_x_past_datasets=0, buffer_size=8, progress=None ):
self.buffer.append(f'{line}')
should_return = False
percent = 0
message = None
# rip out iteration info
if not self.training_started:
@ -679,7 +682,7 @@ class TrainingState():
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
if match and len(match) > 0:
match = match[0]
percent = int(match[0])/100.0
per_cent = int(match[0])/100.0
progressbar = match[1]
step = int(match[2])
steps = int(match[3])
@ -698,15 +701,40 @@ class TrainingState():
self.it_time_end = time.time()
self.it_time_delta = self.it_time_end-self.it_time_start
self.it_time_start = time.time()
self.it_taken = self.it_taken + 1
try:
rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s'
self.it_rate = rate
except Exception as e:
pass
last_loss = ""
metric_step = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"]
metric_step = ", ".join(metric_step)
metric_rate = []
if self.epoch_rate:
metric_rate.append(self.epoch_rate)
if self.it_rate:
metric_rate.append(self.it_rate)
metric_rate = ", ".