Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDaya Khudia <dskhudia@fb.com>2019-07-16 03:34:48 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-07-16 03:47:36 +0300
commit1568107cd14e5b2b8abaafa212156d64778660dd (patch)
treee568011b37b2a407eb7a66a6b73f63ed3362f1f9 /src
parentf08039388abf2fc9908b5086a8c884202355e649 (diff)
Assume input weights to be in transposed format for convUnified (#104)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/104 For consistency, we always assume that weights to PackWeightsForConv are in format K R S C/G, which is same as G K/G R S C/G cc: Huihan Liu: Please note this change. Reviewed By: jianyuh Differential Revision: D16186932 fbshipit-source-id: 9ca2562f213d6b296ef8bd2eca1e5b6e98c436ec
Diffstat (limited to 'src')
-rw-r--r--src/PackWeightsForConv.cc6
-rw-r--r--src/RefImplementations.cc54
2 files changed, 43 insertions, 17 deletions
diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc
index c811144..78379af 100644
--- a/src/PackWeightsForConv.cc
+++ b/src/PackWeightsForConv.cc
@@ -42,18 +42,18 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
W_dw_3D_packed_ = nullptr;
W_gconv_packed_ =
std::make_shared<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>(
- matrix_op_t::NoTranspose, conv_p, sdata, nullptr);
+ matrix_op_t::Transpose, conv_p, sdata, nullptr);
break;
}
case optimized_conv_t::im2col: {
int NDim = conv_p.OC / conv_p.G;
int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC;
W_im2col_packed_ = std::make_shared<PackBMatrix<T, accT>>(
- matrix_op_t::NoTranspose,
+ matrix_op_t::Transpose,
KDim,
NDim,
sdata,
- NDim,
+ KDim / conv_p.G,
nullptr,
conv_p.G,
blocking_params);
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc
index b4b0c2b..e3c0eac 100644
--- a/src/RefImplementations.cc
+++ b/src/RefImplementations.cc
@@ -181,8 +181,7 @@ void cblas_sgemm_ref(
int ldb,
float beta,
float* Cfp32,
- int ldc
- ) {
+ int ldc) {
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
float sum = 0;
@@ -204,7 +203,6 @@ void cblas_sgemm_ref(
}
}
-
void row_offsets_u8acc32_ref(
int M,
int K,
@@ -542,21 +540,49 @@ void transposeConvWeights(
const conv_param_t<SPATIAL_DIM>& conv_p,
const std::int8_t* src,
std::int8_t* dest) {
- assert(SPATIAL_DIM == 2 && "Only 2D supported currently");
- int R = conv_p.K[0];
- int S = conv_p.K[1];
int G = conv_p.G;
int IC_per_G = conv_p.IC / conv_p.G;
int OC_per_G = conv_p.OC / conv_p.G;
- // Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format.
- for (int r = 0; r < R; ++r) {
- for (int s = 0; s < S; ++s) {
- for (int k = 0; k < OC_per_G; ++k) {
- for (int g = 0; g < G; ++g) {
- for (int c = 0; c < IC_per_G; ++c) {
- dest[(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k] =
- src[(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c];
+ assert(
+ (SPATIAL_DIM == 3 || SPATIAL_DIM == 2) &&
+ "Only 2D and 3D convolutions are supported");
+ if (SPATIAL_DIM == 2) {
+ int R = conv_p.K[0];
+ int S = conv_p.K[1];
+ // Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format.
+ for (int r = 0; r < R; ++r) {
+ for (int s = 0; s < S; ++s) {
+ for (int k = 0; k < OC_per_G; ++k) {
+ for (int g = 0; g < G; ++g) {
+ for (int c = 0; c < IC_per_G; ++c) {
+ dest[(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k] =
+ src[(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c];
+ }
+ }
+ }
+ }
+ }
+ } else {
+ // Transforms weights from G K/G (T R S C/G) to G (T R S C/G) K/G format.
+ int T = conv_p.K[0];
+ int R = conv_p.K[1];
+ int S = conv_p.K[2];
+ for (int t = 0; t < T; ++t) {
+ for (int r = 0; r < R; ++r) {
+ for (int s = 0; s < S; ++s) {
+ for (int k = 0; k < OC_per_G; ++k) {
+ for (int g = 0; g < G; ++g) {
+ for (int c = 0; c < IC_per_G; ++c) {
+ dest
+ [((((g * T + t) * R + r) * S + s) * IC_per_G + c) *
+ OC_per_G +
+ k] =
+ src[((((g * OC_per_G + k) * T + t) * R + r) * S + s) *
+ IC_per_G +
+ c];
+ }
+ }
}
}
}