Skip to content

[BUG] thread issues with evaluation #2067

@acsweet

Description

@acsweet

Describe the bug
There seem to be some thread data access issues introduced in the newest version of mlx.

To Reproduce
This doesn't produce the exact same error I was experiencing, but it does produce a similar error (I think).

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import threading
import time
import traceback
import faulthandler

faulthandler.enable()

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.dense1 = nn.Linear(10, 32)
        self.dense2 = nn.Linear(32, 16)
        self.dense3 = nn.Linear(16, 1)
        
    def __call__(self, x):
        x = mx.maximum(0, self.dense1(x))
        x = mx.maximum(0, self.dense2(x))
        return self.dense3(x)

def loss_fn(model, x, y):
    pred = model(x)
    return mx.mean((pred - y) ** 2)

def train_and_convert(thread_id):
    try:
        print(f"Thread {thread_id} starting")
        
        model = SimpleModel()
        optimizer = optim.Adam(learning_rate=0.001)

        x = mx.random.normal((128, 10))
        y = mx.random.normal((128, 1))
        
        for i in range(10):
            print(f"Thread {thread_id}, Iteration {i} starting")
            loss, grads = mx.value_and_grad(loss_fn)(model, x, y)
            
            print(f"Thread {thread_id}, Iteration {i} - evaluating gradients")
            for g in grads.values():
                mx.eval(g)
            
            print(f"Thread {thread_id}, Iteration {i} - updating model")
            optimizer.update(model, grads)
            
            print(f"Thread {thread_id}, Iteration {i} - evaluating loss")
            mx.eval(loss)
            
            print(f"Thread {thread_id}, Iteration {i} - converting loss to numpy")
            # Convert loss to numpy
            np_loss = np.array(loss)
            
            print(f"Thread {thread_id}, Iteration {i} - converting gradients to numpy")
            for name, g in list(grads.items())[:2]:
                print(f"Thread {thread_id}, Iteration {i} - converting gradient {name}")
                np_grad = np.array(g)
                print(f"Thread {thread_id}, Iteration {i}, Grad {name} shape: {np_grad.shape}")
            
            print(f"Thread {thread_id}, Iteration {i}, Loss: {np_loss}")
            
            # small delay to increase chances of thread overlap
            time.sleep(0.01)
    except Exception as e:
        print(f"Thread {thread_id} failed with exception: {e}")
        print(traceback.format_exc())

threads = []
for i in range(3): # Try with more threads
    t = threading.Thread(target=train_and_convert, args=(i,))
    t.daemon = False
    threads.append(t)
    t.start()
    # Small delay between thread starts
    time.sleep(0.05)

try:
    for t in threads:
        t.join()
    print("All threads completed successfully")
except KeyboardInterrupt:
    print("Interrupted by user")
except Exception as e:
    print(f"Exception in main thread: {e}")
    print(traceback.format_exc())

Expected behavior
When I run this with mlx 0.23.2 it executes fine, but with 0.24.0 and up it either gives a segmentation fault or python fatal error like:

python(15651,0x171ba3000) malloc: Double free of object 0x14ea0dbb0
python(15651,0x171ba3000) malloc: *** set a breakpoint in malloc_error_break to debug
Fatal Python error: Aborted

I'd expect it to execute with no issues, or if this is expected/unavoidable behavior with the updates from version 0.24.0 onwards.

Desktop (please complete the following information):

  • OS Version: MacOS 15.2
  • Version: 0.23.2 and 0.24.0 up

Additional context
This is related to the mlx backend for Keras (keras-team/keras#19571), and I initially ran across this error with a simple model.fit(), it would occur when the progress bar was updating (which involves a cast to a numpy array to display the loss). In that case it also executed to completion with 0.23.2 and returned for 0.24.2 failed with a segmentation fault or an error like this:

-[AGXG16XFamilyCommandBuffer tryCoalescingPreviousComputeCommandEncoderWithConfig:nextEncoderClass:]:1091: failed assertion `A command encoder is already encoding to this command buffer'
Fatal Python error: Aborted

While I was trying to replicate that last error I was able to get the error above. I can keep trying to replicate the last error with a simple code block if that's helpful, but hopefully the above is enough to identify what's happening.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions