diff options
author | soumith <soumith@fb.com> | 2016-02-29 23:44:58 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-02-29 23:44:58 +0300 |
commit | d10db73c947546c9a9a9cf264162b73ace29f161 (patch) | |
tree | 8a602bf4b988be38b9de47964e19b38a37e0646d /Squeeze.lua | |
parent | 2e498799e65bc11902f7faeaaf1f7c9ae880b733 (diff) |
adding nn.Squeeze
Diffstat (limited to 'Squeeze.lua')
-rw-r--r-- | Squeeze.lua | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/Squeeze.lua b/Squeeze.lua new file mode 100644 index 0000000..7d204a1 --- /dev/null +++ b/Squeeze.lua @@ -0,0 +1,40 @@ +local Squeeze, parent = torch.class('nn.Squeeze', 'nn.Module') + +function Squeeze:__init(dim, numInputDims) + parent.__init(self) + self.dim = dim + self:setNumInputDims(numInputDims) +end + +function Squeeze:setNumInputDims(numInputDims) + self.numInputDims = numInputDims + return self +end + +function Squeeze:updateOutput(input) + assert(input and torch.isTensor(input), 'Squeeze only works on tensors') + local dim = self.dim + local addone = false + if self.numInputDims and input:dim()==(self.numInputDims+1) then + if dim then + dim = dim + 1 + elseif input:size(1) == 1 then + addone = true -- in case of minibatch of size 1. + end + end + self.output:set(dim and input:squeeze(dim) or input:squeeze()) + if addone then + local s = self.output:size():totable{} + table.insert(s, 1, 1) + self.output:set(self.output:view(torch.LongStorage(s))) + end + return self.output +end + +function Squeeze:updateGradInput(input, gradOutput) + assert(input and torch.isTensor(input), 'Squeeze only works on tensors') + assert(gradOutput and torch.isTensor(gradOutput), 'Squeeze only works on tensors') + assert(input:nElement() == gradOutput:nElement()) + self.gradInput:set(gradOutput:view(input:size())) + return self.gradInput +end |