Skip to content

Conversation

keehyuna
Copy link
Collaborator

@keehyuna keehyuna commented May 20, 2024

Description

Dynamic shape cannot be used in where ops because of exception in torch.broadcast_shapes().
Proposed fix removes expand() for static shape input and only performs prepend ones to have same rank size for static/dynamic shape input. I think broadcast is not required as it is applied to ISelectLayer(addSelect)

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels May 20, 2024
@github-actions github-actions bot requested a review from apbose May 20, 2024 08:43
@keehyuna keehyuna force-pushed the where_dynamic_shape branch from f72e842 to 00d1d85 Compare May 20, 2024 10:52
@github-actions github-actions bot added the component: tests Issues re: Tests label May 20, 2024
@@ -364,9 +364,13 @@ def example_tensor(
)

if isinstance(self.shape, dict):
return torch.rand(self.shape[optimization_profile_field]).to(
dtype=self.dtype.to(torch.dtype, use_default=True)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If dtype is torch.bool in input_spec, torch.rand() returns random numbers on [0,1) and casting it to bool value is True

@keehyuna keehyuna requested a review from peri044 May 21, 2024 00:53
@keehyuna keehyuna self-assigned this May 21, 2024
@keehyuna keehyuna force-pushed the where_dynamic_shape branch from fc6499e to 5e6f3bd Compare May 24, 2024 06:09
@keehyuna keehyuna marked this pull request as ready for review May 24, 2024 07:14
@keehyuna keehyuna requested review from gs-olive and narendasan May 24, 2024 07:15
@chohk88 chohk88 self-requested a review May 29, 2024 05:33
Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please mark torch.where converter in aten_ops_converters.py with supports_dynamic_shapes=True flag. Example : https://github.com/pytorch/TensorRT/blob/release/2.3/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py#L62

Copy link
Collaborator

@chohk88 chohk88 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have left some comments. It looks like there aren't any major issues with the functionality.

@keehyuna keehyuna force-pushed the where_dynamic_shape branch from 5e6f3bd to ac4bf90 Compare May 30, 2024 06:12

def get_axes_for_reduce_op(
dim: Union[int, Sequence[int]],
has_implicit_batch_dimension: bool = False,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default parameter was used to merge below function
get_axes_for_reduce_op = functools.partial(
get_axes_for_reduce_op, has_implicit_batch_dimension=False
)

@chohk88
Copy link
Collaborator

chohk88 commented Jun 3, 2024

Looks good to me!

Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peri044 I think it's reasonable to deprecate FX converter_utils in Dynamo converter implementation. In my opinion, when copying these helper functions from FX utils, it would be better to change the first arg from network: TRTNetwork to ctx: ConversionContext for two reasons: 1) consistent with other helper functions 2) convenient to call them, just passing in ctx.


def get_axes_for_reduce_op(
dim: Union[int, Sequence[int]],
has_implicit_batch_dimension: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peri044 do we still have has_implicit_batch_dimension? Is it possible to remove the arg in the dynamo's converter_utils?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We should remove has_implicit_batch_dimension.
  2. It would be better to change the first arg from network: TRTNetwork to ctx: ConversionContext.

Yes we should make these changes and not use FX code/ data structures as much as possible. If we use them, we should be consistent with dynamo APIs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. Removed has_implicit_batch_dimension flag in get_axes_for_reduce_op()
first arg of boardcast()/prepend_ones() is ctx: ConversionContext, instead of network: TRTNetwork

@keehyuna keehyuna force-pushed the where_dynamic_shape branch 2 times, most recently from f980697 to a12754b Compare June 6, 2024 01:25
@keehyuna
Copy link
Collaborator Author

Moved "chore: Better random bool values for example_tensor" change into PR2878 to combine random input changes to one PR.
436e99b

@keehyuna keehyuna requested a review from peri044 June 10, 2024 03:21
Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. added a few more minor changes.

Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@peri044 peri044 merged commit ac702b7 into pytorch:main Jun 12, 2024
@keehyuna keehyuna deleted the where_dynamic_shape branch August 19, 2024 03:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants