summaryrefslogtreecommitdiff
path: root/offload/tools/offload-tblgen/EntryPointGen.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'offload/tools/offload-tblgen/EntryPointGen.cpp')
-rw-r--r--offload/tools/offload-tblgen/EntryPointGen.cpp49
1 files changed, 35 insertions, 14 deletions
diff --git a/offload/tools/offload-tblgen/EntryPointGen.cpp b/offload/tools/offload-tblgen/EntryPointGen.cpp
index 85c5c50bf2f2..4e42e4905b99 100644
--- a/offload/tools/offload-tblgen/EntryPointGen.cpp
+++ b/offload/tools/offload-tblgen/EntryPointGen.cpp
@@ -35,21 +35,30 @@ static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) {
}
OS << ") {\n";
- OS << TAB_1 "if (offloadConfig().ValidationEnabled) {\n";
- // Emit validation checks
- for (const auto &Return : F.getReturns()) {
- for (auto &Condition : Return.getConditions()) {
- if (Condition.starts_with("`") && Condition.ends_with("`")) {
- auto ConditionString = Condition.substr(1, Condition.size() - 2);
- OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString);
- OS << formatv(TAB_3 "return createOffloadError(error::ErrorCode::{0}, "
- "\"validation failure: {1}\");\n",
- Return.getUnprefixedValue(), ConditionString);
- OS << TAB_2 "}\n\n";
+ bool HasValidation = llvm::any_of(F.getReturns(), [](auto &R) {
+ return llvm::any_of(R.getConditions(), [](auto &C) {
+ return C.starts_with("`") && C.ends_with("`");
+ });
+ });
+
+ if (HasValidation) {
+ OS << TAB_1 "if (llvm::offload::isValidationEnabled()) {\n";
+ // Emit validation checks
+ for (const auto &Return : F.getReturns()) {
+ for (auto &Condition : Return.getConditions()) {
+ if (Condition.starts_with("`") && Condition.ends_with("`")) {
+ auto ConditionString = Condition.substr(1, Condition.size() - 2);
+ OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString);
+ OS << formatv(TAB_3
+ "return createOffloadError(error::ErrorCode::{0}, "
+ "\"validation failure: {1}\");\n",
+ Return.getUnprefixedValue(), ConditionString);
+ OS << TAB_2 "}\n\n";
+ }
}
}
+ OS << TAB_1 "}\n\n";
}
- OS << TAB_1 "}\n\n";
// Perform actual function call to the implementation
ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2);
@@ -73,8 +82,12 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
}
OS << ") {\n";
+ // Check offload is initialized
+ if (F.getName() != "olInit")
+ OS << "if (!llvm::offload::isOffloadInitialized()) return &UninitError;";
+
// Emit pre-call prints
- OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n";
+ OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n";
OS << formatv(TAB_2 "llvm::errs() << \"---> {0}\";\n", F.getName());
OS << TAB_1 "}\n\n";
@@ -85,7 +98,7 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
PrefixLower, F.getName(), ParamNameList);
// Emit post-call prints
- OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n";
+ OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n";
if (F.getParams().size() > 0) {
OS << formatv(TAB_2 "{0} Params = {{", F.getParamStructName());
for (const auto &Param : F.getParams()) {
@@ -134,6 +147,14 @@ static void EmitCodeLocWrapper(const FunctionRec &F, raw_ostream &OS) {
void EmitOffloadEntryPoints(const RecordKeeper &Records, raw_ostream &OS) {
OS << GenericHeader;
+
+ constexpr const char *UninitMessage =
+ "liboffload has not been initialized - please call olInit before using "
+ "this API";
+ OS << formatv("static {0}_error_struct_t UninitError = "
+ "{{{1}_ERRC_UNINITIALIZED, \"{2}\"};",
+ PrefixLower, PrefixUpper, UninitMessage);
+
for (auto *R : Records.getAllDerivedDefinitions("Function")) {
EmitValidationFunc(FunctionRec{R}, OS);
EmitEntryPointFunc(FunctionRec{R}, OS);