-
Notifications
You must be signed in to change notification settings - Fork 2.7k
[v4
] CrossEncoder Training refactor - MultiGPU, loss logging, bf16, etc.
#3222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[v4
] CrossEncoder Training refactor - MultiGPU, loss logging, bf16, etc.
#3222
Conversation
Hi @tomaarsen I've used your fork and branch as a base and added my implementation for ListNet Loss. Changes
🔗 Link to My BranchLet me know what you think or if any modifications are needed! |
Hello! This is excellent work, looks very solid! Have you been able to run the training script yourself so far? I can also try and run it and upload the finished model. In the coming days I can try and merge your work into this PR.
|
It's interesting to see that although the model does get better than the BM25 baseline, the loss effectively does not change. |
Hi @tomaarsen 👋 I successfully trained the model and experimented with hyperparameters. Trained ModelYou can find the trained model here: Observations on LossI noticed that the loss is slightly higher (~2.0). Through research and testing, I found that this discrepancy arises due to differences in distribution:
One possible solution is to apply a transformation to the predicted distribution to better align it with the ground truth. However, for now, I think this is sufficient. Instead of refining this approach further, I’d prefer to integrate more listwise loss functions that are known to outperform ListNet. Additionally, I assume that combining MSE and ListNet loss could yield better results by leveraging the strengths of both approaches. I can explore this further. Issue with Missing Values in Evaluation CSVWhile training, I noticed missing values in the evaluation CSVs. Example from epoch,steps,MAP,MRR@10,NDCG@10
0.17214666896195557,2000,0.04166666666666666,0.0654804137172179
0.34429333792391115,4000,0.2629365079365079,0.32231584858096596
0.5164400068858668,6000,0.417079365079365,0.4742922970864457
0.6885866758478223,8000,0.45016666666666666,0.5006481726476146
0.8607333448097779,10000,0.49913492063492065,0.5646165118665413
1.0328800137717336,12000,0.49426984126984125,0.5609090407025468
1.205026682733689,14000,0.4636349206349206,0.5413296770100868
1.3771733516956446,16000,0.4563809523809524,0.5233662617261322 Here, we expect five columns, but only four values appear in some rows. Has this happened in your training as well? I used the same (or a very similar) setup from your MS MARCO training example. |
I'll investigate the CSV issue, that one is definitely on me. Beyond that, I've implemented an I'm definitely open to other Listwise loss implementations! I'm currently looking into improving
|
Hi @milistu, @tomaarsen About ListMLELossListMLE is a listwise learning-to-rank loss function that directly optimizes the likelihood of the correct permutation of documents. It models the probability of a permutation using the Plackett-Luce model, which sequentially selects items based on their scores.
Additionally, I've implemented Position-Aware ListMLE with lambda weighting, which applies different weights to different rank positions. This allows the model to focus more on getting the top positions correct, which is often more important in ranking tasks. This loss is particularly valuable when dealing with multiple relevant documents that have a clear preference order. For example, when training a reranker for tool selection, some tools should be ranked higher than others for a given query, even though both are relevant. ListMLE can effectively learn these nuanced preferences by modeling the entire permutation probability, ensuring that the most suitable tool appears first in the ranking, followed by the second-best option, and so on. PR
|
* Add ListMLELoss * Fix input_order not being considered * Update init.py * Add training scripts for ListMLELoss * Fix self.lambda_weight to ListMLELambdaWeight Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> * Refactor conditional logic in ListMLELambdaWeight Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> * Fix Delete unused function - create_p_list_mle_lambda_weight * Refactor mask creation using zeros_like Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> * Refactor for-loop with vectorized operations when applying position weights Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> * Fix reference and citation * Refactor to seperate PListMLE and ListMLE * Refactor training scripts for PListMLE and ListMLE * Add information of data to be sorted in a defined rank order * Run formatting * Ensure that paddings are excluded in the loss * Remove lambda_weight as an option from ListMLELoss * Update documentation throughout * Update MS MARCO training examples docs * Fix anecdotal ranking * Vectorize the lambda weight computation It should be equivalent, and considerably faster (although it wasn't necessarily a bottleneck) * Add +1 to PListMLELambdaWeight; normalize weight by divide-by-sum * Simplify code, remove duplicate +1 * Update get_config_dict for ListMLELoss to remove lambda_weight --------- Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>
Hi @tomaarsen, @milistu I was wondering if you would also like to implement RankNetLoss (a.k.a pairwise logistic loss) I've actually worked on the code and did some experiments about activation functions. codefrom __future__ import annotations
from typing import Literal
import torch
from torch import Tensor, nn
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.util import fullname
class RankNetLoss(nn.Module):
def __init__(
self,
model: CrossEncoder,
sigma: float = 1.0,
eps: float = 1e-10,
activation_fct: nn.Module | None = nn.Identity(),
mini_batch_size: int | None = None,
) -> None:
"""
RankNet loss implementation for learning to rank. This loss function implements the RankNet algorithm,
which learns a ranking function by optimizing pairwise document comparisons using a neural network.
The implementation is optimized to handle padded documents efficiently by only processing valid
documents during model inference.
Args:
model (CrossEncoder): CrossEncoder model to be trained
sigma (float): Score difference weight used in sigmoid (default: 1.0)
eps (float): Small constant for numerical stability (default: 1e-10)
activation_fct (:class:`~torch.nn.Module`): Activation function applied to the logits before computing the
loss. Defaults to :class:`~torch.nn.Identity`.
mini_batch_size (int, optional): Number of samples to process in each forward pass. This has a significant
impact on the memory consumption and speed of the training process. Three cases are possible:
- If ``mini_batch_size`` is None, the ``mini_batch_size`` is set to the batch size.
- If ``mini_batch_size`` is greater than 0, the batch is split into mini-batches of size ``mini_batch_size``.
- If ``mini_batch_size`` is <= 0, the entire batch is processed at once.
Defaults to None.
References:
- Learning to Rank using Gradient Descent: https://icml.cc/Conferences/2015/wp-content/uploads/2015/06/icml_ranking.pdf
Requirements:
1. Query with multiple documents (pairwise approach)
2. Documents must have relevance scores/labels. Both binary and continuous labels are supported.
Inputs:
+----------------------------------------+--------------------------------+-------------------------------+
| Texts | Labels | Number of Model Output Labels |
+========================================+================================+===============================+
| (query, [doc1, doc2, ..., docN]) | [score1, score2, ..., scoreN] | 1 |
+----------------------------------------+--------------------------------+-------------------------------+
Example:
::
from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses
from datasets import Dataset
model = CrossEncoder("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
"query": ["What are pandas?", "What is the capital of France?"],
"docs": [
["Pandas are a kind of bear.", "Pandas are kind of like fish."],
["The capital of France is Paris.", "Paris is the capital of France.", "Paris is quite large."],
],
"labels": [[1, 0], [1, 1, 0]],
})
loss = losses.RankNetLoss(model)
trainer = CrossEncoderTrainer(
model=model,
train_dataset=train_dataset,
loss=loss,
)
trainer.train()
"""
super().__init__()
self.model = model
self.sigma = sigma
self.eps = eps
self.activation_fct = activation_fct or nn.Identity()
self.mini_batch_size = mini_batch_size
if self.model.num_labels != 1:
raise ValueError(
f"{self.__class__.__name__} supports a model with 1 output label, "
f"but got a model with {self.model.num_labels} output labels."
)
def forward(self, inputs: list[list[str], list[list[str]]], labels: list[Tensor]) -> Tensor:
"""
Compute RankNet loss for a batch of queries and their documents.
Args:
inputs: List of (queries, documents_list)
labels: Ground truth relevance scores, shape (batch_size, num_documents)
Returns:
Tensor: RankNet loss over the batch
"""
if isinstance(labels, Tensor):
raise ValueError(
"RankNetLoss expects a list of labels for each sample, but got a single value for each sample."
)
if len(inputs) != 2:
raise ValueError(f"RankNetLoss expects two inputs (queries, documents_list), but got {len(inputs)} inputs.")
queries, docs_list = inputs
docs_per_query = [len(docs) for docs in docs_list]
max_docs = max(docs_per_query)
batch_size = len(queries)
if docs_per_query != [len(labels) for labels in labels]:
raise ValueError(
f"Number of documents per query in inputs ({docs_per_query}) does not match number of labels per query ({[len(labels) for labels in labels]})."
)
# Create input pairs for the model₩
pairs = [(query, document) for query, docs in zip(queries, docs_list) for document in docs]
if not pairs:
# Handle edge case where all documents are padded
return torch.tensor(0.0, device=self.model.device, requires_grad=True)
mini_batch_size = self.mini_batch_size or batch_size
if mini_batch_size <= 0:
mini_batch_size = len(pairs)
logits_list = []
for i in range(0, len(pairs), mini_batch_size):
mini_batch_pairs = pairs[i : i + mini_batch_size]
tokens = self.model.tokenizer(
mini_batch_pairs,
padding=True,
truncation=True,
return_tensors="pt",
)
tokens = tokens.to(self.model.device)
logits = self.model(**tokens)[0].view(-1)
logits_list.append(logits)
logits = torch.cat(logits_list, dim=0)
logits = self.activation_fct(logits)
# Create output tensor filled with 0 (padded logits will be ignored via labels)
logits_matrix = torch.full((batch_size, max_docs), -1e16, device=self.model.device)
# Place logits in the desired positions in the logit matrix
doc_indices = torch.cat([torch.arange(len(docs)) for docs in docs_list], dim=0)
batch_indices = torch.repeat_interleave(torch.arange(batch_size), torch.tensor(docs_per_query))
logits_matrix[batch_indices, doc_indices] = logits
# Create labels matrix
labels_matrix = torch.full_like(logits_matrix, float("-inf"))
labels_matrix[batch_indices, doc_indices] = torch.cat(labels, dim=0).float()
labels_matrix = labels_matrix.to(self.model.device)
# Calculate pairwise differences for scores and labels
score_diffs = logits_matrix[:, :, None] - logits_matrix[:, None, :]
label_diffs = labels_matrix[:, :, None] - labels_matrix[:, None, :]
# Create mask for valid pairs (where both documents are not padded)
valid_pairs = torch.isfinite(label_diffs)
# Create mask for pairs where l_i > l_j
positive_pairs = label_diffs > 0
# Calculate probabilities and target probabilities
P_ij = torch.sigmoid(self.sigma * score_diffs)
P_ij = torch.clamp(P_ij, min=self.eps, max=1-self.eps)
# Calculate loss only for pairs where l_i > l_j (positive_pairs)
losses = -torch.log(P_ij)
# Apply masks and compute mean loss
masked_loss = losses[valid_pairs & positive_pairs]
# Handle case when there are no positive pairs
if masked_loss.numel() == 0:
return torch.tensor(0.0, device=self.model.device, requires_grad=True)
loss = torch.mean(masked_loss)
return loss
def get_config_dict(self) -> dict[str, float | int | str | None]:
"""
Get configuration parameters for this loss function.
Returns:
Dictionary containing the configuration parameters
"""
return {
"sigma": self.sigma,
"eps": self.eps,
"activation_fct": fullname(self.activation_fct),
"mini_batch_size": self.mini_batch_size,
}
@property
def citation(self) -> str:
return """
@inproceedings{burges2005learning,
title={Learning to rank using gradient descent},
author={Burges, Chris and Shaked, Tal and Renshaw, Erin and Lazier, Ari and Deeds, Matt and Hamilton, Nicole and Hullender, Greg},
booktitle={Proceedings of the 22nd international conference on Machine learning},
pages={89--96},
year={2005}
}
""" To simplify, the best recipe for training the reranker model was just using the Identity function.
However, looking at So I trained the model with loss = LambdaLoss(
model=model,
weighting_scheme=NoWeightingScheme(),
mini_batch_size=mini_batch_size,
) And the result was:
My suggestionSo if you would like to implement RankNetLoss, there could be two scenarios we can take.
What are your thoughts? |
I'm definitely interested - I think it might make most sense to implement it as a subclass of the |
* Add RankNetLoss and training script * Fix ListMLELoss documentation * Fix RankNet to class LambdaLoss * Update training script for RankNetLoss * Use super().get_config_dict() and remove weighting scheme It's a bit confusing to include the weighting scheme in the config if the RankNet loss doesn't have a notion of that * Add to __init__.py for easier import * Correctly capitalize citation titles * Introduce reproducibility for the msmarco scripts * Add more docs for RankNetLoss * Add RankNet to Loss Overview & API Reference * Expand on RankNet docs slightly --------- Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>
Also fix version comparison for ST - I can't believe it was doing greater-than on strings for so long
feat
] CrossEncoder Training refactor - MultiGPU, loss logging, bf16, etc.v4
] CrossEncoder Training refactor - MultiGPU, loss logging, bf16, etc.
In conclusion:
Blogpost coming on release day. Big thanks to @milistu and @yjoonjang for their huge roles in the learning-to-rank losses.
|
Hello!
Pull Request overview
CrossEncoderTrainer
,CrossEncoderTrainingArguments
, and loss functions. Brings features such as:TODOs
fit
toold_fit
and create a newfit
method that depends onCrossEncoderTrainer
. Goal: no real backwards incompatibility with existing training scripts.Details
Overall, the goal of this refactor is to introduce feature parity between the Cross Encoder training and the Sentence Transformer training. Luckily, the work done for the ST trainer can be extended rather easily, so the refactor is not as big as it was for the SentenceTransformer class in v3.0.
Notably, training now centers around:
Dataset
orDatasetDict
. This class is much more suited for sharing & efficient modifications than lists/DataLoaders ofInputExample
instances. ADataset
can contain multiple text columns that will be fed in order to the corresponding loss function. So, if the loss expects (anchor, positive, negative) triplets, then your dataset should also have 3 columns. The names of these columns are irrelevant at this time. If there is a "label" column, it is treated separately, and used as the labels during training.A
DatasetDict
can be used to train with multiple datasets at once, e.g.:DatasetDict
is used, theloss
parameter to theCrossEncoderTrainer
must also be a dictionary with these dataset keys, e.g.:SentenceEvaluator
instance. These instances either return a float, or a dictionary with metric keys and values. If the latter, the class must also definedevaluator.primary_metric
so e.g. the "best model" checkpointing can be based on an evaluator score.Models can now be evaluated both on an evaluation dataset with some loss function and/or a
SentenceEvaluator
instance.CrossEncoderTrainer
instance. This instance is provided with a CrossEncoder model, a CrossEncoderTrainingArguments class, a SentenceEvaluator, a training and evaluation Dataset/DatasetDict and a loss function/dict of loss functions. Most of these parameters are optional. Once provided, all you have to do is calltrain()
.This is an example of an extensive training script with all of the features at play:
As you may note, it is very similar to the new SentenceTransformer flow:
datasets
Dataset, standaloneloss
with a lot more flexibility than before, a TrainingArguments and Trainer class, Evaluators much like before & as used in SentenceTransformer training, etc.cc @milistu as you're also working on CrossEncoders
cc @LysandreJik