diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-02-12 18:11:24 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-02-12 18:11:24 +0300 |
commit | 28c2f6e76a0d3671ce127197c25e39c5ee4be627 (patch) | |
tree | bb8e1b383f5520727cd3dfa0119e13e738dd915d | |
parent | 19adc6362cf764cde9ff82d702a061d6d367c81e (diff) | |
parent | a2e27cc763d304a212552f4ec81ddc9e2c6fbcf5 (diff) |
Merge pull request #106 from szagoruyko/clearState
clearState
-rw-r--r-- | Pointwise.lua | 9 | ||||
-rw-r--r-- | Pooling.lua | 11 | ||||
-rw-r--r-- | Pooling3D.lua | 5 | ||||
-rw-r--r-- | SpatialConvolution.lua | 5 | ||||
-rw-r--r-- | SpatialCrossMapLRN.lua | 5 | ||||
-rw-r--r-- | SpatialSoftMax.lua | 20 | ||||
-rw-r--r-- | TemporalConvolution.lua | 11 | ||||
-rw-r--r-- | VolumetricConvolution.lua | 6 |
8 files changed, 65 insertions, 7 deletions
diff --git a/Pointwise.lua b/Pointwise.lua index 652ca60..51fdcca 100644 --- a/Pointwise.lua +++ b/Pointwise.lua @@ -25,7 +25,7 @@ local zero = torch.FloatTensor({0}); function Pointwise:updateOutput(input) self:createIODescriptors(input) - if self.inplace then self.output = input end + if self.inplace then self.output:set(input) end errcheck('cudnnActivationForward', cudnn.getHandle(), self.mode, one:data(), @@ -42,7 +42,7 @@ function Pointwise:updateGradInput(input, gradOutput) gradOutput = self._gradOutput end self:createIODescriptors(input) - if self.inplace then self.output = input; self.gradInput = gradOutput end + if self.inplace then self.output:set(input); self.gradInput:set(gradOutput) end errcheck('cudnnActivationBackward', cudnn.getHandle(), self.mode, one:data(), @@ -66,3 +66,8 @@ function Pointwise:write(f) end f:writeObject(var) end + +function Pointwise:clearState() + self:clearDesc() + return parent.clearState(self) +end diff --git a/Pooling.lua b/Pooling.lua index 4da3353..e9c9025 100644 --- a/Pooling.lua +++ b/Pooling.lua @@ -115,13 +115,22 @@ function Pooling:updateGradInput(input, gradOutput) return self.gradInput end -function Pooling:write(f) +function Pooling:clearDesc() self.poolDesc = nil self.iDesc = nil self.oDesc = nil +end + +function Pooling:write(f) + self:clearDesc() local var = {} for k,v in pairs(self) do var[k] = v end f:writeObject(var) end + +function Pooling:clearState() + self:clearDesc() + return parent.clearState(self) +end diff --git a/Pooling3D.lua b/Pooling3D.lua index 8c5cc26..a1fd3e3 100644 --- a/Pooling3D.lua +++ b/Pooling3D.lua @@ -138,3 +138,8 @@ function Pooling:write(f) end f:writeObject(var) end + +function Pooling:clearState() + self:clearDesc() + return parent.clearState(self) +end diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index 0ee250c..2597aa5 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -482,3 +482,8 @@ function SpatialConvolution:write(f) end f:writeObject(var) end + +function SpatialConvolution:clearState() + self:clearDesc() + return nn.Module.clearState(self) +end diff --git a/SpatialCrossMapLRN.lua b/SpatialCrossMapLRN.lua index c79f246..43cba69 100644 --- a/SpatialCrossMapLRN.lua +++ b/SpatialCrossMapLRN.lua @@ -103,3 +103,8 @@ function LRN:write(f) end f:writeObject(var) end + +function LRN:clearState() + self:clearDesc() + return nn.Module.clearState(self) +end diff --git a/SpatialSoftMax.lua b/SpatialSoftMax.lua index f874cd3..f180526 100644 --- a/SpatialSoftMax.lua +++ b/SpatialSoftMax.lua @@ -8,11 +8,14 @@ function SpatialSoftMax:__init(fast) else self.algorithm = 'CUDNN_SOFTMAX_ACCURATE' end - self.mode = 'CUDNN_SOFTMAX_MODE_CHANNEL' - self.iSize = torch.LongStorage(4):fill(0) end function SpatialSoftMax:createIODescriptors(input) + self.mode = self.mode or 'CUDNN_SOFTMAX_MODE_CHANNEL' + -- after converting from nn use accurate + self.algorithm = self.algorithm or 'CUDNN_SOFTMAX_ACCURATE' + self.iSize = self.iSize or torch.LongStorage(4):fill(0) + local batch = true local singleDim = false if input:dim() == 1 then @@ -27,6 +30,7 @@ function SpatialSoftMax:createIODescriptors(input) batch = false end assert(input:dim() == 4 and input:isContiguous()); + if not self.iDesc or not self.oDesc or input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then @@ -86,12 +90,22 @@ function SpatialSoftMax:updateGradInput(input, gradOutput) return self.gradInput end -function SpatialSoftMax:write(f) +function SpatialSoftMax:clearDesc() self.iDesc = nil self.oDesc = nil +end + +function SpatialSoftMax:write(f) + self:clearDesc() local var = {} for k,v in pairs(self) do var[k] = v end f:writeObject(var) end + +function SpatialSoftMax:clearState() + self:clearDesc() + nn.utils.clear(self, '_gradOutput') + return parent.clearState(self) +end diff --git a/TemporalConvolution.lua b/TemporalConvolution.lua index 3e646e5..72a87c3 100644 --- a/TemporalConvolution.lua +++ b/TemporalConvolution.lua @@ -104,10 +104,14 @@ function TemporalConvolution:accGradParameters(input,gradOutput,scale) cudnn.SpatialConvolution.accGradParameters(self,_input,_gradOutput,scale) end -function TemporalConvolution:write(f) +function TemporalConvolution:clearDesc() self.buffer = nil self._ouptut = nil self.oSize = nil +end + +function TemporalConvolution:write(f) + self:clearDesc() cudnn.SpatialConvolution.clearDesc(self) local var = {} for k,v in pairs(self) do @@ -115,3 +119,8 @@ function TemporalConvolution:write(f) end f:writeObject(var) end + +function TemporalConvolution:clearState() + self:clearDesc() + return parent.clearState(self) +end diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua index 62237b6..db352a5 100644 --- a/VolumetricConvolution.lua +++ b/VolumetricConvolution.lua @@ -287,3 +287,9 @@ function VolumetricConvolution:write(f) end f:writeObject(var) end + +function VolumetricConvolution:clearState() + self:clearDesc() + nn.utils.clear(self, 'extraBuffer') + return nn.Module.clearState(self) +end |