#!/usr/bin/python
# @lint-avoid-python-3-compatibility-imports
#
# alisyslatency  Summarize time spent in kernel context.
#                For Linux, uses BCC, eBPF.
#
# USAGE: alisyslatency [-h] [-T] [-N] [-c CPU] [-p PID] [interval] [outputs]
#
# Copyright (c) 2019 Jeffle Xu, Alibaba, Inc.
# Licensed under the Apache License, Version 2.0 (the "License")

from __future__ import print_function
from bcc import BPF
from time import sleep, strftime
import argparse

# arguments
examples = """examples:
    ./alisyslatency            # sum time in kernel context per CPU
    ./alisyslatency -c 0       # show time in kernel context in CPU 0 only
    ./alisyslatency -p 25      # show time in kernel context in pid 25 only
    ./alisyslatency 1 10       # print 1 second summaries, 10 times
    ./alisyslatency -NT 1      # 1s summaries, nanoseconds, and timestamps
"""

parser = argparse.ArgumentParser(
    description="Summarize time spent in kernel context.",
    formatter_class=argparse.RawDescriptionHelpFormatter,
    epilog=examples)
parser.add_argument("-T", "--timestamp", action="store_true",
    help="include timestamp on output")
parser.add_argument("-N", "--nanoseconds", action="store_true",
    help="output in nanoseconds")

thread_group = parser.add_mutually_exclusive_group()
thread_group.add_argument("-c", "--cpu",
    help="output statistics on specific CPU only")
thread_group.add_argument("-p", "--pid",
    help="output statistics of specific pid only")

parser.add_argument("interval", nargs="?", default=99999999,
    help="output interval, in seconds")
parser.add_argument("outputs", nargs="?", default=99999999,
    help="number of outputs")
parser.add_argument("--ebpf", action="store_true",
    help=argparse.SUPPRESS)

args = parser.parse_args()
countdown = int(args.outputs)
if args.nanoseconds:
    factor = 1
    label = "nsecs"
else:
    factor = 1000
    label = "usecs"
debug = 0

# define BPF program
bpf_text = """
#include <uapi/linux/ptrace.h>

typedef struct map_val {
    u64 count;
    u64 time;
    u64 time_max;
} map_val_t;

BPF_PERCPU_ARRAY(start, u64, 1);
BPF_HASH(res, u32, map_val_t);


TRACEPOINT_PROBE(context_tracking, user_exit)
{
    COND_FILTER

    u32 idx = 0;
    u64 ts = bpf_ktime_get_ns();
    start.update(&idx, &ts);

    return 0;
}

TRACEPOINT_PROBE(context_tracking, user_enter)
{
    COND_FILTER

    u64 *tsp, delta;
    u32 idx = 0;
    // fetch timestamp and calculate delta
    tsp = start.lookup(&idx);
    if (!tsp || *tsp == 0) {
        return 0;   // missed start
    }

#if STATISTICS_PER_PROCESS
    u32 key = bpf_get_current_pid_tgid();
#else
    u32 key = bpf_get_smp_processor_id();
#endif

    delta = bpf_ktime_get_ns() - *tsp;
    map_val_t *valp, val;

    valp = res.lookup(&key);
    if (valp) {
        valp->count += 1;
        valp->time += delta;
        if (valp->time_max < delta) {valp->time_max = delta;}
    }
    else {
        val.count = 1;
        val.time = val.time_max = delta;
        res.update(&key, &val);
    }

    u64 zero = 0;
    start.update(&idx, &zero);

    return 0;
}
"""

cond_exp = "1"
if args.cpu:
    cond_exp = "bpf_get_smp_processor_id() == %s" % args.cpu

if args.pid:
    cond_exp = "(u32)bpf_get_current_pid_tgid() == %s" % args.pid

cond_filter = "if (!(%s)) {return 0;}" % cond_exp
bpf_text = bpf_text.replace('COND_FILTER', cond_filter)

bpf_text = bpf_text.replace('STATISTICS_PER_PROCESS', "1" if args.cpu else "0")


# output eBPF program C code after it is replaced, used by debugging
if debug or args.ebpf:
    print(bpf_text)
    if args.ebpf:
        exit()


# load BPF program
b = BPF(text=bpf_text)
print("Tracing time in kernel context... Hit Ctrl-C to end.")
tab = b.get_table("res")


# output
while (1):
    try:
        sleep(int(args.interval))
    except KeyboardInterrupt:
        countdown = 1
    print()
    if args.timestamp:
        print("%-8s\n" % strftime("%H:%M:%S"), end="")

    # print header
    print("%10s %10s %10s %10s" %
            ("PID" if args.cpu else "CPU", "Count", "TOTAL_" + label, "MAX_" + label))

    for k, v in sorted(tab.items(), key=lambda res: res[0]):
        print("%10d %10d %10d %10d" %
            (k, v.count, v.time / factor, v.time_max / factor))

    tab.clear()

    countdown -= 1
    if countdown == 0:
        exit()
