Log losses/metrics with CustomTrainer(Trainer) class in the same frequency as Trainer, with wandb

Oh no. It seems that output cannot be referenced unless it is from compute_loss.

We could log it with overwriting compute_loss, but that would be vulnerable to version changes, and if we do it in a callback anyhow, we would have to calculate the output ourselves…

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        out = model(**inputs)
        loss = out.loss
        self.log({
            "mlm_loss": out.mlm_loss.item(),
            "clf_loss": out.clf_loss.item(),
            "perplexity": torch.exp(out.mlm_loss).item()
        })
        return (loss, out) if return_outputs else loss

Is there a better way to do this…