TensorFlowで画像分類モデルを構築するサンプルコード

2024年6月17日

TensorFlowを使用して画像分類モデルを構築するためのサンプルコードを提供します。以下の例では、畳み込みニューラルネットワーク(Convolutional Neural Network, CNN)を使用して、画像から商品サイズの分類を行います。

python
import tensorflow as tf
from tensorflow.keras import layers, models

# モデルの定義
def create_model(input_shape, num_classes):
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(num_classes, activation='softmax')
    ])
    return model

# モデルのコンパイル
def compile_model(model):
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

# データの前処理
def preprocess_data(X_train, y_train, X_test, y_test):
    X_train_normalized = X_train / 255.0
    X_test_normalized = X_test / 255.0
    return X_train_normalized, y_train, X_test_normalized, y_test

# モデルの学習
def train_model(model, X_train, y_train, X_test, y_test, epochs=10):
    history = model.fit(X_train, y_train, epochs=epochs, validation_data=(X_test, y_test))
    return history

# モデルの評価
def evaluate_model(model, X_test, y_test):
    loss, accuracy = model.evaluate(X_test, y_test)
    print("Test Loss:", loss)
    print("Test Accuracy:", accuracy)

# メイン関数
def main():
    # データの読み込みや前処理
    # X_train, y_train, X_test, y_test = load_and_preprocess_data()
    # preprocess_data 関数を使って、データの前処理を実装してください

    # 仮のデータ
    X_train, y_train = ...
    X_test, y_test = ...

    # モデルの作成
    input_shape = X_train.shape[1:]
    num_classes = len(set(y_train))
    model = create_model(input_shape, num_classes)

    # モデルのコンパイル
    compile_model(model)

    # モデルの学習
    history = train_model(model, X_train, y_train, X_test, y_test)

    # モデルの評価
    evaluate_model(model, X_test, y_test)

if __name__ == "__main__":
    main()

このサンプルコードでは、データの前処理、モデルの作成、コンパイル、学習、評価の各ステップが示されています。実際のデータを読み込んで適切に処理し、必要に応じてモデルのアーキテクチャやハイパーパラメータを調整してください。

未分類

Posted by ぼっち