import transformers
import torch
import pandas as pd
from sqlalchemy import create_engine
from tqdm import tqdm
from sklearn.metrics import classification_report
import random
from pathlib import Path
import json
import argparse
from awq import AutoAWQForCausalLM

def form_sqlite_str(path):
    return f'sqlite:///{path}'

class RedditScoreClassifier:
    def __init__(self):
        self.UPVOTE = "UPVOTE"
        self.DOWNVOTE = "DOWNVOTE"
        self.EXAMPLE_COMMENT_WORD_LIMIT = 150
        self.EXAMPLE_SELF_TEXT_WORD_LIMIT = 75
        # self.SYSTEM_PROMPT = "You are a moderator on the subreddit r/{sub}. You are aware of the type of content that gets heavily upvoted and downvoted on r/{sub}. Each prompt is a comment on r/{sub}. Reply with only UPVOTE or DOWNVOTE based on whether you think the content will be heavily upvoted or downvoted."
        self.SYSTEM_PROMPT = """
You are a machine learning model trained to predict whether a comment on a Reddit post in the subreddit r/{SUBREDDIT_NAME} will receive a lot of upvotes or downvotes based on the POST TITLE, SELF TEXT, and the COMMENT text itself.

Consider factors such as the tone, language used, relevance of the comment to the original post when making your prediction, and community norms in r/{SUBREDDIT_NAME}.  In your analysis, try to consider how users on r/{SUBREDDIT_NAME} might typically react to a comment in the context of the specified post.

{examples}

Output either "UPVOTE" or "DOWNVOTE" to indicate your prediction for the following comment. No explanation needed. Just output "UPVOTE" or "DOWNVOTE".
"""

        self.EXAMPLES_PROMPT = """
Here are some examples:

{examples}
"""
        
        self.EXAMPLES_TEMPLATE = """POST TITLE: {title}
COMMENT: {comment}
SELF TEXT: {self_text}
PREDICTION: {prediction}

"""

    def truncate_str_to_word_limit(self, text:str, limit:int) -> str:
        return " ".join(text.split()[:limit])

    def format_example(self, title:str, comment:str, self_text:str, prediction:int) -> tuple:
        title = title
        comment = self.truncate_str_to_word_limit(comment, self.EXAMPLE_COMMENT_WORD_LIMIT)
        self_text = self.truncate_str_to_word_limit(self_text, self.EXAMPLE_SELF_TEXT_WORD_LIMIT)
        prediction = self.UPVOTE if prediction == 1 else ('' if prediction is None else self.DOWNVOTE)
        return self.EXAMPLES_TEMPLATE.format(title=title, comment=comment, self_text=self_text, prediction=prediction)

    def form_system_prompt(self, sub:str, examples:list[tuple]) -> str:
        if len(examples) > 0:
            examples = [self.format_example(title=e[0], comment=e[1], self_text=e[2], prediction=e[3]) for e in examples]
            examples_str = self.EXAMPLES_PROMPT.format(examples="\n".join(examples))
        else:
            examples_str = ""
        return self.SYSTEM_PROMPT.format(SUBREDDIT_NAME=sub, examples=examples_str)
    
    def form_user_prompt(self, data:tuple) -> str:
        return self.format_example(title=data[0], comment=data[1], self_text=data[2], prediction=None)
    
    def prompt(self, sub:str, data:tuple, examples:list[tuple]=[]) -> str:
        prompt = self.generate_prompt(sub, data, examples)
        text = self.gen_text(prompt)
        return text


class Llama2RedditScoreClassifier(RedditScoreClassifier):
    def __init__(self, model_id, device_map='cuda', quantized=False):
        super().__init__()
        if quantized:
            self.model = AutoAWQForCausalLM.from_quantized(model_id, fuse_layers=True,
                                trust_remote_code=False, safetensors=True, device_map=device_map)
        else:
            self.model = transformers.AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, 
                                trust_remote_code=False, device_map=device_map)
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=False, truncation=True, max_length=2048)

        self.PROMPT_TEMPLATE = '''[INST] <<SYS>>
{sys_prompt}
<</SYS>>
{prompt}[/INST]

'''

    def generate_prompt(self, sub:str, data:tuple, examples:list[tuple]) -> str:
        return self.PROMPT_TEMPLATE.format(
                    sys_prompt=self.form_system_prompt(sub, examples),
                    prompt=self.form_user_prompt(data)
                )
    
    def gen_text(self, prompt:str) -> str:
        tokens = self.tokenizer(prompt, return_tensors='pt').input_ids.cuda()
        generation_output = self.model.generate(
            tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.95,
            top_k=40,
            max_new_tokens=512
        )
        text = self.tokenizer.decode(generation_output[0]).strip()[len(prompt)+3:-4].strip()
        return text


