Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/llvm/llvm-project.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorAart Bik <ajcbik@google.com>2022-11-01 20:34:13 +0300
committerAart Bik <ajcbik@google.com>2022-11-01 21:13:20 +0300
commit10db57b7ea4ef756b3ab2269263bffc2dfff9310 (patch)
treefc5781c53ad4ee4b3e5e3d0f8c461cf90684a596 /mlir
parent5baa4b8e1164b3635ef9220104159988f4ee836a (diff)
[mlir][sparse] replace magic constant with symbol
Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D137177
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp9
1 files changed, 4 insertions, 5 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 85f4c4e073ad..944139f38626 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -31,9 +31,7 @@ using namespace mlir::sparse_tensor;
namespace {
-// TODO: start using these when insertions are implemented
-// static constexpr uint64_t DimSizesIdx = 0;
-// static constexpr uint64_t DimCursorIdx = 1;
+static constexpr uint64_t DimSizesIdx = 0;
static constexpr uint64_t MemSizesIdx = 2;
static constexpr uint64_t FieldsIdx = 3;
@@ -88,11 +86,12 @@ static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
if (!ShapedType::isDynamic(shape[dim]))
return constantIndex(rewriter, loc, shape[dim]);
- // Any other query can consult the dimSizes array at field 0 using,
+ // Any other query can consult the dimSizes array at field DimSizesIdx,
// accounting for the reordering applied to the sparse storage.
auto tuple = getTuple(adaptedValue);
Value idx = constantIndex(rewriter, loc, toStoredDim(tensorTp, dim));
- return rewriter.create<memref::LoadOp>(loc, tuple.getInputs().front(), idx)
+ return rewriter
+ .create<memref::LoadOp>(loc, tuple.getInputs()[DimSizesIdx], idx)
.getResult();
}