import copy

import torch
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl
from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS
from collections import OrderedDict
from util import SUBJECT_START, OBJECT_START, SUBJECT_END, OBJECT_END, save2txt

import os
import random


class InputExample(object):
    """A single training/test example for span pair classification."""

    def __init__(self, guid, sentence, ner1, ner2, prompt, tokens, label):
        self.guid = guid
        self.sentence = sentence
        self.ner1 = ner1
        self.ner2 = ner2
        self.prompt = prompt
        self.tokens = tokens
        self.label = label


def get_examples(data_f, rel2id, mask):
    examples = []
    with open(data_f, "r", encoding='utf-8') as reader:
        lines = reader.readlines()
        for line in lines:
            ins = eval(line)
            sentence = ins['token']
            tokens = []
            for i, token in enumerate(sentence):
                if i == ins['h']['pos'][0]:
                    tokens.append(SUBJECT_START)
                if i == ins['t']['pos'][0]:
                    tokens.append(OBJECT_START)
                if i == ins['h']['pos'][1]:
                    tokens.append(SUBJECT_END)
                if i == ins['t']['pos'][1]:
                    tokens.append(OBJECT_END)
                tokens.append(token)

            SUBJECT = " ".join(sentence[ins['h']['pos'][0]: ins['h']['pos'][1]])
            OBJECT = " ".join(sentence[ins['t']['pos'][0]: ins['t']['pos'][1]])

            if 'name' in ins['h'].keys() and 'name' in ins['t'].keys():
                assert SUBJECT == ins['h']['name'] and OBJECT == ins['t']['name']

            # prompt = f"[sub] {SUBJECT} [sub] {tokenizer.mask_token} [obj] {OBJECT} [obj] ."
            prompt = f"{SUBJECT} {mask} {OBJECT} ."

            examples.append(InputExample(guid=None,
                                         sentence=' '.join(sentence),
                                         ner1=SUBJECT,
                                         ner2=OBJECT,
                                         prompt=prompt,
                                         tokens=tokens,
                                         label=rel2id[ins['relation']]))
    return examples


class MyDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, rel2id, tokenizer, know_enhancer, context_save_dir=None,
                 max_seq_length=256, bs=2, nw=1, p_topk=2):
        super().__init__()
        self.data_dir = data_dir
        self.rel2id = rel2id
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.batch_size = bs
        self.num_workers = nw

        self.prompt_topk = p_topk
        self.raw4context = 'Generate a short background text about entity "{}" in the sentence "{}" to assist the ' \
                           'relation identification task. Directly give the background text within fifty words. ' \
                           'Do not generate any unrelated contents.'
        self.pos_prompt = 'According to the experiment, the background texts "{}" and "{}" improve relation ' \
                          'identification between entities "{}" and "{}" in the sentence "{}". Please give me ' \
                          'the reasons to explain the improvement.'
        self.neg_prompt = 'According to the experiment, the background texts "{}" and "{}" are unfavorable for ' \
                          'relation identification between entities "{}" and "{}" in the sentence "{}". ' \
                          'Please give me the reasons.'
        self.improve_reason = 'Be precise and concise.'
        self.reason2rule = 'Please rewrite these reasons into rules to generate background texts in aid of relation ' \
                           'identification, using the format of “if..., then...”. Give it in sections. ' \
                           'Each is an independent rule. Directly give the content of the rule. ' \
                           'Do not answer anything else.\nReasons:\n{}'
        self.rule4context = 'Given the following rules:\n{}\n'

        self.prompt = None
        self.rule_history = []
        self.train_context_history = []

        self.prompt_save = os.path.join(context_save_dir, 'prompt_epoch_{}.txt')
        self.t_context_save = os.path.join(context_save_dir, 'train_epoch_{}.txt')

        self.load_examples()

        self.know_enhancer = know_enhancer

    def load_examples(self):
        # load original data
        self.train_examples = get_examples(os.path.join(self.data_dir, 'train.txt'), self.rel2id,
                                           self.tokenizer.mask_token)
        self.val_examples = get_examples(os.path.join(self.data_dir, 'val.txt'), self.rel2id, self.tokenizer.mask_token)
        self.test_examples = get_examples(os.path.join(os.path.dirname(self.data_dir), 'test.txt'),
                                          self.rel2id, self.tokenizer.mask_token)

    def prompt_summary(self, init_pred, val_stats):

        prompt_f = open(self.prompt_save.format(self.trainer.model.current_epoch), 'w')
        if init_pred is None:
            assert self.trainer.model.current_epoch == 0
            prompt = self.raw4context
        else:
            assert len(self.train_context_history) >= 1
            if len(val_stats) == 1:
                assert self.trainer.model.current_epoch == self.trainer.reload_dataloaders_every_n_epochs
                jud = init_pred
            else:
                assert len(val_stats) >= 2
                jud = [ni - pi > 0 for ni, pi in zip(val_stats[-1], val_stats[-2])]

            example_with_contexts = list(zip(self.train_examples, self.train_context_history[-1]))
            pos_examples = [example_with_contexts[i] for i in range(len(jud)) if jud[i]]
            neg_examples = [example_with_contexts[i] for i in range(len(jud)) if not jud[i]]
            num_samples = min([len(pos_examples), len(neg_examples), self.prompt_topk])

            pos_samples = random.sample(pos_examples, num_samples)
            neg_samples = random.sample(neg_examples, num_samples)

            reasons = []
            ind = 1

            pos_ind = 1
            for pos, ctxt in pos_samples:
                assert isinstance(pos, InputExample)
                assert len(ctxt) == 2
                prompt_f.write(f'########## POS Example {pos_ind} ##########')
                prompt_f.write('\n')
                prompt_f.write('sentence:')
                prompt_f.write(pos.sentence.strip())
                prompt_f.write('\n')

                dial1 = [{"role": "user", "content": self.pos_prompt.format(ctxt[0], ctxt[1], pos.ner1, pos.ner2,
                                                                            pos.sentence)}]
                raw_reason = self.know_enhancer.llm_generate(dial1, max_new_tokens=512)
                prompt_f.write('raw reasons:')
                prompt_f.write(raw_reason)
                prompt_f.write('\n')

                dial2 = dial1 + [{"role": "assistant", "content": raw_reason},
                                 {"role": "user", "content": self.improve_reason}]
                reason = self.know_enhancer.llm_generate(dial2, max_new_tokens=128)
                prompt_f.write('reason:')
                prompt_f.write(reason)
                prompt_f.write('\n')

                reasons.append(f'{ind}.{reason}\n')
                ind += 1
                pos_ind += 1

            neg_ind = 1
            for neg, ctxt in neg_samples:
                assert isinstance(neg, InputExample)
                assert len(ctxt) == 2
                prompt_f.write(f'########## NEG Example {neg_ind} ##########')
                prompt_f.write('\n')
                prompt_f.write('sentence:')
                prompt_f.write(neg.sentence.strip())
                prompt_f.write('\n')

                dial1 = [{"role": "user", "content": self.neg_prompt.format(ctxt[0], ctxt[1], neg.ner1, neg.ner2,
                                                                            neg.sentence)}]
                raw_reason = self.know_enhancer.llm_generate(dial1, max_new_tokens=512)
                prompt_f.write('raw reasons:')
                prompt_f.write(raw_reason)
                prompt_f.write('\n')

                dial2 = dial1 + [{"role": "assistant", "content": raw_reason},
                                 {"role": "user", "content": self.improve_reason}]
                reason = self.know_enhancer.llm_generate(dial2, max_new_tokens=128)
                prompt_f.write('reason:')
                prompt_f.write(reason)
                prompt_f.write('\n')

                reasons.append(f'{ind}.{reason}\n')
                ind += 1
                neg_ind += 1

            dial_rule = [{"role": "user", "content": self.reason2rule.format(''.join(reasons))}]
            raw_rule = self.know_enhancer.llm_generate(dial_rule, max_new_tokens=512)

            try:
                assert "the format of" in raw_rule
            except:
                print(f'val in epoch {self.trainer.model.current_epoch} derives rules of "{raw_rule}"')
            try:
                rule = ''.join(raw_rule.split(":")[1:]).strip()
            except:
                rule = raw_rule

            prompt_f.write('########## RULE ##########')
            prompt_f.write('\n')
            prompt_f.write(rule)
            prompt_f.write('\n')

            prompt = self.rule4context.format(rule) + self.raw4context

        prompt_f.write('########## PROMPT ##########')
        prompt_f.write('\n')
        prompt_f.write(prompt)
        prompt_f.write('\n')

        return prompt

    def convert_examples_to_features(self, examples, is_training=True):
        instances = []
        contexts = []
        for (ex_index, example) in enumerate(examples):
            """
                the relation between SUBJECT and OBJECT is .
            """
            tokens = copy.deepcopy(example.tokens)
            inputs = self.tokenizer(
                example.prompt,
                " ".join(tokens),
                truncation="longest_first",
                max_length=self.max_seq_length,
                padding="max_length",
                add_special_tokens=True
            )
            x = OrderedDict()
            x['input_ids'] = inputs['input_ids']
            x['attention_mask'] = inputs['attention_mask']
            x['label'] = example.label

            if is_training:
                ctxt_dial_1 = [{"role": "user", "content": self.prompt.format(example.ner1, example.sentence)}]
                raw_context_1 = self.know_enhancer.llm_generate(ctxt_dial_1, max_new_tokens=64)
                ctxt_dial_2 = [{"role": "user", "content": self.prompt.format(example.ner2, example.sentence)}]
                raw_context_2 = self.know_enhancer.llm_generate(ctxt_dial_2, max_new_tokens=64)
                # assert example.ner1 in raw_context_1 and example.ner2 in raw_context_2

                i_contexts = []
                for raw_context in [raw_context_1, raw_context_2]:
                    try:
                        assert "background text" in raw_context_1
                    except:
                        print(
                            f'example {ex_index} in epoch {self.trainer.model.current_epoch} derives context of "{raw_context}"')
                    try:
                        context = raw_context.split(':')[1].strip()
                    except:
                        context = raw_context

                    i_contexts.append(context)

                contexts.append(i_contexts)

                inputs_e1 = self.tokenizer(i_contexts[0], truncation="longest_first", max_length=64,
                                           padding="max_length", add_special_tokens=True)
                inputs_e2 = self.tokenizer(i_contexts[1], truncation="longest_first", max_length=64,
                                           padding="max_length", add_special_tokens=True)
                x['e1_input_ids'] = inputs_e1['input_ids']
                x['e1_attention_mask'] = inputs_e1['attention_mask']

                x['e2_input_ids'] = inputs_e2['input_ids']
                x['e2_attention_mask'] = inputs_e2['attention_mask']
            instances.append(x)

        input_ids = [o['input_ids'] for o in instances]
        attention_mask = [o['attention_mask'] for o in instances]
        labels = [o['label'] for o in instances]
        input_ids = torch.tensor(input_ids)
        attention_mask = torch.tensor(attention_mask)
        labels = torch.tensor(labels)

        if is_training:
            e1_input_ids = [o['e1_input_ids'] for o in instances]
            e1_attention_mask = [o['e1_attention_mask'] for o in instances]
            e2_input_ids = [o['e2_input_ids'] for o in instances]
            e2_attention_mask = [o['e2_attention_mask'] for o in instances]
            e1_input_ids = torch.tensor(e1_input_ids)
            e1_attention_mask = torch.tensor(e1_attention_mask)
            e2_input_ids = torch.tensor(e2_input_ids)
            e2_attention_mask = torch.tensor(e2_attention_mask)

            dataset = TensorDataset(input_ids, attention_mask, labels,
                                    e1_input_ids, e1_attention_mask, e2_input_ids, e2_attention_mask)
        else:
            dataset = TensorDataset(input_ids, attention_mask, labels)
            assert len(contexts) == 0
        return dataset, contexts

    def get_train_data_loader(self):
        dataset, contexts = self.convert_examples_to_features(self.train_examples)
        dataloader = DataLoader(dataset, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers,
                                pin_memory=True)
        return dataloader, contexts

    def get_val_data_loader(self):
        dataset, contexts = self.convert_examples_to_features(self.val_examples, is_training=False)
        dataloader = DataLoader(dataset, shuffle=False, batch_size=1, num_workers=self.num_workers,
                                pin_memory=True)
        return dataloader, contexts

    def get_test_data_loader(self):
        dataset, contexts = self.convert_examples_to_features(self.test_examples, is_training=False)
        dataloader = DataLoader(dataset, shuffle=False, batch_size=1, num_workers=self.num_workers,
                                pin_memory=True)
        return dataloader, contexts

    def train_dataloader(self) -> TRAIN_DATALOADERS:
        # update prompt
        init_pred = self.trainer.model.init_pred
        val_stats = self.trainer.model.val_stats
        prompt = self.prompt_summary(init_pred, val_stats)

        self.rule_history.append(prompt)
        self.prompt = prompt

        train_data, contexts = self.get_train_data_loader()
        save2txt(self.t_context_save.format(self.trainer.model.current_epoch), contexts)
        self.train_context_history.append(contexts)
        return train_data

    def val_dataloader(self) -> EVAL_DATALOADERS:
        val_data, contexts = self.get_val_data_loader()
        return val_data

    def test_dataloader(self) -> EVAL_DATALOADERS:
        test_data, contexts = self.get_test_data_loader()
        return test_data
