2020-05-15 13:40:45 +00:00
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
# File : batchnorm.py
|
|
|
|
# Author : Jiayuan Mao
|
|
|
|
# Email : maojiayuan@gmail.com
|
|
|
|
# Date : 27/01/2018
|
|
|
|
#
|
|
|
|
# This file is part of Synchronized-BatchNorm-PyTorch.
|
|
|
|
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
|
|
|
# Distributed under MIT License.
|
|
|
|
|
|
|
|
import collections
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
from torch.nn.modules.batchnorm import _BatchNorm
|
|
|
|
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
|
|
|
|
|
|
|
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
|
|
|
|
|
|
|
|
|
|
|
def _sum_ft(tensor):
|
|
|
|
"""sum over the first and last dimention"""
|
|
|
|
return tensor.sum(dim=0).sum(dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
def _unsqueeze_ft(tensor):
|
|
|
|
"""add new dementions at the front and the tail"""
|
|
|
|
return tensor.unsqueeze(0).unsqueeze(-1)
|
|
|
|
|
|
|
|
|
|
|
|
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
|
|
|
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
|
|
|
|
|
|
|
|
|
|
|
# _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size'])
|
|
|
|
|
|
|
|
class _SynchronizedBatchNorm(_BatchNorm):
|
|
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
|
|
|
|
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
|
|
|
|
|
|
|
self._sync_master = SyncMaster(self._data_parallel_master)
|
|
|
|
|
|
|
|
self._is_parallel = False
|
|
|
|
self._parallel_id = None
|
|
|
|
self._slave_pipe = None
|
|
|
|
|
|
|
|
def forward(self, input, gain=None, bias=None):
|
|
|
|
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
|
|
|
if not (self._is_parallel and self.training):
|
|
|
|
out = F.batch_norm(
|
|
|
|
input, self.running_mean, self.running_var, self.weight, self.bias,
|
|
|
|
self.training, self.momentum, self.eps)
|
|
|
|
if gain is not None:
|
|
|
|
out = out + gain
|
|
|
|
if bias is not None:
|
|
|
|
out = out + bias
|
|
|
|
return out
|
|
|
|
|
|
|
|
# Resize the input to (B, C, -1).
|
|
|
|
input_shape = input.size()
|
|
|
|
# print(input_shape)
|
|
|
|
input = input.view(input.size(0), input.size(1), -1)
|
|
|
|
|
|
|
|
# Compute the sum and square-sum.
|
|
|
|
sum_size = input.size(0) * input.size(2)
|
|
|
|
input_sum = _sum_ft(input)
|
|
|
|
input_ssum = _sum_ft(input ** 2)
|
|
|
|
# Reduce-and-broadcast the statistics.
|
|
|
|
# print('it begins')
|
|
|
|
if self._parallel_id == 0:
|
|
|
|
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
|
|
|
else:
|
|
|
|
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
|
|
|
# if self._parallel_id == 0:
|
|
|
|
# # print('here')
|
|
|
|
# sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
|
|
|
# else:
|
|
|
|
# # print('there')
|
|
|
|
# sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
|
|
|
|
|
|
|
# print('how2')
|
|
|
|
# num = sum_size
|
|
|
|
# print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu())))
|
|
|
|
# Fix the graph
|
|
|
|
# sum = (sum.detach() - input_sum.detach()) + input_sum
|
|
|
|
# ssum = (ssum.detach() - input_ssum.detach()) + input_ssum
|
|
|
|
|
|
|
|
# mean = sum / num
|
|
|
|
# var = ssum / num - mean ** 2
|
|
|
|
# # var = (ssum - mean * sum) / num
|
|
|
|
# inv_std = torch.rsqrt(var + self.eps)
|
|
|
|
|
|
|
|
# Compute the output.
|
|
|
|
if gain is not None:
|
|
|
|
# print('gaining')
|
|
|
|
# scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1)
|
|
|
|
# shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1)
|
|
|
|
# output = input * scale - shift
|
|
|
|
output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1)
|
|
|
|
elif self.affine:
|
|
|
|
# MJY:: Fuse the multiplication for speed.
|
|
|
|
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
|
|
|
else:
|
|
|
|
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
|
|
|
|
|
|
|
# Reshape it.
|
|
|
|
return output.view(input_shape)
|
|
|
|
|
|
|
|
def __data_parallel_replicate__(self, ctx, copy_id):
|
|
|
|
self._is_parallel = True
|
|
|
|
self._parallel_id = copy_id
|
|
|
|
|
|
|
|
# parallel_id == 0 means master device.
|
|
|
|
if self._parallel_id == 0:
|
|
|
|
ctx.sync_master = self._sync_master
|
|
|
|
else:
|
|
|
|
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
|
|
|
|
|
|
|
def _data_parallel_master(self, intermediates):
|
|
|
|
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
|
|
|
|
|
|
|
# Always using same "device order" makes the ReduceAdd operation faster.
|
|
|
|
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
|
|
|
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
|
|
|
|
|
|
|
to_reduce = [i[1][:2] for i in intermediates]
|
|
|
|
to_reduce = [j for i in to_reduce for j in i] # flatten
|
|
|
|
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
|
|
|
|
|
|
|
sum_size = sum([i[1].sum_size for i in intermediates])
|
|
|
|
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
|
|
|
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
|
|
|
|
|
|
|
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
|
|
|
# print('a')
|
|
|
|
# print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size)
|
|
|
|
# broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device))
|
|
|
|
# print('b')
|
|
|
|
outputs = []
|
|
|
|
for i, rec in enumerate(intermediates):
|
|
|
|
outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))
|
|
|
|
# outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3])))
|
|
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
def _compute_mean_std(self, sum_, ssum, size):
|
|
|
|
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
|
|
|
also maintains the moving average on the master device."""
|
|
|
|
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
|
|
|
mean = sum_ / size
|
|
|
|
sumvar = ssum - sum_ * mean
|
|
|
|
unbias_var = sumvar / (size - 1)
|
|
|
|
bias_var = sumvar / size
|
|
|
|
|
|
|
|
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
|
|
|
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
|
|
|
return mean, torch.rsqrt(bias_var + self.eps)
|
|
|
|
# return mean, bias_var.clamp(self.eps) ** -0.5
|
|
|
|
|
|
|
|
|
|
|
|
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
|
|
|
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
|
|
|
mini-batch.
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
|
|
|
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
|
|
|
|
|
|
|
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
|
|
|
standard-deviation are reduced across all devices during training.
|
|
|
|
|
|
|
|
For example, when one uses `nn.DataParallel` to wrap the network during
|
|
|
|
training, PyTorch's implementation normalize the tensor on each device using
|
|
|
|
the statistics only on that device, which accelerated the computation and
|
|
|
|
is also easy to implement, but the statistics might be inaccurate.
|
|
|
|
Instead, in this synchronized version, the statistics will be computed
|
|
|
|
over all training samples distributed on multiple devices.
|
|
|
|
|
|
|
|
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
|
|
|
as the built-in PyTorch implementation.
|
|
|
|
|
|
|
|
The mean and standard-deviation are calculated per-dimension over
|
|
|
|
the mini-batches and gamma and beta are learnable parameter vectors
|
|
|
|
of size C (where C is the input size).
|
|
|
|
|
|
|
|
During training, this layer keeps a running estimate of its computed mean
|
|
|
|
and variance. The running sum is kept with a default momentum of 0.1.
|
|
|
|
|
|
|
|
During evaluation, this running mean/variance is used for normalization.
|
|
|
|
|
|
|
|
Because the BatchNorm is done over the `C` dimension, computing statistics
|
|
|
|
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
|
|
|
|
|
|
|
Args:
|
|
|
|
num_features: num_features from an expected input of size
|
|
|
|
`batch_size x num_features [x width]`
|
|
|
|
eps: a value added to the denominator for numerical stability.
|
|
|
|
Default: 1e-5
|
|
|
|
momentum: the value used for the running_mean and running_var
|
|
|
|
computation. Default: 0.1
|
|
|
|
affine: a boolean value that when set to ``True``, gives the layer learnable
|
|
|
|
affine parameters. Default: ``True``
|
|
|
|
|
|
|
|
Shape:
|
|
|
|
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
|
|
|
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> # With Learnable Parameters
|
|
|
|
>>> m = SynchronizedBatchNorm1d(100)
|
|
|
|
>>> # Without Learnable Parameters
|
|
|
|
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
|
|
|
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
|
|
|
>>> output = m(input)
|
|
|
|
"""
|
|
|
|
|
|
|
|
def _check_input_dim(self, input):
|
|
|
|
if input.dim() != 2 and input.dim() != 3:
|
|
|
|
raise ValueError('expected 2D or 3D input (got {}D input)'
|
|
|
|
.format(input.dim()))
|
|
|
|
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
|
|
|
|
|
|
|
|
|
|
|
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
|
|
|
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
|
|
|
of 3d inputs
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
|
|
|
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
|
|
|
|
|
|
|
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
|
|
|
standard-deviation are reduced across all devices during training.
|
|
|
|
|
|
|
|
For example, when one uses `nn.DataParallel` to wrap the network during
|
|
|
|
training, PyTorch's implementation normalize the tensor on each device using
|
|
|
|
the statistics only on that device, which accelerated the computation and
|
|
|
|
is also easy to implement, but the statistics might be inaccurate.
|
|
|
|
Instead, in this synchronized version, the statistics will be computed
|
|
|
|
over all training samples distributed on multiple devices.
|
|
|
|
|
|
|
|
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
|
|
|
as the built-in PyTorch implementation.
|
|
|
|
|
|
|
|
The mean and standard-deviation are calculated per-dimension over
|
|
|
|
the mini-batches and gamma and beta are learnable parameter vectors
|
|
|
|
of size C (where C is the input size).
|
|
|
|
|
|
|
|
During training, this layer keeps a running estimate of its computed mean
|
|
|
|
and variance. The running sum is kept with a default momentum of 0.1.
|
|
|
|
|
|
|
|
During evaluation, this running mean/variance is used for normalization.
|
|
|
|
|
|
|
|
Because the BatchNorm is done over the `C` dimension, computing statistics
|
|
|
|
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
|
|
|
|
|
|
|
Args:
|
|
|
|
num_features: num_features from an expected input of
|
|
|
|
size batch_size x num_features x height x width
|
|
|
|
eps: a value added to the denominator for numerical stability.
|
|
|
|
Default: 1e-5
|
|
|
|
momentum: the value used for the running_mean and running_var
|
|
|
|
computation. Default: 0.1
|
|
|
|
affine: a boolean value that when set to ``True``, gives the layer learnable
|
|
|
|
affine parameters. Default: ``True``
|
|
|
|
|
|
|
|
Shape:
|
|
|
|
- Input: :math:`(N, C, H, W)`
|
|
|
|
- Output: :math:`(N, C, H, W)` (same shape as input)
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> # With Learnable Parameters
|
|
|
|
>>> m = SynchronizedBatchNorm2d(100)
|
|
|
|
>>> # Without Learnable Parameters
|
|
|
|
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
|
|
|
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
|
|
|
>>> output = m(input)
|
|
|
|
"""
|
|
|
|
|
|
|
|
def _check_input_dim(self, input):
|
|
|
|
if input.dim() != 4:
|
|
|
|
raise ValueError('expected 4D input (got {}D input)'
|
|
|
|
.format(input.dim()))
|
|
|
|
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
|
|
|
|
|
|
|
|
|
|
|
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
|
|
|
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
|
|
|
of 4d inputs
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
|
|
|
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
|
|
|
|
|
|
|
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
|
|
|
standard-deviation are reduced across all devices during training.
|
|
|
|
|
|
|
|
For example, when one uses `nn.DataParallel` to wrap the network during
|
|
|
|
training, PyTorch's implementation normalize the tensor on each device using
|
|
|
|
the statistics only on that device, which accelerated the computation and
|
|
|
|
is also easy to implement, but the statistics might be inaccurate.
|
|
|
|
Instead, in this synchronized version, the statistics will be computed
|
|
|
|
over all training samples distributed on multiple devices.
|
|
|
|
|
|
|
|
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
|
|
|
as the built-in PyTorch implementation.
|
|
|
|
|
|
|
|
The mean and standard-deviation are calculated per-dimension over
|
|
|
|
the mini-batches and gamma and beta are learnable parameter vectors
|
|
|
|
of size C (where C is the input size).
|
|
|
|
|
|
|
|
During training, this layer keeps a running estimate of its computed mean
|
|
|
|
and variance. The running sum is kept with a default momentum of 0.1.
|
|
|
|
|
|
|
|
During evaluation, this running mean/variance is used for normalization.
|
|
|
|
|
|
|
|
Because the BatchNorm is done over the `C` dimension, computing statistics
|
|
|
|
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
|
|
|
or Spatio-temporal BatchNorm
|
|
|
|
|
|
|
|
Args:
|
|
|
|
num_features: num_features from an expected input of
|
|
|
|
size batch_size x num_features x depth x height x width
|
|
|
|
eps: a value added to the denominator for numerical stability.
|
|
|
|
Default: 1e-5
|
|
|
|
momentum: the value used for the running_mean and running_var
|
|
|
|
computation. Default: 0.1
|
|
|
|
affine: a boolean value that when set to ``True``, gives the layer learnable
|
|
|
|
affine parameters. Default: ``True``
|
|
|
|
|
|
|
|
Shape:
|
|
|
|
- Input: :math:`(N, C, D, H, W)`
|
|
|
|
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> # With Learnable Parameters
|
|
|
|
>>> m = SynchronizedBatchNorm3d(100)
|
|
|
|
>>> # Without Learnable Parameters
|
|
|
|
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
|
|
|
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
|
|
|
>>> output = m(input)
|
|
|
|
"""
|
|
|
|
|
|
|
|
def _check_input_dim(self, input):
|
|
|
|
if input.dim() != 5:
|
|
|
|
raise ValueError('expected 5D input (got {}D input)'
|
|
|
|
.format(input.dim()))
|
2020-05-24 13:43:23 +00:00
|
|
|
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
|
|
|
|
|
|
|
|
|
|
|
# From ccomm.py
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
# File : comm.py
|
|
|
|
# Author : Jiayuan Mao
|
|
|
|
# Email : maojiayuan@gmail.com
|
|
|
|
# Date : 27/01/2018
|
|
|
|
#
|
|
|
|
# This file is part of Synchronized-BatchNorm-PyTorch.
|
|
|
|
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
|
|
|
# Distributed under MIT License.
|
|
|
|
|
|
|
|
import queue
|
|
|
|
import collections
|
|
|
|
import threading
|
|
|
|
|
|
|
|
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
|
|
|
|
|
|
|
|
|
|
|
class FutureResult(object):
|
|
|
|
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self._result = None
|
|
|
|
self._lock = threading.Lock()
|
|
|
|
self._cond = threading.Condition(self._lock)
|
|
|
|
|
|
|
|
def put(self, result):
|
|
|
|
with self._lock:
|
|
|
|
assert self._result is None, 'Previous result has\'t been fetched.'
|
|
|
|
self._result = result
|
|
|
|
self._cond.notify()
|
|
|
|
|
|
|
|
def get(self):
|
|
|
|
with self._lock:
|
|
|
|
if self._result is None:
|
|
|
|
self._cond.wait()
|
|
|
|
|
|
|
|
res = self._result
|
|
|
|
self._result = None
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
|
|
|
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
|
|
|
|
|
|
|
|
|
|
|
class SlavePipe(_SlavePipeBase):
|
|
|
|
"""Pipe for master-slave communication."""
|
|
|
|
|
|
|
|
def run_slave(self, msg):
|
|
|
|
self.queue.put((self.identifier, msg))
|
|
|
|
ret = self.result.get()
|
|
|
|
self.queue.put(True)
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
class SyncMaster(object):
|
|
|
|
"""An abstract `SyncMaster` object.
|
|
|
|
|
|
|
|
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
|
|
|
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
|
|
|
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
|
|
|
and passed to a registered callback.
|
|
|
|
- After receiving the messages, the master device should gather the information and determine to message passed
|
|
|
|
back to each slave devices.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, master_callback):
|
|
|
|
"""
|
|
|
|
|
|
|
|
Args:
|
|
|
|
master_callback: a callback to be invoked after having collected messages from slave devices.
|
|
|
|
"""
|
|
|
|
self._master_callback = master_callback
|
|
|
|
self._queue = queue.Queue()
|
|
|
|
self._registry = collections.OrderedDict()
|
|
|
|
self._activated = False
|
|
|
|
|
|
|
|
def __getstate__(self):
|
|
|
|
return {'master_callback': self._master_callback}
|
|
|
|
|
|
|
|
def __setstate__(self, state):
|
|
|
|
self.__init__(state['master_callback'])
|
|
|
|
|
|
|
|
def register_slave(self, identifier):
|
|
|
|
"""
|
|
|
|
Register an slave device.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
identifier: an identifier, usually is the device id.
|
|
|
|
|
|
|
|
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
|
|
|
|
|
|
|
"""
|
|
|
|
if self._activated:
|
|
|
|
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
|
|
|
self._activated = False
|
|
|
|
self._registry.clear()
|
|
|
|
future = FutureResult()
|
|
|
|
self._registry[identifier] = _MasterRegistry(future)
|
|
|
|
return SlavePipe(identifier, self._queue, future)
|
|
|
|
|
|
|
|
def run_master(self, master_msg):
|
|
|
|
"""
|
|
|
|
Main entry for the master device in each forward pass.
|
|
|
|
The messages were first collected from each devices (including the master device), and then
|
|
|
|
an callback will be invoked to compute the message to be sent back to each devices
|
|
|
|
(including the master device).
|
|
|
|
|
|
|
|
Args:
|
|
|
|
master_msg: the message that the master want to send to itself. This will be placed as the first
|
|
|
|
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
|
|
|
|
|
|
|
Returns: the message to be sent back to the master device.
|
|
|
|
|
|
|
|
"""
|
|
|
|
self._activated = True
|
|
|
|
|
|
|
|
intermediates = [(0, master_msg)]
|
|
|
|
for i in range(self.nr_slaves):
|
|
|
|
intermediates.append(self._queue.get())
|
|
|
|
|
|
|
|
results = self._master_callback(intermediates)
|
|
|
|
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
|
|
|
|
|
|
|
for i, res in results:
|
|
|
|
if i == 0:
|
|
|
|
continue
|
|
|
|
self._registry[i].result.put(res)
|
|
|
|
|
|
|
|
for i in range(self.nr_slaves):
|
|
|
|
assert self._queue.get() is True
|
|
|
|
|
|
|
|
return results[0][1]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def nr_slaves(self):
|
|
|
|
return len(self._registry)
|