class LearnSchedule:
    def __init__(self, learn_rate, max_steps, cur_step=0):
        pairs = learn_rate.split(',')
        self.rates = []
        self.it = 0
        self.maxit = 0
        for i, pair in enumerate(pairs):
            tmp = pair.split(':')
            if len(tmp) == 2:
                step = int(tmp[1])
                if step > cur_step:
                    self.rates.append((float(tmp[0]), min(step, max_steps)))
                    self.maxit += 1
                    if step > max_steps:
                        return
                elif step == -1:
                    self.rates.append((float(tmp[0]), max_steps))
                    self.maxit += 1
                    return
            else:
                self.rates.append((float(tmp[0]), max_steps))
                self.maxit += 1
                return

    def __iter__(self):
        return self

    def __next__(self):
        if self.it < self.maxit:
            self.it += 1
            return self.rates[self.it - 1]
        else:
            raise StopIteration