Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordavidemaz <davidemaz@gmail.com>2017-03-10 20:37:32 +0300
committerdavidemaz <davidemaz@gmail.com>2017-03-10 20:37:32 +0300
commit179133fb499483555074eaff2ea875e975508df1 (patch)
treef5ad843fa396c921d77d23178411978544df2eb6 /MapTable.lua
parent1d38cbaa78a0ba7f48d168e62192292e705040ae (diff)
MapTable layer: 'share' parameter is now boolean
Diffstat (limited to 'MapTable.lua')
-rw-r--r--MapTable.lua9
1 files changed, 7 insertions, 2 deletions
diff --git a/MapTable.lua b/MapTable.lua
index 90b439c..351eb70 100644
--- a/MapTable.lua
+++ b/MapTable.lua
@@ -2,7 +2,8 @@ local MapTable, parent = torch.class('nn.MapTable', 'nn.Container')
function MapTable:__init(module, shared)
parent.__init(self)
- self.shared = shared or {'weight', 'bias', 'gradWeight', 'gradBias'}
+ self.shared = shared or true
+ self.sharedparams = {'weight', 'bias', 'gradWeight', 'gradBias'}
self.output = {}
self.gradInput = {}
self:add(module)
@@ -12,7 +13,11 @@ function MapTable:_extend(n)
self.modules[1] = self.module
for i = 2, n do
if not self.modules[i] then
- self.modules[i] = self.module:clone(table.unpack(self.shared))
+ if shared then
+ self.modules[i] = self.module:clone(table.unpack(self.sharedparams))
+ else
+ self.modules[i] = self.module:clone(table.unpack(self.sharedparams))
+ end
end
end
end