#!/usr/bin/python
# @lint-avoid-python-3-compatibility-imports
#
# alibiolatency    Summarize block device I/O latency.
#       For Linux, uses BCC, eBPF.
#
# USAGE: alibiolatency [-h] [-i interval] [-d device] [-t avg_req_time] [-T req_time] [-r]
#
# Copyright (c) 2019-2021 Alibaba Group.
# Licensed under the Apache License, Version 2.0 (the "License")
#
# 2019/07/01 Xiaoguang Wang Created this.

from __future__ import print_function
from bcc import BPF
import ctypes as ct
import argparse
import time

# arguments
examples = """examples:
    ./alibiolatency          # summarize block I/O latency(default display interval is 2s)
    ./alibiolatency -d sda3  # inspect specified device /dev/sda3
    ./alibiolatency -i 2     # specify display interval, 2s
    ./alibiolatency -t 10    # display only when average request process time is greater than 10ms
    ./alibiolatency -T 20    # dump request life cycle when single request process time is greater than 20ms
    ./alibiolatency -r       # dump every io request life cycle
"""
parser = argparse.ArgumentParser(
    description="Summarize block device I/O latency",
    formatter_class=argparse.RawDescriptionHelpFormatter,
    epilog=examples)
parser.add_argument("-d", "--device", help="inspect specified device")
parser.add_argument("-i", "--dis_interval", nargs="?", default=2,
    help="specify display interval")
parser.add_argument("-t", "--avg_threshold_time", nargs="?", default=0,
    help="display only when average request process time is greater than this value")
parser.add_argument("-T", "--threshold_time", nargs="?", default=999999999,
    help="dump request life cycle when single request process time is greater than this value")
parser.add_argument("-r", "--dump_raw", action="store_true", help="dump every io request life cycle")
parser.add_argument("--ebpf", action="store_true", help=argparse.SUPPRESS)
args = parser.parse_args()
threshold_time = int(args.threshold_time)
avg_threshold_time = int(args.avg_threshold_time)
dis_interval = int(args.dis_interval)
debug = 0

