LightGBMを使用して多クラス分類を行うサンプルコード
以下は、LightGBMを使用して多クラス分類を行うサンプルコードです。ここでは、Irisデータセットを使用して、3つの品種(クラス)を分類します。
python
import lightgbm as lgb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# データの読み込み
iris = load_iris()
X, y = iris.data, iris.target
# データの分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# LightGBM用のデータセットを作成
train_data = lgb.Dataset(X_train, label=y_train)
test_data = lgb.Dataset(X_test, label=y_test)
# パラメータの設定
params = {
'objective': 'multiclass',
'num_class': 3, # クラスの数
'metric': 'multi_error', # 評価指標
'verbosity': -1, # ログレベル
'seed': 42 # ランダムシード
}
# モデルの訓練
num_round = 100 # イテレーション数
model = lgb.train(params, train_data, num_round, valid_sets=[test_data])
# テストデータでの予測
y_pred = model.predict(X_test, num_iteration=model.best_iteration)
y_pred = [list(x).index(max(x)) for x in y_pred] # 確率が最大のクラスを選択
# 精度の評価
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
このコードでは、LightGBMのlgb.Dataset
を使用してデータセットを作成し、lgb.train
を使用してモデルを訓練します。その後、predict
メソッドを使用してテストデータのクラスを予測し、精度を評価します。
ディスカッション
コメント一覧
まだ、コメントがありません