#!/usr/bin/env python3
import math
INPUT_WORDS = 19951184
INPUT_LINES = 1000000

def shorten(image_name):
  shortened = image_name.replace('uedin_release_model-finetune.intgemm.alphas.2021.bin.image', 'base').replace('uedin_', '').replace('student.', '').replace(".tar", '').replace(".image", "").replace('release_', '').replace('_student', '').replace('huaweitsc_cpu_1_latency_', '').replace('wmt21_', '').replace('niutrans_gpu_', '').replace('wmt2021_tentrans_transformer-', '').replace('_gpu_throughput_final', '').replace('_quant.alphas.bin', '').replace('_quant.alphas', '').replace('.pruned', '').replace('student_', '').replace('student-base', 'base').replace('student-tiny11', 'tiny11').replace('4bit_full', '4bit').replace('niutrans_cpu_', '').replace('.efh_', '-').replace("_quant", "").replace("12-1", "12_1").replace('ft8bit', 'ft8').replace('teacher-', 'tea-').replace('student-', 'stu-').replace("rowcol-", "rc").replace("heads-", "he").replace('_', '\_')
  if shortened.startswith("02.") or shortened.startswith("03.") or shortened.startswith("04."):
    shortened = shortened[1:]
  if shortened.startswith("stu-") or shortened.startswith("tea-"):
    shortened = shortened.replace('-enc', '-').replace("0dec", "0\\_")
  return shortened

class Entry:
  def __init__(self, line):
    self.zscore, self.rscore, self.wins, self.wmt21bleu, self.wmt21chrf, self.wmt21comet, self.real, self.cpu, self.cpuram, self.gram, self.model, self.image, self.task, self.site, self.image_name = line.strip().split(' ')
    self.real = float(self.real)
    self.wmt21bleu = float(self.wmt21bleu)
    self.wmt21chrf = float(self.wmt21chrf)
    self.wmt21comet = float(self.wmt21comet)
    self.model = float(self.model)
    self.cpuram = float(self.cpuram)
    self.image = float(self.image)
    if self.wins != "na":
      self.string_zscore = self.zscore
      self.rscore = float(self.rscore)
      self.zscore = float(self.zscore)
    if self.gram != "na":
      self.gram = float(self.gram)
    self.hardware, self.batching = self.task.split('_')
    self.short = shorten(self.image_name)
    if self.site == "edinburgh":
      self.site = 'Edinburgh'
    elif self.site == "niutrans":
      self.site = 'NiuTrans'
    elif self.site == "huawei":
      self.site = 'HuaweiTSC'
    elif self.site == "tencent":
      self.site = 'TenTrans'
    else:
      raise Exception("Unexpected site " + self.site)

class Axis:
    def __init__(self, op, bigbetter, short, axis, title, log = False):
        self.op = op
        self.bigbetter = bigbetter
        self.short = short
        self.axis = axis
        self.title = title
        self.log = log

class KeyGroup:
    def __init__(self, selector, style, title):
        self.selector = selector
        self.style = style
        self.title = title

def pareto(x, y, entries):
    #Sort from best y to worst y.
    entries = sorted(entries, key = y.op, reverse = y.bigbetter)
    ret = []
    bestx = math.inf
    for e in entries:
       value = x.op(e)
       if x.bigbetter:
           value = -value
       #smaller value is better.
       if value < bestx:
           bestx = value
           ret.append(e)
    return ret

def arrow(fromx, fromy, tox, toy):
  return "set arrow from " + str(fromx) + ", " + str(fromy) + " to " + str(tox) + "," + str(toy) + " nohead\n"

def pareto_arrows(x, y, entries):
  entries = pareto(x, y, entries)
  if len(entries) == 0:
    return ""
  if x.bigbetter:
    originx = "graph 0"
  else:
    originx = "graph 1"
  beforex = x.op(entries[0])
  beforey = y.op(entries[0])
  ret = arrow(originx, "first " + str(beforey), beforex, beforey)
  for e in entries[1:]:
    ret += arrow(beforex, beforey, beforex, y.op(e))
    ret += arrow(beforex, y.op(e), x.op(e), y.op(e))
    beforex = x.op(e)
    beforey = y.op(e)
  if y.bigbetter:
    originy = "graph 0"
  else:
    originy = "graph 1"
  ret += arrow(beforex, beforey, beforex, originy)
  return ret

