from libc.stdlib cimport malloc, calloc, free
from libc.string cimport memset
from libcpp.vector cimport vector
import cython
cimport cython
import numpy as np
cimport numpy as np
from numpy.math cimport INFINITY

from coli.span.const_tree import ConstTree, Lexicon
ctypedef unsigned char index_t
ctypedef short label_t
ctypedef np.float32_t score_t

cdef struct Item:
    index_t start
    index_t end
    label_t label
    score_t score
    long left  # cython doesn't allow pointer in memorview, so define it as long
    long right
    bint initialized



cdef inline void item_assign(Item* item, index_t start, index_t end, label_t label,
                      score_t score, Item* left, Item* right) nogil:
    item.start = start
    item.end = end
    item.label = label
    item.score = score
    item.left = <long> left
    item.right = <long> right
    item.initialized = True


cdef struct SpanResult:
    index_t start
    index_t end
    label_t label


cdef void item_get_spans(Item* item, vector[SpanResult]& container) nogil:
    cdef SpanResult result

    if item.label != 0:
        result.start = item.start
        result.end = item.end
        result.label = item.label
        container.push_back(result)

    if item.left != 0:
        item_get_spans(<Item *> item.left, container)

    if item.right != 0:
        item_get_spans(<Item *> item.right, container)


cdef class ItemWrapped(object):
    cdef public:
        object start, end, label, score, left, right

    @staticmethod
    cdef create(Item* item):
        cdef ItemWrapped self = ItemWrapped()
        self.start = item.start
        self.end = item.end
        self.label = item.label
        self.score = item.score
        self.left = ItemWrapped.create(<Item *> item.left) if item.left != 0 else None
        self.right = ItemWrapped.create(<Item *> item.right) if item.right != 0 else None
        return self

    def generate_scoreable_spans(self, label_map=None):
        if label_map is not None:
            label_func = lambda x: label_map[x]
        else:
            label_func = lambda x: x

        if self.label != 0:
            yield (self.start, self.end, label_func(self.label))
        if self.left is not None:
            for i in self.left.generate_scoreable_spans(label_map):
                yield i

        if self.right is not None:
            for i in self.right.generate_scoreable_spans(label_map):
                yield i

    def flat_children(self):
        if self.left is None and self.right is None:
            return []
        return ([self.left] if self.left.label != 0
            else self.left.flat_children()) + [self.right]

    def to_const_tree(self, label_map, words):
        ret = ConstTree(label_map[self.label], (self.start, self.end))
        if self.end - self.start == 1:
            word = words[self.start]
            assert isinstance(word, Lexicon)
            ret.children = [word]
        else:
            ret.children = [i.to_const_tree(label_map, words) for i in self.flat_children()]
        ret.score = self.score
        return ret

    def __str__(self):
        return "[{}, {}, {}]".format(self.start, self.end, self.label)


@cython.boundscheck(False)
@cython.wraparound(False)
cdef Item* cky_binary_rule_decoder(int[:, :] rules,
                            score_t[:, :] span_scores,
                            score_t[:, :, :] label_scores,
                            score_t[:, :] leaftag_scores,
                            np.int32_t[:] leaftag_to_label,
                            Item[:,:,:] table
                           ) nogil:
    cdef int sent_length =  span_scores.shape[0]
    cdef int label_count = label_scores.shape[2]
    cdef int leaftag_count = leaftag_to_label.shape[0]
    cdef int rule_count = rules.shape[0]
    cdef bint use_leaftag = leaftag_scores is not None
    cdef score_t score
    cdef index_t i, k, l, start, end
    cdef label_t label_idx, lhs, rhs1, rhs2
    cdef int rule, leaftag_idx
    cdef Item *left_item
    cdef Item *right_item
    cdef Item *i_item
    cdef Item *current_item

    with nogil:
        for i in range(sent_length):
            for leaftag_idx in range(leaftag_count):
                label_idx = leaftag_to_label[leaftag_idx]
                if use_leaftag:
                    item_assign(&table[i,i+1,label_idx], i, i+1, label_idx,
                                # leaftag_scores[i, leaftag_idx] + label_scores[i, i+1, label_idx],
                                leaftag_scores[i, leaftag_idx],
                                NULL, NULL)
                else:
                    item_assign(&table[i,i+1,label_idx], i, i+1, label_idx,
                                label_scores[i, i+1, label_idx], NULL, NULL)

        for l in range(2, sent_length + 1):
             # for start in prange(sent_length - l + 1,
             #                     num_threads=10, schedule="dynamic", chunksize=1):
             for start in range(sent_length - l + 1):
                 end = start + l
                 for label_idx in range(label_count):
                     current_item = &table[start, end, label_idx]
                     current_item.score = -INFINITY

                 for rule in range(rule_count):
                     lhs = rules[rule,0]
                     rhs1 = rules[rule,1]
                     rhs2 = rules[rule,2]
                     current_item = &table[start, end, lhs]
                     for k in range(start+1, end):
                         left_item = &table[start,k,rhs1]
                         right_item = &table[k,end,rhs2]
                         if not left_item.initialized or not right_item.initialized:
                             continue
                         score = left_item.score + right_item.score + span_scores[start, end] + \
                                 label_scores[start, end, lhs]
                         if score > current_item.score:
                             item_assign(current_item, start, end, lhs, score, left_item, right_item)

        current_item = &table[0,sent_length,1]
        for label_idx in range(1, label_count):
            i_item = &table[0,sent_length,label_idx]
            if i_item.score > current_item.score:
                current_item = i_item
        return current_item


