Skip to content

Conversation

davidkoski
Copy link
Collaborator

This is a workaround for issues seen in the latest tag, I think ultimately caused by:

This uses statically defined streams when switching device between .cpu and .gpu. Additionally it adds a task-scoped override of the default Device along the lines of what we have done recently for Stream and error handling.

@davidkoski davidkoski requested a review from awni May 28, 2025 22:17
@@ -31,14 +31,65 @@ class StreamTests: XCTestCase {
func testUsingDevice() {
let defaultDevice = Device.defaultDevice()

using(device: .cpu) {
Device.withDefaultDevice(.cpu) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to the newer API (older is deprecated and is just a cover on this).

XCTAssertTrue(StreamOrDevice.default.description.contains("gpu"))
}
XCTAssertTrue(StreamOrDevice.default.description.contains("gpu"))
}

func testSetUnsetDefaultDevice() {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This replicates the failure from #237 -- it works with this change.

}

func disabledTestCreateStream() {
// see https://github.com/ml-explore/mlx/issues/2118
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future tests once ml-explore/mlx#2118 is fixed and in a swift build.

@@ -35,7 +35,7 @@ public struct StreamOrDevice: Sendable, CustomStringConvertible, Equatable {
/// This will be ``Device/gpu`` unless ``Device/setDefault(device:)``
/// sets it otherwise.
public static var `default`: StreamOrDevice {
StreamOrDevice(Stream.defaultStream)
StreamOrDevice(Stream.defaultStream ?? Device.defaultStream())
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a task local stream or fall back to the current default stream attached to the device.

public static let gpu = Stream(.gpu)
public static let cpu = Stream(.cpu)
public static let gpu = Stream(mlx_default_gpu_stream_new())
public static let cpu = Stream(mlx_default_cpu_stream_new())
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To prevent circularity in init between Device and Stream

@@ -28,9 +28,19 @@ public enum DeviceType: String, Hashable, Sendable {
public final class Device: @unchecked Sendable, Equatable {

let ctx: mlx_device
let defaultStream: Stream
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the defaultStream is scoped to the Device rather than being a global.

}

public init() {
@available(*, deprecated, message: "please use defaultDevice()")
public convenience init() {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still works but why create a new one when we already have one laying around

#endif

static public func defaultDevice() -> Device {
@TaskLocal static var _tlDefaultDevice = _resolveGlobalDefaultDevice()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A thread local Device -- this lets us manipulate the "global" in a task-safe manner, just like we have done elsewhere.

/// - ``StreamOrDevice/default``
static public func setDefault(device: Device) {
@available(*, deprecated, message: "please use withDefaultDevice()")
static public func setDefault(device: Device?) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still works but callers should use the task-scoped variant. Kept for backward compatibility.

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!!

@davidkoski davidkoski merged commit b94473a into main Jun 2, 2025
1 check passed
@davidkoski davidkoski deleted the device-fix branch June 2, 2025 15:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants