def fscore(nb_correct, nb_gold, nb_pred):
    if (nb_correct * nb_gold * nb_pred == 0):
        return 0
    r = nb_correct / nb_gold
    p = nb_correct / nb_pred
    return 100*2*r*p/(p+r)

class NbHeadsDepsRegressor(nn.Module):
    def __init__(self, indices, device, tasks,  # dict task (letter h, d, b, s) to weight
                 w_emb_size=10, #125
                 l_emb_size=None, 
                 p_emb_size=None, # 100
                 use_pretrained_w_emb=False,
                 lstm_dropout=0.33, 
                 lstm_h_size=20, # 600
                 lstm_num_layers=3, 
                 mlp_arc_o_size=25, # 600
                 mlp_lab_o_size=10, # 600
                 mlp_arc_dropout=0.25, 
                 mlp_lab_dropout=0.25,
                 use_bias=False,
                 bert_model=None, # caution: should coincide with indices.bert_tokenizer
                 freeze_bert=False):
        super(NbHeadsDepsRegressor, self).__init__()

        self.indices = indices
        self.device = device
        self.use_pretrained_w_emb = use_pretrained_w_emb

        # indices for tasks
        self.tasks = sorted(tasks)
        self.nb_tasks = len(self.tasks)
        self.task2i = dict( [ [self.tasks[i],i ] for i in range(self.nb_tasks) ] )

        if 'l' in self.task2i and 'a' not in self.task2i:
          exit("ERROR: task a is required for task l")

        if 'g' in self.task2i and not('d' in self.task2i and 'h' in self.task2i):
          exit("ERROR: tasks d and h are required for task g")

        # ------------ dynamic weights for subtasks -----------------------
        # Kendal et al. 2018 https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf
        # log of variance set to zero (<=> variance == 1)
        # (important to put the tensor on the right device BEFORE instantiating a nn.Parameter)
        self.log_sigma2 = nn.Parameter(torch.zeros(self.nb_tasks).to(device))
        print(self.nb_tasks)
        print(self.named_parameters())


        # ------------ Encoding of sequences ------------------------------
        self.lexical_emb_size = w_emb_size
        w_vocab_size = indices.get_vocab_size('w')

        self.num_labels = indices.get_vocab_size('label')
        self.w_emb_size = w_emb_size
        self.p_emb_size = p_emb_size
        self.l_emb_size = l_emb_size
        self.lstm_h_size = lstm_h_size
        self.mlp_arc_o_size = mlp_arc_o_size
        self.mlp_lab_o_size = mlp_lab_o_size
        self.mlp_arc_dropout = mlp_arc_dropout
        self.lstm_num_layers = lstm_num_layers
        self.lstm_dropout = lstm_dropout

        self.use_bias = use_bias # whether to add bias in all biaffine transformations
        self.freeze_bert = freeze_bert
        # -------------------------
        # word form embedding layer
        if not use_pretrained_w_emb:
            self.w_embs = nn.Embedding(w_vocab_size, w_emb_size).to(self.device)
        else:
            matrix = indices.w_emb_matrix
            
            #if w_emb_size != matrix.shape[1]:
            #    sys.stderr.write("Error, pretrained embeddings are of size %d whereas %d is expected"
            #                     % (matrix.shape[1], w_emb_size))
            if w_vocab_size != matrix.shape[0]:
                sys.stderr.write("Error, pretrained embeddings have a %d vocab size while indices have %d"
#                                  % (matrix.shape[0], w_vocab_size))
            self.w_embs = nn.Embedding.from_pretrained(matrix, freeze = False).to(self.device)
            # linear transformation of the pre-trained embeddings (dozat 2018)
            self.w_emb_linear_reduction = nn.Linear(matrix.shape[1],w_emb_size).to(self.device)
            
        print("w_embs done")
        # -------------------------
        # pos tag embedding layer
        if p_emb_size:
            p_vocab_size = indices.get_vocab_size('p')
            # concatenation of embeddings hence +
            self.lexical_emb_size += p_emb_size
            self.p_embs = nn.Embedding(p_vocab_size, p_emb_size).to(self.device)
        else:
            self.p_embs = None

        # -------------------------
        # lemma embedding layer
        if l_emb_size:
            l_vocab_size = indices.get_vocab_size('l')
            self.lexical_emb_size += l_emb_size
            self.l_embs = nn.Embedding(l_vocab_size, l_emb_size).to(self.device)
        else:
            self.l_embs = None

        # -------------------------
        # bert embedding layer
        if bert_model is not None:
            self.bert_layer = bert_model.to(self.device) 
            if freeze_bert:
              for p in self.bert_layer.parameters():
                p.requires_grad = False
            self.lexical_emb_size += self.bert_layer.config.emb_dim

        else:
            self.bert_layer = None
            
        # -------------------------
        # recurrent LSTM bidirectional layers
        #   TODO: same mask dropout across time-steps ("locked dropout")
        self.lstm = nn.LSTM(input_size = self.lexical_emb_size, 
                            hidden_size = lstm_h_size, 
                            num_layers = lstm_num_layers, 
                            batch_first = True,
                            bidirectional = True,
                            dropout = lstm_dropout).to(self.device)

        # -------------------------
        # specialized MLP applied to biLSTM output
        #   rem: here hidden sizes = output sizes
        #   for readability:
        s = 2 * lstm_h_size
        a = mlp_arc_o_size
        l = mlp_lab_o_size

        self.arc_d_mlp = MLP(s, a, a, dropout=mlp_arc_dropout).to(device)  
        self.arc_h_mlp = MLP(s, a, a, dropout=mlp_arc_dropout).to(device)  
        self.lab_d_mlp = MLP(s, l, l, dropout=mlp_lab_dropout).to(device)  
        self.lab_h_mlp = MLP(s, l, l, dropout=mlp_lab_dropout).to(device)

        # ---- BiAffine scores for arcs and labels --------
        # biaffine matrix size is num_label x d x d, with d the output size of the MLPs
        if 'a' in self.task2i:
          self.biaffine_arc = BiAffine(device, a, use_bias=self.use_bias)
          if 'l' in self.task2i:
            self.biaffine_lab = BiAffine(device, l, num_scores_per_arc=self.num_labels, use_bias=self.use_bias)

        # ----- final layers for the sub tasks ------------

        # final layer to get a single real value for nb heads / nb deps
        # (more precisely : will be interpreted as log(1+nbheads))
        if 'h' in self.task2i:
          #self.final_layer_nbheads = nn.Linear(a, 1).to(self.device)
          self.final_layer_nbheads = MLP(a, a, 1).to(self.device)

        if 'd' in self.task2i:
          #self.final_layer_nbdeps = nn.Linear(a, 1).to(self.device)
          self.final_layer_nbdeps = MLP(a, a, 1).to(self.device)

        # final layer to get a bag of labels vector, of size num_labels + 1 (for an additional "NOLABEL" label) useless in the end
        #@@self.final_layer_bag_of_labels = nn.Linear(a,self.num_labels + 1).to(self.device)
        if 'b' in self.task2i:
          #self.final_layer_bag_of_labels = nn.Linear(a, self.num_labels).to(self.device)
          self.final_layer_bag_of_labels = MLP(a, a, self.num_labels).to(self.device)

        # final layer to get a "sorted label sequence", seen as an atomic symbol
        if 's' in self.task2i:
          #self.final_layer_slabseqs = nn.Linear(a, self.indices.get_vocab_size('slabseq')).to(self.device)
          self.final_layer_slabseqs = MLP(a, a, self.indices.get_vocab_size('slabseq')).to(self.device)

        #for name, param in self.named_parameters():
        #  if name.startswith("final"):
        #    print(name, param.requires_grad)

        
    def forward(self, w_id_seqs, l_id_seqs, p_id_seqs, bert_tid_seqs, bert_ftid_rkss, b_pad_masks, lengths=None):
        """
        Inputs:
         - id sequences for word forms, lemmas and parts-of-speech for a batch of sentences
             = 3 tensors of shape [ batch_size , max_word_seq_length ]
         - bert_tid_seqs : sequences of *bert token ids (=subword ids) 
                           shape [ batch_size, max_token_seq_len ]
         - bert_ftid_rkss : ranks of first subword of each word [batch_size, max_WORD_seq_len +1] (-1+2=1 (no root, but 2 special bert tokens)
         - b_pad_masks : 0 or 1 tensor of shape batch_size , max_word_seq_len , max_word_seq_len 
                         cell [b,i,j] equals 1 iff both i and j are not padded positions in batch instance b
        If lengths is provided : (tensor) list of real lengths of sequences in batch
                                 (for packing in lstm)
        """
        w_embs = self.w_embs(w_id_seqs)
        if self.use_pretrained_w_emb:
            w_embs = self.w_emb_linear_reduction(w_embs)
            
        if self.p_embs:
            p_embs = self.p_embs(p_id_seqs)
            w_embs = torch.cat((w_embs, p_embs), dim=-1)
        if self.l_embs:
            l_embs = self.l_embs(l_id_seqs)
            w_embs = torch.cat((w_embs, l_embs), dim=-1)
        
        if bert_tid_seqs is not None:
            bert_emb_size = self.bert_layer.config.emb_dim
            bert_embs = self.bert_layer(bert_tid_seqs).last_hidden_state
            # select among the subword bert embedddings only the embeddings of the first subword of words
            #   - modify bert_ftid_rkss to serve as indices for gather:
            #     - unsqueeze to add the bert_emb dimension
            #     - repeat the token ranks index along the bert_emb dimension (expand better for memory)
            #     - gather : from bert_embs[batch_sample, all tid ranks, bert_emb_dim]
            #                to bert_embs[batch_sample, only relevant tid ranks, bert_emb_dim]
            #bert_embs = torch.gather(bert_embs, 1, bert_ftid_rkss.unsqueeze(2).repeat(1,1,bert_emb_size))
            bert_embs = torch.gather(bert_embs, 1, bert_ftid_rkss.unsqueeze(2).expand(-1,-1,bert_emb_size))
            w_embs = torch.cat((w_embs, bert_embs), dim=-1)
            
        # h0, c0 vectors are 0 vectors by default (shape batch_size, num_layers*2, lstm_h_size)

        # pack_padded_sequence to save computations
        #     (compute real length of sequences in batch)
        #     see https://gist.github.com/HarshTrivedi/f4e7293e941b17d19058f6fb90ab0fec
        #     NB:batch must be sorted in sequence length descending order
        if lengths is not None:
            lengths=lengths.cpu()
            w_embs = pack_padded_sequence(w_embs, lengths, batch_first=True)
        lstm_hidden_seq, _ = self.lstm(w_embs)
        if lengths is not None:
            lstm_hidden_seq, _ = pad_packed_sequence(lstm_hidden_seq, batch_first=True)
        
        # MLPs
        arc_h = self.arc_h_mlp(lstm_hidden_seq) # [b, max_seq_len, mlp_arc_o_size]
        arc_d = self.arc_d_mlp(lstm_hidden_seq) # [b, max_seq_len, mlp_arc_o_size]
        lab_h = self.lab_h_mlp(lstm_hidden_seq)
        lab_d = self.lab_d_mlp(lstm_hidden_seq)

        S_arc = S_lab = log_nbheads = log_nbdeps = bols = S_slabseqs = None

        # Biaffine scores for arcs
        if 'a' in self.task2i:
          S_arc = self.biaffine_arc(arc_h, arc_d) # S(k, i, j) = score of sample k, head word i, dep word j

        # Biaffine scores for labeled arcs
        if 'l' in self.task2i:
          S_lab = self.biaffine_lab(lab_h, lab_d) # S(k, l, i, j) = score of sample k, label l, head word i, dep word j

        # nb heads / nb deps (actually output will be interpreted as log(1+nb))
        if 'h' in self.task2i:
          log_nbheads = self.final_layer_nbheads(arc_d).squeeze(2) # [b, max_seq_len]

        if 'd' in self.task2i:
          log_nbdeps = self.final_layer_nbdeps(arc_h).squeeze(2)   # [b, max_seq_len]

        # bag of labels
        if 'b' in self.task2i:
          bols = self.final_layer_bag_of_labels(arc_d) # [b, max_seq_len, num_labels + 1] (+1 for NOLABEL)

        # sorted lab sequences (equivalent to bag of labels, but seen as a symbol for the whole BOL)
        if 's' in self.task2i:
          S_slabseqs = self.final_layer_slabseqs(arc_d)

        return S_arc, S_lab, log_nbheads, log_nbdeps, bols, S_slabseqs

    
    def batch_forward_and_loss(self, batch, trace_first=False, study_alt=False):
        """
        - batch of sentences (output of make_batches)

        - study_alt : study alternative ways of make predictions
        
        NB: pertains in graph mode only !!
          - in arc_adja (resp. lab_adja), cells equal 1 for gold arcs
             (0 cells for non gold or padded)
          - pad_masks : 0 cells if head OR dep is padded token and 1 otherwise
        """

        lengths, pad_masks, forms, lemmas, tags, bert_tokens, bert_ftid_rkss, arc_adja, lab_adja, bols, slabseqs = batch
            
        # forward 
        S_arc, S_lab, log_pred_nbheads, log_pred_nbdeps, log_pred_bols, scores_slabseqs = self(forms, lemmas, tags, bert_tokens, bert_ftid_rkss, pad_masks, lengths=lengths)

        # pad_masks is [b, m, m]
        # => build a simple [b, m] mask
        linear_pad_mask = pad_masks[:,0,:] 
        nb_toks = linear_pad_mask.sum().item()
        batch_size = forms.shape[0]

        task2loss = defaultdict(int)

        loss = torch.zeros(()).to(self.device)
        dyn_loss_weights = torch.exp( - self.log_sigma2 ) # if lsig2 is log(sigma2), then exp(-lsig2) = 1/sigma2

        # NB: all sents in batch start with the <root> tok (not padded)
        if 'a' in self.task2i:
          arc_loss = self.bce_loss_with_mask(S_arc, arc_adja, pad_masks)
          ti = self.task2i['a']
          task2loss['a'] = arc_loss.item()
          loss +=  (dyn_loss_weights[ti] * arc_loss) + self.log_sigma2[ti]

        if 'l' in self.task2i:
          # --- Label loss -------------------------
          # label scores : rearrange into a batch in which each sample is 
          #                - one head and dep token pair
          #                - label scores for such arc (NB: non-gold arcs will be masked)
          # S_lab is [batch, label, head, dep]
          s_labels = S_lab.transpose(2,1).transpose(3,2)             # b , h , d, l
          s_labels = torch.flatten(s_labels, start_dim=0, end_dim=2) # b * h * d, l
        
          # same for gold labels
          g_labels = torch.flatten(lab_adja) # [b, h, d] ==> [b * h * d]

          # loss with ignore_index == 0 (=> ignore padded arcs and non-gold arcs, which don't have gold labels anyway)
          # cf. Dozat et al. 2018 "back-propagating error to the labeler only through edges with a non-null gold label"
          lab_loss = self.ce_loss(s_labels, g_labels) 
          ti = self.task2i['l']
          task2loss['l'] = lab_loss.item()
          loss +=  (dyn_loss_weights[ti] * lab_loss) + self.log_sigma2[ti]
        
            
        # auxiliary tasks
        if 'h' in self.task2i:
          gold_nbheads = arc_adja.sum(dim=1).float() # [b, h, d] => [b, d]
          log_gold_nbheads = torch.log(1 + gold_nbheads)
          loss_h = self.mse_loss_with_mask(log_pred_nbheads, log_gold_nbheads, linear_pad_mask)
          task2loss['h'] = loss_h.item()
          ti = self.task2i['h']
          loss +=  (dyn_loss_weights[ti] * loss_h) + self.log_sigma2[ti]
        else:
          gold_nbheads = None

        if 'd' in self.task2i:
          gold_nbdeps = arc_adja.sum(dim=2).float()  # [b, h, d] => [b, h]
          log_gold_nbdeps = torch.log(1 + gold_nbdeps)
          loss_d = self.mse_loss_with_mask(log_pred_nbdeps, log_gold_nbdeps, linear_pad_mask)
          task2loss['d'] = loss_d.item()
          ti = self.task2i['d']
          loss +=  (dyn_loss_weights[ti] * loss_d) + self.log_sigma2[ti]
        else:
          gold_nbdeps = None

#        # predicted global balance in each sentence, between the predicted nbheads and the predicted nbdeps
#        # which should be 0
#        if 'g' in self.task2i:
#          # for each sent, total nb heads minus total nb deps
#          pred_h_d_per_sentence = (log_pred_nbheads * linear_pad_mask).sum(dim=1) - (log_pred_nbdeps * linear_pad_mask).sum(dim=1)
#          gold_h_d_per_sentence = torch.zeros(batch_size).to(self.device)
#          # rescaling the loss : global loss is for all sentences => rescale to approx nb of tokens
#          loss_global = (nb_toks / batch_size) * self.mse_loss(pred_h_d_per_sentence, gold_h_d_per_sentence)
#          task2loss['g'] = loss_global.item()
#          ti = self.task2i['g']
#          loss +=  (dyn_loss_weights[ti] * loss_global) + self.log_sigma2[ti]

        if 'b' in self.task2i:
          # unfortunately, bincount on 1-d tensors only 
          # so computing the gold BOLs in make_batches rather
          #torch.bincount(arc_lab, minlength=self.num_labels) # +1
          # bols are [b, d, num_labels]
          log_gold_bols = torch.log(1+bols)
          #loss_bol = self.cosine_loss(log_pred_bols, log_gold_bols, linear_pad_mask)
          loss_bol = self.l2dist_loss(torch.flatten(log_pred_bols, start_dim=0, end_dim=1), # flatten from [b, d, l] to [b*d, l]
                                      torch.flatten(log_gold_bols, start_dim=0, end_dim=1),
                                      torch.flatten(linear_pad_mask, start_dim=0, end_dim=1))
          task2loss['b'] = loss_bol.item()
          ti = self.task2i['b']
          loss +=  (dyn_loss_weights[ti] * loss_bol) + self.log_sigma2[ti]

        if 's' in self.task2i:
          loss_slabseq = self.ce_loss(torch.flatten(scores_slabseqs, start_dim=0, end_dim=1),
                                      torch.flatten(slabseqs, start_dim=0, end_dim=1))
          task2loss['s'] = loss_slabseq.item()
          ti = self.task2i['s']
          loss +=  (dyn_loss_weights[ti] * loss_slabseq) + self.log_sigma2[ti]

        if trace_first:
          for ti, task in enumerate(self.tasks):
            print("dyn_loss_w of task %s : %f" %(task.upper(), dyn_loss_weights[ti]))

        # --- Prediction and evaluation --------------------------
        # provide the batch, and all the output of the forward pass
        task2nbcorrect = self.batch_predict_and_evaluate(batch, gold_nbheads, gold_nbdeps, linear_pad_mask, 
                                                         S_arc, S_lab, log_pred_nbheads, log_pred_nbdeps, log_pred_bols, scores_slabseqs,
                                                         study_alt)
 
        return loss, task2loss, task2nbcorrect, nb_toks
    
    def batch_predict_and_evaluate(self, batch, 
                                   gold_nbheads, gold_nbdeps, linear_pad_mask, # computed in batch_forward_and_loss
                                   S_arc, S_lab, log_pred_nbheads, log_pred_nbdeps, log_pred_bols, scores_slabseqs, # output by forward pass
                                   study_alt=False # whether to study other prediction algorithms
                                   ):

      lengths, pad_masks, forms, lemmas, tags, bert_tokens, bert_ftid_rkss, arc_adja, lab_adja, bols, slabseqs = batch

      task2nbcorrect = defaultdict(int)

      # --- Prediction and evaluation --------------------------
      with torch.no_grad():
        if 'a' in self.task2i:
          pred_arcs = (S_arc > 0).int() * pad_masks  # b, h, d
          nb_correct_u = torch.sum((pred_arcs * arc_adja).int()).item()
          nb_gold = torch.sum(arc_adja).item()
          nb_pred = torch.sum(pred_arcs).item()
          task2nbcorrect['a'] = (nb_correct_u, nb_gold, nb_pred)

        if 'l' in self.task2i:
            # labeled
            pred_labels = torch.argmax(S_lab, dim=1) # for all arcs (not only the predicted arcs)
            # count correct labels for the predicted arcs only
            nb_correct_u_and_l = torch.sum((pred_labels == lab_adja).float() * pred_arcs).item()
            task2nbcorrect['l'] = (nb_correct_u_and_l, nb_gold, nb_pred)

        # NB: round predicted numbers of heads / deps for evaluation only
        if 'h' in self.task2i:
          pred_nbheads = torch.round(torch.exp(log_pred_nbheads) - 1)
          task2nbcorrect['h'] = torch.sum((pred_nbheads == gold_nbheads).int() * linear_pad_mask).item()

        if 'd' in self.task2i:
          pred_nbdeps = torch.round(torch.exp(log_pred_nbdeps) - 1)
          task2nbcorrect['d'] = torch.sum((pred_nbdeps == gold_nbdeps).int() * linear_pad_mask).item()

        if 'b' in self.task2i:
          pred_bols = torch.round(torch.exp(log_pred_bols) - 1) # [b, d, num_labels+1]
          # nb of b , d pairs (token d in batch instance b) for which the full predicted bol is correct
          #   i.e. nb_toks minus the number of b,d pairs for which 
          #        there is at least (torch.any) one label dim differing (!=) between gold and predicted
          # from mask [b,d] to mask [b,d,num_labels+1]
          bol_pad_mask = linear_pad_mask.unsqueeze(2).expand(-1,-1,self.num_labels) #@@ +1)
          nb_incorrect = torch.sum(torch.any(((pred_bols != bols).int() * bol_pad_mask).bool(), dim=2).int())
          task2nbcorrect['b'] = nb_toks - nb_incorrect.item()

        if 's' in self.task2i:
          pred_slabseqs = torch.argmax(scores_slabseqs, dim=2) # [b, d]
          task2nbcorrect['s'] = torch.sum((pred_slabseqs == slabseqs).int() * linear_pad_mask).item()
          # count the unk slabseq as incorrect
          task2nbcorrect['sknown'] = task2nbcorrect['s'] - torch.sum((pred_slabseqs == UNK_ID).int() * linear_pad_mask).item()

        if study_alt:
          nbheads_from_a = torch.sum(pred_arcs, dim=1) # b, d
          nbheads_from_s, bols_from_s = self.indices.interpret_slabseqs(pred_slabseqs)
          nbheads_from_s = nbheads_from_s.to(self.device)
          uninterpretable = (nbheads_from_s == -1).int() # cases for which output slabseq is not interpretable
          interpretable = (nbheads_from_s != -1).int()
          # majority vote on the 3 kinds of prediction of nbheads 
          nbheads_from_v = torch.round(((nbheads_from_a + nbheads_from_s + pred_nbheads) * interpretable / 3) # will yield 0 if score 0 or 1, and 1 if score 2 or 3
                                        + (pred_nbheads * uninterpretable)) # when s in unavailable, use nbheads from task h
          task2nbcorrect['v'] = torch.sum((nbheads_from_v == gold_nbheads).int() * linear_pad_mask).item()

          # predict the top most arcs according to various nbheads

          s, indices = torch.sort(S_arc, dim=1) # sort the scores of the arcs
          # 3 tensors for 3 other ways to predict arcs, 
          #           according to best xxx scores for each dependent d
          #           with xxx being the nbheads predicted using tasks h, s, or v
          alt_pred_arcs = {'h':torch.zeros(S_arc.shape), 
                           's':torch.zeros(S_arc.shape), 
                           'v':torch.zeros(S_arc.shape)}
          (bs, m, m) = S_arc.shape
          for b in range(bs):
              for d in range(m): # d
                  int_nbheads_list = {'h':pred_nbheads[b,d].item(),
                                      's':nbheads_from_s[b,d].item(), 
                                      'v':nbheads_from_v[b,d].item()}
                  for h in range(m): # h
                      for t in ['h', 's', 'v']: # 3 ways to get the nbheads
                        if h < (m - int_nbheads_list[t]):
                          alt_pred_arcs[t][b, indices[b, h, d], d] = 0
                        else: 
                          alt_pred_arcs[t][b, indices[b, h, d], d] = 1          
          for t in ['h','s','v']: 
            alt_pred_arcs[t] = alt_pred_arcs[t].to(self.device) * pad_masks
            nb_correct_u = torch.sum((alt_pred_arcs[t] * arc_adja).int()).item()
            nb_pred = torch.sum(alt_pred_arcs[t]).item()
            task2nbcorrect['a' + t] = (nb_correct_u, nb_gold, nb_pred)
            # the pred labels contains the best label for all (h,d) pairs, 
            # and thus are common to any arc prediction style
            nb_correct_u_and_l = torch.sum((pred_labels == lab_adja).float() * alt_pred_arcs[t]).item()
            task2nbcorrect['l' + t] = (nb_correct_u_and_l, nb_gold, nb_pred)

      return task2nbcorrect

    def train_model(self, train_data, val_data, outdir, config_name, nb_epochs, batch_size, lr, lex_dropout):
        """
         NB: train and val_data should be GraphDataSet instances 

         task2lossw = dictionary key = task short name / val = loss weight for this task      
        """
        self.lr = lr
        self.alpha = 0.1 # exponent for loss-balancing
        self.lex_dropout = lex_dropout # proba of word / lemma / pos tag dropout
        self.batch_size = batch_size
        self.beta1 = 0.9
        self.beta2 = 0.9
        #optimizer = optim.SGD(biaffineparser.parameters(), lr=LR)
        #optimizer = optim.Adam(self.parameters(), lr=lr, betas=(0., 0.95), eps=1e-09)
        optimizer = optim.Adam(self.parameters(), lr=lr, betas=(self.beta1, self.beta2), eps=1e-09)
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma = 0.95)
        
        # loss function for nbheads / nbdeps (per token)
        self.mse_loss_with_mask = MSELoss_with_mask(reduction='sum') # 'h', 'd'
        # loss function for global_h_d (per sentence)
        self.mse_loss = nn.MSELoss(reduction='sum')                  # 'g'
        #self.cosine_loss = CosineLoss_with_mask(reduction='sum')
        self.l2dist_loss =  L2DistanceLoss_with_mask(reduction='sum') # 'b'
        # for label loss, the label for padded deps is PAD_ID=0 
        #   ignoring padded dep tokens (i.e. whose label id equals PAD_ID)
        # used both for arc labels and for sorted label sequences (seen as atoms)
        self.ce_loss = nn.CrossEntropyLoss(reduction='sum', ignore_index=PAD_ID) # 'l', 's'
        # for graph mode arcs
        self.bce_loss_with_mask = BCEWithLogitsLoss_with_mask(reduction='sum') # 'a'
        # for tree mode arcs and both tree and graph mode labels
        #   (CrossEnt cf. softmax not applied yet in BiAffine output)
        #   ignoring padded dep tokens (i.e. whose head equals PAD_HEAD_RK)
        #self.ce_loss_fn_arc = nn.CrossEntropyLoss(reduction='sum', ignore_index=PAD_HEAD_RK)

        
        self.config_name = config_name
                
        out_model_file = outdir + '/' + config_name + '.model_nbarcs'
        out_log_file = outdir + '/' + config_name + '.log_nbarcs'
        log_stream = open(out_log_file, 'w')

        self.log_train_hyper(sys.stdout)
        self.log_train_hyper(log_stream)

        # losses and scores at each epoch (on train / validation)
        train_losses = []
        val_losses = []
        val_task2losses = defaultdict(list) 
        train_task2accs = defaultdict(list)
        val_task2accs = defaultdict(list)  

        # word / pos / lemma dropout of training data only
        # seems ok to re-drop at each epoch
        #train_data.lex_dropout(lex_dropout)

        for epoch in range(1,nb_epochs+1):
            i = 0
            train_loss = 0
            train_task2loss = defaultdict(int)
            train_task2nbcorrect = defaultdict(int)
            train_nb_toks = 0
            val_loss = 0
            val_task2loss = defaultdict(int)
            val_task2nbcorrect = defaultdict(int)
            val_nb_toks = 0
            
            if 'a' in self.task2i:
              train_task2nbcorrect['a'] = [0,0,0] # tasks with fscore metric we have a triplet nbcorrect, nbgold, nbpred
              val_task2nbcorrect['a'] = [0,0,0] 
            if 'l' in self.task2i:
              train_task2nbcorrect['l'] = [0,0,0]
              val_task2nbcorrect['l'] = [0,0,0] 

            # training mode (certain modules behave differently in train / eval mode)
            self.train()
            train_data.lex_dropout(lex_dropout)
            bid = 0
            trace_first = True
            for batch in train_data.make_batches(self.batch_size, shuffle_data=True, sort_dec_length=True, shuffle_batches=True):        
                self.zero_grad()
                bid += 1
                if bid % 2000 == 0:
                  print("BATCH SHAPE:", batch[2].shape, batch[5].shape)
                  print("MEMORY BEFORE BATCH FORWARD AND LOSS")
                  printm()
                loss, task2loss, task2nbcorrect, nb_toks = self.batch_forward_and_loss(batch, trace_first=trace_first)
                trace_first = False

                loss.backward()
                optimizer.step() 
                loss.detach()           

                train_loss += loss.item()
                train_nb_toks += nb_toks
                for k in self.tasks:
                  train_task2loss[k] += task2loss[k]
                  if k in ['a','l']: 
                    for i in [0,1,2]:
                      train_task2nbcorrect[k][i] += task2nbcorrect[k][i]
                  elif k != 'g':
                    train_task2nbcorrect[k] += task2nbcorrect[k]

            # for one epoch              
            print("Train: nb toks " + str(train_nb_toks) + "/ " + " / ".join([t.upper()+":"+str(train_task2nbcorrect[t]) for t in self.tasks]))              
            assert train_nb_toks == train_data.nb_words, "train_nb_toks %d should equal train_data.nb_words %d" %(train_nb_toks, train_data.nb_words)

            for k in self.tasks:
              train_task2loss[k] /= train_data.nb_words
              if k in ['a', 'l']:
                train_task2accs[k].append(fscore(*train_task2nbcorrect[k]))
              elif k != 'g':
                train_task2accs[k].append( 100 * train_task2nbcorrect[k] / train_nb_toks )            
            train_loss = train_loss/train_data.nb_words
            train_losses.append(train_loss)

            self.log_perf(log_stream, epoch, 'Train', train_loss, train_task2loss, train_task2accs)

            if val_data:
                self.eval()
                # calcul de la perte sur le validation set
                with torch.no_grad():
                    trace_first = True
                    for batch in val_data.make_batches(self.batch_size, sort_dec_length=True):
                        loss, task2loss, task2nbcorrect, nb_toks = self.batch_forward_and_loss(batch, trace_first=trace_first, study_alt=True)
                        val_loss += loss.item()
                        val_nb_toks += nb_toks
                        for k in task2nbcorrect.keys():
                          if k in task2loss:
                            val_task2loss[k] += task2loss[k]
                          if type(task2nbcorrect[k]) != int: # tuple or list 
                            if k not in val_task2nbcorrect: # those that are not registered yet are the study_alt, and are only fscore-like
                              val_task2nbcorrect[k] = [0,0,0]
                            for i in [0,1,2]:
                              val_task2nbcorrect[k][i] += task2nbcorrect[k][i]
                          elif k != 'g':
                            val_task2nbcorrect[k] += task2nbcorrect[k]
                        trace_first = False
                        
                    # for one epoch
                    print("Val: nb toks " + str(val_nb_toks) + "/ " + " / ".join([t.upper()+":"+str(val_task2nbcorrect[t]) for t in self.tasks]))              
                    assert val_nb_toks == val_data.nb_words, "val_nb_toks %d should equal val_data.nb_words %d" %(val_nb_toks, val_data.nb_words)
                    for k in task2nbcorrect.keys():#self.tasks:
                      if k in val_task2loss:
                        val_task2loss[k] /= val_data.nb_words
                      # if task if fscore-like
                      if type(val_task2nbcorrect[k]) == list:
                        val_task2accs[k].append(fscore(*val_task2nbcorrect[k]))
                      elif k != 'g':
                        val_task2accs[k].append( 100 * val_task2nbcorrect[k] / val_nb_toks )            

                    val_loss = val_loss/val_data.nb_words
                    val_losses.append(val_loss)
            
                self.log_perf(log_stream, epoch, '\tValid', val_loss, val_task2loss, val_task2accs)
    
                if epoch == 1:
                    print("saving model after first epoch\n")
                    torch.save(self, out_model_file)
                # if validation loss has decreased: save model
                # nb: when label loss comes into play, it might artificially increase the overall loss
                #     => we don't early stop at this stage 
                #elif (val_losses[-1] < val_losses[-2]) :
                elif val_task2accs['l'][-1] > val_task2accs['l'][-2] :
                    for stream in [sys.stdout, log_stream]:
                        #stream.write("Validation loss has decreased, saving model, current nb epochs = %d\n" % epoch)
                        stream.write("Validation L perf has increased, saving model, current nb epochs = %d\n" % epoch)
                    torch.save(self, out_model_file)
                # otherwise: early stopping, stop training, reload previous model
                # NB: the model at last epoch was not saved yet
                # => we can reload the model from the previous storage
                else:
                    #print("Validation loss has increased, reloading previous model, and stop training\n")
                    print("Validation L perf has decreased, reloading previous model, and stop training\n")
                    # reload
                    # REM:the loading of tensors will be done on the device they were copied from
                    # cf. https://pytorch.org/docs/stable/generated/torch.load.html#torch-load
                    self = torch.load(out_model_file)
                    # stop loop on epochs
                    break
            scheduler.step()
            # end loop on epochs
        for stream in [sys.stdout, log_stream]:
          stream.write("train losses: %s\n" % ' / '.join([ "%.4f" % x for x in train_losses]))
          stream.write("val   losses: %s\n" % ' / '.join([ "%.4f" % x for x in val_losses]))
          for k in sorted(val_task2accs.keys()):
            if k in train_task2accs:
              stream.write("train %s accs: %s\n" % (k.upper(), ' / '.join([ "%.2f" % x for x in train_task2accs[k] ])))
            stream.write("val   %s accs: %s\n" % (k.upper(), ' / '.join([ "%.2f" % x for x in val_task2accs[k] ])))

    def log_perf(self, outstream, epoch, ctype, l, task2loss, task2accs):
        for stream in [sys.stdout, outstream]:
          stream.write("%s   Loss  for epoch %d: %.4f\n" % (ctype, epoch, l))
          for k in sorted(task2loss.keys()):
            stream.write("%s %s Loss  for epoch %d: %.4f\n" % (ctype, k.upper(), epoch, task2loss[k]))
          for k in sorted(task2accs.keys()):
            stream.write("%s %s ACC after epoch %d : %.2f\n" % (ctype, k.upper(), epoch, task2accs[k][-1]))

    def log_train_hyper(self, outstream):
        for h in ['w_emb_size', 'l_emb_size', 'p_emb_size', 'lstm_h_size','mlp_arc_o_size','mlp_arc_dropout','use_pretrained_w_emb']:
          outstream.write("%s : %s\n" %(h, str(self.__dict__[h])))
        outstream.write("\n")
        for h in ['batch_size', 'beta1','beta2','lr', 'lex_dropout', 'freeze_bert']:
          outstream.write("%s : %s\n" %(h, str(self.__dict__[h])))
        for k in self.tasks:
          outstream.write("task %s\n" % k)
        outstream.write("\n")

    def predict_and_evaluate(self, test_data, log_stream, out_file=None):
      """ predict on test data and evaluate 
      if out_file is set, prediction will be dumped in readable format in out_file
      """
      if out_file != None:
        out_stream = open(out_file, 'w')
      
      self.eval()
      test_nb_toks = 0
      test_task2nbcorrect = defaultdict(int)
      test_task2acc = defaultdict(float)
      with torch.no_grad():
        for batch in test_data.make_batches(self.batch_size, sort_dec_length=True):
          # forward 
          S_arc, S_lab, log_pred_nbheads, log_pred_nbdeps, log_pred_bols, scores_slabseqs = self(forms, lemmas, tags, bert_tokens, bert_ftid_rkss, pad_masks, lengths=lengths)

          linear_pad_mask = pad_masks[:,0,:] # from [b, m, m] to [b, m]
          test_nb_toks += linear_pad_mask.sum().item()

          if 'h' in self.task2i:
            pred_nbheads = torch.round(torch.exp(log_pred_nbheads) - 1)

          if 'd' in self.task2i:
            pred_nbdeps = torch.round(torch.exp(log_pred_nbdeps) - 1)

          # --- Prediction and evaluation --------------------------
          # provide the batch, and all the output of the forward pass
          task2nbcorrect = self.batch_predict_and_evaluate(batch, gold_nbheads, gold_nbdeps, linear_pad_mask, S_arc, S_lab, log_pred_nbheads, log_pred_nbdeps, log_pred_bols, scores_slabseqs)
          for k in self.tasks:
            if k in ['a','l']: 
              for i in [0,1,2]:
                test_task2nbcorrect[k][i] += task2nbcorrect[k][i]
            elif k != 'g':
              test_task2nbcorrect[k] += task2nbcorrect[k]

          task2nbcorrect_alt = self.batch_alternative_predict_and_evaluate(batch, gold_nbheads, gold_nbdeps, linear_pad_mask, S_arc, S_lab, log_pred_nbheads, log_pred_nbdeps, log_pred_bols, scores_slabseqs)

        # for the full test set
        print("Test: nb toks " + str(test_nb_toks) + "/ " + " / ".join([t.upper()+":"+str(test_task2nbcorrect[t]) for t in self.tasks]))              
        for k in self.tasks:
          if k in ['a', 'l']:
            test_task2acc[k] = fscore(*test_task2nbcorrect[k])
          elif k != 'g':
            test_task2acc[k] = 100 * test_task2nbcorrect[k] / test_nb_toks
