master
/ src / utils.py

utils.py @c038f7f

c038f7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
from itertools import combinations

from keras.applications.imagenet_utils import preprocess_input
from keras.preprocessing.image import ImageDataGenerator
from keras.applications.xception import preprocess_input
import keras.backend as K
import tensorflow as tf
import numpy as np
import os
import platform
import json
import csv


def get_files(dir):
    import os
    if not os.path.exists(dir):
        return []
    if os.path.isfile(dir):
        return [dir]
    result = []
    for subdir in os.listdir(dir):
        sub_path = os.path.join(dir, subdir)
        result += get_files(sub_path)
    return result


def calculate_file_num(dir):
    if not os.path.exists(dir):
        return 0
    if os.path.isfile(dir):
        return 1
    count = 0
    for subdir in os.listdir(dir):
        sub_path = os.path.join(dir, subdir)
        count += calculate_file_num(sub_path)
    return count


def ensure_dir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)


def calculate_class_weight(train_path):
    if not os.path.isdir(train_path):
        raise Exception('Dir "%s" not exists.' % train_path)
    n_classes = [len(os.listdir(os.path.join(train_path, subdir))) for subdir in os.listdir(train_path)]
    n_all = sum(n_classes)
    return [num / float(n_all) for num in n_classes]


def get_best_weights(path_weights, mode='acc', postfix='.h5'):
    if not os.path.isdir(path_weights):
        return None
    sub_files = os.listdir(path_weights)
    if not sub_files:
        return None
    target = sub_files[0]
    sub_files_with_metric = list(filter(lambda f: f.endswith(postfix) and f.__contains__('-'), sub_files))
    if sub_files_with_metric:
        try:
            weights_value = [file.replace(postfix, '').split('-')[-2:] for file in sub_files_with_metric]
            key_filename = 'filename'
            kw = ['loss', 'acc']
            weights_info = []
            for filename, value in zip(sub_files_with_metric, weights_value):
                item = dict((k, float(v)) for k, v in zip(kw, value))
                item[key_filename] = filename
                weights_info.append(item)
            if mode not in kw:
                mode = 'acc'
            if mode == 'loss':
                weights_info = list(sorted(weights_info, key=lambda x: x['loss']))
            elif mode == 'acc':
                weights_info = list(sorted(weights_info, key=lambda x: x['acc'], reverse=True))
            target = weights_info[0][key_filename]
            print('The best weights is %s, sorted by %s.' % (target, mode))
        except:
            print('Parse best weights failure, choose first file %s.' % target)
    else:
        print('No weights with metric found, choose first file %s.' % target)
    return os.path.join(path_weights, target)


def is_multi_predictions(predictions):
    if isinstance(predictions, np.ndarray):
        return len(predictions.shape) == 3
    element = predictions[0][0]
    return isinstance(element, list) \
           or isinstance(element, tuple) \
           or isinstance(element, np.ndarray)


def all_combines(data):
    result = []
    for i in range(len(data)):
        combines = list(combinations(data, i + 1))
        result.extend(combines)
    return result


def format_time(seconds):
    if seconds < 60:
        return '%.2f' % seconds

    minutes = seconds / 60
    seconds = seconds % 60
    if minutes < 60:
        return '%d m %.0f s' % (minutes, seconds)

    hours = minutes / 60
    minutes = minutes % 60
    if hours < 24:
        return '%dh %dm %.0fs' % (hours, minutes, seconds)

    days = hours / 24
    hours = hours % 24
    return '%dd %dh %dm %.0fs' % (days, hours, minutes, seconds)


def preprocess_image(im, width, height, train=True):
    size = min(im.shape[:2])
    im = tf.constant(im)
    if train:
        im = tf.random_crop(im, (size, size, 3))
        im = tf.image.resize_images(im, (width, height))
    else:
        im = tf.image.resize_image_with_crop_or_pad(im, height, width)
    im = K.get_session().run(im)
    return preprocess_input(im)


