diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-12-03 12:09:23 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-12-03 12:09:23 +0300 |
commit | 8c896314e9aa8540132f32d8e2577a57c35f39cd (patch) | |
tree | 31d6a0130f41bfc5a93401aa9ca4562eb30c1840 | |
parent | e88dc959d5b4c7a70c98c2418b7ef39d6b24731c (diff) | |
parent | b9212fde4c889d2e811534777ebb124b08ba2535 (diff) |
Merge branch 'master' of github.com:torch/nn
-rw-r--r-- | .luacheckrc | 13 | ||||
-rw-r--r-- | CMakeLists.txt | 1 | ||||
-rw-r--r-- | ClassNLLCriterion.lua | 14 | ||||
-rw-r--r-- | Concat.lua | 16 | ||||
-rw-r--r-- | CosineEmbeddingCriterion.lua | 6 | ||||
-rw-r--r-- | CriterionTable.lua | 1 | ||||
-rw-r--r-- | DepthConcat.lua | 6 | ||||
-rw-r--r-- | Euclidean.lua | 2 | ||||
-rw-r--r-- | FlattenTable.lua | 12 | ||||
-rw-r--r-- | Identity.lua | 2 | ||||
-rw-r--r-- | Jacobian.lua | 11 | ||||
-rw-r--r-- | LookupTable.lua | 2 | ||||
-rw-r--r-- | Mean.lua | 14 | ||||
-rw-r--r-- | Module.lua | 2 | ||||
-rw-r--r-- | Parallel.lua | 19 | ||||
-rw-r--r-- | Sequential.lua | 1 | ||||
-rw-r--r-- | SoftMax.lua | 2 | ||||
-rw-r--r-- | SoftMin.lua | 2 | ||||
-rw-r--r-- | SparseJacobian.lua | 4 | ||||
-rw-r--r-- | SpatialAveragePooling.lua | 20 | ||||
-rw-r--r-- | SpatialConvolutionMap.lua | 2 | ||||
-rw-r--r-- | SplitTable.lua | 1 | ||||
-rw-r--r-- | Sum.lua | 23 | ||||
-rw-r--r-- | Transpose.lua | 1 | ||||
-rw-r--r-- | WeightedEuclidean.lua | 4 | ||||
-rwxr-xr-x | doc/convolution.md | 12 | ||||
-rw-r--r-- | doc/simple.md | 26 | ||||
-rw-r--r-- | generic/LogSoftMax.c | 37 | ||||
-rw-r--r-- | generic/Max.c | 40 | ||||
-rw-r--r-- | generic/Min.c | 40 | ||||
-rw-r--r-- | generic/SpatialAveragePooling.c | 190 | ||||
-rw-r--r-- | hessian.lua | 2 | ||||
-rw-r--r-- | init.c | 5 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | test.lua (renamed from test/test.lua) | 401 |
35 files changed, 656 insertions, 279 deletions
diff --git a/.luacheckrc b/.luacheckrc new file mode 100644 index 0000000..3d358e9 --- /dev/null +++ b/.luacheckrc @@ -0,0 +1,13 @@ +-- -*- mode: lua; -*- +std = "luajit" + +globals = { + "torch", + "nn", + "include", +} + +unused_args = false + + +files['test.lua'].redefined = false diff --git a/CMakeLists.txt b/CMakeLists.txt index 1138954..e732709 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,7 +44,6 @@ LINK_DIRECTORIES("${Torch_INSTALL_LIB}") SET(src init.c) FILE(GLOB luasrc *.lua) -SET(luasrc ${luasrc} test/test.lua) ADD_TORCH_PACKAGE(nn "${src}" "${luasrc}") diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua index 2679e81..926e707 100644 --- a/ClassNLLCriterion.lua +++ b/ClassNLLCriterion.lua @@ -3,13 +3,20 @@ local ClassNLLCriterion, parent = torch.class('nn.ClassNLLCriterion', 'nn.Criter function ClassNLLCriterion:__init(weights) parent.__init(self) self.sizeAverage = true + self.outputTensor = torch.Tensor(1) if weights then - assert(weights:dim() == 1, "weights input should be 1-D Tensor") + assert(weights:dim() == 1, "weights input should be 1-D Tensor") self.weights = weights end end function ClassNLLCriterion:updateOutput(input, target) + if input:type() == 'torch.CudaTensor' and not self.weights then + input.nn.ClassNLLCriterion_updateOutput(self, input, target) + self.output = self.outputTensor[1] + return self.output + end + if input:dim() == 1 then self.output = -input[target] if self.weights then @@ -38,6 +45,11 @@ function ClassNLLCriterion:updateGradInput(input, target) self.gradInput:resizeAs(input) self.gradInput:zero() + if input:type() == 'torch.CudaTensor' and not self.weights then + input.nn.ClassNLLCriterion_updateGradInput(self, input, target) + return self.gradInput + end + if input:dim() == 1 then self.gradInput[target] = -1 if self.weights then @@ -63,9 +63,10 @@ function Concat:accGradParameters(input, gradOutput, scale) local offset = 1 for i,module in ipairs(self.modules) do local currentOutput = module.output - local currentGradInput = module:accGradParameters(input, - gradOutput:narrow(self.dimension, offset, currentOutput:size(self.dimension)), - scale) + module:accGradParameters( + input, + gradOutput:narrow(self.dimension, offset, currentOutput:size(self.dimension)), + scale) offset = offset + currentOutput:size(self.dimension) end end @@ -74,9 +75,10 @@ function Concat:accUpdateGradParameters(input, gradOutput, lr) local offset = 1 for i,module in ipairs(self.modules) do local currentOutput = module.output - local currentGradInput = module:accUpdateGradParameters(input, - gradOutput:narrow(self.dimension, offset, currentOutput:size(self.dimension)), - lr) + module:accUpdateGradParameters( + input, + gradOutput:narrow(self.dimension, offset, currentOutput:size(self.dimension)), + lr) offset = offset + currentOutput:size(self.dimension) end end @@ -140,7 +142,7 @@ function Concat:__tostring__() local ext = ' | ' local extlast = ' ' local last = ' ... -> ' - local str = 'nn.Concat' + local str = torch.type(self) str = str .. ' {' .. line .. tab .. 'input' for i=1,#self.modules do if i == self.modules then diff --git a/CosineEmbeddingCriterion.lua b/CosineEmbeddingCriterion.lua index 93348fb..293ae23 100644 --- a/CosineEmbeddingCriterion.lua +++ b/CosineEmbeddingCriterion.lua @@ -23,12 +23,6 @@ function CosineEmbeddingCriterion:updateOutput(input,y) return self.output end -local function mathsign(t) - if t>0 then return 1; end - if t<0 then return -1; end - return 2*torch.random(2)-3; -end - function CosineEmbeddingCriterion:updateGradInput(input, y) local v1 = input[1] local v2 = input[2] diff --git a/CriterionTable.lua b/CriterionTable.lua index e5538f7..be00837 100644 --- a/CriterionTable.lua +++ b/CriterionTable.lua @@ -1,6 +1,7 @@ local CriterionTable, parent = torch.class('nn.CriterionTable', 'nn.Module') function CriterionTable:__init(criterion) + parent.__init(self) self.criterion = criterion self.gradInput = {criterion.gradInput} end diff --git a/DepthConcat.lua b/DepthConcat.lua index 70646f4..7187d61 100644 --- a/DepthConcat.lua +++ b/DepthConcat.lua @@ -9,7 +9,7 @@ -- this, we select the largest spatial dimensions and add zero-padding -- around the smaller dimensions. ------------------------------------------------------------------------ -local DepthConcat, parent = torch.class('nn.DepthConcat', 'nn.Concat') +local DepthConcat, _ = torch.class('nn.DepthConcat', 'nn.Concat') function DepthConcat:windowNarrow(output, currentOutput, offset) local outputWindow = output:narrow(self.dimension, offset, currentOutput:size(self.dimension)) @@ -79,7 +79,7 @@ function DepthConcat:accGradParameters(input, gradOutput, scale) for i,module in ipairs(self.modules) do local currentOutput = module.output local gradOutputWindow = self:windowNarrow(gradOutput, currentOutput, offset) - local currentGradInput = module:accGradParameters(input, gradOutputWindow, scale) + module:accGradParameters(input, gradOutputWindow, scale) offset = offset + currentOutput:size(self.dimension) end end @@ -89,7 +89,7 @@ function DepthConcat:accUpdateGradParameters(input, gradOutput, lr) for i,module in ipairs(self.modules) do local currentOutput = module.output local gradOutputWindow = self:windowNarrow(gradOutput, currentOutput, offset) - local currentGradInput = module:accUpdateGradParameters(input, gradOutputWindow, lr) + module:accUpdateGradParameters(input, gradOutputWindow, lr) offset = offset + currentOutput:size(self.dimension) end end diff --git a/Euclidean.lua b/Euclidean.lua index c0dd99c..229d792 100644 --- a/Euclidean.lua +++ b/Euclidean.lua @@ -61,7 +61,7 @@ function Euclidean:accGradParameters(input, gradOutput, scale) if self.output[o] ~= 0 then self.temp:copy(self.weight:select(2,o)):add(-1,input) self.temp:mul(gradOutput[o]/self.output[o]) - self.gradWeight:select(2,o):add(self.temp) + self.gradWeight:select(2,o):add(scale, self.temp) end end end diff --git a/FlattenTable.lua b/FlattenTable.lua index 849daa7..3a88588 100644 --- a/FlattenTable.lua +++ b/FlattenTable.lua @@ -39,10 +39,10 @@ local function checkMapping(output, input, input_map) end -- forward DFS order for i = 1, #input do - ok = checkMapping(output, input[i], input_map[i]) - if not ok then - return false - end + local ok = checkMapping(output, input[i], input_map[i]) + if not ok then + return false + end end return true else @@ -77,7 +77,7 @@ function FlattenTable:updateOutput(input) self.input_map = flatten(self.output, input) end return self.output -end +end function FlattenTable:updateGradInput(input, gradOutput) assert(type(input) == 'table', 'input must be a table') @@ -90,7 +90,7 @@ function FlattenTable:updateGradInput(input, gradOutput) if not checkMapping(gradOutput, self.gradInput, self.input_map) then self.gradInput = inverseFlatten(gradOutput, self.input_map) end - + return self.gradInput end diff --git a/Identity.lua b/Identity.lua index 79b5c08..088cc34 100644 --- a/Identity.lua +++ b/Identity.lua @@ -1,4 +1,4 @@ -local Identity, parent = torch.class('nn.Identity', 'nn.Module') +local Identity, _ = torch.class('nn.Identity', 'nn.Module') function Identity:updateOutput(input) self.output = input diff --git a/Jacobian.lua b/Jacobian.lua index debdcd7..c3797bd 100644 --- a/Jacobian.lua +++ b/Jacobian.lua @@ -52,7 +52,7 @@ function nn.Jacobian.backwardUpdate(module, input, param) end dout:zero() sdout[i] = 1 - local din = module:updateGradInput(input, dout) + module:updateGradInput(input, dout) module:accUpdateGradParameters(input, dout, 1) jacobian:select(2,i):copy(param) end @@ -170,17 +170,18 @@ function nn.Jacobian.testIO(module,input, minval, maxval) local bo = module.gradInput:clone() -- write module - local f = torch.DiskFile('tmp.bin','w'):binary() + local filename = os.tmpname() + local f = torch.DiskFile(filename, 'w'):binary() f:writeObject(module) f:close() -- read module - local m = torch.DiskFile('tmp.bin'):binary():readObject() + local m = torch.DiskFile(filename):binary():readObject() m:forward(input) m:zeroGradParameters() m:updateGradInput(input,go) m:accGradParameters(input,go) -- cleanup - os.remove('tmp.bin') + os.remove(filename) local fo2 = m.output:clone() local bo2 = m.gradInput:clone() @@ -241,7 +242,7 @@ function nn.Jacobian.testAllUpdate(module, input, weight, gradWeight) macshu2:updateGradInput(input, gradOutput) macshu1:accUpdateGradParameters(input, gradOutput, lr) macshu2:accUpdateGradParameters(input, gradOutput, lr) - local err = (weightc-maccgp[gradWeight]*(lr*2)-macshu1[weight]):norm() + err = (weightc-maccgp[gradWeight]*(lr*2)-macshu1[weight]):norm() err = err + (weightc-maccgp[gradWeight]*(lr*2)-macshu2[weight]):norm() errors["accUpdateGradParameters [shared]"] = err diff --git a/LookupTable.lua b/LookupTable.lua index 71d7f62..5b5f565 100644 --- a/LookupTable.lua +++ b/LookupTable.lua @@ -117,6 +117,7 @@ function LookupTable:accUpdateGradParameters(input, gradOutput, lr) for i=1,input:size(1) do local k = input[i] local kscale = self:scaleUpdateByKey(k) + self.inputs[k] = (self.inputs[k] or 0) + 1 self.weight:select(1, input[i]):add(-lr*kscale, gradOutput:select(1, i)) end elseif input:dim() == 2 then @@ -126,6 +127,7 @@ function LookupTable:accUpdateGradParameters(input, gradOutput, lr) for j=1,input:size(1) do local k = input[j] local kscale = self:scaleUpdateByKey(k) + self.inputs[k] = (self.inputs[k] or 0) + 1 self.weight:select(1, k):add(-lr*kscale, gradOutput:select(1, j)) end end @@ -8,15 +8,23 @@ end function Mean:updateOutput(input) self.output:mean(input, self.dimension) - self.output = self.output:select(self.dimension, 1) + if self.output:nDimension() > 1 then + self.output = self.output:select(self.dimension, 1) + end return self.output end function Mean:updateGradInput(input, gradOutput) local size = gradOutput:size():totable() local stride = gradOutput:stride():totable() - table.insert(size, self.dimension, input:size(self.dimension)) - table.insert(stride, self.dimension, 0) + + if input:nDimension() > 1 then + table.insert(size, self.dimension, input:size(self.dimension)) + table.insert(stride, self.dimension, 0) + else + size[1] = input:size(1) + stride[1] = 0 + end self.gradInput:resizeAs(gradOutput):copy(gradOutput) self.gradInput:mul(1/input:size(self.dimension)) @@ -171,7 +171,7 @@ function Module:getParameters() if storageAndOffset == nil then return nil end - local storage, offset = unpack(storageAndOffset) + local _, offset = unpack(storageAndOffset) return offset end diff --git a/Parallel.lua b/Parallel.lua index 547f444..3057ba2 100644 --- a/Parallel.lua +++ b/Parallel.lua @@ -71,10 +71,12 @@ function Parallel:accGradParameters(input, gradOutput, scale) for i=1,nModule do local module = self.modules[i]; local currentOutput = module.output - local currentGradInput = - module:accGradParameters(input:select(self.inputDimension,i), - gradOutput:narrow(self.outputDimension, - offset, currentOutput:size(self.outputDimension)), scale) + module:accGradParameters( + input:select(self.inputDimension,i), + gradOutput:narrow( + self.outputDimension, offset, + currentOutput:size(self.outputDimension)), + scale) offset = offset + currentOutput:size(self.outputDimension) end @@ -87,10 +89,11 @@ function Parallel:accUpdateGradParameters(input, gradOutput, lr) for i=1,nModule do local module = self.modules[i]; local currentOutput = module.output - local currentGradInput = - module:accUpdateGradParameters(input:select(self.inputDimension,i), - gradOutput:narrow(self.outputDimension, - offset, currentOutput:size(self.outputDimension)), lr) + module:accUpdateGradParameters( + input:select(self.inputDimension,i), + gradOutput:narrow(self.outputDimension, offset, + currentOutput:size(self.outputDimension)), + lr) offset = offset + currentOutput:size(self.outputDimension) end diff --git a/Sequential.lua b/Sequential.lua index ec3247b..97554b3 100644 --- a/Sequential.lua +++ b/Sequential.lua @@ -1,6 +1,7 @@ local Sequential, parent = torch.class('nn.Sequential', 'nn.Module') function Sequential:__init() + parent.__init(self) self.modules = {} end diff --git a/SoftMax.lua b/SoftMax.lua index 609b353..22f0eda 100644 --- a/SoftMax.lua +++ b/SoftMax.lua @@ -1,4 +1,4 @@ -local SoftMax, parent = torch.class('nn.SoftMax', 'nn.Module') +local SoftMax, _ = torch.class('nn.SoftMax', 'nn.Module') function SoftMax:updateOutput(input) return input.nn.SoftMax_updateOutput(self, input) diff --git a/SoftMin.lua b/SoftMin.lua index 90c6c60..7d2358c 100644 --- a/SoftMin.lua +++ b/SoftMin.lua @@ -1,4 +1,4 @@ -local SoftMin, parent = torch.class('nn.SoftMin', 'nn.Module') +local SoftMin, _ = torch.class('nn.SoftMin', 'nn.Module') function SoftMin:updateOutput(input) self.mininput = self.mininput or input.new() diff --git a/SparseJacobian.lua b/SparseJacobian.lua index b778e67..19334d1 100644 --- a/SparseJacobian.lua +++ b/SparseJacobian.lua @@ -61,7 +61,7 @@ function nn.SparseJacobian.backwardUpdate (module, input, param) dout:zero() sdout[i] = 1 module:zeroGradParameters() - local din = module:updateGradInput(input, dout) + module:updateGradInput(input, dout) module:accUpdateGradParameters(input, dout, 1) jacobian:select(2,i):copy(param) end @@ -269,7 +269,7 @@ function nn.SparseJacobian.testAllUpdate(module, input, weight, gradWeight) macshu2:updateGradInput(input, gradOutput) macshu1:accUpdateGradParameters(input, gradOutput, lr) macshu2:accUpdateGradParameters(input, gradOutput, lr) - local err = (weightc-maccgp[gradWeight]*(lr*2)-macshu1[weight]):norm() + err = (weightc-maccgp[gradWeight]*(lr*2)-macshu1[weight]):norm() err = err + (weightc-maccgp[gradWeight]*(lr*2)-macshu2[weight]):norm() errors["accUpdateGradParameters [shared]"] = err diff --git a/SpatialAveragePooling.lua b/SpatialAveragePooling.lua new file mode 100644 index 0000000..13b6b45 --- /dev/null +++ b/SpatialAveragePooling.lua @@ -0,0 +1,20 @@ +local SpatialAveragePooling, parent = torch.class('nn.SpatialAveragePooling', 'nn.Module') + +function SpatialAveragePooling:__init(kW, kH, dW, dH) + parent.__init(self) + + self.kW = kW + self.kH = kH + self.dW = dW or 1 + self.dH = dH or 1 +end + +function SpatialAveragePooling:updateOutput(input) + return input.nn.SpatialAveragePooling_updateOutput(self, input) +end + +function SpatialAveragePooling:updateGradInput(input, gradOutput) + if self.gradInput then + return input.nn.SpatialAveragePooling_updateGradInput(self, input, gradOutput) + end +end diff --git a/SpatialConvolutionMap.lua b/SpatialConvolutionMap.lua index e05ce6e..390ace0 100644 --- a/SpatialConvolutionMap.lua +++ b/SpatialConvolutionMap.lua @@ -29,9 +29,7 @@ function nn.tables.random(nin, nout, nto) local tbl = torch.Tensor(nker, 2) local fi = torch.randperm(nin) local frcntr = 1 - local tocntr = 1 local nfi = math.floor(nin/nto) -- number of distinct nto chunks - local rfi = math.fmod(nin,nto) -- number of remaining from maps local totbl = tbl:select(2,2) local frtbl = tbl:select(2,1) local fitbl = fi:narrow(1, 1, (nfi * nto)) -- part of fi that covers distinct chunks diff --git a/SplitTable.lua b/SplitTable.lua index 70b45f6..bd46b71 100644 --- a/SplitTable.lua +++ b/SplitTable.lua @@ -28,7 +28,6 @@ function SplitTable:updateGradInput(input, gradOutput) local slices = input:size(dimension) self.gradInput:resizeAs(input) - local offset = 1 for i=1,slices do local currentGradInput = gradOutput[i]; self.gradInput:select(dimension,i):copy(currentGradInput) @@ -11,20 +11,21 @@ function Sum:updateOutput(input) self.output = input.new() end self.output:sum(input, self.dimension) - self.output = self.output:select(self.dimension, 1) + if self.output:nDimension() > 1 then + self.output = self.output:select(self.dimension, 1) + end return self.output end function Sum:updateGradInput(input, gradOutput) - local size = gradOutput:size():totable() - local stride = gradOutput:stride():totable() - table.insert(size, self.dimension, input:size(self.dimension)) - table.insert(stride, self.dimension, 0) + -- zero-strides dont work with MKL/BLAS, so + -- dont set self.gradInput to zero-stride tensor. + -- Instead, do a deepcopy + local size = input:size() + size[self.dimension] = 1 + gradOutput = gradOutput:view(size) + self.gradInput:resizeAs(input) + self.gradInput:copy(gradOutput:expandAs(input)) - self.gradInput:set(gradOutput:storage(), - 1, - torch.LongStorage(size), - torch.LongStorage(stride)) - - return self.gradInput + return self.gradInput end diff --git a/Transpose.lua b/Transpose.lua index a43729b..263db60 100644 --- a/Transpose.lua +++ b/Transpose.lua @@ -18,7 +18,6 @@ function Transpose:updateOutput(input) end function Transpose:updateGradInput(input, gradOutput) - local ndim = gradOutput:nDimension() for i = #self.permutations,1,-1 do local perm = self.permutations[i] gradOutput = gradOutput:transpose(perm[1],perm[2]) diff --git a/WeightedEuclidean.lua b/WeightedEuclidean.lua index c4a1dbc..3808db6 100644 --- a/WeightedEuclidean.lua +++ b/WeightedEuclidean.lua @@ -75,13 +75,13 @@ function WeightedEuclidean:accGradParameters(input, gradOutput, scale) self.temp:copy(self.templates:select(2,o)):add(-1,input) self.temp:cmul(self.diagCov:select(2,o)):cmul(self.diagCov:select(2,o)) self.temp:mul(gradOutput[o]/self.output[o]) - self.gradTemplates:select(2,o):add(self.temp) + self.gradTemplates:select(2,o):add(scale, self.temp) self.temp:copy(self.templates:select(2,o)):add(-1,input) self.temp:cmul(self.temp) self.temp:cmul(self.diagCov:select(2,o)) self.temp:mul(gradOutput[o]/self.output[o]) - self.gradDiagCov:select(2,o):add(self.temp) + self.gradDiagCov:select(2,o):add(scale, self.temp) end end end diff --git a/doc/convolution.md b/doc/convolution.md index 5571b19..c65222d 100755 --- a/doc/convolution.md +++ b/doc/convolution.md @@ -12,6 +12,7 @@ A convolution is an integral that expresses the amount of overlap of one functio * [SpatialConvolution](#nn.SpatialConvolution) : a 2D convolution over an input image ; * [SpatialSubSampling](#nn.SpatialSubSampling) : a 2D sub-sampling over an input image ; * [SpatialMaxPooling](#nn.SpatialMaxPooling) : a 2D max-pooling operation over an input image ; + * [SpatialAveragePooling](#nn.SpatialAveragePooling) : a 2D average-pooling operation over an input image ; * [SpatialLPPooling](#nn.SpatialLPPooling) : computes the `p` norm in a convolutional manner on a set of input images ; * [SpatialConvolutionMap](#nn.SpatialConvolutionMap) : a 2D convolution that uses a generic connection table ; * [SpatialZeroPadding](#nn.SpatialZeroPadding) : padds a feature map with specified number of zeros ; @@ -356,6 +357,17 @@ Applies 2D max-pooling operation in `kWxkH` regions by step size `dWxdH` steps. The number of output features is equal to the number of input planes. +<a name="nn.SpatialAveragePooling"/> +### SpatialAveragePooling ### + +```lua +module = nn.SpatialAveragePooling(kW, kH [, dW, dH]) +``` + +Applies 2D average-pooling operation in `kWxkH` regions by step size +`dWxdH` steps. The number of output features is equal to the number of +input planes. + <a name="nn.SpatialSubSampling"/> ### SpatialSubSampling ### diff --git a/doc/simple.md b/doc/simple.md index 93ea08e..aa1a94d 100644 --- a/doc/simple.md +++ b/doc/simple.md @@ -12,22 +12,22 @@ and providing affine transformations : * [Euclidean](#nn.Euclidean) : the euclidean distance of the input to `k` mean centers ; * [WeightedEuclidean](#nn.WeightedEuclidean) : similar to [Euclidean](#nn.Euclidean), but additionally learns a diagonal covariance matrix ; * Modules that adapt basic Tensor methods : - * [Copy](#nn.Copy) : a [copy](https://github.com/nicholas-leonard/torch7/blob/doc/doc/tensor.md#torch.Tensor.copy) of the input with [type](https://github.com/nicholas-leonard/torch7/blob/doc/doc/tensor.md#tensor-or-string-typetype) casting ; - * [Narrow](#nn.Narrow) : a [narrow](https://github.com/nicholas-leonard/torch7/blob/doc/doc/tensor.md#tensor-narrowdim-index-size) operation over a given dimension ; - * [Replicate](#nn.Replicate) : [repeats](https://github.com/nicholas-leonard/torch7/blob/doc/doc/tensor.md#tensor-repeattensorsizes) input `n` times along its first dimension ; + * [Copy](#nn.Copy) : a [copy](https://github.com/torch/torch7/blob/master/doc/tensor.md#torch.Tensor.copy) of the input with [type](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-or-string-typetype) casting ; + * [Narrow](#nn.Narrow) : a [narrow](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-narrowdim-index-size) operation over a given dimension ; + * [Replicate](#nn.Replicate) : [repeats](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-repeattensorresult-sizes) input `n` times along its first dimension ; * [Reshape](#nn.Reshape) : a [reshape](https://github.com/torch/torch7/blob/master/doc/maths.md#res-torchreshaperes-x-m-n) of the inputs ; - * [View](#nn.View) : a [view](https://github.com/nicholas-leonard/torch7/blob/doc/doc/tensor.md#result-viewresult-tensor-sizes) of the inputs ; - * [Select](#nn.Select) : a [select](https://github.com/nicholas-leonard/torch7/blob/doc/doc/tensor.md#tensor-selectdim-index) over a given dimension ; + * [View](#nn.View) : a [view](https://github.com/torch/torch7/blob/master/doc/tensor.md#result-viewresult-tensor-sizes) of the inputs ; + * [Select](#nn.Select) : a [select](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-selectdim-index) over a given dimension ; * Modules that adapt mathematical Tensor methods : - * [Max](#nn.Max) : a [max](https://github.com/nicholas-leonard/torch7/blob/doc/doc/maths.md#torchmaxresval-resind-x-dim) operation over a given dimension ; - * [Min](#nn.Min) : a [min](https://github.com/nicholas-leonard/torch7/blob/doc/doc/maths.md#torchminresval-resind-x) operation over a given dimension ; - * [Mean](#nn.Mean) : a [mean](https://github.com/nicholas-leonard/torch7/blob/doc/doc/maths.md#res-torchmeanres-x-dim) operation over a given dimension ; - * [Sum](#nn.Sum) : a [sum](https://github.com/nicholas-leonard/torch7/blob/doc/doc/maths.md#res-torchsumres-x) operation over a given dimension ; + * [Max](#nn.Max) : a [max](https://github.com/torch/torch7/blob/master/doc/maths.md#torch.max) operation over a given dimension ; + * [Min](#nn.Min) : a [min](https://github.com/torch/torch7/blob/master/doc/maths.md#torchminresval-resind-x) operation over a given dimension ; + * [Mean](#nn.Mean) : a [mean](https://github.com/torch/torch7/blob/master/doc/maths.md#res-torchmeanres-x-dim) operation over a given dimension ; + * [Sum](#nn.Sum) : a [sum](https://github.com/torch/torch7/blob/master/doc/maths.md#res-torchsumres-x) operation over a given dimension ; * [Exp](#nn.Exp) : an element-wise [exp](https://github.com/torch/torch7/blob/master/doc/maths.md#res-torchexpres-x) operation ; * [Abs](#nn.Abs) : an element-wise [abs](https://github.com/torch/torch7/blob/master/doc/maths.md#res-torchabsres-x) operation ; - * [Power](#nn.Power) : an element-wise [pow](https://github.com/nicholas-leonard/torch7/blob/doc/doc/maths.md#res-torchpowres-x) operation ; + * [Power](#nn.Power) : an element-wise [pow](https://github.com/torch/torch7/blob/master/doc/maths.md#res-torchpowres-x) operation ; * [Square](#nn.Square) : an element-wise square operation ; - * [Sqrt](#nn.Sqrt) : an element-wise [sqrt](https://github.com/nicholas-leonard/torch7/blob/doc/doc/maths.md#res-torchsqrtres-x) operation ; + * [Sqrt](#nn.Sqrt) : an element-wise [sqrt](https://github.com/torch/torch7/blob/master/doc/maths.md#res-torchsqrtres-x) operation ; * Miscellaneous Modules : * [Identity](#nn.Identity) : forward input as-is to output (useful with [ParallelTable](table.md#nn.ParallelTable)); * [Dropout](#nn.Dropout) : masks parts of the `input` using binary samples from a [bernoulli](http://en.wikipedia.org/wiki/Bernoulli_distribution) distribution ; @@ -472,7 +472,7 @@ type from `inputType` to `outputType`. `module` = `Narrow(dimension, offset, length)` Narrow is application of -[narrow](https://github.com/nicholas-leonard/torch7/blob/doc/doc/tensor.md#tensor-narrowdim-index-size) operation in a +[narrow](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-narrowdim-index-size) operation in a module. <a name="nn.Replicate"/> @@ -483,7 +483,7 @@ module. This class creates an output where the input is replicated `nFeature` times along its first dimension. There is no memory allocation or memory copy in this module. It sets the -[stride](https://github.com/nicholas-leonard/torch7/blob/doc/doc/tensor.md#number-stridedim) along the first +[stride](https://github.com/torch/torch7/blob/master/doc/tensor.md#torch.Tensor.stride) along the first dimension to zero. ```lua diff --git a/generic/LogSoftMax.c b/generic/LogSoftMax.c index 7741e3b..75b8587 100644 --- a/generic/LogSoftMax.c +++ b/generic/LogSoftMax.c @@ -26,12 +26,19 @@ static int nn_(LogSoftMax_updateOutput)(lua_State *L) input = THTensor_(newContiguous)(input); THTensor_(resizeAs)(output, input); - input_data = THTensor_(data)(input); - output_data = THTensor_(data)(output); + real* input_data0 = THTensor_(data)(input); + real* output_data0 = THTensor_(data)(output); + + accreal logsum; + real maxInput; +#pragma omp parallel for private(t, d, maxInput, logsum, input_data, \ + output_data) for(t = 0; t < nframe; t++) { - accreal logsum = 0; - real maxInput = -THInf; + logsum = 0; + maxInput = -THInf; + input_data = input_data0 + dim*t; + output_data = output_data0 + dim*t; for(d = 0; d < dim; d++) maxInput = THMax(maxInput, input_data[d]); @@ -42,9 +49,6 @@ static int nn_(LogSoftMax_updateOutput)(lua_State *L) for(d = 0; d < dim; d++) output_data[d] = input_data[d] - logsum; - - input_data += dim; - output_data += dim; } THTensor_(free)(input); @@ -75,21 +79,24 @@ static int nn_(LogSoftMax_updateGradInput)(lua_State *L) THError("vector or matrix expected"); THTensor_(resizeAs)(gradInput, output); - gradInput_data = THTensor_(data)(gradInput); - output_data = THTensor_(data)(output); - gradOutput_data = THTensor_(data)(gradOutput); + real* gradInput_data0 = THTensor_(data)(gradInput); + real* output_data0 = THTensor_(data)(output); + real* gradOutput_data0 = THTensor_(data)(gradOutput); + accreal sum; +#pragma omp parallel for private(t, sum, d, gradInput_data, output_data, \ + gradOutput_data) for(t = 0; t < nframe; t++) { - accreal sum = 0; + sum = 0; + gradInput_data = gradInput_data0 + dim*t; + output_data = output_data0 + dim*t; + gradOutput_data = gradOutput_data0 + dim*t; + for(d = 0; d < dim; d++) sum += gradOutput_data[d]; for(d = 0; d < dim; d++) gradInput_data[d] = gradOutput_data[d] - exp(output_data[d])*sum; - - gradInput_data += dim; - output_data += dim; - gradOutput_data += dim; } return 1; diff --git a/generic/Max.c b/generic/Max.c index fe76801..79823a4 100644 --- a/generic/Max.c +++ b/generic/Max.c @@ -36,7 +36,8 @@ static int nn_(Max_updateOutput)(lua_State *L) *indices_data = theIndex+1; *output_data = theMax;) - THTensor_(select)(output, NULL, dimension, 0); + if(output->nDimension > 1) + THTensor_(select)(output, NULL, dimension, 0); return 1; } @@ -56,25 +57,32 @@ static int nn_(Max_updateGradInput)(lua_State *L) THTensor_(resizeAs)(gradInput, input); THTensor_(zero)(gradInput); - dim = THLongStorage_newWithSize(gradOutput->nDimension+1); - str = THLongStorage_newWithSize(gradOutput->nDimension+1); - for(i = 0, j = 0; j < gradOutput->nDimension+1; j++) + if(input->nDimension > 1) { - if(j == dimension) + dim = THLongStorage_newWithSize(gradOutput->nDimension+1); + str = THLongStorage_newWithSize(gradOutput->nDimension+1); + for(i = 0, j = 0; j < gradOutput->nDimension+1; j++) { - dim->data[j] = input->size[dimension]; - str->data[j] = 0; - continue; + if(j == dimension) + { + dim->data[j] = input->size[dimension]; + str->data[j] = 0; + continue; + } + + dim->data[j] = gradOutput->size[i]; + str->data[j] = gradOutput->stride[i]; + i++; } - - dim->data[j] = gradOutput->size[i]; - str->data[j] = gradOutput->stride[i]; - i++; + gradOutputPlusOneDim = THTensor_(newWithStorage)(gradOutput->storage, gradOutput->storageOffset, dim, str); + THLongStorage_free(dim); + THLongStorage_free(str); + } + else + { + THTensor_(retain)(gradOutput); + gradOutputPlusOneDim = gradOutput; } - - gradOutputPlusOneDim = THTensor_(newWithStorage)(gradOutput->storage, gradOutput->storageOffset, dim, str); - THLongStorage_free(dim); - THLongStorage_free(str); TH_TENSOR_DIM_APPLY3(real, gradInput, real, gradOutputPlusOneDim, real, indices, dimension, gradInput_data[ ((long)(*indices_data)-1)*gradInput_stride ] = *gradOutputPlusOneDim_data;) diff --git a/generic/Min.c b/generic/Min.c index ab19c06..020f54b 100644 --- a/generic/Min.c +++ b/generic/Min.c @@ -36,7 +36,8 @@ static int nn_(Min_updateOutput)(lua_State *L) *indices_data = theIndex+1; *output_data = theMin;) - THTensor_(select)(output, NULL, dimension, 0); + if(output->nDimension > 1) + THTensor_(select)(output, NULL, dimension, 0); return 1; } @@ -56,25 +57,32 @@ static int nn_(Min_updateGradInput)(lua_State *L) THTensor_(resizeAs)(gradInput, input); THTensor_(zero)(gradInput); - dim = THLongStorage_newWithSize(gradOutput->nDimension+1); - str = THLongStorage_newWithSize(gradOutput->nDimension+1); - for(i = 0, j = 0; j < gradOutput->nDimension+1; j++) + if(input->nDimension > 1) { - if(j == dimension) + dim = THLongStorage_newWithSize(gradOutput->nDimension+1); + str = THLongStorage_newWithSize(gradOutput->nDimension+1); + for(i = 0, j = 0; j < gradOutput->nDimension+1; j++) { - dim->data[j] = input->size[dimension]; - str->data[j] = 0; - continue; + if(j == dimension) + { + dim->data[j] = input->size[dimension]; + str->data[j] = 0; + continue; + } + + dim->data[j] = gradOutput->size[i]; + str->data[j] = gradOutput->stride[i]; + i++; } - - dim->data[j] = gradOutput->size[i]; - str->data[j] = gradOutput->stride[i]; - i++; + gradOutputPlusOneDim = THTensor_(newWithStorage)(gradOutput->storage, gradOutput->storageOffset, dim, str); + THLongStorage_free(dim); + THLongStorage_free(str); + } + else + { + THTensor_(retain)(gradOutput); + gradOutputPlusOneDim = gradOutput; } - - gradOutputPlusOneDim = THTensor_(newWithStorage)(gradOutput->storage, gradOutput->storageOffset, dim, str); - THLongStorage_free(dim); - THLongStorage_free(str); TH_TENSOR_DIM_APPLY3(real, gradInput, real, gradOutputPlusOneDim, real, indices, dimension, gradInput_data[ ((long)(*indices_data)-1)*gradInput_stride ] = *gradOutputPlusOneDim_data;) diff --git a/generic/SpatialAveragePooling.c b/generic/SpatialAveragePooling.c new file mode 100644 index 0000000..2052d05 --- /dev/null +++ b/generic/SpatialAveragePooling.c @@ -0,0 +1,190 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/SpatialAveragePooling.c" +#else + +static int nn_(SpatialAveragePooling_updateOutput)(lua_State *L) +{ + THTensor *input = luaT_checkudata(L, 2, torch_Tensor); + int kW = luaT_getfieldcheckint(L, 1, "kW"); + int kH = luaT_getfieldcheckint(L, 1, "kH"); + int dW = luaT_getfieldcheckint(L, 1, "dW"); + int dH = luaT_getfieldcheckint(L, 1, "dH"); + + THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor); + + real *output_data; + real *input_data; + + int dimw = 2; + int dimh = 1; + int dimc = 0; + long nbatch = 1; + + long inputWidth; + long inputHeight; + long outputWidth; + long outputHeight; + long nInputPlane; // number of channels (or colors) + + long k; + + luaL_argcheck(L, input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D(batch mode) tensor expected"); + + if (input->nDimension == 4) { + nbatch = input->size[0]; + dimw++; + dimh++; + dimc++; + } + + inputWidth = input->size[dimw]; + inputHeight = input->size[dimh]; + nInputPlane = input->size[dimc]; + outputWidth = (inputWidth - kW) / dW + 1; + outputHeight = (inputHeight - kH) / dH + 1; + + luaL_argcheck(L, inputWidth >= kW && inputHeight >= kH, 2, "input image smaller than kernel size"); + + if (input->nDimension == 3) + THTensor_(resize3d)(output, nInputPlane, outputHeight, outputWidth); + else + THTensor_(resize4d)(output, input->size[0], nInputPlane, outputHeight, outputWidth); + + input = THTensor_(newContiguous)(input); + input_data = THTensor_(data)(input); + output_data = THTensor_(data)(output); + +#pragma omp parallel for private(k) + for(k = 0; k < nInputPlane; k++) + { + long p; + for(p = 0; p < nbatch; p++) + { + long xx, yy; + /* For all output pixels... */ + real *ptr_output = output_data + p*nInputPlane*outputWidth*outputHeight + k*outputWidth*outputHeight; + long i; + for(i = 0; i < outputWidth*outputHeight; i++) + ptr_output[i] = 0; + + for(yy = 0; yy < outputHeight; yy++) + { + for(xx = 0; xx < outputWidth; xx++) + { + /* Compute the mean of the input image... */ + real *ptr_input = input_data + p*nInputPlane*inputWidth*inputHeight + k*inputWidth*inputHeight + yy*dH*inputWidth+xx*dW; + real sum = 0; + long kx, ky; + + for(ky = 0; ky < kH; ky++) + { + for(kx = 0; kx < kW; kx++) + sum += ptr_input[kx]; + ptr_input += inputWidth; /* next input line */ + } + /* Update output */ + *ptr_output++ += sum; + } + } + } + } + THTensor_(free)(input); + + return 1; +} + +static int nn_(SpatialAveragePooling_updateGradInput)(lua_State *L) +{ + THTensor *input = luaT_checkudata(L, 2, torch_Tensor); + THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor); + int kW = luaT_getfieldcheckint(L, 1, "kW"); + int kH = luaT_getfieldcheckint(L, 1, "kH"); + int dW = luaT_getfieldcheckint(L, 1, "dW"); + int dH = luaT_getfieldcheckint(L, 1, "dH"); + THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor); + + int dimw = 2; + int dimh = 1; + int dimc = 0; + long nbatch = 1; + + long inputWidth; + long inputHeight; + long outputWidth; + long outputHeight; + long nInputPlane; // number of channels (or colors) + + real *gradOutput_data; + real *input_data, *gradInput_data; + + long k; + + if (input->nDimension == 4) { + nbatch = input->size[0]; + dimw++; + dimh++; + dimc++; + } + + inputWidth = input->size[dimw]; + inputHeight = input->size[dimh]; + nInputPlane = input->size[dimc]; + outputWidth = (inputWidth - kW) / dW + 1; + outputHeight = (inputHeight - kH) / dH + 1; + + input_data = THTensor_(data)(input); + + THTensor_(resizeAs)(gradInput, input); + gradInput_data = THTensor_(data)(gradInput); + gradOutput_data = THTensor_(data)(gradOutput); + +#pragma omp parallel for private(k) + for(k = 0; k < nInputPlane; k++) + { + long p; + for(p = 0; p < nbatch; p++) + { + real *ptr_gradOutput = gradOutput_data + p*nInputPlane*outputHeight*outputWidth + k*outputWidth*outputHeight; + long xx, yy; + + real* ptr_gi = gradInput_data + p*nInputPlane*inputWidth*inputHeight + k*inputWidth*inputHeight; + long i; + for(i=0; i<inputWidth*inputHeight; i++) + ptr_gi[i] = 0.0; + + for(yy = 0; yy < outputHeight; yy++) + { + for(xx = 0; xx < outputWidth; xx++) + { + real *ptr_gradInput = gradInput_data + p*nInputPlane*inputWidth*inputHeight + k*inputWidth*inputHeight + yy*dH*inputWidth+xx*dW; + real z = *ptr_gradOutput++; + long kx, ky; + + for(ky = 0; ky < kH; ky++) + { + for(kx = 0; kx < kW; kx++) + ptr_gradInput[kx] += z; + ptr_gradInput += inputWidth; + } + } + } + } + } + + return 1; +} + +static const struct luaL_Reg nn_(SpatialAveragePooling__) [] = { + {"SpatialAveragePooling_updateOutput", nn_(SpatialAveragePooling_updateOutput)}, + {"SpatialAveragePooling_updateGradInput", nn_(SpatialAveragePooling_updateGradInput)}, + {NULL, NULL} +}; + +static void nn_(SpatialAveragePooling_init)(lua_State *L) +{ + luaT_pushmetatable(L, torch_Tensor); + luaT_registeratname(L, nn_(SpatialAveragePooling__), "nn"); + lua_pop(L,1); +} + +#endif diff --git a/hessian.lua b/hessian.lua index 3d336fe..21302cb 100644 --- a/hessian.lua +++ b/hessian.lua @@ -330,7 +330,7 @@ function nn.hessian.enable() if storageAndOffset == nil then return nil end - local storage, offset = unpack(storageAndOffset) + local _, offset = unpack(storageAndOffset) return offset end @@ -98,6 +98,9 @@ #include "generic/SpatialMaxPooling.c" #include "THGenerateFloatTypes.h" +#include "generic/SpatialAveragePooling.c" +#include "THGenerateFloatTypes.h" + #include "generic/VolumetricConvolution.c" #include "THGenerateFloatTypes.h" @@ -155,6 +158,7 @@ int luaopen_libnn(lua_State *L) nn_FloatSpatialConvolutionMap_init(L); nn_FloatSpatialSubSampling_init(L); nn_FloatSpatialMaxPooling_init(L); + nn_FloatSpatialAveragePooling_init(L); nn_FloatVolumetricConvolution_init(L); nn_FloatVolumetricMaxPooling_init(L); nn_FloatMultiMarginCriterion_init(L); @@ -193,6 +197,7 @@ int luaopen_libnn(lua_State *L) nn_DoubleSpatialConvolutionMap_init(L); nn_DoubleSpatialSubSampling_init(L); nn_DoubleSpatialMaxPooling_init(L); + nn_DoubleSpatialAveragePooling_init(L); nn_DoubleVolumetricConvolution_init(L); nn_DoubleVolumetricMaxPooling_init(L); nn_DoubleMultiMarginCriterion_init(L); @@ -74,6 +74,7 @@ include('SpatialSubSampling.lua') include('SpatialMaxPooling.lua') include('SpatialMaxPoolingCUDA.lua') include('SpatialLPPooling.lua') +include('SpatialAveragePooling.lua') include('TemporalConvolution.lua') include('TemporalSubSampling.lua') include('TemporalMaxPooling.lua') @@ -1,4 +1,4 @@ --- you can easily test specific units like this: +-- you can easily test specific units like this: -- th -lnn -e "nn.test{'LookupTable'}" -- th -lnn -e "nn.test{'LookupTable', 'Add'}" @@ -66,7 +66,7 @@ function nntest.Add() local ferr,berr = jac.testIO(module,input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') - end + end end function nntest.CMul() @@ -160,12 +160,12 @@ function nntest.HardTanh() local inj = math.random(3,5) local ink = math.random(3,5) local input = torch.Tensor(ink, inj, ini):zero() - + local module = nn.HardTanh() - + local err = jac.testJacobian(module, input) mytester:assertlt(err, precision , 'error on state ') - + local ferr, berr = jac.testIO(module, input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') @@ -176,12 +176,12 @@ function nntest.Abs() local inj = math.random(3,5) local ink = math.random(3,5) local input = torch.Tensor(ink, inj, ini):zero() - + local module = nn.Abs() - + local err = jac.testJacobian(module, input) mytester:assertlt(err, precision , 'error on state ') - + local ferr, berr = jac.testIO(module, input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') @@ -376,25 +376,25 @@ function nntest.SparseLinear() local ini = math.random(50,100) local inj = math.random(5,10) local numNonzero = math.random(3,5) - + local module = nn.SparseLinear(ini,inj) -- Create a random sparse vector - N = {} + local N = {} for i = 1, ini do N[i] = i end - for i = 1, numNonzero do + for i = 1, numNonzero do local j = math.random(i,ini) N[i], N[j] = N[j], N[i] - end + end local input = torch.Tensor(numNonzero, 2):zero() for i = 1, numNonzero do input[{i,1}] = N[i] end local values = input:select(2,2) values:copy(torch.rand(values:nElement())):mul(2):add(-1) - + -- Check output local actual = module:forward(input) local expected = torch.Tensor(inj) - for j = 1, inj do + for j = 1, inj do expected[j] = 0 for i = 1,numNonzero do expected[j] = expected[j] + values[i] * module.weight[{j, N[i]}] @@ -412,13 +412,13 @@ function nntest.SparseLinear() local err = sjac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err,precision, 'error on bias ') - + local err = sjac.testJacobianUpdateParameters(module, input, module.weight) mytester:assertlt(err,precision, 'error on weight [direct update] ') local err = sjac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err,precision, 'error on bias [direct update] ') - + for t,err in pairs(sjac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -473,7 +473,7 @@ end local function criterionJacobianTest1D(cri, input, target) local eps = 1e-6 - local fx = cri:forward(input, target) + local _ = cri:forward(input, target) local dfdx = cri:backward(input, target) -- for each input perturbation, do central difference local centraldiff_dfdx = torch.Tensor():resizeAs(dfdx) @@ -483,7 +483,7 @@ local function criterionJacobianTest1D(cri, input, target) local fx1 = cri:forward(input, target) -- f(xi - h) input[i] = input[i] - 2*eps - local fx2 = cri:forward(input, target) + local fx2 = cri:forward(input, target) -- f'(xi) = (f(xi + h) - f(xi - h)) / 2h local cdfx = (fx1 - fx2) / (2*eps) -- store f' in appropriate place @@ -501,14 +501,14 @@ function nntest.MSECriterion() local input = torch.rand(10) local target = input:clone():add(torch.rand(10)) local cri = nn.MSECriterion() - criterionJacobianTest1D(cri, input, target) + criterionJacobianTest1D(cri, input, target) end function nntest.MarginCriterion() local input = torch.rand(100) local target = input:clone():add(torch.rand(100)) local cri = nn.MarginCriterion() - criterionJacobianTest1D(cri, input, target) + criterionJacobianTest1D(cri, input, target) end function nntest.WeightedMSECriterion() @@ -536,9 +536,9 @@ function nntest.DistKLDivCriterion() end function nntest.ClassNLLCriterion() - local numLabels = math.random(5,10) + local numLabels = math.random(5,10) local input = torch.rand(numLabels) - local target = math.random(1,numLabels) + local target = math.random(1,numLabels) -- default ClassNLLCriterion local cri = nn.ClassNLLCriterion() @@ -595,6 +595,15 @@ end -- end function nntest.Max() + -- 1D + local ini = math.random(3,7) + local input = torch.Tensor(ini):zero() + local module = nn.Max(1) + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'error on state ') + + -- 3D local ini = math.random(3,5) local inj = math.random(3,5) local ink = math.random(3,5) @@ -610,6 +619,15 @@ function nntest.Max() end function nntest.Min() + -- 1D + local ini = math.random(3,7) + local input = torch.Tensor(ini):zero() + local module = nn.Min(1) + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'error on state ') + + -- 3D local ini = math.random(3,5) local inj = math.random(3,5) local ink = math.random(3,5) @@ -625,6 +643,15 @@ function nntest.Min() end function nntest.Mean() + -- 1D + local ini = math.random(3,7) + local input = torch.Tensor(ini):zero() + local module = nn.Mean(1) + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'error on state ') + + -- 3D local ini = math.random(3,5) local inj = math.random(3,5) local ink = math.random(3,5) @@ -814,19 +841,19 @@ function nntest.SpatialConvolution() local input = torch.Tensor(from, inj, ini):zero() -- stochastic - + local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'error on state ') - + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) mytester:assertlt(err , precision, 'error on weight ') - + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'error on bias ') local err = jac.testJacobianUpdateParameters(module, input, module.weight) mytester:assertlt(err , precision, 'error on weight [direct update] ') - + local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') @@ -841,7 +868,7 @@ function nntest.SpatialConvolution() end -- batch - + --verbose = true local batch = math.random(2,5) outi = math.random(4,8) @@ -857,16 +884,16 @@ function nntest.SpatialConvolution() local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'batch error on state ') - + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) mytester:assertlt(err , precision, 'batch error on weight ') - + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'batch error on bias ') local err = jac.testJacobianUpdateParameters(module, input, module.weight) mytester:assertlt(err , precision, 'batch error on weight [direct update] ') - + local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'batch error on bias [direct update] ') @@ -879,7 +906,7 @@ function nntest.SpatialConvolution() mytester:assertlt(err, precision, string.format( 'batch error on bias [%s]', t)) end - + local ferr, berr = jac.testIO(module, input) mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') @@ -898,19 +925,19 @@ function nntest.SpatialConvolutionMM() local input = torch.Tensor(from, inj, ini):zero() -- stochastic - + local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'error on state ') - + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) mytester:assertlt(err , precision, 'error on weight ') - + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'error on bias ') local err = jac.testJacobianUpdateParameters(module, input, module.weight) mytester:assertlt(err , precision, 'error on weight [direct update] ') - + local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') @@ -925,7 +952,7 @@ function nntest.SpatialConvolutionMM() end -- batch - + --verbose = true local batch = math.random(2,5) outi = math.random(4,8) @@ -937,16 +964,16 @@ function nntest.SpatialConvolutionMM() local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'batch error on state ') - + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) mytester:assertlt(err , precision, 'batch error on weight ') - + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'batch error on bias ') local err = jac.testJacobianUpdateParameters(module, input, module.weight) mytester:assertlt(err , precision, 'batch error on weight [direct update] ') - + local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'batch error on bias [direct update] ') @@ -959,7 +986,7 @@ function nntest.SpatialConvolutionMM() mytester:assertlt(err, precision, string.format( 'batch error on bias [%s]', t)) end - + local ferr, berr = jac.testIO(module, input) mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') @@ -990,13 +1017,13 @@ function nntest.SpatialConvolutionMap() local module = nn.SpatialConvolutionMap(nn.tables.random(from, to, fanin), ki, kj, si, sj) local input = torch.Tensor(from, inj, ini):zero() - + local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'error on state ') - + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) mytester:assertlt(err , precision, 'error on weight ') - + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'error on bias ') @@ -1009,7 +1036,7 @@ function nntest.SpatialConvolutionMap() mytester:assertlt(err, precision, string.format( 'error on bias [%s]', t)) end - + local ferr, berr = jac.testIO(module, input) mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') @@ -1065,20 +1092,20 @@ function nntest.SpatialFullConvolution() local inj = math.random(5,8) local module = nn.SpatialFullConvolution(from, to, ki, kj, si, sj) local input = torch.Tensor(from, inj, ini):zero() - + -- stochastic local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'error on state ') - + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) mytester:assertlt(err , precision, 'error on weight ') - + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'error on bias ') local err = jac.testJacobianUpdateParameters(module, input, module.weight) mytester:assertlt(err , precision, 'error on weight [direct update] ') - + local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') @@ -1101,16 +1128,16 @@ function nntest.SpatialFullConvolution() local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'batch error on state ') - + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) mytester:assertlt(err , precision, 'batch error on weight ') - + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'batch error on bias ') local err = jac.testJacobianUpdateParameters(module, input, module.weight) mytester:assertlt(err , precision, 'batch error on weight [direct update] ') - + local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'batch error on bias [direct update] ') @@ -1123,7 +1150,7 @@ function nntest.SpatialFullConvolution() mytester:assertlt(err, precision, string.format( 'batch error on bias [%s]', t)) end - + local ferr, berr = jac.testIO(module, input) mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') @@ -1142,20 +1169,20 @@ function nntest.SpatialFullConvolutionMap() local inj = math.random(5,7) local module = nn.SpatialFullConvolutionMap(tt, ki, kj, si, sj) local input = torch.Tensor(from, inj, ini):zero() - + -- stochastic local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'error on state ') - + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) mytester:assertlt(err , precision, 'error on weight ') - + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'error on bias ') local err = jac.testJacobianUpdateParameters(module, input, module.weight) mytester:assertlt(err , precision, 'error on weight [direct update] ') - + local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') @@ -1168,7 +1195,7 @@ function nntest.SpatialFullConvolutionMap() mytester:assertlt(err, precision, string.format( 'error on bias [%s]', t)) end - + local ferr, berr = jac.testIO(module, input) mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') @@ -1233,7 +1260,7 @@ local function batchcompare(smod, sin, plist) smod:accGradParameters(sin, sgout, 1) bmod:accGradParameters(bin, bgout, 1) - + mytester:assertTensorEq(sout,bout:select(1,1), 1e-8, 'batchcompare error on output') mytester:assertTensorEq(sgin,bgin:select(1,1), 1e-8, 'batchcompare error on gradInput') @@ -1275,7 +1302,7 @@ function nntest.SpatialFullConvolutionBatchCompare() batchcompare(module,input, {'weight','bias','gradWeight','gradBias'}) end - + function nntest.SpatialSubSamplingBatchCompare() @@ -1306,19 +1333,19 @@ function nntest.SpatialSubSampling() local inj = (outj-1)*sj+kj local module = nn.SpatialSubSampling(from, ki, kj, si, sj) local input = torch.Tensor(from, inj, ini):zero() - + local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'error on state ') - + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) mytester:assertlt(err , precision, 'error on weight ') - + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'error on bias ') local err = jac.testJacobianUpdateParameters(module, input, module.weight) mytester:assertlt(err , precision, 'error on weight [direct update] ') - + local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') @@ -1332,7 +1359,6 @@ function nntest.SpatialSubSampling() 'error on bias [%s]', t)) end - --verbose = true local batch = math.random(2,5) outi = math.random(4,8) outj = math.random(4,8) @@ -1341,22 +1367,18 @@ function nntest.SpatialSubSampling() module = nn.SpatialSubSampling(from, ki, kj, si, sj) input = torch.Tensor(batch,from,inj,ini):zero() --- print(from, to, ki, kj, si, sj, batch, ini, inj) --- print(module.weight:size()) --- print(module.gradWeight:size()) - local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'batch error on state ') - + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) mytester:assertlt(err , precision, 'batch error on weight ') - + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'batch error on bias ') local err = jac.testJacobianUpdateParameters(module, input, module.weight) mytester:assertlt(err , precision, 'batch error on weight [direct update] ') - + local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'batch error on bias [direct update] ') @@ -1369,7 +1391,7 @@ function nntest.SpatialSubSampling() mytester:assertlt(err, precision, string.format( 'batch error on bias [%s]', t)) end - + local ferr, berr = jac.testIO(module, input) mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') @@ -1410,6 +1432,70 @@ function nntest.SpatialMaxPooling() end +function nntest.SpatialAveragePooling() + local from = math.random(1,6) + local ki = math.random(1,5) + local kj = math.random(1,5) + local si = math.random(1,4) + local sj = math.random(1,4) + local outi = math.random(6,10) + local outj = math.random(6,10) + local ini = (outi-1)*si+ki + local inj = (outj-1)*sj+kj + local module = nn.SpatialAveragePooling(ki, kj, si, sj) + local input = torch.Tensor(from, inj, ini):zero() + + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'error on state ') + + local ferr, berr = jac.testIO(module, input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') + + local sap = nn.SpatialSubSampling(from, ki, kj, si, sj) + sap.weight:fill(1.0) + sap.bias:fill(0.0) + + local output = module:forward(input) + local gradInput = module:backward(input, output) + local output2 = sap:forward(input) + local gradInput2 = sap:updateGradInput(input, output) + + mytester:assertTensorEq(output, output2, 0.000001, torch.typename(module) .. ' forward err ') + mytester:assertTensorEq(gradInput, gradInput2, 0.000001, torch.typename(module) .. ' backward err ') + + local batch = math.random(2,5) + outi = math.random(4,8) + outj = math.random(4,8) + ini = (outi-1)*si+ki + inj = (outj-1)*sj+kj + module = nn.SpatialAveragePooling(ki, kj, si, sj) + input = torch.Tensor(batch,from,inj,ini):zero() + + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'batch error on state ') + + local ferr, berr = jac.testIO(module, input) + mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') + + local ferr, berr = jac.testIO(module, input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ') + + local sap = nn.SpatialSubSampling(from, ki, kj, si, sj) + sap.weight:fill(1.0) + sap.bias:fill(0.0) + + local output = module:forward(input) + local gradInput = module:backward(input, output) + local output2 = sap:forward(input) + local gradInput2 = sap:updateGradInput(input, output) + + mytester:assertTensorEq(output, output2, 0.000001, torch.typename(module) .. ' forward err (Batch) ') + mytester:assertTensorEq(gradInput, gradInput2, 0.000001, torch.typename(module) .. ' backward err (Batch) ') +end + function nntest.SpatialLPPooling() local fanin = math.random(1,4) local osizex = math.random(1,4) @@ -1433,6 +1519,15 @@ function nntest.SpatialLPPooling() end function nntest.Sum() + -- 1D + local ini = math.random(3,7) + local input = torch.Tensor(ini):zero() + local module = nn.Sum(1) + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'error on state ') + + -- 3D local ini = math.random(3,5) local inj = math.random(3,5) local ink = math.random(3,5) @@ -1452,12 +1547,12 @@ function nntest.Tanh() local inj = math.random(3,5) local ink = math.random(3,5) local input = torch.Tensor(ink, inj, ini):zero() - + local module = nn.Tanh() - + local err = jac.testJacobian(module, input) mytester:assertlt(err, precision , 'error on state ') - + local ferr, berr = jac.testIO(module, input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') @@ -1473,13 +1568,13 @@ function nntest.TemporalConvolution() local ini = (outi-1)*si+ki local module = nn.TemporalConvolution(from, to, ki,si) local input = torch.Tensor(ini, from):zero() - + local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'error on state ') - + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) mytester:assertlt(err , precision, 'error on weight ') - + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'error on bias ') @@ -1498,17 +1593,17 @@ function nntest.TemporalConvolution() mytester:assertlt(err, precision, string.format( 'error on bias [%s]', t)) end - + -- 2D local nBatchFrame = 4 local input = torch.Tensor(nBatchFrame, ini, from):zero() - + local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'error on state ') - + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) mytester:assertlt(err , precision, 'error on weight ') - + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'error on bias ') @@ -1527,21 +1622,21 @@ function nntest.TemporalConvolution() mytester:assertlt(err, precision, string.format( 'error on bias [%s]', t)) end - + local ferr, berr = jac.testIO(module, input) mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') - + -- 2D matches 1D local output = module:forward(input):clone() local outputGrad = torch.randn(output:size()) local inputGrad = module:backward(input, outputGrad):clone() - + local input1D = input:select(1, 2) local output1D = module:forward(input1D) local outputGrad1D = outputGrad:select(1, 2) local inputGrad1D = module:backward(input1D, outputGrad1D) - + mytester:assertTensorEq(output:select(1,2), output1D, 0.000001, 'error on 2D vs 1D forward)') mytester:assertTensorEq(inputGrad:select(1,2), inputGrad1D, 0.000001, 'error on 2D vs 1D backward)') end @@ -1554,19 +1649,19 @@ function nntest.TemporalSubSampling() local ini = (outi-1)*si+ki local module = nn.TemporalSubSampling(from, ki, si) local input = torch.Tensor(ini, from):zero() - + local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'error on state ') - + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) mytester:assertlt(err , precision, 'error on weight ') - + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'error on bias ') - + local err = jac.testJacobianUpdateParameters(module, input, module.weight) mytester:assertlt(err , precision, 'error on weight [direct update] ') - + local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') @@ -1611,17 +1706,17 @@ function nntest.TemporalMaxPooling() local ferr, berr = jac.testIO(module, input) mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') - + -- 2D matches 1D local output = module:forward(input):clone() local outputGrad = torch.randn(output:size()) local inputGrad = module:backward(input, outputGrad):clone() - + local input1D = input:select(1, 2) local output1D = module:forward(input1D) local outputGrad1D = outputGrad:select(1, 2) local inputGrad1D = module:backward(input1D, outputGrad1D) - + mytester:assertTensorEq(output:select(1,2), output1D, 0.000001, 'error on 2D vs 1D forward)') mytester:assertTensorEq(inputGrad:select(1,2), inputGrad1D, 0.000001, 'error on 2D vs 1D backward)') end @@ -1643,19 +1738,19 @@ function nntest.VolumetricConvolution() local inj = (outj-1)*sj+kj local module = nn.VolumetricConvolution(from, to, kt, ki, kj, st, si, sj) local input = torch.Tensor(from, int, inj, ini):zero() - + local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'error on state ') - + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) mytester:assertlt(err , precision, 'error on weight ') - + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'error on bias ') local err = jac.testJacobianUpdateParameters(module, input, module.weight) mytester:assertlt(err , precision, 'error on weight [direct update] ') - + local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') @@ -1668,7 +1763,7 @@ function nntest.VolumetricConvolution() mytester:assertlt(err, precision, string.format( 'error on bias [%s]', t)) end - + local ferr, berr = jac.testIO(module, input) mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') @@ -1676,7 +1771,6 @@ end function nntest.VolumetricMaxPooling() local from = math.random(2,3) - local to = from local kt = math.random(3,4) local ki = math.random(3,4) local kj = math.random(3,4) @@ -1691,10 +1785,10 @@ function nntest.VolumetricMaxPooling() local inj = (outj-1)*sj+kj local module = nn.VolumetricMaxPooling(kt, ki, kj, st, si, sj) local input = torch.Tensor(from, int, inj, ini):zero() - + local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'error on state ') - + local ferr, berr = jac.testIO(module, input) mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') @@ -1712,10 +1806,10 @@ end function nntest.Module_getParameters_2() local n = nn.Sequential() n:add( nn.Linear(10,10) ) - local p = n:getParameters() + local _ = n:getParameters() n:add( nn.Linear(10,10) ) - p = n:getParameters() + local p = n:getParameters() mytester:asserteq((p[{ {111,210} }] - n.modules[2].weight):norm(), 0, 'error when appending new module') mytester:asserteq((p[{ {211,220} }] - n.modules[2].bias):norm(), 0, 'error when appending new module') @@ -1746,10 +1840,10 @@ function nntest.Module_getParameters_4() local n = nn.Sequential() n:add( nn.Linear(10,10) ) n:add( n.modules[1]:clone() ) - local p = n:getParameters() + local _ = n:getParameters() n:add(nn.Linear(10,10)) - p = n:getParameters() + local p = n:getParameters() mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning') mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning') @@ -1787,10 +1881,10 @@ function nntest.Module_getParameters_6() local n = nn.Sequential() n:add( nn.Linear(10,10) ) n:add( n.modules[1]:clone('weight','bias') ) - local p = n:getParameters() + local _ = n:getParameters() n:add(nn.Linear(10,10)) - p = n:getParameters() + local p = n:getParameters() mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning+sharing') mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning+sharing') @@ -1808,10 +1902,10 @@ function nntest.Module_getParameters_7() local n = nn.Sequential() n:add( nn.Linear(10,10) ) n:add( n.modules[1]:clone('weight','bias') ) - local p = n:getParameters() + local _ = n:getParameters() n:add(nn.Linear(10,10)) - p = n:getParameters() + local _ = n:getParameters() local n1 = nn.Sequential() n1:add( nn.Linear(10,10) ) @@ -1823,7 +1917,7 @@ function nntest.Module_getParameters_7() n:add( n1 ) n:add( n2 ) - local p = n:getParameters() + local _ = n:getParameters() local nf = nn.Sequential() nf:add( n1 ) @@ -1843,12 +1937,12 @@ end function nntest.Module_getParameters_8() local function makeMLP(nin, ns) local net = nn.Sequential() - - for k,v in ipairs(ns) do + + for k,v in ipairs(ns) do net:add(nn.Linear(nin, v)) nin = v end - _,_ = net:getParameters() + local _,_ = net:getParameters() return net end @@ -1857,18 +1951,18 @@ function nntest.Module_getParameters_8() local net = nn.Sequential():add(mlp1:get(1)) :add(mlp2:get(1)) - + -- clone the second MLP to ensure that the weights before calling getParameters are preserved - mlp2 = mlp2:clone() + mlp2 = mlp2:clone() - local p, gp = net:getParameters() + local p, _ = net:getParameters() mytester:asserteq((p[{ {1,100} }] - net.modules[1].weight):norm(), 0, 'error when using partial realloc') mytester:asserteq((p[{ {111,210} }] - net.modules[2].weight):norm(), 0, 'error when using partial realloc') -- check that the weights have the same values as before get Parameters was called mytester:asserteq((net.modules[1].weight - mlp1.modules[1].weight):norm(), 0, ' error when using partial realloc') mytester:asserteq((net.modules[2].weight - mlp2.modules[1].weight):norm(), 0, ' error when using partial realloc') - + end function nntest.PairwiseDistance() @@ -1886,17 +1980,17 @@ function nntest.PairwiseDistance() local err = jac.testJacobian(module,input) mytester:assertlt(err,precision, ' error on state ') - + local ferr,berr = jac.testIO(module,input) mytester:asserteq(ferr, 0, torch.typename(module)..' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module)..' - i/o backward err ') -- Also check that the forward prop result is correct. input = torch.rand(2, ini) - err = torch.dist(input:select(1,1), input:select(1,2), p) - + err = torch.dist(input:select(1,1), input:select(1,2), p) - module:forward(input)[1] - mytester:assertlt(err,precision, ' error on non-batch fprop ') - + mytester:assertlt(err,precision, ' error on non-batch fprop ') + -- TEST CASE 2: batch input local inj = math.random(3,5) input = torch.Tensor(2, inj, ini):zero() @@ -1915,12 +2009,12 @@ function nntest.PairwiseDistance() local inputb = torch.rand(inj,ini) local dist_manual = torch.Tensor(inj) for i=1, inputa:size(1) do - dist_manual[i] = torch.dist(inputa:select(1,i), inputb:select(1,i),p) + dist_manual[i] = torch.dist(inputa:select(1,i), inputb:select(1,i),p) end -- compare the distances to the module's fprop local dist = module:forward(torch.cat(inputa,inputb,1):resize(2,inj,ini)) - err = dist - dist_manual - mytester:assertlt(err:norm(), precision, torch.typename(module) .. + err = dist - dist_manual + mytester:assertlt(err:norm(), precision, torch.typename(module) .. ' error on batch fprop ') end end @@ -1933,7 +2027,7 @@ function nntest.LookupTable() local module = nn.LookupTable(totalIndex, entry_size) local minval = 1 local maxval = totalIndex - + local output = module:forward(input) module:backwardUpdate(input, output, 0.1) input:zero() @@ -1944,7 +2038,7 @@ function nntest.LookupTable() local err = jac.testJacobianUpdateParameters(module, input, module.weight, minval, maxval) mytester:assertlt(err,precision, '1D error on weight [direct update] ') - + module.gradWeight:zero() for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( @@ -1957,7 +2051,7 @@ function nntest.LookupTable() local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight, minval, maxval) mytester:assertlt(err,precision, '2D error on weight ') - + local err = jac.testJacobianUpdateParameters(module, input, module.weight, minval, maxval) mytester:assertlt(err,precision, '2D error on weight [direct update] ') @@ -1972,7 +2066,7 @@ function nntest.LookupTable() local ferr,berr = jac.testIO(module,input,minval,maxval) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') - + -- accUpdate module:accUpdateOnly() mytester:assert(not module.gradWeight, 'gradWeight is nil') @@ -1980,7 +2074,7 @@ function nntest.LookupTable() local output = module:forward(input) module:backwardUpdate(input, output, 0.1) end - + function nntest.AddConstant() local nbatch = torch.random(3, 5) local f = torch.random(3, 5) @@ -2059,18 +2153,18 @@ end function nntest.SelectTable() local input = { - torch.rand(3,4,5), torch.rand(3,4,5), - {torch.rand(3,4,5)}, + torch.rand(3,4,5), torch.rand(3,4,5), + {torch.rand(3,4,5)}, {torch.rand(3,4,5), {torch.rand(3,4,5)}} } local gradOutputs = { - torch.rand(3,4,5), torch.rand(3,4,5), - {torch.rand(3,4,5)}, + torch.rand(3,4,5), torch.rand(3,4,5), + {torch.rand(3,4,5)}, {torch.rand(3,4,5), {torch.rand(3,4,5)}} } local zeros = { - torch.Tensor(3,4,5):zero(), torch.Tensor(3,4,5):zero(), - {torch.Tensor(3,4,5):zero()}, + torch.Tensor(3,4,5):zero(), torch.Tensor(3,4,5):zero(), + {torch.Tensor(3,4,5):zero()}, {torch.Tensor(3,4,5):zero(), {torch.Tensor(3,4,5):zero()}} } local nonIdx = {2,3,4,1} @@ -2098,7 +2192,7 @@ function nntest.MixtureTable() local expertInput = torch.randn(5,3,6) local gradOutput = torch.randn(5,6) local input = { - torch.rand(5,3), + torch.rand(5,3), {expertInput:select(2,1), expertInput:select(2,2), expertInput:select(2,3)} } local module = nn.MixtureTable() @@ -2121,13 +2215,13 @@ function nntest.MixtureTable() local gradInput = module:backward(input, gradOutput) mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture2 gater gradInput") mytester:assertTensorEq(gradInput[2], expertGradInput2, 0.000001, "mixture2 expert gradInput") - + --[[ 3D ]]-- local expertInput = torch.randn(5,6,3,2) local gradOutput = torch.randn(5,6,2) -- expertInput is a Table: local input = { - torch.rand(5,3), + torch.rand(5,3), {expertInput:select(3,1), expertInput:select(3,2), expertInput:select(3,3)} } local module = nn.MixtureTable() @@ -2150,13 +2244,13 @@ function nntest.MixtureTable() local gradInput = module:backward(input, gradOutput) mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture4 gater gradInput") mytester:assertTensorEq(gradInput[2], expertGradInput2, 0.000001, "mixture4 expert gradInput") - + --[[ 1D ]]-- -- expertInput is a Table: local expertInput = torch.randn(3,6) local gradOutput = torch.randn(6) local input = { - torch.rand(3), + torch.rand(3), {expertInput:select(1,1), expertInput:select(1,2), expertInput:select(1,3)} } local module = nn.MixtureTable() @@ -2174,7 +2268,7 @@ function nntest.MixtureTable() -- test type-cast module:float() local input2 = { - input[1]:float(), + input[1]:float(), {input[2][1]:float(), input[2][2]:float(), input[2][3]:float()} } local output = module:forward(input2) @@ -2200,13 +2294,13 @@ function nntest.MixtureTable() local gradInput = module:backward(input2, gradOutput:float()) mytester:assertTensorEq(gradInput[1], gaterGradInput2:float(), 0.000001, "mixture6B gater gradInput") mytester:assertTensorEq(gradInput[2], expertGradInput2:float(), 0.000001, "mixture6B expert gradInput") - + --[[ 2D gater, 1D expert]]-- -- expertInput is a Table: local expertInput = torch.randn(5,3) local gradOutput = torch.randn(5) local input = { - torch.rand(5,3), + torch.rand(5,3), {expertInput:select(2,1), expertInput:select(2,2), expertInput:select(2,3)} } local module = nn.MixtureTable() @@ -2280,7 +2374,7 @@ function nntest.SpatialUpSamplingNearest() table.insert(shape, torch.random(2, 2+dim-1)) end - -- Check that the gradient is correct by using finite elements + -- Check that the gradient is correct by using finite elements local input = torch.Tensor(unpack(shape)):zero() local err = jac.testJacobian(m, input) @@ -2296,10 +2390,10 @@ function nntest.ConcatTable() -- Test tensor input local input = torch.rand(5, 5, 5) local m = nn.Sequential() - + local concat = nn.ConcatTable() concat:add(nn.Identity()) - + m:add(concat) -- Output of concat is a table of length 1 m:add(nn.JoinTable(1)) -- jac needs a tensor tensor output @@ -2318,7 +2412,7 @@ function nntest.ConcatTable() torch.randn(3,3,4):float(), torch.randn(3,3,4):float(), torch.randn(3,3,4):float() } local gradOutput = { - {_gradOutput[1][1], _gradOutput[2][1], {_gradOutput[3][1]}}, + {_gradOutput[1][1], _gradOutput[2][1], {_gradOutput[3][1]}}, {_gradOutput[1][2], _gradOutput[2][2], {_gradOutput[3][2]}}, {_gradOutput[1][3], _gradOutput[2][3], {_gradOutput[3][3]}} } @@ -2327,7 +2421,7 @@ function nntest.ConcatTable() module:add(nn.Identity()) module:add(nn.Identity()) module:float() - + local output = module:forward(input) local output2 = {input, input, input} equal(output2, output, "ConcatTable table output") @@ -2338,7 +2432,7 @@ end function nntest.FlattenTable() -- Create a nested table. Obviously we can't even stochastically test - -- the space of all possible nested tables (it's infinite), but here is a + -- the space of all possible nested tables (it's infinite), but here is a -- hand-coded one that covers all the cases we need: local input = { torch.rand(1), @@ -2381,7 +2475,7 @@ function nntest.FlattenTable() -- CASE 1: Nothing changes so the output table shouldn't be redefined local old_input_map = m.input_map local old_output = m.output - output = m:forward(input) + local _ = m:forward(input) mytester:assert(old_input_map == m.input_map and old_output == m.output) -- CASE 2: An element is added to the input table @@ -2390,7 +2484,7 @@ function nntest.FlattenTable() input[2][#(input[2])+1] = torch.rand(5) m:forward(input) mytester:assert(old_input_map ~= m.input_map and old_output ~= m.output) - + -- CASE 3: An element is removed from the input table old_input_map = m.input_map old_output = m.output @@ -2423,7 +2517,7 @@ function nntest.L1Penalty() local input = torch.rand(2,10):add(-0.5) input[1][1] = 0 - local out = m:forward(input) + local _ = m:forward(input) local grad = m:backward(input, torch.ones(input:size())) local err = input:clone():abs():sum()*weight - m.loss @@ -2456,7 +2550,6 @@ function nntest.DepthConcat() local output = torch.Tensor(2, outputSize:sum(), 12, 12):zero() -- zero for padding local narrows = { {{},{1,5},{},{}}, {{},{6,11},{2,11},{2,11}}, {{},{12,18},{2,10},{2,10}}, {{},{19,26},{3,10},{3,10}} } local gradInput = input:clone():zero() - local gradWeights = {} for i=1,4 do local conv = concat:get(i) local gradWeight = conv.gradWeight:clone() |