import sys
import random
import time
import argparse
import pickle

import numpy as np
import Levenshtein

import chainer
from chainer import Variable, cuda, serializers
from chainer import functions as F

from model import ForwardBackward
from data import Task1File, G2pFile, ASJPFile, load_language_embeddings

def main():
    parser = argparse.ArgumentParser(
            description='ForwardBackward model training')
    parser.add_argument('--train', nargs='+', metavar='FILE', required=True)
    parser.add_argument('--dev', type=str, metavar='FILE', required=True)
    parser.add_argument('--wordlist', type=str, metavar='FILE',
                        help='file with a list of word forms to use for '
                             'semi-supervised training')
    parser.add_argument('--model', type=str, metavar='FILE',
                        help='model file for saving')
    parser.add_argument('--optimizer', type=str, default='sgd',
                        choices=('sgd', 'momentum', 'adam'))
    parser.add_argument('--batch-size', type=int, metavar='N', default=64)
    parser.add_argument('--encoder-size', type=int, metavar='N', default=128)
    parser.add_argument('--decoder-size', type=int, metavar='N', default=128)
    parser.add_argument('--attention-size', type=int, metavar='N', default=128)
    parser.add_argument('--embedding-size', type=int, metavar='N', default=32)
    parser.add_argument('--features-size', type=int, metavar='N', default=128)
    parser.add_argument('--learning-rate', type=float, metavar='X',
                        default=0.1)
    parser.add_argument('--recurrent-dropout', type=float, metavar='X',
                        default=0.2)
    parser.add_argument('--dropout', type=float, metavar='X', default=0.2)
    parser.add_argument('--gpu', type=int, metavar='N', default=-1,
                        help='GPU to use (-1 to CPU)')
    parser.add_argument('--random-seed', type=int, metavar='N', default=123)
    parser.add_argument('--test-every', type=int, metavar='N', default=250)
    parser.add_argument('--bidirectional', action='store_true')
    parser.add_argument('--language-embeddings', action='store_true')

    args = parser.parse_args()

    languages, embeddings = load_language_embeddings()
    lemb_dim = embeddings[0].size
    enable_unsupervised = args.wordlist is not None

    if 'pron_data' in args.train[0]:
        train_data_list = [G2pFile(fname, languages) for fname in args.train]
    elif 'asjp' in args.train[0]:
        train_data_list = [ASJPFile(fname, languages) for fname in args.train]
    else:
        train_data_list = [Task1File(fname) for fname in args.train]

    train_data = train_data_list.pop()
    while train_data_list:
        train_data.merge(train_data_list.pop())

    # train_data_list = [Task1File(fname) for fname in args.train]
    # def merge(x, y):
    #     return x.merge(y)
    # import functools#from functools import reduce
    # train_data2 = functools.reduce(lambda x, y: x.merge(y), train_data_list)

    if 'pron_data' in args.dev:
        dev_data = G2pFile(args.dev, languages)
    elif 'asjp' in args.dev:
        dev_data = ASJPFile(args.dev, languages)
    else:
        dev_data = Task1File(args.dev)

    if enable_unsupervised:
        with open(args.wordlist, 'r', encoding='utf-8') as f:
            wordlist = f.read().split()

    alphabet = ['<S>', '</S>', '<UNK>'] + train_data.get_alphabet()
    features = train_data.get_features()
    alphabet_idx = {c:i for i,c in enumerate(alphabet)}
    features_idx = {c:i for i,c in enumerate(features)}

    print('Training data size: %d' % len(train_data.data))
    print('Development data size: %d' % len(dev_data.data))
    if enable_unsupervised:
        print('Unsupervised data size: %d' % len(wordlist))
        wordlist = [word for word in wordlist
                    if all(c in alphabet_idx for c in word)]
        print('Keeping: %d' % len(wordlist))
    else:
        print('No unsupervised data')

    unk = alphabet_idx['<UNK>']
    bos = alphabet_idx['<S>']
    eos = alphabet_idx['</S>']

    with open(args.model + '.pickle', 'wb') as f:
        pickle.dump(args, f)
        pickle.dump(alphabet, f)
        pickle.dump(features, f)

    random.seed(args.random_seed)
    gpu = args.gpu

    model = ForwardBackward(
            alphabet, features, (languages, embeddings),
            embedding_size=args.embedding_size,
            encoder_size=args.encoder_size,
            decoder_size=args.decoder_size,
            features_size=args.features_size,
            language_embedding_size=lemb_dim,
            attention_size=args.attention_size,
            dropout=args.dropout,
            recurrent_dropout=args.recurrent_dropout,
            bidirectional=args.bidirectional,
            use_lembs=args.language_embeddings)

    if gpu >= 0: model.to_gpu(gpu)

    xp = model.xp

    def encode_source(batch, split=False):
        if 'pron' in args.dev and split:
            [
                Variable(xp.array(
                    [alphabet_idx.get(c, unk) for c in ['<S>']+s.split()+['</S>']],
                    dtype=xp.int32))
                for s in batch]

        return [
            Variable(xp.array(
                [alphabet_idx.get(c, unk) for c in ['<S>']+list(s)+['</S>']],
                dtype=xp.int32))
            for s in batch]

    def encode_target(batch, split=True):
        max_len = max(map(len, batch))
        if 'pron' in args.dev and split:
            #import ipdb; ipdb.set_trace()
            return Variable(xp.array([
                [alphabet_idx.get(c, unk) for c in ['<S>']+s.split()+['</S>']] +
                    [-1]*(max_len-len(s.split()))
                for s in batch],
                dtype=xp.int32))
        else:
            return Variable(xp.array([
                [alphabet_idx.get(c, unk) for c in ['<S>']+list(s)+['</S>']] +
                    [-1]*(max_len-len(s))
                for s in batch],
                dtype=xp.int32))

    def encode_features(batch, dtype=xp.float32):
        return Variable(xp.array([
            [int(f in feats) for f in features]
            for feats in batch], dtype=dtype))

    def encode_language(batch):
        return Variable(xp.array([
            [get_lcode(lang)]
            for lang in batch], dtype=xp.int32))

    def get_lcode(lang, code=True):
        encs = ['latn', 'cyrl', 'grek']
        for enc in encs:
            if lang+'-'+enc in languages:
                if code:
                    return languages.index(lang+'-'+enc)
                else:
                    return lang+'-'+enc

        print(lang)

    def translate(source, feats, languages, max_length=50):
        batch_size = len(source)
        target = []
        state = model.forward(source, feats, languages)
        bos = alphabet_idx['<S>']
        eos = alphabet_idx['</S>']
        c_t = Variable(xp.array([bos]*batch_size, dtype=xp.int32))
        alive = xp.ones((batch_size,), dtype=xp.int32)
        while any(cuda.to_cpu(alive)) and len(target) <= max_length:
            c_tm1 = c_t
            state = state(c_tm1)
            c_t = F.argmax(state.p, axis=1)
            alive *= (c_t.data != eos)
            target.append(c_t.data * alive)
        return [[int(c) for c in seq if c]
                for seq in cuda.to_cpu(xp.hstack([t[:,None] for t in target]))]


    if args.optimizer == 'sgd':
        optimizer = chainer.optimizers.SGD(lr=args.learning_rate)
    elif args.optimizer == 'momentum':
        optimizer = chainer.optimizers.MomentumSGD(lr=args.learning_rate)
    elif args.optimizer == 'adam':
        optimizer = chainer.optimizers.Adam()
    else:
        assert False
    optimizer.use_cleargrads()
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.GradientClipping(5.0))

    dev_batch = list(dev_data.data)
    dev_batch.sort(key=lambda t: -len(t[0]))

    train_supervised = train_data.data
    if enable_unsupervised:
        train_unsupervised = [(None,word,None) for word in wordlist]

    max_diff = max(len(trg)-len(src) for src,trg,_,_ in train_supervised)

    logf = open(args.model + '.log', 'a')

    best_dev = float('inf')
    best_levenshtein = float('inf')
    batch_size = args.batch_size
    supervised = True
    n_batches = 0
    n_supervised = 0
    first = True
    while True:
        if first:
            print(', '.join(sorted(list(set([get_lcode(l, code=False) for _,_,_,l in train_supervised])))))
            first = False

        # Every second batch is supervised (but we will start with only
        # supervised training)
        if args.language_embeddings:
            print('Saving lembeds')
            np.save(args.model+'_epoch_{0}'.format(n_batches), model.l_embeddings.W.array)

        # if enable_unsupervised and n_supervised > 100000:
        #     supervised = not supervised

        if supervised:
            n_supervised += batch_size

        batch = random.sample(
                train_supervised if supervised else train_unsupervised,
                batch_size)

        b_batch_idx = sorted(enumerate(batch), key=lambda t: -len(t[1][1]))
        b_batch = [t[1] for t in b_batch_idx]
        b_source = encode_source([trg for _,trg,_,_ in b_batch], split=False)
        if supervised:
            b_target = encode_target([src for src,_,_,_ in b_batch], split=True)
            b_feats = encode_features([f for _,_,f,_ in b_batch], dtype=xp.int32)
            b_langs = encode_language([l for _,_,_,l in b_batch])


        model.cleargrads()
        t0 = time.time()

        state = model.backward(b_source, b_langs)
        if supervised:
            b_feats_loss = F.sigmoid_cross_entropy(state.p_features, b_feats)
            b_loss = b_feats_loss
            for c_tm1, c_t in zip(F.transpose(b_target),
                                  F.transpose(b_target)[1:]):
                state = state(c_tm1)
                b_loss += F.softmax_cross_entropy(state.p, c_t, normalize=False)
            b_loss.backward()
        else:
            # TODO: discretize?
            f_feats = F.sigmoid(state.p_features)

            c_t = Variable(xp.array([bos]*batch_size, dtype=xp.int32))
            max_length = max_diff + 1  + max(len(trg) for _,trg,_ in b_batch)
            target = []
            for _ in range(max_length):
                c_tm1 = c_t
                state = state(c_tm1)
                c_t = F.argmax(state.p, axis=1)
                target.append(c_t)
            target.append(Variable(xp.array([eos]*batch_size, dtype=xp.int32)))
            b_target = F.transpose_sequence(target)
            b_target = [x[:xp.min(xp.flatnonzero(x.data == eos))+1]
                        for x in b_target]

        if supervised:
            f_batch_idx = sorted(enumerate(batch), key=lambda t: -len(t[1][0]))
            f_batch = [t[1] for t in f_batch_idx]

            f_source = encode_source([src for src,_,_,_ in f_batch])
            f_feats = encode_features([f for _,_,f,_ in f_batch])
            f_langs = encode_language([l for _,_,_,l in f_batch])
            f_target = encode_target([trg for _,trg,_,_ in f_batch])
        else:
            f_source_idx = sorted(
                    enumerate(b_target), key=lambda x: -x[1].shape[0])
            f_source = [x for _,x in f_source_idx]
            f_target = encode_target(
                    [b_batch[idx][1] for idx,_ in f_source_idx])

        f_loss = 0.0
        state = model.forward(f_source, f_feats, f_langs)
        for c_tm1, c_t in zip(F.transpose(f_target), F.transpose(f_target)[1:]):
            state = state(c_tm1)
            f_loss += F.softmax_cross_entropy(state.p, c_t, normalize=False)
        f_loss.backward()

        optimizer.update()

        if supervised:
            print('TRAIN SUPERVISED %d %.3f %.3f %.3f %.3f' % (
                    n_batches, time.time() - t0,
                    f_loss.data, b_loss.data, b_feats_loss.data),
                flush=True)
        else:
            print('TRAIN UNSUPERVISED %d %.3f %.3f' %
                    (n_batches, time.time() - t0, f_loss.data),
                  flush=True)

        n_batches += 1
        if n_batches % args.test_every == 0:
            with chainer.using_config('train', False):
                dev_source = encode_source([src for src,_,_,_ in dev_batch])
                dev_target = encode_target([trg for _,trg,_,_ in dev_batch])
                dev_feats = encode_features([f for _,_,f,_ in dev_batch])
                dev_langs = encode_language([l for _,_,_,l in dev_batch])

                translated = [''.join(alphabet[x] for x in pred)
                              for pred in translate(dev_source, dev_feats, dev_langs)]
                for (_,trg,_,lang), pred in zip(dev_batch, translated):
                    print(trg, pred, lang, flush=True)
                n_correct = sum(
                        trg == pred
                        for (_,trg,_,_),pred in zip(dev_batch, translated))
                l_sum = sum(
                        Levenshtein.distance(trg, pred)
                        for (_,trg,_,_),pred in zip(dev_batch, translated))
                accuracy = n_correct / len(translated)
                l_mean = l_sum / len(translated)
                print('ACCURACY %d %.2f %.3f' % (
                    n_batches, 100.0*accuracy, l_mean),
                      flush=True)

                dev_loss = 0.0
                state = model.forward(dev_source, dev_feats, dev_langs)
                for c_tm1, c_t in zip(
                        F.transpose(dev_target), F.transpose(dev_target)[1:]):
                    state = state(c_tm1)
                    dev_loss += F.softmax_cross_entropy(
                            state.p, c_t, normalize=False)

                print('%d %.6f %d %.3f %.4f' % (
                        n_batches, dev_loss.data, n_correct, 100.0*accuracy,
                        l_mean),
                      file=logf, flush=True)

                print('DEV %d %.3f' % (n_batches, dev_loss.data), flush=True)
                if float(dev_loss.data) < best_dev:
                    best_dev = float(dev_loss.data)
                    serializers.save_npz(args.model + '.best-loss.npz', model)
                if l_mean < best_levenshtein:
                    best_levenshtein = l_mean
                    serializers.save_npz(args.model + '.npz', model)


if __name__ == '__main__':
    main()
