TesNet

Jiaqi Wang / Interpretable Image Recognition by Constructing Transparent Embedding Space / ICCV-2021

1. Problem Definition

Convolution Neural Network(CNN)์˜ ๊ฒฐ๊ณผ ํ•ด์„์€ ํŒ๋‹จ์˜ ์ •ํ™•ํ•œ ๊ทผ๊ฑฐ๊ฐ€ ํ•„์ˆ˜์ ์ธ ์ž์œจ ์ฃผํ–‰ ์ž๋™์ฐจ์™€ ์•” ์ง„๋‹จ๊ณผ ๊ฐ™์€ ์˜๋ฃŒ ๋ถ„์•ผ์—์„œ ์ค‘์š”ํ•œ ๊ณผ์ œ์ž…๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋‹ค์–‘ํ•œ ํƒœ์Šคํฌ์—์„œ CNN์˜ ์„ฑ๋Šฅ์ด ๋น„์•ฝ์ ์œผ๋กœ ๋ฐœ์ „ํ•œ ๋ฐ์— ๋น„ํ•ด, ์—ฌ์ „ํžˆ ๋„คํŠธ์›Œํฌ์˜ output์„ ์‚ฌ๋žŒ์ด ์‰ฝ๊ฒŒ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋Š” ์˜๋ฏธ๋“ค๋กœ ํ•ด์„ํ•˜๋Š” ๋ฐ์—๋Š” ์–ด๋ ค์›€์ด ๋งŽ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ์ตœ๊ทผ์— CNN ๋‚ด๋ถ€์˜ feature representation์„ ์‹œ๊ฐํ™”ํ•˜๋Š” ๋งŽ์€ interpetableํ•œ ๋ฐฉ๋ฒ•๋“ค์ด ์ œ์•ˆ๋˜์—ˆ์ง€๋งŒ, ์‹œ๊ฐํ™”๋œ ๋„คํŠธ์›Œํฌ ๋‚ด๋ถ€ feature์™€ ์˜๋ฏธ ํ•ด์„ ๊ฐ„์˜ gap์€ ์—ฌ์ „ํžˆ ํฝ๋‹ˆ๋‹ค.

๋”ฐ๋ผ์„œ interpretable image classification(ํ•ด์„ ๊ฐ€๋Šฅํ•œ ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜)๋ฅผ ์œ„ํ•ด ์‚ฌ๋žŒ๋“ค์ด ์‰ฝ๊ฒŒ ๊ทธ ์˜๋ฏธ๋ฅผ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋Š” input image์˜ concepts๋ฅผ ์ถ”์ถœํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•œ ์—ฐ๊ตฌ๊ฐ€ ์ด๋ฃจ์–ด์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๊ธฐ์กด ๊ด€๋ จ ์—ฐ๊ตฌ๋“ค์ด ์ œ์•ˆํ•œ concepts๋Š” ์„œ๋กœ ๋’ค์–ฝํ˜€์žˆ์–ด output class์— ๋Œ€ํ•œ ๊ฐ ๊ฐœ๋ณ„ concept์˜ ์˜ํ–ฅ์„ ํ•ด์„ํ•˜๊ธฐ ์–ด๋ ต์Šต๋‹ˆ๋‹ค.

๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ์ด๋ฅผ ๋ฌธ์ œ์ ์œผ๋กœ ์ง€์ ํ•˜๋ฉฐ output class์— ๋Œ€ํ•œ input image์˜ ํŠน์ง•์„ ํšจ๊ณผ์ ์œผ๋กœ ์„ค๋ช…ํ•  ์ˆ˜ ์žˆ์œผ๋ฉด์„œ, ๋™์‹œ์— ์„œ๋กœ ์–ฝํ˜€์žˆ์ง€์•Š๊ณ  orthogonalํ•œ(์ง๊ต๋ฅผ ์ด๋ฃจ๋Š”) concepts๋ฅผ ์ถ”์ถœํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•๋ก ์„ ์ œ์•ˆํ•ฉ๋‹ˆ๋‹ค.

2. Motivation

๊ทธ๋ ‡๋‹ค๋ฉด Interpretable Concepts (ํ•ด์„์ด ์šฉ์ดํ•œ ์ปจ์…‰)์ด๋ž€ ๋ฌด์—‡์ผ๊นŒ์š”? ์ธ์ง€์  ๊ด€์ ์—์„œ Interpretable Concepts๋Š” ๋‹ค์Œ์˜ ์„ธ ๊ฐ€์ง€ ์กฐ๊ฑด์„ ๋งŒ์กฑํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

