696 lines
28 KiB
C++
696 lines
28 KiB
C++
|
// modify from
|
||
|
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
|
||
|
|
||
|
#include <torch/extension.h>
|
||
|
|
||
|
#include <cmath>
|
||
|
#include <vector>
|
||
|
|
||
|
void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
|
||
|
const int channels, const int height, const int width,
|
||
|
const int ksize_h, const int ksize_w, const int pad_h,
|
||
|
const int pad_w, const int stride_h, const int stride_w,
|
||
|
const int dilation_h, const int dilation_w,
|
||
|
const int parallel_imgs, const int deformable_group,
|
||
|
at::Tensor data_col);
|
||
|
|
||
|
void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
|
||
|
const int channels, const int height, const int width,
|
||
|
const int ksize_h, const int ksize_w, const int pad_h,
|
||
|
const int pad_w, const int stride_h, const int stride_w,
|
||
|
const int dilation_h, const int dilation_w,
|
||
|
const int parallel_imgs, const int deformable_group,
|
||
|
at::Tensor grad_im);
|
||
|
|
||
|
void deformable_col2im_coord(
|
||
|
const at::Tensor data_col, const at::Tensor data_im,
|
||
|
const at::Tensor data_offset, const int channels, const int height,
|
||
|
const int width, const int ksize_h, const int ksize_w, const int pad_h,
|
||
|
const int pad_w, const int stride_h, const int stride_w,
|
||
|
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
||
|
const int deformable_group, at::Tensor grad_offset);
|
||
|
|
||
|
void modulated_deformable_im2col_cuda(
|
||
|
const at::Tensor data_im, const at::Tensor data_offset,
|
||
|
const at::Tensor data_mask, const int batch_size, const int channels,
|
||
|
const int height_im, const int width_im, const int height_col,
|
||
|
const int width_col, const int kernel_h, const int kenerl_w,
|
||
|
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||
|
const int dilation_h, const int dilation_w, const int deformable_group,
|
||
|
at::Tensor data_col);
|
||
|
|
||
|
void modulated_deformable_col2im_cuda(
|
||
|
const at::Tensor data_col, const at::Tensor data_offset,
|
||
|
const at::Tensor data_mask, const int batch_size, const int channels,
|
||
|
const int height_im, const int width_im, const int height_col,
|
||
|
const int width_col, const int kernel_h, const int kenerl_w,
|
||
|
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||
|
const int dilation_h, const int dilation_w, const int deformable_group,
|
||
|
at::Tensor grad_im);
|
||
|
|
||
|
void modulated_deformable_col2im_coord_cuda(
|
||
|
const at::Tensor data_col, const at::Tensor data_im,
|
||
|
const at::Tensor data_offset, const at::Tensor data_mask,
|
||
|
const int batch_size, const int channels, const int height_im,
|
||
|
const int width_im, const int height_col, const int width_col,
|
||
|
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
|
||
|
const int stride_h, const int stride_w, const int dilation_h,
|
||
|
const int dilation_w, const int deformable_group, at::Tensor grad_offset,
|
||
|
at::Tensor grad_mask);
|
||
|
|
||
|
void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
|
||
|
at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
|
||
|
int padW, int dilationH, int dilationW, int group,
|
||
|
int deformable_group) {
|
||
|
AT_CHECK(weight.ndimension() == 4,
|
||
|
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
|
||
|
"but got: %s",
|
||
|
weight.ndimension());
|
||
|
|
||
|
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
||
|
|
||
|
AT_CHECK(kW > 0 && kH > 0,
|
||
|
"kernel size should be greater than zero, but got kH: %d kW: %d", kH,
|
||
|
kW);
|
||
|
|
||
|
AT_CHECK((weight.size(2) == kH && weight.size(3) == kW),
|
||
|
"kernel size should be consistent with weight, ",
|
||
|
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
|
||
|
kW, weight.size(2), weight.size(3));
|
||
|
|
||
|
AT_CHECK(dW > 0 && dH > 0,
|
||
|
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
|
||
|
|
||
|
AT_CHECK(
|
||
|
dilationW > 0 && dilationH > 0,
|
||
|
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
|
||
|
dilationH, dilationW);
|
||
|
|
||
|
int ndim = input.ndimension();
|
||
|
int dimf = 0;
|
||
|
int dimh = 1;
|
||
|
int dimw = 2;
|
||
|
|
||
|
if (ndim == 4) {
|
||
|
dimf++;
|
||
|
dimh++;
|
||
|
dimw++;
|
||
|
}
|
||
|
|
||
|
AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
|
||
|
ndim);
|
||
|
|
||
|
long nInputPlane = weight.size(1) * group;
|
||
|
long inputHeight = input.size(dimh);
|
||
|
long inputWidth = input.size(dimw);
|
||
|
long nOutputPlane = weight.size(0);
|
||
|
long outputHeight =
|
||
|
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||
|
long outputWidth =
|
||
|
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||
|
|
||
|
AT_CHECK(nInputPlane % deformable_group == 0,
|
||
|
"input channels must divide deformable group size");
|
||
|
|
||
|
if (outputWidth < 1 || outputHeight < 1)
|
||
|
AT_ERROR(
|
||
|
"Given input size: (%ld x %ld x %ld). "
|
||
|
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
|
||
|
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
|
||
|
outputWidth);
|
||
|
|
||
|
AT_CHECK(input.size(1) == nInputPlane,
|
||
|
"invalid number of input planes, expected: %d, but got: %d",
|
||
|
nInputPlane, input.size(1));
|
||
|
|
||
|
AT_CHECK((inputHeight >= kH && inputWidth >= kW),
|
||
|
"input image is smaller than kernel");
|
||
|
|
||
|
AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
|
||
|
"invalid spatial size of offset, expected height: %d width: %d, but "
|
||
|
"got height: %d width: %d",
|
||
|
outputHeight, outputWidth, offset.size(2), offset.size(3));
|
||
|
|
||
|
AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
|
||
|
"invalid number of channels of offset");
|
||
|
|
||
|
if (gradOutput != NULL) {
|
||
|
AT_CHECK(gradOutput->size(dimf) == nOutputPlane,
|
||
|
"invalid number of gradOutput planes, expected: %d, but got: %d",
|
||
|
nOutputPlane, gradOutput->size(dimf));
|
||
|
|
||
|
AT_CHECK((gradOutput->size(dimh) == outputHeight &&
|
||
|
gradOutput->size(dimw) == outputWidth),
|
||
|
"invalid size of gradOutput, expected height: %d width: %d , but "
|
||
|
"got height: %d width: %d",
|
||
|
outputHeight, outputWidth, gradOutput->size(dimh),
|
||
|
gradOutput->size(dimw));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
|
||
|
at::Tensor offset, at::Tensor output,
|
||
|
at::Tensor columns, at::Tensor ones, int kW,
|
||
|
int kH, int dW, int dH, int padW, int padH,
|
||
|
int dilationW, int dilationH, int group,
|
||
|
int deformable_group, int im2col_step) {
|
||
|
// todo: resize columns to include im2col: done
|
||
|
// todo: add im2col_step as input
|
||
|
// todo: add new output buffer and transpose it to output (or directly
|
||
|
// transpose output) todo: possibly change data indexing because of
|
||
|
// parallel_imgs
|
||
|
|
||
|
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
|
||
|
dilationH, dilationW, group, deformable_group);
|
||
|
|
||
|
input = input.contiguous();
|
||
|
offset = offset.contiguous();
|
||
|
weight = weight.contiguous();
|
||
|
|
||
|
int batch = 1;
|
||
|
if (input.ndimension() == 3) {
|
||
|
// Force batch
|
||
|
batch = 0;
|
||
|
input.unsqueeze_(0);
|
||
|
offset.unsqueeze_(0);
|
||
|
}
|
||
|
|
||
|
// todo: assert batchsize dividable by im2col_step
|
||
|
|
||
|
long batchSize = input.size(0);
|
||
|
long nInputPlane = input.size(1);
|
||
|
long inputHeight = input.size(2);
|
||
|
long inputWidth = input.size(3);
|
||
|
|
||
|
long nOutputPlane = weight.size(0);
|
||
|
|
||
|
long outputWidth =
|
||
|
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||
|
long outputHeight =
|
||
|
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||
|
|
||
|
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
||
|
|
||
|
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
|
||
|
outputHeight, outputWidth});
|
||
|
columns = at::zeros(
|
||
|
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||
|
input.options());
|
||
|
|
||
|
if (ones.ndimension() != 2 ||
|
||
|
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
|
||
|
ones = at::ones({outputHeight, outputWidth}, input.options());
|
||
|
}
|
||
|
|
||
|
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||
|
inputHeight, inputWidth});
|
||
|
offset =
|
||
|
offset.view({batchSize / im2col_step, im2col_step,
|
||
|
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||
|
|
||
|
at::Tensor output_buffer =
|
||
|
at::zeros({batchSize / im2col_step, nOutputPlane,
|
||
|
im2col_step * outputHeight, outputWidth},
|
||
|
output.options());
|
||
|
|
||
|
output_buffer = output_buffer.view(
|
||
|
{output_buffer.size(0), group, output_buffer.size(1) / group,
|
||
|
output_buffer.size(2), output_buffer.size(3)});
|
||
|
|
||
|
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||
|
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
||
|
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||
|
dilationW, im2col_step, deformable_group, columns);
|
||
|
|
||
|
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||
|
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||
|
weight.size(2), weight.size(3)});
|
||
|
|
||
|
for (int g = 0; g < group; g++) {
|
||
|
output_buffer[elt][g] = output_buffer[elt][g]
|
||
|
.flatten(1)
|
||
|
.addmm_(weight[g].flatten(1), columns[g])
|
||
|
.view_as(output_buffer[elt][g]);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
output_buffer = output_buffer.view(
|
||
|
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
|
||
|
output_buffer.size(3), output_buffer.size(4)});
|
||
|
|
||
|
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
|
||
|
im2col_step, outputHeight, outputWidth});
|
||
|
output_buffer.transpose_(1, 2);
|
||
|
output.copy_(output_buffer);
|
||
|
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||
|
|
||
|
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||
|
offset = offset.view(
|
||
|
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||
|
|
||
|
if (batch == 0) {
|
||
|
output = output.view({nOutputPlane, outputHeight, outputWidth});
|
||
|
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||
|
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||
|
}
|
||
|
|
||
|
return 1;
|
||
|
}
|
||
|
|
||
|
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
|
||
|
at::Tensor gradOutput, at::Tensor gradInput,
|
||
|
at::Tensor gradOffset, at::Tensor weight,
|
||
|
at::Tensor columns, int kW, int kH, int dW,
|
||
|
int dH, int padW, int padH, int dilationW,
|
||
|
int dilationH, int group,
|
||
|
int deformable_group, int im2col_step) {
|
||
|
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
|
||
|
dilationH, dilationW, group, deformable_group);
|
||
|
|
||
|
input = input.contiguous();
|
||
|
offset = offset.contiguous();
|
||
|
gradOutput = gradOutput.contiguous();
|
||
|
weight = weight.contiguous();
|
||
|
|
||
|
int batch = 1;
|
||
|
|
||
|
if (input.ndimension() == 3) {
|
||
|
// Force batch
|
||
|
batch = 0;
|
||
|
input = input.view({1, input.size(0), input.size(1), input.size(2)});
|
||
|
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
|
||
|
gradOutput = gradOutput.view(
|
||
|
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
||
|
}
|
||
|
|
||
|
long batchSize = input.size(0);
|
||
|
long nInputPlane = input.size(1);
|
||
|
long inputHeight = input.size(2);
|
||
|
long inputWidth = input.size(3);
|
||
|
|
||
|
long nOutputPlane = weight.size(0);
|
||
|
|
||
|
long outputWidth =
|
||
|
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||
|
long outputHeight =
|
||
|
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||
|
|
||
|
AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
|
||
|
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||
|
columns = at::zeros(
|
||
|
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||
|
input.options());
|
||
|
|
||
|
// change order of grad output
|
||
|
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
||
|
nOutputPlane, outputHeight, outputWidth});
|
||
|
gradOutput.transpose_(1, 2);
|
||
|
|
||
|
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||
|
inputHeight, inputWidth});
|
||
|
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||
|
inputHeight, inputWidth});
|
||
|
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
|
||
|
deformable_group * 2 * kH * kW, outputHeight,
|
||
|
outputWidth});
|
||
|
offset =
|
||
|
offset.view({batchSize / im2col_step, im2col_step,
|
||
|
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||
|
|
||
|
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||
|
// divide into groups
|
||
|
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||
|
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||
|
weight.size(2), weight.size(3)});
|
||
|
gradOutput = gradOutput.view(
|
||
|
{gradOutput.size(0), group, gradOutput.size(1) / group,
|
||
|
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
|
||
|
|
||
|
for (int g = 0; g < group; g++) {
|
||
|
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
||
|
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
|
||
|
}
|
||
|
|
||
|
columns =
|
||
|
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||
|
gradOutput = gradOutput.view(
|
||
|
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
|
||
|
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
|
||
|
|
||
|
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
|
||
|
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
|
||
|
dilationH, dilationW, im2col_step, deformable_group,
|
||
|
gradOffset[elt]);
|
||
|
|
||
|
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
|
||
|
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||
|
dilationW, im2col_step, deformable_group, gradInput[elt]);
|
||
|
}
|
||
|
|
||
|
gradOutput.transpose_(1, 2);
|
||
|
gradOutput =
|
||
|
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||
|
|
||
|
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||
|
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||
|
gradOffset = gradOffset.view(
|
||
|
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||
|
offset = offset.view(
|
||
|
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||
|
|
||
|
if (batch == 0) {
|
||
|
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
||
|
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||
|
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
|
||
|
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||
|
gradOffset =
|
||
|
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||
|
}
|
||
|
|
||
|
return 1;
|
||
|
}
|
||
|
|
||
|
int deform_conv_backward_parameters_cuda(
|
||
|
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
|
||
|
at::Tensor gradWeight, // at::Tensor gradBias,
|
||
|
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
|
||
|
int padW, int padH, int dilationW, int dilationH, int group,
|
||
|
int deformable_group, float scale, int im2col_step) {
|
||
|
// todo: transpose and reshape outGrad
|
||
|
// todo: reshape columns
|
||
|
// todo: add im2col_step as input
|
||
|
|
||
|
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
|
||
|
padW, dilationH, dilationW, group, deformable_group);
|
||
|
|
||
|
input = input.contiguous();
|
||
|
offset = offset.contiguous();
|
||
|
gradOutput = gradOutput.contiguous();
|
||
|
|
||
|
int batch = 1;
|
||
|
|
||
|
if (input.ndimension() == 3) {
|
||
|
// Force batch
|
||
|
batch = 0;
|
||
|
input = input.view(
|
||
|
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
|
||
|
gradOutput = gradOutput.view(
|
||
|
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
||
|
}
|
||
|
|
||
|
long batchSize = input.size(0);
|
||
|
long nInputPlane = input.size(1);
|
||
|
long inputHeight = input.size(2);
|
||
|
long inputWidth = input.size(3);
|
||
|
|
||
|
long nOutputPlane = gradWeight.size(0);
|
||
|
|
||
|
long outputWidth =
|
||
|
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||
|
long outputHeight =
|
||
|
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||
|
|
||
|
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
||
|
|
||
|
columns = at::zeros(
|
||
|
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||
|
input.options());
|
||
|
|
||
|
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
||
|
nOutputPlane, outputHeight, outputWidth});
|
||
|
gradOutput.transpose_(1, 2);
|
||
|
|
||
|
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
|
||
|
gradOutputBuffer =
|
||
|
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
|
||
|
outputHeight, outputWidth});
|
||
|
gradOutputBuffer.copy_(gradOutput);
|
||
|
gradOutputBuffer =
|
||
|
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
|
||
|
im2col_step * outputHeight, outputWidth});
|
||
|
|
||
|
gradOutput.transpose_(1, 2);
|
||
|
gradOutput =
|
||
|
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||
|
|
||
|
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||
|
inputHeight, inputWidth});
|
||
|
offset =
|
||
|
offset.view({batchSize / im2col_step, im2col_step,
|
||
|
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||
|
|
||
|
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||
|
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
||
|
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||
|
dilationW, im2col_step, deformable_group, columns);
|
||
|
|
||
|
// divide into group
|
||
|
gradOutputBuffer = gradOutputBuffer.view(
|
||
|
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
|
||
|
gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
|
||
|
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||
|
gradWeight =
|
||
|
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
|
||
|
gradWeight.size(2), gradWeight.size(3)});
|
||
|
|
||
|
for (int g = 0; g < group; g++) {
|
||
|
gradWeight[g] = gradWeight[g]
|
||
|
.flatten(1)
|
||
|
.addmm_(gradOutputBuffer[elt][g].flatten(1),
|
||
|
columns[g].transpose(1, 0), 1.0, scale)
|
||
|
.view_as(gradWeight[g]);
|
||
|
}
|
||
|
gradOutputBuffer = gradOutputBuffer.view(
|
||
|
{gradOutputBuffer.size(0),
|
||
|
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
|
||
|
gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
|
||
|
columns =
|
||
|
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||
|
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
|
||
|
gradWeight.size(2), gradWeight.size(3),
|
||
|
gradWeight.size(4)});
|
||
|
}
|
||
|
|
||
|
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||
|
offset = offset.view(
|
||
|
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||
|
|
||
|
if (batch == 0) {
|
||
|
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
||
|
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||
|
}
|
||
|
|
||
|
return 1;
|
||
|
}
|
||
|
|
||
|
void modulated_deform_conv_cuda_forward(
|
||
|
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
||
|
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
|
||
|
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
|
||
|
const int pad_h, const int pad_w, const int dilation_h,
|
||
|
const int dilation_w, const int group, const int deformable_group,
|
||
|
const bool with_bias) {
|
||
|
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
||
|
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
||
|
|
||
|
const int batch = input.size(0);
|
||
|
const int channels = input.size(1);
|
||
|
const int height = input.size(2);
|
||
|
const int width = input.size(3);
|
||
|
|
||
|
const int channels_out = weight.size(0);
|
||
|
const int channels_kernel = weight.size(1);
|
||
|
const int kernel_h_ = weight.size(2);
|
||
|
const int kernel_w_ = weight.size(3);
|
||
|
|
||
|
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
||
|
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
||
|
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
||
|
if (channels != channels_kernel * group)
|
||
|
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
||
|
channels, channels_kernel * group);
|
||
|
|
||
|
const int height_out =
|
||
|
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||
|
const int width_out =
|
||
|
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||
|
|
||
|
if (ones.ndimension() != 2 ||
|
||
|
ones.size(0) * ones.size(1) < height_out * width_out) {
|
||
|
// Resize plane and fill with ones...
|
||
|
ones = at::ones({height_out, width_out}, input.options());
|
||
|
}
|
||
|
|
||
|
// resize output
|
||
|
output = output.view({batch, channels_out, height_out, width_out}).zero_();
|
||
|
// resize temporary columns
|
||
|
columns =
|
||
|
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
|
||
|
input.options());
|
||
|
|
||
|
output = output.view({output.size(0), group, output.size(1) / group,
|
||
|
output.size(2), output.size(3)});
|
||
|
|
||
|
for (int b = 0; b < batch; b++) {
|
||
|
modulated_deformable_im2col_cuda(
|
||
|
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
||
|
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||
|
dilation_h, dilation_w, deformable_group, columns);
|
||
|
|
||
|
// divide into group
|
||
|
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||
|
weight.size(2), weight.size(3)});
|
||
|
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||
|
|
||
|
for (int g = 0; g < group; g++) {
|
||
|
output[b][g] = output[b][g]
|
||
|
.flatten(1)
|
||
|
.addmm_(weight[g].flatten(1), columns[g])
|
||
|
.view_as(output[b][g]);
|
||
|
}
|
||
|
|
||
|
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||
|
weight.size(3), weight.size(4)});
|
||
|
columns =
|
||
|
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||
|
}
|
||
|
|
||
|
output = output.view({output.size(0), output.size(1) * output.size(2),
|
||
|
output.size(3), output.size(4)});
|
||
|
|
||
|
if (with_bias) {
|
||
|
output += bias.view({1, bias.size(0), 1, 1});
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void modulated_deform_conv_cuda_backward(
|
||
|
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
||
|
at::Tensor offset, at::Tensor mask, at::Tensor columns,
|
||
|
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
|
||
|
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
|
||
|
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
||
|
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
||
|
const bool with_bias) {
|
||
|
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
||
|
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
||
|
|
||
|
const int batch = input.size(0);
|
||
|
const int channels = input.size(1);
|
||
|
const int height = input.size(2);
|
||
|
const int width = input.size(3);
|
||
|
|
||
|
const int channels_kernel = weight.size(1);
|
||
|
const int kernel_h_ = weight.size(2);
|
||
|
const int kernel_w_ = weight.size(3);
|
||
|
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
||
|
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
||
|
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
||
|
if (channels != channels_kernel * group)
|
||
|
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
||
|
channels, channels_kernel * group);
|
||
|
|
||
|
const int height_out =
|
||
|
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||
|
const int width_out =
|
||
|
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||
|
|
||
|
if (ones.ndimension() != 2 ||
|
||
|
ones.size(0) * ones.size(1) < height_out * width_out) {
|
||
|
// Resize plane and fill with ones...
|
||
|
ones = at::ones({height_out, width_out}, input.options());
|
||
|
}
|
||
|
|
||
|
grad_input = grad_input.view({batch, channels, height, width});
|
||
|
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
|
||
|
input.options());
|
||
|
|
||
|
grad_output =
|
||
|
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
|
||
|
grad_output.size(2), grad_output.size(3)});
|
||
|
|
||
|
for (int b = 0; b < batch; b++) {
|
||
|
// divide int group
|
||
|
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||
|
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||
|
weight.size(2), weight.size(3)});
|
||
|
|
||
|
for (int g = 0; g < group; g++) {
|
||
|
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
||
|
grad_output[b][g].flatten(1), 0.0f, 1.0f);
|
||
|
}
|
||
|
|
||
|
columns =
|
||
|
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||
|
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||
|
weight.size(3), weight.size(4)});
|
||
|
|
||
|
// gradient w.r.t. input coordinate data
|
||
|
modulated_deformable_col2im_coord_cuda(
|
||
|
columns, input[b], offset[b], mask[b], 1, channels, height, width,
|
||
|
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
|
||
|
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
|
||
|
grad_mask[b]);
|
||
|
// gradient w.r.t. input data
|
||
|
modulated_deformable_col2im_cuda(
|
||
|
columns, offset[b], mask[b], 1, channels, height, width, height_out,
|
||
|
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||
|
dilation_h, dilation_w, deformable_group, grad_input[b]);
|
||
|
|
||
|
// gradient w.r.t. weight, dWeight should accumulate across the batch and
|
||
|
// group
|
||
|
modulated_deformable_im2col_cuda(
|
||
|
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
||
|
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||
|
dilation_h, dilation_w, deformable_group, columns);
|
||
|
|
||
|
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||
|
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
|
||
|
grad_weight.size(1), grad_weight.size(2),
|
||
|
grad_weight.size(3)});
|
||
|
if (with_bias)
|
||
|
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
|
||
|
|
||
|
for (int g = 0; g < group; g++) {
|
||
|
grad_weight[g] =
|
||
|
grad_weight[g]
|
||
|
.flatten(1)
|
||
|
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
|
||
|
.view_as(grad_weight[g]);
|
||
|
if (with_bias) {
|
||
|
grad_bias[g] =
|
||
|
grad_bias[g]
|
||
|
.view({-1, 1})
|
||
|
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
|
||
|
.view(-1);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
columns =
|
||
|
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||
|
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
|
||
|
grad_weight.size(2), grad_weight.size(3),
|
||
|
grad_weight.size(4)});
|
||
|
if (with_bias)
|
||
|
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
|
||
|
}
|
||
|
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
|
||
|
grad_output.size(2), grad_output.size(3),
|
||
|
grad_output.size(4)});
|
||
|
}
|
||
|
|
||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||
|
m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda,
|
||
|
"deform forward (CUDA)");
|
||
|
m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda,
|
||
|
"deform_conv_backward_input (CUDA)");
|
||
|
m.def("deform_conv_backward_parameters_cuda",
|
||
|
&deform_conv_backward_parameters_cuda,
|
||
|
"deform_conv_backward_parameters (CUDA)");
|
||
|
m.def("modulated_deform_conv_cuda_forward",
|
||
|
&modulated_deform_conv_cuda_forward,
|
||
|
"modulated deform conv forward (CUDA)");
|
||
|
m.def("modulated_deform_conv_cuda_backward",
|
||
|
&modulated_deform_conv_cuda_backward,
|
||
|
"modulated deform conv backward (CUDA)");
|
||
|
}
|