diff options
author | Michael Rotman <rotmanmi@pc-wolf111.(none)> | 2015-03-05 12:37:42 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-03-07 20:32:58 +0300 |
commit | ffaf24a73562b18cf992406b5f2d8a8174e8a4dd (patch) | |
tree | 1527f3bb08df9974dbd64d34388f710a16040495 /Power.lua | |
parent | 8be4ca8848c1fbfbeffe381cfcb9df6bdc80c3d0 (diff) |
Now you can use power on inputs with 0s. (there is no longer division by 0)
Diffstat (limited to 'Power.lua')
-rw-r--r-- | Power.lua | 45 |
1 files changed, 24 insertions, 21 deletions
@@ -1,21 +1,24 @@ -local Power, parent = torch.class('nn.Power','nn.Module') - -function Power:__init(p) - parent.__init(self) - self.pow = p - if not p then - error('nn.Power(power)') - end -end - -function Power:updateOutput(input) - self.output:resizeAs(input):copy(input) - self.output:pow(self.pow) - return self.output -end - -function Power:updateGradInput(input, gradOutput) - self.gradInput:resizeAs(input):copy(gradOutput) - self.gradInput:cmul(self.output):cdiv(input):mul(self.pow) - return self.gradInput -end +local Power, parent = torch.class('nn.Power','nn.Module')
+
+function Power:__init(p)
+ parent.__init(self)
+ self.pow = p
+ if not p then
+ error('nn.Power(power)')
+ end
+end
+
+function Power:updateOutput(input)
+ self.output:resizeAs(input):copy(input)
+ self.output:pow(self.pow)
+ return self.output
+end
+
+function Power:updateGradInput(input, gradOutput)
+ self.buffer = self.buffer or input.new()
+ self.buffer:resizeAs(input):copy(input)
+ self.buffer:pow(self.pow - 1)
+ self.gradInput:resizeAs(input):copy(gradOutput)
+ self.gradInput:cmul(self.buffer):mul(self.pow)
+ return self.gradInput
+end
\ No newline at end of file |