(1) Informative Input data๋Š” basis concept๋“ค๋กœ spanned๋œ vector space์ƒ์—์„œ ํšจ์œจ์ ์œผ๋กœ ๋‚˜ํƒ€๋‚ด์ ธ์•ผํ•˜๋ฉฐ, input์˜ essential information(์ค‘์š”ํ•œ ์ •๋ณด)๊ฐ€ ์ƒˆ๋กœ์šด representation space์—์„œ๋„ ๋ณด์กด๋˜์–ด์•ผํ•ฉ๋‹ˆ๋‹ค. (2) Diversity ๊ฐ ๋ฐ์ดํ„ฐ(ex.์ด๋ฏธ์ง€)๋Š” ์„œ๋กœ ์ค‘๋ณต๋˜์ง€ ์•Š๋Š” ์†Œ์ˆ˜์˜ basis concepts์™€ ๊ด€๋ จ ์žˆ์–ด์•ผํ•˜๋ฉฐ, ๊ฐ™์€ class์— ์†ํ•˜๋Š” ๋ฐ์ดํ„ฐ๋“ค์€ ๋น„์Šทํ•œ basis concepts๋ฅผ ๊ณต์œ ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. (3) Discriminative Basis concepts๋Š” (1)์—์„œ ์–ธ๊ธ‰ํ•œ basis concept vector space์ƒ์—์„œ๋„ class๊ฐ€ ์ž˜ ๋ถ„๋ฆฌ๋˜๋„๋ก class-awareํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ฆ‰, ๊ฐ™์€ class์™€ ์—ฐ๊ด€๋œ basis concepts๋ผ๋ฆฌ๋Š” ๊ทผ์ ‘ํ•˜๊ฒŒ, ๋‹ค๋ฅธ class์˜ basis concepts ๊ฐ„์—๋Š” ๋ฉ€๊ฒŒ embedding๋˜์–ด ์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๋ฐ์ดํ„ฐ์˜ concepts๋ฅผ ์ถ”์ถœํ•˜๊ธฐ ์œ„ํ•ด ์ด์ „ ์—ฐ๊ตฌ๋“ค์€ auto-encoding, prototype learning๊ณผ ๊ฐ™์ด deep neural network์˜ high-level feature๋ฅผ ์ด์šฉํ•˜๋Š” ๋ฐฉ์‹์„ ์ œ์•ˆํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๊ทธ ์ค‘ ํ•œ ๋ฐฉ๋ฒ•์€ U-shaped Beta Distribution์„ ์ด์šฉํ•˜์—ฌ basis concepts์˜ ๊ฐœ์ˆ˜๋ฅผ ์ œํ•œํ•จ์œผ๋กœ์จ ๊ฐ input data๋ฅผ ์†Œ์ˆ˜์˜ ์˜๋ฏธ ์žˆ๋Š” basis concept๋“ค๋กœ ๋‚˜ํƒ€๋‚ด๊ธฐ๋„ ํ•˜์˜€์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์—ฐ๊ตฌ๋“ค์€ Interpretable Concepts์˜ ์ฒซ๋ฒˆ์งธ ์กฐ๊ฑด์„ ๋งŒ์กฑํ•˜์˜€์ง€๋งŒ, ์•ž์„œ ์–ธ๊ธ‰ํ•˜์˜€๋“ฏ์ด basis concepts๊ฐ€ ์„œ๋กœ ์–ฝํ˜€์žˆ์–ด(entangled) input๊ณผ output์— ๋Œ€ํ•œ ๊ฐœ๋ณ„ concept์˜ ์˜ํ–ฅ์„ ํ•ด์„ํ•˜๊ธฐ ์–ด๋ ต๋‹ค๋Š” ๋ฌธ์ œ์ ์ด ์กด์žฌํ•ฉ๋‹ˆ๋‹ค.

๋”ฐ๋ผ์„œ, ์ด ๋…ผ๋ฌธ์—์„œ๋Š” ์œ„์˜ ์„ธ๊ฐ€์ง€ Interpretable Concepts ์กฐ๊ฑด์„ ๋ชจ๋‘ ์ถฉ์กฑ์‹œํ‚ค๋Š” basis concepts๋ฅผ ์„ค๊ณ„ํ•˜๋Š” ๋ฐ์— ์ฃผ๋ชฉํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๋…ผ๋ฌธ์—์„œ ์„ค๊ณ„ํ•œ basis concepts๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ํŠน์ง•๋“ค์„ ๊ฐ€์ง‘๋‹ˆ๋‹ค.

(1) ๊ฐ class๋Š” ์ž์‹ ๋งŒ์˜ basis concepts๋ฅผ ๊ฐ€์ง€๋ฉฐ class๊ฐ€ ๋‹ค๋ฅธ ๊ฒฝ์šฐ basis concepts๋„ ์ตœ๋Œ€ํ•œ ๋‹ค๋ฆ…๋‹ˆ๋‹ค. (2) High-level feature๊ณผ basis concepts ์‚ฌ์ด๋ฅผ ํšจ๊ณผ์ ์œผ๋กœ ์—ฐ๊ฒฐํ•˜๋Š” mapping์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. (3) Input image ์ƒ์˜ basis concepts๋Š” ๊ฐ class์— ๋Œ€ํ•œ prediction score์„ ๊ณ„์‚ฐํ•˜๋Š” ๋ฐ์— ๋„์›€์ด ๋ฉ๋‹ˆ๋‹ค.

