import time
from omegaconf import OmegaConf

from datasets import load_dataset
import transformers
from transformers import AutoTokenizer, Trainer, TrainingArguments
from tokenizers import Tokenizer

import wandb
import python_minifier


from utils import EvaluateOnFirstStepCallback, LogClassificationsCallback, get_compute_metrics, \
                  get_config, preprocess_logits_for_metrics, LogGenerationsCallback, DataCollatorForLanguageModeling
from base_model import CustomGPT2Config, CustomGPT2LMHeadModel, CustomGPT2ForSequenceClassification


def get_dataset_and_tokenizer(config, split=None):
    dataset = load_dataset(config.dataset_name, config.dataset_version, data_dir=config.data_dir, split=split, cache_dir="../data")
    tokenizer = AutoTokenizer.from_pretrained(config.tokenizer)
    if tokenizer.bos_token_id == tokenizer.vocab_size:
        print("Adding special tokens to tokenizer...")
        tokenizer.bos_token = "<bos>"
        tokenizer.pad_token = "<pad>"
        tokenizer.unk_token = "<unk>"
    else:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    num_proc = config.num_proc
    text_column = config.text_column
    label_column = config.label_column
    
    if config.filter_chess_games:
        def filter_chess_games(example):
            num_moves = len(example[text_column].split())
            return 10 <= num_moves <= 150
        dataset = dataset.filter(filter_chess_games, num_proc=num_proc, desc="Filtering chess games")
    
    if config.dataset_name == 'Skylion007/openwebtext':
        # use only a subset of the dataset
        dataset['train'] = dataset['train'].select(range(500_000))

    # split into train/test/validation if missing
    if 'validation' not in dataset:
        print("No validation set, making one...")
        split_dataset = dataset['train'].train_test_split(test_size=config.generated_test_size, seed=2357, shuffle=True)  # make validation split
        dataset['validation'] = split_dataset['test']
        dataset['train'] = split_dataset['train']

    if 'test' not in dataset:
        print("No test set, making one...")
        split_dataset = dataset['train'].train_test_split(test_size=config.generated_test_size, seed=2357, shuffle=True)  # make validation split
        dataset['test'] = split_dataset['test']
        dataset['train'] = split_dataset['train']

    if config.obfuscate_python:
        def obfuscate_function(examples):
            codes = examples[text_column]
            new_codes = []
            for code in codes:
                # try:
                #     # convert code from python2 to python3 code
                #     tree = lib2to3.refactor.RefactoringTool([]).refactor_string(code, 'code')
                #     code = str(tree)
                #     print("Success in converting code from python2 to python3 code")
                # except Exception as e:
                #     print(f"Could not convert code from python2 to python3 code: {e}")
                #     pass

                try:
                    minified_code = python_minifier.minify(
                            code,
                            remove_annotations=True,
                            remove_literal_statements=True,
                            rename_locals=True,
                            rename_globals=True,
                            preserve_shebang=True,
                        )
                    new_codes.append(minified_code)
                except Exception as e:
                    pass
            return {text_column: new_codes}
        
        dataset = dataset.map(obfuscate_function, batched=True, num_proc=num_proc, remove_columns=dataset['train'].column_names, desc="Obfuscating")

    # tokenize text data (but keep labels as they are)
    append_to_text = config.append_to_text
    def tokenize_function(examples):
        result = tokenizer([x + append_to_text for x in examples[text_column]])
        # if label_column in examples:
        #     result["labels"] = examples[label_column]
        return result

    tokenized_dataset = dataset.map(tokenize_function, batched=True, num_proc=num_proc, remove_columns=dataset['train'].column_names, desc="Tokenizing")

    # mask the labels except for the last sentence
    if not config.group_texts:
        dot_idx = tokenizer.convert_tokens_to_ids(".")
        def mask_labels_function(examples):
            input_ids = examples["input_ids"]
            labels = []
            for i in range(len(input_ids)):
                idx = len(input_ids[i]) - input_ids[i][:-1][::-1].index(dot_idx) - 2
                idx = idx + 2   # dot before last sentence + "Box A"
                labels.append([-100] * (idx+1) + input_ids[i][idx+1:])
                # print(input_ids[i], idx, labels[i])
            return {"labels": labels}

        tokenized_dataset = tokenized_dataset.map(mask_labels_function, batched=True, num_proc=num_proc, desc="Masking labels")


    # in language modeling, contatenate all texts
    if config.group_texts and not config.is_classification:
        ctx_len = config.context_length
        def group_texts(examples):
            concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
            total_length = len(concatenated_examples[list(examples.keys())[0]])
            total_length = (total_length // ctx_len) * ctx_len
            result = { k: [t[i: i + ctx_len] for i in range(0, total_length, ctx_len)] 
                    for k, t in concatenated_examples.items()}
            # result["labels"] = result["input_ids"].copy()
            return result

        tokenized_dataset = tokenized_dataset.map(group_texts, batched=True, num_proc=num_proc, desc="Grouping texts")

    if config.is_classification:
        # for classification, truncate (from the left!) texts to max length
        max_length = config.context_length
        def truncate_texts(examples):
            result = {k: [t[-max_length:] for t in examples[k]] for k in ["input_ids", "attention_mask"]}
            result["labels"] = examples["labels"]
            return result
        
        tokenized_dataset = tokenized_dataset.map(truncate_texts, batched=True, num_proc=num_proc, desc="Truncating texts")
    
    return tokenized_dataset, tokenizer


def load_model(config: OmegaConf, tokenizer):
    
    model_config = CustomGPT2Config(
        vocab_size=tokenizer.vocab_size,
        n_positions=config.context_length,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        n_embd=config.n_embd,
        n_layer=config.n_layer,
        n_head=config.n_head,
        n_inner=config.n_inner,
        activation_function=config.activation_function,
        custom_attention=config.custom_attention,
        resid_pdrop=config.resid_pdrop,
        embd_pdrop=config.embd_pdrop,
        attn_pdrop=config.attn_pdrop,
        reorder_and_upcast_attn=config.reorder_and_upcast_attn,
        scale_attn_by_inverse_layer_idx=config.scale_attn_by_inverse_layer_idx,
        scale_attn_weights=config.scale_attn_weights,
        use_cache=config.use_cache,
    )
    
    if config.is_classification:
        model_class = CustomGPT2ForSequenceClassification
    else:
        model_class = CustomGPT2LMHeadModel
    model = model_class(model_config)
    
    if config.load_checkpoint is not None:
        pretrained_model = model_class.from_pretrained(config.load_checkpoint)
        try:
            model.load_state_dict(pretrained_model.state_dict())
        except RuntimeError as e:
            print("Could not load model checkpoint, disabling strict loading.")
            print(e)
            model.load_state_dict(pretrained_model.state_dict(), strict=False)
    return model


def train(config: OmegaConf):
    # Reproducibility
    if config.seed is None:
        config.seed = round(time.time() * 10 % 1e6)
    print("Seed:", config.seed)
    transformers.set_seed(config.seed)

    # Dataset
    print("Loading dataset...")
    tokenized_dataset, tokenizer = get_dataset_and_tokenizer(config)
    print(tokenized_dataset)

    # Config
    print("Loading model...")
    model = load_model(config, tokenizer)

    # Training
    print("Training...")

    training_args = TrainingArguments(
        max_steps=config.max_steps,
        per_device_train_batch_size=config.per_device_train_batch_size,
        per_device_eval_batch_size=config.per_device_eval_batch_size,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        dataloader_num_workers=config.dataloader_num_workers,

        # AdamW optimizer
        learning_rate=config.learning_rate,
        weight_decay=config.weight_decay,
        adam_beta1=config.adam_beta1,
        adam_beta2=config.adam_beta2,

        # Scheduler
        warmup_steps=config.warmup_steps,
        lr_scheduler_type=config.lr_scheduler_type,

        # Other
        output_dir=config.output_dir,
        overwrite_output_dir=True,
        save_steps=config.save_steps,
        save_total_limit=config.save_total_limit,
        seed=config.seed,
        fp16=config.fp16,
        bf16=config.bf16,


        # Wandb
        logging_dir=config.logging_dir,
        logging_steps=config.logging_steps,
        report_to="wandb" if config.use_wandb else None,
        disable_tqdm=False,

        # Eval
        evaluation_strategy=config.evaluation_strategy,
        eval_steps=config.eval_steps,
    )

    # Data collator
    if config.is_classification:
        data_collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer, max_length=config.context_length, padding="longest")
    else:
        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)

    # Metrics
    # def compute_metrics(eval_pred):
    #     predictions, labels = eval_pred
    #     return {
    #         "accuracy": (predictions.argmax(axis=1) == labels).float().mean(),
    #     }

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["validation"],
        data_collator=data_collator,
        compute_metrics=get_compute_metrics(tokenizer),
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        callbacks=[
            EvaluateOnFirstStepCallback(),
        ]
    )

    if config.is_classification:
        trainer.add_callback(LogClassificationsCallback(tokenized_dataset, tokenizer, model))
    else:
        trainer.add_callback(LogGenerationsCallback(tokenized_dataset, tokenizer, model, num_beams=config.num_beams, prompt_length=config.prompt_length, max_new_tokens=config.max_new_tokens))


    if config.use_wandb and trainer.is_world_process_zero():
        wandb.init(entity="efagnou", project="attention_for_causality", name="gpt2", config=OmegaConf.to_container(config, resolve=True), reinit=True)
        wandb.log({"nb_params": model.num_parameters()})


    trainer.train()

    if config.use_wandb and trainer.is_world_process_zero():
        wandb.finish()


if __name__ == "__main__":
    config = get_config()
    train(config)
