Machine Learning, Deep Learning

Transfer learning

n.han 2021. 4. 5. 13:04

Transfer learning

transfer learning은 하나의 도메인에서 학습한 모델을 다른 도메인에 활용하는 것을 말합니다. 이미지가 어떤 사물인지 판단하기 위해 이미 학습된 classifier를 생각해보면, low-level ~ high-level feature들이 모델에 포함되어 있습니다.

low-level feature ~ high level feature. 출처: https://www.kaggle.com/aakashns/advanced-transfer-learning-starter-notebook

같은 이미지 분류 문제인데, class가 약간 달라지는 문제가 있다고 해보죠. 두 모델에 고양이와 같은 공통적인 class가 있다고 하면 기 학습된 모델을 활용할 수 있습니다. 또한 모델에는 low-level ~ high-level feature들이 두루 포함되고 있으므로, 새로운 class에도 활용될 여지가 있습니다. 이런 tranfer learning은 데이터 셋이 유사하며, 새로운 분류 문제의 데이터 셋이 적은 경우에 특히 유용합니다.

그럼 어떻게 pre-trained 된 모델을 활용할 수 있을까요? CNN과 같은 Classifier는 Feature를 추출하는 convolutional layer들과 classification 하는 fully connected layer들로 구성됩니다.

CNN은 Feature extraction을 담당하는 convolutional layer들과 classification을 담당하는 FC layer들로 구성됩니다. 출처: https://www.kaggle.com/aakashns/advanced-transfer-learning-starter-notebook

따라서 feature extraction하는 layer들까지는 활용하고, 추출된 feature들을 다른 classifier에 붙여서 fine-tuning 하는 것입니다.

pre-trained model을 활용하는 방법. 출처: https://www.kaggle.com/aakashns/advanced-transfer-learning-starter-notebook

구체적으로 transfer learning은 다음과 같은 순서로 진행하게 됩니다.
1. pre-training: 타겟 데이터 셋이 아닌 어떤 큰 데이터 셋으로 모델을 training 합니다.
2. output layer들을 randomzie (initialize) 합니다. FC Layer들은 classification 하는데 활용되기 때문에, 이런 classification layer들은 분류 class들이 달라지면 초기화를 해야 합니다.
3. fine-tuning: 우리에 타겟 데이터 셋으로 모델을 training 합니다.

transfer learning이 잘 될까요? 잘 될 수도 있고 안될 수도 있습니다. 비행기 기종을 전문적으로 분류하는 모델을 생각해보죠. 고양이와 비행기를 나누는 pre-trained classifier를 가지고 비행기 기종을 전문적으로 분류하는 것은 잘 되지 않을 것입니다. 하지만 비행기의 low-level feature들을 어느 정도 포함하고 있기 때문에 pre-trained 모델들을 사용하게 되면 정확도나 학습 속도 측면에서 이득이 있을 것 입니다.

transfer learning이 잘 될지 판단할 수 있는 간단한 가이드입니다. 물론 경우에 따라 달라질 수 있어서 단정 짓긴 어렵습니다.
1. pre-tranined 된 모델의 학습 데이터 셋과 타겟 데이터 셋이 유사한 경우: 유용합니다.
2. pre-tranined 된 모델의 학습 데이터 셋과 타겟 데이터 셋이 유사하지 않고, 타겟 데이터 셋의 데이터가 적은 경우: 도움이 되지 않을 것 입니다. 타겟 데이터 셋의 데이터를 늘리는 시도를 먼저 해봐야하지 않을까 생각합니다.
3. pre-tranined된 모델의 학습 데이터 셋과 타겟 데이터 셋이 유사하지 않고, 타겟 데이터 셋의 데이터가 많은 경우: pre-trained모델의 low-level ~ mid-level의 feature들을 사용해보는 것을 고려할 수 있습니다. pre-trained 모델을 많이 드러내야겠지요.

수많은 데이터들로 학습된 유명한 모델들은 공개되어 있기 때문에, 맨땅에서 시작하는 것보다 공개된 모델들을 활용하는 것을 고려해보세요.