์œ„์˜ ์„ธ ๊ฐ€์ง€ ํŠน์ง•์„ ๋งŒ์กฑํ•˜๋Š” basis concepts ์„ค๊ณ„๋ฅผ ์œ„ํ•ด, ๋ณธ ๋…ผ๋ฌธ์€ ๊ธฐ์กด ์—ฐ๊ตฌ๋“ค๊ณผ ๋‹ค๋ฅด๊ฒŒ Grassmann manifold๋ฅผ ๋„์ž…ํ•˜์—ฌ basis concept vector space๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ์˜ ๊ทธ๋ฆผ์ฒ˜๋Ÿผ, ๊ฐ class๋งˆ๋‹ค์˜ basis concepts subset์ด Grassmann manifold ์ƒ์˜ point๋กœ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค. figure1 Grassmann manifold๋Š” ์‰ฝ๊ฒŒ ๋งํ•˜๋ฉด linear subspaces์˜ set(์ง‘ํ•ฉ)์ด๋ผ๊ณ  ์ƒ๊ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ subspace๋ž€ vector space V ์˜ subset(๋ถ€๋ถ„์ง‘ํ•ฉ) W ๊ฐ€ V ๋กœ๋ถ€ํ„ฐ ๋ฌผ๋ ค๋ฐ›์€ ์—ฐ์‚ฐ๋“ค๋กœ ์ด๋ฃจ์–ด์ง„ ๋˜ ๋‹ค๋ฅธ ํ•˜๋‚˜์˜ vector space์ผ ๋•Œ W ๋ฅผ V ์˜ subspace๋ผ๊ณ  ๋งํ•ฉ๋‹ˆ๋‹ค.

๋˜ํ•œ projection metric์„ ํ†ตํ•ด ๊ฐ class์˜ basis concept๋“ค์€ ์„œ๋กœ orthogonalํ•˜๋„๋ก, ๋™์‹œ์— class-awareํ•œ basis concepts subset๋“ค์€ ์„œ๋กœ ๋ฉ€๋ฆฌ ์œ„์น˜ํ•˜๋„๋ก ๊ทœ์ œ๋ฉ๋‹ˆ๋‹ค. ์ด ๋‘ ๊ฐ€์ง€ ๊ทœ์ œ๋ฅผ ํ†ตํ•ด basis concepts๊ฐ€ ์„œ๋กœ ์–ฝํžˆ์ง€ ์•Š๋„๋ก ํ•จ์œผ๋กœ์จ ๊ธฐ์กด ์—ฐ๊ตฌ์˜ ํ•œ๊ณ„์ ์„ ๊ทน๋ณตํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

๋…ผ๋ฌธ์€ ์ด๋ ‡๊ฒŒ ์„ค๊ณ„๋œ transparent embedding space (concept vector space)๊ฐ€ ๋„์ž…๋œ ์ƒˆ๋กœ์šด interpetable network, TesNet์„ ์ œ์•ˆํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

3. Method

The overview of TesNet architecture

๋‹ค์Œ์€ TesNet์˜ ์ „์ฒด์ ์ธ architecture์˜ ๋ชจ์Šต์ž…๋‹ˆ๋‹ค. figure2 ๊ทธ๋ฆผ๊ณผ ๊ฐ™์ด TesNet์€ convolutional layers f, trasparent subspace layer s_bs\_{b}, ๊ทธ๋ฆฌ๊ณ  classifier h ์ด๋ ‡๊ฒŒ ์„ธ ๊ฐ€์ง€์˜ ํ•ต์‹ฌ ์š”์†Œ๋กœ ์ด๋ฃจ์–ด์ ธ ์žˆ์Šต๋‹ˆ๋‹ค.

๊ฐ ์š”์†Œ๋ฅผ ํ•˜๋‚˜์”ฉ ์‚ดํŽด๋ณด๋ฉด, ๋จผ์ € convloutional layers f ๋Š” 1X1 convolutional layer๋“ค์ด ์ถ”๊ฐ€๋œ ๊ธฐ๋ณธ CNN ๋„คํŠธ์›Œํฌ(ex.ResNet) ์ž…๋‹ˆ๋‹ค. sbs_{b}๋Š” feature map์„ transparent embedding space์— projection์‹œํ‚ค๋Š” subspace layer์ž…๋‹ˆ๋‹ค. ๊ฐ class๋งˆ๋‹ค subspace๊ฐ€ ์กด์žฌํ•˜์—ฌ ์ด class ๊ฐœ์ˆ˜๋งŒํผ์˜ subspace๊ฐ€ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค. ๊ฐ class์˜ subspace๋Š” M๊ฐœ์˜ basis concepts๋กœ spanned ๋˜์–ด์žˆ์Šต๋‹ˆ๋‹ค. ์ด M๊ฐœ์˜ within-class concepts(ํด๋ž˜์Šค ๋‚ด๋ถ€ concepts)๋Š” ์„œ๋กœ orthogonalํ•˜๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์ด C๊ฐœ์˜ class๊ฐ€ ์žˆ์„ ๋•Œ, ๊ฐ class ๋งˆ๋‹ค M๊ฐœ์˜ basis concepts๊ฐ€ ์กด์žฌํ•œ๋‹ค๊ณ  ๊ฐ€์ •ํ•˜๋ฉด ์ „์ฒด CM๊ฐœ์˜ basis concepts๊ฐ€ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค.

