Skip to content

Commit

Permalink
Refactor training documentation and code examples for data loading an…
Browse files Browse the repository at this point in the history
…d augmentation
  • Loading branch information
Abraham KOLOBOE committed Oct 18, 2024
1 parent 4950d5d commit c06bd54
Showing 1 changed file with 56 additions and 33 deletions.
89 changes: 56 additions & 33 deletions Train-Models.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,23 @@ Nous utilisons le **dossier `Training`** du dataset Fruits 360, contenant des im
- **Pourquoi ce split ?** Il permet de s’assurer que le modèle peut bien généraliser sur des données non vues.

#### **Commandes pour Préparer le Dataset**
```bash
# Clone du repo et déplacement dans le répertoire
git clone https://github.com/fruits-360/fruits-360-100x100.git
cd fruits-360-100x100

# Split du dataset
python scripts/split_data.py --input_dir Training --split_ratio 0.25
```python
def load_data(data_dir, validation_split=0.25, seed=1337, image_size=(100, 100), batch_size=128, label_mode='int'):
"""Load and split the data into training and validation sets."""
logging.info(f"Loading data from {data_dir}")
train_ds, val_ds = keras.utils.image_dataset_from_directory(
data_dir,
validation_split=validation_split,
subset="both",
seed=seed,
image_size=image_size,
batch_size=batch_size,
label_mode=label_mode
)
return train_ds, val_ds
# Load data
train_ds, val_ds = load_data("data/Training")
```

---
Expand All @@ -39,24 +49,21 @@ Pour **enrichir le dataset** et éviter l’overfitting, nous avons appliqué de
**Transformations appliquées :**
- **Rotation** aléatoire entre -15° et +15°
- **Flip horizontal**
- **Modification de la luminosité et du contraste**
- **Modification de la luminosité et du contraste**


#### **Extrait de Code : Data Augmentation**
```python
import albumentations as A
from albumentations.core.composition import OneOf
from albumentations.pytorch import ToTensorV2

# Définition de l'augmentation
transform = A.Compose([
A.RandomRotate90(),
A.Flip(),
A.OneOf([
A.RandomBrightnessContrast(),
A.HueSaturationValue(),
], p=0.5),
ToTensorV2(),
])
transforms = [
A.RandomRotate90(p=1.0),
A.Transpose(p=1.0),
A.VerticalFlip(p=1.0),
A.HorizontalFlip(p=1.0),
A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, p=1.0),
]
```

**Pourquoi utiliser Albumentations ?**
Expand Down Expand Up @@ -107,10 +114,25 @@ Nous avons testé 4 architectures :

**Code : Initialisation des Modèles Pré-entrainés**
```python
from tensorflow.keras.applications import ResNet50, VGG16, EfficientNetB0

model = ResNet50(weights='imagenet', input_shape=(100, 100, 3), include_top=False)
model.trainable = True # Fine-tuning
def create_efficientnet_model(num_classes):
"""Create an EfficientNet model."""
logging.info("Creating EfficientNet model")
base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=(100, 100, 3))
base_model.trainable = False

model = keras.Sequential([
base_model,
layers.GlobalAveragePooling2D(name='global_avg_pooling'),
layers.Dense(num_classes*3, activation='relu', name='dense_1'),
layers.BatchNormalization(name='batch_norm_1'),
layers.Dropout(0.2, name='dropout_1'),
layers.Dense(num_classes*2, activation='relu', name='dense_2'),
layers.BatchNormalization(name='batch_norm_2'),
layers.Dropout(0.2, name='dropout_2'),
layers.Dense(num_classes, activation='softmax', name='output_layer')
])
return model
create_efficientnet_model(num_classes)
```

---
Expand All @@ -120,14 +142,15 @@ Nous avons utilisé un **callback Keras** pour enregistrer uniquement le **meill

#### **Extrait de Code : Callback**
```python
from tensorflow.keras.callbacks import ModelCheckpoint

checkpoint = ModelCheckpoint(
'best_model.keras',
monitor='val_accuracy',
save_best_only=True,
mode='max'
)
imoort keras
logging.info(f"Training {model_name} model")
callbacks = [
keras.callbacks.ModelCheckpoint(
f"models/best_model_{model_name}.keras", save_best_only=True, monitor="val_acc", mode="max"
),
keras.callbacks.EarlyStopping(monitor='val_acc', patience=patience, mode="max", restore_best_weights=True),
keras.callbacks.CSVLogger(f'artefacts/training_log_{model_name}.csv')
]
```

**Pourquoi utiliser un callback ?**
Expand Down Expand Up @@ -170,16 +193,16 @@ print(f"F1: {f1}, AUC: {auc}, Precision: {precision}, Recall: {recall}")
Pour reproduire l’entraînement :
1. **Cloner le repo :**
```bash
git clone <url_du_repo>
cd <nom_du_repo>
git https://github.com/abrahamkoloboe27/Machine-Learning-En-Production-LinkedIn.git /data
cd data
```
2. **Installer les dépendances :**
```bash
pip install -r requirements.txt
```
3. **Lancer l’entraînement :**
```bash
python train_model.py
python main.py
```

---
Expand Down

0 comments on commit c06bd54

Please sign in to comment.