Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-02-29 23:44:58 +0300
committersoumith <soumith@fb.com>2016-02-29 23:44:58 +0300
commitd10db73c947546c9a9a9cf264162b73ace29f161 (patch)
tree8a602bf4b988be38b9de47964e19b38a37e0646d /Squeeze.lua
parent2e498799e65bc11902f7faeaaf1f7c9ae880b733 (diff)
adding nn.Squeeze
Diffstat (limited to 'Squeeze.lua')
-rw-r--r--Squeeze.lua40
1 files changed, 40 insertions, 0 deletions
diff --git a/Squeeze.lua b/Squeeze.lua
new file mode 100644
index 0000000..7d204a1
--- /dev/null
+++ b/Squeeze.lua
@@ -0,0 +1,40 @@
+local Squeeze, parent = torch.class('nn.Squeeze', 'nn.Module')
+
+function Squeeze:__init(dim, numInputDims)
+ parent.__init(self)
+ self.dim = dim
+ self:setNumInputDims(numInputDims)
+end
+
+function Squeeze:setNumInputDims(numInputDims)
+ self.numInputDims = numInputDims
+ return self
+end
+
+function Squeeze:updateOutput(input)
+ assert(input and torch.isTensor(input), 'Squeeze only works on tensors')
+ local dim = self.dim
+ local addone = false
+ if self.numInputDims and input:dim()==(self.numInputDims+1) then
+ if dim then
+ dim = dim + 1
+ elseif input:size(1) == 1 then
+ addone = true -- in case of minibatch of size 1.
+ end
+ end
+ self.output:set(dim and input:squeeze(dim) or input:squeeze())
+ if addone then
+ local s = self.output:size():totable{}
+ table.insert(s, 1, 1)
+ self.output:set(self.output:view(torch.LongStorage(s)))
+ end
+ return self.output
+end
+
+function Squeeze:updateGradInput(input, gradOutput)
+ assert(input and torch.isTensor(input), 'Squeeze only works on tensors')
+ assert(gradOutput and torch.isTensor(gradOutput), 'Squeeze only works on tensors')
+ assert(input:nElement() == gradOutput:nElement())
+ self.gradInput:set(gradOutput:view(input:size()))
+ return self.gradInput
+end