#!/usr/bin/env python3
import sys
tracks = {
    'gpu_throughput' : [],
    'gpu_latency' : [],
    'cpu1_throughput' : [],
    'cpu1_latency' : [],
    'cpuall_throughput' : []
}

track_names = {
    'gpu_throughput' : "GPU Batch",
    'gpu_latency' : "GPU Latency",
    'cpu1_throughput' : "1 Core Batch",
    'cpu1_latency' : "1 Core Latency",
    'cpuall_throughput' : "36 Core Batch",
}

def shorten(image_name):
  shortened = image_name.replace('uedin_release_model-finetune.intgemm.alphas.2021.bin.image', 'base').replace('uedin_', '').replace('student.', '').replace(".tar", '').replace(".image", "").replace('release_', '').replace('_student', '').replace('huaweitsc_cpu_1_latency_', '').replace('wmt21_', '').replace('niutrans_gpu_', '').replace('wmt2021_tentrans_transformer-', '').replace('_gpu_throughput_final', '').replace('_quant.alphas.bin', '').replace('_quant.alphas', '').replace('.pruned', '').replace('student_', '').replace('student-base', 'base').replace('student-tiny11', 'tiny11').replace('4bit_full', '4bit').replace('niutrans_cpu_', '').replace('.efh_', '-').replace("_quant", "").replace("12-1", "12_1").replace('ft8bit', 'ft8').replace('teacher-', 'tea-').replace('student-', 'stu-').replace('_', '\_')
  if shortened.startswith("02.") or shortened.startswith("03.") or shortened.startswith("04."):
    shortened = shortened[1:]
  if shortened.startswith("stu-") or shortened.startswith("tea-"):
    shortened = shortened.replace('-enc', '-').replace("0dec", "0\\_")
  return shortened

class Entry:
  def __init__(self, line):
    self.zscore, self.rscore, self.wins, self.wmt21bleu, self.wmt21chrf, self.wmt21comet, self.real, self.cpu, self.cpuram, self.gram, self.model, self.image, self.task, self.site, self.image_name = line.strip().split(' ')
    self.real = float(self.real)
    self.wmt21bleu = float(self.wmt21bleu)
    self.wmt21chrf = float(self.wmt21chrf)
    self.wmt21comet = float(self.wmt21comet)
    self.model = float(self.model)
    self.cpuram = float(self.cpuram)
    self.image = float(self.image)
    self.cpu = float(self.cpu)
    if self.wins != "na":
      self.string_zscore = self.zscore
      self.rscore = float(self.rscore)
      self.zscore = float(self.zscore)
    if self.gram != "na":
      self.gram = int(self.gram)
    self.hardware, self.batching = self.task.split('_')
    self.short = shorten(self.image_name)
    if self.site == "edinburgh":
      self.site = 'Edinburgh'
    elif self.site == "niutrans":
      self.site = 'NiuTrans'
    elif self.site == "huawei":
      self.site = 'HuaweiTSC'
    elif self.site == "tencent":
      self.site = 'TenTrans'
    else:
      raise Exception("Unexpected site " + self.site)

header = sys.stdin.readline()
systems = []
for line in sys.stdin:
  e = Entry(line)
  real = '{:0.0f}'.format(e.real)
  cpu = '{:0.0f}'.format(e.cpu / 1000000000.0)
  cpuram = '{:0.0f}'.format(e.cpuram / 1048576.0)
  model = '{:0.0f}'.format(e.model / 1048576.0)
  image = '{:0.0f}'.format(e.image / 1048576.0)
  comet = '{:0.3f}'.format(e.wmt21comet)
  bleu = '{:0.2f}'.format(e.wmt21bleu)
  chrf = '{:0.2f}'.format(e.wmt21chrf)
  if e.wins == "na":
    human = ["\\multicolumn{3}{c|}{}"]
  else:
    human = [e.wins, '{:0.1f}'.format(e.rscore), '{:0.3f}'.format(e.zscore)]
  cols = [e.site, "\\small " + e.short] + human + [comet, bleu, chrf, real, cpu, model, image, cpuram]
  if e.task.startswith('gpu'):
    cols.append(str(e.gram))
  else:
    cols.append('')
  row = '&'.join(cols) + "\\\\"
  tracks[e.task].append([e.wmt21comet, row])
  systems.append(e)

for track, rows in tracks.items():
  rows = sorted(rows)
  rows.reverse()
  headings = ["Team", "Variant", "Win", "Ave.", "Ave. $z$", "\\tiny COMET", "\\tiny BLEU", "\\tiny chrF", 'Wall', "CPU", 'Model', 'Docker', 'CPU']
  if track.startswith('gpu'):
    headings.append("GPU")
    mem_cols = 2
  else:
    headings.append("")
    mem_cols = 2

  with open(track + "_table.tex", "w") as f:
