"""
Preprocess AMR and surface forms.

Options:

- linearise: for seq2seq learning
  - simplify: simplify graphs and lowercase surface
  - anon: same as above but with anonymisation

- triples: for graph2seq learning
  - anon: anonymise NEs

"""
import re
import penman
import torch
from torch import nn
import os
import json

class MaxPoolLayer(nn.Module):
    """
    A layer that performs max pooling along the sequence dimension
    """

    def __init__(self):
        super().__init__()

    def forward(self, inputs, mask_or_lengths):
        """
        inputs: tensor of shape (batch_size, seq_len, hidden_size)
        mask_or_lengths: tensor of shape (batch_size) or (batch_size, seq_len)

        returns: tensor of shape (batch_size, hidden_size)
        """
        bs, sl, _ = inputs.size()
        if len(mask_or_lengths.size()) == 1:
            mask = (torch.arange(sl, device=inputs.device).unsqueeze(0).expand(bs, sl) >= mask_or_lengths.unsqueeze(1))
        else:
            mask = mask_or_lengths
        masked_inputs = inputs.masked_fill(mask.unsqueeze(-1).expand_as(inputs), float('-inf'))
        max_pooled = masked_inputs.max(1)[0]
        return max_pooled


class MeanPoolLayer(nn.Module):
    """
    A layer that performs mean pooling along the sequence dimension
    """

    def __init__(self):
        super().__init__()

    def forward(self, inputs, mask_or_lengths=None):
        """
        inputs: tensor of shape (batch_size, seq_len, hidden_size)
        mask_or_lengths: tensor of shape (batch_size) or (batch_size, seq_len)

        returns: tensor of shape (batch_size, hidden_size)
        """
        bs, _ = inputs.size()
        # if len(mask_or_lengths.size()) == 1:
        #     mask = (torch.arange(sl, device=inputs.device).unsqueeze(0).expand(bs, sl) >= mask_or_lengths.unsqueeze(1))
        #     lengths = mask_or_lengths.float()
        # else:
        #     mask, lengths = mask_or_lengths, (1 - mask_or_lengths.float()).sum(1)
        # masked_inputs = inputs.masked_fill(mask.unsqueeze(-1).expand_as(inputs), 0.0)
        mean_pooled = inputs.sum(0) / bs
        return mean_pooled


class MeanPoolLayer_old(nn.Module):
    """
    A layer that performs mean pooling along the sequence dimension
    """

    def __init__(self):
        super().__init__()

    def forward(self, inputs, mask_or_lengths):
        """
        inputs: tensor of shape (batch_size, seq_len, hidden_size)
        mask_or_lengths: tensor of shape (batch_size) or (batch_size, seq_len)

        returns: tensor of shape (batch_size, hidden_size)
        """
        bs, sl, _ = inputs.size()
        if len(mask_or_lengths.size()) == 1:
            mask = (torch.arange(sl, device=inputs.device).unsqueeze(0).expand(bs, sl) >= mask_or_lengths.unsqueeze(1))
            lengths = mask_or_lengths.float()
        else:
            mask, lengths = mask_or_lengths, (1 - mask_or_lengths.float()).sum(1)
        masked_inputs = inputs.masked_fill(mask.unsqueeze(-1).expand_as(inputs), 0.0)
        mean_pooled = masked_inputs.sum(1) / lengths.unsqueeze(-1)
        return mean_pooled

def simplify(tokens, v2c):
    SENSE_PATTERN = re.compile('-[0-9][0-9]$')
    mapping = {}
    new_tokens = []
    for idx, tok in enumerate(tokens):
        # ignore instance-of
        if tok.startswith('('):
            new_tokens.append('(')
            last_map = tok.replace("(", "")
            continue
        elif tok == '/':
            save_map = True
            continue
        # predicates, we remove any alignment information and parenthesis
        elif tok.startswith(':'):

            new_tok = tok.strip(')')
            new_tok = new_tok.split('~')[0]
            new_tokens.append(new_tok)

            count_ = tok.count(')')
            for _ in range(count_):
                new_tokens.append(')')

        # concepts/reentrancies, treated similar as above
        else:
            new_tok = tok.strip(')')
            new_tok = new_tok.split('~')[0]

            if new_tok == "":
                continue

            # now we check if it is a concept or a variable (reentrancy)
            if new_tok in v2c:
                # reentrancy: replace with concept
                if new_tok not in mapping:
                    mapping[new_tok] = set()
                mapping[new_tok].add(len(new_tokens))
                # except:
                #     print(new_tokens)
                #     print(" ".join(tokens))
                #     print(new_tok)
                #     print(mapping)
                #     print("xx")
                #     exit()
                if v2c[new_tok] is not None:
                    new_tok = v2c[new_tok]


            # check number
            elif new_tok.isnumeric():
                new_tok = new_tok
            # remove sense information
            elif re.search(SENSE_PATTERN, new_tok):
                new_tok = new_tok[:-3]
            # remove quotes
            elif new_tok[0] == '"' and new_tok[-1] == '"':
                new_tok = new_tok[1:-1]

            if new_tok != "":
                new_tokens.append(new_tok)

            if save_map:
                if last_map not in mapping:
                    mapping[last_map] = set()

                mapping[last_map].add(len(new_tokens) - 1)
                save_map = False

            count_ = tok.count(')')
            for _ in range(count_):
                new_tokens.append(')')

    return new_tokens, mapping



def simplify_nopar(tokens, v2c):
    SENSE_PATTERN = re.compile('-[0-9][0-9]$')
    mapping = {}
    new_tokens = []
    for idx, tok in enumerate(tokens):
        # ignore instance-of
        if tok.startswith('('):
            #new_tokens.append('(')
            last_map = tok.replace("(", "")
            continue
        elif tok == '/':
            save_map = True
            continue
        # predicates, we remove any alignment information and parenthesis
        elif tok.startswith(':'):

            new_tok = tok.strip(')')
            new_tok = new_tok.split('~')[0]
            new_tokens.append(new_tok)

            count_ = tok.count(')')
            # for _ in range(count_):
            #     new_tokens.append(')')

        # concepts/reentrancies, treated similar as above
        else:
            new_tok = tok.strip(')')
            new_tok = new_tok.split('~')[0]

            if new_tok == "":
                continue

            # now we check if it is a concept or a variable (reentrancy)
            if new_tok in v2c:
                # reentrancy: replace with concept
                if new_tok not in mapping:
                    mapping[new_tok] = set()
                mapping[new_tok].add(len(new_tokens))
                # except:
                #     print(new_tokens)
                #     print(" ".join(tokens))
                #     print(new_tok)
                #     print(mapping)
                #     print("xx")
                #     exit()
                if v2c[new_tok] is not None:
                    new_tok = v2c[new_tok]


            # check number
            elif new_tok.isnumeric():
                new_tok = new_tok
            # remove sense information
            elif re.search(SENSE_PATTERN, new_tok):
                new_tok = new_tok[:-3]
            # remove quotes
            elif new_tok[0] == '"' and new_tok[-1] == '"':
                new_tok = new_tok[1:-1]

            if new_tok != "":
                new_tokens.append(new_tok)

            if save_map:
                if last_map not in mapping:
                    mapping[last_map] = set()

                mapping[last_map].add(len(new_tokens) - 1)
                save_map = False

            count_ = tok.count(')')
            # for _ in range(count_):
            #     new_tokens.append(')')

    return new_tokens, mapping


def get_positions(new_tokens, src):
    pos = []
    for idx, n in enumerate(new_tokens):
        if n == src:
            pos.append(idx)

    return pos


