master
/ Untitled.ipynb

Untitled.ipynb @masterview markup · raw · history · blame

Notebook
In [1]:
import networkx as nx
In [9]:
def bfs_search(G, max_depth, start_node, target_node):
    # 待访问的路径
    to_search = [(start_node, 0)]
    # 存储所有的历史路径,及此路径的距离
    bfs_path = []
    # 正确的路径列表,及此路径的距离
    bfs_correct_path = []
    # 当还有待访问的路径时
    while to_search:
        # 从待访问的路径中取第一个待访问路径及其路径长度,例如 AC
        this_path, this_path_dis = to_search.pop(0)
        # 如果待访问的路径达到最大搜索深度,跳出循环
        if len(this_path) > max_depth :
            break
        # 把刚取出的路径存入历史路径中
        bfs_path.append((this_path, this_path_dis))
        # 如果路径的最后一个节点是目标节点,路径 AC 的最后一个节点是 C
        if this_path[-1] == target_node:
            # 其为一条正确的路径,将存入正确的路径列表中,
            # 并不再继续往其子节点进行探索
            bfs_correct_path.append((this_path, this_path_dis))
            continue
        # 找到路径最后一个节点的相邻节点
        for ne in sorted(G[this_path[-1]]):
            # 如果相邻节点不在路径中,即不存在回路
            if ne not in this_path:
                # 则加入到待访问的路径中
                to_search.append((this_path + ne,
                                  this_path_dis + G[this_path[-1]][ne][
                                      'weight']))
    return bfs_path, bfs_correct_path
In [10]:
# 定义节点列表
node_list = ['A', 'B', 'C', 'D', 'E', 'F', 'G']

# 定义边及权重列表
weighted_edges_list = [('A', 'B', 8), ('A', 'C', 20),
                       ('B', 'F', 40), ('B', 'E', 30),
                       ('B', 'D', 20), ('C', 'D', 10), 
                       ('D', 'G', 10), ('D', 'E', 10),
                       ('E', 'F', 30), ('F', 'G', 30)]

# 定义绘图中各个节点的坐标
nodes_pos = {"A": (1, 1), "B": (3, 3), "C": (5, 0), "D": (9, 2),
             "E": (7, 4), "F": (6,6),"G": (11,5)}

G = nx.Graph()
G.add_nodes_from(node_list)
G.add_weighted_edges_from(weighted_edges_list)
In [13]:
bfs_path, bfs_correct_path = bfs_search(G, 3, 'A', 'G')
In [14]:
paths = [e[0] for e in bfs_path]
In [23]:
import collections
def get_search_tree_node_position(paths):
    """得到绘图时各个节点的坐标
    """
    max_depth = 3 
    # 得到每条路径的子路径
    path_childern = {}
    for path in paths:
        father = path[:-1]
        if father in paths:
            if father in path_childern:
                path_childern[father].append(path)
            else:
                path_childern[father] = [path]
    # 对每条子路径排序
    o_path_childern = collections.OrderedDict(
        sorted(path_childern.items()))
    # 计算每个树图中每个节点的位置
    tree_node_position = {paths[0][0]: (1, 0, 2)}
    for path, sub_paths in o_path_childern.items():
        y_pos = -1.0 / max_depth * len(path)
        dx = tree_node_position[path][2] / len(sub_paths)
        sub_paths.sort()
        for index, e_s in enumerate(sub_paths):
            x_pos = tree_node_position[path][0] - tree_node_position[path][
                2] / 2 + dx / 2 + dx * index
            tree_node_position[e_s] = (x_pos, y_pos, dx)
    print(tree_node_position)
In [24]:
get_search_tree_node_position(paths)
{'A': (1, 0, 2), 'ACD': (1.5, -0.6666666666666666, 1.0), 'AC': (1.5, -0.3333333333333333, 1.0), 'ABD': (0.16666666666666666, -0.6666666666666666, 0.3333333333333333), 'AB': (0.5, -0.3333333333333333, 1.0), 'ABF': (0.8333333333333333, -0.6666666666666666, 0.3333333333333333), 'ABE': (0.5, -0.6666666666666666, 0.3333333333333333)}
In [ ]:
import matplotlib.pyplot as plt
import collections
from IPython import display
import networkx as nx
import numpy as np
import time


