diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-04-19 18:06:33 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-04-19 18:06:33 +0300 |
commit | a067c6c4de406399f01ec3a2ee9bd1c3bb3d752f (patch) | |
tree | 371726cbafefd360af8a7d8c87d13cb4270ed8bd /test.lua | |
parent | cfbae887bbc3f54566905c446cedab21da65293a (diff) | |
parent | c38f0c7323e2cb1d8263dde67058829f7266cbcd (diff) |
Merge pull request #1171 from davidemaz/maptable-bugfix
MapTable: 'zeroGradParameters' and 'updateParameters' bugfix
Diffstat (limited to 'test.lua')
-rwxr-xr-x | test.lua | 18 |
1 files changed, 18 insertions, 0 deletions
@@ -6829,6 +6829,24 @@ function nntest.MapTable() == torch.pointer(map:get(1).weight:storage())) map:clearState() mytester:assert(map:size() == 1) + + -- check if gradients are correctly reset + -- share weights and gradients + map = nn.MapTable(nn.Linear(10,5)) + map:forward(input) + _, gradParams = map:getParameters() + gradParams:uniform() + map:zeroGradParameters() + mytester:assertlt(gradParams:sum(),precision) + + -- check if gradients are correctly reset + -- do not share weights and gradients + map = nn.MapTable(nn.Linear(10,5),false) + map:forward(input) + _, gradParams = map:getParameters() + gradParams:uniform() + map:zeroGradParameters() + mytester:assertlt(gradParams:sum(),precision) end function nntest.FlattenTable() |