-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Description
The models that are now allowed following #7656 have a disconnected node in the model graph.
The sampling is as expected. It is just the graphviz representation that is incorrect.
import numpy as np
import pymc as pm
from pymc.model_graph import ModelGraph
seed = sum(map(ord, "Observed disconnected node"))
rng = np.random.default_rng(seed)
true_mu = 100
true_sigma = 30
n_obs = 10
coords = {
"date": np.arange(n_obs),
}
dist = pm.Normal.dist(mu=true_mu, sigma=true_sigma, shape=n_obs)
data = pm.draw(dist, random_seed=rng)
scaling = data.max()
with pm.Model(coords=coords) as model:
mu = pm.Normal("mu")
sigma = pm.HalfNormal("sigma")
target = pm.Data("target", data, dims="date")
scaled_target = target / scaling
pm.Normal("observed", mu=mu, sigma=sigma, observed=scaled_target, dims="date")
pm.model_to_graphviz(model).render("scaled_target")
ModelGraph(model).make_compute_graph()
The observed should have "target" in the compute_graph
defaultdict(set,
{'mu': set(),
'sigma': set(),
'target': set(),
'observed': {'mu', 'sigma'}})
Seems like it needs a fix here:
Lines 322 to 343 in af81955
if var in self.model.observed_RVs: | |
obs_node = self.model.rvs_to_values[var] | |
# loop created so that the elif block can go through this again | |
# and remove any intermediate ops, notably dtype casting, to observations | |
while True: | |
obs_name = obs_node.name | |
if obs_name and obs_name != var_name: | |
input_map[var_name] = input_map[var_name].difference({obs_name}) | |
input_map[obs_name] = input_map[obs_name].union({var_name}) | |
break | |
elif ( | |
# for cases where observations are cast to a certain dtype | |
# see issue 5795: https://github.com/pymc-devs/pymc/issues/5795 | |
obs_node.owner | |
and isinstance(obs_node.owner.op, Elemwise) | |
and isinstance(obs_node.owner.op.scalar_op, Cast) | |
): | |
# we can retrieve the observation node by going up the graph | |
obs_node = obs_node.owner.inputs[0] | |
else: | |
break |
Metadata
Metadata
Assignees
Labels
No labels