-
Notifications
You must be signed in to change notification settings - Fork 851
fix: Cannot Load LightGBM Model When Placed in a Spark Pipeline with Custom Transformers #2357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Cannot Load LightGBM Model When Placed in a Spark Pipeline with Custom Transformers #2357
Conversation
3391beb
to
1f1bfdb
Compare
@mhamilton723, @svotaw, tagging maintainers that have both interacted with this problem in the past. |
def setUpClass(cls): | ||
cls.spark = SparkSession.builder \ | ||
.appName("LightGBMSerializationTests") \ | ||
.getOrCreate() | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
cls.spark.stop() | ||
|
||
def setUp(self): | ||
# Define a temporary directory for saving models/pipelines. | ||
self.temp_dir = "./tmp/lightgbm_serialization_test" | ||
if os.path.exists(self.temp_dir): | ||
shutil.rmtree(self.temp_dir) | ||
|
||
def tearDown(self): | ||
if os.path.exists(self.temp_dir): | ||
shutil.rmtree(self.temp_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our tests manage the spark and spark context objects on behalf of the user, im not sure youll need this if you copy the structure of other tests in the repo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see, reused their components. Removed handling functions.
from synapse.ml.stages import SelectColumns | ||
import synapse.ml.lightgbm as lgbm | ||
|
||
class StringArrayToVectorTransformer(Transformer, DefaultParamsReadable, DefaultParamsWritable): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a throwaway piece of code to just test serialization of a custom model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was not exactly sure on the interpretation of this but stripped a fair amount of logic from the test and changed the function name/description to reflect better what it is for. Based on the update, let me know if you think if it should be further reduced. I have the single vs pipeline serialization test to demonstrate the standard single-model Python serialization/deserialization was not affected along with the pipeline test demonstrating the mentioned issue is resolved.
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #2357 +/- ##
==========================================
- Coverage 84.66% 84.60% -0.06%
==========================================
Files 331 331
Lines 17179 17179
Branches 1550 1550
==========================================
- Hits 14544 14535 -9
- Misses 2635 2644 +9 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…dsmith111/SynapseML into dsmith111/lightgbm-fix-pipeline
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run |
Commenter does not have sufficient privileges for PR 2357 in repo microsoft/SynapseML |
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
Related Issues/PRs
Directly related issues: #2293, #1701
Issue that looks to be related, however, used an older API: #614
What changes are proposed in this pull request?
Problem
When attempting to load a Spark pipeline that includes both a non-java backed, "pure Python" transformer and a LightGBM model, an AttributeError was encountered, indicating that the module com.microsoft.azure.synapse.ml.lightgbm lacked the LightGBMClassificationModel attribute. The existing java_params_patch was not sufficient to resolve this, as it only applied to the JavaParams._from_java method and did not cover scenarios where DefaultParamsReader.loadParamsInstance was used.
The previous patch for JavaParams._from_java handled the conversion of Java objects to Python objects during deserialization but did not account for the direct instantiation of classes based on metadata in DefaultParamsReader.loadParamsInstance (as is done during Pipeline/PipelineModel stage loading). As a result, the class references within the ComplexParamsMixin superclass were not correctly resolved, leading to the AttributeError in the above issues.
Solution
In similar fashion to the existing patch, this solution adds a loadParamsInstance patch which enhances the existing Python class handling condition to correct the SynapseML class name in addition to its existing pyspark class name correction.
This PR ensures that the necessary class name transformations are correctly handled during the loading process, addressing the limitation of the previous patch.
How is this patch tested?
Does this PR change any dependencies?
Does this PR add a new feature? If so, have you added samples on website?
website/docs/documentation
folder.Make sure you choose the correct class
estimators/transformers
and namespace.DocTable
points to correct API link.yarn run start
to make sure the website renders correctly.<!--pytest-codeblocks:cont-->
before each python code blocks to enable auto-tests for python samples.WebsiteSamplesTests
job pass in the pipeline.