diff options
author | Camille Couprie <ccouprie@cs.nyu.edu> | 2012-08-23 20:33:07 +0400 |
---|---|---|
committer | Camille Couprie <ccouprie@cs.nyu.edu> | 2012-08-23 20:33:07 +0400 |
commit | a259d971bf5d1156f47a522db24fe51c57db38a9 (patch) | |
tree | bd845c493c37e112ecf08ed10450eeba2f9761e5 /DataSetSamplingPascal.lua | |
parent | fa9574eda5b8f25992e2911d7435c3f2db132949 (diff) |
added a way to remove the background and unknown classes from sampling
Diffstat (limited to 'DataSetSamplingPascal.lua')
-rw-r--r-- | DataSetSamplingPascal.lua | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/DataSetSamplingPascal.lua b/DataSetSamplingPascal.lua index b9b98cc..021d2df 100644 --- a/DataSetSamplingPascal.lua +++ b/DataSetSamplingPascal.lua @@ -36,7 +36,8 @@ function DataSetSamplingPascal:__init(...) {arg='samplingMode', type='string', help='segment sampling method: random | equal', default='random'}, {arg='labelType', type='string', help='type of label returned: center | pixelwise', default='center'}, {arg='infiniteSet', type='boolean', help='if true, the set can be indexed to infinity, looping around samples', default=false}, - {arg='classToSkip', type='number', help='index of class to skip during sampling', default=0}, + {arg='classToSkip', type='number', help='index of class to skip during sampling', default=1}, + {arg='ScClassToSkip', type='number', help='index of class to skip during sampling', default=2}, {arg='preloadSamples', type='boolean', help='if true, all samples are preloaded in memory', default=false}, {arg='cacheFile', type='string', help='path to cache file (once cached, loading is much faster)'}, {arg='verbose', type='boolean', help='dumps information', default=false} @@ -112,7 +113,7 @@ function DataSetSamplingPascal:__init(...) -- get the number of usable segments self.nbRandomSegments = 0 for i,v in ipairs(self.tags) do - if i ~= self.classToSkip then + if i ~= self.classToSkip and i ~= self.ScClassToSkip then self.nbRandomSegments = self.nbRandomSegments + v.size end end @@ -120,7 +121,7 @@ function DataSetSamplingPascal:__init(...) self.randomLookup = torch.ByteTensor(self.nbRandomSegments) local idx = 1 for i,v in ipairs(self.tags) do - if i ~= self.classToSkip and v.size > 0 then + if i ~= self.classToSkip and i ~= self.ScClassToSkip and v.size > 0 then self.randomLookup:narrow(1,idx,v.size):fill(i) idx = idx + v.size end @@ -199,6 +200,10 @@ function DataSetSamplingPascal:__tostring__() if self.classToSkip ~= 0 then str = str .. ' + unused class : ' .. self.classNames[self.classToSkip] .. '\n' end + if self.ScClassToSkip ~= 0 then + str = str .. ' + unused class : ' .. self.classNames[self.ScClassToSkip] .. '\n' + end + str = str .. ' + sampling mode : ' .. self.samplingMode .. '\n' str = str .. ' + label type : ' .. self.labelType .. '\n' str = str .. ' + '..self.nbClasses..' categories : ' @@ -226,7 +231,7 @@ function DataSetSamplingPascal:__index__(key) elseif self.samplingMode == 'equal' then -- equally sample each category: ctr_target = ((key-1) % (self.nbClasses)) + 1 - while self.tags[ctr_target].size == 0 or ctr_target == self.classToSkip do + while self.tags[ctr_target].size == 0 or ctr_target == self.classToSkip or ctr_target == self.ScClassToSkip do -- no sample in that class, replacing with random patch ctr_target = math.floor(torch.uniform(1,self.nbClasses)) end |