for i in range(total_nodes):
for j in range(total_nodes):
if binary_mask[i, j] == 1:
dg.add_edge(i, j)
pos[i] = 2. * np.array(number_to_type_layer(i, n_types))[::-1]
After Change
dg = get_digraph_from_binary_mask(nodes, binary_mask)
pos = {}
val_map = {}
sources, sinks = find_sources_and_sinks(dg)
for i in range(total_nodes):
pos[i] = 2. * np.array(number_to_type_layer(i, n_types))[::-1]
if i in sources:
val_map[i] = 1.
elif i in sinks:
val_map[i] = 0.5
else:
val_map[i] = 0.
plt.figure(figsize=(12, 12))
values = [val_map.get(node, 0.25) for node in nodes]
nx.draw(dg, pos, cmap=plt.get_cmap("jet"), node_color=values, node_size=7000, alpha=0.3)
nx.draw_networkx_labels(dg, pos, nodes, font_size=18)