diff options
author | Marco Scoffier <github@metm.org> | 2011-09-10 22:11:35 +0400 |
---|---|---|
committer | Marco Scoffier <github@metm.org> | 2011-09-10 22:11:35 +0400 |
commit | ba0b319f123b85665e4ae4c8ae45b7a63c2fe3c0 (patch) | |
tree | 8761a4b5eeaf8fe34e85f7b8884e9cd1c1c94cee | |
parent | fcfea5cfbb171841082accd6df137910ff29fbbb (diff) |
updated DataSet to accept png and not automatic rescale. defaults are identical to before
-rw-r--r-- | DataSetLabelMe.lua | 98 |
1 files changed, 61 insertions, 37 deletions
diff --git a/DataSetLabelMe.lua b/DataSetLabelMe.lua index 5666c96..629561a 100644 --- a/DataSetLabelMe.lua +++ b/DataSetLabelMe.lua @@ -27,6 +27,7 @@ function DataSetLabelMe:__init(...) {arg='nbRawSamples', type='number', help='number of images'}, {arg='rawSampleMaxSize', type='number', help='resize all images to fit in a MxM window'}, {arg='rawSampleSize', type='table', help='resize all images precisely: {w=,h=}}'}, + {arg='rawMaskRescale',type='boolean',help='does are the N classes spread between 0->255 in the PNG and need to be rescaled',default=true}, {arg='nbPatchPerSample', type='number', help='number of patches to extract from each image', default=100}, {arg='patchSize', type='number', help='size of patches to extract from images', default=64}, {arg='samplingMode', type='string', help='patch sampling method: random | equal', default='random'}, @@ -53,42 +54,21 @@ function DataSetLabelMe:__init(...) print('<DataSetLabelMe> loading LabelMe dataset from '..self.path) for folder in paths.files(paths.concat(self.path,path_images)) do if folder ~= '.' and folder ~= '..' then - for file in paths.files(paths.concat(self.path,path_images,folder)) do - if file ~= '.' and file ~= '..' then - local filepng = file:gsub('jpg$','png') - local filexml = file:gsub('jpg$','xml') - local imgf = paths.concat(self.path,path_images,folder,file) - local maskf = paths.concat(self.path,path_masks,folder,filepng) - local annotf = paths.concat(self.path,path_annotations,folder,filexml) - local size_c, size_y, size_x - if file:find('.jpg$') then - size_c, size_y, size_x = image.getJPGsize(imgf) - elseif file:find('.mat$') then - if not xrequire 'mattorch' then - xerror('<DataSetLabelMe> mattorch package required to handle MAT files') - end - local loaded = mattorch.load(imgf) - for _,matrix in pairs(loaded) do loaded = matrix; break end - size_c = loaded:size(1) - size_y = loaded:size(2) - size_x = loaded:size(3) - loaded = nil - collectgarbage() - else - xerror('images must either be JPG or MAT files', 'DataSetLabelMe') - end - table.insert(self.rawdata, {imgfile=imgf, - maskfile=maskf, - annotfile=annotf, - size={size_c, size_y, size_x}}) - end - end + -- allowing for less nesting in the data set preparation [MS] + if sys.filep(paths.concat(self.path,path_images,folder)) then + self:getsizes('./',folder) + else + -- loop though nested folders + for file in paths.files(paths.concat(self.path,path_images,folder)) do + if file ~= '.' and file ~= '..' then + self:getsizes(folder,file) + end + end + end end end - -- nb samples: user defined or max self.nbRawSamples = self.nbRawSamples or #self.rawdata - -- extract some info (max sizes) self.maxY = self.rawdata[1].size[2] self.maxX = self.rawdata[1].size[3] @@ -104,10 +84,15 @@ function DataSetLabelMe:__init(...) self.nbSamples = self.nbPatchPerSample * self.nbRawSamples -- max size ? + local maxXY = math.max(self.maxX, self.maxY) if not self.rawSampleMaxSize then - self.rawSampleMaxSize = math.max(self.rawSampleSize.w,self.rawSampleSize.h) + if self.rawSampleSize then + self.rawSampleMaxSize = + math.max(self.rawSampleSize.w,self.rawSampleSize.h) + else + self.rawSampleMaxSize = maxXY + end end - local maxXY = math.max(self.maxX, self.maxY) if maxXY < self.rawSampleMaxSize then self.rawSampleMaxSize = maxXY end @@ -148,6 +133,37 @@ function DataSetLabelMe:__init(...) end end +function DataSetLabelMe:getsizes(folder,file) + local filepng = file:gsub('jpg$','png') + local filexml = file:gsub('jpg$','xml') + local imgf = paths.concat(self.path,path_images,folder,file) + local maskf = paths.concat(self.path,path_masks,folder,filepng) + local annotf = paths.concat(self.path,path_annotations,folder,filexml) + local size_c, size_y, size_x + if file:find('.jpg$') then + size_c, size_y, size_x = image.getJPGsize(imgf) + elseif file:find('.png$') then + size_c, size_y, size_x = image.getPNGsize(imgf) + elseif file:find('.mat$') then + if not xrequire 'mattorch' then + xerror('<DataSetLabelMe> mattorch package required to handle MAT files') + end + local loaded = mattorch.load(imgf) + for _,matrix in pairs(loaded) do loaded = matrix; break end + size_c = loaded:size(1) + size_y = loaded:size(2) + size_x = loaded:size(3) + loaded = nil + collectgarbage() + else + xerror('images must either be JPG, PNG or MAT files', 'DataSetLabelMe') + end + table.insert(self.rawdata, {imgfile=imgf, + maskfile=maskf, + annotfile=annotf, + size={size_c, size_y, size_x}}) +end + function DataSetLabelMe:size() return self.nbSamples end @@ -331,8 +347,13 @@ function DataSetLabelMe:loadSample(index) if self.currentMask:min() == 0 then self.currentMask:add(1) end - else + elseif self.rawMaskRescale then + -- stanford dataset style (png contains 0 and 255) self.currentMask:mul(self.nbClasses-1):add(0.5):floor():add(1) + else + -- PNG already stores values at the correct classes + -- only holds values from 0 to nclasses + self.currentMask:mul(255):add(1) end self.currentIndex = index end @@ -400,14 +421,17 @@ function DataSetLabelMe:parseMask(existing_tags) end end -- use filter - local filter = self.samplingFilter or {ratio=0, size=self.patchSize, step=4} + local filter = self.samplingFilter or + {ratio=0, size=self.patchSize, step=4} -- extract labels local mask = self.currentMask local x_start = math.ceil(self.patchSize/2) local x_end = mask:size(2) - math.ceil(self.patchSize/2) local y_start = math.ceil(self.patchSize/2) local y_end = mask:size(1) - math.ceil(self.patchSize/2) - mask.nn.DataSetLabelMe_extract(tags, mask, x_start, x_end, y_start, y_end, self.currentIndex, + mask.nn.DataSetLabelMe_extract(tags, mask, + x_start, x_end, + y_start, y_end, self.currentIndex, filter.ratio, filter.size, filter.step) return tags end |