TesNet
Jiaqi Wang / Interpretable Image Recognition by Constructing Transparent Embedding Space / ICCV-2021
1. Problem Definition
Convolution Neural Network(CNN)์ ๊ฒฐ๊ณผ ํด์์ ํ๋จ์ ์ ํํ ๊ทผ๊ฑฐ๊ฐ ํ์์ ์ธ ์์จ ์ฃผํ ์๋์ฐจ์ ์ ์ง๋จ๊ณผ ๊ฐ์ ์๋ฃ ๋ถ์ผ์์ ์ค์ํ ๊ณผ์ ์ ๋๋ค. ๊ทธ๋ฌ๋ ๋ค์ํ ํ์คํฌ์์ CNN์ ์ฑ๋ฅ์ด ๋น์ฝ์ ์ผ๋ก ๋ฐ์ ํ ๋ฐ์ ๋นํด, ์ฌ์ ํ ๋คํธ์ํฌ์ output์ ์ฌ๋์ด ์ฝ๊ฒ ์ดํดํ ์ ์๋ ์๋ฏธ๋ค๋ก ํด์ํ๋ ๋ฐ์๋ ์ด๋ ค์์ด ๋ง์ต๋๋ค. ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์ต๊ทผ์ CNN ๋ด๋ถ์ feature representation์ ์๊ฐํํ๋ ๋ง์ interpetableํ ๋ฐฉ๋ฒ๋ค์ด ์ ์๋์์ง๋ง, ์๊ฐํ๋ ๋คํธ์ํฌ ๋ด๋ถ feature์ ์๋ฏธ ํด์ ๊ฐ์ gap์ ์ฌ์ ํ ํฝ๋๋ค.
๋ฐ๋ผ์ interpretable image classification(ํด์ ๊ฐ๋ฅํ ์ด๋ฏธ์ง ๋ถ๋ฅ)๋ฅผ ์ํด ์ฌ๋๋ค์ด ์ฝ๊ฒ ๊ทธ ์๋ฏธ๋ฅผ ์ดํดํ ์ ์๋ input image์ concepts๋ฅผ ์ถ์ถํ๋ ๋ฐฉ๋ฒ์ ๋ํ ์ฐ๊ตฌ๊ฐ ์ด๋ฃจ์ด์ง๊ณ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ๊ธฐ์กด ๊ด๋ จ ์ฐ๊ตฌ๋ค์ด ์ ์ํ concepts๋ ์๋ก ๋ค์ฝํ์์ด output class์ ๋ํ ๊ฐ ๊ฐ๋ณ concept์ ์ํฅ์ ํด์ํ๊ธฐ ์ด๋ ต์ต๋๋ค.
๋ณธ ๋ ผ๋ฌธ์์๋ ์ด๋ฅผ ๋ฌธ์ ์ ์ผ๋ก ์ง์ ํ๋ฉฐ output class์ ๋ํ input image์ ํน์ง์ ํจ๊ณผ์ ์ผ๋ก ์ค๋ช ํ ์ ์์ผ๋ฉด์, ๋์์ ์๋ก ์ฝํ์์ง์๊ณ orthogonalํ(์ง๊ต๋ฅผ ์ด๋ฃจ๋) concepts๋ฅผ ์ถ์ถํ ์ ์๋ ๋ฐฉ๋ฒ๋ก ์ ์ ์ํฉ๋๋ค.
2. Motivation
๊ทธ๋ ๋ค๋ฉด Interpretable Concepts (ํด์์ด ์ฉ์ดํ ์ปจ์
)์ด๋ ๋ฌด์์ผ๊น์? ์ธ์ง์ ๊ด์ ์์ Interpretable Concepts๋ ๋ค์์ ์ธ ๊ฐ์ง ์กฐ๊ฑด์ ๋ง์กฑํด์ผ ํฉ๋๋ค.
(1) Informative
Input data๋ basis concept๋ค๋ก spanned๋ vector space์์์ ํจ์จ์ ์ผ๋ก ๋ํ๋ด์ ธ์ผํ๋ฉฐ, input์ essential information(์ค์ํ ์ ๋ณด)๊ฐ ์๋ก์ด representation space์์๋ ๋ณด์กด๋์ด์ผํฉ๋๋ค.
(2) Diversity
๊ฐ ๋ฐ์ดํฐ(ex.์ด๋ฏธ์ง)๋ ์๋ก ์ค๋ณต๋์ง ์๋ ์์์ basis concepts์ ๊ด๋ จ ์์ด์ผํ๋ฉฐ, ๊ฐ์ class์ ์ํ๋ ๋ฐ์ดํฐ๋ค์ ๋น์ทํ basis concepts๋ฅผ ๊ณต์ ํด์ผ ํฉ๋๋ค.
(3) Discriminative
Basis concepts๋ (1)์์ ์ธ๊ธํ basis concept vector space์์์๋ class๊ฐ ์ ๋ถ๋ฆฌ๋๋๋ก class-awareํด์ผ ํฉ๋๋ค. ์ฆ, ๊ฐ์ class์ ์ฐ๊ด๋ basis concepts๋ผ๋ฆฌ๋ ๊ทผ์ ํ๊ฒ, ๋ค๋ฅธ class์ basis concepts ๊ฐ์๋ ๋ฉ๊ฒ embedding๋์ด ์์ด์ผ ํฉ๋๋ค.
๋ฐ์ดํฐ์ concepts๋ฅผ ์ถ์ถํ๊ธฐ ์ํด ์ด์ ์ฐ๊ตฌ๋ค์ auto-encoding, prototype learning๊ณผ ๊ฐ์ด deep neural network์ high-level feature๋ฅผ ์ด์ฉํ๋ ๋ฐฉ์์ ์ ์ํ์์ต๋๋ค. ๊ทธ ์ค ํ ๋ฐฉ๋ฒ์ U-shaped Beta Distribution์ ์ด์ฉํ์ฌ basis concepts์ ๊ฐ์๋ฅผ ์ ํํจ์ผ๋ก์จ ๊ฐ input data๋ฅผ ์์์ ์๋ฏธ ์๋ basis concept๋ค๋ก ๋ํ๋ด๊ธฐ๋ ํ์์ต๋๋ค. ์ด๋ฌํ ์ฐ๊ตฌ๋ค์ Interpretable Concepts์ ์ฒซ๋ฒ์งธ ์กฐ๊ฑด์ ๋ง์กฑํ์์ง๋ง, ์์ ์ธ๊ธํ์๋ฏ์ด basis concepts๊ฐ ์๋ก ์ฝํ์์ด(entangled) input๊ณผ output์ ๋ํ ๊ฐ๋ณ concept์ ์ํฅ์ ํด์ํ๊ธฐ ์ด๋ ต๋ค๋ ๋ฌธ์ ์ ์ด ์กด์ฌํฉ๋๋ค.
๋ฐ๋ผ์, ์ด ๋
ผ๋ฌธ์์๋ ์์ ์ธ๊ฐ์ง Interpretable Concepts ์กฐ๊ฑด์ ๋ชจ๋ ์ถฉ์กฑ์ํค๋ basis concepts๋ฅผ ์ค๊ณํ๋ ๋ฐ์ ์ฃผ๋ชฉํ๊ณ ์์ต๋๋ค. ๋
ผ๋ฌธ์์ ์ค๊ณํ basis concepts๋ ๋ค์๊ณผ ๊ฐ์ ํน์ง๋ค์ ๊ฐ์ง๋๋ค.
(1) ๊ฐ class๋ ์์ ๋ง์ basis concepts๋ฅผ ๊ฐ์ง๋ฉฐ class๊ฐ ๋ค๋ฅธ ๊ฒฝ์ฐ basis concepts๋ ์ต๋ํ ๋ค๋ฆ
๋๋ค.
(2) High-level feature๊ณผ basis concepts ์ฌ์ด๋ฅผ ํจ๊ณผ์ ์ผ๋ก ์ฐ๊ฒฐํ๋ mapping์ ์ ๊ณตํฉ๋๋ค.
(3) Input image ์์ basis concepts๋ ๊ฐ class์ ๋ํ prediction score์ ๊ณ์ฐํ๋ ๋ฐ์ ๋์์ด ๋ฉ๋๋ค.
์์ ์ธ ๊ฐ์ง ํน์ง์ ๋ง์กฑํ๋ basis concepts ์ค๊ณ๋ฅผ ์ํด, ๋ณธ ๋
ผ๋ฌธ์ ๊ธฐ์กด ์ฐ๊ตฌ๋ค๊ณผ ๋ค๋ฅด๊ฒ Grassmann manifold๋ฅผ ๋์
ํ์ฌ basis concept vector space๋ฅผ ์ ์ํฉ๋๋ค. ๋ค์์ ๊ทธ๋ฆผ์ฒ๋ผ, ๊ฐ class๋ง๋ค์ basis concepts subset์ด Grassmann manifold ์์ point๋ก ์กด์ฌํฉ๋๋ค.
Grassmann manifold๋ ์ฝ๊ฒ ๋งํ๋ฉด linear subspaces์ set(์งํฉ)์ด๋ผ๊ณ ์๊ฐํ ์ ์์ต๋๋ค. ์ฌ๊ธฐ์ subspace๋ vector space V ์ subset(๋ถ๋ถ์งํฉ) W ๊ฐ V ๋ก๋ถํฐ ๋ฌผ๋ ค๋ฐ์ ์ฐ์ฐ๋ค๋ก ์ด๋ฃจ์ด์ง ๋ ๋ค๋ฅธ ํ๋์ vector space์ผ ๋ W ๋ฅผ V ์ subspace๋ผ๊ณ ๋งํฉ๋๋ค.
๋ํ projection metric์ ํตํด ๊ฐ class์ basis concept๋ค์ ์๋ก orthogonalํ๋๋ก, ๋์์ class-awareํ basis concepts subset๋ค์ ์๋ก ๋ฉ๋ฆฌ ์์นํ๋๋ก ๊ท์ ๋ฉ๋๋ค. ์ด ๋ ๊ฐ์ง ๊ท์ ๋ฅผ ํตํด basis concepts๊ฐ ์๋ก ์ฝํ์ง ์๋๋ก ํจ์ผ๋ก์จ ๊ธฐ์กด ์ฐ๊ตฌ์ ํ๊ณ์ ์ ๊ทน๋ณตํ๊ณ ์์ต๋๋ค.
๋
ผ๋ฌธ์ ์ด๋ ๊ฒ ์ค๊ณ๋ transparent embedding space (concept vector space)๊ฐ ๋์
๋ ์๋ก์ด interpetable network, TesNet์ ์ ์ํ๊ณ ์์ต๋๋ค.
3. Method
The overview of TesNet architecture
๋ค์์ TesNet์ ์ ์ฒด์ ์ธ architecture์ ๋ชจ์ต์
๋๋ค.
๊ทธ๋ฆผ๊ณผ ๊ฐ์ด TesNet์ convolutional layers f, trasparent subspace layer , ๊ทธ๋ฆฌ๊ณ classifier h ์ด๋ ๊ฒ ์ธ ๊ฐ์ง์ ํต์ฌ ์์๋ก ์ด๋ฃจ์ด์ ธ ์์ต๋๋ค.
๊ฐ ์์๋ฅผ ํ๋์ฉ ์ดํด๋ณด๋ฉด, ๋จผ์ convloutional layers f ๋ 1X1 convolutional layer๋ค์ด ์ถ๊ฐ๋ ๊ธฐ๋ณธ CNN ๋คํธ์ํฌ(ex.ResNet) ์ ๋๋ค. ๋ feature map์ transparent embedding space์ projection์ํค๋ subspace layer์ ๋๋ค. ๊ฐ class๋ง๋ค subspace๊ฐ ์กด์ฌํ์ฌ ์ด class ๊ฐ์๋งํผ์ subspace๊ฐ ์กด์ฌํฉ๋๋ค. ๊ฐ class์ subspace๋ M๊ฐ์ basis concepts๋ก spanned ๋์ด์์ต๋๋ค. ์ด M๊ฐ์ within-class concepts(ํด๋์ค ๋ด๋ถ concepts)๋ ์๋ก orthogonalํ๋ค๊ณ ๊ฐ์ ํฉ๋๋ค. ๋ฐ๋ผ์ ์ด C๊ฐ์ class๊ฐ ์์ ๋, ๊ฐ class ๋ง๋ค M๊ฐ์ basis concepts๊ฐ ์กด์ฌํ๋ค๊ณ ๊ฐ์ ํ๋ฉด ์ ์ฒด CM๊ฐ์ basis concepts๊ฐ ์กด์ฌํฉ๋๋ค.
Embedding space learning
๊ทธ๋ ๋ค๋ฉด basis concepts๋ ์ด๋ป๊ฒ ์ ์๋์ด embedding space๋ฅผ ์ด๋ฃจ๊ณ ์๋์ง ์ดํด๋ณด๊ฒ ์ต๋๋ค.
๊ฐ basis concept์ basis vector๋ก ํํ๋ฉ๋๋ค. ์ด basis vector๋ ๋ค์ ์ธ ๊ฐ์ง ์กฐ๊ฑด์ ๋ง์กฑํด์ผํฉ๋๋ค.
(1) ๋ค๋ฅธ basis vector ์ฌ์ด์๋ ์๋ฏธ๊ฐ ์ค๋ณต๋๋ฉด ์๋ฉ๋๋ค.
(2) embedding space์์๋ ๊ฐ class๋ ๊ตฌ๋ถ๋์ด์ผ ํฉ๋๋ค.
(3) basis vector๋ค์ ๋น์ทํ high-level patch(์ฌ๋๋ค์ด ์ธ์ํ ์ ์๋ level์ image)๋ค์ ๊ตฐ์งํํ๊ณ ๋ค๋ฅธ ๊ฒ๋ค๋ผ๋ฆฌ๋ ๋ถ๋ฆฌํ ์ ์์ด์ผ ํฉ๋๋ค.
์ด ์ธ ๊ฐ์ง ์กฐ๊ฑด์ ๋ง์กฑ์ํค๊ธฐ ์ํด ์ ์ฒด architecture์์ ๋ณด์๋ convolutional layer, basis vectors, classifier layer์ weight๋ค์ด ์๋ก jointํ๊ฒ optimize(์ต์ ํ)๋ ์ ์๋๋ก joint optimization problem์ ์ ์ํ๊ณ ์์ต๋๋ค. ๋ค์์ ๊ฐ weight๋ฅผ ์ต์ ํํ๊ธฐ ์ํ Loss์ optimization ๊ณผ์ ์ ๋๋ค.
Orthonormality for Within-class Concepts
์กฐ๊ฑด (1)์ ๋ง์กฑ์ํค๊ธฐ ์ํ Loss๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
basis vector์ฌ์ด์ ์๋ฏธ๊ฐ ์ค๋ณต๋์ง ์๋๋ค๋ ๊ฒ์ ๊ฐ์ class์ ์ํ basis concepts์ด๋๋ผ๋ ๋ฐ๋์ ์๋ก ๋ค๋ฅธ ์ธก๋ฉด๋ค์ ๋ํ๋ด๊ณ ์์ด์ผํ๋ค๋ ๋ป์
๋๋ค. ๊ทธ๋ฌ๊ธฐ ์ํด์ ๊ฐ์ class์ ์ํ basis concept vectors๊ฐ ์๋ก orthogonalํด์ผ ํ๋ฏ๋ก ๊ฐ class์ basis vectors ์ฌ์ด์ orthonormality๋ฅผ ๊ท์ ํ๋ Loss๋ฅผ ์ฌ์ฉํฉ๋๋ค.
Loss ์์ ์ดํด๋ณด๋ฉด ๊ฐ class์ basis vector matrix ํ๋ ฌ๊ณฑ๊ณผ identity matrix ์ฌ์ด์ L2 norm์ ๋ชจ๋ ๋ํ๊ณ ์์ต๋๋ค. ์ฆ, ๊ฐ class์ basis vectors๊ฐ์ correlation(์๊ด ๊ด๊ณ)๋ฅผ ์ต์ํ์ํค๊ธฐ ์ํ Loss์ ๋๋ค. ์ด๋ฌํ Loss๋ฅผ ํตํด ํ์ต๋ orthonormal basis vectors๊ฐ ๊ฐ class์ subspace๋ฅผ spanํ๊ฒ ๋ฉ๋๋ค.
Separtion for Class-aware Subsapces
๋๋ฒ์งธ๋ก ์กฐ๊ฑด (2)๋ฅผ ๋ง์กฑ์ํค๊ธฐ ์ํ Loss๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
embedding space์์์ class๊ฐ ๊ตฌ๋ถ๋๊ธฐ ์ํด์๋ ๊ฐ class์ subspace๊ฐ ์๋ก ๋ฉ๋ฆฌ ์์นํด ์์ด์ผํฉ๋๋ค. ์ฆ, Grassmann manifold ์์์ class-aware subspace๋ค์ ๊ฑฐ๋ฆฌ๊ฐ ์ต๋ํ ๋ฉ์ด์ง๋๋ก ๊ท์ ํฉ๋๋ค. ๊ฐ subspace๋ Grasmann manifold์์์ uniqueํ projection์ผ๋ก ์กด์ฌํ๋ฏ๋ก, subspace ์ฌ์ด์ ๊ฑฐ๋ฆฌ๋ฅผ
projection mapping์ ์ด์ฉํ์ฌ ์์นํํ ์ ์์ต๋๋ค.
Loss ์์์ ๋ class c์ orthonormal basis vectors๋ก ์ด๋ฃจ์ด์ง matrix๋ฅผ ์๋ฏธํ๊ณ , ์ด matrix์ ํ๋ ฌ๊ณฑ์ด class c์ ์ฐ๊ด๋ subspace์ projection mapping์
๋๋ค. ๊ฒฐ๊ตญ Loss๋ ์๋ก ๋ค๋ฅธ class์ projection mapping ์ฌ์ด์ L2 norm distance๋ค์ ํฉ์ ์ต์ํ์ํค๊ธฐ ์ํ Loss๋ก ์ดํดํ ์ ์์ต๋๋ค.
High-level Patches Grouping
๋ง์ง๋ง์ผ๋ก ์กฐ๊ฑด (3)์ ๋ง์กฑ์ํค๊ธฐ ์ํ Loss์
๋๋ค.
์กฐ๊ฑด (3)์ ๊ฒฐ๊ตญ high-level ์ด๋ฏธ์ง ํจ์น๋ค์ด embedding subspace์๋ ์ projection ๋์ด์ผ ํ๋ค๋ ์๋ฏธ์
๋๋ค. ์ฆ, ์ด๋ฏธ์ง ํจ์น๋ค์ด subspace์ embedding ๋์์ ๋ ์ด๋ฏธ์ง๊ฐ ์ํ ground-truth class์ basis vectors์ ๊ทผ์ ํด์ผํฉ๋๋ค. ์ด๋ฅผ ์ํด ๋
ผ๋ฌธ์
Compactness Loss์ Separation Loss๋ฅผ ์ ์ํ๊ณ ์์ต๋๋ค.
๋จผ์ Compactness Loss์ ์์ ์ดํด๋ณด๋ฉด, ์ด๋ฏธ์ง ํจ์น์ ground-truth class์ basis vectors์ฌ์ด์ cosine distance(negative cosine similarity)๋ฅผ ์ต์ํํ๊ณ ์์ต๋๋ค. ์ด๋ ๊ฒฐ๊ตญ ์ด๋ฏธ์ง ํจ์น์ ground-truth class์ basis vectors์ฌ์ด์ cosine similarity๋ฅผ ํฌ๊ฒํ๋ ๊ฒ๊ณผ ๊ฐ์ต๋๋ค.
๋ฐ๋ฉด, Separation Loss๋ ์ด๋ฏธ์ง ํจ์น๊ฐ ground-truth๊ฐ ์๋ class์ basis vectors๊ณผ๋ ๋ฉ์ด์ง๋๋ก ๋ ์ฌ์ด์ cosine similarity๋ฅผ ์ต์ํํ๊ณ ์์ต๋๋ค.
์ด ๋ Loss๋ฅผ hyper-parameter M ์ ์ฌ์ฉํ์ฌ ๋ํจ์ผ๋ก์จ Compactness-Separation Loss๋ฅผ ์ ์ํฉ๋๋ค.
Identification
๋ง์ง๋ง์ผ๋ก classifier layer๋ฅผ optimizeํ๊ธฐ ์ํ Loss๋ก์ Cross Entropy Loss๋ฅผ ์ด์ฉํฉ๋๋ค.
์ต์ข
์ ์ผ๋ก, ์ง๊ธ๊น์ง ์ ์๋ loss๋ค์ jointly optimizeํ๊ธฐ ์ํด Total Loss for Joint Optimization์ ์ ์ํฉ๋๋ค.
hyper-parameters๋ฅผ ์ฌ์ฉํ์ฌ classification loss(cross entropy loss)์ orthonormality loss, subspace separation loss, compactness-separation loss๋ฅผ ์ ์ ํ ๋น์จ๋ก ๋ํด์ค๋๋ค. ์ด total loss์ ํจ๊ป convolutional layer, basis vectors๊ฐ ๋์์ ์ต์ ํ๋๋ฉฐ concept embedding subspace๊ฐ ํ์ต๋ฉ๋๋ค.
Concept-based classification
embedding space๊ฐ ํ์ต๋๊ณ ๋๋ฉด, convolutional layers์ basis vectors์ parameter๋ฅผ ๊ณ ์ ์ํจ ํ, ๋ง์ง๋ง ๋จ์ classifier๋ฅผ ํ์ต์ํค๊ฒ ๋ฉ๋๋ค. classifier๋ concept-class weight G ๋ฅผ ์ต์ ํํจ์ผ๋ก์จ ํ์ต์ด ๋๋๋ฐ, weight G ๋ G(c,j) ์ ๊ฐ์ด j๋ฒ์งธ unit์ด class c์ ์ํ๋ ๊ฒฝ์ฐ๋ฅผ ์ ์ธํ๊ณ ๋ชจ๋ 0์ธ sparse matrix์
๋๋ค. ์์ ์ ์ํ Identification Loss์ weight G ๋ฅผ sparseํ๊ฒ ์ ์งํ๊ฒ ํ๋ ๊ท์ ๋ฅผ ๋ํ์ฌ Loss`๋ฅผ ์ ์ํ๊ณ , ์ด Loss๋ฅผ ์ต์ํํ๋๋ก classifier๊ฐ ํ์ต๋ฉ๋๋ค.
4. Experiment
๋ณธ ๋
ผ๋ฌธ์์๋ ๋ค์ํ CNN architecture์ ๋ํ TesNet์ ๋์ ์ ์ฉ์ฑ์ ์
์ฆํ๊ธฐ ์ํด ๋ ๊ฐ์ง์ case study๋ฅผ ์งํํ์์ต๋๋ค. ๊ทธ ์ค ์ฒซ๋ฒ์งธ case study์ธ bird species identification์ ๋ํด์ ์์ธํ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
Experiment setup
Dataset Caltecg-USCD Birds-200-2011 dataset์ ์ฌ์ฉํ์ฌ bird species classification ์คํ์ ์งํํ์์ต๋๋ค. dataset์ 200 ์ข (species)์ bird ์ด๋ฏธ์ง 5994+5794์ฅ์ผ๋ก ์ด๋ฃจ์ด์ก์ต๋๋ค. ๊ทธ ์ค 5994์ฅ์ training, ๋๋จธ์ง 5794์ฅ์ test์ ์ด์ฉํ์์ต๋๋ค. ๊ฐ bird class๋ง๋ค 30์ฅ์ ์ด๋ฏธ์ง๋ฐ์ ์กด์ฌํ์ง ์์, ๋ ผ๋ฌธ์์๋ random rotation, skew, shear, flip ๋ฑ์
augmentation์ ํตํด training set์ ๊ฐ class๋ง๋ค 1200์ฅ์ ์ด๋ฏธ์ง๊ฐ ์กด์ฌํ๋๋ก ๋ฐ์ดํฐ๋ฅผ ์ฆ๊ฐํ์์ต๋๋ค.baseline non-interpetableํ ๋ณธ๋
VGG16, VGG19, ResNet34, ResNet152, DenseNet121, DenseNet161๋คํธ์ํฌ๋ค์ baseline์ผ๋ก ์ผ๊ณ , ๊ฐ ๋คํธ์ํฌ์ interpetableํTesNet์ ์ ์ฉํ ๊ฒฝ์ฐ์ ๋น๊ต ์คํํ์์ต๋๋ค. ๋ํ, TesNet๊ณผ ์ ์ฌํ interpetable network architecture์ธProtoPNet์ ์ ์ฉํ ๊ฒฐ๊ณผ๋ ํจ๊ป ๋น๊ตํ์์ต๋๋ค.Evaluation Metric ์คํ์ ์ฑ๋ฅ ํ๊ฐ์งํ๋ก
classification accuracy๋ฅผ ์ฌ์ฉํ์์ต๋๋ค.
Result
Accuracy comparison with diffrent CNN architectures ์๋ ํ์์ ์ ์ ์๋ฏ์ด, baseline network์ TesNet์ ์ ์ฉํ ๊ฒฝ์ฐ ๋ถ๋ฅ ์ ํ๋๊ฐ ์ต๋ 8%์ ๋ ํฌ๊ฒ ํฅ์๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค. ๋ํ, TesNet์ Loss๋ฅผ ๋ค์ํ๊ฒ ์ ์ํ์ฌ ์คํํ ๊ฒฐ๊ณผ, 4๊ฐ์ง Loss๋ฅผ ๋ชจ๋ jointlyํ๊ฒ optimizeํ์์ ๋ ๊ฐ์ฅ ์ ํ๋๊ฐ ๋์ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
The interpretable reasoning process ๋ค์ ๊ทธ๋ฆผ์ TesNet์ด test image์ ๋ํ์ฌ decision์ ๋ด๋ฆฌ๋ reasoning process๋ฅผ ์๊ฐํํ ๊ฒ์ ๋๋ค.
European Goldfinch๋ผ๋ class์ test image๊ฐ ์ฃผ์ด์ก๋ค๊ณ ํ ๋, TesNet์ ํ์ต๋ basis vectors๋ฅผ ํตํด feature map์ re-representํ ์ ์์ต๋๋ค. ๊ฐ class c์ ๋ํด์, ๋ชจ๋ธ์ ํ์ต๋ basis vectors๋ฅผ image patch์ re-representํจ์ผ๋ก์จ ๊ทธ image๊ฐ class c์ ์ํ score๋ฅผ ๊ณ์ฐํฉ๋๋ค.
์๋ฅผ ๋ค์ด, ์ ๊ทธ๋ฆผ์์ ๋ชจ๋ธ์ European goldfinch class์ basis vector(concept)๋ฅผ test image(original image)๊ฐ ์ด class์ ์ํ ์ง์ ๋ํ ์ฆ๊ฑฐ๋ก ํ์ฉํฉ๋๋ค. Activation map column์ ์ดํด๋ณด๋ฉด, European goldfinch class์ ์ฒซ ๋ฒ์งธ basis vector๊ฐ ์๋ฏธํ๋ 'black and yellow wing concept'์ด test image ์์์ ๊ฐ์ฅ ๋๋๋ฌ์ง๊ฒ activated(ํ์ฑํ) ๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค. ๊ฐ์ ๋ฐฉ์์ผ๋ก ๋ ๋ฒ์งธ basis vector๊ฐ ์๋ฏธํ๋ 'head concept', ์ธ ๋ฒ์งธ basis vector๊ฐ ์๋ฏธํ๋ 'brown fur concept'์ด image์์์ ํฌ๊ฒ ํ์ฑํ๋์์ต๋๋ค.
์ด๋ฅผ ๋ฐํ์ผ๋ก ๋ชจ๋ธ์ class์ ๊ฐ basis concept vector์ test image์์์ activated๋ ๋ถ๋ถ ์ฌ์ด์ similarity(์ ์ฌ๋)๋ฅผ ๊ตฌํ๊ณ basis concept์ ์ค์๋์ ๋ฐ๋ผ ๊ฐ์ค์น๋ฅผ ๋งค๊ฒจ ๋ํจ์ผ๋ก์จ ์ต์ข ์ ์ธ European Goldfinch class์ ๋ํ score๋ฅผ ๊ตฌํฉ๋๋ค. ์ด score๋ฅผ ๋ฐํ์ผ๋ก test image์ class๋ฅผ ์์ธกํฉ๋๋ค.
์ด๋ฌํ reasoning ๊ณผ์ ์ ํตํด baseline CNN ๋ชจ๋ธ๋ค๋ณด๋ค ๋์ ๋ถ๋ฅ ์ ํ๋๋ฅผ ๋ฌ์ฑํ ์ ์์ต๋๋ค.
5. Conclusion
Summary TesNet์ ๋ค๋ฅธ CNN ๋ชจ๋ธ์ plug-in๋์ด classifiaction ์ฑ๋ฅ์ ํฅ์์ํฌ ์ ์๋ ์ ์ฉ์ฑ ๋์ architecture์ ๋๋ค. TesNet์ class-aware concepts๋ฅผ ์ค๊ณํ๊ณ ๊ฐ์ class์ ์ํ concepts๋ผ๋ฆฌ ์ฝํ์ง ์๋๋ก ํ๋ฉฐ ํจ๊ณผ์ ์ผ๋ก prediction ์ฑ๋ฅ์ ํฅ์์์ผฐ์ต๋๋ค. ๋ํ, TesNet์ image์ ์ด๋ค concept์ด CNN์ ํ์ต์ํค๊ณ ์์ธกํ๋ ๋ฐ์ ๊ทผ๊ฑฐ๋ก ์ฌ์ฉ๋๋์ง๋ฅผ ์ค๋ช ํ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋, TesNet์ basis concepts๊ฐ ๋ชจ๋ flatํ๋ค๋ ์ ์ ๋ฅผ ํ๊ณ ์์ด, ์ฌ๋๋ค์ด ์ค์ ๋ก ์ฌ๋ฌผ์ ๋ถ๋ฅํ ๋์ ์ธ์ง ๊ณผ์ ๊ณผ ํฐ ์ฐจ์ด๊ฐ ์์ต๋๋ค. ๋ํ ์ค์ ๋ก real world์์์ concepts๋ ์๋ก ๊ณ์ธต์ ์ผ๋ก ์ด๋ฃจ์ด์ ธ์๊ธฐ ๋๋ฌธ์, hierarchical basis concepts๋ฅผ ํ์ตํ ์ ์๋ ๋คํธ์ํฌ์ ๋ํ ์ฐ๊ตฌ๊ฐ ํ์ํฉ๋๋ค.
Opinion CNN์ output ํด์์ ์์ด input image์ concept์ด๋ผ๋ ๊ฐ๋ ์ ์ ์ ์ํ ์ฐ๊ตฌ๋ผ๊ณ ์๊ฐํฉ๋๋ค. ํนํ basis vector, subspace, manifold์ ๊ฐ์ด ์ด๋ ต์ง์์ ์ํ์ ๊ฐ๋ ๋ค์ ์ ์ ์ฉํ์ฌ ์๋ฏธ์๋ ๊ฒฐ๊ณผ๋ฅผ ๋์ถํด๋ธ ์ ์ด ๊ต์ฅํ ์ธ์๊น์ต๋๋ค. ํ์ ์๊ณ ๋ง ์๋ ์ํ์ ๊ฐ๋ ๋ค์ neural network์์ ์ฐ๊ฒฐ ์ง์ ์ ๋ค์ ์๊ฐํด๋ณผ ์ ์๋ ๊ธฐํ์๊ณ , ๊ฐ์ธ์ ์ผ๋ก Explainable AI์ ๊ด์ฌ์ด ๋ง์ ํฅ๋ฏธ๋ก์ ์ต๋๋ค. ๊ทธ๋ฌ๋ ์ด๋ฐ interpretableํ network๊ฐ ์ฃผ๋ก ์ด๋ฏธ์ง ๋ฐ์ดํฐ์ชฝ์ ์น์ฐ์ณ ์๋ค๋ ์ ์ด ์์ฌ์ ๊ณ audio, text ๋ฑ์๋ generalํ๊ฒ ์ฐ์ผ ์ ์๋ architecture์ ๋ํ ์ฐ๊ตฌ์ ํ์์ฑ์ ๋๊ผ์ต๋๋ค.
Author Information
TaeMi, Kim
KAIST, Industrial and Systems Engineering
Computer Vision, XAI
6. Reference & Additional materials
Github Implementation None
Reference
Chaofan Chen et al, This looks like that: deep learning for interpretable image recognition, NeurIPS, 2019.
https://en.wikipedia.org/wiki/Grassmannian
Last updated