from torch.utils.data import Dataset, DataLoader
from torch import optim
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F
from tqdm import tqdm
import pickle
import torch
import numpy as np
import os
from config import basic_opt as opt
from utils.process_data import *
import models
from utils.utils import *
import fire
import multiprocessing
import logging as log
import time

class vp_set(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs
        self.input_seqs = [pair[0] for pair in pairs]
        self.target_seqs = [pair[-1] for pair in pairs]
        self.ids = [pair[-2] for pair in pairs]
        if len(pairs[0])>3:
            self.pos_seqs = [pair[1] for pair in pairs]
        else:
            self.pos_seqs = None

    def __getitem__(self, index):
        input_seq = self.input_seqs[index]
        target_seq = self.target_seqs[index]
        id = self.ids[index]
        if self.pos_seqs != None:
            pos_seq = self.pos_seqs[index]
            return input_seq, pos_seq, id, target_seq
        else:
            return input_seq, id, target_seq

    def __len__(self):
        return len(self.target_seqs)

def my_collate_fn(batch):
    batch.sort(key=lambda x: len(x[0]), reverse=True)  # sort sentences in decreasing length
    if len(batch[0])==4:
        input_seqs, pos_seqs, id, target_verb = zip(*batch)
        lengths = [len(s) for s in input_seqs]
        input_padded = np.asarray([pad_seq(s, max(lengths)) for s in input_seqs])
        pos_padded = np.asarray([pad_seq(s, max(lengths)) for s in pos_seqs])
        target = np.asarray([target_verb])

        batched_data ={'input':input_padded,
                       'target': target,
                       'pos':pos_padded,
                       'id':id
                       }
    else:
        input_seqs, id, target_verb = zip(*batch)
        lengths = [len(s) for s in input_seqs]
        input_padded = np.asarray([pad_seq(s, max(lengths)) for s in input_seqs])
        target = np.asarray([target_verb])

        batched_data = {'input': input_padded,
                        'id':id,
                        'target': target,
                        }

    return batched_data, lengths

def make_weights_for_balanced_classes(pairs, nclasses):
    count = [0] * nclasses
    for pair in pairs:
        count[pair[-1]] += 1
    weight_per_class = [0.] * nclasses
    N = float(sum(count))
    for i in range(nclasses):
        weight_per_class[i] = N/float(count[i])
    weight = [0] * len(pairs)
    for idx, val in enumerate(pairs):
        weight[idx] = weight_per_class[val[-1]]
    return weight

def calc_loss(loss_func, prediction, target_batches):
    if loss_func == 'cross_entropy':
        criterion = nn.CrossEntropyLoss()
        return criterion(prediction, target_batches.view(-1))
    elif loss_func == 'max_margin':
        criterion = nn.MultiMarginLoss()
        return criterion(prediction, target_batches.view(-1))


def iter(batched_variables, lengths, predictor, hidden, train_flag, opt, optimizer=None):
    if train_flag =='TRAIN':
        predictor.train()  # init train mode
        predictor.zero_grad()  # clear the gradients before loss and backward
    else:
        predictor.eval()

    if batched_variables['input'].size(0) != opt.batch_size:
        hidden = predictor.init_hidden(batched_variables['input'].size(0))
    else:
        hidden = repackage_hidden(hidden, reset=True)

    # make prediction
    if opt.apply_attn:
        prediction, scores, hidden = predictor(batched_variables, lengths, hidden)
    else:
        prediction, hidden = predictor(batched_variables, lengths, hidden)

    loss = calc_loss(opt.loss_func, prediction, batched_variables['target'])

    if train_flag == 'TRAIN':
        loss.backward()
        optimizer.step()

    conf, out = F.softmax(prediction, dim=1).topk(1)
    target = batched_variables['target'].data.view(-1)
    # out = probs.topk(3)[1].data  # get top3

    output_dict = {'pred':out.data.view(-1),
                   'conf':conf.data.view(-1),
                   'target':target,
                   }

    if opt.apply_attn:
        output_dict['attn'] = scores

    return loss.data.item(), output_dict

def train_epoch(train_data, eval_data, predictor, optimizer, epoch, val_history_acc, val_history_loss, scheduler, opt):
    start_time = time.time()  # track time
    total_loss = 0  # track loss
    batch_count = 0  # track batch

    correct = 0  # consider only top1
    correct_3 = 0  # consider top3
    total = 0
    history_acc = []

    # get dataset
    if opt.fix_portion > 0:
        dset = vp_set(train_data[opt.fix_portion])
        total_count = len(train_data[opt.fix_portion])
    else:
        dset = vp_set(train_data)
        total_count = len(train_data)

    # set up sampler
    if opt.sampler == 'weighted':
        weights = make_weights_for_balanced_classes(dset.pairs, 50)
        weights = torch.DoubleTensor(weights)
        data_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
        shuffling = False
    elif opt.sampler == 'random':
        data_sampler = torch.utils.data.sampler.RandomSampler(dset)
        shuffling = False
    else:
        data_sampler = None
        shuffling = opt.shuffling

    # set up data loader
    data_loader = DataLoader(dset, opt.batch_size, shuffle=shuffling, collate_fn=my_collate_fn,
                             num_workers=4, pin_memory=True, sampler=data_sampler)

    hidden = predictor.init_hidden(opt.batch_size) # init hidden

    for batched_data, lengths in data_loader:
        batched_variables = {k: variableFromSentence(v) for k, v in batched_data.items() if k!='id'}
        # input variable( B x L), pos_variable ( B X L ), target_variable(1 x B)

        loss, output = iter(batched_variables, lengths, predictor, hidden, 'TRAIN', opt, optimizer=optimizer)

        total_loss += loss
        batch_count += 1

        correct += torch.sum(torch.eq(output['pred'], output['target'])).item()
        # correct_3 += torch.sum(torch.eq(out, target.unsqueeze(1).expand_as(out))).item()

        total += output['target'].size(0)
        acc = correct / total
        history_acc.append(acc)

        # log
        if batch_count % opt.log_interval == 0 and batch_count > 0:
            cur_loss = total_loss / opt.log_interval
            elapsed = time.time() - start_time
            log.info('{:5d}/{:5d} batches |ms/batch {:5.2f} | loss {:5.4f} | acc {:5.4f}'
                  .format(batch_count, total_count // opt.batch_size, elapsed * 1000 / opt.log_interval, cur_loss, acc))
            total_loss = 0
        start_time = time.time()

        # eval
        if batch_count % opt.eval_interval == 0 and batch_count > 0:
            val_loss, val_acc = eval(eval_data, predictor, opt)  # eval

            # Save checkpoint if is a new best
            is_best = bool(val_acc > max(val_history_acc))
            if is_best:
                log.info("Getting a new best...")
                predictor.save(opt.ckp_dir)

            print_progress(epoch, batch_count, val_loss)

            val_history_acc.append(val_acc)
            val_history_loss.append(val_loss)
            scheduler.step(val_loss)  # reduce learning rate if val_loss stops decreasing


    return history_acc, val_history_acc, val_history_loss


def eval(data, predictor, opt):
    predictor.eval()
    total_loss = 0
    batch_count = 0
    correct = 0
    total = 0

    # get dataset
    if opt.fix_portion > 0:
        dset = vp_set(data[opt.fix_portion])
    else:
        dset = vp_set(data)

    batch_size = opt.batch_size
    data_loader = DataLoader(dset, batch_size, shuffle=False, collate_fn=my_collate_fn, num_workers=4,
                             pin_memory=True)

    hidden = predictor.init_hidden(batch_size)
    for batched_data, lengths in data_loader:
        batched_variables = {k: variableFromSentence(v) for k, v in batched_data.items() if k!='id'}

        loss, output = iter(batched_variables, lengths, predictor, hidden,'EVAL',opt)

        total_loss += loss
        batch_count += 1

        # correct += torch.sum(out == target)
        total += output['target'].size(0)
        correct += torch.sum(torch.eq(output['pred'], output['target'])).item()
        # correct_3 += torch.sum(torch.eq(out, target.unsqueeze(1).expand_as(out))).item()

    avg_loss = total_loss / batch_count  # average loss
    avg_acc = float(correct / total)
    log.info("Test average loss:%.4f, overall accuracy %.4f" % (avg_loss, float(correct / total)))

    return avg_loss, avg_acc


def train(**kwargs):
    opt.parse(kwargs)

    train_data, eval_data, test_data, dictionary= load_data(opt) # load data_processing

    save_output(opt.data_dir, 'test', test_data)
    save_output(opt.data_dir, 'dict', dictionary)

    # initiliaze models
    predictor = getattr(models, opt.model)(dictionary, opt)
    if use_cuda:
        predictor = predictor.cuda()

    # initialize optimizer
    optimizer = optim.Adam(predictor.parameters(), lr=opt.lr, weight_decay=1e-5)
    # initialize scheduler
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2, verbose=True)

    train_history_acc = []
    val_history_acc = [0.0]
    val_history_loss = [1000]


    if opt.load_ckp:
        predictor.load(opt.ckp_dir)

    for epoch in range(1, opt.num_epochs + 1):

        train_acc, val_history_acc, val_history_loss = train_epoch(train_data, eval_data, predictor, optimizer,
                                                 epoch, val_history_acc, val_history_loss, scheduler, opt) # train

        train_history_acc.extend(train_acc)

    save_output(opt.out_dir, 'train_acc', train_history_acc)
    save_output(opt.out_dir, 'val_acc', val_history_acc)


def predict(**kwargs):
    opt.parse(kwargs)

    # load data
    dictionary = pickle.load(open(os.path.join(opt.data_dir,'dict.pickle'),'rb'))

    try:
        test_data = pickle.load(open(os.path.join(opt.data_dir,'test.pickle'),'rb'))
    except:
        test_data = load_test_data(opt,dictionary)

    # initialize model
    predictor = getattr(models, opt.model)(dictionary, opt)
    if use_cuda:
        predictor = predictor.cuda()

    predictor.load(opt.ckp_dir)
    acc_over_time = {}

    # predict at different reveal scale
    # get dataset
    if opt.fix_portion > 0:
        dsets = [(opt.fix_portion, vp_set(test_data[opt.fix_portion]))]
    else:
        dsets = [(p,vp_set(test_data[p])) for p in opt.portion]

    for (p, dset) in dsets:
        data_loader = DataLoader(dset, opt.batch_size, shuffle=False, collate_fn=my_collate_fn, num_workers=4,
                                 pin_memory=True)
        input_ids = []
        input_seqs = []
        predictions = []
        confidences = []
        target_seqs = []
        attn_seqs = []
        total_loss = 0
        batch_count = 0

        hidden = predictor.init_hidden(opt.batch_size)

        for batched_data, lengths in data_loader:
            input_ids.extend(batched_data['id']) # keep track of the preverb id
            batched_variables = {k: variableFromSentence(v) for k, v in batched_data.items() if k!='id'}

            loss, output = iter(batched_variables, lengths, predictor, hidden, 'EVAL', opt)
            total_loss += loss
            batch_count += 1

            if use_cuda:
                input_seq = [[w for w in s if w!= 0] for s in batched_variables['input'].cpu().data.tolist()]
                conf = output['conf'].cpu().tolist()
                out = output['pred'].cpu().tolist()
                target = output['target'].cpu().tolist()
                if opt.apply_attn:
                    attn_seq = output['attn'].cpu().data.tolist()

            else:
                input_seq = [[w for w in s if w!= 0] for s in batched_variables['input'].data.tolist()]
                conf = output['conf'].tolist()
                out = output['pred'].tolist()
                target = output['target'].tolist()
                if opt.apply_attn:
                    attn_seq = output['attn'].data.tolist()

            input_seqs.extend(input_seq)
            predictions.extend(out)
            confidences.extend(conf)
            target_seqs.extend(target)
            if opt.apply_attn:
                attn_seqs.extend(attn_seq)

        avg_loss = total_loss / batch_count  # average loss
        overall_acc = np.sum(np.asarray(predictions) == np.asarray(target_seqs)) / len(predictions)
        acc_over_time[p] = round(float(overall_acc),4)
        log.info("Test average loss (%.2f) :%.4f, overall accuracy %.4f" %
                 (p, avg_loss, float(overall_acc)))

        if not os.path.exists(opt.out_dir):
            os.makedirs(opt.out_dir)

        get_predicted_all(input_ids,input_seqs, predictions, confidences, target_seqs, dictionary['source'], dictionary['target'], opt.out_dir,
                          '{}_predicted.csv'.format(str(p)), attn_seqs, p)
        get_pr_info(target_seqs, predictions, dictionary['target'], opt.out_dir,
                    '{}_pr_report.csv'.format(str(p)))
        # save_output(opt.out_dir, '{}_output'.format(str(p)), [predictions,target_seqs])
    save_acc_df(acc_over_time, opt)


if __name__ == "__main__":
    # enable multiprocessing
    multiprocessing.get_context('spawn')
    # use gpu
    use_cuda = torch.cuda.is_available()

    # setting up log
    logdatetime = time.strftime("%m%d")
    format = '  %(message)s'
    handlers = [log.FileHandler('train_vp'+logdatetime+'.log'), log.StreamHandler()]
    log.basicConfig(level=log.INFO, format=format, handlers=handlers)
    # parse parameters using fire
    fire.Fire()
