from itertools import product
LANGUAGES=["guarani", "maya", "bribri"]
SPLITS=["test"]
MODELS=["gpt-4", "gpt-3.5-turbo", "cohere"]
rule all:
  input:
    [f"output/results_{lang}_{split}_{model}.tsv" for lang, split, model in product(LANGUAGES, SPLITS, MODELS)]

rule results:
  input:
    "data/{lang}-train.tsv",
    "data/{lang}-{split}.tsv",
    "main.py",
  output:
    "output/results_{lang}_{split,dev|test}_{model}.tsv"

  shell:
    "python main.py --input data/{wildcards.lang}-train.tsv --pred data/{wildcards.lang}-{wildcards.split}.tsv --model {wildcards.model} > {output}"

ruleorder: permuted_guarani > permuted_data

rule permuted_data:
  input:
    "counterfactual.py",
    "data/{lang}-{split}.tsv",
    "permutation.csv"
  output:
    "data/{lang}-perm-{split,train|dev}.tsv"

  shell:
    "python {input} > {output}"

rule permuted_guarani:
  input:
      "counterfactual.py",
      "data/guarani-{split}.tsv",
      "permutation_guar.csv"
  output:
    "data/guarani-perm-{split,train|dev}.tsv"

  shell:
    "python {input} > {output}"

rule permuted_results:
  input:
    "data/{lang}-perm-train.tsv",
    "data/{lang}-perm-dev.tsv",
    "main.py",
  output:
    "output/results_{lang}_perm_{model}.tsv"

  shell:
    "python main.py --input data/{wildcards.lang}-perm-train.tsv --pred data/{wildcards.lang}-perm-dev.tsv --model {wildcards.model} > {output}"

rule permuted_all:
  input:
    [f"output/results_{lang}_perm_{model}.tsv" for lang, model in product(LANGUAGES, MODELS)]   
