import json
from tqdm import tqdm
import re
import time
import os

import openai
import tensorflow as tf
import numpy as np

OUTPUT_DIR = 'outputs'

D0 = '{'
D1 = '}'

TEMPLATE_GLM = (
    'Here is some text: {input}. Here is a rewrite of the text, which is '
    '{suggestion}.')

TEMPLATE_DIALOG = ('Here is some text: {input}. Rewrite it to be '
                   '{suggestion}.')

PAIRS = [
    [
        'When the doctor asked Linda to take the medicine, he smiled and gave '
        'her a lollipop.', 'more scary',
        'When the doctor told Linda to take the medicine, there had been a '
        'malicious gleam in her eye that Linda didn\'t like at all.'
    ],
    [
        'they asked loudly, over the sound of the train.', 'more intense',
        'they yelled aggressively, over the clanging of the train.'
    ],
    [
        'When Mohammed left the theatre, it was already dark out',
        'more about the movie itself',
        'The movie was longer than Mohammed had expected, and despite the '
        'excellent ratings he was a bit disappointed when he left the theatre.'
    ],
    ['next to the path', 'about France', 'next to la Siene'],
    [
        'The man stood outside the grocery store, ringing the bell.',
        'about clowns',
        'The man stood outside the circus, holding a bunch of balloons.'
    ],
    ['the bell ringing', 'more flowery', 'the peales of the jangling bell'],
    [
        'against the tree', 'include the word "snow"',
        'against the snow-covered bark of the tree'
    ],
]

SENTIMENT_FEWSHOT_PAIRS = [
    {
        'pos': 'service has always been extremely friendly and attentive. have '
               'never had any complaints',
        'neg':
            'the service is cold and rude. i will complain to the manager next'
            ' time'
    },
    {
        'pos': 'charbroiled oysters are fairly priced',
        'neg': 'charbroiled oysters are way expensive',
    },
    {
        'pos': 'my apple ale paired perfectly with my meal',
        'neg': 'the apple ale did not go well with the meal',
    },
    {
        'pos':
            "the delicious seafood was worth the drive and I can't wait to go "
            'back',
        'neg': "the seafood was not worth the drive. I won't be going back",
    },
    {
        'pos': 'they were delicious and the bread it came with was perfect!',
        'neg': 'they were unflavorful and the bread it came with was soggy',
    },
]

def parse(txt, d0=D0, d1=D1):
  """Parse the text, given the format."""
  start_idx = txt.find(d0) + 1
  end_idx = txt.find(d1)
  if (start_idx >= 0 and end_idx >= 0) and start_idx < end_idx:
    return txt[start_idx:end_idx]


def make_context(pairs,
                 template,
                 transformation,
                 fewshot_sentiment_direction=None,
                 d0=D0,
                 d1=D1):
  """Generate the full context string."""

  full_context = []

  # If fewshot_sentiment_direction, use the negative to positive few shot examples.
  if fewshot_sentiment_direction:
    parts = fewshot_sentiment_direction.split('to')
    source = parts[0]
    target = parts[1]
    for pair in pairs:
      source_sentence = pair[source]
      target_sentence = pair[target]
      full_context += [
          template.format(
              input=(d0 + source_sentence + d1), suggestion=transformation)
      ]
      full_context += [d0 + target_sentence + d1]

  # Otherwise, just use the normal augmented few shot examples.
  else:
    for [orig, suggestion, edited] in pairs:
      full_context += [
          template.format(input=(d0 + orig + d1), suggestion=suggestion)
      ]
      full_context += [d0 + edited + d1]
  return full_context


def add_zeros(x):
  while (len(x)) < 5:
    x = '0' + x
  return x


def read_dataset(input_path, n, batch_size=20, last_idx=None):
  lines = open(input_path, 'r').readlines()[last_idx:n]
  lines = [line.strip() for line in lines]
  lines_batches = [
      lines[x:x + batch_size] for x in range(0, len(lines), batch_size)
  ]
  return lines_batches

def get_input(pairs,
              template,
              line,
              transformation,
              model_type,
              fewshot_sentiment_direction,
              is_zero_shot=False,
              d0=D0,
              d1=D1):
  prompt = template.format(input=(d0 + line + d1), suggestion=transformation)

  if is_zero_shot:
    return [prompt]

  context = make_context(pairs, template, transformation,
                         fewshot_sentiment_direction)
  model_input = context + [prompt]
  return model_input


def combine_shards(input_pattern, num_shards, output_file):
  with open(output_file, 'a') as writer:
    for i in tqdm(range(num_shards)):
      input_file = f'{input_pattern}-{add_zeros(str(i))}-of-{add_zeros(str(num_shards))}'
      input_lines = open(input_file, 'r').readlines()
      for line in input_lines:
        writer.write(line)

  print(f'output at {output_file}')


