from dotenv import load_dotenv
from genai import Credentials, Client
from genai.schema import TextGenerationParameters, TextGenerationReturnOptions
import jsonlines
import random
import json
from tqdm import tqdm
import os

random.seed(42)

def get_bam_response(prompt, model_id):
    load_dotenv()
    credentials = Credentials.from_env()
    parameters = TextGenerationParameters(decoding_method="greedy", max_new_tokens=1600, temperature=0.05, stop_sequences=["\nINPUT"], include_stop_sequence=False)
    client = Client(credentials=credentials)
    responses = list(
        client.text.generation.create(
            model_id=model_id,
            inputs=[prompt],
            parameters=parameters,
        )
    )
    # return responses[0].results[0].generated_text
    response=responses[0].results[0].generated_text.split("Utterance:")[0].strip()
    code=response.split("OUTPUT:", 1)[1].strip()
    # print("code is:\n",code)
    return code

all_examples = []
spotify_all_examples = []
tmdb_all_examples = []
jsonlines_data_path='../jsonlines_data'
# 
with jsonlines.open(jsonlines_data_path+'/train.jsonl') as reader:
    for i, obj in enumerate(reader):
        # if i == 25:
        #     print(len(spotify_all_examples))
        all_examples.append(obj)
        if obj["domain"] == "spotify":
            spotify_all_examples.append(obj)
        if obj["domain"] == "tmdb":
            tmdb_all_examples.append(obj)

spotify_examples = [spotify_all_examples[0], spotify_all_examples[16], spotify_all_examples[25]]
tmdb_examples = [tmdb_all_examples[0], tmdb_all_examples[16], tmdb_all_examples[25]]


spotify_prompt_template = f"""
You are a python code expert. Generate python code as per the user utterance as shown below:

INPUT:

{spotify_examples[0]["input"]}

Thought: 
- call function GET_me_player_currently_playing
- If response_obj is the response of GET_me_player_currently_playing, looking at the GraphQL schema of GET_me_player_currently_playing I must use response_obj.item.id as the `id` parameter of GET_tracks__id_

Observation: 
response_obj = GET_me_player_currently_playing()

Action: 
Continue

Thought: 
- call function GET_tracks__id_
- If response_obj1 is the response of GET_tracks__id_, looking at the GraphQL schema of GET_tracks__id_ I must use response_obj1.artists[0].id as the `id` parameter of PUT_me_following.

Observation: 
response_obj = GET_me_player_currently_playing()
response_obj1 = GET_tracks__id_(id=response_obj.item.id)

Action: 
Continue

Thought: 
- call function PUT_me_following
- The `type` (string) parameter of PUT_me_following must be "artist" as specified by the function signature of PUT_me_following

Observation: 
response_obj = GET_me_player_currently_playing()
response_obj1 = GET_tracks__id_(id=response_obj.item.id)
PUT_me_following(type="artist", ids=response_obj1.artists[0].id)

Action: 
Finish

OUTPUT:

{spotify_examples[0]["output"]}

INPUT:

{spotify_examples[1]["input"]}

Thought: 
- call function GET_me_following
- The parameter `type` (string) of GET_me_following must be "artist" from the user utterance

Observation: 
response_obj = GET_me_following(type=\"artist\")

Action: 
Continue

Thought: 
- call function GET_artists__id__top_tracks
- If response_obj is the response of GET_me_following, looking at the GraphQL schema of GET_me_following I must use response_obj.artists.items[0].id as the `id` (integer) parameter of GET_artists__id__top_tracks

Observation: 
response_obj = GET_me_following(type=\"artist\")
response_obj1 = GET_artists__id__top_tracks(id=response_obj.artists.items[0].id)

Action: 
Continue

Thought: 
- call function PUT_me_tracks
- If response_obj1 is the response of GET_artists__id__top_tracks, since the type of the `ids` parameter of PUT_me_tracks is string, looking at the GraphQL schema of GET_artists__id__top_tracks I must use the python code ",".join([track.id for track in response_obj1.tracks]) as the `ids` (string) parameter of PUT_me_tracks

Observation: 
response_obj = GET_me_following(type=\"artist\")
response_obj1 = GET_artists__id__top_tracks(id=response_obj.artists.items[0].id)
PUT_me_tracks(ids=\",\".join([track.id for track in response_obj1.tracks]))

Action: 
Finish

OUTPUT:

{spotify_examples[1]["output"]}

INPUT:

{spotify_examples[2]["input"]}

Thought: 
- call function GET_search
- The parameter `q` (string) of GET_search must be "quiet songs" from the user utterance
- The parameter `type` (string) of GET_search must be "track" from the user utterance

Observation: 
response_obj = GET_search(q="quiet songs", type="track")

Action: 
Continue

Thought: 
- call function PUT_me_player_play
- If response_obj is the response of GET_search, since the type of the `uris` parameter of PUT_me_player_play is array, looking at the GraphQL schema of GET_search I must use the python code [item.uri for item in response_obj.tracks.items] as the `uris` (array) parameter of PUT_me_player_play

Observation: 
response_obj = GET_search(q="quiet songs", type="track")
PUT_me_player_play(uris=[item.uri for item in response_obj.tracks.items])

Action: 
Finish

OUTPUT:

{spotify_examples[2]["output"]}

INPUT:

<<input>>

Thought:

"""

