diff options
Diffstat (limited to 'offload/tools/offload-tblgen/EntryPointGen.cpp')
| -rw-r--r-- | offload/tools/offload-tblgen/EntryPointGen.cpp | 49 |
1 files changed, 35 insertions, 14 deletions
diff --git a/offload/tools/offload-tblgen/EntryPointGen.cpp b/offload/tools/offload-tblgen/EntryPointGen.cpp index 85c5c50bf2f2..4e42e4905b99 100644 --- a/offload/tools/offload-tblgen/EntryPointGen.cpp +++ b/offload/tools/offload-tblgen/EntryPointGen.cpp @@ -35,21 +35,30 @@ static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) { } OS << ") {\n"; - OS << TAB_1 "if (offloadConfig().ValidationEnabled) {\n"; - // Emit validation checks - for (const auto &Return : F.getReturns()) { - for (auto &Condition : Return.getConditions()) { - if (Condition.starts_with("`") && Condition.ends_with("`")) { - auto ConditionString = Condition.substr(1, Condition.size() - 2); - OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString); - OS << formatv(TAB_3 "return createOffloadError(error::ErrorCode::{0}, " - "\"validation failure: {1}\");\n", - Return.getUnprefixedValue(), ConditionString); - OS << TAB_2 "}\n\n"; + bool HasValidation = llvm::any_of(F.getReturns(), [](auto &R) { + return llvm::any_of(R.getConditions(), [](auto &C) { + return C.starts_with("`") && C.ends_with("`"); + }); + }); + + if (HasValidation) { + OS << TAB_1 "if (llvm::offload::isValidationEnabled()) {\n"; + // Emit validation checks + for (const auto &Return : F.getReturns()) { + for (auto &Condition : Return.getConditions()) { + if (Condition.starts_with("`") && Condition.ends_with("`")) { + auto ConditionString = Condition.substr(1, Condition.size() - 2); + OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString); + OS << formatv(TAB_3 + "return createOffloadError(error::ErrorCode::{0}, " + "\"validation failure: {1}\");\n", + Return.getUnprefixedValue(), ConditionString); + OS << TAB_2 "}\n\n"; + } } } + OS << TAB_1 "}\n\n"; } - OS << TAB_1 "}\n\n"; // Perform actual function call to the implementation ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2); @@ -73,8 +82,12 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) { } OS << ") {\n"; + // Check offload is initialized + if (F.getName() != "olInit") + OS << "if (!llvm::offload::isOffloadInitialized()) return &UninitError;"; + // Emit pre-call prints - OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n"; + OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n"; OS << formatv(TAB_2 "llvm::errs() << \"---> {0}\";\n", F.getName()); OS << TAB_1 "}\n\n"; @@ -85,7 +98,7 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) { PrefixLower, F.getName(), ParamNameList); // Emit post-call prints - OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n"; + OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n"; if (F.getParams().size() > 0) { OS << formatv(TAB_2 "{0} Params = {{", F.getParamStructName()); for (const auto &Param : F.getParams()) { @@ -134,6 +147,14 @@ static void EmitCodeLocWrapper(const FunctionRec &F, raw_ostream &OS) { void EmitOffloadEntryPoints(const RecordKeeper &Records, raw_ostream &OS) { OS << GenericHeader; + + constexpr const char *UninitMessage = + "liboffload has not been initialized - please call olInit before using " + "this API"; + OS << formatv("static {0}_error_struct_t UninitError = " + "{{{1}_ERRC_UNINITIALIZED, \"{2}\"};", + PrefixLower, PrefixUpper, UninitMessage); + for (auto *R : Records.getAllDerivedDefinitions("Function")) { EmitValidationFunc(FunctionRec{R}, OS); EmitEntryPointFunc(FunctionRec{R}, OS); |
