ER-GNN
Fan Zhou / Overcoming Catastrophic Forgetting in Graph Neural Networks with Experience Replay / AAAI-2021
Overcoming Catastrophic Forgetting in Graph Neural Networks with Experience Replay
1. Problem Definition
Static한 graph setting에 맞춰져 있는 현재의 Graph Neural Networks (GNNs)는 현실의 상황과는 거리가 멀다.
Sequence of tasks에 continuously 적용될 수 있는 GNN을 고안하는 것이 본 논문의 주된 목적이다.
Continual Learning에서 발생하는 주된 문제인 catastrophic forgetting 문제도 보완한다.
2. Motivation
2.1 Continual Learning과 Catastrophic Forgetting
Graph Neural Networks (GNNs)은 많은 관심을 받고 있는 연구 분야이며, 눈에 띌만한 성장세를 보이고 있다. 현재까지의 GNN은 static한 graph setting에 초점이 맞춰져 개발되었다. 하지만 현실에서의 setting은 graph가 고정되어 있지 않고, 새로운 node와 edge 등이 끊임없이 추가된다. 이러한 상황에서 model은 정확성을 지속적으로 유지할 수 있어야 한다. 그렇다면 이러한 setting에서 새로운 task까지 잘 해내는 모델을 학습해야 한다면 어떻게 해야할까?
당연히 모델을 retraining 시켜야한다. 모델을 retraining 시키기 위해 아래 두 가지 방법을 쉽게 떠올려 볼 수 있다.
첫째, 기존 데이터에 새로운 데이터까지 추가해서 모델을 처음부터 다시 학습하는 방법이다. 이 방법이 직관적일 수 있지만, 새로운 데이터가 수집될 때마다 전체 데이터셋에 대하여 모델의 모든 가중치값들을 학습하는 것은 시간과 computational cost 측면에서 큰 손실이다.
그렇다면, 모델을 새로운 데이터로만 retraining 시키면 어떻게 될까? 이전에 학습했던 데이터와 유사한 데이터셋을 학습하더라도 아래의 그림처럼 이전의 데이터셋에 대한 정보를 잊어버리게 될 것이다. 이 문제를 일컬어 Catastrophic Forgetting 이라고 부른다.
Catastrophic Forgetting : Single task에 대해서 뛰어난 성능을 보인 모델을 활용하여 다른 task를 위해 학습했을 때 이전에 학습했던 task에 대한 성능이 현저하게 떨어지는 현상
Catastrophic forgetting은 neural network의 더욱 general한 problem인 "stability-plasticity" dilema의 결과이다. 이 때, stability는 previously acquired knowledge의 보존을 의미하고, plasticity는 new knowledge를 integrate하는 능력을 의미한다.
2.2 Limitation
Graph domain에서는 continual learning에 대한 연구가 놀랍도록 얼마 없다. 이는 몇가지 한계점이 존재하기 때문이다.
graph (non-Euclidean data) is not independent and identically distributed data.
graphs can be irregular, noisy and exhibit more complex relations among nodes.
apart from the node feature information, the topological structure in graph plays a crucial role in addressing graph-related tasks.
2.3 Purpose
새로운 task를 학습할 때 이전 task에 대한 catastrophic forgetting 방지.
새로운 task 학습을 용이하게 하기 위해 이전 task의 knowledge를 사용.
Influence function을 이용, previous task에서 영향력이 높은 node들을 buffer에 저장하여 새로운 task 학습에 함께 사용하도록 하는 "Experience Replay GNN (ER-GNN)" method 고안.
2.4 Contributions
Continual Graph Learning (CGL) paradigm을 제시하여 single task가 아닌 multiple consecutive task (continual) setting에서 node classification task를 수행할 수 있도록 함.
Continual node classification task에 기존 GNN을 적용할 때 발생하는 catastrophic forgetting 문제를 해결함.
유명한 GNN model에 적용 가능한 ER-GNN model을 개발하고, 이는 buffer로 들어갈 replay node를 선정할 때 기존 방법과는 다르게 influence function을 사용함.
3. Method
3.1 Problem Definition
Continual Node Classification (task incremental learning) setting에서 등장하는 sequence of task의 notation은 다음과 같다.
Node classification task의 정의는 아래와 같다.
Definition 1 (Node Classification)
3.2 Experience Node Replay
본 논문에서 제시한 ER-GNN의 outline은 아래의 Algorithm에서 확인 가능하다.
이러한 weight factor를 통해 재구성한 loss function은 다음과 같다.
그 이후에는 다음과 같이 loss를 최소화할 수 있는 optimal parameters를 구하면 된다.
3.2.1 Experience Selection Strategy
Replay할 node를 선정하는데 사용되는 3가지 방법을 소개하겠다.
1. Mean of Feature (MF)
2. Coverage Maximization (CM)
3. Influence Maximization (IM)
하지만, 모든 node를 제거해가면서 optimal parameter의 변화를 관찰하는 것은 computational cost 측면에서 매우 비효율적이다.
이에, 저자는 model을 retraining하지 않고 parameter의 변화를 추정할 수 있는 influence function을 적용한다.
Hessian matrix는 다음과 같이 계산된다.
본 process를 진행하는 과정에서 Hessian-vector products (HVPs)를 사용하여 아래의 식을 근사한다.
이 때, Hessian matrix는 positive semi-definite이므로 아래와 같이 식이 변형되고,
4. Experiment
4.1 Experiment setup
4.1.1 Dataset
실험에서 사용한 dataset의 구성은 아래의 표와 같다.
4.1.2 baseline
ER-GNN과의 비교를 위해 continual setting에서 아래의 GNN 모델들과 비교하였다.
저자는 GNN method 중 GAT를 사용하여 ER-GNN을 구성하였다.
위에서 설명한 3가지(MF, CM, IM) experience selection strategy에 대하여 모두 실험을 진행하였는데, 이는 ER-GNN 뒤에 표시되어 있다. (ex. ER-GAT-MF, ER-GAT-CM, ER-GAT-IM 등)
별(*) 표시가 붙어있는 방법론도 있을 것이다. 그러한 경우는 위에서 언급 하였듯, MF와 CM method를 사용할 때 attribute가 아닌 embedding을 기준으로 mean과 coverage maximization을 계산한 것을 의미한다.
4.1.3 Evaluation Metric
본 논문의 주된 목적은 continual learning에서 고질적으로 발생하는 문제인 catastrophic forgetting을 줄이기 위함이므로 이에 알맞은 evaluation metric을 저자는 제안한다.
Performance Mean (PM) : 일반적인 accuracy value이다. 단, Reddit dataset에서는 class 간의 imbalance 문제 때문에 Micro F1 score를 사용한다.
Forgetting Mean (FM) : 이후 task를 학습하고 난 뒤, task의 accuracy가 떨어지는 정도를 측정한 값이다.
4.2 Result
4.2.1 Performance Mean
GNN model들과 다른 두 model (DeepWalk, Node2Vec) 모두 일정 수준의 catastrophic forgetting은 발생하는 것을 관찰할 수 있다.
GNN model에 비해 DeepWalk와 Node2Vec은 PM의 관점에서 더 좋지 않은 결과를 보이지만, FM의 관점에서는 더 좋은 결과가 관찰된다. 이는 DeepWalk나 Node2Vec이 새로운 task를 학습하는 것을 희생하여 이전 task들의 학습을 기억하는 것에 더 초점을 맞추는 것으로 해석 가능하다.
GNN model 중 GAT는 PM과 FM의 관점 모두에서 좋은 결과를 보인다. 이는 attention mechanism이 continual graph-related task learning에서 new task 학습과 existing task 학습 내용을 기억하는데에 모두 장점이 있다는 것을 보여준다.
저자가 고안한 ER-GNN의 경우 IM strategy를 사용한 경우 가장 좋은 performance가 도출되었다. Influence function이 node를 replay하는데에 효과가 있음을 입증한다.
MF와 CM strategy에서는 embedding space를 기준으로 한 model (명칭에 *가 붙어있는)이 attribute space를 기준으로 한 model들보다 좋은 결과를 나타내었다.
Dataset 별 task가 진행됨에 따른 accuracy를 plot
Figure를 보면 세가지 dataset 모두에서 catastrophic forgetting이 발생한다.
ER-GNN model과 함께 influence function을 쓴 model이 catastrophic forgetting을 가장 잘 완화하는 결과이다.
4.2.2 Forgetting Mean
SGC와 GIN model에 대해서 ER-GNN model을 적용하였다.
위의 table과 비교해보면, ER-GNN을 적용하지 않은 natural SGC/GIN일 때보다 FM 값이 확연히 줄어든 것으로 보아 catastrophic forgetting을 줄이는데 도움을 준다는 것을 보여준다.
3가지 experience selection stragtegies 중에서 저자가 제안한 IM 방법이 가장 좋은 performance를 보인다.
Hyperparameter tuning을 통해 catastrophic forgetting과 computational cost 간의 trade-off 관계에서 균형을 찾을 필요가 있을 것이다.
5. Conclusion
5.1 Summary
Graph-based continual learning problem에 GNN을 활용하여 catastrophic forgetting을 방지할 수 있는 framework를 제안함.
Continual learning의 큰 줄기 중에서 replay 방식을 채택하였고, influence function을 접목시켜 가장 영향력이 높은 node를 buffer에 저장하도록 함.
Graph domain에서 유명한 Cora, Citeseer, Reddit 등의 datastet에 적용시켜 실험하였고, PM, FM이라는 두 가지 metric을 토대로 효과가 있음을 입증함.
5.2 Discussion
본 논문의 가장 주요한 novelty는 replay를 할 때, influence function을 접목시킨 것이다.
이렇듯 replay 방식에서는 buffer에 넣을 node를 선택하는 방법론이 매우 주요할 것이고, experience selection strategy에 대하여 연구해보는 것도 좋은 future research topic이 될 것 같다.
또한, 본 논문의 algorithm을 살펴보면 buffer의 크기에 제한이 없이 task가 진행됨에 따라 계속 node를 추가시킨다. 이러한 경우 replay되는 node에 대해 overfitting이 발생할 수 있고, real-world dataset에서는 task가 굉장히 많을 것이기 때문에 computational cost가 많이 증가할 것으로 예상된다. 이에 따라 buffer size와 buffer 내의 node 관리에 대한 연구도 필요할 것이라 생각된다.
Author Information
Seungyoon Choi
Affiliation : DSAIL@KAIST
Research Topic : GNN, Continual Learning, Active Learning
6. Reference & Additional materials
6.1 Github Implementation
6.2 Reference
Fan Zhou, and Chengtai Cao. "Overcoming Catastrophic Forgetting in Graph Neural Network with Experience Replay."
Liu, Huihui, Yiding Yang, and Xinchao Wang. "Overcoming catastrophic forgetting in graph neural networks."
Last updated