tmdb_prompt_template = f"""
You are a python code expert. Generate python code as per the user utterance as shown below:

INPUT:

{tmdb_examples[0]["input"]}

Thought: 
- call function GET_search_movie
- The parameter `query` (string) of the first call to GET_search_movie must be "Avatar" from the user utterance

Observation: 
GET_search_movie(query="Avatar")

Action: 
Continue

Thought: 
- call function GET_search_movie
- The parameter `query` (string) of the second call to GET_search_movie must be "Avatar: The Way of Water" from the user utterance

Observation: 
GET_search_movie(query="Avatar")
GET_search_movie(query="Avatar: The Way of Water")

Action: 
Finish

OUTPUT:

{tmdb_examples[0]["output"]}

INPUT:

{tmdb_examples[1]["input"]}

Thought: 
- call function GET_search_movie
- The parameter `query` (string) of GET_search_movie must be "Django Unchained" from the user utterance

Observation: 
response_obj = GET_search_movie(query="Django Unchained")

Action: 
Continue

Thought: 
- call GET_movie__movie_id__credits
- The parameter `query` (string) of GET_search_movie must be "Django Unchained" from the user utterance

Observation: 
response_obj = GET_search_movie(query="Django Unchained")
response_obj1 = GET_movie__movie_id__credits(movie_id=response_obj.results[0].id)

Action: 
Continue

Thought: 
- call function GET_person__person_id__movie_credits
- If response_obj1 is the response of GET_movie__movie_id__credits, since the type of the `person_id` parameter of PUT_me_tracks is integer, looking at the GraphQL schema of GET_movie__movie_id__credits I must use the python code [person.id for person in response_obj1.crew if person.job.lower()=="director"][0] as the `person_id` (integer) parameter of GET_person__person_id__movie_credits

Observation: 
response_obj = GET_search_movie(query="Django Unchained")
response_obj1 = GET_movie__movie_id__credits(movie_id=response_obj.results[0].id)
GET_person__person_id__movie_credits(person_id=[person.id for person in response_obj1.crew if person.job.lower()=="director"][0])

Action: 
Finish

OUTPUT:

{tmdb_examples[1]["output"]}

INPUT:

{tmdb_examples[2]["input"]}

Thought: 
- call function GET_tv_popular
- The parameter `query` (string) of the first call to GET_search_movie must be "Avatar" from the user utterance

Observation: 
GET_search_movie(query="Avatar")

Action: 
Continue

Thought: 
- call function GET_tv__tv_id__credits
- If response_obj is the response of GET_tv_popular, since the type of the `tv_id` parameter of GET_tv__tv_id__credits is integer, looking at the GraphQL schema of GET_tv_popular I must use response_obj.results[0].id as the `tv_id` (integer) parameter of GET_tv__tv_id__credits

Observation: 
response_obj = GET_tv_popular()
GET_tv__tv_id__credits(tv_id=response_obj.results[0].id)

Action: 
Finish

OUTPUT:

{tmdb_examples[2]["output"]}

INPUT:

<<input>>

Thought:

"""

test_samples = []
with jsonlines.open(jsonlines_data_path+'/test.jsonl') as reader:
    for obj in reader:
        test_samples.append(obj)

spotify_test_samples = []
with jsonlines.open(jsonlines_data_path+'/test_spotify.jsonl') as reader:
    for obj in reader:
        spotify_test_samples.append(obj)

tmdb_test_samples = []
with jsonlines.open(jsonlines_data_path+'/test_tmdb.jsonl') as reader:
    for obj in reader:
        tmdb_test_samples.append(obj)


model_id = "codellama/codellama-34b-instruct"
output_folder = "codellama_output/"

with open(os.path.join(output_folder, "output_spotify.jsonl"), "w") as out:
    for sample in tqdm(spotify_test_samples):
        prompt = spotify_prompt_template.replace("<<input>>", sample["input"])
        code = get_bam_response(prompt, model_id)
        # code = response.split("OUTPUT:", 1)[1].strip()
        d = {"input": sample["input"], "expected_output": sample["output"], "generated_output": code}
        print(json.dumps(d), file=out)

with open(os.path.join(output_folder, "output_tmdb.jsonl"), "w") as out:
    for sample in tqdm(tmdb_test_samples):
        prompt = tmdb_prompt_template.replace("<<input>>", sample["input"])
        code = get_bam_response(prompt, model_id)
        # code = response.split("OUTPUT:", 1)[1].strip()
        d = {"input": sample["input"], "expected_output": sample["output"], "generated_output": code}
        print(json.dumps(d), file=out)

with open(os.path.join(output_folder, "output.jsonl"), "w") as out:
    for sample in tqdm(test_samples):
        if sample["domain"] == "spotify":
            prompt = spotify_prompt_template.replace("<<input>>", sample["input"])
        else:
            prompt = tmdb_prompt_template.replace("<<input>>", sample["input"])
        code = get_bam_response(prompt, model_id)
        # code = response.split("OUTPUT:", 1)[1].strip()
        d = {"input": sample["input"], "expected_output": sample["output"], "generated_output": code}
        print(json.dumps(d), file=out)