#!/usr/bin/env python3

output = r'''/* WARNING: auto-generated (by autogen/speed); do not edit */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <time.h>
#include <sys/time.h>
#include <sys/types.h>
#include <sys/resource.h>
#include "cpucycles.h" /* -lcpucycles */
#include "lib25519.h" /* -l25519 */
#include "randombytes.h" /* -lrandombytes */

static const char *targeto = 0;
static const char *targetp = 0;
static const char *targeti = 0;

#include "limits.inc"

static unsigned char *alignedcalloc(unsigned long long len)
{
  unsigned char *x = (unsigned char *) calloc(1,len + 128);
  if (!x) abort();
  /* will never deallocate so shifting is ok */
  x += 63 & (-(unsigned long) x);
  return x;
}

#define TIMINGS 15
static long long t[TIMINGS+1];

static void t_print(const char *op,long long impl,long long len)
{
  long long median = 0;

  printf("%s",op);
  if (impl >= 0)
    printf(" %lld",impl);
  else
    printf(" selected");
  printf(" %lld",len);
  for (long long i = 0;i < TIMINGS;++i)
    t[i] = t[i+1]-t[i];
  for (long long j = 0;j < TIMINGS;++j) {
    long long belowj = 0;
    long long abovej = 0;
    for (long long i = 0;i < TIMINGS;++i) if (t[i] < t[j]) ++belowj;
    for (long long i = 0;i < TIMINGS;++i) if (t[i] > t[j]) ++abovej;
    if (belowj*2 < TIMINGS && abovej*2 < TIMINGS) {
      median = t[j];
      break;
    }
  }
  printf(" %lld ",median);
  for (long long i = 0;i < TIMINGS;++i)
    printf("%+lld",t[i]-median);
  printf("\n");
  fflush(stdout);
}

#define MAXBATCH 16
#define MAXTEST_BYTES 65536

static void measure_cpucycles(void)
{
  printf("cpucycles selected persecond %lld\n",cpucycles_persecond());
  printf("cpucycles selected implementation %s\n",cpucycles_implementation());

  for (long long i = 0;i <= TIMINGS;++i)
    t[i] = cpucycles();
  t_print("cpucycles",-1,0);
}

static void measure_randombytes(void)
{
  unsigned char *m = alignedcalloc(MAXTEST_BYTES);
  long long mlen = 0;

  while (mlen < MAXTEST_BYTES) {
    for (long long i = 0;i <= TIMINGS;++i) {
      t[i] = cpucycles();
      randombytes(m,mlen);
    }
    t_print("randombytes",-1,mlen);
    mlen += 1+mlen/2;
  }
}
'''

