LightGBMを使用して多クラス分類を行うサンプルコード

以下は、LightGBMを使用して多クラス分類を行うサンプルコードです。ここでは、Irisデータセットを使用して、3つの品種(クラス)を分類します。

python
import lightgbm as lgb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# データの読み込み
iris = load_iris()
X, y = iris.data, iris.target

# データの分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# LightGBM用のデータセットを作成
train_data = lgb.Dataset(X_train, label=y_train)
test_data = lgb.Dataset(X_test, label=y_test)

# パラメータの設定
params = {
    'objective': 'multiclass',
    'num_class': 3,  # クラスの数
    'metric': 'multi_error',  # 評価指標
    'verbosity': -1,  # ログレベル
    'seed': 42  # ランダムシード
}

# モデルの訓練
num_round = 100  # イテレーション数
model = lgb.train(params, train_data, num_round, valid_sets=[test_data])

# テストデータでの予測
y_pred = model.predict(X_test, num_iteration=model.best_iteration)
y_pred = [list(x).index(max(x)) for x in y_pred]  # 確率が最大のクラスを選択

# 精度の評価
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

このコードでは、LightGBMのlgb.Datasetを使用してデータセットを作成し、lgb.trainを使用してモデルを訓練します。その後、predictメソッドを使用してテストデータのクラスを予測し、精度を評価します。

未分類

Posted by ぼっち