# define BPF program
bpf_text = """
#include <uapi/linux/ptrace.h>
#include <linux/blkdev.h>

struct req_stat {
	dev_t dev;
	u64 queue_us;
	u64 getrq_us;
	u64 insert_us;
	u64 issue_us;
	u64 completion_us;
	u64 sector;
	u32 num_sectors;
	char rwbs[8];
};

typedef struct disk_key {
	dev_t dev;
	u64 sector;
} disk_key_t;

BPF_HASH(io_requests, disk_key_t, struct req_stat);
BPF_PERF_OUTPUT(req_stat);


TRACEPOINT_PROBE(block, block_bio_queue)
{
	disk_key_t key = {};
	struct req_stat s = {};
	dev_t dev = args->dev;

	if (FILTER_DEV)
		return 0;

	key.dev = args->dev;
	key.sector = args->sector;
	s.dev = key.dev;
	s.queue_us = bpf_ktime_get_ns() / 1000;
	bpf_probe_read(s.rwbs, 8, args->rwbs);
	io_requests.update(&key, &s);
	return 0;
}

TRACEPOINT_PROBE(block, block_getrq)
{
	struct req_stat *r;
	disk_key_t key = {};
	dev_t dev = args->dev;

	if (FILTER_DEV)
		return 0;

	key.dev = args->dev;
	key.sector = args->sector;
	r = io_requests.lookup(&key);
	if (r == NULL)
		return 0;

	r->getrq_us = bpf_ktime_get_ns() / 1000;
	return 0;
}

TRACEPOINT_PROBE(block, block_bio_backmerge)
{
	struct req_stat *r;
	disk_key_t key = {};
	dev_t dev = args->dev;

	if (FILTER_DEV)
		return 0;

	key.dev = args->dev;
	key.sector = args->sector;
	io_requests.delete(&key);

	return 0;
}

TRACEPOINT_PROBE(block, block_bio_frontmerge)
{
	struct req_stat *r;
	struct req_stat *r2;
	disk_key_t key = {};
	disk_key_t key2 = {};
	dev_t dev = args->dev;

	if (FILTER_DEV)
		return 0;

	key.dev = args->dev;
	key.sector = args->sector + args->nr_sector;
	r = io_requests.lookup(&key);
	if (r == NULL)
		return 0;

	key2.dev = args->dev;
	key2.sector = args->sector;
	r2 = io_requests.lookup(&key2);
	if (r2 == NULL) {
		io_requests.delete(&key);
		return 0;
	}

	// keep old value.
	r2->queue_us = r->queue_us;
	r2->getrq_us = r->getrq_us;
	r2->insert_us = r->insert_us;
	r2->issue_us = r->issue_us;
	io_requests.delete(&key);

	return 0;
}

TRACEPOINT_PROBE(block, block_rq_insert)
{
	struct req_stat *r;
	disk_key_t key = {};
	dev_t dev = args->dev;

	if (FILTER_DEV)
		return 0;

	key.dev = args->dev;
	key.sector = args->sector;
	r = io_requests.lookup(&key);
	if (r == NULL)
		return 0;

	r->insert_us = bpf_ktime_get_ns() / 1000;
	return 0;
}

TRACEPOINT_PROBE(block, block_rq_issue)
{
	struct req_stat *r;
	disk_key_t key = {};
	dev_t dev = args->dev;

	if (FILTER_DEV)
		return 0;

	key.dev = args->dev;
	key.sector = args->sector;
	r = io_requests.lookup(&key);
	if (r == NULL)
		return 0;

	r->issue_us = bpf_ktime_get_ns() / 1000;
	r->num_sectors = args->nr_sector;
	return 0;
}

TRACEPOINT_PROBE(block, block_rq_complete)
{
	struct req_stat *r;
	struct req_stat val = {};
	disk_key_t key = {};
	dev_t dev = args->dev;

	if (FILTER_DEV)
		return 0;

	key.dev = args->dev;
	key.sector = args->sector;
	r = io_requests.lookup(&key);
	if (r == NULL)
		return 0;

	/* could be happen.
	if (r->num_sectors != args->nr_sector)
		bpf_trace_printk("here");
	*/
	r->completion_us = bpf_ktime_get_ns() / 1000;
	/* FIXME: should we use r? */
	//val = *r;
	val.dev = r->dev;
	val.queue_us = r->queue_us;
	val.getrq_us = r->getrq_us;
	val.insert_us = r->insert_us;
	val.issue_us = r->issue_us;
	val.completion_us = r->completion_us;
	val.sector = args->sector;
	val.num_sectors = args->nr_sector;
	memcpy(val.rwbs, r->rwbs, 8);

	io_requests.delete(&key);

	req_stat.perf_submit(args, &val, sizeof(struct req_stat));
	return 0;
}

int trace_elv_merge_requests(struct pt_regs *ctx, struct request_queue *q,
		struct request *rq, struct request *next)
{
	disk_key_t key = {};
	dev_t dev = (next->rq_disk)->part0.__dev.devt;

	if (FILTER_DEV)
		return 0;

	key.dev = dev;
	key.sector = next->__sector;
	io_requests.delete(&key);

	return 0;
}

"""

dev_name={}

def init_dev_name():
        global dev_name

        f = open("/proc/partitions")
        for line in f.readlines():
                line = line.strip()
                if not len(line):
                        continue;

                str = line.split()
                if str[0].isalpha():
                        continue
                dev_num = (int(str[0]) << 20) + int(str[1])
                dev_name[dev_num] = str[3]
        f.close()

init_dev_name()
name2devid={v:k for k,v in dev_name.items()}

if args.device:
	bpf_text = bpf_text.replace('FILTER_DEV', 'dev != %u' % name2devid[args.device])
else:
	bpf_text = bpf_text.replace('FILTER_DEV', '0')

if args.dump_raw:
	dump_raw = 1
else:
	dump_raw = 0

# code substitutions
if debug or args.ebpf:
    print(bpf_text)
    if args.ebpf:
        exit()

# load BPF program
b = BPF(text=bpf_text)
b.attach_kprobe(event="elv_merge_requests", fn_name="trace_elv_merge_requests")

print("Tracing block device I/O... Hit Ctrl-C to end.")

class Data(ct.Structure):
	_fields_ = [
		("dev", ct.c_uint),
		("queue_us", ct.c_ulonglong),
		("getrq_us", ct.c_ulonglong),
		("insert_us", ct.c_ulonglong),
		("issue_us", ct.c_ulonglong),
		("completion_us", ct.c_ulonglong),
		("sector", ct.c_ulonglong),
		("num_sectors", ct.c_uint),
		("rwbs", ct.c_char * 8)
    ]

io_stats={}
last_dis_time=0