cdef class CKYBinaryRuleDecoder(object):
    cdef void* pool
    def __cinit__(self, max_sentence_size, label_count, use_rules=None):
        self.pool = malloc(max_sentence_size * (max_sentence_size + 1) * label_count * sizeof(Item))

    def __call__(self, rules, span_scores, label_scores,
                leaftag_scores, leaftag_to_label, internal_labels, use_rules=None):
        cdef int sent_length =  span_scores.shape[0]
        cdef int label_count = label_scores.shape[2]
        cdef Item[:,:,:] table = <Item[:sent_length, :(sent_length+1), :label_count]> self.pool
        cdef Item* current_item

        memset(self.pool, 0, sent_length * (sent_length + 1) * label_count * sizeof(Item))

        current_item = cky_binary_rule_decoder(
            rules, span_scores, label_scores,
            leaftag_scores, leaftag_to_label, table)

        if not current_item.initialized:
            raise ArithmeticError("Can't decode.")
        ret = ItemWrapped.create(current_item)
        return ret

    def __dealloc__(self):
        free(self.pool)


cdef class CKYRuleFreeDecoder(object):
    cdef void* pool
    def __cinit__(self, max_sentence_size, label_count, use_rules=False):
        self.pool = malloc(
        (max_sentence_size + 1) * (max_sentence_size + 1) * 2 * sizeof(Item))
        if use_rules:
            self.label_assigner = LabelAssigner(max_sentence_size, label_count)
        else:
            self.label_assigner = None

    def __call__(self, rules, span_scores, label_scores,
                 leaftag_scores, leaftag_to_label, internal_labels, root_rules, use_rules=None):
        cdef int sent_length =  span_scores.shape[0]
        cdef int label_count = label_scores.shape[2]
        cdef Item[:,:,:] table = <Item[:(sent_length+1), :(sent_length+1), :2]> self.pool
        if use_rules is None:
            use_rules = self.label_assigner is not None
        memset(self.pool, 0, (sent_length + 1) * (sent_length + 1) * 2 * sizeof(Item))
        cdef Item* final_item = cky_decoder_2(span_scores, label_scores, leaftag_scores,
                             leaftag_to_label, internal_labels, table)
        ret = ItemWrapped.create(final_item)
        if use_rules and self.label_assigner is not None:
            self.label_assigner(ret, rules, label_scores, leaftag_scores, leaftag_to_label, root_rules)
        return ret

    def __dealloc__(self):
        free(self.pool)


