diff options
author | Camille Couprie <ccouprie@cs.nyu.edu> | 2012-09-07 00:33:10 +0400 |
---|---|---|
committer | Camille Couprie <ccouprie@cs.nyu.edu> | 2012-09-07 00:33:10 +0400 |
commit | ad88215a5b0a3933405f5836e5c8254836633897 (patch) | |
tree | 8f8e0f69a26bd17b9bf03fd18b1f1acad5948035 /DataSetSamplingPascal.lua | |
parent | 170449aa65dcc0d0c84651764fe893c5723cd758 (diff) |
improved the sampling strategy for pascal segments dataset
Diffstat (limited to 'DataSetSamplingPascal.lua')
-rw-r--r-- | DataSetSamplingPascal.lua | 322 |
1 files changed, 125 insertions, 197 deletions
diff --git a/DataSetSamplingPascal.lua b/DataSetSamplingPascal.lua index f4104a6..4c9ed48 100644 --- a/DataSetSamplingPascal.lua +++ b/DataSetSamplingPascal.lua @@ -1,12 +1,11 @@ - - -------------------------------------------------------------------------------- --- DataSetSamplingPascal: A class to handle datasets from Pascal +-- DataSetSamplingPascal: A class to handle datasets from LabelMe (and other segmentation +-- based datasets). -- --- Provides options to cache (on disk) dataset, precompute --- segmentation masks, shuffle samples, filter class frequency ... +-- Provides lots of options to cache (on disk) datasets, precompute +-- segmentation masks, shuffle samples, extract subpatches, ... -- --- Authors: Clement Farabet, Benoit Corda, Camille Couprie +-- Authors: Clement Farabet, Benoit Corda -------------------------------------------------------------------------------- local DataSetSamplingPascal = torch.class('DataSetSamplingPascal') @@ -15,54 +14,29 @@ local path_images = 'Images' local path_annotations = 'Annotations' local path_masks = 'Masks' - -colorobject={ -[1]={224,224,192},-- unknown -[2]={0, 0, 0},-- background -[3]={128, 0, 0}, -[4]={0, 128, 0}, -[5]={128, 128, 0}, -[6]={0, 0, 128}, -[7]={128, 0, 128}, -[8]={0, 128, 128}, -[9]={128, 128, 128}, -[10]={64, 0, 0}, -[11]={192, 0, 0}, -[12]={64, 128, 0}, -[13]={192, 128, 0}, -[14]={64, 0, 128}, -[15]={192, 0, 128}, -[16]={64, 128, 128}, -[17]={192, 128, 128}, -[18]={0, 64, 0}, -[19]={128, 64, 0}, -[20]={0, 192, 0}, -[21]={128, 192, 0}, -[22]={0, 64, 128}} - - ---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ - function DataSetSamplingPascal:__init(...) -- check args xlua.unpack_class( self, {...}, 'DataSetSamplingPascal', - 'Creates a DataSet from standard Pascal directories (Images+Annotations)', - {arg='path', type='string', help='path to Pascal directory', req=true}, + 'Creates a DataSet from standard LabelMe directories (Images+Annotations)', + {arg='path', type='string', help='path to LabelMe directory', req=true}, {arg='nbClasses', type='number', help='number of classes in dataset', default=1}, - {arg='nbSegments', type='number', help='number of segment per image in dataset', default=100}, {arg='classNames', type='table', help='list of class names', default={'no name'}}, {arg='nbRawSamples', type='number', help='number of images'}, + {arg='nbSegments', type='number', help='number of segment per image in dataset', default=100}, {arg='rawSampleMaxSize', type='number', help='resize all images to fit in a MxM window'}, {arg='rawSampleSize', type='table', help='resize all images precisely: {w=,h=}}'}, {arg='rawMaskRescale',type='boolean',help='does are the N classes spread between 0->255 in the PNG and need to be rescaled',default=true}, - {arg='samplingMode', type='string', help='segment sampling method: random | equal', default='random'}, + {arg='nbPatchPerSample', type='number', help='number of patches to extract from each image', default=100}, + {arg='patchSize', type='number', help='size of patches to extract from images', default=64}, + {arg='samplingMode', type='string', help='patch sampling method: random | equal', default='random'}, + {arg='samplingFilter', type='table', help='a filter to sample patches: {ratio=,size=,step}'}, {arg='labelType', type='string', help='type of label returned: center | pixelwise', default='center'}, + {arg='labelGenerator', type='function', help='a function to generate sample+target (bypasses labelType)'}, {arg='infiniteSet', type='boolean', help='if true, the set can be indexed to infinity, looping around samples', default=false}, - {arg='classToSkip', type='number', help='index of class to skip during sampling', default=1}, - {arg='ScClassToSkip', type='number', help='index of class to skip during sampling', default=1}, + {arg='classToSkip', type='number', help='index of class to skip during sampling', default=0}, {arg='preloadSamples', type='boolean', help='if true, all samples are preloaded in memory', default=false}, {arg='cacheFile', type='string', help='path to cache file (once cached, loading is much faster)'}, {arg='verbose', type='boolean', help='dumps information', default=false} @@ -72,13 +46,15 @@ function DataSetSamplingPascal:__init(...) self.colorMap = image.colormap(self.nbClasses) self.rawdata = {} self.currentIndex = -1 - self.realIndex = -1 - self.ctr_segment_index = 0 - self.ctr_gt_index = 0 + --location of the patch in the img + self.currentX = 0 + self.currentY = 0 + self.realIndex = -1 + self.currentSegment = 0 -- parse dir structure - print('<DataSetSamplingPascal> loading Pascal dataset from '..self.path) + print('<DataSetSamplingPascal> loading LabelMe dataset from '..self.path) for folder in paths.files(paths.concat(self.path,path_images)) do if folder ~= '.' and folder ~= '..' then -- allowing for less nesting in the data set preparation [MS] @@ -89,7 +65,7 @@ function DataSetSamplingPascal:__init(...) for file in paths.files(paths.concat(self.path,path_images,folder)) do if file ~= '.' and file ~= '..' then - self:getsizes(folder,file) + self:getsizes(folder,file) end end end @@ -110,7 +86,8 @@ function DataSetSamplingPascal:__init(...) self.maxY = self.rawdata[i].size[2] end end - self.nbSamples = self.nbRawSamples + -- and nb of samples obtainable (this is overcomplete ;-) + self.nbSamples = self.nbPatchPerSample * self.nbRawSamples -- max size ? local maxXY = math.max(self.maxX, self.maxY) @@ -131,22 +108,23 @@ function DataSetSamplingPascal:__init(...) print(self) end + -- sampling mode if self.samplingMode == 'equal' or self.samplingMode == 'random' then self:parseAllMasks() if self.samplingMode == 'random' then - -- get the number of usable segments - self.nbRandomSegments = 0 + -- get the number of usable patches + self.nbRandomPatches = 0 for i,v in ipairs(self.tags) do - if i ~= self.classToSkip and i ~= self.ScClassToSkip then - self.nbRandomSegments = self.nbRandomSegments + v.size + if i ~= self.classToSkip then + self.nbRandomPatches = self.nbRandomPatches + v.size end end -- create shuffle table - self.randomLookup = torch.ByteTensor(self.nbRandomSegments) + self.randomLookup = torch.ByteTensor(self.nbRandomPatches) local idx = 1 for i,v in ipairs(self.tags) do - if i ~= self.classToSkip and i ~= self.ScClassToSkip and v.size > 0 then + if i ~= self.classToSkip and v.size > 0 then self.randomLookup:narrow(1,idx,v.size):fill(i) idx = idx + v.size end @@ -156,14 +134,14 @@ function DataSetSamplingPascal:__init(...) error('ERROR <DataSetSamplingPascal> unknown sampling mode') end + + -- preload ? if self.preloadSamples then self:preload() end end ---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ - function DataSetSamplingPascal:getsizes(folder,file) local filepng = file:gsub('jpg$','png') local filexml = file:gsub('jpg$','xml') @@ -196,14 +174,10 @@ function DataSetSamplingPascal:getsizes(folder,file) size={size_c, size_y, size_x}}) end ---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ - function DataSetSamplingPascal:size() return self.nbSamples end ---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ - function DataSetSamplingPascal:__tostring__() local str = 'DataSetSamplingPascal:\n' str = str .. ' + path : '..self.path..'\n' @@ -211,6 +185,7 @@ function DataSetSamplingPascal:__tostring__() str = str .. ' + cache files : [path]/'..self.cacheFile..'-[tags|samples]\n' end str = str .. ' + nb samples : '..self.nbRawSamples..'\n' + str = str .. ' + nb generated patches : '..self.nbSamples..'\n' if self.infiniteSet then str = str .. ' + infinite set (actual nb of samples >> set:size())\n' end @@ -222,15 +197,16 @@ function DataSetSamplingPascal:__tostring__() str = str .. ' + imposed ratio of ' .. self.rawSampleSize.w .. 'x' .. self.rawSampleSize.h .. '\n' end end + str = str .. ' + patches size : ' .. self.patchSize .. 'x' .. self.patchSize .. '\n' if self.classToSkip ~= 0 then str = str .. ' + unused class : ' .. self.classNames[self.classToSkip] .. '\n' end - if self.ScClassToSkip ~= 0 then - str = str .. ' + unused class : ' .. self.classNames[self.ScClassToSkip] .. '\n' - end - str = str .. ' + sampling mode : ' .. self.samplingMode .. '\n' - str = str .. ' + label type : ' .. self.labelType .. '\n' + if not self.labelGenerator then + str = str .. ' + label type : ' .. self.labelType .. '\n' + else + str = str .. ' + label type : generated by user function \n' + end str = str .. ' + '..self.nbClasses..' categories : ' for i = 1,#self.classNames-1 do str = str .. self.classNames[i] .. ' | ' @@ -239,51 +215,75 @@ function DataSetSamplingPascal:__tostring__() return str end ---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ - function DataSetSamplingPascal:__index__(key) - -- generate sample + target at index 'key': if type(key)=='number' then -- select sample, according to samplingMode - + local box_size = self.patchSize local ctr_target, tag_idx if self.samplingMode == 'random' then -- get indexes from random table ctr_target = self.randomLookup[math.random(1,self.nbRandomPatches)] - tag_idx = math.floor(math.random(0,self.tags[ctr_target].size-1)/3)*3+1 + tag_idx = math.floor(math.random(0,self.tags[ctr_target].size-1)/4)*4+1 elseif self.samplingMode == 'equal' then -- equally sample each category: ctr_target = ((key-1) % (self.nbClasses)) + 1 - while self.tags[ctr_target].size == 0 or ctr_target == self.classToSkip or ctr_target == self.ScClassToSkip do + while self.tags[ctr_target].size == 0 or ctr_target == self.classToSkip do -- no sample in that class, replacing with random patch ctr_target = math.floor(torch.uniform(1,self.nbClasses)) end local nbSamplesPerClass = math.ceil(self.nbSamples / self.nbClasses) if self.infiniteSet then - tag_idx = math.random(1,self.tags[ctr_target].size/3) + tag_idx = math.random(1,self.tags[ctr_target].size/4) else tag_idx = math.floor((key-1)/self.nbClasses) + 1 end - tag_idx = ((tag_idx-1) % (self.tags[ctr_target].size/3))*3 + 1 + tag_idx = ((tag_idx-1) % (self.tags[ctr_target].size/4))*4 + 1 end -- generate patch + self:loadSample(self.tags[ctr_target].data[tag_idx+2]) - local sample = self.currentSample - local mask = self.currentMask - self.ctr_segment_index = self.tags[ctr_target].data[tag_idx] - self.ctr_gt_index = self.tags[ctr_target].data[tag_idx+1] + local full_sample = self.currentSample + local full_mask = self.currentMask + local ctr_x = self.tags[ctr_target].data[tag_idx] + local ctr_y = self.tags[ctr_target].data[tag_idx+1] + local box_x = math.floor(ctr_x - box_size/2) + 1 + self.currentX = box_x/full_sample:size(3) + local box_y = math.floor(ctr_y - box_size/2) + 1 + self.currentY = box_y/full_sample:size(2) + self.currentSegment = self.tags[ctr_target].data[tag_idx+3] + + -- extract sample + mask: + local sample = full_sample:narrow(2,box_y,box_size):narrow(3,box_x,box_size) + local mask = full_mask:narrow(1,box_y,box_size):narrow(2,box_x,box_size) + + -- finally, generate the target, either using an arbitrary user function, + -- or a built-in label type + if self.labelGenerator then + -- call user function to generate sample+label + local ret = self:labelGenerator(full_sample, full_mask, sample, mask, + ctr_target, ctr_x, ctr_y, box_x, box_y, box_size,self.currentSegment,self.realIndex) + return ret, true + + elseif self.labelType == 'center' then + -- generate label vector for patch + local vector = torch.Tensor(self.nbClasses):fill(-1) + vector[ctr_target] = 1 + return {sample, vector}, true + + elseif self.labelType == 'pixelwise' then + -- generate pixelwise annotation + return {sample, mask}, true - -- Ce serait bien de rajouter le vecteur overlap pour ne pas le recalculer - return {sample,mask,self.ctr_segment_index, self.ctr_gt_index}, true - end + else + return false + end + end return rawget(self,key) end ---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ - function DataSetSamplingPascal:loadSample(index) if self.preloadedDone then @@ -374,8 +374,6 @@ function DataSetSamplingPascal:loadSample(index) end end ---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ - function DataSetSamplingPascal:preload(saveFile) -- if cache file exists, just retrieve images from it if self.cacheFile @@ -418,130 +416,70 @@ function DataSetSamplingPascal:preload(saveFile) end end ---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ - function DataSetSamplingPascal:parseMask(existing_tags) local tags if not existing_tags then tags = {} local storage for i = 1,self.nbClasses do - storage = torch.ShortStorage(self.rawSampleMaxSize*self.rawSampleMaxSize*3) + storage = torch.ShortStorage(self.rawSampleMaxSize*self.rawSampleMaxSize*4) tags[i] = {data=storage, size=0} end - else tags = existing_tags -- make sure each tag list is large enough to hold the incoming data for i = 1,self.nbClasses do - --print('current size'..tags[i].size) - if ((tags[i].size + (self.rawSampleMaxSize*self.rawSampleMaxSize*3)) > + if ((tags[i].size + (self.rawSampleMaxSize*self.rawSampleMaxSize*4)) > tags[i].data:size()) then - tags[i].data:resize(tags[i].size+(self.rawSampleMaxSize*self.rawSampleMaxSize*3),true) + tags[i].data:resize(tags[i].size+(self.rawSampleMaxSize*self.rawSampleMaxSize*4),true) end end end + -- use filter + local filter = self.samplingFilter or {ratio=0, size=self.patchSize, step=4} -- extract labels local mask = self.currentMask - - -- we dont need much precison for what follows to we downsample images to get faster results - smallmask = torch.Tensor(100,100) - image.scale(mask, smallmask, 'simple') - - -- (1) Compute overlap score for each segment given the image - - local file = sys.concat(self.realIndex:gsub('Images','Segments'),'.mat') - local mat_path = file:gsub('/.mat$','.mat') - local loaded = mattorch.load(mat_path) - loaded = loaded.top_masks:float() - local segment1, segmenttmp - nb_segments = self.nbSegments - if self.nbSegments > loaded:size(1) then nb_segments = loaded:size(1) end - - for k=1,nb_segments do - - -- (a) load one segment mask - segment1 = loaded[k]:t() - - -- (b) resize the segment mask - segmenttmp = image.scale(segment1, 100,100) - - -- (c) compute overlap - local overlap = imgraph.overlap(segmenttmp, smallmask,#classes_pascal) - - -- (2) If overlap score ok, add the segment index to tags - for i = 1,self.nbClasses do - if overlap[i]>0.2 and overlap[i]<0.99 then - tags[i].data[tags[i].size+1] = k - if overlap[i]>0.5 then - tags[i].data[tags[i].size+2] = i+100 -- this tag corresponds to a real segmentation - else - tags[i].data[tags[i].size+2] = i - end - tags[i].data[tags[i].size+3] = self.currentIndex - tags[i].size = tags[i].size+3 - -- print('insert '..k..'; '..tags[i].size ) - end - end - - - end - - -- (2) load the object ground truth image - - file = sys.concat(self.realIndex:gsub('Images','Objects'),'.png') - local mask_path = file:gsub('/.png$','.png') - - maskobject = image.load(mask_path) - maskobject= maskobject:mul(255) - smallmaskobject = torch.Tensor(3,100,100) - image.scale(maskobject, smallmaskobject, 'simple') - - maskobject = smallmaskobject:floor() - - -- extracting segments from ground truth - -i= 3 -still_run = 1 -while still_run~=0 and i <=21 do -- 21 nb max of different objects - still_run = 0 - ii=1 - already_taged = 0 - while ii<=100 do - jj=1 - while jj<=100 do -if already_taged == 0 then - if maskobject[1][ii][jj] == colorobject[i][1] and maskobject[2][ii][jj] == colorobject[i][2] - and maskobject[3][ii][jj] == colorobject[i][3] then - k = smallmask[ii][jj] - tags[k].data[tags[k].size+1] = 0 --this tag corresponds to a ground truth object - tags[k].data[tags[k].size+2] = i-- index of a ground truth object - tags[k].data[tags[k].size+3] = self.currentIndex - tags[k].size = tags[k].size+3 - already_taged = 1 - -- print('classe'..k..'objet'..i ) - end -end - if maskobject[1][ii][jj] == colorobject[i+1][1] and maskobject[2][ii][jj] == colorobject[i+1][2] - and maskobject[3][ii][jj] == colorobject[i+1][3] then - still_run = still_run+1 - if already_taged == 1 then - ii = 100 - jj= 100 - end - end - jj=jj+1 - end - ii=ii+1 - end - i=i+1 -end + local x_start = math.ceil(self.patchSize/2) + local x_end = mask:size(2) - math.ceil(self.patchSize/2) + local y_start = math.ceil(self.patchSize/2) + local y_end = mask:size(1) - math.ceil(self.patchSize/2) + + + local file = sys.concat(self.realIndex:gsub('Images','Segments'),'.mat') + local mat_path = file:gsub('/.mat$','.mat') + local loaded = mattorch.load(mat_path) + loaded = loaded.top_masks:float() + local segment1, segmenttmp + nb_segments = self.nbSegments + if self.nbSegments > loaded:size(1) then nb_segments = loaded:size(1) end + + -- (1) load a random segment + -- for i=1,self.nbPatchPerSample do + + k = math.random(nb_segments) + segment1 = loaded[k]:t() + + segmenttmp = image.scale(segment1, width, height) + -- segmenttmp = segmenttmp:narrow(2, x_start, x_end):narrow(1,y_start, y_end) + + -- (2) mask the ground truth mask with the random segment. + segmenttmp:cmul(mask:add(-1)):add(1) + + --print('extract labels segment '..k) + self.currentSegment = k + --print('self'..self.currentSegment) + + + mask.nn.DataSetSegmentSampling_extract(tags, segmenttmp, + x_start, x_end, + y_start, y_end, self.currentIndex, self.currentSegment, + filter.ratio, filter.size, filter.step) + -- end + return tags end ---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ - function DataSetSamplingPascal:parseAllMasks(saveFile) -- if cache file exists, just retrieve tags from it if self.cacheFile and paths.filep(paths.concat(self.path,self.cacheFile..'-tags')) then @@ -557,7 +495,7 @@ function DataSetSamplingPascal:parseAllMasks(saveFile) print('<DataSetSamplingPascal> parsing all masks to generate list of tags') print('<DataSetSamplingPascal> WARNING: this operation could allocate up to '.. math.ceil(self.nbRawSamples*self.rawSampleMaxSize*self.rawSampleMaxSize* - 3*2/1024/1024)..'MB') + 4*2/1024/1024)..'MB') self.tags = nil for i = 1,self.nbRawSamples do xlua.progress(i,self.nbRawSamples) @@ -565,9 +503,9 @@ function DataSetSamplingPascal:parseAllMasks(saveFile) self.tags = self:parseMask(self.tags) end -- report - print('<DataSetSamplingPascal> nb of segment extracted per category:') + print('<DataSetSamplingPascal> nb of patches extracted per category:') for i = 1,self.nbClasses do - print(' ' .. i .. ' - ' .. self.tags[i].size/3) + print(' ' .. i .. ' - ' .. self.tags[i].size / 4) end -- optional cache file if saveFile then @@ -583,8 +521,6 @@ function DataSetSamplingPascal:parseAllMasks(saveFile) end end ---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ - function DataSetSamplingPascal:display(...) -- check args local _, title, samples, zoom = xlua.unpack( @@ -617,11 +553,3 @@ function DataSetSamplingPascal:display(...) -- display image.display{win=painter, image=allimgs, legend=title, zoom=0.5} end - - - - - - - - |