-
Notifications
You must be signed in to change notification settings - Fork 111
partial fix for #237 -- switching Device exhausts many resources #242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -31,14 +31,65 @@ class StreamTests: XCTestCase { | |||
func testUsingDevice() { | |||
let defaultDevice = Device.defaultDevice() | |||
|
|||
using(device: .cpu) { | |||
Device.withDefaultDevice(.cpu) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch to the newer API (older is deprecated and is just a cover on this).
XCTAssertTrue(StreamOrDevice.default.description.contains("gpu")) | ||
} | ||
XCTAssertTrue(StreamOrDevice.default.description.contains("gpu")) | ||
} | ||
|
||
func testSetUnsetDefaultDevice() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This replicates the failure from #237 -- it works with this change.
} | ||
|
||
func disabledTestCreateStream() { | ||
// see https://github.com/ml-explore/mlx/issues/2118 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Future tests once ml-explore/mlx#2118 is fixed and in a swift build.
@@ -35,7 +35,7 @@ public struct StreamOrDevice: Sendable, CustomStringConvertible, Equatable { | |||
/// This will be ``Device/gpu`` unless ``Device/setDefault(device:)`` | |||
/// sets it otherwise. | |||
public static var `default`: StreamOrDevice { | |||
StreamOrDevice(Stream.defaultStream) | |||
StreamOrDevice(Stream.defaultStream ?? Device.defaultStream()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use a task local stream or fall back to the current default stream attached to the device.
public static let gpu = Stream(.gpu) | ||
public static let cpu = Stream(.cpu) | ||
public static let gpu = Stream(mlx_default_gpu_stream_new()) | ||
public static let cpu = Stream(mlx_default_cpu_stream_new()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To prevent circularity in init between Device and Stream
@@ -28,9 +28,19 @@ public enum DeviceType: String, Hashable, Sendable { | |||
public final class Device: @unchecked Sendable, Equatable { | |||
|
|||
let ctx: mlx_device | |||
let defaultStream: Stream |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now the defaultStream
is scoped to the Device
rather than being a global.
} | ||
|
||
public init() { | ||
@available(*, deprecated, message: "please use defaultDevice()") | ||
public convenience init() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still works but why create a new one when we already have one laying around
#endif | ||
|
||
static public func defaultDevice() -> Device { | ||
@TaskLocal static var _tlDefaultDevice = _resolveGlobalDefaultDevice() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A thread local Device
-- this lets us manipulate the "global" in a task-safe manner, just like we have done elsewhere.
/// - ``StreamOrDevice/default`` | ||
static public func setDefault(device: Device) { | ||
@available(*, deprecated, message: "please use withDefaultDevice()") | ||
static public func setDefault(device: Device?) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still works but callers should use the task-scoped variant. Kept for backward compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!!
This is a workaround for issues seen in the latest tag, I think ultimately caused by:
This uses statically defined streams when switching device between
.cpu
and.gpu
. Additionally it adds a task-scoped override of the defaultDevice
along the lines of what we have done recently forStream
and error handling.