Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-08-18 17:25:17 +0300
committerGitHub <noreply@github.com>2016-08-18 17:25:17 +0300
commit8d07b574d01f013edd8ae0a7df067788301096a2 (patch)
treeeb6919adf6164a75163bb9949a10ee7a223ddfae
parent5cfd34497d0cc7e180ee370767a3ca684176a10c (diff)
parent85e46603dfada0d45af6ba5e5c3ad95c536a04e6 (diff)
Merge pull request #934 from apaszke/spatial_conv
Accept both 2D and 4D weights in SpatialConvolutionMM
-rw-r--r--SpatialConvolution.lua23
-rw-r--r--lib/THNN/generic/SpatialConvolutionMM.c47
2 files changed, 45 insertions, 25 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua
index 8324f95..01a08cd 100644
--- a/SpatialConvolution.lua
+++ b/SpatialConvolution.lua
@@ -89,25 +89,9 @@ local function makeContiguous(self, input, gradOutput)
return input, gradOutput
end
--- function to re-view the weight layout in a way that would make the MM ops happy
-local function viewWeight(self)
- self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane * self.kH * self.kW)
- if self.gradWeight and self.gradWeight:dim() > 0 then
- self.gradWeight = self.gradWeight:view(self.nOutputPlane, self.nInputPlane * self.kH * self.kW)
- end
-end
-
-local function unviewWeight(self)
- self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW)
- if self.gradWeight and self.gradWeight:dim() > 0 then
- self.gradWeight = self.gradWeight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW)
- end
-end
-
function SpatialConvolution:updateOutput(input)
assert(input.THNN, torch.type(input)..'.THNN backend not imported')
backCompatibility(self)
- viewWeight(self)
input = makeContiguous(self, input)
input.THNN.SpatialConvolutionMM_updateOutput(
input:cdata(),
@@ -120,7 +104,6 @@ function SpatialConvolution:updateOutput(input)
self.dW, self.dH,
self.padW, self.padH
)
- unviewWeight(self)
return self.output
end
@@ -128,20 +111,18 @@ function SpatialConvolution:updateGradInput(input, gradOutput)
assert(input.THNN, torch.type(input)..'.THNN backend not imported')
if self.gradInput then
backCompatibility(self)
- viewWeight(self)
input, gradOutput = makeContiguous(self, input, gradOutput)
input.THNN.SpatialConvolutionMM_updateGradInput(
input:cdata(),
gradOutput:cdata(),
self.gradInput:cdata(),
- self.weight:cdata(),
+ self.weight:cdata(),
self.finput:cdata(),
self.fgradInput:cdata(),
self.kW, self.kH,
self.dW, self.dH,
self.padW, self.padH
)
- unviewWeight(self)
return self.gradInput
end
end
@@ -151,7 +132,6 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale)
scale = scale or 1
backCompatibility(self)
input, gradOutput = makeContiguous(self, input, gradOutput)
- viewWeight(self)
input.THNN.SpatialConvolutionMM_accGradParameters(
input:cdata(),
gradOutput:cdata(),
@@ -164,7 +144,6 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale)
self.padW, self.padH,
scale
)
- unviewWeight(self)
end
function SpatialConvolution:type(type,tensorCache)
diff --git a/lib/THNN/generic/SpatialConvolutionMM.c b/lib/THNN/generic/SpatialConvolutionMM.c
index e7460c8..1ca68a9 100644
--- a/lib/THNN/generic/SpatialConvolutionMM.c
+++ b/lib/THNN/generic/SpatialConvolutionMM.c
@@ -67,9 +67,12 @@ void THNN_(SpatialConvolutionMM_updateOutput)(
long outputWidth;
long outputHeight;
+ int freeWeight = 0;
+
THArgCheck( input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch mode) tensor expected");
THArgCheck(kW > 0 && kH > 0, 8, "kernel size should be greater than zero");
THArgCheck(dW > 0 && dH > 0, 10, "stride should be greater than zero");
+ THArgCheck(weight->nDimension == 2 || weight->nDimension == 4, 4, "weight tensor should be 2D or 4D");
if (input->nDimension == 4) {
dimf++;
@@ -88,8 +91,19 @@ void THNN_(SpatialConvolutionMM_updateOutput)(
THError("Given input size: (%dx%dx%d). Calculated output size: (%dx%dx%d). Output size is too small",
nInputPlane,inputHeight,inputWidth,nOutputPlane,outputHeight,outputWidth);
- if (nInputPlane*kW*kH != weight->size[1])
- THError("Wrong number of input channels! Input has %d channels, expected %d",nInputPlane,weight->size[1]/(kW*kH));
+
+ int expectedWeightSize = weight->nDimension == 2 ? nInputPlane*kW*kH : nInputPlane;
+ int weightInputPlanes = weight->nDimension == 2 ? weight->size[1]/(kW*kH) : weight->size[1];
+ if (expectedWeightSize != weight->size[1])
+ THError("Wrong number of input channels! Input has %d channels, expected %d",
+ nInputPlane, weightInputPlanes);
+
+ if (weight->nDimension == 4) {
+ long s1 = weight->size[0];
+ long s2 = weight->size[1] * weight->size[2] * weight->size[3];
+ weight = THTensor_(newWithStorage2d)(weight->storage, 0, s1, -1, s2, -1);
+ freeWeight = 1;
+ }
if(input->nDimension == 3)
{
@@ -126,6 +140,9 @@ void THNN_(SpatialConvolutionMM_updateOutput)(
THTensor_(free)(finput_t);
}
}
+
+ if (freeWeight)
+ THTensor_(free)(weight);
}
static void THNN_(SpatialConvolutionMM_updateGradInput_frame)(
@@ -167,17 +184,27 @@ void THNN_(SpatialConvolutionMM_updateGradInput)(
int padH)
{
long nOutputPlane = weight->size[0];
+ int freeWeight = 0;
THArgCheck( nOutputPlane == gradOutput->size[input->nDimension == 4 ? 1 : 0], 3, "Number of output features is not equal to nOutputPlane" );
THArgCheck(kW > 0 && kH > 0, 9, "kernel size should be greater than zero");
THArgCheck(dW > 0 && dH > 0, 11, "stride should be greater than zero");
+ THArgCheck(weight->nDimension == 2 || weight->nDimension == 4, 4, "weight tensor should be 2D or 4D");
THTensor_(resizeAs)(gradInput, input);
THTensor_(resizeAs)(fgradInput, finput);
// depending on the BLAS library, fgradInput (result tensor) might
// be left uninitialized on zero alpha, which might lead to weird behavior
// hence, to be safe, zero it
- THTensor_(zero)(fgradInput);
+ THTensor_(zero)(fgradInput);
+
+ if (weight->nDimension == 4) {
+ long s1 = weight->size[0];
+ long s2 = weight->size[1] * weight->size[2] * weight->size[3];
+ weight = THTensor_(newWithStorage2d)(weight->storage, 0, s1, -1, s2, -1);
+ freeWeight = 1;
+ }
+
THTensor_(transpose)(weight, weight, 0, 1);
if(input->nDimension == 3)
@@ -205,6 +232,9 @@ void THNN_(SpatialConvolutionMM_updateGradInput)(
}
THTensor_(transpose)(weight, weight, 0, 1);
+
+ if (freeWeight)
+ THTensor_(free)(weight);
}
static void THNN_(SpatialConvolutionMM_accGradParameters_frame)(
@@ -254,10 +284,19 @@ void THNN_(SpatialConvolutionMM_accGradParameters)(
int padH,
real scale)
{
+ int freeWeight = 0;
long nOutputPlane = gradWeight->size[0];
THArgCheck( nOutputPlane == gradOutput->size[input->nDimension == 4 ? 1 : 0], 3, "Number of output features is not equal to nOutputPlane" );
THArgCheck(kW > 0 && kH > 0, 8, "kernel size should be greater than zero");
THArgCheck(dW > 0 && dH > 0, 10, "stride should be greater than zero");
+ THArgCheck(gradWeight->nDimension == 2 || gradWeight->nDimension == 4, 4, "gradWeight tensor should be 2D or 4D");
+
+ if (gradWeight->nDimension == 4) {
+ long s1 = gradWeight->size[0];
+ long s2 = gradWeight->size[1] * gradWeight->size[2] * gradWeight->size[3];
+ gradWeight = THTensor_(newWithStorage2d)(gradWeight->storage, 0, s1, -1, s2, -1);
+ freeWeight = 1;
+ }
if(input->nDimension == 3)
{
@@ -279,6 +318,8 @@ void THNN_(SpatialConvolutionMM_accGradParameters)(
THTensor_(free)(finput_t);
}
}
+ if (freeWeight)
+ THTensor_(free)(gradWeight);
}
#endif