Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorJames Kirkpatrick <kirkpatrick@google.com>2014-10-01 12:16:27 +0400
committerSoumith Chintala <soumith@gmail.com>2014-10-27 17:05:16 +0300
commit8581f84f45c2baef72dd32f6336410480b617ff7 (patch)
treee590661f77f4780dee406b44cc6e11e3fdce09cf /test
parent8d7d03ebe72ae507ae716f292778123dc34e04b1 (diff)
Corrected getParamaters for partial views
Module:getParameters was incorrectly overwriting parameters that were partial views on larger storages.
Diffstat (limited to 'test')
-rw-r--r--test/test.lua31
1 files changed, 31 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index b035928..aa1d062 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1745,6 +1745,37 @@ function nntest.Module_getParameters_7()
mytester:asserteq(p:nElement(), 121, 'error: incorrect number of elements in flat vector')
end
+function nntest.Module_getParameters_8()
+ local function makeMLP(nin, ns)
+ local net = nn.Sequential()
+
+ for k,v in ipairs(ns) do
+ net:add(nn.Linear(nin, v))
+ nin = v
+ end
+ _,_ = net:getParameters()
+ return net
+ end
+
+ local mlp1 = makeMLP(10, {10,10})
+ local mlp2 = makeMLP(10, {10,10})
+
+ local net = nn.Sequential():add(mlp1:get(1))
+ :add(mlp2:get(1))
+
+ -- clone the second MLP to ensure that the weights before calling getParameters are preserved
+ mlp2 = mlp2:clone()
+
+ local p, gp = net:getParameters()
+
+ mytester:asserteq((p[{ {1,100} }] - net.modules[1].weight):norm(), 0, 'error when using partial realloc')
+ mytester:asserteq((p[{ {111,210} }] - net.modules[2].weight):norm(), 0, 'error when using partial realloc')
+ -- check that the weights have the same values as before get Parameters was called
+ mytester:asserteq((net.modules[1].weight - mlp1.modules[1].weight):norm(), 0, ' error when using partial realloc')
+ mytester:asserteq((net.modules[2].weight - mlp2.modules[1].weight):norm(), 0, ' error when using partial realloc')
+
+end
+
function nntest.PairwiseDistance()
-- Note: testJacobian doesn't support table inputs, and rather than re-write
-- it so that it does, I'll just use a split table module on the input.