# XXX: integrate todo into api
todo = (
  ('verify',(
    ('x','lib25519_verify_BYTES'),
    ('y','lib25519_verify_BYTES'),
  ),(
    ('crypto_verify','x','y'),
  )),
  ('hashblocks',(
    ('h','lib25519_hashblocks_STATEBYTES'),
    ('m','MAXTEST_BYTES'),
    ('mlen',None),
  ),(
    ('crypto_hashblocks','h','m','mlen'),
  )),
  ('hash',(
    ('h','lib25519_hash_BYTES'),
    ('m','MAXTEST_BYTES'),
    ('mlen',None),
  ),(
    ('crypto_hash','h','m','mlen'),
  )),
  ('pow',(
    ('n','lib25519_pow_BYTES'),
    ('ne','lib25519_pow_BYTES'),
  ),(
    ('crypto_pow','ne','n'),
  )),
  ('powbatch',(
    ('n','MAXBATCH*lib25519_powbatch_BYTES'),
    ('ne','MAXBATCH*lib25519_powbatch_BYTES'),
  ),(
    ('crypto_powbatch','ne','n','batch'),
  )),
  ('nP',(
    ('n','lib25519_nP_SCALARBYTES'),
    ('P','lib25519_nP_POINTBYTES'),
    ('nP','lib25519_nP_POINTBYTES'),
  ),(
    ('crypto_nP','nP','n','P'),
  )),
  ('nPbatch',(
    ('n','MAXBATCH*lib25519_nPbatch_SCALARBYTES'),
    ('P','MAXBATCH*lib25519_nPbatch_POINTBYTES'),
    ('nP','MAXBATCH*lib25519_nPbatch_POINTBYTES'),
  ),(
    ('crypto_nPbatch','nP','n','P','batch'),
  )),
  ('nG',(
    ('n','lib25519_nG_SCALARBYTES'),
    ('nG','lib25519_nG_POINTBYTES'),
  ),(
    ('crypto_nG','nG','n'),
  )),
  ('mGnP',(
    ('mGnP','lib25519_mGnP_OUTPUTBYTES'),
    ('m','lib25519_mGnP_MBYTES'),
    ('n','lib25519_mGnP_NBYTES'),
    ('P','lib25519_mGnP_PBYTES'),
  ),(
    ('crypto_mGnP','mGnP','m','n','P'),
  )),
  ('multiscalar',(
    ('Q','MAXBATCH*lib25519_multiscalar_OUTPUTBYTES'),
    ('n','MAXBATCH*lib25519_multiscalar_SCALARBYTES'),
    ('P','MAXBATCH*lib25519_multiscalar_POINTBYTES'),
  ),(
    ('crypto_multiscalar','Q','n','P','batch'),
  )),
  ('dh',(
    ('pka','lib25519_dh_PUBLICKEYBYTES'),
    ('ska','lib25519_dh_SECRETKEYBYTES'),
    ('pkb','lib25519_dh_PUBLICKEYBYTES'),
    ('skb','lib25519_dh_SECRETKEYBYTES'),
    ('ka','lib25519_dh_BYTES'),
  ),(
    ('crypto_dh_keypair','pka','ska'),
    ('crypto_dh_keypair','pkb','skb'),
    ('crypto_dh','ka','pkb','ska'),
  )),
  ('sign',(
    ('pk','lib25519_sign_PUBLICKEYBYTES'),
    ('sk','lib25519_sign_SECRETKEYBYTES'),
    ('m','MAXTEST_BYTES+lib25519_sign_BYTES'),
    ('sm','MAXTEST_BYTES+lib25519_sign_BYTES'),
    ('m2','MAXTEST_BYTES+lib25519_sign_BYTES'),
    ('mlen',None),
    ('smlen',None),
    ('m2len',None),
  ),(
    ('crypto_sign_keypair','pk','sk'),
    ('crypto_sign','sm','&smlen','m','mlen','sk'),
    ('crypto_sign_open','m2','&m2len','sm','smlen','pk'),
  )),
)

operations = []
primitives = {}
sizes = {}
exports = {}
prototypes = {}

with open('api') as f:
  for line in f:
    line = line.strip()
    if line.startswith('crypto_'):
      x = line.split()
      x = x[0].split('/')
      assert len(x) == 2
      o = x[0].split('_')[1]
      if o not in operations: operations += [o]
      p = x[1]
      if o not in primitives: primitives[o] = []
      primitives[o] += [p]
      continue
    if line.startswith('#define '):
      x = line.split(' ')
      x = x[1].split('_')
      assert len(x) == 4
      assert x[0] == 'crypto'
      o = x[1]
      p = x[2]
      if (o,p) not in sizes: sizes[o,p] = ''
      sizes[o,p] += line+'\n'
      continue
    if line.endswith(');'):
      fun,args = line[:-2].split('(')
      rettype,fun = fun.split()
      fun = fun.split('_')
      o = fun[1]
      assert fun[0] == 'crypto'
      if o not in exports: exports[o] = []
      exports[o] += ['_'.join(fun[1:])]
      if o not in prototypes: prototypes[o] = []
      prototypes[o] += [(rettype,fun,args)]

