From a5f3adbdd7d9b8245f7782216ac48913660e6bb5 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 29 Oct 2022 15:37:24 +0700 Subject: [PATCH 1/2] Allow trailing comma in learning rate --- modules/textual_inversion/learn_schedule.py | 33 +++++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index 3a736065..76e611b6 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -11,23 +11,30 @@ class LearnScheduleIterator: self.rates = [] self.it = 0 self.maxit = 0 - for i, pair in enumerate(pairs): - tmp = pair.split(':') - if len(tmp) == 2: - step = int(tmp[1]) - if step > cur_step: - self.rates.append((float(tmp[0]), min(step, max_steps))) - self.maxit += 1 - if step > max_steps: + try: + for i, pair in enumerate(pairs): + if not pair.strip(): + continue + tmp = pair.split(':') + if len(tmp) == 2: + step = int(tmp[1]) + if step > cur_step: + self.rates.append((float(tmp[0]), min(step, max_steps))) + self.maxit += 1 + if step > max_steps: + return + elif step == -1: + self.rates.append((float(tmp[0]), max_steps)) + self.maxit += 1 return - elif step == -1: + else: self.rates.append((float(tmp[0]), max_steps)) self.maxit += 1 return - else: - self.rates.append((float(tmp[0]), max_steps)) - self.maxit += 1 - return + assert self.rates + except (ValueError, AssertionError): + raise Exception("Invalid learning rate schedule") + def __iter__(self): return self From ef4c94e1cfe66299227aa95a28c2380d21cb1600 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 29 Oct 2022 15:42:51 +0700 Subject: [PATCH 2/2] Improve lr schedule error message --- modules/textual_inversion/learn_schedule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index 76e611b6..dd0c0ad1 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -4,7 +4,7 @@ import tqdm class LearnScheduleIterator: def __init__(self, learn_rate, max_steps, cur_step=0): """ - specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000 + specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000 """ pairs = learn_rate.split(',') @@ -33,7 +33,7 @@ class LearnScheduleIterator: return assert self.rates except (ValueError, AssertionError): - raise Exception("Invalid learning rate schedule") + raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') def __iter__(self):