def dump_data(x, y, entries):
    ret = ""
    for e in entries:
       ret += str(x.op(e)) + " " + str(y.op(e)) + " \"" + format(x.op(e), "0.2f") + ', ' + format(y.op(e), "0.2f") + " " + e.image_name + "\"\n"
    ret += "e\n"
    return ret

LATENCY_RIGHT = set(["base", "tiny11", "tiny", "sm6", "sm9", "2.12\\_1.tiny.he0.3", "4.12\\_1.micro.rc0.5"])
def label_position(entry, x, y, label_left, label_right):
  xval = x.op(entry)
  if x.short == "latency" and y.short == "zscore" and entry.site == "Edinburgh" and entry.short != "base":
    return "left"
  if x.short == "latency" and y.short == "comet":
    if entry.image_name == "uedin_student.02.6-2.tied.tiny.pruned.heads.efh_0.3.image":
      return "left offset -1,-0.5"
    if entry.image_name == "wmt21_02.student.12-1.base.4bit_full.image":
      return "right offset 0,0.1"
    if entry.short in LATENCY_RIGHT:
      if entry.short == "base" and entry.site == "HuaweiTSC":
        return "right offset 0, 0.75"
      if entry.short == "base" and entry.site == "Edinburgh" and entry.hardware == "cpu1":
        return "center offset 0,0.7"
      return "right"
    else:
      return "left"
  if x.short in ["thwordspersecond", "miwordsperdollar"]:
    if y.short == "zscore":
      if entry.image_name == "wmt2021_tentrans_transformer-student-enc20dec1-h512-ffn2048_gpu_throughput_final.image":
        return "right"
      if entry.image_name == "wmt2021_tentrans_transformer-student-enc10dec1-h512-ffn2048_gpu_throughput_final.image":
        return "left offset -0.5,0.75"
    if entry.short.startswith("4.12\\_1.micro") and entry.hardware.startswith("cpu"):
      return "right"
    if entry.hardware == "gpu":
      if entry.short.startswith("2.") or entry.short.startswith("3."):
        if entry.short=="2.12\\_1.tiny.he0.3":
          return "left"
        if entry.short == "2.12\\_1.tiny.4bit" and y.short == "comet":
          return "left offset -1,-0.5"
        return "right"
      if entry.image_name == "niutrans_gpu_12_1_512.tar":
        return "right"
  if x.short == "model":
    if entry.image_name == "huaweitsc_cpu_1_latency_tiny.tar":
      return "left"
    if entry.image_name == "wmt21_02.student.12-1.tiny.4bit_full.image":
      return "left offset 0,-0.6"
    if entry.image_name == "wmt21_02.student.8-4.tied.tiny.4bit_full.image":
      return "left offset 0,0.6" 
    if entry.image_name == "niutrans_cpu_6_1_512.tar" or entry.image_name == "niutrans_gpu_6_1_0.tar" or entry.image_name == "niutrans_gpu_3_1_512.tar":
      return "left"
    if entry.short == "2.6-2.tied.tiny.he0.3" or entry.short == "2.12\\_1.micro.he0.3" or entry.short == "2.12-2.tied.tiny.he0.3":
      return "left"
    if entry.short == "3.12\_1.large":
      return "left offset -1,1"
    if entry.image_name == "uedin_release_student_tiny11.image" and entry.hardware == "cpu1" and entry.batching == "latency":
      return "left"
  if x.short == "gram":
    if entry.image_name == "niutrans_gpu_12_1_512.tar":
      return "left offset 0,-0.3"
    if entry.image_name == "niutrans_gpu_6_1_0.tar":
      return "left offset 0,-0.3"
    if entry.image_name == "niutrans_gpu_3_1_512.tar":
      return "left offset 0,-0.3"
    if entry.short == "2.12\\_1.micro.he0.3":
      return "right offset 2.9,-0.5"
    if entry.short == "2.12\\_1.base.4bit" and entry.batching == "latency":
      return "right offset 1.5,0.7"
    if entry.short == "2.12\\_1.tiny.4bit":
      if entry.batching == "throughput":
        return "right offset 0,-0.4"
      else:
        return "left"
    if entry.short == "2.8-4.tied.tiny.4bit" and entry.batching == "throughput":
      return "right offset -0.3,0.2"
    if entry.short == "2.12\\_1.micro.rc0.5":
      return "right offset 0,0.3"
  if x.short == "cpuram":
    if entry.short == "4.12\_1.micro.rc0.5.ft8":
      if entry.hardware == "cpu1":
        return "right offset 3,0.7"
      else:
        return "left offset 0,0"
    if entry.short == "4.12\_1.tiny.rc0.5.ft8":
      if entry.hardware == "cpu1" and entry.batching == "latency":
        return "right"
      return "left offset 0,0.8"
    if entry.short == "4.12\_1.micro.rc0.5":
      return "left"
    if entry.short == "base" and entry.hardware == "cpu1" and entry.site == "Edinburgh":
      return "center offset 0,-0.5"
    if entry.short == "tiny11" and entry.hardware == "cpu1" and entry.site == "Edinburgh":
      return "center offset 0,0.8"
    if entry.short == "3.12\\_1.large" and entry.hardware == "cpuall":
      return "left"
    if entry.short == "3\\_1\\_512" and entry.hardware == "cpuall":
      return "right offset 0,0.9"
  if xval > label_right:
    orientation = "right"
    if x.bigbetter:
      orientation += " offset 0, 0.75"
  elif xval < label_left:
    orientation = "left"
    if not x.bigbetter:
      orientation += " offset 0, 0.6"
  else:
    if x.bigbetter:
      orientation = "left"
    else:
      orientation = "right"
  return orientation

