diff options
author | Gregory Chanan <gchanan@fb.com> | 2016-11-09 00:06:52 +0300 |
---|---|---|
committer | Gregory Chanan <gchanan@fb.com> | 2016-11-09 00:06:52 +0300 |
commit | aa18682ea8471c9b8b1b949621149cec8d816b1d (patch) | |
tree | f9ed0b9c210ac3ac39c63b276be928b5ff020787 /test.lua | |
parent | ba27398bc5fd98ad0df4b995b725beb8cbf355d1 (diff) |
Rebase BatchNormalization.
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 190 |
1 files changed, 84 insertions, 106 deletions
@@ -820,135 +820,113 @@ end local function BatchNormalization_forward(moduleName, inputSize) local planes = inputSize[2] - local tm = {} - local title = moduleName .. '.forward ' .. table.concat(inputSize, 'x') - times[title] = tm - local input = torch.randn(table.unpack(inputSize)) - local sbnorm = nn[moduleName](planes) - local groundtruth = sbnorm:forward(input) - local a = torch.Timer() - for i = 1,nloop do - groundtruth = sbnorm:forward(input) - end - tm.cpu = a:time().real + for k, typename in ipairs(typenames) do + local input = torch.randn(table.unpack(inputSize)):type(typename) - input = input:cuda() - local gbnorm = nn[moduleName](planes):cuda() - gbnorm.weight = sbnorm.weight:cuda() - gbnorm.bias = sbnorm.bias:cuda() - local rescuda = gbnorm:forward(input) + local ctype = t2cpu[typename] + input = input:type(ctype) + local sbnorm = nn[moduleName](planes):type(ctype) + local groundtruth = sbnorm:forward(input) - a:reset() - for i = 1,nloop do - rescuda = gbnorm:forward(input) - end - cutorch.synchronize() - tm.gpu = a:time().real + input = input:type(typename) + local gbnorm = nn[moduleName](planes):type(typename) + gbnorm.weight = sbnorm.weight:type(typename) + gbnorm.bias = sbnorm.bias:type(typename) + local rescuda = gbnorm:forward(input) - local error = rescuda:float() - groundtruth - mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward)') - mytester:assertlt((gbnorm.running_mean:float() - sbnorm.running_mean):abs():max(), - precision_forward, 'error on running_mean (forward)') - mytester:assertlt((gbnorm.running_var:float() - sbnorm.running_var):abs():max(), - precision_forward, 'error on running_var (forward)') + local error = rescuda:double() - groundtruth:double() + mytester:assertlt(error:abs():max(), precision_forward_type(precision_forward, typename), + string.format('error on state (forward) with %s', typename)) + mytester:assertlt((gbnorm.running_mean:double() - sbnorm.running_mean:double()):abs():max(), + precision_forward_type(precision_forward, typename), + string.format('error on running_mean (forward) with %s', typenanme)) + mytester:assertlt((gbnorm.running_var:double() - sbnorm.running_var:double()):abs():max(), + precision_forward_type(precision_forward, typename), + string.format('error on running_var (forward) with %s', typename)) + end end local function BatchNormalization_forward_inference(moduleName, inputSize) local planes = inputSize[2] - local tm = {} - local title = moduleName .. '.forward (evaluate) ' .. table.concat(inputSize, 'x') - times[title] = tm - local input = torch.randn(table.unpack(inputSize)) - local sbnorm = nn[moduleName](planes) - sbnorm.running_mean:normal(1, 2) - sbnorm.running_var:uniform(1e-3, 2) - sbnorm:evaluate() - local groundtruth = sbnorm:forward(input) - local a = torch.Timer() - for i = 1,nloop do - groundtruth = sbnorm:forward(input) - end - tm.cpu = a:time().real + for k, typename in ipairs(typenames) do + local input = torch.randn(table.unpack(inputSize)):type(typename) - input = input:cuda() - local gbnorm = nn[moduleName](planes):cuda() - gbnorm:evaluate() - gbnorm.weight = sbnorm.weight:cuda() - gbnorm.bias = sbnorm.bias:cuda() - gbnorm.running_mean = sbnorm.running_mean:cuda() - gbnorm.running_var = sbnorm.running_var:cuda() - local rescuda = gbnorm:forward(input) - a:reset() - for i = 1,nloop do - rescuda = gbnorm:forward(input) - end - cutorch.synchronize() - tm.gpu = a:time().real + local ctype = t2cpu[typename] + input = input:type(ctype) + local sbnorm = nn[moduleName](planes):type(ctype) + sbnorm.running_mean:normal(1, 2) + sbnorm.running_var:uniform(1e-3, 2) + sbnorm:evaluate() + local groundtruth = sbnorm:forward(input) - local error = rescuda:float() - groundtruth - mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward evaluate)') + input = input:type(typename) + local gbnorm = nn[moduleName](planes):type(typename) + gbnorm:evaluate() + gbnorm.weight = sbnorm.weight:type(typename) + gbnorm.bias = sbnorm.bias:type(typename) + gbnorm.running_mean = sbnorm.running_mean:type(typename) + gbnorm.running_var = sbnorm.running_var:type(typename) + local rescuda = gbnorm:forward(input) + + local error = rescuda:double() - groundtruth:double() + mytester:assertlt(error:abs():max(), 3*precision_forward_type(precision_forward, typename), + string.format('error on state (forward evaluate) with %s', typename)) + end end local function BatchNormalization_backward(moduleName, mode, inputSize, backwardFn) assert(mode == 'training' or mode == 'evaluation', 'invalid mode') local planes = inputSize[2] - local tm = {} - local title = moduleName .. '.backward ' .. table.concat(inputSize, 'x') - times[title] = tm - local input = torch.randn(table.unpack(inputSize)) - local gradOutput = torch.randn(table.unpack(inputSize)) - local sbnorm = nn[moduleName](planes) - if mode == 'training' then - sbnorm:training() - else - sbnorm:evaluate() - end - sbnorm:forward(input) - sbnorm:zeroGradParameters() - local groundgrad = backwardFn(sbnorm, input, gradOutput) - local a = torch.Timer() - for i = 1,nloop do + for k, typename in ipairs(typenames) do + local input = torch.randn(table.unpack(inputSize)):type(typename) + local gradOutput = torch.randn(table.unpack(inputSize)):type(typename) + + local ctype = t2cpu[typename] + input = input:type(ctype) + gradOutput = gradOutput:type(ctype) + local sbnorm = nn[moduleName](planes):type(ctype) + if mode == 'training' then + sbnorm:training() + else + sbnorm:evaluate() + end + sbnorm:forward(input) sbnorm:zeroGradParameters() - groundgrad = backwardFn(sbnorm, input, gradOutput) - end - local groundweight = sbnorm.gradWeight - local groundbias = sbnorm.gradBias - tm.cpu = a:time().real + local groundgrad = backwardFn(sbnorm, input, gradOutput) + local groundweight = sbnorm.gradWeight + local groundbias = sbnorm.gradBias - input = input:cuda() - gradOutput = gradOutput:cuda() - local gbnorm = nn[moduleName](planes):cuda() - if mode == 'training' then - gbnorm:training() - else - gbnorm:evaluate() - end - gbnorm.weight = sbnorm.weight:cuda() - gbnorm.bias = sbnorm.bias:cuda() - gbnorm:forward(input) - gbnorm:zeroGradParameters() - local rescuda = backwardFn(gbnorm, input, gradOutput) - a:reset() - for i = 1,nloop do + input = input:type(typename) + gradOutput = gradOutput:type(typename) + local gbnorm = nn[moduleName](planes):type(typename) + if mode == 'training' then + gbnorm:training() + else + gbnorm:evaluate() + end + gbnorm.weight = sbnorm.weight:type(typename) + gbnorm.bias = sbnorm.bias:type(typename) + gbnorm:forward(input) gbnorm:zeroGradParameters() - rescuda = backwardFn(gbnorm, input, gradOutput) - end - local weightcuda = gbnorm.gradWeight - local biascuda = gbnorm.gradBias - cutorch.synchronize() - tm.gpu = a:time().real + local rescuda = backwardFn(gbnorm, input, gradOutput) + local weightcuda = gbnorm.gradWeight + local biascuda = gbnorm.gradBias - local error = rescuda:float() - groundgrad - local werror = weightcuda:float() - groundweight - local berror = biascuda:float() - groundbias + local error = rescuda:double() - groundgrad:double() + local werror = weightcuda:double() - groundweight:double() + local berror = biascuda:double() - groundbias:double() - mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ') - mytester:assertlt(werror:abs():max(), precision_backward, 'error on weight (backward) ') - mytester:assertlt(berror:abs():max(), precision_backward, 'error on bias (backward) ') + mytester:assertlt(error:abs():max(), 3*precision_backward_type(precision_backward, typename), + string.format('error on state (backward) with %s', typename)) + mytester:assertlt(werror:abs():max(), 5*precision_backward_type(precision_backward, typename), + string.format('error on weight (backward) with %s', typename)) + mytester:assertlt(berror:abs():max(), 5*precision_backward_type(precision_backward, typename), + string.format('error on bias (backward) with %s', typename)) + end end local function testBatchNormalization(name, dim, k) |