import json, os, torch, copy
from tqdm import tqdm
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import configparser

def parse_path_cfg():
    cfg = configparser.ConfigParser()
    cfg.read("dataset.cfg")
    path = cfg['PATH']
    return path

def get_vocab(path):
    # path = parse_path_cfg()
    if os.path.exists(path['VOCAB_FILE']):
        print('Loading from ', path['VOCAB_FILE'])
        with open(path['VOCAB_FILE'], 'r', encoding='utf-8') as f:
            word2index = json.load(f)
    else:
        print('Missing vocab file at ', path['VOCAB_FILE'])
        exit(0)
    return word2index

class CitationDataset(Dataset):
    def __init__(self, citation_dataset):
        self.citation_dataset = citation_dataset

    def __len__(self):
        return len(self.citation_dataset)

    def __getitem__(self, item):
        return self.citation_dataset[item]

class CitationDataLoader:
    def __init__(self, citation_dataset: list, batch_size, shuffle):
        self.citation_dataset = citation_dataset
        self.indexes = DataLoader(CitationDataset([i for i, _ in enumerate(citation_dataset)]), batch_size=batch_size,
                                  shuffle=shuffle, drop_last=True)

    def __len__(self):
        return len(self.citation_dataset)

    def __iter__(self):
        for batch in self.indexes:
            yield [self.citation_dataset[i.item()] for i in batch]

    def __del__(self):
        del self.citation_dataset
        del self.indexes

def vectorize_sent(word2index, sent, max_len, is_tgt=False):
    target_ = []
    for word in sent:
        word = word.lower()
        if word in word2index:
            target_.append(word2index[word])
        else:
            target_.append(word2index["UNK"])
        if max_len is not None and len(target_) == max_len:
            break

    if is_tgt:
        target_ = [word2index["SOS"]] + target_ + [word2index["EOS"]]
    return target_

def process_doc(sents, word2index, max_token=200):
    num_sents = len(sents)
    input_lens = [len(s) for s in sents]
    sort_idx = np.argsort(input_lens)
    total_token = 0
    rest_sents = num_sents
    for sent in sents:
        total_token += len(sent)
    if total_token > max_token:
        for i in sort_idx:
            token_per_sent = max_token // rest_sents
            sents[i] = vectorize_sent(word2index, sents[i], token_per_sent)
            max_token = max_token - len(sents[i])
            rest_sents -= 1

    result = []
    end_index = []
    idx = 0
    for i in range(num_sents):
        result += sents[i]
        idx += len(sents[i]) - 1
        end_index.append(idx)
    assert len(result) <= max_token
    return result, end_index

def vectorize_sents(word2index, sents, max_len, is_context=False):
    result = []
    end_index = []
    if is_context:
        sents = sents[:49]
    for sent in sents:
        s = vectorize_sent(word2index, sent, max_len)
        result += s
        # idx += len(s)
        end_index.append(len(s))
    return result, end_index

def create_src_map(sents, word2index):
    src_vocab = {"UNK": 0, "PAD": 1}
    src_vocab_to_vocab = {src_vocab["UNK"]: word2index["UNK"], src_vocab["PAD"]: word2index["PAD"]}
    src = []
    for sent in sents:
        src += sent
    for word in src:
        if word not in src_vocab:
            src_vocab[word] = len(src_vocab)
            if word not in word2index:
                word = "UNK"
            src_vocab_to_vocab[src_vocab[word]] = word2index[word]

    src_map = [src_vocab[w] for w in src]
    return src_vocab, src_map, src_vocab_to_vocab

def citation_data_reader(path):
    file_path = path["CITATION_FILE"]
    with open (path["PAPER_MAP_FILE"], "r") as f:
        paper_abs_map = json.load(f)
    f.close()
    with open (path["CITE_NEIGHBOR_FILE"], "r") as f:
        cite = json.load(f)
    f.close()
    print('Reading from %s' % file_path, end=' ... ')
    dataset = []
    with open(file_path, 'r') as f:
        for line in tqdm(f, disable=True):
            sample = json.loads(line.rstrip("\n"))

            sample['src_abstract_bow'] = paper_abs_map[sample['src_paper_id']]
            src_neighbor = cite[sample['src_paper_id']]
            sample['src_neighbor'] = []
            for paper in src_neighbor:
                sample['src_neighbor'].append(paper_abs_map[paper])

            sample['tgt_abstract_bow'] = paper_abs_map[sample['tgt_paper_id']]
            tgt_neighbor = cite[sample['tgt_paper_id']]
            sample['tgt_neighbor'] = []
            for paper in tgt_neighbor:
                sample['tgt_neighbor'].append(paper_abs_map[paper])

            dataset.append(sample)
    f.close()
    print('Totally %d instances' % len(dataset))
    return dataset

