import os
import sys
import codecs
import random
import fasttext
import subprocess

def run_command(cmd):
  p = subprocess.Popen(cmd, shell=True)
  sts = os.waitpid(p.pid, 0)
  return


def writeFastTextCorpus(data, filename):
  random.shuffle(data)
  with codecs.open(filename, 'w', 'utf-8') as corpusF:
    for sample in data:
      corpusF.write(sample)


def trainModel(corpusFile, modelfile, testfile,  tooldir ):
  # set params
  dim = "10"
  lr =  "0.1"
  epoch = "5"
  minCount = "1"
  word_ngrams = "2"
  bucket = "10000000"
  loss = "softmax"

  # calculate Slp and Rank score
  cmdout = run_command( tooldir + '/fasttext ' + " supervised  -input  " + corpusFile + "  -output  " + modelfile \
                + " -dim  " + dim + " -lr  " + lr + " -epoch  " + epoch + " -minCount  " + minCount + " -bucket " + bucket \
                + " -thread 1 -minn 0  -maxn 0  " + " -loss " + loss + " -wordNgrams " + word_ngrams )

  # Train the classifier
  classifier = fasttext.load_model( modelfile + '.bin', label_prefix='__label__')

  # Test the classifier
  result = classifier.test(testfile)
  print "Current round model precison: ", result.precision, " with model: ", os.path.basename(modelfile)
  return result.precision


def mixAutoLabeledBaseline(workdir, tooldir):

  #load manual labelled data as baseline
  initLabelSamples = []
  with codecs.open( workdir + '/../yahoo_answers.labeled.train' , 'rb', 'utf-8') as trainf:
    for line in trainf:
      initLabelSamples.append( line )

  #load unlabeled data
  unLabeledData = []
  with codecs.open(workdir + '/../yahoo_answers.unlabeled.train', 'r', 'utf-8') as rawf:
    for line in rawf:
      info  = line.strip().split(',')
      content = ','.join(info[1:]) # remove and forget labels
      if len(content.strip()) < 1:
        #print "Error found in unlabeled data line: ", content, " \n Org line:\n", line
        continue
      unLabeledData.append( content )

  print 'Total labeled data samples: ', len(initLabelSamples), ' , unlabeledData: ' ,  len(unLabeledData)

  trainCorpus =  initLabelSamples
  trainFile = os.path.join(workdir, 'TrainCorpus' )
  model = os.path.join(workdir, 'model' )
  testfile = os.path.join(workdir, '../yahoo_answers.test')
  modelPerformance  = []

  print 'Initializaing model  ...'
  writeFastTextCorpus(trainCorpus, trainFile)
  curModelPrecision = trainModel( trainFile, model , testfile , tooldir)
  modelPerformance.append(curModelPrecision)

  for itr in xrange(50):
    print '\n\n\n  iterating #', itr+1, ' ....'
    # enrich corpus A
    viewASamples = enrichingCorpus( model , unLabeledData )

    trainCorpus = initLabelSamples + viewASamples
    print ' Updating  model  ...'
    writeFastTextCorpus(trainCorpus, trainFile)
    curModelPrecision = trainModel(trainFile, model, testfile, tooldir)
    modelPerformance.append(curModelPrecision)

  #summarize performance
  print 'Interation Performances of Model are: ', str(modelPerformance)
  with codecs.open(os.path.normpath(workdir + '/AccuracyReport.mixedAutoLabeldBaseline.txt'), 'w', 'utf-8') as outf:
    outf.write("\n\nRaw Mixed Autolabled Baseline is:\n " + str(modelPerformance) + '\n\n\n')



def enrichingCorpus( modelFile , unlabeledData ):
  # corpus enriching policy is based on  classifier's confidence level
  classifier = fasttext.load_model( modelFile + '.bin', label_prefix='__label__')
  print 'enriching training corpus with unlabeled data ...'
  predictResult = []
  for sample in unlabeledData:
    if len(sample.strip()) < 1:
      print 'Error during enriching corpus!!!\n Sample:\n', sample
      continue
    result = classifier.predict_proba([sample], k=10)[0]
    predCat = result[0][0]
    predConf = result[0][1]
    predictResult.append((sample, predCat, predConf))


  viewSamples=[]
  for entry in predictResult:
    content, predCat, predConf = entry
    viewSamples.append( '__label__' + predCat  + '\t , ' + content + '\n')
  return viewSamples


if __name__ == '__main__':
  if len(sys.argv) != 3:
    print " usage: python mixAutoLabeledBaseline.py  WorkDir ToolDir"
    exit(1)
  mixAutoLabeledBaseline(sys.argv[1], sys.argv[2])

