#!/usr/bin/python
#   Copyright (C) 2013 Canonical Ltd.
#
#   Author: Scott Moser <scott.moser@canonical.com>
#
#   Simplestreams is free software: you can redistribute it and/or modify it
#   under the terms of the GNU Affero General Public License as published by
#   the Free Software Foundation, either version 3 of the License, or (at your
#   option) any later version.
#
#   Simplestreams is distributed in the hope that it will be useful, but
#   WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
#   or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public
#   License for more details.
#
#   You should have received a copy of the GNU Affero General Public License
#   along with Simplestreams.  If not, see <http://www.gnu.org/licenses/>.

import argparse
import hashlib
import errno
import json
import os
import os.path
import sys
from simplestreams import util
from simplestreams import contentsource as cs

try:
    # this is just python2 or python3 compatible prepping for get_url_len
    import urllib.request
    url_request = urllib.request.Request
    url_open = urllib.request.urlopen
    url_error = urllib.error
except ImportError:
    import urllib2
    url_request = urllib2.Request
    url_open = urllib2.urlopen
    url_error = urllib2

import toolutil
import sign_helper


# could support reading from other mirrors
# for example:
#   http://cloud-images-archive.ubuntu.com/
#   file:///srv/ec2-images
#
BASE_URLS = (
    "http://cloud-images.ubuntu.com/",
)

FAKE_DATA = {
    'root.tar.gz': {'size': 10240},
    'root.tar.xz': {
        'size': 10241,
        'mirrors': [
            'https://us.images.linuxcontainers.org/',
            'https://uk.images.linuxcontainers.org/'
        ]
    },
    'root.manifest': {'size': 10242},
    'lxd.tar.xz': {'size': 10245},
    'tar.gz': {'size': 11264},
    'disk1.img': {'size': 12288},
    'disk-kvm.img': {'size': 12289},
    'uefi1.img': {'size': 12421},
    'manifest': {'size': 10241},
    'ova': {'size': 12399},
    'squashfs': {'size': 12400},
    'squashfs.manifest': {'size': 12401},
    'img': {'size': 12402},
}

EC2_ENDPOINTS = {
    'default': 'https://ec2.%s.amazonaws.com',
    'cn-north-1': 'https://ec2.cn-north-1.amazonaws.com.cn',
    'cn-northwest-1': 'https://ec2.cn-northwest-1.amazonaws.com.cn',
    'us-gov-west-1': 'https://ec2.us-gov-west-1.amazonaws.com'
}

UBUNTU_RDNS = "com.ubuntu.cloud"

REAL_DATA = os.environ.get("REAL_DATA", False)
if REAL_DATA and REAL_DATA != "0":
    REAL_DATA = True
else:
    REAL_DATA = False

FILE_DATA = {}


def get_cache_data(path, field):
    dirname = os.path.dirname(path)
    bname = os.path.basename(path)
    return FILE_DATA.get(dirname, {}).get(bname, {}).get(field)


def store_cache_data(path, field, value):
    dirname = os.path.dirname(path)
    bname = os.path.basename(path)
    if dirname not in FILE_DATA:
        FILE_DATA[dirname] = {}
    if bname not in FILE_DATA[dirname]:
        FILE_DATA[dirname][bname] = {}
    FILE_DATA[dirname][bname][field] = value


def store_cache_entry(path, data):
    for k, v in data.items():
        if k in ('size', 'md5', 'sha256'):
            store_cache_data(path, k, v)


def save_cache():
    if FILE_DATA:
        hashcache = FILE_DATA['filename']
        with open(hashcache, "wb") as hfp:
            hfp.write(util.dump_data(FILE_DATA))


