# Adapted from seqeval.metrics.sequence_labeling for more control
import warnings
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset




def get_entities(seq, suffix=False):
    """
    Include start tag in return to later identify span starting with 'I' or 'E'.
    :param seq: ['B-PER', 'I-PER', 'O', 'B-LOC']
    :return: [('PER', 0, 1, 'B'), ('LOC', 3, 3, 'B')]
    :rtype: list of (chunk_type, chunk_start, chunk_end, start tag)
    """
    def _validate_chunk(chunk, suffix):
        if chunk in ['O', 'B', 'I', 'E', 'S']:
            return

        if suffix:
            if not chunk.endswith(('-B', '-I', '-E', '-S')):
                warnings.warn('{} seems not to be NE tag.'.format(chunk))

        else:
            if not chunk.startswith(('B-', 'I-', 'E-', 'S-')):
                warnings.warn('{} seems not to be NE tag.'.format(chunk))

    # for nested list
    if any(isinstance(s, list) for s in seq):
        seq = [item for sublist in seq for item in sublist + ['O']]

    prev_tag = 'O'
    prev_type = ''
    begin_offset = 0
    begin_tag = 'O'
    chunks = []
    for i, chunk in enumerate(seq + ['O']):
        _validate_chunk(chunk, suffix)

        if suffix:
            tag = chunk[-1]
            type_ = chunk[:-1].rsplit('-', maxsplit=1)[0] or '_'
        else:
            tag = chunk[0]
            type_ = chunk[1:].split('-', maxsplit=1)[-1] or '_'

        if end_of_chunk(prev_tag, tag, prev_type, type_):
            chunks.append((prev_type, begin_offset, i - 1, begin_tag))
        if start_of_chunk(prev_tag, tag, prev_type, type_):
            begin_offset = i
            begin_tag = tag
        prev_tag = tag
        prev_type = type_

    return chunks


def end_of_chunk(prev_tag, tag, prev_type, type_):
    """Checks if a chunk ended between the previous and current word.

    Args:
        prev_tag: previous chunk tag.
        tag: current chunk tag.
        prev_type: previous type.
        type_: current type.

    Returns:
        chunk_end: boolean.
    """
    chunk_end = False

    if prev_tag == 'E':
        chunk_end = True
    if prev_tag == 'S':
        chunk_end = True

    if prev_tag == 'B' and tag == 'B':
        chunk_end = True
    if prev_tag == 'B' and tag == 'S':
        chunk_end = True
    if prev_tag == 'B' and tag == 'O':
        chunk_end = True
    if prev_tag == 'I' and tag == 'B':
        chunk_end = True
    if prev_tag == 'I' and tag == 'S':
        chunk_end = True
    if prev_tag == 'I' and tag == 'O':
        chunk_end = True

    if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
        chunk_end = True

    return chunk_end


def start_of_chunk(prev_tag, tag, prev_type, type_):
    """Checks if a chunk started between the previous and current word.

    Args:
        prev_tag: previous chunk tag.
        tag: current chunk tag.
        prev_type: previous type.
        type_: current type.

    Returns:
        chunk_start: boolean.
    """
    chunk_start = False

    if tag == 'B':
        chunk_start = True
    if tag == 'S':
        chunk_start = True

    if prev_tag == 'E' and tag == 'E':
        chunk_start = True
    if prev_tag == 'E' and tag == 'I':
        chunk_start = True
    if prev_tag == 'S' and tag == 'E':
        chunk_start = True
    if prev_tag == 'S' and tag == 'I':
        chunk_start = True
    if prev_tag == 'O' and tag == 'E':
        chunk_start = True
    if prev_tag == 'O' and tag == 'I':
        chunk_start = True

    if tag != 'O' and tag != '.' and prev_type != type_:
        chunk_start = True

    return chunk_start


