From 7d64ef1119517c545b2c7325ba0155ceb11c3ad2 Mon Sep 17 00:00:00 2001 From: Edward Grefenstette Date: Thu, 2 Apr 2015 17:23:55 +0100 Subject: Replicate now works across arbitrary dimensions. --- Replicate.lua | 37 +++++++++++++++++---- doc/simple.md | 62 +++++++++++++++++----------------- test.lua | 104 ++++++++++++++++++++++++++++++++++------------------------ 3 files changed, 123 insertions(+), 80 deletions(-) diff --git a/Replicate.lua b/Replicate.lua index d049044..5dd849a 100644 --- a/Replicate.lua +++ b/Replicate.lua @@ -1,20 +1,34 @@ local Replicate, parent = torch.class('nn.Replicate','nn.Module') -function Replicate:__init(nf) +function Replicate:__init(nf, dim) parent.__init(self) self.nfeatures = nf + self.dim = dim or 1 + assert(dim > 0, "Can only replicate across positive integer dimensions.") end function Replicate:updateOutput(input) + assert( + self.dim <= input:dim()+1, + "Not enough input dimensions to replicate along dimension " .. + tostring(self.dim) .. ".") local sz = torch.LongStorage(input:dim()+1) - sz[1] = self.nfeatures + sz[self.dim] = self.nfeatures for i = 1,input:dim() do - sz[i+1] = input:size(i) + local offset = 0 + if i >= self.dim then + offset = 1 + end + sz[i+offset] = input:size(i) end local st = torch.LongStorage(input:dim()+1) - st[1] = 0 + st[self.dim] = 0 for i = 1,input:dim() do - st[i+1] = input:stride(i) + local offset = 0 + if i >= self.dim then + offset = 1 + end + st[i+offset] = input:stride(i) end self.output = input.new(input:storage(),input:storageOffset(),sz,st) return self.output @@ -22,7 +36,16 @@ end function Replicate:updateGradInput(input, gradOutput) self.gradInput:resizeAs(input):zero() - local gradInput = self.gradInput:view(1, unpack(input:size():totable())) - gradInput:sum(gradOutput, 1) + local sz = torch.LongStorage(input:dim()+1) + sz[self.dim] = 1 + for i = 1,input:dim() do + local offset = 0 + if i >= self.dim then + offset = 1 + end + sz[i+offset] = input:size(i) + end + local gradInput = self.gradInput:view(sz) + gradInput:sum(gradOutput, self.dim) return self.gradInput end diff --git a/doc/simple.md b/doc/simple.md index 01a6140..7c97a6b 100755 --- a/doc/simple.md +++ b/doc/simple.md @@ -1,6 +1,5 @@ # Simple layers # - Simple Modules are used for various tasks like adapting Tensor methods and providing affine transformations : * Parameterized Modules : * [Linear](#nn.Linear) : a linear transformation ; @@ -36,7 +35,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi * [SpatialDropout](#nn.SpatialDropout) : Same as Dropout but for spatial inputs where adjacent pixels are strongly correlated ; * [Padding](#nn.Padding) : adds padding to a dimension ; * [L1Penalty](#nn.L1Penalty) : adds an L1 penalty to an input (for sparsity); - + ## Linear ## @@ -114,7 +113,6 @@ x = torch.Tensor({ {1, 0.1}, {2, 0.3}, {10, 0.3}, {31, 0.2} }) The first column contains indices, the second column contains values in a a vector where all other elements are zeros. The indices should not exceed the stated dimensions of the input to the layer (10000 in the example). - ## Dropout ## @@ -123,7 +121,7 @@ module = nn.Dropout(p) ``` During training, `Dropout` masks parts of the `input` using binary samples from a [bernoulli](http://en.wikipedia.org/wiki/Bernoulli_distribution) distribution. -Each `input` element has a probability of `p` of being dropped, i.e having its commensurate output element be zero. This has proven an effective technique for regularization and preventing the co-adaptation of neurons (see [Hinton et al. 2012](http://arxiv.org/abs/1207.0580)). +Each `input` element has a probability of `p` of being dropped, i.e having its commensurate output element be zero. This has proven an effective technique for regularization and preventing the co-adaptation of neurons (see [Hinton et al. 2012](http://arxiv.org/abs/1207.0580)). Furthermore, the ouputs are scaled by a factor of `1/(1-p)` during training. This allows the `input` to be simply forwarded as-is during evaluation. @@ -185,7 +183,6 @@ We can return to training our model by first calling [Module:training()](module. When used, `Dropout` should normally be applied to the input of parameterized [Modules](module.md#nn.Module) like [Linear](#nn.Linear) or [SpatialConvolution](convolution.md#nn.SpatialConvolution). A `p` of `0.5` (the default) is usually okay for hidden layers. `Dropout` can sometimes be used successfully on the dataset inputs with a `p` around `0.2`. It sometimes works best following [Transfer](transfer.md) Modules like [ReLU](transfer.md#nn.ReLU). All this depends a great deal on the dataset so its up to the user to try different combinations. - ## SpatialDropout ## @@ -224,16 +221,16 @@ gnuplot.grid(true) module = nn.Add(inputDimension, scalar) ``` -Applies a bias term to the incoming data, i.e. `yi = x_i + b_i`, or if `scalar = true` then uses a single bias term, `yi = x_i + b`. +Applies a bias term to the incoming data, i.e. `yi = x_i + b_i`, or if `scalar = true` then uses a single bias term, `yi = x_i + b`. Example: ```lua y = torch.Tensor(5) -mlp = nn.Sequential() +mlp = nn.Sequential() mlp:add(nn.Add(5)) -function gradUpdate(mlp, x, y, criterion, learningRate) +function gradUpdate(mlp, x, y, criterion, learningRate) local pred = mlp:forward(x) local err = criterion:forward(pred, y) local gradCriterion = criterion:backward(pred, y) @@ -245,7 +242,7 @@ end for i = 1, 10000 do x = torch.rand(5) - y:copy(x); + y:copy(x); for i = 1, 5 do y[i] = y[i] + i; end err = gradUpdate(mlp, x, y, nn.MSECriterion(), 0.01) end @@ -274,7 +271,7 @@ i.e. the network successfully learns the input `x` has been shifted to produce t module = nn.Mul() ``` -Applies a _single_ scaling factor to the incoming data, i.e. `y = w x`, where `w` is a scalar. +Applies a _single_ scaling factor to the incoming data, i.e. `y = w x`, where `w` is a scalar. Example: @@ -283,7 +280,7 @@ y = torch.Tensor(5) mlp = nn.Sequential() mlp:add(nn.Mul()) -function gradUpdate(mlp, x, y, criterion, learningRate) +function gradUpdate(mlp, x, y, criterion, learningRate) local pred = mlp:forward(x) local err = criterion:forward(pred, y) local gradCriterion = criterion:backward(pred, y) @@ -331,7 +328,7 @@ y = torch.Tensor(5) sc = torch.Tensor(5) for i = 1, 5 do sc[i] = i; end -- scale input with this -function gradUpdate(mlp, x, y, criterion, learningRate) +function gradUpdate(mlp, x, y, criterion, learningRate) local pred = mlp:forward(x) local err = criterion:forward(pred, y) local gradCriterion = criterion:backward(pred, y) @@ -415,7 +412,7 @@ Hence, if an `nxpxq` Tensor was given as input, and `dimension` = `2` then an `n module = nn.Euclidean(inputSize,outputSize) ``` -Outputs the Euclidean distance of the input to `outputSize` centers, i.e. this layer has the weights `w_j`, for `j` = `1`,..,`outputSize`, where `w_j` are vectors of dimension `inputSize`. +Outputs the Euclidean distance of the input to `outputSize` centers, i.e. this layer has the weights `w_j`, for `j` = `1`,..,`outputSize`, where `w_j` are vectors of dimension `inputSize`. The distance `y_j` between center `j` and input `x` is formulated as `y_j = || w_j - x ||`. @@ -426,7 +423,7 @@ The distance `y_j` between center `j` and input `x` is formulated as `y_j = || w module = nn.WeightedEuclidean(inputSize,outputSize) ``` -This module is similar to [Euclidean](#nn.Euclidean), but additionally learns a separate diagonal covariance matrix across the features of the input space _for each center_. +This module is similar to [Euclidean](#nn.Euclidean), but additionally learns a separate diagonal covariance matrix across the features of the input space _for each center_. In other words, for each of the `outputSize` centers `w_j`, there is a diagonal covariance matrices `c_j`, for `j` = `1`,..,`outputSize`, where `c_j` are stored as vectors of size `inputSize`. @@ -439,8 +436,10 @@ The distance `y_j` between center `j` and input `x` is formulated as `y_j = || c module = nn.Identity() ``` -Creates a module that returns whatever is input to it as output. + +Creates a module that returns whatever is input to it as output. This is useful when combined with the module [ParallelTable](table.md#nn.ParallelTable) in case you do not wish to do anything to one of the input Tensors. + Example: ```lua @@ -448,7 +447,7 @@ mlp = nn.Identity() print(mlp:forward(torch.ones(5, 2))) ``` -gives the output: +gives the output: ```lua 1 1 @@ -461,10 +460,10 @@ gives the output: Here is a more useful example, where one can implement a network which also computes a Criterion using this module: -```lua +```lua pred_mlp = nn.Sequential() -- A network that makes predictions given x. -pred_mlp:add(nn.Linear(5, 4)) -pred_mlp:add(nn.Linear(4, 3)) +pred_mlp:add(nn.Linear(5, 4)) +pred_mlp:add(nn.Linear(4, 3)) xy_mlp = nn.ParallelTable() -- A network for predictions and for keeping the xy_mlp:add(pred_mlp) -- true label for comparison with a criterion @@ -478,14 +477,14 @@ mlp:add(cr_wrap) -- and then applies the criterion. for i = 1, 100 do -- Do a few training iterations x = torch.ones(5) -- Make input features. - y = torch.Tensor(3) + y = torch.Tensor(3) y:copy(x:narrow(1,1,3)) -- Make output label. err = mlp:forward{x,y} -- Forward both input and output. print(err) -- Print error from criterion. - mlp:zeroGradParameters() -- Do backprop... - mlp:backward({x, y}) - mlp:updateParameters(0.05) + mlp:zeroGradParameters() -- Do backprop... + mlp:backward({x, y}) + mlp:updateParameters(0.05) end ``` @@ -496,7 +495,7 @@ end module = nn.Copy(inputType, outputType, [forceCopy, dontCast]) ``` -This layer copies the input to output with type casting from input type from `inputType` to `outputType`. Unless `forceCopy` is true, when the first two arguments are the same, the input isn't copied, only transfered as the output. The default `forceCopy` is false. +This layer copies the input to output with type casting from input type from `inputType` to `outputType`. Unless `forceCopy` is true, when the first two arguments are the same, the input isn't copied, only transfered as the output. The default `forceCopy` is false. When `dontCast` is true, a call to `nn.Copy:type(type)` will not cast the module's `output` and `gradInput` Tensors to the new type. The default is false. @@ -512,10 +511,10 @@ Narrow is application of [narrow](https://github.com/torch/torch7/blob/master/do ## Replicate ## ```lua -module = nn.Replicate(nFeature) +module = nn.Replicate(nFeature, dim) ``` -This class creates an output where the input is replicated `nFeature` times along its first dimension. There is no memory allocation or memory copy in this module. It sets the [stride](https://github.com/torch/torch7/blob/master/doc/tensor.md#torch.Tensor.stride) along the first dimension to zero. +This class creates an output where the input is replicated `nFeature` times along dimension `dim` (default 1). There is no memory allocation or memory copy in this module. It sets the [stride](https://github.com/torch/torch7/blob/master/doc/tensor.md#torch.Tensor.stride) along the `dim`th dimension to zero. ```lua > x = torch.linspace(1, 5, 5) @@ -556,9 +555,10 @@ This class creates an output where the input is replicated `nFeature` times alon module = nn.Reshape(dimension1, dimension2, ... [, batchMode]) ``` -Reshapes an `nxpxqx..` Tensor into a `dimension1xdimension2x...` Tensor, taking the elements column-wise. -The optional last argument `batchMode`, when `true` forces the first dimension of the input to be considered the batch dimension, and thus keep its size fixed. This is necessary when dealing with batch sizes of one. When `false`, it forces the entire input (including the first dimension) to be reshaped to the input size. Default `batchMode=nil`, which means that the module considers inputs with more elements than the produce of provided sizes, i.e. `dimension1xdimension2x...`, to be batches. +Reshapes an `nxpxqx..` Tensor into a `dimension1xdimension2x...` Tensor, taking the elements column-wise. + +The optional last argument `batchMode`, when `true` forces the first dimension of the input to be considered the batch dimension, and thus keep its size fixed. This is necessary when dealing with batch sizes of one. When `false`, it forces the entire input (including the first dimension) to be reshaped to the input size. Default `batchMode=nil`, which means that the module considers inputs with more elements than the produce of provided sizes, i.e. `dimension1xdimension2x...`, to be batches. Example: @@ -766,11 +766,11 @@ This can be used in conjunction with [Concat](containers.md#nn.Concat) to emulat ```lua mlp = nn.Sequential() -c = nn.Concat(2) +c = nn.Concat(2) for i = 1, 10 do local t = nn.Sequential() t:add(nn.Select(1, i)) - t:add(nn.Linear(3, 2)) + t:add(nn.Linear(3, 2)) t:add(nn.Reshape(2, 1)) c:add(t) end @@ -946,7 +946,7 @@ C = model.forward(A) -- C will be of size `b x m` `module` = `nn.Padding(dim, pad [, nInputDim, value])` This module adds `pad` units of padding to dimension `dim` of the input. -If `pad` is negative, padding is added to the left, otherwise, it is added to the right of the dimension. When `nInputDim` is provided, inputs larger than that value will be considered batches where the actual `dim` to be padded will +If `pad` is negative, padding is added to the left, otherwise, it is added to the right of the dimension. When `nInputDim` is provided, inputs larger than that value will be considered batches where the actual `dim` to be padded will be dimension `dim + 1`. When `value` is provide, the padding will be filled with that `value`. The default `value` is zero. Example 1: diff --git a/test.lua b/test.lua index 27a1747..8f08809 100644 --- a/test.lua +++ b/test.lua @@ -113,7 +113,7 @@ function nntest.CMul() local output = module:forward(input) local output2 = torch.cmul(input, module.weight:view(1,ini,inj,ink):expandAs(input)) mytester:assertTensorEq(output2, output, 0.000001, 'CMul forward 2D err') - + module:zeroGradParameters() local gradWeight = module.gradWeight:clone() local gradInput = module:backward(input, output) @@ -122,7 +122,7 @@ function nntest.CMul() gradInput2:view(input:size(1), -1):addcmul(1, module.weight:view(1,-1):expandAs(outputView), outputView) mytester:assertTensorEq(gradInput2, gradInput, 0.000001, 'CMul updateGradInput 2D err') mytester:assert(gradInput:isSameSizeAs(input), 'CMul gradInput 2D size err') - + local inputView = input:view(nframe, -1) local gradWeightView = gradWeight:view(1, -1) for i=1,nframe do @@ -130,7 +130,7 @@ function nntest.CMul() end mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'CMul accGradParameters 2D err') mytester:assert(module.weight:isSameSizeAs(module.gradWeight), 'CMul gradWeight size err') - + input:zero() local err = jac.testJacobian(module,input) @@ -592,18 +592,18 @@ function nntest.Euclidean() local gradOutput = torch.randn(inj) local module = nn.Euclidean(ini,inj) local output = module:forward(input):clone() - + local output2 = torch.Tensor(inj):zero() for o = 1,module.weight:size(2) do output2[o] = input:dist(module.weight:select(2,o)) end mytester:assertTensorEq(output, output2, 0.000001, 'Euclidean forward 1D err') - + local input2 = torch.randn(8, ini) input2[2]:copy(input) local output2 = module:forward(input2) mytester:assertTensorEq(output2[2], output, 0.000001, 'Euclidean forward 2D err') - + local output = module:forward(input):clone() module:zeroGradParameters() local gradInput = module:backward(input, gradOutput, 1):clone() @@ -616,7 +616,7 @@ function nntest.Euclidean() gradInput2:add(temp) end mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'Euclidean updateGradInput 1D err') - + local gradWeight = module.gradWeight:clone():zero() for o = 1,module.weight:size(2) do temp:copy(module.weight:select(2,o)):add(-1,input) @@ -624,16 +624,16 @@ function nntest.Euclidean() gradWeight:select(2,o):add(1, temp) end mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'Euclidean accGradParameters 1D err') - + local input2 = input:view(1, -1):repeatTensor(8, 1) local gradOutput2 = gradOutput:view(1, -1):repeatTensor(8, 1) local output2 = module:forward(input2) module:zeroGradParameters() local gradInput2 = module:backward(input2, gradOutput2, 1/8) mytester:assertTensorEq(gradInput2[2], gradInput, 0.000001, 'Euclidean updateGradInput 2D err') - + mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'Euclidean accGradParameters 2D err') - + input:zero() module.fastBackward = false local err = jac.testJacobian(module,input) @@ -655,7 +655,7 @@ function nntest.WeightedEuclidean() local module = nn.WeightedEuclidean(ini,inj) local output = module:forward(input):clone() - + local output2 = torch.Tensor(inj):zero() local temp = input:clone() for o = 1,module.weight:size(2) do @@ -665,12 +665,12 @@ function nntest.WeightedEuclidean() output2[o] = math.sqrt(temp:sum()) end mytester:assertTensorEq(output, output2, 0.000001, 'WeightedEuclidean forward 1D err') - + local input2 = torch.randn(8, ini) input2[2]:copy(input) local output2 = module:forward(input2) mytester:assertTensorEq(output2[2], output, 0.000001, 'WeightedEuclidean forward 2D err') - + local output = module:forward(input):clone() module:zeroGradParameters() local gradInput = module:backward(input, gradOutput, 1):clone() @@ -683,7 +683,7 @@ function nntest.WeightedEuclidean() gradInput2:add(temp) end mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'WeightedEuclidean updateGradInput 1D err') - + local gradWeight = module.gradWeight:clone():zero() local gradDiagCov = module.gradDiagCov:clone():zero() for o = 1,module.weight:size(2) do @@ -702,20 +702,20 @@ function nntest.WeightedEuclidean() end mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'WeightedEuclidean accGradParameters gradWeight 1D err') mytester:assertTensorEq(gradDiagCov, module.gradDiagCov, 0.000001, 'WeightedEuclidean accGradParameters gradDiagCov 1D err') - + local input2 = input:view(1, -1):repeatTensor(8, 1) local gradOutput2 = gradOutput:view(1, -1):repeatTensor(8, 1) local output2 = module:forward(input2) module:zeroGradParameters() local gradInput2 = module:backward(input2, gradOutput2, 1/8) mytester:assertTensorEq(gradInput2[2], gradInput, 0.000001, 'WeightedEuclidean updateGradInput 2D err') - + mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'WeightedEuclidean accGradParameters gradWeight 2D err') mytester:assertTensorEq(gradDiagCov, module.gradDiagCov, 0.000001, 'WeightedEuclidean accGradParameters gradDiagCov 2D err') - + input:zero() module.fastBackward = false - + local err = jac.testJacobian(module,input) mytester:assertlt(err,precision, 'error on state ') @@ -728,7 +728,7 @@ function nntest.WeightedEuclidean() local ferr,berr = jac.testIO(module,input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') - + input:zero() module:zeroGradParameters() local err = jac.testJacobian(module,input) @@ -1752,23 +1752,23 @@ function nntest.SpatialAveragePooling() local inj = (outj-1)*sj+kj local module = nn.SpatialAveragePooling(ki, kj, si, sj) local input = torch.Tensor(from, inj, ini):zero() - + local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'error on state ') local ferr, berr = jac.testIO(module, input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') - + local sap = nn.SpatialSubSampling(from, ki, kj, si, sj) sap.weight:fill(1.0/(ki*kj)) sap.bias:fill(0.0) - + local output = module:forward(input) local gradInput = module:backward(input, output) local output2 = sap:forward(input) local gradInput2 = sap:updateGradInput(input, output) - + mytester:assertTensorEq(output, output2, 0.000001, torch.typename(module) .. ' forward err ') mytester:assertTensorEq(gradInput, gradInput2, 0.000001, torch.typename(module) .. ' backward err ') @@ -1782,24 +1782,24 @@ function nntest.SpatialAveragePooling() local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'batch error on state ') - + local ferr, berr = jac.testIO(module, input) mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') - + local ferr, berr = jac.testIO(module, input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ') - + local sap = nn.SpatialSubSampling(from, ki, kj, si, sj) sap.weight:fill(1.0/(ki*kj)) sap.bias:fill(0.0) - + local output = module:forward(input) local gradInput = module:backward(input, output) local output2 = sap:forward(input) local gradInput2 = sap:updateGradInput(input, output) - + mytester:assertTensorEq(output, output2, 0.000001, torch.typename(module) .. ' forward err (Batch) ') mytester:assertTensorEq(gradInput, gradInput2, 0.000001, torch.typename(module) .. ' backward err (Batch) ') end @@ -2125,7 +2125,7 @@ function nntest.VolumetricConvolutionBatchCompare() local inj = (outj-1)*sj+kj local module = nn.VolumetricConvolution(from, to, kt, ki, kj, st, si, sj) local input = torch.randn(from, int, inj, ini) - batchcompare(module,input, {'weight','bias','gradWeight','gradBias'}) + batchcompare(module,input, {'weight','bias','gradWeight','gradBias'}) end function nntest.VolumetricMaxPooling() @@ -2346,7 +2346,7 @@ function nntest.Module_listModules() mlp3:add(linear) mlp3:add(tanh) mlp3:add(reshape) - + local mlp2 = nn.Sequential() local view = nn.View(outputSize) local linear2 = nn.Linear(outputSize, inputSize) @@ -2355,7 +2355,7 @@ function nntest.Module_listModules() mlp2:add(view) mlp2:add(linear2) mlp2:add(tanh2) - + local concat = nn.ConcatTable() local id = nn.Identity() concat:add(mlp2) @@ -2364,15 +2364,15 @@ function nntest.Module_listModules() local add = nn.CAddTable() mlp:add(concat) mlp:add(add) - + local modules2 = {mlp, concat, mlp2, mlp3, linear, tanh, reshape, view, linear2, tanh2, id, add} local modules = mlp:listModules() - + mytester:assert(#modules2 == #modules, 'missing modules error') - + for i,module in ipairs(modules) do mytester:assert(torch.type(module) == torch.type(modules2[i]), 'module error') - end + end end function nntest.PairwiseDistance() @@ -2772,7 +2772,7 @@ function nntest.View() mytester:assertTableEq(module:forward(minibatch):size(1), minibatch:size(1), "Error in minibatch dimension with size -1") - + -- Minibatch Generalization local minibatch = torch.rand(5,2,6) local module = nn.View(6) @@ -2835,11 +2835,11 @@ function nntest.Parallel() m:add(nn.View(4,5,1)) m:add(nn.View(4,5,1)) m:add(nn.View(4,5,1)) - + local output = m:forward(input) local output2 = input:transpose(1,3):transpose(1,2) mytester:assertTensorEq(output2, output, 0.000001, 'Parallel forward err') - + local gradInput = m:backward(input, output2) mytester:assertTensorEq(gradInput, input, 0.000001, 'Parallel backward err') end @@ -2854,11 +2854,11 @@ function nntest.ParallelTable() m:add(nn.SplitTable(1)) m:add(p) m:add(nn.JoinTable(3)) - + local output = m:forward(input) local output2 = input:transpose(1,3):transpose(1,2) mytester:assertTensorEq(output2, output, 0.000001, 'ParallelTable forward err') - + local gradInput = m:backward(input, output2) mytester:assertTensorEq(gradInput, input, 0.000001, 'ParallelTable backward err') end @@ -3223,6 +3223,26 @@ function nntest.CosineEmbeddingCriterion() equal(grads[2], zero, 'gradient should be zero') end +function nntest.Replicate() + local vector = torch.rand(3) + + local r1 = nn.Replicate(2, 1) + local r2 = nn.Replicate(2, 2) + + local vOutput1 = r1:forward(vector):clone() + local vOutput2 = r2:forward(vector):clone() + + local expected1 = torch.zeros(2, 3) + local expected2 = torch.zeros(3, 2) + expected1:select(1, 1):copy(vector) + expected1:select(1, 2):copy(vector) + expected2:select(2, 1):copy(vector) + expected2:select(2, 2):copy(vector) + + mytester:assertTensorEq(vOutput1, expected1, precision, 'Wrong tiling of data when replicating vector.') + mytester:assertTensorEq(vOutput2, expected2, precision, 'Wrong tiling of data when replicating vector.') +end + function nntest.BatchNormalization() local nframes = torch.random(50,70) local indim = torch.random(1,10) @@ -3339,10 +3359,10 @@ function nntest.Padding() local input = torch.rand(fanin,sizey,sizex) local size = input:size():totable() size[1] = size[1] + math.abs(pad) - + local output = module:forward(input) mytester:assertTableEq(size, output:size():totable(), 0.00001, "Padding size error") - + local gradInput = module:backward(input, output) mytester:assertTensorEq(gradInput, input, 0.00001, "Padding backward error") end -- cgit v1.2.3