Welcome to mirror list, hosted at ThFree Co, Russian Federation.

PackWeightsForConv.cc « src - github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: c81114494da7d5a6dff48e941444aac61e275bab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
/*
 * Copyright (c) Facebook, Inc. and its affiliates.
 * All rights reserved.
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */
#include <memory>
#include "fbgemm/Fbgemm.h"

namespace fbgemm {

template <int SPATIAL_DIM, typename T, typename accT>
PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
    const conv_param_t<SPATIAL_DIM>& conv_p,
    const T* sdata,
    const BlockingFactors* blocking_params) {
  static_assert(
      SPATIAL_DIM == 2 || SPATIAL_DIM == 3,
      "Only 2D and 3D convolutions are supported");
  // Note: The following logic should *exactly* match with what we have in
  // FbgemmConv.cc
  switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) {
    case optimized_conv_t::depthwise: {
      if (SPATIAL_DIM == 3) {
        W_im2col_packed_ = nullptr;
        W_dw_2D_packed_ = nullptr;
        W_dw_3D_packed_ =
            std::make_shared<Packed3x3x3ConvMatrix>(conv_p.G, sdata);
        W_gconv_packed_ = nullptr;
      } else {
        W_im2col_packed_ = nullptr;
        W_dw_2D_packed_ =
            std::make_shared<Packed3x3ConvMatrix>(conv_p.G, sdata);
        W_dw_3D_packed_ = nullptr;
        W_gconv_packed_ = nullptr;
      }
      break;
    }
    case optimized_conv_t::groupwise: {
      W_im2col_packed_ = nullptr;
      W_dw_2D_packed_ = nullptr;
      W_dw_3D_packed_ = nullptr;
      W_gconv_packed_ =
          std::make_shared<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>(
              matrix_op_t::NoTranspose, conv_p, sdata, nullptr);
      break;
    }
    case optimized_conv_t::im2col: {
      int NDim = conv_p.OC / conv_p.G;
      int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC;
      W_im2col_packed_ = std::make_shared<PackBMatrix<T, accT>>(
          matrix_op_t::NoTranspose,
          KDim,
          NDim,
          sdata,
          NDim,
          nullptr,
          conv_p.G,
          blocking_params);
      W_dw_2D_packed_ = nullptr;
      W_dw_3D_packed_ = nullptr;
      W_gconv_packed_ = nullptr;
      break;
    }
  } // switch
}

template class PackWeightsForConv<2, int8_t, int32_t>;
template class PackWeightsForConv<3, int8_t, int32_t>;

} // namespace fbgemm