Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--DataList.lua4
-rw-r--r--DataSet.lua16
-rw-r--r--DataSetLabelMe.lua669
-rw-r--r--generic/DataSetLabelMe.c54
-rw-r--r--init.c5
-rw-r--r--init.lua1
-rw-r--r--nnx-1.0-1.rockspec1
7 files changed, 743 insertions, 7 deletions
diff --git a/DataList.lua b/DataList.lua
index 5b95c48..99b117a 100644
--- a/DataList.lua
+++ b/DataList.lua
@@ -16,8 +16,8 @@ function DataList:__init()
end
function DataList:__tostring__()
- str = 'DataList'
- str = str .. ' + nb samples : '..self.nbSamples
+ str = 'DataList:\n'
+ str = str .. ' + nb samples : '..self.nbSamples..'\n'
str = str .. ' + nb classes : '..self.nbClass
return str
end
diff --git a/DataSet.lua b/DataSet.lua
index 81e4bf3..23faf73 100644
--- a/DataSet.lua
+++ b/DataSet.lua
@@ -1,7 +1,7 @@
--------------------------------------------------------------------------------
-- DataSet: a class to handle standard datasets.
--
--- Authors: Corda, Farabet
+-- Authors: Clement Farabet, Benoit Corda
--------------------------------------------------------------------------------
local lDataSet = torch.class('nn.DataSet')
@@ -18,6 +18,16 @@ function lDataSet:size()
return self.nbSamples
end
+function lDataSet:__tostring__()
+ str = 'DataSet:\n'
+ if self.nbSamples then
+ str = str .. ' + nb samples : '..self.nbSamples
+ else
+ str = str .. ' + empty set...'
+ end
+ return str
+end
+
function lDataSet:load(...)
-- parse args
local args, dataSetFolder, nbSamplesRequired, cacheFile, channels, sampleSize
@@ -304,10 +314,6 @@ function lDataSet:display(args) -- opt args : scale, nbSamples
end
end
-function lDataSet:__show()
- self:display{nbSamples=100}
-end
-
function lDataSet:useCacheFile(fileName)
self.cacheFileName = fileName
end
diff --git a/DataSetLabelMe.lua b/DataSetLabelMe.lua
new file mode 100644
index 0000000..5076c82
--- /dev/null
+++ b/DataSetLabelMe.lua
@@ -0,0 +1,669 @@
+--------------------------------------------------------------------------------
+-- DataSetLabelMe: A class to handle datasets from LabelMe (and other segmentation
+-- based datasets).
+--
+-- Provides lots of options to cache (on disk) datasets, precompute
+-- segmentation masks, shuffle samples, extract subpatches, ...
+--
+-- Authors: Clement Farabet, Benoit Corda
+--------------------------------------------------------------------------------
+
+local DataSetLabelMe = torch.class('DataSetLabelMe')
+
+local path_images = 'Images'
+local path_annotations = 'Annotations'
+local path_masks = 'Masks'
+
+function DataSetLabelMe:__init(...)
+ -- check args
+ toolBox.unpack_class(
+ self,
+ {...},
+ 'DataSetLabelMe',
+ 'Creates a DataSet from standard LabelMe directories (Images+Annotations)',
+ {arg='path', type='string', help='path to LabelMe directory', req=true},
+ {arg='nbClasses', type='number', help='number of classes in dataset', default=1},
+ {arg='classNames', type='table', help='list of class names', default={'no name'}},
+ {arg='nbRawSamples', type='number', help='number of images'},
+ {arg='rawSampleMaxSize', type='number', help='resize all images to fit in a MxM window'},
+ {arg='rawSampleSize', type='table', help='resize all images precisely'},
+ {arg='nbPatchPerSample', type='number', help='number of patches to extract from each image', default=100},
+ {arg='patchSize', type='number', help='size of patches to extract from images', default=64},
+ {arg='samplingMode', type='string', help='patch sampling method: random | uniform', default='random'},
+ {arg='labelType', type='string', help='type of label returned: center | pixelwise | fovea', default='center'},
+ {arg='fovea', type='nn.SpatialFovea', help='if passed, it will be focused automatically'},
+ {arg='infiniteSet', type='boolean', help='if true, the set can be indexed to infinity, looping around samples', default=false},
+ {arg='classToSkip', type='number', help='index of class to skip during sampling', default=0},
+ {arg='preloadSamples', type='boolean', help='if true, all samples are preloaded in memory', default=false},
+ {arg='cacheFile', type='string', help='path to cache file (once cached, loading is much faster)'},
+ {arg='processor', type='nn.fovea', help='module that postprocess the data for training'},
+ {arg='verbose', type='boolean', help='dumps information', default=false}
+ )
+
+ -- fixed parameters
+ self.colorMap = image.createColorMap(self.nbClasses)
+ self.rawdata = {}
+ self.currentIndex = -1
+ --location of the patch in the img
+ self.currentX = 0
+ self.currentY = 0
+
+ -- parse dir structure
+ print('<DataSetLabelMe> loading LabelMe dataset from '..self.path)
+ for folder in paths.files(paths.concat(self.path,path_images)) do
+ if folder ~= '.' and folder ~= '..' then
+ for file in paths.files(paths.concat(self.path,path_images,folder)) do
+ if file ~= '.' and file ~= '..' then
+ local filepng = file:gsub('jpg$','png')
+ local filexml = file:gsub('jpg$','xml')
+ local imgf = paths.concat(self.path,path_images,folder,file)
+ local maskf = paths.concat(self.path,path_masks,folder,filepng)
+ local annotf = paths.concat(self.path,path_annotations,folder,filexml)
+ local size_x, size_y, size_c = image.getJPGsize(imgf)
+ table.insert(self.rawdata, {imgfile=imgf,
+ maskfile=maskf,
+ annotfile=annotf,
+ size={size_x, size_y, size_c}})
+ end
+ end
+ end
+ end
+
+ -- nb samples: user defined or max
+ self.nbRawSamples = self.nbRawSamples or #self.rawdata
+
+ -- extract some info (max sizes)
+ self.maxX = self.rawdata[1].size[1]
+ self.maxY = self.rawdata[1].size[2]
+ for i = 2,self.nbRawSamples do
+ if self.maxX < self.rawdata[i].size[1] then
+ self.maxX = self.rawdata[i].size[1]
+ end
+ if self.maxY < self.rawdata[i].size[2] then
+ self.maxY = self.rawdata[i].size[2]
+ end
+ end
+ -- and nb of samples obtainable (this is overcomplete ;-)
+ self.nbSamples = self.nbPatchPerSample * self.nbRawSamples
+
+ -- max size ?
+ if not self.rawSampleMaxSize then
+ self.rawSampleMaxSize = math.max(self.rawSampleSize[1],self.rawSampleSize[2])
+ end
+ local maxXY = math.max(self.maxX, self.maxY)
+ if maxXY < self.rawSampleMaxSize then
+ self.rawSampleMaxSize = maxXY
+ end
+
+ -- some info
+ if self.verbose then
+ print(self)
+ end
+
+ -- sampling mode
+ if self.samplingMode == 'equal' or self.samplingMode == 'random' then
+ self:parseAllMasks()
+ if self.samplingMode == 'random' then
+ -- get the number of usable patches
+ self.nbRandomPatches = 0
+ for i,v in ipairs(self.tags) do
+ if i ~= self.classToSkip then
+ self.nbRandomPatches = self.nbRandomPatches + v.size
+ end
+ end
+ -- create shuffle table
+ self.randomLookup = torch.ByteTensor(self.nbRandomPatches)
+ local idx = 1
+ for i,v in ipairs(self.tags) do
+ if i ~= self.classToSkip and v.size > 0 then
+ self.randomLookup:narrow(1,idx,v.size):fill(i)
+ idx = idx + v.size
+ end
+ end
+ end
+ else
+ error('ERROR <DataSetLabelMe> unknown sampling mode')
+ end
+
+ -- preload ?
+ if self.preloadSamples then
+ self:preload()
+ end
+end
+
+function DataSetLabelMe:size()
+ return self.nbSamples
+end
+
+function DataSetLabelMe:__tostring__()
+ local str = 'DataSetLabelMe:\n'
+ str = str .. ' + path : '..self.path..'\n'
+ if self.cacheFile then
+ str = str .. ' + cache files : [path]/'..self.cacheFile..'-[tags|samples]\n'
+ end
+ str = str .. ' + nb samples : '..self.nbRawSamples..'\n'
+ str = str .. ' + nb generated patches : '..self.nbSamples..'\n'
+ if self.infiniteSet then
+ str = str .. ' + infinite set (actual nb of samples >> set:size())\n'
+ end
+ if self.rawSampleMaxSize then
+ str = str .. ' + samples are resized to fit in a '
+ str = str .. self.rawSampleMaxSize .. 'x' .. self.rawSampleMaxSize .. ' tensor'
+ str = str .. ' [max raw size = ' .. self.maxX .. 'x' .. self.maxY .. ']\n'
+ if self.rawSampleSize then
+ str = str .. ' + imposed ratio of ' .. self.rawSampleSize[1] .. 'x' .. self.rawSampleSize[2] .. '\n'
+ end
+ end
+ str = str .. ' + patches size : ' .. self.patchSize .. 'x' .. self.patchSize .. '\n'
+ if self.classToSkip ~= 0 then
+ str = str .. ' + unused class : ' .. self.classNames[self.classToSkip] .. '\n'
+ end
+ str = str .. ' + sampling mode : ' .. self.samplingMode .. '\n'
+ str = str .. ' + label type : ' .. self.labelType .. '\n'
+ str = str .. ' + '..self.nbClasses..' categories : '
+ for i = 1,#self.classNames-1 do
+ str = str .. self.classNames[i] .. ' | '
+ end
+ str = str .. self.classNames[#self.classNames]
+ return str
+end
+
+function DataSetLabelMe:__index__(key)
+ if type(key)=='string' and key == 'last' then
+ xerror('deprecated','DataSetLabelMe')
+ end
+ if type(key)=='number' then
+ local which_tag
+ local tag_idx
+ if self.samplingMode == 'random' then
+ -- get indexes from random table
+ which_tag = self.randomLookup[math.random(1,self.nbRandomPatches)]
+ tag_idx = math.floor(math.random(0,self.tags[which_tag].size-1)/3)*3+1
+ elseif self.samplingMode == 'equal' then
+ -- equally sample each category:
+ which_tag = ((key-1) % (self.nbClasses)) + 1
+ while self.tags[which_tag].size == 0 or which_tag == self.classToSkip do
+ -- no sample in that class, replacing with random patch
+ which_tag = math.floor(random.uniform(1,self.nbClasses))
+ end
+ local nbSamplesPerClass = math.ceil(self.nbSamples / self.nbClasses)
+ tag_idx = math.floor((key*nbSamplesPerClass-1)/self.nbClasses) + 1
+ tag_idx = ((tag_idx-1) % (self.tags[which_tag].size/3))*3 + 1
+ end
+
+ -- generate patch
+ local subx,suby
+ self:loadSample(self.tags[which_tag].data[tag_idx+2])
+ local ctr_x = self.tags[which_tag].data[tag_idx]
+ local ctr_y = self.tags[which_tag].data[tag_idx+1]
+ local subtensor
+ if self.processor then
+ subtensor = self.processor:forward(self.currentSample,ctr_x,ctr_y)
+ else
+ subx = math.floor(ctr_x - self.patchSize/2) + 1
+ self.currentX = subx/self.currentSample:size(1)
+ suby = math.floor(ctr_y - self.patchSize/2) + 1
+ self.currentY = suby/self.currentSample:size(1)
+ subtensor = self.currentSample:narrow(1,subx,self.patchSize):narrow(2,suby,self.patchSize)
+ end
+
+ if self.labelType == 'center' then
+ -- generate label vector for patch centre
+ local vector = torch.Tensor(1,1,self.nbClasses):fill(-1)
+ vector[1][1][which_tag] = 1
+
+ -- and optional string
+ local label = self.classNames[which_tag]
+
+ -- return sample+label
+ return {subtensor, vector, label}, true
+
+ elseif self.labelType == 'pixelwise' then
+ -- generate pixelwise annotation
+ local annotation = self.currentMask:narrow(1,subx,self.patchSize):narrow(2,suby,self.patchSize)
+ return {subtensor, annotation}, true
+
+ elseif self.labelType == 'fovea' then
+ -- focus given fovea on the current patch
+ if self.fovea then
+ self.fovea:focus(ctr_x,ctr_y,self.patchSize)
+ end
+
+ -- generate label vector for patch centre
+ local vector = torch.Tensor(1,1,self.nbClasses):fill(-1)
+ vector[1][1][which_tag] = 1
+
+ -- and optional string
+ local label = self.classNames[which_tag]
+
+ -- return whole input + label
+ return {self.currentSample, vector, label}, true
+
+ elseif self.labelType == 'pixelwise+fovea' then
+ -- focus given fovea on the current patch
+ if self.fovea then
+ self.fovea:focus(ctr_x,ctr_y,self.patchSize)
+ end
+
+ -- return raw sample and annotation
+ return {self.currentSample, self.currentMask, x=ctr_x, y=ctr_y}, true
+
+ else
+ -- no label
+ return {subtensor}, true
+ end
+
+ elseif type(key)=='string' and (key == 'similar' or key == 'dissimilar') then
+ local which_tag, which_tag2
+ local tag_idx, tag_idx2
+ self.currentKey = self.currentKey or 1
+ if key == 'similar' then --for DrLim training
+ if self.samplingMode == 'random' then
+ -- get indexes from random table
+ which_tag = self.randomLookup[math.random(1,self.nbRandomPatches)]
+ which_tag2 = which_tag
+ tag_idx = math.floor(math.random(0,self.tags[which_tag].size-1)/3)*3+1
+ repeat
+ tag_idx2 = math.floor(math.random(0,self.tags[which_tag2].size-1)/3)*3+1
+ until tag_idx2 ~= tag_idx
+ elseif self.samplingMode == 'equal' then
+ key = self.currentKey
+ self.currentKey = self.currentKey + 1
+ -- equally sample each category:
+ which_tag = ((key-1) % (self.nbClasses)) + 1
+ while self.tags[which_tag].size == 0 or which_tag == self.classToSkip do
+ -- no sample in that class, replacing with random patch
+ which_tag = math.floor(random.uniform(1,self.nbClasses))
+ end
+ which_tag2 = which_tag
+
+ local nbSamplesPerClass = math.ceil(self.nbSamples / self.nbClasses)
+ tag_idx = math.floor((key*nbSamplesPerClass-1)/self.nbClasses) + 1
+ tag_idx = ((tag_idx-1) % (self.tags[which_tag].size/3))*3 + 1
+ tag_idx2 = math.floor(math.random(0,self.tags[which_tag2].size-1)/3)*3+1
+ end
+ elseif key == 'dissimilar' then --for DrLim training
+ if self.samplingMode == 'random' then
+ -- get indexes from random table
+ which_tag = self.randomLookup[math.random(1,self.nbRandomPatches)]
+ repeat
+ which_tag2 = self.randomLookup[math.random(1,self.nbRandomPatches)]
+ until which_tag2 ~= which_tag
+ tag_idx = math.floor(math.random(0,self.tags[which_tag].size-1)/3)*3+1
+ tag_idx2 = math.floor(math.random(0,self.tags[which_tag2].size-1)/3)*3+1
+ elseif self.samplingMode == 'equal' then
+ key = self.currentKey
+ self.currentKey = self.currentKey + 1
+ -- equally sample each category:
+ which_tag = ((key-1) % (self.nbClasses)) + 1
+ while self.tags[which_tag].size == 0 or which_tag == self.classToSkip do
+ -- no sample in that class, replacing with random patch
+ which_tag = math.floor(random.uniform(1,self.nbClasses))
+ end
+ repeat
+ which_tag2 = math.floor(random.uniform(1,self.nbClasses))
+ until which_tag2 ~= which_tag
+ and self.tags[which_tag2].size ~= 0
+ and which_tag2 ~= self.classToSkip
+
+ local nbSamplesPerClass = math.ceil(self.nbSamples / self.nbClasses)
+ tag_idx = math.floor((key*nbSamplesPerClass-1)/self.nbClasses) + 1
+ tag_idx = ((tag_idx-1) % (self.tags[which_tag].size/3))*3 + 1
+ tag_idx2 = math.floor(math.random(0,self.tags[which_tag2].size-1)/3)*3+1
+ end
+ end
+
+ -- now generate pair of patches and return
+ self:loadSample(self.tags[which_tag].data[tag_idx+2])
+ local ctr_x = self.tags[which_tag].data[tag_idx]
+ local ctr_y = self.tags[which_tag].data[tag_idx+1]
+ local subx = math.floor(ctr_x - self.patchSize/2) + 1
+ self.currentX = subx/self.currentSample:size(1)
+ local suby = math.floor(ctr_y - self.patchSize/2) + 1
+ self.currentY = suby/self.currentSample:size(1)
+ local subtensor = self.currentSample:narrow(1,subx,self.patchSize):narrow(2,suby,self.patchSize)
+ -- make a copy otherwise it will be overwritten
+ subtensor = torch.Tensor():resizeAs(subtensor):copy(subtensor)
+ -- generate label vector for patch centre
+ local vector = torch.Tensor(1,1,self.nbClasses):fill(-1)
+
+ -- generate pixelwise annotation
+ local annotation = self.currentMask:narrow(1,subx,self.patchSize):narrow(2,suby,self.patchSize)
+
+ -- patch2
+ self:loadSample(self.tags[which_tag2].data[tag_idx2+2])
+ local ctr_x2 = self.tags[which_tag2].data[tag_idx2]
+ local ctr_y2 = self.tags[which_tag2].data[tag_idx2+1]
+ local subx2 = math.floor(ctr_x2 - self.patchSize/2) + 1
+ self.currentX = subx2/self.currentSample:size(1)
+ local suby2 = math.floor(ctr_y2 - self.patchSize/2) + 1
+ self.currentY = suby2/self.currentSample:size(1)
+ local subtensor2 = self.currentSample:narrow(1,subx2,self.patchSize):narrow(2,suby2,self.patchSize)
+ -- make a copy otherwise it will be overwritten
+ subtensor2 = torch.Tensor():resizeAs(subtensor2):copy(subtensor2)
+ -- generate label vector for patch centre
+ local vector2 = torch.Tensor(1,1,self.nbClasses):fill(-1)
+
+ -- generate pixelwise annotation
+ local annotation2 = self.currentMask:narrow(1,subx2,self.patchSize):narrow(2,suby2,self.patchSize)
+
+ if self.labelType == 'center' then
+ vector[1][1][which_tag] = 1
+ vector2[1][1][which_tag2] = 1
+ -- and optional string
+ local label = self.classNames[which_tag]
+ local label2 = self.classNames[which_tag2]
+
+ -- return sample+label
+ return {{subtensor, vector, label},{subtensor2, vector2, label2}}, true
+
+ elseif self.labelType == 'pixelwise' then
+ return {{subtensor, annotation},{subtensor2, annotation2}}, true
+ else
+ -- no label
+ return {subtensor,subtensor2}, true
+ end
+
+ end
+ return rawget(self,key)
+end
+
+function DataSetLabelMe:loadSample(index)
+ if self.preloadedDone then
+ if index ~= self.currentIndex then
+ -- clean up
+ self.currentSample = nil
+ self.currentMask = nil
+ collectgarbage()
+ -- load new sample
+ self.currentSample = torch.Tensor(self.preloaded.samples[index]:size())
+ self.currentSample:copy(self.preloaded.samples[index]):mul(1/255)
+ self.currentMask = torch.Tensor(self.preloaded.masks[index]:size())
+ self.currentMask:copy(self.preloaded.masks[index])
+ -- remember index
+ self.currentIndex = index
+ end
+ elseif index ~= self.currentIndex then
+ -- clean up
+ self.currentSample = nil
+ self.currentMask = nil
+ collectgarbage()
+ -- load image
+ local img_loaded = image.load(self.rawdata[index].imgfile)
+ local mask_loaded = image.load(self.rawdata[index].maskfile):select(3,1)
+ -- resize ?
+ if self.rawSampleSize then
+ -- resize precisely
+ local w = self.rawSampleSize[1]
+ local h = self.rawSampleSize[2]
+ self.currentSample = torch.Tensor(w,h,img_loaded:size(3))
+ image.scale(img_loaded, self.currentSample, 'bilinear')
+ self.currentMask = torch.Tensor(w,h)
+ image.scale(mask_loaded, self.currentMask, 'simple')
+
+ elseif self.rawSampleMaxSize and (self.rawSampleMaxSize < img_loaded:size(1)
+ or self.rawSampleMaxSize < img_loaded:size(2)) then
+ -- resize to fit in bounding box
+ local w,h
+ if img_loaded:size(1) >= img_loaded:size(2) then
+ w = self.rawSampleMaxSize
+ h = math.floor((w*img_loaded:size(2))/img_loaded:size(1))
+ else
+ h = self.rawSampleMaxSize
+ w = math.floor((h*img_loaded:size(1))/img_loaded:size(2))
+ end
+ self.currentSample = torch.Tensor(w,h,img_loaded:size(3))
+ image.scale(img_loaded, self.currentSample, 'bilinear')
+ self.currentMask = torch.Tensor(w,h)
+ image.scale(mask_loaded, self.currentMask, 'simple')
+ else
+ self.currentSample = img_loaded
+ self.currentMask = mask_loaded
+ end
+ -- process mask
+ self.currentMask:mul(self.nbClasses-1):add(0.5):floor():add(1)
+ self.currentIndex = index
+ end
+end
+
+function DataSetLabelMe:preload(saveFile)
+ -- if cache file exists, just retrieve images from it
+ if self.cacheFile
+ and paths.filep(paths.concat(self.path,self.cacheFile..'-samples')) then
+ print('<DataSetLabelMe> retrieving saved samples from :'
+ .. paths.concat(self.path,self.cacheFile..'-samples')
+ .. ' [delete file to force new scan]')
+ local file = torch.DiskFile(paths.concat(self.path,self.cacheFile..'-samples'), 'r')
+ file:binary()
+ self.preloaded = file:readObject()
+ file:close()
+ self.preloadedDone = true
+ return
+ end
+ print('<DataSetLabelMe> preloading all images')
+ self.preloaded = {samples={}, masks={}}
+ for i = 1,self.nbRawSamples do
+ toolBox.dispProgress(i,self.nbRawSamples)
+ -- load samples, and store them in raw byte tensors (min memory footprint)
+ self:loadSample(i)
+ local rawTensor = torch.ByteTensor(self.currentSample:size())
+ local rawMask = torch.ByteTensor(self.currentMask:size()):copy(self.currentMask)
+ rawTensor:copy(self.currentSample:mul(255))
+ -- insert them in our list
+ table.insert(self.preloaded.samples, rawTensor)
+ table.insert(self.preloaded.masks, rawMask)
+ end
+ self.preloadedDone = true
+ -- optional cache file
+ if saveFile then
+ self.cacheFile = saveFile
+ end
+ -- if cache file given, serialize list of tags to it
+ if self.cacheFile then
+ print('<DataSetLabelMe> saving samples to cache file: '
+ .. paths.concat(self.path,self.cacheFile..'-samples'))
+ local file = torch.DiskFile(paths.concat(self.path,self.cacheFile..'-samples'), 'w')
+ file:binary()
+ file:writeObject(self.preloaded)
+ file:close()
+ end
+end
+
+function DataSetLabelMe:parseMask(existing_tags)
+ local tags
+ if not existing_tags then
+ tags = {}
+ local storage
+ for i = 1,self.nbClasses do
+ storage = torch.ShortStorage(self.rawSampleMaxSize*self.rawSampleMaxSize*3)
+ tags[i] = {data=storage, size=0}
+ end
+ else
+ tags = existing_tags
+ -- make sure each tag list is large enough to hold the incoming data
+ for i = 1,self.nbClasses do
+ if ((tags[i].size + (self.rawSampleMaxSize*self.rawSampleMaxSize*3)) >
+ tags[i].data:size()) then
+ tags[i].data:resize(tags[i].size+(self.rawSampleMaxSize*self.rawSampleMaxSize*3),true)
+ end
+ end
+ end
+ local mask = self.currentMask
+ local x_start = math.ceil(self.patchSize/2)
+ local x_end = mask:size(1) - math.ceil(self.patchSize/2)
+ local y_start = math.ceil(self.patchSize/2)
+ local y_end = mask:size(2) - math.ceil(self.patchSize/2)
+ mask.nn.DataSetLabelMe_extract(tags, mask, x_start, x_end, y_start, y_end, self.currentIndex)
+ return tags
+end
+
+function DataSetLabelMe:parseAllMasks(saveFile)
+ -- if cache file exists, just retrieve tags from it
+ if self.cacheFile and paths.filep(paths.concat(self.path,self.cacheFile..'-tags')) then
+ print('<DataSetLabelMe> retrieving saved tags from :' .. paths.concat(self.path,self.cacheFile..'-tags')
+ .. ' [delete file to force new scan]')
+ local file = torch.DiskFile(paths.concat(self.path,self.cacheFile..'-tags'), 'r')
+ file:binary()
+ self.tags = file:readObject()
+ file:close()
+ return
+ end
+ -- parse tags, long operation
+ print('<DataSetLabelMe> parsing all masks to generate list of tags')
+ print('<DataSetLabelMe> WARNING: this operation could allocate up to '..
+ math.ceil(self.nbRawSamples*self.rawSampleMaxSize*self.rawSampleMaxSize*
+ 3*2/1024/1024)..'MB')
+ self.tags = nil
+ for i = 1,self.nbRawSamples do
+ toolBox.dispProgress(i,self.nbRawSamples)
+ self:loadSample(i)
+ self.tags = self:parseMask(self.tags)
+ end
+ -- report
+ print('<DataSetLabelMe> nb of patches extracted per category:')
+ for i = 1,self.nbClasses do
+ print(' ' .. i .. ' - ' .. self.tags[i].size / 3)
+ end
+ -- optional cache file
+ if saveFile then
+ self.cacheFile = saveFile
+ end
+ -- if cache file exists, serialize list of tags to it
+ if self.cacheFile then
+ print('<DataSetLabelMe> saving tags to cache file: ' .. paths.concat(self.path,self.cacheFile..'-tags'))
+ local file = torch.DiskFile(paths.concat(self.path,self.cacheFile..'-tags'), 'w')
+ file:binary()
+ file:writeObject(self.tags)
+ file:close()
+ end
+end
+
+function DataSetLabelMe:exportIDX(samplefile, labelfile)
+ -- current limitation for IDX files
+ local idxMaxSize = 2^31-1
+
+ if samplefile then
+ -- message
+ print('<DataSetLabelMe> exporting data to '..samplefile..'-n|N.idx')
+
+ -- check for global size
+ local chanels = self.preloaded.samples[1]:size(3)
+ local exportSize = self.rawSampleMaxSize^2 * chanels * self.nbRawSamples * 4
+ local nbFiles = math.ceil(exportSize / idxMaxSize)
+ local offset = 0
+
+ for n = 1,nbFiles do
+ local exported
+ local nbSamplesPerFile = math.floor(idxMaxSize / (self.rawSampleMaxSize^2 * chanels * 4))
+ if n == nbFiles then
+ nbSamplesPerFile = self.nbRawSamples - ((nbFiles-1)*nbSamplesPerFile)
+ end
+ exported = torch.FloatTensor(self.rawSampleMaxSize, self.rawSampleMaxSize,
+ chanels, nbSamplesPerFile)
+
+ local filename = samplefile..'-'..string.format("%05d",n)..'|'..string.format("%05d",nbFiles)..'.idx'
+ print('+ doing '..filename..' ('..exported:size(4)..' samples)')
+ if not paths.filep(filename) then
+ -- export samples
+ for i = 1,exported:size(4) do
+ toolBox.dispProgress(i,exported:size(4))
+ local sample = self.preloaded.samples[offset+i]
+ local w = sample:size(1)
+ local h = sample:size(2)
+ exported:select(4,i):narrow(1,1,w):narrow(2,1,h):copy(sample)
+ end
+ offset = offset + exported:size(4)
+
+ -- write file
+ local file = torch.DiskFile(filename,'w')
+ file:binary()
+ file:writeInt(0x1e3d4c51) -- float type
+ file:writeInt(4) -- nb dims
+ file:writeInt(exported:size(4)) -- dim[0]
+ file:writeInt(exported:size(3)) -- dim[1]
+ file:writeInt(exported:size(2)) -- dim[2]
+ file:writeInt(exported:size(1)) -- dim[3]
+ file:writeFloat(exported:storage()) -- data
+ file:close()
+ end
+ end
+ end
+ if labelfile then
+ print('<DataSetLabelMe> exporting labels to '..labelfile..'-N.idx (N=0..'..tostring(#self.tags-1)..')')
+ -- export each tag list in a separate file
+ local tags = self.tags
+ local nbtags = #tags
+ for i = 1,nbtags do
+ local exportSize = tags[i].size
+ toolBox.dispProgress(i,nbtags)
+ if exportSize ~= 0 and not paths.filep(labelfile..'-'..tostring(i-1)..'.idx') then
+ local dest = torch.ShortTensor(exportSize)
+ local src = torch.ShortTensor():set(tags[i].data,1,exportSize)
+ dest:copy(src)
+ -- add 1 to the whole tensor, to go to 0-based indexing
+ dest:apply(function (x) return x-1 end)
+ -- write file
+ local file = torch.DiskFile(labelfile..'-'..string.format("%05d",tostring(i-1))..'.idx', 'w')
+ file:binary()
+ file:writeInt(0x1e3d4c56) -- short type
+ file:writeInt(2) -- nb dims
+ file:writeInt(exportSize/3) -- dim[0]
+ file:writeInt(3) -- dim[1]
+ file:writeInt(1) -- unused dim
+ file:writeShort(dest:storage()) -- data
+ file:close()
+ -- garbage cleaning
+ dest = nil
+ collectgarbage()
+ end
+ end
+ end
+end
+
+function DataSetLabelMe:display(args)
+ -- parse args:
+ local title = args.title or 'DataSetLabelMe'
+ local min = args.min
+ local max = args.max
+ local nbSamples = args.nbSamples or 50
+ local scale = args.scale
+ local resX = args.resX or 1200
+ local resY = args.resY or 800
+
+ -- compute some geometry params
+ local painter = qtwidget.newwindow(resX,resY,title)
+ local step_x = 0
+ local step_y = 0
+ local sizeX = self.maxX
+ local sizeY = self.maxY
+ if not scale then
+ scale = math.sqrt(resX*resY/ (sizeX*sizeY*nbSamples))
+ local nbx = math.floor(resX/(scale*sizeX))
+ scale = resX/(sizeX*nbx)
+ end
+
+ -- load the samples and display them
+ local dispTensor = torch.Tensor(sizeX*scale,sizeY*scale,3)
+ local dispMask = torch.Tensor(sizeX*scale,sizeY*scale)
+ local displayer = Displayer()
+ for i=1,nbSamples do
+ toolBox.dispProgress(i,nbSamples)
+ self:loadSample(i)
+ image.scale(self.currentSample, dispTensor, 'simple')
+ image.scale(self.currentMask, dispMask, 'simple')
+ local displayed = image.mergeSegmentation(dispTensor, dispMask, self.colorMap)
+ if (step_x > (resX-sizeX*scale)) then
+ step_x = 0
+ step_y = step_y + sizeY*scale
+ if (step_y > (resY-sizeY*scale)) then
+ break
+ end
+ end
+ displayer:show{painter=painter,
+ tensor=displayed,
+ min=min, max=max,
+ offset_x=step_x,
+ offset_y=step_y}
+ step_x = step_x + sizeX*scale
+ end
+end
diff --git a/generic/DataSetLabelMe.c b/generic/DataSetLabelMe.c
new file mode 100644
index 0000000..cea78df
--- /dev/null
+++ b/generic/DataSetLabelMe.c
@@ -0,0 +1,54 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/DataSetLabelMe.c"
+#else
+
+static int nn_(DataSetLabelMe_extract)(lua_State *L)
+{
+ const void* torch_ShortStorage_id = luaT_checktypename2id(L, "torch.ShortStorage");
+ int tags = 1;
+ THTensor *mask = luaT_checkudata(L, 2, torch_(Tensor_id));
+ int x_start = lua_tonumber(L, 3);
+ int x_end = lua_tonumber(L, 4);
+ int y_start = lua_tonumber(L, 5);
+ int y_end = lua_tonumber(L, 6);
+ int idx = lua_tonumber(L, 7);
+
+ int x,y,label,tag,size;
+ THShortStorage *data;
+ for (x=x_start; x<=x_end; x++) {
+ for (y=y_start; y<=y_end; y++) {
+ label = THTensor_(get2d)(mask, x-1, y-1); // label = mask[x][y]
+ lua_rawgeti(L, tags, label); // tag = tags[label]
+ tag = lua_gettop(L);
+ lua_pushstring(L, "size"); lua_rawget(L, tag); // size = tag.size
+ size = lua_tonumber(L,-1); lua_pop(L,1);
+ lua_pushstring(L, "size"); lua_pushnumber(L, size+3); lua_rawset(L, tag); // tag.size = size + 3
+ lua_pushstring(L, "data"); lua_rawget(L, tag); // data = tag.data
+ data = luaT_checkudata(L, -1, torch_ShortStorage_id); lua_pop(L, 1);
+ data->data[size] = x; // data[size+1] = x
+ data->data[size+1] = y; // data[size+1] = y
+ data->data[size+2] = idx; // data[size+1] = idx
+ lua_pop(L, 1);
+ }
+ }
+ return 0;
+}
+
+static int nn_(DataSetLabelMe_backward)(lua_State *L)
+{
+
+}
+
+static const struct luaL_Reg nn_(DataSetLabelMe__) [] = {
+ {"DataSetLabelMe_extract", nn_(DataSetLabelMe_extract)},
+ {NULL, NULL}
+};
+
+static void nn_(DataSetLabelMe_init)(lua_State *L)
+{
+ luaT_pushmetaclass(L, torch_(Tensor_id));
+ luaT_registeratname(L, nn_(DataSetLabelMe__), "nn");
+ lua_pop(L,1);
+}
+
+#endif
diff --git a/init.c b/init.c
index 0f0936f..51123f6 100644
--- a/init.c
+++ b/init.c
@@ -51,6 +51,9 @@ static const void* torch_DoubleTensor_id = NULL;
#include "generic/SpatialGraph.c"
#include "THGenerateFloatTypes.h"
+#include "generic/DataSetLabelMe.c"
+#include "THGenerateFloatTypes.h"
+
DLL_EXPORT int luaopen_libnnx(lua_State *L)
{
torch_FloatTensor_id = luaT_checktypename2id(L, "torch.FloatTensor");
@@ -70,6 +73,7 @@ DLL_EXPORT int luaopen_libnnx(lua_State *L)
nn_FloatSpatialMSECriterion_init(L);
nn_FloatSpatialClassNLLCriterion_init(L);
nn_FloatSpatialGraph_init(L);
+ nn_FloatDataSetLabelMe_init(L);
nn_DoubleSpatialLinear_init(L);
nn_DoubleHardShrink_init(L);
@@ -85,6 +89,7 @@ DLL_EXPORT int luaopen_libnnx(lua_State *L)
nn_DoubleSpatialMSECriterion_init(L);
nn_DoubleSpatialClassNLLCriterion_init(L);
nn_DoubleSpatialGraph_init(L);
+ nn_DoubleDataSetLabelMe_init(L);
return 1;
}
diff --git a/init.lua b/init.lua
index 02757c5..dd7d777 100644
--- a/init.lua
+++ b/init.lua
@@ -95,3 +95,4 @@ torch.include('nnx', 'StochasticTrainer.lua')
-- datasets:
torch.include('nnx', 'DataSet.lua')
torch.include('nnx', 'DataList.lua')
+torch.include('nnx', 'DataSetLabelMe.lua')
diff --git a/nnx-1.0-1.rockspec b/nnx-1.0-1.rockspec
index 20fe9e5..c04137c 100644
--- a/nnx-1.0-1.rockspec
+++ b/nnx-1.0-1.rockspec
@@ -78,6 +78,7 @@ build = {
install_files(/lua/nnx StochasticTrainer.lua)
install_files(/lua/nnx DataSet.lua)
install_files(/lua/nnx DataList.lua)
+ install_files(/lua/nnx DataSetLabelMe.lua)
install_files(/lua/nnx CMulTable.lua)
install_files(/lua/nnx CAddTable.lua)
install_files(/lua/nnx CDivTable.lua)