Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2016-06-15 19:31:27 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2016-06-18 18:30:28 +0300
commitac74eef878e3ac9e8f7634595e9fee3f7c1a17cd (patch)
tree4626d5ae82bb5fda57b399f7697b3a043e0c6724 /HardTanh.lua
parentbe1a51a3cbf91b7e0a49650633375e7283e41d2b (diff)
inplace HardTanh, subclass ReLU6
Diffstat (limited to 'HardTanh.lua')
-rw-r--r--HardTanh.lua12
1 files changed, 9 insertions, 3 deletions
diff --git a/HardTanh.lua b/HardTanh.lua
index d3449a1..07cfc62 100644
--- a/HardTanh.lua
+++ b/HardTanh.lua
@@ -1,9 +1,13 @@
local HardTanh, parent = torch.class('nn.HardTanh', 'nn.Module')
-function HardTanh:__init(min_value, max_value)
+function HardTanh:__init(min_value, max_value, inplace)
parent.__init(self)
self.min_val = min_value or -1
self.max_val = max_value or 1
+ self.inplace = inplace or false
+ if (inplace and type(inplace) ~= 'boolean') then
+ error('in-place flag must be boolean')
+ end
assert(self.max_val>self.min_val, 'max_value must be larger than min_value')
end
@@ -14,7 +18,8 @@ function HardTanh:updateOutput(input)
input:cdata(),
self.output:cdata(),
self.min_val,
- self.max_val
+ self.max_val,
+ self.inplace or false
)
return self.output
end
@@ -25,7 +30,8 @@ function HardTanh:updateGradInput(input, gradOutput)
gradOutput:cdata(),
self.gradInput:cdata(),
self.min_val,
- self.max_val
+ self.max_val,
+ self.inplace or false
)
return self.gradInput
end