Embedding space learning

๊ทธ๋ ‡๋‹ค๋ฉด basis concepts๋Š” ์–ด๋–ป๊ฒŒ ์ •์˜๋˜์–ด embedding space๋ฅผ ์ด๋ฃจ๊ณ  ์žˆ๋Š”์ง€ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

๊ฐ basis concept์€ basis vector๋กœ ํ‘œํ˜„๋ฉ๋‹ˆ๋‹ค. ์ด basis vector๋Š” ๋‹ค์Œ ์„ธ ๊ฐ€์ง€ ์กฐ๊ฑด์„ ๋งŒ์กฑํ•ด์•ผํ•ฉ๋‹ˆ๋‹ค. (1) ๋‹ค๋ฅธ basis vector ์‚ฌ์ด์—๋Š” ์˜๋ฏธ๊ฐ€ ์ค‘๋ณต๋˜๋ฉด ์•ˆ๋ฉ๋‹ˆ๋‹ค. (2) embedding space์—์„œ๋„ ๊ฐ class๋Š” ๊ตฌ๋ถ„๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. (3) basis vector๋“ค์€ ๋น„์Šทํ•œ high-level patch(์‚ฌ๋žŒ๋“ค์ด ์ธ์‹ํ•  ์ˆ˜ ์žˆ๋Š” level์˜ image)๋“ค์„ ๊ตฐ์ง‘ํ™”ํ•˜๊ณ  ๋‹ค๋ฅธ ๊ฒƒ๋“ค๋ผ๋ฆฌ๋Š” ๋ถ„๋ฆฌํ•  ์ˆ˜ ์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์ด ์„ธ ๊ฐ€์ง€ ์กฐ๊ฑด์„ ๋งŒ์กฑ์‹œํ‚ค๊ธฐ ์œ„ํ•ด ์ „์ฒด architecture์—์„œ ๋ณด์•˜๋˜ convolutional layer, basis vectors, classifier layer์˜ weight๋“ค์ด ์„œ๋กœ jointํ•˜๊ฒŒ optimize(์ตœ์ ํ™”)๋  ์ˆ˜ ์žˆ๋„๋ก joint optimization problem์„ ์ •์˜ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ์€ ๊ฐ weight๋ฅผ ์ตœ์ ํ™”ํ•˜๊ธฐ ์œ„ํ•œ Loss์™€ optimization ๊ณผ์ •์ž…๋‹ˆ๋‹ค.

Orthonormality for Within-class Concepts

์กฐ๊ฑด (1)์„ ๋งŒ์กฑ์‹œํ‚ค๊ธฐ ์œ„ํ•œ Loss๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

figure3

basis vector์‚ฌ์ด์— ์˜๋ฏธ๊ฐ€ ์ค‘๋ณต๋˜์ง€ ์•Š๋Š”๋‹ค๋Š” ๊ฒƒ์€ ๊ฐ™์€ class์— ์†ํ•œ basis concepts์ด๋”๋ผ๋„ ๋ฐ˜๋“œ์‹œ ์„œ๋กœ ๋‹ค๋ฅธ ์ธก๋ฉด๋“ค์„ ๋‚˜ํƒ€๋‚ด๊ณ  ์žˆ์–ด์•ผํ•œ๋‹ค๋Š” ๋œป์ž…๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๊ธฐ ์œ„ํ•ด์„  ๊ฐ™์€ class์— ์†ํ•œ basis concept vectors๊ฐ€ ์„œ๋กœ orthogonalํ•ด์•ผ ํ•˜๋ฏ€๋กœ ๊ฐ class์˜ basis vectors ์‚ฌ์ด์˜ orthonormality๋ฅผ ๊ทœ์ œํ•˜๋Š” Loss๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

Loss ์‹์„ ์‚ดํŽด๋ณด๋ฉด ๊ฐ class์˜ basis vector matrix ํ–‰๋ ฌ๊ณฑ๊ณผ identity matrix ์‚ฌ์ด์˜ L2 norm์„ ๋ชจ๋‘ ๋”ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์ฆ‰, ๊ฐ class์˜ basis vectors๊ฐ„์˜ correlation(์ƒ๊ด€ ๊ด€๊ณ„)๋ฅผ ์ตœ์†Œํ™”์‹œํ‚ค๊ธฐ ์œ„ํ•œ Loss์ž…๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ Loss๋ฅผ ํ†ตํ•ด ํ•™์Šต๋œ orthonormal basis vectors๊ฐ€ ๊ฐ class์˜ subspace๋ฅผ spanํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

Separtion for Class-aware Subsapces

