# coding: utf-8

import torch
from misc import resample, resample_fromall, resample_sents_fromall

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



def train_step(batch_data, encoder, model, classifier, criterion, optimizer, n_lbls, args, negsampling=False):
    encoder.train()
    model.train()
    classifier.train()
    
    optimizer.zero_grad()

    # make embeddings x from corpus, using sentence encoder and label embeddings.
    embeds_index = batch_data.embeds_index.to(device)
    embeds = encoder(embeds_index)

    # pass input data through the network, get digits
    embeds = model(embeds, batch_data.edge_index.to(device), \
                batch_data.from_mask.to(device), batch_data.to_mask.to(device))

    # pass through classifier to get digits & generate loss
    target_idx = batch_data.target_idx
    target_lbls = batch_data.target_lbls

    idx_map = {idx: idx_in_embeds for idx_in_embeds, idx \
                in enumerate(list(embeds_index.cpu().numpy()))}

    target_inds = [idx_map[idx] for idx in target_idx]
    targets = embeds[target_inds]
    if negsampling:
        target_lbls, ans_masks = resample_fromall(target_lbls, n_lbls, pos_num=args.pos_num, neg_num=args.neg_num)  # resample positives, and sample negatives within the batch
        # # retrieve label embeddings from the after-graph-conv embeddings
        # lbl_inds = torch.tensor([[idx_map[idx] for idx in lbls] for lbls in target_lbls], dtype=torch.long, device=device)
        # lbls = embeds[lbl_inds]
        
        # retrieve label embeddings directly from the label embeddings
        lbls = encoder.label_embeddings(torch.tensor(target_lbls, dtype=torch.long, device=device))

        digits = classifier(targets, lbls).squeeze()
        ans_masks = torch.tensor(ans_masks, dtype=torch.float, device=device)
    else:
        lbls = encoder.label_embeddings.weight
        digits = classifier(targets, lbls).squeeze()
        ans_masks = lbl2tensor(target_lbls, n_lbls).float().to(device)

    loss = criterion(digits, ans_masks)
    loss.backward()
    optimizer.step()

    return loss.item()

def train_step_lbl(batch_data, encoder, model, classifier, criterion, optimizer, n_lbls, n_sents, args):
    encoder.train()
    model.train()
    classifier.train()
    
    optimizer.zero_grad()

    # make embeddings x from corpus, using sentence encoder and label embeddings.
    embeds_index = batch_data.embeds_index.to(device)
    embeds = encoder(embeds_index)

    # pass input data through the network, get digits
    embeds = model(embeds, batch_data.edge_index.to(device), \
                batch_data.from_mask.to(device), batch_data.to_mask.to(device))

    # pass through classifier to get digits & generate loss
    target_idx = batch_data.target_idx
    target_sents = batch_data.target_lbls   # temporarily use this to get related sentences of the target label

    target_sents, ans_masks = resample_sents_fromall(target_sents, n_lbls, n_sents, pos_num=args.pos_num, neg_num=args.neg_num)  # resample positives, and sample negatives within the batch

    idx_map = {idx: idx_in_embeds for idx_in_embeds, idx \
                in enumerate(list(embeds_index.cpu().numpy()))}

    # retrieve label embeddings from the after-graph-conv embeddings
    target_inds = [idx_map[idx] for idx in target_idx]
    lbls = embeds[target_inds]

    # retrieve sentence embeddings from the after-graph-conv embeddings
    # sent_inds = torch.tensor([[idx_map[idx] for idx in sents] for sents in target_sents], \
    #                 dtype=torch.long, device=device)
    # sents = embeds[sent_inds]
    target_sents = torch.tensor(target_sents, dtype=torch.long, device=device)
    sents = encoder(target_sents.view(-1)).view(target_sents.size(0), target_sents.size(1), -1)

    digits = classifier(sents, lbls, moredim=True).squeeze()
    # ans_masks = torch.ones(digits.size(), device=device)
    ans_masks = torch.tensor(ans_masks, dtype=torch.float, device=device)
    loss = criterion(digits, ans_masks)
    loss.backward()
    optimizer.step()

    return loss.item()

def test_step(batch_data, encoder, model, classifier, n_lbls, threshold=0.5, hfet=True):
    encoder.eval()
    model.eval()
    classifier.eval()

    def _prediction(outputs, threshold=0.5, predict_top=True):
        # preds = (outputs.sigmoid() > threshold).int()
        preds = (outputs > threshold).int()
        if predict_top:
            _, highest = outputs.max(dim=1)
            highest = highest.int().tolist()
            for i, h in enumerate(highest):
                preds[i][h] = 1
        return preds

    with torch.no_grad():
        # make embeddings x from corpus, using sentence encoder and label embeddings.
        embeds_index = batch_data.embeds_index.to(device)
        embeds = encoder(embeds_index)

        # pass input data through the network, get digits
        embeds = model(embeds, batch_data.edge_index.to(device), \
                    batch_data.from_mask.to(device), batch_data.to_mask.to(device))

        # pass through classifier to get digits & generate loss
        target_idx = batch_data.target_idx
        target_lbls = batch_data.target_lbls    # when testing, target labels are not in batch index


        idx_map = {idx: idx_in_embeds for idx_in_embeds, idx \
                        in enumerate(list(embeds_index.cpu().numpy()))}

        target_inds = [idx_map[idx] for idx in target_idx]
        targets = embeds[target_inds]

        labels = encoder.label_embeddings.weight
        if not hfet:
            labels = labels.unsqueeze(0).repeat([targets.size(0), 1, 1])
        
        digits = classifier(targets, labels).squeeze()

        # y_pred = torch.sigmoid(digits)
        y_pred = _prediction(digits, threshold, True)
        y_true = lbl2tensor(target_lbls, n_lbls)

        retr = get_eval_info(y_pred, y_true, threshold)
    return retr



def lbl2tensor(lbls_list, n_lbls):
    ''' Transform lists of labels into tensors
    Input lbls_list as list of list of labels, n_lbls is the total number of classes.
    '''
    n_sample = len(lbls_list)

    y_true = torch.zeros([n_sample, n_lbls], dtype=torch.long, device=device)
    for row, lbls in enumerate(lbls_list):
        y_true[row, lbls] = 1
    return y_true



def get_eval_info(y_pred, y_true, threshold=0.5):
    ''' Compute necessary info for the macro P, R, F1. 
    Input y_pred and y_true as [B, n_class] tensors
    '''
    y_pred = (y_pred > threshold)
    hit = torch.sum(y_pred * y_true, dim=0, dtype=torch.float32)
    positive = torch.sum(y_pred, dim=0, dtype=torch.float32)
    true = torch.sum(y_true, dim=0, dtype=torch.float32)

    hit_by_sample = torch.sum(y_pred * y_true, dim=1, dtype=torch.float32)
    positive_by_sample = torch.sum(y_pred, dim=1, dtype=torch.float32)
    true_by_sample = torch.sum(y_true, dim=1, dtype=torch.float32)
    macro_P = torch.sum(hit_by_sample / torch.clamp_min(positive_by_sample, 1e-10))
    macro_R = torch.sum(hit_by_sample / torch.clamp_min(true_by_sample, 1e-10))
    macro_F1 = torch.sum(2 * macro_P * macro_R / torch.clamp_min(macro_P + macro_R, 1e-10))
    return hit, positive, true, macro_P, macro_R, macro_F1, y_pred.size(0)

