diff --git a/.gitignore b/.gitignore index 7c46d15..713d18a 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,6 @@ main.ipynb main.py __pycache__ .DS_Store -data \ No newline at end of file +data +model +dataset \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 93fbfa9..78e3477 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,5 +2,7 @@ "markdownlint.config": { "MD010": false, "MD033": false - } + }, + "debugpy.debugJustMyCode": false, + // "debugpy.debugJustMyCode": true, } \ No newline at end of file diff --git a/README.md b/README.md index a909590..62d7fc0 100644 --- a/README.md +++ b/README.md @@ -43,12 +43,6 @@ | 姓名 | 职责 | 简介 | | :----- | :--------------------- | :----------------- | | 田健翔 | 项目负责人 | 内容创作者 | -| 于小敏 | 项目指导人 | DataWhale正式成员 | -| 卢鑫斌 | 第1章(Datasets)贡献者 | 内容创作者 | -| 胥佳程 | 第3章(PEFT)贡献者 | 内容创作者 | -| 秦子涵 | 第5章(Diffusers)贡献者 | 内容创作者 | -| 陈凯歌 | 第7章(Gradio)贡献者 | 内容创作者 | -| 刘硕 | 第7章(Gradio)贡献者 | 内容创作者 | - PEFT - LoRa:@[鑫民](https://github.com/fancyboi999) @@ -57,6 +51,8 @@ - Prefix-Tuning:@[鑫民](https://github.com/fancyboi999) - prompt-Tuning:@[鑫民](https://github.com/fancyboi999) - P-Tuning:@[鑫民](https://github.com/fancyboi999) +- 代码案例 + - 图像分类: @[陈相斌](https://github.com/chenxinxi) 项目保姆(o^^o):高增玉 diff --git a/docker-compose/Dockerfile b/docker-compose/Dockerfile index 3e00b8a..e6fc05a 100644 --- a/docker-compose/Dockerfile +++ b/docker-compose/Dockerfile @@ -13,6 +13,7 @@ RUN apt-get update && apt-get install -y \ unzip \ inetutils-ping \ tmux \ + watch \ && apt-get clean COPY --from=miniconda-stage /opt/conda /opt/conda diff --git a/docs/chapter6/code_index.md b/docs/chapter6/code_index.md index dd3669f..c5777ef 100644 --- a/docs/chapter6/code_index.md +++ b/docs/chapter6/code_index.md @@ -14,3 +14,4 @@ title: 索引 | 文本翻译 | [中英文本翻译](./translation/translation.md) | | 扩散去噪 | [ddpm-unet简单去噪](./ddpm-unet-mnist/ddpm-unet-mnist.md) | | 文本分类 | [基金年报问答意图识别](./financial_report/financial_report.md) | +| 图像分类 | [菜肴图像分类](./image_classification/image_classification.md) | diff --git a/docs/chapter6/image_classification/image_classification.md b/docs/chapter6/image_classification/image_classification.md new file mode 100644 index 0000000..d421b0d --- /dev/null +++ b/docs/chapter6/image_classification/image_classification.md @@ -0,0 +1,221 @@ +--- +comments: true +title: 菜肴图像分类 +--- + +![image_classification](./imgs/image_classification.png) + +## 前言 + +## 代码 + +```python +model_checkpoint = "google/vit-base-patch16-224-in21k" +``` + +### 导入函数库 + +```python +import evaluate +import numpy as np +import torch +from datasets import load_dataset +from peft import LoraConfig, get_peft_model +from torchvision.transforms import ( + CenterCrop, + Compose, + Normalize, + RandomHorizontalFlip, + RandomResizedCrop, + Resize, + ToTensor, +) +from transformers import ( + AutoImageProcessor, + AutoModelForImageClassification, + Trainer, + TrainingArguments, +) +``` + +### 读取数据集 + +```python +dataset = load_dataset("food101", split="train[:5000]") + +labels = dataset.features["label"].names + +label2id, id2label = dict(), dict() +for i, label in enumerate(labels): + label2id[label] = i + id2label[i] = label +``` + +下面是数据集`food101`的数据集主页。 + + + +### 加载模型 + +```python +model = AutoModelForImageClassification.from_pretrained( + model_checkpoint, + label2id=label2id, + id2label=id2label, + # provide this in case you're planning to fine-tune an already fine-tuned checkpoint + ignore_mismatched_sizes=True, +) + +config = LoraConfig( + r=16, + lora_alpha=16, + target_modules=["query", "value"], + lora_dropout=0.1, + bias="none", + modules_to_save=["classifier"], +) +lora_model = get_peft_model(model, config) +``` + +使用参数高效微调后打印可训练参数如下: + +```python title="model.print_trainable_parameters()" +trainable params: 667,493 || all params: 86,543,818 || trainable%: 0.7713 +``` + +### 加载预处理器 + +```python +image_processor = AutoImageProcessor.from_pretrained(model_checkpoint) +``` + +### 定义数据转换 + +```python +normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) +train_transforms = Compose( + [ + RandomResizedCrop(image_processor.size["height"]), + RandomHorizontalFlip(), + ToTensor(), + normalize, + ] +) + +val_transforms = Compose( + [ + Resize(image_processor.size["height"]), + CenterCrop(image_processor.size["height"]), + ToTensor(), + normalize, + ] +) + +def preprocess_train(example_batch): + """Apply train_transforms across a batch.""" + example_batch["pixel_values"] = [ + train_transforms(image.convert("RGB")) for image in example_batch["image"] + ] + return example_batch + + +def preprocess_val(example_batch): + """Apply val_transforms across a batch.""" + example_batch["pixel_values"] = [ + val_transforms(image.convert("RGB")) for image in example_batch["image"] + ] + return example_batch +``` + +### 数据预处理 + +```python +splits = dataset.train_test_split(test_size=0.1) +train_ds = splits["train"] +val_ds = splits["test"] + +train_ds.set_transform(preprocess_train) +val_ds.set_transform(preprocess_val) +``` + +### 定义评价指标 + +```python +metric = evaluate.load("accuracy") + +def compute_metrics(eval_pred): + """Computes accuracy on a batch of predictions""" + predictions = np.argmax(eval_pred.predictions, axis=1) + return metric.compute(predictions=predictions, references=eval_pred.label_ids) +``` + +### 定义动态数据整理 + +```python +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + labels = torch.tensor([example["label"] for example in examples]) + return {"pixel_values": pixel_values, "labels": labels} +``` + +### 定义训练参数 + +```python +args = TrainingArguments( + "vit-finetuned-lora-food101", + remove_unused_columns=False, + eval_strategy="epoch", + save_strategy="epoch", + save_total_limit=2, + learning_rate=5e-3, + per_device_train_batch_size=128, + gradient_accumulation_steps=4, + per_device_eval_batch_size=128, + fp16=True, + num_train_epochs=5, + logging_steps=10, + load_best_model_at_end=True, + metric_for_best_model="accuracy", + label_names=["labels"], + use_cpu=False, +) +``` + +### 定义训练器 + +```python +trainer = Trainer( + lora_model, + args, + train_dataset=train_ds, + eval_dataset=val_ds, + tokenizer=image_processor, + compute_metrics=compute_metrics, + data_collator=collate_fn, +) +``` + +### 训练与评估 + +```python +trainer.train() +trainer.evaluate(val_ds) +``` + +下面是训练时的过程结果。 + +| 轮次 | 评估损失 | 评估准确率 | +| ---- | -------- | ---------- | +| 0.8 | 4.0372 | 0.80 | +| 1.6 | 3.5086 | 0.876 | +| 2.4 | 3.0289 | 0.896 | +| 4.0 | 2.4545 | 0.908 | + +## 参考资料 + +待补充 diff --git a/docs/chapter6/image_classification/image_classification.py b/docs/chapter6/image_classification/image_classification.py new file mode 100644 index 0000000..99da9b0 --- /dev/null +++ b/docs/chapter6/image_classification/image_classification.py @@ -0,0 +1,148 @@ +import evaluate +import numpy as np +import torch +from datasets import load_dataset +from peft import LoraConfig, get_peft_model +from torchvision.transforms import ( + CenterCrop, + Compose, + Normalize, + RandomHorizontalFlip, + RandomResizedCrop, + Resize, + ToTensor, +) +from transformers import ( + AutoImageProcessor, + AutoModelForImageClassification, + Trainer, + TrainingArguments, +) + + +model_checkpoint = "google/vit-base-patch16-224-in21k" + +dataset = load_dataset("food101", split="train[:5000]") + +labels = dataset.features["label"].names + +label2id, id2label = dict(), dict() +for i, label in enumerate(labels): + label2id[label] = i + id2label[i] = label + + +image_processor = AutoImageProcessor.from_pretrained(model_checkpoint) + +normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) +train_transforms = Compose( + [ + RandomResizedCrop(image_processor.size["height"]), + RandomHorizontalFlip(), + ToTensor(), + normalize, + ] +) + +val_transforms = Compose( + [ + Resize(image_processor.size["height"]), + CenterCrop(image_processor.size["height"]), + ToTensor(), + normalize, + ] +) + + +def preprocess_train(example_batch): + """Apply train_transforms across a batch.""" + example_batch["pixel_values"] = [ + train_transforms(image.convert("RGB")) for image in example_batch["image"] + ] + return example_batch + + +def preprocess_val(example_batch): + """Apply val_transforms across a batch.""" + example_batch["pixel_values"] = [ + val_transforms(image.convert("RGB")) for image in example_batch["image"] + ] + return example_batch + + +# split up training into training + validation +splits = dataset.train_test_split(test_size=0.1) +train_ds = splits["train"] +val_ds = splits["test"] + +train_ds.set_transform(preprocess_train) +val_ds.set_transform(preprocess_val) + +model = AutoModelForImageClassification.from_pretrained( + model_checkpoint, + label2id=label2id, + id2label=id2label, + ignore_mismatched_sizes=True, +) + +config = LoraConfig( + r=16, + lora_alpha=16, + target_modules=["query", "value"], + lora_dropout=0.1, + bias="none", + modules_to_save=["classifier"], +) +lora_model = get_peft_model(model, config) + + +args = TrainingArguments( + "vit-finetuned-lora-food101", + remove_unused_columns=False, + eval_strategy="epoch", + save_strategy="epoch", + save_total_limit=2, + learning_rate=5e-3, + per_device_train_batch_size=128, + gradient_accumulation_steps=4, + per_device_eval_batch_size=128, + fp16=True, + num_train_epochs=5, + logging_steps=10, + load_best_model_at_end=True, + metric_for_best_model="accuracy", + label_names=["labels"], + use_cpu=False, +) + + +metric = evaluate.load("accuracy") + + +# the compute_metrics function takes a Named Tuple as input: +# predictions, which are the logits of the model as Numpy arrays, +# and label_ids, which are the ground-truth labels as Numpy arrays. +def compute_metrics(eval_pred): + """Computes accuracy on a batch of predictions""" + predictions = np.argmax(eval_pred.predictions, axis=1) + return metric.compute(predictions=predictions, references=eval_pred.label_ids) + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + labels = torch.tensor([example["label"] for example in examples]) + return {"pixel_values": pixel_values, "labels": labels} + + +trainer = Trainer( + lora_model, + args, + train_dataset=train_ds, + eval_dataset=val_ds, + tokenizer=image_processor, + compute_metrics=compute_metrics, + data_collator=collate_fn, +) + +trainer.train() +trainer.evaluate(val_ds) diff --git a/docs/chapter6/image_classification/imgs/image_classification.png b/docs/chapter6/image_classification/imgs/image_classification.png new file mode 100644 index 0000000..81737b4 Binary files /dev/null and b/docs/chapter6/image_classification/imgs/image_classification.png differ diff --git a/docs/index.md b/docs/index.md index fb1d833..41dbda6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -43,12 +43,6 @@ title: Unlock-HuggingFace | 姓名 | 职责 | 简介 | | :----- | :--------------------- | :---------------- | | 田健翔 | 项目负责人 | 内容创作者 | -| 于小敏 | 项目指导人 | DataWhale正式成员 | -| 卢鑫斌 | 第1章(Datasets)贡献者 | 内容创作者 | -| 胥佳程 | 第3章(PEFT)贡献者 | 内容创作者 | -| 秦子涵 | 第5章(Diffusers)贡献者 | 内容创作者 | -| 陈凯歌 | 第7章(Gradio)贡献者 | 内容创作者 | -| 刘硕 | 第7章(Gradio)贡献者 | 内容创作者 | - PEFT - LoRa:@[鑫民](https://github.com/fancyboi999) @@ -57,6 +51,8 @@ title: Unlock-HuggingFace - Prefix-Tuning:@[鑫民](https://github.com/fancyboi999) - prompt-Tuning:@[鑫民](https://github.com/fancyboi999) - P-Tuning:@[鑫民](https://github.com/fancyboi999) +- 代码案例 + - 图像分类: @[陈相斌](https://github.com/chenxinxi) 项目保姆(o^^o):高增玉 diff --git a/mkdocs.yml b/mkdocs.yml index 05e3e98..73d32c6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -71,13 +71,14 @@ nav: - DDPM数学原理: 'chapter5/ddpm/ddpm_math.md' - 代码案例: - 索引: 'chapter6/code_index.md' - - 多标签分类任务: "chapter6/mlcoftc/multi-label-classification-of-toxic-comments.md" + - 多标签分类: "chapter6/mlcoftc/multi-label-classification-of-toxic-comments.md" - 抽取式阅读理解: "chapter6/cmrc/cmrc.md" - 文本摘要: "chapter6/text-summary/text-summary.md" - 目标检测: "chapter6/container-detr/container-detr.md" - 文本翻译: "chapter6/translation/translation.md" - 一种简单的去噪方法: 'chapter6/ddpm-unet-mnist/ddpm-unet-mnist.md' - 文本分类: "chapter6/financial_report/financial_report.md" + - 图像分类: "chapter6/image_classification/image_classification.md" - Gradio工具: - 索引: 'chapter7/gradio_index.md' - Gradio: 'chapter7/gradio/gradio_tour.md'