{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2.4 贝叶斯分析\n",
    "贝叶斯分析是一种根据概率统计知识对数据进行分析的方法，属于统计学分类的范畴。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.4.1 贝叶斯公式\n",
    "\n",
    "- **频率学派**：从历史数据中计算某个事件的概率，认为只要采样足够多，则事件发生的频率就可以无限逼近真实概率。\n",
    "- **贝叶斯学派**：认为某个事件发生的概率不仅与先前这个事件发生的概率相关（称为**先验概率**），也与后期计算该事件概率时所观测的“新近”信息有关（称为**似然概率**）\n",
    "\n",
    "贝叶斯概率计算公式表达：\n",
    "后验概率 = 先验概率 × 似然概率\n",
    "\n",
    "**条件概率**:\n",
    "\n",
    "$P(A|B)$：表示事件 $B$ 发生的前提下，事件 $A$ 发生的概率\n",
    "\n",
    "$$P(A|B)=\\frac{P(A ∩ B)}{P(B)}$$\n",
    "\n",
    "$P(B|A)$：表示事件 $A$ 发生的前提下，事件 $B$ 发生的概率\n",
    "\n",
    "$$P(B|A)=\\frac{P(A ∩ B)}{P(A)}$$\n",
    "\n",
    "由于：\n",
    "\n",
    "$$P(B|A){P(A)}=P(A|B){P(B)}={P(A ∩ B)}$$\n",
    "\n",
    "可得**贝叶斯公式**:\n",
    "\n",
    "$$P(A|B) = \\frac{P(B|A)P(A)}{P(B)}$$\n",
    "\n",
    "其中：\n",
    "- $P(A)$ 是事件 $A$ 发生的先验概率，与事件 $B$ 是否发生无关；\n",
    "- $P(B|A)$是事件 $A$ 发生前提下，事件 $B$ 发生的概率，也称为**似然概率**；  \n",
    "- $P(B)$ 是事件 $B$ 发生的先验概率，也称为**标准化常量**；\n",
    "- $P(A|B)$是事件 $B$ 发生前提下，事件 $A$ 发生的概率，也是 $A$ 的后验概率。  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.4.2 贝叶斯推断\n",
    "贝叶斯推断是一种基于贝叶斯公式进行分析的统计学方法。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "小例子：根据邮件中的 “红包” 字样判别该邮件是不是垃圾邮件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 广告邮件的数量 \n",
    "ad_number = 4000\n",
    "# 正常邮件的数量\n",
    "normal_number = 6000\n",
    "\n",
    "# 所有广告邮件中，出现 “红包” 关键词的邮件的数量\n",
    "ad_hongbao_number = 1000\n",
    "# 所有正常邮件中，出现 “红包” 关键词的邮件的数量\n",
    "normal_hongbao_number = 6\n",
    "\n",
    "# 用户收到广告邮件的先验概率为\n",
    "P_ad = ad_number / (ad_number + normal_number)\n",
    "print(\"用户收到广告邮件的先验概率为 \" + str(P_ad))\n",
    "\n",
    "# 用户收到正常邮件的先验概率为\n",
    "P_normal = normal_number / (ad_number + normal_number)\n",
    "print(\"用户收到正常邮件的先验概率为 \" + str(P_normal))\n",
    "\n",
    "# 红包出现的概率\n",
    "P_hongbao = (normal_hongbao_number + ad_hongbao_number) / (\n",
    "            ad_number + normal_number)\n",
    "print(\"邮件包含红包的先验概率为 \" + str(P_hongbao))\n",
    "\n",
    "# 广告邮件中出现 “红包” 关键词的条件概率\n",
    "P_hongbao_ad = ad_hongbao_number / ad_number\n",
    "print(\"广告邮件中出现 “红包” 关键词的条件概率为 \" + str(P_hongbao_ad))\n",
    "\n",
    "# 正确邮件中出现 “红包” 关键词的条件概率\n",
    "P_hongbao_normal = normal_hongbao_number / normal_number\n",
    "print(\"正常邮件中出现 “红包” 关键词的条件概率为 \" + str(P_hongbao_normal))\n",
    "\n",
    "# 根据贝叶斯定理可得\n",
    "# 当邮件中出现 “红包” ，其为广告邮件的后验概率\n",
    "P_ad_hongbao = P_ad * P_hongbao_ad / P_hongbao\n",
    "print(\"当邮件中出现 “红包” ，其为广告邮件的后验概率为 \" + str(P_ad_hongbao))\n",
    "\n",
    "# 当邮件中出现 “红包” ，其为正常邮件的后验概率\n",
    "P_normal_hongbao = P_normal * P_hongbao_normal / P_hongbao\n",
    "print(\"当邮件中出现 “红包” ，其为正常邮件的后验概率为 \" + str(P_normal_hongbao))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.4.3 朴素贝叶斯分类器  \n",
    "一种常用的分类算法，其假设**样本各个特征之间相互独立、互不影响**。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "小例子：预测同学会不会在某店铺订餐。\n",
    "\n",
    "**目标**：根据某同学的订单记录，如果向他推荐一家“价位低、口味偏甜、距离远”的店铺，判断他会下单吗？\n",
    "\n",
    "**数据**：该同学的下单记录如下\n",
    "\n",
    "|店铺价位|店铺口味|店铺距离|是否下单|\n",
    "|:--:|:--:|:--:|:--:|\n",
    "|高|偏甜|近|是|\n",
    "|高|清淡|近|否|\n",
    "|高|偏辣|远|否|\n",
    "|高|偏甜|远|否|\n",
    "|低|偏甜|近|是|\n",
    "|低|偏甜|近|是|\n",
    "|低|清淡|远|否|\n",
    "|低|偏辣|远|是|\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "该同学在收到8次推荐后，下单4次和没有下单4次，则其“下单”，“不下单”的概率：  \n",
    "$$P(下单) = \\frac{4}{8}=0.5$$  \n",
    "$$P(不下单) = \\frac{4}{8}=0.5$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "该同学对 “价位低、口味偏甜、距离远” 这次推荐的 “下单” 或 “不下单” 的似然概率为（注意基本假设是店铺价位、口味、距离这些特质中间互相独立，互不影响）：\n",
    "\n",
    "$$\n",
    "\\begin{align}\n",
    "&P(价位=低,口味=偏甜,距离=远|下单)\\\\\n",
    "=&P（价位=低|下单）×P（口味=偏甜|下单）×P（距离=远|下单）\\\\\n",
    "=&\\frac{3}{4}×\\frac{3}{4}×\\frac{1}{4}\\\\\n",
    "≈ & 0.141\n",
    "& \\\\\n",
    "& \\\\\n",
    "& P(价位=低,口味=偏甜,距离=远|不下单)\\\\\n",
    "=&P（价位=低|不下单）×P（口味=偏甜|不下单）×P（距离=远|不下单）\\\\\n",
    "=&\\frac{1}{4}×\\frac{1}{4}×\\frac{3}{4}\\\\\n",
    "≈ &0.047\n",
    "\\end{align}\n",
    "$$\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "根据贝叶斯公式，可以得到该同学在一家“价格低、口味偏甜、距离远”的店铺,\n",
    "\n",
    "下单的后验概率为：\n",
    "\n",
    "$$\n",
    "\\begin{align}\n",
    "&P(下单|价位=低,口味=偏甜,距离=远)\\\\\n",
    "=&P（下单）×P(价位=低,口味=偏甜,距离=远|下单)\\\\\n",
    "=&0.5×0.141\\\\\n",
    "= &0.0705\n",
    "\\end{align}\n",
    "$$\n",
    "\n",
    "不下单的后验概率为：\n",
    "$$\n",
    "\\begin{align}\n",
    "&P(不下单|价位=低,口味=偏甜,距离=远)\\\\\n",
    "=&P（不下单）×P(价位=低,口味=偏甜,距离=远|不下单)\\\\\n",
    "=&0.5×0.047\\\\\n",
    "=&0.0235\n",
    "\\end{align}\n",
    "$$\n",
    "\n",
    "\n",
    "由此可见,该同学这次会下单的概率大于不下单的概率。\n",
    "\n",
    "上面的计算过程进行了一些简化，本来应该计算如下两个公式：\n",
    "\n",
    "$$\n",
    "\\begin{align}\n",
    "&P(下单|价位=低,口味=偏甜,距离=远)\\\\\n",
    "=&\\frac{P（下单）×P(价位=低,口味=偏甜,距离=远|下单)}{P(价位=低,口味=偏甜,距离=远)}\\\\\n",
    "\\end{align}\n",
    "$$\n",
    "\n",
    "$$\n",
    "\\begin{align}\n",
    "&P(不下单|价位=低,口味=偏甜,距离=远)\\\\\n",
    "=&\\frac{P（不下单）×P(价位=低,口味=偏甜,距离=远|不下单)}{P(价位=低,口味=偏甜,距离=远)}\\\\\n",
    "\\end{align}\n",
    "$$\n",
    "\n",
    "上述两个计算公式分母相同，对计算结果不影响，因此就从计算过程中略去了。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 实践与体验"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 利用朴素贝叶斯分类器解决 MNIST 手写体数字识别问题\n",
    "\n",
    "**MNIST** 是一个手写体数据集，它包含了各种各样的手写体数字图像及其对应的数字标签。其中每幅手写体图像的大小为 **28×28** ，共有 **784** 个像素点，可记为一个 **784** 维的向量，每个 **784** 维向量对应着一个标签。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本次实验我们利用 **tensorflow** 库来来进行原始数据集的解析和读取，利用 **sklearn** 库来进行特征提取和分类。更多内容可参考**tensorflow** 的[数据集部分](https://www.tensorflow.org/datasets/)，sklearn 的 [bayes部分](https://scikit-learn.org/stable/modules/naive_bayes.html)。\n",
    "  \n",
    "1.在 **Python** 中导入相应库。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import numpy as np\n",
    "from tensorflow.keras.datasets import mnist\n",
    "from sklearn.naive_bayes import BernoulliNB\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2.读取 **MNIST** 训练集和测试集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"读取数据中 ...\")\n",
    "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
    "train_images = train_images.reshape(train_images.shape[0], 784)\n",
    "test_images = test_images.reshape(test_images.shape[0], 784)\n",
    "print('读取完毕!')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们使用下面的方法来查看其中几张图片。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_images(imgs):\n",
    "    \"\"\"绘制几个样本图片\n",
    "    :param show: 是否显示绘图\n",
    "    :return:\n",
    "    \"\"\"\n",
    "    sample_num = min(9, len(imgs))\n",
    "    img_figure = plt.figure(1)\n",
    "    img_figure.set_figwidth(5)\n",
    "    img_figure.set_figheight(5)\n",
    "    for index in range(0, sample_num):\n",
    "        ax = plt.subplot(3, 3, index + 1)\n",
    "        ax.imshow(imgs[index].reshape(28, 28), cmap='gray')\n",
    "        ax.grid(False)\n",
    "    plt.margins(0, 0)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "plot_images(train_images)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3.根据 **MNIST** 训练集训练朴素贝叶斯分类器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"初始化并训练贝叶斯模型...\")\n",
    "classifier_BNB = BernoulliNB()\n",
    "classifier_BNB.fit(train_images,train_labels)\n",
    "print('训练完成!')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "4.根据训练出的分类器对 **MNIST** 测试集中的图片进行识别，得到预测值。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"测试训练好的贝叶斯模型...\")\n",
    "test_predict_BNB = classifier_BNB.predict(test_images)\n",
    "print(\"测试完成!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "5.将测试图片的预测值与实际值相比较，计算并输出分类器的正确率。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy = sum(test_predict_BNB==test_labels)/len(test_labels)\n",
    "print('贝叶斯分类模型在测试集上的准确率为 :',accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "6.对实验结果进行分析比较，列出 **0-9** 不同数字识别的准确率，比较其差异。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 记录每个类别的样本的个数，例如 {0：100} 即 数字为 0 的图片有 100 张 \n",
    "class_num = {}\n",
    "# 每个类别预测为 0-9 类别的个数，\n",
    "predict_num = []\n",
    "# 每个类别预测的准确率\n",
    "class_accuracy = {}\n",
    "\n",
    "for i in range(10):\n",
    "    # 找到类别是 i 的下标\n",
    "    class_is_i_index = np.where(test_labels == i)[0]\n",
    "    # 统计类别是 i 的个数\n",
    "    class_num[i] = len(class_is_i_index)\n",
    "\n",
    "    # 统计类别 i 预测为 0-9 各个类别的个数\n",
    "    predict_num.append(\n",
    "        [sum(test_predict_BNB[class_is_i_index] == e) for e in range(10)])\n",
    "\n",
    "    # 统计类别 i 预测的准确率\n",
    "    class_accuracy[i] = round(predict_num[i][i] / class_num[i], 3) * 100\n",
    "\n",
    "    print(\"数字 %s 的样本个数：%4s，预测正确的个数：%4s，准确率：%.4s%%\" % (\n",
    "    i, class_num[i], predict_num[i][i], class_accuracy[i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "sns.set(rc={'figure.figsize': (12, 8)})\n",
    "np.random.seed(0)\n",
    "uniform_data = predict_num\n",
    "ax = sns.heatmap(uniform_data, cmap='YlGnBu', vmin=0, vmax=150)\n",
    "ax.set_xlabel('真实值')\n",
    "ax.set_ylabel('预测值')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过热力图，我们看到 3 经常被错认为 5 和 8， 4 和 9 经常互相错认。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们看看真实标签为 9，但是预测为 4 的错认的照片\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_imgs(images, true_labels, predict_labels, true_label,\n",
    "             predict_label):\n",
    "    \"\"\"\n",
    "    从全部图片中按真实标签和预测标签筛选出图片\n",
    "    :param images: 一组图片\n",
    "    :param true_labels: 每张图片的标签\n",
    "    :param predict_labels: 模型预测的每张图片的标签\n",
    "    :param true_label: 希望取得的图片的真实标签\n",
    "    :param predict_label: 希望取得的图片的预测标签\n",
    "    :return: \n",
    "    \"\"\"\n",
    "    # 所有类别为 true_label 的样本的 index 值\n",
    "    true_label_index = set(np.where(true_labels == true_label)[0])\n",
    "    # 所有预测类别为 predict_label 的样本的 index 值\n",
    "    predict_label_index = set(np.where(predict_labels == predict_label)[0])\n",
    "    # 取交集，即为真实类别为 true_label， 预测结果为 predict_label 的样本的 index 值\n",
    "    res = list(true_label_index & predict_label_index)\n",
    "    return images[res]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgs = get_imgs(test_images, test_labels, test_predict_BNB, 9, 4)\n",
    "plot_images(imgs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**问题 1**：你在上面的试验中观察到了什么？在下方列出模型对 0-9 不同数字识别的准确率，并比较其差异。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**答案 1**：（在此处填写你的答案。）"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