def load_sums_from_sumfiles(path):
    for cksum in ("md5", "sha256"):
        content = None
        sfile_url = path + "/%sSUMS" % cksum.upper()
        if get_cache_data(sfile_url, 'size'):
            continue

        sfile_info = load_url(sfile_url)
        content = sfile_info['content'].decode('utf-8')

        for line in content.splitlines():
            (hexsum, fname) = line.split()
            if fname.startswith("*"):
                fname = fname[1:]
            fpath = path + "/" + fname
            store_cache_data(fpath, cksum, hexsum)
            get_cloud_images_file_size(fpath, save=False)

        store_cache_entry(sfile_url, sfile_info)


class NonExistingUrl(Exception):
    pass


def load_url(path, hashes=None, base_urls=None):
    if base_urls is None:
        base_urls = BASE_URLS
    url = base_urls[0] + path
    mirrors = [u + path for u in base_urls[1:]]
    try:
        data = cs.UrlContentSource(url, mirrors=mirrors).read()
        if b'403 Forbidden' in data:
            raise NonExistingUrl("%s: 403 Forbidden (s3 404)" % path)
    except url_error.HTTPError as e:
        if e.code == 403:
            raise NonExistingUrl("%s: 403" % path)
        elif e.code == 404:
            raise NonExistingUrl("%s: 404" % path)
        else:
            raise e
    except IOError as e:
        if e.errno != errno.ENOENT:
            raise e
        else:
            raise NonExistingUrl("%s: ENOENT" % path)

    sys.stderr.write("read url %s\n" % path)

    raw_content = data

    ret = {'size': len(raw_content)}
    ret['content'] = raw_content

    if hashes is None:
        hashes = ["sha256", "md5"]
    for hashname in hashes:
        t = hashlib.new(hashname)
        t.update(raw_content)
        ret[hashname] = t.hexdigest()
    return ret


def load_data_in_dir(path):
    qfile = ".qindex.json"
    qpath_loaded = False
    qpath = path + "/" + qfile

    if get_cache_data(qpath, "size"):
        sys.stderr.write("dir[cache]: %s\n" % path)
        return

    try:
        ret = load_url(qpath)
        content = ret['content'].decode("utf-8")
        try:
            for fpath, data in json.loads(content).items():
                store_cache_entry(path + "/" + fpath, data)
            qpath_loaded = True
            store_cache_entry(qpath, ret)
        except ValueError as e:
            sys.stderr.write("qindex parse failed %s" % path)
            raise e
    except NonExistingUrl:
        # sys.stderr.write("%s: 404 (%s)" % (qpath, e))
        pass

    # fall back to loading sumfiles and statting sizes
    if qpath_loaded:
        sys.stderr.write("dir[qindex]: %s\n" % path)
    else:
        load_sums_from_sumfiles(path)
        sys.stderr.write("dir[sumfiles]: %s\n" % path)

    save_cache()
    return


def get_cloud_images_file_info(path):
    keys = ('md5', 'sha256', 'size')
    cached = {k: get_cache_data(path, k) for k in keys}
    if all(cached.values()):
        return cached

    dirname = os.path.dirname(path)

    load_data_in_dir(dirname)

    # if we were missing an md5 or a sha256 for the manifest
    # file, then get them ourselves.
    ret = {k: get_cache_data(path, k) for k in keys}
    if path.endswith(".manifest") and not all(ret.values()):
        loaded = load_url(path)
        store_cache_entry(path, loaded)
        save_cache()
        ret = {k: loaded[k] for k in keys}

    missing = [h for h in ret if not ret[h]]
    if missing:
        raise Exception("Unable to get checksums (%s) for %s" %
                        (missing, path))

    return ret


def get_url_len(url):
    if url.startswith("file:///"):
        path = url[len("file://"):]
        return os.stat(path).st_size
    if os.path.exists(url):
        return os.stat(url).st_size

    # http://stackoverflow.com/questions/4421170/
    #  python-head-request-with-urllib2
    request = url_request(url)
    request.get_method = lambda: 'HEAD'
    response = url_open(request)
    return int(response.headers.get('content-length', 0))


