Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-09-14 01:20:02 +0300
committerGitHub <noreply@github.com>2016-09-14 01:20:02 +0300
commitfde63523f149b61fc7a95be70d58b1646380966e (patch)
treee382e936ba163c790887a716817c60fb6e95de9d
parent781c4dff349cba781777a2795201659502e5a4eb (diff)
parent5246df8f52b072399b561be1fb5d1337de5aae46 (diff)
Merge pull request #965 from apaszke/mm_weight
Accept 5D weights in VolumetricConvolutionMM
-rw-r--r--VolumetricConvolution.lua21
-rw-r--r--lib/THNN/generic/VolumetricConvolutionMM.c42
2 files changed, 32 insertions, 31 deletions
diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua
index e40c90a..80d9825 100644
--- a/VolumetricConvolution.lua
+++ b/VolumetricConvolution.lua
@@ -61,21 +61,6 @@ local function makeContiguous(self, input, gradOutput)
return input, gradOutput
end
--- function to re-view the weight layout in a way that would make the MM ops happy
-local function viewWeight(self)
- self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane * self.kT * self.kH * self.kW)
- if self.gradWeight and self.gradWeight:dim() > 0 then
- self.gradWeight = self.gradWeight:view(self.nOutputPlane, self.nInputPlane * self.kT * self.kH * self.kW)
- end
-end
-
-local function unviewWeight(self)
- self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane, self.kT, self.kH, self.kW)
- if self.gradWeight and self.gradWeight:dim() > 0 then
- self.gradWeight = self.gradWeight:view(self.nOutputPlane, self.nInputPlane, self.kT, self.kH, self.kW)
- end
-end
-
function VolumetricConvolution:updateOutput(input)
self.finput = self.finput or input.new()
self.fgradInput = self.fgradInput or input.new()
@@ -91,7 +76,6 @@ function VolumetricConvolution:updateOutput(input)
self.padT, self.padW, self.padH
)
else
- viewWeight(self)
input = makeContiguous(self, input)
input.THNN.VolumetricConvolutionMM_updateOutput(
input:cdata(),
@@ -103,7 +87,6 @@ function VolumetricConvolution:updateOutput(input)
self.dT, self.dW, self.dH,
self.padT, self.padW, self.padH
)
- unviewWeight(self)
end
return self.output
end
@@ -122,7 +105,6 @@ function VolumetricConvolution:updateGradInput(input, gradOutput)
return self.gradInput
else
if self.gradInput then
- viewWeight(self)
input, gradOutput = makeContiguous(self, input, gradOutput)
input.THNN.VolumetricConvolutionMM_updateGradInput(
input:cdata(),
@@ -135,7 +117,6 @@ function VolumetricConvolution:updateGradInput(input, gradOutput)
self.dT, self.dW, self.dH,
self.padT, self.padW, self.padH
)
- unviewWeight(self)
return self.gradInput
end
end
@@ -156,7 +137,6 @@ function VolumetricConvolution:accGradParameters(input, gradOutput, scale)
)
else
input, gradOutput = makeContiguous(self, input, gradOutput)
- viewWeight(self)
input.THNN.VolumetricConvolutionMM_accGradParameters(
input:cdata(),
gradOutput:cdata(),
@@ -165,7 +145,6 @@ function VolumetricConvolution:accGradParameters(input, gradOutput, scale)
self.finput:cdata(),
scale or 1
)
- unviewWeight(self)
end
end
diff --git a/lib/THNN/generic/VolumetricConvolutionMM.c b/lib/THNN/generic/VolumetricConvolutionMM.c
index 8fef1cf..ef49a30 100644
--- a/lib/THNN/generic/VolumetricConvolutionMM.c
+++ b/lib/THNN/generic/VolumetricConvolutionMM.c
@@ -2,6 +2,20 @@
#define TH_GENERIC_FILE "generic/VolumetricConvolutionMM.c"
#else
+static int THNN_(view_weight)(THTensor **_weight)
+{
+ THTensor *weight = *_weight;
+ THArgCheck(weight->nDimension == 2 || weight->nDimension == 5, 4,
+ "weight tensor should be 2D or 5D - got %dD", weight->nDimension);
+ if (weight->nDimension == 5) {
+ long s1 = weight->size[0];
+ long s2 = weight->size[1] * weight->size[2] * weight->size[3] * weight->size[4];
+ *_weight = THTensor_(newWithStorage2d)(weight->storage, 0, s1, -1, s2, -1);
+ return 1;
+ }
+ return 0;
+}
+
/* note: due to write issues, this one cannot be parallelized as well as unfolded_copy */
static void THNN_(unfolded_acc_vol)(
THTensor *finput,
@@ -243,6 +257,7 @@ void THNN_(VolumetricConvolutionMM_updateOutput)(
int dimt = 1;
int dimh = 2;
int dimw = 3;
+ int freeWeight = 0;
long nInputPlane;
long inputDepth;
@@ -283,6 +298,8 @@ void THNN_(VolumetricConvolutionMM_updateOutput)(
);
}
+ freeWeight = THNN_(view_weight)(&weight);
+
if (input->nDimension == 4)
{
THTensor_(resize2d)(finput, kT*kW*kH*nInputPlane, outputDepth*outputHeight*outputWidth);
@@ -326,6 +343,9 @@ void THNN_(VolumetricConvolutionMM_updateOutput)(
THTensor_(free)(finput_t);
}
}
+
+ if (freeWeight)
+ THTensor_(free)(weight);
}
static void THNN_(VolumetricConvolutionMM_updateGradInput_frame)(
@@ -382,23 +402,20 @@ void THNN_(VolumetricConvolutionMM_updateGradInput)(
int pW,
int pH)
{
- // number of input/output planes and kernel size is indirectly defined by the weight tensor
- THArgCheck(weight->nDimension == 2, 4,
- "2D weight tensor is expected (nOutputPlane x (nInputPlane * kT * kH * kW))"
- );
-
int nOutputPlane = (int)weight->size[0];
THArgCheck(nOutputPlane == gradOutput->size[input->nDimension == 5 ? 1 : 0], 1,
"Number of output features is not equal to nOutputPlane"
);
+ int freeWeight = THNN_(view_weight)(&weight);
+
THTensor_(resizeAs)(gradInput, input);
THTensor_(resizeAs)(fgradInput, finput);
// depending on the BLAS library, fgradInput (result tensor) might
// be left uninitialized on zero alpha, which might lead to weird behavior
// hence, to be safe, zero it
- THTensor_(zero)(fgradInput);
+ THTensor_(zero)(fgradInput);
THTensor_(transpose)(weight, weight, 0, 1);
if (input->nDimension == 4)
@@ -436,6 +453,9 @@ void THNN_(VolumetricConvolutionMM_updateGradInput)(
}
THTensor_(transpose)(weight, weight, 0, 1);
+
+ if (freeWeight)
+ THTensor_(free)(weight);
}
static void THNN_(VolumetricConvolutionMM_accGradParameters_frame)(
@@ -479,10 +499,7 @@ void THNN_(VolumetricConvolutionMM_accGradParameters)(
THTensor *finput,
real scale)
{
- THArgCheck(gradWeight->nDimension == 2, 4,
- "2D gradWeight tensor is expected (nOutputPlane x (nInputPlane * kT * kH * kW))"
- );
-
+ int freeWeight;
int nOutputPlane = (int)gradWeight->size[0];
THArgCheck(gradBias->nDimension == 1 && gradBias->size[0] == nOutputPlane, 5,
@@ -493,6 +510,8 @@ void THNN_(VolumetricConvolutionMM_accGradParameters)(
"Number of output features is not equal to nOutputPlane"
);
+ freeWeight = THNN_(view_weight)(&gradWeight);
+
if (input->nDimension == 4) // non-batch mode
{
THNN_(VolumetricConvolutionMM_accGradParameters_frame)(gradOutput, gradWeight, gradBias, finput, scale);
@@ -513,6 +532,9 @@ void THNN_(VolumetricConvolutionMM_accGradParameters)(
THTensor_(free)(finput_t);
}
}
+
+ if (freeWeight)
+ THTensor_(free)(gradWeight);
}
#endif