import argparse
from pathlib import Path
from tqdm import tqdm

# torch
import clip
import os
from PIL import Image
from os import listdir
from os.path import isfile, join
import torch

from einops import repeat

# vision imports

from PIL import Image
from torchvision.utils import make_grid, save_image

# dalle related classes and utils

from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer

# argument parsing

parser = argparse.ArgumentParser()

parser.add_argument('--dalle_path', type = str, required = True,
                    help='path to your trained DALL-E')

parser.add_argument('--vqgan_model_path', type=str, default = None,
                   help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)')

parser.add_argument('--vqgan_config_path', type=str, default = None,
                   help='path to your trained VQGAN config. This should be a .yaml file.  (only valid when taming option is enabled)')

parser.add_argument('--text_file_path', type = str, required = True,
                    help='your text prompt')

parser.add_argument('--num_images', type = int, default = 128, required = False,
                    help='number of images')

parser.add_argument('--batch_size', type = int, default = 4, required = False,
                    help='batch size')

parser.add_argument('--top_k', type = float, default = 0.9, required = False,
                    help='top k filter threshold')

parser.add_argument('--outputs_dir', type = str, default = './outputs', required = False,
                    help='output directory')

parser.add_argument('--bpe_path', type = str,
                    help='path to your huggingface BPE json file')

parser.add_argument('--hug', dest='hug', action = 'store_true')

parser.add_argument('--chinese', dest='chinese', action = 'store_true')

parser.add_argument('--taming', dest='taming', action='store_true')

parser.add_argument('--gentxt', dest='gentxt', action='store_true')

args = parser.parse_args()

# helper fns

def exists(val):
    return val is not None

# tokenizer

if exists(args.bpe_path):
    klass = HugTokenizer if args.hug else YttmTokenizer
    tokenizer = klass(args.bpe_path)
elif args.chinese:
    tokenizer = ChineseTokenizer()

# load DALL-E

dalle_path = Path(args.dalle_path)

assert dalle_path.exists(), 'trained DALL-E must exist'

load_obj = torch.load(str(dalle_path))
print("load_obj: ", type(load_obj), load_obj.keys ())
dalle_params, vae_params, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights')

dalle_params.pop('vae', None) # cleanup later

if vae_params is not None:
    vae = DiscreteVAE(**vae_params)
elif not args.taming:
    vae = OpenAIDiscreteVAE()
else:
    vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path)


dalle = DALLE(vae = vae, **dalle_params).cuda()

dalle.load_state_dict(weights)

# clip model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# generate images

image_size = vae.image_size

text_file_path = args.text_file_path
text_file = open(text_file_path, encoding='utf-8', mode='r')
texts = [l.strip() for l in text_file.readlines()]
print(texts)
# texts = args.text.split('|')

for j, text in tqdm(enumerate(texts)):
    if args.gentxt:
        text_tokens, gen_texts = dalle.generate_texts(tokenizer, text=text, filter_thres = args.top_k)
        text = gen_texts[0]
    else:
        text_tokens = tokenizer.tokenize([text], dalle.text_seq_len).cuda()

    text_tokens = repeat(text_tokens, '() n -> b n', b = args.num_images)

    outputs = []

    for text_chunk in tqdm(text_tokens.split(args.batch_size), desc = f'generating images for - {text}'):
        output = dalle.generate_images(text_chunk, filter_thres = args.top_k)
        outputs.append(output)

    outputs = torch.cat(outputs)

    # save all images
    file_name = text
    dir_name = file_name.replace(' ', '_')[:(100)]
    outputs_dir = Path(args.outputs_dir) / dir_name
    outputs_dir.mkdir(parents = True, exist_ok = True)

    for i, image in tqdm(enumerate(outputs), desc = 'saving images'):
        save_image(image, outputs_dir / f'{i}.jpg', normalize=True)
        with open(outputs_dir / 'caption.txt', 'w') as f:
            f.write(file_name)

    print(f'created {args.num_images} images at "{str(outputs_dir)}"')


    image_dir = "./outputs/" + dir_name + "/"

    images_files = [str(i) + ".jpg" for i in range(args.num_images)]

    emb_text = clip.tokenize([file_name]).to(device)
    # print(text.shape)
    probs = []

    with torch.no_grad():

        text_features = model.encode_text(emb_text)
        for f_id, f_name in enumerate(images_files):
            image_path = image_dir + f_name
            # print(image_path)
            image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)

            logits_per_image, logits_per_text = model(image, emb_text)
            prob = logits_per_image[0][0].cpu().numpy().tolist()  # softmax(dim=-1).cpu().numpy()
            probs.append(prob)


    max_idx = probs.index(max(probs))
    print("Best photo index is: ", max_idx)

    top_n_index = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True)[:256]
    print("Top_N_index", top_n_index)

    top_N_images_dir = image_dir + "top_N_images"
    if not os.path.exists(top_N_images_dir):
        os.makedirs(top_N_images_dir)

    for i, idx in enumerate(top_n_index):
        src = image_dir + str(idx) + ".jpg"
        dst = top_N_images_dir + "/" + str(i) + ".jpg"
        os.rename(src, dst)

