import os
import sys
import math
import struct
import yaml
import time
import random
import argparse
import logging
import tqdm
import pickle

import numpy as np
#import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

def word2ids(word, dataset) -> list:
    """
        word: list of subword ['w', 'o', 'r', 'd', '</s>']
        convert word to ids: first convert word to characters then convert each character to ids
    """
    ids = [] 
    for subword in word:
        ids.append(dataset.vocab[subword])
    return ids

def ids2word(ids, dataset) -> str:
    """
        convert ids to word: first convert ids to characters then convert each character to word
    """
    word = ""
    for id in ids:
        word += dataset.id2word[id]
    return word

def read_lines(file_path):
    with open(file_path, 'r') as f:
        lines = f.readlines()
    return lines

def filter_lines(lines):
    filtered_lines = []
    for line in lines:
        if len(line.strip().split()) > 1:
            # if line not started with = =
            if not line.strip().startswith('= ='):
                filtered_lines.append(line)
    return filtered_lines

def extract_frequency_table_from_lines(lines):
    frequency_table = {}
    for line in lines:
        for word in line.strip().split():
            if word not in frequency_table:
                frequency_table[word] = 1
            else:
                frequency_table[word] += 1
    sorted_frequency_table = sorted(frequency_table.items(), key=lambda x: x[1], reverse=True)
    return sorted_frequency_table

def read_frequency_table(file_path):
    with open(file_path, 'r') as f:
        lines = f.readlines()
    frequency_table = {}
    for line in lines:
        word, frequency = line.strip().split('\t')
        frequency_table[word] = int(frequency)
    return frequency_table

def extract_from_dataset(dataset, file_path):
    with open(file_path, 'w') as f:
        for split in ['train', 'validation', 'test']:
            for i in range(len(dataset[split])):
                f.write(dataset[split][i]['text'].strip() + '\n')


def save_frequency_table(frequency_table, file_path):
    with open(file_path, 'w') as f:
        for (word, freq) in frequency_table:
            f.write(word + '\t' + str(freq) + '\n')