본문 바로가기
info : 유용한 정보, 공부 등

딥 러닝 GNN, Message Passing Neural Network 메세지 패싱 뉴럴 네트워크

by 퇴근길에 삼남매가 알려드림 2022. 11. 2.
728x90
반응형

여전히 지속되고 있는 Transformer 과 같은 language embedding 에 이어 요즘 많이 보이는 논문은 GNN (Graph Neural Network)에서 메세지 패싱 뉴럴 네트워크Message Passing Neural Netowrk (MPNN) 를 이용한 논문들이다.


 

GNN 은 말그대로 그래프 즉, 점 (node 또는 vertex) 과 선 (edge)으로 연결된 그래프를 이용한다. 

이를 G(V,E) 즉, 점 (v)와 선(e)으로 구성된 G 그래프.. 라고 표현된다.

 

 

직관적으로 생각해보면

sequence data 즉, 글자열이나 아미노산 서열, 화학 구조식 서열에서는 추출하지 못하는 정보들이 GNN 에서 이용된다.

GNN은 일차원이상의 구조 정보를 신경망 (neural network) 학습에 이용하는 것이다. 

 

소셜 네트워크, 그림, 분자구조, 코드 짠 글, 문장,  도로 교통망등 그래프 로 다양한 것들을 표현할 수 있다.


그래프에서,

점(노드)과 점 사이의 거리를 정의하는 방법 중 하나는

점과 점을 잇는 가장 최소의 선(엣지) 갯수이다.

 

그래프에서, 

선으로 연결된 점들은 이제 서로 독립적 (independent)하지 않다.

더 복잡해 지려면, 방향성 (direction) 도 가지게 된다. (도로 차선처럼)

(예시 점 1은 점2로 정보를 전달하지만 점2는 점1로 전달 못함. 점 1-> 점2)

 

그래프에서,

각 점이 선으로 연결되었는지 여부는 점의 갯수가 n이라고 했을 때 nxn 행렬로 나타낼 수 있다.

adjacency matrix 라고 할 수 있다.

점 i ->점 j 의 선 edge 여부를 행렬의 (i,j) 값으로 표현하는 것이다.

- 거리가 k 이내를 어떻게 표현하는지 k-nearest neighbor matrix) 나중에 업뎃하기.

 

adjacency matrix에서 다음과 같은 점을 생각해볼 수 있다.

만약, 방향성이 있는 그래프 directional graph)이면 adjacency matrix 가 비대칭일 것이다.

방향이 없으면 adjacency matrix 가 대칭일 것이다.

 

방향이 없는 경우를  permutation invariance 이라고 표현하기도 한다.

즉, 노드간에 순서가 따로 없는 것이다.

이럴 경우, 아래 다시 이야기 할테지만,

각 노드들의 주변 정보 값을 전부 합하는 방식 (summation) 으로 노드에게 주변 정보 (message )를 전달하기도 한다. 


 

그래프를 신경망으로 학습 시킬 때, 무엇을 학습 시키는 것일까?

가장 간단한 예로 들면, 학습한 결과를 이용해

두 점 (노드)이 연관된 점인지, 즉, 선 (엣지)으로 연결되는지 아닌지를 알고자 하는 게 학습 목표 중 하나이다. 

 

MPNN, Message Passing Neural Network 은 

(메세지 패싱 뉴럴 넷, 메세지 패싱 신경망, 흠 정보 전달 신경망, 쯤이 되려나..)

그래프 정보인 점이나 선을 표현하는 방법인 임베딩을 신경망을 통해서 업데이트 하는 것이다.

정말 간단하게 표현해보자면,

학습된 임베딩으로 우린 두 점이 이웃인지 아닌지를 알 수 있게 된다. 

두 점이 이웃이면 임베딩 스페이스에서 가까이에 있다..

 

 

MPNN 에서 네 가지 구성 요소가 나온다.

 

 - feature vector.  숫자로 특정 성질들을 표현하는 벡터. (흠 구지 말하면, 이건 MPNN 아니어도... )

- neighborhood / neighbor : 이웃들.

      그래프는 점들이 선으로 연결되어 있다.

      주변 '이웃'들을 통해서 자신의 정보를 업데이트 하기 때문에 어떻게 이웃을 정의하느냐가 중요하다. 

- aggregate : 어떻게 주변 이웃들의 정보를 합치는지 

- update :어떻게 합쳐진 정보를 업데이트 하는지  (또는 readout function)

 

 

node embedding 도 있고, edge embedding 도 있고,

더 나아가서 global embedding 도 있다.

 


 

 

 

일단 점 (node)를 기준으로 생각해보자.

 

그래프의 점(node) 마다 이 점을 정의/표현하는 정보값들이 있다.

이를 feature vector 라고 하자. 

feature vector 의 예시로는 이 점의 위치좌표 같은 것을 생각해 볼 수 있다.

   논문들에서 어떤 feature 를 쓰는지도 나중에 정리 해보면 재밌겠다.

그리고 여기서 안 쓰는 feature 를 찾아보는거지..

