Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSergio Gomez <sergomezcol@gmail.com>2014-06-26 13:09:43 +0400
committerSergio Gomez <sergomezcol@gmail.com>2014-06-26 13:09:43 +0400
commit8fd02c336bc2d03aa1fe21fced69615fd1e7b99b (patch)
tree9937e7ae5d58c5d7bd2aa5955c486e74c4822eb9 /test
parentd85c2ce7ab24a699855f41b1919a74f81def47cd (diff)
Remove setNumInputDims method in JoinTable and SplitTable
Now nInputDims is an optional parameter in the constructor of these modules.
Diffstat (limited to 'test')
-rw-r--r--test/test.lua6
1 files changed, 2 insertions, 4 deletions
diff --git a/test/test.lua b/test/test.lua
index 2db6f2d..fc7a7e4 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1804,8 +1804,7 @@ function nntest.JoinTable()
local input = {tensor, tensor}
local module
for d = 1,tensor:dim()-1 do
- module = nn.JoinTable(d)
- module:setNumInputDims(2)
+ module = nn.JoinTable(d, 2)
mytester:asserteq(module:forward(input):size(d+1), tensor:size(d+1)*2, "dimension " .. d)
end
end
@@ -1822,8 +1821,7 @@ function nntest.SplitTable()
local input = torch.randn(3,4,5)
local module
for d = 1,input:dim()-1 do
- module = nn.SplitTable(d)
- module:setNumInputDims(2)
+ module = nn.SplitTable(d, 2)
mytester:asserteq(#module:forward(input), input:size(d+1), "dimension " .. d)
end
end