2021-06-03 03:47:32 +00:00
from abc import ABC , abstractmethod
import numpy as np
import torch as th
import torch . distributed as dist
2022-03-04 18:50:50 +00:00
from torch import distributed
2021-06-03 03:47:32 +00:00
def create_named_schedule_sampler ( name , diffusion ) :
"""
Create a ScheduleSampler from a library of pre - defined samplers .
: param name : the name of the sampler .
: param diffusion : the diffusion object to sample for .
"""
if name == " uniform " :
return UniformSampler ( diffusion )
elif name == " loss-second-moment " :
return LossSecondMomentResampler ( diffusion )
else :
raise NotImplementedError ( f " unknown schedule sampler: { name } " )
class ScheduleSampler ( ABC ) :
"""
A distribution over timesteps in the diffusion process , intended to reduce
variance of the objective .
By default , samplers perform unbiased importance sampling , in which the
objective ' s mean is unchanged.
However , subclasses may override sample ( ) to change how the resampled
terms are reweighted , allowing for actual changes in the objective .
"""
@abstractmethod
def weights ( self ) :
"""
Get a numpy array of weights , one per diffusion step .
The weights needn ' t be normalized, but must be positive.
"""
def sample ( self , batch_size , device ) :
"""
Importance - sample timesteps for a batch .
: param batch_size : the number of timesteps .
: param device : the torch device to save to .
: return : a tuple ( timesteps , weights ) :
- timesteps : a tensor of timestep indices .
- weights : a tensor of weights to scale the resulting losses .
"""
w = self . weights ( )
p = w / np . sum ( w )
indices_np = np . random . choice ( len ( p ) , size = ( batch_size , ) , p = p )
indices = th . from_numpy ( indices_np ) . long ( ) . to ( device )
weights_np = 1 / ( len ( p ) * p [ indices_np ] )
weights = th . from_numpy ( weights_np ) . float ( ) . to ( device )
return indices , weights
class UniformSampler ( ScheduleSampler ) :
def __init__ ( self , diffusion ) :
self . diffusion = diffusion
self . _weights = np . ones ( [ diffusion . num_timesteps ] )
def weights ( self ) :
return self . _weights
2022-03-04 17:40:14 +00:00
class DeterministicSampler :
"""
2022-03-04 18:50:50 +00:00
Returns the same equally spread - out sampling schedule every time it is called . Automatically handles distributed
cases by sharing the load across all entities . reset ( ) must be called once a full batch is completed .
2022-03-04 17:40:14 +00:00
"""
2022-03-04 18:50:50 +00:00
def __init__ ( self , diffusion , sampling_range , env ) :
2022-03-04 17:40:14 +00:00
super ( ) . __init__ ( )
self . timesteps = diffusion . num_timesteps
2022-03-04 18:50:50 +00:00
self . rank = max ( env [ ' rank ' ] , 0 )
if distributed . is_initialized ( ) :
self . world_size = distributed . get_world_size ( )
else :
self . world_size = 1
# The sampling range gets spread out across multiple distributed entities.
2022-03-04 18:58:53 +00:00
rnge = th . arange ( self . rank , sampling_range , step = self . world_size ) . float ( ) / sampling_range
2022-03-04 18:50:50 +00:00
self . indices = ( rnge * self . timesteps ) . long ( )
2022-03-04 17:40:14 +00:00
def sample ( self , batch_size , device ) :
2022-03-04 18:50:50 +00:00
"""
Iteratively samples across the deterministic range specified by the initialization params .
"""
assert batch_size < self . indices . shape [ 0 ]
if self . counter + batch_size > self . indices . shape [ 0 ] :
print ( f " Diffusion DeterministicSampler; Likely error. { self . counter } , { batch_size } , { self . indices . shape [ 0 ] } . Did you forget to set the sampling range to your batch size for the deterministic sampler? " )
self . counter = 0 # Recover by setting to 0.
indices = self . indices [ self . counter : self . counter + batch_size ] . to ( device )
self . counter = self . counter + batch_size
2022-03-04 17:40:14 +00:00
weights = th . ones_like ( indices ) . float ( )
return indices , weights
2022-03-04 18:50:50 +00:00
def reset ( self ) :
self . counter = 0
2022-03-04 17:40:14 +00:00
2021-06-03 03:47:32 +00:00
class LossAwareSampler ( ScheduleSampler ) :
def update_with_local_losses ( self , local_ts , local_losses ) :
"""
Update the reweighting using losses from a model .
Call this method from each rank with a batch of timesteps and the
corresponding losses for each of those timesteps .
This method will perform synchronization to make sure all of the ranks
maintain the exact same reweighting .
: param local_ts : an integer Tensor of timesteps .
: param local_losses : a 1 D Tensor of losses .
"""
b atch_sizes = [
th . tensor ( [ 0 ] , dtype = th . int32 , device = local_ts . device )
for _ in range ( dist . get_world_size ( ) )
]
dist . all_gather (
batch_sizes ,
th . tensor ( [ len ( local_ts ) ] , dtype = th . int32 , device = local_ts . device ) ,
)
# Pad all_gather batches to be the maximum batch size.
batch_sizes = [ x . item ( ) for x in batch_sizes ]
max_bs = max ( batch_sizes )
timestep_batches = [ th . zeros ( max_bs ) . to ( local_ts ) for bs in batch_sizes ]
loss_batches = [ th . zeros ( max_bs ) . to ( local_losses ) for bs in batch_sizes ]
dist . all_gather ( timestep_batches , local_ts )
dist . all_gather ( loss_batches , local_losses )
timesteps = [
x . item ( ) for y , bs in zip ( timestep_batches , batch_sizes ) for x in y [ : bs ]
]
losses = [ x . item ( ) for y , bs in zip ( loss_batches , batch_sizes ) for x in y [ : bs ] ]
self . update_with_all_losses ( timesteps , losses )
@abstractmethod
def update_with_all_losses ( self , ts , losses ) :
"""
Update the reweighting using losses from a model .
Sub - classes should override this method to update the reweighting
using losses from the model .
This method directly updates the reweighting without synchronizing
between workers . It is called by update_with_local_losses from all
ranks with identical arguments . Thus , it should have deterministic
behavior to maintain state across workers .
: param ts : a list of int timesteps .
: param losses : a list of float losses , one per timestep .
"""
class LossSecondMomentResampler ( LossAwareSampler ) :
def __init__ ( self , diffusion , history_per_term = 10 , uniform_prob = 0.001 ) :
self . diffusion = diffusion
self . history_per_term = history_per_term
self . uniform_prob = uniform_prob
self . _loss_history = np . zeros (
[ diffusion . num_timesteps , history_per_term ] , dtype = np . float64
)
self . _loss_counts = np . zeros ( [ diffusion . num_timesteps ] , dtype = np . int )
def weights ( self ) :
if not self . _warmed_up ( ) :
return np . ones ( [ self . diffusion . num_timesteps ] , dtype = np . float64 )
weights = np . sqrt ( np . mean ( self . _loss_history * * 2 , axis = - 1 ) )
weights / = np . sum ( weights )
weights * = 1 - self . uniform_prob
weights + = self . uniform_prob / len ( weights )
return weights
def update_with_all_losses ( self , ts , losses ) :
for t , loss in zip ( ts , losses ) :
if self . _loss_counts [ t ] == self . history_per_term :
# Shift out the oldest loss term.
self . _loss_history [ t , : - 1 ] = self . _loss_history [ t , 1 : ]
self . _loss_history [ t , - 1 ] = loss
else :
self . _loss_history [ t , self . _loss_counts [ t ] ] = loss
self . _loss_counts [ t ] + = 1
def _warmed_up ( self ) :
return ( self . _loss_counts == self . history_per_term ) . all ( )