본문으로 건너뛰기

023 다중 분류 문제 다루기

키워드: 다중 분류, multiclass

개요

다중 분류(Multiclass Classification)는 3개 이상의 클래스 중 하나를 예측하는 문제입니다. 이 글에서는 PyCaret으로 다중 분류 문제를 다루는 방법을 배웁니다.

실습 환경

  • Python 버전: 3.11 권장
  • 필요 패키지: pycaret[full]>=3.0

다중 분류 vs 이진 분류

구분이진 분류다중 분류
클래스 수2개3개 이상
예시스팸/정상붓꽃 품종 (3종)
출력0 또는 10, 1, 2, ...

다중 분류 기본 예제

from pycaret.classification import *
from pycaret.datasets import get_data

# 023 붓꽃 데이터 (3 클래스)
data = get_data('iris')

# 023 클래스 확인
print(data['species'].value_counts())
# 023 setosa 50
# 023 versicolor 50
# 023 virginica 50

# 023 환경 설정
clf = setup(data, target='species', session_id=42)

# 023 모델 비교
best = compare_models()

다중 분류 알고리즘

대부분의 알고리즘이 다중 분류를 기본 지원:

# 023 자연스럽게 다중 분류 지원
rf = create_model('rf') # Random Forest
dt = create_model('dt') # Decision Tree
knn = create_model('knn') # K-Nearest Neighbors
nb = create_model('nb') # Naive Bayes

# 023 One-vs-Rest/One-vs-One 방식
lr = create_model('lr') # Logistic Regression
svm = create_model('svm') # Support Vector Machine

다중 분류 평가 지표

Macro Average vs Micro Average

from pycaret.classification import *
from pycaret.datasets import get_data
from sklearn.metrics import classification_report

data = get_data('iris')
clf = setup(data, target='species', session_id=42, verbose=False)

model = create_model('rf', verbose=False)
predictions = predict_model(model)

# 023 상세 분류 리포트
report = classification_report(
predictions['species'],
predictions['prediction_label']
)
print(report)

평균 방식:

  • Macro: 각 클래스별 지표의 단순 평균 (클래스 동등 취급)
  • Micro: 전체 TP, FP, FN으로 계산 (샘플 동등 취급)
  • Weighted: 클래스별 샘플 수로 가중 평균
from sklearn.metrics import f1_score

y_true = predictions['species']
y_pred = predictions['prediction_label']

macro_f1 = f1_score(y_true, y_pred, average='macro')
micro_f1 = f1_score(y_true, y_pred, average='micro')
weighted_f1 = f1_score(y_true, y_pred, average='weighted')

print(f"Macro F1: {macro_f1:.4f}")
print(f"Micro F1: {micro_f1:.4f}")
print(f"Weighted F1: {weighted_f1:.4f}")

다중 분류 시각화

from pycaret.classification import *
from pycaret.datasets import get_data

data = get_data('iris')
clf = setup(data, target='species', session_id=42, verbose=False)

model = create_model('rf', verbose=False)

# 023 혼동 행렬
plot_model(model, plot='confusion_matrix')

# 023 다중 클래스 ROC (One-vs-Rest)
plot_model(model, plot='auc')

# 023 클래스 리포트
plot_model(model, plot='class_report')

불균형 다중 분류

from pycaret.classification import *
from pycaret.datasets import get_data
import pandas as pd
import numpy as np

# 023 불균형 데이터 생성 예시
data = get_data('iris')

# 023 인위적으로 불균형 만들기
np.random.seed(42)
data_imbalanced = pd.concat([
data[data['species'] == 'setosa'], # 50개
data[data['species'] == 'versicolor'].head(30), # 30개
data[data['species'] == 'virginica'].head(10) # 10개
])

print(data_imbalanced['species'].value_counts())

# 023 불균형 처리
clf = setup(
data_imbalanced,
target='species',
fix_imbalance=True,
session_id=42,
verbose=False
)

best = compare_models(sort='F1')

One-vs-Rest (OvR) 전략

이진 분류기를 N개의 클래스에 대해 N번 학습:

from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression

# 023 PyCaret은 내부적으로 OvR 적용
# 023 직접 사용하려면:

from pycaret.classification import *
from pycaret.datasets import get_data

data = get_data('iris')
clf = setup(data, target='species', session_id=42, verbose=False)

# 023 Logistic Regression은 OvR로 다중 분류
lr = create_model('lr')

One-vs-One (OvO) 전략

모든 클래스 쌍에 대해 이진 분류기 학습:

  • 클래스 수가 N이면 N(N-1)/2개의 분류기 필요
  • SVM에서 주로 사용
# 023 SVM은 기본적으로 OvO 사용
svm = create_model('svm')

다중 분류 확률 예측

from pycaret.classification import *
from pycaret.datasets import get_data

data = get_data('iris')
clf = setup(data, target='species', session_id=42, verbose=False)

model = create_model('rf', verbose=False)

# 023 예측 (확률 포함)
predictions = predict_model(model, raw_score=True)

# 023 각 클래스별 확률 확인
print(predictions[['species', 'prediction_label', 'prediction_score']].head())

# 023 raw_score=True 시 각 클래스별 확률 컬럼 생성됨

실전 예제: 손글씨 숫자 분류

from pycaret.classification import *
from sklearn.datasets import load_digits
import pandas as pd

# 023 데이터 로드
digits = load_digits()
data = pd.DataFrame(digits.data)
data['target'] = digits.target

# 10개 클래스 (0-9)
print(f"클래스 수: {data['target'].nunique()}")
print(data['target'].value_counts().sort_index())

# 023 환경 설정
clf = setup(data, target='target', session_id=42, verbose=False)

# 023 모델 비교
print("=== 모델 비교 ===")
best = compare_models(n_select=3)

# 023 최고 모델 평가
plot_model(best[0], plot='confusion_matrix')

정리

  • 다중 분류는 3개 이상의 클래스를 예측
  • 대부분의 알고리즘이 자연스럽게 지원 (OvR, OvO)
  • 평가 지표는 Macro/Micro/Weighted 평균 사용
  • 불균형 시 fix_imbalance=True 또는 F1/AUC 기준 선택
  • raw_score=True로 각 클래스별 확률 확인 가능

다음 글 예고

다음 글에서는 불균형 데이터 처리 - fix_imbalance 옵션을 다룹니다.


PyCaret 머신러닝 마스터 시리즈 #023