def plot(x, y, entries, selector_style):
    command = []
    data = ""
    labels = ""

    #Compute range of x values.
    #I feel bad about this algorithm: it should compute on the fly and allow all selectors.
    xvalues = []
    for s in selector_style:
        xvalues += [x.op(e) for e in entries if s.selector(e)]
    minx = min(xvalues)
    maxx = max(xvalues)
    if x.log:
      label_right = (minx ** 0.05) * (maxx ** 0.95)
      label_left = (minx ** 0.95) * (maxx ** 0.05)
      #gnuplot puts the leftmost point on the axis.  Add some space.
      start = minx / ((maxx / minx) ** 0.03)
      stop = maxx * ((maxx / minx) ** 0.03)
      zeroing = " [" + str(start) + ":" + str(stop) + "]"
    else:
      label_right = 0.95 * maxx
      label_left = 0.05 * maxx
      zeroing = " [0:]"

    for s in selector_style:
        restrict = [e for e in entries if s.selector(e)]
        if len(restrict) == 0:
            continue
        command.append("'-' with labels hypertext point " + s.style + " title '" + s.title + "'")
        data += dump_data(x, y, restrict)
        for r in restrict:
           orientation = label_position(r, x, y, label_left, label_right)
           labels += 'set label "\\\\tiny ' + r.short.replace('\\', '\\\\') + '" at ' + str(x.op(r)) + "," + str(y.op(r)) + " " + orientation + "\n"
    #Labels way too big
    if x.short == "docker":
      labels = ""
    return \
       "set xlabel '" + x.axis + "' offset 0,0.5\n" + \
       "set ylabel '" + y.axis + "'\n" + \
       labels + \
       "plot " + zeroing + " " + ','.join(command) + "\n" + data + \
       "unset label\n"

def pareto_plot(terminal, x, y, entries, selector_style, title, key = None):
    if x.bigbetter:
      bestx = "1"
      orient = "right"
      keyside = "left"
    else:
      bestx = "0"
      orient = "left"
      keyside = "right"
    ret = 'set label at graph ' + bestx + ', graph 1 offset 0, -1.6 ' + orient + ' ' + terminal.smile() + '\n'
    ret += terminal.title('set title "' + title + '"\n')
    if x.log:
      ret += "set log x\n"
    if key is None:
      if keyside == "left":
        ret += "set key bottom left at graph 0.05, 0.02\n"
      else:
        ret += "set key bottom " + keyside + "\n"
    else:
      ret += key + "\n"
    ret += "set key samplen 0\n"
    ret += pareto_arrows(x, y, entries)
    ret += plot(x, y, entries, selector_style)
    return ret + "unset arrow\nunset label\nunset log\nset key default\nset key spacing 1.2\n"


