master
/ src / .ipynb_checkpoints / main-checkpoint.py

main-checkpoint.py @c038f7f raw · history · blame

import os
import sys
import utils

from predictor import *

# sys.path.insert(0, os.path.abspath("../"))
sys.path.insert(0, os.path.dirname(__file__))


class sceneclassificationvgg19(object):

    def __init__(self, input={}):
        '''

        :param input:
        '''
        # self.checkpoint_dir = os.path.dirname(__file__) + "/checkpoint"
        pass

    def train(self, input={}):
        '''

        :param input:
        :return:
        '''

        # 必须参数
        path_train_orgin_images = input.get("path_train_orgin_images")  # 原始train图片数据集的目录
        path_val_orgin_images = input.get("path_val_orgin_images")  # 原始val图片数据集的目录
        path_orgin_json = input.get("path_orgin_json")  # 原始数据集中图片名和类别对应的json
        path_orgin_csv = input.get("path_orgin_csv")  # 原始数据集中类别csv表
        path_train_images = input.get("path_train_images")  # 此处的训练数据路径是分类后的classes路径
        path_val_images = input.get("path_val_images")  # 此处的训练数据路径是分类后的classes路径

        # 可选参数
        model_name = input.get("model_name", "small")
        learning_rate = input.get("learning_rate", 1e-3)
        batch_size = input.get("batch_size", 32)
        image_size = input.get("image_size", 224)
        optimizer = input.get("optimizer", None)  # 默认是Adam
        weights_mode = input.get("weights_mode", "loss")
        epoch = input.get("epoch", 100)

        # 数据预处理部分
        utils.split_train_and_val_to_classes(path_train_images, path_train_orgin_images, path_orgin_json,
                                             path_orgin_csv, False, True, False)
        utils.split_train_and_val_to_classes(path_val_images, path_val_orgin_images, path_orgin_json,
                                             path_orgin_csv, False, True, False)

        # 训练部分
        classes = len(os.listdir(path_train_images))  # 类型数len(os.listdir(path_train_images))
        classifier = BaseClassifier(name=model_name, image_size=image_size, learning_rate=learning_rate,
                                    batch_size=batch_size, classes=classes, weights_mode=weights_mode,
                                    optimizer=optimizer, epoch=epoch, path_train_images=path_train_images,
                                    path_val_images=path_val_images)

        classifier.train()

        pass

    def predict(self, input={}):
        '''

        :param input:
        :return:
        '''

        # 可选参数
        image_path = input.get("image_path", "./image/00f076d9b6ab784f69c0e43e77853d7c24d62342.jpg")

        path_train_images = None
        path_val_images = None
        model_name = "small"
        learning_rate = 1e-3
        batch_size = 32
        image_size = 224
        classes = 80  # 原始训练模型有80个场景类别
        optimizer = None  # 默认是Adam
        weights_mode = "loss"
        epoch = 100

        # single predictor
        predictor = KerasPredictor(BaseClassifier(name=model_name, image_size=image_size, learning_rate=learning_rate,
                                                  batch_size=batch_size, classes=classes, weights_mode=weights_mode,
                                                  optimizer=optimizer, epoch=epoch, path_train_images=path_train_images,
                                                  path_val_images=path_val_images), 'val')

        # integrated predictor
        # predictor = IntegratedPredictor([
        #     KerasPredictor(VGG16Classifier("vgg16_little"), 'test'),
        #     KerasPredictor(VGG16Classifier("vgg16_little"), 'val'),
        # ])

        prediction = predictor(image_path, return_with_prob=True)
        print(prediction)

        pass

    def load_model(self, file=os.path.dirname(__file__) + "/checkpoint" + "/sceneclassification10.pkl"):
        '''

        :param input:
        :return:
        '''
        pass


## Note: Uncomment this block before creating a crowdsourcing task

if __name__ == '__main__':
    func_name = sys.argv[0]
    # 预测选定图片的前三种类别及其准确率
    sceneclassificationvgg19(object).predict(
        input={"image_path": "./image/00f076d9b6ab784f69c0e43e77853d7c24d62342.jpg"})

    # sceneclassificationvgg19(object).train(
    #     input={
    #         "path_train_images": "./data/ai_challenger_scene_train_20170904/classes",
    #         "path_val_images": "./data/ai_challenger_scene_validation_20170908/classes"})

    # sceneclassificationvgg19(object).train(
    #     input={
    #         "path_train_images": "D:\DEEPLEARNING\project\sceneClassificationTest\data/ai_challenger_scene_train_20170904/classes",
    #         "path_val_images": "D:\DEEPLEARNING\project\sceneClassificationTest\data/ai_challenger_scene_validation_20170908/classes"})