import torch


def repackage_hidden(h):
    """Wraps hidden states in new Tensors,
    to detach them from their history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)


def get_batch(source, i, seq_len=None, token_data=True):
    if token_data:
        seq_len = min(seq_len, source.size(1) - 1 - i)
        data = source[:, i:i + seq_len]
        target = source[:, i + 1:i + 1 + seq_len]
        return data, target
    else:
        seq_len = min(seq_len, source.size(1) - i)
        data = source[:, i:i + seq_len]
        return data


def evalb(pred_tree_list, targ_tree_list):
    import os
    import subprocess
    import tempfile
    import re
    import nltk

    temp_path = tempfile.TemporaryDirectory(prefix="evalb-")
    temp_file_path = os.path.join(temp_path.name, "pred_trees.txt")
    temp_targ_path = os.path.join(temp_path.name, "true_trees.txt")
    temp_eval_path = os.path.join(temp_path.name, "evals.txt")

    print("Temp: {}, {}".format(temp_file_path, temp_targ_path))
    temp_tree_file = open(temp_file_path, "w")
    temp_targ_file = open(temp_targ_path, "w")

    for pred_tree, targ_tree in zip(pred_tree_list, targ_tree_list):
        def process_str_tree(str_tree):
            return re.sub('[ |\n]+', ' ', str_tree)

        def list2tree(node):
            if isinstance(node, list):
                tree = []
                for child in node:
                    tree.append(list2tree(child))
                return nltk.Tree('<unk>', tree)
            elif isinstance(node, str):
                return nltk.Tree('<word>', [node])

        temp_tree_file.write(process_str_tree(str(list2tree(pred_tree)).lower()) + '\n')
        temp_targ_file.write(process_str_tree(str(list2tree(targ_tree)).lower()) + '\n')

    temp_tree_file.close()
    temp_targ_file.close()

    evalb_dir = os.path.join(os.getcwd(), "EVALB")
    evalb_param_path = os.path.join(evalb_dir, "COLLINS.prm")
    evalb_program_path = os.path.join(evalb_dir, "evalb")
    command = "{} -p {} {} {} > {}".format(
        evalb_program_path,
        evalb_param_path,
        temp_targ_path,
        temp_file_path,
        temp_eval_path)

    subprocess.run(command, shell=True)

    with open(temp_eval_path) as infile:
        for line in infile:
            match = re.match(r"Bracketing Recall\s+=\s+(\d+\.\d+)", line)
            if match:
                evalb_recall = float(match.group(1))
            match = re.match(r"Bracketing Precision\s+=\s+(\d+\.\d+)", line)
            if match:
                evalb_precision = float(match.group(1))
            match = re.match(r"Bracketing FMeasure\s+=\s+(\d+\.\d+)", line)
            if match:
                evalb_fscore = float(match.group(1))
                break

    temp_path.cleanup()

    print('-' * 80)
    print('Evalb Prec:', evalb_precision,
          ', Evalb Reca:', evalb_recall,
          ', Evalb F1:', evalb_fscore)

    return evalb_fscore


def generate_idx(prev_structure, structure, structure_first, nslot):
    bsz, length = structure_first.size()
    output = torch.zeros_like(structure_first)
    init_idx = prev_structure.size(1)
    structure = torch.cat([prev_structure, structure], dim=-1)
    cummax = torch.zeros_like(structure)

    for i in range(init_idx):
        cummax[:, :i] = torch.relu(cummax[:, :i] - structure[:, i, None]).long() \
                        + structure[:, i, None]

    for i in range(init_idx, init_idx + length):
        first_idx = i + structure_first[:, i - init_idx]
        first_idx.clamp_(min=0, max=structure.size(1) - 1)
        trg = structure[torch.arange(bsz), first_idx] - 1
        thrd = cummax[torch.arange(bsz), first_idx]
        output[:, i - init_idx] = torch.relu(trg - thrd).long() + thrd

        if i < structure.size(1):
            cummax[:, :i] = torch.relu(cummax[:, :i] - structure[:, i, None]).long() \
                            + structure[:, i, None]

    output.masked_fill_(structure_first == 0, nslot - 1)
    output.clamp_(min=0)
    return output.detach()


def generate_ground_truth(prev_structure, structure_first, nslot):
    bsz, length = structure_first.size()

    output = torch.zeros_like(structure_first)
    output.masked_fill_(structure_first == 0, nslot)
    init_idx = prev_structure.size(1)
    structure = torch.cat([prev_structure, output], dim=-1)
    cummax = torch.zeros_like(structure)

    for i in range(init_idx):
        cummax[:, :i] = torch.relu(cummax[:, :i] - structure[:, i, None]).long() \
                        + structure[:, i, None]

    for i in range(init_idx, init_idx + length):
        first_idx = i + structure_first[:, i - init_idx]
        first_idx.clamp_(min=0, max=structure.size(1) - 1)

        trg = structure[torch.arange(bsz), first_idx] - 1
        thrd = cummax[torch.arange(bsz), first_idx]
        trg_corrected = torch.relu(trg - thrd).long() + thrd

        structure[:, i] = trg_corrected

        if i < structure.size(1):
            cummax[:, :i] = torch.relu(cummax[:, :i] - structure[:, i, None]).long() \
                            + structure[:, i, None]

    output = structure[:, init_idx:]
    output.clamp_(min=0)
    return output.detach()
