diff options
author | Marco Scoffier <github@metm.org> | 2011-09-15 11:03:55 +0400 |
---|---|---|
committer | Marco Scoffier <github@metm.org> | 2011-09-15 11:03:55 +0400 |
commit | ab6e8c5fe0cec06682d6bc0bf7ae3c518e934d7f (patch) | |
tree | ce8b07a81acd93547abc32cba36365a5a8517d41 | |
parent | 28d484e2b0be59aceb3addea2b23b706523100c8 (diff) |
changes to run notMNIST with padding
-rw-r--r-- | DataList.lua | 20 | ||||
-rw-r--r-- | DataSet.lua | 20 |
2 files changed, 29 insertions, 11 deletions
diff --git a/DataList.lua b/DataList.lua index 43617ed..287d7dd 100644 --- a/DataList.lua +++ b/DataList.lua @@ -12,6 +12,7 @@ function DataList:__init() self.datasets = {} self.nbClass = 0 self.ClassName = {} + self.ClassRev = {} self.nbSamples = 0 self.targetIsProbability = false self.spatialTarget = false @@ -67,12 +68,17 @@ end function DataList:appendDataSet(dataSet,className) table.insert(self.datasets,dataSet) - if self.nbSamples == 0 then - self.nbSamples = dataSet:size() - else - self.nbSamples = math.floor(math.max(self.nbSamples/self.nbClass,dataSet:size())) + -- if self.nbSamples == 0 then + -- self.nbSamples = dataSet:size() + -- else + -- print(self.nbSamples/self.nbClass,dataSet:size()) + -- self.nbSamples = math.floor(math.max(self.nbSamples/self.nbClass, + -- dataSet:size())) + -- end + if not self.ClassRev[className] then + self.ClassRev[className] = true + self.nbClass = self.nbClass + 1 + table.insert(self.ClassName,self.nbClass,className) end - self.nbClass = self.nbClass + 1 - self.nbSamples = self.nbSamples * self.nbClass - table.insert(self.ClassName,self.nbClass,className) + self.nbSamples = self.nbSamples + dataSet:size() end diff --git a/DataSet.lua b/DataSet.lua index 36cb7c9..4bb618c 100644 --- a/DataSet.lua +++ b/DataSet.lua @@ -30,7 +30,8 @@ end function lDataSet:load(...) -- parse args - local args, dataSetFolder, nbSamplesRequired, cacheFile, channels, sampleSize + local args, dataSetFolder, nbSamplesRequired, cacheFile, channels, + sampleSize,padding = xlua.unpack( {...}, 'DataSet.load', nil, @@ -38,7 +39,8 @@ function lDataSet:load(...) {arg='nbSamplesRequired', type='number', help='number of patches to load', default='all'}, {arg='cacheFile', type='string', help='path to file to cache files'}, {arg='channels', type='number', help='nb of channels', default=1}, - {arg='sampleSize', type='table', help='resize all sample: {w,h}'} + {arg='sampleSize', type='table', help='resize all sample: {c,w,h}'}, + {arg='padding', type='boolean', help='center sample in w,h dont rescale'} ) self.cacheFileName = cacheFile or self.cacheFileName @@ -115,7 +117,7 @@ end function lDataSet:append(...) -- parse args local args, dataSetFolder, channels, nbSamplesRequired, useLabelPiped, - useDirAsLabel, nbLabels, sampleSize + useDirAsLabel, nbLabels, sampleSize, padding = xlua.unpack( {...}, 'DataSet:append', 'append a folder to the dataset object', @@ -125,7 +127,8 @@ function lDataSet:append(...) {arg='useLabelPiped', type='boolean', help='flag to use the filename as output value',default=false}, {arg='useDirAsLabel', type='boolean', help='flag to use the directory as label',default=false}, {arg='nbLabels', type='number', help='how many classes (goes with useDirAsLabel)', default=1}, - {arg='sampleSize', type='table', help='resize all sample: {w,h}'} + {arg='sampleSize', type='table', help='resize all sample: {c,w,h}'}, + {arg='padding',type='boolean',help='do we padd all the inputs in w,h'} ) -- parse args local files = sys.dir(dataSetFolder) @@ -180,6 +183,15 @@ function lDataSet:append(...) -- rescale ? if sampleSize then inputs = torch.Tensor(channels, sampleSize[2], sampleSize[3]) + if padding then + offw = math.floor((sampleSize[2] - input[2])*0.5) + offh = math.floor((sampleSize[3] - input[3])*0.5) + if offw >= 0 and offh >= 0 then + inputs:narrow(2,offw,input[2]):narrow(3,offh,input[3]):copy(input) + else + print('reverse crop not implemented w,h must be larger than all data points') + end + end image.scale(input, inputs, 'bilinear') else inputs = input |