class SearchGraph():
    def __init__(self,
                 node_list, 
                 weighted_edges_list, 
                 start_node,
                 target_node,
                 max_depth=1000,
                 nodes_pos=None,
                 help_info=None,):
        # 图中的节点
        self.node_list = node_list
        self.weighted_edges_list = weighted_edges_list
        self.start_node = start_node
        self.target_node = target_node
        self.nodes_pos = nodes_pos
        self.max_depth = min(max_depth, len(node_list))
        self.temp_best_path = None
        
        self.weighted_edges_dic = {frozenset([e[0],e[1]]):e[2] for e in weighted_edges_list}
        self.help_info = help_info
        self.path_score={self.start_node:0}
        
        self.animation_type = 'dfs'
        
        self.basic_node_color = '#6CB6FF'
        self.start_node_color = 'y'
        self.target_node_color = 'r'
        self.visited_node_color = 'g'
        
        self.basic_edge_color = 'b'
        self.visited_edge_color = 'g'
        
        self.success_color = 'r'
        
        self.correct_paths={}
        self.show_correct_path = []
        self.build_graph()
        self.get_search_tree_node_position()
        self.bfs_search()
        
        

    def build_graph(self):
        self.G = nx.Graph()
        self.G.add_nodes_from(self.node_list)
        self.G.add_weighted_edges_from(self.weighted_edges_list)
        
    def get_search_tree_node_position(self):
        """得到绘图的点的坐标
        """
        self.dfs_search()
        # 得到 dfs 的搜索路径图
        paths = self.dfs_path
        # 得到每条路径的子路径
        path_childern = {}
        for path in paths:
            father = path[:-1]
            if father in paths:
                if father in path_childern:
                    path_childern[father].append(path)
                else:
                    path_childern[father] = [path]
        # 对每条子路径排序
        o_path_childern = collections.OrderedDict(sorted(path_childern.items()))
        # 计算每个树图中每个节点的位置
        tree_node_position = {self.start_node:(1, 0, 2)}
        for path, sub_paths in o_path_childern.items():
            y_pos = -1.0/self.max_depth * len(path)
            dx = tree_node_position[path][2]/len(sub_paths) 
            sub_paths.sort()
            for index, e_s in enumerate(sub_paths):
                x_pos = tree_node_position[path][0] - tree_node_position[path][2]/2 + dx/2 + dx*index
                tree_node_position[e_s]=(x_pos,y_pos, dx)
        self.tree_node_position = tree_node_position
        
    def show_edge_labels(self, ax, pos1, pos2, label):
        (x1, y1) = pos1
        (x2, y2) = pos2
        (x, y) = (x1*0.5 + x2*0.5, y1*0.5 + y2*0.5)

        angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360
        if angle > 90:
            angle -= 180
        if angle < - 90:
            angle += 180
        xy = np.array((x, y))
        trans_angle = ax.transData.transform_angles(np.array((angle,)),
                                                    xy.reshape((1, 2)))[0]
        bbox = dict(boxstyle='round',
                    ec=(1.0, 1.0, 1.0),
                    fc=(1.0, 1.0, 1.0),
                    )
        label = str(label) 
        ax.text(x, y,
                    label,
                    size=16,
                    color='k',
                    alpha=1,
                    horizontalalignment='center',
                    verticalalignment='center',
                    rotation=trans_angle,
                    transform=ax.transData,
                    bbox=bbox,
                    zorder=1,
                    clip_on=True,
                    )
        
    def show_search_tree(self, 
                         this_path=None, 
                         show_success_color=False,
                         best_path=None
                        ):
        """展示搜索树
        """
        # 画出树图        
        fig, ax = plt.subplots()
        fig.set_figwidth(15)
        fig.set_figheight(self.max_depth*1.5)
        plt.axis('off')
        
        for path, pos in self.tree_node_position.items():
            if path[-1] == self.start_node:
                node_color = self.start_node_color
                edge_color = self.basic_edge_color
            elif this_path and path in this_path:
                if show_success_color:
                    node_color = self.success_color
                    edge_color = self.success_color
                else:
                    node_color = self.visited_node_color
                    edge_color = self.visited_edge_color
            elif path[-1] == self.target_node:
                node_color = self.target_node_color
                edge_color = self.basic_edge_color
            else:
                node_color = self.basic_node_color
                edge_color = self.basic_edge_color
            ax.scatter(pos[0], pos[1], c=node_color, s=1000,zorder=1)
            plt.annotate(
                path[-1],
                xy=(pos[0], pos[1]),
                xytext=(0, 0),
                textcoords='offset points',
                ha='center',
                va='center',
                size=15,)
            if len(path)>1:
                plt.plot([self.tree_node_position[path[:-1]][0],pos[0]], 
                         [self.tree_node_position[path[:-1]][1],pos[1]], 
                         color=edge_color,
                         zorder=0)
                if len(path)>1:
                    label = self.weighted_edges_dic[frozenset([path[-2],path[-1]])]
                    if self.animation_type in ['greedy','a_star']:
                            label = self.help_info_weight*self.help_info[path[-1]] + self.origin_info_weight*label
                    self.show_edge_labels(ax, self.tree_node_position[path[:-1]][0:2], pos[0:2], label)
        display.clear_output(wait=True)
        
        show_res_text = ""
        for e_c in self.show_correct_path:
            show_res_text += '找到一条路径: %-7s' % e_c + '。距离为:' +str(self.correct_paths[e_c]) + '\n'
        plt.text(0, -1.1, show_res_text, fontsize=18,horizontalalignment='left', verticalalignment='top',)
        
        if best_path:
            top_text = '最终最短路径为: %-7s' % this_path + '。距离为:' +str(self.correct_paths[this_path]) + '\n'
        elif this_path and  self.animation_type in ['dfs','bfs']:
            top_text = '当前路径: %-7s' % this_path + '。距离为:' +str(self.path_score[this_path]) + '\n' 
            if self.temp_best_path:
                top_text += '当前最短路径为: %-7s' % self.temp_best_path + '。距离为:' +str(self.correct_paths[self.temp_best_path]) + '\n'
        else:
            top_text = ''

        plt.text(0, 0, 
                 top_text, 
                 fontsize=18,
                 horizontalalignment='left', 
                 verticalalignment='top',)
        
        if self.animation_type in ['greedy','a_star']:
            show_greedy_text = self.generate_greedy_help_text(this_path)
            plt.text(0, 0, show_greedy_text, fontsize=18, horizontalalignment='left', verticalalignment='top',)
        plt.show()
        
    def animation_search_tree(self,search_method='dfs', help_info_weight=1, origin_info_weight=1):
        """动画展示搜索过程
        """
        self.animation_type = search_method
        self.show_correct_path = []
        self.temp_best_path = None
        if search_method == 'bfs':
            paths = self.bfs_path
        elif search_method == 'dfs':
            paths = self.dfs_path
        elif search_method == 'greedy':
            self.greedy_search()
            paths = self.greedy_search_path
        elif search_method == 'a_star':
            self.a_star_search(help_info_weight=help_info_weight, origin_info_weight=origin_info_weight)
            paths = self.greedy_search_path
        else:
            paths = []
        for e_path in paths:
            self.show_search_tree(e_path)
            if e_path in self.correct_paths:
                if not self.temp_best_path:
                    self.temp_best_path = e_path
                elif self.path_score[e_path] < self.path_score[self.temp_best_path]:
                    self.temp_best_path = e_path
                self.show_correct_path.append(e_path)
                self.show_search_tree(e_path, True)
            if search_method in ['greedy', 'a_star']:
                time.sleep(5)
        if search_method in ['bfs', 'dfs']:
            if self.correct_paths:
                best_path = min(self.correct_paths, key=self.correct_paths.get)
                self.show_search_tree(best_path, True, True)
    
    def animation_graph(self, search_method='bfs', help_info_weight=1, origin_info_weight=1):
        
        """
        """
        self.animation_type = search_method
        self.show_correct_path = []
        if search_method == 'bfs':
            paths = self.bfs_path
        elif search_method == 'dfs':
            paths = self.dfs_path
        elif search_method == 'greedy':
            self.greedy_search()
            paths = self.greedy_search_path
        elif search_method == 'a_star':
            self.a_star_search(help_info_weight=help_info_weight, origin_info_weight=origin_info_weight)
            paths = self.greedy_search_path
        else:
            paths = []
        for e_path in paths:
            self.show_graph(e_path)
            if e_path in self.correct_paths:
                self.show_correct_path.append(e_path)
                self.show_graph(e_path, True)
            time.sleep(5)
        if search_method in ['bfs', 'dfs']:
            best_path = min(self.correct_paths, key=self.correct_paths.get)
            self.show_graph(best_path, True, True)
    
    def show_graph(self, this_path='', 
                         show_success_color=False,
                         best_path=None):
        """
        绘制图
        :return:
        """
        fig, ax = plt.subplots()
        fig.set_figwidth(6)
        fig.set_figheight(8)
        plt.axis('off')

        # 绘制节点与边颜色
        visited_edges = []
        if not this_path:
            this_path = self.start_node
        path_node_list = list(this_path)
        for i in range(1,len(path_node_list)):
            visited_edges.append(frozenset([path_node_list[i],path_node_list[i-1]]))
            
        # 节点与标识
        nlabels = dict(zip(self.node_list, self.node_list))
        edge_labels = dict([((u, v,), d['weight']) for u, v, d in self.G.edges(data=True)])
        
        # 节点颜色变化
        val_map = {self.target_node: self.target_node_color}
        if path_node_list:
            for i in path_node_list:
                if show_success_color:
                    val_map[i] = self.success_color
                else:
                    val_map[i] = self.visited_node_color
        val_map[self.start_node] = self.start_node_color 
        values = [val_map.get(node, self.basic_node_color) for node in self.G.nodes()]

        # 处理边的颜色
        edge_colors = []
        for edge in self.G.edges():
            # 如果边在result_red_edges,分2种情况:
            # 如果this_path[0]/this_path[-1] 对应起始点和终点,颜色为绿色,否则颜色为红色
            # 如果边不在result_red_edges,则初始化边的颜色为黑色
            if frozenset(edge) in visited_edges:
                if show_success_color:
                    edge_colors.append(self.success_color)
                else:
                    edge_colors.append(self.visited_edge_color)
            else:
                edge_colors.append(self.basic_edge_color)

        # 绘制节点及其标签
        nx.draw_networkx_nodes(self.G, self.nodes_pos, node_size=800, node_color=values, width=6.0)
        nx.draw_networkx_labels(self.G, self.nodes_pos, nlabels, font_size=20)
        # 绘制边及其标签
        nx.draw_networkx_edges(self.G, self.nodes_pos, edge_color=edge_colors, width=2.0, alpha=1.0)
        nx.draw_networkx_edge_labels(self.G, self.nodes_pos, edge_labels=edge_labels, font_size=18)

        display.clear_output(wait=True)
        # show_text = ""
        # for e_c in self.show_correct_path:
        #     show_text += '找到一条路径: %-7s' % e_c + '。距离为:' +str(self.correct_paths[e_c]) + '\n'
        # plt.text(0, -2.6, show_text, fontsize=18, horizontalalignment='left', verticalalignment='top', )
        
#         if best_path:
#             top_text = '最佳路径为: %-7s' % this_path + '。 距离为:' +str(self.correct_paths[this_path]) + '\n'
#         elif this_path and  self.animation_type in ['dfs','bfs']:
#             top_text = '当前路径: %-7s' % this_path + '。 距离为:' +str(self.cal_dis(this_path)) + '\n'
#         else:
#             top_text = ''
#         plt.text(0, 0, 
#                  top_text, 
#                  fontsize=18,
#                  horizontalalignment='left', 
#                  verticalalignment='top',)
        plt.show()
        
    def _dfs_helper(self, G, node,  father, target_node,level, res, path):
        path+=str(node)
        if len(path)>1:
            self.path_score[path] = self.path_score[path[:-1]] + self.weighted_edges_dic[frozenset([path[-2],path[-1]])]
        res.append(path)
        # 找到目标,停止搜索
        if node==target_node:
            return
        if level< self.max_depth:
            for neighbor in sorted(G[node]):
                if str(neighbor) not in path:
                    self._dfs_helper(G, neighbor,  node, target_node, level+1, res, path)
                    
    def dfs_search(self):
        dfs_path=[]
        this_path=''
        if self.start_node:
            self._dfs_helper(self.G, self.start_node, None, self.target_node, 0, dfs_path, this_path)
        self.dfs_path = dfs_path
        for p in dfs_path:
            if p[-1]==self.target_node and p not in self.correct_paths:
                self.correct_paths[p] = self.cal_dis(p) 
        
    def bfs_search(self):
        to_search=[self.start_node]
        bfs_path = []
        bfs_correct_path = []
        depth = 0
        while to_search:
            this_search = to_search.pop(0)
            if len(this_search)>self.max_depth+1 :
                break
            bfs_path.append(this_search)
            if this_search[-1]==self.target_node:
                bfs_correct_path.append(this_search)
                continue
            for ne in sorted(self.G[this_search[-1]]):
                if ne not in this_search:
                    to_search.append(this_search+ne)
        self.bfs_path = bfs_path
        for p in bfs_path:
            if p[-1]==self.target_node and p not in self.correct_paths:
                self.correct_paths[p] = self.cal_dis(p)
                
    def greedy_search(self, help_info_weight=1, origin_info_weight=0):
        self.help_info_weight = help_info_weight
        self.origin_info_weight = origin_info_weight
        search_path = self.start_node
        # 存储每一步的可选项及其分数,用来在动态演示时显示出来
        search_scores = {}
        while len(search_path) <= self.max_depth:
            this_node = search_path[-1]
            neighbour_nodes = [e_n for e_n in sorted(self.G[this_node]) if e_n not in search_path]
            if len(neighbour_nodes) == 0:
                search_scores[search_path]={}
                break
            if self.help_info:
                scores = {e_n:help_info_weight*self.help_info[e_n]+origin_info_weight*self.weighted_edges_dic[frozenset([this_node,e_n])] for e_n in neighbour_nodes }
            else:
                scores = {e_n:self.weighted_edges_dic[frozenset([this_node,e_n])]
                                                             for e_n in neighbour_nodes }
            search_scores[search_path]=scores
            nearest_node = min(scores, key=scores.get)
            search_path += nearest_node
            if nearest_node == self.target_node:
                break
        self.greedy_search_path = [search_path[0:index+1] for index in range(len(search_path))]
        self.search_scores = search_scores
        
    def a_star_search(self, help_info_weight=1, origin_info_weight=1):
        self.greedy_search(help_info_weight, origin_info_weight)
        

    def generate_greedy_help_text(self,path):
        if path[-1] == self.target_node:
            return '抵达目标节点' + str(self.target_node)
        elif path not in self.search_scores:
            return '抵达最大搜索深度,未找到目标节点'
        
        base_text = '当前可选的子节点及其信息值为 \n'+ \
                    str(self.search_scores[path]) + '\n'
        if self.target_node in self.search_scores[path]:
            return base_text + '当前可选的子节点包含了目标节点,\n所以选择目标节点'
        elif len(self.search_scores[path]) == 1:
            return base_text + '因为只有一个子节点,所以选择此节点'
        else:
            return base_text + '因为'+ \
                    str(min(self.search_scores[path], key=self.search_scores[path].get)) + \
                    '的值最小,所以选择此节点'
          
    def cal_dis(self,path):
        dis = 0
        if len(path) > 1:
            for i in range(len(path)-1):
                dis += self.weighted_edges_dic[frozenset([path[i],path[i+1]])]
        return dis