import os
import sys
import json
from tqdm import tqdm
import torch
from torch.utils.data import Dataset


sys.path.append(os.getcwd())


class EFPDataset(Dataset):
    def __init__(self, data_dir, tokenizer, max_length):
        self.data_dir = data_dir
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.prefix = "Event factuality prediction : "
        self.texts, self.labels = self.load_data()
        

    def load_data(self):
        with open(self.data_dir, 'r', encoding='utf8') as f:
            raw_data = [json.loads(line) for line in f]
        texts = []
        labels = []
        for item in tqdm(raw_data, desc='Loading data'):
            for event in item['events']:
                for mention in event['mention']:
                    item['tokens'][mention['sent_id']][mention['offset'][0]] = "<e>" + item['tokens'][mention['sent_id']][mention['offset'][0]]
                    item['tokens'][mention["sent_id"]][mention["offset"][1] - 1] = item['tokens'][mention["sent_id"]][mention["offset"][1] - 1] + "</e>"
                    text = " ".join(item['tokens'][mention['sent_id']])
                    text = self.prefix + text
                    texts.append(text)
                    item['tokens'][mention['sent_id']][mention['offset'][0]] = item['tokens'][mention['sent_id']][mention['offset'][0]].replace("<e>", "")
                    item['tokens'][mention["sent_id"]][mention["offset"][1] - 1] = item['tokens'][mention["sent_id"]][mention["offset"][1] - 1].replace("</e>", "")
                    labels.append(mention['factuality'])
        return texts, labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        source = self.tokenizer.encode_plus(
            self.texts[idx],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        target = self.tokenizer.encode_plus(
            self.labels[idx],
            max_length=5,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        input_ids = source['input_ids'].squeeze()
        attention_mask = source['attention_mask'].squeeze()
        labels = target['input_ids'].squeeze()
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }
    

class EFPDatasetR(Dataset):
    def __init__(self, data_dir, tokenizer, max_length):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data_dir = data_dir
        self.prefix = "Event factuality prediction : "
        self.texts, self.labels, self.causes, self.preconditions = self.load_data()
        
    def load_data(self):
        def parse_sentence(doc, sent_id, offsets, special_mark):
            if isinstance(offsets[0], int):
                doc['tokens'][sent_id][offsets[0]] = special_mark[0] + doc['tokens'][sent_id][offsets[0]]
                doc['tokens'][sent_id][offsets[1] - 1] = doc['tokens'][sent_id][offsets[1] - 1] + special_mark[1]
                text = " ".join(doc['tokens'][sent_id])
                doc['tokens'][sent_id][offsets[0]] = doc['tokens'][sent_id][offsets[0]].replace(special_mark[0], "")
                doc['tokens'][sent_id][offsets[1] - 1] = doc['tokens'][sent_id][offsets[1] - 1].replace(special_mark[1], "")
            else:
                for offset in offsets:
                    doc['tokens'][sent_id][offset[0]] = special_mark[0] + doc['tokens'][sent_id][offset[0]]
                    doc['tokens'][sent_id][offset[1] - 1] = doc['tokens'][sent_id][offset[1] - 1] + special_mark[1]
                text = " ".join(doc['tokens'][sent_id])
                for offset in offsets:
                    doc['tokens'][sent_id][offset[0]] = doc['tokens'][sent_id][offset[0]].replace(special_mark[0], "")
                    doc['tokens'][sent_id][offset[1] - 1] = doc['tokens'][sent_id][offset[1] - 1].replace(special_mark[1], "")
            return text

        with open(self.data_dir, 'r', encoding='utf8') as f:
            raw_data = [json.loads(line) for line in f]
        texts = []
        labels = []
        causes = []  
        preconditions = [] 

        for item in tqdm(raw_data, desc='Loading data'):
            events = {}
            for event in item['events']:
                events[event['id']] = event
            for event in item['events']:
                cause = []
                precondition = []
                cause_map = {}
                for relation in item["causal_relation"]["CAUSE"]:      
                    if event['id'] == relation[1]:
                        for mention in events[relation[0]]['mention']:
                            if mention['sent_id'] not in cause_map:
                                cause_map[mention['sent_id']] = [mention['offset']]
                            else:
                                cause_map[mention['sent_id']].append(mention['offset'])
                for key in cause_map:
                    cause.append(parse_sentence(item, key, cause_map[key], ["<c>", "</c>"]))

                precondition_map = {}
                for relation in item["causal_relation"]["PRECONDITION"]:
                    
                    if event['id'] == relation[1]:
                        for mention in events[relation[0]]['mention']:
                            if mention['sent_id'] not in precondition_map:
                                precondition_map[mention['sent_id']] = [mention['offset']]
                            else:
                                precondition_map[mention['sent_id']].append(mention['offset'])
                for key in precondition_map:
                    precondition.append(parse_sentence(item, key, precondition_map[key], ["<p>", "</p>"]))
                for mention in event['mention']:
                    text = parse_sentence(item, mention['sent_id'], mention['offset'], ["<e>", "</e>"])
                    text = self.prefix + text
                    texts.append(text)
                    labels.append(mention['factuality']) 
                    causes.append(cause)
                    preconditions.append(precondition)
        return texts, labels, causes, preconditions
    
    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        input_texts = self.texts[idx]
        if self.causes[idx] != []:
            cause_texts = " ".join(self.causes[idx])
            input_texts = input_texts + "EVENT CAUSE: " + cause_texts
        if self.preconditions[idx] != []:
            precondition_texts = " ".join(self.preconditions[idx])
            input_texts = input_texts + "EVENT PRECONDITION: " + precondition_texts
        source = self.tokenizer.encode_plus(
            input_texts,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        target = self.tokenizer.encode_plus(
            self.labels[idx],
            max_length=5,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        input_ids = source['input_ids'].squeeze()
        attention_mask = source['attention_mask'].squeeze()
        labels = target['input_ids'].squeeze()
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }
    

class EFPDatasetArg(Dataset):
    def __init__(self, data_dir, tokenizer, max_length):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data_dir = data_dir
        self.prefix = "Event factuality prediction : "
        self.texts, self.labels, self.arguments = self.load_data()

    def load_data(self):
        def parse_sentence(doc, sent_id, offsets, special_mark):
            doc['tokens'][sent_id][offsets[0]] = special_mark[0] + doc['tokens'][sent_id][offsets[0]]
            doc['tokens'][sent_id][offsets[1] - 1] = doc['tokens'][sent_id][offsets[1] - 1] + special_mark[1]
            text = " ".join(doc['tokens'][sent_id])
            doc['tokens'][sent_id][offsets[0]] = doc['tokens'][sent_id][offsets[0]].replace(special_mark[0], "")
            doc['tokens'][sent_id][offsets[1] - 1] = doc['tokens'][sent_id][offsets[1] - 1].replace(special_mark[1], "")
            return text
        
        with open(self.data_dir, 'r', encoding='utf8') as f:
            raw_data = [json.loads(line) for line in f]
        texts = []
        labels = []
        arguments = []

        for item in tqdm(raw_data, desc='Loading data'):
            for event in item['events']:
                for mention in event['mention']:
                    text = parse_sentence(item, mention['sent_id'], mention['offset'], ["<e>", "</e>"])
                    text = self.prefix + text
                    texts.append(text)
                    labels.append(mention['factuality'])
                    arg = ""
                    for argument in event['arguments']:
                        arg += "TYPE: " + argument['type'] + "; ENTITY: " + argument['mentions'][0]['mention'] + ". "
                    arguments.append(arg)
        return texts, labels, arguments


    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        input_texts = self.texts[idx]
        if self.arguments[idx] != "":
            input_texts = input_texts + "ARGUMENTS: " + self.arguments[idx]
        source = self.tokenizer.encode_plus(
            input_texts,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        target = self.tokenizer.encode_plus(
            self.labels[idx],
            max_length=5,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        input_ids = source['input_ids'].squeeze()
        attention_mask = source['attention_mask'].squeeze()
        labels = target['input_ids'].squeeze()
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }
        
        


class EFPDatasetArgR(Dataset):
    def __init__(self, data_dir, tokenizer, max_length):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data_dir = data_dir
        self.prefix = "Event factuality prediction : "
        self.texts, self.labels, self.arguments, self.causes, self.preconditions = self.load_data()

    def load_data(self):
        def parse_sentence(doc, sent_id, offsets, special_mark):
            if isinstance(offsets[0], int):
                doc['tokens'][sent_id][offsets[0]] = special_mark[0] + doc['tokens'][sent_id][offsets[0]]
                doc['tokens'][sent_id][offsets[1] - 1] = doc['tokens'][sent_id][offsets[1] - 1] + special_mark[1]
                text = " ".join(doc['tokens'][sent_id])
                doc['tokens'][sent_id][offsets[0]] = doc['tokens'][sent_id][offsets[0]].replace(special_mark[0], "")
                doc['tokens'][sent_id][offsets[1] - 1] = doc['tokens'][sent_id][offsets[1] - 1].replace(special_mark[1], "")
            else:
                for offset in offsets:
                    doc['tokens'][sent_id][offset[0]] = special_mark[0] + doc['tokens'][sent_id][offset[0]]
                    doc['tokens'][sent_id][offset[1] - 1] = doc['tokens'][sent_id][offset[1] - 1] + special_mark[1]
                text = " ".join(doc['tokens'][sent_id])
                for offset in offsets:
                    doc['tokens'][sent_id][offset[0]] = doc['tokens'][sent_id][offset[0]].replace(special_mark[0], "")
                    doc['tokens'][sent_id][offset[1] - 1] = doc['tokens'][sent_id][offset[1] - 1].replace(special_mark[1], "")
            return text
        
        with open(self.data_dir, 'r', encoding='utf8') as f:
            raw_data = [json.loads(line) for line in f]
        texts = []
        labels = []
        arguments = []
        causes = []
        preconditions = []

        for item in tqdm(raw_data, desc='Loading data'):
            events = {}
            for event in item['events']:
                events[event['id']] = event
            for event in item['events']:
                cause = []
                precondition = []
                cause_map = {}
                for relation in item["causal_relation"]["CAUSE"]:
                    if event['id'] == relation[1]:
                        for mention in events[relation[0]]['mention']:
                            if mention['sent_id'] not in cause_map:
                                cause_map[mention['sent_id']] = [mention['offset']]
                            else:
                                cause_map[mention['sent_id']].append(mention['offset'])
                for key in cause_map:
                    cause.append(parse_sentence(item, key, cause_map[key], ["<c>", "</c>"]))

                precondition_map = {}
                for relation in item["causal_relation"]["PRECONDITION"]:
                    if event['id'] == relation[1]:
                        for mention in events[relation[0]]['mention']:
                            if mention['sent_id'] not in precondition_map:
                                precondition_map[mention['sent_id']] = [mention['offset']]
                            else:
                                precondition_map[mention['sent_id']].append(mention['offset'])
                for key in precondition_map:
                    precondition.append(parse_sentence(item, key, precondition_map[key], ["<p>", "</p>"]))

                for mention in event['mention']:
                    text = parse_sentence(item, mention['sent_id'], mention['offset'], ["<e>", "</e>"])
                    text = self.prefix + text
                    texts.append(text)
                    labels.append(mention['factuality'])
                    arg = ""
                    for argument in event['arguments']:
                        arg += "type: " + argument['type'] + " text: " + argument['mentions'][0]['mention'] + "; "
                    arguments.append(arg)
                    causes.append(cause)
                    preconditions.append(precondition)
        return texts, labels, arguments, causes, preconditions
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        input_texts = self.texts[idx]
        if self.arguments[idx] != "":
            input_texts = input_texts + "ARGUMENTS: " + self.arguments[idx]
        if self.causes[idx] != []:
            cause_texts = " ".join(self.causes[idx])
            input_texts = input_texts + "EVENT CAUSE: " + cause_texts
        if self.preconditions[idx] != []:
            precondition_texts = " ".join(self.preconditions[idx])
            input_texts = input_texts + "EVENT PRECONDITION: " + precondition_texts
        source = self.tokenizer.encode_plus(
            input_texts,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        target = self.tokenizer.encode_plus(
            self.labels[idx],
            max_length=5,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        input_ids = source['input_ids'].squeeze()
        attention_mask = source['attention_mask'].squeeze()
        labels = target['input_ids'].squeeze()
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }