Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'ConcatTable.lua')
-rw-r--r--ConcatTable.lua72
1 files changed, 72 insertions, 0 deletions
diff --git a/ConcatTable.lua b/ConcatTable.lua
new file mode 100644
index 0000000..730d95e
--- /dev/null
+++ b/ConcatTable.lua
@@ -0,0 +1,72 @@
+local ConcatTable, parent = torch.class('nn.ConcatTable', 'nn.Module')
+
+function ConcatTable:__init()
+ parent.__init(self)
+ self.modules = {}
+ self.output = {}
+end
+
+function ConcatTable:add(module)
+ table.insert(self.modules, module)
+ return self
+end
+
+function ConcatTable:get(index)
+ return self.modules[index]
+end
+
+function ConcatTable:size()
+ return #self.modules
+end
+
+function ConcatTable:updateOutput(input)
+ for i=1,#self.modules do
+ self.output[i] = self.modules[i]:updateOutput(input)
+ end
+ return self.output
+end
+
+function ConcatTable:updateGradInput(input, gradOutput)
+ for i,module in ipairs(self.modules) do
+ local currentGradInput = module:updateGradInput(input, gradOutput[i])
+ if i == 1 then
+ self.gradInput:resizeAs(currentGradInput):copy(currentGradInput)
+ else
+ self.gradInput:add(currentGradInput)
+ end
+ end
+ return self.gradInput
+end
+
+function ConcatTable:accGradParameters(input, gradOutput, scale)
+ scale = scale or 1
+ for i,module in ipairs(self.modules) do
+ module:accGradParameters(input, gradOutput[i], scale)
+ end
+end
+
+function ConcatTable:accUpdateGradParameters(input, gradOutput, lr)
+ for i,module in ipairs(self.modules) do
+ module:accUpdateGradParameters(input, gradOutput[i], lr)
+ end
+end
+
+function ConcatTable:zeroGradParameters()
+ for _,module in ipairs(self.modules) do
+ module:zeroGradParameters()
+ end
+end
+
+function ConcatTable:updateParameters(learningRate)
+ for _,module in ipairs(self.modules) do
+ module:updateParameters(learningRate)
+ end
+end
+
+function ConcatTable:share(mlp,...)
+ for i=1,#self.modules do
+ self.modules[i]:share(mlp.modules[i],...);
+ end
+end
+
+