본문 바로가기
열정/연구 일지

Scene Graph Perturbations with VG dataset

by lime9 2024. 4. 29.

환경

  • Ubuntu 18.04
  • VS code
  • Python 3.8.10
  • PyTorch 2.3.0 + cu 121
  • VG dataset in the external SDD

참조 문헌

Code: https://github.com/bknyaz/sgg?tab=readme-ov-file

 

GitHub - bknyaz/sgg: Train Scene Graph Generation for Visual Genome and GQA in PyTorch >= 1.2 with improved zero and few-shot ge

Train Scene Graph Generation for Visual Genome and GQA in PyTorch >= 1.2 with improved zero and few-shot generalization. - GitHub - bknyaz/sgg: Train Scene Graph Generation for Visual Genome an...

github.com

 

Transformer: https://github.com/huggingface/transformers

 

GitHub - huggingface/transformers: 🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX. - huggingface/transformers

github.com

 

Glove: https://nlp.stanford.edu/projects/glove/

 

GloVe: Global Vectors for Word Representation

GloVe is an unsupervised learning algorithm for obtaining vector representations for words. Training is performed on aggregated global word-word co-occurrence statistics from a corpus, and the resulting representations showcase interesting linear substruct

nlp.stanford.edu

 

 


 

 

Scene Graph Perturbations

1. 필요한 패키지 불러오기

import matplotlib.pyplot as plt
import copy
import numpy as np
import torch
import cv2
import os
import subprocess
from dataloaders.visual_genome import VGDataLoader
from augment.sg_perturb import SceneGraphPerturb
from augment.bert import BERT
from lib.visualize import * 
from lib.pytorch_misc import set_seed
from lib.word_vectors import obj_edge_vectors
import seaborn as sns
sns.set(color_codes=True)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

print('torch version', torch.__version__)
print('gitcommit', subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip())

 

VG DataLoader는 VG Dataset에서 SGG를 위해 적당히 데이터를 조작한 후에 불러오는 코드

h5 파일에서 제공하는 각 이미지의 graph 정보를 토대로 GT box와 triplet을 제작

 

HuggingFace에서 제공하는 transformer를 기반으로 BERT를 정의하여 사용: BertForMaskedLM, BertTokenizer 사용

Tokenizer를 통해 [CLS]와 [SEP] 분리

 

obj_edge_vectors는 Glove6B의 word vector를 가져와 embedding한다.

나의 경우 Glove6B 50 dimension으로 먼저 코드를 돌려보았다 (200 d로 코드를 실행하니까 계속 Ubuntu가 freeze되는 현상이 발생하였다).

 

2. Load VG data

# Load VG data
data_dir = '/media/jeeinkim/DATASET'  # root folder where VG and GQA data are located
train_loader, eval_loaders = VGDataLoader.splits(data_dir=data_dir,
                                                 filter_non_overlap=False,
                                                 backbone='vgg16_old')

 

Loading the STANFORD split of Visual Genome

TRAIN DATASET
subj_pred_pairs, pred_obj_pairs 3397 3542
57723 images, 405860 triplets (29283 unique triplets)
Stats: 670591 objects (min=2.0, max=62.0, mean=11.6, std=5.8), 297318 FG edges (min=1.0, max=44.0, mean=5.2, std=3.8), 9029910 BG edges (156.44 avg), graph density min=0.0, max=100.0, mean=6.5, std=8.2

VAL DATASET (ZERO-SHOTS)
722 images, 1130 triplets (851 unique triplets)
Stats: 10129 objects (min=2.0, max=44.0, mean=14.0, std=7.6), 1026 FG edges (min=1.0, max=7.0, mean=1.4, std=0.9), 173934 BG edges (240.91 avg), graph density min=0.1, max=50.0, mean=2.4, std=5.3

VAL DATASET (ALL-SHOTS)
5000 images, 33203 triplets (5043 unique triplets)
Stats: 62754 objects (min=2.0, max=52.0, mean=12.6, std=7.1), 25727 FG edges (min=1.0, max=31.0, mean=5.1, std=4.4), 976590 BG edges (195.32 avg), graph density min=0.1, max=100.0, mean=6.1, std=8.7

TEST DATASET (ZERO SHOTS)
4519 images, 7601 triplets (5278 unique triplets)
Stats: 65281 objects (min=2.0, max=55.0, mean=14.4, std=7.1), 6762 FG edges (min=1.0, max=12.0, mean=1.5, std=1.0), 1107452 BG edges (245.07 avg), graph density min=0.0, max=50.0, mean=1.9, std=4.3

TEST DATASET (10-SHOTS)
9602 images, 19077 triplets (7952 unique triplets)
Stats: 135722 objects (min=2.0, max=56.0, mean=14.1, std=7.0), 16565 FG edges (min=1.0, max=27.0, mean=1.7, std=1.3), 2246514 BG edges (233.96 avg), graph density min=0.0, max=50.0, mean=2.1, std=4.2

TEST DATASET (100-SHOTS)
16528 images, 45385 triplets (3647 unique triplets)
Stats: 224204 objects (min=2.0, max=58.0, mean=13.6, std=6.7), 37923 FG edges (min=1.0, max=32.0, mean=2.3, std=1.8), 3569324 BG edges (215.96 avg), graph density min=0.0, max=100.0, mean=2.7, std=4.9

TEST DATASET (ALL-SHOTS)
26446 images, 183642 triplets (17659 unique triplets)
Stats: 325570 objects (min=2.0, max=58.0, mean=12.3, std=6.5), 145905 FG edges (min=1.0, max=38.0, mean=5.5, std=4.3), 4806730 BG edges (181.76 avg), graph density min=0.1, max=100.0, mean=6.1, std=7.5

 

VG Data는 외부 드라이브인 SSD에 저장하여 로컬 컴퓨터에 mount하여 사용하였다.

(데이터 세트 다운로드에만 시간이 꽤 걸리니 이런 식으로 관리하는 것도 좋을 듯하다. 아니면 서버에 아예 다운로드 받아 놓거나!)

 

VG DataLoader를 사용하여 데이터를 로드하여 각각의 subset으로 분리한다.

 

3. Word vector를 통해 embedding

embed_objs = obj_edge_vectors(train_loader.dataset.ind_to_classes,
                              wv_dir=train_loader.dataset.root,
                              wv_dim=50,
                              avg_words=True)[0]
embed_objs = embed_objs / torch.norm(embed_objs, 2, dim=1, keepdim=True)

 

loading word vectors from /media/jeeinkim/DATASET/VG/glove.6B.50d.txt: 100%|██████████| 400000/400000 [00:02<00:00, 145003.56it/s]

 

word vector의 embedding vector를 다운로드 받은 후, 나의 데이터 세트 (train loader)의 단어들을 embedding한다.

정규화도 잊지 말기!

 

4. BERT 모델 불러오기

bert = BERT(obj_classes=train_loader.dataset.ind_to_classes, 
            rel_classes=train_loader.dataset.ind_to_predicates, 
            triplet2str=train_loader.dataset.triplet2str,
            device='cpu')

 

initializing Bert bert-base-uncased model with threshold 0.000000
tokenizer_config.json: 100%
 48.0/48.0 [00:00<00:00, 1.43kB/s]
vocab.txt: 100%
 232k/232k [00:00<00:00, 621kB/s]
tokenizer.json: 100%
 466k/466k [00:00<00:00, 1.24MB/s]
config.json: 100%
 570/570 [00:00<00:00, 35.3kB/s]
model.safetensors: 100%
 440M/440M [00:17<00:00, 15.3MB/s]
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

 

모델로 사용할 BERT도 불러와 정의해준다.

 

5. 이미지 출력, scene graph와 perturbation 확인

img_list = [2350517, 2343590, 1159620]  # Used in the paper
# img_list = [2350517, 2343590, 713077, 713094, 713129, 713166, 
#            713410, 713467, 713500, 713504, 713581, 713728, 713785, 713835, 1159494, 1159620]  # More images


save = False  # save images and their graphs to the folder results
results_dir = './results'
if save and not os.path.exists(results_dir):
    os.mkdir(results_dir)

methods = ['rand', 'neigh', 'graphn']
alphas = [0, 1, 2, 3, 4, 5, 10, 20, 50]
hit_rates, bert_scores = {}, {}
for split in eval_loaders:
    if split.startswith('test'):
        hit_rates[split.split('_')[1]] = {}
c = 0
n_images_max = 1000  # try more images to reduce noise

