#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Translate pre-processed data with a trained model.
"""

import logging
import json
import os
import sys
import torch

import numpy as np
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.logging import progress_bar
from fairseq.data import encoders
from fairseq_cli.interactive import buffered_read, make_batches

import sys
sys.path.append("/opt/tiger/sumtest/crossLingualTransfer")
from utils.visualizeEmb import linear_CKA

def main(args):
    assert args.path is not None, '--path required for saving embedding!'

    if args.results_path is not None:
        os.makedirs(args.results_path, exist_ok=True)
    print("args.results_path: ", args.results_path)
    return _main(args, sys.stdout)


def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairseq_cli.extract_summarization')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    print("model_overrides: ")
    print(args.model_overrides)
    # args.model_overrides = "{'component_config': '/home/tiger/xgiga_dumpEmb_proj/model/component_config.json', 'doc_state': 'proj', 'arch': 'mbart_summ_abs_large', 'encoder_version': 'v1', 'proj_k': 1024}"
    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args.path),
        arg_overrides=eval(args.model_overrides),
        task=task,
        strict=not getattr(args, "loose_load", False)
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x
    
    period_symbols = [tgt_dict.eos()]
    # for key in tgt_dict.indices:
    #     if "." in key:
    #         period_symbols.append(tgt_dict.index(key))

    topk = 50 # dump embedding of the first topk samples

    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        *[model.max_positions() for model in models]
    )
    
    sample_id = 0
    document_infos = []
    token_embedding_infos = []
    progress = progress_bar.progress_bar(
        buffered_read(args.input, args.buffer_size),
        log_format=args.log_format,
        log_interval=args.log_interval,
        default_log_format=('tqdm' if not args.no_progress_bar else 'none'),
    )
    for inputs in progress:
    # for inputs in buffered_read(args.input, args.buffer_size):
        for batch in make_batches(inputs, args, task, max_positions, encode_fn):
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            if use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()

            sample = {
                'net_input': {
                    'src_tokens': src_tokens,
                    'src_lengths': src_lengths,
                },
            }

            """
            # # dump token, sentence and document embedding
            source_embeddings = task.dump_source_embedding(models, sample, _model_args)
            for i in range(sample['net_input']['src_tokens'].size(0)):
                sent_infos = []
                padded_src_tokens = sample['net_input']['src_tokens'][i, :]
                emb = source_embeddings[i][padded_src_tokens.ne(tgt_dict.pad())]
                # Remove padding
                src_tokens = utils.strip_pad(
                    padded_src_tokens, tgt_dict.pad()
                )

                token_positions = None
                for ps in period_symbols:
                    tmp = (src_tokens == ps)
                    if token_positions is None:
                        token_positions = ~tmp
                    else:
                        token_positions = ~tmp * token_positions
                period_positions = ~token_positions
                sentence_starts = (period_positions[:-1] * ~period_positions[1:]).nonzero(as_tuple=False) + 1
                # print("sentence_starts.shape: ", sentence_starts.shape)
                zero_tensor = torch.zeros([1, 1], dtype=sentence_starts.dtype).to(sentence_starts)
                sentence_starts = torch.cat([zero_tensor, sentence_starts], dim=0)
                
                # Either retrieve the original sentences or regenerate them from tokens.
                spm_src_str = None
                src_sents = None
                if align_dict is not None:
                    full_src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                    src_sents = full_src_str.split('<q>')
                else:
                    full_src_str = src_dict.string(src_tokens, args.remove_bpe)
                    full_src_str = decode_fn(full_src_str)
                    spm_src_str = src_dict.string(src_tokens)
                for j in range(len(sentence_starts)-1):
                    start, end = sentence_starts[j], sentence_starts[j+1]
                    sent_embedding = torch.mean(emb[start:end], dim=0)
                    if align_dict is not None:
                        # retrieve the original sentences
                        src_str = src_sents[j]
                    else:
                        # regenerate the original sent from tokens
                        sent_tokens = src_tokens[start:end]
                        if src_dict is not None:
                            src_str = src_dict.string(sent_tokens, args.remove_bpe)
                        else:
                            src_str = ""

                    src_str = decode_fn(src_str)
                    sent_info = {
                        "sent": src_str, 
                        "embedding": sent_embedding.cpu().detach().numpy().tolist(), 
                        "id": j
                    }
                    sent_infos.append(sent_info)

                doc_embedding = torch.mean(emb, dim=0).cpu().detach().numpy().tolist()
                document_infos.append({
                    'embedding': doc_embedding,
                    'src_str': full_src_str,
                    'doc_id': sample_id,
                    'sent_infos': sent_infos
                })

                # # dump token embedding
                # spm_src_tokens = spm_src_str.split()[:-1] # remove the language tag
                # emb = emb[:-2].cpu().detach().numpy().tolist() # remove the eos symbol and language tag
                # seq_len = len(emb)

                # assert len(spm_src_tokens) == seq_len, \
                #     "the token num of source sequence ({}) and the first dimension" \
                #     "of the embedding tensor ({}) should be equal".format(len(spm_src_tokens), seq_len)
                # token_embs = [{"token": spm_src_tokens[i], "embedding": emb[i]} for i in range(seq_len)]
                # token_embedding_infos.append({
                #     "src_str": " ".join(spm_src_tokens),
                #     "doc_id": sample_id,
                #     "embeddings": token_embs,
                # })
                sample_id += 1
                """

            # # dump document embedding only
            doc_embeddings = task.dump_document_embedding(models, sample, _model_args)
            for i in range(doc_embeddings.size(0)):
                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    full_src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                else:
                    full_src_str = src_dict.string(src_tokens, args.remove_bpe)
                    full_src_str = decode_fn(full_src_str)

                doc_embedding = doc_embeddings[i].cpu().detach().numpy().tolist()
                document_infos.append({
                    'embedding': doc_embedding,
                    'src_str': full_src_str,
                    'doc_id': sample_id
                })

                sample_id += 1

    # encoder_outs = np.concatenate(models[0].encoder.encoder_outs, axis=0)
    # proj_hiddens = np.concatenate(models[0].encoder.proj_hiddens, axis=0)
    # adapter_outs = np.concatenate(models[0].encoder.adapter_outs, axis=0)
    # print("encoder_outs.shape: ", encoder_outs.shape)
    # print("cka_in_out: ", linear_CKA(encoder_outs, adapter_outs))
    # print("cka_in_proj: ", linear_CKA(encoder_outs, proj_hiddens))
    # print("cka_proj_out: ", linear_CKA(proj_hiddens, adapter_outs))
    
    document_infos = sorted(document_infos, key=lambda x: x['doc_id'])
    # token_embedding_infos = sorted(token_embedding_infos, key=lambda x: x['doc_id'])
    # sorted_sent_infos = []
    # for document_info in document_infos:
    #     for sent_info in document_info['sent_infos']:
    #         sorted_sent_infos.append(sent_info)

    # with open(os.path.join(args.results_path, "sent_embedding.jsonl"), 'w') as fout:
    #     for info in sorted_sent_infos:
    #         fout.write(json.dumps(info, ensure_ascii=False) + '\n')

    all_doc_embs = []
    with open(os.path.join(args.results_path, "document_embedding.jsonl"), 'w') as fout:
        for doc_info in document_infos:
            fout.write(json.dumps(doc_info, ensure_ascii=False) + '\n')
            all_doc_embs.append(doc_info['embedding'])
    all_doc_embs = np.array(all_doc_embs)
    np.save(
        open(os.path.join(args.results_path, "document_embedding.npy"), 'wb'),
        all_doc_embs
    )

    logger.info("embeddings are saved in {}".format(args.results_path))

def cli_main():
    parser = options.get_generation_parser(interactive=True)
    args = options.parse_args_and_arch(parser)
    main(args)

if __name__ == '__main__':
    cli_main()
