Skip to content

Conversation

yjoonjang
Copy link

@yjoonjang yjoonjang commented Mar 17, 2025

ListMLELoss Implementation for Cross Encoder Trainer

This PR adds ListMLELoss functionality to the Cross Encoder Trainer feature.

Changes

  • Implemented ListMLELoss as a listwise loss function
  • Added support for weighing schemes for ListMLELoss
  • Created an example script demonstrating ListMLELoss usage

Implementation Details

ListMLELoss is a listwise loss function that optimizes the likelihood of the correct document permutation using the Plackett-Luce model. Key features include:

  1. Plackett-Luce Model: The loss function uses the Plackett-Luce model to compute the probability of a permutation, which sequentially selects items based on their scores.
  2. Position-Aware Weighting: I've implemented Position-Aware ListMLE with lambda weighting, allowing different weights to be applied to different rank positions. This is particularly useful for emphasizing the importance of correctly ranking items at top positions.
  3. Flexible Input Handling: The implementation supports variable-length document lists per query and efficiently handles mini-batching for better memory management.
  4. Input Order Respect: The implementation can either respect the original input order of documents (assuming they're already ordered by relevance) or sort them by label values.
  5. Numerical Stability: Special care has been taken to ensure numerical stability in the computation of log probabilities.

This implementation is particularly useful for information retrieval and recommendation systems where the goal is to present the most relevant items at the top of the list.

Reference: https://auai.org/uai2014/proceedings/individuals/164.pdf

Copy link
Owner

@tomaarsen tomaarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice work on this - I will be training some models with this to experiment. I have some questions and comments with the goal to make this as easy as possible to use for the users.

return self.rank_discount_fn(ranks)


class ListMLELoss(nn.Module):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yjoonjang @milistu

From a user perspective, it might not be immediately obvious that this supports both the "standard" ListMLE loss as well as the "improved" position-aware ListMLE loss. Beyond that, my experience is that 95% of users will use the default settings, so we want to promote the strongest options.
I think there are a handful of options, but I think my favourite is to keep having only ListMLELoss, but perhaps use ListMLELambdaWeight() as the default lambda_weight, so the default is p-ListMLE but ListMLE can be used as well. We can be explicit about this in the model card.

Alternatively, we can rename the ListMLELoss class to e.g. PListMLELoss and then introduce a new ListMLELoss class which simply subclasses PListMLELoss except doesn't expose a lambda_weight option - it just calls the PListMLELoss superclass with lambda_weight=None. In short, this means we introduce 2 losses with this PR.

I'm curious about both of your thoughts here!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly agree with your opinion, @tomaarsen.
Since Position-aware ListMLELoss (PListMLELoss sounds good.) is advanced version of ListMLELoss, using PListMLELoss would probably be better.

If @milistu agrees, I will rename to PListMLELoss (which has lambda_weight in default) and add a new loss named ListMLELoss which calss PListMLELoss superclass with lambda_weight=None.

Comment on lines +76 to +78
respect_input_order (bool): Whether to respect the original input order of documents.
If True, assumes the input documents are already ordered by relevance (most relevant first).
If False, sorts documents by label values. Defaults to True.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If respect_input_order is True, is there a difference between

docs: (doc1, doc2, doc3, doc4, doc5)
labels: (1, 1, 0, 0, 0)

and

docs: (doc1, doc2, doc5, doc3, doc4)
labels: (1, 1, 0, 0, 0)

e.g. if the ones with label=0 are swapped.

Also, if respect_input_order is True, and their inputs is e.g.:

docs: (doc1, doc2, doc3, doc4, doc5)
labels: (1, 0, 0, 1, 0)

would that give incorrect results?

Otherwise it maybe makes sense to do a quick check if the docs are sorted, and update them only if they are not sorted.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. If respect_input_order=True, there is a difference between the two examples you provided. Even though doc3, doc4, and doc5 all have the same label (0), their order matters in ListMLE. The loss function models the sequential selection process and will train the model to select documents in exactly the order provided. So if the order of documents with label 0 changes, the model will learn a different permutation.

  2. If respect_input_order=True and the documents are not sorted by relevance (e.g., labels are [1, 0, 0, 1, 0]), this could indeed lead to suboptimal results for ranking tasks. The model will learn to reproduce this exact order, even though typically we'd want all documents with label 1 to be ranked before documents with label 0.

The intent for respect_input_order argument was to make model learn ranking for same labels. This is particularly useful in scenarios where we have multiple relevant documents with the same label but different degrees of preference or importance that aren't captured by the label values alone.(e.g. Tool selection task: when 2 tools are gold but there is a sequence for the tools.)

Comment on lines 119 to 123
# Position-Aware ListMLE with custom weighting function
def custom_discount(ranks): # e.g. ranks: [1, 2, 3, 4, 5]
return 1.0 / torch.log1p(ranks)
lambda_weight = losses.ListMLELambdaWeight(rank_discount_fn=custom_discount)
loss = losses.ListMLELoss(model, lambda_weight=lambda_weight)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this custom weighting function option also proposed in the p-ListMLE paper? (https://auai.org/uai2014/proceedings/individuals/164.pdf)
Or is this 'custom'? Do we have evaluations on whether this works better than e.g. the default lambda weighting?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This custom weighting function option is proposed in the paper (6. Experiments)
Which says: "In p-ListMLE, we set α(i) as $2^{n−i} − 1$, as guided in the above section."

Copy link
Owner

@tomaarsen tomaarsen Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, perfect, I'll have another look at the paper.

Edit: Is the 1.0 / torch.log1p(ranks) also used anywhere? I can find details about the $2^{n-i} - 1$.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, sorry. I thought you were pointing to this part: return torch.pow(2.0, list_size - ranks) - 1.0

The custom_discount function is a 'custom' function. The one I implemented makes it to a lambdaloss-like function. (It only uses the Discount though.)

yjoonjang and others added 3 commits March 17, 2025 20:00
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
@tomaarsen
Copy link
Owner

I ran the training script with pListMLE:

    # Option 2: Position-Aware ListMLE with default weighting
    lambda_weight = ListMLELambdaWeight()
    loss = ListMLELoss(model, lambda_weight=lambda_weight, mini_batch_size=mini_batch_size, respect_input_order=respect_input_order) 

and the results have been okay so far - not amazing.

Here's the model: https://huggingface.co/tomaarsen/reranker-msmarco-v1.1-MiniLM-L12-H384-uncased-plistmle
Here's the same model but with LambdaLoss for comparison: https://huggingface.co/tomaarsen/reranker-msmarco-v1.1-MiniLM-L12-H384-uncased-lambdaloss

Here's some logs:
image
image

The loss starts at ~1400 and then falls to ~820. This is a lot higher than usual - doesn't normally matter, but it might be indicative of an issue somewhere.

  • Tom Aarsen

logits = self.activation_fct(logits)

# Create output tensor filled with a very small value for padded logits
logits_matrix = torch.full((batch_size, max_docs), -1e16, device=self.model.device)
Copy link
Owner

@tomaarsen tomaarsen Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a very small value, it's a very large negative value (-10000000000000000). cc @milistu this is also the case for e.g. ListNetLoss, although it's ignored via the labels.
Maybe it was meant to be 1e-16

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't using a large negative value (-1e16) be safer for padded positions rather than a small positive value (1e-16)? The large negative value ensures padded positions have effectively zero contribution in softmax operations, while a small positive value might still have some minor influence.

@yjoonjang
Copy link
Author

I ran the training script with pListMLE:

    # Option 2: Position-Aware ListMLE with default weighting
    lambda_weight = ListMLELambdaWeight()
    loss = ListMLELoss(model, lambda_weight=lambda_weight, mini_batch_size=mini_batch_size, respect_input_order=respect_input_order) 

and the results have been okay so far - not amazing.

Here's the model: https://huggingface.co/tomaarsen/reranker-msmarco-v1.1-MiniLM-L12-H384-uncased-plistmle Here's the same model but with LambdaLoss for comparison: https://huggingface.co/tomaarsen/reranker-msmarco-v1.1-MiniLM-L12-H384-uncased-lambdaloss

Here's some logs: image image

The loss starts at ~1400 and then falls to ~820. This is a lot higher than usual - doesn't normally matter, but it might be indicative of an issue somewhere.

  • Tom Aarsen

Hmm interesting.. Looks a bit better than model using ListNetLoss (https://huggingface.co/tomaarsen/reranker-msmarco-v1.1-MiniLM-L12-H384-uncased-listnet)

Could this be due to the data label? I saw that the data you used (https://huggingface.co/datasets/microsoft/ms_marco/viewer/v1.1/train?views%5B%5D=v11_train&row=60) and the label (is_selected) looks like they are not sorted in descending order. Which is related to the issue you mentioned above. #6 (comment)

If you're using respect_input_order=True, I think the label and the data should be sorted in a descending order.

@tomaarsen
Copy link
Owner

Good call! Will investigate. And indeed, it's a bit better than listnet, promising:
image

I don't think it's necessarily weird for this to fall between listnet and lambdaloss.

Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
@tomaarsen
Copy link
Owner

I'm afraid the training script already sorts the pairs by label:

            # Pair passages with labels and sort descending by label (positives first)
            paired = sorted(zip(passages, labels), key=lambda x: x[1], reverse=True)

            # Separate back to passages and labels
            sorted_passages, sorted_labels = zip(*paired) if paired else ([], [])

so it's not worse for this reason - I think it might just fall between ListNetLoss and LambdaLoss.

…eights

Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
@tomaarsen
Copy link
Owner

tomaarsen commented Mar 17, 2025

One of the ways to avoid the weirdly high losses might be to call a softmax() on the position_weights - but I'm not sure if that's a proper solution. But it does turn the losses from e.g. 900 into e.g. 13. I'm trying to figure it out using some other implementations:

@yjoonjang
Copy link
Author

One of the ways to avoid the weirdly high losses might be to call a sigmoid() on the position_weights - but I'm not sure if that's a proper solution. But it does turn the losses from e.g. 900 into e.g. 13. I'm trying to figure it out using some other implementations:

Can I have your training code or script please? I would also like to do some experiments.

@tomaarsen
Copy link
Owner

My exact local code right now is this:

Training Script
import logging
import traceback

from datasets import load_dataset
import torch

from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
from sentence_transformers.cross_encoder.losses import ListMLELoss, ListMLELambdaWeight
from sentence_transformers.cross_encoder.trainer import CrossEncoderTrainer
from sentence_transformers.cross_encoder.training_args import CrossEncoderTrainingArguments


def main():
    model_name = "microsoft/MiniLM-L12-H384-uncased"

    # Set the log level to INFO to get more information
    logging.basicConfig(
        format="%(asctime)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
    )
    # train_batch_size and eval_batch_size inform the size of the batches, while mini_batch_size is used by the loss
    # to subdivide the batch into smaller parts. This mini_batch_size largely informs the training speed and memory usage.
    # Keep in mind that the loss does not process `train_batch_size` pairs, but `train_batch_size * num_docs` pairs.
    train_batch_size = 16
    eval_batch_size = 16
    mini_batch_size = 16
    num_epochs = 1
    max_docs = None
    respect_input_order = True  # Whether to respect the original order of documents

    # 1. Define our CrossEncoder model
    model = CrossEncoder(model_name, num_labels=1)
    print("Model max length:", model.max_length)
    print("Model num labels:", model.num_labels)

    # 2. Load the MS MARCO dataset: https://huggingface.co/datasets/microsoft/ms_marco
    logging.info("Read train dataset")
    dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train")

    def listwise_mapper(batch, max_docs: int | None = 10):
        processed_queries = []
        processed_docs = []
        processed_labels = []

        for query, passages_info in zip(batch["query"], batch["passages"]):
            # Extract passages and labels
            passages = passages_info["passage_text"]
            labels = passages_info["is_selected"]

            # Pair passages with labels and sort descending by label (positives first)
            paired = sorted(zip(passages, labels), key=lambda x: x[1], reverse=True)

            # Separate back to passages and labels
            sorted_passages, sorted_labels = zip(*paired) if paired else ([], [])

            # Filter queries without any positive labels
            if max(sorted_labels) < 1.0:
                continue

            # Truncate to max_docs
            if max_docs is not None:
                sorted_passages = list(sorted_passages[:max_docs])
                sorted_labels = list(sorted_labels[:max_docs])

            processed_queries.append(query)
            processed_docs.append(sorted_passages)
            processed_labels.append(sorted_labels)

        return {
            "query": processed_queries,
            "docs": processed_docs,
            "labels": processed_labels,
        }

    # Create a dataset with a "query" column with strings, a "docs" column with lists of strings,
    # and a "labels" column with lists of floats
    dataset = dataset.map(
        lambda batch: listwise_mapper(batch=batch, max_docs=max_docs),
        batched=True,
        remove_columns=dataset.column_names,
        desc="Processing listwise samples",
    )

    dataset = dataset.train_test_split(test_size=1_000)
    train_dataset = dataset["train"]
    eval_dataset = dataset["test"]
    logging.info(train_dataset)

    # 3. Define our training loss - using ListMLELoss with position-aware weighting
    
    # Option 1: Standard ListMLE loss respecting input order
    # loss = ListMLELoss(model, mini_batch_size=mini_batch_size, respect_input_order=respect_input_order)
    
    # Option 2: Position-Aware ListMLE with default weighting
    lambda_weight = ListMLELambdaWeight()
    loss = ListMLELoss(model, lambda_weight=lambda_weight, mini_batch_size=mini_batch_size, respect_input_order=respect_input_order)
    
    # Option 3: Position-Aware ListMLE with custom weighting function (NDCG-like)
    # def custom_discount(ranks):
    #     return 1.0 / torch.log1p(ranks)
    
    # lambda_weight = ListMLELambdaWeight(rank_discount_fn=custom_discount)
    # loss = ListMLELoss(
    #     model, 
    #     lambda_weight=lambda_weight, 
    #     mini_batch_size=mini_batch_size, 
    #     respect_input_order=respect_input_order
    # )

    # 4. Define the evaluator. We use the CENanoBEIREvaluator, which is a light-weight evaluator for English reranking
    evaluator = CrossEncoderNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=eval_batch_size)
    evaluator(model)

    # 5. Define the training arguments
    short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
    run_name = f"reranker-msmarco-v1.1-{short_model_name}-plistmle"
    args = CrossEncoderTrainingArguments(
        # Required parameter:
        output_dir=f"models/{run_name}",
        # Optional training parameters:
        num_train_epochs=num_epochs,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=eval_batch_size,
        learning_rate=2e-5,
        warmup_ratio=0.1,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=True,  # Set to True if you have a GPU that supports BF16
        load_best_model_at_end=True,
        metric_for_best_model="eval_NanoBEIR_R100_mean_ndcg@10",
        # Optional tracking/debugging parameters:
        eval_strategy="steps",
        eval_steps=500,
        save_strategy="steps",
        save_steps=500,
        save_total_limit=2,
        logging_steps=250,
        logging_first_step=True,
        run_name=run_name,  # Will be used in W&B if `wandb` is installed
        seed=12,
    )

    # 6. Create the trainer & start training
    trainer = CrossEncoderTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss=loss,
        evaluator=evaluator,
    )
    trainer.train()

    # 7. Evaluate the final model, useful to include these in the model card
    evaluator(model)

    # 8. Save the final model
    final_output_dir = f"models/{run_name}/final"
    model.save_pretrained(final_output_dir)

    # 9. (Optional) save the model to the Hugging Face Hub!
    # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
    try:
        model.push_to_hub(run_name)
    except Exception:
        logging.error(
            f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
            f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
            f"and saving it using `model.push_to_hub('{run_name}')`."
        )


if __name__ == "__main__":
    main()
ListMLELoss.py
from __future__ import annotations
import time

import torch
from torch import Tensor, nn
import torch.nn.functional as F

from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.util import fullname


class ListMLELambdaWeight(nn.Module):
    """Base class for implementing weighting schemes in Position-Aware ListMLE Loss."""

    def __init__(self, rank_discount_fn=None) -> None:
        """
        Initialize a lambda weight for ListMLE loss.

        Args:
            rank_discount_fn: Function that computes a discount for each rank position.
                              If None, uses default discount of 2^(list_size - rank) - 1.
        """
        super().__init__()
        self.rank_discount_fn = rank_discount_fn

    def forward(self, ranks: Tensor, list_size: int) -> Tensor:
        """
        Calculate position-aware weights for the ListMLE loss.

        Args:
            ranks: A tensor of rank positions [batch_size, list_size]
            list_size: Size of the list

        Returns:
            Tensor: Weights for each position [batch_size, list_size]
        """
        if self.rank_discount_fn is not None:
            return self.rank_discount_fn(ranks)

        # Default rank discount: 2^(list_size - rank) - 1
        return torch.pow(2.0, list_size - ranks) - 1.0


class ListMLELoss(nn.Module):
    def __init__(
        self,
        model: CrossEncoder,
        lambda_weight: ListMLELambdaWeight | None = None,
        activation_fct: nn.Module | None = nn.Identity(),
        mini_batch_size: int | None = None,
        respect_input_order: bool = True,
    ) -> None:
        """
        ListMLE loss for learning to rank with position-aware weighting. This loss function implements 
        the ListMLE ranking algorithm which uses a list-wise approach based on maximum likelihood 
        estimation of permutations. It maximizes the likelihood of the permutation induced by the 
        ground truth labels with optional position-aware weighting.

        .. note::

            The number of documents per query can vary between samples with the ``ListMLELoss``.

        Args:
            model (CrossEncoder): CrossEncoder model to be trained
            lambda_weight (ListMLELambdaWeight, optional): Weighting scheme to use. When specified,
                implements Position-Aware ListMLE which applies different weights to different rank 
                positions. Default is None (standard ListMLE).
            activation_fct (:class:`~torch.nn.Module`): Activation function applied to the logits before computing the
                loss. Defaults to :class:`~torch.nn.Identity`.
            mini_batch_size (int, optional): Number of samples to process in each forward pass. This has a significant
                impact on the memory consumption and speed of the training process. Three cases are possible:

                - If ``mini_batch_size`` is None, the ``mini_batch_size`` is set to the batch size.
                - If ``mini_batch_size`` is greater than 0, the batch is split into mini-batches of size ``mini_batch_size``.
                - If ``mini_batch_size`` is <= 0, the entire batch is processed at once.

                Defaults to None.
            respect_input_order (bool): Whether to respect the original input order of documents.
                If True, assumes the input documents are already ordered by relevance (most relevant first).
                If False, sorts documents by label values. Defaults to True.

        References:
            - Learning to Rank: From Pairwise Approach to Listwise Approach: https://www.microsoft.com/en-us/research/publication/learning-to-rank-from-pairwise-approach-to-listwise-approach/
            - Position-Aware ListMLE: A Sequential Learning Process for Ranking: https://auai.org/uai2014/proceedings/individuals/164.pdf
            - `Cross Encoder > Training Examples > MS MARCO <../../../examples/cross_encoder/training/ms_marco/README.html>`_

        Requirements:
            1. Query with multiple documents (listwise approach)
            2. Documents must have relevance scores/labels. Both binary and continuous labels are supported.

        Inputs:
            +----------------------------------------+--------------------------------+-------------------------------+
            | Texts                                  | Labels                         | Number of Model Output Labels |
            +========================================+================================+===============================+
            | (query, [doc1, doc2, ..., docN])       | [score1, score2, ..., scoreN]  | 1                             |
            +----------------------------------------+--------------------------------+-------------------------------+

        Example:
            ::

                from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses
                from datasets import Dataset

                model = CrossEncoder("microsoft/mpnet-base")
                train_dataset = Dataset.from_dict({
                    "query": ["What are pandas?", "What is the capital of France?"],
                    "docs": [
                        ["Pandas are a kind of bear.", "Pandas are kind of like fish."],
                        ["The capital of France is Paris.", "Paris is the capital of France.", "Paris is quite large."],
                    ],
                    "labels": [[1, 0], [1, 1, 0]],
                })
                
                # Standard ListMLE loss respecting input order
                loss = losses.ListMLELoss(model)
                
                # Position-Aware ListMLE with default weighting
                lambda_weight = losses.ListMLELambdaWeight()
                loss = losses.ListMLELoss(model, lambda_weight=lambda_weight)
                
                # Position-Aware ListMLE with custom weighting function
                def custom_discount(ranks): # e.g. ranks: [1, 2, 3, 4, 5]
                    return 1.0 / torch.log1p(ranks)
                lambda_weight = losses.ListMLELambdaWeight(rank_discount_fn=custom_discount)
                loss = losses.ListMLELoss(model, lambda_weight=lambda_weight)

                trainer = CrossEncoderTrainer(
                    model=model,
                    train_dataset=train_dataset,
                    loss=loss,
                )
                trainer.train()
        """
        super().__init__()
        self.model = model
        self.lambda_weight = lambda_weight
        self.activation_fct = activation_fct or nn.Identity()
        self.mini_batch_size = mini_batch_size
        self.respect_input_order = respect_input_order
        self.eps = 1e-10

        if self.model.num_labels != 1:
            raise ValueError(
                f"{self.__class__.__name__} supports a model with 1 output label, "
                f"but got a model with {self.model.num_labels} output labels."
            )

    def forward(self, inputs: list[list[str], list[list[str]]], labels: list[Tensor]) -> Tensor:
        """
        Compute ListMLE loss for a batch of queries and their documents.

        Args:
            inputs: List of (queries, documents_list)
            labels: Ground truth relevance scores, shape (batch_size, num_documents)

        Returns:
            Tensor: Mean ListMLE loss over the batch
        """
        if isinstance(labels, Tensor):
            raise ValueError(
                "ListMLELoss expects a list of labels for each sample, but got a single value for each sample."
            )

        if len(inputs) != 2:
            raise ValueError(
                f"ListMLELoss expects two inputs (queries, documents_list), but got {len(inputs)} inputs."
            )

        queries, docs_list = inputs
        docs_per_query = [len(docs) for docs in docs_list]
        max_docs = max(docs_per_query)
        batch_size = len(queries)

        if docs_per_query != [len(labels) for labels in labels]:
            raise ValueError(
                f"Number of documents per query in inputs ({docs_per_query}) does not match number of labels per query ({[len(labels) for labels in labels]})."
            )

        pairs = [(query, document) for query, docs in zip(queries, docs_list) for document in docs]

        if not pairs:
            # Handle edge case where there are no documents
            return torch.tensor(0.0, device=self.model.device, requires_grad=True)

        mini_batch_size = self.mini_batch_size or batch_size
        if mini_batch_size <= 0:
            mini_batch_size = len(pairs)

        logits_list = []
        for i in range(0, len(pairs), mini_batch_size):
            mini_batch_pairs = pairs[i : i + mini_batch_size]

            tokens = self.model.tokenizer(
                mini_batch_pairs,
                padding=True,
                truncation=True,
                return_tensors="pt",
            )
            tokens = tokens.to(self.model.device)

            logits = self.model(**tokens)[0].view(-1)
            logits_list.append(logits)

        logits = torch.cat(logits_list, dim=0)
        logits = self.activation_fct(logits)

        # Create output tensor filled with a very small value for padded logits
        logits_matrix = torch.full((batch_size, max_docs), 1e-16, device=self.model.device)

        # Place logits in the desired positions in the logit matrix
        doc_indices = torch.cat([torch.arange(len(docs)) for docs in docs_list], dim=0)
        batch_indices = torch.repeat_interleave(torch.arange(batch_size), torch.tensor(docs_per_query))
        logits_matrix[batch_indices, doc_indices] = logits

        # Create a mask for valid entries
        mask = torch.zeros_like(logits_matrix, dtype=torch.bool)
        mask[batch_indices, doc_indices] = True

        # Convert labels to tensor matrix
        labels_matrix = torch.full_like(logits_matrix, -float("inf"))
        labels_matrix[batch_indices, doc_indices] = torch.cat(labels, dim=0).float()
        
        if not torch.any(mask):
            return torch.tensor(0.0, device=self.model.device, requires_grad=True)

        if not self.respect_input_order:
            # Sort by labels in descending order if not respecting input order.
            sorted_labels, indices = labels_matrix.sort(descending=True, dim=1)
            sorted_logits = torch.gather(logits_matrix, 1, indices)
        else:
            # Use the original input order, assuming it's already ordered by relevance
            sorted_logits = logits_matrix

        # Compute log-likelihood using Plackett-Luce model
        scores = sorted_logits.exp()
        cumsum_scores = torch.flip(torch.cumsum(torch.flip(scores, [1]), 1), [1])
        log_probs = sorted_logits - torch.log(cumsum_scores + self.eps)

        # Apply position-aware lambda weights if specified
        if self.lambda_weight is not None:
            position_weights = torch.zeros_like(log_probs)
            for i, query_mask in enumerate(mask):
                list_size = query_mask.sum()
                ranks = torch.arange(1, list_size + 1, device=self.model.device)
                position_weights[i, :list_size] = self.lambda_weight(ranks, list_size)
            log_probs = log_probs * position_weights

        # Sum the log probabilities for each list and mask invalid entries
        per_query_losses = -torch.sum(log_probs, dim=1)

        if not torch.any(per_query_losses):
            return torch.tensor(0.0, device=self.model.device, requires_grad=True)
            
        # Average loss over all lists
        return torch.mean(per_query_losses)

    def get_config_dict(self) -> dict[str, float | int | str | None]:
        """
        Get configuration parameters for this loss function.

        Returns:
            Dictionary containing the configuration parameters
        """
        return {
            "lambda_weight": None if self.lambda_weight is None else fullname(self.lambda_weight),
            "activation_fct": fullname(self.activation_fct),
            "mini_batch_size": self.mini_batch_size,
            "respect_input_order": self.respect_input_order,
        }

    @property
    def citation(self) -> str:
        return """
@inproceedings{lan2013position,
    title={Position-aware ListMLE: a sequential learning process for ranking},
    author={Lan, Yanyan and Guo, Jiafeng and Cheng, Xueqi and Liu, Tie-Yan},
    booktitle={Proceedings of the Twenty-Ninth Conference on Uncertainty in Artificial Intelligence},
    pages={333--342},
    year={2013}
}
"""

In short: pretty much what I proposed + the training script that you prepared with the pListMLELoss from the paper.

  • Tom Aarsen

@yjoonjang
Copy link
Author

Hi @tomaarsen. I've done some experiments and want to share with you.

Experiments

1. pListMLE-Identity

Metric Value
map 0.4824 (+0.0923)
mrr@10 0.5513 (+0.0833)
ndcg@10 0.5260 (+0.0707)

2. pListMLE-customweight-Identity

Metric Value
map 0.4202 (+0.0302)
mrr@10 0.4766 (+0.0086)
ndcg@10 0.4636 (+0.0082)

3. pListMLE-sigmoid

Metric Value
map 0.4503 (+0.0603)
mrr@10 0.5051 (+0.0371)
ndcg@10 0.4915 (+0.0361)

4. pListMLE-customweight-sigmoid

Metric Value
map 0.1109 (-0.2791)
mrr@10 0.1556 (-0.3124)
ndcg@10 0.1356 (-0.3198)

5. pListMLE-tanh

Metric Value
map 0.4587 (+0.0686)
mrr@10 0.5262 (+0.0582)
ndcg@10 0.4976 (+0.0423)

Total Evaluation Results

image

Train Losses

image

Analysis

The results show that

  1. Using customweight indeed makes the loss values smaller, but does not improve the model performance
  2. Adding activation functions (Sigmoid, Tanh) slightly reduces performance compared to no activation
  3. Combining custom weights with Sigmoid causes dramatic performance drop
  4. The simplest configuration (standard pListMLE with default weights) achieves the highest evaluation scores

Thoughts

The experiments demonstrate that ListMLELoss occupies a valuable position between ListNetLoss and LambdaLoss in terms of performance effectiveness. Where I believe this loss function truly excels is in scenarios where we need to learn the relative ordering among documents with identical gold labels.

For instance, consider a tool selection task where, given a query, the model must select multiple tools in a specific sequence - all correct tools might have the same binary relevance label (1), but their ordering matters significantly. In such cases, ListMLELoss would be particularly effective as it naturally models permutation probabilities through the Plackett-Luce model, allowing it to learn these subtle ordering relationships even when the labels themselves don't differentiate between equally relevant items. This makes it an excellent choice for tasks requiring sequential decision-making or preserving specific ordering among equally "correct" options.

@tomaarsen
Copy link
Owner

To add to that, I also ran just ListMLE, i.e. no lambda weight:

ListMLE

Metric Value
map 0.3559 (-0.0341)
mrr@10 0.3994 (-0.0686)
ndcg@10 0.3898 (-0.0655)

In short, perhaps ListMLE does not make sense (as it presumably focuses much too heavily on the order of the negative documents).

@yjoonjang
Copy link
Author

yjoonjang commented Mar 18, 2025

Alright. I think we can now say that using PListMLE, with justified weight($2^{n−i} − 1$), and using Identity activation function gives the optimal result.

@tomaarsen
Copy link
Owner

I agree

@yjoonjang
Copy link
Author

yjoonjang commented Mar 18, 2025

Great ! I think there are minor things left:

  1. Changing the name to PListMLELoss (as we talked before in Add Position-Aware ListMLELoss #6 (comment)) with the weight introduced in the paper.
  2. Adding a new loss named ListMLELoss which subclasses PListMLELoss with lambda_weight=None. - Although this showed poor performance in the experiment above (Add Position-Aware ListMLELoss #6 (comment)), implementing it could still be beneficial.
  3. Updating example training script to inform users that training data should be sorted in a defined rank order.

Is there anything else?
I will deal with the list above.

@tomaarsen
Copy link
Owner

That's right. There are some minor documentation things to finalize afterwards, but I will take care of those.

  • Tom Aarsen

@yjoonjang
Copy link
Author

That's right. There are some minor documentation things to finalize afterwards, but I will take care of those.

  • Tom Aarsen

I actually worked most of them, so I will just commit them. It would be great if you could check those.

@yjoonjang
Copy link
Author

I've worked on the minor changes. It would be great if you finalize and implement it! @tomaarsen

It should be equivalent, and considerably faster (although it wasn't necessarily a bottleneck)
@tomaarsen
Copy link
Owner

@yjoonjang I vectorized the lambda weight computation in 0012d0f, now the function only gets a boolean mask with positions that are not padding. It speeds up this section, although this was not a bottleneck. What do you think? I'm also okay with reverting it, if you think sticking with the 'ranks' and list_size was important.

@yjoonjang
Copy link
Author

@yjoonjang I vectorized the lambda weight computation in 0012d0f, now the function only gets a boolean mask with positions that are not padding. It speeds up this section, although this was not a bottleneck. What do you think? I'm also okay with reverting it, if you think sticking with the 'ranks' and list_size was important.

I think this is a great improvement! Moving from explicit rank positions to using a boolean mask makes the code more simple and efficient. Calculating document counts directly from the mask and ensuring weights are only applied to valid positions seems like a clearer approach to me.

@tomaarsen
Copy link
Owner

tomaarsen commented Mar 18, 2025

Okay, I think I'm almost done with all of my testing, but I have one more interesting case:

The issue

My reasoning is this:
The lambda_weight are a tensor of something like this:

tensor([[255., 127.,  63.,  31.,  15.,   7.,   3.,   1.,   0.,  -0.],
        [255., 127.,  63.,  31.,  15.,   7.,   3.,   1.,   0.,  -0.],
        [511., 255., 127.,  63.,  31.,  15.,   7.,   3.,   1.,   0.],
        [255., 127.,  63.,  31.,  15.,   7.,   3.,   1.,   0.,  -0.],
        [ 63.,  31.,  15.,   7.,   3.,   1.,   0.,  -0.,  -0.,  -0.],
        [ 63.,  31.,  15.,   7.,   3.,   1.,   0.,  -0.,  -0.,  -0.],
        [127.,  63.,  31.,  15.,   7.,   3.,   1.,   0.,  -0.,  -0.],
        [127.,  63.,  31.,  15.,   7.,   3.,   1.,   0.,  -0.,  -0.],
        [255., 127.,  63.,  31.,  15.,   7.,   3.,   1.,   0.,  -0.],
        [255., 127.,  63.,  31.,  15.,   7.,   3.,   1.,   0.,  -0.],
        [511., 255., 127.,  63.,  31.,  15.,   7.,   3.,   1.,   0.],
        [255., 127.,  63.,  31.,  15.,   7.,   3.,   1.,   0.,  -0.],
        [255., 127.,  63.,  31.,  15.,   7.,   3.,   1.,   0.,  -0.],
        [ 63.,  31.,  15.,   7.,   3.,   1.,   0.,  -0.,  -0.,  -0.],
        [127.,  63.,  31.,  15.,   7.,   3.,   1.,   0.,  -0.,  -0.],
        [255., 127.,  63.,  31.,  15.,   7.,   3.,   1.,   0.,  -0.]],
       device='cuda:0')

This is multiplied with the log_probs and then we apply the reductions (sum and mean).
When we sum per query:

> (log_probs * self.lambda_weight(mask)).sum(dim=1)
tensor([-1103.7828, -1107.8832, -2218.6123, -1104.4536,  -265.9218,  -267.1783,
         -544.8917,  -544.4240, -1110.4033, -1102.6829, -2228.4136, -1101.6121,
        -1105.8870,  -266.0999,  -544.8442, -1101.8113], device='cuda:0',
       grad_fn=<SumBackward1>)

Afterwards, we mean. With the lambda_weight case, this means that there is a very "slanted" contribution of each query to the overall loss:

  • Only a few docs? Low loss (contribution)
  • Many docs? High loss (contribution)

After all, if you sum the lambda_weight per query, you get wildly different values ranging from like 120 to 1000. I'm pretty sure this means that the training is heavily leaning towards optimizing the queries with a lot of docs as they cause the high losses.

If we don't apply this lambda_weight (i.e. ListMLE), then we can sum per query:

> torch.sum(log_probs, dim=1)
tensor([-15.2067, -15.1940, -15.0684, -15.1953, -15.2408, -15.2317, -15.2240,
        -15.2154, -15.2133, -15.2076, -15.0966, -15.1417, -15.1786, -15.2057,
        -15.2024, -15.1925], device='cuda:0', grad_fn=<SumBackward1>)

And each query has pretty equal contributions, but the position weighting is not taken into consideration.

My proposal

I propose to normalize the lambda_weight such that it has the same sum per query, ideally 1. We can do this with e.g.:

  • Softmax
  • Divide by sum

When there's many docs, softmax starts failing as it only uses the first doc:

>>> torch.tensor([255., 127.,  63.,  31.,  15.,   7.,   3.,   1.,   0.,  -0.]).softmax(0)
tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

But dividing by the sum might work:

>>> weight = torch.tensor([255., 127.,  63.,  31.,  15.,   7.,   3.,   1.,   0.,  -0.])
>>> weight / weight.sum(0)
tensor([0.5080, 0.2530, 0.1255, 0.0618, 0.0299, 0.0139, 0.0060, 0.0020, 0.0000,
        -0.0000])

My models

PListMLELoss, but with .softmax(dim=1) over the lambda_weight

Metric Value
map 0.4797 (+0.0897)
mrr@10 0.5721 (+0.1040)
ndcg@10 0.5386 (+0.0832)

PListMLELoss, but with / lambda_weight.sum(dim=1) over the lambda_weight

Metric Value
map 0.4669 (+0.0768)
mrr@10 0.5474 (+0.0794)
ndcg@10 0.5240 (+0.0686)

Graphs compared to default:

image

As you can see here, the "better" results from the Sigmoid one is primarily just an outlier, it actually ends worse than the others if it wasn't for the "load_best_model_at_end". I'm quite a fan of the "divide by sum" option. This is with just the 2 from this experiment, you can see the effect on the loss too:
image

The softmax gets lower loss because it usually only cares about the positive document.


Also, inspecting that lambda_weight compared to the mask:

tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True, False]],
       device='cuda:0')

It seems that the last document (the last True) always has a weight of 0. That might not be intended.

  • Tom Aarsen

@tomaarsen
Copy link
Owner

I also ran the sum-to-1 script, but for the full https://huggingface.co/datasets/sentence-transformers/msmarco dataset (labeled-list subset).

Metric Value
map 0.5511 (+0.1610)
mrr@10 0.6286 (+0.1606)
ndcg@10 0.6030 (+0.1476)
Metric Value
map 0.5638 (+0.1738)
mrr@10 0.6485 (+0.1805)
ndcg@10 0.6200 (+0.1647)

I'm tempted to merge in the sum-to-1 approach:

        # Apply position-aware lambda weights if specified. If None, then this loss
        # is just ListMLE.
        if self.lambda_weight is not None:
            lambda_weight = self.lambda_weight(mask)
            # Normalize weights to sum to 1
            lambda_weight_sum = lambda_weight.sum(dim=1, keepdim=True) + self.eps
            lambda_weight = lambda_weight / lambda_weight_sum
            log_probs = log_probs * lambda_weight
  • Tom Aarsen

@yjoonjang
Copy link
Author

Thank you for sharing your insights throughout the experiment, @tomaarsen.
I would like to talk first with the 'lambda_weight being always 0 for the last document' problem.

Problem for the last document weight

This indeed is a problem, and I propose to fix the code from
weights = torch.pow(2.0, num_docs_per_query - ranks) - 1.0 to
weights = torch.pow(2.0, num_docs_per_query - ranks + 1) - 1.0.
Adding 1 to each rank will prevent the last document's lambda_weight from being 0.

Normalization methods

After fixing the issue for the ranks, I did some experiments related to normalization. Since I also thought that the loss was too high, I think your approach is great. Besides your normalization method (softmax, sum-to-1), I've experimented other methods too: minmax, log, temperature (the codes are provided below.)

code
from __future__ import annotations

import torch
from torch import Tensor, nn

from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.util import fullname


class PListMLELambdaWeight(nn.Module):
    """Base class for implementing weighting schemes in Position-Aware ListMLE Loss."""

    def __init__(self, rank_discount_fn=None, normalize_weights=True, normalization_method="sum", temperature=0.5) -> None:
        """
        Initialize a lambda weight for PListMLE loss.

        Args:
            rank_discount_fn: Function that computes a discount for each rank position.
                              If None, uses default discount of 2^(num_docs - rank) - 1.
            normalize_weights: Whether to normalize weights to sum to 1 per query.
                              If True, each query has equal contribution to the loss regardless of document count.
            normalization_method: Method to use for normalization. Options:
                                 - "sum": Normalize to sum to 1 (default)
                                 - "softmax": Use softmax normalization
                                 - "minmax": Min-max normalization to [0,1] range
                                 - "log": Log normalization
                                 - "temperature": Softmax with temperature scaling
            temperature: Temperature parameter for softmax scaling (default: 0.5).
                         Lower values make distribution more peaked at high weights.
                         Only used when normalization_method="temperature".
        """
        super().__init__()
        self.rank_discount_fn = rank_discount_fn
        self.normalize_weights = normalize_weights
        self.normalization_method = normalization_method
        self.temperature = temperature

    def forward(self, mask: Tensor) -> Tensor:
        """
        Calculate position-aware weights for the PListMLE loss.

        Args:
            mask: A boolean mask indicating valid positions [batch_size, num_docs]

        Returns:
            Tensor: Weights for each position [batch_size, num_docs]
        """
        if self.rank_discount_fn is not None:
            weights = self.rank_discount_fn(mask)
        else:
            # Apply default rank discount: 2^(num_docs - rank) - 1
            num_docs_per_query = mask.sum(dim=1, keepdim=True)
            ranks = torch.arange(1, mask.size(1) + 1, device=mask.device).expand_as(mask)
            # weights = torch.pow(2.0, num_docs_per_query - ranks) - 1.0
            weights = torch.pow(2.0, num_docs_per_query - ranks + 1) - 1.0
            weights = weights * mask
        
        # Normalize weights to sum to 1 for each query if requested
        if self.normalize_weights:
            if self.normalization_method == "sum":
                # Normalize to sum to 1 (original implementation)
                weight_sums = weights.sum(dim=1, keepdim=True)
                weight_sums = torch.clamp(weight_sums, min=1e-10)
                weights = weights / weight_sums
                
            elif self.normalization_method == "softmax":
                # Use softmax normalization
                # Fix for numerical stability - convert original weights to logits for softmax
                logits = torch.log(weights + 1e-10) * mask.float()
                weights = torch.softmax(logits, dim=1) * mask.float()
                
            elif self.normalization_method == "minmax":
                # Min-max normalization per query
                min_vals, _ = torch.min(weights + (1 - mask.float()) * 1e10, dim=1, keepdim=True)
                max_vals, _ = torch.max(weights * mask.float(), dim=1, keepdim=True)
                # Avoid division by zero
                denom = torch.clamp(max_vals - min_vals, min=1e-10)
                # Normalize to [0,1] range
                weights = ((weights - min_vals) / denom) * mask.float()
                # Ensure sum to 1
                weight_sums = weights.sum(dim=1, keepdim=True)
                weight_sums = torch.clamp(weight_sums, min=1e-10)
                weights = weights / weight_sums
                
            elif self.normalization_method == "log":
                # Log normalization
                # Add small constant for log stability
                log_weights = torch.log1p(weights) * mask.float()
                weight_sums = log_weights.sum(dim=1, keepdim=True)
                weight_sums = torch.clamp(weight_sums, min=1e-10)
                weights = log_weights / weight_sums
                
            elif self.normalization_method == "temperature":
                # Softmax with temperature scaling
                logits = torch.log(weights + 1e-10) * mask.float()
                weights = torch.softmax(logits / self.temperature, dim=1) * mask.float()
        
        return weights
        
    def get_config_dict(self) -> dict[str, float | int | str | bool]:
        """
        Get configuration parameters for this lambda weight.

        Returns:
            Dictionary containing the configuration parameters
        """
        return {
            "rank_discount_fn": None if self.rank_discount_fn is None else fullname(self.rank_discount_fn),
            "normalize_weights": self.normalize_weights,
            "normalization_method": self.normalization_method,
            "temperature": self.temperature,
        }


class PListMLELoss(nn.Module):
    def __init__(
        self,
        model: CrossEncoder,
        lambda_weight: PListMLELambdaWeight | None = PListMLELambdaWeight(normalize_weights=True),
        activation_fct: nn.Module | None = nn.Identity(),
        mini_batch_size: int | None = None,
        respect_input_order: bool = True,
    ) -> None:
        """
        PListMLE loss for learning to rank with position-aware weighting. This loss function implements
        the ListMLE ranking algorithm which uses a list-wise approach based on maximum likelihood
        estimation of permutations. It maximizes the likelihood of the permutation induced by the
        ground truth labels with position-aware weighting.

        This loss is also known as Position-Aware ListMLE or p-ListMLE.

        .. note::

            The number of documents per query can vary between samples with the ``PListMLELoss``.

        Args:
            model (CrossEncoder): CrossEncoder model to be trained
            lambda_weight (PListMLELambdaWeight, optional): Weighting scheme to use. When specified,
                implements Position-Aware ListMLE which applies different weights to different rank
                positions. Default is PListMLELambdaWeight with normalize_weights=True, which ensures 
                each query contributes equally to the loss regardless of document count.
                Several normalization methods are available through the PListMLELambdaWeight's
                normalization_method parameter:
                
                - "sum": Standard normalization to sum to 1 (default)
                - "softmax": Softmax normalization for smoother distribution
                - "minmax": Min-max normalization to [0,1] range
                - "log": Log normalization for compressing large differences
                - "temperature": Softmax with temperature scaling (t=0.5)
                
                Setting normalize_weights=False will make queries with more documents contribute more 
                to the loss, which may bias training towards these queries.
            activation_fct (:class:`~torch.nn.Module`): Activation function applied to the logits before computing the
                loss. Defaults to :class:`~torch.nn.Identity`.
            mini_batch_size (int, optional): Number of samples to process in each forward pass. This has a significant
                impact on the memory consumption and speed of the training process. Three cases are possible:

                - If ``mini_batch_size`` is None, the ``mini_batch_size`` is set to the batch size.
                - If ``mini_batch_size`` is greater than 0, the batch is split into mini-batches of size ``mini_batch_size``.
                - If ``mini_batch_size`` is <= 0, the entire batch is processed at once.

                Defaults to None.
            respect_input_order (bool): Whether to respect the original input order of documents.
                If True, assumes the input documents are already ordered by relevance (most relevant first).
                If False, sorts documents by label values. Defaults to True.

        References:
            - Position-Aware ListMLE: A Sequential Learning Process for Ranking: https://auai.org/uai2014/proceedings/individuals/164.pdf
            - `Cross Encoder > Training Examples > MS MARCO <../../../examples/cross_encoder/training/ms_marco/README.html>`_

        Requirements:
            1. Query with multiple documents (listwise approach)
            2. Documents must have relevance scores/labels. Both binary and continuous labels are supported.
            3. Documents must be sorted in a defined rank order.

        Inputs:
            +----------------------------------------+--------------------------------+-------------------------------+
            | Texts                                  | Labels                         | Number of Model Output Labels |
            +========================================+================================+===============================+
            | (query, [doc1, doc2, ..., docN])       | [score1, score2, ..., scoreN]  | 1                             |
            +----------------------------------------+--------------------------------+-------------------------------+

        Recommendations:
            - Use :class:`~sentence_transformers.util.mine_hard_negatives` with ``output_format="labeled-list"``
              to convert question-answer pairs to the required input format with hard negatives.

        Relations:
            - The :class:`~sentence_transformers.cross_encoder.losses.PListMLELoss` is an extension of the
              :class:`~sentence_transformers.cross_encoder.losses.ListMLELoss` and allows for positional weighting
              of the loss. :class:`~sentence_transformers.cross_encoder.losses.PListMLELoss` generally outperforms
              :class:`~sentence_transformers.cross_encoder.losses.ListMLELoss` and is recommended over it.
            - :class:`~sentence_transformers.cross_encoder.losses.LambdaLoss` takes the same inputs, and generally
              outperforms this loss.

        Example:
            ::

                from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses
                from datasets import Dataset

                model = CrossEncoder("microsoft/mpnet-base")
                train_dataset = Dataset.from_dict({
                    "query": ["What are pandas?", "What is the capital of France?"],
                    "docs": [
                        ["Pandas are a kind of bear.", "Pandas are kind of like fish."],
                        ["The capital of France is Paris.", "Paris is the capital of France.", "Paris is quite large."],
                    ],
                    "labels": [[1, 0], [1, 1, 0]],
                })

                # Either: Position-Aware ListMLE with default weighting
                lambda_weight = losses.PListMLELambdaWeight()
                loss = losses.PListMLELoss(model, lambda_weight=lambda_weight)

                # or: Position-Aware ListMLE with custom weighting function
                def custom_discount(ranks): # e.g. ranks: [1, 2, 3, 4, 5]
                    return 1.0 / torch.log1p(ranks)
                lambda_weight = losses.PListMLELambdaWeight(rank_discount_fn=custom_discount)
                loss = losses.PListMLELoss(model, lambda_weight=lambda_weight)
                
                # or: Position-Aware ListMLE with weights explicitly not normalized (original behavior)
                lambda_weight = losses.PListMLELambdaWeight(normalize_weights=False)
                loss = losses.PListMLELoss(model, lambda_weight=lambda_weight)
                
                # or: Position-Aware ListMLE with softmax for normalization 
                # Softmax normalization (may focus more on top documents)
                lambda_weight = losses.PListMLELambdaWeight(normalization_method="softmax")
                loss = losses.PListMLELoss(model, lambda_weight=lambda_weight)

                trainer = CrossEncoderTrainer(
                    model=model,
                    train_dataset=train_dataset,
                    loss=loss,
                )
                trainer.train()
        """
        super().__init__()
        self.model = model
        self.lambda_weight = lambda_weight
        self.activation_fct = activation_fct or nn.Identity()
        self.mini_batch_size = mini_batch_size
        self.respect_input_order = respect_input_order
        self.eps = 1e-10

        if self.model.num_labels != 1:
            raise ValueError(
                f"{self.__class__.__name__} supports a model with 1 output label, "
                f"but got a model with {self.model.num_labels} output labels."
            )

    def forward(self, inputs: list[list[str], list[list[str]]], labels: list[Tensor]) -> Tensor:
        """
        Compute PListMLE loss for a batch of queries and their documents.

        Args:
            inputs: List of (queries, documents_list)
            labels: Ground truth relevance scores, shape (batch_size, num_documents)

        Returns:
            Tensor: Mean PListMLE loss over the batch
        """
        if isinstance(labels, Tensor):
            raise ValueError(
                "PListMLELoss expects a list of labels for each sample, but got a single value for each sample."
            )

        if len(inputs) != 2:
            raise ValueError(
                f"PListMLELoss expects two inputs (queries, documents_list), but got {len(inputs)} inputs."
            )

        queries, docs_list = inputs
        docs_per_query = [len(docs) for docs in docs_list]
        max_docs = max(docs_per_query)
        batch_size = len(queries)

        if docs_per_query != [len(labels) for labels in labels]:
            raise ValueError(
                f"Number of documents per query in inputs ({docs_per_query}) does not match number of labels per query ({[len(labels) for labels in labels]})."
            )

        pairs = [(query, document) for query, docs in zip(queries, docs_list) for document in docs]

        if not pairs:
            # Handle edge case where there are no documents
            return torch.tensor(0.0, device=self.model.device, requires_grad=True)

        mini_batch_size = self.mini_batch_size or batch_size
        if mini_batch_size <= 0:
            mini_batch_size = len(pairs)

        logits_list = []
        for i in range(0, len(pairs), mini_batch_size):
            mini_batch_pairs = pairs[i : i + mini_batch_size]

            tokens = self.model.tokenizer(
                mini_batch_pairs,
                padding=True,
                truncation=True,
                return_tensors="pt",
            )
            tokens = tokens.to(self.model.device)

            logits = self.model(**tokens)[0].view(-1)
            logits_list.append(logits)

        logits = torch.cat(logits_list, dim=0)
        logits = self.activation_fct(logits)

        # Create output tensor filled with a very small value for padded logits
        logits_matrix = torch.full((batch_size, max_docs), 1e-16, device=self.model.device)

        # Place logits in the desired positions in the logit matrix
        doc_indices = torch.cat([torch.arange(len(docs)) for docs in docs_list], dim=0)
        batch_indices = torch.repeat_interleave(torch.arange(batch_size), torch.tensor(docs_per_query))
        logits_matrix[batch_indices, doc_indices] = logits

        # Create a mask for valid entries
        mask = torch.zeros_like(logits_matrix, dtype=torch.bool)
        mask[batch_indices, doc_indices] = True

        # Convert labels to tensor matrix
        labels_matrix = torch.full_like(logits_matrix, -float("inf"))
        labels_matrix[batch_indices, doc_indices] = torch.cat(labels, dim=0).float()

        if not torch.any(mask):
            return torch.tensor(0.0, device=self.model.device, requires_grad=True)

        if not self.respect_input_order:
            # Sort by labels in descending order if not respecting input order.
            sorted_labels, indices = labels_matrix.sort(descending=True, dim=1)
            sorted_logits = torch.gather(logits_matrix, 1, indices)
        else:
            # Use the original input order, assuming it's already ordered by relevance
            sorted_logits = logits_matrix

        # Compute log-likelihood using Plackett-Luce model
        scores = sorted_logits.exp()
        cumsum_scores = torch.flip(torch.cumsum(torch.flip(scores, [1]), 1), [1])
        log_probs = sorted_logits - torch.log(cumsum_scores + self.eps)

        if self.lambda_weight is not None:
            log_probs = log_probs * self.lambda_weight(mask)

        # Sum the log probabilities for each list and mask padded entries
        log_probs[~mask] = 0.0
        per_query_losses = -torch.sum(log_probs, dim=1)

        if not torch.any(per_query_losses):
            return torch.tensor(0.0, device=self.model.device, requires_grad=True)

        # Average loss over all lists
        return torch.mean(per_query_losses)

    def get_config_dict(self) -> dict[str, float | int | str | None]:
        """
        Get configuration parameters for this loss function.

        Returns:
            Dictionary containing the configuration parameters
        """
        return {
            "lambda_weight": None if self.lambda_weight is None else fullname(self.lambda_weight),
            "activation_fct": fullname(self.activation_fct),
            "mini_batch_size": self.mini_batch_size,
            "respect_input_order": self.respect_input_order,
        }

    @property
    def citation(self) -> str:
        return """
@inproceedings{lan2014position,
  title={Position-Aware ListMLE: A Sequential Learning Process for Ranking.},
  author={Lan, Yanyan and Zhu, Yadong and Guo, Jiafeng and Niu, Shuzi and Cheng, Xueqi},
  booktitle={UAI},
  volume={14},
  pages={449--458},
  year={2014}
}
"""

Results

(Note that the my results can differ with yours since I fixed the 'lambda_weight being always 0 for the last document' problem metioned above.)

Metric Value
map 0.4334 (+0.0433)
mrr@10 0.5103 (+0.0423)
ndcg@10 0.4818 (+0.0264)
Metric Value
map 0.4894 (+0.0994)
mrr@10 0.5815 (+0.1135)
ndcg@10 0.5420 (+0.0866)
Metric Value
map 0.4740 (+0.0839)
mrr@10 0.5602 (+0.0921)
ndcg@10 0.5174 (+0.0621)
Metric Value
map 0.4294 (+0.0393)
mrr@10 0.4881 (+0.0201)
ndcg@10 0.4658 (+0.0104)
Metric Value
map 0.4745 (+0.0844)
mrr@10 0.5479 (+0.0799)
ndcg@10 0.5206 (+0.0653)
Metric Value
map 0.4742 (+0.0841)
mrr@10 0.5569 (+0.0889)
ndcg@10 0.5254 (+0.0701)

Evaluation graphs across all methods:
image

As you see in the results, it appears to be softmax > temperature(2) > temperature(0.5) > minmax > sum-to-1 > log in my experiment.

When comparing softmax and sum-to-1 you mentioned:
image
It shows that softmax shows better performance robustly.

I actually want to implement softmax normalization a bit more.
What are your thoughts?

@tomaarsen
Copy link
Owner

tomaarsen commented Mar 19, 2025

That looks quite promising - I didn't do log before the softmax, I just applied softmax, which meant that the weights were mostly just [1, 0, 0, 0, ...]. Oops. I'll also do a test with the log-softmax!

0.54 is definitely higher than we've seen for PListMLE - we've not even seen 0.53... before whereas your softmax one reaches that in ~5 different evals, seems like it's not an outlier.

  • Tom Aarsen

@tomaarsen
Copy link
Owner

Bad news (perhaps):
weight / (weight.sum(dim=1) + self.eps) and weight.log().softmax(dim=1) are identical:

            lambda_weight = self.lambda_weight(mask)
            # Normalize weights to sum to 1
            lambda_weight_sum_1 = lambda_weight.sum(dim=1, keepdim=True) + self.eps
            lambda_weight_1 = lambda_weight / lambda_weight_sum_1

            lambda_weight_2 = lambda_weight.log().softmax(dim=1)
(Pdb) lambda_weight_1[-1]
tensor([0.5044, 0.2517, 0.1254, 0.0622, 0.0306, 0.0148, 0.0069, 0.0030, 0.0010,
        0.0000], device='cuda:0')
(Pdb) lambda_weight_2[-1]
tensor([0.5044, 0.2517, 0.1254, 0.0622, 0.0306, 0.0148, 0.0069, 0.0030, 0.0010,
        0.0000], device='cuda:0')

So the sum-to-1 and log -> softmax should perform identically.

@yjoonjang
Copy link
Author

Oh you're right. Mine's a little different though because it multiplies the mask twice:

            elif self.normalization_method == "softmax":
                # Use softmax normalization
                # Fix for numerical stability - convert original weights to logits for softmax
                logits = torch.log(weights + 1e-10) * mask.float()
                weights = torch.softmax(logits, dim=1) * mask.float()

But I don't think this is a big deal. We could merge in to the sum-to-1 approach you mentioned with the fix of last document weight being 0 problem.

@tomaarsen
Copy link
Owner

I'm afraid that it seems like the CrossEncoder initialization creates random weights for the classifier head, before the seed is set in the Trainer. This means that the runs are not reproducible, so the differences in results are likely either biased or completely caused by the random initialization :(

I'll merge the sum-to-1 approach, it does seem to do well.

@tomaarsen
Copy link
Owner

I'm going to merge this now, I quite like how it looks. Great work on this @yjoonjang, and I really appreciate the experimentation as well.

@milistu I know you're also interested in having a look: you can still do that with the merged version and give feedback if you want.

  • Tom Aarsen

@tomaarsen tomaarsen merged commit 8d0b2d3 into tomaarsen:feat/cross_encoder_trainer Mar 19, 2025
@yjoonjang
Copy link
Author

Nice work and thank you, @tomaarsen !!

tomaarsen pushed a commit that referenced this pull request May 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants