Welcome to mirror list, hosted at ThFree Co, Russian Federation.

BatchTrainer.lua - github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: a5b135da5a32cc09249bf73de0afbb05cab83ccf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
local BatchTrainer, parent = torch.class('nn.BatchTrainer', 'nn.OnlineTrainer')

-- Essentially simialar to the OnlineTrainer but only used the parts
-- of the code which prepare the data and the tester. train() has been
-- replaced by nextBatch() which moves the trainer one batch further
-- in the data.  When the first epoch is finished then the batches are
-- reused.  Each call to optimizer.forward() in nextBatch() creates a
-- closure with the current batch as input.

function BatchTrainer:__init(...)
   local args = {...}
   parent.__init(self, args)
   -- unpack args
   xlua.unpack_class(
      self, args,
      'BatchTrainer', 
      'A modified version of the general-purpose online trainer class.\n'
	 .. ' which only preps the input batch and calls optimizer to\n'
	 .. ' create a closure\n',
      {arg='trainset', type='nn.DataList', 
       help='dataset from which to draw batches', req=true},
      {arg='module', type='nn.Module', help='a module to train', req=true},
      {arg='criterion', type='nn.Criterion', 
       help='a criterion to estimate the error'},
      {arg='preprocessor', type='nn.Module', 
       help='a preprocessor to prime the data before the module'},
      {arg='optimizer', type='nn.Optimization', 
       help='an optimization method'}, 
      {arg='batchSize', type='number', 
       help='[mini] batch size', default=1},
      {arg='maxEpoch', type='number', 
       help='maximum number of epochs', default=50},
      {arg='dispProgress', type='boolean', 
       help='display a progress bar during training/testing', default=true},
      {arg='save', type='string', 
       help='path to save networks and log training'},
      {arg='timestamp', type='boolean', 
       help='if true, appends a timestamp to each network saved', default=false}
   )
   self.epoch = 1
   self.batch = nil
   self.trainOffset = nil
end

-- update the counters
function BatchTrainer:next()
   if not self.batch or not self.trainOffset then
      -- initialize
      self.batch = 1
      self.trainOffset = 1
   else
      -- hook to run something on the current batch
      -- (for eg. if you want to run a test on this batch before
      -- switching to the next)
      if self.hookTrainBatch then
	 self.hookTrainBatch(self)
      end

      -- simple batch increment
      self.batch = self.batch + 1
      self.trainOffset = self.trainOffset + self.batchSize
      
      -- test for new epoch
      if self.trainOffset > self.trainset:size() then

	 -- hook to run on current epoch before switching to next
	 if self.hookTrainEpoch then
	    self.hookTrainEpoch(self)
	 end

	 if self.save then self:log() end

	 self.trainOffset = 1
	 self.epoch = self.epoch + 1
	 self.batch = 1
      end
      
      -- on all but the first batch we need to reset the children
      if optimizer.parallelize > 1 then 
	 parallel.children:send('break')
      end

   end
   -- disp progress
   if self.dispProgress then
      xlua.progress(self.trainOffset, self.trainset:size())
   end

end

-- this function is called train() in the online trainer.  I seems to
-- make more sense to call it next_batch() here as the training is
-- done outside of this code.

function BatchTrainer:nextBatch()
   self:next()
   local module = self.module
   local criterion = self.criterion
   local t = self.trainOffset
   local ds = self.trainset:size()
   local bs = self.batchSize
   
   print('<trainer> on training set:')
   print("<trainer> online epoch # " .. self.epoch 
	 .. ' batch # '..self.batch
	 .. ' [batchSize = ' .. self.batchSize .. ']')

   -- create mini batch
   self.inputs = self.inputs or {}
   self.targets = self.targets or {}
   local inputs = {}
   local targets = {}
   if not self.inputs[self.batch] then

      self.inputs[self.batch] = {}
      inputs = self.inputs[self.batch] 
      self.targets[self.batch] = {}
      targets = self.targets[self.batch]

      for i = t,math.min(t+bs-1,ds) do
	 -- load new sample
	 local sample = self.trainset[i]
	 local input = sample[1]
	 local target = sample[2]
	 
	 -- optional preprocess (no learning is done for that guy)
	 if self.preprocessor then input = self.preprocessor:forward(input) end
	 
	 -- store input/target
	 table.insert(inputs, input)
	 table.insert(targets, target)
      end
   else  
      -- get batch from cache
      inputs = self.inputs[self.batch] 
      targets = self.targets[self.batch]
   end   

   -- set up closure batch.evaluate() for optimizer
   local error = self.optimizer:forward(inputs, targets)
   
end

-- special test to just get results of current batch
function BatchTrainer:testBatch()
   local criterion = self.criterion
   local module = self.module
   
   local inputs = self.inputs[self.batch]
   local targets = self.targets[self.batch]
   
   self.currentError = 0
   
   for i = 1,#inputs do 
      local input = inputs[i]
      local target = targets[i]
      if criterion then
         self.currentError = self.currentError + 
	    criterion:forward(module:forward(input), target)
      else
         local _,error = module:forward(input, target)
         self.currentError = self.currentError + error
      end
      -- user hook
      if self.hookTestSample then
         self.hookTestSample(self, {input, target})
      end
   end
end