diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-07-27 19:59:07 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-07-27 19:59:07 +0400 |
commit | 07105fe36908f2782d2d2ea199ea96752e94c290 (patch) | |
tree | 30b69f7cfecc9d45eda2f58d056d3290fc6f2842 | |
parent | 8cc903ad7f2c472a0047d6f65e12f1129a4db174 (diff) |
debugging/clarifying sampling modes
-rw-r--r-- | DataSetLabelMe.lua | 7 | ||||
-rw-r--r-- | StochasticTrainer.lua | 2 |
2 files changed, 7 insertions, 2 deletions
diff --git a/DataSetLabelMe.lua b/DataSetLabelMe.lua index f5f2414..8d81448 100644 --- a/DataSetLabelMe.lua +++ b/DataSetLabelMe.lua @@ -191,8 +191,13 @@ function DataSetLabelMe:__index__(key) ctr_target = math.floor(random.uniform(1,self.nbClasses)) end local nbSamplesPerClass = math.ceil(self.nbSamples / self.nbClasses) - tag_idx = math.floor((key*nbSamplesPerClass-1)/self.nbClasses) + 1 + if self.infiniteSet then + tag_idx = math.random(1,self.tags[ctr_target].size/3) + else + tag_idx = math.floor((key-1)/self.nbClasses) + 1 + end tag_idx = ((tag_idx-1) % (self.tags[ctr_target].size/3))*3 + 1 + print('key:', key, 'tag:', tag_idx, 'label', ctr_target) end -- generate patch diff --git a/StochasticTrainer.lua b/StochasticTrainer.lua index 7576636..8c71f3e 100644 --- a/StochasticTrainer.lua +++ b/StochasticTrainer.lua @@ -72,7 +72,7 @@ function StochasticTrainer:train(dataset) else shuffledIndices = lab.randperm(dataset:size()) end - + while true do print('<trainer> on training set:') print("<trainer> stochastic gradient descent epoch # " .. self.epoch) |