Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--SpatialNormalization.lua22
-rw-r--r--init.lua3
-rw-r--r--nnx-1.0-1.rockspec1
3 files changed, 22 insertions, 4 deletions
diff --git a/SpatialNormalization.lua b/SpatialNormalization.lua
index 349098b..e0e5c60 100644
--- a/SpatialNormalization.lua
+++ b/SpatialNormalization.lua
@@ -186,14 +186,14 @@ function SpatialNormalization:forward(input)
-- auto switch to 3-channel
self.input = input
if (input:nDimension() == 2) then
- self.input = torch.Tensor(1,input:size(1),input:size(2)):copy(input)
+ self.input = input:clone():resize(1,input:size(1),input:size(2))
end
-- recompute coef only if necessary
if (self.input:size(3) ~= self.coef:size(2)) or (self.input:size(2) ~= self.coef:size(1)) then
- local intVals = torch.Tensor(self.nfeatures,self.input:size(2),self.input:size(3)):fill(1)
+ local intVals = self.input.new(self.nfeatures,self.input:size(2),self.input:size(3)):fill(1)
self.coef = self.convo:forward(intVals)
- self.coef = torch.Tensor():resizeAs(self.coef):copy(self.coef)
+ self.coef = self.coef:clone()
end
-- compute mean
@@ -221,7 +221,7 @@ function SpatialNormalization:backward(input, gradOutput)
-- auto switch to 3-channel
self.input = input
if (input:nDimension() == 2) then
- self.input = torch.Tensor(1,input:size(1),input:size(2)):copy(input)
+ self.input = input:clone():resize(1,input:size(1),input:size(2))
end
self.gradInput:resizeAs(self.input):zero()
@@ -246,6 +246,20 @@ function SpatialNormalization:backward(input, gradOutput)
return self.gradInput
end
+function SpatialNormalization:type(type)
+ parent.type(self,type)
+ self.convo:type(type)
+ self.meanDiviseMod:type(type)
+ self.subtractMod:type(type)
+ self.squareMod:type(type)
+ self.convostd:type(type)
+ self.sqrtMod:type(type)
+ self.stdDiviseMod:type(type)
+ self.thresMod:type(type)
+ self.diviseMod:type(type)
+ return self
+end
+
function SpatialNormalization:write(file)
parent.write(self,file)
file:writeObject(self.kernel)
diff --git a/init.lua b/init.lua
index d34abba..4e694db 100644
--- a/init.lua
+++ b/init.lua
@@ -108,6 +108,9 @@ torch.include('nnx', 'Trainer.lua')
torch.include('nnx', 'OnlineTrainer.lua')
torch.include('nnx', 'BatchTrainer.lua')
+-- conversion helper:
+torch.include('nnx', 'Type.lua')
+
-- datasets:
torch.include('nnx', 'DataSet.lua')
torch.include('nnx', 'DataList.lua')
diff --git a/nnx-1.0-1.rockspec b/nnx-1.0-1.rockspec
index 120cafc..38c0af2 100644
--- a/nnx-1.0-1.rockspec
+++ b/nnx-1.0-1.rockspec
@@ -103,6 +103,7 @@ build = {
install_files(/lua/nnx Probe.lua)
install_files(/lua/nnx HardShrink.lua)
install_files(/lua/nnx Narrow.lua)
+ install_files(/lua/nnx Type.lua)
install_files(/lua/nnx Power.lua)
install_files(/lua/nnx Square.lua)
install_files(/lua/nnx Sqrt.lua)