def make_output_path(model_type: str,
                     transformation_name: str,
                     dataset: str,
                     fewshot_sentiment_direction=None,
                     is_zero_shot=False):
  fewshot_string = '_fewshot' if fewshot_sentiment_direction else ''
  zero_shot_string = '_zeroshot' if is_zero_shot else ''
  base_path = f'{OUTPUT_DIR}/style:{transformation_name}_dataset:{dataset}_model:{model_type}{fewshot_string}{zero_shot_string}'
  return base_path


class ExperimentRunner():

  def compute_live(self, prompt, model_type, gpt3_key, gpt3_engine):
      if model_type is 'gpt3':
          response = openai.Completion.create(engine=gpt3_engine, prompt=prompt, max_tokens=64, logprobs=0, temperature=1.0, top_p=0.6, echo=False, n=1)
          output = response["choices"][0]["text"].strip()
          return [parse(output)]

  def run_experiment(self,
                     input_path: str,
                     model_type: str,
                     transformation: str,
                     transformation_name: str,
                     dataset: str,
                     num_examples=50,
                     fewshot_sentiment_direction=None,
                     is_zero_shot=False,
                     gpt3_key=None,
                     gpt3_engine=None):
    """Run a single experiment: for every line in the input_path file, run our zero shot prompt on it and save the outputs in a csv with the columns 'input_number' (original index of the input), 'input', 'style', 'output'"""

    print('----------------------------------------')
    print('EXPERIMENT CONFIGURATION')
    print('----------------------------------------')
    print('input_path: ', input_path)
    print('model_type: ', model_type)
    print('transformation: ', transformation)
    print('transformation_name: ', transformation_name)
    print('dataset: ', dataset)
    print('num_examples: ', num_examples)
    print('fewshot_sentiment_direction: ', fewshot_sentiment_direction)
    print('gpt3_key: ', gpt3_key)
    print('gpt3_engine: ', gpt3_engine)
    print('----------------------------------------')

    if model_type is not 'gpt3':
        raise Exception('Unfortunately, we do not support other model inferrences at this time. The other models (LLM and LLM-dialog) used in the paper will be described in detail in an upcom-ing paper, and added to this code when available')
    if model_type is 'gpt3' and not gpt3_key or not gpt3_engine:
        raise Exception('Please provide a key to use the GPT3 model.')
    if model_type is 'gpt3' and gpt3_key:
        openai.api_key = gpt3_key


    template = TEMPLATE_GLM if model_type == 'glm' else TEMPLATE_DIALOG
    pairs = SENTIMENT_FEWSHOT_PAIRS if fewshot_sentiment_direction else PAIRS

    base_path = make_output_path(model_type, transformation_name, dataset,
                                 fewshot_sentiment_direction, is_zero_shot)
    full_output_file = f'{base_path}_results.tsv'

    # If we have already saved the full file and are just appending lines, do that here.
    last_idx = 0
    if os.path.exists(full_output_file):
      print(f'Experiment has already been run: {full_output_file}')
      with open(full_output_file) as f:
        last_line = f.readlines()[-1]
        last_idx = int(last_line.split('\t')[0])
      print(f'last saved datapoint index was {last_idx}')
      if last_idx >= num_examples:
        return

    lines_batches = read_dataset(input_path, num_examples, last_idx=last_idx)
    num_lines_batches = len(lines_batches)

    for _ in range(3):
      for batch_num, lines in enumerate(lines_batches):
        output_path = f'{base_path}.tsv-{add_zeros(str(batch_num))}-of-{add_zeros(str(num_lines_batches))}'

        if os.path.exists(output_path):
          print(f'skipping {output_path}')

        else:
          attempts = 0
          while attempts < 10:
            try:
              responses_list = []
              for line in lines:
                model_input = get_input(pairs, template, line, transformation,
                                        model_type, fewshot_sentiment_direction,
                                        is_zero_shot)

                responses = self.compute_live(model_input, model_type, gpt3_key, gpt3_engine)
                responses_list.append(responses)
              with open(output_path, 'w') as writer:
                for i, (line,
                        responses) in enumerate(zip(lines, responses_list)):
                  for response in responses:
                    response_num = batch_num * len(
                        lines_batches[0]) + i + last_idx
                    output_line = '\t'.join([
                        str(response_num), line, transformation,
                        response.strip()
                    ])
                    writer.write(output_line + '\n')
              print(f'done!  \t{output_path}')
              time.sleep(3)
              break

            except Exception as e:
              print(e)
              attempts += 1
              print('')
              print('%d attempts' % attempts)
              time.sleep(10 + 10 * attempts)

    # Combine the shards
    input_pattern = f'{base_path}.tsv'
    combine_shards(input_pattern, num_lines_batches, full_output_file)