Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2017-05-03 21:03:58 +0300
committerSoumith Chintala <soumith@gmail.com>2017-05-09 21:55:42 +0300
commit8252eb7450c6e632be7a4f37f217995059834837 (patch)
tree2fd8617b73ad2d6e969496f98b11de61a4ba2b7e
parentb9ab0f7ad41aca03605d44d54fdda5c3b2beb8ac (diff)
Add a keepdim parameter for reduction functions over a single dimension.
By default, this parameter is False -- a backwards incompatible change, but one that follows numpy semantics, e.g. numpy.sum (numpy names the parameter "keepdims" since you can pass multiple dims to reduction functions). The old behavior seems desired for normalization type operations where the tensor will immediately be expanded out again, e.g.: probs.sum(1).expand_as(probs) which no longer works because the dimension to expand is missing. This can be fixed by simply passing True as "keepdim" argument to the reduction operation, e.g: probs.sum(1, keepdim=True).expand_as(probs)
-rw-r--r--lib/THCUNN/generic/PReLU.cu6
-rw-r--r--lib/THCUNN/generic/SparseLinear.cu2
2 files changed, 4 insertions, 4 deletions
diff --git a/lib/THCUNN/generic/PReLU.cu b/lib/THCUNN/generic/PReLU.cu
index 949e3d9..05271f8 100644
--- a/lib/THCUNN/generic/PReLU.cu
+++ b/lib/THCUNN/generic/PReLU.cu
@@ -135,7 +135,7 @@ void THNN_(PReLU_accGradParameters)(
if (ndim == 2)
{
- THCTensor_(sum)(state, gradWeightBuf, gradInput, 0);
+ THCTensor_(sum)(state, gradWeightBuf, gradInput, 0, 1);
THCTensor_(cadd)(state, gradWeight, gradWeight, scale, gradWeightBuf);
}
else
@@ -147,8 +147,8 @@ void THNN_(PReLU_accGradParameters)(
}
THCTensor_(resize3d)(state, buffer, input->size[0], nOutputPlane, size3);
THCTensor_(resize2d)(state, sumbuf, input->size[0], nOutputPlane);
- THCTensor_(sum)(state, sumbuf, buffer, 2);
- THCTensor_(sum)(state, gradWeightBuf, sumbuf, 0);
+ THCTensor_(sum)(state, sumbuf, buffer, 2, 1);
+ THCTensor_(sum)(state, gradWeightBuf, sumbuf, 0, 1);
THCTensor_(cadd)(state, gradWeight, gradWeight, scale, gradWeightBuf);
THCTensor_(free)(state, buffer);
}
diff --git a/lib/THCUNN/generic/SparseLinear.cu b/lib/THCUNN/generic/SparseLinear.cu
index 70c9f5b..07eda62 100644
--- a/lib/THCUNN/generic/SparseLinear.cu
+++ b/lib/THCUNN/generic/SparseLinear.cu
@@ -206,7 +206,7 @@ void THNN_(SparseLinear_accGradParameters)(
&one, THCTensor_(data)(state, gradWeight), inDim
);
- THCTensor_(sum)(state, buf, gradOutput, 0);
+ THCTensor_(sum)(state, buf, gradOutput, 0, 1);
THCTensor_(resize1d)(state, buf, outDim);
THCTensor_(cadd)(state, gradBias, gradBias, scale, buf);