import matplotlib.pyplot as plt
import collections
from IPython import display
import networkx as nx
import numpy as np
import time
class Color:
# 绘图配色
# 基本的节点配色
basic_node_color = '#6CB6FF'
# 初始节点配色
start_node_color = 'y'
# 目标节点配色
target_node_color = 'r'
# 已访问的节点配色
visited_node_color = 'g'
# 基本边的配色
basic_edge_color = 'b'
# 已访问的边的配色
visited_edge_color = 'g'
# 抵达目标节点时的配色
success_color = 'r'
class Graph(nx.Graph):
def __init__(self,
start_node=None,
target_node=None):
# 初始化父类
nx.Graph.__init__(self)
# 绘制 graph 时各个节点的坐标值
self.nodes_pos = {}
# 搜索起点
self.start_node = start_node
# 搜搜目标点
self.target_node = target_node
# 搜索的最大深度
self.max_depth = None
# bfs 的历史查找路径,例如 ['A','AB', 'AD', 'ABC']
self.bfs_paths = []
# dfs 的历史查找路径,例如 ['A','AB', 'ABC', 'AD']
self.dfs_paths = []
# 每条历史路径的分数(历史距离)
self.path_score = {}
# 绘制搜索数时每个节点的坐标值
self.tree_node_position = {}
# 额外的辅助信息值,用于进行启发式搜索(贪婪或 A-star 搜索)
self.help_info = {}
# 额外辅助信息的权重
self.help_info_weight = 1
# 原始信息的权重
self.origin_info_weight = 1
# a_star 算法的历史查找路径
self.a_star_search_paths = []
# a_star 算法的中路径的得分
self.a_star_search_scores = {}
#
self.changed = True
def set_start_node(self, start_node):
"""
设置起点
:param start_node: 起点
:return:
"""
if start_node not in self.nodes:
print('起点不在图中,请输入正确的起点')
return
self.start_node = start_node
self.path_score = {self.start_node: 0}
self.changed = True
def set_target_node(self, target_node):
"""
设置目标点
:param target_node: 目标点
:return:
"""
if target_node not in self.nodes:
print('目标点不在图中,请输入正确的目标点')
return
self.target_node = target_node
self.path_score = {self.start_node: 0}
self.changed = True
def set_nodes_pos(self, nodes_pos):
"""
设置 graph 内各个节点的坐标,绘图用,与搜索无关
:param nodes_pos: 字典,节点的坐标,
例如 {"A": (1, 1), "B": (3, 3), "C": (5, 0)}
:return:
"""
self.nodes_pos = nodes_pos
def set_max_depth(self, max_depth):
"""
设置最大搜索深度
:param max_depth: 整数,大于 0
:return:
"""
max_depth = min(max_depth, len(self.nodes))
self.max_depth = max_depth
self.changed = True
def set_help_info(self, help_info):
"""
设置辅助信息,用于进行启发式搜索(贪婪或 A-star 搜索)
:param help_info: 字典,
例如 {'A': 30, 'B': 20, 'C': 19 },为各个点到目标点的距离
:return:
"""
self.help_info = help_info
def show_graph(self, this_path=''):
"""
绘制 graph
:param this_path: 设置一条图中的路径
:return:
"""
# 当不传入路径时,默认在初始节点
this_path = this_path or self.start_node
# 根据当前你路径,处理节点和边的颜色
# 根据路径得到已访问的边,例如 'ABC' 得到 ['AB', 'BC']
visited_edges = []
for i in range(1, len(this_path)):
visited_edges.append(
frozenset([this_path[i], this_path[i - 1]]))
# 节点和边以及其显示的标签
node_labels = dict(zip(self.nodes(), self.nodes()))
edge_labels = dict(
[((u, v,), d['weight']) for u, v, d in self.edges(data=True)])
# 处理节点的颜色
node_color_map = {e_node: Color.visited_node_color for e_node in
this_path}
node_color_map[self.start_node] = Color.start_node_color
node_color_map[self.target_node] = Color.target_node_color
node_color = [node_color_map.get(node, Color.basic_node_color)
for node in self.nodes()]
# 处理每条边的颜色
edge_color = []
for edge in self.edges():
if frozenset(edge) in visited_edges:
edge_color.append(Color.visited_edge_color)
else:
edge_color.append(Color.basic_edge_color)
# 创建绘图
fig, ax = plt.subplots()
# 定义绘图的宽和高,并关闭坐标轴的显示
fig.set_figwidth(6)
fig.set_figheight(8)
plt.axis('off')
# 绘制节点及其标签
nx.draw_networkx_nodes(self, self.nodes_pos, node_size=800,
node_color=node_color, width=6.0)
nx.draw_networkx_labels(self, self.nodes_pos, node_labels,
font_size=20)
# 绘制边及其标签
nx.draw_networkx_edges(self, self.nodes_pos, edge_color=edge_color,
width=2.0, alpha=1.0)
nx.draw_networkx_edge_labels(self, self.nodes_pos,
edge_labels=edge_labels, font_size=18)
# 清除绘图区,显示新绘图
display.clear_output(wait=True)
plt.show()
def bfs_search(self):
"""
使用迭代法进行广度优先搜索
:return:
"""
# 待访问的路径
to_search = [self.start_node]
# 存储所有的已访问的路径
bfs_paths = []
# 当还有待访问的路径时
while to_search:
# 从待访问的路径中取第一个待访问路径
this_search = to_search.pop(0)
# 如果待访问的路径超过最大搜索深度,跳出循环
if len(this_search) > self.max_depth + 1:
break
# 把刚取出的路径存入已访问的路径中
bfs_paths.append(this_search)
# 如果路径的最后一个节点是目标节点,路径 AC 的最后一个节点是 C
if this_search[-1] == self.target_node:
# 其为一条正确的路径,将存入正确的路径列表中,
# 并不再继续往其子节点进行探索
continue
# 找到路径最后一个节点的相邻节点
else:
for ne in sorted(self.neighbors(this_search[-1])):
# 如果相邻节点不在路径中,即不存在回路
if ne not in this_search:
# 则加入到待访问的路径中
to_search.append(this_search + ne)
self.bfs_paths = bfs_paths
def _dfs_helper(self, node, target_node, level, dfs_paths, path):
"""
深度优先搜索的辅助函数
:param node: 当前节点
:param target_node: 目标点
:param level: 搜索深度
:param dfs_paths: dfs 的历史搜索路径
:param path: 从哪一个路径来到当前节点
:return:
"""
path += str(node)
# 更新路径的分数(距离)
if len(path) > 1:
self.path_score[path] = self.path_score[path[:-1]] + \
self.edges[path[-2], path[-1]]['weight']
# 存储 dfs 的历史搜索路径
dfs_paths.append(path)
# 找到目标,停止搜索
if node == target_node:
return
# 未达到最大搜索深度时,继续下一层搜索
if level < self.max_depth:
# 对当前节点的每一个相邻节点
for neighbor in sorted(self.neighbors(node)):
# 如果该相邻节点不在路径中,即没有出现回环,则递归调用,继续往下搜索
if str(neighbor) not in path:
self._dfs_helper(neighbor, target_node, level + 1,
dfs_paths, path)
def dfs_search(self):
"""
使用递归法进行深度优先搜索
:return:
"""
# dfs 的历史搜索路径
dfs_paths = []
this_path = ''
if self.start_node and self.target_node:
self._dfs_helper(self.start_node, self.target_node,
0, dfs_paths, this_path)
else:
print('请设置起点和目标点')
self.dfs_paths = dfs_paths
# 完成搜索后,可得到搜索树中各个节点的坐标
self.get_search_tree_node_position()
def dfs_bfs_search(self):
"""
如果对起始点,目标点或搜索深度进行了设置,需要重新绘制搜索树
:return:
"""
self.dfs_search()
self.bfs_search()
self.changed = False
def get_search_tree_node_position(self):
"""得到绘图时各个节点的坐标
"""
# 得到 dfs 的搜索路径图
paths = self.dfs_paths
# 得到每条路径的子路径
path_children = {}
for path in paths:
father = path[:-1]
if father in paths:
if father in path_children:
path_children[father].append(path)
else:
path_children[father] = [path]
# 对每条子路径排序
o_path_children = collections.OrderedDict(
sorted(path_children.items()))
# 计算每个树图中每个节点的位置
tree_node_position = {self.start_node: (1, 0, 2)}
for path, sub_paths in o_path_children.items():
y_pos = -1.0 / self.max_depth * len(path)
dx = tree_node_position[path][2] / len(sub_paths)
sub_paths.sort()
for index, e_s in enumerate(sub_paths):
x_pos = tree_node_position[path][0] - tree_node_position[path][
2] / 2 + dx / 2 + dx * index
tree_node_position[e_s] = (x_pos, y_pos, dx)
self.tree_node_position = tree_node_position
def a_star_search(self, help_info_weight=1, origin_info_weight=0):
"""
a-star 搜索, 当 origin_info_weight 为 0 时,则退化为贪婪搜索
:param help_info_weight: 辅助信息的比重
:param origin_info_weight: 原始信息的比重
:return:
"""
if not self.help_info:
print('缺少额外的辅助信息')
return
# 存到类属性中,便于绘图时使用
self.help_info_weight = help_info_weight
self.origin_info_weight = origin_info_weight
# 初始路径为起点
search_path = self.start_node
# 存储每一步的可选项及其分数,用来在动态演示时显示出来
search_scores = {}
# 当搜索路径未超过最大搜索深度
while len(search_path) <= self.max_depth:
# 当前的节点
this_node = search_path[-1]
# 当前节点的子节点
neighbour_nodes = [e_n for e_n in sorted(self.neighbors(this_node))
if e_n not in search_path]
# 如果没有子节点,则跳出循环,结束搜索
if len(neighbour_nodes) == 0:
search_scores[search_path] = {}
break
# 计算每个子节点的得分并存储
scores = {e_n: help_info_weight * self.help_info[
e_n] + origin_info_weight * self.edges[this_node, e_n][
'weight'] for e_n in
neighbour_nodes}
search_scores[search_path] = scores
# 挑选最佳的子节点,并添加到路径中
nearest_node = min(scores, key=scores.get)
search_path += nearest_node
# 如果最佳的子节点是目标节点,跳出循环,结束搜索
if nearest_node == self.target_node:
break
# 把最终路径切分为每一步,便于动态展示,例如 ABCD 变为 [A, AB, ABC, ABCD ]
self.a_star_search_paths = [search_path[0:index + 1] for index in
range(len(search_path))]
self.a_star_search_scores = search_scores
def greedy_search(self):
"""
贪婪搜索,就是 origin_info_weight权重为 0 时的 a-star 搜索
:return:
"""
self.a_star_search(help_info_weight=1, origin_info_weight=0)
@staticmethod
def show_edge_labels(ax, pos1, pos2, label):
"""
绘制搜索树的边
:param ax: 子图
:param pos1: 点 1 的坐标
:param pos2: 点 2 的坐标
:param label: 连接点 1 和点 2 的边上的文字
:return:
"""
# 点1
(x1, y1) = pos1
# 点2
(x2, y2) = pos2
# 文字的位置
(x, y) = (x1 * 0.5 + x2 * 0.5, y1 * 0.5 + y2 * 0.5)
# 文字的角度
angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360
if angle > 90:
angle -= 180
if angle < - 90:
angle += 180
xy = np.array((x, y))
trans_angle = ax.transData.transform_angles(np.array((angle,)),
xy.reshape((1, 2)))[0]
# 绘制文字框和文字
bbox = dict(boxstyle='round',
ec=(1.0, 1.0, 1.0),
fc=(1.0, 1.0, 1.0),
)
label = str(label)
ax.text(x, y,
label,
size=16,
color='k',
alpha=1,
horizontalalignment='center',
verticalalignment='center',
rotation=trans_angle,
transform=ax.transData,
bbox=bbox,
zorder=1,
clip_on=True,
)
def show_search_tree(self,
animation_type='bfs',
top_text='',
bottom_text='',
this_path=None,
show_success_color=False,
):
"""
展示搜索树,动态展示搜索过程时,会调用此方法
:param animation_type: 动态演示的类型,如果是启发式搜索,边的权重需要变化
:param top_text: 上方的文字展示
:param bottom_text: 下方的文字展示
:param this_path: 当前路径
:param show_success_color: 成功找到目标点后,路径颜色的变换
:return:
"""
# 如果对起始点,目标点或搜索深度进行了设置,需要重新绘制搜索树
if self.changed is True:
self.dfs_bfs_search()
# 创建子图
fig, ax = plt.subplots()
# 定义绘图的宽度
fig.set_figwidth(15)
# 定义绘图的高度
fig.set_figheight(self.max_depth * 1.5)
# 关闭绘图中坐标轴的显示
plt.axis('off')
# 对每条路径
for path, pos in self.tree_node_position.items():
# 如果是初始点
if path[-1] == self.start_node:
node_color = Color.start_node_color
edge_color = Color.basic_edge_color
# 把当前路径的节点和边的颜色变为已访问的颜色
elif this_path and path in this_path:
# 是否显示成功找到目标点
if show_success_color:
node_color = Color.success_color
edge_color = Color.success_color
else:
node_color = Color.visited_node_color
edge_color = Color.visited_edge_color
# 如果路径的终点是目标点,改变目标点的颜色
elif path[-1] == self.target_node:
node_color = Color.target_node_color
edge_color = Color.basic_edge_color
# 其他的情况下,节点和边的颜色是正常色
else:
node_color = Color.basic_node_color
edge_color = Color.basic_edge_color
# 绘制节点
ax.scatter(pos[0], pos[1], c=node_color, s=1000, zorder=1)
# 绘制节点的标注
plt.annotate(
path[-1],
xy=(pos[0], pos[1]),
xytext=(0, 0),
textcoords='offset points',
ha='center',
va='center',
size=15, )
if len(path) > 1:
# 绘制边
plt.plot([self.tree_node_position[path[:-1]][0], pos[0]],
[self.tree_node_position[path[:-1]][1], pos[1]],
color=edge_color,
zorder=0)
# 绘制边的标注
label = self.edges[path[-2], path[-1]]['weight']
if animation_type in ['greedy', 'a_star']:
label = self.help_info_weight * self.help_info[
path[-1]] + self.origin_info_weight * label
self.show_edge_labels(ax,
self.tree_node_position[path[:-1]][
0:2], pos[0:2], label)
# 绘制上方文字
plt.text(0,
0,
top_text,
fontsize=18,
horizontalalignment='left',
verticalalignment='top', )
# 绘制下方文字
plt.text(0,
-1.1,
bottom_text,
fontsize=18,
horizontalalignment='left',
verticalalignment='top', )
# 刷新绘图
display.clear_output(wait=True)
plt.show()
def _generate_bottom_text(self, show_correct_path):
"""
生成目前找到的最佳路径的信息的文字
:param show_correct_path: 当前找到的正确的路径
:return:
"""
# 默认不展示文字
bottom_text = ""
# 对每一条找到的正确路径,增加一条展示文本s
for path in show_correct_path:
bottom_text += '找到一条路径: %-7s' % path + '。距离为:' + str(
self.path_score[path]) + '\n'
return bottom_text
def _generate_a_star_help_text(self, path):
"""
生成贪婪搜索和 a-star 动态展示时的文字
:param path: 当前路径
:return:
"""
# 如果到达目标节点
if path[-1] == self.target_node:
return '抵达目标节点' + str(self.target_node)
# 如果未抵达目标节点并且抵达了最大搜索深度
elif path not in self.a_star_search_scores:
return '未找到目标节点, 结束搜索'
# 其他情况,展示当前节点可选节点的信息,以及挑选的原因
else:
base_text = '当前可选的子节点及其信息值为 \n' + \
str(self.a_star_search_scores[path]) + '\n'
if self.target_node in self.a_star_search_scores[path]:
return base_text + '当前可选的子节点包含了目标节点,\n所以选择目标节点'
elif len(self.a_star_search_scores[path]) == 1:
return base_text + '因为只有一个子节点,所以选择此节点'
else:
return base_text + '因为' + \
str(min(self.a_star_search_scores[path],
key=self.a_star_search_scores[path].get)) + \
'的值最小,所以选择此节点'
def _generate_top_text(self,
animation_type,
this_path,
best_path=None,
finish=False):
"""
生成展示当前路径的信息文字
:param animation_type:
:param this_path:
:param best_path:
:param finish:
:return:
"""
# 如果结束搜索,展示最终的最短路径
if finish:
if this_path:
top_text = '最终最短路径为: %-7s' % this_path + '。距离为:' + str(
self.path_score[this_path]) + '\n'
else:
top_text = '未找到正确路径'
# 如果是其他路径,并且是展示 dfs 或 bfs 的搜索过程,展示当前路径的信息
elif this_path and animation_type in ['dfs', 'bfs']:
top_text = '当前路径: %-7s' % this_path + '。距离为:' + str(
self.path_score[this_path]) + '\n'
if best_path:
top_text += '当前最短路径为: %-7s' % best_path + '。距离为:' + \
str(self.path_score[best_path]) + '\n'
# 如果是其他路径,并且是展示 贪婪搜索 或 A-star 算法的搜索过程,展示当前路径的信息
elif this_path and animation_type in ['greedy', 'a_star']:
top_text = self._generate_a_star_help_text(this_path)
# 其他情况,不展示文字
else:
top_text = ''
return top_text
def animate_search_tree(self,
animation_type='dfs',
help_info_weight=1,
origin_info_weight=1,
sleep_time=0):
"""
动态演示搜索过程
:param animation_type: 可选项为 ['bfs', 'dfs', 'greedy', 'a_star']
:param help_info_weight: 附加信息的权重值
:param origin_info_weight: 原始信息的权重值
:param sleep_time: 设置每一步的等待时间
:return:
"""
# 如果对起始点,目标点或搜索深度进行了设置,需要重新绘制搜索树
if self.changed is True:
self.dfs_bfs_search()
# 根据展示的搜索方式,获取展示的路径列表
if animation_type == 'bfs':
paths = self.bfs_paths
elif animation_type == 'dfs':
paths = self.dfs_paths
elif animation_type == 'greedy':
self.greedy_search()
paths = self.a_star_search_paths
elif animation_type == 'a_star':
self.a_star_search(help_info_weight=help_info_weight,
origin_info_weight=origin_info_weight)
paths = self.a_star_search_paths
else:
print('animation_type 参数错误,请从 dfs、bfs、 greedy 或 a_star 中挑选一个')
return
if animation_type in ['bfs', 'dfs']:
show_correct_path = []
# 动态演示过程中找到的最佳路径
best_path = None
# 对路径列表中的每一个路径,绘图
for e_path in paths:
top_text = self._generate_top_text(animation_type,
e_path,
best_path=best_path,
finish=False)
bottom_text = self._generate_bottom_text(show_correct_path)
self.show_search_tree(top_text=top_text,
bottom_text=bottom_text,
this_path=e_path)
# 设置等待时间,避免切换过快
time.sleep(sleep_time)
# 如果该路径是正确路径
if e_path[-1] == self.target_node:
# 如果是第一个正确路径,则其为当前最佳路径
if not best_path:
best_path = e_path
# 如果不是,与当前最佳路径比较,
elif self.path_score[e_path] < self.path_score[best_path]:
best_path = e_path
# 增加一条最佳路径的展示
show_correct_path.append(e_path)
bottom_text = self._generate_bottom_text(show_correct_path)
self.show_search_tree(top_text=top_text,
bottom_text=bottom_text,
this_path=e_path,
show_success_color=True)
# 设置等待时间,避免切换过快
time.sleep(sleep_time)
# 搜索结束后,展示最佳路径
top_text = self._generate_top_text(animation_type,
best_path,
best_path=best_path,
finish=True
)
bottom_text = self._generate_bottom_text(show_correct_path)
self.show_search_tree(top_text=top_text,
bottom_text=bottom_text,
this_path=best_path,
show_success_color=True)
else:
# 对路径列表中的每一个路径,绘图
for e_path in paths:
top_text = self._generate_top_text(animation_type,
e_path,
best_path=False)
self.show_search_tree(top_text=top_text,
this_path=e_path)
# 设置等待时间,避免切换过快
time.sleep(sleep_time)
# 如果抵达目标点
if e_path[-1] == self.target_node:
top_text = self._generate_top_text(animation_type,
e_path,
best_path=True)
self.show_search_tree(top_text=top_text,
this_path=e_path,
show_success_color=True)