diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-07-27 23:59:43 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-07-27 23:59:43 +0400 |
commit | 4fe018c20f291e00b7bca2849c4cf0c9e1415903 (patch) | |
tree | 3e954a5329c27d6651bf0c551a0016860f60418f | |
parent | 1e180fefe5018accab7b17cb67658a42a5a37fbe (diff) |
Added a filter to sample patches in a smarter way.
-rw-r--r-- | DataSetLabelMe.lua | 7 | ||||
-rw-r--r-- | generic/DataSetLabelMe.c | 51 |
2 files changed, 40 insertions, 18 deletions
diff --git a/DataSetLabelMe.lua b/DataSetLabelMe.lua index 960d07f..fed7b70 100644 --- a/DataSetLabelMe.lua +++ b/DataSetLabelMe.lua @@ -30,6 +30,7 @@ function DataSetLabelMe:__init(...) {arg='nbPatchPerSample', type='number', help='number of patches to extract from each image', default=100}, {arg='patchSize', type='number', help='size of patches to extract from images', default=64}, {arg='samplingMode', type='string', help='patch sampling method: random | equal', default='random'}, + {arg='samplingFilter', type='table', help='a filter to sample patches: {ratio=,size=,step}'}, {arg='labelType', type='string', help='type of label returned: center | pixelwise', default='center'}, {arg='labelGenerator', type='function', help='a function to generate sample+target (bypasses labelType)'}, {arg='infiniteSet', type='boolean', help='if true, the set can be indexed to infinity, looping around samples', default=false}, @@ -358,12 +359,16 @@ function DataSetLabelMe:parseMask(existing_tags) end end end + -- use filter + local filter = self.samplingFilter or {ratio=0, size=self.patchSize, step=4} + -- extract labels local mask = self.currentMask local x_start = math.ceil(self.patchSize/2) local x_end = mask:size(2) - math.ceil(self.patchSize/2) local y_start = math.ceil(self.patchSize/2) local y_end = mask:size(1) - math.ceil(self.patchSize/2) - mask.nn.DataSetLabelMe_extract(tags, mask, x_start, x_end, y_start, y_end, self.currentIndex) + mask.nn.DataSetLabelMe_extract(tags, mask, x_start, x_end, y_start, y_end, self.currentIndex, + filter.ratio, filter.size, filter.step) return tags end diff --git a/generic/DataSetLabelMe.c b/generic/DataSetLabelMe.c index 0f74419..878c44b 100644 --- a/generic/DataSetLabelMe.c +++ b/generic/DataSetLabelMe.c @@ -12,33 +12,50 @@ static int nn_(DataSetLabelMe_extract)(lua_State *L) int y_start = lua_tonumber(L, 5); int y_end = lua_tonumber(L, 6); int idx = lua_tonumber(L, 7); + float filter_ratio = lua_tonumber(L, 8); + int filter_size = lua_tonumber(L, 9); + int filter_step = lua_tonumber(L, 10); + float ratio = 1; int x,y,label,tag,size; THShortStorage *data; for (x=x_start; x<=x_end; x++) { for (y=y_start; y<=y_end; y++) { - label = THTensor_(get2d)(mask, y-1, x-1); // label = mask[x][y] - lua_rawgeti(L, tags, label); // tag = tags[label] - tag = lua_gettop(L); - lua_pushstring(L, "size"); lua_rawget(L, tag); // size = tag.size - size = lua_tonumber(L,-1); lua_pop(L,1); - lua_pushstring(L, "size"); lua_pushnumber(L, size+3); lua_rawset(L, tag); // tag.size = size + 3 - lua_pushstring(L, "data"); lua_rawget(L, tag); // data = tag.data - data = luaT_checkudata(L, -1, torch_ShortStorage_id); lua_pop(L, 1); - data->data[size] = x; // data[size+1] = x - data->data[size+1] = y; // data[size+1] = y - data->data[size+2] = idx; // data[size+1] = idx - lua_pop(L, 1); + // label = mask[x][y] + label = THTensor_(get2d)(mask, y-1, x-1); + + // optional filter: insures that at least N% of local pixels belong to the same class + if (filter_ratio > 0) { + int kx,ky,count=0,good=0; + for (kx=MAX(1,x-filter_size/2); kx<=MIN(x_end,x+filter_size/2); kx+=filter_step) { + for (ky=MAX(1,y-filter_size/2); ky<=MIN(y_end,y+filter_size/2); ky+=filter_step) { + int other = THTensor_(get2d)(mask, ky-1, kx-1); + if (other == label) good++; + count++; + } + } + ratio = (float)good/(float)count; + } + + // if filter(s) satisfied, then append label + if (ratio >= filter_ratio) { + lua_rawgeti(L, tags, label); // tag = tags[label] + tag = lua_gettop(L); + lua_pushstring(L, "size"); lua_rawget(L, tag); // size = tag.size + size = lua_tonumber(L,-1); lua_pop(L,1); + lua_pushstring(L, "size"); lua_pushnumber(L, size+3); lua_rawset(L, tag); // tag.size = size + 3 + lua_pushstring(L, "data"); lua_rawget(L, tag); // data = tag.data + data = luaT_checkudata(L, -1, torch_ShortStorage_id); lua_pop(L, 1); + data->data[size] = x; // data[size+1] = x + data->data[size+1] = y; // data[size+1] = y + data->data[size+2] = idx; // data[size+1] = idx + lua_pop(L, 1); + } } } return 0; } -static int nn_(DataSetLabelMe_backward)(lua_State *L) -{ - -} - static const struct luaL_Reg nn_(DataSetLabelMe__) [] = { {"DataSetLabelMe_extract", nn_(DataSetLabelMe_extract)}, {NULL, NULL} |