@cython.boundscheck(False)
@cython.wraparound(False)
cdef Item* cky_decoder_2(score_t[:, :] span_scores,
                  score_t[:, :, :] label_scores,
                  score_t[:, :] leaftag_scores,
                  np.int32_t[:] leaftag_to_label,
                  np.int32_t[:] internal_labels,
                  Item[:, :, :] table
                  ) nogil:
    cdef int sent_length =  span_scores.shape[0]
    cdef int label_count = label_scores.shape[2]
    cdef int leaftag_count = leaftag_to_label.shape[0]
    cdef int internal_count = internal_labels.shape[0]
    cdef bint use_leaftag = leaftag_scores is not None

    cdef int i, label_idx, leaftag_idx, internal_idx, span_length, start, k, end
    cdef int sub_label_idx, best_label_idx = 0, best_label_idx_nonempty = 0
    cdef score_t score, best_label_score, best_label_score_nonempty

    cdef Item* item
    cdef Item* item_nonempty
    cdef Item* left_item
    cdef Item* right_item
    cdef Item* current_item
    cdef Item* current_item_nonempty

    for i in range(sent_length):
        item = &table[i,i+1,0]  # empty or nonempty
        item_nonempty = &table[i,i+1,1]  # nonempty
        item.score = -INFINITY
        for leaftag_idx in range(leaftag_count):
            label_idx = leaftag_to_label[leaftag_idx]
            if use_leaftag:
                score = leaftag_scores[i, leaftag_idx]
            else:
                score = label_scores[i, i+1, label_idx]
            if score > item.score:
                item_assign(item, i, i+1, label_idx, score, NULL, NULL)
                item_assign(item_nonempty, i, i+1, label_idx, score, NULL, NULL)

    for span_length in range(2, sent_length + 1):
         for start in range(sent_length - span_length + 1):
             end = start + span_length
             current_item = &table[start, end, 0]
             current_item_nonempty = &table[start, end, 1]
             current_item.score = -INFINITY
             current_item_nonempty.score = -INFINITY
             best_label_score = -INFINITY
             best_label_score_nonempty = -INFINITY
             for internal_idx in range(internal_count):
                 label_idx = internal_labels[internal_idx]
                 score = label_scores[start, end, label_idx]
                 if score > best_label_score:
                     best_label_score = score
                     best_label_idx = label_idx
                 if score > best_label_score_nonempty and label_idx != 0:
                     best_label_score_nonempty = score
                     best_label_idx_nonempty = label_idx
             for k in range(start+1, end):
                 left_item = &table[start,k,0]
                 right_item = &table[k,end,1]
                 score = left_item.score + right_item.score + span_scores[start, end] + \
                         best_label_score
                 if score > current_item.score:
                     item_assign(current_item, start, end, best_label_idx, score, left_item, right_item)

                 score = left_item.score + right_item.score + span_scores[start, end] + \
                         best_label_score_nonempty
                 if score > current_item_nonempty.score:
                     item_assign(current_item_nonempty, start, end, best_label_idx_nonempty, score, left_item, right_item)

    return &table[0,sent_length,1]


cdef class CKYRuleFreeDecoder3(object):
    cdef void* pool
    cdef LabelAssigner label_assigner
    def __cinit__(self, max_sentence_size, label_count, use_rules=False):
        self.pool = malloc(
        (max_sentence_size + 1) * (max_sentence_size + 1) * 2 * sizeof(Item))
        if use_rules:
            self.label_assigner = LabelAssigner(max_sentence_size, label_count)
        else:
            self.label_assigner = None

    def __call__(self, rules, span_scores, label_scores,
                 leaftag_scores, leaftag_to_label, internal_labels, root_rules,
                 use_rules=None, return_item=True):
        if use_rules is None:
            use_rules = self.label_assigner is not None
        cdef int sent_length =  span_scores.shape[0]
        cdef int label_count = label_scores.shape[2]
        cdef Item[:,:,:] table = <Item[:(sent_length+1), :(sent_length+1), :2]> self.pool
        cdef Item* final_item
        cdef vector[SpanResult] result
        with nogil:
            memset(self.pool, 0, (sent_length + 1) * (sent_length + 1) * 2 * sizeof(Item))
        final_item =  cky_decoder_3(span_scores, label_scores, leaftag_scores,
                                    leaftag_to_label, internal_labels, table)
        if return_item:
            ret = ItemWrapped.create(final_item)
            if use_rules and self.label_assigner is not None:
                self.label_assigner(ret, rules, label_scores, leaftag_scores, leaftag_to_label, root_rules)
            return ret
        else:
            with nogil:
                item_get_spans(final_item, result)
            ret = []
            for i in range(result.size()):
                ret.append((result[i].start, result[i].end, result[i].label))
            return ret


    def __dealloc__(self):
        free(self.pool)


