Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Rotman <rotmanmi@pc-wolf111.(none)>2015-03-05 12:37:42 +0300
committerSoumith Chintala <soumith@gmail.com>2015-03-07 20:32:58 +0300
commitffaf24a73562b18cf992406b5f2d8a8174e8a4dd (patch)
tree1527f3bb08df9974dbd64d34388f710a16040495 /Power.lua
parent8be4ca8848c1fbfbeffe381cfcb9df6bdc80c3d0 (diff)
Now you can use power on inputs with 0s. (there is no longer division by 0)
Diffstat (limited to 'Power.lua')
-rw-r--r--Power.lua45
1 files changed, 24 insertions, 21 deletions
diff --git a/Power.lua b/Power.lua
index 8052b3f..76421c9 100644
--- a/Power.lua
+++ b/Power.lua
@@ -1,21 +1,24 @@
-local Power, parent = torch.class('nn.Power','nn.Module')
-
-function Power:__init(p)
- parent.__init(self)
- self.pow = p
- if not p then
- error('nn.Power(power)')
- end
-end
-
-function Power:updateOutput(input)
- self.output:resizeAs(input):copy(input)
- self.output:pow(self.pow)
- return self.output
-end
-
-function Power:updateGradInput(input, gradOutput)
- self.gradInput:resizeAs(input):copy(gradOutput)
- self.gradInput:cmul(self.output):cdiv(input):mul(self.pow)
- return self.gradInput
-end
+local Power, parent = torch.class('nn.Power','nn.Module')
+
+function Power:__init(p)
+ parent.__init(self)
+ self.pow = p
+ if not p then
+ error('nn.Power(power)')
+ end
+end
+
+function Power:updateOutput(input)
+ self.output:resizeAs(input):copy(input)
+ self.output:pow(self.pow)
+ return self.output
+end
+
+function Power:updateGradInput(input, gradOutput)
+ self.buffer = self.buffer or input.new()
+ self.buffer:resizeAs(input):copy(input)
+ self.buffer:pow(self.pow - 1)
+ self.gradInput:resizeAs(input):copy(gradOutput)
+ self.gradInput:cmul(self.buffer):mul(self.pow)
+ return self.gradInput
+end \ No newline at end of file