Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-07-27 19:59:07 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-07-27 19:59:07 +0400
commit07105fe36908f2782d2d2ea199ea96752e94c290 (patch)
tree30b69f7cfecc9d45eda2f58d056d3290fc6f2842
parent8cc903ad7f2c472a0047d6f65e12f1129a4db174 (diff)
debugging/clarifying sampling modes
-rw-r--r--DataSetLabelMe.lua7
-rw-r--r--StochasticTrainer.lua2
2 files changed, 7 insertions, 2 deletions
diff --git a/DataSetLabelMe.lua b/DataSetLabelMe.lua
index f5f2414..8d81448 100644
--- a/DataSetLabelMe.lua
+++ b/DataSetLabelMe.lua
@@ -191,8 +191,13 @@ function DataSetLabelMe:__index__(key)
ctr_target = math.floor(random.uniform(1,self.nbClasses))
end
local nbSamplesPerClass = math.ceil(self.nbSamples / self.nbClasses)
- tag_idx = math.floor((key*nbSamplesPerClass-1)/self.nbClasses) + 1
+ if self.infiniteSet then
+ tag_idx = math.random(1,self.tags[ctr_target].size/3)
+ else
+ tag_idx = math.floor((key-1)/self.nbClasses) + 1
+ end
tag_idx = ((tag_idx-1) % (self.tags[ctr_target].size/3))*3 + 1
+ print('key:', key, 'tag:', tag_idx, 'label', ctr_target)
end
-- generate patch
diff --git a/StochasticTrainer.lua b/StochasticTrainer.lua
index 7576636..8c71f3e 100644
--- a/StochasticTrainer.lua
+++ b/StochasticTrainer.lua
@@ -72,7 +72,7 @@ function StochasticTrainer:train(dataset)
else
shuffledIndices = lab.randperm(dataset:size())
end
-
+
while true do
print('<trainer> on training set:')
print("<trainer> stochastic gradient descent epoch # " .. self.epoch)