diff options
author | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-02-12 16:01:09 +0300 |
---|---|---|
committer | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-02-12 16:02:44 +0300 |
commit | a2e27cc763d304a212552f4ec81ddc9e2c6fbcf5 (patch) | |
tree | 193c7841b44b900fa9993e859261f51a648eb87f /SpatialSoftMax.lua | |
parent | 3f7c066fae27e202925a9f7c74eaec70130e6e2f (diff) |
clearState
Diffstat (limited to 'SpatialSoftMax.lua')
-rw-r--r-- | SpatialSoftMax.lua | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/SpatialSoftMax.lua b/SpatialSoftMax.lua index f874cd3..f180526 100644 --- a/SpatialSoftMax.lua +++ b/SpatialSoftMax.lua @@ -8,11 +8,14 @@ function SpatialSoftMax:__init(fast) else self.algorithm = 'CUDNN_SOFTMAX_ACCURATE' end - self.mode = 'CUDNN_SOFTMAX_MODE_CHANNEL' - self.iSize = torch.LongStorage(4):fill(0) end function SpatialSoftMax:createIODescriptors(input) + self.mode = self.mode or 'CUDNN_SOFTMAX_MODE_CHANNEL' + -- after converting from nn use accurate + self.algorithm = self.algorithm or 'CUDNN_SOFTMAX_ACCURATE' + self.iSize = self.iSize or torch.LongStorage(4):fill(0) + local batch = true local singleDim = false if input:dim() == 1 then @@ -27,6 +30,7 @@ function SpatialSoftMax:createIODescriptors(input) batch = false end assert(input:dim() == 4 and input:isContiguous()); + if not self.iDesc or not self.oDesc or input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then @@ -86,12 +90,22 @@ function SpatialSoftMax:updateGradInput(input, gradOutput) return self.gradInput end -function SpatialSoftMax:write(f) +function SpatialSoftMax:clearDesc() self.iDesc = nil self.oDesc = nil +end + +function SpatialSoftMax:write(f) + self:clearDesc() local var = {} for k,v in pairs(self) do var[k] = v end f:writeObject(var) end + +function SpatialSoftMax:clearState() + self:clearDesc() + nn.utils.clear(self, '_gradOutput') + return parent.clearState(self) +end |