diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-12-16 23:34:11 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-12-16 23:34:11 +0300 |
commit | 0a48d8e9430c47d6074bdae3a3d5341c3c5a75ff (patch) | |
tree | 5a3fde2607af08211e0608a910297275fa0f15dd | |
parent | 9a338fa4a61c7622aff75ee4d92adddd2accebf8 (diff) | |
parent | 6a2e7ef0013250fe97726251f7b8276885f4f1d0 (diff) |
Merge pull request #401 from gchanan/dataParallelTablePrecision
Add support for Double tensors to DataParallelTable.
-rw-r--r-- | DataParallelTable.lua | 28 | ||||
-rw-r--r-- | test_DataParallelTable.lua | 293 |
2 files changed, 227 insertions, 94 deletions
diff --git a/DataParallelTable.lua b/DataParallelTable.lua index 102be72..e0194d4 100644 --- a/DataParallelTable.lua +++ b/DataParallelTable.lua @@ -47,6 +47,7 @@ function DataParallelTable:__init(dimension, flattenParams, usenccl) error "must specify a dimension!" end + self.typeStr = 'torch.CudaTensor' self.dimension = dimension self.modules = {} self.gpuAssignments = {} -- Which gpuid each module sits on @@ -101,6 +102,7 @@ end -- this flattens parameters, so that syncParameters and accGradParameters can be much more efficient function DataParallelTable:flattenParameters() + local typeStr = self.typeStr self.flattenedParams = self.impl:exec(function(module) local p, dp = module:parameters() local flattened = true @@ -112,9 +114,9 @@ function DataParallelTable:flattenParameters() end end if flattened then - local pp = torch.CudaTensor(p[1]:storage(), p[1]:storageOffset(), + local pp = torch[typeStr:match('torch.(%a+)')](p[1]:storage(), p[1]:storageOffset(), p[#p]:storageOffset()+p[#p]:numel()-p[1]:storageOffset()) - local dpp = torch.CudaTensor(dp[1]:storage(), dp[1]:storageOffset(), + local dpp = torch[typeStr:match('torch.(%a+)')](dp[1]:storage(), dp[1]:storageOffset(), dp[#dp]:storageOffset()+dp[#dp]:numel() - dp[1]:storageOffset()) return {pp, dpp} @@ -340,10 +342,12 @@ function DataParallelTable:reset(stdv) end function DataParallelTable:type(typeStr) - assert(typeStr == 'torch.CudaTensor', 'DataParallelTable supports only torch.CudaTensor type') + assert(typeStr == 'torch.CudaHalfTensor' or typeStr == 'torch.CudaTensor' or typeStr == 'torch.CudaDoubleTensor', + 'DataParallelTable supports only torch.CudaHalfTensor or torch.CudaDoubleTensor or torch.CudaTensor types') for i, m in ipairs(self.modules) do m:type(typeStr) end + self.typeStr = typeStr return self end @@ -503,7 +507,7 @@ function DataParallelTable:_reduce(gradParams) local dstGpuid = self.gpuAssignments[1] cutorch.setDevice(dstGpuid) - self.buffer = self.buffer or torch.CudaTensor() + self.buffer = self.buffer or torch[self.typeStr:match('torch.(%a+)')]() for moduleIdx = 2, #gradParams do for paramIdx = 1, #gradParams[moduleIdx] do local dst = gradParams[1][paramIdx] @@ -545,10 +549,18 @@ function DataParallelTable:_distributeTensorRecursive(dst, src, idx, n) end assert(torch.isTensor(src), 'input must be a tensor or table of tensors') - assert(src:type() == 'torch.CudaTensor' or src:type() == 'torch.FloatTensor', - 'input must be a CUDA or Float tensor') + if self.typeStr == 'torch.CudaHalfTensor' then + assert(false, + 'Half Tensors not supported yet by DataParallelTable') + elseif self.typeStr == 'torch.CudaDoubleTensor' then + assert(src:type() == self.typeStr or src:type() == 'torch.DoubleTensor', + 'input must be a CudaDouble or Double tensor') + else + assert(src:type() == 'torch.CudaTensor' or src:type() == 'torch.FloatTensor', + 'input must be a CUDA or Float tensor') + end - dst = torch.type(dst) == 'torch.CudaTensor' and dst or torch.CudaTensor() + dst = torch.type(dst) == self.typeStr and dst or torch[self.typeStr:match('torch.(%a+)')]() local srcsize = src:dim() > 0 and src:size(self.dimension) or 0 local index, size = sliceRange(srcsize, idx, n) @@ -590,7 +602,7 @@ function DataParallelTable:_concatTensorRecursive(dst, src) assert(torch.isTensor(src[1]), 'input must be a tensor or table of tensors') cutorch.setDevice(self.gpuAssignments[1]) - dst = torch.type(dst) == 'torch.CudaTensor' and dst or torch.CudaTensor() + dst = torch.type(dst) == self.typeStr and dst or torch[self.typeStr:match('torch.(%a+)')]() local cumsum = sumSizes(src, self.dimension) diff --git a/test_DataParallelTable.lua b/test_DataParallelTable.lua index 2758ad7..2b25cf2 100644 --- a/test_DataParallelTable.lua +++ b/test_DataParallelTable.lua @@ -11,8 +11,39 @@ torch.setnumthreads(8) cutorch.setDevice(baseGpu) cutorch.reserveStreams(1) +local typenames = { + 'torch.CudaTensor', + 'torch.CudaDoubleTensor', +} + +local t2cpu = { + ['torch.CudaTensor'] = 'torch.FloatTensor', + ['torch.CudaDoubleTensor'] = 'torch.DoubleTensor', + +} + +local function checkHalf() + if cutorch.hasHalf then + table.insert(typenames, 'torch.CudaHalfTensor') + t2cpu['torch.CudaHalfTensor'] = 'torch.FloatTensor' + end +end + +local function half_max_error(maxabs) + -- arbitrarily double the precision limit + return 2 * ((maxabs and (2^(math.floor(math.log(maxabs) / math.log(2)))) * (2^(-10))) or 0) +end + +-- Create an instance of the test framework +function precision(typename, max_error) + if typename == 'torch.CudaHalfTensor' then + return 5e-2 + half_max_error(max_error) + else + return 1e-5 + end +end + -- Create an instance of the test framework -local precision = 1e-5 local mytester = torch.Tester() local test = torch.TestSuite() @@ -73,7 +104,14 @@ local function deserialize(file) return net end + function test.DataParallelTable() + for k, typename in ipairs(typenames) do + test_DataParallelTable(typename) + end +end + +function test_DataParallelTable(gtype) local width = 16 local height = 16 local pool = 4 @@ -93,11 +131,11 @@ function test.DataParallelTable() numConvs) -- Build a multi-GPU model - local gClassifier = nn.DataParallelTable(1) + local gClassifier = nn.DataParallelTable(1):type(gtype) for i = 1, numGpus do local curGpu = math.fmod(baseGpu+(i-1)-1, cutorch.getDeviceCount()) + 1 cutorch.setDevice(curGpu) - gClassifier:add(cpuClassifier:clone():cuda(), curGpu) + gClassifier:add(cpuClassifier:clone():type(gtype), curGpu) end cutorch.setDevice(baseGpu) @@ -108,14 +146,14 @@ function test.DataParallelTable() cNet:add(createSplitNetwork(2,3)) cNet:add(cpuClassifier) cNet:add(nn.JoinTable(2)) - cNet:cuda() + cNet:type(gtype) local gNet = nn.Sequential() gNet:add(createSplitNetwork(2,3)) gNet:add(gClassifier) - gNet:add(nn.JoinTable(2):cuda()) - gNet:get(1):cuda() - gNet:get(3):cuda() + gNet:add(nn.JoinTable(2):type(gtype)) + gNet:get(1):type(gtype) + gNet:get(3):type(gtype) -- Force in a serialization / deserialization pass ------------ local file = serialize(gNet) @@ -125,10 +163,10 @@ function test.DataParallelTable() gNet = deserialize(file) ---------------------------------------------------------------- - local cInput = torch.rand(batchSize, 3, height, width):cuda() - local gInput = cInput:cuda() - local cTarget = torch.rand(batchSize, 2):cuda() - local gTarget = cTarget:cuda():cuda() + local cInput = torch.rand(batchSize, 3, height, width):type(gtype) + local gInput = cInput:type(gtype) + local cTarget = torch.rand(batchSize, 2):type(gtype) + local gTarget = cTarget:type(gtype):type(gtype) local cParams, cGradParams = cNet:getParameters() local gParams, gGradParams = gNet:getParameters() @@ -148,8 +186,8 @@ function test.DataParallelTable() local optimStateGpu = copyTable(optimStateCpu) local optimMethod = optim.sgd - local criterionCpu = nn.MSECriterion():cuda() - local criterionGpu = criterionCpu:clone():cuda() + local criterionCpu = nn.MSECriterion():type(gtype) + local criterionGpu = criterionCpu:clone():type(gtype) for i = 1, numSgdSteps do collectgarbage() @@ -194,17 +232,18 @@ function test.DataParallelTable() local cGradInput = cNet.gradInput local gGradInput = gNet.gradInput - mytester:assertlt((cOutput:float() - gOutput:float()):abs():max(), - precision, 'fprop error ') - mytester:assertlt((criterionCpu.gradInput:float() - - criterionGpu.gradInput:float()):abs():max(), precision, - 'CRITERION BPROP error ') - mytester:assertlt((cParams:float() - gParams:float()):abs():max(), - precision, 'parameters error ') - mytester:assertlt((cGradParams:float() - gGradParams:float()):abs():max(), - precision, 'BPROP error (gradParams)') - mytester:assertlt((cGradInput:float() - gGradInput:float()):abs():max(), - precision, 'BPROP error (gradInput)') + mytester:assertlt((cOutput:double() - gOutput:double()):abs():max(), + precision(gtype, cOutput:clone():double():abs():max()), 'fprop error ' .. gtype) + mytester:assertlt((criterionCpu.gradInput:double() - + criterionCpu.gradInput:double()):abs():max(), + precision(gtype, criterionGpu.gradInput:clone():double():abs():max()), + 'CRITERION BPROP error ' .. gtype) + mytester:assertlt((cParams:double() - gParams:double()):abs():max(), + precision(gtype, cParams:clone():double():abs():max()), 'parameters error ' .. gtype) + mytester:assertlt((cGradParams:double() - gGradParams:double()):abs():max(), + precision(gtype, cGradParams:clone():double():abs():max()), 'BPROP error (gradParams) ' .. gtype) + mytester:assertlt((cGradInput:double() - gGradInput:double()):abs():max(), + precision(gtype, cGradInput:clone():double():abs():max()), 'BPROP error (gradInput) ' .. gtype) -- Sync the CPU and GPU weights every few "epochs" to prevent floating point -- drift between SGD iterations (ie, they will eventually be divergent after @@ -222,69 +261,94 @@ function test.DataParallelTable() end function test.DataParallelTable_smallBatch() - local net = nn.SpatialConvolution(3, 3, 3, 5):cuda() + for k, typename in ipairs(typenames) do + test_DataParallelTable_smallBatch(typename) + end +end - local dpt = nn.DataParallelTable(1) +function test_DataParallelTable_smallBatch(gtype) + local net = nn.SpatialConvolution(3, 3, 3, 5):type(gtype) + + local dpt = nn.DataParallelTable(1):type(gtype) for i=1,numGpus do cutorch.withDevice(i, function() - dpt:add(net:clone():cuda(), i) + dpt:add(net:clone():type(gtype), i) end) end -- Check for batches that are smaller than numGpus or don't divide evenly for _,batchSize in ipairs{numGpus-1,2*numGpus-1} do - local input = torch.CudaTensor(batchSize,3,10,10):uniform(-1, 1) + local input = torch[gtype:match('torch.(%a+)')](batchSize,3,10,10):uniform(-1, 1) -- Check that forward works as expected local output = dpt:forward(input) local expected = net:forward(input) - assert((expected - output):abs():max() < precision, 'unexpected output') + assert((expected - output):abs():max() < precision(gtype, expected:clone():abs():max()), 'unexpected output') local gradOutput = output:clone():uniform(-1, 1) local gradInput = dpt:updateGradInput(input, gradOutput) local expected = net:updateGradInput(input, gradOutput) - assert((expected - gradInput):abs():max() < precision, 'unexpected gradInput') + assert((expected - gradInput):abs():max() < precision(gtype, expected:clone():abs():max()), 'unexpected gradInput') end end function test.DataParallelTable_emptyTensor() - local net = nn.Sequential():add(nn.SelectTable(2)):add(nn.Linear(10,2)):cuda() + for k, typename in ipairs(typenames) do + test_DataParallelTable_emptyTensor(typename) + end +end - local dpt = nn.DataParallelTable(1) +function test_DataParallelTable_emptyTensor(gtype) + local net = nn.Sequential():add(nn.SelectTable(2)):add(nn.Linear(10,2)):type(gtype) + + local dpt = nn.DataParallelTable(1):type(gtype) for i=1,numGpus do cutorch.withDevice(i, function() - dpt:add(net:clone():cuda(), i) + dpt:add(net:clone():type(gtype), i) end) end - local input = {torch.CudaTensor(0), torch.CudaTensor(numGpus, 10):fill(1)} + local input = {torch[gtype:match('torch.(%a+)')](0), torch[gtype:match('torch.(%a+)')](numGpus, 10):fill(1)} local output = dpt:forward(input) local expected = net:forward(input) - assert((output - expected ):abs():max() < precision, 'unexpected output') + assert((output - expected ):abs():max() < precision(gtype, expected:clone():abs():max()), 'unexpected output') local gradOutput = output:clone():uniform(-1,1) local gradInput = dpt:backward(input, gradOutput) local expected = net:backward(input, gradOutput) - assert((expected[2] - gradInput[2]):abs():max() < precision, 'unexpected gradInput') + assert((expected[2] - gradInput[2]):abs():max() < precision(gtype, expected[2]:clone():abs():max()), 'unexpected gradInput') end function test.DataParallelTable_type() - local net = nn.SpatialConvolution(3, 3, 3, 5) + for k, typename in ipairs(typenames) do + test_DataParallelTable_type(typename) + end +end - local dpt = nn.DataParallelTable(1) +function test_DataParallelTable_type(gtype) + local ctype = t2cpu[gtype] + local net = nn.SpatialConvolution(3, 3, 3, 5):type(ctype) + + local dpt = nn.DataParallelTable(1):type(gtype) for i=1,numGpus do cutorch.withDevice(i, function() dpt:add(net:clone(), i) end) end - dpt:cuda() + dpt:type(gtype) - ok = pcall(function() dpt:float() end) - assert(not ok, 'should not be able to call DataParallelTable:float()') + ok = pcall(function() dpt:type(ctype) end) + assert(not ok, 'should not be able to call DataParallelTable:type(' .. ctype .. ')') end function test.DataParallelTable_sync() + for k, typename in ipairs(typenames) do + test_DataParallelTable_sync(typename) + end +end + +function test_DataParallelTable_sync(gtype) -- Test that DataParallelTable automatically syncParameters in updateOutput -- if you forget to call :syncParameters() local nSteps = 10 @@ -292,24 +356,24 @@ function test.DataParallelTable_sync() :add(nn.Linear(10, 10)) :add(nn.ReLU(true)) :add(nn.Linear(10, 10)) - :cuda() + :type(gtype) - local dpt = nn.DataParallelTable(1) + local dpt = nn.DataParallelTable(1):type(gtype) for i=1,numGpus do cutorch.withDevice(i, function() dpt:add(net:clone(), i) end) end - local criterion = nn.MSECriterion():cuda() + local criterion = nn.MSECriterion():type(gtype) local optimState = { learningRate = 1, momentum = 0, } - local input = torch.CudaTensor(numGpus,10) - local target = torch.CudaTensor(numGpus,10) + local input = torch[gtype:match('torch.(%a+)')](numGpus,10) + local target = torch[gtype:match('torch.(%a+)')](numGpus,10) local function feval(net) local params, gradParams = net:getParameters() @@ -333,25 +397,31 @@ function test.DataParallelTable_sync() optim.sgd(fevalBase, paramsBase, optimState) end - assert((paramsDpt - paramsBase):abs():max() < precision, + assert((paramsDpt - paramsBase):abs():max() < precision(gtype, paramsDpt:clone():abs():max()), 'parameters do not match') end function test.DataParallelTable_serialize() + for k, typename in ipairs(typenames) do + test_DataParallelTable_serialize(typename) + end +end + +function test_DataParallelTable_serialize(gtype) -- Test serialization after getParameters() - local net = nn.Linear(10, 10):cuda() + local net = nn.Linear(10, 10):type(gtype) - local dpt = nn.DataParallelTable(1) + local dpt = nn.DataParallelTable(1):type(gtype) for i=1,numGpus do cutorch.withDevice(i, function() - dpt:add(net:clone():cuda(), i) + dpt:add(net:clone():type(gtype), i) end) end dpt:getParameters() dpt = deserialize(serialize(dpt)) - local input = torch.CudaTensor(numGpus,10):uniform(-1, 1) + local input = torch[gtype:match('torch.(%a+)')](numGpus,10):uniform(-1, 1) -- Check that forward works as expected local output = dpt:forward(input) @@ -368,17 +438,23 @@ end function test.DataParallelTable_flattenParameters() + for k, typename in ipairs(typenames) do + test_DataParallelTable_flattenParameters(typename) + end +end + +function test_DataParallelTable_flattenParameters(gtype) -- Wrap only a part of a network with data parallel table and -- check if the correct number of parameters have been copied local seq = nn.Sequential() - local layer1 = nn.Linear(10, 10):cuda() - local layer2 = nn.Linear(10, 5):cuda() - local dpt = nn.DataParallelTable(1, true, true):threads():cuda() + local layer1 = nn.Linear(10, 10):type(gtype) + local layer2 = nn.Linear(10, 5):type(gtype) + local dpt = nn.DataParallelTable(1, true, true):threads():type(gtype) dpt:add(layer2, torch.range(1, numGpus):totable()) seq:add(layer1):add(dpt) seq:getParameters() - local input = torch.randn(7, 10):cuda() + local input = torch.randn(7, 10):type(gtype) seq:forward(input) -- There are 55 parameters in layer 2 (50 + 5 bias weights) assert(dpt.flattenedParams[1][1]:size(1) == 55, "Incorrect number of " .. @@ -389,17 +465,23 @@ function test.DataParallelTable_flattenParameters() end function test.DataParallelTable_misc() + for k, typename in ipairs(typenames) do + test_DataParallelTable_misc(typename) + end +end + +function test_DataParallelTable_misc(gtype) local net = nn.Sequential() :add(nn.Linear(3, 10)) :add(nn.ReLU()) :add(nn.Linear(10, 7)) - local dpt = nn.DataParallelTable(1) + local dpt = nn.DataParallelTable(1):type(gtype) :add(net, torch.range(1, numGpus):totable()) :threads() - :cuda() + :type(gtype) - local input = torch.randn(8, 3):cuda() + local input = torch.randn(8, 3):type(gtype) local output = dpt:forward(input) -- check that clone works @@ -417,41 +499,53 @@ function test.DataParallelTable_misc() end function test.DataParallelTable_noGradInput() + for k, typename in ipairs(typenames) do + test_DataParallelTable_noGradInput(typename) + end +end + +function test_DataParallelTable_noGradInput(gtype) local net = nn.Sequential() :add(nn.LookupTable(10, 10)) :add(nn.Linear(10, 7)) :add(nn.ReLU()) - :cuda() + :type(gtype) local dpt = nn.DataParallelTable(1) :add(net, torch.range(1, numGpus):totable()) :threads() - :cuda() + :type(gtype) - local input = torch.Tensor(5):random(10):cuda() + local input = torch.Tensor(5):random(10):type(gtype) local output1 = net:forward(input):clone() local gradOutput = output1:clone():uniform(-1, 1) local gradInput1 = net:backward(input, gradOutput):clone() local output2 = dpt:forward(input) local gradInput2 = dpt:backward(input, gradOutput) - mytester:assertlt((output1 - output2):abs():max(), precision, + mytester:assertlt((output1 - output2):abs():max(), precision(gtype, output1:clone():abs():max()), 'forward prop error') mytester:asserteq(gradInput2:nElement(), gradInput1:nElement()) end function test.DataParallelTable_accGradParameters() + for k, typename in ipairs(typenames) do + test_DataParallelTable_accGradParameters(typename) + end +end + +function test_DataParallelTable_accGradParameters(gtype) local net = nn.Sequential() :add(nn.Linear(3, 10)) :add(nn.ReLU()) :add(nn.Linear(10, 7)) - :cuda() + :type(gtype) local inputs = {} local gradOutputs = {} for i=1,3 do - inputs[i] = torch.randn(8, 3):cuda() - gradOutputs[i] = torch.randn(8, 7):cuda() + inputs[i] = torch.randn(8, 3):type(gtype) + gradOutputs[i] = torch.randn(8, 7):type(gtype) end local configs = { @@ -474,25 +568,31 @@ function test.DataParallelTable_accGradParameters() for _, config in ipairs(configs) do local dpt = nn.DataParallelTable(table.unpack(config)) - :add(net:clone(), torch.range(1, numGpus):totable()) + :add(net:clone(), torch.range(1, numGpus):totable()):type(gtype) accumulateGradient(dpt) local output = dpt:forward(inputs[1]) - mytester:assertlt((output - expected):abs():max(), 1e-5, 'invalid output') + mytester:assertlt((output - expected):abs():max(), precision(gtype, expected:clone():abs():max()), 'invalid output ' .. gtype) end end function test.DataParallelTable_apply() + for k, typename in ipairs(typenames) do + test_DataParallelTable_apply(typename) + end +end + +function test_DataParallelTable_apply(gtype) local net = nn.Sequential() :add(nn.Linear(3, 10)) :add(nn.ReLU()) :add(nn.Linear(10, 7)) - :cuda() + :type(gtype) local inputs = {} local gradOutputs = {} for i=1,3 do - inputs[i] = torch.randn(8, 3):cuda() - gradOutputs[i] = torch.randn(8, 7):cuda() + inputs[i] = torch.randn(8, 3):type(gtype) + gradOutputs[i] = torch.randn(8, 7):type(gtype) end local configs = { @@ -521,13 +621,13 @@ function test.DataParallelTable_apply() for _, usethreads in ipairs{false,true} do for _, config in ipairs(configs) do local dpt = nn.DataParallelTable(table.unpack(config)) - :add(net:clone(), torch.range(1, numGpus):totable()) + :add(net:clone(), torch.range(1, numGpus):totable()):type(gtype) if usethreads then dpt:threads() end trainNetwork(dpt) local output = dpt:forward(inputs[1]) - mytester:assertlt((output - expected):abs():max(), 1e-5, + mytester:assertlt((output - expected):abs():max(), precision(gtype, expected:clone():abs():max()), 'invalid output: flatten=' .. tostring(config[2]) .. ' threads=' .. tostring(usethreads)) end @@ -535,14 +635,20 @@ function test.DataParallelTable_apply() end function test.DataParallelTable_streams() + for k, typename in ipairs(typenames) do + test_DataParallelTable_streams(typename) + end +end + +function test_DataParallelTable_streams(gtype) local net = nn.Sequential() :add(nn.Linear(3, 10)) :add(nn.ReLU()) :add(nn.Linear(10, 7)) - :cuda() + :type(gtype) - local input = torch.randn(8, 3):cuda() - local gradOutput = torch.randn(8, 7):cuda() + local input = torch.randn(8, 3):type(gtype) + local gradOutput = torch.randn(8, 7):type(gtype) local gOutput = net:forward(input):clone() net:zeroGradParameters() local gGradInput = net:backward(input, gradOutput):clone() @@ -569,7 +675,7 @@ function test.DataParallelTable_streams() for _, threads in ipairs{false, true} do local dpt = nn.DataParallelTable(table.unpack(config)) :add(net, torch.range(1, numGpus):totable()) - :cuda() + :type(gtype) if threads then dpt:threads(function() cutorch.reserveStreams(1) @@ -584,6 +690,12 @@ function test.DataParallelTable_streams() end function test.DataParallelTable_emptyData() + for k, typename in ipairs(typenames) do + test_DataParallelTable_emptyData(typename) + end +end + +function test_DataParallelTable_emptyData(gtype) local function eq(a,b) if not torch.isTensor(a) then local res = true @@ -601,11 +713,11 @@ function test.DataParallelTable_emptyData() local a = nn.DataParallelTable(1) a:add(identity, torch.range(1,numGpus):totable()) - a:cuda() + a:type(gtype) - local inputs = {torch.range(1,numGpus*5):reshape(numGpus,5):cuda(), - torch.range(1,5):reshape(1,5):cuda(), - torch.range(1,10):reshape(2,5):cuda(), + local inputs = {torch.range(1,numGpus*5):reshape(numGpus,5):type(gtype), + torch.range(1,5):reshape(1,5):type(gtype), + torch.range(1,10):reshape(2,5):type(gtype), } for _, input in ipairs(inputs) do @@ -617,7 +729,7 @@ function test.DataParallelTable_emptyData() a = nn.DataParallelTable(1) a:add(nn.ParallelTable():add(identity):add(identity), torch.range(1,numGpus):totable()) - a:cuda() + a:type(gtype) for _, input in ipairs(inputs) do input = {input, input} @@ -630,6 +742,12 @@ end function test.ProfileDataParallelTable() + for k, typename in ipairs(typenames) do + test_ProfileDataParallelTable(typename) + end +end + +function test_ProfileDataParallelTable(gtype) local width = 32 local height = 32 local pool = 4 @@ -661,17 +779,19 @@ function test.ProfileDataParallelTable() local gNet = module(1) if (moduleName == 'DataParallel') then cutorch.setDevice(baseGpu) - gNet:cuda() + gNet:type(gtype) + elseif (moduleName == 'DataParallelTable') then + gNet:type(gtype) end for i = 1, numGpus do local curGpu = math.fmod(baseGpu+(i-1)-1, cutorch.getDeviceCount())+1 cutorch.setDevice(curGpu) - gNet:add(cNet:clone():cuda(), curGpu) + gNet:add(cNet:clone():type(gtype), curGpu) end cutorch.setDevice(baseGpu) - local input = torch.rand(batchSize, 3, height, width):cuda() - local target = torch.rand(batchSize, 2):cuda() + local input = torch.rand(batchSize, 3, height, width):type(gtype) + local target = torch.rand(batchSize, 2):type(gtype) local gParams, gGradParams if (moduleName == 'DataParallelTable') then @@ -695,7 +815,7 @@ function test.ProfileDataParallelTable() nesterov = true, } local optimMethod = optim.sgd - local criterion = nn.MSECriterion():cuda() + local criterion = nn.MSECriterion():type(gtype) local timeGpuNet = 0 local opt @@ -742,5 +862,6 @@ function test.ProfileDataParallelTable() end -- Now run the test above +--checkHalf() -- half not enabled yet for DataParallelTable mytester:add(test) mytester:run() |