diff options
Diffstat (limited to 'offload/liboffload/src/OffloadImpl.cpp')
| -rw-r--r-- | offload/liboffload/src/OffloadImpl.cpp | 59 |
1 files changed, 35 insertions, 24 deletions
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index 17a2b00cb714..ffc9016bca0a 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -84,17 +84,20 @@ struct ol_program_impl_t { DeviceImage(DeviceImage) {} plugin::DeviceImageTy *Image; std::unique_ptr<llvm::MemoryBuffer> ImageData; - std::vector<std::unique_ptr<ol_symbol_impl_t>> Symbols; + std::mutex SymbolListMutex; __tgt_device_image DeviceImage; + llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> KernelSymbols; + llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> GlobalSymbols; }; struct ol_symbol_impl_t { - ol_symbol_impl_t(GenericKernelTy *Kernel) - : PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL) {} - ol_symbol_impl_t(GlobalTy &&Global) - : PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE) {} + ol_symbol_impl_t(const char *Name, GenericKernelTy *Kernel) + : PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL), Name(Name) {} + ol_symbol_impl_t(const char *Name, GlobalTy &&Global) + : PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE), Name(Name) {} std::variant<GenericKernelTy *, GlobalTy> PluginImpl; ol_symbol_kind_t Kind; + llvm::StringRef Name; }; namespace llvm { @@ -231,7 +234,7 @@ Error olShutDown_impl() { for (auto &P : OldContext->Platforms) { // Host plugin is nullptr and has no deinit - if (!P.Plugin) + if (!P.Plugin || !P.Plugin->is_initialized()) continue; if (auto Res = P.Plugin->deinit()) @@ -714,32 +717,40 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name, ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) { auto &Device = Program->Image->getDevice(); + std::lock_guard<std::mutex> Lock{Program->SymbolListMutex}; + switch (Kind) { case OL_SYMBOL_KIND_KERNEL: { - auto KernelImpl = Device.constructKernel(Name); - if (!KernelImpl) - return KernelImpl.takeError(); + auto &Kernel = Program->KernelSymbols[Name]; + if (!Kernel) { + auto KernelImpl = Device.constructKernel(Name); + if (!KernelImpl) + return KernelImpl.takeError(); - if (auto Err = KernelImpl->init(Device, *Program->Image)) - return Err; + if (auto Err = KernelImpl->init(Device, *Program->Image)) + return Err; + + Kernel = std::make_unique<ol_symbol_impl_t>(KernelImpl->getName(), + &*KernelImpl); + } - *Symbol = - Program->Symbols - .emplace_back(std::make_unique<ol_symbol_impl_t>(&*KernelImpl)) - .get(); + *Symbol = Kernel.get(); return Error::success(); } case OL_SYMBOL_KIND_GLOBAL_VARIABLE: { - GlobalTy GlobalObj{Name}; - if (auto Res = Device.Plugin.getGlobalHandler().getGlobalMetadataFromDevice( - Device, *Program->Image, GlobalObj)) - return Res; - - *Symbol = Program->Symbols - .emplace_back( - std::make_unique<ol_symbol_impl_t>(std::move(GlobalObj))) - .get(); + auto &Global = Program->KernelSymbols[Name]; + if (!Global) { + GlobalTy GlobalObj{Name}; + if (auto Res = + Device.Plugin.getGlobalHandler().getGlobalMetadataFromDevice( + Device, *Program->Image, GlobalObj)) + return Res; + + Global = std::make_unique<ol_symbol_impl_t>(GlobalObj.getName().c_str(), + std::move(GlobalObj)); + } + *Symbol = Global.get(); return Error::success(); } default: |
