import sys
import random
import time
import argparse
import pickle
from pprint import pprint

import numpy as np
import Levenshtein

import chainer
from chainer import Variable, cuda, serializers
from chainer import functions as F

from model import RNNTransducer
from data import Task1File

def align(s1, s2):
    if s1 == s2: return [[]] + list(s1) + [[]]

    ops = Levenshtein.editops(s1, s2)
    parts = [[] for _ in range(len(s1)+2)]
    pos1 = pos2 = 0

    for op, i, j in ops:
        if i == 0 and op == 'insert':
            parts[0].append(s2[pos2])
            pos2 += 1
            continue
        elif i == 0 and op == 'replace':
            parts[1].append(s2[pos2])
            pos1 += 2
            pos2 += 1
            continue
        elif i == 0 and op == 'delete':
            pos1 += 2
            continue
        elif pos1 == 0:
            pos1 += 1

        # Take care of plain copying
        while pos1 < i+1 and pos2 < j:
            parts[pos1].append(s2[pos2])
            pos1 += 1
            pos2 += 1

        assert pos1 == i+1 and pos2 == j, (s1, s2, pos1, pos2, i, j, ops)

        # Replacement is handled in the same way as copying (1:1 mapping)
        if op == 'replace':
            parts[pos1].append(s2[pos2])
            pos1 += 1
            pos2 += 1
        elif op == 'insert':
            parts[pos1].append(s2[pos2])
            pos2 += 1
        elif op == 'delete':
            pos1 += 1

    if pos1 == 0:
        pos1 += 1
    while pos1 < len(s1)+1 and pos2 < len(s2):
        parts[pos1].append(s2[pos2])
        pos1 += 1
        pos2 += 1

    assert pos1 == len(s1)+1 and pos2 == len(s2)

    # Heuristic to treat suffixes preceded by a single deleted character as
    # deletion + suffix
    if len(parts[-1]) == 0 and len(parts[-2]) >= 2:
        parts[-1] = parts[-2]
        parts[-2] = []

    return parts