bleu = Axis(lambda e : e.wmt21bleu, True, "bleu", "BLEU", "Quality")
comet = Axis(lambda e : e.wmt21comet, True, "comet", "COMET", "Quality")
chrf = Axis(lambda e : e.wmt21chrf, True, "chrf", "chrF", "Quality")
zscore = Axis(lambda e: e.zscore, True, "zscore", "Source-based DA $z$-score", "Quality")
thousandwpreals = Axis(lambda e : 19951184.0/e.real/1000.0, True, "thwordspersecond", "Thousand words per wall second", "Speed")
model_size = Axis(lambda e : e.model /1048576.0, False, "model", "Model size (MB)", "Model size", log = True)
docker_size = Axis(lambda e : e.image / 1048576.0, False, "docker", "Docker size (MB)", "Docker size", log = True)
gram = Axis(lambda e : e.gram / 1024.0, False, "gram", "GPU RAM (GB)", "RAM on GPU", log = True)
cpuram = Axis(lambda e : e.cpuram / 1048576.0 / 1024.0, False, "cpuram", "CPU RAM (GB)", "RAM", log = True)
latency = Axis(lambda e : e.real / 1000.0, False, "latency", "Latency (ms)", "Latency")
def milwordsperdollar(entry):
    millionwordspersecond = float(INPUT_WORDS)/entry.real/float(INPUT_LINES)
    if entry.hardware == "cpuall":
        return millionwordspersecond / (2.7 / 3600.0) # 36 cores * $0.075 / core https://blogs.oracle.com/cloud-infrastructure/post/announcing-compute-instances-with-3rd-gen-intel-xeon-ice-lake-processors
    elif entry.hardware == "gpu":
        return millionwordspersecond / (3.05 / 3600.0) # $3.05 / A100 GPU https://www.oracle.com/uk/cloud/partners/gpu.html
    else:
        raise Exception("Don't have dollar cost for " + entry.hardware)

millionperdollar = Axis(milwordsperdollar, True, "miwordsperdollar", "Million words per dollar", "Cost", log = False)

participants = [
    KeyGroup(lambda e : e.site == "Edinburgh", "lt 2", "Edinburgh"),
    KeyGroup(lambda e : e.site == "HuaweiTSC", "lt 4", "HuaweiTSC"),
    KeyGroup(lambda e : e.site == "NiuTrans", "lt 3", "NiuTrans"),
    KeyGroup(lambda e : e.site == "TenTrans", "lt 1", "TenTrans"),
]

participants_no_garbage = [
    KeyGroup(lambda e : e.site == "Edinburgh" and e.image_name not in ["uedin_release_04.12-1.micro.pruned.rowcol.efh_0.5.ft8bit_quant.alphas.bin.image", "uedin_release_student.04.12-1.tiny.pruned.rowcol.efh_0.5.ft8bit_quant.alphas.image"], "lt 2", "Edinburgh")
] + participants[1:]

