diff options
author | davidemaz <davidemaz@gmail.com> | 2017-03-10 20:37:32 +0300 |
---|---|---|
committer | davidemaz <davidemaz@gmail.com> | 2017-03-10 20:37:32 +0300 |
commit | 179133fb499483555074eaff2ea875e975508df1 (patch) | |
tree | f5ad843fa396c921d77d23178411978544df2eb6 /MapTable.lua | |
parent | 1d38cbaa78a0ba7f48d168e62192292e705040ae (diff) |
MapTable layer: 'share' parameter is now boolean
Diffstat (limited to 'MapTable.lua')
-rw-r--r-- | MapTable.lua | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/MapTable.lua b/MapTable.lua index 90b439c..351eb70 100644 --- a/MapTable.lua +++ b/MapTable.lua @@ -2,7 +2,8 @@ local MapTable, parent = torch.class('nn.MapTable', 'nn.Container') function MapTable:__init(module, shared) parent.__init(self) - self.shared = shared or {'weight', 'bias', 'gradWeight', 'gradBias'} + self.shared = shared or true + self.sharedparams = {'weight', 'bias', 'gradWeight', 'gradBias'} self.output = {} self.gradInput = {} self:add(module) @@ -12,7 +13,11 @@ function MapTable:_extend(n) self.modules[1] = self.module for i = 2, n do if not self.modules[i] then - self.modules[i] = self.module:clone(table.unpack(self.shared)) + if shared then + self.modules[i] = self.module:clone(table.unpack(self.sharedparams)) + else + self.modules[i] = self.module:clone(table.unpack(self.sharedparams)) + end end end end |