https://arxiv.org/pdf/2312.17624.pdf
목차
사전지식
ABSTRACT
INTRODUCTION
METHOD
EXPERIMENT
CONCLUSION
사전 지식
XAI : explainable artificial intelligence
사용자가 머신러닝 알고리즘으로 생성된 결과와 출력을 이해하며 신뢰할 수 있도록 하는 일련의 프로세스
=> ai 모델과 이에 대한 영향 및 편향을 설명하는데 사용됨
ai로 의사결정을 할 때, 모델 정확성, 공정성, 투명성 및 결과 특성화
조직이 ai 발전에 책임있는 접근 방식을 채택하는데 도움을 줌
AI가 발전하면 인간은 알고리즘이 어떻게 결과에 도달했는지 알 수 있어야 관리 및 통제가 가능하다
통상적으로 ML은 이해하기 힘든 블랙박스이며 bias, 학습데이터 != 생산데이터의 한계점 존재
[XAI와 AI]
AI는 ML의 결과에 대해 ai 시스템 설계자는 어떻게 ML algorithm이 해당 결과를 도출하는지 알지 못함
XAI는 ML process 중 내린 각 결정을 추적하여 설명하기 위한 특정 기술 및 방법 구현
- 예측 정확도 : 출력과 학습데이터의 비교 ex) LIME
- 추적 가능성 : 의사결정 방법 제한, ML 규칙 및 기능을 좁게 설정 ex) DeepLIFT
- 의사결정 이해 : 인적 요소
해석 가능성 : 관리자가 ai의 출력을 예상할 수 있는 정도
설명 가능성 : ai가 어떻게 결과에 도달했는지
설명 가능한 ai는 결과가 도출된 후 결과 check라면 책임있는 ai는 결과 도출 전 알고리즘 계획 단계에서 check하며 이 둘은 더 나은 ai를 만들기위해 함께 작동되기도 함
XAI의 이점
- 신뢰성과 투명성을 갖춘 ai
- 지속적 모니터링을 통한 ai 결과 도출시간 단축
- 모델 거버넌스에 대한 위험 및 비용 완화
고려사항
- 공정성 및 편견 제거
- 모델 드리프트 완화
- 모델 위험 관리
- 라이프 사이클 자동화
- 멀티클라우드 적용
In-hospital Mortality Prediction (병원 내 사망률 예측)
환자가 병원에 입원한 동안 사망할 확률 예측
⇒ 환자의 상태 평가 및 적절한 치료 계획에 유리
Multimodal ICU data
ICU:중환자실에서 수집된 다양한 형태의 데이터
ex) 환자의 생체 신호(심박수, 혈압), 의료 기록, 영상 자료(x-ray, MRI), 실험실 결과 등이 포함됨
Layer-Wise Relevance Propagation
인공 신경망의 결정을 설명하기 위해 입력 특성의 중요도를 역전파하는 방법
Transformer
NLP에서 주로 사용되는 attention 메커니즘에 기반한 모델
MIMIC-III
ICU 환자들의 대규모 의료 데이터베이스
ABSTRACT
다양한 데이터를 input으로 받아 학습하여 판단한 결과를 설명할 수 있는 모델 X-MMP 제안
multimodal transforemr framework + XAI
multimodal transformer에 대한 LRP의 확장인 Layer-wise propagation to transformer 기법을 도입한 XAI 솔루션
MIMIC-III and MIMIC-III Waveform Database Matched Subset를 기반으로 multimodal dataset을 만들어 광범위한 학습을 진행해 본 연구의 프레임워크 성능을 증명하였고 본 프레임워크는 다른 임상 작업에도 쉽게 적용이 가능하다
INTRODUCTION
- 중환자실 환자들은 감염에 취약하며 병원 내 타 부서들 중 사망률 1위
=> 사망률이 높은 환자 예측은 중요한 문제
- 병원 내 사망률 예측 관련 딥러닝 알고리즘은 현재에도 많으나 여전히 '블랙박스'로 여겨지며 예측 논리의 타당성 입증 X
=> 이로 인해 환자나 의사들은 딥러닝 알고리즘의 결과를 신뢰
딥러닝의 설명 가능성 향상을 위한 모델 행동에 대한 통찰을 제공하는 XAI 제안
NLP와 computer vision에서 XAI 연구가 진행되고 있으나 시계열 데이터와 멀티모달 데이터에 의존하는 특성으로 인해 아직 초기 단계임
∴ 병원 내 사망률 예측 인공지능 기술의 현재 한계를 나타내며 딥러닝 기반 접근법의 설명 가능성의 중요성 강조
introduction 1) multimodal의 중요성
다양한 형태의 데이터가 대량으로 발생 => 다른 형태는 데이터 처리에 난도가 올라감
기존 clinical multimodal data에 대한 XAI는 거의 없음
=> healthcare data의 다양성에 대응되는 연구, 기술이 필요하며 multimodal data를 통한 임상 결정 지원 ai 모델이 필요
introduction 2) X-MMP 모델
병원 내 사망률 예측 멀티모달 프레임워크
- multimodal : (clinical note + discrete event sequence + vital sign)
- explainable ai
- 이질적 입력에서 feature 추출하는 3개의 transformer 기반 encoder
- Gradient * input
- Layer-wise Propagation to Transformer
위 기술들을 통해 모든 하위 인코더는 explainable module
=> ∀input feature의 attribution을 알 수 있음
즉, 다양한 모달리티의 input을 통합하여 처리하며 예측이 어떻게 이루어지는지 설명 가능
introduction 3) dataset & experiment
Medical Information Mart for Intensive Care III (MIMIC-III) 및 MIMIC-III Waveform Database Matched Subset에서 multimodal dataset 구축
관련 기존 연구
Deep learning in acute care data : 급성 질환 데이터
임상 데이터가 기하급수적으로 증가하여 다양한 임상 응용 분야에서 딥러닝 알고리즘 적용
- clinical note : 환자 건강 상태 정보 제공, 증상 변화 직관적 묘사 => 텍스트 데이터
- discrete event sequence : 시간-스탬프에 생리 측정값 기록, 모든 병원에서 손쉽게 수집 가능 => 연속적, 범주형 특징
- High-density vital signs : 매분 매초 샘플링됨, 병상 모니터 => 심전도(ECG), 동맥혈압(ABP), 호흡률(RR), 손끝 광용적맥파(PPG) 등 => 양이 많아 효율적 계산 모델 필요
분류 CNN, transformer, BERT 등 적용
* 많은 healthcare 데이터들은 시계열 형식, 테이블 기반임
=> 위 형식에 맞는 XAI 방법 필요
Data Preprocessing
(1) discrete event sequences
환자가 불안정할 때와 시간별 라운드 동안의 vital signs, lab results, and intervention events를 기록한 데이터
이산적 시계열 데이터
- 환자의 ICU 입원 후 첫 24시간으로 관찰 창 제한
- 1시간 간격으로 resampling하며 시간 내 여러 값이 있는 변수는 가장 최근 값을 유지함
- 누락된 값은 forward filling, 이전 값이 기록되지 않은 경우는 ‘normal’
- ⇒ 누락을 특징으로 모델링, time step이 유효한 값 포함 / 대체된 값 포함하는지 binary mask input matrix 제공
- categorical features → one-hot vectors, continuous features → zero-mean normalization
(2) clinical notes
의료진이 환자의 ICU 입원 후 환자에 대해 기록한 텍스트 데이터
오류 태그, 특수 문자 및 비식별화된 정보- 환자의 ICU 입원 후 첫 24시간 동안 작성된 텍스트 유지
- max 512개 단어를 ICU 입원 후 24시간의 끝부터 거꾸로 수집
- 전자건강기록 내 모든 노트 유형 활용 (for 가능한 많은 샘플의 보존)
- 환자의 모든 노트 연결 → 모든 정보를 하나의 문서로 통합해 분석 효율 향상
- 예측 목표(die, dying)와 관련된 단어 제거 → 데이터 유출 방지
(3) vital signs
병원 내 모니터로 매분, 매초 수집되는 환자 데이터 ex) PULSE..
고밀도 고차원 데이터
- 모니터링 장비의 차이, 설정의 변화로 환자마다 측정되는 생명 징후의 데이터 차이 존재
- 환자의 ICU 입원 후 첫 24시간으로 관찰 창 제한
- MIMIC-III Waveform Database Matched Subset에서 21가지 유형 생명 징후 추출
- 하나의 vital sign 중 절반 이상의 값이 누락된 경우, 샘플 폐기 → 데이터의 신뢰성
- 다중 채널, 고밀도로 계산 효율성 향상을 위해 vital 측정 빈도는 3분마다 한번씩 진행 → 24h동안 단일 채널에서 샘플링된 데이터는 480개 수치값
- 누락 값은 해당 변수의 가장 최근 값으로 대체, 이전 값 기록 x → ‘normal’
- zero-mean normalization 적용
- 병상 모니터는 병원의 디지털 건강 시스템에 직접 연결 X → waveform 기록의 일부:34%만 전자건강기록과 연결됨
- ⇒ vital sign의 sample 크기는 clinical note, discrete event sequence에 비해 작음
METHOD
A. Multimodal modeling : X-MMP structure
Discrete event sequence
- 시계열 event로 순서가 중요
- 각 event의 위치 encode를 위해 위치 임베딩 사인함수 적용
- short- and long- distance dependencies를 포착하기 위해 transformer block 사용
multi-head attention은 input feature의 다른 channel에 집중할 수 있음
- 트랜스포머 블록 다음에는 pooler layer를 사용해 시퀀스의 첫 번째 시간 단계를 표현으로 추출
Clinical note
- 텍스트 데이터
- input token을 인코딩하기 위해 clinicalBERT의 embedding layer 사용
의료 분야를 위해 사전 훈련된 언어 모델
- embedding layer 이후 token matrix는 transformer로 전달됨
- downstream calculation을 위한 표현으로 첫번째 토큰 추출을 위해 pooler layer 사용
vital sign
- multi-channel and highdensity
=> event sequence encoding의 sole position과는 다른 처리 필요
: embedding을 위해 input embedding & position embedding 모두 추가
- embedding layer 이후 transformer로 전달됨
- input embedding
- 저차원 벡터 → 고차원 벡터 mapping
- sequence modeling에 유리
- position embedding
- discrete event sequence 처리와 동일
- embedding layer 이후 long-term dependencies을 위해 transformer 적용
- pooler layer를 사용해 첫 번째 시간 단계 출력
▶ Multimodal fusion
서로 다른 Modality에서 나온 출력 표현들을 concatenated
latent representation들은 feedforward neural network로 전달
병원 내 사망 확률 예측을 위해 softmax layer 적용
즉, transformer 기반으로 다양한 modality들을 융합하여 사망률 예측을 진행하게됨
B. XAI method
LRP에 기반을 두고 있음
- 입력 변수에 할당된 attridution(기여도)이 모델의 출력에 합산되어야함
- multimodal representation = output에서 predicted score 형성에 각자 기여함
- conservation axiom E : Discrete event sequence, C : clinical note, V : vital sign
But, 트랜스포머는 skip connection과 attention에 의존하므로 위 식만으로는 적절 X
설명하는 AI와 LRP
설명하는 AI
- 피처맵 시각화 방식
- 히트맵 출력 방식
기존 배경
피쳐맵 시각화 방식(LIME, filter visualization)은 모델이 입력 이미지에 어떻게 반응하는지 각 은닉층을 조사
⇒ BUT, 깊은 은닉 계층일수록 해석력 감소, 여전히 다양한 해석의 여지 존재
딥러닝 모델은 가중치, bias, 활성화 함수 등으로 이루어진 신경망의 결합
⇒ 피처 연결, 활성화 과정 non-linear하며 다양한 커널로 매핑되어 추론하기 힘듦
⇒ 기존의 XAI 방법인 필터 시각화나 민감도 기법은 딥러닝 모델에서 feed-forward로 진행
LRP (Layer-wise Relevance Propagation)
딥러닝 모델의 결과를 역추적하여 입력 이미지에 히트맵 출력
블랙박스가 어느 곳 데이터를 주목하는지 확인 가능
⇒ 피처맵 기반 방법보다 블랙박스 오인 가능성이 낮음
블랙박스가 분류한 이미지 결과를 역순으로 탐지, 분해
-> 분해된 요소들이 원본 이미지까지 도달했을 때, 원본 이미지에 상대적 기여도를 표시하여 딥러닝 해석 가능
- 타당성 전파 : Relevance Propagation
: 특정 결과가 나오게 된 원인 분해, 비중 분배
분해 과정을 마친 은닉층이 결과값 출력에 기여하는지 타당성 계산
⇒ 모든 은닉층 내 활성화 함수의 기여도를 계산 ⇒ 이미지 x에서 픽셀 별 기여도를 표시할 수 있음
- 분해 : decomposition
타당성 전파를 통해 얻어낸 원인을 가중치로 환원, 해부
입력된 피처 ‘하나’가 결과 해석에 얼마나 영향을 미치는지 해체
ex) 이미지 x에서 픽셀 k가 결과 도출에 도움이 되는지 해가 되는지
=> 보존 특성 : Conservation property
각 타당성 계층 간 총합은 동일
Gradient*input + LRP
- β-규칙 → GI로 적용하여 attribution backpropagation 도출
- 상위 계층에서 받은 attribution은 하위 계층으로 완전히 재분배
- 다양한 모달리티의 입력 특성에서 특성 예측 결과 attribution을 얻을 수 있음
=> 이를 통해 모델이 어떻게 특정 결과에 도달하였는지에 대한 명확한 설명을 제공할 수 있음
ex) clinical note의 특정 단어, event sequence의 값 기여도 등
C. Better LRP rules for transformer
앞서 말했듯, 기존 LRP로는 transformer의 self-attention & layer normalization ⇒ attribution 보존 X, 해석 성능 저하 발생
→ Imporved gradient propagation rules
: self-attention과 layer normalizatoin을 지역적으로 선형 계층으로 변환
- 입력 시퀀스 x → 출력 시퀀스 y로 mapping
- 위 개선된 gradient rule을 사용해 국소적 선형 레이어의 가중치로 판단하여 더 보수적으로 기여도를 계산하여 해석 성능 개선
EXPERIMENT
1. 완성된 모델에 대한 성능 비교 -> 각 modality에 대해 단일 모델(GRU, LSTM, attRNN, RCNN)이 사용됨
그 결과, transformer이 기존 모델들에 비해 높은 성능을 띄는 것 확인
2. multimodal의 성능 확인
bi-modal => tri-modal 비교
단일 모델보다 멀티 모달을 사용한 모델의 성능이 점차 좋아지는 것 확인
=> 사용한 3개 데이터가 상호 보완적으로 예측을 도움을 알 수 있음
3. LRPTrans의 성능 확인
LRPTrans가 모든 데이터셋 중 높은 성능을 가지고 있음을 알 수 있음
4. 예측 설명성
분석 결과 GCS total 점수가 낮을수록, arrest, SpO2 등의 단어가 나타날수록 높은 사망률과 높은 연관성을 가지는 것 확인
이 결과는 통상적으로 우리가 알고 있는 의학적 지식과 동일함
=> 모델이 사망률과 연관성이 높은 정보를 적절히 반영하는 중임 확인
CONCLUSION
본 논문은 병원 내 ICU 환자의 data를 통한 사망률 예측을 위한 XAI 모델인 X-MMP 제안
- transformer 기반의 multimodal로 향상된 LRPTrans를 통해 예측 결과에 대한 근거 입증, 시각화 가능
- 해당 framework를 다른 임상 data에 대한 작업으로 쉽게 이전할 수 있어 의료 연구에서 중요한 의미를 가짐