import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from functools import partial
from paddle.nn import Linear, Conv2D
from paddle.vision import transforms

def get_clip_token_for_string(tokenizer, string):
    batch_encoding = tokenizer(
        string,
        truncation=True,
        max_length=77,
        return_length=True,
        return_overflowing_tokens=False,
        padding="max_length",
        return_tensors="pt",
    )
    tokens = batch_encoding["input_ids"]
    assert (
        paddle.count_nonzero(tokens - 49407) == 2
    ), f"String '{string}' maps to more than a single token. Please use another string"
    return tokens[0, 1]


def get_clip_vision_emb(encoder, processor, img):
    _img = img.repeat(1, 3, 1, 1) * 255
    inputs = processor(images=_img, return_tensors="pt")
    outputs = encoder(**inputs)
    emb = outputs.image_embeds
    return emb


def get_recog_emb(ocr_predictor, img_list):
    _img_list = [(img * 255) for img in img_list]
    _, preds_neck = ocr_predictor.pred_imglist(_img_list, show_debug=False)
    return preds_neck


def pad_H(x):
    _, _, H, W = x.shape
    p_top = (W - H) // 2
    p_bot = W - H - p_top
    return F.pad(x, (0, 0, p_top, p_bot))


class EncodeNet(nn.Layer):
    def __init__(self, in_channels, out_channels):
        super(EncodeNet, self).__init__()
        chan = 16
        n_layer = 4  # downsample

        self.conv1 = Conv2D(in_channels, chan, 3, padding=1)
        self.conv_list = nn.LayerList([])
        _c = chan
        for i in range(n_layer):
            self.conv_list.append(Conv2D(_c, _c * 2, 3, padding=1, stride=2))
            _c *= 2
        self.conv2 = Conv2D(_c, out_channels, 3, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2D(1)
        self.act = nn.Silu()

    def forward(self, x):
        x = self.act(self.conv1(x))
        for layer in self.conv_list:
            x = self.act(layer(x))
        x = self.act(self.conv2(x))
        x = self.avgpool(x)
        x = x.reshape([x.shape[0], -1])
        return x


class VisionEmbedding(nn.Layer):
    def __init__(
        self,
        token_embedding: nn.Embedding,
        glyph_channels=20,
        position_channels=1,
        add_pos=False,
        emb_type="ocr",
        lang='ch',
        **kwargs,
    ):
        super().__init__()
        self.emb_type = emb_type
        self.token_embedding = token_embedding
        token_dim = token_embedding.weight.shape[-1]
        self.tags_dic = {}
        self.add_pos = add_pos
        self.mul=1
        if lang=='ch':
            self.mul = 16
        if add_pos:
            self.position_encoder = EncodeNet(position_channels, token_dim)
        if emb_type == "ocr":
            self.ocr_predictor = None  # will be initialized in model.init
            self.proj = Linear(glyph_channels // self.mul, token_dim)
        if emb_type == "conv":
            self.glyph_encoder = EncodeNet(glyph_channels, token_dim)

    def encode_text(self, visual_text):
        img_list = [self.process_visual_text(text=visual_text)]
        process = transforms.ToTensor()
        img_list = [process(img) for img in img_list]
        if self.emb_type == "ocr":
            recog_emb = get_recog_emb(self.ocr_predictor, img_list)
            enc_glyph = self.proj(recog_emb.reshape([recog_emb.shape[0] * self.mul, -1]).astype(self.proj.dtype))
        elif self.emb_type == "conv":
            enc_glyph = self.glyph_encoder(pad_H(paddle.concat(img_list, axis=0).unsqueeze(0).astype(self.glyph_encoder.dtype)))
        if self.add_pos:
            enc_pos = self.position_encoder(paddle.concat(img_list, axis=0).unsqueeze(0))
            enc_glyph = enc_glyph + enc_pos
        return enc_glyph.reshape([enc_glyph.shape[0], -1])

    def forward(
        self,
        input_ids,
    ):
        """replace visual tokens embedding with vision embedding"""

        text_emb = self.token_embedding(input_ids)
        bsz = input_ids.shape[0]
        # 按照span排序
        visual_list = [
            sorted(
                [[tags_dic[k]["span"][0], tags_dic[k]["span"][1], k] for k in tags_dic]
            )
            for tags_dic in self.tags_dic
        ]  # [start, end, text]
        bias = 0  # 当把start；end之间的token替换为1个时，对之后的token会产生位移偏差
        for i in range(bsz):
            for tags in visual_list:
                if tags == []:
                    continue
                for tag in tags:
                    start, end, text = tag[0], tag[1], tag[2]
                    if start != -1 and end != -1:
                        text_emb[i] = paddle.concat([text_emb[i,0 : start - bias],self.encode_text(text), text_emb[i, end - bias :]])[:77]
                        # bias += end - start
        return text_emb

    def embedding_parameters(self):
        return self.parameters()
