Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2016-02-12 16:01:09 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2016-02-12 16:02:44 +0300
commita2e27cc763d304a212552f4ec81ddc9e2c6fbcf5 (patch)
tree193c7841b44b900fa9993e859261f51a648eb87f /SpatialSoftMax.lua
parent3f7c066fae27e202925a9f7c74eaec70130e6e2f (diff)
clearState
Diffstat (limited to 'SpatialSoftMax.lua')
-rw-r--r--SpatialSoftMax.lua20
1 files changed, 17 insertions, 3 deletions
diff --git a/SpatialSoftMax.lua b/SpatialSoftMax.lua
index f874cd3..f180526 100644
--- a/SpatialSoftMax.lua
+++ b/SpatialSoftMax.lua
@@ -8,11 +8,14 @@ function SpatialSoftMax:__init(fast)
else
self.algorithm = 'CUDNN_SOFTMAX_ACCURATE'
end
- self.mode = 'CUDNN_SOFTMAX_MODE_CHANNEL'
- self.iSize = torch.LongStorage(4):fill(0)
end
function SpatialSoftMax:createIODescriptors(input)
+ self.mode = self.mode or 'CUDNN_SOFTMAX_MODE_CHANNEL'
+ -- after converting from nn use accurate
+ self.algorithm = self.algorithm or 'CUDNN_SOFTMAX_ACCURATE'
+ self.iSize = self.iSize or torch.LongStorage(4):fill(0)
+
local batch = true
local singleDim = false
if input:dim() == 1 then
@@ -27,6 +30,7 @@ function SpatialSoftMax:createIODescriptors(input)
batch = false
end
assert(input:dim() == 4 and input:isContiguous());
+
if not self.iDesc or not self.oDesc or
input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2]
or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then
@@ -86,12 +90,22 @@ function SpatialSoftMax:updateGradInput(input, gradOutput)
return self.gradInput
end
-function SpatialSoftMax:write(f)
+function SpatialSoftMax:clearDesc()
self.iDesc = nil
self.oDesc = nil
+end
+
+function SpatialSoftMax:write(f)
+ self:clearDesc()
local var = {}
for k,v in pairs(self) do
var[k] = v
end
f:writeObject(var)
end
+
+function SpatialSoftMax:clearState()
+ self:clearDesc()
+ nn.utils.clear(self, '_gradOutput')
+ return parent.clearState(self)
+end