import sys 
sys.path.append("..") 

import argparse
import copy
import datetime
import json
import os
import time
from pathlib import Path
from rank_bm25 import BM25Okapi
from functools import partial

import numpy as np
# import timm.optim.optim_factory as optim_factory
import torch
import torch.backends.cudnn as cudnn
import util.misc as misc
from engine_finetuning import train_one_epoch, val_one_epoch, load_model, load_generator_from_raw, load_generator_from_trained
from torch.utils.data import Dataset
# from torch.utils.tensorboard import SummaryWriter
# from util.misc import NativeScalerWithGradNormCount as NativeScaler

from llama import Tokenizer
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
from tqdm import trange, tqdm

# PROMPT_DICT = {
#     "prompt_input": (
#         "Below is an instruction that describes a task, paired with an input that provides further context. "
#         "Write a response that appropriately completes the request.\n\n"
#         "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
#     ),
#     "prompt_no_input": (
#         "Below is an instruction that describes a task. "
#         "Write a response that appropriately completes the request.\n\n"
#         "### Instruction:\n{instruction}\n\n### Response:"
#     ),
# }


def extract_article(text):
    marker = "Generate a title for the following abstract of a paper: "
    # Find the position of the marker in the text
    marker_pos = text.find(marker)
    
    # Check if the marker is found
    if marker_pos == -1:
        raise ValueError()

    # Extract the string after the marker
    extracted_string = text[marker_pos + len(marker):]

    return extracted_string


class InstructionDataset(Dataset):
    def __init__(self, data_list, tokenizer_path, max_tokens=2048):
        self.ann = data_list

        self.max_words = max_tokens
        tokenizer = Tokenizer(model_path=tokenizer_path + "/tokenizer.model")
        self.tokenizer1 = tokenizer

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):

        # ann = self.ann[index]
        # prompt = ann['prompt']
        # example = ann['full_prompt']

        # prompt = torch.tensor(self.tokenizer1.encode(prompt, bos=True, eos=False), dtype=torch.int64)
        # example = torch.tensor(self.tokenizer1.encode(example, bos=True, eos=True), dtype=torch.int64)

        # padding = self.max_words - example.shape[0]
        # if padding > 0:
        #     example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
        # elif padding < 0:
        #     example = example[: self.max_words]
        # labels = copy.deepcopy(example)
        # labels[: len(prompt)] = -1
        # example_mask = example.ge(0)
        # label_mask = labels.ge(0)
        # example[~example_mask] = 0
        # labels[~label_mask] = 0
        # example_mask = example_mask.float()
        # label_mask = label_mask.float()

        # return example, labels, example_mask
        ann = self.ann[index]
        prompt = ann['prompt']
        example = ann['full_prompt']

        prompt = torch.tensor(self.tokenizer1.encode(prompt, bos=True, eos=False), dtype=torch.int64)
        example = torch.tensor(self.tokenizer1.encode(example, bos=True, eos=True), dtype=torch.int64)

        labels = copy.deepcopy(example)
        labels[: len(prompt)] = -1
        return example, labels, prompt


def collate_fn(batch, max_length=2048):
    examples, labels, prompts = zip(*batch)
    # Trim sequences to max_length
    trimmed_examples = [example[:max_length] for example in examples]
    trimmed_labels = [label[:max_length] for label in labels]
    
    # Determine the maximum sequence length after trimming but capped at max_length
    max_length = min(max([len(example) for example in trimmed_examples]), max_length)

    # Pad sequences to the determined max_length
    padded_examples = torch.stack([torch.cat((example, torch.zeros(max_length - len(example), dtype=torch.int64) - 1)) if len(example) < max_length else example for example in trimmed_examples])
    padded_labels = torch.stack([torch.cat((label, torch.zeros(max_length - len(label), dtype=torch.int64) - 1)) if len(label) < max_length else label for label in trimmed_labels])

    example_masks = padded_examples.ge(0)
    label_masks = padded_labels.ge(0)

    padded_examples[~example_masks] = 0
    padded_labels[~label_masks] = 0

    example_masks = example_masks.float()
    label_masks = label_masks.float()

    return padded_examples, padded_labels, example_masks


def get_args_parser():
    parser = argparse.ArgumentParser("MAE pre-training", add_help=False)
    parser.add_argument(
        "--batch_size",
        default=6,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument("--epochs", default=3, type=int)
    parser.add_argument("--max_step", default=100, type=int)

    parser.add_argument("--warmup_epochs", default=0, type=int)

    parser.add_argument(
        "--accum_iter",
        default=1,
        type=int,
        help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)",
    )

    # Model parameters
    parser.add_argument("--llama_model_path", default="vol3/models/model-llama/llama-main/download/llama-2/7B", type=str, help="path of llama model")
    parser.add_argument("--tokenizer_path", default="vol3/models/model-llama/llama-main/download/llama-2", type=str, help="path of llama model")
    
    parser.add_argument("--model", default="llama7B_lora", type=str, metavar="MODEL", help="Name of model to train")

    parser.add_argument("--max_seq_len", type=int, default=3000, metavar="LENGTH", help="the maximum sequence length")
    
    parser.add_argument("--w_lora", type=bool, default=True, help="use lora or not")

    # Optimizer parameters
    parser.add_argument("--weight_decay", type=float, default=0.01, help="weight decay (default: 0.05)")

    parser.add_argument("--lr", type=float, default=2e-5, help="learning rate (absolute lr)")
    parser.add_argument("--clip", type=float, default=0.3, help="gradient clipping")

    parser.add_argument(
        "--blr",
        type=float,
        default=1e-3,
        metavar="LR",
        help="base learning rate: absolute_lr = base_lr * total_batch_size / 256",
    )
    parser.add_argument(
        "--min_lr", type=float, default=0.0, metavar="LR", help="lower lr bound for cyclic schedulers that hit 0"
    )


    # Dataset parameters
    parser.add_argument("--test_data_path", default="Private/LoRA-composition/LaMP_data-final/scholarly_title/user_anchor_candidate.json", type=str, help="dataset path")
    parser.add_argument("--train_data_path", default="Private/LoRA-composition/LaMP_data-final/scholarly_title/user_base_LLM.json", type=str, help="dataset path")
    
    parser.add_argument("--output_dir", default="./output_dir/Anchor-v2-step", help="path where to save, empty for no saving")

    parser.add_argument("--log_dir", default="./output_dir", help="path where to tensorboard log")
    parser.add_argument("--device", default="cuda", help="device to use for training / testing")
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--lora_ckpt", default='Private/LoRA-composition/scholarly/output_dir-basek1-OPPUqa/base-newProfile-k1/lora_ckpt.pt', help="resume lora from checkpoint")
    parser.add_argument("--grad_ckpt", type=bool, default=True, help="whether to user gradient checkpoint, recommend TRUE!!")

    parser.add_argument("--anchor_path", default='Private/LoRA-composition/scholarly/anchor_selection/history_avg/anchor_user_idx.pt', help="resume lora from checkpoint")

    parser.add_argument("--num_workers", default=10, type=int)

    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )

    parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem")
    parser.set_defaults(pin_mem=True)

    # generation hyperparameters
    parser.add_argument("--top_p", type=float, default=0.9, help="top_p")
    parser.add_argument("--temperature", type=float, default=0.1, help="temperature")
    parser.add_argument("--max_gen_len", type=int, default=10, help="top_p")

    parser.add_argument("--k_list", type=list, default=[1,2,4], help="top_p")
    parser.add_argument('--infer', default=False, action=argparse.BooleanOptionalAction)

    return parser