def image_generator(train=True, preprocess=preprocess_input):
    def wrap(value):
        return float(train) and value

    return ImageDataGenerator(
        # samplewise_center=True,
        # samplewise_std_normalization=True,
        channel_shift_range=wrap(25.5),
        rotation_range=wrap(15.),
        width_shift_range=wrap(0.2),
        height_shift_range=wrap(0.2),
        shear_range=wrap(0.2),
        zoom_range=wrap(0.2),
        horizontal_flip=train,
        preprocessing_function=preprocess,
    )


# 从config中转移的方法
os_name = platform.system().lower()


def is_mac():
    return os_name.startswith('darwin')


def is_windows():
    return os_name.startswith('windows')


def is_linux():
    return os_name.startswith('linux')


def context(name, **kwargs):
    return {
        'weights': 'params/%s/{epoch:05d}-{val_loss:.4f}-{val_acc:.4f}.h5' % name,
        'summary': 'log/%s' % name,
        'predictor_cache_dir': 'cache/%s' % name,
        'load_imagenet_weights': is_linux(),
        'path_json_dump': 'eval_json/%s/result%s.json' % (
            name, ('_' + kwargs['policy']) if kwargs.__contains__('policy') else ''),
    }


def parse_weigths(weights):
    if not weights \
            or not weights.endswith('.h5') \
            or not weights.__contains__('/') \
            or not weights.__contains__('-'):
        return None
    try:
        weights_info = weights.split(os.path.sep)[-1].replace('.h5', '').split('-')
        if len(weights_info) != 3:
            return None
        epoch = int(weights_info[0])
        val_loss = float(weights_info[1])
        val_acc = float(weights_info[2])
        return epoch, val_loss, val_acc
    except Exception as e:
        raise Exception('Parse weights failure: %s', str(e))


def output(obj, is_print):
    if is_print:
        if isinstance(obj, list) or isinstance(obj, tuple):
            for i in obj:
                print(i)
        else:
            print(obj)


# 用于将train数据和val数据集按照类别重新分装,用于后续训练模型
def split_train_and_val_to_classes(path_save_dir, path_images, path_json, path_csv, sub_dir_with_name, is_print,
                                   mean_handle):
    with open(path_csv, encoding='utf-8') as f:
        labels = [line[1] for line in csv.reader(f)]
    output(labels[:5], is_print)

    with open(path_json) as f:
        mapping = json.load(f)
        image2label = {item['image_id']: int(item['label_id']) for item in mapping}
        label2image = {}
        for image, label in image2label.items():
            if not label2image.__contains__(label):
                label2image[label] = []
            label2image[label].append(image)
    output(label2image[0][:5], is_print)

    for label, images in label2image.items():
        label_format = np.unicode(labels[label], 'utf-8') if sub_dir_with_name else ('%02d' % label)
        sub_dir = os.path.join(path_save_dir, label_format)
        if not os.path.exists(sub_dir):
            os.makedirs(sub_dir)
        if mean_handle:
            target_files_size = len(image2label) // len(label2image)
            if len(images) > target_files_size:
                # 多了抽取
                images = np.random.choice(images, target_files_size, replace=False).tolist()
            elif len(images) < target_files_size:
                # 少了添加
                added = []
                while len(images) + len(added) < target_files_size:
                    offset = target_files_size - len(images) - len(added)
                    if offset >= len(images):
                        added.extend(images)
                    else:
                        images.extend(np.random.choice(images, offset, replace=False).tolist())
                images.extend(added)
        for image in images:
            with open(os.path.join(path_images, image), 'rb') as old:
                target_file = os.path.join(sub_dir, image)
                while os.path.exists(target_file):
                    target_file = target_file.replace('.', '_.')
                with open(target_file, 'wb') as new:
                    new.write(old.read())
                    output('Write finish % s' % image, is_print)
    output('Completed.', is_print)