@cython.boundscheck(False)
@cython.wraparound(False)
cdef Item* cky_decoder_3(score_t[:, :] span_scores,
                  score_t[:, :, :] label_scores,
                  score_t[:, :] leaftag_scores,
                  np.int32_t[:] leaftag_to_label,
                  np.int32_t[:] internal_labels,
                  Item[:, :, :] table
                  ) nogil:
    cdef int sent_length =  span_scores.shape[0]
    cdef int label_count = label_scores.shape[2]
    cdef int leaftag_count = leaftag_to_label.shape[0]
    cdef int internal_count = internal_labels.shape[0]
    cdef bint use_leaftag = leaftag_scores is not None

    cdef int i, label_idx, leaftag_idx, internal_idx, span_length, start, k, end
    cdef int sub_label_idx, best_label_idx = 0, best_label_idx_nonempty = 0
    cdef score_t score, best_label_score, best_label_score_nonempty

    cdef Item* item_empty
    cdef Item* item_nonempty
    cdef Item* left_item
    cdef Item* right_item
    cdef Item* current_item_empty
    cdef Item* current_item_nonempty

    with nogil:
        for i in range(sent_length):
            item_empty = &table[i,i+1,0]
            item_empty.score = -INFINITY

            item_nonempty = &table[i,i+1,1]
            item_nonempty.score = -INFINITY
            for leaftag_idx in range(leaftag_count):
                label_idx = leaftag_to_label[leaftag_idx]
                if use_leaftag:
                    score = leaftag_scores[i, leaftag_idx]
                else:
                    score = label_scores[i, i+1, label_idx]
                if score > item_nonempty.score:
                    item_assign(item_nonempty, i, i+1, label_idx, score, NULL, NULL)

        for span_length in range(2, sent_length + 1):
             for start in range(sent_length - span_length + 1):
                 end = start + span_length
                 current_item_empty = &table[start, end, 0]
                 current_item_nonempty = &table[start, end, 1]
                 current_item_empty.score = -INFINITY
                 current_item_nonempty.score = -INFINITY
                 best_label_score_nonempty = -INFINITY
                 for internal_idx in range(internal_count):
                     label_idx = internal_labels[internal_idx]
                     score = label_scores[start, end, label_idx]
                     if score > best_label_score_nonempty and label_idx != 0:
                         best_label_score_nonempty = score
                         best_label_idx_nonempty = label_idx
                 for k in range(start+1, end):
                     left_item = &table[start,k,0]
                     if table[start,k,1].score > left_item.score:
                         left_item = &table[start,k,1]
                     right_item = &table[k,end,1]
                     score = left_item.score + right_item.score
                     if score > current_item_empty.score:
                         item_assign(current_item_empty, start, end, best_label_idx,
                                     score, left_item, right_item)

                     score = left_item.score + right_item.score + span_scores[start, end] + \
                             best_label_score_nonempty
                     if score > current_item_nonempty.score:
                         item_assign(
                             current_item_nonempty, start, end,
                             best_label_idx_nonempty, score, left_item, right_item)

        return &table[0,sent_length,1]


cdef struct Itemh:
    index_t start
    index_t end
    index_t head
    label_t label
    score_t score
    long left  # cython doesn't allow pointer in memorview, so define it as long
    long right
    bint initialized

cdef inline void itemh_assign(
        Itemh* item, index_t start, index_t end, index_t head,
        label_t label, score_t score, Itemh* left, Itemh* right) nogil:
    item.start = start
    item.end = end
    item.head = head
    item.label = label
    item.score = score
    item.left = <long> left
    item.right = <long> right
    item.initialized = True