def get_line_graph_new(graph, new_tokens, mapping, roles_in_order, amr):
    triples = []
    nodes_to_print = new_tokens

    graph_triples = graph.triples

    edge_id = -1
    triples_set = set()
    count_roles = 0
    for triple in graph_triples:
        src, edge, tgt = triple

        # try:

        if edge == ':instance' or edge == ':instance-of':
            continue

        # print(triple)

        # if penman.layout.appears_inverted(graph_penman, v):
        if "-of" in roles_in_order[count_roles] and "-off" not in roles_in_order[count_roles]:
            if edge != ':consist-of':
                edge = edge + "-of"
                old_tgt = tgt
                tgt = src
                src = old_tgt

        try:
            assert roles_in_order[count_roles] == edge
        except:
            print(roles_in_order)
            print(count_roles)
            print(edge)
            # import pdb
            # pdb.set_trace()
        count_roles += 1

        if edge == ':wiki':
            continue

        # print(edge)
        # process triples
        src = str(src).replace("\"", "")
        tgt = str(tgt).replace("\"", "")

        # except:
        #     print(count_roles, edge)
        #     print(graph_triples)
        #     print(roles_in_order)
        #     # print(get_roles_penman(graph_triples, roles_in_order))
        #     # print(" ".join(new_tokens))
        #     # exit()

        try:
            if src not in mapping:
                src_id = get_positions(new_tokens, src)
            else:
                src_id = sorted(list(mapping[src]))
            # check edge to verify
            edge_id = get_edge(new_tokens, edge, edge_id, triple, mapping, graph)

            if tgt not in mapping:
                tgt_id = get_positions(new_tokens, tgt)
            else:
                tgt_id = sorted(list(mapping[tgt]))
        except:
            print(graph_triples)
            print(src, edge, tgt)
            print("error")
            # print(mapping)

            # print(graph_triples)
            print(" ".join(new_tokens))
            # print(" ".join(tokens))
            #exit()

        # if len(src_id) > 3:
        #     print(src_id)
        #     for idx, n in enumerate(new_tokens):
        #         print(str(idx) + n, end=" ")
        #     print("")
        #     print(mapping)
        #     exit()

        # if idx == 967:
        #
        #     print(mapping)
        #
        #     #print(graph_triples)
        #     print(" ".join(new_tokens))
        #     print(" ".join(tokens))
        #     for triple in graph_triples:
        #         src, edge, tgt = triple
        #         if edge == ':top':
        #             # store this to add scope later
        #             top_node = get_name(tgt, v2c)
        #             continue
        #         if edge == ':instance-of' or edge == ':wiki':
        #             # if edge == ':instance-of':
        #             continue
        #         print(get_name(src, v2c), edge, get_name(tgt, v2c))
        #
        #     exit()

        # for s_id in src_id:
        #     triples.append((s_id, s_id, 's'))
        # for t_id in tgt_id:
        #     triples.append((t_id, t_id, 's'))

        for s_id in src_id:
            if (s_id, edge_id, 'd') not in triples_set:
                triples.append((s_id, edge_id, 'd'))
                triples_set.add((s_id, edge_id, 'd'))
                triples.append((edge_id, s_id, 'r'))
        for t_id in tgt_id:
            if (edge_id, t_id, 'd') not in triples_set:
                triples.append((edge_id, t_id, 'd'))
                triples_set.add((edge_id, t_id, 'd'))
                triples.append((t_id, edge_id, 'r'))

    if nodes_to_print == []:
        # single node graph, first triple is ":top", second triple is the node
        triples.append((0, 0, 's'))
    return nodes_to_print, triples


def get_edge(tokens, edge, edge_id, triple, mapping, graph):
    for idx in range(edge_id + 1, len(tokens)):
        if tokens[idx] == edge:
            return idx

    print(tokens)
    print(len(tokens))
    print(triple)
    print(graph.triples)
    print(mapping)
    print(edge, edge_id)
    print("error2")
    #exit()


def create_set_instances(graph_penman):
    instances = graph_penman.instances()
    # print(instances)
    dict_insts = {}
    for i in instances:
        dict_insts[i.source] = i.target
    return dict_insts


def get_roles_penman(graph_triples, roles_in_order):
    roles_penman = []
    count_roles = 0
    for v in graph_triples:
        role = v[1]
        if role == ':instance' or role == ':instance-of':
            continue
        if "-of" in roles_in_order[count_roles]:
            role = role + "-of"
        roles_penman.append(role)
        count_roles += 1

    return roles_penman


