master
/ src / classifier_base.py

classifier_base.py @c038f7f

c038f7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
from keras.optimizers import *
from keras.callbacks import *
from keras.applications import *
from keras.layers import *
from keras.models import Sequential
from keras.engine import *
from tensorboard import *
from lr_monitor import *
from generator import *

import im_utils
import utils
import os


class BaseClassifier(object):
    def __init__(self, name, image_size, learning_rate, batch_size, classes, weights_mode, optimizer, epoch,
                 path_train_images, path_val_images):
        # receive params
        self.name = name
        self.im_size = image_size
        self.lr = learning_rate
        self.batch_size = batch_size
        self.classes = classes
        self.weights_mode = weights_mode
        self.weights = None
        self.optimizer = optimizer
        self.epoch = epoch
        self.path_train_images=path_train_images
        self.path_val_images=path_val_images

        # parse context
        self.context = utils.context(self.name)
        self.path_summary = self.context['summary']
        self.path_weights = self.context['weights']

        # build model
        self.model = self.build_model()
        self._compiled = False

    def data_generator(self, path_image, train=True, random_prob=0.5, **kwargs):
        return DirectoryIterator(
            path_image, None,
            classes=['%02d' % i for i in range(self.classes)],
            target_size=(self.im_size, self.im_size),
            batch_size=self.batch_size,
            class_mode='categorical',
            batch_handler=lambda x: func_batch_handle_with_multi_process(x, train, random_prob),
            **kwargs
        )

    def build_model(self):
        if self.weights_mode not in [None, 'acc', 'loss']:
            raise Exception('Weights set error.')

        model = self.create_model()

        if self.weights_mode:
            self.weights = utils.get_best_weights(os.path.dirname(self.path_weights), self.weights_mode)
            if self.weights:
                model.load_weights(self.weights)
                print('Load %s successfully.' % self.weights)
            else:
                print('Model params not found.')
        return model

    def create_model(self):
        # weights = 'imagenet' if self.context['load_imagenet_weights'] else None
        model_vgg19 = VGG19(include_top=False, weights=None,
                            input_shape=(self.im_size, self.im_size, 3), pooling='avg')
        for layer in model_vgg19.layers:
            layer.trainable = False
        x = model_vgg19.output
        x = Dense(4096, activation='relu', name='fc1')(x)
        x = BatchNormalization()(x)
        x = Dense(4096, activation='relu', name='fc2')(x)
        x = BatchNormalization()(x)
        x = Dense(self.classes, activation='softmax')(x)
        model = Model(inputs=model_vgg19.inputs, outputs=x)
        return model

    def compile_mode(self, force=False):
        if not self._compiled or force:
            self._compiled = True
            if not self.optimizer:
                self.optimizer = Adam(self.lr)
            self.model.compile(loss='categorical_crossentropy', optimizer=self.optimizer, metrics=['accuracy'])

    def train(self, **kwargs):
        # calculate files number
        file_num = utils.calculate_file_num(self.path_train_images)
        steps_train = file_num // self.batch_size
        print('Steps number is %d every epoch.' % steps_train)
        steps_val = utils.calculate_file_num(self.path_val_images) // self.batch_size

        # build data generator
        train_generator = self.data_generator(self.path_train_images)
        val_generator = self.data_generator(self.path_val_images, train=False)

        # compile model if not
        self.compile_mode()

        # start training
        utils.ensure_dir(os.path.dirname(self.path_weights))
        weights_info = utils.parse_weigths(self.weights) if self.weights else None
        init_epoch = weights_info[0] if weights_info else 0
        print('Start training from %d epoch.' % init_epoch)
        init_step = init_epoch * steps_train
        try:
            self.model.fit_generator(
                train_generator,
                steps_per_epoch=steps_train,
                callbacks=[
                    ModelCheckpoint(self.path_weights, verbose=1),
                    StepTensorBoard(self.path_summary, init_steps=init_step, skip_steps=200),
                    LRMonitor(step=10),
                ],
                initial_epoch=init_epoch,
                epochs=self.epoch,
                validation_data=val_generator,
                validation_steps=steps_val,
                verbose=1,
                class_weight=utils.calculate_class_weight(self.path_train_images),
                **kwargs
            )
        except KeyboardInterrupt:
            print('\nStop by keyboardInterrupt, try saving weights.')
            # model.save_weights(PATH_WEIGHTS)
            print('Save weights successfully.')
        finally:
            im_utils.recycle_pool()