import torch
import torch.nn as nn
import preprocess
import math
import pdb

from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack
from dataset import mask, pad
from modules.pointer import PointerPredicator
from modules.attention import ScaledDotAttention
from modules.feedforward import Feedforward
from utils import pack_bidirectional_lstm_state


class BeamNode(object):
    def __init__(self, data, parent=None, depth=1, score=1., children=[]):
        self.data = data
        self.parent = parent
        self.depth = depth
        self.score = math.log(score)
        self.children = children
        self.end = False

    def add(self, child):
        child.parent = self
        child.score += self.score
        child.depth = self.depth + 1
        self.children.append(child)
        return self

    def path(self):
        p = self
        out = []
        out_nodes = []
        while p:
            out.insert(0, p.data['predict'])
            out_nodes.insert(0, p)
            p = p.parent
        return out, out_nodes, self.score


class BertSQL(nn.Module):
    def __init__(self, bert, tokenizer, config):
        super().__init__()
        self.bert = bert
        self.tokenizer = tokenizer
        self.d_hidden_bert = self.bert.config.hidden_size

        self.column_types = preprocess.columnt
        self.column_type_rev_ids = preprocess.columnt_rev_idx
        self.d_hidden = config['d_hidden']
        self.dropout_p = config['dropout_p']

        self.dec = nn.LSTM(input_size=self.d_hidden*2, hidden_size=self.d_hidden, batch_first=True)
        self.ali_attn = ScaledDotAttention(self.d_hidden, dropout_p=self.dropout_p)

        self.emb_encoder = nn.LSTM(input_size=self.d_hidden_bert, hidden_size=self.d_hidden//2, batch_first=True, bidirectional=True)
        self.emb_encoder_q = nn.LSTM(input_size=self.d_hidden, hidden_size=self.d_hidden//2, batch_first=True, bidirectional=True)

        self.init_sql_predicator()

    def init_sql_predicator(self):
        self.keywords = preprocess.keywords
        self.padding_idx = 0

        self.sql_embedding = nn.Embedding(len(self.keywords), self.d_hidden, padding_idx=self.padding_idx)
        self.sql_predicator = PointerPredicator(self.sql_embedding, d_hidden=self.d_hidden, dropout_p=self.dropout_p,
                                                ignore_index=self.padding_idx, pointers=["token", "table", "column"])

        self.colt_predicator = nn.Linear(self.d_hidden, len(preprocess.columnt))

        self.CLS_TOK = self.keywords['[CLS]']
        self.SEP_TOK = self.keywords['[SEP]']
        self.PAD_TOK = self.keywords['[PAD]']

    def embedding(self, x, pointers=None, debug=None, ex=None):
        return self.sql_predicator.embedding(x, pointers, debug, ex=ex)

    def tokenize(self, examples, is_train=True, with_align=True, with_reduce=False, with_key=True, with_col=True, with_val=True):
        return preprocess.tokenize(self.tokenizer, examples, with_target='sql', is_train=is_train, with_align=with_align,
                                   with_reduce=with_reduce, with_key=with_key, with_col=with_col, with_val=with_val)

    def encode(self, x):
        # Part 1: encode input
        r = self.bert(input_ids=x["input_ids"],
                      attention_mask=x["attention_mask"],
                      token_type_ids=x["token_type_ids"],
                      return_dict=True,
                      output_hidden_states=True)
        emb = r.last_hidden_state

        emb_lengths = (x['attention_mask'] == 1).sum(dim=-1).cpu()
        emb_packed = pack(emb, lengths=emb_lengths, batch_first=True, enforce_sorted=False)
        emb_packed, _ = self.emb_encoder(emb_packed)
        emb, _ = unpack(emb_packed, batch_first=True)

        cls_emb = emb[:, 0:1, :]    # [bs, 1, dim]
        bs = emb.size(0)
        dim = emb.size(-1)


        # Part 2: get branch result
        n_tok = max(x['q_length'])
        n_tbl = 1
        n_col = max([len(x['c_span_ids'][b]) for b in range(bs)])
        n_ali = max([len(x['a_span_ids'][b]) for b in range(bs)])

        tok = torch.zeros(bs, n_tok, dim).cuda()
        tbl = torch.zeros(bs, n_tbl, dim).cuda()
        col = torch.zeros(bs, n_col, dim).cuda()
        ali = torch.zeros(bs, n_ali, dim).cuda()

        tok_mask = []
        tbl_mask = []
        col_mask = []
        ali_mask = []

        for b in range(bs):
            q_start, q_end = x['q_span_ids'][b]
            q_len = q_end - q_start
            tok[b, :q_len] = emb[b, q_start:q_end]

            tbl[b, 0] = emb[b, [x['t_span_ids'][b][0][0]]]

            c_size = len(x['c_span_ids'][b])
            col[b, :c_size] = emb[b, [idx[0] for idx in x['c_span_ids'][b]]]

            a_size = len(x['a_span_ids'][b])
            ali[b, :a_size] = emb[b, [idx[0] for idx in x['a_span_ids'][b]]]

        tok_mask = mask([q_ends - q_starts for q_starts, q_ends in x['q_span_ids']]).cuda()
        tbl_mask = mask([1] * bs).cuda()  # WTQ is a single-table dataset
        col_mask = mask([len(c_span_ids) for c_span_ids in x['c_span_ids']]).cuda()
        ali_mask = mask([len(a_span_ids) for a_span_ids in x['a_span_ids']]).cuda()


        # Part 3: encode question
        tok_length = x['q_length']
        tok_packed = pack(tok, tok_length, batch_first=True, enforce_sorted=False)
        tok_packed, tok_state = self.emb_encoder_q(tok_packed)
        tok_unpacked, _ = unpack(tok_packed, batch_first=True)

        h = pack_bidirectional_lstm_state(tok_state[0], self.emb_encoder_q.num_layers)
        c = pack_bidirectional_lstm_state(tok_state[1], self.emb_encoder_q.num_layers)
        hidden = (h, c)


        x["tok"] = tok_unpacked
        x["tbl"] = tbl
        x["col"] = col
        x["ali"] = ali
        x["tok_mask"] = tok_mask
        x["tbl_mask"] = tbl_mask
        x["col_mask"] = col_mask
        x["ali_mask"] = ali_mask
        x["initial_hidden_state"] = hidden
        x["cls"] = cls_emb.squeeze(1)
        return x

    def decode(self, predicator, decoder, x, y, y_mask, is_train=True, debug=None, max_seq=128):
        tok = x["tok"]
        tbl = x["tbl"]
        col = x["col"]
        ali = x["ali"]

        tok_mask = x["tok_mask"]
        tbl_mask = x["tbl_mask"]
        col_mask = x["col_mask"]
        ali_mask = x["ali_mask"]

        h_t = x["initial_hidden_state"]

        if is_train:
            y_embedding = predicator.embedding(y, [tok, tbl, col], debug=debug, ex=x)
            y_len = y.size(1)
            y_out = []

            for t in range(y_len):
                emb_t = y_embedding[:, t:t+1, :]

                h_pre = h_t[0].transpose(0, 1)
                ali_t, _ = self.ali_attn(h_pre, ali, ali, k_mask=ali_mask)

                in_t = torch.cat([emb_t, ali_t], dim=-1)
                _, h_t = decoder(in_t, h_t)

                h_cur = h_t[0].transpose(0, 1)
                y_out.append(h_cur)

            y_out = torch.cat(y_out, dim=1)     # [bs, y_len, dim]
            (pred, pad_mask), y_ctx = predicator(
                y_out, mask=y_mask,
                pointers=[(tok, tok_mask), (tbl, tbl_mask), (col, col_mask)],
                with_cv=True,
                ex=x
            )

            colt_score = self.colt_predicator(y_ctx)
            colt_mask = torch.ones(colt_score.size(0), colt_score.size(-1), device=colt_score.device, dtype=torch.long)

            x["predict_columnt"] = (colt_score, colt_mask)
            x["y_mask"] = y_mask
            return pred, pad_mask
        else:  # greedy search
            raise NotImplementedError

    def beam(self, predicator, decoder, x, beam_size=16, max_seq=64):
        tok = x["tok"]
        tbl = x["tbl"]
        col = x["col"]
        ali = x["ali"]

        tok_mask = x["tok_mask"]
        tbl_mask = x["tbl_mask"]
        col_mask = x["col_mask"]
        ali_mask = x["ali_mask"]

        h_t = x["initial_hidden_state"]

        ex = x["ex"]
        c_base = x['c_base']
        y_ct = []

        root = BeamNode(data={
                            "predict": self.CLS_TOK,
                            "score": 1.,
                            "current": [],
                            "current_prob": [],
                            "h_t": h_t,
                        },
                        parent=None,
                        depth=1,
                        score=1.,
                        children=[])

        queue = [root]
        current_depth = root.depth
        finished = 0
        beam_results = []

        while queue:
            if finished >= beam_size:
                break

            if queue[0].depth > current_depth:
                current_depth += 1
                queue.sort(key=lambda x: x.score, reverse=True)
                for nd in queue[beam_size:]:
                    nd.end = True
                queue = queue[:beam_size]
                continue

            node = queue.pop(0)
            if node.depth >= max_seq:       # for test
                finished += 1
                beam_results.append(node.path())
                continue

            y_t = node.data['predict']
            if y_t == self.SEP_TOK:
                finished += 1
                beam_results.append(node.path())
                continue

            h_t = node.data['h_t']
            y_t = torch.tensor([[y_t]]).cuda()
            emb_in = predicator.embedding(y_t, [tok, tbl, col], debug=None, ex=x)

            h_pre = h_t[0].transpose(0, 1)
            ali_t, _ = self.ali_attn(h_pre, ali, ali, k_mask=ali_mask)

            dec_in = torch.cat([emb_in, ali_t], dim=-1)
            _, h_t = decoder(dec_in, h_t)

            dec_out = h_t[0].transpose(0, 1)
            node.data['dec_out'] = dec_out

            (score_t, pad_mask), c_t = predicator(
                dec_out, mask=None,
                pointers=[(tok, tok_mask), (tbl, tbl_mask), (col, col_mask)],
                with_cv=True,
                ex=None
            )
            node.data['c_t'] = c_t

            prob_t, pred_t = score_t.softmax(-1).topk(beam_size)
            prob_t = prob_t.view(-1).tolist()
            pred_t = pred_t.view(-1).tolist()

            for i, tok_id in enumerate(pred_t):
                child = BeamNode(data={
                                    "predict": tok_id,
                                    "score": prob_t[i],
                                    "current": [],
                                    "current_prob": [],
                                    "h_t": h_t,
                                },
                                parent=node,
                                depth=1,
                                score=prob_t[i],
                                children=[])
                node.add(child)
                queue.append(child)

        beam_results.sort(key=lambda x: x[1], reverse=True)
        if len(beam_results) > 0:
            predict_ids = beam_results[0][0]
            predict_nodes = beam_results[0][1]
            PERSIST_COL = 3

            for i, target_id in enumerate(predict_ids):
                if target_id >= c_base:
                    if 'c_t' not in predict_nodes[i].data:
                        continue

                    c_t = predict_nodes[i].data['c_t']
                    columnt = self.colt_predicator(c_t)

                    col_idx = target_id - c_base - PERSIST_COL
                    candidate_t = [1]
                    candidate_text = ["[KEY]"]
                    if col_idx >= len(ex['columns']) or col_idx < 0:
                        continue

                    all_candidate_text = ex['columns'][col_idx][2]
                    if not all_candidate_text:
                        y_ct.append([target_id, "[KEY]"])
                        continue

                    for c in all_candidate_text:
                        if c in self.column_types:
                            candidate_t.append(self.column_types[c])
                            candidate_text.append(c)
                    if candidate_t:
                        columnt = columnt.view(-1)[candidate_t].argmax(dim=-1).tolist()
                        y_ct.append([target_id, candidate_text[columnt]])

        x['beam'] = root
        x['predict_column_type'] = y_ct
        return beam_results

    def forward(self, x, is_train=True, beam_size=16, max_seq=128):
        x = self.encode(x)

        if is_train:
            y_mask = [[1] * s_len for s_len in x["target_length"]]
            y_mask = pad(y_mask, device="cuda")
            y = x["target_ids"]
            x['predict'] = self.decode(self.sql_predicator, self.dec, x, y, y_mask, is_train=is_train, debug=x['target_ids_debug'])
        else:
            x['predict'] = self.beam(self.sql_predicator, self.dec, x, beam_size=beam_size, max_seq=max_seq)
        return x

    def loss_columnt(self, x):
        y_mask = x["y_mask"]
        pred, ignore_mask = x["predict_columnt"]
        gold = x["target_ids_columnt"]

        loss = self.sql_predicator.loss(
            pred, gold, padding_mask=y_mask, ignore_index=-1, ignore_mask=ignore_mask, generate=False, debug=x)
        return loss

    def loss(self, x):
        pred, ignore_mask = x["predict"]
        gold = x["target_ids"]
        y_mask = x['y_mask']

        loss = self.sql_predicator.loss(
            pred, gold, padding_mask=y_mask, ignore_index=self.padding_idx, ignore_mask=ignore_mask, debug=x)
        loss_columnt = self.loss_columnt(x)
        return loss + loss_columnt
