From f0ba32d6667b2ea9a58c7dbd5f27984da9d335fc Mon Sep 17 00:00:00 2001 From: Shraiysh Vaishay Date: Wed, 13 Apr 2022 08:20:56 +0530 Subject: [mlir][LLVM-IR] Added support for global variable attributes This patch adds thread_local to llvm.mlir.global and adds translation for dso_local and addr_space to and from LLVM IR. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D123412 --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 69 ++++++++++++++++------------ mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp | 9 ++-- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 5 +- 3 files changed, 49 insertions(+), 34 deletions(-) (limited to 'mlir/lib') diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index bf8d487ccb49..3fc042d17c57 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1579,38 +1579,38 @@ LogicalResult AddressOfOp::verify() { // Builder, printer and verifier for LLVM::GlobalOp. //===----------------------------------------------------------------------===// -/// Returns the name used for the linkage attribute. This *must* correspond to -/// the name of the attribute in ODS. -static StringRef getLinkageAttrName() { return "linkage"; } - -/// Returns the name used for the unnamed_addr attribute. This *must* correspond -/// to the name of the attribute in ODS. -static StringRef getUnnamedAddrAttrName() { return "unnamed_addr"; } - void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, bool isConstant, Linkage linkage, StringRef name, Attribute value, uint64_t alignment, unsigned addrSpace, - bool dsoLocal, ArrayRef attrs) { - result.addAttribute(SymbolTable::getSymbolAttrName(), + bool dsoLocal, bool threadLocal, + ArrayRef attrs) { + result.addAttribute(getSymNameAttrName(result.name), builder.getStringAttr(name)); - result.addAttribute("global_type", TypeAttr::get(type)); + result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type)); if (isConstant) - result.addAttribute("constant", builder.getUnitAttr()); + result.addAttribute(getConstantAttrName(result.name), + builder.getUnitAttr()); if (value) - result.addAttribute("value", value); + result.addAttribute(getValueAttrName(result.name), value); if (dsoLocal) - result.addAttribute("dso_local", builder.getUnitAttr()); + result.addAttribute(getDsoLocalAttrName(result.name), + builder.getUnitAttr()); + if (threadLocal) + result.addAttribute(getThreadLocal_AttrName(result.name), + builder.getUnitAttr()); // Only add an alignment attribute if the "alignment" input // is different from 0. The value must also be a power of two, but // this is tested in GlobalOp::verify, not here. if (alignment != 0) - result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); + result.addAttribute(getAlignmentAttrName(result.name), + builder.getI64IntegerAttr(alignment)); - result.addAttribute(::getLinkageAttrName(), + result.addAttribute(getLinkageAttrName(result.name), LinkageAttr::get(builder.getContext(), linkage)); if (addrSpace != 0) - result.addAttribute("addr_space", builder.getI32IntegerAttr(addrSpace)); + result.addAttribute(getAddrSpaceAttrName(result.name), + builder.getI32IntegerAttr(addrSpace)); result.attributes.append(attrs.begin(), attrs.end()); result.addRegion(); } @@ -1622,6 +1622,8 @@ void GlobalOp::print(OpAsmPrinter &p) { if (!str.empty()) p << str << ' '; } + if (getThreadLocal_()) + p << "thread_local "; if (getConstant()) p << "constant "; p.printSymbolName(getSymName()); @@ -1632,10 +1634,11 @@ void GlobalOp::print(OpAsmPrinter &p) { // Note that the alignment attribute is printed using the // default syntax here, even though it is an inherent attribute // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) - p.printOptionalAttrDict((*this)->getAttrs(), - {SymbolTable::getSymbolAttrName(), "global_type", - "constant", "value", getLinkageAttrName(), - getUnnamedAddrAttrName()}); + p.printOptionalAttrDict( + (*this)->getAttrs(), + {SymbolTable::getSymbolAttrName(), getGlobalTypeAttrName(), + getConstantAttrName(), getValueAttrName(), getLinkageAttrName(), + getUnnamedAddrAttrName(), getThreadLocal_AttrName()}); // Print the trailing type unless it's a string global. if (getValueOrNull().dyn_cast_or_null()) @@ -1702,28 +1705,35 @@ static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { MLIRContext *ctx = parser.getContext(); // Parse optional linkage, default to External. - result.addAttribute(::getLinkageAttrName(), + result.addAttribute(getLinkageAttrName(result.name), LLVM::LinkageAttr::get( ctx, parseOptionalLLVMKeyword( parser, result, LLVM::Linkage::External))); + + if (succeeded(parser.parseOptionalKeyword("thread_local"))) + result.addAttribute(getThreadLocal_AttrName(result.name), + parser.getBuilder().getUnitAttr()); + // Parse optional UnnamedAddr, default to None. - result.addAttribute(::getUnnamedAddrAttrName(), + result.addAttribute(getUnnamedAddrAttrName(result.name), parser.getBuilder().getI64IntegerAttr( parseOptionalLLVMKeyword( parser, result, LLVM::UnnamedAddr::None))); if (succeeded(parser.parseOptionalKeyword("constant"))) - result.addAttribute("constant", parser.getBuilder().getUnitAttr()); + result.addAttribute(getConstantAttrName(result.name), + parser.getBuilder().getUnitAttr()); StringAttr name; - if (parser.parseSymbolName(name, SymbolTable::getSymbolAttrName(), + if (parser.parseSymbolName(name, getSymNameAttrName(result.name), result.attributes) || parser.parseLParen()) return failure(); Attribute value; if (parser.parseOptionalRParen()) { - if (parser.parseAttribute(value, "value", result.attributes) || + if (parser.parseAttribute(value, getValueAttrName(result.name), + result.attributes) || parser.parseRParen()) return failure(); } @@ -1755,7 +1765,8 @@ ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); } - result.addAttribute("global_type", TypeAttr::get(types[0])); + result.addAttribute(getGlobalTypeAttrName(result.name), + TypeAttr::get(types[0])); return success(); } @@ -1976,7 +1987,7 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, builder.getStringAttr(name)); result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(type)); - result.addAttribute(::getLinkageAttrName(), + result.addAttribute(getLinkageAttrName(result.name), LinkageAttr::get(builder.getContext(), linkage)); result.attributes.append(attrs.begin(), attrs.end()); if (dsoLocal) @@ -2036,7 +2047,7 @@ buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef inputs, ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { // Default to external linkage if no keyword is provided. result.addAttribute( - ::getLinkageAttrName(), + getLinkageAttrName(result.name), LinkageAttr::get(parser.getContext(), parseOptionalLLVMKeyword( parser, result, LLVM::Linkage::External))); diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index 2f866e10dd8a..bda695c35ddb 100644 --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -432,10 +432,11 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *gv) { alignment = align.value(); } - GlobalOp op = - b.create(UnknownLoc::get(context), type, gv->isConstant(), - convertLinkageFromLLVM(gv->getLinkage()), - gv->getName(), valueAttr, alignment); + GlobalOp op = b.create( + UnknownLoc::get(context), type, gv->isConstant(), + convertLinkageFromLLVM(gv->getLinkage()), gv->getName(), valueAttr, + alignment, /*addr_space=*/gv->getAddressSpace(), + /*dso_local=*/gv->isDSOLocal(), /*thread_local=*/gv->isThreadLocal()); if (gv->hasInitializer() && !valueAttr) { Region &r = op.getInitializerRegion(); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 953fb2461520..127e7e15ccab 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -661,7 +661,10 @@ LogicalResult ModuleTranslation::convertGlobals() { auto *var = new llvm::GlobalVariable( *llvmModule, type, op.getConstant(), linkage, cst, op.getSymName(), - /*InsertBefore=*/nullptr, llvm::GlobalValue::NotThreadLocal, addrSpace); + /*InsertBefore=*/nullptr, + op.getThreadLocal_() ? llvm::GlobalValue::GeneralDynamicTLSModel + : llvm::GlobalValue::NotThreadLocal, + addrSpace); if (op.getUnnamedAddr().hasValue()) var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr())); -- cgit v1.2.3