def lws_loss(outputs, partialY, confidence, index, lw_weight, lw_weight0):
    device = outputs.device  ####outputs(4,128,7)
    #### update the english dataset, change label to onehot
    if len(partialY.shape) != 3:
        partialY = F.one_hot(torch.where(partialY >= 0, partialY, 0), outputs[0].shape[-1])

    # onezero = torch.zeros(outputs.shape[0], outputs.shape[1], outputs.shape[2])
    # onezero[partialY > 0] = 1
    counter_onezero = 1 - partialY
    # onezero = onezero.to(device)
    counter_onezero = counter_onezero.to(device)
    sig_loss1 = 0.5 * torch.ones(outputs.shape[0], outputs.shape[1], outputs.shape[2])
    sig_loss1 = sig_loss1.to(device)
    sig_loss1[outputs < 0] = 1 / (1 + torch.exp(outputs[outputs < 0]))
    sig_loss1[outputs > 0] = torch.exp(-outputs[outputs > 0]) / (
        1 + torch.exp(-outputs[outputs > 0]))
    # l1 = confidence[index, :].unsqueeze(1).expand(-1,128,-1) * partialY * sig_loss1 ### confidence(num,7) onezero(4,128,7)   sig_loss1(4,128,7)
    l1 = confidence[index, :, :] * partialY * sig_loss1 ### confidence(num,7) onezero(4,128,7)   sig_loss1(4,128,7)
    no_tag_label = torch.Tensor([0,0,0,0,0,0,0])
    target = no_tag_label.repeat(partialY.shape[0], partialY.shape[1], 1).to(device)
    mask = torch.eq(partialY, target)
    mask = torch.all(mask, dim=2)
    mask = 1-mask.float()
    try:
        average_loss1 = l1[mask > 0, :].sum() / l1[mask > 0, :].size(0)
    except:
        print('l1.shape', l1.shape, 'mask.shape', mask.shape, 'target', target.shape, 'partialY', partialY.shape)
    # average_loss1 = torch.sum(l1) / l1.size(0)

    sig_loss2 = 0.5 * torch.ones(outputs.shape[0], outputs.shape[1], outputs.shape[2])
    sig_loss2 = sig_loss2.to(device)
    sig_loss2[outputs > 0] = 1 / (1 + torch.exp(-outputs[outputs > 0]))
    sig_loss2[outputs < 0] = torch.exp(
        outputs[outputs < 0]) / (1 + torch.exp(outputs[outputs < 0]))
    # l2 = confidence[index, :].unsqueeze(1).expand(-1,128,-1) * counter_onezero * sig_loss2
    l2 = confidence[index, :, :] * counter_onezero * sig_loss2
    average_loss2 = l2[mask > 0, :].sum() / l2[mask > 0, :].size(0)
    # average_loss2 = torch.sum(l2) / l2.size(0)

    average_loss = lw_weight0 * average_loss1 + lw_weight * average_loss2
    return average_loss, lw_weight0 * average_loss1, lw_weight * average_loss2


def confidence_update_lw(model, confidence, inputs, index=None):
    with torch.no_grad():
        device = inputs['labels'].device
        loss, total_ouputs = model(**inputs, confidence=confidence, index=index)
        batch_outputs = total_ouputs[0]
        sm_outputs = F.softmax(batch_outputs, dim=2)  ###using softmax as confidence???????

        # onezero = torch.zeros(sm_outputs.shape[0], sm_outputs.shape[1], sm_outputs.shape[2])
        labels = inputs['labels']
        if len(labels.shape) != 3:
            labels = F.one_hot(torch.where(labels >= 0, labels, 0), batch_outputs[0].shape[-1])
        # onezero[labels > 0] = 1
        counter_onezero = 1 - labels
        # onezero = onezero.to(device)
        counter_onezero = counter_onezero.to(device)

        new_weight1 = sm_outputs * labels
        new_weight1 = new_weight1 / (new_weight1 + 1e-8).sum(dim=2).unsqueeze(2).expand(-1, -1, confidence.shape[2])
        new_weight2 = sm_outputs * counter_onezero
        new_weight2 = new_weight2 / (new_weight2 + 1e-8).sum(dim=2).unsqueeze(2).expand(-1, -1, confidence.shape[2])
        new_weight = new_weight1 + new_weight2

        confidence[index, :, :] = new_weight
        return confidence


class gen_index_dataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        record = self.dataset[index]

        return record, index