class req_stats:
	def __init__(self, dev):
		self.dev=dev
		self.sectors=[0, 0, 0]
		self.num_reqs=[0, 0, 0]
		self.lat=[0, 0, 0]
		self.lat_max=[0, 0, 0]
		self.lat_min=[99999999, 99999999, 99999999]
		self.i2d_lat=[0, 0, 0]
		self.i2d_lat_max=[0, 0, 0]
		self.i2d_lat_min=[99999999, 99999999, 99999999]
		self.d2c_lat=[0, 0, 0]
		self.d2c_lat_max=[0,0,0]
		self.d2c_lat_min=[99999999, 99999999, 99999999]

	def update_req_stats(self, io_type, num_sectors, lat, i2d_lat, d2c_lat):
		self.num_reqs[io_type] += 1
		self.sectors[io_type] += num_sectors
		self.lat[io_type] += lat
		self.i2d_lat[io_type] += i2d_lat
		self.d2c_lat[io_type] += d2c_lat

		if lat > self.lat_max[io_type]:
			self.lat_max[io_type] = lat;

		if lat < self.lat_min[io_type]:
			self.lat_min[io_type] = lat;

		if i2d_lat > self.i2d_lat_max[io_type]:
			self.i2d_lat_max[io_type] = i2d_lat

		if i2d_lat < self.i2d_lat_min[io_type]:
			self.i2d_lat_min[io_type] = i2d_lat

		if d2c_lat > self.d2c_lat_max[io_type]:
			self.d2c_lat_max[io_type] = d2c_lat

		if d2c_lat < self.d2c_lat_min[io_type]:
			self.d2c_lat_min[io_type] = d2c_lat

	def display_req_stats(self, time_interval):
		global avg_threshold_time

		ti = time_interval / 1000.0 / 1000.0

		print("dev: %s" % dev_name[self.dev])
		for io_type in (0, 1, 2):
			if self.num_reqs[io_type] == 0:
				continue
			total_reqs = self.num_reqs[io_type]

			iops = int(total_reqs / ti)
			avg_lat = self.lat[io_type] / 1000.0 / total_reqs
			if avg_lat <  avg_threshold_time:
				continue;

			speed = self.sectors[io_type] * 512 / 1024 / 1024 / ti
			avg_req_sz = self.sectors[io_type] * 512 / 1024 / total_reqs

			i2d_avg = self.i2d_lat[io_type] / 1000.0 / total_reqs
			i2d_percent = int(i2d_avg / avg_lat * 100)

			d2c_avg = self.d2c_lat[io_type] / 1000.0 / total_reqs
			d2c_percent = int(d2c_avg / avg_lat * 100)

			if io_type == 0:
				io_type_str="READ"
			elif io_type == 1:
				io_type_str="WRITE"
			else:
				io_type_str="DISCARD"

			print("  %s iops:%d avg:%.2f min:%.2f max:%.2f speed:%.2fMB/s avgrq-sz:%dKB" %
			      (io_type_str, iops, avg_lat, self.lat_min[io_type] / 1000.0,
			      self.lat_max[io_type] / 1000.0, speed, avg_req_sz))

			print("  [i2d_avg: %.2f percent:%d min:%.2f max:%.2f" % (i2d_avg,
			      i2d_percent, self.i2d_lat_min[io_type] / 1000.0,
			      self.i2d_lat_max[io_type] / 1000.0))

			print("  [d2c_avg: %.2f percent:%d min:%.2f max:%.2f" % (d2c_avg,
			      d2c_percent, self.d2c_lat_min[io_type] / 1000.0,
			      self.d2c_lat_max[io_type] / 1000.0))


def get_io_type(rwbs):
	# please see logic in kernel func blk_fill_rwbs()
	if rwbs[0] == 'R' or rwbs[0:2] == "FR":
		return 0;
	elif rwbs[0] == 'W' or rwbs[0:2] == "FW":
		return 1;
	return 2

def is_exception_req(event):
	# we count this as a abnomal request...
	# Currently I still hasn't figured out in which case getrq_us
	# will be zero...
	if event.getrq_us == 0:
		print("dev:%s act:%s sector:%lu bytes:%u exception req Q:%lu G:%lu I:%lu D:%lu C:%lu" % (
		      dev_name[event.dev], event.rwbs, event.sector, event.num_sectors, event.queue_us,
		      event.getrq_us, event.insert_us, event.issue_us, event.completion_us))