def get_cloud_images_file_size(path, save=True):
    size = get_cache_data(path, 'size')
    if size:
        return size

    error = None
    sys.stderr.write(" size: %s\n" % path)
    for burl in BASE_URLS:
        try:
            size = int(get_url_len(burl + path))
            break
        except Exception:
            sys.stderr.write("  size stat failed: %s" % burl + path)
            pass

    if not size:
        raise error
    store_cache_data(path, 'size', size)
    if save:
        save_cache()
    return size


def create_fake_file(prefix, item):
    fpath = os.path.join(prefix, item['path'])

    data = FAKE_DATA[item['ftype']]

    util.mkdir_p(os.path.dirname(fpath))
    print("creating %s" % fpath)
    with open(fpath, "w") as fp:
        fp.truncate(data['size'])

    pwd_mirror = "file://" + os.getcwd() + "/"
    if 'md5' not in FAKE_DATA[item['ftype']]:
        # load the url to get checksums and update the sparse FAKE_DATA
        fdata = load_url(fpath, base_urls=[pwd_mirror])
        del fdata['content']
        FAKE_DATA[item['ftype']].update(fdata)

    # create a combined sha256 for lxd.tar.xz (metadata) and the root fs
    # - combined_sha256 and combined_rootxz_sha256 for the -root.tar.gz
    # - combined_squashfs_sha256 for the squashfs
    # - combined_disk1-img_sha256 for the img
    # - combined_uefi1-img_sha256 for the img (xenial)
    # - combined_disk-kvm-img_sha256 for the img
    ftype = item['ftype']
    for name, extension in (('disk1-img', '.img'),
                            ('uefi1-img', '-uefi1.img'),  # xenial
                            ('disk-kvm-img', '-disk-kvm.img'),
                            ('rootxz', '-root.tar.xz'),
                            ('squashfs', '.squashfs')):
        if (ftype == "lxd.tar.xz" and
                'combined_{}_sha256'.format(name) not in FAKE_DATA[ftype]):
            rootpath = item['path'].replace("-lxd.tar.xz", extension)
            if not os.path.exists(os.path.join(prefix, rootpath)):
                rootitem = item.copy()
                rootitem['ftype'] = extension.lstrip('-.')
                rootitem['path'] = rootpath
                if 'mirrors' in FAKE_DATA[ftype]:
                    rootitem['mirrors'] = FAKE_DATA[ftype]['mirrors']
                create_fake_file(prefix, rootitem)

            # and sha256 hash the combined file
            chash = hashlib.new('sha256')
            for member in [fpath, os.path.join(prefix, rootpath)]:
                with open(member, "rb") as fp:
                    chash.update(fp.read())
            data.update(
                {'combined_{}_sha256'.format(name): chash.hexdigest()})

    # Add legacy combined_sha256 if combined_rootxz_sha256 exists
    if (ftype == "lxd.tar.xz" and
            'combined_sha256' not in FAKE_DATA[ftype] and
            'combined_rootxz_sha256' in FAKE_DATA[ftype]):
        data.update({'combined_sha256':
                     FAKE_DATA[ftype]['combined_rootxz_sha256']})

    item.update(data)

    for cksum in util.CHECKSUMS:
        if cksum in item and cksum not in data:
            del item[data]

    return


def should_skip_due_to_unsupported(pdata):
    # filter out unsupported releases from download data
    if os.environ.get("SS_SKIP_UNSUPPORTED", "0") == "0":
        return False
    # this will raise exception if 'supported' is not present.
    # thus, SS_SKIP_UNSUPPORTED=1 requires that distro-info be present.
    return not pdata['supported']


