diff options
author | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-06-15 19:31:27 +0300 |
---|---|---|
committer | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-06-18 18:30:28 +0300 |
commit | ac74eef878e3ac9e8f7634595e9fee3f7c1a17cd (patch) | |
tree | 4626d5ae82bb5fda57b399f7697b3a043e0c6724 /HardTanh.lua | |
parent | be1a51a3cbf91b7e0a49650633375e7283e41d2b (diff) |
inplace HardTanh, subclass ReLU6
Diffstat (limited to 'HardTanh.lua')
-rw-r--r-- | HardTanh.lua | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/HardTanh.lua b/HardTanh.lua index d3449a1..07cfc62 100644 --- a/HardTanh.lua +++ b/HardTanh.lua @@ -1,9 +1,13 @@ local HardTanh, parent = torch.class('nn.HardTanh', 'nn.Module') -function HardTanh:__init(min_value, max_value) +function HardTanh:__init(min_value, max_value, inplace) parent.__init(self) self.min_val = min_value or -1 self.max_val = max_value or 1 + self.inplace = inplace or false + if (inplace and type(inplace) ~= 'boolean') then + error('in-place flag must be boolean') + end assert(self.max_val>self.min_val, 'max_value must be larger than min_value') end @@ -14,7 +18,8 @@ function HardTanh:updateOutput(input) input:cdata(), self.output:cdata(), self.min_val, - self.max_val + self.max_val, + self.inplace or false ) return self.output end @@ -25,7 +30,8 @@ function HardTanh:updateGradInput(input, gradOutput) gradOutput:cdata(), self.gradInput:cdata(), self.min_val, - self.max_val + self.max_val, + self.inplace or false ) return self.gradInput end |