Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-12-16 23:34:11 +0300
committerGitHub <noreply@github.com>2016-12-16 23:34:11 +0300
commit0a48d8e9430c47d6074bdae3a3d5341c3c5a75ff (patch)
tree5a3fde2607af08211e0608a910297275fa0f15dd
parent9a338fa4a61c7622aff75ee4d92adddd2accebf8 (diff)
parent6a2e7ef0013250fe97726251f7b8276885f4f1d0 (diff)
Merge pull request #401 from gchanan/dataParallelTablePrecision
Add support for Double tensors to DataParallelTable.
-rw-r--r--DataParallelTable.lua28
-rw-r--r--test_DataParallelTable.lua293
2 files changed, 227 insertions, 94 deletions
diff --git a/DataParallelTable.lua b/DataParallelTable.lua
index 102be72..e0194d4 100644
--- a/DataParallelTable.lua
+++ b/DataParallelTable.lua
@@ -47,6 +47,7 @@ function DataParallelTable:__init(dimension, flattenParams, usenccl)
error "must specify a dimension!"
end
+ self.typeStr = 'torch.CudaTensor'
self.dimension = dimension
self.modules = {}
self.gpuAssignments = {} -- Which gpuid each module sits on
@@ -101,6 +102,7 @@ end
-- this flattens parameters, so that syncParameters and accGradParameters can be much more efficient
function DataParallelTable:flattenParameters()
+ local typeStr = self.typeStr
self.flattenedParams = self.impl:exec(function(module)
local p, dp = module:parameters()
local flattened = true
@@ -112,9 +114,9 @@ function DataParallelTable:flattenParameters()
end
end
if flattened then
- local pp = torch.CudaTensor(p[1]:storage(), p[1]:storageOffset(),
+ local pp = torch[typeStr:match('torch.(%a+)')](p[1]:storage(), p[1]:storageOffset(),
p[#p]:storageOffset()+p[#p]:numel()-p[1]:storageOffset())
- local dpp = torch.CudaTensor(dp[1]:storage(), dp[1]:storageOffset(),
+ local dpp = torch[typeStr:match('torch.(%a+)')](dp[1]:storage(), dp[1]:storageOffset(),
dp[#dp]:storageOffset()+dp[#dp]:numel()
- dp[1]:storageOffset())
return {pp, dpp}
@@ -340,10 +342,12 @@ function DataParallelTable:reset(stdv)
end
function DataParallelTable:type(typeStr)
- assert(typeStr == 'torch.CudaTensor', 'DataParallelTable supports only torch.CudaTensor type')
+ assert(typeStr == 'torch.CudaHalfTensor' or typeStr == 'torch.CudaTensor' or typeStr == 'torch.CudaDoubleTensor',
+ 'DataParallelTable supports only torch.CudaHalfTensor or torch.CudaDoubleTensor or torch.CudaTensor types')
for i, m in ipairs(self.modules) do
m:type(typeStr)
end
+ self.typeStr = typeStr
return self
end
@@ -503,7 +507,7 @@ function DataParallelTable:_reduce(gradParams)
local dstGpuid = self.gpuAssignments[1]
cutorch.setDevice(dstGpuid)
- self.buffer = self.buffer or torch.CudaTensor()
+ self.buffer = self.buffer or torch[self.typeStr:match('torch.(%a+)')]()
for moduleIdx = 2, #gradParams do
for paramIdx = 1, #gradParams[moduleIdx] do
local dst = gradParams[1][paramIdx]
@@ -545,10 +549,18 @@ function DataParallelTable:_distributeTensorRecursive(dst, src, idx, n)
end
assert(torch.isTensor(src), 'input must be a tensor or table of tensors')
- assert(src:type() == 'torch.CudaTensor' or src:type() == 'torch.FloatTensor',
- 'input must be a CUDA or Float tensor')
+ if self.typeStr == 'torch.CudaHalfTensor' then
+ assert(false,
+ 'Half Tensors not supported yet by DataParallelTable')
+ elseif self.typeStr == 'torch.CudaDoubleTensor' then
+ assert(src:type() == self.typeStr or src:type() == 'torch.DoubleTensor',
+ 'input must be a CudaDouble or Double tensor')
+ else
+ assert(src:type() == 'torch.CudaTensor' or src:type() == 'torch.FloatTensor',
+ 'input must be a CUDA or Float tensor')
+ end
- dst = torch.type(dst) == 'torch.CudaTensor' and dst or torch.CudaTensor()
+ dst = torch.type(dst) == self.typeStr and dst or torch[self.typeStr:match('torch.(%a+)')]()
local srcsize = src:dim() > 0 and src:size(self.dimension) or 0
local index, size = sliceRange(srcsize, idx, n)
@@ -590,7 +602,7 @@ function DataParallelTable:_concatTensorRecursive(dst, src)
assert(torch.isTensor(src[1]), 'input must be a tensor or table of tensors')
cutorch.setDevice(self.gpuAssignments[1])
- dst = torch.type(dst) == 'torch.CudaTensor' and dst or torch.CudaTensor()
+ dst = torch.type(dst) == self.typeStr and dst or torch[self.typeStr:match('torch.(%a+)')]()
local cumsum = sumSizes(src, self.dimension)
diff --git a/test_DataParallelTable.lua b/test_DataParallelTable.lua
index 2758ad7..2b25cf2 100644
--- a/test_DataParallelTable.lua
+++ b/test_DataParallelTable.lua
@@ -11,8 +11,39 @@ torch.setnumthreads(8)
cutorch.setDevice(baseGpu)
cutorch.reserveStreams(1)
+local typenames = {
+ 'torch.CudaTensor',
+ 'torch.CudaDoubleTensor',
+}
+
+local t2cpu = {
+ ['torch.CudaTensor'] = 'torch.FloatTensor',
+ ['torch.CudaDoubleTensor'] = 'torch.DoubleTensor',
+
+}
+
+local function checkHalf()
+ if cutorch.hasHalf then
+ table.insert(typenames, 'torch.CudaHalfTensor')
+ t2cpu['torch.CudaHalfTensor'] = 'torch.FloatTensor'
+ end
+end
+
+local function half_max_error(maxabs)
+ -- arbitrarily double the precision limit
+ return 2 * ((maxabs and (2^(math.floor(math.log(maxabs) / math.log(2)))) * (2^(-10))) or 0)
+end
+
+-- Create an instance of the test framework
+function precision(typename, max_error)
+ if typename == 'torch.CudaHalfTensor' then
+ return 5e-2 + half_max_error(max_error)
+ else
+ return 1e-5
+ end
+end
+
-- Create an instance of the test framework
-local precision = 1e-5
local mytester = torch.Tester()
local test = torch.TestSuite()
@@ -73,7 +104,14 @@ local function deserialize(file)
return net
end
+
function test.DataParallelTable()
+ for k, typename in ipairs(typenames) do
+ test_DataParallelTable(typename)
+ end
+end
+
+function test_DataParallelTable(gtype)
local width = 16
local height = 16
local pool = 4
@@ -93,11 +131,11 @@ function test.DataParallelTable()
numConvs)
-- Build a multi-GPU model
- local gClassifier = nn.DataParallelTable(1)
+ local gClassifier = nn.DataParallelTable(1):type(gtype)
for i = 1, numGpus do
local curGpu = math.fmod(baseGpu+(i-1)-1, cutorch.getDeviceCount()) + 1
cutorch.setDevice(curGpu)
- gClassifier:add(cpuClassifier:clone():cuda(), curGpu)
+ gClassifier:add(cpuClassifier:clone():type(gtype), curGpu)
end
cutorch.setDevice(baseGpu)
@@ -108,14 +146,14 @@ function test.DataParallelTable()
cNet:add(createSplitNetwork(2,3))
cNet:add(cpuClassifier)
cNet:add(nn.JoinTable(2))
- cNet:cuda()
+ cNet:type(gtype)
local gNet = nn.Sequential()
gNet:add(createSplitNetwork(2,3))
gNet:add(gClassifier)
- gNet:add(nn.JoinTable(2):cuda())
- gNet:get(1):cuda()
- gNet:get(3):cuda()
+ gNet:add(nn.JoinTable(2):type(gtype))
+ gNet:get(1):type(gtype)
+ gNet:get(3):type(gtype)
-- Force in a serialization / deserialization pass ------------
local file = serialize(gNet)
@@ -125,10 +163,10 @@ function test.DataParallelTable()
gNet = deserialize(file)
----------------------------------------------------------------
- local cInput = torch.rand(batchSize, 3, height, width):cuda()
- local gInput = cInput:cuda()
- local cTarget = torch.rand(batchSize, 2):cuda()
- local gTarget = cTarget:cuda():cuda()
+ local cInput = torch.rand(batchSize, 3, height, width):type(gtype)
+ local gInput = cInput:type(gtype)
+ local cTarget = torch.rand(batchSize, 2):type(gtype)
+ local gTarget = cTarget:type(gtype):type(gtype)
local cParams, cGradParams = cNet:getParameters()
local gParams, gGradParams = gNet:getParameters()
@@ -148,8 +186,8 @@ function test.DataParallelTable()
local optimStateGpu = copyTable(optimStateCpu)
local optimMethod = optim.sgd
- local criterionCpu = nn.MSECriterion():cuda()
- local criterionGpu = criterionCpu:clone():cuda()
+ local criterionCpu = nn.MSECriterion():type(gtype)
+ local criterionGpu = criterionCpu:clone():type(gtype)
for i = 1, numSgdSteps do
collectgarbage()
@@ -194,17 +232,18 @@ function test.DataParallelTable()
local cGradInput = cNet.gradInput
local gGradInput = gNet.gradInput
- mytester:assertlt((cOutput:float() - gOutput:float()):abs():max(),
- precision, 'fprop error ')
- mytester:assertlt((criterionCpu.gradInput:float() -
- criterionGpu.gradInput:float()):abs():max(), precision,
- 'CRITERION BPROP error ')
- mytester:assertlt((cParams:float() - gParams:float()):abs():max(),
- precision, 'parameters error ')
- mytester:assertlt((cGradParams:float() - gGradParams:float()):abs():max(),
- precision, 'BPROP error (gradParams)')
- mytester:assertlt((cGradInput:float() - gGradInput:float()):abs():max(),
- precision, 'BPROP error (gradInput)')
+ mytester:assertlt((cOutput:double() - gOutput:double()):abs():max(),
+ precision(gtype, cOutput:clone():double():abs():max()), 'fprop error ' .. gtype)
+ mytester:assertlt((criterionCpu.gradInput:double() -
+ criterionCpu.gradInput:double()):abs():max(),
+ precision(gtype, criterionGpu.gradInput:clone():double():abs():max()),
+ 'CRITERION BPROP error ' .. gtype)
+ mytester:assertlt((cParams:double() - gParams:double()):abs():max(),
+ precision(gtype, cParams:clone():double():abs():max()), 'parameters error ' .. gtype)
+ mytester:assertlt((cGradParams:double() - gGradParams:double()):abs():max(),
+ precision(gtype, cGradParams:clone():double():abs():max()), 'BPROP error (gradParams) ' .. gtype)
+ mytester:assertlt((cGradInput:double() - gGradInput:double()):abs():max(),
+ precision(gtype, cGradInput:clone():double():abs():max()), 'BPROP error (gradInput) ' .. gtype)
-- Sync the CPU and GPU weights every few "epochs" to prevent floating point
-- drift between SGD iterations (ie, they will eventually be divergent after
@@ -222,69 +261,94 @@ function test.DataParallelTable()
end
function test.DataParallelTable_smallBatch()
- local net = nn.SpatialConvolution(3, 3, 3, 5):cuda()
+ for k, typename in ipairs(typenames) do
+ test_DataParallelTable_smallBatch(typename)
+ end
+end
- local dpt = nn.DataParallelTable(1)
+function test_DataParallelTable_smallBatch(gtype)
+ local net = nn.SpatialConvolution(3, 3, 3, 5):type(gtype)
+
+ local dpt = nn.DataParallelTable(1):type(gtype)
for i=1,numGpus do
cutorch.withDevice(i, function()
- dpt:add(net:clone():cuda(), i)
+ dpt:add(net:clone():type(gtype), i)
end)
end
-- Check for batches that are smaller than numGpus or don't divide evenly
for _,batchSize in ipairs{numGpus-1,2*numGpus-1} do
- local input = torch.CudaTensor(batchSize,3,10,10):uniform(-1, 1)
+ local input = torch[gtype:match('torch.(%a+)')](batchSize,3,10,10):uniform(-1, 1)
-- Check that forward works as expected
local output = dpt:forward(input)
local expected = net:forward(input)
- assert((expected - output):abs():max() < precision, 'unexpected output')
+ assert((expected - output):abs():max() < precision(gtype, expected:clone():abs():max()), 'unexpected output')
local gradOutput = output:clone():uniform(-1, 1)
local gradInput = dpt:updateGradInput(input, gradOutput)
local expected = net:updateGradInput(input, gradOutput)
- assert((expected - gradInput):abs():max() < precision, 'unexpected gradInput')
+ assert((expected - gradInput):abs():max() < precision(gtype, expected:clone():abs():max()), 'unexpected gradInput')
end
end
function test.DataParallelTable_emptyTensor()
- local net = nn.Sequential():add(nn.SelectTable(2)):add(nn.Linear(10,2)):cuda()
+ for k, typename in ipairs(typenames) do
+ test_DataParallelTable_emptyTensor(typename)
+ end
+end
- local dpt = nn.DataParallelTable(1)
+function test_DataParallelTable_emptyTensor(gtype)
+ local net = nn.Sequential():add(nn.SelectTable(2)):add(nn.Linear(10,2)):type(gtype)
+
+ local dpt = nn.DataParallelTable(1):type(gtype)
for i=1,numGpus do
cutorch.withDevice(i, function()
- dpt:add(net:clone():cuda(), i)
+ dpt:add(net:clone():type(gtype), i)
end)
end
- local input = {torch.CudaTensor(0), torch.CudaTensor(numGpus, 10):fill(1)}
+ local input = {torch[gtype:match('torch.(%a+)')](0), torch[gtype:match('torch.(%a+)')](numGpus, 10):fill(1)}
local output = dpt:forward(input)
local expected = net:forward(input)
- assert((output - expected ):abs():max() < precision, 'unexpected output')
+ assert((output - expected ):abs():max() < precision(gtype, expected:clone():abs():max()), 'unexpected output')
local gradOutput = output:clone():uniform(-1,1)
local gradInput = dpt:backward(input, gradOutput)
local expected = net:backward(input, gradOutput)
- assert((expected[2] - gradInput[2]):abs():max() < precision, 'unexpected gradInput')
+ assert((expected[2] - gradInput[2]):abs():max() < precision(gtype, expected[2]:clone():abs():max()), 'unexpected gradInput')
end
function test.DataParallelTable_type()
- local net = nn.SpatialConvolution(3, 3, 3, 5)
+ for k, typename in ipairs(typenames) do
+ test_DataParallelTable_type(typename)
+ end
+end
- local dpt = nn.DataParallelTable(1)
+function test_DataParallelTable_type(gtype)
+ local ctype = t2cpu[gtype]
+ local net = nn.SpatialConvolution(3, 3, 3, 5):type(ctype)
+
+ local dpt = nn.DataParallelTable(1):type(gtype)
for i=1,numGpus do
cutorch.withDevice(i, function()
dpt:add(net:clone(), i)
end)
end
- dpt:cuda()
+ dpt:type(gtype)
- ok = pcall(function() dpt:float() end)
- assert(not ok, 'should not be able to call DataParallelTable:float()')
+ ok = pcall(function() dpt:type(ctype) end)
+ assert(not ok, 'should not be able to call DataParallelTable:type(' .. ctype .. ')')
end
function test.DataParallelTable_sync()
+ for k, typename in ipairs(typenames) do
+ test_DataParallelTable_sync(typename)
+ end
+end
+
+function test_DataParallelTable_sync(gtype)
-- Test that DataParallelTable automatically syncParameters in updateOutput
-- if you forget to call :syncParameters()
local nSteps = 10
@@ -292,24 +356,24 @@ function test.DataParallelTable_sync()
:add(nn.Linear(10, 10))
:add(nn.ReLU(true))
:add(nn.Linear(10, 10))
- :cuda()
+ :type(gtype)
- local dpt = nn.DataParallelTable(1)
+ local dpt = nn.DataParallelTable(1):type(gtype)
for i=1,numGpus do
cutorch.withDevice(i, function()
dpt:add(net:clone(), i)
end)
end
- local criterion = nn.MSECriterion():cuda()
+ local criterion = nn.MSECriterion():type(gtype)
local optimState = {
learningRate = 1,
momentum = 0,
}
- local input = torch.CudaTensor(numGpus,10)
- local target = torch.CudaTensor(numGpus,10)
+ local input = torch[gtype:match('torch.(%a+)')](numGpus,10)
+ local target = torch[gtype:match('torch.(%a+)')](numGpus,10)
local function feval(net)
local params, gradParams = net:getParameters()
@@ -333,25 +397,31 @@ function test.DataParallelTable_sync()
optim.sgd(fevalBase, paramsBase, optimState)
end
- assert((paramsDpt - paramsBase):abs():max() < precision,
+ assert((paramsDpt - paramsBase):abs():max() < precision(gtype, paramsDpt:clone():abs():max()),
'parameters do not match')
end
function test.DataParallelTable_serialize()
+ for k, typename in ipairs(typenames) do
+ test_DataParallelTable_serialize(typename)
+ end
+end
+
+function test_DataParallelTable_serialize(gtype)
-- Test serialization after getParameters()
- local net = nn.Linear(10, 10):cuda()
+ local net = nn.Linear(10, 10):type(gtype)
- local dpt = nn.DataParallelTable(1)
+ local dpt = nn.DataParallelTable(1):type(gtype)
for i=1,numGpus do
cutorch.withDevice(i, function()
- dpt:add(net:clone():cuda(), i)
+ dpt:add(net:clone():type(gtype), i)
end)
end
dpt:getParameters()
dpt = deserialize(serialize(dpt))
- local input = torch.CudaTensor(numGpus,10):uniform(-1, 1)
+ local input = torch[gtype:match('torch.(%a+)')](numGpus,10):uniform(-1, 1)
-- Check that forward works as expected
local output = dpt:forward(input)
@@ -368,17 +438,23 @@ end
function test.DataParallelTable_flattenParameters()
+ for k, typename in ipairs(typenames) do
+ test_DataParallelTable_flattenParameters(typename)
+ end
+end
+
+function test_DataParallelTable_flattenParameters(gtype)
-- Wrap only a part of a network with data parallel table and
-- check if the correct number of parameters have been copied
local seq = nn.Sequential()
- local layer1 = nn.Linear(10, 10):cuda()
- local layer2 = nn.Linear(10, 5):cuda()
- local dpt = nn.DataParallelTable(1, true, true):threads():cuda()
+ local layer1 = nn.Linear(10, 10):type(gtype)
+ local layer2 = nn.Linear(10, 5):type(gtype)
+ local dpt = nn.DataParallelTable(1, true, true):threads():type(gtype)
dpt:add(layer2, torch.range(1, numGpus):totable())
seq:add(layer1):add(dpt)
seq:getParameters()
- local input = torch.randn(7, 10):cuda()
+ local input = torch.randn(7, 10):type(gtype)
seq:forward(input)
-- There are 55 parameters in layer 2 (50 + 5 bias weights)
assert(dpt.flattenedParams[1][1]:size(1) == 55, "Incorrect number of " ..
@@ -389,17 +465,23 @@ function test.DataParallelTable_flattenParameters()
end
function test.DataParallelTable_misc()
+ for k, typename in ipairs(typenames) do
+ test_DataParallelTable_misc(typename)
+ end
+end
+
+function test_DataParallelTable_misc(gtype)
local net = nn.Sequential()
:add(nn.Linear(3, 10))
:add(nn.ReLU())
:add(nn.Linear(10, 7))
- local dpt = nn.DataParallelTable(1)
+ local dpt = nn.DataParallelTable(1):type(gtype)
:add(net, torch.range(1, numGpus):totable())
:threads()
- :cuda()
+ :type(gtype)
- local input = torch.randn(8, 3):cuda()
+ local input = torch.randn(8, 3):type(gtype)
local output = dpt:forward(input)
-- check that clone works
@@ -417,41 +499,53 @@ function test.DataParallelTable_misc()
end
function test.DataParallelTable_noGradInput()
+ for k, typename in ipairs(typenames) do
+ test_DataParallelTable_noGradInput(typename)
+ end
+end
+
+function test_DataParallelTable_noGradInput(gtype)
local net = nn.Sequential()
:add(nn.LookupTable(10, 10))
:add(nn.Linear(10, 7))
:add(nn.ReLU())
- :cuda()
+ :type(gtype)
local dpt = nn.DataParallelTable(1)
:add(net, torch.range(1, numGpus):totable())
:threads()
- :cuda()
+ :type(gtype)
- local input = torch.Tensor(5):random(10):cuda()
+ local input = torch.Tensor(5):random(10):type(gtype)
local output1 = net:forward(input):clone()
local gradOutput = output1:clone():uniform(-1, 1)
local gradInput1 = net:backward(input, gradOutput):clone()
local output2 = dpt:forward(input)
local gradInput2 = dpt:backward(input, gradOutput)
- mytester:assertlt((output1 - output2):abs():max(), precision,
+ mytester:assertlt((output1 - output2):abs():max(), precision(gtype, output1:clone():abs():max()),
'forward prop error')
mytester:asserteq(gradInput2:nElement(), gradInput1:nElement())
end
function test.DataParallelTable_accGradParameters()
+ for k, typename in ipairs(typenames) do
+ test_DataParallelTable_accGradParameters(typename)
+ end
+end
+
+function test_DataParallelTable_accGradParameters(gtype)
local net = nn.Sequential()
:add(nn.Linear(3, 10))
:add(nn.ReLU())
:add(nn.Linear(10, 7))
- :cuda()
+ :type(gtype)
local inputs = {}
local gradOutputs = {}
for i=1,3 do
- inputs[i] = torch.randn(8, 3):cuda()
- gradOutputs[i] = torch.randn(8, 7):cuda()
+ inputs[i] = torch.randn(8, 3):type(gtype)
+ gradOutputs[i] = torch.randn(8, 7):type(gtype)
end
local configs = {
@@ -474,25 +568,31 @@ function test.DataParallelTable_accGradParameters()
for _, config in ipairs(configs) do
local dpt = nn.DataParallelTable(table.unpack(config))
- :add(net:clone(), torch.range(1, numGpus):totable())
+ :add(net:clone(), torch.range(1, numGpus):totable()):type(gtype)
accumulateGradient(dpt)
local output = dpt:forward(inputs[1])
- mytester:assertlt((output - expected):abs():max(), 1e-5, 'invalid output')
+ mytester:assertlt((output - expected):abs():max(), precision(gtype, expected:clone():abs():max()), 'invalid output ' .. gtype)
end
end
function test.DataParallelTable_apply()
+ for k, typename in ipairs(typenames) do
+ test_DataParallelTable_apply(typename)
+ end
+end
+
+function test_DataParallelTable_apply(gtype)
local net = nn.Sequential()
:add(nn.Linear(3, 10))
:add(nn.ReLU())
:add(nn.Linear(10, 7))
- :cuda()
+ :type(gtype)
local inputs = {}
local gradOutputs = {}
for i=1,3 do
- inputs[i] = torch.randn(8, 3):cuda()
- gradOutputs[i] = torch.randn(8, 7):cuda()
+ inputs[i] = torch.randn(8, 3):type(gtype)
+ gradOutputs[i] = torch.randn(8, 7):type(gtype)
end
local configs = {
@@ -521,13 +621,13 @@ function test.DataParallelTable_apply()
for _, usethreads in ipairs{false,true} do
for _, config in ipairs(configs) do
local dpt = nn.DataParallelTable(table.unpack(config))
- :add(net:clone(), torch.range(1, numGpus):totable())
+ :add(net:clone(), torch.range(1, numGpus):totable()):type(gtype)
if usethreads then
dpt:threads()
end
trainNetwork(dpt)
local output = dpt:forward(inputs[1])
- mytester:assertlt((output - expected):abs():max(), 1e-5,
+ mytester:assertlt((output - expected):abs():max(), precision(gtype, expected:clone():abs():max()),
'invalid output: flatten=' .. tostring(config[2]) ..
' threads=' .. tostring(usethreads))
end
@@ -535,14 +635,20 @@ function test.DataParallelTable_apply()
end
function test.DataParallelTable_streams()
+ for k, typename in ipairs(typenames) do
+ test_DataParallelTable_streams(typename)
+ end
+end
+
+function test_DataParallelTable_streams(gtype)
local net = nn.Sequential()
:add(nn.Linear(3, 10))
:add(nn.ReLU())
:add(nn.Linear(10, 7))
- :cuda()
+ :type(gtype)
- local input = torch.randn(8, 3):cuda()
- local gradOutput = torch.randn(8, 7):cuda()
+ local input = torch.randn(8, 3):type(gtype)
+ local gradOutput = torch.randn(8, 7):type(gtype)
local gOutput = net:forward(input):clone()
net:zeroGradParameters()
local gGradInput = net:backward(input, gradOutput):clone()
@@ -569,7 +675,7 @@ function test.DataParallelTable_streams()
for _, threads in ipairs{false, true} do
local dpt = nn.DataParallelTable(table.unpack(config))
:add(net, torch.range(1, numGpus):totable())
- :cuda()
+ :type(gtype)
if threads then
dpt:threads(function()
cutorch.reserveStreams(1)
@@ -584,6 +690,12 @@ function test.DataParallelTable_streams()
end
function test.DataParallelTable_emptyData()
+ for k, typename in ipairs(typenames) do
+ test_DataParallelTable_emptyData(typename)
+ end
+end
+
+function test_DataParallelTable_emptyData(gtype)
local function eq(a,b)
if not torch.isTensor(a) then
local res = true
@@ -601,11 +713,11 @@ function test.DataParallelTable_emptyData()
local a = nn.DataParallelTable(1)
a:add(identity, torch.range(1,numGpus):totable())
- a:cuda()
+ a:type(gtype)
- local inputs = {torch.range(1,numGpus*5):reshape(numGpus,5):cuda(),
- torch.range(1,5):reshape(1,5):cuda(),
- torch.range(1,10):reshape(2,5):cuda(),
+ local inputs = {torch.range(1,numGpus*5):reshape(numGpus,5):type(gtype),
+ torch.range(1,5):reshape(1,5):type(gtype),
+ torch.range(1,10):reshape(2,5):type(gtype),
}
for _, input in ipairs(inputs) do
@@ -617,7 +729,7 @@ function test.DataParallelTable_emptyData()
a = nn.DataParallelTable(1)
a:add(nn.ParallelTable():add(identity):add(identity), torch.range(1,numGpus):totable())
- a:cuda()
+ a:type(gtype)
for _, input in ipairs(inputs) do
input = {input, input}
@@ -630,6 +742,12 @@ end
function test.ProfileDataParallelTable()
+ for k, typename in ipairs(typenames) do
+ test_ProfileDataParallelTable(typename)
+ end
+end
+
+function test_ProfileDataParallelTable(gtype)
local width = 32
local height = 32
local pool = 4
@@ -661,17 +779,19 @@ function test.ProfileDataParallelTable()
local gNet = module(1)
if (moduleName == 'DataParallel') then
cutorch.setDevice(baseGpu)
- gNet:cuda()
+ gNet:type(gtype)
+ elseif (moduleName == 'DataParallelTable') then
+ gNet:type(gtype)
end
for i = 1, numGpus do
local curGpu = math.fmod(baseGpu+(i-1)-1, cutorch.getDeviceCount())+1
cutorch.setDevice(curGpu)
- gNet:add(cNet:clone():cuda(), curGpu)
+ gNet:add(cNet:clone():type(gtype), curGpu)
end
cutorch.setDevice(baseGpu)
- local input = torch.rand(batchSize, 3, height, width):cuda()
- local target = torch.rand(batchSize, 2):cuda()
+ local input = torch.rand(batchSize, 3, height, width):type(gtype)
+ local target = torch.rand(batchSize, 2):type(gtype)
local gParams, gGradParams
if (moduleName == 'DataParallelTable') then
@@ -695,7 +815,7 @@ function test.ProfileDataParallelTable()
nesterov = true,
}
local optimMethod = optim.sgd
- local criterion = nn.MSECriterion():cuda()
+ local criterion = nn.MSECriterion():type(gtype)
local timeGpuNet = 0
local opt
@@ -742,5 +862,6 @@ function test.ProfileDataParallelTable()
end
-- Now run the test above
+--checkHalf() -- half not enabled yet for DataParallelTable
mytester:add(test)
mytester:run()