//===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file declares a byte-code and interpreter for pattern rewrites in MLIR. // The byte-code is constructed from the PDL Interpreter dialect. // //===----------------------------------------------------------------------===// #ifndef MLIR_REWRITE_BYTECODE_H_ #define MLIR_REWRITE_BYTECODE_H_ #include "mlir/IR/PatternMatch.h" namespace mlir { namespace pdl_interp { class RecordMatchOp; } // namespace pdl_interp namespace detail { class PDLByteCode; /// Use generic bytecode types. ByteCodeField refers to the actual bytecode /// entries. ByteCodeAddr refers to size of indices into the bytecode. using ByteCodeField = uint16_t; using ByteCodeAddr = uint32_t; using OwningOpRange = llvm::OwningArrayRef; //===----------------------------------------------------------------------===// // PDLByteCodePattern //===----------------------------------------------------------------------===// /// All of the data pertaining to a specific pattern within the bytecode. class PDLByteCodePattern : public Pattern { public: static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, ByteCodeAddr rewriterAddr); /// Return the bytecode address of the rewriter for this pattern. ByteCodeAddr getRewriterAddr() const { return rewriterAddr; } private: template PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs) : Pattern(std::forward(patternArgs)...), rewriterAddr(rewriterAddr) {} /// The address of the rewriter for this pattern. ByteCodeAddr rewriterAddr; }; //===----------------------------------------------------------------------===// // PDLByteCodeMutableState //===----------------------------------------------------------------------===// /// This class contains the mutable state of a bytecode instance. This allows /// for a bytecode instance to be cached and reused across various different /// threads/drivers. class PDLByteCodeMutableState { public: /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds /// to the position of the pattern within the range returned by /// `PDLByteCode::getPatterns`. void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit); /// Cleanup any allocated state after a match/rewrite has been completed. This /// method should be called irregardless of whether the match+rewrite was a /// success or not. void cleanupAfterMatchAndRewrite(); private: /// Allow access to data fields. friend class PDLByteCode; /// The mutable block of memory used during the matching and rewriting phases /// of the bytecode. std::vector memory; /// A mutable block of memory used during the matching and rewriting phase of /// the bytecode to store ranges of operations. These are always stored by /// owning references, because at no point in the execution of the byte code /// we get an indexed range (view) of operations. std::vector opRangeMemory; /// A mutable block of memory used during the matching and rewriting phase of /// the bytecode to store ranges of types. std::vector typeRangeMemory; /// A set of type ranges that have been allocated by the byte code interpreter /// to provide a guaranteed lifetime. std::vector> allocatedTypeRangeMemory; /// A mutable block of memory used during the matching and rewriting phase of /// the bytecode to store ranges of values. std::vector valueRangeMemory; /// A set of value ranges that have been allocated by the byte code /// interpreter to provide a guaranteed lifetime. std::vector> allocatedValueRangeMemory; /// The current index of ranges being iterated over for each level of nesting. /// These are always maintained at 0 for the loops that are not active, so we /// do not need to have a separate initialization phase for each loop. std::vector loopIndex; /// The up-to-date benefits of the patterns held by the bytecode. The order /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`. std::vector currentPatternBenefits; }; //===----------------------------------------------------------------------===// // PDLByteCode //===----------------------------------------------------------------------===// /// The bytecode class is also the interpreter. Contains the bytecode itself, /// the static info, addresses of the rewriter functions, the interpreter /// memory buffer, and the execution context. class PDLByteCode { public: /// Each successful match returns a MatchResult, which contains information /// necessary to execute the rewriter and indicates the originating pattern. struct MatchResult { MatchResult(Location loc, const PDLByteCodePattern &pattern, PatternBenefit benefit) : location(loc), pattern(&pattern), benefit(benefit) {} MatchResult(const MatchResult &) = delete; MatchResult &operator=(const MatchResult &) = delete; MatchResult(MatchResult &&other) = default; MatchResult &operator=(MatchResult &&) = default; /// The location of operations to be replaced. Location location; /// Memory values defined in the matcher that are passed to the rewriter. SmallVector values; /// Memory used for the range input values. SmallVector typeRangeValues; SmallVector valueRangeValues; /// The originating pattern that was matched. This is always non-null, but /// represented with a pointer to allow for assignment. const PDLByteCodePattern *pattern; /// The current benefit of the pattern that was matched. PatternBenefit benefit; }; /// Create a ByteCode instance from the given module containing operations in /// the PDL interpreter dialect. PDLByteCode(ModuleOp module, llvm::StringMap constraintFns, llvm::StringMap rewriteFns); /// Return the patterns held by the bytecode. ArrayRef getPatterns() const { return patterns; } /// Initialize the given state such that it can be used to execute the current /// bytecode. void initializeMutableState(PDLByteCodeMutableState &state) const; /// Run the pattern matcher on the given root operation, collecting the /// matched patterns in `matches`. void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl &matches, PDLByteCodeMutableState &state) const; /// Run the rewriter of the given pattern that was previously matched in /// `match`. void rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const; private: /// Execute the given byte code starting at the provided instruction `inst`. /// `matches` is an optional field provided when this function is executed in /// a matching context. void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter, PDLByteCodeMutableState &state, SmallVectorImpl *matches) const; /// A vector containing pointers to uniqued data. The storage is intentionally /// opaque such that we can store a wide range of data types. The types of /// data stored here include: /// * Attribute, OperationName, Type std::vector uniquedData; /// A vector containing the generated bytecode for the matcher. SmallVector matcherByteCode; /// A vector containing the generated bytecode for all of the rewriters. SmallVector rewriterByteCode; /// The set of patterns contained within the bytecode. SmallVector patterns; /// A set of user defined functions invoked via PDL. std::vector constraintFunctions; std::vector rewriteFunctions; /// The maximum memory index used by a value. ByteCodeField maxValueMemoryIndex = 0; /// The maximum number of different types of ranges. ByteCodeField maxOpRangeCount = 0; ByteCodeField maxTypeRangeCount = 0; ByteCodeField maxValueRangeCount = 0; /// The maximum number of nested loops. ByteCodeField maxLoopLevel = 0; }; } // namespace detail } // namespace mlir #endif // MLIR_REWRITE_BYTECODE_H_