어쨌든, 이 점들의 feature vector 를 숨겨진 hidden state 정보의 일종이라고 본다. (나중에 더 정확하게 수정하자..)

 

 

이제, 이 점은 주변의 점들로 정의 된다.

이 때 이 '주변/이웃'을 어떻게 정의 하느냐가 중요한 문제 중 하나이다.

예를 들면, 

neighboring 을 점과 최소 거리가 선 1개 라고 정의할 수 있는 것이다.

이 정의에 따르면, 여러 점들 (점1, 점2, 점3...) 이 있는 그래프에서 점 1과 선하나로 연결된 점들이 점1의 이웃들이다.

 

이제 이 이웃들이 하나의 신경망 (neural network layer)이 된다.

즉, GNN Layer 는, 그래프의 각 요소 (각 점/ 각 선 등)에 대한 개별의 신경망이다.

(이게 fully connected layer MLP 일수도 있고 다른 것일수도 있고).

 

자 이 이웃들도 각각 feature vector 를 가지고 있다.

이웃들이 가지고 있는 feature vector 들이 점1의 표현해준다.

즉, 우리는 이웃들이 가지고 있는 feature vector 를 잘 가공해야하고 이게 '메세지'이다.

이제 이 가공방법이, aggregate 이다.

이렇게 가공된 정보를 통해서 점1의 정보를 업데이트! 한다. 즉

주변이웃들의 정보/메세지 message 가 점1을 통과 passing 하는 것이다.

앞서 주변 이웃들이 하나의 신경망 neural network 레이어를 형성한다고 했으므로 

이걸 message passing neural network라고 하는 것이다.

 

이 때, 내 이웃의 이웃 정보까지가 업데이트하고 싶다면,

이건 2-hop depth 라고 표현 할 수 있다.

모델이 깊이까지 내 주위 이웃을 파고 들어가는지! 를 depth 로 표현한다. 

 

여하든  2-hop depth 를 생각하면 내 이웃의 정보가 나에게 오는데..

내 이웃은 본인들 이웃의 정보로 업데이트가 된 상태인 거다.

그리고 이런 이웃들의 메세지가, aggregate 되어서 나를 update 하는 겆..

 

 

예를 들어서, aggregate의 예시 중 하나는 그냥 다 합치는 거다.

점 1 의 이웃이 점 2 와 점 3 이라고 하자.

점1의 feature vector 는 [0,0,0],

점 2의 feature vector 는 [3,10,-4], 점 3의 feature vector 는 [1,-2,0] 이라고 하자.

 

summation 으로 aggregate 되었다는 건, 

점 1이 받는 message 메세지는 

점 2와 점3의 feature vector 의 합인 [3+1, 10+(-2), -4+0] 즉, [4,8,-4] 이다.

 

이제 점 1 이 받는 메세지 [4,8,-4]이 점1을 지나가야 passing 해야 한다. 

이 때 이 정보가 어떻게 점1을 update 업데이트 할까?

또다른 간단한 방법으로는, 평균 값 내서 그걸 기존 값에 더하는 방식으로 업데이트 하는 게 있다.

즉, 점 2와 점3 즉 2개의 점의 정보가 합쳐진 거니 메세지가 업데이트 되는건 [4/2, 8/2, -4/2] 인 [2,4,-2] 를 기존 값에 업데이트 해주면

기존값이 다 0 이니까 점1의 업데이트 된 feature vector 는 [2,4,-2]이다.

 

 

 

여기선 정말 간단하게 되었지만,

이 업데이트에 신경망 neural network 가 이용된다.

 

예를 들어서 업데이트에 attention 즉, 다른 가중치를 적용할 수도 있다.

 

이제 node embedding, edge embedding, global embedding 등을 모두 이용하면 

이 그래프의 각 요소들 사이에서도 

서로 정보를 전달해서 업데이트 하는 것도 다양한 방법으로 고민해서 구현할 수 있다. 

 

어쨌든 .. 

이걸 특정 순서만큼 반복하면 readout phase 에 이른다.  

 

 


GNN 에서도 convolution 을 이용할 수 있다.

기존 convolution neural network 을 보면, 커널로 정의된 '창'이 있고

주어진 데이터를 커널을 통해서 새로 가공한다.

(즉 mxm matrix 에서 2x2 matrix 의 데이터들을 보면서 값 중 최고 값을 추출한다던가)

이때, 이걸 그래프에서 이용하면, cnn 가 주어진 데이터를 인덱스 순서대로 미끄러져 간 것과 달리 전과 선을 기준으로 convolution한다.

 

즉, 엣지로 연결된 이웃들을 기준으로 kernel 정의하고, 여기서 풀링해서 새로운 정보값을 만드는 것이다.

이건 각 노드 그룹을 "normalization" 한다고 표현 할 수도 있다.

 

어쨌든 이렇게 학습된게 그래프에 대한 정보를 담고 있는 encoder 이다..

 

여러모로 부족한 점들이 많지만 후에  업데이트하고 추가 글을 올리도록 하겠다. 

GNN 을 이용한 다양한 데이터, 기본 개념들 에 대한 설명과 예시도 다음에 :) 

 

 

728x90
반응형

댓글


TOP