# -*- coding: utf-8 -*-
"""
Created on Fri Mar  6 10:34:31 2015

@author: xmb
"""
import os

import gensim, numpy, logging, multiprocessing

logger = logging.getLogger()



def sentence_generator( filelist, startPadding = 2, stopPadding = 1 ):
    """
    A generator which yields sentences as lists of words.
    
    Parameters
    ----------
        filelist : list of str
            the file(s) from which to read the sentences, assuming one line is 
            one sentence
        startPadding : int
            how many '<s>' to be put before a sentence
        stopPadding : int
            how many '</s>' to be put after a sentence
            
    Yields
    ------
        l : list of str
            a sentence splited as a list of words
    """
    logger.debug('%-8s %s' % ('Enter', 'sentence_generator'))
    for f in filelist:
        fp = open( f, 'rb' )
        for l in fp:
            l = ['<s>'] * startPadding + l.strip().split() + ['</s>'] * stopPadding
            yield l
    logger.debug('%-8s %s' % ('Exit', 'sentence_generator'))



def train_using_skip_gram( files,
                          startPadding = 2,
                          stopPadding = 1,
                          min_count = 1,
                          workers = 22, 
                          size = 300, 
                          window = 6,
                          bin2save = None,
                          txt2question = None ):
    """
    Retrun the word-embeddings trained with skip-gram on the corpus of 'sentences'
    
    Parameters
    ----------
        sentences : iterable of str
            the list of file names that contain the corpus
        min_count : int
            minimum occurance of word to be considered
        size : int
            the dimension of word embedding
        windows : int
            the length of context
        bin2save : str
            the filename to the path where to save the word vectors (word2vec binary format)
        txt2question : str
            the filename to questions of similarity                
    Returns
    -------
        vocab_matrix : numpy.ndarray
            the word vectors in a 2d matrix
        idx2word : list of str
            the vocabulory in sorted order
        word2idx : dict of (str,int) pairs
            the vocabulory and their corresponding indices
    """
    logger.debug('%-8s %s' % ('Enter', 'train_using_skip_gram'))     
    
    if not bin2save or not os.path.isfile(bin2save):
        logger.info( 'Trained vectors not found. Start to train from sketch. ' )
        workers = multiprocessing.cpu_count()
        model = gensim.models.Word2Vec( # sorry, i know 'indigo' is 24-core
                                min_count = min_count, workers = workers, 
                                size = size, window = window    )                                      
        model.build_vocab( sentence_generator( files, startPadding, stopPadding ) )
        model.train( sentence_generator( files, startPadding, stopPadding ) )
        
        if bin2save:
            model.save_word2vec_format( bin2save, binary = True )
    else:
        logger.info( 'Trained vectors found. Load it instead of training. ' )
        model = gensim.models.Word2Vec.load_word2vec_format( bin2save, binary = True )
        
    if txt2question:    
        model.accuracy( txt2question )
    
    vocab = [ w for w in model.vocab ]
    nRow = len(model.vocab)
    if '<unk>' not in vocab:
        vocab.append( '<unk>' )
        nRow += 1
    idx2word = sorted( vocab )
    word2idx = dict( [ (idx2word[i], i) for i in xrange(len(idx2word)) ] )
    del vocab
    
    vocab_matrix = numpy.ndarray( (nRow, model.layer1_size), numpy.float32 )
    for i in xrange(len(idx2word)):
        if idx2word[i] in model:
            vocab_matrix[i] = model[idx2word[i]]
            
    logger.debug('%-8s %s' % ('Exit', 'train_using_skip_gram')) 
    return vocab_matrix, idx2word, word2idx




def convert(inname, outname, vocab):
    """
    Converts sentences to numbers (one per line)
    
    Parameters
    ----------
        inname : str
            a file of sentences, one sentence per line.
        outname : str
            each line of outname corresponds to a line of sentence in 'inname',
            while words are mapped into numbers and seperated by ','
        voab : dict
            key-value pairs where keys are words and values are integers     
    """
    logger.debug('%-8s %s' % ('Enter', 'convert')) 
    infile = open( inname, 'rb')
    outfile = open( outname, 'wb' )
    for l in infile:
        v = [  str(vocab[w]) for w in l.split() ]
        outfile.write(','.join(v) + '\n')
    infile.close()
    outfile.close()
    logger.debug('%-8s %s' % ('Exit', 'convert')) 
    


if __name__ == '__main__':  
    logging.basicConfig(format = '%(asctime)s : %(levelname)s : %(message)s', 
                        level = logging.INFO)    
    
    vocab_matrix, idx2word, word2idx = \
        train_using_skip_gram ( files = ['./data/enwik9.clean.80000.format.train'],
                                min_count = 1, 
                                size = 300,
                                bin2save = './data/wiki.8k.vector300',
                                txt2question = None )
    
    # to save the word matrix                            
    numpy.savetxt( './data/wiki.8k.vector300.csv', vocab_matrix, delimiter = ',' )
    
    # to save the index mapping
    fp = open( './data/wiki.8k.map.csv', 'wb' )
    for w in word2idx:
        fp.write( w + ',' + str(word2idx[w]) + '\n' )
    fp.close()
    
    convert( './data/enwik9.clean.80000.format.train', './data/wiki.8k.train.csv', word2idx )
    convert( './data/enwik9.clean.80000.format.dev', './data/wiki.8k.valid.csv', word2idx )
    convert( './data/enwik9.clean.80000.format.test', './data/wiki.8k.test.csv', word2idx )
