Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'Optimization.lua')
-rw-r--r--Optimization.lua49
1 files changed, 2 insertions, 47 deletions
diff --git a/Optimization.lua b/Optimization.lua
index f18c635..daf0a8d 100644
--- a/Optimization.lua
+++ b/Optimization.lua
@@ -1,56 +1,11 @@
local Optimization = torch.class('nn.Optimization')
function Optimization:__init()
+ self.output = 0
end
function Optimization:forward(inputs, targets)
- self:flatten(parameters, gradParameters)
self.output = 0
- self:unflatten(parameters, gradParameters)
+ print('<Optimization> WARNING: this is a virtual function, please overload !')
return self.output
end
-
-function Optimization:flatten(parameters, gradParameters)
- if type(parameters) == 'table' then
- -- create flat parameters
- self.parameters = self.parameters or torch.Tensor()
- self.gradParameters = self.gradParameters or torch.Tensor()
- -- assuming that the parameters won't change their size,
- -- we compute offsets once
- if not self.offsets then
- self.nParameters = 0
- self.offsets = {}
- for _,param in ipairs(parameters) do
- table.insert(self.offsets, self.nParameters+1)
- self.nParameters = self.nParameters + param:nElement()
- end
- self.parameters:resize(self.nParameters)
- self.gradParameters:resize(self.nParameters)
- end
- -- copy all params in flat array
- for i = 1,#parameters do
- local nElement = parameters[i]:nElement()
- self.parameters:narrow(1,self.offsets[i],nElement):copy(parameters[i])
- self.gradParameters:narrow(1,self.offsets[i],nElement):copy(gradParameters[i])
- end
- else
- self.parameters = parameters
- self.gradParameters = gradParameters
- end
-end
-
-function Optimization:unflatten(parameters, gradParameters)
- if type(parameters) == 'table' then
- -- copy all params into unflat arrays
- local offset = 1
- for i = 1,#parameters do
- local nElement = parameters[i]:nElement()
- parameters[i]:copy(self.parameters:narrow(1,offset,nElement))
- gradParameters[i]:copy(self.gradParameters:narrow(1,offset,nElement))
- offset = offset + nElement
- end
- else
- parameters = self.parameters
- gradParameters = self.gradParameters
- end
-end