class Llama3RedditScoreClassifier(RedditScoreClassifier):
    def __init__(self, model_id, device_map='cuda'):
        super().__init__()
        self.pipeline = transformers.pipeline(
                            "text-generation",
                            model=model_id,
                            model_kwargs={"torch_dtype": torch.bfloat16},
                            device_map=device_map,
                        )
    
    def generate_prompt(self, sub:str, data:tuple, examples:list[tuple]) -> str:
        messages = [
            {"role": "system", "content": self.form_system_prompt(sub, examples)},
            {"role": "user", "content": self.form_user_prompt(data)},
        ]
        prompt = self.pipeline.tokenizer.apply_chat_template(messages, tokenize=False, 
                                                        add_generation_prompt=True)
        return prompt

    def gen_text(self, prompt:str) -> str:
        terminators = [self.pipeline.tokenizer.eos_token_id, 
                       self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")]        
        outputs = self.pipeline(
            prompt,
            max_new_tokens=128,
            eos_token_id=terminators,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
        )
        return outputs[0]["generated_text"][len(prompt):].strip()


def get_examples(df, row, n) -> list[tuple]:
    if n == 0:
        return []
    # randomly sample n (without replacement) rows from df from the same subreddit that "row" is from
    # exclude "row" from the sample
    columns_to_keep = ['p_title', 'c_body', 'p_selftext', 'c_polarity']
    subdf = df[df['p_subreddit'] == row['p_subreddit']][columns_to_keep]
    subdf = subdf[subdf.index != row.name]
    subdf = subdf.sample(n=n)
    return subdf.to_records(index=False).tolist()


parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, required=True)
parser.add_argument("--input_db", type=str, required=True)
parser.add_argument("--out_path", type=str, required=True)
parser.add_argument("--exp_name", type=str, required=True)
parser.add_argument("--quantized", action='store_true')
parser.add_argument("--model_family", type=str, default='LLAMA2')  # LLAMA2 or LLAMA3
parser.add_argument("--num_examples", type=int, default=0)
parser.add_argument("--n_prompts_to_print", type=int, default=0)
args = parser.parse_args()

print(args)

model_id = args.model_id
input_db = Path(args.input_db)
out_path = Path(args.out_path)
exp_name = args.exp_name
CHOICES = [1, 0]
quantized = args.quantized
model_family = args.model_family
n_prompts_to_print = args.n_prompts_to_print

clf = Llama3RedditScoreClassifier(model_id) if model_family == 'LLAMA3' else Llama2RedditScoreClassifier(model_id, quantized=quantized)
df = pd.read_sql_table('test', create_engine(form_sqlite_str(input_db) ))
res = []
n_prompts_printed = 0
for index, row in tqdm(df.iterrows(), total=df.shape[0]):
    data = (row['p_title'], row['c_body'], row['p_selftext'])
    examples = get_examples(df, row, args.num_examples)
    if n_prompts_printed < n_prompts_to_print:
        prompt = clf.generate_prompt(row['p_subreddit'], data, examples)
        print(prompt)
        n_prompts_printed += 1
    res.append(clf.prompt(row['p_subreddit'], data, examples))
df[exp_name + '_preds'] = res
df[exp_name + '_preds_binary'] = df[exp_name + '_preds'].apply(lambda e: 1 if e == 'UPVOTE' else (0 if e == 'DOWNVOTE' else random.choice(CHOICES)) )
df.to_json(exp_name + '.json')

f1s = {}
for sub, subdf in df.groupby('p_subreddit'):
    f1s[sub] = classification_report(subdf['c_polarity'],
                      subdf[exp_name + '_preds_binary'],
                      output_dict=True)['macro avg']['f1-score']
with open(out_path, 'w') as f:
    json.dump(f1s, f)

# python similarity/llm_inference_quantized.py --model_id "meta-llama/Meta-Llama-3-8B-Instruct" --input_db /mnt/e/influence_risk_analysis/data/mix_opp_3000.db --out_path f1_sims_l3_8B_5shot.json --exp_name L3_8B_5S --model_family LLAMA3 --num_examples 5