diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-06-18 18:44:16 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-06-18 18:44:16 +0300 |
commit | 67c87efa98fa9eb79f71e2c3b792a16077eef1fd (patch) | |
tree | 7c9dbc83f08b55c84374bc92156cbd8416d273db | |
parent | e5181dc9a318beae6028a94c4970e92f8ae76c71 (diff) | |
parent | ca50de62f02d33ec12fd481737808d548529a0a9 (diff) |
Merge pull request #300 from jonathantompson/volpad
Added VolumetricReplicationPadding.
-rw-r--r-- | lib/THCUNN/THCUNN.h | 16 | ||||
-rw-r--r-- | lib/THCUNN/VolumetricReplicationPadding.cu | 189 | ||||
-rw-r--r-- | test.lua | 109 |
3 files changed, 314 insertions, 0 deletions
diff --git a/lib/THCUNN/THCUNN.h b/lib/THCUNN/THCUNN.h index cbc71b4..87a220a 100644 --- a/lib/THCUNN/THCUNN.h +++ b/lib/THCUNN/THCUNN.h @@ -937,3 +937,19 @@ TH_API void THNN_CudaSpatialReplicationPadding_updateGradInput(THCState *state, THCudaTensor *gradInput, int padL, int padR, int padT, int padB); + +TH_API void THNN_CudaVolumetricReplicationPadding_updateOutput( + THCState *state, + THCudaTensor *input, + THCudaTensor *output, + int pleft, int pright, + int ptop, int pbottom, + int pfront, int pback); +TH_API void THNN_CudaVolumetricReplicationPadding_updateGradInput( + THCState *state, + THCudaTensor *input, + THCudaTensor *gradOutput, + THCudaTensor *gradInput, + int pleft, int pright, + int ptop, int pbottom, + int pfront, int pback); diff --git a/lib/THCUNN/VolumetricReplicationPadding.cu b/lib/THCUNN/VolumetricReplicationPadding.cu new file mode 100644 index 0000000..cca51d5 --- /dev/null +++ b/lib/THCUNN/VolumetricReplicationPadding.cu @@ -0,0 +1,189 @@ +#include "THCUNN.h" + +#include "THCDeviceTensor.cuh" +#include "THCDeviceTensorUtils.cuh" +#include "THCDeviceUtils.cuh" +#include "THCReduceApplyUtils.cuh" + +__global__ void VolumetricReplicationPadding_updateOutput( + THCDeviceTensor<float, 5> input, + THCDeviceTensor<float, 5> output, + int pfront, int pback, int ptop, int pbottom, int pleft, int pright) { + + int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; + int plane = blockIdx.y; + int batch = blockIdx.z; + if (outputPointId >= (output.getSize(2) * output.getSize(3) * + output.getSize(4))) { + return; + } + int outputPointX = outputPointId % output.getSize(4); + int outputPointY = (outputPointId / output.getSize(4)) % output.getSize(3); + int outputPointZ = outputPointId / (output.getSize(3) * output.getSize(4)); + + int iStartX = max(0, -pleft); + int iStartY = max(0, -ptop); + int iStartZ = max(0, -pfront); + int oStartX = max(0, pleft); + int oStartY = max(0, ptop); + int oStartZ = max(0, pfront); + + int inputPointX = min(max(pleft, outputPointX), + input.getSize(4) + pleft - 1) - oStartX + iStartX; + int inputPointY = min(max(ptop, outputPointY), + input.getSize(3) + ptop - 1) - oStartY + iStartY; + int inputPointZ = min(max(pfront, outputPointZ), + input.getSize(2) + pfront - 1) - oStartZ + iStartZ; + + float valueToCopy = + input[batch][plane][inputPointZ][inputPointY][inputPointX]; + output[batch][plane][outputPointZ][outputPointY][outputPointX] = valueToCopy; +} + +void THNN_CudaVolumetricReplicationPadding_updateOutput(THCState *state, + THCudaTensor *input, + THCudaTensor *output, + int pleft, int pright, + int ptop, int pbottom, + int pfront, int pback) { + THArgCheck(TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, input), 2, + "input tensor must fit into 32-bit index math"); + + int planeDim = 0; + int dimd = 1; + int dimh = 2; + int dimw = 3; + int numBatch = 1; + + int numInputDims = THCudaTensor_nDimension(state, input); + THArgCheck(numInputDims == 4 || numInputDims == 5, 2, + "input must be 4 or 5-dimensional"); + + if (numInputDims == 5) { + numBatch = THCudaTensor_size(state, input, 0); + planeDim++; + dimd++; + dimh++; + dimw++; + } + + int numPlanes = THCudaTensor_size(state, input, planeDim); + int inputD = THCudaTensor_size(state, input, dimd); + int inputH = THCudaTensor_size(state, input, dimh); + int inputW = THCudaTensor_size(state, input, dimw); + int outputD = inputD + pfront + pback; + int outputH = inputH + ptop + pbottom; + int outputW = inputW + pleft + pright; + + THCDeviceTensor<float, 5> devInput; + THCDeviceTensor<float, 5> devOutput; + + if (numInputDims == 4) { + THCudaTensor_resize4d(state, output, numPlanes, outputD, outputH, outputW); + + devInput = toDeviceTensor<float, 4>(state, input).upcastOuter<5>(); + devOutput = toDeviceTensor<float, 4>(state, output).upcastOuter<5>(); + } else { + THCudaTensor_resize5d(state, output, numBatch, numPlanes, outputD, outputH, + outputW); + + devInput = toDeviceTensor<float, 5>(state, input); + devOutput = toDeviceTensor<float, 5>(state, output); + } + + int outputPlaneSize = devOutput.getSize(2) * devOutput.getSize(3) * + devOutput.getSize(4); + dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), + devOutput.getSize(1), + devOutput.getSize(0)); + dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize); + + VolumetricReplicationPadding_updateOutput<<<gridSize, blockSize, 0, THCState_getCurrentStream(state)>>>( + devInput, devOutput, pfront, pback, ptop, pbottom, pleft, pright); +} + +__global__ void VolumetricReplicationPadding_updateGradInput( + THCDeviceTensor<float, 5> gradInput, + THCDeviceTensor<float, 5> gradOutput, + int pfront, int pback, int ptop, int pbottom, int pleft, int pright) { + int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; + int plane = blockIdx.y; + int batch = blockIdx.z; + + if (outputPointId >= (gradOutput.getSize(2) * gradOutput.getSize(3) * + gradOutput.getSize(4))) { + return; + } + int outputPointX = outputPointId % gradOutput.getSize(4); + int outputPointY = (outputPointId / gradOutput.getSize(4)) % + gradOutput.getSize(3); + int outputPointZ = outputPointId / (gradOutput.getSize(3) * + gradOutput.getSize(4)); + + int iStartX = max(0, -pleft); + int iStartY = max(0, -ptop); + int iStartZ = max(0, -pfront); + int oStartX = max(0, pleft); + int oStartY = max(0, ptop); + int oStartZ = max(0, pfront); + + int inputPointX = min(max(pleft, outputPointX), + gradInput.getSize(4) + pleft - 1) - oStartX + iStartX; + int inputPointY = min(max(ptop, outputPointY), + gradInput.getSize(3) + ptop - 1) - oStartY + iStartY; + int inputPointZ = min(max(pfront, outputPointZ), + gradInput.getSize(2) + pfront - 1) - oStartZ + iStartZ; + + float valueToCopy = + gradOutput[batch][plane][outputPointZ][outputPointY][outputPointX]; + atomicAdd(&gradInput[batch][plane][inputPointZ][inputPointY][inputPointX], + valueToCopy); +} + +void THNN_CudaVolumetricReplicationPadding_updateGradInput( + THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, + THCudaTensor *gradInput, int pleft, int pright, int ptop, int pbottom, + int pfront, int pback) { + THArgCheck(TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, input), 2, + "input tensor must fit into 32-bit index math"); + THArgCheck(TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, gradOutput), + 3, "output gradient tensor must fit into 32-bit index math"); + + int planeDim = 0; + int dimd = 1; + int dimh = 2; + int dimw = 3; + + int numInputDims = THCudaTensor_nDimension(state, input); + if (numInputDims == 5) { + planeDim++; + dimd++; + dimh++; + dimw++; + } + + THCudaTensor_resizeAs(state, gradInput, input); + THCudaTensor_zero(state, gradInput); + + THCDeviceTensor<float, 5> devGradInput; + THCDeviceTensor<float, 5> devGradOutput; + + if (numInputDims == 4) { + devGradInput = toDeviceTensor<float, 4>(state, gradInput).upcastOuter<5>(); + devGradOutput = + toDeviceTensor<float, 4>(state, gradOutput).upcastOuter<5>(); + } else { + devGradInput = toDeviceTensor<float, 5>(state, gradInput); + devGradOutput = toDeviceTensor<float, 5>(state, gradOutput); + } + + int outputPlaneSize = devGradOutput.getSize(2) * devGradOutput.getSize(3) * + devGradOutput.getSize(4); + dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), + devGradOutput.getSize(1), + devGradOutput.getSize(0)); + dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize); + + VolumetricReplicationPadding_updateGradInput<<<gridSize, blockSize, 0, THCState_getCurrentStream(state)>>>( + devGradInput, devGradOutput, pfront, pback, ptop, pbottom, pleft, pright); +} @@ -5220,6 +5220,115 @@ function cunntest.SpatialReplicationPadding_backward() precision_backward, 'error on state (backward) ') end +function cunntest.VolumetricReplicationPadding_forward() + local batch = math.random(1,3) + local plane = math.random(1,3) + local sizeZ = math.random(7,16) + local sizeY = math.random(7,16) + local sizeX = math.random(7,16) + local pleft = math.random(-3,3) + local pright = math.random(-3,3) + local ptop = math.random(-3,3) + local pbottom = math.random(-3,3) + local pfront = math.random(-3,3) + local pback = math.random(-3,3) + + local tm = {} + local title = + string.format( + 'VolumetricReplicationPadding.forward %dx%dx%dx%dx%d -> ' .. + '%dx%dx%dx%dx%d', + batch, plane, sizeZ, sizeY, sizeX, + batch, plane, sizeZ + pfront + pback, sizeY + ptop + pbottom, + sizeX + pleft + pright) + times[title] = tm + + local input = torch.rand(batch, plane, sizeZ, sizeY, sizeX) + local module = nn.VolumetricReplicationPadding(pleft, pright, ptop, pbottom, + pfront, pback) + local groundtruth = module:forward(input) + local a = torch.Timer() + for i = 1, nloop do + groundtruth = module:forward(input) + end + tm.cpu = a:time().real + + input = input:cuda() + local gmodule = nn.VolumetricReplicationPadding(pleft, pright, ptop, pbottom, + pfront, pback):cuda() + local rescuda = gmodule:forward(input) + a:reset() + for i = 1, nloop do + rescuda = gmodule:forward(input) + end + cutorch.synchronize() + tm.gpu = a:time().real + + local error = rescuda:float() - groundtruth + mytester:assertlt(error:abs():max(), + precision_forward, 'error on state (forward) ') +end + +function cunntest.VolumetricReplicationPadding_backward() + local batch = math.random(1,3) + local plane = math.random(1,3) + local sizeZ = math.random(7,16) + local sizeY = math.random(7,16) + local sizeX = math.random(7,16) + local pleft = math.random(-3,3) + local pright = math.random(-3,3) + local ptop = math.random(-3,3) + local pbottom = math.random(-3,3) + local pfront = math.random(-3,3) + local pback = math.random(-3,3) + + local tm = {} + local title = + string.format( + 'VolumetricReplicationPadding.backward %dx%dx%dx%dx%d -> ' .. + '%dx%dx%dx%dx%d', + batch, plane, sizeZ, sizeY, sizeX, + batch, plane, sizeZ + pfront + pback, sizeY + ptop + pbottom, + sizeX + pleft + pright) + times[title] = tm + + local input = torch.rand(batch, plane, sizeZ, sizeY, sizeX) + local gradOutput = torch.rand( + batch, plane, sizeZ + pfront + pback, sizeY + ptop + pbottom, + sizeX + pleft + pright + ) + local module = nn.VolumetricReplicationPadding(pleft, pright, ptop, pbottom, + pfront, pback) + module:forward(input) + module:zeroGradParameters() + local groundgrad = module:backward(input, gradOutput) + local a = torch.Timer() + for i = 1, nloop do + module:zeroGradParameters() + groundgrad = module:backward(input, gradOutput) + end + tm.cpu = a:time().real + + input = input:cuda() + gradOutput = gradOutput:cuda() + local gmodule = nn.VolumetricReplicationPadding(pleft, pright, ptop, pbottom, + pfront, pback):cuda() + gmodule:forward(input) + gmodule:zeroGradParameters() + local rescuda = gmodule:backward(input, gradOutput) + a:reset() + for i = 1, nloop do + gmodule:zeroGradParameters() + rescuda = gmodule:backward(input, gradOutput) + end + cutorch.synchronize() + tm.gpu = a:time().real + + local error = rescuda:float() - groundgrad + mytester:assertlt(error:abs():max(), + precision_backward, 'error on state (backward) ') +end + local function setUp() cutorch.setDevice(1) end |