def dl_load_query(path):
    tree = {}
    for rline in toolutil.load_query_download(path):
        (stream, rel, build, label, serial, arch, filepath, fname) = rline

        if skip_for_faster_debug(serial):
            continue

        if stream not in tree:
            tree[stream] = {'products': {}}
        products = tree[stream]['products']

        prodname_rdns = UBUNTU_RDNS
        if stream != "released":
            prodname_rdns += "." + stream

        prodname = ':'.join(
            [prodname_rdns, build, release_to_version(rel), arch])

        if prodname not in products:
            product = new_ubuntu_product_dict(rel, arch, label)

            if should_skip_due_to_unsupported(product):
                sys.stderr.write("skipping unsupported %s\n" % prodname)
                continue

            products[prodname] = product

        product = products[prodname]

        if serial not in product['versions']:
            product['versions'][serial] = {'items': {}, "label": label}

        name = pubname(label, rel, arch, serial, build)
        product['versions'][serial]['pubname'] = name

        items = product['versions'][serial]['items']

        # ftype: finding the unique extension is not-trivial
        # take basename of the filename, and remove up to "-<arch>?"
        # so ubuntu-12.04-server-cloudimg-armhf.tar.gz becomes
        # 'tar.gz' and 'ubuntu-12.04-server-cloudimg-armhf-disk1.img'
        # becomes 'disk1.img'
        dash_arch = "-" + arch
        ftype = filepath[filepath.rindex(dash_arch) + len(dash_arch) + 1:]
        # the ftype for '.img' is 'disk1.img' to keep stream metadata
        # backwards compatible.
        if ftype == "img":
            ftype = "disk1.img"
        items[ftype] = {
            'path': filepath,
            'ftype': ftype
        }

    return tree


def release_to_version(release):
    return toolutil.REL2VER[release]['version']


def pubname(label, rel, arch, serial, build='server'):
    version = release_to_version(rel)

    if label == "daily":
        rv_label = rel + "-daily"
    elif label == "release":
        rv_label = "%s-%s" % (rel, version)
    elif label.startswith("beta"):
        rv_label = "%s-%s-%s" % (rel, version, label)
    else:
        rv_label = "%s-%s" % (rel, label)
    return "ubuntu-%s-%s-%s-%s" % (rv_label, arch, build, serial)


def new_ubuntu_product_dict(release, arch, label):
    reldata = toolutil.REL2VER[release]

    ret = {
        "arch": arch,
        "os": "ubuntu",
        "release": release,
        "release_title": reldata['release_title'],
        "release_codename": reldata['release_codename'],
        "version": reldata['version'],
        "versions": {},
    }

    for maybe in ('supported', 'support_eol'):
        if maybe in reldata:
            ret[maybe] = reldata[maybe]

    if reldata['aliases']:
        aliases = list(reldata['aliases'])
        if "devel" in aliases and 'daily' not in label:
            aliases.remove("devel")
        if aliases:
            ret['aliases'] = ','.join(sorted(aliases))

    return ret


def skip_for_faster_debug(serial):
    min_serial = os.environ.get('SS_DEBUG_MIN_SERIAL')
    if min_serial in (None, "0"):
        return False
    if "." not in min_serial:
        min_serial += ".0"
    ret = serial < min_serial
    print("min_serial: %s serial: %s -> %s" % (min_serial, serial, ret))
    return serial < min_serial


