Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCamille Couprie <ccouprie@cs.nyu.edu>2012-09-07 00:33:10 +0400
committerCamille Couprie <ccouprie@cs.nyu.edu>2012-09-07 00:33:10 +0400
commitad88215a5b0a3933405f5836e5c8254836633897 (patch)
tree8f8e0f69a26bd17b9bf03fd18b1f1acad5948035 /DataSetSamplingPascal.lua
parent170449aa65dcc0d0c84651764fe893c5723cd758 (diff)
improved the sampling strategy for pascal segments dataset
Diffstat (limited to 'DataSetSamplingPascal.lua')
-rw-r--r--DataSetSamplingPascal.lua322
1 files changed, 125 insertions, 197 deletions
diff --git a/DataSetSamplingPascal.lua b/DataSetSamplingPascal.lua
index f4104a6..4c9ed48 100644
--- a/DataSetSamplingPascal.lua
+++ b/DataSetSamplingPascal.lua
@@ -1,12 +1,11 @@
-
-
--------------------------------------------------------------------------------
--- DataSetSamplingPascal: A class to handle datasets from Pascal
+-- DataSetSamplingPascal: A class to handle datasets from LabelMe (and other segmentation
+-- based datasets).
--
--- Provides options to cache (on disk) dataset, precompute
--- segmentation masks, shuffle samples, filter class frequency ...
+-- Provides lots of options to cache (on disk) datasets, precompute
+-- segmentation masks, shuffle samples, extract subpatches, ...
--
--- Authors: Clement Farabet, Benoit Corda, Camille Couprie
+-- Authors: Clement Farabet, Benoit Corda
--------------------------------------------------------------------------------
local DataSetSamplingPascal = torch.class('DataSetSamplingPascal')
@@ -15,54 +14,29 @@ local path_images = 'Images'
local path_annotations = 'Annotations'
local path_masks = 'Masks'
-
-colorobject={
-[1]={224,224,192},-- unknown
-[2]={0, 0, 0},-- background
-[3]={128, 0, 0},
-[4]={0, 128, 0},
-[5]={128, 128, 0},
-[6]={0, 0, 128},
-[7]={128, 0, 128},
-[8]={0, 128, 128},
-[9]={128, 128, 128},
-[10]={64, 0, 0},
-[11]={192, 0, 0},
-[12]={64, 128, 0},
-[13]={192, 128, 0},
-[14]={64, 0, 128},
-[15]={192, 0, 128},
-[16]={64, 128, 128},
-[17]={192, 128, 128},
-[18]={0, 64, 0},
-[19]={128, 64, 0},
-[20]={0, 192, 0},
-[21]={128, 192, 0},
-[22]={0, 64, 128}}
-
-
---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
-
function DataSetSamplingPascal:__init(...)
-- check args
xlua.unpack_class(
self,
{...},
'DataSetSamplingPascal',
- 'Creates a DataSet from standard Pascal directories (Images+Annotations)',
- {arg='path', type='string', help='path to Pascal directory', req=true},
+ 'Creates a DataSet from standard LabelMe directories (Images+Annotations)',
+ {arg='path', type='string', help='path to LabelMe directory', req=true},
{arg='nbClasses', type='number', help='number of classes in dataset', default=1},
- {arg='nbSegments', type='number', help='number of segment per image in dataset', default=100},
{arg='classNames', type='table', help='list of class names', default={'no name'}},
{arg='nbRawSamples', type='number', help='number of images'},
+ {arg='nbSegments', type='number', help='number of segment per image in dataset', default=100},
{arg='rawSampleMaxSize', type='number', help='resize all images to fit in a MxM window'},
{arg='rawSampleSize', type='table', help='resize all images precisely: {w=,h=}}'},
{arg='rawMaskRescale',type='boolean',help='does are the N classes spread between 0->255 in the PNG and need to be rescaled',default=true},
- {arg='samplingMode', type='string', help='segment sampling method: random | equal', default='random'},
+ {arg='nbPatchPerSample', type='number', help='number of patches to extract from each image', default=100},
+ {arg='patchSize', type='number', help='size of patches to extract from images', default=64},
+ {arg='samplingMode', type='string', help='patch sampling method: random | equal', default='random'},
+ {arg='samplingFilter', type='table', help='a filter to sample patches: {ratio=,size=,step}'},
{arg='labelType', type='string', help='type of label returned: center | pixelwise', default='center'},
+ {arg='labelGenerator', type='function', help='a function to generate sample+target (bypasses labelType)'},
{arg='infiniteSet', type='boolean', help='if true, the set can be indexed to infinity, looping around samples', default=false},
- {arg='classToSkip', type='number', help='index of class to skip during sampling', default=1},
- {arg='ScClassToSkip', type='number', help='index of class to skip during sampling', default=1},
+ {arg='classToSkip', type='number', help='index of class to skip during sampling', default=0},
{arg='preloadSamples', type='boolean', help='if true, all samples are preloaded in memory', default=false},
{arg='cacheFile', type='string', help='path to cache file (once cached, loading is much faster)'},
{arg='verbose', type='boolean', help='dumps information', default=false}
@@ -72,13 +46,15 @@ function DataSetSamplingPascal:__init(...)
self.colorMap = image.colormap(self.nbClasses)
self.rawdata = {}
self.currentIndex = -1
- self.realIndex = -1
- self.ctr_segment_index = 0
- self.ctr_gt_index = 0
+ --location of the patch in the img
+ self.currentX = 0
+ self.currentY = 0
+ self.realIndex = -1
+ self.currentSegment = 0
-- parse dir structure
- print('<DataSetSamplingPascal> loading Pascal dataset from '..self.path)
+ print('<DataSetSamplingPascal> loading LabelMe dataset from '..self.path)
for folder in paths.files(paths.concat(self.path,path_images)) do
if folder ~= '.' and folder ~= '..' then
-- allowing for less nesting in the data set preparation [MS]
@@ -89,7 +65,7 @@ function DataSetSamplingPascal:__init(...)
for file in paths.files(paths.concat(self.path,path_images,folder)) do
if file ~= '.' and file ~= '..' then
- self:getsizes(folder,file)
+ self:getsizes(folder,file)
end
end
end
@@ -110,7 +86,8 @@ function DataSetSamplingPascal:__init(...)
self.maxY = self.rawdata[i].size[2]
end
end
- self.nbSamples = self.nbRawSamples
+ -- and nb of samples obtainable (this is overcomplete ;-)
+ self.nbSamples = self.nbPatchPerSample * self.nbRawSamples
-- max size ?
local maxXY = math.max(self.maxX, self.maxY)
@@ -131,22 +108,23 @@ function DataSetSamplingPascal:__init(...)
print(self)
end
+
-- sampling mode
if self.samplingMode == 'equal' or self.samplingMode == 'random' then
self:parseAllMasks()
if self.samplingMode == 'random' then
- -- get the number of usable segments
- self.nbRandomSegments = 0
+ -- get the number of usable patches
+ self.nbRandomPatches = 0
for i,v in ipairs(self.tags) do
- if i ~= self.classToSkip and i ~= self.ScClassToSkip then
- self.nbRandomSegments = self.nbRandomSegments + v.size
+ if i ~= self.classToSkip then
+ self.nbRandomPatches = self.nbRandomPatches + v.size
end
end
-- create shuffle table
- self.randomLookup = torch.ByteTensor(self.nbRandomSegments)
+ self.randomLookup = torch.ByteTensor(self.nbRandomPatches)
local idx = 1
for i,v in ipairs(self.tags) do
- if i ~= self.classToSkip and i ~= self.ScClassToSkip and v.size > 0 then
+ if i ~= self.classToSkip and v.size > 0 then
self.randomLookup:narrow(1,idx,v.size):fill(i)
idx = idx + v.size
end
@@ -156,14 +134,14 @@ function DataSetSamplingPascal:__init(...)
error('ERROR <DataSetSamplingPascal> unknown sampling mode')
end
+
+
-- preload ?
if self.preloadSamples then
self:preload()
end
end
---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
-
function DataSetSamplingPascal:getsizes(folder,file)
local filepng = file:gsub('jpg$','png')
local filexml = file:gsub('jpg$','xml')
@@ -196,14 +174,10 @@ function DataSetSamplingPascal:getsizes(folder,file)
size={size_c, size_y, size_x}})
end
---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
-
function DataSetSamplingPascal:size()
return self.nbSamples
end
---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
-
function DataSetSamplingPascal:__tostring__()
local str = 'DataSetSamplingPascal:\n'
str = str .. ' + path : '..self.path..'\n'
@@ -211,6 +185,7 @@ function DataSetSamplingPascal:__tostring__()
str = str .. ' + cache files : [path]/'..self.cacheFile..'-[tags|samples]\n'
end
str = str .. ' + nb samples : '..self.nbRawSamples..'\n'
+ str = str .. ' + nb generated patches : '..self.nbSamples..'\n'
if self.infiniteSet then
str = str .. ' + infinite set (actual nb of samples >> set:size())\n'
end
@@ -222,15 +197,16 @@ function DataSetSamplingPascal:__tostring__()
str = str .. ' + imposed ratio of ' .. self.rawSampleSize.w .. 'x' .. self.rawSampleSize.h .. '\n'
end
end
+ str = str .. ' + patches size : ' .. self.patchSize .. 'x' .. self.patchSize .. '\n'
if self.classToSkip ~= 0 then
str = str .. ' + unused class : ' .. self.classNames[self.classToSkip] .. '\n'
end
- if self.ScClassToSkip ~= 0 then
- str = str .. ' + unused class : ' .. self.classNames[self.ScClassToSkip] .. '\n'
- end
-
str = str .. ' + sampling mode : ' .. self.samplingMode .. '\n'
- str = str .. ' + label type : ' .. self.labelType .. '\n'
+ if not self.labelGenerator then
+ str = str .. ' + label type : ' .. self.labelType .. '\n'
+ else
+ str = str .. ' + label type : generated by user function \n'
+ end
str = str .. ' + '..self.nbClasses..' categories : '
for i = 1,#self.classNames-1 do
str = str .. self.classNames[i] .. ' | '
@@ -239,51 +215,75 @@ function DataSetSamplingPascal:__tostring__()
return str
end
---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
-
function DataSetSamplingPascal:__index__(key)
-
-- generate sample + target at index 'key':
if type(key)=='number' then
-- select sample, according to samplingMode
-
+ local box_size = self.patchSize
local ctr_target, tag_idx
if self.samplingMode == 'random' then
-- get indexes from random table
ctr_target = self.randomLookup[math.random(1,self.nbRandomPatches)]
- tag_idx = math.floor(math.random(0,self.tags[ctr_target].size-1)/3)*3+1
+ tag_idx = math.floor(math.random(0,self.tags[ctr_target].size-1)/4)*4+1
elseif self.samplingMode == 'equal' then
-- equally sample each category:
ctr_target = ((key-1) % (self.nbClasses)) + 1
- while self.tags[ctr_target].size == 0 or ctr_target == self.classToSkip or ctr_target == self.ScClassToSkip do
+ while self.tags[ctr_target].size == 0 or ctr_target == self.classToSkip do
-- no sample in that class, replacing with random patch
ctr_target = math.floor(torch.uniform(1,self.nbClasses))
end
local nbSamplesPerClass = math.ceil(self.nbSamples / self.nbClasses)
if self.infiniteSet then
- tag_idx = math.random(1,self.tags[ctr_target].size/3)
+ tag_idx = math.random(1,self.tags[ctr_target].size/4)
else
tag_idx = math.floor((key-1)/self.nbClasses) + 1
end
- tag_idx = ((tag_idx-1) % (self.tags[ctr_target].size/3))*3 + 1
+ tag_idx = ((tag_idx-1) % (self.tags[ctr_target].size/4))*4 + 1
end
-- generate patch
+
self:loadSample(self.tags[ctr_target].data[tag_idx+2])
- local sample = self.currentSample
- local mask = self.currentMask
- self.ctr_segment_index = self.tags[ctr_target].data[tag_idx]
- self.ctr_gt_index = self.tags[ctr_target].data[tag_idx+1]
+ local full_sample = self.currentSample
+ local full_mask = self.currentMask
+ local ctr_x = self.tags[ctr_target].data[tag_idx]
+ local ctr_y = self.tags[ctr_target].data[tag_idx+1]
+ local box_x = math.floor(ctr_x - box_size/2) + 1
+ self.currentX = box_x/full_sample:size(3)
+ local box_y = math.floor(ctr_y - box_size/2) + 1
+ self.currentY = box_y/full_sample:size(2)
+ self.currentSegment = self.tags[ctr_target].data[tag_idx+3]
+
+ -- extract sample + mask:
+ local sample = full_sample:narrow(2,box_y,box_size):narrow(3,box_x,box_size)
+ local mask = full_mask:narrow(1,box_y,box_size):narrow(2,box_x,box_size)
+
+ -- finally, generate the target, either using an arbitrary user function,
+ -- or a built-in label type
+ if self.labelGenerator then
+ -- call user function to generate sample+label
+ local ret = self:labelGenerator(full_sample, full_mask, sample, mask,
+ ctr_target, ctr_x, ctr_y, box_x, box_y, box_size,self.currentSegment,self.realIndex)
+ return ret, true
+
+ elseif self.labelType == 'center' then
+ -- generate label vector for patch
+ local vector = torch.Tensor(self.nbClasses):fill(-1)
+ vector[ctr_target] = 1
+ return {sample, vector}, true
+
+ elseif self.labelType == 'pixelwise' then
+ -- generate pixelwise annotation
+ return {sample, mask}, true
- -- Ce serait bien de rajouter le vecteur overlap pour ne pas le recalculer
- return {sample,mask,self.ctr_segment_index, self.ctr_gt_index}, true
- end
+ else
+ return false
+ end
+ end
return rawget(self,key)
end
---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
-
function DataSetSamplingPascal:loadSample(index)
if self.preloadedDone then
@@ -374,8 +374,6 @@ function DataSetSamplingPascal:loadSample(index)
end
end
---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
-
function DataSetSamplingPascal:preload(saveFile)
-- if cache file exists, just retrieve images from it
if self.cacheFile
@@ -418,130 +416,70 @@ function DataSetSamplingPascal:preload(saveFile)
end
end
---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
-
function DataSetSamplingPascal:parseMask(existing_tags)
local tags
if not existing_tags then
tags = {}
local storage
for i = 1,self.nbClasses do
- storage = torch.ShortStorage(self.rawSampleMaxSize*self.rawSampleMaxSize*3)
+ storage = torch.ShortStorage(self.rawSampleMaxSize*self.rawSampleMaxSize*4)
tags[i] = {data=storage, size=0}
end
-
else
tags = existing_tags
-- make sure each tag list is large enough to hold the incoming data
for i = 1,self.nbClasses do
- --print('current size'..tags[i].size)
- if ((tags[i].size + (self.rawSampleMaxSize*self.rawSampleMaxSize*3)) >
+ if ((tags[i].size + (self.rawSampleMaxSize*self.rawSampleMaxSize*4)) >
tags[i].data:size()) then
- tags[i].data:resize(tags[i].size+(self.rawSampleMaxSize*self.rawSampleMaxSize*3),true)
+ tags[i].data:resize(tags[i].size+(self.rawSampleMaxSize*self.rawSampleMaxSize*4),true)
end
end
end
+ -- use filter
+ local filter = self.samplingFilter or {ratio=0, size=self.patchSize, step=4}
-- extract labels
local mask = self.currentMask
-
- -- we dont need much precison for what follows to we downsample images to get faster results
- smallmask = torch.Tensor(100,100)
- image.scale(mask, smallmask, 'simple')
-
- -- (1) Compute overlap score for each segment given the image
-
- local file = sys.concat(self.realIndex:gsub('Images','Segments'),'.mat')
- local mat_path = file:gsub('/.mat$','.mat')
- local loaded = mattorch.load(mat_path)
- loaded = loaded.top_masks:float()
- local segment1, segmenttmp
- nb_segments = self.nbSegments
- if self.nbSegments > loaded:size(1) then nb_segments = loaded:size(1) end
-
- for k=1,nb_segments do
-
- -- (a) load one segment mask
- segment1 = loaded[k]:t()
-
- -- (b) resize the segment mask
- segmenttmp = image.scale(segment1, 100,100)
-
- -- (c) compute overlap
- local overlap = imgraph.overlap(segmenttmp, smallmask,#classes_pascal)
-
- -- (2) If overlap score ok, add the segment index to tags
- for i = 1,self.nbClasses do
- if overlap[i]>0.2 and overlap[i]<0.99 then
- tags[i].data[tags[i].size+1] = k
- if overlap[i]>0.5 then
- tags[i].data[tags[i].size+2] = i+100 -- this tag corresponds to a real segmentation
- else
- tags[i].data[tags[i].size+2] = i
- end
- tags[i].data[tags[i].size+3] = self.currentIndex
- tags[i].size = tags[i].size+3
- -- print('insert '..k..'; '..tags[i].size )
- end
- end
-
-
- end
-
- -- (2) load the object ground truth image
-
- file = sys.concat(self.realIndex:gsub('Images','Objects'),'.png')
- local mask_path = file:gsub('/.png$','.png')
-
- maskobject = image.load(mask_path)
- maskobject= maskobject:mul(255)
- smallmaskobject = torch.Tensor(3,100,100)
- image.scale(maskobject, smallmaskobject, 'simple')
-
- maskobject = smallmaskobject:floor()
-
- -- extracting segments from ground truth
-
-i= 3
-still_run = 1
-while still_run~=0 and i <=21 do -- 21 nb max of different objects
- still_run = 0
- ii=1
- already_taged = 0
- while ii<=100 do
- jj=1
- while jj<=100 do
-if already_taged == 0 then
- if maskobject[1][ii][jj] == colorobject[i][1] and maskobject[2][ii][jj] == colorobject[i][2]
- and maskobject[3][ii][jj] == colorobject[i][3] then
- k = smallmask[ii][jj]
- tags[k].data[tags[k].size+1] = 0 --this tag corresponds to a ground truth object
- tags[k].data[tags[k].size+2] = i-- index of a ground truth object
- tags[k].data[tags[k].size+3] = self.currentIndex
- tags[k].size = tags[k].size+3
- already_taged = 1
- -- print('classe'..k..'objet'..i )
- end
-end
- if maskobject[1][ii][jj] == colorobject[i+1][1] and maskobject[2][ii][jj] == colorobject[i+1][2]
- and maskobject[3][ii][jj] == colorobject[i+1][3] then
- still_run = still_run+1
- if already_taged == 1 then
- ii = 100
- jj= 100
- end
- end
- jj=jj+1
- end
- ii=ii+1
- end
- i=i+1
-end
+ local x_start = math.ceil(self.patchSize/2)
+ local x_end = mask:size(2) - math.ceil(self.patchSize/2)
+ local y_start = math.ceil(self.patchSize/2)
+ local y_end = mask:size(1) - math.ceil(self.patchSize/2)
+
+
+ local file = sys.concat(self.realIndex:gsub('Images','Segments'),'.mat')
+ local mat_path = file:gsub('/.mat$','.mat')
+ local loaded = mattorch.load(mat_path)
+ loaded = loaded.top_masks:float()
+ local segment1, segmenttmp
+ nb_segments = self.nbSegments
+ if self.nbSegments > loaded:size(1) then nb_segments = loaded:size(1) end
+
+ -- (1) load a random segment
+ -- for i=1,self.nbPatchPerSample do
+
+ k = math.random(nb_segments)
+ segment1 = loaded[k]:t()
+
+ segmenttmp = image.scale(segment1, width, height)
+ -- segmenttmp = segmenttmp:narrow(2, x_start, x_end):narrow(1,y_start, y_end)
+
+ -- (2) mask the ground truth mask with the random segment.
+ segmenttmp:cmul(mask:add(-1)):add(1)
+
+ --print('extract labels segment '..k)
+ self.currentSegment = k
+ --print('self'..self.currentSegment)
+
+
+ mask.nn.DataSetSegmentSampling_extract(tags, segmenttmp,
+ x_start, x_end,
+ y_start, y_end, self.currentIndex, self.currentSegment,
+ filter.ratio, filter.size, filter.step)
+ -- end
+
return tags
end
---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
-
function DataSetSamplingPascal:parseAllMasks(saveFile)
-- if cache file exists, just retrieve tags from it
if self.cacheFile and paths.filep(paths.concat(self.path,self.cacheFile..'-tags')) then
@@ -557,7 +495,7 @@ function DataSetSamplingPascal:parseAllMasks(saveFile)
print('<DataSetSamplingPascal> parsing all masks to generate list of tags')
print('<DataSetSamplingPascal> WARNING: this operation could allocate up to '..
math.ceil(self.nbRawSamples*self.rawSampleMaxSize*self.rawSampleMaxSize*
- 3*2/1024/1024)..'MB')
+ 4*2/1024/1024)..'MB')
self.tags = nil
for i = 1,self.nbRawSamples do
xlua.progress(i,self.nbRawSamples)
@@ -565,9 +503,9 @@ function DataSetSamplingPascal:parseAllMasks(saveFile)
self.tags = self:parseMask(self.tags)
end
-- report
- print('<DataSetSamplingPascal> nb of segment extracted per category:')
+ print('<DataSetSamplingPascal> nb of patches extracted per category:')
for i = 1,self.nbClasses do
- print(' ' .. i .. ' - ' .. self.tags[i].size/3)
+ print(' ' .. i .. ' - ' .. self.tags[i].size / 4)
end
-- optional cache file
if saveFile then
@@ -583,8 +521,6 @@ function DataSetSamplingPascal:parseAllMasks(saveFile)
end
end
---$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
-
function DataSetSamplingPascal:display(...)
-- check args
local _, title, samples, zoom = xlua.unpack(
@@ -617,11 +553,3 @@ function DataSetSamplingPascal:display(...)
-- display
image.display{win=painter, image=allimgs, legend=title, zoom=0.5}
end
-
-
-
-
-
-
-
-