def get_first_k_tokens(text, k):
    """
    Extracts the first k tokens from a text string.

    :param text: The input text string.
    :param k: The number of tokens to extract.
    :return: The first k tokens of the text string.
    """
    # Split the text into tokens based on whitespace
    tokens = text.split()
    output = " ".join(tokens[:k])

    # Return the first k tokens
    return output


with open('./user_profile/all_profile_id2text.json', 'r') as f:
    all_profile = json.load(f)

import random


first_token_num = 368

def process_train_data(user, k_list=[1,2,4]):

    train_data = []
    # for q in user['profile']:
    #     prompt = 'Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story]\n description: {} tag:'.format(q['description'])
    #     full_prompt = 'Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story]\n description: {} tag: {}'.format(q['description'], q['tag'])
    #     train_data.append(
    #         {
    #             "prompt": prompt,
    #             "full_prompt": full_prompt
    #         }
    #     )

    #  augment query

    # user_profile = all_profile[str(user['user_id'])]
    
    # for idx, q in enumerate(user['profile']):
    #     if idx != 0:
    #         # k = random.sample([1,2,4], 1)[0]
    #         # for k in k_list:
    #         k = 1
    #         visible_history_list = user['profile'][:idx]
    #         history_list = ['Description: {} Tag: {}\n'.format(get_first_k_tokens(p['description'], 368), p['tag']) for p in visible_history_list]
    #         tokenized_corpus = [doc.split(" ") for doc in history_list]
    #         bm25 = BM25Okapi(tokenized_corpus)
            
    #         tokenized_query = 'Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. description: {}'.format(get_first_k_tokens(q['description'], 768)).split(' ')
    #         retrieved_history = bm25.get_top_n(tokenized_query, history_list, n=k)

    #         history_string = "".join(retrieved_history)

    #         prompt = '### User History:\n{}\n\n### User Instruction:\nWhich tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story]\nDescription: {} Tag:'.format(history_string, get_first_k_tokens(q['description'], 768))
    #         full_prompt = '### User History:\n{}\n\n### User Instruction:\nWhich tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story]\nDescription: {} Tag: {}'.format(history_string, get_first_k_tokens(q['description'], 768), q['tag'])

    #         profile_str = '### User Profile:\n{}\n\n'.format(user_profile)

    #         # if random.sample([0,1], 1)[0] == 1:
    #         prompt = profile_str + prompt
    #         full_prompt = profile_str + full_prompt
                
    #         train_data.append(
    #             {
    #                 "prompt": prompt,
    #                 "full_prompt": full_prompt
    #             }
    #         )

    # for q in user['profile']:
    #     prompt = '### User Profile:\n{}\n\n### User Instruction:\nWhich tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story]\n## Description: {} ## Tag:'.format(user_profile, get_first_k_tokens(q['description'], 768))
    #     full_prompt = '### User Profile:\n{}\n\n### User Instruction:\nWhich tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story]\n## Description: {} ## Tag: {}'.format(user_profile, get_first_k_tokens(q['description'], 768), q['tag'])
    #     train_data.append(
    #         {
    #             "prompt": prompt,
    #             "full_prompt": full_prompt
    #         }
    #     )

    
    
    user_profile = all_profile[str(user['user_id'])]

    for idx, q in enumerate(user['profile']):
        # for k in k_list:
        if idx != 0:
            # k = random.sample([1,2,4], 1)[0]
            k = 1
            visible_history_list = user['profile'][:idx]
            history_list = ['Abstract: {} Title: {}\n'.format(get_first_k_tokens(p['abstract'], first_token_num), p['title']) for p in visible_history_list]
            tokenized_corpus = [doc.split(" ") for doc in history_list]
            bm25 = BM25Okapi(tokenized_corpus)
            
            tokenized_query = 'Generate a title for the following abstract of a paper. abstract: {}'.format(get_first_k_tokens(q['abstract'], first_token_num)).split(' ')
            retrieved_history = bm25.get_top_n(tokenized_query, history_list, n=k)

            history_string = "".join(retrieved_history)

            prompt = '### User History:\n{}\n\n### User Instruction:\nGenerate a title for the following abstract of a paper.\nAbstract: {} Title:'.format(history_string, get_first_k_tokens(q['abstract'], first_token_num))
            full_prompt = '### User History:\n{}\n\n### User Instruction:\nGenerate a title for the following abstract of a paper.\nAbstract: {} Title: {}'.format(history_string, get_first_k_tokens(q['abstract'], first_token_num), q['title'])

            profile_str = '### User Profile:\n{}\n\n'.format(user_profile)
            
            # if random.sample([0,1], 1)[0] == 1:
            prompt = profile_str + prompt
            full_prompt = profile_str + full_prompt

            
            train_data.append(
                {
                    "prompt": prompt,
                    "full_prompt": full_prompt
                }
            )

    for idx, q in enumerate(user['profile']):

        prompt = '### User Instruction:\nGenerate a title for the following abstract of a paper.\nAbstract: {} Title:'.format(get_first_k_tokens(q['abstract'], first_token_num))
        full_prompt = '### User Instruction:\nGenerate a title for the following abstract of a paper.\nAbstract: {} Title: {}'.format(get_first_k_tokens(q['abstract'], first_token_num), q['title'])
            
        train_data.append(
            {
                "prompt": prompt,
                "full_prompt": full_prompt
            }
        )


    return train_data



