Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Lerer <alerer@fb.com>2015-09-04 01:16:34 +0300
committerAdam Lerer <alerer@fb.com>2015-09-04 01:16:34 +0300
commit7105419c8382c3a9c5c257661c8ab49872fdcddc (patch)
tree3ba71368a554420c2d71d035e04df64ac33506a2 /test.lua
parent98ecaf98858ccb373c605437422949c1d0998d7c (diff)
getParameters: improve memory efficiency, fix bug with non-compact tensors
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua39
1 files changed, 39 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index 3911118..3fff151 100644
--- a/test.lua
+++ b/test.lua
@@ -2793,6 +2793,45 @@ function nntest.Module_getParameters_8()
end
+function nntest.Module_getParameters_10()
+ -- tensors are non-contiguous but compact; they can be gathered
+ local L = nn.Linear(10,10)
+ L.weight = torch.Tensor(10,10):t():fill(1)
+ local tmp = torch.Tensor(10,10):fill(2)
+ L.bias = tmp:select(1,2)
+ local P = L:getParameters()
+ mytester:asserteq(L.weight:mean(), 1)
+ mytester:asserteq(L.bias:mean(), 2)
+ mytester:asserteq(L.weight:storage(), L.bias:storage())
+ mytester:asserteq(P:nElement(), 110)
+ mytester:asserteq(P:storage():size(), 110)
+ mytester:assertlt(L.bias[{ {10} }]:storageOffset() - 1, L.bias:storage():size())
+end
+
+function nntest.Module_getParameters_11()
+ -- tensors are non-compact; they can't be gathered
+ local L = nn.Linear(10,10)
+ local tmp = torch.Tensor(10,10):fill(2)
+ L.bias = tmp:select(2,2)
+ local ok, err = pcall(L.getParameters, L)
+ mytester:assert(not ok)
+end
+
+function nntest.Module_getParameters_12()
+ -- tensors are expanded (i.e. have dimension 0)
+ local L = nn.Linear(10,10)
+ L.weight = torch.Tensor(10, 1):fill(1)
+ torch.expand(L.weight, 10, 10)
+ L.bias = torch.Tensor(10):fill(2)
+ local P = L:getParameters()
+ mytester:asserteq(L.weight:mean(), 1)
+ mytester:asserteq(L.bias:mean(), 2)
+ mytester:asserteq(L.weight:storage(), L.bias:storage())
+ mytester:asserteq(P:nElement(), 20)
+ mytester:asserteq(P:storage():size(), 20)
+ mytester:assertlt(L.bias[{ {10} }]:storageOffset() - 1, L.bias:storage():size())
+end
+
function nntest.Module_listModules()
local batchSize = 4
local inputSize, outputSize = 7, 6