Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkoray kavukcuoglu <koray@kavukcuoglu.org>2013-01-03 18:11:51 +0400
committerkoray kavukcuoglu <koray@kavukcuoglu.org>2013-01-03 18:11:51 +0400
commit1ee64da4abbafd92c6bea5121e9a11e55ef1cbcb (patch)
tree0ac888ab1f86d49439d9a720df6a3c5329f2ff58 /generic
parent9e6f3d322212a4c62d869de6733701be8e676754 (diff)
New NN classes
extra/nn/L1Cost.lua : L1 penalty extra/nn/SpatialFullConvolution.lua : full convolution extra/nn/SpatialFullConvolutionMap.lua : full convolution with connection table extra/nn/TanhShrink.lua : shrinkage with x-tanh(x) extra/nn/WeightedMSECriterion.lua : mean squared error with weighting mask on the target Add new nn classes that are used commonly for unsupervised training of convolutional auto encoders
Diffstat (limited to 'generic')
-rw-r--r--generic/L1Cost.c49
-rw-r--r--generic/SpatialConvolutionMap.c2
-rw-r--r--generic/SpatialFullConvolution.c191
-rw-r--r--generic/SpatialFullConvolutionMap.c225
4 files changed, 466 insertions, 1 deletions
diff --git a/generic/L1Cost.c b/generic/L1Cost.c
new file mode 100644
index 0000000..a450e06
--- /dev/null
+++ b/generic/L1Cost.c
@@ -0,0 +1,49 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/L1Cost.c"
+#else
+
+static int nn_(L1Cost_updateOutput)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
+ accreal sum;
+
+ sum = 0;
+ TH_TENSOR_APPLY(real, input, sum += fabs(*input_data););
+
+ lua_pushnumber(L, sum);
+ lua_setfield(L, 1, "output");
+
+ lua_pushnumber(L, sum);
+ return 1;
+}
+
+static int nn_(L1Cost_updateGradInput)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
+ THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
+
+ THTensor_(resizeAs)(gradInput, input);
+ TH_TENSOR_APPLY2(real, gradInput, real, input,
+ if (*input_data > 0)
+ *gradInput_data = 1;
+ else if (*input_data < 0)
+ *gradInput_data = -1;
+ else
+ *gradInput_data = 0;);
+ return 1;
+}
+
+static const struct luaL_Reg nn_(L1Cost__) [] = {
+ {"L1Cost_updateOutput", nn_(L1Cost_updateOutput)},
+ {"L1Cost_updateGradInput", nn_(L1Cost_updateGradInput)},
+ {NULL, NULL}
+};
+
+static void nn_(L1Cost_init)(lua_State *L)
+{
+ luaT_pushmetatable(L, torch_Tensor);
+ luaT_registeratname(L, nn_(L1Cost__), "nn");
+ lua_pop(L,1);
+}
+
+#endif
diff --git a/generic/SpatialConvolutionMap.c b/generic/SpatialConvolutionMap.c
index 4c289fb..a1d20bc 100644
--- a/generic/SpatialConvolutionMap.c
+++ b/generic/SpatialConvolutionMap.c
@@ -18,7 +18,7 @@ static int nn_(SpatialConvolutionMap_updateOutput)(lua_State *L)
THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
luaL_argcheck(L, input->nDimension == 3, 2, "3D tensor expected");
- luaL_argcheck(L, input->size[0] == nInputPlane, 2, "invalid number of input planes");
+ luaL_argcheck(L, input->size[0] >= nInputPlane, 2, "invalid number of input planes");
luaL_argcheck(L, input->size[2] >= kW && input->size[1] >= kH, 2, "input image smaller than kernel size");
THTensor_(resize3d)(output, nOutputPlane,
diff --git a/generic/SpatialFullConvolution.c b/generic/SpatialFullConvolution.c
new file mode 100644
index 0000000..cb2e340
--- /dev/null
+++ b/generic/SpatialFullConvolution.c
@@ -0,0 +1,191 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/SpatialFullConvolution.c"
+#else
+
+static int nn_(SpatialFullConvolution_updateOutput)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
+ int dW = luaT_getfieldcheckint(L, 1, "dW");
+ int dH = luaT_getfieldcheckint(L, 1, "dH");
+
+ THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
+ THTensor *bias = luaT_getfieldcheckudata(L, 1, "bias", torch_Tensor);
+ THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
+
+ luaL_argcheck(L, input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D(batch mode) tensor expected");
+ int dimw = 2;
+ int dimh = 1;
+ if (input->nDimension == 4) {
+ dimw++;
+ dimh++;
+ }
+
+ long nOutputPlane = weight->size[1];
+ long kW = weight->size[3];
+ long kH = weight->size[2];
+ long inputWidth = input->size[dimw];
+ long inputHeight = input->size[dimh];
+ long outputWidth = (inputWidth - 1) * dW + kW;
+ long outputHeight = (inputHeight - 1) * dH + kH;
+
+ if (input->nDimension == 3)
+ {
+ THTensor_(resize3d)(output, nOutputPlane, outputHeight, outputWidth);
+ /* add bias */
+ long i;
+ real* bias_data = THTensor_(data)(bias);
+ real* output_data = THTensor_(data)(output);
+#pragma omp parallel for private(i)
+ for (i=0; i<bias->size[0]; i++)
+ {
+ /*THTensor_(select)(outn,output,0,i);*/
+ /*TH_TENSOR_APPLY(real,outn, *outn_data = bias_data[i];);*/
+ real *ptr_output = output_data + i*outputWidth*outputHeight;
+ long j;
+ for(j = 0; j < outputWidth*outputHeight; j++)
+ ptr_output[j] = bias_data[i];
+ }
+
+ /* do convolutions */
+ THTensor *tweight = THTensor_(newTranspose)(weight,0,1);
+ THTensor_(conv2Dmv)(output, 1.0, 1.0, input, tweight, dH, dW, "F", "C");
+ THTensor_(free)(tweight);
+ }
+ else
+ {
+ THTensor_(resize4d)(output, input->size[0], nOutputPlane, outputHeight, outputWidth);
+ real* bias_data = THTensor_(data)(bias);
+ real* output_data = THTensor_(data)(output);
+
+ long p;
+#pragma omp parallel for private(p)
+ for (p=0; p<input->size[0]; p++)
+ {
+ /* BIAS */
+ long i;
+ for (i=0; i<bias->size[0]; i++)
+ {
+ real *ptr_output = output_data + p*nOutputPlane*outputWidth*outputHeight + i*outputWidth*outputHeight;
+ long j;
+ for(j = 0; j < outputWidth*outputHeight; j++)
+ ptr_output[j] = bias_data[i];
+ }
+ }
+ /* do convolutions */
+ THTensor *tweight = THTensor_(newTranspose)(weight,0,1);
+ THTensor_(conv2Dmm)(output, 1.0, 1.0, input, tweight, dH, dW, "F", "C");
+ THTensor_(free)(tweight);
+ }
+ return 1;
+}
+
+
+static int nn_(SpatialFullConvolution_updateGradInput)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
+ THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
+ int dW = luaT_getfieldcheckint(L, 1, "dW");
+ int dH = luaT_getfieldcheckint(L, 1, "dH");
+
+ THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
+ THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
+
+ long nOutputPlane = weight->size[1];
+ THArgCheck( nOutputPlane == gradOutput->size[input->nDimension == 4 ? 1 : 0], 1, "Number of output features is not equal to nOutputPlane" );
+
+ if (input->nDimension == 3)
+ {
+ /* gradient to input */
+ THTensor_(conv2Dmv)(gradInput, 0.0, 1.0, gradOutput, weight, dH, dW, "V", "X");
+ }
+ else
+ {
+ /* gradient to input */
+ THTensor_(conv2Dmm)(gradInput, 0.0, 1.0, gradOutput, weight, dH, dW, "V", "X");
+ }
+
+ return 1;
+}
+
+static int nn_(SpatialFullConvolution_accGradParameters)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
+ THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
+ real scale = luaL_optnumber(L, 4, 1);
+ int dW = luaT_getfieldcheckint(L, 1, "dW");
+ int dH = luaT_getfieldcheckint(L, 1, "dH");
+
+ THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
+ THTensor *gradWeight = luaT_getfieldcheckudata(L, 1, "gradWeight", torch_Tensor);
+ THTensor *gradBias = luaT_getfieldcheckudata(L, 1, "gradBias", torch_Tensor);
+
+ long nOutputPlane = weight->size[1];
+ THArgCheck( nOutputPlane == gradOutput->size[input->nDimension == 4 ? 1 : 0], 1, "Number of output features is not equal to nOutputPlane" );
+
+ int dimw = 2;
+ int dimh = 1;
+
+ if (input->nDimension == 4)
+ {
+ dimw++;
+ dimh++;
+ }
+ /* gradient to bias */
+ real *gradBias_data = THTensor_(data)(gradBias);
+ real *gradOutput_data = THTensor_(data)(gradOutput);
+ long noutSlice = gradOutput->size[dimh]*gradOutput->size[dimw];
+ /*THTensor* gradOutSlice = THTensor_(new)();*/
+
+ if (input->nDimension == 3)
+ {
+ long k;
+#pragma omp parallel for private(k)
+ for(k = 0; k < nOutputPlane; k++)
+ {
+ /*THTensor_(select)(gradOutSlice, gradOutput, 0, k);*/
+ real *ptr_gradOutput = gradOutput_data + k*noutSlice;
+ long l;
+ for(l = 0; l < noutSlice; l++)
+ gradBias_data[k] += scale*ptr_gradOutput[l];
+ }
+
+ /* gradient to kernels */
+ THTensor_(conv2DRevger)(gradWeight, 1.0, scale, gradOutput, input, dH, dW);
+ }
+ else
+ {
+ long k;
+#pragma omp parallel for private(k)
+ for(k = 0; k < nOutputPlane; k++)
+ {
+ long p;
+ for(p = 0; p < input->size[0]; p++)
+ {
+ /* BIAS */
+ real *ptr_gradOutput = gradOutput_data + p*nOutputPlane*noutSlice + k*noutSlice;
+ long l;
+ for(l = 0; l < noutSlice; l++)
+ gradBias_data[k] += scale*ptr_gradOutput[l];
+ }
+ }
+ /* gradient to kernels */
+ THTensor_(conv2DRevgerm)(gradWeight, 1.0, scale, gradOutput, input, dH, dW);
+ }
+ return 0;
+}
+
+static const struct luaL_Reg nn_(SpatialFullConvolution__) [] = {
+ {"SpatialFullConvolution_updateOutput", nn_(SpatialFullConvolution_updateOutput)},
+ {"SpatialFullConvolution_updateGradInput", nn_(SpatialFullConvolution_updateGradInput)},
+ {"SpatialFullConvolution_accGradParameters", nn_(SpatialFullConvolution_accGradParameters)},
+ {NULL, NULL}
+};
+
+static void nn_(SpatialFullConvolution_init)(lua_State *L)
+{
+ luaT_pushmetatable(L, torch_Tensor);
+ luaT_registeratname(L, nn_(SpatialFullConvolution__), "nn");
+ lua_pop(L,1);
+}
+
+#endif
diff --git a/generic/SpatialFullConvolutionMap.c b/generic/SpatialFullConvolutionMap.c
new file mode 100644
index 0000000..8a5d9df
--- /dev/null
+++ b/generic/SpatialFullConvolutionMap.c
@@ -0,0 +1,225 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/SpatialFullConvolutionMap.c"
+#else
+
+static int nn_(SpatialFullConvolutionMap_updateOutput)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
+ int kW = luaT_getfieldcheckint(L, 1, "kW");
+ int kH = luaT_getfieldcheckint(L, 1, "kH");
+ int dW = luaT_getfieldcheckint(L, 1, "dW");
+ int dH = luaT_getfieldcheckint(L, 1, "dH");
+ int nInputPlane = luaT_getfieldcheckint(L, 1, "nInputPlane");
+ int nOutputPlane = luaT_getfieldcheckint(L, 1, "nOutputPlane");
+
+
+ THTensor *connTable = luaT_getfieldcheckudata(L, 1, "connTable", torch_Tensor);
+ THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
+ THTensor *bias = luaT_getfieldcheckudata(L, 1, "bias", torch_Tensor);
+ THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
+
+ luaL_argcheck(L, input->nDimension == 3, 2, "3D tensor expected");
+ luaL_argcheck(L, input->size[0] >= nInputPlane, 2, "invalid number of input planes");
+
+
+ THTensor_(resize3d)(output, nOutputPlane,
+ (input->size[1] - 1) * dH + kH,
+ (input->size[2] - 1) * dW + kW);
+
+ // contiguous
+ input = THTensor_(newContiguous)(input);
+ output = THTensor_(newContiguous)(output);
+
+ // get raw pointers
+ real *input_data = THTensor_(data)(input);
+ real *output_data = THTensor_(data)(output);
+ real *weight_data = THTensor_(data)(weight);
+ real *bias_data = THTensor_(data)(bias);
+ real *connTable_data = THTensor_(data)(connTable);
+
+ // and dims
+ long input_h = input->size[1];
+ long input_w = input->size[2];
+ long output_h = output->size[1];
+ long output_w = output->size[2];
+ long weight_h = weight->size[1];
+ long weight_w = weight->size[2];
+
+ long p;
+#pragma omp parallel for private(p)
+ for (p = 0; p < nOutputPlane; p++) {
+ // add bias
+ real *ptr_output = output_data + p*output_w*output_h;
+ long j;
+ for(j = 0; j < output_h*output_w; j++)
+ ptr_output[j] = bias_data[p];
+
+ // convolve all maps
+ int nweight = connTable->size[0];
+ long k;
+ for (k = 0; k < nweight; k++) {
+ // get offsets for input/output
+ int o = (int)connTable_data[k*2+1]-1;
+ int i = (int)connTable_data[k*2+0]-1;
+
+ if (o == p)
+ {
+ THTensor_(fullConv2Dptr)(output_data + o*output_w*output_h,
+ 1.0,
+ input_data + i*input_w*input_h, input_h, input_w,
+ weight_data + k*weight_w*weight_h, weight_h, weight_w,
+ dH, dW);
+ }
+ }
+ }
+
+ // clean up
+ THTensor_(free)(input);
+ THTensor_(free)(output);
+
+ return 1;
+}
+
+static int nn_(SpatialFullConvolutionMap_updateGradInput)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
+ THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
+ int dW = luaT_getfieldcheckint(L, 1, "dW");
+ int dH = luaT_getfieldcheckint(L, 1, "dH");
+ int nInputPlane = luaT_getfieldcheckint(L, 1, "nInputPlane");
+
+ THTensor *connTable = luaT_getfieldcheckudata(L, 1, "connTable", torch_Tensor);
+ THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
+ THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
+
+ // contiguous
+ gradInput = THTensor_(newContiguous)(gradInput);
+ gradOutput = THTensor_(newContiguous)(gradOutput);
+
+ // Resize/Zero
+ THTensor_(resizeAs)(gradInput, input);
+ THTensor_(zero)(gradInput);
+
+ // get raw pointers
+ real *gradInput_data = THTensor_(data)(gradInput);
+ real *gradOutput_data = THTensor_(data)(gradOutput);
+ real *weight_data = THTensor_(data)(weight);
+ real *connTable_data = THTensor_(data)(connTable);
+
+ // and dims
+ long input_h = input->size[1];
+ long input_w = input->size[2];
+ long output_h = gradOutput->size[1];
+ long output_w = gradOutput->size[2];
+ long weight_h = weight->size[1];
+ long weight_w = weight->size[2];
+
+ long p;
+#pragma omp parallel for private(p)
+ for(p = 0; p < nInputPlane; p++)
+ {
+ long k;
+ // backward all
+ int nkernel = connTable->size[0];
+ for(k = 0; k < nkernel; k++)
+ {
+ int o = (int)connTable_data[k*2+1]-1;
+ int i = (int)connTable_data[k*2+0]-1;
+ if (i == p)
+ {
+ // gradient to input
+ THTensor_(validXCorr2Dptr)(gradInput_data + i*input_w*input_h,
+ 1.0,
+ gradOutput_data + o*output_w*output_h, output_h, output_w,
+ weight_data + k*weight_w*weight_h, weight_h, weight_w,
+ dH, dW);
+ }
+ }
+ }
+
+ // clean up
+ THTensor_(free)(gradInput);
+ THTensor_(free)(gradOutput);
+
+ return 1;
+}
+
+static int nn_(SpatialFullConvolutionMap_accGradParameters)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
+ THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
+ int dW = luaT_getfieldcheckint(L, 1, "dW");
+ int dH = luaT_getfieldcheckint(L, 1, "dH");
+ int nOutputPlane = luaT_getfieldcheckint(L, 1, "nOutputPlane");
+ real scale = luaL_optnumber(L, 4, 1);
+
+ THTensor *connTable = luaT_getfieldcheckudata(L, 1, "connTable", torch_Tensor);
+ THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
+ THTensor *gradWeight = luaT_getfieldcheckudata(L, 1, "gradWeight", torch_Tensor);
+ THTensor *gradBias = luaT_getfieldcheckudata(L, 1, "gradBias", torch_Tensor);
+
+ // contiguous
+ input = THTensor_(newContiguous)(input);
+ gradOutput = THTensor_(newContiguous)(gradOutput);
+
+ // get raw pointers
+ real *input_data = THTensor_(data)(input);
+ real *gradOutput_data = THTensor_(data)(gradOutput);
+ real *gradWeight_data = THTensor_(data)(gradWeight);
+ real *gradBias_data = THTensor_(data)(gradBias);
+
+ // and dims
+ long input_h = input->size[1];
+ long input_w = input->size[2];
+ long output_h = gradOutput->size[1];
+ long output_w = gradOutput->size[2];
+ long weight_h = weight->size[1];
+ long weight_w = weight->size[2];
+
+ // gradients wrt bias
+ long k;
+#pragma omp parallel for private(k)
+ for(k = 0; k < nOutputPlane; k++) {
+ real *ptr_gradOutput = gradOutput_data + k*output_w*output_h;
+ long l;
+ for(l = 0; l < output_h*output_w; l++)
+ gradBias_data[k] += scale*ptr_gradOutput[l];
+ }
+
+ // gradients wrt weight
+ int nkernel = connTable->size[0];
+#pragma omp parallel for private(k)
+ for(k = 0; k < nkernel; k++)
+ {
+ int o = (int)THTensor_(get2d)(connTable,k,1)-1;
+ int i = (int)THTensor_(get2d)(connTable,k,0)-1;
+
+ // gradient to kernel
+ THTensor_(validXCorr2DRevptr)(gradWeight_data + k*weight_w*weight_h,
+ scale,
+ gradOutput_data + o*output_w*output_h, output_h, output_w,
+ input_data + i*input_w*input_h, input_h, input_w,
+ dH, dW);
+ }
+
+ // clean up
+ THTensor_(free)(input);
+ THTensor_(free)(gradOutput);
+ return 0;
+}
+
+static const struct luaL_Reg nn_(SpatialFullConvolutionMapStuff__) [] = {
+ {"SpatialFullConvolutionMap_updateOutput", nn_(SpatialFullConvolutionMap_updateOutput)},
+ {"SpatialFullConvolutionMap_updateGradInput", nn_(SpatialFullConvolutionMap_updateGradInput)},
+ {"SpatialFullConvolutionMap_accGradParameters", nn_(SpatialFullConvolutionMap_accGradParameters)},
+ {NULL, NULL}
+};
+
+static void nn_(SpatialFullConvolutionMap_init)(lua_State *L)
+{
+ luaT_pushmetatable(L, torch_Tensor);
+ luaT_registeratname(L, nn_(SpatialFullConvolutionMapStuff__), "nn");
+ lua_pop(L,1);
+}
+
+#endif