Skip to content

Commit

Permalink
[wasm] Add a monitoring phase to jiterpreter traces and discard unpro…
Browse files Browse the repository at this point in the history
…ductive ones (#83432)

* Add a monitoring phase to jiterpreter traces, where we determine the average distance (in bytes) they travel. Then reject traces that have a low average distance after the monitoring period
  • Loading branch information
kg committed Mar 15, 2023
1 parent 367a360 commit 1cfda5e
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 6 deletions.
15 changes: 10 additions & 5 deletions src/mono/mono/mini/interp/interp.c
Original file line number Diff line number Diff line change
Expand Up @@ -7787,11 +7787,7 @@ MINT_IN_CASE(MINT_BRTRUE_I8_SP) ZEROP_SP(gint64, !=); MINT_IN_BREAK;
* (note that right now threading doesn't work, but it's worth being correct
* here so that implementing thread support will be easier later.)
*/
*mutable_ip = MINT_TIER_NOP_JITERPRETER;
mono_memory_barrier ();
*(volatile JiterpreterThunk*)(ip + 1) = prepare_result;
mono_memory_barrier ();
*mutable_ip = MINT_TIER_ENTER_JITERPRETER;
*mutable_ip = MINT_TIER_MONITOR_JITERPRETER;
// now execute the trace
// this isn't important for performance, but it makes it easier to use the
// jiterpreter early in automated tests where code only runs once
Expand All @@ -7806,6 +7802,15 @@ MINT_IN_CASE(MINT_BRTRUE_I8_SP) ZEROP_SP(gint64, !=); MINT_IN_BREAK;
MINT_IN_BREAK;
}

MINT_IN_CASE(MINT_TIER_MONITOR_JITERPRETER) {
// The trace is in monitoring mode, where we track how far it actually goes
// each time it is executed for a while. After N more hits, we either
// turn it into an ENTER or a NOP depending on how well it is working
ptrdiff_t offset = mono_jiterp_monitor_trace (ip, frame, locals);
ip = (guint16*) (((guint8*)ip) + offset);
MINT_IN_BREAK;
}