def split_batch(init_list, batch_size):
    groups = zip(*(iter(init_list),) * batch_size)
    end_list = [list(i) for i in groups]
    count = len(init_list) % batch_size
    end_list.append(init_list[-count:]) if count != 0 else end_list
    return end_list



# def process_test_data(user, batch_size, k_list = [1,2,4]):
#     out_list = []
#     test_question_list = [] 
#     question_id_list = []
#     retrieval_test_question_list = [[] for _ in range(len(k_list))]

#     for q in user['query']:
#         test_question = q['input']

#         test_article = extract_article(test_question)

#         test_prompt = 'Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story]\n description: {} tag:'.format(test_article)
#         test_question_list.append(test_prompt)
#         question_id_list.append(q['id'])

#     history_list = ['description: {} tag: {}\n'.format(p['description'], p['tag']) for p in user['profile']]

#     tokenized_corpus = [doc.split(" ") for doc in history_list]
#     bm25 = BM25Okapi(tokenized_corpus)


#     for idx, k in enumerate(k_list):
#         for q in user['query']:
#             test_question = q['input']
#             test_article = extract_article(test_question)

#             tokenized_query = 'Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation.\n description: {}'.format(test_article).split(" ")
#             retrieved_history = bm25.get_top_n(tokenized_query, history_list, n=k)
        
#             history_string = "".join(retrieved_history)

#             test_prompt = '{}Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story]\n description: {} tag:'.format(history_string, test_article)

#             retrieval_test_question_list[idx].append(test_prompt)
#             # question_id_list.append(q['id'])
        

#     test_batch_list = split_batch(test_question_list, batch_size)
#     out_list.append(test_batch_list)

#     for i, k in enumerate(k_list):
#         out_list.append(split_batch(retrieval_test_question_list[i], batch_size))

#     all_test_question_list = [test_question_list] + retrieval_test_question_list

#     return out_list, question_id_list, all_test_question_list