for dataset_ind, dataset in enumerate([train_loader.dataset, eval_loaders['test_alls'].dataset]):
    for i, (im_name, gt_classes, gt_rels, boxes) in enumerate(list(zip(dataset.filenames, 
                                                                   dataset.gt_classes, 
                                                                   dataset.relationships,
                                                                   dataset.gt_boxes))):
        name = im_name.split('.')[0]
        vis = int(name) in img_list
        
        if vis:
            im_path = os.path.join(dataset.images_dir, im_name)
            im = cv2.imread(im_path)
            obj_class_names = [dataset.ind_to_classes[cls] for cls in gt_classes]

            im = draw_boxes(im[:,:,::-1], obj_class_names, boxes, fontscale=1, rels=None)
            if save:
                cv2.imwrite('{}/{}.png'.format(results_dir, name), im[:,:,::-1])
            plt.figure(figsize=(7,7))
            plt.imshow(im)
            plt.title((im_path))
            plt.grid(False)
            plt.axis('off')
            plt.show()
        
            show_nx(gt_classes, boxes, gt_rels, 
                    train_set=train_loader.dataset, 
                    test_set=eval_loaders['test_zs'].dataset,
                    name='{}/{}_sg'.format(results_dir, name) if save else None,
                    fontsize=26)
        
        if not vis and c >= n_images_max:
            continue
            
        for method in methods:
            for split in hit_rates:
                if method not in hit_rates[split]:
                    hit_rates[split][method] = {}
                if method not in bert_scores:
                    bert_scores[method] = {}
            
            topk = 10 if method == 'neigh' else 5
            L = 0.5 if vis else 0.2
            
            for alpha in (alphas if method == 'graphn' else alphas[:1]):                
                
                for split in hit_rates:
                    if alpha not in hit_rates[split][method]:
                        hit_rates[split][method][alpha] = []
                if alpha not in bert_scores[method]:
                    bert_scores[method][alpha] = []
                                
                sgp = SceneGraphPerturb(method=method,
                                        embed_objs=embed_objs,
                                        subj_pred_obj_pairs=(train_loader.dataset.subj_pred_pairs,
                                                             train_loader.dataset.pred_obj_pairs),
                                        L=L, 
                                        topk=topk, 
                                        alpha=alpha)
                
                for it in range(5 if vis else 1):  # perturb the same graph several times to check diversity
                    
                    if vis:
                        set_seed(it)  # fix to generate the same perturbations                    
                        print('Image={}, Perturbation={}({}), Seed={}'.format(name, method.upper(), alpha, it))
                    
                    gt_cls_pert = sgp.perturb(torch.cat((torch.zeros(len(gt_classes), 1).long(), torch.from_numpy(gt_classes).view(-1, 1)), 1),
                                              torch.cat((torch.zeros(len(gt_rels), 1).long(), torch.from_numpy(gt_rels).view(-1, 3)), 1))[:, 1].data.numpy()
                    perturbed_nodes = np.where(gt_classes != gt_cls_pert)[0]
                    
                    if dataset_ind == 0 and c < n_images_max:
                        # Compute hit rates only for training inputs
                        c_hit = {split: 0 for split in hit_rates}
                        bert_score = []
                        for node in perturbed_nodes:
                            for node1, node2, R in gt_rels:
                                if node in [node1, node2]:
                                    tri = '{}_{}_{}'.format(gt_cls_pert[node1], R, gt_cls_pert[node2])
                                    for split in hit_rates:
                                        if tri in eval_loaders['test_%s' % split].dataset.triplet_counts:
                                            c_hit[split] += 1
                                    bert_score.append(bert.bert_score_triplet(tri, gt_cls_pert, gt_rels, 
                                                                              node == node1, verbose=False))
                                    
                                    

                        if len(bert_score) > 0:
                            for split in hit_rates:
                                hit_rates[split][method][alpha].append(c_hit[split] / len(bert_score))
                            bert_scores[method][alpha].append(np.mean(bert_score))

                    if vis:    
                        show_nx(gt_cls_pert, boxes, gt_rels, 
                                train_set=train_loader.dataset, 
                                test_set=eval_loaders['test_zs'].dataset,
                                perturbed_nodes=perturbed_nodes, 
                                obj_names_orig=obj_class_names,
                                name='{}/{}_{}_L{}_topk{}_a{}_sg_{}'.format(
                                    results_dir, name, method, L, topk, alpha, it).replace('.', '_') 
                                if save else None, fontsize=26)

        c += 1
        if c % 100 == 0 and c > 0:
            print('%d samples are processed' % c)
print('done!')

 

3개의 이미지에 대해 scene graph와 perturbation을 출력해본다.