cdef class LexicalizedDecoder(object):
    cdef void* pool
    def __cinit__(self, max_sentence_size, label_count):
        self.pool = malloc(
        (max_sentence_size + 1) * (max_sentence_size + 1) * (max_sentence_size + 1)  * 2 * sizeof(Itemh))

    def __call__(self, rules, span_scores, label_scores,
                 leaftag_scores, leaftag_to_label, internal_labels,
                 dep_scores):
        cdef int sent_length =  span_scores.shape[0]
        cdef int label_count = label_scores.shape[2]
        cdef Itemh[:,:,:, :] table = <Itemh[:(sent_length+1), :(sent_length+1), :(sent_length+1), :2]> self.pool
        memset(self.pool, 0, (sent_length + 1) * (sent_length + 1) * (sent_length + 1) * 2 * sizeof(Itemh))
        cdef Itemh* final_item =  lexicalized_decoder(span_scores, label_scores, leaftag_scores,
                             leaftag_to_label, internal_labels, dep_scores, table)
        return ItemhWrapped.create(final_item)

    def __dealloc__(self):
        free(self.pool)


@cython.boundscheck(False)
@cython.wraparound(False)
cdef Itemh* lexicalized_decoder(score_t[:, :] span_scores,
                  score_t[:, :, :] label_scores,
                  score_t[:, :] leaftag_scores,
                  np.int32_t[:] leaftag_to_label,
                  np.int32_t[:] internal_labels,
                  score_t[:, :] dep_scores,
                  Itemh[:, :, :, :] table
                  ) nogil:
    cdef int sent_length =  span_scores.shape[0]
    cdef int label_count = label_scores.shape[2]
    cdef int leaftag_count = leaftag_to_label.shape[0]
    cdef int internal_count = internal_labels.shape[0]
    cdef bint use_leaftag = leaftag_scores is not None

    cdef int i, label_idx, leaftag_idx, internal_idx, span_length, left_status, right_status
    cdef index_t start, end, head, head_2, k
    cdef int sub_label_idx, best_label_idx = 0, best_label_idx_nonempty = 0
    cdef score_t l2r_score_empty, l2r_score_nonempty, r2l_score_empty, r2l_score_nonempty
    cdef score_t score, best_label_score, best_label_score_nonempty

    cdef Itemh* item_empty
    cdef Itemh* item_nonempty
    cdef Itemh* left_item
    cdef Itemh* right_item
    cdef Itemh* current_item_empty
    cdef Itemh* current_item_nonempty

    for i in range(sent_length):
        item_empty = &table[i,i+1,i,0]  # empty or nonempty
        item_nonempty = &table[i,i+1,i,1]  # nonempty
        item_empty.score = -INFINITY
        item_nonempty.score = -INFINITY
        for leaftag_idx in range(leaftag_count):
            label_idx = leaftag_to_label[leaftag_idx]
            if use_leaftag:
                score = leaftag_scores[i, leaftag_idx]
            else:
                score = label_scores[i, i+1, label_idx]
            if score > item_nonempty.score:
                itemh_assign(item_nonempty, i, i+1, i, label_idx, score, NULL, NULL)

    for span_length in range(2, sent_length + 1):
         for start in range(sent_length - span_length + 1):
             end = start + span_length
             # calculate label scores
             best_label_score = -INFINITY
             best_label_score_nonempty = -INFINITY
             for internal_idx in range(internal_count):
                 label_idx = internal_labels[internal_idx]
                 score = label_scores[start, end, label_idx]
                 if score > best_label_score:
                     best_label_score = score
                     best_label_idx = label_idx
                 if score > best_label_score_nonempty and label_idx != 0:
                     best_label_score_nonempty = score
                     best_label_idx_nonempty = label_idx

             # reset head
             for head in range(start, end):
                itemh_assign(
                    &table[start, end, head, 0], start, end, head,
                    0, -INFINITY, NULL, NULL)
                itemh_assign(
                    &table[start, end, head, 1], start, end, head,
                    best_label_idx_nonempty, -INFINITY, NULL, NULL)

             for k in range(start+1, end):
                 for head in range(start, k):
                     for head_2 in range(k, end):
                         for left_status in range(2):
                             for right_status in range(2):
                                 if left_status == 0 and right_status == 0:
                                     continue
                                 left_item = &table[start,k,head,left_status]
                                 right_item = &table[k,end,head_2,right_status]

                                 current_item_empty = &table[start,end,head,0]
                                 current_item_nonempty = &table[start,end,head,1]

                                 l2r_score_empty = left_item.score + right_item.score + dep_scores[head+1, head_2+1]
                                 if l2r_score_empty > current_item_empty.score:
                                     current_item_empty.score = l2r_score_empty
                                     current_item_empty.left = <long> left_item
                                     current_item_empty.right = <long> right_item

                                 l2r_score_nonempty = l2r_score_empty + span_scores[start, end] + best_label_score
                                 if l2r_score_nonempty > current_item_nonempty.score:
                                     current_item_nonempty.score = l2r_score_nonempty
                                     current_item_nonempty.left = <long> left_item
                                     current_item_nonempty.right = <long> right_item

                                 current_item_empty = &table[start,end,head_2,0]
                                 current_item_nonempty = &table[start,end,head_2,1]

                                 r2l_score_empty = left_item.score + right_item.score + dep_scores[head_2+1,head+1]
                                 r2l_score_nonempty  = r2l_score_empty + span_scores[start, end] + best_label_score
                                 if r2l_score_empty > current_item_empty.score:
                                     current_item_empty.score = r2l_score_empty
                                     current_item_empty.left = <long> left_item
                                     current_item_empty.right = <long> right_item
                                 if r2l_score_nonempty > current_item_nonempty.score:
                                     current_item_nonempty.score = r2l_score_nonempty
                                     current_item_nonempty.left = <long> left_item
                                     current_item_nonempty.right = <long> right_item

    current_item_nonempty = &table[0,sent_length,0,1]
    for head in range(0, sent_length):
        item_nonempty = &table[0,sent_length,head,1]
        item_nonempty.score += dep_scores[0, head+1]
        if item_nonempty.score > current_item_nonempty.score:
            current_item_nonempty = item_nonempty
    return current_item_nonempty


