Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-02-12 18:11:24 +0300
committerSoumith Chintala <soumith@gmail.com>2016-02-12 18:11:24 +0300
commit28c2f6e76a0d3671ce127197c25e39c5ee4be627 (patch)
treebb8e1b383f5520727cd3dfa0119e13e738dd915d
parent19adc6362cf764cde9ff82d702a061d6d367c81e (diff)
parenta2e27cc763d304a212552f4ec81ddc9e2c6fbcf5 (diff)
Merge pull request #106 from szagoruyko/clearState
clearState
-rw-r--r--Pointwise.lua9
-rw-r--r--Pooling.lua11
-rw-r--r--Pooling3D.lua5
-rw-r--r--SpatialConvolution.lua5
-rw-r--r--SpatialCrossMapLRN.lua5
-rw-r--r--SpatialSoftMax.lua20
-rw-r--r--TemporalConvolution.lua11
-rw-r--r--VolumetricConvolution.lua6
8 files changed, 65 insertions, 7 deletions
diff --git a/Pointwise.lua b/Pointwise.lua
index 652ca60..51fdcca 100644
--- a/Pointwise.lua
+++ b/Pointwise.lua
@@ -25,7 +25,7 @@ local zero = torch.FloatTensor({0});
function Pointwise:updateOutput(input)
self:createIODescriptors(input)
- if self.inplace then self.output = input end
+ if self.inplace then self.output:set(input) end
errcheck('cudnnActivationForward',
cudnn.getHandle(), self.mode,
one:data(),
@@ -42,7 +42,7 @@ function Pointwise:updateGradInput(input, gradOutput)
gradOutput = self._gradOutput
end
self:createIODescriptors(input)
- if self.inplace then self.output = input; self.gradInput = gradOutput end
+ if self.inplace then self.output:set(input); self.gradInput:set(gradOutput) end
errcheck('cudnnActivationBackward',
cudnn.getHandle(), self.mode,
one:data(),
@@ -66,3 +66,8 @@ function Pointwise:write(f)
end
f:writeObject(var)
end
+
+function Pointwise:clearState()
+ self:clearDesc()
+ return parent.clearState(self)
+end
diff --git a/Pooling.lua b/Pooling.lua
index 4da3353..e9c9025 100644
--- a/Pooling.lua
+++ b/Pooling.lua
@@ -115,13 +115,22 @@ function Pooling:updateGradInput(input, gradOutput)
return self.gradInput
end
-function Pooling:write(f)
+function Pooling:clearDesc()
self.poolDesc = nil
self.iDesc = nil
self.oDesc = nil
+end
+
+function Pooling:write(f)
+ self:clearDesc()
local var = {}
for k,v in pairs(self) do
var[k] = v
end
f:writeObject(var)
end
+
+function Pooling:clearState()
+ self:clearDesc()
+ return parent.clearState(self)
+end
diff --git a/Pooling3D.lua b/Pooling3D.lua
index 8c5cc26..a1fd3e3 100644
--- a/Pooling3D.lua
+++ b/Pooling3D.lua
@@ -138,3 +138,8 @@ function Pooling:write(f)
end
f:writeObject(var)
end
+
+function Pooling:clearState()
+ self:clearDesc()
+ return parent.clearState(self)
+end
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua
index 0ee250c..2597aa5 100644
--- a/SpatialConvolution.lua
+++ b/SpatialConvolution.lua
@@ -482,3 +482,8 @@ function SpatialConvolution:write(f)
end
f:writeObject(var)
end
+
+function SpatialConvolution:clearState()
+ self:clearDesc()
+ return nn.Module.clearState(self)
+end
diff --git a/SpatialCrossMapLRN.lua b/SpatialCrossMapLRN.lua
index c79f246..43cba69 100644
--- a/SpatialCrossMapLRN.lua
+++ b/SpatialCrossMapLRN.lua
@@ -103,3 +103,8 @@ function LRN:write(f)
end
f:writeObject(var)
end
+
+function LRN:clearState()
+ self:clearDesc()
+ return nn.Module.clearState(self)
+end
diff --git a/SpatialSoftMax.lua b/SpatialSoftMax.lua
index f874cd3..f180526 100644
--- a/SpatialSoftMax.lua
+++ b/SpatialSoftMax.lua
@@ -8,11 +8,14 @@ function SpatialSoftMax:__init(fast)
else
self.algorithm = 'CUDNN_SOFTMAX_ACCURATE'
end
- self.mode = 'CUDNN_SOFTMAX_MODE_CHANNEL'
- self.iSize = torch.LongStorage(4):fill(0)
end
function SpatialSoftMax:createIODescriptors(input)
+ self.mode = self.mode or 'CUDNN_SOFTMAX_MODE_CHANNEL'
+ -- after converting from nn use accurate
+ self.algorithm = self.algorithm or 'CUDNN_SOFTMAX_ACCURATE'
+ self.iSize = self.iSize or torch.LongStorage(4):fill(0)
+
local batch = true
local singleDim = false
if input:dim() == 1 then
@@ -27,6 +30,7 @@ function SpatialSoftMax:createIODescriptors(input)
batch = false
end
assert(input:dim() == 4 and input:isContiguous());
+
if not self.iDesc or not self.oDesc or
input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2]
or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then
@@ -86,12 +90,22 @@ function SpatialSoftMax:updateGradInput(input, gradOutput)
return self.gradInput
end
-function SpatialSoftMax:write(f)
+function SpatialSoftMax:clearDesc()
self.iDesc = nil
self.oDesc = nil
+end
+
+function SpatialSoftMax:write(f)
+ self:clearDesc()
local var = {}
for k,v in pairs(self) do
var[k] = v
end
f:writeObject(var)
end
+
+function SpatialSoftMax:clearState()
+ self:clearDesc()
+ nn.utils.clear(self, '_gradOutput')
+ return parent.clearState(self)
+end
diff --git a/TemporalConvolution.lua b/TemporalConvolution.lua
index 3e646e5..72a87c3 100644
--- a/TemporalConvolution.lua
+++ b/TemporalConvolution.lua
@@ -104,10 +104,14 @@ function TemporalConvolution:accGradParameters(input,gradOutput,scale)
cudnn.SpatialConvolution.accGradParameters(self,_input,_gradOutput,scale)
end
-function TemporalConvolution:write(f)
+function TemporalConvolution:clearDesc()
self.buffer = nil
self._ouptut = nil
self.oSize = nil
+end
+
+function TemporalConvolution:write(f)
+ self:clearDesc()
cudnn.SpatialConvolution.clearDesc(self)
local var = {}
for k,v in pairs(self) do
@@ -115,3 +119,8 @@ function TemporalConvolution:write(f)
end
f:writeObject(var)
end
+
+function TemporalConvolution:clearState()
+ self:clearDesc()
+ return parent.clearState(self)
+end
diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua
index 62237b6..db352a5 100644
--- a/VolumetricConvolution.lua
+++ b/VolumetricConvolution.lua
@@ -287,3 +287,9 @@ function VolumetricConvolution:write(f)
end
f:writeObject(var)
end
+
+function VolumetricConvolution:clearState()
+ self:clearDesc()
+ nn.utils.clear(self, 'extraBuffer')
+ return nn.Module.clearState(self)
+end