import os
import argparse
import json

import pandas as pd
import torch.nn.functional as F
from setproctitle import setproctitle
import torch
from torchmetrics.text.rouge import ROUGEScore

from module.ModelingModule import return_model
from utils.training_utils import logger, replace_unicode_punct
from utils.file_utils import FileModule, TorchFileModule

fileutils = FileModule()
ROUGE = ROUGEScore()


def call_args(parent_parser=None):
    if parent_parser is not None:
        parser = argparse.ArgumentParser(
            parents=[parent_parser], add_help=False)
    else:
        parser = argparse.ArgumentParser()
    parser.add_argument('--ckpt', default=None, type=str, help='')
    parser.add_argument('--type', default=1, type=int, help='equal to training_type in training phase.')

    parser.add_argument('--testfile', default=None, type=str, help='')
    parser.add_argument('--savefile', default=None, type=str, help='')
    parser.add_argument('--gen', default=False, action='s')

    args = parser.parse_args()
    return args


def make_input(args, instance, tokenizer, device, model_type):
    psg = tokenizer.encode(replace_unicode_punct(instance['content']))
    q = tokenizer.encode(replace_unicode_punct(instance['question']))
    a = tokenizer.encode(replace_unicode_punct(instance['answer']))

    if 'bart' in model_type:
        psg = psg[1:]
        q = q[1:]
        a = a[1:]
    elif 't5' in model_type:
        psg_prefix = tokenizer.encode(replace_unicode_punct('context: '))[:-1]  # eos x
        q_prefix = tokenizer.encode(replace_unicode_punct('question: '))[:-1]  # eos x
        a_prefix = tokenizer.encode(replace_unicode_punct('answer: '))[:-1]  # eos x
        pass
    else:
        raise Exception('model type error')

    q_input = None
    if args.type == 1:
        # psg + q ==> a
        input_ids_tmp = [q, psg]
        label_ids_tmp = [a]
    elif args.type == 2:
        q_input = q[:1]
        if q_input[-1] != tokenizer.eos_token_id:
            q_input = q_input + [tokenizer.eos_token_id]
        input_ids_tmp = [q_input, psg]
        label_ids_tmp = [q, a]
    elif args.type == 3:
        q_input = q[:2]
        if q_input[-1] != tokenizer.eos_token_id:
            q_input = q_input + [tokenizer.eos_token_id]
        input_ids_tmp = [q_input, psg]
        label_ids_tmp = [q, a]
    elif args.type == 4:
        q_input = q[:3]
        if q_input[-1] != tokenizer.eos_token_id:
            q_input = q_input + [tokenizer.eos_token_id]
        input_ids_tmp = [q_input, psg]
        label_ids_tmp = [q, a]
    elif args.type == 5:
        q_input = q[:4]
        if q_input[-1] != tokenizer.eos_token_id:
            q_input = q_input + [tokenizer.eos_token_id]
        input_ids_tmp = [q_input, psg]
        label_ids_tmp = [q, a]
    elif args.type == 6:
        q_input = q[:5]
        if q_input[-1] != tokenizer.eos_token_id:
            q_input = q_input + [tokenizer.eos_token_id]
        input_ids_tmp = [q_input, psg]
        label_ids_tmp = [q, a]
    elif args.type == 7:
        input_ids_tmp = [q, psg]
        label_ids_tmp = [q, a]
    elif args.type == 8:
        # psg ==> q + a
        input_ids_tmp = [psg]
        label_ids_tmp = [q, a]
    else:
        raise Exception('training type error')

    if 'bart' in model_type:
        input_ids, label_ids = [tokenizer.bos_token_id], []
        for i in input_ids_tmp:
            input_ids = input_ids + i
        for i in label_ids_tmp:
            label_ids = label_ids + i
        dec_input_ids = [tokenizer.bos_token_id] + label_ids[:-1]
    elif 't5' in model_type:
        if args.type != 8:
            input_ids = q_prefix + input_ids_tmp[0] + psg_prefix + input_ids_tmp[1]
        else:
            input_ids = psg_prefix + input_ids_tmp[0]

        label_ids = a_prefix
        for i in label_ids_tmp:
            label_ids = label_ids + i[:-1]
        label_ids = label_ids + [tokenizer.eos_token_id]
        dec_input_ids = label_ids[:-1]
        label_ids = label_ids[1:]
    else:
        raise Exception('model type error')

    inputs = {
        'input_ids': torch.as_tensor(input_ids, device=device),
        'label_ids': torch.as_tensor(label_ids, device=device),
        'dec_input_ids': torch.as_tensor(dec_input_ids, device=device),
    }
    output = {}
    nmt_src_attention_mask = inputs['input_ids'].ne(tokenizer.pad_token_id).float()
    nmt_decoder_attention_mask = inputs['dec_input_ids'].ne(tokenizer.pad_token_id).float()
    output['src_attention_mask'] = nmt_src_attention_mask
    output['decoder_attention_mask'] = nmt_decoder_attention_mask
    output['src_ids'] = inputs['input_ids']
    output['labels'] = inputs['label_ids']
    output['decoder_input_ids'] = inputs['dec_input_ids']

    q_input = '###'.join(list(map(lambda x: str(x), q_input))) if q_input is not None else ''
    lengths = {'q_len': len(q),
               'a_len': len(a),
               'label_id': '###'.join(list(map(lambda x: str(x), inputs['label_ids'].cpu().tolist()))),
               'q_input': q_input}

    output = {key: torch.as_tensor(output[key], device=device).unsqueeze(0) for key in output}
    return output, lengths


