Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua39
1 files changed, 39 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index 3911118..3fff151 100644
--- a/test.lua
+++ b/test.lua
@@ -2793,6 +2793,45 @@ function nntest.Module_getParameters_8()
end
+function nntest.Module_getParameters_10()
+ -- tensors are non-contiguous but compact; they can be gathered
+ local L = nn.Linear(10,10)
+ L.weight = torch.Tensor(10,10):t():fill(1)
+ local tmp = torch.Tensor(10,10):fill(2)
+ L.bias = tmp:select(1,2)
+ local P = L:getParameters()
+ mytester:asserteq(L.weight:mean(), 1)
+ mytester:asserteq(L.bias:mean(), 2)
+ mytester:asserteq(L.weight:storage(), L.bias:storage())
+ mytester:asserteq(P:nElement(), 110)
+ mytester:asserteq(P:storage():size(), 110)
+ mytester:assertlt(L.bias[{ {10} }]:storageOffset() - 1, L.bias:storage():size())
+end
+
+function nntest.Module_getParameters_11()
+ -- tensors are non-compact; they can't be gathered
+ local L = nn.Linear(10,10)
+ local tmp = torch.Tensor(10,10):fill(2)
+ L.bias = tmp:select(2,2)
+ local ok, err = pcall(L.getParameters, L)
+ mytester:assert(not ok)
+end
+
+function nntest.Module_getParameters_12()
+ -- tensors are expanded (i.e. have dimension 0)
+ local L = nn.Linear(10,10)
+ L.weight = torch.Tensor(10, 1):fill(1)
+ torch.expand(L.weight, 10, 10)
+ L.bias = torch.Tensor(10):fill(2)
+ local P = L:getParameters()
+ mytester:asserteq(L.weight:mean(), 1)
+ mytester:asserteq(L.bias:mean(), 2)
+ mytester:asserteq(L.weight:storage(), L.bias:storage())
+ mytester:asserteq(P:nElement(), 20)
+ mytester:asserteq(P:storage():size(), 20)
+ mytester:assertlt(L.bias[{ {10} }]:storageOffset() - 1, L.bias:storage():size())
+end
+
function nntest.Module_listModules()
local batchSize = 4
local inputSize, outputSize = 7, 6