diff options
author | Ronan Collobert <ronan@collobert.com> | 2015-04-19 07:23:52 +0300 |
---|---|---|
committer | Ronan Collobert <ronan@collobert.com> | 2015-04-19 07:23:52 +0300 |
commit | d223e9d60ed04f91016867065a5ef7575d337dce (patch) | |
tree | 35c3df050899fade1d929c540eb6366dab679708 | |
parent | 83b1299a92c949ca913962e4e09d3a3581f04a5d (diff) |
updated neural network threaded training example such that it uses shared serialization
-rw-r--r-- | benchmark/threadedtrain.lua | 175 | ||||
-rw-r--r-- | benchmark/utils.lua | 37 |
2 files changed, 74 insertions, 138 deletions
diff --git a/benchmark/threadedtrain.lua b/benchmark/threadedtrain.lua index 02f8485..859b43f 100644 --- a/benchmark/threadedtrain.lua +++ b/benchmark/threadedtrain.lua @@ -1,8 +1,5 @@ -local ffi = require 'ffi' local Threads = require 'threads' -require 'utils' - local function threadedTrain(module, criterion, data, label, params) -- corner case: we are here to do batches @@ -20,108 +17,84 @@ local function threadedTrain(module, criterion, data, label, params) -- in the end, we normalize ourselves per batch-size criterion.sizeAverage = false - local weight = module:getParameters() - local weight_p = tonumber(ffi.cast('intptr_t', weight:data())) - local weight_nelem = weight:nElement() - local data_p = tonumber(ffi.cast('intptr_t', data:data())) - local label_p = tonumber(ffi.cast('intptr_t', label:data())) - - local data_size = data:size() - local data_nelem = data:nElement() - local label_size = label:size() - local label_nelem = label:nElement() - - local threads, gradweights = Threads(params.threads, - function() - require 'nn' - require 'utils' - end, - - function() - local ffi = require 'ffi' - - gmodule = module - gcriterion = criterion - - sharefloatstorage(gmodule:get(1).weight:storage(), weight_p) - gdatastorage = torch.FloatStorage() - sharefloatstorage(gdatastorage, data_p, data_nelem) - gdata = torch.FloatTensor(gdatastorage, 1, data_size) - - glabelstorage = torch.LongStorage() - sharelongstorage(glabelstorage, label_p, label_nelem) - glabel = torch.LongTensor(glabelstorage, 1, label_size) - - gdataset = {} - - local nex = glabel:size(1) - - if params.batch == 1 or params.batch == params.threads then - function gdataset:size() - return nex - end - - setmetatable(gdataset, {__index = function(self, index) - return {gdata[index], glabel[index]} - end}) - else - assert(nex % params.batch == 0, '# of examples must be divisible with batch size') - assert(params.batch % params.threads == 0, 'batch size must be divisible threads') - local n = params.batch/params.threads - function gdataset:size() - return nex/n - end - setmetatable(gdataset, {__index = function(self, index) - return {gdata:narrow(1,(index-1)*n+1, n), - glabel:narrow(1,(index-1)*n+1, n)} - end}) - end - - function gupdate(idx) - local ex = gdataset[idx] - local x, y = ex[1], ex[2] - - local z = gmodule:forward(x) - local err = gcriterion:forward(z, y) - gmodule:zeroGradParameters() - gmodule:updateGradInput(x, gcriterion:updateGradInput(gmodule.output, y)) - gmodule:accGradParameters(x, gcriterion.gradInput) - - return err - end - - return tonumber(ffi.cast('intptr_t', gmodule:get(1).gradWeight:data())) - end) - - - for i=1,params.threads do - local gradweight = torch.FloatStorage() - sharefloatstorage(gradweight, gradweights[i][1], weight_nelem) - gradweights[i] = torch.FloatTensor(gradweight) - end + Threads.serialization('threads.sharedserialize') + local threads = Threads( + params.threads, + function() + require 'nn' + end, + + function() + local module = module:clone('weights', 'bias') + local weights, dweights = module:parameters() + local criterion = criterion:clone() + local data = data + local label = label + local dataset = {} + + local nex = label:size(1) + + if params.batch == 1 then + function dataset:size() + return nex + end + + setmetatable(dataset, {__index = + function(self, index) + return {data[index], label[index]} + end}) + else + assert(nex % params.batch == 0, '# of examples must be divisible with batch size') + local batch = params.batch + function dataset:size() + return nex/batch + end + setmetatable(dataset, {__index = + function(self, index) + return { + data:narrow(1,(index-1)*batch+1, batch), + label:narrow(1,(index-1)*batch+1, batch) + } + end}) + end + + function gupdate(idx) + local ex = dataset[idx] + local x, y = ex[1], ex[2] + local z = module:forward(x) + local err = criterion:forward(z, y) + module:zeroGradParameters() + module:updateGradInput(x, criterion:updateGradInput(module.output, y)) + module:accGradParameters(x, criterion.gradInput) + return err, dweights + end + end + ) + + local weights = module:parameters() for iter=1,params.iter do local totalerr = 0 - for b=1,label:size(1)/params.batch do - for t=1,params.threads do - local idx = (b-1)*params.threads + t - - threads:addjob(function(idx) - return gupdate(idx) - end, - - function(err) - totalerr = totalerr + err - end, - - idx - ) - end - threads:synchronize() - for i=1,params.threads do - weight:add(-0.01/params.batch, gradweights[i]) - end + local idx = 1 + while idx < label:size(1)/params.batch do + + threads:addjob( + function(idx) + return gupdate(idx) + end, + + function(err, dweights) + totalerr = totalerr + err + for i=1,#weights do + weights[i]:add(-0.01, dweights[i]) + end + end, + idx + ) + + idx = idx + 1 end + threads:synchronize() print('# current error = ', totalerr/label:size(1)) end diff --git a/benchmark/utils.lua b/benchmark/utils.lua deleted file mode 100644 index 2abcd63..0000000 --- a/benchmark/utils.lua +++ /dev/null @@ -1,37 +0,0 @@ -local ffi = require 'ffi' - -local TH_STORAGE_REFCOUNTED = 1 -local TH_STORAGE_RESIZABLE = 2 -local TH_STORAGE_FREEMEM = 4 - -function sharefloatstorage(storage, data_p, sz) - local storage_p = ffi.cast('THFloatStorage*', torch.pointer(storage)) - assert(bit.band(storage_p.flag, TH_STORAGE_REFCOUNTED) ~= 0) - - if storage_p.data ~= nil then - storage_p.allocator.free(storage_p.allocatorContext, storage_p.data) - end - - storage_p.data = ffi.cast('float*', data_p) - if sz then - storage_p.size = sz - end - - storage_p.flag = TH_STORAGE_REFCOUNTED -end - -function sharelongstorage(storage, data_p, sz) - local storage_p = ffi.cast('THLongStorage*', torch.pointer(storage)) - assert(bit.band(storage_p.flag, TH_STORAGE_REFCOUNTED) ~= 0) - - if storage_p.data ~= nil then - storage_p.allocator.free(storage_p.allocatorContext, storage_p.data) - end - - storage_p.data = ffi.cast('long*', data_p) - if sz then - storage_p.size = sz - end - - storage_p.flag = TH_STORAGE_REFCOUNTED -end |