from asyncio.log import logger
from dataclasses import asdict
from socket import TIPC_CRITICAL_IMPORTANCE
from datasets import load_dataset
from transformers import AutoTokenizer
# from models.modeling_roberta_single import RobertaPromptTuning
from transformers import HfArgumentParser
from config import modelArgument, promptArgument
import sys
import os
import torch
import logging
from data.process import processors, output_modes
from transformers import RobertaConfig, RobertaTokenizer
from models.modeling_roberta import RobertaForMaskedLM
# from models.modeling_roberta_lora import RobertaForMaskedLM
from flearn.utils.process_data import PromptDataset
from train_utils import train
from eval_utils import evaluate
from layer_utils import evaluate_layer_scores, evaluate_layer_scores_F_score
import numpy as np

logging.basicConfig(level=logging.DEBUG)


arguments = sys.argv
if len(arguments) < 3:
    assert len(arguments) == 3, "please input the config file and the sort method..."

config_path = arguments[1]
sort_method = arguments[2]
parser = HfArgumentParser((modelArgument, promptArgument))
model_config, prompt_config = parser.parse_json_file(config_path)
model_config.task_name = model_config.task_name.lower()

print("the prefix length:", prompt_config.num_prompt_tokens)
model_config.log_dir = model_config.log_dir + 'PROMPT/' + model_config.task_name + '/'
model_config.output_dir = model_config.output_dir + 'PROMPT/' + model_config.task_name + '/'
model_config.data_dir = model_config.data_dir + model_config.task_name + '/'

print(model_config.log_dir)
print(model_config.output_dir)

if not os.path.exists(model_config.log_dir):
        try:
            os.makedirs(model_config.log_dir)
        except:
            pass
# fitlog.set_log_dir(model_config.log_dir)
# torch.cuda.set_device(model_config.local_rank)
device = torch.device("cuda", 0)
# torch.distributed.init_process_group(backend="nccl")
model_config.n_gpu = 1
model_config.device = device

if model_config.task_name not in processors:
    raise ValueError("Task not found: %s" % (model_config.task_name))
processor = processors[model_config.task_name]()
model_config.output_mode = output_modes[model_config.task_name]
label_list = processor.get_labels()
num_labels = len(label_list)


config = RobertaConfig.from_pretrained(        
        model_config.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=model_config.task_name,
        cache_dir=model_config.cache_dir if model_config.cache_dir else None,
        output_hidden_states=True,
        output_attentions=True,
    )

tokenizer = RobertaTokenizer.from_pretrained(
    model_config.model_name_or_path,
    do_lower_case=model_config.do_lower_case,
    cache_dir=model_config.cache_dir if model_config.cache_dir else None,        
)

model = RobertaForMaskedLM.from_pretrained(
    model_config.model_name_or_path, 
    config=config,
)

# exit()

model = model.to(device)

for param in model.parameters():
        param.requires_grad = False


total_model_params = 0
num_trained_params = 0
for name, p in model.named_parameters():
    if p.requires_grad:
        num_trained_params += p.numel()
    total_model_params += p.numel()

print("Total Model Parameters: {}, Trainable Parameters: {}".format(total_model_params, num_trained_params))

# train
train_dataset = PromptDataset(model_config, model_config.task_name, tokenizer, data_type='train')
print(len(train_dataset))
if sort_method == "ours":
    layer_scores = evaluate_layer_scores_F_score(model_config, train_dataset, model)
    print(layer_scores)
    layer_sort = np.argsort(np.array(layer_scores))
    print(layer_sort)






