diff options
Diffstat (limited to 'mlir/lib/Target/LLVMIR/ModuleImport.cpp')
| -rw-r--r-- | mlir/lib/Target/LLVMIR/ModuleImport.cpp | 55 |
1 files changed, 36 insertions, 19 deletions
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 21f2050cbceb..4ff1f1135b0a 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -792,10 +792,6 @@ Attribute ModuleImport::getConstantAsAttr(llvm::Constant *constant) { if (Attribute scalarAttr = getScalarConstantAsAttr(builder, constant)) return scalarAttr; - // Convert function references. - if (auto *func = dyn_cast<llvm::Function>(constant)) - return SymbolRefAttr::get(builder.getContext(), func->getName()); - // Returns the static shape of the provided type if possible. auto getConstantShape = [&](llvm::Type *type) { return llvm::dyn_cast_if_present<ShapedType>( @@ -878,6 +874,24 @@ Attribute ModuleImport::getConstantAsAttr(llvm::Constant *constant) { return {}; } +FlatSymbolRefAttr +ModuleImport::getOrCreateNamelessSymbolName(llvm::GlobalVariable *globalVar) { + assert(globalVar->getName().empty() && + "expected to work with a nameless global"); + auto [it, success] = namelessGlobals.try_emplace(globalVar); + if (!success) + return it->second; + + // Make sure the symbol name does not clash with an existing symbol. + SmallString<128> globalName = SymbolTable::generateSymbolName<128>( + getNamelessGlobalPrefix(), + [this](StringRef newName) { return llvmModule->getNamedValue(newName); }, + namelessGlobalId); + auto symbolRef = FlatSymbolRefAttr::get(context, globalName); + it->getSecond() = symbolRef; + return symbolRef; +} + LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) { // Insert the global after the last one or at the start of the module. OpBuilder::InsertionGuard guard(builder); @@ -911,17 +925,10 @@ LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) { // Workaround to support LLVM's nameless globals. MLIR, in contrast to LLVM, // always requires a symbol name. - SmallString<128> globalName(globalVar->getName()); - if (globalName.empty()) { - // Make sure the symbol name does not clash with an existing symbol. - globalName = SymbolTable::generateSymbolName<128>( - getNamelessGlobalPrefix(), - [this](StringRef newName) { - return llvmModule->getNamedValue(newName); - }, - namelessGlobalId); - namelessGlobals[globalVar] = FlatSymbolRefAttr::get(context, globalName); - } + StringRef globalName = globalVar->getName(); + if (globalName.empty()) + globalName = getOrCreateNamelessSymbolName(globalVar).getValue(); + GlobalOp globalOp = builder.create<GlobalOp>( mlirModule.getLoc(), type, globalVar->isConstant(), convertLinkageFromLLVM(globalVar->getLinkage()), StringRef(globalName), @@ -1019,6 +1026,14 @@ ModuleImport::getConstantsToConvert(llvm::Constant *constant) { workList.insert(constant); while (!workList.empty()) { llvm::Constant *current = workList.back(); + // References of global objects are just pointers to the object. Avoid + // walking the elements of these here. + if (isa<llvm::GlobalObject>(current)) { + orderedSet.insert(current); + workList.pop_back(); + continue; + } + // Collect all dependencies of the current constant and add them to the // adjacency list if none has been computed before. auto [adjacencyIt, inserted] = adjacencyLists.try_emplace(current); @@ -1096,12 +1111,14 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) { } // Convert global variable accesses. - if (auto *globalVar = dyn_cast<llvm::GlobalVariable>(constant)) { - Type type = convertType(globalVar->getType()); - StringRef globalName = globalVar->getName(); + if (auto *globalObj = dyn_cast<llvm::GlobalObject>(constant)) { + Type type = convertType(globalObj->getType()); + StringRef globalName = globalObj->getName(); FlatSymbolRefAttr symbolRef; + // Empty names are only allowed for global variables. if (globalName.empty()) - symbolRef = namelessGlobals[globalVar]; + symbolRef = + getOrCreateNamelessSymbolName(cast<llvm::GlobalVariable>(globalObj)); else symbolRef = FlatSymbolRefAttr::get(context, globalName); return builder.create<AddressOfOp>(loc, type, symbolRef).getResult(); |
