Skip to content

Conversation

r3v1
Copy link
Contributor

@r3v1 r3v1 commented Aug 8, 2023

Addresses #3134.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @r3v1, thanks for taking on this task! Could you describe your approach a bit in the PR description? 🙏

It's been almost a year since I looked into this, but I recall one issue was that Pyro's visualization reuses inference machinery for tracking dependency, and the goals of these two uses are a bit at odds: in inference one wants to ignore deterministic dependencies and simply track stochastic latent variables because only those stochastic latent variables can be freely varied during inference. By contrast in visualization one often wants to inspect the model, including in many cases deterministic dependencies or even compute graphs.

Now again it's been a while but I think where I left off on this task, it seemed challenging to separate these two concerns, like we might need to fork some of the provenance tracking and maybe add some conditional flags so one could optionally trace deterministic nodes (for visualization) or optionally elide them (for inference). I'm not sure how you PR might address these issues, but I think it's worth examining before making changes that affect inference. 🙂

EDIT could you also please describe your merge intention? I see you're merging into my old branch, but maybe you should just create a new branch incorporating those old changes? And thanks again for continuing this line of work!

@r3v1
Copy link
Contributor Author

r3v1 commented Aug 10, 2023

Of course @fritzo, I could explain the process I have gone through.

First, I started by studying the unit tests and seeing what the get_model_relations function was expected to return depending on the model, which returned a dictionary representing a graph. Next, I started debugging inside that method the logic you had implemented and started to understand how the different parts "exchange messages" (dictionaries) and I noticed that everything resided inside the apply_stack call.

When calling _post_process_message(), it checks if the TrackProvenance object has a method f"_pyro_post_{}".format(msg["type"]), but as being deterministic, it was not implemented yet, so I need to implement it. To do it, I just remembered watching the same structured code somewhere in Numpyro's code, so that's what I applied. In the same way, the _pyro_post_deterministic of TraceMessenger was implemented.

Of course, the primitive deterministic should be modified to match the above. I just took another look to Numpyro's deterministic to take some ideas. I noticed that their implementation differs, so I gave a try and implemented something similar. I ended up with the idea submitted, emulating the same functionality.

Finally, I apply some corrections to the core of get_model_relations to recognize deterministic node's parents and ended applying the same visualization (dashed) to the rendered graph as Numpyro.

That`s it. Hope it works, or at least, it do the job on my project 🤓

About the merge intention, I just take the render-deterministic branch to continue the work done. It cannot be merged into the dev branch, can it? Or can it?

@r3v1
Copy link
Contributor Author

r3v1 commented Sep 13, 2023

Hi, any advances on this?

@eb8680 eb8680 changed the base branch from render-deterministic to dev September 13, 2023 15:22
@eb8680
Copy link
Member

eb8680 commented Sep 13, 2023

@r3v1 could you revert your changes to the body of pyro.deterministic and refactor your new handler methods for deterministic to fit within handler methods for sample? Changing the type of deterministic will break some things elsewhere in Pyro and downstream.

You can just use a helper function to check within a _pyro_sample/_pyro_post_sample if a sample site came from deterministic:

def site_is_deterministic(msg: dict) -> bool:
    return msg["type"] == "sample" and msg["infer"].get("_deterministic", False)

...

class ...:
    ...
    def _pyro_sample(self, msg):
        if site_is_deterministic(msg):
            # new logic for handling deterministic sites here
            ...
            return

        # old logic for handling sample sites here
        ...

@eb8680
Copy link
Member

eb8680 commented Sep 13, 2023

@r3v1 I also just changed the base branch to dev, can you resolve the merge conflicts and push so that CI can run tests?

@r3v1
Copy link
Contributor Author

r3v1 commented Sep 14, 2023

Ok, changes made in pull request #3266 to dev branch, so this branch may be closed

@eb8680 eb8680 closed this Sep 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants