For Sumatra targets, we wanted to experiment with an infrastructure for supporting deoptimization, that is, transferring execution from the compiled code running on the GPU back to the equivalent bytecodes being run through the interpreter on the CPU. The following describes some experiments with deoptimization using the HSAIL Backend which are now available on the graal trunk.
Some reasons for deoptimization on the GPU are:
- When compile-time assumptions are violated, we can "trap" to the interpreter (this relies on the fact that the interpreter can handle anything). In addition, we have a way of handling certain hopefully rare events, such as throwing exceptions back to the CPU, which might be difficult to implement completely from the GPU. If statistics shows that such events are not actually "rare", we can
- decide that in the future this particular lambda is not a good candidate for offload.
- or in some cases we might be able to recompile and generate new code for the GPU.
- compiled code running on the GPU might get to a point where it needs the CPU to do something before the GPU can make further progress. For example, if we are supporting allocation on the GPU, we could get to a point where we cannot allocate any new object until a GC happens. If the target does not have an easy way to spin and wait for the CPU to do the GC, one way to support this is to deoptimize. The interpreter will let the GC happen and then continue executing bytecodes from the point of the deoptimization, including finishing the allocation.
- It allows an implementation of compiler safepoints. A long-running kernel can check an external flag, when directed by the compiler, for example at the bottom of loops, and deoptimize if the flag is set. This external flag can be set by the VM safepoint logic. And since the deoptimization into the interpreter can pause at safepoints, this mechanism allows a long-running kernel to be interrupted so that CPU threads do not have to wait so long.
A Small Example
Here is a simple example where we want to produce an output array of the squared value of a sequence of integers. But the logic that computes the output array index will cause an ArrayIndexOutOfBoundsException about halfway through the range. You can build and run this using the instructions on Standalone Sumatra Stream API Offload Demo assuming you are using the latest graal and sumatra trunks.
Try running with -Dcom.amd.sumatra.offload.immediate=false (for normal JDK stream parallel operation) and -Dcom.amd.sumatra.offload.immediate=true (so the lambda will be offloaded to the GPU). Note that the stack trace shows the same trace lines through the context of the lambda itself. (Lines further up the stack will be dependent on the internal mechanism used to run the lambda across the range).
Note that on any run some output array slots contain their original -1 value, indicating the workitem for that entry did not run. The set of workitems that did not run may be different for the GPU vs. CPU cases, in fact it may be different for any two GPU runs, or any two CPU runs. The semantics for an exception on a stream operation is that the first exception is reported and pending new workitems will not run. Since the lambda is executing in parallel across the range, the set of workitems that might not run because of the exception is implementation-dependent. As a further experiment, you could try removing the .parallel() call from the forEach invocation, and see yet another output array configuration for a non-parallel run.
package simpledeopt;
import java.util.stream.IntStream;
import java.util.Arrays;
public class SimpleDeopt {
public static void main(String[] args) {
final int length = 20;
int[] output = new int[length];
Arrays.fill(output, -1);
try {
// offloadable since it is parallel
// will trigger exception halfway thru the range
IntStream.range(0, length).parallel().forEach(p -> {
int outIndex = (p == length/2 ? p + length : p);
writeIntArray(output, outIndex, p * p);
});
} catch (Exception e) {
e.printStackTrace();
}
// Print results - not offloadable since it is not parallel
IntStream.range(0, length).forEach(p -> {
System.out.println(p + ", " + output[p]);
});
}
static void writeIntArray(int[] ary, int index, int val) {
ary[index] = val;
}
}
Implementation Notes
Compile Time
We use the graal compiler to generate the hsail code. The graal compiler has a mature infrastructure for supporting deoptimization and still achieving good code quality. See http://design.cs.iastate.edu/vmil/2013/papers/p04-Duboscq.pdf. Basically the compiler nicely keeps track of the deoptimization state at each deopt point, and from that we can tell what HSAIL registers or stackslots need to be saved, which registers or stackslots contain oops, etc.
In HSAIL, registers do not live across function calls, so the register saving must be inlined into each function. To avoid code bloat, we currently have one deopt exit point per kernel. In that deopt exit code we save the union of the actual registers and stack slots that are live at any of the infopoints. There is another reason why only the minimum number of registers should be saved. The total number of HSAIL registers used affects the code quality when the HSAIL code is finalized to the device IL, even if registers are used on a "cold" path such as a deopt exit point.
Execution Time
When we dispatch a kernel to the GPU, we need to finish the execution of all the workitems even if one or more of the workitems deoptimize. The kernel code can be executed across a possibly very large number of workitems, each of which can have its own state. The non-deoptimizing workitems can finish as they normally would but we need to be able to save the state of the deoptimizing workitems. For a particular kernel graal can tell us the maximum size of the registers and stack slots that will need to be saved for any deoptimization PC in that kernel. We don't know up front how many workitems are going to need to deoptimize and need to save their state, yet we want to avoid having to allocate state-saving space for the entire possibly very large range of workitems.
A particular HSA target has a maximum number of workitems that can be executing concurrently. So to save space, we need only allocate state-saving space for this maximum number of possible concurrent workitems. To support this, deopting workitems set a "deopt happened" flag, and before beginning execution, each workitem check this "deopt happened" flag. Future workitems that see the "deopt happened" flag as true just set a simple flag indicating they never ran and exit immediately, thus not needing to save any additional state. Currently the never_ran flag array is one byte per workitem. We are looking at ways to make this smaller but HSA devices have a lot of freedom in how they schedule workitems which makes it difficult to figure out the never-rans.
Workitems that deopt atomically bump an index saying where they should store their deopt data. The deopt data consists of
- the workitem id, which is a linear index across the range and unique for each workitem
- the deoptimization actionAndReason, just as in cpu deoptimizations
- the first HSAILFrame
An HSAILFrame consists of:
- the deopt Id or "pc" offset where the deopt occurred
- number of 32-bit s registers saved
- number of 64-bit d registers saved
- number of stack-slot variables saved
- actual space for saving the s and d registers and stack slot variables
When the GPU dispatch completes, each workitem will have either finished normally, deopted, or not run at all. Still in the VM, we check if there were any deopts. If not, we know everything completed normally and we can just return back to java. However, if there was at least one deopt then
- for the workitems that finished normally, there is nothing to do
- if there are any deopted workitems, we want to run each deopting workitem thru the interpreter starting from the byte code index of the deoptimization point. Note that it is possible for different workitems to have different deoptimization "PCs". We currently re-dispatch each of the deopting workitems sequentially although other policies are clearly possible.
- Note: Getting to the interpreter from the saved hsail state goes first through some special compiled host trampoline code infrastructure designed by Gilles Duboscq of the graal team. The trampoline host code takes the hsail deoptId and a pointer to the saved hsail frame as input and then immediately deoptimizes just as any host compiled code would deoptimize.
- Note: Getting to the interpreter from the saved hsail state goes first through some special compiled host trampoline code infrastructure designed by Gilles Duboscq of the graal team. The trampoline host code takes the hsail deoptId and a pointer to the saved hsail frame as input and then immediately deoptimizes just as any host compiled code would deoptimize.
- for each never-ran workitem, we can just run it as a "javaCall" from the beginning of the kernel method, just making sure we pass the arguments and the appropriate workitem id for each one. We currently do this sequentially although other policies are possible. One policy that could be considered if the never-rans are contiguous is to resubmit the kernel with a sort of offset applied to each workitemId. For instance if out of an original range of 10000, we see that workitems 8000-9999 did not run, these could be run as a new range of 2000, as long as each workitem adds 8000 to its raw workitemId.
GC Considerations
Currently, the normal kernel dispatch runs in "thread in VM" mode and thus does not need to worry about moving Oops. However, each time a deopting workitem is run through the interpreter we are back in Java mode which can cause GCs. So for each saved hsail frame, we need to know which of the saved state contain oops and make sure those locations are updated in the face of GC. The OopMap is calculated on the java side and passed down at execute time. Given a deoptimization PC, once can use the OopMap tell which 64-bit registers or stack slots are oops.
This OopMap information is used by a GC thread when as part of its oops_do work it goes through the list of JavaThreads and finds one that is using HSAIL execution.