Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-03-20 03:04:59 +0300
committersoumith <soumith@fb.com>2015-04-07 23:51:42 +0300
commitab09f77b32119e0c2de49572c8c856c81363c2a0 (patch)
treef444078b84beca2db77c90fa16c83bf45f1a12c9 /Threshold.lua
parenta7b5dcd4ec6e0b38f2c88c84e757eb42081e2aca (diff)
adds in-place ReLU and fixes a potential divide-by-zero in nn.Sqrt
Diffstat (limited to 'Threshold.lua')
-rw-r--r--Threshold.lua20
1 files changed, 19 insertions, 1 deletions
diff --git a/Threshold.lua b/Threshold.lua
index 6083957..1c35e1e 100644
--- a/Threshold.lua
+++ b/Threshold.lua
@@ -1,20 +1,38 @@
local Threshold, parent = torch.class('nn.Threshold','nn.Module')
-function Threshold:__init(th,v)
+function Threshold:__init(th,v,ip)
parent.__init(self)
self.threshold = th or 1e-6
self.val = v or 0
if (th and type(th) ~= 'number') or (v and type(v) ~= 'number') then
error('nn.Threshold(threshold, value)')
end
+ -- default for inplace is false
+ self.inplace = ip or false
+ if (ip and type(ip) ~= 'boolean') then
+ error('in-place flag must be boolean')
+ end
+ self:validateParameters()
end
function Threshold:updateOutput(input)
+ self:validateParameters()
input.nn.Threshold_updateOutput(self, input)
return self.output
end
function Threshold:updateGradInput(input, gradOutput)
+ self:validateParameters()
input.nn.Threshold_updateGradInput(self, input, gradOutput)
return self.gradInput
end
+
+function Threshold:validateParameters()
+ self.inplace = self.inplace or false -- backwards compatibility pre inplace
+ if self.inplace then
+ if self.val > self.threshold then
+ error('in-place processing requires value (' .. self.val ..
+ ') not exceed threshold (' .. self.threshold .. ')')
+ end
+ end
+end