From 59deaaac17d20db73e693c532715456024d0dd1b Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 11 Apr 2016 22:28:55 +0200 Subject: Add :replace() for gModule --- gmodule.lua | 16 ++++++++++++++++ test/test_nngraph.lua | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/gmodule.lua b/gmodule.lua index e5bd4d4..7bdd13d 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -193,6 +193,22 @@ function gModule:__init(inputs,outputs) end end +function gModule:replace(callback) + local out = callback(self) + local revmodules = {} + for i,m in ipairs(self.modules) do + revmodules[m] = i + end + for i,node in ipairs(self.forwardnodes) do + if node.data.module then + local m = node.data.module + node.data.module = m:replace(callback) + self.modules[revmodules[m]] = node.data.module + end + end + return out +end + function gModule:map(gm, func) for i,node in ipairs(self.forwardnodes) do local gmnode = gm.forwardnodes[i] diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua index 49b5e93..a993283 100644 --- a/test/test_nngraph.lua +++ b/test/test_nngraph.lua @@ -432,4 +432,40 @@ function test.test_gradInputType() assert(ok, "backward should succeed") end + function test.test_replace() + local i = nn.Identity()() + local l1 = nn.Linear(5, 2)(i) + local sig = nn.Sigmoid()(l1) + local l2 = nn.Linear(2, 5)(sig) + local model = nn.gModule({i}, {l2}) + + local input = torch.randn(4, 5) + local gradOutput = torch.randn(4, 5) + tester:eq(model:forward(input):size(), input:size(), "inconsistent output size") + tester:eq(model:backward(input, gradOutput):size(), input:size(), "inconsistent output size") + + model:replace(function(m) + if torch.type(m) == 'nn.Linear' then + if m.weight:size(1) == 5 then + return nn.Linear(2, 10) + elseif m.weight:size(1) == 2 then + return nn.Linear(10, 2) + end + elseif torch.type(m) == 'nn.Sigmoid' then + return nn.Tanh() + end + return m + end) + + local input = torch.randn(4, 10) + local gradOutput = torch.randn(4, 10) + tester:eq(model:forward(input):size(), input:size(), "inconsistent output size") + tester:eq(model:backward(input, gradOutput):size(), input:size(), "inconsistent output size") + + tester:ne(model.modules[2], l1, "gModule.modules wasn't updated") + tester:ne(model.modules[3], sig, "gModule.modules wasn't updated") + tester:eq(torch.type(model.modules[3]), 'nn.Tanh', "replace didn't update gModule.modules") + tester:ne(model.modules[4], l2, "gModule.modules wasn't updated") + end + tester:add(test):run() -- cgit v1.2.3