#!/usr/bin/env python

from __future__ import print_function, division

import os
import sys
import itertools

MAX_LITERAL_SIZE = 65535
MAX_LINE_LENGTH = 82

def main():
    try:
        assets_root = sys.argv[1]
        assert os.path.isdir(assets_root)
    except:
        print("Error: First argument must be a directory", file=sys.stderr)
        sys.exit(1)

    # List all the files in assets_root

    assets = [
        os.path.relpath(os.path.join(root, path), assets_root)
        for root, __, paths in os.walk(assets_root)
        for path in paths
    ]

    # Write the encoded files and an index to stdout

    print(prelude)

    write_assets(assets_root, assets)

prelude = """
// Generated by scripts/compile-web-assets.py
// See CONTRIBUTING.md for information on how to generate this
// or see https://github.com/rethinkdb/rethinkdb/tree/old_admin

#include <map>
#include <string>
"""

def write_assets(asset_root, unsorted_assets):
    assets = sorted(unsorted_assets)

    print('std::map<std::string, const std::string> static_web_assets = {')
    for i, asset in enumerate(assets):

        print('    { ' + encode('/' + asset) + ', std::string(', end='')

        data = open(os.path.join(asset_root, asset), "rb").read()
        position = 0 # track the position to keep lines short
        trigraph = 0 # track consecutive question marks to avoid writing trigraphs
        prev_e = None # track the previous character to avoid tacking on hex digits
        literal_size = 0 # number of characters in the current string literal

        for c in data:
            c = byte(c)
            literal_size += 1

            if position == 0:
                print('\n      "', end='')
                position = 7
                trigraph = 0
                prev_e = None

            if trigraph >= 2 and c in b"=/'()!<>-":
                # split a trigraph
                print('" "', end='')
                prev_e = None
                position += 3
                trigraph = 0
            elif c == b'?':
                # count the amount of question marks
                trigraph += 1
            else:
                trigraph = 0

            e = encode_char(c, prev_e)
            prev_e = e
            print(e, end='')
            position += len(e)

            if position > MAX_LINE_LENGTH or c == b'\n':
                # end a line if it gets too long and on newlines
                print('"', end='')
                position = 0
                if literal_size > MAX_LITERAL_SIZE - MAX_LINE_LENGTH:
                    print(',')
                    print('      ' + str(literal_size) + ') + std::string(', end='')
                    literal_size = 0


        if position != 0:
            print('"', end='')

        if not data:
            print('""', end='')

        if literal_size:
            print(',')
            print('      ' + str(literal_size) + ' ) },')
        else:
            print('"") },')

    print('};')

def encode(string):
    return ''.join(
        ['"'] +
        [encode_char(byte(c)) for c in string.encode('utf-8')]
        + ['"']
    )

c_escapes = {
    b'\a': '\\a',
    b'\b': '\\b',
    b'\f': '\\f',
    b'\n': '\\n',
    b'\r': '\\r',
    b'\t': '\\t',
    b'\v': '\\v',
    b'\\': '\\\\',
    b'"': '\\"',
}

# The input character c should be a single character bytes
# The return value is an ascii-compatible unicode string
def encode_char(c, previous=None):
    avoid_hex = previous and previous[0] == '\\' and previous[1] in 'x01234567'
    n = ord(c)
    if c in c_escapes:
        return c_escapes[c]
    if n < 8:
        return '\\' + str(n)
    elif 32 <= n and n < 127 and not (avoid_hex and c in b'01234567890abcdefABCDEF'):
        return c.decode('ascii')
    else:
        return '\\x%x' % n

if sys.version < '3':
    def byte(b):
        return b
else:
    def byte(b):
        return bytes([b])

if __name__ == "__main__":
    main()
