728x90

 Image 학습 시 Label 마다 개수가 다르다. Label의 종류가 많고 개수 분포가 한쪽으로 치우쳐져있을 경우 Image가 많은 Label로 model이 판단하게 된다.

 

*이 경우

ImageGenerator를 이용해 균형을 맞춰준다. 

def balance(train_df, max_samples, min_samples, column, working_dir, image_size):
    train_df = tr_d
    train_df = train_df.copy()
    train_df = trim(train_df, max_samples, min_samples, column)
    if 'aug' not in os.listdir():
        os.mkdir('aug')
    aug_dir = os.path.join(working_dir, 'aug')
    
    #if os.path.isdir(aug_dir):
    #    shutil.rmtree(aug_dir)
    for label in train_df['target'].unique():
        dir_path = os.path.join(aug_dir, label)
        os.mkdir(dir_path)
        
    total = 0
    gen = ImageDataGenerator(horizontal_flip=True, rotation_range=20, width_shift_range=.2,zoom_range=.2)
    groups = train_df.groupby('target')
    for label in train_df['target'].unique():
        group=groups.get_group(label)
        sample_count = len(group)
        if sample_count < max_samples:
            aug_img_count = 0
            delta = max_samples-sample_count
            target_dir=os.path.join(aug_dir, label)
            aug_gen=gen.flow_from_dataframe(group, x_col='image_path', y_col=None, target_size=image_size, class_mode=None, batch_size=1, shuffle=False, save_to_dir=target_dir, save_prefix='aug-',color_mode='rgb',save_format='jpg',image_size=img_size)
            
            while aug_img_count < delta:
                images=next(aug_gen)
                aug_img_count+=len(images)
            total+=aug_img_count
    print('Total Augment images created= ',total)
    
    if total>0:
        aug_fpaths = []
        aug_labels = []
        classlist = os.listdir(aug_dir)
        for klass in classlist:
            classpath=os.path.join(aug_dir, klass)
            flist=os.listdir(classpath)
            for f in flist:
                fpath = os.path.join(classpath, f)
                aug_fpaths.append(fpath)
                aug_labels.append(klass)
        Fseries = pd.Series(aug_fpaths, name='image_path')
        Lseries = pd.Series(aug_labels, name='target')
        aug_df = pd.concat([Fseries, Lseries], axis=1)
        train_df = pd.concat([train_df, aug_df], axis=0).reset_index(drop=True)
        
    print(list(train_df['target'].value_counts()))
    return train_df

 

# ImageDataGenerator(horizontal_flip=True, rotation_range=20, width_shift_range=.2,zoom_range=.2)
 
 
반응형
다했다