1
votes

The following code generates adjacency matrix using nx.grid_2d. I use an array P to color the nodes but there is an error while generating colorbar. I want the colorbar to display the range of node colors.

import numpy as np
import networkx as nx
import matplotlib
import matplotlib.pyplot as plt

G = nx.grid_2d_graph(3,3)
new_nodes = {e: n for n, e in enumerate(G.nodes, start=1)}
new_edges = [(new_nodes[e1], new_nodes[e2]) for e1, e2 in G.edges]
G = nx.Graph()
G.add_edges_from(new_edges)
#nx.draw(G, with_labels=True)

A1 = nx.adjacency_matrix(G) 
A=A1.toarray()
print([A]) 

fig, ax = plt.subplots(1, 1)
P=np.array([1,2,5,4,5,6,7,2,10])
D=nx.draw(G, with_labels=True, node_color=[P], node_size=1000, cmap='Blues')
fig.colorbar(D)
plt.show()

enter image description here

The error is

in <module>
    fig.colorbar(D)

  File "C:\Users\USER\anaconda3\lib\site-packages\matplotlib\figure.py", line 1176, in colorbar
    cb = cbar.Colorbar(cax, mappable, **cb_kw)

  File "C:\Users\USER\anaconda3\lib\site-packages\matplotlib\colorbar.py", line 1171, in __init__
    if mappable.get_array() is not None:

AttributeError: 'NoneType' object has no attribute 'get_array'
1

1 Answers

1
votes

nx.draw appears to return None instead of a mappable. You can simulate a suitable mappable using the ScalarMappable class in plt.cm.

enter image description here

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import networkx as nx

G = nx.grid_2d_graph(3,3)
new_nodes = {e: n for n, e in enumerate(G.nodes, start=1)}
new_edges = [(new_nodes[e1], new_nodes[e2]) for e1, e2 in G.edges]
G = nx.Graph()
G.add_edges_from(new_edges)
#nx.draw(G, with_labels=True)

A1 = nx.adjacency_matrix(G)
A=A1.toarray()
print([A])

fig, ax = plt.subplots(1, 1)
P=np.array([1,2,5,4,5,6,7,2,10])
D=nx.draw(G, with_labels=True, node_color=[P], node_size=1000, cmap='Blues')

# -------------------------------------
norm = matplotlib.colors.Normalize(vmin=np.min(P), vmax=np.max(P))
sm = plt.cm.ScalarMappable(cmap='Blues', norm=norm)
sm.set_array([])
fig.colorbar(sm)
# -------------------------------------

plt.show()