import sys
import time

import torch
from torch.nn import Parameter

from few_shot_ner.losses import BaselineNerLoss, ProtoNetNerLoss
from few_shot_ner.optimizer import create_optimizer, create_optimizer_protonet


@torch.no_grad()
def proto_ner_step(meta_dataset, loss_func, optimizer, args, device):
    # Get data
    dataset = meta_dataset.sample_task(args.batch_size, args.n, args.k)
    # One SGD step (one batch of queries only, we could do more)
    loss_func.train()
    queries, queries_intents, queries_slots = dataset.batch_of_queries(device=device)
    supports, supports_intents, supports_slots = dataset.get_supports(device=device)
    # Reset gradients
    optimizer.zero_grad()
    # Forward pass
    with torch.enable_grad():
        loss = loss_func(queries, queries_intents, queries_slots, supports, supports_intents, supports_slots, device)
        # Compute gradients
    loss.backward()
    # Update parameters
    optimizer.step()
    return loss.item(), queries[0].size(0)


@torch.no_grad()
def meta_proto_ner_step(meta_dataset, proto_net, inner_proto_net, outer_optimizer, args, device):
    dataset = meta_dataset.sample_task(args.batch_size, args.n, args.k)
    # Reset gradients
    outer_optimizer.zero_grad()
    # Copy the model
    inner_proto_net.load_state_dict(proto_net.state_dict())
    # Optimizer creation
    inner_optimizer = create_optimizer_protonet(inner_proto_net, args)
    # Stats
    loss = None
    num_inst = None
    ner_loss = ProtoNetNerLoss(inner_proto_net, args.decoder == "softmax").to(device)
    # First order approximation
    for step in range(args.num_steps):
        ner_loss.train()
        queries, queries_intents, queries_slots = dataset.batch_of_queries(device=device)
        supports, supports_intents, supports_slots = dataset.get_supports(device=device)
        # Reset gradients
        inner_optimizer.zero_grad()
        with torch.enable_grad():
            # Forward pass
            _loss = ner_loss(queries, queries_intents, queries_slots,
                             supports, supports_intents, supports_slots, device)
        # Compute gradients
        _loss.backward()
        # Update parameters
        inner_optimizer.step()
        _loss = _loss.item()
        _num_inst = queries[0].size(0)
        if loss is None:
            loss = _loss
            num_inst = _num_inst

    for p, inner_p in zip(proto_net.parameters(), inner_proto_net.parameters()):
        if p.grad is None:
            p.grad = Parameter(torch.zeros(p.size()).to(device))
        p.grad.data.copy_(p.data - inner_p.data)
    # Update parameters
    outer_optimizer.step()
    return loss, num_inst


@torch.no_grad()
def proto_ft_ner_step(dataset, loss_func, optimizer, device):
    loss_func.train()
    # get supports
    examples, intents, slots = dataset.get_supports(device=device)
    # split them randomly into supports and queries for finetuning
    size = examples[0].size(0)
    n_queries = size // 8
    idx = torch.randperm(size)

    queries_idx, supports_idx = idx[:n_queries], idx[n_queries:]
    queries = tuple([elem[queries_idx] for elem in examples])
    queries_intents = intents[queries_idx]
    queries_slots = slots[queries_idx]

    supports = tuple([elem[supports_idx] for elem in examples])
    supports_intents = intents[supports_idx]
    supports_slots = slots[supports_idx]

    # Reset gradients
    optimizer.zero_grad()
    with torch.enable_grad():
        # Forward pass
        loss = loss_func(queries, queries_intents, queries_slots, supports, supports_intents, supports_slots, device)
    # Compute gradients
    loss.backward()
    # Update parameters
    optimizer.step()
    return loss.item(), queries[0].size(0)


@torch.no_grad()
def meta_bert_ner_step(meta_dataset, base_net, inner_base_net, outer_optimizer, args, device):
    dataset = meta_dataset.sample_task(args.batch_size, args.n, num_supports=0)
    # Reset gradients
    outer_optimizer.zero_grad()
    # Copy the model
    inner_base_net.load_state_dict(base_net.state_dict())
    # Optimizer creation
    inner_optimizer = create_optimizer(inner_base_net, args)
    # Stats
    loss = None
    num_inst = None
    ner_loss = BaselineNerLoss(inner_base_net, args.decoder == "softmax").to(device)
    # First order approximation
    for step in range(args.num_steps):
        ner_loss.train()
        queries, queries_intents, queries_slots = dataset.batch_of_queries(device)
        for idx, slot in enumerate(dataset.tgt_slots):
            queries_slots[queries_slots == slot] = idx + 2
        # Reset gradients
        inner_optimizer.zero_grad()
        with torch.enable_grad():
            # Forward pass
            _loss = ner_loss(queries, queries_intents, queries_slots, device)
        # Compute gradients
        _loss.backward()
        # Update parameters
        inner_optimizer.step()
        _loss = _loss.item()
        _num_inst = queries[0].size(0)
        if loss is None:
            loss = _loss
            num_inst = _num_inst
    for p, inner_p in zip(base_net.parameters(), inner_base_net.parameters()):
        if p.grad is None:
            p.grad = Parameter(torch.zeros(p.size()).to(device))
        p.grad.data.copy_(p.data - inner_p.data)
    # Update parameters
    outer_optimizer.step()
    return loss, num_inst


@torch.no_grad()
def bert_ner_step(dataset, loss_func, optimizer, device):
    loss_func.train()
    supports, supports_intents, supports_slots = dataset.batch_of_supports(device)
    for idx, slot in enumerate(dataset.tgt_slots):
        supports_slots[supports_slots == slot] = idx + 2
    # Reset gradients
    optimizer.zero_grad()
    with torch.enable_grad():
        # Forward pass
        loss = loss_func(supports, supports_intents, supports_slots, device)
    # Compute gradients
    loss.backward()
    # Update parameters
    optimizer.step()
    return loss.item(), supports[0].size(0)


def train(step_function, num_updates, log=False):
    total_loss = 0
    total_inst = 0
    num_back = 0
    start_time = time.time()
    for update in range(1, num_updates + 1):
        loss, num_inst = step_function()
        total_loss += loss * num_inst
        total_inst += num_inst
        # Write progress
        if log and update % 5 == 0:
            sys.stdout.write("\b" * num_back + " " * num_back + "\b" * num_back)
            log_info = "update: {:d}/{:d} loss: {:.4f}, time left (estimated): {:.2f}s".format(
                update, num_updates, total_loss / total_inst,
                                     (num_updates - update) * (time.time() - start_time) / update)
            sys.stdout.write(log_info)
            sys.stdout.flush()
            num_back = len(log_info)
    if log:
        sys.stdout.write("\b" * num_back + " " * num_back + "\b" * num_back)
        sys.stdout.flush()
    return total_loss / total_inst