def main(args):
    torch.set_default_device('cuda')

    # misc.init_distributed_mode(args)

    print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(", ", ",\n"))

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed # + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True

    anchor_idx = torch.load(args.anchor_path)

    with open(args.test_data_path, 'r') as f:
        all_user_data = json.load(f)
    
    # define the model
    model = load_model(
        ckpt_dir=args.llama_model_path,
        tokenizer_path=args.tokenizer_path,
        max_seq_len=args.max_seq_len,
        max_batch_size=args.batch_size,
        lora_path=args.lora_ckpt,
        w_lora=args.w_lora,
        grad_ckpt=args.grad_ckpt
    )

    model.to(device)
    model.print_trainable_params()
    model.merge_lora_parameters()
    print('merged!!')


    # print("Model = %s" % str(model))
    print("actual lr: %.2e" % args.lr)
    print("accumulate grad iterations: %d" % args.accum_iter)

    pred_all = [[] for _ in range(len(args.k_list)+1)]

    for idx in tqdm(range(len(anchor_idx))):
        idx_all_test = anchor_idx[idx]
        user = all_user_data[idx_all_test]

        user_out_dir = os.path.join(args.output_dir, 'user_{}'.format(user['user_id']))

        Path(user_out_dir).mkdir(parents=True, exist_ok=True)

        model.reset_lora_parameters()
        model.set_lora_trainable()
        
        data_list = process_train_data(user)

        dataset_train = InstructionDataset(
            data_list=data_list, tokenizer_path=args.tokenizer_path, max_tokens=args.max_seq_len
        )
    
        # sampler_train = torch.utils.data.RandomSampler(dataset_train)

        # os.makedirs(args.log_dir, exist_ok=True)
        # log_writer = SummaryWriter(log_dir=args.log_dir)
        # else:
        log_writer = None

        data_loader_train = torch.utils.data.DataLoader(
            dataset_train,
            # sampler=sampler_train,
            shuffle=True,
            batch_size=args.batch_size,
            # num_workers=args.num_workers,
            # pin_memory=args.pin_mem,
            drop_last=False,
            generator=torch.Generator(device='cuda'),
            collate_fn=partial(collate_fn, max_length=args.max_seq_len),
        )

        # following timm: set wd as 0 for bias and norm layers
        # param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay)
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weight_decay)
        # loss_scaler = NativeScaler()


        args.cur_step = 0
        try:
            print(f"Start training for {args.max_step} steps")
        except:
            print(f"Start training for {args.epochs} epochs")        
        
        start_time = time.time()
        # for epoch in range(args.epochs):
        epoch = 0

        while args.cur_step < args.max_step:
            train_stats = train_one_epoch(
                model, data_loader_train, optimizer, device, epoch, None, log_writer=log_writer, args=args
            )
            epoch += 1

            log_stats = {
                **{f"train_{k}": v for k, v in train_stats.items()},
                "epoch": epoch,
                # **{f"val_{k}": v for k, v in val_stats.items()},
            }

            if args.output_dir:
                if log_writer is not None:
                    log_writer.flush()
                with open(os.path.join(user_out_dir, "log.txt"), mode="a", encoding="utf-8") as f:
                    f.write(json.dumps(log_stats) + "\n")

        torch.save(model.lora_state_dict(), os.path.join(user_out_dir, 'lora_ckpt.pt'))
    
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print("Training time {}".format(total_time_str))

        # Inference stage
        
    #     generator = load_generator_from_trained(model, args.tokenizer_path)
    #     test_batch_list, test_id_list, test_question_list = process_test_data(user, batch_size=args.batch_size)


    #     for idx, setting in enumerate(test_batch_list):
    #         all_results = []

    #         for batch in setting:
    #             results = generator.generate(batch, max_gen_len=args.max_gen_len, temperature=args.temperature, top_p=args.top_p)
    #             all_results += results
            
    #         for i in range(len(all_results)):
    #             output = all_results[i].replace(test_question_list[idx][i], "")
    #             pred_all[idx].append({
    #                 "id": test_id_list[i],
    #                 "output": output,
    #                 })

    # name_list = ['NP'] + args.k_list

    # for idx, name in enumerate(name_list):
    #     output_file = {
    #         'task': 'LaMP_8',
    #         'golds': pred_all[idx],
    #     }

    #     with open(os.path.join(args.output_dir, 'output-OPPU-k{}.json'.format(name)), 'w') as f:
    #         json.dump(output_file, f, indent=4)

    
if __name__ == "__main__":

    args = get_args_parser()
    args = args.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)
