Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-09-07 18:41:37 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-09-07 18:41:37 +0400
commit2bbd53ac7e19b5c97781d152d6d8d5a9b5679a1e (patch)
tree54d9db2f7b7f92ef60fa5c080e3b30d352a73bbd
parent4e98f67a59cb7d2dcdc1191421212cfc440f7aba (diff)
Updated training code to latest Torch version: backward is now split into two.
-rw-r--r--BatchOptimization.lua2
-rw-r--r--SpatialConvolutionSparse.lua204
-rw-r--r--generic/SpatialConvolutionSparse.c147
-rw-r--r--generic/SpatialLinear.c2
-rw-r--r--init.c5
-rw-r--r--init.lua1
-rw-r--r--nnx-1.0-1.rockspec1
7 files changed, 3 insertions, 359 deletions
diff --git a/BatchOptimization.lua b/BatchOptimization.lua
index 5ebccf9..51ee075 100644
--- a/BatchOptimization.lua
+++ b/BatchOptimization.lua
@@ -66,6 +66,7 @@ function Batch:forward_sequential(inputs, targets, options)
-- estimate df/dW
local df_do = self.criterion:backward(output, targets[i])
self.module:backward(inputs[i], df_do)
+ self.module:accGradParameters(inputs[i], df_do)
-- user hook
if self.posthook then
self.posthook(self, {inputs[i], targets[i], options[i]})
@@ -298,6 +299,7 @@ function Batch:setup_mapreduce ()
-- estimate df/dW
local df_do = criterion:backward(output, targets[i])
module:backward(inputs[i], df_do)
+ module:accGradParameters(inputs[i], df_do)
-- user hook
if posthook then
posthook(optimizer, {inputs[i], targets[i], options[i]})
diff --git a/SpatialConvolutionSparse.lua b/SpatialConvolutionSparse.lua
deleted file mode 100644
index 5545af3..0000000
--- a/SpatialConvolutionSparse.lua
+++ /dev/null
@@ -1,204 +0,0 @@
-local SpatialConvolutionSparse, parent = torch.class('nn.SpatialConvolutionSparse', 'nn.Module')
-
-local help_desc =
-[[Applies a 2D convolution over an input image composed of
-several input planes. The input tensor in forward(input)
-is expected to be a 3D tensor (width x height x nInputPlane).
-
-A table of connections is used to specify the topology of the
-layer. If a plain fully connected module is enough,
-nn.SpatialConvolution should be used. This table should be
-a 2D tensor (2 x nb_kernels), where table[k][1] points to an
-input, and table[k][2] points to an output.
-
-Note that depending of the size of your kernel, several
-(of the last) columns or rows of the input image might be lost.
-It is up to the user to add proper padding in images.
-
-If the input image is a 3D tensor width x height x nInputPlane,
-the output image size will be owidth x oheight x nOutputPlane where
-
-owidth = (width - kW) / dW + 1
-oheight = (height - kH) / dH + 1 .
-
-The parameters of the convolution can be found in self.weight
-(Tensor of size kH x kW x nInputPlane x nOutputPlane) and
-self.bias (Tensor of size nOutputPlane). The corresponding
-gradients can be found in self.gradWeight and self.gradBias.]]
-
-local help_example =
-[[-- create a filter bank with 8 inputs, 32 outputs, and
--- random connections with a fanin of 4, filters are 9x9
-stimulus = lab.randn(8,500,500)
-mod = nn.SpatialConvolutionSparse(nn.tables.random(8,32,4), 9, 9)
-result = mod:forward(stimulus)]]
-
-nn.tables = nn.tables or {}
-
-function nn.tables.full(nin, nout)
- local ft = torch.Tensor(nin*nout,2)
- local p = 1
- for j=1,nout do
- for i=1,nin do
- ft[p][1] = i
- ft[p][2] = j
- p = p + 1
- end
- end
- return ft
-end
-
-function nn.tables.oneToOne(nfeat)
- local ft = torch.Tensor(nfeat,2)
- for i=1,nfeat do
- ft[i][1] = i
- ft[i][2] = i
- end
- return ft
-end
-
-function nn.tables.random(nin, nout, nto)
- local nker = nto * nout
- local tbl = torch.Tensor(nker, 2)
- local fi = lab.randperm(nin)
- local frcntr = 1
- local tocntr = 1
- local nfi = math.floor(nin/nto) -- number of distinct nto chunks
- local rfi = math.mod(nin,nto) -- number of remaining from maps
- local totbl = tbl:select(2,2)
- local frtbl = tbl:select(2,1)
- local fitbl = fi:narrow(1, 1, (nfi * nto)) -- part of fi that covers distinct chunks
- local ufrtbl= frtbl:unfold(1, nto, nto)
- local utotbl= totbl:unfold(1, nto, nto)
- local ufitbl= fitbl:unfold(1, nto, nto)
-
- -- start filling frtbl
- for i=1,nout do -- fro each unit in target map
- ufrtbl:select(1,i):copy(ufitbl:select(1,frcntr))
- frcntr = frcntr + 1
- if frcntr-1 == nfi then -- reset fi
- fi:copy(lab.randperm(nin))
- frcntr = 1
- end
- end
- for tocntr=1,utotbl:size(1) do
- utotbl:select(1,tocntr):fill(tocntr)
- end
- return tbl
-end
-
-function SpatialConvolutionSparse:__init(conMatrix, kW, kH, dW, dH)
- parent.__init(self)
-
- -- usage
- if not conMatrix or not kW or not kH or type(conMatrix) ~= 'userdata' then
- error(xlua.usage('nn.SpatialConvolutionSparse', help_desc, help_example,
- {type='torch.Tensor', help='a Nx2 array, N being the number of kernels',
- req=true},
- {type='number', help='kernel width', req=true},
- {type='number', help='kernel height', req=true},
- {type='number', help='stride width'},
- {type='number', help='stride height'}))
- end
-
- dW = dW or 1
- dH = dH or 1
-
- self.kW = kW
- self.kH = kH
- self.dW = dW
- self.dH = dH
- self.connTable = conMatrix
- self.nInputPlane = self.connTable:select(2,1):max()
- self.nOutputPlane = self.connTable:select(2,2):max()
-
- self.weight = torch.Tensor(self.connTable:size(1), kH, kW)
- self.bias = torch.Tensor(self.nOutputPlane)
- self.gradWeight = torch.Tensor(self.connTable:size(1), kH, kW)
- self.gradBias = torch.Tensor(self.nOutputPlane)
-
- self:reset()
-end
-
-function SpatialConvolutionSparse:reset(stdv)
- if stdv then
- stdv = stdv * math.sqrt(3)
- self.weight:apply(function()
- return random.uniform(-stdv, stdv)
- end)
- self.bias:apply(function()
- return random.uniform(-stdv, stdv)
- end)
- else
- local ninp = torch.Tensor(self.nOutputPlane):zero()
- for i=1,self.connTable:size(1) do ninp[self.connTable[i][2]] = ninp[self.connTable[i][2]]+1 end
- for k=1,self.connTable:size(1) do
- stdv = 1/math.sqrt(self.kW*self.kH*ninp[self.connTable[k][2]])
- self.weight:select(1,k):apply(function() return random.uniform(-stdv,stdv) end)
- end
- for k=1,self.bias:size(1) do
- stdv = 1/math.sqrt(self.kW*self.kH*ninp[k])
- self.bias[k] = random.uniform(-stdv,stdv)
- end
- end
-end
-
-function SpatialConvolutionSparse:forward(input)
- input.nn.SpatialConvolutionSparse_forward(self, input)
- return self.output
-end
-
-function SpatialConvolutionSparse:backward(input, gradOutput)
- input.nn.SpatialConvolutionSparse_backward(self, input, gradOutput)
- return self.gradInput
-end
-
-function SpatialConvolutionSparse:zeroGradParameters(momentum)
- if momentum then
- self.gradWeight:mul(momentum)
- self.gradBias:mul(momentum)
- else
- self.gradWeight:zero()
- self.gradBias:zero()
- end
-end
-
-function SpatialConvolutionSparse:updateParameters(learningRate)
- self.weight:add(-learningRate, self.gradWeight)
- self.bias:add(-learningRate, self.gradBias)
-end
-
-function SpatialConvolutionSparse:decayParameters(decay)
- self.weight:add(-decay, self.weight)
- self.bias:add(-decay, self.bias)
-end
-
-function SpatialConvolutionSparse:write(file)
- parent.write(self, file)
- file:writeInt(self.kW)
- file:writeInt(self.kH)
- file:writeInt(self.dW)
- file:writeInt(self.dH)
- file:writeInt(self.nInputPlane)
- file:writeInt(self.nOutputPlane)
- file:writeObject(self.weight)
- file:writeObject(self.bias)
- file:writeObject(self.gradWeight)
- file:writeObject(self.gradBias)
- file:writeObject(self.connTable)
-end
-
-function SpatialConvolutionSparse:read(file)
- parent.read(self, file)
- self.kW = file:readInt()
- self.kH = file:readInt()
- self.dW = file:readInt()
- self.dH = file:readInt()
- self.nInputPlane = file:readInt()
- self.nOutputPlane = file:readInt()
- self.weight = file:readObject()
- self.bias = file:readObject()
- self.gradWeight = file:readObject()
- self.gradBias = file:readObject()
- self.connTable = file:readObject()
-end
diff --git a/generic/SpatialConvolutionSparse.c b/generic/SpatialConvolutionSparse.c
deleted file mode 100644
index c3b5162..0000000
--- a/generic/SpatialConvolutionSparse.c
+++ /dev/null
@@ -1,147 +0,0 @@
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "generic/SpatialConvolutionSparse.c"
-#else
-
-static int nn_(SpatialConvolutionSparse_forward)(lua_State *L)
-{
- THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id));
- int kW = luaT_getfieldcheckint(L, 1, "kW");
- int kH = luaT_getfieldcheckint(L, 1, "kH");
- int dW = luaT_getfieldcheckint(L, 1, "dW");
- int dH = luaT_getfieldcheckint(L, 1, "dH");
- int nInputPlane = luaT_getfieldcheckint(L, 1, "nInputPlane");
- int nOutputPlane = luaT_getfieldcheckint(L, 1, "nOutputPlane");
-
- THTensor *connTable = luaT_getfieldcheckudata(L, 1, "connTable", torch_(Tensor_id));
- THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_(Tensor_id));
- THTensor *bias = luaT_getfieldcheckudata(L, 1, "bias", torch_(Tensor_id));
- THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_(Tensor_id));
-
- luaL_argcheck(L, input->nDimension == 3, 2, "3D tensor expected");
- luaL_argcheck(L, input->size[0] == nInputPlane, 2, "invalid number of input planes");
- luaL_argcheck(L, input->size[2] >= kW && input->size[1] >= kH, 2, "input image smaller than kernel size");
-
- THTensor_(resize3d)(output, nOutputPlane,
- (input->size[1] - kH) / dH + 1,
- (input->size[2] - kW) / dW + 1);
-
- THTensor *inputPlane = THTensor_(new)();
- THTensor *weightPlane = THTensor_(new)();
- THTensor *outputPlane = THTensor_(new)();
-
- /* Add bias */
- int k;
- for (k = 0; k < nOutputPlane; k++)
- {
- THTensor_(select)(outputPlane,output,0,k);
- THTensor_(fill)(outputPlane, THTensor_(get1d)(bias, k));
- }
-
- /* Convolve all maps */
- int nkernel = connTable->size[0];
- for (k = 0; k < nkernel; k++)
- {
- int outplaneid = (int)THTensor_(get2d)(connTable,k,1)-1;
- int inplaneid = (int)THTensor_(get2d)(connTable,k,0)-1;
-
- /* Get input, output and kernel*/
- THTensor_(select)(outputPlane, output, 0, outplaneid);
- THTensor_(select)(inputPlane, input, 0, inplaneid);
- THTensor_(select)(weightPlane, weight, 0, k);
-
- /* Convolve */
- THLab_(conv2Dmul)(outputPlane, 1.0, inputPlane, weightPlane, dH, dW, "vx");
- }
-
- /* Cleanup */
- THTensor_(free)(inputPlane);
- THTensor_(free)(weightPlane);
- THTensor_(free)(outputPlane);
-
- return 1;
-}
-
-static int nn_(SpatialConvolutionSparse_backward)(lua_State *L)
-{
- THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id));
- THTensor *gradOutput = luaT_checkudata(L, 3, torch_(Tensor_id));
- int kW = luaT_getfieldcheckint(L, 1, "kW");
- int kH = luaT_getfieldcheckint(L, 1, "kH");
- int dW = luaT_getfieldcheckint(L, 1, "dW");
- int dH = luaT_getfieldcheckint(L, 1, "dH");
- int nInputPlane = luaT_getfieldcheckint(L, 1, "nInputPlane");
- int nOutputPlane = luaT_getfieldcheckint(L, 1, "nOutputPlane");
-
- THTensor *connTable = luaT_getfieldcheckudata(L, 1, "connTable", torch_(Tensor_id));
- THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_(Tensor_id));
- THTensor *gradWeight = luaT_getfieldcheckudata(L, 1, "gradWeight", torch_(Tensor_id));
- THTensor *gradBias = luaT_getfieldcheckudata(L, 1, "gradBias", torch_(Tensor_id));
- THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_(Tensor_id));
-
- THTensor *gradInputPlane = THTensor_(new)();
- THTensor *inputPlane = THTensor_(new)();
- THTensor *gradOutputPlane = THTensor_(new)();
- THTensor *weightPlane = THTensor_(new)();
- THTensor *gradWeightPlane = THTensor_(new)();
-
- /* Resize/Zero */
- THTensor_(resizeAs)(gradInput, input);
- THTensor_(zero)(gradInput);
-
- /* gradients wrt bias */
- int i, k;
- real *gradBias_data = THTensor_(data)(gradBias);
- for(k = 0; k < nOutputPlane; k++)
- {
- THTensor_(select)(gradOutputPlane, gradOutput, 0, k);
- gradBias_data[k] += THTensor_(sum)(gradOutputPlane);
- }
-
- int nkernel = connTable->size[0];
- for(k = 0; k < nkernel; k++)
- {
- int outplaneid = (int)THTensor_(get2d)(connTable,k,1)-1;
- int inplaneid = (int)THTensor_(get2d)(connTable,k,0)-1;
-
- /* Select all planes */
- THTensor_(select)(inputPlane, input, 0, inplaneid);
- THTensor_(select)(gradInputPlane, gradInput, 0, inplaneid);
- THTensor_(select)(gradOutputPlane, gradOutput, 0, outplaneid);
- THTensor_(select)(weightPlane, weight, 0, k);
- THTensor_(select)(gradWeightPlane, gradWeight, 0, k);
-
- /* Gradient to kernel */
- THTensor_(resize3d)(inputPlane, 1, inputPlane->size[0], inputPlane->size[1]);
- THTensor_(resize3d)(gradOutputPlane, 1, gradOutputPlane->size[0], gradOutputPlane->size[1]);
- THTensor_(resize4d)(gradWeightPlane, 1, 1, gradWeightPlane->size[0], gradWeightPlane->size[1]);
- THLab_(conv2DRevger)(gradWeightPlane, 1.0, inputPlane, gradOutputPlane, dH, dW);
-
- /* Gradient to input */
- THTensor_(resize3d)(gradInputPlane, 1, gradInputPlane->size[0], gradInputPlane->size[1]);
- THTensor_(resize4d)(weightPlane, 1, 1, weightPlane->size[0], weightPlane->size[1]);
- THLab_(conv2Dmv)(gradInputPlane, 1.0, gradOutputPlane, weightPlane, dH, dW, "fc");
- }
-
- THTensor_(free)(gradInputPlane);
- THTensor_(free)(inputPlane);
- THTensor_(free)(gradOutputPlane);
- THTensor_(free)(weightPlane);
- THTensor_(free)(gradWeightPlane);
-
- return 1;
-}
-
-static const struct luaL_Reg nn_(SpatialConvolutionSparse__) [] = {
- {"SpatialConvolutionSparse_forward", nn_(SpatialConvolutionSparse_forward)},
- {"SpatialConvolutionSparse_backward", nn_(SpatialConvolutionSparse_backward)},
- {NULL, NULL}
-};
-
-static void nn_(SpatialConvolutionSparse_init)(lua_State *L)
-{
- luaT_pushmetaclass(L, torch_(Tensor_id));
- luaT_registeratname(L, nn_(SpatialConvolutionSparse__), "nn");
- lua_pop(L,1);
-}
-
-#endif
diff --git a/generic/SpatialLinear.c b/generic/SpatialLinear.c
index 56b4a3d..903c3b8 100644
--- a/generic/SpatialLinear.c
+++ b/generic/SpatialLinear.c
@@ -103,7 +103,7 @@ static int nn_(SpatialLinear_backward)(lua_State *L)
}
// compute dE/dI
- THTensor_(addmv)(gradInput_xy, 1, weight_t, gradOutput_xy);
+ THTensor_(addmv)(gradInput_xy, 1, 1, weight_t, gradOutput_xy);
}
}
diff --git a/init.c b/init.c
index 89f387d..8ae2b3c 100644
--- a/init.c
+++ b/init.c
@@ -18,9 +18,6 @@ static const void* torch_DoubleTensor_id = NULL;
#include "generic/SpatialLinear.c"
#include "THGenerateFloatTypes.h"
-#include "generic/SpatialConvolutionSparse.c"
-#include "THGenerateFloatTypes.h"
-
#include "generic/SpatialMaxPooling.c"
#include "THGenerateFloatTypes.h"
@@ -63,7 +60,6 @@ DLL_EXPORT int luaopen_libnnx(lua_State *L)
nn_FloatHardShrink_init(L);
nn_FloatAbs_init(L);
nn_FloatThreshold_init(L);
- nn_FloatSpatialConvolutionSparse_init(L);
nn_FloatSpatialLogSoftMax_init(L);
nn_FloatSpatialMaxPooling_init(L);
nn_FloatSpatialUpSampling_init(L);
@@ -79,7 +75,6 @@ DLL_EXPORT int luaopen_libnnx(lua_State *L)
nn_DoubleHardShrink_init(L);
nn_DoubleAbs_init(L);
nn_DoubleThreshold_init(L);
- nn_DoubleSpatialConvolutionSparse_init(L);
nn_DoubleSpatialLogSoftMax_init(L);
nn_DoubleSpatialMaxPooling_init(L);
nn_DoubleSpatialUpSampling_init(L);
diff --git a/init.lua b/init.lua
index 3519294..10d88ce 100644
--- a/init.lua
+++ b/init.lua
@@ -79,7 +79,6 @@ end
-- spatial (images) operators:
torch.include('nnx', 'SpatialLinear.lua')
torch.include('nnx', 'SpatialLogSoftMax.lua')
-torch.include('nnx', 'SpatialConvolutionSparse.lua')
torch.include('nnx', 'SpatialMaxPooling.lua')
torch.include('nnx', 'SpatialPadding.lua')
torch.include('nnx', 'SpatialNormalization.lua')
diff --git a/nnx-1.0-1.rockspec b/nnx-1.0-1.rockspec
index dcb8d1b..f872955 100644
--- a/nnx-1.0-1.rockspec
+++ b/nnx-1.0-1.rockspec
@@ -72,7 +72,6 @@ build = {
install_files(/lua/nnx Sqrt.lua)
install_files(/lua/nnx Threshold.lua)
install_files(/lua/nnx OmpModule.lua)
- install_files(/lua/nnx SpatialConvolutionSparse.lua)
install_files(/lua/nnx SpatialLogSoftMax.lua)
install_files(/lua/nnx SpatialMaxPooling.lua)
install_files(/lua/nnx SpatialLinear.lua)