MINT_IN_CASE(MINT_TIER_ENTER_JITERPRETER) {
JiterpreterThunk thunk = (void*)READ32(ip + 1);
ptrdiff_t offset = thunk(frame, locals);
Expand Down
71 changes: 71 additions & 0 deletions src/mono/mono/mini/interp/jiterpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -863,13 +863,24 @@ typedef struct {
// 64-bits because it can get very high if estimate heat is turned on
gint64 hit_count;
JiterpreterThunk thunk;
int size_of_trace;
gint32 total_distance;
} TraceInfo;

// If a trace exits with an exterior backward branch, treat its distance as this value
#define TRACE_NEGATIVE_DISTANCE 64
// Don't allow a trace to increase the total distance by more than this amount, since
// it would skew the average too much
#define TRACE_DISTANCE_LIMIT 512

// The maximum number of trace segments used to store TraceInfo. This limits
// the maximum total number of traces to MAX_TRACE_SEGMENTS * TRACE_SEGMENT_SIZE
#define MAX_TRACE_SEGMENTS 256
#define TRACE_SEGMENT_SIZE 1024

static volatile gint32 trace_count = 0;
static TraceInfo *trace_segments[MAX_TRACE_SEGMENTS] = { NULL };
static gint32 traces_rejected = 0;

static TraceInfo *
trace_info_allocate_segment (gint32 index) {
Expand Down Expand Up @@ -1030,6 +1041,9 @@ mono_interp_tier_prepare_jiterpreter_fast (
frame, method, ip, (gint32)trace_index,
start_of_body, size_of_body
);
// Record the maximum size of the trace (we don't know how long it actually is here)
// which might be smaller than the function body if this trace is in the middle
trace_info->size_of_trace = size_of_body - (ip - start_of_body);
trace_info->thunk = result;
return result;
} else {
Expand Down Expand Up @@ -1309,6 +1323,63 @@ mono_jiterp_write_number_unaligned (void *dest, double value, int mode) {
}
}

ptrdiff_t
mono_jiterp_monitor_trace (const guint16 *ip, void *frame, void *locals)
{
gint32 index = READ32(ip + 1);
TraceInfo *info = trace_info_get(index);
g_assert(info);

JiterpreterThunk thunk = info->thunk;
// FIXME: This shouldn't be possible
if (((guint32)(void *)thunk) <= JITERPRETER_NOT_JITTED)
return 6;
ptrdiff_t result = thunk(frame, locals);
// Maintain an approximate sum of how far trace execution has advanced over
// the monitoring period, so we can evaluate its average later and decide
// whether to keep the trace
// Note that a result of 0 means that a loop back-branched to itself.
info->total_distance += result <= 0
? TRACE_NEGATIVE_DISTANCE
: (result > TRACE_DISTANCE_LIMIT
? TRACE_DISTANCE_LIMIT
: result
);

gint64 hit_count = info->hit_count++ - mono_opt_jiterpreter_minimum_trace_hit_count;
if (hit_count == mono_opt_jiterpreter_trace_monitoring_period) {
// Prepare to enable the trace
volatile guint16 *mutable_ip = (volatile guint16*)ip;
*mutable_ip = MINT_TIER_NOP_JITERPRETER;

mono_memory_barrier ();
gint64 average_distance = info->total_distance / hit_count;
gint64 threshold = mono_opt_jiterpreter_trace_average_distance_threshold,
low_threshold = info->size_of_trace / 2;
// Don't reject short traces as long as they run mostly to the end, we already
// decided previously that they are worth keeping for some reason
if (low_threshold < threshold)
threshold = low_threshold;

if (average_distance >= threshold) {
*(volatile JiterpreterThunk*)(ip + 1) = thunk;
mono_memory_barrier ();
*mutable_ip = MINT_TIER_ENTER_JITERPRETER;
} else {
traces_rejected++;
// g_print("trace #%d @%d rejected; average_distance==%d\n", index, ip, average_distance);
}
}

return result;
}

EMSCRIPTEN_KEEPALIVE gint32
mono_jiterp_get_rejected_trace_count ()
{
return traces_rejected;
}

// HACK: fix C4206
EMSCRIPTEN_KEEPALIVE
#endif // HOST_BROWSER
Expand Down
3 changes: 3 additions & 0 deletions src/mono/mono/mini/interp/jiterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ mono_jiterp_imethod_to_ftnptr (InterpMethod *imethod);
void
mono_jiterp_enum_hasflag (MonoClass *klass, gint32 *dest, stackval *sp1, stackval *sp2);

ptrdiff_t
mono_jiterp_monitor_trace (const guint16 *ip, void *frame, void *locals);

#endif // __MONO_MINI_INTERPRETER_INTERNALS_H__

extern WasmDoJitCall jiterpreter_do_jit_call;
Expand Down
1 change: 1 addition & 0 deletions src/mono/mono/mini/interp/mintops.def
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,7 @@ OPDEF(MINT_METADATA_UPDATE_LDFLDA, "metadata_update.ldflda", 5, 1, 1, MintOpTwoS
OPDEF(MINT_TIER_PREPARE_JITERPRETER, "tier_prepare_jiterpreter", 3, 0, 0, MintOpInt)
OPDEF(MINT_TIER_NOP_JITERPRETER, "tier_nop_jiterpreter", 3, 0, 0, MintOpInt)
OPDEF(MINT_TIER_ENTER_JITERPRETER, "tier_enter_jiterpreter", 3, 0, 0, MintOpInt)
OPDEF(MINT_TIER_MONITOR_JITERPRETER, "tier_monitor_jiterpreter", 3, 0, 0, MintOpInt)
#endif // HOST_BROWSER

IROPDEF(MINT_NOP, "nop", 1, 0, 0, MintOpNoArgs)
Expand Down
4 changes: 4 additions & 0 deletions src/mono/mono/utils/options-def.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ DEFINE_INT(jiterpreter_minimum_trace_length, "jiterpreter-minimum-trace-length",
DEFINE_INT(jiterpreter_minimum_distance_between_traces, "jiterpreter-minimum-distance-between-traces", 4, "Don't insert entry points closer together than this")
// once a trace entry point is inserted, we only actually JIT code for it once it's been hit this many times
DEFINE_INT(jiterpreter_minimum_trace_hit_count, "jiterpreter-minimum-trace-hit-count", 5000, "JIT trace entry points once they are hit this many times")
// trace prepares turn into a monitor opcode and stay one this long before being converted to enter or nop
DEFINE_INT(jiterpreter_trace_monitoring_period, "jiterpreter-trace-monitoring-period", 3000, "Monitor jitted traces for this many calls to determine whether to keep them")
// traces that only offset ip by less than this on average will be rejected
DEFINE_INT(jiterpreter_trace_average_distance_threshold, "jiterpreter-trace-average-distance-threshold", 52, "Traces with an average distance less than this will be discarded")
// After a do_jit_call call site is hit this many times, we will queue it to be jitted
DEFINE_INT(jiterpreter_jit_call_trampoline_hit_count, "jiterpreter-jit-call-hit-count", 1000, "Queue specialized do_jit_call trampoline for JIT after this many hits")
// After a do_jit_call call site is hit this many times without being jitted, we will flush the JIT queue
Expand Down
2 changes: 2 additions & 0 deletions src/mono/wasm/runtime/cwraps.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ const fn_signatures: SigLine[] = [
[true, "mono_jiterp_debug_count", "number", []],
[true, "mono_jiterp_get_trace_hit_count", "number", ["number"]],
[true, "mono_jiterp_get_polling_required_address", "number", []],
[true, "mono_jiterp_get_rejected_trace_count", "number", []],
...legacy_interop_cwraps
];

Expand Down Expand Up @@ -236,6 +237,7 @@ export interface t_Cwraps {
mono_jiterp_get_trace_hit_count(traceIndex: number): number;
mono_jiterp_get_polling_required_address(): Int32Ptr;
mono_jiterp_write_number_unaligned(destination: VoidPtr, value: number, mode: number): void;
mono_jiterp_get_rejected_trace_count(): number;
}

const wrapped_c_functions: t_Cwraps = <any>{};
Expand Down
3 changes: 2 additions & 1 deletion src/mono/wasm/runtime/jiterpreter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,8 @@ export function jiterpreter_dump_stats (b?: boolean, concise?: boolean) {

console.log(`// jitted ${counters.bytesGenerated} bytes; ${counters.tracesCompiled} traces (${counters.traceCandidates} candidates, ${(counters.tracesCompiled / counters.traceCandidates * 100).toFixed(1)}%); ${counters.jitCallsCompiled} jit_calls (${(counters.directJitCallsCompiled / counters.jitCallsCompiled * 100).toFixed(1)}% direct); ${counters.entryWrappersCompiled} interp_entries`);
const backBranchHitRate = (counters.backBranchesEmitted / (counters.backBranchesEmitted + counters.backBranchesNotEmitted)) * 100;
console.log(`// time: ${elapsedTimes.generation | 0}ms generating, ${elapsedTimes.compilation | 0}ms compiling wasm. ${counters.nullChecksEliminated} null checks eliminated. ${counters.backBranchesEmitted} back-branches emitted (${counters.backBranchesNotEmitted} failed, ${backBranchHitRate.toFixed(1)}%)`);
const tracesRejected = cwraps.mono_jiterp_get_rejected_trace_count();
console.log(`// time: ${elapsedTimes.generation | 0}ms generating, ${elapsedTimes.compilation | 0}ms compiling wasm. ${counters.nullChecksEliminated} cknulls removed. ${counters.backBranchesEmitted} back-branches (${counters.backBranchesNotEmitted} failed, ${backBranchHitRate.toFixed(1)}%), ${tracesRejected} traces rejected`);
if (concise)
return;

Expand Down

0 comments on commit 1cfda5e

Please sign in to comment.