๋‘๋ฒˆ์งธ๋กœ ์กฐ๊ฑด (2)๋ฅผ ๋งŒ์กฑ์‹œํ‚ค๊ธฐ ์œ„ํ•œ Loss๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

figure4 embedding space์ƒ์—์„œ class๊ฐ€ ๊ตฌ๋ถ„๋˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ๊ฐ class์˜ subspace๊ฐ€ ์„œ๋กœ ๋ฉ€๋ฆฌ ์œ„์น˜ํ•ด ์žˆ์–ด์•ผํ•ฉ๋‹ˆ๋‹ค. ์ฆ‰, Grassmann manifold ์ƒ์—์„œ class-aware subspace๋“ค์˜ ๊ฑฐ๋ฆฌ๊ฐ€ ์ตœ๋Œ€ํ•œ ๋ฉ€์–ด์ง€๋„๋ก ๊ทœ์ œํ•ฉ๋‹ˆ๋‹ค. ๊ฐ subspace๋Š” Grasmann manifold์ƒ์—์„œ uniqueํ•œ projection์œผ๋กœ ์กด์žฌํ•˜๋ฏ€๋กœ, subspace ์‚ฌ์ด์˜ ๊ฑฐ๋ฆฌ๋ฅผ projection mapping์„ ์ด์šฉํ•˜์—ฌ ์ˆ˜์น˜ํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Loss ์‹์—์„œ BcB^{c}๋Š” class c์˜ orthonormal basis vectors๋กœ ์ด๋ฃจ์–ด์ง„ matrix๋ฅผ ์˜๋ฏธํ•˜๊ณ , ์ด matrix์˜ ํ–‰๋ ฌ๊ณฑ์ด class c์™€ ์—ฐ๊ด€๋œ subspace์˜ projection mapping์ž…๋‹ˆ๋‹ค. ๊ฒฐ๊ตญ Loss๋Š” ์„œ๋กœ ๋‹ค๋ฅธ class์˜ projection mapping ์‚ฌ์ด์˜ L2 norm distance๋“ค์˜ ํ•ฉ์„ ์ตœ์†Œํ™”์‹œํ‚ค๊ธฐ ์œ„ํ•œ Loss๋กœ ์ดํ•ดํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

High-level Patches Grouping

๋งˆ์ง€๋ง‰์œผ๋กœ ์กฐ๊ฑด (3)์„ ๋งŒ์กฑ์‹œํ‚ค๊ธฐ ์œ„ํ•œ Loss์ž…๋‹ˆ๋‹ค. figure5 ์กฐ๊ฑด (3)์€ ๊ฒฐ๊ตญ high-level ์ด๋ฏธ์ง€ ํŒจ์น˜๋“ค์ด embedding subspace์—๋„ ์ž˜ projection ๋˜์–ด์•ผ ํ•œ๋‹ค๋Š” ์˜๋ฏธ์ž…๋‹ˆ๋‹ค. ์ฆ‰, ์ด๋ฏธ์ง€ ํŒจ์น˜๋“ค์ด subspace์— embedding ๋˜์—ˆ์„ ๋•Œ ์ด๋ฏธ์ง€๊ฐ€ ์†ํ•œ ground-truth class์˜ basis vectors์™€ ๊ทผ์ ‘ํ•ด์•ผํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด ๋…ผ๋ฌธ์€ Compactness Loss์™€ Separation Loss๋ฅผ ์ •์˜ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

๋จผ์ € Compactness Loss์˜ ์‹์„ ์‚ดํŽด๋ณด๋ฉด, ์ด๋ฏธ์ง€ ํŒจ์น˜์™€ ground-truth class์˜ basis vectors์‚ฌ์ด์˜ cosine distance(negative cosine similarity)๋ฅผ ์ตœ์†Œํ™”ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ๊ฒฐ๊ตญ ์ด๋ฏธ์ง€ ํŒจ์น˜์™€ ground-truth class์˜ basis vectors์‚ฌ์ด์˜ cosine similarity๋ฅผ ํฌ๊ฒŒํ•˜๋Š” ๊ฒƒ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

๋ฐ˜๋ฉด, Separation Loss๋Š” ์ด๋ฏธ์ง€ ํŒจ์น˜๊ฐ€ ground-truth๊ฐ€ ์•„๋‹Œ class์˜ basis vectors๊ณผ๋Š” ๋ฉ€์–ด์ง€๋„๋ก ๋‘˜ ์‚ฌ์ด์˜ cosine similarity๋ฅผ ์ตœ์†Œํ™”ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

์ด ๋‘ Loss๋ฅผ hyper-parameter M ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋”ํ•จ์œผ๋กœ์จ Compactness-Separation Loss๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.

Identification

๋งˆ์ง€๋ง‰์œผ๋กœ classifier layer๋ฅผ optimizeํ•˜๊ธฐ ์œ„ํ•œ Loss๋กœ์„œ Cross Entropy Loss๋ฅผ ์ด์šฉํ•ฉ๋‹ˆ๋‹ค. figure6

