Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/threads-ffi.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2015-04-19 07:23:52 +0300
committerRonan Collobert <ronan@collobert.com>2015-04-19 07:23:52 +0300
commitd223e9d60ed04f91016867065a5ef7575d337dce (patch)
tree35c3df050899fade1d929c540eb6366dab679708
parent83b1299a92c949ca913962e4e09d3a3581f04a5d (diff)
updated neural network threaded training example such that it uses shared serialization
-rw-r--r--benchmark/threadedtrain.lua175
-rw-r--r--benchmark/utils.lua37
2 files changed, 74 insertions, 138 deletions
diff --git a/benchmark/threadedtrain.lua b/benchmark/threadedtrain.lua
index 02f8485..859b43f 100644
--- a/benchmark/threadedtrain.lua
+++ b/benchmark/threadedtrain.lua
@@ -1,8 +1,5 @@
-local ffi = require 'ffi'
local Threads = require 'threads'
-require 'utils'
-
local function threadedTrain(module, criterion, data, label, params)
-- corner case: we are here to do batches
@@ -20,108 +17,84 @@ local function threadedTrain(module, criterion, data, label, params)
-- in the end, we normalize ourselves per batch-size
criterion.sizeAverage = false
- local weight = module:getParameters()
- local weight_p = tonumber(ffi.cast('intptr_t', weight:data()))
- local weight_nelem = weight:nElement()
- local data_p = tonumber(ffi.cast('intptr_t', data:data()))
- local label_p = tonumber(ffi.cast('intptr_t', label:data()))
-
- local data_size = data:size()
- local data_nelem = data:nElement()
- local label_size = label:size()
- local label_nelem = label:nElement()
-
- local threads, gradweights = Threads(params.threads,
- function()
- require 'nn'
- require 'utils'
- end,
-
- function()
- local ffi = require 'ffi'
-
- gmodule = module
- gcriterion = criterion
-
- sharefloatstorage(gmodule:get(1).weight:storage(), weight_p)
- gdatastorage = torch.FloatStorage()
- sharefloatstorage(gdatastorage, data_p, data_nelem)
- gdata = torch.FloatTensor(gdatastorage, 1, data_size)
-
- glabelstorage = torch.LongStorage()
- sharelongstorage(glabelstorage, label_p, label_nelem)
- glabel = torch.LongTensor(glabelstorage, 1, label_size)
-
- gdataset = {}
-
- local nex = glabel:size(1)
-
- if params.batch == 1 or params.batch == params.threads then
- function gdataset:size()
- return nex
- end
-
- setmetatable(gdataset, {__index = function(self, index)
- return {gdata[index], glabel[index]}
- end})
- else
- assert(nex % params.batch == 0, '# of examples must be divisible with batch size')
- assert(params.batch % params.threads == 0, 'batch size must be divisible threads')
- local n = params.batch/params.threads
- function gdataset:size()
- return nex/n
- end
- setmetatable(gdataset, {__index = function(self, index)
- return {gdata:narrow(1,(index-1)*n+1, n),
- glabel:narrow(1,(index-1)*n+1, n)}
- end})
- end
-
- function gupdate(idx)
- local ex = gdataset[idx]
- local x, y = ex[1], ex[2]
-
- local z = gmodule:forward(x)
- local err = gcriterion:forward(z, y)
- gmodule:zeroGradParameters()
- gmodule:updateGradInput(x, gcriterion:updateGradInput(gmodule.output, y))
- gmodule:accGradParameters(x, gcriterion.gradInput)
-
- return err
- end
-
- return tonumber(ffi.cast('intptr_t', gmodule:get(1).gradWeight:data()))
- end)
-
-
- for i=1,params.threads do
- local gradweight = torch.FloatStorage()
- sharefloatstorage(gradweight, gradweights[i][1], weight_nelem)
- gradweights[i] = torch.FloatTensor(gradweight)
- end
+ Threads.serialization('threads.sharedserialize')
+ local threads = Threads(
+ params.threads,
+ function()
+ require 'nn'
+ end,
+
+ function()
+ local module = module:clone('weights', 'bias')
+ local weights, dweights = module:parameters()
+ local criterion = criterion:clone()
+ local data = data
+ local label = label
+ local dataset = {}
+
+ local nex = label:size(1)
+
+ if params.batch == 1 then
+ function dataset:size()
+ return nex
+ end
+
+ setmetatable(dataset, {__index =
+ function(self, index)
+ return {data[index], label[index]}
+ end})
+ else
+ assert(nex % params.batch == 0, '# of examples must be divisible with batch size')
+ local batch = params.batch
+ function dataset:size()
+ return nex/batch
+ end
+ setmetatable(dataset, {__index =
+ function(self, index)
+ return {
+ data:narrow(1,(index-1)*batch+1, batch),
+ label:narrow(1,(index-1)*batch+1, batch)
+ }
+ end})
+ end
+
+ function gupdate(idx)
+ local ex = dataset[idx]
+ local x, y = ex[1], ex[2]
+ local z = module:forward(x)
+ local err = criterion:forward(z, y)
+ module:zeroGradParameters()
+ module:updateGradInput(x, criterion:updateGradInput(module.output, y))
+ module:accGradParameters(x, criterion.gradInput)
+ return err, dweights
+ end
+ end
+ )
+
+ local weights = module:parameters()
for iter=1,params.iter do
local totalerr = 0
- for b=1,label:size(1)/params.batch do
- for t=1,params.threads do
- local idx = (b-1)*params.threads + t
-
- threads:addjob(function(idx)
- return gupdate(idx)
- end,
-
- function(err)
- totalerr = totalerr + err
- end,
-
- idx
- )
- end
- threads:synchronize()
- for i=1,params.threads do
- weight:add(-0.01/params.batch, gradweights[i])
- end
+ local idx = 1
+ while idx < label:size(1)/params.batch do
+
+ threads:addjob(
+ function(idx)
+ return gupdate(idx)
+ end,
+
+ function(err, dweights)
+ totalerr = totalerr + err
+ for i=1,#weights do
+ weights[i]:add(-0.01, dweights[i])
+ end
+ end,
+ idx
+ )
+
+ idx = idx + 1
end
+ threads:synchronize()
print('# current error = ', totalerr/label:size(1))
end
diff --git a/benchmark/utils.lua b/benchmark/utils.lua
deleted file mode 100644
index 2abcd63..0000000
--- a/benchmark/utils.lua
+++ /dev/null
@@ -1,37 +0,0 @@
-local ffi = require 'ffi'
-
-local TH_STORAGE_REFCOUNTED = 1
-local TH_STORAGE_RESIZABLE = 2
-local TH_STORAGE_FREEMEM = 4
-
-function sharefloatstorage(storage, data_p, sz)
- local storage_p = ffi.cast('THFloatStorage*', torch.pointer(storage))
- assert(bit.band(storage_p.flag, TH_STORAGE_REFCOUNTED) ~= 0)
-
- if storage_p.data ~= nil then
- storage_p.allocator.free(storage_p.allocatorContext, storage_p.data)
- end
-
- storage_p.data = ffi.cast('float*', data_p)
- if sz then
- storage_p.size = sz
- end
-
- storage_p.flag = TH_STORAGE_REFCOUNTED
-end
-
-function sharelongstorage(storage, data_p, sz)
- local storage_p = ffi.cast('THLongStorage*', torch.pointer(storage))
- assert(bit.band(storage_p.flag, TH_STORAGE_REFCOUNTED) ~= 0)
-
- if storage_p.data ~= nil then
- storage_p.allocator.free(storage_p.allocatorContext, storage_p.data)
- end
-
- storage_p.data = ffi.cast('long*', data_p)
- if sz then
- storage_p.size = sz
- end
-
- storage_p.flag = TH_STORAGE_REFCOUNTED
-end