from typing import List, Iterable, Any, Generator, Tuple
from queue import Queue
from argparse import Namespace

import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import IPython
import omegaconf
from transformers import AutoModel, AutoTokenizer, MarianTokenizer, MarianMTModel
import pandas as pd
import numpy as np

import logging

logger = logging.getLogger(__name__)


class BackTranslator:
    def __init__(self, target_lang: str = 'fr'):
        logger.info(f'loading en-{target_lang} model and tokenizer')
        self.enc_model = MarianMTModel.from_pretrained(
            f'Helsinki-NLP/opus-mt-en-{target_lang}',
            cache_dir="/nas/home/qasemi/model_cache",
        )
        self.enc_tokens = MarianTokenizer.from_pretrained(
            f'Helsinki-NLP/opus-mt-en-{target_lang}',
            cache_dir="/nas/home/qasemi/model_cache",
            use_fast=False
        )

        logger.info(f'loading {target_lang}-en model and tokenizer')
        self.dec_model = MarianMTModel.from_pretrained(
            f'Helsinki-NLP/opus-mt-{target_lang}-en',
            cache_dir="/nas/home/qasemi/model_cache",
        )

        self.dec_tokens = MarianTokenizer.from_pretrained(
            f'Helsinki-NLP/opus-mt-{target_lang}-en',
            cache_dir="/nas/home/qasemi/model_cache",
            use_fast=False
        )
        logger.info('BackTranslator fully initialized.')

    @staticmethod
    def _translate(sent: List[str], model, tokenizer) -> List[str]:
        translated = model.generate(**tokenizer.prepare_translation_batch(
            sent, return_tensors="pt",
            # model_max_length=256,
        ))
        tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
        return tgt_text

    def rephrase(self, sent_batch: List[str]) -> List[str]:
        enc = self._translate(sent_batch, self.enc_model, self.enc_tokens)
        dec = self._translate(enc, self.dec_model, self.dec_tokens)
        return dec

    def rephrase_corpus(self, corpus: List[str], batch_size: int = 8, use_cuda: bool = True) -> List[str]:
        out = []
        for i in tqdm(range(0, len(corpus), batch_size), desc='rephrasing'):
            batch = corpus[i:min(i+batch_size, len(corpus))]
            assert all([isinstance(s, str) for s in batch]), f'{batch}'
            out += self.rephrase(batch)
        return out


class Rephraser:
    def __init__(self, config: omegaconf.dictconfig.DictConfig):
        self.config = config

    def process(self, df: pd.DataFrame) -> pd.DataFrame:
        torch.set_num_threads(16)
        bt = BackTranslator(target_lang='fr')
        df_out = df.copy()
        df_out.fillna('', inplace=True)
        cols = ['question']+[f'choices_{i}' for i in range(4)]
        all_sents = df_out[cols].values.reshape([-1])

        logger.info(f'rephrasing array with shape {all_sents.shape}')

        all_rephrased = np.array(bt.rephrase_corpus(all_sents, batch_size=4)).reshape([-1, len(cols)],)
        # df_out['question'] = bt.rephrase_corpus(df['question'].values)
        for i, c in enumerate(cols):
            # c = f'choices_{i}'
            df_out[c] = all_rephrased[:, i]
        return df_out


if __name__ == '__main__':
    bt = BackTranslator(target_lang='fr')
    ins = pd.Series([
        'Net is typically  used for catching fish. When is this impossible?',
        'Net is typically  used for catching fish. When is this possible?',
        'Net is typically  used for catching fish.',
        'Accordian bag are typically used for carrying accordion. What makes this impossible?',
        'Accordian bag are typically used for carrying accordion. When is this possible?',
        'Accordian bag are typically used for carrying accordion.',
    ])
    out = bt.rephrase_corpus(ins.values, batch_size=2)
    print(out)
