import sys
import random
import time
import argparse
import pickle
from collections import namedtuple

import numpy as np

import chainer
from chainer import Variable, cuda, serializers
from chainer import functions as F

from model import ForwardBackward
from data import Task1File

Hypothesis = namedtuple('Hypothesis', 'states, score, history, alive')


def main():
    parser = argparse.ArgumentParser(
            description='ForwardBackward model testing')

    parser.add_argument('--model', type=str, metavar='FILE',
                        help='model file(s) to load, comma-separated list')
    parser.add_argument('--input', type=str, metavar='FILE', required=True)
    #parser.add_argument('--batch-size', type=int, metavar='N', default=64)
    parser.add_argument('--beam-size', type=int, metavar='N', default=4)
    parser.add_argument('--gpu', type=int, metavar='N', default=-1,
                        help='GPU to use (-1 to CPU)')

    args = parser.parse_args()

    gpu = args.gpu
    test_data = Task1File(args.input)

    models = []
    alphabet = None
    features = None
    for prefix in args.model.split(','):
        print('Loading model %s...' % prefix, file=sys.stderr, flush=True)
        with open(prefix + '.pickle', 'rb') as f:
            model_args = pickle.load(f)
            alphabet_ = pickle.load(f)
            features_ = pickle.load(f)
            assert alphabet is None or alphabet == alphabet_
            assert features is None or features == features_
            alphabet = alphabet_
            features = features_

        model = ForwardBackward(
                alphabet, features,
                embedding_size=model_args.embedding_size,
                encoder_size=model_args.encoder_size,
                decoder_size=model_args.decoder_size,
                features_size=model_args.features_size,
                attention_size=model_args.attention_size,
                dropout=model_args.dropout,
                recurrent_dropout=model_args.recurrent_dropout,
                bidirectional=model_args.bidirectional \
                        if hasattr(model_args, 'bidirectional') else False)

        serializers.load_npz(prefix + '.npz', model)

        if gpu >= 0: model.to_gpu(gpu)

        models.append(model)


    alphabet_idx = {c:i for i,c in enumerate(alphabet)}
    features_idx = {c:i for i,c in enumerate(features)}
    unk = alphabet_idx['<UNK>']
    bos = alphabet_idx['<S>']
    eos = alphabet_idx['</S>']

    xp = model.xp

    def encode_source(batch):
        return [
            Variable(xp.array(
                [alphabet_idx.get(c, unk) for c in ['<S>']+list(s)+['</S>']],
                dtype=xp.int32))
            for s in batch]

    def encode_features(batch, dtype=xp.float32):
        return Variable(xp.array([
            [int(f in feats) for f in features]
            for feats in batch], dtype=dtype))

    def beam_search(states, beam_size, max_length=100):
        beam = [Hypothesis(
            states=states,
            score=0.0,
            history=[xp.array([bos], dtype=xp.int32)],
            alive=True)]

        for _ in range(max_length):
            alive = [hypothesis for hypothesis in beam if hypothesis.alive]
            if not alive: break
            beam = [hypothesis for hypothesis in beam if not hypothesis.alive]

            for hypothesis in alive:
                states = [state(Variable(hypothesis.history[-1]))
                          for state in hypothesis.states]
                scores = np.mean(
                        [cuda.to_cpu(F.log_softmax(state.p).data)[0]
                         for state in states],
                        axis=0)
                scores += hypothesis.score
                beam.extend(
                        Hypothesis(
                            states=states,
                            score=score,
                            history=hypothesis.history+[
                                xp.array([i], dtype=xp.int32)],
                            alive=(i != eos))
                        for i, score in enumerate(scores))

            beam.sort(key=lambda hypothesis: hypothesis.score, reverse=True)
            beam = beam[:beam_size]

        return beam

    batch_size = 1
    n_correct = 0
    n_total = 0
    rr_sum = 0.0
    n_correct10 = 0
    for i in range(0, len(test_data.data), batch_size):
        with chainer.using_config('train', False):
            batch = test_data.data[i:i+batch_size]
            source = encode_source([src for src,_,_ in batch])
            feats = encode_features([f for _,_,f in batch], dtype=xp.float32)

            states = [model.forward(source, feats) for model in models]
            beam = beam_search(states, args.beam_size,
                        max_length=20+max(len(src) for src,_,_ in batch))

            target = [''.join(alphabet[int(cuda.to_cpu(x))]
                      for x in hypothesis.history[1:-1])
                      for hypothesis in beam]
            if batch[0][1] is not None:
                n_correct += int(batch[0][1] == target[0])
                if batch[0][1] in target:
                    n_correct10 += 1
                    rr_sum += 1.0 / (1.0 + target.index(batch[0][1]))
                n_total += 1
            for src,trg,f in batch:
                print(src, trg, ';'.join(f), ','.join(target), file=sys.stderr)
                print('\t'.join((src, target[0], ';'.join(f))))

    if n_total:
        print('CORRECT/TOTAL/ACCURACY/10BEST-ACC/MRR',
                n_correct, n_total, n_correct/n_total,
                n_correct10/n_total, rr_sum/n_total,
              file=sys.stderr)

if __name__ == '__main__': main()
