We perform, in-batch negative sampling. There isn't any real "sampling" per say.
What we need to compute the loss are logits of the positives and negatives along with the indices of the positives.
The positive logits are on the diagonal of the logits matrix and the negatives off-diagonal.
The positive indices are simply [0, 1, 2, ..., batch_size]