diff options
-rw-r--r-- | ElementTable.lua | 35 | ||||
-rw-r--r-- | init.lua | 1 |
2 files changed, 36 insertions, 0 deletions
diff --git a/ElementTable.lua b/ElementTable.lua new file mode 100644 index 0000000..b1b28d0 --- /dev/null +++ b/ElementTable.lua @@ -0,0 +1,35 @@ +local ElementTable, parent = torch.class('nn.ElementTable', 'nn.Module') + +function ElementTable:__init(index) + parent.__init(self) + self.index = index + self.gradInput = {} +end + +function ElementTable:updateOutput(input) + self.output:set(input[self.index]) + return self.output +end + +function ElementTable:updateGradInput(input, gradOutput) + if #gradInput == 0 then + local function zeroTableCopy(t1, t2) + for k, v in pairs(t2) do + if (torch.type(v) == "table") then + t1[k] = zeroTableCopy(t1[k] or {}, t2[k]) + else + t1[k] = v:clone():zero() + end + end + return t1 + end + zeroTableCopy(self.gradInput, input) + end + self.gradInput[self.index] = gradOutput + return self.gradInput +end + +function ElementTable:type(type) + self.gradInput = {} + self.output = self.output:type(type) +end @@ -85,6 +85,7 @@ include('ParallelTable.lua') include('ConcatTable.lua') include('SplitTable.lua') include('JoinTable.lua') +include('ElementTable.lua') include('CriterionTable.lua') include('Identity.lua') |