def simplify_amr(amr):
    try:
        graph_penman = penman.decode(amr)
        v2c_penman = create_set_instances(graph_penman)

        amr_penman = penman.encode(graph_penman)
        #print(amr)
        amr_penman = amr_penman.replace('\t', '')
        amr_penman = amr_penman.replace('\n', '')
        #tokens = amr.split()
        tokens = amr_penman.split()
        # import pdb
        # pdb.set_trace()
    except:
        print('error')
        #exit()
        return None, None

    try:
        new_tokens, mapping = simplify(tokens, v2c_penman)
    except Exception as e:
        print(e.message, e.args)
        print('error simply')
        #exit()
        return None, None

    roles_in_order = []
    instance_true = False
    for token in amr_penman.split():
        if token.startswith(":"):
            if token == ':instance-of':
                continue
            roles_in_order.append(token)

    nodes, triples = get_line_graph_new(graph_penman, new_tokens, mapping, roles_in_order, amr)
    # try:
    #     nodes, triples = get_line_graph_new(graph_penman, new_tokens, mapping, roles_in_order, amr)
    # except:
    #     print('error new graph')
    #     exit()
    #     return None
    try:
        triples = sorted(triples)

        return nodes, triples
    except:
        return None, None


def simplify_amr_nopar(amr):
    try:
        graph_penman = penman.decode(amr)
        v2c_penman = create_set_instances(graph_penman)

        amr_penman = penman.encode(graph_penman)
        #print(amr)
        amr_penman = amr_penman.replace('\t', '')
        amr_penman = amr_penman.replace('\n', '')
        #tokens = amr.split()
        tokens = amr_penman.split()
        # import pdb
        # pdb.set_trace()
    except:
        print('error')
        exit()
        return None

    try:
        new_tokens, mapping = simplify_nopar(tokens, v2c_penman)
    except Exception as e:
        print(e.message, e.args)
        print('error simply')
        #exit()
        return None

    roles_in_order = []
    instance_true = False
    for token in amr_penman.split():
        if token.startswith(":"):
            if token == ':instance-of':
                continue
            roles_in_order.append(token)

    nodes, triples = get_line_graph_new(graph_penman, new_tokens, mapping, roles_in_order, amr)
    # try:
    #     nodes, triples = get_line_graph_new(graph_penman, new_tokens, mapping, roles_in_order, amr)
    # except:
    #     print('error new graph')
    #     exit()
    #     return None
    triples = sorted(triples)

    return nodes, triples



def simplify_amr_triples(amr):
    graph_penman = penman.decode(amr)
    v2c_penman = create_set_instances(graph_penman)

    graph = ''
    for t in graph_penman.triples:
        if t[1] != ':instance':
            try:
                head = t[0]
                tail = t[2]
                if t[0] in v2c_penman.keys():
                    head = v2c_penman[t[0]]
                if t[2] in v2c_penman.keys():
                    tail = v2c_penman[t[2]]

                graph += ' <H> ' + head + ' <R> ' + t[1] + ' <T> ' + tail
            except Exception as e:
                print(e)
                print(graph_penman.triples)
                print(graph_penman.instances())
                # import pdb
                # pdb.set_trace()

    # import pdb
    # pdb.set_trace()

    return graph


def save_metrics(metrics, output_file):
    if os.path.exists(output_file):
        append_write = 'a'  # append if already exists
    else:
        append_write = 'w'  # make a new file if not

    with open(output_file, append_write, encoding="utf-8") as fd:
            fd.write(json.dumps(metrics, ensure_ascii=False) + "\n")


def maybe_save_checkpoint(metrics, save_dir, global_step, model, tokenizer):

    best_bacc = 0
    folder_checkpoint = ""

    output_file = os.path.join(save_dir, "best_checkpoint.json")
    if os.path.exists(output_file):
        with open(output_file, encoding="utf-8") as f:
            data = [json.loads(line) for line in f]
            best_bacc = data[0]["bacc"]
            folder_checkpoint = data[0]["folder_checkpoint"]

    if metrics["bacc"] > best_bacc:

        save_dir_name = "step_{}".format(global_step)
        save_sub_dir = os.path.join(save_dir, save_dir_name)
        os.mkdir(save_sub_dir)
        torch.save(model, save_sub_dir + '/model.pt')
        #model.save_pretrained(save_sub_dir)
        tokenizer.save_pretrained(save_sub_dir)

        if folder_checkpoint:
            os.system("rm -rf " + folder_checkpoint)

        os.system("rm -rf " + output_file)
        metrics["folder_checkpoint"] = save_sub_dir
        save_metrics(metrics, output_file)