def data_reader_generation(data_type, path, word2index, max_len=20):
    assert data_type in ['train', 'dev', 'test']
    print('Reading from %s' % path["PAPER_MAP_FILE"], end=' ... ')
    with open (path["PAPER_MAP_FILE"], "r") as f:
        paper_abs_map = json.load(f)
    f.close()
    print('Reading from %s' % path["CITE_NEIGHBOR_FILE"], end=' ... ')
    with open (path["CITE_NEIGHBOR_FILE"], "r") as f:
        cite = json.load(f)
    f.close()
    file_path = path['%s_FILE' % data_type.upper()]
    print('Reading from %s' % file_path, end=' ... ')
    dataset = []
    with open(file_path, 'r') as f:
        for line in tqdm(f, disable=True):
            sample = json.loads(line.rstrip("\n"))
            if len(sample['context']) == 0:
                continue

            sample['src_abstract_bow'] = paper_abs_map[sample['src_paper_id']]
            src_neighbor = cite[sample['src_paper_id']]
            sample['src_neighbor'] = []
            for paper in src_neighbor:
                sample['src_neighbor'].append(paper_abs_map[paper])

            sample['tgt_abstract_bow'] = paper_abs_map[sample['tgt_paper_id']]
            tgt_neighbor = cite[sample['tgt_paper_id']]
            sample['tgt_neighbor'] = []
            for paper in tgt_neighbor:
                sample['tgt_neighbor'].append(paper_abs_map[paper])
            sample['relation'] = 1

            sample['src_abstract'], sample['src_abs_end_idx'] = vectorize_sents(word2index, sample['src_abstract'], None)
            abs_src_vocab, sample['abs_src_map'], src_vocab_to_vocab = create_src_map(sample['tgt_abstract'], word2index)
            src_vocab = {"stoi": copy.deepcopy(abs_src_vocab), "itos": {}}
            for w in abs_src_vocab:
                src_vocab["itos"][abs_src_vocab[w]] = w
            sample["abs_src_vocab"] = src_vocab
            sample['tgt_abstract'], sample['tgt_abs_end_idx'] = vectorize_sents(word2index, sample['tgt_abstract'], None)

            sample['abs_align'] = [0]
            for w in sample['citation'][0][:50]:
                if w not in abs_src_vocab:
                    sample['abs_align'].append(abs_src_vocab['UNK'])
                else:
                    sample['abs_align'].append(abs_src_vocab[w])
            sample['abs_align'].append(0)

            context_src_vocab, sample['context_src_map'], src_vocab_to_vocab = create_src_map(sample['context'], word2index)
            src_vocab = {"stoi": copy.deepcopy(context_src_vocab), "itos": {}}
            for w in context_src_vocab:
                src_vocab["itos"][context_src_vocab[w]] = w
            sample["context_src_vocab"] = src_vocab
            sample['context'], sample['context_end_idx'] = vectorize_sents(word2index, sample['context'], None, is_context=True)
            # if len(sample['context_end_idx']) > 49:
            #     print("here!")

            sample['context_align'] = [0]
            for w in sample['citation'][0][:50]:
                if w not in context_src_vocab:
                    sample['context_align'].append(context_src_vocab['UNK'])
                else:
                    sample['context_align'].append(context_src_vocab[w])
            sample['context_align'].append(0)

            sample['citation'] = vectorize_sent(word2index, sample['citation'][0], 50, is_tgt=True)

            dataset.append(sample)
    f.close()
    print('Totally %d instances' % len(dataset))
    return dataset

def get_citation_DataLoader(path, batch_size):
    train_dataset = citation_data_reader(path)
    train = CitationDataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    return train

def get_DataLoader(path, word2index, batch_size, load_train):
    if load_train:
        train_dataset = data_reader_generation('train', path, word2index)
        train = CitationDataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    else:
        train = None
    test_dataset = data_reader_generation('test', path, word2index)
    test = CitationDataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train, test

if __name__ == '__main__':
    path = parse_path_cfg()
    # entity_map, relation_map = load_kg(path)
    word2index = get_vocab(path)

    train, dev = get_DataLoader(path, word2index, 32, True)

    # print(len(dev))

    cnt = 0
    for di, data in enumerate(train):
        cnt += len(data)
        print(cnt)
        assert len(data) == 32
        # print(data)
        # break
    print(cnt)