#!/usr/bin/env python3
'''
Tool to read callgrind output and return the net cost of the first call
to a given function
'''

import argparse
import json
import os
import re
import tempfile
from contextlib import suppress
from multiprocessing import Pool
from pathlib import Path
from subprocess import call

p = argparse.ArgumentParser()
p.add_argument('file', nargs='?', type=argparse.FileType('r'),
               help="control file")
p.add_argument('-g', '--graph', action='store_true',
               help='Display a graph instead of csv results')
p.add_argument('-s', '--single', action='store_true',
               help='Do single processing')
p.add_argument('-f', '--font',
               help='font file to use for a single test')
p.add_argument('-t', '--text',
               help='text file to use for a single test')
p.add_argument('-l', '--line', type=int, default=0,
               help='line in text file for single test')
p.add_argument('-n', '--normalize', action="store_true",
               help="Normalize results")
p.add_argument('-k', '--keep',
               help='Store intermediate callgrind results in '
                    'given directory')
p.add_argument('-u', '--usp', type=Path,
               help="Include usp results")
p.add_argument('-v', '--verbose', action='store_true',
               help='Print the commands executed')

args = p.parse_args()


def parsecg(fname):
    idcache = {}

    def parseline(line):
        (_, s) = line.split('=', 1)
        s = s.lstrip()
        m = re.match(r'(?:\((\d+)\))?\s*(.*)$', s)
        ident = m.group(1)
        name = m.group(2)
        if ident and name: idcache[ident] = name
        if ident: return (ident, idcache.get(ident, name))
        else: return (None, name)

    fncosts = {}
    with fname.open() as fh:
        incfn = False
        currfn = ''
        for line in fh:
            line = line.strip()
            if incfn:
                if line.startswith('calls'): continue
                cost = line.split(' ')[1]
                fncosts[currfn] = cost
                incfn = False
            elif line.startswith('cfn'):
                (ident, name) = parseline(line)
                incfn = True
                currfn = name
            elif line.startswith('fn'):
                (id, name) = parseline(line)
    return fncosts


modes = {
    'ot': (
        'ot',
        {'ot': 'hb_shape_plan_execute',
         'ot_GSUB': 'hb_ot_map_t::substitute(hb_ot_shape_plan_t const*, hb_font_t*, hb_buffer_t*) const',
         'ot_GPOS': 'hb_ot_map_t::position(hb_ot_shape_plan_t const*, hb_font_t*, hb_buffer_t*) const'}),
    'gr': (
        'graphite2',
        {'hbgr': 'hb_shape_plan_execute',
         'gr': 'gr_make_seg',
         'grcoll': 'graphite2::Pass::collisionShift(graphite2::Segment*, int, graphite2::json*) const',
         'grkern': 'graphite2::Pass::collisionKern(graphite2::Segment*, int, graphite2::json*) const'}),
    'usp': (
        'uniscribe',
        {'usp':      '0x61829dc6',
         'abc':      'GetCharABCWidthsI',
         'usp_GSUB': '0x6f90e350',
         'usp_GPOS': '0x6f90e660'},
        {'valgrind_parms': ['--trace-children=yes', '--combine-dumps=yes'],
         'dir': args.usp,
         'exe': ['wine',
                 args.usp and args.usp.joinpath('util', 'hb-shape.exe')],
         'env': {
            'WINEPATH':         args.usp,
            'WINEDLLPATH':      args.usp,
            'WINEDLLOVERRIDES': '"usp10=n"',
            'WINEDEBUG':        'trace+dll'}},
        {'uspnorm': lambda x: int(x['usp']) - int(x['abc'])}
    )
}


def getval(dat, key, defkey=None):
    if not hasattr(dat, 'keys'):
        return dat
    if defkey is not None:
        return dat.get(key, dat[defkey])
    else:
        return dat.get(key, None)


def dorow(x):
    (c, r) = x
    results = r.copy()
    for k, job in modes.items():
        if k == 'usp' and not args.usp:
            for l in ('usp', 'abc', 'uspnorm', 'usp_GSUB', 'usp_GPOS'):
                results[l] = 0
            continue
        resf = tmpd.joinpath('callgrind_'+k+'_'+str(c)+'.out').resolve()
        with open(r['textfile'], encoding='utf_8') as ft:
            t = ft.readlines()[int(r['textline'])].strip()
        tfile = tmpd.joinpath('t'+k+str(c)+'.out').resolve()
        with tfile.open('w') as ft:
            ft.write(t+"\n")
        results['char'] = len(t)
        font = getval(r['font'], k, defkey='ot')
        cmd = ['valgrind', '-q',
                           '--tool=callgrind',
                           '--callgrind-out-file='+str(resf),
                           'hb-shape',
                           '--font-file='+font,
                           '--shapers='+job[0],
                           '--text-file='+str(tfile)]
        cwd = None
        env = os.environ.copy()
        if len(job) > 2:
            cmd = cmd[:4] + job[2]['valgrind_parms'] + job[2]['exe'] + cmd[5:]
            cwd = job[2]['dir']
            env.update(job[2]['env'])
        # copt = 3  # This is never used
        if 'rtl' in r and r['rtl']:
            cmd.append('--direction=rtl')
        if 'cmdext' in r:
            f = getval(r['cmdext'], k)
            if f is not None:
                cmd.append(f)
        if args.verbose:
            print(" ".join(cmd))

        with suppress(Exception):
            call(cmd, stdout=dumpfh, cwd=cwd)

        cfns = parsecg(resf)
        for rk, rf in job[1].items():
            if 'ignore' in r and k in r['ignore']:
                results[rk] = 0
            else:
                results[rk] = int(cfns.get(rf, 0))
            if args.normalize:
                results[rk] /= results['char']
        if len(job) > 3:
            for rk, rl in job[3].items():
                results[rk] = rl(results)
        if not args.keep:
            resf.unlink()
            tfile.unlink()
    return results


outtop = 'name, length, ot, ot_GSUB, ot_GPOS, hbgr, gr,' \
         ' grcoll, grkern, usp, abc, uspnorm, usp_GSUB, usp_GPOS'
outformat = '{label}, {char}, {ot}, {ot_GSUB}, {ot_GPOS}, {hbgr}, {gr},' \
            ' {grcoll}, {grkern}, {usp}, {abc}, {uspnorm}, {usp_GSUB},' \
            ' {usp_GPOS}'

tmpd = Path(args.keep if args.keep else tempfile.mkdtemp(prefix='hbspeeds-'))
dumpf = tmpd.joinpath('dump.out')
with dumpf.open('w') as dumpfh:
    if args.file:
        c = json.load(args.file)
        args.file.close()
    else:
        c = [{'font':     args.font,
              'textfile': args.text,
              'textline': args.line,
              'label':    'Commandline'}]

    if args.single or args.file is None:
        res = list(map(dorow, enumerate(c)))
    else:
        pool = Pool()
        res = pool.map(dorow, enumerate(c))
        pool.close()
        pool.join()
dumpf.unlink()

if not args.keep:
    tmpd.rmdir()

if not args.graph:
    print(outtop)
    for r in res:
        if r is not None:
            print(outformat.format(**r))
else:
    import matplotlib.pyplot as plt

    def stackplot(xs, ydat, width, **kwargs):
        res = []
        h = len(ydat[0])
        # yacc = [0] * len(ydat) # This is never used
        for i in range(h):
            res.append(plt.bar(xs,
                               [y[i] - (y[i-1] if i > 0 else 0) for y in ydat],
                               width,
                               bottom=[y[i-1] if i > 0 else 0 for y in ydat],
                               **kwargs))
        return res

    width = 0.25
    fig, ax = plt.subplots(figsize=(10, 5))
    dat = [
        [(x['uspnorm'], ) if args.normalize
         else (x['uspnorm'] / x['char'], ) for x in res],
        [(x['ot'], ) if args.normalize
         else (x['ot'] / x['char'], ) for x in res],
        [(x['gr'], ) if args.normalize
         else (x['gr'] / x['char'], ) for x in res]]
#         [(x['usp_GSUB'], x['usp_GSUB'] + x['usp_GPOS'] - x['abc'],
#           x['uspnorm']) for x in res],
#         [(x['ot_GSUB'], x['ot_GSUB'] + x['ot_GPOS'], x['ot']) for x in res],
#         [(x['gr'], x['hbgr']) for x in res]]
    pos = list(range(len(dat[1])))
    plots = [stackplot(pos, dat[0], width,
                       alpha=0.6, color='blue', label='usp'),
             stackplot([p + width for p in pos], dat[1], width,
                       alpha=0.6, color='red', label='ot'),
             stackplot([p + width * 2 for p in pos], dat[2], width,
                       alpha=0.6, color='green', label='gr')]
    ax.set_ylabel('Cycles / character')
    ax.set_title('Font Speed Test')
    ax.set_xticks([p + 1.5 * width for p in pos])
    ax.set_xticklabels([x['label'] for x in res])
    plt.xlim(min(pos) - width, max(pos) + width * 4)
    # ymax = max([max(x) for x in dat])
    # plt.ylim(0, ymax)
    plt.legend([x[0] for x in plots], ['usp', 'ot', 'gr'], loc='upper left')
    plt.grid()
    plt.show()