def ec2_load_query(path):
    tree = {}

    dmap = {
        "north": "nn",
        "northeast": "ne",
        "east": "ee",
        "southeast": "se",
        "south": "ss",
        "southwest": "sw",
        "west": "ww",
        "northwest": "nw",
        "central": "cc",
    }
    itmap = {
        'pv': {'instance': "pi", "ebs": "pe", "ssd": "es", "io1": "eo"},
        'hvm': {'instance': "hi", "ebs": "he", "ssd": "hs", "io1": "ho"}
    }

    for rline in toolutil.load_query_ec2(path):
        (stream, rel, build, label, serial, store, arch, region,
         iid, _kern, _rmd, vtype) = rline

        if skip_for_faster_debug(serial):
            continue
        if stream not in tree:
            tree[stream] = {'products': {}}
        products = tree[stream]['products']

        prodname_rdns = UBUNTU_RDNS
        if stream != "released":
            prodname_rdns += "." + stream

        prodname = ':'.join(
            [prodname_rdns, build, release_to_version(rel), arch])

        if prodname not in products:
            product = new_ubuntu_product_dict(rel, arch, label)

            if should_skip_due_to_unsupported(product):
                sys.stderr.write("skipping unsupported %s\n" % prodname)
                continue

            products[prodname] = product

        product = products[prodname]

        if serial not in product['versions']:
            product['versions'][serial] = {'items': {}, "label": label}
        items = product['versions'][serial]['items']

        name = pubname(label, rel, arch, serial, build)
        product['versions'][serial]['pubname'] = name

        if store == "instance-store":
            store = 'instance'
        elif '-' in store:
            store = store.split('-')[-1]
        if vtype == "paravirtual":
            vtype = "pv"

        # create the item key:
        #  - 2 letter country code (us) . 3 for govcloud (gww)
        #  - 2 letter direction (nn=north, nw=northwest, cc=central)
        #  - 1 digit number
        #  - 1 char for virt type
        #  - 1 char for root-store type

        # Handle special case of 'gov' regions
        _region = region
        pre_cc = ""
        if '-gov-' in region:
            _region = region.replace('gov-', '')
            pre_cc = "g"

        (cc, direction, num) = _region.split("-")

        ikey = pre_cc + cc + dmap[direction] + num + itmap[vtype][store]

        items[ikey] = {
            'id': iid,
            'root_store': store,
            'virt': vtype,
            'crsn': region,
        }
    return tree


def printitem(item, exdata):
    full = exdata.copy()
    full.update(item)
    print(full)


def create_image_data(query_tree, out_d, streamdir):
    license = 'http://www.canonical.com/intellectual-property-policy'
    hashcache = os.path.join(query_tree, "FILE_DATA_CACHE")
    FILE_DATA['filename'] = hashcache
    if os.path.isfile(hashcache):
        FILE_DATA.update(json.loads(open(hashcache, "r").read()))

    ts = util.timestamp()
    tree = dl_load_query(query_tree)

    def update_data(item, tree, pedigree):
        path = item['path']
        item.update(get_cloud_images_file_info(path))
        if path.endswith('-lxd.tar.xz'):
            dirname = os.path.dirname(path)
            lxd = os.path.basename(path)

            # find calculated combined checksums
            for name, extension in (('disk1-img', '.img'),
                                    ('uefi1-img', '-uefi1.img'),  # xenial
                                    ('disk-kvm-img', '-disk-kvm.img'),
                                    ('rootxz', '-root.tar.xz'),
                                    ('squashfs', '.squashfs')):
                root = lxd.replace('-lxd.tar.xz', extension)
                combined = os.path.join(dirname, ','.join([lxd, root]))
                value = get_cache_data(combined, 'sha256')
                if value:
                    item.update({'combined_{}_sha256'.format(name): value})

            # Add legacy combined_sha256 if combined_rootxz_sha256 exists
            if 'combined_rootxz_sha256' in item:
                value = item['combined_rootxz_sha256']
                item.update({'combined_sha256': value})

    cid_fmt = "com.ubuntu.cloud:%s:download"
    for stream in tree:
        def create_file(item, tree, pedigree):
            create_fake_file(os.path.join(out_d, stream), item)

        cid = cid_fmt % stream
        if REAL_DATA:
            util.walk_products(tree[stream], cb_item=update_data)
        else:
            util.walk_products(tree[stream], cb_item=create_file)

        tree[stream]['format'] = "products:1.0"
        tree[stream]['updated'] = ts
        tree[stream]['content_id'] = cid
        tree[stream]['datatype'] = 'image-downloads'
        tree[stream]['license'] = license

        outfile = os.path.join(out_d, stream, streamdir, cid + ".json")
        util.mkdir_p(os.path.dirname(outfile))
        with open(outfile, "wb") as fp:
            sys.stderr.write("writing %s\n" % outfile)
            fp.write(util.dump_data(tree[stream]) + b"\n")

    # save hashes data
    save_cache()
    return tree