def main():
    parser = argparse.ArgumentParser(
            description='RNNTransducer model training')
    parser.add_argument('--train', type=str, metavar='FILE', required=True)
    parser.add_argument('--dev', type=str, metavar='FILE', required=True)
    parser.add_argument('--model', type=str, metavar='FILE',
                        help='model file for saving')
    parser.add_argument('--optimizer', type=str, default='sgd',
                        choices=('sgd', 'momentum', 'adam'))
    parser.add_argument('--batch-size', type=int, metavar='N', default=64)
    parser.add_argument('--encoder-size', type=int, metavar='N', default=128)
    parser.add_argument('--decoder-size', type=int, metavar='N', default=128)
    parser.add_argument('--embedding-size', type=int, metavar='N', default=32)
    parser.add_argument('--features-size', type=int, metavar='N', default=128)
    parser.add_argument('--learning-rate', type=float, metavar='X',
                        default=0.1)
    parser.add_argument('--recurrent-dropout', type=float, metavar='X',
                        default=0.2)
    parser.add_argument('--dropout', type=float, metavar='X', default=0.2)
    parser.add_argument('--gpu', type=int, metavar='N', default=-1,
                        help='GPU to use (-1 to CPU)')
    parser.add_argument('--random-seed', type=int, metavar='N', default=123)
    parser.add_argument('--test-every', type=int, metavar='N', default=250)

    args = parser.parse_args()

    train_data = Task1File(args.train)
    dev_data = Task1File(args.dev)

    print('Training data size: %d' % len(train_data.data))
    print('Development data size: %d' % len(dev_data.data))

    alphabet = ['<S>', '</S>', '<UNK>'] + train_data.get_alphabet()
    features = train_data.get_features()
    alphabet_idx = {c:i for i,c in enumerate(alphabet)}
    features_idx = {c:i for i,c in enumerate(features)}
    unk = alphabet_idx['<UNK>']
    bos = alphabet_idx['<S>']
    eos = alphabet_idx['</S>']

    with open(args.model + '.pickle', 'wb') as f:
        pickle.dump(args, f)
        pickle.dump(alphabet, f)
        pickle.dump(features, f)

    random.seed(args.random_seed)
    gpu = args.gpu

    model = RNNTransducer(
            alphabet, features,
            embedding_size=args.embedding_size,
            encoder_size=args.encoder_size,
            decoder_size=args.decoder_size,
            features_size=args.features_size,
            dropout=args.dropout)

    if gpu >= 0: model.to_gpu(gpu)

    xp = model.xp

    def encode_source(source):
        return [
            Variable(xp.array(
                [alphabet_idx.get(c, unk) for c in ['<S>']+list(s)+['</S>']],
                dtype=xp.int32))
            for s in source]

    def encode_target(source, target):
        aligned = [['<S>'] + list(part) + ['</S>']
                   for s, t in zip(source, target)
                   for part in align(s, t)]
        max_len = max(map(len, aligned))

        return Variable(xp.array([
            [alphabet_idx.get(c, unk) for c in part + [-1]*(max_len-len(part))]
            for part in aligned],
            dtype=xp.int32))

    def encode_features(batch, dtype=xp.float32):
        return Variable(xp.array([
            [int(f in feats) for f in features]
            for feats in batch], dtype=dtype))

    def translate(source, feats, max_length=50):
        batch_size = sum(s.shape[0] for s in source)
        target = []
        state = model(source, feats)
        bos = alphabet_idx['<S>']
        eos = alphabet_idx['</S>']
        c_t = Variable(xp.array([bos]*batch_size, dtype=xp.int32))
        alive = xp.ones((batch_size,), dtype=xp.int32)
        while any(cuda.to_cpu(alive)) and len(target) <= max_length:
            c_tm1 = c_t
            state = state(c_tm1)
            c_t = F.argmax(state.p, axis=1)
            alive *= (c_t.data != eos)
            target.append(c_t.data * alive)
        return [[int(c) for c in seq if c]
                for seq in cuda.to_cpu(xp.hstack([t[:,None] for t in target]))]


    if args.optimizer == 'sgd':
        optimizer = chainer.optimizers.SGD(lr=args.learning_rate)
    elif args.optimizer == 'momentum':
        optimizer = chainer.optimizers.MomentumSGD(lr=args.learning_rate)
    elif args.optimizer == 'adam':
        optimizer = chainer.optimizers.Adam()
    else:
        assert False
    optimizer.use_cleargrads()
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.GradientClipping(5.0))

    dev_batch = list(dev_data.data)
    #dev_batch.sort(key=lambda t: -len(t[0]))

    train_supervised = train_data.data

    encode_target([src for src,_,_ in train_supervised],
                  [trg for _,trg,_ in train_supervised])

    #max_diff = max(len(trg)-len(src) for src,trg,_ in train_supervised)

    max_part = max(len(part) for src,trg,_ in train_supervised+dev_batch
                             for part in align(src, trg))

    print('Maximum length of parts: %d' % max_part)

    logf = open(args.model + '.log', 'a')

    best_dev = float('inf')
    best_levenshtein = float('inf')
    batch_size = args.batch_size
    supervised = True
    n_batches = 0
    while True:
        batch = random.sample(train_supervised, batch_size)

        source = encode_source([src for src,_,_ in batch])
        target = encode_target([src for src,_,_ in batch],
                               [trg for _,trg,_ in batch])
        feats = encode_features([f for _,_,f in batch])

        print(len(source), target.shape, feats.shape,
                sum(len(src)+2 for src,_,_ in batch))

        model.cleargrads()

        t0 = time.time()

        loss = 0.0
        state = model(source, feats)

        for c_tm1, c_t in zip(F.transpose(target), F.transpose(target)[1:]):
            state = state(c_tm1)
            loss += F.softmax_cross_entropy(state.p, c_t, normalize=False)
        loss.backward()

        optimizer.update()

        print('TRAIN %d %.3f %.3f' % (
                n_batches, time.time() - t0, loss.data),
            flush=True)

        n_batches += 1
        if n_batches % args.test_every == 0:
            with chainer.using_config('train', False):
                dev_source = encode_source([src for src,_,_ in dev_batch])
                dev_target = encode_target([src for src,_,_ in dev_batch],
                                           [trg for _,trg,_ in dev_batch])
                dev_feats = encode_features([f for _,_,f in dev_batch])

                parts = [[alphabet[x] for x in pred]
                         for pred in translate(dev_source, dev_feats, max_part)]
                translated = []
                i = 0
                for s in dev_source:
                    translated.append(''.join([
                        c for j in range(i, i+s.shape[0])
                          for c in parts[j]]))
                    i += s.shape[0]

                for (_,trg,_), pred in zip(dev_batch, translated):
                    print(trg, pred, flush=True)
                n_correct = sum(
                        trg == pred
                        for (_,trg,_),pred in zip(dev_batch, translated))
                l_sum = sum(
                        Levenshtein.distance(trg, pred)
                        for (_,trg,_),pred in zip(dev_batch, translated))
                accuracy = n_correct / len(translated)
                l_mean = l_sum / len(translated)
                print('ACCURACY %d %.2f %.3f' % (
                    n_batches, 100.0*accuracy, l_mean),
                      flush=True)

                dev_loss = 0.0
                state = model(dev_source, dev_feats)
                for c_tm1, c_t in zip(
                        F.transpose(dev_target), F.transpose(dev_target)[1:]):
                    state = state(c_tm1)
                    dev_loss += F.softmax_cross_entropy(
                            state.p, c_t, normalize=False)

                print('%d %.6f %d %.3f %.4f' % (
                        n_batches, dev_loss.data, n_correct, 100.0*accuracy,
                        l_mean),
                      file=logf, flush=True)

                print('DEV %d %.3f' % (n_batches, dev_loss.data), flush=True)
                if float(dev_loss.data) < best_dev:
                    best_dev = float(dev_loss.data)
                    serializers.save_npz(args.model + '.best-loss.npz', model)
                if l_mean < best_levenshtein:
                    best_levenshtein = l_mean
                    serializers.save_npz(args.model + '.npz', model)


if __name__ == '__main__':
    main()