def generate(lines, terminal, yaxis):
  print("set pointsize 2")
  print("set key spacing 1.2")
  for batching in ["throughput", "latency"]:
    if batching == "throughput":
      hwarray = ["gpu", "cpu1", "cpuall"]
      metricsbase = [thousandwpreals]
    else:
      hwarray = ["gpu", "cpu1"]
      metricsbase = [thousandwpreals, latency]
    for hw in hwarray:
      entries = [e for e in lines if e.hardware == hw and e.batching == batching]
      if len(entries) == 0:
        continue
      metrics = metricsbase.copy()
      if hw == "gpu":
        titlebase = "GPU"
      elif hw == "cpu1":
        titlebase = "1 CPU core"
      elif hw == "cpuall":
        titlebase = "36 CPU cores"
      titlebase += " " + batching
      if hw == "gpu":
        metrics.append(gram)
      else:
        metrics.append(cpuram)
      for x in metrics:
        name = hw + "_" + batching + '_' + x.short + '_' + yaxis.short
        keylocation = None
        if hw == 'gpu' and x.short == "thwordspersecond" and batching == "throughput":
          keylocation = "set key at graph 0.25, graph 0.90 samplen 0"
        if hw == 'gpu' and x.short == 'gram' and batching == 'latency':
          keylocation = "set key bottom left at graph 0.1,graph 0.1 samplen 0"
        if x.log:
          if hw == 'gpu' and batching == "latency":
            print("set xtics (28,29,30,31,32,33,34,35,36,37,38) logscale")
          else:
            print("set xtics 0.0625,2")
        terminal.set_output(name)
        print(pareto_plot(terminal, x, yaxis, entries, participants, titlebase + ": " + x.title, key = keylocation))
        print("set xtics auto")
  
  terminal.set_output('size_' + yaxis.short)
  print("set xtics 1,2")
  print(pareto_plot(terminal, model_size, yaxis, lines, participants_no_garbage, title = "Model size"))
  terminal.set_output("docker_" + yaxis.short)
  print("set xtics 0.0625,2")
  print(pareto_plot(terminal, docker_size, yaxis, lines, participants, title = "Docker image size"))
  print("set xtics auto")
  
  terminal.set_output('dollar_' + yaxis.short)
  dollar = [
      KeyGroup(lambda e : e.site == "Edinburgh" and e.hardware == "gpu", "lt 2", "Edinburgh: GPU"),
      KeyGroup(lambda e : e.site == "NiuTrans" and e.hardware == "gpu", "lt 3", "NiuTrans: GPU"),
      KeyGroup(lambda e : e.site == "TenTrans" and e.hardware == "gpu", "lt 1", "TenTrans: GPU"),
      KeyGroup(lambda e : e.site == "Edinburgh" and e.hardware == "cpuall", "lt 8 lc 2", "CPU"),
      KeyGroup(lambda e : e.site == "NiuTrans" and e.hardware == "cpuall", "lt 6 lc 3", "CPU"),
  ]
  print(pareto_plot(terminal, millionperdollar, yaxis, [e for e in lines if (e.hardware == "cpuall" or e.hardware == "gpu") and e.batching == "throughput"], dollar, title = "Cost efficiency", key = 'set key at graph 0.9, graph 0.05 bottom maxrows 3 samplen 0 width -8'))
  #print(pareto_plot(terminal, millionperdollar, yaxis, [e for e in lines if (e.hardware == "cpuall" or e.hardware == "gpu") and e.batching == "throughput"], dollar, title = "Cost efficiency", key = 'set key at graph 0.85, graph 0.3 maxrows 3 samplen 0 width -10'))
  
  # Combined GPU and CPU latency
  latency_lines = [l for l in lines if l.batching == "latency"]
  latency_options = [
      KeyGroup(lambda e : e.site == "Edinburgh" and e.hardware == "gpu", "lt 2", "Edinburgh GPU"),
      KeyGroup(lambda e : e.site == "Edinburgh" and e.hardware == "cpu1" and e.image_name not in ["uedin_release_04.12-1.micro.pruned.rowcol.efh_0.5.ft8bit_quant.alphas.bin.image", "uedin_release_student.04.12-1.tiny.pruned.rowcol.efh_0.5.ft8bit_quant.alphas.image"], "lt 8 lc 2", "Edinburgh CPU"),
      KeyGroup(lambda e : e.site == "HuaweiTSC" and e.hardware == "cpu1", "lt 4", "HuaweiTSC CPU"),
  ]
  terminal.set_output('latency_' + yaxis.short)
  print(pareto_plot(terminal, latency, yaxis, latency_lines, latency_options, title = "Latency without batching"))

class SVG:
  def set_output(self, name):
    print("set terminal svg size 600,335 dynamic fontscale 1.5 mouse jsdir './'")
    print("set output '" + name + ".svg'")
  def smile(self):
    return '"😊" font "Times,30"'
  def title(self, title):
    return title

class PNG:
  def set_output(self, name):
    print("set terminal png enhanced fontscale 1.5 font \"arial,11\" lw 3 size 1024,768")
    print("set output '" + name + ".png'")
  def smile(self):
    return '"😊" font "Times,30"'
  def title(self, title):
    return title

class Tikz:
  def set_output(self, name):
    if name.startswith("latency") or name.startswith("dollar") or name.startswith("size"):
      size = "15.8,7"
    elif "_gram_" in name or "_cpuram_" in name:
      if name.startswith("gpu_throughput"):
        size = "15.8,6"
      else:
        size = "7.9,5"
    else:
      size = "15.8,5.8"
    print("set terminal tikz size " + size + " lw 3")
    print("set output '" + name + ".tex'")
  def smile(self):
    return '"\\\\huge\\\\smiley"'
  def title(self, title):
    return ""

lines = [Entry(l) for l in open("wmt21-efficiency-humeval.csv") if not l.startswith("#")]
print("set encoding utf8")
#generate(lines, SVG(), comet)
#generate(lines, SVG(), bleu)
#generate(lines, SVG(), chrf)
generate(lines, Tikz(), comet)
generate([e for e in lines if e.zscore != "na"], Tikz(), zscore)
#generate(lines, Tikz(), bleu)