for t in todo:
  o,vars,benches = t

  for p in primitives[o]:

    output += '\n'
    output += 'static void measure_%s_%s(void)\n' % (o,p)
    output += '{\n'

    output += '  if (targeto && strcmp(targeto,"%s")) return;\n' % o
    output += '  if (targetp && strcmp(targetp,"%s")) return;\n' % p

    varsize = {}
    for v,size in vars:
      if size is None:
        output += '  long long %s;\n' % v
      else:
        size = size.replace('lib25519_'+o,'lib25519_'+o+'_'+p)
        output += '  unsigned char *%s = alignedcalloc(%s);\n' % (v,size)
        varsize[v] = size
    output += '\n'
  
    output += '  for (long long impl = -1;impl < lib25519_numimpl_%s_%s();++impl) {\n' % (o,p)
    for rettype,fun,args in prototypes[o]:
      output += '    %s (*%s)(%s);\n' % (rettype,'_'.join(fun),args)

    output += '    if (targeti && strcmp(targeti,lib25519_dispatch_%s_%s_implementation(impl))) continue;\n' % (o,p)

    output += '    if (impl >= 0) {\n'
    for rettype,fun,args in prototypes[o]:
      f2 = ['lib25519','dispatch',o,p]+fun[2:]
      output += '      %s = %s(impl);\n' % ('_'.join(fun),'_'.join(f2))
    output += r'      printf("%s_%s %%lld implementation %%s compiler %%s\n",impl,lib25519_dispatch_%s_%s_implementation(impl),lib25519_dispatch_%s_%s_compiler(impl));' % (o,p,o,p,o,p)
    output += '\n'
    output += '    } else {\n'
    for rettype,fun,args in prototypes[o]:
      f2 = ['lib25519',o,p]+fun[2:]
      output += '      %s = %s;\n' % ('_'.join(fun),'_'.join(f2))
    output += r'      printf("%s_%s selected implementation %%s compiler %%s\n",lib25519_%s_%s_implementation(),lib25519_%s_%s_compiler());' % (o,p,o,p,o,p)
    output += '\n'
    output += '    }\n'

    for v,size in vars:
      if size is not None:
        size = size.replace('lib25519_'+o,'lib25519_'+o+'_'+p)
        output += '    randombytes(%s,%s);\n' % (v,size)

    alreadybenched = set()
    alreadybenched.add('assert')

    for b in benches:
      if b[0] in alreadybenched:
        output += '    %s(%s);\n' % (b[0],','.join(b[1:]))
        continue

      fun = b[0].split('_')
      shortfun = '_'.join([o,p]+fun[2:])
      alreadybenched.add(b[0])

      if 'mlen' in b[1:] or 'smlen' in b[1:]:
        output += '    mlen = 0;\n'
        output += '    while (mlen <= MAXTEST_BYTES) {\n'
        output += '      randombytes(m,mlen);\n'

        if shortfun == 'sign_ed25519_open': # XXX: put this into todo
          output += '      lib25519_sign(sm,&smlen,m,mlen,sk);\n'

        output += '      for (long long i = 0;i <= TIMINGS;++i) {\n'
        output += '        t[i] = cpucycles();\n'
        output += '        %s(%s);\n' % (b[0],','.join(b[1:]))
        output += '      }\n'
        output += '      t_print("%s",impl,mlen);\n' % (shortfun)

        if shortfun == 'sign_ed25519_open': # XXX: put this into todo
          output += '      /* this is, in principle, not a test program */\n'
          output += '      /* but some checks here help validate the data flow above */\n'
          output += '      assert(m2len == mlen);\n'
          output += '      assert(!memcmp(m,m2,mlen));\n'

        if o == 'sign': # XXX: put this into todo
          output += '      mlen += 1+mlen/4;\n'
        else:
          output += '      mlen += 1+mlen/2;\n'
        output += '    }\n'
      elif 'batch' in o or 'multiscalar' in o:
        output += '    for (long long batch = 0;batch <= MAXBATCH;++batch) {\n'
        output += '      for (long long i = 0;i <= TIMINGS;++i) {\n'
        output += '        t[i] = cpucycles();\n'
        output += '        %s(%s);\n' % (b[0],','.join(b[1:]))
        output += '      }\n'
        output += '      t_print("%s",impl,batch);\n' % shortfun
        output += '    }\n'
      else:
        output += '    for (long long i = 0;i <= TIMINGS;++i) {\n'
        output += '      t[i] = cpucycles();\n'
        output += '      %s(%s);\n' % (b[0],','.join(b[1:]))
        output += '    }\n'
        output += '    t_print("%s",impl,%s);\n' % (shortfun,varsize[b[1]])

    output += '  }\n'
    output += '}\n'

output += r'''
#include "print_cpuid.inc"

int main(int argc,char **argv)
{
  printf("lib25519 version %s\n",lib25519_version);
  printf("lib25519 arch %s\n",lib25519_arch);
  print_cpuid();

  if (*argv) ++argv;
  if (*argv) {
    targeto = *argv++;
    if (*argv) {
      targetp = *argv++;
      if (*argv) {
        targeti = *argv++;
      }
    }
  }

  measure_cpucycles();
  measure_randombytes();
  limits();
'''

for t in todo:
  o,vars,benches = t
  for p in primitives[o]:
    output += '  measure_%s_%s();\n' % (o,p)

output += r'''
  return 0;
}
'''

with open('command/lib25519-speed.c','w') as f:
  f.write(output)
