Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'Mean.lua')
-rw-r--r--Mean.lua26
1 files changed, 26 insertions, 0 deletions
diff --git a/Mean.lua b/Mean.lua
new file mode 100644
index 0000000..55e7609
--- /dev/null
+++ b/Mean.lua
@@ -0,0 +1,26 @@
+local Mean, parent = torch.class('nn.Mean', 'nn.Module')
+
+function Mean:__init(dimension)
+ parent.__init(self)
+ dimension = dimension or 1
+ self.dimension = dimension
+end
+
+function Mean:updateOutput(input)
+ input.torch.mean(self.output, input, self.dimension)
+ self.output = self.output:select(self.dimension, 1)
+ return self.output
+end
+
+function Mean:updateGradInput(input, gradOutput)
+ local size = gradOutput:size():totable()
+ local stride = gradOutput:stride():totable()
+ table.insert(size, self.dimension, input:size(self.dimension))
+ table.insert(stride, self.dimension, 0)
+
+ self.gradInput:resizeAs(gradOutput):copy(gradOutput)
+ self.gradInput:mul(1/input:size(self.dimension))
+ self.gradInput:resize(torch.LongStorage(size), torch.LongStorage(stride))
+
+ return self.gradInput
+end