Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuca Antiga <luca.antiga@orobix.com>2017-05-29 20:02:05 +0300
committerSoumith Chintala <soumith@gmail.com>2017-06-07 18:24:05 +0300
commita3724e89d906e3ad2d8911fcb466f6147f4cab00 (patch)
treecc82cbb55c453560111fe7a3745c1ff4c3e8fd8a
parent21bc88f1b6634cf742993795c8fc26f0f6782d37 (diff)
Add 3D upsampling (nearest and trilinear) with tests
-rw-r--r--lib/THNN/generic/THNN.h33
-rw-r--r--lib/THNN/generic/VolumetricUpSamplingNearest.c226
-rw-r--r--lib/THNN/generic/VolumetricUpSamplingTrilinear.c213
-rw-r--r--lib/THNN/init.c7
4 files changed, 479 insertions, 0 deletions
diff --git a/lib/THNN/generic/THNN.h b/lib/THNN/generic/THNN.h
index d99b43b..76a28eb 100644
--- a/lib/THNN/generic/THNN.h
+++ b/lib/THNN/generic/THNN.h
@@ -1465,4 +1465,37 @@ TH_API void THNN_(VolumetricReplicationPadding_updateGradInput)(
int pleft, int pright,
int ptop, int pbottom,
int pfront, int pback);
+
+TH_API void THNN_(VolumetricUpSamplingNearest_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ int scale_factor);
+TH_API void THNN_(VolumetricUpSamplingNearest_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ int scale_factor);
+
+TH_API void THNN_(VolumetricUpSamplingTrilinear_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ int outputDepth,
+ int outputHeight,
+ int outputWidth);
+TH_API void THNN_(VolumetricUpSamplingTrilinear_updateGradInput)(
+ THNNState *state,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ int nbatch,
+ int nchannels,
+ int inputDepth,
+ int inputHeight,
+ int inputWidth,
+ int outputDepth,
+ int outputHeight,
+ int outputWidth);
+
#endif
diff --git a/lib/THNN/generic/VolumetricUpSamplingNearest.c b/lib/THNN/generic/VolumetricUpSamplingNearest.c
new file mode 100644
index 0000000..5b01a1b
--- /dev/null
+++ b/lib/THNN/generic/VolumetricUpSamplingNearest.c
@@ -0,0 +1,226 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/VolumetricUpSamplingNearest.c"
+#else
+
+
+static inline void THNN_(VolumetricUpSamplingNearest_shapeCheck)
+ (THTensor *input, THTensor *gradOutput,
+ int scale_factor) {
+ THArgCheck(input != NULL, 2, "5D input tensor expected but got NULL");
+ THArgCheck(scale_factor > 1, 4,
+ "scale_factor must be greater than 1, but got: %d", scale_factor);
+ THNN_ARGCHECK(input->nDimension == 4 || input->nDimension == 5, 2, input,
+ "4D or 5D input tensor expected but got: %s");
+ if (input->nDimension == 4) {
+ int nChannels = THTensor_(size)(input, 0);
+ int inputDepth = THTensor_(size)(input, 1);
+ int inputHeight = THTensor_(size)(input, 2);
+ int inputWidth = THTensor_(size)(input, 3);
+ int outputDepth = inputDepth * scale_factor;
+ int outputHeight = inputHeight * scale_factor;
+ int outputWidth = inputWidth * scale_factor;
+ if (gradOutput != NULL) {
+ THNN_CHECK_DIM_SIZE(gradOutput, 4, 0, nChannels);
+ THNN_CHECK_DIM_SIZE(gradOutput, 4, 1, outputDepth);
+ THNN_CHECK_DIM_SIZE(gradOutput, 4, 2, outputHeight);
+ THNN_CHECK_DIM_SIZE(gradOutput, 4, 3, outputWidth);
+ }
+ } else {
+ int nBatch = THTensor_(size)(input, 0);
+ int nChannels = THTensor_(size)(input, 1);
+ int inputDepth = THTensor_(size)(input, 2);
+ int inputHeight = THTensor_(size)(input, 3);
+ int inputWidth = THTensor_(size)(input, 4);
+ int outputDepth = inputDepth * scale_factor;
+ int outputHeight = inputHeight * scale_factor;
+ int outputWidth = inputWidth * scale_factor;
+ if (gradOutput != NULL) {
+ THNN_CHECK_DIM_SIZE(gradOutput, 5, 0, nBatch);
+ THNN_CHECK_DIM_SIZE(gradOutput, 5, 1, nChannels);
+ THNN_CHECK_DIM_SIZE(gradOutput, 5, 2, outputDepth);
+ THNN_CHECK_DIM_SIZE(gradOutput, 5, 3, outputHeight);
+ THNN_CHECK_DIM_SIZE(gradOutput, 5, 4, outputWidth);
+ }
+ }
+}
+
+void THNN_(VolumetricUpSamplingNearest_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ int scale_factor)
+{
+ THNN_(VolumetricUpSamplingNearest_shapeCheck)(input, NULL, scale_factor);
+ int inputDepth = THTensor_(size)(input, input->nDimension-3);
+ int inputHeight = THTensor_(size)(input, input->nDimension-2);
+ int inputWidth = THTensor_(size)(input, input->nDimension-1);
+ int outputDepth = inputDepth * scale_factor;
+ int outputHeight = inputHeight * scale_factor;
+ int outputWidth = inputWidth * scale_factor;
+
+ if (input->nDimension == 4) {
+ THTensor_(resize4d)(output,
+ THTensor_(size)(input, 0),
+ outputDepth, outputHeight, outputWidth);
+ } else {
+ THTensor_(resize5d)(output,
+ THTensor_(size)(input, 0),
+ THTensor_(size)(input, 1),
+ outputDepth, outputHeight, outputWidth);
+ }
+
+ int dT = scale_factor;
+ int dW = scale_factor;
+ int dH = scale_factor;
+ int xDim = input->nDimension-3;
+ int yDim = input->nDimension-2;
+ int zDim = input->nDimension-1;
+
+ // dims
+ int idim = input->nDimension;
+ int osz0 = output->size[0];
+ int osz1 = output->size[1];
+ int osz2 = output->size[2];
+ int osz3 = output->size[3];
+ int osz4 = 1;
+ if (idim > 4) {
+ osz4 = output->size[4];
+ }
+
+ // get strides
+ long *is = input->stride;
+ long *os = output->stride;
+
+ // get raw pointers
+ real *pin = THTensor_(data)(input);
+ real *pout = THTensor_(data)(output);
+
+ // perform the upsampling
+ int i0, i1, i2, i3, i4, isrc, idst;
+ int iout[5]; // Output indices
+ int iin[5]; // Input indices
+
+ for (i0 = 0; i0 < osz0; i0++) {
+ iout[0] = i0;
+ iin[0] = i0;
+ for (i1 = 0; i1 < osz1; i1++) {
+ iout[1] = i1;
+ iin[1] = i1;
+ for (i2 = 0; i2 < osz2; i2++) {
+ iout[2] = i2;
+ iin[2] = i2;
+ for (i3 = 0; i3 < osz3; i3++) {
+ iout[3] = i3;
+ iin[3] = i3;
+ for (i4 = 0; i4 < osz4; i4++) {
+ iout[4] = i4;
+ iin[4] = i4;
+
+ // set the indices for the upsampled dimensions
+ iin[xDim] = iout[xDim] / dW;
+ iin[yDim] = iout[yDim] / dH;
+ iin[zDim] = iout[zDim] / dT;
+
+ idst = i0*os[0] + i1*os[1] + i2*os[2] + i3*os[3];
+ isrc = iin[0]*is[0] + iin[1]*is[1] + iin[2]*is[2] + iin[3]*is[3];
+ if (idim > 4) {
+ idst += i4*os[4];
+ isrc += iin[4]*is[4];
+ }
+
+ pout[idst] = pin[isrc];
+ }
+ }
+ }
+ }
+ }
+}
+
+void THNN_(VolumetricUpSamplingNearest_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ int scale_factor)
+{
+ THNN_(VolumetricUpSamplingNearest_shapeCheck)(input, gradOutput, scale_factor);
+ THTensor_(resizeAs)(gradInput, input);
+
+ int dW = scale_factor;
+ int dH = scale_factor;
+ int dT = scale_factor;
+ int xDim = gradInput->nDimension-3;
+ int yDim = gradInput->nDimension-2;
+ int zDim = gradInput->nDimension-1;
+
+ // dims
+ int idim = gradInput->nDimension; // Guaranteed to be between 3 and 5
+ int isz0 = gradInput->size[0];
+ int isz1 = gradInput->size[1];
+ int isz2 = gradInput->size[2];
+ int isz3 = gradInput->size[3];
+ int isz4 = 1;
+ if (idim > 4) {
+ isz4 = gradInput->size[4];
+ }
+
+ // get strides
+ long *is = gradInput->stride;
+ long *os = gradOutput->stride;
+
+ // get raw pointers
+ real *pin = THTensor_(data)(gradInput);
+ real *pout = THTensor_(data)(gradOutput);
+
+ // perform the upsampling
+ int i0, i1, i2, i3, i4, isrc, idst, x, y, z;
+ int iin[5]; // Input indices
+ int iout[5]; // Output indices
+
+ THTensor_(zero)(gradInput);
+
+ for (i0 = 0; i0 < isz0; i0++) {
+ iin[0] = i0;
+ iout[0] = i0;
+ for (i1 = 0; i1 < isz1; i1++) {
+ iin[1] = i1;
+ iout[1] = i1;
+ for (i2 = 0; i2 < isz2; i2++) {
+ iin[2] = i2;
+ iout[2] = i2;
+ for (i3 = 0; i3 < isz3; i3++) {
+ iin[3] = i3;
+ iout[3] = i3;
+
+ for (i4 = 0; i4 < isz4; i4++) {
+ iin[4] = i4;
+ iout[4] = i4;
+
+ idst = i0*is[0] + i1*is[1] + i2*is[2] + i3*is[3];
+ if (idim > 4) {
+ idst += i4*is[4];
+ }
+
+ // Now accumulate the gradients from gradOutput
+ for (z = 0; z < dT; z++) {
+ for (y = 0; y < dH; y++) {
+ for (x = 0; x < dW; x++) {
+ iout[xDim] = dW * iin[xDim] + x;
+ iout[yDim] = dH * iin[yDim] + y;
+ iout[zDim] = dT * iin[zDim] + z;
+ isrc = iout[0]*os[0] + iout[1]*os[1] + iout[2]*os[2] + iout[3]*os[3];
+ if (idim > 4) {
+ isrc += iout[4]*os[4];
+ }
+ pin[idst] += pout[isrc];
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+#endif
diff --git a/lib/THNN/generic/VolumetricUpSamplingTrilinear.c b/lib/THNN/generic/VolumetricUpSamplingTrilinear.c
new file mode 100644
index 0000000..d2043cd
--- /dev/null
+++ b/lib/THNN/generic/VolumetricUpSamplingTrilinear.c
@@ -0,0 +1,213 @@
+// Adapted from interp.cpp from Caffe util by Pauline Luc
+// Originally developed by George Papandreou
+
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/VolumetricUpSamplingTrilinear.c"
+#else
+
+static inline void THNN_(VolumetricUpSamplingTrilinear_shapeCheck)
+ (THTensor *input, THTensor *gradOutput,
+ int nBatch, int nChannels,
+ int inputDepth, int inputHeight, int inputWidth,
+ int outputDepth, int outputHeight, int outputWidth) {
+ THArgCheck(inputDepth > 0 && inputHeight > 0 && inputWidth > 0
+ && outputDepth > 0 && outputHeight > 0 && outputWidth > 0, 2,
+ "input and output sizes should be greater than 0,"
+ " but got input (D: %d, H: %d, W: %d) output (D: %d, H: %d, W: %d)",
+ inputDepth, inputHeight, inputWidth, outputDepth, outputHeight, outputWidth);
+ if (input != NULL) {
+ THNN_ARGCHECK(input->nDimension == 5, 2, input,
+ "5D input tensor expected but got: %s");
+ }
+
+ if (gradOutput != NULL) {
+ THNN_CHECK_DIM_SIZE(gradOutput, 5, 0, nBatch);
+ THNN_CHECK_DIM_SIZE(gradOutput, 5, 1, nChannels);
+ THNN_CHECK_DIM_SIZE(gradOutput, 5, 2, outputDepth);
+ THNN_CHECK_DIM_SIZE(gradOutput, 5, 3, outputHeight);
+ THNN_CHECK_DIM_SIZE(gradOutput, 5, 4, outputWidth);
+ }
+}
+
+void THNN_(VolumetricUpSamplingTrilinear_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ int outputDepth,
+ int outputHeight,
+ int outputWidth){
+
+ int nbatch = THTensor_(size)(input, 0);
+ int channels = THTensor_(size)(input, 1);
+ int inputDepth = THTensor_(size)(input, 2);
+ int inputHeight = THTensor_(size)(input, 3);
+ int inputWidth = THTensor_(size)(input, 4);
+
+ THNN_(VolumetricUpSamplingTrilinear_shapeCheck)
+ (input, NULL,
+ nbatch, channels,
+ inputDepth, inputHeight, inputWidth,
+ outputDepth, outputHeight, outputWidth);
+
+ input = THTensor_(newContiguous)(input);
+ THTensor_(resize5d)(output,
+ THTensor_(size)(input, 0),
+ THTensor_(size)(input, 1),
+ outputDepth, outputHeight, outputWidth);
+ THTensor_(zero)(output);
+ real *idata = THTensor_(data)(input);
+ real *odata = THTensor_(data)(output);
+ channels = nbatch * channels;
+ THAssert(inputDepth > 0 && inputHeight > 0 && inputWidth > 0 &&
+ outputDepth > 0 && outputHeight > 0 && outputWidth > 0);
+ // special case: just copy
+ if (inputDepth == outputDepth && inputHeight == outputHeight && inputWidth == outputWidth) {
+ for (int t2 = 0; t2 < outputDepth; ++t2) {
+ const int t1 = t2;
+ for (int h2 = 0; h2 < outputHeight; ++h2) {
+ const int h1 = h2;
+ for (int w2 = 0; w2 < outputWidth; ++w2) {
+ const int w1 = w2;
+ const real* pos1 = &idata[t1 * inputHeight * inputWidth + h1 * inputWidth + w1];
+ real* pos2 = &odata[t2 * outputHeight * outputWidth + h2 * outputWidth + w2];
+ for (int c = 0; c < channels; ++c) {
+ pos2[0] = pos1[0];
+ pos1 += inputWidth * inputHeight * inputDepth;
+ pos2 += outputWidth * outputHeight * outputDepth;
+ }
+ }
+ }
+ }
+ return;
+ }
+ const float rdepth = (outputDepth > 1) ? (float)(inputDepth - 1)/(outputDepth - 1) : 0.f;
+ const float rheight = (outputHeight > 1) ? (float)(inputHeight - 1)/(outputHeight - 1) : 0.f;
+ const float rwidth = (outputWidth > 1) ? (float)(inputWidth - 1) / (outputWidth - 1) : 0.f;
+ for (int t2 = 0; t2 < outputDepth; ++t2) {
+ const float t1r = rdepth * t2;
+ const int t1 = t1r;
+ const int t1p = (t1 < inputDepth - 1) ? 1 : 0;
+ const real t1lambda = t1r - t1;
+ const real t0lambda = (real)1. - t1lambda;
+ for (int h2 = 0; h2 < outputHeight; ++h2) {
+ const float h1r = rheight * h2;
+ const int h1 = h1r;
+ const int h1p = (h1 < inputHeight - 1) ? 1 : 0;
+ const real h1lambda = h1r - h1;
+ const real h0lambda = (real)1. - h1lambda;
+ for (int w2 = 0; w2 < outputWidth; ++w2) {
+ const float w1r = rwidth * w2;
+ const int w1 = w1r;
+ const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
+ const real w1lambda = w1r - w1;
+ const real w0lambda = (real)1. - w1lambda;
+ const real* pos1 = &idata[t1 * inputHeight * inputWidth + h1 * inputWidth + w1];
+ real* pos2 = &odata[t2 * outputHeight * outputWidth + h2 * outputWidth + w2];
+ for (int c = 0; c < channels; ++c) {
+ pos2[0] = t0lambda * (h0lambda * (w0lambda * pos1[0] + w1lambda * pos1[w1p])
+ + h1lambda * (w0lambda * pos1[h1p * inputWidth]
+ + w1lambda * pos1[h1p * inputWidth + w1p]))
+ + t1lambda * (h0lambda * (w0lambda * pos1[t1p * inputHeight * inputWidth]
+ + w1lambda * pos1[t1p * inputHeight * inputWidth
+ + w1p])
+ + h1lambda * (w0lambda * pos1[t1p * inputHeight * inputWidth
+ + h1p * inputWidth]
+ + w1lambda * pos1[t1p * inputHeight * inputWidth
+ + h1p * inputWidth + w1p]));
+ pos1 += inputWidth * inputHeight * inputDepth;
+ pos2 += outputWidth * outputHeight * outputDepth;
+ }
+ }
+ }
+ }
+ THTensor_(free)(input);
+}
+
+void THNN_(VolumetricUpSamplingTrilinear_updateGradInput)(
+ THNNState *state,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ int nbatch,
+ int channels,
+ int inputDepth,
+ int inputHeight,
+ int inputWidth,
+ int outputDepth,
+ int outputHeight,
+ int outputWidth){
+
+ THNN_(VolumetricUpSamplingTrilinear_shapeCheck)
+ (NULL, gradOutput,
+ nbatch, channels,
+ inputDepth, inputHeight, inputWidth,
+ outputDepth, outputHeight, outputWidth);
+
+ THTensor_(resize5d)(gradInput, nbatch, channels, inputDepth, inputHeight, inputWidth);
+ THTensor_(zero)(gradInput);
+ gradOutput = THTensor_(newContiguous)(gradOutput);
+ real *data1 = THTensor_(data)(gradInput);
+ real *data2 = THTensor_(data)(gradOutput);
+ channels = nbatch * channels;
+
+ // special case: same-size matching grids
+ if (inputDepth == outputDepth && inputHeight == outputHeight && inputWidth == outputWidth) {
+ for (int t2 = 0; t2 < outputDepth; ++t2) {
+ const int t1 = t2;
+ for (int h2 = 0; h2 < outputHeight; ++h2) {
+ const int h1 = h2;
+ for (int w2 = 0; w2 < outputWidth; ++w2) {
+ const int w1 = w2;
+ real* pos1 = &data1[t1 * inputHeight * inputWidth + h1 * inputWidth + w1];
+ const real* pos2 = &data2[t2 * outputHeight * outputWidth + h2 * outputWidth + w2];
+ for (int c = 0; c < channels; ++c) {
+ pos1[0] += pos2[0];
+ pos1 += inputWidth * inputHeight * inputDepth;
+ pos2 += outputWidth * outputHeight * outputDepth;
+ }
+ }
+ }
+ }
+ return;
+ }
+ const float rdepth = (outputDepth > 1) ? (float)(inputDepth - 1)/(outputDepth - 1) : 0.f;
+ const float rheight = (outputHeight > 1) ? (float)(inputHeight - 1)/(outputHeight - 1) : 0.f;
+ const float rwidth = (outputWidth > 1) ? (float)(inputWidth - 1)/(outputWidth - 1) : 0.f;
+ for (int t2 = 0; t2 < outputDepth; ++t2) {
+ const float t1r = rdepth * t2;
+ const int t1 = t1r;
+ const int t1p = (t1 < inputDepth - 1) ? 1 : 0;
+ const real t1lambda = t1r - t1;
+ const real t0lambda = (real)1. - t1lambda;
+ for (int h2 = 0; h2 < outputHeight; ++h2) {
+ const float h1r = rheight * h2;
+ const int h1 = h1r;
+ const int h1p = (h1 < inputHeight - 1) ? 1 : 0;
+ const real h1lambda = h1r - h1;
+ const real h0lambda = (real)1. - h1lambda;
+ for (int w2 = 0; w2 < outputWidth; ++w2) {
+ const float w1r = rwidth * w2;
+ const int w1 = w1r;
+ const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
+ const real w1lambda = w1r - w1;
+ const real w0lambda = (real)1. - w1lambda;
+ real* pos1 = &data1[t1 * inputHeight * inputWidth + h1 * inputWidth + w1];
+ const real* pos2 = &data2[t2 * outputHeight * outputWidth + h2 * outputWidth + w2];
+ for (int c = 0; c < channels; ++c) {
+ pos1[0] += t0lambda * h0lambda * w0lambda * pos2[0];
+ pos1[w1p] += t0lambda * h0lambda * w1lambda * pos2[0];
+ pos1[h1p * inputWidth] += t0lambda * h1lambda * w0lambda * pos2[0];
+ pos1[h1p * inputWidth + w1p] += t0lambda * h1lambda * w1lambda * pos2[0];
+ pos1[t1p * inputHeight * inputWidth] += t1lambda * h0lambda * w0lambda * pos2[0];
+ pos1[t1p * inputHeight * inputWidth + w1p] += t1lambda * h0lambda * w1lambda * pos2[0];
+ pos1[t1p * inputHeight * inputWidth + h1p * inputWidth] += t1lambda * h1lambda * w0lambda * pos2[0];
+ pos1[t1p * inputHeight * inputWidth + h1p * inputWidth + w1p] += t1lambda * h1lambda * w1lambda * pos2[0];
+ pos1 += inputWidth * inputHeight * inputDepth;
+ pos2 += outputWidth * outputHeight * outputDepth;
+ }
+ }
+ }
+ }
+ THTensor_(free)(gradOutput);
+}
+
+#endif
diff --git a/lib/THNN/init.c b/lib/THNN/init.c
index 6c64015..5c8c023 100644
--- a/lib/THNN/init.c
+++ b/lib/THNN/init.c
@@ -271,3 +271,10 @@
#include "generic/VolumetricReplicationPadding.c"
#include "THGenerateFloatTypes.h"
+
+#include "generic/VolumetricUpSamplingNearest.c"
+#include "THGenerateFloatTypes.h"
+
+#include "generic/VolumetricUpSamplingTrilinear.c"
+#include "THGenerateFloatTypes.h"
+