Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Replicate.lua37
-rwxr-xr-xdoc/simple.md62
-rw-r--r--test.lua104
3 files changed, 123 insertions, 80 deletions
diff --git a/Replicate.lua b/Replicate.lua
index d049044..5dd849a 100644
--- a/Replicate.lua
+++ b/Replicate.lua
@@ -1,20 +1,34 @@
local Replicate, parent = torch.class('nn.Replicate','nn.Module')
-function Replicate:__init(nf)
+function Replicate:__init(nf, dim)
parent.__init(self)
self.nfeatures = nf
+ self.dim = dim or 1
+ assert(dim > 0, "Can only replicate across positive integer dimensions.")
end
function Replicate:updateOutput(input)
+ assert(
+ self.dim <= input:dim()+1,
+ "Not enough input dimensions to replicate along dimension " ..
+ tostring(self.dim) .. ".")
local sz = torch.LongStorage(input:dim()+1)
- sz[1] = self.nfeatures
+ sz[self.dim] = self.nfeatures
for i = 1,input:dim() do
- sz[i+1] = input:size(i)
+ local offset = 0
+ if i >= self.dim then
+ offset = 1
+ end
+ sz[i+offset] = input:size(i)
end
local st = torch.LongStorage(input:dim()+1)
- st[1] = 0
+ st[self.dim] = 0
for i = 1,input:dim() do
- st[i+1] = input:stride(i)
+ local offset = 0
+ if i >= self.dim then
+ offset = 1
+ end
+ st[i+offset] = input:stride(i)
end
self.output = input.new(input:storage(),input:storageOffset(),sz,st)
return self.output
@@ -22,7 +36,16 @@ end
function Replicate:updateGradInput(input, gradOutput)
self.gradInput:resizeAs(input):zero()
- local gradInput = self.gradInput:view(1, unpack(input:size():totable()))
- gradInput:sum(gradOutput, 1)
+ local sz = torch.LongStorage(input:dim()+1)
+ sz[self.dim] = 1
+ for i = 1,input:dim() do
+ local offset = 0
+ if i >= self.dim then
+ offset = 1
+ end
+ sz[i+offset] = input:size(i)
+ end
+ local gradInput = self.gradInput:view(sz)
+ gradInput:sum(gradOutput, self.dim)
return self.gradInput
end
diff --git a/doc/simple.md b/doc/simple.md
index 01a6140..7c97a6b 100755
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -1,6 +1,5 @@
<a name="nn.simplelayers.dok"/>
# Simple layers #
-
Simple Modules are used for various tasks like adapting Tensor methods and providing affine transformations :
* Parameterized Modules :
* [Linear](#nn.Linear) : a linear transformation ;
@@ -36,7 +35,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi
* [SpatialDropout](#nn.SpatialDropout) : Same as Dropout but for spatial inputs where adjacent pixels are strongly correlated ;
* [Padding](#nn.Padding) : adds padding to a dimension ;
* [L1Penalty](#nn.L1Penalty) : adds an L1 penalty to an input (for sparsity);
-
+
<a name="nn.Linear"/>
## Linear ##
@@ -114,7 +113,6 @@ x = torch.Tensor({ {1, 0.1}, {2, 0.3}, {10, 0.3}, {31, 0.2} })
The first column contains indices, the second column contains values in a a vector where all other elements are zeros. The indices should not exceed the stated dimensions of the input to the layer (10000 in the example).
-
<a name="nn.Dropout"/>
## Dropout ##
@@ -123,7 +121,7 @@ module = nn.Dropout(p)
```
During training, `Dropout` masks parts of the `input` using binary samples from a [bernoulli](http://en.wikipedia.org/wiki/Bernoulli_distribution) distribution.
-Each `input` element has a probability of `p` of being dropped, i.e having its commensurate output element be zero. This has proven an effective technique for regularization and preventing the co-adaptation of neurons (see [Hinton et al. 2012](http://arxiv.org/abs/1207.0580)).
+Each `input` element has a probability of `p` of being dropped, i.e having its commensurate output element be zero. This has proven an effective technique for regularization and preventing the co-adaptation of neurons (see [Hinton et al. 2012](http://arxiv.org/abs/1207.0580)).
Furthermore, the ouputs are scaled by a factor of `1/(1-p)` during training. This allows the `input` to be simply forwarded as-is during evaluation.
@@ -185,7 +183,6 @@ We can return to training our model by first calling [Module:training()](module.
When used, `Dropout` should normally be applied to the input of parameterized [Modules](module.md#nn.Module) like [Linear](#nn.Linear) or [SpatialConvolution](convolution.md#nn.SpatialConvolution). A `p` of `0.5` (the default) is usually okay for hidden layers. `Dropout` can sometimes be used successfully on the dataset inputs with a `p` around `0.2`. It sometimes works best following [Transfer](transfer.md) Modules like [ReLU](transfer.md#nn.ReLU). All this depends a great deal on the dataset so its up to the user to try different combinations.
-
<a name="nn.SpatialDropout"/>
## SpatialDropout ##
@@ -224,16 +221,16 @@ gnuplot.grid(true)
module = nn.Add(inputDimension, scalar)
```
-Applies a bias term to the incoming data, i.e. `yi = x_i + b_i`, or if `scalar = true` then uses a single bias term, `yi = x_i + b`.
+Applies a bias term to the incoming data, i.e. `yi = x_i + b_i`, or if `scalar = true` then uses a single bias term, `yi = x_i + b`.
Example:
```lua
y = torch.Tensor(5)
-mlp = nn.Sequential()
+mlp = nn.Sequential()
mlp:add(nn.Add(5))
-function gradUpdate(mlp, x, y, criterion, learningRate)
+function gradUpdate(mlp, x, y, criterion, learningRate)
local pred = mlp:forward(x)
local err = criterion:forward(pred, y)
local gradCriterion = criterion:backward(pred, y)
@@ -245,7 +242,7 @@ end
for i = 1, 10000 do
x = torch.rand(5)
- y:copy(x);
+ y:copy(x);
for i = 1, 5 do y[i] = y[i] + i; end
err = gradUpdate(mlp, x, y, nn.MSECriterion(), 0.01)
end
@@ -274,7 +271,7 @@ i.e. the network successfully learns the input `x` has been shifted to produce t
module = nn.Mul()
```
-Applies a _single_ scaling factor to the incoming data, i.e. `y = w x`, where `w` is a scalar.
+Applies a _single_ scaling factor to the incoming data, i.e. `y = w x`, where `w` is a scalar.
Example:
@@ -283,7 +280,7 @@ y = torch.Tensor(5)
mlp = nn.Sequential()
mlp:add(nn.Mul())
-function gradUpdate(mlp, x, y, criterion, learningRate)
+function gradUpdate(mlp, x, y, criterion, learningRate)
local pred = mlp:forward(x)
local err = criterion:forward(pred, y)
local gradCriterion = criterion:backward(pred, y)
@@ -331,7 +328,7 @@ y = torch.Tensor(5)
sc = torch.Tensor(5)
for i = 1, 5 do sc[i] = i; end -- scale input with this
-function gradUpdate(mlp, x, y, criterion, learningRate)
+function gradUpdate(mlp, x, y, criterion, learningRate)
local pred = mlp:forward(x)
local err = criterion:forward(pred, y)
local gradCriterion = criterion:backward(pred, y)
@@ -415,7 +412,7 @@ Hence, if an `nxpxq` Tensor was given as input, and `dimension` = `2` then an `n
module = nn.Euclidean(inputSize,outputSize)
```
-Outputs the Euclidean distance of the input to `outputSize` centers, i.e. this layer has the weights `w_j`, for `j` = `1`,..,`outputSize`, where `w_j` are vectors of dimension `inputSize`.
+Outputs the Euclidean distance of the input to `outputSize` centers, i.e. this layer has the weights `w_j`, for `j` = `1`,..,`outputSize`, where `w_j` are vectors of dimension `inputSize`.
The distance `y_j` between center `j` and input `x` is formulated as `y_j = || w_j - x ||`.
@@ -426,7 +423,7 @@ The distance `y_j` between center `j` and input `x` is formulated as `y_j = || w
module = nn.WeightedEuclidean(inputSize,outputSize)
```
-This module is similar to [Euclidean](#nn.Euclidean), but additionally learns a separate diagonal covariance matrix across the features of the input space _for each center_.
+This module is similar to [Euclidean](#nn.Euclidean), but additionally learns a separate diagonal covariance matrix across the features of the input space _for each center_.
In other words, for each of the `outputSize` centers `w_j`, there is a diagonal covariance matrices `c_j`, for `j` = `1`,..,`outputSize`, where `c_j` are stored as vectors of size `inputSize`.
@@ -439,8 +436,10 @@ The distance `y_j` between center `j` and input `x` is formulated as `y_j = || c
module = nn.Identity()
```
-Creates a module that returns whatever is input to it as output.
+
+Creates a module that returns whatever is input to it as output.
This is useful when combined with the module [ParallelTable](table.md#nn.ParallelTable) in case you do not wish to do anything to one of the input Tensors.
+
Example:
```lua
@@ -448,7 +447,7 @@ mlp = nn.Identity()
print(mlp:forward(torch.ones(5, 2)))
```
-gives the output:
+gives the output:
```lua
1 1
@@ -461,10 +460,10 @@ gives the output:
Here is a more useful example, where one can implement a network which also computes a Criterion using this module:
-```lua
+```lua
pred_mlp = nn.Sequential() -- A network that makes predictions given x.
-pred_mlp:add(nn.Linear(5, 4))
-pred_mlp:add(nn.Linear(4, 3))
+pred_mlp:add(nn.Linear(5, 4))
+pred_mlp:add(nn.Linear(4, 3))
xy_mlp = nn.ParallelTable() -- A network for predictions and for keeping the
xy_mlp:add(pred_mlp) -- true label for comparison with a criterion
@@ -478,14 +477,14 @@ mlp:add(cr_wrap) -- and then applies the criterion.
for i = 1, 100 do -- Do a few training iterations
x = torch.ones(5) -- Make input features.
- y = torch.Tensor(3)
+ y = torch.Tensor(3)
y:copy(x:narrow(1,1,3)) -- Make output label.
err = mlp:forward{x,y} -- Forward both input and output.
print(err) -- Print error from criterion.
- mlp:zeroGradParameters() -- Do backprop...
- mlp:backward({x, y})
- mlp:updateParameters(0.05)
+ mlp:zeroGradParameters() -- Do backprop...
+ mlp:backward({x, y})
+ mlp:updateParameters(0.05)
end
```
@@ -496,7 +495,7 @@ end
module = nn.Copy(inputType, outputType, [forceCopy, dontCast])
```
-This layer copies the input to output with type casting from input type from `inputType` to `outputType`. Unless `forceCopy` is true, when the first two arguments are the same, the input isn't copied, only transfered as the output. The default `forceCopy` is false.
+This layer copies the input to output with type casting from input type from `inputType` to `outputType`. Unless `forceCopy` is true, when the first two arguments are the same, the input isn't copied, only transfered as the output. The default `forceCopy` is false.
When `dontCast` is true, a call to `nn.Copy:type(type)` will not cast the module's `output` and `gradInput` Tensors to the new type. The default is false.
<a name="nn.Narrow"/>
@@ -512,10 +511,10 @@ Narrow is application of [narrow](https://github.com/torch/torch7/blob/master/do
## Replicate ##
```lua
-module = nn.Replicate(nFeature)
+module = nn.Replicate(nFeature, dim)
```
-This class creates an output where the input is replicated `nFeature` times along its first dimension. There is no memory allocation or memory copy in this module. It sets the [stride](https://github.com/torch/torch7/blob/master/doc/tensor.md#torch.Tensor.stride) along the first dimension to zero.
+This class creates an output where the input is replicated `nFeature` times along dimension `dim` (default 1). There is no memory allocation or memory copy in this module. It sets the [stride](https://github.com/torch/torch7/blob/master/doc/tensor.md#torch.Tensor.stride) along the `dim`th dimension to zero.
```lua
> x = torch.linspace(1, 5, 5)
@@ -556,9 +555,10 @@ This class creates an output where the input is replicated `nFeature` times alon
module = nn.Reshape(dimension1, dimension2, ... [, batchMode])
```
-Reshapes an `nxpxqx..` Tensor into a `dimension1xdimension2x...` Tensor, taking the elements column-wise.
-The optional last argument `batchMode`, when `true` forces the first dimension of the input to be considered the batch dimension, and thus keep its size fixed. This is necessary when dealing with batch sizes of one. When `false`, it forces the entire input (including the first dimension) to be reshaped to the input size. Default `batchMode=nil`, which means that the module considers inputs with more elements than the produce of provided sizes, i.e. `dimension1xdimension2x...`, to be batches.
+Reshapes an `nxpxqx..` Tensor into a `dimension1xdimension2x...` Tensor, taking the elements column-wise.
+
+The optional last argument `batchMode`, when `true` forces the first dimension of the input to be considered the batch dimension, and thus keep its size fixed. This is necessary when dealing with batch sizes of one. When `false`, it forces the entire input (including the first dimension) to be reshaped to the input size. Default `batchMode=nil`, which means that the module considers inputs with more elements than the produce of provided sizes, i.e. `dimension1xdimension2x...`, to be batches.
Example:
@@ -766,11 +766,11 @@ This can be used in conjunction with [Concat](containers.md#nn.Concat) to emulat
```lua
mlp = nn.Sequential()
-c = nn.Concat(2)
+c = nn.Concat(2)
for i = 1, 10 do
local t = nn.Sequential()
t:add(nn.Select(1, i))
- t:add(nn.Linear(3, 2))
+ t:add(nn.Linear(3, 2))
t:add(nn.Reshape(2, 1))
c:add(t)
end
@@ -946,7 +946,7 @@ C = model.forward(A) -- C will be of size `b x m`
`module` = `nn.Padding(dim, pad [, nInputDim, value])`
This module adds `pad` units of padding to dimension `dim` of the input.
-If `pad` is negative, padding is added to the left, otherwise, it is added to the right of the dimension. When `nInputDim` is provided, inputs larger than that value will be considered batches where the actual `dim` to be padded will
+If `pad` is negative, padding is added to the left, otherwise, it is added to the right of the dimension. When `nInputDim` is provided, inputs larger than that value will be considered batches where the actual `dim` to be padded will
be dimension `dim + 1`. When `value` is provide, the padding will be filled with that `value`. The default `value` is zero.
Example 1:
diff --git a/test.lua b/test.lua
index 27a1747..8f08809 100644
--- a/test.lua
+++ b/test.lua
@@ -113,7 +113,7 @@ function nntest.CMul()
local output = module:forward(input)
local output2 = torch.cmul(input, module.weight:view(1,ini,inj,ink):expandAs(input))
mytester:assertTensorEq(output2, output, 0.000001, 'CMul forward 2D err')
-
+
module:zeroGradParameters()
local gradWeight = module.gradWeight:clone()
local gradInput = module:backward(input, output)
@@ -122,7 +122,7 @@ function nntest.CMul()
gradInput2:view(input:size(1), -1):addcmul(1, module.weight:view(1,-1):expandAs(outputView), outputView)
mytester:assertTensorEq(gradInput2, gradInput, 0.000001, 'CMul updateGradInput 2D err')
mytester:assert(gradInput:isSameSizeAs(input), 'CMul gradInput 2D size err')
-
+
local inputView = input:view(nframe, -1)
local gradWeightView = gradWeight:view(1, -1)
for i=1,nframe do
@@ -130,7 +130,7 @@ function nntest.CMul()
end
mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'CMul accGradParameters 2D err')
mytester:assert(module.weight:isSameSizeAs(module.gradWeight), 'CMul gradWeight size err')
-
+
input:zero()
local err = jac.testJacobian(module,input)
@@ -592,18 +592,18 @@ function nntest.Euclidean()
local gradOutput = torch.randn(inj)
local module = nn.Euclidean(ini,inj)
local output = module:forward(input):clone()
-
+
local output2 = torch.Tensor(inj):zero()
for o = 1,module.weight:size(2) do
output2[o] = input:dist(module.weight:select(2,o))
end
mytester:assertTensorEq(output, output2, 0.000001, 'Euclidean forward 1D err')
-
+
local input2 = torch.randn(8, ini)
input2[2]:copy(input)
local output2 = module:forward(input2)
mytester:assertTensorEq(output2[2], output, 0.000001, 'Euclidean forward 2D err')
-
+
local output = module:forward(input):clone()
module:zeroGradParameters()
local gradInput = module:backward(input, gradOutput, 1):clone()
@@ -616,7 +616,7 @@ function nntest.Euclidean()
gradInput2:add(temp)
end
mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'Euclidean updateGradInput 1D err')
-
+
local gradWeight = module.gradWeight:clone():zero()
for o = 1,module.weight:size(2) do
temp:copy(module.weight:select(2,o)):add(-1,input)
@@ -624,16 +624,16 @@ function nntest.Euclidean()
gradWeight:select(2,o):add(1, temp)
end
mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'Euclidean accGradParameters 1D err')
-
+
local input2 = input:view(1, -1):repeatTensor(8, 1)
local gradOutput2 = gradOutput:view(1, -1):repeatTensor(8, 1)
local output2 = module:forward(input2)
module:zeroGradParameters()
local gradInput2 = module:backward(input2, gradOutput2, 1/8)
mytester:assertTensorEq(gradInput2[2], gradInput, 0.000001, 'Euclidean updateGradInput 2D err')
-
+
mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'Euclidean accGradParameters 2D err')
-
+
input:zero()
module.fastBackward = false
local err = jac.testJacobian(module,input)
@@ -655,7 +655,7 @@ function nntest.WeightedEuclidean()
local module = nn.WeightedEuclidean(ini,inj)
local output = module:forward(input):clone()
-
+
local output2 = torch.Tensor(inj):zero()
local temp = input:clone()
for o = 1,module.weight:size(2) do
@@ -665,12 +665,12 @@ function nntest.WeightedEuclidean()
output2[o] = math.sqrt(temp:sum())
end
mytester:assertTensorEq(output, output2, 0.000001, 'WeightedEuclidean forward 1D err')
-
+
local input2 = torch.randn(8, ini)
input2[2]:copy(input)
local output2 = module:forward(input2)
mytester:assertTensorEq(output2[2], output, 0.000001, 'WeightedEuclidean forward 2D err')
-
+
local output = module:forward(input):clone()
module:zeroGradParameters()
local gradInput = module:backward(input, gradOutput, 1):clone()
@@ -683,7 +683,7 @@ function nntest.WeightedEuclidean()
gradInput2:add(temp)
end
mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'WeightedEuclidean updateGradInput 1D err')
-
+
local gradWeight = module.gradWeight:clone():zero()
local gradDiagCov = module.gradDiagCov:clone():zero()
for o = 1,module.weight:size(2) do
@@ -702,20 +702,20 @@ function nntest.WeightedEuclidean()
end
mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'WeightedEuclidean accGradParameters gradWeight 1D err')
mytester:assertTensorEq(gradDiagCov, module.gradDiagCov, 0.000001, 'WeightedEuclidean accGradParameters gradDiagCov 1D err')
-
+
local input2 = input:view(1, -1):repeatTensor(8, 1)
local gradOutput2 = gradOutput:view(1, -1):repeatTensor(8, 1)
local output2 = module:forward(input2)
module:zeroGradParameters()
local gradInput2 = module:backward(input2, gradOutput2, 1/8)
mytester:assertTensorEq(gradInput2[2], gradInput, 0.000001, 'WeightedEuclidean updateGradInput 2D err')
-
+
mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'WeightedEuclidean accGradParameters gradWeight 2D err')
mytester:assertTensorEq(gradDiagCov, module.gradDiagCov, 0.000001, 'WeightedEuclidean accGradParameters gradDiagCov 2D err')
-
+
input:zero()
module.fastBackward = false
-
+
local err = jac.testJacobian(module,input)
mytester:assertlt(err,precision, 'error on state ')
@@ -728,7 +728,7 @@ function nntest.WeightedEuclidean()
local ferr,berr = jac.testIO(module,input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
-
+
input:zero()
module:zeroGradParameters()
local err = jac.testJacobian(module,input)
@@ -1752,23 +1752,23 @@ function nntest.SpatialAveragePooling()
local inj = (outj-1)*sj+kj
local module = nn.SpatialAveragePooling(ki, kj, si, sj)
local input = torch.Tensor(from, inj, ini):zero()
-
+
local err = jac.testJacobian(module, input)
mytester:assertlt(err, precision, 'error on state ')
local ferr, berr = jac.testIO(module, input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
-
+
local sap = nn.SpatialSubSampling(from, ki, kj, si, sj)
sap.weight:fill(1.0/(ki*kj))
sap.bias:fill(0.0)
-
+
local output = module:forward(input)
local gradInput = module:backward(input, output)
local output2 = sap:forward(input)
local gradInput2 = sap:updateGradInput(input, output)
-
+
mytester:assertTensorEq(output, output2, 0.000001, torch.typename(module) .. ' forward err ')
mytester:assertTensorEq(gradInput, gradInput2, 0.000001, torch.typename(module) .. ' backward err ')
@@ -1782,24 +1782,24 @@ function nntest.SpatialAveragePooling()
local err = jac.testJacobian(module, input)
mytester:assertlt(err, precision, 'batch error on state ')
-
+
local ferr, berr = jac.testIO(module, input)
mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ')
-
+
local ferr, berr = jac.testIO(module, input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ')
-
+
local sap = nn.SpatialSubSampling(from, ki, kj, si, sj)
sap.weight:fill(1.0/(ki*kj))
sap.bias:fill(0.0)
-
+
local output = module:forward(input)
local gradInput = module:backward(input, output)
local output2 = sap:forward(input)
local gradInput2 = sap:updateGradInput(input, output)
-
+
mytester:assertTensorEq(output, output2, 0.000001, torch.typename(module) .. ' forward err (Batch) ')
mytester:assertTensorEq(gradInput, gradInput2, 0.000001, torch.typename(module) .. ' backward err (Batch) ')
end
@@ -2125,7 +2125,7 @@ function nntest.VolumetricConvolutionBatchCompare()
local inj = (outj-1)*sj+kj
local module = nn.VolumetricConvolution(from, to, kt, ki, kj, st, si, sj)
local input = torch.randn(from, int, inj, ini)
- batchcompare(module,input, {'weight','bias','gradWeight','gradBias'})
+ batchcompare(module,input, {'weight','bias','gradWeight','gradBias'})
end
function nntest.VolumetricMaxPooling()
@@ -2346,7 +2346,7 @@ function nntest.Module_listModules()
mlp3:add(linear)
mlp3:add(tanh)
mlp3:add(reshape)
-
+
local mlp2 = nn.Sequential()
local view = nn.View(outputSize)
local linear2 = nn.Linear(outputSize, inputSize)
@@ -2355,7 +2355,7 @@ function nntest.Module_listModules()
mlp2:add(view)
mlp2:add(linear2)
mlp2:add(tanh2)
-
+
local concat = nn.ConcatTable()
local id = nn.Identity()
concat:add(mlp2)
@@ -2364,15 +2364,15 @@ function nntest.Module_listModules()
local add = nn.CAddTable()
mlp:add(concat)
mlp:add(add)
-
+
local modules2 = {mlp, concat, mlp2, mlp3, linear, tanh, reshape, view, linear2, tanh2, id, add}
local modules = mlp:listModules()
-
+
mytester:assert(#modules2 == #modules, 'missing modules error')
-
+
for i,module in ipairs(modules) do
mytester:assert(torch.type(module) == torch.type(modules2[i]), 'module error')
- end
+ end
end
function nntest.PairwiseDistance()
@@ -2772,7 +2772,7 @@ function nntest.View()
mytester:assertTableEq(module:forward(minibatch):size(1),
minibatch:size(1),
"Error in minibatch dimension with size -1")
-
+
-- Minibatch Generalization
local minibatch = torch.rand(5,2,6)
local module = nn.View(6)
@@ -2835,11 +2835,11 @@ function nntest.Parallel()
m:add(nn.View(4,5,1))
m:add(nn.View(4,5,1))
m:add(nn.View(4,5,1))
-
+
local output = m:forward(input)
local output2 = input:transpose(1,3):transpose(1,2)
mytester:assertTensorEq(output2, output, 0.000001, 'Parallel forward err')
-
+
local gradInput = m:backward(input, output2)
mytester:assertTensorEq(gradInput, input, 0.000001, 'Parallel backward err')
end
@@ -2854,11 +2854,11 @@ function nntest.ParallelTable()
m:add(nn.SplitTable(1))
m:add(p)
m:add(nn.JoinTable(3))
-
+
local output = m:forward(input)
local output2 = input:transpose(1,3):transpose(1,2)
mytester:assertTensorEq(output2, output, 0.000001, 'ParallelTable forward err')
-
+
local gradInput = m:backward(input, output2)
mytester:assertTensorEq(gradInput, input, 0.000001, 'ParallelTable backward err')
end
@@ -3223,6 +3223,26 @@ function nntest.CosineEmbeddingCriterion()
equal(grads[2], zero, 'gradient should be zero')
end
+function nntest.Replicate()
+ local vector = torch.rand(3)
+
+ local r1 = nn.Replicate(2, 1)
+ local r2 = nn.Replicate(2, 2)
+
+ local vOutput1 = r1:forward(vector):clone()
+ local vOutput2 = r2:forward(vector):clone()
+
+ local expected1 = torch.zeros(2, 3)
+ local expected2 = torch.zeros(3, 2)
+ expected1:select(1, 1):copy(vector)
+ expected1:select(1, 2):copy(vector)
+ expected2:select(2, 1):copy(vector)
+ expected2:select(2, 2):copy(vector)
+
+ mytester:assertTensorEq(vOutput1, expected1, precision, 'Wrong tiling of data when replicating vector.')
+ mytester:assertTensorEq(vOutput2, expected2, precision, 'Wrong tiling of data when replicating vector.')
+end
+
function nntest.BatchNormalization()
local nframes = torch.random(50,70)
local indim = torch.random(1,10)
@@ -3339,10 +3359,10 @@ function nntest.Padding()
local input = torch.rand(fanin,sizey,sizex)
local size = input:size():totable()
size[1] = size[1] + math.abs(pad)
-
+
local output = module:forward(input)
mytester:assertTableEq(size, output:size():totable(), 0.00001, "Padding size error")
-
+
local gradInput = module:backward(input, output)
mytester:assertTensorEq(gradInput, input, 0.00001, "Padding backward error")
end