cdef class ItemhWrapped(object):
    cdef public:
        object start, end, head, label, score, left, right

    @staticmethod
    cdef create(Itemh* item):
        cdef ItemhWrapped self = ItemhWrapped()
        self.start = item.start
        self.end = item.end
        self.head = item.head
        self.label = item.label
        self.score = item.score
        self.left = ItemhWrapped.create(<Itemh *> item.left) if item.left != 0 else None
        self.right = ItemhWrapped.create(<Itemh *> item.right) if item.right != 0 else None
        return self

    def generate_scoreable_spans(self, label_map):
        if self.label != 0:
            yield (self.start, self.end, label_map[self.label])
        if self.left is not None:
            for i in self.left.generate_scoreable_spans(label_map):
                yield i

        if self.right is not None:
            for i in self.right.generate_scoreable_spans(label_map):
                yield i

    def flat_children(self):
        if self.left is None and self.right is None:
            return []
        if self.left.label != 0 and self.right.label != 0:
            return [self.left, self.right]
        elif self.left.label == 0 and self.right.label != 0:
            return self.left.flat_children() + [self.right]
        elif self.left.label != 0 and self.right.label == 0:
            return [self.left] + self.right.flat_children()
        else:
            raise Exception("score_t empty!!")

    def to_const_tree(self, label_map, words):
        ret = ConstTree(label_map[self.label], (self.start, self.end))
        if self.end - self.start == 1:
            word = words[self.start]
            assert isinstance(word, Lexicon)
            ret.children = [word]
        else:
            ret.children = [i.to_const_tree(label_map, words) for i in self.flat_children()]
        ret.score = self.score
        return ret

    def to_dependency(self, container=None):
        if container is None:
            assert self.start == 0
            container = np.zeros((self.end + 1,), dtype=np.int32)
            container[0] = -1
        assert not (self.left is not None) ^ (self.right is not None)
        if self.left is not None:
            if self.head == self.left.head:
                container[self.right.head+1] = self.head+1
            elif self.head == self.right.head:
                container[self.left.head+1] = self.head+1
            else:
                raise Exception("Wrong head transfer.")
            self.left.to_dependency(container)
            self.right.to_dependency(container)
        return container

    def __str__(self):
        return "[{}, {}, {}]".format(self.start, self.end, self.label)


