Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2016-04-11 23:28:55 +0300
committerSoumith Chintala <soumith@fb.com>2016-04-28 08:20:35 +0300
commit59deaaac17d20db73e693c532715456024d0dd1b (patch)
tree849829dec9a619b89024133502fb1467a9057258
parent7be68b72cccaf69433af5efc7061b288fd22a4e8 (diff)
Add :replace() for gModule
-rw-r--r--gmodule.lua16
-rw-r--r--test/test_nngraph.lua36
2 files changed, 52 insertions, 0 deletions
diff --git a/gmodule.lua b/gmodule.lua
index e5bd4d4..7bdd13d 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -193,6 +193,22 @@ function gModule:__init(inputs,outputs)
end
end
+function gModule:replace(callback)
+ local out = callback(self)
+ local revmodules = {}
+ for i,m in ipairs(self.modules) do
+ revmodules[m] = i
+ end
+ for i,node in ipairs(self.forwardnodes) do
+ if node.data.module then
+ local m = node.data.module
+ node.data.module = m:replace(callback)
+ self.modules[revmodules[m]] = node.data.module
+ end
+ end
+ return out
+end
+
function gModule:map(gm, func)
for i,node in ipairs(self.forwardnodes) do
local gmnode = gm.forwardnodes[i]
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua
index 49b5e93..a993283 100644
--- a/test/test_nngraph.lua
+++ b/test/test_nngraph.lua
@@ -432,4 +432,40 @@ function test.test_gradInputType()
assert(ok, "backward should succeed")
end
+ function test.test_replace()
+ local i = nn.Identity()()
+ local l1 = nn.Linear(5, 2)(i)
+ local sig = nn.Sigmoid()(l1)
+ local l2 = nn.Linear(2, 5)(sig)
+ local model = nn.gModule({i}, {l2})
+
+ local input = torch.randn(4, 5)
+ local gradOutput = torch.randn(4, 5)
+ tester:eq(model:forward(input):size(), input:size(), "inconsistent output size")
+ tester:eq(model:backward(input, gradOutput):size(), input:size(), "inconsistent output size")
+
+ model:replace(function(m)
+ if torch.type(m) == 'nn.Linear' then
+ if m.weight:size(1) == 5 then
+ return nn.Linear(2, 10)
+ elseif m.weight:size(1) == 2 then
+ return nn.Linear(10, 2)
+ end
+ elseif torch.type(m) == 'nn.Sigmoid' then
+ return nn.Tanh()
+ end
+ return m
+ end)
+
+ local input = torch.randn(4, 10)
+ local gradOutput = torch.randn(4, 10)
+ tester:eq(model:forward(input):size(), input:size(), "inconsistent output size")
+ tester:eq(model:backward(input, gradOutput):size(), input:size(), "inconsistent output size")
+
+ tester:ne(model.modules[2], l1, "gModule.modules wasn't updated")
+ tester:ne(model.modules[3], sig, "gModule.modules wasn't updated")
+ tester:eq(torch.type(model.modules[3]), 'nn.Tanh', "replace didn't update gModule.modules")
+ tester:ne(model.modules[4], l2, "gModule.modules wasn't updated")
+ end
+
tester:add(test):run()