def create_aws_data(query_tree, out_d, streamdir):
    tree = ec2_load_query(query_tree)
    ts = util.timestamp()
    cid_fmt = "com.ubuntu.cloud:%s:aws"
    for stream in tree:
        cid = cid_fmt % stream
        # now add the '_alias' data
        regions = set()

        def findregions(item, tree, pedigree):
            regions.add(item['crsn'])

        util.walk_products(tree[stream], cb_item=findregions)

        tree[stream]['_aliases'] = {'crsn': {}}
        for region in regions:
            epoint = EC2_ENDPOINTS['default'] % region
            if region in EC2_ENDPOINTS:
                epoint = EC2_ENDPOINTS[region]

            tree[stream]['_aliases']['crsn'][region] = {
                'endpoint': epoint,
                'region': region}

        tree[stream]['format'] = "products:1.0"
        tree[stream]['datatype'] = "image-ids"
        tree[stream]['updated'] = ts
        tree[stream]['content_id'] = cid
        outfile = os.path.join(out_d, stream, streamdir, cid + ".json")
        util.mkdir_p(os.path.dirname(outfile))
        with open(outfile, "w") as fp:
            sys.stderr.write("writing %s\n" % outfile)
            fp.write(json.dumps(tree[stream], indent=1) + "\n")

    return tree


def main():
    parser = argparse.ArgumentParser(description="create example content tree")

    parser.add_argument("query_tree", metavar='query_tree',
                        help=('read in content from /query tree. Hint: ' +
                              'make exdata-query'))

    parser.add_argument("out_d", metavar='out_d',
                        help=('create content under output_dir'))

    parser.add_argument('--skip-aws', action='store_true',
                        default=False, help='Skip AWS generation')

    parser.add_argument('--sign', action='store_true', default=False,
                        help='sign all generated files')

    args = parser.parse_args()
    streamdir = "streams/v1"

    dltree = create_image_data(args.query_tree, args.out_d, streamdir)

    if not args.skip_aws:
        aws_tree = create_aws_data(args.query_tree, args.out_d, streamdir)

        for streamname in aws_tree:
            index = {"index": {}, 'format': 'index:1.0',
                     'updated': util.timestamp()}

            clouds = list(aws_tree[streamname]['_aliases']['crsn'].values())
            index['index'][aws_tree[streamname]['content_id']] = {
                'updated': aws_tree[streamname]['updated'],
                'datatype': aws_tree[streamname]['datatype'],
                'clouds': clouds,
                'cloudname': "aws",
                'path': '/'.join(
                    (
                        streamdir,
                        "%s.json" % aws_tree[streamname]['content_id'],
                    )
                ),
                'products': sorted(
                    list(aws_tree[streamname]['products'].keys())
                ),
                'format': aws_tree[streamname]['format'],
            }
            index['index'][dltree[streamname]['content_id']] = {
                'updated': dltree[streamname]['updated'],
                'datatype': dltree[streamname]['datatype'],
                'path': '/'.join(
                    (streamdir, "%s.json" % dltree[streamname]['content_id'])
                ),
                'products': sorted(
                    list(dltree[streamname]['products'].keys())
                ),
                'format': dltree[streamname]['format']
            }

            outfile = os.path.join(
                args.out_d, streamname, streamdir, 'index.json'
            )
            util.mkdir_p(os.path.dirname(outfile))
            with open(outfile, "wb") as fp:
                sys.stderr.write("writing %s\n" % outfile)
                fp.write(util.dump_data(index) + b"\n")

    if args.sign:
        def printstatus(name):
            sys.stderr.write("signing %s\n" % name)
        for root, dirs, files in os.walk(args.out_d):
            for f in [f for f in files if f.endswith(".json")]:
                sign_helper.signjson_file(os.path.join(root, f),
                                          status_cb=printstatus)

    return


if __name__ == '__main__':
    sys.exit(main())

# vi: ts=4 expandtab
