-
Notifications
You must be signed in to change notification settings - Fork 2
Add Position-Aware ListMLELoss #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Position-Aware ListMLELoss #6
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice work on this - I will be training some models with this to experiment. I have some questions and comments with the goal to make this as easy as possible to use for the users.
return self.rank_discount_fn(ranks) | ||
|
||
|
||
class ListMLELoss(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From a user perspective, it might not be immediately obvious that this supports both the "standard" ListMLE loss as well as the "improved" position-aware ListMLE loss. Beyond that, my experience is that 95% of users will use the default settings, so we want to promote the strongest options.
I think there are a handful of options, but I think my favourite is to keep having only ListMLELoss
, but perhaps use ListMLELambdaWeight()
as the default lambda_weight
, so the default is p-ListMLE but ListMLE can be used as well. We can be explicit about this in the model card.
Alternatively, we can rename the ListMLELoss
class to e.g. PListMLELoss
and then introduce a new ListMLELoss
class which simply subclasses PListMLELoss
except doesn't expose a lambda_weight
option - it just calls the PListMLELoss
superclass with lambda_weight=None
. In short, this means we introduce 2 losses with this PR.
I'm curious about both of your thoughts here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I strongly agree with your opinion, @tomaarsen.
Since Position-aware ListMLELoss (PListMLELoss sounds good.) is advanced version of ListMLELoss, using PListMLELoss would probably be better.
If @milistu agrees, I will rename to PListMLELoss
(which has lambda_weight
in default) and add a new loss named ListMLELoss
which calss PListMLELoss
superclass with lambda_weight=None
.
respect_input_order (bool): Whether to respect the original input order of documents. | ||
If True, assumes the input documents are already ordered by relevance (most relevant first). | ||
If False, sorts documents by label values. Defaults to True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If respect_input_order
is True, is there a difference between
docs: (doc1, doc2, doc3, doc4, doc5)
labels: (1, 1, 0, 0, 0)
and
docs: (doc1, doc2, doc5, doc3, doc4)
labels: (1, 1, 0, 0, 0)
e.g. if the ones with label=0 are swapped.
Also, if respect_input_order
is True, and their inputs is e.g.:
docs: (doc1, doc2, doc3, doc4, doc5)
labels: (1, 0, 0, 1, 0)
would that give incorrect results?
Otherwise it maybe makes sense to do a quick check if the docs are sorted, and update them only if they are not sorted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
If
respect_input_order=True
, there is a difference between the two examples you provided. Even though doc3, doc4, and doc5 all have the same label (0), their order matters in ListMLE. The loss function models the sequential selection process and will train the model to select documents in exactly the order provided. So if the order of documents with label 0 changes, the model will learn a different permutation. -
If
respect_input_order=True
and the documents are not sorted by relevance (e.g., labels are [1, 0, 0, 1, 0]), this could indeed lead to suboptimal results for ranking tasks. The model will learn to reproduce this exact order, even though typically we'd want all documents with label 1 to be ranked before documents with label 0.
The intent for respect_input_order
argument was to make model learn ranking for same labels. This is particularly useful in scenarios where we have multiple relevant documents with the same label but different degrees of preference or importance that aren't captured by the label values alone.(e.g. Tool selection task: when 2 tools are gold but there is a sequence for the tools.)
# Position-Aware ListMLE with custom weighting function | ||
def custom_discount(ranks): # e.g. ranks: [1, 2, 3, 4, 5] | ||
return 1.0 / torch.log1p(ranks) | ||
lambda_weight = losses.ListMLELambdaWeight(rank_discount_fn=custom_discount) | ||
loss = losses.ListMLELoss(model, lambda_weight=lambda_weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this custom weighting function option also proposed in the p-ListMLE paper? (https://auai.org/uai2014/proceedings/individuals/164.pdf)
Or is this 'custom'? Do we have evaluations on whether this works better than e.g. the default lambda weighting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This custom weighting function option is proposed in the paper (6. Experiments)
Which says: "In p-ListMLE, we set α(i) as
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, perfect, I'll have another look at the paper.
Edit: Is the 1.0 / torch.log1p(ranks)
also used anywhere? I can find details about the
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, sorry. I thought you were pointing to this part: return torch.pow(2.0, list_size - ranks) - 1.0
The custom_discount function is a 'custom' function. The one I implemented makes it to a lambdaloss-like function. (It only uses the Discount
though.)
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
I ran the training script with pListMLE: # Option 2: Position-Aware ListMLE with default weighting
lambda_weight = ListMLELambdaWeight()
loss = ListMLELoss(model, lambda_weight=lambda_weight, mini_batch_size=mini_batch_size, respect_input_order=respect_input_order) and the results have been okay so far - not amazing. Here's the model: https://huggingface.co/tomaarsen/reranker-msmarco-v1.1-MiniLM-L12-H384-uncased-plistmle The loss starts at ~1400 and then falls to ~820. This is a lot higher than usual - doesn't normally matter, but it might be indicative of an issue somewhere.
|
logits = self.activation_fct(logits) | ||
|
||
# Create output tensor filled with a very small value for padded logits | ||
logits_matrix = torch.full((batch_size, max_docs), -1e16, device=self.model.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a very small value, it's a very large negative value (-10000000000000000). cc @milistu this is also the case for e.g. ListNetLoss, although it's ignored via the labels.
Maybe it was meant to be 1e-16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't using a large negative value (-1e16) be safer for padded positions rather than a small positive value (1e-16)? The large negative value ensures padded positions have effectively zero contribution in softmax operations, while a small positive value might still have some minor influence.
Hmm interesting.. Looks a bit better than model using ListNetLoss (https://huggingface.co/tomaarsen/reranker-msmarco-v1.1-MiniLM-L12-H384-uncased-listnet) Could this be due to the data label? I saw that the data you used (https://huggingface.co/datasets/microsoft/ms_marco/viewer/v1.1/train?views%5B%5D=v11_train&row=60) and the label ( If you're using |
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
I'm afraid the training script already sorts the pairs by label: # Pair passages with labels and sort descending by label (positives first)
paired = sorted(zip(passages, labels), key=lambda x: x[1], reverse=True)
# Separate back to passages and labels
sorted_passages, sorted_labels = zip(*paired) if paired else ([], []) so it's not worse for this reason - I think it might just fall between ListNetLoss and LambdaLoss. |
…eights Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
One of the ways to avoid the weirdly high losses might be to call a |
Can I have your training code or script please? I would also like to do some experiments. |
My exact local code right now is this: Training Scriptimport logging
import traceback
from datasets import load_dataset
import torch
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
from sentence_transformers.cross_encoder.losses import ListMLELoss, ListMLELambdaWeight
from sentence_transformers.cross_encoder.trainer import CrossEncoderTrainer
from sentence_transformers.cross_encoder.training_args import CrossEncoderTrainingArguments
def main():
model_name = "microsoft/MiniLM-L12-H384-uncased"
# Set the log level to INFO to get more information
logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
)
# train_batch_size and eval_batch_size inform the size of the batches, while mini_batch_size is used by the loss
# to subdivide the batch into smaller parts. This mini_batch_size largely informs the training speed and memory usage.
# Keep in mind that the loss does not process `train_batch_size` pairs, but `train_batch_size * num_docs` pairs.
train_batch_size = 16
eval_batch_size = 16
mini_batch_size = 16
num_epochs = 1
max_docs = None
respect_input_order = True # Whether to respect the original order of documents
# 1. Define our CrossEncoder model
model = CrossEncoder(model_name, num_labels=1)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
# 2. Load the MS MARCO dataset: https://huggingface.co/datasets/microsoft/ms_marco
logging.info("Read train dataset")
dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train")
def listwise_mapper(batch, max_docs: int | None = 10):
processed_queries = []
processed_docs = []
processed_labels = []
for query, passages_info in zip(batch["query"], batch["passages"]):
# Extract passages and labels
passages = passages_info["passage_text"]
labels = passages_info["is_selected"]
# Pair passages with labels and sort descending by label (positives first)
paired = sorted(zip(passages, labels), key=lambda x: x[1], reverse=True)
# Separate back to passages and labels
sorted_passages, sorted_labels = zip(*paired) if paired else ([], [])
# Filter queries without any positive labels
if max(sorted_labels) < 1.0:
continue
# Truncate to max_docs
if max_docs is not None:
sorted_passages = list(sorted_passages[:max_docs])
sorted_labels = list(sorted_labels[:max_docs])
processed_queries.append(query)
processed_docs.append(sorted_passages)
processed_labels.append(sorted_labels)
return {
"query": processed_queries,
"docs": processed_docs,
"labels": processed_labels,
}
# Create a dataset with a "query" column with strings, a "docs" column with lists of strings,
# and a "labels" column with lists of floats
dataset = dataset.map(
lambda batch: listwise_mapper(batch=batch, max_docs=max_docs),
batched=True,
remove_columns=dataset.column_names,
desc="Processing listwise samples",
)
dataset = dataset.train_test_split(test_size=1_000)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
logging.info(train_dataset)
# 3. Define our training loss - using ListMLELoss with position-aware weighting
# Option 1: Standard ListMLE loss respecting input order
# loss = ListMLELoss(model, mini_batch_size=mini_batch_size, respect_input_order=respect_input_order)
# Option 2: Position-Aware ListMLE with default weighting
lambda_weight = ListMLELambdaWeight()
loss = ListMLELoss(model, lambda_weight=lambda_weight, mini_batch_size=mini_batch_size, respect_input_order=respect_input_order)
# Option 3: Position-Aware ListMLE with custom weighting function (NDCG-like)
# def custom_discount(ranks):
# return 1.0 / torch.log1p(ranks)
# lambda_weight = ListMLELambdaWeight(rank_discount_fn=custom_discount)
# loss = ListMLELoss(
# model,
# lambda_weight=lambda_weight,
# mini_batch_size=mini_batch_size,
# respect_input_order=respect_input_order
# )
# 4. Define the evaluator. We use the CENanoBEIREvaluator, which is a light-weight evaluator for English reranking
evaluator = CrossEncoderNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=eval_batch_size)
evaluator(model)
# 5. Define the training arguments
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
run_name = f"reranker-msmarco-v1.1-{short_model_name}-plistmle"
args = CrossEncoderTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=num_epochs,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=eval_batch_size,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
load_best_model_at_end=True,
metric_for_best_model="eval_NanoBEIR_R100_mean_ndcg@10",
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=500,
save_strategy="steps",
save_steps=500,
save_total_limit=2,
logging_steps=250,
logging_first_step=True,
run_name=run_name, # Will be used in W&B if `wandb` is installed
seed=12,
)
# 6. Create the trainer & start training
trainer = CrossEncoderTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=evaluator,
)
trainer.train()
# 7. Evaluate the final model, useful to include these in the model card
evaluator(model)
# 8. Save the final model
final_output_dir = f"models/{run_name}/final"
model.save_pretrained(final_output_dir)
# 9. (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
try:
model.push_to_hub(run_name)
except Exception:
logging.error(
f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
f"and saving it using `model.push_to_hub('{run_name}')`."
)
if __name__ == "__main__":
main() ListMLELoss.pyfrom __future__ import annotations
import time
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.util import fullname
class ListMLELambdaWeight(nn.Module):
"""Base class for implementing weighting schemes in Position-Aware ListMLE Loss."""
def __init__(self, rank_discount_fn=None) -> None:
"""
Initialize a lambda weight for ListMLE loss.
Args:
rank_discount_fn: Function that computes a discount for each rank position.
If None, uses default discount of 2^(list_size - rank) - 1.
"""
super().__init__()
self.rank_discount_fn = rank_discount_fn
def forward(self, ranks: Tensor, list_size: int) -> Tensor:
"""
Calculate position-aware weights for the ListMLE loss.
Args:
ranks: A tensor of rank positions [batch_size, list_size]
list_size: Size of the list
Returns:
Tensor: Weights for each position [batch_size, list_size]
"""
if self.rank_discount_fn is not None:
return self.rank_discount_fn(ranks)
# Default rank discount: 2^(list_size - rank) - 1
return torch.pow(2.0, list_size - ranks) - 1.0
class ListMLELoss(nn.Module):
def __init__(
self,
model: CrossEncoder,
lambda_weight: ListMLELambdaWeight | None = None,
activation_fct: nn.Module | None = nn.Identity(),
mini_batch_size: int | None = None,
respect_input_order: bool = True,
) -> None:
"""
ListMLE loss for learning to rank with position-aware weighting. This loss function implements
the ListMLE ranking algorithm which uses a list-wise approach based on maximum likelihood
estimation of permutations. It maximizes the likelihood of the permutation induced by the
ground truth labels with optional position-aware weighting.
.. note::
The number of documents per query can vary between samples with the ``ListMLELoss``.
Args:
model (CrossEncoder): CrossEncoder model to be trained
lambda_weight (ListMLELambdaWeight, optional): Weighting scheme to use. When specified,
implements Position-Aware ListMLE which applies different weights to different rank
positions. Default is None (standard ListMLE).
activation_fct (:class:`~torch.nn.Module`): Activation function applied to the logits before computing the
loss. Defaults to :class:`~torch.nn.Identity`.
mini_batch_size (int, optional): Number of samples to process in each forward pass. This has a significant
impact on the memory consumption and speed of the training process. Three cases are possible:
- If ``mini_batch_size`` is None, the ``mini_batch_size`` is set to the batch size.
- If ``mini_batch_size`` is greater than 0, the batch is split into mini-batches of size ``mini_batch_size``.
- If ``mini_batch_size`` is <= 0, the entire batch is processed at once.
Defaults to None.
respect_input_order (bool): Whether to respect the original input order of documents.
If True, assumes the input documents are already ordered by relevance (most relevant first).
If False, sorts documents by label values. Defaults to True.
References:
- Learning to Rank: From Pairwise Approach to Listwise Approach: https://www.microsoft.com/en-us/research/publication/learning-to-rank-from-pairwise-approach-to-listwise-approach/
- Position-Aware ListMLE: A Sequential Learning Process for Ranking: https://auai.org/uai2014/proceedings/individuals/164.pdf
- `Cross Encoder > Training Examples > MS MARCO <../../../examples/cross_encoder/training/ms_marco/README.html>`_
Requirements:
1. Query with multiple documents (listwise approach)
2. Documents must have relevance scores/labels. Both binary and continuous labels are supported.
Inputs:
+----------------------------------------+--------------------------------+-------------------------------+
| Texts | Labels | Number of Model Output Labels |
+========================================+================================+===============================+
| (query, [doc1, doc2, ..., docN]) | [score1, score2, ..., scoreN] | 1 |
+----------------------------------------+--------------------------------+-------------------------------+
Example:
::
from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses
from datasets import Dataset
model = CrossEncoder("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
"query": ["What are pandas?", "What is the capital of France?"],
"docs": [
["Pandas are a kind of bear.", "Pandas are kind of like fish."],
["The capital of France is Paris.", "Paris is the capital of France.", "Paris is quite large."],
],
"labels": [[1, 0], [1, 1, 0]],
})
# Standard ListMLE loss respecting input order
loss = losses.ListMLELoss(model)
# Position-Aware ListMLE with default weighting
lambda_weight = losses.ListMLELambdaWeight()
loss = losses.ListMLELoss(model, lambda_weight=lambda_weight)
# Position-Aware ListMLE with custom weighting function
def custom_discount(ranks): # e.g. ranks: [1, 2, 3, 4, 5]
return 1.0 / torch.log1p(ranks)
lambda_weight = losses.ListMLELambdaWeight(rank_discount_fn=custom_discount)
loss = losses.ListMLELoss(model, lambda_weight=lambda_weight)
trainer = CrossEncoderTrainer(
model=model,
train_dataset=train_dataset,
loss=loss,
)
trainer.train()
"""
super().__init__()
self.model = model
self.lambda_weight = lambda_weight
self.activation_fct = activation_fct or nn.Identity()
self.mini_batch_size = mini_batch_size
self.respect_input_order = respect_input_order
self.eps = 1e-10
if self.model.num_labels != 1:
raise ValueError(
f"{self.__class__.__name__} supports a model with 1 output label, "
f"but got a model with {self.model.num_labels} output labels."
)
def forward(self, inputs: list[list[str], list[list[str]]], labels: list[Tensor]) -> Tensor:
"""
Compute ListMLE loss for a batch of queries and their documents.
Args:
inputs: List of (queries, documents_list)
labels: Ground truth relevance scores, shape (batch_size, num_documents)
Returns:
Tensor: Mean ListMLE loss over the batch
"""
if isinstance(labels, Tensor):
raise ValueError(
"ListMLELoss expects a list of labels for each sample, but got a single value for each sample."
)
if len(inputs) != 2:
raise ValueError(
f"ListMLELoss expects two inputs (queries, documents_list), but got {len(inputs)} inputs."
)
queries, docs_list = inputs
docs_per_query = [len(docs) for docs in docs_list]
max_docs = max(docs_per_query)
batch_size = len(queries)
if docs_per_query != [len(labels) for labels in labels]:
raise ValueError(
f"Number of documents per query in inputs ({docs_per_query}) does not match number of labels per query ({[len(labels) for labels in labels]})."
)
pairs = [(query, document) for query, docs in zip(queries, docs_list) for document in docs]
if not pairs:
# Handle edge case where there are no documents
return torch.tensor(0.0, device=self.model.device, requires_grad=True)
mini_batch_size = self.mini_batch_size or batch_size
if mini_batch_size <= 0:
mini_batch_size = len(pairs)
logits_list = []
for i in range(0, len(pairs), mini_batch_size):
mini_batch_pairs = pairs[i : i + mini_batch_size]
tokens = self.model.tokenizer(
mini_batch_pairs,
padding=True,
truncation=True,
return_tensors="pt",
)
tokens = tokens.to(self.model.device)
logits = self.model(**tokens)[0].view(-1)
logits_list.append(logits)
logits = torch.cat(logits_list, dim=0)
logits = self.activation_fct(logits)
# Create output tensor filled with a very small value for padded logits
logits_matrix = torch.full((batch_size, max_docs), 1e-16, device=self.model.device)
# Place logits in the desired positions in the logit matrix
doc_indices = torch.cat([torch.arange(len(docs)) for docs in docs_list], dim=0)
batch_indices = torch.repeat_interleave(torch.arange(batch_size), torch.tensor(docs_per_query))
logits_matrix[batch_indices, doc_indices] = logits
# Create a mask for valid entries
mask = torch.zeros_like(logits_matrix, dtype=torch.bool)
mask[batch_indices, doc_indices] = True
# Convert labels to tensor matrix
labels_matrix = torch.full_like(logits_matrix, -float("inf"))
labels_matrix[batch_indices, doc_indices] = torch.cat(labels, dim=0).float()
if not torch.any(mask):
return torch.tensor(0.0, device=self.model.device, requires_grad=True)
if not self.respect_input_order:
# Sort by labels in descending order if not respecting input order.
sorted_labels, indices = labels_matrix.sort(descending=True, dim=1)
sorted_logits = torch.gather(logits_matrix, 1, indices)
else:
# Use the original input order, assuming it's already ordered by relevance
sorted_logits = logits_matrix
# Compute log-likelihood using Plackett-Luce model
scores = sorted_logits.exp()
cumsum_scores = torch.flip(torch.cumsum(torch.flip(scores, [1]), 1), [1])
log_probs = sorted_logits - torch.log(cumsum_scores + self.eps)
# Apply position-aware lambda weights if specified
if self.lambda_weight is not None:
position_weights = torch.zeros_like(log_probs)
for i, query_mask in enumerate(mask):
list_size = query_mask.sum()
ranks = torch.arange(1, list_size + 1, device=self.model.device)
position_weights[i, :list_size] = self.lambda_weight(ranks, list_size)
log_probs = log_probs * position_weights
# Sum the log probabilities for each list and mask invalid entries
per_query_losses = -torch.sum(log_probs, dim=1)
if not torch.any(per_query_losses):
return torch.tensor(0.0, device=self.model.device, requires_grad=True)
# Average loss over all lists
return torch.mean(per_query_losses)
def get_config_dict(self) -> dict[str, float | int | str | None]:
"""
Get configuration parameters for this loss function.
Returns:
Dictionary containing the configuration parameters
"""
return {
"lambda_weight": None if self.lambda_weight is None else fullname(self.lambda_weight),
"activation_fct": fullname(self.activation_fct),
"mini_batch_size": self.mini_batch_size,
"respect_input_order": self.respect_input_order,
}
@property
def citation(self) -> str:
return """
@inproceedings{lan2013position,
title={Position-aware ListMLE: a sequential learning process for ranking},
author={Lan, Yanyan and Guo, Jiafeng and Cheng, Xueqi and Liu, Tie-Yan},
booktitle={Proceedings of the Twenty-Ninth Conference on Uncertainty in Artificial Intelligence},
pages={333--342},
year={2013}
}
""" In short: pretty much what I proposed + the training script that you prepared with the pListMLELoss from the paper.
|
Hi @tomaarsen. I've done some experiments and want to share with you. Experiments1. pListMLE-Identity
2. pListMLE-customweight-Identity
3. pListMLE-sigmoid
4. pListMLE-customweight-sigmoid
5. pListMLE-tanh
Total Evaluation ResultsTrain LossesAnalysisThe results show that
ThoughtsThe experiments demonstrate that ListMLELoss occupies a valuable position between ListNetLoss and LambdaLoss in terms of performance effectiveness. Where I believe this loss function truly excels is in scenarios where we need to learn the relative ordering among documents with identical gold labels. For instance, consider a |
To add to that, I also ran just ListMLE
In short, perhaps ListMLE does not make sense (as it presumably focuses much too heavily on the order of the negative documents). |
Alright. I think we can now say that using |
I agree |
Great ! I think there are minor things left:
Is there anything else? |
That's right. There are some minor documentation things to finalize afterwards, but I will take care of those.
|
I actually worked most of them, so I will just commit them. It would be great if you could check those. |
I've worked on the minor changes. It would be great if you finalize and implement it! @tomaarsen |
It should be equivalent, and considerably faster (although it wasn't necessarily a bottleneck)
@yjoonjang I vectorized the lambda weight computation in 0012d0f, now the function only gets a boolean |
I think this is a great improvement! Moving from explicit rank positions to using a boolean mask makes the code more simple and efficient. Calculating document counts directly from the mask and ensuring weights are only applied to valid positions seems like a clearer approach to me. |
Okay, I think I'm almost done with all of my testing, but I have one more interesting case: The issueMy reasoning is this: tensor([[255., 127., 63., 31., 15., 7., 3., 1., 0., -0.],
[255., 127., 63., 31., 15., 7., 3., 1., 0., -0.],
[511., 255., 127., 63., 31., 15., 7., 3., 1., 0.],
[255., 127., 63., 31., 15., 7., 3., 1., 0., -0.],
[ 63., 31., 15., 7., 3., 1., 0., -0., -0., -0.],
[ 63., 31., 15., 7., 3., 1., 0., -0., -0., -0.],
[127., 63., 31., 15., 7., 3., 1., 0., -0., -0.],
[127., 63., 31., 15., 7., 3., 1., 0., -0., -0.],
[255., 127., 63., 31., 15., 7., 3., 1., 0., -0.],
[255., 127., 63., 31., 15., 7., 3., 1., 0., -0.],
[511., 255., 127., 63., 31., 15., 7., 3., 1., 0.],
[255., 127., 63., 31., 15., 7., 3., 1., 0., -0.],
[255., 127., 63., 31., 15., 7., 3., 1., 0., -0.],
[ 63., 31., 15., 7., 3., 1., 0., -0., -0., -0.],
[127., 63., 31., 15., 7., 3., 1., 0., -0., -0.],
[255., 127., 63., 31., 15., 7., 3., 1., 0., -0.]],
device='cuda:0') This is multiplied with the > (log_probs * self.lambda_weight(mask)).sum(dim=1)
tensor([-1103.7828, -1107.8832, -2218.6123, -1104.4536, -265.9218, -267.1783,
-544.8917, -544.4240, -1110.4033, -1102.6829, -2228.4136, -1101.6121,
-1105.8870, -266.0999, -544.8442, -1101.8113], device='cuda:0',
grad_fn=<SumBackward1>) Afterwards, we
After all, if you sum the If we don't apply this > torch.sum(log_probs, dim=1)
tensor([-15.2067, -15.1940, -15.0684, -15.1953, -15.2408, -15.2317, -15.2240,
-15.2154, -15.2133, -15.2076, -15.0966, -15.1417, -15.1786, -15.2057,
-15.2024, -15.1925], device='cuda:0', grad_fn=<SumBackward1>) And each query has pretty equal contributions, but the position weighting is not taken into consideration. My proposalI propose to normalize the
When there's many docs, softmax starts failing as it only uses the first doc: >>> torch.tensor([255., 127., 63., 31., 15., 7., 3., 1., 0., -0.]).softmax(0)
tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) But dividing by the sum might work: >>> weight = torch.tensor([255., 127., 63., 31., 15., 7., 3., 1., 0., -0.])
>>> weight / weight.sum(0)
tensor([0.5080, 0.2530, 0.1255, 0.0618, 0.0299, 0.0139, 0.0060, 0.0020, 0.0000,
-0.0000]) My modelsPListMLELoss, but with
|
Metric | Value |
---|---|
map | 0.4797 (+0.0897) |
mrr@10 | 0.5721 (+0.1040) |
ndcg@10 | 0.5386 (+0.0832) |
PListMLELoss, but with / lambda_weight.sum(dim=1)
over the lambda_weight
- Model: https://huggingface.co/tomaarsen/reranker-msmarco-v1.1-MiniLM-L12-H384-uncased-plistmle-sum-to-1
- Final results:
Metric | Value |
---|---|
map | 0.4669 (+0.0768) |
mrr@10 | 0.5474 (+0.0794) |
ndcg@10 | 0.5240 (+0.0686) |
Graphs compared to default:
As you can see here, the "better" results from the Sigmoid one is primarily just an outlier, it actually ends worse than the others if it wasn't for the "load_best_model_at_end". I'm quite a fan of the "divide by sum" option. This is with just the 2 from this experiment, you can see the effect on the loss too:
The softmax gets lower loss because it usually only cares about the positive document.
Also, inspecting that lambda_weight
compared to the mask
:
tensor([[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, False, False, False],
[ True, True, True, True, True, True, True, False, False, False],
[ True, True, True, True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, False, False, False],
[ True, True, True, True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True, True, True, False]],
device='cuda:0')
It seems that the last document (the last True
) always has a weight of 0. That might not be intended.
- Tom Aarsen
I also ran the
I'm tempted to merge in the sum-to-1 approach: # Apply position-aware lambda weights if specified. If None, then this loss
# is just ListMLE.
if self.lambda_weight is not None:
lambda_weight = self.lambda_weight(mask)
# Normalize weights to sum to 1
lambda_weight_sum = lambda_weight.sum(dim=1, keepdim=True) + self.eps
lambda_weight = lambda_weight / lambda_weight_sum
log_probs = log_probs * lambda_weight
|
Thank you for sharing your insights throughout the experiment, @tomaarsen. Problem for the last document weightThis indeed is a problem, and I propose to fix the code from Normalization methodsAfter fixing the issue for the ranks, I did some experiments related to normalization. Since I also thought that the loss was too high, I think your approach is great. Besides your normalization method (softmax, sum-to-1), I've experimented other methods too: minmax, log, temperature (the codes are provided below.) codefrom __future__ import annotations
import torch
from torch import Tensor, nn
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.util import fullname
class PListMLELambdaWeight(nn.Module):
"""Base class for implementing weighting schemes in Position-Aware ListMLE Loss."""
def __init__(self, rank_discount_fn=None, normalize_weights=True, normalization_method="sum", temperature=0.5) -> None:
"""
Initialize a lambda weight for PListMLE loss.
Args:
rank_discount_fn: Function that computes a discount for each rank position.
If None, uses default discount of 2^(num_docs - rank) - 1.
normalize_weights: Whether to normalize weights to sum to 1 per query.
If True, each query has equal contribution to the loss regardless of document count.
normalization_method: Method to use for normalization. Options:
- "sum": Normalize to sum to 1 (default)
- "softmax": Use softmax normalization
- "minmax": Min-max normalization to [0,1] range
- "log": Log normalization
- "temperature": Softmax with temperature scaling
temperature: Temperature parameter for softmax scaling (default: 0.5).
Lower values make distribution more peaked at high weights.
Only used when normalization_method="temperature".
"""
super().__init__()
self.rank_discount_fn = rank_discount_fn
self.normalize_weights = normalize_weights
self.normalization_method = normalization_method
self.temperature = temperature
def forward(self, mask: Tensor) -> Tensor:
"""
Calculate position-aware weights for the PListMLE loss.
Args:
mask: A boolean mask indicating valid positions [batch_size, num_docs]
Returns:
Tensor: Weights for each position [batch_size, num_docs]
"""
if self.rank_discount_fn is not None:
weights = self.rank_discount_fn(mask)
else:
# Apply default rank discount: 2^(num_docs - rank) - 1
num_docs_per_query = mask.sum(dim=1, keepdim=True)
ranks = torch.arange(1, mask.size(1) + 1, device=mask.device).expand_as(mask)
# weights = torch.pow(2.0, num_docs_per_query - ranks) - 1.0
weights = torch.pow(2.0, num_docs_per_query - ranks + 1) - 1.0
weights = weights * mask
# Normalize weights to sum to 1 for each query if requested
if self.normalize_weights:
if self.normalization_method == "sum":
# Normalize to sum to 1 (original implementation)
weight_sums = weights.sum(dim=1, keepdim=True)
weight_sums = torch.clamp(weight_sums, min=1e-10)
weights = weights / weight_sums
elif self.normalization_method == "softmax":
# Use softmax normalization
# Fix for numerical stability - convert original weights to logits for softmax
logits = torch.log(weights + 1e-10) * mask.float()
weights = torch.softmax(logits, dim=1) * mask.float()
elif self.normalization_method == "minmax":
# Min-max normalization per query
min_vals, _ = torch.min(weights + (1 - mask.float()) * 1e10, dim=1, keepdim=True)
max_vals, _ = torch.max(weights * mask.float(), dim=1, keepdim=True)
# Avoid division by zero
denom = torch.clamp(max_vals - min_vals, min=1e-10)
# Normalize to [0,1] range
weights = ((weights - min_vals) / denom) * mask.float()
# Ensure sum to 1
weight_sums = weights.sum(dim=1, keepdim=True)
weight_sums = torch.clamp(weight_sums, min=1e-10)
weights = weights / weight_sums
elif self.normalization_method == "log":
# Log normalization
# Add small constant for log stability
log_weights = torch.log1p(weights) * mask.float()
weight_sums = log_weights.sum(dim=1, keepdim=True)
weight_sums = torch.clamp(weight_sums, min=1e-10)
weights = log_weights / weight_sums
elif self.normalization_method == "temperature":
# Softmax with temperature scaling
logits = torch.log(weights + 1e-10) * mask.float()
weights = torch.softmax(logits / self.temperature, dim=1) * mask.float()
return weights
def get_config_dict(self) -> dict[str, float | int | str | bool]:
"""
Get configuration parameters for this lambda weight.
Returns:
Dictionary containing the configuration parameters
"""
return {
"rank_discount_fn": None if self.rank_discount_fn is None else fullname(self.rank_discount_fn),
"normalize_weights": self.normalize_weights,
"normalization_method": self.normalization_method,
"temperature": self.temperature,
}
class PListMLELoss(nn.Module):
def __init__(
self,
model: CrossEncoder,
lambda_weight: PListMLELambdaWeight | None = PListMLELambdaWeight(normalize_weights=True),
activation_fct: nn.Module | None = nn.Identity(),
mini_batch_size: int | None = None,
respect_input_order: bool = True,
) -> None:
"""
PListMLE loss for learning to rank with position-aware weighting. This loss function implements
the ListMLE ranking algorithm which uses a list-wise approach based on maximum likelihood
estimation of permutations. It maximizes the likelihood of the permutation induced by the
ground truth labels with position-aware weighting.
This loss is also known as Position-Aware ListMLE or p-ListMLE.
.. note::
The number of documents per query can vary between samples with the ``PListMLELoss``.
Args:
model (CrossEncoder): CrossEncoder model to be trained
lambda_weight (PListMLELambdaWeight, optional): Weighting scheme to use. When specified,
implements Position-Aware ListMLE which applies different weights to different rank
positions. Default is PListMLELambdaWeight with normalize_weights=True, which ensures
each query contributes equally to the loss regardless of document count.
Several normalization methods are available through the PListMLELambdaWeight's
normalization_method parameter:
- "sum": Standard normalization to sum to 1 (default)
- "softmax": Softmax normalization for smoother distribution
- "minmax": Min-max normalization to [0,1] range
- "log": Log normalization for compressing large differences
- "temperature": Softmax with temperature scaling (t=0.5)
Setting normalize_weights=False will make queries with more documents contribute more
to the loss, which may bias training towards these queries.
activation_fct (:class:`~torch.nn.Module`): Activation function applied to the logits before computing the
loss. Defaults to :class:`~torch.nn.Identity`.
mini_batch_size (int, optional): Number of samples to process in each forward pass. This has a significant
impact on the memory consumption and speed of the training process. Three cases are possible:
- If ``mini_batch_size`` is None, the ``mini_batch_size`` is set to the batch size.
- If ``mini_batch_size`` is greater than 0, the batch is split into mini-batches of size ``mini_batch_size``.
- If ``mini_batch_size`` is <= 0, the entire batch is processed at once.
Defaults to None.
respect_input_order (bool): Whether to respect the original input order of documents.
If True, assumes the input documents are already ordered by relevance (most relevant first).
If False, sorts documents by label values. Defaults to True.
References:
- Position-Aware ListMLE: A Sequential Learning Process for Ranking: https://auai.org/uai2014/proceedings/individuals/164.pdf
- `Cross Encoder > Training Examples > MS MARCO <../../../examples/cross_encoder/training/ms_marco/README.html>`_
Requirements:
1. Query with multiple documents (listwise approach)
2. Documents must have relevance scores/labels. Both binary and continuous labels are supported.
3. Documents must be sorted in a defined rank order.
Inputs:
+----------------------------------------+--------------------------------+-------------------------------+
| Texts | Labels | Number of Model Output Labels |
+========================================+================================+===============================+
| (query, [doc1, doc2, ..., docN]) | [score1, score2, ..., scoreN] | 1 |
+----------------------------------------+--------------------------------+-------------------------------+
Recommendations:
- Use :class:`~sentence_transformers.util.mine_hard_negatives` with ``output_format="labeled-list"``
to convert question-answer pairs to the required input format with hard negatives.
Relations:
- The :class:`~sentence_transformers.cross_encoder.losses.PListMLELoss` is an extension of the
:class:`~sentence_transformers.cross_encoder.losses.ListMLELoss` and allows for positional weighting
of the loss. :class:`~sentence_transformers.cross_encoder.losses.PListMLELoss` generally outperforms
:class:`~sentence_transformers.cross_encoder.losses.ListMLELoss` and is recommended over it.
- :class:`~sentence_transformers.cross_encoder.losses.LambdaLoss` takes the same inputs, and generally
outperforms this loss.
Example:
::
from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses
from datasets import Dataset
model = CrossEncoder("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
"query": ["What are pandas?", "What is the capital of France?"],
"docs": [
["Pandas are a kind of bear.", "Pandas are kind of like fish."],
["The capital of France is Paris.", "Paris is the capital of France.", "Paris is quite large."],
],
"labels": [[1, 0], [1, 1, 0]],
})
# Either: Position-Aware ListMLE with default weighting
lambda_weight = losses.PListMLELambdaWeight()
loss = losses.PListMLELoss(model, lambda_weight=lambda_weight)
# or: Position-Aware ListMLE with custom weighting function
def custom_discount(ranks): # e.g. ranks: [1, 2, 3, 4, 5]
return 1.0 / torch.log1p(ranks)
lambda_weight = losses.PListMLELambdaWeight(rank_discount_fn=custom_discount)
loss = losses.PListMLELoss(model, lambda_weight=lambda_weight)
# or: Position-Aware ListMLE with weights explicitly not normalized (original behavior)
lambda_weight = losses.PListMLELambdaWeight(normalize_weights=False)
loss = losses.PListMLELoss(model, lambda_weight=lambda_weight)
# or: Position-Aware ListMLE with softmax for normalization
# Softmax normalization (may focus more on top documents)
lambda_weight = losses.PListMLELambdaWeight(normalization_method="softmax")
loss = losses.PListMLELoss(model, lambda_weight=lambda_weight)
trainer = CrossEncoderTrainer(
model=model,
train_dataset=train_dataset,
loss=loss,
)
trainer.train()
"""
super().__init__()
self.model = model
self.lambda_weight = lambda_weight
self.activation_fct = activation_fct or nn.Identity()
self.mini_batch_size = mini_batch_size
self.respect_input_order = respect_input_order
self.eps = 1e-10
if self.model.num_labels != 1:
raise ValueError(
f"{self.__class__.__name__} supports a model with 1 output label, "
f"but got a model with {self.model.num_labels} output labels."
)
def forward(self, inputs: list[list[str], list[list[str]]], labels: list[Tensor]) -> Tensor:
"""
Compute PListMLE loss for a batch of queries and their documents.
Args:
inputs: List of (queries, documents_list)
labels: Ground truth relevance scores, shape (batch_size, num_documents)
Returns:
Tensor: Mean PListMLE loss over the batch
"""
if isinstance(labels, Tensor):
raise ValueError(
"PListMLELoss expects a list of labels for each sample, but got a single value for each sample."
)
if len(inputs) != 2:
raise ValueError(
f"PListMLELoss expects two inputs (queries, documents_list), but got {len(inputs)} inputs."
)
queries, docs_list = inputs
docs_per_query = [len(docs) for docs in docs_list]
max_docs = max(docs_per_query)
batch_size = len(queries)
if docs_per_query != [len(labels) for labels in labels]:
raise ValueError(
f"Number of documents per query in inputs ({docs_per_query}) does not match number of labels per query ({[len(labels) for labels in labels]})."
)
pairs = [(query, document) for query, docs in zip(queries, docs_list) for document in docs]
if not pairs:
# Handle edge case where there are no documents
return torch.tensor(0.0, device=self.model.device, requires_grad=True)
mini_batch_size = self.mini_batch_size or batch_size
if mini_batch_size <= 0:
mini_batch_size = len(pairs)
logits_list = []
for i in range(0, len(pairs), mini_batch_size):
mini_batch_pairs = pairs[i : i + mini_batch_size]
tokens = self.model.tokenizer(
mini_batch_pairs,
padding=True,
truncation=True,
return_tensors="pt",
)
tokens = tokens.to(self.model.device)
logits = self.model(**tokens)[0].view(-1)
logits_list.append(logits)
logits = torch.cat(logits_list, dim=0)
logits = self.activation_fct(logits)
# Create output tensor filled with a very small value for padded logits
logits_matrix = torch.full((batch_size, max_docs), 1e-16, device=self.model.device)
# Place logits in the desired positions in the logit matrix
doc_indices = torch.cat([torch.arange(len(docs)) for docs in docs_list], dim=0)
batch_indices = torch.repeat_interleave(torch.arange(batch_size), torch.tensor(docs_per_query))
logits_matrix[batch_indices, doc_indices] = logits
# Create a mask for valid entries
mask = torch.zeros_like(logits_matrix, dtype=torch.bool)
mask[batch_indices, doc_indices] = True
# Convert labels to tensor matrix
labels_matrix = torch.full_like(logits_matrix, -float("inf"))
labels_matrix[batch_indices, doc_indices] = torch.cat(labels, dim=0).float()
if not torch.any(mask):
return torch.tensor(0.0, device=self.model.device, requires_grad=True)
if not self.respect_input_order:
# Sort by labels in descending order if not respecting input order.
sorted_labels, indices = labels_matrix.sort(descending=True, dim=1)
sorted_logits = torch.gather(logits_matrix, 1, indices)
else:
# Use the original input order, assuming it's already ordered by relevance
sorted_logits = logits_matrix
# Compute log-likelihood using Plackett-Luce model
scores = sorted_logits.exp()
cumsum_scores = torch.flip(torch.cumsum(torch.flip(scores, [1]), 1), [1])
log_probs = sorted_logits - torch.log(cumsum_scores + self.eps)
if self.lambda_weight is not None:
log_probs = log_probs * self.lambda_weight(mask)
# Sum the log probabilities for each list and mask padded entries
log_probs[~mask] = 0.0
per_query_losses = -torch.sum(log_probs, dim=1)
if not torch.any(per_query_losses):
return torch.tensor(0.0, device=self.model.device, requires_grad=True)
# Average loss over all lists
return torch.mean(per_query_losses)
def get_config_dict(self) -> dict[str, float | int | str | None]:
"""
Get configuration parameters for this loss function.
Returns:
Dictionary containing the configuration parameters
"""
return {
"lambda_weight": None if self.lambda_weight is None else fullname(self.lambda_weight),
"activation_fct": fullname(self.activation_fct),
"mini_batch_size": self.mini_batch_size,
"respect_input_order": self.respect_input_order,
}
@property
def citation(self) -> str:
return """
@inproceedings{lan2014position,
title={Position-Aware ListMLE: A Sequential Learning Process for Ranking.},
author={Lan, Yanyan and Zhu, Yadong and Guo, Jiafeng and Niu, Shuzi and Cheng, Xueqi},
booktitle={UAI},
volume={14},
pages={449--458},
year={2014}
}
""" Results(Note that the my results can differ with yours since I fixed the '
Evaluation graphs across all methods: As you see in the results, it appears to be When comparing I actually want to implement |
That looks quite promising - I didn't do 0.54 is definitely higher than we've seen for PListMLE - we've not even seen 0.53... before whereas your softmax one reaches that in ~5 different evals, seems like it's not an outlier.
|
Bad news (perhaps): lambda_weight = self.lambda_weight(mask)
# Normalize weights to sum to 1
lambda_weight_sum_1 = lambda_weight.sum(dim=1, keepdim=True) + self.eps
lambda_weight_1 = lambda_weight / lambda_weight_sum_1
lambda_weight_2 = lambda_weight.log().softmax(dim=1) (Pdb) lambda_weight_1[-1]
tensor([0.5044, 0.2517, 0.1254, 0.0622, 0.0306, 0.0148, 0.0069, 0.0030, 0.0010,
0.0000], device='cuda:0')
(Pdb) lambda_weight_2[-1]
tensor([0.5044, 0.2517, 0.1254, 0.0622, 0.0306, 0.0148, 0.0069, 0.0030, 0.0010,
0.0000], device='cuda:0') So the sum-to-1 and log -> softmax should perform identically. |
Oh you're right. Mine's a little different though because it multiplies the mask twice: elif self.normalization_method == "softmax":
# Use softmax normalization
# Fix for numerical stability - convert original weights to logits for softmax
logits = torch.log(weights + 1e-10) * mask.float()
weights = torch.softmax(logits, dim=1) * mask.float() But I don't think this is a big deal. We could merge in to the |
I'm afraid that it seems like the I'll merge the |
I'm going to merge this now, I quite like how it looks. Great work on this @yjoonjang, and I really appreciate the experimentation as well. @milistu I know you're also interested in having a look: you can still do that with the merged version and give feedback if you want.
|
Nice work and thank you, @tomaarsen !! |
ListMLELoss Implementation for Cross Encoder Trainer
This PR adds ListMLELoss functionality to the Cross Encoder Trainer feature.
Changes
Implementation Details
ListMLELoss is a listwise loss function that optimizes the likelihood of the correct document permutation using the Plackett-Luce model. Key features include:
This implementation is particularly useful for information retrieval and recommendation systems where the goal is to present the most relevant items at the top of the list.
Reference: https://auai.org/uai2014/proceedings/individuals/164.pdf