Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-04-19 18:06:33 +0300
committerGitHub <noreply@github.com>2017-04-19 18:06:33 +0300
commita067c6c4de406399f01ec3a2ee9bd1c3bb3d752f (patch)
tree371726cbafefd360af8a7d8c87d13cb4270ed8bd /test.lua
parentcfbae887bbc3f54566905c446cedab21da65293a (diff)
parentc38f0c7323e2cb1d8263dde67058829f7266cbcd (diff)
Merge pull request #1171 from davidemaz/maptable-bugfix
MapTable: 'zeroGradParameters' and 'updateParameters' bugfix
Diffstat (limited to 'test.lua')
-rwxr-xr-xtest.lua18
1 files changed, 18 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index 688bde2..e3a2782 100755
--- a/test.lua
+++ b/test.lua
@@ -6829,6 +6829,24 @@ function nntest.MapTable()
== torch.pointer(map:get(1).weight:storage()))
map:clearState()
mytester:assert(map:size() == 1)
+
+ -- check if gradients are correctly reset
+ -- share weights and gradients
+ map = nn.MapTable(nn.Linear(10,5))
+ map:forward(input)
+ _, gradParams = map:getParameters()
+ gradParams:uniform()
+ map:zeroGradParameters()
+ mytester:assertlt(gradParams:sum(),precision)
+
+ -- check if gradients are correctly reset
+ -- do not share weights and gradients
+ map = nn.MapTable(nn.Linear(10,5),false)
+ map:forward(input)
+ _, gradParams = map:getParameters()
+ gradParams:uniform()
+ map:zeroGradParameters()
+ mytester:assertlt(gradParams:sum(),precision)
end
function nntest.FlattenTable()