Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2015-03-18 19:16:11 +0300
committerRonan Collobert <ronan@collobert.com>2015-03-18 19:16:11 +0300
commit6b0a15f3f3f811f483ecd07e91d2f89a01bf4d6d (patch)
tree1382ceec9b6c5ea4079f306cbc015f704d785ffa
parentf04886e0f3c0ca7354b71b2872a7caeca434e59a (diff)
fix special View case with -1 arg and dimension set <> 1
added test case
-rw-r--r--View.lua2
-rw-r--r--test.lua7
2 files changed, 8 insertions, 1 deletions
diff --git a/View.lua b/View.lua
index 98e53e3..a6785c8 100644
--- a/View.lua
+++ b/View.lua
@@ -30,7 +30,7 @@ local function batchsize(input, size, numInputDims, numElements)
numElements = 1
local dim = input:nDimension()
for i=1,numInputDims do
- numElements = numElements * input:size(dim-numElements+1)
+ numElements = numElements * input:size(dim-i+1)
end
else
numElements = input:nElement()
diff --git a/test.lua b/test.lua
index 3661d03..f05e502 100644
--- a/test.lua
+++ b/test.lua
@@ -2738,6 +2738,13 @@ function nntest.View()
minibatch:nElement(),
"Error in minibatch nElement with size -1")
+ -- another setNumInputDims case
+ local minibatch = torch.rand(5,4,10)
+ local module = nn.View(-1):setNumInputDims(2)
+ mytester:assertTableEq(module:forward(minibatch):size(1),
+ minibatch:size(1),
+ "Error in minibatch dimension with size -1")
+
-- Minibatch Generalization
local minibatch = torch.rand(5,2,6)
local module = nn.View(6)