Oh no. It seems that output cannot be referenced unless it is from compute_loss.
We could log it with overwriting compute_loss, but that would be vulnerable to version changes, and if we do it in a callback anyhow, we would have to calculate the output ourselves…
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
out = model(**inputs)
loss = out.loss
self.log({
"mlm_loss": out.mlm_loss.item(),
"clf_loss": out.clf_loss.item(),
"perplexity": torch.exp(out.mlm_loss).item()
})
return (loss, out) if return_outputs else loss
Is there a better way to do this…