-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Implement Khatri-Rao operator #7781
Conversation
Pinging @jli05 and @JeanKossaifi |
@cswiercz that's a great addition and will be very useful! Was meaning to do it but didn't have time.. I am not sure about the row_wise flag. The khatri-rao is a column-wise kronecker product, and the row-wise would also refer to the kronecker product, not khatri-rao.. Would it be clearer to have khatri_rao and maybe row_wise_kronecker (long name but unambiguous?) |
Yeah, I wasn't so sure about the flag. Having a separate operator sounds like a better idea. Plus it would clean up the code a bit. For consistency maybe:
Thoughts from anyone who is interested? |
I'm for exposing only column-wise Khatri-Rao for the Python Op, but only reserving the row-wise one for framework developers (C++). |
Fair enough. It would certainly simplify the code considerably. I made the appropriate changes. |
@mli @piiswrong - Two MXNet contributors who know this function have approved this PR. Is there anything I'm missing? |
@mli @piiswrong - If it makes any difference, this code is going into |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for missing this. Please ping me with email if PR gets ignored for too long next time.
@@ -111,7 +111,11 @@ inline char loup(char uplo, bool invert) { return invert ? (uplo == 'U' ? 'L' : | |||
* \param lda leading dimension of a | |||
*/ | |||
template <typename xpu, typename DType> | |||
inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What was the previous definition for flip?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only the DType=float
and DType=double
versions were implemented in the subsequent lines. I the function prototype in this line with a generalized version templated on the DType
.
This change was necessary due to the use of MSHADOW_TYPE_SWITCH
, I believe. Without this templating I was observing compile errors when testing with non float
and double
types.
I think the float
and double
-specific versions can be removed, actually...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes they don't make much sense.
Please remove them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. See 66afdc3 below.
No worries. I know you're busy with more pressing, non-contrib things. |
@cswiercz Could you rebase to current master? |
Create an operator for the Khatri-Rao function implemented in apache#6567. The operator accepts a variable number of input matrix arguments as well as a flag indicating whether the Khatri-Rao product should be computed row-wise or column-wise. This option is provided because the row-wise operation is more memory efficient.
Remove row-wise flag from parameter/keyword set. That is, only compute the column-wise Khatri-Rao product.
Remove unnecessary type-specific `flip` code. The fully-templatized version handles the `float` and `double` instances.
66afdc3
to
4fb5114
Compare
@piiswrong Rebased. All tests pass. |
DMLC_REGISTER_PARAMETER(KhatriRaoParam); | ||
|
||
|
||
NNVM_REGISTER_OP(khatri_rao) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This op is supposed to be registered under contrib
namespace.
* Implement Khatri-Rao operator Create an operator for the Khatri-Rao function implemented in apache#6567. The operator accepts a variable number of input matrix arguments as well as a flag indicating whether the Khatri-Rao product should be computed row-wise or column-wise. This option is provided because the row-wise operation is more memory efficient. * Implement Khatri-Rao operator Remove row-wise flag from parameter/keyword set. That is, only compute the column-wise Khatri-Rao product. * Implement Khatri-Rao operator Remove unnecessary type-specific `flip` code. The fully-templatized version handles the `float` and `double` instances.
* Implement Khatri-Rao operator Create an operator for the Khatri-Rao function implemented in apache#6567. The operator accepts a variable number of input matrix arguments as well as a flag indicating whether the Khatri-Rao product should be computed row-wise or column-wise. This option is provided because the row-wise operation is more memory efficient. * Implement Khatri-Rao operator Remove row-wise flag from parameter/keyword set. That is, only compute the column-wise Khatri-Rao product. * Implement Khatri-Rao operator Remove unnecessary type-specific `flip` code. The fully-templatized version handles the `float` and `double` instances.
Create an operator for the Khatri-Rao function implemented in #6567. The operator accepts a variable number of input matrix arguments as well as a flag indicating whether the Khatri-Rao product should be computed row-wise or column-wise. This option is provided because the row-wise operation is more memory efficient.
A Note to Reviewers
This is my first time putting MXNet code up for review so please tear this code apart. I need to learn.
Additionally, I've knowingly omitted the Backwards/Gradient calculation.