Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-09 21:34:51 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-09 21:34:51 +0400
commit157004f72ff7fa9f9c7216ca6b0b766476b4696c (patch)
tree13c70a8625bc1fcc78e004f387795c31a9dd1495
parent0413ddd6dc0a35b5281fcaaebc73144b15f285fa (diff)
ElementTable
-rw-r--r--ElementTable.lua35
-rw-r--r--init.lua1
2 files changed, 36 insertions, 0 deletions
diff --git a/ElementTable.lua b/ElementTable.lua
new file mode 100644
index 0000000..b1b28d0
--- /dev/null
+++ b/ElementTable.lua
@@ -0,0 +1,35 @@
+local ElementTable, parent = torch.class('nn.ElementTable', 'nn.Module')
+
+function ElementTable:__init(index)
+ parent.__init(self)
+ self.index = index
+ self.gradInput = {}
+end
+
+function ElementTable:updateOutput(input)
+ self.output:set(input[self.index])
+ return self.output
+end
+
+function ElementTable:updateGradInput(input, gradOutput)
+ if #gradInput == 0 then
+ local function zeroTableCopy(t1, t2)
+ for k, v in pairs(t2) do
+ if (torch.type(v) == "table") then
+ t1[k] = zeroTableCopy(t1[k] or {}, t2[k])
+ else
+ t1[k] = v:clone():zero()
+ end
+ end
+ return t1
+ end
+ zeroTableCopy(self.gradInput, input)
+ end
+ self.gradInput[self.index] = gradOutput
+ return self.gradInput
+end
+
+function ElementTable:type(type)
+ self.gradInput = {}
+ self.output = self.output:type(type)
+end
diff --git a/init.lua b/init.lua
index 757c9ec..6424cf5 100644
--- a/init.lua
+++ b/init.lua
@@ -85,6 +85,7 @@ include('ParallelTable.lua')
include('ConcatTable.lua')
include('SplitTable.lua')
include('JoinTable.lua')
+include('ElementTable.lua')
include('CriterionTable.lua')
include('Identity.lua')