diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2016-04-11 23:28:55 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@fb.com> | 2016-04-28 08:20:35 +0300 |
commit | 59deaaac17d20db73e693c532715456024d0dd1b (patch) | |
tree | 849829dec9a619b89024133502fb1467a9057258 | |
parent | 7be68b72cccaf69433af5efc7061b288fd22a4e8 (diff) |
Add :replace() for gModule
-rw-r--r-- | gmodule.lua | 16 | ||||
-rw-r--r-- | test/test_nngraph.lua | 36 |
2 files changed, 52 insertions, 0 deletions
diff --git a/gmodule.lua b/gmodule.lua index e5bd4d4..7bdd13d 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -193,6 +193,22 @@ function gModule:__init(inputs,outputs) end end +function gModule:replace(callback) + local out = callback(self) + local revmodules = {} + for i,m in ipairs(self.modules) do + revmodules[m] = i + end + for i,node in ipairs(self.forwardnodes) do + if node.data.module then + local m = node.data.module + node.data.module = m:replace(callback) + self.modules[revmodules[m]] = node.data.module + end + end + return out +end + function gModule:map(gm, func) for i,node in ipairs(self.forwardnodes) do local gmnode = gm.forwardnodes[i] diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua index 49b5e93..a993283 100644 --- a/test/test_nngraph.lua +++ b/test/test_nngraph.lua @@ -432,4 +432,40 @@ function test.test_gradInputType() assert(ok, "backward should succeed") end + function test.test_replace() + local i = nn.Identity()() + local l1 = nn.Linear(5, 2)(i) + local sig = nn.Sigmoid()(l1) + local l2 = nn.Linear(2, 5)(sig) + local model = nn.gModule({i}, {l2}) + + local input = torch.randn(4, 5) + local gradOutput = torch.randn(4, 5) + tester:eq(model:forward(input):size(), input:size(), "inconsistent output size") + tester:eq(model:backward(input, gradOutput):size(), input:size(), "inconsistent output size") + + model:replace(function(m) + if torch.type(m) == 'nn.Linear' then + if m.weight:size(1) == 5 then + return nn.Linear(2, 10) + elseif m.weight:size(1) == 2 then + return nn.Linear(10, 2) + end + elseif torch.type(m) == 'nn.Sigmoid' then + return nn.Tanh() + end + return m + end) + + local input = torch.randn(4, 10) + local gradOutput = torch.randn(4, 10) + tester:eq(model:forward(input):size(), input:size(), "inconsistent output size") + tester:eq(model:backward(input, gradOutput):size(), input:size(), "inconsistent output size") + + tester:ne(model.modules[2], l1, "gModule.modules wasn't updated") + tester:ne(model.modules[3], sig, "gModule.modules wasn't updated") + tester:eq(torch.type(model.modules[3]), 'nn.Tanh', "replace didn't update gModule.modules") + tester:ne(model.modules[4], l2, "gModule.modules wasn't updated") + end + tester:add(test):run() |