Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Pointwise.lua35
-rw-r--r--README.md9
-rw-r--r--ReLU.lua4
-rw-r--r--Sigmoid.lua4
-rw-r--r--SpatialSoftMax.lua1
-rw-r--r--Tanh.lua4
-rw-r--r--test/test.lua222
7 files changed, 70 insertions, 209 deletions
diff --git a/Pointwise.lua b/Pointwise.lua
index 2a8b6bd..a210610 100644
--- a/Pointwise.lua
+++ b/Pointwise.lua
@@ -1,8 +1,9 @@
local Pointwise, parent = torch.class('cudnn._Pointwise','nn.Module')
local errcheck = cudnn.errcheck
-function Pointwise:__init()
+function Pointwise:__init(inplace)
parent.__init(self)
+ self.inplace = inplace or false
self.iSize = torch.LongStorage(4):fill(0)
end
@@ -14,21 +15,22 @@ function Pointwise:createIODescriptors(input)
batch = false
end
assert(input:dim() == 4 and input:isContiguous());
- if not self.iDesc or not self.oDesc or
+ if not self.iDesc or
input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2]
or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then
self.iSize = input:size()
- self.gradInput:resizeAs(input)
- self.output:resizeAs(input)
self.iDesc = cudnn.toDescriptor(input)
- self.oDesc = cudnn.toDescriptor(self.output)
- if not batch then
- self.gradInput = self.gradInput:view(self.gradInput:size(2),
- self.gradInput:size(3),
- self.gradInput:size(4))
- self.output = self.output:view(self.output:size(2),
- self.output:size(3),
- self.output:size(4))
+ if not self.inplace then
+ self.gradInput:resizeAs(input)
+ self.output:resizeAs(input)
+ if not batch then
+ self.gradInput = self.gradInput:view(self.gradInput:size(2),
+ self.gradInput:size(3),
+ self.gradInput:size(4))
+ self.output = self.output:view(self.output:size(2),
+ self.output:size(3),
+ self.output:size(4))
+ end
end
end
end
@@ -38,12 +40,13 @@ local zero = torch.FloatTensor({0});
function Pointwise:updateOutput(input)
self:createIODescriptors(input)
+ if self.inplace then self.output = input end
errcheck('cudnnActivationForward',
cudnn.handle[cutorch.getDevice()-1], self.mode,
one:data(),
self.iDesc[0], input:data(),
zero:data(),
- self.oDesc[0], self.output:data());
+ self.iDesc[0], self.output:data());
return self.output
end
@@ -56,12 +59,12 @@ function Pointwise:updateGradInput(input, gradOutput)
gradOutput = self._gradOutput
end
self:createIODescriptors(input)
- self.gradInput:fill(1) -- to get around bug in R2-RC1 https://github.com/soumith/cudnn.torch/issues/9
+ if self.inplace then self.output = input; self.gradInput = gradOutput end
errcheck('cudnnActivationBackward',
cudnn.handle[cutorch.getDevice()-1], self.mode,
one:data(),
- self.oDesc[0], self.output:data(),
- self.oDesc[0], gradOutput:data(),
+ self.iDesc[0], self.output:data(),
+ self.iDesc[0], gradOutput:data(),
self.iDesc[0], input:data(),
zero:data(),
self.iDesc[0], self.gradInput:data());
diff --git a/README.md b/README.md
index b850ff3..d46a83c 100644
--- a/README.md
+++ b/README.md
@@ -18,9 +18,12 @@ Modules are API compatible their [`nn`](https://github.com/torch/nn) equivalents
cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
cudnn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
-cudnn.ReLU()
-cudnn.Tanh()
-cudnn.Sigmoid()
+
+-- the pointwise functions take an additional optional argument. if inplace=true then they do operations in-place without using any extra memory for themselves
+cudnn.ReLU(inplace[=false])
+cudnn.Tanh(inplace[=false])
+cudnn.Sigmoid(inplace[=false])
+
-- SoftMax can be run in fast mode or accurate mode. Default is accurate mode.
cudnn.SoftMax(fastMode [= false]) -- SoftMax across each image (just like nn.SoftMax)
cudnn.SpatialSoftMax(fastMode [= false]) -- SoftMax across feature-maps (per spatial location)
diff --git a/ReLU.lua b/ReLU.lua
index d25e69c..9edabce 100644
--- a/ReLU.lua
+++ b/ReLU.lua
@@ -1,6 +1,6 @@
local ReLU, parent = torch.class('cudnn.ReLU','cudnn._Pointwise')
-function ReLU:__init()
- parent.__init(self)
+function ReLU:__init(inplace)
+ parent.__init(self, inplace)
self.mode = 'CUDNN_ACTIVATION_RELU'
end
diff --git a/Sigmoid.lua b/Sigmoid.lua
index f3d510f..ecc880b 100644
--- a/Sigmoid.lua
+++ b/Sigmoid.lua
@@ -1,6 +1,6 @@
local Sigmoid, parent = torch.class('cudnn.Sigmoid','cudnn._Pointwise')
-function Sigmoid:__init()
- parent.__init(self)
+function Sigmoid:__init(inplace)
+ parent.__init(self, inplace)
self.mode = 'CUDNN_ACTIVATION_SIGMOID'
end
diff --git a/SpatialSoftMax.lua b/SpatialSoftMax.lua
index 1586882..87af4d5 100644
--- a/SpatialSoftMax.lua
+++ b/SpatialSoftMax.lua
@@ -57,7 +57,6 @@ function SpatialSoftMax:updateGradInput(input, gradOutput)
assert((gradOutput:dim() == 4 or gradOutput:dim() == 3)
and gradOutput:isContiguous());
self:createIODescriptors(input)
- self.gradInput:fill(1) -- to get around bug in R2-RC1 https://github.com/soumith/cudnn.torch/issues/9
errcheck('cudnnSoftmaxBackward',
cudnn.handle[cutorch.getDevice()-1],
self.algorithm, self.mode,
diff --git a/Tanh.lua b/Tanh.lua
index 0ebe563..4efb2d4 100644
--- a/Tanh.lua
+++ b/Tanh.lua
@@ -1,6 +1,6 @@
local Tanh, parent = torch.class('cudnn.Tanh','cudnn._Pointwise')
-function Tanh:__init()
- parent.__init(self)
+function Tanh:__init(inplace)
+ parent.__init(self, inplace)
self.mode = 'CUDNN_ACTIVATION_TANH'
end
diff --git a/test/test.lua b/test/test.lua
index 6b069bb..b27fd11 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -318,7 +318,7 @@ function cudnntest.SpatialMaxPooling_single()
'error on state (backward) ')
end
-function cudnntest.ReLU_single()
+local function nonlinSingle(nonlin)
local from = math.random(1,32)
local outi = math.random(1,64)
local outj = math.random(1,64)
@@ -327,19 +327,27 @@ function cudnntest.ReLU_single()
local input = torch.randn(from,inj,ini):cuda()
local gradOutput = torch.randn(from,outj,outi):cuda()
- local sconv = nn.ReLU():cuda()
+ local sconv = nn[nonlin]():cuda()
local groundtruth = sconv:forward(input)
local groundgrad = sconv:backward(input, gradOutput)
cutorch.synchronize()
- local gconv = cudnn.ReLU():cuda()
- local _ = gconv:forward(input)
+ -- 50% prob to choose inplace or out-of-place
+ local inplace = false
+ if math.random(0,1) == 1 then
+ inplace = true
+ end
+ local gconv = cudnn[nonlin](inplace):cuda()
+ local input__ = input:clone()
+ local _ = gconv:forward(input__)
-- serialize and deserialize
torch.save('modelTemp.t7', gconv)
gconv = torch.load('modelTemp.t7')
- local rescuda = gconv:forward(input)
- local resgrad = gconv:backward(input, gradOutput)
+ local input__ = input:clone()
+ local gradOutput__ = gradOutput:clone()
+ local rescuda = gconv:forward(input__)
+ local resgrad = gconv:backward(input__, gradOutput__)
cutorch.synchronize()
mytester:asserteq(rescuda:dim(), 3, 'error in dimension')
mytester:asserteq(resgrad:dim(), 3, 'error in dimension')
@@ -351,7 +359,7 @@ function cudnntest.ReLU_single()
'error on state (backward) ')
end
-function cudnntest.ReLU_batch()
+function nonlinBatch(nonlin)
local bs = math.random(1,32)
local from = math.random(1,32)
local outi = math.random(1,64)
@@ -361,19 +369,27 @@ function cudnntest.ReLU_batch()
local input = torch.randn(bs,from,inj,ini):cuda()
local gradOutput = torch.randn(bs,from,outj,outi):cuda()
- local sconv = nn.ReLU():cuda()
+ local sconv = nn[nonlin]():cuda()
local groundtruth = sconv:forward(input)
local groundgrad = sconv:backward(input, gradOutput)
cutorch.synchronize()
- local gconv = cudnn.ReLU():cuda()
- local rescuda = gconv:forward(input)
+ -- 50% prob to choose inplace or out-of-place
+ local inplace = false
+ if math.random(0,1) == 1 then
+ inplace = true
+ end
+ local gconv = cudnn[nonlin](inplace):cuda()
+ local input__ = input:clone()
+ local rescuda = gconv:forward(input__)
-- serialize and deserialize
torch.save('modelTemp.t7', gconv)
gconv = torch.load('modelTemp.t7')
- local rescuda = gconv:forward(input)
- local resgrad = gconv:backward(input, gradOutput)
+ local input__ = input:clone()
+ local gradOutput__ = gradOutput:clone()
+ local rescuda = gconv:forward(input__)
+ local resgrad = gconv:backward(input__, gradOutput__)
cutorch.synchronize()
mytester:asserteq(rescuda:dim(), 4, 'error in dimension')
mytester:asserteq(resgrad:dim(), 4, 'error in dimension')
@@ -385,188 +401,28 @@ function cudnntest.ReLU_batch()
'error on state (backward) ')
end
-function cudnntest.Tanh_single()
- local from = math.random(1,32)
- local outi = math.random(1,64)
- local outj = math.random(1,64)
- local ini = outi
- local inj = outj
- local input = torch.randn(from,inj,ini):cuda()
- local gradOutput = torch.randn(from,outj,outi):cuda()
-
- local sconv = nn.Tanh():cuda()
- local groundtruth = sconv:forward(input)
- local groundgrad = sconv:backward(input, gradOutput)
- cutorch.synchronize()
- local gconv = cudnn.Tanh():cuda()
- local _ = gconv:forward(input)
+function cudnntest.ReLU_single()
+ nonlinSingle('ReLU')
+end
- -- serialize and deserialize
- torch.save('modelTemp.t7', gconv)
- gconv = torch.load('modelTemp.t7')
+function cudnntest.ReLU_batch()
+ nonlinBatch('ReLU')
+end
- local rescuda = gconv:forward(input)
- local resgrad = gconv:backward(input, gradOutput)
- cutorch.synchronize()
- mytester:asserteq(rescuda:dim(), 3, 'error in dimension')
- mytester:asserteq(resgrad:dim(), 3, 'error in dimension')
- local error = rescuda:float() - groundtruth:float()
- local errmax = error:abs():max()
- if (errmax ~= errmax) then
- local state = {}
- state.input = input
- state.gradOutput = gradOutput
- state.rescuda = rescuda
- state.resgrad = resgrad
- state.groundtruth = groundtruth
- state.groundgrad = groundgrad
- print(#input)
- torch.save('badTanh.t7', state)
- end
- mytester:assertlt(errmax, precision_forward,
- 'error on state (forward) ')
- error = resgrad:float() - groundgrad:float()
- errmax = error:abs():max()
- if (errmax ~= errmax) then
- local state = {}
- state.input = input
- state.gradOutput = gradOutput
- state.rescuda = rescuda
- state.resgrad = resgrad
- state.groundtruth = groundtruth
- state.groundgrad = groundgrad
- print(#input)
- torch.save('badTanh.t7', state)
- end
- mytester:assertlt(errmax, precision_backward,
- 'error on state (backward) ')
+function cudnntest.Tanh_single()
+ nonlinSingle('Tanh')
end
function cudnntest.Tanh_batch()
- local bs = math.random(1,32)
- local from = math.random(1,32)
- local outi = math.random(1,64)
- local outj = math.random(1,64)
- local ini = outi
- local inj = outj
- local input = torch.randn(bs,from,inj,ini):cuda()
- local gradOutput = torch.randn(bs,from,outj,outi):cuda()
-
- local sconv = nn.Tanh():cuda()
- local groundtruth = sconv:forward(input)
- local groundgrad = sconv:backward(input, gradOutput)
- cutorch.synchronize()
- local gconv = cudnn.Tanh():cuda()
- local rescuda = gconv:forward(input)
-
- -- serialize and deserialize
- torch.save('modelTemp.t7', gconv)
- gconv = torch.load('modelTemp.t7')
-
- local rescuda = gconv:forward(input)
- local resgrad = gconv:backward(input, gradOutput)
- cutorch.synchronize()
- mytester:asserteq(rescuda:dim(), 4, 'error in dimension')
- mytester:asserteq(resgrad:dim(), 4, 'error in dimension')
- local error = rescuda:float() - groundtruth:float()
- mytester:assertlt(error:abs():max(), precision_forward,
- 'error on state (forward) ')
- error = resgrad:float() - groundgrad:float()
- mytester:assertlt(error:abs():max(), precision_backward,
- 'error on state (backward) ')
+ nonlinBatch('Tanh')
end
function cudnntest.Sigmoid_single()
- local from = math.random(1,32)
- local outi = math.random(1,64)
- local outj = math.random(1,64)
- local ini = outi
- local inj = outj
- local input = torch.randn(from,inj,ini):cuda()
- local gradOutput = torch.randn(from,outj,outi):cuda()
-
- local sconv = nn.Sigmoid():cuda()
- local groundtruth = sconv:forward(input)
- local groundgrad = sconv:backward(input, gradOutput)
- cutorch.synchronize()
- local gconv = cudnn.Sigmoid():cuda()
- local _ = gconv:forward(input)
-
- -- serialize and deserialize
- torch.save('modelTemp.t7', gconv)
- gconv = torch.load('modelTemp.t7')
-
- local rescuda = gconv:forward(input)
- local resgrad = gconv:backward(input, gradOutput)
- cutorch.synchronize()
- mytester:asserteq(rescuda:dim(), 3, 'error in dimension')
- mytester:asserteq(resgrad:dim(), 3, 'error in dimension')
- local error = rescuda:float() - groundtruth:float()
- local errmax = error:abs():max()
- if (errmax ~= errmax) then
- local state = {}
- state.input = input
- state.gradOutput = gradOutput
- state.rescuda = rescuda
- state.resgrad = resgrad
- state.groundtruth = groundtruth
- state.groundgrad = groundgrad
- print(#input)
- torch.save('badSigmoid.t7', state)
- print(#input)
- end
- mytester:assertlt(errmax, precision_forward,
- 'error on state (forward) ')
- error = resgrad:float() - groundgrad:float()
- errmax = error:abs():max()
- if (errmax ~= errmax) then
- local state = {}
- state.input = input
- state.gradOutput = gradOutput
- state.rescuda = rescuda
- state.resgrad = resgrad
- state.groundtruth = groundtruth
- state.groundgrad = groundgrad
- print(#input)
- torch.save('badSigmoid.t7', state)
- print(#input)
- end
- mytester:assertlt(errmax, precision_backward,
- 'error on state (backward) ')
+ nonlinSingle('Sigmoid')
end
function cudnntest.Sigmoid_batch()
- local bs = math.random(1,32)
- local from = math.random(1,32)
- local outi = math.random(1,64)
- local outj = math.random(1,64)
- local ini = outi
- local inj = outj
- local input = torch.randn(bs,from,inj,ini):cuda()
- local gradOutput = torch.randn(bs,from,outj,outi):cuda()
-
- local sconv = nn.Sigmoid():cuda()
- local groundtruth = sconv:forward(input)
- local groundgrad = sconv:backward(input, gradOutput)
- cutorch.synchronize()
- local gconv = cudnn.Sigmoid():cuda()
- local rescuda = gconv:forward(input)
-
- -- serialize and deserialize
- torch.save('modelTemp.t7', gconv)
- gconv = torch.load('modelTemp.t7')
-
- local rescuda = gconv:forward(input)
- local resgrad = gconv:backward(input, gradOutput)
- cutorch.synchronize()
- mytester:asserteq(rescuda:dim(), 4, 'error in dimension')
- mytester:asserteq(resgrad:dim(), 4, 'error in dimension')
- local error = rescuda:float() - groundtruth:float()
- mytester:assertlt(error:abs():max(), precision_forward,
- 'error on state (forward) ')
- error = resgrad:float() - groundgrad:float()
- mytester:assertlt(error:abs():max(), precision_backward,
- 'error on state (backward) ')
+ nonlinBatch('Sigmoid')
end
function cudnntest.SoftMax_single()