JIT Compilation (Proteus + RAJA)¶
RAJA can optionally integrate with Proteus to just-in-time (JIT) compile specialized variants of kernels. This is useful when some performance-critical values are only known at runtime. For example, propagating loop bounds as runtime constants can enhance loop analysis and scheduling. Propagating other values can enable optimizations like branch elimination, etc.
Warning
This capability is new and should be considered experimental.
Enabling JIT in a RAJA build¶
JIT support is enabled at configuration time:
cmake -DRAJA_ENABLE_JIT=On ...
Enabling RAJA_ENABLE_JIT requires an LLVM-based (Clang-family) compiler.
If you enable JIT and configure with a non-Clang compiler, configuration will
fail. See Build Configuration Options for the CMake options described here.
Proteus dependency¶
When RAJA_ENABLE_JIT=On, RAJA needs Proteus headers and build integration:
If you provide
-DPROTEUS_INSTALL_DIR=<prefix>, RAJA will usefind_package(proteus ...)using that prefix.Otherwise, RAJA will attempt to fetch Proteus via CMake
FetchContentat configure time.
LLVM installation requirement¶
Unless you provide a Proteus installation that is statically linked with LLVM,
Proteus support requires an LLVM 18, 19, or 20 installation that you must point
RAJA/Proteus at via LLVM_INSTALL_DIR:
cmake -DRAJA_ENABLE_JIT=On -DLLVM_INSTALL_DIR=/path/to/llvm-19 ...
An example of how to configure a JIT build of RAJA with HIP on LC machines is included
in scripts/toss4_amdclang_proteus.sh.
Marking a kernel for JIT¶
The user-facing interface shown in examples/forall-jit.cpp consists of:
RAJA_JIT_COMPILE: annotate a lambda or function so Proteus can identify it as a JIT compilation candidate.RAJA_JIT_VARIABLE: wrap runtime values that should be treated as constants for specialization.
For example, specializing loop bounds and a branch condition:
RAJA::forall<policy>(RAJA::RangeSegment(0, N), [=,
a = RAJA_JIT_VARIABLE(a),
b = RAJA_JIT_VARIABLE(b),
accum = RAJA_JIT_VARIABLE(accum)
] (int i) RAJA_JIT_COMPILE {
for (int row = 0; row < a; ++row){
for (int col = 0; col < b; ++col) {
if (!accum) {
C(i, row, col) = A(i, row, col) * B(i, col, row);
}
else {
C(i, row, col) += A(i, row, col) * B(i, col, row);
}
}
}
});
When JIT is disabled (RAJA_ENABLE_JIT=Off), RAJA_JIT_COMPILE expands to
nothing. Similarly, RAJA_JIT_VARIABLE expands to proteus::jit_variable when JIT
is enabled, but simply expands to its single argument with RAJA_ENABLE_JIT=Off.
proteus::enable() and proteus::disable() manually enable/disable Proteus within a
region of source code. Currently, with RAJA_ENABLE_JIT=On, all RAJA kernels
will be JIT compiled unless proteus::disable() is specified.
Building and running the example¶
The example examples/forall-jit.cpp is built from the RAJA source tree when:
ENABLE_EXAMPLES=OnRAJA_ENABLE_JIT=On
The example takes four command-line arguments:
./bin/forall-jit <a> <b> <N> <accum>
where a and b are matrix dimensions, N is the problem size, and
accum is the branch condition (0/1) that is specialized with JIT.
The example performs N-many matrix multiplications (1 per thread). Each multiplication
is [a x b] [b x a]. The result is either set or added into the output, depending
on the boolean flag the user provides. By forcing each thread to perform serialized
arithmetic with a simple branch condition, we show how JIT compilation can improve both
serial loop scheduling (per-thread) and branch elimination.
Specializing argument indices (advanced)¶
RAJA also provides RAJA_JIT_COMPILE_ARGS(...) to annotate functions and
specify which 1-indexed arguments should be treated as specialization inputs:
__global__ RAJA_JIT_COMPILE_ARGS(3) void my_kernel(int x, int y, int z) { ... }
This is used internally by RAJA’s GPU back-ends; most users only need
RAJA_JIT_COMPILE on their lambdas and proteus::jit_variable for values
captured into the lambda.