@cython.boundscheck(False)
@cython.wraparound(False)
def cky_binary_rule_decoder_unlabeled(score_t[:, :] span_scores):
    cdef int sent_length =  span_scores.shape[0]
    cdef score_t score
    cdef int i, k, l, start, end, label_idx, lhs, rule
    cdef Item *left_item
    cdef Item *right_item
    cdef Item *i_item
    cdef Item *current_item

    cdef void* pool = calloc(sent_length * (sent_length + 1), sizeof(Item))
    cdef Item[:,:] table = <Item[:sent_length, :(sent_length+1)]> pool


    with nogil:
        for l in range(2, sent_length + 1):
             for start in range(sent_length - l + 1):
                 end = start + l
                 current_item = &table[start, end]
                 current_item.score = -INFINITY
                 for k in range(start+1, end):
                     left_item = &table[start,k]
                     right_item = &table[k,end]
                     if not left_item.initialized or not right_item.initialized:
                         continue
                     score = left_item.score + right_item.score + span_scores[start, end]
                     if score > current_item.score:
                         item_assign(current_item, start, end, 0, score, left_item, right_item)
    current_item = &table[0,sent_length]
    ret = ItemWrapped.create(current_item)
    free(pool)
    return ret


cdef class CKYBinaryRuleDecoderUnlabeled(object):
    cdef void* pool
    def __cinit__(self, max_sentence_size, label_count):
        self.pool = malloc(max_sentence_size * (max_sentence_size + 1) * label_count * sizeof(Item))

    def __call__(self, rules, span_scores, label_scores,
                leaftag_scores, leaftag_to_label, internal_labels):
        cdef int sent_length =  span_scores.shape[0]
        cdef int label_count = label_scores.shape[2]
        cdef Item[:,:,:] table = <Item[:sent_length, :(sent_length+1), :label_count]> self.pool
        cdef Item* current_item

        memset(self.pool, 0, sent_length * (sent_length + 1) * label_count * sizeof(Item))

        current_item = cky_binary_rule_decoder(
            rules, span_scores, label_scores,
            leaftag_scores, leaftag_to_label, table)

        if not current_item.initialized:
            raise ArithmeticError("Can't decode.")
        ret = ItemWrapped.create(current_item)
        return ret

    def __dealloc__(self):
        free(self.pool)


cdef struct TreeInfo:
    label_t left
    label_t right
    bint is_leaf


cdef inline void tree_info_assign(
        TreeInfo* item,
        label_t left, label_t right, bint is_leaf) nogil:
    item.left = left
    item.right = right
    item.is_leaf = is_leaf


cdef struct TreeLabelItem:
    label_t label
    long left  # TreeLabelItem* actually
    long right
    score_t score
    bint initialized


cdef inline void tree_label_item_assign(TreeLabelItem* item,
                                        label_t label,
                                        TreeLabelItem* left,
                                        TreeLabelItem* right,
                                        score_t score
                                        ) nogil:
    item.label = label
    item.left = <long> left
    item.right = <long> right
    item.score = score
    item.initialized = True


cdef struct LabelAssignerRetVal:
    TreeLabelItem[:] items
    TreeLabelItem* best_item
    int new_pointer


cdef inline LabelAssignerRetVal  pack_label_assigner_retval(
        TreeLabelItem[:] items, TreeLabelItem* best_item,
        int new_pointer
        ) nogil:
    cdef LabelAssignerRetVal r
    r.items = items
    r.best_item = best_item
    r.new_pointer = new_pointer
    return r


