import requests as rq
from typing import List
from lm_eval.api.model import LM
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model
from tqdm import tqdm

@register_model("TSP")
class TSPLM(LM):
    def __init__(self,*args, **kwargs):
        super().__init__()
        self.batch_size = int(kwargs.get('batch_size'))
        self.api_url = 'http://0.0.0.0:8000/api/generate/'

    def generate_until(self, requests: List[Instance]) -> List[str]:
        print(f"total num of samples: {len(requests)} with bs {self.batch_size}")
        final_texts = []

        # batch request
        for i in tqdm(range(0, len(requests), self.batch_size), desc="Generating response..."):
            batch_requests = requests[i:i + self.batch_size]
            messages_list = [[{"role": "user", "content": request.args[0]}] for request in batch_requests]
            max_gen_toks = max(request.args[1].get("max_gen_toks", 8) for request in batch_requests)
            untils = [request.args[1].get("until", None) for request in batch_requests]

            data = {
                "messages_list": messages_list,
                "max_new_tokens": max_gen_toks,
                "until":untils[0]
            }

            response = rq.post(self.api_url, json=data)
            if response.status_code == 200:
                generated_texts_batch = response.json().get("response", [])

                for j, generated_text in enumerate(generated_texts_batch):
                    input_text = batch_requests[j].args[0]
                    until = untils[j]
                    if until:
                        for term in until:
                            if term in generated_text:
                                generated_text = generated_text.split(term)[0]
                                break
                    final_text = generated_text # input_text + generated_text
                    final_texts.append(final_text)
            else:
                raise Exception(f"API call failed with status code {response.status_code}: {response.text}")

        return final_texts

    def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
        raise NotImplementedError()


    def loglikelihood_rolling(self, requests: list[Instance]) -> list[tuple[float, bool]]:
        raise NotImplementedError()