import ast
import json
import time
import gym
import requests
from bs4 import BeautifulSoup

# import wikipedia

def clean_str(p):
  return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")

def get_page_obs(page, target_true=None, num_sent=5, num_word=None):
    # find all paragraphs
    paragraphs = page.split("\n")
    paragraphs = [p.strip() for p in paragraphs if p.strip()]

    # find all sentence
    sentences = []
    for p in paragraphs:
      sentences += p.split('. ')
    sentences = [s.strip() + '.' for s in sentences if s.strip()]
    
    if num_word:
        profix = ''
        si = 0
        try:
            while len(profix.split(' ')) <= num_word:
                profix += sentences[si]
                si += 1
            assert len(profix.split(' ')) >= num_word
        except IndexError:
            print(sentences[0])
            profix = ' '.join(sentences)
    else:
        profix = ' '.join(sentences[:num_sent])
    if target_true:
        profix = profix.replace(target_true, '')
        profix = profix.replace(target_true.lower(), '')
    return profix

def search_step(entity, target_true=None, num_sent=5, target_new=None, num_word = None):
    entity_ = entity.replace(" ", "+")
    search_url = f"https://en.wikipedia.org/w/index.php?search={entity_}"
    old_time = time.time()
    response_text = requests.get(search_url).text
    search_time = time.time() - old_time
    soup = BeautifulSoup(response_text, features="html.parser")
    result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
    if result_divs:  # mismatch
        result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
        obs = f"Could not find {entity}. Similar: {result_titles[:5]}."
    else:
        page = [p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")]
        if any("may refer to:" in p for p in page):
            # search_step("[" + entity + "]")
            print("may refer to:", entity)
            obs = ''
        else:
            pages = ""
            for p in page:
                if len(p.split(" ")) > 2:
                    pages += clean_str(p)
                if not p.endswith("\n"):
                    pages += "\n"
            obs = get_page_obs(pages, target_true, num_sent, num_word)
            lookup_keyword = lookup_list = lookup_cnt = None
    return obs

# obs = search_step('Singled Out', "MTV")
# print(obs)