cdef class LabelAssigner(object):
    cdef void* pool
    def __cinit__(self, max_sentence_size, label_count):
        # n_label * n_node = n_label * 2 * n_word
        self.pool = malloc((max_sentence_size + 1) * 2 * label_count *
                           sizeof(TreeLabelItem))

    def __dealloc__(self):
        free(self.pool)


    @cython.boundscheck(False)
    @cython.wraparound(False)
    cdef LabelAssignerRetVal fill_table(self, gold_tree,
                 int[:, :] rules,
                 score_t[:, :, :] label_scores,
                 score_t[:, :] leaftag_scores,
                 np.int32_t[:] leaftag_to_label,
                 np.int32_t[:] root_rules,
                 TreeLabelItem[:, :] table,
                 int table_pointer
                 ):
        cdef int label_count = label_scores.shape[2]
        cdef int leaftag_count = leaftag_to_label.shape[0]
        cdef int rules_count = rules.shape[0]
        cdef int root_rules_count = root_rules.shape[0]
        cdef index_t start = gold_tree.start
        cdef index_t end = gold_tree.end
        cdef TreeLabelItem *label_item, *left_item, *right_item, *best_item
        cdef bint use_leaftag = leaftag_scores is not None
        cdef int leaftag_idx, rule_idx, root_rule_idx, new_pointer
        cdef label_t lhs, rhs1, rhs2, label_idx
        cdef LabelAssignerRetVal left, right
        cdef score_t score, new_score

        cdef bint restrict_root = root_rules is not None and (table_pointer == 0)

        if end == start + 1:
            # leaf node
            with nogil:
                for leaftag_idx in range(leaftag_count):
                    label_idx = leaftag_to_label[leaftag_idx]
                    if label_idx == 0:
                        continue
                    label_item = &table[table_pointer, label_idx]
                    if use_leaftag:
                        score = leaftag_scores[start, leaftag_idx]
                    else:
                        score = label_scores[start, end, label_idx]
                    tree_label_item_assign(label_item, label_idx, NULL, NULL, score)
            return pack_label_assigner_retval(table[table_pointer],
                                              NULL,
                                              table_pointer + 1)
        else:
            # internal node
            left = self.fill_table(
                gold_tree.left, rules, label_scores,
                leaftag_scores, leaftag_to_label, root_rules,
                table, table_pointer + 1)
            right = self.fill_table(
                gold_tree.right, rules, label_scores,
                leaftag_scores, leaftag_to_label, root_rules,
                table, left.new_pointer)
            new_pointer = right.new_pointer
            with nogil:
                best_item = NULL
                rule_idx = -1
                root_rule_idx = -1
                while True:
                    if not restrict_root:
                        # iterate over all rules
                        rule_idx += 1
                        if rule_idx >= rules_count:
                            break
                    else:
                        # iterate over root rules
                        root_rule_idx += 1
                        if root_rule_idx >= root_rules_count:
                            break
                        rule_idx = root_rules[root_rule_idx]
                    lhs = rules[rule_idx,0]
                    rhs1 = rules[rule_idx,1]
                    rhs2 = rules[rule_idx,2]
                    label_item = &table[table_pointer, lhs]
                    left_item = &left.items[rhs1]
                    right_item = &right.items[rhs2]
                    if not left_item.initialized or not right_item.initialized:
                        continue
                    new_score = left_item.score + right_item.score + label_scores[start, end, lhs]
                    if not label_item.initialized or new_score > label_item.score:
                        tree_label_item_assign(label_item, lhs, left_item, right_item, new_score)
                    if best_item == NULL or new_score > best_item.score:
                        best_item = label_item
                return pack_label_assigner_retval(table[table_pointer], best_item, right.new_pointer)

    cdef rewrite_labels(self, gold_tree,
                        TreeLabelItem* label_item):
        # assert label_item != NULL  <- it may happen when sent_length == 1
        cdef TreeLabelItem* left = <TreeLabelItem*> label_item.left
        cdef TreeLabelItem* right = <TreeLabelItem*> label_item.right
        gold_tree.label = label_item.label
        if left != NULL:
            self.rewrite_labels(gold_tree.left, left)
        if right != NULL:
            self.rewrite_labels(gold_tree.right, right)

    def __call__(self, gold_tree,
                 int[:, :] rules,
                 score_t[:, :, :] label_scores,
                 score_t[:, :] leaftag_scores,
                 np.int32_t[:] leaftag_to_label,
                 np.int32_t[:] root_rules,
                ):
        cdef int sent_length =  label_scores.shape[0]
        cdef int label_count = label_scores.shape[2]
        cdef TreeLabelItem[:,:] table = <TreeLabelItem[:(sent_length * 2 + 2), :label_count]> self.pool

        if sent_length == 1:
            # best_item === NULL now
            return

        memset(self.pool, 0, (sent_length + 1) * 2 * label_count * sizeof(TreeLabelItem))
        cdef LabelAssignerRetVal ret = self.fill_table(
            gold_tree, rules, label_scores, leaftag_scores, leaftag_to_label, root_rules, table, 0)
        self.rewrite_labels(gold_tree, ret.best_item)


decoder_types = {"binary": CKYBinaryRuleDecoder, "rulefree": CKYRuleFreeDecoder, "rulefree3": CKYRuleFreeDecoder3}