def make_original(obj, model_type='bart'):
    if 'bart' in model_type:
        if len(obj) == 2:  # min length
            return obj
        return obj[1:-1]
    elif 't5' in model_type:
        if len(obj) <= 3:  # min length
            return obj
        return obj[2:-1]

@torch.no_grad()
def inference(parser):
    inference_args = call_args(parser)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(device)
    fileutils = TorchFileModule()
    logger.info('start')
    tmp = torch.load(
        os.path.join(inference_args.ckpt,
                     fileutils.find_best(inference_args.ckpt))
    )

    model_args = tmp['args']
    model, tokenizer = return_model(model_args)
    model.load_state_dict(tmp['model_state_dict'], strict=False)
    model = model.to(device)

    testdata = pd.read_csv(inference_args.testfile)

    scores = []
    scores_original = []
    scores_candidate = []
    scores_generated = []
    for idx, item in enumerate(testdata.iloc):
        logger.info(f'{str(idx + 1)} out of {str(len(testdata))}')
        prepared, lengths = make_input(
            args=inference_args,
            instance=item,
            tokenizer=tokenizer,
            device=device,
            model_type=model_args.model_type
        )

        with torch.no_grad():
            out = model(**prepared)
            generated = model.model.generate(
                prepared['src_ids'],
                num_beams=5,
                max_length=128,
                early_stopping=True
            )

        generated = tokenizer.batch_decode(
            generated.detach().cpu().tolist(), skip_special_tokens=True
        )
        reference = tokenizer.batch_decode(
            prepared['labels'].detach().cpu().tolist(), skip_special_tokens=True
        )
        score_generated = ROUGE(generated, reference)['rougeL_fmeasure'].tolist()
        scores_generated.append(score_generated)

        logger.info('\n\n')
        _, idx = torch.max(out['out']['logits'][0], dim=-1)
        logger.info(f"generated: {tokenizer.decode(idx.cpu())}")
        logger.info(f"reference: {tokenizer.decode(prepared['labels'][0].cpu())}")

        out = out['out']['logits'][0].cpu()
        out = F.softmax(out, dim=-1, dtype=torch.float64)

        score = []
        for idx, label_idx in enumerate(prepared['labels'][0].cpu()):
            score.append(out[idx, label_idx])

        score_original = torch.mean(
            torch.as_tensor(make_original(score, model_type=model_args.model_type), dtype=torch.float64)
        )
        score_candidate = torch.mean(
            torch.as_tensor(score, dtype=torch.float64)
        )

        lengths['score'] = '###'.join(list(map(lambda x: str(x.tolist()), score)))
        scores.append(lengths)

        scores_original.append(score_original.tolist())
        scores_candidate.append(score_candidate.tolist())

        logger.info(f'{score_original}')
        logger.info(prepared['labels'][0].cpu())

    scores_original = [(scores_original[3 * i] + scores_original[3 * i + 1] + scores_original[3 * i + 2]) / 3
                       for i in range(len(scores_original) // 3)]
    scores_candidate = [(scores_candidate[3 * i] + scores_candidate[3 * i + 1] + scores_candidate[3 * i + 2]) / 3
                        for i in range(len(scores_candidate) // 3)]
    scores_generated = [(scores_generated[3 * i] + scores_generated[3 * i + 1] + scores_generated[3 * i + 2]) / 3
                        for i in range(len(scores_generated) // 3)]

    save_dir = 'generated'
    os.makedirs(save_dir, exist_ok=True)
    with open(os.path.join(save_dir,
                           inference_args.savefile + '.json'), 'w', encoding='utf-8') as f:
        json.dump(scores, f)

    with open(os.path.join(save_dir,  # intended file (GI / GI_skip)
                           inference_args.savefile), 'w', encoding='utf-8') as f:
        f.writelines(list(map(lambda x: str(x)+'\n', scores_original)))

    with open(os.path.join(save_dir,  # GI_all
                           inference_args.savefile + '_candidate'), 'w', encoding='utf-8') as f:
        f.writelines(list(map(lambda x: str(x) + '\n', scores_candidate)))

    with open(os.path.join(save_dir,  # GI_ss
                           inference_args.savefile + '_generated'), 'w', encoding='utf-8') as f:
        f.writelines(list(map(lambda x: str(x) + '\n', scores_generated)))


if __name__ == '__main__':
    print(':)')
    parser = argparse.ArgumentParser()
    inference(parser)
