import sys
import random
import time

import numpy as np

import chainer
from chainer import Variable, cuda
from chainer import functions as F

from model import EncoderDecoder
from data import Task1File


def main():
    lang = sys.argv[1]
    train_data = Task1File('../all/task1/%s-train-high' % lang)
    dev_data = Task1File('../all/task1/%s-dev' % lang)

    alphabet = ['<S>', '</S>', '<UNK>'] + train_data.get_alphabet()
    alphabet = alphabet + train_data.get_features()
    alphabet_idx = {c:i for i,c in enumerate(alphabet)}
    unk = alphabet_idx['<UNK>']

    gpu = -1

    model = EncoderDecoder(alphabet)
    if gpu >= 0: model.to_gpu(gpu)

    xp = model.xp

    def encode_source(batch):
        return [
            Variable(xp.array(
                [alphabet_idx.get(c, unk) for c in ['<S>']+list(s)+['</S>']],
                dtype=xp.int32))
            for s in batch]

    def encode_target(batch):
        max_len = max(map(len, batch))
        return Variable(xp.array([
            [alphabet_idx.get(c, unk) for c in ['<S>']+list(s)+['</S>']] +
                [-1]*(max_len-len(s))
            for s in batch],
            dtype=xp.int32))


    def translate(batch, max_length=50):
        source = encode_source(batch)
        target = []
        state = model(source)
        bos = alphabet_idx['<S>']
        eos = alphabet_idx['</S>']
        c_t = Variable(xp.array([bos]*len(batch), dtype=xp.int32))
        alive = xp.ones((len(batch),), dtype=xp.int32)
        while any(cuda.to_cpu(alive)) and len(target) <= max_length:
            c_tm1 = c_t
            state = state(c_tm1)
            c_t = F.argmax(state.p, axis=1)
            alive *= (c_t.data != eos)
            target.append(c_t.data * alive)
        return [[int(c) for c in seq if c]
                for seq in cuda.to_cpu(xp.stack(target, axis=1))]


    optimizer = chainer.optimizers.Adam()
    optimizer.use_cleargrads()
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.GradientClipping(5.0))

    dev_batch = list(dev_data.data)
    dev_batch.sort(key=lambda t: -(len(t[2])+len(t[0])))

    batch_size = 64
    n_batches = 0
    while True:
        batch = random.sample(train_data.data, batch_size)
        batch.sort(key=lambda t: -(len(t[2])+len(t[0])))

        # Test with auto-encoder
        source = encode_source([feats+list(src) for src,_,feats in batch])
        target = encode_target([trg for _,trg,_ in batch])

        t0 = time.time()

        model.cleargrads()

        loss = 0.0
        state = model(source)
        for c_tm1, c_t in zip(F.transpose(target), F.transpose(target)[1:]):
            state = state(c_tm1)
            loss += F.softmax_cross_entropy(state.p, c_t, normalize=False)

        loss.backward()
        optimizer.update()

        print('TIME %.3f' % (time.time() - t0), flush=True)

        print('TRAIN', loss.data, flush=True)

        n_batches += 1
        if n_batches % 200 == 0:
            with chainer.using_config('train', False):
                dev_source_raw = [feats+list(src) for src,_,feats in dev_batch]
                dev_target_raw = [list(trg) for _,trg,_ in dev_batch]
                dev_source = encode_source(dev_source_raw)
                dev_target = encode_target(dev_target_raw)

                translated = [''.join(alphabet[x] for x in pred)
                              for pred in translate(dev_source_raw)]
                for trg, pred in zip(dev_target_raw, translated):
                    print(''.join(trg), pred, flush=True)
                accuracy = sum(
                        ''.join(trg) == pred
                        for trg,pred in zip(dev_target_raw, translated)) / (
                                len(translated))
                print('Accuracy: %.2f%%' % (100.0*accuracy), flush=True)

                dev_loss = 0.0
                state = model(dev_source)
                for c_tm1, c_t in zip(
                        F.transpose(dev_target), F.transpose(dev_target)[1:]):
                    state = state(c_tm1)
                    dev_loss += F.softmax_cross_entropy(
                            state.p, c_t, normalize=False)

                print('DEV', dev_loss.data, flush=True)

if __name__ == '__main__':
    main()

