import torch
import random
import numpy as np


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0 and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)


def collate_fn(batch):
    max_len = max([len(f["input_ids"]) for f in batch])
    input_ids = [f["input_ids"] + [0] * (max_len - len(f["input_ids"])) for f in batch]
    sentid_mask = [ torch.cat([f["sentid_mask"] , torch.zeros((max_len - len(f["sentid_mask"])))]) for f in batch]
    input_mask = [[1.0] * len(f["input_ids"]) + [0.0] * (max_len - len(f["input_ids"])) for f in batch]
    labels = [f["labels"] for f in batch]
    entity_pos = [f["entity_pos"] for f in batch]
    mention_pos = [f["mention_pos"] for f in batch]
    hts = [f["hts"] for f in batch]
    mention_hts = [f["mention_hts"] for f in batch]
    padded_mention = [f["padded_mention"] for f in batch]
    padded_mention_mask = [f["padded_mention_mask"] for f in batch]
    input_ids = torch.tensor(input_ids, dtype=torch.long)
    sentid_mask = torch.stack(sentid_mask)
    input_mask = torch.tensor(input_mask, dtype=torch.float)
    #labels = [torch.tensor(label) for label in labels]
    output = (input_ids, input_mask, labels, entity_pos, hts, mention_pos, mention_hts, padded_mention, padded_mention_mask, sentid_mask)
    return output

def collate_fn_kd(batch):
    teacher_logits = None
    segment_spans = None
    entity_types = None
    max_len = max([len(f["input_ids"]) for f in batch])
    max_ent_len = max([len(f["entity_pos"]) for f in batch])
    input_ids = [f["input_ids"] + [0] * (max_len - len(f["input_ids"])) for f in batch]
    sentid_mask = None
    #sentid_mask = [ torch.cat([f["sentid_mask"] , torch.zeros((max_len - len(f["sentid_mask"])))]) for f in batch]
    input_mask = [[1.0] * len(f["input_ids"]) + [0.0] * (max_len - len(f["input_ids"])) for f in batch]
    labels = [f["labels"] for f in batch]
    entity_pos = [f["entity_pos"] for f in batch]
    mention_pos = [f["mention_pos"] for f in batch]
    hts = [f["hts"] for f in batch]
    mention_hts = [f["mention_hts"] for f in batch]
    padded_mention = [f["padded_mention"] for f in batch]
    padded_mention_mask = [f["padded_mention_mask"] for f in batch]
    input_ids = torch.tensor(input_ids, dtype=torch.long)
    #sentid_mask = torch.stack(sentid_mask)
    input_mask = torch.tensor(input_mask, dtype=torch.float)
    if "teacher_logits" in batch[0]:
        teacher_logits = [f["teacher_logits"] for f in batch]
    if "segment_span" in batch[0]:
        segment_spans = [f["segment_span"] for f in batch]
    if "entity_types" in batch[0]:
        entity_types = [x for f in batch for x in f['entity_types'] ]
        entity_types = [f['entity_types'] for f in batch ]
    output = (input_ids, input_mask, labels, entity_pos, hts, mention_pos, mention_hts, padded_mention, padded_mention_mask, sentid_mask, teacher_logits, entity_types, segment_spans)
    return output
