Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorsoumith <soumith@gmail.com>2016-07-27 08:43:55 +0300
committersoumith <soumith@gmail.com>2016-07-27 08:43:55 +0300
commit4972614b3bb851a2234e832942d29fd7e89d3c77 (patch)
treea6eab4e234e6591ee33ada29bfbe7cd61980481e /test
parent08d0b5db1b56bbb52dfb81a233d288d3654c07d8 (diff)
fix for :type() to typecast the children as welltypefix
Diffstat (limited to 'test')
-rw-r--r--test/test_nngraph.lua18
1 files changed, 13 insertions, 5 deletions
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua
index 8e5aadc..340ab24 100644
--- a/test/test_nngraph.lua
+++ b/test/test_nngraph.lua
@@ -214,6 +214,9 @@ function test.test_gradInputType()
local module = nn.gModule({in1}, {out1})
local input = torch.rand(20)
local output = module:forward(input)
+ local gradOutput = output:clone():normal()
+ local gradInput = module:backward(input, gradOutput)
+
module:backward(input, output)
tester:eq(torch.typename(output), "torch.DoubleTensor")
tester:eq(torch.typename(module.output), "torch.DoubleTensor")
@@ -222,17 +225,22 @@ function test.test_gradInputType()
tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor")
tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor")
tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.backwardnodes[1].children[1].data.gradOutput[1]), "torch.DoubleTensor")
module:float()
- local output = module:forward(input:float())
- tester:eq(torch.typename(output), "torch.FloatTensor")
tester:eq(torch.typename(module.output), "torch.FloatTensor")
tester:eq(torch.typename(module.gradInput), "torch.FloatTensor")
tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor")
tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor")
tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor")
tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.FloatTensor")
- end
+ tester:eq(torch.typename(module.backwardnodes[1].children[1].data.gradOutput[1]), "torch.FloatTensor")
+ local output = module:forward(input:float())
+ tester:eq(torch.typename(output), "torch.FloatTensor")
+ local gradInput = module:backward(input:float(), gradOutput:float())
+ tester:eq(torch.typename(gradInput), "torch.FloatTensor")
+
+ end
function test.test_nestedGradInput()
local x = nn.Identity()()
@@ -384,8 +392,8 @@ function test.test_gradInputType()
end
function test.test_gradOutputZeroOptim()
- local unpack = function(...)
- if _G[unpack] then return _G[unpack](...)
+ local unpack = function(...)
+ if _G[unpack] then return _G[unpack](...)
else return table.unpack(...) end
end
-- Make module that produces an expanded zero gradInput tensor