def dump_raw_trace_event(dev, act, ts, event):
	print("%s %s %lu sector:%lu bytes:%u act:%s" % (dev, act, ts / 1000,
	      event.sector, event.num_sectors, event.rwbs))

def dump_timeout_req(event):
	dev = dev_name[event.dev]
	localtime = time.asctime( time.localtime(time.time()))

	print("%s" % localtime)
	print("timeouted request sector:%lu bytes:%u req time:%lu" % (event.sector,
	      event.num_sectors, (event.completion_us - event.queue_us) / 1000))
	dump_raw_trace_event(dev, "Q", event.queue_us, event)
	dump_raw_trace_event(dev, "G", event.getrq_us, event)
	dump_raw_trace_event(dev, "I", event.insert_us, event)
	dump_raw_trace_event(dev, "D", event.issue_us, event)
	dump_raw_trace_event(dev, "C", event.completion_us, event)

def dump_req_life_cycle(event):
	dev = dev_name[event.dev]
	localtime = time.asctime( time.localtime(time.time()))

	print("%s" % localtime)
	print("raw request event sector:%lu bytes:%u req time:%lu" % (event.sector,
	      event.num_sectors, (event.completion_us - event.queue_us) / 1000))
	dump_raw_trace_event(dev, "Q", event.queue_us, event)
	dump_raw_trace_event(dev, "G", event.getrq_us, event)
	dump_raw_trace_event(dev, "I", event.insert_us, event)
	dump_raw_trace_event(dev, "D", event.issue_us, event)
	dump_raw_trace_event(dev, "C", event.completion_us, event)

# process event
def print_event(cpu, data, size):
	global io_stats
	global last_dis_time

	event = ct.cast(data, ct.POINTER(Data)).contents

	is_exception_req(event)

	if dump_raw != 0:
		dump_req_life_cycle(event)

	lat = event.completion_us - event.queue_us;

	if event.issue_us != 0:
		d2c_lat = event.completion_us - event.issue_us
	elif event.insert_us != 0:
		d2c_lat = event.completion_us - event.insert_us
	elif event.getrq_us != 0:
		d2c_lat = event.completion_us - event.getrq_us
	else:
		d2c_lat = lat

	# Note: this could happen...
	# In kernel 4.9, for example, scsi_request_fn calls blk_peek_request()
	# to fetch a request, but if later scsi driver could not process this
	# request(see not_ready lable in scsi_request_fn), it will call
	# blk_requeue_request() to put back this request to block elevator, now
	# insert_us will be greater than issue_us.
	#
	# Later scsi_request_fn will pick this request again, but still it could
	# not handle this request, see codes in scsi_request_fn
	#	if (!scsi_dev_queue_ready(q, sdev))
	#		break;
	# then this request will still be in block elevator, but with RQF_STARTED
	# marked.
	#
	# Finally scsi_request_fn calls blk_peek_request() to process this request,
	# because RQF_STARTED is set, and this time blk_peek_request won't call
	# trace_block_rq_issue(), then issue_us also won't be updated.
	# Now insert_us is greater than issue_us.
	if event.issue_us <  event.insert_us:
		print("lege %u %s sector:%lu bytes:%u %lu %lu" % (dev_name[event.dev],
		      event.rwbs, event.sector, event.num_sectors, event.insert_us,
		      event.issue_us))

	if event.insert_us != 0 and event.issue_us > event.insert_us:
		i2d_lat = event.issue_us - event.insert_us;
	else:
		i2d_lat = 0
	io_type = get_io_type(event.rwbs)

	if last_dis_time == 0:
		last_dis_time = event.completion_us

	if not io_stats.has_key(event.dev):
		io_stats[event.dev] = req_stats(event.dev)

	if lat > threshold_time * 1000:
		dump_timeout_req(event)

	io_stats[event.dev].update_req_stats(io_type, event.num_sectors, lat, i2d_lat, d2c_lat)
	if event.completion_us > last_dis_time and (event.completion_us - last_dis_time) > dis_interval * 1000 * 1000:
		time_interval = event.completion_us - last_dis_time;
		localtime = time.asctime( time.localtime(time.time()))

		keys = io_stats.keys()
		keys.sort()
		print("%s" % localtime)
		for k in keys:
			io_stats[k].display_req_stats(time_interval)
		last_dis_time = event.completion_us

		print("\n")
		io_stats.clear()


b["req_stat"].open_perf_buffer(print_event, page_cnt=64)
while 1:
    try:
        b.perf_buffer_poll()

    except KeyboardInterrupt:
        exit()