#    f.write("\\begin{tabular}{|@{ }l@{ }l@{ }|@{ }r@{ }r@{ }r@{ }r|r@{ }r|r@{ }r|r@{ }r@{ }|}\\hline\n")
#    f.write("\\multicolumn{12}{|c|}{\\textbf{\large NVIDIA A100 GPU}}\\\\\\hline\n")
    f.write("&&\\multicolumn{3}{@{}c@{}|}{\\thead{Human}}&\\multicolumn{3}{@{}c@{}|}{\\thead{Automatic}}&\\multicolumn{2}{@{ }c@{ }|}{\\thead{Seconds}}&\\multicolumn{2}{@{ }c@{ }}{\\thead{Disk MB}}&\\multicolumn{" + str(mem_cols) + "}{|@{ }c@{ }|}{\\thead{RAM MB}}\\\\\n")
    f.write('&'.join(["\\thead{" + h + "}" for h in headings]) + "\\\\\\hline\n")
    f.write('\n'.join(r[1] for r in rows))
    f.write("\\hline\n")
#    f.write("\\end{tabular}\n")

avez_to_roman = {}
with open("roman/tables/table_humaneval_stdda.tex") as f:
  lines = [l.strip().split() for l in f if l.startswith(' ')]
  for l in lines:
    roman_name = l[2]
    avez = l[4]
    wins = l[0]
    assert (wins, avez) not in avez_to_roman
    avez_to_roman[(wins, avez)] = roman_name

roman_to_entry = {}
for s in systems:
  if s.wins != 'na':
    s.roman_name = avez_to_roman[(s.wins, s.string_zscore)]
    roman_to_entry[s.roman_name] = s

found = [e for e in systems if e.task == "gpu_latency" and e.short == "tiny11"]
assert len(found) == 1
roman_to_entry["uedin-gpu-student-tiny11-latency"] = found[0]
found = [e for e in systems if e.task == "gpu_latency" and e.short == "base"]
assert len(found) == 1
roman_to_entry["uedin-gpu-student-base-latency"] = found[0]


def parse_roman_line(line):
    line = line.strip()
    assert line.endswith("\\\\")
    line = line[0:-2]
    split = line.split('&')
    split = [s.strip() for s in split]
    return split

for combo in [("table_humaneval_stdda", True), ("scores.HuaweiTSC-CPU", False), ("scores.Latency-CPU", False), ("scores.Latency-GPU-vs-CPU", True), ("scores.NiuTrans-GPU", False), ("scores.Tentrans-GPU", False), ("scores.Throughput-GPU", False)]:
  base, include_condition = combo
  with open("roman/tables/" + base + ".tex") as f, open("tables/fixed_" + base + ".tex", "w") as t:
    for l in f:
      if "&" in l:
        wins, system, avez, ave = parse_roman_line(l)
        if wins == "Wins":
          t.write("\\textbf{Team}&\\textbf{Variant}&\\textbf{Win}&\\textbf{Ave.}&\\textbf{Ave. $z$}&\\textbf{Time (s)}")
          if include_condition:
            t.write("&\\textbf{Condition}")
          t.write("\\\\\n")
        else:
          entry = roman_to_entry[system]
          values = [entry.site, entry.short, wins, ave, avez, '{:0.0f}'.format(entry.real)]
          if include_condition:
            values.append(track_names[entry.task])
          t.write('&'.join(values) + "\\\\\n")
      elif l == "\\begin{tabular}{rlrr}\n":
        t.write("\\begin{tabular}{llrrrr")
        if include_condition:
          t.write("l")
        t.write("}\n")
      elif l == "\n":
        pass
      else:
        t.write(l)
  
with open("tables/table_humaneval_pairs.tex") as s, open("tables/fixed_humaneval_pairs.tex", "w") as t:
  for l in s:
    if l == "\\begin{tabular}{lllrrrl}\n":
      t.write("\\begin{tabular}{|lll|lll|rrrl|}\n")
    elif "&" in l:
      number, sysa, sysb, avea, aveb, delta, pval = parse_roman_line(l)
      if number == "{}":
        t.write("\\multicolumn{3}{|c|}{\\textbf{Stronger System}}&\\multicolumn{3}{|c|}{\\textbf{Weaker System}}&\\textbf{Stronger}&\\textbf{Weaker}&&\\\\\n")
        t.write("\\textbf{Team}&\\textbf{Variant}&\\textbf{Condition}&\\textbf{Team}&\\textbf{Variant}&\\textbf{Condition}&\\textbf{DA Score}&\\textbf{DA Score}&\\textbf{Delta}&\\textbf{$p$-val}\\\\\n")
        continue
      number = int(number)
      if number >= 35 and number <= 40:
        continue
      sysa = roman_to_entry[sysa]
      sysb = roman_to_entry[sysb]
      cols = [sysa.site, sysa.short, track_names[sysa.task], sysb.site, sysb.short, track_names[sysb.task], avea, aveb, delta, pval]
      t.write('&'.join(cols) + "\\\\\n")
    else:
      t.write(l)
