diff options
author | Adam Lerer <alerer@fb.com> | 2015-09-04 01:16:34 +0300 |
---|---|---|
committer | Adam Lerer <alerer@fb.com> | 2015-09-04 01:16:34 +0300 |
commit | 7105419c8382c3a9c5c257661c8ab49872fdcddc (patch) | |
tree | 3ba71368a554420c2d71d035e04df64ac33506a2 /test.lua | |
parent | 98ecaf98858ccb373c605437422949c1d0998d7c (diff) |
getParameters: improve memory efficiency, fix bug with non-compact tensors
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 39 |
1 files changed, 39 insertions, 0 deletions
@@ -2793,6 +2793,45 @@ function nntest.Module_getParameters_8() end +function nntest.Module_getParameters_10() + -- tensors are non-contiguous but compact; they can be gathered + local L = nn.Linear(10,10) + L.weight = torch.Tensor(10,10):t():fill(1) + local tmp = torch.Tensor(10,10):fill(2) + L.bias = tmp:select(1,2) + local P = L:getParameters() + mytester:asserteq(L.weight:mean(), 1) + mytester:asserteq(L.bias:mean(), 2) + mytester:asserteq(L.weight:storage(), L.bias:storage()) + mytester:asserteq(P:nElement(), 110) + mytester:asserteq(P:storage():size(), 110) + mytester:assertlt(L.bias[{ {10} }]:storageOffset() - 1, L.bias:storage():size()) +end + +function nntest.Module_getParameters_11() + -- tensors are non-compact; they can't be gathered + local L = nn.Linear(10,10) + local tmp = torch.Tensor(10,10):fill(2) + L.bias = tmp:select(2,2) + local ok, err = pcall(L.getParameters, L) + mytester:assert(not ok) +end + +function nntest.Module_getParameters_12() + -- tensors are expanded (i.e. have dimension 0) + local L = nn.Linear(10,10) + L.weight = torch.Tensor(10, 1):fill(1) + torch.expand(L.weight, 10, 10) + L.bias = torch.Tensor(10):fill(2) + local P = L:getParameters() + mytester:asserteq(L.weight:mean(), 1) + mytester:asserteq(L.bias:mean(), 2) + mytester:asserteq(L.weight:storage(), L.bias:storage()) + mytester:asserteq(P:nElement(), 20) + mytester:asserteq(P:storage():size(), 20) + mytester:assertlt(L.bias[{ {10} }]:storageOffset() - 1, L.bias:storage():size()) +end + function nntest.Module_listModules() local batchSize = 4 local inputSize, outputSize = 7, 6 |