import json
import os
import numpy as np
from PIL import Image
import torch
import datasets
from datasets import Dataset
from datasets.features import ClassLabel
from datasets import Features, Sequence, ClassLabel, Value, Array2D, Array3D
from datasets import load_metric
from transformers import AutoProcessor
from transformers import LayoutLMv3ForTokenClassification
from torch_geometric.nn import DataParallel
from transformers import TrainingArguments, Trainer
from transformers.data.data_collator import default_data_collator


"""
Adapted from https://github.com/NielsRogge/Transformers-Tutorials
"""


def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    w, h = image.size
    return image, (w, h)


def normalize_bbox(bbox, size):
    return [
        int(1000 * bbox[0] / size[0]),
        int(1000 * bbox[1] / size[1]),
        int(1000 * bbox[2] / size[0]),
        int(1000 * bbox[3] / size[1]),
    ]


def generate_examples(split):
    filepath = os.path.join('dataset/buddie_v1/', split)
    ann_dir = os.path.join(filepath, "ocr")
    img_dir = os.path.join(filepath, "images")
    for guid, file in enumerate(sorted(os.listdir(ann_dir))):
        tokens = []
        bboxes = []
        ner_tags = []
        file_path = os.path.join(ann_dir, file)
        with open(file_path, "r", encoding="utf8") as f:
            data = json.load(f)
        image_path = os.path.join(img_dir, file)
        image_path = image_path.replace("json", "jpg")
        image, size = load_image(image_path)
        for idx, token in enumerate(data["tokens"]):
            tokens.append(token['text'])
            class_id = token['class_id'] + 1
            if idx > 0 and class_id > 0 and data['tokens'][idx-1] != data['tokens'][idx]: class_id = class_id * 2 - 1
            if idx > 0 and class_id > 0 and data['tokens'][idx-1] == data['tokens'][idx]: class_id = class_id * 2
            ner_tags.append(class_id)
            bboxes.append(normalize_bbox([token['x'], token['y'], token['x']+token['width'], token['y']+token['height']], size))
        yield {"id": str(guid), "tokens": tokens, "bboxes": bboxes, "ner_tags": ner_tags,
                        "image": image}
        
        
def get_classes():
    labels = ["O"]
    path = '/home/anourbak/gen/GPT-GNN/dataset/buddie_v1/labels.json'
    with open(path, 'r') as f:
        j = json.load(f)
        for cls in j['key_entities']:
            labels.append("B-" + cls['label'].upper())
            labels.append("I-" + cls['label'].upper())
    return labels

# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
# unique labels.
def get_label_list(labels):
    unique_labels = set()
    for label in labels:
        unique_labels = unique_labels | set(label)
    label_list = list(unique_labels)
    label_list.sort()
    return label_list


def prepare_examples(examples):
  images = examples[image_column_name]
  words = examples[text_column_name]
  boxes = examples[boxes_column_name]
  word_labels = examples[label_column_name]

  encoding = processor(images, words, boxes=boxes, word_labels=word_labels,
                       truncation=True, padding="max_length")

  return encoding


def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    if return_entity_level_metrics:
        # Unpack nested dictionaries
        final_results = {}
        for key, value in results.items():
            if isinstance(value, dict):
                for n, v in value.items():
                    final_results[f"{key}_{n}"] = v
            else:
                final_results[key] = value
        return final_results
    else:
        return {
            "precision": results["overall_precision"],
            "recall": results["overall_recall"],
            "f1": results["overall_f1"],
            "accuracy": results["overall_accuracy"],
        }
    

if __name__ == "__main__":
    features = datasets.Features(
                    {
                        "id": datasets.Value("string"),
                        "tokens": datasets.Sequence(datasets.Value("string")),
                        "bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
                        "ner_tags": datasets.Sequence(
                            datasets.features.ClassLabel(
                                names=get_classes()
                            )
                        ),
                        "image": datasets.features.Image(),
                    }
                )


    train = Dataset.from_generator(generate_examples, features=features, gen_kwargs={"split": "train"})
    test = Dataset.from_generator(generate_examples, features=features, gen_kwargs={"split": "test"})

    # we'll use the Auto API here - it will load LayoutLMv3Processor behind the scenes,
    # based on the checkpoint we provide from the hub
    processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)

    features = features
    column_names = train.column_names
    print(column_names)
    image_column_name = "image"
    text_column_name = "tokens"
    boxes_column_name = "bboxes"
    label_column_name = "ner_tags"


    if isinstance(features[label_column_name].feature, ClassLabel):
        label_list = features[label_column_name].feature.names
        # No need to convert the labels since they are already ints.
        id2label = {k: v for k,v in enumerate(label_list)}
        label2id = {v: k for k,v in enumerate(label_list)}
    else:
        label_list = get_label_list(train[label_column_name])
        id2label = {k: v for k,v in enumerate(label_list)}
        label2id = {v: k for k,v in enumerate(label_list)}
    num_labels = len(label_list)

    # we need to define custom features for `set_format` (used later on) to work properly
    features = Features({
        'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)),
        'input_ids': Sequence(feature=Value(dtype='int64')),
        'attention_mask': Sequence(Value(dtype='int64')),
        'bbox': Array2D(dtype="int64", shape=(512, 4)),
        'labels': Sequence(feature=Value(dtype='int64')),
    })

    train_dataset = train.map(
        prepare_examples,
        batched=True,
        remove_columns=column_names,
        features=features,
    )
    eval_dataset = test.map(
        prepare_examples,
        batched=True,
        remove_columns=column_names,
        features=features,
    )

    train_dataset.set_format("torch")
    eval_dataset.set_format("torch")


    metric = load_metric("seqeval")

    return_entity_level_metrics = False

    model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-large",
                                                         id2label=id2label,
                                                         label2id=label2id)

    device = torch.device("cuda")
    # model = DataParallel(model, device_ids=[0, 1, 2, 3])
    model.to(device)

    training_args = TrainingArguments(output_dir="/path/to/output",
                                    max_steps=1000,
                                    per_device_train_batch_size=2,
                                    per_device_eval_batch_size=2,
                                    learning_rate=1e-5,
                                    evaluation_strategy="steps",
                                    eval_steps=100,
                                    load_best_model_at_end=True,
                                    metric_for_best_model="f1")


    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=processor,
        data_collator=default_data_collator,
        compute_metrics=compute_metrics
    )

    trainer.train()
    trainer.evaluate()
