diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-01-27 21:29:05 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-01-27 21:29:05 +0300 |
commit | 02ebb69c1838c5c63a64b374faceff40671fdaeb (patch) | |
tree | b351977839e9df3298ac14bfd831b801ca1a144c /TemporalRowConvolution.lua | |
parent | e1efc6345f3dec8b631f91f640c11a4b7dd9e012 (diff) |
Rowconv repull (#1120)
* Added TemporalRowConvolutionMM layer, tests, and documentation
Diffstat (limited to 'TemporalRowConvolution.lua')
-rw-r--r-- | TemporalRowConvolution.lua | 120 |
1 files changed, 120 insertions, 0 deletions
diff --git a/TemporalRowConvolution.lua b/TemporalRowConvolution.lua new file mode 100644 index 0000000..7c9d6a2 --- /dev/null +++ b/TemporalRowConvolution.lua @@ -0,0 +1,120 @@ +local THNN = require "nn.THNN" + +local TemporalRowConvolution, parent = torch.class("nn.TemporalRowConvolution", "nn.Module") + +function TemporalRowConvolution:__init(inputFrameSize, kW, dW, featFirst) + parent.__init(self) + + self.inputFrameSize = inputFrameSize + self.kW = kW + self.dW = dW or 1 + + self.weight = torch.Tensor(inputFrameSize, 1, kW) + self.bias = torch.Tensor(inputFrameSize) + self.gradWeight = torch.Tensor(inputFrameSize, 1, kW) + self.gradBias = torch.Tensor(inputFrameSize) + + -- Set to true for batch x inputFrameSize x nInputFrame + self.featFirst = featFirst and true or false + self:reset() +end + +function TemporalRowConvolution:noBias() + self.bias = nil + self.gradBias = nil + return self +end + +function TemporalRowConvolution:reset(stdv) + if stdv then + stdv = stdv * math.sqrt(3) + else + stdv = 1 / math.sqrt(self.kW * self.inputFrameSize) + end + self.weight:uniform(-stdv, stdv) + self.bias:uniform(-stdv, stdv) +end + +function TemporalRowConvolution:updateOutput(input) + assert(input.THNN, torch.type(input)..".THNN backend not imported") + self.finput = self.finput or input.new() + self.fgradInput = self.fgradInput or input.new() + + input.THNN.TemporalRowConvolution_updateOutput( + input:cdata(), + self.output:cdata(), + self.weight:cdata(), + THNN.optionalTensor(self.bias), + self.finput:cdata(), + self.fgradInput:cdata(), + self.kW, + self.dW, + 0, -- would be self.padW + self.featFirst + ) + + return self.output +end + +function TemporalRowConvolution:updateGradInput(input, gradOutput) + assert(input.THNN, torch.type(input)..".THNN backend not imported") + + if self.gradInput then + input.THNN.TemporalRowConvolution_updateGradInput( + input:cdata(), + gradOutput:cdata(), + self.gradInput:cdata(), + self.weight:cdata(), + self.finput:cdata(), + self.fgradInput:cdata(), + self.kW, + self.dW, + 0, -- would be self.padW + self.featFirst + ) + return self.gradInput + end +end + +function TemporalRowConvolution:accGradParameters(input, gradOutput, scale) + assert(input.THNN, torch.type(input)..".THNN backend not imported") + + input.THNN.TemporalRowConvolution_accGradParameters( + input:cdata(), + gradOutput:cdata(), + self.gradWeight:cdata(), + THNN.optionalTensor(self.gradBias), + self.finput:cdata(), + self.fgradInput:cdata(), + self.kW, + self.dW, + 0, -- would be self.padW + self.featFirst, + scale or 1) +end + +function TemporalRowConvolution:type(type, tensorCache) + if self.finput then self.finput:set() end + if self.fgradInput then self.fgradInput:set() end + return parent.type(self, type, tensorCache) +end + +function TemporalRowConvolution:__tostring__() + local s = string.format("%s(%d, %d", torch.type(self), self.inputFrameSize, self.kW) + if self.dW ~= 1 then + s = s .. string.format(", %d", self.dW) + end + if self.padW and self.padW ~= 0 then -- currently padding is not supported + s = s .. ", " .. self.padW + end + if self.bias then + return s .. ")" + else + return s .. ") without bias" + end +end + +function TemporalRowConvolution:clearState() + nn.utils.clear(self, "finput", "fgradInput", "_input", "_gradOutput") + return parent.clearState(self) +end |