์ตœ์ข…์ ์œผ๋กœ, ์ง€๊ธˆ๊นŒ์ง€ ์ •์˜๋œ loss๋“ค์„ jointly optimizeํ•˜๊ธฐ ์œ„ํ•ด Total Loss for Joint Optimization์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. figure7

hyper-parameters๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ classification loss(cross entropy loss)์— orthonormality loss, subspace separation loss, compactness-separation loss๋ฅผ ์ ์ ˆํ•œ ๋น„์œจ๋กœ ๋”ํ•ด์ค๋‹ˆ๋‹ค. ์ด total loss์™€ ํ•จ๊ป˜ convolutional layer, basis vectors๊ฐ€ ๋™์‹œ์— ์ตœ์ ํ™”๋˜๋ฉฐ concept embedding subspace๊ฐ€ ํ•™์Šต๋ฉ๋‹ˆ๋‹ค.

Concept-based classification

embedding space๊ฐ€ ํ•™์Šต๋˜๊ณ  ๋‚˜๋ฉด, convolutional layers์™€ basis vectors์˜ parameter๋ฅผ ๊ณ ์ •์‹œํ‚จ ํ›„, ๋งˆ์ง€๋ง‰ ๋‹จ์˜ classifier๋ฅผ ํ•™์Šต์‹œํ‚ค๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. classifier๋Š” concept-class weight G ๋ฅผ ์ตœ์ ํ™”ํ•จ์œผ๋กœ์จ ํ•™์Šต์ด ๋˜๋Š”๋ฐ, weight G ๋Š” G(c,j) ์˜ ๊ฐ’์ด j๋ฒˆ์งธ unit์ด class c์— ์†ํ•˜๋Š” ๊ฒฝ์šฐ๋ฅผ ์ œ์™ธํ•˜๊ณ  ๋ชจ๋‘ 0์ธ sparse matrix์ž…๋‹ˆ๋‹ค. ์•ž์„œ ์ •์˜ํ•œ Identification Loss์— weight G ๋ฅผ sparseํ•˜๊ฒŒ ์œ ์ง€ํ•˜๊ฒŒ ํ•˜๋Š” ๊ทœ์ œ๋ฅผ ๋”ํ•˜์—ฌ Loss`๋ฅผ ์ •์˜ํ•˜๊ณ , ์ด Loss๋ฅผ ์ตœ์†Œํ™”ํ•˜๋„๋ก classifier๊ฐ€ ํ•™์Šต๋ฉ๋‹ˆ๋‹ค. figure10

4. Experiment

๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ๋‹ค์–‘ํ•œ CNN architecture์— ๋Œ€ํ•œ TesNet์˜ ๋„“์€ ์ ์šฉ์„ฑ์„ ์ž…์ฆํ•˜๊ธฐ ์œ„ํ•ด ๋‘ ๊ฐ€์ง€์˜ case study๋ฅผ ์ง„ํ–‰ํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๊ทธ ์ค‘ ์ฒซ๋ฒˆ์งธ case study์ธ bird species identification์— ๋Œ€ํ•ด์„œ ์ž์„ธํžˆ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

Experiment setup

  • Dataset Caltecg-USCD Birds-200-2011 dataset์„ ์‚ฌ์šฉํ•˜์—ฌ bird species classification ์‹คํ—˜์„ ์ง„ํ–‰ํ•˜์˜€์Šต๋‹ˆ๋‹ค. dataset์€ 200 ์ข…(species)์˜ bird ์ด๋ฏธ์ง€ 5994+5794์žฅ์œผ๋กœ ์ด๋ฃจ์–ด์กŒ์Šต๋‹ˆ๋‹ค. ๊ทธ ์ค‘ 5994์žฅ์€ training, ๋‚˜๋จธ์ง€ 5794์žฅ์€ test์‹œ ์ด์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๊ฐ bird class๋งˆ๋‹ค 30์žฅ์˜ ์ด๋ฏธ์ง€๋ฐ–์— ์กด์žฌํ•˜์ง€ ์•Š์•„, ๋…ผ๋ฌธ์—์„œ๋Š” random rotation, skew, shear, flip ๋“ฑ์˜ augmentation์„ ํ†ตํ•ด training set์˜ ๊ฐ class๋งˆ๋‹ค 1200์žฅ์˜ ์ด๋ฏธ์ง€๊ฐ€ ์กด์žฌํ•˜๋„๋ก ๋ฐ์ดํ„ฐ๋ฅผ ์ฆ๊ฐ•ํ•˜์˜€์Šต๋‹ˆ๋‹ค.

  • baseline non-interpetableํ•œ ๋ณธ๋ž˜ VGG16, VGG19, ResNet34, ResNet152, DenseNet121, DenseNet161 ๋„คํŠธ์›Œํฌ๋“ค์„ baseline์œผ๋กœ ์‚ผ๊ณ , ๊ฐ ๋„คํŠธ์›Œํฌ์— interpetableํ•œ TesNet์„ ์ ์šฉํ•œ ๊ฒฝ์šฐ์™€ ๋น„๊ต ์‹คํ—˜ํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ, TesNet๊ณผ ์œ ์‚ฌํ•œ interpetable network architecture์ธ ProtoPNet์„ ์ ์šฉํ•œ ๊ฒฐ๊ณผ๋„ ํ•จ๊ป˜ ๋น„๊ตํ•˜์˜€์Šต๋‹ˆ๋‹ค.

  • Evaluation Metric ์‹คํ—˜์˜ ์„ฑ๋Šฅ ํ‰๊ฐ€์ง€ํ‘œ๋กœ classification accuracy๋ฅผ ์‚ฌ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค.

Result

  • Accuracy comparison with diffrent CNN architectures ์•„๋ž˜ ํ‘œ์—์„œ ์•Œ ์ˆ˜ ์žˆ๋“ฏ์ด, baseline network์— TesNet์„ ์ ์šฉํ•œ ๊ฒฝ์šฐ ๋ถ„๋ฅ˜ ์ •ํ™•๋„๊ฐ€ ์ตœ๋Œ€ 8%์ •๋„ ํฌ๊ฒŒ ํ–ฅ์ƒ๋œ ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ, TesNet์˜ Loss๋ฅผ ๋‹ค์–‘ํ•˜๊ฒŒ ์ •์˜ํ•˜์—ฌ ์‹คํ—˜ํ•œ ๊ฒฐ๊ณผ, 4๊ฐ€์ง€ Loss๋ฅผ ๋ชจ๋‘ jointlyํ•˜๊ฒŒ optimizeํ•˜์˜€์„ ๋•Œ ๊ฐ€์žฅ ์ •ํ™•๋„๊ฐ€ ๋†’์€ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. figure8

  • The interpretable reasoning process ๋‹ค์Œ ๊ทธ๋ฆผ์€ TesNet์ด test image์— ๋Œ€ํ•˜์—ฌ decision์„ ๋‚ด๋ฆฌ๋Š” reasoning process๋ฅผ ์‹œ๊ฐํ™”ํ•œ ๊ฒƒ์ž…๋‹ˆ๋‹ค. figure9

European Goldfinch๋ผ๋Š” class์˜ test image๊ฐ€ ์ฃผ์–ด์กŒ๋‹ค๊ณ  ํ•  ๋•Œ, TesNet์€ ํ•™์Šต๋œ basis vectors๋ฅผ ํ†ตํ•ด feature map์„ re-representํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ class c์— ๋Œ€ํ•ด์„œ, ๋ชจ๋ธ์€ ํ•™์Šต๋œ basis vectors๋ฅผ image patch์— re-representํ•จ์œผ๋กœ์จ ๊ทธ image๊ฐ€ class c์— ์†ํ•  score๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด, ์œ„ ๊ทธ๋ฆผ์—์„œ ๋ชจ๋ธ์€ European goldfinch class์˜ basis vector(concept)๋ฅผ test image(original image)๊ฐ€ ์ด class์— ์†ํ• ์ง€์— ๋Œ€ํ•œ ์ฆ๊ฑฐ๋กœ ํ™œ์šฉํ•ฉ๋‹ˆ๋‹ค. Activation map column์„ ์‚ดํŽด๋ณด๋ฉด, European goldfinch class์˜ ์ฒซ ๋ฒˆ์งธ basis vector๊ฐ€ ์˜๋ฏธํ•˜๋Š” 'black and yellow wing concept'์ด test image ์ƒ์—์„œ ๊ฐ€์žฅ ๋‘๋“œ๋Ÿฌ์ง€๊ฒŒ activated(ํ™œ์„ฑํ™”) ๋œ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ™์€ ๋ฐฉ์‹์œผ๋กœ ๋‘ ๋ฒˆ์งธ basis vector๊ฐ€ ์˜๋ฏธํ•˜๋Š” 'head concept', ์„ธ ๋ฒˆ์งธ basis vector๊ฐ€ ์˜๋ฏธํ•˜๋Š” 'brown fur concept'์ด image์ƒ์—์„œ ํฌ๊ฒŒ ํ™œ์„ฑํ™”๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

์ด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๋ชจ๋ธ์€ class์˜ ๊ฐ basis concept vector์™€ test image์ƒ์—์„œ activated๋œ ๋ถ€๋ถ„ ์‚ฌ์ด์˜ similarity(์œ ์‚ฌ๋„)๋ฅผ ๊ตฌํ•˜๊ณ  basis concept์˜ ์ค‘์š”๋„์— ๋”ฐ๋ผ ๊ฐ€์ค‘์น˜๋ฅผ ๋งค๊ฒจ ๋”ํ•จ์œผ๋กœ์จ ์ตœ์ข…์ ์ธ European Goldfinch class์— ๋Œ€ํ•œ score๋ฅผ ๊ตฌํ•ฉ๋‹ˆ๋‹ค. ์ด score๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ test image์˜ class๋ฅผ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.

์ด๋Ÿฌํ•œ reasoning ๊ณผ์ •์„ ํ†ตํ•ด baseline CNN ๋ชจ๋ธ๋“ค๋ณด๋‹ค ๋†’์€ ๋ถ„๋ฅ˜ ์ •ํ™•๋„๋ฅผ ๋‹ฌ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

5. Conclusion

  • Summary TesNet์€ ๋‹ค๋ฅธ CNN ๋ชจ๋ธ์— plug-in๋˜์–ด classifiaction ์„ฑ๋Šฅ์„ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ๋Š” ์ ์šฉ์„ฑ ๋†’์€ architecture์ž…๋‹ˆ๋‹ค. TesNet์€ class-aware concepts๋ฅผ ์„ค๊ณ„ํ•˜๊ณ  ๊ฐ™์€ class์— ์†ํ•œ concepts๋ผ๋ฆฌ ์–ฝํžˆ์ง€ ์•Š๋„๋ก ํ•˜๋ฉฐ ํšจ๊ณผ์ ์œผ๋กœ prediction ์„ฑ๋Šฅ์„ ํ–ฅ์ƒ์‹œ์ผฐ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ, TesNet์€ image์˜ ์–ด๋–ค concept์ด CNN์„ ํ•™์Šต์‹œํ‚ค๊ณ  ์˜ˆ์ธกํ•˜๋Š” ๋ฐ์— ๊ทผ๊ฑฐ๋กœ ์‚ฌ์šฉ๋˜๋Š”์ง€๋ฅผ ์„ค๋ช…ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜, TesNet์€ basis concepts๊ฐ€ ๋ชจ๋‘ flatํ•˜๋‹ค๋Š” ์ „์ œ๋ฅผ ํ•˜๊ณ  ์žˆ์–ด, ์‚ฌ๋žŒ๋“ค์ด ์‹ค์ œ๋กœ ์‚ฌ๋ฌผ์„ ๋ถ„๋ฅ˜ํ•  ๋•Œ์˜ ์ธ์ง€ ๊ณผ์ •๊ณผ ํฐ ์ฐจ์ด๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ ์‹ค์ œ๋กœ real world์—์„œ์˜ concepts๋Š” ์„œ๋กœ ๊ณ„์ธต์ ์œผ๋กœ ์ด๋ฃจ์–ด์ ธ์žˆ๊ธฐ ๋•Œ๋ฌธ์—, hierarchical basis concepts๋ฅผ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๋Š” ๋„คํŠธ์›Œํฌ์— ๋Œ€ํ•œ ์—ฐ๊ตฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

  • Opinion CNN์˜ output ํ•ด์„์— ์žˆ์–ด input image์˜ concept์ด๋ผ๋Š” ๊ฐœ๋…์„ ์ž˜ ์ •์˜ํ•œ ์—ฐ๊ตฌ๋ผ๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค. ํŠนํžˆ basis vector, subspace, manifold์™€ ๊ฐ™์ด ์–ด๋ ต์ง€์•Š์€ ์ˆ˜ํ•™์  ๊ฐœ๋…๋“ค์„ ์ž˜ ์ ์šฉํ•˜์—ฌ ์˜๋ฏธ์žˆ๋Š” ๊ฒฐ๊ณผ๋ฅผ ๋„์ถœํ•ด๋‚ธ ์ ์ด ๊ต‰์žฅํžˆ ์ธ์ƒ๊นŠ์Šต๋‹ˆ๋‹ค. ํ‰์†Œ ์•Œ๊ณ ๋งŒ ์žˆ๋˜ ์ˆ˜ํ•™์  ๊ฐœ๋…๋“ค์„ neural network์™€์˜ ์—ฐ๊ฒฐ ์ง€์ ์„ ๋‹ค์‹œ ์ƒ๊ฐํ•ด๋ณผ ์ˆ˜ ์žˆ๋Š” ๊ธฐํšŒ์˜€๊ณ , ๊ฐœ์ธ์ ์œผ๋กœ Explainable AI์— ๊ด€์‹ฌ์ด ๋งŽ์•„ ํฅ๋ฏธ๋กœ์› ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์ด๋Ÿฐ interpretableํ•œ network๊ฐ€ ์ฃผ๋กœ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ์ชฝ์— ์น˜์šฐ์ณ ์žˆ๋‹ค๋Š” ์ ์ด ์•„์‰ฌ์› ๊ณ  audio, text ๋“ฑ์—๋„ generalํ•˜๊ฒŒ ์“ฐ์ผ ์ˆ˜ ์žˆ๋Š” architecture์— ๋Œ€ํ•œ ์—ฐ๊ตฌ์˜ ํ•„์š”์„ฑ์„ ๋А๊ผˆ์Šต๋‹ˆ๋‹ค.


Author Information

  • TaeMi, Kim

    • KAIST, Industrial and Systems Engineering

    • Computer Vision, XAI

6. Reference & Additional materials

  • Github Implementation None

  • Reference

    • Chaofan Chen et al, This looks like that: deep learning for interpretable image recognition, NeurIPS, 2019.

    • https://en.wikipedia.org/wiki/Grassmannian

Last updated